File size: 7,240 Bytes
aa15689
 
 
 
 
 
 
 
 
e45a23a
aa15689
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e45a23a
aa15689
 
e45a23a
aa15689
e45a23a
aa15689
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e45a23a
aa15689
 
 
 
 
e45a23a
aa15689
 
 
 
 
 
 
 
 
e45a23a
aa15689
 
 
 
 
e45a23a
 
 
 
 
 
 
 
 
aa15689
e45a23a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa15689
 
e45a23a
 
aa15689
e45a23a
 
 
 
aa15689
e45a23a
 
 
 
 
aa15689
e45a23a
aa15689
e45a23a
aa15689
 
e45a23a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
aa15689
 
 
 
 
 
 
 
 
 
 
 
 
 
e45a23a
aa15689
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e45a23a
ac37ae3
e45a23a
 
aa15689
 
e45a23a
aa15689
 
 
 
7523ee1
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
import gradio as gr
import chromadb
from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
from sentence_transformers import CrossEncoder
import torch
from rank_bm25 import BM25Okapi
import string
import os
import sys
import numpy as np # Needed for normalization

# --- 1. SETUP & MODEL LOADING ---
print("⏳ Loading models...")

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on: {device}")

ef = SentenceTransformerEmbeddingFunction(
    model_name="BAAI/bge-m3",
    device=device
)

reranker = CrossEncoder(
    "BAAI/bge-reranker-v2-m3",
    device=device,
    trust_remote_code=True,
    model_kwargs={"dtype": "float16"} if device == "cuda" else {}
)

print("✅ Models loaded!")

# --- 2. LOAD PERSISTENT VECTOR DB ---
DB_PATH = "./vector_db"

if not os.path.exists(DB_PATH):
    print(f"❌ Error: The folder '{DB_PATH}' was not found.")
else:
    print(f"wd: {os.getcwd()}")

client = chromadb.PersistentClient(path=DB_PATH)

try:
    collection = client.get_collection(name='ct_data', embedding_function=ef)
    print(f"✅ Loaded collection 'ct_data' with {collection.count()} documents.")
except Exception as e:
    print(f"❌ Error loading collection: {e}")
    sys.exit(1)

# --- 3. BUILD IN-MEMORY INDICES (BM25) ---
bm25_index = None
doc_id_map = {}
all_metadatas = {}

def build_indices_from_db():
    global bm25_index, doc_id_map, all_metadatas
    print("⏳ Fetching data to build BM25 index...")
    data = collection.get()
    ids = data['ids']
    documents = data['documents']
    metadatas = data['metadatas']
    
    if not documents: return

    tokenized_corpus = [
        doc.lower().translate(str.maketrans('', '', string.punctuation)).split()
        for doc in documents
    ]
    bm25_index = BM25Okapi(tokenized_corpus)
    
    for idx, (doc_id, doc_text, meta) in enumerate(zip(ids, documents, metadatas)):
        doc_id_map[idx] = doc_id
        all_metadatas[doc_id] = {"document": doc_text, "meta": meta if meta else {}}
        
    print("✅ Hybrid Index Ready.")

build_indices_from_db()

# --- 4. NEW: WEIGHTED FUSION LOGIC ---
def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def weighted_score_fusion(vector_results, vector_scores, bm25_results, bm25_scores, alpha=0.65):
    """
    Combines results using score weighting:
    Final Score = alpha * NormalizedVector + (1-alpha) * NormalizedBM25
    """
    fused_scores = {}
    
    # 1. Normalize Vector Scores (Cosine Sim is -1 to 1, usually 0 to 1 for embeddings)
    # We assume vector_scores are already somewhat normalized (0-1), but let's ensure it.
    # If using L2 distance, you'd need to invert this. Chroma default is usually distance, 
    # but bge-m3 uses cosine similarity (higher is better).
    
    # 2. Normalize BM25 Scores (They are unbounded, so we use MinMax or Sigmoid)
    if bm25_scores:
        min_bm25 = min(bm25_scores)
        max_bm25 = max(bm25_scores)
        if max_bm25 == min_bm25:
            norm_bm25 = [1.0] * len(bm25_scores)
        else:
            norm_bm25 = [(s - min_bm25) / (max_bm25 - min_bm25) for s in bm25_scores]
    else:
        norm_bm25 = []

    # Map scores to IDs
    vec_map = {doc_id: score for doc_id, score in zip(vector_results, vector_scores)}
    bm25_map = {doc_id: score for doc_id, score in zip(bm25_results, norm_bm25)}
    
    # Union of all found documents
    all_ids = set(vector_results) | set(bm25_results)
    
    for doc_id in all_ids:
        v_score = vec_map.get(doc_id, 0.0)
        b_score = bm25_map.get(doc_id, 0.0)
        
        # The Alpha Ratio Logic
        final_score = (alpha * v_score) + ((1.0 - alpha) * b_score)
        fused_scores[doc_id] = final_score
        
    return sorted(fused_scores.keys(), key=lambda x: fused_scores[x], reverse=True)


