fix mask and position bug for batch generation
#35
by
qingsonglv
- opened
- modeling_chatglm.py +27 -6
modeling_chatglm.py
CHANGED
|
@@ -662,6 +662,12 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
|
| 662 |
"""Initialize the weights."""
|
| 663 |
return
|
| 664 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 665 |
def get_masks(self, input_ids, device):
|
| 666 |
batch_size, seq_length = input_ids.shape
|
| 667 |
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
|
@@ -669,6 +675,10 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
|
| 669 |
attention_mask.tril_()
|
| 670 |
for i, context_length in enumerate(context_lengths):
|
| 671 |
attention_mask[i, :, :context_length] = 1
|
|
|
|
|
|
|
|
|
|
|
|
|
| 672 |
attention_mask.unsqueeze_(1)
|
| 673 |
attention_mask = (attention_mask < 0.5).bool()
|
| 674 |
|
|
@@ -676,16 +686,22 @@ class ChatGLMPreTrainedModel(PreTrainedModel):
|
|
| 676 |
|
| 677 |
def get_position_ids(self, input_ids, mask_positions, device, gmask=False):
|
| 678 |
batch_size, seq_length = input_ids.shape
|
|
|
|
| 679 |
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
| 680 |
if self.position_encoding_2d:
|
| 681 |
-
position_ids = torch.arange(seq_length, dtype=torch.long, device=device)
|
| 682 |
-
for i, context_length in enumerate(context_lengths):
|
| 683 |
-
position_ids[i
|
| 684 |
block_position_ids = [torch.cat((
|
| 685 |
torch.zeros(context_length, dtype=torch.long, device=device),
|
| 686 |
torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
|
| 687 |
)) for context_length in context_lengths]
|
| 688 |
block_position_ids = torch.stack(block_position_ids, dim=0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 689 |
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
|
| 690 |
else:
|
| 691 |
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
|
@@ -1094,15 +1110,20 @@ class ChatGLMForConditionalGeneration(ChatGLMPreTrainedModel):
|
|
| 1094 |
if attention_mask is not None and attention_mask.dtype == torch.bool:
|
| 1095 |
attention_mask = attention_mask[:, :, -1:]
|
| 1096 |
else:
|
| 1097 |
-
attention_mask =
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1098 |
if position_ids is not None:
|
| 1099 |
position_ids = position_ids[..., -1:]
|
| 1100 |
else:
|
|
|
|
| 1101 |
context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs]
|
| 1102 |
if self.position_encoding_2d:
|
| 1103 |
position_ids = torch.tensor(
|
| 1104 |
-
[[mask_position, seq_length - context_length] for mask_position, context_length in
|
| 1105 |
-
zip(mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1)
|
| 1106 |
else:
|
| 1107 |
position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long,
|
| 1108 |
device=input_ids.device).unsqueeze(-1)
|
|
|
|
| 662 |
"""Initialize the weights."""
|
| 663 |
return
|
| 664 |
|
| 665 |
+
def get_pad_length(self, seq):
|
| 666 |
+
l = 0
|
| 667 |
+
while l < len(seq) and seq[l] == self.config.pad_token_id:
|
| 668 |
+
l += 1
|
| 669 |
+
return l
|
| 670 |
+
|
| 671 |
def get_masks(self, input_ids, device):
|
| 672 |
batch_size, seq_length = input_ids.shape
|
| 673 |
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
|
|
|
| 675 |
attention_mask.tril_()
|
| 676 |
for i, context_length in enumerate(context_lengths):
|
| 677 |
attention_mask[i, :, :context_length] = 1
|
| 678 |
+
pad_lengths = [self.get_pad_length(seq.tolist()) for seq in input_ids]
|
| 679 |
+
for i, pad_length in enumerate(pad_lengths):
|
| 680 |
+
attention_mask[i, :, :pad_length] = 0
|
| 681 |
+
attention_mask[i, :pad_length, :] = 0
|
| 682 |
attention_mask.unsqueeze_(1)
|
| 683 |
attention_mask = (attention_mask < 0.5).bool()
|
| 684 |
|
|
|
|
| 686 |
|
| 687 |
def get_position_ids(self, input_ids, mask_positions, device, gmask=False):
|
| 688 |
batch_size, seq_length = input_ids.shape
|
| 689 |
+
pad_lengths = [self.get_pad_length(seq.tolist()) for seq in input_ids]
|
| 690 |
context_lengths = [seq.tolist().index(self.config.bos_token_id) for seq in input_ids]
|
| 691 |
if self.position_encoding_2d:
|
| 692 |
+
position_ids = [torch.arange(seq_length-pad_length, dtype=torch.long, device=device) for pad_length in pad_lengths]
|
| 693 |
+
for i, (context_length, pad_length) in enumerate(zip(context_lengths, pad_lengths)):
|
| 694 |
+
position_ids[i][context_length-pad_length:] = mask_positions[i] - pad_length
|
| 695 |
block_position_ids = [torch.cat((
|
| 696 |
torch.zeros(context_length, dtype=torch.long, device=device),
|
| 697 |
torch.arange(seq_length - context_length, dtype=torch.long, device=device) + 1
|
| 698 |
)) for context_length in context_lengths]
|
| 699 |
block_position_ids = torch.stack(block_position_ids, dim=0)
|
| 700 |
+
position_ids = [torch.cat((
|
| 701 |
+
torch.zeros(pad_length, dtype=torch.long, device=device),
|
| 702 |
+
range_pos
|
| 703 |
+
)) for pad_length, range_pos in zip(pad_lengths, position_ids)]
|
| 704 |
+
position_ids = torch.stack(position_ids, dim=0)
|
| 705 |
position_ids = torch.stack((position_ids, block_position_ids), dim=1)
|
| 706 |
else:
|
| 707 |
position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0).repeat(batch_size, 1)
|
|
|
|
| 1110 |
if attention_mask is not None and attention_mask.dtype == torch.bool:
|
| 1111 |
attention_mask = attention_mask[:, :, -1:]
|
| 1112 |
else:
|
| 1113 |
+
attention_mask = self.get_masks(
|
| 1114 |
+
input_ids,
|
| 1115 |
+
device=input_ids.device
|
| 1116 |
+
)
|
| 1117 |
+
attention_mask = attention_mask[:, :, -1:]
|
| 1118 |
if position_ids is not None:
|
| 1119 |
position_ids = position_ids[..., -1:]
|
| 1120 |
else:
|
| 1121 |
+
pad_lengths = [self.get_pad_length(seq.tolist()) for seq in input_ids]
|
| 1122 |
context_lengths = [seq.index(self.config.bos_token_id) for seq in seqs]
|
| 1123 |
if self.position_encoding_2d:
|
| 1124 |
position_ids = torch.tensor(
|
| 1125 |
+
[[mask_position - pad_length, seq_length - context_length] for pad_length, mask_position, context_length in
|
| 1126 |
+
zip(pad_lengths, mask_positions, context_lengths)], dtype=torch.long, device=input_ids.device).unsqueeze(-1)
|
| 1127 |
else:
|
| 1128 |
position_ids = torch.tensor([mask_position for mask_position in mask_positions], dtype=torch.long,
|
| 1129 |
device=input_ids.device).unsqueeze(-1)
|