--- tags: - image-classification - ai-image-detection - deepfake-detection - frequency-analysis - computer-vision - pytorch - swinv2 - srm - dct - fft license: apache-2.0 datasets: - OwensLab/CommunityForensics-Small metrics: - accuracy pipeline_tag: image-classification --- # 🔍 AI-Generated Image Detector **Multi-Branch Frequency-Aware Detector: SwinV2 + SRM + DCT + FFT** A robust AI-generated image detector that combines **semantic understanding** with **frequency-domain forensic analysis** to detect AI-generated images from any source — including high-quality outputs from Stable Diffusion, DALL-E, Midjourney, Flux, and 4,800+ other generators. ## 🏗️ Architecture This model uses a novel **4-branch fusion architecture** for maximum detection robustness: ``` Input Image (256×256) │ ┌────┼─────────────┬──────────────┬──────────────┐ │ │ │ │ │ ▼ ▼ ▼ ▼ ▼ SwinV2 SRM HPF DCT Analyzer FFT Analyzer (768d) (256d) (22d) (36d) │ │ │ │ │ └─────────────┼──────────────┘ │ Freq Features (314d) │ │ │ Freq Projection (128d) │ │ └──────────────────┘ │ Fusion MLP (896d → 512 → 128 → 2) │ Real / AI-Generated ``` ### Branch 1: SwinV2-Tiny Backbone (Semantic Features) - Pretrained `microsoft/swinv2-tiny-patch4-window8-256` - Captures high-level semantic inconsistencies (e.g., unnatural textures, impossible geometry) - 768-dimensional feature vector ### Branch 2: SRM High-Pass Filter Bank (Forensic Residuals) - **30 fixed Spatial Rich Model (SRM) filters** from image forensics literature - Based on [Fridrich & Kodovsky (2012)](https://ieeexplore.ieee.org/document/6197267) — the gold standard in steganalysis - Includes: 1st/2nd/3rd order derivatives, Laplacians, SPAM filters, edge detectors, Gabor-like directional filters - Detects subtle manipulation artifacts **invisible in RGB space** - **Zero learnable parameters** in the filter bank → maximum generalization - Processed through a lightweight CNN encoder (30→64→128→256 channels) ### Branch 3: DCT Frequency Band Analysis - **2D Discrete Cosine Transform** on 32×32 image patches - Extracts 8 frequency band energy statistics (mean + std per band) - Computes **spectral centroid** (center of mass of frequency distribution) - Measures **high-to-low frequency energy ratio** — AI images often have anomalous ratios - Captures **DC component statistics** across patches - 22-dimensional feature vector ### Branch 4: FFT Radial Power Spectrum - **2D Fast Fourier Transform** with Hanning window (reduces spectral leakage) - Azimuthally averaged power spectrum in 32 radial bins - Measures **deviation from natural 1/f² power law** — natural images follow this law, AI-generated images deviate - Extracts: log spectrum, spectral slope, intercept, residual std, residual max - Detects **upsampling artifacts** and periodic patterns from generator architectures - 36-dimensional feature vector ### Fusion - Frequency features (SRM + DCT + FFT = 314d) → projected to 128d - Concatenated with SwinV2 semantic features (768d) → 896d - MLP classifier with dropout (0.3, 0.1) and label smoothing (0.1) **Total parameters: ~28.6M** (compact enough for real-time inference) ## 📊 Training Dataset **[OwensLab/CommunityForensics-Small](https://huggingface.co/datasets/OwensLab/CommunityForensics-Small)** (CVPR 2025) - **556,000 images** (278K real + 278K AI-generated) - **4,803 different AI generators** — the most diverse training set ever used - Real images from: LAION, ImageNet, COCO, FFHQ, CelebA, MetFaces, AFHQ, and more - AI images from: All Stable Diffusion variants, DeepFloyd, StyleGAN 1/2/3, BigGAN, VQDM, and thousands of community models ### Social Media Robustness Augmentation During training, images are augmented with: - **Random JPEG compression** (QF 30-95) — simulates Instagram/Twitter/WhatsApp compression - **Gaussian blur** (σ 0.1-2.0) — simulates re-encoding artifacts - **Downscale-upscale** (0.5x-0.9x) — simulates re-upload quality loss - Standard color jitter, random crops, and horizontal flips ## 🚀 Training ### Requirements ```bash pip install transformers torch torchvision datasets evaluate accelerate trackio pillow scikit-learn ``` ### Run Training ```bash # Full training on GPU (recommended: A10G 24GB or better) python train.py \ --num_train_epochs 5 \ --per_device_train_batch_size 16 \ --gradient_accumulation_steps 4 \ --learning_rate 2e-5 \ --hub_model_id your-username/ai-image-detector # Quick test run python train.py --test_mode # Custom settings python train.py \ --max_train_samples 50000 \ --num_train_epochs 3 \ --per_device_train_batch_size 8 \ --image_size 256 ``` ### Training Hyperparameters | Parameter | Value | |-----------|-------| | Optimizer | AdamW | | Learning rate | 2e-5 | | Weight decay | 0.01 | | Warmup ratio | 0.1 | | Batch size | 16 × 4 GPUs = 64 effective | | Epochs | 5 | | Precision | bf16 | | Label smoothing | 0.1 | | Gradient checkpointing | ✓ | | Image size | 256×256 | ## 🔮 Inference ### Single Image ```python import torch from train import FrequencyAwareDetector from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize from PIL import Image # Load model model = FrequencyAwareDetector() state_dict = torch.load("model_state_dict.pt", map_location="cpu") model.load_state_dict(state_dict) model.eval() # Preprocess transform = Compose([ Resize((288, 288)), CenterCrop((256, 256)), ToTensor(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) img = Image.open("test.jpg").convert("RGB") pixel_values = transform(img).unsqueeze(0) # Predict with torch.no_grad(): output = model(pixel_values=pixel_values) probs = torch.softmax(output["logits"], dim=1) pred = probs.argmax(dim=1).item() labels = {0: "Real", 1: "AI-Generated"} print(f"Prediction: {labels[pred]} ({probs[0][pred]:.2%} confidence)") ``` ### Command Line ```bash # Single image python inference.py --image photo.jpg # URL python inference.py --image https://example.com/image.png # Batch (entire directory) python inference.py --image_dir ./photos/ ``` ## 📚 Scientific Background ### Why Frequency Analysis? AI-generated images contain subtle artifacts that are invisible to the human eye but detectable in the frequency domain: 1. **Upsampling Artifacts**: Diffusion models and GANs use transposed convolutions and upsampling layers that leave periodic patterns in the frequency spectrum 2. **1/f² Deviation**: Natural images follow a characteristic 1/f² power spectrum (Fourier). AI images deviate from this, especially at mid-to-high frequencies 3. **DCT Block Patterns**: The generation process creates non-natural distributions of DCT coefficients across image patches 4. **Noise Residuals**: SRM filters reveal that AI images have fundamentally different noise patterns than camera-captured images ### Key References 1. **AIDE** (2024): "A Sanity Check for AI-generated Image Detection" — [arxiv:2406.19435](https://arxiv.org/abs/2406.19435). DCT patch selection + SRM + CLIP fusion achieves 92.77% on AIGCDetectBenchmark. 2. **CommunityForensics** (CVPR 2025): "Using Thousands of Generators to Train Fake Image Detectors" — [arxiv:2411.04125](https://arxiv.org/abs/2411.04125). Training on diverse generators (4803+) dramatically improves cross-generator generalization. 3. **SRM Filters**: Fridrich & Kodovsky (2012) — "Rich Models for Steganalysis of Digital Images". The standard filter bank for image forensics. 4. **UnivFD**: Ojha et al. (2023) — "Towards Universal Fake Image Detectors". CLIP features for zero-shot detection. ## 📁 Repository Structure ``` ├── train.py # Full training script with model architecture ├── inference.py # Easy-to-use inference script ├── detector_config.json # Model configuration ├── model_state_dict.pt # Trained weights (after training) └── README.md # This file ``` ## License Apache 2.0