duzx16
commited on
Commit
·
89621ea
1
Parent(s):
a8445da
Fix batch beam search
Browse files- modeling_glm.py +82 -10
modeling_glm.py
CHANGED
|
@@ -30,6 +30,7 @@ from transformers.utils import (
|
|
| 30 |
from transformers.modeling_outputs import (
|
| 31 |
BaseModelOutputWithPastAndCrossAttentions,
|
| 32 |
ModelOutput,
|
|
|
|
| 33 |
)
|
| 34 |
|
| 35 |
from transformers.modeling_utils import (
|
|
@@ -780,17 +781,15 @@ class GLMModel(GLMPreTrainedModel):
|
|
| 780 |
attention_mask = torch.zeros(batch_size)
|
| 781 |
# Transformer.
|
| 782 |
transformer_output = self.transformer(embeddings, position_ids, attention_mask, mems)
|
| 783 |
-
|
| 784 |
-
|
| 785 |
if self.output_predict:
|
| 786 |
-
|
| 787 |
-
# logits_parallel = mpu.copy_to_model_parallel_region(
|
| 788 |
-
# logits)
|
| 789 |
-
logits = F.linear(logits, self.word_embeddings.weight)
|
| 790 |
|
| 791 |
return ModelOutput(
|
|
|
|
| 792 |
logits=logits,
|
| 793 |
-
mems=
|
| 794 |
)
|
| 795 |
|
| 796 |
|
|
@@ -815,7 +814,7 @@ class GLMForMultipleChoice(GLMPreTrainedModel):
|
|
| 815 |
mems=None,
|
| 816 |
**kwargs
|
| 817 |
):
|
| 818 |
-
model_output = self.glm
|
| 819 |
lm_logits = model_output.logits
|
| 820 |
log_probs = []
|
| 821 |
for output, choices, choice_index in zip(F.log_softmax(lm_logits, dim=-1), choice_ids, choice_indices):
|
|
@@ -874,6 +873,16 @@ class GLMForConditionalGeneration(GLMPreTrainedModel):
|
|
| 874 |
position_ids = position_ids[:, :, :seq_length]
|
| 875 |
if attention_mask is not None:
|
| 876 |
attention_mask = attention_mask[:, :, :seq_length, :seq_length]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 877 |
return {
|
| 878 |
"input_ids": input_ids,
|
| 879 |
"position_ids": position_ids,
|
|
@@ -890,7 +899,7 @@ class GLMForConditionalGeneration(GLMPreTrainedModel):
|
|
| 890 |
mems=None,
|
| 891 |
**kwargs
|
| 892 |
):
|
| 893 |
-
model_output = self.glm
|
| 894 |
lm_logits = model_output.logits
|
| 895 |
loss = None
|
| 896 |
if labels is not None:
|
|
@@ -900,4 +909,67 @@ class GLMForConditionalGeneration(GLMPreTrainedModel):
|
|
| 900 |
loss=loss,
|
| 901 |
logits=lm_logits,
|
| 902 |
mems=model_output.mems
|
| 903 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
from transformers.modeling_outputs import (
|
| 31 |
BaseModelOutputWithPastAndCrossAttentions,
|
| 32 |
ModelOutput,
|
| 33 |
+
SequenceClassifierOutput,
|
| 34 |
)
|
| 35 |
|
| 36 |
from transformers.modeling_utils import (
|
|
|
|
| 781 |
attention_mask = torch.zeros(batch_size)
|
| 782 |
# Transformer.
|
| 783 |
transformer_output = self.transformer(embeddings, position_ids, attention_mask, mems)
|
| 784 |
+
last_hidden_states, mems = transformer_output
|
| 785 |
+
logits = None
|
| 786 |
if self.output_predict:
|
| 787 |
+
logits = F.linear(last_hidden_states, self.word_embeddings.weight)
|
|
|
|
|
|
|
|
|
|
| 788 |
|
| 789 |
return ModelOutput(
|
| 790 |
+
last_hidden_states=last_hidden_states,
|
| 791 |
logits=logits,
|
| 792 |
+
mems=mems,
|
| 793 |
)
|
| 794 |
|
| 795 |
|
|
|
|
| 814 |
mems=None,
|
| 815 |
**kwargs
|
| 816 |
):
|
| 817 |
+
model_output = self.glm(input_ids, position_ids, attention_mask, mems=mems, **kwargs)
|
| 818 |
lm_logits = model_output.logits
|
| 819 |
log_probs = []
|
| 820 |
for output, choices, choice_index in zip(F.log_softmax(lm_logits, dim=-1), choice_ids, choice_indices):
|
|
|
|
| 873 |
position_ids = position_ids[:, :, :seq_length]
|
| 874 |
if attention_mask is not None:
|
| 875 |
attention_mask = attention_mask[:, :, :seq_length, :seq_length]
|
| 876 |
+
if position_ids is not None and input_ids.size(0) > position_ids.size(0):
|
| 877 |
+
batch_size = position_ids.size(0)
|
| 878 |
+
num_beams = input_ids.size(0) // batch_size
|
| 879 |
+
position_ids = position_ids.unsqueeze(1).expand(-1, num_beams, -1, -1)
|
| 880 |
+
position_ids = position_ids.reshape(batch_size * num_beams, *position_ids.shape[-2:])
|
| 881 |
+
if attention_mask is not None and input_ids.size(0) > attention_mask.size(0):
|
| 882 |
+
batch_size = attention_mask.size(0)
|
| 883 |
+
num_beams = input_ids.size(0) // batch_size
|
| 884 |
+
attention_mask = attention_mask.unsqueeze(1).expand(-1, num_beams, -1, -1, -1)
|
| 885 |
+
attention_mask = attention_mask.reshape(batch_size * num_beams, *attention_mask.shape[-3:])
|
| 886 |
return {
|
| 887 |
"input_ids": input_ids,
|
| 888 |
"position_ids": position_ids,
|
|
|
|
| 899 |
mems=None,
|
| 900 |
**kwargs
|
| 901 |
):
|
| 902 |
+
model_output = self.glm(input_ids, position_ids, attention_mask, mems=mems, **kwargs)
|
| 903 |
lm_logits = model_output.logits
|
| 904 |
loss = None
|
| 905 |
if labels is not None:
|
|
|
|
| 909 |
loss=loss,
|
| 910 |
logits=lm_logits,
|
| 911 |
mems=model_output.mems
|
| 912 |
+
)
|
| 913 |
+
|
| 914 |
+
|
| 915 |
+
@add_start_docstrings(
|
| 916 |
+
"""GLM Model transformer with a sequence classification/regression head on top (a linear layer on top of
|
| 917 |
+
the pooled output) e.g. for GLUE tasks. """,
|
| 918 |
+
GLM_START_DOCSTRING,
|
| 919 |
+
)
|
| 920 |
+
class GLMForSequenceClassification(GLMPreTrainedModel):
|
| 921 |
+
def __init__(self, config: GLMConfig, hidden_dropout=None, num_class=1):
|
| 922 |
+
super().__init__(config)
|
| 923 |
+
self.pool_token = config.pool_token
|
| 924 |
+
self.glm = GLMModel(config)
|
| 925 |
+
self.glm.output_predict = False
|
| 926 |
+
self.num_class = num_class
|
| 927 |
+
# Multi-choice head.
|
| 928 |
+
self.dense = torch.nn.Linear(config.hidden_size, config.hidden_size)
|
| 929 |
+
classifier_dropout = (
|
| 930 |
+
config.classifier_dropout if config.classifier_dropout is not None else config.output_dropout_prob
|
| 931 |
+
)
|
| 932 |
+
self.dropout = torch.nn.Dropout(classifier_dropout)
|
| 933 |
+
self.out_proj = torch.nn.Linear(config.hidden_size, config.num_labels)
|
| 934 |
+
|
| 935 |
+
# Initialize weights and apply final processing
|
| 936 |
+
self.post_init()
|
| 937 |
+
|
| 938 |
+
@add_start_docstrings_to_model_forward(GLM_INPUTS_DOCSTRING.format("batch_size, sequence_length"))
|
| 939 |
+
@add_code_sample_docstrings(
|
| 940 |
+
processor_class=_TOKENIZER_FOR_DOC,
|
| 941 |
+
checkpoint=_CHECKPOINT_FOR_DOC,
|
| 942 |
+
output_type=SequenceClassifierOutput,
|
| 943 |
+
config_class=_CONFIG_FOR_DOC,
|
| 944 |
+
)
|
| 945 |
+
def forward(self,
|
| 946 |
+
input_ids=None,
|
| 947 |
+
position_ids=None,
|
| 948 |
+
attention_mask=None,
|
| 949 |
+
labels=None):
|
| 950 |
+
|
| 951 |
+
num_choices = None
|
| 952 |
+
|
| 953 |
+
if len(input_ids.shape) == 3:
|
| 954 |
+
batch_size, num_choices = input_ids.shape[:2]
|
| 955 |
+
input_ids = input_ids.reshape(-1, input_ids.size(-1))
|
| 956 |
+
attention_mask = attention_mask.reshape(-1, *attention_mask.size()[2:])
|
| 957 |
+
position_ids = position_ids.reshape(-1, *position_ids.size()[2:])
|
| 958 |
+
model_out = self.glm(input_ids, position_ids, attention_mask)
|
| 959 |
+
outputs, mems = model_out.last_hidden_states, model_out.mems
|
| 960 |
+
|
| 961 |
+
output = outputs[:, 0, :]
|
| 962 |
+
output = self.dropout(output)
|
| 963 |
+
output = torch.tanh(self.dense(output))
|
| 964 |
+
output = self.dropout(output)
|
| 965 |
+
logits = self.out_proj(output)
|
| 966 |
+
if num_choices is not None:
|
| 967 |
+
logits = logits.view(-1, num_choices)
|
| 968 |
+
loss = None
|
| 969 |
+
if labels is not None:
|
| 970 |
+
loss_fct = CrossEntropyLoss()
|
| 971 |
+
loss = loss_fct(logits, labels)
|
| 972 |
+
# loss = F.cross_entropy(logits.contiguous().float(), labels.long())
|
| 973 |
+
return SequenceClassifierOutput(loss=loss,
|
| 974 |
+
logits=logits,
|
| 975 |
+
hidden_states=outputs)
|