tkkbbo332 commited on
Commit
27216bd
·
1 Parent(s): f767c71

Update Bert_predict.py

Browse files
Files changed (1) hide show
  1. Bert_predict.py +7 -2
Bert_predict.py CHANGED
@@ -9,11 +9,15 @@ import requests
9
  import time
10
  from typing import List, Optional, Dict, Any
11
 
 
 
 
 
12
  class BertPredictor:
13
  """
14
  用於加載 BERT 模型、獲取新聞並對其進行股市影響預測的類別。
15
  """
16
- def __init__(self, tokenizer_name: str = 'hfl/rbt3', max_news_per_keyword: int = 5):
17
  """
18
  初始化預測器,載入分詞器、預訓練模型並獲取新聞。
19
 
@@ -43,7 +47,8 @@ class BertPredictor:
43
 
44
  # --- 模型相關設置 ---
45
  self.text_max_length = 256
46
- self.tokenizer = BertTokenizer.from_pretrained(tokenizer_name)
 
47
 
48
  # 載入最佳模型
49
  print("正在加載模型...")
 
9
  import time
10
  from typing import List, Optional, Dict, Any
11
 
12
+ from transformers import AutoTokenizer
13
+ ### huawei-noah/TinyBERT_General_4L_312D huawei-noah/TinyBERT_General_6L_768D hfl/rbt3 bert-base-chinese
14
+ Bert_model_name = "hfl/rbt3"
15
+
16
  class BertPredictor:
17
  """
18
  用於加載 BERT 模型、獲取新聞並對其進行股市影響預測的類別。
19
  """
20
+ def __init__(self, tokenizer_name: str = Bert_model_name, max_news_per_keyword: int = 5):
21
  """
22
  初始化預測器,載入分詞器、預訓練模型並獲取新聞。
23
 
 
47
 
48
  # --- 模型相關設置 ---
49
  self.text_max_length = 256
50
+ self.tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)
51
+ #self.tokenizer = BertTokenizer.from_pretrained(tokenizer_name)
52
 
53
  # 載入最佳模型
54
  print("正在加載模型...")