File size: 17,399 Bytes
d520909
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
"""
Critic Agent

Validates generated answers for hallucination and factual accuracy.
Follows FAANG best practices for production RAG systems.

Key Features:
- Hallucination detection
- Citation verification
- Factual consistency checking
- Confidence scoring
- Actionable feedback for self-correction
"""

from typing import List, Optional, Dict, Any, Tuple
from pydantic import BaseModel, Field
from loguru import logger
from enum import Enum
import json
import re

try:
    import httpx
    HTTPX_AVAILABLE = True
except ImportError:
    HTTPX_AVAILABLE = False

from .synthesizer import SynthesisResult, Citation
from .reranker import RankedResult


class IssueType(str, Enum):
    """Types of validation issues."""
    HALLUCINATION = "hallucination"        # Information not in sources
    UNSUPPORTED_CLAIM = "unsupported_claim" # Claim without citation
    INCORRECT_CITATION = "incorrect_citation" # Citation doesn't support claim
    CONTRADICTION = "contradiction"         # Contradicts source material
    INCOMPLETE = "incomplete"               # Missing important information
    FACTUAL_ERROR = "factual_error"         # Verifiable factual mistake


class ValidationIssue(BaseModel):
    """A single validation issue found."""
    issue_type: IssueType
    severity: float = Field(ge=0.0, le=1.0)  # 0 = minor, 1 = critical
    description: str
    problematic_text: Optional[str] = None
    suggestion: Optional[str] = None
    citation_index: Optional[int] = None


class CriticResult(BaseModel):
    """Result of answer validation."""
    is_valid: bool
    confidence: float
    issues: List[ValidationIssue]

    # Detailed scores
    hallucination_score: float = Field(ge=0.0, le=1.0)  # 0 = no hallucination
    citation_accuracy: float = Field(ge=0.0, le=1.0)
    factual_consistency: float = Field(ge=0.0, le=1.0)

    # For self-correction
    needs_revision: bool = False
    revision_suggestions: List[str] = Field(default_factory=list)


class CriticConfig(BaseModel):
    """Configuration for critic agent."""
    # LLM settings
    model: str = Field(default="llama3.2:3b")
    base_url: str = Field(default="http://localhost:11434")
    temperature: float = Field(default=0.1)

    # Validation thresholds
    hallucination_threshold: float = Field(default=0.3)
    citation_accuracy_threshold: float = Field(default=0.7)
    overall_confidence_threshold: float = Field(default=0.6)

    # Validation options
    check_hallucination: bool = Field(default=True)
    check_citations: bool = Field(default=True)
    check_consistency: bool = Field(default=True)


