Harshasnade's picture
Deploy Backend (No Frontend)
0966609
import argparse
import torch
import cv2
import os
import glob
import numpy as np
import ssl
# Disable SSL verification
ssl._create_default_https_context = ssl._create_unverified_context
import albumentations as A
from albumentations.pytorch import ToTensorV2
from src.models import DeepfakeDetector
from src.config import Config
try:
from safetensors.torch import load_file
SAFETENSORS_AVAILABLE = True
except ImportError:
SAFETENSORS_AVAILABLE = False
def get_transform():
return A.Compose([
A.Resize(Config.IMAGE_SIZE, Config.IMAGE_SIZE),
A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
ToTensorV2(),
])
def load_models(checkpoints_arg, device):
"""
Load one or multiple models for ensemble inference.
checkpoints_arg: Comma-separated list of paths, or single path, or directory.
"""
paths = []
if os.path.isdir(checkpoints_arg):
paths = glob.glob(os.path.join(checkpoints_arg, "*.safetensors"))
if not paths:
paths = glob.glob(os.path.join(checkpoints_arg, "*.pth"))
else:
paths = checkpoints_arg.split(',')
models = []
print(f"Loading {len(paths)} model(s) for ensemble inference...")
for path in paths:
path = path.strip()
if not path: continue
print(f"Loading: {path}")
model = DeepfakeDetector(pretrained=False) # Structure only
model.to(device)
model.eval()
try:
if path.endswith(".safetensors") and SAFETENSORS_AVAILABLE:
state_dict = load_file(path)
else:
state_dict = torch.load(path, map_location=device)
model.load_state_dict(state_dict)
models.append(model)
print(f"✅ Successfully loaded: {os.path.basename(path)}")
except Exception as e:
print(f"❌ Failed to load {path}: {e}")
import traceback
traceback.print_exc()
if not models:
# Fallback for testing if no checkpoint exists yet
print("Warning: No valid checkoints loaded. Using random initialization for testing flow.")
model = DeepfakeDetector(pretrained=False).to(device)
model.eval()
models.append(model)
return models
def predict_ensemble(models, image_path, device, transform):
try:
image = cv2.imread(image_path)
if image is None:
return None, "Error: Could not read image"
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
except Exception as e:
return None, str(e)
augmented = transform(image=image)
image_tensor = augmented['image'].unsqueeze(0).to(device)
probs = []
with torch.no_grad():
for model in models:
logits = model(image_tensor)
prob = torch.sigmoid(logits).item()
probs.append(prob)
# Ensemble Strategy: Average Probability
avg_prob = sum(probs) / len(probs)
return avg_prob, None
def main():
parser = argparse.ArgumentParser(description="Deepfake Detection Inference (Ensemble Support)")
parser.add_argument("--source", type=str, required=True, help="Path to image or directory")
parser.add_argument("--checkpoints", type=str, default="results/checkpoints", help="Path to checkpoint file, list of files (comma-separated), or directory")
parser.add_argument("--device", type=str, default=Config.DEVICE, help="Device to use (cuda/mps/cpu)")
args = parser.parse_args()
device = torch.device(args.device)
print(f"Using device: {device}")
# Load Models
models = load_models(args.checkpoints, device)
transform = get_transform()
# Process Source
if os.path.isdir(args.source):
files = glob.glob(os.path.join(args.source, "*.*"))
# Filter images
files = [f for f in files if f.lower().endswith(('.png', '.jpg', '.jpeg', '.webp'))]
else:
files = [args.source]
print(f"Processing {len(files)} images with {len(models)} model(s)...")
print("-" * 65)
print(f"{'Image Name':<40} | {'Prediction':<10} | {'Confidence':<10}")
print("-" * 65)
for file_path in files:
prob, error = predict_ensemble(models, file_path, device, transform)
if error:
print(f"{os.path.basename(file_path):<40} | ERROR: {error}")
continue
is_fake = prob > 0.5
label = "FAKE" if is_fake else "REAL"
confidence = prob if is_fake else 1 - prob
print(f"{os.path.basename(file_path):<40} | {label:<10} | {confidence:.2%}")
if __name__ == "__main__":
main()