confer_tech2 / src /streamlit_app.py
molba2see's picture
Update src/streamlit_app.py
5fffe72 verified
raw
history blame
16.4 kB
import os
os.environ['HF_HOME'] = '/tmp/hf_cache'
import streamlit as st
import pandas as pd
import plotly.express as px
import torch
from peft import PeftModel
from transformers import LlamaForCausalLM, LlamaTokenizerFast
import sqlite3
# 타 코드에서 모듈 불러오기
from analyze_portfolio_risk import classify_investment_style # 사용자 성향 파악
# ----------------------------------------------------------------------
# 0. (필수) LLM 모델 로드 및 NASDAQ100 리스트 준비
# ----------------------------------------------------------------------
BASE_PATH = 'src/data/'
BASE_MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"
ADAPTER_PATH = BASE_PATH + "earningcall"
DB_PATH = BASE_PATH + "news.db"
hf_token = os.environ.get("HF_TOKEN")
from huggingface_hub import login
login(token=hf_token)
@st.cache_data
def load_ticker_data():
TICKERS = pd.read_csv(BASE_PATH + 'ticker_list.csv')
TICKER_OPTIONS_LIST = TICKERS['display_name'].tolist()
DISPLAY_TO_TICKER_MAP = TICKERS.set_index('display_name')['Ticker'].to_dict()
TICKER_TO_PRICE_MAP = TICKERS.set_index('Ticker')['Price'].to_dict()
return TICKER_OPTIONS_LIST, DISPLAY_TO_TICKER_MAP, TICKER_TO_PRICE_MAP
@st.cache_data
def load_company_metrics():
nasdaq = pd.read_csv(BASE_PATH + 'NASDAQ100_metrics.csv')
nasdaq = nasdaq.set_index('Ticker', drop = False)
company = nasdaq.to_dict(orient='index')
return company
@st.cache_data
def load_full_df():
full_df = pd.read_csv(BASE_PATH + "us_market_metrics_sp500_nasdaq100.csv")
return full_df
@st.cache_resource # 모델처럼 무거운 객체는 캐시
def load_my_model():
base_model = LlamaForCausalLM.from_pretrained(
BASE_MODEL_NAME,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="balanced",)
tokenizer = LlamaTokenizerFast.from_pretrained(BASE_MODEL_NAME, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
peft_model = PeftModel.from_pretrained(base_model, ADAPTER_PATH)
return peft_model, tokenizer
def get_news_summaries_for_ticker(ticker_symbol: str):
conn = sqlite3.connect(DB_PATH)
query = """
SELECT summary
FROM articles
WHERE ticker = ?
AND summary IS NOT NULL AND LENGTH(TRIM(summary)) > 0
ORDER BY pubdate DESC
LIMIT 5
"""
cursor = conn.cursor()
rows = cursor.execute(query, (ticker_symbol.upper(),)).fetchall()
summaries = [row[0] for row in rows]
conn.close()
if not summaries:
return ["최근 뉴스 없음"]
return summaries
# 리포트 전체 생성
def generate_llm_reports(portfolio_list, override_style=None):
df = pd.DataFrame(portfolio_list)
df['total_value'] = df['quantity'] * df['price']
if override_style:
investor_style = override_style
st.toast(f"'{investor_style}' 스타일(수동)로 재생성 시작...")
else:
srisk, investor_style = classify_investment_style(full_market_df, portfolio_list)
st.toast(f"srisk: {srisk:.2f} '{investor_style}' 스타일(자동)로 생성 시작...")
reports = {}
for item in portfolio_list:
ticker = item['ticker']
if ticker in company:
report = generate_llm_report(investor_style, ticker)
reports[ticker] = report
else:
st.warning(f"{ticker} 종목은 NASDAQ100에 없어 건너뜁니다.")
continue
return reports
# LLM 리포트 생성 함수
def generate_llm_report(investor_style, ticker):
peft_model = st.session_state.peft_model
tokenizer = st.session_state.tokenizer
if not peft_model or not tokenizer:
st.error("모델이 로드되지 않았습니다. (generate_llm_report)")
return "오류: 모델 로드 실패"
company_data = str(company[ticker])
news = get_news_summaries_for_ticker(ticker)
news_text = "\n\n".join([f"- {s}" for s in news])
user_prompt = f"""Analyze all the provided data and generate a report tailored to the investor's profile.
1. Investor Style: {investor_style}
2. Company Under Review, Key Data from Corporate Filings:
{company_data}
3. Recent News (Last 10 Articles):
{news_text}
"""
system_prompt = """You are an expert financial analyst. Your mission is to write a concise, objective investment report for a client based on their specific risk profile.
ANALYSIS INSTRUCTIONS:
- Use BOTH the provided financial metrics ("Facts") and recent news headlines ("News").
- Always include the metrics listed under `core_metrics` in Facts. These are the most important indicators for the company/sector.
- Each Key Highlight should integrate at least one financial fact and one news item together (not listed separately).
- Do not hallucinate numbers that are not in Facts.
- Adjust the focus and tone strictly based on the investor's style:
If the style is SAFE (Conservative):
* Focus: Capital preservation and stable income.
* Highlight: Balance sheet strength, liquidity, predictable returns.
* Mention risks (regulatory, legal, earnings decline) first, then cautiously note positives.
* Downplay speculative or uncertain news.
If the style is NEUTRAL (Moderate):
* Focus: A balance between growth and safety.
* Highlight: Strategic trade-offs. Analyze how growth initiatives (from news) interact with financial stability (from facts).
* Present risks and opportunities in equal measure.
If the style is RISKY (Aggressive):
* Focus: High growth potential and maximum returns.
* Highlight: Exciting, forward-looking growth story. Emphasize innovation, expansion, competitive advantages.
* Frame risks as natural volatility on the path to high rewards.
* Place financial facts in the context of supporting aggressive growth.
OUTPUT REPORT TEMPLATE
Report for: A (investor_style) Investor
Company: (company_name)
1. Executive Summary:
(Provide a brief, one-paragraph summary that aligns with the investor's style, integrating at least one key metric and one recent news item.)
2. Key Analysis & Highlights:
(5–7 bullet points. Each bullet must combine a financial metric with a relevant news event, written from the perspective of the given investor style.)
3. Concluding Remark:
(One or two sentences, neutrally summarizing the company’s current standing for this type of investor.
Do NOT provide direct financial advice or buy/sell recommendations.)
IMPORTANT:
- Keep the tone professional and concise.
- Reports must be grounded in Facts and News only.
- Different investor styles should produce clearly differentiated tone and emphasis.
"""
message = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}]
tokens = tokenizer.apply_chat_template(message,tokenize=True,padding=True,add_generation_prompt=True, return_tensors="pt")
input_ids_length = tokens.shape[1]
with torch.no_grad():
res_base = peft_model.generate(tokens, max_new_tokens=1024)
result = tokenizer.decode(res_base[0, input_ids_length:], skip_special_tokens=True)
return result
# ----------------------------------------------------------------------
# 1. 세션 상태(Session State) 초기화
# ----------------------------------------------------------------------
# st.session_state : 스트림릿이 재실행되어도 값을 유지하는 마법의 변수
if 'portfolio' not in st.session_state:
st.session_state.portfolio = [] # 사용자의 포트폴리오를 저장할 리스트
if 'last_report' not in st.session_state:
st.session_state.last_report = None # 생성된 보고서를 저장할 변수
if 'peft_model' not in st.session_state:
st.session_state.peft_model = None
if 'tokenizer' not in st.session_state:
st.session_state.tokenizer = None
# ----------------------------------------------------------------------
# 2. 페이지 기본 설정
# ----------------------------------------------------------------------
st.set_page_config(page_title="AI 주식 포트폴리오 분석", layout="wide")
st.title("🤖 AI 주식 포트폴리오 보고서 생성기")
st.write("NASDAQ 100 종목을 검색하여 포트폴리오를 구성하고, 맞춤형 AI 보고서를 받아보세요.")
TICKER_OPTIONS_LIST, DISPLAY_TO_TICKER_MAP, TICKER_TO_PRICE_MAP = load_ticker_data()
company = load_company_metrics()
full_market_df = load_full_df() # (Srisk 모듈용 데이터)
# ----------------------------------------------------------------------
# 3. 입력 섹션 (종목 추가)
# ----------------------------------------------------------------------
st.subheader("1. 보유 종목 추가하기")
# 컬럼을 나눠서 UI를 깔끔하게 구성
col1, col2 = st.columns([2, 1])
with col1:
# `selectbox`를 검색 가능한 입력창으로 사용
selected_display = st.selectbox(
"종목 검색 (NASDAQ100 or S&P500 티커 또는 기업명)",
options=TICKER_OPTIONS_LIST,
index=None,
placeholder="티커를 검색하거나 선택하세요 (예: AAPL 또는 Apple)"
)
with col2:
quantity = st.number_input("보유 수량 (주)", min_value=0.01, step=0.1, format="%.2f")
# '종목 추가' 버튼
if st.button("➕ 포트폴리오에 추가", use_container_width=True):
selected_ticker = None
if selected_display:
selected_ticker = DISPLAY_TO_TICKER_MAP.get(selected_display)
current_price = TICKER_TO_PRICE_MAP.get(selected_ticker)
if selected_ticker:
st.session_state.portfolio.append({
"ticker": selected_ticker,
"quantity": quantity,
"price": current_price,
"total_value": quantity * current_price
})
st.success(f"{selected_ticker} {quantity}주 (현재가 ${current_price:,.2f})를 포트폴리오에 추가했습니다.")
else:
st.warning("종목, 수량을 모두 올바르게 입력하세요.")
# ----------------------------------------------------------------------
# 4. 포트폴리오 요약 및 보고서 생성 (스케치 레이아웃)
# ----------------------------------------------------------------------
st.subheader("2. 포트폴리오 요약 및 보고서 생성")
col_chart, col_controls = st.columns(2, gap="large")
with col_chart:
st.markdown("### 📊 포트폴리오 구성")
if st.session_state.portfolio:
df = pd.DataFrame(st.session_state.portfolio)
# Plotly 파이 차트 생성 (스케치와 유사하게)
fig = px.pie(
df,
values='total_value',
names='ticker',
hole=.3 # 도넛 차트 형태
)
fig.update_traces(textposition='inside', textinfo='percent+label')
st.plotly_chart(fig, use_container_width=True)
else:
st.info("종목을 추가하면 여기에 파이 차트가 표시됩니다.")
with col_controls:
st.markdown("### ✏️ 포트폴리오 수정 (삭제)")
if st.session_state.portfolio:
df = pd.DataFrame(st.session_state.portfolio)
edited_df = st.data_editor(
df,
column_config={
"ticker": st.column_config.TextColumn("티커", disabled=True),
"quantity": st.column_config.NumberColumn("수량", min_value=0.01, format="%.2f"),
"price": st.column_config.NumberColumn("현재가", disabled=True, format="$%.2f"),
"total_value": st.column_config.NumberColumn("총 가치", disabled=True, format="$%.2f"),
},
hide_index=True,
num_rows="dynamic",
key="portfolio_editor"
)
if not df.equals(edited_df):
# 삭제되거나 수정된 DataFrame을 다시 세션 상태(list of dicts)로 변환
st.session_state.portfolio = edited_df.to_dict('records')
st.toast("포트폴리오가 수정(삭제)되었습니다.")
st.rerun()
if st.button("🔄 포트폴리오 전체 초기화", use_container_width=True, type="secondary"):
st.session_state.portfolio = []
st.session_state.last_report = None
st.toast("포트폴리오가 초기화되었습니다.")
st.rerun()
# ----------------------------------------------------------------------
# 5. 보고서 생성 버튼 (메인 LLM 호출)
# ----------------------------------------------------------------------
if st.button("🚀 AI 보고서 생성하기", type="primary", use_container_width=True, disabled=(not st.session_state.portfolio)):
if st.session_state.peft_model and st.session_state.tokenizer:
with st.spinner("AI가 포트폴리오를 분석하고 보고서를 작성 중입니다..."):
generated_reports = generate_llm_reports(
st.session_state.portfolio)
st.session_state.last_report = generated_reports
else:
# 모델이 아직 로드 중일 때
st.warning("모델이 아직 로드 중입니다. 잠시 후 다시 시도해주세요.")
# ----------------------------------------------------------------------
# 6. 보고서 재생성 버튼 (메인 LLM 호출)
# ----------------------------------------------------------------------
if st.session_state.last_report:
st.markdown("##### 🔄 다른 성향으로 보고서 다시 뽑기")
col_style, col_regen = st.columns([3, 2])
with col_style:
new_style = st.selectbox(
"보고서 성향 선택",
["SAFE", "RISKY", "NEUTRAL"],
key="report_style_select",
label_visibility="collapsed" # 레이블 숨기기
)
with col_regen:
if st.button(f"'{new_style}' 스타일로 재생성", use_container_width=True):
with st.spinner(f"'{new_style}' 스타일로 보고서를 다시 작성 중입니다..."):
regenerated_reports = generate_llm_reports(
st.session_state.portfolio,
override_style=new_style
)
st.session_state.last_report = regenerated_reports
st.rerun() # 화면을 즉시 새로고침
# ----------------------------------------------------------------------
# 7. 보고서 출력 섹션
# ----------------------------------------------------------------------
st.divider()
if st.session_state.last_report:
st.subheader("📑 생성된 AI 보고서")
report_data = st.session_state.last_report
ordered_tickers = [item['ticker'] for item in st.session_state.portfolio if item['ticker'] in report_data]
ticker_tabs = st.tabs(ordered_tickers)
for i, ticker in enumerate(ordered_tickers):
with ticker_tabs[i]:
st.markdown(report_data[ticker]) # LLM이 생성한 마크다운 보고서 출력
else:
st.info("보고서를 생성하면 이 곳에 결과가 표시됩니다.")
# ----------------------------------------------------------------------
# 8. (신규) LLM 모델 로딩 (모든 UI를 그린 후 마지막에 실행)
# ----------------------------------------------------------------------
# peft_model, tokenizer를 st.session_state로 관리
if 'peft_model' not in st.session_state:
st.session_state.peft_model = None
if 'tokenizer' not in st.session_state:
st.session_state.tokenizer = None
# 세션에 모델이 없으면(최초 실행 시) 로드
if st.session_state.peft_model is None:
# (중요) UI를 먼저 그린 후, 스피너를 표시하며 모델 로드
with st.spinner("AI 분석 모델(LLM)을 로드 중입니다... (최초 실행 시 1-2분 소요)"):
st.session_state.peft_model, st.session_state.tokenizer = load_my_model()
# 로드가 완료되면 스피너를 없애기 위해 화면을 한 번 새로고침
st.rerun()
# 세션에 저장된 모델을 전역 변수처럼 사용
peft_model = st.session_state.peft_model
tokenizer = st.session_state.tokenizer
# (중요) 모델 로딩 실패 시 버튼 비활성화
if peft_model is None or tokenizer is None:
st.error("모델 로딩에 실패했습니다. 앱을 새로고침하거나 관리자에게 문의하세요.")