# app.py import io import json import base64 import random from typing import Optional, Dict, Any, List, Tuple import pandas as pd from PIL import Image import gradio as gr from huggingface_hub import HfApi, hf_hub_download DATASET_REPO_ID = "piekenius123/Amaze" REPO_TYPE = "dataset" SHAPES = ["circle", "hexagon", "square", "triangle"] SPLITS = ["train", "val", "test"] MAZE_SIZE_MIN, MAZE_SIZE_MAX = 3, 16 MAZE_SIZE_CHOICES = ["All"] + [f"{n}×{n}" for n in range(MAZE_SIZE_MIN, MAZE_SIZE_MAX + 1)] IMAGE_COLS = ["original_img", "m_original_img", "sol_img", "mask_img", "cell_map"] # ------------------------- # Decode / parse helpers # ------------------------- def safe_json_loads(s: Any) -> Tuple[Optional[Dict[str, Any]], Optional[str]]: if s is None: return None, None if isinstance(s, float) and pd.isna(s): return None, None if not isinstance(s, str): return None, f"metadata is not a string, got type={type(s)}" ss = s.strip() if ss == "" or ss.lower() == "null": return None, None try: return json.loads(ss), None except Exception as e: return None, str(e) def decode_base64_image(base64_str: Any) -> Optional[Image.Image]: if base64_str is None: return None if isinstance(base64_str, float) and pd.isna(base64_str): return None if isinstance(base64_str, str) and (base64_str.strip() == "" or base64_str.strip().lower() == "null"): return None if not isinstance(base64_str, str): return None s = base64_str.strip() try: if s.startswith("data:"): s = s.split(",", 1)[1] img_bytes = base64.b64decode(s) img = Image.open(io.BytesIO(img_bytes)) img.load() return img except Exception: return None def infer_shape_from_repo_path(path: str) -> Optional[str]: p = path.replace("\\", "/").lower() for s in SHAPES: if p.startswith(f"{s}/") or f"/{s}/" in p: return s return None def infer_split_from_repo_path(path: str) -> Optional[str]: p = path.replace("\\", "/").lower() fn = p.split("/")[-1] if fn == "maze_dataset_train.parquet": return "train" if fn == "maze_dataset_test.parquet": if "/maze-dataset_train/" in p: return "val" if "/maze-dataset/" in p: return "test" return None def get_metadata_size(meta_str: Any) -> Optional[Tuple[int, int]]: """ Your metadata structure says width/height are under maze_config (for non-circle). Some datasets also duplicate width/height at top-level; we support both. """ d, err = safe_json_loads(meta_str) if not d or err: return None mc = d.get("maze_config") if isinstance(d, dict) else None if isinstance(mc, dict) and ("width" in mc) and ("height" in mc): try: return int(mc["width"]), int(mc["height"]) except Exception: pass if ("width" in d) and ("height" in d): try: return int(d["width"]), int(d["height"]) except Exception: pass return None def filter_df_by_maze_size(df: pd.DataFrame, size_str: Optional[str]) -> pd.DataFrame: if not size_str or size_str == "All": return df try: a, b = size_str.split("×") w, h = int(a), int(b) except Exception: return df if "metadata" not in df.columns: return df mask = df["metadata"].apply(lambda m: get_metadata_size(m) == (w, h)) return df.loc[mask].reset_index(drop=True) def summarize_df(df: pd.DataFrame, filtered_len: Optional[int] = None) -> str: base = f"{len(df)} rows · {len(df.columns)} cols" if filtered_len is not None and filtered_len != len(df): base += f" · filtered: {filtered_len}" return base def find_index_by_id(df: pd.DataFrame, sample_id: str) -> Optional[int]: if "id" not in df.columns or not sample_id: return None try: mask = df["id"] == sample_id if mask.any(): return int(mask.idxmax()) if isinstance(df.index, pd.RangeIndex) else int(df.index.get_loc(df[mask].index[0])) except Exception: pass try: mask = df["id"].astype(str).str.contains(sample_id, na=False) if mask.any(): first = df[mask].index[0] return int(df.index.get_loc(first)) except Exception: pass return None # ------------------------- # HF repo index + cache # ------------------------- def build_repo_index() -> List[Dict[str, str]]: api = HfApi() files = api.list_repo_files(repo_id=DATASET_REPO_ID, repo_type=REPO_TYPE) records: List[Dict[str, str]] = [] for f in files: if not f.lower().endswith(".parquet"): continue shape = infer_shape_from_repo_path(f) split = infer_split_from_repo_path(f) if shape and split: records.append({"repo_path": f, "shape": shape, "split": split}) records.sort(key=lambda r: r["repo_path"]) return records _DF_CACHE: Dict[str, pd.DataFrame] = {} def download_and_load_df(repo_path: str) -> pd.DataFrame: local_path = hf_hub_download( repo_id=DATASET_REPO_ID, repo_type=REPO_TYPE, filename=repo_path, ) if local_path in _DF_CACHE: return _DF_CACHE[local_path] wanted_cols = ["id", "instruction", "metadata"] + IMAGE_COLS df = pd.read_parquet(local_path, columns=[c for c in wanted_cols if c is not None]) _DF_CACHE[local_path] = df return df def get_repo_paths(records: List[Dict[str, str]], shape: str, split: str) -> List[str]: out = [r["repo_path"] for r in (records or []) if r["shape"] == shape and r["split"] == split] out.sort() return out # ------------------------- # Rendering # ------------------------- def render_sample_view(df_filtered: pd.DataFrame, index: int): if len(df_filtered) == 0: return ( 0, gr.update(value="No samples (after filtering)."), "", [], {}, "", ) index = max(0, min(int(index), len(df_filtered) - 1)) row = df_filtered.iloc[index] sid = str(row.get("id", f"maze_{index}")) instruction = str(row.get("instruction", "")) original = decode_base64_image(row.get("original_img")) marked = decode_base64_image(row.get("m_original_img")) or original cell_map = decode_base64_image(row.get("cell_map")) mask = decode_base64_image(row.get("mask_img")) sol = decode_base64_image(row.get("sol_img")) meta_dict, meta_err = safe_json_loads(row.get("metadata")) if meta_err: meta_json = {"_parse_error": meta_err} else: meta_json = meta_dict or {} meta_raw = row.get("metadata", "") meta_raw = meta_raw if isinstance(meta_raw, str) else str(meta_raw) gallery_items = [ (marked, "Marked / Original"), (original, "Original"), (sol, "Solution"), (mask, "Mask"), (cell_map, "Cell map"), ] gallery_items = [(img, cap) for (img, cap) in gallery_items if img is not None] status_md = f"**Sample** `{sid}` \n**Index** `{index}` / `{len(df_filtered)-1}`" return index, status_md, instruction, gallery_items, meta_json, meta_raw # ------------------------- # Gradio callbacks # ------------------------- def init_app(): try: recs = build_repo_index() info_html = f"
✅ Indexed {DATASET_REPO_ID}{len(recs)} parquet files
" return recs, info_html except Exception as e: return [], f"
❌ Failed to index: {e}
" def on_shape_split_change(records: List[Dict[str, str]], shape: str, split: str): choices = get_repo_paths(records, shape, split) value = choices[0] if choices else None tip_html = f"
Found {len(choices)} parquet file(s) for {shape} / {split}
" return gr.Dropdown(choices=choices, value=value), tip_html def get_filtered_df(repo_path: str, size_str: str) -> Tuple[pd.DataFrame, str]: df = download_and_load_df(repo_path) filtered = filter_df_by_maze_size(df, size_str) summary = summarize_df(df, filtered_len=len(filtered)) return filtered, summary def on_select_parquet(repo_path: str, size_str: str): if not repo_path: return gr.update(value="
No parquet selected
"), gr.update(maximum=0, value=0) filtered, summary = get_filtered_df(repo_path, size_str) max_idx = max(0, len(filtered) - 1) summary_html = f"
{summary}
" return gr.update(value=summary_html), gr.update(maximum=max_idx, value=0) def on_prev(repo_path: str, index: int, size_str: str): if not repo_path: return 0, "No parquet selected.", "", [], {}, "" filtered, _ = get_filtered_df(repo_path, size_str) return render_sample_view(filtered, max(0, int(index) - 1)) def on_next(repo_path: str, index: int, size_str: str): if not repo_path: return 0, "No parquet selected.", "", [], {}, "" filtered, _ = get_filtered_df(repo_path, size_str) return render_sample_view(filtered, min(len(filtered) - 1, int(index) + 1)) def on_show(repo_path: str, index: int, size_str: str): if not repo_path: return 0, "No parquet selected.", "", [], {}, "" filtered, _ = get_filtered_df(repo_path, size_str) return render_sample_view(filtered, index) def on_random(repo_path: str, size_str: str): if not repo_path: return 0, "No parquet selected.", "", [], {}, "" filtered, _ = get_filtered_df(repo_path, size_str) if len(filtered) == 0: return render_sample_view(filtered, 0) return render_sample_view(filtered, random.randint(0, len(filtered) - 1)) def on_find_id(repo_path: str, query_id: str, size_str: str): if not repo_path: return 0, "No parquet selected.", "", [], {}, "" filtered, _ = get_filtered_df(repo_path, size_str) pos = find_index_by_id(filtered, query_id.strip() if isinstance(query_id, str) else "") if pos is None: out = list(render_sample_view(filtered, 0)) out[1] = out[1] + f" \n⚠️ id search `{query_id}` not found" return tuple(out) return render_sample_view(filtered, pos) # ------------------------- # UI (styled) # ------------------------- CSS = """ /* 使用系统默认字体 */ .gradio-container { font-family: system-ui, -apple-system, BlinkMacSystemFont, "Segoe UI", sans-serif !important; } /* 全局:页面居中 + 不要铺满 */ .gradio-container { max-width: 1200px !important; margin: 0 auto !important; } /* 顶部控制卡片:紧凑、没有大灰底空白 */ #topbar { padding: 12px 14px; border-radius: 16px; background: var(--block-background-fill); border: 1px solid var(--border-color-primary); } #topbar .gr-row { flex-wrap: wrap; gap: 10px; } #topbar .gr-form { margin-bottom: 0 !important; } /* 输入/下拉更紧凑 */ #topbar input, #topbar textarea, #topbar .wrap { border-radius: 12px !important; } /* 按钮统一,不要变成右侧巨大菜单 */ #topbar button { height: 42px !important; border-radius: 12px !important; } /* badges */ #badges { display: flex; gap: 10px; flex-wrap: wrap; align-items: center; } .badge { padding: 6px 10px; border-radius: 999px; border: 1px solid var(--border-color-primary); background: var(--background-fill-secondary); font-size: 13px; line-height: 1.2; } /* Index 一行,按钮单独一行并向下留间距 */ #toolbar .gr-row { align-items: end; } #toolbar-btns { margin-top: 12px; } #toolbar-btns .gr-row { align-items: end; } /* Gallery 更像 viewer */ #viewer { margin-top: 10px; } """ THEME = gr.themes.Soft( radius_size=gr.themes.sizes.radius_lg, text_size=gr.themes.sizes.text_md, ) def build_ui(): with gr.Blocks(title="Amaze Viewer", theme=THEME, css=CSS) as demo: gr.Markdown( f""" # Amaze Dataset: https://huggingface.co/datasets/piekenius123/Amaze Amaze is a benchmark for Edting-as-Reasoning task (EAR). It features four maze shapes: circle, hexagon, square, and triangle. Each sample provides: an unmarked maze image (original_img), a maze image with start and end points marked (m_original_img), a blue solution path image (sol_img), a binary path mask (mask_img), a cell segmentation map (cell_map), and metadata (JSON) for describing the maze structure and difficulty. The test set covers various sizes from 3×3 to 16×16 (50 samples for each size), while the training set mainly consists of 3×3 mazes (1024 samples), and validation set consists of 3×3 mazes (256 samples). Browse samples by **shape / split / maze size**, then view images + metadata. """ ) records_state = gr.State([]) # Top control bar (compact card) with gr.Column(elem_id="topbar"): with gr.Row(): parquet_tip = gr.HTML(value="
") summary_badge = gr.HTML(value="
No parquet selected
") scan_info = gr.HTML(value="
Indexing dataset repo…
") with gr.Row(): shape_dd = gr.Dropdown(label="Shape", choices=SHAPES, value="circle", scale=1) split_dd = gr.Dropdown(label="Split", choices=SPLITS, value="test", scale=1) size_dd = gr.Dropdown(label="Maze size", choices=MAZE_SIZE_CHOICES, value="All", scale=1) parquet_dd = gr.Dropdown(label="Parquet", choices=[], value=None, scale=2) with gr.Row(elem_id="toolbar"): id_query = gr.Textbox(label="Find by id", placeholder="UUID or substring", scale=2) idx_slider = gr.Slider(label="Index", minimum=0, maximum=0, value=0, step=1, scale=2) with gr.Row(): prev_btn = gr.Button("⬅ Prev", variant="secondary", scale=1) next_btn = gr.Button("Next ➡", variant="secondary", scale=1) random_btn = gr.Button("🎲 Random", variant="primary", scale=1) find_btn = gr.Button("🔎 Find", variant="secondary", scale=1) show_btn = gr.Button("Show", variant="secondary", scale=1) # Main viewer layout with gr.Row(elem_id="viewer"): with gr.Column(scale=3): status_md = gr.Markdown(elem_id="status") gallery = gr.Gallery( label="Images", columns=2, height=520, object_fit="contain", preview=True, ) with gr.Column(scale=2): instruction = gr.Textbox(label="Instruction", lines=6, interactive=False) with gr.Accordion("Metadata (parsed JSON)", open=True): meta_json = gr.JSON() with gr.Accordion("Metadata (raw)", open=False): meta_raw = gr.Textbox(lines=10, interactive=False) # ---- events ---- demo.load( fn=init_app, inputs=None, outputs=[records_state, scan_info], ).then( fn=on_shape_split_change, inputs=[records_state, shape_dd, split_dd], outputs=[parquet_dd, parquet_tip], ).then( fn=lambda p, s: on_select_parquet(p, s) if p else (gr.update(value="
No parquet selected
"), gr.update(maximum=0, value=0)), inputs=[parquet_dd, size_dd], outputs=[summary_badge, idx_slider], ).then( fn=lambda p, s: on_show(p, 0, s) if p else (0, "No parquet selected.", "", [], {}, ""), inputs=[parquet_dd, size_dd], outputs=[idx_slider, status_md, instruction, gallery, meta_json, meta_raw], ) shape_dd.change( fn=on_shape_split_change, inputs=[records_state, shape_dd, split_dd], outputs=[parquet_dd, parquet_tip], ) split_dd.change( fn=on_shape_split_change, inputs=[records_state, shape_dd, split_dd], outputs=[parquet_dd, parquet_tip], ) parquet_dd.change( fn=on_select_parquet, inputs=[parquet_dd, size_dd], outputs=[summary_badge, idx_slider], ).then( fn=lambda p, s: on_show(p, 0, s) if p else (0, "No parquet selected.", "", [], {}, ""), inputs=[parquet_dd, size_dd], outputs=[idx_slider, status_md, instruction, gallery, meta_json, meta_raw], ) size_dd.change( fn=on_select_parquet, inputs=[parquet_dd, size_dd], outputs=[summary_badge, idx_slider], ).then( fn=lambda p, s: on_show(p, 0, s) if p else (0, "No parquet selected.", "", [], {}, ""), inputs=[parquet_dd, size_dd], outputs=[idx_slider, status_md, instruction, gallery, meta_json, meta_raw], ) show_btn.click( fn=on_show, inputs=[parquet_dd, idx_slider, size_dd], outputs=[idx_slider, status_md, instruction, gallery, meta_json, meta_raw], ) idx_slider.release( fn=on_show, inputs=[parquet_dd, idx_slider, size_dd], outputs=[idx_slider, status_md, instruction, gallery, meta_json, meta_raw], ) prev_btn.click( fn=on_prev, inputs=[parquet_dd, idx_slider, size_dd], outputs=[idx_slider, status_md, instruction, gallery, meta_json, meta_raw], ) next_btn.click( fn=on_next, inputs=[parquet_dd, idx_slider, size_dd], outputs=[idx_slider, status_md, instruction, gallery, meta_json, meta_raw], ) random_btn.click( fn=on_random, inputs=[parquet_dd, size_dd], outputs=[idx_slider, status_md, instruction, gallery, meta_json, meta_raw], ) find_btn.click( fn=on_find_id, inputs=[parquet_dd, id_query, size_dd], outputs=[idx_slider, status_md, instruction, gallery, meta_json, meta_raw], ) return demo if __name__ == "__main__": demo = build_ui() demo.launch()