File size: 9,999 Bytes
690bcb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
import numpy as np
from sklearn.feature_extraction.text import TfidfVectorizer
import faiss
import re

# --- Global caches ---
tfidf_vectorizer = None
tfidf_matrix = None
corpus_texts = None
faiss_index = None
embeddings_array = None  # FAISS requires float32
description_texts = None  # For exact/fuzzy match
patient_texts = None      # For exact/fuzzy match
description_norm_texts = None  # Normalized (punctuation stripped)
patient_norm_texts = None      # Normalized (punctuation stripped)


def encode_question(model, user_question):
    """Encodes the user's question using the embedding model."""
    if model is None or not user_question.strip():
        return None
    return model.encode([user_question], show_progress_bar=False)[0].astype('float32')


def init_tfidf(data_texts):
    """
    Initialize TF-IDF matrix for hybrid search.
    """
    global tfidf_vectorizer, tfidf_matrix, corpus_texts
    corpus_texts = data_texts
    tfidf_vectorizer = TfidfVectorizer(stop_words='english', max_features=10000)
    tfidf_matrix = tfidf_vectorizer.fit_transform(corpus_texts)


def init_faiss(embeddings):
    """
    Initialize FAISS index for fast semantic search.
    embeddings: np.array (num_samples x 768) normalized
    """
    global faiss_index, embeddings_array
    embeddings_array = embeddings.astype('float32')

    # Do not renormalize
    dimension = embeddings_array.shape[1]
    faiss_index = faiss.IndexFlatIP(dimension)  # Inner product for cosine similarity
    faiss_index.add(embeddings_array)



def set_description_texts(texts):
    """
    Provide the Description column for exact/fuzzy match search.
    """
    global description_texts, description_norm_texts
    description_texts = [str(t).lower() for t in texts]
    description_norm_texts = [preprocess_text_for_embeddings(t) for t in texts]


def set_patient_texts(texts):
    """
    Provide the Patient column for exact/fuzzy match search.
    """
    global patient_texts, patient_norm_texts
    patient_texts = [str(t).lower() for t in texts]
    patient_norm_texts = [preprocess_text_for_embeddings(t) for t in texts]


def strong_recall_indices(user_query_raw: str, top_k: int = 10):
    """
    Scan the entire dataset (Description + Patient) for:
    1) Exact equality on normalized text
    2) Exact substring presence
    3) High-threshold fuzzy match (if rapidfuzz available)

    Returns a list of indices (unique, in priority order) up to top_k.
    """
    global description_texts, patient_texts, description_norm_texts, patient_norm_texts

    if not user_query_raw:
        return []

    q_lower = str(user_query_raw).lower()
    q_norm = preprocess_text_for_embeddings(user_query_raw)

    N_desc = len(description_texts) if description_texts is not None else 0
    N_pat = len(patient_texts) if patient_texts is not None else 0
    N = max(N_desc, N_pat)
    if N == 0:
        return []

    exact_equal = []
    exact_sub = []
    fuzzy_hits = []

    # 1) Exact equality on normalized text
    if description_norm_texts is not None:
        exact_equal += [i for i in range(len(description_norm_texts)) if description_norm_texts[i] == q_norm]
    if patient_norm_texts is not None:
        exact_equal += [i for i in range(len(patient_norm_texts)) if patient_norm_texts[i] == q_norm]

    # Deduplicate preserving order
    seen = set()
    ordered = []
    for i in exact_equal:
        if i not in seen:
            seen.add(i)
            ordered.append(i)
    if len(ordered) >= top_k:
        return ordered[:top_k]

    # 2) Exact substring presence (lowercased)
    if description_texts is not None:
        exact_sub += [i for i in range(len(description_texts)) if q_lower in description_texts[i]]
    if patient_texts is not None:
        exact_sub += [i for i in range(len(patient_texts)) if q_lower in patient_texts[i]]

    for i in exact_sub:
        if i not in seen:
            seen.add(i)
            ordered.append(i)
    if len(ordered) >= top_k:
        return ordered[:top_k]

    # 3) High-threshold fuzzy matches
    try:
        from rapidfuzz import fuzz
        # Use partial_ratio and token_set_ratio; take max as score
        scored = []
        for i in range(N):
            s_desc = description_texts[i] if (description_texts is not None and i < len(description_texts)) else ""
            s_pat = patient_texts[i] if (patient_texts is not None and i < len(patient_texts)) else ""
            score_desc = max(fuzz.partial_ratio(q_lower, s_desc), fuzz.token_set_ratio(q_lower, s_desc)) if s_desc else 0
            score_pat = max(fuzz.partial_ratio(q_lower, s_pat), fuzz.token_set_ratio(q_lower, s_pat)) if s_pat else 0
            score = max(score_desc, score_pat)
            if score >= 90:
                scored.append((i, score))
        # sort by score desc
        scored.sort(key=lambda x: x[1], reverse=True)
        for i, _ in scored:
            if i not in seen:
                seen.add(i)
                ordered.append(i)
            if len(ordered) >= top_k:
                break
    except Exception:
        pass

    return ordered[:top_k]


