import os import sys import yaml import torch import numpy as np from transformers import AutoTokenizer, AutoModelForSequenceClassification def load_label_map(yaml_path: str): with open(yaml_path, "r", encoding="utf-8") as f: data = yaml.safe_load(f) label_map = {} if isinstance(data, list): # 支持两种写法: # - 0: 伤心 # - {0: 伤心} for item in data: if isinstance(item, dict): for k, v in item.items(): label_map[int(k)] = str(v) elif isinstance(item, str) and ":" in item: k, v = item.split(":", 1) label_map[int(k.strip())] = v.strip() elif isinstance(data, dict): for k, v in data.items(): label_map[int(k)] = str(v) else: raise ValueError(f"无法解析标签映射:{yaml_path}") if not label_map: raise ValueError(f"标签映射为空:{yaml_path}") return label_map def predict(text: str, tokenizer, model, device: torch.device): inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True, max_length=128) inputs = {k: v.to(device) for k, v in inputs.items()} model.eval() with torch.no_grad(): outputs = model(**inputs) logits = outputs.logits probs = torch.softmax(logits, dim=-1).detach().cpu().numpy()[0] pred_id = int(np.argmax(probs)) confidence = float(probs[pred_id]) return pred_id, confidence, probs def main(): base_dir = os.path.dirname(os.path.abspath(__file__)) model_dir = os.path.join(base_dir, "sentiment_roberta") yaml_path = os.path.join(base_dir, "text-emotion.yaml") if not os.path.isdir(model_dir): print(f"找不到模型目录:{model_dir}") print("请先训练并确保训练脚本 output_dir=./sentiment_roberta(相对 data_preload 目录)。") sys.exit(1) if not os.path.isfile(yaml_path): print(f"找不到标签映射文件:{yaml_path}") sys.exit(1) label_map = load_label_map(yaml_path) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"推理设备:{device}") tokenizer = AutoTokenizer.from_pretrained(model_dir) model = AutoModelForSequenceClassification.from_pretrained(model_dir) model.to(device) print("请输入一段文本(直接回车退出):") while True: try: text = input("> ").strip() except (EOFError, KeyboardInterrupt): print("\n退出") break if not text: print("退出") break pred_id, conf, _ = predict(text, tokenizer, model, device) emotion_cn = label_map.get(pred_id, f"未知标签({pred_id})") print(f"情绪预测:{emotion_cn}") print(f"置信度:{conf:.4f}") if __name__ == "__main__": main()