bitwise31337 commited on
Commit
a971b64
·
verified ·
1 Parent(s): 507f847

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +227 -80
app.py CHANGED
@@ -1,56 +1,103 @@
1
- import gradio as gr
2
- import pandas as pd
3
- from typing import List, Dict, Any, Tuple
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
  from functools import lru_cache
 
5
 
6
- from huggingface_hub import HfApi
 
 
7
  from transformers import pipeline
8
 
 
 
 
9
  ORG = "mediabiasgroup"
10
  DEFAULT_TASK = "text-classification"
11
  MAX_MODELS = 10 # safety cap to avoid loading too many models at once on CPU Spaces
12
 
13
  api = HfApi()
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  @lru_cache(maxsize=1)
16
  def list_org_models() -> List[Any]:
17
  # full=True to fetch pipeline_tag & tags
18
  return list(api.list_models(author=ORG, full=True))
19
 
 
20
  def discover_tasks_and_models() -> Tuple[List[str], Dict[str, List[str]]]:
21
  infos = list_org_models()
22
  task2models: Dict[str, List[str]] = {}
23
  for info in infos:
 
24
  task = getattr(info, "pipeline_tag", None)
 
 
25
  if not task:
26
- # Try to infer from tags if missing
27
  tags = set(getattr(info, "tags", []) or [])
28
- # Very light heuristic; expand if you add other task types later
29
  if "text-classification" in tags:
30
  task = "text-classification"
 
31
  if task:
32
  task2models.setdefault(task, []).append(info.modelId)
33
- tasks = sorted(task2models.keys())
34
- # Keep deterministic sorting of model ids within each task
35
  for t in task2models:
36
  task2models[t] = sorted(task2models[t])
37
  return tasks, task2models
38
 
 
39
  @lru_cache(maxsize=256)
40
  def get_card_data(repo_id: str) -> Dict[str, Any]:
41
  try:
42
  info = api.model_info(repo_id)
43
- # .cardData is already a parsed dict when available
44
  data = getattr(info, "cardData", None)
 
 
45
  return data or {}
46
  except Exception:
47
  return {}
48
 
 
49
  def extract_model_index_metrics(repo_id: str) -> pd.DataFrame:
50
  data = get_card_data(repo_id)
51
- rows = []
52
  if not data:
53
- return pd.DataFrame(columns=["model", "dataset", "task", "metric", "value"])
 
54
  mi = data.get("model-index") or data.get("model_index") or []
55
  for entry in mi:
56
  name = entry.get("name", repo_id)
@@ -60,83 +107,147 @@ def extract_model_index_metrics(repo_id: str) -> pd.DataFrame:
60
  dset = res.get("dataset", {})
61
  dname = dset.get("name", dset.get("type", ""))
62
  for m in res.get("metrics", []):
63
- rows.append({
64
- "model": name,
65
- "dataset": dname,
66
- "task": task_type,
67
- "metric": m.get("name", ""),
68
- "value": m.get("value", None),
69
- "repo_id": repo_id
70
- })
 
 
 
71
  if not rows:
72
- return pd.DataFrame(columns=["model", "dataset", "task", "metric", "value"])
73
- df = pd.DataFrame(rows)
74
- # Optional: pivot for nicer viewing in the UI
75
- return df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
- # Lazy-loaded pipelines cache
 
 
78
  PIPE_CACHE: Dict[str, Any] = {}
79
 
 
80
  def get_pipeline(repo_id: str, task: str):
81
  key = f"{task}::{repo_id}"
82
  if key in PIPE_CACHE:
83
  return PIPE_CACHE[key]
84
- # Use return_all_scores=True so we can compare per-label scores
 
 
85
  if task == "text-classification":
86
- pipe = pipeline(task, model=repo_id, tokenizer=repo_id, return_all_scores=True, truncation=True)
 
 
 
 
 
 
87
  else:
88
- # Add more pipelines if you start supporting other tasks
89
- pipe = pipeline(task, model=repo_id, tokenizer=repo_id)
 
90
  PIPE_CACHE[key] = pipe
91
  return pipe
92
 
 
 
 
 
 
 
 
 
 
93
  def predict(models: List[str], task: str, text: str) -> Tuple[str, pd.DataFrame, pd.DataFrame]:
94
  if not text.strip():
95
  return "Please enter some text.", pd.DataFrame(), pd.DataFrame()
96
  if not models:
97
- return "Please select 1–{} models.".format(MAX_MODELS), pd.DataFrame(), pd.DataFrame()
98
  if len(models) > MAX_MODELS:
99
  models = models[:MAX_MODELS]
