farwew commited on
Commit
e231c61
Β·
verified Β·
1 Parent(s): aa11893

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -14
app.py CHANGED
@@ -1,19 +1,20 @@
1
- import gradio as gr
2
  import torch
3
  import torch.nn as nn
4
  import numpy as np
5
  from torchvision import transforms as T
6
  from torchvision.transforms.v2 import ToDtype
7
  from decord import VideoReader, cpu
8
- from trainers import vificlip
9
- from utils.config import get_config
10
- from utils.logger import create_logger
11
 
12
  # -------------------------
13
- # Setup Device & Seed
14
  # -------------------------
15
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
- torch.manual_seed(42)
 
 
 
17
 
18
  # -------------------------
19
  # Transform
@@ -34,16 +35,22 @@ class ClassificationHead(nn.Module):
34
  def __init__(self, input_dim=512, num_classes=1):
35
  super().__init__()
36
  self.dense = nn.Linear(input_dim, num_classes)
 
37
  def forward(self, x):
38
  return self.dense(x)
39
 
40
  # -------------------------
41
- # Load ViFi-CLIP + Classifier
42
  # -------------------------
 
 
 
 
43
  cfgpth = 'configs/zero_shot/train/k400/16_16_vifi_clip.yaml'
44
- model_path = 'vifi_clip_30_epochs_k400_full_finetuned.pth'
45
  classifier_path = 'best_detector_model.pt'
46
 
 
 
47
  class parse_option:
48
  def __init__(self):
49
  self.config = cfgpth
@@ -69,7 +76,7 @@ classifier.to(device)
69
  classifier.eval()
70
 
71
  # -------------------------
72
- # Inference Function (with threshold)
73
  # -------------------------
74
  def predict_video(video_path, threshold=0.5):
75
  preprocess = _transform(224)
@@ -104,15 +111,15 @@ def predict_video(video_path, threshold=0.5):
104
  return f"❌ Error: {str(e)}"
105
 
106
  # -------------------------
107
- # Gradio UI (with slider)
108
  # -------------------------
109
  gr.Interface(
110
  fn=predict_video,
111
  inputs=[
112
- gr.Video(type="filepath", label="Upload Video"),
113
  gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Threshold (Real β‰₯ Threshold)")
114
  ],
115
  outputs="text",
116
- title="Fake Video Detection with Threshold Control",
117
- description="Upload a video to classify it as Real or Fake. Adjust the threshold to tune sensitivity."
118
  ).launch()
 
1
+ import os
2
  import torch
3
  import torch.nn as nn
4
  import numpy as np
5
  from torchvision import transforms as T
6
  from torchvision.transforms.v2 import ToDtype
7
  from decord import VideoReader, cpu
8
+ import gradio as gr
 
 
9
 
10
  # -------------------------
11
+ # Step 0: Download model from Google Drive if not exists
12
  # -------------------------
13
+ model_path = 'vifi_clip_30_epochs_k400_full_finetuned.pth'
14
+ if not os.path.exists(model_path):
15
+ print(f"πŸ”½ Downloading model to {model_path}...")
16
+ os.system("pip install -q gdown")
17
+ os.system("gdown --id 1Nx30Kbu5xnv6dPwz4I3Ivy380LCdp1Md -O vifi_clip_30_epochs_k400_full_finetuned.pth")
18
 
19
  # -------------------------
20
  # Transform
 
35
  def __init__(self, input_dim=512, num_classes=1):
36
  super().__init__()
37
  self.dense = nn.Linear(input_dim, num_classes)
38
+
39
  def forward(self, x):
40
  return self.dense(x)
41
 
42
  # -------------------------
43
+ # Load ViFi-CLIP Model
44
  # -------------------------
45
+ from trainers import vificlip
46
+ from utils.config import get_config
47
+ from utils.logger import create_logger
48
+
49
  cfgpth = 'configs/zero_shot/train/k400/16_16_vifi_clip.yaml'
 
50
  classifier_path = 'best_detector_model.pt'
51
 
52
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
53
+
54
  class parse_option:
55
  def __init__(self):
56
  self.config = cfgpth
 
76
  classifier.eval()
77
 
78
  # -------------------------
79
+ # Inference Function
80
  # -------------------------
81
  def predict_video(video_path, threshold=0.5):
82
  preprocess = _transform(224)
 
111
  return f"❌ Error: {str(e)}"
112
 
113
  # -------------------------
114
+ # Gradio UI
115
  # -------------------------
116
  gr.Interface(
117
  fn=predict_video,
118
  inputs=[
119
+ gr.Video(type="filepath", label="Upload Video (.mp4)"),
120
  gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Threshold (Real β‰₯ Threshold)")
121
  ],
122
  outputs="text",
123
+ title="🧠 Deepfake Detection with ViFi-CLIP",
124
+ description="Upload a video to classify it as Real or Fake. Threshold slider lets you adjust sensitivity."
125
  ).launch()