model_compare / app.py
bitwise31337's picture
Update app.py
cca7ea6 verified
"""
MediaBiasGroup — Model Comparator (Gradio Space)
- Discovers org models by pipeline_tag
- Lets users pick a task, select multiple models, and compare outputs on the same input
- Uses a full local snapshot for robustness (avoids NoneType path issues)
- Falls back to base_model's tokenizer if a fine-tuned repo lacks tokenizer files
- Canonicalizes label names across models (LABEL_0 -> neutral, etc.)
- "Select all" button to quickly select all models for the chosen task
Requirements (see requirements.txt):
gradio>=4.31.4
transformers>=4.42.0
huggingface_hub>=0.23.0
torch>=2.2.0
pandas>=2.0.0
"""
from __future__ import annotations
import os
from functools import lru_cache
from typing import Any, Dict, List, Tuple
import gradio as gr
import pandas as pd
from huggingface_hub import HfApi, list_repo_files, snapshot_download
from transformers import pipeline
# =========================
# Configuration
# =========================
ORG = "mediabiasgroup"
DEFAULT_TASK = "text-classification"
MAX_MODELS = 10 # safety cap to avoid loading too many models at once on CPU Spaces
HF_TOKEN = (
os.environ.get("HF_TOKEN")
or os.environ.get("HUGGING_FACE_HUB_TOKEN")
or os.environ.get("HUGGINGFACEHUB_API_TOKEN")
)
api = HfApi(token=HF_TOKEN)
# Canonical label mapping (extend if needed)
CANON = {
"LABEL_0": "neutral",
"LABEL_1": "lexical_bias",
"NEGATIVE": "neutral",
"POSITIVE": "lexical_bias",
"neutral": "neutral",
"not_biased": "neutral",
"non-biased": "neutral",
"unbiased": "neutral",
"biased": "lexical_bias",
"lexical_bias": "lexical_bias",
}
# =========================
# Discovery & card helpers
# =========================
@lru_cache(maxsize=1)
def list_org_models() -> List[Any]:
# full=True to fetch pipeline_tag & tags
return list(api.list_models(author=ORG, full=True))
def discover_tasks_and_models() -> Tuple[List[str], Dict[str, List[str]]]:
infos = list_org_models()
task2models: Dict[str, List[str]] = {}
for info in infos:
task = getattr(info, "pipeline_tag", None)
if not task:
# Heuristic fallback via tags if pipeline_tag is missing
tags = set(getattr(info, "tags", []) or [])
if "text-classification" in tags:
task = "text-classification"
if task:
task2models.setdefault(task, []).append(info.modelId)
tasks = sorted(task2models.keys()) or [DEFAULT_TASK]
for t in task2models:
task2models[t] = sorted(task2models[t])
return tasks, task2models
@lru_cache(maxsize=256)
def get_card_data(repo_id: str) -> Dict[str, Any]:
try:
info = api.model_info(repo_id, token=HF_TOKEN)
data = getattr(info, "cardData", None)
if hasattr(data, "data"): # ModelCardData -> dict
return dict(data.data)
return data or {}
except Exception:
return {}
# =========================
# Tokenizer fallback logic
# =========================
def _has_tokenizer_files(repo_id: str) -> bool:
try:
files = set(list_repo_files(repo_id, repo_type="model", token=HF_TOKEN))
except Exception:
return False
if "tokenizer.json" in files:
return True
if {"vocab.json", "merges.txt"}.issubset(files):
return True
if "spiece.model" in files:
return True
return False
def _base_model_from_card(repo_id: str) -> str | None:
data = get_card_data(repo_id) or {}
base = data.get("base_model")
if isinstance(base, list):
base = base[0] if base else None
return base
def _tokenizer_source(repo_id: str) -> str:
# Prefer repo tokenizer; else fall back to base_model; else repo_id
if _has_tokenizer_files(repo_id):
return repo_id
base = _base_model_from_card(repo_id)
return base or repo_id
# =========================
# Pipelines & prediction
# =========================
PIPE_CACHE: Dict[str, Any] = {}
def get_pipeline(repo_id: str, task: str):
key = f"{task}::{repo_id}"
if key in PIPE_CACHE:
return PIPE_CACHE[key]
tok_src = _tokenizer_source(repo_id)
# Robust path: download a full local snapshot (no restrictive allow_patterns)
try:
local_dir = snapshot_download(
repo_id=repo_id,
repo_type="model",
token=HF_TOKEN, # works for public and gated/private (if token has access)
local_files_only=False,
)
if not isinstance(local_dir, str) or not local_dir:
# extremely defensive: fall back to remote id
local_dir = repo_id
except Exception:
local_dir = repo_id # fall back to remote if snapshot fails
if task == "text-classification":
pipe = pipeline(
task,
model=local_dir,
tokenizer=tok_src,
return_all_scores=True,
truncation=True,
token=HF_TOKEN,
)
else:
# Add more tasks if you release them later
pipe = pipeline(task, model=local_dir, tokenizer=tok_src, token=HF_TOKEN)
PIPE_CACHE[key] = pipe
return pipe
def _canonicalize(scores: Dict[str, float]) -> Dict[str, float]:
out: Dict[str, float] = {}
for raw_label, sc in scores.items():
lab = CANON.get(raw_label, raw_label)
out[lab] = max(sc, out.get(lab, 0.0))
return out
def predict(models: List[str], task: str, text: str) -> Tuple[str, pd.DataFrame]:
if not text.strip():
return "Please enter some text.", pd.DataFrame()
if not models:
return f"Please select 1–{MAX_MODELS} models.", pd.DataFrame()
if len(models) > MAX_MODELS:
models = models[:MAX_MODELS]
table_rows: List[Dict[str, Any]] = []
label_union: set[str] = set()
per_model_outputs: Dict[str, Dict[str, float]] = {}
errors: Dict[str, str] = {}
for rid in models:
try:
pipe = get_pipeline(rid, task)
out = pipe(text)
# text-classification pipeline typical shapes:
# [[{label, score}, ...]] or [{label, score}, ...]
if isinstance(out, list) and out and isinstance(out[0], list):
scores = {d["label"]: float(d["score"]) for d in out[0]}
elif isinstance(out, list) and out and isinstance(out[0], dict) and "label" in out[0]:
scores = {d["label"]: float(d["score"]) for d in out}
else:
scores = {}
scores = _canonicalize(scores) or {"<no_output>": 1.0}
per_model_outputs[rid] = scores
label_union.update(scores.keys())
except Exception as e:
per_model_outputs[rid] = {"<error>": 1.0}
label_union.add("<error>")
errors[rid] = str(e)
# Build table with union of labels as columns
label_cols = sorted(label_union)
for rid in models:
row = {"model": rid}
scores = per_model_outputs.get(rid, {})
for lab in label_cols:
row[lab] = scores.get(lab, 0.0)
if scores:
pred = max(scores.items(), key=lambda kv: kv[1])[0]
row["predicted_label"] = pred
else:
row["predicted_label"] = ""
table_rows.append(row)
pred_df = pd.DataFrame(table_rows, columns=["model"] + label_cols + ["predicted_label"])
msg = f"✓ Done. Compared {len(models)} model(s) on task: `{task}`"
if errors:
msg += "\n\n**Errors**:\n" + "\n".join(f"- {k}: {v}" for k, v in errors.items())
return msg, pred_df
# =========================
# UI wiring
# =========================
def refresh_models(selected_task: str) -> Tuple[List[str], List[str]]:
tasks, task2models = discover_tasks_and_models()
models = task2models.get(selected_task, [])
return tasks, models
def on_task_change(selected_task: str) -> List[str]:
_, task2models = discover_tasks_and_models()
return task2models.get(selected_task, [])
def select_all_models(selected_task: str) -> List[str]:
_, task2models = discover_tasks_and_models()
return task2models.get(selected_task, [])
def build_ui() -> gr.Blocks:
with gr.Blocks(fill_height=True, title="MediaBiasGroup — Model Comparator") as demo:
gr.Markdown(
"# MediaBiasGroup — Model Comparator\n"
"Select a **task**, choose multiple models, enter text, and compare outputs side-by-side."
)
with gr.Row():
with gr.Column(scale=1):
tasks, task2models = discover_tasks_and_models()
task_choices = tasks or [DEFAULT_TASK]
task_default = task_choices[0] if task_choices else DEFAULT_TASK
task_dd = gr.Dropdown(
choices=task_choices,
value=task_default,
label="Task",
)
model_ms = gr.Dropdown(
choices=task2models.get(task_default, []),
multiselect=True,
label="Models",
)
select_all_btn = gr.Button("Select all")
gr.Markdown(f"**Organization:** `{ORG}` \n**Max models per run:** {MAX_MODELS}")
with gr.Column(scale=2):
text_in = gr.Textbox(lines=4, placeholder="Paste a sentence…", label="Input text")
gr.Examples(
examples=[
["The bill passed the House on Tuesday in a 220–210 vote."], # unbiased/factual
["Lawmakers shamelessly rammed the bill through the House on Tuesday."], # biased/loaded
["Unemployment fell from 5.2% to 5.0% in July, according to government figures."],
["The corrupt regime bragged unemployment fell, but it's just cooking the books."],
],
inputs=[text_in],
label="Examples",
)
run_btn = gr.Button("Compare")
status = gr.Markdown("")
# Single wide results table
gr.Markdown("### Predictions")
pred_df = gr.Dataframe(interactive=False)
# Events
task_dd.change(fn=on_task_change, inputs=[task_dd], outputs=[model_ms])
select_all_btn.click(fn=select_all_models, inputs=[task_dd], outputs=[model_ms])
run_btn.click(fn=predict, inputs=[model_ms, task_dd, text_in], outputs=[status, pred_df])
return demo
if __name__ == "__main__":
demo = build_ui()
demo.queue(max_size=16).launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))