class CriticAgent:
    """
    Validates generated answers for quality and accuracy.

    Capabilities:
    1. Hallucination detection
    2. Citation verification
    3. Factual consistency checking
    4. Actionable revision suggestions
    """

    HALLUCINATION_PROMPT = """Analyze this answer for hallucination - information NOT supported by the provided sources.

SOURCES:
{sources}

ANSWER:
{answer}

For each claim in the answer, determine if it is:
1. SUPPORTED - Directly supported by the sources
2. PARTIALLY_SUPPORTED - Somewhat supported but with additions
3. UNSUPPORTED - Not found in sources (hallucination)

Respond with JSON:
{{
    "claims": [
        {{"text": "claim text", "status": "SUPPORTED|PARTIALLY_SUPPORTED|UNSUPPORTED", "source_index": 1 or null}}
    ],
    "hallucination_score": 0.0-1.0,
    "issues": ["list of specific issues found"]
}}"""

    CITATION_PROMPT = """Verify that each citation in this answer correctly references the source material.

SOURCES:
{sources}

ANSWER WITH CITATIONS:
{answer}

For each citation [N], check if the claim it supports is actually in source N.

Respond with JSON:
{{
    "citation_checks": [
        {{"citation_index": 1, "is_accurate": true/false, "reason": "explanation"}}
    ],
    "overall_accuracy": 0.0-1.0
}}"""

    def __init__(self, config: Optional[CriticConfig] = None):
        """
        Initialize Critic Agent.

        Args:
            config: Critic configuration
        """
        self.config = config or CriticConfig()
        logger.info(f"CriticAgent initialized (model={self.config.model})")

    def validate(
        self,
        synthesis_result: SynthesisResult,
        sources: List[RankedResult],
    ) -> CriticResult:
        """
        Validate a synthesized answer.

        Args:
            synthesis_result: The generated answer with citations
            sources: Source chunks used for generation

        Returns:
            CriticResult with validation details
        """
        issues = []
        hallucination_score = 0.0
        citation_accuracy = 1.0
        factual_consistency = 1.0

        # Skip validation for abstained answers
        if synthesis_result.abstained:
            return CriticResult(
                is_valid=True,
                confidence=1.0,
                issues=[],
                hallucination_score=0.0,
                citation_accuracy=1.0,
                factual_consistency=1.0,
            )

        # Check for hallucination
        if self.config.check_hallucination and HTTPX_AVAILABLE:
            h_score, h_issues = self._check_hallucination(
                synthesis_result.answer,
                sources,
            )
            hallucination_score = h_score
            issues.extend(h_issues)

        # Check citation accuracy
        if self.config.check_citations and synthesis_result.citations:
            c_accuracy, c_issues = self._check_citations(
                synthesis_result.answer,
                synthesis_result.citations,
                sources,
            )
            citation_accuracy = c_accuracy
            issues.extend(c_issues)

        # Check factual consistency
        if self.config.check_consistency:
            f_score, f_issues = self._check_consistency(
                synthesis_result.answer,
                sources,
            )
            factual_consistency = f_score
            issues.extend(f_issues)

        # Calculate overall confidence
        confidence = (
            0.4 * (1 - hallucination_score) +
            0.4 * citation_accuracy +
            0.2 * factual_consistency
        )

        # Determine if valid
        is_valid = (
            hallucination_score < self.config.hallucination_threshold and
            citation_accuracy >= self.config.citation_accuracy_threshold and
            confidence >= self.config.overall_confidence_threshold
        )

        # Generate revision suggestions if needed
        needs_revision = not is_valid and len(issues) > 0
        revision_suggestions = self._generate_revision_suggestions(issues) if needs_revision else []

        return CriticResult(
            is_valid=is_valid,
            confidence=confidence,
            issues=issues,
            hallucination_score=hallucination_score,
            citation_accuracy=citation_accuracy,
            factual_consistency=factual_consistency,
            needs_revision=needs_revision,
            revision_suggestions=revision_suggestions,
        )

    def _check_hallucination(
        self,
        answer: str,
        sources: List[RankedResult],
    ) -> Tuple[float, List[ValidationIssue]]:
        """Check for hallucination using LLM."""
        # Build source context
        source_text = self._format_sources(sources)

        prompt = self.HALLUCINATION_PROMPT.format(
            sources=source_text,
            answer=answer,
        )

        try:
            with httpx.Client(timeout=30.0) as client:
                response = client.post(
                    f"{self.config.base_url}/api/generate",
                    json={
                        "model": self.config.model,
                        "prompt": prompt,
                        "stream": False,
                        "options": {
                            "temperature": self.config.temperature,
                            "num_predict": 1024,
                        },
                    },
                )
                response.raise_for_status()
                result = response.json()

            # Parse response
            response_text = result.get("response", "")
            data = self._parse_json_response(response_text)

            hallucination_score = data.get("hallucination_score", 0.0)

            issues = []
            for claim in data.get("claims", []):
                if claim.get("status") == "UNSUPPORTED":
                    issues.append(ValidationIssue(
                        issue_type=IssueType.HALLUCINATION,
                        severity=0.8,
                        description=f"Unsupported claim: {claim.get('text', '')}",
                        problematic_text=claim.get("text"),
                        suggestion="Remove or find supporting source",
                    ))
                elif claim.get("status") == "PARTIALLY_SUPPORTED":
                    issues.append(ValidationIssue(
                        issue_type=IssueType.UNSUPPORTED_CLAIM,
                        severity=0.4,
                        description=f"Partially supported: {claim.get('text', '')}",
                        problematic_text=claim.get("text"),
                        suggestion="Verify claim against source",
                    ))

            return hallucination_score, issues

        except Exception as e:
            logger.warning(f"Hallucination check failed: {e}")
            # Fall back to heuristic check
            return self._heuristic_hallucination_check(answer, sources)

    def _heuristic_hallucination_check(
        self,
        answer: str,
        sources: List[RankedResult],
    ) -> Tuple[float, List[ValidationIssue]]:
        """Simple heuristic hallucination check."""
        # Combine all source text
        source_text = " ".join(s.text.lower() for s in sources)
        answer_lower = answer.lower()

        # Check for proper nouns/entities not in sources
        # Simple approach: look for capitalized words
        answer_words = set(re.findall(r'\b[A-Z][a-z]+\b', answer))
        source_words = set(re.findall(r'\b[A-Z][a-z]+\b', " ".join(s.text for s in sources)))

        unsupported_entities = answer_words - source_words
        # Filter out common words
        common_words = {"The", "This", "That", "However", "Therefore", "Additionally", "Based", "According"}
        unsupported_entities = unsupported_entities - common_words

        issues = []
        for entity in list(unsupported_entities)[:3]:  # Limit issues
            issues.append(ValidationIssue(
                issue_type=IssueType.HALLUCINATION,
                severity=0.5,
                description=f"Entity '{entity}' not found in sources",
                problematic_text=entity,
            ))

        # Calculate score based on unsupported entities
        if answer_words:
            score = len(unsupported_entities) / len(answer_words)
        else:
            score = 0.0

        return min(score, 1.0), issues

    def _check_citations(
        self,
        answer: str,
        citations: List[Citation],
        sources: List[RankedResult],
    ) -> Tuple[float, List[ValidationIssue]]:
        """Verify citation accuracy."""
        if not citations:
            # No citations when expected
            return 0.0, [ValidationIssue(
                issue_type=IssueType.UNSUPPORTED_CLAIM,
                severity=0.6,
                description="Answer contains no citations",
                suggestion="Add citations to support claims",
            )]

        # Build source context
        source_text = self._format_sources(sources)

        if HTTPX_AVAILABLE:
            try:
                prompt = self.CITATION_PROMPT.format(
                    sources=source_text,
                    answer=answer,
                )

                with httpx.Client(timeout=30.0) as client:
                    response = client.post(
                        f"{self.config.base_url}/api/generate",
                        json={
                            "model": self.config.model,
                            "prompt": prompt,
                            "stream": False,
                            "options": {
                                "temperature": self.config.temperature,
                                "num_predict": 512,
                            },
                        },
                    )
                    response.raise_for_status()
                    result = response.json()

                response_text = result.get("response", "")
                data = self._parse_json_response(response_text)

                accuracy = data.get("overall_accuracy", 1.0)

                issues = []
                for check in data.get("citation_checks", []):
                    if not check.get("is_accurate", True):
                        issues.append(ValidationIssue(
                            issue_type=IssueType.INCORRECT_CITATION,
                            severity=0.6,
                            description=f"Citation [{check.get('citation_index')}]: {check.get('reason', 'Inaccurate')}",
                            citation_index=check.get("citation_index"),
                            suggestion="Verify citation matches source",
                        ))

                return accuracy, issues

            except Exception as e:
                logger.warning(f"Citation check failed: {e}")

        # Fallback: basic citation presence check
        citation_pattern = r'\[(\d+)\]'
        used_citations = set(int(m) for m in re.findall(citation_pattern, answer))

        if not used_citations:
            return 0.5, []

        # Check if citation indices are valid
        valid_indices = set(range(1, len(sources) + 1))
        invalid = used_citations - valid_indices

        issues = []
        for idx in invalid:
            issues.append(ValidationIssue(
                issue_type=IssueType.INCORRECT_CITATION,
                severity=0.7,
                description=f"Citation [{idx}] references non-existent source",
                citation_index=idx,
            ))

        accuracy = 1.0 - (len(invalid) / len(used_citations)) if used_citations else 1.0
        return accuracy, issues

    def _check_consistency(
        self,
        answer: str,
        sources: List[RankedResult],
    ) -> Tuple[float, List[ValidationIssue]]:
        """Check for internal and external consistency."""
        issues = []

        # Check for contradictory statements (simplified)
        contradictions = self._detect_contradictions(answer)
        for contradiction in contradictions:
            issues.append(ValidationIssue(
                issue_type=IssueType.CONTRADICTION,
                severity=0.7,
                description=contradiction,
            ))

        # Check for completeness (are key source points addressed?)
        # Simplified: just check answer isn't too short
        if len(answer) < 50 and len(sources) > 0:
            issues.append(ValidationIssue(
                issue_type=IssueType.INCOMPLETE,
                severity=0.4,
                description="Answer may be incomplete given available sources",
                suggestion="Expand answer to include more relevant information",
            ))

        score = 1.0 - (0.2 * len(issues))
        return max(score, 0.0), issues

    def _detect_contradictions(self, text: str) -> List[str]:
        """Simple contradiction detection."""
        contradictions = []

        # Look for negation patterns that might indicate contradiction
        sentences = text.split('.')
        for i, sent in enumerate(sentences):
            sent_lower = sent.lower()
            # Check for contradictory conjunctions
            if any(c in sent_lower for c in ["however", "but", "although"]):
                # This could be legitimate contrast, so low severity
                pass

        return contradictions

    def _format_sources(self, sources: List[RankedResult]) -> str:
        """Format sources for prompt."""
        parts = []
        for i, source in enumerate(sources, 1):
            parts.append(f"[{i}] {source.text[:500]}")
        return "\n\n".join(parts)

    def _parse_json_response(self, text: str) -> Dict[str, Any]:
        """Parse JSON from LLM response."""
        try:
            json_match = re.search(r'\{[\s\S]*\}', text)
            if json_match:
                return json.loads(json_match.group())
        except json.JSONDecodeError:
            pass
        return {}

    def _generate_revision_suggestions(
        self,
        issues: List[ValidationIssue],
    ) -> List[str]:
        """Generate actionable revision suggestions."""
        suggestions = []

        for issue in issues:
            if issue.suggestion:
                suggestions.append(issue.suggestion)
            elif issue.issue_type == IssueType.HALLUCINATION:
                suggestions.append(
                    f"Remove or verify: {issue.problematic_text or 'unsupported claim'}"
                )
            elif issue.issue_type == IssueType.INCORRECT_CITATION:
                suggestions.append(
                    f"Fix citation [{issue.citation_index}] to match source"
                )

        return list(set(suggestions))[:5]  # Deduplicate and limit