| from typing import Unpack | |
| import torch | |
| from transformers import ( | |
| DataCollatorWithFlattening, | |
| ModernBertModel, | |
| ModernBertConfig, | |
| ModernBertForMaskedLM, | |
| ModernBertForSequenceClassification, | |
| ModernBertForTokenClassification, | |
| ModernBertForQuestionAnswering, | |
| ModernBertForMultipleChoice | |
| ) | |
| from transformers.masking_utils import create_bidirectional_mask, create_bidirectional_sliding_window_mask | |
| from transformers.modeling_outputs import BaseModelOutput | |
| from transformers.utils import TransformersKwargs | |
| def _unpad_input(input_ids: torch.Tensor, attention_mask: torch.Tensor): | |
| collator = DataCollatorWithFlattening(return_flash_attn_kwargs=True) | |
| features = collator([{"input_ids": i[a.bool()].tolist()} for i, a in zip(input_ids, attention_mask)]) | |
| return features | |
| def _pad_output(inputs: torch.Tensor, indices: torch.Tensor, batch: int, seqlen: int,) -> torch.Tensor: | |
| if inputs.dim() == 3: | |
| inputs = inputs.squeeze() | |
| if inputs.dim() == 1: | |
| output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device) | |
| output[indices] = inputs | |
| padded_inputs = output.view(batch, seqlen) | |
| else: | |
| _, *rest = inputs.shape | |
| output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device) | |
| output[indices] = inputs | |
| padded_inputs = output.view(batch, seqlen, *rest) | |
| return padded_inputs | |
| class UnpadModernBertModel(ModernBertModel): | |
| def __init__(self, config: ModernBertConfig): | |
| super().__init__(config) | |
| def forward( | |
| self, | |
| input_ids: torch.LongTensor | None = None, | |
| attention_mask: torch.Tensor | None = None, | |
| position_ids: torch.LongTensor | None = None, | |
| inputs_embeds: torch.Tensor | None = None, | |
| **kwargs: Unpack[TransformersKwargs], | |
| ) -> BaseModelOutput: | |
| if (input_ids is None) ^ (inputs_embeds is not None): | |
| raise ValueError("You must specify exactly one of input_ids or inputs_embeds") | |
| seq_len = inputs_embeds.shape[1] if inputs_embeds is not None else input_ids.shape[1] | |
| batch_size = inputs_embeds.shape[0] if inputs_embeds is not None else input_ids.shape[0] | |
| device = input_ids.device if input_ids is not None else inputs_embeds.device | |
| indices = None | |
| if self.config._attn_implementation.startswith("flash_attention"): | |
| if input_ids is None or attention_mask is None: | |
| raise ValueError("Unpadding requires both input_ids and attention_mask") | |
| with torch.no_grad(): | |
| indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() | |
| features = _unpad_input(input_ids, attention_mask) | |
| input_ids = features["input_ids"].to(device=device) | |
| position_ids = features["position_ids"].to(device=device) | |
| attention_mask = None | |
| kwargs["cu_seq_lens_k"] = features["cu_seq_lens_k"].to(device=device) | |
| kwargs["cu_seq_lens_q"] = features["cu_seq_lens_q"].to(device=device) | |
| kwargs["max_length_k"] = features["max_length_k"] | |
| kwargs["max_length_q"] = features["max_length_q"] | |
| if position_ids is None: | |
| position_ids = torch.arange(seq_len, device=device).unsqueeze(0) | |
| hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds) | |
| if not isinstance(attention_mask_mapping := attention_mask, dict): | |
| mask_kwargs = { | |
| "config": self.config, | |
| "inputs_embeds": hidden_states, | |
| "attention_mask": attention_mask, | |
| } | |
| attention_mask_mapping = { | |
| "full_attention": create_bidirectional_mask(**mask_kwargs), | |
| "sliding_attention": create_bidirectional_sliding_window_mask(**mask_kwargs), | |
| } | |
| position_embeddings = {} | |
| for layer_type in self.config.layer_types: | |
| position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type) | |
| for encoder_layer in self.layers: | |
| hidden_states = encoder_layer( | |
| hidden_states, | |
| attention_mask=attention_mask_mapping[encoder_layer.attention_type], | |
| position_embeddings=position_embeddings[encoder_layer.attention_type], | |
| **kwargs, | |
| ) | |
| hidden_states = self.final_norm(hidden_states) | |
| if self.config._attn_implementation.startswith("flash_attention"): | |
| hidden_states = _pad_output( | |
| inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_len | |
| ) | |
| return BaseModelOutput(last_hidden_state=hidden_states) | |
| class UnpadModernBertForMaskedLM(ModernBertForMaskedLM): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = UnpadModernBertModel(config) | |
| self.post_init() | |
| class UnpadModernBertForSequenceClassification(ModernBertForSequenceClassification): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = UnpadModernBertModel(config) | |
| self.post_init() | |
| class UnpadModernBertForTokenClassification(ModernBertForTokenClassification): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = UnpadModernBertModel(config) | |
| self.post_init() | |
| class UnpadModernBertForQuestionAnswering(ModernBertForQuestionAnswering): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = UnpadModernBertModel(config) | |
| self.post_init() | |
| class UnpadModernBertForMultipleChoice(ModernBertForMultipleChoice): | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.model = UnpadModernBertModel(config) | |
| self.post_init() | |
| def enable_modernbert_unpadding(): | |
| ModernBertModel.forward = UnpadModernBertModel.forward | |