| """ |
| AI-Generated Image Detector |
| Architecture: SwinV2 backbone + SRM High-Pass Filter + DCT Frequency Analysis + FFT Spectral Analysis |
| Dataset: OwensLab/CommunityForensics-Small (556K images, 4803 generators) |
| Based on: AIDE paper (arxiv:2406.19435) + CommunityForensics (CVPR 2025) |
| """ |
| import os |
| import io |
| import math |
| import random |
| import numpy as np |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from PIL import Image |
| from torch.utils.data import Dataset, DataLoader |
| from transformers import ( |
| AutoImageProcessor, |
| Swinv2Model, |
| TrainingArguments, |
| Trainer, |
| DefaultDataCollator, |
| ) |
| from torchvision.transforms import ( |
| Compose, Normalize, Resize, CenterCrop, RandomResizedCrop, |
| RandomHorizontalFlip, ToTensor, ColorJitter, |
| ) |
| import evaluate |
| import datasets as hf_datasets |
| import trackio |
|
|
|
|
| |
| |
| |
|
|
| def get_srm_kernels(): |
| """Generate 30 SRM high-pass filter kernels (5x5).""" |
| f1 = np.array([[0,0,0,0,0],[0,0,0,0,0],[0,0,-1,1,0],[0,0,0,0,0],[0,0,0,0,0]], dtype=np.float32) |
| f2 = np.array([[0,0,0,0,0],[0,0,0,0,0],[0,0,-1,0,0],[0,0,1,0,0],[0,0,0,0,0]], dtype=np.float32) |
| f3 = np.array([[0,0,0,0,0],[0,0,0,0,0],[0,1,-2,1,0],[0,0,0,0,0],[0,0,0,0,0]], dtype=np.float32) |
| f4 = np.array([[0,0,0,0,0],[0,0,1,0,0],[0,0,-2,0,0],[0,0,1,0,0],[0,0,0,0,0]], dtype=np.float32) |
| f5 = np.array([[0,0,0,0,0],[0,1,0,0,0],[0,0,-2,0,0],[0,0,0,1,0],[0,0,0,0,0]], dtype=np.float32) |
| f6 = np.array([[0,0,0,0,0],[0,0,0,1,0],[0,0,-2,0,0],[0,1,0,0,0],[0,0,0,0,0]], dtype=np.float32) |
| f7 = np.array([[0,0,0,0,0],[0,0,0,0,0],[0,-1,3,-3,1],[0,0,0,0,0],[0,0,0,0,0]], dtype=np.float32) |
| f8 = np.array([[0,0,0,0,0],[0,0,-1,0,0],[0,0,3,0,0],[0,0,-3,0,0],[0,0,1,0,0]], dtype=np.float32) |
| f9 = np.array([[0,0,0,0,0],[0,0,1,0,0],[0,1,-4,1,0],[0,0,1,0,0],[0,0,0,0,0]], dtype=np.float32) |
| f10 = np.array([[0,0,0,0,0],[0,1,1,1,0],[0,1,-8,1,0],[0,1,1,1,0],[0,0,0,0,0]], dtype=np.float32) / 3.0 |
| f11 = np.array([[0,0,0,0,0],[0,-1,2,-1,0],[0,2,-4,2,0],[0,-1,2,-1,0],[0,0,0,0,0]], dtype=np.float32) |
| f12 = np.array([[0,0,0,0,0],[0,0,0,0,0],[-1,2,-2,2,-1],[0,0,0,0,0],[0,0,0,0,0]], dtype=np.float32) / 2.0 |
| f13 = np.array([[0,0,-1,0,0],[0,0,2,0,0],[0,0,-2,0,0],[0,0,2,0,0],[0,0,-1,0,0]], dtype=np.float32) / 2.0 |
| f14 = np.array([[0,0,-1,0,0],[0,0,2,0,0],[-1,2,-4,2,-1],[0,0,2,0,0],[0,0,-1,0,0]], dtype=np.float32) / 4.0 |
| f15 = np.array([[-1,2,-2,2,-1],[2,-6,8,-6,2],[-2,8,-12,8,-2],[2,-6,8,-6,2],[-1,2,-2,2,-1]], dtype=np.float32) / 12.0 |
| spam_h = np.array([[0,0,0,0,0],[0,0,0,0,0],[0,-1,2,-1,0],[0,0,0,0,0],[0,0,0,0,0]], dtype=np.float32) |
| spam_v = np.array([[0,0,0,0,0],[0,0,-1,0,0],[0,0,2,0,0],[0,0,-1,0,0],[0,0,0,0,0]], dtype=np.float32) |
| spam_d1 = np.array([[0,0,0,0,0],[0,-1,0,0,0],[0,0,2,0,0],[0,0,0,-1,0],[0,0,0,0,0]], dtype=np.float32) |
| spam_d2 = np.array([[0,0,0,0,0],[0,0,0,-1,0],[0,0,2,0,0],[0,-1,0,0,0],[0,0,0,0,0]], dtype=np.float32) |
| spam3_h = np.array([[0,0,0,0,0],[0,0,0,0,0],[1,-3,3,-1,0],[0,0,0,0,0],[0,0,0,0,0]], dtype=np.float32) |
| spam3_v = np.array([[0,0,1,0,0],[0,0,-3,0,0],[0,0,3,0,0],[0,0,-1,0,0],[0,0,0,0,0]], dtype=np.float32) |
| sq5_1 = np.array([[0,0,0,0,0],[0,0,1,0,0],[0,1,-4,1,0],[0,0,1,0,0],[0,0,0,0,0]], dtype=np.float32) / 2.0 |
| sq5_2 = np.array([[0,0,0,0,0],[0,1,0,1,0],[0,0,-4,0,0],[0,1,0,1,0],[0,0,0,0,0]], dtype=np.float32) / 2.0 |
| cross1 = np.array([[0,0,-1,0,0],[0,0,2,0,0],[0,0,-2,0,0],[0,0,2,0,0],[0,0,-1,0,0]], dtype=np.float32) / 2.0 |
| cross2 = np.array([[0,0,0,0,0],[0,0,0,0,0],[-1,2,-2,2,-1],[0,0,0,0,0],[0,0,0,0,0]], dtype=np.float32) / 2.0 |
| edge_d1 = np.array([[-1,0,0,0,0],[0,2,0,0,0],[0,0,-2,0,0],[0,0,0,2,0],[0,0,0,0,-1]], dtype=np.float32) / 2.0 |
| edge_d2 = np.array([[0,0,0,0,-1],[0,0,0,2,0],[0,0,-2,0,0],[0,2,0,0,0],[-1,0,0,0,0]], dtype=np.float32) / 2.0 |
| gabor_h = np.array([[0,0,0,0,0],[1,-1,0,-1,1],[0,0,0,0,0],[-1,1,0,1,-1],[0,0,0,0,0]], dtype=np.float32) / 4.0 |
| gabor_v = np.array([[0,1,0,-1,0],[0,-1,0,1,0],[0,0,0,0,0],[0,-1,0,1,0],[0,1,0,-1,0]], dtype=np.float32) / 4.0 |
| |
| all_filters = [f1,f2,f3,f4,f5,f6,f7,f8,f9,f10,f11,f12,f13,f14,f15, |
| spam_h,spam_v,spam_d1,spam_d2,spam3_h,spam3_v, |
| sq5_1,sq5_2,cross1,cross2,edge_d1,edge_d2,gabor_h,gabor_v,f15] |
| return all_filters[:30] |
|
|
|
|
| class SRMFilterBank(nn.Module): |
| """SRM High-Pass Filter Bank - 30 fixed forensic filters.""" |
| def __init__(self): |
| super().__init__() |
| kernels = get_srm_kernels() |
| weight = torch.stack([torch.tensor(k) for k in kernels]).unsqueeze(1) |
| weight = weight.repeat(1, 3, 1, 1) / 3.0 |
| self.register_buffer('weight', weight) |
| |
| def forward(self, x): |
| return F.conv2d(x, self.weight, padding=2) |
|
|
|
|
| |
| |
| |
|
|
| class DCTFrequencyAnalyzer(nn.Module): |
| """Extracts frequency-domain features using 2D DCT on image patches.""" |
| def __init__(self, patch_size=32, num_freq_bands=8): |
| super().__init__() |
| self.patch_size = patch_size |
| self.num_freq_bands = num_freq_bands |
| N = patch_size |
| dct_mat = torch.zeros(N, N) |
| for k in range(N): |
| for n in range(N): |
| if k == 0: |
| dct_mat[k, n] = math.sqrt(1.0 / N) |
| else: |
| dct_mat[k, n] = math.sqrt(2.0 / N) * math.cos(math.pi * (2*n + 1) * k / (2*N)) |
| self.register_buffer('dct_mat', dct_mat) |
| self.register_buffer('dct_mat_t', dct_mat.t()) |
| |
| def dct2d(self, x): |
| return torch.matmul(torch.matmul(self.dct_mat, x), self.dct_mat_t) |
| |
| def forward(self, x): |
| B, C, H, W = x.shape |
| ps = self.patch_size |
| gray = 0.299 * x[:, 0] + 0.587 * x[:, 1] + 0.114 * x[:, 2] |
| h_patches = H // ps |
| w_patches = W // ps |
| gray = gray[:, :h_patches*ps, :w_patches*ps] |
| patches = gray.unfold(1, ps, ps).unfold(2, ps, ps) |
| B_p, hp, wp = patches.shape[:3] |
| patches = patches.reshape(B_p * hp * wp, ps, ps) |
| dct_patches = self.dct2d(patches) |
| dct_patches = dct_patches.reshape(B_p, hp * wp, ps, ps) |
| |
| features = [] |
| freq_y = torch.arange(ps, device=x.device).float() |
| freq_x = torch.arange(ps, device=x.device).float() |
| fy, fx = torch.meshgrid(freq_y, freq_x, indexing='ij') |
| freq_dist = torch.sqrt(fy**2 + fx**2) |
| max_freq = math.sqrt(2) * ps |
| |
| for band in range(self.num_freq_bands): |
| lo = band * max_freq / self.num_freq_bands |
| hi = (band + 1) * max_freq / self.num_freq_bands |
| mask = ((freq_dist >= lo) & (freq_dist < hi)).float() |
| band_energy = (dct_patches ** 2 * mask.unsqueeze(0).unsqueeze(0)).sum(dim=(-2, -1)) |
| features.append(band_energy.mean(dim=1, keepdim=True)) |
| features.append(band_energy.std(dim=1, keepdim=True)) |
| |
| total_energy = (dct_patches ** 2).sum(dim=(-2, -1)) |
| weighted_freq = (dct_patches ** 2 * freq_dist.unsqueeze(0).unsqueeze(0)).sum(dim=(-2, -1)) |
| spectral_centroid = weighted_freq / (total_energy + 1e-8) |
| features.append(spectral_centroid.mean(dim=1, keepdim=True)) |
| features.append(spectral_centroid.std(dim=1, keepdim=True)) |
| |
| mid = ps // 2 |
| low_mask = (freq_dist < mid).float() |
| high_mask = (freq_dist >= mid).float() |
| low_energy = (dct_patches ** 2 * low_mask).sum(dim=(-2, -1)) |
| high_energy = (dct_patches ** 2 * high_mask).sum(dim=(-2, -1)) |
| hl_ratio = high_energy / (low_energy + 1e-8) |
| features.append(hl_ratio.mean(dim=1, keepdim=True)) |
| features.append(hl_ratio.std(dim=1, keepdim=True)) |
| |
| dc_values = dct_patches[:, :, 0, 0] |
| features.append(dc_values.mean(dim=1, keepdim=True)) |
| features.append(dc_values.std(dim=1, keepdim=True)) |
| |
| return torch.cat(features, dim=1) |
|
|
|
|
| |
| |
| |
|
|
| class FFTSpectralAnalyzer(nn.Module): |
| """Azimuthally averaged power spectrum + 1/f deviation analysis.""" |
| def __init__(self, num_bins=32): |
| super().__init__() |
| self.num_bins = num_bins |
| |
| def forward(self, x): |
| B, C, H, W = x.shape |
| gray = 0.299 * x[:, 0] + 0.587 * x[:, 1] + 0.114 * x[:, 2] |
| hann_y = torch.hann_window(H, device=x.device) |
| hann_x = torch.hann_window(W, device=x.device) |
| window = hann_y.unsqueeze(1) * hann_x.unsqueeze(0) |
| gray = gray * window.unsqueeze(0) |
| |
| fft = torch.fft.fft2(gray) |
| fft_shift = torch.fft.fftshift(fft) |
| power = torch.abs(fft_shift) ** 2 |
| |
| cy, cx = H // 2, W // 2 |
| y = torch.arange(H, device=x.device).float() - cy |
| xx = torch.arange(W, device=x.device).float() - cx |
| yy, xx = torch.meshgrid(y, xx, indexing='ij') |
| radius = torch.sqrt(yy**2 + xx**2) |
| |
| max_radius = min(cy, cx) |
| bin_width = max_radius / self.num_bins |
| |
| features = [] |
| for i in range(self.num_bins): |
| r_lo = i * bin_width |
| r_hi = (i + 1) * bin_width |
| mask = ((radius >= r_lo) & (radius < r_hi)).float() |
| count = mask.sum() + 1e-8 |
| bin_power = (power * mask.unsqueeze(0)).sum(dim=(-2, -1)) / count |
| features.append(bin_power.unsqueeze(1)) |
| |
| radial_spectrum = torch.cat(features, dim=1) |
| log_spectrum = torch.log1p(radial_spectrum) |
| |
| log_freq = torch.log1p(torch.arange(self.num_bins, device=x.device).float() + 1) |
| log_freq = log_freq.unsqueeze(0).expand(B, -1) |
| |
| xm = log_freq - log_freq.mean(dim=1, keepdim=True) |
| ym = log_spectrum - log_spectrum.mean(dim=1, keepdim=True) |
| slope = (xm * ym).sum(dim=1, keepdim=True) / ((xm**2).sum(dim=1, keepdim=True) + 1e-8) |
| intercept = log_spectrum.mean(dim=1, keepdim=True) - slope * log_freq.mean(dim=1, keepdim=True) |
| |
| predicted = slope * log_freq + intercept |
| residuals = log_spectrum - predicted |
| residual_std = residuals.std(dim=1, keepdim=True) |
| residual_max = residuals.max(dim=1, keepdim=True)[0] |
| |
| return torch.cat([log_spectrum, slope, intercept, residual_std, residual_max], dim=1) |
|
|
|
|
| |
| |
| |
|
|
| class FrequencyAwareDetector(nn.Module): |
| """Multi-branch AI image detector: SwinV2 + SRM + DCT + FFT""" |
| def __init__(self, backbone_name="microsoft/swinv2-tiny-patch4-window8-256", |
| num_labels=2, dct_patch_size=32, num_freq_bands=8, fft_bins=32): |
| super().__init__() |
| self.num_labels = num_labels |
| self.supports_gradient_checkpointing = True |
| |
| self.backbone = Swinv2Model.from_pretrained(backbone_name) |
| backbone_dim = self.backbone.config.hidden_size |
| |
| self.srm = SRMFilterBank() |
| self.srm_encoder = nn.Sequential( |
| nn.Conv2d(30, 64, kernel_size=3, stride=2, padding=1), |
| nn.BatchNorm2d(64), nn.GELU(), |
| nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1), |
| nn.BatchNorm2d(128), nn.GELU(), |
| nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1), |
| nn.BatchNorm2d(256), nn.GELU(), |
| nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(), |
| ) |
| srm_dim = 256 |
| |
| self.dct_analyzer = DCTFrequencyAnalyzer(patch_size=dct_patch_size, num_freq_bands=num_freq_bands) |
| dct_dim = num_freq_bands * 2 + 6 |
| |
| self.fft_analyzer = FFTSpectralAnalyzer(num_bins=fft_bins) |
| fft_dim = fft_bins + 4 |
| |
| freq_total_dim = srm_dim + dct_dim + fft_dim |
| self.freq_proj = nn.Sequential( |
| nn.Linear(freq_total_dim, 256), nn.GELU(), nn.Dropout(0.3), |
| nn.Linear(256, 128), |
| ) |
| |
| self.classifier = nn.Sequential( |
| nn.Linear(backbone_dim + 128, 512), nn.GELU(), nn.Dropout(0.3), |
| nn.Linear(512, 128), nn.GELU(), nn.Dropout(0.1), |
| nn.Linear(128, num_labels), |
| ) |
| |
| self.loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1) |
| |
| def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): |
| self.backbone.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs) |
| |
| def gradient_checkpointing_disable(self): |
| self.backbone.gradient_checkpointing_disable() |
| |
| def forward(self, pixel_values, labels=None, **kwargs): |
| backbone_out = self.backbone(pixel_values=pixel_values) |
| semantic_feats = backbone_out.pooler_output |
| |
| srm_maps = self.srm(pixel_values) |
| srm_feats = self.srm_encoder(srm_maps) |
| |
| dct_feats = self.dct_analyzer(pixel_values) |
| fft_feats = self.fft_analyzer(pixel_values) |
| |
| freq_feats = torch.cat([srm_feats, dct_feats, fft_feats], dim=1) |
| freq_proj = self.freq_proj(freq_feats) |
| |
| fused = torch.cat([semantic_feats, freq_proj], dim=1) |
| logits = self.classifier(fused) |
| |
| loss = None |
| if labels is not None: |
| loss = self.loss_fn(logits, labels) |
| |
| return {"loss": loss, "logits": logits} |
|
|
|
|
| |
| |
| |
|
|
| class CommunityForensicsDataset(Dataset): |
| """Wraps OwensLab/CommunityForensics-Small for PyTorch training.""" |
| def __init__(self, hf_dataset, transform=None, is_train=True): |
| self.dataset = hf_dataset |
| self.transform = transform |
| self.is_train = is_train |
| |
| def __len__(self): |
| return len(self.dataset) |
| |
| def __getitem__(self, idx): |
| item = self.dataset[idx] |
| img_data = item['image_data'] |
| if isinstance(img_data, str): |
| import base64 |
| img_data = base64.b64decode(img_data) |
| elif isinstance(img_data, list): |
| img_data = bytes(img_data) |
| |
| img = Image.open(io.BytesIO(img_data)).convert('RGB') |
| |
| if self.is_train: |
| img = self._social_media_augment(img) |
| |
| if self.transform: |
| pixel_values = self.transform(img) |
| else: |
| pixel_values = ToTensor()(img) |
| |
| return {'pixel_values': pixel_values, 'labels': item['label']} |
| |
| def _social_media_augment(self, img): |
| """Simulate social media compression/resize artifacts for robustness.""" |
| if random.random() < 0.10: |
| quality = random.randint(30, 95) |
| buffer = io.BytesIO() |
| img.save(buffer, format='JPEG', quality=quality) |
| buffer.seek(0) |
| img = Image.open(buffer).convert('RGB') |
| |
| if random.random() < 0.10: |
| from PIL import ImageFilter |
| radius = random.uniform(0.1, 2.0) |
| img = img.filter(ImageFilter.GaussianBlur(radius=radius)) |
| |
| if random.random() < 0.05: |
| w, h = img.size |
| scale = random.uniform(0.5, 0.9) |
| small = img.resize((int(w*scale), int(h*scale)), Image.BILINEAR) |
| img = small.resize((w, h), Image.BILINEAR) |
| |
| return img |
|
|
|
|
| |
| |
| |
|
|
| class FreqDetectorTrainer(Trainer): |
| def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None): |
| labels = inputs.pop("labels") |
| outputs = model(pixel_values=inputs["pixel_values"], labels=labels) |
| loss = outputs["loss"] |
| return (loss, outputs) if return_outputs else loss |
|
|
|
|
| |
| |
| |
|
|
| def main(): |
| import argparse |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--max_train_samples", type=int, default=None) |
| parser.add_argument("--max_eval_samples", type=int, default=None) |
| parser.add_argument("--num_train_epochs", type=int, default=5) |
| parser.add_argument("--per_device_train_batch_size", type=int, default=16) |
| parser.add_argument("--gradient_accumulation_steps", type=int, default=4) |
| parser.add_argument("--learning_rate", type=float, default=2e-5) |
| parser.add_argument("--output_dir", type=str, default="ai-image-detector") |
| parser.add_argument("--hub_model_id", type=str, default="Reju983/ai-generated-image-detector") |
| parser.add_argument("--image_size", type=int, default=256) |
| parser.add_argument("--test_mode", action="store_true") |
| args = parser.parse_args() |
| |
| print("=" * 60) |
| print("AI-Generated Image Detector Training") |
| print("Architecture: SwinV2 + SRM + DCT + FFT") |
| print(f"Dataset: OwensLab/CommunityForensics-Small") |
| print("=" * 60) |
| |
| trackio.init(project="ai-image-detector", name="swinv2-srm-dct-fft") |
| |
| print("\n[1/5] Loading dataset...") |
| if args.test_mode: |
| ds = hf_datasets.load_dataset( |
| "OwensLab/CommunityForensics-Small", split="train[:200]", trust_remote_code=True, |
| ) |
| ds = ds.train_test_split(test_size=0.5, seed=42) |
| train_ds = ds["train"] |
| eval_ds = ds["test"] |
| else: |
| full_ds = hf_datasets.load_dataset( |
| "OwensLab/CommunityForensics-Small", split="train", trust_remote_code=True, |
| ) |
| split = full_ds.train_test_split(test_size=0.05, seed=42, stratify_by_column="label") |
| train_ds = split["train"] |
| eval_ds = split["test"] |
| |
| if args.max_train_samples: |
| train_ds = train_ds.select(range(min(args.max_train_samples, len(train_ds)))) |
| if args.max_eval_samples: |
| eval_ds = eval_ds.select(range(min(args.max_eval_samples, len(eval_ds)))) |
| |
| print(f" Train: {len(train_ds)}, Eval: {len(eval_ds)}") |
| |
| img_size = args.image_size |
| train_transform = Compose([ |
| RandomResizedCrop((img_size, img_size), scale=(0.8, 1.0)), |
| RandomHorizontalFlip(), |
| ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1), |
| ToTensor(), |
| Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ]) |
| eval_transform = Compose([ |
| Resize((img_size + 32, img_size + 32)), |
| CenterCrop((img_size, img_size)), |
| ToTensor(), |
| Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), |
| ]) |
| |
| train_dataset = CommunityForensicsDataset(train_ds, transform=train_transform, is_train=True) |
| eval_dataset = CommunityForensicsDataset(eval_ds, transform=eval_transform, is_train=False) |
| |
| print("\n[3/5] Building model...") |
| model = FrequencyAwareDetector( |
| backbone_name="microsoft/swinv2-tiny-patch4-window8-256", |
| num_labels=2, dct_patch_size=32, num_freq_bands=8, fft_bins=32, |
| ) |
| total_params = sum(p.numel() for p in model.parameters()) |
| print(f" Parameters: {total_params:,}") |
| |
| accuracy_metric = evaluate.load("accuracy") |
| def compute_metrics(eval_pred): |
| predictions, labels = eval_pred |
| if isinstance(predictions, dict): predictions = predictions["logits"] |
| predictions = np.argmax(predictions, axis=1) |
| acc = accuracy_metric.compute(predictions=predictions, references=labels) |
| real_mask = labels == 0 |
| fake_mask = labels == 1 |
| real_acc = (predictions[real_mask] == labels[real_mask]).mean() if real_mask.sum() > 0 else 0 |
| fake_acc = (predictions[fake_mask] == labels[fake_mask]).mean() if fake_mask.sum() > 0 else 0 |
| return {"accuracy": acc["accuracy"], "real_accuracy": float(real_acc), "fake_accuracy": float(fake_acc)} |
| |
| def data_collator(features): |
| pixel_values = torch.stack([f["pixel_values"] for f in features]) |
| labels = torch.tensor([f["labels"] for f in features], dtype=torch.long) |
| return {"pixel_values": pixel_values, "labels": labels} |
| |
| training_args = TrainingArguments( |
| output_dir=args.output_dir, |
| remove_unused_columns=False, |
| eval_strategy="epoch", save_strategy="epoch", |
| learning_rate=args.learning_rate, |
| per_device_train_batch_size=args.per_device_train_batch_size, |
| gradient_accumulation_steps=args.gradient_accumulation_steps, |
| per_device_eval_batch_size=args.per_device_train_batch_size, |
| num_train_epochs=args.num_train_epochs, |
| warmup_ratio=0.1, weight_decay=0.01, bf16=True, |
| logging_steps=25, logging_strategy="steps", logging_first_step=True, |
| disable_tqdm=True, |
| load_best_model_at_end=True, metric_for_best_model="accuracy", greater_is_better=True, |
| push_to_hub=True, hub_model_id=args.hub_model_id, hub_strategy="end", |
| save_total_limit=2, dataloader_num_workers=4, |
| gradient_checkpointing=True, |
| report_to="trackio", run_name="swinv2-srm-dct-fft", |
| label_names=["labels"], |
| ) |
| |
| trainer = FreqDetectorTrainer( |
| model=model, args=training_args, data_collator=data_collator, |
| train_dataset=train_dataset, eval_dataset=eval_dataset, |
| compute_metrics=compute_metrics, |
| ) |
| |
| print("\n[5/5] Training...") |
| trainer.train() |
| |
| metrics = trainer.evaluate() |
| for k, v in metrics.items(): print(f" {k}: {v}") |
| |
| save_dir = os.path.join(args.output_dir, "final_model") |
| os.makedirs(save_dir, exist_ok=True) |
| torch.save(model.state_dict(), os.path.join(save_dir, "model_state_dict.pt")) |
| |
| import json |
| config = { |
| "architecture": "FrequencyAwareDetector", |
| "backbone_name": "microsoft/swinv2-tiny-patch4-window8-256", |
| "num_labels": 2, "dct_patch_size": 32, "num_freq_bands": 8, "fft_bins": 32, |
| "id2label": {"0": "real", "1": "ai_generated"}, |
| "label2id": {"real": 0, "ai_generated": 1}, |
| } |
| with open(os.path.join(save_dir, "config.json"), "w") as f: |
| json.dump(config, f, indent=2) |
| |
| trainer.push_to_hub( |
| commit_message="AI-Generated Image Detector: SwinV2 + SRM + DCT + FFT", |
| tags=["image-classification", "ai-image-detection", "deepfake-detection", |
| "frequency-analysis", "swinv2", "srm", "dct", "fft"], |
| ) |
| |
| from huggingface_hub import HfApi |
| api = HfApi() |
| api.upload_file( |
| path_or_fileobj=os.path.join(save_dir, "model_state_dict.pt"), |
| path_in_repo="model_state_dict.pt", repo_id=args.hub_model_id, |
| ) |
| api.upload_file( |
| path_or_fileobj=os.path.join(save_dir, "config.json"), |
| path_in_repo="detector_config.json", repo_id=args.hub_model_id, |
| ) |
| |
| print(f"\nDone! Model at: https://huggingface.co/{args.hub_model_id}") |
|
|
| if __name__ == "__main__": |
| main() |
|
|