def hybrid_search(
    model,
    embeddings,
    user_query_raw,
    user_query_enhanced,
    top_k=5,
    weight_semantic=0.7,
    faiss_top_candidates=256,
    use_exact_match=True,
    use_fuzzy_match=True
):
    """
    Hybrid search combining:
    1. FAISS semantic similarity
    2. TF-IDF boosting
    3. Optional exact substring match in Description

    Returns: list of top indices in dataset
    """
    global tfidf_vectorizer, tfidf_matrix, corpus_texts, faiss_index, embeddings_array, description_texts

    if model is None or embeddings is None or len(embeddings) == 0:
        return []

    # Encode enhanced query for semantic/TF-IDF stages
    question_embedding = encode_question(model, user_query_enhanced)
    if question_embedding is None:
        return []

    # --- 1. FAISS semantic search ---
    if faiss_index is not None:
        D, I = faiss_index.search(np.array([question_embedding]), k=min(faiss_top_candidates, embeddings.shape[0]))
        top_candidates = I[0]
        semantic_sim_top = D[0]
    else:
        semantic_sim_full = np.dot(embeddings, question_embedding)
        top_candidates = np.argpartition(semantic_sim_full, -faiss_top_candidates)[-faiss_top_candidates:]
        top_candidates = top_candidates[np.argsort(semantic_sim_full[top_candidates])[::-1]]
        semantic_sim_top = semantic_sim_full[top_candidates]

    # --- 2. TF-IDF similarity ---
    if tfidf_vectorizer is not None and tfidf_matrix is not None:
        tfidf_vec = tfidf_vectorizer.transform([user_query_enhanced])
        tfidf_sim_top = (tfidf_matrix[top_candidates] @ tfidf_vec.T).toarray().ravel()
    else:
        tfidf_sim_top = np.zeros(len(top_candidates))

    # --- 3. Optional exact + fuzzy match across Description & Patient ---
    combined_sim_top = weight_semantic * semantic_sim_top + (1 - weight_semantic) * tfidf_sim_top

    if use_exact_match or use_fuzzy_match:
        query_lower = user_query_raw.lower()

        # Exact substring presence boosts
        exact_desc = np.zeros(len(top_candidates))
        exact_pat = np.zeros(len(top_candidates))
        if description_texts is not None:
            exact_desc = np.array([1.0 if query_lower in description_texts[i] else 0.0 for i in top_candidates])
        if patient_texts is not None:
            exact_pat = np.array([1.0 if query_lower in patient_texts[i] else 0.0 for i in top_candidates])

        # Fuzzy partial ratio via rapidfuzz (graceful fallback)
        fuzzy_desc = np.zeros(len(top_candidates))
        fuzzy_pat = np.zeros(len(top_candidates))
        if use_fuzzy_match:
            try:
                from rapidfuzz import fuzz
                if description_texts is not None:
                    fuzzy_desc = np.array([
                        fuzz.partial_ratio(query_lower, description_texts[i]) / 100.0 for i in top_candidates
                    ])
                if patient_texts is not None:
                    fuzzy_pat = np.array([
                        fuzz.partial_ratio(query_lower, patient_texts[i]) / 100.0 for i in top_candidates
                    ])
            except Exception:
                pass

        # Token overlap (Jaccard) as an additional weak signal
        def jaccard(a: str, b: str) -> float:
            sa = set(a.split())
            sb = set(b.split())
            if not sa or not sb:
                return 0.0
            inter = len(sa & sb)
            union = len(sa | sb)
            return inter / union if union else 0.0

        token_desc = np.zeros(len(top_candidates))
        token_pat = np.zeros(len(top_candidates))
        if description_texts is not None:
            token_desc = np.array([jaccard(query_lower, description_texts[i]) for i in top_candidates])
        if patient_texts is not None:
            token_pat = np.array([jaccard(query_lower, patient_texts[i]) for i in top_candidates])

        # Combine boosters with gentle weights; exact match is strongest
        booster = 0.20 * exact_desc + 0.20 * exact_pat + 0.10 * fuzzy_desc + 0.10 * fuzzy_pat + 0.05 * token_desc + 0.05 * token_pat
        combined_sim_top = combined_sim_top + booster

    # --- 4. Select final top-k indices ---
    sorted_top_indices = top_candidates[np.argsort(combined_sim_top)[::-1][:top_k]]

    return sorted_top_indices


# --- Minimal preprocessing for embeddings ---
def preprocess_text_for_embeddings(text: str) -> str:
    """Lowercase + remove punctuation for embeddings."""
    text = str(text).lower()
    text = re.sub(r'[^\w\s]', ' ', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text


# --- Minimal preprocessing for keywords ---
def preprocess_text_for_keywords(text: str) -> str:
    """Lowercase + remove punctuation for keywords."""
    text = str(text).lower()
    text = re.sub(r'[^\w\s]', ' ', text)
    text = re.sub(r'\s+', ' ', text).strip()
    return text