Files changed (1) hide show
  1. evaluator.py +130 -135
evaluator.py CHANGED
@@ -1,13 +1,11 @@
1
- # evaluator.py
2
  """
3
- Evaluation module: loads models (lightweight), computes metrics, and creates visualizations.
4
- No Java required.
5
  """
6
 
7
  import re
8
  import math
9
  import uuid
10
- import os
11
  from typing import List, Dict, Tuple
12
 
13
  import numpy as np
@@ -19,13 +17,12 @@ from transformers import AutoTokenizer, AutoModelForSequenceClassification
19
  from sentence_transformers import SentenceTransformer, util
20
 
21
  # --------------------------
22
- # MODEL LOADING (CPU-friendly)
23
  # --------------------------
24
- # Use small/medium models appropriate for Spaces.
25
  NLI_MODEL = "textattack/roberta-base-MNLI"
26
  EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
27
 
28
- # Load NLI model & tokenizer (on CPU)
29
  nli_tokenizer = AutoTokenizer.from_pretrained(NLI_MODEL)
30
  nli_model = AutoModelForSequenceClassification.from_pretrained(NLI_MODEL)
31
  nli_model.to("cpu")
@@ -34,66 +31,72 @@ nli_model.eval()
34
  # Load embedding model
35
  embed_model = SentenceTransformer(EMBED_MODEL)
36
 
37
- # get label mapping from model config (e.g., {0: 'CONTRADICTION', 1:'NEUTRAL', 2:'ENTAILMENT'})
38
  id2label = {int(k): v.upper() for k, v in nli_model.config.id2label.items()}
39
 
 
40
  # --------------------------
41
  # METRIC FUNCTIONS
42
  # --------------------------
43
  def check_instruction_following(prompt: str, response: str) -> float:
44
- """Keyword-overlap heuristic (normalized)."""
45
- prompt = (prompt or "").lower()
46
- response = (response or "").lower()
47
- keywords = re.findall(r"\b\w+\b", prompt)
48
- if len(keywords) == 0:
49
  return 0.0
50
- matches = sum(1 for k in set(keywords) if k in response)
51
- return round(matches / len(set(keywords)), 3)
 
 
 
52
 
53
- def check_hallucination(reference: str, response: str) -> Tuple[float, float]:
54
  """
55
- Use NLI to get entailment and contradiction probabilities.
56
- Returns (entail_prob, contra_prob) in [0,1].
57
- If no reference provided, returns (0.0, 0.0).
58
  """
59
  if not reference or not response:
60
- return 0.0, 0.0
61
  with torch.no_grad():
62
  inputs = nli_tokenizer.encode_plus(reference, response, return_tensors="pt", truncation=True)
63
  outputs = nli_model(**inputs)
64
  probs = torch.softmax(outputs.logits, dim=-1).cpu().numpy()[0]
65
- # Map probabilities to labels using id2label
66
- entail_prob = 0.0
67
- contra_prob = 0.0
68
  for idx, p in enumerate(probs):
69
- label = id2label.get(idx, "").upper()
70
  if "ENTAIL" in label:
71
  entail_prob = float(p)
72
- if "CONTRA" in label or "CONTRADICTION" in label:
73
  contra_prob = float(p)
74
- return round(entail_prob, 3), round(contra_prob, 3)
 
 
 
 
75
 
76
  def check_assumption(response: str) -> float:
77
- """Penalize speculative language (hedges)."""
78
  if not response:
79
  return 0.0
80
  speculative_terms = ["maybe", "probably", "might", "perhaps", "i guess", "seems", "could"]
81
  count = sum(1 for t in speculative_terms if t in response.lower())
82
- score = 1.0 - min(count / 3.0, 1.0)
83
  return round(score, 3)
84
 
 
85
  def check_coherence(response: str) -> float:
86
- """Placeholder coherence metric β€” using a bounded random or simple heuristic.
87
- Replace with grammar/perplexity later. Returns in [0,1]."""
88
  if not response:
89
  return 0.0
90
- # simple heuristic: longer responses that have many sentences get slightly higher
91
  sents = max(1, len(re.split(r"[.!?]+", response)) - 1)
