Spaces:
Sleeping
Sleeping
Update sentiment_analyzer.py
Browse files- sentiment_analyzer.py +36 -14
sentiment_analyzer.py
CHANGED
|
@@ -5,11 +5,12 @@ import re
|
|
| 5 |
from typing import Dict, Tuple, Optional
|
| 6 |
import jieba
|
| 7 |
import emoji
|
|
|
|
| 8 |
|
| 9 |
logger = logging.getLogger(__name__)
|
| 10 |
|
| 11 |
class SentimentAnalyzer:
|
| 12 |
-
"""中文新聞情緒分析器"""
|
| 13 |
|
| 14 |
def __init__(self, model_name: str = "uer/roberta-base-finetuned-jd-binary-chinese"):
|
| 15 |
self.model_name = model_name
|
|
@@ -18,6 +19,8 @@ class SentimentAnalyzer:
|
|
| 18 |
self.classifier = None
|
| 19 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 20 |
|
|
|
|
|
|
|
| 21 |
# 初始化模型
|
| 22 |
self._load_model()
|
| 23 |
|
|
@@ -35,26 +38,47 @@ class SentimentAnalyzer:
|
|
| 35 |
}
|
| 36 |
|
| 37 |
def _load_model(self):
|
| 38 |
-
"""載入預訓練模型"""
|
| 39 |
try:
|
| 40 |
-
logger.info(f"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
self.classifier = pipeline(
|
| 47 |
"text-classification",
|
| 48 |
model=self.model,
|
| 49 |
tokenizer=self.tokenizer,
|
| 50 |
device=0 if self.device == "cuda" else -1,
|
| 51 |
-
return_all_scores=True
|
| 52 |
)
|
| 53 |
|
| 54 |
-
logger.info("情緒分析模型載入成功")
|
| 55 |
|
| 56 |
except Exception as e:
|
| 57 |
-
logger.error(f"載入模型時發生錯誤: {e}")
|
|
|
|
| 58 |
self.classifier = None
|
| 59 |
|
| 60 |
def _preprocess_text(self, text: str) -> str:
|
|
@@ -120,10 +144,8 @@ class SentimentAnalyzer:
|
|
| 120 |
|
| 121 |
# 處理模型結果
|
| 122 |
if results and len(results) > 0:
|
| 123 |
-
scores = results[0]
|
| 124 |
-
|
| 125 |
# 找到最高分數的標籤
|
| 126 |
-
best_result = max(
|
| 127 |
|
| 128 |
# 標籤映射
|
| 129 |
label_mapping = {
|
|
@@ -186,7 +208,7 @@ class SentimentAnalyzer:
|
|
| 186 |
results.append(result)
|
| 187 |
|
| 188 |
# 避免GPU記憶體問題
|
| 189 |
-
if i % 10 == 0:
|
| 190 |
-
torch.cuda.empty_cache()
|
| 191 |
|
| 192 |
return results
|
|
|
|
| 5 |
from typing import Dict, Tuple, Optional
|
| 6 |
import jieba
|
| 7 |
import emoji
|
| 8 |
+
import os
|
| 9 |
|
| 10 |
logger = logging.getLogger(__name__)
|
| 11 |
|
| 12 |
class SentimentAnalyzer:
|
| 13 |
+
"""中文新聞情緒分析器 - 改進版"""
|
| 14 |
|
| 15 |
def __init__(self, model_name: str = "uer/roberta-base-finetuned-jd-binary-chinese"):
|
| 16 |
self.model_name = model_name
|
|
|
|
| 19 |
self.classifier = None
|
| 20 |
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 21 |
|
| 22 |
+
logger.info(f"Device set to use {self.device}")
|
| 23 |
+
|
| 24 |
# 初始化模型
|
| 25 |
self._load_model()
|
| 26 |
|
|
|
|
| 38 |
}
|
| 39 |
|
| 40 |
def _load_model(self):
|
| 41 |
+
"""載入預訓練模型 - 改進版"""
|
| 42 |
try:
|
| 43 |
+
logger.info(f"開始載入情緒分析模型: {self.model_name}")
|
| 44 |
+
|
| 45 |
+
# 檢查模型是否已快取
|
| 46 |
+
cache_dir = os.path.expanduser("~/.cache/huggingface/transformers")
|
| 47 |
+
logger.info(f"模型快取目錄: {cache_dir}")
|
| 48 |
|
| 49 |
+
# 載入 tokenizer
|
| 50 |
+
logger.info("載入 tokenizer...")
|
| 51 |
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
| 52 |
+
self.model_name,
|
| 53 |
+
trust_remote_code=True
|
| 54 |
+
)
|
| 55 |
|
| 56 |
+
# 載入模型
|
| 57 |
+
logger.info("載入模型...")
|
| 58 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(
|
| 59 |
+
self.model_name,
|
| 60 |
+
trust_remote_code=True
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# 移動到適當的設備
|
| 64 |
+
if self.device == "cuda":
|
| 65 |
+
self.model = self.model.cuda()
|
| 66 |
+
|
| 67 |
+
# 創建分類器管道 - 修正過時的參數
|
| 68 |
+
logger.info("創建分類器管道...")
|
| 69 |
self.classifier = pipeline(
|
| 70 |
"text-classification",
|
| 71 |
model=self.model,
|
| 72 |
tokenizer=self.tokenizer,
|
| 73 |
device=0 if self.device == "cuda" else -1,
|
| 74 |
+
top_k=None # 替代 return_all_scores=True
|
| 75 |
)
|
| 76 |
|
| 77 |
+
logger.info("✅ 情緒分析模型載入成功")
|
| 78 |
|
| 79 |
except Exception as e:
|
| 80 |
+
logger.error(f"❌ 載入模型時發生錯誤: {e}")
|
| 81 |
+
logger.info("將使用關鍵字分析作為備用方案")
|
| 82 |
self.classifier = None
|
| 83 |
|
| 84 |
def _preprocess_text(self, text: str) -> str:
|
|
|
|
| 144 |
|
| 145 |
# 處理模型結果
|
| 146 |
if results and len(results) > 0:
|
|
|
|
|
|
|
| 147 |
# 找到最高分數的標籤
|
| 148 |
+
best_result = max(results, key=lambda x: x['score'])
|
| 149 |
|
| 150 |
# 標籤映射
|
| 151 |
label_mapping = {
|
|
|
|
| 208 |
results.append(result)
|
| 209 |
|
| 210 |
# 避免GPU記憶體問題
|
| 211 |
+
if i % 10 == 0 and torch.cuda.is_available():
|
| 212 |
+
torch.cuda.empty_cache()
|
| 213 |
|
| 214 |
return results
|