harness / diffs /39028.patch
ArthurZ's picture
ArthurZ HF Staff
Initial harness: 100 perf tasks + Gradio browser
dfefe0b verified
diff --git a/src/transformers/models/granite_speech/modeling_granite_speech.py b/src/transformers/models/granite_speech/modeling_granite_speech.py
index d30254ca62af..6e61f732b77b 100644
--- a/src/transformers/models/granite_speech/modeling_granite_speech.py
+++ b/src/transformers/models/granite_speech/modeling_granite_speech.py
@@ -159,8 +159,12 @@ def forward(self, hidden_states: torch.Tensor, attention_dists: torch.Tensor) ->
# shaw's relative positional embedding
dist = attention_dists.to(hidden_states.device)
rel_pos_emb = self.rel_pos_emb(dist)
- rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + list(rel_pos_emb.shape))
- pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, dim=-1) * self.scale
+ # alternative computation of `pos_attn` - for readability
+ # rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + list(rel_pos_emb.shape))
+ # pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, dim=-1) * self.scale
+ # einsum implementation of pos_attn - gives x30 speedup over the alternative
+ # TODO (@avihu111) find a fast alternative to einsum
+ pos_attn = torch.einsum("b m h c d, c r d -> b m h c r", query_states, rel_pos_emb) * self.scale
if remainder > 0:
# masked attention in the extended block
@@ -541,17 +545,34 @@ def generate(self, *args, **kwargs) -> torch.LongTensor:
self.disable_adapters()
return super().generate(*args, input_features=input_features, **kwargs)
- def save_pretrained(self, *args, **kwargs):
+ def save_pretrained(self, save_directory, *args, **kwargs):
# overwrite save_pretrained to first save the adapter if we have one
- # NOTE - this will use the base model path we are exporting in the lora
- # adapter, which may not necessarily be the best behavior, but for now
- # we keep this for portability, since using the local dir causes problems
- # if the model is loaded from outside of the current working dir.
if is_peft_available and self._hf_peft_config_loaded:
- super().save_pretrained(*args, **kwargs)
+ adapter_name = self._get_adapter_name()
+ self.peft_config[adapter_name].base_model_name_or_path = save_directory
+ super().save_pretrained(save_directory, *args, **kwargs)
# Then save the base model afterwards
+ prev_val = self._hf_peft_config_loaded
self._hf_peft_config_loaded = False
- super().save_pretrained(*args, **kwargs)
+ super().save_pretrained(save_directory, *args, **kwargs)
+ self._hf_peft_config_loaded = prev_val
+
+ @staticmethod
+ def _fix_state_dict_key_on_save(key) -> tuple[str, bool]:
+ # save the model with the original weights format
+ return key.replace(".base_layer", ""), False
+
+ def _fix_state_dict_keys_on_save(self, state_dict):
+ if is_peft_available and self._hf_peft_config_loaded:
+ # state dict is only adapter, should keep the same
+ return state_dict
+ # rename back the base model state dict
+ return {
+ self._fix_state_dict_key_on_save(key)[0]: value for key, value in state_dict.items() if ".lora_" not in key
+ }
+
+ def _get_adapter_name(self):
+ return list(self.peft_config.keys())[0]
__all__ = [