piekenius123 commited on
Commit
eb2300a
·
verified ·
1 Parent(s): a949f42

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +326 -136
app.py CHANGED
@@ -1,25 +1,34 @@
1
- # app.py
2
  import io
3
  import json
4
- import base64
5
  import random
6
- from typing import Optional, Dict, Any, List, Tuple
 
7
 
8
- import pandas as pd
9
- from PIL import Image
10
  import gradio as gr
 
11
  from huggingface_hub import HfApi, hf_hub_download
 
12
 
13
  DATASET_REPO_ID = "piekenius123/Amaze"
14
  REPO_TYPE = "dataset"
15
 
 
 
 
 
16
  SHAPES = ["circle", "hexagon", "square", "triangle"]
17
  SPLITS = ["train", "val", "test"]
18
 
19
- MAZE_SIZE_MIN, MAZE_SIZE_MAX = 3, 16
20
- MAZE_SIZE_CHOICES = ["All"] + [f"{n}×{n}" for n in range(MAZE_SIZE_MIN, MAZE_SIZE_MAX + 1)]
21
-
22
  IMAGE_COLS = ["original_img", "m_original_img", "sol_img", "mask_img", "cell_map"]
 
 
 
 
 
 
 
 
23
 
24
 
25
  # -------------------------
@@ -63,74 +72,167 @@ def decode_base64_image(base64_str: Any) -> Optional[Image.Image]:
63
  return None
64
 
65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
  def infer_shape_from_repo_path(path: str) -> Optional[str]:
67
- p = path.replace("\\", "/").lower()
68
- for s in SHAPES:
69
- if p.startswith(f"{s}/") or f"/{s}/" in p:
70
- return s
71
  return None
72
 
73
 
74
  def infer_split_from_repo_path(path: str) -> Optional[str]:
75
- p = path.replace("\\", "/").lower()
76
  fn = p.split("/")[-1]
77
-
78
- if fn == "maze_dataset_train.parquet":
79
- return "train"
80
-
81
- if fn == "maze_dataset_test.parquet":
82
- if "/maze-dataset_train/" in p:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  return "val"
84
- if "/maze-dataset/" in p:
85
  return "test"
86
 
87
  return None
88
 
89
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
90
  def get_metadata_size(meta_str: Any) -> Optional[Tuple[int, int]]:
91
- """
92
- Your metadata structure says width/height are under maze_config (for non-circle).
93
- Some datasets also duplicate width/height at top-level; we support both.
94
- """
95
- d, err = safe_json_loads(meta_str)
96
- if not d or err:
97
  return None
98
 
99
- mc = d.get("maze_config") if isinstance(d, dict) else None
100
- if isinstance(mc, dict) and ("width" in mc) and ("height" in mc):
101
- try:
102
- return int(mc["width"]), int(mc["height"])
103
- except Exception:
104
- pass
 
105
 
106
- if ("width" in d) and ("height" in d):
107
- try:
108
- return int(d["width"]), int(d["height"])
109
- except Exception:
110
- pass
111
 
112
  return None
113
 
114
 
115
- def filter_df_by_maze_size(df: pd.DataFrame, size_str: Optional[str]) -> pd.DataFrame:
116
- if not size_str or size_str == "All":
117
- return df
118
- try:
119
- a, b = size_str.split("×")
120
- w, h = int(a), int(b)
121
- except Exception:
122
- return df
 
 
 
 
 
 
 
123
  if "metadata" not in df.columns:
124
- return df
 
 
 
 
 
 
 
 
 
125
 
126
- mask = df["metadata"].apply(lambda m: get_metadata_size(m) == (w, h))
 
 
 
 
 
127
  return df.loc[mask].reset_index(drop=True)
128
 
129
 
130
  def summarize_df(df: pd.DataFrame, filtered_len: Optional[int] = None) -> str:
131
- base = f"{len(df)} rows · {len(df.columns)} cols"
132
  if filtered_len is not None and filtered_len != len(df):
133
- base += f" · filtered: {filtered_len}"
134
  return base
135
 
136
 
@@ -156,22 +258,36 @@ def find_index_by_id(df: pd.DataFrame, sample_id: str) -> Optional[int]:
156
  return None
157
 
158
 
 
 
 
 
159
  # -------------------------
160
  # HF repo index + cache
161
  # -------------------------
162
- def build_repo_index() -> List[Dict[str, str]]:
163
  api = HfApi()
164
  files = api.list_repo_files(repo_id=DATASET_REPO_ID, repo_type=REPO_TYPE)
165
 
166
- records: List[Dict[str, str]] = []
167
- for f in files:
168
- if not f.lower().endswith(".parquet"):
169
  continue
170
- shape = infer_shape_from_repo_path(f)
171
- split = infer_split_from_repo_path(f)
172
- if shape and split:
173
- records.append({"repo_path": f, "shape": shape, "split": split})
174
- records.sort(key=lambda r: r["repo_path"])
 
 
 
 
 
 
 
 
 
 
175
  return records
176
 
177
 
@@ -187,14 +303,43 @@ def download_and_load_df(repo_path: str) -> pd.DataFrame:
187
  if local_path in _DF_CACHE:
188
  return _DF_CACHE[local_path]
189
 
190
- wanted_cols = ["id", "instruction", "metadata"] + IMAGE_COLS
191
- df = pd.read_parquet(local_path, columns=[c for c in wanted_cols if c is not None])
192
  _DF_CACHE[local_path] = df
