unpad-impl / modeling_roberta.py
sdadas's picture
Upload modeling_roberta.py
eabfded verified
from typing import Unpack
import torch
from transformers import (
RobertaModel,
Cache,
EncoderDecoderCache,
DynamicCache,
DataCollatorWithFlattening,
RobertaForMaskedLM,
RobertaForSequenceClassification,
RobertaForTokenClassification,
RobertaForQuestionAnswering,
RobertaForMultipleChoice,
RobertaForCausalLM
)
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions
from transformers.utils import TransformersKwargs
def _unpad_input(input_ids: torch.Tensor, attention_mask: torch.Tensor):
collator = DataCollatorWithFlattening(return_flash_attn_kwargs=True)
features = collator([{"input_ids": i[a.bool()].tolist()} for i, a in zip(input_ids, attention_mask)])
return features
def _pad_output(inputs: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int,) -> torch.Tensor:
if inputs.dim() == 3:
inputs = inputs.squeeze()
if inputs.dim() == 1:
output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device)
output[indices] = inputs
padded_inputs = output.view(batch, seqlen)
else:
_, *rest = inputs.shape
output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device)
output[indices] = inputs
padded_inputs = output.view(batch, seqlen, *rest)
return padded_inputs
class UnpadRobertaModel(RobertaModel):
_no_split_modules = ["RobertaEmbeddings", "RobertaLayer"]
def __init__(self, config, add_pooling_layer=True):
super().__init__(config, add_pooling_layer=add_pooling_layer)
def forward(
self,
input_ids: torch.Tensor | None = None,
attention_mask: torch.Tensor | None = None,
token_type_ids: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
inputs_embeds: torch.Tensor | None = None,
encoder_hidden_states: torch.Tensor | None = None,
encoder_attention_mask: torch.Tensor | None = None,
past_key_values: Cache | None = None,
use_cache: bool | None = None,
cache_position: torch.Tensor | None = None,
**kwargs: Unpack[TransformersKwargs],
) -> tuple[torch.Tensor] | BaseModelOutputWithPoolingAndCrossAttentions:
if self.config.is_decoder:
use_cache = use_cache if use_cache is not None else self.config.use_cache
else:
use_cache = False
if use_cache and past_key_values is None:
past_key_values = (
EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
if encoder_hidden_states is not None or self.config.is_encoder_decoder
else DynamicCache(config=self.config)
)
if (input_ids is None) ^ (inputs_embeds is not None):
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
if input_ids is not None:
device = input_ids.device
seq_length = input_ids.shape[1]
batch_size = input_ids.size(0)
else:
device = inputs_embeds.device
seq_length = inputs_embeds.shape[1]
batch_size = inputs_embeds.size(0)
past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
if cache_position is None:
cache_position = torch.arange(past_key_values_length, past_key_values_length + seq_length, device=device)
indices = None
if self.config._attn_implementation.startswith("flash_attention"):
if input_ids is None or attention_mask is None:
raise ValueError("Unpadding requires both input_ids and attention_mask")
with torch.no_grad():
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
features = _unpad_input(input_ids, attention_mask)
input_ids = features["input_ids"].to(device=device)
# roberta requires shifting position_ids by 2
position_ids = (features["position_ids"] + 2).to(device=device)
attention_mask = None
kwargs["cu_seq_lens_k"] = features["cu_seq_lens_k"].to(device=device)
kwargs["cu_seq_lens_q"] = features["cu_seq_lens_q"].to(device=device)
kwargs["max_length_k"] = features["max_length_k"]
kwargs["max_length_q"] = features["max_length_q"]
embedding_output = self.embeddings(
input_ids=input_ids,
position_ids=position_ids,
token_type_ids=token_type_ids,
inputs_embeds=inputs_embeds,
past_key_values_length=past_key_values_length,
)
attention_mask, encoder_attention_mask = self._create_attention_masks(
attention_mask=attention_mask,
encoder_attention_mask=encoder_attention_mask,
embedding_output=embedding_output,
encoder_hidden_states=encoder_hidden_states,
cache_position=cache_position,
past_key_values=past_key_values,
)
encoder_outputs = self.encoder(
embedding_output,
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
cache_position=cache_position,
position_ids=position_ids,
**kwargs,
)
sequence_output = encoder_outputs.last_hidden_state
if self.config._attn_implementation.startswith("flash_attention"):
sequence_output = _pad_output(
inputs=sequence_output, indices=indices, batch=batch_size, seqlen=seq_length
)
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
return BaseModelOutputWithPoolingAndCrossAttentions(
last_hidden_state=sequence_output,
pooler_output=pooled_output,
past_key_values=encoder_outputs.past_key_values,
)
class UnpadRobertaForCausalLM(RobertaForCausalLM):
def __init__(self, config):
super().__init__(config)
self.roberta = UnpadRobertaModel(config, add_pooling_layer=False)
self.post_init()
class UnpadRobertaForMaskedLM(RobertaForMaskedLM):
def __init__(self, config):
super().__init__(config)
self.roberta = UnpadRobertaModel(config, add_pooling_layer=False)
self.post_init()
class UnpadRobertaForSequenceClassification(RobertaForSequenceClassification):
def __init__(self, config):
super().__init__(config)
self.roberta = UnpadRobertaModel(config, add_pooling_layer=False)
self.post_init()
class UnpadRobertaForTokenClassification(RobertaForTokenClassification):
def __init__(self, config):
super().__init__(config)
self.roberta = UnpadRobertaModel(config, add_pooling_layer=False)
self.post_init()
class UnpadRobertaForMultipleChoice(RobertaForMultipleChoice):
def __init__(self, config):
super().__init__(config)
self.roberta = UnpadRobertaModel(config)
self.post_init()
class UnpadRobertaForQuestionAnswering(RobertaForQuestionAnswering):
def __init__(self, config):
super().__init__(config)
self.roberta = UnpadRobertaModel(config, add_pooling_layer=False)
self.post_init()
def enable_roberta_unpadding():
RobertaModel.forward = UnpadRobertaModel.forward