ash12321's picture
Upload model.py with huggingface_hub
121343d verified
"""
Fake Image Detection Ensemble - Model Definitions
9 specialized models for detecting AI-generated/fake images
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from scipy.ndimage import sobel
from sklearn.svm import OneClassSVM
from sklearn.ensemble import IsolationForest
from sklearn.neighbors import LocalOutlierFactor
from sklearn.mixture import GaussianMixture
from sklearn.preprocessing import StandardScaler
class EnhancedFreqVAE(nn.Module):
"""Enhanced Frequency-domain VAE with multi-scale analysis and attention"""
def __init__(self, ld=256):
super().__init__()
self.enc = nn.Sequential(
nn.Conv2d(3, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.LeakyReLU(0.2), nn.Dropout2d(0.1),
nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2), nn.Dropout2d(0.1),
nn.Conv2d(128, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2), nn.Dropout2d(0.1),
nn.Conv2d(256, 512, 4, 2, 1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2), nn.Dropout2d(0.1),
nn.Conv2d(512, 512, 4, 2, 1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2),
)
self.mu = nn.Linear(512*8*8, ld)
self.lv = nn.Linear(512*8*8, ld)
self.dec_fc = nn.Linear(ld, 512*8*8)
self.dec = nn.Sequential(
nn.ConvTranspose2d(512, 512, 4, 2, 1), nn.BatchNorm2d(512), nn.ReLU(),
nn.ConvTranspose2d(512, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.ReLU(),
nn.ConvTranspose2d(256, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU(),
nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.ReLU(),
nn.ConvTranspose2d(64, 3, 4, 2, 1)
)
def encode(self, x):
xf = torch.fft.fft2(x)
xf_mag = torch.log(torch.abs(xf) + 1e-8)
xf_phase = torch.angle(xf)
xf_combined = xf_mag * 0.8 + xf_phase * 0.2
h = self.enc(xf_combined).view(x.size(0), -1)
return self.mu(h), self.lv(h)
def forward(self, x):
mu, lv = self.encode(x)
z = mu + torch.randn_like(mu) * torch.exp(0.5*lv)
return self.dec(self.dec_fc(z).view(x.size(0), 512, 8, 8)), mu, lv
def score(self, img, dev):
self.eval()
img = img.to(dev)
with torch.no_grad():
if img.dim()==3: img=img.unsqueeze(0)
rc, mu, lv = self(img)
xf = torch.fft.fft2(img)
xf_mag = torch.log(torch.abs(xf) + 1e-8)
xf_phase = torch.angle(xf)
xf_combined = xf_mag * 0.8 + xf_phase * 0.2
recon = F.mse_loss(rc, xf_combined, reduction='sum')
kl = -0.5 * torch.sum(1 + lv - mu.pow(2) - lv.exp())
return (recon + 0.15*kl).item()
class EdgeNormalizingFlow(nn.Module):
"""Normalizing flow for edge probability density"""
def __init__(self, feature_dim=32):
super().__init__()
self.feature_dim = feature_dim
self.flows = nn.ModuleList([
nn.Sequential(
nn.Linear(feature_dim, feature_dim*2), nn.ReLU(),
nn.Linear(feature_dim*2, feature_dim*2), nn.ReLU(),
nn.Linear(feature_dim*2, feature_dim)
) for _ in range(4)
])
self.base_mean = nn.Parameter(torch.zeros(feature_dim))
self.base_logstd = nn.Parameter(torch.zeros(feature_dim))
def extract_edge_features(self, img):
if torch.is_tensor(img):
im = img.permute(1,2,0).cpu().numpy()
im = im*np.array([0.229,0.224,0.225]) + np.array([0.485,0.456,0.406])
im = np.clip(im, 0, 1)
else:
im = np.array(img)
gray = np.mean(im, 2)
ex, ey = sobel(gray, 0), sobel(gray, 1)
em = np.sqrt(ex**2 + ey**2)
features = []
for scale in [1, 2, 4, 8]:
if scale > 1:
scaled = gray[::scale, ::scale]
ex_s, ey_s = sobel(scaled, 0), sobel(scaled, 1)
em_s = np.sqrt(ex_s**2 + ey_s**2)
else:
em_s = em
features.extend([
np.mean(em_s), np.std(em_s), np.max(em_s),
np.percentile(em_s, 50), np.percentile(em_s, 75),
np.percentile(em_s, 90), np.percentile(em_s, 95),
np.sum(em_s > 0.1) / em_s.size
])
return torch.tensor(features[:self.feature_dim], dtype=torch.float32)
def forward(self, x):
log_det = 0
for flow in self.flows:
x = x + flow(x)
return x, log_det
def log_prob(self, x):
z, log_det = self.forward(x)
log_pz = -0.5 * torch.sum((z - self.base_mean)**2 / torch.exp(2*self.base_logstd) + 2*self.base_logstd, dim=-1)
return log_pz + log_det
def score(self, img, dev):
self.eval()
self.to(dev)
with torch.no_grad():
feat = self.extract_edge_features(img).unsqueeze(0).to(dev)
return -self.log_prob(feat).item()
class SemanticDeepSVDD(nn.Module):
"""Deep SVDD with semantic features from ResNet"""
def __init__(self):
super().__init__()
from torchvision.models import resnet50
resnet = resnet50(weights='IMAGENET1K_V1')
self.features = nn.Sequential(*list(resnet.children())[:-1])
for i, param in enumerate(self.features.parameters()):
param.requires_grad = (i >= 100)
self.proj = nn.Sequential(
nn.Flatten(),
nn.Linear(2048, 1024), nn.BatchNorm1d(1024), nn.ReLU(), nn.Dropout(0.4),
nn.Linear(1024, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.3),
nn.Linear(512, 256)
)
self.center = None
def forward(self, x):
return self.proj(self.features(x))
def score(self, img, dev):
self.eval()
img = img.to(dev)
with torch.no_grad():
if img.dim()==3: img=img.unsqueeze(0)
return torch.sum((self(img) - self.center)**2, 1).mean().item()
class Ensemble:
"""9-model ensemble with adaptive threshold"""
def __init__(self, models_dict):
self.models = models_dict
self.wts = {
'freq_vae': 0.18,
'texture_ocsvm': 0.13,
'color_model': 0.09,
'edge_flow': 0.13,
'semantic_svdd': 0.17,
'stat': 0.09,
'iforest': 0.09,
'lof': 0.07,
'gmm': 0.05
}
self.norms = None
self.thresh = 0.0
def get_scores(self, img, dev):
return {
'freq_vae': self.models['freq_vae'].score(img, dev),
'texture_ocsvm': self.models['texture_ocsvm'].score(img),
'color_model': self.models['color_model'].score(img),
'edge_flow': self.models['edge_flow'].score(img, dev),
'semantic_svdd': self.models['semantic_svdd'].score(img, dev),
'stat': self.models['stat'].score(img),
'iforest': self.models['iforest'].score(img),
'lof': self.models['lof'].score(img),
'gmm': self.models['gmm'].score(img)
}
def predict(self, img, dev):
sc = self.get_scores(img, dev)
nsc = {k: (sc[k]-self.norms[k]['mean'])/(self.norms[k]['std']+1e-8)
for k in sc.keys()}
final = sum(self.wts[k]*nsc[k] for k in sc.keys())
return final > self.thresh, final, sc