pyCrawing / sentiment_analyzer.py
khjhs60199's picture
Update sentiment_analyzer.py
d740988 verified
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification, pipeline
import logging
import re
from typing import Dict, Tuple, Optional
import jieba
import emoji
import os
logger = logging.getLogger(__name__)
class SentimentAnalyzer:
"""中文新聞情緒分析器 - 修正版"""
def __init__(self, model_name: str = "uer/roberta-base-finetuned-jd-binary-chinese"):
self.model_name = model_name
self.tokenizer = None
self.model = None
self.classifier = None
self.device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Device set to use {self.device}")
# 初始化模型
self._load_model()
# 情緒關鍵字典
self.positive_keywords = {
'上漲', '漲', '漲幅', '上升', '增長', '成長', '利好', '利多', '買進', '看好',
'樂觀', '獲利', '盈利', '突破', '新高', '強勢', '回升', '反彈', '看漲',
'推薦', '買入', '增持', '超買', '牛市', '多頭', '正面', '積極', '飆漲',
'大漲', '強勢', '創新高', '獲利', '成功', '贏家', '提升', '改善'
}
self.negative_keywords = {
'下跌', '跌', '跌幅', '下滑', '下降', '減少', '衰退', '利空', '賣出', '看壞',
'悲觀', '虧損', '損失', '破底', '新低', '弱勢', '下探', '重挫', '看跌',
'賣出', '減持', '超賣', '熊市', '空頭', '負面', '消極', '警告', '暴跌',
'大跌', '崩盤', '危機', '風險', '下修', '衰退'
}
def _load_model(self):
"""載入預訓練模型 - 修正版"""
try:
logger.info(f"開始載入情緒分析模型: {self.model_name}")
# 載入 tokenizer
logger.info("載入 tokenizer...")
self.tokenizer = AutoTokenizer.from_pretrained(
self.model_name,
trust_remote_code=True
)
# 載入模型
logger.info("載入模型...")
self.model = AutoModelForSequenceClassification.from_pretrained(
self.model_name,
trust_remote_code=True
)
# 移動到適當的設備
if self.device == "cuda":
self.model = self.model.cuda()
# 創建分類器管道 - 修正配置
logger.info("創建分類器管道...")
self.classifier = pipeline(
"text-classification",
model=self.model,
tokenizer=self.tokenizer,
device=0 if self.device == "cuda" else -1,
return_all_scores=False # 只返回最高分的結果
)
logger.info("✅ 情緒分析模型載入成功")
except Exception as e:
logger.error(f"❌ 載入模型時發生錯誤: {e}")
logger.info("將使用關鍵字分析作為備用方案")
self.classifier = None
def _preprocess_text(self, text: str) -> str:
"""文本預處理"""
try:
if not text:
return ""
# 移除emoji
text = emoji.demojize(text, language='zh')
text = re.sub(r':[a-zA-Z_]+:', '', text)
# 移除特殊字符
text = re.sub(r'[^\u4e00-\u9fff\u3400-\u4dbf\w\s.,!?()(),。!?]', '', text)
# 移除多餘空格
text = re.sub(r'\s+', ' ', text).strip()
# 截斷長度 (BERT模型限制)
if len(text) > 500:
text = text[:500]
return text
except Exception as e:
logger.error(f"文本預處理錯誤: {e}")
return text
def _keyword_sentiment(self, text: str) -> Tuple[str, float]:
"""基於關鍵字的情緒分析"""
if not text:
return "neutral", 0.5
positive_count = sum(1 for keyword in self.positive_keywords if keyword in text)
negative_count = sum(1 for keyword in self.negative_keywords if keyword in text)
total_keywords = positive_count + negative_count
if total_keywords == 0:
return "neutral", 0.5
positive_ratio = positive_count / total_keywords
if positive_ratio > 0.6:
return "positive", 0.7 + (positive_ratio - 0.6) * 0.75
elif positive_ratio < 0.4:
return "negative", 0.3 - (0.4 - positive_ratio) * 0.75
else:
return "neutral", 0.5
def analyze_sentiment(self, text: str, title: str = "") -> Dict[str, any]:
"""分析文本情緒 - 修正版"""
try:
# 合併標題和內容
full_text = f"{title} {text}" if title else text
processed_text = self._preprocess_text(full_text)
if not processed_text:
return {
"sentiment": "neutral",
"confidence": 0.5,
"method": "default"
}
# 使用模型分析
if self.classifier:
try:
# 修正模型調用方式
result = self.classifier(processed_text)
# 處理模型結果 - 修正數據結構問題
if result:
# result 是單個字典,不是列表
if isinstance(result, list) and len(result) > 0:
best_result = result[0]
else:
best_result = result
# 標籤映射
label_mapping = {
'LABEL_0': 'negative',
'LABEL_1': 'positive',
'negative': 'negative',
'positive': 'positive'
}
sentiment = label_mapping.get(best_result.get('label', ''), 'neutral')
confidence = best_result.get('score', 0.5)
# 如果信心度較低,使用關鍵字方法
if confidence < 0.7:
keyword_sentiment, keyword_confidence = self._keyword_sentiment(processed_text)
# 加權平均
if abs(confidence - 0.5) < abs(keyword_confidence - 0.5):
sentiment = keyword_sentiment
confidence = (confidence + keyword_confidence) / 2
method = "hybrid"
else:
method = "model"
else:
method = "model"
return {
"sentiment": sentiment,
"confidence": confidence,
"method": method
}
except Exception as e:
logger.error(f"模型分析錯誤: {e}")
logger.debug(f"錯誤詳情: {str(e)}")
# 備用:關鍵字分析
sentiment, confidence = self._keyword_sentiment(processed_text)
return {
"sentiment": sentiment,
"confidence": confidence,
"method": "keyword"
}
except Exception as e:
logger.error(f"情緒分析錯誤: {e}")
return {
"sentiment": "neutral",
"confidence": 0.5,
"method": "error"
}
def batch_analyze(self, texts: list, titles: list = None) -> list:
"""批量分析情緒"""
results = []
titles = titles or [""] * len(texts)
for i, text in enumerate(texts):
title = titles[i] if i < len(titles) else ""
result = self.analyze_sentiment(text, title)
results.append(result)
# 避免GPU記憶體問題
if i % 10 == 0 and torch.cuda.is_available():
torch.cuda.empty_cache()
return results