# app.py import os import sys import time from pathlib import Path from typing import List, Optional, Tuple import gradio as gr import numpy as np from huggingface_hub import hf_hub_download, HfApi from unigaze.infer_runtime import UniGazeRuntime # in-process runtime (no subprocess) # -------------------------------------------------------------------------------------- # Defaults # -------------------------------------------------------------------------------------- DEFAULT_HF_REPO = "xucongzhang/UniGaze-models" DEFAULT_CKPT_FILE = [ "unigaze_h14_joint.pth.tar", "unigaze_l16_joint.pth.tar", "unigaze_b16_joint.pth.tar", ] DEFAULT_REVISION = "main" DEFAULT_CFGS = [ "unigaze/configs/model/mae_h_14_gaze.yaml", "unigaze/configs/model/mae_L_16_gaze.yaml", "unigaze/configs/model/mae_b_16_gaze.yaml", ] TITLE = "UniGaze Demo (Video + Image)" DESC = """ Upload a short video or a single image. The app downloads a checkpoint from the Hub, runs UniGaze in-process (no subprocess, no permanent writes), and returns results. """ # Ensure imports of local packages work sys.path.append(os.path.dirname(__file__)) # -------------------------------------------------------------------------------------- # Helpers # -------------------------------------------------------------------------------------- def resolve_cfg_abs(cfg_str: str) -> Path: """Return an absolute path to the YAML config.""" p = Path(cfg_str) if p.is_absolute(): if p.exists(): return p raise FileNotFoundError(f"Config not found: {p}") p2 = (Path.cwd() / p).resolve() if p2.exists(): return p2 if str(p).startswith("configs/"): p3 = (Path.cwd() / "unigaze" / p).resolve() if p3.exists(): return p3 raise FileNotFoundError(f"Config not found. Tried: {p2}") def list_weight_files(repo_id: str, revision: str = "main") -> List[str]: try: api = HfApi() files = api.list_repo_files(repo_id=repo_id, repo_type="model", revision=revision) return [f for f in files if f.lower().endswith((".pth", ".pt", ".safetensors", ".tar", ".pth.tar"))] except Exception: return [] def get_ckpt_path(repo_id: str, filename: str, revision: str = "main") -> str: files = list_weight_files(repo_id, revision) if files and filename not in files: raise FileNotFoundError( f"File '{filename}' not found in model repo '{repo_id}' at rev '{revision}'. " f"Available weights: {files}" ) return hf_hub_download( repo_id=repo_id, filename=filename, revision=revision, repo_type="model", ) # Cache the runtime so we load model/FA only once from functools import lru_cache @lru_cache(maxsize=3) def get_runtime(cfg_abs_str: str, ckpt_path: str, device: str = "cpu") -> UniGazeRuntime: return UniGazeRuntime(cfg_abs_str, ckpt_path, device=device) # -------------------------------------------------------------------------------------- # Runners (in-process) # -------------------------------------------------------------------------------------- def run_unigaze_on_video( video_path: str, hf_repo: str, ckpt_filename: str, cfg_path_user: str, extra_args: str = "", ) -> Tuple[Optional[np.ndarray], Optional[str], Optional[str], str]: logs: List[str] = [] t0 = time.time() try: ckpt_path = get_ckpt_path(hf_repo, ckpt_filename, revision=DEFAULT_REVISION) logs.append(f"[hub] downloaded: {ckpt_path}") except Exception as e: return None, None, None, f"[hub] ERROR: {e}" try: cfg_abs = resolve_cfg_abs(cfg_path_user) except Exception as e: return None, None, None, f"[cfg] {e}" rt = get_runtime(str(cfg_abs), ckpt_path, device="cpu") mp4_path, last_rgb, run_sec = rt.predict_video(video_path) logs.append(f"[time] total runtime: {run_sec:.2f} seconds") return (last_rgb if last_rgb is not None else None), mp4_path, None, "\n".join(logs) def run_unigaze_on_image( image_array: np.ndarray, hf_repo: str, ckpt_filename: str, cfg_path_user: str, extra_args: str = "", ) -> Tuple[Optional[np.ndarray], str]: logs: List[str] = [] t0 = time.time() try: ckpt_path = get_ckpt_path(hf_repo, ckpt_filename, revision=DEFAULT_REVISION) logs.append(f"[hub] downloaded: {ckpt_path}") except Exception as e: return None, f"[hub] ERROR: {e}" try: cfg_abs = resolve_cfg_abs(cfg_path_user) except Exception as e: return None, f"[cfg] {e}" rt = get_runtime(str(cfg_abs), ckpt_path, device="cpu") out_rgb = rt.predict_image(image_array) logs.append(f"[time] total runtime: {time.time() - t0:.2f} seconds") return out_rgb, "\n".join(logs) # -------------------------------------------------------------------------------------- # UI # -------------------------------------------------------------------------------------- with gr.Blocks(title=TITLE) as demo: gr.Markdown(f"# {TITLE}\n{DESC}") with gr.Row(): ckpt_file = gr.Dropdown(choices=DEFAULT_CKPT_FILE, value=DEFAULT_CKPT_FILE[0], label="Checkpoint filename") cfg_choice = gr.Dropdown(choices=DEFAULT_CFGS, value=DEFAULT_CFGS[0], label="Model config") # IMAGE TAB with gr.Tab("Image"): in_img = gr.Image(type="numpy", label="Input image") run_img = gr.Button("Run on Image", variant="primary") out_img = gr.Image(label="Output image") out_logs = gr.Textbox(label="Logs", interactive=False, lines=18) def ui_predict_image(image, ckpt, cfg_use): return run_unigaze_on_image( image_array=image, hf_repo=DEFAULT_HF_REPO, ckpt_filename=ckpt, cfg_path_user=cfg_use, ) run_img.click( fn=ui_predict_image, inputs=[in_img, ckpt_file, cfg_choice], outputs=[out_img, out_logs], ) # Example image gr.Examples( examples=[["examples/The_Night_Watch_Frans_Banninck_Cocq.png", DEFAULT_CKPT_FILE[0], DEFAULT_CFGS[0]]], inputs=[in_img, ckpt_file, cfg_choice], outputs=[out_img, out_logs], fn=ui_predict_image, cache_examples=False, ) # VIDEO TAB with gr.Tab("Video"): in_vid = gr.Video(label="Input video", sources=["upload"]) run_vid = gr.Button("Run on Video", variant="primary") out_img_v = gr.Image(label="Annotated image (last frame)") out_vid_v = gr.Video(label="Output video") out_zip_v = gr.File(label="All artifacts as ZIP") # always None now out_logs_v = gr.Textbox(label="Logs", interactive=False, lines=18) def ui_predict_video(video, ckpt, cfg_use): return run_unigaze_on_video( video_path=video, hf_repo=DEFAULT_HF_REPO, ckpt_filename=ckpt, cfg_path_user=cfg_use, ) run_vid.click( fn=ui_predict_video, inputs=[in_vid, ckpt_file, cfg_choice], outputs=[out_img_v, out_vid_v, out_zip_v, out_logs_v], ) if __name__ == "__main__": demo.launch()