bitwise31337 commited on
Commit
cca7ea6
·
verified ·
1 Parent(s): 02d6bdb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +50 -100
app.py CHANGED
@@ -1,10 +1,11 @@
1
  """
2
  MediaBiasGroup — Model Comparator (Gradio Space)
3
- - Discovers models under the org and groups them by pipeline_tag
4
  - Lets users pick a task, select multiple models, and compare outputs on the same input
5
- - Reads any 'model-index' metrics from model cards and shows them in a table
6
  - Falls back to base_model's tokenizer if a fine-tuned repo lacks tokenizer files
7
- - Canonicalizes label names across models (e.g., LABEL_0 -> neutral)
 
8
 
9
  Requirements (see requirements.txt):
10
  gradio>=4.31.4
@@ -40,7 +41,7 @@ HF_TOKEN = (
40
 
41
  api = HfApi(token=HF_TOKEN)
42
 
43
- # Canonical label mapping (expand as needed)
44
  CANON = {
45
  "LABEL_0": "neutral",
46
  "LABEL_1": "lexical_bias",
@@ -54,9 +55,8 @@ CANON = {
54
  "lexical_bias": "lexical_bias",
55
  }
56
 
57
-
58
  # =========================
59
- # Discovery & metadata
60
  # =========================
61
  @lru_cache(maxsize=1)
62
  def list_org_models() -> List[Any]:
@@ -68,15 +68,12 @@ def discover_tasks_and_models() -> Tuple[List[str], Dict[str, List[str]]]:
68
  infos = list_org_models()
69
  task2models: Dict[str, List[str]] = {}
70
  for info in infos:
71
- # Prefer the explicit pipeline_tag
72
  task = getattr(info, "pipeline_tag", None)
73
-
74
- # Heuristic fallback via tags if pipeline_tag is missing
75
  if not task:
 
76
  tags = set(getattr(info, "tags", []) or [])
77
  if "text-classification" in tags:
78
  task = "text-classification"
79
-
80
  if task:
81
  task2models.setdefault(task, []).append(info.modelId)
82
 
@@ -87,45 +84,16 @@ def discover_tasks_and_models() -> Tuple[List[str], Dict[str, List[str]]]:
87
 
88
 
89
  @lru_cache(maxsize=256)
90
- def get_card_data(repo_id: str) -> dict:
91
  try:
92
  info = api.model_info(repo_id, token=HF_TOKEN)
93
  data = getattr(info, "cardData", None)
94
- return dict(getattr(data, "data", {})) if data else {}
 
 
95
  except Exception:
96
  return {}
97
 
98
- def extract_model_index_metrics(repo_id: str) -> pd.DataFrame:
99
- data = get_card_data(repo_id)
100
- rows: List[Dict[str, Any]] = []
101
- if not data:
102
- return pd.DataFrame(columns=["model", "dataset", "task", "metric", "value", "repo_id"])
103
-
104
- mi = data.get("model-index") or data.get("model_index") or []
105
- for entry in mi:
106
- name = entry.get("name", repo_id)
107
- for res in entry.get("results", []):
108
- task = res.get("task", {})
109
- task_type = task.get("type", task.get("name", ""))
110
- dset = res.get("dataset", {})
111
- dname = dset.get("name", dset.get("type", ""))
112
- for m in res.get("metrics", []):
113
- rows.append(
114
- {
115
- "model": name,
116
- "dataset": dname,
117
- "task": task_type,
118
- "metric": m.get("name", ""),
119
- "value": m.get("value", None),
120
- "repo_id": repo_id,
121
- }
122
- )
123
-
124
- if not rows:
125
- return pd.DataFrame(columns=["model", "dataset", "task", "metric", "value", "repo_id"])
126
- return pd.DataFrame(rows)
127
-
128
-
129
  # =========================
130
  # Tokenizer fallback logic
131
  # =========================
@@ -134,11 +102,14 @@ def _has_tokenizer_files(repo_id: str) -> bool:
134
  files = set(list_repo_files(repo_id, repo_type="model", token=HF_TOKEN))
135
  except Exception:
136
  return False
137
- return (
138
- "tokenizer.json" in files
139
- or {"vocab.json", "merges.txt"}.issubset(files)
140
- or "spiece.model" in files
141
- )
 
 
 
142
 
143
 
144
  def _base_model_from_card(repo_id: str) -> str | None:
@@ -150,13 +121,12 @@ def _base_model_from_card(repo_id: str) -> str | None:
150
 
151
 
152
  def _tokenizer_source(repo_id: str) -> str:
153
- # prefer repo tokenizer; else fall back to base_model; else repo_id
154
  if _has_tokenizer_files(repo_id):
155
  return repo_id
156
  base = _base_model_from_card(repo_id)
157
  return base or repo_id
158
 
159
-
160
  # =========================
161
  # Pipelines & prediction
162
  # =========================
@@ -170,27 +140,20 @@ def get_pipeline(repo_id: str, task: str):
170
 
171
  tok_src = _tokenizer_source(repo_id)
172
 
173
- # 1) Download a local snapshot of the model repo (robust to xet/LFS)
174
  try:
175
  local_dir = snapshot_download(
176
- repo_id,
177
- allow_patterns=[
178
- "config.json",
179
- "*.safetensors",
180
- "*.bin",
181
- "tokenizer.json",
182
- "vocab.json",
183
- "merges.txt",
184
- "tokenizer_config.json",
185
- "special_tokens_map.json",
186
- ],
187
- token=HF_TOKEN, # fine if None for public repos
188
  )
189
- except Exception as e:
190
- # Fallback: try remote loading if snapshot fails
191
- local_dir = repo_id
 
 
192
 
193
- # 2) Build pipeline; still pass tokenizer source (repo or base_model)
194
  if task == "text-classification":
195
  pipe = pipeline(
196
  task,
@@ -201,6 +164,7 @@ def get_pipeline(repo_id: str, task: str):
201
  token=HF_TOKEN,
202
  )
203
  else:
 
204
  pipe = pipeline(task, model=local_dir, tokenizer=tok_src, token=HF_TOKEN)
205
 
206
  PIPE_CACHE[key] = pipe
@@ -215,11 +179,11 @@ def _canonicalize(scores: Dict[str, float]) -> Dict[str, float]:
215
  return out
216
 
217
 
218
- def predict(models: List[str], task: str, text: str) -> Tuple[str, pd.DataFrame, pd.DataFrame]:
219
  if not text.strip():
220
- return "Please enter some text.", pd.DataFrame(), pd.DataFrame()
221
  if not models:
222
- return f"Please select 1–{MAX_MODELS} models.", pd.DataFrame(), pd.DataFrame()
223
  if len(models) > MAX_MODELS:
224
  models = models[:MAX_MODELS]
225
 
@@ -233,13 +197,11 @@ def predict(models: List[str], task: str, text: str) -> Tuple[str, pd.DataFrame,
233
  pipe = get_pipeline(rid, task)
234
  out = pipe(text)
235
 
236
- # text-classification pipeline:
237
- # typical shape: [ [ {label, score}, ... ] ] or [ {label, score}, ... ]
238
- scores: Dict[str, float]
239
  if isinstance(out, list) and out and isinstance(out[0], list):
240
  scores = {d["label"]: float(d["score"]) for d in out[0]}
241
  elif isinstance(out, list) and out and isinstance(out[0], dict) and "label" in out[0]:
242
- # some classifiers return flat list
243
  scores = {d["label"]: float(d["score"]) for d in out}
244
  else:
245
  scores = {}
@@ -269,22 +231,11 @@ def predict(models: List[str], task: str, text: str) -> Tuple[str, pd.DataFrame,
269
 
270
  pred_df = pd.DataFrame(table_rows, columns=["model"] + label_cols + ["predicted_label"])
271
 
272
- # Collect reported metrics if present
273
- metrics_frames = []
274
- for rid in models:
275
- df = extract_model_index_metrics(rid)
276
- if not df.empty:
277
- df = df.copy()
278
- df.insert(0, "repo_id", rid)
279
- metrics_frames.append(df)
280
- metrics_df = pd.concat(metrics_frames, ignore_index=True) if metrics_frames else pd.DataFrame()
281
-
282
  msg = f"✓ Done. Compared {len(models)} model(s) on task: `{task}`"
283
  if errors:
284
  msg += "\n\n**Errors**:\n" + "\n".join(f"- {k}: {v}" for k, v in errors.items())
285
 
286
- return msg, pred_df, metrics_df
287
-
288
 
289
  # =========================
290
  # UI wiring
@@ -300,12 +251,16 @@ def on_task_change(selected_task: str) -> List[str]:
300
  return task2models.get(selected_task, [])
301
 
302
 
 
 
 
 
 
303
  def build_ui() -> gr.Blocks:
304
  with gr.Blocks(fill_height=True, title="MediaBiasGroup — Model Comparator") as demo:
305
  gr.Markdown(
306
  "# MediaBiasGroup — Model Comparator\n"
307
- "Select a **task**, choose multiple models, enter text, and compare outputs side-by-side. "
308
- "If models provide a `model-index` in their cards, reported metrics appear below."
309
  )
310
 
311
  with gr.Row():
@@ -324,12 +279,12 @@ def build_ui() -> gr.Blocks:
324
  multiselect=True,
325
  label="Models",
326
  )
327
- refresh_btn = gr.Button("🔄 Refresh list from Hub")
328
  gr.Markdown(f"**Organization:** `{ORG}` \n**Max models per run:** {MAX_MODELS}")
329
 
330
  with gr.Column(scale=2):
331
  text_in = gr.Textbox(lines=4, placeholder="Paste a sentence…", label="Input text")
332
- examples = gr.Examples(
333
  examples=[
334
  ["The bill passed the House on Tuesday in a 220–210 vote."], # unbiased/factual
335
  ["Lawmakers shamelessly rammed the bill through the House on Tuesday."], # biased/loaded
@@ -342,23 +297,18 @@ def build_ui() -> gr.Blocks:
342
  run_btn = gr.Button("Compare")
343
  status = gr.Markdown("")
344
 
345
- with gr.Row():
346
- with gr.Column():
347
- gr.Markdown("### Predictions")
348
- pred_df = gr.Dataframe(interactive=False)
349
- with gr.Column():
350
- gr.Markdown("### Reported metrics (from model cards)")
351
- metrics_df = gr.Dataframe(interactive=False)
352
 
353
  # Events
354
  task_dd.change(fn=on_task_change, inputs=[task_dd], outputs=[model_ms])
355
- refresh_btn.click(fn=refresh_models, inputs=[task_dd], outputs=[task_dd, model_ms])
356
- run_btn.click(fn=predict, inputs=[model_ms, task_dd, text_in], outputs=[status, pred_df, metrics_df])
357
 
358
  return demo
359
 
360
 
361
  if __name__ == "__main__":
362
  demo = build_ui()
363
- # queue() allows concurrent requests; adjust concurrency per Space hardware
364
  demo.queue(max_size=16).launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))
 
1
  """
