| |
| """HF Space entry point for 3DReflecNet dataset preview. |
| |
| Loads only data/preview/preview.parquet so the Space exposes the configured |
| preview instance subset instead of the full dataset metadata. |
| """ |
| from __future__ import annotations |
|
|
| import atexit |
| import io |
| import os |
| import shutil |
| import tempfile |
| from pathlib import Path |
| from typing import Any |
|
|
| import gradio as gr |
| import pandas as pd |
| from datasets import load_dataset |
| from huggingface_hub import hf_hub_download |
| from PIL import Image |
|
|
| from utils import ( |
| BOOL_FILTER_CHOICES, |
| FILTER_ALL, |
| filter_dataframe_advanced, |
| get_distinct_text_choices, |
| logger, |
| require_bool_columns, |
| require_columns, |
| require_text_columns, |
| setup_logging, |
| ) |
|
|
| DATASET_REPO = os.environ.get("DATASET_REPO", "3DReflecNet/3DReflecNet") |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) |
| MAX_RESULTS = 300 |
| BOOL_COLUMNS = ["hasGlass", "isGenerated", "transparent", "near_light"] |
|
|
| _GLB_CACHE_DIR = Path(tempfile.mkdtemp(prefix="glb_cache_")) |
| atexit.register(shutil.rmtree, str(_GLB_CACHE_DIR), True) |
|
|
|
|
| |
| |
| |
|
|
| def load_preview_dataframe() -> pd.DataFrame: |
| """Load the small preview Parquet into memory.""" |
| PREVIEW_COLS = [ |
| "instance_id", "split", "frame_id", "rgb", "mask", |
| "depth_preview", "normal_preview", |
| "main_category", "sub_category", "model_name", |
| "material_name", "env_name", "glb_path", |
| "hasGlass", "isGenerated", "transparent", "near_light", |
| ] |
| ds = load_dataset( |
| DATASET_REPO, |
| data_files="data/preview/preview.parquet", |
| split="train", |
| streaming=False, |
| token=HF_TOKEN, |
| ).select_columns(PREVIEW_COLS) |
| df = pd.DataFrame(list(ds)) |
| require_columns(df, PREVIEW_COLS, "preview parquet") |
| require_text_columns( |
| df, |
| [ |
| "instance_id", "split", "main_category", "sub_category", |
| "model_name", "material_name", "env_name", "glb_path", |
| ], |
| "preview parquet", |
| ) |
| require_bool_columns(df, BOOL_COLUMNS, "preview parquet") |
| if df["frame_id"].isna().any() or not pd.api.types.is_integer_dtype(df["frame_id"]): |
| raise TypeError(f"Expected integer dtype for column 'frame_id' in preview parquet, got {df['frame_id'].dtype}.") |
| for col in ["rgb", "mask", "depth_preview", "normal_preview"]: |
| invalid = df[col].map(lambda value: not isinstance(value, (bytes, bytearray)) or len(value) == 0) |
| if invalid.any(): |
| raise TypeError(f"Expected non-empty binary values for column {col!r} in preview parquet.") |
| return df |
|
|
|
|
| def decode_image_bytes(img_bytes: bytes | bytearray, context: str) -> Image.Image: |
| if not isinstance(img_bytes, (bytes, bytearray)) or not img_bytes: |
| raise TypeError(f"Expected non-empty image bytes for {context}.") |
| with Image.open(io.BytesIO(img_bytes)) as img: |
| return img.copy() |
|
|
|
|
| def build_preview_instance_dataframe(preview_df: pd.DataFrame) -> pd.DataFrame: |
| """Derive one row per preview instance from preview frame rows.""" |
| instance_cols = [ |
| "instance_id", "main_category", "sub_category", "model_name", |
| "material_name", "env_name", "hasGlass", "isGenerated", |
| "transparent", "near_light", "glb_path", |
| ] |
| require_columns(preview_df, instance_cols, "preview parquet") |
|
|
| rows: list[dict[str, Any]] = [] |
| for instance_id, group in preview_df.groupby("instance_id", sort=True): |
| row: dict[str, Any] = {} |
| for col in instance_cols: |
| values = group[col].drop_duplicates().tolist() |
| if len(values) != 1: |
| raise ValueError(f"Inconsistent {col!r} values for preview instance {instance_id!r}.") |
| row[col] = values[0] |
| rows.append(row) |
|
|
| df = pd.DataFrame(rows, columns=instance_cols) |
| require_text_columns( |
| df, |
| [ |
| "instance_id", "main_category", "sub_category", |
| "model_name", "material_name", "env_name", "glb_path", |
| ], |
| "preview instance dataframe", |
| ) |
| require_bool_columns(df, BOOL_COLUMNS, "preview instance dataframe") |
| if df["glb_path"].map(lambda value: not value.strip()).any(): |
| raise ValueError("Preview instance dataframe contains empty GLB paths.") |
| return df |
|
|
|
|
| def train_frame_rows(preview_df: pd.DataFrame, instance_id: str, max_frames: int | None = None) -> list[dict[str, Any]]: |
| selected = preview_df[ |
| (preview_df["instance_id"].astype(str) == str(instance_id)) |
| & (preview_df["split"].astype(str) == "train") |
| ].copy() |
| if selected.empty: |
| raise ValueError(f"Preview instance {instance_id!r} has no train split rows.") |
| selected = selected.sort_values("frame_id") |
| if max_frames is not None: |
| selected = selected.head(max_frames) |
| return selected.to_dict(orient="records") |
|
|
|
|
| def get_instance_thumbnail(preview_df: pd.DataFrame, instance_id: str) -> Image.Image: |
| row = train_frame_rows(preview_df, instance_id, max_frames=1)[0] |
| return decode_image_bytes(row["rgb"], f"{instance_id} thumbnail RGB") |
|
|
|
|
| def instance_caption(row: dict[str, Any]) -> str: |
| return f"{row['model_name']} | {row['material_name']} | {row['env_name']}" |
|
|
|
|
| def build_instance_gallery_items( |
| rows: list[dict[str, Any]], |
| preview_df: pd.DataFrame, |
| ) -> list[tuple[Image.Image, str]]: |
| return [ |
| (get_instance_thumbnail(preview_df, row["instance_id"]), instance_caption(row)) |
| for row in rows |
| ] |
|
|
|
|
| def load_instance_frames( |
| preview_df: pd.DataFrame, instance_id: str, max_frames: int = 50, |
| ) -> list[dict[str, Any]]: |
| """Load train preview image payloads for one instance from preview Parquet.""" |
| rows = train_frame_rows(preview_df, instance_id, max_frames=max_frames) |
| frames: list[dict[str, Any]] = [] |
| for example in rows: |
| frame_id = int(example["frame_id"]) |
| frame_item: dict[str, Any] = {"frame_id": frame_id} |
| for key in ("rgb", "mask", "depth_preview", "normal_preview"): |
| frame_item[key] = decode_image_bytes(example[key], f"{key} frame {frame_id}") |
| frames.append(frame_item) |
| return frames |
|
|
|
|
| def render_frame_images(frame_items: list[dict[str, Any]], frame_index: float) -> list[Any | None]: |
| """Render RGB/Mask/Depth/Normal images for one selected frame index (1-based).""" |
| if not frame_items: |
| return [ |
| gr.update(value=None, label="RGB"), |
| gr.update(value=None, label="Mask"), |
| gr.update(value=None, label="Depth"), |
| gr.update(value=None, label="Normal"), |
| ] |
|
|
| idx = int(round(frame_index)) - 1 |
| idx = max(0, min(idx, len(frame_items) - 1)) |
| selected = frame_items[idx] |
| frame_id = int(selected["frame_id"]) |
| return [ |
| gr.update(value=selected["rgb"], label=f"RGB frame_{frame_id:05d}"), |
| gr.update(value=selected["mask"], label=f"Mask frame_{frame_id:05d}"), |
| gr.update(value=selected["depth_preview"], label=f"Depth frame_{frame_id:05d}"), |
| gr.update(value=selected["normal_preview"], label=f"Normal frame_{frame_id:05d}"), |
| ] |
|
|
|
|
| |
| |
| |
|
|
| def download_glb(glb_path: str) -> str: |
| """Download pre-converted GLB file from HF dataset repo.""" |
| if not glb_path: |
| raise ValueError("GLB path is required.") |
| local = _GLB_CACHE_DIR / Path(glb_path).name |
| if local.exists(): |
| return str(local) |
| downloaded = hf_hub_download( |
| repo_id=DATASET_REPO, |
| filename=glb_path, |
| repo_type="dataset", |
| token=HF_TOKEN, |
| ) |
| shutil.copy2(downloaded, str(local)) |
| logger.info("GLB ready: %s", local) |
| return str(local) |
|
|
|
|
| |
| |
| |
|
|
| def build_app(instance_df: pd.DataFrame, preview_df: pd.DataFrame) -> gr.Blocks: |
| model_name_choices = get_distinct_text_choices(instance_df, "model_name") |
| material_name_choices = get_distinct_text_choices(instance_df, "material_name") |
| env_name_choices = get_distinct_text_choices(instance_df, "env_name") |
|
|
| def filtered_instance_rows( |
| model_name: str, |
| material_name: str, |
| env_name: str, |
| has_glass: str, |
| is_generated: str, |
| transparent: str, |
| near_light: str, |
| ) -> tuple[pd.DataFrame, list[dict[str, Any]]]: |
| filtered = filter_dataframe_advanced( |
| instance_df, |
| model_name=model_name, |
| material_name=material_name, |
| env_name=env_name, |
| has_glass=has_glass, |
| is_generated=is_generated, |
| transparent=transparent, |
| near_light=near_light, |
| ) |
| shown = filtered.head(MAX_RESULTS).copy() |
| rows = shown.to_dict(orient="records") |
| return filtered, rows |
|
|
| def filter_gallery( |
| model_name: str, |
| material_name: str, |
| env_name: str, |
| has_glass: str, |
| is_generated: str, |
| transparent: str, |
| near_light: str, |
| ): |
| filtered, rows = filtered_instance_rows( |
| model_name=model_name, |
| material_name=material_name, |
| env_name=env_name, |
| has_glass=has_glass, |
| is_generated=is_generated, |
| transparent=transparent, |
| near_light=near_light, |
| ) |
| summary = f"Matched **{len(filtered)}** preview instances, showing **{len(rows)}**." |
| gallery_items = build_instance_gallery_items(rows, preview_df) |
| slider_empty = gr.update(minimum=1, maximum=1, step=1, value=1, interactive=False) |
| return summary, gallery_items, rows, {}, None, None, None, None, None, slider_empty, [] |
|
|
| def on_instance_select(rows: list[dict[str, Any]], evt: gr.SelectData): |
| if not rows: |
| slider_empty = gr.update(minimum=1, maximum=1, step=1, value=1, interactive=False) |
| return {}, None, None, None, None, None, slider_empty, [] |
| idx = evt.index[0] if isinstance(evt.index, tuple) else evt.index |
| if not isinstance(idx, int) or idx < 0 or idx >= len(rows): |
| raise IndexError(f"Selected gallery index is out of range: {evt.index!r}") |
|
|
| row = rows[idx] |
| instance_id = row["instance_id"] |
| if not isinstance(instance_id, str) or not instance_id.strip(): |
| raise ValueError(f"Selected instance row has invalid instance_id: {rows[idx]!r}") |
| logger.info("Loading images for instance: %s", instance_id) |
| frame_items = load_instance_frames(preview_df, instance_id, max_frames=50) |
| slider_ready = gr.update( |
| minimum=1, |
| maximum=len(frame_items), |
| step=1, |
| value=1, |
| interactive=True, |
| ) |
| return row, download_glb(row["glb_path"]), *render_frame_images(frame_items, 1), slider_ready, frame_items |
|
|
| def on_frame_change(frame_idx: float, frame_items: list[dict[str, Any]]): |
| return render_frame_images(frame_items, frame_idx) |
|
|
| initial_rows = instance_df.head(MAX_RESULTS).to_dict(orient="records") |
| initial_gallery = build_instance_gallery_items(initial_rows, preview_df) |
| initial_summary = f"Matched **{len(instance_df)}** preview instances, showing **{len(initial_rows)}**." |
|
|
| with gr.Blocks(title="3DReflecNet Dataset Explorer") as demo: |
| gr.Markdown("# 3DReflecNet Dataset Explorer") |
| gr.Markdown( |
| "Browse the configured preview subset. Select an RGB thumbnail to inspect the instance." |
| ) |
|
|
| with gr.Row(): |
| model_name = gr.Dropdown(label="model_name", choices=model_name_choices, value=FILTER_ALL) |
| material_name = gr.Dropdown(label="material_name", choices=material_name_choices, value=FILTER_ALL) |
| env_name = gr.Dropdown(label="env_name", choices=env_name_choices, value=FILTER_ALL) |
| with gr.Row(): |
| has_glass = gr.Dropdown(label="hasGlass", choices=BOOL_FILTER_CHOICES, value=FILTER_ALL) |
| is_generated = gr.Dropdown(label="isGenerated", choices=BOOL_FILTER_CHOICES, value=FILTER_ALL) |
| transparent = gr.Dropdown(label="transparent", choices=BOOL_FILTER_CHOICES, value=FILTER_ALL) |
| near_light = gr.Dropdown(label="near_light", choices=BOOL_FILTER_CHOICES, value=FILTER_ALL) |
|
|
| summary = gr.Markdown(initial_summary) |
| instance_gallery = gr.Gallery( |
| label="Preview Instances", |
| value=initial_gallery, |
| columns=5, |
| object_fit="contain", |
| height="auto", |
| ) |
|
|
| with gr.Row(): |
| instance_meta = gr.JSON(label="Instance Metadata") |
| model_viewer = gr.Model3D( |
| label="3D Preview (GLB)", |
| clear_color=(0.35, 0.35, 0.38, 1.0), |
| camera_position=(35, 70, 3.5), |
| ) |
|
|
| with gr.Row(): |
| rgb_image = gr.Image(label="RGB", height=360, interactive=False, scale=1, min_width=160) |
| mask_image = gr.Image(label="Mask", height=360, interactive=False, scale=1, min_width=160) |
| depth_image = gr.Image(label="Depth", height=360, interactive=False, scale=1, min_width=160) |
| normal_image = gr.Image(label="Normal", height=360, interactive=False, scale=1, min_width=160) |
| frame_slider = gr.Slider( |
| label="Frame", |
| minimum=1, |
| maximum=1, |
| step=1, |
| value=1, |
| interactive=False, |
| ) |
|
|
| instance_state = gr.State(initial_rows) |
| frame_state = gr.State([]) |
|
|
| filter_inputs = [ |
| model_name, |
| material_name, |
| env_name, |
| has_glass, |
| is_generated, |
| transparent, |
| near_light, |
| ] |
| filter_outputs = [ |
| summary, |
| instance_gallery, |
| instance_state, |
| instance_meta, |
| model_viewer, |
| rgb_image, |
| mask_image, |
| depth_image, |
| normal_image, |
| frame_slider, |
| frame_state, |
| ] |
| for filter_component in filter_inputs: |
| filter_component.change( |
| fn=filter_gallery, |
| inputs=filter_inputs, |
| outputs=filter_outputs, |
| ) |
| instance_gallery.select( |
| fn=on_instance_select, |
| inputs=[instance_state], |
| outputs=[ |
| instance_meta, |
| model_viewer, |
| rgb_image, |
| mask_image, |
| depth_image, |
| normal_image, |
| frame_slider, |
| frame_state, |
| ], |
| ) |
| frame_slider.change( |
| fn=on_frame_change, |
| inputs=[frame_slider, frame_state], |
| outputs=[rgb_image, mask_image, depth_image, normal_image], |
| ) |
|
|
| return demo |
|
|
|
|
| def main() -> None: |
| setup_logging() |
| logger.info("DATASET_REPO = %r", DATASET_REPO) |
| logger.info("HF_TOKEN set = %s, length = %d", HF_TOKEN is not None, len(HF_TOKEN) if HF_TOKEN else 0) |
| logger.info("Loading preview subset from Hugging Face Hub...") |
| preview_df = load_preview_dataframe() |
| instance_df = build_preview_instance_dataframe(preview_df) |
| logger.info("Loaded %d preview rows for %d preview instance(s).", len(preview_df), len(instance_df)) |
| app = build_app(instance_df, preview_df) |
| app.launch() |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|