video / app.py
21f1003825's picture
Update app.py
4251a0b verified
import gradio as gr
import torch
from torch import nn
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
import cv2
import numpy as np
# --- 1. Define Model Architecture (Must match your training script) ---
class CLIPImageClassifier(nn.Module):
def __init__(self, clip_model_name="openai/clip-vit-base-patch32"):
super(CLIPImageClassifier, self).__init__()
self.clip = CLIPModel.from_pretrained(clip_model_name)
self.classifier = nn.Sequential(
nn.Linear(self.clip.config.vision_config.hidden_size, 256),
nn.ReLU(),
nn.Dropout(0.5),
nn.Linear(256, 1),
nn.Sigmoid()
)
def forward(self, pixel_values):
vision_outputs = self.clip.vision_model(pixel_values=pixel_values)
image_features = vision_outputs.pooler_output
return self.classifier(image_features)
# --- 2. Load Model & Processor ---
DEVICE = "cpu" # Force CPU for Hugging Face Free Tier
MODEL_PATH = "best_clip_finetuned_classifier.pth"
CLIP_NAME = "openai/clip-vit-base-patch32"
print("Loading model...")
model = CLIPImageClassifier()
# Load weights with strict=False to ignore potential extra keys
model.load_state_dict(torch.load(MODEL_PATH, map_location=torch.device(DEVICE)), strict=False)
model.to(DEVICE)
model.eval()
print("Loading processor...")
processor = CLIPProcessor.from_pretrained(CLIP_NAME)
# --- 3. Define Inference Function ---
def predict_video(video_path):
"""
Gradio passes the 'video_path' as a string to the temporary file.
"""
if video_path is None:
return "Please upload a video.", 0.0
print(f"Processing video: {video_path}")
cap = cv2.VideoCapture(video_path)
fps = cap.get(cv2.CAP_PROP_FPS)
if fps == 0 or np.isnan(fps):
fps = 30 # Default fallback
# Sample 1 frame every second to keep it fast on CPU
frames_to_sample = 1
frame_skip = max(1, int(fps / frames_to_sample))
predictions = []
frame_count = 0
while cap.isOpened():
ret, frame = cap.read()
if not ret:
break
if frame_count % frame_skip == 0:
# Convert BGR (OpenCV) to RGB (PIL)
frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_image = Image.fromarray(frame_rgb)
# Preprocess
inputs = processor(images=pil_image, return_tensors="pt")['pixel_values'].to(DEVICE)
# Inference
with torch.no_grad():
output = model(inputs)
prob = output.item()
predictions.append(prob)
frame_count += 1
cap.release()
if not predictions:
return "Could not analyze video frames.", 0.0
# Aggregate results
avg_fake_prob = sum(predictions) / len(predictions)
# Create Final Label
label = "FAKE" if avg_fake_prob > 0.5 else "REAL"
confidence = avg_fake_prob if label == "FAKE" else (1 - avg_fake_prob)
return f"{label} (Confidence: {confidence:.2%})", avg_fake_prob
# --- 4. Create Gradio Interface ---
interface = gr.Interface(
fn=predict_video,
inputs=gr.Video(label="Upload Video"),
outputs=[
gr.Textbox(label="Verdict"),
gr.Number(label="Fake Probability Score (0=Real, 1=Fake)")
],
title="DeepFake Video Detector",
description="Upload a video to check if it is Real or AI-Generated. The model analyzes using a fine-tuned CLIP classifier."
)
# Launch the app
if __name__ == "__main__":
interface.launch()