92
- words = max(1, len(re.findall(r"\w+", response)))
 
 
 
93
  base = min(1.0, (words / 50.0) + (sents / 5.0))
94
- # clamp to [0.5, 0.98] to avoid extreme
95
- val = max(0.5, min(base * 0.9, 0.98))
96
- return round(val, 3)
97
 
98
  def check_accuracy(reference: str, response: str) -> float:
99
  """Semantic similarity between reference and response via embeddings (cosine)."""
@@ -102,146 +105,118 @@ def check_accuracy(reference: str, response: str) -> float:
102
  ref_emb = embed_model.encode(reference, convert_to_tensor=True)
103
  resp_emb = embed_model.encode(response, convert_to_tensor=True)
104
  sim = float(util.cos_sim(ref_emb, resp_emb).item())
105
- # cosine similarity in [-1,1] but for sentences usually [0,1]
106
- sim = max(0.0, min(1.0, sim))
107
- return round(sim, 3)
108
 
109
  # --------------------------
110
- # AGGREGATION & SCORING
111
  # --------------------------
112
  def compute_row_scores(prompt, response, reference) -> Dict:
113
  instr = check_instruction_following(prompt, response)
114
- entail, contra = check_hallucination(reference, response)
115
  assum = check_assumption(response)
116
  coh = check_coherence(response)
117
  acc = check_accuracy(reference, response)
118
 
119
- # Combine hallucination metrics into single positive metric: entail good, contra bad
120
- hyst = entail * (1 - contra)
121
- hyst = round(max(0.0, min(1.0, hyst)), 3)
122
-
123
- # final_score: simple average of six components (all in [0,1])
124
- components = [instr, hyst, assum, coh, acc]
125
  final = round(float(sum(components) / len(components)), 3)
126
 
127
  return {
128
  "InstructionFollowing": instr,
129
- "Hallucination_Entail": entail,
130
- "Hallucination_Contra": contra,
131
- "Hallucination_Metric": hyst,
132
  "AssumptionControl": assum,
133
  "Coherence": coh,
134
  "Accuracy": acc,
135
- "FinalScore": final
136
  }
137
 
 
138
  # --------------------------
139
  # VISUALIZATION HELPERS
140
  # --------------------------
141
- def spider_net_multi(labels: List[str], rows: List[Dict], title: str = "Spider (Radar) Chart", fill_alpha: float = 0.12):
142
- """
143
- Create and return Matplotlib figure for radar chart.
144
- rows: list of {"name": str, "values": [v1,...,vN]} values assumed on 0-100 scale for visibility.
145
- """
146
- N = len(labels)
147
- angles = [n / float(N) * 2 * math.pi for n in range(N)]
148
- angles += angles[:1]
149
-
150
- fig = plt.figure(figsize=(6.5, 6.5))
151
- ax = plt.subplot(111, polar=True)
152
-
153
- ax.set_xticks(angles[:-1])
154
- ax.set_xticklabels(labels, fontsize=9)
155
- # radial limits: 0 to 100
156
- ax.set_ylim(0, 100)
157
- ax.set_yticks([0, 25, 50, 75, 100])
158
-
159
- for r in rows:
160
- values = r["values"]
161
- values_closed = values + values[:1]
162
- ax.plot(angles, values_closed, linewidth=1.5, label=r["name"])
163
- ax.fill(angles, values_closed, alpha=fill_alpha)
164
-
165
- ax.set_title(title, y=1.08, fontsize=12)
166
- ax.legend(loc="upper right", bbox_to_anchor=(1.25, 1.1))
167
- return fig
168
-
169
- def heatmap_plot(df: pd.DataFrame, metric_cols: List[str], title: str = "Metric Correlations"):
170
- fig, ax = plt.subplots(figsize=(7, 5))
171
- sns.heatmap(df[metric_cols].corr(), annot=True, fmt=".2f", cmap="coolwarm", ax=ax)
172
- ax.set_title(title)
173
- return fig
174
-
175
- def bar_plot_avg(df: pd.DataFrame, metric_cols: List[str], title: str = "Average Metric Scores per Agent"):
176
- agg = df.groupby("Agent")[metric_cols].mean().reset_index()
177
- fig, ax = plt.subplots(figsize=(10, 5))
178
- agg.set_index("Agent")[metric_cols].plot(kind="bar", ax=ax)
179
- ax.set_title(title)
180
- ax.set_ylabel("Score (0 - 1)")
181
- plt.xticks(rotation=45)
182
- plt.tight_layout()
183
- return fig
184
 
