Billy Lin
text-emotion-classification
97a5393
import os
import sys
import yaml
import torch
import numpy as np
from transformers import AutoTokenizer, AutoModelForSequenceClassification
def load_label_map(yaml_path: str):
with open(yaml_path, "r", encoding="utf-8") as f:
data = yaml.safe_load(f)
label_map = {}
if isinstance(data, list):
# 支持两种写法:
# - 0: 伤心
# - {0: 伤心}
for item in data:
if isinstance(item, dict):
for k, v in item.items():
label_map[int(k)] = str(v)
elif isinstance(item, str) and ":" in item:
k, v = item.split(":", 1)
label_map[int(k.strip())] = v.strip()
elif isinstance(data, dict):
for k, v in data.items():
label_map[int(k)] = str(v)
else:
raise ValueError(f"无法解析标签映射:{yaml_path}")
if not label_map:
raise ValueError(f"标签映射为空:{yaml_path}")
return label_map
def predict(text: str, tokenizer, model, device: torch.device):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128)
inputs = {k: v.to(device) for k, v in inputs.items()}
model.eval()
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
probs = torch.softmax(logits, dim=-1).detach().cpu().numpy()[0]
pred_id = int(np.argmax(probs))
confidence = float(probs[pred_id])
return pred_id, confidence, probs
def main():
base_dir = os.path.dirname(os.path.abspath(__file__))
model_dir = os.path.join(base_dir, "sentiment_roberta")
yaml_path = os.path.join(base_dir, "text-emotion.yaml")
if not os.path.isdir(model_dir):
print(f"找不到模型目录:{model_dir}")
print("请先训练并确保训练脚本 output_dir=./sentiment_roberta(相对 data_preload 目录)。")
sys.exit(1)
if not os.path.isfile(yaml_path):
print(f"找不到标签映射文件:{yaml_path}")
sys.exit(1)
label_map = load_label_map(yaml_path)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"推理设备:{device}")
tokenizer = AutoTokenizer.from_pretrained(model_dir)
model = AutoModelForSequenceClassification.from_pretrained(model_dir)
model.to(device)
print("请输入一段文本(直接回车退出):")
while True:
try:
text = input("> ").strip()
except (EOFError, KeyboardInterrupt):
print("\n退出")
break
if not text:
print("退出")
break
pred_id, conf, _ = predict(text, tokenizer, model, device)
emotion_cn = label_map.get(pred_id, f"未知标签({pred_id})")
print(f"情绪预测:{emotion_cn}")
print(f"置信度:{conf:.4f}")
if __name__ == "__main__":
main()