Update Bert_predict.py
Browse files- Bert_predict.py +7 -2
Bert_predict.py
CHANGED
|
@@ -9,11 +9,15 @@ import requests
|
|
| 9 |
import time
|
| 10 |
from typing import List, Optional, Dict, Any
|
| 11 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
class BertPredictor:
|
| 13 |
"""
|
| 14 |
用於加載 BERT 模型、獲取新聞並對其進行股市影響預測的類別。
|
| 15 |
"""
|
| 16 |
-
def __init__(self, tokenizer_name: str =
|
| 17 |
"""
|
| 18 |
初始化預測器,載入分詞器、預訓練模型並獲取新聞。
|
| 19 |
|
|
@@ -43,7 +47,8 @@ class BertPredictor:
|
|
| 43 |
|
| 44 |
# --- 模型相關設置 ---
|
| 45 |
self.text_max_length = 256
|
| 46 |
-
self.tokenizer =
|
|
|
|
| 47 |
|
| 48 |
# 載入最佳模型
|
| 49 |
print("正在加載模型...")
|
|
|
|
| 9 |
import time
|
| 10 |
from typing import List, Optional, Dict, Any
|
| 11 |
|
| 12 |
+
from transformers import AutoTokenizer
|
| 13 |
+
### huawei-noah/TinyBERT_General_4L_312D huawei-noah/TinyBERT_General_6L_768D hfl/rbt3 bert-base-chinese
|
| 14 |
+
Bert_model_name = "hfl/rbt3"
|
| 15 |
+
|
| 16 |
class BertPredictor:
|
| 17 |
"""
|
| 18 |
用於加載 BERT 模型、獲取新聞並對其進行股市影響預測的類別。
|
| 19 |
"""
|
| 20 |
+
def __init__(self, tokenizer_name: str = Bert_model_name, max_news_per_keyword: int = 5):
|
| 21 |
"""
|
| 22 |
初始化預測器,載入分詞器、預訓練模型並獲取新聞。
|
| 23 |
|
|
|
|
| 47 |
|
| 48 |
# --- 模型相關設置 ---
|
| 49 |
self.text_max_length = 256
|
| 50 |
+
self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
|
| 51 |
+
#self.tokenizer = BertTokenizer.from_pretrained(tokenizer_name)
|
| 52 |
|
| 53 |
# 載入最佳模型
|
| 54 |
print("正在加載模型...")
|