File size: 3,486 Bytes
dbe3f97
b636403
 
dbe3f97
b636403
dbe3f97
b636403
fa3b1cb
b636403
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fa3b1cb
b636403
 
 
dbe3f97
b636403
 
 
 
 
 
 
 
 
dbe3f97
 
 
 
 
b636403
dbe3f97
 
b636403
 
dbe3f97
b636403
dbe3f97
 
 
 
 
 
 
b636403
 
 
 
dbe3f97
b636403
dbe3f97
 
b636403
 
 
dbe3f97
b636403
dbe3f97
 
 
 
 
 
 
 
b636403
 
dbe3f97
b636403
 
 
dbe3f97
 
b636403
 
dbe3f97
b636403
 
 
dbe3f97
 
 
 
 
 
b636403
dbe3f97
b636403
 
dbe3f97
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
import cv2
import numpy as np
import gradio as gr

# ------------------------------------------------------------------
# 1. Define the GenConViT Model Architecture (Minimal Version)
# ------------------------------------------------------------------

class GenConViT(nn.Module):
    def __init__(self, num_classes=2):
        super().__init__()
        # Very lightweight demo backbone (adjust to your real architecture)
        self.feature_extractor = nn.Sequential(
            nn.Conv2d(3, 32, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(32, 64, 3, stride=2, padding=1),
            nn.ReLU(),
            nn.Conv2d(64, 128, 3, stride=2, padding=1),
            nn.AdaptiveAvgPool2d((1, 1)),
        )
        self.fc = nn.Linear(128, num_classes)

    def forward(self, x):
        x = self.feature_extractor(x)
        x = x.flatten(1)
        return self.fc(x)


# ------------------------------------------------------------------
# 2. Load Model From genconvit_ed_inference.pth
# ------------------------------------------------------------------

model_path = "genconvit_ed_inference.pth"

model = GenConViT(num_classes=2)
checkpoint = torch.load(model_path, map_location="cpu")
model.load_state_dict(checkpoint)
model.eval()

# ------------------------------------------------------------------
# 3. Preprocessing
# ------------------------------------------------------------------

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.5]*3, [0.5]*3)
])

# ------------------------------------------------------------------
# 4. Video Deepfake Detection Function
# ------------------------------------------------------------------

def detect_deepfake(video):

    cap = cv2.VideoCapture(video)
    if not cap.isOpened():
        return "Error: Cannot open video", None

    scores = []
    sample_frame = None
    frame_interval = 10  # Process every 10th frame
    i = 0

    while True:
        ret, frame = cap.read()
        if not ret:
            break

        if i % frame_interval == 0:
            rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
            img = Image.fromarray(rgb)

            # Save last processed frame for display
            sample_frame = img

            inp = transform(img).unsqueeze(0)

            with torch.no_grad():
                logits = model(inp)
                probs = torch.softmax(logits, dim=1)[0]
                fake_prob = probs[1].item()

            scores.append(fake_prob)

        i += 1

    cap.release()

    if len(scores) == 0:
        return "No frames processed", None

    avg = float(np.mean(scores))
    label = "🔴 Deepfake" if avg > 0.5 else "🟢 Real"

    output = f"""
### **Prediction: {label}**
**Fake confidence: {avg:.4f}**
"""

    return output, sample_frame


# ------------------------------------------------------------------
# 5. Gradio App UI
# ------------------------------------------------------------------

app = gr.Interface(
    fn=detect_deepfake,
    inputs=gr.Video(label="Upload a video"),
    outputs=[
        gr.Markdown(label="Prediction"),
        gr.Image(label="Sample Frame")
    ],
    title="GenConViT Deepfake Detector (Local .pth Model)",
    description="Upload a video. The system loads genconvit_ed_inference.pth and predicts deepfake probability."
)

app.launch()