| import torch |
| import torch.nn as nn |
| import timm |
| from pathlib import Path |
| import logging |
| import os |
|
|
| logging.basicConfig(level=logging.INFO) |
| logger = logging.getLogger(__name__) |
|
|
|
|
| class EfficientNetDeepFakeDetector(nn.Module): |
| """Frame-level EfficientNet-B0 with temporal mean-pooling.""" |
|
|
| FEAT_DIM = 1280 |
|
|
| def __init__(self, dropout: float = 0.4): |
| super().__init__() |
|
|
| |
| backbone = timm.create_model( |
| 'efficientnet_b0', |
| pretrained=False, |
| num_classes=0, |
| global_pool='avg' |
| ) |
|
|
| |
| for m in backbone.modules(): |
| if isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)): |
| m.eval() |
| for p in m.parameters(): |
| p.requires_grad = False |
|
|
| self.backbone = backbone |
|
|
| |
| self.head = nn.Sequential( |
| nn.LayerNorm(self.FEAT_DIM), |
| nn.Dropout(dropout), |
| nn.Linear(self.FEAT_DIM, 256), |
| nn.GELU(), |
| nn.Dropout(dropout * 0.5), |
| nn.Linear(256, 1) |
| ) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| B, T, C, H, W = x.shape |
| x = x.view(B * T, C, H, W) |
| feat = self.backbone(x) |
| feat = feat.view(B, T, self.FEAT_DIM) |
| feat = feat.mean(dim=1) |
| logit = self.head(feat).squeeze(-1) |
| return logit |
|
|
|
|
| class DeepFakeModel: |
| def __init__(self, model_path: str, device: str = "cpu"): |
| self.device = torch.device(device) |
| self.model = EfficientNetDeepFakeDetector(dropout=0.4).to(self.device) |
| self._load_model(model_path) |
| self.model.eval() |
| logger.info(f"Model loaded on {self.device}") |
|
|
| def _load_model(self, model_path: str): |
| """Load model checkpoint from file""" |
| if not os.path.exists(model_path): |
| raise FileNotFoundError(f"Model file not found at {model_path}") |
| |
| checkpoint = torch.load(model_path, map_location=self.device, weights_only=False) |
| |
| |
| if 'model_state_dict' in checkpoint: |
| self.model.load_state_dict(checkpoint['model_state_dict']) |
| epoch = checkpoint.get('epoch', 'unknown') |
| val_f1 = checkpoint.get('val_f1_macro', 'unknown') |
| logger.info(f"Loaded checkpoint from epoch {epoch} (val_f1={val_f1})") |
| else: |
| |
| self.model.load_state_dict(checkpoint) |
| logger.info("Loaded model state dict") |
|
|
| @torch.no_grad() |
| def predict(self, video_tensor: torch.Tensor, threshold: float = 0.5) -> dict: |
| """ |
| Predict if video is real or fake. |
| |
| Args: |
| video_tensor: Tensor of shape (T, 3, H, W) or (1, T, 3, H, W) |
| threshold: Decision threshold (default: 0.5) |
| |
| Returns: |
| dict with prediction, confidence, and probabilities |
| """ |
| if video_tensor.dim() == 4: |
| video_tensor = video_tensor.unsqueeze(0) |
|
|
| video_tensor = video_tensor.to(self.device) |
| logit = self.model(video_tensor) |
| prob = torch.sigmoid(logit).item() |
| |
| |
| prediction = "REAL" if prob >= threshold else "FAKE" |
| confidence = prob if prediction == "REAL" else 1 - prob |
|
|
| return { |
| "prediction": prediction, |
| "confidence": round(confidence, 4), |
| "probability_real": round(prob, 4), |
| "probability_fake": round(1 - prob, 4), |
| "threshold": threshold |
| } |
| |
| @torch.no_grad() |
| def predict_from_video_path(self, video_path: str, threshold: float = 0.5) -> dict: |
| """ |
| Convenience method to predict directly from video file path. |
| |
| Args: |
| video_path: Path to video file |
| threshold: Decision threshold |
| |
| Returns: |
| Prediction result dictionary |
| """ |
| from .utils import video_to_tensor |
| |
| video_tensor = video_to_tensor( |
| video_path, |
| num_frames=16, |
| img_size=224 |
| ) |
| |
| return self.predict(video_tensor, threshold) |