import torch import torch.nn as nn import timm from pathlib import Path import logging import os logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class EfficientNetDeepFakeDetector(nn.Module): """Frame-level EfficientNet-B0 with temporal mean-pooling.""" FEAT_DIM = 1280 def __init__(self, dropout: float = 0.4): super().__init__() # Backbone backbone = timm.create_model( 'efficientnet_b0', pretrained=False, num_classes=0, global_pool='avg' ) # Freeze BatchNorm layers for m in backbone.modules(): if isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)): m.eval() for p in m.parameters(): p.requires_grad = False self.backbone = backbone # Classifier head self.head = nn.Sequential( nn.LayerNorm(self.FEAT_DIM), nn.Dropout(dropout), nn.Linear(self.FEAT_DIM, 256), nn.GELU(), nn.Dropout(dropout * 0.5), nn.Linear(256, 1) ) def forward(self, x: torch.Tensor) -> torch.Tensor: B, T, C, H, W = x.shape x = x.view(B * T, C, H, W) feat = self.backbone(x) feat = feat.view(B, T, self.FEAT_DIM) feat = feat.mean(dim=1) logit = self.head(feat).squeeze(-1) return logit class DeepFakeModel: def __init__(self, model_path: str, device: str = "cpu"): self.device = torch.device(device) self.model = EfficientNetDeepFakeDetector(dropout=0.4).to(self.device) self._load_model(model_path) self.model.eval() logger.info(f"Model loaded on {self.device}") def _load_model(self, model_path: str): """Load model checkpoint from file""" if not os.path.exists(model_path): raise FileNotFoundError(f"Model file not found at {model_path}") checkpoint = torch.load(model_path, map_location=self.device, weights_only=False) # Handle different checkpoint formats if 'model_state_dict' in checkpoint: self.model.load_state_dict(checkpoint['model_state_dict']) epoch = checkpoint.get('epoch', 'unknown') val_f1 = checkpoint.get('val_f1_macro', 'unknown') logger.info(f"Loaded checkpoint from epoch {epoch} (val_f1={val_f1})") else: # If checkpoint is just the state dict self.model.load_state_dict(checkpoint) logger.info("Loaded model state dict") @torch.no_grad() def predict(self, video_tensor: torch.Tensor, threshold: float = 0.5) -> dict: """ Predict if video is real or fake. Args: video_tensor: Tensor of shape (T, 3, H, W) or (1, T, 3, H, W) threshold: Decision threshold (default: 0.5) Returns: dict with prediction, confidence, and probabilities """ if video_tensor.dim() == 4: video_tensor = video_tensor.unsqueeze(0) video_tensor = video_tensor.to(self.device) logit = self.model(video_tensor) prob = torch.sigmoid(logit).item() # prob = P(REAL), because training used label 1=REAL, 0=FAKE prediction = "REAL" if prob >= threshold else "FAKE" confidence = prob if prediction == "REAL" else 1 - prob return { "prediction": prediction, "confidence": round(confidence, 4), "probability_real": round(prob, 4), "probability_fake": round(1 - prob, 4), "threshold": threshold } @torch.no_grad() def predict_from_video_path(self, video_path: str, threshold: float = 0.5) -> dict: """ Convenience method to predict directly from video file path. Args: video_path: Path to video file threshold: Decision threshold Returns: Prediction result dictionary """ from .utils import video_to_tensor video_tensor = video_to_tensor( video_path, num_frames=16, img_size=224 ) return self.predict(video_tensor, threshold)