| --- |
| language: zh |
| tags: |
| - bert |
| - multilabel-classification |
| - chinese |
| - intent-classification |
| - time-lbs |
| base_model: |
| - google-bert/bert-base-chinese |
| --- |
| |
| # 中文多标签意图识别模型(BERT) |
|
|
| 这是一个基于 `bert-base-chinese` 微调的多标签分类模型,支持以下任务: |
|
|
| 对中文query进行分类 |
| - 多分类:意图识别(chat / simple question / complex question) |
| - 二分类:是否时间相关、是否位置(LBS)相关 |
|
|
| ## 模型结构 |
|
|
| - 基础模型:[`bert-base-chinese`](https://huggingface.co/bert-base-chinese) |
| - 输出层:一个 5 维的 sigmoid 多标签输出向量 |
| - `[意图-chat, 意图-simple, 意图-complex, 是否时间相关, 是否LBS相关]` |
|
|
| ## 使用方法 |
|
|
| ```python |
| import torch |
| from transformers import BertTokenizer |
| from bert_classifier_3 import BertMultiLabelClassifier |
| |
| # 加载 tokenizer 和模型 |
| bert_base = "bert-base-chinese" |
| model_id = "Xiaoxi2333/bert_multilabel_chinese" |
| tokenizer = BertTokenizer.from_pretrained(model_id) |
| model = BertMultiLabelClassifier(pretrained_model_path=bert_base, num_labels=5) |
| state_dict = torch.hub.load_state_dict_from_url( |
| f"https://huggingface.co/{model_id}/resolve/main/pytorch_model.bin", |
| map_location="cpu" |
| ) |
| model.load_state_dict(state_dict) |
| model.eval() |
| |
| # 定义标签 |
| intent_labels = ["chat", "simple question", "complex question"] |
| yesno_labels = ["否", "是"] |
| |
| # 定义预测函数 |
| def predict(query): |
| enc = tokenizer( |
| query, |
| truncation=True, |
| padding="max_length", |
| max_length=128, |
| return_tensors="pt" |
| ) |
| with torch.no_grad(): |
| logits = model(enc["input_ids"], enc["attention_mask"]) |
| probs = torch.sigmoid(logits).squeeze(0) |
| intent_index = torch.argmax(probs[:3]).item() |
| is_time = int(probs[3] > 0.5) |
| is_lbs = int(probs[4] > 0.5) |
| |
| return { |
| "query": query, |
| "意图": intent_labels[intent_index], |
| "是否时间相关": yesno_labels[is_time], |
| "是否lbs相关": yesno_labels[is_lbs], |
| "原始概率": probs.tolist() |
| } |
| |
| # 示例查询 |
| result = predict("明天北京天气怎么样?") |
| print(result) |