Nyha15 commited on
Commit
66b667c
·
1 Parent(s): fe05c77

Refactored

Browse files
Files changed (1) hide show
  1. app.py +64 -10
app.py CHANGED
@@ -169,24 +169,78 @@ class KnowledgeBase:
169
  with self.lock:
170
  if country in self.vstores:
171
  return
172
- texts = []
173
- if self.domain == "banking": texts = self.collector.get_banking_data(country)
174
- elif self.domain == "credit": texts = self.collector.get_credit_data(country)
 
 
175
  elif self.domain == "tax":
176
  ti = self.collector.tax_db.get_tax_information(country)
177
- texts = ti["regulations"] + ti["treaty"]
178
- # else: other domains...
179
- splits = self.splitter.split_text("\n\n".join(texts))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
  store = Chroma.from_texts(splits, self.embeddings, collection_name=f"{self.domain}_{country}")
 
181
  with self.lock:
182
  self.vstores[country] = store
183
- self.retrievers[country] = store.as_retriever(search_kwargs={"k":3})
 
184
 
185
  def retrieve(self, query: str, country: str) -> List[str]:
186
- log_workflow(f"Retrieving {self.domain} for {country}")
187
  self._init_country(country)
188
- docs = self.retrievers[country].get_relevant_documents(query)
189
- return [d.page_content for d in docs]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
190
 
191
  # =======================================
192
  # Specialist Agents
 
169
  with self.lock:
170
  if country in self.vstores:
171
  return
172
+ # Retrieve raw texts
173
+ if self.domain == "banking":
174
+ texts = self.collector.get_banking_data(country)
175
+ elif self.domain == "credit":
176
+ texts = self.collector.get_credit_data(country)
177
  elif self.domain == "tax":
178
  ti = self.collector.tax_db.get_tax_information(country)
179
+ texts = ti.get("regulations", []) + ti.get("treaty", [])
180
+ else:
181
+ texts = []
182
+
183
+ if not texts:
184
+ log_workflow(f"No texts available for domain '{self.domain}' and country '{country}'")
185
+ with self.lock:
186
+ self.vstores[country] = None
187
+ self.retrievers[country] = None
188
+ return
189
+
190
+ # Split texts into chunks
191
+ # Split texts into chunks
192
+ splits = self.splitter.split_text("".join(texts))
193
+ if not splits:
194
+ log_workflow(f"No splits generated for domain '{self.domain}' and country '{country}'")
195
+ with self.lock:
196
+ self.vstores[country] = None
197
+ self.retrievers[country] = None
198
+ return
199
+
200
+ # Build vector store
201
  store = Chroma.from_texts(splits, self.embeddings, collection_name=f"{self.domain}_{country}")
202
+ retr = store.as_retriever(search_kwargs={"k": 3})
203
  with self.lock:
204
  self.vstores[country] = store
205
+ self.retrievers[country] = retr
206
+ log_workflow(f"Vector store ready for domain '{self.domain}' and country '{country}'")
207
 
208
  def retrieve(self, query: str, country: str) -> List[str]:
209
+ log_workflow(f"Retrieving domain '{self.domain}' for {country}")
210
  self._init_country(country)
211
+ retr = self.retrievers.get(country)
212
+ if not retr:
213
+ # Fallback to direct collector methods
214
+ log_workflow(f"Falling back to direct retrieval for domain '{self.domain}' and country '{country}'")
215
+ if self.domain == "banking":
216
+ return self.collector.get_banking_data(country)
217
+ if self.domain == "credit":
218
+ return self.collector.get_credit_data(country)
219
+ if self.domain == "tax":
220
+ info = self.collector.tax_db.get_tax_information(country)
221
+ return info.get("regulations", []) + info.get("treaty", [])
222
+ return []
223
+
224
+ # Perform similarity search
225
+ docs = retr.get_relevant_documents(query)
226
+ results = [d.page_content for d in docs]
227
+ log_workflow(f"Retrieved {len(results)} docs for domain '{self.domain}' and country '{country}'")
228
+ return results
229
+
230
+ # Pre-initialize KnowledgeBase for common domains and countries
231
+ COMMON_COUNTRIES = ["India", "China", "South Korea", "Brazil", "Saudi Arabia", "Canada", "Mexico", "Taiwan", "Japan", "Vietnam"]
232
+ DOMAINS = ["banking", "credit", "tax"]
233
+
234
+ # Initialize and preload vector stores at startup
235
+ def preload_kbs():
236
+ for domain in DOMAINS:
237
+ kb = KnowledgeBase(domain)
238
+ for country in COMMON_COUNTRIES:
239
+ # Launch in background to avoid blocking
240
+ threading.Thread(target=kb._init_country, args=(country,), daemon=True).start()
241
+
242
+ # Trigger preload
243
+ preload_kbs()
244
 
245
  # =======================================
246
  # Specialist Agents