AlanRex commited on
Commit
cc07c78
·
verified ·
1 Parent(s): d6aec87

Delete model_predictor.py

Browse files
Files changed (1) hide show
  1. model_predictor.py +0 -847
model_predictor.py DELETED
@@ -1,847 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
- """model_predictor.ipynb
3
-
4
- Automatically generated by Colab.
5
-
6
- Original file is located at
7
- https://colab.research.google.com/drive/1pIuCvafVPCRzTLojc-rZH_MFKsxMam2L
8
- """
9
-
10
- # model_predictor.py
11
- # 深度學習股價預測模型 - 適用於 HUGGING_FACE_V4.2
12
-
13
- import os
14
- import numpy as np
15
- import pandas as pd
16
- import yfinance as yf
17
- from datetime import datetime, timedelta
18
- import warnings
19
- warnings.filterwarnings('ignore')
20
-
21
- # TensorFlow/Keras 相關
22
- try:
23
- import tensorflow as tf
24
- from tensorflow.keras.models import Sequential, load_model
25
- from tensorflow.keras.layers import LSTM, Dense, Dropout, BatchNormalization, LeakyReLU
26
- from tensorflow.keras.optimizers import Adam
27
- from tensorflow.keras.callbacks import EarlyStopping, ReduceLROnPlateau
28
- from tensorflow.keras.regularizers import l2
29
- from sklearn.preprocessing import MinMaxScaler, RobustScaler
30
- from sklearn.model_selection import train_test_split
31
- from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
32
- TENSORFLOW_AVAILABLE = True
33
- except ImportError:
34
- TENSORFLOW_AVAILABLE = False
35
-
36
- # 設定隨機種子以確保結果可重現
37
- if TENSORFLOW_AVAILABLE:
38
- tf.random.set_seed(42)
39
- np.random.seed(42)
40
-
41
- class StockPredictor:
42
- """股價預測模型類別"""
43
-
44
- def __init__(self):
45
- self.model = None
46
- self.feature_scaler = None
47
- self.target_scalers = {} # 為每個目標變數建立獨立的縮放器
48
- self.feature_columns = [
49
- 'volume', 'rate', 'DJI', 'NAS', 'SOX', 'S&P_500', 'TSM_ADR',
50
- 'RSI', 'MACD', 'MACDsign', 'MACDvol', 'K', 'D',
51
- '+DI', '-DI', 'ADX', 'business_climate', 'PMI'
52
- ]
53
- self.target_columns = [
54
- 'close_1d', 'close_5d', 'close_10d', 'close_20d', 'close_60d'
55
- ]
56
- self.sequence_length = 60 # 使用60天的歷史數據
57
- self.model_path = 'lstm_stock_model.h5'
58
- self.scalers_path = 'scalers.npz'
59
-
60
- def fetch_yfinance_data(self, start_date='2022-09-12', end_date='2025-09-08'):
61
- """從yfinance獲取股市數據"""
62
- try:
63
- # 台積電 (2330.TW) 作為主要目標股票
64
- taiwan_stock = yf.Ticker('2330.TW')
65
- taiwan_data = taiwan_stock.history(start=start_date, end=end_date)
66
-
67
- # 移除時區,使索引為 tz-naive
68
- taiwan_data.index = taiwan_data.index.tz_localize(None)
69
-
70
- if taiwan_data.empty:
71
- print("警告: 無法獲取台灣股市數據")
72
- return None
73
-
74
- # 獲取美國市場數據
75
- symbols = {
76
- 'DJI': '^DJI',
77
- 'NAS': '^IXIC',
78
- 'SOX': '^SOX',
79
- 'S&P_500': '^GSPC',
80
- 'TSM_ADR': 'TSM'
81
- }
82
-
83
- us_data = {}
84
- for name, symbol in symbols.items():
85
- try:
86
- ticker = yf.Ticker(symbol)
87
- data = ticker.history(start=start_date, end=end_date)
88
-
89
- # 移除時區,使索引為 tz-naive
90
- data.index = data.index.tz_localize(None)
91
-
92
- if not data.empty:
93
- us_data[name] = data['Close']
94
- else:
95
- print(f"警告: 無法獲取 {name} 數據")
96
- except Exception as e:
97
- print(f"獲取 {name} 數據時發生錯誤: {e}")
98
-
99
- # 合併數據
100
- main_df = pd.DataFrame(index=taiwan_data.index)
101
- main_df['close'] = taiwan_data['Close']
102
- main_df['volume'] = taiwan_data['Volume']
103
-
104
- # 計算報酬率
105
- main_df['rate'] = main_df['close'].pct_change()
106
-
107
- # 添加美國市場數據
108
- for name, data in us_data.items():
109
- main_df[name] = data.reindex(main_df.index, method='ffill')
110
-
111
- return main_df
112
-
113
- except Exception as e:
114
- print(f"獲取yfinance數據時發生錯誤: {e}")
115
- return None
116
-
117
- def load_external_data(self):
118
- """載入外部經濟數據"""
119
- business_climate = pd.DataFrame()
120
- pmi_data = pd.DataFrame()
121
-
122
- # 載入景氣燈號數據
123
- try:
124
- if os.path.exists('business_climate.csv'):
125
- business_climate = pd.read_csv('business_climate.csv')
126
- business_climate['Date'] = pd.to_datetime(business_climate['Date'])
127
- business_climate.set_index('Date', inplace=True)
128
-
129
- # 新增: 確保索引為 tz-naive
130
- business_climate.index = business_climate.index.tz_localize(None)
131
-
132
- print("成功載入景氣燈號數據")
133
- except Exception as e:
134
- print(f"載入景氣燈號數據失敗: {e}")
135
-
136
- # 載入PMI數據
137
- try:
138
- if os.path.exists('taiwan_pmi.csv'):
139
- pmi_data = pd.read_csv('taiwan_pmi.csv')
140
- pmi_data['Date'] = pd.to_datetime(pmi_data['Date'])
141
- pmi_data.set_index('Date', inplace=True)
142
-
143
- # 新增: 確保索引為 tz-naive
144
- pmi_data.index = pmi_data.index.tz_localize(None)
145
-
146
- print("成功載入PMI數據")
147
- except Exception as e:
148
- print(f"載入PMI數據失敗: {e}")
149
-
150
- return business_climate, pmi_data
151
-
152
- def calculate_technical_indicators(self, df):
153
- """計算技術指標"""
154
- try:
155
- # RSI
156
- delta = df['close'].diff()
157
- gain = (delta.where(delta > 0, 0)).rolling(window=14).mean()
158
- loss = (-delta.where(delta < 0, 0)).rolling(window=14).mean()
159
- rs = gain / loss
160
- df['RSI'] = 100 - (100 / (1 + rs))
161
-
162
- # MACD
163
- exp1 = df['close'].ewm(span=12).mean()
164
- exp2 = df['close'].ewm(span=26).mean()
165
- df['MACD'] = exp1 - exp2
166
- df['MACDsign'] = df['MACD'].ewm(span=9).mean()
167
- df['MACDvol'] = df['MACD'] - df['MACDsign']
168
-
169
- # KD指標
170
- low_min = df['close'].rolling(window=9).min()
171
- high_max = df['close'].rolling(window=9).max()
172
- rsv = (df['close'] - low_min) / (high_max - low_min) * 100
173
- df['K'] = rsv.ewm(com=2).mean()
174
- df['D'] = df['K'].ewm(com=2).mean()
175
-
176
- # DMI指標 (簡化版本,使用close價格)
177
- df['high_low_diff'] = df['close'].rolling(2).max() - df['close'].rolling(2).min()
178
- df['+DI'] = df['high_low_diff'].rolling(14).mean()
179
- df['-DI'] = df['high_low_diff'].rolling(14).std()
180
- df['ADX'] = (df['+DI'] + df['-DI']).rolling(14).mean()
181
-
182
- # 清理臨時欄位
183
- df.drop(['high_low_diff'], axis=1, inplace=True)
184
-
185
- return df
186
-
187
- except Exception as e:
188
- print(f"計算技術指標時發生錯誤: {e}")
189
- return df
190
-
191
- def create_sample_data(self, days=500):
192
- """創建示例數據用於訓練(當CSV載入失敗時的後備方案)"""
193
- try:
194
- print("創建示例數據進行訓練...")
195
-
196
- # 獲取台積電數據作為基礎
197
- taiwan_data = self.fetch_yfinance_data(
198
- start_date=(datetime.now() - timedelta(days=days)).strftime('%Y-%m-%d'),
199
- end_date=datetime.now().strftime('%Y-%m-%d')
200
- )
201
-
202
- if taiwan_data is None or taiwan_data.empty:
203
- print("無法獲取示例數據")
204
- return None
205
-
206
- # 確保有基本的close和volume數據
207
- if 'close' not in taiwan_data.columns or 'volume' not in taiwan_data.columns:
208
- print("示例數據缺少必要欄位")
209
- return None
210
-
211
- # 計算技術指標
212
- taiwan_data = self.calculate_technical_indicators(taiwan_data)
213
-
214
- # 添加經濟指標(使用固定值)
215
- taiwan_data['business_climate'] = 25.0
216
- taiwan_data['PMI'] = 50.0
217
-
218
- # 確保所有特徵欄位存在
219
- for feature in self.feature_columns:
220
- if feature not in taiwan_data.columns:
221
- taiwan_data[feature] = 0.0
222
-
223
- # 計算未來價格目標
224
- for days in [1, 5, 10, 20, 60]:
225
- taiwan_data[f'close_{days}d'] = taiwan_data['close'].shift(-days)
226
-
227
- # 移除缺失值
228
- taiwan_data = taiwan_data.dropna()
229
-
230
- if len(taiwan_data) < 100:
231
- print("示例數據不足")
232
- return None
233
-
234
- print(f"成功創建示例數據: {taiwan_data.shape}")
235
- return taiwan_data
236
-
237
- except Exception as e:
238
- print(f"創建示例數據時發生錯誤: {e}")
239
- return None
240
- """調試CSV檔案結構"""
241
- try:
242
- print(f"\n=== 調試CSV檔案: {csv_path} ===")
243
-
244
- # 讀取前幾行看看結構
245
- with open(csv_path, 'r', encoding='utf-8') as f:
246
- first_lines = [f.readline().strip() for _ in range(5)]
247
-
248
- print("前5行原始內容:")
249
- for i, line in enumerate(first_lines):
250
- print(f"第{i+1}行: {line[:100]}...") # 只顯示前100個字符
251
-
252
- # 嘗試不同的編碼和分隔符
253
- encodings = ['utf-8', 'utf-8-sig', 'latin-1', 'cp1252']
254
- separators = [',', ';', '\t', '|']
255
-
256
- for encoding in encodings:
257
- for sep in separators:
258
- try:
259
- df_test = pd.read_csv(csv_path, encoding=encoding, sep=sep, nrows=5)
260
- if len(df_test.columns) > 5: # 如果有合理的欄位數量
261
- print(f"\n成功讀取 (編碼: {encoding}, 分隔符: '{sep}'):")
262
- print(f"欄位: {list(df_test.columns)}")
263
- print(f"數據形狀: {df_test.shape}")
264
- return encoding, sep
265
- except:
266
- continue
267
-
268
- print("無法找到合適的讀取參數")
269
- return None, None
270
-
271
- except Exception as e:
272
- print(f"調試CSV檔案時發生錯誤: {e}")
273
- return None, None
274
-
275
- def prepare_training_data(self, csv_path=None):
276
- """準備訓練數據"""
277
- try:
278
- if csv_path and os.path.exists(csv_path):
279
- # 先調試CSV檔案
280
- encoding, separator = self.debug_csv_file(csv_path)
281
-
282
- # 如果提供了CSV檔案,直接載入
283
- print(f"\n從 {csv_path} 載入數據...")
284
-
285
- # 使用找到的最佳參數讀取
286
- read_params = {}
287
- if encoding:
288
- read_params['encoding'] = encoding
289
- if separator and separator != ',':
290
- read_params['sep'] = separator
291
-
292
- df = pd.read_csv(csv_path, **read_params)
293
-
294
- # 檢查CSV檔案結構
295
- print(f"CSV檔案欄位: {list(df.columns)}")
296
- print(f"數據形狀: {df.shape}")
297
- print(f"前5行數據:")
298
- print(df.head())
299
-
300
- # 處理日期欄位
301
- date_columns = ['Date', 'date', 'DATE', 'Unnamed: 0']
302
- date_col = None
303
- for col in date_columns:
304
- if col in df.columns:
305
- date_col = col
306
- break
307
-
308
- if date_col:
309
- print(f"使用日期欄位: {date_col}")
310
- df[date_col] = pd.to_datetime(df[date_col])
311
- df.set_index(date_col, inplace=True)
312
- elif df.index.dtype == 'object':
313
- df.index = pd.to_datetime(df.index)
314
-
315
- print(f"處理日期後的數據形狀: {df.shape}")
316
- print(f"日期範圍: {df.index.min()} 到 {df.index.max()}")
317
-
318
- else:
319
- # 從yfinance和外部檔案獲取數據
320
- print("從yfinance獲取數據...")
321
- df = self.fetch_yfinance_data()
322
- if df is None:
323
- return None, None, None, None
324
-
325
- # 計算技術指標
326
- df = self.calculate_technical_indicators(df)
327
-
328
- # 載入外部經濟數據
329
- business_climate, pmi_data = self.load_external_data()
330
-
331
- # 合併外部數據
332
- if not business_climate.empty:
333
- df['business_climate'] = business_climate['Index'].reindex(
334
- df.index, method='ffill'
335
- )
336
- else:
337
- df['business_climate'] = 25.0 # 預設值
338
-
339
- if not pmi_data.empty:
340
- df['PMI'] = pmi_data['Index'].reindex(df.index, method='ffill')
341
- else:
342
- df['PMI'] = 50.0 # 預設值
343
-
344
- # 檢查並映射欄位名稱
345
- column_mapping = {
346
- # 可能的volume欄位名稱
347
- 'Volume': 'volume', 'vol': 'volume', 'VOLUME': 'volume',
348
- # 可能的close欄位名稱
349
- 'Close': 'close', 'close_price': 'close', 'CLOSE': 'close', 'price': 'close',
350
- # 可能的rate欄位名稱
351
- 'Rate': 'rate', 'return': 'rate', 'pct_change': 'rate', 'RATE': 'rate',
352
- # 美股指數
353
- 'DJI': 'DJI', 'DOW': 'DJI', 'dow': 'DJI',
354
- 'NAS': 'NAS', 'NASDAQ': 'NAS', 'nasdaq': 'NAS',
355
- 'SOX': 'SOX', 'sox': 'SOX',
356
- 'S&P_500': 'S&P_500', 'SP500': 'S&P_500', 'sp500': 'S&P_500',
357
- 'TSM_ADR': 'TSM_ADR', 'TSM': 'TSM_ADR', 'tsm': 'TSM_ADR',
358
- # 技術指標
359
- 'rsi': 'RSI', 'macd': 'MACD', 'macdsign': 'MACDsign', 'macdvol': 'MACDvol',
360
- 'k': 'K', 'd': 'D', '+di': '+DI', '-di': '-DI', 'adx': 'ADX',
361
- # 經濟指標
362
- 'Business_Climate': 'business_climate', 'business_climate_index': 'business_climate',
363
- 'pmi': 'PMI', 'PMI_Index': 'PMI'
364
- }
365
-
366
- # 應用欄位映射
367
- df = df.rename(columns=column_mapping)
368
- print(f"映射後的欄位: {list(df.columns)}")
369
-
370
- # 如果沒有close欄位但有其他價格欄位,嘗試使用
371
- if 'close' not in df.columns:
372
- price_candidates = ['Close', 'Price', 'CLOSE', 'close_price']
373
- for candidate in price_candidates:
374
- if candidate in df.columns:
375
- df['close'] = df[candidate]
376
- print(f"使用 {candidate} 作為 close 價格")
377
- break
378
-
379
- # 計算missing的技術指標(如果數據中沒有)
380
- if 'close' in df.columns:
381
- if 'rate' not in df.columns:
382
- df['rate'] = df['close'].pct_change()
383
- print("計算了price return rate")
384
-
385
- # 如果缺少技術指標,計算它們
386
- if 'RSI' not in df.columns:
387
- df = self.calculate_technical_indicators(df)
388
- print("計算了技術指��")
389
-
390
- # 計算未來價格目標
391
- if 'close' in df.columns:
392
- for days in [1, 5, 10, 20, 60]:
393
- df[f'close_{days}d'] = df['close'].shift(-days)
394
- print("計算了未來價格目標")
395
- else:
396
- print("錯誤: 找不到價格數據,無法計算目標變數")
397
- return None, None, None, None
398
-
399
- print(f"計算目標變數後的數據形狀: {df.shape}")
400
-
401
- # 移除缺失值
402
- original_len = len(df)
403
- df = df.dropna()
404
- print(f"移除缺失值: {original_len} -> {len(df)} 行")
405
-
406
- if df.empty:
407
- print("錯誤: 處理後的數據集為空")
408
- print("可能原因:")
409
- print("1. 所有數據都有缺失值")
410
- print("2. 日期格式不正確")
411
- print("3. 欄位名稱不匹配")
412
- return None, None, None, None
413
-
414
- # 確保所有需要的欄位都存在
415
- missing_features = set(self.feature_columns) - set(df.columns)
416
- if missing_features:
417
- print(f"警告: 缺少特徵欄位: {missing_features}")
418
- # 為缺少的特徵填充預設值
419
- for feature in missing_features:
420
- if feature == 'business_climate':
421
- df[feature] = 25.0 # 景氣燈號預設值
422
- elif feature == 'PMI':
423
- df[feature] = 50.0 # PMI預設值
424
- else:
425
- df[feature] = 0.0
426
- print("已填充缺失的特徵欄位")
427
-
428
- missing_targets = set(self.target_columns) - set(df.columns)
429
- if missing_targets:
430
- print(f"錯誤: 缺少目標欄位: {missing_targets}")
431
- return None, None, None, None
432
-
433
- # 提取特徵和目標變數
434
- X = df[self.feature_columns].values
435
- y = df[self.target_columns].values
436
-
437
- print(f"數據形狀: X={X.shape}, y={y.shape}")
438
- print(f"數據日期範圍: {df.index.min()} 到 {df.index.max()}")
439
-
440
- return X, y, df.index, df
441
-
442
- except Exception as e:
443
- print(f"準備訓練數據時發生錯誤: {e}")
444
- return None, None, None, None
445
-
446
- def create_sequences(self, X, y):
447
- """創建時間序列序列"""
448
- X_seq, y_seq = [], []
449
-
450
- for i in range(self.sequence_length, len(X)):
451
- X_seq.append(X[i-self.sequence_length:i])
452
- y_seq.append(y[i])
453
-
454
- return np.array(X_seq), np.array(y_seq)
455
-
456
- def build_model(self, input_shape, output_shape):
457
- """建構LSTM模型"""
458
- if not TENSORFLOW_AVAILABLE:
459
- raise ImportError("TensorFlow未安裝,無法建立模型")
460
-
461
- model = Sequential([
462
- # 第一層LSTM
463
- LSTM(128, return_sequences=True, input_shape=input_shape,
464
- kernel_regularizer=l2(0.001)),
465
- BatchNormalization(),
466
- Dropout(0.2),
467
-
468
- # 第二層LSTM
469
- LSTM(64, return_sequences=True, kernel_regularizer=l2(0.001)),
470
- BatchNormalization(),
471
- Dropout(0.2),
472
-
473
- # 第三層LSTM
474
- LSTM(32, return_sequences=False, kernel_regularizer=l2(0.001)),
475
- BatchNormalization(),
476
- Dropout(0.2),
477
-
478
- # 全連接層
479
- Dense(64, kernel_regularizer=l2(0.001)),
480
- LeakyReLU(alpha=0.1),
481
- BatchNormalization(),
482
- Dropout(0.3),
483
-
484
- Dense(32, kernel_regularizer=l2(0.001)),
485
- LeakyReLU(alpha=0.1),
486
- Dropout(0.2),
487
-
488
- # 輸出層
489
- Dense(output_shape, activation='linear')
490
- ])
491
-
492
- # 編譯模型
493
- optimizer = Adam(learning_rate=0.001, clipnorm=1.0)
494
- model.compile(
495
- optimizer=optimizer,
496
- loss='huber', # 對異常值較不敏感
497
- metrics=['mse', 'mae']
498
- )
499
-
500
- return model
501
-
502
- def train_model(self, csv_path=None):
503
- """訓練模型"""
504
- if not TENSORFLOW_AVAILABLE:
505
- print("錯誤: TensorFlow未安裝,無法訓練模型")
506
- return False
507
-
508
- print("開始準備訓練數據...")
509
- X, y, dates, df = self.prepare_training_data(csv_path)
510
-
511
- # 如果CSV載入失敗,嘗試使用示例數據
512
- if (X is None or y is None) and csv_path:
513
- print("CSV載入失敗,嘗試創建示例數據...")
514
- df = self.create_sample_data()
515
- if df is not None:
516
- X = df[self.feature_columns].values
517
- y = df[self.target_columns].values
518
- dates = df.index
519
- print("使用示例數據繼續訓練")
520
-
521
- if X is None or y is None:
522
- print("錯誤: 無法準備訓練數據")
523
- return False
524
-
525
- # 檢查數據質量
526
- if len(X) < 100:
527
- print(f"警告: 訓練��據過少 ({len(X)} 樣本),建議至少100個樣本")
528
- return False
529
-
530
- print("正在縮放數據...")
531
- # 縮放特徵
532
- self.feature_scaler = RobustScaler()
533
- X_scaled = self.feature_scaler.fit_transform(X)
534
-
535
- # 為每個目標變數建立獨立的縮放器
536
- y_scaled = np.zeros_like(y)
537
- for i, target in enumerate(self.target_columns):
538
- scaler = RobustScaler()
539
- y_scaled[:, i:i+1] = scaler.fit_transform(y[:, i:i+1])
540
- self.target_scalers[target] = scaler
541
-
542
- print("正在創建時間序列...")
543
- X_seq, y_seq = self.create_sequences(X_scaled, y_scaled)
544
-
545
- if len(X_seq) == 0:
546
- print(f"錯誤: 序列長度不足,需要至少{self.sequence_length + 1}個數據點")
547
- return False
548
-
549
- print(f"序列形狀: X_seq={X_seq.shape}, y_seq={y_seq.shape}")
550
-
551
- # 分割訓練和驗證集
552
- split_idx = int(len(X_seq) * 0.8) # 使用時間順序分割而不是隨機分割
553
- X_train, X_val = X_seq[:split_idx], X_seq[split_idx:]
554
- y_train, y_val = y_seq[:split_idx], y_seq[split_idx:]
555
-
556
- print(f"訓練集大小: {X_train.shape}, 驗證集大小: {X_val.shape}")
557
-
558
- # 建立模型
559
- print("正在建立模型...")
560
- input_shape = (X_seq.shape[1], X_seq.shape[2])
561
- output_shape = y_seq.shape[1]
562
-
563
- self.model = self.build_model(input_shape, output_shape)
564
- print(f"模型架構: 輸入={input_shape}, 輸出={output_shape}")
565
-
566
- # 設定回調函數
567
- callbacks = [
568
- EarlyStopping(
569
- monitor='val_loss',
570
- patience=15,
571
- restore_best_weights=True,
572
- verbose=1
573
- ),
574
- ReduceLROnPlateau(
575
- monitor='val_loss',
576
- factor=0.5,
577
- patience=8,
578
- min_lr=1e-6,
579
- verbose=1
580
- )
581
- ]
582
-
583
- # 訓練模型
584
- print("開始訓練模型...")
585
- try:
586
- history = self.model.fit(
587
- X_train, y_train,
588
- validation_data=(X_val, y_val),
589
- epochs=50, # 減少epoch數量以加快訓練
590
- batch_size=min(32, len(X_train) // 4), # 根據數據大小調整batch size
591
- callbacks=callbacks,
592
- verbose=1
593
- )
594
- except Exception as e:
595
- print(f"訓練過程中發生錯誤: {e}")
596
- return False
597
-
598
- # 評估模型
599
- print("\n評估模型性能...")
600
- try:
601
- train_loss = self.model.evaluate(X_train, y_train, verbose=0)
602
- val_loss = self.model.evaluate(X_val, y_val, verbose=0)
603
-
604
- print(f"訓練集損失: {train_loss[0]:.4f}")
605
- print(f"驗證集損失: {val_loss[0]:.4f}")
606
- except Exception as e:
607
- print(f"評估過程中發生錯誤: {e}")
608
-
609
- # 儲存模型和縮放器
610
- self.save_model()
611
-
612
- return True
613
-
614
- def save_model(self):
615
- """儲存模型和縮放器"""
616
- try:
617
- if self.model:
618
- self.model.save(self.model_path)
619
- print(f"模型已儲存至: {self.model_path}")
620
-
621
- # 儲存縮放器
622
- scalers_dict = {'feature_scaler': self.feature_scaler}
623
- scalers_dict.update(self.target_scalers)
624
-
625
- # 將sklearn縮放器轉換為可序列化的格式
626
- scalers_data = {}
627
- for name, scaler in scalers_dict.items():
628
- if hasattr(scaler, 'scale_'):
629
- scalers_data[f'{name}_scale'] = scaler.scale_
630
- scalers_data[f'{name}_center'] = scaler.center_
631
-
632
- np.savez(self.scalers_path, **scalers_data)
633
- print(f"縮放器已儲存至: {self.scalers_path}")
634
-
635
- except Exception as e:
636
- print(f"儲存模型時發生錯誤: {e}")
637
-
638
- def load_model(self):
639
- """載入模型和縮放器"""
640
- try:
641
- if os.path.exists(self.model_path) and TENSORFLOW_AVAILABLE:
642
- self.model = load_model(self.model_path)
643
- print("模型載入成功")
644
-
645
- # 載入縮放器
646
- if os.path.exists(self.scalers_path):
647
- scalers_data = np.load(self.scalers_path)
648
-
649
- # 重建特徵縮放器
650
- if 'feature_scaler_scale' in scalers_data:
651
- self.feature_scaler = RobustScaler()
652
- self.feature_scaler.scale_ = scalers_data['feature_scaler_scale']
653
- self.feature_scaler.center_ = scalers_data['feature_scaler_center']
654
-
655
- # 重建目標縮放器
656
- for target in self.target_columns:
657
- scale_key = f'{target}_scale'
658
- center_key = f'{target}_center'
659
- if scale_key in scalers_data:
660
- scaler = RobustScaler()
661
- scaler.scale_ = scalers_data[scale_key]
662
- scaler.center_ = scalers_data[center_key]
663
- self.target_scalers[target] = scaler
664
-
665
- print("縮放器載入成功")
666
-
667
- return True
668
- else:
669
- print("模型文件不存在或TensorFlow未安裝")
670
- return False
671
-
672
- except Exception as e:
673
- print(f"載入模型時發生錯誤: {e}")
674
- return False
675
-
676
- # 全域預測器實例
677
- _predictor = None
678
-
679
- def get_predictor():
680
- """獲取預測器實例"""
681
- global _predictor
682
- if _predictor is None:
683
- _predictor = StockPredictor()
684
- _predictor.load_model()
685
- return _predictor
686
-
687
- def advanced_lstm_predict(predict_days):
688
- """
689
- 進階LSTM預測函數 - 與main程式的介面
690
-
691
- Args:
692
- predict_days: 預測天數 (1, 5, 10, 20, 60)
693
-
694
- Returns:
695
- dict: 包含predicted_price, change_pct, confidence的字典
696
- None: 如果預測失敗
697
- """
698
- try:
699
- predictor = get_predictor()
700
-
701
- if predictor.model is None:
702
- print("模型未載入,無法進行預測")
703
- return None
704
-
705
- # 獲取最新數據進行預測
706
- current_data = predictor.fetch_yfinance_data(
707
- start_date=(datetime.now() - timedelta(days=90)).strftime('%Y-%m-%d'),
708
- end_date=datetime.now().strftime('%Y-%m-%d')
709
- )
710
-
711
- if current_data is None or len(current_data) < predictor.sequence_length:
712
- print("無法獲取足夠的當前數據進行預測")
713
- return None
714
-
715
- # 計算技術指標
716
- current_data = predictor.calculate_technical_indicators(current_data)
717
-
718
- # 載入外部數據
719
- business_climate, pmi_data = predictor.load_external_data()
720
-
721
- # 合併外部數據
722
- if not business_climate.empty:
723
- current_data['business_climate'] = business_climate['Index'].reindex(
724
- current_data.index, method='ffill'
725
- ).fillna(25.0)
726
- else:
727
- current_data['business_climate'] = 25.0
728
-
729
- if not pmi_data.empty:
730
- current_data['PMI'] = pmi_data['Index'].reindex(
731
- current_data.index, method='ffill'
732
- ).fillna(50.0)
733
- else:
734
- current_data['PMI'] = 50.0
735
-
736
- # 填補缺失的特徵
737
- for feature in predictor.feature_columns:
738
- if feature not in current_data.columns:
739
- current_data[feature] = 0.0
740
-
741
- current_data = current_data.dropna()
742
-
743
- if len(current_data) < predictor.sequence_length:
744
- print("處理後的數據不足以進行預測")
745
- return None
746
-
747
- # 提取特徵並縮放
748
- X_current = current_data[predictor.feature_columns].values
749
- X_current_scaled = predictor.feature_scaler.transform(X_current)
750
-
751
- # 創建序列
752
- X_seq = X_current_scaled[-predictor.sequence_length:].reshape(
753
- 1, predictor.sequence_length, len(predictor.feature_columns)
754
- )
755
-
756
- # 進行預測
757
- prediction_scaled = predictor.model.predict(X_seq, verbose=0)
758
-
759
- # 確定目標欄位索引
760
- target_map = {1: 'close_1d', 5: 'close_5d', 10: 'close_10d',
761
- 20: 'close_20d', 60: 'close_60d'}
762
- target_col = target_map.get(predict_days, 'close_5d')
763
- target_idx = predictor.target_columns.index(target_col)
764
-
765
- # 反縮放預測結果
766
- if target_col in predictor.target_scalers:
767
- predicted_price = predictor.target_scalers[target_col].inverse_transform(
768
- prediction_scaled[:, target_idx:target_idx+1]
769
- )[0, 0]
770
- else:
771
- print(f"未找到 {target_col} 的縮放器")
772
- return None
773
-
774
- # 計算變化百分比
775
- current_price = current_data['close'].iloc[-1]
776
- change_pct = ((predicted_price - current_price) / current_price) * 100
777
-
778
- # 計算信心度 (簡化版本,基於歷史波動性)
779
- recent_volatility = current_data['close'].pct_change().tail(20).std()
780
- confidence = max(0.5, min(0.9, 1 - recent_volatility * 5))
781
-
782
- return {
783
- 'predicted_price': predicted_price,
784
- 'change_pct': change_pct,
785
- 'confidence': confidence
786
- }
787
-
788
- except Exception as e:
789
- print(f"LSTM預測時發生錯誤: {e}")
790
- return None
791
-
792
- def train_model_from_csv(csv_path):
793
- """從CSV檔案訓練模型的便利函數"""
794
- predictor = StockPredictor()
795
- return predictor.train_model(csv_path)
796
-
797
- if __name__ == "__main__":
798
- # 測試模式
799
- print("=== 股價預測模型測試 ===")
800
-
801
- # 首先檢查TensorFlow是否可用
802
- if not TENSORFLOW_AVAILABLE:
803
- print("警告: TensorFlow未安裝,無法使用深度學習功能")
804
- print("請安裝TensorFlow: pip install tensorflow")
805
- exit(1)
806
-
807
- # 檢查是否有CSV檔案
808
- csv_file = "新期末專案輸入��料20220912-20250909.csv"
809
-
810
- if os.path.exists(csv_file):
811
- print(f"找到CSV檔案: {csv_file}")
812
-
813
- # 先創建預測器並調試CSV
814
- predictor = StockPredictor()
815
-
816
- success = predictor.train_model(csv_file)
817
- if success:
818
- print("模型訓練完成!")
819
- else:
820
- print("CSV訓練失敗,嘗試使用yfinance數據...")
821
- success = predictor.train_model()
822
- if success:
823
- print("使用yfinance數據訓練完成!")
824
- else:
825
- print("所有訓練方法都失敗!")
826
- else:
827
- print(f"未找到CSV檔案: {csv_file}")
828
- print("將使用yfinance數據進行訓練...")
829
- predictor = StockPredictor()
830
- success = predictor.train_model()
831
- if success:
832
- print("模型訓練完成!")
833
- else:
834
- print("模型訓練失敗!")
835
-
836
- # 測試預測
837
- print("\n=== 測試預測功能 ===")
838
- test_predictions = [1, 5, 10, 20, 60]
839
-
840
- for days in test_predictions:
841
- result = advanced_lstm_predict(days)
842
- if result:
843
- print(f"{days}日預測: 價格={result['predicted_price']:.2f}, "
844
- f"變化={result['change_pct']:+.2f}%, "
845
- f"信心度={result['confidence']:.1%}")
846
- else:
847
- print(f"{days}日預測失敗")