deepfake-backend / model_utils.py
Devanshu2025's picture
directory structure
1d2e071
import torch
import torch.nn as nn
from torchvision import models
import glob
import os
from typing import Optional
class Model(nn.Module):
"""
Video deepfake detection model using ResNeXt50 + LSTM architecture.
Ported from reference code for production use.
"""
def __init__(self, num_classes=2, latent_dim=2048, lstm_layers=1, hidden_dim=2048, bidirectional=False):
super(Model, self).__init__()
# Load pretrained ResNeXt50
model = models.resnext50_32x4d(pretrained=True)
# Remove the last two layers (avgpool and fc)
self.model = nn.Sequential(*list(model.children())[:-2])
self.lstm = nn.LSTM(latent_dim, hidden_dim, lstm_layers, bidirectional)
self.relu = nn.LeakyReLU()
self.dp = nn.Dropout(0.4)
self.linear1 = nn.Linear(2048, num_classes)
self.avgpool = nn.AdaptiveAvgPool2d(1)
def forward(self, x):
batch_size, seq_length, c, h, w = x.shape
x = x.view(batch_size * seq_length, c, h, w)
fmap = self.model(x)
x = self.avgpool(fmap)
x = x.view(batch_size, seq_length, 2048)
x_lstm, _ = self.lstm(x, None)
return fmap, self.dp(self.linear1(x_lstm[:, -1, :]))
def get_accurate_model(sequence_length: int, models_dir: str = "models") -> Optional[str]:
"""
Select the best model based on sequence length (frame count).
Args:
sequence_length: Number of frames to sample from video (10, 20, 40, 60, 80, 100)
models_dir: Directory containing the model files
Returns:
Full path to the selected model file, or None if no model found
"""
model_name = []
sequence_model = []
final_model = ""
# Get all .pt model files
list_models = glob.glob(os.path.join(models_dir, "*.pt"))
if not list_models:
print(f"No models found in {models_dir}")
return None
for model_path in list_models:
model_name.append(os.path.basename(model_path))
# Find models matching the sequence length
for model_filename in model_name:
try:
# Model naming pattern: model_{accuracy}_acc_{frames}_frames_*.pt
parts = model_filename.split("_")
seq = parts[3] # frames count is at index 3
if int(seq) == sequence_length:
sequence_model.append(model_filename)
except (IndexError, ValueError):
continue
# Select model with highest accuracy if multiple found
if len(sequence_model) > 1:
accuracy = []
for filename in sequence_model:
acc = filename.split("_")[1] # accuracy is at index 1
accuracy.append(acc)
max_index = accuracy.index(max(accuracy))
final_model = os.path.join(models_dir, sequence_model[max_index])
elif len(sequence_model) == 1:
final_model = os.path.join(models_dir, sequence_model[0])
else:
print(f"No model found for sequence length {sequence_length}")
return None
return final_model
# Global model cache to avoid reloading
_model_cache = {}
def load_model(sequence_length: int, device: str = "cpu") -> Optional[Model]:
"""
Load the model for the specified sequence length.
Uses caching to avoid reloading the same model.
Args:
sequence_length: Number of frames (10, 20, 40, 60, 80, 100)
device: 'cpu' or 'cuda'
Returns:
Loaded model ready for inference, or None if loading fails
"""
cache_key = f"{sequence_length}_{device}"
# Check cache first
if cache_key in _model_cache:
print(f"Using cached model for {sequence_length} frames")
return _model_cache[cache_key]
# Get the model path
model_path = get_accurate_model(sequence_length)
if not model_path:
return None
print(f"Loading model: {model_path}")
try:
# Initialize model
model = Model(num_classes=2)
# Load state dict
if device == "cuda" and torch.cuda.is_available():
model = model.cuda()
model.load_state_dict(torch.load(model_path))
else:
model = model.cpu()
model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))
model.eval()
# Cache the model
_model_cache[cache_key] = model
print(f"Model loaded successfully for {sequence_length} frames")
return model
except Exception as e:
print(f"Error loading model: {e}")
return None
def get_device() -> str:
"""Detect available device (GPU or CPU)"""
return "cuda" if torch.cuda.is_available() else "cpu"