vla4ad / app.py
worldbench's picture
Update app.py
b6cc9f3 verified
import os
import json
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import gradio as gr
# ๅŠ ่ฝฝๆ•ฐๆฎ
def load_json(filename):
possible_paths = [
f'results/{filename}',
filename,
os.path.join(os.path.dirname(__file__), 'results', filename),
os.path.join(os.path.dirname(__file__), filename)
]
for path in possible_paths:
if os.path.exists(path):
with open(path, 'r', encoding='utf-8') as f:
return json.load(f)
raise FileNotFoundError(f"Cannot find {filename}")
# ๅŠ ่ฝฝๆ‰€ๆœ‰ๆ•ฐๆฎ
va_data = load_json('va_models.json')
vla_data = load_json('vla_models.json')
datasets_data = load_json('datasets.json')
nuscenes_data = load_json('nuscenes_results.json')
wod_data = load_json('wod_results.json')
navsim_data = load_json('navsim_results.json')
bench2drive_data = load_json('bench2drive_results.json')
df_va = pd.DataFrame(va_data)
df_vla = pd.DataFrame(vla_data)
df_datasets = pd.DataFrame(datasets_data)
df_nuscenes = pd.DataFrame(nuscenes_data)
df_wod = pd.DataFrame(wod_data)
df_navsim = pd.DataFrame(navsim_data)
df_bench2drive = pd.DataFrame(bench2drive_data)
INPUT_ICONS = {'camera': '๐Ÿ“ท', 'lidar': '๐Ÿ”', 'status': 'โš™๏ธ', 'prompt': '๐Ÿ’ญ',
'instruction': '๐Ÿ“', 'scene': '๐ŸŒ†', 'traffic': '๐Ÿšฆ', 'context': '๐Ÿ“š'}
ACTION_LABELS = {'RL': 'Policy w/ RL', 'REG': 'Decoder+MLP', 'SEL': 'Traj. Selection',
'GEN': 'Generative Model', 'LH': 'Language Head'}
# ๆ ผๅผๅŒ–ๅ‡ฝๆ•ฐ
def format_inputs(inputs):
if isinstance(inputs, list):
return ' '.join([INPUT_ICONS.get(inp, inp) for inp in inputs])
return str(inputs)
def format_datasets(datasets):
if isinstance(datasets, list):
return ', '.join(datasets[:3])
return str(datasets)
# VA Models
def update_va_table(category, venue, action, search):
df = df_va.copy()
if category != "All":
df = df[df['category'] == category]
if venue:
df = df[df['venue'].str.contains(venue, case=False, na=False)]
if action != "All":
df = df[df['action'] == action]
if search:
mask = df['model'].str.contains(search, case=False, na=False) | df['vision'].str.contains(search, case=False, na=False)
df = df[mask]
display_df = df[['id', 'model', 'venue', 'input', 'dataset', 'vision', 'action', 'output', 'category']].copy()
display_df['input'] = display_df['input'].apply(format_inputs)
display_df['dataset'] = display_df['dataset'].apply(format_datasets)
display_df['action'] = display_df['action'].map(ACTION_LABELS)
return display_df, f"**๐Ÿ“Š Statistics:** Total {len(df)} models | Categories: {df['category'].nunique()}"
# ๅˆๅง‹ๅŒ–VA่กจๆ ผๆ•ฐๆฎ
def init_va_table():
return update_va_table("All", "", "All", "")
def create_va_plot(category):
df = df_va.copy()
if category != "All":
df = df[df['category'] == category]
bg_color, panel_color = "#0e1117", "#161b22"
bar_colors = ['#4cc9f0', '#f72585', '#7209b7', '#3a0ca3', '#4361ee', '#4895ef', '#560bad']
text_color, grid_color = "#c9d1d9", "#30363d"
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
fig.patch.set_facecolor(bg_color)
category_counts = df['category'].value_counts()
ax1.barh(category_counts.index, category_counts.values, color=bar_colors[:len(category_counts)])
ax1.set_facecolor(panel_color)
ax1.set_xlabel('Count', color=text_color, fontsize=11)
ax1.set_title('Models by Category', color=text_color, fontsize=13, fontweight='bold')
ax1.tick_params(colors=text_color)
ax1.grid(axis='x', linestyle='--', alpha=0.3, color=grid_color)
ax1.xaxis.set_major_locator(plt.MaxNLocator(integer=True))
for spine in ['top', 'right', 'left']:
ax1.spines[spine].set_visible(False)
ax1.spines['bottom'].set_color(grid_color)
action_counts = df['action'].value_counts()
ax2.bar(range(len(action_counts)), action_counts.values, color=bar_colors[:len(action_counts)])
ax2.set_facecolor(panel_color)
ax2.set_xticks(range(len(action_counts)))
ax2.set_xticklabels([ACTION_LABELS.get(a, a) for a in action_counts.index], rotation=45, ha='right')
ax2.set_ylabel('Count', color=text_color, fontsize=11)
ax2.set_title('Models by Action Type', color=text_color, fontsize=13, fontweight='bold')
ax2.tick_params(colors=text_color)
ax2.grid(axis='y', linestyle='--', alpha=0.3, color=grid_color)
ax2.yaxis.set_major_locator(plt.MaxNLocator(integer=True))
for spine in ['top', 'right']:
ax2.spines[spine].set_visible(False)
ax2.spines['left'].set_color(grid_color)
ax2.spines['bottom'].set_color(grid_color)
plt.tight_layout()
result = fig
plt.close('all')
return result
# VLA Models
def update_vla_table(category, venue, language, search):
df = df_vla.copy()
if category != "All":
df = df[df['category'] == category]
if venue:
df = df[df['venue'].str.contains(venue, case=False, na=False)]
if language:
df = df[df['language'].str.contains(language, case=False, na=False)]
if search:
mask = df['model'].str.contains(search, case=False, na=False) | df['vision'].str.contains(search, case=False, na=False) | df['language'].str.contains(search, case=False, na=False)
df = df[mask]
display_df = df[['id', 'model', 'venue', 'input', 'dataset', 'vision', 'language', 'action', 'output', 'category']].copy()
display_df['input'] = display_df['input'].apply(format_inputs)
display_df['dataset'] = display_df['dataset'].apply(format_datasets)
display_df['action'] = display_df['action'].map(ACTION_LABELS)
return display_df, f"**๐Ÿ“Š Statistics:** Total {len(df)} models | Categories: {df['category'].nunique()}"
# ๅˆๅง‹ๅŒ–VLA่กจๆ ผๆ•ฐๆฎ
def init_vla_table():
return update_vla_table("All", "", "", "")
def create_vla_plot(category):
df = df_vla.copy()
if category != "All":
df = df[df['category'] == category]
bg_color, panel_color = "#0e1117", "#161b22"
bar_colors = ['#4cc9f0', '#f72585', '#7209b7', '#3a0ca3']
text_color, grid_color = "#c9d1d9", "#30363d"
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
fig.patch.set_facecolor(bg_color)
category_counts = df['category'].value_counts()
ax1.barh(category_counts.index, category_counts.values, color=bar_colors[:len(category_counts)])
ax1.set_facecolor(panel_color)
ax1.set_xlabel('Count', color=text_color, fontsize=11)
ax1.set_title('VLA Models by Category', color=text_color, fontsize=13, fontweight='bold')
ax1.tick_params(colors=text_color)
ax1.grid(axis='x', linestyle='--', alpha=0.3, color=grid_color)
ax1.xaxis.set_major_locator(plt.MaxNLocator(integer=True))
for spine in ['top', 'right', 'left']:
ax1.spines[spine].set_visible(False)
ax1.spines['bottom'].set_color(grid_color)
language_counts = df['language'].value_counts().head(10)
ax2.bar(range(len(language_counts)), language_counts.values, color=bar_colors[:len(language_counts)])
ax2.set_facecolor(panel_color)
ax2.set_xticks(range(len(language_counts)))
ax2.set_xticklabels(language_counts.index, rotation=45, ha='right')
ax2.set_ylabel('Count', color=text_color, fontsize=11)
ax2.set_title('Top Language Models', color=text_color, fontsize=13, fontweight='bold')
ax2.tick_params(colors=text_color)
ax2.grid(axis='y', linestyle='--', alpha=0.3, color=grid_color)
ax2.yaxis.set_major_locator(plt.MaxNLocator(integer=True))
for spine in ['top', 'right']:
ax2.spines[spine].set_visible(False)
ax2.spines['left'].set_color(grid_color)
ax2.spines['bottom'].set_color(grid_color)
plt.tight_layout()
result = fig
plt.close('all')
return result
# Datasets
def update_datasets_table(category, year, search):
df = df_datasets.copy()
if category != "All":
df = df[df['category'] == category]
if year:
try:
df = df[df['year'] == int(year)]
except:
pass
if search:
mask = df['dataset'].str.contains(search, case=False, na=False) | df['sensor'].str.contains(search, case=False, na=False)
df = df[mask]
return df, f"**๐Ÿ“Š Statistics:** Total {len(df)} datasets | Years: {df['year'].nunique()}"
# ๅˆๅง‹ๅŒ–Datasets่กจๆ ผๆ•ฐๆฎ
def init_datasets_table():
return update_datasets_table("All", "", "")
# Benchmark tables update functions
def update_nuscenes_table(category, search):
df = df_nuscenes.copy()
if category != "All":
df = df[df['category'] == category]
if search:
mask = df['model'].str.contains(search, case=False, na=False)
df = df[mask]
return df, f"**๐Ÿ“Š Statistics:** Total {len(df)} models"
# ๅˆๅง‹ๅŒ–nuScenes่กจๆ ผๆ•ฐๆฎ
def init_nuscenes_table():
return update_nuscenes_table("All", "")
def create_nuscenes_plot(category):
"""ๅˆ›ๅปบnuScenes็ปŸ่ฎกๅ›พ่กจ"""
df = df_nuscenes.copy()
if category != "All":
df = df[df['category'] == category]
bg_color, panel_color = "#0e1117", "#161b22"
bar_colors = ['#4cc9f0', '#f72585', '#7209b7', '#3a0ca3']
text_color, grid_color = "#c9d1d9", "#30363d"
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
fig.patch.set_facecolor(bg_color)
# 1. ๆŒ‰็ฑปๅˆซ็ปŸ่ฎก
category_counts = df['category'].value_counts()
ax1.bar(range(len(category_counts)), category_counts.values, color=bar_colors[:len(category_counts)])
ax1.set_facecolor(panel_color)
ax1.set_xticks(range(len(category_counts)))
ax1.set_xticklabels(category_counts.index, rotation=15, ha='right')
ax1.set_ylabel('Count', color=text_color, fontsize=11)
ax1.set_title('Models by Category', color=text_color, fontsize=13, fontweight='bold')
ax1.tick_params(colors=text_color)
ax1.grid(axis='y', linestyle='--', alpha=0.3, color=grid_color)
ax1.yaxis.set_major_locator(plt.MaxNLocator(integer=True))
for spine in ['top', 'right']:
ax1.spines[spine].set_visible(False)
ax1.spines['left'].set_color(grid_color)
ax1.spines['bottom'].set_color(grid_color)
# 2. Top 10 ๆจกๅž‹ๆŒ‰L2 AvgๆŽ’ๅบ
top_models = df.nsmallest(10, 'l2_avg')[['model', 'l2_avg']].dropna()
ax2.barh(range(len(top_models)), top_models['l2_avg'].values, color=bar_colors[2])
ax2.set_facecolor(panel_color)
ax2.set_yticks(range(len(top_models)))
ax2.set_yticklabels(top_models['model'].values, fontsize=9)
ax2.set_xlabel('L2 Avg (m)', color=text_color, fontsize=11)
ax2.set_title('Top 10 Models by L2 Error', color=text_color, fontsize=13, fontweight='bold')
ax2.tick_params(colors=text_color)
ax2.grid(axis='x', linestyle='--', alpha=0.3, color=grid_color)
ax2.invert_yaxis()
for spine in ['top', 'right', 'left']:
ax2.spines[spine].set_visible(False)
ax2.spines['bottom'].set_color(grid_color)
plt.tight_layout()
result = fig
plt.close('all')
return result
def update_wod_table(category, search):
df = df_wod.copy()
if category != "All":
df = df[df['category'] == category]
if search:
mask = df['model'].str.contains(search, case=False, na=False)
df = df[mask]
return df, f"**๐Ÿ“Š Statistics:** Total {len(df)} models"
# ๅˆๅง‹ๅŒ–WOD่กจๆ ผๆ•ฐๆฎ
def init_wod_table():
return update_wod_table("All", "")
def create_wod_plot(category):
"""ๅˆ›ๅปบWOD็ปŸ่ฎกๅ›พ่กจ"""
df = df_wod.copy()
if category != "All":
df = df[df['category'] == category]
bg_color, panel_color = "#0e1117", "#161b22"
bar_colors = ['#4cc9f0', '#f72585', '#7209b7', '#3a0ca3']
text_color, grid_color = "#c9d1d9", "#30363d"
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
fig.patch.set_facecolor(bg_color)
# 1. ๆŒ‰็ฑปๅˆซ็ปŸ่ฎก
category_counts = df['category'].value_counts()
ax1.bar(range(len(category_counts)), category_counts.values, color=bar_colors[:len(category_counts)])
ax1.set_facecolor(panel_color)
ax1.set_xticks(range(len(category_counts)))
ax1.set_xticklabels(category_counts.index, rotation=15, ha='right')
ax1.set_ylabel('Count', color=text_color, fontsize=11)
ax1.set_title('Models by Category', color=text_color, fontsize=13, fontweight='bold')
ax1.tick_params(colors=text_color)
ax1.grid(axis='y', linestyle='--', alpha=0.3, color=grid_color)
ax1.yaxis.set_major_locator(plt.MaxNLocator(integer=True))
for spine in ['top', 'right']:
ax1.spines[spine].set_visible(False)
ax1.spines['left'].set_color(grid_color)
ax1.spines['bottom'].set_color(grid_color)
# 2. Topๆจกๅž‹ๆŒ‰RFS OverallๆŽ’ๅบ
top_models = df.nlargest(10, 'rfs_overall')[['model', 'rfs_overall']].dropna()
if len(top_models) > 0:
ax2.barh(range(len(top_models)), top_models['rfs_overall'].values, color=bar_colors[1])
ax2.set_facecolor(panel_color)
ax2.set_yticks(range(len(top_models)))
ax2.set_yticklabels(top_models['model'].values, fontsize=9)
ax2.set_xlabel('RFS Overall', color=text_color, fontsize=11)
ax2.set_title('Top Models by RFS Overall', color=text_color, fontsize=13, fontweight='bold')
ax2.tick_params(colors=text_color)
ax2.grid(axis='x', linestyle='--', alpha=0.3, color=grid_color)
ax2.invert_yaxis()
else:
ax2.text(0.5, 0.5, 'No data available', ha='center', va='center',
color=text_color, fontsize=12, transform=ax2.transAxes)
ax2.set_facecolor(panel_color)
for spine in ['top', 'right', 'left']:
ax2.spines[spine].set_visible(False)
ax2.spines['bottom'].set_color(grid_color)
plt.tight_layout()
result = fig
plt.close('all')
return result
def update_navsim_table(category, search):
df = df_navsim.copy()
if category != "All":
df = df[df['category'] == category]
if search:
mask = df['model'].str.contains(search, case=False, na=False)
df = df[mask]
return df, f"**๐Ÿ“Š Statistics:** Total {len(df)} models"
# ๅˆๅง‹ๅŒ–NAVSIM่กจๆ ผๆ•ฐๆฎ
def init_navsim_table():
return update_navsim_table("All", "")
def create_navsim_plot(category):
"""ๅˆ›ๅปบNAVSIM็ปŸ่ฎกๅ›พ่กจ"""
df = df_navsim.copy()
if category != "All":
df = df[df['category'] == category]
bg_color, panel_color = "#0e1117", "#161b22"
bar_colors = ['#4cc9f0', '#f72585', '#7209b7', '#3a0ca3']
text_color, grid_color = "#c9d1d9", "#30363d"
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
fig.patch.set_facecolor(bg_color)
# 1. ๆŒ‰็ฑปๅˆซ็ปŸ่ฎก
category_counts = df['category'].value_counts()
ax1.bar(range(len(category_counts)), category_counts.values, color=bar_colors[:len(category_counts)])
ax1.set_facecolor(panel_color)
ax1.set_xticks(range(len(category_counts)))
ax1.set_xticklabels(category_counts.index, rotation=15, ha='right')
ax1.set_ylabel('Count', color=text_color, fontsize=11)
ax1.set_title('Models by Category', color=text_color, fontsize=13, fontweight='bold')
ax1.tick_params(colors=text_color)
ax1.grid(axis='y', linestyle='--', alpha=0.3, color=grid_color)
ax1.yaxis.set_major_locator(plt.MaxNLocator(integer=True))
for spine in ['top', 'right']:
ax1.spines[spine].set_visible(False)
ax1.spines['left'].set_color(grid_color)
ax1.spines['bottom'].set_color(grid_color)
# 2. Top 10 ๆจกๅž‹ๆŒ‰PDMSๆŽ’ๅบ
top_models = df.nlargest(10, 'pdms')[['model', 'pdms']]
ax2.barh(range(len(top_models)), top_models['pdms'].values, color=bar_colors[0])
ax2.set_facecolor(panel_color)
ax2.set_yticks(range(len(top_models)))
ax2.set_yticklabels(top_models['model'].values, fontsize=9)
ax2.set_xlabel('PDMS Score', color=text_color, fontsize=11)
ax2.set_title('Top 10 Models by PDMS', color=text_color, fontsize=13, fontweight='bold')
ax2.tick_params(colors=text_color)
ax2.grid(axis='x', linestyle='--', alpha=0.3, color=grid_color)
ax2.invert_yaxis()
for spine in ['top', 'right', 'left']:
ax2.spines[spine].set_visible(False)
ax2.spines['bottom'].set_color(grid_color)
plt.tight_layout()
result = fig
plt.close('all')
return result
def update_bench2drive_table(category, search):
df = df_bench2drive.copy()
if category != "All":
df = df[df['category'] == category]
if search:
mask = df['model'].str.contains(search, case=False, na=False)
df = df[mask]
return df, f"**๐Ÿ“Š Statistics:** Total {len(df)} models"
# ๅˆๅง‹ๅŒ–Bench2Drive่กจๆ ผๆ•ฐๆฎ
def init_bench2drive_table():
return update_bench2drive_table("All", "")
def create_bench2drive_plot(category):
"""ๅˆ›ๅปบBench2Drive็ปŸ่ฎกๅ›พ่กจ"""
df = df_bench2drive.copy()
if category != "All":
df = df[df['category'] == category]
bg_color, panel_color = "#0e1117", "#161b22"
bar_colors = ['#4cc9f0', '#f72585', '#7209b7', '#3a0ca3']
text_color, grid_color = "#c9d1d9", "#30363d"
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))
fig.patch.set_facecolor(bg_color)
# 1. ๆŒ‰็ฑปๅˆซ็ปŸ่ฎก
category_counts = df['category'].value_counts()
ax1.bar(range(len(category_counts)), category_counts.values, color=bar_colors[:len(category_counts)])
ax1.set_facecolor(panel_color)
ax1.set_xticks(range(len(category_counts)))
ax1.set_xticklabels(category_counts.index, rotation=15, ha='right')
ax1.set_ylabel('Count', color=text_color, fontsize=11)
ax1.set_title('Models by Category', color=text_color, fontsize=13, fontweight='bold')
ax1.tick_params(colors=text_color)
ax1.grid(axis='y', linestyle='--', alpha=0.3, color=grid_color)
ax1.yaxis.set_major_locator(plt.MaxNLocator(integer=True))
for spine in ['top', 'right']:
ax1.spines[spine].set_visible(False)
ax1.spines['left'].set_color(grid_color)
ax1.spines['bottom'].set_color(grid_color)
# 2. Top 10 ๆจกๅž‹ๆŒ‰DSๆŽ’ๅบ
top_models = df.nlargest(10, 'ds')[['model', 'ds']]
ax2.barh(range(len(top_models)), top_models['ds'].values, color=bar_colors[3])
ax2.set_facecolor(panel_color)
ax2.set_yticks(range(len(top_models)))
ax2.set_yticklabels(top_models['model'].values, fontsize=9)
ax2.set_xlabel('Driving Score', color=text_color, fontsize=11)
ax2.set_title('Top 10 Models by DS', color=text_color, fontsize=13, fontweight='bold')
ax2.tick_params(colors=text_color)
ax2.grid(axis='x', linestyle='--', alpha=0.3, color=grid_color)
ax2.invert_yaxis()
for spine in ['top', 'right', 'left']:
ax2.spines[spine].set_visible(False)
ax2.spines['bottom'].set_color(grid_color)
plt.tight_layout()
result = fig
plt.close('all')
return result
# ้ข„็”Ÿๆˆๅˆๅง‹ๆ•ฐๆฎๅ’Œๅ›พ่กจ
initial_va_df, initial_va_stats = init_va_table()
initial_vla_df, initial_vla_stats = init_vla_table()
initial_datasets_df, initial_datasets_stats = init_datasets_table()
initial_nuscenes_df, initial_nuscenes_stats = init_nuscenes_table()
initial_wod_df, initial_wod_stats = init_wod_table()
initial_navsim_df, initial_navsim_stats = init_navsim_table()
initial_bench2drive_df, initial_bench2drive_stats = init_bench2drive_table()
initial_va_plot = create_va_plot("All")
initial_vla_plot = create_vla_plot("All")
initial_nuscenes_plot = create_nuscenes_plot("All")
initial_wod_plot = create_wod_plot("All")
initial_navsim_plot = create_navsim_plot("All")
initial_bench2drive_plot = create_bench2drive_plot("All")
# Gradio UI
with gr.Blocks(css="#title {text-align: center;} .gradio-container {max-width: 100% !important;}") as demo:
gr.Markdown("# ๐Ÿš— VLA for Autonomous Driving: Model Leaderboard", elem_id="title")
gr.Markdown("### ๐Ÿ“„ Survey: *Vision-Language-Action Models for Autonomous Driving*")
with gr.Tabs():
# Tab 1: VA Models
with gr.Tab("Vision-Action Models (VA)"):
gr.Markdown("### Table 1: Vision-Action Models in Autonomous Driving")
va_stats = gr.Markdown(initial_va_stats)
with gr.Row():
va_category = gr.Dropdown(label="Category", choices=["All"] + sorted(df_va['category'].unique().tolist()), value="All")
va_venue = gr.Textbox(label="Filter by Venue", placeholder="e.g., CVPR, ICCV...")
va_action = gr.Dropdown(label="Action Type", choices=["All"] + sorted(df_va['action'].unique().tolist()), value="All")
va_search = gr.Textbox(label="Search", placeholder="e.g., TransFuser...")
va_update_btn = gr.Button("๐Ÿ” Update", variant="primary")
va_table = gr.Dataframe(value=initial_va_df, label="VA Models", interactive=False, wrap=True)
va_plot = gr.Plot(label="Statistics", value=initial_va_plot, format="png")
va_update_btn.click(fn=lambda cat, ven, act, search: (*update_va_table(cat, ven, act, search), create_va_plot(cat)),
inputs=[va_category, va_venue, va_action, va_search], outputs=[va_table, va_stats, va_plot])
# Tab 2: VLA Models
with gr.Tab("Vision-Language-Action Models (VLA)"):
gr.Markdown("### Table 3: Vision-Language-Action Models in Autonomous Driving")
vla_stats = gr.Markdown(initial_vla_stats)
with gr.Row():
vla_category = gr.Dropdown(label="Category", choices=["All"] + sorted(df_vla['category'].unique().tolist()), value="All")
vla_venue = gr.Textbox(label="Filter by Venue", placeholder="e.g., CVPR...")
vla_language = gr.Textbox(label="Filter by Language Model", placeholder="e.g., Qwen...")
vla_search = gr.Textbox(label="Search", placeholder="e.g., AutoVLA...")
vla_update_btn = gr.Button("๐Ÿ” Update", variant="primary")
vla_table = gr.Dataframe(value=initial_vla_df, label="VLA Models", interactive=False, wrap=True)
vla_plot = gr.Plot(label="Statistics", value=initial_vla_plot, format="png")
vla_update_btn.click(fn=lambda cat, ven, lang, search: (*update_vla_table(cat, ven, lang, search), create_vla_plot(cat)),
inputs=[vla_category, vla_venue, vla_language, vla_search], outputs=[vla_table, vla_stats, vla_plot])
# Tab 3: Datasets
with gr.Tab("Datasets & Benchmarks"):
gr.Markdown("### Table 4: Summary of Datasets & Benchmarks")
datasets_stats = gr.Markdown(initial_datasets_stats)
with gr.Row():
datasets_category = gr.Dropdown(label="Category", choices=["All"] + sorted(df_datasets['category'].unique().tolist()), value="All")
datasets_year = gr.Textbox(label="Filter by Year", placeholder="e.g., 2024...")
datasets_search = gr.Textbox(label="Search", placeholder="e.g., nuScenes...")
datasets_update_btn = gr.Button("๐Ÿ” Update", variant="primary")
datasets_table = gr.Dataframe(value=initial_datasets_df, label="Datasets", interactive=False, wrap=True)
datasets_update_btn.click(fn=update_datasets_table, inputs=[datasets_category, datasets_year, datasets_search], outputs=[datasets_table, datasets_stats])
# Tab 4: nuScenes
with gr.Tab("nuScenes Open-Loop"):
gr.Markdown("### Table 5: Open-Loop Planning Results on nuScenes")
gr.Markdown("**Metrics:** L2 Error (m) โ†“ | Collision Rate โ†“")
nuscenes_stats = gr.Markdown(initial_nuscenes_stats)
with gr.Row():
nuscenes_category = gr.Dropdown(label="Category", choices=["All", "Vision-Action", "Vision-Language-Action"], value="All")
nuscenes_search = gr.Textbox(label="Search Model", placeholder="e.g., UniAD, EMMA...")
nuscenes_update_btn = gr.Button("๐Ÿ” Update", variant="primary")
nuscenes_table = gr.Dataframe(value=initial_nuscenes_df, label="nuScenes Results", interactive=False, wrap=True)
nuscenes_plot = gr.Plot(label="Statistics", value=initial_nuscenes_plot, format="png")
nuscenes_update_btn.click(fn=lambda cat, search: (*update_nuscenes_table(cat, search), create_nuscenes_plot(cat)),
inputs=[nuscenes_category, nuscenes_search],
outputs=[nuscenes_table, nuscenes_stats, nuscenes_plot])
# Tab 5: WOD-E2E
with gr.Tab("WOD-E2E"):
gr.Markdown("### Table 6: Results on WOD-E2E Test Split")
gr.Markdown("**Metrics:** RFS (Overall/Spotlight) โ†‘ | ADE (5s/3s) โ†“")
wod_stats = gr.Markdown(initial_wod_stats)
with gr.Row():
wod_category = gr.Dropdown(label="Category", choices=["All", "Vision-Action", "Vision-Language-Action"], value="All")
wod_search = gr.Textbox(label="Search Model", placeholder="e.g., AutoVLA...")
wod_update_btn = gr.Button("๐Ÿ” Update", variant="primary")
wod_table = gr.Dataframe(value=initial_wod_df, label="WOD-E2E Results", interactive=False, wrap=True)
wod_plot = gr.Plot(label="Statistics", value=initial_wod_plot, format="png")
wod_update_btn.click(fn=lambda cat, search: (*update_wod_table(cat, search), create_wod_plot(cat)),
inputs=[wod_category, wod_search],
outputs=[wod_table, wod_stats, wod_plot])
# Tab 6: NAVSIM
with gr.Tab("NAVSIM Closed-Loop"):
gr.Markdown("### Table 7: Closed-Loop Results on NAVSIM")
gr.Markdown("**Metrics:** NC/DAC/TTC/Comf/EP/PDMS โ†‘")
navsim_stats = gr.Markdown(initial_navsim_stats)
with gr.Row():
navsim_category = gr.Dropdown(label="Category", choices=["All", "Vision-Action", "Vision-Language-Action"], value="All")
navsim_search = gr.Textbox(label="Search Model", placeholder="e.g., ReflectDrive...")
navsim_update_btn = gr.Button("๐Ÿ” Update", variant="primary")
navsim_table = gr.Dataframe(value=initial_navsim_df, label="NAVSIM Results", interactive=False, wrap=True)
navsim_plot = gr.Plot(label="Statistics", value=initial_navsim_plot, format="png")
navsim_update_btn.click(fn=lambda cat, search: (*update_navsim_table(cat, search), create_navsim_plot(cat)),
inputs=[navsim_category, navsim_search],
outputs=[navsim_table, navsim_stats, navsim_plot])
# Tab 7: Bench2Drive
with gr.Tab("Bench2Drive"):
gr.Markdown("### Table 8: Closed-Loop & Open-Loop Results on Bench2Drive")
gr.Markdown("**Metrics:** DS/SR โ†‘ | L2 Avg โ†“")
bench2drive_stats = gr.Markdown(initial_bench2drive_stats)
with gr.Row():
bench2drive_category = gr.Dropdown(label="Category", choices=["All", "Vision-Action", "Vision-Language-Action"], value="All")
bench2drive_search = gr.Textbox(label="Search Model", placeholder="e.g., SimLingo...")
bench2drive_update_btn = gr.Button("๐Ÿ” Update", variant="primary")
bench2drive_table = gr.Dataframe(value=initial_bench2drive_df, label="Bench2Drive Results", interactive=False, wrap=True)
bench2drive_plot = gr.Plot(label="Statistics", value=initial_bench2drive_plot, format="png")
bench2drive_update_btn.click(fn=lambda cat, search: (*update_bench2drive_table(cat, search), create_bench2drive_plot(cat)),
inputs=[bench2drive_category, bench2drive_search],
outputs=[bench2drive_table, bench2drive_stats, bench2drive_plot])
gr.Markdown("""
---
**Legend:** ๐Ÿ“ท Camera | ๐Ÿ” LiDAR | โš™๏ธ Status | ๐Ÿ’ญ Prompt | ๐Ÿ“ Instruction | ๐ŸŒ† Scene | ๐Ÿšฆ Traffic | ๐Ÿ“š Context
**Action Types:** RL=Reinforcement Learning | REG=Decoder+MLP | SEL=Selection | GEN=Generative | LH=Language Head
๐Ÿ“– **Project:** [https://worldbench.github.io/vla4ad](https://worldbench.github.io/vla4ad) | ๐Ÿ”— **GitHub:** [https://github.com/worldbench/awesome-vla-for-ad](https://github.com/worldbench/awesome-vla-for-ad)
""")
if __name__ == "__main__":
demo.launch()