193
  return df
194
 
195
 
196
- def get_repo_paths(records: List[Dict[str, str]], shape: str, split: str) -> List[str]:
197
- out = [r["repo_path"] for r in (records or []) if r["shape"] == shape and r["split"] == split]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
198
  out.sort()
199
  return out
200
 
@@ -204,19 +349,12 @@ def get_repo_paths(records: List[Dict[str, str]], shape: str, split: str) -> Lis
204
  # -------------------------
205
  def render_sample_view(df_filtered: pd.DataFrame, index: int):
206
  if len(df_filtered) == 0:
207
- return (
208
- 0,
209
- gr.update(value="No samples (after filtering)."),
210
- "",
211
- [],
212
- {},
213
- "",
214
- )
215
 
216
  index = max(0, min(int(index), len(df_filtered) - 1))
217
  row = df_filtered.iloc[index]
218
 
219
- sid = str(row.get("id", f"maze_{index}"))
220
  instruction = str(row.get("instruction", ""))
221
 
222
  original = decode_base64_image(row.get("original_img"))
@@ -241,9 +379,9 @@ def render_sample_view(df_filtered: pd.DataFrame, index: int):
241
  (mask, "Mask"),
242
  (cell_map, "Cell map"),
243
  ]
244
- gallery_items = [(img, cap) for (img, cap) in gallery_items if img is not None]
245
 
246
- status_md = f"**Sample** `{sid}` \n**Index** `{index}` / `{len(df_filtered)-1}`"
247
  return index, status_md, instruction, gallery_items, meta_json, meta_raw
248
 
249
 
@@ -252,61 +390,98 @@ def render_sample_view(df_filtered: pd.DataFrame, index: int):
252
  # -------------------------
253
  def init_app():
254
  try:
255
- recs = build_repo_index()
256
- info_html = f"<div id='badges'><span class='badge'>✅ Indexed <b>{DATASET_REPO_ID}</b></span><span class='badge'>{len(recs)} parquet files</span></div>"
257
- return recs, info_html
 
 
 
 
 
258
  except Exception as e:
259
- return [], f"<div id='badges'><span class='badge'>Failed to index: {e}</span></div>"
260
 
261
 
262
- def on_shape_split_change(records: List[Dict[str, str]], shape: str, split: str):
263
- choices = get_repo_paths(records, shape, split)
264
  value = choices[0] if choices else None
265
- tip_html = f"<div id='badges'><span class='badge'>Found <b>{len(choices)}</b> parquet file(s) for <b>{shape}</b> / <b>{split}</b></span></div>"
266
- return gr.Dropdown(choices=choices, value=value), tip_html
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
 
268
 
269
  def get_filtered_df(repo_path: str, size_str: str) -> Tuple[pd.DataFrame, str]:
270
  df = download_and_load_df(repo_path)
271
- filtered = filter_df_by_maze_size(df, size_str)
272
  summary = summarize_df(df, filtered_len=len(filtered))
273
  return filtered, summary
274
 
275
 
276
  def on_select_parquet(repo_path: str, size_str: str):
277
  if not repo_path:
278
- return gr.update(value="<div id='badges'><span class='badge'>No parquet selected</span></div>"), gr.update(maximum=0, value=0)
 
 
 
 
279
 
280
- filtered, summary = get_filtered_df(repo_path, size_str)
 
 
 
281
  max_idx = max(0, len(filtered) - 1)
282
- summary_html = f"<div id='badges'><span class='badge'>{summary}</span></div>"
283
- return gr.update(value=summary_html), gr.update(maximum=max_idx, value=0)
 
 
 
 
284
 
285
 
286
  def on_prev(repo_path: str, index: int, size_str: str):
287
  if not repo_path:
288
- return 0, "No parquet selected.", "", [], {}, ""
289
  filtered, _ = get_filtered_df(repo_path, size_str)
290
  return render_sample_view(filtered, max(0, int(index) - 1))
291
 
292
 
293
  def on_next(repo_path: str, index: int, size_str: str):
294
  if not repo_path:
295
- return 0, "No parquet selected.", "", [], {}, ""
296
  filtered, _ = get_filtered_df(repo_path, size_str)
297
  return render_sample_view(filtered, min(len(filtered) - 1, int(index) + 1))
298
 
299
 
300
  def on_show(repo_path: str, index: int, size_str: str):
301
  if not repo_path:
302
- return 0, "No parquet selected.", "", [], {}, ""
303
  filtered, _ = get_filtered_df(repo_path, size_str)
304
  return render_sample_view(filtered, index)
305
 
306
 
307
  def on_random(repo_path: str, size_str: str):
308
  if not repo_path:
309
- return 0, "No parquet selected.", "", [], {}, ""
310
  filtered, _ = get_filtered_df(repo_path, size_str)
311
  if len(filtered) == 0:
312
  return render_sample_view(filtered, 0)
@@ -315,12 +490,12 @@ def on_random(repo_path: str, size_str: str):
315
 
316
  def on_find_id(repo_path: str, query_id: str, size_str: str):
317
  if not repo_path:
318
- return 0, "No parquet selected.", "", [], {}, ""
319
  filtered, _ = get_filtered_df(repo_path, size_str)
320
  pos = find_index_by_id(filtered, query_id.strip() if isinstance(query_id, str) else "")
321
  if pos is None:
322
  out = list(render_sample_view(filtered, 0))
