from __future__ import annotations import json import logging import re from typing import List, Tuple from .config import DIABETES_KEYWORDS logger = logging.getLogger(__name__) # Negation patterns that indicate the question is about non-diabetic patients NON_DIABETES_PATTERNS = [ "non-diabetic", "non diabetic", "nondiabetic", "without diabetes", "no diabetes", "not diabetic", "healthy individuals", "healthy subjects", "healthy patients", "non-diabetic patients", "non-diabetic individuals", ] # Common misspellings of diabetes-related terms DIABETES_MISSPELLINGS = [ "diabeties", "diabtes", "dibeties", "diabetis", "diabeets", "diebetes", "diabeetus", "diebeties", "metformn", "metformine", "metformin", "insuln", "insuline", "glucos", "glocose", ] class QueryProcessor: """Handles domain validation and query expansion using LLM.""" def __init__(self, generator) -> None: self.generator = generator def validate_domain(self, question: str) -> Tuple[bool, str]: q_lower = question.lower() # Check if question is explicitly about non-diabetic patients if any(pattern in q_lower for pattern in NON_DIABETES_PATTERNS): # Still allow if question compares diabetic vs non-diabetic if not any(k in q_lower for k in ["compared to", "versus", "vs", "comparison"]): return False, ( "This system is designed for diabetes patients only. " "Your question appears to be about non-diabetic patients." ) # Check standard keywords if any(keyword in q_lower for keyword in DIABETES_KEYWORDS): return True, "" # Check common misspellings if any(misspelling in q_lower for misspelling in DIABETES_MISSPELLINGS): return True, "" return False, ( "This system is strict to Diabetes. " "Your question appears to be outside this domain." ) def expand_queries(self, question: str) -> List[str]: prompt = ( "You are a medical query engineer. Given a user question about diabetes, produce 4 search query variants:\n" "1 BM25-optimized with MeSH terms\n" "1 Dense-optimized\n" "2 semantic variants\n\n" "Return as JSON array of query strings. Do NOT include Markdown formatting like ``json.\n\n" f"Question: '{question}'\n\n" "JSON Output:" ) try: output = self._generate_with_model(prompt, is_json=True) import re cleaned_json = re.sub(r'^```[jJ]son\s*', '', output) cleaned_json = re.sub(r'```$', '', cleaned_json).strip() # Handle standard Groq response format for json try: queries = json.loads(cleaned_json) if isinstance(queries, dict): # Trying to find the array in the dict for key in queries: if isinstance(queries[key], list): queries = queries[key] break # Extract string queries if it returned a list of dicts instead of list of strings if isinstance(queries, list) and len(queries) > 0 and isinstance(queries[0], dict): queries = [q.get("query", str(q)) for q in queries if "query" in q] except json.JSONDecodeError: # Fallback pattern if JSON parse fails queries = [] if isinstance(queries, list) and all(isinstance(q, str) for q in queries): if question not in queries: queries.insert(0, question) print("Generated Queries:", queries) return queries logger.warning(f"Failed to parse JSON for query expansion. Returning original query. Output was: {output}") return [question] except Exception as e: logger.warning(f"Error during query expansion: {e}") return [question] def _generate_with_model(self, text: str, is_json: bool = False) -> str: # Calls the centralized Groq API generation method return self.generator.generate_direct(text, max_tokens=300, is_json=is_json)