AlanRex commited on
Commit
2ded77c
·
verified ·
1 Parent(s): f962002

Upload 3 files

Browse files
.gitattributes CHANGED
@@ -34,3 +34,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  stock_lstm_model_v2.keras filter=lfs diff=lfs merge=lfs -text
 
 
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  stock_lstm_model_v2.keras filter=lfs diff=lfs merge=lfs -text
37
+ 9CE6ABB0E688BCE5A5B3E69920220912-20250909.xlsx filter=lfs diff=lfs merge=lfs -text
9CE6ABB0E688BCE5A5B3E69920220912-20250909.xlsx ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:89dd2c60285fa57c04c55a4397207553f73e00ee75d72ee9e91fbad39247c7e8
3
+ size 121399
predictor_logic.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import xgboost as xgb
3
+ import os
4
+
5
+ # 定義一個類別來封裝模型載入和預測邏輯
6
+ class StockPredictor:
7
+ """
8
+ 這個類別負責載入模型和資料,並提供預測方法。
9
+ """
10
+ def __init__(self, model_path="xgboost_model.json", data_path="期末專案輸入資料20220912-20250909.xlsx"):
11
+ """
12
+ 初始化預測器物件,載入模型和資料集。
13
+ 可接受無引數,自動帶入預設路徑檔名。
14
+ """
15
+ if not os.path.exists(model_path):
16
+ raise FileNotFoundError(f"找不到模型檔案: {model_path}")
17
+ if not os.path.exists(data_path):
18
+ raise FileNotFoundError(f"找不到資料檔案: {data_path}")
19
+
20
+ self.model = xgb.XGBRegressor()
21
+ self.model.load_model(model_path)
22
+ self.historical_df = pd.read_excel(data_path)
23
+
24
+ # 確保 '日期' 欄位是 datetime 格式
25
+ self.historical_df['日期'] = pd.to_datetime(self.historical_df['日期'], format='%Y%m%d')
26
+
27
+ # 確保欄位名稱正確
28
+ self.historical_df.rename(columns={
29
+ '加權指\n數開盤': 'Open',
30
+ '加權指\n數最高': 'High',
31
+ '加權指\n數最低': 'Low',
32
+ '加權指\n數收盤': 'Close',
33
+ '加權指數\n成交量': 'Volume'
34
+ }, inplace=True)
35
+
36
+ def predict(self, date_str):
37
+ """
38
+ 輸入一個日期字串,自動抓取前一筆資料進行預測。
39
+ 返回預測結果。
40
+ """
41
+ # 將輸入日期轉換為 datetime 物件
42
+ pred_date = pd.to_datetime(date_str)
43
+
44
+ # 尋找使用者輸入日期在資料集中的位置
45
+ current_row_index = self.historical_df[self.historical_df['日期'] == pred_date].index
46
+
47
+ if current_row_index.empty:
48
+ raise ValueError("找不到您輸入日期所對應的資料。")
49
+
50
+ # 取得前一筆資料的索引
51
+ prev_data_index = current_row_index[0] - 1
52
+
53
+ if prev_data_index < 0:
54
+ raise ValueError("資料集沒有前一筆資料可以進行預測。")
55
+
56
+ # 抓取前一筆資料
57
+ selected_row = self.historical_df.iloc[[prev_data_index]]
58
+
59
+ # 準備模型輸入
60
+ input_data = selected_row[['Open', 'High', 'Low', 'Volume']].values
61
+
62
+ # 進行預測,模型需要二維陣列的輸入
63
+ predicted_price = self.model.predict(input_data.reshape(1, -1))[0]
64
+
65
+ return predicted_price
xgboost_model.json ADDED
The diff for this file is too large to render. See raw diff