Spaces:
Running
Running
| import os | |
| import re | |
| import functools | |
| import pandas as pd | |
| import gradio as gr | |
| DATA_PATH = os.environ.get("LB_DATA_PATH", "gradio (2).xlsx") | |
| BANNER_PATH = os.environ.get("LB_BANNER_PATH", "banner.png") | |
| pd.options.mode.copy_on_write = True | |
| TAB_ORDER = [ | |
| "Main Leader Board", | |
| "TTS Difficultly Level Per Model", | |
| "Performance On Codecs", | |
| "Best Model per Attack", | |
| ] | |
| SHEET_DESCRIPTIONS = { | |
| "Main Leader Board": "<b>Overview.</b> Comparison of models across all datasets. Lower <b>EER</b> is better.", | |
| "TTS Difficultly Level Per Model": "<b>TTS stress-test.</b> Lower TNR means harder; higher TNR means easier.", | |
| "Performance On Codecs": "<b>Codec robustness.</b> Compare models under compression/bitrates.", | |
| "Best Model per Attack": "<b>Per-attack winners.</b> Shows top-performing model per attack.", | |
| } | |
| try: | |
| import openpyxl # better engine for xlsx on Windows | |
| READ_KW = {"engine": "openpyxl"} | |
| except Exception: | |
| READ_KW = {} | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Excel helpers | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def _ensure_file(): | |
| if not os.path.exists(DATA_PATH): | |
| raise FileNotFoundError(f"Excel file not found at '{DATA_PATH}'") | |
| def load_sheet(name: str): | |
| _ensure_file() | |
| return pd.read_excel(DATA_PATH, sheet_name=name, **READ_KW) | |
| def _find_model_col(headers: list[str]) -> str: | |
| for cand in headers: | |
| if str(cand).lower().strip() in ("ssl model", "ssl_model", "model"): | |
| return cand | |
| return headers[0] | |
| def _base_dataset_name(colname: str) -> str: | |
| return re.sub(r"\.\d+$", "", str(colname)).strip() | |
| def read_tts_header(): | |
| """ | |
| From 'TTS all' read: | |
| - column headers (Excel header row) | |
| - first data row (contains labels per column: attacks/TTS names/means) | |
| - the model column name | |
| """ | |
| df_cols = pd.read_excel(DATA_PATH, sheet_name="TTS all", nrows=0, **READ_KW) | |
| columns = df_cols.columns.tolist() | |
| df_head = pd.read_excel(DATA_PATH, sheet_name="TTS all", nrows=1, header=0, **READ_KW) | |
| first_row = df_head.iloc[0] if len(df_head) else pd.Series(index=columns, dtype=object) | |
| model_col = _find_model_col(columns) | |
| return model_col, columns, first_row | |
| def build_dataset_to_labels(): | |
| """ | |
| Build mapping WITHOUT filtering to Axx only: | |
| dataset -> { label -> wide_column_name } | |
| Includes any non-empty string labels (e.g., 'A07', 'mean TNR', 'MeloTTS', etc.). | |
| If duplicate labels occur within the same dataset, make them unique (label, label#2, ...). | |
| """ | |
| model_col, columns, first_row = read_tts_header() | |
| mapping: dict[str, dict[str, str]] = {} | |
| per_dataset_seen: dict[str, set[str]] = {} | |
| for col in columns: | |
| if col == model_col: | |
| continue | |
| dataset = _base_dataset_name(col) | |
| raw_label = first_row.get(col, "") | |
| label = str(raw_label).strip() if pd.notna(raw_label) else "" | |
| if not label: # skip empty header labels | |
| continue | |
| # Ensure uniqueness within a dataset | |
| if dataset not in per_dataset_seen: | |
| per_dataset_seen[dataset] = set() | |
| unique_label = label | |
| idx = 2 | |
| while unique_label in per_dataset_seen[dataset]: | |
| unique_label = f"{label}#{idx}" | |
| idx += 1 | |
| per_dataset_seen[dataset].add(unique_label) | |
| mapping.setdefault(dataset, {})[unique_label] = col | |
| return model_col, columns, mapping | |
| def _indices_for_columns(columns: list[str], wanted: list[str]) -> list[int]: | |
| name_to_pos = {name: i for i, name in enumerate(columns)} | |
| positions = [] | |
| for w in wanted: | |
| if w not in name_to_pos: | |
| raise KeyError(f"Requested column '{w}' not found in header.") | |
| positions.append(name_to_pos[w]) | |
| return positions | |
| def read_tts_subset_for_dataset(dataset: str, label: str | None) -> pd.DataFrame: | |
| """ | |
| Efficiently read only the model column + chosen dataset labels (or a single label). | |
| Steps: | |
| - Map dataset -> labels (wide column names) | |
| - Read via positional indices (stable even with dup headers) | |
| - Drop the first data row (label row) | |
| - Rename columns to ['Model', <labels...>] | |
| - Convert numeric columns | |
| """ | |
| model_col, columns, mapping = build_dataset_to_labels() | |
| if dataset not in mapping: | |
| return pd.DataFrame({"Info": [f"No columns found for dataset '{dataset}'."]}) | |
| if label and label != "(All)": | |
| if label not in mapping[dataset]: | |
| return pd.DataFrame({"Info": [f"Label '{label}' not found in dataset '{dataset}'."]}) | |
| selected_cols = [mapping[dataset][label]] | |
| labels = [label] | |
| else: | |
| # include ALL labels for this dataset (attacks, means, TTS namesβ¦) | |
| labels = sorted(mapping[dataset].keys(), key=lambda s: s.lower()) | |
| selected_cols = [mapping[dataset][lab] for lab in labels] | |
| try: | |
| use_positions = _indices_for_columns(columns, [model_col] + selected_cols) | |
| except KeyError as e: | |
| return pd.DataFrame({"Info": [f"{e}"]}) | |
| df = pd.read_excel( | |
| DATA_PATH, | |
| sheet_name="TTS all", | |
| usecols=use_positions, | |
| header=0, | |
| **READ_KW | |
| ) | |
| # First data row holds the labels; drop it. | |
| if len(df) > 0: | |
| df = df.iloc[1:].reset_index(drop=True) | |
| # Rename columns | |
| df.columns = ["Model"] + labels | |
| df["Model"] = df["Model"].astype(str) | |
| # Coerce numeric where possible | |
| for lab in labels: | |
| df[lab] = pd.to_numeric(df[lab], errors="coerce") | |
| return df | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Build Gradio app | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def build_app(): | |
| with gr.Blocks(title="ποΈ Benchmarking Linear-Head Classifiers") as demo: | |
| if os.path.exists(BANNER_PATH): | |
| gr.Image(value=BANNER_PATH, show_label=False) | |
| gr.Markdown("<h1>ποΈ Benchmarking Linear-Head Classifiers Built on S3PRL Embeddings</h1>") | |
| with gr.Tabs(): | |
| # 1β4: Auto-load the four main tabs | |
| for sheet in TAB_ORDER: | |
| try: | |
| df = load_sheet(sheet) | |
| except Exception as e: | |
| df = pd.DataFrame({"Error": [str(e)]}) | |
| with gr.TabItem(sheet): | |
| gr.Markdown(f"<div class='sheet-card'>{SHEET_DESCRIPTIONS.get(sheet, '')}</div>") | |
| gr.Dataframe(value=df, interactive=False) | |
| # Explorer (all labels, not just Axx) | |
| with gr.TabItem("Explorer for Attacks"): | |
| gr.Markdown( | |
| "<div class='sheet-card'>" | |
| "<b>Pick a Dataset</b> then (optionally) pick a specific label.<br>" | |
| "Rows = Models, Columns = ALL labels for that dataset (attacks, means, TTS names, etc.)." | |
| "</div>" | |
| ) | |
| try: | |
| _, _, mapping = build_dataset_to_labels() | |
| datasets = sorted(mapping.keys()) | |
| except Exception: | |
| mapping, datasets = {}, [] | |
| if not datasets: | |
| gr.Markdown("β <i>No valid 'TTS all' sheet detected (first row must contain labels; first column models).</i>") | |
| else: | |
| default_ds = datasets[0] | |
| default_labels = ["(All)"] + sorted(mapping[default_ds].keys(), key=lambda s: s.lower()) | |
| with gr.Row(): | |
| ds_dd = gr.Dropdown(datasets, value=default_ds, label="Dataset") | |
| lab_dd = gr.Dropdown(default_labels, value="(All)", label="Label (optional)") | |
| out = gr.Dataframe(interactive=False) | |
| def on_ds_change(ds): | |
| new_labels = ["(All)"] + sorted(mapping.get(ds, {}).keys(), key=lambda s: s.lower()) | |
| table = read_tts_subset_for_dataset(ds, None) | |
| return gr.update(choices=new_labels, value="(All)"), table | |
| ds_dd.change(on_ds_change, ds_dd, [lab_dd, out]) | |
| def on_label_change(ds, lab): | |
| return read_tts_subset_for_dataset(ds, lab if lab != "(All)" else None) | |
| lab_dd.change(on_label_change, [ds_dd, lab_dd], out) | |
| # initial render | |
| out.value = read_tts_subset_for_dataset(default_ds, None) | |
| return demo | |
| demo = build_app() | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860, share=True) | |