File size: 2,977 Bytes
690bcb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
import re
import nltk
from nltk.tokenize import RegexpTokenizer
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
from keybert import KeyBERT

# --- Download NLTK resources if needed ---
try:
    stopwords.words('english')
except LookupError:
    nltk.download('stopwords', quiet=True)
    nltk.download('punkt', quiet=True)
    nltk.download('averaged_perceptron_tagger', quiet=True)
    nltk.download('wordnet', quiet=True)

# --- Initialize tools ---
tokenizer = RegexpTokenizer(r'\w+')
lemmatizer = WordNetLemmatizer()
custom_stopwords = set(stopwords.words('english')) - {'no', 'not', 'without', 'due', 'to', 'with', 'on', 'in'}

# --- Medical synonym expansion ---
medical_synonyms = {
    "flu": ["influenza"],
    "cold": ["common cold", "rhinitis"],
    "heart attack": ["myocardial infarction"],
    "diabetes": ["high blood sugar", "hyperglycemia"],
    "bp": ["blood pressure", "hypertension"],
    "hypertension": ["high blood pressure"],
    "asthma": ["respiratory disease"],
    "cough": ["dry cough", "wet cough"],
    "fever": ["temperature", "high fever"]
}

def expand_medical_terms(text: str) -> str:
    """Expands known medical terms with their synonyms for better recall."""
    for key, syns in medical_synonyms.items():
        for syn in syns:
            text = re.sub(rf"\b{key}\b", f"{key} {syn}", text, flags=re.IGNORECASE)
    return text

def preprocess_text(text: str) -> str:
    """Minimal preprocessing: lowercase, remove punctuation, collapse spaces."""
    text = str(text).lower()
    text = re.sub(r'[^\w\s]', ' ', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text

class QueryEnhancer:
    """
    Wrapper class to handle query enhancement with local SapBERT + KeyBERT.
    """
    def __init__(self, sentence_transformer_model):
        """
        sentence_transformer_model: the already-loaded local SapBERT SentenceTransformer
        """
        self.kw_model = KeyBERT(sentence_transformer_model)

    def extract_keywords(self, text: str, top_n: int = 5) -> list:
        """Extracts top keywords using KeyBERT."""
        if not self.kw_model:
            return []
        try:
            keywords = self.kw_model.extract_keywords(
                text,
                keyphrase_ngram_range=(1, 2),
                stop_words='english',
                top_n=top_n
            )
            return [kw[0] for kw in keywords]
        except Exception:
            return []

    def enhance_query(self, user_query: str) -> str:
        """
        Full query enhancement pipeline:
        - Preprocess text
        - Expand medical synonyms
        - Extract keywords
        - Return combined enhanced query string
        """
        preprocessed = preprocess_text(user_query)
        expanded = expand_medical_terms(preprocessed)
        keywords = self.extract_keywords(user_query)
        enhanced_query = f"{expanded} {' '.join(keywords)}".strip()
        return enhanced_query