Spaces:
Sleeping
Sleeping
| import pandas as pd | |
| import gradio as gr | |
| import plotly.express as px | |
| from typing import Dict | |
| from pathlib import Path | |
| from config import METADATA_COLUMNS, DATA_FOLDER | |
| from data_loader import load_csv_from_folder, get_available_datasets | |
| DB: Dict[str, pd.DataFrame] = {} | |
| # --- 1. DATA PROCESSING FUNCTIONS --- | |
| def analyze_domain_configs(df_subset): | |
| """Separates configuration columns into constants and variables for a domain.""" | |
| actual_cols = [c for c in df_subset.columns if c not in METADATA_COLUMNS] | |
| # Exclude any column containing 'failed' in the name | |
| actual_cols = [c for c in actual_cols if 'failed' not in c.lower()] | |
| constants = {} | |
| variables = [] | |
| for col in actual_cols: | |
| unique_vals = df_subset[col].astype(str).unique() | |
| if len(unique_vals) <= 1: | |
| constants[col] = unique_vals[0] if len(unique_vals) > 0 else "N/A" | |
| else: | |
| variables.append(col) | |
| return constants, variables | |
| def load_data() -> str: | |
| """Loads data from the configured data folder and responses folder.""" | |
| try: | |
| # Load aggregate metrics data | |
| df, status_msg = load_csv_from_folder(DATA_FOLDER) | |
| if not df.empty: | |
| # Remove failed_samples column if it exists | |
| if 'failed_samples' in df.columns: | |
| df = df.drop(columns=['failed_samples']) | |
| DB["data"] = df | |
| # Load response data | |
| DB["responses"] = load_response_data() | |
| response_count = sum(len(df) for df in DB["responses"].values()) | |
| return f"{status_msg}\nLoaded {len(DB['responses'])} response datasets with {response_count} total responses." | |
| except Exception as e: | |
| return f"Error loading data: {str(e)}" | |
| def load_response_data() -> Dict[str, pd.DataFrame]: | |
| """Load all response CSV files from responses folder.""" | |
| responses_folder = Path("./responses") | |
| response_db = {} | |
| domain_mapping = { | |
| 'Biomedical_pubmedqa_checkpoint_100.csv': 'Biomedical (PubMedQA)', | |
| 'Customer_Support_techqa_checkpoint_100.csv': 'Customer Support (TechQA)', | |
| 'Finance_finqa_checkpoint_100.csv': 'Finance (FinQA)', | |
| 'General_msmarco_checkpoint_100.csv': 'General (MS MARCO)', | |
| 'Legal_cuad_checkpoint_100.csv': 'Legal (CUAD)' | |
| } | |
| for filename, domain_name in domain_mapping.items(): | |
| filepath = responses_folder / filename | |
| if filepath.exists(): | |
| df = pd.read_csv(filepath) | |
| # Convert metric columns to numeric | |
| for col in ['trace_relevance', 'trace_utilization', 'trace_completeness', 'trace_adherence']: | |
| if col in df.columns: | |
| df[col] = pd.to_numeric(df[col], errors='coerce').fillna(0.0) | |
| response_db[domain_name] = df | |
| return response_db | |
| def get_questions_for_domain(domain): | |
| """Get list of questions for selected domain.""" | |
| if "responses" not in DB or domain not in DB["responses"]: | |
| return gr.update(choices=[], value=None) | |
| df = DB["responses"][domain] | |
| questions = df['question'].unique().tolist() | |
| return gr.update(choices=questions, value=None) | |
| def get_response_details(domain, question): | |
| """Get LLM answer, gold answer, and metrics for selected question.""" | |
| if "responses" not in DB or domain not in DB["responses"]: | |
| return "", "", None | |
| df = DB["responses"][domain] | |
| row = df[df['question'] == question] | |
| if row.empty: | |
| return "", "", None | |
| row = row.iloc[0] | |
| llm_answer = str(row.get('answer', 'N/A')) | |
| gold_answer = str(row.get('gold_answer', 'N/A')) | |
| # Create metrics visualization | |
| metrics_data = { | |
| 'Metric': ['Relevance', 'Utilization', 'Completeness', 'Adherence'], | |
| 'Score': [ | |
| row.get('trace_relevance', 0.0), | |
| row.get('trace_utilization', 0.0), | |
| row.get('trace_completeness', 0.0), | |
| row.get('trace_adherence', 0.0) | |
| ] | |
| } | |
| metrics_df = pd.DataFrame(metrics_data) | |
| # Create bar chart | |
| fig = px.bar( | |
| metrics_df, | |
| x='Metric', | |
| y='Score', | |
| title=f'Quality Metrics for Selected Response', | |
| text_auto='.3f', | |
| color='Metric', | |
| range_y=[0, 1] | |
| ) | |
| fig.update_traces(textposition='outside') | |
| return llm_answer, gold_answer, fig | |
| # --- 2. UI LOGIC --- | |
| def get_dataset_choices(): | |
| """Safely retrieves dataset choices for dropdown.""" | |
| try: | |
| if "data" in DB and not DB["data"].empty: | |
| return get_available_datasets(DB["data"]) | |
| return [] | |
| except Exception as e: | |
| print(f"Error getting dataset choices: {e}") | |
| return [] | |
| def get_data_preview(): | |
| """Returns separate dataframes for each domain with columns reordered by type.""" | |
| if "data" not in DB: | |
| return {}, {}, {}, {}, {} | |
| df = DB["data"].copy() | |
| # Remove failed_samples related columns | |
| columns_to_remove = ['failed_samples', '# Failed/Total Samples', 'failedsamples', '%_failed_sample'] | |
| for col in columns_to_remove: | |
| if col in df.columns: | |
| df = df.drop(columns=[col]) | |
| # Define explicit domain order matching the UI | |
| domain_order = ['pubmedqa', 'techqa', 'finqa', 'msmarco', 'cuad'] | |
| # Metric columns (Results) | |
| result_cols = ['rmse_relevance', 'rmse_utilization', 'rmse_completeness', 'f1_score', 'aucroc'] | |
| metadata_cols = ['test_id', 'config_purpose', 'dataset_name'] | |
| domain_dfs = [] | |
| for ds in domain_order: | |
| domain_df = df[df['dataset_name'] == ds].copy() | |
| if domain_df.empty: | |
| domain_dfs.append(pd.DataFrame()) | |
| continue | |
| # Analyze constants and variables | |
| consts, variables = analyze_domain_configs(domain_df) | |
| # Reorder columns: Metadata -> Constants -> Variables -> Results | |
| ordered_cols = [] | |
| # Add metadata columns first | |
| for col in metadata_cols: | |
| if col in domain_df.columns: | |
| ordered_cols.append(col) | |
| # Add constant columns (sorted) | |
| const_cols = sorted([col for col in consts.keys() if col in domain_df.columns]) | |
| ordered_cols.extend(const_cols) | |
| # Add variable columns (sorted) | |
| var_cols = sorted([col for col in variables if col in domain_df.columns]) | |
| ordered_cols.extend(var_cols) | |
| # Add result columns | |
| for col in result_cols: | |
| if col in domain_df.columns: | |
| ordered_cols.append(col) | |
| # Add any remaining columns (excluding failed samples columns) | |
| remaining = [col for col in domain_df.columns if col not in ordered_cols] | |
| ordered_cols.extend(remaining) | |
| # Reorder dataframe | |
| domain_df = domain_df[ordered_cols] | |
| domain_dfs.append(domain_df) | |
| return domain_dfs[0], domain_dfs[1], domain_dfs[2], domain_dfs[3], domain_dfs[4] | |
| def get_domain_state(dataset): | |
| empty_update = gr.update(visible=False, value=None, choices=[]) | |
| if "data" not in DB: | |
| return "", empty_update, empty_update, empty_update, empty_update, empty_update | |
| df = DB["data"] | |
| subset = df[df['dataset_name'] == dataset] | |
| if subset.empty: | |
| return "No data for this domain.", empty_update, empty_update, empty_update, empty_update, empty_update | |
| consts, _ = analyze_domain_configs(subset) | |
| const_text = "CONSTANTS (Fixed for this domain):\n" + "\n".join([f"{k}: {v}" for k,v in consts.items()]) | |
| # Fixed filter columns across all domains | |
| FILTER_COLUMNS = ['reranker_model', 'chunking_strategy', 'summarization', 'repacking', 'gpt_label'] | |
| updates = [] | |
| for col_name in FILTER_COLUMNS: | |
| if col_name in subset.columns: | |
| unique_choices = list(subset[col_name].astype(str).unique()) | |
| unique_choices.insert(0, "All") | |
| updates.append(gr.update( | |
| label=f"Filter by {col_name}", | |
| choices=unique_choices, | |
| value="All", | |
| visible=True, | |
| interactive=True | |
| )) | |
| else: | |
| updates.append(empty_update) | |
| return const_text, updates[0], updates[1], updates[2], updates[3], updates[4] | |
| def plot_metrics_on_x_axis(dataset, f1_val, f2_val, f3_val, f4_val, f5_val): | |
| """Generates RMSE and Performance metric plots for selected domain and filters.""" | |
| if "data" not in DB or not dataset: | |
| return None, None | |
| try: | |
| df = DB["data"] | |
| subset = df[df['dataset_name'] == dataset].copy() | |
| except Exception as e: | |
| print(f"Error accessing data: {e}") | |
| return None, None | |
| # Fixed filter columns across all domains | |
| FILTER_COLUMNS = ['reranker_model', 'chunking_strategy', 'summarization', 'repacking', 'gpt_label'] | |
| filters = [f1_val, f2_val, f3_val, f4_val, f5_val] | |
| for i, val in enumerate(filters): | |
| if i < len(FILTER_COLUMNS) and val != "All" and val is not None: | |
| col = FILTER_COLUMNS[i] | |
| if col in subset.columns: | |
| subset = subset[subset[col].astype(str) == str(val)].copy() | |
| if subset.empty: | |
| return None, None | |
| # Reset index to avoid any index-related issues | |
| subset = subset.reset_index(drop=True) | |
| # Create Legend Label | |
| # Ensure test_id is string to prevent errors | |
| subset['Legend'] = "Test " + subset['test_id'].astype(str) + ": " + subset['config_purpose'].astype(str) | |
| # --- PLOT 1: RMSE --- | |
| # Check if columns exist before melting | |
| rmse_cols = ['rmse_relevance', 'rmse_utilization', 'rmse_completeness'] | |
| available_rmse = [c for c in rmse_cols if c in subset.columns] | |
| if available_rmse: | |
| rmse_melted = subset.melt( | |
| id_vars=['Legend', 'test_id'], | |
| value_vars=available_rmse, | |
| var_name='Metric Name', | |
| value_name='Score' | |
| ) | |
| # Explicitly ensure Score is numeric float | |
| rmse_melted['Score'] = pd.to_numeric(rmse_melted['Score'], errors='coerce').fillna(0.0).astype(float) | |
| rmse_melted['Metric Name'] = rmse_melted['Metric Name'].str.replace('rmse_', '').str.capitalize() | |
| rmse_melted = rmse_melted.reset_index(drop=True) | |
| # DEBUG: Print to verify values | |
| print(f"[DEBUG] RMSE melted data - Score range: {rmse_melted['Score'].min():.4f} to {rmse_melted['Score'].max():.4f}") | |
| print(f"[DEBUG] Sample scores: {rmse_melted['Score'].head(6).tolist()}") | |
| fig_rmse = px.bar( | |
| rmse_melted, | |
| x="Metric Name", | |
| y="Score", | |
| color="Legend", | |
| barmode="group", | |
| title=f"RMSE Breakdown (Lower is Better) - {len(subset)} Tests", | |
| text_auto='.3f' | |
| ) | |
| fig_rmse.update_traces(textposition='outside') | |
| fig_rmse.update_layout(legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)) | |
| else: | |
| fig_rmse = None | |
| # --- PLOT 2: Performance --- | |
| perf_cols = ['f1_score', 'aucroc'] | |
| available_perf = [c for c in perf_cols if c in subset.columns] | |
| if available_perf: | |
| perf_melted = subset.melt( | |
| id_vars=['Legend', 'test_id'], | |
| value_vars=available_perf, | |
| var_name='Metric Name', | |
| value_name='Score' | |
| ) | |
| # Explicitly ensure Score is numeric float | |
| perf_melted['Score'] = pd.to_numeric(perf_melted['Score'], errors='coerce').fillna(0.0).astype(float) | |
| perf_melted['Metric Name'] = perf_melted['Metric Name'].replace({ | |
| 'f1_score': 'F1 Score', 'aucroc': 'AUC-ROC' | |
| }) | |
| perf_melted = perf_melted.reset_index(drop=True) | |
| # DEBUG: Print to verify values | |
| print(f"[DEBUG] Performance melted data - Score range: {perf_melted['Score'].min():.4f} to {perf_melted['Score'].max():.4f}") | |
| print(f"[DEBUG] Sample scores: {perf_melted['Score'].head(6).tolist()}") | |
| fig_perf = px.bar( | |
| perf_melted, | |
| x="Metric Name", | |
| y="Score", | |
| color="Legend", | |
| barmode="group", | |
| title=f"Performance Metrics (Higher is Better) - {len(subset)} Tests", | |
| text_auto='.3f' | |
| ) | |
| fig_perf.update_traces(textposition='outside') | |
| fig_perf.update_layout(legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)) | |
| else: | |
| fig_perf = None | |
| return fig_rmse, fig_perf | |
| def generate_inter_domain_comparison(metric='f1_score'): | |
| """Generates comparison table and plot across all domains for selected metric.""" | |
| if "data" not in DB: | |
| return pd.DataFrame(), None | |
| try: | |
| df = DB["data"] | |
| except Exception as e: | |
| print(f"Error accessing data: {e}") | |
| return pd.DataFrame(), None | |
| datasets = df['dataset_name'].unique() | |
| all_keys = set() | |
| domain_constants = {} | |
| for ds in datasets: | |
| subset = df[df['dataset_name'] == ds] | |
| consts, _ = analyze_domain_configs(subset) | |
| domain_constants[ds] = consts | |
| all_keys.update(consts.keys()) | |
| # Exclude failed_samples and other unwanted columns | |
| EXCLUDE_COLUMNS = ['failed_samples', 'failedsamples', '%_failed_sample'] | |
| all_keys = {k for k in all_keys if k not in EXCLUDE_COLUMNS and 'failed' not in k.lower()} | |
| table_rows = [] | |
| for key in sorted(list(all_keys)): | |
| row = {"Configuration Parameter": key} | |
| for ds in datasets: | |
| val = domain_constants[ds].get(key, "Variable") | |
| row[ds] = val | |
| table_rows.append(row) | |
| comp_df = pd.DataFrame(table_rows) | |
| # Metric display names | |
| metric_names = { | |
| 'rmse_relevance': 'RMSE Relevance', | |
| 'rmse_utilization': 'RMSE Utilization', | |
| 'rmse_completeness': 'RMSE Completeness', | |
| 'f1_score': 'F1 Score', | |
| 'aucroc': 'AUC-ROC' | |
| } | |
| metric_display = metric_names.get(metric, metric) | |
| is_rmse = metric.startswith('rmse') | |
| direction = "Lower is Better" if is_rmse else "Higher is Better" | |
| best_results = [] | |
| for ds in datasets: | |
| subset = df[df['dataset_name'] == ds] | |
| if metric in subset.columns: | |
| if is_rmse: | |
| best_val = subset[metric].min() | |
| best_idx = subset[metric].idxmin() | |
| else: | |
| best_val = subset[metric].max() | |
| best_idx = subset[metric].idxmax() | |
| best_row = subset.loc[best_idx] | |
| best_results.append({ | |
| "Domain": ds, | |
| metric_display: best_val, | |
| "Best Config": best_row['config_purpose'] | |
| }) | |
| if best_results: | |
| best_df = pd.DataFrame(best_results) | |
| fig_global = px.bar( | |
| best_df, x="Domain", y=metric_display, | |
| color="Domain", | |
| text_auto='.4f', | |
| hover_data=["Best Config"], | |
| title=f"Peak Performance per Domain: {metric_display} ({direction})" | |
| ) | |
| fig_global.update_traces(textposition='outside') | |
| else: | |
| fig_global = None | |
| return comp_df, fig_global | |
| # --- 3. UI --- | |
| APP_VERSION = "v2.2.0" | |
| # Global constants used across all experiments | |
| GLOBAL_CONSTANTS = """ | |
| **Global Constants (Applied to All Domains):** | |
| - Generator Model: **llama-3.1-8b-instant** | |
| - Generator Max Tokens: **512** | |
| - Generator Temperature: **0.2** | |
| - Generator API Provider: **Groq** | |
| - Generation LLM Context Budget: **2000** | |
| - Judge Model: **llama-3.3-70b-versatile** | |
| - Judge Max Tokens: **1024** | |
| - Judge Temperature: **0.0** | |
| - Judge Sentence Attribution: **ENABLED** | |
| - Summarization Model: **fangyuan/nq_abstractive_compressor** | |
| """ | |
| with gr.Blocks(title="RAG Analytics Pro") as demo: | |
| gr.Markdown("## RAG Pipeline Analytics") | |
| gr.Markdown(f"**Data Source:** `{DATA_FOLDER}` | **Version:** {APP_VERSION}") | |
| with gr.Accordion("Global Experiment Configuration", open=False): | |
| gr.Markdown(GLOBAL_CONSTANTS) | |
| with gr.Row(): | |
| refresh_data_btn = gr.Button("Load/Refresh Data", variant="primary") | |
| status = gr.Textbox(label="Status (Check here for debug info)", interactive=False, scale=3) | |
| with gr.Tabs(): | |
| # TAB 1: Main Analytics | |
| with gr.TabItem("Intra-Domain Analysis"): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| ds_dropdown = gr.Dropdown(label="1. Select Domain", choices=[], interactive=True) | |
| constants_box = gr.Textbox(label="Domain Constants", lines=5, interactive=False) | |
| gr.Markdown("### Filter Tests") | |
| filter_1 = gr.Dropdown(visible=False) | |
| filter_2 = gr.Dropdown(visible=False) | |
| filter_3 = gr.Dropdown(visible=False) | |
| filter_4 = gr.Dropdown(visible=False) | |
| filter_5 = gr.Dropdown(visible=False) | |
| with gr.Column(scale=3): | |
| plot_r = gr.Plot(label="RMSE Comparison") | |
| plot_p = gr.Plot(label="Performance Comparison") | |
| # TAB 2: Data Inspector | |
| with gr.TabItem("Data Preview"): | |
| gr.Markdown("### All Test Configurations by Domain") | |
| gr.Markdown("**Biomedical (PubMedQA)**") | |
| preview_table_1 = gr.Dataframe(interactive=False, wrap=True) | |
| gr.Markdown("**Customer Support (TechQA)**") | |
| preview_table_2 = gr.Dataframe(interactive=False, wrap=True) | |
| gr.Markdown("**Finance (FinQA)**") | |
| preview_table_3 = gr.Dataframe(interactive=False, wrap=True) | |
| gr.Markdown("**General (MS MARCO)**") | |
| preview_table_4 = gr.Dataframe(interactive=False, wrap=True) | |
| gr.Markdown("**Legal (CUAD)**") | |
| preview_table_5 = gr.Dataframe(interactive=False, wrap=True) | |
| preview_btn = gr.Button("Refresh Data Preview") | |
| # TAB 3: Comparison | |
| with gr.TabItem("Inter-Domain Comparison"): | |
| gr.Markdown("### Select Metric to Compare") | |
| metric_dropdown = gr.Dropdown( | |
| label="Comparison Metric", | |
| choices=[ | |
| ("F1 Score (Higher is Better)", "f1_score"), | |
| ("AUC-ROC (Higher is Better)", "aucroc"), | |
| ("RMSE Relevance (Lower is Better)", "rmse_relevance"), | |
| ("RMSE Utilization (Lower is Better)", "rmse_utilization"), | |
| ("RMSE Completeness (Lower is Better)", "rmse_completeness") | |
| ], | |
| value="f1_score", | |
| interactive=True | |
| ) | |
| refresh_btn = gr.Button("Generate Comparison") | |
| gr.Markdown("### Configuration Differences") | |
| comp_table = gr.Dataframe(interactive=False) | |
| gr.Markdown("### Peak Performance") | |
| global_plot = gr.Plot() | |
| # TAB 4: Response Preview & Metrics | |
| with gr.TabItem("Response Preview & Metrics"): | |
| gr.Markdown("### Preview LLM Responses and Quality Metrics") | |
| gr.Markdown("Select a domain and question to view the generated answer, gold answer, and quality metrics.") | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| domain_selector = gr.Dropdown( | |
| label="Select Domain", | |
| choices=[ | |
| 'Biomedical (PubMedQA)', | |
| 'Customer Support (TechQA)', | |
| 'Finance (FinQA)', | |
| 'General (MS MARCO)', | |
| 'Legal (CUAD)' | |
| ], | |
| interactive=True | |
| ) | |
| question_selector = gr.Dropdown( | |
| label="Select Question", | |
| choices=[], | |
| interactive=True | |
| ) | |
| with gr.Column(scale=2): | |
| metrics_plot = gr.Plot(label="Quality Metrics") | |
| with gr.Row(): | |
| with gr.Column(): | |
| gr.Markdown("#### LLM Generated Answer") | |
| llm_answer_box = gr.Textbox( | |
| label="LLM Answer", | |
| lines=12, | |
| interactive=False | |
| ) | |
| with gr.Column(): | |
| gr.Markdown("#### Gold Standard Answer") | |
| gold_answer_box = gr.Textbox( | |
| label="Gold Answer", | |
| lines=12, | |
| interactive=False | |
| ) | |
| # EVENTS | |
| refresh_data_btn.click( | |
| load_data, inputs=None, outputs=[status] | |
| ).then( | |
| lambda: gr.Dropdown(choices=get_dataset_choices()), | |
| outputs=[ds_dropdown] | |
| ) | |
| ds_dropdown.change( | |
| get_domain_state, | |
| inputs=[ds_dropdown], | |
| outputs=[constants_box, filter_1, filter_2, filter_3, filter_4, filter_5] | |
| ).then( | |
| plot_metrics_on_x_axis, | |
| inputs=[ds_dropdown, filter_1, filter_2, filter_3, filter_4, filter_5], | |
| outputs=[plot_r, plot_p] | |
| ) | |
| gr.on( | |
| triggers=[filter_1.change, filter_2.change, filter_3.change, filter_4.change, filter_5.change], | |
| fn=plot_metrics_on_x_axis, | |
| inputs=[ds_dropdown, filter_1, filter_2, filter_3, filter_4, filter_5], | |
| outputs=[plot_r, plot_p] | |
| ) | |
| # Debug Preview Events | |
| preview_btn.click(get_data_preview, inputs=None, outputs=[preview_table_1, preview_table_2, preview_table_3, preview_table_4, preview_table_5]) | |
| refresh_btn.click( | |
| generate_inter_domain_comparison, | |
| inputs=[metric_dropdown], | |
| outputs=[comp_table, global_plot] | |
| ) | |
| # Response Preview Events | |
| domain_selector.change( | |
| fn=get_questions_for_domain, | |
| inputs=[domain_selector], | |
| outputs=[question_selector] | |
| ).then( | |
| fn=lambda: ("", "", None), | |
| outputs=[llm_answer_box, gold_answer_box, metrics_plot] | |
| ) | |
| question_selector.change( | |
| fn=get_response_details, | |
| inputs=[domain_selector, question_selector], | |
| outputs=[llm_answer_box, gold_answer_box, metrics_plot] | |
| ) | |
| # Auto-load data on startup | |
| print(f"Loading data from {DATA_FOLDER}...") | |
| startup_status = load_data() | |
| print(startup_status) | |
| # Launch Gradio app (for Hugging Face Spaces, this runs on import) | |
| demo.launch(ssr_mode=False) |