100
-
101
- # Run inference
102
- table_rows = []
103
- label_union = set()
104
- per_model_outputs = {}
105
-
106
  for rid in models:
107
  try:
108
  pipe = get_pipeline(rid, task)
109
  out = pipe(text)
110
- # text-classification returns: [ [ {label, score}, ... ] ]
111
- if isinstance(out, list) and len(out) and isinstance(out[0], list):
 
 
 
112
  scores = {d["label"]: float(d["score"]) for d in out[0]}
113
- elif isinstance(out, list) and len(out) and isinstance(out[0], dict) and "label" in out[0]:
114
- # Some classifiers return top-1 only
115
- scores = {out[0]["label"]: float(out[0]["score"])}
116
  else:
117
  scores = {}
 
 
118
  per_model_outputs[rid] = scores
119
  label_union.update(scores.keys())
 
120
  except Exception as e:
121
- per_model_outputs[rid] = {"<error>": 0.0}
122
  label_union.add("<error>")
123
-
124
- # Build a nice table with union of labels as columns
 
125
  label_cols = sorted(label_union)
126
  for rid in models:
127
  row = {"model": rid}
128
  scores = per_model_outputs.get(rid, {})
129
  for lab in label_cols:
130
  row[lab] = scores.get(lab, 0.0)
131
- # Also record the predicted (argmax) label if present
132
  if scores:
133
  pred = max(scores.items(), key=lambda kv: kv[1])[0]
134
  row["predicted_label"] = pred
135
  else:
136
  row["predicted_label"] = ""
137
  table_rows.append(row)
 
138
  pred_df = pd.DataFrame(table_rows, columns=["model"] + label_cols + ["predicted_label"])
139
-
140
  # Collect reported metrics if present
141
  metrics_frames = []
142
  for rid in models:
