import torch import timm from torchvision import transforms from PIL import Image from safetensors.torch import load_file class DeepfakeDetector(torch.nn.Module): def __init__(self, backbone_name, dropout=0.3): super().__init__() self.backbone = timm.create_model(backbone_name, pretrained=False, num_classes=0) if hasattr(self.backbone, 'num_features'): feat_dim = self.backbone.num_features else: with torch.no_grad(): feat_dim = self.backbone(torch.randn(1, 3, 224, 224)).shape[1] self.classifier = torch.nn.Sequential( torch.nn.Linear(feat_dim, 512), torch.nn.BatchNorm1d(512), torch.nn.GELU(), torch.nn.Dropout(dropout), torch.nn.Linear(512, 128), torch.nn.BatchNorm1d(128), torch.nn.GELU(), torch.nn.Dropout(dropout * 0.5), torch.nn.Linear(128, 1) ) def forward(self, x): features = self.backbone(x) return self.classifier(features).squeeze(-1) # Load models device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') configs = [ ('convnext_large', 0.3, 'model_1.safetensors', 0.25), ('vit_large_patch16_224', 0.35, 'model_2.safetensors', 0.35), ('swin_large_patch4_window7_224', 0.3, 'model_3.safetensors', 0.40) ] models = [] for backbone, dropout, filename, weight in configs: model = DeepfakeDetector(backbone, dropout) state_dict = load_file(filename) model.load_state_dict(state_dict) model = model.to(device) model.eval() models.append((model, weight)) # Preprocess image transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Predict image = Image.open('test.jpg').convert('RGB') input_tensor = transform(image).unsqueeze(0).to(device) with torch.no_grad(): predictions = [] for model, weight in models: logits = model(input_tensor) prob = torch.sigmoid(logits).item() predictions.append(prob * weight) final_prob = sum(predictions) prediction = 'FAKE' if final_prob > 0.5 else 'REAL' print(f"Prediction: {prediction}") print(f"Confidence: {final_prob:.2%}")