|
|
import os |
|
|
import tensorflow as tf |
|
|
import transformers |
|
|
from tensorflow import keras |
|
|
from transformers import BertTokenizer, TFBertModel |
|
|
import pandas as pd |
|
|
from datetime import date, timedelta |
|
|
import requests |
|
|
import time |
|
|
from typing import List, Optional, Dict, Any |
|
|
|
|
|
from transformers import AutoTokenizer |
|
|
|
|
|
Bert_model_name = "hfl/rbt3" |
|
|
|
|
|
class BertPredictor: |
|
|
""" |
|
|
用於加載 BERT 模型、獲取新聞並對其進行股市影響預測的類別。 |
|
|
""" |
|
|
def __init__(self, tokenizer_name: str = Bert_model_name, max_news_per_keyword: int = 5): |
|
|
""" |
|
|
初始化預測器,載入分詞器、預訓練模型並獲取新聞。 |
|
|
|
|
|
Args: |
|
|
tokenizer_name (str): BERT 分詞器的名稱。 |
|
|
max_news_per_keyword (int): 每個關鍵字要抓取的新聞最大數量。 |
|
|
""" |
|
|
|
|
|
self.current_dir = os.path.dirname(os.path.abspath(__file__)) |
|
|
self.model_path = os.path.join(self.current_dir, 'Best-complete-model.h5') |
|
|
|
|
|
|
|
|
today_date_str = date.today().strftime('%Y-%m-%d') |
|
|
self.news_csv_path = os.path.join(self.current_dir, f'news_{today_date_str}.csv') |
|
|
|
|
|
self.news_csv_path = os.path.join(self.current_dir, "news_2025-09-12.csv") |
|
|
|
|
|
|
|
|
self.target_date = date.today() - timedelta(days=1) |
|
|
self.target_date_str = self.target_date.strftime('%Y-%m-%d') |
|
|
|
|
|
|
|
|
self.api_key = "fd12e84a158c7d9eaf31627aaae0927a" |
|
|
self.base_url = "https://gnews.io/api/v4/search" |
|
|
self.keywords = ["Fed", "Interest Rates", "Inflation", "Tariffs", "ADR", "Treasury Yields"] |
|
|
self.max_news_per_keyword = max_news_per_keyword |
|
|
|
|
|
|
|
|
self.text_max_length = 256 |
|
|
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) |
|
|
|
|
|
|
|
|
|
|
|
print("正在加載模型...") |
|
|
self.model = keras.models.load_model( |
|
|
self.model_path, |
|
|
custom_objects={'TFBertModel': TFBertModel} |
|
|
) |
|
|
print("模型加載完成。") |
|
|
|
|
|
|
|
|
self._check_file_and_get_news_if_needed() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _encode_texts(self, texts: list): |
|
|
"""將文本轉換為 BERT 輸入格式 (input_ids, attention_mask)""" |
|
|
return self.tokenizer( |
|
|
texts, |
|
|
max_length=self.text_max_length, |
|
|
padding='max_length', |
|
|
truncation=True, |
|
|
return_tensors='tf' |
|
|
) |
|
|
|
|
|
def _predict(self, new_text: str) -> float: |
|
|
""" |
|
|
對單一新聞文本進行預測。 |
|
|
|
|
|
Args: |
|
|
new_text (str): 待預測的新聞文本。 |
|
|
|
|
|
Returns: |
|
|
float: 預測的股市影響分數。 |
|
|
""" |
|
|
new_encoding = self._encode_texts([new_text]) |
|
|
predicted_score = self.model.predict(dict(new_encoding), verbose=0)[0][0] |
|
|
return float(predicted_score) |
|
|
|
|
|
def _check_file_and_get_news_if_needed(self): |
|
|
""" |
|
|
檢查今天的 news csv 是否存在。如果不存在,則呼叫 _get_news() 進行抓取。 |
|
|
""" |
|
|
if not os.path.exists(self.news_csv_path): |
|
|
print(f"找不到今天的檔案 '{os.path.basename(self.news_csv_path)}'。") |
|
|
self._get_news() |
|
|
else: |
|
|
print(f"已找到今天的檔案 '{os.path.basename(self.news_csv_path)}',將跳過新聞抓取步驟。") |
|
|
|
|
|
def _get_news(self): |
|
|
""" |
|
|
使用 GNews API 抓取目標日期(昨天)的新聞,即時預測分數並儲存。 |
|
|
""" |
|
|
print("開始執行新聞抓取與即時預測...") |
|
|
print(f"搜尋日期設定為:{self.target_date_str} (將存檔至檔名含今日日期的檔案)") |
|
|
|
|
|
results = [] |
|
|
for kw in self.keywords: |
|
|
params = { |
|
|
"q": kw, "lang": "en", "country": "us", "max": self.max_news_per_keyword, |
|
|
"in": "title,description", "apikey": self.api_key, |
|
|
"from": f"{self.target_date_str}T00:00:00Z", |
|
|
"to": f"{self.target_date_str}T23:59:59Z" |
|
|
} |
|
|
try: |
|
|
response = requests.get(self.base_url, params=params) |
|
|
response.raise_for_status() |
|
|
data = response.json() |
|
|
print(f"關鍵字 '{kw}' 成功抓取到: {data.get('totalArticles', 0)} 則新聞") |
|
|
if "articles" in data: |
|
|
for article in data["articles"]: |
|
|
published_date = pd.to_datetime(article['publishedAt']).strftime('%Y-%m-%d') |
|
|
news_content = f"{article['title']} - {article.get('description', '')}" |
|
|
score = self._predict(news_content) |
|
|
results.append({ |
|
|
"時間": published_date, |
|
|
"分數": score, |
|
|
"內容": news_content |
|
|
}) |
|
|
except requests.exceptions.RequestException as e: |
|
|
print(f"錯誤:API 請求失敗 - {e}") |
|
|
continue |
|
|
finally: |
|
|
time.sleep(0.5) |
|
|
|
|
|
if not results: |
|
|
print("抓取完成。未找到任何相關新聞。") |
|
|
df_to_save = pd.DataFrame(columns=['時間', '分數', '內容']) |
|
|
else: |
|
|
print(f"成功抓取並預測 {len(results)} 筆新聞。") |
|
|
df_to_save = pd.DataFrame(results) |
|
|
|
|
|
try: |
|
|
print(f"正在將結果寫入檔案 '{self.news_csv_path}'...") |
|
|
df_to_save.to_csv(self.news_csv_path, index=False, encoding='utf-8-sig') |
|
|
print(f"成功!檔案已儲存至 '{self.news_csv_path}'。") |
|
|
except IOError as e: |
|
|
print(f"錯誤:寫入檔案失敗 - {e}") |
|
|
|
|
|
|
|
|
|
|
|
def get_news_index(self) -> Optional[float]: |
|
|
""" |
|
|
從今天的 news csv 檔案中讀取所有新聞分數並回傳其平均值。 |
|
|
|
|
|
Returns: |
|
|
float or None: 所有新聞的平均分數,如果檔案不存在或為空則回傳 None。 |
|
|
""" |
|
|
try: |
|
|
df = pd.read_csv(self.news_csv_path) |
|
|
if df.empty or '分數' not in df.columns: |
|
|
print(f"'{self.news_csv_path}' 為空或缺少 '分數' 欄位。") |
|
|
return None |
|
|
|
|
|
average_score = pd.to_numeric(df['分數'], errors='coerce').mean() |
|
|
return average_score if pd.notna(average_score) else None |
|
|
|
|
|
except FileNotFoundError: |
|
|
print(f"錯誤:找不到檔案 '{self.news_csv_path}'。") |
|
|
return None |
|
|
except Exception as e: |
|
|
print(f"讀取或計算 CSV 檔案時發生錯誤:{e}") |
|
|
return None |
|
|
|
|
|
def get_news(self) -> Optional[List[str]]: |
|
|
""" |
|
|
讀取今天的 news csv 檔案,並以 list 格式回傳分數絕對值最高的三則新聞內容。 |
|
|
""" |
|
|
try: |
|
|
df = pd.read_csv(self.news_csv_path) |
|
|
df['分數'] = pd.to_numeric(df['分數'], errors='coerce') |
|
|
df.dropna(subset=['分數'], inplace=True) |
|
|
if df.empty: |
|
|
return [] |
|
|
|
|
|
df['abs_score'] = df['分數'].abs() |
|
|
top_3_news_df = df.sort_values(by='abs_score', ascending=False).head(3) |
|
|
|
|
|
|
|
|
return top_3_news_df['內容'].tolist() |
|
|
|
|
|
except FileNotFoundError: |
|
|
print(f"錯誤:找不到檔案 '{self.news_csv_path}'。") |
|
|
return None |
|
|
except Exception as e: |
|
|
print(f"讀取或處理 CSV 檔案時發生錯誤:{e}") |
|
|
return None |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
if not os.path.exists(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'Best-complete-model.h5')): |
|
|
print("錯誤:找不到模型文件 'Best-complete-model.h5'。請先訓練模型並確保它已保存。") |
|
|
else: |
|
|
predictor = BertPredictor(max_news_per_keyword=3) |
|
|
print("\n" + "="*30) |
|
|
avg_score = predictor.get_news_index() |
|
|
if avg_score is not None: |
|
|
print(f"從新聞檔案中計算出的平均分數為:{avg_score:.4f}") |
|
|
else: |
|
|
print("無法計算新聞檔案中的平均分數。") |
|
|
|
|
|
print("\n" + "="*30) |
|
|
top_news_content = predictor.get_news() |
|
|
if top_news_content: |
|
|
print("\n分數絕對值最高的三則新聞內容:") |
|
|
for i, content in enumerate(top_news_content): |
|
|
print(f" {i+1}. {content}") |
|
|
elif top_news_content == []: |
|
|
print("新聞檔案中無有效內容可顯示。") |
|
|
else: |
|
|
print("無法獲取最高分新聞。") |
|
|
|
|
|
|