hskwon7 commited on
Commit
34e567f
·
verified ·
1 Parent(s): 28a4d62

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -159
app.py CHANGED
@@ -14,8 +14,8 @@ from transformers import AutoTokenizer, AutoModelForTokenClassification
14
  def load_etf_data():
15
  """
16
  Load ETF data with a persistent 'doc' column.
17
- - On first run: reads raw CSV, computes 'doc', saves enriched CSV.
18
- - On subsequent runs: loads enriched CSV directly.
19
  """
20
  enriched_path = "etf_general_info_enriched_doc_added.csv"
21
  raw_path = "etf_general_info_enriched.csv"
@@ -27,29 +27,21 @@ def load_etf_data():
27
  df_info["doc"] = df_info.apply(modules.make_doc_text, axis=1)
28
  df_info.to_csv(enriched_path, index=False)
29
 
30
- # Split into DataFrame and ticker list
31
  df_etf, available_tickers = modules.set_etf_data(df_info)
32
-
33
- # Load other supporting DataFrames
34
  df_analyst_report = pd.read_csv("etf_analyst_report_full.csv")
35
- df_annual_return_master = (
36
- pd.read_csv("annual_return.csv").rename(columns={"ticker": "Ticker"})
37
- )
38
  return df_etf, df_analyst_report, available_tickers, df_annual_return_master
39
 
40
  @st.cache_resource(show_spinner=False)
41
  def build_search_resources():
42
  """
43
- Loads (or builds) SentenceTransformer + FAISS index + ticker list.
44
- - On first run: computes embeddings, builds index, writes to disk.
45
- - On subsequent runs: loads FAISS index from disk.
46
  """
47
  df_etf, *_ = load_etf_data()
48
-
49
- # Load SentenceTransformer
50
- repo_name = "hskwon7/paraphrase-MiniLM-L6-v2-ft-for-etf-semantic-search"
51
- model = SentenceTransformer(repo_name)
52
-
53
  ticker_list = df_etf["Ticker"].tolist()
54
 
55
  index_path = "etf_faiss.index"
@@ -68,28 +60,25 @@ def build_search_resources():
68
  @st.cache_resource(show_spinner=False)
69
  def load_ner_models():
70
  """
71
- Loads two NER models for ticker extraction, and builds the valid ticker set.
72
  """
73
  repo1 = "hskwon7/distilbert-base-uncased-for-etf-ticker"
74
  repo2 = "hskwon7/albert-base-v2-for-etf-ticker"
75
-
76
- tok1 = AutoTokenizer.from_pretrained(repo1)
77
- m1 = AutoModelForTokenClassification.from_pretrained(repo1)
78
- tok2 = AutoTokenizer.from_pretrained(repo2)
79
- m2 = AutoModelForTokenClassification.from_pretrained(repo2)
80
 
81
  df_etf, *_ = load_etf_data()
82
  valid_ticker_set = set(t.upper() for t in df_etf["Ticker"].unique())
83
 
84
  return (tok1, m1), (tok2, m2), valid_ticker_set
85
 
86
- # ─── INITIALIZE CACHED RESOURCES ─────────────────────────────────────────────
87
 
88
  df_etf, df_analyst_report, available_tickers, df_annual_return_master = load_etf_data()
89
  s2_model, faiss_index, etf_list = build_search_resources()
90
  (tok1, m1), (tok2, m2), valid_ticker_set = load_ner_models()
91
 
92
- # ─── CORE SEARCH & EXTRACTION ─────────────────────────────────────────────────
93
 
94
  def semantic_search(query: str, top_k: int = 100):
95
  q_emb = s2_model.encode([query], convert_to_numpy=True)
@@ -100,54 +89,94 @@ def semantic_search(query: str, top_k: int = 100):
100
  def ensemble_ticker_extraction(query: str):
101
  preds = set()
102
  for tok, mdl in ((tok1, m1), (tok2, m2)):
103
- enc = tok(query, return_tensors="pt")
104
  with torch.no_grad():
105
  logits = mdl(**enc).logits
106
- ids = logits.argmax(dim=-1)[0].tolist()
107
- toks = tok.convert_ids_to_tokens(enc["input_ids"][0])
108
- labs = [mdl.config.id2label[i] for i in ids]
109
  preds |= modules.extract_valid_tickers(toks, labs, tok, valid_ticker_set)
110
  return preds
111
 
112
- # ─── HELPERS ────────────────────────���───────────────────────────────────────
113
 
