tkkbbo332 commited on
Commit
64e47fe
·
1 Parent(s): 30035b6

Upload Bert_predict.py

Browse files
Files changed (1) hide show
  1. Bert_predict.py +216 -0
Bert_predict.py ADDED
@@ -0,0 +1,216 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import tensorflow as tf
3
+ import transformers
4
+ from tensorflow import keras
5
+ from transformers import BertTokenizer, TFBertModel
6
+ import pandas as pd
7
+ from datetime import date, timedelta
8
+ import requests
9
+ import time
10
+ from typing import List, Optional, Dict, Any
11
+
12
+ class BertPredictor:
13
+ """
14
+ 用於加載 BERT 模型、獲取新聞並對其進行股市影響預測的類別。
15
+ """
16
+ def __init__(self, tokenizer_name: str = 'hfl/rbt3', max_news_per_keyword: int = 5):
17
+ """
18
+ 初始化預測器,載入分詞器、預訓練模型並獲取新聞。
19
+
20
+ Args:
21
+ tokenizer_name (str): BERT 分詞器的名稱。
22
+ max_news_per_keyword (int): 每個關鍵字要抓取的新聞最大數量。
23
+ """
24
+ # --- 路徑和檔案名稱設置 ---
25
+ self.current_dir = os.path.dirname(os.path.abspath(__file__))
26
+ self.model_path = os.path.join(self.current_dir, 'Best-complete-model.h5')
27
+
28
+ # 檔案名稱用今天的日期,但內容是昨天的
29
+ today_date_str = date.today().strftime('%Y-%m-%d')
30
+ self.news_csv_path = os.path.join(self.current_dir, f'news_{today_date_str}.csv')
31
+
32
+ # 用於API查詢的日期仍然是昨天
33
+ self.target_date = date.today() - timedelta(days=1)
34
+ self.target_date_str = self.target_date.strftime('%Y-%m-%d')
35
+
36
+ # --- GNews API 設定 ---
37
+ self.api_key = "00270dacb75799771e6842ae1d6d6e71" # 請替換成您的 API Key
38
+ self.base_url = "https://gnews.io/api/v4/search"
39
+ self.keywords = ["Fed", "Interest Rates", "Inflation", "Tariffs", "ADR", "Treasury Yields"]
40
+ self.max_news_per_keyword = max_news_per_keyword
41
+
42
+ # --- 模型相關設置 ---
43
+ self.text_max_length = 256
44
+ self.tokenizer = BertTokenizer.from_pretrained(tokenizer_name)
45
+
46
+ # 載入最佳模型
47
+ print("正在加載模型...")
48
+ self.model = keras.models.load_model(
49
+ self.model_path,
50
+ custom_objects={'TFBertModel': TFBertModel}
51
+ )
52
+ print("模型加載完成。")
53
+
54
+ # --- 初始化流程 ---
55
+ self._check_file_and_get_news_if_needed()
56
+
57
+
58
+ # --- 內部使用方法 ---
59
+
60
+ def _encode_texts(self, texts: list):
61
+ """將文本轉換為 BERT 輸入格式 (input_ids, attention_mask)"""
62
+ return self.tokenizer(
63
+ texts,
64
+ max_length=self.text_max_length,
65
+ padding='max_length',
66
+ truncation=True,
67
+ return_tensors='tf'
68
+ )
69
+
70
+ def _predict(self, new_text: str) -> float:
71
+ """
72
+ 對單一新聞文本進行預測。
73
+
74
+ Args:
75
+ new_text (str): 待預測的新聞文本。
76
+
77
+ Returns:
78
+ float: 預測的股市影響分數。
79
+ """
80
+ new_encoding = self._encode_texts([new_text])
81
+ predicted_score = self.model.predict(dict(new_encoding), verbose=0)[0][0]
82
+ return float(predicted_score)
83
+
84
+ def _check_file_and_get_news_if_needed(self):
85
+ """
86
+ 檢查今天的 news csv 是否存在。如果不存在,則呼叫 _get_news() 進行抓取。
87
+ """
88
+ if not os.path.exists(self.news_csv_path):
89
+ print(f"找不到今天的檔案 '{os.path.basename(self.news_csv_path)}'。")
90
+ self._get_news()
91
+ else:
92
+ print(f"已找到今天的檔案 '{os.path.basename(self.news_csv_path)}',將跳過新聞抓取步驟。")
93
+
94
+ def _get_news(self):
95
+ """
96
+ 使用 GNews API 抓取目標日期(昨天)的新聞,即時預測分數並儲存。
97
+ """
98
+ print("開始執行新聞抓取與即時預測...")
99
+ print(f"搜尋日期設定為:{self.target_date_str} (將存檔至檔名含今日日期的檔案)")
100
+
101
+ results = []
102
+ for kw in self.keywords:
103
+ params = {
104
+ "q": kw, "lang": "en", "country": "us", "max": self.max_news_per_keyword,
105
+ "in": "title,description", "apikey": self.api_key,
106
+ "from": f"{self.target_date_str}T00:00:00Z",
107
+ "to": f"{self.target_date_str}T23:59:59Z"
108
+ }
109
+ try:
110
+ response = requests.get(self.base_url, params=params)
111
+ response.raise_for_status()
112
+ data = response.json()
113
+ print(f"關鍵字 '{kw}' 成功抓取到: {data.get('totalArticles', 0)} 則新聞")
114
+ if "articles" in data:
115
+ for article in data["articles"]:
116
+ published_date = pd.to_datetime(article['publishedAt']).strftime('%Y-%m-%d')
117
+ news_content = f"{article['title']} - {article.get('description', '')}"
118
+ score = self._predict(news_content)
119
+ results.append({
120
+ "時間": published_date,
121
+ "分數": score,
122
+ "內容": news_content
123
+ })
124
+ except requests.exceptions.RequestException as e:
125
+ print(f"錯誤:API 請求失敗 - {e}")
126
+ continue
127
+ finally:
128
+ time.sleep(0.5)
129
+
130
+ if not results:
131
+ print("抓取完成。未找到任何相關新聞。")
132
+ df_to_save = pd.DataFrame(columns=['時間', '分數', '內容'])
133
+ else:
134
+ print(f"成功抓取並預測 {len(results)} 筆新聞。")
135
+ df_to_save = pd.DataFrame(results)
136
+
137
+ try:
138
+ print(f"正在將結果寫入檔案 '{self.news_csv_path}'...")
139
+ df_to_save.to_csv(self.news_csv_path, index=False, encoding='utf-8-sig')
140
+ print(f"成功!檔案已儲存至 '{self.news_csv_path}'。")
141
+ except IOError as e:
142
+ print(f"錯誤:寫入檔案失敗 - {e}")
143
+
144
+ # --- 公開方法 ---
145
+
146
+ def get_news_index(self) -> Optional[float]:
147
+ """
148
+ 從今天的 news csv 檔案中讀取所有新聞分數並回傳其平均值。
149
+
150
+ Returns:
151
+ float or None: 所有新聞的平均分數,如果檔案不存在或為空則回傳 None。
152
+ """
153
+ try:
154
+ df = pd.read_csv(self.news_csv_path)
155
+ if df.empty or '分數' not in df.columns:
156
+ print(f"'{self.news_csv_path}' 為空或缺少 '分數' 欄位。")
157
+ return None
158
+
159
+ average_score = pd.to_numeric(df['分數'], errors='coerce').mean()
160
+ return average_score if pd.notna(average_score) else None
161
+
162
+ except FileNotFoundError:
163
+ print(f"錯誤:找不到檔案 '{self.news_csv_path}'。")
164
+ return None
165
+ except Exception as e:
166
+ print(f"讀取或計算 CSV 檔案時發生錯誤:{e}")
167
+ return None
168
+
169
+ def get_news(self) -> Optional[List[str]]:
170
+ """
171
+ 讀取今天的 news csv 檔案,並以 list 格式回傳分數絕對值最高的三則新聞內容。
172
+ """
173
+ try:
174
+ df = pd.read_csv(self.news_csv_path)
175
+ df['分數'] = pd.to_numeric(df['分數'], errors='coerce')
176
+ df.dropna(subset=['分數'], inplace=True)
177
+ if df.empty:
178
+ return []
179
+
180
+ df['abs_score'] = df['分數'].abs()
181
+ top_3_news_df = df.sort_values(by='abs_score', ascending=False).head(3)
182
+
183
+ # 將 '內容' 欄位轉換為 list of strings
184
+ return top_3_news_df['內容'].tolist()
185
+
186
+ except FileNotFoundError:
187
+ print(f"錯誤:找不到檔案 '{self.news_csv_path}'。")
188
+ return None
189
+ except Exception as e:
190
+ print(f"讀取或處理 CSV 檔案時發生錯誤:{e}")
191
+ return None
192
+
193
+ # --- 主程式區塊:只有當腳本直接執行時才運行 ---
194
+ if __name__ == "__main__":
195
+ if not os.path.exists(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'Best-complete-model.h5')):
196
+ print("錯誤:找不到模型文件 'Best-complete-model.h5'。請先訓練模型並確保它已保存。")
197
+ else:
198
+ predictor = BertPredictor(max_news_per_keyword=3)
199
+ print("\n" + "="*30)
200
+ avg_score = predictor.get_news_index()
201
+ if avg_score is not None:
202
+ print(f"從新聞檔案中計算出的平均分數為:{avg_score:.4f}")
203
+ else:
204
+ print("無法計算新聞檔案中的平均分數。")
205
+
206
+ print("\n" + "="*30)
207
+ top_news_content = predictor.get_news()
208
+ if top_news_content:
209
+ print("\n分數絕對值最高的三則新聞內容:")
210
+ for i, content in enumerate(top_news_content):
211
+ print(f" {i+1}. {content}")
212
+ elif top_news_content == []:
213
+ print("新聞檔案中無有效內容可顯示。")
214
+ else:
215
+ print("無法獲取最高分新聞。")
216
+