import os import tempfile import shutil import torch import gradio as gr from types import SimpleNamespace from huggingface_hub import hf_hub_download # Import your original pipeline import infer_videos_txt as tall_infer # ================================ # CONFIGURE YOUR MODEL REPO HERE # ================================ MODEL_REPO = "guard2PFE/tall4deepfake-weights" # <-- change if needed MODEL_FILE = "model_best.pth" # <-- change if needed # ================================ # Build args object (like CLI) # ================================ def build_args(ckpt_path): return SimpleNamespace( video_list="", initial_checkpoint=ckpt_path, dataset="ffpp", model="TALL_SWIN", device="cuda" if torch.cuda.is_available() else "cpu", num_workers=0, duration=4, frames_per_group=1, num_clips=8, num_crops=1, thumbnail_rows=2, input_size=224, threshold=0.5, disable_scaleup=False, threed_data=False, dense_sampling=True, augmentor_ver="v1", scale_range=[256, 320], modality="rgb", use_lmdb=False, hpe_to_token=False, rel_pos=False, window_size=7, no_token_mask=False, drop=0.0, drop_path=0.1, drop_block=None, use_checkpoint=False, dist_url="env://", world_size=1, local_rank=None, output_json="", output_csv="", ) # ================================ # Load model once (global cache) # ================================ META = None ARGS = None def ensure_model_loaded(): global META, ARGS if META is not None: return print("Downloading checkpoint...") ckpt_path = hf_hub_download( repo_id=MODEL_REPO, filename=MODEL_FILE ) ARGS = build_args(ckpt_path) META = tall_infer.build_model_and_augmentor(ARGS) # ================================ # Inference function # ================================ @torch.no_grad() def predict(video, threshold): ensure_model_loaded() ARGS.threshold = float(threshold) if isinstance(video, dict): video_path = video["name"] else: video_path = video if not os.path.isfile(video_path): return {"error": "Video not found"} tmp_dir = tempfile.mkdtemp(prefix="tall_space_") try: tmp_info = tall_infer.build_tmp_dataset_from_video( video_path, tmp_dir, image_tmpl=META["image_tmpl"] ) result = tall_infer.infer_one_video_from_tmp( ARGS, META, tmp_dir, tmp_info["list_rel"], image_tmpl=META["image_tmpl"] ) return { "video": os.path.basename(video_path), "frames": tmp_info["nframes"], **result, "device": str(META["device"]) } except Exception as e: return {"error": str(e)} finally: shutil.rmtree(tmp_dir, ignore_errors=True) # ================================ # Gradio UI # ================================ demo = gr.Interface( fn=predict, inputs=[ gr.Video(label="Upload Video"), gr.Slider(0, 1, value=0.5, step=0.01, label="Threshold") ], outputs=gr.JSON(), title="TALL4Deepfake Detector", description="Video-level deepfake detection using TALL-SWIN" ) if __name__ == "__main__": demo.launch()