114
- def display_sample_query_box(task: str):
115
  sample_queries = {
116
  "search_etf": {
117
- "title": "ETF Search",
118
- "description": "Explore ETFs based on dividend, expense ratio, sector, etc.",
119
- "examples": ['Find me technology ETFs', 'Show me Cryptocurrency ETFs']
120
  },
121
  "comparison": {
122
- "title": "ETF Comparison",
123
- "description": "Compare two ETFs side by side on performance, risk, etc.",
124
- "examples": ["QQQ vs. SPY", "Compare performance of QQQ with SPY"]
125
  },
126
  "portfolio_projection": {
127
- "title": "ETF Portfolio",
128
- "description": "Project a multi-ETF portfolio out over 30 years.",
129
- "examples": ["SPY, GLD, BND", "I want to invest in SCHD and IAU"]
130
- }
131
  }
132
- details = sample_queries[task]
133
- box_html = f"""
134
- <div style='border:1px solid #ddd;padding:1rem;border-radius:8px;'>
135
- <h4>{details['title']}</h4>
136
- <p style='color:#555;margin-bottom:.5rem;'>{details['description']}</p>
137
- <p style='font-style:italic;color:#333;'>
138
- Examples:<br/>{'<br/>'.join(details['examples'])}
139
- </p>
140
- </div>
141
- """
142
- st.markdown(box_html, unsafe_allow_html=True)
143
- if st.button("Try this app", key=f"try_{task}"):
144
- page_map = {
145
- "search_etf": "ETF Search",
146
- "comparison": "ETF Comparison",
147
- "portfolio_projection": "ETF Portfolio"
148
- }
149
- st.session_state["page"] = page_map[task]
150
- st.experimental_rerun()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
  def display_chat_history(task: str):
153
  hist = st.session_state.get(f"all_chat_history_{task}", [])
@@ -161,122 +190,110 @@ def display_chat_history(task: str):
161
  def process_query(task: str, query: str):
162
  top_k, top_n = 100, 30
163
 
164
- if task == 'search_etf':
165
  with st.spinner("Searching ETFs..."):
166
- fetched = semantic_search(query, top_k=top_k)
167
  df_out = modules.get_etf_recommendations_from_list(
168
- fetched, df_etf, top_n=top_n
169
  )
 
170
  st.session_state[f"all_chat_history_{task}"].append(
171
- modules.form_d_chat_history(
172
- result_id=str(uuid.uuid4()),
173
- response=f"{len(df_out)} ETFs found.",
174
- task=task,
175
- df=df_out
176
- )
177
  )
178
 
179
- elif task == 'comparison':
180
  with st.spinner("Running comparison..."):
181
- tickers = ensemble_ticker_extraction(query)
182
- if len(tickers) != 2:
183
- response, fig, df_out = (
184
- "Please specify exactly two tickers.", None, None
185
- )
186
  else:
187
  df_out = modules.get_etf_recommendations_from_list(
188
- tickers,
189
- df_etf, top_n=2
190
  )
191
- fig = modules.compare_etfs_interactive(*tickers)
192
- response = f"Compared {tickers[0]} vs. {tickers[1]}."
193
  st.session_state[f"all_chat_history_{task}"].append(
194
- modules.form_d_chat_history(
195
- result_id=str(uuid.uuid4()),
196
- response=response,
197
- task=task,
198
- fig=fig,
199
- df=df_out
200
- )
201
  )
202
 
203
- elif task == 'portfolio_projection':
204
  with st.spinner("Projecting portfolio..."):
205
- fetched = semantic_search(query, top_k=top_k)
206
- df_port = modules.run_portfolio_analysis(
207
- fetched, df_etf, df_annual_return_master
208
- )
209
  fig = modules.portfolio_interactive_chart(df_port)
210
- response = "30-year projection generated."
211
  st.session_state[f"all_chat_history_{task}"].append(
212
- modules.form_d_chat_history(
213
- result_id=str(uuid.uuid4()),
214
- response=response,
215
- task=task,
216
- fig=fig
217
- )
218
  )
219
 
