JustscrAPIng commited on
Commit
aa15689
·
verified ·
1 Parent(s): 7523ee1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +190 -190
app.py CHANGED
@@ -1,191 +1,191 @@
1
- import gradio as gr
2
- import chromadb
3
- from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
4
- from sentence_transformers import CrossEncoder
5
- import torch
6
- from rank_bm25 import BM25Okapi
7
- import string
8
- import os
9
- import sys
10
-
11
- # --- 1. SETUP & MODEL LOADING ---
12
- print("⏳ Loading models...")
13
-
14
- # Detect Hardware (GPU vs CPU)
15
- device = "cuda" if torch.cuda.is_available() else "cpu"
16
- print(f"Running on: {device}")
17
-
18
- # Embedding Function (Must match what you used to create the DB)
19
- ef = SentenceTransformerEmbeddingFunction(
20
- model_name="BAAI/bge-m3",
21
- device=device
22
- )
23
-
24
- # Reranker Model
25
- reranker = CrossEncoder(
26
- "BAAI/bge-reranker-v2-m3",
27
- device=device,
28
- trust_remote_code=True,
29
- model_kwargs={"dtype": "float16"} if device == "cuda" else {}
30
- )
31
-
32
- print("✅ Models loaded!")
33
-
34
- # --- 2. LOAD PERSISTENT VECTOR DB ---
35
- DB_PATH = "./vector_db" # This must match the folder name you uploaded
36
-
37
- if not os.path.exists(DB_PATH):
38
- print(f"❌ Error: The folder '{DB_PATH}' was not found in the Space.")
39
- print("Please upload your local 'vector_db' folder to the Files tab.")
40
- # We don't exit here so you can see the error in logs, but the app will fail later.
41
- else:
42
- print(f"wd: {os.getcwd()}") # Print working directory for debugging
43
-
44
- # Initialize Persistent Client
45
- client = chromadb.PersistentClient(path=DB_PATH)
46
-
47
- # Get the existing collection
48
- # Note: We use get_collection because we assume it already exists.
49
- try:
50
- collection = client.get_collection(name='ct_data', embedding_function=ef)
51
- print(f"✅ Loaded collection 'ct_data' with {collection.count()} documents.")
52
- except Exception as e:
53
- print(f"❌ Error loading collection: {e}")
54
- # Fallback for debugging if name is wrong
55
- print(f"Available collections: {[c.name for c in client.list_collections()]}")
56
- sys.exit(1)
57
-
58
- # --- 3. BUILD IN-MEMORY INDICES (BM25) ---
59
- # We need to fetch all documents from the DB to build the BM25 index
60
- # and the metadata cache. This avoids needing the CSV files.
61
-
62
- bm25_index = None
63
- doc_id_map = {}
64
- all_metadatas = {}
65
-
66
- def build_indices_from_db():
67
- global bm25_index, doc_id_map, all_metadatas
68
-
69
- print("⏳ Fetching data from ChromaDB to build BM25 index...")
70
-
71
- # Fetch all data (IDs, Documents, Metadatas)
72
- # If you have >100k docs, you might want to paginate this, but for typical RAG it's fine.
73
- data = collection.get()
74
-
75
- ids = data['ids']
76
- documents = data['documents']
77
- metadatas = data['metadatas']
78
-
79
- if not documents:
80
- print("⚠️ Warning: Collection is empty!")
81
- return
82
-
83
- # Build BM25 Corpus
84
- print(f"Processing {len(documents)} documents for Keyword Search...")
85
- tokenized_corpus = [
86
- doc.lower().translate(str.maketrans('', '', string.punctuation)).split()
87
- for doc in documents
88
- ]
89
- bm25_index = BM25Okapi(tokenized_corpus)
90
-
91
- # Reconstruct Cache
92
- for idx, (doc_id, doc_text, meta) in enumerate(zip(ids, documents, metadatas)):
93
- # Map integer index (from BM25) back to string ID
94
- doc_id_map[idx] = doc_id
95
-
96
- # Store in fast lookup dict
97
- all_metadatas[doc_id] = {
98
- "document": doc_text,
99
- "meta": meta if meta else {}
100
- }
101
-
102
- print("✅ Hybrid Index Ready.")
103
-
104
- # Run this immediately
105
- build_indices_from_db()
106
-
107
- # --- 4. SEARCH LOGIC ---
108
- def reciprocal_rank_fusion(vector_results, bm25_results, k=60):
109
- fused_scores = {}
110
- for rank, doc_id in enumerate(vector_results):
111
- fused_scores[doc_id] = fused_scores.get(doc_id, 0) + (1 / (k + rank))
112
- for rank, doc_id in enumerate(bm25_results):
113
- fused_scores[doc_id] = fused_scores.get(doc_id, 0) + (1 / (k + rank))
114
- return sorted(fused_scores.keys(), key=lambda x: fused_scores[x], reverse=True)
115
-
116
- def granular_search(query: str, initial_k: int = 15, final_k: int = 3):
117
- try:
118
- # A. Vector Search
119
- # Querying the persistent DB
120
- vec_res = collection.query(query_texts=[query], n_results=initial_k)
121
- vector_ids = vec_res['ids'][0] if vec_res['ids'] else []
122
-
123
- # B. BM25 Search
124
- bm25_ids = []
125
- if bm25_index:
126
- tokenized = query.lower().translate(str.maketrans('', '', string.punctuation)).split()
127
- scores = bm25_index.get_scores(tokenized)
128
- top_indices = scores.argsort()[::-1][:initial_k]
129
- bm25_ids = [doc_id_map[i] for i in top_indices if scores[i] > 0]
130
-
131
- # C. Fusion
132
- candidates_ids = reciprocal_rank_fusion(vector_ids, bm25_ids)[:initial_k]
133
-
134
- if not candidates_ids:
135
- return {"data": [], "message": "No results found"}
136
-
137
- # D. Fetch Text (from Cache)
138
- docs = []
139
- metas = []
140
- for did in candidates_ids:
141
- item = all_metadatas.get(did)
142
- if item:
143
- docs.append(item['document'])
144
- metas.append(item['meta'])
145
-
146
- # E. Rerank
147
- if not docs:
148
- return {"data": []}
149
-
150
- pairs = [[query, doc] for doc in docs]
151
- scores = reranker.predict(pairs)
152
-
153
- # F. Format
154
- results = sorted(zip(scores, docs, metas), key=lambda x: x[0], reverse=True)[:final_k]
155
-
156
- formatted_data = []
157
- for score, doc, meta in results:
158
- formatted_data.append({
159
- "name": meta.get('name', 'Unknown'),
160
- "description": doc,
161
- "image_id": meta.get('image id', ''),
162
- "relevance_score": float(score),
163
- "building_type": meta.get('building_type', 'unknown')
164
- })
165
-
166
- return {
167
- "data": formatted_data,
168
- "meta": {
169
- "query": query,
170
- "count": len(formatted_data)
171
- }
172
- }
173
-
174
- except Exception as e:
175
- return {"error": str(e)}
176
-
177
- # --- 5. GRADIO UI ---
178
- demo = gr.Interface(
179
- fn=granular_search,
180
- inputs=[
181
- gr.Textbox(label="Query", placeholder="Search for Vietnamese architecture..."),
182
- gr.Number(value=15, label="Initial K", visible=False),
183
- gr.Number(value=3, label="Final K", visible=False)
184
- ],
185
- outputs=gr.JSON(label="Results"),
186
- title="Granular Search API (Persistent)",
187
- allow_flagging="never"
188
- )
189
-
190
- if __name__ == "__main__":
191
  demo.queue().launch(server_name="0.0.0.0", server_port=7860)
 
