File size: 2,281 Bytes
239017e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import warnings
warnings.filterwarnings("ignore", category=FutureWarning)

import torch
from transformers import AutoImageProcessor, AutoModelForImageClassification
from PIL import Image
from facenet_pytorch import MTCNN

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Using:", device)
mtcnn = MTCNN(keep_all=False, device=device)

# Load state-of-the-art Hugging Face Vision Transformer for Deepfake Detection
model_id = "prithivMLmods/Deep-Fake-Detector-v2-Model"
print(f"Loading Hugging Face ViT [{model_id}]... (This may take a moment to download weights on first run)")
processor = AutoImageProcessor.from_pretrained(model_id)
model = AutoModelForImageClassification.from_pretrained(model_id).to(device)
model.eval()

def predict(image_path):
    image = Image.open(image_path).convert("RGB")
    
    # Transformer Inference (The ViT is trained on global image context, not just local cropped bounding boxes)
    inputs = processor(images=image, return_tensors="pt").to(device)
    
    with torch.no_grad():
        outputs = model(**inputs)
        logits = outputs.logits
        model_probs = torch.nn.functional.softmax(logits, dim=1)

    labels = model.config.id2label
    # Dynamically find which index maps to real/fake in case the hub model swaps them
    fake_prob = 0.0
    real_prob = 0.0
    for idx, label_name in labels.items():
        prob = model_probs[0][idx].item()
        l = label_name.lower()
        if 'fake' in l or 'deepfake' in l or 'spoof' in l:
            fake_prob += prob
        elif 'real' in l or 'pristine' in l:
            real_prob += prob

    # Fallback to direct argmax if labels are ambiguous
    if fake_prob == 0 and real_prob == 0:
        pred_idx = torch.argmax(model_probs, dim=1).item()
        print(f"\nRaw Hub Label mapped: {labels[pred_idx]}")
        return

    print(f"\nFake: {fake_prob*100:.2f}%")
    print(f"Real: {real_prob*100:.2f}%")

    if real_prob > fake_prob:
        print("Prediction: REAL ✅")
    else:
        print("Prediction: FAKE ⚠️")

if __name__ == "__main__":
    import sys
    import os
    if len(sys.argv) > 1 and os.path.exists(sys.argv[1]):
        predict(sys.argv[1])
    else:
        if os.path.exists("test.jpg"):
            predict("test.jpg")