Nyha15 commited on
Commit
5c95ea1
·
1 Parent(s): 5d862db

Refactored

Browse files
Files changed (1) hide show
  1. app.py +55 -93
app.py CHANGED
@@ -53,17 +53,14 @@ def get_workflow_log() -> str:
53
  # =======================================
54
 
55
  class TaxRegulationDatabase:
56
- """Database of tax regulations for international students"""
57
-
58
  def __init__(self):
59
- self.llm = ChatOpenAI(temperature=0.1, model="gpt-3.5-turbo")
60
  self.tax_regulations: Dict[str, List[str]] = {}
61
  self.tax_treaties: Dict[str, List[str]] = {}
62
  self.lock = threading.Lock()
63
 
64
  def preload_common_countries(self):
65
- countries = ["India", "China", "South Korea", "Brazil", "Saudi Arabia",
66
- "Canada", "Mexico", "Taiwan", "Japan", "Vietnam"]
67
  log_workflow("Preloading tax regulations for common countries")
68
  for country in countries:
69
  threading.Thread(target=self._load_all, args=(country,), daemon=True).start()
@@ -75,11 +72,10 @@ class TaxRegulationDatabase:
75
  @lru_cache(maxsize=32)
76
  def _get_tax_regulations(self, country: str) -> List[str]:
77
  log_workflow(f"Loading tax regulations for {country}")
78
- prompt = (f"Provide 5 specific, factual statements about tax regulations that directly affect international students "
79
- f"from {country} studying in the US. Include form numbers, thresholds, deadlines.")
80
  try:
81
  resp = self.llm.invoke(prompt)
82
- regs = [line.strip() for line in resp.content.split('\n') if line.strip()]
83
  with self.lock:
84
  self.tax_regulations[country] = regs
85
  return regs
@@ -89,12 +85,11 @@ class TaxRegulationDatabase:
89
 
90
  @lru_cache(maxsize=32)
91
  def _get_tax_treaty(self, country: str) -> List[str]:
92
- log_workflow(f"Loading tax treaty info for {country}")
93
- prompt = (f"Provide 5 specific statements about the US-{country} tax treaty relevant to students, "
94
- f"including article numbers and exemption limits.")
95
  try:
96
  resp = self.llm.invoke(prompt)
97
- treaty = [line.strip() for line in resp.content.split('\n') if line.strip()]
98
  with self.lock:
99
  self.tax_treaties[country] = treaty
100
  return treaty
@@ -113,43 +108,42 @@ class TaxRegulationDatabase:
113
  # =======================================
114
 
115
  class InternationalStudentDataCollector:
116
- """Collects financial data for international students"""
117
-
118
  def __init__(self):
119
- self.llm = ChatOpenAI(temperature=0.1, model="gpt-3.5-turbo")
120
  self.cache: Dict[str, List[str]] = {}
121
  self.tax_db = TaxRegulationDatabase()
122
 
123
  def preload_common(self):
124
  log_workflow("Preloading data for common countries")
125
  self.tax_db.preload_common_countries()
126
- for country in ["India", "China"]:
127
  for fn in [self.get_banking_data, self.get_credit_data]:
128
- threading.Thread(target=fn, args=(country,), daemon=True).start()
129
 
130
  def _cached(self, key: str, prompt: str) -> List[str]:
131
  log_workflow(f"Collecting data for {key}")
132
  if key in self.cache:
133
- log_workflow("Using cached data")
134
  return self.cache[key]
135
  try:
136
  resp = self.llm.invoke(prompt)
137
- facts = [line.strip() for line in resp.content.split('\n') if line.strip()]
138
- self.cache[key] = facts
139
- return facts
140
  except Exception as e:
141
  log_workflow(f"Error collecting {key}", str(e))
142
  return [f"Error: {e}"]
143
 
144
  def get_banking_data(self, country: str) -> List[str]:
145
- prompt = (f"5 facts about banking for {country} students in the US: banks, fees, docs.")
146
- return self._cached(f"banking_{country}", prompt)
 
 
147
 
148
  def get_credit_data(self, country: str) -> List[str]:
149
- prompt = (f"5 facts about credit building for {country} students: cards, steps, pitfalls.")
150
- return self._cached(f"credit_{country}", prompt)
151
-
152
- # Additional domain methods omitted for brevity
153
 
154
  # =======================================
155
  # RAG Knowledge Base
@@ -169,7 +163,6 @@ class KnowledgeBase:
169
  with self.lock:
170
  if country in self.vstores:
171
  return
172
- # Retrieve raw texts
173
  if self.domain == "banking":
174
  texts = self.collector.get_banking_data(country)
