Spaces:
Sleeping
Sleeping
| import numpy as np | |
| import pandas as pd | |
| import yfinance as yf | |
| import datetime as dt | |
| import plotly.graph_objects as go | |
| from sklearn.preprocessing import MinMaxScaler | |
| from tensorflow.keras.models import load_model | |
| import gradio as gr | |
| import warnings | |
| import os | |
| import requests | |
| # Suppress warnings | |
| warnings.filterwarnings("ignore") | |
| # Constants | |
| PREDICTION_DAYS = 30 | |
| TIME_STEP = 60 | |
| DATA_YEARS = 1 | |
| # Load model | |
| model = load_model('stock_price_model.h5') | |
| model.make_predict_function() # For faster inference | |
| def preprocess_data(df): | |
| """Process yfinance data""" | |
| df.columns = [col[0] if isinstance(col, tuple) else col for col in df.columns] | |
| df = df.reset_index().rename(columns={'index': 'Date'}) | |
| df = df[['Date', 'High', 'Low', 'Open', 'Close', 'Volume']] | |
| df['Date'] = pd.to_datetime(df['Date']) | |
| df.set_index('Date', inplace=True) | |
| return df | |
| '''def get_stock_data(stock_symbol): | |
| """Fetch stock data with caching""" | |
| end_date = dt.datetime.now() | |
| start_date = end_date - dt.timedelta(days=365 * DATA_YEARS) | |
| df = yf.download(stock_symbol, start=start_date, end=end_date) | |
| return preprocess_data(df)''' | |
| def get_stock_data(stock_symbol): | |
| """日本株のデータを取得する""" | |
| end_date = dt.datetime.now() | |
| start_date = end_date - dt.timedelta(days=365 * DATA_YEARS) | |
| # YFinanceは日本株もサポート。例: トヨタ (7203.T)、ソニー (6758.T) | |
| df = yf.download(stock_symbol, start=start_date, end=end_date) | |
| return preprocess_data(df) | |
| def prepare_data(df): | |
| """Prepare data for LSTM prediction""" | |
| scaler = MinMaxScaler() | |
| scaled_data = scaler.fit_transform(df['Close'].values.reshape(-1, 1)) | |
| # Create dataset using sliding window | |
| X = np.array([scaled_data[i:i + TIME_STEP, 0] | |
| for i in range(len(scaled_data) - TIME_STEP - 1)]) | |
| y = scaled_data[TIME_STEP + 1:, 0] | |
| return X.reshape(X.shape[0], TIME_STEP, 1), y, scaler | |
| def predict_future(model, data, scaler): | |
| """Generate future predictions""" | |
| last_data = data[-TIME_STEP:].reshape(1, TIME_STEP, 1) | |
| future_preds = np.zeros(PREDICTION_DAYS, dtype='float32') | |
| for i in range(PREDICTION_DAYS): | |
| next_pred = model.predict(last_data, verbose=0)[0, 0] | |
| future_preds[i] = next_pred | |
| last_data = np.roll(last_data, -1, axis=1) | |
| last_data[0, -1, 0] = next_pred | |
| return scaler.inverse_transform(future_preds.reshape(-1, 1)) | |
| def create_plot(df, pred_data=None, future_data=None, title=""): | |
| """Create interactive Plotly figure""" | |
| fig = go.Figure() | |
| # Main price line | |
| fig.add_trace(go.Scatter( | |
| x=df.index, | |
| y=df['Close'], | |
| name='実株価', | |
| line=dict(color='blue') | |
| )) | |
| # Prediction line | |
| if pred_data is not None: | |
| fig.add_trace(go.Scatter( | |
| x=df.index[TIME_STEP + 1:], | |
| y=pred_data[:, 0], | |
| name='予想株価', | |
| line=dict(color='orange') | |
| )) | |
| # Future prediction | |
| if future_data is not None: | |
| future_dates = pd.date_range( | |
| start=df.index[-1], | |
| periods=PREDICTION_DAYS + 1 | |
| )[1:] | |
| fig.add_trace(go.Scatter( | |
| x=future_dates, | |
| y=future_data[:, 0], | |
| name='30日予測', | |
| line=dict(color='green') | |
| )) | |
| fig.update_layout( | |
| title=title, | |
| template='plotly_dark', | |
| margin=dict(l=20, r=20, t=40, b=20) | |
| ) | |
| return fig | |
| def predict_stock(stock_symbol,symbol): | |
| """Main prediction function for Gradio""" | |
| try: | |
| df = get_stock_data(stock_symbol) | |
| X, y, scaler = prepare_data(df) | |
| # Make predictions | |
| y_pred = model.predict(X) | |
| y_pred = scaler.inverse_transform(y_pred) | |
| # Future prediction | |
| future_prices = predict_future( | |
| model, | |
| scaler.transform(df['Close'].values.reshape(-1, 1)), | |
| scaler | |
| ) | |
| # Create plots | |
| main_plot = create_plot( | |
| df, | |
| pred_data=y_pred, | |
| title=f"{stock_symbol} 株価予測" | |
| ) | |
| future_plot = create_plot( | |
| df, | |
| future_data=future_prices, | |
| title=f"{stock_symbol} 30日予測" | |
| ) | |
| # Technical indicators | |
| df['SMA_50'] = df['Close'].rolling(50).mean() | |
| df['SMA_200'] = df['Close'].rolling(200).mean() | |
| tech_fig = go.Figure() | |
| tech_fig.add_trace(go.Scatter( | |
| x=df.index, y=df['Close'], | |
| name='Price', line=dict(color='blue'))) | |
| tech_fig.add_trace(go.Scatter( | |
| x=df.index, y=df['SMA_50'], | |
| name='50-Day SMA', line=dict(color='orange'))) | |
| tech_fig.add_trace(go.Scatter( | |
| x=df.index, y=df['SMA_200'], | |
| name='200-Day SMA', line=dict(color='red'))) | |
| tech_fig.update_layout( | |
| title=f"{stock_symbol} テクニカル・インジケーター", | |
| template='plotly_dark' | |
| ) | |
| return ( | |
| f"{df['Close'].iloc[-1]:.2f}", | |
| df.index[-1].strftime('%Y-%m-%d'), | |
| main_plot, | |
| future_plot, | |
| tech_fig | |
| ) | |
| except Exception as e: | |
| raise gr.Error(f"Prediction failed: {str(e)}") | |
| dataid = requests.get("https://www.ryhintl.com/dbjson/getjson?sqlcmd=select symbol_code,company_name_jp from stock_symbol") | |
| # Decode the JSON response | |
| data_str = dataid.content.decode('utf-8') | |
| data = eval(data_str) | |
| # ラベルと値のペアに変換 | |
| choices = [(item["company_name_jp"], str(item["symbol_code"])) for item in data] | |
| # Gradio Interface | |
| with gr.Blocks(title="株価予測", theme=gr.themes.Glass(), css="footer {visibility: hidden;} #header {display: flex; justify-content: space-between; align-items: center; font-size: 24px; font-weight: bold;} #logo {width: 50px; height: 50px;} .logout-btn { background-color: #3498db; border-radius: 10px; color: white; padding: 10px 20px; border: none; cursor: pointer; transparent-bg {background-color: transparent; color: black; padding: 10px; border: none;}") as demo: | |
| gr.Markdown("# 📈リアルタイム株価予測") | |
| gr.Markdown("LSTMを利用して株価の予測を行う。") | |
| with gr.Row(): | |
| symbol_input = gr.Dropdown(choices, label="ドロップダウンを選択", value="トヨタ自動車株式会社") | |
| stock_input = gr.Textbox( | |
| label="株コード (Examples: トヨタ (7203.T)、ソニー (6758.T) MSFT)", | |
| placeholder="株コードを入力してください。 例) 7203.T, 6758.T, MSFT)" | |
| ) | |
| # `symbol_input`の選択内容を変更すると`stock_input`を更新 | |
| def update_stock(selected_symbol): | |
| # 選択された企業名に対応する株コードを取得 | |
| print("selected:",selected_symbol+".T") | |
| return selected_symbol+".T" | |
| # `symbol_input`の変更イベントで`update_stock`関数を実行 | |
| symbol_input.change( | |
| update_stock, | |
| inputs=symbol_input, | |
| outputs=stock_input | |
| ) | |
| submit_btn = gr.Button("予測", variant="primary") | |
| with gr.Row(): | |
| with gr.Column(): | |
| last_price = gr.Textbox(label="終値") | |
| last_date = gr.Textbox(label="前日") | |
| with gr.Tabs(): | |
| with gr.Tab("株価予測"): | |
| main_plot = gr.Plot(label="株価予測") | |
| with gr.Tab("30日予測"): | |
| future_plot = gr.Plot(label="将来予測") | |
| with gr.Tab("テクニカル・インジケーター"): | |
| tech_plot = gr.Plot(label="テクニカル分析") | |
| submit_btn.click( | |
| fn=predict_stock, | |
| inputs=[stock_input,symbol_input], | |
| outputs=[last_price, last_date, main_plot, future_plot, tech_plot] | |
| ) | |
| # For Hugging Face Spacessymbol_input | |
| demo.launch(debug=False, favicon_path="favicon.ico") | |