323
- out[1] = out[1] + f" \n⚠️ id search `{query_id}` not found"
324
  return tuple(out)
325
  return render_sample_view(filtered, pos)
326
 
@@ -329,12 +504,9 @@ def on_find_id(repo_path: str, query_id: str, size_str: str):
329
  # UI (styled)
330
  # -------------------------
331
  CSS = """
332
- /* 使用系统默认字体 */
333
  .gradio-container { font-family: system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif !important; }
334
- /* 全局:页面居中 + 不要铺满 */
335
  .gradio-container { max-width: 1200px !important; margin: 0 auto !important; }
336
 
337
- /* 顶部控制卡片:紧凑、没有大灰底空白 */
338
  #topbar {
339
  padding: 12px 14px;
340
  border-radius: 16px;
@@ -343,14 +515,9 @@ CSS = """
343
  }
344
  #topbar .gr-row { flex-wrap: wrap; gap: 10px; }
345
  #topbar .gr-form { margin-bottom: 0 !important; }
346
-
347
- /* 输入/下拉更紧凑 */
348
  #topbar input, #topbar textarea, #topbar .wrap { border-radius: 12px !important; }
349
-
350
- /* 按钮统一,不要变成右侧巨大菜单 */
351
  #topbar button { height: 42px !important; border-radius: 12px !important; }
352
 
353
- /* badges */
354
  #badges { display: flex; gap: 10px; flex-wrap: wrap; align-items: center; }
