Reju983's picture
Add training script: SwinV2 + SRM + DCT + FFT AI image detector
8540127 verified
"""
AI-Generated Image Detector
Architecture: SwinV2 backbone + SRM High-Pass Filter + DCT Frequency Analysis + FFT Spectral Analysis
Dataset: OwensLab/CommunityForensics-Small (556K images, 4803 generators)
Based on: AIDE paper (arxiv:2406.19435) + CommunityForensics (CVPR 2025)
"""
import os
import io
import math
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from transformers import (
AutoImageProcessor,
Swinv2Model,
TrainingArguments,
Trainer,
DefaultDataCollator,
)
from torchvision.transforms import (
Compose, Normalize, Resize, CenterCrop, RandomResizedCrop,
RandomHorizontalFlip, ToTensor, ColorJitter,
)
import evaluate
import datasets as hf_datasets
import trackio
# ============================================================
# 1. SRM HIGH-PASS FILTER BANK (30 fixed kernels, no gradient)
# ============================================================
def get_srm_kernels():
"""Generate 30 SRM high-pass filter kernels (5x5)."""
f1 = np.array([[0,0,0,0,0],[0,0,0,0,0],[0,0,-1,1,0],[0,0,0,0,0],[0,0,0,0,0]], dtype=np.float32)
f2 = np.array([[0,0,0,0,0],[0,0,0,0,0],[0,0,-1,0,0],[0,0,1,0,0],[0,0,0,0,0]], dtype=np.float32)
f3 = np.array([[0,0,0,0,0],[0,0,0,0,0],[0,1,-2,1,0],[0,0,0,0,0],[0,0,0,0,0]], dtype=np.float32)
f4 = np.array([[0,0,0,0,0],[0,0,1,0,0],[0,0,-2,0,0],[0,0,1,0,0],[0,0,0,0,0]], dtype=np.float32)
f5 = np.array([[0,0,0,0,0],[0,1,0,0,0],[0,0,-2,0,0],[0,0,0,1,0],[0,0,0,0,0]], dtype=np.float32)
f6 = np.array([[0,0,0,0,0],[0,0,0,1,0],[0,0,-2,0,0],[0,1,0,0,0],[0,0,0,0,0]], dtype=np.float32)
f7 = np.array([[0,0,0,0,0],[0,0,0,0,0],[0,-1,3,-3,1],[0,0,0,0,0],[0,0,0,0,0]], dtype=np.float32)
f8 = np.array([[0,0,0,0,0],[0,0,-1,0,0],[0,0,3,0,0],[0,0,-3,0,0],[0,0,1,0,0]], dtype=np.float32)
f9 = np.array([[0,0,0,0,0],[0,0,1,0,0],[0,1,-4,1,0],[0,0,1,0,0],[0,0,0,0,0]], dtype=np.float32)
f10 = np.array([[0,0,0,0,0],[0,1,1,1,0],[0,1,-8,1,0],[0,1,1,1,0],[0,0,0,0,0]], dtype=np.float32) / 3.0
f11 = np.array([[0,0,0,0,0],[0,-1,2,-1,0],[0,2,-4,2,0],[0,-1,2,-1,0],[0,0,0,0,0]], dtype=np.float32)
f12 = np.array([[0,0,0,0,0],[0,0,0,0,0],[-1,2,-2,2,-1],[0,0,0,0,0],[0,0,0,0,0]], dtype=np.float32) / 2.0
f13 = np.array([[0,0,-1,0,0],[0,0,2,0,0],[0,0,-2,0,0],[0,0,2,0,0],[0,0,-1,0,0]], dtype=np.float32) / 2.0
f14 = np.array([[0,0,-1,0,0],[0,0,2,0,0],[-1,2,-4,2,-1],[0,0,2,0,0],[0,0,-1,0,0]], dtype=np.float32) / 4.0
f15 = np.array([[-1,2,-2,2,-1],[2,-6,8,-6,2],[-2,8,-12,8,-2],[2,-6,8,-6,2],[-1,2,-2,2,-1]], dtype=np.float32) / 12.0
spam_h = np.array([[0,0,0,0,0],[0,0,0,0,0],[0,-1,2,-1,0],[0,0,0,0,0],[0,0,0,0,0]], dtype=np.float32)
spam_v = np.array([[0,0,0,0,0],[0,0,-1,0,0],[0,0,2,0,0],[0,0,-1,0,0],[0,0,0,0,0]], dtype=np.float32)
spam_d1 = np.array([[0,0,0,0,0],[0,-1,0,0,0],[0,0,2,0,0],[0,0,0,-1,0],[0,0,0,0,0]], dtype=np.float32)
spam_d2 = np.array([[0,0,0,0,0],[0,0,0,-1,0],[0,0,2,0,0],[0,-1,0,0,0],[0,0,0,0,0]], dtype=np.float32)
spam3_h = np.array([[0,0,0,0,0],[0,0,0,0,0],[1,-3,3,-1,0],[0,0,0,0,0],[0,0,0,0,0]], dtype=np.float32)
spam3_v = np.array([[0,0,1,0,0],[0,0,-3,0,0],[0,0,3,0,0],[0,0,-1,0,0],[0,0,0,0,0]], dtype=np.float32)
sq5_1 = np.array([[0,0,0,0,0],[0,0,1,0,0],[0,1,-4,1,0],[0,0,1,0,0],[0,0,0,0,0]], dtype=np.float32) / 2.0
sq5_2 = np.array([[0,0,0,0,0],[0,1,0,1,0],[0,0,-4,0,0],[0,1,0,1,0],[0,0,0,0,0]], dtype=np.float32) / 2.0
cross1 = np.array([[0,0,-1,0,0],[0,0,2,0,0],[0,0,-2,0,0],[0,0,2,0,0],[0,0,-1,0,0]], dtype=np.float32) / 2.0
cross2 = np.array([[0,0,0,0,0],[0,0,0,0,0],[-1,2,-2,2,-1],[0,0,0,0,0],[0,0,0,0,0]], dtype=np.float32) / 2.0
edge_d1 = np.array([[-1,0,0,0,0],[0,2,0,0,0],[0,0,-2,0,0],[0,0,0,2,0],[0,0,0,0,-1]], dtype=np.float32) / 2.0
edge_d2 = np.array([[0,0,0,0,-1],[0,0,0,2,0],[0,0,-2,0,0],[0,2,0,0,0],[-1,0,0,0,0]], dtype=np.float32) / 2.0
gabor_h = np.array([[0,0,0,0,0],[1,-1,0,-1,1],[0,0,0,0,0],[-1,1,0,1,-1],[0,0,0,0,0]], dtype=np.float32) / 4.0
gabor_v = np.array([[0,1,0,-1,0],[0,-1,0,1,0],[0,0,0,0,0],[0,-1,0,1,0],[0,1,0,-1,0]], dtype=np.float32) / 4.0
all_filters = [f1,f2,f3,f4,f5,f6,f7,f8,f9,f10,f11,f12,f13,f14,f15,
spam_h,spam_v,spam_d1,spam_d2,spam3_h,spam3_v,
sq5_1,sq5_2,cross1,cross2,edge_d1,edge_d2,gabor_h,gabor_v,f15]
return all_filters[:30]
class SRMFilterBank(nn.Module):
"""SRM High-Pass Filter Bank - 30 fixed forensic filters."""
def __init__(self):
super().__init__()
kernels = get_srm_kernels()
weight = torch.stack([torch.tensor(k) for k in kernels]).unsqueeze(1)
weight = weight.repeat(1, 3, 1, 1) / 3.0
self.register_buffer('weight', weight)
def forward(self, x):
return F.conv2d(x, self.weight, padding=2)
# ============================================================
# 2. DCT FREQUENCY ANALYSIS MODULE
# ============================================================
class DCTFrequencyAnalyzer(nn.Module):
"""Extracts frequency-domain features using 2D DCT on image patches."""
def __init__(self, patch_size=32, num_freq_bands=8):
super().__init__()
self.patch_size = patch_size
self.num_freq_bands = num_freq_bands
N = patch_size
dct_mat = torch.zeros(N, N)
for k in range(N):
for n in range(N):
if k == 0:
dct_mat[k, n] = math.sqrt(1.0 / N)
else:
dct_mat[k, n] = math.sqrt(2.0 / N) * math.cos(math.pi * (2*n + 1) * k / (2*N))
self.register_buffer('dct_mat', dct_mat)
self.register_buffer('dct_mat_t', dct_mat.t())
def dct2d(self, x):
return torch.matmul(torch.matmul(self.dct_mat, x), self.dct_mat_t)
def forward(self, x):
B, C, H, W = x.shape
ps = self.patch_size
gray = 0.299 * x[:, 0] + 0.587 * x[:, 1] + 0.114 * x[:, 2]
h_patches = H // ps
w_patches = W // ps
gray = gray[:, :h_patches*ps, :w_patches*ps]
patches = gray.unfold(1, ps, ps).unfold(2, ps, ps)
B_p, hp, wp = patches.shape[:3]
patches = patches.reshape(B_p * hp * wp, ps, ps)
dct_patches = self.dct2d(patches)
dct_patches = dct_patches.reshape(B_p, hp * wp, ps, ps)
features = []
freq_y = torch.arange(ps, device=x.device).float()
freq_x = torch.arange(ps, device=x.device).float()
fy, fx = torch.meshgrid(freq_y, freq_x, indexing='ij')
freq_dist = torch.sqrt(fy**2 + fx**2)
max_freq = math.sqrt(2) * ps
for band in range(self.num_freq_bands):
lo = band * max_freq / self.num_freq_bands
hi = (band + 1) * max_freq / self.num_freq_bands
mask = ((freq_dist >= lo) & (freq_dist < hi)).float()
band_energy = (dct_patches ** 2 * mask.unsqueeze(0).unsqueeze(0)).sum(dim=(-2, -1))
features.append(band_energy.mean(dim=1, keepdim=True))
features.append(band_energy.std(dim=1, keepdim=True))
total_energy = (dct_patches ** 2).sum(dim=(-2, -1))
weighted_freq = (dct_patches ** 2 * freq_dist.unsqueeze(0).unsqueeze(0)).sum(dim=(-2, -1))
spectral_centroid = weighted_freq / (total_energy + 1e-8)
features.append(spectral_centroid.mean(dim=1, keepdim=True))
features.append(spectral_centroid.std(dim=1, keepdim=True))
mid = ps // 2
low_mask = (freq_dist < mid).float()
high_mask = (freq_dist >= mid).float()
low_energy = (dct_patches ** 2 * low_mask).sum(dim=(-2, -1))
high_energy = (dct_patches ** 2 * high_mask).sum(dim=(-2, -1))
hl_ratio = high_energy / (low_energy + 1e-8)
features.append(hl_ratio.mean(dim=1, keepdim=True))
features.append(hl_ratio.std(dim=1, keepdim=True))
dc_values = dct_patches[:, :, 0, 0]
features.append(dc_values.mean(dim=1, keepdim=True))
features.append(dc_values.std(dim=1, keepdim=True))
return torch.cat(features, dim=1)
# ============================================================
# 3. FFT SPECTRAL ANALYSIS MODULE
# ============================================================
class FFTSpectralAnalyzer(nn.Module):
"""Azimuthally averaged power spectrum + 1/f deviation analysis."""
def __init__(self, num_bins=32):
super().__init__()
self.num_bins = num_bins
def forward(self, x):
B, C, H, W = x.shape
gray = 0.299 * x[:, 0] + 0.587 * x[:, 1] + 0.114 * x[:, 2]
hann_y = torch.hann_window(H, device=x.device)
hann_x = torch.hann_window(W, device=x.device)
window = hann_y.unsqueeze(1) * hann_x.unsqueeze(0)
gray = gray * window.unsqueeze(0)
fft = torch.fft.fft2(gray)
fft_shift = torch.fft.fftshift(fft)
power = torch.abs(fft_shift) ** 2
cy, cx = H // 2, W // 2
y = torch.arange(H, device=x.device).float() - cy
xx = torch.arange(W, device=x.device).float() - cx
yy, xx = torch.meshgrid(y, xx, indexing='ij')
radius = torch.sqrt(yy**2 + xx**2)
max_radius = min(cy, cx)
bin_width = max_radius / self.num_bins
features = []
for i in range(self.num_bins):
r_lo = i * bin_width
r_hi = (i + 1) * bin_width
mask = ((radius >= r_lo) & (radius < r_hi)).float()
count = mask.sum() + 1e-8
bin_power = (power * mask.unsqueeze(0)).sum(dim=(-2, -1)) / count
features.append(bin_power.unsqueeze(1))
radial_spectrum = torch.cat(features, dim=1)
log_spectrum = torch.log1p(radial_spectrum)
log_freq = torch.log1p(torch.arange(self.num_bins, device=x.device).float() + 1)
log_freq = log_freq.unsqueeze(0).expand(B, -1)
xm = log_freq - log_freq.mean(dim=1, keepdim=True)
ym = log_spectrum - log_spectrum.mean(dim=1, keepdim=True)
slope = (xm * ym).sum(dim=1, keepdim=True) / ((xm**2).sum(dim=1, keepdim=True) + 1e-8)
intercept = log_spectrum.mean(dim=1, keepdim=True) - slope * log_freq.mean(dim=1, keepdim=True)
predicted = slope * log_freq + intercept
residuals = log_spectrum - predicted
residual_std = residuals.std(dim=1, keepdim=True)
residual_max = residuals.max(dim=1, keepdim=True)[0]
return torch.cat([log_spectrum, slope, intercept, residual_std, residual_max], dim=1)
# ============================================================
# 4. FREQUENCY-AWARE DETECTOR MODEL
# ============================================================
class FrequencyAwareDetector(nn.Module):
"""Multi-branch AI image detector: SwinV2 + SRM + DCT + FFT"""
def __init__(self, backbone_name="microsoft/swinv2-tiny-patch4-window8-256",
num_labels=2, dct_patch_size=32, num_freq_bands=8, fft_bins=32):
super().__init__()
self.num_labels = num_labels
self.supports_gradient_checkpointing = True
self.backbone = Swinv2Model.from_pretrained(backbone_name)
backbone_dim = self.backbone.config.hidden_size
self.srm = SRMFilterBank()
self.srm_encoder = nn.Sequential(
nn.Conv2d(30, 64, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(64), nn.GELU(),
nn.Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(128), nn.GELU(),
nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
nn.BatchNorm2d(256), nn.GELU(),
nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(),
)
srm_dim = 256
self.dct_analyzer = DCTFrequencyAnalyzer(patch_size=dct_patch_size, num_freq_bands=num_freq_bands)
dct_dim = num_freq_bands * 2 + 6
self.fft_analyzer = FFTSpectralAnalyzer(num_bins=fft_bins)
fft_dim = fft_bins + 4
freq_total_dim = srm_dim + dct_dim + fft_dim
self.freq_proj = nn.Sequential(
nn.Linear(freq_total_dim, 256), nn.GELU(), nn.Dropout(0.3),
nn.Linear(256, 128),
)
self.classifier = nn.Sequential(
nn.Linear(backbone_dim + 128, 512), nn.GELU(), nn.Dropout(0.3),
nn.Linear(512, 128), nn.GELU(), nn.Dropout(0.1),
nn.Linear(128, num_labels),
)
self.loss_fn = nn.CrossEntropyLoss(label_smoothing=0.1)
def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None):
self.backbone.gradient_checkpointing_enable(gradient_checkpointing_kwargs=gradient_checkpointing_kwargs)
def gradient_checkpointing_disable(self):
self.backbone.gradient_checkpointing_disable()
def forward(self, pixel_values, labels=None, **kwargs):
backbone_out = self.backbone(pixel_values=pixel_values)
semantic_feats = backbone_out.pooler_output
srm_maps = self.srm(pixel_values)
srm_feats = self.srm_encoder(srm_maps)
dct_feats = self.dct_analyzer(pixel_values)
fft_feats = self.fft_analyzer(pixel_values)
freq_feats = torch.cat([srm_feats, dct_feats, fft_feats], dim=1)
freq_proj = self.freq_proj(freq_feats)
fused = torch.cat([semantic_feats, freq_proj], dim=1)
logits = self.classifier(fused)
loss = None
if labels is not None:
loss = self.loss_fn(logits, labels)
return {"loss": loss, "logits": logits}
# ============================================================
# 5. DATASET & TRANSFORMS
# ============================================================
class CommunityForensicsDataset(Dataset):
"""Wraps OwensLab/CommunityForensics-Small for PyTorch training."""
def __init__(self, hf_dataset, transform=None, is_train=True):
self.dataset = hf_dataset
self.transform = transform
self.is_train = is_train
def __len__(self):
return len(self.dataset)
def __getitem__(self, idx):
item = self.dataset[idx]
img_data = item['image_data']
if isinstance(img_data, str):
import base64
img_data = base64.b64decode(img_data)
elif isinstance(img_data, list):
img_data = bytes(img_data)
img = Image.open(io.BytesIO(img_data)).convert('RGB')
if self.is_train:
img = self._social_media_augment(img)
if self.transform:
pixel_values = self.transform(img)
else:
pixel_values = ToTensor()(img)
return {'pixel_values': pixel_values, 'labels': item['label']}
def _social_media_augment(self, img):
"""Simulate social media compression/resize artifacts for robustness."""
if random.random() < 0.10:
quality = random.randint(30, 95)
buffer = io.BytesIO()
img.save(buffer, format='JPEG', quality=quality)
buffer.seek(0)
img = Image.open(buffer).convert('RGB')
if random.random() < 0.10:
from PIL import ImageFilter
radius = random.uniform(0.1, 2.0)
img = img.filter(ImageFilter.GaussianBlur(radius=radius))
if random.random() < 0.05:
w, h = img.size
scale = random.uniform(0.5, 0.9)
small = img.resize((int(w*scale), int(h*scale)), Image.BILINEAR)
img = small.resize((w, h), Image.BILINEAR)
return img
# ============================================================
# 6. CUSTOM TRAINER
# ============================================================
class FreqDetectorTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
labels = inputs.pop("labels")
outputs = model(pixel_values=inputs["pixel_values"], labels=labels)
loss = outputs["loss"]
return (loss, outputs) if return_outputs else loss
# ============================================================
# 7. MAIN TRAINING SCRIPT
# ============================================================
def main():
import argparse
parser = argparse.ArgumentParser()
parser.add_argument("--max_train_samples", type=int, default=None)
parser.add_argument("--max_eval_samples", type=int, default=None)
parser.add_argument("--num_train_epochs", type=int, default=5)
parser.add_argument("--per_device_train_batch_size", type=int, default=16)
parser.add_argument("--gradient_accumulation_steps", type=int, default=4)
parser.add_argument("--learning_rate", type=float, default=2e-5)
parser.add_argument("--output_dir", type=str, default="ai-image-detector")
parser.add_argument("--hub_model_id", type=str, default="Reju983/ai-generated-image-detector")
parser.add_argument("--image_size", type=int, default=256)
parser.add_argument("--test_mode", action="store_true")
args = parser.parse_args()
print("=" * 60)
print("AI-Generated Image Detector Training")
print("Architecture: SwinV2 + SRM + DCT + FFT")
print(f"Dataset: OwensLab/CommunityForensics-Small")
print("=" * 60)
trackio.init(project="ai-image-detector", name="swinv2-srm-dct-fft")
print("\n[1/5] Loading dataset...")
if args.test_mode:
ds = hf_datasets.load_dataset(
"OwensLab/CommunityForensics-Small", split="train[:200]", trust_remote_code=True,
)
ds = ds.train_test_split(test_size=0.5, seed=42)
train_ds = ds["train"]
eval_ds = ds["test"]
else:
full_ds = hf_datasets.load_dataset(
"OwensLab/CommunityForensics-Small", split="train", trust_remote_code=True,
)
split = full_ds.train_test_split(test_size=0.05, seed=42, stratify_by_column="label")
train_ds = split["train"]
eval_ds = split["test"]
if args.max_train_samples:
train_ds = train_ds.select(range(min(args.max_train_samples, len(train_ds))))
if args.max_eval_samples:
eval_ds = eval_ds.select(range(min(args.max_eval_samples, len(eval_ds))))
print(f" Train: {len(train_ds)}, Eval: {len(eval_ds)}")
img_size = args.image_size
train_transform = Compose([
RandomResizedCrop((img_size, img_size), scale=(0.8, 1.0)),
RandomHorizontalFlip(),
ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1),
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
eval_transform = Compose([
Resize((img_size + 32, img_size + 32)),
CenterCrop((img_size, img_size)),
ToTensor(),
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])
train_dataset = CommunityForensicsDataset(train_ds, transform=train_transform, is_train=True)
eval_dataset = CommunityForensicsDataset(eval_ds, transform=eval_transform, is_train=False)
print("\n[3/5] Building model...")
model = FrequencyAwareDetector(
backbone_name="microsoft/swinv2-tiny-patch4-window8-256",
num_labels=2, dct_patch_size=32, num_freq_bands=8, fft_bins=32,
)
total_params = sum(p.numel() for p in model.parameters())
print(f" Parameters: {total_params:,}")
accuracy_metric = evaluate.load("accuracy")
def compute_metrics(eval_pred):
predictions, labels = eval_pred
if isinstance(predictions, dict): predictions = predictions["logits"]
predictions = np.argmax(predictions, axis=1)
acc = accuracy_metric.compute(predictions=predictions, references=labels)
real_mask = labels == 0
fake_mask = labels == 1
real_acc = (predictions[real_mask] == labels[real_mask]).mean() if real_mask.sum() > 0 else 0
fake_acc = (predictions[fake_mask] == labels[fake_mask]).mean() if fake_mask.sum() > 0 else 0
return {"accuracy": acc["accuracy"], "real_accuracy": float(real_acc), "fake_accuracy": float(fake_acc)}
def data_collator(features):
pixel_values = torch.stack([f["pixel_values"] for f in features])
labels = torch.tensor([f["labels"] for f in features], dtype=torch.long)
return {"pixel_values": pixel_values, "labels": labels}
training_args = TrainingArguments(
output_dir=args.output_dir,
remove_unused_columns=False,
eval_strategy="epoch", save_strategy="epoch",
learning_rate=args.learning_rate,
per_device_train_batch_size=args.per_device_train_batch_size,
gradient_accumulation_steps=args.gradient_accumulation_steps,
per_device_eval_batch_size=args.per_device_train_batch_size,
num_train_epochs=args.num_train_epochs,
warmup_ratio=0.1, weight_decay=0.01, bf16=True,
logging_steps=25, logging_strategy="steps", logging_first_step=True,
disable_tqdm=True,
load_best_model_at_end=True, metric_for_best_model="accuracy", greater_is_better=True,
push_to_hub=True, hub_model_id=args.hub_model_id, hub_strategy="end",
save_total_limit=2, dataloader_num_workers=4,
gradient_checkpointing=True,
report_to="trackio", run_name="swinv2-srm-dct-fft",
label_names=["labels"],
)
trainer = FreqDetectorTrainer(
model=model, args=training_args, data_collator=data_collator,
train_dataset=train_dataset, eval_dataset=eval_dataset,
compute_metrics=compute_metrics,
)
print("\n[5/5] Training...")
trainer.train()
metrics = trainer.evaluate()
for k, v in metrics.items(): print(f" {k}: {v}")
save_dir = os.path.join(args.output_dir, "final_model")
os.makedirs(save_dir, exist_ok=True)
torch.save(model.state_dict(), os.path.join(save_dir, "model_state_dict.pt"))
import json
config = {
"architecture": "FrequencyAwareDetector",
"backbone_name": "microsoft/swinv2-tiny-patch4-window8-256",
"num_labels": 2, "dct_patch_size": 32, "num_freq_bands": 8, "fft_bins": 32,
"id2label": {"0": "real", "1": "ai_generated"},
"label2id": {"real": 0, "ai_generated": 1},
}
with open(os.path.join(save_dir, "config.json"), "w") as f:
json.dump(config, f, indent=2)
trainer.push_to_hub(
commit_message="AI-Generated Image Detector: SwinV2 + SRM + DCT + FFT",
tags=["image-classification", "ai-image-detection", "deepfake-detection",
"frequency-analysis", "swinv2", "srm", "dct", "fft"],
)
from huggingface_hub import HfApi
api = HfApi()
api.upload_file(
path_or_fileobj=os.path.join(save_dir, "model_state_dict.pt"),
path_in_repo="model_state_dict.pt", repo_id=args.hub_model_id,
)
api.upload_file(
path_or_fileobj=os.path.join(save_dir, "config.json"),
path_in_repo="detector_config.json", repo_id=args.hub_model_id,
)
print(f"\nDone! Model at: https://huggingface.co/{args.hub_model_id}")
if __name__ == "__main__":
main()