File size: 2,661 Bytes
6308f07
744bf47
 
 
 
 
 
 
3cb4ae2
6308f07
744bf47
 
 
6308f07
8cc640f
 
 
744bf47
 
6308f07
8cc640f
 
 
6308f07
 
af9574e
744bf47
6308f07
8cc640f
 
 
 
 
6308f07
 
8cc640f
 
 
 
 
 
6308f07
744bf47
 
3cb4ae2
744bf47
8cc640f
 
744bf47
6308f07
 
8cc640f
 
 
 
 
 
 
6308f07
8cc640f
 
6308f07
8cc640f
 
6308f07
8cc640f
 
6308f07
 
 
8cc640f
6308f07
8cc640f
 
 
 
6308f07
8cc640f
 
6308f07
8cc640f
 
6308f07
8cc640f
 
744bf47
6308f07
8cc640f
6308f07
8cc640f
 
 
 
 
6308f07
8cc640f
 
 
 
 
6308f07
8cc640f
6308f07
8cc640f
 
 
 
 
 
6308f07
8cc640f
6308f07
 
8cc640f
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115

import os
import sys
import time
from pathlib import Path

import gradio as gr
import numpy as np
from huggingface_hub import hf_hub_download
from unigaze.infer_runtime import UniGazeRuntime

DEFAULT_HF_REPO = "xucongzhang/UniGaze-models"

CHECKPOINTS = [
    "unigaze_h14_joint.pth.tar",
    "unigaze_l16_joint.pth.tar",
    "unigaze_b16_joint.pth.tar",
]

MODEL_MAP = {
    "unigaze_h14_joint.pth.tar": "unigaze/configs/model/mae_h_14_gaze.yaml",
    "unigaze_l16_joint.pth.tar": "unigaze/configs/model/mae_l_16_gaze.yaml",
    "unigaze_b16_joint.pth.tar": "unigaze/configs/model/mae_b_16_gaze.yaml",
}

sys.path.append(os.path.dirname(__file__))


def resolve_cfg_abs(cfg_str):
    p = Path(cfg_str)
    if p.is_absolute():
        return p
    return (Path.cwd() / p).resolve()


def get_ckpt_path(repo_id, filename):
    return hf_hub_download(
        repo_id=repo_id,
        filename=filename,
        repo_type="model",
    )


from functools import lru_cache

@lru_cache(maxsize=3)
def get_runtime(cfg_abs, ckpt_path):
    return UniGazeRuntime(cfg_abs, ckpt_path, device="cpu")


def run_image(image_array, ckpt_name):
    logs = []
    t0 = time.time()

    try:
        ckpt_path = get_ckpt_path(DEFAULT_HF_REPO, ckpt_name)
    except Exception as e:
        return None, str(e)

    cfg_path = MODEL_MAP[ckpt_name]
    cfg_abs = resolve_cfg_abs(cfg_path)

    rt = get_runtime(str(cfg_abs), ckpt_path)
    out = rt.predict_image(image_array)

    logs.append(f"time: {time.time()-t0:.2f}s")
    return out, "\n".join(logs)


def run_video(video_path, ckpt_name):
    logs = []

    try:
        ckpt_path = get_ckpt_path(DEFAULT_HF_REPO, ckpt_name)
    except Exception as e:
        return None, None, None, str(e)

    cfg_path = MODEL_MAP[ckpt_name]
    cfg_abs = resolve_cfg_abs(cfg_path)

    rt = get_runtime(str(cfg_abs), ckpt_path)
    out_video, last_frame, runtime = rt.predict_video(video_path)

    logs.append(f"time: {runtime:.2f}s")
    return last_frame, out_video, None, "\n".join(logs)


with gr.Blocks() as demo:

    ckpt = gr.Dropdown(
        choices=CHECKPOINTS,
        value=CHECKPOINTS[1],
        label="Checkpoint"
    )

    with gr.Tab("Image"):
        inp = gr.Image(type="numpy")
        btn = gr.Button("Run")
        out = gr.Image()
        logs = gr.Textbox()

        btn.click(run_image, [inp, ckpt], [out, logs])

    with gr.Tab("Video"):
        vin = gr.Video()
        btn2 = gr.Button("Run")
        img_out = gr.Image()
        vout = gr.Video()
        logs2 = gr.Textbox()

        btn2.click(run_video, [vin, ckpt], [img_out, vout, gr.File(), logs2])


if __name__ == "__main__":
    demo.launch()