hskwon7 commited on
Commit
2d69da8
Β·
verified Β·
1 Parent(s): 34e567f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +104 -153
app.py CHANGED
@@ -12,64 +12,54 @@ from transformers import AutoTokenizer, AutoModelForTokenClassification
12
 
13
  @st.cache_data(show_spinner=False)
14
  def load_etf_data():
15
- """
16
- Load ETF data with a persistent 'doc' column.
17
- - First run: reads raw CSV, computes 'doc', saves enriched CSV.
18
- - Subsequent runs: loads enriched CSV directly.
19
- """
20
  enriched_path = "etf_general_info_enriched_doc_added.csv"
21
  raw_path = "etf_general_info_enriched.csv"
22
-
23
  if os.path.exists(enriched_path):
24
  df_info = pd.read_csv(enriched_path)
25
  else:
26
  df_info = pd.read_csv(raw_path).rename(columns={"ticker": "Ticker"})
27
  df_info["doc"] = df_info.apply(modules.make_doc_text, axis=1)
28
  df_info.to_csv(enriched_path, index=False)
29
-
30
  df_etf, available_tickers = modules.set_etf_data(df_info)
31
  df_analyst_report = pd.read_csv("etf_analyst_report_full.csv")
32
- df_annual_return_master = pd.read_csv("annual_return.csv").rename(columns={"ticker": "Ticker"})
 
 
 
33
  return df_etf, df_analyst_report, available_tickers, df_annual_return_master
34
 
35
  @st.cache_resource(show_spinner=False)
36
  def build_search_resources():
37
- """
38
- Load or build SentenceTransformer + FAISS index + ticker list.
39
- - First run: computes embeddings, builds index, writes to disk.
40
- - Subsequent runs: loads FAISS index from disk.
41
- """
42
  df_etf, *_ = load_etf_data()
43
- repo_name = "hskwon7/paraphrase-MiniLM-L6-v2-ft-for-etf-semantic-search"
44
- model = SentenceTransformer(repo_name)
 
45
  ticker_list = df_etf["Ticker"].tolist()
46
 
47
- index_path = "etf_faiss.index"
48
- if os.path.exists(index_path):
49
- index = faiss.read_index(index_path)
50
  else:
51
- docs = df_etf["doc"].tolist()
52
- embs = model.encode(docs, convert_to_numpy=True)
53
  faiss.normalize_L2(embs)
54
  index = faiss.IndexFlatIP(embs.shape[1])
55
  index.add(embs)
56
- faiss.write_index(index, index_path)
57
 
58
  return model, index, ticker_list
59
 
60
  @st.cache_resource(show_spinner=False)
61
  def load_ner_models():
62
- """
63
- Loads two NER models for ticker extraction and builds valid_ticker_set.
64
- """
65
- repo1 = "hskwon7/distilbert-base-uncased-for-etf-ticker"
66
- repo2 = "hskwon7/albert-base-v2-for-etf-ticker"
67
- tok1, m1 = AutoTokenizer.from_pretrained(repo1), AutoModelForTokenClassification.from_pretrained(repo1)
68
- tok2, m2 = AutoTokenizer.from_pretrained(repo2), AutoModelForTokenClassification.from_pretrained(repo2)
69
-
70
  df_etf, *_ = load_etf_data()
71
- valid_ticker_set = set(t.upper() for t in df_etf["Ticker"].unique())
72
-
73
  return (tok1, m1), (tok2, m2), valid_ticker_set
74
 
75
  # ─── INITIALIZE ─────────────────────────────────────────────────────────────
@@ -78,18 +68,18 @@ df_etf, df_analyst_report, available_tickers, df_annual_return_master = load_etf
78
  s2_model, faiss_index, etf_list = build_search_resources()
79
  (tok1, m1), (tok2, m2), valid_ticker_set = load_ner_models()
80
 
81
- # ─── CORE SEARCH & EXTRACTION ───────────────────────────────────────────────
82
 
