Reju983's picture
Add inference script for detecting AI-generated images
f53f978 verified
"""
AI-Generated Image Detector - Inference Script
Detects whether an image is real or AI-generated using frequency analysis + deep learning.
Usage:
python inference.py --image path/to/image.jpg
python inference.py --image https://example.com/image.png
python inference.py --image_dir path/to/folder/
"""
import os
import io
import math
import argparse
import json
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from pathlib import Path
# Import model architecture from train.py
from train import FrequencyAwareDetector
def load_model(model_dir=".", device="cuda" if torch.cuda.is_available() else "cpu"):
"""Load trained FrequencyAwareDetector model."""
config_path = os.path.join(model_dir, "detector_config.json")
weights_path = os.path.join(model_dir, "model_state_dict.pt")
if os.path.exists(config_path):
with open(config_path) as f:
config = json.load(f)
else:
config = {
"backbone_name": "microsoft/swinv2-tiny-patch4-window8-256",
"num_labels": 2, "dct_patch_size": 32,
"num_freq_bands": 8, "fft_bins": 32,
}
model = FrequencyAwareDetector(
backbone_name=config["backbone_name"],
num_labels=config["num_labels"],
dct_patch_size=config["dct_patch_size"],
num_freq_bands=config["num_freq_bands"],
fft_bins=config["fft_bins"],
)
if os.path.exists(weights_path):
state_dict = torch.load(weights_path, map_location=device)
model.load_state_dict(state_dict)
print(f"✓ Loaded weights from {weights_path}")
else:
print("⚠ No weights found, using random initialization")
model.to(device)
model.eval()
return model, config
def get_transform(image_size=256):
"""Standard evaluation transform."""
return Compose([
Resize((image_size + 32, image_size + 32)),
CenterCrop((image_size, image_size)),
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
def predict_single(model, image_path, transform, device="cpu"):
"""Predict whether a single image is real or AI-generated."""
if image_path.startswith("http"):
import requests
response = requests.get(image_path)
img = Image.open(io.BytesIO(response.content)).convert("RGB")
else:
img = Image.open(image_path).convert("RGB")
pixel_values = transform(img).unsqueeze(0).to(device)
with torch.no_grad():
output = model(pixel_values=pixel_values)
logits = output["logits"]
probs = torch.softmax(logits, dim=1)
pred = probs.argmax(dim=1).item()
confidence = probs[0][pred].item()
labels = {0: "Real", 1: "AI-Generated"}
return {
"prediction": labels[pred],
"confidence": confidence,
"real_probability": probs[0][0].item(),
"ai_generated_probability": probs[0][1].item(),
}
def main():
parser = argparse.ArgumentParser(description="Detect AI-generated images")
parser.add_argument("--image", type=str, help="Path or URL to single image")
parser.add_argument("--image_dir", type=str, help="Directory of images to analyze")
parser.add_argument("--model_dir", type=str, default=".", help="Directory containing model weights")
parser.add_argument("--image_size", type=int, default=256)
parser.add_argument("--device", type=str, default="auto")
args = parser.parse_args()
if args.device == "auto":
device = "cuda" if torch.cuda.is_available() else "cpu"
else:
device = args.device
print(f"Device: {device}")
model, config = load_model(args.model_dir, device)
transform = get_transform(args.image_size)
if args.image:
result = predict_single(model, args.image, transform, device)
print(f"\n{'='*50}")
print(f"Image: {args.image}")
print(f"Prediction: {result['prediction']}")
print(f"Confidence: {result['confidence']:.2%}")
print(f" Real probability: {result['real_probability']:.4f}")
print(f" AI-generated probability: {result['ai_generated_probability']:.4f}")
print(f"{'='*50}")
elif args.image_dir:
extensions = {'.jpg', '.jpeg', '.png', '.webp', '.bmp', '.tiff'}
image_files = [
f for f in Path(args.image_dir).iterdir()
if f.suffix.lower() in extensions
]
print(f"\nAnalyzing {len(image_files)} images from {args.image_dir}...\n")
results = []
for img_path in sorted(image_files):
try:
result = predict_single(model, str(img_path), transform, device)
results.append(result)
status = "🤖" if result["prediction"] == "AI-Generated" else "📷"
print(f" {status} {img_path.name}: {result['prediction']} ({result['confidence']:.1%})")
except Exception as e:
print(f" ❌ {img_path.name}: Error - {e}")
real_count = sum(1 for r in results if r["prediction"] == "Real")
ai_count = sum(1 for r in results if r["prediction"] == "AI-Generated")
print(f"\nSummary: {real_count} Real, {ai_count} AI-Generated out of {len(results)} images")
else:
parser.print_help()
if __name__ == "__main__":
main()