saptak21 commited on
Commit
6308f07
·
verified ·
1 Parent(s): 24a1a93

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +124 -173
app.py CHANGED
@@ -1,4 +1,5 @@
1
  # app.py
 
2
  import os
3
  import sys
4
  import time
@@ -8,203 +9,153 @@ from typing import List, Optional, Tuple
8
  import gradio as gr
9
  import numpy as np
10
  from huggingface_hub import hf_hub_download, HfApi
11
- from unigaze.infer_runtime import UniGazeRuntime # in-process runtime (no subprocess)
12
 
13
  # --------------------------------------------------------------------------------------
 
14
  # Defaults
 
15
  # --------------------------------------------------------------------------------------
 
16
  DEFAULT_HF_REPO = "xucongzhang/UniGaze-models"
17
- DEFAULT_CKPT_FILE = [
18
- "unigaze_h14_joint.pth.tar",
19
- "unigaze_l16_joint.pth.tar",
20
- "unigaze_b16_joint.pth.tar",
21
- ]
22
- DEFAULT_REVISION = "main"
23
 
24
- DEFAULT_CFGS = [
25
- "unigaze/configs/model/mae_h_14_gaze.yaml",
26
- "unigaze/configs/model/mae_L_16_gaze.yaml",
27
- "unigaze/configs/model/mae_b_16_gaze.yaml",
28
  ]
29
 
30
- TITLE = "UniGaze Demo (Video + Image)"
 
 
 
 
 
 
 
 
31
  DESC = """
32
- Upload a short video or a single image. The app downloads a checkpoint from the Hub,
33
- runs UniGaze in-process (no subprocess, no permanent writes), and returns results.
34
  """
35
 
36
- # Ensure imports of local packages work
37
- sys.path.append(os.path.dirname(__file__))
38
 
39
  # --------------------------------------------------------------------------------------
 
40
  # Helpers
 
41
  # --------------------------------------------------------------------------------------
 
42
  def resolve_cfg_abs(cfg_str: str) -> Path:
43
- """Return an absolute path to the YAML config."""
44
- p = Path(cfg_str)
45
- if p.is_absolute():
46
- if p.exists():
47
- return p
48
- raise FileNotFoundError(f"Config not found: {p}")
49
-
50
- p2 = (Path.cwd() / p).resolve()
51
- if p2.exists():
52
- return p2
53
-
54
- if str(p).startswith("configs/"):
55
- p3 = (Path.cwd() / "unigaze" / p).resolve()
56
- if p3.exists():
57
- return p3
58
-
59
- raise FileNotFoundError(f"Config not found. Tried: {p2}")
60
-
61
- def list_weight_files(repo_id: str, revision: str = "main") -> List[str]:
62
- try:
63
- api = HfApi()
64
- files = api.list_repo_files(repo_id=repo_id, repo_type="model", revision=revision)
65
- return [f for f in files if f.lower().endswith((".pth", ".pt", ".safetensors", ".tar", ".pth.tar"))]
66
- except Exception:
67
- return []
68
-
69
- def get_ckpt_path(repo_id: str, filename: str, revision: str = "main") -> str:
70
- files = list_weight_files(repo_id, revision)
71
- if files and filename not in files:
72
- raise FileNotFoundError(
73
- f"File '{filename}' not found in model repo '{repo_id}' at rev '{revision}'. "
74
- f"Available weights: {files}"
75
- )
76
- return hf_hub_download(
77
- repo_id=repo_id,
78
- filename=filename,
79
- revision=revision,
80
- repo_type="model",
81
- )
82
 
83
- # Cache the runtime so we load model/FA only once
84
  from functools import lru_cache
85
  @lru_cache(maxsize=3)
86
- def get_runtime(cfg_abs_str: str, ckpt_path: str, device: str = "cpu") -> UniGazeRuntime:
87
- return UniGazeRuntime(cfg_abs_str, ckpt_path, device=device)
88
 
