Spaces:
Running
Running
| import os | |
| import json | |
| import pandas as pd | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| import matplotlib.pyplot as plt | |
| import gradio as gr | |
| # ๅ ่ฝฝๆฐๆฎ | |
| def load_json(filename): | |
| possible_paths = [ | |
| f'results/{filename}', | |
| filename, | |
| os.path.join(os.path.dirname(__file__), 'results', filename), | |
| os.path.join(os.path.dirname(__file__), filename) | |
| ] | |
| for path in possible_paths: | |
| if os.path.exists(path): | |
| with open(path, 'r', encoding='utf-8') as f: | |
| return json.load(f) | |
| raise FileNotFoundError(f"Cannot find {filename}") | |
| # ๅ ่ฝฝๆๆๆฐๆฎ | |
| va_data = load_json('va_models.json') | |
| vla_data = load_json('vla_models.json') | |
| datasets_data = load_json('datasets.json') | |
| nuscenes_data = load_json('nuscenes_results.json') | |
| wod_data = load_json('wod_results.json') | |
| navsim_data = load_json('navsim_results.json') | |
| bench2drive_data = load_json('bench2drive_results.json') | |
| df_va = pd.DataFrame(va_data) | |
| df_vla = pd.DataFrame(vla_data) | |
| df_datasets = pd.DataFrame(datasets_data) | |
| df_nuscenes = pd.DataFrame(nuscenes_data) | |
| df_wod = pd.DataFrame(wod_data) | |
| df_navsim = pd.DataFrame(navsim_data) | |
| df_bench2drive = pd.DataFrame(bench2drive_data) | |
| INPUT_ICONS = {'camera': '๐ท', 'lidar': '๐', 'status': 'โ๏ธ', 'prompt': '๐ญ', | |
| 'instruction': '๐', 'scene': '๐', 'traffic': '๐ฆ', 'context': '๐'} | |
| ACTION_LABELS = {'RL': 'Policy w/ RL', 'REG': 'Decoder+MLP', 'SEL': 'Traj. Selection', | |
| 'GEN': 'Generative Model', 'LH': 'Language Head'} | |
| # ๆ ผๅผๅๅฝๆฐ | |
| def format_inputs(inputs): | |
| if isinstance(inputs, list): | |
| return ' '.join([INPUT_ICONS.get(inp, inp) for inp in inputs]) | |
| return str(inputs) | |
| def format_datasets(datasets): | |
| if isinstance(datasets, list): | |
| return ', '.join(datasets[:3]) | |
| return str(datasets) | |
| # VA Models | |
| def update_va_table(category, venue, action, search): | |
| df = df_va.copy() | |
| if category != "All": | |
| df = df[df['category'] == category] | |
| if venue: | |
| df = df[df['venue'].str.contains(venue, case=False, na=False)] | |
| if action != "All": | |
| df = df[df['action'] == action] | |
| if search: | |
| mask = df['model'].str.contains(search, case=False, na=False) | df['vision'].str.contains(search, case=False, na=False) | |
| df = df[mask] | |
| display_df = df[['id', 'model', 'venue', 'input', 'dataset', 'vision', 'action', 'output', 'category']].copy() | |
| display_df['input'] = display_df['input'].apply(format_inputs) | |
| display_df['dataset'] = display_df['dataset'].apply(format_datasets) | |
| display_df['action'] = display_df['action'].map(ACTION_LABELS) | |
| return display_df, f"**๐ Statistics:** Total {len(df)} models | Categories: {df['category'].nunique()}" | |
| # ๅๅงๅVA่กจๆ ผๆฐๆฎ | |
| def init_va_table(): | |
| return update_va_table("All", "", "All", "") | |
| def create_va_plot(category): | |
| df = df_va.copy() | |
| if category != "All": | |
| df = df[df['category'] == category] | |
| bg_color, panel_color = "#0e1117", "#161b22" | |
| bar_colors = ['#4cc9f0', '#f72585', '#7209b7', '#3a0ca3', '#4361ee', '#4895ef', '#560bad'] | |
| text_color, grid_color = "#c9d1d9", "#30363d" | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)) | |
| fig.patch.set_facecolor(bg_color) | |
| category_counts = df['category'].value_counts() | |
| ax1.barh(category_counts.index, category_counts.values, color=bar_colors[:len(category_counts)]) | |
| ax1.set_facecolor(panel_color) | |
| ax1.set_xlabel('Count', color=text_color, fontsize=11) | |
| ax1.set_title('Models by Category', color=text_color, fontsize=13, fontweight='bold') | |
| ax1.tick_params(colors=text_color) | |
| ax1.grid(axis='x', linestyle='--', alpha=0.3, color=grid_color) | |
| ax1.xaxis.set_major_locator(plt.MaxNLocator(integer=True)) | |
| for spine in ['top', 'right', 'left']: | |
| ax1.spines[spine].set_visible(False) | |
| ax1.spines['bottom'].set_color(grid_color) | |
| action_counts = df['action'].value_counts() | |
| ax2.bar(range(len(action_counts)), action_counts.values, color=bar_colors[:len(action_counts)]) | |
| ax2.set_facecolor(panel_color) | |
| ax2.set_xticks(range(len(action_counts))) | |
| ax2.set_xticklabels([ACTION_LABELS.get(a, a) for a in action_counts.index], rotation=45, ha='right') | |
| ax2.set_ylabel('Count', color=text_color, fontsize=11) | |
| ax2.set_title('Models by Action Type', color=text_color, fontsize=13, fontweight='bold') | |
| ax2.tick_params(colors=text_color) | |
| ax2.grid(axis='y', linestyle='--', alpha=0.3, color=grid_color) | |
| ax2.yaxis.set_major_locator(plt.MaxNLocator(integer=True)) | |
| for spine in ['top', 'right']: | |
| ax2.spines[spine].set_visible(False) | |
| ax2.spines['left'].set_color(grid_color) | |
| ax2.spines['bottom'].set_color(grid_color) | |
| plt.tight_layout() | |
| result = fig | |
| plt.close('all') | |
| return result | |
| # VLA Models | |
| def update_vla_table(category, venue, language, search): | |
| df = df_vla.copy() | |
| if category != "All": | |
| df = df[df['category'] == category] | |
| if venue: | |
| df = df[df['venue'].str.contains(venue, case=False, na=False)] | |
| if language: | |
| df = df[df['language'].str.contains(language, case=False, na=False)] | |
| if search: | |
| mask = df['model'].str.contains(search, case=False, na=False) | df['vision'].str.contains(search, case=False, na=False) | df['language'].str.contains(search, case=False, na=False) | |
| df = df[mask] | |
| display_df = df[['id', 'model', 'venue', 'input', 'dataset', 'vision', 'language', 'action', 'output', 'category']].copy() | |
| display_df['input'] = display_df['input'].apply(format_inputs) | |
| display_df['dataset'] = display_df['dataset'].apply(format_datasets) | |
| display_df['action'] = display_df['action'].map(ACTION_LABELS) | |
| return display_df, f"**๐ Statistics:** Total {len(df)} models | Categories: {df['category'].nunique()}" | |
| # ๅๅงๅVLA่กจๆ ผๆฐๆฎ | |
| def init_vla_table(): | |
| return update_vla_table("All", "", "", "") | |
| def create_vla_plot(category): | |
| df = df_vla.copy() | |
| if category != "All": | |
| df = df[df['category'] == category] | |
| bg_color, panel_color = "#0e1117", "#161b22" | |
| bar_colors = ['#4cc9f0', '#f72585', '#7209b7', '#3a0ca3'] | |
| text_color, grid_color = "#c9d1d9", "#30363d" | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)) | |
| fig.patch.set_facecolor(bg_color) | |
| category_counts = df['category'].value_counts() | |
| ax1.barh(category_counts.index, category_counts.values, color=bar_colors[:len(category_counts)]) | |
| ax1.set_facecolor(panel_color) | |
| ax1.set_xlabel('Count', color=text_color, fontsize=11) | |
| ax1.set_title('VLA Models by Category', color=text_color, fontsize=13, fontweight='bold') | |
| ax1.tick_params(colors=text_color) | |
| ax1.grid(axis='x', linestyle='--', alpha=0.3, color=grid_color) | |
| ax1.xaxis.set_major_locator(plt.MaxNLocator(integer=True)) | |
| for spine in ['top', 'right', 'left']: | |
| ax1.spines[spine].set_visible(False) | |
| ax1.spines['bottom'].set_color(grid_color) | |
| language_counts = df['language'].value_counts().head(10) | |
| ax2.bar(range(len(language_counts)), language_counts.values, color=bar_colors[:len(language_counts)]) | |
| ax2.set_facecolor(panel_color) | |
| ax2.set_xticks(range(len(language_counts))) | |
| ax2.set_xticklabels(language_counts.index, rotation=45, ha='right') | |
| ax2.set_ylabel('Count', color=text_color, fontsize=11) | |
| ax2.set_title('Top Language Models', color=text_color, fontsize=13, fontweight='bold') | |
| ax2.tick_params(colors=text_color) | |
| ax2.grid(axis='y', linestyle='--', alpha=0.3, color=grid_color) | |
| ax2.yaxis.set_major_locator(plt.MaxNLocator(integer=True)) | |
| for spine in ['top', 'right']: | |
| ax2.spines[spine].set_visible(False) | |
| ax2.spines['left'].set_color(grid_color) | |
| ax2.spines['bottom'].set_color(grid_color) | |
| plt.tight_layout() | |
| result = fig | |
| plt.close('all') | |
| return result | |
| # Datasets | |
| def update_datasets_table(category, year, search): | |
| df = df_datasets.copy() | |
| if category != "All": | |
| df = df[df['category'] == category] | |
| if year: | |
| try: | |
| df = df[df['year'] == int(year)] | |
| except: | |
| pass | |
| if search: | |
| mask = df['dataset'].str.contains(search, case=False, na=False) | df['sensor'].str.contains(search, case=False, na=False) | |
| df = df[mask] | |
| return df, f"**๐ Statistics:** Total {len(df)} datasets | Years: {df['year'].nunique()}" | |
| # ๅๅงๅDatasets่กจๆ ผๆฐๆฎ | |
| def init_datasets_table(): | |
| return update_datasets_table("All", "", "") | |
| # Benchmark tables update functions | |
| def update_nuscenes_table(category, search): | |
| df = df_nuscenes.copy() | |
| if category != "All": | |
| df = df[df['category'] == category] | |
| if search: | |
| mask = df['model'].str.contains(search, case=False, na=False) | |
| df = df[mask] | |
| return df, f"**๐ Statistics:** Total {len(df)} models" | |
| # ๅๅงๅnuScenes่กจๆ ผๆฐๆฎ | |
| def init_nuscenes_table(): | |
| return update_nuscenes_table("All", "") | |
| def create_nuscenes_plot(category): | |
| """ๅๅปบnuScenes็ป่ฎกๅพ่กจ""" | |
| df = df_nuscenes.copy() | |
| if category != "All": | |
| df = df[df['category'] == category] | |
| bg_color, panel_color = "#0e1117", "#161b22" | |
| bar_colors = ['#4cc9f0', '#f72585', '#7209b7', '#3a0ca3'] | |
| text_color, grid_color = "#c9d1d9", "#30363d" | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)) | |
| fig.patch.set_facecolor(bg_color) | |
| # 1. ๆ็ฑปๅซ็ป่ฎก | |
| category_counts = df['category'].value_counts() | |
| ax1.bar(range(len(category_counts)), category_counts.values, color=bar_colors[:len(category_counts)]) | |
| ax1.set_facecolor(panel_color) | |
| ax1.set_xticks(range(len(category_counts))) | |
| ax1.set_xticklabels(category_counts.index, rotation=15, ha='right') | |
| ax1.set_ylabel('Count', color=text_color, fontsize=11) | |
| ax1.set_title('Models by Category', color=text_color, fontsize=13, fontweight='bold') | |
| ax1.tick_params(colors=text_color) | |
| ax1.grid(axis='y', linestyle='--', alpha=0.3, color=grid_color) | |
| ax1.yaxis.set_major_locator(plt.MaxNLocator(integer=True)) | |
| for spine in ['top', 'right']: | |
| ax1.spines[spine].set_visible(False) | |
| ax1.spines['left'].set_color(grid_color) | |
| ax1.spines['bottom'].set_color(grid_color) | |
| # 2. Top 10 ๆจกๅๆL2 Avgๆๅบ | |
| top_models = df.nsmallest(10, 'l2_avg')[['model', 'l2_avg']].dropna() | |
| ax2.barh(range(len(top_models)), top_models['l2_avg'].values, color=bar_colors[2]) | |
| ax2.set_facecolor(panel_color) | |
| ax2.set_yticks(range(len(top_models))) | |
| ax2.set_yticklabels(top_models['model'].values, fontsize=9) | |
| ax2.set_xlabel('L2 Avg (m)', color=text_color, fontsize=11) | |
| ax2.set_title('Top 10 Models by L2 Error', color=text_color, fontsize=13, fontweight='bold') | |
| ax2.tick_params(colors=text_color) | |
| ax2.grid(axis='x', linestyle='--', alpha=0.3, color=grid_color) | |
| ax2.invert_yaxis() | |
| for spine in ['top', 'right', 'left']: | |
| ax2.spines[spine].set_visible(False) | |
| ax2.spines['bottom'].set_color(grid_color) | |
| plt.tight_layout() | |
| result = fig | |
| plt.close('all') | |
| return result | |
| def update_wod_table(category, search): | |
| df = df_wod.copy() | |
| if category != "All": | |
| df = df[df['category'] == category] | |
| if search: | |
| mask = df['model'].str.contains(search, case=False, na=False) | |
| df = df[mask] | |
| return df, f"**๐ Statistics:** Total {len(df)} models" | |
| # ๅๅงๅWOD่กจๆ ผๆฐๆฎ | |
| def init_wod_table(): | |
| return update_wod_table("All", "") | |
| def create_wod_plot(category): | |
| """ๅๅปบWOD็ป่ฎกๅพ่กจ""" | |
| df = df_wod.copy() | |
| if category != "All": | |
| df = df[df['category'] == category] | |
| bg_color, panel_color = "#0e1117", "#161b22" | |
| bar_colors = ['#4cc9f0', '#f72585', '#7209b7', '#3a0ca3'] | |
| text_color, grid_color = "#c9d1d9", "#30363d" | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)) | |
| fig.patch.set_facecolor(bg_color) | |
| # 1. ๆ็ฑปๅซ็ป่ฎก | |
| category_counts = df['category'].value_counts() | |
| ax1.bar(range(len(category_counts)), category_counts.values, color=bar_colors[:len(category_counts)]) | |
| ax1.set_facecolor(panel_color) | |
| ax1.set_xticks(range(len(category_counts))) | |
| ax1.set_xticklabels(category_counts.index, rotation=15, ha='right') | |
| ax1.set_ylabel('Count', color=text_color, fontsize=11) | |
| ax1.set_title('Models by Category', color=text_color, fontsize=13, fontweight='bold') | |
| ax1.tick_params(colors=text_color) | |
| ax1.grid(axis='y', linestyle='--', alpha=0.3, color=grid_color) | |
| ax1.yaxis.set_major_locator(plt.MaxNLocator(integer=True)) | |
| for spine in ['top', 'right']: | |
| ax1.spines[spine].set_visible(False) | |
| ax1.spines['left'].set_color(grid_color) | |
| ax1.spines['bottom'].set_color(grid_color) | |
| # 2. TopๆจกๅๆRFS Overallๆๅบ | |
| top_models = df.nlargest(10, 'rfs_overall')[['model', 'rfs_overall']].dropna() | |
| if len(top_models) > 0: | |
| ax2.barh(range(len(top_models)), top_models['rfs_overall'].values, color=bar_colors[1]) | |
| ax2.set_facecolor(panel_color) | |
| ax2.set_yticks(range(len(top_models))) | |
| ax2.set_yticklabels(top_models['model'].values, fontsize=9) | |
| ax2.set_xlabel('RFS Overall', color=text_color, fontsize=11) | |
| ax2.set_title('Top Models by RFS Overall', color=text_color, fontsize=13, fontweight='bold') | |
| ax2.tick_params(colors=text_color) | |
| ax2.grid(axis='x', linestyle='--', alpha=0.3, color=grid_color) | |
| ax2.invert_yaxis() | |
| else: | |
| ax2.text(0.5, 0.5, 'No data available', ha='center', va='center', | |
| color=text_color, fontsize=12, transform=ax2.transAxes) | |
| ax2.set_facecolor(panel_color) | |
| for spine in ['top', 'right', 'left']: | |
| ax2.spines[spine].set_visible(False) | |
| ax2.spines['bottom'].set_color(grid_color) | |
| plt.tight_layout() | |
| result = fig | |
| plt.close('all') | |
| return result | |
| def update_navsim_table(category, search): | |
| df = df_navsim.copy() | |
| if category != "All": | |
| df = df[df['category'] == category] | |
| if search: | |
| mask = df['model'].str.contains(search, case=False, na=False) | |
| df = df[mask] | |
| return df, f"**๐ Statistics:** Total {len(df)} models" | |
| # ๅๅงๅNAVSIM่กจๆ ผๆฐๆฎ | |
| def init_navsim_table(): | |
| return update_navsim_table("All", "") | |
| def create_navsim_plot(category): | |
| """ๅๅปบNAVSIM็ป่ฎกๅพ่กจ""" | |
| df = df_navsim.copy() | |
| if category != "All": | |
| df = df[df['category'] == category] | |
| bg_color, panel_color = "#0e1117", "#161b22" | |
| bar_colors = ['#4cc9f0', '#f72585', '#7209b7', '#3a0ca3'] | |
| text_color, grid_color = "#c9d1d9", "#30363d" | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)) | |
| fig.patch.set_facecolor(bg_color) | |
| # 1. ๆ็ฑปๅซ็ป่ฎก | |
| category_counts = df['category'].value_counts() | |
| ax1.bar(range(len(category_counts)), category_counts.values, color=bar_colors[:len(category_counts)]) | |
| ax1.set_facecolor(panel_color) | |
| ax1.set_xticks(range(len(category_counts))) | |
| ax1.set_xticklabels(category_counts.index, rotation=15, ha='right') | |
| ax1.set_ylabel('Count', color=text_color, fontsize=11) | |
| ax1.set_title('Models by Category', color=text_color, fontsize=13, fontweight='bold') | |
| ax1.tick_params(colors=text_color) | |
| ax1.grid(axis='y', linestyle='--', alpha=0.3, color=grid_color) | |
| ax1.yaxis.set_major_locator(plt.MaxNLocator(integer=True)) | |
| for spine in ['top', 'right']: | |
| ax1.spines[spine].set_visible(False) | |
| ax1.spines['left'].set_color(grid_color) | |
| ax1.spines['bottom'].set_color(grid_color) | |
| # 2. Top 10 ๆจกๅๆPDMSๆๅบ | |
| top_models = df.nlargest(10, 'pdms')[['model', 'pdms']] | |
| ax2.barh(range(len(top_models)), top_models['pdms'].values, color=bar_colors[0]) | |
| ax2.set_facecolor(panel_color) | |
| ax2.set_yticks(range(len(top_models))) | |
| ax2.set_yticklabels(top_models['model'].values, fontsize=9) | |
| ax2.set_xlabel('PDMS Score', color=text_color, fontsize=11) | |
| ax2.set_title('Top 10 Models by PDMS', color=text_color, fontsize=13, fontweight='bold') | |
| ax2.tick_params(colors=text_color) | |
| ax2.grid(axis='x', linestyle='--', alpha=0.3, color=grid_color) | |
| ax2.invert_yaxis() | |
| for spine in ['top', 'right', 'left']: | |
| ax2.spines[spine].set_visible(False) | |
| ax2.spines['bottom'].set_color(grid_color) | |
| plt.tight_layout() | |
| result = fig | |
| plt.close('all') | |
| return result | |
| def update_bench2drive_table(category, search): | |
| df = df_bench2drive.copy() | |
| if category != "All": | |
| df = df[df['category'] == category] | |
| if search: | |
| mask = df['model'].str.contains(search, case=False, na=False) | |
| df = df[mask] | |
| return df, f"**๐ Statistics:** Total {len(df)} models" | |
| # ๅๅงๅBench2Drive่กจๆ ผๆฐๆฎ | |
| def init_bench2drive_table(): | |
| return update_bench2drive_table("All", "") | |
| def create_bench2drive_plot(category): | |
| """ๅๅปบBench2Drive็ป่ฎกๅพ่กจ""" | |
| df = df_bench2drive.copy() | |
| if category != "All": | |
| df = df[df['category'] == category] | |
| bg_color, panel_color = "#0e1117", "#161b22" | |
| bar_colors = ['#4cc9f0', '#f72585', '#7209b7', '#3a0ca3'] | |
| text_color, grid_color = "#c9d1d9", "#30363d" | |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5)) | |
| fig.patch.set_facecolor(bg_color) | |
| # 1. ๆ็ฑปๅซ็ป่ฎก | |
| category_counts = df['category'].value_counts() | |
| ax1.bar(range(len(category_counts)), category_counts.values, color=bar_colors[:len(category_counts)]) | |
| ax1.set_facecolor(panel_color) | |
| ax1.set_xticks(range(len(category_counts))) | |
| ax1.set_xticklabels(category_counts.index, rotation=15, ha='right') | |
| ax1.set_ylabel('Count', color=text_color, fontsize=11) | |
| ax1.set_title('Models by Category', color=text_color, fontsize=13, fontweight='bold') | |
| ax1.tick_params(colors=text_color) | |
| ax1.grid(axis='y', linestyle='--', alpha=0.3, color=grid_color) | |
| ax1.yaxis.set_major_locator(plt.MaxNLocator(integer=True)) | |
| for spine in ['top', 'right']: | |
| ax1.spines[spine].set_visible(False) | |
| ax1.spines['left'].set_color(grid_color) | |
| ax1.spines['bottom'].set_color(grid_color) | |
| # 2. Top 10 ๆจกๅๆDSๆๅบ | |
| top_models = df.nlargest(10, 'ds')[['model', 'ds']] | |
| ax2.barh(range(len(top_models)), top_models['ds'].values, color=bar_colors[3]) | |
| ax2.set_facecolor(panel_color) | |
| ax2.set_yticks(range(len(top_models))) | |
| ax2.set_yticklabels(top_models['model'].values, fontsize=9) | |
| ax2.set_xlabel('Driving Score', color=text_color, fontsize=11) | |
| ax2.set_title('Top 10 Models by DS', color=text_color, fontsize=13, fontweight='bold') | |
| ax2.tick_params(colors=text_color) | |
| ax2.grid(axis='x', linestyle='--', alpha=0.3, color=grid_color) | |
| ax2.invert_yaxis() | |
| for spine in ['top', 'right', 'left']: | |
| ax2.spines[spine].set_visible(False) | |
| ax2.spines['bottom'].set_color(grid_color) | |
| plt.tight_layout() | |
| result = fig | |
| plt.close('all') | |
| return result | |
| # ้ข็ๆๅๅงๆฐๆฎๅๅพ่กจ | |
| initial_va_df, initial_va_stats = init_va_table() | |
| initial_vla_df, initial_vla_stats = init_vla_table() | |
| initial_datasets_df, initial_datasets_stats = init_datasets_table() | |
| initial_nuscenes_df, initial_nuscenes_stats = init_nuscenes_table() | |
| initial_wod_df, initial_wod_stats = init_wod_table() | |
| initial_navsim_df, initial_navsim_stats = init_navsim_table() | |
| initial_bench2drive_df, initial_bench2drive_stats = init_bench2drive_table() | |
| initial_va_plot = create_va_plot("All") | |
| initial_vla_plot = create_vla_plot("All") | |
| initial_nuscenes_plot = create_nuscenes_plot("All") | |
| initial_wod_plot = create_wod_plot("All") | |
| initial_navsim_plot = create_navsim_plot("All") | |
| initial_bench2drive_plot = create_bench2drive_plot("All") | |
| # Gradio UI | |
| with gr.Blocks(css="#title {text-align: center;} .gradio-container {max-width: 100% !important;}") as demo: | |
| gr.Markdown("# ๐ VLA for Autonomous Driving: Model Leaderboard", elem_id="title") | |
| gr.Markdown("### ๐ Survey: *Vision-Language-Action Models for Autonomous Driving*") | |
| with gr.Tabs(): | |
| # Tab 1: VA Models | |
| with gr.Tab("Vision-Action Models (VA)"): | |
| gr.Markdown("### Table 1: Vision-Action Models in Autonomous Driving") | |
| va_stats = gr.Markdown(initial_va_stats) | |
| with gr.Row(): | |
| va_category = gr.Dropdown(label="Category", choices=["All"] + sorted(df_va['category'].unique().tolist()), value="All") | |
| va_venue = gr.Textbox(label="Filter by Venue", placeholder="e.g., CVPR, ICCV...") | |
| va_action = gr.Dropdown(label="Action Type", choices=["All"] + sorted(df_va['action'].unique().tolist()), value="All") | |
| va_search = gr.Textbox(label="Search", placeholder="e.g., TransFuser...") | |
| va_update_btn = gr.Button("๐ Update", variant="primary") | |
| va_table = gr.Dataframe(value=initial_va_df, label="VA Models", interactive=False, wrap=True) | |
| va_plot = gr.Plot(label="Statistics", value=initial_va_plot, format="png") | |
| va_update_btn.click(fn=lambda cat, ven, act, search: (*update_va_table(cat, ven, act, search), create_va_plot(cat)), | |
| inputs=[va_category, va_venue, va_action, va_search], outputs=[va_table, va_stats, va_plot]) | |
| # Tab 2: VLA Models | |
| with gr.Tab("Vision-Language-Action Models (VLA)"): | |
| gr.Markdown("### Table 3: Vision-Language-Action Models in Autonomous Driving") | |
| vla_stats = gr.Markdown(initial_vla_stats) | |
| with gr.Row(): | |
| vla_category = gr.Dropdown(label="Category", choices=["All"] + sorted(df_vla['category'].unique().tolist()), value="All") | |
| vla_venue = gr.Textbox(label="Filter by Venue", placeholder="e.g., CVPR...") | |
| vla_language = gr.Textbox(label="Filter by Language Model", placeholder="e.g., Qwen...") | |
| vla_search = gr.Textbox(label="Search", placeholder="e.g., AutoVLA...") | |
| vla_update_btn = gr.Button("๐ Update", variant="primary") | |
| vla_table = gr.Dataframe(value=initial_vla_df, label="VLA Models", interactive=False, wrap=True) | |
| vla_plot = gr.Plot(label="Statistics", value=initial_vla_plot, format="png") | |
| vla_update_btn.click(fn=lambda cat, ven, lang, search: (*update_vla_table(cat, ven, lang, search), create_vla_plot(cat)), | |
| inputs=[vla_category, vla_venue, vla_language, vla_search], outputs=[vla_table, vla_stats, vla_plot]) | |
| # Tab 3: Datasets | |
| with gr.Tab("Datasets & Benchmarks"): | |
| gr.Markdown("### Table 4: Summary of Datasets & Benchmarks") | |
| datasets_stats = gr.Markdown(initial_datasets_stats) | |
| with gr.Row(): | |
| datasets_category = gr.Dropdown(label="Category", choices=["All"] + sorted(df_datasets['category'].unique().tolist()), value="All") | |
| datasets_year = gr.Textbox(label="Filter by Year", placeholder="e.g., 2024...") | |
| datasets_search = gr.Textbox(label="Search", placeholder="e.g., nuScenes...") | |
| datasets_update_btn = gr.Button("๐ Update", variant="primary") | |
| datasets_table = gr.Dataframe(value=initial_datasets_df, label="Datasets", interactive=False, wrap=True) | |
| datasets_update_btn.click(fn=update_datasets_table, inputs=[datasets_category, datasets_year, datasets_search], outputs=[datasets_table, datasets_stats]) | |
| # Tab 4: nuScenes | |
| with gr.Tab("nuScenes Open-Loop"): | |
| gr.Markdown("### Table 5: Open-Loop Planning Results on nuScenes") | |
| gr.Markdown("**Metrics:** L2 Error (m) โ | Collision Rate โ") | |
| nuscenes_stats = gr.Markdown(initial_nuscenes_stats) | |
| with gr.Row(): | |
| nuscenes_category = gr.Dropdown(label="Category", choices=["All", "Vision-Action", "Vision-Language-Action"], value="All") | |
| nuscenes_search = gr.Textbox(label="Search Model", placeholder="e.g., UniAD, EMMA...") | |
| nuscenes_update_btn = gr.Button("๐ Update", variant="primary") | |
| nuscenes_table = gr.Dataframe(value=initial_nuscenes_df, label="nuScenes Results", interactive=False, wrap=True) | |
| nuscenes_plot = gr.Plot(label="Statistics", value=initial_nuscenes_plot, format="png") | |
| nuscenes_update_btn.click(fn=lambda cat, search: (*update_nuscenes_table(cat, search), create_nuscenes_plot(cat)), | |
| inputs=[nuscenes_category, nuscenes_search], | |
| outputs=[nuscenes_table, nuscenes_stats, nuscenes_plot]) | |
| # Tab 5: WOD-E2E | |
| with gr.Tab("WOD-E2E"): | |
| gr.Markdown("### Table 6: Results on WOD-E2E Test Split") | |
| gr.Markdown("**Metrics:** RFS (Overall/Spotlight) โ | ADE (5s/3s) โ") | |
| wod_stats = gr.Markdown(initial_wod_stats) | |
| with gr.Row(): | |
| wod_category = gr.Dropdown(label="Category", choices=["All", "Vision-Action", "Vision-Language-Action"], value="All") | |
| wod_search = gr.Textbox(label="Search Model", placeholder="e.g., AutoVLA...") | |
| wod_update_btn = gr.Button("๐ Update", variant="primary") | |
| wod_table = gr.Dataframe(value=initial_wod_df, label="WOD-E2E Results", interactive=False, wrap=True) | |
| wod_plot = gr.Plot(label="Statistics", value=initial_wod_plot, format="png") | |
| wod_update_btn.click(fn=lambda cat, search: (*update_wod_table(cat, search), create_wod_plot(cat)), | |
| inputs=[wod_category, wod_search], | |
| outputs=[wod_table, wod_stats, wod_plot]) | |
| # Tab 6: NAVSIM | |
| with gr.Tab("NAVSIM Closed-Loop"): | |
| gr.Markdown("### Table 7: Closed-Loop Results on NAVSIM") | |
| gr.Markdown("**Metrics:** NC/DAC/TTC/Comf/EP/PDMS โ") | |
| navsim_stats = gr.Markdown(initial_navsim_stats) | |
| with gr.Row(): | |
| navsim_category = gr.Dropdown(label="Category", choices=["All", "Vision-Action", "Vision-Language-Action"], value="All") | |
| navsim_search = gr.Textbox(label="Search Model", placeholder="e.g., ReflectDrive...") | |
| navsim_update_btn = gr.Button("๐ Update", variant="primary") | |
| navsim_table = gr.Dataframe(value=initial_navsim_df, label="NAVSIM Results", interactive=False, wrap=True) | |
| navsim_plot = gr.Plot(label="Statistics", value=initial_navsim_plot, format="png") | |
| navsim_update_btn.click(fn=lambda cat, search: (*update_navsim_table(cat, search), create_navsim_plot(cat)), | |
| inputs=[navsim_category, navsim_search], | |
| outputs=[navsim_table, navsim_stats, navsim_plot]) | |
| # Tab 7: Bench2Drive | |
| with gr.Tab("Bench2Drive"): | |
| gr.Markdown("### Table 8: Closed-Loop & Open-Loop Results on Bench2Drive") | |
| gr.Markdown("**Metrics:** DS/SR โ | L2 Avg โ") | |
| bench2drive_stats = gr.Markdown(initial_bench2drive_stats) | |
| with gr.Row(): | |
| bench2drive_category = gr.Dropdown(label="Category", choices=["All", "Vision-Action", "Vision-Language-Action"], value="All") | |
| bench2drive_search = gr.Textbox(label="Search Model", placeholder="e.g., SimLingo...") | |
| bench2drive_update_btn = gr.Button("๐ Update", variant="primary") | |
| bench2drive_table = gr.Dataframe(value=initial_bench2drive_df, label="Bench2Drive Results", interactive=False, wrap=True) | |
| bench2drive_plot = gr.Plot(label="Statistics", value=initial_bench2drive_plot, format="png") | |
| bench2drive_update_btn.click(fn=lambda cat, search: (*update_bench2drive_table(cat, search), create_bench2drive_plot(cat)), | |
| inputs=[bench2drive_category, bench2drive_search], | |
| outputs=[bench2drive_table, bench2drive_stats, bench2drive_plot]) | |
| gr.Markdown(""" | |
| --- | |
| **Legend:** ๐ท Camera | ๐ LiDAR | โ๏ธ Status | ๐ญ Prompt | ๐ Instruction | ๐ Scene | ๐ฆ Traffic | ๐ Context | |
| **Action Types:** RL=Reinforcement Learning | REG=Decoder+MLP | SEL=Selection | GEN=Generative | LH=Language Head | |
| ๐ **Project:** [https://worldbench.github.io/vla4ad](https://worldbench.github.io/vla4ad) | ๐ **GitHub:** [https://github.com/worldbench/awesome-vla-for-ad](https://github.com/worldbench/awesome-vla-for-ad) | |
| """) | |
| if __name__ == "__main__": | |
| demo.launch() |