import os import streamlit as st import uuid import pandas as pd import modules import torch from sentence_transformers import SentenceTransformer import faiss from transformers import AutoTokenizer, AutoModelForTokenClassification import re # ─── CACHES ───────────────────────────────────────────────────────────────── @st.cache_data(show_spinner=False) def load_etf_data(): enriched_path = "etf_general_info_enriched_doc_added.csv" raw_path = "etf_general_info_enriched.csv" if os.path.exists(enriched_path): df_info = pd.read_csv(enriched_path).rename(columns={"ticker": "Ticker"}) else: df_info = pd.read_csv(raw_path).rename(columns={"ticker": "Ticker"}) df_info["doc"] = df_info.apply(modules.make_doc_text, axis=1) df_info.to_csv(enriched_path, index=False) df_etf_holdings = pd.read_csv('etf_holdings_summarized.csv').rename(columns={'ticker': 'Ticker', 'holdingInformation': 'Holdings'}) df_info = df_info.merge(df_etf_holdings, how='left', on='Ticker') df_etf, available_tickers = modules.set_etf_data(df_info) df_analyst_report = pd.read_csv("etf_analyst_report_full.csv").rename(columns={"ticker": "Ticker"}) df_annual_return_master = ( pd.read_csv("annual_return.csv") .rename(columns={"ticker": "Ticker"}) ) return df_etf, df_analyst_report, available_tickers, df_annual_return_master @st.cache_resource(show_spinner=False) def build_search_resources(): df_etf, *_ = load_etf_data() model = SentenceTransformer( "hskwon7/paraphrase-MiniLM-L6-v2-ft-for-etf-semantic-search" ) ticker_list = df_etf["Ticker"].tolist() idx_path = "etf_faiss.index" if os.path.exists(idx_path): index = faiss.read_index(idx_path) else: embs = model.encode(df_etf["doc"].tolist(), convert_to_numpy=True) faiss.normalize_L2(embs) index = faiss.IndexFlatIP(embs.shape[1]) index.add(embs) faiss.write_index(index, idx_path) return model, index, ticker_list @st.cache_resource(show_spinner=False) def load_ner_models(): tok1, m1 = ( AutoTokenizer.from_pretrained("hskwon7/distilbert-base-uncased-for-etf-ticker"), AutoModelForTokenClassification.from_pretrained("hskwon7/distilbert-base-uncased-for-etf-ticker") ) tok2, m2 = ( AutoTokenizer.from_pretrained("hskwon7/albert-base-v2-for-etf-ticker"), AutoModelForTokenClassification.from_pretrained("hskwon7/albert-base-v2-for-etf-ticker") ) df_etf, *_ = load_etf_data() valid_ticker_set = set(df_etf["Ticker"].str.upper()) return (tok1, m1), (tok2, m2), valid_ticker_set # ─── INITIALIZE ───────────────────────────────────────────────────────────── df_etf, df_analyst_report, available_tickers, df_annual_return_master = load_etf_data() s2_model, faiss_index, etf_list = build_search_resources() (tok1, m1), (tok2, m2), valid_ticker_set = load_ner_models() # ─── CORE ROUTINES ────────────────────────────────────────────────────────── # Semantic Search def semantic_search(q: str, top_k: int=500): emb = s2_model.encode([q], convert_to_numpy=True) faiss.normalize_L2(emb) D, I = faiss_index.search(emb, top_k) l_fetched_etf_score_tuples = [(etf_list[i], float(D[0][j])) for j,i in enumerate(I[0])] # return only the tickers return [t for t, _ in l_fetched_etf_score_tuples] # Ensemble function: union of both models' predictions def ensemble_ticker_extraction(query): preds = set() for tok, mdl in ((tok1,m1),(tok2,m2)): enc = tok(query, return_tensors="pt") with torch.no_grad(): logits = mdl(**enc).logits pred_ids = logits.argmax(dim=-1)[0].tolist() tokens = tok.convert_ids_to_tokens(enc["input_ids"][0]) labels = [mdl.config.id2label[i] for i in pred_ids] preds.update(modules.extract_valid_tickers(tokens, labels, tok, valid_ticker_set)) return preds # Rule-based fallback: catch literal 2–4 char tickers in the text def rule_fallback(query, valid_set): words = re.findall(r"\b[A-Za-z0-9]{2,4}\b", query) return {w.upper() for w in words if w.upper() in valid_set} # ─── UI HELPERS ───────────────────────────────────────────────────────────── def display_sample_query_boxes(key_prefix=""): sample_queries = { "search_etf": { "title": "ETF Search", "description": "Explore ETFs based on criteria such as high dividends, low expense ratios, or sector focus.", "query": [ 'High-dividend ETFs in the tech sector.', 'Precious metals ETFs with low expense ratio.', 'Large growth ETFs with high returns.' ] }, "comparison": { "title": "ETF Performance Comparison", "description": "Compare two ETFs side by side to evaluate their performance, risk, and other metrics.", "query": [ "I'd like to compare performance of QQQ with GLD.", "Compare SPY and VOO.", "SCHD vs. VTI" ] }, "portfolio_projection": { "title": "Portfolio Projection", "description": "Project a portfolio with your choice of ETFs over 30 years.", "query": [ "I want to invest in SPY, QQQ, SCHD, and IAU.", "Portfolio projection for VTI, XLF, and XLY." ] }, } cols = st.columns(len(sample_queries)) title_h, desc_h, query_h = "30px", "60px", "70px" for idx, (key, details) in enumerate(sample_queries.items()): with cols[idx]: st.markdown(f"""
{details['title']}
{details['description']}
{'
'.join(f'“{q}”' for q in details['query'])}
""", unsafe_allow_html=True) # center the button directly under the box st.markdown("
", unsafe_allow_html=True) if st.button("Go to this app", key=key_prefix+key): page_map = { "search_etf": "ETF Search", "comparison": "ETF Comparison", "portfolio_projection": "ETF Portfolio" } st.session_state["page"] = page_map[key] st.rerun() st.markdown("
", unsafe_allow_html=True) def display_chat_history(task: str): for entry in st.session_state.get(f"all_chat_history_{task}", []): if entry.get("query"): st.chat_message("user").write(entry["query"]) if entry.get("fig"): st.plotly_chart(entry["fig"], use_container_width=True) if entry.get("df") is not None: modules.display_matching_etfs(entry["df"]) if entry.get("response"): st.chat_message("assistant").write(entry["response"]) def process_query(task: str, query: str): # Define the number of ETFs to fetch and display top_k, top_n = 50, 20 if task=="search_etf": # Display user query st.chat_message("user").write(query) # Store query in chat history st.session_state[f"all_chat_history_{task}"].append( modules.form_d_chat_history(str(uuid.uuid4()), None, task, df=None, query=query) ) # Run semantic search with st.spinner("Hang on tight! Searching ETFs..."): fetched = semantic_search(query, top_k) # Get ETF data from the list of tickers df_out = modules.get_etf_recommendations_from_list( fetched, df_etf, top_n ) # Generate response relavant_tickers = df_out['Ticker'].tolist() response = modules.format_etf_search_results_inline(relavant_tickers) # Display results st.markdown("### ETF Search Results") modules.display_matching_etfs(df_out) st.chat_message("assistant").write(response) # Store response in chat history st.session_state[f"all_chat_history_{task}"].append( modules.form_d_chat_history(str(uuid.uuid4()), response, task, df=df_out) ) elif task=="comparison": # Display user query st.chat_message("user").write(query) # Store query in chat history st.session_state[f"all_chat_history_{task}"].append( modules.form_d_chat_history(str(uuid.uuid4()), None, task, df=None, query=query) ) # Run comparison analysis with st.spinner("Hang on tight! Running comparison analysis..."): # Extarct tickers from query ensemble_preds = ensemble_ticker_extraction(query) fallback_preds = rule_fallback(query, valid_ticker_set) tk = list(sorted(ensemble_preds | fallback_preds)) # Check if exactly two tickers are provided if len(tk)!=2: response, fig, df_out = "Please specify exactly two tickers.", None, None else: # Get ETF data from the list of tickers df_out = modules.get_etf_recommendations_from_list( tk, df_etf, top_n=2 ) # Get performance comparison plot fig = modules.compare_etfs_interactive(tk[0], tk[1]) # Generate response d_analyst_reports = modules.lookup_etf_report(tk, df_analyst_report=df_analyst_report) response = modules.format_insights_report(d_analyst_reports) # Display comparison st.markdown("### Performance Comparison") st.plotly_chart(fig, use_container_width=True) # Display Table modules.display_matching_etfs(df_out) # Return response st.chat_message("assistant").write(response) # Store response in chat history st.session_state[f"all_chat_history_{task}"].append( modules.form_d_chat_history(str(uuid.uuid4()), response, task, fig=fig, df=df_out) ) elif task=="portfolio_projection": # Display user query st.chat_message("user").write(query) # Store query in chat history st.session_state[f"all_chat_history_{task}"].append( modules.form_d_chat_history(str(uuid.uuid4()), None, task, df=None, query=query) ) # Run portfolio analysis with st.spinner("Hang on tight! Projecting portfolio ..."): # Extarct tickers from query ensemble_preds = ensemble_ticker_extraction(query) fallback_preds = rule_fallback(query, valid_ticker_set) tk = list(sorted(ensemble_preds | fallback_preds)) # Run portfolio analysis df_port_output, d_summary = modules.run_portfolio_analysis(tk, df_etf, df_annual_return_master) # Form a reprot response = modules.format_portfolio_summary(d_summary=d_summary) # Display projection fig = modules.portfolio_interactive_chart(df_port_output) st.markdown(f"### 30 Years Investment Return Projection") st.plotly_chart(fig, use_container_width=True) st.chat_message("assistant").write(response) # Store response in chat history st.session_state[f"all_chat_history_{task}"].append( modules.form_d_chat_history(str(uuid.uuid4()), response, task, fig=fig) ) # ─── MAIN ──────────────────────────────────────────────────────────────── def main(): st.set_page_config(layout="wide") # init if "page" not in st.session_state: st.session_state["page"]="Home" for t in ["search_etf","comparison","portfolio_projection"]: st.session_state.setdefault(f"all_chat_history_{t}", []) # sidebar st.sidebar.title("ETF Assistant") if st.sidebar.button("🏠 Home"): st.session_state["page"]="Home" if st.sidebar.button("🔎 ETF Search"): st.session_state["page"]="ETF Search" if st.sidebar.button("⚖️ ETF Comparison"): st.session_state["page"]="ETF Comparison" if st.sidebar.button("💼 ETF Portfolio"): st.session_state["page"]="ETF Portfolio" # main page page = st.session_state["page"] st.title(page if page!="Home" else "ETF Assistant") # display content if page=="Home": # Home page st.header("How can I assist you today?") # Display introduction text 1 etf_intro_text = "An exchange-traded fund (ETF) is an investment vehicle that holds a diversified basket of assets—such as stocks, bonds," \ " or commodities—and trades on an exchange like a single stock. ETFs combine the diversification and low costs of mutual funds " \ "with the flexibility and intraday liquidity of individual equities." st.write(etf_intro_text) # Display introduction text 2 app_intro_text = "Find ETFs that align with your investment goals and sector interests, compare performance, and estimate your portfolio—all in one place!" st.write(app_intro_text) display_sample_query_boxes(key_prefix="home_") else: # Other pages task = { "ETF Search":"search_etf", "ETF Comparison":"comparison", "ETF Portfolio":"portfolio_projection" }[page] # Display introduction text app_description_text = { "ETF Search": "Explore ETFs based on criteria such as high dividends, low expense ratios, or sector focus.", "ETF Comparison": "Compare two ETFs side by side to evaluate their performance, risk, and other metrics.", "ETF Portfolio": "Project a portfolio with your choice of ETFs over 30 years." }[page] st.write(app_description_text) # Display all previous chat history display_chat_history(task) # Display input box q = st.chat_input({ "ETF Search":"Search for ETFs…", "ETF Comparison":"Compare ETFs…", "ETF Portfolio":"Project portfolio…" }[page], key=task) # Process query if q: process_query(task, q) if __name__=="__main__": main()