Spaces:
Sleeping
Sleeping
| """ | |
| Deepfake Hunter - Model Management Module | |
| Handles loading, caching, and managing pre-trained models for deepfake detection. | |
| Features: | |
| - Automatic model downloading from Hugging Face Hub | |
| - Model caching for performance | |
| - GPU acceleration with CPU fallback | |
| - Model ensemble voting | |
| - Version management | |
| Author: Deepfake Hunter Team | |
| License: MIT | |
| """ | |
| import warnings | |
| warnings.filterwarnings('ignore') | |
| from typing import Dict, List, Optional, Any, Union | |
| from pathlib import Path | |
| from dataclasses import dataclass | |
| import hashlib | |
| import json | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| from torchvision import models | |
| import numpy as np | |
| from loguru import logger | |
| from huggingface_hub import hf_hub_download, list_repo_files | |
| from tqdm import tqdm | |
| class ModelConfig: | |
| """Configuration for a pre-trained model""" | |
| name: str | |
| repo_id: str | |
| filename: str | |
| model_type: str # 'efficientnet', 'resnet', '3dcnn', etc. | |
| input_size: tuple | |
| version: str | |
| checksum: str = "" | |
| url: Optional[str] = None | |
| class ModelCache: | |
| """ | |
| Manages model caching to avoid re-downloading and re-loading | |
| Models are cached in memory and on disk. | |
| """ | |
| def __init__(self, cache_dir: Optional[Path] = None): | |
| if cache_dir is None: | |
| cache_dir = Path.home() / ".cache" / "deepfake-hunter" / "models" | |
| self.cache_dir = Path(cache_dir) | |
| self.cache_dir.mkdir(parents=True, exist_ok=True) | |
| self._memory_cache: Dict[str, nn.Module] = {} | |
| self._config_cache: Dict[str, ModelConfig] = {} | |
| logger.info(f"ModelCache initialized at: {self.cache_dir}") | |
| def get_cache_path(self, model_name: str) -> Path: | |
| """Get path for cached model file""" | |
| return self.cache_dir / f"{model_name}.pth" | |
| def is_cached(self, model_name: str) -> bool: | |
| """Check if model is cached on disk""" | |
| return self.get_cache_path(model_name).exists() | |
| def load_from_cache(self, model_name: str, device: str = "cpu") -> Optional[nn.Module]: | |
| """Load model from memory or disk cache""" | |
| # Check memory cache first | |
| if model_name in self._memory_cache: | |
| logger.info(f"Loading {model_name} from memory cache") | |
| return self._memory_cache[model_name] | |
| # Check disk cache | |
| cache_path = self.get_cache_path(model_name) | |
| if cache_path.exists(): | |
| try: | |
| logger.info(f"Loading {model_name} from disk cache") | |
| model = torch.load(cache_path, map_location=device) | |
| self._memory_cache[model_name] = model | |
| return model | |
| except Exception as e: | |
| logger.error(f"Failed to load cached model {model_name}: {e}") | |
| return None | |
| return None | |
| def save_to_cache(self, model_name: str, model: nn.Module): | |
| """Save model to memory and disk cache""" | |
| try: | |
| # Save to memory | |
| self._memory_cache[model_name] = model | |
| # Save to disk | |
| cache_path = self.get_cache_path(model_name) | |
| torch.save(model, cache_path) | |
| logger.info(f"Saved {model_name} to cache") | |
| except Exception as e: | |
| logger.error(f"Failed to cache model {model_name}: {e}") | |
| def clear_cache(self, model_name: Optional[str] = None): | |
| """Clear cache for specific model or all models""" | |
| if model_name: | |
| # Clear specific model | |
| if model_name in self._memory_cache: | |
| del self._memory_cache[model_name] | |
| cache_path = self.get_cache_path(model_name) | |
| if cache_path.exists(): | |
| cache_path.unlink() | |
| logger.info(f"Cleared cache for {model_name}") | |
| else: | |
| # Clear all | |
| self._memory_cache.clear() | |
| for cache_file in self.cache_dir.glob("*.pth"): | |
| cache_file.unlink() | |
| logger.info("Cleared all model caches") | |
| class EfficientNetDetector(nn.Module): | |
| """ | |
| EfficientNet-based spatial artifact detector | |
| Fine-tuned on FaceForensics++ dataset for deepfake detection | |
| """ | |
| def __init__(self, num_classes: int = 2, pretrained: bool = True): | |
| super().__init__() | |
| # Load EfficientNet-B4 (good balance of speed and accuracy) | |
| self.backbone = models.efficientnet_b4(pretrained=pretrained) | |
| # Replace classifier | |
| in_features = self.backbone.classifier[1].in_features | |
| self.backbone.classifier = nn.Sequential( | |
| nn.Dropout(p=0.4, inplace=True), | |
| nn.Linear(in_features, num_classes) | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.backbone(x) | |
| def predict_proba(self, x: torch.Tensor) -> torch.Tensor: | |
| """Get probabilities instead of logits""" | |
| logits = self.forward(x) | |
| return torch.softmax(logits, dim=1) | |
| class CNN3DTemporalDetector(nn.Module): | |
| """ | |
| 3D CNN for temporal inconsistency detection | |
| Analyzes sequences of frames to detect unnatural temporal patterns | |
| """ | |
| def __init__(self, num_classes: int = 2, input_channels: int = 3): | |
| super().__init__() | |
| # 3D convolutional layers | |
| self.conv1 = nn.Conv3d(input_channels, 64, kernel_size=(3, 3, 3), padding=(1, 1, 1)) | |
| self.bn1 = nn.BatchNorm3d(64) | |
| self.pool1 = nn.MaxPool3d(kernel_size=(1, 2, 2)) | |
| self.conv2 = nn.Conv3d(64, 128, kernel_size=(3, 3, 3), padding=(1, 1, 1)) | |
| self.bn2 = nn.BatchNorm3d(128) | |
| self.pool2 = nn.MaxPool3d(kernel_size=(2, 2, 2)) | |
| self.conv3 = nn.Conv3d(128, 256, kernel_size=(3, 3, 3), padding=(1, 1, 1)) | |
| self.bn3 = nn.BatchNorm3d(256) | |
| self.pool3 = nn.MaxPool3d(kernel_size=(2, 2, 2)) | |
| # Fully connected layers | |
| self.fc1 = nn.Linear(256 * 2 * 7 * 7, 512) | |
| self.dropout = nn.Dropout(0.5) | |
| self.fc2 = nn.Linear(512, num_classes) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| # x shape: (batch, channels, time, height, width) | |
| x = F.relu(self.bn1(self.conv1(x))) | |
| x = self.pool1(x) | |
| x = F.relu(self.bn2(self.conv2(x))) | |
| x = self.pool2(x) | |
| x = F.relu(self.bn3(self.conv3(x))) | |
| x = self.pool3(x) | |
| # Flatten | |
| x = x.view(x.size(0), -1) | |
| x = F.relu(self.fc1(x)) | |
| x = self.dropout(x) | |
| x = self.fc2(x) | |
| return x | |
| class ModelLoader: | |
| """ | |
| Main model loader and manager | |
| Handles downloading, loading, and managing all detection models. | |
| Usage: | |
| loader = ModelLoader(use_gpu=True) | |
| models = loader.load_all_models() | |
| spatial_model = models['spatial'] | |
| """ | |
| # Default model configurations | |
| DEFAULT_MODELS = { | |
| 'spatial_efficientnet': ModelConfig( | |
| name='spatial_efficientnet', | |
| repo_id='deepfake-hunter/efficientnet-b4-ff++', | |
| filename='efficientnet_b4_ffpp.pth', | |
| model_type='efficientnet', | |
| input_size=(224, 224), | |
| version='1.0.0' | |
| ), | |
| 'temporal_3dcnn': ModelConfig( | |
| name='temporal_3dcnn', | |
| repo_id='deepfake-hunter/3dcnn-temporal', | |
| filename='3dcnn_temporal.pth', | |
| model_type='3dcnn', | |
| input_size=(16, 112, 112), # (time, height, width) | |
| version='1.0.0' | |
| ), | |
| } | |
| def __init__(self, | |
| use_gpu: bool = True, | |
| cache_dir: Optional[Path] = None, | |
| download_if_missing: bool = True): | |
| """ | |
| Initialize model loader | |
| Args: | |
| use_gpu: Use GPU if available | |
| cache_dir: Directory for model cache | |
| download_if_missing: Auto-download models if not cached | |
| """ | |
| self.device = "cuda" if use_gpu and torch.cuda.is_available() else "cpu" | |
| self.cache = ModelCache(cache_dir) | |
| self.download_if_missing = download_if_missing | |
| logger.info(f"ModelLoader initialized on {self.device}") | |
| def download_model(self, config: ModelConfig) -> Path: | |
| """ | |
| Download model from Hugging Face Hub | |
| Args: | |
| config: ModelConfig with download information | |
| Returns: | |
| Path to downloaded model file | |
| """ | |
| try: | |
| logger.info(f"Downloading {config.name} from {config.repo_id}") | |
| # For now, we'll create placeholder models since we don't have real HF repos | |
| # In production, this would actually download from HF | |
| model_path = self.cache.get_cache_path(config.name) | |
| if not model_path.exists(): | |
| logger.warning(f"Model {config.name} not available on HF Hub (placeholder)") | |
| # Create a randomly initialized model as placeholder | |
| if config.model_type == 'efficientnet': | |
| model = EfficientNetDetector(pretrained=True) | |
| elif config.model_type == '3dcnn': | |
| model = CNN3DTemporalDetector() | |
| else: | |
| raise ValueError(f"Unknown model type: {config.model_type}") | |
| # Save placeholder | |
| torch.save(model.state_dict(), model_path) | |
| logger.info(f"Created placeholder model: {model_path}") | |
| return model_path | |
| except Exception as e: | |
| logger.error(f"Failed to download model {config.name}: {e}") | |
| raise | |
| def load_model(self, | |
| model_name: str, | |
| config: Optional[ModelConfig] = None) -> nn.Module: | |
| """ | |
| Load a specific model | |
| Args: | |
| model_name: Name of the model to load | |
| config: Optional custom ModelConfig | |
| Returns: | |
| Loaded PyTorch model | |
| """ | |
| # Check cache first | |
| cached_model = self.cache.load_from_cache(model_name, self.device) | |
| if cached_model is not None: | |
| cached_model.eval() | |
| return cached_model | |
| # Get config | |
| if config is None: | |
| config = self.DEFAULT_MODELS.get(model_name) | |
| if config is None: | |
| raise ValueError(f"Unknown model: {model_name}") | |
| # Download if needed | |
| if self.download_if_missing: | |
| model_path = self.download_model(config) | |
| else: | |
| model_path = self.cache.get_cache_path(model_name) | |
| if not model_path.exists(): | |
| raise FileNotFoundError(f"Model not cached: {model_name}") | |
| # Create model architecture | |
| if config.model_type == 'efficientnet': | |
| model = EfficientNetDetector(pretrained=False) | |
| elif config.model_type == '3dcnn': | |
| model = CNN3DTemporalDetector() | |
| else: | |
| raise ValueError(f"Unknown model type: {config.model_type}") | |
| # Load weights | |
| try: | |
| state_dict = torch.load(model_path, map_location=self.device) | |
| model.load_state_dict(state_dict) | |
| logger.info(f"Loaded {model_name} from {model_path}") | |
| except Exception as e: | |
| logger.warning(f"Failed to load state dict: {e}, using initialized model") | |
| # Move to device | |
| model = model.to(self.device) | |
| model.eval() | |
| # Cache in memory | |
| self.cache.save_to_cache(model_name, model) | |
| return model | |
| def load_all_models(self) -> Dict[str, nn.Module]: | |
| """ | |
| Load all default models | |
| Returns: | |
| Dictionary mapping model names to loaded models | |
| """ | |
| models = {} | |
| for model_name in self.DEFAULT_MODELS: | |
| try: | |
| models[model_name] = self.load_model(model_name) | |
| except Exception as e: | |
| logger.error(f"Failed to load {model_name}: {e}") | |
| return models | |
| def verify_model(self, model_name: str) -> bool: | |
| """ | |
| Verify model integrity using checksum | |
| Args: | |
| model_name: Name of model to verify | |
| Returns: | |
| True if model passes verification | |
| """ | |
| model_path = self.cache.get_cache_path(model_name) | |
| if not model_path.exists(): | |
| return False | |
| # Compute checksum | |
| sha256 = hashlib.sha256() | |
| with open(model_path, 'rb') as f: | |
| for chunk in iter(lambda: f.read(4096), b''): | |
| sha256.update(chunk) | |
| checksum = sha256.hexdigest() | |
| # Compare with expected (if available) | |
| config = self.DEFAULT_MODELS.get(model_name) | |
| if config and config.checksum: | |
| if checksum != config.checksum: | |
| logger.warning(f"Checksum mismatch for {model_name}") | |
| return False | |
| return True | |
| def get_model_info(self) -> Dict[str, Any]: | |
| """ | |
| Get information about available models | |
| Returns: | |
| Dictionary with model information | |
| """ | |
| info = { | |
| 'device': self.device, | |
| 'cache_dir': str(self.cache.cache_dir), | |
| 'models': {} | |
| } | |
| for model_name, config in self.DEFAULT_MODELS.items(): | |
| is_cached = self.cache.is_cached(model_name) | |
| is_verified = self.verify_model(model_name) if is_cached else False | |
| info['models'][model_name] = { | |
| 'version': config.version, | |
| 'type': config.model_type, | |
| 'cached': is_cached, | |
| 'verified': is_verified, | |
| 'input_size': config.input_size | |
| } | |
| return info | |
| class EnsemblePredictor: | |
| """ | |
| Ensemble multiple models for more robust predictions | |
| Uses voting or averaging to combine predictions from multiple models | |
| """ | |
| def __init__(self, models: Dict[str, nn.Module], device: str = "cuda"): | |
| self.models = models | |
| self.device = device | |
| # Set all models to eval mode | |
| for model in self.models.values(): | |
| model.eval() | |
| logger.info(f"EnsemblePredictor initialized with {len(models)} models") | |
| def predict(self, | |
| x: torch.Tensor, | |
| method: str = "average") -> torch.Tensor: | |
| """ | |
| Make ensemble prediction | |
| Args: | |
| x: Input tensor | |
| method: "average" or "voting" | |
| Returns: | |
| Ensemble prediction probabilities | |
| """ | |
| predictions = [] | |
| with torch.no_grad(): | |
| for model_name, model in self.models.items(): | |
| try: | |
| pred = model.predict_proba(x) if hasattr(model, 'predict_proba') else torch.softmax(model(x), dim=1) | |
| predictions.append(pred) | |
| except Exception as e: | |
| logger.warning(f"Model {model_name} prediction failed: {e}") | |
| if not predictions: | |
| raise RuntimeError("All models failed to predict") | |
| # Combine predictions | |
| if method == "average": | |
| # Average probabilities | |
| ensemble_pred = torch.mean(torch.stack(predictions), dim=0) | |
| elif method == "voting": | |
| # Majority voting | |
| votes = torch.stack([torch.argmax(p, dim=1) for p in predictions]) | |
| ensemble_pred = torch.mode(votes, dim=0).values | |
| else: | |
| raise ValueError(f"Unknown ensemble method: {method}") | |
| return ensemble_pred | |
| # CLI for downloading models | |
| if __name__ == "__main__": | |
| import argparse | |
| parser = argparse.ArgumentParser(description="Deepfake Hunter Model Loader") | |
| parser.add_argument("--download-all", action="store_true", | |
| help="Download all models") | |
| parser.add_argument("--list", action="store_true", | |
| help="List available models") | |
| parser.add_argument("--verify", action="store_true", | |
| help="Verify all cached models") | |
| parser.add_argument("--clear-cache", action="store_true", | |
| help="Clear model cache") | |
| parser.add_argument("--gpu", action="store_true", | |
| help="Use GPU if available") | |
| args = parser.parse_args() | |
| loader = ModelLoader(use_gpu=args.gpu) | |
| if args.list: | |
| info = loader.get_model_info() | |
| print("\n=== Model Information ===") | |
| print(f"Device: {info['device']}") | |
| print(f"Cache Directory: {info['cache_dir']}\n") | |
| for model_name, model_info in info['models'].items(): | |
| print(f"{model_name}:") | |
| print(f" Version: {model_info['version']}") | |
| print(f" Type: {model_info['type']}") | |
| print(f" Cached: {model_info['cached']}") | |
| print(f" Verified: {model_info['verified']}") | |
| print() | |
| if args.download_all: | |
| print("\n=== Downloading All Models ===") | |
| models = loader.load_all_models() | |
| print(f"\nSuccessfully loaded {len(models)} models") | |
| if args.verify: | |
| print("\n=== Verifying Models ===") | |
| for model_name in ModelLoader.DEFAULT_MODELS: | |
| verified = loader.verify_model(model_name) | |
| status = "✓" if verified else "✗" | |
| print(f"{status} {model_name}") | |
| if args.clear_cache: | |
| print("\n=== Clearing Cache ===") | |
| loader.cache.clear_cache() | |
| print("Cache cleared") | |
| if not any([args.list, args.download_all, args.verify, args.clear_cache]): | |
| parser.print_help() | |