File size: 4,320 Bytes
3b237c2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import torch
import torch.nn as nn
import timm
from pathlib import Path
import logging
import os

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


class EfficientNetDeepFakeDetector(nn.Module):
    """Frame-level EfficientNet-B0 with temporal mean-pooling."""

    FEAT_DIM = 1280

    def __init__(self, dropout: float = 0.4):
        super().__init__()

        # Backbone
        backbone = timm.create_model(
            'efficientnet_b0',
            pretrained=False,
            num_classes=0,
            global_pool='avg'
        )

        # Freeze BatchNorm layers
        for m in backbone.modules():
            if isinstance(m, (nn.BatchNorm2d, nn.SyncBatchNorm)):
                m.eval()
                for p in m.parameters():
                    p.requires_grad = False

        self.backbone = backbone

        # Classifier head
        self.head = nn.Sequential(
            nn.LayerNorm(self.FEAT_DIM),
            nn.Dropout(dropout),
            nn.Linear(self.FEAT_DIM, 256),
            nn.GELU(),
            nn.Dropout(dropout * 0.5),
            nn.Linear(256, 1)
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, T, C, H, W = x.shape
        x = x.view(B * T, C, H, W)
        feat = self.backbone(x)
        feat = feat.view(B, T, self.FEAT_DIM)
        feat = feat.mean(dim=1)
        logit = self.head(feat).squeeze(-1)
        return logit


class DeepFakeModel:
    def __init__(self, model_path: str, device: str = "cpu"):
        self.device = torch.device(device)
        self.model = EfficientNetDeepFakeDetector(dropout=0.4).to(self.device)
        self._load_model(model_path)
        self.model.eval()
        logger.info(f"Model loaded on {self.device}")

    def _load_model(self, model_path: str):
        """Load model checkpoint from file"""
        if not os.path.exists(model_path):
            raise FileNotFoundError(f"Model file not found at {model_path}")
        
        checkpoint = torch.load(model_path, map_location=self.device, weights_only=False)
        
        # Handle different checkpoint formats
        if 'model_state_dict' in checkpoint:
            self.model.load_state_dict(checkpoint['model_state_dict'])
            epoch = checkpoint.get('epoch', 'unknown')
            val_f1 = checkpoint.get('val_f1_macro', 'unknown')
            logger.info(f"Loaded checkpoint from epoch {epoch} (val_f1={val_f1})")
        else:
            # If checkpoint is just the state dict
            self.model.load_state_dict(checkpoint)
            logger.info("Loaded model state dict")

    @torch.no_grad()
    def predict(self, video_tensor: torch.Tensor, threshold: float = 0.5) -> dict:
        """
        Predict if video is real or fake.
        
        Args:
            video_tensor: Tensor of shape (T, 3, H, W) or (1, T, 3, H, W)
            threshold: Decision threshold (default: 0.5)
        
        Returns:
            dict with prediction, confidence, and probabilities
        """
        if video_tensor.dim() == 4:
            video_tensor = video_tensor.unsqueeze(0)

        video_tensor = video_tensor.to(self.device)
        logit = self.model(video_tensor)
        prob = torch.sigmoid(logit).item()
        
        # prob = P(REAL), because training used label 1=REAL, 0=FAKE
        prediction = "REAL" if prob >= threshold else "FAKE"
        confidence = prob if prediction == "REAL" else 1 - prob

        return {
            "prediction": prediction,
            "confidence": round(confidence, 4),
            "probability_real": round(prob, 4),
            "probability_fake": round(1 - prob, 4),
            "threshold": threshold
        }
    
    @torch.no_grad()
    def predict_from_video_path(self, video_path: str, threshold: float = 0.5) -> dict:
        """
        Convenience method to predict directly from video file path.
        
        Args:
            video_path: Path to video file
            threshold: Decision threshold
        
        Returns:
            Prediction result dictionary
        """
        from .utils import video_to_tensor
        
        video_tensor = video_to_tensor(
            video_path,
            num_frames=16,
            img_size=224
        )
        
        return self.predict(video_tensor, threshold)