Spaces:
Runtime error
Runtime error
xyh1756
commited on
Commit
·
ad16774
1
Parent(s):
f35fe56
first commit
Browse files- app.py +16 -0
- bert-base-chinese/README.md +3 -0
- bert-base-chinese/config.json +25 -0
- bert-base-chinese/flax_model.msgpack +3 -0
- bert-base-chinese/pytorch_model.bin +3 -0
- bert-base-chinese/tf_model.h5 +3 -0
- bert-base-chinese/tokenizer.json +0 -0
- bert-base-chinese/tokenizer_config.json +3 -0
- bert-base-chinese/vocab.txt +0 -0
- bert/__init__.py +15 -0
- bert/modeling_jointbert.py +63 -0
- bert/module.py +23 -0
- book_model/config.json +27 -0
- book_model/pytorch_model.bin +3 -0
- book_model/training_args.bin +3 -0
- data/intent_label.txt +3 -0
- data/slot_label.txt +13 -0
- predictOnce.py +180 -0
app.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from predictOnce import Estimator
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
def predict(inputText):
|
| 6 |
+
global e
|
| 7 |
+
res = e.predict(inputText)
|
| 8 |
+
return res[0], res[1]
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
if __name__ == '__main__':
|
| 12 |
+
e = Estimator()
|
| 13 |
+
iface = gr.Interface(fn=e.predict, inputs=gr.inputs.Textbox(lines=2, label="输入语句", placeholder="输入要识别的语句..."),
|
| 14 |
+
outputs=[gr.outputs.Textbox(label="意图"), gr.outputs.Textbox(label="槽值")], live=True,
|
| 15 |
+
theme="huggingface", allow_screenshot=False, allow_flagging=False)
|
| 16 |
+
iface.launch(share=True)
|
bert-base-chinese/README.md
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
language: zh
|
| 3 |
+
---
|
bert-base-chinese/config.json
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"BertForMaskedLM"
|
| 4 |
+
],
|
| 5 |
+
"attention_probs_dropout_prob": 0.1,
|
| 6 |
+
"directionality": "bidi",
|
| 7 |
+
"hidden_act": "gelu",
|
| 8 |
+
"hidden_dropout_prob": 0.1,
|
| 9 |
+
"hidden_size": 768,
|
| 10 |
+
"initializer_range": 0.02,
|
| 11 |
+
"intermediate_size": 3072,
|
| 12 |
+
"layer_norm_eps": 1e-12,
|
| 13 |
+
"max_position_embeddings": 512,
|
| 14 |
+
"model_type": "bert",
|
| 15 |
+
"num_attention_heads": 12,
|
| 16 |
+
"num_hidden_layers": 12,
|
| 17 |
+
"pad_token_id": 0,
|
| 18 |
+
"pooler_fc_size": 768,
|
| 19 |
+
"pooler_num_attention_heads": 12,
|
| 20 |
+
"pooler_num_fc_layers": 3,
|
| 21 |
+
"pooler_size_per_head": 128,
|
| 22 |
+
"pooler_type": "first_token_transform",
|
| 23 |
+
"type_vocab_size": 2,
|
| 24 |
+
"vocab_size": 21128
|
| 25 |
+
}
|
bert-base-chinese/flax_model.msgpack
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:76df8425215fb9ede22e0393e356f82a99d84e79f078cd141afbbf9277460c8e
|
| 3 |
+
size 409168515
|
bert-base-chinese/pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8a693db616eaf647ed2bfe531e1fa446637358fc108a8bf04e8d4db17e837ee9
|
| 3 |
+
size 411577189
|
bert-base-chinese/tf_model.h5
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:612acd33db45677c3d6ba70615336619dc65cddf1ecf9d39a22dd1934af4aff2
|
| 3 |
+
size 478309336
|
bert-base-chinese/tokenizer.json
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
bert-base-chinese/tokenizer_config.json
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"do_lower_case": false
|
| 3 |
+
}
|
bert-base-chinese/vocab.txt
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
bert/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2018 The Google AI Language Team Authors.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
bert/modeling_jointbert.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from transformers.modeling_bert import BertPreTrainedModel, BertModel, BertConfig
|
| 4 |
+
from torchcrf import CRF
|
| 5 |
+
from .module import IntentClassifier, SlotClassifier
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class JointBERT(BertPreTrainedModel):
|
| 9 |
+
def __init__(self, config, args, intent_label_lst, slot_label_lst):
|
| 10 |
+
super(JointBERT, self).__init__(config)
|
| 11 |
+
self.args = args
|
| 12 |
+
self.num_intent_labels = len(intent_label_lst)
|
| 13 |
+
self.num_slot_labels = len(slot_label_lst)
|
| 14 |
+
self.bert = BertModel(config=config) # Load pretrained bert
|
| 15 |
+
|
| 16 |
+
self.intent_classifier = IntentClassifier(config.hidden_size, self.num_intent_labels, args.dropout_rate)
|
| 17 |
+
self.slot_classifier = SlotClassifier(config.hidden_size, self.num_slot_labels, args.dropout_rate)
|
| 18 |
+
|
| 19 |
+
if args.use_crf:
|
| 20 |
+
self.crf = CRF(num_tags=self.num_slot_labels, batch_first=True)
|
| 21 |
+
|
| 22 |
+
def forward(self, input_ids, attention_mask, token_type_ids, intent_label_ids, slot_labels_ids):
|
| 23 |
+
outputs = self.bert(input_ids, attention_mask=attention_mask,
|
| 24 |
+
token_type_ids=token_type_ids) # sequence_output, pooled_output, (hidden_states), (attentions)
|
| 25 |
+
sequence_output = outputs[0]
|
| 26 |
+
pooled_output = outputs[1] # [CLS]
|
| 27 |
+
|
| 28 |
+
intent_logits = self.intent_classifier(pooled_output)
|
| 29 |
+
slot_logits = self.slot_classifier(sequence_output)
|
| 30 |
+
|
| 31 |
+
total_loss = 0
|
| 32 |
+
# 1. Intent Softmax
|
| 33 |
+
if intent_label_ids is not None:
|
| 34 |
+
if self.num_intent_labels == 1:
|
| 35 |
+
intent_loss_fct = nn.MSELoss()
|
| 36 |
+
intent_loss = intent_loss_fct(intent_logits.view(-1), intent_label_ids.view(-1))
|
| 37 |
+
else:
|
| 38 |
+
intent_loss_fct = nn.CrossEntropyLoss()
|
| 39 |
+
intent_loss = intent_loss_fct(intent_logits.view(-1, self.num_intent_labels), intent_label_ids.view(-1))
|
| 40 |
+
total_loss += intent_loss
|
| 41 |
+
|
| 42 |
+
# 2. Slot Softmax
|
| 43 |
+
if slot_labels_ids is not None:
|
| 44 |
+
if self.args.use_crf:
|
| 45 |
+
slot_loss = self.crf(slot_logits, slot_labels_ids, mask=attention_mask.byte(), reduction='mean')
|
| 46 |
+
slot_loss = -1 * slot_loss # negative log-likelihood
|
| 47 |
+
else:
|
| 48 |
+
slot_loss_fct = nn.CrossEntropyLoss(ignore_index=self.args.ignore_index)
|
| 49 |
+
# Only keep active parts of the loss
|
| 50 |
+
if attention_mask is not None:
|
| 51 |
+
active_loss = attention_mask.view(-1) == 1
|
| 52 |
+
active_logits = slot_logits.view(-1, self.num_slot_labels)[active_loss]
|
| 53 |
+
active_labels = slot_labels_ids.view(-1)[active_loss]
|
| 54 |
+
slot_loss = slot_loss_fct(active_logits, active_labels)
|
| 55 |
+
else:
|
| 56 |
+
slot_loss = slot_loss_fct(slot_logits.view(-1, self.num_slot_labels), slot_labels_ids.view(-1))
|
| 57 |
+
total_loss += self.args.slot_loss_coef * slot_loss
|
| 58 |
+
|
| 59 |
+
outputs = ((intent_logits, slot_logits),) + outputs[2:] # add hidden states and attention if they are here
|
| 60 |
+
|
| 61 |
+
outputs = (total_loss,) + outputs
|
| 62 |
+
|
| 63 |
+
return outputs # (loss), logits, (hidden_states), (attentions) # Logits is a tuple of intent and slot logits
|
bert/module.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class IntentClassifier(nn.Module):
|
| 5 |
+
def __init__(self, input_dim, num_intent_labels, dropout_rate=0.):
|
| 6 |
+
super(IntentClassifier, self).__init__()
|
| 7 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 8 |
+
self.linear = nn.Linear(input_dim, num_intent_labels)
|
| 9 |
+
|
| 10 |
+
def forward(self, x):
|
| 11 |
+
x = self.dropout(x)
|
| 12 |
+
return self.linear(x)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class SlotClassifier(nn.Module):
|
| 16 |
+
def __init__(self, input_dim, num_slot_labels, dropout_rate=0.):
|
| 17 |
+
super(SlotClassifier, self).__init__()
|
| 18 |
+
self.dropout = nn.Dropout(dropout_rate)
|
| 19 |
+
self.linear = nn.Linear(input_dim, num_slot_labels)
|
| 20 |
+
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
x = self.dropout(x)
|
| 23 |
+
return self.linear(x)
|
book_model/config.json
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"JointBERT"
|
| 4 |
+
],
|
| 5 |
+
"attention_probs_dropout_prob": 0.1,
|
| 6 |
+
"directionality": "bidi",
|
| 7 |
+
"finetuning_task": "book",
|
| 8 |
+
"gradient_checkpointing": false,
|
| 9 |
+
"hidden_act": "gelu",
|
| 10 |
+
"hidden_dropout_prob": 0.1,
|
| 11 |
+
"hidden_size": 768,
|
| 12 |
+
"initializer_range": 0.02,
|
| 13 |
+
"intermediate_size": 3072,
|
| 14 |
+
"layer_norm_eps": 1e-12,
|
| 15 |
+
"max_position_embeddings": 512,
|
| 16 |
+
"model_type": "bert",
|
| 17 |
+
"num_attention_heads": 12,
|
| 18 |
+
"num_hidden_layers": 12,
|
| 19 |
+
"pad_token_id": 0,
|
| 20 |
+
"pooler_fc_size": 768,
|
| 21 |
+
"pooler_num_attention_heads": 12,
|
| 22 |
+
"pooler_num_fc_layers": 3,
|
| 23 |
+
"pooler_size_per_head": 128,
|
| 24 |
+
"pooler_type": "first_token_transform",
|
| 25 |
+
"type_vocab_size": 2,
|
| 26 |
+
"vocab_size": 21128
|
| 27 |
+
}
|
book_model/pytorch_model.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6389ddb0d25ffbbea13eae6adbeb3a8e9dde3dd71ad811abd019862f51570ede
|
| 3 |
+
size 409203155
|
book_model/training_args.bin
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0e2a51024072e0d7bf7f5c695ade9f2bf7b52f85696d12e73389e95a8d63fe9c
|
| 3 |
+
size 1199
|
data/intent_label.txt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
UNK
|
| 2 |
+
query
|
| 3 |
+
chat
|
data/slot_label.txt
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
PAD
|
| 2 |
+
UNK
|
| 3 |
+
O
|
| 4 |
+
B-Author
|
| 5 |
+
I-Author
|
| 6 |
+
B-Book
|
| 7 |
+
I-Book
|
| 8 |
+
B-Press
|
| 9 |
+
I-Press
|
| 10 |
+
B-Tag
|
| 11 |
+
I-Tag
|
| 12 |
+
B-Topic
|
| 13 |
+
I-Topic
|
predictOnce.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from transformers import BertTokenizer
|
| 7 |
+
from bert.modeling_jointbert import JointBERT
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class Estimator:
|
| 11 |
+
class Args:
|
| 12 |
+
adam_epsilon = 1e-08
|
| 13 |
+
batch_size = 16
|
| 14 |
+
data_dir = 'data'
|
| 15 |
+
device = 'cpu'
|
| 16 |
+
do_eval = True
|
| 17 |
+
do_train = False
|
| 18 |
+
dropout_rate = 0.1
|
| 19 |
+
eval_batch_size = 64
|
| 20 |
+
gradient_accumulation_steps = 1
|
| 21 |
+
ignore_index = 0
|
| 22 |
+
intent_label_file = 'data/intent_label.txt'
|
| 23 |
+
learning_rate = 5e-05
|
| 24 |
+
logging_steps = 50
|
| 25 |
+
max_grad_norm = 1.0
|
| 26 |
+
max_seq_len = 50
|
| 27 |
+
max_steps = -1
|
| 28 |
+
model_dir = 'book_model'
|
| 29 |
+
model_name_or_path = 'bert-base-chinese'
|
| 30 |
+
model_type = 'bert-chinese'
|
| 31 |
+
no_cuda = False
|
| 32 |
+
num_train_epochs = 5.0
|
| 33 |
+
save_steps = 200
|
| 34 |
+
seed = 1234
|
| 35 |
+
slot_label_file = 'data/slot_label.txt'
|
| 36 |
+
slot_loss_coef = 1.0
|
| 37 |
+
slot_pad_label = 'PAD'
|
| 38 |
+
task = 'book'
|
| 39 |
+
train_batch_size = 32
|
| 40 |
+
use_crf = False
|
| 41 |
+
warmup_steps = 0
|
| 42 |
+
weight_decay = 0.0
|
| 43 |
+
|
| 44 |
+
def __init__(self, args=Args):
|
| 45 |
+
self.intent_label_lst = [label.strip() for label in open(args.intent_label_file, 'r', encoding='utf-8')]
|
| 46 |
+
self.slot_label_lst = [label.strip() for label in open(args.slot_label_file, 'r', encoding='utf-8')]
|
| 47 |
+
|
| 48 |
+
# Check whether model exists
|
| 49 |
+
if not os.path.exists(args.model_dir):
|
| 50 |
+
raise Exception("Model doesn't exists! Train first!")
|
| 51 |
+
|
| 52 |
+
self.model = JointBERT.from_pretrained(args.model_dir,
|
| 53 |
+
args=args,
|
| 54 |
+
intent_label_lst=self.intent_label_lst,
|
| 55 |
+
slot_label_lst=self.slot_label_lst)
|
| 56 |
+
self.model.to(args.device)
|
| 57 |
+
self.model.eval()
|
| 58 |
+
self.args = args
|
| 59 |
+
self.tokenizer = BertTokenizer.from_pretrained(self.args.model_name_or_path)
|
| 60 |
+
|
| 61 |
+
def convert_input_to_tensor_data(self, input, tokenizer, pad_token_label_id,
|
| 62 |
+
cls_token_segment_id=0,
|
| 63 |
+
pad_token_segment_id=0,
|
| 64 |
+
sequence_a_segment_id=0,
|
| 65 |
+
mask_padding_with_zero=True):
|
| 66 |
+
# Setting based on the current model type
|
| 67 |
+
cls_token = tokenizer.cls_token
|
| 68 |
+
sep_token = tokenizer.sep_token
|
| 69 |
+
unk_token = tokenizer.unk_token
|
| 70 |
+
pad_token_id = tokenizer.pad_token_id
|
| 71 |
+
|
| 72 |
+
slot_label_mask = []
|
| 73 |
+
|
| 74 |
+
words = list(input)
|
| 75 |
+
tokens = []
|
| 76 |
+
for word in words:
|
| 77 |
+
word_tokens = tokenizer.tokenize(word)
|
| 78 |
+
if not word_tokens:
|
| 79 |
+
word_tokens = [unk_token] # For handling the bad-encoded word
|
| 80 |
+
tokens.extend(word_tokens)
|
| 81 |
+
# Use the real label id for the first token of the word, and padding ids for the remaining tokens
|
| 82 |
+
slot_label_mask.extend([pad_token_label_id + 1] + [pad_token_label_id] * (len(word_tokens) - 1))
|
| 83 |
+
|
| 84 |
+
# Account for [CLS] and [SEP]
|
| 85 |
+
special_tokens_count = 2
|
| 86 |
+
if len(tokens) > self.args.max_seq_len - special_tokens_count:
|
| 87 |
+
tokens = tokens[: (self.args.max_seq_len - special_tokens_count)]
|
| 88 |
+
slot_label_mask = slot_label_mask[:(self.args.max_seq_len - special_tokens_count)]
|
| 89 |
+
|
| 90 |
+
# Add [SEP] token
|
| 91 |
+
tokens += [sep_token]
|
| 92 |
+
token_type_ids = [sequence_a_segment_id] * len(tokens)
|
| 93 |
+
slot_label_mask += [pad_token_label_id]
|
| 94 |
+
|
| 95 |
+
# Add [CLS] token
|
| 96 |
+
tokens = [cls_token] + tokens
|
| 97 |
+
token_type_ids = [cls_token_segment_id] + token_type_ids
|
| 98 |
+
slot_label_mask = [pad_token_label_id] + slot_label_mask
|
| 99 |
+
|
| 100 |
+
input_ids = tokenizer.convert_tokens_to_ids(tokens)
|
| 101 |
+
|
| 102 |
+
# The mask has 1 for real tokens and 0 for padding tokens. Only real tokens are attended to.
|
| 103 |
+
attention_mask = [1 if mask_padding_with_zero else 0] * len(input_ids)
|
| 104 |
+
|
| 105 |
+
# Zero-pad up to the sequence length.
|
| 106 |
+
padding_length = self.args.max_seq_len - len(input_ids)
|
| 107 |
+
input_ids = input_ids + ([pad_token_id] * padding_length)
|
| 108 |
+
attention_mask = attention_mask + ([0 if mask_padding_with_zero else 1] * padding_length)
|
| 109 |
+
token_type_ids = token_type_ids + ([pad_token_segment_id] * padding_length)
|
| 110 |
+
slot_label_mask = slot_label_mask + ([pad_token_label_id] * padding_length)
|
| 111 |
+
|
| 112 |
+
# Change to Tensor
|
| 113 |
+
input_ids = torch.tensor([input_ids], dtype=torch.long)
|
| 114 |
+
attention_mask = torch.tensor([attention_mask], dtype=torch.long)
|
| 115 |
+
token_type_ids = torch.tensor([token_type_ids], dtype=torch.long)
|
| 116 |
+
slot_label_mask = torch.tensor([slot_label_mask], dtype=torch.long)
|
| 117 |
+
|
| 118 |
+
data = [input_ids, attention_mask, token_type_ids, slot_label_mask]
|
| 119 |
+
|
| 120 |
+
return data
|
| 121 |
+
|
| 122 |
+
def predict(self, input):
|
| 123 |
+
# Convert input file to TensorDataset
|
| 124 |
+
pad_token_label_id = self.args.ignore_index
|
| 125 |
+
batch = self.convert_input_to_tensor_data(input, self.tokenizer, pad_token_label_id)
|
| 126 |
+
|
| 127 |
+
# Predict
|
| 128 |
+
batch = tuple(t.to(self.args.device) for t in batch)
|
| 129 |
+
with torch.no_grad():
|
| 130 |
+
inputs = {"input_ids": batch[0],
|
| 131 |
+
"attention_mask": batch[1],
|
| 132 |
+
"token_type_ids": batch[2],
|
| 133 |
+
"intent_label_ids": None,
|
| 134 |
+
"slot_labels_ids": None}
|
| 135 |
+
outputs = self.model(**inputs)
|
| 136 |
+
_, (intent_logits, slot_logits) = outputs[:2]
|
| 137 |
+
|
| 138 |
+
# Intent Prediction
|
| 139 |
+
intent_pred = intent_logits.detach().cpu().numpy()
|
| 140 |
+
|
| 141 |
+
# Slot prediction
|
| 142 |
+
if self.args.use_crf:
|
| 143 |
+
# decode() in `torchcrf` returns list with best index directly
|
| 144 |
+
slot_preds = np.array(self.model.crf.decode(slot_logits))
|
| 145 |
+
else:
|
| 146 |
+
slot_preds = slot_logits.detach().cpu().numpy()
|
| 147 |
+
all_slot_label_mask = batch[3].detach().cpu().numpy()
|
| 148 |
+
|
| 149 |
+
intent_pred = np.argmax(intent_pred, axis=1)[0]
|
| 150 |
+
|
| 151 |
+
if not self.args.use_crf:
|
| 152 |
+
slot_preds = np.argmax(slot_preds, axis=2)
|
| 153 |
+
|
| 154 |
+
slot_label_map = {i: label for i, label in enumerate(self.slot_label_lst)}
|
| 155 |
+
slot_preds_list = []
|
| 156 |
+
|
| 157 |
+
for i in range(slot_preds.shape[1]):
|
| 158 |
+
if all_slot_label_mask[0, i] != pad_token_label_id:
|
| 159 |
+
slot_preds_list.append(slot_label_map[slot_preds[0][i]])
|
| 160 |
+
|
| 161 |
+
words = list(input)
|
| 162 |
+
slots = dict()
|
| 163 |
+
slot = str()
|
| 164 |
+
for i in range(len(words)):
|
| 165 |
+
if slot_preds_list[i] == 'O':
|
| 166 |
+
if slot == '':
|
| 167 |
+
continue
|
| 168 |
+
slots[slot_preds_list[i - 1].split('-')[1]] = slot
|
| 169 |
+
slot = str()
|
| 170 |
+
else:
|
| 171 |
+
slot += words[i]
|
| 172 |
+
if slot != '':
|
| 173 |
+
slots[slot_preds_list[len(words) - 1].split('-')[1]] = slot
|
| 174 |
+
return self.intent_label_lst[intent_pred], slots
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
if __name__ == "__main__":
|
| 178 |
+
e = Estimator()
|
| 179 |
+
while True:
|
| 180 |
+
print(e.predict(input(">>")))
|