| |
| import os |
| import sys |
| import time |
| from pathlib import Path |
| from typing import List, Optional, Tuple |
|
|
| import gradio as gr |
| import numpy as np |
| from huggingface_hub import hf_hub_download, HfApi |
| from unigaze.infer_runtime import UniGazeRuntime |
|
|
| |
| |
| |
| DEFAULT_HF_REPO = "xucongzhang/UniGaze-models" |
| DEFAULT_CKPT_FILE = [ |
| "unigaze_h14_joint.pth.tar", |
| "unigaze_l16_joint.pth.tar", |
| "unigaze_b16_joint.pth.tar", |
| ] |
| DEFAULT_REVISION = "main" |
|
|
| DEFAULT_CFGS = [ |
| "unigaze/configs/model/mae_h_14_gaze.yaml", |
| "unigaze/configs/model/mae_L_16_gaze.yaml", |
| "unigaze/configs/model/mae_b_16_gaze.yaml", |
| ] |
|
|
| TITLE = "UniGaze Demo (Video + Image)" |
| DESC = """ |
| Upload a short video or a single image. The app downloads a checkpoint from the Hub, |
| runs UniGaze in-process (no subprocess, no permanent writes), and returns results. |
| """ |
|
|
| |
| sys.path.append(os.path.dirname(__file__)) |
|
|
| |
| |
| |
| def resolve_cfg_abs(cfg_str: str) -> Path: |
| """Return an absolute path to the YAML config.""" |
| p = Path(cfg_str) |
| if p.is_absolute(): |
| if p.exists(): |
| return p |
| raise FileNotFoundError(f"Config not found: {p}") |
|
|
| p2 = (Path.cwd() / p).resolve() |
| if p2.exists(): |
| return p2 |
|
|
| if str(p).startswith("configs/"): |
| p3 = (Path.cwd() / "unigaze" / p).resolve() |
| if p3.exists(): |
| return p3 |
|
|
| raise FileNotFoundError(f"Config not found. Tried: {p2}") |
|
|
| def list_weight_files(repo_id: str, revision: str = "main") -> List[str]: |
| try: |
| api = HfApi() |
| files = api.list_repo_files(repo_id=repo_id, repo_type="model", revision=revision) |
| return [f for f in files if f.lower().endswith((".pth", ".pt", ".safetensors", ".tar", ".pth.tar"))] |
| except Exception: |
| return [] |
|
|
| def get_ckpt_path(repo_id: str, filename: str, revision: str = "main") -> str: |
| files = list_weight_files(repo_id, revision) |
| if files and filename not in files: |
| raise FileNotFoundError( |
| f"File '{filename}' not found in model repo '{repo_id}' at rev '{revision}'. " |
| f"Available weights: {files}" |
| ) |
| return hf_hub_download( |
| repo_id=repo_id, |
| filename=filename, |
| revision=revision, |
| repo_type="model", |
| ) |
|
|
| |
| from functools import lru_cache |
| @lru_cache(maxsize=3) |
| def get_runtime(cfg_abs_str: str, ckpt_path: str, device: str = "cpu") -> UniGazeRuntime: |
| return UniGazeRuntime(cfg_abs_str, ckpt_path, device=device) |
|
|
| |
| |
| |
| def run_unigaze_on_video( |
| video_path: str, |
| hf_repo: str, |
| ckpt_filename: str, |
| cfg_path_user: str, |
| extra_args: str = "", |
| ) -> Tuple[Optional[np.ndarray], Optional[str], Optional[str], str]: |
| logs: List[str] = [] |
| t0 = time.time() |
|
|
| try: |
| ckpt_path = get_ckpt_path(hf_repo, ckpt_filename, revision=DEFAULT_REVISION) |
| logs.append(f"[hub] downloaded: {ckpt_path}") |
| except Exception as e: |
| return None, None, None, f"[hub] ERROR: {e}" |
|
|
| try: |
| cfg_abs = resolve_cfg_abs(cfg_path_user) |
| except Exception as e: |
| return None, None, None, f"[cfg] {e}" |
|
|
| rt = get_runtime(str(cfg_abs), ckpt_path, device="cpu") |
| mp4_path, last_rgb, run_sec = rt.predict_video(video_path) |
| logs.append(f"[time] total runtime: {run_sec:.2f} seconds") |
|
|
| return (last_rgb if last_rgb is not None else None), mp4_path, None, "\n".join(logs) |
|
|
| def run_unigaze_on_image( |
| image_array: np.ndarray, |
| hf_repo: str, |
| ckpt_filename: str, |
| cfg_path_user: str, |
| extra_args: str = "", |
| ) -> Tuple[Optional[np.ndarray], str]: |
| logs: List[str] = [] |
| t0 = time.time() |
|
|
| try: |
| ckpt_path = get_ckpt_path(hf_repo, ckpt_filename, revision=DEFAULT_REVISION) |
| logs.append(f"[hub] downloaded: {ckpt_path}") |
| except Exception as e: |
| return None, f"[hub] ERROR: {e}" |
|
|
| try: |
| cfg_abs = resolve_cfg_abs(cfg_path_user) |
| except Exception as e: |
| return None, f"[cfg] {e}" |
|
|
| rt = get_runtime(str(cfg_abs), ckpt_path, device="cpu") |
| out_rgb = rt.predict_image(image_array) |
| logs.append(f"[time] total runtime: {time.time() - t0:.2f} seconds") |
|
|
| return out_rgb, "\n".join(logs) |
|
|
| |
| |
| |
| with gr.Blocks(title=TITLE) as demo: |
| gr.Markdown(f"# {TITLE}\n{DESC}") |
|
|
| with gr.Row(): |
| ckpt_file = gr.Dropdown(choices=DEFAULT_CKPT_FILE, value=DEFAULT_CKPT_FILE[0], label="Checkpoint filename") |
| cfg_choice = gr.Dropdown(choices=DEFAULT_CFGS, value=DEFAULT_CFGS[0], label="Model config") |
|
|
| |
| with gr.Tab("Image"): |
| in_img = gr.Image(type="numpy", label="Input image") |
| run_img = gr.Button("Run on Image", variant="primary") |
| out_img = gr.Image(label="Output image") |
| out_logs = gr.Textbox(label="Logs", interactive=False, lines=18) |
|
|
| def ui_predict_image(image, ckpt, cfg_use): |
| return run_unigaze_on_image( |
| image_array=image, |
| hf_repo=DEFAULT_HF_REPO, |
| ckpt_filename=ckpt, |
| cfg_path_user=cfg_use, |
| ) |
|
|
| run_img.click( |
| fn=ui_predict_image, |
| inputs=[in_img, ckpt_file, cfg_choice], |
| outputs=[out_img, out_logs], |
| ) |
|
|
| |
| gr.Examples( |
| examples=[["examples/The_Night_Watch_Frans_Banninck_Cocq.png", DEFAULT_CKPT_FILE[0], DEFAULT_CFGS[0]]], |
| inputs=[in_img, ckpt_file, cfg_choice], |
| outputs=[out_img, out_logs], |
| fn=ui_predict_image, |
| cache_examples=False, |
| ) |
|
|
| |
| with gr.Tab("Video"): |
| in_vid = gr.Video(label="Input video", sources=["upload"]) |
| run_vid = gr.Button("Run on Video", variant="primary") |
| out_img_v = gr.Image(label="Annotated image (last frame)") |
| out_vid_v = gr.Video(label="Output video") |
| out_zip_v = gr.File(label="All artifacts as ZIP") |
| out_logs_v = gr.Textbox(label="Logs", interactive=False, lines=18) |
|
|
| def ui_predict_video(video, ckpt, cfg_use): |
| return run_unigaze_on_video( |
| video_path=video, |
| hf_repo=DEFAULT_HF_REPO, |
| ckpt_filename=ckpt, |
| cfg_path_user=cfg_use, |
| ) |
|
|
| run_vid.click( |
| fn=ui_predict_video, |
| inputs=[in_vid, ckpt_file, cfg_choice], |
| outputs=[out_img_v, out_vid_v, out_zip_v, out_logs_v], |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |