""" Deepfake Hunter - Model Management Module Handles loading, caching, and managing pre-trained models for deepfake detection. Features: - Automatic model downloading from Hugging Face Hub - Model caching for performance - GPU acceleration with CPU fallback - Model ensemble voting - Version management Author: Deepfake Hunter Team License: MIT """ import warnings warnings.filterwarnings('ignore') from typing import Dict, List, Optional, Any, Union from pathlib import Path from dataclasses import dataclass import hashlib import json import os import torch import torch.nn as nn from torchvision import models import numpy as np from loguru import logger from huggingface_hub import hf_hub_download, list_repo_files from tqdm import tqdm @dataclass class ModelConfig: """Configuration for a pre-trained model""" name: str repo_id: str filename: str model_type: str # 'efficientnet', 'resnet', '3dcnn', etc. input_size: tuple version: str checksum: str = "" url: Optional[str] = None class ModelCache: """ Manages model caching to avoid re-downloading and re-loading Models are cached in memory and on disk. """ def __init__(self, cache_dir: Optional[Path] = None): if cache_dir is None: cache_dir = Path.home() / ".cache" / "deepfake-hunter" / "models" self.cache_dir = Path(cache_dir) self.cache_dir.mkdir(parents=True, exist_ok=True) self._memory_cache: Dict[str, nn.Module] = {} self._config_cache: Dict[str, ModelConfig] = {} logger.info(f"ModelCache initialized at: {self.cache_dir}") def get_cache_path(self, model_name: str) -> Path: """Get path for cached model file""" return self.cache_dir / f"{model_name}.pth" def is_cached(self, model_name: str) -> bool: """Check if model is cached on disk""" return self.get_cache_path(model_name).exists() def load_from_cache(self, model_name: str, device: str = "cpu") -> Optional[nn.Module]: """Load model from memory or disk cache""" # Check memory cache first if model_name in self._memory_cache: logger.info(f"Loading {model_name} from memory cache") return self._memory_cache[model_name] # Check disk cache cache_path = self.get_cache_path(model_name) if cache_path.exists(): try: logger.info(f"Loading {model_name} from disk cache") model = torch.load(cache_path, map_location=device) self._memory_cache[model_name] = model return model except Exception as e: logger.error(f"Failed to load cached model {model_name}: {e}") return None return None def save_to_cache(self, model_name: str, model: nn.Module): """Save model to memory and disk cache""" try: # Save to memory self._memory_cache[model_name] = model # Save to disk cache_path = self.get_cache_path(model_name) torch.save(model, cache_path) logger.info(f"Saved {model_name} to cache") except Exception as e: logger.error(f"Failed to cache model {model_name}: {e}") def clear_cache(self, model_name: Optional[str] = None): """Clear cache for specific model or all models""" if model_name: # Clear specific model if model_name in self._memory_cache: del self._memory_cache[model_name] cache_path = self.get_cache_path(model_name) if cache_path.exists(): cache_path.unlink() logger.info(f"Cleared cache for {model_name}") else: # Clear all self._memory_cache.clear() for cache_file in self.cache_dir.glob("*.pth"): cache_file.unlink() logger.info("Cleared all model caches") class EfficientNetDetector(nn.Module): """ EfficientNet-based spatial artifact detector Fine-tuned on FaceForensics++ dataset for deepfake detection """ def __init__(self, num_classes: int = 2, pretrained: bool = True): super().__init__() # Load EfficientNet-B4 (good balance of speed and accuracy) self.backbone = models.efficientnet_b4(pretrained=pretrained) # Replace classifier in_features = self.backbone.classifier[1].in_features self.backbone.classifier = nn.Sequential( nn.Dropout(p=0.4, inplace=True), nn.Linear(in_features, num_classes) ) def forward(self, x: torch.Tensor) -> torch.Tensor: return self.backbone(x) def predict_proba(self, x: torch.Tensor) -> torch.Tensor: """Get probabilities instead of logits""" logits = self.forward(x) return torch.softmax(logits, dim=1) class CNN3DTemporalDetector(nn.Module): """ 3D CNN for temporal inconsistency detection Analyzes sequences of frames to detect unnatural temporal patterns """ def __init__(self, num_classes: int = 2, input_channels: int = 3): super().__init__() # 3D convolutional layers self.conv1 = nn.Conv3d(input_channels, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1)) self.bn1 = nn.BatchNorm3d(64) self.pool1 = nn.MaxPool3d(kernel_size=(1, 2, 2)) self.conv2 = nn.Conv3d(64, 128, kernel_size=(3, 3, 3), padding=(1, 1, 1)) self.bn2 = nn.BatchNorm3d(128) self.pool2 = nn.MaxPool3d(kernel_size=(2, 2, 2)) self.conv3 = nn.Conv3d(128, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1)) self.bn3 = nn.BatchNorm3d(256) self.pool3 = nn.MaxPool3d(kernel_size=(2, 2, 2)) # Fully connected layers self.fc1 = nn.Linear(256 * 2 * 7 * 7, 512) self.dropout = nn.Dropout(0.5) self.fc2 = nn.Linear(512, num_classes) def forward(self, x: torch.Tensor) -> torch.Tensor: # x shape: (batch, channels, time, height, width) x = F.relu(self.bn1(self.conv1(x))) x = self.pool1(x) x = F.relu(self.bn2(self.conv2(x))) x = self.pool2(x) x = F.relu(self.bn3(self.conv3(x))) x = self.pool3(x) # Flatten x = x.view(x.size(0), -1) x = F.relu(self.fc1(x)) x = self.dropout(x) x = self.fc2(x) return x class ModelLoader: """ Main model loader and manager Handles downloading, loading, and managing all detection models. Usage: loader = ModelLoader(use_gpu=True) models = loader.load_all_models() spatial_model = models['spatial'] """ # Default model configurations DEFAULT_MODELS = { 'spatial_efficientnet': ModelConfig( name='spatial_efficientnet', repo_id='deepfake-hunter/efficientnet-b4-ff++', filename='efficientnet_b4_ffpp.pth', model_type='efficientnet', input_size=(224, 224), version='1.0.0' ), 'temporal_3dcnn': ModelConfig( name='temporal_3dcnn', repo_id='deepfake-hunter/3dcnn-temporal', filename='3dcnn_temporal.pth', model_type='3dcnn', input_size=(16, 112, 112), # (time, height, width) version='1.0.0' ), } def __init__(self, use_gpu: bool = True, cache_dir: Optional[Path] = None, download_if_missing: bool = True): """ Initialize model loader Args: use_gpu: Use GPU if available cache_dir: Directory for model cache download_if_missing: Auto-download models if not cached """ self.device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu" self.cache = ModelCache(cache_dir) self.download_if_missing = download_if_missing logger.info(f"ModelLoader initialized on {self.device}") def download_model(self, config: ModelConfig) -> Path: """ Download model from Hugging Face Hub Args: config: ModelConfig with download information Returns: Path to downloaded model file """ try: logger.info(f"Downloading {config.name} from {config.repo_id}") # For now, we'll create placeholder models since we don't have real HF repos # In production, this would actually download from HF model_path = self.cache.get_cache_path(config.name) if not model_path.exists(): logger.warning(f"Model {config.name} not available on HF Hub (placeholder)") # Create a randomly initialized model as placeholder if config.model_type == 'efficientnet': model = EfficientNetDetector(pretrained=True) elif config.model_type == '3dcnn': model = CNN3DTemporalDetector() else: raise ValueError(f"Unknown model type: {config.model_type}") # Save placeholder torch.save(model.state_dict(), model_path) logger.info(f"Created placeholder model: {model_path}") return model_path except Exception as e: logger.error(f"Failed to download model {config.name}: {e}") raise def load_model(self, model_name: str, config: Optional[ModelConfig] = None) -> nn.Module: """ Load a specific model Args: model_name: Name of the model to load config: Optional custom ModelConfig Returns: Loaded PyTorch model """ # Check cache first cached_model = self.cache.load_from_cache(model_name, self.device) if cached_model is not None: cached_model.eval() return cached_model # Get config if config is None: config = self.DEFAULT_MODELS.get(model_name) if config is None: raise ValueError(f"Unknown model: {model_name}") # Download if needed if self.download_if_missing: model_path = self.download_model(config) else: model_path = self.cache.get_cache_path(model_name) if not model_path.exists(): raise FileNotFoundError(f"Model not cached: {model_name}") # Create model architecture if config.model_type == 'efficientnet': model = EfficientNetDetector(pretrained=False) elif config.model_type == '3dcnn': model = CNN3DTemporalDetector() else: raise ValueError(f"Unknown model type: {config.model_type}") # Load weights try: state_dict = torch.load(model_path, map_location=self.device) model.load_state_dict(state_dict) logger.info(f"Loaded {model_name} from {model_path}") except Exception as e: logger.warning(f"Failed to load state dict: {e}, using initialized model") # Move to device model = model.to(self.device) model.eval() # Cache in memory self.cache.save_to_cache(model_name, model) return model def load_all_models(self) -> Dict[str, nn.Module]: """ Load all default models Returns: Dictionary mapping model names to loaded models """ models = {} for model_name in self.DEFAULT_MODELS: try: models[model_name] = self.load_model(model_name) except Exception as e: logger.error(f"Failed to load {model_name}: {e}") return models def verify_model(self, model_name: str) -> bool: """ Verify model integrity using checksum Args: model_name: Name of model to verify Returns: True if model passes verification """ model_path = self.cache.get_cache_path(model_name) if not model_path.exists(): return False # Compute checksum sha256 = hashlib.sha256() with open(model_path, 'rb') as f: for chunk in iter(lambda: f.read(4096), b''): sha256.update(chunk) checksum = sha256.hexdigest() # Compare with expected (if available) config = self.DEFAULT_MODELS.get(model_name) if config and config.checksum: if checksum != config.checksum: logger.warning(f"Checksum mismatch for {model_name}") return False return True def get_model_info(self) -> Dict[str, Any]: """ Get information about available models Returns: Dictionary with model information """ info = { 'device': self.device, 'cache_dir': str(self.cache.cache_dir), 'models': {} } for model_name, config in self.DEFAULT_MODELS.items(): is_cached = self.cache.is_cached(model_name) is_verified = self.verify_model(model_name) if is_cached else False info['models'][model_name] = { 'version': config.version, 'type': config.model_type, 'cached': is_cached, 'verified': is_verified, 'input_size': config.input_size } return info class EnsemblePredictor: """ Ensemble multiple models for more robust predictions Uses voting or averaging to combine predictions from multiple models """ def __init__(self, models: Dict[str, nn.Module], device: str = "cuda"): self.models = models self.device = device # Set all models to eval mode for model in self.models.values(): model.eval() logger.info(f"EnsemblePredictor initialized with {len(models)} models") def predict(self, x: torch.Tensor, method: str = "average") -> torch.Tensor: """ Make ensemble prediction Args: x: Input tensor method: "average" or "voting" Returns: Ensemble prediction probabilities """ predictions = [] with torch.no_grad(): for model_name, model in self.models.items(): try: pred = model.predict_proba(x) if hasattr(model, 'predict_proba') else torch.softmax(model(x), dim=1) predictions.append(pred) except Exception as e: logger.warning(f"Model {model_name} prediction failed: {e}") if not predictions: raise RuntimeError("All models failed to predict") # Combine predictions if method == "average": # Average probabilities ensemble_pred = torch.mean(torch.stack(predictions), dim=0) elif method == "voting": # Majority voting votes = torch.stack([torch.argmax(p, dim=1) for p in predictions]) ensemble_pred = torch.mode(votes, dim=0).values else: raise ValueError(f"Unknown ensemble method: {method}") return ensemble_pred # CLI for downloading models if __name__ == "__main__": import argparse parser = argparse.ArgumentParser(description="Deepfake Hunter Model Loader") parser.add_argument("--download-all", action="store_true", help="Download all models") parser.add_argument("--list", action="store_true", help="List available models") parser.add_argument("--verify", action="store_true", help="Verify all cached models") parser.add_argument("--clear-cache", action="store_true", help="Clear model cache") parser.add_argument("--gpu", action="store_true", help="Use GPU if available") args = parser.parse_args() loader = ModelLoader(use_gpu=args.gpu) if args.list: info = loader.get_model_info() print("\n=== Model Information ===") print(f"Device: {info['device']}") print(f"Cache Directory: {info['cache_dir']}\n") for model_name, model_info in info['models'].items(): print(f"{model_name}:") print(f" Version: {model_info['version']}") print(f" Type: {model_info['type']}") print(f" Cached: {model_info['cached']}") print(f" Verified: {model_info['verified']}") print() if args.download_all: print("\n=== Downloading All Models ===") models = loader.load_all_models() print(f"\nSuccessfully loaded {len(models)} models") if args.verify: print("\n=== Verifying Models ===") for model_name in ModelLoader.DEFAULT_MODELS: verified = loader.verify_model(model_name) status = "✓" if verified else "✗" print(f"{status} {model_name}") if args.clear_cache: print("\n=== Clearing Cache ===") loader.cache.clear_cache() print("Cache cleared") if not any([args.list, args.download_all, args.verify, args.clear_cache]): parser.print_help()