Create evaluator.py
Browse files- evaluator.py +443 -0
evaluator.py
ADDED
|
@@ -0,0 +1,443 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# evaluator.py
|
| 2 |
+
import re
|
| 3 |
+
import math
|
| 4 |
+
import os
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
import textstat
|
| 8 |
+
from typing import Tuple, Dict
|
| 9 |
+
|
| 10 |
+
# Use LanguageTool public API to avoid Java dependency in Spaces
|
| 11 |
+
import language_tool_python
|
| 12 |
+
try:
|
| 13 |
+
tool = language_tool_python.LanguageToolPublicAPI('en-US')
|
| 14 |
+
except Exception:
|
| 15 |
+
# final fallback: simple grammar placeholder if network issue
|
| 16 |
+
tool = None
|
| 17 |
+
|
| 18 |
+
# Import heavy dependencies lazily inside the hallucination detector to avoid startup OOM
|
| 19 |
+
HALLUCINATION_AVAILABLE = True
|
| 20 |
+
try:
|
| 21 |
+
# 'unieval' import may fail if package not installed; guard it
|
| 22 |
+
from unieval.metric.evaluator import get_evaluator # optional
|
| 23 |
+
import evaluate # required by hallucination detector
|
| 24 |
+
import torch
|
| 25 |
+
from transformers import AutoTokenizer, T5ForConditionalGeneration, AutoModelForQuestionAnswering, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM
|
| 26 |
+
from sentence_transformers import SentenceTransformer, util
|
| 27 |
+
except Exception:
|
| 28 |
+
HALLUCINATION_AVAILABLE = False
|
| 29 |
+
|
| 30 |
+
# -------------------------
|
| 31 |
+
# Rule-based metrics
|
| 32 |
+
# -------------------------
|
| 33 |
+
def check_instruction_following(prompt: str, response: str) -> float:
|
| 34 |
+
prompt = (prompt or "").lower()
|
| 35 |
+
response = (response or "").lower()
|
| 36 |
+
keywords = re.findall(r"\b\w+\b", prompt)
|
| 37 |
+
if not keywords:
|
| 38 |
+
return 0.0
|
| 39 |
+
matches = sum(1 for k in set(keywords) if k in response)
|
| 40 |
+
return round(matches / len(set(keywords)), 3)
|
| 41 |
+
|
| 42 |
+
def check_grammar(response: str) -> Tuple[int, float]:
|
| 43 |
+
"""
|
| 44 |
+
Returns (num_matches, grammar_score_in_0_1)
|
| 45 |
+
grammar_score = 1 - num_matches/10 clipped
|
| 46 |
+
If language tool unavailable, returns (0, 0.8) as a coarse default.
|
| 47 |
+
"""
|
| 48 |
+
if not response:
|
| 49 |
+
return 0, 0.0
|
| 50 |
+
if tool is None:
|
| 51 |
+
return 0, 0.8
|
| 52 |
+
try:
|
| 53 |
+
matches = tool.check(response)
|
| 54 |
+
num = len(matches)
|
| 55 |
+
score = max(0.0, 1 - num / 10)
|
| 56 |
+
return num, round(score, 3)
|
| 57 |
+
except Exception:
|
| 58 |
+
return 0, 0.8
|
| 59 |
+
|
| 60 |
+
def check_coherence(response: str) -> float:
|
| 61 |
+
if not response:
|
| 62 |
+
return 0.0
|
| 63 |
+
sents = max(1, len(re.split(r"[.!?]+", response)) - 1)
|
| 64 |
+
words = max(1, len(re.findall(r"\w+", response)))
|
| 65 |
+
base = min(1.0, (words / 50.0) + (sents / 5.0))
|
| 66 |
+
val = max(0.5, min(base * 0.9, 0.98))
|
| 67 |
+
return round(val, 3)
|
| 68 |
+
|
| 69 |
+
def check_accuracy_embeddings(reference: str, response: str, embed_model=None) -> float:
|
| 70 |
+
"""
|
| 71 |
+
If embed_model passed and reference provided, compute cosine sim.
|
| 72 |
+
Otherwise return 0 or a neutral value.
|
| 73 |
+
"""
|
| 74 |
+
if not reference or not response or embed_model is None:
|
| 75 |
+
return 0.0
|
| 76 |
+
try:
|
| 77 |
+
ref_emb = embed_model.encode(reference, convert_to_tensor=True)
|
| 78 |
+
resp_emb = embed_model.encode(response, convert_to_tensor=True)
|
| 79 |
+
sim = float(util.cos_sim(ref_emb, resp_emb))
|
| 80 |
+
sim = max(0.0, min(1.0, sim))
|
| 81 |
+
return round(sim, 3)
|
| 82 |
+
except Exception:
|
| 83 |
+
return 0.0
|
| 84 |
+
|
| 85 |
+
# -------------------------
|
| 86 |
+
# Hallucination Detector wrapper
|
| 87 |
+
# -------------------------
|
| 88 |
+
class HallucinationDetectorWrapper:
|
| 89 |
+
"""
|
| 90 |
+
Wraps the ComprehensiveHallucinationDetector logic. Loads heavy models lazily and sets
|
| 91 |
+
DETECTOR_AVAILABLE flag depending on success. If loading fails, methods return neutral stubs.
|
| 92 |
+
"""
|
| 93 |
+
def __init__(self):
|
| 94 |
+
self.ready = False
|
| 95 |
+
self._init_detector()
|
| 96 |
+
|
| 97 |
+
def _init_detector(self):
|
| 98 |
+
global HALLUCINATION_AVAILABLE
|
| 99 |
+
if not HALLUCINATION_AVAILABLE:
|
| 100 |
+
self.ready = False
|
| 101 |
+
return
|
| 102 |
+
try:
|
| 103 |
+
# Import inside to isolate errors
|
| 104 |
+
import evaluate
|
| 105 |
+
import torch
|
| 106 |
+
from transformers import AutoTokenizer, T5ForConditionalGeneration, AutoModelForQuestionAnswering, AutoModelForSequenceClassification, AutoModelForSeq2SeqLM
|
| 107 |
+
from unieval.metric.evaluator import get_evaluator
|
| 108 |
+
# Minimal lightweight choices could be substituted here if you want smaller models
|
| 109 |
+
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 110 |
+
|
| 111 |
+
# Load metrics
|
| 112 |
+
self.rouge = evaluate.load('rouge')
|
| 113 |
+
self.sacrebleu = evaluate.load('sacrebleu')
|
| 114 |
+
self.bertscore = evaluate.load('bertscore')
|
| 115 |
+
|
| 116 |
+
# load unieval if available
|
| 117 |
+
try:
|
| 118 |
+
self.unieval_evaluator = get_evaluator('fact')
|
| 119 |
+
except Exception:
|
| 120 |
+
self.unieval_evaluator = None
|
| 121 |
+
|
| 122 |
+
# Load QG / QA / NLI / knowledge gen models
|
| 123 |
+
# Note: These models may be large; this is inside try/except
|
| 124 |
+
try:
|
| 125 |
+
self.qg_tokenizer = AutoTokenizer.from_pretrained("mrm8488/t5-base-finetuned-question-generation")
|
| 126 |
+
self.qg_model = T5ForConditionalGeneration.from_pretrained("mrm8488/t5-base-finetuned-question-generation").to(self.device)
|
| 127 |
+
self.qa_tokenizer = AutoTokenizer.from_pretrained("deepset/roberta-base-squad2")
|
| 128 |
+
self.qa_model = AutoModelForQuestionAnswering.from_pretrained("deepset/roberta-base-squad2").to(self.device)
|
| 129 |
+
nli_model_name = "ynie/roberta-large-snli_mnli_fever_anli_R1_R2_R3-nli"
|
| 130 |
+
self.nli_tokenizer = AutoTokenizer.from_pretrained(nli_model_name)
|
| 131 |
+
self.nli_model = AutoModelForSequenceClassification.from_pretrained(nli_model_name).to(self.device)
|
| 132 |
+
judge_model_name = "google/flan-t5-large"
|
| 133 |
+
self.judge_tokenizer = AutoTokenizer.from_pretrained(judge_model_name)
|
| 134 |
+
self.judge_model = AutoModelForSeq2SeqLM.from_pretrained(judge_model_name).to(self.device)
|
| 135 |
+
self.ready = True
|
| 136 |
+
except Exception:
|
| 137 |
+
# If any heavy-model loading fails, disable the detector
|
| 138 |
+
self.ready = False
|
| 139 |
+
except Exception:
|
| 140 |
+
self.ready = False
|
| 141 |
+
|
| 142 |
+
def is_ready(self):
|
| 143 |
+
return self.ready
|
| 144 |
+
|
| 145 |
+
def detect(self, prompt: str, output: str) -> Dict:
|
| 146 |
+
"""
|
| 147 |
+
If ready, run the comprehensive detector and return dict of metrics.
|
| 148 |
+
If not ready, return neutral placeholder dict.
|
| 149 |
+
"""
|
| 150 |
+
if not self.ready:
|
| 151 |
+
# Neutral placeholders (so hallucination_score = 0.5 later)
|
| 152 |
+
return {
|
| 153 |
+
"knowledge_source": "",
|
| 154 |
+
"rouge_l": 0.0,
|
| 155 |
+
"sacrebleu": 0.0,
|
| 156 |
+
"bertscore_f1": 0.0,
|
| 157 |
+
"unieval_consistency": 0.0,
|
| 158 |
+
"q_squared_nli_contradiction": 0.5,
|
| 159 |
+
"critic_contradiction": 0.5
|
| 160 |
+
}
|
| 161 |
+
# Actual detection implementation (mirrors the code you provided)
|
| 162 |
+
try:
|
| 163 |
+
# generate knowledge source using judge model
|
| 164 |
+
input_text = f"Provide a factual answer: {prompt}"
|
| 165 |
+
input_ids = self.judge_tokenizer(input_text, return_tensors="pt").input_ids.to(self.device)
|
| 166 |
+
outputs = self.judge_model.generate(input_ids, max_length=384, num_beams=5, early_stopping=True)
|
| 167 |
+
knowledge_source = self.judge_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 168 |
+
|
| 169 |
+
# n-gram & semantic
|
| 170 |
+
rouge_l = self.rouge.compute(predictions=[output], references=[knowledge_source])['rougeL']
|
| 171 |
+
sacre = self.sacrebleu.compute(predictions=[output], references=[[knowledge_source]])['score'] / 100.0
|
| 172 |
+
bert_results = self.bertscore.compute(predictions=[output], references=[knowledge_source], lang='en')
|
| 173 |
+
bert_f1 = np.mean(bert_results.get('f1', [0.0]))
|
| 174 |
+
|
| 175 |
+
# unieval
|
| 176 |
+
if self.unieval_evaluator:
|
| 177 |
+
try:
|
| 178 |
+
ue = self.unieval_evaluator.evaluate([{'source': knowledge_source, 'system_output': output}])[0]['consistency']
|
| 179 |
+
except Exception:
|
| 180 |
+
ue = 0.0
|
| 181 |
+
else:
|
| 182 |
+
ue = 0.0
|
| 183 |
+
|
| 184 |
+
# q^2
|
| 185 |
+
qg_input = f"generate question: {output}"
|
| 186 |
+
qg_input_ids = self.qg_tokenizer(qg_input, return_tensors="pt").input_ids.to(self.device)
|
| 187 |
+
qg_out = self.qg_model.generate(qg_input_ids, max_length=64, num_beams=4)
|
| 188 |
+
question = self.qg_tokenizer.decode(qg_out[0], skip_special_tokens=True)
|
| 189 |
+
if not question:
|
| 190 |
+
q2_contra = 0.5
|
| 191 |
+
else:
|
| 192 |
+
try:
|
| 193 |
+
qa_inputs = self.qa_tokenizer(question, knowledge_source, return_tensors="pt").to(self.device)
|
| 194 |
+
with torch.no_grad():
|
| 195 |
+
qa_output = self.qa_model(**qa_inputs)
|
| 196 |
+
answer_start = torch.argmax(qa_output.start_logits)
|
| 197 |
+
answer_end = torch.argmax(qa_output.end_logits) + 1
|
| 198 |
+
answer_from_knowledge = self.qa_tokenizer.decode(qa_inputs["input_ids"][0][answer_start:answer_end])
|
| 199 |
+
if not answer_from_knowledge:
|
| 200 |
+
q2_contra = 0.5
|
| 201 |
+
else:
|
| 202 |
+
# NLI: output vs answer_from_knowledge
|
| 203 |
+
tokenized = self.nli_tokenizer(output, answer_from_knowledge, return_tensors='pt', truncation=True, max_length=512).to(self.device)
|
| 204 |
+
with torch.no_grad():
|
| 205 |
+
out = self.nli_model(**tokenized)
|
| 206 |
+
probs = torch.softmax(out.logits, dim=1)[0].tolist()
|
| 207 |
+
q2_contra = probs[0] # contradiction prob
|
| 208 |
+
except Exception:
|
| 209 |
+
q2_contra = 0.5
|
| 210 |
+
|
| 211 |
+
# critic contradiction
|
| 212 |
+
try:
|
| 213 |
+
tokenized2 = self.nli_tokenizer(knowledge_source, output, return_tensors='pt', truncation=True, max_length=512).to(self.device)
|
| 214 |
+
with torch.no_grad():
|
| 215 |
+
out2 = self.nli_model(**tokenized2)
|
| 216 |
+
probs2 = torch.softmax(out2.logits, dim=1)[0].tolist()
|
| 217 |
+
critic_contra = probs2[0]
|
| 218 |
+
except Exception:
|
| 219 |
+
critic_contra = 0.5
|
| 220 |
+
|
| 221 |
+
return {
|
| 222 |
+
"knowledge_source": knowledge_source,
|
| 223 |
+
"rouge_l": rouge_l,
|
| 224 |
+
"sacrebleu": sacre,
|
| 225 |
+
"bertscore_f1": bert_f1,
|
| 226 |
+
"unieval_consistency": ue,
|
| 227 |
+
"q_squared_nli_contradiction": q2_contra,
|
| 228 |
+
"critic_contradiction": critic_contra
|
| 229 |
+
}
|
| 230 |
+
except Exception:
|
| 231 |
+
# On any runtime failure, return neutral placeholders
|
| 232 |
+
return {
|
| 233 |
+
"knowledge_source": "",
|
| 234 |
+
"rouge_l": 0.0,
|
| 235 |
+
"sacrebleu": 0.0,
|
| 236 |
+
"bertscore_f1": 0.0,
|
| 237 |
+
"unieval_consistency": 0.0,
|
| 238 |
+
"q_squared_nli_contradiction": 0.5,
|
| 239 |
+
"critic_contradiction": 0.5
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
# Singleton detector instance
|
| 243 |
+
_DETECTOR = None
|
| 244 |
+
def get_detector():
|
| 245 |
+
global _DETECTOR
|
| 246 |
+
if _DETECTOR is None:
|
| 247 |
+
_DETECTOR = HallucinationDetectorWrapper()
|
| 248 |
+
return _DETECTOR
|
| 249 |
+
|
| 250 |
+
def hallucination_score(prompt: str, output: str) -> float:
|
| 251 |
+
d = get_detector()
|
| 252 |
+
res = d.detect(prompt, output)
|
| 253 |
+
weights = {
|
| 254 |
+
"rouge_l": 0.2, "sacrebleu": 0.05, "bertscore_f1": 0.25,
|
| 255 |
+
"unieval_consistency": 0.25,
|
| 256 |
+
"q_squared_nli_contradiction": 0.15,
|
| 257 |
+
"critic_contradiction": 0.10
|
| 258 |
+
}
|
| 259 |
+
total = sum(weights.values())
|
| 260 |
+
weights = {k: v/total for k, v in weights.items()}
|
| 261 |
+
invert_metrics = {"rouge_l", "sacrebleu", "bertscore_f1", "unieval_consistency"}
|
| 262 |
+
final = 0.0
|
| 263 |
+
for m, w in weights.items():
|
| 264 |
+
v = res.get(m, 0.0)
|
| 265 |
+
if m in invert_metrics:
|
| 266 |
+
v = 1 - v
|
| 267 |
+
final += w * v
|
| 268 |
+
# final is in [0,1], higher -> more hallucination (worse)
|
| 269 |
+
return float(final)
|
| 270 |
+
|
| 271 |
+
# -------------------------
|
| 272 |
+
# Main evaluation function (integrate hallucination as complementary metric)
|
| 273 |
+
# -------------------------
|
| 274 |
+
def evaluate_dataframe(df: pd.DataFrame, use_llm_judge: bool = False) -> Tuple[pd.DataFrame, list, pd.DataFrame]:
|
| 275 |
+
"""
|
| 276 |
+
Input: df with columns prompt (or instruction), response, task, agent, reference (opt)
|
| 277 |
+
Returns: metrics_df (per row), list of visualization image paths (path, caption), leaderboard_df
|
| 278 |
+
"""
|
| 279 |
+
# Normalize column names
|
| 280 |
+
df = df.rename(columns={c: c.strip() for c in df.columns})
|
| 281 |
+
# Accept alternate column names
|
| 282 |
+
if "instruction" not in df.columns and "prompt" in df.columns:
|
| 283 |
+
df = df.rename(columns={"prompt": "instruction"})
|
| 284 |
+
if "response" not in df.columns and "output" in df.columns:
|
| 285 |
+
df = df.rename(columns={"output": "response"})
|
| 286 |
+
if "agent" not in df.columns:
|
| 287 |
+
df["agent"] = df.get("metadata", {}).apply(lambda x: x.get("agent") if isinstance(x, dict) else "Unknown")
|
| 288 |
+
|
| 289 |
+
# optional embed model for accuracy: lazy load sentence-transformers if available
|
| 290 |
+
embed_model = None
|
| 291 |
+
try:
|
| 292 |
+
from sentence_transformers import SentenceTransformer, util
|
| 293 |
+
embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
|
| 294 |
+
except Exception:
|
| 295 |
+
embed_model = None
|
| 296 |
+
|
| 297 |
+
rows = []
|
| 298 |
+
for _, r in df.iterrows():
|
| 299 |
+
instr = str(r.get("instruction", ""))
|
| 300 |
+
response = str(r.get("response", ""))
|
| 301 |
+
reference = str(r.get("reference", "")) if "reference" in r else ""
|
| 302 |
+
agent = r.get("agent", "Unknown")
|
| 303 |
+
task = r.get("task", "Unknown")
|
| 304 |
+
|
| 305 |
+
inst_score = check_instruction_following(instr, response)
|
| 306 |
+
num_matches, grammar_score = check_grammar(response)
|
| 307 |
+
coh_score = check_coherence(response)
|
| 308 |
+
acc_emb = check_accuracy_embeddings(reference, response, embed_model)
|
| 309 |
+
|
| 310 |
+
base_components = [inst_score, coh_score, grammar_score, acc_emb]
|
| 311 |
+
base_final = float(sum(base_components) / max(1, len(base_components)))
|
| 312 |
+
|
| 313 |
+
row_entry = {
|
| 314 |
+
"Task": str(task),
|
| 315 |
+
"Agent": str(agent),
|
| 316 |
+
"Instruction": instr,
|
| 317 |
+
"Response": response,
|
| 318 |
+
"Reference": reference,
|
| 319 |
+
"score_instruction": inst_score,
|
| 320 |
+
"score_grammar": grammar_score,
|
| 321 |
+
"score_coherence": coh_score,
|
| 322 |
+
"score_accuracy": acc_emb,
|
| 323 |
+
"base_final_score": round(base_final, 4)
|
| 324 |
+
}
|
| 325 |
+
|
| 326 |
+
# optional LLM judge: compute hallucination_score
|
| 327 |
+
if use_llm_judge:
|
| 328 |
+
try:
|
| 329 |
+
h = hallucination_score(instr, response)
|
| 330 |
+
# convert to consistency (higher is better): 1 - hallucination
|
| 331 |
+
consistency = round(1.0 - float(h), 4)
|
| 332 |
+
row_entry["score_llm_consistency"] = consistency
|
| 333 |
+
# combine base_final and consistency (simple averaging)
|
| 334 |
+
final_score = round((base_final + consistency) / 2.0, 4)
|
| 335 |
+
row_entry["final_score"] = final_score
|
| 336 |
+
except Exception:
|
| 337 |
+
# fallback
|
| 338 |
+
row_entry["score_llm_consistency"] = 0.5
|
| 339 |
+
row_entry["final_score"] = round(base_final, 4)
|
| 340 |
+
else:
|
| 341 |
+
row_entry["score_llm_consistency"] = np.nan
|
| 342 |
+
row_entry["final_score"] = round(base_final, 4)
|
| 343 |
+
|
| 344 |
+
rows.append(row_entry)
|
| 345 |
+
|
| 346 |
+
metrics_df = pd.DataFrame(rows)
|
| 347 |
+
|
| 348 |
+
# Create visualizations (saved to /tmp)
|
| 349 |
+
images = []
|
| 350 |
+
import matplotlib.pyplot as plt
|
| 351 |
+
import seaborn as sns
|
| 352 |
+
import uuid
|
| 353 |
+
# Leaderboard (avg final score per agent)
|
| 354 |
+
try:
|
| 355 |
+
lb = metrics_df.groupby("Agent")["final_score"].mean().reset_index().sort_values("final_score", ascending=False)
|
| 356 |
+
fname = f"/tmp/{uuid.uuid4().hex}_leaderboard.png"
|
| 357 |
+
fig, ax = plt.subplots(figsize=(8, max(4, len(lb)*0.4)))
|
| 358 |
+
ax.barh(lb["Agent"], lb["final_score"], color="tab:blue")
|
| 359 |
+
ax.invert_yaxis()
|
| 360 |
+
ax.set_xlabel("Average final score")
|
| 361 |
+
ax.set_title("Leaderboard: Avg final score per agent")
|
| 362 |
+
plt.tight_layout()
|
| 363 |
+
fig.savefig(fname, bbox_inches="tight")
|
| 364 |
+
plt.close(fig)
|
| 365 |
+
images.append((fname, "Leaderboard (horizontal bar)"))
|
| 366 |
+
except Exception:
|
| 367 |
+
pass
|
| 368 |
+
|
| 369 |
+
# Combined spider / radar : compare all agents across metrics
|
| 370 |
+
try:
|
| 371 |
+
metric_cols = ["score_instruction", "score_coherence", "score_grammar", "score_accuracy"]
|
| 372 |
+
if use_llm_judge:
|
| 373 |
+
metric_cols.append("score_llm_consistency")
|
| 374 |
+
agg = metrics_df.groupby("Agent")[metric_cols].mean().reset_index()
|
| 375 |
+
labels = [c.replace("score_", "").replace("_", " ").capitalize() for c in metric_cols]
|
| 376 |
+
# Build rows as required
|
| 377 |
+
rows_for_plot = []
|
| 378 |
+
for _, row in agg.iterrows():
|
| 379 |
+
vals = [float(row[c]) * 100 for c in metric_cols] # scale to 0-100
|
| 380 |
+
rows_for_plot.append({"name": row["Agent"], "values": vals})
|
| 381 |
+
# draw radar using a small internal function
|
| 382 |
+
def spider_net_multi(labels, rows, title="Spider Chart"):
|
| 383 |
+
import math
|
| 384 |
+
N = len(labels)
|
| 385 |
+
angles = [n / float(N) * 2 * math.pi for n in range(N)]
|
| 386 |
+
angles += angles[:1]
|
| 387 |
+
fig = plt.figure(figsize=(6.5,6.5))
|
| 388 |
+
ax = plt.subplot(111, polar=True)
|
| 389 |
+
ax.set_xticks(angles[:-1])
|
| 390 |
+
ax.set_xticklabels(labels)
|
| 391 |
+
ax.set_ylim(0, 100)
|
| 392 |
+
for r in rows:
|
| 393 |
+
v = r["values"] + r["values"][:1]
|
| 394 |
+
ax.plot(angles, v, label=r["name"])
|
| 395 |
+
ax.fill(angles, v, alpha=0.12)
|
| 396 |
+
ax.set_title(title)
|
| 397 |
+
ax.legend(loc="upper right", bbox_to_anchor=(1.3,1.1))
|
| 398 |
+
return fig
|
| 399 |
+
fig = spider_net_multi(labels, rows_for_plot, title="All Agents Comparison (Radar)")
|
| 400 |
+
fname2 = f"/tmp/{uuid.uuid4().hex}_radar.png"
|
| 401 |
+
fig.savefig(fname2, bbox_inches="tight")
|
| 402 |
+
plt.close(fig)
|
| 403 |
+
images.append((fname2, "All agents radar chart"))
|
| 404 |
+
except Exception:
|
| 405 |
+
pass
|
| 406 |
+
|
| 407 |
+
# Per-task spider charts
|
| 408 |
+
try:
|
| 409 |
+
for task, subset in metrics_df.groupby("Task"):
|
| 410 |
+
agg = subset.groupby("Agent")[metric_cols].mean().reset_index()
|
| 411 |
+
if agg.shape[0] == 0:
|
| 412 |
+
continue
|
| 413 |
+
rows_for_plot = []
|
| 414 |
+
for _, row in agg.iterrows():
|
| 415 |
+
vals = [float(row[c]) * 100 for c in metric_cols]
|
| 416 |
+
rows_for_plot.append({"name": row["Agent"], "values": vals})
|
| 417 |
+
fig = spider_net_multi(labels, rows_for_plot, title=f"{task} Agents (Radar)")
|
| 418 |
+
fname3 = f"/tmp/{uuid.uuid4().hex}_{task}_radar.png"
|
| 419 |
+
fig.savefig(fname3, bbox_inches="tight")
|
| 420 |
+
plt.close(fig)
|
| 421 |
+
images.append((fname3, f"{task} - radar"))
|
| 422 |
+
except Exception:
|
| 423 |
+
pass
|
| 424 |
+
|
| 425 |
+
# Heatmap for metric correlations
|
| 426 |
+
try:
|
| 427 |
+
metric_cols2 = ["score_instruction", "score_coherence", "score_grammar", "score_accuracy", "final_score"]
|
| 428 |
+
if use_llm_judge:
|
| 429 |
+
metric_cols2.append("score_llm_consistency")
|
| 430 |
+
fig, ax = plt.subplots(figsize=(7,6))
|
| 431 |
+
sns.heatmap(metrics_df[metric_cols2].corr(), annot=True, fmt=".2f", cmap="coolwarm", ax=ax)
|
| 432 |
+
ax.set_title("Metric correlations")
|
| 433 |
+
fnameh = f"/tmp/{uuid.uuid4().hex}_heatmap.png"
|
| 434 |
+
fig.savefig(fnameh, bbox_inches="tight")
|
| 435 |
+
plt.close(fig)
|
| 436 |
+
images.append((fnameh, "Metric correlations"))
|
| 437 |
+
except Exception:
|
| 438 |
+
pass
|
| 439 |
+
|
| 440 |
+
# Leaderboard df return
|
| 441 |
+
leaderboard_df = metrics_df.groupby(["Agent", "Task"])["final_score"].mean().reset_index().sort_values("final_score", ascending=False)
|
| 442 |
+
|
| 443 |
+
return metrics_df, images, leaderboard_df
|