diff --git a/src/transformers/models/lfm2/modeling_lfm2.py b/src/transformers/models/lfm2/modeling_lfm2.py index 0674cd18319a..07aced67f650 100644 --- a/src/transformers/models/lfm2/modeling_lfm2.py +++ b/src/transformers/models/lfm2/modeling_lfm2.py @@ -165,6 +165,8 @@ def __init__( ) torch._dynamo.mark_static_address(conv_state) self.conv_cache.append(conv_state) + self.key_cache.append(torch.tensor([])) + self.value_cache.append(torch.tensor([])) def update( self, @@ -190,35 +192,27 @@ def update( A tuple containing the updated key and value states. """ # Update the cache - if key_states is not None: - if len(self.key_cache) <= layer_idx: - # There may be skipped layers, fill them with empty lists - for _ in range(len(self.key_cache), layer_idx): - self.key_cache.append(torch.tensor([])) - self.value_cache.append(torch.tensor([])) - self.key_cache.append(key_states) - self.value_cache.append(value_states) - elif ( - not self.key_cache[layer_idx].numel() # prefers not t.numel() to len(t) == 0 to export the model - ): # fills previously skipped layers; checking for tensor causes errors - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + if self.key_cache[layer_idx].numel() == 0: + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) return self.key_cache[layer_idx], self.value_cache[layer_idx] def reorder_cache(self, beam_idx: torch.LongTensor): """Reorders the cache for beam search, given the selected beam indices.""" for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - device = self.conv_cache[layer_idx].device - self.conv_cache[layer_idx] = self.conv_cache[layer_idx].index_select(0, beam_idx.to(device)) + if self.key_cache[layer_idx].numel(): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + if self.conv_cache[layer_idx].numel(): + device = self.conv_cache[layer_idx].device + self.conv_cache[layer_idx] = self.conv_cache[layer_idx].index_select(0, beam_idx.to(device)) def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed.""" diff --git a/src/transformers/models/lfm2/modular_lfm2.py b/src/transformers/models/lfm2/modular_lfm2.py index 5832a4d457a0..16a69fa0dc06 100644 --- a/src/transformers/models/lfm2/modular_lfm2.py +++ b/src/transformers/models/lfm2/modular_lfm2.py @@ -123,6 +123,8 @@ def __init__( ) torch._dynamo.mark_static_address(conv_state) self.conv_cache.append(conv_state) + self.key_cache.append(torch.tensor([])) + self.value_cache.append(torch.tensor([])) def update( self, @@ -148,35 +150,27 @@ def update( A tuple containing the updated key and value states. """ # Update the cache - if key_states is not None: - if len(self.key_cache) <= layer_idx: - # There may be skipped layers, fill them with empty lists - for _ in range(len(self.key_cache), layer_idx): - self.key_cache.append(torch.tensor([])) - self.value_cache.append(torch.tensor([])) - self.key_cache.append(key_states) - self.value_cache.append(value_states) - elif ( - not self.key_cache[layer_idx].numel() # prefers not t.numel() to len(t) == 0 to export the model - ): # fills previously skipped layers; checking for tensor causes errors - self.key_cache[layer_idx] = key_states - self.value_cache[layer_idx] = value_states - else: - self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) - self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) + if self.key_cache[layer_idx].numel() == 0: + self.key_cache[layer_idx] = key_states + self.value_cache[layer_idx] = value_states + else: + self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2) + self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2) return self.key_cache[layer_idx], self.value_cache[layer_idx] def reorder_cache(self, beam_idx: torch.LongTensor): """Reorders the cache for beam search, given the selected beam indices.""" for layer_idx in range(len(self.key_cache)): - device = self.key_cache[layer_idx].device - self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) - device = self.value_cache[layer_idx].device - self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) - - device = self.conv_cache[layer_idx].device - self.conv_cache[layer_idx] = self.conv_cache[layer_idx].index_select(0, beam_idx.to(device)) + if self.key_cache[layer_idx].numel(): + device = self.key_cache[layer_idx].device + self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device)) + device = self.value_cache[layer_idx].device + self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device)) + + if self.conv_cache[layer_idx].numel(): + device = self.conv_cache[layer_idx].device + self.conv_cache[layer_idx] = self.conv_cache[layer_idx].index_select(0, beam_idx.to(device)) def get_seq_length(self, layer_idx: Optional[int] = 0) -> int: """Returns the sequence length of the cached states. A layer index can be optionally passed."""