vu0018 commited on
Commit
b636403
·
verified ·
1 Parent(s): 7985697

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -46
app.py CHANGED
@@ -1,61 +1,91 @@
1
- import gradio as gr
2
  import torch
3
- import cv2
4
- from transformers import AutoModelForImageClassification, AutoImageProcessor
5
  from PIL import Image
 
6
  import numpy as np
 
7
 
8
- # ----------------------------------------------------------
9
- # Load Hugging Face GenConViT Model
10
- # ----------------------------------------------------------
11
- model = AutoModelForImageClassification.from_pretrained(
12
- "Thanuja2109/GenConViT"
13
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
 
15
- processor = AutoImageProcessor.from_pretrained(
16
- "Thanuja2109/GenConViT"
17
- )
18
 
19
- model.eval()
 
 
 
 
 
 
 
 
20
 
21
- # ----------------------------------------------------------
22
- # Deepfake detection function
23
- # ----------------------------------------------------------
24
  def detect_deepfake(video):
25
 
26
- # Load video
27
  cap = cv2.VideoCapture(video)
28
-
29
  if not cap.isOpened():
30
- return "Error: cannot open video", None
31
 
32
  scores = []
33
- frames_collected = 0
34
-
35
- # Sample 1 frame every 10
36
- frame_interval = 10
37
-
38
- frame_img = None
39
-
40
  i = 0
 
41
  while True:
42
  ret, frame = cap.read()
43
  if not ret:
44
  break
45
 
46
  if i % frame_interval == 0:
47
- # Convert to RGB
48
  rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
49
- pil_img = Image.fromarray(rgb)
 
 
 
50
 
51
- inputs = processor(images=pil_img, return_tensors="pt")
52
 
53
  with torch.no_grad():
54
- logits = model(**inputs).logits
55
- prob_fake = torch.softmax(logits, dim=1)[0][1].item()
 
56
 
57
- scores.append(prob_fake)
58
- frame_img = pil_img # save last sampled frame
59
 
60
  i += 1
61
 
@@ -64,30 +94,30 @@ def detect_deepfake(video):
64
  if len(scores) == 0:
65
  return "No frames processed", None
66
 
67
- avg_score = np.mean(scores)
 
68
 
69
- label = "🔴 Deepfake" if avg_score > 0.5 else "🟢 Real"
70
-
71
- result_text = f"""
72
- ### Prediction: **{label}**
73
- **Confidence (fake probability): {avg_score:.4f}**
74
  """
75
 
76
- return result_text, frame_img
 
77
 
 
 
 
78
 
79
- # ----------------------------------------------------------
80
- # Gradio Interface
81
- # ----------------------------------------------------------
82
  app = gr.Interface(
83
  fn=detect_deepfake,
84
  inputs=gr.Video(label="Upload a video"),
85
  outputs=[
86
  gr.Markdown(label="Prediction"),
87
- gr.Image(label="Analyzed Frame")
88
  ],
89
- title="GenConViT Deepfake Video Detector",
90
- description="Upload a video. The app samples frames and uses GenConViT to detect deepfakes."
91
  )
92
 
93
  app.launch()
 
 
1
  import torch
2
+ import torch.nn as nn
3
+ from torchvision import transforms
4
  from PIL import Image
5
+ import cv2
6
  import numpy as np
7
+ import gradio as gr
8
 
9
+ # ------------------------------------------------------------------
10
+ # 1. Define the GenConViT Model Architecture (Minimal Version)
11
+ # ------------------------------------------------------------------
12
+
13
+ class GenConViT(nn.Module):
14
+ def __init__(self, num_classes=2):
15
+ super().__init__()
16
+ # Very lightweight demo backbone (adjust to your real architecture)
17
+ self.feature_extractor = nn.Sequential(
18
+ nn.Conv2d(3, 32, 3, stride=2, padding=1),
19
+ nn.ReLU(),
20
+ nn.Conv2d(32, 64, 3, stride=2, padding=1),
21
+ nn.ReLU(),
22
+ nn.Conv2d(64, 128, 3, stride=2, padding=1),
23
+ nn.AdaptiveAvgPool2d((1, 1)),
24
+ )
25
+ self.fc = nn.Linear(128, num_classes)
26
+
27
+ def forward(self, x):
28
+ x = self.feature_extractor(x)
29
+ x = x.flatten(1)
30
+ return self.fc(x)
31
+
32
+
33
+ # ------------------------------------------------------------------
34
+ # 2. Load Model From genconvit_ed_inference.pth
35
+ # ------------------------------------------------------------------
36
+
37
+ model_path = "genconvit_ed_inference.pth"
38
+
39
+ model = GenConViT(num_classes=2)
40
+ checkpoint = torch.load(model_path, map_location="cpu")
41
+ model.load_state_dict(checkpoint)
42
+ model.eval()
43
 
44
+ # ------------------------------------------------------------------
45
+ # 3. Preprocessing
46
+ # ------------------------------------------------------------------
47
 
48
+ transform = transforms.Compose([
49
+ transforms.Resize((224, 224)),
50
+ transforms.ToTensor(),
51
+ transforms.Normalize([0.5]*3, [0.5]*3)
52
+ ])
53
+
54
+ # ------------------------------------------------------------------
55
+ # 4. Video Deepfake Detection Function
56
+ # ------------------------------------------------------------------
57
 
 
 
 
58
  def detect_deepfake(video):
59
 
 
60
  cap = cv2.VideoCapture(video)
 
61
  if not cap.isOpened():
62
+ return "Error: Cannot open video", None
63
 
64
  scores = []
65
+ sample_frame = None
66
+ frame_interval = 10 # Process every 10th frame
 
 
 
 
 
67
  i = 0
68
+
69
  while True:
70
  ret, frame = cap.read()
71
  if not ret:
72
  break
73
 
74
  if i % frame_interval == 0:
 
75
  rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
76
+ img = Image.fromarray(rgb)
77
+
78
+ # Save last processed frame for display
79
+ sample_frame = img
80
 
81
+ inp = transform(img).unsqueeze(0)
82
 
83
  with torch.no_grad():
84
+ logits = model(inp)
85
+ probs = torch.softmax(logits, dim=1)[0]
86
+ fake_prob = probs[1].item()
87
 
88
+ scores.append(fake_prob)
 
89
 
90
  i += 1
91
 
 
94
  if len(scores) == 0:
95
  return "No frames processed", None
96
 
97
+ avg = float(np.mean(scores))
98
+ label = "🔴 Deepfake" if avg > 0.5 else "🟢 Real"
99
 
100
+ output = f"""
101
+ ### **Prediction: {label}**
102
+ **Fake confidence: {avg:.4f}**
 
 
103
  """
104
 
105
+ return output, sample_frame
106
+
107
 
108
+ # ------------------------------------------------------------------
109
+ # 5. Gradio App UI
110
+ # ------------------------------------------------------------------
111
 
 
 
 
112
  app = gr.Interface(
113
  fn=detect_deepfake,
114
  inputs=gr.Video(label="Upload a video"),
115
  outputs=[
116
  gr.Markdown(label="Prediction"),
117
+ gr.Image(label="Sample Frame")
118
  ],
119
+ title="GenConViT Deepfake Detector (Local .pth Model)",
120
+ description="Upload a video. The system loads genconvit_ed_inference.pth and predicts deepfake probability."
121
  )
122
 
123
  app.launch()