guard2PFE's picture
Update app.py
6de2f3c verified
import os
import tempfile
import shutil
import torch
import gradio as gr
from types import SimpleNamespace
from huggingface_hub import hf_hub_download
# Import your original pipeline
import infer_videos_txt as tall_infer
# ================================
# CONFIGURE YOUR MODEL REPO HERE
# ================================
MODEL_REPO = "guard2PFE/tall4deepfake-weights" # <-- change if needed
MODEL_FILE = "model_best.pth" # <-- change if needed
# ================================
# Build args object (like CLI)
# ================================
def build_args(ckpt_path):
return SimpleNamespace(
video_list="",
initial_checkpoint=ckpt_path,
dataset="ffpp",
model="TALL_SWIN",
device="cuda" if torch.cuda.is_available() else "cpu",
num_workers=0,
duration=4,
frames_per_group=1,
num_clips=8,
num_crops=1,
thumbnail_rows=2,
input_size=224,
threshold=0.5,
disable_scaleup=False,
threed_data=False,
dense_sampling=True,
augmentor_ver="v1",
scale_range=[256, 320],
modality="rgb",
use_lmdb=False,
hpe_to_token=False,
rel_pos=False,
window_size=7,
no_token_mask=False,
drop=0.0,
drop_path=0.1,
drop_block=None,
use_checkpoint=False,
dist_url="env://",
world_size=1,
local_rank=None,
output_json="",
output_csv="",
)
# ================================
# Load model once (global cache)
# ================================
META = None
ARGS = None
def ensure_model_loaded():
global META, ARGS
if META is not None:
return
print("Downloading checkpoint...")
ckpt_path = hf_hub_download(
repo_id=MODEL_REPO,
filename=MODEL_FILE
)
ARGS = build_args(ckpt_path)
META = tall_infer.build_model_and_augmentor(ARGS)
# ================================
# Inference function
# ================================
@torch.no_grad()
def predict(video, threshold):
ensure_model_loaded()
ARGS.threshold = float(threshold)
if isinstance(video, dict):
video_path = video["name"]
else:
video_path = video
if not os.path.isfile(video_path):
return {"error": "Video not found"}
tmp_dir = tempfile.mkdtemp(prefix="tall_space_")
try:
tmp_info = tall_infer.build_tmp_dataset_from_video(
video_path,
tmp_dir,
image_tmpl=META["image_tmpl"]
)
result = tall_infer.infer_one_video_from_tmp(
ARGS,
META,
tmp_dir,
tmp_info["list_rel"],
image_tmpl=META["image_tmpl"]
)
return {
"video": os.path.basename(video_path),
"frames": tmp_info["nframes"],
**result,
"device": str(META["device"])
}
except Exception as e:
return {"error": str(e)}
finally:
shutil.rmtree(tmp_dir, ignore_errors=True)
# ================================
# Gradio UI
# ================================
demo = gr.Interface(
fn=predict,
inputs=[
gr.Video(label="Upload Video"),
gr.Slider(0, 1, value=0.5, step=0.01, label="Threshold")
],
outputs=gr.JSON(),
title="TALL4Deepfake Detector",
description="Video-level deepfake detection using TALL-SWIN"
)
if __name__ == "__main__":
demo.launch()