Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torchvision import models | |
| import glob | |
| import os | |
| from typing import Optional | |
| class Model(nn.Module): | |
| """ | |
| Video deepfake detection model using ResNeXt50 + LSTM architecture. | |
| Ported from reference code for production use. | |
| """ | |
| def __init__(self, num_classes=2, latent_dim=2048, lstm_layers=1, hidden_dim=2048, bidirectional=False): | |
| super(Model, self).__init__() | |
| # Load pretrained ResNeXt50 | |
| model = models.resnext50_32x4d(pretrained=True) | |
| # Remove the last two layers (avgpool and fc) | |
| self.model = nn.Sequential(*list(model.children())[:-2]) | |
| self.lstm = nn.LSTM(latent_dim, hidden_dim, lstm_layers, bidirectional) | |
| self.relu = nn.LeakyReLU() | |
| self.dp = nn.Dropout(0.4) | |
| self.linear1 = nn.Linear(2048, num_classes) | |
| self.avgpool = nn.AdaptiveAvgPool2d(1) | |
| def forward(self, x): | |
| batch_size, seq_length, c, h, w = x.shape | |
| x = x.view(batch_size * seq_length, c, h, w) | |
| fmap = self.model(x) | |
| x = self.avgpool(fmap) | |
| x = x.view(batch_size, seq_length, 2048) | |
| x_lstm, _ = self.lstm(x, None) | |
| return fmap, self.dp(self.linear1(x_lstm[:, -1, :])) | |
| def get_accurate_model(sequence_length: int, models_dir: str = "models") -> Optional[str]: | |
| """ | |
| Select the best model based on sequence length (frame count). | |
| Args: | |
| sequence_length: Number of frames to sample from video (10, 20, 40, 60, 80, 100) | |
| models_dir: Directory containing the model files | |
| Returns: | |
| Full path to the selected model file, or None if no model found | |
| """ | |
| model_name = [] | |
| sequence_model = [] | |
| final_model = "" | |
| # Get all .pt model files | |
| list_models = glob.glob(os.path.join(models_dir, "*.pt")) | |
| if not list_models: | |
| print(f"No models found in {models_dir}") | |
| return None | |
| for model_path in list_models: | |
| model_name.append(os.path.basename(model_path)) | |
| # Find models matching the sequence length | |
| for model_filename in model_name: | |
| try: | |
| # Model naming pattern: model_{accuracy}_acc_{frames}_frames_*.pt | |
| parts = model_filename.split("_") | |
| seq = parts[3] # frames count is at index 3 | |
| if int(seq) == sequence_length: | |
| sequence_model.append(model_filename) | |
| except (IndexError, ValueError): | |
| continue | |
| # Select model with highest accuracy if multiple found | |
| if len(sequence_model) > 1: | |
| accuracy = [] | |
| for filename in sequence_model: | |
| acc = filename.split("_")[1] # accuracy is at index 1 | |
| accuracy.append(acc) | |
| max_index = accuracy.index(max(accuracy)) | |
| final_model = os.path.join(models_dir, sequence_model[max_index]) | |
| elif len(sequence_model) == 1: | |
| final_model = os.path.join(models_dir, sequence_model[0]) | |
| else: | |
| print(f"No model found for sequence length {sequence_length}") | |
| return None | |
| return final_model | |
| # Global model cache to avoid reloading | |
| _model_cache = {} | |
| def load_model(sequence_length: int, device: str = "cpu") -> Optional[Model]: | |
| """ | |
| Load the model for the specified sequence length. | |
| Uses caching to avoid reloading the same model. | |
| Args: | |
| sequence_length: Number of frames (10, 20, 40, 60, 80, 100) | |
| device: 'cpu' or 'cuda' | |
| Returns: | |
| Loaded model ready for inference, or None if loading fails | |
| """ | |
| cache_key = f"{sequence_length}_{device}" | |
| # Check cache first | |
| if cache_key in _model_cache: | |
| print(f"Using cached model for {sequence_length} frames") | |
| return _model_cache[cache_key] | |
| # Get the model path | |
| model_path = get_accurate_model(sequence_length) | |
| if not model_path: | |
| return None | |
| print(f"Loading model: {model_path}") | |
| try: | |
| # Initialize model | |
| model = Model(num_classes=2) | |
| # Load state dict | |
| if device == "cuda" and torch.cuda.is_available(): | |
| model = model.cuda() | |
| model.load_state_dict(torch.load(model_path)) | |
| else: | |
| model = model.cpu() | |
| model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu'))) | |
| model.eval() | |
| # Cache the model | |
| _model_cache[cache_key] = model | |
| print(f"Model loaded successfully for {sequence_length} frames") | |
| return model | |
| except Exception as e: | |
| print(f"Error loading model: {e}") | |
| return None | |
| def get_device() -> str: | |
| """Detect available device (GPU or CPU)""" | |
| return "cuda" if torch.cuda.is_available() else "cpu" | |