vu0018's picture
Update app.py
dbe3f97 verified
raw
history blame
2.35 kB
import gradio as gr
import torch
import cv2
from transformers import AutoModelForImageClassification, AutoImageProcessor
from PIL import Image
import numpy as np
# ----------------------------------------------------------
# Load Hugging Face GenConViT Model
# ----------------------------------------------------------
model = AutoModelForImageClassification.from_pretrained(
"Thanuja2109/GenConViT"
)
processor = AutoImageProcessor.from_pretrained(
"Thanuja2109/GenConViT"
)
model.eval()
# ----------------------------------------------------------
# Deepfake detection function
# ----------------------------------------------------------
def detect_deepfake(video):
# Load video
cap = cv2.VideoCapture(video)
if not cap.isOpened():
return "Error: cannot open video", None
scores = []
frames_collected = 0
# Sample 1 frame every 10
frame_interval = 10
frame_img = None
i = 0
while True:
ret, frame = cap.read()
if not ret:
break
if i % frame_interval == 0:
# Convert to RGB
rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
pil_img = Image.fromarray(rgb)
inputs = processor(images=pil_img, return_tensors="pt")
with torch.no_grad():
logits = model(**inputs).logits
prob_fake = torch.softmax(logits, dim=1)[0][1].item()
scores.append(prob_fake)
frame_img = pil_img # save last sampled frame
i += 1
cap.release()
if len(scores) == 0:
return "No frames processed", None
avg_score = np.mean(scores)
label = "πŸ”΄ Deepfake" if avg_score > 0.5 else "🟒 Real"
result_text = f"""
### Prediction: **{label}**
**Confidence (fake probability): {avg_score:.4f}**
"""
return result_text, frame_img
# ----------------------------------------------------------
# Gradio Interface
# ----------------------------------------------------------
app = gr.Interface(
fn=detect_deepfake,
inputs=gr.Video(label="Upload a video"),
outputs=[
gr.Markdown(label="Prediction"),
gr.Image(label="Analyzed Frame")
],
title="GenConViT Deepfake Video Detector",
description="Upload a video. The app samples frames and uses GenConViT to detect deepfakes."
)
app.launch()