AlanRex commited on
Commit
b3374d4
·
verified ·
1 Parent(s): b0e9a2b

Delete model_predictor.py

Browse files
Files changed (1) hide show
  1. model_predictor.py +0 -833
model_predictor.py DELETED
@@ -1,833 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- """model_predictor.ipynb
3
-
4
- Automatically generated by Colab.
5
-
6
- Original file is located at
7
- https://colab.research.google.com/drive/1pIuCvafVPCRzTLojc-rZH_MFKsxMam2L
8
- """
9
-
10
- # model_predictor.py
11
- # 深度學習股價預測模型 - 適用於 HUGGING_FACE_V4.2
12
-
13
- import os
14
- import numpy as np
15
- import pandas as pd
16
- import yfinance as yf
17
- from datetime import datetime, timedelta
18
- import warnings
19
- warnings.filterwarnings('ignore')
20
-
21
- # TensorFlow/Keras 相關
22
- try:
23
- import tensorflow as tf
24
- from tensorflow.keras.models import Sequential, load_model
25
- from tensorflow.keras.layers import LSTM, Dense, Dropout, BatchNormalization, LeakyReLU
26
- from tensorflow.keras.optimizers import Adam
27
- from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
28
- from tensorflow.keras.regularizers import l2
29
- from sklearn.preprocessing import MinMaxScaler, RobustScaler
30
- from sklearn.model_selection import train_test_split
31
- from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
32
- TENSORFLOW_AVAILABLE = True
33
- except ImportError:
34
- TENSORFLOW_AVAILABLE = False
35
-
36
- # 設定隨機種子以確保結果可重現
37
- if TENSORFLOW_AVAILABLE:
38
- tf.random.set_seed(42)
39
- np.random.seed(42)
40
-
41
- class StockPredictor:
42
- """股價預測模型類別"""
43
-
44
- def __init__(self):
45
- self.model = None
46
- self.feature_scaler = None
47
- self.target_scalers = {} # 為每個目標變數建立獨立的縮放器
48
- self.feature_columns = [
49
- 'volume', 'rate', 'DJI', 'NAS', 'SOX', 'S&P_500', 'TSM_ADR',
50
- 'RSI', 'MACD', 'MACDsign', 'MACDvol', 'K', 'D',
51
- '+DI', '-DI', 'ADX', 'business_climate', 'PMI'
52
- ]
53
- self.target_columns = [
54
- 'close_1d', 'close_5d', 'close_10d', 'close_20d', 'close_60d'
55
- ]
56
- self.sequence_length = 60 # 使用60天的歷史數據
57
- self.model_path = 'lstm_stock_model.h5'
58
- self.scalers_path = 'scalers.npz'
59
-
60
- def fetch_yfinance_data(self, start_date='2022-09-12', end_date='2025-09-08'):
61
- """從yfinance獲取股市數據"""
62
- try:
63
- # 台積電 (2330.TW) 作為主要目標股票
64
- taiwan_stock = yf.Ticker('2314.TW')
65
- taiwan_data = taiwan_stock.history(start=start_date, end=end_date)
66
-
67
- if taiwan_data.empty:
68
- print("警告: 無法獲取台灣股市數據")
69
- return None
70
-
71
- # 獲取美國市場數據
72
- symbols = {
73
- 'DJI': '^DJI',
74
- 'NAS': '^IXIC',
75
- 'SOX': '^SOX',
76
- 'S&P_500': '^GSPC',
77
- 'TSM_ADR': 'TSM'
78
- }
79
-
80
- us_data = {}
81
- for name, symbol in symbols.items():
82
- try:
83
- ticker = yf.Ticker(symbol)
84
- data = ticker.history(start=start_date, end=end_date)
85
- if not data.empty:
86
- us_data[name] = data['Close']
87
- else:
88
- print(f"警告: 無法獲取 {name} 數據")
89
- except Exception as e:
90
- print(f"獲取 {name} 數據時發生錯誤: {e}")
91
-
92
- # 合併數據
93
- main_df = pd.DataFrame(index=taiwan_data.index)
94
- main_df['close'] = taiwan_data['Close']
95
- main_df['volume'] = taiwan_data['Volume']
96
-
97
- # 計算報酬率
98
- main_df['rate'] = main_df['close'].pct_change()
99
-
100
- # 添加美國市場數據
101
- for name, data in us_data.items():
102
- # 重新索引以匹配台灣股市交易日
103
- main_df[name] = data.reindex(main_df.index, method='ffill')
104
-
105
- return main_df
106
-
107
- except Exception as e:
108
- print(f"獲取yfinance數據時發生錯誤: {e}")
109
- return None
110
-
111
- def load_external_data(self):
112
- """載入外部經濟數據"""
113
- business_climate = pd.DataFrame()
114
- pmi_data = pd.DataFrame()
115
-
116
- # 載入景氣燈號數據
117
- try:
118
- if os.path.exists('business_climate.csv'):
119
- business_climate = pd.read_csv('business_climate.csv')
120
- business_climate['Date'] = pd.to_datetime(business_climate['Date'])
121
- business_climate.set_index('Date', inplace=True)
122
- print("成功載入景氣燈號數據")
123
- except Exception as e:
124
- print(f"載入景氣燈號數據失敗: {e}")
125
-
126
- # 載入PMI數據
127
- try:
128
- if os.path.exists('taiwan_pmi.csv'):
129
- pmi_data = pd.read_csv('taiwan_pmi.csv')
130
- pmi_data['Date'] = pd.to_datetime(pmi_data['Date'])
131
- pmi_data.set_index('Date', inplace=True)
132
- print("成功載入PMI數據")
133
- except Exception as e:
134
- print(f"載入PMI數據失敗: {e}")
135
-
136
- return business_climate, pmi_data
137
-
138
- def calculate_technical_indicators(self, df):
139
- """計算技術指標"""
140
- try:
141
- # RSI
142
- delta = df['close'].diff()
143
- gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
144
- loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
145
- rs = gain / loss
146
- df['RSI'] = 100 - (100 / (1 + rs))
147
-
148
- # MACD
149
- exp1 = df['close'].ewm(span=12).mean()
150
- exp2 = df['close'].ewm(span=26).mean()
151
- df['MACD'] = exp1 - exp2
152
- df['MACDsign'] = df['MACD'].ewm(span=9).mean()
153
- df['MACDvol'] = df['MACD'] - df['MACDsign']
154
-
155
- # KD指標
156
- low_min = df['close'].rolling(window=9).min()
157
- high_max = df['close'].rolling(window=9).max()
158
- rsv = (df['close'] - low_min) / (high_max - low_min) * 100
159
- df['K'] = rsv.ewm(com=2).mean()
160
- df['D'] = df['K'].ewm(com=2).mean()
161
-
162
- # DMI指標 (簡化版本,使用close價格)
163
- df['high_low_diff'] = df['close'].rolling(2).max() - df['close'].rolling(2).min()
164
- df['+DI'] = df['high_low_diff'].rolling(14).mean()
165
- df['-DI'] = df['high_low_diff'].rolling(14).std()
166
- df['ADX'] = (df['+DI'] + df['-DI']).rolling(14).mean()
167
-
168
- # 清理臨時欄位
169
- df.drop(['high_low_diff'], axis=1, inplace=True)
170
-
171
- return df
172
-
173
- except Exception as e:
174
- print(f"計算技術指標時發生錯誤: {e}")
175
- return df
176
-
177
- def create_sample_data(self, days=500):
178
- """創建示例數據用於訓練(當CSV載入失敗時的後備方案)"""
179
- try:
180
- print("創建示例數據進行訓練...")
181
-
182
- # 獲取台積電數據作為基礎
183
- taiwan_data = self.fetch_yfinance_data(
184
- start_date=(datetime.now() - timedelta(days=days)).strftime('%Y-%m-%d'),
185
- end_date=datetime.now().strftime('%Y-%m-%d')
186
- )
187
-
188
- if taiwan_data is None or taiwan_data.empty:
189
- print("無法獲取示例數據")
190
- return None
191
-
192
- # 確保有基本的close和volume數據
193
- if 'close' not in taiwan_data.columns or 'volume' not in taiwan_data.columns:
194
- print("示例數據缺少必要欄位")
195
- return None
196
-
197
- # 計算技術指標
198
- taiwan_data = self.calculate_technical_indicators(taiwan_data)
199
-
200
- # 添加經濟指標(使用固定值)
201
- taiwan_data['business_climate'] = 25.0
202
- taiwan_data['PMI'] = 50.0
203
-
204
- # 確保所有特徵欄位存在
205
- for feature in self.feature_columns:
206
- if feature not in taiwan_data.columns:
207
- taiwan_data[feature] = 0.0
208
-
209
- # 計算未來價格目標
210
- for days in [1, 5, 10, 20, 60]:
211
- taiwan_data[f'close_{days}d'] = taiwan_data['close'].shift(-days)
212
-
213
- # 移除缺失值
214
- taiwan_data = taiwan_data.dropna()
215
-
216
- if len(taiwan_data) < 100:
217
- print("示例數據不足")
218
- return None
219
-
220
- print(f"成功創建示例數據: {taiwan_data.shape}")
221
- return taiwan_data
222
-
223
- except Exception as e:
224
- print(f"創建示例數據時發生錯誤: {e}")
225
- return None
226
- """調試CSV檔案結構"""
227
- try:
228
- print(f"\n=== 調試CSV檔案: {csv_path} ===")
229
-
230
- # 讀取前幾行看看結構
231
- with open(csv_path, 'r', encoding='utf-8') as f:
232
- first_lines = [f.readline().strip() for _ in range(5)]
233
-
234
- print("前5行原始內容:")
235
- for i, line in enumerate(first_lines):
236
- print(f"第{i+1}行: {line[:100]}...") # 只顯示前100個字符
237
-
238
- # 嘗試不同的編碼和分隔符
239
- encodings = ['utf-8', 'utf-8-sig', 'latin-1', 'cp1252']
240
- separators = [',', ';', '\t', '|']
241
-
242
- for encoding in encodings:
243
- for sep in separators:
244
- try:
245
- df_test = pd.read_csv(csv_path, encoding=encoding, sep=sep, nrows=5)
246
- if len(df_test.columns) > 5: # 如果有合理的欄位數量
247
- print(f"\n成功讀取 (編碼: {encoding}, 分隔符: '{sep}'):")
248
- print(f"欄位: {list(df_test.columns)}")
249
- print(f"數據形狀: {df_test.shape}")
250
- return encoding, sep
251
- except:
252
- continue
253
-
254
- print("無法找到合適的讀取參數")
255
- return None, None
256
-
257
- except Exception as e:
258
- print(f"調試CSV檔案時發生錯誤: {e}")
259
- return None, None
260
-
261
- def prepare_training_data(self, csv_path=None):
262
- """準備訓練數據"""
263
- try:
264
- if csv_path and os.path.exists(csv_path):
265
- # 先調試CSV檔案
266
- encoding, separator = self.debug_csv_file(csv_path)
267
-
268
- # 如果提供了CSV檔案,直接載入
269
- print(f"\n從 {csv_path} 載入數據...")
270
-
271
- # 使用找到的最佳參數讀取
272
- read_params = {}
273
- if encoding:
274
- read_params['encoding'] = encoding
275
- if separator and separator != ',':
276
- read_params['sep'] = separator
277
-
278
- df = pd.read_csv(csv_path, **read_params)
279
-
280
- # 檢查CSV檔案結構
281
- print(f"CSV檔案欄位: {list(df.columns)}")
282
- print(f"數據形狀: {df.shape}")
283
- print(f"前5行數據:")
284
- print(df.head())
285
-
286
- # 處理日期欄位
287
- date_columns = ['Date', 'date', 'DATE', 'Unnamed: 0']
288
- date_col = None
289
- for col in date_columns:
290
- if col in df.columns:
291
- date_col = col
292
- break
293
-
294
- if date_col:
295
- print(f"使用日期欄位: {date_col}")
296
- df[date_col] = pd.to_datetime(df[date_col])
297
- df.set_index(date_col, inplace=True)
298
- elif df.index.dtype == 'object':
299
- df.index = pd.to_datetime(df.index)
300
-
301
- print(f"處理日期後的數據形狀: {df.shape}")
302
- print(f"日期範圍: {df.index.min()} 到 {df.index.max()}")
303
-
304
- else:
305
- # 從yfinance和外部檔案獲取數據
306
- print("從yfinance獲取數據...")
307
- df = self.fetch_yfinance_data()
308
- if df is None:
309
- return None, None, None, None
310
-
311
- # 計算技術指標
312
- df = self.calculate_technical_indicators(df)
313
-
314
- # 載入外部經濟數據
315
- business_climate, pmi_data = self.load_external_data()
316
-
317
- # 合併外部數據
318
- if not business_climate.empty:
319
- df['business_climate'] = business_climate['Index'].reindex(
320
- df.index, method='ffill'
321
- )
322
- else:
323
- df['business_climate'] = 25.0 # 預設值
324
-
325
- if not pmi_data.empty:
326
- df['PMI'] = pmi_data['Index'].reindex(df.index, method='ffill')
327
- else:
328
- df['PMI'] = 50.0 # 預設值
329
-
330
- # 檢查並映射欄位名稱
331
- column_mapping = {
332
- # 可能的volume欄位名稱
333
- 'Volume': 'volume', 'vol': 'volume', 'VOLUME': 'volume',
334
- # 可能的close欄位名稱
335
- 'Close': 'close', 'close_price': 'close', 'CLOSE': 'close', 'price': 'close',
336
- # 可能的rate欄位名稱
337
- 'Rate': 'rate', 'return': 'rate', 'pct_change': 'rate', 'RATE': 'rate',
338
- # 美股指數
339
- 'DJI': 'DJI', 'DOW': 'DJI', 'dow': 'DJI',
340
- 'NAS': 'NAS', 'NASDAQ': 'NAS', 'nasdaq': 'NAS',
341
- 'SOX': 'SOX', 'sox': 'SOX',
342
- 'S&P_500': 'S&P_500', 'SP500': 'S&P_500', 'sp500': 'S&P_500',
343
- 'TSM_ADR': 'TSM_ADR', 'TSM': 'TSM_ADR', 'tsm': 'TSM_ADR',
344
- # 技術指標
345
- 'rsi': 'RSI', 'macd': 'MACD', 'macdsign': 'MACDsign', 'macdvol': 'MACDvol',
346
- 'k': 'K', 'd': 'D', '+di': '+DI', '-di': '-DI', 'adx': 'ADX',
347
- # 經濟指標
348
- 'Business_Climate': 'business_climate', 'business_climate_index': 'business_climate',
349
- 'pmi': 'PMI', 'PMI_Index': 'PMI'
350
- }
351
-
352
- # 應用欄位映射
353
- df = df.rename(columns=column_mapping)
354
- print(f"映射後的欄位: {list(df.columns)}")
355
-
356
- # 如果沒有close欄位但有其他價格欄位,嘗試使用
357
- if 'close' not in df.columns:
358
- price_candidates = ['Close', 'Price', 'CLOSE', 'close_price']
359
- for candidate in price_candidates:
360
- if candidate in df.columns:
361
- df['close'] = df[candidate]
362
- print(f"使用 {candidate} 作為 close 價格")
363
- break
364
-
365
- # 計算missing的技術指標(如果數據中沒有)
366
- if 'close' in df.columns:
367
- if 'rate' not in df.columns:
368
- df['rate'] = df['close'].pct_change()
369
- print("計算了price return rate")
370
-
371
- # 如果缺少技術指標,計算它們
372
- if 'RSI' not in df.columns:
373
- df = self.calculate_technical_indicators(df)
374
- print("計算了技術指標")
375
-
376
- # 計算未來價格目標
377
- if 'close' in df.columns:
378
- for days in [1, 5, 10, 20, 60]:
379
- df[f'close_{days}d'] = df['close'].shift(-days)
380
- print("計算了未來價格目標")
381
- else:
382
- print("錯誤: 找不到價格數據,無法計算目標變數")
383
- return None, None, None, None
384
-
385
- print(f"計算目標變數後的數據形狀: {df.shape}")
386
-
387
- # 移除缺失值
388
- original_len = len(df)
389
- df = df.dropna()
390
- print(f"移除缺失值: {original_len} -> {len(df)} 行")
391
-
392
- if df.empty:
393
- print("錯誤: 處理後的數據集為空")
394
- print("可能原因:")
395
- print("1. 所有數據都有缺失值")
396
- print("2. 日期格式不正確")
397
- print("3. 欄位名稱不匹配")
398
- return None, None, None, None
399
-
400
- # 確保所有需要的欄位都存在
401
- missing_features = set(self.feature_columns) - set(df.columns)
402
- if missing_features:
403
- print(f"警告: 缺少特徵欄位: {missing_features}")
404
- # 為缺少的特徵填充預設值
405
- for feature in missing_features:
406
- if feature == 'business_climate':
407
- df[feature] = 25.0 # 景氣燈號預設值
408
- elif feature == 'PMI':
409
- df[feature] = 50.0 # PMI預設值
410
- else:
411
- df[feature] = 0.0
412
- print("已填充缺失的特徵欄位")
413
-
414
- missing_targets = set(self.target_columns) - set(df.columns)
415
- if missing_targets:
416
- print(f"錯誤: 缺少目標欄位: {missing_targets}")
417
- return None, None, None, None
418
-
419
- # 提取特徵和目標變數
420
- X = df[self.feature_columns].values
421
- y = df[self.target_columns].values
422
-
423
- print(f"數據形狀: X={X.shape}, y={y.shape}")
424
- print(f"數據日期範圍: {df.index.min()} 到 {df.index.max()}")
425
-
426
- return X, y, df.index, df
427
-
428
- except Exception as e:
429
- print(f"準備訓練數據時發生錯誤: {e}")
430
- return None, None, None, None
431
-
432
- def create_sequences(self, X, y):
433
- """創建時間序列序列"""
434
- X_seq, y_seq = [], []
435
-
436
- for i in range(self.sequence_length, len(X)):
437
- X_seq.append(X[i-self.sequence_length:i])
438
- y_seq.append(y[i])
439
-
440
- return np.array(X_seq), np.array(y_seq)
441
-
442
- def build_model(self, input_shape, output_shape):
443
- """建構LSTM模型"""
444
- if not TENSORFLOW_AVAILABLE:
445
- raise ImportError("TensorFlow未安裝,無法建立模型")
446
-
447
- model = Sequential([
448
- # 第一層LSTM
449
- LSTM(128, return_sequences=True, input_shape=input_shape,
450
- kernel_regularizer=l2(0.001)),
451
- BatchNormalization(),
452
- Dropout(0.2),
453
-
454
- # 第二層LSTM
455
- LSTM(64, return_sequences=True, kernel_regularizer=l2(0.001)),
456
- BatchNormalization(),
457
- Dropout(0.2),
458
-
459
- # 第三層LSTM
460
- LSTM(32, return_sequences=False, kernel_regularizer=l2(0.001)),
461
- BatchNormalization(),
462
- Dropout(0.2),
463
-
464
- # 全連接層
465
- Dense(64, kernel_regularizer=l2(0.001)),
466
- LeakyReLU(alpha=0.1),
467
- BatchNormalization(),
468
- Dropout(0.3),
469
-
470
- Dense(32, kernel_regularizer=l2(0.001)),
471
- LeakyReLU(alpha=0.1),
472
- Dropout(0.2),
473
-
474
- # 輸出層
475
- Dense(output_shape, activation='linear')
476
- ])
477
-
478
- # 編譯模型
479
- optimizer = Adam(learning_rate=0.001, clipnorm=1.0)
480
- model.compile(
481
- optimizer=optimizer,
482
- loss='huber', # 對異常值較不敏感
483
- metrics=['mse', 'mae']
484
- )
485
-
486
- return model
487
-
488
- def train_model(self, csv_path=None):
489
- """訓練模型"""
490
- if not TENSORFLOW_AVAILABLE:
491
- print("錯誤: TensorFlow未安裝,無法訓練模型")
492
- return False
493
-
494
- print("開始準備訓練數據...")
495
- X, y, dates, df = self.prepare_training_data(csv_path)
496
-
497
- # 如果CSV載入失敗,嘗試使用示例數據
498
- if (X is None or y is None) and csv_path:
499
- print("CSV載入失敗,嘗試創建示例數據...")
500
- df = self.create_sample_data()
501
- if df is not None:
502
- X = df[self.feature_columns].values
503
- y = df[self.target_columns].values
504
- dates = df.index
505
- print("使用示例數據繼續訓練")
506
-
507
- if X is None or y is None:
508
- print("錯誤: 無法準備訓練數據")
509
- return False
510
-
511
- # 檢查數據質量
512
- if len(X) < 100:
513
- print(f"警告: 訓練數據過少 ({len(X)} 樣本),建議至少100個樣本")
514
- return False
515
-
516
- print("正在縮放數據...")
517
- # 縮放特徵
518
- self.feature_scaler = RobustScaler()
519
- X_scaled = self.feature_scaler.fit_transform(X)
520
-
521
- # 為每個目標變數建立獨立的縮放器
522
- y_scaled = np.zeros_like(y)
523
- for i, target in enumerate(self.target_columns):
524
- scaler = RobustScaler()
525
- y_scaled[:, i:i+1] = scaler.fit_transform(y[:, i:i+1])
526
- self.target_scalers[target] = scaler
527
-
528
- print("正在創建時間序列...")
529
- X_seq, y_seq = self.create_sequences(X_scaled, y_scaled)
530
-
531
- if len(X_seq) == 0:
532
- print(f"錯誤: 序列長度不足,需要至少{self.sequence_length + 1}個數據點")
533
- return False
534
-
535
- print(f"序列形狀: X_seq={X_seq.shape}, y_seq={y_seq.shape}")
536
-
537
- # 分割訓練和驗證集
538
- split_idx = int(len(X_seq) * 0.8) # 使用時間順序分割而不是隨機分割
539
- X_train, X_val = X_seq[:split_idx], X_seq[split_idx:]
540
- y_train, y_val = y_seq[:split_idx], y_seq[split_idx:]
541
-
542
- print(f"訓練集大小: {X_train.shape}, 驗證集大小: {X_val.shape}")
543
-
544
- # 建立模型
545
- print("正在建立模型...")
546
- input_shape = (X_seq.shape[1], X_seq.shape[2])
547
- output_shape = y_seq.shape[1]
548
-
549
- self.model = self.build_model(input_shape, output_shape)
550
- print(f"模型架構: 輸入={input_shape}, 輸出={output_shape}")
551
-
552
- # 設定回調函數
553
- callbacks = [
554
- EarlyStopping(
555
- monitor='val_loss',
556
- patience=15,
557
- restore_best_weights=True,
558
- verbose=1
559
- ),
560
- ReduceLROnPlateau(
561
- monitor='val_loss',
562
- factor=0.5,
563
- patience=8,
564
- min_lr=1e-6,
565
- verbose=1
566
- )
567
- ]
568
-
569
- # 訓練模型
570
- print("開始訓練模型...")
571
- try:
572
- history = self.model.fit(
573
- X_train, y_train,
574
- validation_data=(X_val, y_val),
575
- epochs=50, # 減少epoch數量以加快訓練
576
- batch_size=min(32, len(X_train) // 4), # 根據數據大小調整batch size
577
- callbacks=callbacks,
578
- verbose=1
579
- )
580
- except Exception as e:
581
- print(f"訓練過程中發生錯誤: {e}")
582
- return False
583
-
584
- # 評估模型
585
- print("\n評估模型性能...")
586
- try:
587
- train_loss = self.model.evaluate(X_train, y_train, verbose=0)
588
- val_loss = self.model.evaluate(X_val, y_val, verbose=0)
589
-
590
- print(f"訓練集損失: {train_loss[0]:.4f}")
591
- print(f"驗證集損失: {val_loss[0]:.4f}")
592
- except Exception as e:
593
- print(f"評估過程中發生錯誤: {e}")
594
-
595
- # 儲存模型和縮放器
596
- self.save_model()
597
-
598
- return True
599
-
600
- def save_model(self):
601
- """儲存模型和縮放器"""
602
- try:
603
- if self.model:
604
- self.model.save(self.model_path)
605
- print(f"模型已儲存至: {self.model_path}")
606
-
607
- # 儲存縮放器
608
- scalers_dict = {'feature_scaler': self.feature_scaler}
609
- scalers_dict.update(self.target_scalers)
610
-
611
- # 將sklearn縮放器轉換為可序列化的格式
612
- scalers_data = {}
613
- for name, scaler in scalers_dict.items():
614
- if hasattr(scaler, 'scale_'):
615
- scalers_data[f'{name}_scale'] = scaler.scale_
616
- scalers_data[f'{name}_center'] = scaler.center_
617
-
618
- np.savez(self.scalers_path, **scalers_data)
619
- print(f"縮放器已儲存至: {self.scalers_path}")
620
-
621
- except Exception as e:
622
- print(f"儲存模型時發生錯誤: {e}")
623
-
624
- def load_model(self):
625
- """載入模型和縮放器"""
626
- try:
627
- if os.path.exists(self.model_path) and TENSORFLOW_AVAILABLE:
628
- self.model = load_model(self.model_path)
629
- print("模型載入成功")
630
-
631
- # 載入縮放器
632
- if os.path.exists(self.scalers_path):
633
- scalers_data = np.load(self.scalers_path)
634
-
635
- # 重建特徵縮放器
636
- if 'feature_scaler_scale' in scalers_data:
637
- self.feature_scaler = RobustScaler()
638
- self.feature_scaler.scale_ = scalers_data['feature_scaler_scale']
639
- self.feature_scaler.center_ = scalers_data['feature_scaler_center']
640
-
641
- # 重建目標縮放器
642
- for target in self.target_columns:
643
- scale_key = f'{target}_scale'
644
- center_key = f'{target}_center'
645
- if scale_key in scalers_data:
646
- scaler = RobustScaler()
647
- scaler.scale_ = scalers_data[scale_key]
648
- scaler.center_ = scalers_data[center_key]
649
- self.target_scalers[target] = scaler
650
-
651
- print("縮放器載入成功")
652
-
653
- return True
654
- else:
655
- print("模型文件不存在或TensorFlow未安裝")
656
- return False
657
-
658
- except Exception as e:
659
- print(f"載入模型時發生錯誤: {e}")
660
- return False
661
-
662
- # 全域預測器實例
663
- _predictor = None
664
-
665
- def get_predictor():
666
- """獲取預測器實例"""
667
- global _predictor
668
- if _predictor is None:
669
- _predictor = StockPredictor()
670
- _predictor.load_model()
671
- return _predictor
672
-
673
- def advanced_lstm_predict(predict_days):
674
- """
675
- 進階LSTM預測函數 - 與main程式的介面
676
-
677
- Args:
678
- predict_days: 預測天數 (1, 5, 10, 20, 60)
679
-
680
- Returns:
681
- dict: 包含predicted_price, change_pct, confidence的字典
682
- None: 如果預測失敗
683
- """
684
- try:
685
- predictor = get_predictor()
686
-
687
- if predictor.model is None:
688
- print("模型未載入,無法進行預測")
689
- return None
690
-
691
- # 獲取最新數據進行預測
692
- current_data = predictor.fetch_yfinance_data(
693
- start_date=(datetime.now() - timedelta(days=90)).strftime('%Y-%m-%d'),
694
- end_date=datetime.now().strftime('%Y-%m-%d')
695
- )
696
-
697
- if current_data is None or len(current_data) < predictor.sequence_length:
698
- print("無法獲取足夠的當前數據進行預測")
699
- return None
700
-
701
- # 計算技術指標
702
- current_data = predictor.calculate_technical_indicators(current_data)
703
-
704
- # 載入外部數據
705
- business_climate, pmi_data = predictor.load_external_data()
706
-
707
- # 合併外部數據
708
- if not business_climate.empty:
709
- current_data['business_climate'] = business_climate['Index'].reindex(
710
- current_data.index, method='ffill'
711
- ).fillna(25.0)
712
- else:
713
- current_data['business_climate'] = 25.0
714
-
715
- if not pmi_data.empty:
716
- current_data['PMI'] = pmi_data['Index'].reindex(
717
- current_data.index, method='ffill'
718
- ).fillna(50.0)
719
- else:
720
- current_data['PMI'] = 50.0
721
-
722
- # 填補缺失的特徵
723
- for feature in predictor.feature_columns:
724
- if feature not in current_data.columns:
725
- current_data[feature] = 0.0
726
-
727
- current_data = current_data.dropna()
728
-
729
- if len(current_data) < predictor.sequence_length:
730
- print("處理後的數據不足以進行預測")
731
- return None
732
-
733
- # 提取特徵並縮放
734
- X_current = current_data[predictor.feature_columns].values
735
- X_current_scaled = predictor.feature_scaler.transform(X_current)
736
-
737
- # 創建序列
738
- X_seq = X_current_scaled[-predictor.sequence_length:].reshape(
739
- 1, predictor.sequence_length, len(predictor.feature_columns)
740
- )
741
-
742
- # 進行預測
743
- prediction_scaled = predictor.model.predict(X_seq, verbose=0)
744
-
745
- # 確定目標欄位索引
746
- target_map = {1: 'close_1d', 5: 'close_5d', 10: 'close_10d',
747
- 20: 'close_20d', 60: 'close_60d'}
748
- target_col = target_map.get(predict_days, 'close_5d')
749
- target_idx = predictor.target_columns.index(target_col)
750
-
751
- # 反縮放預測結果
752
- if target_col in predictor.target_scalers:
753
- predicted_price = predictor.target_scalers[target_col].inverse_transform(
754
- prediction_scaled[:, target_idx:target_idx+1]
755
- )[0, 0]
756
- else:
757
- print(f"未找到 {target_col} 的縮放器")
758
- return None
759
-
760
- # 計算變化百分比
761
- current_price = current_data['close'].iloc[-1]
762
- change_pct = ((predicted_price - current_price) / current_price) * 100
763
-
764
- # 計算信心度 (簡化版本,基於歷史波動性)
765
- recent_volatility = current_data['close'].pct_change().tail(20).std()
766
- confidence = max(0.5, min(0.9, 1 - recent_volatility * 5))
767
-
768
- return {
769
- 'predicted_price': predicted_price,
770
- 'change_pct': change_pct,
771
- 'confidence': confidence
772
- }
773
-
774
- except Exception as e:
775
- print(f"LSTM預測時發生錯誤: {e}")
776
- return None
777
-
778
- def train_model_from_csv(csv_path):
779
- """從CSV檔案訓練模型的便利函數"""
780
- predictor = StockPredictor()
781
- return predictor.train_model(csv_path)
782
-
783
- if __name__ == "__main__":
784
- # 測試模式
785
- print("=== 股價預測模型測試 ===")
786
-
787
- # 首先檢查TensorFlow是否可用
788
- if not TENSORFLOW_AVAILABLE:
789
- print("警告: TensorFlow未安裝,無法使用深度學習功能")
790
- print("請安裝TensorFlow: pip install tensorflow")
791
- exit(1)
792
-
793
- # 檢查是否有CSV檔案
794
- csv_file = "新期末專案輸入資料20220912-20250909.csv"
795
-
796
- if os.path.exists(csv_file):
797
- print(f"找到CSV檔案: {csv_file}")
798
-
799
- # 先創建預測器並調試CSV
800
- predictor = StockPredictor()
801
-
802
- success = predictor.train_model(csv_file)
803
- if success:
804
- print("模型訓練完成!")
805
- else:
806
- print("CSV訓練失敗,嘗試使用yfinance數據...")
807
- success = predictor.train_model()
808
- if success:
809
- print("使用yfinance數據訓練完成!")
810
- else:
811
- print("所有訓練方法都失敗!")
812
- else:
813
- print(f"未找到CSV檔案: {csv_file}")
814
- print("將使用yfinance數據進行訓練...")
815
- predictor = StockPredictor()
816
- success = predictor.train_model()
817
- if success:
818
- print("模型訓練完成!")
819
- else:
820
- print("模型訓練失敗!")
821
-
822
- # 測試預測
823
- print("\n=== 測試預測功能 ===")
824
- test_predictions = [1, 5, 10, 20, 60]
825
-
826
- for days in test_predictions:
827
- result = advanced_lstm_predict(days)
828
- if result:
829
- print(f"{days}日預測: 價格={result['predicted_price']:.2f}, "
830
- f"變化={result['change_pct']:+.2f}%, "
831
- f"信心度={result['confidence']:.1%}")
832
- else:
833
- print(f"{days}日預測失敗")