| |
| |
| |
| |
| @@ -165,6 +165,8 @@ def __init__( |
| ) |
| torch._dynamo.mark_static_address(conv_state) |
| self.conv_cache.append(conv_state) |
| + self.key_cache.append(torch.tensor([])) |
| + self.value_cache.append(torch.tensor([])) |
| |
| def update( |
| self, |
| @@ -190,35 +192,27 @@ def update( |
| A tuple containing the updated key and value states. |
| """ |
| # Update the cache |
| - if key_states is not None: |
| - if len(self.key_cache) <= layer_idx: |
| - # There may be skipped layers, fill them with empty lists |
| - for _ in range(len(self.key_cache), layer_idx): |
| - self.key_cache.append(torch.tensor([])) |
| - self.value_cache.append(torch.tensor([])) |
| - self.key_cache.append(key_states) |
| - self.value_cache.append(value_states) |
| - elif ( |
| - not self.key_cache[layer_idx].numel() # prefers not t.numel() to len(t) == 0 to export the model |
| - ): # fills previously skipped layers; checking for tensor causes errors |
| - self.key_cache[layer_idx] = key_states |
| - self.value_cache[layer_idx] = value_states |
| - else: |
| - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) |
| - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) |
| + if self.key_cache[layer_idx].numel() == 0: |
| + self.key_cache[layer_idx] = key_states |
| + self.value_cache[layer_idx] = value_states |
| + else: |
| + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) |
| + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) |
| |
| return self.key_cache[layer_idx], self.value_cache[layer_idx] |
| |
| def reorder_cache(self, beam_idx: torch.LongTensor): |
| """Reorders the cache for beam search, given the selected beam indices.""" |
| for layer_idx in range(len(self.key_cache)): |
| - device = self.key_cache[layer_idx].device |
| - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) |
| - device = self.value_cache[layer_idx].device |
| - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) |
| - |
| - device = self.conv_cache[layer_idx].device |
| - self.conv_cache[layer_idx] = self.conv_cache[layer_idx].index_select(0, beam_idx.to(device)) |
| + if self.key_cache[layer_idx].numel(): |
| + device = self.key_cache[layer_idx].device |
| + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) |
| + device = self.value_cache[layer_idx].device |
| + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) |
| + |
| + if self.conv_cache[layer_idx].numel(): |
| + device = self.conv_cache[layer_idx].device |
| + self.conv_cache[layer_idx] = self.conv_cache[layer_idx].index_select(0, beam_idx.to(device)) |
| |
| def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: |
| """Returns the sequence length of the cached states. A layer index can be optionally passed.""" |
| |
| |
| |
| |
| @@ -123,6 +123,8 @@ def __init__( |
| ) |
| torch._dynamo.mark_static_address(conv_state) |
| self.conv_cache.append(conv_state) |
| + self.key_cache.append(torch.tensor([])) |
| + self.value_cache.append(torch.tensor([])) |
| |
| def update( |
| self, |
| @@ -148,35 +150,27 @@ def update( |
| A tuple containing the updated key and value states. |
| """ |
| # Update the cache |
| - if key_states is not None: |
| - if len(self.key_cache) <= layer_idx: |
| - # There may be skipped layers, fill them with empty lists |
| - for _ in range(len(self.key_cache), layer_idx): |
| - self.key_cache.append(torch.tensor([])) |
| - self.value_cache.append(torch.tensor([])) |
| - self.key_cache.append(key_states) |
| - self.value_cache.append(value_states) |
| - elif ( |
| - not self.key_cache[layer_idx].numel() # prefers not t.numel() to len(t) == 0 to export the model |
| - ): # fills previously skipped layers; checking for tensor causes errors |
| - self.key_cache[layer_idx] = key_states |
| - self.value_cache[layer_idx] = value_states |
| - else: |
| - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) |
| - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) |
| + if self.key_cache[layer_idx].numel() == 0: |
| + self.key_cache[layer_idx] = key_states |
| + self.value_cache[layer_idx] = value_states |
| + else: |
| + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) |
| + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) |
| |
| return self.key_cache[layer_idx], self.value_cache[layer_idx] |
| |
| def reorder_cache(self, beam_idx: torch.LongTensor): |
| """Reorders the cache for beam search, given the selected beam indices.""" |
| for layer_idx in range(len(self.key_cache)): |
| - device = self.key_cache[layer_idx].device |
| - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) |
| - device = self.value_cache[layer_idx].device |
| - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) |
| - |
| - device = self.conv_cache[layer_idx].device |
| - self.conv_cache[layer_idx] = self.conv_cache[layer_idx].index_select(0, beam_idx.to(device)) |
| + if self.key_cache[layer_idx].numel(): |
| + device = self.key_cache[layer_idx].device |
| + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) |
| + device = self.value_cache[layer_idx].device |
| + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) |
| + |
| + if self.conv_cache[layer_idx].numel(): |
| + device = self.conv_cache[layer_idx].device |
| + self.conv_cache[layer_idx] = self.conv_cache[layer_idx].index_select(0, beam_idx.to(device)) |
| |
| def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: |
| """Returns the sequence length of the cached states. A layer index can be optionally passed.""" |
|
|