File size: 2,185 Bytes
1760b74 6b91ff3 1760b74 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 | ---
language: zh
tags:
- bert
- multilabel-classification
- chinese
- intent-classification
- time-lbs
base_model:
- google-bert/bert-base-chinese
---
# 中文多标签意图识别模型(BERT)
这是一个基于 `bert-base-chinese` 微调的多标签分类模型,支持以下任务:
对中文query进行分类
- 多分类:意图识别(chat / simple question / complex question)
- 二分类:是否时间相关、是否位置(LBS)相关
## 模型结构
- 基础模型:[`bert-base-chinese`](https://huggingface.co/bert-base-chinese)
- 输出层:一个 5 维的 sigmoid 多标签输出向量
- `[意图-chat, 意图-simple, 意图-complex, 是否时间相关, 是否LBS相关]`
## 使用方法
```python
import torch
from transformers import BertTokenizer
from bert_classifier_3 import BertMultiLabelClassifier
# 加载 tokenizer 和模型
bert_base = "bert-base-chinese"
model_id = "Xiaoxi2333/bert_multilabel_chinese"
tokenizer = BertTokenizer.from_pretrained(model_id)
model = BertMultiLabelClassifier(pretrained_model_path=bert_base, num_labels=5)
state_dict = torch.hub.load_state_dict_from_url(
f"https://huggingface.co/{model_id}/resolve/main/pytorch_model.bin",
map_location="cpu"
)
model.load_state_dict(state_dict)
model.eval()
# 定义标签
intent_labels = ["chat", "simple question", "complex question"]
yesno_labels = ["否", "是"]
# 定义预测函数
def predict(query):
enc = tokenizer(
query,
truncation=True,
padding="max_length",
max_length=128,
return_tensors="pt"
)
with torch.no_grad():
logits = model(enc["input_ids"], enc["attention_mask"])
probs = torch.sigmoid(logits).squeeze(0)
intent_index = torch.argmax(probs[:3]).item()
is_time = int(probs[3] > 0.5)
is_lbs = int(probs[4] > 0.5)
return {
"query": query,
"意图": intent_labels[intent_index],
"是否时间相关": yesno_labels[is_time],
"是否lbs相关": yesno_labels[is_lbs],
"原始概率": probs.tolist()
}
# 示例查询
result = predict("明天北京天气怎么样?")
print(result) |