import os import sys import time from pathlib import Path import gradio as gr import numpy as np from huggingface_hub import hf_hub_download from unigaze.infer_runtime import UniGazeRuntime DEFAULT_HF_REPO = "xucongzhang/UniGaze-models" CHECKPOINTS = [ "unigaze_h14_joint.pth.tar", "unigaze_l16_joint.pth.tar", "unigaze_b16_joint.pth.tar", ] MODEL_MAP = { "unigaze_h14_joint.pth.tar": "unigaze/configs/model/mae_h_14_gaze.yaml", "unigaze_l16_joint.pth.tar": "unigaze/configs/model/mae_l_16_gaze.yaml", "unigaze_b16_joint.pth.tar": "unigaze/configs/model/mae_b_16_gaze.yaml", } sys.path.append(os.path.dirname(__file__)) def resolve_cfg_abs(cfg_str): p = Path(cfg_str) if p.is_absolute(): return p return (Path.cwd() / p).resolve() def get_ckpt_path(repo_id, filename): return hf_hub_download( repo_id=repo_id, filename=filename, repo_type="model", ) from functools import lru_cache @lru_cache(maxsize=3) def get_runtime(cfg_abs, ckpt_path): return UniGazeRuntime(cfg_abs, ckpt_path, device="cpu") def run_image(image_array, ckpt_name): logs = [] t0 = time.time() try: ckpt_path = get_ckpt_path(DEFAULT_HF_REPO, ckpt_name) except Exception as e: return None, str(e) cfg_path = MODEL_MAP[ckpt_name] cfg_abs = resolve_cfg_abs(cfg_path) rt = get_runtime(str(cfg_abs), ckpt_path) out = rt.predict_image(image_array) logs.append(f"time: {time.time()-t0:.2f}s") return out, "\n".join(logs) def run_video(video_path, ckpt_name): logs = [] try: ckpt_path = get_ckpt_path(DEFAULT_HF_REPO, ckpt_name) except Exception as e: return None, None, None, str(e) cfg_path = MODEL_MAP[ckpt_name] cfg_abs = resolve_cfg_abs(cfg_path) rt = get_runtime(str(cfg_abs), ckpt_path) out_video, last_frame, runtime = rt.predict_video(video_path) logs.append(f"time: {runtime:.2f}s") return last_frame, out_video, None, "\n".join(logs) with gr.Blocks() as demo: ckpt = gr.Dropdown( choices=CHECKPOINTS, value=CHECKPOINTS[1], label="Checkpoint" ) with gr.Tab("Image"): inp = gr.Image(type="numpy") btn = gr.Button("Run") out = gr.Image() logs = gr.Textbox() btn.click(run_image, [inp, ckpt], [out, logs]) with gr.Tab("Video"): vin = gr.Video() btn2 = gr.Button("Run") img_out = gr.Image() vout = gr.Video() logs2 = gr.Textbox() btn2.click(run_video, [vin, ckpt], [img_out, vout, gr.File(), logs2]) if __name__ == "__main__": demo.launch()