Spaces:
Running
Running
| import os | |
| import glob | |
| import json | |
| from typing import Dict, Literal, Tuple, List, Optional | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| import gradio as gr | |
| RESULTS_DIR = "./worldlens-results" | |
| # 指标好坏方向 | |
| METRICS_MIN_BETTER = [ | |
| "Depth Discrepancy", "Perceptual Discrepancy", | |
| "Photometric Error", "Geometric Discrepancy", | |
| "Novel-View Discrepancy", | |
| "Displacement Error", | |
| ] | |
| METRICS_MAX_BETTER = [ | |
| "Subject Fidelity", "Subject Coherence", "Subject Consistency", | |
| "Temporal Consistency", "Semantic Consistency", | |
| "View Consistency", # 你的 JSON 里有这个,默认认为越大越好 | |
| "Novel-View Quality", | |
| "Open-Loop Adherence", "Route Completion", "Closed-Loop Adherence", | |
| "Map Segmentation", "3D Object Detection", "3D Object Tracking", | |
| "Occupancy Prediction", | |
| ] | |
| METRIC_BETTER: Dict[str, Literal["min", "max"]] = { | |
| m: "min" for m in METRICS_MIN_BETTER | |
| } | |
| METRIC_BETTER.update({m: "max" for m in METRICS_MAX_BETTER}) | |
| # 下拉框展示的所有指标(去重+排序) | |
| METRIC_CHOICES: List[str] = sorted(set(METRICS_MIN_BETTER + METRICS_MAX_BETTER)) | |
| DEFAULT_METRIC = "Subject Fidelity" if "Subject Fidelity" in METRIC_CHOICES else METRIC_CHOICES[0] | |
| # 全局 DataFrame(所有模型) | |
| df_all: Optional[pd.DataFrame] = None | |
| def load_results() -> pd.DataFrame: | |
| """ | |
| 从 ./worldlens-results 读取所有 json,整理成一个宽表: | |
| 每一行是一个模型,每一列是一个指标。 | |
| """ | |
| rows = [] | |
| json_files = sorted(glob.glob(os.path.join(RESULTS_DIR, "*.json"))) | |
| if not json_files: | |
| return pd.DataFrame() | |
| for path in json_files: | |
| with open(path, "r") as f: | |
| data = json.load(f) | |
| model_name = os.path.splitext(os.path.basename(path))[0] | |
| venue = data.get("venue", "") | |
| date = data.get("data", "") # 你这边字段叫 data,我就直接用 | |
| row = { | |
| "Model": model_name, | |
| "venue": venue, | |
| "date": date, | |
| } | |
| metrics = data.get("Metrics", {}) | |
| # 展开所有子字典,列名直接用 metric 名称(假设唯一) | |
| for category, metric_dict in metrics.items(): | |
| if not isinstance(metric_dict, dict): | |
| continue | |
| for metric_name, value in metric_dict.items(): | |
| row[metric_name] = value | |
| rows.append(row) | |
| df = pd.DataFrame(rows) | |
| # 统一列顺序:meta + 指标 | |
| meta_cols = ["Model", "venue", "date"] | |
| metric_cols = [c for c in df.columns if c not in meta_cols] | |
| df = df[meta_cols + metric_cols] | |
| return df | |
| def get_venue_choices(df: pd.DataFrame) -> List[str]: | |
| if "venue" not in df.columns: | |
| return ["All"] | |
| venues = sorted([v for v in df["venue"].dropna().unique() if v != ""]) | |
| return ["All"] + venues | |
| def update_leaderboard( | |
| metric: str, | |
| top_k: int, | |
| model_filter: str, | |
| venue_filter: str, | |
| sort_mode: str, | |
| selected_metrics: Optional[List[str]], | |
| ) -> Tuple[pd.DataFrame, plt.Figure]: | |
| """ | |
| 根据用户选择更新排行榜表格与条形图。 | |
| metric: 用于排序 & 画图的主指标 | |
| selected_metrics: 勾选的“想在表格中展示”的其它指标(可以多个) | |
| """ | |
| global df_all | |
| if df_all is None or df_all.empty: | |
| # 空表兜底 | |
| fig, ax = plt.subplots(figsize=(6, 3)) | |
| ax.text(0.5, 0.5, "No results found in ./worldlens-results", | |
| ha="center", va="center") | |
| ax.axis("off") | |
| return pd.DataFrame(), fig | |
| df = df_all.copy() | |
| # 模型名过滤 | |
| if model_filter: | |
| df = df[df["Model"].str.contains(model_filter, case=False, regex=False)] | |
| # venue 过滤 | |
| if venue_filter and venue_filter != "All": | |
| df = df[df["venue"] == venue_filter] | |
| if metric not in df.columns: | |
| fig, ax = plt.subplots(figsize=(6, 3)) | |
| ax.text(0.5, 0.5, f"Metric '{metric}' not found in current data.", ha="center", va="center") | |
| ax.axis("off") | |
| return pd.DataFrame(), fig | |
| # 排序方向 | |
| better = METRIC_BETTER.get(metric, "max") | |
| if sort_mode == "Auto": | |
| ascending = (better == "min") | |
| elif sort_mode == "Ascending (small → large)": | |
| ascending = True | |
| else: # "Descending (large → small)" | |
| ascending = False | |
| df_sorted = df.sort_values(metric, ascending=ascending) | |
| # Top-K | |
| df_top = df_sorted.head(top_k).copy() | |
| # 构造表格列: | |
| # 固定: Model, venue, date | |
| # + 勾选的指标 | |
| # + 排序指标(如果没选) | |
| cols = ["Model", "venue", "date"] | |
| if selected_metrics is None: | |
| selected_metrics = [] | |
| # 去掉不在 df_top 里的指标(有些 metric 可能某些 json 里没计算) | |
| for m in selected_metrics: | |
| if m in df_top.columns and m not in cols: | |
| cols.append(m) | |
| if metric in df_top.columns and metric not in cols: | |
| cols.append(metric) | |
| table_df = df_top[cols].round(3) | |
| # 画条形图(只画排序指标) | |
| fig, ax = plt.subplots(figsize=(9, 4)) | |
| ax.barh(table_df["Model"], df_top[metric].iloc[:len(table_df)]) | |
| ax.set_xlabel(metric) | |
| ax.set_ylabel("Model") | |
| ax.set_title(f"Leaderboard by {metric}") | |
| # 为了让「最好的」在上面:如果按升序(小→大),我们反转 y 轴,让更小的在上。 | |
| if ascending: | |
| ax.invert_yaxis() | |
| plt.tight_layout() | |
| return table_df, fig | |
| def reload_data(): | |
| """ | |
| 点击“Reload JSONs” / 页面加载时调用: | |
| 重新加载所有 json,并返回: | |
| - 状态文字 | |
| - venue_dropdown 的更新 | |
| - 默认的表格和图 | |
| """ | |
| global df_all | |
| df_all = load_results() | |
| if df_all is None or df_all.empty: | |
| msg = "No JSON files found in ./worldlens-results. Please upload some results." | |
| dummy_fig, ax = plt.subplots(figsize=(6, 3)) | |
| ax.text(0.5, 0.5, msg, ha="center", va="center") | |
| ax.axis("off") | |
| venue_update = gr.update(choices=["All"], value="All") | |
| return msg, venue_update, pd.DataFrame(), dummy_fig | |
| venue_choices = get_venue_choices(df_all) | |
| msg = f"Loaded {len(df_all)} models from {RESULTS_DIR}" | |
| # 用默认 metric 画一次(selected_metrics 先用一个简单默认) | |
| default_selected = ["Subject Fidelity", "Temporal Consistency", "Map Segmentation"] | |
| default_selected = [m for m in default_selected if m in METRIC_CHOICES] | |
| table_df, fig = update_leaderboard( | |
| metric=DEFAULT_METRIC, | |
| top_k=10, | |
| model_filter="", | |
| venue_filter="All", | |
| sort_mode="Auto", | |
| selected_metrics=default_selected, | |
| ) | |
| venue_update = gr.update( | |
| choices=venue_choices, | |
| value="All", | |
| interactive=True, | |
| ) | |
| return msg, venue_update, table_df, fig | |
| with gr.Blocks(css=""" | |
| #title { | |
| text-align: center; | |
| } | |
| """) as demo: | |
| gr.Markdown( | |
| """ | |
| # 🌍 WorldLens Leaderboard | |
| 基于 `./worldlens-results/*.json` 的自动排行榜: | |
| - 选择一个**排序指标**用来排名 | |
| - 勾选多个指标一起在表格中展示 | |
| - 支持模型名搜索 & venue 筛选 | |
| - 自动区分“越大越好 / 越小越好”的指标 | |
| """, | |
| elem_id="title" | |
| ) | |
| status_box = gr.Markdown("Loading results...", elem_id="status") | |
| with gr.Row(): | |
| metric_dropdown = gr.Dropdown( | |
| label="排序指标 / Metric (for ranking)", | |
| choices=METRIC_CHOICES, # 固定 choices,避免动态更新不兼容 | |
| value=DEFAULT_METRIC, | |
| interactive=True, | |
| ) | |
| sort_mode_radio = gr.Radio( | |
| label="排序方式 / Sort mode", | |
| choices=[ | |
| "Auto", | |
| "Ascending (small → large)", | |
| "Descending (large → small)", | |
| ], | |
| value="Auto", | |
| interactive=True, | |
| ) | |
| topk_slider = gr.Slider( | |
| label="显示 Top-K 模型 / Top-K", | |
| minimum=3, | |
| maximum=50, | |
| value=10, | |
| step=1, | |
| interactive=True, | |
| ) | |
| # 新增:表格中展示的多个指标 | |
| metrics_select = gr.CheckboxGroup( | |
| label="在表格中一起展示的指标 / Metrics to show in table", | |
| choices=METRIC_CHOICES, | |
| value=["Subject Fidelity", "Temporal Consistency", "Map Segmentation"], | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| model_filter_box = gr.Textbox( | |
| label="模型名过滤(包含关系)/ Filter by model name", | |
| placeholder="例如: magic, dream, ...", | |
| interactive=True, | |
| ) | |
| venue_dropdown = gr.Dropdown( | |
| label="按 Venue 筛选 / Filter by venue", | |
| choices=["All"], | |
| value="All", | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| reload_button = gr.Button("🔄 Reload JSONs", variant="secondary") | |
| update_button = gr.Button("✅ Update leaderboard", variant="primary") | |
| leaderboard_table = gr.DataFrame( | |
| label="Leaderboard", | |
| interactive=False, | |
| ) | |
| # 显式指定 format="png",避免 webp 不支持的问题 | |
| leaderboard_plot = gr.Plot(label="Metric comparison", format="png") | |
| # 点击 Reload:重新加载 + 更新 venue + 表格与图 | |
| reload_button.click( | |
| fn=reload_data, | |
| inputs=[], | |
| outputs=[status_box, venue_dropdown, leaderboard_table, leaderboard_plot], | |
| ) | |
| # 更新排行榜(多传一个 selected_metrics) | |
| update_button.click( | |
| fn=update_leaderboard, | |
| inputs=[ | |
| metric_dropdown, | |
| topk_slider, | |
| model_filter_box, | |
| venue_dropdown, | |
| sort_mode_radio, | |
| metrics_select, | |
| ], | |
| outputs=[leaderboard_table, leaderboard_plot], | |
| ) | |
| # 页面加载时自动尝试加载一次 | |
| demo.load( | |
| fn=reload_data, | |
| inputs=[], | |
| outputs=[status_box, venue_dropdown, leaderboard_table, leaderboard_plot], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() # 本地想公网访问可以改成 demo.launch(share=True) | |