AlanRex commited on
Commit
146b63d
·
verified ·
1 Parent(s): 20e725c

Upload 4 files

Browse files
Files changed (4) hide show
  1. lstm_model.pth +3 -0
  2. model_predictor.py +343 -534
  3. scaler_X.pkl +3 -0
  4. scaler_y.pkl +3 -0
lstm_model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c857878b0e754a7034e1e5c54eb540557264a01dbf050cb11820f7bb43b7ddcc
3
+ size 141685
model_predictor.py CHANGED
@@ -4,569 +4,378 @@
4
  Automatically generated by Colab.
5
 
6
  Original file is located at
7
- https://colab.research.google.com/drive/1CaAPRdPsp3Jt5tQ3BLVcK19euWZmFme5
8
  """
9
 
10
  # model_predictor.py
11
- # 進階LSTM模型預測器,適用於HUGING_FACE_V4.2
 
 
 
 
 
12
 
13
  import os
14
- import numpy as np
15
  import pandas as pd
 
 
 
 
 
 
16
  import yfinance as yf
17
- from datetime import datetime, timedelta
18
- import joblib
19
- from sklearn.preprocessing import StandardScaler, RobustScaler
20
- from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score
21
- import warnings
22
- warnings.filterwarnings('ignore')
23
-
24
- # TensorFlow/Keras 相關
25
- try:
26
- import tensorflow as tf
27
- from tensorflow.keras.models import Sequential, load_model
28
- from tensorflow.keras.layers import LSTM, Dense, Dropout, BatchNormalization, GRU, Bidirectional
29
- from tensorflow.keras.optimizers import Adam
30
- from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau, ModelCheckpoint
31
- from tensorflow.keras.regularizers import l1_l2
32
- print("TensorFlow 載入成功")
33
- except ImportError:
34
- print("警告:TensorFlow 未安裝,模型將無法正常運作")
35
- tf = None
36
-
37
- class AdvancedStockPredictor:
38
- def __init__(self, model_name='taiwan_stock_predictor'):
39
- self.model_name = model_name
40
- self.model = None
41
- self.scaler_X = RobustScaler()
42
- self.scaler_y = StandardScaler()
43
- self.sequence_length = 60 # 使用60天的歷史數據
44
- self.feature_names = [
45
- 'volume', 'rate', 'DJI', 'NAS', 'SOX', 'SP500', 'TSM_ADR',
46
- 'RSI', 'MACD', 'MACDsign', 'MACDvol', 'K', 'D',
47
- '+DI', '-DI', 'ADX', 'business_climate', 'PMI'
48
- ]
49
- self.target_names = ['close_1d', 'close_5d', 'close_10d', 'close_20d', 'close_60d']
50
- self.is_trained = False
51
-
52
- def fetch_yfinance_data(self, start_date='2022-09-12', end_date='2025-09-08'):
53
- """從 yfinance 獲取所需的市場數據"""
54
- print("正在從 yfinance 獲取數據...")
55
-
56
- # 定義股票代碼
57
- symbols = {
58
- 'TAIEX': '^TWII', # 台股指數
59
- 'DJI': '^DJI', # 道瓊工業指數
60
- 'NAS': '^IXIC', # 納斯達克
61
- 'SOX': '^SOX', # 費城半導體指數
62
- 'SP500': '^GSPC', # 標普500
63
- 'TSM_ADR': 'TSM' # 台積電ADR
64
- }
65
-
66
- data_dict = {}
67
-
68
- for name, symbol in symbols.items():
69
- try:
70
- stock = yf.Ticker(symbol)
71
- hist = stock.history(start=start_date, end=end_date)
72
- if not hist.empty:
73
- data_dict[name] = hist
74
- print(f"成功獲取 {name} 數據: {len(hist)} 筆記錄")
75
- else:
76
- print(f"警告:無法獲取 {name} 數據")
77
- except Exception as e:
78
- print(f"錯誤:獲取 {name} 數據時發生錯誤: {e}")
79
-
80
- return data_dict
81
-
82
- def load_economic_data(self):
83
- """載入經濟數據檔案"""
84
- economic_data = {}
85
-
86
- # 載入景氣燈號
87
- try:
88
- if os.path.exists('business_climate.csv'):
89
- bc_df = pd.read_csv('business_climate.csv')
90
- if len(bc_df.columns) >= 2:
91
- bc_df.columns = ['Date', 'business_climate']
92
- # 統一時區處理
93
- bc_df['Date'] = pd.to_datetime(bc_df['Date'], errors='coerce').dt.tz_localize(None)
94
- bc_df = bc_df.dropna(subset=['Date'])
95
- bc_df.set_index('Date', inplace=True)
96
- economic_data['business_climate'] = bc_df
97
- print(f"成功載入景氣燈號數據: {len(bc_df)} 筆記錄")
98
- except Exception as e:
99
- print(f"載入景氣燈號數據時發生錯誤: {e}")
100
-
101
- # 載入 PMI 數據
102
- try:
103
- if os.path.exists('taiwan_pmi.csv'):
104
- pmi_df = pd.read_csv('taiwan_pmi.csv')
105
- if len(pmi_df.columns) >= 2:
106
- pmi_df.columns = ['Date', 'PMI']
107
- # 統一時區處理
108
- pmi_df['Date'] = pd.to_datetime(pmi_df['Date'], errors='coerce').dt.tz_localize(None)
109
- pmi_df = pmi_df.dropna(subset=['Date'])
110
- pmi_df.set_index('Date', inplace=True)
111
- economic_data['PMI'] = pmi_df
112
- print(f"成功載入 PMI 數據: {len(pmi_df)} 筆記錄")
113
- except Exception as e:
114
- print(f"載入 PMI 數據時發生錯誤: {e}")
115
-
116
- return economic_data
117
-
118
- def calculate_technical_indicators(self, df):
119
- """計算技術指標"""
120
- if df.empty:
121
- return df
122
-
123
- # 確保有足夠的數據計算技術指標
124
- if len(df) < 60:
125
- return pd.DataFrame()
126
-
127
- try:
128
- # 基本指標
129
- df['volume'] = df['Volume']
130
- df['rate'] = df['Close'].pct_change()
131
-
132
- # RSI
133
- delta = df['Close'].diff()
134
- gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
135
- loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
136
- rs = gain / loss
137
- df['RSI'] = 100 - (100 / (1 + rs))
138
-
139
- # MACD
140
- exp1 = df['Close'].ewm(span=12).mean()
141
- exp2 = df['Close'].ewm(span=26).mean()
142
- df['MACD'] = exp1 - exp2
143
- df['MACDsign'] = df['MACD'].ewm(span=9).mean()
144
- df['MACDvol'] = df['MACD'] - df['MACDsign']
145
-
146
- # KD 指標
147
- low_min = df['Low'].rolling(window=9).min()
148
- high_max = df['High'].rolling(window=9).max()
149
- rsv = (df['Close'] - low_min) / (high_max - low_min) * 100
150
- df['K'] = rsv.ewm(com=2).mean()
151
- df['D'] = df['K'].ewm(com=2).mean()
152
-
153
- # DMI 指標
154
- df['up_move'] = df['High'] - df['High'].shift(1)
155
- df['down_move'] = df['Low'].shift(1) - df['Low']
156
- df['+DM'] = np.where((df['up_move'] > df['down_move']) & (df['up_move'] > 0), df['up_move'], 0)
157
- df['-DM'] = np.where((df['down_move'] > df['up_move']) & (df['down_move'] > 0), df['down_move'], 0)
158
- df['TR'] = np.max([df['High'] - df['Low'],
159
- abs(df['High'] - df['Close'].shift(1)),
160
- abs(df['Low'] - df['Close'].shift(1))], axis=0)
161
-
162
- df['+DI'] = (df['+DM'].ewm(com=13).mean() / df['TR'].ewm(com=13).mean()) * 100
163
- df['-DI'] = (df['-DM'].ewm(com=13).mean() / df['TR'].ewm(com=13).mean()) * 100
164
- df['DX'] = abs(df['+DI'] - df['-DI']) / (df['+DI'] + df['-DI']) * 100
165
- df['ADX'] = df['DX'].ewm(com=13).mean()
166
-
167
- except Exception as e:
168
- print(f"計算技術指標時發生錯誤: {e}")
169
- return pd.DataFrame()
170
-
171
- return df
172
-
173
- def prepare_training_data(self):
174
- """準備訓練數據"""
175
- print("開始準備訓練數據...")
176
-
177
- # 獲取市場數據
178
- market_data = self.fetch_yfinance_data()
179
- economic_data = self.load_economic_data()
180
-
181
- if 'TAIEX' not in market_data:
182
- print("錯誤:無法獲取台股指數數據")
183
- return None, None
184
-
185
- # 以台股指數為主要數據
186
- main_df = market_data['TAIEX'].copy()
187
- # 統一時區處理 - 移除時區資訊
188
- main_df.index = main_df.index.tz_localize(None)
189
-
190
- main_df = self.calculate_technical_indicators(main_df)
191
-
192
- if main_df.empty:
193
- print("錯誤:技術指標計算失敗")
194
- return None, None
195
-
196
- # 合併其他市場數據
197
- for name, data in market_data.items():
198
- if name != 'TAIEX' and not data.empty:
199
- # 統一時區處理
200
- data.index = data.index.tz_localize(None)
201
-
202
- # 重新命名欄位以避免衝突
203
- if name == 'DJI':
204
- main_df['DJI'] = data['Close'].reindex(main_df.index)
205
- elif name == 'NAS':
206
- main_df['NAS'] = data['Close'].reindex(main_df.index)
207
- elif name == 'SOX':
208
- main_df['SOX'] = data['Close'].reindex(main_df.index)
209
- elif name == 'SP500':
210
- main_df['SP500'] = data['Close'].reindex(main_df.index)
211
- elif name == 'TSM_ADR':
212
- main_df['TSM_ADR'] = data['Close'].reindex(main_df.index)
213
-
214
- # 合併經濟數據
215
- for name, data in economic_data.items():
216
- if name == 'business_climate':
217
- main_df['business_climate'] = data['business_climate'].reindex(main_df.index, method='ffill')
218
- elif name == 'PMI':
219
- main_df['PMI'] = data['PMI'].reindex(main_df.index, method='ffill')
220
-
221
- # 創建未來價格標籤
222
- close_prices = main_df['Close']
223
- for days in [1, 5, 10, 20, 60]:
224
- main_df[f'close_{days}d'] = close_prices.shift(-days)
225
-
226
- # 選擇特徵欄位
227
- feature_columns = []
228
- for feature in self.feature_names:
229
- if feature in main_df.columns:
230
- feature_columns.append(feature)
231
- else:
232
- print(f"警告:特徵 {feature} 不存在,使用預���值 0")
233
- main_df[feature] = 0 # 使用預設值
234
- feature_columns.append(feature)
235
-
236
- # 移除包含 NaN 的行
237
- print(f"處理前數據量: {len(main_df)}")
238
- main_df = main_df.dropna()
239
- print(f"處理後數據量: {len(main_df)}")
240
-
241
- if len(main_df) < self.sequence_length + 60: # 需要足夠的數據
242
- print("錯誤:數據量不足以進行訓練")
243
- return None, None
244
-
245
- # 準備特徵和標籤
246
- X = main_df[feature_columns].values
247
- y = main_df[self.target_names].values
248
-
249
- print(f"數據準備完成:X shape: {X.shape}, y shape: {y.shape}")
250
- return X, y
251
-
252
- def create_sequences(self, X, y):
253
- """創建時間序列序列"""
254
- X_seq, y_seq = [], []
255
-
256
- for i in range(self.sequence_length, len(X)):
257
- X_seq.append(X[i-self.sequence_length:i])
258
- y_seq.append(y[i])
259
-
260
- return np.array(X_seq), np.array(y_seq)
261
-
262
- def build_model(self, input_shape, output_shape):
263
- """建立進階LSTM模型"""
264
- if tf is None:
265
- raise ImportError("TensorFlow 未安裝,無法建立模型")
266
-
267
- model = Sequential([
268
- # 第一層 Bidirectional LSTM
269
- Bidirectional(LSTM(128, return_sequences=True, dropout=0.2, recurrent_dropout=0.2),
270
- input_shape=input_shape),
271
- BatchNormalization(),
272
-
273
- # 第二層 LSTM
274
- LSTM(64, return_sequences=True, dropout=0.2, recurrent_dropout=0.2),
275
- BatchNormalization(),
276
-
277
- # 第三層 LSTM
278
- LSTM(32, dropout=0.2, recurrent_dropout=0.2),
279
- BatchNormalization(),
280
-
281
- # 全連接層
282
- Dense(64, activation='relu', kernel_regularizer=l1_l2(l1=0.01, l2=0.01)),
283
- Dropout(0.3),
284
-
285
- Dense(32, activation='relu', kernel_regularizer=l1_l2(l1=0.01, l2=0.01)),
286
- Dropout(0.2),
287
-
288
- # 輸出層
289
- Dense(output_shape, activation='linear')
290
- ])
291
-
292
- # 編譯模型
293
- model.compile(
294
- optimizer=Adam(learning_rate=0.001),
295
- loss='huber',
296
- metrics=['mae', 'mse']
297
- )
298
-
299
- return model
300
-
301
- def train(self, epochs=100, batch_size=32, validation_split=0.2):
302
- """訓練模型"""
303
- print("開始訓練模型...")
304
-
305
- # 準備數據
306
- X, y = self.prepare_training_data()
307
- if X is None or y is None:
308
- print("錯誤:無法準備訓練數據")
309
- return False
310
-
311
- # 數據標準化
312
- X_scaled = self.scaler_X.fit_transform(X)
313
- y_scaled = self.scaler_y.fit_transform(y)
314
-
315
- # 創建序列
316
- X_seq, y_seq = self.create_sequences(X_scaled, y_scaled)
317
-
318
- if len(X_seq) == 0:
319
- print("錯誤:無法創建有效序列")
320
- return False
321
-
322
- print(f"訓練數據形狀:X_seq: {X_seq.shape}, y_seq: {y_seq.shape}")
323
-
324
- # 建立模型
325
- self.model = self.build_model(
326
- input_shape=(X_seq.shape[1], X_seq.shape[2]),
327
- output_shape=y_seq.shape[1]
328
- )
329
-
330
- print("模型架構:")
331
- self.model.summary()
332
-
333
- # 設定回調函數
334
- callbacks = [
335
- EarlyStopping(patience=15, restore_best_weights=True, monitor='val_loss'),
336
- ReduceLROnPlateau(factor=0.5, patience=8, min_lr=0.0001, monitor='val_loss'),
337
- ModelCheckpoint(f'{self.model_name}.keras', save_best_only=True, monitor='val_loss')
338
- ]
339
-
340
- # 訓練模型
341
- history = self.model.fit(
342
- X_seq, y_seq,
343
- epochs=epochs,
344
- batch_size=batch_size,
345
- validation_split=validation_split,
346
- callbacks=callbacks,
347
- verbose=1
348
- )
349
-
350
- # 儲存模型和縮放器
351
- self.save_model()
352
 
353
- # 評估模型
354
- self.evaluate_model(X_seq, y_seq, validation_split)
355
-
356
- self.is_trained = True
357
- print("模型訓練完成!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
358
  return True
359
 
360
- def evaluate_model(self, X_seq, y_seq, validation_split):
361
- """評估模型性能"""
362
- print("\n模型評估結果:")
363
 
364
- # 分割數據
365
- split_idx = int(len(X_seq) * (1 - validation_split))
366
- X_val, y_val = X_seq[split_idx:], y_seq[split_idx:]
367
-
368
- # 預測
369
- y_pred = self.model.predict(X_val)
370
-
371
- # 反標準化
372
- y_val_orig = self.scaler_y.inverse_transform(y_val)
373
- y_pred_orig = self.scaler_y.inverse_transform(y_pred)
374
-
375
- # 計算指標
376
- for i, target in enumerate(self.target_names):
377
- mae = mean_absolute_error(y_val_orig[:, i], y_pred_orig[:, i])
378
- mse = mean_squared_error(y_val_orig[:, i], y_pred_orig[:, i])
379
- r2 = r2_score(y_val_orig[:, i], y_pred_orig[:, i])
380
- print(f"{target}: MAE={mae:.2f}, MSE={mse:.2f}, R2={r2:.4f}")
381
-
382
- def save_model(self):
383
- """儲存模型和縮放器"""
384
- try:
385
- if self.model is not None:
386
- self.model.save(f'{self.model_name}.keras')
387
- print(f"模型已儲存: {self.model_name}.keras")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
388
 
389
- joblib.dump(self.scaler_X, f'{self.model_name}_scaler_X.pkl')
390
- joblib.dump(self.scaler_y, f'{self.model_name}_scaler_y.pkl')
391
- print("縮放器已儲存")
 
 
 
392
 
393
- except Exception as e:
394
- print(f"儲存模型時發生錯誤: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
395
 
396
- def load_model(self):
397
- """載入已訓練的模型"""
398
- try:
399
- if tf is not None and os.path.exists(f'{self.model_name}.keras'):
400
- self.model = load_model(f'{self.model_name}.keras')
401
- print("模型載入成功")
402
 
403
- if os.path.exists(f'{self.model_name}_scaler_X.pkl'):
404
- self.scaler_X = joblib.load(f'{self.model_name}_scaler_X.pkl')
405
- print("X 縮放器載入成功")
406
 
407
- if os.path.exists(f'{self.model_name}_scaler_y.pkl'):
408
- self.scaler_y = joblib.load(f'{self.model_name}_scaler_y.pkl')
409
- print("y 縮放器載入成功")
 
 
 
 
410
 
411
- self.is_trained = True
412
- return True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
413
 
414
- except Exception as e:
415
- print(f"載入模型時發生錯誤: {e}")
416
- return False
417
 
418
- def predict(self, predict_days=5):
419
- """進行預測"""
420
- if not self.is_trained and not self.load_model():
421
- print("錯誤:模型未訓練且無法載入已訓練的模型")
422
- return None
 
 
 
 
 
 
 
 
 
 
423
 
424
- if self.model is None:
425
- print("錯誤:模型未載入")
426
- return None
427
 
428
- try:
429
- # 獲取最新數據
430
- print("正在獲取最新數據進行預測...")
431
- market_data = self.fetch_yfinance_data(
432
- start_date=(datetime.now() - timedelta(days=120)).strftime('%Y-%m-%d'),
433
- end_date=datetime.now().strftime('%Y-%m-%d')
434
- )
435
- economic_data = self.load_economic_data()
436
-
437
- if 'TAIEX' not in market_data:
438
- print("錯誤:無法獲取最新台股數據")
439
- return None
440
-
441
- # 處理數據(與訓練時相同的流程)
442
- main_df = market_data['TAIEX'].copy()
443
- # 統一時區處理
444
- main_df.index = main_df.index.tz_localize(None)
445
-
446
- main_df = self.calculate_technical_indicators(main_df)
447
-
448
- if main_df.empty or len(main_df) < self.sequence_length:
449
- print("錯誤:數據不足以進行預測")
450
- return None
451
-
452
- # 合併其他數據
453
- for name, data in market_data.items():
454
- if name != 'TAIEX' and not data.empty:
455
- # 統一時區處理
456
- data.index = data.index.tz_localize(None)
457
-
458
- if name == 'DJI':
459
- main_df['DJI'] = data['Close'].reindex(main_df.index)
460
- elif name == 'NAS':
461
- main_df['NAS'] = data['Close'].reindex(main_df.index)
462
- elif name == 'SOX':
463
- main_df['SOX'] = data['Close'].reindex(main_df.index)
464
- elif name == 'SP500':
465
- main_df['SP500'] = data['Close'].reindex(main_df.index)
466
- elif name == 'TSM_ADR':
467
- main_df['TSM_ADR'] = data['Close'].reindex(main_df.index)
468
-
469
- for name, data in economic_data.items():
470
- if name == 'business_climate':
471
- main_df['business_climate'] = data['business_climate'].reindex(main_df.index, method='ffill')
472
- elif name == 'PMI':
473
- main_df['PMI'] = data['PMI'].reindex(main_df.index, method='ffill')
474
-
475
- # 填充缺失特徵
476
- for feature in self.feature_names:
477
- if feature not in main_df.columns:
478
- main_df[feature] = 0
479
-
480
- # 使用 fillna 替代已棄用的 method 參數
481
- main_df = main_df.fillna(method='ffill').fillna(0)
482
-
483
- # 準備預測數據
484
- X = main_df[self.feature_names].values
485
- if len(X) < self.sequence_length:
486
- print("錯誤:歷史數據不足")
487
- return None
488
-
489
- # 使用最後的sequence_length天數據
490
- X_recent = X[-self.sequence_length:]
491
- X_scaled = self.scaler_X.transform(X_recent.reshape(1, -1))
492
- X_scaled = X_scaled.reshape(1, self.sequence_length, -1)
493
-
494
- # 進行預測
495
- y_pred_scaled = self.model.predict(X_scaled)
496
- y_pred = self.scaler_y.inverse_transform(y_pred_scaled)
497
-
498
- # 獲取當前價格
499
- current_price = main_df['Close'].iloc[-1]
500
-
501
- # 根據預測天數選擇對應的預測值
502
- day_mapping = {1: 0, 5: 1, 10: 2, 20: 3, 60: 4}
503
-
504
- if predict_days in day_mapping:
505
- predicted_price = y_pred[0][day_mapping[predict_days]]
506
- change_pct = ((predicted_price - current_price) / current_price) * 100
507
-
508
- # 計算信心度(簡化版本)
509
- confidence = min(0.9, max(0.6, 1 - abs(change_pct) / 100))
510
-
511
- result = {
512
- 'predicted_price': float(predicted_price),
513
- 'change_pct': float(change_pct),
514
- 'confidence': float(confidence),
515
- 'current_price': float(current_price),
516
- 'prediction_days': predict_days
517
- }
518
-
519
- print(f"預測結果:{predict_days}天後價格 = {predicted_price:.2f}, 變化 = {change_pct:+.2f}%")
520
- return result
521
- else:
522
- print(f"不支援的預測天數:{predict_days}")
523
- return None
524
 
525
- except Exception as e:
526
- print(f"預測時發生錯誤: {e}")
527
  return None
528
 
529
- # 全域預測器實例
530
- _predictor = None
 
 
531
 
532
- def get_predictor():
533
- """獲取全域預測器實例"""
534
- global _predictor
535
- if _predictor is None:
536
- _predictor = AdvancedStockPredictor()
537
- return _predictor
538
 
539
- def advanced_lstm_predict(predict_days=5):
540
- """
541
- 供 HUGING_FACE_V4.2 調用的預測函數
542
 
543
- Args:
544
- predict_days (int): 預測天數 (1, 5, 10, 20, 60)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
545
 
546
- Returns:
547
- dict or None: 預測結果字典,包含 predicted_price, change_pct, confidence
548
- """
549
- try:
550
- predictor = get_predictor()
551
- return predictor.predict(predict_days)
552
  except Exception as e:
553
- print(f"advanced_lstm_predict 錯誤: {e}")
554
  return None
555
 
556
- def train_model():
557
- """
558
- 訓練模型的主函數
559
- """
560
- print("開始訓練進階LSTM模型...")
561
- predictor = AdvancedStockPredictor()
562
-
563
- if predictor.train(epochs=50, batch_size=16):
564
- print("模型訓練成功!")
565
- return True
566
  else:
567
- print("模型訓練失敗!")
568
- return False
569
-
570
- if __name__ == "__main__":
571
- # 直接執行時進行模型訓練
572
- train_model()
 
4
  Automatically generated by Colab.
5
 
6
  Original file is located at
7
+ https://colab.research.google.com/drive/1zJ1bLFePNZEiz12aSub_bWwjS4CMtnH3
8
  """
9
 
10
  # model_predictor.py
11
+ # Advanced LSTM predictor for stock price forecasting, designed for HUGING_FACE_V4.2(輕量AI版).py.
12
+ # Trains model using provided CSVs and saves lstm_model.pth, scaler_X.pkl, scaler_y.pkl for deployment.
13
+ # Features: volume, rate, DJI, NAS, SOX, S&P_500, RSI, MACD, MACDsign, MACDvol, K, D, +DI, -DI, ADX, business_climate, PMI.
14
+ # Targets: close prices 1, 5, 10, 20, 60 days ahead.
15
+ # For Hugging Face deployment, upload the generated model and scaler files.
16
+ # Prediction uses yfinance for real-time features, falling back to last CSV row if unavailable.
17
 
18
  import os
 
19
  import pandas as pd
20
+ import numpy as np
21
+ import torch
22
+ import torch.nn as nn
23
+ from sklearn.preprocessing import StandardScaler
24
+ import pickle
25
+ from datetime import datetime
26
  import yfinance as yf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
 
28
+ # Define LSTM model
29
+ class LSTMPredictor(nn.Module):
30
+ def __init__(self, input_size, hidden_size=50, num_layers=2, output_size=5):
31
+ super(LSTMPredictor, self).__init__()
32
+ self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
33
+ self.fc = nn.Linear(hidden_size, output_size)
34
+
35
+ def forward(self, x):
36
+ if x.dim() == 2:
37
+ x = x.unsqueeze(1) # (batch, 1, features)
38
+ out, _ = self.lstm(x)
39
+ out = self.fc(out[:, -1, :])
40
+ return out
41
+
42
+ # Global variables
43
+ model = None
44
+ scaler_X = None
45
+ scaler_y = None
46
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
47
+
48
+ def load_model_and_scalers():
49
+ """
50
+ Load pre-trained model and scalers.
51
+ Returns True if successful, False otherwise.
52
+ """
53
+ global model, scaler_X, scaler_y
54
+
55
+ if os.path.exists('lstm_model.pth') and os.path.exists('scaler_X.pkl') and os.path.exists('scaler_y.pkl'):
56
+ input_size = 16 # Number of features
57
+ model = LSTMPredictor(input_size).to(device)
58
+ model.load_state_dict(torch.load('lstm_model.pth', map_location=device))
59
+ model.eval()
60
+
61
+ with open('scaler_X.pkl', 'rb') as f:
62
+ scaler_X = pickle.load(f)
63
+ with open('scaler_y.pkl', 'rb') as f:
64
+ scaler_y = pickle.load(f)
65
+ print("Pre-trained model and scalers loaded successfully.")
66
  return True
67
 
68
+ print("Pre-trained files not found. Cannot predict without training.")
69
+ return False
 
70
 
71
+ def prepare_data_from_csvs():
72
+ """
73
+ Prepare training data from CSVs.
74
+ Date range: 2022-09-12 to 2025-09-08.
75
+ Handles PMI column conflict by renaming PMI from taiwan_pmi.csv.
76
+ """
77
+ # Read main data
78
+ if not os.path.exists('新期末專案輸入資料20220912-20250909.csv'):
79
+ raise FileNotFoundError("Main CSV file not found. Please upload '新期末專案輸入資料20220912-20250909.csv'.")
80
+
81
+ main_data = pd.read_csv('新期末專案輸入資料20220912-20250909.csv', parse_dates=['date'])
82
+ start_date = pd.to_datetime('20220912')
83
+ end_date = pd.to_datetime('20250908')
84
+ main_data = main_data[(main_data['date'] >= start_date) & (main_data['date'] <= end_date)].copy()
85
+ main_data.set_index('date', inplace=True)
86
+ main_data.sort_index(inplace=True)
87
+
88
+ # Read business_climate
89
+ if not os.path.exists('business_climate.csv'):
90
+ raise FileNotFoundError("business_climate.csv not found.")
91
+ business_climate = pd.read_csv('business_climate.csv')
92
+ business_climate['Date'] = pd.to_datetime(business_climate['Date'] + '-01')
93
+ business_climate.set_index('Date', inplace=True)
94
+ business_climate.sort_index(inplace=True)
95
+
96
+ # Read PMI and rename to avoid conflict
97
+ if not os.path.exists('taiwan_pmi.csv'):
98
+ raise FileNotFoundError("taiwan_pmi.csv not found.")
99
+ taiwan_pmi = pd.read_csv('taiwan_pmi.csv')
100
+ taiwan_pmi['DATE'] = pd.to_datetime(taiwan_pmi['DATE'] + '-01')
101
+ taiwan_pmi.set_index('DATE', inplace=True)
102
+ taiwan_pmi.rename(columns={'INDEX': 'PMI_external'}, inplace=True) # Rename to avoid conflict
103
+ taiwan_pmi.sort_index(inplace=True)
104
+
105
+ # Merge
106
+ main_data = main_data.join(business_climate, how='left')
107
+ main_data = main_data.join(taiwan_pmi['PMI_external'], how='left')
108
+
109
+ # Fill missing values
110
+ main_data['business_climate'] = main_data['business_climate'].fillna(method='ffill').fillna(method='bfill').fillna(main_data['business_climate'].mean())
111
+ main_data['PMI_external'] = main_data['PMI_external'].fillna(method='ffill').fillna(method='bfill').fillna(main_data['PMI_external'].mean())
112
+
113
+ # Drop TSM_ADR if empty
114
+ if 'TSM_ADR' in main_data.columns:
115
+ main_data.drop(columns=['TSM_ADR'], inplace=True)
116
+
117
+ # Feature columns (use PMI from main_data)
118
+ feature_columns = ['volume', 'rate', 'DJI', 'NAS', 'SOX', 'S&P_500', 'RSI', 'MACD', 'MACDsign', 'MACDvol', 'K', 'D', '+DI', '-DI', 'ADX', 'business_climate', 'PMI']
119
+
120
+ # Remove rows with NaN in features or close
121
+ main_data = main_data.dropna(subset=feature_columns + ['close'])
122
+
123
+ # Create targets
124
+ for days in [1, 5, 10, 20, 60]:
125
+ main_data[f'close_{days}d'] = main_data['close'].shift(-days)
126
+
127
+ # Remove rows where targets are NaN
128
+ main_data = main_data[:-60]
129
+
130
+ # X and y
131
+ X = main_data[feature_columns].values.astype(np.float32)
132
+ y = main_data[[f'close_{days}d' for days in [1, 5, 10, 20, 60]]].values.astype(np.float32)
133
+
134
+ return X, y, feature_columns
135
 
136
+ def train_model():
137
+ """
138
+ Train LSTM model and save model/scalers.
139
+ Returns True if successful.
140
+ """
141
+ global model, scaler_X, scaler_y
142
 
143
+ try:
144
+ X, y, feature_columns = prepare_data_from_csvs()
145
+
146
+ # Scale
147
+ scaler_X = StandardScaler()
148
+ scaler_y = StandardScaler()
149
+ X_scaled = scaler_X.fit_transform(X)
150
+ y_scaled = scaler_y.fit_transform(y)
151
+
152
+ # Save scalers
153
+ with open('scaler_X.pkl', 'wb') as f:
154
+ pickle.dump(scaler_X, f)
155
+ with open('scaler_y.pkl', 'wb') as f:
156
+ pickle.dump(scaler_y, f)
157
+
158
+ # To tensors
159
+ X_tensor = torch.tensor(X_scaled, dtype=torch.float32)
160
+ y_tensor = torch.tensor(y_scaled, dtype=torch.float32)
161
+
162
+ # Dataset and loaders
163
+ from torch.utils.data import TensorDataset, DataLoader
164
+ from sklearn.model_selection import train_test_split
165
+
166
+ dataset = TensorDataset(X_tensor, y_tensor)
167
+ train_size = int(0.8 * len(dataset))
168
+ val_size = len(dataset) - train_size
169
+ train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
170
+
171
+ train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
172
+ val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)
173
+
174
+ # Model
175
+ input_size = X.shape[1]
176
+ model = LSTMPredictor(input_size).to(device)
177
+
178
+ # Loss and optimizer
179
+ criterion = nn.MSELoss()
180
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
181
+
182
+ # Train
183
+ num_epochs = 50
184
+ for epoch in range(num_epochs):
185
+ model.train()
186
+ train_loss = 0.0
187
+ for batch_X, batch_y in train_loader:
188
+ batch_X, batch_y = batch_X.to(device), batch_y.to(device)
189
+ optimizer.zero_grad()
190
+ outputs = model(batch_X)
191
+ loss = criterion(outputs, batch_y)
192
+ loss.backward()
193
+ optimizer.step()
194
+ train_loss += loss.item() * batch_X.size(0)
195
+
196
+ train_loss /= len(train_dataset)
197
+
198
+ # Validation
199
+ model.eval()
200
+ val_loss = 0.0
201
+ with torch.no_grad():
202
+ for batch_X, batch_y in val_loader:
203
+ batch_X, batch_y = batch_X.to(device), batch_y.to(device)
204
+ outputs = model(batch_X)
205
+ loss = criterion(outputs, batch_y)
206
+ val_loss += loss.item() * batch_X.size(0)
207
+
208
+ val_loss /= len(val_dataset)
209
+
210
+ if (epoch + 1) % 10 == 0:
211
+ print(f'Epoch [{epoch+1}/{num_epochs}], Train Loss: {train_loss:.4f}, Val Loss: {val_loss:.4f}')
212
+
213
+ # Save model
214
+ torch.save(model.state_dict(), 'lstm_model.pth')
215
 
216
+ return True
 
 
 
 
 
217
 
218
+ except Exception as e:
219
+ print(f"Training error: {e}")
220
+ return False
221
 
222
+ def get_current_features():
223
+ """
224
+ Fetch current features for prediction using yfinance.
225
+ Fallback to last row of CSV if yfinance fails.
226
+ business_climate and PMI from CSVs.
227
+ """
228
+ feature_columns = ['volume', 'rate', 'DJI', 'NAS', 'SOX', 'S&P_500', 'RSI', 'MACD', 'MACDsign', 'MACDvol', 'K', 'D', '+DI', '-DI', 'ADX', 'business_climate', 'PMI']
229
 
230
+ try:
231
+ # Fetch indices
232
+ tickers = {'DJI': '^DJI', 'NAS': '^IXIC', 'SOX': '^SOX', 'S&P_500': '^GSPC'}
233
+ data = {}
234
+ for key, ticker in tickers.items():
235
+ stock = yf.Ticker(ticker)
236
+ hist = stock.history(period='1d')
237
+ if not hist.empty:
238
+ data[key] = hist['Close'].iloc[-1]
239
+ else:
240
+ raise Exception(f"Failed to fetch {key}")
241
+
242
+ # Assume rate is constant or from last CSV (not provided in yfinance)
243
+ main_data = pd.read_csv('新期末專案輸入資料20220912-20250909.csv', parse_dates=['date'])
244
+ last_row = main_data.iloc[-1]
245
+ rate = last_row['rate']
246
+
247
+ # Volume and technical indicators from ^TWII
248
+ twii = yf.Ticker('^TWII')
249
+ hist = twii.history(period='60d')
250
+ if hist.empty:
251
+ raise Exception("Failed to fetch ^TWII")
252
+
253
+ hist['MA5'] = hist['Close'].rolling(window=5).mean()
254
+ hist['MA20'] = hist['Close'].rolling(window=20).mean()
255
+ delta = hist['Close'].diff()
256
+ gain = delta.where(delta > 0, 0).rolling(window=14).mean()
257
+ loss = -delta.where(delta < 0, 0).rolling(window=14).mean()
258
+ rs = gain / loss
259
+ rsi = 100 - (100 / (1 + rs))
260
+ exp1 = hist['Close'].ewm(span=12).mean()
261
+ exp2 = hist['Close'].ewm(span=26).mean()
262
+ macd = exp1 - exp2
263
+ macd_signal = macd.ewm(span=9).mean()
264
+ macd_hist = macd - macd_signal
265
+ low_min = hist['Low'].rolling(window=9).min()
266
+ high_max = hist['High'].rolling(window=9).max()
267
+ rsv = (hist['Close'] - low_min) / (high_max - low_min) * 100
268
+ k = rsv.ewm(com=2).mean()
269
+ d = k.ewm(com=2).mean()
270
+ up_move = hist['High'] - hist['High'].shift(1)
271
+ down_move = hist['Low'].shift(1) - hist['Low']
272
+ plus_dm = np.where((up_move > down_move) & (up_move > 0), up_move, 0)
273
+ minus_dm = np.where((down_move > up_move) & (down_move > 0), down_move, 0)
274
+ tr = np.max([hist['High'] - hist['Low'], abs(hist['High'] - hist['Close'].shift(1)), abs(hist['Low'] - hist['Close'].shift(1))], axis=0)
275
+ plus_di = (pd.Series(plus_dm).ewm(com=13, adjust=False).mean() / pd.Series(tr).ewm(com=13, adjust=False).mean()) * 100
276
+ minus_di = (pd.Series(minus_dm).ewm(com=13, adjust=False).mean() / pd.Series(tr).ewm(com=13, adjust=False).mean()) * 100
277
+ dx = abs(plus_di - minus_di) / (plus_di + minus_di) * 100
278
+ adx = dx.ewm(com=13, adjust=False).mean()
279
+
280
+ # Latest values
281
+ volume = hist['Volume'].iloc[-1]
282
+ rsi = rsi.iloc[-1]
283
+ macd = macd.iloc[-1]
284
+ macd_signal = macd_signal.iloc[-1]
285
+ macd_vol = macd_hist.iloc[-1]
286
+ k = k.iloc[-1]
287
+ d = d.iloc[-1]
288
+ plus_di = plus_di.iloc[-1]
289
+ minus_di = minus_di.iloc[-1]
290
+ adx = adx.iloc[-1]
291
+
292
+ # business_climate and PMI from CSVs
293
+ business_climate = pd.read_csv('business_climate.csv')
294
+ pmi = pd.read_csv('taiwan_pmi.csv')
295
+ last_bc = business_climate['Index'].iloc[-1]
296
+ last_pmi = pmi['INDEX'].iloc[-1]
297
+
298
+ features = [
299
+ volume, rate, data['DJI'], data['NAS'], data['SOX'], data['S&P_500'],
300
+ rsi, macd, macd_signal, macd_vol, k, d, plus_di, minus_di, adx,
301
+ last_bc, last_pmi
302
+ ]
303
 
