File size: 15,936 Bytes
f0e4a41
3718f37
f0e4a41
 
 
 
02470e1
 
 
a80ed0d
30de4d5
f0e4a41
 
 
 
 
 
 
a80ed0d
f0e4a41
 
 
 
a80ed0d
 
 
f0e4a41
a80ed0d
2d69da8
 
 
 
f0e4a41
 
30de4d5
f0e4a41
28a4d62
2d69da8
 
 
30de4d5
 
2d69da8
 
 
f0e4a41
2d69da8
f0e4a41
 
 
2d69da8
02470e1
30de4d5
02470e1
30de4d5
 
2d69da8
 
 
 
 
 
 
 
f0e4a41
2d69da8
f0e4a41
30de4d5
34e567f
30de4d5
f0e4a41
 
30de4d5
02470e1
2d69da8
f0e4a41
a80ed0d
 
2d69da8
 
 
a80ed0d
 
 
15444a5
 
 
02470e1
15444a5
2d69da8
15444a5
02470e1
 
15444a5
 
 
 
 
02470e1
 
a80ed0d
 
 
 
 
2d69da8
02470e1
34e567f
02470e1
 
d3dd799
a80ed0d
2d69da8
a80ed0d
 
 
2d69da8
02470e1
 
34e567f
a80ed0d
2d69da8
 
 
 
 
02470e1
 
34e567f
a80ed0d
2d69da8
 
 
 
34e567f
02470e1
34e567f
 
7eb1513
34e567f
 
 
2d69da8
34e567f
7eb1513
2d69da8
 
 
34e567f
2d69da8
 
 
 
34e567f
2d69da8
 
34e567f
2d69da8
 
34e567f
 
2d69da8
 
 
 
d3dd799
2d69da8
 
 
 
 
 
8900633
2d69da8
02470e1
 
2d69da8
a80ed0d
e5c69a4
28a4d62
02470e1
28a4d62
 
a80ed0d
 
02470e1
 
a80ed0d
 
2d69da8
a80ed0d
15444a5
a80ed0d
 
 
 
 
 
 
 
34e567f
a80ed0d
 
28a4d62
0411aa9
28a4d62
a80ed0d
 
 
 
 
 
 
 
 
 
 
02470e1
a80ed0d
02470e1
 
2d69da8
a80ed0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d69da8
a80ed0d
02470e1
a80ed0d
02470e1
0411aa9
02470e1
a80ed0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02470e1
a80ed0d
02470e1
 
2d69da8
a80ed0d
 
 
 
02470e1
a80ed0d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02470e1
 
2d69da8
 
34e567f
 
2d69da8
34e567f
2d69da8
34e567f
 
 
 
 
2d69da8
 
 
 
 
 
 
 
34e567f
a80ed0d
34e567f
2d69da8
34e567f
a80ed0d
2d69da8
a80ed0d
 
 
 
 
 
 
 
 
 
 
 
34e567f
 
a80ed0d
2d69da8
 
 
 
34e567f
a80ed0d
 
 
 
 
 
 
 
 
 
2d69da8
a80ed0d
 
2d69da8
 
 
 
 
a80ed0d
 
34e567f
 
 
2d69da8
34e567f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
import os
import streamlit as st
import uuid
import pandas as pd
import modules
import torch
from sentence_transformers import SentenceTransformer
import faiss
from transformers import AutoTokenizer, AutoModelForTokenClassification
import re

# ─── CACHES ─────────────────────────────────────────────────────────────────

@st.cache_data(show_spinner=False)
def load_etf_data():
    enriched_path = "etf_general_info_enriched_doc_added.csv"
    raw_path      = "etf_general_info_enriched.csv"
    if os.path.exists(enriched_path):
        df_info = pd.read_csv(enriched_path).rename(columns={"ticker": "Ticker"})
    else:
        df_info = pd.read_csv(raw_path).rename(columns={"ticker": "Ticker"})
        df_info["doc"] = df_info.apply(modules.make_doc_text, axis=1)
        df_info.to_csv(enriched_path, index=False)
    df_etf_holdings = pd.read_csv('etf_holdings_summarized.csv').rename(columns={'ticker': 'Ticker',
                                                                                  'holdingInformation': 'Holdings'})
    df_info = df_info.merge(df_etf_holdings, how='left', on='Ticker') 
    df_etf, available_tickers = modules.set_etf_data(df_info)
    df_analyst_report       = pd.read_csv("etf_analyst_report_full.csv").rename(columns={"ticker": "Ticker"})
    df_annual_return_master = (
        pd.read_csv("annual_return.csv")
          .rename(columns={"ticker": "Ticker"})
    )
    return df_etf, df_analyst_report, available_tickers, df_annual_return_master

