BioRAG / src /bio_rag /claim_decomposer.py
aseelflihan's picture
Deploy Bio-RAG
2a2c039
from __future__ import annotations
import json
import logging
import re
import sys
import os
# Add root folder to sys.path to be able to import prompts
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
try:
from prompts import DECOMPOSITION_PROMPT
except ImportError:
# Fallback if import fails
DECOMPOSITION_PROMPT = """You are an expert medical analyzer. Break down the following medical answer into a list of atomic, verifiable facts (claims).
You must inject context from the original question into every claim so it is completely self-sufficient.
RULES:
1. Each claim must be an atomic, standalone factual statement.
2. Each claim must explicitly embed the medical subject, the condition context (e.g., diabetes), and any patient constraints mentioned in the question.
3. Preserve negation: e.g., 'Metformin is NOT recommended' must remain negated.
4. Preserve uncertainty: e.g., 'Metformin may cause...' must keep 'may'.
5. Preserve conditionality: e.g., 'When kidney function is below 30...' must stay conditional.
6. Format the output as a valid JSON object with the key 'claims' containing an array of strings ONLY. Do not include markdown or explanations. NEVER output just an array directly.
7. Do NOT include any reference codes like [E1], [E2], [E3] in claims.
8. Do NOT mention study names or abstract numbers. Extract only the medical fact itself.
9. Do NOT add unnecessary filler phrases like "For a patient with no specified condition".
Original Question:
{question}
Answer to Decompose:
{answer}
JSON Output:"""
logger = logging.getLogger(__name__)
class ClaimDecomposer:
"""Decomposes an answer into atomic, context-injected claims using an LLM."""
def __init__(self, generator) -> None:
self.generator = generator
def decompose(self, question: str, answer: str) -> list[str]:
# Ensure we use our updated prompt even if local prompts.py exists by overriding for this test
prompt = DECOMPOSITION_PROMPT.format(question=question, answer=answer)
try:
output = self._generate_with_model(prompt, is_json=True)
import re
cleaned_json = re.sub(r'^```[jJ]son\s*', '', output)
cleaned_json = re.sub(r'```$', '', cleaned_json).strip()
obj = json.loads(cleaned_json)
claims = obj.get("claims", [])
if isinstance(claims, list) and all(isinstance(c, str) for c in claims):
return claims
logger.warning("Failed to parse JSON for claim decomposition. Attempting simple split fallback.")
return self._fallback_decompose(answer)
except Exception as e:
logger.warning(f"Error during claim decomposition: {e}")
return self._fallback_decompose(answer)
def _fallback_decompose(self, answer: str) -> list[str]:
"""Fallback just in case the LLM or JSON parsing fails severely."""
_SENTENCE_SPLIT = re.compile(r"(?<=[.!?])\s+")
return [s.strip(" -\n\t") for s in _SENTENCE_SPLIT.split(answer.strip()) if len(s.strip()) > 10]
def _generate_with_model(self, text: str, is_json: bool = False) -> str:
# Calls the centralized Groq API generation method
return self.generator.generate_direct(text, max_tokens=500, is_json=is_json)