File size: 9,022 Bytes
625124c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76d2f06
a45f165
625124c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
import os
import tensorflow as tf
import transformers
from tensorflow import keras
from transformers import BertTokenizer, TFBertModel
import pandas as pd
from datetime import date, timedelta
import requests
import time
from typing import List, Optional, Dict, Any

class BertPredictor:
    """
    用於加載 BERT 模型、獲取新聞並對其進行股市影響預測的類別。
    """
    def __init__(self, tokenizer_name: str = 'hfl/rbt3', max_news_per_keyword: int = 5):
        """
        初始化預測器,載入分詞器、預訓練模型並獲取新聞。

        Args:
            tokenizer_name (str): BERT 分詞器的名稱。
            max_news_per_keyword (int): 每個關鍵字要抓取的新聞最大數量。
        """
        # --- 路徑和檔案名稱設置 ---
        self.current_dir = os.path.dirname(os.path.abspath(__file__))
        self.model_path = os.path.join(self.current_dir, 'Best-complete-model.h5')
        
        # 檔案名稱用今天的日期,但內容是昨天的
        today_date_str = date.today().strftime('%Y-%m-%d')
        self.news_csv_path = os.path.join(self.current_dir, f'news_{today_date_str}.csv')
        ### 抓取固定檔案而非今天,成果展示要註解
        self.news_csv_path = os.path.join(self.current_dir, "news_2025-09-12.csv")
        
        # 用於API查詢的日期仍然是昨天
        self.target_date = date.today() - timedelta(days=1)
        self.target_date_str = self.target_date.strftime('%Y-%m-%d')
        
        # --- GNews API 設定 ---
        self.api_key = "fd12e84a158c7d9eaf31627aaae0927a"  # 請替換成您的 API Key
        self.base_url = "https://gnews.io/api/v4/search"
        self.keywords = ["Fed", "Interest Rates", "Inflation", "Tariffs", "ADR", "Treasury Yields"]
        self.max_news_per_keyword = max_news_per_keyword

        # --- 模型相關設置 ---
        self.text_max_length = 256
        self.tokenizer = BertTokenizer.from_pretrained(tokenizer_name)
        
        # 載入最佳模型
        print("正在加載模型...")
        self.model = keras.models.load_model(
            self.model_path,
            custom_objects={'TFBertModel': TFBertModel}
        )
        print("模型加載完成。")

        # --- 初始化流程 ---
        self._check_file_and_get_news_if_needed()


    # --- 內部使用方法 ---

    def _encode_texts(self, texts: list):    
        """將文本轉換為 BERT 輸入格式 (input_ids, attention_mask)"""
        return self.tokenizer(
            texts,
            max_length=self.text_max_length,
            padding='max_length',
            truncation=True,
            return_tensors='tf'
        )

    def _predict(self, new_text: str) -> float:
        """
        對單一新聞文本進行預測。

        Args:
            new_text (str): 待預測的新聞文本。

        Returns:
            float: 預測的股市影響分數。
        """
        new_encoding = self._encode_texts([new_text])
        predicted_score = self.model.predict(dict(new_encoding), verbose=0)[0][0]
        return float(predicted_score)

    def _check_file_and_get_news_if_needed(self):
        """
        檢查今天的 news csv 是否存在。如果不存在,則呼叫 _get_news() 進行抓取。
        """
        if not os.path.exists(self.news_csv_path):
            print(f"找不到今天的檔案 '{os.path.basename(self.news_csv_path)}'。")
            self._get_news()
        else:
            print(f"已找到今天的檔案 '{os.path.basename(self.news_csv_path)}',將跳過新聞抓取步驟。")

    def _get_news(self):
        """
        使用 GNews API 抓取目標日期(昨天)的新聞,即時預測分數並儲存。
        """
        print("開始執行新聞抓取與即時預測...")
        print(f"搜尋日期設定為:{self.target_date_str} (將存檔至檔名含今日日期的檔案)")

        results = []
        for kw in self.keywords:
            params = {
                "q": kw, "lang": "en", "country": "us", "max": self.max_news_per_keyword,
                "in": "title,description", "apikey": self.api_key,
                "from": f"{self.target_date_str}T00:00:00Z",
                "to": f"{self.target_date_str}T23:59:59Z"
            }
            try:
                response = requests.get(self.base_url, params=params)
                response.raise_for_status()
                data = response.json()
                print(f"關鍵字 '{kw}' 成功抓取到: {data.get('totalArticles', 0)} 則新聞")
                if "articles" in data:
                    for article in data["articles"]:
                        published_date = pd.to_datetime(article['publishedAt']).strftime('%Y-%m-%d')
                        news_content = f"{article['title']} - {article.get('description', '')}"
                        score = self._predict(news_content)
                        results.append({
                            "時間": published_date,
                            "分數": score,
                            "內容": news_content
                        })
            except requests.exceptions.RequestException as e:
                print(f"錯誤:API 請求失敗 - {e}")
                continue
            finally:
                time.sleep(0.5)

        if not results:
            print("抓取完成。未找到任何相關新聞。")
            df_to_save = pd.DataFrame(columns=['時間', '分數', '內容'])
        else:
            print(f"成功抓取並預測 {len(results)} 筆新聞。")
            df_to_save = pd.DataFrame(results)
        
        try:
            print(f"正在將結果寫入檔案 '{self.news_csv_path}'...")
            df_to_save.to_csv(self.news_csv_path, index=False, encoding='utf-8-sig')
            print(f"成功!檔案已儲存至 '{self.news_csv_path}'。")
        except IOError as e:
            print(f"錯誤:寫入檔案失敗 - {e}")

    # --- 公開方法 ---

    def get_news_index(self) -> Optional[float]:
        """
        從今天的 news csv 檔案中讀取所有新聞分數並回傳其平均值。

        Returns:
            float or None: 所有新聞的平均分數,如果檔案不存在或為空則回傳 None。
        """
        try:
            df = pd.read_csv(self.news_csv_path)
            if df.empty or '分數' not in df.columns:
                print(f"'{self.news_csv_path}' 為空或缺少 '分數' 欄位。")
                return None
            
            average_score = pd.to_numeric(df['分數'], errors='coerce').mean()
            return average_score if pd.notna(average_score) else None
            
        except FileNotFoundError:
            print(f"錯誤:找不到檔案 '{self.news_csv_path}'。")
            return None
        except Exception as e:
            print(f"讀取或計算 CSV 檔案時發生錯誤:{e}")
            return None
    
    def get_news(self) -> Optional[List[str]]:
        """
        讀取今天的 news csv 檔案,並以 list 格式回傳分數絕對值最高的三則新聞內容。
        """
        try:
            df = pd.read_csv(self.news_csv_path)
            df['分數'] = pd.to_numeric(df['分數'], errors='coerce')
            df.dropna(subset=['分數'], inplace=True)
            if df.empty:
                return [] 

            df['abs_score'] = df['分數'].abs()
            top_3_news_df = df.sort_values(by='abs_score', ascending=False).head(3)
            
            # 將 '內容' 欄位轉換為 list of strings
            return top_3_news_df['內容'].tolist()

        except FileNotFoundError:
            print(f"錯誤:找不到檔案 '{self.news_csv_path}'。")
            return None
        except Exception as e:
            print(f"讀取或處理 CSV 檔案時發生錯誤:{e}")
            return None

# --- 主程式區塊:只有當腳本直接執行時才運行 ---
if __name__ == "__main__":
    if not os.path.exists(os.path.join(os.path.dirname(os.path.abspath(__file__)), 'Best-complete-model.h5')):
        print("錯誤:找不到模型文件 'Best-complete-model.h5'。請先訓練模型並確保它已保存。")
    else:
        predictor = BertPredictor(max_news_per_keyword=3)
        print("\n" + "="*30)
        avg_score = predictor.get_news_index()
        if avg_score is not None:
            print(f"從新聞檔案中計算出的平均分數為:{avg_score:.4f}")
        else:
            print("無法計算新聞檔案中的平均分數。")

        print("\n" + "="*30)
        top_news_content = predictor.get_news()
        if top_news_content:
            print("\n分數絕對值最高的三則新聞內容:")
            for i, content in enumerate(top_news_content):
                print(f"  {i+1}. {content}")
        elif top_news_content == []:
             print("新聞檔案中無有效內容可顯示。")
        else:
            print("無法獲取最高分新聞。")