hskwon7 commited on
Commit
02470e1
Β·
verified Β·
1 Parent(s): 135bebd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +207 -128
app.py CHANGED
@@ -1,130 +1,209 @@
1
- # app.py
2
-
3
- """
4
- Streamlit application for Image-to-Story demo with history sidebar.
5
- Allows demo/upload image, generates a caption, a trimmed story,
6
- and plays back as MP3 via gTTS. Keeps history of all runs.
7
- """
8
  import streamlit as st
9
- from PIL import Image
10
- import warnings
11
- from modules import (
12
- load_captioner, load_story_gen,
13
- generate_caption, generate_story_simple,
14
- generate_audio
15
- )
16
- import io
17
-
18
- warnings.filterwarnings("ignore", category=DeprecationWarning)
19
-
20
- # Reset state when switching image source
21
- def reset_state():
22
- for key in ["caption", "story", "audio_bytes", "audio_mime", "selected_index"]:
23
- if key in st.session_state:
24
- del st.session_state[key]
25
-
26
- def main():
27
- st.set_page_config(layout="wide")
28
- st.title("🎨 Magic Picture Story Time!")
29
- st.write("Pick or upload a picture, and watch it turn into a fun story with voice! Ready for a magical tale?")
30
-
31
- # --- Sidebar: History ---
32
- st.sidebar.header("History")
33
- if "history" not in st.session_state:
34
- st.session_state.history = [] # list of dicts
35
- if "selected_index" not in st.session_state:
36
- st.session_state.selected_index = None
37
-
38
- # Render thumbnails & select buttons
39
- for idx, entry in enumerate(st.session_state.history):
40
- with st.sidebar.container():
41
- st.sidebar.image(entry["image_bytes"], width=100)
42
- if st.sidebar.button(f"View #{idx+1}", key=f"view_{idx}"):
43
- st.session_state.selected_index = idx
44
-
45
- # Sidebar clear-all button
46
- if st.sidebar.button("Clear History"):
47
- st.session_state.history = []
48
- st.session_state.selected_index = None
49
-
50
- # --- Main panel: image selection ---
51
- source = st.radio("Image source:",
52
- ("Upload my own image", "Use demo image"),
53
- on_change=reset_state)
54
-
55
- # Load pipelines once
56
- if "models_loaded" not in st.session_state:
57
- with st.spinner("Loading models…"):
58
- st.session_state.captioner = load_captioner()
59
- st.session_state.story_gen = load_story_gen()
60
- st.session_state.models_loaded = True
61
-
62
- captioner = st.session_state.captioner
63
- story_gen = st.session_state.story_gen
64
-
65
- # If user clicked a history entry, load it
66
- sel = st.session_state.selected_index
67
- if sel is not None:
68
- entry = st.session_state.history[sel]
69
- img = Image.open(io.BytesIO(entry["image_bytes"])).convert("RGB")
70
- st.image(img, use_container_width=True)
71
- st.markdown(f"**Caption:** {entry['caption']}")
72
- st.markdown(f"**Story:** {entry['story']}")
73
- if st.button("πŸ”Š Play Story Audio"):
74
- st.audio(data=entry["audio_bytes"], format=entry["audio_mime"])
75
- return
76
-
77
- # Otherwise, handle a fresh upload/demo
78
- if source == "Use demo image":
79
- img = Image.open("test_kids_playing.jpg").convert("RGB")
80
- # grab raw bytes for history
81
- buf = io.BytesIO()
82
- img.save(buf, format="JPEG")
83
- img_bytes = buf.getvalue()
84
- else:
85
- uploaded = st.file_uploader("Upload an image",
86
- type=["png", "jpg", "jpeg"])
87
- if not uploaded:
88
- return
89
- img = Image.open(uploaded).convert("RGB")
90
- img_bytes = uploaded.getvalue()
91
-
92
- st.image(img, use_container_width=True)
93
-
94
- # Step 1: Caption
95
- if "caption" not in st.session_state:
96
- with st.spinner("Captioning image…"):
97
- st.session_state.caption = generate_caption(captioner, img)
98
- st.markdown(f"**Caption:** {st.session_state.caption}")
99
-
100
- # Step 2: Story
101
- if "story" not in st.session_state:
102
- with st.spinner("Creating story…"):
103
- st.session_state.story = generate_story_simple(
104
- story_gen, st.session_state.caption, 50, 100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
105
  )