355
  .badge {
356
  padding: 6px 10px;
@@ -361,12 +528,9 @@ CSS = """
361
  line-height: 1.2;
362
  }
363
 
364
- /* Index 一行,按钮单独一行并向下留间距 */
365
  #toolbar .gr-row { align-items: end; }
366
  #toolbar-btns { margin-top: 12px; }
367
  #toolbar-btns .gr-row { align-items: end; }
368
-
369
- /* Gallery 更像 viewer */
370
  #viewer { margin-top: 10px; }
371
  """
372
 
@@ -375,49 +539,45 @@ THEME = gr.themes.Soft(
375
  text_size=gr.themes.sizes.text_md,
376
  )
377
 
 
378
  def build_ui():
379
  with gr.Blocks(title="Amaze Viewer", theme=THEME, css=CSS) as demo:
380
  gr.Markdown(
381
  f"""
382
  # Amaze
383
- Dataset: https://huggingface.co/datasets/piekenius123/Amaze
384
-
385
- Amaze is a benchmark for Edting-as-Reasoning task (EAR). It features four maze shapes: circle, hexagon, square, and triangle. Each sample provides: an unmarked maze image (original_img), a maze image with start and end points marked (m_original_img), a blue solution path image (sol_img), a binary path mask (mask_img), a cell segmentation map (cell_map), and metadata (JSON) for describing the maze structure and difficulty.
386
-
387
- The test set covers various sizes from 3×3 to 16×16 (50 samples for each size), while the training set mainly consists of 3×3 mazes (1024 samples), and validation set consists of 3×3 mazes (256 samples).
388
-
389
- Browse samples by **shape / split / maze size**, then view images + metadata.
390
  """
391
  )
392
 
393
  records_state = gr.State([])
394
 
395
- # Top control bar (compact card)
396
  with gr.Column(elem_id="topbar"):
397
  with gr.Row():
398
  parquet_tip = gr.HTML(value="<div id='badges'></div>")
399
  summary_badge = gr.HTML(value="<div id='badges'><span class='badge'>No parquet selected</span></div>")
400
- scan_info = gr.HTML(value="<div id='badges'><span class='badge'>Indexing dataset repo</span></div>")
401
-
402
 
403
  with gr.Row():
 
404
  shape_dd = gr.Dropdown(label="Shape", choices=SHAPES, value="circle", scale=1)
405
  split_dd = gr.Dropdown(label="Split", choices=SPLITS, value="test", scale=1)
406
- size_dd = gr.Dropdown(label="Maze size", choices=MAZE_SIZE_CHOICES, value="All", scale=1)
407
  parquet_dd = gr.Dropdown(label="Parquet", choices=[], value=None, scale=2)
408
-
409
 
410
  with gr.Row(elem_id="toolbar"):
411
  id_query = gr.Textbox(label="Find by id", placeholder="UUID or substring", scale=2)
412
  idx_slider = gr.Slider(label="Index", minimum=0, maximum=0, value=0, step=1, scale=2)
 
413
  with gr.Row():
414
- prev_btn = gr.Button("Prev", variant="secondary", scale=1)
415
- next_btn = gr.Button("Next", variant="secondary", scale=1)
416
- random_btn = gr.Button("🎲 Random", variant="primary", scale=1)
417
- find_btn = gr.Button("🔎 Find", variant="secondary", scale=1)
418
  show_btn = gr.Button("Show", variant="secondary", scale=1)
419
 
420
- # Main viewer layout
421
  with gr.Row(elem_id="viewer"):
422
  with gr.Column(scale=3):
423
  status_md = gr.Markdown(elem_id="status")
@@ -436,42 +596,72 @@ def build_ui():
436
  with gr.Accordion("Metadata (raw)", open=False):
437
  meta_raw = gr.Textbox(lines=10, interactive=False)
438
 
439
- # ---- events ----
440
  demo.load(
441
  fn=init_app,
442
  inputs=None,
443
  outputs=[records_state, scan_info],
444
  ).then(
445
- fn=on_shape_split_change,
446
- inputs=[records_state, shape_dd, split_dd],
447
- outputs=[parquet_dd, parquet_tip],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448
  ).then(
449
- fn=lambda p, s: on_select_parquet(p, s) if p else (gr.update(value="<div id='badges'><span class='badge'>No parquet selected</span></div>"), gr.update(maximum=0, value=0)),
450
  inputs=[parquet_dd, size_dd],
451
- outputs=[summary_badge, idx_slider],
452
  ).then(
453
- fn=lambda p, s: on_show(p, 0, s) if p else (0, "No parquet selected.", "", [], {}, ""),
454
  inputs=[parquet_dd, size_dd],
455
  outputs=[idx_slider, status_md, instruction, gallery, meta_json, meta_raw],
456
  )
457
 
458
  shape_dd.change(
459
- fn=on_shape_split_change,
460
- inputs=[records_state, shape_dd, split_dd],
461
  outputs=[parquet_dd, parquet_tip],
 
 
 
 
 
 
 
 
462
  )
 
463
  split_dd.change(
464
- fn=on_shape_split_change,
465
- inputs=[records_state, shape_dd, split_dd],
466
  outputs=[parquet_dd, parquet_tip],
 
 
 
 
 
 
 
 
467
  )
468
 
469
  parquet_dd.change(
470
  fn=on_select_parquet,
471
  inputs=[parquet_dd, size_dd],
472
- outputs=[summary_badge, idx_slider],
473
  ).then(
474
- fn=lambda p, s: on_show(p, 0, s) if p else (0, "No parquet selected.", "", [], {}, ""),
475
  inputs=[parquet_dd, size_dd],
476
  outputs=[idx_slider, status_md, instruction, gallery, meta_json, meta_raw],
477
  )
@@ -479,9 +669,9 @@ def build_ui():
479
  size_dd.change(
480
  fn=on_select_parquet,
481
  inputs=[parquet_dd, size_dd],
482
- outputs=[summary_badge, idx_slider],
483
  ).then(
484
- fn=lambda p, s: on_show(p, 0, s) if p else (0, "No parquet selected.", "", [], {}, ""),
485
  inputs=[parquet_dd, size_dd],
486
  outputs=[idx_slider, status_md, instruction, gallery, meta_json, meta_raw],
487
  )
@@ -522,4 +712,4 @@ def build_ui():
522
 
523
  if __name__ == "__main__":
524
  demo = build_ui()
525
- demo.launch()
 
1
+ import base64
2
  import io
3
  import json
 
4
  import random
5
+ import re
6
+ from typing import Any, Dict, List, Optional, Tuple
7
 
 
 
8
  import gradio as gr
9
+ import pandas as pd
10
  from huggingface_hub import HfApi, hf_hub_download
11
+ from PIL import Image
12
 
13
  DATASET_REPO_ID = "piekenius123/Amaze"
14
  REPO_TYPE = "dataset"
15
 
16
+ TASKS = ["maze", "queen"]
17
+ DEFAULT_TASK = "maze"
18
+ DEFAULT_SIZE_CHOICE = "All"
19
+
20
  SHAPES = ["circle", "hexagon", "square", "triangle"]
21
  SPLITS = ["train", "val", "test"]
22
 
 
 
 
23
  IMAGE_COLS = ["original_img", "m_original_img", "sol_img", "mask_img", "cell_map"]
24
+ SIZE_CONTAINER_KEYS = ["maze_config", "queen_config", "board_config", "config"]
25
+ SIZE_PAIR_KEYS = [
26
+ ("width", "height"),
27
+ ("cols", "rows"),
28
+ ("board_width", "board_height"),
29
+ ("board_cols", "board_rows"),
30
+ ]
31
+ SIZE_SCALAR_KEYS = ["size", "board_size", "n", "board_n"]
32
 
33
 
34
  # -------------------------
 
72
  return None
73
 
74
 
75
+ def normalize_repo_path(path: str) -> str:
76
+ return path.replace("\\", "/").lower()
77
+
78
+
79
+ def get_path_segments(path: str) -> List[str]:
80
+ return [seg for seg in normalize_repo_path(path).split("/") if seg]
81
+
82
+
83
+ def get_path_tokens(path: str) -> List[str]:
84
+ return [tok for tok in re.split(r"[/_.-]+", normalize_repo_path(path)) if tok]
85
+
86
+
87
+ def infer_task_from_repo_path(path: str) -> str:
88
+ segments = get_path_segments(path)
89
+ tokens = get_path_tokens(path)
90
+
91
+ for candidate in TASKS:
92
+ if candidate in segments:
93
+ return candidate
94
+
95
+ for token in tokens:
96
+ if token.startswith("queen"):
97
+ return "queen"
98
+ if token.startswith("maze"):
99
+ return "maze"
100
+
101
+ return DEFAULT_TASK
102
+
103
+
104
  def infer_shape_from_repo_path(path: str) -> Optional[str]:
105
+ segments = get_path_segments(path)
106
+ for shape in SHAPES:
107
+ if shape in segments:
108
+ return shape
109
  return None
110
 
111
 
112
  def infer_split_from_repo_path(path: str) -> Optional[str]:
113
+ p = normalize_repo_path(path)
114
  fn = p.split("/")[-1]
115
+ segments = get_path_segments(path)
116
+
117
+ # Backward compatibility for the original maze repo layout:
118
+ # maze-dataset_train/maze_dataset_test.parquet is the validation split.
119
+ if "/maze-dataset_train/" in p and fn == "maze_dataset_test.parquet":
120
+ return "val"
121
+ if "/maze-dataset/" in p and fn == "maze_dataset_test.parquet":
122
+ return "test"
123
+
124
+ filename_checks = [
125
+ (r"(?:^|[_-])train(?:ing)?(?:[_-]|\.|$)", "train"),
126
+ (r"(?:^|[_-])(?:val|valid|validation)(?:[_-]|\.|$)", "val"),
127
+ (r"(?:^|[_-])test(?:[_-]|\.|$)", "test"),
128
+ ]
129
+ for pattern, split in filename_checks:
130
+ if re.search(pattern, fn):
131
+ return split
132
+
133
+ for seg in reversed(segments[:-1]):
134
+ if seg in {"train", "training"}:
135
+ return "train"
136
+ if seg in {"val", "valid", "validation"}:
137
  return "val"
138
+ if seg == "test":
139
  return "test"
140
 
141
  return None
142
 
143
 
144
+ def iter_metadata_containers(meta: Dict[str, Any]) -> List[Dict[str, Any]]:
145
+ containers = [meta]
146
+ for key in SIZE_CONTAINER_KEYS:
147
+ value = meta.get(key)
148
+ if isinstance(value, dict):
149
+ containers.append(value)
150
+ return containers
151
+
152
+
153
+ def parse_square_size(value: Any) -> Optional[Tuple[int, int]]:
154
+ if isinstance(value, bool):
155
+ return None
156
+ if isinstance(value, int):
157
+ return value, value
158
+ if isinstance(value, float):
159
+ if pd.isna(value):
160
+ return None
161
+ iv = int(value)
162
+ return iv, iv
163
+ if isinstance(value, str):
164
+ text = value.strip().lower()
165
+ match = re.fullmatch(r"(\d+)\s*[x×]\s*(\d+)", text)
166
+ if match:
167
+ return int(match.group(1)), int(match.group(2))
168
+ if text.isdigit():
169
+ iv = int(text)
170
+ return iv, iv
171
+ return None
172
+
173
+
174
  def get_metadata_size(meta_str: Any) -> Optional[Tuple[int, int]]:
175
+ meta, err = safe_json_loads(meta_str)
176
+ if not meta or err:
 
 
 
 
177
  return None
178
 
179
+ for container in iter_metadata_containers(meta):
180
+ for width_key, height_key in SIZE_PAIR_KEYS:
181
+ if width_key in container and height_key in container:
182
+ try:
183
+ return int(container[width_key]), int(container[height_key])
184
+ except Exception:
185
+ pass
186
 
187
+ for scalar_key in SIZE_SCALAR_KEYS:
188
+ if scalar_key in container:
189
+ parsed = parse_square_size(container[scalar_key])
190
+ if parsed:
191
+ return parsed
192
 
193
  return None
194
 
195
 
196
+ def format_size_choice(size: Tuple[int, int]) -> str:
197
+ return f"{size[0]}x{size[1]}"
198
+
199
+
200
+ def parse_size_choice(size_str: Optional[str]) -> Optional[Tuple[int, int]]:
201
+ if not size_str or size_str == DEFAULT_SIZE_CHOICE:
202
+ return None
203
+
204
+ match = re.fullmatch(r"\s*(\d+)\s*[x×]\s*(\d+)\s*", size_str)
205
+ if not match:
206
+ return None
207
+ return int(match.group(1)), int(match.group(2))
208
+
209
+
210
+ def get_size_choices(df: pd.DataFrame) -> List[str]:
211
  if "metadata" not in df.columns:
212
+ return [DEFAULT_SIZE_CHOICE]
213
+
214
+ sizes = {
215
+ size
216
+ for size in df["metadata"].map(get_metadata_size)
217
+ if size is not None
218
+ }
219
+ ordered_sizes = sorted(sizes, key=lambda x: (x[0] * x[1], x[0], x[1]))
220
+ return [DEFAULT_SIZE_CHOICE] + [format_size_choice(size) for size in ordered_sizes]
221
+
222
 
223
+ def filter_df_by_size(df: pd.DataFrame, size_str: Optional[str]) -> pd.DataFrame:
224
+ target_size = parse_size_choice(size_str)
225
+ if target_size is None or "metadata" not in df.columns:
226
+ return df.reset_index(drop=True)
227
+
228
+ mask = df["metadata"].map(lambda meta: get_metadata_size(meta) == target_size)
229
  return df.loc[mask].reset_index(drop=True)
230
 
231
 
232
  def summarize_df(df: pd.DataFrame, filtered_len: Optional[int] = None) -> str:
233
+ base = f"{len(df)} rows | {len(df.columns)} cols"
234
  if filtered_len is not None and filtered_len != len(df):
235
+ base += f" | filtered: {filtered_len}"
236
  return base
237
 
238
 
 
258
  return None
259
 
260
 
261
+ def empty_sample_view(message: str = "No parquet selected."):
262
+ return 0, message, "", [], {}, ""
263
+
264
+
265
  # -------------------------
266
  # HF repo index + cache
267
  # -------------------------
268
+ def build_repo_index() -> List[Dict[str, Optional[str]]]:
269
  api = HfApi()
270
  files = api.list_repo_files(repo_id=DATASET_REPO_ID, repo_type=REPO_TYPE)
271
 
272
+ records: List[Dict[str, Optional[str]]] = []
273
+ for repo_path in files:
274
+ if not repo_path.lower().endswith(".parquet"):
275
  continue
276
+
277
+ task = infer_task_from_repo_path(repo_path)
278
+ shape = infer_shape_from_repo_path(repo_path)
279
+ split = infer_split_from_repo_path(repo_path)
280
+ if split:
281
+ records.append(
282
+ {
283
+ "repo_path": repo_path,
284
+ "task": task,
285
+ "shape": shape,
286
+ "split": split,
287
+ }
288
+ )
289
+
290
+ records.sort(key=lambda record: (record["task"] or "", record["shape"] or "", record["split"] or "", record["repo_path"] or ""))
291
  return records
292
 
293
 
 
303
  if local_path in _DF_CACHE:
304
  return _DF_CACHE[local_path]
305
 
306
+ df = pd.read_parquet(local_path)
 
307
  _DF_CACHE[local_path] = df
308
  return df
309
 
310
 
311
+ def get_shape_choices(records: List[Dict[str, Optional[str]]], task: str) -> List[str]:
312
+ shapes = sorted(
313
+ {record["shape"] for record in (records or []) if record.get("task") == task and record.get("shape")},
314
+ key=lambda shape: SHAPES.index(shape) if shape in SHAPES else len(SHAPES),
315
+ )
316
+ if shapes:
317
+ return shapes
318
+ if task == "maze":
319
+ return SHAPES.copy()
320
+ return []
321
+
322
+
323
+ def get_default_shape(task: str, choices: List[str]) -> str:
324
+ if not choices:
325
+ return "All"
326
+ if task == "maze" and "circle" in choices:
327
+ return "circle"
328
+ return choices[0]
329
+
330
+
331
+ def get_repo_paths(records: List[Dict[str, Optional[str]]], task: str, shape: str, split: str) -> List[str]:
332
+ out: List[str] = []
333
+ for record in records or []:
334
+ if record.get("task") != task:
335
+ continue
336
+ if record.get("split") != split:
337
+ continue
338
+ record_shape = record.get("shape")
339
+ if shape and shape != "All" and record_shape != shape:
340
+ continue
341
+ out.append(str(record["repo_path"]))
342
+
343
  out.sort()
344
  return out
345
 
 
349
  # -------------------------
350
  def render_sample_view(df_filtered: pd.DataFrame, index: int):
351
  if len(df_filtered) == 0:
352
+ return empty_sample_view("No samples after filtering.")
 
 
 
 
 
 
 
353
 
354
  index = max(0, min(int(index), len(df_filtered) - 1))
355
  row = df_filtered.iloc[index]
356
 
357
+ sid = str(row.get("id", f"sample_{index}"))
358
  instruction = str(row.get("instruction", ""))
359
 
360
  original = decode_base64_image(row.get("original_img"))
 
379
  (mask, "Mask"),
380
  (cell_map, "Cell map"),
381
  ]