2
  MediaBiasGroup — Model Comparator (Gradio Space)
3
+ - Discovers org models by pipeline_tag
4
  - Lets users pick a task, select multiple models, and compare outputs on the same input
5
+ - Uses a full local snapshot for robustness (avoids NoneType path issues)
6
  - Falls back to base_model's tokenizer if a fine-tuned repo lacks tokenizer files
7
+ - Canonicalizes label names across models (LABEL_0 -> neutral, etc.)
8
+ - "Select all" button to quickly select all models for the chosen task
9
 
10
  Requirements (see requirements.txt):
11
  gradio>=4.31.4
 
41
 
42
  api = HfApi(token=HF_TOKEN)
43
 
44
+ # Canonical label mapping (extend if needed)
45
  CANON = {
46
  "LABEL_0": "neutral",
47
  "LABEL_1": "lexical_bias",
 
55
  "lexical_bias": "lexical_bias",
56
  }
57
 
 
58
  # =========================
59
+ # Discovery & card helpers
60
  # =========================
61
  @lru_cache(maxsize=1)
62
  def list_org_models() -> List[Any]:
 
68
  infos = list_org_models()
69
  task2models: Dict[str, List[str]] = {}
70
  for info in infos:
 
71
  task = getattr(info, "pipeline_tag", None)
 
 
