saptak21 commited on
Commit
3cb4ae2
·
verified ·
1 Parent(s): 2163def

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +19 -16
app.py CHANGED
@@ -8,7 +8,7 @@ from typing import List, Optional, Tuple
8
 
9
  import gradio as gr
10
  import numpy as np
11
- from huggingface_hub import hf_hub_download, HfApi
12
  from unigaze.infer_runtime import UniGazeRuntime
13
 
14
  # --------------------------------------------------------------------------------------
@@ -25,7 +25,7 @@ CHECKPOINTS = [
25
  "unigaze_b16_joint.pth.tar",
26
  ]
27
 
28
- # 🔥 SAFE MAPPING (fixes your bug permanently)
29
 
30
  MODEL_MAP = {
31
  "unigaze_h14_joint.pth.tar": "unigaze/configs/model/mae_h_14_gaze.yaml",
@@ -39,7 +39,9 @@ Upload a short video or a single image.
39
  Checkpoint automatically loads correct model config (no mismatch errors).
40
  """
41
 
42
- sys.path.append(os.path.dirname(__file__))
 
 
43
 
44
  # --------------------------------------------------------------------------------------
45
 
@@ -61,13 +63,14 @@ repo_type="model",
61
  )
62
 
63
  from functools import lru_cache
 
64
  @lru_cache(maxsize=3)
65
  def get_runtime(cfg_abs: str, ckpt_path: str):
66
  return UniGazeRuntime(cfg_abs, ckpt_path, device="cpu")
67
 
68
  # --------------------------------------------------------------------------------------
69
 
70
- # RUNNERS
71
 
72
  # --------------------------------------------------------------------------------------
73
 
@@ -88,7 +91,7 @@ cfg_abs = resolve_cfg_abs(cfg_path)
88
  rt = get_runtime(str(cfg_abs), ckpt_path)
89
  out = rt.predict_image(image_array)
90
 
91
- logs.append(f"[time] {time.time()-t0:.2f}s")
92
  return out, "\n".join(logs)
93
  ```
94
 
@@ -125,16 +128,16 @@ gr.Markdown(f"# {TITLE}\n{DESC}")
125
  ```
126
  ckpt = gr.Dropdown(
127
  choices=CHECKPOINTS,
128
- value=CHECKPOINTS[1], # default L16 (safe)
129
- label="Checkpoint"
130
  )
131
 
132
- # IMAGE
133
  with gr.Tab("Image"):
134
- inp = gr.Image(type="numpy")
135
  btn = gr.Button("Run")
136
- out = gr.Image()
137
- logs = gr.Textbox(lines=10)
138
 
139
  btn.click(
140
  fn=run_image,
@@ -142,13 +145,13 @@ with gr.Tab("Image"):
142
  outputs=[out, logs],
143
  )
144
 
145
- # VIDEO
146
  with gr.Tab("Video"):
147
- vin = gr.Video()
148
  btn2 = gr.Button("Run")
149
- vout = gr.Video()
150
- img_out = gr.Image()
151
- logs2 = gr.Textbox(lines=10)
152
 
153
  btn2.click(
154
  fn=run_video,
 
8
 
9
  import gradio as gr
10
  import numpy as np
11
+ from huggingface_hub import hf_hub_download
12
  from unigaze.infer_runtime import UniGazeRuntime
13
 
14
  # --------------------------------------------------------------------------------------
 
25
  "unigaze_b16_joint.pth.tar",
26
  ]
27
 
28
+ # 🔥 SAFE MAPPING (fixes mismatch issue)
29
 
30
  MODEL_MAP = {
31
  "unigaze_h14_joint.pth.tar": "unigaze/configs/model/mae_h_14_gaze.yaml",
 
39
  Checkpoint automatically loads correct model config (no mismatch errors).
40
  """
41
 
42
+ # Fix import path
43
+
44
+ sys.path.append(os.path.dirname(**file**))
45
 
46
  # --------------------------------------------------------------------------------------
47
 
 
63
  )
64
 
65
  from functools import lru_cache
66
+
67
  @lru_cache(maxsize=3)
68
  def get_runtime(cfg_abs: str, ckpt_path: str):
69
  return UniGazeRuntime(cfg_abs, ckpt_path, device="cpu")
70
 
71
  # --------------------------------------------------------------------------------------
72
 
73
+ # Runners
74
 
75
  # --------------------------------------------------------------------------------------
76
 
 
91
  rt = get_runtime(str(cfg_abs), ckpt_path)
92
  out = rt.predict_image(image_array)
93
 
94
+ logs.append(f"[time] {time.time() - t0:.2f}s")
95
  return out, "\n".join(logs)
96
  ```
97
 
 
128
  ```
129
  ckpt = gr.Dropdown(
130
  choices=CHECKPOINTS,
131
+ value=CHECKPOINTS[1], # default L16
132
+ label="Checkpoint",
133
  )
134
 
135
+ # IMAGE TAB
136
  with gr.Tab("Image"):
137
+ inp = gr.Image(type="numpy", label="Input Image")
138
  btn = gr.Button("Run")
139
+ out = gr.Image(label="Output Image")
140
+ logs = gr.Textbox(label="Logs", lines=10)
141
 
142
  btn.click(
143
  fn=run_image,
 
145
  outputs=[out, logs],
146
  )
147
 
148
+ # VIDEO TAB
149
  with gr.Tab("Video"):
150
+ vin = gr.Video(label="Input Video")
151
  btn2 = gr.Button("Run")
152
+ img_out = gr.Image(label="Last Frame")
153
+ vout = gr.Video(label="Output Video")
154
+ logs2 = gr.Textbox(label="Logs", lines=10)
155
 
156
  btn2.click(
157
  fn=run_video,