Supastrikas-004 commited on
Commit
a095974
Β·
verified Β·
1 Parent(s): badc270

Update evaluator.py

Browse files
Files changed (1) hide show
  1. evaluator.py +270 -37
evaluator.py CHANGED
@@ -1,38 +1,271 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import pandas as pd
2
- import textstat
3
- import language_tool_python
4
-
5
- # Use Public API version (does NOT require Java)
6
- tool = language_tool_python.LanguageToolPublicAPI('en-US')
7
-
8
- def evaluate_responses(df, use_llm_judge=False):
9
- scores = []
10
- for _, row in df.iterrows():
11
- response = row["response"]
12
-
13
- # Rule-based metrics
14
- grammar_matches = len(tool.check(response))
15
- readability = textstat.flesch_reading_ease(response)
16
-
17
- # Simple scoring
18
- instruction_follow = 1 if row["instruction"].lower() in response.lower() else 0
19
- coherence = 1 if readability > 40 else 0
20
- grammar_score = max(0, 1 - grammar_matches / 10)
21
-
22
- final_score = (instruction_follow + coherence + grammar_score) / 3
23
-
24
- # Optional LLM judge (stubbed for now, can hook Hugging Face API later)
25
- if use_llm_judge:
26
- final_score = (final_score + 0.8) / 2 # Example: trust LLM judge
27
-
28
- scores.append({
29
- "agent": row["agent"],
30
- "instruction": row["instruction"],
31
- "response": response,
32
- "score_instruction": instruction_follow,
33
- "score_coherence": coherence,
34
- "score_grammar": grammar_score,
35
- "final_score": final_score
36
- })
37
-
38
- return pd.DataFrame(scores)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # evaluator.py
2
+ """
3
+ Evaluation module: loads models (lightweight), computes metrics, and creates visualizations.
4
+ No Java required.
5
+ """
6
+
7
+ import re
8
+ import math
9
+ import uuid
10
+ import os
11
+ from typing import List, Dict, Tuple
12
+
13
+ import numpy as np
14
  import pandas as pd