175
  elif self.domain == "credit":
@@ -179,68 +172,48 @@ class KnowledgeBase:
179
  texts = ti.get("regulations", []) + ti.get("treaty", [])
180
  else:
181
  texts = []
182
-
183
  if not texts:
184
- log_workflow(f"No texts available for domain '{self.domain}' and country '{country}'")
185
  with self.lock:
186
  self.vstores[country] = None
187
  self.retrievers[country] = None
188
  return
189
-
190
- # Split texts into chunks
191
- # Split texts into chunks
192
  splits = self.splitter.split_text("\n\n".join(texts))
193
  if not splits:
194
- log_workflow(f"No splits generated for domain '{self.domain}' and country '{country}'")
195
  with self.lock:
196
  self.vstores[country] = None
197
  self.retrievers[country] = None
198
  return
199
-
200
- # Build vector store
201
  store = Chroma.from_texts(splits, self.embeddings, collection_name=f"{self.domain}_{country}")
202
- retr = store.as_retriever(search_kwargs={"k": 3})
203
  with self.lock:
204
  self.vstores[country] = store
205
  self.retrievers[country] = retr
206
- log_workflow(f"Vector store ready for domain '{self.domain}' and country '{country}'")
207
 
208
  def retrieve(self, query: str, country: str) -> List[str]:
209
- log_workflow(f"Retrieving domain '{self.domain}' for {country}")
210
  self._init_country(country)
211
  retr = self.retrievers.get(country)
212
  if not retr:
213
- # Fallback to direct collector methods
214
- log_workflow(f"Falling back to direct retrieval for domain '{self.domain}' and country '{country}'")
215
- if self.domain == "banking":
216
- return self.collector.get_banking_data(country)
217
- if self.domain == "credit":
218
- return self.collector.get_credit_data(country)
219
  if self.domain == "tax":
220
- info = self.collector.tax_db.get_tax_information(country)
221
- return info.get("regulations", []) + info.get("treaty", [])
222
  return []
223
-
224
- # Perform similarity search
225
  docs = retr.get_relevant_documents(query)
226
- results = [d.page_content for d in docs]
227
- log_workflow(f"Retrieved {len(results)} docs for domain '{self.domain}' and country '{country}'")
228
- return results
229
-
230
- # Pre-initialize KnowledgeBase for common domains and countries
231
- COMMON_COUNTRIES = ["India", "China", "South Korea", "Brazil", "Saudi Arabia", "Canada", "Mexico", "Taiwan", "Japan", "Vietnam"]
232
- DOMAINS = ["banking", "credit", "tax"]
233
 
234
- # Initialize and preload vector stores at startup
235
- def preload_kbs():
236
- for domain in DOMAINS:
237
- kb = KnowledgeBase(domain)
238
- for country in COMMON_COUNTRIES:
239
- # Launch in background to avoid blocking
240
- threading.Thread(target=kb._init_country, args=(country,), daemon=True).start()
241
-
242
- # Trigger preload
243
- preload_kbs()
244
 
245
  # =======================================
246
  # Specialist Agents
@@ -256,17 +229,16 @@ class SpecialistAgent:
256
  log_workflow(f"{self.name} analyzing")
257
  refs = self.kb.retrieve(query, country)
258
  context = "\n".join(f"- {r}" for r in refs)
259
- prompt = f"As {self.name} for {country}, references:\n{context}\nQuestion: {query}\nProvide detailed advice."
260
  resp = self.llm.invoke(prompt)
261
  log_workflow(f"{self.name} done")
262
  return resp.content
263
 
264
  # Instantiate specialists
265
- BankingAdvisor = lambda llm=None: SpecialistAgent("Banking Advisor", "banking")
266
- CreditBuilder = lambda llm=None: SpecialistAgent("Credit Builder", "credit")
267
- LegalFinanceAdvisor = lambda llm=None: SpecialistAgent("Legal Advisor", "legal")
268
- TaxSpecialist = lambda llm=None: SpecialistAgent("Tax Specialist", "tax")
269
- # Additional specialists omitted
270
 
271
  # =======================================
272
  # Coordinator Agent
@@ -278,38 +250,28 @@ class CoordinatorAgent:
278
  self.specialists = {
279
  "banking": BankingAdvisor(),
280
  "credit": CreditBuilder(),
281
- "legal": LegalFinanceAdvisor(),
282
  "tax": TaxSpecialist()
283
  }
284
 
285
- def run(self, query: str, profile: Dict[str, Any]) -> str:
286
  clear_workflow_log()
287
  country = profile.get("home_country","unknown")
