fudii0921's picture
Update app.py
af37acf verified
import numpy as np
import pandas as pd
import yfinance as yf
import datetime as dt
import plotly.graph_objects as go
from sklearn.preprocessing import MinMaxScaler
from tensorflow.keras.models import load_model
import gradio as gr
import warnings
import os
import requests
# Suppress warnings
warnings.filterwarnings("ignore")
# Constants
PREDICTION_DAYS = 30
TIME_STEP = 60
DATA_YEARS = 1
# Load model
model = load_model('stock_price_model.h5')
model.make_predict_function() # For faster inference
def preprocess_data(df):
"""Process yfinance data"""
df.columns = [col[0] if isinstance(col, tuple) else col for col in df.columns]
df = df.reset_index().rename(columns={'index': 'Date'})
df = df[['Date', 'High', 'Low', 'Open', 'Close', 'Volume']]
df['Date'] = pd.to_datetime(df['Date'])
df.set_index('Date', inplace=True)
return df
'''def get_stock_data(stock_symbol):
"""Fetch stock data with caching"""
end_date = dt.datetime.now()
start_date = end_date - dt.timedelta(days=365 * DATA_YEARS)
df = yf.download(stock_symbol, start=start_date, end=end_date)
return preprocess_data(df)'''
def get_stock_data(stock_symbol):
"""日本株のデータを取得する"""
end_date = dt.datetime.now()
start_date = end_date - dt.timedelta(days=365 * DATA_YEARS)
# YFinanceは日本株もサポート。例: トヨタ (7203.T)、ソニー (6758.T)
df = yf.download(stock_symbol, start=start_date, end=end_date)
return preprocess_data(df)
def prepare_data(df):
"""Prepare data for LSTM prediction"""
scaler = MinMaxScaler()
scaled_data = scaler.fit_transform(df['Close'].values.reshape(-1, 1))
# Create dataset using sliding window
X = np.array([scaled_data[i:i + TIME_STEP, 0]
for i in range(len(scaled_data) - TIME_STEP - 1)])
y = scaled_data[TIME_STEP + 1:, 0]
return X.reshape(X.shape[0], TIME_STEP, 1), y, scaler
def predict_future(model, data, scaler):
"""Generate future predictions"""
last_data = data[-TIME_STEP:].reshape(1, TIME_STEP, 1)
future_preds = np.zeros(PREDICTION_DAYS, dtype='float32')
for i in range(PREDICTION_DAYS):
next_pred = model.predict(last_data, verbose=0)[0, 0]
future_preds[i] = next_pred
last_data = np.roll(last_data, -1, axis=1)
last_data[0, -1, 0] = next_pred
return scaler.inverse_transform(future_preds.reshape(-1, 1))
def create_plot(df, pred_data=None, future_data=None, title=""):
"""Create interactive Plotly figure"""
fig = go.Figure()
# Main price line
fig.add_trace(go.Scatter(
x=df.index,
y=df['Close'],
name='実株価',
line=dict(color='blue')
))
# Prediction line
if pred_data is not None:
fig.add_trace(go.Scatter(
x=df.index[TIME_STEP + 1:],
y=pred_data[:, 0],
name='予想株価',
line=dict(color='orange')
))
# Future prediction
if future_data is not None:
future_dates = pd.date_range(
start=df.index[-1],
periods=PREDICTION_DAYS + 1
)[1:]
fig.add_trace(go.Scatter(
x=future_dates,
y=future_data[:, 0],
name='30日予測',
line=dict(color='green')
))
fig.update_layout(
title=title,
template='plotly_dark',
margin=dict(l=20, r=20, t=40, b=20)
)
return fig
def predict_stock(stock_symbol,symbol):
"""Main prediction function for Gradio"""
try:
df = get_stock_data(stock_symbol)
X, y, scaler = prepare_data(df)
# Make predictions
y_pred = model.predict(X)
y_pred = scaler.inverse_transform(y_pred)
# Future prediction
future_prices = predict_future(
model,
scaler.transform(df['Close'].values.reshape(-1, 1)),
scaler
)
# Create plots
main_plot = create_plot(
df,
pred_data=y_pred,
title=f"{stock_symbol} 株価予測"
)
future_plot = create_plot(
df,
future_data=future_prices,
title=f"{stock_symbol} 30日予測"
)
# Technical indicators
df['SMA_50'] = df['Close'].rolling(50).mean()
df['SMA_200'] = df['Close'].rolling(200).mean()
tech_fig = go.Figure()
tech_fig.add_trace(go.Scatter(
x=df.index, y=df['Close'],
name='Price', line=dict(color='blue')))
tech_fig.add_trace(go.Scatter(
x=df.index, y=df['SMA_50'],
name='50-Day SMA', line=dict(color='orange')))
tech_fig.add_trace(go.Scatter(
x=df.index, y=df['SMA_200'],
name='200-Day SMA', line=dict(color='red')))
tech_fig.update_layout(
title=f"{stock_symbol} テクニカル・インジケーター",
template='plotly_dark'
)
return (
f"{df['Close'].iloc[-1]:.2f}",
df.index[-1].strftime('%Y-%m-%d'),
main_plot,
future_plot,
tech_fig
)
except Exception as e:
raise gr.Error(f"Prediction failed: {str(e)}")
dataid = requests.get("https://www.ryhintl.com/dbjson/getjson?sqlcmd=select symbol_code,company_name_jp from stock_symbol")
# Decode the JSON response
data_str = dataid.content.decode('utf-8')
data = eval(data_str)
# ラベルと値のペアに変換
choices = [(item["company_name_jp"], str(item["symbol_code"])) for item in data]
# Gradio Interface
with gr.Blocks(title="株価予測", theme=gr.themes.Glass(), css="footer {visibility: hidden;} #header {display: flex; justify-content: space-between; align-items: center; font-size: 24px; font-weight: bold;} #logo {width: 50px; height: 50px;} .logout-btn { background-color: #3498db; border-radius: 10px; color: white; padding: 10px 20px; border: none; cursor: pointer; transparent-bg {background-color: transparent; color: black; padding: 10px; border: none;}") as demo:
gr.Markdown("# 📈リアルタイム株価予測")
gr.Markdown("LSTMを利用して株価の予測を行う。")
with gr.Row():
symbol_input = gr.Dropdown(choices, label="ドロップダウンを選択", value="トヨタ自動車株式会社")
stock_input = gr.Textbox(
label="株コード (Examples: トヨタ (7203.T)、ソニー (6758.T) MSFT)",
placeholder="株コードを入力してください。 例) 7203.T, 6758.T, MSFT)"
)
# `symbol_input`の選択内容を変更すると`stock_input`を更新
def update_stock(selected_symbol):
# 選択された企業名に対応する株コードを取得
print("selected:",selected_symbol+".T")
return selected_symbol+".T"
# `symbol_input`の変更イベントで`update_stock`関数を実行
symbol_input.change(
update_stock,
inputs=symbol_input,
outputs=stock_input
)
submit_btn = gr.Button("予測", variant="primary")
with gr.Row():
with gr.Column():
last_price = gr.Textbox(label="終値")
last_date = gr.Textbox(label="前日")
with gr.Tabs():
with gr.Tab("株価予測"):
main_plot = gr.Plot(label="株価予測")
with gr.Tab("30日予測"):
future_plot = gr.Plot(label="将来予測")
with gr.Tab("テクニカル・インジケーター"):
tech_plot = gr.Plot(label="テクニカル分析")
submit_btn.click(
fn=predict_stock,
inputs=[stock_input,symbol_input],
outputs=[last_price, last_date, main_plot, future_plot, tech_plot]
)
# For Hugging Face Spacessymbol_input
demo.launch(debug=False, favicon_path="favicon.ico")