try to simplify checkpointing
Browse files- modeling_bert.py +2 -247
modeling_bert.py
CHANGED
|
@@ -329,9 +329,7 @@ class BertPreTrainedModel(nn.Module):
|
|
| 329 |
"""
|
| 330 |
# Instantiate model.
|
| 331 |
model = cls(config, *inputs, **kwargs)
|
| 332 |
-
load_return = model.load_state_dict(
|
| 333 |
-
remap_state_dict(state_dict_from_pretrained(model_name), config), strict=False
|
| 334 |
-
)
|
| 335 |
logger.info(load_return)
|
| 336 |
return model
|
| 337 |
|
|
@@ -528,247 +526,4 @@ class BertForPreTraining(BertPreTrainedModel):
|
|
| 528 |
loss=total_loss,
|
| 529 |
prediction_logits=prediction_scores,
|
| 530 |
seq_relationship_logits=seq_relationship_score,
|
| 531 |
-
)
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
def remap_state_dict(state_dict, config: PretrainedConfig):
|
| 535 |
-
"""
|
| 536 |
-
Map the state_dict of a Huggingface BERT model to be flash_attn compatible.
|
| 537 |
-
"""
|
| 538 |
-
|
| 539 |
-
# LayerNorm
|
| 540 |
-
def key_mapping_ln_gamma_beta(key):
|
| 541 |
-
key = re.sub(r"LayerNorm.gamma$", "LayerNorm.weight", key)
|
| 542 |
-
key = re.sub(r"LayerNorm.beta$", "LayerNorm.bias", key)
|
| 543 |
-
return key
|
| 544 |
-
|
| 545 |
-
state_dict = OrderedDict((key_mapping_ln_gamma_beta(k), v) for k, v in state_dict.items())
|
| 546 |
-
|
| 547 |
-
# Layers
|
| 548 |
-
def key_mapping_layers(key):
|
| 549 |
-
return re.sub(r"^bert.encoder.layer.", "bert.encoder.layers.", key)
|
| 550 |
-
|
| 551 |
-
state_dict = OrderedDict((key_mapping_layers(k), v) for k, v in state_dict.items())
|
| 552 |
-
|
| 553 |
-
# LayerNorm
|
| 554 |
-
def key_mapping_ln(key):
|
| 555 |
-
key = re.sub(r"^bert.embeddings.LayerNorm.", "bert.emb_ln.", key)
|
| 556 |
-
key = re.sub(
|
| 557 |
-
r"^bert.encoder.layers.(\d+).attention.output.LayerNorm.(weight|bias)",
|
| 558 |
-
r"bert.encoder.layers.\1.norm1.\2",
|
| 559 |
-
key,
|
| 560 |
-
)
|
| 561 |
-
key = re.sub(
|
| 562 |
-
r"^bert.encoder.layers.(\d+).output.LayerNorm.(weight|bias)",
|
| 563 |
-
r"bert.encoder.layers.\1.norm2.\2",
|
| 564 |
-
key,
|
| 565 |
-
)
|
| 566 |
-
key = re.sub(
|
| 567 |
-
r"^cls.predictions.transform.LayerNorm.(weight|bias)",
|
| 568 |
-
r"cls.predictions.transform.layer_norm.\1",
|
| 569 |
-
key,
|
| 570 |
-
)
|
| 571 |
-
return key
|
| 572 |
-
|
| 573 |
-
state_dict = OrderedDict((key_mapping_ln(k), v) for k, v in state_dict.items())
|
| 574 |
-
|
| 575 |
-
# MLP
|
| 576 |
-
def key_mapping_mlp(key):
|
| 577 |
-
key = re.sub(
|
| 578 |
-
r"^bert.encoder.layers.(\d+).intermediate.dense.(weight|bias)",
|
| 579 |
-
r"bert.encoder.layers.\1.mlp.fc1.\2",
|
| 580 |
-
key,
|
| 581 |
-
)
|
| 582 |
-
key = re.sub(
|
| 583 |
-
r"^bert.encoder.layers.(\d+).output.dense.(weight|bias)",
|
| 584 |
-
r"bert.encoder.layers.\1.mlp.fc2.\2",
|
| 585 |
-
key,
|
| 586 |
-
)
|
| 587 |
-
return key
|
| 588 |
-
|
| 589 |
-
state_dict = OrderedDict((key_mapping_mlp(k), v) for k, v in state_dict.items())
|
| 590 |
-
|
| 591 |
-
# Attention
|
| 592 |
-
last_layer_subset = getattr(config, "last_layer_subset", False)
|
| 593 |
-
for d in range(config.num_hidden_layers):
|
| 594 |
-
Wq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.weight")
|
| 595 |
-
Wk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.weight")
|
| 596 |
-
Wv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.weight")
|
| 597 |
-
bq = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.query.bias")
|
| 598 |
-
bk = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.key.bias")
|
| 599 |
-
bv = state_dict.pop(f"bert.encoder.layers.{d}.attention.self.value.bias")
|
| 600 |
-
if not (last_layer_subset and d == config.num_hidden_layers - 1):
|
| 601 |
-
state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.weight"] = torch.cat(
|
| 602 |
-
[Wq, Wk, Wv], dim=0
|
| 603 |
-
)
|
| 604 |
-
state_dict[f"bert.encoder.layers.{d}.mixer.Wqkv.bias"] = torch.cat([bq, bk, bv], dim=0)
|
| 605 |
-
else:
|
| 606 |
-
state_dict[f"bert.encoder.layers.{d}.mixer.Wq.weight"] = Wq
|
| 607 |
-
state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.weight"] = torch.cat([Wk, Wv], dim=0)
|
| 608 |
-
state_dict[f"bert.encoder.layers.{d}.mixer.Wq.bias"] = bq
|
| 609 |
-
state_dict[f"bert.encoder.layers.{d}.mixer.Wkv.bias"] = torch.cat([bk, bv], dim=0)
|
| 610 |
-
|
| 611 |
-
def key_mapping_attn(key):
|
| 612 |
-
return re.sub(
|
| 613 |
-
r"^bert.encoder.layers.(\d+).attention.output.dense.(weight|bias)",
|
| 614 |
-
r"bert.encoder.layers.\1.mixer.out_proj.\2",
|
| 615 |
-
key,
|
| 616 |
-
)
|
| 617 |
-
|
| 618 |
-
state_dict = OrderedDict((key_mapping_attn(k), v) for k, v in state_dict.items())
|
| 619 |
-
|
| 620 |
-
def key_mapping_decoder_bias(key):
|
| 621 |
-
return re.sub(r"^cls.predictions.bias", "cls.predictions.decoder.bias", key)
|
| 622 |
-
|
| 623 |
-
state_dict = OrderedDict((key_mapping_decoder_bias(k), v) for k, v in state_dict.items())
|
| 624 |
-
|
| 625 |
-
# Word embedding
|
| 626 |
-
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
| 627 |
-
if pad_vocab_size_multiple > 1:
|
| 628 |
-
word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"]
|
| 629 |
-
state_dict["bert.embeddings.word_embeddings.weight"] = F.pad(
|
| 630 |
-
word_embeddings, (0, 0, 0, config.vocab_size - word_embeddings.shape[0])
|
| 631 |
-
)
|
| 632 |
-
decoder_weight = state_dict["cls.predictions.decoder.weight"]
|
| 633 |
-
state_dict["cls.predictions.decoder.weight"] = F.pad(
|
| 634 |
-
decoder_weight, (0, 0, 0, config.vocab_size - decoder_weight.shape[0])
|
| 635 |
-
)
|
| 636 |
-
# If the vocab was padded, we want to set the decoder bias for those padded indices to be
|
| 637 |
-
# strongly negative (i.e. the decoder shouldn't predict those indices).
|
| 638 |
-
# TD [2022-05-09]: I don't think it affects the MLPerf training.
|
| 639 |
-
decoder_bias = state_dict["cls.predictions.decoder.bias"]
|
| 640 |
-
state_dict["cls.predictions.decoder.bias"] = F.pad(
|
| 641 |
-
decoder_bias, (0, config.vocab_size - decoder_bias.shape[0]), value=-100.0
|
| 642 |
-
)
|
| 643 |
-
|
| 644 |
-
return state_dict
|
| 645 |
-
|
| 646 |
-
|
| 647 |
-
def inv_remap_state_dict(state_dict, config: PretrainedConfig):
|
| 648 |
-
"""
|
| 649 |
-
Map the state_dict of a flash_attn model to be Huggingface BERT compatible.
|
| 650 |
-
|
| 651 |
-
This function is meant to be the inverse of remap_state_dict.
|
| 652 |
-
"""
|
| 653 |
-
# Word embedding
|
| 654 |
-
pad_vocab_size_multiple = getattr(config, "pad_vocab_size_multiple", 1)
|
| 655 |
-
if pad_vocab_size_multiple > 1:
|
| 656 |
-
word_embeddings = state_dict["bert.embeddings.word_embeddings.weight"]
|
| 657 |
-
decoder_weight = state_dict["cls.predictions.decoder.weight"]
|
| 658 |
-
decoder_bias = state_dict["cls.predictions.decoder.bias"]
|
| 659 |
-
# unpad embeddings
|
| 660 |
-
state_dict["bert.embeddings.word_embeddings.weight"] = word_embeddings[
|
| 661 |
-
: config.orig_vocab_size, :
|
| 662 |
-
]
|
| 663 |
-
state_dict["cls.predictions.decoder.weight"] = decoder_weight[: config.orig_vocab_size, :]
|
| 664 |
-
state_dict["cls.predictions.decoder.bias"] = decoder_bias[: config.orig_vocab_size]
|
| 665 |
-
|
| 666 |
-
for d in range(config.num_hidden_layers):
|
| 667 |
-
last_layer_subset = getattr(config, "last_layer_subset", False)
|
| 668 |
-
if not last_layer_subset or d != (config.num_hidden_layers - 1):
|
| 669 |
-
Wqkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.weight")
|
| 670 |
-
Wqkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wqkv.bias")
|
| 671 |
-
state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = Wqkv_weights[
|
| 672 |
-
: Wqkv_weights.shape[0] // 3, :
|
| 673 |
-
]
|
| 674 |
-
state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = Wqkv_weights[
|
| 675 |
-
Wqkv_weights.shape[0] // 3 : 2 * Wqkv_weights.shape[0] // 3, :
|
| 676 |
-
]
|
| 677 |
-
state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = Wqkv_weights[
|
| 678 |
-
2 * Wqkv_weights.shape[0] // 3 :, :
|
| 679 |
-
]
|
| 680 |
-
state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wqkv_biases[
|
| 681 |
-
: Wqkv_biases.shape[0] // 3
|
| 682 |
-
]
|
| 683 |
-
state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wqkv_biases[
|
| 684 |
-
Wqkv_biases.shape[0] // 3 : 2 * Wqkv_biases.shape[0] // 3
|
| 685 |
-
]
|
| 686 |
-
state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = Wqkv_biases[
|
| 687 |
-
2 * Wqkv_biases.shape[0] // 3 :
|
| 688 |
-
]
|
| 689 |
-
else:
|
| 690 |
-
Wq_weight = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.weight")
|
| 691 |
-
Wkv_weights = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.weight")
|
| 692 |
-
Wq_bias = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wq.bias")
|
| 693 |
-
Wkv_biases = state_dict.pop(f"bert.encoder.layers.{d}.mixer.Wkv.bias")
|
| 694 |
-
state_dict[f"bert.encoder.layers.{d}.attention.self.query.weight"] = Wq_weight
|
| 695 |
-
state_dict[f"bert.encoder.layers.{d}.attention.self.key.weight"] = Wkv_weights[
|
| 696 |
-
: Wkv_weights.shape[0] // 2, :
|
| 697 |
-
]
|
| 698 |
-
state_dict[f"bert.encoder.layers.{d}.attention.self.value.weight"] = Wkv_weights[
|
| 699 |
-
Wkv_weights.shape[0] // 2 :, :
|
| 700 |
-
]
|
| 701 |
-
state_dict[f"bert.encoder.layers.{d}.attention.self.query.bias"] = Wq_bias
|
| 702 |
-
state_dict[f"bert.encoder.layers.{d}.attention.self.key.bias"] = Wkv_biases[
|
| 703 |
-
: Wkv_biases.shape[0] // 2
|
| 704 |
-
]
|
| 705 |
-
state_dict[f"bert.encoder.layers.{d}.attention.self.value.bias"] = Wkv_biases[
|
| 706 |
-
Wkv_biases.shape[0] // 2 :
|
| 707 |
-
]
|
| 708 |
-
|
| 709 |
-
def inv_key_mapping_ln(key):
|
| 710 |
-
key = re.sub(r"bert.emb_ln.", "bert.embeddings.LayerNorm.", key)
|
| 711 |
-
key = re.sub(
|
| 712 |
-
r"bert.encoder.layers.(\d+).norm1.(weight|bias)",
|
| 713 |
-
r"bert.encoder.layers.\1.attention.output.LayerNorm.\2",
|
| 714 |
-
key,
|
| 715 |
-
)
|
| 716 |
-
key = re.sub(
|
| 717 |
-
r"bert.encoder.layers.(\d+).norm2.(weight|bias)",
|
| 718 |
-
r"bert.encoder.layers.\1.output.LayerNorm.\2",
|
| 719 |
-
key,
|
| 720 |
-
)
|
| 721 |
-
key = re.sub(
|
| 722 |
-
r"cls.predictions.transform.layer_norm.(weight|bias)",
|
| 723 |
-
r"cls.predictions.transform.LayerNorm.\1",
|
| 724 |
-
key,
|
| 725 |
-
)
|
| 726 |
-
return key
|
| 727 |
-
|
| 728 |
-
def inv_key_mapping_ln_gamma_beta(key):
|
| 729 |
-
key = re.sub(r"LayerNorm.weight$", "LayerNorm.gamma", key)
|
| 730 |
-
key = re.sub(r"LayerNorm.bias$", "LayerNorm.beta", key)
|
| 731 |
-
return key
|
| 732 |
-
|
| 733 |
-
def inv_key_mapping_layers(key):
|
| 734 |
-
return re.sub(r"bert.encoder.layers.", "bert.encoder.layer.", key)
|
| 735 |
-
|
| 736 |
-
def inv_key_mapping_mlp(key):
|
| 737 |
-
key = re.sub(
|
| 738 |
-
r"bert.encoder.layer.(\d+).mlp.fc1.(weight|bias)",
|
| 739 |
-
r"bert.encoder.layer.\1.intermediate.dense.\2",
|
| 740 |
-
key,
|
| 741 |
-
)
|
| 742 |
-
key = re.sub(
|
| 743 |
-
r"bert.encoder.layer.(\d+).mlp.fc2.(weight|bias)",
|
| 744 |
-
r"bert.encoder.layer.\1.output.dense.\2",
|
| 745 |
-
key,
|
| 746 |
-
)
|
| 747 |
-
return key
|
| 748 |
-
|
| 749 |
-
def inv_key_mapping_attn(key):
|
| 750 |
-
return re.sub(
|
| 751 |
-
r"bert.encoder.layer.(\d+).mixer.out_proj.(weight|bias)",
|
| 752 |
-
r"bert.encoder.layer.\1.attention.output.dense.\2",
|
| 753 |
-
key,
|
| 754 |
-
)
|
| 755 |
-
|
| 756 |
-
def inv_key_mapping_decoder_bias(key):
|
| 757 |
-
return re.sub(r"cls.predictions.decoder.bias", "cls.predictions.bias", key)
|
| 758 |
-
|
| 759 |
-
state_dict = OrderedDict((inv_key_mapping_ln(key), value) for key, value in state_dict.items())
|
| 760 |
-
state_dict = OrderedDict(
|
| 761 |
-
(inv_key_mapping_ln_gamma_beta(key), value) for key, value in state_dict.items()
|
| 762 |
-
)
|
| 763 |
-
state_dict = OrderedDict(
|
| 764 |
-
(inv_key_mapping_layers(key), value) for key, value in state_dict.items()
|
| 765 |
-
)
|
| 766 |
-
state_dict = OrderedDict((inv_key_mapping_mlp(key), value) for key, value in state_dict.items())
|
| 767 |
-
state_dict = OrderedDict(
|
| 768 |
-
(inv_key_mapping_attn(key), value) for key, value in state_dict.items()
|
| 769 |
-
)
|
| 770 |
-
state_dict = OrderedDict(
|
| 771 |
-
(inv_key_mapping_decoder_bias(key), value) for key, value in state_dict.items()
|
| 772 |
-
)
|
| 773 |
-
|
| 774 |
-
return state_dict
|
|
|
|
| 329 |
"""
|
| 330 |
# Instantiate model.
|
| 331 |
model = cls(config, *inputs, **kwargs)
|
| 332 |
+
load_return = model.load_state_dict(state_dict_from_pretrained(model_name), strict=False)
|
|
|
|
|
|
|
| 333 |
logger.info(load_return)
|
| 334 |
return model
|
| 335 |
|
|
|
|
| 526 |
loss=total_loss,
|
| 527 |
prediction_logits=prediction_scores,
|
| 528 |
seq_relationship_logits=seq_relationship_score,
|
| 529 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|