382
+ gallery_items = [(img, caption) for (img, caption) in gallery_items if img is not None]
383
 
384
+ status_md = f"**Sample** `{sid}` \n**Index** `{index}` / `{len(df_filtered) - 1}`"
385
  return index, status_md, instruction, gallery_items, meta_json, meta_raw
386
 
387
 
 
390
  # -------------------------
391
  def init_app():
392
  try:
393
+ records = build_repo_index()
394
+ info_html = (
395
+ "<div id='badges'>"
396
+ f"<span class='badge'>Indexed <b>{DATASET_REPO_ID}</b></span>"
397
+ f"<span class='badge'>{len(records)} parquet files</span>"
398
+ "</div>"
399
+ )
400
+ return records, info_html
401
  except Exception as e:
402
+ return [], f"<div id='badges'><span class='badge'>Failed to index: {e}</span></div>"
403
 
404
 
405
+ def build_parquet_dropdown(records: List[Dict[str, Optional[str]]], task: str, shape: str, split: str):
406
+ choices = get_repo_paths(records, task, shape, split)
407
  value = choices[0] if choices else None
408
+ scope = f"{task} / {split}" if shape == "All" else f"{task} / {shape} / {split}"
409
+ tip_html = (
410
+ "<div id='badges'>"
411
+ f"<span class='badge'>Found <b>{len(choices)}</b> parquet file(s) for <b>{scope}</b></span>"
412
+ "</div>"
413
+ )
414
+ return gr.update(choices=choices, value=value), tip_html
415
+
416
+
417
+ def on_task_change(records: List[Dict[str, Optional[str]]], task: str, split: str):
418
+ shape_choices = get_shape_choices(records, task)
419
+ shape_visible = task == "maze" or bool(shape_choices)
420
+ if not shape_choices:
421
+ shape_choices = ["All"]
422
+ shape_value = get_default_shape(task, shape_choices)
423
+
424
+ parquet_update, tip_html = build_parquet_dropdown(records, task, shape_value, split)
425
+ shape_update = gr.update(choices=shape_choices, value=shape_value, visible=shape_visible)
426
+ return shape_update, parquet_update, tip_html
427
+
428
+
429
+ def on_task_shape_split_change(records: List[Dict[str, Optional[str]]], task: str, shape: str, split: str):
430
+ return build_parquet_dropdown(records, task, shape or "All", split)
431
 