83
- def semantic_search(query: str, top_k: int = 100):
84
- q_emb = s2_model.encode([query], convert_to_numpy=True)
85
- faiss.normalize_L2(q_emb)
86
- D, I = faiss_index.search(q_emb, top_k)
87
- return [(etf_list[idx], float(D[0][i])) for i, idx in enumerate(I[0])]
88
 
89
- def ensemble_ticker_extraction(query: str):
90
  preds = set()
91
- for tok, mdl in ((tok1, m1), (tok2, m2)):
92
- enc = tok(query, return_tensors="pt")
93
  with torch.no_grad():
94
  logits = mdl(**enc).logits
95
  ids = logits.argmax(dim=-1)[0].tolist()
@@ -98,89 +88,78 @@ def ensemble_ticker_extraction(query: str):
98
  preds |= modules.extract_valid_tickers(toks, labs, tok, valid_ticker_set)
99
  return preds
100
 
101
- # ─── UI HELPERS ──────────────────────────────────────────────────────────────
102
 
103
  def display_sample_query_boxes(key_prefix=""):
104
  sample_queries = {
105
  "search_etf": {
106
  "title": "AI ETF Search",
107
  "description": "Explore ETFs by dividend, expense ratio, sector, etc.",
108
- "query": ['Find me some high-dividend ETFs in the tech sector', 'Any ETFs holding Apple?']
 
 
 
 
109
  },
110
  "comparison": {
111
  "title": "ETF Performance Comparison",
112
  "description": "Compare two ETFs side by side on performance and risk.",
113
- "query": "I'd like to compare performance of QQQ with SPY."
 
 
 
 
114
  },
115
  "portfolio_projection": {
116
  "title": "Portfolio Projection",
117
- "description": "Project a portfolio with multiple ETFs over a set number of years.",
118
- "query": "I want to invest in SPY, QQQ, SCHD, and IAU for 20 years."
 
 
 
119
  },
120
  }
121
 
122
  cols = st.columns(len(sample_queries))
123
- title_h = "60px"
124
- desc_h = "100px"
125
- query_h = "80px"
126
-
127
- st.markdown("""
128
- <style>
129
- .small-link-button {
130
- font-size: 12px;
131
- color: #0073e6;
132
- text-decoration: underline;
133
- cursor: pointer;
134
- }
135
- </style>
136
- """, unsafe_allow_html=True)
137
 
138
  for idx, (key, details) in enumerate(sample_queries.items()):
139
  with cols[idx]:
140
- box_html = f"""
141
  <div style="
142
- width: 100%; height: 300px; border: 1px solid #ddd;
143
- border-radius: 10px; padding: 15px; margin: 10px auto;
144
- background-color: #fff; box-shadow: 2px 2px 8px rgba(0,0,0,0.1);
145
- display: flex; flex-direction: column; justify-content: space-between;
146
  ">
147
- <div style="height: {title_h}; text-align: center;">
148
- <p style="margin:5px 0; color:#2c3e50; font-size:16px;">
149
- <b>{details['title']}</b>
150
- </p>
151
  </div>
152
- <div style="height: {desc_h}; text-align: center; overflow:auto;">
153
- <p style="margin:5px 0; color:#7f8c8d; font-size:14px;">
154
- {details['description']}
155
- </p>
156
  </div>
157
- <div style="height: {query_h}; text-align: center; overflow:auto;">
158
- <p style="margin:5px 0; font-style:italic; color:#34495e; font-size:13px;">
159
- {"<br>".join(f'&quot;{q}&quot;' for q in details['query'])
160
- if isinstance(details['query'], list)
161
- else f'&quot;{details["query"]}&quot;'}
162
- </p>
163
  </div>
164
  </div>
165
- """
166
- st.markdown(box_html, unsafe_allow_html=True)
167
-
168
- # center the button
169
- l, c, r = st.columns([1,2,1])
170
- with c:
171
- if st.button("Try this app", key=key_prefix + key):
172
- # just switch page
173
- page_map = {
174
- "search_etf": "ETF Search",
175
- "comparison": "ETF Comparison",
176
- "portfolio_projection": "ETF Portfolio"
177
- }
178
- st.session_state["page"] = page_map[key]
179
- st.rerun()
180
 
