Spaces:
Running
Running
| import gradio as gr | |
| import requests | |
| from bs4 import BeautifulSoup | |
| import pandas as pd | |
| import numpy as np | |
| from sklearn.preprocessing import MinMaxScaler | |
| from tensorflow.keras.models import Sequential | |
| from tensorflow.keras.layers import LSTM, Dense, Dropout | |
| from tensorflow.keras.optimizers import Adam | |
| import matplotlib.pyplot as plt | |
| import io | |
| import matplotlib as mpl | |
| import matplotlib.font_manager as fm | |
| import tempfile | |
| import os | |
| import yfinance as yf | |
| import logging | |
| from datetime import datetime, timedelta | |
| # 設置日誌 | |
| logging.basicConfig(level=logging.INFO, | |
| format='%(asctime)s - %(levelname)s - %(message)s') | |
| # 字體設置 | |
| def setup_font(): | |
| try: | |
| url_font = "https://drive.google.com/uc?id=1eGAsTN1HBpJAkeVM57_C7ccp7hbgSz3_" | |
| response_font = requests.get(url_font) | |
| with tempfile.NamedTemporaryFile(delete=False, suffix='.ttf') as tmp_file: | |
| tmp_file.write(response_font.content) | |
| tmp_file_path = tmp_file.name | |
| fm.fontManager.addfont(tmp_file_path) | |
| mpl.rc('font', family='Taipei Sans TC Beta') | |
| except Exception as e: | |
| logging.error(f"字體設置失敗: {str(e)}") | |
| # 使用備用字體 | |
| mpl.rc('font', family='SimHei') | |
| # 網路請求設置 | |
| headers = { | |
| 'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36', | |
| 'Accept-Language': 'zh-TW,zh;q=0.9,en-US;q=0.8,en;q=0.7', | |
| 'Accept-Encoding': 'gzip, deflate, br', | |
| 'Connection': 'keep-alive', | |
| 'Upgrade-Insecure-Requests': '1' | |
| } | |
| def fetch_stock_categories(): | |
| try: | |
| url = "https://tw.stock.yahoo.com/class/" | |
| response = requests.get(url, headers=headers, timeout=10) | |
| response.raise_for_status() | |
| soup = BeautifulSoup(response.text, 'html.parser') | |
| main_categories = soup.find_all('div', class_='C($c-link-text)') | |
| data = [] | |
| for category in main_categories: | |
| main_category_name = category.find('h2', class_="Fw(b) Fz(24px) Lh(32px)") | |
| if main_category_name: | |
| main_category_name = main_category_name.text.strip() | |
| sub_categories = category.find_all('a', class_='Fz(16px) Lh(1.5) C($c-link-text) C($c-active-text):h Fw(b):h Td(n)') | |
| for sub_category in sub_categories: | |
| data.append({ | |
| '台股': main_category_name, | |
| '類股': sub_category.text.strip(), | |
| '網址': "https://tw.stock.yahoo.com" + sub_category['href'] | |
| }) | |
| category_dict = {} | |
| for item in data: | |
| if item['台股'] not in category_dict: | |
| category_dict[item['台股']] = [] | |
| category_dict[item['台股']].append({'類股': item['類股'], '網址': item['網址']}) | |
| return category_dict | |
| except Exception as e: | |
| logging.error(f"獲取股票類別失敗: {str(e)}") | |
| return {} | |
| # 股票預測模型類別 | |
| class StockPredictor: | |
| def __init__(self): | |
| self.model = None | |
| self.scaler = MinMaxScaler() | |
| def prepare_data(self, df, selected_features): | |
| scaled_data = self.scaler.fit_transform(df[selected_features]) | |
| X, y = [], [] | |
| for i in range(len(scaled_data) - 1): | |
| X.append(scaled_data[i]) | |
| y.append(scaled_data[i+1]) | |
| return np.array(X).reshape(-1, 1, len(selected_features)), np.array(y) | |
| def build_model(self, input_shape): | |
| model = Sequential([ | |
| LSTM(100, activation='relu', input_shape=input_shape, return_sequences=True), | |
| Dropout(0.2), | |
| LSTM(50, activation='relu'), | |
| Dropout(0.2), | |
| Dense(input_shape[1]) | |
| ]) | |
| model.compile(optimizer=Adam(learning_rate=0.001), loss='mse') | |
| return model | |
| def train(self, df, selected_features): | |
| X, y = self.prepare_data(df, selected_features) | |
| self.model = self.build_model((1, X.shape[2])) | |
| history = self.model.fit( | |
| X, y, | |
| epochs=50, | |
| batch_size=32, | |
| validation_split=0.2, | |
| verbose=0 | |
| ) | |
| return history | |
| def predict(self, last_data, n_days): | |
| predictions = [] | |
| current_data = last_data.copy() | |
| for _ in range(n_days): | |
| next_day = self.model.predict(current_data.reshape(1, 1, -1), verbose=0) | |
| predictions.append(next_day[0]) | |
| current_data = current_data.flatten() | |
| current_data[:len(next_day[0])] = next_day[0] | |
| current_data = current_data.reshape(1, -1) | |
| return np.array(predictions) | |
| # Gradio界面函數 | |
| def update_stocks(category): | |
| if not category or category not in category_dict: | |
| return [] | |
| return [item['類股'] for item in category_dict[category]] | |
| def get_stock_items(url): | |
| try: | |
| response = requests.get(url, headers=headers, timeout=10) | |
| response.raise_for_status() | |
| soup = BeautifulSoup(response.text, 'html.parser') | |
| stock_items = soup.find_all('li', class_='List(n)') | |
| stocks_dict = {} | |
| for item in stock_items: | |
| stock_name = item.find('div', class_='Lh(20px) Fw(600) Fz(16px) Ell') | |
| stock_code = item.find('span', class_='Fz(14px) C(#979ba7) Ell') | |
| if stock_name and stock_code: | |
| full_code = stock_code.text.strip() | |
| display_code = full_code.split('.')[0] | |
| display_name = f"{stock_name.text.strip()}{display_code}" | |
| stocks_dict[display_name] = full_code | |
| return stocks_dict | |
| except Exception as e: | |
| logging.error(f"獲取股票項目失敗: {str(e)}") | |
| return {} | |
| def update_category(category): | |
| stocks = update_stocks(category) | |
| return { | |
| stock_dropdown: gr.update(choices=stocks, value=None), | |
| stock_item_dropdown: gr.update(choices=[], value=None), | |
| stock_plot: gr.update(value=None), | |
| status_output: gr.update(value="") | |
| } | |
| def update_stock(category, stock): | |
| if not category or not stock: | |
| return { | |
| stock_item_dropdown: gr.update(choices=[], value=None), | |
| stock_plot: gr.update(value=None), | |
| status_output: gr.update(value="") | |
| } | |
| url = next((item['網址'] for item in category_dict.get(category, []) | |
| if item['類股'] == stock), None) | |
| if url: | |
| stock_items = get_stock_items(url) | |
| return { | |
| stock_item_dropdown: gr.update(choices=list(stock_items.keys()), value=None), | |
| stock_plot: gr.update(value=None), | |
| status_output: gr.update(value="") | |
| } | |
| return { | |
| stock_item_dropdown: gr.update(choices=[], value=None), | |
| stock_plot: gr.update(value=None), | |
| status_output: gr.update(value="") | |
| } | |
| def predict_stock(category, stock, stock_item, period, selected_features): | |
| if not all([category, stock, stock_item]): | |
| return gr.update(value=None), "請選擇產業類別、類股和股票" | |
| try: | |
| url = next((item['網址'] for item in category_dict.get(category, []) | |
| if item['類股'] == stock), None) | |
| if not url: | |
| return gr.update(value=None), "無法獲取類股網址" | |
| stock_items = get_stock_items(url) | |
| stock_code = stock_items.get(stock_item, "") | |
| if not stock_code: | |
| return gr.update(value=None), "無法獲取股票代碼" | |
| # 下載股票數據,根據用戶選擇的時間範圍 | |
| df = yf.download(stock_code, period=period) | |
| if df.empty: | |
| raise ValueError("無法獲取股票數據") | |
| # 預測 | |
| predictor = StockPredictor() | |
| predictor.train(df, selected_features) | |
| last_data = predictor.scaler.transform(df[selected_features].iloc[-1:].values) | |
| predictions = predictor.predict(last_data[0], 5) | |
| # 反轉預測結果 | |
| last_original = df[selected_features].iloc[-1].values | |
| predictions_original = predictor.scaler.inverse_transform( | |
| np.vstack([last_data, predictions]) | |
| ) | |
| all_predictions = np.vstack([last_original, predictions_original[1:]]) | |
| # 創建日期索引 | |
| dates = [datetime.now() + timedelta(days=i) for i in range(6)] | |
| date_labels = [d.strftime('%m/%d') for d in dates] | |
| # 繪圖 | |
| fig, ax = plt.subplots(figsize=(14, 7)) | |
| colors = ['#FF9999', '#66B2FF'] | |
| labels = [f'預測{feature}' for feature in selected_features] | |
| for i, (label, color) in enumerate(zip(labels, colors)): | |
| ax.plot(date_labels, all_predictions[:, i], label=label, | |
| marker='o', color=color, linewidth=2) | |
| for j, value in enumerate(all_predictions[:, i]): | |
| ax.annotate(f'{value:.2f}', (date_labels[j], value), | |
| textcoords="offset points", xytext=(0,10), | |
| ha='center', va='bottom') | |
| ax.set_title(f'{stock_item} 股價預測 (未來5天)', pad=20, fontsize=14) | |
| ax.set_xlabel('日期', labelpad=10) | |
| ax.set_ylabel('股價', labelpad=10) | |
| ax.legend(loc='upper left', bbox_to_anchor=(1, 1)) | |
| ax.grid(True, linestyle='--', alpha=0.7) | |
| plt.tight_layout() | |
| return gr.update(value=fig), "預測成功" | |
| except Exception as e: | |
| logging.error(f"預測過程發生錯誤: {str(e)}") | |
| return gr.update(value=None), f"預測過程發生錯誤: {str(e)}" | |
| # 初始化 | |
| setup_font() | |
| category_dict = fetch_stock_categories() | |
| categories = list(category_dict.keys()) | |
| # Gradio界面 | |
| with gr.Blocks() as demo: | |
| gr.Markdown("# 台股預測系統") | |
| with gr.Row(): | |
| with gr.Column(): | |
| category_dropdown = gr.Dropdown( | |
| choices=categories, | |
| label="產業類別", | |
| value=None | |
| ) | |
| stock_dropdown = gr.Dropdown( | |
| choices=[], | |
| label="類股", | |
| value=None | |
| ) | |
| stock_item_dropdown = gr.Dropdown( | |
| choices=[], | |
| label="股票", | |
| value=None | |
| ) | |
| period_dropdown = gr.Dropdown( | |
| choices=["1y", "6mo", "3mo", "1mo"], | |
| label="抓取時間範圍", | |
| value="1y" | |
| ) | |
| features_checkbox = gr.CheckboxGroup( | |
| choices=['Open', 'High', 'Low', 'Close', 'Adj Close', 'Volume'], | |
| label="選擇要用於預測的特徵", | |
| value=['Open', 'Close'] | |
| ) | |
| predict_button = gr.Button("開始預測", variant="primary") | |
| status_output = gr.Textbox(label="狀態", interactive=False) | |
| with gr.Row(): | |
| stock_plot = gr.Plot(label="股價預測圖") | |
| # 事件綁定 | |
| category_dropdown.change( | |
| update_category, | |
| inputs=[category_dropdown], | |
| outputs=[stock_dropdown, stock_item_dropdown, stock_plot, status_output] | |
| ) | |
| stock_dropdown.change( | |
| update_stock, | |
| inputs=[category_dropdown, stock_dropdown], | |
| outputs=[stock_item_dropdown, stock_plot, status_output] | |
| ) | |
| predict_button.click( | |
| predict_stock, | |
| inputs=[category_dropdown, stock_dropdown, stock_item_dropdown, period_dropdown, features_checkbox], | |
| outputs=[stock_plot, status_output] | |
| ) | |
| # 啟動應用 | |
| if __name__ == "__main__": | |
| demo.launch(share=False) |