File size: 11,810 Bytes
f589dab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
"""
agents.py β€” Prefect @flow orchestrating two concurrent LiteLLM tasks.

Flow topology:
  run_analysis_flow(claim, platform, rag_context)
       β”‚
       β”œβ”€β”€ misinformation_task()  ←── Groq / mixtral-8x7b-32768
       β”‚     verdict: green | yellow | red
       β”‚
       └── hallucination_task()   ←── Anthropic Claude Haiku (AI platforms only)
             verdict: purple | green
       β”‚
       └── merge_results() β†’ AnalysisResult (higher severity wins)

Severity order: red > purple > yellow > green

Why LiteLLM as abstraction:
  - Single .completion() call works across Groq, Anthropic, OpenAI, Ollama
  - Automatic retry with provider-level fallbacks
  - No code change to swap providers
"""

from __future__ import annotations

import asyncio
import os
from dataclasses import dataclass
from typing import Literal

import litellm
import structlog
from pydantic import BaseModel, Field, field_validator
from tenacity import retry, stop_after_attempt, wait_exponential

from rag_pipeline import RagContext

log = structlog.get_logger(__name__)

# Silence LiteLLM's verbose logs unless explicitly enabled
litellm.set_verbose = os.getenv("LITELLM_VERBOSE", "false").lower() == "true"

# ── Color severity ordering ────────────────────────────────────────────────────
SEVERITY: dict[str, int] = {"green": 0, "yellow": 1, "purple": 2, "red": 3}

COLOR_TYPE = Literal["green", "yellow", "red", "purple"]

# AI-interface platforms that trigger the hallucination agent
AI_PLATFORMS = {"chatgpt", "claude", "gemini", "openai", "ai_chat", "bard", "copilot"}


# ── Result models ─────────────────────────────────────────────────────────────
class AgentOutput(BaseModel):
    color: COLOR_TYPE
    confidence: int = Field(ge=0, le=100)
    verdict: str = Field(max_length=120)
    explanation: str = Field(max_length=600)
    sources: list[str] = Field(default_factory=list, max_length=5)

    @field_validator("color", mode="before")
    @classmethod
    def normalize_color(cls, v: str) -> str:
        """Coerce LLM output to valid color string."""
        v = str(v).lower().strip()
        if v not in SEVERITY:
            return "yellow"
        return v

    @field_validator("confidence", mode="before")
    @classmethod
    def clamp_confidence(cls, v) -> int:
        return max(0, min(100, int(v)))


@dataclass
class AnalysisFlowResult:
    color: str
    confidence: int
    verdict: str
    explanation: str
    sources: list[str]


# ── System prompts ─────────────────────────────────────────────────────────────
MISINFORMATION_SYSTEM = """You are a veteran fact-checking analyst. Given a claim and retrieved evidence, determine whether the claim is true, misleading, or false.

You MUST output ONLY valid JSON matching this exact schema:
{"color": "green"|"yellow"|"red", "confidence": 0-100, "verdict": "<10-word label>", "explanation": "<2-3 sentences>", "sources": ["<url1>", "<url2>", "<url3>"]}

Color logic:
- green: Widely corroborated, verified, factually sound
- yellow: Breaking/unverified, weak evidence, contested but not proven false
- red: Debunked by multiple independent sources, intentional deceit, or contradicts established consensus

Base your confidence on the quality and quantity of retrieved evidence. Use only URLs from the evidence β€” never fabricate URLs."""

HALLUCINATION_SYSTEM = """You are an AI output auditor specializing in detecting LLM hallucinations. Analyze AI-generated text for:
1. Fabricated citations (URLs, paper titles, author names that don't match real publications)
2. Statistical impossibilities (numbers that cannot logically be correct)
3. Internal contradictions (statements that contradict each other in the same passage)
4. Knowledge cutoff violations (claiming events that postdate the model's training)

You MUST output ONLY valid JSON:
{"color": "purple"|"green", "confidence": 0-100, "verdict": "<10-word label>", "explanation": "<2-3 sentences describing the specific hallucination type>", "sources": []}

purple = hallucination detected with high probability
green = no hallucination detected"""


# ── LiteLLM call wrapper ───────────────────────────────────────────────────────
@retry(
    stop=stop_after_attempt(2),
    wait=wait_exponential(multiplier=0.1, min=0.1, max=1.0),
    reraise=False,
)
async def _call_llm(
    model: str,
    system: str,
    user_content: str,
    max_tokens: int = 400,
) -> str | None:
    """
    Thin async wrapper around litellm.acompletion.
    Returns the response text or None on failure.
    """
    try:
        response = await litellm.acompletion(
            model=model,
            messages=[
                {"role": "system", "content": system},
                {"role": "user", "content": user_content},
            ],
            temperature=0.1,
            max_tokens=max_tokens,
            response_format={"type": "json_object"},
        )
        return response.choices[0].message.content
    except Exception as exc:
        log.warning("llm.call_failed", model=model, error=str(exc))
        return None


def _parse_agent_output(raw: str | None, fallback_color: COLOR_TYPE) -> AgentOutput:
    """Parse LLM JSON response with graceful fallback."""
    if raw is None:
        return AgentOutput(
            color=fallback_color,
            confidence=40,
            verdict="Analysis unavailable",
            explanation="LLM service temporarily unavailable. Result based on heuristics only.",
            sources=[],
        )
    try:
        import json, re
        # Strip any accidental markdown fences
        cleaned = re.sub(r"```(?:json)?|```", "", raw).strip()
        data = json.loads(cleaned)
        return AgentOutput.model_validate(data)
    except Exception as exc:
        log.warning("agent.parse_error", error=str(exc), raw=raw[:200])
        return AgentOutput(
            color=fallback_color,
            confidence=35,
            verdict="Parse error",
            explanation=f"Could not parse agent response. Raw snippet: {raw[:100]}",
            sources=[],
        )


