BioRAG / src /bio_rag /query_processor.py
aseelflihan's picture
Deploy Bio-RAG
2a2c039
from __future__ import annotations
import json
import logging
import re
from typing import List, Tuple
from .config import DIABETES_KEYWORDS
logger = logging.getLogger(__name__)
# Negation patterns that indicate the question is about non-diabetic patients
NON_DIABETES_PATTERNS = [
"non-diabetic", "non diabetic", "nondiabetic",
"without diabetes", "no diabetes", "not diabetic",
"healthy individuals", "healthy subjects", "healthy patients",
"non-diabetic patients", "non-diabetic individuals",
]
# Common misspellings of diabetes-related terms
DIABETES_MISSPELLINGS = [
"diabeties", "diabtes", "dibeties", "diabetis", "diabeets",
"diebetes", "diabeetus", "diebeties",
"metformn", "metformine", "metformin",
"insuln", "insuline",
"glucos", "glocose",
]
class QueryProcessor:
"""Handles domain validation and query expansion using LLM."""
def __init__(self, generator) -> None:
self.generator = generator
def validate_domain(self, question: str) -> Tuple[bool, str]:
q_lower = question.lower()
# Check if question is explicitly about non-diabetic patients
if any(pattern in q_lower for pattern in NON_DIABETES_PATTERNS):
# Still allow if question compares diabetic vs non-diabetic
if not any(k in q_lower for k in ["compared to", "versus", "vs", "comparison"]):
return False, (
"This system is designed for diabetes patients only. "
"Your question appears to be about non-diabetic patients."
)
# Check standard keywords
if any(keyword in q_lower for keyword in DIABETES_KEYWORDS):
return True, ""
# Check common misspellings
if any(misspelling in q_lower for misspelling in DIABETES_MISSPELLINGS):
return True, ""
return False, (
"This system is strict to Diabetes. "
"Your question appears to be outside this domain."
)
def expand_queries(self, question: str) -> List[str]:
prompt = (
"You are a medical query engineer. Given a user question about diabetes, produce 4 search query variants:\n"
"1 BM25-optimized with MeSH terms\n"
"1 Dense-optimized\n"
"2 semantic variants\n\n"
"Return as JSON array of query strings. Do NOT include Markdown formatting like ``json.\n\n"
f"Question: '{question}'\n\n"
"JSON Output:"
)
try:
output = self._generate_with_model(prompt, is_json=True)
import re
cleaned_json = re.sub(r'^```[jJ]son\s*', '', output)
cleaned_json = re.sub(r'```$', '', cleaned_json).strip()
# Handle standard Groq response format for json
try:
queries = json.loads(cleaned_json)
if isinstance(queries, dict):
# Trying to find the array in the dict
for key in queries:
if isinstance(queries[key], list):
queries = queries[key]
break
# Extract string queries if it returned a list of dicts instead of list of strings
if isinstance(queries, list) and len(queries) > 0 and isinstance(queries[0], dict):
queries = [q.get("query", str(q)) for q in queries if "query" in q]
except json.JSONDecodeError:
# Fallback pattern if JSON parse fails
queries = []
if isinstance(queries, list) and all(isinstance(q, str) for q in queries):
if question not in queries:
queries.insert(0, question)
print("Generated Queries:", queries)
return queries
logger.warning(f"Failed to parse JSON for query expansion. Returning original query. Output was: {output}")
return [question]
except Exception as e:
logger.warning(f"Error during query expansion: {e}")
return [question]
def _generate_with_model(self, text: str, is_json: bool = False) -> str:
# Calls the centralized Groq API generation method
return self.generator.generate_direct(text, max_tokens=300, is_json=is_json)