432
 
433
  def get_filtered_df(repo_path: str, size_str: str) -> Tuple[pd.DataFrame, str]:
434
  df = download_and_load_df(repo_path)
435
+ filtered = filter_df_by_size(df, size_str)
436
  summary = summarize_df(df, filtered_len=len(filtered))
437
  return filtered, summary
438
 
439
 
440
  def on_select_parquet(repo_path: str, size_str: str):
441
  if not repo_path:
442
+ return (
443
+ gr.update(value="<div id='badges'><span class='badge'>No parquet selected</span></div>"),
444
+ gr.update(maximum=0, value=0),
445
+ gr.update(choices=[DEFAULT_SIZE_CHOICE], value=DEFAULT_SIZE_CHOICE),
446
+ )
447
 
448
+ df = download_and_load_df(repo_path)
449
+ size_choices = get_size_choices(df)
450
+ size_value = size_str if size_str in size_choices else DEFAULT_SIZE_CHOICE
451
+ filtered = filter_df_by_size(df, size_value)
452
  max_idx = max(0, len(filtered) - 1)
453
+ summary_html = f"<div id='badges'><span class='badge'>{summarize_df(df, filtered_len=len(filtered))}</span></div>"
454
+ return (
455
+ gr.update(value=summary_html),
456
+ gr.update(maximum=max_idx, value=0),
457
+ gr.update(choices=size_choices, value=size_value),
458
+ )
459
 
