""" AI-Generated Image Detector Architecture: SwinV2 backbone + SRM High-Pass Filter + DCT Frequency Analysis + FFT Spectral Analysis Dataset: OwensLab/CommunityForensics-Small (556K images, 4803 generators) Based on: AIDE paper (arxiv:2406.19435) + CommunityForensics (CVPR 2025) """ import os import io import math import random import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from PIL import Image from torch.utils.data import Dataset, DataLoader from transformers import ( AutoImageProcessor, Swinv2Model, TrainingArguments, Trainer, DefaultDataCollator, ) from torchvision.transforms import ( Compose, Normalize, Resize, CenterCrop, RandomResizedCrop, RandomHorizontalFlip, ToTensor, ColorJitter, ) import evaluate import datasets as hf_datasets import trackio # ============================================================ # 1. SRM HIGH-PASS FILTER BANK (30 fixed kernels, no gradient) # ============================================================ def get_srm_kernels(): """Generate 30 SRM high-pass filter kernels (5x5).""" f1 = np.array([[0,0,0,0,0],[0,0,0,0,0],[0,0,-1,1,0],[0,0,0,0,0],[0,0,0,0,0]], dtype=np.float32) f2 = np.array([[0,0,0,0,0],[0,0,0,0,0],[0,0,-1,0,0],[0,0,1,0,0],[0,0,0,0,0]], dtype=np.float32) f3 = np.array([[0,0,0,0,0],[0,0,0,0,0],[0,1,-2,1,0],[0,0,0,0,0],[0,0,0,0,0]], dtype=np.float32) f4 = np.array([[0,0,0,0,0],[0,0,1,0,0],[0,0,-2,0,0],[0,0,1,0,0],[0,0,0,0,0]], dtype=np.float32) f5 = np.array([[0,0,0,0,0],[0,1,0,0,0],[0,0,-2,0,0],[0,0,0,1,0],[0,0,0,0,0]], dtype=np.float32) f6 = np.array([[0,0,0,0,0],[0,0,0,1,0],[0,0,-2,0,0],[0,1,0,0,0],[0,0,0,0,0]], dtype=np.float32) f7 = np.array([[0,0,0,0,0],[0,0,0,0,0],[0,-1,3,-3,1],[0,0,0,0,0],[0,0,0,0,0]], dtype=np.float32) f8 = np.array([[0,0,0,0,0],[0,0,-1,0,0],[0,0,3,0,0],[0,0,-3,0,0],[0,0,1,0,0]], dtype=np.float32) f9 = np.array([[0,0,0,0,0],[0,0,1,0,0],[0,1,-4,1,0],[0,0,1,0,0],[0,0,0,0,0]], dtype=np.float32) f10 = np.array([[0,0,0,0,0],[0,1,1,1,0],[0,1,-8,1,0],[0,1,1,1,0],[0,0,0,0,0]], dtype=np.float32) / 3.0 f11 = np.array([[0,0,0,0,0],[0,-1,2,-1,0],[0,2,-4,2,0],[0,-1,2,-1,0],[0,0,0,0,0]], dtype=np.float32) f12 = np.array([[0,0,0,0,0],[0,0,0,0,0],[-1,2,-2,2,-1],[0,0,0,0,0],[0,0,0,0,0]], dtype=np.float32) / 2.0 f13 = np.array([[0,0,-1,0,0],[0,0,2,0,0],[0,0,-2,0,0],[0,0,2,0,0],[0,0,-1,0,0]], dtype=np.float32) / 2.0 f14 = np.array([[0,0,-1,0,0],[0,0,2,0,0],[-1,2,-4,2,-1],[0,0,2,0,0],[0,0,-1,0,0]], dtype=np.float32) / 4.0 f15 = np.array([[-1,2,-2,2,-1],[2,-6,8,-6,2],[-2,8,-12,8,-2],[2,-6,8,-6,2],[-1,2,-2,2,-1]], dtype=np.float32) / 12.0 spam_h = np.array([[0,0,0,0,0],[0,0,0,0,0],[0,-1,2,-1,0],[0,0,0,0,0],[0,0,0,0,0]], dtype=np.float32) spam_v = np.array([[0,0,0,0,0],[0,0,-1,0,0],[0,0,2,0,0],[0,0,-1,0,0],[0,0,0,0,0]], dtype=np.float32) spam_d1 = np.array([[0,0,0,0,0],[0,-1,0,0,0],[0,0,2,0,0],[0,0,0,-1,0],[0,0,0,0,0]], dtype=np.float32) spam_d2 = np.array([[0,0,0,0,0],[0,0,0,-1,0],[0,0,2,0,0],[0,-1,0,0,0],[0,0,0,0,0]], dtype=np.float32) spam3_h = np.array([[0,0,0,0,0],[0,0,0,0,0],[1,-3,3,-1,0],[0,0,0,0,0],[0,0,0,0,0]], dtype=np.float32) spam3_v = np.array([[0,0,1,0,0],[0,0,-3,0,0],[0,0,3,0,0],[0,0,-1,0,0],[0,0,0,0,0]], dtype=np.float32) sq5_1 = np.array([[0,0,0,0,0],[0,0,1,0,0],[0,1,-4,1,0],[0,0,1,0,0],[0,0,0,0,0]], dtype=np.float32) / 2.0 sq5_2 = np.array([[0,0,0,0,0],[0,1,0,1,0],[0,0,-4,0,0],[0,1,0,1,0],[0,0,0,0,0]], dtype=np.float32) / 2.0 cross1 = np.array([[0,0,-1,0,0],[0,0,2,0,0],[0,0,-2,0,0],[0,0,2,0,0],[0,0,-1,0,0]], dtype=np.float32) / 2.0 cross2 = np.array([[0,0,0,0,0],[0,0,0,0,0],[-1,2,-2,2,-1],[0,0,0,0,0],[0,0,0,0,0]], dtype=np.float32) / 2.0 edge_d1 = np.array([[-1,0,0,0,0],[0,2,0,0,0],[0,0,-2,0,0],[0,0,0,2,0],[0,0,0,0,-1]], dtype=np.float32) / 2.0 edge_d2 = np.array([[0,0,0,0,-1],[0,0,0,2,0],[0,0,-2,0,0],[0,2,0,0,0],[-1,0,0,0,0]], dtype=np.float32) / 2.0 gabor_h = np.array([[0,0,0,0,0],[1,-1,0,-1,1],[0,0,0,0,0],[-1,1,0,1,-1],[0,0,0,0,0]], dtype=np.float32) / 4.0 gabor_v = np.array([[0,1,0,-1,0],[0,-1,0,1,0],[0,0,0,0,0],[0,-1,0,1,0],[0,1,0,-1,0]], dtype=np.float32) / 4.0 all_filters = [f1,f2,f3,f4,f5,f6,f7,f8,f9,f10,f11,f12,f13,f14,f15, spam_h,spam_v,spam_d1,spam_d2,spam3_h,spam3_v, sq5_1,sq5_2,cross1,cross2,edge_d1,edge_d2,gabor_h,gabor_v,f15] return all_filters[:30] class SRMFilterBank(nn.Module): """SRM High-Pass Filter Bank - 30 fixed forensic filters.""" def __init__(self): super().__init__() kernels = get_srm_kernels() weight = torch.stack([torch.tensor(k) for k in kernels]).unsqueeze(1) weight = weight.repeat(1, 3, 1, 1) / 3.0 self.register_buffer('weight', weight) def forward(self, x): return F.conv2d(x, self.weight, padding=2) # ============================================================ # 2. DCT FREQUENCY ANALYSIS MODULE # ============================================================ class DCTFrequencyAnalyzer(nn.Module): """Extracts frequency-domain features using 2D DCT on image patches.""" def __init__(self, patch_size=32, num_freq_bands=8): super().__init__() self.patch_size = patch_size self.num_freq_bands = num_freq_bands N = patch_size dct_mat = torch.zeros(N, N) for k in range(N): for n in range(N): if k == 0: dct_mat[k, n] = math.sqrt(1.0 / N) else: dct_mat[k, n] = math.sqrt(2.0 / N) * math.cos(math.pi * (2*n + 1) * k / (2*N)) self.register_buffer('dct_mat', dct_mat) self.register_buffer('dct_mat_t', dct_mat.t()) def dct2d(self, x): return torch.matmul(torch.matmul(self.dct_mat, x), self.dct_mat_t) def forward(self, x): B, C, H, W = x.shape ps = self.patch_size gray = 0.299 * x[:, 0] + 0.587 * x[:, 1] + 0.114 * x[:, 2] h_patches = H // ps w_patches = W // ps gray = gray[:, :h_patches*ps, :w_patches*ps] patches = gray.unfold(1, ps, ps).unfold(2, ps, ps) B_p, hp, wp = patches.shape[:3] patches = patches.reshape(B_p * hp * wp, ps, ps) dct_patches = self.dct2d(patches) dct_patches = dct_patches.reshape(B_p, hp * wp, ps, ps) features = [] freq_y = torch.arange(ps, device=x.device).float() freq_x = torch.arange(ps, device=x.device).float() fy, fx = torch.meshgrid(freq_y, freq_x, indexing='ij') freq_dist = torch.sqrt(fy**2 + fx**2) max_freq = math.sqrt(2) * ps for band in range(self.num_freq_bands): lo = band * max_freq / self.num_freq_bands hi = (band + 1) * max_freq / self.num_freq_bands mask = ((freq_dist >= lo) & (freq_dist < hi)).float() band_energy = (dct_patches ** 2 * mask.unsqueeze(0).unsqueeze(0)).sum(dim=(-2, -1)) features.append(band_energy.mean(dim=1, keepdim=True)) features.append(band_energy.std(dim=1, keepdim=True)) total_energy = (dct_patches ** 2).sum(dim=(-2, -1)) weighted_freq = (dct_patches ** 2 * freq_dist.unsqueeze(0).unsqueeze(0)).sum(dim=(-2, -1)) spectral_centroid = weighted_freq / (total_energy + 1e-8) features.append(spectral_centroid.mean(dim=1, keepdim=True)) features.append(spectral_centroid.std(dim=1, keepdim=True)) mid = ps // 2 low_mask = (freq_dist < mid).float() high_mask = (freq_dist >= mid).float() low_energy = (dct_patches ** 2 * low_mask).sum(dim=(-2, -1)) high_energy = (dct_patches ** 2 * high_mask).sum(dim=(-2, -1)) hl_ratio = high_energy / (low_energy + 1e-8) features.append(hl_ratio.mean(dim=1, keepdim=True)) features.append(hl_ratio.std(dim=1, keepdim=True)) dc_values = dct_patches[:, :, 0, 0] features.append(dc_values.mean(dim=1, keepdim=True)) features.append(dc_values.std(dim=1, keepdim=True)) return torch.cat(features, dim=1) # ============================================================ # 3. FFT SPECTRAL ANALYSIS MODULE # ============================================================ class FFTSpectralAnalyzer(nn.Module): """Azimuthally averaged power spectrum + 1/f deviation analysis.""" def __init__(self, num_bins=32): super().__init__() self.num_bins = num_bins def forward(self, x): B, C, H, W = x.shape gray = 0.299 * x[:, 0] + 0.587 * x[:, 1] + 0.114 * x[:, 2] hann_y = torch.hann_window(H, device=x.device) hann_x = torch.hann_window(W, device=x.device) window = hann_y.unsqueeze(1) * hann_x.unsqueeze(0) gray = gray * window.unsqueeze(0) fft = torch.fft.fft2(gray) fft_shift = torch.fft.fftshift(fft) power = torch.abs(fft_shift) ** 2 cy, cx = H // 2, W // 2 y = torch.arange(H, device=x.device).float() - cy xx = torch.arange(W, device=x.device).float() - cx yy, xx = torch.meshgrid(y, xx, indexing='ij') radius = torch.sqrt(yy**2 + xx**2) max_radius = min(cy, cx) bin_width = max_radius / self.num_bins features = [] for i in range(self.num_bins): r_lo = i * bin_width r_hi = (i + 1) * bin_width mask = ((radius >= r_lo) & (radius < r_hi)).float() count = mask.sum() + 1e-8 bin_power = (power * mask.unsqueeze(0)).sum(dim=(-2, -1)) / count features.append(bin_power.unsqueeze(1)) radial_spectrum = torch.cat(features, dim=1) log_spectrum = torch.log1p(radial_spectrum) log_freq = torch.log1p(torch.arange(self.num_bins, device=x.device).float() + 1) log_freq = log_freq.unsqueeze(0).expand(B, -1) xm = log_freq - log_freq.mean(dim=1, keepdim=True) ym = log_spectrum - log_spectrum.mean(dim=1, keepdim=True) slope = (xm * ym).sum(dim=1, keepdim=True) / ((xm**2).sum(dim=1, keepdim=True) + 1e-8) intercept = log_spectrum.mean(dim=1, keepdim=True) - slope * log_freq.mean(dim=1, keepdim=True) predicted = slope * log_freq + intercept residuals = log_spectrum - predicted residual_std = residuals.std(dim=1, keepdim=True) residual_max = residuals.max(dim=1, keepdim=True)[0] return torch.cat([log_spectrum, slope, intercept, residual_std, residual_max], dim=1) # ============================================================ # 4. FREQUENCY-AWARE DETECTOR MODEL # ============================================================ class FrequencyAwareDetector(nn.Module): """Multi-branch AI image detector: SwinV2 + SRM + DCT + FFT""" def __init__(self, backbone_name="microsoft/swinv2-tiny-patch4-window8-256", num_labels=2, dct_patch_size=32, num_freq_bands=8, fft_bins=32): super().__init__() self.num_labels = num_labels self.supports_gradient_checkpointing = True self.backbone = Swinv2Model.from_pretrained(backbone_name) backbone_dim = self.backbone.config.hidden_size self.srm = SRMFilterBank() self.srm_encoder = nn.Sequential( nn.Conv2d(30, 64, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(64), nn.GELU(), nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(128), nn.GELU(), nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(256), nn.GELU(), nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(), ) srm_dim = 256 self.dct_analyzer = DCTFrequencyAnalyzer(patch_size=dct_patch_size, num_freq_bands=num_freq_bands) dct_dim = num_freq_bands * 2 + 6 self.fft_analyzer = FFTSpectralAnalyzer(num_bins=fft_bins) fft_dim = fft_bins + 4 freq_total_dim = srm_dim + dct_dim + fft_dim self.freq_proj = nn.Sequential( nn.Linear(freq_total_dim, 256), nn.GELU(), nn.Dropout(0.3), nn.Linear(256, 128), ) self.classifier = nn.Sequential( nn.Linear(backbone_dim + 128, 512), nn.GELU(), nn.Dropout(0.3), nn.Linear(512, 128), nn.GELU(), nn.Dropout(0.1), nn.Linear(128, num_labels), ) self.loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1) def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): self.backbone.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs) def gradient_checkpointing_disable(self): self.backbone.gradient_checkpointing_disable() def forward(self, pixel_values, labels=None, **kwargs): backbone_out = self.backbone(pixel_values=pixel_values) semantic_feats = backbone_out.pooler_output srm_maps = self.srm(pixel_values) srm_feats = self.srm_encoder(srm_maps) dct_feats = self.dct_analyzer(pixel_values) fft_feats = self.fft_analyzer(pixel_values) freq_feats = torch.cat([srm_feats, dct_feats, fft_feats], dim=1) freq_proj = self.freq_proj(freq_feats) fused = torch.cat([semantic_feats, freq_proj], dim=1) logits = self.classifier(fused) loss = None if labels is not None: loss = self.loss_fn(logits, labels) return {"loss": loss, "logits": logits} # ============================================================ # 5. DATASET & TRANSFORMS # ============================================================ class CommunityForensicsDataset(Dataset): """Wraps OwensLab/CommunityForensics-Small for PyTorch training.""" def __init__(self, hf_dataset, transform=None, is_train=True): self.dataset = hf_dataset self.transform = transform self.is_train = is_train def __len__(self): return len(self.dataset) def __getitem__(self, idx): item = self.dataset[idx] img_data = item['image_data'] if isinstance(img_data, str): import base64 img_data = base64.b64decode(img_data) elif isinstance(img_data, list): img_data = bytes(img_data) img = Image.open(io.BytesIO(img_data)).convert('RGB') if self.is_train: img = self._social_media_augment(img) if self.transform: pixel_values = self.transform(img) else: pixel_values = ToTensor()(img) return {'pixel_values': pixel_values, 'labels': item['label']} def _social_media_augment(self, img): """Simulate social media compression/resize artifacts for robustness.""" if random.random() < 0.10: quality = random.randint(30, 95) buffer = io.BytesIO() img.save(buffer, format='JPEG', quality=quality) buffer.seek(0) img = Image.open(buffer).convert('RGB') if random.random() < 0.10: from PIL import ImageFilter radius = random.uniform(0.1, 2.0) img = img.filter(ImageFilter.GaussianBlur(radius=radius)) if random.random() < 0.05: w, h = img.size scale = random.uniform(0.5, 0.9) small = img.resize((int(w*scale), int(h*scale)), Image.BILINEAR) img = small.resize((w, h), Image.BILINEAR) return img # ============================================================ # 6. CUSTOM TRAINER # ============================================================ class FreqDetectorTrainer(Trainer): def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): labels = inputs.pop("labels") outputs = model(pixel_values=inputs["pixel_values"], labels=labels) loss = outputs["loss"] return (loss, outputs) if return_outputs else loss # ============================================================ # 7. MAIN TRAINING SCRIPT # ============================================================ def main(): import argparse parser = argparse.ArgumentParser() parser.add_argument("--max_train_samples", type=int, default=None) parser.add_argument("--max_eval_samples", type=int, default=None) parser.add_argument("--num_train_epochs", type=int, default=5) parser.add_argument("--per_device_train_batch_size", type=int, default=16) parser.add_argument("--gradient_accumulation_steps", type=int, default=4) parser.add_argument("--learning_rate", type=float, default=2e-5) parser.add_argument("--output_dir", type=str, default="ai-image-detector") parser.add_argument("--hub_model_id", type=str, default="Reju983/ai-generated-image-detector") parser.add_argument("--image_size", type=int, default=256) parser.add_argument("--test_mode", action="store_true") args = parser.parse_args() print("=" * 60) print("AI-Generated Image Detector Training") print("Architecture: SwinV2 + SRM + DCT + FFT") print(f"Dataset: OwensLab/CommunityForensics-Small") print("=" * 60) trackio.init(project="ai-image-detector", name="swinv2-srm-dct-fft") print("\n[1/5] Loading dataset...") if args.test_mode: ds = hf_datasets.load_dataset( "OwensLab/CommunityForensics-Small", split="train[:200]", trust_remote_code=True, ) ds = ds.train_test_split(test_size=0.5, seed=42) train_ds = ds["train"] eval_ds = ds["test"] else: full_ds = hf_datasets.load_dataset( "OwensLab/CommunityForensics-Small", split="train", trust_remote_code=True, ) split = full_ds.train_test_split(test_size=0.05, seed=42, stratify_by_column="label") train_ds = split["train"] eval_ds = split["test"] if args.max_train_samples: train_ds = train_ds.select(range(min(args.max_train_samples, len(train_ds)))) if args.max_eval_samples: eval_ds = eval_ds.select(range(min(args.max_eval_samples, len(eval_ds)))) print(f" Train: {len(train_ds)}, Eval: {len(eval_ds)}") img_size = args.image_size train_transform = Compose([ RandomResizedCrop((img_size, img_size), scale=(0.8, 1.0)), RandomHorizontalFlip(), ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1), ToTensor(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) eval_transform = Compose([ Resize((img_size + 32, img_size + 32)), CenterCrop((img_size, img_size)), ToTensor(), Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), ]) train_dataset = CommunityForensicsDataset(train_ds, transform=train_transform, is_train=True) eval_dataset = CommunityForensicsDataset(eval_ds, transform=eval_transform, is_train=False) print("\n[3/5] Building model...") model = FrequencyAwareDetector( backbone_name="microsoft/swinv2-tiny-patch4-window8-256", num_labels=2, dct_patch_size=32, num_freq_bands=8, fft_bins=32, ) total_params = sum(p.numel() for p in model.parameters()) print(f" Parameters: {total_params:,}") accuracy_metric = evaluate.load("accuracy") def compute_metrics(eval_pred): predictions, labels = eval_pred if isinstance(predictions, dict): predictions = predictions["logits"] predictions = np.argmax(predictions, axis=1) acc = accuracy_metric.compute(predictions=predictions, references=labels) real_mask = labels == 0 fake_mask = labels == 1 real_acc = (predictions[real_mask] == labels[real_mask]).mean() if real_mask.sum() > 0 else 0 fake_acc = (predictions[fake_mask] == labels[fake_mask]).mean() if fake_mask.sum() > 0 else 0 return {"accuracy": acc["accuracy"], "real_accuracy": float(real_acc), "fake_accuracy": float(fake_acc)} def data_collator(features): pixel_values = torch.stack([f["pixel_values"] for f in features]) labels = torch.tensor([f["labels"] for f in features], dtype=torch.long) return {"pixel_values": pixel_values, "labels": labels} training_args = TrainingArguments( output_dir=args.output_dir, remove_unused_columns=False, eval_strategy="epoch", save_strategy="epoch", learning_rate=args.learning_rate, per_device_train_batch_size=args.per_device_train_batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, per_device_eval_batch_size=args.per_device_train_batch_size, num_train_epochs=args.num_train_epochs, warmup_ratio=0.1, weight_decay=0.01, bf16=True, logging_steps=25, logging_strategy="steps", logging_first_step=True, disable_tqdm=True, load_best_model_at_end=True, metric_for_best_model="accuracy", greater_is_better=True, push_to_hub=True, hub_model_id=args.hub_model_id, hub_strategy="end", save_total_limit=2, dataloader_num_workers=4, gradient_checkpointing=True, report_to="trackio", run_name="swinv2-srm-dct-fft", label_names=["labels"], ) trainer = FreqDetectorTrainer( model=model, args=training_args, data_collator=data_collator, train_dataset=train_dataset, eval_dataset=eval_dataset, compute_metrics=compute_metrics, ) print("\n[5/5] Training...") trainer.train() metrics = trainer.evaluate() for k, v in metrics.items(): print(f" {k}: {v}") save_dir = os.path.join(args.output_dir, "final_model") os.makedirs(save_dir, exist_ok=True) torch.save(model.state_dict(), os.path.join(save_dir, "model_state_dict.pt")) import json config = { "architecture": "FrequencyAwareDetector", "backbone_name": "microsoft/swinv2-tiny-patch4-window8-256", "num_labels": 2, "dct_patch_size": 32, "num_freq_bands": 8, "fft_bins": 32, "id2label": {"0": "real", "1": "ai_generated"}, "label2id": {"real": 0, "ai_generated": 1}, } with open(os.path.join(save_dir, "config.json"), "w") as f: json.dump(config, f, indent=2) trainer.push_to_hub( commit_message="AI-Generated Image Detector: SwinV2 + SRM + DCT + FFT", tags=["image-classification", "ai-image-detection", "deepfake-detection", "frequency-analysis", "swinv2", "srm", "dct", "fft"], ) from huggingface_hub import HfApi api = HfApi() api.upload_file( path_or_fileobj=os.path.join(save_dir, "model_state_dict.pt"), path_in_repo="model_state_dict.pt", repo_id=args.hub_model_id, ) api.upload_file( path_or_fileobj=os.path.join(save_dir, "config.json"), path_in_repo="detector_config.json", repo_id=args.hub_model_id, ) print(f"\nDone! Model at: https://huggingface.co/{args.hub_model_id}") if __name__ == "__main__": main()