288
- # 1. collect
289
- advice = {d: self.specialists[d].run(query, country) for d in self.specialists}
290
- # 2. plan (omitted)
291
- # 3. synthesis
292
- # 4. Synthesize and pretty–print
293
  lines = ["# Your Personalized Financial Advice\n"]
294
-
295
- # Add each specialist’s section
296
- for domain, text in specialist_advice.items():
297
  lines.append(f"## {domain.capitalize()}\n")
298
- # Indent each paragraph for readability
299
  for para in text.strip().split("\n\n"):
300
- lines.append(" " + para.strip().replace("\n", "\n "))
301
- lines.append("") # blank line
302
-
303
- # Append Multi-Path Plans as a JSON code block
304
- lines.append("## Multi-Path Financial Plans\n")
305
- lines.append("```json")
306
- lines.append(json.dumps(plans, indent=2))
307
  lines.append("```")
308
-
309
  formatted = "\n".join(lines)
310
  log_workflow("Synthesis complete")
311
-
312
- # 5. Return formatted advice + workflow log
313
  return formatted + "\n\n---\n" + get_workflow_log()
314
 
315
 
 
53
  # =======================================
54
 
55
  class TaxRegulationDatabase:
 
 
56
  def __init__(self):
57
+ self.llm = ChatOpenAI(temperature=0.1)
58
  self.tax_regulations: Dict[str, List[str]] = {}
59
  self.tax_treaties: Dict[str, List[str]] = {}
60
  self.lock = threading.Lock()
61
 
62
  def preload_common_countries(self):
63
+ countries = ["India", "China", "South Korea", "Brazil", "Canada", "Mexico", "Taiwan", "Japan", "Vietnam"]
 
64
  log_workflow("Preloading tax regulations for common countries")
65
  for country in countries:
66
  threading.Thread(target=self._load_all, args=(country,), daemon=True).start()
 
72
  @lru_cache(maxsize=32)
73
  def _get_tax_regulations(self, country: str) -> List[str]:
74
  log_workflow(f"Loading tax regulations for {country}")
75
+ prompt = f"Provide 5 factual statements about tax regs for {country} students in the US, incl. forms, thresholds."
 
76
  try:
77
  resp = self.llm.invoke(prompt)
78
+ regs = [line.strip() for line in resp.content.split("\n") if line.strip()]
79
  with self.lock:
80
  self.tax_regulations[country] = regs
81
  return regs
 
85
 
86
  @lru_cache(maxsize=32)
87
  def _get_tax_treaty(self, country: str) -> List[str]:
88
+ log_workflow(f"Loading tax treaty for {country}")
89
+ prompt = f"Provide 5 statements about US-{country} tax treaty for students, incl. articles, exemptions."
 
90
  try:
91
  resp = self.llm.invoke(prompt)
92
+ treaty = [line.strip() for line in resp.content.split("\n") if line.strip()]
93
  with self.lock:
94
  self.tax_treaties[country] = treaty
95
  return treaty
 
108
  # =======================================
109
 
110
  class InternationalStudentDataCollector:
 
 
111
  def __init__(self):
112
+ self.llm = ChatOpenAI(temperature=0.1)
113
  self.cache: Dict[str, List[str]] = {}
114
  self.tax_db = TaxRegulationDatabase()
115
 
116
  def preload_common(self):
117
  log_workflow("Preloading data for common countries")
118
  self.tax_db.preload_common_countries()
119
+ for c in ["India", "China"]:
120
  for fn in [self.get_banking_data, self.get_credit_data]:
121
+ threading.Thread(target=fn, args=(c,), daemon=True).start()
122
 
123
  def _cached(self, key: str, prompt: str) -> List[str]:
124
  log_workflow(f"Collecting data for {key}")
125
  if key in self.cache:
 
126
  return self.cache[key]
127
  try:
128
  resp = self.llm.invoke(prompt)
129
+ items = [line.strip() for line in resp.content.split("\n") if line.strip()]
130
+ self.cache[key] = items
131
+ return items
132
  except Exception as e:
133
  log_workflow(f"Error collecting {key}", str(e))
134
  return [f"Error: {e}"]
135
 
136
  def get_banking_data(self, country: str) -> List[str]:
137
+ return self._cached(
138
+ f"banking_{country}",
139
+ f"5 facts on banking for {country} students in the US, incl. banks, fees, docs."
140
+ )
141
 
142
  def get_credit_data(self, country: str) -> List[str]:
143
+ return self._cached(
144
+ f"credit_{country}",
145
+ f"5 facts on credit building for {country} students: cards, history, pitfalls."
146
+ )
147
 
148
  # =======================================
149
  # RAG Knowledge Base
 
163
  with self.lock:
