Reju983's picture
Add comprehensive README with architecture details and usage instructions
5cb23d7 verified
---
tags:
- image-classification
- ai-image-detection
- deepfake-detection
- frequency-analysis
- computer-vision
- pytorch
- swinv2
- srm
- dct
- fft
license: apache-2.0
datasets:
- OwensLab/CommunityForensics-Small
metrics:
- accuracy
pipeline_tag: image-classification
---
# ๐Ÿ” AI-Generated Image Detector
**Multi-Branch Frequency-Aware Detector: SwinV2 + SRM + DCT + FFT**
A robust AI-generated image detector that combines **semantic understanding** with **frequency-domain forensic analysis** to detect AI-generated images from any source โ€” including high-quality outputs from Stable Diffusion, DALL-E, Midjourney, Flux, and 4,800+ other generators.
## ๐Ÿ—๏ธ Architecture
This model uses a novel **4-branch fusion architecture** for maximum detection robustness:
```
Input Image (256ร—256)
โ”‚
โ”Œโ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ฌโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”
โ”‚ โ”‚ โ”‚ โ”‚ โ”‚
โ–ผ โ–ผ โ–ผ โ–ผ โ–ผ
SwinV2 SRM HPF DCT Analyzer FFT Analyzer
(768d) (256d) (22d) (36d)
โ”‚ โ”‚ โ”‚ โ”‚
โ”‚ โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”ผโ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
โ”‚ Freq Features (314d)
โ”‚ โ”‚
โ”‚ Freq Projection (128d)
โ”‚ โ”‚
โ””โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”˜
โ”‚
Fusion MLP (896d โ†’ 512 โ†’ 128 โ†’ 2)
โ”‚
Real / AI-Generated
```
### Branch 1: SwinV2-Tiny Backbone (Semantic Features)
- Pretrained `microsoft/swinv2-tiny-patch4-window8-256`
- Captures high-level semantic inconsistencies (e.g., unnatural textures, impossible geometry)
- 768-dimensional feature vector
### Branch 2: SRM High-Pass Filter Bank (Forensic Residuals)
- **30 fixed Spatial Rich Model (SRM) filters** from image forensics literature
- Based on [Fridrich & Kodovsky (2012)](https://ieeexplore.ieee.org/document/6197267) โ€” the gold standard in steganalysis
- Includes: 1st/2nd/3rd order derivatives, Laplacians, SPAM filters, edge detectors, Gabor-like directional filters
- Detects subtle manipulation artifacts **invisible in RGB space**
- **Zero learnable parameters** in the filter bank โ†’ maximum generalization
- Processed through a lightweight CNN encoder (30โ†’64โ†’128โ†’256 channels)
### Branch 3: DCT Frequency Band Analysis
- **2D Discrete Cosine Transform** on 32ร—32 image patches
- Extracts 8 frequency band energy statistics (mean + std per band)
- Computes **spectral centroid** (center of mass of frequency distribution)
- Measures **high-to-low frequency energy ratio** โ€” AI images often have anomalous ratios
- Captures **DC component statistics** across patches
- 22-dimensional feature vector
### Branch 4: FFT Radial Power Spectrum
- **2D Fast Fourier Transform** with Hanning window (reduces spectral leakage)
- Azimuthally averaged power spectrum in 32 radial bins
- Measures **deviation from natural 1/fยฒ power law** โ€” natural images follow this law, AI-generated images deviate
- Extracts: log spectrum, spectral slope, intercept, residual std, residual max
- Detects **upsampling artifacts** and periodic patterns from generator architectures
- 36-dimensional feature vector
### Fusion
- Frequency features (SRM + DCT + FFT = 314d) โ†’ projected to 128d
- Concatenated with SwinV2 semantic features (768d) โ†’ 896d
- MLP classifier with dropout (0.3, 0.1) and label smoothing (0.1)
**Total parameters: ~28.6M** (compact enough for real-time inference)
## ๐Ÿ“Š Training Dataset
**[OwensLab/CommunityForensics-Small](https://huggingface.co/datasets/OwensLab/CommunityForensics-Small)** (CVPR 2025)
- **556,000 images** (278K real + 278K AI-generated)
- **4,803 different AI generators** โ€” the most diverse training set ever used
- Real images from: LAION, ImageNet, COCO, FFHQ, CelebA, MetFaces, AFHQ, and more
- AI images from: All Stable Diffusion variants, DeepFloyd, StyleGAN 1/2/3, BigGAN, VQDM, and thousands of community models
### Social Media Robustness Augmentation
During training, images are augmented with:
- **Random JPEG compression** (QF 30-95) โ€” simulates Instagram/Twitter/WhatsApp compression
- **Gaussian blur** (ฯƒ 0.1-2.0) โ€” simulates re-encoding artifacts
- **Downscale-upscale** (0.5x-0.9x) โ€” simulates re-upload quality loss
- Standard color jitter, random crops, and horizontal flips
## ๐Ÿš€ Training
### Requirements
```bash
pip install transformers torch torchvision datasets evaluate accelerate trackio pillow scikit-learn
```
### Run Training
```bash
# Full training on GPU (recommended: A10G 24GB or better)
python train.py \
--num_train_epochs 5 \
--per_device_train_batch_size 16 \
--gradient_accumulation_steps 4 \
--learning_rate 2e-5 \
--hub_model_id your-username/ai-image-detector
# Quick test run
python train.py --test_mode
# Custom settings
python train.py \
--max_train_samples 50000 \
--num_train_epochs 3 \
--per_device_train_batch_size 8 \
--image_size 256
```
### Training Hyperparameters
| Parameter | Value |
|-----------|-------|
| Optimizer | AdamW |
| Learning rate | 2e-5 |
| Weight decay | 0.01 |
| Warmup ratio | 0.1 |
| Batch size | 16 ร— 4 GPUs = 64 effective |
| Epochs | 5 |
| Precision | bf16 |
| Label smoothing | 0.1 |
| Gradient checkpointing | โœ“ |
| Image size | 256ร—256 |
## ๐Ÿ”ฎ Inference
### Single Image
```python
import torch
from train import FrequencyAwareDetector
from torchvision.transforms import Compose, Resize, CenterCrop, ToTensor, Normalize
from PIL import Image
# Load model
model = FrequencyAwareDetector()
state_dict = torch.load("model_state_dict.pt", map_location="cpu")
model.load_state_dict(state_dict)
model.eval()
# Preprocess
transform = Compose([
Resize((288, 288)),
CenterCrop((256, 256)),
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
img = Image.open("test.jpg").convert("RGB")
pixel_values = transform(img).unsqueeze(0)
# Predict
with torch.no_grad():
output = model(pixel_values=pixel_values)
probs = torch.softmax(output["logits"], dim=1)
pred = probs.argmax(dim=1).item()
labels = {0: "Real", 1: "AI-Generated"}
print(f"Prediction: {labels[pred]} ({probs[0][pred]:.2%} confidence)")
```
### Command Line
```bash
# Single image
python inference.py --image photo.jpg
# URL
python inference.py --image https://example.com/image.png
# Batch (entire directory)
python inference.py --image_dir ./photos/
```
## ๐Ÿ“š Scientific Background
### Why Frequency Analysis?
AI-generated images contain subtle artifacts that are invisible to the human eye but detectable in the frequency domain:
1. **Upsampling Artifacts**: Diffusion models and GANs use transposed convolutions and upsampling layers that leave periodic patterns in the frequency spectrum
2. **1/fยฒ Deviation**: Natural images follow a characteristic 1/fยฒ power spectrum (Fourier). AI images deviate from this, especially at mid-to-high frequencies
3. **DCT Block Patterns**: The generation process creates non-natural distributions of DCT coefficients across image patches
4. **Noise Residuals**: SRM filters reveal that AI images have fundamentally different noise patterns than camera-captured images
### Key References
1. **AIDE** (2024): "A Sanity Check for AI-generated Image Detection" โ€” [arxiv:2406.19435](https://arxiv.org/abs/2406.19435). DCT patch selection + SRM + CLIP fusion achieves 92.77% on AIGCDetectBenchmark.
2. **CommunityForensics** (CVPR 2025): "Using Thousands of Generators to Train Fake Image Detectors" โ€” [arxiv:2411.04125](https://arxiv.org/abs/2411.04125). Training on diverse generators (4803+) dramatically improves cross-generator generalization.
3. **SRM Filters**: Fridrich & Kodovsky (2012) โ€” "Rich Models for Steganalysis of Digital Images". The standard filter bank for image forensics.
4. **UnivFD**: Ojha et al. (2023) โ€” "Towards Universal Fake Image Detectors". CLIP features for zero-shot detection.
## ๐Ÿ“ Repository Structure
```
โ”œโ”€โ”€ train.py # Full training script with model architecture
โ”œโ”€โ”€ inference.py # Easy-to-use inference script
โ”œโ”€โ”€ detector_config.json # Model configuration
โ”œโ”€โ”€ model_state_dict.pt # Trained weights (after training)
โ””โ”€โ”€ README.md # This file
```
## License
Apache 2.0