Supastrikas-004 commited on
Commit
0cb4117
·
verified ·
1 Parent(s): c1a93ea

Create evaluator.py

Browse files
Files changed (1) hide show
  1. evaluator.py +443 -0
evaluator.py ADDED
@@ -0,0 +1,443 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # evaluator.py
2
+ import re
3
+ import math
4
+ import os
5
+ import numpy as np
6
+ import pandas as pd
7
+ import textstat
8
+ from typing import Tuple, Dict
9
+
10
+ # Use LanguageTool public API to avoid Java dependency in Spaces
11
+ import language_tool_python
12
+ try:
13
+ tool = language_tool_python.LanguageToolPublicAPI('en-US')
14
+ except Exception:
15
+ # final fallback: simple grammar placeholder if network issue
16
+ tool = None
17
+
18
+ # Import heavy dependencies lazily inside the hallucination detector to avoid startup OOM
19
+ HALLUCINATION_AVAILABLE = True
20
+ try:
21
+ # 'unieval' import may fail if package not installed; guard it
22
+ from unieval.metric.evaluator import get_evaluator # optional
23
+ import evaluate # required by hallucination detector
24
+ import torch
25
+ from transformers import AutoTokenizer, T5ForConditionalGeneration, AutoModelForQuestionAnswering, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM
26
+ from sentence_transformers import SentenceTransformer, util
27
+ except Exception:
28
+ HALLUCINATION_AVAILABLE = False
29
+
30
+ # -------------------------
31
+ # Rule-based metrics
32
+ # -------------------------
33
+ def check_instruction_following(prompt: str, response: str) -> float:
34
+ prompt = (prompt or "").lower()
35
+ response = (response or "").lower()
36
+ keywords = re.findall(r"\b\w+\b", prompt)
37
+ if not keywords:
38
+ return 0.0
39
+ matches = sum(1 for k in set(keywords) if k in response)
40
+ return round(matches / len(set(keywords)), 3)
41
+
42
+ def check_grammar(response: str) -> Tuple[int, float]:
43
+ """
44
+ Returns (num_matches, grammar_score_in_0_1)
45
+ grammar_score = 1 - num_matches/10 clipped
46
+ If language tool unavailable, returns (0, 0.8) as a coarse default.
47
+ """
48
+ if not response:
49
+ return 0, 0.0
50
+ if tool is None:
51
+ return 0, 0.8
52
+ try:
53
+ matches = tool.check(response)
54
+ num = len(matches)
55
+ score = max(0.0, 1 - num / 10)
56
+ return num, round(score, 3)
57
+ except Exception:
58
+ return 0, 0.8
59
+
60
+ def check_coherence(response: str) -> float:
61
+ if not response:
62
+ return 0.0
63
+ sents = max(1, len(re.split(r"[.!?]+", response)) - 1)
64
+ words = max(1, len(re.findall(r"\w+", response)))
65
+ base = min(1.0, (words / 50.0) + (sents / 5.0))
66
+ val = max(0.5, min(base * 0.9, 0.98))
67
+ return round(val, 3)
68
+
69
+ def check_accuracy_embeddings(reference: str, response: str, embed_model=None) -> float:
70
+ """
71
+ If embed_model passed and reference provided, compute cosine sim.
72
+ Otherwise return 0 or a neutral value.
73
+ """
74
+ if not reference or not response or embed_model is None:
75
+ return 0.0
76
+ try:
77
+ ref_emb = embed_model.encode(reference, convert_to_tensor=True)
78
+ resp_emb = embed_model.encode(response, convert_to_tensor=True)
79
+ sim = float(util.cos_sim(ref_emb, resp_emb))
80
+ sim = max(0.0, min(1.0, sim))
81
+ return round(sim, 3)
82
+ except Exception:
83
+ return 0.0
84
+
85
+ # -------------------------
86
+ # Hallucination Detector wrapper
87
+ # -------------------------
88
+ class HallucinationDetectorWrapper:
89
+ """
90
+ Wraps the ComprehensiveHallucinationDetector logic. Loads heavy models lazily and sets
91
+ DETECTOR_AVAILABLE flag depending on success. If loading fails, methods return neutral stubs.
92
+ """
93
+ def __init__(self):
94
+ self.ready = False
95
+ self._init_detector()
96
+
97
+ def _init_detector(self):
98
+ global HALLUCINATION_AVAILABLE
99
+ if not HALLUCINATION_AVAILABLE:
100
+ self.ready = False
101
+ return
102
+ try:
103
+ # Import inside to isolate errors
104
+ import evaluate
105
+ import torch
106
+ from transformers import AutoTokenizer, T5ForConditionalGeneration, AutoModelForQuestionAnswering, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM
107
+ from unieval.metric.evaluator import get_evaluator
108
+ # Minimal lightweight choices could be substituted here if you want smaller models
109
+ self.device = "cuda" if torch.cuda.is_available() else "cpu"
110
+
111
+ # Load metrics
112
+ self.rouge = evaluate.load('rouge')
113
+ self.sacrebleu = evaluate.load('sacrebleu')
114
+ self.bertscore = evaluate.load('bertscore')
115
+
116
+ # load unieval if available
117
+ try:
118
+ self.unieval_evaluator = get_evaluator('fact')
119
+ except Exception:
120
+ self.unieval_evaluator = None
121
+
122
+ # Load QG / QA / NLI / knowledge gen models
123
+ # Note: These models may be large; this is inside try/except
124
+ try:
125
+ self.qg_tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-question-generation")
126
+ self.qg_model = T5ForConditionalGeneration.from_pretrained("mrm8488/t5-base-finetuned-question-generation").to(self.device)
127
+ self.qa_tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
128
+ self.qa_model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2").to(self.device)
129
+ nli_model_name = "ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli"
130
+ self.nli_tokenizer = AutoTokenizer.from_pretrained(nli_model_name)
131
+ self.nli_model = AutoModelForSequenceClassification.from_pretrained(nli_model_name).to(self.device)
132
+ judge_model_name = "google/flan-t5-large"
133
+ self.judge_tokenizer = AutoTokenizer.from_pretrained(judge_model_name)
134
+ self.judge_model = AutoModelForSeq2SeqLM.from_pretrained(judge_model_name).to(self.device)
135
+ self.ready = True
136
+ except Exception:
137
+ # If any heavy-model loading fails, disable the detector
138
+ self.ready = False
139
+ except Exception:
140
+ self.ready = False
141
+
142
+ def is_ready(self):
143
+ return self.ready
144
+
145
+ def detect(self, prompt: str, output: str) -> Dict:
146
+ """
147
+ If ready, run the comprehensive detector and return dict of metrics.
148
+ If not ready, return neutral placeholder dict.
149
+ """
150
+ if not self.ready:
151
+ # Neutral placeholders (so hallucination_score = 0.5 later)
152
+ return {
153
+ "knowledge_source": "",
154
+ "rouge_l": 0.0,
155
+ "sacrebleu": 0.0,
156
+ "bertscore_f1": 0.0,
157
+ "unieval_consistency": 0.0,
158
+ "q_squared_nli_contradiction": 0.5,
159
+ "critic_contradiction": 0.5
160
+ }
161
+ # Actual detection implementation (mirrors the code you provided)
162
+ try:
163
+ # generate knowledge source using judge model
164
+ input_text = f"Provide a factual answer: {prompt}"
165
+ input_ids = self.judge_tokenizer(input_text, return_tensors="pt").input_ids.to(self.device)
166
+ outputs = self.judge_model.generate(input_ids, max_length=384, num_beams=5, early_stopping=True)
167
+ knowledge_source = self.judge_tokenizer.decode(outputs[0], skip_special_tokens=True)
168
+
169
+ # n-gram & semantic
170
+ rouge_l = self.rouge.compute(predictions=[output], references=[knowledge_source])['rougeL']
171
+ sacre = self.sacrebleu.compute(predictions=[output], references=[[knowledge_source]])['score'] / 100.0
172
+ bert_results = self.bertscore.compute(predictions=[output], references=[knowledge_source], lang='en')
173
+ bert_f1 = np.mean(bert_results.get('f1', [0.0]))
174
+
175
+ # unieval
176
+ if self.unieval_evaluator:
177
+ try:
178
+ ue = self.unieval_evaluator.evaluate([{'source': knowledge_source, 'system_output': output}])[0]['consistency']
179
+ except Exception:
180
+ ue = 0.0
181
+ else:
182
+ ue = 0.0
183
+
184
+ # q^2
185
+ qg_input = f"generate question: {output}"
186
+ qg_input_ids = self.qg_tokenizer(qg_input, return_tensors="pt").input_ids.to(self.device)
187
+ qg_out = self.qg_model.generate(qg_input_ids, max_length=64, num_beams=4)
188
+ question = self.qg_tokenizer.decode(qg_out[0], skip_special_tokens=True)
189
+ if not question:
190
+ q2_contra = 0.5
191
+ else:
192
+ try:
193
+ qa_inputs = self.qa_tokenizer(question, knowledge_source, return_tensors="pt").to(self.device)
194
+ with torch.no_grad():
195
+ qa_output = self.qa_model(**qa_inputs)
196
+ answer_start = torch.argmax(qa_output.start_logits)
197
+ answer_end = torch.argmax(qa_output.end_logits) + 1
198
+ answer_from_knowledge = self.qa_tokenizer.decode(qa_inputs["input_ids"][0][answer_start:answer_end])
199
+ if not answer_from_knowledge:
200
+ q2_contra = 0.5
201
+ else:
202
+ # NLI: output vs answer_from_knowledge
203
+ tokenized = self.nli_tokenizer(output, answer_from_knowledge, return_tensors='pt', truncation=True, max_length=512).to(self.device)
204
+ with torch.no_grad():
205
+ out = self.nli_model(**tokenized)
206
+ probs = torch.softmax(out.logits, dim=1)[0].tolist()
207
+ q2_contra = probs[0] # contradiction prob
208
+ except Exception:
209
+ q2_contra = 0.5
210
+
211
+ # critic contradiction
212
+ try:
213
+ tokenized2 = self.nli_tokenizer(knowledge_source, output, return_tensors='pt', truncation=True, max_length=512).to(self.device)
214
+ with torch.no_grad():
215
+ out2 = self.nli_model(**tokenized2)
216
+ probs2 = torch.softmax(out2.logits, dim=1)[0].tolist()
217
+ critic_contra = probs2[0]
218
+ except Exception:
219
+ critic_contra = 0.5
220
+
221
+ return {
222
+ "knowledge_source": knowledge_source,
223
+ "rouge_l": rouge_l,
224
+ "sacrebleu": sacre,
225
+ "bertscore_f1": bert_f1,
226
+ "unieval_consistency": ue,
227
+ "q_squared_nli_contradiction": q2_contra,
228
+ "critic_contradiction": critic_contra
229
+ }
230
+ except Exception:
231
+ # On any runtime failure, return neutral placeholders
232
+ return {
233
+ "knowledge_source": "",
234
+ "rouge_l": 0.0,
235
+ "sacrebleu": 0.0,
236
+ "bertscore_f1": 0.0,
237
+ "unieval_consistency": 0.0,
238
+ "q_squared_nli_contradiction": 0.5,
239
+ "critic_contradiction": 0.5
240
+ }
241
+
242
+ # Singleton detector instance
243
+ _DETECTOR = None
244
+ def get_detector():
245
+ global _DETECTOR
246
+ if _DETECTOR is None:
247
+ _DETECTOR = HallucinationDetectorWrapper()
248
+ return _DETECTOR
249
+
250
+ def hallucination_score(prompt: str, output: str) -> float:
251
+ d = get_detector()
252
+ res = d.detect(prompt, output)
253
+ weights = {
254
+ "rouge_l": 0.2, "sacrebleu": 0.05, "bertscore_f1": 0.25,
255
+ "unieval_consistency": 0.25,
256
+ "q_squared_nli_contradiction": 0.15,
257
+ "critic_contradiction": 0.10
258
+ }
259
+ total = sum(weights.values())
260
+ weights = {k: v/total for k, v in weights.items()}
261
+ invert_metrics = {"rouge_l", "sacrebleu", "bertscore_f1", "unieval_consistency"}
262
+ final = 0.0
263
+ for m, w in weights.items():
264
+ v = res.get(m, 0.0)
265
+ if m in invert_metrics:
266
+ v = 1 - v
267
+ final += w * v
268
+ # final is in [0,1], higher -> more hallucination (worse)
269
+ return float(final)
270
+
271
+ # -------------------------
272
+ # Main evaluation function (integrate hallucination as complementary metric)
273
+ # -------------------------
274
+ def evaluate_dataframe(df: pd.DataFrame, use_llm_judge: bool = False) -> Tuple[pd.DataFrame, list, pd.DataFrame]:
275
+ """
276
+ Input: df with columns prompt (or instruction), response, task, agent, reference (opt)
277
+ Returns: metrics_df (per row), list of visualization image paths (path, caption), leaderboard_df
278
+ """
279
+ # Normalize column names
280
+ df = df.rename(columns={c: c.strip() for c in df.columns})
281
+ # Accept alternate column names
282
+ if "instruction" not in df.columns and "prompt" in df.columns:
283
+ df = df.rename(columns={"prompt": "instruction"})
284
+ if "response" not in df.columns and "output" in df.columns:
285
+ df = df.rename(columns={"output": "response"})
286
+ if "agent" not in df.columns:
287
+ df["agent"] = df.get("metadata", {}).apply(lambda x: x.get("agent") if isinstance(x, dict) else "Unknown")
288
+
289
+ # optional embed model for accuracy: lazy load sentence-transformers if available
290
+ embed_model = None
291
+ try:
292
+ from sentence_transformers import SentenceTransformer, util
293
+ embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
294
+ except Exception:
295
+ embed_model = None
296
+
297
+ rows = []
298
+ for _, r in df.iterrows():
299
+ instr = str(r.get("instruction", ""))
300
+ response = str(r.get("response", ""))
301
+ reference = str(r.get("reference", "")) if "reference" in r else ""
302
+ agent = r.get("agent", "Unknown")
303
+ task = r.get("task", "Unknown")
304
+
305
+ inst_score = check_instruction_following(instr, response)
306
+ num_matches, grammar_score = check_grammar(response)
307
+ coh_score = check_coherence(response)
308
+ acc_emb = check_accuracy_embeddings(reference, response, embed_model)
309
+
310
+ base_components = [inst_score, coh_score, grammar_score, acc_emb]
311
+ base_final = float(sum(base_components) / max(1, len(base_components)))
312
+
313
+ row_entry = {
314
+ "Task": str(task),
315
+ "Agent": str(agent),
316
+ "Instruction": instr,
317
+ "Response": response,
318
+ "Reference": reference,
319
+ "score_instruction": inst_score,
320
+ "score_grammar": grammar_score,
321
+ "score_coherence": coh_score,
322
+ "score_accuracy": acc_emb,
323
+ "base_final_score": round(base_final, 4)
324
+ }
325
+
326
+ # optional LLM judge: compute hallucination_score
327
+ if use_llm_judge:
328
+ try:
329
+ h = hallucination_score(instr, response)
330
+ # convert to consistency (higher is better): 1 - hallucination
331
+ consistency = round(1.0 - float(h), 4)
332
+ row_entry["score_llm_consistency"] = consistency
333
+ # combine base_final and consistency (simple averaging)
334
+ final_score = round((base_final + consistency) / 2.0, 4)
335
+ row_entry["final_score"] = final_score
336
+ except Exception:
337
+ # fallback
338
+ row_entry["score_llm_consistency"] = 0.5
339
+ row_entry["final_score"] = round(base_final, 4)
340
+ else:
341
+ row_entry["score_llm_consistency"] = np.nan
342
+ row_entry["final_score"] = round(base_final, 4)
343
+
344
+ rows.append(row_entry)
345
+
346
+ metrics_df = pd.DataFrame(rows)
347
+
348
+ # Create visualizations (saved to /tmp)
349
+ images = []
350
+ import matplotlib.pyplot as plt
351
+ import seaborn as sns
352
+ import uuid
353
+ # Leaderboard (avg final score per agent)
354
+ try:
355
+ lb = metrics_df.groupby("Agent")["final_score"].mean().reset_index().sort_values("final_score", ascending=False)
356
+ fname = f"/tmp/{uuid.uuid4().hex}_leaderboard.png"
357
+ fig, ax = plt.subplots(figsize=(8, max(4, len(lb)*0.4)))
358
+ ax.barh(lb["Agent"], lb["final_score"], color="tab:blue")
359
+ ax.invert_yaxis()
360
+ ax.set_xlabel("Average final score")
361
+ ax.set_title("Leaderboard: Avg final score per agent")
362
+ plt.tight_layout()
363
+ fig.savefig(fname, bbox_inches="tight")
364
+ plt.close(fig)
365
+ images.append((fname, "Leaderboard (horizontal bar)"))
366
+ except Exception:
367
+ pass
368
+
369
+ # Combined spider / radar : compare all agents across metrics
370
+ try:
371
+ metric_cols = ["score_instruction", "score_coherence", "score_grammar", "score_accuracy"]
372
+ if use_llm_judge:
373
+ metric_cols.append("score_llm_consistency")
374
+ agg = metrics_df.groupby("Agent")[metric_cols].mean().reset_index()
375
+ labels = [c.replace("score_", "").replace("_", " ").capitalize() for c in metric_cols]
376
+ # Build rows as required
377
+ rows_for_plot = []
378
+ for _, row in agg.iterrows():
379
+ vals = [float(row[c]) * 100 for c in metric_cols] # scale to 0-100
380
+ rows_for_plot.append({"name": row["Agent"], "values": vals})
381
+ # draw radar using a small internal function
382
+ def spider_net_multi(labels, rows, title="Spider Chart"):
383
+ import math
384
+ N = len(labels)
385
+ angles = [n / float(N) * 2 * math.pi for n in range(N)]
386
+ angles += angles[:1]
387
+ fig = plt.figure(figsize=(6.5,6.5))
388
+ ax = plt.subplot(111, polar=True)
389
+ ax.set_xticks(angles[:-1])
390
+ ax.set_xticklabels(labels)
391
+ ax.set_ylim(0, 100)
392
+ for r in rows:
393
+ v = r["values"] + r["values"][:1]
394
+ ax.plot(angles, v, label=r["name"])
395
+ ax.fill(angles, v, alpha=0.12)
396
+ ax.set_title(title)
397
+ ax.legend(loc="upper right", bbox_to_anchor=(1.3,1.1))
398
+ return fig
399
+ fig = spider_net_multi(labels, rows_for_plot, title="All Agents Comparison (Radar)")
400
+ fname2 = f"/tmp/{uuid.uuid4().hex}_radar.png"
401
+ fig.savefig(fname2, bbox_inches="tight")
402
+ plt.close(fig)
403
+ images.append((fname2, "All agents radar chart"))
404
+ except Exception:
405
+ pass
406
+
407
+ # Per-task spider charts
408
+ try:
409
+ for task, subset in metrics_df.groupby("Task"):
410
+ agg = subset.groupby("Agent")[metric_cols].mean().reset_index()
411
+ if agg.shape[0] == 0:
412
+ continue
413
+ rows_for_plot = []
414
+ for _, row in agg.iterrows():
415
+ vals = [float(row[c]) * 100 for c in metric_cols]
416
+ rows_for_plot.append({"name": row["Agent"], "values": vals})
417
+ fig = spider_net_multi(labels, rows_for_plot, title=f"{task} Agents (Radar)")
418
+ fname3 = f"/tmp/{uuid.uuid4().hex}_{task}_radar.png"
419
+ fig.savefig(fname3, bbox_inches="tight")
420
+ plt.close(fig)
421
+ images.append((fname3, f"{task} - radar"))
422
+ except Exception:
423
+ pass
424
+
425
+ # Heatmap for metric correlations
426
+ try:
427
+ metric_cols2 = ["score_instruction", "score_coherence", "score_grammar", "score_accuracy", "final_score"]
428
+ if use_llm_judge:
429
+ metric_cols2.append("score_llm_consistency")
430
+ fig, ax = plt.subplots(figsize=(7,6))
431
+ sns.heatmap(metrics_df[metric_cols2].corr(), annot=True, fmt=".2f", cmap="coolwarm", ax=ax)
432
+ ax.set_title("Metric correlations")
433
+ fnameh = f"/tmp/{uuid.uuid4().hex}_heatmap.png"
434
+ fig.savefig(fnameh, bbox_inches="tight")
435
+ plt.close(fig)
436
+ images.append((fnameh, "Metric correlations"))
437
+ except Exception:
438
+ pass
439
+
440
+ # Leaderboard df return
441
+ leaderboard_df = metrics_df.groupby(["Agent", "Task"])["final_score"].mean().reset_index().sort_values("final_score", ascending=False)
442
+
443
+ return metrics_df, images, leaderboard_df