1
+ import gradio as gr
2
+ import chromadb
3
+ from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
4
+ from sentence_transformers import CrossEncoder
5
+ import torch
6
+ from rank_bm25 import BM25Okapi
7
+ import string
8
+ import os
9
+ import sys
10
+
11
+ # --- 1. SETUP & MODEL LOADING ---
12
+ print("⏳ Loading models...")
13
+
14
+ # Detect Hardware (GPU vs CPU)
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ print(f"Running on: {device}")
17
+
18
+ # Embedding Function (Must match what you used to create the DB)
19
+ ef = SentenceTransformerEmbeddingFunction(
20
+ model_name="BAAI/bge-m3",
21
+ device=device
22
+ )
23
+
24
+ # Reranker Model
25
+ reranker = CrossEncoder(
26
+ "BAAI/bge-reranker-v2-m3",
27
+ device=device,
28
+ trust_remote_code=True,
29
+ model_kwargs={"dtype": "float16"} if device == "cuda" else {}
30
+ )
31
+
32
+ print("✅ Models loaded!")
33
+
34
+ # --- 2. LOAD PERSISTENT VECTOR DB ---
35
+ DB_PATH = "./vector_db" # This must match the folder name you uploaded
36
+
37
+ if not os.path.exists(DB_PATH):
38
+ print(f"❌ Error: The folder '{DB_PATH}' was not found in the Space.")
39
+ print("Please upload your local 'vector_db' folder to the Files tab.")
40
+ # We don't exit here so you can see the error in logs, but the app will fail later.
41
+ else:
42
+ print(f"wd: {os.getcwd()}") # Print working directory for debugging
43
+
44
+ # Initialize Persistent Client
45
+ client = chromadb.PersistentClient(path=DB_PATH)
46
+
47
+ # Get the existing collection
48
+ # Note: We use get_collection because we assume it already exists.
49
+ try:
50
+ collection = client.get_collection(name='ct_data', embedding_function=ef)
51
+ print(f"✅ Loaded collection 'ct_data' with {collection.count()} documents.")
52
+ except Exception as e:
53
+ print(f"❌ Error loading collection: {e}")
54
+ # Fallback for debugging if name is wrong
55
+ print(f"Available collections: {[c.name for c in client.list_collections()]}")
56
+ sys.exit(1)
57
+
58
+ # --- 3. BUILD IN-MEMORY INDICES (BM25) ---
59
+ # We need to fetch all documents from the DB to build the BM25 index
60
+ # and the metadata cache. This avoids needing the CSV files.
61
+
62
+ bm25_index = None
63
+ doc_id_map = {}
64
+ all_metadatas = {}
65
+
66
+ def build_indices_from_db():
67
+ global bm25_index, doc_id_map, all_metadatas
68
+
69
+ print("⏳ Fetching data from ChromaDB to build BM25 index...")
70
+
71
+ # Fetch all data (IDs, Documents, Metadatas)
72
+ # If you have >100k docs, you might want to paginate this, but for typical RAG it's fine.
73
+ data = collection.get()
74
+
75
+ ids = data['ids']
76
+ documents = data['documents']
77
+ metadatas = data['metadatas']
78
+
79
+ if not documents:
80
+ print("⚠️ Warning: Collection is empty!")
81
+ return
82
+
83
+ # Build BM25 Corpus
84
+ print(f"Processing {len(documents)} documents for Keyword Search...")
85
+ tokenized_corpus = [
86
+ doc.lower().translate(str.maketrans('', '', string.punctuation)).split()
87
+ for doc in documents
88
+ ]
89
+ bm25_index = BM25Okapi(tokenized_corpus)
90
+
91
+ # Reconstruct Cache
92
+ for idx, (doc_id, doc_text, meta) in enumerate(zip(ids, documents, metadatas)):
93
+ # Map integer index (from BM25) back to string ID
94
+ doc_id_map[idx] = doc_id
95
+
96
+ # Store in fast lookup dict
97
+ all_metadatas[doc_id] = {
98
+ "document": doc_text,
99
+ "meta": meta if meta else {}
100
+ }
101
+
102
+ print("✅ Hybrid Index Ready.")
103
+
104
+ # Run this immediately
105
+ build_indices_from_db()
106
+
107
+ # --- 4. SEARCH LOGIC ---
108
+ def reciprocal_rank_fusion(vector_results, bm25_results, k=60):
109
+ fused_scores = {}
110
+ for rank, doc_id in enumerate(vector_results):
111
+ fused_scores[doc_id] = fused_scores.get(doc_id, 0) + (1 / (k + rank))
112
+ for rank, doc_id in enumerate(bm25_results):
113
+ fused_scores[doc_id] = fused_scores.get(doc_id, 0) + (1 / (k + rank))
114
+ return sorted(fused_scores.keys(), key=lambda x: fused_scores[x], reverse=True)
115
+
116
+ def granular_search(query: str, initial_k: int = 15, final_k: int = 3):
117
+ try:
118
+ # A. Vector Search
119
+ # Querying the persistent DB
120
+ vec_res = collection.query(query_texts=[query], n_results=initial_k)
121
+ vector_ids = vec_res['ids'][0] if vec_res['ids'] else []
122
+
123
+ # B. BM25 Search
124
+ bm25_ids = []
125
+ if bm25_index:
126
+ tokenized = query.lower().translate(str.maketrans('', '', string.punctuation)).split()
127
+ scores = bm25_index.get_scores(tokenized)
128
+ top_indices = scores.argsort()[::-1][:initial_k]
129
+ bm25_ids = [doc_id_map[i] for i in top_indices if scores[i] > 0]
130
+
131
+ # C. Fusion
132
+ candidates_ids = reciprocal_rank_fusion(vector_ids, bm25_ids)[:initial_k]
133
+
134
+ if not candidates_ids:
135
+ return {"data": [], "message": "No results found"}
136
+
137
+ # D. Fetch Text (from Cache)
138
+ docs = []
139
+ metas = []
140
+ for did in candidates_ids:
141
+ item = all_metadatas.get(did)
142
+ if item:
143
+ docs.append(item['document'])
144
+ metas.append(item['meta'])
145
+
146
+ # E. Rerank
147
+ if not docs:
148
+ return {"data": []}
149
+
150
+ pairs = [[query, doc] for doc in docs]
151
+ scores = reranker.predict(pairs)
152
+
153
+ # F. Format
154
+ results = sorted(zip(scores, docs, metas), key=lambda x: x[0], reverse=True)[:final_k]
155
+
156
+ formatted_data = []
157
+ for score, doc, meta in results:
158
+ formatted_data.append({
159
+ "name": meta.get('name', 'Unknown'),
160
+ "description": doc,
161
+ "image_id": meta.get('image id', ''),
162
+ "relevance_score": float(score),
163
+ "building_type": meta.get('building_type', 'unknown')
164
+ })
165
+
166
+ return {
167
+ "data": formatted_data,
168
+ "meta": {
169
+ "query": query,
170
+ "count": len(formatted_data)
171
+ }
172
+ }
173
+
174
+ except Exception as e:
175
+ return {"error": str(e)}
176
+
177
+ # --- 5. GRADIO UI ---
178
+ demo = gr.Interface(
179
+ fn=granular_search,
180
+ inputs=[
181
+ gr.Textbox(label="Query", placeholder="Search for Vietnamese architecture..."),
182
+ gr.Number(value=15, label="Initial K", visible=False),
183
+ gr.Number(value=3, label="Final K", visible=False)
184
+ ],
185
+ outputs=gr.JSON(label="Results"),
186
+ title="Granular Search API (Persistent)",
187
+ flagging_mode="never"
188
+ )
189
+
190
+ if __name__ == "__main__":
191
  demo.queue().launch(server_name="0.0.0.0", server_port=7860)