Upload model.py with huggingface_hub
Browse files
model.py
ADDED
|
@@ -0,0 +1,198 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Fake Image Detection Ensemble - Model Definitions
|
| 3 |
+
9 specialized models for detecting AI-generated/fake images
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import torch.nn as nn
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
import numpy as np
|
| 10 |
+
from scipy.ndimage import sobel
|
| 11 |
+
from sklearn.svm import OneClassSVM
|
| 12 |
+
from sklearn.ensemble import IsolationForest
|
| 13 |
+
from sklearn.neighbors import LocalOutlierFactor
|
| 14 |
+
from sklearn.mixture import GaussianMixture
|
| 15 |
+
from sklearn.preprocessing import StandardScaler
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class EnhancedFreqVAE(nn.Module):
|
| 19 |
+
"""Enhanced Frequency-domain VAE with multi-scale analysis and attention"""
|
| 20 |
+
def __init__(self, ld=256):
|
| 21 |
+
super().__init__()
|
| 22 |
+
self.enc = nn.Sequential(
|
| 23 |
+
nn.Conv2d(3, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.LeakyReLU(0.2), nn.Dropout2d(0.1),
|
| 24 |
+
nn.Conv2d(64, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2), nn.Dropout2d(0.1),
|
| 25 |
+
nn.Conv2d(128, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2), nn.Dropout2d(0.1),
|
| 26 |
+
nn.Conv2d(256, 512, 4, 2, 1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2), nn.Dropout2d(0.1),
|
| 27 |
+
nn.Conv2d(512, 512, 4, 2, 1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2),
|
| 28 |
+
)
|
| 29 |
+
self.mu = nn.Linear(512*8*8, ld)
|
| 30 |
+
self.lv = nn.Linear(512*8*8, ld)
|
| 31 |
+
self.dec_fc = nn.Linear(ld, 512*8*8)
|
| 32 |
+
self.dec = nn.Sequential(
|
| 33 |
+
nn.ConvTranspose2d(512, 512, 4, 2, 1), nn.BatchNorm2d(512), nn.ReLU(),
|
| 34 |
+
nn.ConvTranspose2d(512, 256, 4, 2, 1), nn.BatchNorm2d(256), nn.ReLU(),
|
| 35 |
+
nn.ConvTranspose2d(256, 128, 4, 2, 1), nn.BatchNorm2d(128), nn.ReLU(),
|
| 36 |
+
nn.ConvTranspose2d(128, 64, 4, 2, 1), nn.BatchNorm2d(64), nn.ReLU(),
|
| 37 |
+
nn.ConvTranspose2d(64, 3, 4, 2, 1)
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
def encode(self, x):
|
| 41 |
+
xf = torch.fft.fft2(x)
|
| 42 |
+
xf_mag = torch.log(torch.abs(xf) + 1e-8)
|
| 43 |
+
xf_phase = torch.angle(xf)
|
| 44 |
+
xf_combined = xf_mag * 0.8 + xf_phase * 0.2
|
| 45 |
+
h = self.enc(xf_combined).view(x.size(0), -1)
|
| 46 |
+
return self.mu(h), self.lv(h)
|
| 47 |
+
|
| 48 |
+
def forward(self, x):
|
| 49 |
+
mu, lv = self.encode(x)
|
| 50 |
+
z = mu + torch.randn_like(mu) * torch.exp(0.5*lv)
|
| 51 |
+
return self.dec(self.dec_fc(z).view(x.size(0), 512, 8, 8)), mu, lv
|
| 52 |
+
|
| 53 |
+
def score(self, img, dev):
|
| 54 |
+
self.eval()
|
| 55 |
+
img = img.to(dev)
|
| 56 |
+
with torch.no_grad():
|
| 57 |
+
if img.dim()==3: img=img.unsqueeze(0)
|
| 58 |
+
rc, mu, lv = self(img)
|
| 59 |
+
xf = torch.fft.fft2(img)
|
| 60 |
+
xf_mag = torch.log(torch.abs(xf) + 1e-8)
|
| 61 |
+
xf_phase = torch.angle(xf)
|
| 62 |
+
xf_combined = xf_mag * 0.8 + xf_phase * 0.2
|
| 63 |
+
recon = F.mse_loss(rc, xf_combined, reduction='sum')
|
| 64 |
+
kl = -0.5 * torch.sum(1 + lv - mu.pow(2) - lv.exp())
|
| 65 |
+
return (recon + 0.15*kl).item()
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class EdgeNormalizingFlow(nn.Module):
|
| 69 |
+
"""Normalizing flow for edge probability density"""
|
| 70 |
+
def __init__(self, feature_dim=32):
|
| 71 |
+
super().__init__()
|
| 72 |
+
self.feature_dim = feature_dim
|
| 73 |
+
self.flows = nn.ModuleList([
|
| 74 |
+
nn.Sequential(
|
| 75 |
+
nn.Linear(feature_dim, feature_dim*2), nn.ReLU(),
|
| 76 |
+
nn.Linear(feature_dim*2, feature_dim*2), nn.ReLU(),
|
| 77 |
+
nn.Linear(feature_dim*2, feature_dim)
|
| 78 |
+
) for _ in range(4)
|
| 79 |
+
])
|
| 80 |
+
self.base_mean = nn.Parameter(torch.zeros(feature_dim))
|
| 81 |
+
self.base_logstd = nn.Parameter(torch.zeros(feature_dim))
|
| 82 |
+
|
| 83 |
+
def extract_edge_features(self, img):
|
| 84 |
+
if torch.is_tensor(img):
|
| 85 |
+
im = img.permute(1,2,0).cpu().numpy()
|
| 86 |
+
im = im*np.array([0.229,0.224,0.225]) + np.array([0.485,0.456,0.406])
|
| 87 |
+
im = np.clip(im, 0, 1)
|
| 88 |
+
else:
|
| 89 |
+
im = np.array(img)
|
| 90 |
+
|
| 91 |
+
gray = np.mean(im, 2)
|
| 92 |
+
ex, ey = sobel(gray, 0), sobel(gray, 1)
|
| 93 |
+
em = np.sqrt(ex**2 + ey**2)
|
| 94 |
+
|
| 95 |
+
features = []
|
| 96 |
+
for scale in [1, 2, 4, 8]:
|
| 97 |
+
if scale > 1:
|
| 98 |
+
scaled = gray[::scale, ::scale]
|
| 99 |
+
ex_s, ey_s = sobel(scaled, 0), sobel(scaled, 1)
|
| 100 |
+
em_s = np.sqrt(ex_s**2 + ey_s**2)
|
| 101 |
+
else:
|
| 102 |
+
em_s = em
|
| 103 |
+
|
| 104 |
+
features.extend([
|
| 105 |
+
np.mean(em_s), np.std(em_s), np.max(em_s),
|
| 106 |
+
np.percentile(em_s, 50), np.percentile(em_s, 75),
|
| 107 |
+
np.percentile(em_s, 90), np.percentile(em_s, 95),
|
| 108 |
+
np.sum(em_s > 0.1) / em_s.size
|
| 109 |
+
])
|
| 110 |
+
|
| 111 |
+
return torch.tensor(features[:self.feature_dim], dtype=torch.float32)
|
| 112 |
+
|
| 113 |
+
def forward(self, x):
|
| 114 |
+
log_det = 0
|
| 115 |
+
for flow in self.flows:
|
| 116 |
+
x = x + flow(x)
|
| 117 |
+
return x, log_det
|
| 118 |
+
|
| 119 |
+
def log_prob(self, x):
|
| 120 |
+
z, log_det = self.forward(x)
|
| 121 |
+
log_pz = -0.5 * torch.sum((z - self.base_mean)**2 / torch.exp(2*self.base_logstd) + 2*self.base_logstd, dim=-1)
|
| 122 |
+
return log_pz + log_det
|
| 123 |
+
|
| 124 |
+
def score(self, img, dev):
|
| 125 |
+
self.eval()
|
| 126 |
+
self.to(dev)
|
| 127 |
+
with torch.no_grad():
|
| 128 |
+
feat = self.extract_edge_features(img).unsqueeze(0).to(dev)
|
| 129 |
+
return -self.log_prob(feat).item()
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
class SemanticDeepSVDD(nn.Module):
|
| 133 |
+
"""Deep SVDD with semantic features from ResNet"""
|
| 134 |
+
def __init__(self):
|
| 135 |
+
super().__init__()
|
| 136 |
+
from torchvision.models import resnet50
|
| 137 |
+
resnet = resnet50(weights='IMAGENET1K_V1')
|
| 138 |
+
self.features = nn.Sequential(*list(resnet.children())[:-1])
|
| 139 |
+
|
| 140 |
+
for i, param in enumerate(self.features.parameters()):
|
| 141 |
+
param.requires_grad = (i >= 100)
|
| 142 |
+
|
| 143 |
+
self.proj = nn.Sequential(
|
| 144 |
+
nn.Flatten(),
|
| 145 |
+
nn.Linear(2048, 1024), nn.BatchNorm1d(1024), nn.ReLU(), nn.Dropout(0.4),
|
| 146 |
+
nn.Linear(1024, 512), nn.BatchNorm1d(512), nn.ReLU(), nn.Dropout(0.3),
|
| 147 |
+
nn.Linear(512, 256)
|
| 148 |
+
)
|
| 149 |
+
self.center = None
|
| 150 |
+
|
| 151 |
+
def forward(self, x):
|
| 152 |
+
return self.proj(self.features(x))
|
| 153 |
+
|
| 154 |
+
def score(self, img, dev):
|
| 155 |
+
self.eval()
|
| 156 |
+
img = img.to(dev)
|
| 157 |
+
with torch.no_grad():
|
| 158 |
+
if img.dim()==3: img=img.unsqueeze(0)
|
| 159 |
+
return torch.sum((self(img) - self.center)**2, 1).mean().item()
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
class Ensemble:
|
| 163 |
+
"""9-model ensemble with adaptive threshold"""
|
| 164 |
+
def __init__(self, models_dict):
|
| 165 |
+
self.models = models_dict
|
| 166 |
+
self.wts = {
|
| 167 |
+
'freq_vae': 0.18,
|
| 168 |
+
'texture_ocsvm': 0.13,
|
| 169 |
+
'color_model': 0.09,
|
| 170 |
+
'edge_flow': 0.13,
|
| 171 |
+
'semantic_svdd': 0.17,
|
| 172 |
+
'stat': 0.09,
|
| 173 |
+
'iforest': 0.09,
|
| 174 |
+
'lof': 0.07,
|
| 175 |
+
'gmm': 0.05
|
| 176 |
+
}
|
| 177 |
+
self.norms = None
|
| 178 |
+
self.thresh = 0.0
|
| 179 |
+
|
| 180 |
+
def get_scores(self, img, dev):
|
| 181 |
+
return {
|
| 182 |
+
'freq_vae': self.models['freq_vae'].score(img, dev),
|
| 183 |
+
'texture_ocsvm': self.models['texture_ocsvm'].score(img),
|
| 184 |
+
'color_model': self.models['color_model'].score(img),
|
| 185 |
+
'edge_flow': self.models['edge_flow'].score(img, dev),
|
| 186 |
+
'semantic_svdd': self.models['semantic_svdd'].score(img, dev),
|
| 187 |
+
'stat': self.models['stat'].score(img),
|
| 188 |
+
'iforest': self.models['iforest'].score(img),
|
| 189 |
+
'lof': self.models['lof'].score(img),
|
| 190 |
+
'gmm': self.models['gmm'].score(img)
|
| 191 |
+
}
|
| 192 |
+
|
| 193 |
+
def predict(self, img, dev):
|
| 194 |
+
sc = self.get_scores(img, dev)
|
| 195 |
+
nsc = {k: (sc[k]-self.norms[k]['mean'])/(self.norms[k]['std']+1e-8)
|
| 196 |
+
for k in sc.keys()}
|
| 197 |
+
final = sum(self.wts[k]*nsc[k] for k in sc.keys())
|
| 198 |
+
return final > self.thresh, final, sc
|