Upload model
Browse files- model.safetensors +1 -1
- modeling_transformer.py +36 -2
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 250204
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3c6565fcbb6c375fc4b9c112ed3d73602e2741f45d6f7bea766a81c603e2f0be
|
| 3 |
size 250204
|
modeling_transformer.py
CHANGED
|
@@ -71,7 +71,7 @@ def masked_softmax(X, valid_lens):
|
|
| 71 |
valid_lens = valid_lens.reshape(-1)
|
| 72 |
# 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0
|
| 73 |
X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens,
|
| 74 |
-
value=-
|
| 75 |
return nn.functional.softmax(X.reshape(shape), dim=-1)
|
| 76 |
|
| 77 |
class DotProductAttention(nn.Module):
|
|
@@ -411,4 +411,38 @@ class transformerModel(PreTrainedModel):
|
|
| 411 |
def forward(self, enc_X, dec_X, *args):
|
| 412 |
enc_outputs = self.encoder(enc_X, *args)
|
| 413 |
dec_state = self.decoder.init_state(enc_outputs, *args)
|
| 414 |
-
return self.decoder(dec_X, dec_state)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
valid_lens = valid_lens.reshape(-1)
|
| 72 |
# 最后一轴上被掩蔽的元素使用一个非常大的负值替换,从而其softmax输出为0
|
| 73 |
X = sequence_mask(X.reshape(-1, shape[-1]), valid_lens,
|
| 74 |
+
value=-1e4)
|
| 75 |
return nn.functional.softmax(X.reshape(shape), dim=-1)
|
| 76 |
|
| 77 |
class DotProductAttention(nn.Module):
|
|
|
|
| 411 |
def forward(self, enc_X, dec_X, *args):
|
| 412 |
enc_outputs = self.encoder(enc_X, *args)
|
| 413 |
dec_state = self.decoder.init_state(enc_outputs, *args)
|
| 414 |
+
return self.decoder(dec_X, dec_state)
|
| 415 |
+
|
| 416 |
+
def predict_seq2seq(net, src_sentence, src_vocab, tgt_vocab, num_steps,
|
| 417 |
+
device, save_attention_weights=False):
|
| 418 |
+
"""序列到序列模型的预测
|
| 419 |
+
|
| 420 |
+
Defined in :numref:`sec_seq2seq_training`"""
|
| 421 |
+
# 在预测时将net设置为评估模式
|
| 422 |
+
net.eval()
|
| 423 |
+
src_tokens = src_vocab[src_sentence.lower().split(' ')] + [
|
| 424 |
+
src_vocab['<eos>']]
|
| 425 |
+
enc_valid_len = torch.tensor([len(src_tokens)], device=device)
|
| 426 |
+
src_tokens = d2l.truncate_pad(src_tokens, num_steps, src_vocab['<pad>'])
|
| 427 |
+
# 添加批量轴
|
| 428 |
+
enc_X = torch.unsqueeze(
|
| 429 |
+
torch.tensor(src_tokens, dtype=torch.long, device=device), dim=0)
|
| 430 |
+
enc_outputs = net.encoder(enc_X, enc_valid_len)
|
| 431 |
+
dec_state = net.decoder.init_state(enc_outputs, enc_valid_len)
|
| 432 |
+
# 添加批量轴
|
| 433 |
+
dec_X = torch.unsqueeze(torch.tensor(
|
| 434 |
+
[tgt_vocab['<bos>']], dtype=torch.long, device=device), dim=0)
|
| 435 |
+
output_seq, attention_weight_seq = [], []
|
| 436 |
+
for _ in range(num_steps):
|
| 437 |
+
Y, dec_state = net.decoder(dec_X, dec_state)
|
| 438 |
+
# 我们使用具有预测最高可能性的词元,作为解码器在下一时间步的输入
|
| 439 |
+
dec_X = Y.argmax(dim=2)
|
| 440 |
+
pred = dec_X.squeeze(dim=0).type(torch.int32).item()
|
| 441 |
+
# 保存注意力权重(稍后讨论)
|
| 442 |
+
if save_attention_weights:
|
| 443 |
+
attention_weight_seq.append(net.decoder.attention_weights)
|
| 444 |
+
# 一旦序列结束词元被预测,输出序列的生成就完成了
|
| 445 |
+
if pred == tgt_vocab['<eos>']:
|
| 446 |
+
break
|
| 447 |
+
output_seq.append(pred)
|
| 448 |
+
return ' '.join(tgt_vocab.to_tokens(output_seq)), attention_weight_seq
|