--- language: zh tags: - bert - multilabel-classification - chinese - intent-classification - time-lbs base_model: - google-bert/bert-base-chinese --- # 中文多标签意图识别模型(BERT) 这是一个基于 `bert-base-chinese` 微调的多标签分类模型,支持以下任务: 对中文query进行分类 - 多分类:意图识别(chat / simple question / complex question) - 二分类:是否时间相关、是否位置(LBS)相关 ## 模型结构 - 基础模型:[`bert-base-chinese`](https://huggingface.co/bert-base-chinese) - 输出层:一个 5 维的 sigmoid 多标签输出向量 - `[意图-chat, 意图-simple, 意图-complex, 是否时间相关, 是否LBS相关]` ## 使用方法 ```python import torch from transformers import BertTokenizer from bert_classifier_3 import BertMultiLabelClassifier # 加载 tokenizer 和模型 bert_base = "bert-base-chinese" model_id = "Xiaoxi2333/bert_multilabel_chinese" tokenizer = BertTokenizer.from_pretrained(model_id) model = BertMultiLabelClassifier(pretrained_model_path=bert_base, num_labels=5) state_dict = torch.hub.load_state_dict_from_url( f"https://huggingface.co/{model_id}/resolve/main/pytorch_model.bin", map_location="cpu" ) model.load_state_dict(state_dict) model.eval() # 定义标签 intent_labels = ["chat", "simple question", "complex question"] yesno_labels = ["否", "是"] # 定义预测函数 def predict(query): enc = tokenizer( query, truncation=True, padding="max_length", max_length=128, return_tensors="pt" ) with torch.no_grad(): logits = model(enc["input_ids"], enc["attention_mask"]) probs = torch.sigmoid(logits).squeeze(0) intent_index = torch.argmax(probs[:3]).item() is_time = int(probs[3] > 0.5) is_lbs = int(probs[4] > 0.5) return { "query": query, "意图": intent_labels[intent_index], "是否时间相关": yesno_labels[is_time], "是否lbs相关": yesno_labels[is_lbs], "原始概率": probs.tolist() } # 示例查询 result = predict("明天北京天气怎么样?") print(result)