JustscrAPIng commited on
Commit
7523ee1
·
verified ·
1 Parent(s): 6f540a7

initial commit

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ vector_db/chroma.sqlite3 filter=lfs diff=lfs merge=lfs -text
app.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import chromadb
3
+ from chromadb.utils.embedding_functions import SentenceTransformerEmbeddingFunction
4
+ from sentence_transformers import CrossEncoder
5
+ import torch
6
+ from rank_bm25 import BM25Okapi
7
+ import string
8
+ import os
9
+ import sys
10
+
11
+ # --- 1. SETUP & MODEL LOADING ---
12
+ print("⏳ Loading models...")
13
+
14
+ # Detect Hardware (GPU vs CPU)
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+ print(f"Running on: {device}")
17
+
18
+ # Embedding Function (Must match what you used to create the DB)
19
+ ef = SentenceTransformerEmbeddingFunction(
20
+ model_name="BAAI/bge-m3",
21
+ device=device
22
+ )
23
+
24
+ # Reranker Model
25
+ reranker = CrossEncoder(
26
+ "BAAI/bge-reranker-v2-m3",
27
+ device=device,
28
+ trust_remote_code=True,
29
+ model_kwargs={"dtype": "float16"} if device == "cuda" else {}
30
+ )
31
+
32
+ print("✅ Models loaded!")
33
+
34
+ # --- 2. LOAD PERSISTENT VECTOR DB ---
35
+ DB_PATH = "./vector_db" # This must match the folder name you uploaded
36
+
37
+ if not os.path.exists(DB_PATH):
38
+ print(f"❌ Error: The folder '{DB_PATH}' was not found in the Space.")
39
+ print("Please upload your local 'vector_db' folder to the Files tab.")
40
+ # We don't exit here so you can see the error in logs, but the app will fail later.
41
+ else:
42
+ print(f"wd: {os.getcwd()}") # Print working directory for debugging
43
+
44
+ # Initialize Persistent Client
45
+ client = chromadb.PersistentClient(path=DB_PATH)
46
+
47
+ # Get the existing collection
48
+ # Note: We use get_collection because we assume it already exists.
49
+ try:
50
+ collection = client.get_collection(name='ct_data', embedding_function=ef)
51
+ print(f"✅ Loaded collection 'ct_data' with {collection.count()} documents.")
52
+ except Exception as e:
53
+ print(f"❌ Error loading collection: {e}")
54
+ # Fallback for debugging if name is wrong
55
+ print(f"Available collections: {[c.name for c in client.list_collections()]}")
56
+ sys.exit(1)
57
+
58
+ # --- 3. BUILD IN-MEMORY INDICES (BM25) ---
59
+ # We need to fetch all documents from the DB to build the BM25 index
60
+ # and the metadata cache. This avoids needing the CSV files.
61
+
62
+ bm25_index = None
63
+ doc_id_map = {}
64
+ all_metadatas = {}
65
+
66
+ def build_indices_from_db():
67
+ global bm25_index, doc_id_map, all_metadatas
68
+
69
+ print("⏳ Fetching data from ChromaDB to build BM25 index...")
70
+
71
+ # Fetch all data (IDs, Documents, Metadatas)
72
+ # If you have >100k docs, you might want to paginate this, but for typical RAG it's fine.
73
+ data = collection.get()
74
+
75
+ ids = data['ids']
76
+ documents = data['documents']
77
+ metadatas = data['metadatas']
78
+
79
+ if not documents:
80
+ print("⚠️ Warning: Collection is empty!")
81
+ return
82
+
83
+ # Build BM25 Corpus
84
+ print(f"Processing {len(documents)} documents for Keyword Search...")
85
+ tokenized_corpus = [
86
+ doc.lower().translate(str.maketrans('', '', string.punctuation)).split()
87
+ for doc in documents
88
+ ]
89
+ bm25_index = BM25Okapi(tokenized_corpus)
90
+
91
+ # Reconstruct Cache
92
+ for idx, (doc_id, doc_text, meta) in enumerate(zip(ids, documents, metadatas)):
93
+ # Map integer index (from BM25) back to string ID
94
+ doc_id_map[idx] = doc_id
95
+
96
+ # Store in fast lookup dict
97
+ all_metadatas[doc_id] = {
98
+ "document": doc_text,
99
+ "meta": meta if meta else {}
100
+ }
101
+
102
+ print("✅ Hybrid Index Ready.")
103
+
104
+ # Run this immediately
105
+ build_indices_from_db()
106
+
107
+ # --- 4. SEARCH LOGIC ---
108
+ def reciprocal_rank_fusion(vector_results, bm25_results, k=60):
109
+ fused_scores = {}
110
+ for rank, doc_id in enumerate(vector_results):
111
+ fused_scores[doc_id] = fused_scores.get(doc_id, 0) + (1 / (k + rank))
112
+ for rank, doc_id in enumerate(bm25_results):
113
+ fused_scores[doc_id] = fused_scores.get(doc_id, 0) + (1 / (k + rank))
114
+ return sorted(fused_scores.keys(), key=lambda x: fused_scores[x], reverse=True)
115
+
116
+ def granular_search(query: str, initial_k: int = 15, final_k: int = 3):
117
+ try:
118
+ # A. Vector Search
119
+ # Querying the persistent DB
120
+ vec_res = collection.query(query_texts=[query], n_results=initial_k)
121
+ vector_ids = vec_res['ids'][0] if vec_res['ids'] else []
122
+
123
+ # B. BM25 Search
124
+ bm25_ids = []
125
+ if bm25_index:
126
+ tokenized = query.lower().translate(str.maketrans('', '', string.punctuation)).split()
127
+ scores = bm25_index.get_scores(tokenized)
128
+ top_indices = scores.argsort()[::-1][:initial_k]
129
+ bm25_ids = [doc_id_map[i] for i in top_indices if scores[i] > 0]
130
+
131
+ # C. Fusion
132
+ candidates_ids = reciprocal_rank_fusion(vector_ids, bm25_ids)[:initial_k]
133
+
134
+ if not candidates_ids:
135
+ return {"data": [], "message": "No results found"}
136
+
137
+ # D. Fetch Text (from Cache)
138
+ docs = []
139
+ metas = []
140
+ for did in candidates_ids:
141
+ item = all_metadatas.get(did)
142
+ if item:
143
+ docs.append(item['document'])
144
+ metas.append(item['meta'])
145
+
146
+ # E. Rerank
147
+ if not docs:
148
+ return {"data": []}
149
+
150
+ pairs = [[query, doc] for doc in docs]
151
+ scores = reranker.predict(pairs)
152
+
153
+ # F. Format
154
+ results = sorted(zip(scores, docs, metas), key=lambda x: x[0], reverse=True)[:final_k]
155
+
156
+ formatted_data = []
157
+ for score, doc, meta in results:
158
+ formatted_data.append({
159
+ "name": meta.get('name', 'Unknown'),
160
+ "description": doc,
161
+ "image_id": meta.get('image id', ''),
162
+ "relevance_score": float(score),
163
+ "building_type": meta.get('building_type', 'unknown')
164
+ })
165
+
166
+ return {
167
+ "data": formatted_data,
168
+ "meta": {
169
+ "query": query,
170
+ "count": len(formatted_data)
171
+ }
172
+ }
173
+
174
+ except Exception as e:
175
+ return {"error": str(e)}
176
+
177
+ # --- 5. GRADIO UI ---
178
+ demo = gr.Interface(
179
+ fn=granular_search,
180
+ inputs=[
181
+ gr.Textbox(label="Query", placeholder="Search for Vietnamese architecture..."),
182
+ gr.Number(value=15, label="Initial K", visible=False),
183
+ gr.Number(value=3, label="Final K", visible=False)
184
+ ],
185
+ outputs=gr.JSON(label="Results"),
186
+ title="Granular Search API (Persistent)",
187
+ allow_flagging="never"
188
+ )
189
+
190
+ if __name__ == "__main__":
191
+ demo.queue().launch(server_name="0.0.0.0", server_port=7860)
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ gradio
2
+ chromadb
3
+ sentence-transformers
4
+ rank-bm25
5
+ torch
6
+ pandas
vector_db/chroma.sqlite3 ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a19a14fda1005c8076e5c60c0dc0e278d4ab4dd1fb5f887cba0963ce8e6a52f6
3
+ size 9961472
vector_db/d52b05ce-3bda-4661-8b17-113da6931a95/data_level0.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:00f5b57a41e0f68a4c1d5b16ce971b6eb7d7c5daee0babf5a25083b5a366fe0c
3
+ size 423600
vector_db/d52b05ce-3bda-4661-8b17-113da6931a95/header.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bf12d4486518c7addf488cb4854526902c78e91951990e1e2f4e055cec814e5d
3
+ size 100
vector_db/d52b05ce-3bda-4661-8b17-113da6931a95/length.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a8979607f7bee3b8d1f84a23664624c5ef19c4b96f0fccb1fd1fb45e1e962e37
3
+ size 400
vector_db/d52b05ce-3bda-4661-8b17-113da6931a95/link_lists.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855
3
+ size 0