Spaces:
Sleeping
Sleeping
| """ | |
| MediaBiasGroup — Model Comparator (Gradio Space) | |
| - Discovers org models by pipeline_tag | |
| - Lets users pick a task, select multiple models, and compare outputs on the same input | |
| - Uses a full local snapshot for robustness (avoids NoneType path issues) | |
| - Falls back to base_model's tokenizer if a fine-tuned repo lacks tokenizer files | |
| - Canonicalizes label names across models (LABEL_0 -> neutral, etc.) | |
| - "Select all" button to quickly select all models for the chosen task | |
| Requirements (see requirements.txt): | |
| gradio>=4.31.4 | |
| transformers>=4.42.0 | |
| huggingface_hub>=0.23.0 | |
| torch>=2.2.0 | |
| pandas>=2.0.0 | |
| """ | |
| from __future__ import annotations | |
| import os | |
| from functools import lru_cache | |
| from typing import Any, Dict, List, Tuple | |
| import gradio as gr | |
| import pandas as pd | |
| from huggingface_hub import HfApi, list_repo_files, snapshot_download | |
| from transformers import pipeline | |
| # ========================= | |
| # Configuration | |
| # ========================= | |
| ORG = "mediabiasgroup" | |
| DEFAULT_TASK = "text-classification" | |
| MAX_MODELS = 10 # safety cap to avoid loading too many models at once on CPU Spaces | |
| HF_TOKEN = ( | |
| os.environ.get("HF_TOKEN") | |
| or os.environ.get("HUGGING_FACE_HUB_TOKEN") | |
| or os.environ.get("HUGGINGFACEHUB_API_TOKEN") | |
| ) | |
| api = HfApi(token=HF_TOKEN) | |
| # Canonical label mapping (extend if needed) | |
| CANON = { | |
| "LABEL_0": "neutral", | |
| "LABEL_1": "lexical_bias", | |
| "NEGATIVE": "neutral", | |
| "POSITIVE": "lexical_bias", | |
| "neutral": "neutral", | |
| "not_biased": "neutral", | |
| "non-biased": "neutral", | |
| "unbiased": "neutral", | |
| "biased": "lexical_bias", | |
| "lexical_bias": "lexical_bias", | |
| } | |
| # ========================= | |
| # Discovery & card helpers | |
| # ========================= | |
| def list_org_models() -> List[Any]: | |
| # full=True to fetch pipeline_tag & tags | |
| return list(api.list_models(author=ORG, full=True)) | |
| def discover_tasks_and_models() -> Tuple[List[str], Dict[str, List[str]]]: | |
| infos = list_org_models() | |
| task2models: Dict[str, List[str]] = {} | |
| for info in infos: | |
| task = getattr(info, "pipeline_tag", None) | |
| if not task: | |
| # Heuristic fallback via tags if pipeline_tag is missing | |
| tags = set(getattr(info, "tags", []) or []) | |
| if "text-classification" in tags: | |
| task = "text-classification" | |
| if task: | |
| task2models.setdefault(task, []).append(info.modelId) | |
| tasks = sorted(task2models.keys()) or [DEFAULT_TASK] | |
| for t in task2models: | |
| task2models[t] = sorted(task2models[t]) | |
| return tasks, task2models | |
| def get_card_data(repo_id: str) -> Dict[str, Any]: | |
| try: | |
| info = api.model_info(repo_id, token=HF_TOKEN) | |
| data = getattr(info, "cardData", None) | |
| if hasattr(data, "data"): # ModelCardData -> dict | |
| return dict(data.data) | |
| return data or {} | |
| except Exception: | |
| return {} | |
| # ========================= | |
| # Tokenizer fallback logic | |
| # ========================= | |
| def _has_tokenizer_files(repo_id: str) -> bool: | |
| try: | |
| files = set(list_repo_files(repo_id, repo_type="model", token=HF_TOKEN)) | |
| except Exception: | |
| return False | |
| if "tokenizer.json" in files: | |
| return True | |
| if {"vocab.json", "merges.txt"}.issubset(files): | |
| return True | |
| if "spiece.model" in files: | |
| return True | |
| return False | |
| def _base_model_from_card(repo_id: str) -> str | None: | |
| data = get_card_data(repo_id) or {} | |
| base = data.get("base_model") | |
| if isinstance(base, list): | |
| base = base[0] if base else None | |
| return base | |
| def _tokenizer_source(repo_id: str) -> str: | |
| # Prefer repo tokenizer; else fall back to base_model; else repo_id | |
| if _has_tokenizer_files(repo_id): | |
| return repo_id | |
| base = _base_model_from_card(repo_id) | |
| return base or repo_id | |
| # ========================= | |
| # Pipelines & prediction | |
| # ========================= | |
| PIPE_CACHE: Dict[str, Any] = {} | |
| def get_pipeline(repo_id: str, task: str): | |
| key = f"{task}::{repo_id}" | |
| if key in PIPE_CACHE: | |
| return PIPE_CACHE[key] | |
| tok_src = _tokenizer_source(repo_id) | |
| # Robust path: download a full local snapshot (no restrictive allow_patterns) | |
| try: | |
| local_dir = snapshot_download( | |
| repo_id=repo_id, | |
| repo_type="model", | |
| token=HF_TOKEN, # works for public and gated/private (if token has access) | |
| local_files_only=False, | |
| ) | |
| if not isinstance(local_dir, str) or not local_dir: | |
| # extremely defensive: fall back to remote id | |
| local_dir = repo_id | |
| except Exception: | |
| local_dir = repo_id # fall back to remote if snapshot fails | |
| if task == "text-classification": | |
| pipe = pipeline( | |
| task, | |
| model=local_dir, | |
| tokenizer=tok_src, | |
| return_all_scores=True, | |
| truncation=True, | |
| token=HF_TOKEN, | |
| ) | |
| else: | |
| # Add more tasks if you release them later | |
| pipe = pipeline(task, model=local_dir, tokenizer=tok_src, token=HF_TOKEN) | |
| PIPE_CACHE[key] = pipe | |
| return pipe | |
| def _canonicalize(scores: Dict[str, float]) -> Dict[str, float]: | |
| out: Dict[str, float] = {} | |
| for raw_label, sc in scores.items(): | |
| lab = CANON.get(raw_label, raw_label) | |
| out[lab] = max(sc, out.get(lab, 0.0)) | |
| return out | |
| def predict(models: List[str], task: str, text: str) -> Tuple[str, pd.DataFrame]: | |
| if not text.strip(): | |
| return "Please enter some text.", pd.DataFrame() | |
| if not models: | |
| return f"Please select 1–{MAX_MODELS} models.", pd.DataFrame() | |
| if len(models) > MAX_MODELS: | |
| models = models[:MAX_MODELS] | |
| table_rows: List[Dict[str, Any]] = [] | |
| label_union: set[str] = set() | |
| per_model_outputs: Dict[str, Dict[str, float]] = {} | |
| errors: Dict[str, str] = {} | |
| for rid in models: | |
| try: | |
| pipe = get_pipeline(rid, task) | |
| out = pipe(text) | |
| # text-classification pipeline typical shapes: | |
| # [[{label, score}, ...]] or [{label, score}, ...] | |
| if isinstance(out, list) and out and isinstance(out[0], list): | |
| scores = {d["label"]: float(d["score"]) for d in out[0]} | |
| elif isinstance(out, list) and out and isinstance(out[0], dict) and "label" in out[0]: | |
| scores = {d["label"]: float(d["score"]) for d in out} | |
| else: | |
| scores = {} | |
| scores = _canonicalize(scores) or {"<no_output>": 1.0} | |
| per_model_outputs[rid] = scores | |
| label_union.update(scores.keys()) | |
| except Exception as e: | |
| per_model_outputs[rid] = {"<error>": 1.0} | |
| label_union.add("<error>") | |
| errors[rid] = str(e) | |
| # Build table with union of labels as columns | |
| label_cols = sorted(label_union) | |
| for rid in models: | |
| row = {"model": rid} | |
| scores = per_model_outputs.get(rid, {}) | |
| for lab in label_cols: | |
| row[lab] = scores.get(lab, 0.0) | |
| if scores: | |
| pred = max(scores.items(), key=lambda kv: kv[1])[0] | |
| row["predicted_label"] = pred | |
| else: | |
| row["predicted_label"] = "" | |
| table_rows.append(row) | |
| pred_df = pd.DataFrame(table_rows, columns=["model"] + label_cols + ["predicted_label"]) | |
| msg = f"✓ Done. Compared {len(models)} model(s) on task: `{task}`" | |
| if errors: | |
| msg += "\n\n**Errors**:\n" + "\n".join(f"- {k}: {v}" for k, v in errors.items()) | |
| return msg, pred_df | |
| # ========================= | |
| # UI wiring | |
| # ========================= | |
| def refresh_models(selected_task: str) -> Tuple[List[str], List[str]]: | |
| tasks, task2models = discover_tasks_and_models() | |
| models = task2models.get(selected_task, []) | |
| return tasks, models | |
| def on_task_change(selected_task: str) -> List[str]: | |
| _, task2models = discover_tasks_and_models() | |
| return task2models.get(selected_task, []) | |
| def select_all_models(selected_task: str) -> List[str]: | |
| _, task2models = discover_tasks_and_models() | |
| return task2models.get(selected_task, []) | |
| def build_ui() -> gr.Blocks: | |
| with gr.Blocks(fill_height=True, title="MediaBiasGroup — Model Comparator") as demo: | |
| gr.Markdown( | |
| "# MediaBiasGroup — Model Comparator\n" | |
| "Select a **task**, choose multiple models, enter text, and compare outputs side-by-side." | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| tasks, task2models = discover_tasks_and_models() | |
| task_choices = tasks or [DEFAULT_TASK] | |
| task_default = task_choices[0] if task_choices else DEFAULT_TASK | |
| task_dd = gr.Dropdown( | |
| choices=task_choices, | |
| value=task_default, | |
| label="Task", | |
| ) | |
| model_ms = gr.Dropdown( | |
| choices=task2models.get(task_default, []), | |
| multiselect=True, | |
| label="Models", | |
| ) | |
| select_all_btn = gr.Button("Select all") | |
| gr.Markdown(f"**Organization:** `{ORG}` \n**Max models per run:** {MAX_MODELS}") | |
| with gr.Column(scale=2): | |
| text_in = gr.Textbox(lines=4, placeholder="Paste a sentence…", label="Input text") | |
| gr.Examples( | |
| examples=[ | |
| ["The bill passed the House on Tuesday in a 220–210 vote."], # unbiased/factual | |
| ["Lawmakers shamelessly rammed the bill through the House on Tuesday."], # biased/loaded | |
| ["Unemployment fell from 5.2% to 5.0% in July, according to government figures."], | |
| ["The corrupt regime bragged unemployment fell, but it's just cooking the books."], | |
| ], | |
| inputs=[text_in], | |
| label="Examples", | |
| ) | |
| run_btn = gr.Button("Compare") | |
| status = gr.Markdown("") | |
| # Single wide results table | |
| gr.Markdown("### Predictions") | |
| pred_df = gr.Dataframe(interactive=False) | |
| # Events | |
| task_dd.change(fn=on_task_change, inputs=[task_dd], outputs=[model_ms]) | |
| select_all_btn.click(fn=select_all_models, inputs=[task_dd], outputs=[model_ms]) | |
| run_btn.click(fn=predict, inputs=[model_ms, task_dd, text_in], outputs=[status, pred_df]) | |
| return demo | |
| if __name__ == "__main__": | |
| demo = build_ui() | |
| demo.queue(max_size=16).launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860))) | |