""" AI-Generated Image Detector - Inference Script Detects whether an image is real or AI-generated using frequency analysis + deep learning. Usage: python inference.py --image path/to/image.jpg python inference.py --image https://example.com/image.png python inference.py --image_dir path/to/folder/ """ import os import io import math import argparse import json import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize from pathlib import Path # Import model architecture from train.py from train import FrequencyAwareDetector def load_model(model_dir=".", device="cuda" if torch.cuda.is_available() else "cpu"): """Load trained FrequencyAwareDetector model.""" config_path = os.path.join(model_dir, "detector_config.json") weights_path = os.path.join(model_dir, "model_state_dict.pt") if os.path.exists(config_path): with open(config_path) as f: config = json.load(f) else: config = { "backbone_name": "microsoft/swinv2-tiny-patch4-window8-256", "num_labels": 2, "dct_patch_size": 32, "num_freq_bands": 8, "fft_bins": 32, } model = FrequencyAwareDetector( backbone_name=config["backbone_name"], num_labels=config["num_labels"], dct_patch_size=config["dct_patch_size"], num_freq_bands=config["num_freq_bands"], fft_bins=config["fft_bins"], ) if os.path.exists(weights_path): state_dict = torch.load(weights_path, map_location=device) model.load_state_dict(state_dict) print(f"✓ Loaded weights from {weights_path}") else: print("⚠ No weights found, using random initialization") model.to(device) model.eval() return model, config def get_transform(image_size=256): """Standard evaluation transform.""" return Compose([ Resize((image_size + 32, image_size + 32)), CenterCrop((image_size, image_size)), ToTensor(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) def predict_single(model, image_path, transform, device="cpu"): """Predict whether a single image is real or AI-generated.""" if image_path.startswith("http"): import requests response = requests.get(image_path) img = Image.open(io.BytesIO(response.content)).convert("RGB") else: img = Image.open(image_path).convert("RGB") pixel_values = transform(img).unsqueeze(0).to(device) with torch.no_grad(): output = model(pixel_values=pixel_values) logits = output["logits"] probs = torch.softmax(logits, dim=1) pred = probs.argmax(dim=1).item() confidence = probs[0][pred].item() labels = {0: "Real", 1: "AI-Generated"} return { "prediction": labels[pred], "confidence": confidence, "real_probability": probs[0][0].item(), "ai_generated_probability": probs[0][1].item(), } def main(): parser = argparse.ArgumentParser(description="Detect AI-generated images") parser.add_argument("--image", type=str, help="Path or URL to single image") parser.add_argument("--image_dir", type=str, help="Directory of images to analyze") parser.add_argument("--model_dir", type=str, default=".", help="Directory containing model weights") parser.add_argument("--image_size", type=int, default=256) parser.add_argument("--device", type=str, default="auto") args = parser.parse_args() if args.device == "auto": device = "cuda" if torch.cuda.is_available() else "cpu" else: device = args.device print(f"Device: {device}") model, config = load_model(args.model_dir, device) transform = get_transform(args.image_size) if args.image: result = predict_single(model, args.image, transform, device) print(f"\n{'='*50}") print(f"Image: {args.image}") print(f"Prediction: {result['prediction']}") print(f"Confidence: {result['confidence']:.2%}") print(f" Real probability: {result['real_probability']:.4f}") print(f" AI-generated probability: {result['ai_generated_probability']:.4f}") print(f"{'='*50}") elif args.image_dir: extensions = {'.jpg', '.jpeg', '.png', '.webp', '.bmp', '.tiff'} image_files = [ f for f in Path(args.image_dir).iterdir() if f.suffix.lower() in extensions ] print(f"\nAnalyzing {len(image_files)} images from {args.image_dir}...\n") results = [] for img_path in sorted(image_files): try: result = predict_single(model, str(img_path), transform, device) results.append(result) status = "🤖" if result["prediction"] == "AI-Generated" else "📷" print(f" {status} {img_path.name}: {result['prediction']} ({result['confidence']:.1%})") except Exception as e: print(f" ❌ {img_path.name}: Error - {e}") real_count = sum(1 for r in results if r["prediction"] == "Real") ai_count = sum(1 for r in results if r["prediction"] == "AI-Generated") print(f"\nSummary: {real_count} Real, {ai_count} AI-Generated out of {len(results)} images") else: parser.print_help() if __name__ == "__main__": main()