import torch import torch.nn as nn import timm from pathlib import Path import logging import os logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class EfficientNetDeepFakeDetector(nn.Module): """Frame-level EfficientNet-B0 with temporal mean-pooling.""" FEAT_DIM = 1280 def __init__(self, dropout: float = 0.4): super().__init__() # Backbone backbone = timm.create_model( 'efficientnet_b0', pretrained=False, num_classes=0, global_pool='avg' ) # Freeze BatchNorm layers for m in backbone.modules(): if isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)): m.eval() for p in m.parameters(): p.requires_grad = False self.backbone = backbone # Classifier head self.head = nn.Sequential( nn.LayerNorm(self.FEAT_DIM), nn.Dropout(dropout), nn.Linear(self.FEAT_DIM, 256), nn.GELU(), nn.Dropout(dropout * 0.5), nn.Linear(256, 1) ) def forward(self, x: torch.Tensor) -> torch.Tensor: B, T, C, H, W = x.shape x = x.view(B * T, C, H, W) feat = self.backbone(x) feat = feat.view(B, T, self.FEAT_DIM) feat = feat.mean(dim=1) logit = self.head(feat).squeeze(-1) return logit class DeepFakeModel: def __init__(self, model_path: str, device: str = "cpu"): self.device = torch.device(device) self.model = EfficientNetDeepFakeDetector(dropout=0.4).to(self.device) self._load_model(model_path) self.model.eval() logger.info(f"Model loaded on {self.device}") def _load_model(self, model_path: str): checkpoint = torch.load(model_path, map_location=self.device, weights_only=False) self.model.load_state_dict(checkpoint['model_state_dict']) logger.info(f"Loaded checkpoint from epoch {checkpoint.get('epoch', 'unknown')}") @torch.no_grad() def predict(self, video_tensor: torch.Tensor, threshold: float = 0.4) -> dict: """ Predict if video is real or fake. Args: video_tensor: Tensor of shape (T, 3, H, W) or (1, T, 3, H, W) threshold: Decision threshold (default: 0.4 from notebook testing) Returns: dict with prediction, confidence, and probabilities """ if video_tensor.dim() == 4: video_tensor = video_tensor.unsqueeze(0) video_tensor = video_tensor.to(self.device) logit = self.model(video_tensor) prob = torch.sigmoid(logit).item() # prob = P(REAL), because training used label 1=REAL, 0=FAKE prediction = "REAL" if prob >= threshold else "FAKE" confidence = prob if prediction == "REAL" else 1 - prob return { "prediction": prediction, "confidence": round(confidence, 4), "probability_real": round(prob, 4), "probability_fake": round(1 - prob, 4), "threshold": threshold }