AlanRex commited on
Commit
d012b7e
·
verified ·
1 Parent(s): 0ea63d4

Update model_predictor.py

Browse files
Files changed (1) hide show
  1. model_predictor.py +79 -64
model_predictor.py CHANGED
@@ -58,82 +58,97 @@ class StockPredictor:
58
  self.scalers_path = 'scalers.npz'
59
 
60
  def fetch_yfinance_data(self, start_date='2022-09-12', end_date='2025-09-08'):
61
- """從yfinance獲取股市數據"""
62
- try:
63
- # 台積電 (2330.TW) 作為主要目標股票
64
- taiwan_stock = yf.Ticker('2330.TW')
65
- taiwan_data = taiwan_stock.history(start=start_date, end=end_date)
66
-
67
- if taiwan_data.empty:
68
- print("警告: 無法獲取台灣股市數據")
69
- return None
 
 
 
70
 
71
- # 獲取美國市場數據
72
- symbols = {
73
- 'DJI': '^DJI',
74
- 'NAS': '^IXIC',
75
- 'SOX': '^SOX',
76
- 'S&P_500': '^GSPC',
77
- 'TSM_ADR': 'TSM'
78
- }
79
 
80
- us_data = {}
81
- for name, symbol in symbols.items():
82
- try:
83
- ticker = yf.Ticker(symbol)
84
- data = ticker.history(start=start_date, end=end_date)
85
- if not data.empty:
86
- us_data[name] = data['Close']
87
- else:
88
- print(f"警告: 無法獲取 {name} 數據")
89
- except Exception as e:
90
- print(f"獲取 {name} 數據時發生錯誤: {e}")
 
 
 
 
91
 
92
- # 合併數據
93
- main_df = pd.DataFrame(index=taiwan_data.index)
94
- main_df['close'] = taiwan_data['Close']
95
- main_df['volume'] = taiwan_data['Volume']
96
 
97
- # 計算報酬率
98
- main_df['rate'] = main_df['close'].pct_change()
99
 
100
- # 添加美國市場數據
101
- for name, data in us_data.items():
102
- # 重新索引以匹配台灣股市交易日
103
- main_df[name] = data.reindex(main_df.index, method='ffill')
104
 
105
- return main_df
106
 
107
- except Exception as e:
108
- print(f"獲取yfinance數據時發生錯誤: {e}")
109
- return None
110
 
111
  def load_external_data(self):
112
- """載入外部經濟數據"""
113
- business_climate = pd.DataFrame()
114
- pmi_data = pd.DataFrame()
115
 
116
- # 載入景氣燈號數據
117
- try:
118
- if os.path.exists('business_climate.csv'):
119
- business_climate = pd.read_csv('business_climate.csv')
120
- business_climate['Date'] = pd.to_datetime(business_climate['Date'])
121
- business_climate.set_index('Date', inplace=True)
122
- print("成功載入景氣燈號數據")
123
- except Exception as e:
124
- print(f"載入景氣燈號數據失敗: {e}")
 
 
 
 
125
 
126
- # 載入PMI數據
127
- try:
128
- if os.path.exists('taiwan_pmi.csv'):
129
- pmi_data = pd.read_csv('taiwan_pmi.csv')
130
- pmi_data['Date'] = pd.to_datetime(pmi_data['Date'])
131
- pmi_data.set_index('Date', inplace=True)
132
- print("成功載入PMI數據")
133
- except Exception as e:
134
- print(f"載入PMI數據失敗: {e}")
 
 
 
 
135
 
136
- return business_climate, pmi_data
137
 
138
  def calculate_technical_indicators(self, df):
139
  """計算技術指標"""
 
58
  self.scalers_path = 'scalers.npz'
59
 
60
  def fetch_yfinance_data(self, start_date='2022-09-12', end_date='2025-09-08'):
61
+ """從yfinance獲取股市數據"""
62
+ try:
63
+ # 台積電 (2330.TW) 作為主要目標股票
64
+ taiwan_stock = yf.Ticker('2330.TW')
65
+ taiwan_data = taiwan_stock.history(start=start_date, end=end_date)
66
+
67
+ # 新增: 移除時區,使索引為 tz-naive
68
+ taiwan_data.index = taiwan_data.index.tz_localize(None)
69
+
70
+ if taiwan_data.empty:
71
+ print("警告: 無法獲取台灣股市數據")
72
+ return None
73
 
74
+ # 獲取美國市場數據
75
+ symbols = {
76
+ 'DJI': '^DJI',
77
+ 'NAS': '^IXIC',
78
+ 'SOX': '^SOX',
79
+ 'S&P_500': '^GSPC',
80
+ 'TSM_ADR': 'TSM'
81
+ }
82
 
83
+ us_data = {}
84
+ for name, symbol in symbols.items():
85
+ try:
86
+ ticker = yf.Ticker(symbol)
87
+ data = ticker.history(start=start_date, end=end_date)
88
+
89
+ # 新增: 移除時區,使索引為 tz-naive
90
+ data.index = data.index.tz_localize(None)
91
+
92
+ if not data.empty:
93
+ us_data[name] = data['Close']
94
+ else:
95
+ print(f"警告: 無法獲取 {name} 數據")
96
+ except Exception as e:
97
+ print(f"獲取 {name} 數據時發生錯誤: {e}")
98
 
99
+ # 合併數據
100
+ main_df = pd.DataFrame(index=taiwan_data.index) # 現在 index 已 tz-naive
101
+ main_df['close'] = taiwan_data['Close']
102
+ main_df['volume'] = taiwan_data['Volume']
103
 
104
+ # 計算報酬率
105
+ main_df['rate'] = main_df['close'].pct_change()
106
 
107
+ # 添加美國市場數據
108
+ for name, data in us_data.items():
109
+ # 重新索引以匹配台灣股市交易日
110
+ main_df[name] = data.reindex(main_df.index, method='ffill')
111
 
112
+ return main_df
113
 
114
+ except Exception as e:
115
+ print(f"獲取yfinance數據時發生錯誤: {e}")
116
+ return None
117
 
118
  def load_external_data(self):
119
+ """載入外部經濟數據"""
120
+ business_climate = pd.DataFrame()
121
+ pmi_data = pd.DataFrame()
122
 
123
+ # 載入景氣燈號數據
124
+ try:
125
+ if os.path.exists('business_climate.csv'):
126
+ business_climate = pd.read_csv('business_climate.csv')
127
+ business_climate['Date'] = pd.to_datetime(business_climate['Date'])
128
+ business_climate.set_index('Date', inplace=True)
129
+
130
+ # 新增: 確保索引為 tz-naive
131
+ business_climate.index = business_climate.index.tz_localize(None)
132
+
133
+ print("成功載入景氣燈號數據")
134
+ except Exception as e:
135
+ print(f"載入景氣燈號數據失敗: {e}")
136
 
137
+ # 載入PMI數據
138
+ try:
139
+ if os.path.exists('taiwan_pmi.csv'):
140
+ pmi_data = pd.read_csv('taiwan_pmi.csv')
141
+ pmi_data['Date'] = pd.to_datetime(pmi_data['Date'])
142
+ pmi_data.set_index('Date', inplace=True)
143
+
144
+ # 新增: 確保索引為 tz-naive
145
+ pmi_data.index = pmi_data.index.tz_localize(None)
146
+
147
+ print("成功載入PMI數據")
148
+ except Exception as e:
149
+ print(f"載入PMI數據失敗: {e}")
150
 
151
+ return business_climate, pmi_data
152
 
153
  def calculate_technical_indicators(self, df):
154
  """計算技術指標"""