Spaces:
Sleeping
Sleeping
Update model_predictor.py
Browse files- model_predictor.py +79 -64
model_predictor.py
CHANGED
|
@@ -58,82 +58,97 @@ class StockPredictor:
|
|
| 58 |
self.scalers_path = 'scalers.npz'
|
| 59 |
|
| 60 |
def fetch_yfinance_data(self, start_date='2022-09-12', end_date='2025-09-08'):
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
|
|
|
|
|
|
|
|
|
| 70 |
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 91 |
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
|
| 97 |
-
|
| 98 |
-
|
| 99 |
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
|
| 105 |
-
|
| 106 |
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
|
| 111 |
def load_external_data(self):
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 125 |
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 135 |
|
| 136 |
-
|
| 137 |
|
| 138 |
def calculate_technical_indicators(self, df):
|
| 139 |
"""計算技術指標"""
|
|
|
|
| 58 |
self.scalers_path = 'scalers.npz'
|
| 59 |
|
| 60 |
def fetch_yfinance_data(self, start_date='2022-09-12', end_date='2025-09-08'):
|
| 61 |
+
"""從yfinance獲取股市數據"""
|
| 62 |
+
try:
|
| 63 |
+
# 台積電 (2330.TW) 作為主要目標股票
|
| 64 |
+
taiwan_stock = yf.Ticker('2330.TW')
|
| 65 |
+
taiwan_data = taiwan_stock.history(start=start_date, end=end_date)
|
| 66 |
+
|
| 67 |
+
# 新增: 移除時區,使索引為 tz-naive
|
| 68 |
+
taiwan_data.index = taiwan_data.index.tz_localize(None)
|
| 69 |
+
|
| 70 |
+
if taiwan_data.empty:
|
| 71 |
+
print("警告: 無法獲取台灣股市數據")
|
| 72 |
+
return None
|
| 73 |
|
| 74 |
+
# 獲取美國市場數據
|
| 75 |
+
symbols = {
|
| 76 |
+
'DJI': '^DJI',
|
| 77 |
+
'NAS': '^IXIC',
|
| 78 |
+
'SOX': '^SOX',
|
| 79 |
+
'S&P_500': '^GSPC',
|
| 80 |
+
'TSM_ADR': 'TSM'
|
| 81 |
+
}
|
| 82 |
|
| 83 |
+
us_data = {}
|
| 84 |
+
for name, symbol in symbols.items():
|
| 85 |
+
try:
|
| 86 |
+
ticker = yf.Ticker(symbol)
|
| 87 |
+
data = ticker.history(start=start_date, end=end_date)
|
| 88 |
+
|
| 89 |
+
# 新增: 移除時區,使索引為 tz-naive
|
| 90 |
+
data.index = data.index.tz_localize(None)
|
| 91 |
+
|
| 92 |
+
if not data.empty:
|
| 93 |
+
us_data[name] = data['Close']
|
| 94 |
+
else:
|
| 95 |
+
print(f"警告: 無法獲取 {name} 數據")
|
| 96 |
+
except Exception as e:
|
| 97 |
+
print(f"獲取 {name} 數據時發生錯誤: {e}")
|
| 98 |
|
| 99 |
+
# 合併數據
|
| 100 |
+
main_df = pd.DataFrame(index=taiwan_data.index) # 現在 index 已 tz-naive
|
| 101 |
+
main_df['close'] = taiwan_data['Close']
|
| 102 |
+
main_df['volume'] = taiwan_data['Volume']
|
| 103 |
|
| 104 |
+
# 計算報酬率
|
| 105 |
+
main_df['rate'] = main_df['close'].pct_change()
|
| 106 |
|
| 107 |
+
# 添加美國市場數據
|
| 108 |
+
for name, data in us_data.items():
|
| 109 |
+
# 重新索引以匹配台灣股市交易日
|
| 110 |
+
main_df[name] = data.reindex(main_df.index, method='ffill')
|
| 111 |
|
| 112 |
+
return main_df
|
| 113 |
|
| 114 |
+
except Exception as e:
|
| 115 |
+
print(f"獲取yfinance數據時發生錯誤: {e}")
|
| 116 |
+
return None
|
| 117 |
|
| 118 |
def load_external_data(self):
|
| 119 |
+
"""載入外部經濟數據"""
|
| 120 |
+
business_climate = pd.DataFrame()
|
| 121 |
+
pmi_data = pd.DataFrame()
|
| 122 |
|
| 123 |
+
# 載入景氣燈號數據
|
| 124 |
+
try:
|
| 125 |
+
if os.path.exists('business_climate.csv'):
|
| 126 |
+
business_climate = pd.read_csv('business_climate.csv')
|
| 127 |
+
business_climate['Date'] = pd.to_datetime(business_climate['Date'])
|
| 128 |
+
business_climate.set_index('Date', inplace=True)
|
| 129 |
+
|
| 130 |
+
# 新增: 確保索引為 tz-naive
|
| 131 |
+
business_climate.index = business_climate.index.tz_localize(None)
|
| 132 |
+
|
| 133 |
+
print("成功載入景氣燈號數據")
|
| 134 |
+
except Exception as e:
|
| 135 |
+
print(f"載入景氣燈號數據失敗: {e}")
|
| 136 |
|
| 137 |
+
# 載入PMI數據
|
| 138 |
+
try:
|
| 139 |
+
if os.path.exists('taiwan_pmi.csv'):
|
| 140 |
+
pmi_data = pd.read_csv('taiwan_pmi.csv')
|
| 141 |
+
pmi_data['Date'] = pd.to_datetime(pmi_data['Date'])
|
| 142 |
+
pmi_data.set_index('Date', inplace=True)
|
| 143 |
+
|
| 144 |
+
# 新增: 確保索引為 tz-naive
|
| 145 |
+
pmi_data.index = pmi_data.index.tz_localize(None)
|
| 146 |
+
|
| 147 |
+
print("成功載入PMI數據")
|
| 148 |
+
except Exception as e:
|
| 149 |
+
print(f"載入PMI數據失敗: {e}")
|
| 150 |
|
| 151 |
+
return business_climate, pmi_data
|
| 152 |
|
| 153 |
def calculate_technical_indicators(self, df):
|
| 154 |
"""計算技術指標"""
|