164
  if country in self.vstores:
165
  return
 
166
  if self.domain == "banking":
167
  texts = self.collector.get_banking_data(country)
168
  elif self.domain == "credit":
 
172
  texts = ti.get("regulations", []) + ti.get("treaty", [])
173
  else:
174
  texts = []
 
175
  if not texts:
176
+ log_workflow(f"No texts for {self.domain}/{country}")
177
  with self.lock:
178
  self.vstores[country] = None
179
  self.retrievers[country] = None
180
  return
 
 
 
181
  splits = self.splitter.split_text("\n\n".join(texts))
182
  if not splits:
183
+ log_workflow(f"No splits for {self.domain}/{country}")
184
  with self.lock:
185
  self.vstores[country] = None
186
  self.retrievers[country] = None
187
  return
 
 
188
  store = Chroma.from_texts(splits, self.embeddings, collection_name=f"{self.domain}_{country}")
189
+ retr = store.as_retriever(search_kwargs={"k":3})
190
  with self.lock:
191
  self.vstores[country] = store
192
  self.retrievers[country] = retr
193
+ log_workflow(f"Vector store ready for {self.domain}/{country}")
194
 
195
  def retrieve(self, query: str, country: str) -> List[str]:
196
+ log_workflow(f"Retrieving {self.domain} for {country}")
197
  self._init_country(country)
198
  retr = self.retrievers.get(country)
199
  if not retr:
200
+ log_workflow(f"Fallback direct for {self.domain}/{country}")
201
+ if self.domain == "banking": return self.collector.get_banking_data(country)
202
+ if self.domain == "credit": return self.collector.get_credit_data(country)
 
 
 
203
  if self.domain == "tax":
204
+ ti = self.collector.tax_db.get_tax_information(country)
205
+ return ti.get("regulations",[]) + ti.get("treaty",[])
206
  return []
 
 
207
  docs = retr.get_relevant_documents(query)
208
+ return [d.page_content for d in docs]
 
 
 
 
 
 
209
 
210
+ # Preload KBs
211
+ COMMON_COUNTRIES = ["India","China"]
212
+ DOMAINS = ["banking","credit","tax"]
213
+ for dom in DOMAINS:
214
+ kb = KnowledgeBase(dom)
215
+ for c in COMMON_COUNTRIES:
216
+ threading.Thread(target=kb._init_country, args=(c,), daemon=True).start()
 
 
 
217
 
218
  # =======================================
219
  # Specialist Agents
 
229
  log_workflow(f"{self.name} analyzing")
230
  refs = self.kb.retrieve(query, country)
231
  context = "\n".join(f"- {r}" for r in refs)
232
+ prompt = f"As {self.name} for {country}, context:\n{context}\nQuestion: {query}\nProvide detailed advice."
233
  resp = self.llm.invoke(prompt)
234
  log_workflow(f"{self.name} done")
235
  return resp.content
236
 
237
  # Instantiate specialists
238
+ BankingAdvisor = lambda: SpecialistAgent("Banking Advisor","banking")
239
+ CreditBuilder = lambda: SpecialistAgent("Credit Builder","credit")
240
+ TaxSpecialist = lambda: SpecialistAgent("Tax Specialist","tax")
241
+ # Add more as needed
 
242
 
243
  # =======================================
244
  # Coordinator Agent
 
250
  self.specialists = {
251
  "banking": BankingAdvisor(),
252
  "credit": CreditBuilder(),
 
253
  "tax": TaxSpecialist()
254
  }
255
 
256
+ def run(self, query: str, profile: Dict[str,Any]) -> str:
257
  clear_workflow_log()
258
  country = profile.get("home_country","unknown")
259
+ # 1. Gather specialist advice
260
+ advice_map = {d:agent.run(query,country) for d,agent in self.specialists.items()}
261
+ # 2. Multi-path plans placeholder
262
+ plans = {"conservative":"...","balanced":"...","growth":"..."}
263
+ # 3. Synthesis & formatting
264
  lines = ["# Your Personalized Financial Advice\n"]
265
+ for domain, text in advice_map.items():
 
 
266
  lines.append(f"## {domain.capitalize()}\n")
 
267
  for para in text.strip().split("\n\n"):
268
+ lines.append(" "+para.replace("\n","\n "))
269
+ lines.append("")
270
+ lines.append("## Multi-Path Financial Plans\n```json")
271
+ lines.append(json.dumps(plans,indent=2))
 
 
 
272
  lines.append("```")
 
273
  formatted = "\n".join(lines)
274
  log_workflow("Synthesis complete")
 
 
275
  return formatted + "\n\n---\n" + get_workflow_log()
276
 
277