Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import time | |
| from pathlib import Path | |
| import gradio as gr | |
| import numpy as np | |
| from huggingface_hub import hf_hub_download | |
| from unigaze.infer_runtime import UniGazeRuntime | |
| DEFAULT_HF_REPO = "xucongzhang/UniGaze-models" | |
| CHECKPOINTS = [ | |
| "unigaze_h14_joint.pth.tar", | |
| "unigaze_l16_joint.pth.tar", | |
| "unigaze_b16_joint.pth.tar", | |
| ] | |
| MODEL_MAP = { | |
| "unigaze_h14_joint.pth.tar": "unigaze/configs/model/mae_h_14_gaze.yaml", | |
| "unigaze_l16_joint.pth.tar": "unigaze/configs/model/mae_l_16_gaze.yaml", | |
| "unigaze_b16_joint.pth.tar": "unigaze/configs/model/mae_b_16_gaze.yaml", | |
| } | |
| sys.path.append(os.path.dirname(__file__)) | |
| def resolve_cfg_abs(cfg_str): | |
| p = Path(cfg_str) | |
| if p.is_absolute(): | |
| return p | |
| return (Path.cwd() / p).resolve() | |
| def get_ckpt_path(repo_id, filename): | |
| return hf_hub_download( | |
| repo_id=repo_id, | |
| filename=filename, | |
| repo_type="model", | |
| ) | |
| from functools import lru_cache | |
| def get_runtime(cfg_abs, ckpt_path): | |
| return UniGazeRuntime(cfg_abs, ckpt_path, device="cpu") | |
| def run_image(image_array, ckpt_name): | |
| logs = [] | |
| t0 = time.time() | |
| try: | |
| ckpt_path = get_ckpt_path(DEFAULT_HF_REPO, ckpt_name) | |
| except Exception as e: | |
| return None, str(e) | |
| cfg_path = MODEL_MAP[ckpt_name] | |
| cfg_abs = resolve_cfg_abs(cfg_path) | |
| rt = get_runtime(str(cfg_abs), ckpt_path) | |
| out = rt.predict_image(image_array) | |
| logs.append(f"time: {time.time()-t0:.2f}s") | |
| return out, "\n".join(logs) | |
| def run_video(video_path, ckpt_name): | |
| logs = [] | |
| try: | |
| ckpt_path = get_ckpt_path(DEFAULT_HF_REPO, ckpt_name) | |
| except Exception as e: | |
| return None, None, None, str(e) | |
| cfg_path = MODEL_MAP[ckpt_name] | |
| cfg_abs = resolve_cfg_abs(cfg_path) | |
| rt = get_runtime(str(cfg_abs), ckpt_path) | |
| out_video, last_frame, runtime = rt.predict_video(video_path) | |
| logs.append(f"time: {runtime:.2f}s") | |
| return last_frame, out_video, None, "\n".join(logs) | |
| with gr.Blocks() as demo: | |
| ckpt = gr.Dropdown( | |
| choices=CHECKPOINTS, | |
| value=CHECKPOINTS[1], | |
| label="Checkpoint" | |
| ) | |
| with gr.Tab("Image"): | |
| inp = gr.Image(type="numpy") | |
| btn = gr.Button("Run") | |
| out = gr.Image() | |
| logs = gr.Textbox() | |
| btn.click(run_image, [inp, ckpt], [out, logs]) | |
| with gr.Tab("Video"): | |
| vin = gr.Video() | |
| btn2 = gr.Button("Run") | |
| img_out = gr.Image() | |
| vout = gr.Video() | |
| logs2 = gr.Textbox() | |
| btn2.click(run_video, [vin, ckpt], [img_out, vout, gr.File(), logs2]) | |
| if __name__ == "__main__": | |
| demo.launch() |