vu0018 commited on
Commit
2b2808e
·
verified ·
1 Parent(s): eb30246

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -98
app.py CHANGED
@@ -1,123 +1,54 @@
 
1
  import torch
2
- import torch.nn as nn
3
- from torchvision import transforms
4
- from PIL import Image
5
  import cv2
6
  import numpy as np
7
- import gradio as gr
8
-
9
- # ------------------------------------------------------------------
10
- # 1. Define the GenConViT Model Architecture (Minimal Version)
11
- # ------------------------------------------------------------------
12
-
13
- class GenConViT(nn.Module):
14
- def __init__(self, num_classes=2):
15
- super().__init__()
16
- # Very lightweight demo backbone (adjust to your real architecture)
17
- self.feature_extractor = nn.Sequential(
18
- nn.Conv2d(3, 32, 3, stride=2, padding=1),
19
- nn.ReLU(),
20
- nn.Conv2d(32, 64, 3, stride=2, padding=1),
21
- nn.ReLU(),
22
- nn.Conv2d(64, 128, 3, stride=2, padding=1),
23
- nn.AdaptiveAvgPool2d((1, 1)),
24
- )
25
- self.fc = nn.Linear(128, num_classes)
26
-
27
- def forward(self, x):
28
- x = self.feature_extractor(x)
29
- x = x.flatten(1)
30
- return self.fc(x)
31
 
 
32
 
33
- # ------------------------------------------------------------------
34
- # 2. Load Model From genconvit_ed_inference.pth
35
- # ------------------------------------------------------------------
36
-
37
- model_path = "genconvit_ed_inference.pth"
38
-
39
- model = GenConViT(num_classes=2)
40
- checkpoint = torch.load(model_path, map_location="cpu")
41
- model.load_state_dict(checkpoint)
42
  model.eval()
43
 
44
- # ------------------------------------------------------------------
45
- # 3. Preprocessing
46
- # ------------------------------------------------------------------
47
-
48
- transform = transforms.Compose([
49
- transforms.Resize((224, 224)),
50
- transforms.ToTensor(),
51
- transforms.Normalize([0.5]*3, [0.5]*3)
52
- ])
53
-
54
- # ------------------------------------------------------------------
55
- # 4. Video Deepfake Detection Function
56
- # ------------------------------------------------------------------
57
-
58
- def detect_deepfake(video):
59
 
 
60
  cap = cv2.VideoCapture(video)
61
- if not cap.isOpened():
62
- return "Error: Cannot open video", None
63
-
64
  scores = []
65
- sample_frame = None
66
- frame_interval = 10 # Process every 10th frame
67
- i = 0
68
 
69
  while True:
70
  ret, frame = cap.read()
71
  if not ret:
72
  break
73
 
74
- if i % frame_interval == 0:
75
- rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
76
- img = Image.fromarray(rgb)
77
-
78
- # Save last processed frame for display
79
- sample_frame = img
80
-
81
- inp = transform(img).unsqueeze(0)
82
-
83
- with torch.no_grad():
84
- logits = model(inp)
85
- probs = torch.softmax(logits, dim=1)[0]
86
- fake_prob = probs[1].item()
87
-
88
- scores.append(fake_prob)
89
-
90
- i += 1
91
 
92
  cap.release()
93
 
94
  if len(scores) == 0:
95
- return "No frames processed", None
96
-
97
- avg = float(np.mean(scores))
98
- label = "🔴 Deepfake" if avg > 0.5 else "🟢 Real"
99
-
100
- output = f"""
101
- ### **Prediction: {label}**
102
- **Fake confidence: {avg:.4f}**
103
- """
104
-
105
- return output, sample_frame
106
 
 
 
107
 
108
- # ------------------------------------------------------------------
109
- # 5. Gradio App UI
110
- # ------------------------------------------------------------------
111
 
112
- app = gr.Interface(
113
- fn=detect_deepfake,
114
- inputs=gr.Video(label="Upload a video"),
115
- outputs=[
116
- gr.Markdown(label="Prediction"),
117
- gr.Image(label="Sample Frame")
118
- ],
119
- title="GenConViT Deepfake Detector (Local .pth Model)",
120
- description="Upload a video. The system loads genconvit_ed_inference.pth and predicts deepfake probability."
121
  )
122
 
123
- app.launch()
 
1
+ import gradio as gr
2
  import torch
 
 
 
3
  import cv2
4
  import numpy as np
5
+ from model import GenConViT
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
 
9
+ # Load model
10
+ model = GenConViT().to(device)
11
+ state = torch.load("genconvit_ed_inference.pth", map_location=device)
12
+ model.load_state_dict(state)
 
 
 
 
 
13
  model.eval()
14
 
15
+ def preprocess(frame):
16
+ frame = cv2.resize(frame, (224, 224))
17
+ frame = frame[:, :, ::-1] / 255.0
18
+ frame = torch.tensor(frame, dtype=torch.float32).permute(2, 0, 1)
19
+ return frame.unsqueeze(0)
 
 
 
 
 
 
 
 
 
 
20
 
21
+ def predict(video):
22
  cap = cv2.VideoCapture(video)
 
 
 
23
  scores = []
 
 
 
24
 
25
  while True:
26
  ret, frame = cap.read()
27
  if not ret:
28
  break
29
 
30
+ inp = preprocess(frame).to(device)
31
+ with torch.no_grad():
32
+ pred = model(inp)
33
+ prob = torch.softmax(pred, dim=1)[0, 1].item()
34
+ scores.append(prob)
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
  cap.release()
37
 
38
  if len(scores) == 0:
39
+ return "No frames detected."
 
 
 
 
 
 
 
 
 
 
40
 
41
+ deepfake_prob = float(np.mean(scores))
42
+ label = "Deepfake" if deepfake_prob > 0.5 else "Real"
43
 
44
+ return f"{label} (score: {deepfake_prob:.4f})"
 
 
45
 
46
+ # UI
47
+ demo = gr.Interface(
48
+ fn=predict,
49
+ inputs=gr.Video(),
50
+ outputs="text",
51
+ title="GenConViT Deepfake Detector",
 
 
 
52
  )
53
 
54
+ demo.launch()