181
  def display_chat_history(task: str):
182
- hist = st.session_state.get(f"all_chat_history_{task}", [])
183
- for entry in hist:
184
  st.chat_message("assistant").write(entry["response"])
185
  if entry.get("fig"):
186
  st.plotly_chart(entry["fig"], use_container_width=True)
@@ -189,8 +168,7 @@ def display_chat_history(task: str):
189
 
190
  def process_query(task: str, query: str):
191
  top_k, top_n = 100, 30
192
-
193
- if task == "search_etf":
194
  with st.spinner("Searching ETFs..."):
195
  fetched = semantic_search(query, top_k)
196
  df_out = modules.get_etf_recommendations_from_list(
@@ -201,14 +179,14 @@ def process_query(task: str, query: str):
201
  modules.form_d_chat_history(str(uuid.uuid4()), msg, task, df=df_out)
202
  )
203
 
204
- elif task == "comparison":
205
  with st.spinner("Running comparison..."):
206
  tk = ensemble_ticker_extraction(query)
207
- if len(tk) != 2:
208
  resp, fig, df_out = "Please specify exactly two tickers.", None, None
209
  else:
210
  df_out = modules.get_etf_recommendations_from_list(
211
- [(t, None) for t in tk],
212
  modules.get_cols_to_display(), df_etf, top_n=2
213
  )
214
  fig = modules.compare_etfs_interactive(*tk)
@@ -217,7 +195,7 @@ def process_query(task: str, query: str):
217
  modules.form_d_chat_history(str(uuid.uuid4()), resp, task, fig=fig, df=df_out)
218
  )
219
 
220
- elif task == "portfolio_projection":
221
  with st.spinner("Projecting portfolio..."):
222
  fetched = semantic_search(query, top_k)
223
  df_port = modules.run_portfolio_analysis(fetched, df_etf, df_annual_return_master)
@@ -227,73 +205,46 @@ def process_query(task: str, query: str):
227
  modules.form_d_chat_history(str(uuid.uuid4()), resp, task, fig=fig)
228
  )
229
 
 
 
230
  def main():
231
  st.set_page_config(layout="wide")
232
-
233
- # init state
234
  if "page" not in st.session_state:
235
- st.session_state["page"] = "Home"
236
- if "user_query" not in st.session_state:
237
- st.session_state["user_query"] = ""
238
- if "auto_query_sent" not in st.session_state:
239
- st.session_state["auto_query_sent"] = False
240
  for t in ["search_etf","comparison","portfolio_projection"]:
241
  st.session_state.setdefault(f"all_chat_history_{t}", [])
242
 
243
  # sidebar
244
  st.sidebar.title("ETF Assistant")
245
- if st.sidebar.button("Home"):
246
- st.session_state["page"] = "Home"
247
- if st.sidebar.button("ETF Search"):
248
- st.session_state["page"] = "ETF Search"
249
- if st.sidebar.button("ETF Comparison"):
250
- st.session_state["page"] = "ETF Comparison"
251
- if st.sidebar.button("ETF Portfolio"):
252
- st.session_state["page"] = "ETF Portfolio"
253
 
254
- # render
255
  page = st.session_state["page"]
256
- if page == 'Home':
257
- st.title("ETF Assistant")
258
- else:
259
- st.title(page)
260
 
261
- # Home splash
262
- if page == "Home":
263
  display_sample_query_boxes(key_prefix="home_")
264
-
265
- # Sub-apps
266
  else:
