| from datasets import load_dataset |
| from transformers import ( |
| AutoTokenizer, |
| AutoModelForMaskedLM, |
| DataCollatorForLanguageModeling, |
| Trainer, |
| TrainingArguments, |
| ) |
|
|
| from itertools import chain |
| import torch |
|
|
| import transformers as ts |
| from optimi import StableAdamW |
| import os |
|
|
| from transformers.modeling_outputs import * |
| import torch.nn as nn |
| import torch |
| from dataclasses import dataclass |
| from typing import Optional, Tuple |
| import transformers as ts |
| import gc |
|
|
| from transformers import PretrainedConfig |
| import torch.nn.functional as F |
|
|
| from typing import Optional, Union |
|
|
| from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update |
| from transformers.activations import ACT2FN |
| from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss |
|
|
| import math |
|
|
| |
| try: |
| from flash_attn.bert_padding import unpad_input, pad_input |
| from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func |
| from flash_attn.layers.rotary import RotaryEmbedding |
| from flash_attn.ops.triton.rotary import apply_rotary |
| FLASH_ATTN_AVAILABLE = True |
| print("✅ FlashAttention is available.") |
| except ImportError: |
| FLASH_ATTN_AVAILABLE = False |
| print("❌ FlashAttention is not available. Using PyTorch SDPA fallback.") |
|
|
| from .configuration_modernalbert import ModernALBERTConfig |
|
|
| |
| class SharedLoraFFN(nn.Module): |
| """ |
| A shared Feed-Forward Network modified by LoRA weights. |
| The forward pass accepts pre-merged LoRA weights. |
| """ |
| def __init__(self, config): |
| super().__init__() |
| dim = config.hidden_size |
| intermediate_dim = config.expert_intermediate_size |
| |
| self.linear1 = nn.Linear(dim, intermediate_dim) |
| self.act = nn.GELU() |
| self.linear2 = nn.Linear(intermediate_dim, dim) |
| self.lora_scaling = config.lora_alpha / config.lora_rank |
|
|
| def forward(self, x, lora_A1, lora_B1, lora_A2, lora_B2): |
| |
| |
| expanded = self.linear1(x) + (x @ lora_A1.T @ lora_B1.T) * self.lora_scaling |
| activated = self.act(expanded) |
| contracted = self.linear2(activated) + (activated @ lora_A2.T @ lora_B2.T) * self.lora_scaling |
| return contracted |
|
|
| |
| class SwitchRouterTopK(nn.Module): |
| """ |
| Calculates the EMA weights for expert merging. |
| Optimized for unpadded (Flash Attention) inputs where shape is (total_nnz, dim). |
| """ |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| |
| self.num_experts = config.num_experts |
| |
| self.ema_decay = getattr(config, "router_ema_decay", 0.99) |
| |
| self.layer = nn.Linear(config.hidden_size, config.num_experts, bias=False) |
| self.k = config.top_k |
| self.jitter_noise = config.router_jitter_noise |
| |
| |
| self.register_buffer("ema_weights", torch.ones(config.num_experts) / config.num_experts) |
| |
| def forward(self, hidden_states): |
| |
| if self.config.routing_strategy == "ema": |
| |
| logits = self.layer(hidden_states) |
| probs = F.softmax(logits, dim=-1) |
| |
| if self.training: |
| |
| |
| |
| r_b = probs.mean(dim=0) |
| |
| |
| |
| weights_for_forward = self.ema_decay * self.ema_weights.detach() + (1 - self.ema_decay) * r_b |
| |
| |
| new_ema_value = weights_for_forward.detach() |
| self.ema_weights.copy_(new_ema_value) |
| |
| |
| self.ema_weights.div_(self.ema_weights.sum() + 1e-9) |
| |
| return weights_for_forward |
| |
| |
| return self.ema_weights |
| else: |
| num_tokens = hidden_states.shape[0] |
| |
| |
| |
| |
| |
| logits = self.layer(hidden_states) |
| probs = F.softmax(logits, dim=-1, dtype=torch.float32) |
| topk_probs, topk_indices = torch.topk(probs, k=self.k, dim=-1) |
| topk_probs_normalized = topk_probs / torch.sum(topk_probs, dim=-1, keepdim=True) |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| one_hot = F.one_hot(topk_indices, num_classes=self.num_experts).float() |
| tokens_per_expert = torch.sum(one_hot * topk_probs.unsqueeze(-1), dim=(0, 1)) / num_tokens |
| router_prob_per_expert = torch.mean(probs, dim=0) |
| |
| aux_loss = self.num_experts * torch.mean(tokens_per_expert * router_prob_per_expert) |
| |
| |
| |
| |
| |
| return topk_indices, topk_probs_normalized, aux_loss |
|
|
| |
| class LoraMoELayerTopK(nn.Module): |
| """ |
| Implements the MoL layer with expert merging. |
| Allows for efficient dense computation by collapsing experts |
| into a single adapter based on router weights. |
| """ |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| |
| dim = config.hidden_size |
| expert_intermediate_dim = config.expert_intermediate_size |
| num_experts = config.num_experts |
| lora_rank = config.lora_rank |
| |
| self.k = config.top_k |
| self.num_experts = num_experts |
| |
| self.norm = nn.LayerNorm(dim, eps=config.layer_norm_eps) |
| |
| self.router = SwitchRouterTopK(config) |
| self.shared_ffn = SharedLoraFFN(config) |
| |
| |
| self.lora_A1 = nn.Parameter(torch.randn(num_experts, lora_rank, dim)) |
| self.lora_B1 = nn.Parameter(torch.zeros(num_experts, expert_intermediate_dim, lora_rank)) |
| self.lora_A2 = nn.Parameter(torch.randn(num_experts, lora_rank, expert_intermediate_dim)) |
| self.lora_B2 = nn.Parameter(torch.zeros(num_experts, dim, lora_rank)) |
| |
| |
| for i in range(num_experts): |
| nn.init.kaiming_uniform_(self.lora_A1[i], a=math.sqrt(5)) |
| nn.init.kaiming_uniform_(self.lora_A2[i], a=math.sqrt(5)) |
|
|
| def forward(self, hidden_states: torch.Tensor): |
| if self.config.routing_strategy == "ema": |
| residual = hidden_states |
| hidden_states_norm = self.norm(hidden_states) |
| |
| |
| |
| merge_weights = self.router(hidden_states_norm) |
| |
| |
| |
| |
| w = merge_weights.view(-1, 1, 1) |
| |
| merged_A1 = torch.sum(w * self.lora_A1, dim=0) |
| merged_B1 = torch.sum(w * self.lora_B1, dim=0) |
| merged_A2 = torch.sum(w * self.lora_A2, dim=0) |
| merged_B2 = torch.sum(w * self.lora_B2, dim=0) |
| |
| |
| |
| output = self.shared_ffn( |
| hidden_states_norm, |
| merged_A1, merged_B1, |
| merged_A2, merged_B2 |
| ) |
| |
| |
| return residual + output, torch.tensor(0.0, device=hidden_states.device) |
| elif self.config.routing_strategy == "uniform": |
| residual = hidden_states |
| hidden_states_norm = self.norm(hidden_states) |
| |
| |
| |
| merge_weights = torch.ones(self.config.num_experts, dtype=hidden_states_norm.dtype, device=hidden_states_norm.device) / (self.config.num_experts) |
| |
| |
| |
| |
| w = merge_weights.view(-1, 1, 1) |
| |
| merged_A1 = torch.sum(w * self.lora_A1, dim=0) |
| merged_B1 = torch.sum(w * self.lora_B1, dim=0) |
| merged_A2 = torch.sum(w * self.lora_A2, dim=0) |
| merged_B2 = torch.sum(w * self.lora_B2, dim=0) |
| |
| |
| |
| output = self.shared_ffn( |
| hidden_states_norm, |
| merged_A1, merged_B1, |
| merged_A2, merged_B2 |
| ) |
| |
| |
| return residual + output, torch.tensor(0.0, device=hidden_states.device) |
| else: |
| residual = hidden_states |
| hidden_states_norm = self.norm(hidden_states) |
| num_tokens, dim = hidden_states_norm.shape |
| |
| topk_indices, topk_probs, aux_loss = self.router(hidden_states_norm) |
| |
| |
| flat_token_indices = torch.arange(num_tokens, device=hidden_states.device).repeat_interleave(self.k) |
| flat_expert_indices = topk_indices.flatten() |
| |
| perm_indices = torch.argsort(flat_expert_indices) |
| sorted_token_indices = flat_token_indices[perm_indices] |
| sorted_expert_indices = flat_expert_indices[perm_indices] |
| |
| permuted_tokens = hidden_states_norm[sorted_token_indices] |
| permuted_probs = topk_probs.flatten()[perm_indices] |
| |
| tokens_per_expert = F.one_hot(sorted_expert_indices, self.num_experts).sum(dim=0) |
| split_tokens = torch.split(permuted_tokens, tokens_per_expert.tolist(), dim=0) |
| split_probs = torch.split(permuted_probs, tokens_per_expert.tolist(), dim=0) |
| |
| |
| expert_outputs = [] |
| for i in range(self.num_experts): |
| if tokens_per_expert[i] > 0: |
| output = self.shared_ffn( |
| split_tokens[i], |
| self.lora_A1[i], self.lora_B1[i], |
| self.lora_A2[i], self.lora_B2[i] |
| ) |
| expert_outputs.append(output * split_probs[i].unsqueeze(1)) |
| else: |
| expert_outputs.append(torch.empty(0, dim, device=hidden_states.device)) |
| |
| |
| concatenated_outputs = torch.cat(expert_outputs, dim=0) |
| inverse_perm_indices = torch.argsort(perm_indices) |
| unpermuted_outputs = concatenated_outputs[inverse_perm_indices] |
| |
| final_output = unpermuted_outputs.view(num_tokens, self.k, dim).sum(dim=1) |
| |
| |
| output = residual + final_output |
| return output, aux_loss |
|
|
| |
| class ModernAlbertMLP(nn.Module): |
| def __init__(self, config: ModernALBERTConfig): |
| super().__init__() |
| self.config = config |
| self.Wi = nn.Linear(config.hidden_size, int(config.intermediate_size) * 2, bias=False) |
| |
| self.act = ACT2FN["gelu"] |
| self.drop = nn.Dropout(config.hidden_dropout_prob) |
| self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=False) |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| input, gate = self.Wi(hidden_states).chunk(2, dim=-1) |
| return self.Wo(self.drop(self.act(input) * gate)) |
|
|
| |
| class ApplyRotaryEmbUnpad(torch.autograd.Function): |
| @staticmethod |
| def forward( |
| ctx, |
| qkv, |
| cos, |
| sin, |
| cu_seqlens: Optional[torch.Tensor] = None, |
| max_seqlen: Optional[int] = None, |
| ): |
| |
| qkv = qkv.contiguous() |
| total_nnz, _three, _nheads, headdim = qkv.shape |
| |
| |
| |
| qk = qkv[:, :2].view(total_nnz, -1, headdim) |
| apply_rotary( |
| qk, |
| cos, |
| sin, |
| seqlen_offsets=0, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| interleaved=False, |
| inplace=True, |
| ) |
|
|
| ctx.save_for_backward(cos, sin, cu_seqlens) |
| ctx.max_seqlen = max_seqlen |
| return qkv |
|
|
| @staticmethod |
| def backward(ctx, do): |
| cos, sin, cu_seqlens = ctx.saved_tensors |
| do = do.contiguous() |
| total_nnz, _three, _nheads, headdim = do.shape |
| |
| |
| dqk = do[:, :2].view(total_nnz, -1, headdim) |
| apply_rotary( |
| dqk, |
| cos, |
| sin, |
| seqlen_offsets=0, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=ctx.max_seqlen, |
| interleaved=False, |
| inplace=True, |
| conjugate=True, |
| ) |
|
|
| return do, None, None, None, None, None, None |
|
|
|
|
| def apply_rotary_unpadded( |
| qkv, |
| cos, |
| sin, |
| cu_seqlens: Optional[torch.Tensor] = None, |
| max_seqlen: Optional[int] = None, |
| ): |
| """ |
| Arguments: |
| qkv: (total_nnz, 3, nheads, headdim) - input tensor for packed QKV. |
| cos, sin: (seqlen_rotary, rotary_dim / 2) |
| interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead |
| of 1st half and 2nd half (GPT-NeoX style). |
| inplace: if True, apply rotary embedding in-place. |
| seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount. |
| Most commonly used in inference when we have KV cache. |
| cu_seqlens: (batch + 1,) or None |
| max_seqlen: int |
| Return: |
| out: (total_nnz, dim) |
| rotary_dim must be <= headdim |
| Apply rotary embedding to the first rotary_dim of x. |
| """ |
| return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen) |
|
|
|
|
| class ModernAlbertUnpaddedRotaryEmbedding(RotaryEmbedding): |
| """ |
| The rotary position embeddings applied directly to unpadded sequences. |
| """ |
|
|
| def __init__( |
| self, |
| dim: int, |
| base: float = 10000.0, |
| max_seqlen: Optional[int] = None, |
| device: Optional[torch.device] = None, |
| dtype: Optional[torch.dtype] = None, |
| ): |
| """ |
| max_seqlen: if max_seqlen, device, and dtype are provided, we precompute the cos_sin_cache |
| up to max_seqlen. If the max_seqlen, device, or dtype during training/inference differ, |
| the cos_sin_cache will be recomputed during the forward pass. |
| """ |
| super().__init__(dim=dim, base=base, device=device, interleaved=False) |
| self.max_seqlen = max_seqlen |
|
|
| if max_seqlen is not None and device is not None and dtype is not None: |
| self._update_cos_sin_cache(max_seqlen, device=device, dtype=dtype) |
|
|
| def forward( |
| self, |
| qkv: torch.Tensor, |
| cu_seqlens: torch.Tensor, |
| max_seqlen: Optional[int] = None, |
| ) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]: |
| """ |
| Apply rotary embedding *inplace* to qkv. |
| qkv: (total_nnz, 3, nheads, headdim) |
| cu_seqlens: (batch + 1,) cumulative sequence lengths |
| max_seqlen: int max seq length in the batch |
| """ |
| if max_seqlen is not None: |
| self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype) |
|
|
| qkv = apply_rotary_unpadded( |
| qkv, |
| self._cos_cached, |
| self._sin_cached, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| ) |
|
|
| return qkv |
|
|
| def extra_repr(self) -> str: |
| return f"dim={self.dim}, base={self.base}, scale_base={self.scale_base}" |
|
|
| class ModernAlbertRotaryEmbedding(nn.Module): |
| def __init__(self, config: ModernALBERTConfig, device=None): |
| super().__init__() |
| |
| if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict): |
| self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type")) |
| else: |
| self.rope_type = "default" |
| self.max_seq_len_cached = config.max_position_embeddings |
| self.original_max_seq_len = config.max_position_embeddings |
|
|
| self.config = config |
| self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type] |
|
|
| inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device) |
| self.register_buffer("inv_freq", inv_freq, persistent=False) |
| self.original_inv_freq = self.inv_freq |
|
|
| @torch.no_grad() |
| @dynamic_rope_update |
| def forward(self, x, position_ids): |
| inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device) |
| position_ids_expanded = position_ids[:, None, :].float() |
|
|
| device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu" |
| with torch.autocast(device_type=device_type, enabled=False): |
| freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2) |
| emb = torch.cat((freqs, freqs), dim=-1) |
| cos = emb.cos() * self.attention_scaling |
| sin = emb.sin() * self.attention_scaling |
|
|
| return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) |
|
|
| |
| class GeGLU(nn.Module): |
| def __init__(self, dim_in, dim_out): |
| super().__init__() |
| self.w1 = nn.Linear(dim_in, dim_out) |
| self.w2 = nn.Linear(dim_in, dim_out) |
| def forward(self, x): |
| return F.gelu(self.w1(x)) * self.w2(x) |
|
|
|
|
| |
| def _unpad_modernbert_input( |
| inputs: torch.Tensor, |
| attention_mask: torch.Tensor, |
| position_ids: Optional[torch.Tensor] = None, |
| labels: Optional[torch.Tensor] = None, |
| ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, Optional[torch.Tensor], Optional[torch.Tensor]]: |
| """ |
| Remove padding from input sequences. |
| |
| Args: |
| inputs: (batch, seqlen, ...) or (batch, seqlen) |
| attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid. |
| position_ids: (batch, seqlen), int, position ids |
| labels: (batch, seqlen), int, labels |
| |
| Returns: |
| unpadded_inputs: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask. |
| indices: (total_nnz) |
| cu_seqlens: (batch + 1), the cumulative sequence lengths |
| max_seqlen_in_batch: int |
| unpadded_position_ids: (total_nnz) or None |
| unpadded_labels: (total_nnz) or None |
| """ |
| seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32) |
| indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten() |
| max_seqlen_in_batch = int(seqlens_in_batch.max().item()) |
| cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)) |
|
|
| if inputs.dim() == 2: |
| unpadded_inputs = inputs.flatten()[indices] |
| else: |
| batch, seqlen, *rest = inputs.shape |
| shape = batch * seqlen |
| unpadded_inputs = inputs.view(shape, *rest)[indices] |
|
|
| unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None |
| unpadded_labels = labels.flatten()[indices] if labels is not None else None |
|
|
| return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels |
|
|
|
|
| def _pad_modernbert_output( |
| inputs: torch.Tensor, |
| indices: torch.Tensor, |
| batch: int, |
| seqlen: int, |
| ) -> torch.Tensor: |
| """ |
| Add padding to sequences. |
| |
| Args: |
| inputs: (total_nnz, ...) or (total_nnz,), where total_nnz = number of tokens selected in attention_mask. |
| indices: (total_nnz) |
| batch: int, batch size |
| seqlen: int, max sequence length |
| |
| Returns: |
| padded_inputs: (batch, seqlen, ...) or (batch, seqlen) |
| """ |
| if inputs.dim() == 1: |
| output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device) |
| output[indices] = inputs |
| padded_inputs = output.view(batch, seqlen) |
| else: |
| _, *rest = inputs.shape |
| output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device) |
| output[indices] = inputs |
| padded_inputs = output.view(batch, seqlen, *rest) |
|
|
| return padded_inputs |
|
|
|
|
| def flash_attention_forward( |
| module: "SharedGroup", |
| qkv: torch.Tensor, |
| rotary_emb: ModernAlbertUnpaddedRotaryEmbedding, |
| cu_seqlens: torch.Tensor, |
| max_seqlen: int, |
| local_attention: tuple[int, int], |
| bs: int, |
| dim: int, |
| target_dtype: torch.dtype = torch.bfloat16, |
| **_kwargs, |
| ) -> tuple[torch.Tensor]: |
| |
| qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen) |
|
|
| convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16) |
| if convert_dtype: |
| |
| |
| orig_dtype = qkv.dtype |
| qkv = qkv.to(target_dtype) |
|
|
| attn = flash_attn_varlen_qkvpacked_func( |
| qkv, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| dropout_p=module.att_dropout.p if module.training else 0.0, |
| |
| |
| window_size=local_attention, |
| ) |
| attn = attn.to(orig_dtype) |
| else: |
| attn = flash_attn_varlen_qkvpacked_func( |
| qkv, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| dropout_p=module.att_dropout.p if module.training else 0.0, |
| |
| window_size=local_attention, |
| ) |
| return (attn.view(bs, dim),) |
|
|
| def sdpa_attention_forward( |
| module: "SharedGroup", |
| qkv: torch.Tensor, |
| attention_mask: torch.Tensor, |
| sliding_window_mask: torch.Tensor, |
| position_ids: Optional[torch.LongTensor], |
| local_attention: tuple[int, int], |
| bs: int, |
| dim: int, |
| **_kwargs, |
| ) -> tuple[torch.Tensor]: |
| |
| cos, sin = module.rotary_emb(qkv, position_ids=position_ids) |
| query, key, value = qkv.transpose(3, 1).unbind(dim=2) |
| |
| query, key = apply_rotary_pos_emb(query, key, cos, sin) |
|
|
| if local_attention != (-1, -1): |
| attention_mask = sliding_window_mask |
|
|
| attn_output = ( |
| F.scaled_dot_product_attention( |
| query, |
| key, |
| value, |
| dropout_p=module.attention_dropout.p if module.training else 0.0, |
| attn_mask=attention_mask, |
| ) |
| .transpose(1, 2) |
| .contiguous() |
| ) |
| attn_output = attn_output.view(bs, -1, dim) |
| return (attn_output,) |
|
|
| class SharedGroup(nn.Module): |
| def __init__(self, config): |
| super().__init__() |
| self.config = config |
| |
| hs, nh = config.hidden_size, config.num_attention_heads |
| self.head_dim = hs // nh |
| self.num_heads = nh |
| self.use_adapter = config.use_adapter |
| eps = config.layer_norm_eps |
|
|
| rope_theta = 10000 |
| |
| |
| self.att_pre_norm = nn.LayerNorm(hs, eps=eps) |
| self.ffn_pre_norm = nn.LayerNorm(hs, eps=eps) |
| |
| |
| self.qkv = nn.Linear(hs, 3 * hs) |
| self.out_proj = nn.Linear(hs, hs) |
| self.att_dropout = nn.Dropout(config.attention_probs_dropout_prob) |
| self.local_attention = (-1, -1) |
| |
| if FLASH_ATTN_AVAILABLE: |
| self.rotary_emb = ModernAlbertUnpaddedRotaryEmbedding( |
| dim=self.head_dim, max_seqlen=config.max_position_embeddings, base=rope_theta |
| ) |
| else: |
| config_copy = copy.deepcopy(config) |
| config_copy.rope_theta = rope_theta |
| self.rotary_emb = ModernAlbertRotaryEmbedding(config=config_copy) |
| |
| |
| self.mlp = ModernAlbertMLP(config) |
|
|
| def forward(self, inputs, mask, config, start_idx=0, use_moa=False, **kwargs): |
| outputs = [] if config.output_hidden_states else None |
| attn_maps = [] if config.output_attentions else None |
|
|
| x = inputs |
|
|
| for i in range(config.group_depth): |
| h = x |
| h_norm = self.att_pre_norm(h) |
| |
| qkv_proj = self.qkv(h_norm) |
| bs = h.shape[0] |
|
|
| |
| if FLASH_ATTN_AVAILABLE: |
| qkv = qkv_proj.view(-1, 3, self.num_heads, self.head_dim) |
| |
| attn_outputs = flash_attention_forward( |
| self, |
| qkv=qkv, |
| rotary_emb=self.rotary_emb, |
| local_attention=self.local_attention, |
| bs=bs, |
| dim=self.head_dim * self.num_heads, |
| **kwargs, |
| ) |
| |
| attn_out = attn_outputs[0] |
| else: |
| qkv = qkv_proj.view(bs, -1, 3, self.num_heads, self.head_dim) |
| attn_mask = mask[:, None, None, :] |
| |
| attn_outputs = sdpa_attention_forward( |
| self, |
| qkv=qkv, |
| rotary_emb=self.rotary_emb, |
| local_attention=self.local_attention, |
| bs=bs, |
| dim=self.head_dim * self.num_heads, |
| **kwargs, |
| ) |
| |
| attn_out = attn_outputs[0] |
|
|
| x = self.att_dropout(self.out_proj(attn_out)) + h |
|
|
| if use_moa == True and i == config.group_depth - 1: |
| return x, outputs, attn_maps |
| else: |
| |
| h2 = x |
| h2_norm = self.ffn_pre_norm(h2) |
| |
| x = self.mlp(h2_norm) + h2 |
| |
| |
| if config.output_hidden_states: |
| outputs.append(x) |
| |
| return x, outputs, attn_maps |
|
|
|
|
| class ModernAlbertEmbeddings(nn.Module): |
| """ |
| Same as BertEmbeddings with a tiny tweak for positional embeddings indexing. |
| """ |
|
|
| def __init__(self, config: ModernALBERTConfig): |
| super().__init__() |
| self.config = config |
| |
| self.tok_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id) |
| self.embed_proj = nn.Linear(config.embedding_size, config.hidden_size) |
| |
| |
| self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, bias=False) |
| self.drop = nn.Dropout(0.0) |
|
|
| def forward( |
| self, input_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.Tensor] = None |
| ) -> torch.Tensor: |
| if inputs_embeds is not None: |
| hidden_states = self.drop(self.norm(self.embed_proj(inputs_embeds))) |
| else: |
| hidden_states = self.drop(self.norm(self.embed_proj(self.tok_embeddings(input_ids)))) |
| |
| return hidden_states |
|
|
| @dataclass |
| class MoABaseModelOutput(BaseModelOutput): |
| load_balancing_loss: Optional[torch.FloatTensor] = None |
|
|
| class ModernALBERTModel(ts.PreTrainedModel): |
| config_class = ModernALBERTConfig |
| base_model_prefix = "modernAlbert" |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_cache_class = True |
|
|
| def __init__(self, config: ModernALBERTConfig): |
| super().__init__(config) |
| self.config = config |
| |
| |
| self.embeddings = ModernAlbertEmbeddings(config) |
| self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, bias=False) |
| self.num_groups = config.num_hidden_layers // config.group_depth |
| self.groups = nn.ModuleList([SharedGroup(config) for _ in range(self.num_groups)]) |
| |
| if config.use_moa: |
| self.moa_layers = nn.ModuleList([ |
| |
| LoraMoELayerTopK(config) for _ in range(config.num_expert_modules) |
| ]) |
| |
| self.pooler = nn.Linear(config.hidden_size, config.hidden_size) |
| self.post_init() |
|
|
| def forward(self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| sliding_window_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.LongTensor] = None, |
| inputs_embeds: Optional[torch.Tensor] = None, |
| indices: Optional[torch.Tensor] = None, |
| cu_seqlens: Optional[torch.Tensor] = None, |
| max_seqlen: Optional[int] = None, |
| batch_size: Optional[int] = None, |
| seq_len: Optional[int] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| ): |
| all_hidden_states = [] |
| |
| if batch_size is None and seq_len is None: |
| if inputs_embeds is not None: |
| batch_size, seq_len = inputs_embeds.shape[:2] |
| else: |
| batch_size, seq_len = input_ids.shape[:2] |
| |
| output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
| output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
|
| if output_hidden_states: |
| self.config.output_hidden_states = True |
|
|
| hs, atts = ([] if output_hidden_states else None), ([] if output_attentions else None) |
| all_aux_losses = [] |
| |
| repad = False |
| if FLASH_ATTN_AVAILABLE: |
| if indices is None and cu_seqlens is None and max_seqlen is None: |
| repad = True |
| if inputs_embeds is None: |
| with torch.no_grad(): |
| input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input( |
| inputs=input_ids, attention_mask=attention_mask |
| ) |
| else: |
| inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input( |
| inputs=inputs_embeds, attention_mask=attention_mask |
| ) |
| else: |
| if position_ids is None: |
| position_ids = torch.arange(seq_len, device=device).unsqueeze(0) |
|
|
| attention_mask, sliding_window_mask = self._update_attention_mask( |
| attention_mask, output_attentions=output_attentions |
| ) |
|
|
| hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds) |
| x = hidden_states |
|
|
| if output_hidden_states: |
| hs.append(x) |
|
|
| |
| mask = None |
| if attention_mask is not None: |
| mask = attention_mask.to(torch.bool) |
| |
| for i, group in enumerate(self.groups): |
|
|
| is_moa = self.config.use_moa and (i > len(self.groups) - len(self.moa_layers) - 1) |
| moa_idx = i - (len(self.groups) - len(self.moa_layers)) |
| |
| x, layer_hs, layer_atts = group(x, |
| mask, |
| self.config, |
| sliding_window_mask=sliding_window_mask, |
| position_ids=position_ids, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| use_moa=is_moa, |
| output_attentions=output_attentions,) |
| |
| if output_hidden_states and layer_hs: |
| hs.extend(layer_hs) |
| if output_attentions and layer_atts: |
| atts.extend(layer_atts) |
|
|
| |
| if self.config.use_moa and is_moa: |
| x, aux_loss = self.moa_layers[moa_idx](x) |
| if output_hidden_states: |
| hs.append(x) |
| all_aux_losses.append(aux_loss) |
|
|
|
|
| hidden_states = self.final_norm(x) |
|
|
| |
| |
| |
| |
| if repad: |
| hidden_states = _pad_modernbert_output( |
| inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_len |
| ) |
| if all_hidden_states is not None: |
| all_hidden_states = tuple( |
| _pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len) |
| for hs in all_hidden_states |
| ) |
|
|
|
|
| load_balancing_loss = None |
| if all_aux_losses != []: |
| load_balancing_loss = torch.stack(all_aux_losses).mean() * self.config.load_balancing_loss_coef |
|
|
| return MoABaseModelOutput(last_hidden_state=hidden_states, hidden_states=hs, attentions=atts, load_balancing_loss=load_balancing_loss) |
|
|
| class ModernAlbertPredictionHead(nn.Module): |
| def __init__(self, config: ModernALBERTConfig): |
| super().__init__() |
| self.config = config |
| self.dense = nn.Linear(config.hidden_size, config.embedding_size, bias=False) |
| self.act = ACT2FN["gelu"] |
| self.norm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps, bias=False) |
|
|
| def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: |
| return self.norm(self.act(self.dense(hidden_states))) |
|
|
| class ModernALBERTForMaskedLM(ts.PreTrainedModel): |
| """ |
| Modern ALBERT model with a Masked Language Modeling (MLM) head, |
| optimized to mirror the HuggingFace `AlbertForMaskedLM` API. |
| """ |
| _tied_weights_keys = ["decoder.weight"] |
| config_class = ModernALBERTConfig |
| base_model_prefix = "modernAlbert" |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_cache_class = True |
|
|
| def __init__(self, config: ModernALBERTConfig): |
| super().__init__(config) |
| self.config = config |
|
|
| |
| self.albert = ModernALBERTModel(config) |
|
|
| |
| self.head = ModernAlbertPredictionHead(config) |
| self.decoder = nn.Linear(config.embedding_size, config.vocab_size, bias=False) |
|
|
| self.post_init() |
|
|
| def get_input_embeddings(self): |
| return self.albert.embeddings.tok_embeddings |
|
|
| def get_output_embeddings(self): |
| return self.decoder |
|
|
| def set_output_embeddings(self, new_embeddings: nn.Linear): |
| self.decoder = new_embeddings |
|
|
| @torch.compile(dynamic=True) |
| def compiled_head(self, output: torch.Tensor) -> torch.Tensor: |
| return self.decoder(self.head(output)) |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| sliding_window_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| inputs_embeds: Optional[torch.Tensor] = None, |
| labels: Optional[torch.Tensor] = None, |
| indices: Optional[torch.Tensor] = None, |
| cu_seqlens: Optional[torch.Tensor] = None, |
| max_seqlen: Optional[int] = None, |
| batch_size: Optional[int] = None, |
| seq_len: Optional[int] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| **kwargs, |
| ): |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
| if FLASH_ATTN_AVAILABLE: |
| if indices is None and cu_seqlens is None and max_seqlen is None: |
| if batch_size is None and seq_len is None: |
| if inputs_embeds is not None: |
| batch_size, seq_len = inputs_embeds.shape[:2] |
| else: |
| batch_size, seq_len = input_ids.shape[:2] |
| device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
|
| if attention_mask is None: |
| attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool) |
|
|
| if inputs_embeds is None: |
| with torch.no_grad(): |
| input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input( |
| inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels |
| ) |
| else: |
| inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input( |
| inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels |
| ) |
| |
| |
| outputs = self.albert( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| sliding_window_mask=sliding_window_mask, |
| position_ids=position_ids, |
| inputs_embeds=inputs_embeds, |
| indices=indices, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| batch_size=batch_size, |
| seq_len=seq_len, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
|
|
| last_hidden_state = outputs[0] |
|
|
| if FLASH_ATTN_AVAILABLE: |
| last_hidden_state_unpaded = _pad_modernbert_output(inputs=last_hidden_state, indices=indices, batch=batch_size, seqlen=seq_len) |
| if outputs.hidden_states != None: |
| outputs.hidden_states.append(last_hidden_state_unpaded) |
| |
| logits = self.decoder(self.head(last_hidden_state)) |
| |
| loss = None |
| |
| if labels is not None: |
| loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size, **kwargs) |
|
|
| if outputs.load_balancing_loss != None and self.training: |
| |
| loss += outputs.load_balancing_loss |
| |
| if FLASH_ATTN_AVAILABLE: |
| logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len) |
|
|
| if not return_dict: |
| output = (logits,) |
| return ((loss,) + output) if loss is not None else output |
|
|
| return MaskedLMOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
| class ModernALBERTForSequenceClassification(ts.PreTrainedModel): |
| config_class = ModernALBERTConfig |
| |
| def __init__(self, config: ModernALBERTConfig): |
| super().__init__(config) |
| self.num_labels = config.num_labels |
| self.config = config |
|
|
| self.albert = ModernALBERTModel(config) |
| self.head = ModernAlbertPredictionHead(config) |
| self.drop = torch.nn.Dropout(0.0) |
| self.classifier = nn.Linear(config.embedding_size, config.num_labels) |
|
|
| |
| self.post_init() |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.LongTensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| sliding_window_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| inputs_embeds: Optional[torch.Tensor] = None, |
| labels: Optional[torch.Tensor] = None, |
| indices: Optional[torch.Tensor] = None, |
| cu_seqlens: Optional[torch.Tensor] = None, |
| max_seqlen: Optional[int] = None, |
| batch_size: Optional[int] = None, |
| seq_len: Optional[int] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| **kwargs, |
| ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]: |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
|
| if FLASH_ATTN_AVAILABLE: |
| if indices is None and cu_seqlens is None and max_seqlen is None: |
| if batch_size is None and seq_len is None: |
| if inputs_embeds is not None: |
| batch_size, seq_len = inputs_embeds.shape[:2] |
| else: |
| batch_size, seq_len = input_ids.shape[:2] |
| device = input_ids.device if input_ids is not None else inputs_embeds.device |
|
|
| if attention_mask is None: |
| attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool) |
|
|
| if inputs_embeds is None: |
| with torch.no_grad(): |
| input_ids, indices, cu_seqlens, max_seqlen, position_ids, _ = _unpad_modernbert_input( |
| inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=None |
| ) |
| else: |
| inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, _ = _unpad_modernbert_input( |
| inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=None |
| ) |
| |
| outputs = self.albert( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| |
| position_ids=position_ids, |
| inputs_embeds=inputs_embeds, |
| indices=indices, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| batch_size=batch_size, |
| seq_len=seq_len, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
| |
| last_hidden_state = outputs[0] |
| last_hidden_state = _pad_modernbert_output(inputs=last_hidden_state, indices=indices, batch=batch_size, seqlen=seq_len) |
|
|
| |
| |
| |
| last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum( |
| dim=1, keepdim=True |
| ) |
|
|
| pooled_output = self.head(last_hidden_state) |
| pooled_output = self.drop(pooled_output) |
| logits = self.classifier(pooled_output) |
|
|
| loss = None |
| if labels is not None: |
| if self.config.problem_type is None: |
| if self.num_labels == 1: |
| self.config.problem_type = "regression" |
| elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): |
| self.config.problem_type = "single_label_classification" |
| else: |
| self.config.problem_type = "multi_label_classification" |
|
|
| if self.config.problem_type == "regression": |
| loss_fct = MSELoss() |
| if self.num_labels == 1: |
| loss = loss_fct(logits.squeeze(), labels.squeeze()) |
| else: |
| loss = loss_fct(logits, labels) |
| elif self.config.problem_type == "single_label_classification": |
| loss_fct = CrossEntropyLoss() |
| loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) |
| elif self.config.problem_type == "multi_label_classification": |
| loss_fct = BCEWithLogitsLoss() |
| loss = loss_fct(logits, labels) |
|
|
| if not return_dict: |
| output = (logits,) |
| return ((loss,) + output) if loss is not None else output |
|
|
| return SequenceClassifierOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
| class ModernALBERTForQuestionAnswering(ts.PreTrainedModel): |
| config_class = ModernALBERTConfig |
|
|
| def __init__(self, config: ModernALBERTConfig): |
| super().__init__(config) |
| self.num_labels = config.num_labels |
|
|
| self.albert = ModernALBERTModel(config) |
| self.head = ModernAlbertPredictionHead(config) |
| self.drop = torch.nn.Dropout(0.0) |
| self.classifier_head = nn.Linear(config.embedding_size, config.num_labels) |
|
|
| self.post_init() |
|
|
| def forward( |
| self, |
| input_ids: Optional[torch.Tensor], |
| attention_mask: Optional[torch.Tensor] = None, |
| sliding_window_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| start_positions: Optional[torch.Tensor] = None, |
| end_positions: Optional[torch.Tensor] = None, |
| indices: Optional[torch.Tensor] = None, |
| cu_seqlens: Optional[torch.Tensor] = None, |
| max_seqlen: Optional[int] = None, |
| batch_size: Optional[int] = None, |
| seq_len: Optional[int] = None, |
| output_attentions: Optional[bool] = None, |
| output_hidden_states: Optional[bool] = None, |
| return_dict: Optional[bool] = None, |
| **kwargs, |
| ) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]: |
|
|
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
| |
|
|
| outputs = self.albert( |
| input_ids, |
| attention_mask=attention_mask, |
| sliding_window_mask=sliding_window_mask, |
| position_ids=position_ids, |
| indices=indices, |
| cu_seqlens=cu_seqlens, |
| max_seqlen=max_seqlen, |
| batch_size=batch_size, |
| seq_len=seq_len, |
| output_attentions=output_attentions, |
| output_hidden_states=output_hidden_states, |
| return_dict=return_dict, |
| ) |
| last_hidden_state = outputs[0] |
|
|
| last_hidden_state = self.head(last_hidden_state) |
| last_hidden_state = self.drop(last_hidden_state) |
| logits = self.classifier_head(last_hidden_state) |
|
|
| start_logits, end_logits = logits.split(1, dim=-1) |
| start_logits = start_logits.squeeze(-1).contiguous() |
| end_logits = end_logits.squeeze(-1).contiguous() |
|
|
| loss = None |
| if start_positions is not None and end_positions is not None: |
| loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs) |
|
|
| if not return_dict: |
| output = (start_logits, end_logits) + outputs[1:] |
| return ((loss,) + output) if loss is not None else output |
|
|
| return QuestionAnsweringModelOutput( |
| loss=loss, |
| start_logits=start_logits, |
| end_logits=end_logits, |
| hidden_states=outputs.hidden_states, |
| attentions=outputs.attentions, |
| ) |
|
|
|
|
| |
|
|
| @dataclass |
| class DistillationOutputWithPasts(ModelOutput): |
| loss: Optional[torch.FloatTensor] = None |
| logits: torch.FloatTensor = None |
| last_hidden_state: torch.FloatTensor = None |
| past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
| hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None |
| attentions: Optional[Tuple[torch.FloatTensor, ...]] = None |
| depth_loss: Optional[torch.FloatTensor] = None |
| audio_logits: torch.FloatTensor = None |
| depth_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None |
| depth_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None |
| depth_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None |
|
|
| class DistillationWrapper(ts.PreTrainedModel): |
| config_class = ModernALBERTConfig |
| base_model_prefix = "model" |
| _no_split_modules = ["LlamaDecoderLayer", "FlowDecoderLayerGroup", "MimiTransformerLayer"] |
| _keys_to_ignore_on_load_missing = ["speech_tokenizer", "teacher"] |
| _tied_weights_keys = ["llm.decoder.weight"] |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_cache_class = True |
| _tp_plan = [] |
|
|
| def __init__(self, config, student=None, teacher=None): |
| super().__init__(config) |
| self.llm = ModernALBERTForMaskedLM(config) |
|
|
| if teacher != None: |
| self.teacher = teacher |
| else: |
| self.teacher = None |
| |
| self.attention_loss = nn.KLDivLoss(reduction="mean") |
| self.hidden_loss = nn.CosineEmbeddingLoss(reduction="mean") |
| self.output_loss = nn.KLDivLoss(reduction="batchmean") |
|
|
| self.temperature = 1.0 |
|
|
|
|
| @property |
| def device(self): |
| return next(self.parameters()).device |
|
|
|
|
|
|
| def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): |
|
|
| device = self.device |
|
|
| |
|
|
| input_ids = input_ids.to(device) |
|
|
| attention_mask = attention_mask.to(device) |
|
|
|
|
|
|
| if labels != None: |
|
|
| labels = labels.to(device) |
|
|
|
|
|
|
| student_outputs = self.llm( |
|
|
| input_ids=input_ids, |
|
|
| attention_mask=attention_mask, |
|
|
| labels=labels, |
|
|
| output_hidden_states=True, |
|
|
| |
|
|
| **kwargs |
|
|
| ) |
|
|
|
|
|
|
| hidden_loss = None |
|
|
| output_loss = None |
|
|
| |
|
|
| if self.teacher != None: |
|
|
| with torch.no_grad(): |
|
|
| input_ids = input_ids.to(self.teacher.device) |
|
|
| attention_mask = attention_mask.to(self.teacher.device) |
|
|
| |
|
|
| teacher_outputs = self.teacher(input_ids=input_ids, |
|
|
| attention_mask=attention_mask, |
|
|
| output_hidden_states=True, |
|
|
| |
|
|
| **kwargs |
|
|
| ) |
|
|
|
|
|
|
| s_hiddens = student_outputs.hidden_states[-1] |
|
|
| t_hiddens = teacher_outputs.hidden_states[-1].detach() |
|
|
|
|
|
|
| s_logits = student_outputs.logits |
|
|
| t_logits = teacher_outputs.logits.detach() |
|
|
| |
|
|
| hidden_loss = self.compute_hidden_loss(s_hiddens, t_hiddens, attention_mask) |
|
|
| output_loss = self.compute_output_loss(s_logits, t_logits, labels) |
|
|
|
|
|
|
| if self.teacher != None: |
|
|
| total_loss = (1.0 * student_outputs.loss) + (3.0 * hidden_loss) + (5.0 * output_loss) |
|
|
| else: |
|
|
| total_loss = student_outputs.loss |
|
|
|
|
|
|
| return DistillationOutputWithPasts( |
|
|
| loss=total_loss, |
|
|
| logits=student_outputs.logits, |
|
|
| hidden_states=student_outputs.hidden_states, |
|
|
| attentions=student_outputs.attentions, |
|
|
| ) |
|
|
|
|
|
|
| def compute_output_loss(self, s_logits, t_logits, labels): |
|
|
| mask = (labels > -1).unsqueeze(-1) |
|
|
| s_logits_masked = s_logits.masked_fill(~mask, 0.0) |
|
|
| t_logits_masked = t_logits.masked_fill(~mask, 0.0) |
|
|
|
|
|
|
| s_logits_slct = s_logits_masked.view(-1, s_logits.size(-1)) |
|
|
| t_logits_slct = t_logits_masked.view(-1, t_logits.size(-1)) |
|
|
| |
|
|
| valid_rows = mask.view(-1) |
|
|
| s_logits_slct = s_logits_slct[valid_rows] |
|
|
| t_logits_slct = t_logits_slct[valid_rows] |
|
|
|
|
|
|
| output_loss = ( |
|
|
| self.output_loss( |
|
|
| nn.functional.log_softmax(s_logits_slct / self.temperature, dim=-1), |
|
|
| nn.functional.softmax(t_logits_slct / self.temperature, dim=-1), |
|
|
| ) |
|
|
| * (self.temperature) ** 2 |
|
|
| ) |
|
|
| |
|
|
| return output_loss |
|
|
|
|
|
|
| def compute_hidden_loss(self, s_hiddens, t_hiddens, attention_mask, lambdas=None): |
|
|
| s_hidden_states = s_hiddens |
|
|
| t_hidden_states = t_hiddens |
|
|
| |
|
|
| assert s_hidden_states.size() == t_hidden_states.size() |
|
|
| |
|
|
| dim = s_hidden_states.size(-1) |
|
|
| s_hidden_states_slct = s_hidden_states |
|
|
| t_hidden_states_slct = t_hidden_states |
|
|
| |
|
|
| target = s_hidden_states_slct.new(s_hidden_states_slct.size(0)).fill_(1) |
|
|
| |
|
|
| hidden_loss = self.hidden_loss(s_hidden_states_slct, t_hidden_states_slct, target) |
|
|
| |
|
|
| return hidden_loss |
|
|
|
|
| class DistillationWrapperForSequenceClassification(ts.PreTrainedModel): |
| config_class = ModernALBERTConfig |
| base_model_prefix = "model" |
| _no_split_modules = ["LlamaDecoderLayer", "FlowDecoderLayerGroup", "MimiTransformerLayer"] |
| _keys_to_ignore_on_load_missing = ["speech_tokenizer", "teacher"] |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_cache_class = True |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.num_labels = config.num_labels |
| self.llm = ModernALBERTForSequenceClassification(config) |
|
|
| self.teacher = None |
| |
| self.attention_loss = nn.KLDivLoss(reduction="mean") |
| self.hidden_loss = nn.CosineEmbeddingLoss(reduction="mean") |
| self.output_loss = nn.KLDivLoss(reduction="batchmean") |
|
|
| self.temperature = 1.0 |
|
|
| @property |
| def device(self): |
| return next(self.parameters()).device |
|
|
| def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): |
| device = self.device |
| |
| input_ids = input_ids.to(device) |
| attention_mask = attention_mask.to(device) |
|
|
| if labels != None: |
| labels = labels.to(device) |
|
|
| student_outputs = self.llm( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| labels=labels, |
| output_hidden_states=True, |
| |
| |
| ) |
|
|
| hidden_loss = None |
| output_loss = None |
| |
| if self.teacher != None: |
| with torch.no_grad(): |
| input_ids = input_ids.to(self.teacher.device) |
| attention_mask = attention_mask.to(self.teacher.device) |
| |
| teacher_outputs = self.teacher(input_ids=input_ids, |
| attention_mask=attention_mask, |
| output_hidden_states=True, |
| |
| |
| ) |
|
|
| s_hiddens = student_outputs.hidden_states[-1] |
| t_hiddens = teacher_outputs.hidden_states[-1].detach() |
|
|
| |
| |
| |
| s_logits = student_outputs.logits |
| t_logits = teacher_outputs.logits.detach() |
| |
| hidden_loss = self.compute_hidden_loss(s_hiddens, t_hiddens, attention_mask) |
| output_loss = self.compute_output_loss(s_logits, t_logits, labels) |
|
|
| if self.teacher != None: |
| total_loss = (1.0 * student_outputs.loss) + (3.0 * hidden_loss) + (5.0 * output_loss) |
| else: |
| total_loss = student_outputs.loss |
|
|
| return DistillationOutputWithPasts( |
| loss=total_loss, |
| logits=student_outputs.logits, |
| hidden_states=student_outputs.hidden_states, |
| attentions=student_outputs.attentions, |
| ) |
|
|
| def compute_output_loss(self, s_logits, t_logits, labels): |
| mask = (labels > -1).unsqueeze(-1).expand_as(s_logits).bool() |
| |
| s_logits_slct = torch.masked_select(s_logits, mask) |
| s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1)) |
| t_logits_slct = torch.masked_select(t_logits, mask) |
| t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1)) |
| assert t_logits_slct.size() == s_logits_slct.size() |
| |
| output_loss = ( |
| self.output_loss( |
| nn.functional.log_softmax(s_logits_slct / self.temperature, dim=-1), |
| nn.functional.softmax(t_logits_slct / self.temperature, dim=-1), |
| ) |
| * (self.temperature) ** 2 |
| ) |
| |
| return output_loss |
|
|
| def compute_hidden_loss(self, s_hiddens, t_hiddens, attention_mask, lambdas=None): |
| s_hidden_states = s_hiddens |
| t_hidden_states = t_hiddens |
| |
| |
| |
| assert s_hidden_states.size() == t_hidden_states.size() |
| |
| dim = s_hidden_states.size(-1) |
| |
| |
| |
| |
| |
|
|
| s_hidden_states_slct = s_hidden_states |
| t_hidden_states_slct = t_hidden_states |
| |
| target = s_hidden_states_slct.new(s_hidden_states_slct.size(0)).fill_(1) |
| |
| hidden_loss = self.hidden_loss(s_hidden_states_slct, t_hidden_states_slct, target) |
| |
| return hidden_loss |
|
|
| class DistillationWrapperForQuestionAnswering(ts.PreTrainedModel): |
| config_class = ModernALBERTConfig |
| base_model_prefix = "model" |
| _no_split_modules = ["LlamaDecoderLayer", "FlowDecoderLayerGroup", "MimiTransformerLayer"] |
| _keys_to_ignore_on_load_missing = ["speech_tokenizer", "teacher"] |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_cache_class = True |
|
|
| def __init__(self, config): |
| super().__init__(config) |
| self.num_labels = config.num_labels |
| self.llm = ModernALBERTForQuestionAnswering(config) |
|
|
| self.teacher = None |
| |
| self.attention_loss = nn.KLDivLoss(reduction="mean") |
| self.hidden_loss = nn.CosineEmbeddingLoss(reduction="mean") |
| self.output_loss = nn.KLDivLoss(reduction="batchmean") |
|
|
| self.temperature = 1.0 |
|
|
| @property |
| def device(self): |
| return next(self.parameters()).device |
|
|
| def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs): |
| device = self.device |
| |
| input_ids = input_ids.to(device) |
| attention_mask = attention_mask.to(device) |
|
|
| if labels != None: |
| labels = labels.to(device) |
|
|
| student_outputs = self.llm( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| labels=labels, |
| output_hidden_states=True, |
| |
| |
| ) |
|
|
| hidden_loss = None |
| output_loss = None |
| |
| if self.teacher != None: |
| with torch.no_grad(): |
| input_ids = input_ids.to(self.teacher.device) |
| attention_mask = attention_mask.to(self.teacher.device) |
| |
| teacher_outputs = self.teacher(input_ids=input_ids, |
| attention_mask=attention_mask, |
| output_hidden_states=True, |
| |
| |
| ) |
|
|
| s_hiddens = student_outputs.hidden_states[-1] |
| t_hiddens = teacher_outputs.hidden_states[-1].detach() |
|
|
| |
| |
| |
| s_logits = student_outputs.logits |
| t_logits = teacher_outputs.logits.detach() |
| |
| hidden_loss = self.compute_hidden_loss(s_hiddens, t_hiddens, attention_mask) |
| output_loss = self.compute_output_loss(s_logits, t_logits, labels) |
|
|
| if self.teacher != None: |
| total_loss = (1.0 * student_outputs.loss) + (3.0 * hidden_loss) + (5.0 * output_loss) |
| else: |
| total_loss = student_outputs.loss |
|
|
| |
| |
| |
| |
| |
| |
|
|
| return student_outputs |
|
|
| def compute_output_loss(self, s_logits, t_logits, labels): |
| mask = (labels > -1).unsqueeze(-1).expand_as(s_logits).bool() |
| |
| s_logits_slct = torch.masked_select(s_logits, mask) |
| s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1)) |
| t_logits_slct = torch.masked_select(t_logits, mask) |
| t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1)) |
| assert t_logits_slct.size() == s_logits_slct.size() |
| |
| output_loss = ( |
| self.output_loss( |
| nn.functional.log_softmax(s_logits_slct / self.temperature, dim=-1), |
| nn.functional.softmax(t_logits_slct / self.temperature, dim=-1), |
| ) |
| * (self.temperature) ** 2 |
| ) |
| |
| return output_loss |
|
|
| def compute_hidden_loss(self, s_hiddens, t_hiddens, attention_mask, lambdas=None): |
| s_hidden_states = s_hiddens |
| t_hidden_states = t_hiddens |
| |
| |
| |
| assert s_hidden_states.size() == t_hidden_states.size() |
| |
| dim = s_hidden_states.size(-1) |
| |
| |
| |
| |
| |
|
|
| s_hidden_states_slct = s_hidden_states |
| t_hidden_states_slct = t_hidden_states |
| |
| target = s_hidden_states_slct.new(s_hidden_states_slct.size(0)).fill_(1) |
| |
| hidden_loss = self.hidden_loss(s_hidden_states_slct, t_hidden_states_slct, target) |
| |
| return hidden_loss |