460
 
461
  def on_prev(repo_path: str, index: int, size_str: str):
462
  if not repo_path:
463
+ return empty_sample_view()
464
  filtered, _ = get_filtered_df(repo_path, size_str)
465
  return render_sample_view(filtered, max(0, int(index) - 1))
466
 
467
 
468
  def on_next(repo_path: str, index: int, size_str: str):
469
  if not repo_path:
470
+ return empty_sample_view()
471
  filtered, _ = get_filtered_df(repo_path, size_str)
472
  return render_sample_view(filtered, min(len(filtered) - 1, int(index) + 1))
473
 
474
 
475
  def on_show(repo_path: str, index: int, size_str: str):
476
  if not repo_path:
477
+ return empty_sample_view()
478
  filtered, _ = get_filtered_df(repo_path, size_str)
479
  return render_sample_view(filtered, index)
480
 
481
 
482
  def on_random(repo_path: str, size_str: str):
483
  if not repo_path:
484
+ return empty_sample_view()
485
  filtered, _ = get_filtered_df(repo_path, size_str)
486
  if len(filtered) == 0:
487
  return render_sample_view(filtered, 0)
 
490
 
491
  def on_find_id(repo_path: str, query_id: str, size_str: str):
492
  if not repo_path:
493
+ return empty_sample_view()
494
  filtered, _ = get_filtered_df(repo_path, size_str)
495
  pos = find_index_by_id(filtered, query_id.strip() if isinstance(query_id, str) else "")
496
  if pos is None:
497
  out = list(render_sample_view(filtered, 0))
498
+ out[1] = out[1] + f" \nID search `{query_id}` not found"
499
  return tuple(out)
500
  return render_sample_view(filtered, pos)
501
 
 
504
  # UI (styled)
505
  # -------------------------
506
  CSS = """
 
507
  .gradio-container { font-family: system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif !important; }
 
508
  .gradio-container { max-width: 1200px !important; margin: 0 auto !important; }
509
 
 
510
  #topbar {
511
  padding: 12px 14px;
512
  border-radius: 16px;
 
515
  }
516
  #topbar .gr-row { flex-wrap: wrap; gap: 10px; }
517
  #topbar .gr-form { margin-bottom: 0 !important; }
 
 
518
  #topbar input, #topbar textarea, #topbar .wrap { border-radius: 12px !important; }
 
 
519
  #topbar button { height: 42px !important; border-radius: 12px !important; }
520
 
 
521
  #badges { display: flex; gap: 10px; flex-wrap: wrap; align-items: center; }
522
  .badge {
523
  padding: 6px 10px;
 
528
  line-height: 1.2;
529
  }
530
 
 
531
  #toolbar .gr-row { align-items: end; }
532
  #toolbar-btns { margin-top: 12px; }
533
  #toolbar-btns .gr-row { align-items: end; }
 
 
534
  #viewer { margin-top: 10px; }
535
  """
536
 
 
539
  text_size=gr.themes.sizes.text_md,
540
  )
541
 
542
+
543
  def build_ui():
544
  with gr.Blocks(title="Amaze Viewer", theme=THEME, css=CSS) as demo:
545
  gr.Markdown(
546
  f"""
547
  # Amaze
548
+ Dataset: https://huggingface.co/datasets/piekenius123/Amaze
549
+
550
+ Browse samples by **task / shape / split / size**, then inspect the images and metadata.
551
+ Maze and Queen share the same viewer so the visualization panel stays unchanged.
 
 
 
552
  """
553
  )
554
 
555
  records_state = gr.State([])
556
 
 
557
  with gr.Column(elem_id="topbar"):
558
  with gr.Row():
559
  parquet_tip = gr.HTML(value="<div id='badges'></div>")
560
  summary_badge = gr.HTML(value="<div id='badges'><span class='badge'>No parquet selected</span></div>")
561
+ scan_info = gr.HTML(value="<div id='badges'><span class='badge'>Indexing dataset repo...</span></div>")
 
562
 
563
  with gr.Row():
564
+ task_dd = gr.Dropdown(label="Task", choices=TASKS, value=DEFAULT_TASK, scale=1)
565
  shape_dd = gr.Dropdown(label="Shape", choices=SHAPES, value="circle", scale=1)
