saptak21's picture
Update app.py
8cc640f verified
import os
import sys
import time
from pathlib import Path
import gradio as gr
import numpy as np
from huggingface_hub import hf_hub_download
from unigaze.infer_runtime import UniGazeRuntime
DEFAULT_HF_REPO = "xucongzhang/UniGaze-models"
CHECKPOINTS = [
"unigaze_h14_joint.pth.tar",
"unigaze_l16_joint.pth.tar",
"unigaze_b16_joint.pth.tar",
]
MODEL_MAP = {
"unigaze_h14_joint.pth.tar": "unigaze/configs/model/mae_h_14_gaze.yaml",
"unigaze_l16_joint.pth.tar": "unigaze/configs/model/mae_l_16_gaze.yaml",
"unigaze_b16_joint.pth.tar": "unigaze/configs/model/mae_b_16_gaze.yaml",
}
sys.path.append(os.path.dirname(__file__))
def resolve_cfg_abs(cfg_str):
p = Path(cfg_str)
if p.is_absolute():
return p
return (Path.cwd() / p).resolve()
def get_ckpt_path(repo_id, filename):
return hf_hub_download(
repo_id=repo_id,
filename=filename,
repo_type="model",
)
from functools import lru_cache
@lru_cache(maxsize=3)
def get_runtime(cfg_abs, ckpt_path):
return UniGazeRuntime(cfg_abs, ckpt_path, device="cpu")
def run_image(image_array, ckpt_name):
logs = []
t0 = time.time()
try:
ckpt_path = get_ckpt_path(DEFAULT_HF_REPO, ckpt_name)
except Exception as e:
return None, str(e)
cfg_path = MODEL_MAP[ckpt_name]
cfg_abs = resolve_cfg_abs(cfg_path)
rt = get_runtime(str(cfg_abs), ckpt_path)
out = rt.predict_image(image_array)
logs.append(f"time: {time.time()-t0:.2f}s")
return out, "\n".join(logs)
def run_video(video_path, ckpt_name):
logs = []
try:
ckpt_path = get_ckpt_path(DEFAULT_HF_REPO, ckpt_name)
except Exception as e:
return None, None, None, str(e)
cfg_path = MODEL_MAP[ckpt_name]
cfg_abs = resolve_cfg_abs(cfg_path)
rt = get_runtime(str(cfg_abs), ckpt_path)
out_video, last_frame, runtime = rt.predict_video(video_path)
logs.append(f"time: {runtime:.2f}s")
return last_frame, out_video, None, "\n".join(logs)
with gr.Blocks() as demo:
ckpt = gr.Dropdown(
choices=CHECKPOINTS,
value=CHECKPOINTS[1],
label="Checkpoint"
)
with gr.Tab("Image"):
inp = gr.Image(type="numpy")
btn = gr.Button("Run")
out = gr.Image()
logs = gr.Textbox()
btn.click(run_image, [inp, ckpt], [out, logs])
with gr.Tab("Video"):
vin = gr.Video()
btn2 = gr.Button("Run")
img_out = gr.Image()
vout = gr.Video()
logs2 = gr.Textbox()
btn2.click(run_video, [vin, ckpt], [img_out, vout, gr.File(), logs2])
if __name__ == "__main__":
demo.launch()