| """ |
| AI-Generated Image Detector - Inference Script |
| Detects whether an image is real or AI-generated using frequency analysis + deep learning. |
| |
| Usage: |
| python inference.py --image path/to/image.jpg |
| python inference.py --image https://example.com/image.png |
| python inference.py --image_dir path/to/folder/ |
| """ |
| import os |
| import io |
| import math |
| import argparse |
| import json |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from PIL import Image |
| from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize |
| from pathlib import Path |
|
|
|
|
| |
| from train import FrequencyAwareDetector |
|
|
|
|
| def load_model(model_dir=".", device="cuda" if torch.cuda.is_available() else "cpu"): |
| """Load trained FrequencyAwareDetector model.""" |
| config_path = os.path.join(model_dir, "detector_config.json") |
| weights_path = os.path.join(model_dir, "model_state_dict.pt") |
| |
| if os.path.exists(config_path): |
| with open(config_path) as f: |
| config = json.load(f) |
| else: |
| config = { |
| "backbone_name": "microsoft/swinv2-tiny-patch4-window8-256", |
| "num_labels": 2, "dct_patch_size": 32, |
| "num_freq_bands": 8, "fft_bins": 32, |
| } |
| |
| model = FrequencyAwareDetector( |
| backbone_name=config["backbone_name"], |
| num_labels=config["num_labels"], |
| dct_patch_size=config["dct_patch_size"], |
| num_freq_bands=config["num_freq_bands"], |
| fft_bins=config["fft_bins"], |
| ) |
| |
| if os.path.exists(weights_path): |
| state_dict = torch.load(weights_path, map_location=device) |
| model.load_state_dict(state_dict) |
| print(f"✓ Loaded weights from {weights_path}") |
| else: |
| print("⚠ No weights found, using random initialization") |
| |
| model.to(device) |
| model.eval() |
| return model, config |
|
|
|
|
| def get_transform(image_size=256): |
| """Standard evaluation transform.""" |
| return Compose([ |
| Resize((image_size + 32, image_size + 32)), |
| CenterCrop((image_size, image_size)), |
| ToTensor(), |
| Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ]) |
|
|
|
|
| def predict_single(model, image_path, transform, device="cpu"): |
| """Predict whether a single image is real or AI-generated.""" |
| if image_path.startswith("http"): |
| import requests |
| response = requests.get(image_path) |
| img = Image.open(io.BytesIO(response.content)).convert("RGB") |
| else: |
| img = Image.open(image_path).convert("RGB") |
| |
| pixel_values = transform(img).unsqueeze(0).to(device) |
| |
| with torch.no_grad(): |
| output = model(pixel_values=pixel_values) |
| logits = output["logits"] |
| probs = torch.softmax(logits, dim=1) |
| pred = probs.argmax(dim=1).item() |
| confidence = probs[0][pred].item() |
| |
| labels = {0: "Real", 1: "AI-Generated"} |
| return { |
| "prediction": labels[pred], |
| "confidence": confidence, |
| "real_probability": probs[0][0].item(), |
| "ai_generated_probability": probs[0][1].item(), |
| } |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Detect AI-generated images") |
| parser.add_argument("--image", type=str, help="Path or URL to single image") |
| parser.add_argument("--image_dir", type=str, help="Directory of images to analyze") |
| parser.add_argument("--model_dir", type=str, default=".", help="Directory containing model weights") |
| parser.add_argument("--image_size", type=int, default=256) |
| parser.add_argument("--device", type=str, default="auto") |
| args = parser.parse_args() |
| |
| if args.device == "auto": |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| else: |
| device = args.device |
| |
| print(f"Device: {device}") |
| model, config = load_model(args.model_dir, device) |
| transform = get_transform(args.image_size) |
| |
| if args.image: |
| result = predict_single(model, args.image, transform, device) |
| print(f"\n{'='*50}") |
| print(f"Image: {args.image}") |
| print(f"Prediction: {result['prediction']}") |
| print(f"Confidence: {result['confidence']:.2%}") |
| print(f" Real probability: {result['real_probability']:.4f}") |
| print(f" AI-generated probability: {result['ai_generated_probability']:.4f}") |
| print(f"{'='*50}") |
| |
| elif args.image_dir: |
| extensions = {'.jpg', '.jpeg', '.png', '.webp', '.bmp', '.tiff'} |
| image_files = [ |
| f for f in Path(args.image_dir).iterdir() |
| if f.suffix.lower() in extensions |
| ] |
| |
| print(f"\nAnalyzing {len(image_files)} images from {args.image_dir}...\n") |
| |
| results = [] |
| for img_path in sorted(image_files): |
| try: |
| result = predict_single(model, str(img_path), transform, device) |
| results.append(result) |
| status = "🤖" if result["prediction"] == "AI-Generated" else "📷" |
| print(f" {status} {img_path.name}: {result['prediction']} ({result['confidence']:.1%})") |
| except Exception as e: |
| print(f" ❌ {img_path.name}: Error - {e}") |
| |
| real_count = sum(1 for r in results if r["prediction"] == "Real") |
| ai_count = sum(1 for r in results if r["prediction"] == "AI-Generated") |
| print(f"\nSummary: {real_count} Real, {ai_count} AI-Generated out of {len(results)} images") |
| |
| else: |
| parser.print_help() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|