deepfake-detector-v14 / inference_example.py
ash12321's picture
Upload folder using huggingface_hub
fcd620e verified
import torch
import timm
from torchvision import transforms
from PIL import Image
from safetensors.torch import load_file
class DeepfakeDetector(torch.nn.Module):
def __init__(self, backbone_name, dropout=0.3):
super().__init__()
self.backbone = timm.create_model(backbone_name, pretrained=False, num_classes=0)
if hasattr(self.backbone, 'num_features'):
feat_dim = self.backbone.num_features
else:
with torch.no_grad():
feat_dim = self.backbone(torch.randn(1, 3, 224, 224)).shape[1]
self.classifier = torch.nn.Sequential(
torch.nn.Linear(feat_dim, 512),
torch.nn.BatchNorm1d(512),
torch.nn.GELU(),
torch.nn.Dropout(dropout),
torch.nn.Linear(512, 128),
torch.nn.BatchNorm1d(128),
torch.nn.GELU(),
torch.nn.Dropout(dropout * 0.5),
torch.nn.Linear(128, 1)
)
def forward(self, x):
features = self.backbone(x)
return self.classifier(features).squeeze(-1)
# Load models
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
configs = [
('convnext_large', 0.3, 'model_1.safetensors', 0.25),
('vit_large_patch16_224', 0.35, 'model_2.safetensors', 0.35),
('swin_large_patch4_window7_224', 0.3, 'model_3.safetensors', 0.40)
]
models = []
for backbone, dropout, filename, weight in configs:
model = DeepfakeDetector(backbone, dropout)
state_dict = load_file(filename)
model.load_state_dict(state_dict)
model = model.to(device)
model.eval()
models.append((model, weight))
# Preprocess image
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
# Predict
image = Image.open('test.jpg').convert('RGB')
input_tensor = transform(image).unsqueeze(0).to(device)
with torch.no_grad():
predictions = []
for model, weight in models:
logits = model(input_tensor)
prob = torch.sigmoid(logits).item()
predictions.append(prob * weight)
final_prob = sum(predictions)
prediction = 'FAKE' if final_prob > 0.5 else 'REAL'
print(f"Prediction: {prediction}")
print(f"Confidence: {final_prob:.2%}")