220
- def display_explore_etfs_chat():
221
- st.markdown("""
222
- <style>
223
- .main .block-container {
224
- max-width: 90% !important;
225
- padding: 1rem 2rem;
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
226
  }
227
- </style>
228
- """, unsafe_allow_html=True)
229
- st.header("How can I assist you today?")
230
- st.write("Find ETFs that align with your investment goals and sector interests, compare performance, estimate your portfolio, and get insights.")
231
- display_sample_query_box("search_etf")
232
- display_sample_query_box("comparison")
233
- display_sample_query_box("portfolio_projection")
234
-
235
- # ─── MAIN ────────────────────────────────────────────────────────────────
236
-
237
- st.set_page_config(layout="wide")
238
-
239
- # initialize page and histories
240
- if "page" not in st.session_state:
241
- st.session_state["page"] = "Home"
242
- for t in ["search_etf", "comparison", "portfolio_projection"]:
243
- st.session_state.setdefault(f"all_chat_history_{t}", [])
244
-
245
- # sidebar navigation buttons
246
- st.sidebar.title("Navigation")
247
- if st.sidebar.button("🏠 Home"):
248
- st.session_state["page"] = "Home"
249
- if st.sidebar.button("🔎 ETF Search"):
250
- st.session_state["page"] = "ETF Search"
251
- if st.sidebar.button("⚖️ ETF Comparison"):
252
- st.session_state["page"] = "ETF Comparison"
253
- if st.sidebar.button("💼 ETF Portfolio"):
254
- st.session_state["page"] = "ETF Portfolio"
255
-
256
- # render page
257
- page = st.session_state["page"]
258
- st.title(page)
259
-
260
- if page == "Home":
261
- display_explore_etfs_chat()
262
-
263
- elif page == "ETF Search":
264
- display_sample_query_box("search_etf")
265
- display_chat_history("search_etf")
266
- q = st.chat_input("Search for ETFs…", key="in_search")
267
- if q:
268
- process_query("search_etf", q)
269
-
270
- elif page == "ETF Comparison":
271
- display_sample_query_box("comparison")
272
- display_chat_history("comparison")
273
- q = st.chat_input("Compare ETFs…", key="in_comp")
274
- if q:
275
- process_query("comparison", q)
276
-
277
- elif page == "ETF Portfolio":
278
- display_sample_query_box("portfolio_projection")
279
- display_chat_history("portfolio_projection")
280
- q = st.chat_input("Project portfolio…", key="in_port")
281
- if q:
282
- process_query("portfolio_projection", q)
 
14
  def load_etf_data():
15
  """
16
  Load ETF data with a persistent 'doc' column.
17
+ - First run: reads raw CSV, computes 'doc', saves enriched CSV.
18
+ - Subsequent runs: loads enriched CSV directly.
19
  """
20
  enriched_path = "etf_general_info_enriched_doc_added.csv"
21
  raw_path = "etf_general_info_enriched.csv"
 
27
  df_info["doc"] = df_info.apply(modules.make_doc_text, axis=1)
28
  df_info.to_csv(enriched_path, index=False)
29
 
 
30
  df_etf, available_tickers = modules.set_etf_data(df_info)
 
 
31
  df_analyst_report = pd.read_csv("etf_analyst_report_full.csv")
32
+ df_annual_return_master = pd.read_csv("annual_return.csv").rename(columns={"ticker": "Ticker"})
 
 
33
  return df_etf, df_analyst_report, available_tickers, df_annual_return_master
34
 
35
  @st.cache_resource(show_spinner=False)
36
  def build_search_resources():
37
  """
38
+ Load or build SentenceTransformer + FAISS index + ticker list.
39
+ - First run: computes embeddings, builds index, writes to disk.
40
+ - Subsequent runs: loads FAISS index from disk.
41
  """
42
  df_etf, *_ = load_etf_data()
43
+ repo_name = "hskwon7/paraphrase-MiniLM-L6-v2-ft-for-etf-semantic-search"
44
+ model = SentenceTransformer(repo_name)
 
 
 
45
  ticker_list = df_etf["Ticker"].tolist()
46
 
47
  index_path = "etf_faiss.index"
 
60
  @st.cache_resource(show_spinner=False)
61
  def load_ner_models():
62
  """
63
+ Loads two NER models for ticker extraction and builds valid_ticker_set.
64
  """
65
  repo1 = "hskwon7/distilbert-base-uncased-for-etf-ticker"
66
  repo2 = "hskwon7/albert-base-v2-for-etf-ticker"
67
+ tok1, m1 = AutoTokenizer.from_pretrained(repo1), AutoModelForTokenClassification.from_pretrained(repo1)
68
+ tok2, m2 = AutoTokenizer.from_pretrained(repo2), AutoModelForTokenClassification.from_pretrained(repo2)
 
 
 
69
 
70
  df_etf, *_ = load_etf_data()
71
  valid_ticker_set = set(t.upper() for t in df_etf["Ticker"].unique())
72
 
73
  return (tok1, m1), (tok2, m2), valid_ticker_set
74
 
75
+ # ─── INITIALIZE ─────────────────────────────────────────────────────────────
76
 
77
  df_etf, df_analyst_report, available_tickers, df_annual_return_master = load_etf_data()
78
  s2_model, faiss_index, etf_list = build_search_resources()
79
  (tok1, m1), (tok2, m2), valid_ticker_set = load_ner_models()
