JustscrAPIng's picture
Update app.py
e45a23a verified
import gradio as gr
import chromadb
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
from sentence_transformers import CrossEncoder
import torch
from rank_bm25 import BM25Okapi
import string
import os
import sys
import numpy as np # Needed for normalization
# --- 1. SETUP & MODEL LOADING ---
print("⏳ Loading models...")
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on: {device}")
ef = SentenceTransformerEmbeddingFunction(
model_name="BAAI/bge-m3",
device=device
)
reranker = CrossEncoder(
"BAAI/bge-reranker-v2-m3",
device=device,
trust_remote_code=True,
model_kwargs={"dtype": "float16"} if device == "cuda" else {}
)
print("✅ Models loaded!")
# --- 2. LOAD PERSISTENT VECTOR DB ---
DB_PATH = "./vector_db"
if not os.path.exists(DB_PATH):
print(f"❌ Error: The folder '{DB_PATH}' was not found.")
else:
print(f"wd: {os.getcwd()}")
client = chromadb.PersistentClient(path=DB_PATH)
try:
collection = client.get_collection(name='ct_data', embedding_function=ef)
print(f"✅ Loaded collection 'ct_data' with {collection.count()} documents.")
except Exception as e:
print(f"❌ Error loading collection: {e}")
sys.exit(1)
# --- 3. BUILD IN-MEMORY INDICES (BM25) ---
bm25_index = None
doc_id_map = {}
all_metadatas = {}
def build_indices_from_db():
global bm25_index, doc_id_map, all_metadatas
print("⏳ Fetching data to build BM25 index...")
data = collection.get()
ids = data['ids']
documents = data['documents']
metadatas = data['metadatas']
if not documents: return
tokenized_corpus = [
doc.lower().translate(str.maketrans('', '', string.punctuation)).split()
for doc in documents
]
bm25_index = BM25Okapi(tokenized_corpus)
for idx, (doc_id, doc_text, meta) in enumerate(zip(ids, documents, metadatas)):
doc_id_map[idx] = doc_id
all_metadatas[doc_id] = {"document": doc_text, "meta": meta if meta else {}}
print("✅ Hybrid Index Ready.")
build_indices_from_db()
# --- 4. NEW: WEIGHTED FUSION LOGIC ---
def sigmoid(x):
return 1 / (1 + np.exp(-x))
def weighted_score_fusion(vector_results, vector_scores, bm25_results, bm25_scores, alpha=0.65):
"""
Combines results using score weighting:
Final Score = alpha * NormalizedVector + (1-alpha) * NormalizedBM25
"""
fused_scores = {}
# 1. Normalize Vector Scores (Cosine Sim is -1 to 1, usually 0 to 1 for embeddings)
# We assume vector_scores are already somewhat normalized (0-1), but let's ensure it.
# If using L2 distance, you'd need to invert this. Chroma default is usually distance,
# but bge-m3 uses cosine similarity (higher is better).
# 2. Normalize BM25 Scores (They are unbounded, so we use MinMax or Sigmoid)
if bm25_scores:
min_bm25 = min(bm25_scores)
max_bm25 = max(bm25_scores)
if max_bm25 == min_bm25:
norm_bm25 = [1.0] * len(bm25_scores)
else:
norm_bm25 = [(s - min_bm25) / (max_bm25 - min_bm25) for s in bm25_scores]
else:
norm_bm25 = []
# Map scores to IDs
vec_map = {doc_id: score for doc_id, score in zip(vector_results, vector_scores)}
bm25_map = {doc_id: score for doc_id, score in zip(bm25_results, norm_bm25)}
# Union of all found documents
all_ids = set(vector_results) | set(bm25_results)
for doc_id in all_ids:
v_score = vec_map.get(doc_id, 0.0)
b_score = bm25_map.get(doc_id, 0.0)
# The Alpha Ratio Logic
final_score = (alpha * v_score) + ((1.0 - alpha) * b_score)
fused_scores[doc_id] = final_score
return sorted(fused_scores.keys(), key=lambda x: fused_scores[x], reverse=True)
def granular_search(query: str, initial_k: int = 15, final_k: int = 3, alpha: float = 0.65):
try:
# A. Vector Search (Get Scores too)
# include=['documents', 'distances'] tells Chroma to return scores
vec_res = collection.query(query_texts=[query], n_results=initial_k, include=['documents', 'distances'])
vector_ids = vec_res['ids'][0] if vec_res['ids'] else []
# Chroma returns Distances (Lower is better for L2/Cosine Distance)
# But BGE-M3 is usually Cosine Similarity.
# If score is Distance: Sim = 1 - Distance
vector_dists = vec_res['distances'][0] if vec_res['distances'] else []
vector_scores = [1 - d for d in vector_dists] # Convert distance to similarity
# B. BM25 Search (Get Scores too)
bm25_ids = []
bm25_scores = []
if bm25_index:
tokenized = query.lower().translate(str.maketrans('', '', string.punctuation)).split()
# Get all scores
all_scores = bm25_index.get_scores(tokenized)
# Sort top K
top_indices = all_scores.argsort()[::-1][:initial_k]
for i in top_indices:
score = all_scores[i]
if score > 0:
bm25_ids.append(doc_id_map[i])
bm25_scores.append(score)
# C. Weighted Fusion (USING ALPHA)
candidates_ids = weighted_score_fusion(
vector_ids, vector_scores,
bm25_ids, bm25_scores,
alpha=alpha
)[:initial_k] # Keep top K after fusion
if not candidates_ids:
return {"data": [], "message": "No results found"}
# D. Fetch Text (from Cache)
docs = []
metas = []
for did in candidates_ids:
item = all_metadatas.get(did)
if item:
docs.append(item['document'])
metas.append(item['meta'])
# E. Rerank
if not docs: return {"data": []}
pairs = [[query, doc] for doc in docs]
scores = reranker.predict(pairs)
# F. Format
results = sorted(zip(scores, docs, metas), key=lambda x: x[0], reverse=True)[:final_k]
formatted_data = []
for score, doc, meta in results:
formatted_data.append({
"name": meta.get('name', 'Unknown'),
"description": doc,
"image_id": meta.get('image id', ''),
"relevance_score": float(score),
"building_type": meta.get('building_type', 'unknown')
})
return {
"data": formatted_data,
"meta": {
"query": query,
"count": len(formatted_data)
}
}
except Exception as e:
return {"error": str(e)}
# --- 5. GRADIO UI ---
demo = gr.Interface(
fn=granular_search,
inputs=[
gr.Textbox(label="Query", placeholder="Search..."),
gr.Number(value=5, label="Initial K", visible=False),
gr.Number(value=1, label="Final K", visible=False),
gr.Number(value=0.65, label="Alpha (Vector Weight)", visible=False) # Expose Alpha
],
outputs=gr.JSON(label="Results"),
title="Granular Search API (Weighted)",
flagging_mode="never"
)
if __name__ == "__main__":
demo.queue().launch(server_name="0.0.0.0", server_port=7860)