CC
Deploy DeepFake video classifier to Hugging Face Spaces
198f874
import torch
import torch.nn as nn
import timm
from pathlib import Path
import logging
import os
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class EfficientNetDeepFakeDetector(nn.Module):
"""Frame-level EfficientNet-B0 with temporal mean-pooling."""
FEAT_DIM = 1280
def __init__(self, dropout: float = 0.4):
super().__init__()
# Backbone
backbone = timm.create_model(
'efficientnet_b0',
pretrained=False,
num_classes=0,
global_pool='avg'
)
# Freeze BatchNorm layers
for m in backbone.modules():
if isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)):
m.eval()
for p in m.parameters():
p.requires_grad = False
self.backbone = backbone
# Classifier head
self.head = nn.Sequential(
nn.LayerNorm(self.FEAT_DIM),
nn.Dropout(dropout),
nn.Linear(self.FEAT_DIM, 256),
nn.GELU(),
nn.Dropout(dropout * 0.5),
nn.Linear(256, 1)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, T, C, H, W = x.shape
x = x.view(B * T, C, H, W)
feat = self.backbone(x)
feat = feat.view(B, T, self.FEAT_DIM)
feat = feat.mean(dim=1)
logit = self.head(feat).squeeze(-1)
return logit
class DeepFakeModel:
def __init__(self, model_path: str, device: str = "cpu"):
self.device = torch.device(device)
self.model = EfficientNetDeepFakeDetector(dropout=0.4).to(self.device)
self._load_model(model_path)
self.model.eval()
logger.info(f"Model loaded on {self.device}")
def _load_model(self, model_path: str):
checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
self.model.load_state_dict(checkpoint['model_state_dict'])
logger.info(f"Loaded checkpoint from epoch {checkpoint.get('epoch', 'unknown')}")
@torch.no_grad()
def predict(self, video_tensor: torch.Tensor, threshold: float = 0.4) -> dict:
"""
Predict if video is real or fake.
Args:
video_tensor: Tensor of shape (T, 3, H, W) or (1, T, 3, H, W)
threshold: Decision threshold (default: 0.4 from notebook testing)
Returns:
dict with prediction, confidence, and probabilities
"""
if video_tensor.dim() == 4:
video_tensor = video_tensor.unsqueeze(0)
video_tensor = video_tensor.to(self.device)
logit = self.model(video_tensor)
prob = torch.sigmoid(logit).item()
# prob = P(REAL), because training used label 1=REAL, 0=FAKE
prediction = "REAL" if prob >= threshold else "FAKE"
confidence = prob if prediction == "REAL" else 1 - prob
return {
"prediction": prediction,
"confidence": round(confidence, 4),
"probability_real": round(prob, 4),
"probability_fake": round(1 - prob, 4),
"threshold": threshold
}