piekenius123 commited on
Commit
eaf430a
·
verified ·
1 Parent(s): 553a2bb

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +438 -0
app.py ADDED
@@ -0,0 +1,438 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import io
3
+ import json
4
+ import base64
5
+ import random
6
+ from typing import Optional, Dict, Any, List, Tuple
7
+
8
+ import pandas as pd
9
+ from PIL import Image
10
+ import gradio as gr
11
+
12
+ from huggingface_hub import HfApi, hf_hub_download
13
+
14
+ # =========================
15
+ # Hugging Face Dataset Repo
16
+ # =========================
17
+ DATASET_REPO_ID = "piekenius123/Amaze" # your dataset
18
+ REPO_TYPE = "dataset"
19
+
20
+ SHAPES = ["circle", "hexagon", "square", "triangle"]
21
+ SPLITS = ["train", "val", "test"]
22
+
23
+ IMAGE_COLS = ["original_img", "m_original_img", "sol_img", "mask_img", "cell_map"]
24
+
25
+
26
+ # =========================
27
+ # Helpers
28
+ # =========================
29
+ def infer_shape_from_repo_path(path: str) -> Optional[str]:
30
+ p = path.replace("\\", "/").lower()
31
+ for s in SHAPES:
32
+ if p.startswith(f"{s}/") or f"/{s}/" in p:
33
+ return s
34
+ return None
35
+
36
+
37
+ def infer_split_from_repo_path(path: str) -> Optional[str]:
38
+ """
39
+ Rules (based on your dataset description):
40
+ - .../maze_dataset_train.parquet => train
41
+ - .../maze_dataset_test.parquet:
42
+ * if under .../maze-dataset_train/ => val
43
+ * else if under .../maze-dataset/ => test
44
+ """
45
+ p = path.replace("\\", "/").lower()
46
+ fn = p.split("/")[-1]
47
+
48
+ if fn == "maze_dataset_train.parquet":
49
+ return "train"
50
+
51
+ if fn == "maze_dataset_test.parquet":
52
+ if "/maze-dataset_train/" in p:
53
+ return "val"
54
+ if "/maze-dataset/" in p:
55
+ return "test"
56
+
57
+ return None
58
+
59
+
60
+ def decode_base64_image(base64_str: Any) -> Optional[Image.Image]:
61
+ if base64_str is None:
62
+ return None
63
+ if isinstance(base64_str, float) and pd.isna(base64_str):
64
+ return None
65
+ if isinstance(base64_str, str) and (base64_str.strip() == "" or base64_str.strip().lower() == "null"):
66
+ return None
67
+ if not isinstance(base64_str, str):
68
+ return None
69
+
70
+ s = base64_str.strip()
71
+ try:
72
+ # Remove data URL prefix if present
73
+ if s.startswith("data:"):
74
+ s = s.split(",", 1)[1]
75
+ img_bytes = base64.b64decode(s)
76
+ img = Image.open(io.BytesIO(img_bytes))
77
+ img.load()
78
+ return img
79
+ except Exception:
80
+ return None
81
+
82
+
83
+ def safe_json_loads(s: Any) -> Tuple[Optional[Dict[str, Any]], Optional[str]]:
84
+ if s is None:
85
+ return None, None
86
+ if isinstance(s, float) and pd.isna(s):
87
+ return None, None
88
+ if not isinstance(s, str):
89
+ return None, f"metadata is not a string, got type={type(s)}"
90
+ ss = s.strip()
91
+ if ss == "" or ss.lower() == "null":
92
+ return None, None
93
+ try:
94
+ return json.loads(ss), None
95
+ except Exception as e:
96
+ return None, str(e)
97
+
98
+
99
+ def summarize_df(df: pd.DataFrame) -> str:
100
+ cols = list(df.columns)
101
+ return f"Rows: {len(df)}\nCols: {len(cols)}\nColumns: {', '.join(cols)}"
102
+
103
+
104
+ def row_to_kv_table(row: pd.Series) -> pd.DataFrame:
105
+ records = []
106
+ for k, v in row.items():
107
+ if k in IMAGE_COLS:
108
+ records.append((k, f"<base64 image str> len={len(v) if isinstance(v, str) else 'NA'}"))
109
+ elif k == "metadata":
110
+ records.append((k, f"<json str> len={len(v) if isinstance(v, str) else 'NA'}"))
111
+ else:
112
+ if isinstance(v, str) and len(v) > 500:
113
+ vv = v[:500] + " ... (truncated)"
114
+ else:
115
+ vv = v
116
+ records.append((k, vv))
117
+ return pd.DataFrame(records, columns=["field", "value"])
118
+
119
+
120
+ def render_sample(df: pd.DataFrame, index: int):
121
+ if len(df) == 0:
122
+ return (
123
+ 0, "Empty dataframe.", "",
124
+ None, None, None, None, None,
125
+ {}, "", pd.DataFrame(columns=["field", "value"])
126
+ )
127
+
128
+ index = max(0, min(int(index), len(df) - 1))
129
+ row = df.iloc[index]
130
+
131
+ sample_id = str(row.get("id", f"maze_{index}"))
132
+ instruction = str(row.get("instruction", ""))
133
+
134
+ imgs = {col: decode_base64_image(row.get(col, None)) for col in IMAGE_COLS}
135
+
136
+ meta_dict, meta_err = safe_json_loads(row.get("metadata", None))
137
+ meta_raw = row.get("metadata", "")
138
+ meta_json = {"_parse_error": meta_err} if meta_err else (meta_dict if meta_dict is not None else {})
139
+
140
+ kv_df = row_to_kv_table(row)
141
+
142
+ status = f"Index: {index} / {len(df)-1} | id: {sample_id}"
143
+ return (
144
+ index,
145
+ status,
146
+ instruction,
147
+ imgs["original_img"],
148
+ imgs["m_original_img"],
149
+ imgs["sol_img"],
150
+ imgs["mask_img"],
151
+ imgs["cell_map"],
152
+ meta_json,
153
+ meta_raw if isinstance(meta_raw, str) else str(meta_raw),
154
+ kv_df,
155
+ )
156
+
157
+
158
+ def find_index_by_id(df: pd.DataFrame, sample_id: str) -> Optional[int]:
159
+ if "id" not in df.columns or not sample_id:
160
+ return None
161
+
162
+ # exact match
163
+ try:
164
+ mask = df["id"] == sample_id
165
+ if mask.any():
166
+ return int(df.index[mask][0]) if not isinstance(df.index, pd.RangeIndex) else int(mask.idxmax())
167
+ except Exception:
168
+ pass
169
+
170
+ # substring match
171
+ try:
172
+ mask = df["id"].astype(str).str.contains(sample_id, na=False)
173
+ if mask.any():
174
+ # return first match position
175
+ pos = df[mask].index[0]
176
+ # convert label to positional index
177
+ return int(df.index.get_loc(pos))
178
+ except Exception:
179
+ pass
180
+
181
+ return None
182
+
183
+
184
+ # =========================
185
+ # HF repo indexing + caching
186
+ # =========================
187
+ def build_repo_index() -> List[Dict[str, str]]:
188
+ """
189
+ List all files in dataset repo, keep parquet only, infer shape/split.
190
+ """
191
+ api = HfApi()
192
+ files = api.list_repo_files(repo_id=DATASET_REPO_ID, repo_type=REPO_TYPE)
193
+ # list_repo_files is part of HfApi; repo_type supports "dataset". :contentReference[oaicite:3]{index=3}
194
+ records: List[Dict[str, str]] = []
195
+ for f in files:
196
+ if not f.lower().endswith(".parquet"):
197
+ continue
198
+ shape = infer_shape_from_repo_path(f)
199
+ split = infer_split_from_repo_path(f)
200
+ if shape and split:
201
+ records.append({"repo_path": f, "shape": shape, "split": split})
202
+ records.sort(key=lambda r: r["repo_path"])
203
+ return records
204
+
205
+
206
+ # cache dataframes per local downloaded file path
207
+ _DF_CACHE: Dict[str, pd.DataFrame] = {}
208
+
209
+
210
+ def download_and_load_df(repo_path: str) -> pd.DataFrame:
211
+ """
212
+ Download parquet from dataset repo (cached by hf_hub_download), then read to pandas.
213
+ """
214
+ local_path = hf_hub_download(
215
+ repo_id=DATASET_REPO_ID,
216
+ repo_type=REPO_TYPE,
217
+ filename=repo_path,
218
+ )
219
+ # hf_hub_download caches files and returns local path; do not modify cached file. :contentReference[oaicite:4]{index=4}
220
+ if local_path in _DF_CACHE:
221
+ return _DF_CACHE[local_path]
222
+ df = pd.read_parquet(local_path)
223
+ _DF_CACHE[local_path] = df
224
+ return df
225
+
226
+
227
+ def get_repo_paths(records: List[Dict[str, str]], shape: str, split: str) -> List[str]:
228
+ out = [r["repo_path"] for r in (records or []) if r["shape"] == shape and r["split"] == split]
229
+ out.sort()
230
+ return out
231
+
232
+
233
+ # =========================
234
+ # Gradio callbacks
235
+ # =========================
236
+ def init_app():
237
+ try:
238
+ recs = build_repo_index()
239
+ info = f"Dataset: {DATASET_REPO_ID}\nParquet files indexed: {len(recs)}"
240
+ return recs, info
241
+ except Exception as e:
242
+ return [], f"Failed to index dataset repo: {e}"
243
+
244
+
245
+ def on_shape_split_change(records: List[Dict[str, str]], shape: str, split: str):
246
+ choices = get_repo_paths(records, shape, split)
247
+ value = choices[0] if choices else None
248
+ tip = f"Matched parquet files: {len(choices)}"
249
+ return gr.Dropdown(choices=choices, value=value), tip
250
+
251
+
252
+ def on_select_parquet(repo_path: str):
253
+ if not repo_path:
254
+ return "No parquet selected.", 0, 0
255
+ df = download_and_load_df(repo_path)
256
+ summary = summarize_df(df)
257
+ max_idx = max(0, len(df) - 1)
258
+ return summary, max_idx, 0
259
+
260
+
261
+ def on_show(repo_path: str, index: int):
262
+ if not repo_path:
263
+ return (
264
+ 0, "No parquet selected.", "",
265
+ None, None, None, None, None,
266
+ {}, "", pd.DataFrame(columns=["field", "value"])
267
+ )
268
+ df = download_and_load_df(repo_path)
269
+ return render_sample(df, index)
270
+
271
+
272
+ def on_random(repo_path: str):
273
+ if not repo_path:
274
+ return on_show(repo_path, 0)
275
+ df = download_and_load_df(repo_path)
276
+ if len(df) == 0:
277
+ return on_show(repo_path, 0)
278
+ idx = random.randint(0, len(df) - 1)
279
+ return render_sample(df, idx)
280
+
281
+
282
+ def on_find_id(repo_path: str, query_id: str):
283
+ if not repo_path:
284
+ return on_show(repo_path, 0)
285
+ df = download_and_load_df(repo_path)
286
+ pos = find_index_by_id(df, query_id.strip() if isinstance(query_id, str) else "")
287
+ if pos is None:
288
+ out = list(render_sample(df, 0))
289
+ out[1] = out[1] + f" | id search '{query_id}' NOT FOUND"
290
+ return tuple(out)
291
+ return render_sample(df, pos)
292
+
293
+
294
+ # =========================
295
+ # UI
296
+ # =========================
297
+ def build_ui():
298
+ with gr.Blocks(title="Amaze Parquet Viewer (HF Dataset)") as demo:
299
+ gr.Markdown(
300
+ "# Amaze Benchmark Parquet Viewer (HF Space)\n"
301
+ f"数据来自 Hugging Face Dataset:`{DATASET_REPO_ID}`。\n\n"
302
+ "选择 **shape / split(train/val/test)** 后,Space 会按需下载对应 parquet 并可视化每条样本。"
303
+ )
304
+
305
+ records_state = gr.State([])
306
+
307
+ scan_info = gr.Textbox(label="Repo index status", interactive=False)
308
+
309
+ with gr.Row():
310
+ shape_dd = gr.Dropdown(label="Shape", choices=SHAPES, value=SHAPES[0])
311
+ split_dd = gr.Dropdown(label="Split", choices=SPLITS, value="test")
312
+
313
+ parquet_tip = gr.Markdown(value="Matched parquet files: 0")
314
+ parquet_dd = gr.Dropdown(label="Matched parquet files (repo path)", choices=[], value=None, interactive=True)
315
+
316
+ with gr.Row():
317
+ file_summary = gr.Textbox(label="Selected parquet summary", interactive=False)
318
+ idx_slider = gr.Slider(label="Row index", minimum=0, maximum=0, value=0, step=1, interactive=True)
319
+
320
+ with gr.Row():
321
+ show_btn = gr.Button("Show")
322
+ random_btn = gr.Button("Random")
323
+ id_query = gr.Textbox(label="Find by id (exact or substring)", placeholder="paste UUID or substring")
324
+ find_btn = gr.Button("Find")
325
+
326
+ status = gr.Textbox(label="Status", interactive=False)
327
+ instruction = gr.Textbox(label="Instruction", lines=4, interactive=False)
328
+
329
+ with gr.Tabs():
330
+ with gr.Tab("Images"):
331
+ with gr.Row():
332
+ original_img = gr.Image(label="original_img", type="pil")
333
+ m_original_img = gr.Image(label="m_original_img", type="pil")
334
+ with gr.Row():
335
+ sol_img = gr.Image(label="sol_img", type="pil")
336
+ mask_img = gr.Image(label="mask_img", type="pil")
337
+ with gr.Row():
338
+ cell_map = gr.Image(label="cell_map", type="pil")
339
+
340
+ with gr.Tab("Metadata"):
341
+ meta_json = gr.JSON(label="metadata (parsed)")
342
+ meta_raw = gr.Textbox(label="metadata (raw)", lines=8, interactive=False)
343
+
344
+ with gr.Tab("Row fields"):
345
+ kv_table = gr.Dataframe(
346
+ label="All fields (base64 summarized)",
347
+ headers=["field", "value"],
348
+ wrap=True,
349
+ interactive=False,
350
+ )
351
+
352
+ # Events
353
+ demo.load(
354
+ fn=init_app,
355
+ inputs=None,
356
+ outputs=[records_state, scan_info],
357
+ ).then(
358
+ fn=on_shape_split_change,
359
+ inputs=[records_state, shape_dd, split_dd],
360
+ outputs=[parquet_dd, parquet_tip],
361
+ ).then(
362
+ fn=lambda p: on_select_parquet(p) if p else ("No parquet selected.", 0, 0),
363
+ inputs=[parquet_dd],
364
+ outputs=[file_summary, idx_slider, idx_slider],
365
+ ).then(
366
+ fn=lambda p: on_show(p, 0) if p else (
367
+ 0, "No parquet selected.", "",
368
+ None, None, None, None, None,
369
+ {}, "", pd.DataFrame(columns=["field", "value"])
370
+ ),
371
+ inputs=[parquet_dd],
372
+ outputs=[
373
+ idx_slider, status, instruction,
374
+ original_img, m_original_img, sol_img, mask_img, cell_map,
375
+ meta_json, meta_raw, kv_table
376
+ ],
377
+ )
378
+
379
+ shape_dd.change(
380
+ fn=on_shape_split_change,
381
+ inputs=[records_state, shape_dd, split_dd],
382
+ outputs=[parquet_dd, parquet_tip],
383
+ )
384
+ split_dd.change(
385
+ fn=on_shape_split_change,
386
+ inputs=[records_state, shape_dd, split_dd],
387
+ outputs=[parquet_dd, parquet_tip],
388
+ )
389
+
390
+ parquet_dd.change(
391
+ fn=on_select_parquet,
392
+ inputs=[parquet_dd],
393
+ outputs=[file_summary, idx_slider, idx_slider],
394
+ )
395
+
396
+ show_btn.click(
397
+ fn=on_show,
398
+ inputs=[parquet_dd, idx_slider],
399
+ outputs=[
400
+ idx_slider, status, instruction,
401
+ original_img, m_original_img, sol_img, mask_img, cell_map,
402
+ meta_json, meta_raw, kv_table
403
+ ],
404
+ )
405
+ idx_slider.release(
406
+ fn=on_show,
407
+ inputs=[parquet_dd, idx_slider],
408
+ outputs=[
409
+ idx_slider, status, instruction,
410
+ original_img, m_original_img, sol_img, mask_img, cell_map,
411
+ meta_json, meta_raw, kv_table
412
+ ],
413
+ )
414
+ random_btn.click(
415
+ fn=on_random,
416
+ inputs=[parquet_dd],
417
+ outputs=[
418
+ idx_slider, status, instruction,
419
+ original_img, m_original_img, sol_img, mask_img, cell_map,
420
+ meta_json, meta_raw, kv_table
421
+ ],
422
+ )
423
+ find_btn.click(
424
+ fn=on_find_id,
425
+ inputs=[parquet_dd, id_query],
426
+ outputs=[
427
+ idx_slider, status, instruction,
428
+ original_img, m_original_img, sol_img, mask_img, cell_map,
429
+ meta_json, meta_raw, kv_table
430
+ ],
431
+ )
432
+
433
+ return demo
434
+
435
+
436
+ if __name__ == "__main__":
437
+ demo = build_ui()
438
+ demo.launch()