185
  # --------------------------
186
- # HIGH-LEVEL EVALUATION (batch)
187
  # --------------------------
188
- def evaluate_dataframe(df: pd.DataFrame) -> Tuple[pd.DataFrame, List[Tuple[str,str]]]:
189
  """
190
- df must contain columns: prompt, response, task, agent, reference (reference optional)
191
- Returns: metrics_df, list of (image_path, caption) for visualizations
192
  """
193
- # Normalize columns
194
  df = df.rename(columns={c: c.strip() for c in df.columns})
195
- # try to extract agent from metadata if not present
196
- if "agent" not in df.columns and "metadata" in df.columns:
197
- df["agent"] = df["metadata"].apply(lambda m: m.get("agent") if isinstance(m, dict) else None)
198
 
199
  rows = []
200
  for _, r in df.iterrows():
201
  prompt = r.get("prompt", "")
202
  response = r.get("response", "")
203
- reference = r.get("reference", "") if "reference" in r else ""
204
  agent = r.get("agent", "Unknown")
205
  task = r.get("task", "Unknown")
 
206
  scores = compute_row_scores(prompt, response, reference)
207
  entry = {
208
  "Task": str(task).strip(),
209
  "Agent": str(agent),
210
  "Prompt": prompt,
211
  "Response": response,
212
- "Reference": reference
213
  }
214
  entry.update(scores)
215
  rows.append(entry)
 
216
  metrics_df = pd.DataFrame(rows)
217
 
218
  # Visualization artifacts
219
  images = []
 
220
 
221
- # Per-task spider charts
222
- metric_labels = ["InstructionFollowing", "Hallucination_Metric", "AssumptionControl", "Coherence", "Accuracy"]
223
  for task, g in metrics_df.groupby("Task"):
224
- agents = g["Agent"].unique().tolist()
225
  series = []
226
- for a in agents:
227
  subset = g[g["Agent"] == a]
228
- vals = []
229
- # convert to 0-100 scale for plot
230
- for m in metric_labels:
231
- vals.append(round(float(subset[m].mean()) * 100, 2))
232
  series.append({"name": a, "values": vals})
233
- if len(series) == 0:
234
- continue
235
- fig = spider_net_multi(metric_labels, series, title=f"{task} β€” Agent Comparison")
236
- fname = f"/tmp/{uuid.uuid4().hex}_{task}_radar.png"
237
- fig.savefig(fname, bbox_inches="tight")
238
- plt.close(fig)
239
- images.append((fname, f"{task} - radar"))
240
-
241
- # also bar plot (averages) per task
242
- try:
243
  fig2, ax = plt.subplots(figsize=(8, 4))
244
- avg = g.groupby("Agent")[["InstructionFollowing", "Hallucination_Metric", "AssumptionControl", "Coherence", "Accuracy"]].mean()
245
  avg.plot(kind="bar", ax=ax)
246
  ax.set_title(f"{task} β€” Average Metrics by Agent")
247
  ax.set_ylabel("Score (0-1)")
