File size: 13,896 Bytes
c7256ee
 
 
 
 
 
 
 
 
 
 
 
c27a4e3
c7256ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c27a4e3
 
 
 
c7256ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6da8267
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
import re
import numpy as np
from sklearn.metrics.pairwise import cosine_similarity
from openai import OpenAI
from concurrent.futures import ThreadPoolExecutor, as_completed


# ------------------------------------------------------------------
# OpenRouter Judge Wrapper
# ------------------------------------------------------------------

class GroqJudge:
    def __init__(self, api_key: str, model: str = "stepfun/step-3.5-flash:free"):
        """
        Wraps OpenRouter's chat completions to match the .generate(prompt) interface
        expected by RAGEvaluator.

        Args:
            api_key: Your OpenRouter API key (https://openrouter.ai)
            model:   OpenRouter model to use (primary model with fallback support)
        """
        self.client = OpenAI(
            base_url="https://openrouter.ai/api/v1",
            api_key=api_key,
        )
        self.model = model
        
        # Fallback models in order of preference (OpenRouter free models)
        self.fallback_models = [
            "stepfun/step-3.5-flash:free",
            "nvidia/nemotron-3-super-120b-a12b:free",
            "z-ai/glm-4.5-air:free",
            "nvidia/nemotron-3-nano-30b-a3b:free",
            "arcee-ai/trinity-mini:free",
            "xiaomi/mimo-v2-flash"
        ]

    def generate(self, prompt: str) -> str:
        """Generate response with fallback support for multiple models."""
        last_error = None
        
        # Try primary model first, then fallbacks
        models_to_try = [self.model] + [m for m in self.fallback_models if m != self.model]
        
        for model_name in models_to_try:
               
           
            try:
                response = self.client.chat.completions.create(
                    model=model_name,
                    messages=[{"role": "user", "content": prompt}],
                )
                content = response.choices[0].message.content
                if content is None:
                    raise ValueError(f"Model {model_name} returned None content")
                return content.strip()
            except Exception as e:
                last_error = e
                # If rate limited or model unavailable, try next model
                if "429" in str(e) or "rate_limit" in str(e).lower() or "model" in str(e).lower():
                    continue
                # For other errors, raise immediately
                raise
        
        # If all models fail, raise the last error
        raise last_error


# ------------------------------------------------------------------
# RAG Evaluator
# ------------------------------------------------------------------