89
  # --------------------------------------------------------------------------------------
90
- # Runners (in-process)
 
 
91
  # --------------------------------------------------------------------------------------
92
- def run_unigaze_on_video(
93
- video_path: str,
94
- hf_repo: str,
95
- ckpt_filename: str,
96
- cfg_path_user: str,
97
- extra_args: str = "",
98
- ) -> Tuple[Optional[np.ndarray], Optional[str], Optional[str], str]:
99
- logs: List[str] = []
100
- t0 = time.time()
101
-
102
- try:
103
- ckpt_path = get_ckpt_path(hf_repo, ckpt_filename, revision=DEFAULT_REVISION)
104
- logs.append(f"[hub] downloaded: {ckpt_path}")
105
- except Exception as e:
106
- return None, None, None, f"[hub] ERROR: {e}"
107
-
108
- try:
109
- cfg_abs = resolve_cfg_abs(cfg_path_user)
110
- except Exception as e:
111
- return None, None, None, f"[cfg] {e}"
112
-
113
- rt = get_runtime(str(cfg_abs), ckpt_path, device="cpu")
114
- mp4_path, last_rgb, run_sec = rt.predict_video(video_path)
115
- logs.append(f"[time] total runtime: {run_sec:.2f} seconds")
116
-
117
- return (last_rgb if last_rgb is not None else None), mp4_path, None, "\n".join(logs)
118
-
119
- def run_unigaze_on_image(
120
- image_array: np.ndarray,
121
- hf_repo: str,
122
- ckpt_filename: str,
123
- cfg_path_user: str,
124
- extra_args: str = "",
125
- ) -> Tuple[Optional[np.ndarray], str]:
126
- logs: List[str] = []
127
- t0 = time.time()
128
-
129
- try:
130
- ckpt_path = get_ckpt_path(hf_repo, ckpt_filename, revision=DEFAULT_REVISION)
131
- logs.append(f"[hub] downloaded: {ckpt_path}")
132
- except Exception as e:
133
- return None, f"[hub] ERROR: {e}"
134
-
135
- try:
136
- cfg_abs = resolve_cfg_abs(cfg_path_user)
137
- except Exception as e:
138
- return None, f"[cfg] {e}"
139
-
140
- rt = get_runtime(str(cfg_abs), ckpt_path, device="cpu")
141
- out_rgb = rt.predict_image(image_array)
142
- logs.append(f"[time] total runtime: {time.time() - t0:.2f} seconds")
143
-
144
- return out_rgb, "\n".join(logs)
145
 
146
  # --------------------------------------------------------------------------------------
 
147
  # UI
 
148
  # --------------------------------------------------------------------------------------
 
149
  with gr.Blocks(title=TITLE) as demo:
150
- gr.Markdown(f"# {TITLE}\n{DESC}")
151
-
152
- with gr.Row():
153
- ckpt_file = gr.Dropdown(choices=DEFAULT_CKPT_FILE, value=DEFAULT_CKPT_FILE[0], label="Checkpoint filename")
154
- cfg_choice = gr.Dropdown(choices=DEFAULT_CFGS, value=DEFAULT_CFGS[0], label="Model config")
155
-
156
- # IMAGE TAB
157
- with gr.Tab("Image"):
158
- in_img = gr.Image(type="numpy", label="Input image")
159
- run_img = gr.Button("Run on Image", variant="primary")
160
- out_img = gr.Image(label="Output image")
161
- out_logs = gr.Textbox(label="Logs", interactive=False, lines=18)
162
-
163
- def ui_predict_image(image, ckpt, cfg_use):
164
- return run_unigaze_on_image(
165
- image_array=image,
166
- hf_repo=DEFAULT_HF_REPO,
167
- ckpt_filename=ckpt,
168
- cfg_path_user=cfg_use,
169
- )
170
-
171
- run_img.click(
172
- fn=ui_predict_image,
173
- inputs=[in_img, ckpt_file, cfg_choice],
174
- outputs=[out_img, out_logs],
175
- )
176
-
177
- # Example image
178
- gr.Examples(
179
- examples=[["examples/The_Night_Watch_Frans_Banninck_Cocq.png", DEFAULT_CKPT_FILE[0], DEFAULT_CFGS[0]]],
180
- inputs=[in_img, ckpt_file, cfg_choice],
181
- outputs=[out_img, out_logs],
182
- fn=ui_predict_image,
183
- cache_examples=False,
184
- )
185
-
186
- # VIDEO TAB
187
- with gr.Tab("Video"):
188
- in_vid = gr.Video(label="Input video", sources=["upload"])
189
- run_vid = gr.Button("Run on Video", variant="primary")
190
- out_img_v = gr.Image(label="Annotated image (last frame)")
191
- out_vid_v = gr.Video(label="Output video")
192
- out_zip_v = gr.File(label="All artifacts as ZIP") # always None now
193
- out_logs_v = gr.Textbox(label="Logs", interactive=False, lines=18)
194
-
195
- def ui_predict_video(video, ckpt, cfg_use):
196
- return run_unigaze_on_video(
197
- video_path=video,
198
- hf_repo=DEFAULT_HF_REPO,
199
- ckpt_filename=ckpt,
200
- cfg_path_user=cfg_use,
201
- )
202
-
203
- run_vid.click(
204
- fn=ui_predict_video,
205
- inputs=[in_vid, ckpt_file, cfg_choice],
206
- outputs=[out_img_v, out_vid_v, out_zip_v, out_logs_v],
207
- )
208
-
209
- if __name__ == "__main__":
210
- demo.launch()
 
1
  # app.py
2
+
3
  import os
4
  import sys
5
  import time
 
9
  import gradio as gr
10
  import numpy as np
11
  from huggingface_hub import hf_hub_download, HfApi
12
+ from unigaze.infer_runtime import UniGazeRuntime
13
 
14
  # --------------------------------------------------------------------------------------
15
+
16
  # Defaults
17
+
18
  # --------------------------------------------------------------------------------------
19
+
20
  DEFAULT_HF_REPO = "xucongzhang/UniGaze-models"
 
 
 
 
 
 
21
 
22
+ CHECKPOINTS = [
23
+ "unigaze_h14_joint.pth.tar",
24
+ "unigaze_l16_joint.pth.tar",
25
+ "unigaze_b16_joint.pth.tar",
26
  ]
27
 
28
+ # 🔥 SAFE MAPPING (fixes your bug permanently)
29
+
30
+ MODEL_MAP = {
31
+ "unigaze_h14_joint.pth.tar": "unigaze/configs/model/mae_h_14_gaze.yaml",
32
+ "unigaze_l16_joint.pth.tar": "unigaze/configs/model/mae_l_16_gaze.yaml",
33
+ "unigaze_b16_joint.pth.tar": "unigaze/configs/model/mae_b_16_gaze.yaml",
34
+ }
35
+
36
+ TITLE = "UniGaze Demo (Fixed Version)"
37
  DESC = """
38
+ Upload a short video or a single image.
39
+ Checkpoint automatically loads correct model config (no mismatch errors).
40
  """
41
 
42
+ sys.path.append(os.path.dirname(**file**))
 
43
 
44
  # --------------------------------------------------------------------------------------
45
+
46
  # Helpers
47
+
48
  # --------------------------------------------------------------------------------------
49
+
50
  def resolve_cfg_abs(cfg_str: str) -> Path:
