Spaces:
Sleeping
Sleeping
File size: 7,240 Bytes
aa15689 e45a23a aa15689 e45a23a aa15689 e45a23a aa15689 e45a23a aa15689 e45a23a aa15689 e45a23a aa15689 e45a23a aa15689 e45a23a aa15689 e45a23a aa15689 e45a23a aa15689 e45a23a aa15689 e45a23a aa15689 e45a23a aa15689 e45a23a aa15689 e45a23a aa15689 e45a23a aa15689 e45a23a ac37ae3 e45a23a aa15689 e45a23a aa15689 7523ee1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 |
import gradio as gr
import chromadb
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
from sentence_transformers import CrossEncoder
import torch
from rank_bm25 import BM25Okapi
import string
import os
import sys
import numpy as np # Needed for normalization
# --- 1. SETUP & MODEL LOADING ---
print("⏳ Loading models...")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on: {device}")
ef = SentenceTransformerEmbeddingFunction(
model_name="BAAI/bge-m3",
device=device
)
reranker = CrossEncoder(
"BAAI/bge-reranker-v2-m3",
device=device,
trust_remote_code=True,
model_kwargs={"dtype": "float16"} if device == "cuda" else {}
)
print("✅ Models loaded!")
# --- 2. LOAD PERSISTENT VECTOR DB ---
DB_PATH = "./vector_db"
if not os.path.exists(DB_PATH):
print(f"❌ Error: The folder '{DB_PATH}' was not found.")
else:
print(f"wd: {os.getcwd()}")
client = chromadb.PersistentClient(path=DB_PATH)
try:
collection = client.get_collection(name='ct_data', embedding_function=ef)
print(f"✅ Loaded collection 'ct_data' with {collection.count()} documents.")
except Exception as e:
print(f"❌ Error loading collection: {e}")
sys.exit(1)
# --- 3. BUILD IN-MEMORY INDICES (BM25) ---
bm25_index = None
doc_id_map = {}
all_metadatas = {}
def build_indices_from_db():
global bm25_index, doc_id_map, all_metadatas
print("⏳ Fetching data to build BM25 index...")
data = collection.get()
ids = data['ids']
documents = data['documents']
metadatas = data['metadatas']
if not documents: return
tokenized_corpus = [
doc.lower().translate(str.maketrans('', '', string.punctuation)).split()
for doc in documents
]
bm25_index = BM25Okapi(tokenized_corpus)
for idx, (doc_id, doc_text, meta) in enumerate(zip(ids, documents, metadatas)):
doc_id_map[idx] = doc_id
all_metadatas[doc_id] = {"document": doc_text, "meta": meta if meta else {}}
print("✅ Hybrid Index Ready.")
build_indices_from_db()
# --- 4. NEW: WEIGHTED FUSION LOGIC ---
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def weighted_score_fusion(vector_results, vector_scores, bm25_results, bm25_scores, alpha=0.65):
"""
Combines results using score weighting:
Final Score = alpha * NormalizedVector + (1-alpha) * NormalizedBM25
"""
fused_scores = {}
# 1. Normalize Vector Scores (Cosine Sim is -1 to 1, usually 0 to 1 for embeddings)
# We assume vector_scores are already somewhat normalized (0-1), but let's ensure it.
# If using L2 distance, you'd need to invert this. Chroma default is usually distance,
# but bge-m3 uses cosine similarity (higher is better).
# 2. Normalize BM25 Scores (They are unbounded, so we use MinMax or Sigmoid)
if bm25_scores:
min_bm25 = min(bm25_scores)
max_bm25 = max(bm25_scores)
if max_bm25 == min_bm25:
norm_bm25 = [1.0] * len(bm25_scores)
else:
norm_bm25 = [(s - min_bm25) / (max_bm25 - min_bm25) for s in bm25_scores]
else:
norm_bm25 = []
# Map scores to IDs
vec_map = {doc_id: score for doc_id, score in zip(vector_results, vector_scores)}
bm25_map = {doc_id: score for doc_id, score in zip(bm25_results, norm_bm25)}
# Union of all found documents
all_ids = set(vector_results) | set(bm25_results)
for doc_id in all_ids:
v_score = vec_map.get(doc_id, 0.0)
b_score = bm25_map.get(doc_id, 0.0)
# The Alpha Ratio Logic
final_score = (alpha * v_score) + ((1.0 - alpha) * b_score)
fused_scores[doc_id] = final_score
return sorted(fused_scores.keys(), key=lambda x: fused_scores[x], reverse=True)
def granular_search(query: str, initial_k: int = 15, final_k: int = 3, alpha: float = 0.65):
try:
# A. Vector Search (Get Scores too)
# include=['documents', 'distances'] tells Chroma to return scores
vec_res = collection.query(query_texts=[query], n_results=initial_k, include=['documents', 'distances'])
vector_ids = vec_res['ids'][0] if vec_res['ids'] else []
# Chroma returns Distances (Lower is better for L2/Cosine Distance)
# But BGE-M3 is usually Cosine Similarity.
# If score is Distance: Sim = 1 - Distance
vector_dists = vec_res['distances'][0] if vec_res['distances'] else []
vector_scores = [1 - d for d in vector_dists] # Convert distance to similarity
# B. BM25 Search (Get Scores too)
bm25_ids = []
bm25_scores = []
if bm25_index:
tokenized = query.lower().translate(str.maketrans('', '', string.punctuation)).split()
# Get all scores
all_scores = bm25_index.get_scores(tokenized)
# Sort top K
top_indices = all_scores.argsort()[::-1][:initial_k]
for i in top_indices:
score = all_scores[i]
if score > 0:
bm25_ids.append(doc_id_map[i])
bm25_scores.append(score)
# C. Weighted Fusion (USING ALPHA)
candidates_ids = weighted_score_fusion(
vector_ids, vector_scores,
bm25_ids, bm25_scores,
alpha=alpha
)[:initial_k] # Keep top K after fusion
if not candidates_ids:
return {"data": [], "message": "No results found"}
# D. Fetch Text (from Cache)
docs = []
metas = []
for did in candidates_ids:
item = all_metadatas.get(did)
if item:
docs.append(item['document'])
metas.append(item['meta'])
# E. Rerank
if not docs: return {"data": []}
pairs = [[query, doc] for doc in docs]
scores = reranker.predict(pairs)
# F. Format
results = sorted(zip(scores, docs, metas), key=lambda x: x[0], reverse=True)[:final_k]
formatted_data = []
for score, doc, meta in results:
formatted_data.append({
"name": meta.get('name', 'Unknown'),
"description": doc,
"image_id": meta.get('image id', ''),
"relevance_score": float(score),
"building_type": meta.get('building_type', 'unknown')
})
return {
"data": formatted_data,
"meta": {
"query": query,
"count": len(formatted_data)
}
}
except Exception as e:
return {"error": str(e)}
# --- 5. GRADIO UI ---
demo = gr.Interface(
fn=granular_search,
inputs=[
gr.Textbox(label="Query", placeholder="Search..."),
gr.Number(value=5, label="Initial K", visible=False),
gr.Number(value=1, label="Final K", visible=False),
gr.Number(value=0.65, label="Alpha (Vector Weight)", visible=False) # Expose Alpha
],
outputs=gr.JSON(label="Results"),
title="Granular Search API (Weighted)",
flagging_mode="never"
)
if __name__ == "__main__":
demo.queue().launch(server_name="0.0.0.0", server_port=7860) |