| | import argparse |
| | import torch |
| | import cv2 |
| | import os |
| | import glob |
| | import numpy as np |
| | import ssl |
| | |
| | ssl._create_default_https_context = ssl._create_unverified_context |
| |
|
| | import albumentations as A |
| | from albumentations.pytorch import ToTensorV2 |
| | from src.models import DeepfakeDetector |
| | from src.config import Config |
| |
|
| | try: |
| | from safetensors.torch import load_file |
| | SAFETENSORS_AVAILABLE = True |
| | except ImportError: |
| | SAFETENSORS_AVAILABLE = False |
| |
|
| | def get_transform(): |
| | return A.Compose([ |
| | A.Resize(Config.IMAGE_SIZE, Config.IMAGE_SIZE), |
| | A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), |
| | ToTensorV2(), |
| | ]) |
| |
|
| | def load_models(checkpoints_arg, device): |
| | """ |
| | Load one or multiple models for ensemble inference. |
| | checkpoints_arg: Comma-separated list of paths, or single path, or directory. |
| | """ |
| | paths = [] |
| | if os.path.isdir(checkpoints_arg): |
| | paths = glob.glob(os.path.join(checkpoints_arg, "*.safetensors")) |
| | if not paths: |
| | paths = glob.glob(os.path.join(checkpoints_arg, "*.pth")) |
| | else: |
| | paths = checkpoints_arg.split(',') |
| | |
| | models = [] |
| | print(f"Loading {len(paths)} model(s) for ensemble inference...") |
| | |
| | for path in paths: |
| | path = path.strip() |
| | if not path: continue |
| | |
| | print(f"Loading: {path}") |
| | model = DeepfakeDetector(pretrained=False) |
| | model.to(device) |
| | model.eval() |
| | |
| | try: |
| | if path.endswith(".safetensors") and SAFETENSORS_AVAILABLE: |
| | state_dict = load_file(path) |
| | else: |
| | state_dict = torch.load(path, map_location=device) |
| | model.load_state_dict(state_dict) |
| | models.append(model) |
| | print(f"✅ Successfully loaded: {os.path.basename(path)}") |
| | except Exception as e: |
| | |
| | try: |
| | print(f"⚠️ Initial load failed. Attempting legacy key remapping for {os.path.basename(path)}...") |
| | from collections import OrderedDict |
| | new_state_dict = OrderedDict() |
| | for k, v in state_dict.items(): |
| | if k.startswith('rgb_branch.features.'): |
| | new_k = k.replace('rgb_branch.features.', 'rgb_branch.net.features.') |
| | new_state_dict[new_k] = v |
| | elif k.startswith('rgb_branch.avgpool.'): |
| | new_k = k.replace('rgb_branch.avgpool.', 'rgb_branch.net.avgpool.') |
| | new_state_dict[new_k] = v |
| | else: |
| | new_state_dict[k] = v |
| | |
| | model.load_state_dict(new_state_dict, strict=False) |
| | models.append(model) |
| | print(f"✅ Successfully loaded (with remapping): {os.path.basename(path)}") |
| | except Exception as e2: |
| | print(f"❌ Failed to load {path}: {e}") |
| | print(f"❌ Remapping also failed: {e2}") |
| | |
| | if not models: |
| | |
| | print("Warning: No valid checkoints loaded. Using random initialization for testing flow.") |
| | model = DeepfakeDetector(pretrained=False).to(device) |
| | model.eval() |
| | models.append(model) |
| | |
| | return models |
| |
|
| | def predict_ensemble(models, image_path, device, transform): |
| | try: |
| | image = cv2.imread(image_path) |
| | if image is None: |
| | return None, "Error: Could not read image" |
| | image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
| | except Exception as e: |
| | return None, str(e) |
| |
|
| | augmented = transform(image=image) |
| | image_tensor = augmented['image'].unsqueeze(0).to(device) |
| | |
| | probs = [] |
| | with torch.no_grad(): |
| | for model in models: |
| | logits = model(image_tensor) |
| | prob = torch.sigmoid(logits).item() |
| | probs.append(prob) |
| | |
| | |
| | avg_prob = sum(probs) / len(probs) |
| | return avg_prob, None |
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description="Deepfake Detection Inference (Ensemble Support)") |
| | parser.add_argument("--source", type=str, required=True, help="Path to image or directory") |
| | parser.add_argument("--checkpoints", type=str, default=Config.ACTIVE_MODEL_PATH, help="Path to checkpoint file or directory (Default: Mark-V)") |
| | parser.add_argument("--device", type=str, default=Config.DEVICE, help="Device to use (cuda/mps/cpu)") |
| | args = parser.parse_args() |
| | |
| | device = torch.device(args.device) |
| | print(f"Using device: {device}") |
| | |
| | |
| | models = load_models(args.checkpoints, device) |
| | transform = get_transform() |
| | |
| | |
| | if os.path.isdir(args.source): |
| | files = glob.glob(os.path.join(args.source, "*.*")) |
| | |
| | files = [f for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))] |
| | else: |
| | files = [args.source] |
| | |
| | print(f"Processing {len(files)} images with {len(models)} model(s)...") |
| | print("-" * 65) |
| | print(f"{'Image Name':<40} | {'Prediction':<10} | {'Confidence':<10}") |
| | print("-" * 65) |
| | |
| | for file_path in files: |
| | prob, error = predict_ensemble(models, file_path, device, transform) |
| | if error: |
| | print(f"{os.path.basename(file_path):<40} | ERROR: {error}") |
| | continue |
| | |
| | is_fake = prob > 0.5 |
| | label = "FAKE" if is_fake else "REAL" |
| | confidence = prob if is_fake else 1 - prob |
| | |
| | print(f"{os.path.basename(file_path):<40} | {label:<10} | {confidence:.2%}") |
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|