| from typing import Any, Dict, Optional |
|
|
| import torch |
| from transformers import GenerationMixin |
| from transformers.cache_utils import Cache |
| from transformers.utils import ModelOutput |
|
|
|
|
| class VoraGenerationMixin(GenerationMixin): |
|
|
| def prepare_inputs_for_generation( |
| self, |
| input_ids: torch.LongTensor, |
| past_key_values: Optional[Cache] = None, |
| attention_mask: Optional[torch.LongTensor] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| cache_position: Optional[torch.LongTensor] = None, |
| **kwargs, |
| ): |
| if attention_mask is not None and attention_mask.ndim == 4: |
| attention_mask_2d = (attention_mask[:, 0, :, :] == 0).any(dim=1).long().to(attention_mask.device) |
| model_input = super().prepare_inputs_for_generation( |
| input_ids, |
| past_key_values=past_key_values, |
| attention_mask=attention_mask_2d, |
| inputs_embeds=inputs_embeds, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
| model_input['attention_mask'] = attention_mask |
| return model_input |
| else: |
| return super().prepare_inputs_for_generation( |
| input_ids, |
| past_key_values=past_key_values, |
| attention_mask=attention_mask, |
| inputs_embeds=inputs_embeds, |
| cache_position=cache_position, |
| **kwargs, |
| ) |
|
|
| def _update_model_kwargs_for_generation( |
| self, |
| outputs: ModelOutput, |
| model_kwargs: Dict[str, Any], |
| is_encoder_decoder: bool = False, |
| num_new_tokens: int = 1, |
| ) -> Dict[str, Any]: |
| if "attention_mask" in model_kwargs and model_kwargs["attention_mask"].ndim == 4: |
| attention_mask = model_kwargs.pop("attention_mask") |
| model_kwargs = super()._update_model_kwargs_for_generation( |
| outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder, num_new_tokens=num_new_tokens |
| ) |
| bs, _, seq_len, tgt_len = attention_mask.shape |
| dtype = attention_mask.dtype |
| min_dtype = torch.finfo(dtype).min |
| new_col = attention_mask.new_zeros((bs, 1, seq_len, 1)).fill_(min_dtype) |
| new_row = attention_mask.new_zeros((bs, 1, 1, tgt_len + 1)) |
| model_kwargs["attention_mask"] = torch.cat([ |
| torch.cat([attention_mask, new_col], dim=-1), |
| new_row |
| ], dim=2) |
| return model_kwargs |
| else: |
| return super()._update_model_kwargs_for_generation( |
| outputs, model_kwargs, is_encoder_decoder=is_encoder_decoder, num_new_tokens=num_new_tokens |
| ) |
|
|
|
|
| def custom_prepare_4d_causal_attention_mask_with_cache_position( |
| attention_mask: torch.Tensor, |
| sequence_length: int, |
| target_length: int, |
| dtype: torch.dtype, |
| device: torch.device, |
| cache_position: torch.Tensor, |
| batch_size: int, |
| **kwargs, |
| ): |
| if attention_mask is not None and attention_mask.dim() == 4: |
| |
| causal_mask = attention_mask[:, :, -sequence_length:, -target_length:] |
| else: |
| min_dtype = torch.finfo(dtype).min |
| causal_mask = torch.full( |
| (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device |
| ) |
| if sequence_length != 1: |
| causal_mask = torch.triu(causal_mask, diagonal=1) |
| causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) |
| causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1) |
| if attention_mask is not None: |
| causal_mask = causal_mask.clone() |
| mask_length = attention_mask.shape[-1] |
| padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :] |
| padding_mask = padding_mask == 0 |
| causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill( |
| padding_mask, min_dtype |
| ) |
|
|
| return causal_mask |
|
|