@@ -146,51 +257,87 @@ def predict(models: List[str], task: str, text: str) -> Tuple[str, pd.DataFrame,
146
  df.insert(0, "repo_id", rid)
147
  metrics_frames.append(df)
148
  metrics_df = pd.concat(metrics_frames, ignore_index=True) if metrics_frames else pd.DataFrame()
149
-
150
- msg = "✓ Done. Compared {} model(s) on task: `{}`".format(len(models), task)
 
 
 
151
  return msg, pred_df, metrics_df
152
 
 
 
 
 
153
  def refresh_models(selected_task: str) -> Tuple[List[str], List[str]]:
154
  tasks, task2models = discover_tasks_and_models()
155
  models = task2models.get(selected_task, [])
156
  return tasks, models
157
 
 
158
  def on_task_change(selected_task: str) -> List[str]:
159
  _, task2models = discover_tasks_and_models()
160
  return task2models.get(selected_task, [])
161
 
162
- with gr.Blocks(fill_height=True, title="MediaBiasGroup — Model Comparator") as demo:
163
- gr.Markdown(
164
- "# MediaBiasGroup — Model Comparator\n"
165
- "Select a **task**, choose multiple models, enter text, and compare outputs side-by-side. "
166
- "If models provide a `model-index` in their cards, reported metrics are shown below."
167
- )
168
- with gr.Row():
169
- with gr.Column(scale=1):
170
- tasks, task2models = discover_tasks_and_models()
171
- task_dd = gr.Dropdown(choices=tasks or [DEFAULT_TASK], value=(tasks[0] if tasks else DEFAULT_TASK), label="Task")
172
- model_ms = gr.Dropdown(choices=task2models.get(tasks[0], []) if tasks else [], multiselect=True, label="Models")
173
- refresh_btn = gr.Button("🔄 Refresh list from Hub")
174
- gr.Markdown(
175
- f"**Organization:** `{ORG}` \n"
176
- f"**Max models per run:** {MAX_MODELS}"
177
- )
178
- with gr.Column(scale=2):
179
- text_in = gr.Textbox(lines=4, placeholder="Paste a sentence…", label="Input text")
180
- run_btn = gr.Button("Compare")
181
- status = gr.Markdown("")
182
- with gr.Row():
183
- with gr.Column():
184
- gr.Markdown("### Predictions")
185
- pred_df = gr.Dataframe(wrap=True)
186
- with gr.Column():
187
- gr.Markdown("### Reported metrics (from model cards)")
188
- metrics_df = gr.Dataframe(wrap=True)
189
-
190
- # Events wiring
191
- task_dd.change(fn=on_task_change, inputs=[task_dd], outputs=[model_ms])
192
- refresh_btn.click(fn=refresh_models, inputs=[task_dd], outputs=[task_dd, model_ms])
193
- run_btn.click(fn=predict, inputs=[model_ms, task_dd, text_in], outputs=[status, pred_df, metrics_df])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
  if __name__ == "__main__":
196
- demo.launch()
 
 
 
1
+ """
2
+ MediaBiasGroup Model Comparator (Gradio Space)
3
+ - Discovers models under the org and groups them by pipeline_tag
4
+ - Lets users pick a task, select multiple models, and compare outputs on the same input
5
+ - Reads any 'model-index' metrics from model cards and shows them in a table
6
+ - Falls back to base_model's tokenizer if a fine-tuned repo lacks tokenizer files
7
+ - Canonicalizes label names across models (e.g., LABEL_0 -> neutral)
8
+
9
+ Requirements (see requirements.txt):
10
+ gradio>=4.31.4
11
+ transformers>=4.42.0
12
+ huggingface_hub>=0.23.0
13
+ torch>=2.2.0
14
+ pandas>=2.0.0
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import os
20
  from functools import lru_cache
21
+ from typing import Any, Dict, List, Tuple
22
 
23
+ import gradio as gr
24
+ import pandas as pd
25
+ from huggingface_hub import HfApi, list_repo_files
26
  from transformers import pipeline
27
 
28
+ # =========================
29
+ # Configuration
30
+ # =========================
31
  ORG = "mediabiasgroup"
32
  DEFAULT_TASK = "text-classification"
33
  MAX_MODELS = 10 # safety cap to avoid loading too many models at once on CPU Spaces
34
 
35
  api = HfApi()
36
 
37
+ # Canonical label mapping (expand as needed)
38
+ CANON = {
39
+ "LABEL_0": "neutral",
40
+ "LABEL_1": "lexical_bias",
41
+ "NEGATIVE": "neutral",
42
+ "POSITIVE": "lexical_bias",
43
+ "neutral": "neutral",
44
+ "not_biased": "neutral",
45
+ "non-biased": "neutral",
46
+ "unbiased": "neutral",
47
+ "biased": "lexical_bias",
48
+ "lexical_bias": "lexical_bias",
49
+ }
50
+
51
+
52
+ # =========================
53
+ # Discovery & metadata
54
+ # =========================
55
  @lru_cache(maxsize=1)
56
  def list_org_models() -> List[Any]:
57
  # full=True to fetch pipeline_tag & tags
58
  return list(api.list_models(author=ORG, full=True))
59
 
60
+
61
  def discover_tasks_and_models() -> Tuple[List[str], Dict[str, List[str]]]:
62
  infos = list_org_models()
63
  task2models: Dict[str, List[str]] = {}
64
  for info in infos:
65
+ # Prefer the explicit pipeline_tag
66
  task = getattr(info, "pipeline_tag", None)
67
+
68
+ # Heuristic fallback via tags if pipeline_tag is missing
69
  if not task:
 
70
  tags = set(getattr(info, "tags", []) or [])
 
71
  if "text-classification" in tags:
72
  task = "text-classification"
73
+
74
  if task:
75
  task2models.setdefault(task, []).append(info.modelId)
76
+
77
+ tasks = sorted(task2models.keys()) or [DEFAULT_TASK]
78
  for t in task2models:
79
  task2models[t] = sorted(task2models[t])
80
  return tasks, task2models
81
 
82
+
83
  @lru_cache(maxsize=256)
84
  def get_card_data(repo_id: str) -> Dict[str, Any]:
85
  try:
86
  info = api.model_info(repo_id)
 
87
  data = getattr(info, "cardData", None)
88
+ if hasattr(data, "data"): # ModelCardData -> dict
89
+ return dict(data.data)
90
  return data or {}
91
  except Exception:
92
  return {}
93
 
94
+
95
  def extract_model_index_metrics(repo_id: str) -> pd.DataFrame:
96
  data = get_card_data(repo_id)
97
+ rows: List[Dict[str, Any]] = []
98
  if not data:
99
+ return pd.DataFrame(columns=["model", "dataset", "task", "metric", "value", "repo_id"])
100
+
101
  mi = data.get("model-index") or data.get("model_index") or []
102
  for entry in mi:
103
  name = entry.get("name", repo_id)
 
107
  dset = res.get("dataset", {})
108
  dname = dset.get("name", dset.get("type", ""))
109
  for m in res.get("metrics", []):
110
+ rows.append(
111
+ {
112
+ "model": name,
113
+ "dataset": dname,
114
+ "task": task_type,
115
+ "metric": m.get("name", ""),
116
+ "value": m.get("value", None),
117
+ "repo_id": repo_id,
118
+ }
119
+ )
120
+
121
  if not rows:
122
+ return pd.DataFrame(columns=["model", "dataset", "task", "metric", "value", "repo_id"])
123
+ return pd.DataFrame(rows)
124
+
125
+
126
+ # =========================
127
+ # Tokenizer fallback logic
128
+ # =========================
129
+ def _has_tokenizer_files(repo_id: str) -> bool:
130
+ try:
131
+ files = set(list_repo_files(repo_id, repo_type="model"))
132
+ except Exception:
133
+ return False
134
+
135
+ if "tokenizer.json" in files:
136
+ return True
137
+ if {"vocab.json", "merges.txt"}.issubset(files):
138
+ return True
139
+ if "spiece.model" in files:
140
+ return True
141
+ return False
142
+
143
+
144
+ def _base_model_from_card(repo_id: str) -> str | None:
145
+ data = get_card_data(repo_id) or {}
146
+ base = data.get("base_model")
147
+ if isinstance(base, list):
148
+ base = base[0] if base else None
149
+ return base
150
+
151
+
152
+ def _tokenizer_source(repo_id: str) -> str:
153
+ # prefer repo tokenizer; else fall back to base_model; else repo_id
154
+ if _has_tokenizer_files(repo_id):
155
+ return repo_id
156
+ base = _base_model_from_card(repo_id)
157
+ return base or repo_id
158
+
159
 
160
+ # =========================
161
+ # Pipelines & prediction
162
+ # =========================
163
  PIPE_CACHE: Dict[str, Any] = {}
164
 
165
+
166
  def get_pipeline(repo_id: str, task: str):
167
  key = f"{task}::{repo_id}"
168
  if key in PIPE_CACHE:
169
  return PIPE_CACHE[key]
170
+
171
+ tok_src = _tokenizer_source(repo_id)
172
+
173
  if task == "text-classification":
174
+ pipe = pipeline(
175
+ task,
176
+ model=repo_id,
177
+ tokenizer=tok_src,
178
+ return_all_scores=True,
179
+ truncation=True,
180
+ )
181
  else:
182
+ # Add more tasks if you release them later
183
+ pipe = pipeline(task, model=repo_id, tokenizer=tok_src)
184
+
185
  PIPE_CACHE[key] = pipe
186
  return pipe
187
 
188
+
189
+ def _canonicalize(scores: Dict[str, float]) -> Dict[str, float]:
190
+ out: Dict[str, float] = {}
191
+ for raw_label, sc in scores.items():
192
+ lab = CANON.get(raw_label, raw_label)
193
+ out[lab] = max(sc, out.get(lab, 0.0))
194
+ return out
195
+
196
+
197
  def predict(models: List[str], task: str, text: str) -> Tuple[str, pd.DataFrame, pd.DataFrame]:
198
  if not text.strip():
199
  return "Please enter some text.", pd.DataFrame(), pd.DataFrame()
200
  if not models:
201
+ return f"Please select 1–{MAX_MODELS} models.", pd.DataFrame(), pd.DataFrame()
202
  if len(models) > MAX_MODELS:
203
  models = models[:MAX_MODELS]
204
+
205
+ table_rows: List[Dict[str, Any]] = []
206
+ label_union: set[str] = set()
207
+ per_model_outputs: Dict[str, Dict[str, float]] = {}
208
+ errors: Dict[str, str] = {}
209
+
210
  for rid in models:
211
  try:
212
  pipe = get_pipeline(rid, task)
213
  out = pipe(text)
214
+
215
+ # text-classification pipeline:
216
+ # typical shape: [ [ {label, score}, ... ] ] or [ {label, score}, ... ]
217
+ scores: Dict[str, float]
218
+ if isinstance(out, list) and out and isinstance(out[0], list):
219
  scores = {d["label"]: float(d["score"]) for d in out[0]}
220
+ elif isinstance(out, list) and out and isinstance(out[0], dict) and "label" in out[0]:
221
+ # some classifiers return flat list
222
+ scores = {d["label"]: float(d["score"]) for d in out}
223
  else:
224
  scores = {}
225
+
226
+ scores = _canonicalize(scores) or {"<no_output>": 1.0}
227
  per_model_outputs[rid] = scores
228
  label_union.update(scores.keys())
229
+
230
  except Exception as e:
231
+ per_model_outputs[rid] = {"<error>": 1.0}
232
  label_union.add("<error>")
233
+ errors[rid] = str(e)
234
+
235
+ # Build table with union of labels as columns
236
  label_cols = sorted(label_union)
237
  for rid in models:
238
  row = {"model": rid}
239
  scores = per_model_outputs.get(rid, {})
240
  for lab in label_cols:
241
  row[lab] = scores.get(lab, 0.0)
 
242
  if scores:
243
  pred = max(scores.items(), key=lambda kv: kv[1])[0]
244
  row["predicted_label"] = pred
245
  else:
246
  row["predicted_label"] = ""
247
  table_rows.append(row)
248
+
249
  pred_df = pd.DataFrame(table_rows, columns=["model"] + label_cols + ["predicted_label"])
250
+
251
  # Collect reported metrics if present
252
  metrics_frames = []
253
  for rid in models:
 
257
  df.insert(0, "repo_id", rid)
258
  metrics_frames.append(df)
259
  metrics_df = pd.concat(metrics_frames, ignore_index=True) if metrics_frames else pd.DataFrame()
260
+
261
+ msg = f"✓ Done. Compared {len(models)} model(s) on task: `{task}`"
262
+ if errors:
263
+ msg += "\n\n**Errors**:\n" + "\n".join(f"- {k}: {v}" for k, v in errors.items())
264
+
265
  return msg, pred_df, metrics_df
266
 
267
+
268
+ # =========================
269
+ # UI wiring
270
+ # =========================
271
  def refresh_models(selected_task: str) -> Tuple[List[str], List[str]]:
272
  tasks, task2models = discover_tasks_and_models()
273
  models = task2models.get(selected_task, [])
274
  return tasks, models
275
 
276
+
277
  def on_task_change(selected_task: str) -> List[str]:
278
  _, task2models = discover_tasks_and_models()
279
  return task2models.get(selected_task, [])
280
 
281
+
282
+ def build_ui() -> gr.Blocks:
283
+ with gr.Blocks(fill_height=True, title="MediaBiasGroup — Model Comparator") as demo:
284
+ gr.Markdown(
285
+ "# MediaBiasGroup Model Comparator\n"
286
+ "Select a **task**, choose multiple models, enter text, and compare outputs side-by-side. "
287
+ "If models provide a `model-index` in their cards, reported metrics appear below."
288
+ )
289
+
290
+ with gr.Row():
291
+ with gr.Column(scale=1):
292
+ tasks, task2models = discover_tasks_and_models()
293
+ task_choices = tasks or [DEFAULT_TASK]
294
+ task_default = task_choices[0] if task_choices else DEFAULT_TASK
295
+
296
+ task_dd = gr.Dropdown(
297
+ choices=task_choices,
298
+ value=task_default,
299
+ label="Task",
300
+ )
301
+ model_ms = gr.Dropdown(
302
+ choices=task2models.get(task_default, []),
303
+ multiselect=True,
304
+ label="Models",
305
+ )
306
+ refresh_btn = gr.Button("🔄 Refresh list from Hub")
307
+ gr.Markdown(f"**Organization:** `{ORG}` \n**Max models per run:** {MAX_MODELS}")
308
+
309
+ with gr.Column(scale=2):
310
+ text_in = gr.Textbox(lines=4, placeholder="Paste a sentence…", label="Input text")
311
+ examples = gr.Examples(
312
+ examples=[
313
+ ["The bill passed the House on Tuesday in a 220–210 vote."], # unbiased/factual
314
+ ["Lawmakers shamelessly rammed the bill through the House on Tuesday."], # biased/loaded
315
+ ["Unemployment fell from 5.2% to 5.0% in July, according to government figures."],
316
+ ["The corrupt regime bragged unemployment fell, but it's just cooking the books."],
317
+ ],
318
+ inputs=[text_in],
319
+ label="Examples",
320
+ )
321
+ run_btn = gr.Button("Compare")
322
+ status = gr.Markdown("")
323
+
324
+ with gr.Row():
325
+ with gr.Column():
326
+ gr.Markdown("### Predictions")
327
+ pred_df = gr.Dataframe(interactive=False)
328
+ with gr.Column():
329
+ gr.Markdown("### Reported metrics (from model cards)")
330
+ metrics_df = gr.Dataframe(interactive=False)
331
+
332
+ # Events
333
+ task_dd.change(fn=on_task_change, inputs=[task_dd], outputs=[model_ms])
334
+ refresh_btn.click(fn=refresh_models, inputs=[task_dd], outputs=[task_dd, model_ms])
335
+ run_btn.click(fn=predict, inputs=[model_ms, task_dd, text_in], outputs=[status, pred_df, metrics_df])
336
+
337
+ return demo
338
+
339
 
340
  if __name__ == "__main__":
341
+ demo = build_ui()
342
+ # queue() allows concurrent requests; adjust concurrency per Space hardware
343
+ demo.queue(max_size=16).launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))