72
  if not task:
73
+ # Heuristic fallback via tags if pipeline_tag is missing
74
  tags = set(getattr(info, "tags", []) or [])
75
  if "text-classification" in tags:
76
  task = "text-classification"
 
77
  if task:
78
  task2models.setdefault(task, []).append(info.modelId)
79
 
 
84
 
85
 
86
  @lru_cache(maxsize=256)
87
+ def get_card_data(repo_id: str) -> Dict[str, Any]:
88
  try:
89
  info = api.model_info(repo_id, token=HF_TOKEN)
90
  data = getattr(info, "cardData", None)
91
+ if hasattr(data, "data"): # ModelCardData -> dict
92
+ return dict(data.data)
93
+ return data or {}
94
  except Exception:
95
  return {}
96
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  # =========================
98
  # Tokenizer fallback logic
99
  # =========================
 
102
  files = set(list_repo_files(repo_id, repo_type="model", token=HF_TOKEN))
103
  except Exception:
104
  return False
105
+
106
+ if "tokenizer.json" in files:
107
+ return True
108
+ if {"vocab.json", "merges.txt"}.issubset(files):
109
+ return True
110
+ if "spiece.model" in files:
111
+ return True
112
+ return False
113
 
114
 
115
  def _base_model_from_card(repo_id: str) -> str | None:
 
121
 
122
 
123
  def _tokenizer_source(repo_id: str) -> str:
124
+ # Prefer repo tokenizer; else fall back to base_model; else repo_id
125
  if _has_tokenizer_files(repo_id):
126
  return repo_id
127
  base = _base_model_from_card(repo_id)
128
  return base or repo_id
129
 
 
130
  # =========================
131
  # Pipelines & prediction
132
  # =========================
 
140
 
141
  tok_src = _tokenizer_source(repo_id)
