File size: 4,822 Bytes
e0c75d6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5728d33
e0c75d6
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import gradio as gr
import torch
import os
import cv2
import urllib.request
from model.pred_func import load_genconvit, df_face, pred_vid, real_or_fake
from model.config import load_config

# --- Model Download ---
def download_models():
    """
    Downloads the pre-trained model weights if they don't exist.
    """
    weight_dir = 'weight'
    ed_url = 'https://huggingface.co/Deressa/GenConViT/resolve/main/genconvit_ed_inference.pth'
    vae_url = 'https://huggingface.co/Deressa/GenConViT/resolve/main/genconvit_vae_inference.pth'
    ed_path = os.path.join(weight_dir, 'genconvit_ed_inference.pth')
    vae_path = os.path.join(weight_dir, 'genconvit_vae_inference.pth')

    if not os.path.exists(weight_dir):
        os.makedirs(weight_dir)

    if not os.path.exists(ed_path):
        print("Downloading ED model weights...")
        urllib.request.urlretrieve(ed_url, ed_path)
        print("Download complete.")

    if not os.path.exists(vae_path):
        print("Downloading VAE model weights...")
        urllib.request.urlretrieve(vae_url, vae_path)
        print("Download complete.")

# --- Global Variables ---
config = load_config()
model = None

def load_model_once():
    """
    Loads the model into memory. This function is called once at the start.
    """
    global model
    if model is None:
        download_models()
        print("Loading GenConViT model...")
        ed_weight = 'genconvit_ed_inference'
        vae_weight = 'genconvit_vae_inference'
        # Set net='genconvit' to use both ED and VAE as per prediction.py logic for best results
        model = load_genconvit(config, net='genconvit', ed_weight=ed_weight, vae_weight=vae_weight, fp16=False)
        print("Model loaded successfully.")

def get_video_duration(video_path):
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        return 0

    fps = cap.get(cv2.CAP_PROP_FPS)
    frame_count = cap.get(cv2.CAP_PROP_FRAME_COUNT)
    cap.release()

    if fps == 0:
        return 0

    return frame_count / fps

# --- Prediction Function ---
def detect_deepfake(video_path, model_type, num_frames):
    if video_path is None:
        return "❌ Please upload a video file."

    # ===== VALIDASI DURASI VIDEO =====
    duration = get_video_duration(video_path)
    if duration > 60:
        return "❌ Video terlalu besar. Durasi maksimal adalah 1 menit (60 detik)."

    try:
        print(f"Processing video: {video_path} with model: {model_type}")
        
        # Map model_type to internal net identifier
        net_mapping = {
            "GenConViT": "genconvit",
            "AE": "ed",
            "VAE": "vae"
        }
        net_val = net_mapping.get(model_type, "genconvit")

        # Extract faces from the video
        faces = df_face(video_path, num_frames)

        if len(faces) == 0:
            return "No faces were detected in the video. Please try another video."

        # Make prediction
        y, y_val = pred_vid(faces, model, net=net_val)
        
        # Get the label (REAL or FAKE)
        label = real_or_fake(y)

        # The confidence score y_val is a bit complex in the original code.
        # For simplicity, we'll show the raw score associated with the prediction.
        # A lower score generally means more likely to be REAL, higher means more likely to be FAKE.
        
        confidence = y_val if label == 'FAKE' else 1 - y_val
        
        return { "FAKE": confidence, "REAL": 1 - confidence }

    except Exception as e:
        print(f"An error occurred: {e}")
        return "An error occurred during processing. The video might be corrupted or in an unsupported format."

# --- Gradio Interface ---
title = "GenConViT: Deepfake Video Detection"
description = """
Upload a video file to detect if it's a deepfake. This application uses the Generative Convolutional Vision Transformer (GenConViT)
to analyze the video. The model achieves an average accuracy of 95.8% and an AUC of 99.3% across multiple datasets.
"""

# Load the model once when the app starts
load_model_once()

iface = gr.Interface(
    fn=detect_deepfake,
    inputs=[
        gr.Video(label="Upload Video"),
        gr.Radio(["GenConViT", "AE", "VAE"], label="Pilih Model", value="GenConViT"),
        gr.Slider(1, 200, value=15, step=1, label="Number of Frames")
    ],
    outputs=gr.Label(num_top_classes=2, label="Prediction Result"),
    title=title,
    description=description,
    flagging_mode="never",
    examples=[
        ["sample_prediction_data/aajsqyyjni.mp4", "GenConViT", 15],
        ["sample_prediction_data/anndvqgoko.mp4", "GenConViT", 15],
        ["sample_prediction_data/0017_fake.mp4.mp4", "GenConViT", 15],
        ["sample_prediction_data/0048_fake.mp4.mp4", "GenConViT", 15]
    ]
)

if __name__ == "__main__":
    iface.queue().launch()