class RAGEvaluator:
    def __init__(self, judge_model: str, embedding_model, api_key: str, verbose=True):
        """
        judge_model:     Model name string passed to OpenRouterJudge, must match cfg.gen['judge_model']
                         e.g. "stepfun/step-3.5-flash:free", "nvidia/nemotron-3-super-120b-a12b:free"
        embedding_model: The proc.encoder (SentenceTransformer) for similarity checks
        api_key:         OpenRouter API key (https://openrouter.ai)
        verbose:         If True, prints progress via internal helpers
        """
        self.judge = GroqJudge(api_key=api_key, model=judge_model)
        self.encoder = embedding_model
        self.verbose = verbose

    # ------------------------------------------------------------------
    # 1. FAITHFULNESS: Claim Extraction & Verification
    # ------------------------------------------------------------------

    def evaluate_faithfulness(self, answer: str, context_list: list[str], strict: bool = True) -> dict:
        """
        Args:
            strict: If True, verifies each claim against chunks individually
                    (more API calls but catches vague batch verdicts).
                    If False, uses single batched verification call.
        """
        if self.verbose:
            self._print_extraction_header(len(answer), strict=strict)

        # --- Step A: Extraction ---
        extraction_prompt = (
            "Extract a list of independent factual claims from the following answer.\n"
            "Rules:\n"
            "- Each claim must be specific and verifiable — include numbers, names, or concrete details where present\n"
            "- Vague claims like 'the model performs well' or 'this improves results' are NOT acceptable\n"
            "- Do NOT include claims about what the context does or does not contain\n"
            "- Do NOT include introductory text, numbering, or bullet points\n"
            "- Do NOT rephrase or merge claims\n"
            "- One claim per line only\n\n"
            f"Answer: {answer}"
        )
        raw_claims = self.judge.generate(extraction_prompt)

        # Filter out short lines, preamble, and lines ending with ':'
        claims = [
            c.strip() for c in raw_claims.split('\n')
            if len(c.strip()) > 20 and not c.strip().endswith(':')
        ]

        if not claims:
            return {"score": 0, "details": []}

        # --- Step B: Verification ---
        if strict:
            # Per-chunk: claim must be explicitly supported by at least one chunk
            # Parallelize across claims as well
            def verify_claim_wrapper(args):
                i, claim = args
                return i, self._verify_claim_against_chunks(claim, context_list)
            
            with ThreadPoolExecutor(max_workers=min(len(claims), 5)) as executor:
                futures = [executor.submit(verify_claim_wrapper, (i, claim)) for i, claim in enumerate(claims)]
                verdicts = {i: result for future in as_completed(futures) for i, result in [future.result()]}
        else:
            # Batch: all chunks joined, strict burden-of-proof prompt
            combined_context = "\n".join(context_list)
            if len(combined_context) > 6000:
                combined_context = combined_context[:6000]

            claims_formatted = "\n".join([f"{i+1}. {c}" for i, c in enumerate(claims)])

            batch_prompt = (
                f"Context:\n{combined_context}\n\n"
                f"For each claim, respond YES only if the claim is EXPLICITLY and DIRECTLY "
                f"supported by the context above. Respond NO if the claim is inferred, assumed, "
                f"or not clearly stated in the context.\n\n"
                f"Format strictly as:\n"
                f"1: YES\n"
                f"2: NO\n\n"
                f"Claims:\n{claims_formatted}"
            )
            raw_verdicts = self.judge.generate(batch_prompt)

            verdicts = {}
            for line in raw_verdicts.split('\n'):
                match = re.match(r'(\d+)\s*:\s*(YES|NO)', line.strip().upper())
                if match:
                    verdicts[int(match.group(1)) - 1] = match.group(2) == "YES"

        # --- Step C: Scoring & Details ---
        verified_count = 0
        details = []
        for i, claim in enumerate(claims):
            is_supported = verdicts.get(i, False)
            if is_supported:
                verified_count += 1
            details.append({
                "claim": claim,
                "verdict": "Supported" if is_supported else "Not Supported"
            })

        score = (verified_count / len(claims)) * 100

        if self.verbose:
            self._print_faithfulness_results(claims, details, score)

        return {"score": score, "details": details}

    def _verify_claim_against_chunks(self, claim: str, context_list: list[str]) -> bool:
        """Verify a single claim against each chunk individually. Returns True if any chunk supports it."""
        def verify_single_chunk(chunk):
            prompt = (
                f"Context:\n{chunk}\n\n"
                f"Claim: {claim}\n\n"
                f"Is this claim EXPLICITLY and DIRECTLY stated in the context above? "
                f"Do not infer or assume. Respond with YES or NO only."
            )
            result = self.judge.generate(prompt)
            return "YES" in result.upper()
        
        # Use ThreadPoolExecutor for parallel verification
        with ThreadPoolExecutor(max_workers=min(len(context_list), 5)) as executor:
            futures = [executor.submit(verify_single_chunk, chunk) for chunk in context_list]
            for future in as_completed(futures):
                if future.result():
                    return True
        return False

    # ------------------------------------------------------------------
    # 2. RELEVANCY: Alternate Query Generation
    # ------------------------------------------------------------------

    def evaluate_relevancy(self, query: str, answer: str) -> dict:
        if self.verbose:
            self._print_relevancy_header()

        # --- Step A: Generation ---
        # Explicitly ask the judge NOT to rephrase the original query
        gen_prompt = (
            f"Generate 3 distinct questions that the following answer addresses.\n"
            f"Rules:\n"
            f"- Each question must end with a '?'\n"
            f"- One question per line, no numbering or bullet points\n\n"
            f"Answer: {answer}"
        )
        raw_gen = self.judge.generate(gen_prompt)

        # Filter by length rather than just '?' presence
        gen_queries = [
            q.strip() for q in raw_gen.split('\n')
            if len(q.strip()) > 10
        ][:3]

        if not gen_queries:
            return {"score": 0, "queries": []}

        # --- Step B: Similarity (single batched encode call) ---
        try:
            all_vecs = self.encoder.encode([query] + gen_queries)
        except AttributeError:
            all_vecs = np.array([self.encoder.encode(text) for text in [query] + gen_queries])
        original_vec = all_vecs[0:1]
        generated_vecs = all_vecs[1:]

        similarities = cosine_similarity(original_vec, generated_vecs)[0]
        avg_score = float(np.mean(similarities))

        if self.verbose:
            self._print_relevancy_results(query, gen_queries, similarities, avg_score)

        return {"score": avg_score, "queries": gen_queries}

    # ------------------------------------------------------------------
    # 3. DATASET-LEVEL EVALUATION
    # ------------------------------------------------------------------

    def evaluate_dataset(self, test_cases: list[dict], strict: bool = False) -> dict:
        """
        Runs faithfulness + relevancy over a full test set and aggregates results.

        Args:
            test_cases: List of dicts, each with keys:
                        - "query":    str
                        - "answer":   str
                        - "contexts": List[str]
            strict:     If True, passes strict=True to evaluate_faithfulness
                        (per-chunk verification, more API calls, harder to pass)

        Returns:
            {
                "avg_faithfulness": float,
                "avg_relevancy":    float,
                "per_query":        List[dict]
            }
        """
        faithfulness_scores = []
        relevancy_scores = []
        per_query = []

        for i, case in enumerate(test_cases):
            if self.verbose:
                print(f"\n{'='*60}")
                print(f"Query {i+1}/{len(test_cases)}: {case['query']}")
                print('='*60)

            f_result = self.evaluate_faithfulness(case['answer'], case['contexts'], strict=strict)
            r_result = self.evaluate_relevancy(case['query'], case['answer'])

            faithfulness_scores.append(f_result['score'])
            relevancy_scores.append(r_result['score'])
            per_query.append({
                "query":       case['query'],
                "faithfulness": f_result,
                "relevancy":    r_result,
            })

        results = {
            "avg_faithfulness": float(np.mean(faithfulness_scores)),
            "avg_relevancy":    float(np.mean(relevancy_scores)),
            "per_query":        per_query,
        }

        if self.verbose:
            self._print_dataset_summary(results)

        return results

    # ------------------------------------------------------------------
    # 4. PRINT HELPERS
    # ------------------------------------------------------------------

    def _print_extraction_header(self, length, strict=False):
        mode = "strict per-chunk" if strict else "batch"
        print(f"\n[EVAL] Analyzing Faithfulness ({mode})...")
        print(f"      - Extracting claims from answer ({length} chars)")

    def _print_faithfulness_results(self, claims, details, score):
        print(f"      - Verifying {len(claims)} claims against context...")
        for i, detail in enumerate(details):
            status = "✅" if "Yes" in detail['verdict'] else "❌"
            print(f"        {status} Claim {i+1}: {detail['claim'][:75]}...")
        print(f"      🎯 Faithfulness Score: {score:.1f}%")

    def _print_relevancy_header(self):
        print(f"\n[EVAL] Analyzing Relevancy...")
        print(f"      - Generating 3 distinct questions addressed by the answer")

    def _print_relevancy_results(self, query, gen_queries, similarities, avg):
        print(f"      - Comparing to original query: '{query}'")
        for i, (q, sim) in enumerate(zip(gen_queries, similarities)):
            print(f"        Q{i+1}: {q} (Sim: {sim:.2f})")
        print(f"      🎯 Average Relevancy: {avg:.2f}")

    def _print_dataset_summary(self, results):
        print(f"\n{'='*60}")
        print(f"  DATASET EVALUATION SUMMARY")
        print(f"{'='*60}")
        print(f"  Avg Faithfulness : {results['avg_faithfulness']:.1f}%")
        print(f"  Avg Relevancy    : {results['avg_relevancy']:.2f}")
        print(f"  Queries Evaluated: {len(results['per_query'])}")
        print(f"{'='*60}")