Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -17,16 +17,28 @@ from bs4 import BeautifulSoup
|
|
| 17 |
import requests
|
| 18 |
import time # 引用 time 模組以處理時間戳
|
| 19 |
|
|
|
|
| 20 |
# 引用您組員的預測器程式
|
| 21 |
from Bert_predict import BertPredictor
|
| 22 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 23 |
# ========================= CACHE 設定 START =========================
|
| 24 |
# 分析結果的快取字典
|
| 25 |
ANALYSIS_CACHE = {}
|
| 26 |
# 快取有效時間(秒),例如:4 小時 = 4 * 60 * 60 = 14400 秒
|
| 27 |
CACHE_DURATION_SECONDS = 8 * 60 * 60
|
| 28 |
# ========================== CACHE 設定 END ==========================
|
| 29 |
-
|
| 30 |
|
| 31 |
# 台股代號對應表 (移除台指期,因為它現在是獨立區塊)
|
| 32 |
TAIWAN_STOCKS = {
|
|
@@ -171,10 +183,9 @@ def get_stock_data(symbol, period='1y'):
|
|
| 171 |
except:
|
| 172 |
return pd.DataFrame()
|
| 173 |
|
| 174 |
-
def
|
| 175 |
-
"""
|
| 176 |
-
if len(data) < 60:
|
| 177 |
-
return None
|
| 178 |
prices = data['Close'].values
|
| 179 |
ma_short = np.mean(prices[-5:])
|
| 180 |
ma_medium = np.mean(prices[-20:])
|
|
@@ -182,21 +193,29 @@ def simple_lstm_predict(data, predict_days=5):
|
|
| 182 |
recent_trend = np.polyfit(range(20), prices[-20:], 1)[0]
|
| 183 |
volatility = np.std(prices[-20:]) / np.mean(prices[-20:])
|
| 184 |
base_change = recent_trend * predict_days
|
| 185 |
-
trend_factor = 1.0
|
| 186 |
-
if ma_short > ma_medium > ma_long:
|
| 187 |
-
trend_factor = 1.02
|
| 188 |
-
elif ma_short < ma_medium < ma_long:
|
| 189 |
-
trend_factor = 0.98
|
| 190 |
-
else:
|
| 191 |
-
trend_factor = 1.0
|
| 192 |
noise_factor = np.random.normal(1, volatility * 0.1)
|
| 193 |
predicted_price = prices[-1] * trend_factor + base_change + (prices[-1] * noise_factor * 0.01)
|
| 194 |
change_pct = ((predicted_price - prices[-1]) / prices[-1]) * 100
|
| 195 |
-
return {
|
| 196 |
-
|
| 197 |
-
|
| 198 |
-
|
| 199 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
|
| 201 |
def calculate_technical_indicators(df):
|
| 202 |
"""計算技術指標"""
|
|
@@ -467,18 +486,26 @@ app.layout = html.Div([
|
|
| 467 |
def update_taiex_prediction(predict_days):
|
| 468 |
data = get_stock_data('^TWII', '2y')
|
| 469 |
if data.empty: return html.Div("無法獲取台指期資料"), {}
|
| 470 |
-
|
|
|
|
|
|
|
|
|
|
| 471 |
if final_prediction is None: return html.Div("資料不足,無法進行預測"), {}
|
| 472 |
current_price, last_date = data['Close'].iloc[-1], data.index[-1]
|
| 473 |
predicted_price, change_pct, confidence = final_prediction['predicted_price'], final_prediction['change_pct'], final_prediction['confidence']
|
|
|
|
| 474 |
prediction_paths = {1: [1], 5: [1, 5], 10: [1, 5, 10], 20: [1, 10, 20], 60: [1, 10, 20, 60]}
|
| 475 |
intervals_to_predict = prediction_paths.get(predict_days, [predict_days])
|
| 476 |
prediction_dates, prediction_prices = [last_date], [current_price]
|
|
|
|
| 477 |
for days in intervals_to_predict:
|
| 478 |
-
|
|
|
|
| 479 |
if interim_prediction:
|
| 480 |
prediction_dates.append(last_date + timedelta(days=days))
|
| 481 |
prediction_prices.append(interim_prediction['predicted_price'])
|
|
|
|
|
|
|
| 482 |
color, arrow = ('red', '📈') if change_pct >= 0 else ('green', '📉')
|
| 483 |
result_card = html.Div([
|
| 484 |
html.H4(f"{predict_days}日後預測結果", style={'margin': '0 0 15px 0', 'color': 'white'}),
|
|
|
|
| 17 |
import requests
|
| 18 |
import time # 引用 time 模組以處理時間戳
|
| 19 |
|
| 20 |
+
# ========================= 引用外部模組 START =========================
|
| 21 |
# 引用您組員的預測器程式
|
| 22 |
from Bert_predict import BertPredictor
|
| 23 |
|
| 24 |
+
# 引用新的模型預測器
|
| 25 |
+
from model_predictor import advanced_lstm_predict
|
| 26 |
+
# ========================== 引用外部模組 END ==========================
|
| 27 |
+
|
| 28 |
+
# ========================= 全域設定 START =========================
|
| 29 |
+
# 【【【模型切換開關】】】
|
| 30 |
+
# False: 使用簡易統計模型 (預設)
|
| 31 |
+
# True: 使用 model_predictor.py 中的進階 LSTM 模型 (未來啟用)
|
| 32 |
+
USE_ADVANCED_MODEL = False
|
| 33 |
+
|
| 34 |
+
|
| 35 |
# ========================= CACHE 設定 START =========================
|
| 36 |
# 分析結果的快取字典
|
| 37 |
ANALYSIS_CACHE = {}
|
| 38 |
# 快取有效時間(秒),例如:4 小時 = 4 * 60 * 60 = 14400 秒
|
| 39 |
CACHE_DURATION_SECONDS = 8 * 60 * 60
|
| 40 |
# ========================== CACHE 設定 END ==========================
|
| 41 |
+
# ========================== 全域設定 END ==========================
|
| 42 |
|
| 43 |
# 台股代號對應表 (移除台指期,因為它現在是獨立區塊)
|
| 44 |
TAIWAN_STOCKS = {
|
|
|
|
| 183 |
except:
|
| 184 |
return pd.DataFrame()
|
| 185 |
|
| 186 |
+
def simple_statistical_predict(data, predict_days=5):
|
| 187 |
+
"""【備用模型】簡化的統計預測模型。"""
|
| 188 |
+
if len(data) < 60: return None
|
|
|
|
| 189 |
prices = data['Close'].values
|
| 190 |
ma_short = np.mean(prices[-5:])
|
| 191 |
ma_medium = np.mean(prices[-20:])
|
|
|
|
| 193 |
recent_trend = np.polyfit(range(20), prices[-20:], 1)[0]
|
| 194 |
volatility = np.std(prices[-20:]) / np.mean(prices[-20:])
|
| 195 |
base_change = recent_trend * predict_days
|
| 196 |
+
trend_factor = 1.0 + (0.02 if ma_short > ma_medium > ma_long else -0.02 if ma_short < ma_medium < ma_long else 0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
noise_factor = np.random.normal(1, volatility * 0.1)
|
| 198 |
predicted_price = prices[-1] * trend_factor + base_change + (prices[-1] * noise_factor * 0.01)
|
| 199 |
change_pct = ((predicted_price - prices[-1]) / prices[-1]) * 100
|
| 200 |
+
return {'predicted_price': predicted_price, 'change_pct': change_pct, 'confidence': max(0.6, 1 - volatility * 2)}
|
| 201 |
+
|
| 202 |
+
def get_prediction(data, predict_days=5):
|
| 203 |
+
"""
|
| 204 |
+
【【模型預測控制器】】
|
| 205 |
+
根據 USE_ADVANCED_MODEL 的設定,呼叫對應的預測模型。
|
| 206 |
+
"""
|
| 207 |
+
if USE_ADVANCED_MODEL:
|
| 208 |
+
print(f"模式: 進階LSTM模型 | 預測天期: {predict_days}天")
|
| 209 |
+
prediction = advanced_lstm_predict(predict_days)
|
| 210 |
+
# 如果進階模型預測失敗,則自動降級使用簡易模型
|
| 211 |
+
if prediction is not None:
|
| 212 |
+
return prediction
|
| 213 |
+
else:
|
| 214 |
+
print("進階模型預測失敗,自動降級為簡易統計模型。")
|
| 215 |
+
|
| 216 |
+
# 預設或降級時執行簡易模型
|
| 217 |
+
print(f"模式: 簡易統計模型 | 預測天期: {predict_days}天")
|
| 218 |
+
return simple_statistical_predict(data, predict_days)
|
| 219 |
|
| 220 |
def calculate_technical_indicators(df):
|
| 221 |
"""計算技術指標"""
|
|
|
|
| 486 |
def update_taiex_prediction(predict_days):
|
| 487 |
data = get_stock_data('^TWII', '2y')
|
| 488 |
if data.empty: return html.Div("無法獲取台指期資料"), {}
|
| 489 |
+
|
| 490 |
+
# === 修改點:統一呼叫 get_prediction 控制器 ===
|
| 491 |
+
final_prediction = get_prediction(data, predict_days)
|
| 492 |
+
|
| 493 |
if final_prediction is None: return html.Div("資料不足,無法進行預測"), {}
|
| 494 |
current_price, last_date = data['Close'].iloc[-1], data.index[-1]
|
| 495 |
predicted_price, change_pct, confidence = final_prediction['predicted_price'], final_prediction['change_pct'], final_prediction['confidence']
|
| 496 |
+
|
| 497 |
prediction_paths = {1: [1], 5: [1, 5], 10: [1, 5, 10], 20: [1, 10, 20], 60: [1, 10, 20, 60]}
|
| 498 |
intervals_to_predict = prediction_paths.get(predict_days, [predict_days])
|
| 499 |
prediction_dates, prediction_prices = [last_date], [current_price]
|
| 500 |
+
|
| 501 |
for days in intervals_to_predict:
|
| 502 |
+
# === 修改點:迴圈內也使用統一的預測控制器 ===
|
| 503 |
+
interim_prediction = get_prediction(data, days)
|
| 504 |
if interim_prediction:
|
| 505 |
prediction_dates.append(last_date + timedelta(days=days))
|
| 506 |
prediction_prices.append(interim_prediction['predicted_price'])
|
| 507 |
+
|
| 508 |
+
# (後續繪圖邏輯不變)
|
| 509 |
color, arrow = ('red', '📈') if change_pct >= 0 else ('green', '📉')
|
| 510 |
result_card = html.Div([
|
| 511 |
html.H4(f"{predict_days}日後預測結果", style={'margin': '0 0 15px 0', 'color': 'white'}),
|