|
|
""" |
|
|
Ensemble Deepfake Detector |
|
|
Combines Deep SVDD + Autoencoder with 50/50 voting |
|
|
|
|
|
Usage: |
|
|
from ensemble_model import EnsembleDeepfakeDetector |
|
|
|
|
|
detector = EnsembleDeepfakeDetector.from_pretrained() |
|
|
score, is_fake = detector.predict('image.jpg') |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from torchvision import transforms |
|
|
from PIL import Image |
|
|
import json |
|
|
from pathlib import Path |
|
|
from huggingface_hub import snapshot_download |
|
|
import pickle |
|
|
|
|
|
|
|
|
|
|
|
class SVDDResidualBlock(nn.Module): |
|
|
def __init__(self, in_ch, out_ch, stride=1): |
|
|
super().__init__() |
|
|
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False) |
|
|
self.bn1 = nn.BatchNorm2d(out_ch) |
|
|
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1, bias=False) |
|
|
self.bn2 = nn.BatchNorm2d(out_ch) |
|
|
|
|
|
self.shortcut = nn.Sequential() |
|
|
if stride != 1 or in_ch != out_ch: |
|
|
self.shortcut = nn.Sequential( |
|
|
nn.Conv2d(in_ch, out_ch, 1, stride=stride, bias=False), |
|
|
nn.BatchNorm2d(out_ch) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
out = F.relu(self.bn1(self.conv1(x))) |
|
|
out = self.bn2(self.conv2(out)) |
|
|
out += self.shortcut(x) |
|
|
return F.relu(out) |
|
|
|
|
|
|
|
|
class DeepSVDDEncoder(nn.Module): |
|
|
def __init__(self, latent_dim=512): |
|
|
super().__init__() |
|
|
self.conv1 = nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=False) |
|
|
self.bn1 = nn.BatchNorm2d(64) |
|
|
self.layer1 = self._make_layer(64, 128, stride=2) |
|
|
self.layer2 = self._make_layer(128, 256, stride=2) |
|
|
self.layer3 = self._make_layer(256, 512, stride=2) |
|
|
self.layer4 = self._make_layer(512, 512, stride=2) |
|
|
self.fc = nn.Linear(512 * 4 * 4, latent_dim, bias=False) |
|
|
|
|
|
def _make_layer(self, in_ch, out_ch, stride=1): |
|
|
return nn.Sequential( |
|
|
SVDDResidualBlock(in_ch, out_ch, stride), |
|
|
SVDDResidualBlock(out_ch, out_ch, 1) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
x = F.relu(self.bn1(self.conv1(x))) |
|
|
x = self.layer1(x) |
|
|
x = self.layer2(x) |
|
|
x = self.layer3(x) |
|
|
x = self.layer4(x) |
|
|
x = x.view(x.size(0), -1) |
|
|
return self.fc(x) |
|
|
|
|
|
|
|
|
|
|
|
class AEResidualBlock(nn.Module): |
|
|
def __init__(self, channels, dropout=0.1): |
|
|
super().__init__() |
|
|
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1) |
|
|
self.bn1 = nn.BatchNorm2d(channels) |
|
|
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1) |
|
|
self.bn2 = nn.BatchNorm2d(channels) |
|
|
self.relu = nn.ReLU(inplace=True) |
|
|
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity() |
|
|
|
|
|
def forward(self, x): |
|
|
residual = x |
|
|
out = self.relu(self.bn1(self.conv1(x))) |
|
|
out = self.dropout(out) |
|
|
out = self.bn2(self.conv2(out)) |
|
|
out += residual |
|
|
return self.relu(out) |
|
|
|
|
|
|
|
|
class ResidualConvAutoencoder(nn.Module): |
|
|
def __init__(self, latent_dim=512, dropout=0.1): |
|
|
super().__init__() |
|
|
|
|
|
self.encoder = nn.Sequential( |
|
|
nn.Conv2d(3, 64, 4, stride=2, padding=1), |
|
|
nn.BatchNorm2d(64), |
|
|
nn.ReLU(inplace=True), |
|
|
AEResidualBlock(64, dropout), |
|
|
nn.Conv2d(64, 128, 4, stride=2, padding=1), |
|
|
nn.BatchNorm2d(128), |
|
|
nn.ReLU(inplace=True), |
|
|
AEResidualBlock(128, dropout), |
|
|
nn.Conv2d(128, 256, 4, stride=2, padding=1), |
|
|
nn.BatchNorm2d(256), |
|
|
nn.ReLU(inplace=True), |
|
|
AEResidualBlock(256, dropout), |
|
|
nn.Conv2d(256, 512, 4, stride=2, padding=1), |
|
|
nn.BatchNorm2d(512), |
|
|
nn.ReLU(inplace=True), |
|
|
AEResidualBlock(512, dropout), |
|
|
nn.Conv2d(512, 512, 4, stride=2, padding=1), |
|
|
nn.BatchNorm2d(512), |
|
|
nn.ReLU(inplace=True), |
|
|
) |
|
|
|
|
|
self.fc_encoder = nn.Linear(512 * 4 * 4, latent_dim) |
|
|
self.fc_decoder = nn.Linear(latent_dim, 512 * 4 * 4) |
|
|
|
|
|
self.decoder = nn.Sequential( |
|
|
nn.ConvTranspose2d(512, 512, 4, stride=2, padding=1), |
|
|
nn.BatchNorm2d(512), |
|
|
nn.ReLU(inplace=True), |
|
|
AEResidualBlock(512, dropout), |
|
|
nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1), |
|
|
nn.BatchNorm2d(256), |
|
|
nn.ReLU(inplace=True), |
|
|
AEResidualBlock(256, dropout), |
|
|
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1), |
|
|
nn.BatchNorm2d(128), |
|
|
nn.ReLU(inplace=True), |
|
|
AEResidualBlock(128, dropout), |
|
|
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1), |
|
|
nn.BatchNorm2d(64), |
|
|
nn.ReLU(inplace=True), |
|
|
AEResidualBlock(64, dropout), |
|
|
nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1), |
|
|
nn.Tanh() |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
x = self.encoder(x) |
|
|
x = x.view(x.size(0), -1) |
|
|
latent = self.fc_encoder(x) |
|
|
x = self.fc_decoder(latent) |
|
|
x = x.view(x.size(0), 512, 4, 4) |
|
|
reconstructed = self.decoder(x) |
|
|
return reconstructed, latent |
|
|
|
|
|
def reconstruction_error(self, x): |
|
|
reconstructed, _ = self.forward(x) |
|
|
error = ((reconstructed - x) ** 2).view(x.size(0), -1).mean(dim=1) |
|
|
return error |
|
|
|
|
|
|
|
|
|
|
|
class EnsembleDeepfakeDetector: |
|
|
""" |
|
|
Ensemble Deepfake Detector combining Deep SVDD + Autoencoder |
|
|
|
|
|
Usage: |
|
|
detector = EnsembleDeepfakeDetector.from_pretrained() |
|
|
score, is_fake = detector.predict('image.jpg') |
|
|
""" |
|
|
|
|
|
def __init__(self, device='cuda'): |
|
|
self.device = torch.device(device if torch.cuda.is_available() else 'cpu') |
|
|
self.ensemble_threshold = 0.1163 |
|
|
|
|
|
@classmethod |
|
|
def from_pretrained(cls, device='cuda'): |
|
|
"""Load ensemble from HuggingFace""" |
|
|
detector = cls(device=device) |
|
|
|
|
|
print("Loading ensemble models...") |
|
|
|
|
|
|
|
|
print(" [1/2] Loading Deep SVDD...") |
|
|
svdd_path = snapshot_download(repo_id="ash12321/deep-svdd-anomaly-detection") |
|
|
detector.load_svdd(svdd_path) |
|
|
|
|
|
|
|
|
print(" [2/2] Loading Autoencoder...") |
|
|
ae_path = snapshot_download(repo_id="ash12321/deepfake-autoencoder-cifar10-v2") |
|
|
detector.load_autoencoder(ae_path) |
|
|
|
|
|
print("✓ Ensemble loaded successfully!") |
|
|
return detector |
|
|
|
|
|
def load_svdd(self, model_dir): |
|
|
model_dir = Path(model_dir) |
|
|
checkpoint = torch.load(model_dir / 'deepsvdd_model.pth', map_location=self.device) |
|
|
self.svdd_encoder = DeepSVDDEncoder(checkpoint['latent_dim']).to(self.device) |
|
|
self.svdd_encoder.load_state_dict(checkpoint['encoder_state_dict']) |
|
|
self.svdd_encoder.eval() |
|
|
self.svdd_center = checkpoint['center'].to(self.device) |
|
|
|
|
|
with open(model_dir / 'thresholds.pkl', 'rb') as f: |
|
|
thresholds = pickle.load(f) |
|
|
self.svdd_threshold = thresholds['optimal_f1'] |
|
|
|
|
|
self.svdd_transform = transforms.Compose([ |
|
|
transforms.Resize((128, 128)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
def load_autoencoder(self, model_dir): |
|
|
model_dir = Path(model_dir) |
|
|
checkpoint = torch.load(model_dir / 'model_universal_best.ckpt', |
|
|
map_location=self.device, weights_only=False) |
|
|
|
|
|
config = checkpoint.get('config', {}) |
|
|
self.ae_model = ResidualConvAutoencoder( |
|
|
latent_dim=config.get('latent_dim', 512), |
|
|
dropout=config.get('dropout', 0.1) |
|
|
).to(self.device) |
|
|
self.ae_model.load_state_dict(checkpoint['model_state_dict']) |
|
|
self.ae_model.eval() |
|
|
|
|
|
with open(model_dir / 'thresholds_calibrated.json', 'r') as f: |
|
|
thresholds = json.load(f) |
|
|
self.ae_threshold = thresholds['reconstruction_thresholds']['thresholds']['balanced']['value'] |
|
|
|
|
|
self.ae_transform = transforms.Compose([ |
|
|
transforms.Resize((128, 128)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Lambda(lambda x: x * 2 - 1) |
|
|
]) |
|
|
|
|
|
@torch.no_grad() |
|
|
def predict(self, image): |
|
|
""" |
|
|
Predict if image is deepfake |
|
|
|
|
|
Args: |
|
|
image: PIL Image or path to image |
|
|
|
|
|
Returns: |
|
|
score: Ensemble score (0-1, higher = more likely fake) |
|
|
is_fake: Boolean prediction |
|
|
""" |
|
|
if isinstance(image, (str, Path)): |
|
|
image = Image.open(image).convert('RGB') |
|
|
|
|
|
|
|
|
svdd_img = self.svdd_transform(image).unsqueeze(0).to(self.device) |
|
|
svdd_embedding = self.svdd_encoder(svdd_img) |
|
|
svdd_distance = torch.sum((svdd_embedding - self.svdd_center) ** 2, dim=1).item() |
|
|
svdd_score = min(svdd_distance / (self.svdd_threshold * 3), 1.0) |
|
|
|
|
|
|
|
|
ae_img = self.ae_transform(image).unsqueeze(0).to(self.device) |
|
|
ae_error = self.ae_model.reconstruction_error(ae_img).item() |
|
|
ae_score = min(ae_error / (self.ae_threshold * 3), 1.0) |
|
|
|
|
|
|
|
|
ensemble_score = (svdd_score + ae_score) / 2.0 |
|
|
is_fake = ensemble_score > self.ensemble_threshold |
|
|
|
|
|
return ensemble_score, is_fake |
|
|
|
|
|
def set_threshold(self, threshold): |
|
|
"""Set ensemble threshold (0-1)""" |
|
|
self.ensemble_threshold = threshold |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
detector = EnsembleDeepfakeDetector.from_pretrained() |
|
|
score, is_fake = detector.predict('test.jpg') |
|
|
print(f"Score: {score:.4f}, Fake: {is_fake}") |
|
|
|