Spaces:
Sleeping
Sleeping
File size: 2,303 Bytes
1e732dd 696f787 1e732dd fd5543a 1e732dd fd5543a 1e732dd fd5543a 1e732dd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 | """
MediGuard AI — Grade Documents Node
Uses the LLM to judge whether each retrieved document is relevant to the query.
"""
from __future__ import annotations
import json
import logging
from typing import Any
from src.services.agents.prompts import GRADING_SYSTEM
logger = logging.getLogger(__name__)
def grade_documents_node(state: dict, *, context: Any) -> dict:
"""Grade each retrieved document for relevance."""
query = state.get("rewritten_query") or state.get("query", "")
documents = state.get("retrieved_documents", [])
if context.tracer:
context.tracer.trace(name="grade_documents_node", metadata={"query": query})
if not documents:
return {
"grading_results": [],
"relevant_documents": [],
"needs_rewrite": True,
}
relevant: list[dict] = []
grading_results: list[dict] = []
for doc in documents:
text = doc.get("content") or doc.get("text", "")
user_msg = f"Query: {query}\n\nDocument:\n{text[:2000]}"
try:
response = context.llm.invoke(
[
{"role": "system", "content": GRADING_SYSTEM},
{"role": "user", "content": user_msg},
]
)
content = response.content.strip()
if "```" in content:
content = content.split("```")[1].split("```")[0].strip()
if content.startswith("json"):
content = content[4:].strip()
data = json.loads(content)
is_relevant = str(data.get("relevant", "false")).lower() == "true"
except Exception as exc:
logger.warning("Grading LLM failed for doc %s: %s — marking relevant", doc.get("id"), exc)
is_relevant = True # benefit of the doubt
grading_results.append({"doc_id": doc.get("id", doc.get("_id")), "relevant": is_relevant})
if is_relevant:
relevant.append(doc)
attempts = state.get("retrieval_attempts", 1)
max_attempts = state.get("max_retrieval_attempts", 2)
needs_rewrite = len(relevant) < 2 and attempts < max_attempts
return {
"grading_results": grading_results,
"relevant_documents": relevant,
"needs_rewrite": needs_rewrite,
}
|