import re import nltk from nltk.tokenize import RegexpTokenizer from nltk.corpus import stopwords from nltk.stem import WordNetLemmatizer from keybert import KeyBERT # --- Download NLTK resources if needed --- try: stopwords.words('english') except LookupError: nltk.download('stopwords', quiet=True) nltk.download('punkt', quiet=True) nltk.download('averaged_perceptron_tagger', quiet=True) nltk.download('wordnet', quiet=True) # --- Initialize tools --- tokenizer = RegexpTokenizer(r'\w+') lemmatizer = WordNetLemmatizer() custom_stopwords = set(stopwords.words('english')) - {'no', 'not', 'without', 'due', 'to', 'with', 'on', 'in'} # --- Medical synonym expansion --- medical_synonyms = { "flu": ["influenza"], "cold": ["common cold", "rhinitis"], "heart attack": ["myocardial infarction"], "diabetes": ["high blood sugar", "hyperglycemia"], "bp": ["blood pressure", "hypertension"], "hypertension": ["high blood pressure"], "asthma": ["respiratory disease"], "cough": ["dry cough", "wet cough"], "fever": ["temperature", "high fever"] } def expand_medical_terms(text: str) -> str: """Expands known medical terms with their synonyms for better recall.""" for key, syns in medical_synonyms.items(): for syn in syns: text = re.sub(rf"\b{key}\b", f"{key} {syn}", text, flags=re.IGNORECASE) return text def preprocess_text(text: str) -> str: """Minimal preprocessing: lowercase, remove punctuation, collapse spaces.""" text = str(text).lower() text = re.sub(r'[^\w\s]', ' ', text) text = re.sub(r'\s+', ' ', text).strip() return text class QueryEnhancer: """ Wrapper class to handle query enhancement with local SapBERT + KeyBERT. """ def __init__(self, sentence_transformer_model): """ sentence_transformer_model: the already-loaded local SapBERT SentenceTransformer """ self.kw_model = KeyBERT(sentence_transformer_model) def extract_keywords(self, text: str, top_n: int = 5) -> list: """Extracts top keywords using KeyBERT.""" if not self.kw_model: return [] try: keywords = self.kw_model.extract_keywords( text, keyphrase_ngram_range=(1, 2), stop_words='english', top_n=top_n ) return [kw[0] for kw in keywords] except Exception: return [] def enhance_query(self, user_query: str) -> str: """ Full query enhancement pipeline: - Preprocess text - Expand medical synonyms - Extract keywords - Return combined enhanced query string """ preprocessed = preprocess_text(user_query) expanded = expand_medical_terms(preprocessed) keywords = self.extract_keywords(user_query) enhanced_query = f"{expanded} {' '.join(keywords)}".strip() return enhanced_query