gnikhilchand commited on
Commit
acce773
·
verified ·
1 Parent(s): e540bdf

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +112 -0
app.py ADDED
@@ -0,0 +1,112 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from torch import nn
4
+ from PIL import Image
5
+ from transformers import CLIPProcessor, CLIPModel
6
+ import cv2
7
+ import numpy as np
8
+
9
+ # --- 1. Define Model Architecture (Must match your training script) ---
10
+ class CLIPImageClassifier(nn.Module):
11
+ def __init__(self, clip_model_name="openai/clip-vit-base-patch32"):
12
+ super(CLIPImageClassifier, self).__init__()
13
+ self.clip = CLIPModel.from_pretrained(clip_model_name)
14
+ self.classifier = nn.Sequential(
15
+ nn.Linear(self.clip.config.vision_config.hidden_size, 256),
16
+ nn.ReLU(),
17
+ nn.Dropout(0.5),
18
+ nn.Linear(256, 1),
19
+ nn.Sigmoid()
20
+ )
21
+ def forward(self, pixel_values):
22
+ vision_outputs = self.clip.vision_model(pixel_values=pixel_values)
23
+ image_features = vision_outputs.pooler_output
24
+ return self.classifier(image_features)
25
+
26
+ # --- 2. Load Model & Processor ---
27
+ DEVICE = "cpu" # Force CPU for Hugging Face Free Tier
28
+ MODEL_PATH = "best_clip_finetuned_classifier.pth"
29
+ CLIP_NAME = "openai/clip-vit-base-patch32"
30
+
31
+ print("Loading model...")
32
+ model = CLIPImageClassifier()
33
+ # Load weights with strict=False to ignore potential extra keys
34
+ model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device(DEVICE)), strict=False)
35
+ model.to(DEVICE)
36
+ model.eval()
37
+
38
+ print("Loading processor...")
39
+ processor = CLIPProcessor.from_pretrained(CLIP_NAME)
40
+
41
+ # --- 3. Define Inference Function ---
42
+ def predict_video(video_path):
43
+ """
44
+ Gradio passes the 'video_path' as a string to the temporary file.
45
+ """
46
+ if video_path is None:
47
+ return "Please upload a video.", 0.0
48
+
49
+ print(f"Processing video: {video_path}")
50
+ cap = cv2.VideoCapture(video_path)
51
+
52
+ fps = cap.get(cv2.CAP_PROP_FPS)
53
+ if fps == 0 or np.isnan(fps):
54
+ fps = 30 # Default fallback
55
+
56
+ # Sample 1 frame every second to keep it fast on CPU
57
+ frames_to_sample = 1
58
+ frame_skip = max(1, int(fps / frames_to_sample))
59
+
60
+ predictions = []
61
+ frame_count = 0
62
+
63
+ while cap.isOpened():
64
+ ret, frame = cap.read()
65
+ if not ret:
66
+ break
67
+
68
+ if frame_count % frame_skip == 0:
69
+ # Convert BGR (OpenCV) to RGB (PIL)
70
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
71
+ pil_image = Image.fromarray(frame_rgb)
72
+
73
+ # Preprocess
74
+ inputs = processor(images=pil_image, return_tensors="pt")['pixel_values'].to(DEVICE)
75
+
76
+ # Inference
77
+ with torch.no_grad():
78
+ output = model(inputs)
79
+ prob = output.item()
80
+ predictions.append(prob)
81
+
82
+ frame_count += 1
83
+
84
+ cap.release()
85
+
86
+ if not predictions:
87
+ return "Could not analyze video frames.", 0.0
88
+
89
+ # Aggregate results
90
+ avg_fake_prob = sum(predictions) / len(predictions)
91
+
92
+ # Create Final Label
93
+ label = "FAKE" if avg_fake_prob > 0.5 else "REAL"
94
+ confidence = avg_fake_prob if label == "FAKE" else (1 - avg_fake_prob)
95
+
96
+ return f"{label} (Confidence: {confidence:.2%})", avg_fake_prob
97
+
98
+ # --- 4. Create Gradio Interface ---
99
+ interface = gr.Interface(
100
+ fn=predict_video,
101
+ inputs=gr.Video(label="Upload Video"),
102
+ outputs=[
103
+ gr.Textbox(label="Verdict"),
104
+ gr.Number(label="Fake Probability Score (0=Real, 1=Fake)")
105
+ ],
106
+ title="DeepFake Video Detector",
107
+ description="Upload a video to check if it is Real or AI-Generated. The model analyzes using a fine-tuned CLIP classifier."
108
+ )
109
+
110
+ # Launch the app
111
+ if __name__ == "__main__":
112
+ interface.launch()