usingcolor commited on
Commit
9d42ecb
·
1 Parent(s): 0f440c7

feat: initialize local git repository and update project configurations

Browse files
Files changed (1) hide show
  1. app.py +147 -47
app.py CHANGED
@@ -3,29 +3,20 @@ import os
3
  import subprocess
4
  import time
5
 
6
- # --- Dynamic Repository Clone ---
7
- # If the MambaEye source code isn't deployed directly alongside this app.py,
8
- # we clone it from GitHub before trying to import it.
9
  mamba_dir = os.path.join(os.path.dirname(__file__), "MambaEye")
10
  if not os.path.exists(mamba_dir) or not os.path.exists(os.path.join(mamba_dir, "mambaeye")):
11
  print("Cloning MambaEye repository from GitHub...", flush=True)
12
- # Ensure any empty/partial directory is removed before cloning
13
  if os.path.exists(mamba_dir):
14
  import shutil
15
  shutil.rmtree(mamba_dir)
16
  subprocess.check_call(["git", "clone", "https://github.com/usingcolor/MambaEye.git", mamba_dir])
17
 
18
- # --- Dynamic Dependency Injection for HuggingFace Spaces ---
19
- # HuggingFace ZeroGPU builder environments lack `nvcc`.
20
- # We intercept the import and softly compile mamba-ssm using CPU-fallback PyTorch natives
21
- # so we pass the build requirements perfectly.
22
  try:
23
  import mamba_ssm
24
  import causal_conv1d
25
  except ImportError:
26
  print("Installing mamba_ssm and causal_conv1d in backend...", flush=True)
27
  env = os.environ.copy()
28
- # Bypass CUDA extensions because we don't have nvcc locally or in standard Hub build container
29
  env["MAMBA_SKIP_CUDA_BUILD"] = "TRUE"
30
  env["CAUSAL_CONV1D_SKIP_CUDA_BUILD"] = "TRUE"
31
  subprocess.check_call(
@@ -33,7 +24,6 @@ except ImportError:
33
  env=env
34
  )
35
 
36
- # Add the cloned MambaEye repository to the Python path
37
  sys.path.append(os.path.join(os.path.dirname(__file__), "MambaEye"))
38
 
39
  import gradio as gr
@@ -46,13 +36,11 @@ from torchvision.models import ResNet50_Weights
46
  from huggingface_hub import hf_hub_download
47
  import spaces
48
 
49
- # MambaEye Imports
50
  from mambaeye.model import MambaEye
51
  from mambaeye.scan import generate_scan_positions
52
  from mambaeye.positional_encoding import sinusoidal_position_encoding_2d
53
  from mamba_ssm.utils.generation import InferenceParams
54
 
55
- # Global Configuration
56
  TARGET_CANVAS_SIZE = 512
57
  PATCH_SIZE = 16
58
  CATEGORIES = ResNet50_Weights.IMAGENET1K_V1.meta["categories"]
