khjhs60199 commited on
Commit
62fddb6
·
verified ·
1 Parent(s): 99f90dc

Update sentiment_analyzer.py

Browse files
Files changed (1) hide show
  1. sentiment_analyzer.py +36 -14
sentiment_analyzer.py CHANGED
@@ -5,11 +5,12 @@ import re
5
  from typing import Dict, Tuple, Optional
6
  import jieba
7
  import emoji
 
8
 
9
  logger = logging.getLogger(__name__)
10
 
11
  class SentimentAnalyzer:
12
- """中文新聞情緒分析器"""
13
 
14
  def __init__(self, model_name: str = "uer/roberta-base-finetuned-jd-binary-chinese"):
15
  self.model_name = model_name
@@ -18,6 +19,8 @@ class SentimentAnalyzer:
18
  self.classifier = None
19
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
20
 
 
 
21
  # 初始化模型
22
  self._load_model()
23
 
@@ -35,26 +38,47 @@ class SentimentAnalyzer:
35
  }
36
 
37
  def _load_model(self):
38
- """載入預訓練模型"""
39
  try:
40
- logger.info(f"載入情緒分析模型: {self.model_name}")
 
 
 
 
41
 
42
- self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
43
- self.model = AutoModelForSequenceClassification.from_pretrained(self.model_name)
 
 
 
 
44
 
45
- # 創建分類器管道
 
 
 
 
 
 
 
 
 
 
 
 
46
  self.classifier = pipeline(
47
  "text-classification",
48
  model=self.model,
49
  tokenizer=self.tokenizer,
50
  device=0 if self.device == "cuda" else -1,
51
- return_all_scores=True
52
  )
53
 
54
- logger.info("情緒分析模型載入成功")
55
 
56
  except Exception as e:
57
- logger.error(f"載入模型時發生錯誤: {e}")
 
58
  self.classifier = None
59
 
60
  def _preprocess_text(self, text: str) -> str:
@@ -120,10 +144,8 @@ class SentimentAnalyzer:
120
 
121
  # 處理模型結果
122
  if results and len(results) > 0:
123
- scores = results[0]
124
-
125
  # 找到最高分數的標籤
126
- best_result = max(scores, key=lambda x: x['score'])
127
 
128
  # 標籤映射
129
  label_mapping = {
@@ -186,7 +208,7 @@ class SentimentAnalyzer:
186
  results.append(result)
187
 
188
  # 避免GPU記憶體問題
189
- if i % 10 == 0:
190
- torch.cuda.empty_cache() if torch.cuda.is_available() else None
191
 
192
  return results
 
5
  from typing import Dict, Tuple, Optional
6
  import jieba
7
  import emoji
8
+ import os
9
 
10
  logger = logging.getLogger(__name__)
11
 
12
  class SentimentAnalyzer:
13
+ """中文新聞情緒分析器 - 改進版"""
14
 
15
  def __init__(self, model_name: str = "uer/roberta-base-finetuned-jd-binary-chinese"):
16
  self.model_name = model_name
 
19
  self.classifier = None
20
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
21
 
22
+ logger.info(f"Device set to use {self.device}")
23
+
24
  # 初始化模型
25
  self._load_model()
26
 
 
38
  }
39
 
40
  def _load_model(self):
41
+ """載入預訓練模型 - 改進版"""
42
  try:
43
+ logger.info(f"開始載入情緒分析模型: {self.model_name}")
44
+
45
+ # 檢查模型是否已快取
46
+ cache_dir = os.path.expanduser("~/.cache/huggingface/transformers")
47
+ logger.info(f"模型快取目錄: {cache_dir}")
48
 
49
+ # 載入 tokenizer
50
+ logger.info("載入 tokenizer...")
51
+ self.tokenizer = AutoTokenizer.from_pretrained(
52
+ self.model_name,
53
+ trust_remote_code=True
54
+ )
55
 
56
+ # 載入模型
57
+ logger.info("載入模型...")
58
+ self.model = AutoModelForSequenceClassification.from_pretrained(
59
+ self.model_name,
60
+ trust_remote_code=True
61
+ )
62
+
63
+ # 移動到適當的設備
64
+ if self.device == "cuda":
65
+ self.model = self.model.cuda()
66
+
67
+ # 創建分類器管道 - 修正過時的參數
68
+ logger.info("創建分類器管道...")
69
  self.classifier = pipeline(
70
  "text-classification",
71
  model=self.model,
72
  tokenizer=self.tokenizer,
73
  device=0 if self.device == "cuda" else -1,
74
+ top_k=None # 替代 return_all_scores=True
75
  )
76
 
77
+ logger.info("情緒分析模型載入成功")
78
 
79
  except Exception as e:
80
+ logger.error(f"載入模型時發生錯誤: {e}")
81
+ logger.info("將使用關鍵字分析作為備用方案")
82
  self.classifier = None
83
 
84
  def _preprocess_text(self, text: str) -> str:
 
144
 
145
  # 處理模型結果
146
  if results and len(results) > 0:
 
 
147
  # 找到最高分數的標籤
148
+ best_result = max(results, key=lambda x: x['score'])
149
 
150
  # 標籤映射
151
  label_mapping = {
 
208
  results.append(result)
209
 
210
  # 避免GPU記憶體問題
211
+ if i % 10 == 0 and torch.cuda.is_available():
212
+ torch.cuda.empty_cache()
213
 
214
  return results