Spaces:
Running
Running
| import gradio as gr | |
| import pandas as pd | |
| import matplotlib.pyplot as plt | |
| from src.data_loader import DataLoader | |
| from src.leaderboard import Leaderboard | |
| from src.plotter import Plotter | |
| from src.radar_plotter import RadarPlotter | |
| from src.styling import dataframe_to_html, get_academic_css | |
| from src.utils import get_metric_choices, clean_metric_names | |
| data_loader = DataLoader(results_dir="./data") | |
| leaderboard = Leaderboard(data_loader) | |
| plotter = Plotter(data_loader) | |
| radar_plotter = RadarPlotter(data_loader) | |
| DEFAULT_METRIC = "Average β" | |
| TITLE_RESOURCE_LINKS = """ | |
| <div class="project-links-bar"> | |
| <a class="pl-link pl-project" href="https://iworld-bench.com/" target="_blank" rel="noopener noreferrer"><i class="fa-solid fa-globe" aria-hidden="true"></i><em>Project Page</em></a> | |
| <a class="pl-link pl-dataset" href="https://huggingface.co/datasets/EmbodiedCity/iWorld-Bench-Dataset" target="_blank" rel="noopener noreferrer"><i class="fa-solid fa-database" aria-hidden="true"></i><em>Dataset</em></a> | |
| <a class="pl-link pl-code" href="https://github.com/EmbodiedCity/iWorld-Bench" target="_blank" rel="noopener noreferrer"><i class="fa-brands fa-github" aria-hidden="true"></i><em>Code</em></a> | |
| <a class="pl-link pl-leaderboard" href="https://huggingface.co/spaces/EmbodiedCity/iWorld-Bench" target="_blank" rel="noopener noreferrer"><i class="fa-solid fa-trophy" aria-hidden="true"></i><em>Leaderboard</em></a> | |
| </div> | |
| """ | |
| def reload_data(): | |
| msg = data_loader.reload_data() | |
| if data_loader.df_all is None or data_loader.df_all.empty: | |
| dummy_fig, ax = plt.subplots(figsize=(6, 3)) | |
| ax.text(0.5, 0.5, msg, ha="center", va="center") | |
| ax.axis("off") | |
| placeholder_html = "<div class='placeholder'>No data available</div>" | |
| # Return empty strings for dropdowns, placeholder, dummy figure | |
| return "", gr.update(choices=["All"], value="All"), placeholder_html, dummy_fig | |
| # Only category filter remains | |
| category_choices = data_loader.get_category_choices() | |
| all_metrics_with_markers = [m for m in get_metric_choices() if m != "Average β"] | |
| # Ensure Average column is always included | |
| selected = ["Average"] + clean_metric_names(all_metrics_with_markers) | |
| table_df = leaderboard.update_leaderboard( | |
| metric="Average", | |
| top_k=25, | |
| model_filter="", | |
| open_source_filter="All", | |
| year_filter="All", | |
| category_filter="All", | |
| sort_mode="Auto", | |
| selected_metrics=selected, | |
| ) | |
| html_table = dataframe_to_html(table_df) | |
| radar_fig = radar_plotter.create_radar_chart() | |
| return "", \ | |
| gr.update(choices=category_choices, value="All"), \ | |
| html_table, radar_fig | |
| def update_leaderboard_wrapper(metric, top_k, model_filter, | |
| category_filter, sort_mode, selected_metrics): | |
| clean_metric = clean_metric_names([metric])[0] | |
| # Ensure Average column is always included | |
| clean_selected = ["Average"] + clean_metric_names(selected_metrics) | |
| table_df = leaderboard.update_leaderboard( | |
| clean_metric, top_k, model_filter, | |
| open_source_filter="All", | |
| year_filter="All", | |
| category_filter=category_filter, | |
| sort_mode=sort_mode, | |
| selected_metrics=clean_selected, | |
| ) | |
| html_table = dataframe_to_html(table_df) | |
| displayed_models = table_df["Model"].tolist() if not table_df.empty else [] | |
| if displayed_models and data_loader.df_all is not None: | |
| radar_df = data_loader.df_all[data_loader.df_all["Model"].isin(displayed_models)].copy() | |
| else: | |
| radar_df = pd.DataFrame() | |
| radar_fig = radar_plotter.create_radar_chart(radar_df) | |
| return html_table, radar_fig | |
| def create_comparison_plot_wrapper(model_filter, category_filter, | |
| selected_plot_metric, plot_sort_mode): | |
| clean_metric = clean_metric_names([selected_plot_metric])[0] | |
| return plotter.create_comparison_plot( | |
| model_filter, | |
| open_source_filter="All", | |
| year_filter="All", | |
| category_filter=category_filter, | |
| metric=clean_metric, | |
| sort_mode=plot_sort_mode | |
| ) | |
| academic_css = get_academic_css() | |
| with gr.Blocks(css=academic_css) as demo: | |
| gr.Markdown( | |
| """ | |
| # <span class="emoji">π</span> iWorld-Bench Leaderboard | |
| <span class="subtitle">A Benchmark for Interactive World Models with a Unified Action Generation Framework</span> | |
| """, | |
| elem_id="title", | |
| ) | |
| gr.HTML(TITLE_RESOURCE_LINKS) | |
| # Hidden status box | |
| status_box = gr.Markdown(visible=False) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| metric_choices = get_metric_choices() | |
| metric_dropdown = gr.Dropdown( | |
| label="Primary Ranking Metric", | |
| choices=metric_choices, | |
| value=DEFAULT_METRIC, | |
| interactive=True, | |
| ) | |
| with gr.Column(scale=1): | |
| sort_mode_radio = gr.Radio( | |
| label="Sort Order", | |
| choices=["Auto", "Ascending (low β high)", "Descending (high β low)"], | |
| value="Auto", | |
| interactive=True, | |
| ) | |
| topk_slider = gr.Slider( | |
| label="Display Top-K Models", | |
| minimum=3, maximum=50, value=25, step=1, | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| metrics_select = gr.CheckboxGroup( | |
| label="Additional Metrics to Display (π indicates dimension metrics)", | |
| choices=[m for m in metric_choices if m != "Average β"], | |
| value=[m for m in metric_choices if m != "Average β"], | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| model_filter_box = gr.Textbox( | |
| label="Filter by Model Name", | |
| placeholder="Enter model name (partial match)", | |
| interactive=True, | |
| ) | |
| # Removed Open Source and Year filters | |
| with gr.Column(scale=1): | |
| category_dropdown = gr.Dropdown( | |
| label="Filter by Category", | |
| choices=["All"], | |
| value="All", | |
| interactive=True, | |
| ) | |
| with gr.Row(): | |
| reload_button = gr.Button("π Reload Data", variant="secondary", size="sm") | |
| update_button = gr.Button("β Update Leaderboard", variant="primary", size="sm") | |
| leaderboard_html = gr.HTML( | |
| label="Leaderboard Table", | |
| value="<div class='placeholder'>Leaderboard will be displayed here...</div>" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| radar_plot = gr.Plot(label="Performance Radar (8 metrics)", format="png") | |
| with gr.Column(scale=1): | |
| plot_metric_radio = gr.Radio( | |
| label="Select Metric for Comparison Plot", | |
| choices=metric_choices, | |
| value=DEFAULT_METRIC, | |
| interactive=True, | |
| ) | |
| plot_sort_radio = gr.Radio( | |
| label="Plot Sort Order", | |
| choices=["Ascending (low β high)", "Descending (high β low)"], | |
| value="Descending (high β low)", | |
| interactive=True, | |
| ) | |
| plot_update_button = gr.Button("π Generate Comparison Plot", variant="primary", size="sm") | |
| comparison_plot = gr.Plot(label="Model Comparison Visualization", format="png") | |
| # Event bindings β adjusted inputs/outputs | |
| reload_button.click( | |
| fn=reload_data, | |
| inputs=[], | |
| outputs=[status_box, category_dropdown, leaderboard_html, radar_plot], | |
| ) | |
| update_button.click( | |
| fn=update_leaderboard_wrapper, | |
| inputs=[ | |
| metric_dropdown, topk_slider, model_filter_box, | |
| category_dropdown, sort_mode_radio, metrics_select, | |
| ], | |
| outputs=[leaderboard_html, radar_plot], | |
| ) | |
| plot_update_button.click( | |
| fn=create_comparison_plot_wrapper, | |
| inputs=[ | |
| model_filter_box, category_dropdown, | |
| plot_metric_radio, plot_sort_radio, | |
| ], | |
| outputs=[comparison_plot], | |
| ) | |
| demo.load( | |
| fn=reload_data, | |
| inputs=[], | |
| outputs=[status_box, category_dropdown, leaderboard_html, radar_plot], | |
| ) | |
| if __name__ == "__main__": | |
| import os | |
| # HF Spaces: leave share off (default). Docker / locked-down hosts: set GRADIO_SHARE=true. | |
| demo.launch( | |
| server_name=os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0"), | |
| server_port=int(os.environ.get("GRADIO_SERVER_PORT", "7860")), | |
| share=os.environ.get("GRADIO_SHARE", "false").strip().lower() in ("1", "true", "yes"), | |
| ) | |