| |
| |
| |
| |
| @@ -159,8 +159,12 @@ def forward(self, hidden_states: torch.Tensor, attention_dists: torch.Tensor) -> |
| # shaw's relative positional embedding |
| dist = attention_dists.to(hidden_states.device) |
| rel_pos_emb = self.rel_pos_emb(dist) |
| - rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + list(rel_pos_emb.shape)) |
| - pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, dim=-1) * self.scale |
| + # alternative computation of `pos_attn` - for readability |
| + # rel_pos_emb_expanded = rel_pos_emb.view([1, 1, 1] + list(rel_pos_emb.shape)) |
| + # pos_attn = torch.sum(query_states.unsqueeze(-2) * rel_pos_emb_expanded, dim=-1) * self.scale |
| + # einsum implementation of pos_attn - gives x30 speedup over the alternative |
| + # TODO (@avihu111) find a fast alternative to einsum |
| + pos_attn = torch.einsum("b m h c d, c r d -> b m h c r", query_states, rel_pos_emb) * self.scale |
| |
| if remainder > 0: |
| # masked attention in the extended block |
| @@ -541,17 +545,34 @@ def generate(self, *args, **kwargs) -> torch.LongTensor: |
| self.disable_adapters() |
| return super().generate(*args, input_features=input_features, **kwargs) |
| |
| - def save_pretrained(self, *args, **kwargs): |
| + def save_pretrained(self, save_directory, *args, **kwargs): |
| # overwrite save_pretrained to first save the adapter if we have one |
| - # NOTE - this will use the base model path we are exporting in the lora |
| - # adapter, which may not necessarily be the best behavior, but for now |
| - # we keep this for portability, since using the local dir causes problems |
| - # if the model is loaded from outside of the current working dir. |
| if is_peft_available and self._hf_peft_config_loaded: |
| - super().save_pretrained(*args, **kwargs) |
| + adapter_name = self._get_adapter_name() |
| + self.peft_config[adapter_name].base_model_name_or_path = save_directory |
| + super().save_pretrained(save_directory, *args, **kwargs) |
| # Then save the base model afterwards |
| + prev_val = self._hf_peft_config_loaded |
| self._hf_peft_config_loaded = False |
| - super().save_pretrained(*args, **kwargs) |
| + super().save_pretrained(save_directory, *args, **kwargs) |
| + self._hf_peft_config_loaded = prev_val |
| + |
| + @staticmethod |
| + def _fix_state_dict_key_on_save(key) -> tuple[str, bool]: |
| + # save the model with the original weights format |
| + return key.replace(".base_layer", ""), False |
| + |
| + def _fix_state_dict_keys_on_save(self, state_dict): |
| + if is_peft_available and self._hf_peft_config_loaded: |
| + # state dict is only adapter, should keep the same |
| + return state_dict |
| + # rename back the base model state dict |
| + return { |
| + self._fix_state_dict_key_on_save(key)[0]: value for key, value in state_dict.items() if ".lora_" not in key |
| + } |
| + |
| + def _get_adapter_name(self): |
| + return list(self.peft_config.keys())[0] |
| |
| |
| __all__ = [ |
|
|