267
- # auto-run if launched from Home
268
- if st.session_state["user_query"] and not st.session_state["auto_query_sent"]:
269
- process_query(
270
- {"ETF Search":"search_etf",
271
- "ETF Comparison":"comparison",
272
- "ETF Portfolio":"portfolio_projection"}[page],
273
- st.session_state["user_query"]
274
- )
275
- st.session_state["auto_query_sent"] = True
276
-
277
- task_map = {
278
- "ETF Search": "search_etf",
279
- "ETF Comparison": "comparison",
280
- "ETF Portfolio": "portfolio_projection",
281
- }
282
- task = task_map[page]
283
-
284
- display_sample_query_boxes(key_prefix="sub_")
285
- display_chat_history(task)
286
-
287
- # chat input
288
- prompt = {
289
- "ETF Search": "Search for ETFs…",
290
- "ETF Comparison": "Compare ETFs…",
291
- "ETF Portfolio": "Project portfolio…",
292
  }[page]
293
-
294
- q = st.chat_input(prompt, key="inp_" + task)
 
 
 
 
295
  if q:
296
  process_query(task, q)
297
 
298
- if __name__ == "__main__":
299
  main()
 
12
 
13
  @st.cache_data(show_spinner=False)
14
  def load_etf_data():
 
 
 
 
 
15
  enriched_path = "etf_general_info_enriched_doc_added.csv"
16
  raw_path = "etf_general_info_enriched.csv"
 
17
  if os.path.exists(enriched_path):
18
  df_info = pd.read_csv(enriched_path)
19
  else:
20
  df_info = pd.read_csv(raw_path).rename(columns={"ticker": "Ticker"})
21
  df_info["doc"] = df_info.apply(modules.make_doc_text, axis=1)
22
  df_info.to_csv(enriched_path, index=False)
 
23
  df_etf, available_tickers = modules.set_etf_data(df_info)
24
  df_analyst_report = pd.read_csv("etf_analyst_report_full.csv")
25
+ df_annual_return_master = (
26
+ pd.read_csv("annual_return.csv")
27
+ .rename(columns={"ticker": "Ticker"})
28
+ )
29
  return df_etf, df_analyst_report, available_tickers, df_annual_return_master
30
 
31
  @st.cache_resource(show_spinner=False)
32
  def build_search_resources():
 
 
 
 
 
33
  df_etf, *_ = load_etf_data()
34
+ model = SentenceTransformer(
35
+ "hskwon7/paraphrase-MiniLM-L6-v2-ft-for-etf-semantic-search"
36
+ )
37
  ticker_list = df_etf["Ticker"].tolist()
38
 
39
+ idx_path = "etf_faiss.index"
40
+ if os.path.exists(idx_path):
41
+ index = faiss.read_index(idx_path)
42
  else:
43
+ embs = model.encode(df_etf["doc"].tolist(), convert_to_numpy=True)
 
44
  faiss.normalize_L2(embs)
45
  index = faiss.IndexFlatIP(embs.shape[1])
46
  index.add(embs)
47
+ faiss.write_index(index, idx_path)
48
 
49
  return model, index, ticker_list
50
 
51
  @st.cache_resource(show_spinner=False)
52
  def load_ner_models():
53
+ tok1, m1 = (
54
+ AutoTokenizer.from_pretrained("hskwon7/distilbert-base-uncased-for-etf-ticker"),
55
+ AutoModelForTokenClassification.from_pretrained("hskwon7/distilbert-base-uncased-for-etf-ticker")
56
+ )
57
+ tok2, m2 = (
58
+ AutoTokenizer.from_pretrained("hskwon7/albert-base-v2-for-etf-ticker"),
59
+ AutoModelForTokenClassification.from_pretrained("hskwon7/albert-base-v2-for-etf-ticker")
60
+ )
61
  df_etf, *_ = load_etf_data()
62
+ valid_ticker_set = set(df_etf["Ticker"].str.upper())
 
63
  return (tok1, m1), (tok2, m2), valid_ticker_set
64
 
65
  # ─── INITIALIZE ─────────────────────────────────────────────────────────────
 
68
  s2_model, faiss_index, etf_list = build_search_resources()
