File size: 2,332 Bytes
fcd620e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import torch
import timm
from torchvision import transforms
from PIL import Image
from safetensors.torch import load_file

class DeepfakeDetector(torch.nn.Module):
    def __init__(self, backbone_name, dropout=0.3):
        super().__init__()
        self.backbone = timm.create_model(backbone_name, pretrained=False, num_classes=0)
        
        if hasattr(self.backbone, 'num_features'):
            feat_dim = self.backbone.num_features
        else:
            with torch.no_grad():
                feat_dim = self.backbone(torch.randn(1, 3, 224, 224)).shape[1]
        
        self.classifier = torch.nn.Sequential(
            torch.nn.Linear(feat_dim, 512),
            torch.nn.BatchNorm1d(512),
            torch.nn.GELU(),
            torch.nn.Dropout(dropout),
            torch.nn.Linear(512, 128),
            torch.nn.BatchNorm1d(128),
            torch.nn.GELU(),
            torch.nn.Dropout(dropout * 0.5),
            torch.nn.Linear(128, 1)
        )
    
    def forward(self, x):
        features = self.backbone(x)
        return self.classifier(features).squeeze(-1)

# Load models
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

configs = [
    ('convnext_large', 0.3, 'model_1.safetensors', 0.25),
    ('vit_large_patch16_224', 0.35, 'model_2.safetensors', 0.35),
    ('swin_large_patch4_window7_224', 0.3, 'model_3.safetensors', 0.40)
]

models = []
for backbone, dropout, filename, weight in configs:
    model = DeepfakeDetector(backbone, dropout)
    state_dict = load_file(filename)
    model.load_state_dict(state_dict)
    model = model.to(device)
    model.eval()
    models.append((model, weight))

# Preprocess image
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# Predict
image = Image.open('test.jpg').convert('RGB')
input_tensor = transform(image).unsqueeze(0).to(device)

with torch.no_grad():
    predictions = []
    for model, weight in models:
        logits = model(input_tensor)
        prob = torch.sigmoid(logits).item()
        predictions.append(prob * weight)
    
    final_prob = sum(predictions)
    prediction = 'FAKE' if final_prob > 0.5 else 'REAL'

print(f"Prediction: {prediction}")
print(f"Confidence: {final_prob:.2%}")