Spaces:
Running
Running
| from __future__ import annotations | |
| import json | |
| import logging | |
| import re | |
| from typing import List, Tuple | |
| from .config import DIABETES_KEYWORDS | |
| logger = logging.getLogger(__name__) | |
| # Negation patterns that indicate the question is about non-diabetic patients | |
| NON_DIABETES_PATTERNS = [ | |
| "non-diabetic", "non diabetic", "nondiabetic", | |
| "without diabetes", "no diabetes", "not diabetic", | |
| "healthy individuals", "healthy subjects", "healthy patients", | |
| "non-diabetic patients", "non-diabetic individuals", | |
| ] | |
| # Common misspellings of diabetes-related terms | |
| DIABETES_MISSPELLINGS = [ | |
| "diabeties", "diabtes", "dibeties", "diabetis", "diabeets", | |
| "diebetes", "diabeetus", "diebeties", | |
| "metformn", "metformine", "metformin", | |
| "insuln", "insuline", | |
| "glucos", "glocose", | |
| ] | |
| class QueryProcessor: | |
| """Handles domain validation and query expansion using LLM.""" | |
| def __init__(self, generator) -> None: | |
| self.generator = generator | |
| def validate_domain(self, question: str) -> Tuple[bool, str]: | |
| q_lower = question.lower() | |
| # Check if question is explicitly about non-diabetic patients | |
| if any(pattern in q_lower for pattern in NON_DIABETES_PATTERNS): | |
| # Still allow if question compares diabetic vs non-diabetic | |
| if not any(k in q_lower for k in ["compared to", "versus", "vs", "comparison"]): | |
| return False, ( | |
| "This system is designed for diabetes patients only. " | |
| "Your question appears to be about non-diabetic patients." | |
| ) | |
| # Check standard keywords | |
| if any(keyword in q_lower for keyword in DIABETES_KEYWORDS): | |
| return True, "" | |
| # Check common misspellings | |
| if any(misspelling in q_lower for misspelling in DIABETES_MISSPELLINGS): | |
| return True, "" | |
| return False, ( | |
| "This system is strict to Diabetes. " | |
| "Your question appears to be outside this domain." | |
| ) | |
| def expand_queries(self, question: str) -> List[str]: | |
| prompt = ( | |
| "You are a medical query engineer. Given a user question about diabetes, produce 4 search query variants:\n" | |
| "1 BM25-optimized with MeSH terms\n" | |
| "1 Dense-optimized\n" | |
| "2 semantic variants\n\n" | |
| "Return as JSON array of query strings. Do NOT include Markdown formatting like ``json.\n\n" | |
| f"Question: '{question}'\n\n" | |
| "JSON Output:" | |
| ) | |
| try: | |
| output = self._generate_with_model(prompt, is_json=True) | |
| import re | |
| cleaned_json = re.sub(r'^```[jJ]son\s*', '', output) | |
| cleaned_json = re.sub(r'```$', '', cleaned_json).strip() | |
| # Handle standard Groq response format for json | |
| try: | |
| queries = json.loads(cleaned_json) | |
| if isinstance(queries, dict): | |
| # Trying to find the array in the dict | |
| for key in queries: | |
| if isinstance(queries[key], list): | |
| queries = queries[key] | |
| break | |
| # Extract string queries if it returned a list of dicts instead of list of strings | |
| if isinstance(queries, list) and len(queries) > 0 and isinstance(queries[0], dict): | |
| queries = [q.get("query", str(q)) for q in queries if "query" in q] | |
| except json.JSONDecodeError: | |
| # Fallback pattern if JSON parse fails | |
| queries = [] | |
| if isinstance(queries, list) and all(isinstance(q, str) for q in queries): | |
| if question not in queries: | |
| queries.insert(0, question) | |
| print("Generated Queries:", queries) | |
| return queries | |
| logger.warning(f"Failed to parse JSON for query expansion. Returning original query. Output was: {output}") | |
| return [question] | |
| except Exception as e: | |
| logger.warning(f"Error during query expansion: {e}") | |
| return [question] | |
| def _generate_with_model(self, text: str, is_json: bool = False) -> str: | |
| # Calls the centralized Groq API generation method | |
| return self.generator.generate_direct(text, max_tokens=300, is_json=is_json) | |