def granular_search(query: str, initial_k: int = 15, final_k: int = 3, alpha: float = 0.65):
    try:
        # A. Vector Search (Get Scores too)
        # include=['documents', 'distances'] tells Chroma to return scores
        vec_res = collection.query(query_texts=[query], n_results=initial_k, include=['documents', 'distances'])
        
        vector_ids = vec_res['ids'][0] if vec_res['ids'] else []
        # Chroma returns Distances (Lower is better for L2/Cosine Distance)
        # But BGE-M3 is usually Cosine Similarity.
        # If score is Distance: Sim = 1 - Distance
        vector_dists = vec_res['distances'][0] if vec_res['distances'] else []
        vector_scores = [1 - d for d in vector_dists] # Convert distance to similarity

        # B. BM25 Search (Get Scores too)
        bm25_ids = []
        bm25_scores = []
        if bm25_index:
            tokenized = query.lower().translate(str.maketrans('', '', string.punctuation)).split()
            # Get all scores
            all_scores = bm25_index.get_scores(tokenized)
            # Sort top K
            top_indices = all_scores.argsort()[::-1][:initial_k]
            
            for i in top_indices:
                score = all_scores[i]
                if score > 0:
                    bm25_ids.append(doc_id_map[i])
                    bm25_scores.append(score)

        # C. Weighted Fusion (USING ALPHA)
        candidates_ids = weighted_score_fusion(
            vector_ids, vector_scores, 
            bm25_ids, bm25_scores, 
            alpha=alpha
        )[:initial_k] # Keep top K after fusion
        
        if not candidates_ids:
            return {"data": [], "message": "No results found"}

        # D. Fetch Text (from Cache)
        docs = []
        metas = []
        for did in candidates_ids:
            item = all_metadatas.get(did)
            if item:
                docs.append(item['document'])
                metas.append(item['meta'])

        # E. Rerank
        if not docs: return {"data": []}

        pairs = [[query, doc] for doc in docs]
        scores = reranker.predict(pairs)
        
        # F. Format
        results = sorted(zip(scores, docs, metas), key=lambda x: x[0], reverse=True)[:final_k]
        
        formatted_data = []
        for score, doc, meta in results:
            formatted_data.append({
                "name": meta.get('name', 'Unknown'),
                "description": doc,
                "image_id": meta.get('image id', ''),
                "relevance_score": float(score),
                "building_type": meta.get('building_type', 'unknown')
            })

        return {
            "data": formatted_data,
            "meta": {
                "query": query,
                "count": len(formatted_data)
            }
        }

    except Exception as e:
        return {"error": str(e)}

# --- 5. GRADIO UI ---
demo = gr.Interface(
    fn=granular_search,
    inputs=[
        gr.Textbox(label="Query", placeholder="Search..."),
        gr.Number(value=5, label="Initial K", visible=False),
        gr.Number(value=1, label="Final K", visible=False),
        gr.Number(value=0.65, label="Alpha (Vector Weight)", visible=False) # Expose Alpha
    ],
    outputs=gr.JSON(label="Results"),
    title="Granular Search API (Weighted)",
    flagging_mode="never"
)

if __name__ == "__main__":
    demo.queue().launch(server_name="0.0.0.0", server_port=7860)