106
- st.markdown(f"**Story:** {st.session_state.story}")
107
-
108
- # Step 3: Audio
109
- if "audio_bytes" not in st.session_state:
110
- with st.spinner("Generating audio…"):
111
- audio_bytes, mime = generate_audio(st.session_state.story)
112
- st.session_state.audio_bytes = audio_bytes
113
- st.session_state.audio_mime = mime
114
-
115
- if st.button("πŸ”Š Play Story Audio"):
116
- st.audio(data=st.session_state.audio_bytes,
117
- format=st.session_state.audio_mime)
118
-
119
- # Step 4: Append to history (only once per new run)
120
- if not st.session_state.history or st.session_state.history[-1]["image_bytes"] != img_bytes:
121
- st.session_state.history.append({
122
- "image_bytes": img_bytes,
123
- "caption": st.session_state.caption,
124
- "story": st.session_state.story,
125
- "audio_bytes": st.session_state.audio_bytes,
126
- "audio_mime": st.session_state.audio_mime
127
- })
128
-
129
- if __name__ == "__main__":
130
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import streamlit as st
2
+ import openai
3
+ import uuid
4
+ import modules
5
+ import pandas as pd
6
+ import torch
7
+
8
+ from sentence_transformers import SentenceTransformer
9
+ import faiss
10
+ from transformers import AutoTokenizer, AutoModelForTokenClassification
11
+
12
+ # ─── LOAD DATA & MODELS (unchanged) ─────────────────────────────────────────
13
+ df_etf, df_analyst_report, available_tickers, df_annual_return_master = modules.load_etf_data()
14
+
15
+ repo_name = "hskwon7/paraphrase-MiniLM-L6-v2-ft-for-etf-semantic-search"
16
+ s2_model = SentenceTransformer(repo_name)
17
+ df_etf["doc"] = df_etf.apply(modules.make_doc_text, axis=1)
18
+ etf_list = df_etf["ticker"].tolist()
19
+ doc_embs = s2_model.encode(df_etf["doc"].tolist(), convert_to_numpy=True, show_progress_bar=True)
20
+ faiss.normalize_L2(doc_embs)
21
+ index = faiss.IndexFlatIP(doc_embs.shape[1])
22
+ index.add(doc_embs)
23
+
24
+ def semantic_search(query, top_k=100):
25
+ q_emb = s2_model.encode([query], convert_to_numpy=True)
26
+ faiss.normalize_L2(q_emb)
27
+ D, I = index.search(q_emb, top_k)
28
+ return [(etf_list[idx], float(D[0][i])) for i, idx in enumerate(I[0])]
29
+
30
+ # NER ensemble remains the same
31
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ repo1 = "hskwon7/distilbert-base-uncased-for-etf-ticker"
33
+ repo2 = "hskwon7/albert-base-v2-for-etf-ticker"
34
+ tokenizer1 = AutoTokenizer.from_pretrained(repo1)
35
+ model1 = AutoModelForTokenClassification.from_pretrained(repo1).to(device).eval()
36
+ tokenizer2 = AutoTokenizer.from_pretrained(repo2)
37
+ model2 = AutoModelForTokenClassification.from_pretrained(repo2).to(device).eval()
38
+ valid_ticker_set = set(t.upper() for t in df_etf['ticker'].unique().tolist())
39
+
40
+ def ensemble_ticker_extraction(query):
41
+ preds = set()
42
+ for tok, mdl in ((tokenizer1, model1),(tokenizer2, model2)):
43
+ enc = tok(query, return_tensors="pt").to(device)
44
+ with torch.no_grad():
45
+ logits = mdl(**enc).logits
46
+ ids = logits.argmax(dim=-1)[0].tolist()
47
+ toks = tok.convert_ids_to_tokens(enc["input_ids"][0])
48
+ labs = [mdl.config.id2label[i] for i in ids]
49
+ preds |= modules.extract_valid_tickers(toks, labs, tok, valid_ticker_set)
50
+ return preds
51
+
52
+ # ─── HELPERS ──────────────────────────────────────────────────────────────
53
+
54
+ def display_sample_query_box(task: str):
55
+ """
56
+ Renders one sample-query box for the given task
57
+ and adds a 'Try this app' button that switches the sidebar page.
58
+ """
59
+ sample_queries = {
60
+ "search_etf": {
61
+ "title": "ETF Search",
62
+ "description": "Explore ETFs based on dividend, expense ratio, sector, etc.",
63
+ "examples": ['Find me technology ETFs', 'Show me Cryptocurrency ETFs']
64
+ },
65
+ "comparison": {
66
+ "title": "ETF Comparison",
67
+ "description": "Compare two ETFs side by side on performance, risk, etc.",
68
+ "examples": ["QQQ vs. SPY", "Compare performance of QQQ with SPY"]
69
+ },
70
+ "portfolio_projection": {
71
+ "title": "ETF Portfolio",
72
+ "description": "Project a multi-ETF portfolio out over 30 years.",
73
+ "examples": ["SPY, GLD, BND", "I want to invest in SCHD and IAU"]
74
+ }
75
+ }
76
+ details = sample_queries[task]
77
+ box_html = f"""
78
+ <div style='border:1px solid #ddd;padding:1rem;border-radius:8px;'>
79
+ <h4>{details['title']}</h4>
80
+ <p style='color:#555;margin-bottom:.5rem;'>{details['description']}</p>
81
+ <p style='font-style:italic;color:#333;'>
82
+ Examples:<br/>{'<br/>'.join(details['examples'])}
83
+ </p>
84
+ </div>
85
+ """
86
+ st.markdown(box_html, unsafe_allow_html=True)
87
+ if st.button("Try this app", key=f"try_{task}"):
88
+ page_map = {
89
+ "search_etf": "ETF Search",
90
+ "comparison": "ETF Comparison",
91
+ "portfolio_projection": "ETF Portfolio"
92
+ }
93
+ st.session_state["page"] = page_map[task]
94
+ st.experimental_rerun()
95
+
96
+ def display_chat_history(task: str):
97
+ """
98
+ Shows only the chat history for a given task.
99
+ """
100
+ hist = st.session_state.get(f"all_chat_history_{task}", [])
101
+ for entry in hist:
102
+ if task == "search_etf":
103
+ st.chat_message("assistant").write(entry["response"])
104
+ modules.display_matching_etfs(entry["df"])
105
+ elif task == "comparison":
106
+ st.chat_message("assistant").write(entry["response"])
107
+ st.plotly_chart(entry["fig"], use_container_width=True)
108
+ st.dataframe(entry["df"], hide_index=True)
109
+ elif task == "portfolio_projection":
110
+ st.chat_message("assistant").write(entry["response"])
111
+ st.plotly_chart(entry["fig"], use_container_width=True)
112
+
113
+ def process_query(task: str, query: str):
114
+ """
115
+ Core logic for each sub-app.
116
+ """
117
+ # make sure top_k / top_n are always available
118
+ top_k, top_n = 100, 30
119
+
120
+ if task == 'search_etf':
121
+ with st.spinner("Searching ETFs..."):
122
+ fetched = semantic_search(query, top_k=top_k)
123
+ df_out = modules.get_etf_recommendations_from_list(fetched,
124
+ modules.get_cols_to_display(), df_etf, top_n=top_n)
125
+ st.session_state[f"all_chat_history_{task}"].append(
126
+ modules.form_d_chat_history(
127
+ result_id=str(uuid.uuid4()),
128
+ response=f"{len(df_out)} ETFs found.",
129
+ task=task,
130
+ df=df_out
131
+ )
132
+ )
133
+
134
+ elif task == 'comparison':
135
+ with st.spinner("Running comparison..."):
136
+ tickers = ensemble_ticker_extraction(query)
137
+ if len(tickers) != 2:
138
+ response, fig, df_out = (
139
+ "Please specify exactly two tickers.", None, None
140
+ )
141
+ else:
142
+ df_out = modules.get_etf_recommendations_from_list(
143
+ [(t, None) for t in tickers],
144
+ modules.get_cols_to_display(), df_etf, top_n=2
145
+ )
146
+ fig = modules.compare_etfs_interactive(*tickers)
147
+ response = f"Compared {tickers[0]} vs. {tickers[1]}."
148
+ st.session_state[f"all_chat_history_{task}"].append(
149
+ modules.form_d_chat_history(
150
+ result_id=str(uuid.uuid4()),
151
+ response=response,
152
+ task=task,
153
+ fig=fig,
154
+ df=df_out
155
+ )
156
+ )
157
+
158
+ elif task == 'portfolio_projection':
159
+ with st.spinner("Projecting portfolio..."):
160
+ fetched = semantic_search(query, top_k=top_k)
161
+ df_port = modules.run_portfolio_analysis(fetched, df_etf, df_annual_return_master)
162
+ fig = modules.portfolio_interactive_chart(df_port)
163
+ response = "30-year projection generated."
164
+ st.session_state[f"all_chat_history_{task}"].append(
165
+ modules.form_d_chat_history(
166
+ result_id=str(uuid.uuid4()),
167
+ response=response,
168
+ task=task,
169
+ fig=fig
170
  )
171
+ )
172
+
173
+ # ─── MAIN ────────────────────────────────────────────────────────────────
174
+ st.set_page_config(layout="wide")
175
+ if "page" not in st.session_state:
176
+ st.session_state["page"] = "ETF Search" # default
177
+
178
+ # initialize histories
179
+ for t in ["search_etf","comparison","portfolio_projection"]:
180
+ st.session_state.setdefault(f"all_chat_history_{t}", [])
181
+
182
+ # sidebar navigation
183
+ st.sidebar.title("ETF Assistant")
184
+ st.sidebar.radio("Go to…", ["ETF Search","ETF Comparison","ETF Portfolio"], key="page")
185
+
186
+ # dispatch
187
+ page = st.session_state["page"]
188
+ st.title(page)
189
+
190
+ if page == "ETF Search":
191
+ display_sample_query_box("search_etf")
192
+ display_chat_history("search_etf")
193
+ q = st.chat_input("Search for ETFs…", key="in_search")
194
+ if q:
195
+ process_query("search_etf", q)
196
+
197
+ elif page == "ETF Comparison":
198
+ display_sample_query_box("comparison")
199
+ display_chat_history("comparison")
200
+ q = st.chat_input("Compare ETFs…", key="in_comp")
201
+ if q:
202
+ process_query("comparison", q)
203
+
204
+ elif page == "ETF Portfolio":
205
+ display_sample_query_box("portfolio_projection")
206
+ display_chat_history("portfolio_projection")
207
+ q = st.chat_input("Project portfolio…", key="in_port")
208
+ if q:
209
+ process_query("portfolio_projection", q)