# ── Individual tasks ───────────────────────────────────────────────────────────
async def misinformation_task(
    claim_text: str,
    rag_context: RagContext,
) -> AgentOutput:
    """
    Uses mixtral-8x7b-32768 via Groq for high-throughput misinformation detection.
    Falls back to llama3-8b-8192 if mixtral quota is exceeded.
    """
    # Build concise evidence summary from top-3 RAG docs
    evidence_lines = []
    for i, doc in enumerate(rag_context.retrieved_docs[:3], 1):
        evidence_lines.append(
            f"{i}. [{doc.domain}] (score:{doc.score:.2f}) {doc.text[:180]}\n   URL: {doc.source_url}"
        )
    evidence_block = "\n".join(evidence_lines) if evidence_lines else "No retrieved evidence."

    user_content = (
        f"CLAIM: {claim_text}\n\n"
        f"TRUST SCORE: {rag_context.trust_score:.2f} "
        f"(community_note={rag_context.community_note}, "
        f"corroborations={rag_context.corroboration_count})\n\n"
        f"RETRIEVED EVIDENCE:\n{evidence_block}"
    )

    # Prefer Groq's Mixtral; fallback model chain
    groq_key = os.getenv("GROQ_API_KEY", "")
    model = f"groq/mixtral-8x7b-32768" if groq_key else "openai/gpt-4o-mini"

    raw = await _call_llm(model=model, system=MISINFORMATION_SYSTEM, user_content=user_content)

    # If primary model fails, try secondary
    if raw is None and groq_key:
        raw = await _call_llm(
            model="groq/llama3-8b-8192",
            system=MISINFORMATION_SYSTEM,
            user_content=user_content,
        )

    output = _parse_agent_output(raw, fallback_color="yellow")

    # Override: community notes are strong red signals
    if rag_context.community_note and output.color != "red":
        output.color = "red"
        output.confidence = max(output.confidence, 75)
        output.explanation = f"⚠ Active Community Note. {output.explanation}"

    # Override: low trust score combined with no corroboration β†’ yellow floor
    if rag_context.trust_score < 0.3 and rag_context.corroboration_count == 0:
        if output.color == "green":
            output.color = "yellow"
            output.confidence = min(output.confidence, 55)

    log.info(
        "misinformation_task.done",
        color=output.color,
        confidence=output.confidence,
        model=model,
    )
    return output


async def hallucination_task(claim_text: str) -> AgentOutput:
    """
    Runs only for AI chat platform sources.
    Uses Claude Haiku for superior hallucination pattern recognition.
    Falls back to Groq llama3 if Anthropic key is absent.
    """
    anthropic_key = os.getenv("ANTHROPIC_API_KEY", "")
    model = "claude-haiku-4-5-20251001" if anthropic_key else "groq/llama3-8b-8192"

    raw = await _call_llm(
        model=model,
        system=HALLUCINATION_SYSTEM,
        user_content=f"Audit this AI-generated text for hallucinations:\n\n{claim_text}",
        max_tokens=300,
    )
    output = _parse_agent_output(raw, fallback_color="purple")
    log.info(
        "hallucination_task.done",
        color=output.color,
        confidence=output.confidence,
        model=model,
    )
    return output


def _merge_results(
    misinfo: AgentOutput,
    hallucination: AgentOutput | None,
) -> AnalysisFlowResult:
    """
    Severity-based merge: pick the higher-severity color.
    Purple (AI hallucination) and Red (misinformation) are both max severity
    but represent different categories β€” red wins if both fire.
    """
    if hallucination is None:
        winner = misinfo
    else:
        winner = misinfo if SEVERITY[misinfo.color] >= SEVERITY[hallucination.color] else hallucination

    return AnalysisFlowResult(
        color=winner.color,
        confidence=winner.confidence,
        verdict=winner.verdict,
        explanation=winner.explanation,
        sources=winner.sources,
    )


# ── Main flow (replaces Prefect decorator for HF compatibility) ────────────────
async def run_analysis_flow(
    claim_text: str,
    claim_hash: str,
    platform: str,
    rag_context: RagContext,
) -> AnalysisFlowResult:
    """
    Orchestrates concurrent agent tasks.
    On Hugging Face Spaces, Prefect's scheduler is replaced by asyncio.gather
    for zero-dependency concurrent execution. For production Prefect deployment,
    wrap each inner call with @task decorator.
    """
    is_ai_platform = platform.lower() in AI_PLATFORMS

    if is_ai_platform:
        # Run both tasks concurrently
        misinfo_coro = misinformation_task(claim_text, rag_context)
        halluc_coro = hallucination_task(claim_text)
        misinfo_result, halluc_result = await asyncio.gather(
            misinfo_coro, halluc_coro
        )
    else:
        misinfo_result = await misinformation_task(claim_text, rag_context)
        halluc_result = None

    merged = _merge_results(misinfo_result, halluc_result)

    log.info(
        "flow.complete",
        claim_hash=claim_hash[:8],
        color=merged.color,
        confidence=merged.confidence,
        platform=platform,
        ai_platform=is_ai_platform,
    )
    return merged