Spaces:
Running
Running
| import os | |
| import glob | |
| import json | |
| from typing import Dict, Literal, Tuple, List, Optional | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import gradio as gr | |
| RESULTS_DIR = "./worldlens-results" | |
| METRICS_MIN_BETTER = [ | |
| "Depth Discrepancy", "Perceptual Discrepancy", | |
| "Photometric Error", "Geometric Discrepancy", | |
| "Novel-View Discrepancy", | |
| "Displacement Error", | |
| ] | |
| METRICS_MAX_BETTER = [ | |
| "Subject Fidelity", "Subject Coherence", "Subject Consistency", | |
| "Temporal Consistency", "Semantic Consistency", | |
| "View Consistency", | |
| "Novel-View Quality", | |
| "Open-Loop Adherence", "Route Completion", "Closed-Loop Adherence", | |
| "Map Segmentation", "3D Object Detection", "3D Object Tracking", | |
| "Occupancy Prediction", | |
| ] | |
| METRIC_BETTER: Dict[str, Literal["min", "max"]] = { | |
| m: "min" for m in METRICS_MIN_BETTER | |
| } | |
| METRIC_BETTER.update({m: "max" for m in METRICS_MAX_BETTER}) | |
| METRIC_CHOICES: List[str] = sorted(set(METRICS_MIN_BETTER + METRICS_MAX_BETTER)) | |
| DEFAULT_METRIC = "Subject Fidelity" if "Subject Fidelity" in METRIC_CHOICES else METRIC_CHOICES[0] | |
| df_all: Optional[pd.DataFrame] = None | |
| def load_results() -> pd.DataFrame: | |
| rows = [] | |
| json_files = sorted(glob.glob(os.path.join(RESULTS_DIR, "*.json"))) | |
| if not json_files: | |
| return pd.DataFrame() | |
| for path in json_files: | |
| with open(path, "r") as f: | |
| data = json.load(f) | |
| model_name = os.path.splitext(os.path.basename(path))[0] | |
| venue = data.get("venue", "") | |
| date = data.get("date", "") | |
| row = { | |
| "Model": model_name, | |
| "venue": venue, | |
| "date": date, | |
| } | |
| metrics = data.get("Metrics", {}) | |
| for category, metric_dict in metrics.items(): | |
| if not isinstance(metric_dict, dict): | |
| continue | |
| for metric_name, value in metric_dict.items(): | |
| row[metric_name] = value | |
| rows.append(row) | |
| df = pd.DataFrame(rows) | |
| meta_cols = ["Model", "venue", "date"] | |
| metric_cols = [c for c in df.columns if c not in meta_cols] | |
| df = df[meta_cols + metric_cols] | |
| return df | |
| def get_venue_choices(df: pd.DataFrame) -> List[str]: | |
| if "venue" not in df.columns: | |
| return ["All"] | |
| venues = sorted([v for v in df["venue"].dropna().unique() if v != ""]) | |
| return ["All"] + venues | |
| def update_leaderboard( | |
| metric: str, | |
| top_k: int, | |
| model_filter: str, | |
| venue_filter: str, | |
| sort_mode: str, | |
| selected_metrics: Optional[List[str]], | |
| ) -> Tuple[pd.DataFrame, plt.Figure]: | |
| global df_all | |
| if df_all is None or df_all.empty: | |
| fig, ax = plt.subplots(figsize=(6, 3)) | |
| ax.text(0.5, 0.5, "No results found in ./worldlens-results", | |
| ha="center", va="center") | |
| ax.axis("off") | |
| return pd.DataFrame(), fig | |
| df = df_all.copy() | |
| if model_filter: | |
| df = df[df["Model"].str.contains(model_filter, case=False, regex=False)] | |
| if venue_filter and venue_filter != "All": | |
| df = df[df["venue"] == venue_filter] | |
| if metric not in df.columns: | |
| fig, ax = plt.subplots(figsize=(6, 3)) | |
| ax.text(0.5, 0.5, f"Metric '{metric}' not found in current data.", ha="center", va="center") | |
| ax.axis("off") | |
| return pd.DataFrame(), fig | |
| better = METRIC_BETTER.get(metric, "max") | |
| if sort_mode == "Auto": | |
| ascending = (better == "min") | |
| elif sort_mode == "Ascending (small → large)": | |
| ascending = True | |
| else: | |
| ascending = False | |
| df_sorted = df.sort_values(metric, ascending=ascending) | |
| df_top = df_sorted.head(top_k).copy() | |
| cols = ["Model", "venue", "date"] | |
| if selected_metrics is None: | |
| selected_metrics = [] | |
| for m in selected_metrics: | |
| if m in df_top.columns and m not in cols: | |
| cols.append(m) | |
| if metric in df_top.columns and metric not in cols: | |
| cols.append(metric) | |
| table_df = df_top[cols].round(3) | |
| # ========================= | |
| # Dark-theme leaderboard plot | |
| # ========================= | |
| bg_color = "#0e1117" # HF 深色背景 | |
| panel_color = "#161b22" # 面板背景 | |
| bar_color = "#4cc9f0" # 主色(青蓝) | |
| grid_color = "#30363d" | |
| text_color = "#c9d1d9" | |
| fig, ax = plt.subplots(figsize=(10, 4.5)) | |
| fig.patch.set_facecolor(bg_color) | |
| ax.set_facecolor(panel_color) | |
| values = df_top[metric].values | |
| models = table_df["Model"].values | |
| bars = ax.barh(models, values, color=bar_color, height=0.6) | |
| if ascending: | |
| ax.invert_yaxis() | |
| ax.set_xlabel(metric, color=text_color, fontsize=11, labelpad=6) | |
| ax.set_title( | |
| f"Leaderboard · {metric}", | |
| fontsize=13, | |
| color=text_color, | |
| pad=10, | |
| fontweight="bold" | |
| ) | |
| ax.xaxis.grid(True, linestyle="--", linewidth=0.6, color=grid_color, alpha=0.7) | |
| ax.yaxis.grid(False) | |
| for spine in ["top", "right", "left"]: | |
| ax.spines[spine].set_visible(False) | |
| ax.spines["bottom"].set_color(grid_color) | |
| ax.tick_params(axis="x", colors=text_color, labelsize=10) | |
| ax.tick_params(axis="y", colors=text_color, labelsize=10) | |
| for bar, value in zip(bars, values): | |
| ax.text( | |
| bar.get_width() * 1.01, | |
| bar.get_y() + bar.get_height() / 2, | |
| f"{value:.2f}", | |
| va="center", | |
| ha="left", | |
| fontsize=9.5, | |
| color=text_color | |
| ) | |
| plt.tight_layout() | |
| return table_df, fig | |
| def reload_data(): | |
| global df_all | |
| df_all = load_results() | |
| if df_all is None or df_all.empty: | |
| msg = "No JSON files found in ./worldlens-results. Please upload some results." | |
| dummy_fig, ax = plt.subplots(figsize=(6, 3)) | |
| ax.text(0.5, 0.5, msg, ha="center", va="center") | |
| ax.axis("off") | |
| venue_update = gr.update(choices=["All"], value="All") | |
| return msg, venue_update, pd.DataFrame(), dummy_fig | |
| venue_choices = get_venue_choices(df_all) | |
| msg = f"Loaded {len(df_all)} models from {RESULTS_DIR}" | |
| default_selected = ["Subject Fidelity", "Temporal Consistency", "Map Segmentation"] | |
| default_selected = [m for m in default_selected if m in METRIC_CHOICES] | |
| table_df, fig = update_leaderboard( | |
| metric=DEFAULT_METRIC, | |
| top_k=10, | |
| model_filter="", | |
| venue_filter="All", | |
| sort_mode="Auto", | |
| selected_metrics=default_selected, | |
| ) | |
| venue_update = gr.update( | |
| choices=venue_choices, | |
| value="All", | |
| interactive=True, | |
| ) | |
| return msg, venue_update, table_df, fig | |
| with gr.Blocks(css=""" | |
| #title { | |
| text-align: center; | |
| } | |
| """) as demo: | |
| gr.Markdown( | |
| """ | |
| # 🌍 WorldLens Leaderboard | |
| """, | |
| elem_id="title" | |
| ) | |
| status_box = gr.Markdown("Loading results...", elem_id="status") | |
| with gr.Row(): | |
| metric_dropdown = gr.Dropdown( | |
| label="Metric (for ranking)", | |
| choices=METRIC_CHOICES, # 固定 choices,避免动态更新不兼容 | |
| value=DEFAULT_METRIC, | |
| interactive=True, | |
| ) | |
| sort_mode_radio = gr.Radio( | |
| label="Sort mode", | |
| choices=[ | |
| "Auto", | |
| "Ascending (small → large)", | |
| "Descending (large → small)", | |
| ], | |
| value="Auto", | |
| interactive=True, | |
| ) | |
| topk_slider = gr.Slider( | |
| label="Top-K", | |
| minimum=3, | |
| maximum=50, | |
| value=10, | |
| step=1, | |
| interactive=True, | |
| ) | |
| # 新增:表格中展示的多个指标 | |
| metrics_select = gr.CheckboxGroup( | |
| label="Metrics to show in table", | |
| choices=METRIC_CHOICES, | |
| value=["Subject Fidelity", "Temporal Consistency", "Map Segmentation"], | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| model_filter_box = gr.Textbox( | |
| label="Filter by model name", | |
| placeholder="magic, dream, ...", | |
| interactive=True, | |
| ) | |
| venue_dropdown = gr.Dropdown( | |
| label="Filter by venue", | |
| choices=["All"], | |
| value="All", | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| reload_button = gr.Button("🔄 Reload JSONs", variant="secondary") | |
| update_button = gr.Button("✅ Update leaderboard", variant="primary") | |
| leaderboard_table = gr.DataFrame( | |
| label="Leaderboard", | |
| interactive=False, | |
| ) | |
| leaderboard_plot = gr.Plot(label="Metric comparison", format="png") | |
| reload_button.click( | |
| fn=reload_data, | |
| inputs=[], | |
| outputs=[status_box, venue_dropdown, leaderboard_table, leaderboard_plot], | |
| ) | |
| update_button.click( | |
| fn=update_leaderboard, | |
| inputs=[ | |
| metric_dropdown, | |
| topk_slider, | |
| model_filter_box, | |
| venue_dropdown, | |
| sort_mode_radio, | |
| metrics_select, | |
| ], | |
| outputs=[leaderboard_table, leaderboard_plot], | |
| ) | |
| demo.load( | |
| fn=reload_data, | |
| inputs=[], | |
| outputs=[status_box, venue_dropdown, leaderboard_table, leaderboard_plot], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |