molba2see commited on
Commit
df159f4
·
verified ·
1 Parent(s): 5f0840e

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +388 -38
src/streamlit_app.py CHANGED
@@ -1,40 +1,390 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
  import streamlit as st
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
1
+ import os
 
 
2
  import streamlit as st
3
+ import pandas as pd
4
+ import plotly.express as px
5
+ import torch
6
+ from peft import PeftModel
7
+ from transformers import LlamaForCausalLM, LlamaTokenizerFast
8
+ import sqlite3
9
+ import yfinance as yf
10
+
11
+ # 타 코드에서 모듈 불러오기
12
+ from analyze_portfolio_risk import classify_investment_style # 사용자 성향 파악
13
+ # ----------------------------------------------------------------------
14
+ # 0. (필수) LLM 모델 로드 및 NASDAQ100 리스트 준비
15
+ # ----------------------------------------------------------------------
16
+
17
+ BASE_PATH = 'data/'
18
+ BASE_MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"
19
+ ADAPTER_PATH = BASE_PATH + "earningcall"
20
+ DB_PATH = BASE_PATH + "news.db"
21
+
22
+ hf_token = os.environ.get("HF_TOKEN")
23
+ from huggingface_hub import login
24
+ login(token=hf_token)
25
+
26
+ @st.cache_data
27
+ def load_ticker_data():
28
+ TICKERS = pd.read_csv(BASE_PATH + 'ticker_list.csv')
29
+ TICKER_OPTIONS_LIST = TICKERS['display_name'].tolist()
30
+ DISPLAY_TO_TICKER_MAP = TICKERS.set_index('display_name')['Ticker'].to_dict()
31
+ TICKER_TO_PRICE_MAP = TICKERS.set_index('Ticker')['Price'].to_dict()
32
+ return TICKER_OPTIONS_LIST, DISPLAY_TO_TICKER_MAP, TICKER_TO_PRICE_MAP
33
+
34
+ @st.cache_data
35
+ def load_company_metrics():
36
+ nasdaq = pd.read_csv(BASE_PATH + 'NASDAQ100_metrics.csv')
37
+ nasdaq = nasdaq.set_index('Ticker', drop = False)
38
+ company = nasdaq.to_dict(orient='index')
39
+ return company
40
+
41
+ @st.cache_data
42
+ def load_full_df():
43
+ full_df = pd.read_csv(BASE_PATH + "us_market_metrics_sp500_nasdaq100.csv")
44
+ return full_df
45
+
46
+ @st.cache_resource # 모델처럼 무거운 객체는 캐시
47
+ def load_my_model():
48
+ base_model = LlamaForCausalLM.from_pretrained(
49
+ BASE_MODEL_NAME,
50
+ torch_dtype=torch.bfloat16,
51
+ trust_remote_code=True,
52
+ device_map="balanced",)
53
+
54
+ tokenizer = LlamaTokenizerFast.from_pretrained(BASE_MODEL_NAME, trust_remote_code=True)
55
+ tokenizer.pad_token = tokenizer.eos_token
56
+
57
+ peft_model = PeftModel.from_pretrained(base_model, ADAPTER_PATH)
58
+ return peft_model, tokenizer
59
+
60
+
61
+ def get_news_summaries_for_ticker(ticker_symbol: str):
62
+ conn = sqlite3.connect(DB_PATH)
63
+ query = """
64
+ SELECT summary
65
+ FROM articles
66
+ WHERE ticker = ?
67
+ AND summary IS NOT NULL AND LENGTH(TRIM(summary)) > 0
68
+ ORDER BY pubdate DESC
69
+ LIMIT 5
70
+ """
71
+ cursor = conn.cursor()
72
+
73
+ rows = cursor.execute(query, (ticker_symbol.upper(),)).fetchall()
74
+ summaries = [row[0] for row in rows]
75
+ conn.close()
76
+ if not summaries:
77
+ return ["최근 뉴스 없음"]
78
+ return summaries
79
+
80
+ # 리포트 전체 생성
81
+ def generate_llm_reports(portfolio_list, override_style=None):
82
+ df = pd.DataFrame(portfolio_list)
83
+ df['total_value'] = df['quantity'] * df['price']
84
+
85
+ if override_style:
86
+ investor_style = override_style
87
+ st.toast(f"'{investor_style}' 스타일(수동)로 재생성 시작...")
88
+ else:
89
+ srisk, investor_style = classify_investment_style(full_market_df, portfolio_list)
90
+ st.toast(f"srisk: {srisk:.2f} '{investor_style}' 스타일(자동)로 생성 시작...")
91
+
92
+ reports = {}
93
+
94
+ for item in portfolio_list:
95
+ ticker = item['ticker']
96
+
97
+ if ticker in company:
98
+ report = generate_llm_report(investor_style, ticker)
99
+ reports[ticker] = report
100
+ else:
101
+ st.warning(f"{ticker} 종목은 NASDAQ100에 없어 건너뜁니다.")
102
+ continue
103
+
104
+ return reports
105
+
106
+ # LLM 리포트 생성 함수
107
+ def generate_llm_report(investor_style, ticker):
108
+
109
+ peft_model = st.session_state.peft_model
110
+ tokenizer = st.session_state.tokenizer
111
+ if not peft_model or not tokenizer:
112
+ st.error("모델이 로드되지 않았습니다. (generate_llm_report)")
113
+ return "오류: 모델 로드 실패"
114
+
115
+ company_data = str(company[ticker])
116
+
117
+ news = get_news_summaries_for_ticker(ticker)
118
+ news_text = "\n\n".join([f"- {s}" for s in news])
119
+
120
+ user_prompt = f"""Analyze all the provided data and generate a report tailored to the investor's profile.
121
+ 1. Investor Style: {investor_style}
122
+
123
+ 2. Company Under Review, Key Data from Corporate Filings:
124
+ {company_data}
125
+
126
+ 3. Recent News (Last 10 Articles):
127
+ {news_text}
128
+ """
129
+ system_prompt = """You are an expert financial analyst. Your mission is to write a concise, objective investment report for a client based on their specific risk profile.
130
+
131
+ ANALYSIS INSTRUCTIONS:
132
+ - Use BOTH the provided financial metrics ("Facts") and recent news headlines ("News").
133
+ - Always include the metrics listed under `core_metrics` in Facts. These are the most important indicators for the company/sector.
134
+ - Each Key Highlight should integrate at least one financial fact and one news item together (not listed separately).
135
+ - Do not hallucinate numbers that are not in Facts.
136
+ - Adjust the focus and tone strictly based on the investor's style:
137
+
138
+ If the style is SAFE (Conservative):
139
+ * Focus: Capital preservation and stable income.
140
+ * Highlight: Balance sheet strength, liquidity, predictable returns.
141
+ * Mention risks (regulatory, legal, earnings decline) first, then cautiously note positives.
142
+ * Downplay speculative or uncertain news.
143
+
144
+ If the style is NEUTRAL (Moderate):
145
+ * Focus: A balance between growth and safety.
146
+ * Highlight: Strategic trade-offs. Analyze how growth initiatives (from news) interact with financial stability (from facts).
147
+ * Present risks and opportunities in equal measure.
148
+
149
+ If the style is RISKY (Aggressive):
150
+ * Focus: High growth potential and maximum returns.
151
+ * Highlight: Exciting, forward-looking growth story. Emphasize innovation, expansion, competitive advantages.
152
+ * Frame risks as natural volatility on the path to high rewards.
153
+ * Place financial facts in the context of supporting aggressive growth.
154
+
155
+ OUTPUT REPORT TEMPLATE
156
+ Report for: A (investor_style) Investor
157
+ Company: (company_name)
158
+
159
+ 1. Executive Summary:
160
+ (Provide a brief, one-paragraph summary that aligns with the investor's style, integrating at least one key metric and one recent news item.)
161
+
162
+ 2. Key Analysis & Highlights:
163
+ (5–7 bullet points. Each bullet must combine a financial metric with a relevant news event, written from the perspective of the given investor style.)
164
+
165
+ 3. Concluding Remark:
166
+ (One or two sentences, neutrally summarizing the company’s current standing for this type of investor.
167
+ Do NOT provide direct financial advice or buy/sell recommendations.)
168
+
169
+ IMPORTANT:
170
+ - Keep the tone professional and concise.
171
+ - Reports must be grounded in Facts and News only.
172
+ - Different investor styles should produce clearly differentiated tone and emphasis.
173
+ """
174
+
175
+ message = [
176
+ {"role": "system", "content": system_prompt},
177
+ {"role": "user", "content": user_prompt}]
178
+
179
+ tokens = tokenizer.apply_chat_template(message,tokenize=True,padding=True,add_generation_prompt=True, return_tensors="pt")
180
+ input_ids_length = tokens.shape[1]
181
+
182
+ with torch.no_grad():
183
+ res_base = peft_model.generate(tokens, max_new_tokens=1024)
184
+
185
+ result = tokenizer.decode(res_base[0, input_ids_length:], skip_special_tokens=True)
186
+
187
+ return result
188
+
189
+
190
+ # ----------------------------------------------------------------------
191
+ # 1. 세션 상태(Session State) 초기화
192
+ # ----------------------------------------------------------------------
193
+ # st.session_state : 스트림릿이 재실행되어도 값을 유지하는 마법의 변수
194
+ if 'portfolio' not in st.session_state:
195
+ st.session_state.portfolio = [] # 사용자의 포트폴리오를 저장할 리스트
196
+ if 'last_report' not in st.session_state:
197
+ st.session_state.last_report = None # 생성된 보고서를 저장할 변수
198
+
199
+ if 'peft_model' not in st.session_state:
200
+ st.session_state.peft_model = None
201
+ if 'tokenizer' not in st.session_state:
202
+ st.session_state.tokenizer = None
203
+
204
+ # ----------------------------------------------------------------------
205
+ # 2. 페이지 기본 설정
206
+ # ----------------------------------------------------------------------
207
+ st.set_page_config(page_title="AI 주식 포트폴리오 분석", layout="wide")
208
+ st.title("🤖 AI 주식 포트폴리오 보고서 생성기")
209
+ st.write("NASDAQ 100 종목을 검색하여 포트폴리오를 구성하고, 맞춤형 AI 보고서를 받아보세요.")
210
+
211
+ TICKER_OPTIONS_LIST, DISPLAY_TO_TICKER_MAP, TICKER_TO_PRICE_MAP = load_ticker_data()
212
+ company = load_company_metrics()
213
+ full_market_df = load_full_df() # (Srisk 모듈용 데이터)
214
+
215
+ # ----------------------------------------------------------------------
216
+ # 3. 입력 섹션 (종목 추가)
217
+ # ----------------------------------------------------------------------
218
+ st.subheader("1. 보유 종목 추가하기")
219
+
220
+ # 컬럼을 나눠서 UI를 깔끔하게 구성
221
+ col1, col2 = st.columns([2, 1])
222
+
223
+ with col1:
224
+ # `selectbox`를 검색 가능한 입력창으로 사용
225
+ selected_display = st.selectbox(
226
+ "종목 검색 (NASDAQ100 or S&P500 티커 또는 기업명)",
227
+ options=TICKER_OPTIONS_LIST,
228
+ index=None,
229
+ placeholder="티커를 검색하거나 선택하세요 (예: AAPL 또는 Apple)"
230
+ )
231
+ with col2:
232
+ quantity = st.number_input("보유 수량 (주)", min_value=0.01, step=0.1, format="%.2f")
233
+
234
+ # '종목 추가' 버튼
235
+ if st.button("➕ 포트폴리오에 추가", use_container_width=True):
236
+ selected_ticker = None
237
+ if selected_display:
238
+ selected_ticker = DISPLAY_TO_TICKER_MAP.get(selected_display)
239
+ current_price = TICKER_TO_PRICE_MAP.get(selected_ticker)
240
+
241
+ if selected_ticker:
242
+ st.session_state.portfolio.append({
243
+ "ticker": selected_ticker,
244
+ "quantity": quantity,
245
+ "price": current_price,
246
+ "total_value": quantity * current_price
247
+ })
248
+ st.success(f"{selected_ticker} {quantity}주 (현재가 ${current_price:,.2f})를 포트폴리오에 추가했습니다.")
249
+ else:
250
+ st.warning("종목, 수량을 모두 올바르게 입력하세요.")
251
+
252
+ # ----------------------------------------------------------------------
253
+ # 4. 포트폴리오 요약 및 보고서 생성 (스케치 레이아웃)
254
+ # ----------------------------------------------------------------------
255
+ st.subheader("2. 포트폴리오 요약 및 보고서 생성")
256
+
257
+ col_chart, col_controls = st.columns(2, gap="large")
258
+
259
+ with col_chart:
260
+ st.markdown("### 📊 포트폴리오 구성")
261
+ if st.session_state.portfolio:
262
+ df = pd.DataFrame(st.session_state.portfolio)
263
+
264
+ # Plotly 파이 차트 생성 (스케치와 유사하게)
265
+ fig = px.pie(
266
+ df,
267
+ values='total_value',
268
+ names='ticker',
269
+ hole=.3 # 도넛 차트 형태
270
+ )
271
+ fig.update_traces(textposition='inside', textinfo='percent+label')
272
+ st.plotly_chart(fig, use_container_width=True)
273
+
274
+ else:
275
+ st.info("종목을 추가하면 여기에 파이 차트가 표시됩니다.")
276
+
277
+ with col_controls:
278
+ st.markdown("### ✏️ 포트폴리오 수정 (삭제)")
279
+ if st.session_state.portfolio:
280
+ df = pd.DataFrame(st.session_state.portfolio)
281
+
282
+ edited_df = st.data_editor(
283
+ df,
284
+ column_config={
285
+ "ticker": st.column_config.TextColumn("티커", disabled=True),
286
+ "quantity": st.column_config.NumberColumn("수량", min_value=0.01, format="%.2f"),
287
+ "price": st.column_config.NumberColumn("현재가", disabled=True, format="$%.2f"),
288
+ "total_value": st.column_config.NumberColumn("총 가치", disabled=True, format="$%.2f"),
289
+ },
290
+ hide_index=True,
291
+ num_rows="dynamic",
292
+ key="portfolio_editor"
293
+ )
294
+
295
+ if not df.equals(edited_df):
296
+ # 삭제되거나 수정된 DataFrame을 다시 세션 상태(list of dicts)로 변환
297
+ st.session_state.portfolio = edited_df.to_dict('records')
298
+ st.toast("포트폴리오가 수정(삭제)되었습니다.")
299
+ st.rerun()
300
+
301
+ if st.button("🔄 포트폴리오 전체 초기화", use_container_width=True, type="secondary"):
302
+ st.session_state.portfolio = []
303
+ st.session_state.last_report = None
304
+ st.toast("포트폴리오가 초기화되었습니다.")
305
+ st.rerun()
306
+
307
+ # ----------------------------------------------------------------------
308
+ # 5. 보고서 생성 버튼 (메인 LLM 호출)
309
+ # ----------------------------------------------------------------------
310
+ if st.button("🚀 AI 보고서 생성하기", type="primary", use_container_width=True, disabled=(not st.session_state.portfolio)):
311
+ if st.session_state.peft_model and st.session_state.tokenizer:
312
+ with st.spinner("AI가 포트폴리오를 분석하고 보고서를 작성 중입니다..."):
313
+ generated_reports = generate_llm_reports(
314
+ st.session_state.portfolio)
315
+ st.session_state.last_report = generated_reports
316
+ else:
317
+ # 모델이 아직 로드 중일 때
318
+ st.warning("모델이 아직 로드 중입니다. 잠시 후 다시 시도해주세요.")
319
+ st.divider()
320
+
321
+ # ----------------------------------------------------------------------
322
+ # 6. 보고서 재생성 버튼 (메인 LLM 호출)
323
+ # ----------------------------------------------------------------------
324
+ if st.session_state.last_report:
325
+ st.markdown("##### 🔄 다른 성향으로 보고서 다시 뽑기")
326
+
327
+ col_style, col_regen = st.columns([3, 2])
328
+
329
+ with col_style:
330
+ new_style = st.selectbox(
331
+ "보고서 성향 선택",
332
+ ["안정형", "공격형", "중립형"],
333
+ key="report_style_select",
334
+ label_visibility="collapsed" # 레이블 숨기기
335
+ )
336
+
337
+ with col_regen:
338
+ if st.button(f"'{new_style}' 스타일로 재생성", use_container_width=True):
339
+ with st.spinner(f"'{new_style}' 스타일로 보고서를 다시 작성 중입니다..."):
340
+ regenerated_reports = generate_llm_reports(
341
+ st.session_state.portfolio,
342
+ override_style=new_style
343
+ )
344
+ st.session_state.last_report = regenerated_reports
345
+ st.rerun() # 화면을 즉시 새로고침
346
+
347
+ # ----------------------------------------------------------------------
348
+ # 7. 보고서 출력 섹션
349
+ # ----------------------------------------------------------------------
350
+ st.divider()
351
+
352
+ if st.session_state.last_report:
353
+ st.subheader("📑 생성된 AI 보고서")
354
+
355
+ report_data = st.session_state.last_report
356
+ ordered_tickers = [item['ticker'] for item in st.session_state.portfolio if item['ticker'] in report_data]
357
+ ticker_tabs = st.tabs(ordered_tickers)
358
+
359
+ for i, ticker in enumerate(ordered_tickers):
360
+ with ticker_tabs[i]:
361
+ st.markdown(report_data[ticker]) # LLM이 생성한 마크다운 보고서 출력
362
+ else:
363
+ st.info("보고서를 생성하면 이 곳에 결과가 표시됩니다.")
364
+
365
+ # ----------------------------------------------------------------------
366
+ # 8. (신규) LLM 모델 로딩 (모든 UI를 그린 후 마지막에 실행)
367
+ # ----------------------------------------------------------------------
368
+
369
+ # peft_model, tokenizer를 st.session_state로 관리
370
+ if 'peft_model' not in st.session_state:
371
+ st.session_state.peft_model = None
372
+ if 'tokenizer' not in st.session_state:
373
+ st.session_state.tokenizer = None
374
+
375
+ # 세션에 모델이 없으면(최초 실행 시) 로드
376
+ if st.session_state.peft_model is None:
377
+ # (중요) UI를 먼저 그린 후, 스피너를 표시하며 모델 로드
378
+ with st.spinner("AI 분석 모델(LLM)을 로드 중입니다... (최초 실행 시 1-2분 소요)"):
379
+ st.session_state.peft_model, st.session_state.tokenizer = load_my_model()
380
+
381
+ # 로드가 완료되면 스피너를 없애기 위해 화면을 한 번 새로고침
382
+ st.rerun()
383
+
384
+ # 세션에 저장된 모델을 전역 변수처럼 사용
385
+ peft_model = st.session_state.peft_model
386
+ tokenizer = st.session_state.tokenizer
387
 
388
+ # (중요) 모델 로딩 실패 시 버튼 비활성화
389
+ if peft_model is None or tokenizer is None:
390
+ st.error("모델 로딩에 실패했습니다. 앱을 새로고침하거나 관리자에게 문의하세요.")