Spaces:
Runtime error
Runtime error
File size: 4,822 Bytes
e0c75d6 5728d33 e0c75d6 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 | import gradio as gr
import torch
import os
import cv2
import urllib.request
from model.pred_func import load_genconvit, df_face, pred_vid, real_or_fake
from model.config import load_config
# --- Model Download ---
def download_models():
"""
Downloads the pre-trained model weights if they don't exist.
"""
weight_dir = 'weight'
ed_url = 'https://huggingface.co/Deressa/GenConViT/resolve/main/genconvit_ed_inference.pth'
vae_url = 'https://huggingface.co/Deressa/GenConViT/resolve/main/genconvit_vae_inference.pth'
ed_path = os.path.join(weight_dir, 'genconvit_ed_inference.pth')
vae_path = os.path.join(weight_dir, 'genconvit_vae_inference.pth')
if not os.path.exists(weight_dir):
os.makedirs(weight_dir)
if not os.path.exists(ed_path):
print("Downloading ED model weights...")
urllib.request.urlretrieve(ed_url, ed_path)
print("Download complete.")
if not os.path.exists(vae_path):
print("Downloading VAE model weights...")
urllib.request.urlretrieve(vae_url, vae_path)
print("Download complete.")
# --- Global Variables ---
config = load_config()
model = None
def load_model_once():
"""
Loads the model into memory. This function is called once at the start.
"""
global model
if model is None:
download_models()
print("Loading GenConViT model...")
ed_weight = 'genconvit_ed_inference'
vae_weight = 'genconvit_vae_inference'
# Set net='genconvit' to use both ED and VAE as per prediction.py logic for best results
model = load_genconvit(config, net='genconvit', ed_weight=ed_weight, vae_weight=vae_weight, fp16=False)
print("Model loaded successfully.")
def get_video_duration(video_path):
cap = cv2.VideoCapture(video_path)
if not cap.isOpened():
return 0
fps = cap.get(cv2.CAP_PROP_FPS)
frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT)
cap.release()
if fps == 0:
return 0
return frame_count / fps
# --- Prediction Function ---
def detect_deepfake(video_path, model_type, num_frames):
if video_path is None:
return "❌ Please upload a video file."
# ===== VALIDASI DURASI VIDEO =====
duration = get_video_duration(video_path)
if duration > 60:
return "❌ Video terlalu besar. Durasi maksimal adalah 1 menit (60 detik)."
try:
print(f"Processing video: {video_path} with model: {model_type}")
# Map model_type to internal net identifier
net_mapping = {
"GenConViT": "genconvit",
"AE": "ed",
"VAE": "vae"
}
net_val = net_mapping.get(model_type, "genconvit")
# Extract faces from the video
faces = df_face(video_path, num_frames)
if len(faces) == 0:
return "No faces were detected in the video. Please try another video."
# Make prediction
y, y_val = pred_vid(faces, model, net=net_val)
# Get the label (REAL or FAKE)
label = real_or_fake(y)
# The confidence score y_val is a bit complex in the original code.
# For simplicity, we'll show the raw score associated with the prediction.
# A lower score generally means more likely to be REAL, higher means more likely to be FAKE.
confidence = y_val if label == 'FAKE' else 1 - y_val
return { "FAKE": confidence, "REAL": 1 - confidence }
except Exception as e:
print(f"An error occurred: {e}")
return "An error occurred during processing. The video might be corrupted or in an unsupported format."
# --- Gradio Interface ---
title = "GenConViT: Deepfake Video Detection"
description = """
Upload a video file to detect if it's a deepfake. This application uses the Generative Convolutional Vision Transformer (GenConViT)
to analyze the video. The model achieves an average accuracy of 95.8% and an AUC of 99.3% across multiple datasets.
"""
# Load the model once when the app starts
load_model_once()
iface = gr.Interface(
fn=detect_deepfake,
inputs=[
gr.Video(label="Upload Video"),
gr.Radio(["GenConViT", "AE", "VAE"], label="Pilih Model", value="GenConViT"),
gr.Slider(1, 200, value=15, step=1, label="Number of Frames")
],
outputs=gr.Label(num_top_classes=2, label="Prediction Result"),
title=title,
description=description,
flagging_mode="never",
examples=[
["sample_prediction_data/aajsqyyjni.mp4", "GenConViT", 15],
["sample_prediction_data/anndvqgoko.mp4", "GenConViT", 15],
["sample_prediction_data/0017_fake.mp4.mp4", "GenConViT", 15],
["sample_prediction_data/0048_fake.mp4.mp4", "GenConViT", 15]
]
)
if __name__ == "__main__":
iface.queue().launch()
|