deepfake-ensemble-detector / ensemble_model.py
ash12321's picture
Upload ensemble deepfake detector (Deep SVDD + Autoencoder)
3539678 verified
"""
Ensemble Deepfake Detector
Combines Deep SVDD + Autoencoder with 50/50 voting
Usage:
from ensemble_model import EnsembleDeepfakeDetector
detector = EnsembleDeepfakeDetector.from_pretrained()
score, is_fake = detector.predict('image.jpg')
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
import json
from pathlib import Path
from huggingface_hub import snapshot_download
import pickle
# Deep SVDD Components
class SVDDResidualBlock(nn.Module):
def __init__(self, in_ch, out_ch, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(in_ch, out_ch, 3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_ch)
self.conv2 = nn.Conv2d(out_ch, out_ch, 3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_ch)
self.shortcut = nn.Sequential()
if stride != 1 or in_ch != out_ch:
self.shortcut = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 1, stride=stride, bias=False),
nn.BatchNorm2d(out_ch)
)
def forward(self, x):
out = F.relu(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
return F.relu(out)
class DeepSVDDEncoder(nn.Module):
def __init__(self, latent_dim=512):
super().__init__()
self.conv1 = nn.Conv2d(3, 64, 7, stride=2, padding=3, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self._make_layer(64, 128, stride=2)
self.layer2 = self._make_layer(128, 256, stride=2)
self.layer3 = self._make_layer(256, 512, stride=2)
self.layer4 = self._make_layer(512, 512, stride=2)
self.fc = nn.Linear(512 * 4 * 4, latent_dim, bias=False)
def _make_layer(self, in_ch, out_ch, stride=1):
return nn.Sequential(
SVDDResidualBlock(in_ch, out_ch, stride),
SVDDResidualBlock(out_ch, out_ch, 1)
)
def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = self.layer1(x)
x = self.layer2(x)
x = self.layer3(x)
x = self.layer4(x)
x = x.view(x.size(0), -1)
return self.fc(x)
# Autoencoder Components
class AEResidualBlock(nn.Module):
def __init__(self, channels, dropout=0.1):
super().__init__()
self.conv1 = nn.Conv2d(channels, channels, 3, padding=1)
self.bn1 = nn.BatchNorm2d(channels)
self.conv2 = nn.Conv2d(channels, channels, 3, padding=1)
self.bn2 = nn.BatchNorm2d(channels)
self.relu = nn.ReLU(inplace=True)
self.dropout = nn.Dropout2d(dropout) if dropout > 0 else nn.Identity()
def forward(self, x):
residual = x
out = self.relu(self.bn1(self.conv1(x)))
out = self.dropout(out)
out = self.bn2(self.conv2(out))
out += residual
return self.relu(out)
class ResidualConvAutoencoder(nn.Module):
def __init__(self, latent_dim=512, dropout=0.1):
super().__init__()
self.encoder = nn.Sequential(
nn.Conv2d(3, 64, 4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
AEResidualBlock(64, dropout),
nn.Conv2d(64, 128, 4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
AEResidualBlock(128, dropout),
nn.Conv2d(128, 256, 4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
AEResidualBlock(256, dropout),
nn.Conv2d(256, 512, 4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
AEResidualBlock(512, dropout),
nn.Conv2d(512, 512, 4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
)
self.fc_encoder = nn.Linear(512 * 4 * 4, latent_dim)
self.fc_decoder = nn.Linear(latent_dim, 512 * 4 * 4)
self.decoder = nn.Sequential(
nn.ConvTranspose2d(512, 512, 4, stride=2, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
AEResidualBlock(512, dropout),
nn.ConvTranspose2d(512, 256, 4, stride=2, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
AEResidualBlock(256, dropout),
nn.ConvTranspose2d(256, 128, 4, stride=2, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
AEResidualBlock(128, dropout),
nn.ConvTranspose2d(128, 64, 4, stride=2, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
AEResidualBlock(64, dropout),
nn.ConvTranspose2d(64, 3, 4, stride=2, padding=1),
nn.Tanh()
)
def forward(self, x):
x = self.encoder(x)
x = x.view(x.size(0), -1)
latent = self.fc_encoder(x)
x = self.fc_decoder(latent)
x = x.view(x.size(0), 512, 4, 4)
reconstructed = self.decoder(x)
return reconstructed, latent
def reconstruction_error(self, x):
reconstructed, _ = self.forward(x)
error = ((reconstructed - x) ** 2).view(x.size(0), -1).mean(dim=1)
return error
# Ensemble Detector
class EnsembleDeepfakeDetector:
"""
Ensemble Deepfake Detector combining Deep SVDD + Autoencoder
Usage:
detector = EnsembleDeepfakeDetector.from_pretrained()
score, is_fake = detector.predict('image.jpg')
"""
def __init__(self, device='cuda'):
self.device = torch.device(device if torch.cuda.is_available() else 'cpu')
self.ensemble_threshold = 0.1163 # Optimized threshold
@classmethod
def from_pretrained(cls, device='cuda'):
"""Load ensemble from HuggingFace"""
detector = cls(device=device)
print("Loading ensemble models...")
# Load Deep SVDD
print(" [1/2] Loading Deep SVDD...")
svdd_path = snapshot_download(repo_id="ash12321/deep-svdd-anomaly-detection")
detector.load_svdd(svdd_path)
# Load Autoencoder
print(" [2/2] Loading Autoencoder...")
ae_path = snapshot_download(repo_id="ash12321/deepfake-autoencoder-cifar10-v2")
detector.load_autoencoder(ae_path)
print("✓ Ensemble loaded successfully!")
return detector
def load_svdd(self, model_dir):
model_dir = Path(model_dir)
checkpoint = torch.load(model_dir / 'deepsvdd_model.pth', map_location=self.device)
self.svdd_encoder = DeepSVDDEncoder(checkpoint['latent_dim']).to(self.device)
self.svdd_encoder.load_state_dict(checkpoint['encoder_state_dict'])
self.svdd_encoder.eval()
self.svdd_center = checkpoint['center'].to(self.device)
with open(model_dir / 'thresholds.pkl', 'rb') as f:
thresholds = pickle.load(f)
self.svdd_threshold = thresholds['optimal_f1']
self.svdd_transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
def load_autoencoder(self, model_dir):
model_dir = Path(model_dir)
checkpoint = torch.load(model_dir / 'model_universal_best.ckpt',
map_location=self.device, weights_only=False)
config = checkpoint.get('config', {})
self.ae_model = ResidualConvAutoencoder(
latent_dim=config.get('latent_dim', 512),
dropout=config.get('dropout', 0.1)
).to(self.device)
self.ae_model.load_state_dict(checkpoint['model_state_dict'])
self.ae_model.eval()
with open(model_dir / 'thresholds_calibrated.json', 'r') as f:
thresholds = json.load(f)
self.ae_threshold = thresholds['reconstruction_thresholds']['thresholds']['balanced']['value']
self.ae_transform = transforms.Compose([
transforms.Resize((128, 128)),
transforms.ToTensor(),
transforms.Lambda(lambda x: x * 2 - 1)
])
@torch.no_grad()
def predict(self, image):
"""
Predict if image is deepfake
Args:
image: PIL Image or path to image
Returns:
score: Ensemble score (0-1, higher = more likely fake)
is_fake: Boolean prediction
"""
if isinstance(image, (str, Path)):
image = Image.open(image).convert('RGB')
# Deep SVDD prediction
svdd_img = self.svdd_transform(image).unsqueeze(0).to(self.device)
svdd_embedding = self.svdd_encoder(svdd_img)
svdd_distance = torch.sum((svdd_embedding - self.svdd_center) ** 2, dim=1).item()
svdd_score = min(svdd_distance / (self.svdd_threshold * 3), 1.0)
# Autoencoder prediction
ae_img = self.ae_transform(image).unsqueeze(0).to(self.device)
ae_error = self.ae_model.reconstruction_error(ae_img).item()
ae_score = min(ae_error / (self.ae_threshold * 3), 1.0)
# Ensemble (50/50 average)
ensemble_score = (svdd_score + ae_score) / 2.0
is_fake = ensemble_score > self.ensemble_threshold
return ensemble_score, is_fake
def set_threshold(self, threshold):
"""Set ensemble threshold (0-1)"""
self.ensemble_threshold = threshold
# Example usage
if __name__ == '__main__':
detector = EnsembleDeepfakeDetector.from_pretrained()
score, is_fake = detector.predict('test.jpg')
print(f"Score: {score:.4f}, Fake: {is_fake}")