|
|
import torch |
|
|
import timm |
|
|
from torchvision import transforms |
|
|
from PIL import Image |
|
|
from safetensors.torch import load_file |
|
|
|
|
|
class DeepfakeDetector(torch.nn.Module): |
|
|
def __init__(self, backbone_name, dropout=0.3): |
|
|
super().__init__() |
|
|
self.backbone = timm.create_model(backbone_name, pretrained=False, num_classes=0) |
|
|
|
|
|
if hasattr(self.backbone, 'num_features'): |
|
|
feat_dim = self.backbone.num_features |
|
|
else: |
|
|
with torch.no_grad(): |
|
|
feat_dim = self.backbone(torch.randn(1, 3, 224, 224)).shape[1] |
|
|
|
|
|
self.classifier = torch.nn.Sequential( |
|
|
torch.nn.Linear(feat_dim, 512), |
|
|
torch.nn.BatchNorm1d(512), |
|
|
torch.nn.GELU(), |
|
|
torch.nn.Dropout(dropout), |
|
|
torch.nn.Linear(512, 128), |
|
|
torch.nn.BatchNorm1d(128), |
|
|
torch.nn.GELU(), |
|
|
torch.nn.Dropout(dropout * 0.5), |
|
|
torch.nn.Linear(128, 1) |
|
|
) |
|
|
|
|
|
def forward(self, x): |
|
|
features = self.backbone(x) |
|
|
return self.classifier(features).squeeze(-1) |
|
|
|
|
|
|
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
configs = [ |
|
|
('convnext_large', 0.3, 'model_1.safetensors', 0.25), |
|
|
('vit_large_patch16_224', 0.35, 'model_2.safetensors', 0.35), |
|
|
('swin_large_patch4_window7_224', 0.3, 'model_3.safetensors', 0.40) |
|
|
] |
|
|
|
|
|
models = [] |
|
|
for backbone, dropout, filename, weight in configs: |
|
|
model = DeepfakeDetector(backbone, dropout) |
|
|
state_dict = load_file(filename) |
|
|
model.load_state_dict(state_dict) |
|
|
model = model.to(device) |
|
|
model.eval() |
|
|
models.append((model, weight)) |
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((224, 224)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
|
|
|
image = Image.open('test.jpg').convert('RGB') |
|
|
input_tensor = transform(image).unsqueeze(0).to(device) |
|
|
|
|
|
with torch.no_grad(): |
|
|
predictions = [] |
|
|
for model, weight in models: |
|
|
logits = model(input_tensor) |
|
|
prob = torch.sigmoid(logits).item() |
|
|
predictions.append(prob * weight) |
|
|
|
|
|
final_prob = sum(predictions) |
|
|
prediction = 'FAKE' if final_prob > 0.5 else 'REAL' |
|
|
|
|
|
print(f"Prediction: {prediction}") |
|
|
print(f"Confidence: {final_prob:.2%}") |
|
|
|