304
+ return np.array(features).reshape(1, -1).astype(np.float32)
 
 
305
 
306
+ except Exception as e:
307
+ print(f"yfinance fetch failed: {e}. Using last CSV row.")
308
+ main_data = pd.read_csv('新期末專案輸入資料20220912-20250909.csv', parse_dates=['date'])
309
+ last_row = main_data.iloc[-1]
310
+ business_climate = pd.read_csv('business_climate.csv')
311
+ pmi = pd.read_csv('taiwan_pmi.csv')
312
+ last_bc = business_climate['Index'].iloc[-1]
313
+ last_pmi = pmi['INDEX'].iloc[-1]
314
+
315
+ features = [
316
+ last_row['volume'], last_row['rate'], last_row['DJI'], last_row['NAS'], last_row['SOX'],
317
+ last_row['S&P_500'], last_row['RSI'], last_row['MACD'], last_row['MACDsign'],
318
+ last_row['MACDvol'], last_row['K'], last_row['D'], last_row['+DI'], last_row['-DI'],
319
+ last_row['ADX'], last_bc, last_pmi
320
+ ]
321
 
322
+ return np.array(features).reshape(1, -1).astype(np.float32)
 
 
323
 
324
+ def advanced_lstm_predict(predict_days=5):
325
+ """
326
+ Predict stock closes for 1, 5, 10, 20, 60 days ahead.
327
+ Returns dict with predicted_price, change_pct, confidence for specified horizon.
328
+ Includes all_predictions for compatibility.
329
+ """
330
+ global model, scaler_X, scaler_y
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
 
