confer_tech2 / streamlit_app.py
molba2see's picture
Rename app.py to streamlit_app.py
5faec39 verified
raw
history blame
16.8 kB
import os
import streamlit as st
import pandas as pd
import plotly.express as px
import torch
from peft import PeftModel
from transformers import LlamaForCausalLM, LlamaTokenizerFast
import sqlite3
import yfinance as yf
# 타 코드에서 모듈 불러오기
from analyze_portfolio_risk import classify_investment_style # 사용자 성향 파악
# ----------------------------------------------------------------------
# 0. (필수) LLM 모델 로드 및 NASDAQ100 리스트 준비
# ----------------------------------------------------------------------
BASE_PATH = 'data/'
BASE_MODEL_NAME = "meta-llama/Meta-Llama-3-8B-Instruct"
ADAPTER_PATH = BASE_PATH + "earningcall"
DB_PATH = BASE_PATH + "news.db"
hf_token = os.environ.get("HF_TOKEN")
from huggingface_hub import login
login(token=hf_token)
@st.cache_data
def load_ticker_data():
TICKERS = pd.read_csv(BASE_PATH + 'ticker_list.csv')
TICKER_OPTIONS_LIST = TICKERS['display_name'].tolist()
DISPLAY_TO_TICKER_MAP = TICKERS.set_index('display_name')['Ticker'].to_dict()
TICKER_TO_PRICE_MAP = TICKERS.set_index('Ticker')['Price'].to_dict()
return TICKER_OPTIONS_LIST, DISPLAY_TO_TICKER_MAP, TICKER_TO_PRICE_MAP
@st.cache_data
def load_company_metrics():
nasdaq = pd.read_csv(BASE_PATH + 'NASDAQ100_metrics.csv')
nasdaq = nasdaq.set_index('Ticker', drop = False)
company = nasdaq.to_dict(orient='index')
return company
@st.cache_data
def load_full_df():
full_df = pd.read_csv(BASE_PATH + "us_market_metrics_sp500_nasdaq100.csv")
return full_df
@st.cache_resource # 모델처럼 무거운 객체는 캐시
def load_my_model():
base_model = LlamaForCausalLM.from_pretrained(
BASE_MODEL_NAME,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
device_map="balanced",)
tokenizer = LlamaTokenizerFast.from_pretrained(BASE_MODEL_NAME, trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token
peft_model = PeftModel.from_pretrained(base_model, ADAPTER_PATH)
return peft_model, tokenizer
def get_news_summaries_for_ticker(ticker_symbol: str):
conn = sqlite3.connect(DB_PATH)
query = """
SELECT summary
FROM articles
WHERE ticker = ?
AND summary IS NOT NULL AND LENGTH(TRIM(summary)) > 0
ORDER BY pubdate DESC
LIMIT 5
"""
cursor = conn.cursor()
rows = cursor.execute(query, (ticker_symbol.upper(),)).fetchall()
summaries = [row[0] for row in rows]
conn.close()
if not summaries:
return ["최근 뉴스 없음"]
return summaries
# 리포트 전체 생성
def generate_llm_reports(portfolio_list, override_style=None):
df = pd.DataFrame(portfolio_list)
df['total_value'] = df['quantity'] * df['price']
if override_style:
investor_style = override_style
st.toast(f"'{investor_style}' 스타일(수동)로 재생성 시작...")
else:
srisk, investor_style = classify_investment_style(full_market_df, portfolio_list)
st.toast(f"srisk: {srisk:.2f} '{investor_style}' 스타일(자동)로 생성 시작...")
reports = {}
for item in portfolio_list:
ticker = item['ticker']
if ticker in company:
report = generate_llm_report(investor_style, ticker)
reports[ticker] = report
else:
st.warning(f"{ticker} 종목은 NASDAQ100에 없어 건너뜁니다.")
continue
return reports
# LLM 리포트 생성 함수
def generate_llm_report(investor_style, ticker):
peft_model = st.session_state.peft_model
tokenizer = st.session_state.tokenizer
if not peft_model or not tokenizer:
st.error("모델이 로드되지 않았습니다. (generate_llm_report)")
return "오류: 모델 로드 실패"
company_data = str(company[ticker])
news = get_news_summaries_for_ticker(ticker)
news_text = "\n\n".join([f"- {s}" for s in news])
user_prompt = f"""Analyze all the provided data and generate a report tailored to the investor's profile.
1. Investor Style: {investor_style}
2. Company Under Review, Key Data from Corporate Filings:
{company_data}
3. Recent News (Last 10 Articles):
{news_text}
"""
system_prompt = """You are an expert financial analyst. Your mission is to write a concise, objective investment report for a client based on their specific risk profile.
ANALYSIS INSTRUCTIONS:
- Use BOTH the provided financial metrics ("Facts") and recent news headlines ("News").
- Always include the metrics listed under `core_metrics` in Facts. These are the most important indicators for the company/sector.
- Each Key Highlight should integrate at least one financial fact and one news item together (not listed separately).
- Do not hallucinate numbers that are not in Facts.
- Adjust the focus and tone strictly based on the investor's style:
If the style is SAFE (Conservative):
* Focus: Capital preservation and stable income.
* Highlight: Balance sheet strength, liquidity, predictable returns.
* Mention risks (regulatory, legal, earnings decline) first, then cautiously note positives.
* Downplay speculative or uncertain news.
If the style is NEUTRAL (Moderate):
* Focus: A balance between growth and safety.
* Highlight: Strategic trade-offs. Analyze how growth initiatives (from news) interact with financial stability (from facts).
* Present risks and opportunities in equal measure.
If the style is RISKY (Aggressive):
* Focus: High growth potential and maximum returns.
* Highlight: Exciting, forward-looking growth story. Emphasize innovation, expansion, competitive advantages.
* Frame risks as natural volatility on the path to high rewards.
* Place financial facts in the context of supporting aggressive growth.
OUTPUT REPORT TEMPLATE
Report for: A (investor_style) Investor
Company: (company_name)
1. Executive Summary:
(Provide a brief, one-paragraph summary that aligns with the investor's style, integrating at least one key metric and one recent news item.)
2. Key Analysis & Highlights:
(5–7 bullet points. Each bullet must combine a financial metric with a relevant news event, written from the perspective of the given investor style.)
3. Concluding Remark:
(One or two sentences, neutrally summarizing the company’s current standing for this type of investor.
Do NOT provide direct financial advice or buy/sell recommendations.)
IMPORTANT:
- Keep the tone professional and concise.
- Reports must be grounded in Facts and News only.
- Different investor styles should produce clearly differentiated tone and emphasis.
"""
message = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}]
tokens = tokenizer.apply_chat_template(message,tokenize=True,padding=True,add_generation_prompt=True, return_tensors="pt")
input_ids_length = tokens.shape[1]
with torch.no_grad():
res_base = peft_model.generate(tokens, max_new_tokens=1024)
result = tokenizer.decode(res_base[0, input_ids_length:], skip_special_tokens=True)
return result
# ----------------------------------------------------------------------
# 1. 세션 상태(Session State) 초기화
# ----------------------------------------------------------------------
# st.session_state : 스트림릿이 재실행되어도 값을 유지하는 마법의 변수
if 'portfolio' not in st.session_state:
st.session_state.portfolio = [] # 사용자의 포트폴리오를 저장할 리스트
if 'last_report' not in st.session_state:
st.session_state.last_report = None # 생성된 보고서를 저장할 변수
if 'peft_model' not in st.session_state:
st.session_state.peft_model = None
if 'tokenizer' not in st.session_state:
st.session_state.tokenizer = None
# ----------------------------------------------------------------------
# 2. 페이지 기본 설정
# ----------------------------------------------------------------------
st.set_page_config(page_title="AI 주식 포트폴리오 분석", layout="wide")
st.title("🤖 AI 주식 포트폴리오 보고서 생성기")
st.write("NASDAQ 100 종목을 검색하여 포트폴리오를 구성하고, 맞춤형 AI 보고서를 받아보세요.")
TICKER_OPTIONS_LIST, DISPLAY_TO_TICKER_MAP, TICKER_TO_PRICE_MAP = load_ticker_data()
company = load_company_metrics()
full_market_df = load_full_df() # (Srisk 모듈용 데이터)
# ----------------------------------------------------------------------
# 3. 입력 섹션 (종목 추가)
# ----------------------------------------------------------------------
st.subheader("1. 보유 종목 추가하기")
# 컬럼을 나눠서 UI를 깔끔하게 구성
col1, col2 = st.columns([2, 1])
with col1:
# `selectbox`를 검색 가능한 입력창으로 사용
selected_display = st.selectbox(
"종목 검색 (NASDAQ100 or S&P500 티커 또는 기업명)",
options=TICKER_OPTIONS_LIST,
index=None,
placeholder="티커를 검색하거나 선택하세요 (예: AAPL 또는 Apple)"
)
with col2:
quantity = st.number_input("보유 수량 (주)", min_value=0.01, step=0.1, format="%.2f")
# '종목 추가' 버튼
if st.button("➕ 포트폴리오에 추가", use_container_width=True):
selected_ticker = None
if selected_display:
selected_ticker = DISPLAY_TO_TICKER_MAP.get(selected_display)
current_price = TICKER_TO_PRICE_MAP.get(selected_ticker)
if selected_ticker:
st.session_state.portfolio.append({
"ticker": selected_ticker,
"quantity": quantity,
"price": current_price,
"total_value": quantity * current_price
})
st.success(f"{selected_ticker} {quantity}주 (현재가 ${current_price:,.2f})를 포트폴리오에 추가했습니다.")
else:
st.warning("종목, 수량을 모두 올바르게 입력하세요.")
# ----------------------------------------------------------------------
# 4. 포트폴리오 요약 및 보고서 생성 (스케치 레이아웃)
# ----------------------------------------------------------------------
st.subheader("2. 포트폴리오 요약 및 보고서 생성")
col_chart, col_controls = st.columns(2, gap="large")
with col_chart:
st.markdown("### 📊 포트폴리오 구성")
if st.session_state.portfolio:
df = pd.DataFrame(st.session_state.portfolio)
# Plotly 파이 차트 생성 (스케치와 유사하게)
fig = px.pie(
df,
values='total_value',
names='ticker',
hole=.3 # 도넛 차트 형태
)
fig.update_traces(textposition='inside', textinfo='percent+label')
st.plotly_chart(fig, use_container_width=True)
else:
st.info("종목을 추가하면 여기에 파이 차트가 표시됩니다.")
with col_controls:
st.markdown("### ✏️ 포트폴리오 수정 (삭제)")
if st.session_state.portfolio:
df = pd.DataFrame(st.session_state.portfolio)
edited_df = st.data_editor(
df,
column_config={
"ticker": st.column_config.TextColumn("티커", disabled=True),
"quantity": st.column_config.NumberColumn("수량", min_value=0.01, format="%.2f"),
"price": st.column_config.NumberColumn("현재가", disabled=True, format="$%.2f"),
"total_value": st.column_config.NumberColumn("총 가치", disabled=True, format="$%.2f"),
},
hide_index=True,
num_rows="dynamic",
key="portfolio_editor"
)
if not df.equals(edited_df):
# 삭제되거나 수정된 DataFrame을 다시 세션 상태(list of dicts)로 변환
st.session_state.portfolio = edited_df.to_dict('records')
st.toast("포트폴리오가 수정(삭제)되었습니다.")
st.rerun()
if st.button("🔄 포트폴리오 전체 초기화", use_container_width=True, type="secondary"):
st.session_state.portfolio = []
st.session_state.last_report = None
st.toast("포트폴리오가 초기화되었습니다.")
st.rerun()
# ----------------------------------------------------------------------
# 5. 보고서 생성 버튼 (메인 LLM 호출)
# ----------------------------------------------------------------------
if st.button("🚀 AI 보고서 생성하기", type="primary", use_container_width=True, disabled=(not st.session_state.portfolio)):
if st.session_state.peft_model and st.session_state.tokenizer:
with st.spinner("AI가 포트폴리오를 분석하고 보고서를 작성 중입니다..."):
generated_reports = generate_llm_reports(
st.session_state.portfolio)
st.session_state.last_report = generated_reports
else:
# 모델이 아직 로드 중일 때
st.warning("모델이 아직 로드 중입니다. 잠시 후 다시 시도해주세요.")
st.divider()
# ----------------------------------------------------------------------
# 6. 보고서 재생성 버튼 (메인 LLM 호출)
# ----------------------------------------------------------------------
if st.session_state.last_report:
st.markdown("##### 🔄 다른 성향으로 보고서 다시 뽑기")
col_style, col_regen = st.columns([3, 2])
with col_style:
new_style = st.selectbox(
"보고서 성향 선택",
["안정형", "공격형", "중립형"],
key="report_style_select",
label_visibility="collapsed" # 레이블 숨기기
)
with col_regen:
if st.button(f"'{new_style}' 스타일로 재생성", use_container_width=True):
with st.spinner(f"'{new_style}' 스타일로 보고서를 다시 작성 중입니다..."):
regenerated_reports = generate_llm_reports(
st.session_state.portfolio,
override_style=new_style
)
st.session_state.last_report = regenerated_reports
st.rerun() # 화면을 즉시 새로고침
# ----------------------------------------------------------------------
# 7. 보고서 출력 섹션
# ----------------------------------------------------------------------
st.divider()
if st.session_state.last_report:
st.subheader("📑 생성된 AI 보고서")
report_data = st.session_state.last_report
ordered_tickers = [item['ticker'] for item in st.session_state.portfolio if item['ticker'] in report_data]
ticker_tabs = st.tabs(ordered_tickers)
for i, ticker in enumerate(ordered_tickers):
with ticker_tabs[i]:
st.markdown(report_data[ticker]) # LLM이 생성한 마크다운 보고서 출력
else:
st.info("보고서를 생성하면 이 곳에 결과가 표시됩니다.")
# ----------------------------------------------------------------------
# 8. (신규) LLM 모델 로딩 (모든 UI를 그린 후 마지막에 실행)
# ----------------------------------------------------------------------
# peft_model, tokenizer를 st.session_state로 관리
if 'peft_model' not in st.session_state:
st.session_state.peft_model = None
if 'tokenizer' not in st.session_state:
st.session_state.tokenizer = None
# 세션에 모델이 없으면(최초 실행 시) 로드
if st.session_state.peft_model is None:
# (중요) UI를 먼저 그린 후, 스피너를 표시하며 모델 로드
with st.spinner("AI 분석 모델(LLM)을 로드 중입니다... (최초 실행 시 1-2분 소요)"):
st.session_state.peft_model, st.session_state.tokenizer = load_my_model()
# 로드가 완료되면 스피너를 없애기 위해 화면을 한 번 새로고침
st.rerun()
# 세션에 저장된 모델을 전역 변수처럼 사용
peft_model = st.session_state.peft_model
tokenizer = st.session_state.tokenizer
# (중요) 모델 로딩 실패 시 버튼 비활성화
if peft_model is None or tokenizer is None:
st.error("모델 로딩에 실패했습니다. 앱을 새로고침하거나 관리자에게 문의하세요.")