JustscrAPIng commited on
Commit
e45a23a
Β·
verified Β·
1 Parent(s): ac37ae3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -58
app.py CHANGED
@@ -7,21 +7,19 @@ from rank_bm25 import BM25Okapi
7
  import string
8
  import os
9
  import sys
 
10
 
11
  # --- 1. SETUP & MODEL LOADING ---
12
  print("⏳ Loading models...")
13
 
14
- # Detect Hardware (GPU vs CPU)
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  print(f"Running on: {device}")
17
 
18
- # Embedding Function (Must match what you used to create the DB)
19
  ef = SentenceTransformerEmbeddingFunction(
20
  model_name="BAAI/bge-m3",
21
  device=device
22
  )
23
 
24
- # Reranker Model
25
  reranker = CrossEncoder(
26
  "BAAI/bge-reranker-v2-m3",
27
  device=device,
@@ -32,104 +30,131 @@ reranker = CrossEncoder(
32
  print("βœ… Models loaded!")
33
 
34
  # --- 2. LOAD PERSISTENT VECTOR DB ---
35
- DB_PATH = "./vector_db" # This must match the folder name you uploaded
36
 
37
  if not os.path.exists(DB_PATH):
38
- print(f"❌ Error: The folder '{DB_PATH}' was not found in the Space.")
39
- print("Please upload your local 'vector_db' folder to the Files tab.")
40
- # We don't exit here so you can see the error in logs, but the app will fail later.
41
  else:
42
- print(f"wd: {os.getcwd()}") # Print working directory for debugging
43
 
44
- # Initialize Persistent Client
45
  client = chromadb.PersistentClient(path=DB_PATH)
46
 
47
- # Get the existing collection
48
- # Note: We use get_collection because we assume it already exists.
49
  try:
50
  collection = client.get_collection(name='ct_data', embedding_function=ef)
51
  print(f"βœ… Loaded collection 'ct_data' with {collection.count()} documents.")
52
  except Exception as e:
53
  print(f"❌ Error loading collection: {e}")
54
- # Fallback for debugging if name is wrong
55
- print(f"Available collections: {[c.name for c in client.list_collections()]}")
56
  sys.exit(1)
57
 
58
  # --- 3. BUILD IN-MEMORY INDICES (BM25) ---
59
- # We need to fetch all documents from the DB to build the BM25 index
60
- # and the metadata cache. This avoids needing the CSV files.
61
-
62
  bm25_index = None
63
  doc_id_map = {}
64
  all_metadatas = {}
65
 
66
  def build_indices_from_db():
67
  global bm25_index, doc_id_map, all_metadatas
68
-
69
- print("⏳ Fetching data from ChromaDB to build BM25 index...")
70
-
71
- # Fetch all data (IDs, Documents, Metadatas)
72
- # If you have >100k docs, you might want to paginate this, but for typical RAG it's fine.
73
  data = collection.get()
74
-
75
  ids = data['ids']
76
  documents = data['documents']
77
  metadatas = data['metadatas']
78
 
79
- if not documents:
80
- print("⚠️ Warning: Collection is empty!")
81
- return
82
 
83
- # Build BM25 Corpus
84
- print(f"Processing {len(documents)} documents for Keyword Search...")
85
  tokenized_corpus = [
86
  doc.lower().translate(str.maketrans('', '', string.punctuation)).split()
87
  for doc in documents
88
  ]
89
  bm25_index = BM25Okapi(tokenized_corpus)
90
 
91
- # Reconstruct Cache
92
  for idx, (doc_id, doc_text, meta) in enumerate(zip(ids, documents, metadatas)):
93
- # Map integer index (from BM25) back to string ID
94
  doc_id_map[idx] = doc_id
95
-
96
- # Store in fast lookup dict
97
- all_metadatas[doc_id] = {
98
- "document": doc_text,
99
- "meta": meta if meta else {}
100
- }
101
 
102
  print("βœ… Hybrid Index Ready.")
103
 
104
- # Run this immediately
105
  build_indices_from_db()
106
 
107
- # --- 4. SEARCH LOGIC ---
108
- def reciprocal_rank_fusion(vector_results, bm25_results, k=60):
 
 
 
 
 
 
 
109
  fused_scores = {}
110
- for rank, doc_id in enumerate(vector_results):
111
- fused_scores[doc_id] = fused_scores.get(doc_id, 0) + (1 / (k + rank))
112
- for rank, doc_id in enumerate(bm25_results):
113
- fused_scores[doc_id] = fused_scores.get(doc_id, 0) + (1 / (k + rank))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  return sorted(fused_scores.keys(), key=lambda x: fused_scores[x], reverse=True)
115
 
116
- def granular_search(query: str, initial_k: int = 15, final_k: int = 3):
 
117
  try:
118
- # A. Vector Search
119
- # Querying the persistent DB
120
- vec_res = collection.query(query_texts=[query], n_results=initial_k)
 
121
  vector_ids = vec_res['ids'][0] if vec_res['ids'] else []
 
 
 
 
 
122
 
123
- # B. BM25 Search
124
  bm25_ids = []
 
125
  if bm25_index:
126
  tokenized = query.lower().translate(str.maketrans('', '', string.punctuation)).split()
127
- scores = bm25_index.get_scores(tokenized)
128
- top_indices = scores.argsort()[::-1][:initial_k]
129
- bm25_ids = [doc_id_map[i] for i in top_indices if scores[i] > 0]
130
-
131
- # C. Fusion
132
- candidates_ids = reciprocal_rank_fusion(vector_ids, bm25_ids)[:initial_k]
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  if not candidates_ids:
135
  return {"data": [], "message": "No results found"}
@@ -144,8 +169,7 @@ def granular_search(query: str, initial_k: int = 15, final_k: int = 3):
144
  metas.append(item['meta'])
145
 
146
  # E. Rerank
147
- if not docs:
148
- return {"data": []}
149
 
150
  pairs = [[query, doc] for doc in docs]
151
  scores = reranker.predict(pairs)
@@ -178,12 +202,13 @@ def granular_search(query: str, initial_k: int = 15, final_k: int = 3):
178
  demo = gr.Interface(
179
  fn=granular_search,
180
  inputs=[
181
- gr.Textbox(label="Query", placeholder="Search for Vietnamese architecture..."),
182
  gr.Number(value=5, label="Initial K", visible=False),
183
- gr.Number(value=1, label="Final K", visible=False)
 
184
  ],
185
  outputs=gr.JSON(label="Results"),
186
- title="Granular Search API (Persistent)",
187
  flagging_mode="never"
188
  )
189
 
 
7
  import string
8
  import os
9
  import sys
10
+ import numpy as np # Needed for normalization
11
 
12
  # --- 1. SETUP & MODEL LOADING ---
13
  print("⏳ Loading models...")
14
 
 
15
  device = "cuda" if torch.cuda.is_available() else "cpu"
16
  print(f"Running on: {device}")
17
 
 
18
  ef = SentenceTransformerEmbeddingFunction(
19
  model_name="BAAI/bge-m3",
20
  device=device
21
  )
22
 
 
23
  reranker = CrossEncoder(
24
  "BAAI/bge-reranker-v2-m3",
25
  device=device,
 
30
  print("βœ… Models loaded!")
31
 
32
  # --- 2. LOAD PERSISTENT VECTOR DB ---
33
+ DB_PATH = "./vector_db"
34
 
35
  if not os.path.exists(DB_PATH):
36
+ print(f"❌ Error: The folder '{DB_PATH}' was not found.")
 
 
37
  else:
38
+ print(f"wd: {os.getcwd()}")
39
 
 
40
  client = chromadb.PersistentClient(path=DB_PATH)
41
 
 
 
42
  try:
43
  collection = client.get_collection(name='ct_data', embedding_function=ef)
44
  print(f"βœ… Loaded collection 'ct_data' with {collection.count()} documents.")
45
  except Exception as e:
46
  print(f"❌ Error loading collection: {e}")
 
 
47
  sys.exit(1)
48
 
49
  # --- 3. BUILD IN-MEMORY INDICES (BM25) ---
 
 
 
50
  bm25_index = None
51
  doc_id_map = {}
52
  all_metadatas = {}
53
 
54
  def build_indices_from_db():
55
  global bm25_index, doc_id_map, all_metadatas
56
+ print("⏳ Fetching data to build BM25 index...")
 
 
 
 
57
  data = collection.get()
 
58
  ids = data['ids']
59
  documents = data['documents']
60
  metadatas = data['metadatas']
61
 
62
+ if not documents: return
 
 
63
 
 
 
64
  tokenized_corpus = [
65
  doc.lower().translate(str.maketrans('', '', string.punctuation)).split()
66
  for doc in documents
67
  ]
68
  bm25_index = BM25Okapi(tokenized_corpus)
69
 
 
70
  for idx, (doc_id, doc_text, meta) in enumerate(zip(ids, documents, metadatas)):
 
71
  doc_id_map[idx] = doc_id
72
+ all_metadatas[doc_id] = {"document": doc_text, "meta": meta if meta else {}}
 
 
 
 
 
73
 
74
  print("βœ… Hybrid Index Ready.")
75
 
 
76
  build_indices_from_db()
77
 
78
+ # --- 4. NEW: WEIGHTED FUSION LOGIC ---
79
+ def sigmoid(x):
80
+ return 1 / (1 + np.exp(-x))
81
+
82
+ def weighted_score_fusion(vector_results, vector_scores, bm25_results, bm25_scores, alpha=0.65):
83
+ """
84
+ Combines results using score weighting:
85
+ Final Score = alpha * NormalizedVector + (1-alpha) * NormalizedBM25
86
+ """
87
  fused_scores = {}
88
+
89
+ # 1. Normalize Vector Scores (Cosine Sim is -1 to 1, usually 0 to 1 for embeddings)
90
+ # We assume vector_scores are already somewhat normalized (0-1), but let's ensure it.
91
+ # If using L2 distance, you'd need to invert this. Chroma default is usually distance,
92
+ # but bge-m3 uses cosine similarity (higher is better).
93
+
94
+ # 2. Normalize BM25 Scores (They are unbounded, so we use MinMax or Sigmoid)
95
+ if bm25_scores:
96
+ min_bm25 = min(bm25_scores)
97
+ max_bm25 = max(bm25_scores)
98
+ if max_bm25 == min_bm25:
99
+ norm_bm25 = [1.0] * len(bm25_scores)
100
+ else:
101
+ norm_bm25 = [(s - min_bm25) / (max_bm25 - min_bm25) for s in bm25_scores]
102
+ else:
103
+ norm_bm25 = []
104
+
105
+ # Map scores to IDs
106
+ vec_map = {doc_id: score for doc_id, score in zip(vector_results, vector_scores)}
107
+ bm25_map = {doc_id: score for doc_id, score in zip(bm25_results, norm_bm25)}
108
+
109
+ # Union of all found documents
110
+ all_ids = set(vector_results) | set(bm25_results)
111
+
112
+ for doc_id in all_ids:
113
+ v_score = vec_map.get(doc_id, 0.0)
114
+ b_score = bm25_map.get(doc_id, 0.0)
115
+
116
+ # The Alpha Ratio Logic
117
+ final_score = (alpha * v_score) + ((1.0 - alpha) * b_score)
118
+ fused_scores[doc_id] = final_score
119
+
120
  return sorted(fused_scores.keys(), key=lambda x: fused_scores[x], reverse=True)
121
 
122
+
123
+ def granular_search(query: str, initial_k: int = 15, final_k: int = 3, alpha: float = 0.65):
124
  try:
125
+ # A. Vector Search (Get Scores too)
126
+ # include=['documents', 'distances'] tells Chroma to return scores
127
+ vec_res = collection.query(query_texts=[query], n_results=initial_k, include=['documents', 'distances'])
128
+
129
  vector_ids = vec_res['ids'][0] if vec_res['ids'] else []
130
+ # Chroma returns Distances (Lower is better for L2/Cosine Distance)
131
+ # But BGE-M3 is usually Cosine Similarity.
132
+ # If score is Distance: Sim = 1 - Distance
133
+ vector_dists = vec_res['distances'][0] if vec_res['distances'] else []
134
+ vector_scores = [1 - d for d in vector_dists] # Convert distance to similarity
135
 
136
+ # B. BM25 Search (Get Scores too)
137
  bm25_ids = []
138
+ bm25_scores = []
139
  if bm25_index:
140
  tokenized = query.lower().translate(str.maketrans('', '', string.punctuation)).split()
141
+ # Get all scores
142
+ all_scores = bm25_index.get_scores(tokenized)
143
+ # Sort top K
144
+ top_indices = all_scores.argsort()[::-1][:initial_k]
145
+
146
+ for i in top_indices:
147
+ score = all_scores[i]
148
+ if score > 0:
149
+ bm25_ids.append(doc_id_map[i])
150
+ bm25_scores.append(score)
151
+
152
+ # C. Weighted Fusion (USING ALPHA)
153
+ candidates_ids = weighted_score_fusion(
154
+ vector_ids, vector_scores,
155
+ bm25_ids, bm25_scores,
156
+ alpha=alpha
157
+ )[:initial_k] # Keep top K after fusion
158
 
159
  if not candidates_ids:
160
  return {"data": [], "message": "No results found"}
 
169
  metas.append(item['meta'])
170
 
171
  # E. Rerank
172
+ if not docs: return {"data": []}
 
173
 
174
  pairs = [[query, doc] for doc in docs]
175
  scores = reranker.predict(pairs)
 
202
  demo = gr.Interface(
203
  fn=granular_search,
204
  inputs=[
205
+ gr.Textbox(label="Query", placeholder="Search..."),
206
  gr.Number(value=5, label="Initial K", visible=False),
207
+ gr.Number(value=1, label="Final K", visible=False),
208
+ gr.Number(value=0.65, label="Alpha (Vector Weight)", visible=False) # Expose Alpha
209
  ],
210
  outputs=gr.JSON(label="Results"),
211
+ title="Granular Search API (Weighted)",
212
  flagging_mode="never"
213
  )
214