Spaces:
Sleeping
Sleeping
File size: 4,749 Bytes
1d2e071 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 | import torch
import torch.nn as nn
from torchvision import models
import glob
import os
from typing import Optional
class Model(nn.Module):
"""
Video deepfake detection model using ResNeXt50 + LSTM architecture.
Ported from reference code for production use.
"""
def __init__(self, num_classes=2, latent_dim=2048, lstm_layers=1, hidden_dim=2048, bidirectional=False):
super(Model, self).__init__()
# Load pretrained ResNeXt50
model = models.resnext50_32x4d(pretrained=True)
# Remove the last two layers (avgpool and fc)
self.model = nn.Sequential(*list(model.children())[:-2])
self.lstm = nn.LSTM(latent_dim, hidden_dim, lstm_layers, bidirectional)
self.relu = nn.LeakyReLU()
self.dp = nn.Dropout(0.4)
self.linear1 = nn.Linear(2048, num_classes)
self.avgpool = nn.AdaptiveAvgPool2d(1)
def forward(self, x):
batch_size, seq_length, c, h, w = x.shape
x = x.view(batch_size * seq_length, c, h, w)
fmap = self.model(x)
x = self.avgpool(fmap)
x = x.view(batch_size, seq_length, 2048)
x_lstm, _ = self.lstm(x, None)
return fmap, self.dp(self.linear1(x_lstm[:, -1, :]))
def get_accurate_model(sequence_length: int, models_dir: str = "models") -> Optional[str]:
"""
Select the best model based on sequence length (frame count).
Args:
sequence_length: Number of frames to sample from video (10, 20, 40, 60, 80, 100)
models_dir: Directory containing the model files
Returns:
Full path to the selected model file, or None if no model found
"""
model_name = []
sequence_model = []
final_model = ""
# Get all .pt model files
list_models = glob.glob(os.path.join(models_dir, "*.pt"))
if not list_models:
print(f"No models found in {models_dir}")
return None
for model_path in list_models:
model_name.append(os.path.basename(model_path))
# Find models matching the sequence length
for model_filename in model_name:
try:
# Model naming pattern: model_{accuracy}_acc_{frames}_frames_*.pt
parts = model_filename.split("_")
seq = parts[3] # frames count is at index 3
if int(seq) == sequence_length:
sequence_model.append(model_filename)
except (IndexError, ValueError):
continue
# Select model with highest accuracy if multiple found
if len(sequence_model) > 1:
accuracy = []
for filename in sequence_model:
acc = filename.split("_")[1] # accuracy is at index 1
accuracy.append(acc)
max_index = accuracy.index(max(accuracy))
final_model = os.path.join(models_dir, sequence_model[max_index])
elif len(sequence_model) == 1:
final_model = os.path.join(models_dir, sequence_model[0])
else:
print(f"No model found for sequence length {sequence_length}")
return None
return final_model
# Global model cache to avoid reloading
_model_cache = {}
def load_model(sequence_length: int, device: str = "cpu") -> Optional[Model]:
"""
Load the model for the specified sequence length.
Uses caching to avoid reloading the same model.
Args:
sequence_length: Number of frames (10, 20, 40, 60, 80, 100)
device: 'cpu' or 'cuda'
Returns:
Loaded model ready for inference, or None if loading fails
"""
cache_key = f"{sequence_length}_{device}"
# Check cache first
if cache_key in _model_cache:
print(f"Using cached model for {sequence_length} frames")
return _model_cache[cache_key]
# Get the model path
model_path = get_accurate_model(sequence_length)
if not model_path:
return None
print(f"Loading model: {model_path}")
try:
# Initialize model
model = Model(num_classes=2)
# Load state dict
if device == "cuda" and torch.cuda.is_available():
model = model.cuda()
model.load_state_dict(torch.load(model_path))
else:
model = model.cpu()
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
# Cache the model
_model_cache[cache_key] = model
print(f"Model loaded successfully for {sequence_length} frames")
return model
except Exception as e:
print(f"Error loading model: {e}")
return None
def get_device() -> str:
"""Detect available device (GPU or CPU)"""
return "cuda" if torch.cuda.is_available() else "cpu"
|