Spaces:
Sleeping
Sleeping
| import numpy as np | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| import faiss | |
| import re | |
| # --- Global caches --- | |
| tfidf_vectorizer = None | |
| tfidf_matrix = None | |
| corpus_texts = None | |
| faiss_index = None | |
| embeddings_array = None # FAISS requires float32 | |
| description_texts = None # For exact/fuzzy match | |
| patient_texts = None # For exact/fuzzy match | |
| description_norm_texts = None # Normalized (punctuation stripped) | |
| patient_norm_texts = None # Normalized (punctuation stripped) | |
| def encode_question(model, user_question): | |
| """Encodes the user's question using the embedding model.""" | |
| if model is None or not user_question.strip(): | |
| return None | |
| return model.encode([user_question], show_progress_bar=False)[0].astype('float32') | |
| def init_tfidf(data_texts): | |
| """ | |
| Initialize TF-IDF matrix for hybrid search. | |
| """ | |
| global tfidf_vectorizer, tfidf_matrix, corpus_texts | |
| corpus_texts = data_texts | |
| tfidf_vectorizer = TfidfVectorizer(stop_words='english', max_features=10000) | |
| tfidf_matrix = tfidf_vectorizer.fit_transform(corpus_texts) | |
| def init_faiss(embeddings): | |
| """ | |
| Initialize FAISS index for fast semantic search. | |
| embeddings: np.array (num_samples x 768) normalized | |
| """ | |
| global faiss_index, embeddings_array | |
| embeddings_array = embeddings.astype('float32') | |
| # Do not renormalize | |
| dimension = embeddings_array.shape[1] | |
| faiss_index = faiss.IndexFlatIP(dimension) # Inner product for cosine similarity | |
| faiss_index.add(embeddings_array) | |
| def set_description_texts(texts): | |
| """ | |
| Provide the Description column for exact/fuzzy match search. | |
| """ | |
| global description_texts, description_norm_texts | |
| description_texts = [str(t).lower() for t in texts] | |
| description_norm_texts = [preprocess_text_for_embeddings(t) for t in texts] | |
| def set_patient_texts(texts): | |
| """ | |
| Provide the Patient column for exact/fuzzy match search. | |
| """ | |
| global patient_texts, patient_norm_texts | |
| patient_texts = [str(t).lower() for t in texts] | |
| patient_norm_texts = [preprocess_text_for_embeddings(t) for t in texts] | |
| def strong_recall_indices(user_query_raw: str, top_k: int = 10): | |
| """ | |
| Scan the entire dataset (Description + Patient) for: | |
| 1) Exact equality on normalized text | |
| 2) Exact substring presence | |
| 3) High-threshold fuzzy match (if rapidfuzz available) | |
| Returns a list of indices (unique, in priority order) up to top_k. | |
| """ | |
| global description_texts, patient_texts, description_norm_texts, patient_norm_texts | |
| if not user_query_raw: | |
| return [] | |
| q_lower = str(user_query_raw).lower() | |
| q_norm = preprocess_text_for_embeddings(user_query_raw) | |
| N_desc = len(description_texts) if description_texts is not None else 0 | |
| N_pat = len(patient_texts) if patient_texts is not None else 0 | |
| N = max(N_desc, N_pat) | |
| if N == 0: | |
| return [] | |
| exact_equal = [] | |
| exact_sub = [] | |
| fuzzy_hits = [] | |
| # 1) Exact equality on normalized text | |
| if description_norm_texts is not None: | |
| exact_equal += [i for i in range(len(description_norm_texts)) if description_norm_texts[i] == q_norm] | |
| if patient_norm_texts is not None: | |
| exact_equal += [i for i in range(len(patient_norm_texts)) if patient_norm_texts[i] == q_norm] | |
| # Deduplicate preserving order | |
| seen = set() | |
| ordered = [] | |
| for i in exact_equal: | |
| if i not in seen: | |
| seen.add(i) | |
| ordered.append(i) | |
| if len(ordered) >= top_k: | |
| return ordered[:top_k] | |
| # 2) Exact substring presence (lowercased) | |
| if description_texts is not None: | |
| exact_sub += [i for i in range(len(description_texts)) if q_lower in description_texts[i]] | |
| if patient_texts is not None: | |
| exact_sub += [i for i in range(len(patient_texts)) if q_lower in patient_texts[i]] | |
| for i in exact_sub: | |
| if i not in seen: | |
| seen.add(i) | |
| ordered.append(i) | |
| if len(ordered) >= top_k: | |
| return ordered[:top_k] | |
| # 3) High-threshold fuzzy matches | |
| try: | |
| from rapidfuzz import fuzz | |
| # Use partial_ratio and token_set_ratio; take max as score | |
| scored = [] | |
| for i in range(N): | |
| s_desc = description_texts[i] if (description_texts is not None and i < len(description_texts)) else "" | |
| s_pat = patient_texts[i] if (patient_texts is not None and i < len(patient_texts)) else "" | |
| score_desc = max(fuzz.partial_ratio(q_lower, s_desc), fuzz.token_set_ratio(q_lower, s_desc)) if s_desc else 0 | |
| score_pat = max(fuzz.partial_ratio(q_lower, s_pat), fuzz.token_set_ratio(q_lower, s_pat)) if s_pat else 0 | |
| score = max(score_desc, score_pat) | |
| if score >= 90: | |
| scored.append((i, score)) | |
| # sort by score desc | |
| scored.sort(key=lambda x: x[1], reverse=True) | |
| for i, _ in scored: | |
| if i not in seen: | |
| seen.add(i) | |
| ordered.append(i) | |
| if len(ordered) >= top_k: | |
| break | |
| except Exception: | |
| pass | |
| return ordered[:top_k] | |
| def hybrid_search( | |
| model, | |
| embeddings, | |
| user_query_raw, | |
| user_query_enhanced, | |
| top_k=5, | |
| weight_semantic=0.7, | |
| faiss_top_candidates=256, | |
| use_exact_match=True, | |
| use_fuzzy_match=True | |
| ): | |
| """ | |
| Hybrid search combining: | |
| 1. FAISS semantic similarity | |
| 2. TF-IDF boosting | |
| 3. Optional exact substring match in Description | |
| Returns: list of top indices in dataset | |
| """ | |
| global tfidf_vectorizer, tfidf_matrix, corpus_texts, faiss_index, embeddings_array, description_texts | |
| if model is None or embeddings is None or len(embeddings) == 0: | |
| return [] | |
| # Encode enhanced query for semantic/TF-IDF stages | |
| question_embedding = encode_question(model, user_query_enhanced) | |
| if question_embedding is None: | |
| return [] | |
| # --- 1. FAISS semantic search --- | |
| if faiss_index is not None: | |
| D, I = faiss_index.search(np.array([question_embedding]), k=min(faiss_top_candidates, embeddings.shape[0])) | |
| top_candidates = I[0] | |
| semantic_sim_top = D[0] | |
| else: | |
| semantic_sim_full = np.dot(embeddings, question_embedding) | |
| top_candidates = np.argpartition(semantic_sim_full, -faiss_top_candidates)[-faiss_top_candidates:] | |
| top_candidates = top_candidates[np.argsort(semantic_sim_full[top_candidates])[::-1]] | |
| semantic_sim_top = semantic_sim_full[top_candidates] | |
| # --- 2. TF-IDF similarity --- | |
| if tfidf_vectorizer is not None and tfidf_matrix is not None: | |
| tfidf_vec = tfidf_vectorizer.transform([user_query_enhanced]) | |
| tfidf_sim_top = (tfidf_matrix[top_candidates] @ tfidf_vec.T).toarray().ravel() | |
| else: | |
| tfidf_sim_top = np.zeros(len(top_candidates)) | |
| # --- 3. Optional exact + fuzzy match across Description & Patient --- | |
| combined_sim_top = weight_semantic * semantic_sim_top + (1 - weight_semantic) * tfidf_sim_top | |
| if use_exact_match or use_fuzzy_match: | |
| query_lower = user_query_raw.lower() | |
| # Exact substring presence boosts | |
| exact_desc = np.zeros(len(top_candidates)) | |
| exact_pat = np.zeros(len(top_candidates)) | |
| if description_texts is not None: | |
| exact_desc = np.array([1.0 if query_lower in description_texts[i] else 0.0 for i in top_candidates]) | |
| if patient_texts is not None: | |
| exact_pat = np.array([1.0 if query_lower in patient_texts[i] else 0.0 for i in top_candidates]) | |
| # Fuzzy partial ratio via rapidfuzz (graceful fallback) | |
| fuzzy_desc = np.zeros(len(top_candidates)) | |
| fuzzy_pat = np.zeros(len(top_candidates)) | |
| if use_fuzzy_match: | |
| try: | |
| from rapidfuzz import fuzz | |
| if description_texts is not None: | |
| fuzzy_desc = np.array([ | |
| fuzz.partial_ratio(query_lower, description_texts[i]) / 100.0 for i in top_candidates | |
| ]) | |
| if patient_texts is not None: | |
| fuzzy_pat = np.array([ | |
| fuzz.partial_ratio(query_lower, patient_texts[i]) / 100.0 for i in top_candidates | |
| ]) | |
| except Exception: | |
| pass | |
| # Token overlap (Jaccard) as an additional weak signal | |
| def jaccard(a: str, b: str) -> float: | |
| sa = set(a.split()) | |
| sb = set(b.split()) | |
| if not sa or not sb: | |
| return 0.0 | |
| inter = len(sa & sb) | |
| union = len(sa | sb) | |
| return inter / union if union else 0.0 | |
| token_desc = np.zeros(len(top_candidates)) | |
| token_pat = np.zeros(len(top_candidates)) | |
| if description_texts is not None: | |
| token_desc = np.array([jaccard(query_lower, description_texts[i]) for i in top_candidates]) | |
| if patient_texts is not None: | |
| token_pat = np.array([jaccard(query_lower, patient_texts[i]) for i in top_candidates]) | |
| # Combine boosters with gentle weights; exact match is strongest | |
| booster = 0.20 * exact_desc + 0.20 * exact_pat + 0.10 * fuzzy_desc + 0.10 * fuzzy_pat + 0.05 * token_desc + 0.05 * token_pat | |
| combined_sim_top = combined_sim_top + booster | |
| # --- 4. Select final top-k indices --- | |
| sorted_top_indices = top_candidates[np.argsort(combined_sim_top)[::-1][:top_k]] | |
| return sorted_top_indices | |
| # --- Minimal preprocessing for embeddings --- | |
| def preprocess_text_for_embeddings(text: str) -> str: | |
| """Lowercase + remove punctuation for embeddings.""" | |
| text = str(text).lower() | |
| text = re.sub(r'[^\w\s]', ' ', text) | |
| text = re.sub(r'\s+', ' ', text).strip() | |
| return text | |
| # --- Minimal preprocessing for keywords --- | |
| def preprocess_text_for_keywords(text: str) -> str: | |
| """Lowercase + remove punctuation for keywords.""" | |
| text = str(text).lower() | |
| text = re.sub(r'[^\w\s]', ' ', text) | |
| text = re.sub(r'\s+', ' ', text).strip() | |
| return text | |