@@ -70,10 +58,105 @@ MODEL_CONFIG = {
70
 
71
  MODEL_REPO = "usingcolor/MambaEye-base"
72
  MODEL_FILENAME = "mambaeye_base_ft.pt"
73
-
74
- # Global Model Cache
75
  _GLOBAL_MODEL = None
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
  def get_model():
78
  global _GLOBAL_MODEL
79
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -82,9 +165,7 @@ def get_model():
82
  try:
83
  checkpoint_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
84
  model = MambaEye(**MODEL_CONFIG)
85
-
86
- # Since this runs inside ZeroGPU worker, load directly to device
87
- model.load_state_dict(torch.load(checkpoint_path, map_location=device))
88
  model.to(device)
89
  model.eval()
90
  _GLOBAL_MODEL = model
@@ -95,7 +176,6 @@ def get_model():
95
  return _GLOBAL_MODEL, device
96
 
97
  def transfer_inference_params(params, device):
98
- """Recursively moves the KV cache state of MambaEye InferenceParams to CPU or CUDA."""
99
  if params is None or getattr(params, "key_value_memory_dict", None) is None:
100
  return params
101
 
@@ -106,7 +186,7 @@ def transfer_inference_params(params, device):
106
  params.key_value_memory_dict[k] = tuple(x.to(device) if isinstance(x, torch.Tensor) else x for x in v)
107
  elif isinstance(v, list):
108
  params.key_value_memory_dict[k] = [x.to(device) if isinstance(x, torch.Tensor) else x for x in v]
109
- elif isinstance(v, dict): # E.g., layers map
110
  for k2, v2 in v.items():
111
  if hasattr(v2, "to"):
112
  params.key_value_memory_dict[k][k2] = v2.to(device)
@@ -155,10 +235,18 @@ def extract_patch(canvas_tensor, px, py):
155
  return patch.flatten()
156
 
157
  def draw_patches_on_image(image_arr, positions, x_offset, y_offset, h, w):
158
- img = Image.fromarray(image_arr).convert("RGB")
159
- draw = ImageDraw.Draw(img)
 
 
 
 
 
 
 
 
160
 
161
- orig_w, orig_h = img.size
162
  ratio = min(TARGET_CANVAS_SIZE / orig_w, TARGET_CANVAS_SIZE / orig_h)
163
 
164
  for i, (px, py) in enumerate(positions):
@@ -166,8 +254,14 @@ def draw_patches_on_image(image_arr, positions, x_offset, y_offset, h, w):
166
  orig_x = (px - x_offset) / ratio
167
  orig_px_size = PATCH_SIZE / ratio
168
 
 
 
 
 
 
 
169
  color = "red" if i == len(positions) - 1 else "blue"
170
- draw.rectangle([orig_y, orig_x, orig_y + orig_px_size, orig_x + orig_px_size], outline=color, width=2)
171
 
172
  if i > 0:
173
  prev_py, prev_px = positions[i-1]
@@ -178,7 +272,7 @@ def draw_patches_on_image(image_arr, positions, x_offset, y_offset, h, w):
178
  center_curr = (orig_y + orig_px_size / 2, orig_x + orig_px_size / 2)
179
  draw.line([center_prev, center_curr], fill="blue", width=2)
180
 
181
- return np.array(img), positions
182
 
183
  def init_state_for_image(image):
184
  canvas_tensor, x_offset, y_offset, h, w = preprocess_image(image)
@@ -244,7 +338,6 @@ def run_auto_scan(image, scan_pattern, sequence_length):
244
  state['drawn_positions'] = positions
245
  state['sequence_length'] = sequence_length
246
 
247
- # On ZeroGPU spaces securely move Tensors back to CPU State
248
  state['canvas_tensor'] = state['canvas_tensor'].cpu()
249
  state['inference_params'] = transfer_inference_params(inference_params, torch.device('cpu'))
250
 
@@ -256,7 +349,7 @@ def run_auto_scan(image, scan_pattern, sequence_length):
256
  return img_display, format_predictions(final_probs), state, f"Auto Scan Complete. Extracted {sequence_length} patches. Click to add more!"
257
 
258
  @spaces.GPU
259
- def on_click(evt: gr.SelectData, original_image, state):
260
  if original_image is None:
261
  return None, {"Upload Image": 1.0}, state, "Upload Image"
262
 
@@ -266,10 +359,8 @@ def on_click(evt: gr.SelectData, original_image, state):
266
  state = init_state_for_image(original_image)
267
  state['inference_params'] = InferenceParams(max_seqlen=4000, max_batch_size=1)
268
 
269
- # Move InferenceParams back to the functional device correctly!
270
  state['inference_params'] = transfer_inference_params(state['inference_params'], device)
271
 
272
- x_orig, y_orig = evt.index
273
  orig_h, orig_w = state['original_image'].shape[:2]
274
  ratio = min(TARGET_CANVAS_SIZE / orig_w, TARGET_CANVAS_SIZE / orig_h)
275
 
@@ -285,8 +376,8 @@ def on_click(evt: gr.SelectData, original_image, state):
285
 
286
  patch = extract_patch(state['canvas_tensor'], px, py).to(device)
287
 
288
- img_seq = patch.unsqueeze(0).unsqueeze(0) # (1, 1, 768)
289
- move_seq = move_emb.unsqueeze(0) # (1, 1, 512)
290
 
291
  with torch.no_grad():
292
  out = model(img_seq, move_seq, inference_params=state['inference_params'])
@@ -297,7 +388,6 @@ def on_click(evt: gr.SelectData, original_image, state):
297
  state['drawn_positions'].append((px, py))
298
  state['sequence_length'] += 1
299
 
300
- # Strip back to CPU for Gradio Session Memory
301
  state['inference_params'] = transfer_inference_params(state['inference_params'], torch.device('cpu'))
302
 
303
  img_display, _ = draw_patches_on_image(
@@ -305,27 +395,39 @@ def on_click(evt: gr.SelectData, original_image, state):
305
  state['x_offset'], state['y_offset'], state['h'], state['w']
306
  )
307
 
308
- return img_display, format_predictions(final_probs), state, f"Added patch {state['sequence_length']} (Total {state['inference_params'].seqlen_offset} inference steps)."
 
 
 
 
309
 
310
  def on_upload(image):
311
  if image is None:
312
- return None, {"Waiting...": 1.0}, None, "Upload Image"
313
- # Delay model load until auto-scan triggers, saving memory overhead in preloads
314
- return image, {"Click Auto Scan or click the image": 1.0}, None, "Ready. You can Auto Scan or click."
 
 
 
 
315
 
316
  def on_clear(original_image):
317
  if original_image is None:
318
  return None, {"Cleared": 1.0}, None, "Cleared"
319
- return original_image, {"Cleared": 1.0}, init_state_for_image(original_image), "Selections cleared. Ready for new patch sequence."
 
 
 
 
320
 
321
- # Build the Gradio App Blocks
322
- with gr.Blocks(title="MambaEye Interactive Demo", theme=gr.themes.Soft()) as demo:
323
- gr.Markdown("# MambaEye Interactive inference Demo")
324
- gr.Markdown("This interface incorporates the full **MambaEye-base** model inference natively. Using **ZeroGPU** inference via PyTorch equivalents.")
325
 
326
  with gr.Row():
327
  with gr.Column(scale=2):
328
- input_image = gr.Image(type="numpy", label="Upload and Select Patches", interactive=True)
 
329
 
330
  with gr.Row():
331
  scan_pattern = gr.Dropdown(
@@ -343,17 +445,13 @@ with gr.Blocks(title="MambaEye Interactive Demo", theme=gr.themes.Soft()) as dem
343
  model_output_label = gr.Label(label="MambaEye Output Predictions", num_top_classes=5)
344
  status_text = gr.Markdown("Status: Waiting for image upload...")
345
 
346
- # Application State
347
  state = gr.State(None)
348
  original_image_state = gr.State(None)
349
 
350
- # Event wiring
351
  input_image.upload(
352
  fn=on_upload,
353
  inputs=[input_image],
354
- outputs=[input_image, model_output_label, state, status_text]
355
- ).then(
356
- fn=lambda img: img, inputs=[input_image], outputs=[original_image_state]
357
  )
358
 
359
  auto_btn.click(
@@ -373,6 +471,8 @@ with gr.Blocks(title="MambaEye Interactive Demo", theme=gr.themes.Soft()) as dem
373
  inputs=[original_image_state],
374
  outputs=[input_image, model_output_label, state, status_text]
375
  )
 
 
376
 
377
  if __name__ == "__main__":
378
- demo.launch()
 
3
  import subprocess
4
  import time
5
 
 
 
 
6
  mamba_dir = os.path.join(os.path.dirname(__file__), "MambaEye")
7
  if not os.path.exists(mamba_dir) or not os.path.exists(os.path.join(mamba_dir, "mambaeye")):
8
  print("Cloning MambaEye repository from GitHub...", flush=True)
 
9
  if os.path.exists(mamba_dir):
10
  import shutil
11
  shutil.rmtree(mamba_dir)
12
  subprocess.check_call(["git", "clone", "https://github.com/usingcolor/MambaEye.git", mamba_dir])
13
 
 
 
 
 
14
  try:
15
  import mamba_ssm
16
  import causal_conv1d
17
  except ImportError:
18
  print("Installing mamba_ssm and causal_conv1d in backend...", flush=True)
19
  env = os.environ.copy()
 
20
  env["MAMBA_SKIP_CUDA_BUILD"] = "TRUE"
21
  env["CAUSAL_CONV1D_SKIP_CUDA_BUILD"] = "TRUE"
22
  subprocess.check_call(
 
24
  env=env
25
  )
26
 
 
27
  sys.path.append(os.path.join(os.path.dirname(__file__), "MambaEye"))
28
 
29
  import gradio as gr
 
36
  from huggingface_hub import hf_hub_download
37
  import spaces
38
 
 
39
  from mambaeye.model import MambaEye
40
  from mambaeye.scan import generate_scan_positions
41
  from mambaeye.positional_encoding import sinusoidal_position_encoding_2d
42
  from mamba_ssm.utils.generation import InferenceParams
43
 
 
44
  TARGET_CANVAS_SIZE = 512
45
  PATCH_SIZE = 16
46
  CATEGORIES = ResNet50_Weights.IMAGENET1K_V1.meta["categories"]
 
58
 
59
  MODEL_REPO = "usingcolor/MambaEye-base"
60
  MODEL_FILENAME = "mambaeye_base_ft.pt"
 
 
61
  _GLOBAL_MODEL = None
62
 
63
+
64
+ # --- HOVER SCRIPT INJECTION ---
65
+ JS_HOVER_SCRIPT = """
66
+ function() {
67
+ let overlay = document.getElementById('mamba-hover-overlay');
68
+ if (!overlay) {
69
+ overlay = document.createElement('div');
70
+ overlay.id = 'mamba-hover-overlay';
71
+ overlay.style.position = 'fixed';
72
+ overlay.style.pointerEvents = 'none';
73
+ overlay.style.border = '2px solid rgba(0, 102, 255, 0.8)';
74
+ overlay.style.backgroundColor = 'rgba(0, 102, 255, 0.2)';
75
+ overlay.style.zIndex = '99999';
76
+ overlay.style.display = 'none';
77
+ document.body.appendChild(overlay);
78
+ }
79
+
80
+ document.addEventListener('mousemove', (e) => {
81
+ let imgs = document.querySelectorAll('img');
82
+ let targetImg = null;
83
+ for (let img of imgs) {
84
+ if (img.closest('.gradio-image-hook')) {
85
+ if (img.src && !img.src.includes('data:image/svg')) {
86
+ targetImg = img;
87
+ }
88
+ }
89
+ }
90
+ if (!targetImg) { overlay.style.display = 'none'; return; }
91
+
92
+ let rect = targetImg.getBoundingClientRect();
93
+ if (e.clientX >= rect.left && e.clientX <= rect.right && e.clientY >= rect.top && e.clientY <= rect.bottom) {
94
+ let nw = targetImg.naturalWidth;
95
+ let nh = targetImg.naturalHeight;
96
+ if (nw === 0 || nh === 0) return;
97
+
98
+ let cw = rect.width;
99
+ let ch = rect.height;
100
+ let imgRatio = nw / nh;
101
+ let containerRatio = cw / ch;
102
+
103
+ let renderW, renderH, renderX, renderY;
104
+ if (imgRatio > containerRatio) {
105
+ renderW = cw;
106
+ renderH = cw / imgRatio;
107
+ renderX = 0;
108
+ renderY = (ch - renderH) / 2;
109
+ } else {
110
+ renderH = ch;
111
+ renderW = ch * imgRatio;
112
+ renderY = 0;
113
+ renderX = (cw - renderW) / 2;
114
+ }
115
+
116
+ let relX = e.clientX - rect.left - renderX;
117
+ let relY = e.clientY - rect.top - renderY;
118
+
119
+ if (relX >= 0 && relX <= renderW && relY >= 0 && relY <= renderH) {
120
+ let scale = renderW / nw;
121
+ let TARGET_CANVAS_SIZE = 512;
122
+ let ratio = Math.min(TARGET_CANVAS_SIZE / nw, TARGET_CANVAS_SIZE / nh);
123
+
124
+ let origX = relX / scale;
125
+ let origY = relY / scale;
126
+
127
+ let y_offset = (TARGET_CANVAS_SIZE - nw * ratio) / 2;
128
+ let x_offset = (TARGET_CANVAS_SIZE - nh * ratio) / 2;
129
+
130
+ let canvas_y = origX * ratio + y_offset;
131
+ let canvas_x = origY * ratio + x_offset;
132
+
133
+ let px = Math.floor(canvas_x / 16) * 16;
134
+ let py = Math.floor(canvas_y / 16) * 16;
135
+
136
+ let start_orig_y = (py - y_offset) / ratio;
137
+ let start_orig_x = (px - x_offset) / ratio;
138
+
139
+ let render_box_x = rect.left + renderX + start_orig_y * scale;
140
+ let render_box_y = rect.top + renderY + start_orig_x * scale;
141
+
142
+ let size_scale = (16 / ratio) * scale;
143
+
144
+ overlay.style.left = render_box_x + "px";
145
+ overlay.style.top = render_box_y + "px";
146
+ overlay.style.width = size_scale + "px";
147
+ overlay.style.height = size_scale + "px";
148
+ overlay.style.display = 'block';
149
+ } else {
150
+ overlay.style.display = 'none';
151
+ }
152
+ } else {
153
+ overlay.style.display = 'none';
154
+ }
155
+ });
156
+ }
157
+ """
158
+ # -----------------------------
159
+
160
  def get_model():
161
  global _GLOBAL_MODEL
162
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
165
  try:
166
  checkpoint_path = hf_hub_download(repo_id=MODEL_REPO, filename=MODEL_FILENAME)
167
  model = MambaEye(**MODEL_CONFIG)
168
+ model.load_state_dict(torch.load(checkpoint_path, map_location=device, weights_only=True))
 
 
169
  model.to(device)
170
  model.eval()
171
  _GLOBAL_MODEL = model
 
176
  return _GLOBAL_MODEL, device
177
 
178
  def transfer_inference_params(params, device):
 
179
  if params is None or getattr(params, "key_value_memory_dict", None) is None:
180
  return params
181
 
 
186
  params.key_value_memory_dict[k] = tuple(x.to(device) if isinstance(x, torch.Tensor) else x for x in v)
187
  elif isinstance(v, list):
188
  params.key_value_memory_dict[k] = [x.to(device) if isinstance(x, torch.Tensor) else x for x in v]
189
+ elif isinstance(v, dict):
190
  for k2, v2 in v.items():
191
  if hasattr(v2, "to"):
192
  params.key_value_memory_dict[k][k2] = v2.to(device)
 
235
  return patch.flatten()
236
 
237
  def draw_patches_on_image(image_arr, positions, x_offset, y_offset, h, w):
238
+ img = np.array(image_arr)
239
+
240
+ # Create the greyed-out ambient background
241
+ grey_base = Image.fromarray(img).convert("L").convert("RGB")
242
+ grey_np = np.array(grey_base) * 0.4 + np.full_like(grey_np, 160) # Note: broadcasting handles full_like internally safely via float math
243
+ grey_base_np = (np.array(grey_base).astype(float) * 0.4 + 160).clip(0, 255).astype(np.uint8)
244
+
245
+ temp_img = Image.fromarray(grey_base_np)
246
+ orig_pil = Image.fromarray(img)
247
+ draw = ImageDraw.Draw(temp_img)
248
 
249
+ orig_w, orig_h = orig_pil.size
250
  ratio = min(TARGET_CANVAS_SIZE / orig_w, TARGET_CANVAS_SIZE / orig_h)
251
 
252
  for i, (px, py) in enumerate(positions):
 
254
  orig_x = (px - x_offset) / ratio
255
  orig_px_size = PATCH_SIZE / ratio
256
 
257
+ box = (int(orig_y), int(orig_x), int(orig_y + orig_px_size), int(orig_x + orig_px_size))
258
+
259
+ # Paste original color into the highlighted region
260
+ patch_crop = orig_pil.crop(box)
261
+ temp_img.paste(patch_crop, box)
262
+
263
  color = "red" if i == len(positions) - 1 else "blue"
264
+ draw.rectangle(box, outline=color, width=2)
265
 
266
  if i > 0:
267
  prev_py, prev_px = positions[i-1]
 
272
  center_curr = (orig_y + orig_px_size / 2, orig_x + orig_px_size / 2)
273
  draw.line([center_prev, center_curr], fill="blue", width=2)
274
 
275
+ return np.array(temp_img), positions
276
 
277
  def init_state_for_image(image):
278
  canvas_tensor, x_offset, y_offset, h, w = preprocess_image(image)
 
338
  state['drawn_positions'] = positions
339
  state['sequence_length'] = sequence_length
340
 
 
341
  state['canvas_tensor'] = state['canvas_tensor'].cpu()
342
  state['inference_params'] = transfer_inference_params(inference_params, torch.device('cpu'))
343
 
 
349
  return img_display, format_predictions(final_probs), state, f"Auto Scan Complete. Extracted {sequence_length} patches. Click to add more!"
350
 
351
  @spaces.GPU
352
+ def process_click_inference(x_orig, y_orig, original_image, state):
353
  if original_image is None:
354
  return None, {"Upload Image": 1.0}, state, "Upload Image"
355
 
 
359
  state = init_state_for_image(original_image)
360
  state['inference_params'] = InferenceParams(max_seqlen=4000, max_batch_size=1)
361
 
 
362
  state['inference_params'] = transfer_inference_params(state['inference_params'], device)
363
 
 
364
  orig_h, orig_w = state['original_image'].shape[:2]
365
  ratio = min(TARGET_CANVAS_SIZE / orig_w, TARGET_CANVAS_SIZE / orig_h)
366
 
 
376
 
377
  patch = extract_patch(state['canvas_tensor'], px, py).to(device)
378
 
379
+ img_seq = patch.unsqueeze(0).unsqueeze(0)
380
+ move_seq = move_emb.unsqueeze(0)
381
 
382
  with torch.no_grad():
383
  out = model(img_seq, move_seq, inference_params=state['inference_params'])
 
388
  state['drawn_positions'].append((px, py))
389
  state['sequence_length'] += 1
390
 
 
391
  state['inference_params'] = transfer_inference_params(state['inference_params'], torch.device('cpu'))
392
 
393
  img_display, _ = draw_patches_on_image(
 
395
  state['x_offset'], state['y_offset'], state['h'], state['w']
396
  )
397
 
398
+ return img_display, format_predictions(final_probs), state, f"Added patch {state['sequence_length']} (Total {state['inference_params'].seqlen_offset} steps)."
399
+
400
+ def on_click(evt: gr.SelectData, original_image, state):
401
+ x_orig, y_orig = evt.index
402
+ return process_click_inference(x_orig, y_orig, original_image, state)
403
 
404
  def on_upload(image):
405
  if image is None:
406
+ return None, None, {"Waiting...": 1.0}, None, "Upload Image"
407
+
408
+ # Pre-render the grey background immediately on upload
409
+ grey_base = Image.fromarray(image).convert("L").convert("RGB")
410
+ grey_base_np = (np.array(grey_base).astype(float) * 0.4 + 160).clip(0, 255).astype(np.uint8)
411
+
412
+ return grey_base_np, image, {"Click Auto Scan or click the image": 1.0}, None, "Ready. You can Auto Scan or click."
413
 
414
  def on_clear(original_image):
415
  if original_image is None:
416
  return None, {"Cleared": 1.0}, None, "Cleared"
417
+
418
+ grey_base = Image.fromarray(original_image).convert("L").convert("RGB")
419
+ grey_base_np = (np.array(grey_base).astype(float) * 0.4 + 160).clip(0, 255).astype(np.uint8)
420
+
421
+ return grey_base_np, {"Cleared": 1.0}, init_state_for_image(original_image), "Selections cleared. Ready for new patch sequence."
422
 
423
+ with gr.Blocks(title="MambaEye Interactive Demo") as demo:
424
+ gr.Markdown("# MambaEye Interactive Inference Demo")
425
+ gr.Markdown("This interface incorporates the full **MambaEye-base** model natively.")
 
426
 
427
  with gr.Row():
428
  with gr.Column(scale=2):
429
+ # elem_classes targets the JS overlay script correctly
430
+ input_image = gr.Image(type="numpy", label="Upload and Select Patches", interactive=True, elem_classes="gradio-image-hook")
431
 
432
  with gr.Row():
433
  scan_pattern = gr.Dropdown(
 
445
  model_output_label = gr.Label(label="MambaEye Output Predictions", num_top_classes=5)
446
  status_text = gr.Markdown("Status: Waiting for image upload...")
447
 
 
448
  state = gr.State(None)
449
  original_image_state = gr.State(None)
450
 
 
451
  input_image.upload(
452
  fn=on_upload,
453
  inputs=[input_image],
454
+ outputs=[input_image, original_image_state, model_output_label, state, status_text]
 
 
455
  )
456
 
457
  auto_btn.click(
 
471
  inputs=[original_image_state],
472
  outputs=[input_image, model_output_label, state, status_text]
473
  )
474
+
475
+ demo.load(js=JS_HOVER_SCRIPT)
476
 
477
  if __name__ == "__main__":
478
+ demo.launch(theme=gr.themes.Soft(), ssr_mode=False)