80
 
81
+ # ─── CORE SEARCH & EXTRACTION ───────────────────────────────────────────────
82
 
83
  def semantic_search(query: str, top_k: int = 100):
84
  q_emb = s2_model.encode([query], convert_to_numpy=True)
 
89
  def ensemble_ticker_extraction(query: str):
90
  preds = set()
91
  for tok, mdl in ((tok1, m1), (tok2, m2)):
92
+ enc = tok(query, return_tensors="pt")
93
  with torch.no_grad():
94
  logits = mdl(**enc).logits
95
+ ids = logits.argmax(dim=-1)[0].tolist()
96
+ toks = tok.convert_ids_to_tokens(enc["input_ids"][0])
97
+ labs = [mdl.config.id2label[i] for i in ids]
98
  preds |= modules.extract_valid_tickers(toks, labs, tok, valid_ticker_set)
99
  return preds
100
 
101
+ # ─── UI HELPERS ──────────────────────────────────────────────────────────────
102
 
103
+ def display_sample_query_boxes(key_prefix=""):
104
  sample_queries = {
105
  "search_etf": {
106
+ "title": "AI ETF Search",
107
+ "description": "Explore ETFs by dividend, expense ratio, sector, etc.",
108
+ "query": ['Find me some high-dividend ETFs in the tech sector', 'Any ETFs holding Apple?']
109
  },
110
  "comparison": {
111
+ "title": "ETF Performance Comparison",
112
+ "description": "Compare two ETFs side by side on performance and risk.",
113
+ "query": "I'd like to compare performance of QQQ with SPY."
114
  },
115
  "portfolio_projection": {
116
+ "title": "Portfolio Projection",
117
+ "description": "Project a portfolio with multiple ETFs over a set number of years.",
118
+ "query": "I want to invest in SPY, QQQ, SCHD, and IAU for 20 years."
119
+ },
120
  }
121
+
122
+ cols = st.columns(len(sample_queries))
123
+ title_h = "60px"
124
+ desc_h = "100px"
125
+ query_h = "80px"
126
+
127
+ st.markdown("""
128
+ <style>
129
+ .small-link-button {
130
+ font-size: 12px;
131
+ color: #0073e6;
132
+ text-decoration: underline;
133
+ cursor: pointer;
134
+ }
135
+ </style>
136
+ """, unsafe_allow_html=True)
137
+
138
+ for idx, (key, details) in enumerate(sample_queries.items()):
139
+ with cols[idx]:
140
+ box_html = f"""
141
+ <div style="
142
+ width: 100%; height: 300px; border: 1px solid #ddd;
143
+ border-radius: 10px; padding: 15px; margin: 10px auto;
144
+ background-color: #fff; box-shadow: 2px 2px 8px rgba(0,0,0,0.1);
145
+ display: flex; flex-direction: column; justify-content: space-between;
146
+ ">
147
+ <div style="height: {title_h}; text-align: center;">
148
+ <p style="margin:5px 0; color:#2c3e50; font-size:16px;">
149
+ <b>{details['title']}</b>
150
+ </p>
151
+ </div>
152
+ <div style="height: {desc_h}; text-align: center; overflow:auto;">
153
+ <p style="margin:5px 0; color:#7f8c8d; font-size:14px;">
154
+ {details['description']}
155
+ </p>
156
+ </div>
157
+ <div style="height: {query_h}; text-align: center; overflow:auto;">
158
+ <p style="margin:5px 0; font-style:italic; color:#34495e; font-size:13px;">
159
+ {"<br>".join(f'&quot;{q}&quot;' for q in details['query'])
160
+ if isinstance(details['query'], list)
161
+ else f'&quot;{details["query"]}&quot;'}
162
+ </p>
163
+ </div>
164
+ </div>
165
+ """
166
+ st.markdown(box_html, unsafe_allow_html=True)
167
+
168
+ # center the button
169
+ l, c, r = st.columns([1,2,1])
170
+ with c:
171
+ if st.button("Try this app", key=key_prefix + key):
172
+ # just switch page
173
+ page_map = {
174
+ "search_etf": "ETF Search",
175
+ "comparison": "ETF Comparison",
176
+ "portfolio_projection": "ETF Portfolio"
177
+ }
178
+ st.session_state["page"] = page_map[key]
179
+ st.rerun()
180
 
181
  def display_chat_history(task: str):
182
  hist = st.session_state.get(f"all_chat_history_{task}", [])
 
190
  def process_query(task: str, query: str):
191
  top_k, top_n = 100, 30
192
 
193
+ if task == "search_etf":
194
  with st.spinner("Searching ETFs..."):
