AlanRex commited on
Commit
c618dae
·
verified ·
1 Parent(s): d772e4c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -19
app.py CHANGED
@@ -17,16 +17,28 @@ from bs4 import BeautifulSoup
17
  import requests
18
  import time # 引用 time 模組以處理時間戳
19
 
 
20
  # 引用您組員的預測器程式
21
  from Bert_predict import BertPredictor
22
 
 
 
 
 
 
 
 
 
 
 
 
23
  # ========================= CACHE 設定 START =========================
24
  # 分析結果的快取字典
25
  ANALYSIS_CACHE = {}
26
  # 快取有效時間(秒),例如:4 小時 = 4 * 60 * 60 = 14400 秒
27
  CACHE_DURATION_SECONDS = 8 * 60 * 60
28
  # ========================== CACHE 設定 END ==========================
29
-
30
 
31
  # 台股代號對應表 (移除台指期,因為它現在是獨立區塊)
32
  TAIWAN_STOCKS = {
@@ -171,10 +183,9 @@ def get_stock_data(symbol, period='1y'):
171
  except:
172
  return pd.DataFrame()
173
 
174
- def simple_lstm_predict(data, predict_days=5):
175
- """簡化的LSTM預測模型 (使用統計方法模擬)"""
176
- if len(data) < 60:
177
- return None
178
  prices = data['Close'].values
179
  ma_short = np.mean(prices[-5:])
180
  ma_medium = np.mean(prices[-20:])
@@ -182,21 +193,29 @@ def simple_lstm_predict(data, predict_days=5):
182
  recent_trend = np.polyfit(range(20), prices[-20:], 1)[0]
183
  volatility = np.std(prices[-20:]) / np.mean(prices[-20:])
184
  base_change = recent_trend * predict_days
185
- trend_factor = 1.0
186
- if ma_short > ma_medium > ma_long:
187
- trend_factor = 1.02
188
- elif ma_short < ma_medium < ma_long:
189
- trend_factor = 0.98
190
- else:
191
- trend_factor = 1.0
192
  noise_factor = np.random.normal(1, volatility * 0.1)
193
  predicted_price = prices[-1] * trend_factor + base_change + (prices[-1] * noise_factor * 0.01)
194
  change_pct = ((predicted_price - prices[-1]) / prices[-1]) * 100
195
- return {
196
- 'predicted_price': predicted_price,
197
- 'change_pct': change_pct,
198
- 'confidence': max(0.6, 1 - volatility * 2)
199
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
200
 
201
  def calculate_technical_indicators(df):
202
  """計算技術指標"""
@@ -467,18 +486,26 @@ app.layout = html.Div([
467
  def update_taiex_prediction(predict_days):
468
  data = get_stock_data('^TWII', '2y')
469
  if data.empty: return html.Div("無法獲取台指期資料"), {}
470
- final_prediction = simple_lstm_predict(data, predict_days)
 
 
 
471
  if final_prediction is None: return html.Div("資料不足,無法進行預測"), {}
472
  current_price, last_date = data['Close'].iloc[-1], data.index[-1]
473
  predicted_price, change_pct, confidence = final_prediction['predicted_price'], final_prediction['change_pct'], final_prediction['confidence']
 
474
  prediction_paths = {1: [1], 5: [1, 5], 10: [1, 5, 10], 20: [1, 10, 20], 60: [1, 10, 20, 60]}
475
  intervals_to_predict = prediction_paths.get(predict_days, [predict_days])
476
  prediction_dates, prediction_prices = [last_date], [current_price]
 
477
  for days in intervals_to_predict:
478
- interim_prediction = simple_lstm_predict(data, days)
 
479
  if interim_prediction:
480
  prediction_dates.append(last_date + timedelta(days=days))
481
  prediction_prices.append(interim_prediction['predicted_price'])
 
 
482
  color, arrow = ('red', '📈') if change_pct >= 0 else ('green', '📉')
483
  result_card = html.Div([
484
  html.H4(f"{predict_days}日後預測結果", style={'margin': '0 0 15px 0', 'color': 'white'}),
 
17
  import requests
18
  import time # 引用 time 模組以處理時間戳
19
 
20
+ # ========================= 引用外部模組 START =========================
21
  # 引用您組員的預測器程式
22
  from Bert_predict import BertPredictor
23
 
24
+ # 引用新的模型預測器
25
+ from model_predictor import advanced_lstm_predict
26
+ # ========================== 引用外部模組 END ==========================
27
+
28
+ # ========================= 全域設定 START =========================
29
+ # 【【【模型切換開關】】】
30
+ # False: 使用簡易統計模型 (預設)
31
+ # True: 使用 model_predictor.py 中的進階 LSTM 模型 (未來啟用)
32
+ USE_ADVANCED_MODEL = False
33
+
34
+
35
  # ========================= CACHE 設定 START =========================
36
  # 分析結果的快取字典
37
  ANALYSIS_CACHE = {}
38
  # 快取有效時間(秒),例如:4 小時 = 4 * 60 * 60 = 14400 秒
39
  CACHE_DURATION_SECONDS = 8 * 60 * 60
40
  # ========================== CACHE 設定 END ==========================
41
+ # ========================== 全域設定 END ==========================
42
 
43
  # 台股代號對應表 (移除台指期,因為它現在是獨立區塊)
44
  TAIWAN_STOCKS = {
 
183
  except:
184
  return pd.DataFrame()
185
 
186
+ def simple_statistical_predict(data, predict_days=5):
187
+ """【備用模型】簡化的統計預測模型。"""
188
+ if len(data) < 60: return None
 
189
  prices = data['Close'].values
190
  ma_short = np.mean(prices[-5:])
191
  ma_medium = np.mean(prices[-20:])
 
193
  recent_trend = np.polyfit(range(20), prices[-20:], 1)[0]
194
  volatility = np.std(prices[-20:]) / np.mean(prices[-20:])
195
  base_change = recent_trend * predict_days
196
+ trend_factor = 1.0 + (0.02 if ma_short > ma_medium > ma_long else -0.02 if ma_short < ma_medium < ma_long else 0)
 
 
 
 
 
 
197
  noise_factor = np.random.normal(1, volatility * 0.1)
198
  predicted_price = prices[-1] * trend_factor + base_change + (prices[-1] * noise_factor * 0.01)
199
  change_pct = ((predicted_price - prices[-1]) / prices[-1]) * 100
200
+ return {'predicted_price': predicted_price, 'change_pct': change_pct, 'confidence': max(0.6, 1 - volatility * 2)}
201
+
202
+ def get_prediction(data, predict_days=5):
203
+ """
204
+ 【【模型預測控制器】】
205
+ 根據 USE_ADVANCED_MODEL 的設定,呼叫對應的預測模型。
206
+ """
207
+ if USE_ADVANCED_MODEL:
208
+ print(f"模式: 進階LSTM模型 | 預測天期: {predict_days}天")
209
+ prediction = advanced_lstm_predict(predict_days)
210
+ # 如果進階模型預測失敗,則自動降級使用簡易模型
211
+ if prediction is not None:
212
+ return prediction
213
+ else:
214
+ print("進階模型預測失敗,自動降級為簡易統計模型。")
215
+
216
+ # 預設或降級時執行簡易模型
217
+ print(f"模式: 簡易統計模型 | 預測天期: {predict_days}天")
218
+ return simple_statistical_predict(data, predict_days)
219
 
220
  def calculate_technical_indicators(df):
221
  """計算技術指標"""
 
486
  def update_taiex_prediction(predict_days):
487
  data = get_stock_data('^TWII', '2y')
488
  if data.empty: return html.Div("無法獲取台指期資料"), {}
489
+
490
+ # === 修改點:統一呼叫 get_prediction 控制器 ===
491
+ final_prediction = get_prediction(data, predict_days)
492
+
493
  if final_prediction is None: return html.Div("資料不足,無法進行預測"), {}
494
  current_price, last_date = data['Close'].iloc[-1], data.index[-1]
495
  predicted_price, change_pct, confidence = final_prediction['predicted_price'], final_prediction['change_pct'], final_prediction['confidence']
496
+
497
  prediction_paths = {1: [1], 5: [1, 5], 10: [1, 5, 10], 20: [1, 10, 20], 60: [1, 10, 20, 60]}
498
  intervals_to_predict = prediction_paths.get(predict_days, [predict_days])
499
  prediction_dates, prediction_prices = [last_date], [current_price]
500
+
501
  for days in intervals_to_predict:
502
+ # === 修改點:迴圈內也使用統一的預測控制器 ===
503
+ interim_prediction = get_prediction(data, days)
504
  if interim_prediction:
505
  prediction_dates.append(last_date + timedelta(days=days))
506
  prediction_prices.append(interim_prediction['predicted_price'])
507
+
508
+ # (後續繪圖邏輯不變)
509
  color, arrow = ('red', '📈') if change_pct >= 0 else ('green', '📉')
510
  result_card = html.Div([
511
  html.H4(f"{predict_days}日後預測結果", style={'margin': '0 0 15px 0', 'color': 'white'}),