142
 
143
+ # Robust path: download a full local snapshot (no restrictive allow_patterns)
144
  try:
145
  local_dir = snapshot_download(
146
+ repo_id=repo_id,
147
+ repo_type="model",
148
+ token=HF_TOKEN, # works for public and gated/private (if token has access)
149
+ local_files_only=False,
 
 
 
 
 
 
 
 
150
  )
151
+ if not isinstance(local_dir, str) or not local_dir:
152
+ # extremely defensive: fall back to remote id
153
+ local_dir = repo_id
154
+ except Exception:
155
+ local_dir = repo_id # fall back to remote if snapshot fails
156
 
 
157
  if task == "text-classification":
158
  pipe = pipeline(
159
  task,
 
164
  token=HF_TOKEN,
165
  )
166
  else:
167
+ # Add more tasks if you release them later
168
  pipe = pipeline(task, model=local_dir, tokenizer=tok_src, token=HF_TOKEN)
169
 
170
  PIPE_CACHE[key] = pipe
 
179
  return out
180
 
181
 
182
+ def predict(models: List[str], task: str, text: str) -> Tuple[str, pd.DataFrame]:
183
  if not text.strip():
184
+ return "Please enter some text.", pd.DataFrame()
185
  if not models:
186
+ return f"Please select 1–{MAX_MODELS} models.", pd.DataFrame()
187
  if len(models) > MAX_MODELS:
188
  models = models[:MAX_MODELS]
189
 
 
197
  pipe = get_pipeline(rid, task)
198
  out = pipe(text)
199
 
200
+ # text-classification pipeline typical shapes:
201
+ # [[{label, score}, ...]] or [{label, score}, ...]
 
202
  if isinstance(out, list) and out and isinstance(out[0], list):
203
  scores = {d["label"]: float(d["score"]) for d in out[0]}
204
  elif isinstance(out, list) and out and isinstance(out[0], dict) and "label" in out[0]:
 
205
  scores = {d["label"]: float(d["score"]) for d in out}
206
  else:
207
  scores = {}
 
231
 
232
  pred_df = pd.DataFrame(table_rows, columns=["model"] + label_cols + ["predicted_label"])
233
 
 
 
 
 
 
 
 
 
 
 
234
  msg = f"✓ Done. Compared {len(models)} model(s) on task: `{task}`"
235
  if errors:
236
  msg += "\n\n**Errors**:\n" + "\n".join(f"- {k}: {v}" for k, v in errors.items())
237
 
238
+ return msg, pred_df
 
239
 
240
  # =========================
241
  # UI wiring
 
251
  return task2models.get(selected_task, [])
252
 
253
 
254
+ def select_all_models(selected_task: str) -> List[str]:
255
+ _, task2models = discover_tasks_and_models()
256
+ return task2models.get(selected_task, [])
257
+
258
+
259
  def build_ui() -> gr.Blocks:
260
  with gr.Blocks(fill_height=True, title="MediaBiasGroup — Model Comparator") as demo:
261
  gr.Markdown(
262
  "# MediaBiasGroup — Model Comparator\n"
263
+ "Select a **task**, choose multiple models, enter text, and compare outputs side-by-side."
 
264
  )
265
 
266
  with gr.Row():
 
279
  multiselect=True,
280
  label="Models",
281
  )
282
+ select_all_btn = gr.Button("Select all")
283
  gr.Markdown(f"**Organization:** `{ORG}` \n**Max models per run:** {MAX_MODELS}")
284
 
285
  with gr.Column(scale=2):
286
  text_in = gr.Textbox(lines=4, placeholder="Paste a sentence…", label="Input text")
287
+ gr.Examples(
288
  examples=[
289
  ["The bill passed the House on Tuesday in a 220–210 vote."], # unbiased/factual
290
  ["Lawmakers shamelessly rammed the bill through the House on Tuesday."], # biased/loaded
 
297
  run_btn = gr.Button("Compare")
298
  status = gr.Markdown("")
299
 
300
+ # Single wide results table
301
+ gr.Markdown("### Predictions")
302
+ pred_df = gr.Dataframe(interactive=False)
 
 
 
 
303
 
304
  # Events
305
  task_dd.change(fn=on_task_change, inputs=[task_dd], outputs=[model_ms])
306
+ select_all_btn.click(fn=select_all_models, inputs=[task_dd], outputs=[model_ms])
307
+ run_btn.click(fn=predict, inputs=[model_ms, task_dd, text_in], outputs=[status, pred_df])
308
 
309
  return demo
310
 
311
 
312
  if __name__ == "__main__":
313
  demo = build_ui()
 
314
  demo.queue(max_size=16).launch(server_name="0.0.0.0", server_port=int(os.environ.get("PORT", 7860)))