Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import chromadb | |
| from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction | |
| from sentence_transformers import CrossEncoder | |
| import torch | |
| from rank_bm25 import BM25Okapi | |
| import string | |
| import os | |
| import sys | |
| import numpy as np # Needed for normalization | |
| # --- 1. SETUP & MODEL LOADING --- | |
| print("⏳ Loading models...") | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Running on: {device}") | |
| ef = SentenceTransformerEmbeddingFunction( | |
| model_name="BAAI/bge-m3", | |
| device=device | |
| ) | |
| reranker = CrossEncoder( | |
| "BAAI/bge-reranker-v2-m3", | |
| device=device, | |
| trust_remote_code=True, | |
| model_kwargs={"dtype": "float16"} if device == "cuda" else {} | |
| ) | |
| print("✅ Models loaded!") | |
| # --- 2. LOAD PERSISTENT VECTOR DB --- | |
| DB_PATH = "./vector_db" | |
| if not os.path.exists(DB_PATH): | |
| print(f"❌ Error: The folder '{DB_PATH}' was not found.") | |
| else: | |
| print(f"wd: {os.getcwd()}") | |
| client = chromadb.PersistentClient(path=DB_PATH) | |
| try: | |
| collection = client.get_collection(name='ct_data', embedding_function=ef) | |
| print(f"✅ Loaded collection 'ct_data' with {collection.count()} documents.") | |
| except Exception as e: | |
| print(f"❌ Error loading collection: {e}") | |
| sys.exit(1) | |
| # --- 3. BUILD IN-MEMORY INDICES (BM25) --- | |
| bm25_index = None | |
| doc_id_map = {} | |
| all_metadatas = {} | |
| def build_indices_from_db(): | |
| global bm25_index, doc_id_map, all_metadatas | |
| print("⏳ Fetching data to build BM25 index...") | |
| data = collection.get() | |
| ids = data['ids'] | |
| documents = data['documents'] | |
| metadatas = data['metadatas'] | |
| if not documents: return | |
| tokenized_corpus = [ | |
| doc.lower().translate(str.maketrans('', '', string.punctuation)).split() | |
| for doc in documents | |
| ] | |
| bm25_index = BM25Okapi(tokenized_corpus) | |
| for idx, (doc_id, doc_text, meta) in enumerate(zip(ids, documents, metadatas)): | |
| doc_id_map[idx] = doc_id | |
| all_metadatas[doc_id] = {"document": doc_text, "meta": meta if meta else {}} | |
| print("✅ Hybrid Index Ready.") | |
| build_indices_from_db() | |
| # --- 4. NEW: WEIGHTED FUSION LOGIC --- | |
| def sigmoid(x): | |
| return 1 / (1 + np.exp(-x)) | |
| def weighted_score_fusion(vector_results, vector_scores, bm25_results, bm25_scores, alpha=0.65): | |
| """ | |
| Combines results using score weighting: | |
| Final Score = alpha * NormalizedVector + (1-alpha) * NormalizedBM25 | |
| """ | |
| fused_scores = {} | |
| # 1. Normalize Vector Scores (Cosine Sim is -1 to 1, usually 0 to 1 for embeddings) | |
| # We assume vector_scores are already somewhat normalized (0-1), but let's ensure it. | |
| # If using L2 distance, you'd need to invert this. Chroma default is usually distance, | |
| # but bge-m3 uses cosine similarity (higher is better). | |
| # 2. Normalize BM25 Scores (They are unbounded, so we use MinMax or Sigmoid) | |
| if bm25_scores: | |
| min_bm25 = min(bm25_scores) | |
| max_bm25 = max(bm25_scores) | |
| if max_bm25 == min_bm25: | |
| norm_bm25 = [1.0] * len(bm25_scores) | |
| else: | |
| norm_bm25 = [(s - min_bm25) / (max_bm25 - min_bm25) for s in bm25_scores] | |
| else: | |
| norm_bm25 = [] | |
| # Map scores to IDs | |
| vec_map = {doc_id: score for doc_id, score in zip(vector_results, vector_scores)} | |
| bm25_map = {doc_id: score for doc_id, score in zip(bm25_results, norm_bm25)} | |
| # Union of all found documents | |
| all_ids = set(vector_results) | set(bm25_results) | |
| for doc_id in all_ids: | |
| v_score = vec_map.get(doc_id, 0.0) | |
| b_score = bm25_map.get(doc_id, 0.0) | |
| # The Alpha Ratio Logic | |
| final_score = (alpha * v_score) + ((1.0 - alpha) * b_score) | |
| fused_scores[doc_id] = final_score | |
| return sorted(fused_scores.keys(), key=lambda x: fused_scores[x], reverse=True) | |
| def granular_search(query: str, initial_k: int = 15, final_k: int = 3, alpha: float = 0.65): | |
| try: | |
| # A. Vector Search (Get Scores too) | |
| # include=['documents', 'distances'] tells Chroma to return scores | |
| vec_res = collection.query(query_texts=[query], n_results=initial_k, include=['documents', 'distances']) | |
| vector_ids = vec_res['ids'][0] if vec_res['ids'] else [] | |
| # Chroma returns Distances (Lower is better for L2/Cosine Distance) | |
| # But BGE-M3 is usually Cosine Similarity. | |
| # If score is Distance: Sim = 1 - Distance | |
| vector_dists = vec_res['distances'][0] if vec_res['distances'] else [] | |
| vector_scores = [1 - d for d in vector_dists] # Convert distance to similarity | |
| # B. BM25 Search (Get Scores too) | |
| bm25_ids = [] | |
| bm25_scores = [] | |
| if bm25_index: | |
| tokenized = query.lower().translate(str.maketrans('', '', string.punctuation)).split() | |
| # Get all scores | |
| all_scores = bm25_index.get_scores(tokenized) | |
| # Sort top K | |
| top_indices = all_scores.argsort()[::-1][:initial_k] | |
| for i in top_indices: | |
| score = all_scores[i] | |
| if score > 0: | |
| bm25_ids.append(doc_id_map[i]) | |
| bm25_scores.append(score) | |
| # C. Weighted Fusion (USING ALPHA) | |
| candidates_ids = weighted_score_fusion( | |
| vector_ids, vector_scores, | |
| bm25_ids, bm25_scores, | |
| alpha=alpha | |
| )[:initial_k] # Keep top K after fusion | |
| if not candidates_ids: | |
| return {"data": [], "message": "No results found"} | |
| # D. Fetch Text (from Cache) | |
| docs = [] | |
| metas = [] | |
| for did in candidates_ids: | |
| item = all_metadatas.get(did) | |
| if item: | |
| docs.append(item['document']) | |
| metas.append(item['meta']) | |
| # E. Rerank | |
| if not docs: return {"data": []} | |
| pairs = [[query, doc] for doc in docs] | |
| scores = reranker.predict(pairs) | |
| # F. Format | |
| results = sorted(zip(scores, docs, metas), key=lambda x: x[0], reverse=True)[:final_k] | |
| formatted_data = [] | |
| for score, doc, meta in results: | |
| formatted_data.append({ | |
| "name": meta.get('name', 'Unknown'), | |
| "description": doc, | |
| "image_id": meta.get('image id', ''), | |
| "relevance_score": float(score), | |
| "building_type": meta.get('building_type', 'unknown') | |
| }) | |
| return { | |
| "data": formatted_data, | |
| "meta": { | |
| "query": query, | |
| "count": len(formatted_data) | |
| } | |
| } | |
| except Exception as e: | |
| return {"error": str(e)} | |
| # --- 5. GRADIO UI --- | |
| demo = gr.Interface( | |
| fn=granular_search, | |
| inputs=[ | |
| gr.Textbox(label="Query", placeholder="Search..."), | |
| gr.Number(value=5, label="Initial K", visible=False), | |
| gr.Number(value=1, label="Final K", visible=False), | |
| gr.Number(value=0.65, label="Alpha (Vector Weight)", visible=False) # Expose Alpha | |
| ], | |
| outputs=gr.JSON(label="Results"), | |
| title="Granular Search API (Weighted)", | |
| flagging_mode="never" | |
| ) | |
| if __name__ == "__main__": | |
| demo.queue().launch(server_name="0.0.0.0", server_port=7860) |