69
  (tok1, m1), (tok2, m2), valid_ticker_set = load_ner_models()
70
 
71
+ # ─── CORE ROUTINES ──────────────────────────────────────────────────────────
72
 
73
+ def semantic_search(q: str, top_k: int=100):
74
+ emb = s2_model.encode([q], convert_to_numpy=True)
75
+ faiss.normalize_L2(emb)
76
+ D, I = faiss_index.search(emb, top_k)
77
+ return [(etf_list[i], float(D[0][j])) for j,i in enumerate(I[0])]
78
 
79
+ def ensemble_ticker_extraction(q: str):
80
  preds = set()
81
+ for tok, mdl in ((tok1,m1),(tok2,m2)):
82
+ enc = tok(q, return_tensors="pt")
83
  with torch.no_grad():
84
  logits = mdl(**enc).logits
85
  ids = logits.argmax(dim=-1)[0].tolist()
 
88
  preds |= modules.extract_valid_tickers(toks, labs, tok, valid_ticker_set)
89
  return preds
90
 
91
+ # ─── UI HELPERS ─────────────────────────────────────────────────────────────
92
 
93
  def display_sample_query_boxes(key_prefix=""):
94
  sample_queries = {
95
  "search_etf": {
96
  "title": "AI ETF Search",
97
  "description": "Explore ETFs by dividend, expense ratio, sector, etc.",
98
+ "query": [
99
+ 'High-dividend ETFs in the tech sector',
100
+ 'Precious metals ETFs with low expense ratio',
101
+ 'Large growth ETFs with high returns'
102
+ ]
103
  },
104
  "comparison": {
105
  "title": "ETF Performance Comparison",
106
  "description": "Compare two ETFs side by side on performance and risk.",
107
+ "query": [
108
+ "I'd like to compare performance of QQQ with GLD.",
109
+ "Compare SPY and VOO.",
110
+ "SCHD vs. VTI"
111
+ ]
112
  },
113
  "portfolio_projection": {
114
  "title": "Portfolio Projection",
115
+ "description": "Project a portfolio with multiple ETFs over 30 years.",
116
+ "query": [
117
+ "I want to invest in SPY, QQQ, SCHD, and IAU.",
118
+ "Portfolio projection for VTI, XLF, and XLY."
119
+ ]
120
  },
121
  }
122
 
123
  cols = st.columns(len(sample_queries))
124
+ title_h, desc_h, query_h = "40px", "60px", "60px"
 
 
 
 
 
 
 
 
 
 
 
 
 
125
 
126
  for idx, (key, details) in enumerate(sample_queries.items()):
127
  with cols[idx]:
128
+ st.markdown(f"""
129
  <div style="
130
+ width:100%; height:300px; border:1px solid #ddd;
131
+ border-radius:10px; padding:15px; margin:auto;
132
+ display:flex; flex-direction:column; justify-content:space-between;
133
+ box-shadow:2px 2px 8px rgba(0,0,0,0.1);
134
  ">
135
+ <div style="height:{title_h}; text-align:center;">
136
+ <b style="font-size:16px; color:#2c3e50;">
137
+ {details['title']}
138
+ </b>
139
  </div>
140
+ <div style="height:{desc_h}; text-align:center; color:#7f8c8d; font-size:14px; overflow:auto;">
141
+ {details['description']}
 
 
142
  </div>
143
+ <div style="height:{query_h}; text-align:center; color:#34495e; font-size:13px; font-style:italic; overflow:auto;">
144
+ {'<br>'.join(f'β€œ{q}”' for q in details['query'])}
 
 
 
 
145
  </div>
146
  </div>
147
+ """, unsafe_allow_html=True)
148
+
149
+ # center the button directly under the box
150
+ st.markdown("<div style='text-align:center; margin-top:10px;'>", unsafe_allow_html=True)
151
+ if st.button("Try this app", key=key_prefix+key):
152
+ page_map = {
153
+ "search_etf": "ETF Search",
154
+ "comparison": "ETF Comparison",
155
+ "portfolio_projection": "ETF Portfolio"
156
+ }
157
+ st.session_state["page"] = page_map[key]
158
+ st.experimental_rerun()
159
+ st.markdown("</div>", unsafe_allow_html=True)
 
 
160
 