@st.cache_resource(show_spinner=False)
def build_search_resources():
    df_etf, *_ = load_etf_data()
    model      = SentenceTransformer(
        "hskwon7/paraphrase-MiniLM-L6-v2-ft-for-etf-semantic-search"
    )
    ticker_list = df_etf["Ticker"].tolist()

    idx_path = "etf_faiss.index"
    if os.path.exists(idx_path):
        index = faiss.read_index(idx_path)
    else:
        embs = model.encode(df_etf["doc"].tolist(), convert_to_numpy=True)
        faiss.normalize_L2(embs)
        index = faiss.IndexFlatIP(embs.shape[1])
        index.add(embs)
        faiss.write_index(index, idx_path)

    return model, index, ticker_list

@st.cache_resource(show_spinner=False)
def load_ner_models():
    tok1, m1 = (
        AutoTokenizer.from_pretrained("hskwon7/distilbert-base-uncased-for-etf-ticker"),
        AutoModelForTokenClassification.from_pretrained("hskwon7/distilbert-base-uncased-for-etf-ticker")
    )
    tok2, m2 = (
        AutoTokenizer.from_pretrained("hskwon7/albert-base-v2-for-etf-ticker"),
        AutoModelForTokenClassification.from_pretrained("hskwon7/albert-base-v2-for-etf-ticker")
    )
    df_etf, *_ = load_etf_data()
    valid_ticker_set = set(df_etf["Ticker"].str.upper())
    return (tok1, m1), (tok2, m2), valid_ticker_set

# ─── INITIALIZE ─────────────────────────────────────────────────────────────

df_etf, df_analyst_report, available_tickers, df_annual_return_master = load_etf_data()
s2_model, faiss_index, etf_list = build_search_resources()
(tok1, m1), (tok2, m2), valid_ticker_set = load_ner_models()

# ─── CORE ROUTINES ──────────────────────────────────────────────────────────

# Semantic Search
def semantic_search(q: str, top_k: int=500):
    emb = s2_model.encode([q], convert_to_numpy=True)
    faiss.normalize_L2(emb)
    D, I = faiss_index.search(emb, top_k)
    l_fetched_etf_score_tuples = [(etf_list[i], float(D[0][j])) for j,i in enumerate(I[0])]
    # return only the tickers
    return [t for t, _ in l_fetched_etf_score_tuples]

# Ensemble function: union of both models' predictions
def ensemble_ticker_extraction(query):
    preds = set()

    for tok, mdl in ((tok1,m1),(tok2,m2)):
        enc = tok(query, return_tensors="pt")
        with torch.no_grad():
            logits = mdl(**enc).logits
        pred_ids = logits.argmax(dim=-1)[0].tolist()
        tokens   = tok.convert_ids_to_tokens(enc["input_ids"][0])
        labels   = [mdl.config.id2label[i] for i in pred_ids]
        preds.update(modules.extract_valid_tickers(tokens, labels, tok, valid_ticker_set))

    return preds

# Rule-based fallback: catch literal 2–4 char tickers in the text
def rule_fallback(query, valid_set):
    words = re.findall(r"\b[A-Za-z0-9]{2,4}\b", query)
    return {w.upper() for w in words if w.upper() in valid_set}

# ─── UI HELPERS ─────────────────────────────────────────────────────────────