566
  split_dd = gr.Dropdown(label="Split", choices=SPLITS, value="test", scale=1)
567
+ size_dd = gr.Dropdown(label="Size", choices=[DEFAULT_SIZE_CHOICE], value=DEFAULT_SIZE_CHOICE, scale=1)
568
  parquet_dd = gr.Dropdown(label="Parquet", choices=[], value=None, scale=2)
 
569
 
570
  with gr.Row(elem_id="toolbar"):
571
  id_query = gr.Textbox(label="Find by id", placeholder="UUID or substring", scale=2)
572
  idx_slider = gr.Slider(label="Index", minimum=0, maximum=0, value=0, step=1, scale=2)
573
+
574
  with gr.Row():
575
+ prev_btn = gr.Button("Prev", variant="secondary", scale=1)
576
+ next_btn = gr.Button("Next", variant="secondary", scale=1)
577
+ random_btn = gr.Button("Random", variant="primary", scale=1)
578
+ find_btn = gr.Button("Find", variant="secondary", scale=1)
579
  show_btn = gr.Button("Show", variant="secondary", scale=1)
580
 
 
581
  with gr.Row(elem_id="viewer"):
582
  with gr.Column(scale=3):
583
  status_md = gr.Markdown(elem_id="status")
 
596
  with gr.Accordion("Metadata (raw)", open=False):
597
  meta_raw = gr.Textbox(lines=10, interactive=False)
598
 
 
599
  demo.load(
600
  fn=init_app,
601
  inputs=None,
602
  outputs=[records_state, scan_info],
603
  ).then(
604
+ fn=on_task_change,
605
+ inputs=[records_state, task_dd, split_dd],
606
+ outputs=[shape_dd, parquet_dd, parquet_tip],
607
+ ).then(
608
+ fn=on_select_parquet,
609
+ inputs=[parquet_dd, size_dd],
610
+ outputs=[summary_badge, idx_slider, size_dd],
611
+ ).then(
612
+ fn=lambda p, s: on_show(p, 0, s) if p else empty_sample_view(),
613
+ inputs=[parquet_dd, size_dd],
614
+ outputs=[idx_slider, status_md, instruction, gallery, meta_json, meta_raw],
615
+ )
616
+
617
+ task_dd.change(
618
+ fn=on_task_change,
619
+ inputs=[records_state, task_dd, split_dd],
620
+ outputs=[shape_dd, parquet_dd, parquet_tip],
621
  ).then(
622
+ fn=on_select_parquet,
623
  inputs=[parquet_dd, size_dd],
624
+ outputs=[summary_badge, idx_slider, size_dd],
625
  ).then(
626
+ fn=lambda p, s: on_show(p, 0, s) if p else empty_sample_view(),
627
  inputs=[parquet_dd, size_dd],
628
  outputs=[idx_slider, status_md, instruction, gallery, meta_json, meta_raw],
629
  )
630
 
631
  shape_dd.change(
632
+ fn=on_task_shape_split_change,
633
+ inputs=[records_state, task_dd, shape_dd, split_dd],
634
  outputs=[parquet_dd, parquet_tip],
635
+ ).then(
636
+ fn=on_select_parquet,
637
+ inputs=[parquet_dd, size_dd],
638
+ outputs=[summary_badge, idx_slider, size_dd],
639
+ ).then(
640
+ fn=lambda p, s: on_show(p, 0, s) if p else empty_sample_view(),
641
+ inputs=[parquet_dd, size_dd],
642
+ outputs=[idx_slider, status_md, instruction, gallery, meta_json, meta_raw],
643
  )
644
+
645
  split_dd.change(
646
+ fn=on_task_shape_split_change,
647
+ inputs=[records_state, task_dd, shape_dd, split_dd],
648
  outputs=[parquet_dd, parquet_tip],
649
+ ).then(
650
+ fn=on_select_parquet,
651
+ inputs=[parquet_dd, size_dd],
652
+ outputs=[summary_badge, idx_slider, size_dd],
653
+ ).then(
654
+ fn=lambda p, s: on_show(p, 0, s) if p else empty_sample_view(),
655
+ inputs=[parquet_dd, size_dd],
656
+ outputs=[idx_slider, status_md, instruction, gallery, meta_json, meta_raw],
657
  )
658
 
659
  parquet_dd.change(
660
  fn=on_select_parquet,
661
  inputs=[parquet_dd, size_dd],
662
+ outputs=[summary_badge, idx_slider, size_dd],
663
  ).then(
664
+ fn=lambda p, s: on_show(p, 0, s) if p else empty_sample_view(),
665
  inputs=[parquet_dd, size_dd],
666
  outputs=[idx_slider, status_md, instruction, gallery, meta_json, meta_raw],
667
  )
 
669
  size_dd.change(
670
  fn=on_select_parquet,
671
  inputs=[parquet_dd, size_dd],
672
+ outputs=[summary_badge, idx_slider, size_dd],
673
  ).then(
674
+ fn=lambda p, s: on_show(p, 0, s) if p else empty_sample_view(),
675
  inputs=[parquet_dd, size_dd],
676
  outputs=[idx_slider, status_md, instruction, gallery, meta_json, meta_raw],
677
  )
 
712
 
713
  if __name__ == "__main__":
714
  demo = build_ui()
715
+ demo.launch()