Spaces:
Sleeping
Sleeping
| import os | |
| import streamlit as st | |
| import uuid | |
| import pandas as pd | |
| import modules | |
| import torch | |
| from sentence_transformers import SentenceTransformer | |
| import faiss | |
| from transformers import AutoTokenizer, AutoModelForTokenClassification | |
| import re | |
| # βββ CACHES βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_etf_data(): | |
| enriched_path = "etf_general_info_enriched_doc_added.csv" | |
| raw_path = "etf_general_info_enriched.csv" | |
| if os.path.exists(enriched_path): | |
| df_info = pd.read_csv(enriched_path).rename(columns={"ticker": "Ticker"}) | |
| else: | |
| df_info = pd.read_csv(raw_path).rename(columns={"ticker": "Ticker"}) | |
| df_info["doc"] = df_info.apply(modules.make_doc_text, axis=1) | |
| df_info.to_csv(enriched_path, index=False) | |
| df_etf_holdings = pd.read_csv('etf_holdings_summarized.csv').rename(columns={'ticker': 'Ticker', | |
| 'holdingInformation': 'Holdings'}) | |
| df_info = df_info.merge(df_etf_holdings, how='left', on='Ticker') | |
| df_etf, available_tickers = modules.set_etf_data(df_info) | |
| df_analyst_report = pd.read_csv("etf_analyst_report_full.csv").rename(columns={"ticker": "Ticker"}) | |
| df_annual_return_master = ( | |
| pd.read_csv("annual_return.csv") | |
| .rename(columns={"ticker": "Ticker"}) | |
| ) | |
| return df_etf, df_analyst_report, available_tickers, df_annual_return_master | |
| def build_search_resources(): | |
| df_etf, *_ = load_etf_data() | |
| model = SentenceTransformer( | |
| "hskwon7/paraphrase-MiniLM-L6-v2-ft-for-etf-semantic-search" | |
| ) | |
| ticker_list = df_etf["Ticker"].tolist() | |
| idx_path = "etf_faiss.index" | |
| if os.path.exists(idx_path): | |
| index = faiss.read_index(idx_path) | |
| else: | |
| embs = model.encode(df_etf["doc"].tolist(), convert_to_numpy=True) | |
| faiss.normalize_L2(embs) | |
| index = faiss.IndexFlatIP(embs.shape[1]) | |
| index.add(embs) | |
| faiss.write_index(index, idx_path) | |
| return model, index, ticker_list | |
| def load_ner_models(): | |
| tok1, m1 = ( | |
| AutoTokenizer.from_pretrained("hskwon7/distilbert-base-uncased-for-etf-ticker"), | |
| AutoModelForTokenClassification.from_pretrained("hskwon7/distilbert-base-uncased-for-etf-ticker") | |
| ) | |
| tok2, m2 = ( | |
| AutoTokenizer.from_pretrained("hskwon7/albert-base-v2-for-etf-ticker"), | |
| AutoModelForTokenClassification.from_pretrained("hskwon7/albert-base-v2-for-etf-ticker") | |
| ) | |
| df_etf, *_ = load_etf_data() | |
| valid_ticker_set = set(df_etf["Ticker"].str.upper()) | |
| return (tok1, m1), (tok2, m2), valid_ticker_set | |
| # βββ INITIALIZE βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| df_etf, df_analyst_report, available_tickers, df_annual_return_master = load_etf_data() | |
| s2_model, faiss_index, etf_list = build_search_resources() | |
| (tok1, m1), (tok2, m2), valid_ticker_set = load_ner_models() | |
| # βββ CORE ROUTINES ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Semantic Search | |
| def semantic_search(q: str, top_k: int=500): | |
| emb = s2_model.encode([q], convert_to_numpy=True) | |
| faiss.normalize_L2(emb) | |
| D, I = faiss_index.search(emb, top_k) | |
| l_fetched_etf_score_tuples = [(etf_list[i], float(D[0][j])) for j,i in enumerate(I[0])] | |
| # return only the tickers | |
| return [t for t, _ in l_fetched_etf_score_tuples] | |
| # Ensemble function: union of both models' predictions | |
| def ensemble_ticker_extraction(query): | |
| preds = set() | |
| for tok, mdl in ((tok1,m1),(tok2,m2)): | |
| enc = tok(query, return_tensors="pt") | |
| with torch.no_grad(): | |
| logits = mdl(**enc).logits | |
| pred_ids = logits.argmax(dim=-1)[0].tolist() | |
| tokens = tok.convert_ids_to_tokens(enc["input_ids"][0]) | |
| labels = [mdl.config.id2label[i] for i in pred_ids] | |
| preds.update(modules.extract_valid_tickers(tokens, labels, tok, valid_ticker_set)) | |
| return preds | |
| # Rule-based fallback: catch literal 2β4 char tickers in the text | |
| def rule_fallback(query, valid_set): | |
| words = re.findall(r"\b[A-Za-z0-9]{2,4}\b", query) | |
| return {w.upper() for w in words if w.upper() in valid_set} | |
| # βββ UI HELPERS βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def display_sample_query_boxes(key_prefix=""): | |
| sample_queries = { | |
| "search_etf": { | |
| "title": "ETF Search", | |
| "description": "Explore ETFs based on criteria such as high dividends, low expense ratios, or sector focus.", | |
| "query": [ | |
| 'High-dividend ETFs in the tech sector.', | |
| 'Precious metals ETFs with low expense ratio.', | |
| 'Large growth ETFs with high returns.' | |
| ] | |
| }, | |
| "comparison": { | |
| "title": "ETF Performance Comparison", | |
| "description": "Compare two ETFs side by side to evaluate their performance, risk, and other metrics.", | |
| "query": [ | |
| "I'd like to compare performance of QQQ with GLD.", | |
| "Compare SPY and VOO.", | |
| "SCHD vs. VTI" | |
| ] | |
| }, | |
| "portfolio_projection": { | |
| "title": "Portfolio Projection", | |
| "description": "Project a portfolio with your choice of ETFs over 30 years.", | |
| "query": [ | |
| "I want to invest in SPY, QQQ, SCHD, and IAU.", | |
| "Portfolio projection for VTI, XLF, and XLY." | |
| ] | |
| }, | |
| } | |
| cols = st.columns(len(sample_queries)) | |
| title_h, desc_h, query_h = "30px", "60px", "70px" | |
| for idx, (key, details) in enumerate(sample_queries.items()): | |
| with cols[idx]: | |
| st.markdown(f""" | |
| <div style=" | |
| width:100%; height:350; border:1px solid #ddd; | |
| border-radius:10px; padding:15px; margin:auto; | |
| display:flex; flex-direction:column; justify-content:space-between; | |
| box-shadow:2px 2px 8px rgba(0,0,0,0.1); | |
| "> | |
| <div style="height:{title_h}; text-align:center;"> | |
| <b style="font-size:16px; color:#2c3e50;"> | |
| {details['title']} | |
| </b> | |
| </div> | |
| <div style="height:{desc_h}; text-align:center; color:#7f8c8d; font-size:14px; overflow:auto;"> | |
| {details['description']} | |
| </div> | |
| <div style="height:{query_h}; text-align:center; color:#34495e; font-size:13px; font-style:italic; overflow:auto;"> | |
| {'<br>'.join(f'β{q}β' for q in details['query'])} | |
| </div> | |
| </div> | |
| """, unsafe_allow_html=True) | |
| # center the button directly under the box | |
| st.markdown("<div style='text-align:center; margin-top:10px;'>", unsafe_allow_html=True) | |
| if st.button("Go to this app", key=key_prefix+key): | |
| page_map = { | |
| "search_etf": "ETF Search", | |
| "comparison": "ETF Comparison", | |
| "portfolio_projection": "ETF Portfolio" | |
| } | |
| st.session_state["page"] = page_map[key] | |
| st.rerun() | |
| st.markdown("</div>", unsafe_allow_html=True) | |
| def display_chat_history(task: str): | |
| for entry in st.session_state.get(f"all_chat_history_{task}", []): | |
| if entry.get("query"): | |
| st.chat_message("user").write(entry["query"]) | |
| if entry.get("fig"): | |
| st.plotly_chart(entry["fig"], use_container_width=True) | |
| if entry.get("df") is not None: | |
| modules.display_matching_etfs(entry["df"]) | |
| if entry.get("response"): | |
| st.chat_message("assistant").write(entry["response"]) | |
| def process_query(task: str, query: str): | |
| # Define the number of ETFs to fetch and display | |
| top_k, top_n = 50, 20 | |
| if task=="search_etf": | |
| # Display user query | |
| st.chat_message("user").write(query) | |
| # Store query in chat history | |
| st.session_state[f"all_chat_history_{task}"].append( | |
| modules.form_d_chat_history(str(uuid.uuid4()), None, task, df=None, query=query) | |
| ) | |
| # Run semantic search | |
| with st.spinner("Hang on tight! Searching ETFs..."): | |
| fetched = semantic_search(query, top_k) | |
| # Get ETF data from the list of tickers | |
| df_out = modules.get_etf_recommendations_from_list( | |
| fetched, df_etf, top_n | |
| ) | |
| # Generate response | |
| relavant_tickers = df_out['Ticker'].tolist() | |
| response = modules.format_etf_search_results_inline(relavant_tickers) | |
| # Display results | |
| st.markdown("### ETF Search Results") | |
| modules.display_matching_etfs(df_out) | |
| st.chat_message("assistant").write(response) | |
| # Store response in chat history | |
| st.session_state[f"all_chat_history_{task}"].append( | |
| modules.form_d_chat_history(str(uuid.uuid4()), response, task, df=df_out) | |
| ) | |
| elif task=="comparison": | |
| # Display user query | |
| st.chat_message("user").write(query) | |
| # Store query in chat history | |
| st.session_state[f"all_chat_history_{task}"].append( | |
| modules.form_d_chat_history(str(uuid.uuid4()), None, task, df=None, query=query) | |
| ) | |
| # Run comparison analysis | |
| with st.spinner("Hang on tight! Running comparison analysis..."): | |
| # Extarct tickers from query | |
| ensemble_preds = ensemble_ticker_extraction(query) | |
| fallback_preds = rule_fallback(query, valid_ticker_set) | |
| tk = list(sorted(ensemble_preds | fallback_preds)) | |
| # Check if exactly two tickers are provided | |
| if len(tk)!=2: | |
| response, fig, df_out = "Please specify exactly two tickers.", None, None | |
| else: | |
| # Get ETF data from the list of tickers | |
| df_out = modules.get_etf_recommendations_from_list( | |
| tk, df_etf, top_n=2 | |
| ) | |
| # Get performance comparison plot | |
| fig = modules.compare_etfs_interactive(tk[0], tk[1]) | |
| # Generate response | |
| d_analyst_reports = modules.lookup_etf_report(tk, df_analyst_report=df_analyst_report) | |
| response = modules.format_insights_report(d_analyst_reports) | |
| # Display comparison | |
| st.markdown("### Performance Comparison") | |
| st.plotly_chart(fig, use_container_width=True) | |
| # Display Table | |
| modules.display_matching_etfs(df_out) | |
| # Return response | |
| st.chat_message("assistant").write(response) | |
| # Store response in chat history | |
| st.session_state[f"all_chat_history_{task}"].append( | |
| modules.form_d_chat_history(str(uuid.uuid4()), response, task, fig=fig, df=df_out) | |
| ) | |
| elif task=="portfolio_projection": | |
| # Display user query | |
| st.chat_message("user").write(query) | |
| # Store query in chat history | |
| st.session_state[f"all_chat_history_{task}"].append( | |
| modules.form_d_chat_history(str(uuid.uuid4()), None, task, df=None, query=query) | |
| ) | |
| # Run portfolio analysis | |
| with st.spinner("Hang on tight! Projecting portfolio ..."): | |
| # Extarct tickers from query | |
| ensemble_preds = ensemble_ticker_extraction(query) | |
| fallback_preds = rule_fallback(query, valid_ticker_set) | |
| tk = list(sorted(ensemble_preds | fallback_preds)) | |
| # Run portfolio analysis | |
| df_port_output, d_summary = modules.run_portfolio_analysis(tk, df_etf, df_annual_return_master) | |
| # Form a reprot | |
| response = modules.format_portfolio_summary(d_summary=d_summary) | |
| # Display projection | |
| fig = modules.portfolio_interactive_chart(df_port_output) | |
| st.markdown(f"### 30 Years Investment Return Projection") | |
| st.plotly_chart(fig, use_container_width=True) | |
| st.chat_message("assistant").write(response) | |
| # Store response in chat history | |
| st.session_state[f"all_chat_history_{task}"].append( | |
| modules.form_d_chat_history(str(uuid.uuid4()), response, task, fig=fig) | |
| ) | |
| # βββ MAIN ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def main(): | |
| st.set_page_config(layout="wide") | |
| # init | |
| if "page" not in st.session_state: | |
| st.session_state["page"]="Home" | |
| for t in ["search_etf","comparison","portfolio_projection"]: | |
| st.session_state.setdefault(f"all_chat_history_{t}", []) | |
| # sidebar | |
| st.sidebar.title("ETF Assistant") | |
| if st.sidebar.button("π Home"): | |
| st.session_state["page"]="Home" | |
| if st.sidebar.button("π ETF Search"): | |
| st.session_state["page"]="ETF Search" | |
| if st.sidebar.button("βοΈ ETF Comparison"): | |
| st.session_state["page"]="ETF Comparison" | |
| if st.sidebar.button("πΌ ETF Portfolio"): | |
| st.session_state["page"]="ETF Portfolio" | |
| # main page | |
| page = st.session_state["page"] | |
| st.title(page if page!="Home" else "ETF Assistant") | |
| # display content | |
| if page=="Home": | |
| # Home page | |
| st.header("How can I assist you today?") | |
| # Display introduction text 1 | |
| etf_intro_text = "An exchange-traded fund (ETF) is an investment vehicle that holds a diversified basket of assetsβsuch as stocks, bonds," \ | |
| " or commoditiesβand trades on an exchange like a single stock. ETFs combine the diversification and low costs of mutual funds " \ | |
| "with the flexibility and intraday liquidity of individual equities." | |
| st.write(etf_intro_text) | |
| # Display introduction text 2 | |
| app_intro_text = "Find ETFs that align with your investment goals and sector interests, compare performance, and estimate your portfolioβall in one place!" | |
| st.write(app_intro_text) | |
| display_sample_query_boxes(key_prefix="home_") | |
| else: | |
| # Other pages | |
| task = { | |
| "ETF Search":"search_etf", | |
| "ETF Comparison":"comparison", | |
| "ETF Portfolio":"portfolio_projection" | |
| }[page] | |
| # Display introduction text | |
| app_description_text = { | |
| "ETF Search": "Explore ETFs based on criteria such as high dividends, low expense ratios, or sector focus.", | |
| "ETF Comparison": "Compare two ETFs side by side to evaluate their performance, risk, and other metrics.", | |
| "ETF Portfolio": "Project a portfolio with your choice of ETFs over 30 years." | |
| }[page] | |
| st.write(app_description_text) | |
| # Display all previous chat history | |
| display_chat_history(task) | |
| # Display input box | |
| q = st.chat_input({ | |
| "ETF Search":"Search for ETFsβ¦", | |
| "ETF Comparison":"Compare ETFsβ¦", | |
| "ETF Portfolio":"Project portfolioβ¦" | |
| }[page], key=task) | |
| # Process query | |
| if q: | |
| process_query(task, q) | |
| if __name__=="__main__": | |
| main() | |