harness / diffs /41396.patch
ArthurZ's picture
ArthurZ HF Staff
Initial harness: 100 perf tasks + Gradio browser
dfefe0b verified
diff --git a/src/transformers/models/lfm2/modeling_lfm2.py b/src/transformers/models/lfm2/modeling_lfm2.py
index 0674cd18319a..07aced67f650 100644
--- a/src/transformers/models/lfm2/modeling_lfm2.py
+++ b/src/transformers/models/lfm2/modeling_lfm2.py
@@ -165,6 +165,8 @@ def __init__(
)
torch._dynamo.mark_static_address(conv_state)
self.conv_cache.append(conv_state)
+ self.key_cache.append(torch.tensor([]))
+ self.value_cache.append(torch.tensor([]))
def update(
self,
@@ -190,35 +192,27 @@ def update(
A tuple containing the updated key and value states.
"""
# Update the cache
- if key_states is not None:
- if len(self.key_cache) <= layer_idx:
- # There may be skipped layers, fill them with empty lists
- for _ in range(len(self.key_cache), layer_idx):
- self.key_cache.append(torch.tensor([]))
- self.value_cache.append(torch.tensor([]))
- self.key_cache.append(key_states)
- self.value_cache.append(value_states)
- elif (
- not self.key_cache[layer_idx].numel() # prefers not t.numel() to len(t) == 0 to export the model
- ): # fills previously skipped layers; checking for tensor causes errors
- self.key_cache[layer_idx] = key_states
- self.value_cache[layer_idx] = value_states
- else:
- self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
- self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
+ if self.key_cache[layer_idx].numel() == 0:
+ self.key_cache[layer_idx] = key_states
+ self.value_cache[layer_idx] = value_states
+ else:
+ self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
+ self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
return self.key_cache[layer_idx], self.value_cache[layer_idx]
def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)):
- device = self.key_cache[layer_idx].device
- self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
- device = self.value_cache[layer_idx].device
- self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
-
- device = self.conv_cache[layer_idx].device
- self.conv_cache[layer_idx] = self.conv_cache[layer_idx].index_select(0, beam_idx.to(device))
+ if self.key_cache[layer_idx].numel():
+ device = self.key_cache[layer_idx].device
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
+ device = self.value_cache[layer_idx].device
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
+
+ if self.conv_cache[layer_idx].numel():
+ device = self.conv_cache[layer_idx].device
+ self.conv_cache[layer_idx] = self.conv_cache[layer_idx].index_select(0, beam_idx.to(device))
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""
diff --git a/src/transformers/models/lfm2/modular_lfm2.py b/src/transformers/models/lfm2/modular_lfm2.py
index 5832a4d457a0..16a69fa0dc06 100644
--- a/src/transformers/models/lfm2/modular_lfm2.py
+++ b/src/transformers/models/lfm2/modular_lfm2.py
@@ -123,6 +123,8 @@ def __init__(
)
torch._dynamo.mark_static_address(conv_state)
self.conv_cache.append(conv_state)
+ self.key_cache.append(torch.tensor([]))
+ self.value_cache.append(torch.tensor([]))
def update(
self,
@@ -148,35 +150,27 @@ def update(
A tuple containing the updated key and value states.
"""
# Update the cache
- if key_states is not None:
- if len(self.key_cache) <= layer_idx:
- # There may be skipped layers, fill them with empty lists
- for _ in range(len(self.key_cache), layer_idx):
- self.key_cache.append(torch.tensor([]))
- self.value_cache.append(torch.tensor([]))
- self.key_cache.append(key_states)
- self.value_cache.append(value_states)
- elif (
- not self.key_cache[layer_idx].numel() # prefers not t.numel() to len(t) == 0 to export the model
- ): # fills previously skipped layers; checking for tensor causes errors
- self.key_cache[layer_idx] = key_states
- self.value_cache[layer_idx] = value_states
- else:
- self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
- self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
+ if self.key_cache[layer_idx].numel() == 0:
+ self.key_cache[layer_idx] = key_states
+ self.value_cache[layer_idx] = value_states
+ else:
+ self.key_cache[layer_idx] = torch.cat([self.key_cache[layer_idx], key_states], dim=-2)
+ self.value_cache[layer_idx] = torch.cat([self.value_cache[layer_idx], value_states], dim=-2)
return self.key_cache[layer_idx], self.value_cache[layer_idx]
def reorder_cache(self, beam_idx: torch.LongTensor):
"""Reorders the cache for beam search, given the selected beam indices."""
for layer_idx in range(len(self.key_cache)):
- device = self.key_cache[layer_idx].device
- self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
- device = self.value_cache[layer_idx].device
- self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
-
- device = self.conv_cache[layer_idx].device
- self.conv_cache[layer_idx] = self.conv_cache[layer_idx].index_select(0, beam_idx.to(device))
+ if self.key_cache[layer_idx].numel():
+ device = self.key_cache[layer_idx].device
+ self.key_cache[layer_idx] = self.key_cache[layer_idx].index_select(0, beam_idx.to(device))
+ device = self.value_cache[layer_idx].device
+ self.value_cache[layer_idx] = self.value_cache[layer_idx].index_select(0, beam_idx.to(device))
+
+ if self.conv_cache[layer_idx].numel():
+ device = self.conv_cache[layer_idx].device
+ self.conv_cache[layer_idx] = self.conv_cache[layer_idx].index_select(0, beam_idx.to(device))
def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
"""Returns the sequence length of the cached states. A layer index can be optionally passed."""