File size: 2,303 Bytes
1e732dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
696f787
 
 
1e732dd
 
 
 
 
 
 
 
 
 
 
fd5543a
1e732dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fd5543a
1e732dd
 
 
fd5543a
 
 
1e732dd
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
"""
MediGuard AI — Grade Documents Node

Uses the LLM to judge whether each retrieved document is relevant to the query.
"""

from __future__ import annotations

import json
import logging
from typing import Any

from src.services.agents.prompts import GRADING_SYSTEM

logger = logging.getLogger(__name__)


def grade_documents_node(state: dict, *, context: Any) -> dict:
    """Grade each retrieved document for relevance."""
    query = state.get("rewritten_query") or state.get("query", "")
    documents = state.get("retrieved_documents", [])

    if context.tracer:
        context.tracer.trace(name="grade_documents_node", metadata={"query": query})

    if not documents:
        return {
            "grading_results": [],
            "relevant_documents": [],
            "needs_rewrite": True,
        }

    relevant: list[dict] = []
    grading_results: list[dict] = []

    for doc in documents:
        text = doc.get("content") or doc.get("text", "")
        user_msg = f"Query: {query}\n\nDocument:\n{text[:2000]}"
        try:
            response = context.llm.invoke(
                [
                    {"role": "system", "content": GRADING_SYSTEM},
                    {"role": "user", "content": user_msg},
                ]
            )
            content = response.content.strip()
            if "```" in content:
                content = content.split("```")[1].split("```")[0].strip()
                if content.startswith("json"):
                    content = content[4:].strip()
            data = json.loads(content)
            is_relevant = str(data.get("relevant", "false")).lower() == "true"
        except Exception as exc:
            logger.warning("Grading LLM failed for doc %s: %s — marking relevant", doc.get("id"), exc)
            is_relevant = True  # benefit of the doubt

        grading_results.append({"doc_id": doc.get("id", doc.get("_id")), "relevant": is_relevant})
        if is_relevant:
            relevant.append(doc)

    attempts = state.get("retrieval_attempts", 1)
    max_attempts = state.get("max_retrieval_attempts", 2)
    needs_rewrite = len(relevant) < 2 and attempts < max_attempts

    return {
        "grading_results": grading_results,
        "relevant_documents": relevant,
        "needs_rewrite": needs_rewrite,
    }