def display_sample_query_boxes(key_prefix=""):
    sample_queries = {
        "search_etf": {
            "title": "ETF Search",
            "description": "Explore ETFs based on criteria such as high dividends, low expense ratios, or sector focus.",
            "query": [
                'High-dividend ETFs in the tech sector.',
                'Precious metals ETFs with low expense ratio.',
                'Large growth ETFs with high returns.'
            ]
        },
        "comparison": {
            "title": "ETF Performance Comparison",
            "description": "Compare two ETFs side by side to evaluate their performance, risk, and other metrics.",
            "query": [
                "I'd like to compare performance of QQQ with GLD.",
                "Compare SPY and VOO.",
                "SCHD vs. VTI"
            ]
        },
        "portfolio_projection": {
            "title": "Portfolio Projection",
            "description": "Project a portfolio with your choice of ETFs over 30 years.",
            "query": [
                "I want to invest in SPY, QQQ, SCHD, and IAU.",
                "Portfolio projection for VTI, XLF, and XLY."
            ]
        },
    }

    cols = st.columns(len(sample_queries))
    title_h, desc_h, query_h = "30px", "60px", "70px"

    for idx, (key, details) in enumerate(sample_queries.items()):
        with cols[idx]:
            st.markdown(f"""
            <div style="
              width:100%; height:350; border:1px solid #ddd;
              border-radius:10px; padding:15px; margin:auto;
              display:flex; flex-direction:column; justify-content:space-between;
              box-shadow:2px 2px 8px rgba(0,0,0,0.1);
            ">
              <div style="height:{title_h}; text-align:center;">
                <b style="font-size:16px; color:#2c3e50;">
                  {details['title']}
                </b>
              </div>
              <div style="height:{desc_h}; text-align:center; color:#7f8c8d; font-size:14px; overflow:auto;">
                {details['description']}
              </div>
              <div style="height:{query_h}; text-align:center; color:#34495e; font-size:13px; font-style:italic; overflow:auto;">
                {'<br>'.join(f'β€œ{q}”' for q in details['query'])}
              </div>
            </div>
            """, unsafe_allow_html=True)

            # center the button directly under the box
            st.markdown("<div style='text-align:center; margin-top:10px;'>", unsafe_allow_html=True)
            if st.button("Go to this app", key=key_prefix+key):
                page_map = {
                    "search_etf": "ETF Search",
                    "comparison": "ETF Comparison",
                    "portfolio_projection": "ETF Portfolio"
                }
                st.session_state["page"] = page_map[key]
                st.rerun()
            st.markdown("</div>", unsafe_allow_html=True)

def display_chat_history(task: str):
    for entry in st.session_state.get(f"all_chat_history_{task}", []):
        if entry.get("query"):
            st.chat_message("user").write(entry["query"])
        if entry.get("fig"):
            st.plotly_chart(entry["fig"], use_container_width=True)
        if entry.get("df") is not None:
            modules.display_matching_etfs(entry["df"])
        if entry.get("response"):
            st.chat_message("assistant").write(entry["response"])

def process_query(task: str, query: str):
    # Define the number of ETFs to fetch and display
    top_k, top_n = 50, 20
    if task=="search_etf":
        # Display user query
        st.chat_message("user").write(query)

        # Store query in chat history
        st.session_state[f"all_chat_history_{task}"].append(
            modules.form_d_chat_history(str(uuid.uuid4()), None, task, df=None, query=query)
        )

        # Run semantic search
        with st.spinner("Hang on tight! Searching ETFs..."):
            fetched = semantic_search(query, top_k)
            
            # Get ETF data from the list of tickers
            df_out  = modules.get_etf_recommendations_from_list(
                fetched, df_etf, top_n
            )

        # Generate response
        relavant_tickers = df_out['Ticker'].tolist()
        response = modules.format_etf_search_results_inline(relavant_tickers)

        # Display results
        st.markdown("### ETF Search Results")
        modules.display_matching_etfs(df_out)
        st.chat_message("assistant").write(response)
        
        # Store response in chat history
        st.session_state[f"all_chat_history_{task}"].append(
            modules.form_d_chat_history(str(uuid.uuid4()), response, task, df=df_out)
        )

    elif task=="comparison":
        # Display user query
        st.chat_message("user").write(query)

        # Store query in chat history
        st.session_state[f"all_chat_history_{task}"].append(
            modules.form_d_chat_history(str(uuid.uuid4()), None, task, df=None, query=query)
        )

        # Run comparison analysis
        with st.spinner("Hang on tight! Running comparison analysis..."):
            # Extarct tickers from query
            ensemble_preds = ensemble_ticker_extraction(query)
            fallback_preds = rule_fallback(query, valid_ticker_set)
            tk = list(sorted(ensemble_preds | fallback_preds))
            
            # Check if exactly two tickers are provided
            if len(tk)!=2:
                response, fig, df_out = "Please specify exactly two tickers.", None, None
            else:
                # Get ETF data from the list of tickers
                df_out = modules.get_etf_recommendations_from_list(
                    tk, df_etf, top_n=2
                )
                # Get performance comparison plot
                fig  = modules.compare_etfs_interactive(tk[0], tk[1])
                
                # Generate response
                d_analyst_reports = modules.lookup_etf_report(tk, df_analyst_report=df_analyst_report)
                response = modules.format_insights_report(d_analyst_reports)
            
                # Display comparison
                st.markdown("### Performance Comparison")
                st.plotly_chart(fig, use_container_width=True)

                # Display Table
                modules.display_matching_etfs(df_out)

        # Return response
        st.chat_message("assistant").write(response)
        
        # Store response in chat history
        st.session_state[f"all_chat_history_{task}"].append(
            modules.form_d_chat_history(str(uuid.uuid4()), response, task, fig=fig, df=df_out)
        )

    elif task=="portfolio_projection":
        # Display user query
        st.chat_message("user").write(query)

        # Store query in chat history
        st.session_state[f"all_chat_history_{task}"].append(
            modules.form_d_chat_history(str(uuid.uuid4()), None, task, df=None, query=query)
        )
        
        # Run portfolio analysis
        with st.spinner("Hang on tight! Projecting portfolio ..."):
            # Extarct tickers from query
            ensemble_preds = ensemble_ticker_extraction(query)
            fallback_preds = rule_fallback(query, valid_ticker_set)
            tk = list(sorted(ensemble_preds | fallback_preds))

            # Run portfolio analysis
            df_port_output, d_summary = modules.run_portfolio_analysis(tk, df_etf, df_annual_return_master)
            
            # Form a reprot
            response = modules.format_portfolio_summary(d_summary=d_summary)

            # Display projection
            fig = modules.portfolio_interactive_chart(df_port_output)
            st.markdown(f"### 30 Years Investment Return Projection")
            st.plotly_chart(fig, use_container_width=True)
            st.chat_message("assistant").write(response)
       
        # Store response in chat history
        st.session_state[f"all_chat_history_{task}"].append(
            modules.form_d_chat_history(str(uuid.uuid4()), response, task, fig=fig)
        )

