import os
import streamlit as st
import uuid
import pandas as pd
import modules
import torch
from sentence_transformers import SentenceTransformer
import faiss
from transformers import AutoTokenizer, AutoModelForTokenClassification
import re
# ─── CACHES ─────────────────────────────────────────────────────────────────
@st.cache_data(show_spinner=False)
def load_etf_data():
enriched_path = "etf_general_info_enriched_doc_added.csv"
raw_path = "etf_general_info_enriched.csv"
if os.path.exists(enriched_path):
df_info = pd.read_csv(enriched_path).rename(columns={"ticker": "Ticker"})
else:
df_info = pd.read_csv(raw_path).rename(columns={"ticker": "Ticker"})
df_info["doc"] = df_info.apply(modules.make_doc_text, axis=1)
df_info.to_csv(enriched_path, index=False)
df_etf_holdings = pd.read_csv('etf_holdings_summarized.csv').rename(columns={'ticker': 'Ticker',
'holdingInformation': 'Holdings'})
df_info = df_info.merge(df_etf_holdings, how='left', on='Ticker')
df_etf, available_tickers = modules.set_etf_data(df_info)
df_analyst_report = pd.read_csv("etf_analyst_report_full.csv").rename(columns={"ticker": "Ticker"})
df_annual_return_master = (
pd.read_csv("annual_return.csv")
.rename(columns={"ticker": "Ticker"})
)
return df_etf, df_analyst_report, available_tickers, df_annual_return_master
@st.cache_resource(show_spinner=False)
def build_search_resources():
df_etf, *_ = load_etf_data()
model = SentenceTransformer(
"hskwon7/paraphrase-MiniLM-L6-v2-ft-for-etf-semantic-search"
)
ticker_list = df_etf["Ticker"].tolist()
idx_path = "etf_faiss.index"
if os.path.exists(idx_path):
index = faiss.read_index(idx_path)
else:
embs = model.encode(df_etf["doc"].tolist(), convert_to_numpy=True)
faiss.normalize_L2(embs)
index = faiss.IndexFlatIP(embs.shape[1])
index.add(embs)
faiss.write_index(index, idx_path)
return model, index, ticker_list
@st.cache_resource(show_spinner=False)
def load_ner_models():
tok1, m1 = (
AutoTokenizer.from_pretrained("hskwon7/distilbert-base-uncased-for-etf-ticker"),
AutoModelForTokenClassification.from_pretrained("hskwon7/distilbert-base-uncased-for-etf-ticker")
)
tok2, m2 = (
AutoTokenizer.from_pretrained("hskwon7/albert-base-v2-for-etf-ticker"),
AutoModelForTokenClassification.from_pretrained("hskwon7/albert-base-v2-for-etf-ticker")
)
df_etf, *_ = load_etf_data()
valid_ticker_set = set(df_etf["Ticker"].str.upper())
return (tok1, m1), (tok2, m2), valid_ticker_set
# ─── INITIALIZE ─────────────────────────────────────────────────────────────
df_etf, df_analyst_report, available_tickers, df_annual_return_master = load_etf_data()
s2_model, faiss_index, etf_list = build_search_resources()
(tok1, m1), (tok2, m2), valid_ticker_set = load_ner_models()
# ─── CORE ROUTINES ──────────────────────────────────────────────────────────
# Semantic Search
def semantic_search(q: str, top_k: int=500):
emb = s2_model.encode([q], convert_to_numpy=True)
faiss.normalize_L2(emb)
D, I = faiss_index.search(emb, top_k)
l_fetched_etf_score_tuples = [(etf_list[i], float(D[0][j])) for j,i in enumerate(I[0])]
# return only the tickers
return [t for t, _ in l_fetched_etf_score_tuples]
# Ensemble function: union of both models' predictions
def ensemble_ticker_extraction(query):
preds = set()
for tok, mdl in ((tok1,m1),(tok2,m2)):
enc = tok(query, return_tensors="pt")
with torch.no_grad():
logits = mdl(**enc).logits
pred_ids = logits.argmax(dim=-1)[0].tolist()
tokens = tok.convert_ids_to_tokens(enc["input_ids"][0])
labels = [mdl.config.id2label[i] for i in pred_ids]
preds.update(modules.extract_valid_tickers(tokens, labels, tok, valid_ticker_set))
return preds
# Rule-based fallback: catch literal 2–4 char tickers in the text
def rule_fallback(query, valid_set):
words = re.findall(r"\b[A-Za-z0-9]{2,4}\b", query)
return {w.upper() for w in words if w.upper() in valid_set}
# ─── UI HELPERS ─────────────────────────────────────────────────────────────
def display_sample_query_boxes(key_prefix=""):
sample_queries = {
"search_etf": {
"title": "ETF Search",
"description": "Explore ETFs based on criteria such as high dividends, low expense ratios, or sector focus.",
"query": [
'High-dividend ETFs in the tech sector.',
'Precious metals ETFs with low expense ratio.',
'Large growth ETFs with high returns.'
]
},
"comparison": {
"title": "ETF Performance Comparison",
"description": "Compare two ETFs side by side to evaluate their performance, risk, and other metrics.",
"query": [
"I'd like to compare performance of QQQ with GLD.",
"Compare SPY and VOO.",
"SCHD vs. VTI"
]
},
"portfolio_projection": {
"title": "Portfolio Projection",
"description": "Project a portfolio with your choice of ETFs over 30 years.",
"query": [
"I want to invest in SPY, QQQ, SCHD, and IAU.",
"Portfolio projection for VTI, XLF, and XLY."
]
},
}
cols = st.columns(len(sample_queries))
title_h, desc_h, query_h = "30px", "60px", "70px"
for idx, (key, details) in enumerate(sample_queries.items()):
with cols[idx]:
st.markdown(f"""
{details['title']}
{details['description']}
{'
'.join(f'“{q}”' for q in details['query'])}
""", unsafe_allow_html=True)
# center the button directly under the box
st.markdown("", unsafe_allow_html=True)
if st.button("Go to this app", key=key_prefix+key):
page_map = {
"search_etf": "ETF Search",
"comparison": "ETF Comparison",
"portfolio_projection": "ETF Portfolio"
}
st.session_state["page"] = page_map[key]
st.rerun()
st.markdown("
", unsafe_allow_html=True)
def display_chat_history(task: str):
for entry in st.session_state.get(f"all_chat_history_{task}", []):
if entry.get("query"):
st.chat_message("user").write(entry["query"])
if entry.get("fig"):
st.plotly_chart(entry["fig"], use_container_width=True)
if entry.get("df") is not None:
modules.display_matching_etfs(entry["df"])
if entry.get("response"):
st.chat_message("assistant").write(entry["response"])
def process_query(task: str, query: str):
# Define the number of ETFs to fetch and display
top_k, top_n = 50, 20
if task=="search_etf":
# Display user query
st.chat_message("user").write(query)
# Store query in chat history
st.session_state[f"all_chat_history_{task}"].append(
modules.form_d_chat_history(str(uuid.uuid4()), None, task, df=None, query=query)
)
# Run semantic search
with st.spinner("Hang on tight! Searching ETFs..."):
fetched = semantic_search(query, top_k)
# Get ETF data from the list of tickers
df_out = modules.get_etf_recommendations_from_list(
fetched, df_etf, top_n
)
# Generate response
relavant_tickers = df_out['Ticker'].tolist()
response = modules.format_etf_search_results_inline(relavant_tickers)
# Display results
st.markdown("### ETF Search Results")
modules.display_matching_etfs(df_out)
st.chat_message("assistant").write(response)
# Store response in chat history
st.session_state[f"all_chat_history_{task}"].append(
modules.form_d_chat_history(str(uuid.uuid4()), response, task, df=df_out)
)
elif task=="comparison":
# Display user query
st.chat_message("user").write(query)
# Store query in chat history
st.session_state[f"all_chat_history_{task}"].append(
modules.form_d_chat_history(str(uuid.uuid4()), None, task, df=None, query=query)
)
# Run comparison analysis
with st.spinner("Hang on tight! Running comparison analysis..."):
# Extarct tickers from query
ensemble_preds = ensemble_ticker_extraction(query)
fallback_preds = rule_fallback(query, valid_ticker_set)
tk = list(sorted(ensemble_preds | fallback_preds))
# Check if exactly two tickers are provided
if len(tk)!=2:
response, fig, df_out = "Please specify exactly two tickers.", None, None
else:
# Get ETF data from the list of tickers
df_out = modules.get_etf_recommendations_from_list(
tk, df_etf, top_n=2
)
# Get performance comparison plot
fig = modules.compare_etfs_interactive(tk[0], tk[1])
# Generate response
d_analyst_reports = modules.lookup_etf_report(tk, df_analyst_report=df_analyst_report)
response = modules.format_insights_report(d_analyst_reports)
# Display comparison
st.markdown("### Performance Comparison")
st.plotly_chart(fig, use_container_width=True)
# Display Table
modules.display_matching_etfs(df_out)
# Return response
st.chat_message("assistant").write(response)
# Store response in chat history
st.session_state[f"all_chat_history_{task}"].append(
modules.form_d_chat_history(str(uuid.uuid4()), response, task, fig=fig, df=df_out)
)
elif task=="portfolio_projection":
# Display user query
st.chat_message("user").write(query)
# Store query in chat history
st.session_state[f"all_chat_history_{task}"].append(
modules.form_d_chat_history(str(uuid.uuid4()), None, task, df=None, query=query)
)
# Run portfolio analysis
with st.spinner("Hang on tight! Projecting portfolio ..."):
# Extarct tickers from query
ensemble_preds = ensemble_ticker_extraction(query)
fallback_preds = rule_fallback(query, valid_ticker_set)
tk = list(sorted(ensemble_preds | fallback_preds))
# Run portfolio analysis
df_port_output, d_summary = modules.run_portfolio_analysis(tk, df_etf, df_annual_return_master)
# Form a reprot
response = modules.format_portfolio_summary(d_summary=d_summary)
# Display projection
fig = modules.portfolio_interactive_chart(df_port_output)
st.markdown(f"### 30 Years Investment Return Projection")
st.plotly_chart(fig, use_container_width=True)
st.chat_message("assistant").write(response)
# Store response in chat history
st.session_state[f"all_chat_history_{task}"].append(
modules.form_d_chat_history(str(uuid.uuid4()), response, task, fig=fig)
)
# ─── MAIN ────────────────────────────────────────────────────────────────
def main():
st.set_page_config(layout="wide")
# init
if "page" not in st.session_state:
st.session_state["page"]="Home"
for t in ["search_etf","comparison","portfolio_projection"]:
st.session_state.setdefault(f"all_chat_history_{t}", [])
# sidebar
st.sidebar.title("ETF Assistant")
if st.sidebar.button("🏠 Home"):
st.session_state["page"]="Home"
if st.sidebar.button("🔎 ETF Search"):
st.session_state["page"]="ETF Search"
if st.sidebar.button("⚖️ ETF Comparison"):
st.session_state["page"]="ETF Comparison"
if st.sidebar.button("💼 ETF Portfolio"):
st.session_state["page"]="ETF Portfolio"
# main page
page = st.session_state["page"]
st.title(page if page!="Home" else "ETF Assistant")
# display content
if page=="Home":
# Home page
st.header("How can I assist you today?")
# Display introduction text 1
etf_intro_text = "An exchange-traded fund (ETF) is an investment vehicle that holds a diversified basket of assets—such as stocks, bonds," \
" or commodities—and trades on an exchange like a single stock. ETFs combine the diversification and low costs of mutual funds " \
"with the flexibility and intraday liquidity of individual equities."
st.write(etf_intro_text)
# Display introduction text 2
app_intro_text = "Find ETFs that align with your investment goals and sector interests, compare performance, and estimate your portfolio—all in one place!"
st.write(app_intro_text)
display_sample_query_boxes(key_prefix="home_")
else:
# Other pages
task = {
"ETF Search":"search_etf",
"ETF Comparison":"comparison",
"ETF Portfolio":"portfolio_projection"
}[page]
# Display introduction text
app_description_text = {
"ETF Search": "Explore ETFs based on criteria such as high dividends, low expense ratios, or sector focus.",
"ETF Comparison": "Compare two ETFs side by side to evaluate their performance, risk, and other metrics.",
"ETF Portfolio": "Project a portfolio with your choice of ETFs over 30 years."
}[page]
st.write(app_description_text)
# Display all previous chat history
display_chat_history(task)
# Display input box
q = st.chat_input({
"ETF Search":"Search for ETFs…",
"ETF Comparison":"Compare ETFs…",
"ETF Portfolio":"Project portfolio…"
}[page], key=task)
# Process query
if q:
process_query(task, q)
if __name__=="__main__":
main()