Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| import timm | |
| from pathlib import Path | |
| import logging | |
| import os | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class EfficientNetDeepFakeDetector(nn.Module): | |
| """Frame-level EfficientNet-B0 with temporal mean-pooling.""" | |
| FEAT_DIM = 1280 | |
| def __init__(self, dropout: float = 0.4): | |
| super().__init__() | |
| # Backbone | |
| backbone = timm.create_model( | |
| 'efficientnet_b0', | |
| pretrained=False, | |
| num_classes=0, | |
| global_pool='avg' | |
| ) | |
| # Freeze BatchNorm layers | |
| for m in backbone.modules(): | |
| if isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)): | |
| m.eval() | |
| for p in m.parameters(): | |
| p.requires_grad = False | |
| self.backbone = backbone | |
| # Classifier head | |
| self.head = nn.Sequential( | |
| nn.LayerNorm(self.FEAT_DIM), | |
| nn.Dropout(dropout), | |
| nn.Linear(self.FEAT_DIM, 256), | |
| nn.GELU(), | |
| nn.Dropout(dropout * 0.5), | |
| nn.Linear(256, 1) | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| B, T, C, H, W = x.shape | |
| x = x.view(B * T, C, H, W) | |
| feat = self.backbone(x) | |
| feat = feat.view(B, T, self.FEAT_DIM) | |
| feat = feat.mean(dim=1) | |
| logit = self.head(feat).squeeze(-1) | |
| return logit | |
| class DeepFakeModel: | |
| def __init__(self, model_path: str, device: str = "cpu"): | |
| self.device = torch.device(device) | |
| self.model = EfficientNetDeepFakeDetector(dropout=0.4).to(self.device) | |
| self._load_model(model_path) | |
| self.model.eval() | |
| logger.info(f"Model loaded on {self.device}") | |
| def _load_model(self, model_path: str): | |
| checkpoint = torch.load(model_path, map_location=self.device, weights_only=False) | |
| self.model.load_state_dict(checkpoint['model_state_dict']) | |
| logger.info(f"Loaded checkpoint from epoch {checkpoint.get('epoch', 'unknown')}") | |
| def predict(self, video_tensor: torch.Tensor, threshold: float = 0.4) -> dict: | |
| """ | |
| Predict if video is real or fake. | |
| Args: | |
| video_tensor: Tensor of shape (T, 3, H, W) or (1, T, 3, H, W) | |
| threshold: Decision threshold (default: 0.4 from notebook testing) | |
| Returns: | |
| dict with prediction, confidence, and probabilities | |
| """ | |
| if video_tensor.dim() == 4: | |
| video_tensor = video_tensor.unsqueeze(0) | |
| video_tensor = video_tensor.to(self.device) | |
| logit = self.model(video_tensor) | |
| prob = torch.sigmoid(logit).item() | |
| # prob = P(REAL), because training used label 1=REAL, 0=FAKE | |
| prediction = "REAL" if prob >= threshold else "FAKE" | |
| confidence = prob if prediction == "REAL" else 1 - prob | |
| return { | |
| "prediction": prediction, | |
| "confidence": round(confidence, 4), | |
| "probability_real": round(prob, 4), | |
| "probability_fake": round(1 - prob, 4), | |
| "threshold": threshold | |
| } |