| from typing import Unpack
|
| import torch
|
| from transformers import (
|
| RobertaModel,
|
| Cache,
|
| EncoderDecoderCache,
|
| DynamicCache,
|
| DataCollatorWithFlattening,
|
| RobertaForMaskedLM,
|
| RobertaForSequenceClassification,
|
| RobertaForTokenClassification,
|
| RobertaForQuestionAnswering,
|
| RobertaForMultipleChoice,
|
| RobertaForCausalLM
|
| )
|
| from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
|
| from transformers.utils import TransformersKwargs
|
|
|
|
|
| def _unpad_input(input_ids: torch.Tensor, attention_mask: torch.Tensor):
|
| collator = DataCollatorWithFlattening(return_flash_attn_kwargs=True)
|
| features = collator([{"input_ids": i[a.bool()].tolist()} for i, a in zip(input_ids, attention_mask)])
|
| return features
|
|
|
|
|
| def _pad_output(inputs: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int,) -> torch.Tensor:
|
| if inputs.dim() == 3:
|
| inputs = inputs.squeeze()
|
| if inputs.dim() == 1:
|
| output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device)
|
| output[indices] = inputs
|
| padded_inputs = output.view(batch, seqlen)
|
| else:
|
| _, *rest = inputs.shape
|
| output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device)
|
| output[indices] = inputs
|
| padded_inputs = output.view(batch, seqlen, *rest)
|
| return padded_inputs
|
|
|
|
|
| class UnpadRobertaModel(RobertaModel):
|
| _no_split_modules = ["RobertaEmbeddings", "RobertaLayer"]
|
|
|
| def __init__(self, config, add_pooling_layer=True):
|
| super().__init__(config, add_pooling_layer=add_pooling_layer)
|
|
|
| def forward(
|
| self,
|
| input_ids: torch.Tensor | None = None,
|
| attention_mask: torch.Tensor | None = None,
|
| token_type_ids: torch.Tensor | None = None,
|
| position_ids: torch.Tensor | None = None,
|
| inputs_embeds: torch.Tensor | None = None,
|
| encoder_hidden_states: torch.Tensor | None = None,
|
| encoder_attention_mask: torch.Tensor | None = None,
|
| past_key_values: Cache | None = None,
|
| use_cache: bool | None = None,
|
| cache_position: torch.Tensor | None = None,
|
| **kwargs: Unpack[TransformersKwargs],
|
| ) -> tuple[torch.Tensor] | BaseModelOutputWithPoolingAndCrossAttentions:
|
| if self.config.is_decoder:
|
| use_cache = use_cache if use_cache is not None else self.config.use_cache
|
| else:
|
| use_cache = False
|
|
|
| if use_cache and past_key_values is None:
|
| past_key_values = (
|
| EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
|
| if encoder_hidden_states is not None or self.config.is_encoder_decoder
|
| else DynamicCache(config=self.config)
|
| )
|
|
|
| if (input_ids is None) ^ (inputs_embeds is not None):
|
| raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
|
| if input_ids is not None:
|
| device = input_ids.device
|
| seq_length = input_ids.shape[1]
|
| batch_size = input_ids.size(0)
|
| else:
|
| device = inputs_embeds.device
|
| seq_length = inputs_embeds.shape[1]
|
| batch_size = inputs_embeds.size(0)
|
|
|
| past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
|
| if cache_position is None:
|
| cache_position = torch.arange(past_key_values_length, past_key_values_length + seq_length, device=device)
|
|
|
| indices = None
|
| if self.config._attn_implementation.startswith("flash_attention"):
|
| if input_ids is None or attention_mask is None:
|
| raise ValueError("Unpadding requires both input_ids and attention_mask")
|
| with torch.no_grad():
|
| indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
|
| features = _unpad_input(input_ids, attention_mask)
|
| input_ids = features["input_ids"].to(device=device)
|
|
|
| position_ids = (features["position_ids"] + 2).to(device=device)
|
| attention_mask = None
|
| kwargs["cu_seq_lens_k"] = features["cu_seq_lens_k"].to(device=device)
|
| kwargs["cu_seq_lens_q"] = features["cu_seq_lens_q"].to(device=device)
|
| kwargs["max_length_k"] = features["max_length_k"]
|
| kwargs["max_length_q"] = features["max_length_q"]
|
|
|
| embedding_output = self.embeddings(
|
| input_ids=input_ids,
|
| position_ids=position_ids,
|
| token_type_ids=token_type_ids,
|
| inputs_embeds=inputs_embeds,
|
| past_key_values_length=past_key_values_length,
|
| )
|
|
|
| attention_mask, encoder_attention_mask = self._create_attention_masks(
|
| attention_mask=attention_mask,
|
| encoder_attention_mask=encoder_attention_mask,
|
| embedding_output=embedding_output,
|
| encoder_hidden_states=encoder_hidden_states,
|
| cache_position=cache_position,
|
| past_key_values=past_key_values,
|
| )
|
|
|
| encoder_outputs = self.encoder(
|
| embedding_output,
|
| attention_mask=attention_mask,
|
| encoder_hidden_states=encoder_hidden_states,
|
| encoder_attention_mask=encoder_attention_mask,
|
| past_key_values=past_key_values,
|
| use_cache=use_cache,
|
| cache_position=cache_position,
|
| position_ids=position_ids,
|
| **kwargs,
|
| )
|
|
|
| sequence_output = encoder_outputs.last_hidden_state
|
| if self.config._attn_implementation.startswith("flash_attention"):
|
| sequence_output = _pad_output(
|
| inputs=sequence_output, indices=indices, batch=batch_size, seqlen=seq_length
|
| )
|
|
|
| pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
| return BaseModelOutputWithPoolingAndCrossAttentions(
|
| last_hidden_state=sequence_output,
|
| pooler_output=pooled_output,
|
| past_key_values=encoder_outputs.past_key_values,
|
| )
|
|
|
|
|
| class UnpadRobertaForCausalLM(RobertaForCausalLM):
|
|
|
| def __init__(self, config):
|
| super().__init__(config)
|
| self.roberta = UnpadRobertaModel(config, add_pooling_layer=False)
|
| self.post_init()
|
|
|
|
|
| class UnpadRobertaForMaskedLM(RobertaForMaskedLM):
|
|
|
| def __init__(self, config):
|
| super().__init__(config)
|
| self.roberta = UnpadRobertaModel(config, add_pooling_layer=False)
|
| self.post_init()
|
|
|
|
|
| class UnpadRobertaForSequenceClassification(RobertaForSequenceClassification):
|
|
|
| def __init__(self, config):
|
| super().__init__(config)
|
| self.roberta = UnpadRobertaModel(config, add_pooling_layer=False)
|
| self.post_init()
|
|
|
|
|
| class UnpadRobertaForTokenClassification(RobertaForTokenClassification):
|
|
|
| def __init__(self, config):
|
| super().__init__(config)
|
| self.roberta = UnpadRobertaModel(config, add_pooling_layer=False)
|
| self.post_init()
|
|
|
|
|
| class UnpadRobertaForMultipleChoice(RobertaForMultipleChoice):
|
|
|
| def __init__(self, config):
|
| super().__init__(config)
|
| self.roberta = UnpadRobertaModel(config)
|
| self.post_init()
|
|
|
|
|
| class UnpadRobertaForQuestionAnswering(RobertaForQuestionAnswering):
|
|
|
| def __init__(self, config):
|
| super().__init__(config)
|
| self.roberta = UnpadRobertaModel(config, add_pooling_layer=False)
|
| self.post_init()
|
|
|
|
|
| def enable_roberta_unpadding():
|
| RobertaModel.forward = UnpadRobertaModel.forward
|
|
|