bitwise31337 commited on
Commit
904efce
·
verified ·
1 Parent(s): 0479029

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +197 -0
app.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import gradio as gr
3
+ import pandas as pd
4
+ from typing import List, Dict, Any, Tuple
5
+ from functools import lru_cache
6
+
7
+ from huggingface_hub import HfApi
8
+ from transformers import pipeline
9
+
10
+ ORG = "mediabiasgroup"
11
+ DEFAULT_TASK = "text-classification"
12
+ MAX_MODELS = 10 # safety cap to avoid loading too many models at once on CPU Spaces
13
+
14
+ api = HfApi()
15
+
16
+ @lru_cache(maxsize=1)
17
+ def list_org_models() -> List[Any]:
18
+ # full=True to fetch pipeline_tag & tags
19
+ return list(api.list_models(author=ORG, full=True))
20
+
21
+ def discover_tasks_and_models() -> Tuple[List[str], Dict[str, List[str]]]:
22
+ infos = list_org_models()
23
+ task2models: Dict[str, List[str]] = {}
24
+ for info in infos:
25
+ task = getattr(info, "pipeline_tag", None)
26
+ if not task:
27
+ # Try to infer from tags if missing
28
+ tags = set(getattr(info, "tags", []) or [])
29
+ # Very light heuristic; expand if you add other task types later
30
+ if "text-classification" in tags:
31
+ task = "text-classification"
32
+ if task:
33
+ task2models.setdefault(task, []).append(info.modelId)
34
+ tasks = sorted(task2models.keys())
35
+ # Keep deterministic sorting of model ids within each task
36
+ for t in task2models:
37
+ task2models[t] = sorted(task2models[t])
38
+ return tasks, task2models
39
+
40
+ @lru_cache(maxsize=256)
41
+ def get_card_data(repo_id: str) -> Dict[str, Any]:
42
+ try:
43
+ info = api.model_info(repo_id)
44
+ # .cardData is already a parsed dict when available
45
+ data = getattr(info, "cardData", None)
46
+ return data or {}
47
+ except Exception:
48
+ return {}
49
+
50
+ def extract_model_index_metrics(repo_id: str) -> pd.DataFrame:
51
+ data = get_card_data(repo_id)
52
+ rows = []
53
+ if not data:
54
+ return pd.DataFrame(columns=["model", "dataset", "task", "metric", "value"])
55
+ mi = data.get("model-index") or data.get("model_index") or []
56
+ for entry in mi:
57
+ name = entry.get("name", repo_id)
58
+ for res in entry.get("results", []):
59
+ task = res.get("task", {})
60
+ task_type = task.get("type", task.get("name", ""))
61
+ dset = res.get("dataset", {})
62
+ dname = dset.get("name", dset.get("type", ""))
63
+ for m in res.get("metrics", []):
64
+ rows.append({
65
+ "model": name,
66
+ "dataset": dname,
67
+ "task": task_type,
68
+ "metric": m.get("name", ""),
69
+ "value": m.get("value", None),
70
+ "repo_id": repo_id
71
+ })
72
+ if not rows:
73
+ return pd.DataFrame(columns=["model", "dataset", "task", "metric", "value"])
74
+ df = pd.DataFrame(rows)
75
+ # Optional: pivot for nicer viewing in the UI
76
+ return df
77
+
78
+ # Lazy-loaded pipelines cache
79
+ PIPE_CACHE: Dict[str, Any] = {}
80
+
81
+ def get_pipeline(repo_id: str, task: str):
82
+ key = f"{task}::{repo_id}"
83
+ if key in PIPE_CACHE:
84
+ return PIPE_CACHE[key]
85
+ # Use return_all_scores=True so we can compare per-label scores
86
+ if task == "text-classification":
87
+ pipe = pipeline(task, model=repo_id, tokenizer=repo_id, return_all_scores=True, truncation=True)
88
+ else:
89
+ # Add more pipelines if you start supporting other tasks
90
+ pipe = pipeline(task, model=repo_id, tokenizer=repo_id)
91
+ PIPE_CACHE[key] = pipe
92
+ return pipe
93
+
94
+ def predict(models: List[str], task: str, text: str) -> Tuple[str, pd.DataFrame, pd.DataFrame]:
95
+ if not text.strip():
96
+ return "Please enter some text.", pd.DataFrame(), pd.DataFrame()
97
+ if not models:
98
+ return "Please select 1–{} models.".format(MAX_MODELS), pd.DataFrame(), pd.DataFrame()
99
+ if len(models) > MAX_MODELS:
100
+ models = models[:MAX_MODELS]
101
+
102
+ # Run inference
103
+ table_rows = []
104
+ label_union = set()
105
+ per_model_outputs = {}
106
+
107
+ for rid in models:
108
+ try:
109
+ pipe = get_pipeline(rid, task)
110
+ out = pipe(text)
111
+ # text-classification returns: [ [ {label, score}, ... ] ]
112
+ if isinstance(out, list) and len(out) and isinstance(out[0], list):
113
+ scores = {d["label"]: float(d["score"]) for d in out[0]}
114
+ elif isinstance(out, list) and len(out) and isinstance(out[0], dict) and "label" in out[0]:
115
+ # Some classifiers return top-1 only
116
+ scores = {out[0]["label"]: float(out[0]["score"])}
117
+ else:
118
+ scores = {}
119
+ per_model_outputs[rid] = scores
120
+ label_union.update(scores.keys())
121
+ except Exception as e:
122
+ per_model_outputs[rid] = {"<error>": 0.0}
123
+ label_union.add("<error>")
124
+
125
+ # Build a nice table with union of labels as columns
126
+ label_cols = sorted(label_union)
127
+ for rid in models:
128
+ row = {"model": rid}
129
+ scores = per_model_outputs.get(rid, {})
130
+ for lab in label_cols:
131
+ row[lab] = scores.get(lab, 0.0)
132
+ # Also record the predicted (argmax) label if present
133
+ if scores:
134
+ pred = max(scores.items(), key=lambda kv: kv[1])[0]
135
+ row["predicted_label"] = pred
136
+ else:
137
+ row["predicted_label"] = ""
138
+ table_rows.append(row)
139
+ pred_df = pd.DataFrame(table_rows, columns=["model"] + label_cols + ["predicted_label"])
140
+
141
+ # Collect reported metrics if present
142
+ metrics_frames = []
143
+ for rid in models:
144
+ df = extract_model_index_metrics(rid)
145
+ if not df.empty:
146
+ df = df.copy()
147
+ df.insert(0, "repo_id", rid)
148
+ metrics_frames.append(df)
149
+ metrics_df = pd.concat(metrics_frames, ignore_index=True) if metrics_frames else pd.DataFrame()
150
+
151
+ msg = "✓ Done. Compared {} model(s) on task: `{}`".format(len(models), task)
152
+ return msg, pred_df, metrics_df
153
+
154
+ def refresh_models(selected_task: str) -> Tuple[List[str], List[str]]:
155
+ tasks, task2models = discover_tasks_and_models()
156
+ models = task2models.get(selected_task, [])
157
+ return tasks, models
158
+
159
+ def on_task_change(selected_task: str) -> List[str]:
160
+ _, task2models = discover_tasks_and_models()
161
+ return task2models.get(selected_task, [])
162
+
163
+ with gr.Blocks(fill_height=True, title="MediaBiasGroup — Model Comparator") as demo:
164
+ gr.Markdown(
165
+ "# MediaBiasGroup — Model Comparator\n"
166
+ "Select a **task**, choose multiple models, enter text, and compare outputs side-by-side. "
167
+ "If models provide a `model-index` in their cards, reported metrics are shown below."
168
+ )
169
+ with gr.Row():
170
+ with gr.Column(scale=1):
171
+ tasks, task2models = discover_tasks_and_models()
172
+ task_dd = gr.Dropdown(choices=tasks or [DEFAULT_TASK], value=(tasks[0] if tasks else DEFAULT_TASK), label="Task")
173
+ model_ms = gr.Dropdown(choices=task2models.get(tasks[0], []) if tasks else [], multiselect=True, label="Models")
174
+ refresh_btn = gr.Button("🔄 Refresh list from Hub")
175
+ gr.Markdown(
176
+ f"**Organization:** `{ORG}` \n"
177
+ f"**Max models per run:** {MAX_MODELS}"
178
+ )
179
+ with gr.Column(scale=2):
180
+ text_in = gr.Textbox(lines=4, placeholder="Paste a sentence…", label="Input text")
181
+ run_btn = gr.Button("Compare")
182
+ status = gr.Markdown("")
183
+ with gr.Row():
184
+ with gr.Column():
185
+ gr.Markdown("### Predictions")
186
+ pred_df = gr.Dataframe(wrap=True)
187
+ with gr.Column():
188
+ gr.Markdown("### Reported metrics (from model cards)")
189
+ metrics_df = gr.Dataframe(wrap=True)
190
+
191
+ # Events wiring
192
+ task_dd.change(fn=on_task_change, inputs=[task_dd], outputs=[model_ms])
193
+ refresh_btn.click(fn=refresh_models, inputs=[task_dd], outputs=[task_dd, model_ms])
194
+ run_btn.click(fn=predict, inputs=[model_ms, task_dd, text_in], outputs=[status, pred_df, metrics_df])
195
+
196
+ if __name__ == "__main__":
197
+ demo.launch()