51
+ p = Path(cfg_str)
52
+ if p.is_absolute():
53
+ return p
54
+ return (Path.cwd() / p).resolve()
55
+
56
+ def get_ckpt_path(repo_id: str, filename: str) -> str:
57
+ return hf_hub_download(
58
+ repo_id=repo_id,
59
+ filename=filename,
60
+ repo_type="model",
61
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
 
 
63
  from functools import lru_cache
64
  @lru_cache(maxsize=3)
65
+ def get_runtime(cfg_abs: str, ckpt_path: str):
66
+ return UniGazeRuntime(cfg_abs, ckpt_path, device="cpu")
67
 
68
  # --------------------------------------------------------------------------------------
69
+
70
+ # RUNNERS
71
+
72
  # --------------------------------------------------------------------------------------
73
+
74
+ def run_image(image_array, ckpt_name):
75
+ logs = []
76
+ t0 = time.time()
77
+
78
+ ```
79
+ try:
80
+ ckpt_path = get_ckpt_path(DEFAULT_HF_REPO, ckpt_name)
81
+ logs.append(f"[OK] Loaded checkpoint: {ckpt_name}")
82
+ except Exception as e:
83
+ return None, f"[ERROR] {e}"
84
+
85
+ cfg_path = MODEL_MAP[ckpt_name]
86
+ cfg_abs = resolve_cfg_abs(cfg_path)
87
+
88
+ rt = get_runtime(str(cfg_abs), ckpt_path)
89
+ out = rt.predict_image(image_array)
90
+
91
+ logs.append(f"[time] {time.time()-t0:.2f}s")
92
+ return out, "\n".join(logs)
93
+ ```
94
+
95
+ def run_video(video_path, ckpt_name):
96
+ logs = []
97
+ t0 = time.time()
98
+
99
+ ```
100
+ try:
101
+ ckpt_path = get_ckpt_path(DEFAULT_HF_REPO, ckpt_name)
102
+ logs.append(f"[OK] Loaded checkpoint: {ckpt_name}")
103
+ except Exception as e:
104
+ return None, None, None, f"[ERROR] {e}"
105
+
106
+ cfg_path = MODEL_MAP[ckpt_name]
107
+ cfg_abs = resolve_cfg_abs(cfg_path)
108
+
109
+ rt = get_runtime(str(cfg_abs), ckpt_path)
110
+ out_video, last_frame, runtime = rt.predict_video(video_path)
111
+
112
+ logs.append(f"[time] {runtime:.2f}s")
113
+ return last_frame, out_video, None, "\n".join(logs)
114
+ ```
 
 
 
 
 
 
 
 
 
 
 
115
 
116
  # --------------------------------------------------------------------------------------
117
+
118
  # UI
119
+
120
  # --------------------------------------------------------------------------------------
121
+
122
  with gr.Blocks(title=TITLE) as demo:
123
+ gr.Markdown(f"# {TITLE}\n{DESC}")
124
+
125
+ ```
126
+ ckpt = gr.Dropdown(
127
+ choices=CHECKPOINTS,
128
+ value=CHECKPOINTS[1], # default L16 (safe)
129
+ label="Checkpoint"
130
+ )
131
+
132
+ # IMAGE
133
+ with gr.Tab("Image"):
134
+ inp = gr.Image(type="numpy")
135
+ btn = gr.Button("Run")
136
+ out = gr.Image()
137
+ logs = gr.Textbox(lines=10)
138
+
139
+ btn.click(
140
+ fn=run_image,
141
+ inputs=[inp, ckpt],
142
+ outputs=[out, logs],
143
+ )
144
+
145
+ # VIDEO
146
+ with gr.Tab("Video"):
147
+ vin = gr.Video()
148
+ btn2 = gr.Button("Run")
149
+ vout = gr.Video()
150
+ img_out = gr.Image()
151
+ logs2 = gr.Textbox(lines=10)
152
+
153
+ btn2.click(
154
+ fn=run_video,
155
+ inputs=[vin, ckpt],
156
+ outputs=[img_out, vout, gr.File(), logs2],
157
+ )
158
+ ```
159
+
160
+ if **name** == "**main**":
161
+ demo.launch()