中文对话情感分类模型(8分类,上下文感知)
基于 chinese-roberta-wwm-ext-large 微调的中文对话场景情感分类模型,支持 8 种情感标签,输入为带上下文的对话文本。
本仓库提供 ONNX FP16 格式,可直接通过 ONNX Runtime 推理,无需依赖 PyTorch。
模型亮点
- 上下文感知:输入包含
[上文]对话历史和[当前]目标消息,结合语境判断情感 - 8 类情感:开心、难过、生气、惊讶、害怕、厌恶、中性、关心
- 高精度:验证集准确率 98.86%,宏平均 F1 值 0.9886
- 日常聊天风格:训练数据为互联网日常聊天语料,包含网络用语、颜文字、缩写等
- 轻量部署:ONNX FP16 格式,模型体积 623.7 MB,无需 PyTorch 即可推理
评估结果
| 标签 | 精确率 | 召回率 | F1 值 | 样本数 |
|---|---|---|---|---|
| 开心 | 1.0000 | 0.9911 | 0.9955 | 224 |
| 难过 | 0.9742 | 0.9956 | 0.9848 | 228 |
| 生气 | 0.9790 | 0.9708 | 0.9749 | 240 |
| 惊讶 | 0.9926 | 1.0000 | 0.9963 | 270 |
| 害怕 | 0.9955 | 0.9779 | 0.9866 | 226 |
| 厌恶 | 0.9784 | 0.9784 | 0.9784 | 231 |
| 中性 | 0.9901 | 0.9950 | 0.9926 | 202 |
| 关心 | 1.0000 | 1.0000 | 1.0000 | 228 |
| 加权平均 | 0.9887 | 0.9886 | 0.9886 | 1849 |
输入格式
模型输入为拼接的对话文本,格式如下:
[上文]
A:今天开会的时候看你一直揉太阳穴,是不是昨晚又熬夜肝新番了
B:啊被发现了,其实不只是追番啦,最近项目deadline压得有点喘不过气
[当前]
B:话说回来你脸色也不太好的说,是不是胃又不舒服了,要不要去茶水间泡点热茶休息一下,身体才是革命的本钱啊
模型预测 [当前] 部分发言者的情感。
使用方法
安装依赖
pip install onnxruntime transformers numpy
推理示例
import onnxruntime as ort
from transformers import AutoTokenizer
import numpy as np
# 标签列表
labels = ["开心", "难过", "生气", "惊讶", "害怕", "厌恶", "中性", "关心"]
# 加载分词器和模型
tokenizer = AutoTokenizer.from_pretrained("foxyanuo/chinese-chat-sentiment-8class")
session = ort.InferenceSession("model.onnx")
# 构造输入(必须包含上下文)
text = """[上文]
A:你听说了吗,小王要走了
B:啊?真的假的
[当前]
A:嗯,听到这个消息我都快哭了..."""
inputs = tokenizer(text, return_tensors="np", max_length=256, truncation=True, padding="max_length")
logits = session.run(None, {k: v for k, v in inputs.items()})[0]
# 获取预测结果
probs = np.exp(logits) / np.exp(logits).sum(axis=-1, keepdims=True)
pred = int(np.argmax(probs, axis=-1)[0])
confidence = float(probs[0][pred])
print(f"情感:{labels[pred]},置信度:{confidence:.2%}")
# 情感:难过,置信度:99.91%
标签映射
| 编号 | 标签 |
|---|---|
| 0 | 开心 |
| 1 | 难过 |
| 2 | 生气 |
| 3 | 惊讶 |
| 4 | 害怕 |
| 5 | 厌恶 |
| 6 | 中性 |
| 7 | 关心 |
训练细节
| 参数 | 值 |
|---|---|
| 训练集大小 | 35k |
| 验证集大小 | 1.8k |
| 最大长度 | 256 |
| 批大小 | 32 |
| 学习率 | 2e-5 |
| 训练轮数 | 5(早停策略,耐心值=2) |
| 预热比例 | 0.1 |
| 权重衰减 | 0.01 |
| 优化器 | AdamW |
| 混合精度训练 | 是 |
模型信息
| 项目 | 值 |
|---|---|
| 格式 | ONNX |
| 数值精度 | FP16 |
| 模型大小 | 623.7 MB |
| 输入 | input_ids、attention_mask、token_type_ids |
| 输出 | logits(8维) |
| 动态轴 | 支持(批大小、序列长度) |
许可证
CC-BY-NC-4.0
- Downloads last month
- 19
Model tree for foxyanuo/chinese-chat-sentiment-8class
Base model
hfl/chinese-roberta-wwm-ext-largeEvaluation results
- Accuracyself-reported0.989
- Macro F1self-reported0.989