161
  def display_chat_history(task: str):
162
+ for entry in st.session_state.get(f"all_chat_history_{task}", []):
 
163
  st.chat_message("assistant").write(entry["response"])
164
  if entry.get("fig"):
165
  st.plotly_chart(entry["fig"], use_container_width=True)
 
168
 
169
  def process_query(task: str, query: str):
170
  top_k, top_n = 100, 30
171
+ if task=="search_etf":
 
172
  with st.spinner("Searching ETFs..."):
173
  fetched = semantic_search(query, top_k)
174
  df_out = modules.get_etf_recommendations_from_list(
 
179
  modules.form_d_chat_history(str(uuid.uuid4()), msg, task, df=df_out)
180
  )
181
 
182
+ elif task=="comparison":
183
  with st.spinner("Running comparison..."):
184
  tk = ensemble_ticker_extraction(query)
185
+ if len(tk)!=2:
186
  resp, fig, df_out = "Please specify exactly two tickers.", None, None
187
  else:
188
  df_out = modules.get_etf_recommendations_from_list(
189
+ [(t,None) for t in tk],
190
  modules.get_cols_to_display(), df_etf, top_n=2
191
  )
192
  fig = modules.compare_etfs_interactive(*tk)
 
195
  modules.form_d_chat_history(str(uuid.uuid4()), resp, task, fig=fig, df=df_out)
196
  )
197
 
198
+ elif task=="portfolio_projection":
199
  with st.spinner("Projecting portfolio..."):
200
  fetched = semantic_search(query, top_k)
201
  df_port = modules.run_portfolio_analysis(fetched, df_etf, df_annual_return_master)
 
205
  modules.form_d_chat_history(str(uuid.uuid4()), resp, task, fig=fig)
206
  )
207
 
208
+ # ─── MAIN ────────────────────────────────────────────────────────────────
209
+
210
  def main():
211
  st.set_page_config(layout="wide")
212
+ # init
 
213
  if "page" not in st.session_state:
214
+ st.session_state["page"]="Home"
 
 
 
 
215
  for t in ["search_etf","comparison","portfolio_projection"]:
216
  st.session_state.setdefault(f"all_chat_history_{t}", [])
217
 
218
  # sidebar
219
  st.sidebar.title("ETF Assistant")
220
+ if st.sidebar.button("🏠 Home"):
221
+ st.session_state["page"]="Home"
222
+ if st.sidebar.button("πŸ”Ž ETF Search"):
223
+ st.session_state["page"]="ETF Search"
224
+ if st.sidebar.button("βš–οΈ ETF Comparison"):
225
+ st.session_state["page"]="ETF Comparison"
226
+ if st.sidebar.button("πŸ’Ό ETF Portfolio"):
227
+ st.session_state["page"]="ETF Portfolio"
228
 
 
229
  page = st.session_state["page"]
230
+ st.title(page if page!="Home" else "ETF Assistant")
 
 
 
231
 
232
+ if page=="Home":
 
233
  display_sample_query_boxes(key_prefix="home_")
 
 
234
  else:
235
+ task = {
236
+ "ETF Search":"search_etf",
237
+ "ETF Comparison":"comparison",
238
+ "ETF Portfolio":"portfolio_projection"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
239
  }[page]
240
+ display_chat_history(task)
241
+ q = st.chat_input({
242
+ "ETF Search":"Search for ETFs…",
243
+ "ETF Comparison":"Compare ETFs…",
244
+ "ETF Portfolio":"Project portfolio…"
245
+ }[page], key=task)
246
  if q:
247
  process_query(task, q)
248
 
249
+ if __name__=="__main__":
250
  main()