""" MediaBiasGroup — Model Comparator (Gradio Space) - Discovers org models by pipeline_tag - Lets users pick a task, select multiple models, and compare outputs on the same input - Uses a full local snapshot for robustness (avoids NoneType path issues) - Falls back to base_model's tokenizer if a fine-tuned repo lacks tokenizer files - Canonicalizes label names across models (LABEL_0 -> neutral, etc.) - "Select all" button to quickly select all models for the chosen task Requirements (see requirements.txt): gradio>=4.31.4 transformers>=4.42.0 huggingface_hub>=0.23.0 torch>=2.2.0 pandas>=2.0.0 """ from __future__ import annotations import os from functools import lru_cache from typing import Any, Dict, List, Tuple import gradio as gr import pandas as pd from huggingface_hub import HfApi, list_repo_files, snapshot_download from transformers import pipeline # ========================= # Configuration # ========================= ORG = "mediabiasgroup" DEFAULT_TASK = "text-classification" MAX_MODELS = 10 # safety cap to avoid loading too many models at once on CPU Spaces HF_TOKEN = ( os.environ.get("HF_TOKEN") or os.environ.get("HUGGING_FACE_HUB_TOKEN") or os.environ.get("HUGGINGFACEHUB_API_TOKEN") ) api = HfApi(token=HF_TOKEN) # Canonical label mapping (extend if needed) CANON = { "LABEL_0": "neutral", "LABEL_1": "lexical_bias", "NEGATIVE": "neutral", "POSITIVE": "lexical_bias", "neutral": "neutral", "not_biased": "neutral", "non-biased": "neutral", "unbiased": "neutral", "biased": "lexical_bias", "lexical_bias": "lexical_bias", } # ========================= # Discovery & card helpers # ========================= @lru_cache(maxsize=1) def list_org_models() -> List[Any]: # full=True to fetch pipeline_tag & tags return list(api.list_models(author=ORG, full=True)) def discover_tasks_and_models() -> Tuple[List[str], Dict[str, List[str]]]: infos = list_org_models() task2models: Dict[str, List[str]] = {} for info in infos: task = getattr(info, "pipeline_tag", None) if not task: # Heuristic fallback via tags if pipeline_tag is missing tags = set(getattr(info, "tags", []) or []) if "text-classification" in tags: task = "text-classification" if task: task2models.setdefault(task, []).append(info.modelId) tasks = sorted(task2models.keys()) or [DEFAULT_TASK] for t in task2models: task2models[t] = sorted(task2models[t]) return tasks, task2models @lru_cache(maxsize=256) def get_card_data(repo_id: str) -> Dict[str, Any]: try: info = api.model_info(repo_id, token=HF_TOKEN) data = getattr(info, "cardData", None) if hasattr(data, "data"): # ModelCardData -> dict return dict(data.data) return data or {} except Exception: return {} # ========================= # Tokenizer fallback logic # ========================= def _has_tokenizer_files(repo_id: str) -> bool: try: files = set(list_repo_files(repo_id, repo_type="model", token=HF_TOKEN)) except Exception: return False if "tokenizer.json" in files: return True if {"vocab.json", "merges.txt"}.issubset(files): return True if "spiece.model" in files: return True return False def _base_model_from_card(repo_id: str) -> str | None: data = get_card_data(repo_id) or {} base = data.get("base_model") if isinstance(base, list): base = base[0] if base else None return base def _tokenizer_source(repo_id: str) -> str: # Prefer repo tokenizer; else fall back to base_model; else repo_id if _has_tokenizer_files(repo_id): return repo_id base = _base_model_from_card(repo_id) return base or repo_id # ========================= # Pipelines & prediction # ========================= PIPE_CACHE: Dict[str, Any] = {} def get_pipeline(repo_id: str, task: str): key = f"{task}::{repo_id}" if key in PIPE_CACHE: return PIPE_CACHE[key] tok_src = _tokenizer_source(repo_id) # Robust path: download a full local snapshot (no restrictive allow_patterns) try: local_dir = snapshot_download( repo_id=repo_id, repo_type="model", token=HF_TOKEN, # works for public and gated/private (if token has access) local_files_only=False, ) if not isinstance(local_dir, str) or not local_dir: # extremely defensive: fall back to remote id local_dir = repo_id except Exception: local_dir = repo_id # fall back to remote if snapshot fails if task == "text-classification": pipe = pipeline( task, model=local_dir, tokenizer=tok_src, return_all_scores=True, truncation=True, token=HF_TOKEN, ) else: # Add more tasks if you release them later pipe = pipeline(task, model=local_dir, tokenizer=tok_src, token=HF_TOKEN) PIPE_CACHE[key] = pipe return pipe def _canonicalize(scores: Dict[str, float]) -> Dict[str, float]: out: Dict[str, float] = {} for raw_label, sc in scores.items(): lab = CANON.get(raw_label, raw_label) out[lab] = max(sc, out.get(lab, 0.0)) return out def predict(models: List[str], task: str, text: str) -> Tuple[str, pd.DataFrame]: if not text.strip(): return "Please enter some text.", pd.DataFrame() if not models: return f"Please select 1–{MAX_MODELS} models.", pd.DataFrame() if len(models) > MAX_MODELS: models = models[:MAX_MODELS] table_rows: List[Dict[str, Any]] = [] label_union: set[str] = set() per_model_outputs: Dict[str, Dict[str, float]] = {} errors: Dict[str, str] = {} for rid in models: try: pipe = get_pipeline(rid, task) out = pipe(text) # text-classification pipeline typical shapes: # [[{label, score}, ...]] or [{label, score}, ...] if isinstance(out, list) and out and isinstance(out[0], list): scores = {d["label"]: float(d["score"]) for d in out[0]} elif isinstance(out, list) and out and isinstance(out[0], dict) and "label" in out[0]: scores = {d["label"]: float(d["score"]) for d in out} else: scores = {} scores = _canonicalize(scores) or {"": 1.0} per_model_outputs[rid] = scores label_union.update(scores.keys()) except Exception as e: per_model_outputs[rid] = {"": 1.0} label_union.add("") errors[rid] = str(e) # Build table with union of labels as columns label_cols = sorted(label_union) for rid in models: row = {"model": rid} scores = per_model_outputs.get(rid, {}) for lab in label_cols: row[lab] = scores.get(lab, 0.0) if scores: pred = max(scores.items(), key=lambda kv: kv[1])[0] row["predicted_label"] = pred else: row["predicted_label"] = "" table_rows.append(row) pred_df = pd.DataFrame(table_rows, columns=["model"] + label_cols + ["predicted_label"]) msg = f"✓ Done. Compared {len(models)} model(s) on task: `{task}`" if errors: msg += "\n\n**Errors**:\n" + "\n".join(f"- {k}: {v}" for k, v in errors.items()) return msg, pred_df # ========================= # UI wiring # ========================= def refresh_models(selected_task: str) -> Tuple[List[str], List[str]]: tasks, task2models = discover_tasks_and_models() models = task2models.get(selected_task, []) return tasks, models def on_task_change(selected_task: str) -> List[str]: _, task2models = discover_tasks_and_models() return task2models.get(selected_task, []) def select_all_models(selected_task: str) -> List[str]: _, task2models = discover_tasks_and_models() return task2models.get(selected_task, []) def build_ui() -> gr.Blocks: with gr.Blocks(fill_height=True, title="MediaBiasGroup — Model Comparator") as demo: gr.Markdown( "# MediaBiasGroup — Model Comparator\n" "Select a **task**, choose multiple models, enter text, and compare outputs side-by-side." ) with gr.Row(): with gr.Column(scale=1): tasks, task2models = discover_tasks_and_models() task_choices = tasks or [DEFAULT_TASK] task_default = task_choices[0] if task_choices else DEFAULT_TASK task_dd = gr.Dropdown( choices=task_choices, value=task_default, label="Task", ) model_ms = gr.Dropdown( choices=task2models.get(task_default, []), multiselect=True, label="Models", ) select_all_btn = gr.Button("Select all") gr.Markdown(f"**Organization:** `{ORG}` \n**Max models per run:** {MAX_MODELS}") with gr.Column(scale=2): text_in = gr.Textbox(lines=4, placeholder="Paste a sentence…", label="Input text") gr.Examples( examples=[ ["The bill passed the House on Tuesday in a 220–210 vote."], # unbiased/factual ["Lawmakers shamelessly rammed the bill through the House on Tuesday."], # biased/loaded ["Unemployment fell from 5.2% to 5.0% in July, according to government figures."], ["The corrupt regime bragged unemployment fell, but it's just cooking the books."], ], inputs=[text_in], label="Examples", ) run_btn = gr.Button("Compare") status = gr.Markdown("") # Single wide results table gr.Markdown("### Predictions") pred_df = gr.Dataframe(interactive=False) # Events task_dd.change(fn=on_task_change, inputs=[task_dd], outputs=[model_ms]) select_all_btn.click(fn=select_all_models, inputs=[task_dd], outputs=[model_ms]) run_btn.click(fn=predict, inputs=[model_ms, task_dd, text_in], outputs=[status, pred_df]) return demo if __name__ == "__main__": demo = build_ui() demo.queue(max_size=16).launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))