Spaces:
Sleeping
Sleeping
| import os | |
| import tensorflow as tf | |
| import transformers | |
| from tensorflow import keras | |
| from transformers import BertTokenizer, TFBertModel | |
| import pandas as pd | |
| from datetime import date, timedelta | |
| import requests | |
| import time | |
| from typing import List, Optional, Dict, Any | |
| class BertPredictor: | |
| """ | |
| 用於加載 BERT 模型、獲取新聞並對其進行股市影響預測的類別。 | |
| """ | |
| def __init__(self, tokenizer_name: str = 'hfl/rbt3', max_news_per_keyword: int = 5): | |
| """ | |
| 初始化預測器,載入分詞器、預訓練模型並獲取新聞。 | |
| Args: | |
| tokenizer_name (str): BERT 分詞器的名稱。 | |
| max_news_per_keyword (int): 每個關鍵字要抓取的新聞最大數量。 | |
| """ | |
| # --- 路徑和檔案名稱設置 --- | |
| self.current_dir = os.path.dirname(os.path.abspath(__file__)) | |
| self.model_path = os.path.join(self.current_dir, 'Best-complete-model.h5') | |
| # 檔案名稱用今天的日期,但內容是昨天的 | |
| today_date_str = date.today().strftime('%Y-%m-%d') | |
| self.news_csv_path = os.path.join(self.current_dir, f'news_{today_date_str}.csv') | |
| ### 抓取固定檔案而非今天,成果展示要註解 | |
| self.news_csv_path = os.path.join(self.current_dir, "news_2025-09-12.csv") | |
| # 用於API查詢的日期仍然是昨天 | |
| self.target_date = date.today() - timedelta(days=1) | |
| self.target_date_str = self.target_date.strftime('%Y-%m-%d') | |
| # --- GNews API 設定 --- | |
| self.api_key = "fd12e84a158c7d9eaf31627aaae0927a" # 請替換成您的 API Key | |
| self.base_url = "https://gnews.io/api/v4/search" | |
| self.keywords = ["Fed", "Interest Rates", "Inflation", "Tariffs", "ADR", "Treasury Yields"] | |
| self.max_news_per_keyword = max_news_per_keyword | |
| # --- 模型相關設置 --- | |
| self.text_max_length = 256 | |
| self.tokenizer = BertTokenizer.from_pretrained(tokenizer_name) | |
| # 載入最佳模型 | |
| print("正在加載模型...") | |
| self.model = keras.models.load_model( | |
| self.model_path, | |
| custom_objects={'TFBertModel': TFBertModel} | |
| ) | |
| print("模型加載完成。") | |
| # --- 初始化流程 --- | |
| self._check_file_and_get_news_if_needed() | |
| # --- 內部使用方法 --- | |
| def _encode_texts(self, texts: list): | |
| """將文本轉換為 BERT 輸入格式 (input_ids, attention_mask)""" | |
| return self.tokenizer( | |
| texts, | |
| max_length=self.text_max_length, | |
| padding='max_length', | |
| truncation=True, | |
| return_tensors='tf' | |
| ) | |
| def _predict(self, new_text: str) -> float: | |
| """ | |
| 對單一新聞文本進行預測。 | |
| Args: | |
| new_text (str): 待預測的新聞文本。 | |
| Returns: | |
| float: 預測的股市影響分數。 | |
| """ | |
| new_encoding = self._encode_texts([new_text]) | |
| predicted_score = self.model.predict(dict(new_encoding), verbose=0)[0][0] | |
| return float(predicted_score) | |
| def _check_file_and_get_news_if_needed(self): | |
| """ | |
| 檢查今天的 news csv 是否存在。如果不存在,則呼叫 _get_news() 進行抓取。 | |
| """ | |
| if not os.path.exists(self.news_csv_path): | |
| print(f"找不到今天的檔案 '{os.path.basename(self.news_csv_path)}'。") | |
| self._get_news() | |
| else: | |
| print(f"已找到今天的檔案 '{os.path.basename(self.news_csv_path)}',將跳過新聞抓取步驟。") | |
| def _get_news(self): | |
| """ | |
| 使用 GNews API 抓取目標日期(昨天)的新聞,即時預測分數並儲存。 | |
| """ | |
| print("開始執行新聞抓取與即時預測...") | |
| print(f"搜尋日期設定為:{self.target_date_str} (將存檔至檔名含今日日期的檔案)") | |
| results = [] | |
| for kw in self.keywords: | |
| params = { | |
| "q": kw, "lang": "en", "country": "us", "max": self.max_news_per_keyword, | |
| "in": "title,description", "apikey": self.api_key, | |
| "from": f"{self.target_date_str}T00:00:00Z", | |
| "to": f"{self.target_date_str}T23:59:59Z" | |
| } | |
| try: | |
| response = requests.get(self.base_url, params=params) | |
| response.raise_for_status() | |
| data = response.json() | |
| print(f"關鍵字 '{kw}' 成功抓取到: {data.get('totalArticles', 0)} 則新聞") | |
| if "articles" in data: | |
| for article in data["articles"]: | |
| published_date = pd.to_datetime(article['publishedAt']).strftime('%Y-%m-%d') | |
| news_content = f"{article['title']} - {article.get('description', '')}" | |
| score = self._predict(news_content) | |
| results.append({ | |
| "時間": published_date, | |
| "分數": score, | |
| "內容": news_content | |
| }) | |
| except requests.exceptions.RequestException as e: | |
| print(f"錯誤:API 請求失敗 - {e}") | |
| continue | |
| finally: | |
| time.sleep(0.5) | |
| if not results: | |
| print("抓取完成。未找到任何相關新聞。") | |
| df_to_save = pd.DataFrame(columns=['時間', '分數', '內容']) | |
| else: | |
| print(f"成功抓取並預測 {len(results)} 筆新聞。") | |
| df_to_save = pd.DataFrame(results) | |
| try: | |
| print(f"正在將結果寫入檔案 '{self.news_csv_path}'...") | |
| df_to_save.to_csv(self.news_csv_path, index=False, encoding='utf-8-sig') | |
| print(f"成功!檔案已儲存至 '{self.news_csv_path}'。") | |
| except IOError as e: | |
| print(f"錯誤:寫入檔案失敗 - {e}") | |
| # --- 公開方法 --- | |
| def get_news_index(self) -> Optional[float]: | |
| """ | |
| 從今天的 news csv 檔案中讀取所有新聞分數並回傳其平均值。 | |
| Returns: | |
| float or None: 所有新聞的平均分數,如果檔案不存在或為空則回傳 None。 | |
| """ | |
| try: | |
| df = pd.read_csv(self.news_csv_path) | |
| if df.empty or '分數' not in df.columns: | |
| print(f"'{self.news_csv_path}' 為空或缺少 '分數' 欄位。") | |
| return None | |
| average_score = pd.to_numeric(df['分數'], errors='coerce').mean() | |
| return average_score if pd.notna(average_score) else None | |
| except FileNotFoundError: | |
| print(f"錯誤:找不到檔案 '{self.news_csv_path}'。") | |
| return None | |
| except Exception as e: | |
| print(f"讀取或計算 CSV 檔案時發生錯誤:{e}") | |
| return None | |
| def get_news(self) -> Optional[List[str]]: | |
| """ | |
| 讀取今天的 news csv 檔案,並以 list 格式回傳分數絕對值最高的三則新聞內容。 | |
| """ | |
| try: | |
| df = pd.read_csv(self.news_csv_path) | |
| df['分數'] = pd.to_numeric(df['分數'], errors='coerce') | |
| df.dropna(subset=['分數'], inplace=True) | |
| if df.empty: | |
| return [] | |
| df['abs_score'] = df['分數'].abs() | |
| top_3_news_df = df.sort_values(by='abs_score', ascending=False).head(3) | |
| # 將 '內容' 欄位轉換為 list of strings | |
| return top_3_news_df['內容'].tolist() | |
| except FileNotFoundError: | |
| print(f"錯誤:找不到檔案 '{self.news_csv_path}'。") | |
| return None | |
| except Exception as e: | |
| print(f"讀取或處理 CSV 檔案時發生錯誤:{e}") | |
| return None | |
| # --- 主程式區塊:只有當腳本直接執行時才運行 --- | |
| if __name__ == "__main__": | |
| if not os.path.exists(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'Best-complete-model.h5')): | |
| print("錯誤:找不到模型文件 'Best-complete-model.h5'。請先訓練模型並確保它已保存。") | |
| else: | |
| predictor = BertPredictor(max_news_per_keyword=3) | |
| print("\n" + "="*30) | |
| avg_score = predictor.get_news_index() | |
| if avg_score is not None: | |
| print(f"從新聞檔案中計算出的平均分數為:{avg_score:.4f}") | |
| else: | |
| print("無法計算新聞檔案中的平均分數。") | |
| print("\n" + "="*30) | |
| top_news_content = predictor.get_news() | |
| if top_news_content: | |
| print("\n分數絕對值最高的三則新聞內容:") | |
| for i, content in enumerate(top_news_content): | |
| print(f" {i+1}. {content}") | |
| elif top_news_content == []: | |
| print("新聞檔案中無有效內容可顯示。") | |
| else: | |
| print("無法獲取最高分新聞。") | |