from typing import Unpack import torch from transformers import ( DataCollatorWithFlattening, ModernBertModel, ModernBertConfig, ModernBertForMaskedLM, ModernBertForSequenceClassification, ModernBertForTokenClassification, ModernBertForQuestionAnswering, ModernBertForMultipleChoice ) from transformers.masking_utils import create_bidirectional_mask, create_bidirectional_sliding_window_mask from transformers.modeling_outputs import BaseModelOutput from transformers.utils import TransformersKwargs def _unpad_input(input_ids: torch.Tensor, attention_mask: torch.Tensor): collator = DataCollatorWithFlattening(return_flash_attn_kwargs=True) features = collator([{"input_ids": i[a.bool()].tolist()} for i, a in zip(input_ids, attention_mask)]) return features def _pad_output(inputs: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int,) -> torch.Tensor: if inputs.dim() == 3: inputs = inputs.squeeze() if inputs.dim() == 1: output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device) output[indices] = inputs padded_inputs = output.view(batch, seqlen) else: _, *rest = inputs.shape output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device) output[indices] = inputs padded_inputs = output.view(batch, seqlen, *rest) return padded_inputs class UnpadModernBertModel(ModernBertModel): def __init__(self, config: ModernBertConfig): super().__init__(config) def forward( self, input_ids: torch.LongTensor | None = None, attention_mask: torch.Tensor | None = None, position_ids: torch.LongTensor | None = None, inputs_embeds: torch.Tensor | None = None, **kwargs: Unpack[TransformersKwargs], ) -> BaseModelOutput: if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError("You must specify exactly one of input_ids or inputs_embeds") seq_len = inputs_embeds.shape[1] if inputs_embeds is not None else input_ids.shape[1] batch_size = inputs_embeds.shape[0] if inputs_embeds is not None else input_ids.shape[0] device = input_ids.device if input_ids is not None else inputs_embeds.device indices = None if self.config._attn_implementation.startswith("flash_attention"): if input_ids is None or attention_mask is None: raise ValueError("Unpadding requires both input_ids and attention_mask") with torch.no_grad(): indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() features = _unpad_input(input_ids, attention_mask) input_ids = features["input_ids"].to(device=device) position_ids = features["position_ids"].to(device=device) attention_mask = None kwargs["cu_seq_lens_k"] = features["cu_seq_lens_k"].to(device=device) kwargs["cu_seq_lens_q"] = features["cu_seq_lens_q"].to(device=device) kwargs["max_length_k"] = features["max_length_k"] kwargs["max_length_q"] = features["max_length_q"] if position_ids is None: position_ids = torch.arange(seq_len, device=device).unsqueeze(0) hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds) if not isinstance(attention_mask_mapping := attention_mask, dict): mask_kwargs = { "config": self.config, "inputs_embeds": hidden_states, "attention_mask": attention_mask, } attention_mask_mapping = { "full_attention": create_bidirectional_mask(**mask_kwargs), "sliding_attention": create_bidirectional_sliding_window_mask(**mask_kwargs), } position_embeddings = {} for layer_type in self.config.layer_types: position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type) for encoder_layer in self.layers: hidden_states = encoder_layer( hidden_states, attention_mask=attention_mask_mapping[encoder_layer.attention_type], position_embeddings=position_embeddings[encoder_layer.attention_type], **kwargs, ) hidden_states = self.final_norm(hidden_states) if self.config._attn_implementation.startswith("flash_attention"): hidden_states = _pad_output( inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_len ) return BaseModelOutput(last_hidden_state=hidden_states) class UnpadModernBertForMaskedLM(ModernBertForMaskedLM): def __init__(self, config): super().__init__(config) self.model = UnpadModernBertModel(config) self.post_init() class UnpadModernBertForSequenceClassification(ModernBertForSequenceClassification): def __init__(self, config): super().__init__(config) self.model = UnpadModernBertModel(config) self.post_init() class UnpadModernBertForTokenClassification(ModernBertForTokenClassification): def __init__(self, config): super().__init__(config) self.model = UnpadModernBertModel(config) self.post_init() class UnpadModernBertForQuestionAnswering(ModernBertForQuestionAnswering): def __init__(self, config): super().__init__(config) self.model = UnpadModernBertModel(config) self.post_init() class UnpadModernBertForMultipleChoice(ModernBertForMultipleChoice): def __init__(self, config): super().__init__(config) self.model = UnpadModernBertModel(config) self.post_init() def enable_modernbert_unpadding(): ModernBertModel.forward = UnpadModernBertModel.forward