Fixes exceeding maximum sequence length when using generate().
Browse files- modeling_phi.py +16 -8
modeling_phi.py
CHANGED
|
@@ -481,7 +481,7 @@ def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, l
|
|
| 481 |
num_heads, head_dim = kv.shape[-2:]
|
| 482 |
|
| 483 |
if layer_idx not in inference_params.key_value_memory_dict:
|
| 484 |
-
|
| 485 |
inference_params.max_batch_size,
|
| 486 |
inference_params.max_seqlen,
|
| 487 |
2,
|
|
@@ -490,9 +490,6 @@ def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, l
|
|
| 490 |
dtype=kv.dtype,
|
| 491 |
device=kv.device,
|
| 492 |
)
|
| 493 |
-
inference_params.key_value_memory_dict[layer_idx] = kv_cache
|
| 494 |
-
else:
|
| 495 |
-
kv_cache = inference_params.key_value_memory_dict[layer_idx]
|
| 496 |
|
| 497 |
batch_start = inference_params.batch_size_offset
|
| 498 |
batch_end = batch_start + kv.shape[0]
|
|
@@ -500,9 +497,14 @@ def _update_kv_cache(kv: torch.FloatTensor, inference_params: InferenceParams, l
|
|
| 500 |
sequence_start = inference_params.seqlen_offset
|
| 501 |
sequence_end = sequence_start + kv.shape[1]
|
| 502 |
|
| 503 |
-
|
| 504 |
-
|
|
|
|
|
|
|
| 505 |
|
|
|
|
|
|
|
|
|
|
| 506 |
return kv
|
| 507 |
|
| 508 |
|
|
@@ -710,7 +712,6 @@ class MHA(nn.Module):
|
|
| 710 |
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
|
| 711 |
**kwargs,
|
| 712 |
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
| 713 |
-
# TODO: Need an alternative way for dynamic control flow: torch.any(~attention_mask.bool())
|
| 714 |
if attention_mask is not None:
|
| 715 |
attention_mask = attention_mask.bool()
|
| 716 |
else:
|
|
@@ -863,6 +864,13 @@ class PhiPreTrainedModel(PreTrainedModel):
|
|
| 863 |
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
|
| 864 |
**kwargs,
|
| 865 |
) -> Dict[str, Any]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 866 |
if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
|
| 867 |
past_key_values = InferenceParams(
|
| 868 |
max_seqlen=self.config.n_positions,
|
|
@@ -874,7 +882,7 @@ class PhiPreTrainedModel(PreTrainedModel):
|
|
| 874 |
)
|
| 875 |
else:
|
| 876 |
# Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
|
| 877 |
-
past_key_values.seqlen_offset =
|
| 878 |
input_ids = input_ids[:, -1].unsqueeze(-1)
|
| 879 |
|
| 880 |
return {
|
|
|
|
| 481 |
num_heads, head_dim = kv.shape[-2:]
|
| 482 |
|
| 483 |
if layer_idx not in inference_params.key_value_memory_dict:
|
| 484 |
+
inference_params.key_value_memory_dict[layer_idx] = torch.empty(
|
| 485 |
inference_params.max_batch_size,
|
| 486 |
inference_params.max_seqlen,
|
| 487 |
2,
|
|
|
|
| 490 |
dtype=kv.dtype,
|
| 491 |
device=kv.device,
|
| 492 |
)
|
|
|
|
|
|
|
|
|
|
| 493 |
|
| 494 |
batch_start = inference_params.batch_size_offset
|
| 495 |
batch_end = batch_start + kv.shape[0]
|
|
|
|
| 497 |
sequence_start = inference_params.seqlen_offset
|
| 498 |
sequence_end = sequence_start + kv.shape[1]
|
| 499 |
|
| 500 |
+
# When the current sequence length is equal to or larger than the maximum sequence length,
|
| 501 |
+
# we need to roll the cache to the left and update it
|
| 502 |
+
if sequence_end >= inference_params.max_seqlen:
|
| 503 |
+
inference_params.key_value_memory_dict[layer_idx] = inference_params.key_value_memory_dict[layer_idx].roll(-(sequence_end - sequence_start), 1)
|
| 504 |
|
| 505 |
+
inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, sequence_start:sequence_end, ...] = kv
|
| 506 |
+
kv = inference_params.key_value_memory_dict[layer_idx][batch_start:batch_end, :sequence_end, ...]
|
| 507 |
+
|
| 508 |
return kv
|
| 509 |
|
| 510 |
|
|
|
|
| 712 |
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
|
| 713 |
**kwargs,
|
| 714 |
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
|
|
|
| 715 |
if attention_mask is not None:
|
| 716 |
attention_mask = attention_mask.bool()
|
| 717 |
else:
|
|
|
|
| 864 |
attention_mask: Optional[Union[torch.LongTensor, torch.BoolTensor]] = None,
|
| 865 |
**kwargs,
|
| 866 |
) -> Dict[str, Any]:
|
| 867 |
+
# Truncate `input_ids` and `attention_mask` (if necessary) to prevent exceeding
|
| 868 |
+
# the maximum sequence length
|
| 869 |
+
if input_ids.shape[1] > self.config.n_positions:
|
| 870 |
+
input_ids = input_ids[:, -self.config.n_positions :]
|
| 871 |
+
if attention_mask is not None:
|
| 872 |
+
attention_mask = attention_mask[:, -self.config.n_positions :]
|
| 873 |
+
|
| 874 |
if past_key_values is None or not (isinstance(past_key_values, InferenceParams)):
|
| 875 |
past_key_values = InferenceParams(
|
| 876 |
max_seqlen=self.config.n_positions,
|
|
|
|
| 882 |
)
|
| 883 |
else:
|
| 884 |
# Assume that `past_key_values` has cached all tokens up to the last token in `input_ids`
|
| 885 |
+
past_key_values.seqlen_offset = input_ids.shape[1] - 1
|
| 886 |
input_ids = input_ids[:, -1].unsqueeze(-1)
|
| 887 |
|
| 888 |
return {
|