# ─── MAIN ────────────────────────────────────────────────────────────────

def main():
    st.set_page_config(layout="wide")
    # init
    if "page" not in st.session_state:
        st.session_state["page"]="Home"
    for t in ["search_etf","comparison","portfolio_projection"]:
        st.session_state.setdefault(f"all_chat_history_{t}", [])

    # sidebar
    st.sidebar.title("ETF Assistant")
    if st.sidebar.button("🏠 Home"):
        st.session_state["page"]="Home"
    if st.sidebar.button("πŸ”Ž ETF Search"):
        st.session_state["page"]="ETF Search"
    if st.sidebar.button("βš–οΈ ETF Comparison"):
        st.session_state["page"]="ETF Comparison"
    if st.sidebar.button("πŸ’Ό ETF Portfolio"):
        st.session_state["page"]="ETF Portfolio"

    # main page
    page = st.session_state["page"]
    st.title(page if page!="Home" else "ETF Assistant")

    # display content
    if page=="Home":
        # Home page
        st.header("How can I assist you today?")
        
        # Display introduction text 1
        etf_intro_text = "An exchange-traded fund (ETF) is an investment vehicle that holds a diversified basket of assetsβ€”such as stocks, bonds," \
        " or commoditiesβ€”and trades on an exchange like a single stock. ETFs combine the diversification and low costs of mutual funds " \
        "with the flexibility and intraday liquidity of individual equities."
        st.write(etf_intro_text)
        
        # Display introduction text 2
        app_intro_text = "Find ETFs that align with your investment goals and sector interests, compare performance, and estimate your portfolioβ€”all in one place!"
        st.write(app_intro_text)
        display_sample_query_boxes(key_prefix="home_")
    else:
        # Other pages
        task = {
            "ETF Search":"search_etf",
            "ETF Comparison":"comparison",
            "ETF Portfolio":"portfolio_projection"
        }[page]

        # Display introduction text
        app_description_text = {
            "ETF Search": "Explore ETFs based on criteria such as high dividends, low expense ratios, or sector focus.",
            "ETF Comparison": "Compare two ETFs side by side to evaluate their performance, risk, and other metrics.",
            "ETF Portfolio": "Project a portfolio with your choice of ETFs over 30 years."
        }[page]
        st.write(app_description_text)
        
        # Display all previous chat history
        display_chat_history(task)
        
        # Display input box
        q = st.chat_input({
            "ETF Search":"Search for ETFs…",
            "ETF Comparison":"Compare ETFs…",
            "ETF Portfolio":"Project portfolio…"
        }[page], key=task)
        
        # Process query
        if q:
            process_query(task, q)

if __name__=="__main__":
    main()