AlanRex commited on
Commit
5777f47
·
verified ·
1 Parent(s): 3311f4d

Update model_predictor.py

Browse files
Files changed (1) hide show
  1. model_predictor.py +833 -0
model_predictor.py CHANGED
@@ -0,0 +1,833 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """model_predictor.ipynb
3
+
4
+ Automatically generated by Colab.
5
+
6
+ Original file is located at
7
+ https://colab.research.google.com/drive/1pIuCvafVPCRzTLojc-rZH_MFKsxMam2L
8
+ """
9
+
10
+ # model_predictor.py
11
+ # 深度學習股價預測模型 - 適用於 HUGGING_FACE_V4.2
12
+
13
+ import os
14
+ import numpy as np
15
+ import pandas as pd
16
+ import yfinance as yf
17
+ from datetime import datetime, timedelta
18
+ import warnings
19
+ warnings.filterwarnings('ignore')
20
+
21
+ # TensorFlow/Keras 相關
22
+ try:
23
+ import tensorflow as tf
24
+ from tensorflow.keras.models import Sequential, load_model
25
+ from tensorflow.keras.layers import LSTM, Dense, Dropout, BatchNormalization, LeakyReLU
26
+ from tensorflow.keras.optimizers import Adam
27
+ from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
28
+ from tensorflow.keras.regularizers import l2
29
+ from sklearn.preprocessing import MinMaxScaler, RobustScaler
30
+ from sklearn.model_selection import train_test_split
31
+ from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
32
+ TENSORFLOW_AVAILABLE = True
33
+ except ImportError:
34
+ TENSORFLOW_AVAILABLE = False
35
+
36
+ # 設定隨機種子以確保結果可重現
37
+ if TENSORFLOW_AVAILABLE:
38
+ tf.random.set_seed(42)
39
+ np.random.seed(42)
40
+
41
+ class StockPredictor:
42
+ """股價預測模型類別"""
43
+
44
+ def __init__(self):
45
+ self.model = None
46
+ self.feature_scaler = None
47
+ self.target_scalers = {} # 為每個目標變數建立獨立的縮放器
48
+ self.feature_columns = [
49
+ 'volume', 'rate', 'DJI', 'NAS', 'SOX', 'S&P_500', 'TSM_ADR',
50
+ 'RSI', 'MACD', 'MACDsign', 'MACDvol', 'K', 'D',
51
+ '+DI', '-DI', 'ADX', 'business_climate', 'PMI'
52
+ ]
53
+ self.target_columns = [
54
+ 'close_1d', 'close_5d', 'close_10d', 'close_20d', 'close_60d'
55
+ ]
56
+ self.sequence_length = 60 # 使用60天的歷史數據
57
+ self.model_path = 'lstm_stock_model.h5'
58
+ self.scalers_path = 'scalers.npz'
59
+
60
+ def fetch_yfinance_data(self, start_date='2022-09-12', end_date='2025-09-08'):
61
+ """從yfinance獲取股市數據"""
62
+ try:
63
+ # 台積電 (2330.TW) 作為主要目標股票
64
+ taiwan_stock = yf.Ticker('2314.TW')
65
+ taiwan_data = taiwan_stock.history(start=start_date, end=end_date)
66
+
67
+ if taiwan_data.empty:
68
+ print("警告: 無法獲取台灣股市數據")
69
+ return None
70
+
71
+ # 獲取美國市場數據
72
+ symbols = {
73
+ 'DJI': '^DJI',
74
+ 'NAS': '^IXIC',
75
+ 'SOX': '^SOX',
76
+ 'S&P_500': '^GSPC',
77
+ 'TSM_ADR': 'TSM'
78
+ }
79
+
80
+ us_data = {}
81
+ for name, symbol in symbols.items():
82
+ try:
83
+ ticker = yf.Ticker(symbol)
84
+ data = ticker.history(start=start_date, end=end_date)
85
+ if not data.empty:
86
+ us_data[name] = data['Close']
87
+ else:
88
+ print(f"警告: 無法獲取 {name} 數據")
89
+ except Exception as e:
90
+ print(f"獲取 {name} 數據時發生錯誤: {e}")
91
+
92
+ # 合併數據
93
+ main_df = pd.DataFrame(index=taiwan_data.index)
94
+ main_df['close'] = taiwan_data['Close']
95
+ main_df['volume'] = taiwan_data['Volume']
96
+
97
+ # 計算報酬率
98
+ main_df['rate'] = main_df['close'].pct_change()
99
+
100
+ # 添加美國市場數據
101
+ for name, data in us_data.items():
102
+ # 重新索引以匹配台灣股市交易日
103
+ main_df[name] = data.reindex(main_df.index, method='ffill')
104
+
105
+ return main_df
106
+
107
+ except Exception as e:
108
+ print(f"獲取yfinance數據時發生錯誤: {e}")
109
+ return None
110
+
111
+ def load_external_data(self):
112
+ """載入外部經濟數據"""
113
+ business_climate = pd.DataFrame()
114
+ pmi_data = pd.DataFrame()
115
+
116
+ # 載入景氣燈號數據
117
+ try:
118
+ if os.path.exists('business_climate.csv'):
119
+ business_climate = pd.read_csv('business_climate.csv')
120
+ business_climate['Date'] = pd.to_datetime(business_climate['Date'])
121
+ business_climate.set_index('Date', inplace=True)
122
+ print("成功載入景氣燈號數據")
123
+ except Exception as e:
124
+ print(f"載入景氣燈號數據失敗: {e}")
125
+
126
+ # 載入PMI數據
127
+ try:
128
+ if os.path.exists('taiwan_pmi.csv'):
129
+ pmi_data = pd.read_csv('taiwan_pmi.csv')
130
+ pmi_data['Date'] = pd.to_datetime(pmi_data['Date'])
131
+ pmi_data.set_index('Date', inplace=True)
132
+ print("成功載入PMI數據")
133
+ except Exception as e:
134
+ print(f"載入PMI數據失敗: {e}")
135
+
136
+ return business_climate, pmi_data
137
+
138
+ def calculate_technical_indicators(self, df):
139
+ """計算技術指標"""
140
+ try:
141
+ # RSI
142
+ delta = df['close'].diff()
143
+ gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
144
+ loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
145
+ rs = gain / loss
146
+ df['RSI'] = 100 - (100 / (1 + rs))
147
+
148
+ # MACD
149
+ exp1 = df['close'].ewm(span=12).mean()
150
+ exp2 = df['close'].ewm(span=26).mean()
151
+ df['MACD'] = exp1 - exp2
152
+ df['MACDsign'] = df['MACD'].ewm(span=9).mean()
153
+ df['MACDvol'] = df['MACD'] - df['MACDsign']
154
+
155
+ # KD指標
156
+ low_min = df['close'].rolling(window=9).min()
157
+ high_max = df['close'].rolling(window=9).max()
158
+ rsv = (df['close'] - low_min) / (high_max - low_min) * 100
159
+ df['K'] = rsv.ewm(com=2).mean()
160
+ df['D'] = df['K'].ewm(com=2).mean()
161
+
162
+ # DMI指標 (簡化版本,使用close價格)
163
+ df['high_low_diff'] = df['close'].rolling(2).max() - df['close'].rolling(2).min()
164
+ df['+DI'] = df['high_low_diff'].rolling(14).mean()
165
+ df['-DI'] = df['high_low_diff'].rolling(14).std()
166
+ df['ADX'] = (df['+DI'] + df['-DI']).rolling(14).mean()
167
+
168
+ # 清理臨時欄位
169
+ df.drop(['high_low_diff'], axis=1, inplace=True)
170
+
171
+ return df
172
+
173
+ except Exception as e:
174
+ print(f"計算技術指標時發生錯誤: {e}")
175
+ return df
176
+
177
+ def create_sample_data(self, days=500):
178
+ """創建示例數據用於訓練(當CSV載入失敗時的後備方案)"""
179
+ try:
180
+ print("創建示例數據進行訓練...")
181
+
182
+ # 獲取台積電數據作為基礎
183
+ taiwan_data = self.fetch_yfinance_data(
184
+ start_date=(datetime.now() - timedelta(days=days)).strftime('%Y-%m-%d'),
185
+ end_date=datetime.now().strftime('%Y-%m-%d')
186
+ )
187
+
188
+ if taiwan_data is None or taiwan_data.empty:
189
+ print("無法獲取示例數據")
190
+ return None
191
+
192
+ # 確保有基本的close和volume數據
193
+ if 'close' not in taiwan_data.columns or 'volume' not in taiwan_data.columns:
194
+ print("示例數據缺少必要欄位")
195
+ return None
196
+
197
+ # 計算技術指標
198
+ taiwan_data = self.calculate_technical_indicators(taiwan_data)
199
+
200
+ # 添加經濟指標(使用固定值)
201
+ taiwan_data['business_climate'] = 25.0
202
+ taiwan_data['PMI'] = 50.0
203
+
204
+ # 確保所有特徵欄位存在
205
+ for feature in self.feature_columns:
206
+ if feature not in taiwan_data.columns:
207
+ taiwan_data[feature] = 0.0
208
+
209
+ # 計算未來價格目標
210
+ for days in [1, 5, 10, 20, 60]:
211
+ taiwan_data[f'close_{days}d'] = taiwan_data['close'].shift(-days)
212
+
213
+ # 移除缺失值
214
+ taiwan_data = taiwan_data.dropna()
215
+
216
+ if len(taiwan_data) < 100:
217
+ print("示例數據不足")
218
+ return None
219
+
220
+ print(f"成功創建示例數據: {taiwan_data.shape}")
221
+ return taiwan_data
222
+
223
+ except Exception as e:
224
+ print(f"創建示例數據時發生錯誤: {e}")
225
+ return None
226
+ """調試CSV檔案結構"""
227
+ try:
228
+ print(f"\n=== 調試CSV檔案: {csv_path} ===")
229
+
230
+ # 讀取前幾行看看結構
231
+ with open(csv_path, 'r', encoding='utf-8') as f:
232
+ first_lines = [f.readline().strip() for _ in range(5)]
233
+
234
+ print("前5行原始內容:")
235
+ for i, line in enumerate(first_lines):
236
+ print(f"第{i+1}行: {line[:100]}...") # 只顯示前100個字符
237
+
238
+ # 嘗試不同的編碼和分隔符
239
+ encodings = ['utf-8', 'utf-8-sig', 'latin-1', 'cp1252']
240
+ separators = [',', ';', '\t', '|']
241
+
242
+ for encoding in encodings:
243
+ for sep in separators:
244
+ try:
245
+ df_test = pd.read_csv(csv_path, encoding=encoding, sep=sep, nrows=5)
246
+ if len(df_test.columns) > 5: # 如果有合理的欄位數量
247
+ print(f"\n成功讀取 (編碼: {encoding}, 分隔符: '{sep}'):")
248
+ print(f"欄位: {list(df_test.columns)}")
249
+ print(f"數據形狀: {df_test.shape}")
250
+ return encoding, sep
251
+ except:
252
+ continue
253
+
254
+ print("無法找到合適的讀取參數")
255
+ return None, None
256
+
257
+ except Exception as e:
258
+ print(f"調試CSV檔案時發生錯誤: {e}")
259
+ return None, None
260
+
261
+ def prepare_training_data(self, csv_path=None):
262
+ """準備訓練數據"""
263
+ try:
264
+ if csv_path and os.path.exists(csv_path):
265
+ # 先調試CSV檔案
266
+ encoding, separator = self.debug_csv_file(csv_path)
267
+
268
+ # 如果提供了CSV檔案,直接載入
269
+ print(f"\n從 {csv_path} 載入數據...")
270
+
271
+ # 使用找到的最佳參數讀取
272
+ read_params = {}
273
+ if encoding:
274
+ read_params['encoding'] = encoding
275
+ if separator and separator != ',':
276
+ read_params['sep'] = separator
277
+
278
+ df = pd.read_csv(csv_path, **read_params)
279
+
280
+ # 檢查CSV檔案結構
281
+ print(f"CSV檔案欄位: {list(df.columns)}")
282
+ print(f"數據形狀: {df.shape}")
283
+ print(f"前5行數據:")
284
+ print(df.head())
285
+
286
+ # 處理日期欄位
287
+ date_columns = ['Date', 'date', 'DATE', 'Unnamed: 0']
288
+ date_col = None
289
+ for col in date_columns:
290
+ if col in df.columns:
291
+ date_col = col
292
+ break
293
+
294
+ if date_col:
295
+ print(f"使用日期欄位: {date_col}")
296
+ df[date_col] = pd.to_datetime(df[date_col])
297
+ df.set_index(date_col, inplace=True)
298
+ elif df.index.dtype == 'object':
299
+ df.index = pd.to_datetime(df.index)
300
+
301
+ print(f"處理日期後的數據形狀: {df.shape}")
302
+ print(f"日期範圍: {df.index.min()} 到 {df.index.max()}")
303
+
304
+ else:
305
+ # 從yfinance和外部檔案獲取數據
306
+ print("從yfinance獲取數據...")
307
+ df = self.fetch_yfinance_data()
308
+ if df is None:
309
+ return None, None, None, None
310
+
311
+ # 計算技術指標
312
+ df = self.calculate_technical_indicators(df)
313
+
314
+ # 載入外部經濟數據
315
+ business_climate, pmi_data = self.load_external_data()
316
+
317
+ # 合併外部數據
318
+ if not business_climate.empty:
319
+ df['business_climate'] = business_climate['Index'].reindex(
320
+ df.index, method='ffill'
321
+ )
322
+ else:
323
+ df['business_climate'] = 25.0 # 預設值
324
+
325
+ if not pmi_data.empty:
326
+ df['PMI'] = pmi_data['Index'].reindex(df.index, method='ffill')
327
+ else:
328
+ df['PMI'] = 50.0 # 預設值
329
+
330
+ # 檢查並映射欄位名稱
331
+ column_mapping = {
332
+ # 可能的volume欄位名稱
333
+ 'Volume': 'volume', 'vol': 'volume', 'VOLUME': 'volume',
334
+ # 可能的close欄位名稱
335
+ 'Close': 'close', 'close_price': 'close', 'CLOSE': 'close', 'price': 'close',
336
+ # 可能的rate欄位名稱
337
+ 'Rate': 'rate', 'return': 'rate', 'pct_change': 'rate', 'RATE': 'rate',
338
+ # 美股指數
339
+ 'DJI': 'DJI', 'DOW': 'DJI', 'dow': 'DJI',
340
+ 'NAS': 'NAS', 'NASDAQ': 'NAS', 'nasdaq': 'NAS',
341
+ 'SOX': 'SOX', 'sox': 'SOX',
342
+ 'S&P_500': 'S&P_500', 'SP500': 'S&P_500', 'sp500': 'S&P_500',
343
+ 'TSM_ADR': 'TSM_ADR', 'TSM': 'TSM_ADR', 'tsm': 'TSM_ADR',
344
+ # 技術指標
345
+ 'rsi': 'RSI', 'macd': 'MACD', 'macdsign': 'MACDsign', 'macdvol': 'MACDvol',
346
+ 'k': 'K', 'd': 'D', '+di': '+DI', '-di': '-DI', 'adx': 'ADX',
347
+ # 經濟指標
348
+ 'Business_Climate': 'business_climate', 'business_climate_index': 'business_climate',
349
+ 'pmi': 'PMI', 'PMI_Index': 'PMI'
350
+ }
351
+
352
+ # 應用欄位映射
353
+ df = df.rename(columns=column_mapping)
354
+ print(f"映射後的欄位: {list(df.columns)}")
355
+
356
+ # 如果沒有close欄位但有其他價格欄位,嘗試使用
357
+ if 'close' not in df.columns:
358
+ price_candidates = ['Close', 'Price', 'CLOSE', 'close_price']
359
+ for candidate in price_candidates:
360
+ if candidate in df.columns:
361
+ df['close'] = df[candidate]
362
+ print(f"使用 {candidate} 作為 close 價格")
363
+ break
364
+
365
+ # 計算missing的技術指標(如果數據中沒有)
366
+ if 'close' in df.columns:
367
+ if 'rate' not in df.columns:
368
+ df['rate'] = df['close'].pct_change()
369
+ print("計算了price return rate")
370
+
371
+ # 如果缺少技術指標,計算它們
372
+ if 'RSI' not in df.columns:
373
+ df = self.calculate_technical_indicators(df)
374
+ print("計算了技術指標")
375
+
376
+ # 計算未來價格目標
377
+ if 'close' in df.columns:
378
+ for days in [1, 5, 10, 20, 60]:
379
+ df[f'close_{days}d'] = df['close'].shift(-days)
380
+ print("計算了未來價格目標")
381
+ else:
382
+ print("錯誤: 找不到價格數據,無法計算目標變數")
383
+ return None, None, None, None
384
+
385
+ print(f"計算目標變數後的數據形狀: {df.shape}")
386
+
387
+ # 移除缺失值
388
+ original_len = len(df)
389
+ df = df.dropna()
390
+ print(f"移除缺失值: {original_len} -> {len(df)} 行")
391
+
392
+ if df.empty:
393
+ print("錯誤: 處理後的數據集為空")
394
+ print("可能原因:")
395
+ print("1. 所有數據都有缺失值")
396
+ print("2. 日期格式不正確")
397
+ print("3. 欄位名稱不匹配")
398
+ return None, None, None, None
399
+
400
+ # 確保所有需要的欄位都存在
401
+ missing_features = set(self.feature_columns) - set(df.columns)
402
+ if missing_features:
403
+ print(f"警告: 缺少特徵欄位: {missing_features}")
404
+ # 為缺少的特徵填充預設值
405
+ for feature in missing_features:
406
+ if feature == 'business_climate':
407
+ df[feature] = 25.0 # 景氣燈號預設值
408
+ elif feature == 'PMI':
409
+ df[feature] = 50.0 # PMI預設值
410
+ else:
411
+ df[feature] = 0.0
412
+ print("已填充缺失的特徵欄位")
413
+
414
+ missing_targets = set(self.target_columns) - set(df.columns)
415
+ if missing_targets:
416
+ print(f"錯誤: 缺少目標欄位: {missing_targets}")
417
+ return None, None, None, None
418
+
419
+ # 提取特徵和目標變數
420
+ X = df[self.feature_columns].values
421
+ y = df[self.target_columns].values
422
+
423
+ print(f"數據形狀: X={X.shape}, y={y.shape}")
424
+ print(f"數據日期範圍: {df.index.min()} 到 {df.index.max()}")
425
+
426
+ return X, y, df.index, df
427
+
428
+ except Exception as e:
429
+ print(f"準備訓練數據時發生錯誤: {e}")
430
+ return None, None, None, None
431
+
432
+ def create_sequences(self, X, y):
433
+ """創建時間序列序列"""
434
+ X_seq, y_seq = [], []
435
+
436
+ for i in range(self.sequence_length, len(X)):
437
+ X_seq.append(X[i-self.sequence_length:i])
438
+ y_seq.append(y[i])
439
+
440
+ return np.array(X_seq), np.array(y_seq)
441
+
442
+ def build_model(self, input_shape, output_shape):
443
+ """建構LSTM模型"""
444
+ if not TENSORFLOW_AVAILABLE:
445
+ raise ImportError("TensorFlow未安裝,無法建立模型")
446
+
447
+ model = Sequential([
448
+ # 第一層LSTM
449
+ LSTM(128, return_sequences=True, input_shape=input_shape,
450
+ kernel_regularizer=l2(0.001)),
451
+ BatchNormalization(),
452
+ Dropout(0.2),
453
+
454
+ # 第二層LSTM
455
+ LSTM(64, return_sequences=True, kernel_regularizer=l2(0.001)),
456
+ BatchNormalization(),
457
+ Dropout(0.2),
458
+
459
+ # 第三層LSTM
460
+ LSTM(32, return_sequences=False, kernel_regularizer=l2(0.001)),
461
+ BatchNormalization(),
462
+ Dropout(0.2),
463
+
464
+ # 全連接層
465
+ Dense(64, kernel_regularizer=l2(0.001)),
466
+ LeakyReLU(alpha=0.1),
467
+ BatchNormalization(),
468
+ Dropout(0.3),
469
+
470
+ Dense(32, kernel_regularizer=l2(0.001)),
471
+ LeakyReLU(alpha=0.1),
472
+ Dropout(0.2),
473
+
474
+ # 輸出層
475
+ Dense(output_shape, activation='linear')
476
+ ])
477
+
478
+ # 編譯模型
479
+ optimizer = Adam(learning_rate=0.001, clipnorm=1.0)
480
+ model.compile(
481
+ optimizer=optimizer,
482
+ loss='huber', # 對異常值較不敏感
483
+ metrics=['mse', 'mae']
484
+ )
485
+
486
+ return model
487
+
488
+ def train_model(self, csv_path=None):
489
+ """訓練模型"""
490
+ if not TENSORFLOW_AVAILABLE:
491
+ print("錯誤: TensorFlow未安裝,無法訓練模型")
492
+ return False
493
+
494
+ print("開始準備訓練數據...")
495
+ X, y, dates, df = self.prepare_training_data(csv_path)
496
+
497
+ # 如果CSV載入失敗,嘗試使用示例數據
498
+ if (X is None or y is None) and csv_path:
499
+ print("CSV載入失敗,嘗試創建示例數據...")
500
+ df = self.create_sample_data()
501
+ if df is not None:
502
+ X = df[self.feature_columns].values
503
+ y = df[self.target_columns].values
504
+ dates = df.index
505
+ print("使用示例數據繼續訓練")
506
+
507
+ if X is None or y is None:
508
+ print("錯誤: 無法準備訓練數據")
509
+ return False
510
+
511
+ # 檢查數據質量
512
+ if len(X) < 100:
513
+ print(f"警告: 訓練數據過少 ({len(X)} 樣本),建議至少100個樣本")
514
+ return False
515
+
516
+ print("正在縮放數據...")
517
+ # 縮放特徵
518
+ self.feature_scaler = RobustScaler()
519
+ X_scaled = self.feature_scaler.fit_transform(X)
520
+
521
+ # 為每個目標變數建立獨立的縮放器
522
+ y_scaled = np.zeros_like(y)
523
+ for i, target in enumerate(self.target_columns):
524
+ scaler = RobustScaler()
525
+ y_scaled[:, i:i+1] = scaler.fit_transform(y[:, i:i+1])
526
+ self.target_scalers[target] = scaler
527
+
528
+ print("正在創建時間序列...")
529
+ X_seq, y_seq = self.create_sequences(X_scaled, y_scaled)
530
+
531
+ if len(X_seq) == 0:
532
+ print(f"錯誤: 序列長度不足,需要至少{self.sequence_length + 1}個數據點")
533
+ return False
534
+
535
+ print(f"序列形狀: X_seq={X_seq.shape}, y_seq={y_seq.shape}")
536
+
537
+ # 分割訓練和驗證集
538
+ split_idx = int(len(X_seq) * 0.8) # 使用時間順序分割而不是隨機分割
539
+ X_train, X_val = X_seq[:split_idx], X_seq[split_idx:]
540
+ y_train, y_val = y_seq[:split_idx], y_seq[split_idx:]
541
+
542
+ print(f"訓練集大小: {X_train.shape}, 驗證集大小: {X_val.shape}")
543
+
544
+ # 建立模型
545
+ print("正在建立模型...")
546
+ input_shape = (X_seq.shape[1], X_seq.shape[2])
547
+ output_shape = y_seq.shape[1]
548
+
549
+ self.model = self.build_model(input_shape, output_shape)
550
+ print(f"模型架構: 輸入={input_shape}, 輸出={output_shape}")
551
+
552
+ # 設定回調函數
553
+ callbacks = [
554
+ EarlyStopping(
555
+ monitor='val_loss',
556
+ patience=15,
557
+ restore_best_weights=True,
558
+ verbose=1
559
+ ),
560
+ ReduceLROnPlateau(
561
+ monitor='val_loss',
562
+ factor=0.5,
563
+ patience=8,
564
+ min_lr=1e-6,
565
+ verbose=1
566
+ )
567
+ ]
568
+
569
+ # 訓練模型
570
+ print("開始訓練模型...")
571
+ try:
572
+ history = self.model.fit(
573
+ X_train, y_train,
574
+ validation_data=(X_val, y_val),
575
+ epochs=50, # 減少epoch數量以加快訓練
576
+ batch_size=min(32, len(X_train) // 4), # 根據數據大小調整batch size
577
+ callbacks=callbacks,
578
+ verbose=1
579
+ )
580
+ except Exception as e:
581
+ print(f"訓練過程中發生錯誤: {e}")
582
+ return False
583
+
584
+ # 評估模型
585
+ print("\n評估模型性能...")
586
+ try:
587
+ train_loss = self.model.evaluate(X_train, y_train, verbose=0)
588
+ val_loss = self.model.evaluate(X_val, y_val, verbose=0)
589
+
590
+ print(f"訓練集損失: {train_loss[0]:.4f}")
591
+ print(f"驗證集損失: {val_loss[0]:.4f}")
592
+ except Exception as e:
593
+ print(f"評估過程中發生錯誤: {e}")
594
+
595
+ # 儲存模型和縮放器
596
+ self.save_model()
597
+
598
+ return True
599
+
600
+ def save_model(self):
601
+ """儲存模型和縮放器"""
602
+ try:
603
+ if self.model:
604
+ self.model.save(self.model_path)
605
+ print(f"模型已儲存至: {self.model_path}")
606
+
607
+ # 儲存縮放器
608
+ scalers_dict = {'feature_scaler': self.feature_scaler}
609
+ scalers_dict.update(self.target_scalers)
610
+
611
+ # 將sklearn縮放器轉換為可序列化的格式
612
+ scalers_data = {}
613
+ for name, scaler in scalers_dict.items():
614
+ if hasattr(scaler, 'scale_'):
615
+ scalers_data[f'{name}_scale'] = scaler.scale_
616
+ scalers_data[f'{name}_center'] = scaler.center_
617
+
618
+ np.savez(self.scalers_path, **scalers_data)
619
+ print(f"縮放器已儲存至: {self.scalers_path}")
620
+
621
+ except Exception as e:
622
+ print(f"儲存模型時發生錯誤: {e}")
623
+
624
+ def load_model(self):
625
+ """載入模型和縮放器"""
626
+ try:
627
+ if os.path.exists(self.model_path) and TENSORFLOW_AVAILABLE:
628
+ self.model = load_model(self.model_path)
629
+ print("模型載入成功")
630
+
631
+ # 載入縮放器
632
+ if os.path.exists(self.scalers_path):
633
+ scalers_data = np.load(self.scalers_path)
634
+
635
+ # 重建特徵縮放器
636
+ if 'feature_scaler_scale' in scalers_data:
637
+ self.feature_scaler = RobustScaler()
638
+ self.feature_scaler.scale_ = scalers_data['feature_scaler_scale']
639
+ self.feature_scaler.center_ = scalers_data['feature_scaler_center']
640
+
641
+ # 重建目標縮放器
642
+ for target in self.target_columns:
643
+ scale_key = f'{target}_scale'
644
+ center_key = f'{target}_center'
645
+ if scale_key in scalers_data:
646
+ scaler = RobustScaler()
647
+ scaler.scale_ = scalers_data[scale_key]
648
+ scaler.center_ = scalers_data[center_key]
649
+ self.target_scalers[target] = scaler
650
+
651
+ print("縮放器載入成功")
652
+
653
+ return True
654
+ else:
655
+ print("模型文件不存在或TensorFlow未安裝")
656
+ return False
657
+
658
+ except Exception as e:
659
+ print(f"載入模型時發生錯誤: {e}")
660
+ return False
661
+
662
+ # 全域預測器實例
663
+ _predictor = None
664
+
665
+ def get_predictor():
666
+ """獲取預測器實例"""
667
+ global _predictor
668
+ if _predictor is None:
669
+ _predictor = StockPredictor()
670
+ _predictor.load_model()
671
+ return _predictor
672
+
673
+ def advanced_lstm_predict(predict_days):
674
+ """
675
+ 進階LSTM預測函數 - 與main程式的介面
676
+
677
+ Args:
678
+ predict_days: 預測天數 (1, 5, 10, 20, 60)
679
+
680
+ Returns:
681
+ dict: 包含predicted_price, change_pct, confidence的字典
682
+ None: 如果預測失敗
683
+ """
684
+ try:
685
+ predictor = get_predictor()
686
+
687
+ if predictor.model is None:
688
+ print("模型未載入,無法進行預測")
689
+ return None
690
+
691
+ # 獲取最新數據進行預測
692
+ current_data = predictor.fetch_yfinance_data(
693
+ start_date=(datetime.now() - timedelta(days=90)).strftime('%Y-%m-%d'),
694
+ end_date=datetime.now().strftime('%Y-%m-%d')
695
+ )
696
+
697
+ if current_data is None or len(current_data) < predictor.sequence_length:
698
+ print("無法獲取足夠的當前數據進行預測")
699
+ return None
700
+
701
+ # 計算技術指標
702
+ current_data = predictor.calculate_technical_indicators(current_data)
703
+
704
+ # 載入外部數據
705
+ business_climate, pmi_data = predictor.load_external_data()
706
+
707
+ # 合併外部數據
708
+ if not business_climate.empty:
709
+ current_data['business_climate'] = business_climate['Index'].reindex(
710
+ current_data.index, method='ffill'
711
+ ).fillna(25.0)
712
+ else:
713
+ current_data['business_climate'] = 25.0
714
+
715
+ if not pmi_data.empty:
716
+ current_data['PMI'] = pmi_data['Index'].reindex(
717
+ current_data.index, method='ffill'
718
+ ).fillna(50.0)
719
+ else:
720
+ current_data['PMI'] = 50.0
721
+
722
+ # 填補缺失的特徵
723
+ for feature in predictor.feature_columns:
724
+ if feature not in current_data.columns:
725
+ current_data[feature] = 0.0
726
+
727
+ current_data = current_data.dropna()
728
+
729
+ if len(current_data) < predictor.sequence_length:
730
+ print("處理後的數據不足以進行預測")
731
+ return None
732
+
733
+ # 提取特徵並縮放
734
+ X_current = current_data[predictor.feature_columns].values
735
+ X_current_scaled = predictor.feature_scaler.transform(X_current)
736
+
737
+ # 創建序列
738
+ X_seq = X_current_scaled[-predictor.sequence_length:].reshape(
739
+ 1, predictor.sequence_length, len(predictor.feature_columns)
740
+ )
741
+
742
+ # 進行預測
743
+ prediction_scaled = predictor.model.predict(X_seq, verbose=0)
744
+
745
+ # 確定目標欄位索引
746
+ target_map = {1: 'close_1d', 5: 'close_5d', 10: 'close_10d',
747
+ 20: 'close_20d', 60: 'close_60d'}
748
+ target_col = target_map.get(predict_days, 'close_5d')
749
+ target_idx = predictor.target_columns.index(target_col)
750
+
751
+ # 反縮放預測結果
752
+ if target_col in predictor.target_scalers:
753
+ predicted_price = predictor.target_scalers[target_col].inverse_transform(
754
+ prediction_scaled[:, target_idx:target_idx+1]
755
+ )[0, 0]
756
+ else:
757
+ print(f"未找到 {target_col} 的縮放器")
758
+ return None
759
+
760
+ # 計算變化百分比
761
+ current_price = current_data['close'].iloc[-1]
762
+ change_pct = ((predicted_price - current_price) / current_price) * 100
763
+
764
+ # 計算信心度 (簡化版本,基於歷史波動性)
765
+ recent_volatility = current_data['close'].pct_change().tail(20).std()
766
+ confidence = max(0.5, min(0.9, 1 - recent_volatility * 5))
767
+
768
+ return {
769
+ 'predicted_price': predicted_price,
770
+ 'change_pct': change_pct,
771
+ 'confidence': confidence
772
+ }
773
+
774
+ except Exception as e:
775
+ print(f"LSTM預測時發生錯誤: {e}")
776
+ return None
777
+
778
+ def train_model_from_csv(csv_path):
779
+ """從CSV檔案訓練模型的便利函數"""
780
+ predictor = StockPredictor()
781
+ return predictor.train_model(csv_path)
782
+
783
+ if __name__ == "__main__":
784
+ # 測試模式
785
+ print("=== 股價預測模型測試 ===")
786
+
787
+ # 首先檢查TensorFlow是否可用
788
+ if not TENSORFLOW_AVAILABLE:
789
+ print("警告: TensorFlow未安裝,無法使用深度學習功能")
790
+ print("請安裝TensorFlow: pip install tensorflow")
791
+ exit(1)
792
+
793
+ # 檢查是否有CSV檔案
794
+ csv_file = "新期末專案輸入資料20220912-20250909.csv"
795
+
796
+ if os.path.exists(csv_file):
797
+ print(f"找到CSV檔案: {csv_file}")
798
+
799
+ # 先創建預測器並調試CSV
800
+ predictor = StockPredictor()
801
+
802
+ success = predictor.train_model(csv_file)
803
+ if success:
804
+ print("模型訓練完成!")
805
+ else:
806
+ print("CSV訓練失敗,嘗試使用yfinance數據...")
807
+ success = predictor.train_model()
808
+ if success:
809
+ print("使用yfinance數據訓練完成!")
810
+ else:
811
+ print("所有訓練方法都失敗!")
812
+ else:
813
+ print(f"未找到CSV檔案: {csv_file}")
814
+ print("將使用yfinance數據進行訓練...")
815
+ predictor = StockPredictor()
816
+ success = predictor.train_model()
817
+ if success:
818
+ print("模型訓練完成!")
819
+ else:
820
+ print("模型訓練失敗!")
821
+
822
+ # 測試預測
823
+ print("\n=== 測試預測功能 ===")
824
+ test_predictions = [1, 5, 10, 20, 60]
825
+
826
+ for days in test_predictions:
827
+ result = advanced_lstm_predict(days)
828
+ if result:
829
+ print(f"{days}日預測: 價格={result['predicted_price']:.2f}, "
830
+ f"變化={result['change_pct']:+.2f}%, "
831
+ f"信心度={result['confidence']:.1%}")
832
+ else:
833
+ print(f"{days}日預測失敗")