15
+ import matplotlib.pyplot as plt
16
+ import seaborn as sns
17
+ import torch
18
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
19
+ from sentence_transformers import SentenceTransformer, util
20
+
21
+ # --------------------------
22
+ # MODEL LOADING (CPU-friendly)
23
+ # --------------------------
24
+ # Use small/medium models appropriate for Spaces.
25
+ NLI_MODEL = "textattack/roberta-base-MNLI"
26
+ EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
27
+
28
+ # Load NLI model & tokenizer (on CPU)
29
+ nli_tokenizer = AutoTokenizer.from_pretrained(NLI_MODEL)
30
+ nli_model = AutoModelForSequenceClassification.from_pretrained(NLI_MODEL)
31
+ nli_model.to("cpu")
32
+ nli_model.eval()
33
+
34
+ # Load embedding model
35
+ embed_model = SentenceTransformer(EMBED_MODEL)
36
+
37
+ # get label mapping from model config (e.g., {0: 'CONTRADICTION', 1:'NEUTRAL', 2:'ENTAILMENT'})
38
+ id2label = {int(k): v.upper() for k, v in nli_model.config.id2label.items()}
39
+
40
+ # --------------------------
41
+ # METRIC FUNCTIONS
42
+ # --------------------------
43
+ def check_instruction_following(prompt: str, response: str) -> float:
44
+ """Keyword-overlap heuristic (normalized)."""
45
+ prompt = (prompt or "").lower()
46
+ response = (response or "").lower()
47
+ keywords = re.findall(r"\b\w+\b", prompt)
48
+ if len(keywords) == 0:
49
+ return 0.0
50
+ matches = sum(1 for k in set(keywords) if k in response)
51
+ return round(matches / len(set(keywords)), 3)
52
+
53
+ def check_hallucination(reference: str, response: str) -> Tuple[float, float]:
54
+ """
55
+ Use NLI to get entailment and contradiction probabilities.
56
+ Returns (entail_prob, contra_prob) in [0,1].
57
+ If no reference provided, returns (0.0, 0.0).
58
+ """
59
+ if not reference or not response:
60
+ return 0.0, 0.0
61
+ with torch.no_grad():
62
+ inputs = nli_tokenizer.encode_plus(reference, response, return_tensors="pt", truncation=True)
63
+ outputs = nli_model(**inputs)
64
+ probs = torch.softmax(outputs.logits, dim=-1).cpu().numpy()[0]
65
+ # Map probabilities to labels using id2label
66
+ entail_prob = 0.0
67
+ contra_prob = 0.0
68
+ for idx, p in enumerate(probs):
69
+ label = id2label.get(idx, "").upper()
70
+ if "ENTAIL" in label:
71
+ entail_prob = float(p)
72
+ if "CONTRA" in label or "CONTRADICTION" in label:
73
+ contra_prob = float(p)
74
+ return round(entail_prob, 3), round(contra_prob, 3)
75
+
76
+ def check_assumption(response: str) -> float:
77
+ """Penalize speculative language (hedges)."""
78
+ if not response:
79
+ return 0.0
80
+ speculative_terms = ["maybe", "probably", "might", "perhaps", "i guess", "seems", "could"]
81
+ count = sum(1 for t in speculative_terms if t in response.lower())
82
+ score = 1.0 - min(count / 3.0, 1.0)
83
+ return round(score, 3)
84
+
85
+ def check_coherence(response: str) -> float:
86
+ """Placeholder coherence metric β€” using a bounded random or simple heuristic.
87
+ Replace with grammar/perplexity later. Returns in [0,1]."""
88
+ if not response:
89
+ return 0.0
90
+ # simple heuristic: longer responses that have many sentences get slightly higher
91
+ sents = max(1, len(re.split(r"[.!?]+", response)) - 1)
92
+ words = max(1, len(re.findall(r"\w+", response)))
93
+ base = min(1.0, (words / 50.0) + (sents / 5.0))
94
+ # clamp to [0.5, 0.98] to avoid extreme
95
+ val = max(0.5, min(base * 0.9, 0.98))
96
+ return round(val, 3)
97
+
98
+ def check_accuracy(reference: str, response: str) -> float:
99
+ """Semantic similarity between reference and response via embeddings (cosine)."""
100
+ if not reference or not response:
101
+ return 0.0
102
+ ref_emb = embed_model.encode(reference, convert_to_tensor=True)
103
+ resp_emb = embed_model.encode(response, convert_to_tensor=True)
104
+ sim = float(util.cos_sim(ref_emb, resp_emb).item())
105
+ # cosine similarity in [-1,1] but for sentences usually [0,1]
106
+ sim = max(0.0, min(1.0, sim))
107
+ return round(sim, 3)
108
+
109
+ # --------------------------
110
+ # AGGREGATION & SCORING
111
+ # --------------------------
112
+ def compute_row_scores(prompt, response, reference) -> Dict:
113
+ instr = check_instruction_following(prompt, response)
114
+ entail, contra = check_hallucination(reference, response)
115
+ assum = check_assumption(response)
116
+ coh = check_coherence(response)
117
+ acc = check_accuracy(reference, response)
118
+
119
+ # Combine hallucination metrics into single positive metric: entail good, contra bad
120
+ hyst = entail * (1 - contra)
121
+ hyst = round(max(0.0, min(1.0, hyst)), 3)
122
+
123
+ # final_score: simple average of six components (all in [0,1])
124
+ components = [instr, hyst, assum, coh, acc]
125
+ final = round(float(sum(components) / len(components)), 3)
126
+
127
+ return {
128
+ "InstructionFollowing": instr,
129
+ "Hallucination_Entail": entail,
130
+ "Hallucination_Contra": contra,
131
+ "Hallucination_Metric": hyst,
132
+ "AssumptionControl": assum,
133
+ "Coherence": coh,
134
+ "Accuracy": acc,
135
+ "FinalScore": final
136
+ }
137
+
138
+ # --------------------------
139
+ # VISUALIZATION HELPERS
140
+ # --------------------------
141
+ def spider_net_multi(labels: List[str], rows: List[Dict], title: str = "Spider (Radar) Chart", fill_alpha: float = 0.12):
142
+ """
143
+ Create and return Matplotlib figure for radar chart.
144
+ rows: list of {"name": str, "values": [v1,...,vN]} values assumed on 0-100 scale for visibility.
145
+ """
146
+ N = len(labels)
147
+ angles = [n / float(N) * 2 * math.pi for n in range(N)]
148
+ angles += angles[:1]
149
+
150
+ fig = plt.figure(figsize=(6.5, 6.5))
151
+ ax = plt.subplot(111, polar=True)
152
+
153
+ ax.set_xticks(angles[:-1])
154
+ ax.set_xticklabels(labels, fontsize=9)
155
+ # radial limits: 0 to 100
156
+ ax.set_ylim(0, 100)
157
+ ax.set_yticks([0, 25, 50, 75, 100])
158
+
159
+ for r in rows:
160
+ values = r["values"]
161
+ values_closed = values + values[:1]
162
+ ax.plot(angles, values_closed, linewidth=1.5, label=r["name"])
163
+ ax.fill(angles, values_closed, alpha=fill_alpha)
164
+
165
+ ax.set_title(title, y=1.08, fontsize=12)
166
+ ax.legend(loc="upper right", bbox_to_anchor=(1.25, 1.1))
167
+ return fig
168
+
169
+ def heatmap_plot(df: pd.DataFrame, metric_cols: List[str], title: str = "Metric Correlations"):
170
+ fig, ax = plt.subplots(figsize=(7, 5))
171
+ sns.heatmap(df[metric_cols].corr(), annot=True, fmt=".2f", cmap="coolwarm", ax=ax)
172
+ ax.set_title(title)
173
+ return fig
174
+
175
+ def bar_plot_avg(df: pd.DataFrame, metric_cols: List[str], title: str = "Average Metric Scores per Agent"):
176
+ agg = df.groupby("Agent")[metric_cols].mean().reset_index()
177
+ fig, ax = plt.subplots(figsize=(10, 5))
178
+ agg.set_index("Agent")[metric_cols].plot(kind="bar", ax=ax)
179
+ ax.set_title(title)
180
+ ax.set_ylabel("Score (0 - 1)")
181
+ plt.xticks(rotation=45)
182
+ plt.tight_layout()
183
+ return fig
184
+
185
+ # --------------------------
186
+ # HIGH-LEVEL EVALUATION (batch)
187
+ # --------------------------
188
+ def evaluate_dataframe(df: pd.DataFrame) -> Tuple[pd.DataFrame, List[Tuple[str,str]]]:
189
+ """
190
+ df must contain columns: prompt, response, task, agent, reference (reference optional)
191
+ Returns: metrics_df, list of (image_path, caption) for visualizations
192
+ """
193
+ # Normalize columns
194
+ df = df.rename(columns={c: c.strip() for c in df.columns})
195
+ # try to extract agent from metadata if not present
196
+ if "agent" not in df.columns and "metadata" in df.columns:
197
+ df["agent"] = df["metadata"].apply(lambda m: m.get("agent") if isinstance(m, dict) else None)
198
+
199
+ rows = []
200
+ for _, r in df.iterrows():
201
+ prompt = r.get("prompt", "")
202
+ response = r.get("response", "")
203
+ reference = r.get("reference", "") if "reference" in r else ""
204
+ agent = r.get("agent", "Unknown")
205
+ task = r.get("task", "Unknown")
206
+ scores = compute_row_scores(prompt, response, reference)
207
+ entry = {
208
+ "Task": str(task).strip(),
209
+ "Agent": str(agent),
210
+ "Prompt": prompt,
211
+ "Response": response,
212
+ "Reference": reference
213
+ }
214
+ entry.update(scores)
215
+ rows.append(entry)
216
+ metrics_df = pd.DataFrame(rows)
217
+
218
+ # Visualization artifacts
219
+ images = []
220
+
221
+ # Per-task spider charts
222
+ metric_labels = ["InstructionFollowing", "Hallucination_Metric", "AssumptionControl", "Coherence", "Accuracy"]
223
+ for task, g in metrics_df.groupby("Task"):
224
+ agents = g["Agent"].unique().tolist()
225
+ series = []
226
+ for a in agents:
227
+ subset = g[g["Agent"] == a]
228
+ vals = []
229
+ # convert to 0-100 scale for plot
230
+ for m in metric_labels:
231
+ vals.append(round(float(subset[m].mean()) * 100, 2))
232
+ series.append({"name": a, "values": vals})
233
+ if len(series) == 0:
234
+ continue
235
+ fig = spider_net_multi(metric_labels, series, title=f"{task} β€” Agent Comparison")
236
+ fname = f"/tmp/{uuid.uuid4().hex}_{task}_radar.png"
237
+ fig.savefig(fname, bbox_inches="tight")
238
+ plt.close(fig)
239
+ images.append((fname, f"{task} - radar"))
240
+
241
+ # also bar plot (averages) per task
242
+ try:
243
+ fig2, ax = plt.subplots(figsize=(8, 4))
244
+ avg = g.groupby("Agent")[["InstructionFollowing", "Hallucination_Metric", "AssumptionControl", "Coherence", "Accuracy"]].mean()
245
+ avg.plot(kind="bar", ax=ax)
246
+ ax.set_title(f"{task} β€” Average Metrics by Agent")
247
+ ax.set_ylabel("Score (0-1)")
248
+ plt.xticks(rotation=45)
249
+ fname2 = f"/tmp/{uuid.uuid4().hex}_{task}_bar.png"
250
+ fig2.savefig(fname2, bbox_inches="tight")
251
+ plt.close(fig2)
252
+ images.append((fname2, f"{task} - bar"))
253
+ except Exception:
254
+ pass
255
+
256
+ # Global heatmap
257
+ metric_cols = ["InstructionFollowing", "Hallucination_Metric", "AssumptionControl", "Coherence", "Accuracy", "FinalScore"]
258
+ try:
259
+ figh = heatmap_plot(metrics_df, metric_cols)
260
+ fnameh = f"/tmp/{uuid.uuid4().hex}_heatmap.png"
261
+ figh.savefig(fnameh, bbox_inches="tight")
262
+ plt.close(figh)
263
+ images.append((fnameh, "Metric Correlations Heatmap"))
264
+ except Exception:
265
+ pass
266
+
267
+ # Leaderboard: average final score per agent (global)
268
+ lb = metrics_df.groupby(["Agent", "Task"])["FinalScore"].mean().reset_index()
269
+ lb = lb.sort_values(["FinalScore"], ascending=False)
270
+
271
+ return metrics_df, images, lb