farwew commited on
Commit
bb7eaf5
Β·
verified Β·
1 Parent(s): c152631

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +118 -0
app.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ import torch.nn as nn
4
+ import numpy as np
5
+ from torchvision import transforms as T
6
+ from torchvision.transforms.v2 import ToDtype
7
+ from decord import VideoReader, cpu
8
+ from trainers import vificlip
9
+ from utils.config import get_config
10
+ from utils.logger import create_logger
11
+
12
+ # -------------------------
13
+ # Setup Device & Seed
14
+ # -------------------------
15
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
16
+ torch.manual_seed(42)
17
+
18
+ # -------------------------
19
+ # Transform
20
+ # -------------------------
21
+ def _transform(n_px=224):
22
+ return T.Compose([
23
+ ToDtype(torch.float32, scale=True),
24
+ T.Resize(n_px, antialias=True),
25
+ T.CenterCrop(n_px),
26
+ T.Normalize((0.48145466, 0.4578275, 0.40821073),
27
+ (0.26862954, 0.26130258, 0.27577711)),
28
+ ])
29
+
30
+ # -------------------------
31
+ # Classifier Head
32
+ # -------------------------
33
+ class ClassificationHead(nn.Module):
34
+ def __init__(self, input_dim=512, num_classes=1):
35
+ super().__init__()
36
+ self.dense = nn.Linear(input_dim, num_classes)
37
+ def forward(self, x):
38
+ return self.dense(x)
39
+
40
+ # -------------------------
41
+ # Load ViFi-CLIP + Classifier
42
+ # -------------------------
43
+ cfgpth = 'configs/zero_shot/train/k400/16_16_vifi_clip.yaml'
44
+ model_path = 'vifi_clip_30_epochs_k400_full_finetuned.pth'
45
+ classifier_path = 'best_detector_model.pt'
46
+
47
+ class parse_option:
48
+ def __init__(self):
49
+ self.config = cfgpth
50
+ self.output = "exp"
51
+ self.resume = model_path
52
+ self.only_test = True
53
+ self.opts = None
54
+ self.batch_size = None
55
+ self.pretrained = None
56
+ self.accumulation_steps = None
57
+ self.local_rank = 0
58
+
59
+ args = parse_option()
60
+ config = get_config(args)
61
+ logger = create_logger(output_dir=args.output, name=f"{config.MODEL.ARCH}")
62
+ model = vificlip.returnCLIP(config, logger, class_names=["true", "false"])
63
+ model = model.float().to(device)
64
+ feature_extractor = model
65
+
66
+ classifier = ClassificationHead()
67
+ classifier.load_state_dict(torch.load(classifier_path, map_location=device))
68
+ classifier.to(device)
69
+ classifier.eval()
70
+
71
+ # -------------------------
72
+ # Inference Function (with threshold)
73
+ # -------------------------
74
+ def predict_video(video_path, threshold=0.5):
75
+ preprocess = _transform(224)
76
+ try:
77
+ vr = VideoReader(video_path, ctx=cpu(0))
78
+ total_frames = len(vr)
79
+ num_frames = 16
80
+
81
+ if total_frames > num_frames:
82
+ start = np.random.randint(0, total_frames - num_frames)
83
+ indices = list(range(start, start + num_frames))
84
+ else:
85
+ indices = list(range(total_frames))
86
+ indices += [total_frames - 1] * (num_frames - len(indices))
87
+
88
+ frames = vr.get_batch(indices).asnumpy()
89
+ video_tensor = torch.from_numpy(frames).permute(0, 3, 1, 2)
90
+ video_tensor = preprocess(video_tensor).unsqueeze(0).to(device)
91
+
92
+ B, T, C, H, W = video_tensor.shape
93
+ input_clip = video_tensor.view(B * T, C, H, W)
94
+
95
+ with torch.no_grad():
96
+ features = feature_extractor.image_encoder(input_clip)
97
+ features = features.view(B, T, -1).mean(dim=1)
98
+ logits = classifier(features)
99
+ prob = torch.sigmoid(logits).item()
100
+ label = "Real" if prob >= threshold else "Fake"
101
+
102
+ return f"{label} (prob: {prob:.4f}, threshold: {threshold})"
103
+ except Exception as e:
104
+ return f"❌ Error: {str(e)}"
105
+
106
+ # -------------------------
107
+ # Gradio UI (with slider)
108
+ # -------------------------
109
+ gr.Interface(
110
+ fn=predict_video,
111
+ inputs=[
112
+ gr.Video(type="filepath", label="Upload Video"),
113
+ gr.Slider(0.0, 1.0, value=0.5, step=0.01, label="Threshold (Real β‰₯ Threshold)")
114
+ ],
115
+ outputs="text",
116
+ title="Fake Video Detection with Threshold Control",
117
+ description="Upload a video to classify it as Real or Fake. Adjust the threshold to tune sensitivity."
118
+ ).launch()