csjjin2002's picture
Upload 3 files
74ef49c verified
import dspy
import yfinance as yf
import os
import requests
import re
import matplotlib.pyplot as plt
import gradio as gr
from bs4 import BeautifulSoup
from datetime import datetime, timedelta, timezone
import io
import textwrap
import plotly.graph_objects as go
import plotly.io as pio
# === 환경 설정 ===
os.environ["USER_AGENT"] = "Mozilla/5.0"
original_get = requests.get
def patched_get(url, *args, **kwargs):
headers = kwargs.get("headers", {})
headers.update({"User-Agent": os.environ["USER_AGENT"]})
kwargs["headers"] = headers
return original_get(url, *args, **kwargs)
requests.get = patched_get
# === LLM 설정 ===
api_key = os.environ.get("OPENAI_API_KEY")
lm = dspy.LM("openai/gpt-4o", api_key=api_key)
dspy.configure(lm=lm)
# === 티커 추출 서명 ===
class ExtractTicker(dspy.Signature):
user_input: str = dspy.InputField()
suggested_ticker: str = dspy.OutputField()
ticker_extractor = dspy.Predict(ExtractTicker)
# === 한국 종목 티커 추출 서명 ===
class ExtractKRStock(dspy.Signature):
user_input: str = dspy.InputField()
suggested_kr_ticker: str = dspy.OutputField()
kr_extractor = dspy.Predict(ExtractKRStock)
def extract_ticker(user_input: str) -> str:
result = ticker_extractor(user_input=user_input)
ticker = result.suggested_ticker.strip().upper()
# 미국식 또는 한국식 티커면 그대로
if re.fullmatch(r"[A-Z]{1,5}(\.[A-Z]{2})?", ticker) or re.fullmatch(r"\d{6}\.KS", ticker):
return ticker
# 아니면 한글 기업명 추정 → 한국 티커 추출
kr_result = kr_extractor(user_input=user_input)
return kr_result.suggested_kr_ticker.strip().upper()
# === 뉴스 수집 ===
def get_yahoo_news(ticker: str):
try:
url = f"https://feeds.finance.yahoo.com/rss/2.0/headline?s={ticker}&region=US&lang=en-US"
soup = BeautifulSoup(requests.get(url).content, "xml")
one_week_ago = datetime.now(timezone.utc) - timedelta(days=7)
return [
{
"title": i.title.text,
"source": "Yahoo Finance",
"link": i.link.text
}
for i in soup.find_all("item")
if datetime.strptime(i.pubDate.text, "%a, %d %b %Y %H:%M:%S %z") >= one_week_ago
][:6]
except:
return []
# === 주가 정보 수집 ===
def fetch_stock_info(ticker: str):
try:
stock = yf.Ticker(ticker)
hist = stock.history(period="1d")
info = stock.info
if hist.empty or "longName" not in info:
return None
return {
"company": info["longName"],
"price": round(hist["Close"].iloc[-1], 2),
"change_percent": round((hist["Close"].iloc[-1] - info.get("previousClose", 0)) / max(info.get("previousClose", 1), 1) * 100, 2)
}
except:
return None
# === 리스크 스코어링 서명 ===
class StructuredRiskScoringSignature(dspy.Signature):
stock_info: str = dspy.InputField()
news: str = dspy.InputField()
overvaluation_score: str = dspy.OutputField()
overvaluation_reasoning: str = dspy.OutputField()
poor_earnings_score: str = dspy.OutputField()
poor_earnings_reasoning: str = dspy.OutputField()
financial_instability_score: str = dspy.OutputField()
financial_instability_reasoning: str = dspy.OutputField()
theme_overheating_score: str = dspy.OutputField()
theme_overheating_reasoning: str = dspy.OutputField()
recurring_negatives_score: str = dspy.OutputField()
recurring_negatives_reasoning: str = dspy.OutputField()
selloff_score: str = dspy.OutputField()
selloff_reasoning: str = dspy.OutputField()
total_score: str = dspy.OutputField()
risk_level: str = dspy.OutputField()
investment_message: str = dspy.OutputField()
risk_model = dspy.ChainOfThought(StructuredRiskScoringSignature)
def create_price_plot(ticker: str):
try:
stock = yf.Ticker(ticker)
hist = stock.history(period="1mo")
if hist.empty:
return None
fig = go.Figure()
fig.add_trace(go.Scatter(
x=hist.index,
y=hist["Close"],
mode='lines+markers',
name='Close Price',
line=dict(color='royalblue'),
marker=dict(size=6)
))
fig.update_layout(
title=f"{ticker} Price Trend (1mo)",
xaxis_title="Date",
yaxis_title="Close Price ($)",
font=dict(
family="Arial, sans-serif",
size=14,
color="#333333"
),
template="plotly_white",
height=300,
margin=dict(l=20, r=20, t=40, b=20),
plot_bgcolor="#fefbd8", # 내부
paper_bgcolor="#e0f7fa" # 외부
)
return fig # 직접 반환 (Gradio가 Plotly 지원함)
except Exception as e:
print("Plotly price chart error:", e)
return None
# === 시각화 ===
from PIL import Image
def create_risk_plot(result, company_name, ticker, total_score, risk_level, investment_message):
categories = [
"Overval.", "Earnings", "Fin. Instab.",
"Theme OH", "Neg. News", "FII Sell"
]
scores = list(map(int, [
result.overvaluation_score,
result.poor_earnings_score,
result.financial_instability_score,
result.theme_overheating_score,
result.recurring_negatives_score,
result.selloff_score
]))
# 📐 전체 배경색 (figure)
plt.figure(figsize=(10, 8), facecolor='#e0f7fa')
# 🎨 AXES 배경색
ax = plt.gca()
ax.set_facecolor('#fefbd8') # 예: 흰색 배경, 혹은 '#fefbd8', 'aliceblue' 등
bars = plt.bar(categories, scores, color='pink')
plt.ylim(0, 10)
plt.title(f"{company_name} ({ticker})", fontsize=30)
plt.ylabel("Score")
plt.xlabel("Category")
for bar in bars:
height = bar.get_height()
plt.text(bar.get_x() + bar.get_width() / 2, height + 0.2, str(height), ha='center', fontsize=12)
plt.text(
0.5, 0.91,
f"Total Score: {total_score} / Risk Level: {risk_level}",
fontsize=20, ha='center', transform=plt.gca().transAxes
)
# 여백 조정
plt.subplots_adjust(top=0.85, bottom=0.10)
buf = io.BytesIO()
plt.savefig(buf, format='png')
buf.seek(0)
plt.close()
return Image.open(buf)
# === Gradio 함수 ===
def gradio_run(user_input: str):
ticker = extract_ticker(user_input)
if not ticker:
return "❌ Unable to recognize a valid stock ticker from your input.", None
stock = fetch_stock_info(ticker)
if not stock:
return f"❌ Failed to retrieve stock information for: {ticker}", None
news = get_yahoo_news(ticker)
if news:
news_html = "<h4>📰 Cho's Pick Headlines (Recent 7 Days)</h4><ul>" + "".join(
f"<li><a href='{n['link']}' target='_blank'>{n['title']}</a> ({n['source']})</li>"
for n in news
) + "</ul>"
else:
news_html = "<h4>📰 Cho's Pick Headlines</h4><p>(No recent news available)</p>"
stock_text = f"Company: {stock['company']}\nPrice: ${stock['price']} ({stock['change_percent']}%)"
result = risk_model(stock_info=stock_text, news=news_html)
summary = f"""📌 Ticker: {stock['company']} ({ticker})
🧮 Total Score: {result.total_score}
⚠️ Risk Level: {result.risk_level}
💬 Recommendation: {result.investment_message}"""
plot_img = create_risk_plot(
result, stock["company"], ticker, result.total_score,
result.risk_level, result.investment_message
)
price_img = create_price_plot(ticker)
# 🧠 리스크 항목별 설명
annotations = f"""🧠 Category-wise Reasoning:
1️⃣ Overvaluation: {result.overvaluation_reasoning}
2️⃣ Poor Earnings: {result.poor_earnings_reasoning}
3️⃣ Financial Instability: {result.financial_instability_reasoning}
4️⃣ Theme Overheating: {result.theme_overheating_reasoning}
5️⃣ Recurring Negatives: {result.recurring_negatives_reasoning}
6️⃣ FII Sell-off: {result.selloff_reasoning}
"""
return price_img, news_html, summary, plot_img, annotations
# === Gradio 인터페이스 실행 ===
with gr.Blocks() as iface:
# ✅ 여기에 로고+제목 HTML 삽입
gr.HTML("""
<div style="display: flex; align-items: center; margin-bottom: 10px;">
<img src="https://www.hanyang.ac.kr/documents/20182/73809/HYU_logo_singlecolor_png.png/b8aabfbe-a488-437d-b4a5-bd616d1577da?t=1474070795276" style="height: 50px; margin-right: 10px;">
<h2 style="margin: 0;">HYU-Cho's 'Risk Scoring Model' for Retail Portfolios based on Chain-of-Thought</h2>
</div>
<p>Analyze and visualize the risk level of any stock using natural language input.</p>
""")
# ✅ 안내 문구 추가
with gr.Row():
gr.Markdown("📌 **Welcome to Cho's Risk Scoring Model. Enter a stock-related question to begin.**")
with gr.Row():
user_input = gr.Textbox(label="User Input", lines=2, placeholder="e.g. Is Tesla risky these days?")
submit_btn = gr.Button("Submit", variant="primary")
clear_btn = gr.Button("Clear")
# 📈 주가 흐름 제일 위로
output_price_plot = gr.Plot(label="📈 1-Month Price Trend")
# 📰 뉴스 (HTML 클릭 가능 + 타이틀)
output_news_only = gr.HTML(label=None) # 헤드라인 타이틀은 HTML 내에서 직접 표현
# 📋 요약
output_summary = gr.Textbox(label="📋 Portfolio Risk Evaluation by CoT", lines=6)
# 📊 위험도 시각화 + 상세 설명
with gr.Row():
output_plot = gr.Image(label="📊 Risk Score Visualization")
output_detail = gr.Textbox(label="🧠 Detailed Risk Reasoning", lines=25)
# 실행
submit_btn.click(fn=gradio_run, inputs=user_input,
outputs=[output_price_plot, output_news_only, output_summary, output_plot, output_detail])
clear_btn.click(lambda: ("", "", "", None, ""), inputs=[],
outputs=[output_price_plot, output_news_only, output_summary, output_plot, output_detail])
iface.launch()