195
+ fetched = semantic_search(query, top_k)
196
  df_out = modules.get_etf_recommendations_from_list(
197
+ fetched, modules.get_cols_to_display(), df_etf, top_n
198
  )
199
+ msg = f"{len(df_out)} ETFs found."
200
  st.session_state[f"all_chat_history_{task}"].append(
201
+ modules.form_d_chat_history(str(uuid.uuid4()), msg, task, df=df_out)
 
 
 
 
 
202
  )
203
 
204
+ elif task == "comparison":
205
  with st.spinner("Running comparison..."):
206
+ tk = ensemble_ticker_extraction(query)
207
+ if len(tk) != 2:
208
+ resp, fig, df_out = "Please specify exactly two tickers.", None, None
 
 
209
  else:
210
  df_out = modules.get_etf_recommendations_from_list(
211
+ [(t, None) for t in tk],
212
+ modules.get_cols_to_display(), df_etf, top_n=2
213
  )
214
+ fig = modules.compare_etfs_interactive(*tk)
215
+ resp = f"Compared {tk[0]} vs. {tk[1]}."
216
  st.session_state[f"all_chat_history_{task}"].append(
217
+ modules.form_d_chat_history(str(uuid.uuid4()), resp, task, fig=fig, df=df_out)
 
 
 
 
 
 
218
  )
219
 
220
+ elif task == "portfolio_projection":
221
  with st.spinner("Projecting portfolio..."):
222
+ fetched = semantic_search(query, top_k)
223
+ df_port = modules.run_portfolio_analysis(fetched, df_etf, df_annual_return_master)
 
 
224
  fig = modules.portfolio_interactive_chart(df_port)
225
+ resp = "30-year projection generated."
226
  st.session_state[f"all_chat_history_{task}"].append(
227
+ modules.form_d_chat_history(str(uuid.uuid4()), resp, task, fig=fig)
 
 
 
 
 
228
  )
229
 
230
+ def main():
231
+ st.set_page_config(layout="wide")
232
+
233
+ # init state
234
+ if "page" not in st.session_state:
235
+ st.session_state["page"] = "Home"
236
+ if "user_query" not in st.session_state:
237
+ st.session_state["user_query"] = ""
238
+ if "auto_query_sent" not in st.session_state:
239
+ st.session_state["auto_query_sent"] = False
240
+ for t in ["search_etf","comparison","portfolio_projection"]:
241
+ st.session_state.setdefault(f"all_chat_history_{t}", [])
242
+
243
+ # sidebar
244
+ st.sidebar.title("ETF Assistant")
245
+ if st.sidebar.button("Home"):
246
+ st.session_state["page"] = "Home"
247
+ if st.sidebar.button("ETF Search"):
248
+ st.session_state["page"] = "ETF Search"
249
+ if st.sidebar.button("ETF Comparison"):
250
+ st.session_state["page"] = "ETF Comparison"
251
+ if st.sidebar.button("ETF Portfolio"):
252
+ st.session_state["page"] = "ETF Portfolio"
253
+
254
+ # render
255
+ page = st.session_state["page"]
256
+ if page == 'Home':
257
+ st.title("ETF Assistant")
258
+ else:
259
+ st.title(page)
260
+
261
+ # Home splash
262
+ if page == "Home":
263
+ display_sample_query_boxes(key_prefix="home_")
264
+
265
+ # Sub-apps
266
+ else:
267
+ # auto-run if launched from Home
268
+ if st.session_state["user_query"] and not st.session_state["auto_query_sent"]:
269
+ process_query(
270
+ {"ETF Search":"search_etf",
271
+ "ETF Comparison":"comparison",
272
+ "ETF Portfolio":"portfolio_projection"}[page],
273
+ st.session_state["user_query"]
274
+ )
275
+ st.session_state["auto_query_sent"] = True
276
+
277
+ task_map = {
278
+ "ETF Search": "search_etf",
279
+ "ETF Comparison": "comparison",
280
+ "ETF Portfolio": "portfolio_projection",
281
  }
282
+ task = task_map[page]
283
+
284
+ display_sample_query_boxes(key_prefix="sub_")
285
+ display_chat_history(task)
286
+
287
+ # chat input
288
+ prompt = {
289
+ "ETF Search": "Search for ETFs…",
290
+ "ETF Comparison": "Compare ETFs…",
291
+ "ETF Portfolio": "Project portfolio…",
292
+ }[page]
293
+
294
+ q = st.chat_input(prompt, key="inp_" + task)
295
+ if q:
296
+ process_query(task, q)
297
+
298
+ if __name__ == "__main__":
299
+ main()