from typing import Unpack import torch from transformers import ( RobertaModel, Cache, EncoderDecoderCache, DynamicCache, DataCollatorWithFlattening, RobertaForMaskedLM, RobertaForSequenceClassification, RobertaForTokenClassification, RobertaForQuestionAnswering, RobertaForMultipleChoice, RobertaForCausalLM ) from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions from transformers.utils import TransformersKwargs def _unpad_input(input_ids: torch.Tensor, attention_mask: torch.Tensor): collator = DataCollatorWithFlattening(return_flash_attn_kwargs=True) features = collator([{"input_ids": i[a.bool()].tolist()} for i, a in zip(input_ids, attention_mask)]) return features def _pad_output(inputs: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int,) -> torch.Tensor: if inputs.dim() == 3: inputs = inputs.squeeze() if inputs.dim() == 1: output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device) output[indices] = inputs padded_inputs = output.view(batch, seqlen) else: _, *rest = inputs.shape output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device) output[indices] = inputs padded_inputs = output.view(batch, seqlen, *rest) return padded_inputs class UnpadRobertaModel(RobertaModel): _no_split_modules = ["RobertaEmbeddings", "RobertaLayer"] def __init__(self, config, add_pooling_layer=True): super().__init__(config, add_pooling_layer=add_pooling_layer) def forward( self, input_ids: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, token_type_ids: torch.Tensor | None = None, position_ids: torch.Tensor | None = None, inputs_embeds: torch.Tensor | None = None, encoder_hidden_states: torch.Tensor | None = None, encoder_attention_mask: torch.Tensor | None = None, past_key_values: Cache | None = None, use_cache: bool | None = None, cache_position: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> tuple[torch.Tensor] | BaseModelOutputWithPoolingAndCrossAttentions: if self.config.is_decoder: use_cache = use_cache if use_cache is not None else self.config.use_cache else: use_cache = False if use_cache and past_key_values is None: past_key_values = ( EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config)) if encoder_hidden_states is not None or self.config.is_encoder_decoder else DynamicCache(config=self.config) ) if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") if input_ids is not None: device = input_ids.device seq_length = input_ids.shape[1] batch_size = input_ids.size(0) else: device = inputs_embeds.device seq_length = inputs_embeds.shape[1] batch_size = inputs_embeds.size(0) past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0 if cache_position is None: cache_position = torch.arange(past_key_values_length, past_key_values_length + seq_length, device=device) indices = None if self.config._attn_implementation.startswith("flash_attention"): if input_ids is None or attention_mask is None: raise ValueError("Unpadding requires both input_ids and attention_mask") with torch.no_grad(): indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() features = _unpad_input(input_ids, attention_mask) input_ids = features["input_ids"].to(device=device) # roberta requires shifting position_ids by 2 position_ids = (features["position_ids"] + 2).to(device=device) attention_mask = None kwargs["cu_seq_lens_k"] = features["cu_seq_lens_k"].to(device=device) kwargs["cu_seq_lens_q"] = features["cu_seq_lens_q"].to(device=device) kwargs["max_length_k"] = features["max_length_k"] kwargs["max_length_q"] = features["max_length_q"] embedding_output = self.embeddings( input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds, past_key_values_length=past_key_values_length, ) attention_mask, encoder_attention_mask = self._create_attention_masks( attention_mask=attention_mask, encoder_attention_mask=encoder_attention_mask, embedding_output=embedding_output, encoder_hidden_states=encoder_hidden_states, cache_position=cache_position, past_key_values=past_key_values, ) encoder_outputs = self.encoder( embedding_output, attention_mask=attention_mask, encoder_hidden_states=encoder_hidden_states, encoder_attention_mask=encoder_attention_mask, past_key_values=past_key_values, use_cache=use_cache, cache_position=cache_position, position_ids=position_ids, **kwargs, ) sequence_output = encoder_outputs.last_hidden_state if self.config._attn_implementation.startswith("flash_attention"): sequence_output = _pad_output( inputs=sequence_output, indices=indices, batch=batch_size, seqlen=seq_length ) pooled_output = self.pooler(sequence_output) if self.pooler is not None else None return BaseModelOutputWithPoolingAndCrossAttentions( last_hidden_state=sequence_output, pooler_output=pooled_output, past_key_values=encoder_outputs.past_key_values, ) class UnpadRobertaForCausalLM(RobertaForCausalLM): def __init__(self, config): super().__init__(config) self.roberta = UnpadRobertaModel(config, add_pooling_layer=False) self.post_init() class UnpadRobertaForMaskedLM(RobertaForMaskedLM): def __init__(self, config): super().__init__(config) self.roberta = UnpadRobertaModel(config, add_pooling_layer=False) self.post_init() class UnpadRobertaForSequenceClassification(RobertaForSequenceClassification): def __init__(self, config): super().__init__(config) self.roberta = UnpadRobertaModel(config, add_pooling_layer=False) self.post_init() class UnpadRobertaForTokenClassification(RobertaForTokenClassification): def __init__(self, config): super().__init__(config) self.roberta = UnpadRobertaModel(config, add_pooling_layer=False) self.post_init() class UnpadRobertaForMultipleChoice(RobertaForMultipleChoice): def __init__(self, config): super().__init__(config) self.roberta = UnpadRobertaModel(config) self.post_init() class UnpadRobertaForQuestionAnswering(RobertaForQuestionAnswering): def __init__(self, config): super().__init__(config) self.roberta = UnpadRobertaModel(config, add_pooling_layer=False) self.post_init() def enable_roberta_unpadding(): RobertaModel.forward = UnpadRobertaModel.forward