File size: 4,749 Bytes
1d2e071
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import torch
import torch.nn as nn
from torchvision import models
import glob
import os
from typing import Optional

class Model(nn.Module):
    """
    Video deepfake detection model using ResNeXt50 + LSTM architecture.
    Ported from reference code for production use.
    """
    
    def __init__(self, num_classes=2, latent_dim=2048, lstm_layers=1, hidden_dim=2048, bidirectional=False):
        super(Model, self).__init__()
        # Load pretrained ResNeXt50
        model = models.resnext50_32x4d(pretrained=True)
        # Remove the last two layers (avgpool and fc)
        self.model = nn.Sequential(*list(model.children())[:-2])
        self.lstm = nn.LSTM(latent_dim, hidden_dim, lstm_layers, bidirectional)
        self.relu = nn.LeakyReLU()
        self.dp = nn.Dropout(0.4)
        self.linear1 = nn.Linear(2048, num_classes)
        self.avgpool = nn.AdaptiveAvgPool2d(1)

    def forward(self, x):
        batch_size, seq_length, c, h, w = x.shape
        x = x.view(batch_size * seq_length, c, h, w)
        fmap = self.model(x)
        x = self.avgpool(fmap)
        x = x.view(batch_size, seq_length, 2048)
        x_lstm, _ = self.lstm(x, None)
        return fmap, self.dp(self.linear1(x_lstm[:, -1, :]))


def get_accurate_model(sequence_length: int, models_dir: str = "models") -> Optional[str]:
    """
    Select the best model based on sequence length (frame count).
    
    Args:
        sequence_length: Number of frames to sample from video (10, 20, 40, 60, 80, 100)
        models_dir: Directory containing the model files
    
    Returns:
        Full path to the selected model file, or None if no model found
    """
    model_name = []
    sequence_model = []
    final_model = ""
    
    # Get all .pt model files
    list_models = glob.glob(os.path.join(models_dir, "*.pt"))
    
    if not list_models:
        print(f"No models found in {models_dir}")
        return None
    
    for model_path in list_models:
        model_name.append(os.path.basename(model_path))
    
    # Find models matching the sequence length
    for model_filename in model_name:
        try:
            # Model naming pattern: model_{accuracy}_acc_{frames}_frames_*.pt
            parts = model_filename.split("_")
            seq = parts[3]  # frames count is at index 3
            if int(seq) == sequence_length:
                sequence_model.append(model_filename)
        except (IndexError, ValueError):
            continue
    
    # Select model with highest accuracy if multiple found
    if len(sequence_model) > 1:
        accuracy = []
        for filename in sequence_model:
            acc = filename.split("_")[1]  # accuracy is at index 1
            accuracy.append(acc)
        max_index = accuracy.index(max(accuracy))
        final_model = os.path.join(models_dir, sequence_model[max_index])
    elif len(sequence_model) == 1:
        final_model = os.path.join(models_dir, sequence_model[0])
    else:
        print(f"No model found for sequence length {sequence_length}")
        return None
    
    return final_model


# Global model cache to avoid reloading
_model_cache = {}


def load_model(sequence_length: int, device: str = "cpu") -> Optional[Model]:
    """
    Load the model for the specified sequence length.
    Uses caching to avoid reloading the same model.
    
    Args:
        sequence_length: Number of frames (10, 20, 40, 60, 80, 100)
        device: 'cpu' or 'cuda'
    
    Returns:
        Loaded model ready for inference, or None if loading fails
    """
    cache_key = f"{sequence_length}_{device}"
    
    # Check cache first
    if cache_key in _model_cache:
        print(f"Using cached model for {sequence_length} frames")
        return _model_cache[cache_key]
    
    # Get the model path
    model_path = get_accurate_model(sequence_length)
    if not model_path:
        return None
    
    print(f"Loading model: {model_path}")
    
    try:
        # Initialize model
        model = Model(num_classes=2)
        
        # Load state dict
        if device == "cuda" and torch.cuda.is_available():
            model = model.cuda()
            model.load_state_dict(torch.load(model_path))
        else:
            model = model.cpu()
            model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
        
        model.eval()
        
        # Cache the model
        _model_cache[cache_key] = model
        print(f"Model loaded successfully for {sequence_length} frames")
        
        return model
    except Exception as e:
        print(f"Error loading model: {e}")
        return None


def get_device() -> str:
    """Detect available device (GPU or CPU)"""
    return "cuda" if torch.cuda.is_available() else "cpu"