Spaces:
Sleeping
Sleeping
File size: 2,661 Bytes
6308f07 744bf47 3cb4ae2 6308f07 744bf47 6308f07 8cc640f 744bf47 6308f07 8cc640f 6308f07 af9574e 744bf47 6308f07 8cc640f 6308f07 8cc640f 6308f07 744bf47 3cb4ae2 744bf47 8cc640f 744bf47 6308f07 8cc640f 6308f07 8cc640f 6308f07 8cc640f 6308f07 8cc640f 6308f07 8cc640f 6308f07 8cc640f 6308f07 8cc640f 6308f07 8cc640f 6308f07 8cc640f 744bf47 6308f07 8cc640f 6308f07 8cc640f 6308f07 8cc640f 6308f07 8cc640f 6308f07 8cc640f 6308f07 8cc640f 6308f07 8cc640f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import os
import sys
import time
from pathlib import Path
import gradio as gr
import numpy as np
from huggingface_hub import hf_hub_download
from unigaze.infer_runtime import UniGazeRuntime
DEFAULT_HF_REPO = "xucongzhang/UniGaze-models"
CHECKPOINTS = [
"unigaze_h14_joint.pth.tar",
"unigaze_l16_joint.pth.tar",
"unigaze_b16_joint.pth.tar",
]
MODEL_MAP = {
"unigaze_h14_joint.pth.tar": "unigaze/configs/model/mae_h_14_gaze.yaml",
"unigaze_l16_joint.pth.tar": "unigaze/configs/model/mae_l_16_gaze.yaml",
"unigaze_b16_joint.pth.tar": "unigaze/configs/model/mae_b_16_gaze.yaml",
}
sys.path.append(os.path.dirname(__file__))
def resolve_cfg_abs(cfg_str):
p = Path(cfg_str)
if p.is_absolute():
return p
return (Path.cwd() / p).resolve()
def get_ckpt_path(repo_id, filename):
return hf_hub_download(
repo_id=repo_id,
filename=filename,
repo_type="model",
)
from functools import lru_cache
@lru_cache(maxsize=3)
def get_runtime(cfg_abs, ckpt_path):
return UniGazeRuntime(cfg_abs, ckpt_path, device="cpu")
def run_image(image_array, ckpt_name):
logs = []
t0 = time.time()
try:
ckpt_path = get_ckpt_path(DEFAULT_HF_REPO, ckpt_name)
except Exception as e:
return None, str(e)
cfg_path = MODEL_MAP[ckpt_name]
cfg_abs = resolve_cfg_abs(cfg_path)
rt = get_runtime(str(cfg_abs), ckpt_path)
out = rt.predict_image(image_array)
logs.append(f"time: {time.time()-t0:.2f}s")
return out, "\n".join(logs)
def run_video(video_path, ckpt_name):
logs = []
try:
ckpt_path = get_ckpt_path(DEFAULT_HF_REPO, ckpt_name)
except Exception as e:
return None, None, None, str(e)
cfg_path = MODEL_MAP[ckpt_name]
cfg_abs = resolve_cfg_abs(cfg_path)
rt = get_runtime(str(cfg_abs), ckpt_path)
out_video, last_frame, runtime = rt.predict_video(video_path)
logs.append(f"time: {runtime:.2f}s")
return last_frame, out_video, None, "\n".join(logs)
with gr.Blocks() as demo:
ckpt = gr.Dropdown(
choices=CHECKPOINTS,
value=CHECKPOINTS[1],
label="Checkpoint"
)
with gr.Tab("Image"):
inp = gr.Image(type="numpy")
btn = gr.Button("Run")
out = gr.Image()
logs = gr.Textbox()
btn.click(run_image, [inp, ckpt], [out, logs])
with gr.Tab("Video"):
vin = gr.Video()
btn2 = gr.Button("Run")
img_out = gr.Image()
vout = gr.Video()
logs2 = gr.Textbox()
btn2.click(run_video, [vin, ckpt], [img_out, vout, gr.File(), logs2])
if __name__ == "__main__":
demo.launch() |