@@ -250,22 +225,42 @@ def evaluate_dataframe(df: pd.DataFrame) -> Tuple[pd.DataFrame, List[Tuple[str,s
250
  fig2.savefig(fname2, bbox_inches="tight")
251
  plt.close(fig2)
252
  images.append((fname2, f"{task} - bar"))
253
- except Exception:
254
- pass
255
 
256
  # Global heatmap
257
- metric_cols = ["InstructionFollowing", "Hallucination_Metric", "AssumptionControl", "Coherence", "Accuracy", "FinalScore"]
258
- try:
259
- figh = heatmap_plot(metrics_df, metric_cols)
260
- fnameh = f"/tmp/{uuid.uuid4().hex}_heatmap.png"
261
- figh.savefig(fnameh, bbox_inches="tight")
262
- plt.close(figh)
263
- images.append((fnameh, "Metric Correlations Heatmap"))
264
- except Exception:
265
- pass
266
-
267
- # Leaderboard: average final score per agent (global)
268
  lb = metrics_df.groupby(["Agent", "Task"])["FinalScore"].mean().reset_index()
269
  lb = lb.sort_values(["FinalScore"], ascending=False)
270
 
271
  return metrics_df, images, lb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ Evaluation module: loads models, computes metrics, and creates visualizations.
3
+ Lightweight, CPU-friendly, no Java required.
4
  """
5
 
6
  import re
7
  import math
8
  import uuid
 
9
  from typing import List, Dict, Tuple
10
 
11
  import numpy as np
 
17
  from sentence_transformers import SentenceTransformer, util
18
 
19
  # --------------------------
20
+ # MODEL LOADING
21
  # --------------------------
 
22
  NLI_MODEL = "textattack/roberta-base-MNLI"
23
  EMBED_MODEL = "sentence-transformers/all-MiniLM-L6-v2"
24
 
25
+ # Load NLI model & tokenizer
26
  nli_tokenizer = AutoTokenizer.from_pretrained(NLI_MODEL)
27
  nli_model = AutoModelForSequenceClassification.from_pretrained(NLI_MODEL)
28
  nli_model.to("cpu")
 
31
  # Load embedding model
32
  embed_model = SentenceTransformer(EMBED_MODEL)
33
 
34
+ # Label mapping from config
35
  id2label = {int(k): v.upper() for k, v in nli_model.config.id2label.items()}
36
 
37
+
38
  # --------------------------
39
  # METRIC FUNCTIONS
40
  # --------------------------
41
  def check_instruction_following(prompt: str, response: str) -> float:
42
+ """Embedding-based similarity between prompt and response."""
43
+ if not prompt or not response:
 
 
 
44
  return 0.0
45
+ p_emb = embed_model.encode(prompt, convert_to_tensor=True)
46
+ r_emb = embed_model.encode(response, convert_to_tensor=True)
47
+ sim = float(util.cos_sim(p_emb, r_emb).item())
48
+ return round(max(0.0, min(1.0, sim)), 3)
49
+
50
 
51
+ def check_hallucination(reference: str, response: str) -> float:
52
  """
53
+ Single hallucination score:
54
+ Entailment prob - Contradiction prob (normalized to [0,1]).
55
+ Higher = less hallucination.
56
  """
57
  if not reference or not response:
58
+ return 0.0
59
  with torch.no_grad():
60
  inputs = nli_tokenizer.encode_plus(reference, response, return_tensors="pt", truncation=True)
61
  outputs = nli_model(**inputs)
62
  probs = torch.softmax(outputs.logits, dim=-1).cpu().numpy()[0]
63
+
64
+ entail_prob, contra_prob = 0.0, 0.0
 
65
  for idx, p in enumerate(probs):
66
+ label = id2label.get(idx, "")
67
  if "ENTAIL" in label:
68
  entail_prob = float(p)
69
+ elif "CONTRA" in label:
70
  contra_prob = float(p)
71
+
72
+ score = entail_prob - contra_prob
73
+ score = (score + 1) / 2 # normalize [-1,1] β†’ [0,1]
74
+ return round(max(0.0, min(1.0, score)), 3)
75
+
76
 
77
  def check_assumption(response: str) -> float:
78
+ """Detect speculative/hedging terms."""
79
  if not response:
80
  return 0.0
81
  speculative_terms = ["maybe", "probably", "might", "perhaps", "i guess", "seems", "could"]
82
  count = sum(1 for t in speculative_terms if t in response.lower())
83
+ score = 1.0 - min(count / 5.0, 1.0) # smoother decay
84
  return round(score, 3)
85
 
86
+
87
  def check_coherence(response: str) -> float:
88
+ """Heuristic coherence metric: penalizes very short/long, rewards sentence balance."""
 
89
  if not response:
90
  return 0.0
91
+ words = len(re.findall(r"\w+", response))
92
  sents = max(1, len(re.split(r"[.!?]+", response)) - 1)
93
+ if words < 5:
94
+ return 0.3
95
+ if words > 200:
96
+ return 0.5
97
  base = min(1.0, (words / 50.0) + (sents / 5.0))
98
+ return round(max(0.4, min(base, 0.95)), 3)
99
+
 
100
 
101
  def check_accuracy(reference: str, response: str) -> float:
102
  """Semantic similarity between reference and response via embeddings (cosine)."""
 
105
  ref_emb = embed_model.encode(reference, convert_to_tensor=True)
106
  resp_emb = embed_model.encode(response, convert_to_tensor=True)
107
  sim = float(util.cos_sim(ref_emb, resp_emb).item())
108
+ return round(max(0.0, min(1.0, sim)), 3)
109
+
 
110
 
111
  # --------------------------
112
+ # SCORING PIPELINE
113
  # --------------------------
114
  def compute_row_scores(prompt, response, reference) -> Dict:
115
  instr = check_instruction_following(prompt, response)
116
+ halluc = check_hallucination(reference, response)
117
  assum = check_assumption(response)
118
  coh = check_coherence(response)
119
  acc = check_accuracy(reference, response)
120
 
121
+ # Final score: average
122
+ components = [instr, halluc, assum, coh, acc]
 
 
 
 
123
  final = round(float(sum(components) / len(components)), 3)
124
 
125
  return {
126
  "InstructionFollowing": instr,
127
+ "Hallucination": halluc,
 
 
128
  "AssumptionControl": assum,
129
  "Coherence": coh,
130
  "Accuracy": acc,
131
+ "FinalScore": final,
132
  }
133
 
134
+
135
  # --------------------------
136
  # VISUALIZATION HELPERS
137
  # --------------------------
138
+ # def spider_net_multi(labels: List[str], rows: List[Dict], title: str, fill_alpha: float = 0.12):
139
+ # """Radar chart for multiple agents."""
140
+ # N = len(labels)
141
+ # angles = [n / float(N) * 2 * math.pi for n in range(N)]
142
+ # angles += angles[:1]
143
+
144
+ # fig = plt.figure(figsize=(6.5, 6.5))
145
+ # ax = plt.subplot(111, polar=True)
146
+ # ax.set_xticks(angles[:-1])
147
+ # ax.set_xticklabels(labels, fontsize=9)
148
+ # ax.set_ylim(0, 100)
149
+ # ax.set_yticks([0, 25, 50, 75, 100])
150
+
151
+ # for r in rows:
152
+ # values = r["values"]
153
+ # values_closed = values + values[:1]
154
+ # ax.plot(angles, values_closed, linewidth=1.5, label=r["name"])
155
+ # ax.fill(angles, values_closed, alpha=fill_alpha)
156
+
157
+ # ax.set_title(title, y=1.08, fontsize=12)
158
+ # ax.legend(loc="upper right", bbox_to_anchor=(1.25, 1.1))
159
+ # return fig
160
+
161
+
162
+ # def heatmap_plot(df: pd.DataFrame, metric_cols: List[str], title: str = "Metric Correlations"):
163
+ # fig, ax = plt.subplots(figsize=(7, 5))
164
+ # sns.heatmap(df[metric_cols].corr(), annot=True, fmt=".2f", cmap="coolwarm", ax=ax)
165
+ # ax.set_title(title)
166
+ # return fig
167
+
 
 
 
 
 
 
 
 
 
 
 
 
 
168
 
169
  # --------------------------
170
+ # HIGH-LEVEL EVALUATION
171
  # --------------------------
172
+ def evaluate_dataframe(df: pd.DataFrame) -> Tuple[pd.DataFrame, List[Tuple[str, str]], pd.DataFrame]:
173
  """
174
+ df must contain: prompt, response, task, agent, reference
175
+ Returns: metrics_df, [(image_path, caption)], leaderboard_df
176
  """
 
177
  df = df.rename(columns={c: c.strip() for c in df.columns})
 
 
 
178
 
179
  rows = []
180
  for _, r in df.iterrows():
181
  prompt = r.get("prompt", "")
182
  response = r.get("response", "")
183
+ reference = r.get("reference", "")
184
  agent = r.get("agent", "Unknown")
185
  task = r.get("task", "Unknown")
186
+
187
  scores = compute_row_scores(prompt, response, reference)
188
  entry = {
189
  "Task": str(task).strip(),
190
  "Agent": str(agent),
191
  "Prompt": prompt,
192
  "Response": response,
193
+ "Reference": reference,
194
  }
195
  entry.update(scores)
196
  rows.append(entry)
197
+
198
  metrics_df = pd.DataFrame(rows)
199
 
200
  # Visualization artifacts
201
  images = []
202
+ metric_labels = ["InstructionFollowing", "Hallucination", "AssumptionControl", "Coherence", "Accuracy"]
203
 
204
+ # Per-task radar and bar charts
 
205
  for task, g in metrics_df.groupby("Task"):
 
206
  series = []
207
+ for a in g["Agent"].unique():
208
  subset = g[g["Agent"] == a]
209
+ vals = [round(float(subset[m].mean()) * 100, 2) for m in metric_labels]
 
 
 
210
  series.append({"name": a, "values": vals})
211
+ if series:
212
+ fig = spider_net_multi(metric_labels, series, title=f"{task} β€” Agent Comparison")
213
+ fname = f"/tmp/{uuid.uuid4().hex}_{task}_radar.png"
214
+ fig.savefig(fname, bbox_inches="tight")
215
+ plt.close(fig)
216
+ images.append((fname, f"{task} - radar"))
217
+
 
 
 
218
  fig2, ax = plt.subplots(figsize=(8, 4))
219
+ avg = g.groupby("Agent")[metric_labels].mean()
220
  avg.plot(kind="bar", ax=ax)
221
  ax.set_title(f"{task} β€” Average Metrics by Agent")
222
  ax.set_ylabel("Score (0-1)")
 
225
  fig2.savefig(fname2, bbox_inches="tight")
226
  plt.close(fig2)
227
  images.append((fname2, f"{task} - bar"))
 
 
228
 
229
  # Global heatmap
230
+ metric_cols = metric_labels + ["FinalScore"]
231
+ figh = heatmap_plot(metrics_df, metric_cols)
232
+ fnameh = f"/tmp/{uuid.uuid4().hex}_heatmap.png"
233
+ figh.savefig(fnameh, bbox_inches="tight")
234
+ plt.close(figh)
235
+ images.append((fnameh, "Metric Correlations Heatmap"))
236
+
237
+ # Leaderboard
 
 
 
238
  lb = metrics_df.groupby(["Agent", "Task"])["FinalScore"].mean().reset_index()
239
  lb = lb.sort_values(["FinalScore"], ascending=False)
240
 
241
  return metrics_df, images, lb
242
+
243
+
244
+ # --------------------------
245
+ # DEMO USAGE
246
+ # --------------------------
247
+ if __name__ == "__main__":
248
+ # Sample dataset
249
+ data = [
250
+ {"task": "Math QA", "agent": "AgentA", "prompt": "What is 2+2?", "response": "The answer is 4.", "reference": "2+2=4"},
251
+ {"task": "Math QA", "agent": "AgentB", "prompt": "What is 2+2?", "response": "It might be 5, but usually 4.", "reference": "2+2=4"},
252
+ {"task": "Summarization", "agent": "AgentA", "prompt": "Summarize: 'The cat sat on the mat. The dog barked.'", "response": "A cat sat while a dog barked.", "reference": "Cat on mat, dog barking."},
253
+ ]
254
+ df = pd.DataFrame(data)
255
+
256
+ metrics_df, images, leaderboard = evaluate_dataframe(df)
257
+
258
+ print("\n=== Metrics per response ===")
259
+ print(metrics_df[["Task", "Agent", "InstructionFollowing", "Hallucination", "AssumptionControl", "Coherence", "Accuracy", "FinalScore"]])
260
+
261
+ print("\n=== Leaderboard (average per task & agent) ===")
262
+ print(leaderboard)
263
+
264
+ print("\nVisualization files saved in /tmp/:")
265
+ for path, caption in images:
266
+ print(f"{caption}: {path}")