332
+ if model is None or scaler_X is None or scaler_y is None:
333
+ if not load_model_and_scalers():
334
  return None
335
 
336
+ try:
337
+ current_features = get_current_features()
338
+ current_scaled = scaler_X.transform(current_features)
339
+ current_tensor = torch.tensor(current_scaled, dtype=torch.float32).to(device)
340
 
341
+ model.eval()
342
+ with torch.no_grad():
343
+ pred_scaled = model(current_tensor)
 
 
 
344
 
345
+ pred = scaler_y.inverse_transform(pred_scaled.cpu().numpy())
 
 
346
 
347
+ # Current close from ^TWII or last CSV
348
+ try:
349
+ twii = yf.Ticker('^TWII')
350
+ current_close = twii.history(period='1d')['Close'].iloc[-1]
351
+ except:
352
+ main_data = pd.read_csv('新期末專案輸入資料20220912-20250909.csv')
353
+ current_close = main_data['close'].iloc[-1]
354
+
355
+ horizons = [1, 5, 10, 20, 60]
356
+ predictions = {}
357
+ for i, days in enumerate(horizons):
358
+ pred_price = pred[0, i]
359
+ change_pct = ((pred_price - current_close) / current_close) * 100
360
+ predictions[f'{days}d'] = {
361
+ 'predicted_price': pred_price,
362
+ 'change_pct': change_pct,
363
+ 'confidence': 0.8 # Placeholder
364
+ }
365
+
366
+ selected = predictions.get(f'{predict_days}d', predictions['1d'])
367
+ selected['all_predictions'] = predictions
368
+
369
+ return selected
370
 
 
 
 
 
 
 
371
  except Exception as e:
372
+ print(f"Prediction error: {e}")
373
  return None
374
 
375
+ # Train and save model/scalers for deployment
376
+ if __name__ == '__main__':
377
+ print("Training model...")
378
+ if train_model():
379
+ print("Model and scalers saved successfully.")
 
 
 
 
 
380
  else:
381
+ print("Training failed.")
 
 
 
 
 
scaler_X.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bfa2a7e0e256257525b7949a92aa59e2385d01f65afe1431441c73bff85098a3
3
+ size 858
scaler_y.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7fdd5b70d72852c4486dcc93c276dd3975ea81a98fbc54095bac900631fc418b
3
+ size 570