Spaces:
Running
Running
| import gradio as gr | |
| import json | |
| import pandas as pd | |
| import numpy as np | |
| import plotly.express as px | |
| import plotly.graph_objects as go | |
| from plotly.subplots import make_subplots | |
| import os | |
| import traceback | |
| from datetime import datetime | |
| from packaging import version | |
| # Color scheme for charts | |
| COLORS = px.colors.qualitative.Plotly | |
| # Line colors for radar charts | |
| line_colors = [ | |
| "#EE4266", | |
| "#00a6ed", | |
| "#ECA72C", | |
| "#B42318", | |
| "#3CBBB1", | |
| ] | |
| # Fill colors for radar charts | |
| fill_colors = [ | |
| "rgba(238,66,102,0.05)", | |
| "rgba(0,166,237,0.05)", | |
| "rgba(236,167,44,0.05)", | |
| "rgba(180,35,24,0.05)", | |
| "rgba(60,187,177,0.05)", | |
| ] | |
| # Define the question categories | |
| QUESTION_CATEGORIES = ["simple", "set", "mh", "cond", "comp"] | |
| METRIC_TYPES = ["retrieval", "generation"] | |
| def load_results(): | |
| """Load results from the results.json file.""" | |
| try: | |
| # Get the directory of the current script | |
| script_dir = os.path.dirname(os.path.abspath(__file__)) | |
| # Build the path to results.json | |
| results_path = os.path.join(script_dir, 'results.json') | |
| print(f"Loading results from: {results_path}") | |
| with open(results_path, 'r', encoding='utf-8') as f: | |
| results = json.load(f) | |
| print(f"Successfully loaded results with {len(results.get('items', {}))} version(s)") | |
| return results | |
| except FileNotFoundError: | |
| # Return empty structure if file doesn't exist | |
| print(f"Results file not found, creating empty structure") | |
| return {"items": {}, "last_version": "1.0", "n_questions": "0"} | |
| except Exception as e: | |
| print(f"Error loading results: {e}") | |
| print(traceback.format_exc()) | |
| return {"items": {}, "last_version": "1.0", "n_questions": "0"} | |
| def filter_and_process_results(results, n_versions, only_actual_versions): | |
| """Filter results by version and process them for display.""" | |
| if not results or "items" not in results: | |
| return pd.DataFrame(), [], [], [] | |
| all_items = results["items"] | |
| last_version_str = results.get("last_version", "1.0") | |
| last_version = version.parse(last_version_str) | |
| print(f"Last version: {last_version_str}") | |
| # Group items by model_name | |
| model_groups = {} | |
| for version_str, version_items in all_items.items(): | |
| version_obj = version.parse(version_str) | |
| for item_id, item in version_items.items(): | |
| model_name = item.get("model_name", "Unknown") | |
| if model_name not in model_groups: | |
| model_groups[model_name] = [] | |
| # Add version info to the item (both as string and as parsed version object for comparison) | |
| item["version_str"] = version_str | |
| item["version_obj"] = version_obj | |
| model_groups[model_name].append(item) | |
| rows = [] | |
| for model_name, items in model_groups.items(): | |
| # Sort items by version (newest first) | |
| items.sort(key=lambda x: x["version_obj"], reverse=True) | |
| # Filter versions based on selection | |
| filtered_items = [] | |
| if only_actual_versions: | |
| # Get the n most recent actual dataset versions | |
| all_versions = sorted([version.parse(v_str) for v_str in all_items.keys()], reverse=True) | |
| # Take at most n_versions | |
| versions_to_consider = all_versions[:n_versions] if all_versions else [] | |
| # Filter items that match those versions | |
| filtered_items = [item for item in items if any(item["version_obj"] == v for v in versions_to_consider)] | |
| else: | |
| # Consider n_versions most recent items for this model | |
| filtered_items = items[:n_versions] | |
| if not filtered_items: | |
| continue | |
| config = filtered_items[0]["config"] # Use config from most recent version | |
| # Create row with basic info | |
| row = { | |
| 'Model': model_name, | |
| 'Embeddings': config.get('embedding_model', 'N/A'), | |
| 'Retriever': config.get('retriever_type', 'N/A'), | |
| 'Top-K': config.get('retrieval_config', {}).get('top_k', 'N/A'), | |
| 'Versions': ", ".join([item["version_str"] for item in filtered_items]), | |
| 'Last Updated': filtered_items[0].get("timestamp", "") | |
| } | |
| # Format timestamp if available | |
| if row['Last Updated']: | |
| try: | |
| dt = datetime.fromisoformat(row['Last Updated'].replace('Z', '+00:00')) | |
| row['Last Updated'] = dt.strftime("%Y-%m-%d") | |
| except: | |
| pass | |
| # Process metrics based on categories | |
| category_metrics = { | |
| category: { | |
| metric_type: { | |
| "avg": 0.0, | |
| "count": 0 | |
| } for metric_type in METRIC_TYPES | |
| } for category in QUESTION_CATEGORIES | |
| } | |
| # Collect metrics by category | |
| for item in filtered_items: | |
| metrics = item.get("metrics", {}) | |
| for category in QUESTION_CATEGORIES: | |
| if category in metrics: | |
| for metric_type in METRIC_TYPES: | |
| if metric_type in metrics[category]: | |
| metric_values = metrics[category][metric_type] | |
| avg_value = sum(metric_values.values()) / len(metric_values) | |
| # Add to the running sum for this category and metric type | |
| category_metrics[category][metric_type]["avg"] += avg_value | |
| category_metrics[category][metric_type]["count"] += 1 | |
| # Calculate averages and add to row | |
| for category in QUESTION_CATEGORIES: | |
| for metric_type in METRIC_TYPES: | |
| metric_data = category_metrics[category][metric_type] | |
| if metric_data["count"] > 0: | |
| avg_value = metric_data["avg"] / metric_data["count"] | |
| # Add to row with appropriate column name | |
| col_name = f"{category}_{metric_type}" | |
| row[col_name] = round(avg_value, 4) | |
| # Calculate overall averages for each metric type | |
| for metric_type in METRIC_TYPES: | |
| total_sum = 0 | |
| total_count = 0 | |
| for category in QUESTION_CATEGORIES: | |
| metric_data = category_metrics[category][metric_type] | |
| if metric_data["count"] > 0: | |
| total_sum += metric_data["avg"] | |
| total_count += metric_data["count"] | |
| if total_count > 0: | |
| row[f"{metric_type}_avg"] = round(total_sum / total_count, 4) | |
| rows.append(row) | |
| # Create DataFrame | |
| df = pd.DataFrame(rows) | |
| # Get lists of metrics for each category | |
| category_metrics = [] | |
| for category in QUESTION_CATEGORIES: | |
| metrics = [] | |
| for metric_type in METRIC_TYPES: | |
| col_name = f"{category}_{metric_type}" | |
| if col_name in df.columns: | |
| metrics.append(col_name) | |
| if metrics: | |
| category_metrics.append((category, metrics)) | |
| # Define retrieval and generation columns for radar charts | |
| retrieval_metrics = [f"{category}_retrieval" for category in QUESTION_CATEGORIES if f"{category}_retrieval" in df.columns] | |
| generation_metrics = [f"{category}_generation" for category in QUESTION_CATEGORIES if f"{category}_generation" in df.columns] | |
| return df, retrieval_metrics, generation_metrics, category_metrics | |
| def create_radar_chart(df, selected_models, metrics, title): | |
| """Create a radar chart for the selected models and metrics.""" | |
| if not metrics or len(selected_models) == 0: | |
| # Return empty figure if no metrics or models selected | |
| fig = go.Figure() | |
| fig.update_layout( | |
| title=title, | |
| title_font_size=16, | |
| height=400, | |
| width=500, | |
| margin=dict(l=30, r=30, t=50, b=30) | |
| ) | |
| return fig | |
| # Filter dataframe for selected models | |
| filtered_df = df[df['Model'].isin(selected_models)] | |
| if filtered_df.empty: | |
| # Return empty figure if no data | |
| fig = go.Figure() | |
| fig.update_layout( | |
| title=title, | |
| title_font_size=16, | |
| height=400, | |
| width=500, | |
| margin=dict(l=30, r=30, t=50, b=30) | |
| ) | |
| return fig | |
| # Limit to top 5 models for better visualization (similar to inspiration file) | |
| if len(filtered_df) > 5: | |
| filtered_df = filtered_df.head(5) | |
| # Prepare data for radar chart | |
| categories = [m.split('_', 1)[0] for m in metrics] # Get category name (simple, set, etc.) | |
| fig = go.Figure() | |
| # Process in reverse order to match inspiration file | |
| for i, (_, row) in enumerate(filtered_df.iterrows()): | |
| values = [row[m] for m in metrics] | |
| # Close the loop for radar chart | |
| values.append(values[0]) | |
| categories_loop = categories + [categories[0]] | |
| fig.add_trace(go.Scatterpolar( | |
| name=row['Model'], | |
| r=values, | |
| theta=categories_loop, | |
| showlegend=True, | |
| mode="lines", | |
| line=dict(width=2, color=line_colors[i % len(line_colors)]), | |
| fill="toself", | |
| fillcolor=fill_colors[i % len(fill_colors)] | |
| )) | |
| fig.update_layout( | |
| font=dict(size=13, color="black"), | |
| template="plotly_white", | |
| polar=dict( | |
| radialaxis=dict( | |
| visible=True, | |
| gridcolor="black", | |
| linecolor="rgba(0,0,0,0)", | |
| gridwidth=1, | |
| showticklabels=False, | |
| ticks="", | |
| range=[0, 1] # Ensure consistent range for scores | |
| ), | |
| angularaxis=dict( | |
| gridcolor="black", | |
| gridwidth=1.5, | |
| linecolor="rgba(0,0,0,0)" | |
| ), | |
| ), | |
| legend=dict( | |
| orientation="h", | |
| yanchor="bottom", | |
| y=-0.35, | |
| xanchor="center", | |
| x=0.4, | |
| itemwidth=30, | |
| font=dict(size=13), | |
| entrywidth=0.6, | |
| entrywidthmode="fraction", | |
| ), | |
| margin=dict(l=0, r=16, t=30, b=30), | |
| autosize=True, | |
| ) | |
| return fig | |
| def create_summary_df(df, retrieval_metrics, generation_metrics): | |
| """Create a summary dataframe with averaged metrics for display.""" | |
| if df.empty: | |
| return pd.DataFrame() | |
| summary_df = df.copy() | |
| # Add retrieval average | |
| if retrieval_metrics: | |
| retrieval_avg = summary_df[retrieval_metrics].mean(axis=1).round(4) | |
| summary_df['Retrieval (avg)'] = retrieval_avg | |
| # Add generation average | |
| if generation_metrics: | |
| generation_avg = summary_df[generation_metrics].mean(axis=1).round(4) | |
| summary_df['Generation (avg)'] = generation_avg | |
| # Add total score if both averages exist | |
| if 'Retrieval (avg)' in summary_df.columns and 'Generation (avg)' in summary_df.columns: | |
| summary_df['Total Score'] = summary_df['Retrieval (avg)'] + summary_df['Generation (avg)'] | |
| summary_df = summary_df.sort_values('Total Score', ascending=False) | |
| # Select columns for display | |
| summary_cols = ['Model', 'Embeddings', 'Retriever', 'Top-K'] | |
| if 'Retrieval (avg)' in summary_df.columns: | |
| summary_cols.append('Retrieval (avg)') | |
| if 'Generation (avg)' in summary_df.columns: | |
| summary_cols.append('Generation (avg)') | |
| if 'Total Score' in summary_df.columns: | |
| summary_cols.append('Total Score') | |
| if 'Versions' in summary_df.columns: | |
| summary_cols.append('Versions') | |
| if 'Last Updated' in summary_df.columns: | |
| summary_cols.append('Last Updated') | |
| return summary_df[summary_cols] | |
| def create_category_df(df, category, retrieval_col, generation_col): | |
| """Create a dataframe for a specific category with detailed metrics.""" | |
| if df.empty or retrieval_col not in df.columns or generation_col not in df.columns: | |
| return pd.DataFrame() | |
| category_df = df.copy() | |
| # Calculate total score for this category | |
| category_df[f'{category} Score'] = category_df[retrieval_col] + category_df[generation_col] | |
| # Sort by total score | |
| category_df = category_df.sort_values(f'{category} Score', ascending=False) | |
| # Select columns for display | |
| category_cols = ['Model', 'Embeddings', 'Retriever', retrieval_col, generation_col, f'{category} Score'] | |
| # Rename columns for display | |
| category_df = category_df[category_cols].rename(columns={ | |
| retrieval_col: 'Retrieval', | |
| generation_col: 'Generation' | |
| }) | |
| return category_df | |
| # Load initial data | |
| results = load_results() | |
| last_version = results.get("last_version", "1.0") | |
| n_questions = results.get("n_questions", "100") | |
| date_title = results.get("date_title", "---") | |
| # Initial data processing | |
| df, retrieval_metrics, generation_metrics, category_metrics = filter_and_process_results( | |
| results, n_versions=1, only_actual_versions=True | |
| ) | |
| # Pre-generate charts for initial display | |
| default_models = df['Model'].head(5).tolist() if not df.empty else [] | |
| initial_gen_chart = create_radar_chart(df, default_models, generation_metrics, "Performance on Generation Tasks") | |
| initial_ret_chart = create_radar_chart(df, default_models, retrieval_metrics, "Performance on Retrieval Tasks") | |
| # Create summary dataframe | |
| summary_df = create_summary_df(df, retrieval_metrics, generation_metrics) | |
| with gr.Blocks(css=""" | |
| .title-container { | |
| text-align: center; | |
| margin-bottom: 10px; | |
| } | |
| .description-text { | |
| text-align: left; | |
| padding: 10px; | |
| margin-bottom: 0px; | |
| } | |
| .version-info { | |
| text-align: center; | |
| padding: 10px; | |
| background-color: #f0f0f0; | |
| border-radius: 8px; | |
| margin-bottom: 15px; | |
| } | |
| .version-selector { | |
| padding: 15px; | |
| border: 1px solid #ddd; | |
| border-radius: 8px; | |
| margin-bottom: 20px; | |
| background-color: #f9f9f9; | |
| height: 100%; | |
| } | |
| .citation-block { | |
| padding: 15px; | |
| border: 1px solid #ddd; | |
| border-radius: 8px; | |
| margin-bottom: 20px; | |
| background-color: #f9f9f9; | |
| font-family: monospace; | |
| font-size: 14px; | |
| overflow-x: auto; | |
| height: 100%; | |
| } | |
| .flex-row-container { | |
| display: flex; | |
| justify-content: space-between; | |
| gap: 20px; | |
| width: 100%; | |
| } | |
| .charts-container { | |
| display: flex; | |
| gap: 20px; | |
| margin-bottom: 20px; | |
| } | |
| .chart-box { | |
| flex: 1; | |
| border: 1px solid #eee; | |
| border-radius: 8px; | |
| padding: 10px; | |
| background-color: white; | |
| min-height: 550px; /* Increased height to accommodate legend at bottom */ | |
| } | |
| .metrics-table { | |
| border: 1px solid #eee; | |
| border-radius: 8px; | |
| padding: 15px; | |
| background-color: white; | |
| } | |
| .info-text { | |
| font-size: 0.9em; | |
| font-style: italic; | |
| color: #666; | |
| margin-top: 5px; | |
| } | |
| footer { | |
| text-align: center; | |
| margin-top: 30px; | |
| font-size: 0.9em; | |
| color: #666; | |
| } | |
| /* Style for selected rows */ | |
| table tbody tr.selected { | |
| background-color: rgba(25, 118, 210, 0.1) !important; | |
| border-left: 3px solid #1976d2; | |
| } | |
| /* Add this class via JavaScript */ | |
| .gr-table tbody tr.selected td:first-child { | |
| font-weight: bold; | |
| color: #1976d2; | |
| } | |
| .category-tab { | |
| padding: 10px; | |
| } | |
| .chart-title { | |
| font-size: 1.2em; | |
| font-weight: bold; | |
| margin-bottom: 10px; | |
| text-align: center; | |
| } | |
| .clear-charts-button { | |
| display: flex; | |
| justify-content: center; | |
| margin-top: 10px; | |
| margin-bottom: 20px; | |
| } | |
| """) as demo: | |
| # Title | |
| with gr.Row(elem_classes=["title-container"]): | |
| gr.Markdown("# 🐙 Dynamic RAG Benchmark") | |
| # Version info | |
| with gr.Row(elem_classes=["description-text"]): | |
| gr.Markdown(f"На этом лидерборде можно сравнить RAG системы в разрезе генеративных и поисковых метрик моделей по вопросам разного типа (простые вопросы, сравнения, multi-hop, условные и др.). <li>Вопросы автоматичеки генерируются на основе новостных источников.</li><li>Обновление датасета с вопросами происходит регулярно, при этом пересчитываются все метрики для открытых моделей.</li><li>Для пользовательских сабмитов учитываются последние посчитанные для них метрики.</li><li>Чтобы посчитать ранее отправленную конфигурацию на последней версии данных, используйте submit_id, полученный при первой отправке через клиент (см. инструкцию ниже).</li>") | |
| # Version info | |
| with gr.Row(elem_classes=["version-info"]): | |
| gr.Markdown(f"## Версия {last_version} → {n_questions} вопросов, сгенерированных по новостным источникам → {date_title}") | |
| # Radar Charts | |
| with gr.Row(elem_classes=["charts-container"]): | |
| with gr.Column(elem_classes=["chart-box"]): | |
| gr.Markdown("### Генеративные метрики", elem_classes=["chart-title"]) | |
| generation_chart = gr.Plot(value=initial_gen_chart) | |
| with gr.Column(elem_classes=["chart-box"]): | |
| gr.Markdown("### Метрики поиска", elem_classes=["chart-title"]) | |
| retrieval_chart = gr.Plot(value=initial_ret_chart) | |
| # Clear Charts Button | |
| with gr.Row(elem_classes=["clear-charts-button"]): | |
| clear_charts_btn = gr.Button("Очистить графики", variant="secondary") | |
| # Metrics table with tabs | |
| with gr.Tabs(elem_classes=["metrics-table"]) as metrics_tabs: | |
| with gr.TabItem("Общая таблица"): | |
| selected_models = gr.State(default_models) | |
| # If dataframe is empty, show a message | |
| if df.empty: | |
| gr.Markdown("No data available. Please submit some results.") | |
| metrics_table = gr.DataFrame() | |
| else: | |
| metrics_table = gr.DataFrame( | |
| value=summary_df, | |
| headers=summary_df.columns.tolist(), | |
| datatype=["str"] * len(summary_df.columns), | |
| row_count=(min(10, len(summary_df)) if not summary_df.empty else 0), | |
| col_count=(len(summary_df.columns) if not summary_df.empty else 0), | |
| interactive=False, | |
| wrap=True | |
| ) | |
| with gr.TabItem("По типам вопросов"): | |
| # Create tabs for each category | |
| category_tabs = gr.Tabs() | |
| category_tables = {} | |
| # Dictionary to map category codes to display names | |
| category_display_names = { | |
| "simple": "Simple Questions", | |
| "set": "Set-based", | |
| "mh": "Multi-hop", | |
| "cond": "Conditional", | |
| "comp": "Comparison" | |
| } | |
| with category_tabs: | |
| for category, _ in category_metrics: | |
| if f"{category}_retrieval" in df.columns and f"{category}_generation" in df.columns: | |
| with gr.TabItem(category_display_names.get(category, category.capitalize()), elem_classes=["category-tab"]): | |
| # Create dataframe for this category | |
| category_df = create_category_df(df, category, f"{category}_retrieval", f"{category}_generation") | |
| if category_df.empty: | |
| gr.Markdown(f"No data available for {category_display_names.get(category, category)} category.") | |
| category_tables[category] = gr.DataFrame() | |
| else: | |
| gr.Markdown(f"#### Performance on {category_display_names.get(category, category)}") | |
| category_tables[category] = gr.DataFrame( | |
| value=category_df, | |
| headers=category_df.columns.tolist(), | |
| datatype=["str"] * len(category_df.columns), | |
| row_count=(min(10, len(category_df)) if not category_df.empty else 0), | |
| col_count=(len(category_df.columns) if not category_df.empty else 0), | |
| interactive=False, | |
| wrap=True | |
| ) | |
| # Version selector and Citation block in a flex container | |
| with gr.Row(): | |
| # Citation block (left side) | |
| with gr.Column(scale=1, elem_classes=["citation-block"]): | |
| gr.Markdown("### Цитирование") | |
| gr.Markdown(""" | |
| ``` | |
| @article{dynamic-rag-benchmark, | |
| title={Dynamic RAG Benchmark}, | |
| author={RAG Benchmark Team}, | |
| journal={arXiv preprint}, | |
| year={2024}, | |
| url={https://github.com/rag-benchmark} | |
| } | |
| ``` | |
| Шаблон для цитирования нашего бенча. | |
| """) | |
| # Version selector (right side) | |
| with gr.Column(scale=1, elem_classes=["version-selector"]): | |
| gr.Markdown("### Выбор версий") | |
| with gr.Column(): | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| only_actual_versions = gr.Checkbox( | |
| label="Только актуальные версии", | |
| value=True, | |
| info="Считать, начиная с актуальной версии датасета" | |
| ) | |
| with gr.Column(scale=5): | |
| n_versions_slider = gr.Slider( | |
| minimum=1, | |
| maximum=5, | |
| value=1, | |
| step=1, | |
| label="Взять n последних версий", | |
| info="Количество версий для подсчета метрик" | |
| ) | |
| with gr.Row(): | |
| filter_btn = gr.Button("Применить фильтр", variant="primary") | |
| gr.Markdown( | |
| "Кликайте на модели в таблице, чтобы добавить их в графики", | |
| elem_classes=["info-text"] | |
| ) | |
| # Footer | |
| with gr.Row(): | |
| gr.Markdown(""" | |
| <footer>Dynamic RAG Benchmark Leaderboard</footer> | |
| """) | |
| # Handle row selection for radar charts | |
| def update_charts(evt: gr.SelectData, selected_models): | |
| try: | |
| # Get current data with the latest filters | |
| current_df, current_ret_metrics, current_gen_metrics, _ = filter_and_process_results( | |
| results, n_versions=n_versions_slider.value, only_actual_versions=only_actual_versions.value | |
| ) | |
| # Debug info | |
| print(f"Selection event: {evt}, type: {type(evt)}") | |
| selected_model = None | |
| # Extract the selected model based on the row index | |
| try: | |
| # Get the table component that was clicked | |
| component = evt.target | |
| # Get the row index | |
| row_idx = evt.index[0] if isinstance(evt.index, list) else evt.index | |
| print(f"Row index: {row_idx}") | |
| # Determine what type of data we're dealing with and extract model name | |
| # First check if it's a summary table | |
| if component is metrics_table: | |
| # Summary table was clicked | |
| if isinstance(summary_df, pd.DataFrame) and 0 <= row_idx < len(summary_df): | |
| selected_model = summary_df.iloc[row_idx]['Model'] | |
| print(f"Selected from summary table: {selected_model}") | |
| else: | |
| # Check if it's a category table | |
| for category, table in category_tables.items(): | |
| if component is table: | |
| # Get the category dataframe | |
| category_df = create_category_df( | |
| current_df, | |
| category, | |
| f"{category}_retrieval", | |
| f"{category}_generation" | |
| ) | |
| if isinstance(category_df, pd.DataFrame) and 0 <= row_idx < len(category_df): | |
| selected_model = category_df.iloc[row_idx]['Model'] | |
| print(f"Selected from {category} table: {selected_model}") | |
| break | |
| # If we still couldn't identify the model, try to get it from the raw data | |
| if selected_model is None and hasattr(component, "value"): | |
| table_value = component.value | |
| if isinstance(table_value, pd.DataFrame) and 0 <= row_idx < len(table_value): | |
| selected_model = table_value.iloc[row_idx]['Model'] | |
| elif isinstance(table_value, list) and 0 <= row_idx < len(table_value): | |
| selected_model = table_value[row_idx][0] # Assuming Model is the first column | |
| elif isinstance(table_value, dict) and 'data' in table_value and 0 <= row_idx < len(table_value['data']): | |
| selected_model = table_value['data'][row_idx][0] | |
| except Exception as e: | |
| print(f"Error extracting model name: {e}") | |
| traceback.print_exc() | |
| # If we found a model name, toggle its selection | |
| if selected_model: | |
| print(f"Selected model: {selected_model}") | |
| # Make sure the model exists in the current dataframe | |
| available_models = current_df['Model'].tolist() if not current_df.empty else [] | |
| if selected_model in available_models: | |
| # Add to list if not already there, otherwise remove (toggle selection) | |
| if selected_model in selected_models: | |
| selected_models.remove(selected_model) | |
| else: | |
| selected_models.append(selected_model) | |
| else: | |
| print(f"Model {selected_model} not found in current dataframe") | |
| # Ensure only models from the current dataframe are included | |
| available_models = current_df['Model'].tolist() if not current_df.empty else [] | |
| selected_models = [model for model in selected_models if model in available_models] | |
| # If no models are selected after filtering, use the first available model | |
| if not selected_models and available_models: | |
| selected_models = [available_models[0]] | |
| # Create radar charts using the current dataframe and metrics | |
| gen_chart = create_radar_chart(current_df, selected_models, current_gen_metrics, "Performance on Generation Tasks") | |
| ret_chart = create_radar_chart(current_df, selected_models, current_ret_metrics, "Performance on Retrieval Tasks") | |
| return selected_models, gen_chart, ret_chart | |
| except Exception as e: | |
| print(f"Error in update_charts: {e}") | |
| print(traceback.format_exc()) | |
| return selected_models, generation_chart.value, retrieval_chart.value | |
| # Use custom event handler for row selection | |
| metrics_table.select( | |
| fn=update_charts, | |
| inputs=[selected_models], | |
| outputs=[selected_models, generation_chart, retrieval_chart] | |
| ) | |
| # Add selection handlers for category tables too | |
| for category_table in category_tables.values(): | |
| category_table.select( | |
| fn=update_charts, | |
| inputs=[selected_models], | |
| outputs=[selected_models, generation_chart, retrieval_chart] | |
| ) | |
| # Handle version filter changes | |
| def update_data(n_versions, only_actual, current_selected_models): | |
| try: | |
| # Get updated data | |
| new_df, new_ret_metrics, new_gen_metrics, new_category_metrics = filter_and_process_results( | |
| results, n_versions=n_versions, only_actual_versions=only_actual | |
| ) | |
| # Get available models | |
| available_models = new_df['Model'].tolist() if not new_df.empty else [] | |
| # Filter selected models to only include those that exist in the new dataset | |
| filtered_selected_models = [model for model in current_selected_models if model in available_models] | |
| # If no previously selected models remain, select the top models | |
| if not filtered_selected_models and available_models: | |
| filtered_selected_models = available_models[:min(5, len(available_models))] | |
| # Create radar charts | |
| gen_chart = create_radar_chart(new_df, filtered_selected_models, new_gen_metrics, "Performance on Generation Tasks") | |
| ret_chart = create_radar_chart(new_df, filtered_selected_models, new_ret_metrics, "Performance on Retrieval Tasks") | |
| # Create summary dataframe | |
| summary_df = create_summary_df(new_df, new_ret_metrics, new_gen_metrics) | |
| # Create category tables dictionary for output | |
| category_tables_output = {} | |
| # First initialize all tables to empty DataFrame | |
| for category in category_tables.keys(): | |
| category_tables_output[category] = pd.DataFrame() | |
| # Then populate available tables | |
| for category, _ in new_category_metrics: | |
| if f"{category}_retrieval" in new_df.columns and f"{category}_generation" in new_df.columns: | |
| category_df = create_category_df(new_df, category, f"{category}_retrieval", f"{category}_generation") | |
| if category in category_tables: | |
| category_tables_output[category] = category_df if not category_df.empty else pd.DataFrame() | |
| # Prepare all outputs | |
| outputs = [summary_df, gen_chart, ret_chart, filtered_selected_models] | |
| # Add category tables to outputs in the same order as in category_tables | |
| for category in category_tables.keys(): | |
| outputs.append(category_tables_output.get(category, pd.DataFrame())) | |
| # Update global df for later use | |
| global df, retrieval_metrics, generation_metrics | |
| df = new_df | |
| retrieval_metrics = new_ret_metrics | |
| generation_metrics = new_gen_metrics | |
| return outputs | |
| except Exception as e: | |
| print(f"Error in update_data: {e}") | |
| print(traceback.format_exc()) | |
| # Return original values in case of error | |
| empty_tables = [pd.DataFrame() for _ in category_tables] | |
| return summary_df, generation_chart.value, retrieval_chart.value, current_selected_models, *empty_tables | |
| # Define filter button outputs | |
| filter_outputs = [metrics_table, generation_chart, retrieval_chart, selected_models] | |
| # Add category tables to outputs | |
| for category_table in category_tables.values(): | |
| filter_outputs.append(category_table) | |
| filter_btn.click( | |
| fn=update_data, | |
| inputs=[n_versions_slider, only_actual_versions, selected_models], | |
| outputs=filter_outputs | |
| ) | |
| # Function to clear charts | |
| def clear_charts(): | |
| empty_models = [] | |
| # Create empty charts | |
| empty_gen_chart = create_radar_chart(df, empty_models, generation_metrics, "Performance on Generation Tasks") | |
| empty_ret_chart = create_radar_chart(df, empty_models, retrieval_metrics, "Performance on Retrieval Tasks") | |
| return empty_models, empty_gen_chart, empty_ret_chart | |
| # Connect clear charts button | |
| clear_charts_btn.click( | |
| fn=clear_charts, | |
| inputs=[], | |
| outputs=[selected_models, generation_chart, retrieval_chart] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() | |