modernalbert-tiny-v1.0 / modeling_modernalbert.py
mohammadmahdinouri's picture
Upload ModernALBERTForMaskedLM
aeec937 verified
from datasets import load_dataset
from transformers import (
AutoTokenizer,
AutoModelForMaskedLM,
DataCollatorForLanguageModeling,
Trainer,
TrainingArguments,
)
from itertools import chain
import torch
import transformers as ts
from optimi import StableAdamW
import os
from transformers.modeling_outputs import *
import torch.nn as nn
import torch
from dataclasses import dataclass
from typing import Optional, Tuple
import transformers as ts
import gc
from transformers import PretrainedConfig
import torch.nn.functional as F
from typing import Optional, Union
from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
from transformers.activations import ACT2FN
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
import math
# Optional FlashAttention import
try:
from flash_attn.bert_padding import unpad_input, pad_input
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
from flash_attn.layers.rotary import RotaryEmbedding
from flash_attn.ops.triton.rotary import apply_rotary
FLASH_ATTN_AVAILABLE = True
print("✅ FlashAttention is available.")
except ImportError:
FLASH_ATTN_AVAILABLE = False
print("❌ FlashAttention is not available. Using PyTorch SDPA fallback.")
from .configuration_modernalbert import ModernALBERTConfig
# --- Shared FFN (Unchanged) ---
class SharedLoraFFN(nn.Module):
"""
A shared Feed-Forward Network modified by LoRA weights.
The forward pass accepts pre-merged LoRA weights.
"""
def __init__(self, config):
super().__init__()
dim = config.hidden_size
intermediate_dim = config.expert_intermediate_size
self.linear1 = nn.Linear(dim, intermediate_dim)
self.act = nn.GELU()
self.linear2 = nn.Linear(intermediate_dim, dim)
self.lora_scaling = config.lora_alpha / config.lora_rank
def forward(self, x, lora_A1, lora_B1, lora_A2, lora_B2):
# Apply the merged LoRA weights
# Formula: (x @ A.T @ B.T) * scale
expanded = self.linear1(x) + (x @ lora_A1.T @ lora_B1.T) * self.lora_scaling
activated = self.act(expanded)
contracted = self.linear2(activated) + (activated @ lora_A2.T @ lora_B2.T) * self.lora_scaling
return contracted
# --- 1. The Router (Fixed for Flash Attention / Unpadded Inputs) ---
class SwitchRouterTopK(nn.Module):
"""
Calculates the EMA weights for expert merging.
Optimized for unpadded (Flash Attention) inputs where shape is (total_nnz, dim).
"""
def __init__(self, config):
super().__init__()
self.config = config
self.num_experts = config.num_experts
# Use a slightly lower decay during training if starting from scratch (e.g., 0.9)
self.ema_decay = getattr(config, "router_ema_decay", 0.99)
self.layer = nn.Linear(config.hidden_size, config.num_experts, bias=False)
self.k = config.top_k
self.jitter_noise = config.router_jitter_noise
# Buffer for inference (frozen stats)
self.register_buffer("ema_weights", torch.ones(config.num_experts) / config.num_experts)
def forward(self, hidden_states):
# hidden_states shape: (total_nnz, hidden_size)
if self.config.routing_strategy == "ema":
# 1. Compute Router Probabilities
logits = self.layer(hidden_states) # Shape: (total_nnz, num_experts)
probs = F.softmax(logits, dim=-1)
if self.training:
# 2. Compute batch-level routing vector r_b
# Since inputs are unpadded (Batch * Seq flattened to dim 0),
# we simply average across all tokens to get the global batch stats.
r_b = probs.mean(dim=0) # Shape: (num_experts,)
# 3. Calculate the weight to USE for this step (Allow Gradients!)
# We mix history (detached) with current (with grad) to stabilize training.
weights_for_forward = self.ema_decay * self.ema_weights.detach() + (1 - self.ema_decay) * r_b
# 4. Update the buffer in the background (No Gradients needed for storage)
new_ema_value = weights_for_forward.detach()
self.ema_weights.copy_(new_ema_value)
# Normalize to ensure sum is 1
self.ema_weights.div_(self.ema_weights.sum() + 1e-9)
return weights_for_forward
# During inference, return the frozen stable weights
return self.ema_weights
else:
num_tokens = hidden_states.shape[0]
# if self.training and self.jitter_noise > 0:
# noise = torch.randn_like(hidden_states) * self.jitter_noise
# hidden_states = hidden_states + noise
logits = self.layer(hidden_states)
probs = F.softmax(logits, dim=-1, dtype=torch.float32)
topk_probs, topk_indices = torch.topk(probs, k=self.k, dim=-1)
topk_probs_normalized = topk_probs / torch.sum(topk_probs, dim=-1, keepdim=True)
# Load Balancing for K = 1
# flat_topk_indices = topk_indices.flatten()
# one_hot_assignments = F.one_hot(flat_topk_indices, num_classes=self.num_experts).float()
# tokens_per_expert_fraction = one_hot_assignments.sum(0) / num_tokens
# print(tokens_per_expert_fraction)
# router_prob_per_expert = torch.mean(probs, dim=0)
# Load Balancing for K > 1
one_hot = F.one_hot(topk_indices, num_classes=self.num_experts).float()
tokens_per_expert = torch.sum(one_hot * topk_probs.unsqueeze(-1), dim=(0, 1)) / num_tokens
router_prob_per_expert = torch.mean(probs, dim=0)
aux_loss = self.num_experts * torch.mean(tokens_per_expert * router_prob_per_expert)
# --- DEBUG: print expert utilization for this batch ---
# print("Expert utilization (fraction of tokens per expert):", tokens_per_expert.detach().cpu().numpy())
# print(aux_loss)
return topk_indices, topk_probs_normalized, aux_loss
# --- 2. The MoE Layer (Minor cleanup for debug prints) ---
class LoraMoELayerTopK(nn.Module):
"""
Implements the MoL layer with expert merging.
Allows for efficient dense computation by collapsing experts
into a single adapter based on router weights.
"""
def __init__(self, config):
super().__init__()
self.config = config
dim = config.hidden_size
expert_intermediate_dim = config.expert_intermediate_size
num_experts = config.num_experts
lora_rank = config.lora_rank
self.k = config.top_k
self.num_experts = num_experts
self.norm = nn.LayerNorm(dim, eps=config.layer_norm_eps)
self.router = SwitchRouterTopK(config)
self.shared_ffn = SharedLoraFFN(config)
# The pool of Expert LoRA weights {\Delta_1, ..., \Delta_E}
self.lora_A1 = nn.Parameter(torch.randn(num_experts, lora_rank, dim))
self.lora_B1 = nn.Parameter(torch.zeros(num_experts, expert_intermediate_dim, lora_rank))
self.lora_A2 = nn.Parameter(torch.randn(num_experts, lora_rank, expert_intermediate_dim))
self.lora_B2 = nn.Parameter(torch.zeros(num_experts, dim, lora_rank))
# Initialization (Kaiming Uniform)
for i in range(num_experts):
nn.init.kaiming_uniform_(self.lora_A1[i], a=math.sqrt(5))
nn.init.kaiming_uniform_(self.lora_A2[i], a=math.sqrt(5))
def forward(self, hidden_states: torch.Tensor):
if self.config.routing_strategy == "ema":
residual = hidden_states
hidden_states_norm = self.norm(hidden_states)
# 1. Get the global merging weights (w_t)
# Returns shape: (num_experts,)
merge_weights = self.router(hidden_states_norm)
# 2. Weighted Merge of all LoRA parameters
# Formula: \Delta_{merged} = \sum_{j=1}^E w_j * \Delta_j
# We reshape weights to [Experts, 1, 1] for broadcasting against [Experts, Rank, Dim]
w = merge_weights.view(-1, 1, 1)
merged_A1 = torch.sum(w * self.lora_A1, dim=0)
merged_B1 = torch.sum(w * self.lora_B1, dim=0)
merged_A2 = torch.sum(w * self.lora_A2, dim=0)
merged_B2 = torch.sum(w * self.lora_B2, dim=0)
# 3. Dense Forward Pass
# Pass the merged adapter to the FFN.
output = self.shared_ffn(
hidden_states_norm,
merged_A1, merged_B1,
merged_A2, merged_B2
)
# We return 0.0 for aux_loss because we are not doing load balancing in this mode
return residual + output, torch.tensor(0.0, device=hidden_states.device)
elif self.config.routing_strategy == "uniform":
residual = hidden_states
hidden_states_norm = self.norm(hidden_states)
# 1. Get the global merging weights (w_t)
# Returns shape: (num_experts,)
merge_weights = torch.ones(self.config.num_experts, dtype=hidden_states_norm.dtype, device=hidden_states_norm.device) / (self.config.num_experts)
# 2. Weighted Merge of all LoRA parameters
# Formula: \Delta_{merged} = \sum_{j=1}^E w_j * \Delta_j
# We reshape weights to [Experts, 1, 1] for broadcasting against [Experts, Rank, Dim]
w = merge_weights.view(-1, 1, 1)
merged_A1 = torch.sum(w * self.lora_A1, dim=0)
merged_B1 = torch.sum(w * self.lora_B1, dim=0)
merged_A2 = torch.sum(w * self.lora_A2, dim=0)
merged_B2 = torch.sum(w * self.lora_B2, dim=0)
# 3. Dense Forward Pass
# Pass the merged adapter to the FFN.
output = self.shared_ffn(
hidden_states_norm,
merged_A1, merged_B1,
merged_A2, merged_B2
)
# We return 0.0 for aux_loss because we are not doing load balancing in this mode
return residual + output, torch.tensor(0.0, device=hidden_states.device)
else:
residual = hidden_states
hidden_states_norm = self.norm(hidden_states)
num_tokens, dim = hidden_states_norm.shape
topk_indices, topk_probs, aux_loss = self.router(hidden_states_norm)
# Efficient permutation-based dispatch
flat_token_indices = torch.arange(num_tokens, device=hidden_states.device).repeat_interleave(self.k)
flat_expert_indices = topk_indices.flatten()
perm_indices = torch.argsort(flat_expert_indices)
sorted_token_indices = flat_token_indices[perm_indices]
sorted_expert_indices = flat_expert_indices[perm_indices]
permuted_tokens = hidden_states_norm[sorted_token_indices]
permuted_probs = topk_probs.flatten()[perm_indices]
tokens_per_expert = F.one_hot(sorted_expert_indices, self.num_experts).sum(dim=0)
split_tokens = torch.split(permuted_tokens, tokens_per_expert.tolist(), dim=0)
split_probs = torch.split(permuted_probs, tokens_per_expert.tolist(), dim=0)
# Batched processing loop over experts
expert_outputs = []
for i in range(self.num_experts):
if tokens_per_expert[i] > 0:
output = self.shared_ffn(
split_tokens[i],
self.lora_A1[i], self.lora_B1[i],
self.lora_A2[i], self.lora_B2[i]
)
expert_outputs.append(output * split_probs[i].unsqueeze(1))
else:
expert_outputs.append(torch.empty(0, dim, device=hidden_states.device))
# Un-permute and combine results
concatenated_outputs = torch.cat(expert_outputs, dim=0)
inverse_perm_indices = torch.argsort(perm_indices)
unpermuted_outputs = concatenated_outputs[inverse_perm_indices]
final_output = unpermuted_outputs.view(num_tokens, self.k, dim).sum(dim=1)
# Final residual connection
output = residual + final_output
return output, aux_loss
class ModernAlbertMLP(nn.Module):
def __init__(self, config: ModernALBERTConfig):
super().__init__()
self.config = config
self.Wi = nn.Linear(config.hidden_size, int(config.intermediate_size) * 2, bias=False)
# self.act = ACT2FN[config.hidden_activation]
self.act = ACT2FN["gelu"]
self.drop = nn.Dropout(config.hidden_dropout_prob)
self.Wo = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
input, gate = self.Wi(hidden_states).chunk(2, dim=-1)
return self.Wo(self.drop(self.act(input) * gate))
#Flash Attention Rotatory Embedding
class ApplyRotaryEmbUnpad(torch.autograd.Function):
@staticmethod
def forward(
ctx,
qkv,
cos,
sin,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
):
# (total_nnz, 3, nheads, headdim)
qkv = qkv.contiguous()
total_nnz, _three, _nheads, headdim = qkv.shape
# We need qkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
# we get the same tensor
# qk = rearrange(qkv[:, :2], "b_s t h d -> b_s (t h) d")
qk = qkv[:, :2].view(total_nnz, -1, headdim)
apply_rotary(
qk,
cos,
sin,
seqlen_offsets=0,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
interleaved=False,
inplace=True,
)
ctx.save_for_backward(cos, sin, cu_seqlens)
ctx.max_seqlen = max_seqlen
return qkv
@staticmethod
def backward(ctx, do):
cos, sin, cu_seqlens = ctx.saved_tensors
do = do.contiguous()
total_nnz, _three, _nheads, headdim = do.shape
# We need dqkv to be contiguous so that when we reshape to combine (3, nheads) dimensions,
# we get the same tensor
dqk = do[:, :2].view(total_nnz, -1, headdim)
apply_rotary(
dqk,
cos,
sin,
seqlen_offsets=0,
cu_seqlens=cu_seqlens,
max_seqlen=ctx.max_seqlen,
interleaved=False,
inplace=True,
conjugate=True,
)
return do, None, None, None, None, None, None
def apply_rotary_unpadded(
qkv,
cos,
sin,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
):
"""
Arguments:
qkv: (total_nnz, 3, nheads, headdim) - input tensor for packed QKV.
cos, sin: (seqlen_rotary, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
inplace: if True, apply rotary embedding in-place.
seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
Most commonly used in inference when we have KV cache.
cu_seqlens: (batch + 1,) or None
max_seqlen: int
Return:
out: (total_nnz, dim)
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
"""
return ApplyRotaryEmbUnpad.apply(qkv, cos, sin, cu_seqlens, max_seqlen)
class ModernAlbertUnpaddedRotaryEmbedding(RotaryEmbedding):
"""
The rotary position embeddings applied directly to unpadded sequences.
"""
def __init__(
self,
dim: int,
base: float = 10000.0,
max_seqlen: Optional[int] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
):
"""
max_seqlen: if max_seqlen, device, and dtype are provided, we precompute the cos_sin_cache
up to max_seqlen. If the max_seqlen, device, or dtype during training/inference differ,
the cos_sin_cache will be recomputed during the forward pass.
"""
super().__init__(dim=dim, base=base, device=device, interleaved=False)
self.max_seqlen = max_seqlen
if max_seqlen is not None and device is not None and dtype is not None:
self._update_cos_sin_cache(max_seqlen, device=device, dtype=dtype)
def forward(
self,
qkv: torch.Tensor,
cu_seqlens: torch.Tensor,
max_seqlen: Optional[int] = None,
) -> Union[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
"""
Apply rotary embedding *inplace* to qkv.
qkv: (total_nnz, 3, nheads, headdim)
cu_seqlens: (batch + 1,) cumulative sequence lengths
max_seqlen: int max seq length in the batch
"""
if max_seqlen is not None:
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
qkv = apply_rotary_unpadded(
qkv,
self._cos_cached,
self._sin_cached,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
)
return qkv
def extra_repr(self) -> str:
return f"dim={self.dim}, base={self.base}, scale_base={self.scale_base}"
class ModernAlbertRotaryEmbedding(nn.Module):
def __init__(self, config: ModernALBERTConfig, device=None):
super().__init__()
# BC: "rope_type" was originally "type"
if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
else:
self.rope_type = "default"
self.max_seq_len_cached = config.max_position_embeddings
self.original_max_seq_len = config.max_position_embeddings
self.config = config
self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
self.register_buffer("inv_freq", inv_freq, persistent=False)
self.original_inv_freq = self.inv_freq
@torch.no_grad()
@dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
def forward(self, x, position_ids):
inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
position_ids_expanded = position_ids[:, None, :].float()
device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
with torch.autocast(device_type=device_type, enabled=False): # Force float32
freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
emb = torch.cat((freqs, freqs), dim=-1)
cos = emb.cos() * self.attention_scaling
sin = emb.sin() * self.attention_scaling
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
# GeGLU unchanged
class GeGLU(nn.Module):
def __init__(self, dim_in, dim_out):
super().__init__()
self.w1 = nn.Linear(dim_in, dim_out)
self.w2 = nn.Linear(dim_in, dim_out)
def forward(self, x):
return F.gelu(self.w1(x)) * self.w2(x)
#Flash Attention
def _unpad_modernbert_input(
inputs: torch.Tensor,
attention_mask: torch.Tensor,
position_ids: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, Optional[torch.Tensor], Optional[torch.Tensor]]:
"""
Remove padding from input sequences.
Args:
inputs: (batch, seqlen, ...) or (batch, seqlen)
attention_mask: (batch, seqlen), bool / int, 1 means valid and 0 means not valid.
position_ids: (batch, seqlen), int, position ids
labels: (batch, seqlen), int, labels
Returns:
unpadded_inputs: (total_nnz, ...), where total_nnz = number of tokens selected in attention_mask.
indices: (total_nnz)
cu_seqlens: (batch + 1), the cumulative sequence lengths
max_seqlen_in_batch: int
unpadded_position_ids: (total_nnz) or None
unpadded_labels: (total_nnz) or None
"""
seqlens_in_batch = attention_mask.sum(dim=-1, dtype=torch.int32)
indices = torch.nonzero(attention_mask.flatten(), as_tuple=False).flatten()
max_seqlen_in_batch = int(seqlens_in_batch.max().item())
cu_seqlens = torch.nn.functional.pad(torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0))
if inputs.dim() == 2:
unpadded_inputs = inputs.flatten()[indices]
else:
batch, seqlen, *rest = inputs.shape
shape = batch * seqlen
unpadded_inputs = inputs.view(shape, *rest)[indices]
unpadded_position_ids = position_ids.flatten()[indices] if position_ids is not None else None
unpadded_labels = labels.flatten()[indices] if labels is not None else None
return unpadded_inputs, indices, cu_seqlens, max_seqlen_in_batch, unpadded_position_ids, unpadded_labels
def _pad_modernbert_output(
inputs: torch.Tensor,
indices: torch.Tensor,
batch: int,
seqlen: int,
) -> torch.Tensor:
"""
Add padding to sequences.
Args:
inputs: (total_nnz, ...) or (total_nnz,), where total_nnz = number of tokens selected in attention_mask.
indices: (total_nnz)
batch: int, batch size
seqlen: int, max sequence length
Returns:
padded_inputs: (batch, seqlen, ...) or (batch, seqlen)
"""
if inputs.dim() == 1:
output = torch.zeros(batch * seqlen, dtype=inputs.dtype, device=inputs.device)
output[indices] = inputs
padded_inputs = output.view(batch, seqlen)
else:
_, *rest = inputs.shape
output = torch.zeros(batch * seqlen, *rest, dtype=inputs.dtype, device=inputs.device)
output[indices] = inputs
padded_inputs = output.view(batch, seqlen, *rest)
return padded_inputs
def flash_attention_forward(
module: "SharedGroup",
qkv: torch.Tensor,
rotary_emb: ModernAlbertUnpaddedRotaryEmbedding,
cu_seqlens: torch.Tensor,
max_seqlen: int,
local_attention: tuple[int, int],
bs: int,
dim: int,
target_dtype: torch.dtype = torch.bfloat16,
**_kwargs,
) -> tuple[torch.Tensor]:
# (total_seqlen, 3, nheads, headdim)
qkv = rotary_emb(qkv, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen)
convert_dtype = qkv.dtype not in (torch.float16, torch.bfloat16)
if convert_dtype:
# FA2 implementation only supports fp16 and bf16. If FA2 is supported,
# bfloat16 must be supported as of FA2 2.5.7. (Turing GPUs not supported)
orig_dtype = qkv.dtype
qkv = qkv.to(target_dtype)
attn = flash_attn_varlen_qkvpacked_func(
qkv,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
dropout_p=module.att_dropout.p if module.training else 0.0,
# deterministic=module.deterministic_flash_attn,
# deterministic=False,
window_size=local_attention,
)
attn = attn.to(orig_dtype) # type: ignore
else:
attn = flash_attn_varlen_qkvpacked_func(
qkv,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
dropout_p=module.att_dropout.p if module.training else 0.0,
# deterministic=module.deterministic_flash_attn,
window_size=local_attention,
)
return (attn.view(bs, dim),)
def sdpa_attention_forward(
module: "SharedGroup",
qkv: torch.Tensor,
attention_mask: torch.Tensor,
sliding_window_mask: torch.Tensor,
position_ids: Optional[torch.LongTensor],
local_attention: tuple[int, int],
bs: int,
dim: int,
**_kwargs,
) -> tuple[torch.Tensor]:
# qkv: [batch_size, seqlen, 3, nheads, headdim]
cos, sin = module.rotary_emb(qkv, position_ids=position_ids)
query, key, value = qkv.transpose(3, 1).unbind(dim=2)
# query, key, value: [batch_size, heads, seq_len, head_dim]
query, key = apply_rotary_pos_emb(query, key, cos, sin)
if local_attention != (-1, -1):
attention_mask = sliding_window_mask
attn_output = (
F.scaled_dot_product_attention(
query,
key,
value,
dropout_p=module.attention_dropout.p if module.training else 0.0,
attn_mask=attention_mask,
)
.transpose(1, 2)
.contiguous()
)
attn_output = attn_output.view(bs, -1, dim)
return (attn_output,)
class SharedGroup(nn.Module):
def __init__(self, config): # config: ModernALBERTConfig
super().__init__()
self.config = config
hs, nh = config.hidden_size, config.num_attention_heads
self.head_dim = hs // nh
self.num_heads = nh
self.use_adapter = config.use_adapter
eps = config.layer_norm_eps
rope_theta = 10000
# Norms
self.att_pre_norm = nn.LayerNorm(hs, eps=eps)
self.ffn_pre_norm = nn.LayerNorm(hs, eps=eps)
# Attention
self.qkv = nn.Linear(hs, 3 * hs)
self.out_proj = nn.Linear(hs, hs)
self.att_dropout = nn.Dropout(config.attention_probs_dropout_prob)
self.local_attention = (-1, -1)
if FLASH_ATTN_AVAILABLE:
self.rotary_emb = ModernAlbertUnpaddedRotaryEmbedding(
dim=self.head_dim, max_seqlen=config.max_position_embeddings, base=rope_theta
)
else:
config_copy = copy.deepcopy(config)
config_copy.rope_theta = rope_theta
self.rotary_emb = ModernAlbertRotaryEmbedding(config=config_copy)
# FFN
self.mlp = ModernAlbertMLP(config)
def forward(self, inputs, mask, config, start_idx=0, use_moa=False, **kwargs):
outputs = [] if config.output_hidden_states else None
attn_maps = [] if config.output_attentions else None
x = inputs
for i in range(config.group_depth):
h = x
h_norm = self.att_pre_norm(h)
qkv_proj = self.qkv(h_norm)
bs = h.shape[0]
# --- Attention Calculation ---
if FLASH_ATTN_AVAILABLE:
qkv = qkv_proj.view(-1, 3, self.num_heads, self.head_dim)
attn_outputs = flash_attention_forward(
self,
qkv=qkv,
rotary_emb=self.rotary_emb,
local_attention=self.local_attention,
bs=bs,
dim=self.head_dim * self.num_heads,
**kwargs,
)
attn_out = attn_outputs[0]
else: # Fallback to PyTorch Scaled Dot Product Attention
qkv = qkv_proj.view(bs, -1, 3, self.num_heads, self.head_dim)
attn_mask = mask[:, None, None, :]
attn_outputs = sdpa_attention_forward(
self,
qkv=qkv,
rotary_emb=self.rotary_emb,
local_attention=self.local_attention,
bs=bs,
dim=self.head_dim * self.num_heads,
**kwargs,
)
attn_out = attn_outputs[0]
x = self.att_dropout(self.out_proj(attn_out)) + h
if use_moa == True and i == config.group_depth - 1:
return x, outputs, attn_maps
else:
# FFN block
h2 = x
h2_norm = self.ffn_pre_norm(h2)
x = self.mlp(h2_norm) + h2
# Collect hidden state if needed
if config.output_hidden_states:
outputs.append(x)
return x, outputs, attn_maps
class ModernAlbertEmbeddings(nn.Module):
"""
Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
"""
def __init__(self, config: ModernALBERTConfig):
super().__init__()
self.config = config
self.tok_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
self.embed_proj = nn.Linear(config.embedding_size, config.hidden_size)
# self.norm = nn.LayerNorm(config.hidden_size, eps=config.norm_eps, bias=config.norm_bias)
self.norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, bias=False)
self.drop = nn.Dropout(0.0)
def forward(
self, input_ids: Optional[torch.LongTensor] = None, inputs_embeds: Optional[torch.Tensor] = None
) -> torch.Tensor:
if inputs_embeds is not None:
hidden_states = self.drop(self.norm(self.embed_proj(inputs_embeds)))
else:
hidden_states = self.drop(self.norm(self.embed_proj(self.tok_embeddings(input_ids))))
return hidden_states
@dataclass
class MoABaseModelOutput(BaseModelOutput):
load_balancing_loss: Optional[torch.FloatTensor] = None
class ModernALBERTModel(ts.PreTrainedModel):
config_class = ModernALBERTConfig
base_model_prefix = "modernAlbert"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
def __init__(self, config: ModernALBERTConfig):
super().__init__(config)
self.config = config
# Factorized embeddings
self.embeddings = ModernAlbertEmbeddings(config)
self.final_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, bias=False)
self.num_groups = config.num_hidden_layers // config.group_depth
self.groups = nn.ModuleList([SharedGroup(config) for _ in range(self.num_groups)])
if config.use_moa:
self.moa_layers = nn.ModuleList([
# LoraMoELayerTopK(config) for _ in range(self.num_groups - 1)
LoraMoELayerTopK(config) for _ in range(config.num_expert_modules)
])
self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
self.post_init()
def forward(self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
batch_size: Optional[int] = None,
seq_len: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
):
all_hidden_states = []
if batch_size is None and seq_len is None:
if inputs_embeds is not None:
batch_size, seq_len = inputs_embeds.shape[:2]
else:
batch_size, seq_len = input_ids.shape[:2]
output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
if output_hidden_states:
self.config.output_hidden_states = True
hs, atts = ([] if output_hidden_states else None), ([] if output_attentions else None)
all_aux_losses = []
repad = False
if FLASH_ATTN_AVAILABLE:
if indices is None and cu_seqlens is None and max_seqlen is None:
repad = True
if inputs_embeds is None:
with torch.no_grad():
input_ids, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
inputs=input_ids, attention_mask=attention_mask
)
else:
inputs_embeds, indices, cu_seqlens, max_seqlen, *_ = _unpad_modernbert_input(
inputs=inputs_embeds, attention_mask=attention_mask
)
else:
if position_ids is None:
position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
attention_mask, sliding_window_mask = self._update_attention_mask(
attention_mask, output_attentions=output_attentions
)
hidden_states = self.embeddings(input_ids=input_ids, inputs_embeds=inputs_embeds)
x = hidden_states
if output_hidden_states:
hs.append(x)
# Mask
mask = None
if attention_mask is not None:
mask = attention_mask.to(torch.bool)
for i, group in enumerate(self.groups):
is_moa = self.config.use_moa and (i > len(self.groups) - len(self.moa_layers) - 1)
moa_idx = i - (len(self.groups) - len(self.moa_layers))
x, layer_hs, layer_atts = group(x,
mask,
self.config,
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
use_moa=is_moa,
output_attentions=output_attentions,)
if output_hidden_states and layer_hs:
hs.extend(layer_hs)
if output_attentions and layer_atts:
atts.extend(layer_atts)
# After each group (except the last), apply the MoA layer
if self.config.use_moa and is_moa:
x, aux_loss = self.moa_layers[moa_idx](x)
if output_hidden_states:
hs.append(x)
all_aux_losses.append(aux_loss)
hidden_states = self.final_norm(x)
# hidden_states = _pad_modernbert_output(
# inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_len
# )
if repad:
hidden_states = _pad_modernbert_output(
inputs=hidden_states, indices=indices, batch=batch_size, seqlen=seq_len
)
if all_hidden_states is not None:
all_hidden_states = tuple(
_pad_modernbert_output(inputs=hs, indices=indices, batch=batch_size, seqlen=seq_len)
for hs in all_hidden_states
)
load_balancing_loss = None
if all_aux_losses != []:
load_balancing_loss = torch.stack(all_aux_losses).mean() * self.config.load_balancing_loss_coef
return MoABaseModelOutput(last_hidden_state=hidden_states, hidden_states=hs, attentions=atts, load_balancing_loss=load_balancing_loss)
class ModernAlbertPredictionHead(nn.Module):
def __init__(self, config: ModernALBERTConfig):
super().__init__()
self.config = config
self.dense = nn.Linear(config.hidden_size, config.embedding_size, bias=False)
self.act = ACT2FN["gelu"]
self.norm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps, bias=False)
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
return self.norm(self.act(self.dense(hidden_states)))
class ModernALBERTForMaskedLM(ts.PreTrainedModel):
"""
Modern ALBERT model with a Masked Language Modeling (MLM) head,
optimized to mirror the HuggingFace `AlbertForMaskedLM` API.
"""
_tied_weights_keys = ["decoder.weight"]
config_class = ModernALBERTConfig
base_model_prefix = "modernAlbert"
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
def __init__(self, config: ModernALBERTConfig):
super().__init__(config)
self.config = config
# Base encoder without pooling
self.albert = ModernALBERTModel(config)
# MLM head
self.head = ModernAlbertPredictionHead(config)
self.decoder = nn.Linear(config.embedding_size, config.vocab_size, bias=False)
self.post_init()
def get_input_embeddings(self):
return self.albert.embeddings.tok_embeddings
def get_output_embeddings(self):
return self.decoder
def set_output_embeddings(self, new_embeddings: nn.Linear):
self.decoder = new_embeddings
@torch.compile(dynamic=True)
def compiled_head(self, output: torch.Tensor) -> torch.Tensor:
return self.decoder(self.head(output))
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
batch_size: Optional[int] = None,
seq_len: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
if FLASH_ATTN_AVAILABLE:
if indices is None and cu_seqlens is None and max_seqlen is None:
if batch_size is None and seq_len is None:
if inputs_embeds is not None:
batch_size, seq_len = inputs_embeds.shape[:2]
else:
batch_size, seq_len = input_ids.shape[:2]
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
if inputs_embeds is None:
with torch.no_grad():
input_ids, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=labels
)
else:
inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, labels = _unpad_modernbert_input(
inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=labels
)
# Encode
outputs = self.albert(
input_ids=input_ids,
attention_mask=attention_mask,
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
batch_size=batch_size,
seq_len=seq_len,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = outputs[0]
if FLASH_ATTN_AVAILABLE:
last_hidden_state_unpaded = _pad_modernbert_output(inputs=last_hidden_state, indices=indices, batch=batch_size, seqlen=seq_len)
if outputs.hidden_states != None:
outputs.hidden_states.append(last_hidden_state_unpaded)
logits = self.decoder(self.head(last_hidden_state))
loss = None
if labels is not None:
loss = self.loss_function(logits, labels, vocab_size=self.config.vocab_size, **kwargs)
if outputs.load_balancing_loss != None and self.training:
# print(outputs.load_balancing_loss)
loss += outputs.load_balancing_loss
if FLASH_ATTN_AVAILABLE:
logits = _pad_modernbert_output(inputs=logits, indices=indices, batch=batch_size, seqlen=seq_len)
if not return_dict:
output = (logits,)
return ((loss,) + output) if loss is not None else output
return MaskedLMOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class ModernALBERTForSequenceClassification(ts.PreTrainedModel):
config_class = ModernALBERTConfig
def __init__(self, config: ModernALBERTConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.config = config
self.albert = ModernALBERTModel(config)
self.head = ModernAlbertPredictionHead(config)
self.drop = torch.nn.Dropout(0.0)
self.classifier = nn.Linear(config.embedding_size, config.num_labels)
# Initialize weights and apply final processing
self.post_init()
def forward(
self,
input_ids: Optional[torch.LongTensor] = None,
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
batch_size: Optional[int] = None,
seq_len: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# self._maybe_set_compile()
if FLASH_ATTN_AVAILABLE:
if indices is None and cu_seqlens is None and max_seqlen is None:
if batch_size is None and seq_len is None:
if inputs_embeds is not None:
batch_size, seq_len = inputs_embeds.shape[:2]
else:
batch_size, seq_len = input_ids.shape[:2]
device = input_ids.device if input_ids is not None else inputs_embeds.device
if attention_mask is None:
attention_mask = torch.ones((batch_size, seq_len), device=device, dtype=torch.bool)
if inputs_embeds is None:
with torch.no_grad():
input_ids, indices, cu_seqlens, max_seqlen, position_ids, _ = _unpad_modernbert_input(
inputs=input_ids, attention_mask=attention_mask, position_ids=position_ids, labels=None
)
else:
inputs_embeds, indices, cu_seqlens, max_seqlen, position_ids, _ = _unpad_modernbert_input(
inputs=inputs_embeds, attention_mask=attention_mask, position_ids=position_ids, labels=None
)
outputs = self.albert(
input_ids=input_ids,
attention_mask=attention_mask,
# sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
inputs_embeds=inputs_embeds,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
batch_size=batch_size,
seq_len=seq_len,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = outputs[0]
last_hidden_state = _pad_modernbert_output(inputs=last_hidden_state, indices=indices, batch=batch_size, seqlen=seq_len)
# if self.config.classifier_pooling == "cls":
# last_hidden_state = last_hidden_state[:, 0]
# elif self.config.classifier_pooling == "mean":
last_hidden_state = (last_hidden_state * attention_mask.unsqueeze(-1)).sum(dim=1) / attention_mask.sum(
dim=1, keepdim=True
)
pooled_output = self.head(last_hidden_state)
pooled_output = self.drop(pooled_output)
logits = self.classifier(pooled_output)
loss = None
if labels is not None:
if self.config.problem_type is None:
if self.num_labels == 1:
self.config.problem_type = "regression"
elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
self.config.problem_type = "single_label_classification"
else:
self.config.problem_type = "multi_label_classification"
if self.config.problem_type == "regression":
loss_fct = MSELoss()
if self.num_labels == 1:
loss = loss_fct(logits.squeeze(), labels.squeeze())
else:
loss = loss_fct(logits, labels)
elif self.config.problem_type == "single_label_classification":
loss_fct = CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
elif self.config.problem_type == "multi_label_classification":
loss_fct = BCEWithLogitsLoss()
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,)
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
class ModernALBERTForQuestionAnswering(ts.PreTrainedModel):
config_class = ModernALBERTConfig
def __init__(self, config: ModernALBERTConfig):
super().__init__(config)
self.num_labels = config.num_labels
self.albert = ModernALBERTModel(config)
self.head = ModernAlbertPredictionHead(config)
self.drop = torch.nn.Dropout(0.0)
self.classifier_head = nn.Linear(config.embedding_size, config.num_labels)
self.post_init()
def forward(
self,
input_ids: Optional[torch.Tensor],
attention_mask: Optional[torch.Tensor] = None,
sliding_window_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
start_positions: Optional[torch.Tensor] = None,
end_positions: Optional[torch.Tensor] = None,
indices: Optional[torch.Tensor] = None,
cu_seqlens: Optional[torch.Tensor] = None,
max_seqlen: Optional[int] = None,
batch_size: Optional[int] = None,
seq_len: Optional[int] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]:
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# self._maybe_set_compile()
outputs = self.albert(
input_ids,
attention_mask=attention_mask,
sliding_window_mask=sliding_window_mask,
position_ids=position_ids,
indices=indices,
cu_seqlens=cu_seqlens,
max_seqlen=max_seqlen,
batch_size=batch_size,
seq_len=seq_len,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=return_dict,
)
last_hidden_state = outputs[0]
last_hidden_state = self.head(last_hidden_state)
last_hidden_state = self.drop(last_hidden_state)
logits = self.classifier_head(last_hidden_state)
start_logits, end_logits = logits.split(1, dim=-1)
start_logits = start_logits.squeeze(-1).contiguous()
end_logits = end_logits.squeeze(-1).contiguous()
loss = None
if start_positions is not None and end_positions is not None:
loss = self.loss_function(start_logits, end_logits, start_positions, end_positions, **kwargs)
if not return_dict:
output = (start_logits, end_logits) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return QuestionAnsweringModelOutput(
loss=loss,
start_logits=start_logits,
end_logits=end_logits,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
# Distillation
@dataclass
class DistillationOutputWithPasts(ModelOutput):
loss: Optional[torch.FloatTensor] = None
logits: torch.FloatTensor = None
last_hidden_state: torch.FloatTensor = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
depth_loss: Optional[torch.FloatTensor] = None
audio_logits: torch.FloatTensor = None
depth_past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
depth_hidden_states: Optional[Tuple[torch.FloatTensor, ...]] = None
depth_attentions: Optional[Tuple[torch.FloatTensor, ...]] = None
class DistillationWrapper(ts.PreTrainedModel):
config_class = ModernALBERTConfig
base_model_prefix = "model"
_no_split_modules = ["LlamaDecoderLayer", "FlowDecoderLayerGroup", "MimiTransformerLayer"]
_keys_to_ignore_on_load_missing = ["speech_tokenizer", "teacher"]
_tied_weights_keys = ["llm.decoder.weight"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
_tp_plan = []
def __init__(self, config, student=None, teacher=None):
super().__init__(config)
self.llm = ModernALBERTForMaskedLM(config)
if teacher != None:
self.teacher = teacher
else:
self.teacher = None
self.attention_loss = nn.KLDivLoss(reduction="mean")
self.hidden_loss = nn.CosineEmbeddingLoss(reduction="mean")
self.output_loss = nn.KLDivLoss(reduction="batchmean")
self.temperature = 1.0
@property
def device(self):
return next(self.parameters()).device
def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
device = self.device
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
if labels != None:
labels = labels.to(device)
student_outputs = self.llm(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
output_hidden_states=True,
# output_attentions=True,
**kwargs
)
hidden_loss = None
output_loss = None
if self.teacher != None:
with torch.no_grad():
input_ids = input_ids.to(self.teacher.device)
attention_mask = attention_mask.to(self.teacher.device)
teacher_outputs = self.teacher(input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
# output_attentions=True,
**kwargs
)
s_hiddens = student_outputs.hidden_states[-1]
t_hiddens = teacher_outputs.hidden_states[-1].detach()
s_logits = student_outputs.logits
t_logits = teacher_outputs.logits.detach()
hidden_loss = self.compute_hidden_loss(s_hiddens, t_hiddens, attention_mask)
output_loss = self.compute_output_loss(s_logits, t_logits, labels)
if self.teacher != None:
total_loss = (1.0 * student_outputs.loss) + (3.0 * hidden_loss) + (5.0 * output_loss)
else:
total_loss = student_outputs.loss
return DistillationOutputWithPasts(
loss=total_loss,
logits=student_outputs.logits,
hidden_states=student_outputs.hidden_states,
attentions=student_outputs.attentions,
)
def compute_output_loss(self, s_logits, t_logits, labels):
mask = (labels > -1).unsqueeze(-1)
s_logits_masked = s_logits.masked_fill(~mask, 0.0)
t_logits_masked = t_logits.masked_fill(~mask, 0.0)
s_logits_slct = s_logits_masked.view(-1, s_logits.size(-1))
t_logits_slct = t_logits_masked.view(-1, t_logits.size(-1))
valid_rows = mask.view(-1)
s_logits_slct = s_logits_slct[valid_rows]
t_logits_slct = t_logits_slct[valid_rows]
output_loss = (
self.output_loss(
nn.functional.log_softmax(s_logits_slct / self.temperature, dim=-1),
nn.functional.softmax(t_logits_slct / self.temperature, dim=-1),
)
* (self.temperature) ** 2
)
return output_loss
def compute_hidden_loss(self, s_hiddens, t_hiddens, attention_mask, lambdas=None):
s_hidden_states = s_hiddens
t_hidden_states = t_hiddens
assert s_hidden_states.size() == t_hidden_states.size()
dim = s_hidden_states.size(-1)
s_hidden_states_slct = s_hidden_states
t_hidden_states_slct = t_hidden_states
target = s_hidden_states_slct.new(s_hidden_states_slct.size(0)).fill_(1) # (bs * seq_length,)
hidden_loss = self.hidden_loss(s_hidden_states_slct, t_hidden_states_slct, target)
return hidden_loss
class DistillationWrapperForSequenceClassification(ts.PreTrainedModel):
config_class = ModernALBERTConfig
base_model_prefix = "model"
_no_split_modules = ["LlamaDecoderLayer", "FlowDecoderLayerGroup", "MimiTransformerLayer"]
_keys_to_ignore_on_load_missing = ["speech_tokenizer", "teacher"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.llm = ModernALBERTForSequenceClassification(config)
self.teacher = None
self.attention_loss = nn.KLDivLoss(reduction="mean")
self.hidden_loss = nn.CosineEmbeddingLoss(reduction="mean")
self.output_loss = nn.KLDivLoss(reduction="batchmean")
self.temperature = 1.0
@property
def device(self):
return next(self.parameters()).device
def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
device = self.device
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
if labels != None:
labels = labels.to(device)
student_outputs = self.llm(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
output_hidden_states=True,
# output_attentions=True,
# **kwargs
)
hidden_loss = None
output_loss = None
if self.teacher != None:
with torch.no_grad():
input_ids = input_ids.to(self.teacher.device)
attention_mask = attention_mask.to(self.teacher.device)
teacher_outputs = self.teacher(input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
# output_attentions=True,
# **kwargs)
)
s_hiddens = student_outputs.hidden_states[-1]
t_hiddens = teacher_outputs.hidden_states[-1].detach()
# print(s_hiddens.shape)
# print(t_hiddens.shape)
s_logits = student_outputs.logits
t_logits = teacher_outputs.logits.detach()
hidden_loss = self.compute_hidden_loss(s_hiddens, t_hiddens, attention_mask)
output_loss = self.compute_output_loss(s_logits, t_logits, labels)
if self.teacher != None:
total_loss = (1.0 * student_outputs.loss) + (3.0 * hidden_loss) + (5.0 * output_loss)
else:
total_loss = student_outputs.loss
return DistillationOutputWithPasts(
loss=total_loss,
logits=student_outputs.logits,
hidden_states=student_outputs.hidden_states,
attentions=student_outputs.attentions,
)
def compute_output_loss(self, s_logits, t_logits, labels):
mask = (labels > -1).unsqueeze(-1).expand_as(s_logits).bool()
s_logits_slct = torch.masked_select(s_logits, mask)
s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1))
t_logits_slct = torch.masked_select(t_logits, mask)
t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1))
assert t_logits_slct.size() == s_logits_slct.size()
output_loss = (
self.output_loss(
nn.functional.log_softmax(s_logits_slct / self.temperature, dim=-1),
nn.functional.softmax(t_logits_slct / self.temperature, dim=-1),
)
* (self.temperature) ** 2
)
return output_loss
def compute_hidden_loss(self, s_hiddens, t_hiddens, attention_mask, lambdas=None):
s_hidden_states = s_hiddens
t_hidden_states = t_hiddens
# mask = attention_mask.unsqueeze(-1).expand_as(s_hidden_states).bool() # (bs, seq_length, dim)
assert s_hidden_states.size() == t_hidden_states.size()
dim = s_hidden_states.size(-1)
# s_hidden_states_slct = torch.masked_select(s_hidden_states, mask) # (bs * seq_length * dim)
# s_hidden_states_slct = s_hidden_states_slct.view(-1, dim) # (bs * seq_length, dim)
# t_hidden_states_slct = torch.masked_select(t_hidden_states, mask) # (bs * seq_length * dim)
# t_hidden_states_slct = t_hidden_states_slct.view(-1, dim) # (bs * seq_length, dim)
s_hidden_states_slct = s_hidden_states
t_hidden_states_slct = t_hidden_states
target = s_hidden_states_slct.new(s_hidden_states_slct.size(0)).fill_(1) # (bs * seq_length,)
hidden_loss = self.hidden_loss(s_hidden_states_slct, t_hidden_states_slct, target)
return hidden_loss
class DistillationWrapperForQuestionAnswering(ts.PreTrainedModel):
config_class = ModernALBERTConfig
base_model_prefix = "model"
_no_split_modules = ["LlamaDecoderLayer", "FlowDecoderLayerGroup", "MimiTransformerLayer"]
_keys_to_ignore_on_load_missing = ["speech_tokenizer", "teacher"]
_supports_flash_attn_2 = True
_supports_sdpa = True
_supports_cache_class = True
def __init__(self, config):
super().__init__(config)
self.num_labels = config.num_labels
self.llm = ModernALBERTForQuestionAnswering(config)
self.teacher = None
self.attention_loss = nn.KLDivLoss(reduction="mean")
self.hidden_loss = nn.CosineEmbeddingLoss(reduction="mean")
self.output_loss = nn.KLDivLoss(reduction="batchmean")
self.temperature = 1.0
@property
def device(self):
return next(self.parameters()).device
def forward(self, input_ids=None, attention_mask=None, labels=None, **kwargs):
device = self.device
input_ids = input_ids.to(device)
attention_mask = attention_mask.to(device)
if labels != None:
labels = labels.to(device)
student_outputs = self.llm(
input_ids=input_ids,
attention_mask=attention_mask,
labels=labels,
output_hidden_states=True,
# output_attentions=True,
# **kwargs
)
hidden_loss = None
output_loss = None
if self.teacher != None:
with torch.no_grad():
input_ids = input_ids.to(self.teacher.device)
attention_mask = attention_mask.to(self.teacher.device)
teacher_outputs = self.teacher(input_ids=input_ids,
attention_mask=attention_mask,
output_hidden_states=True,
# output_attentions=True,
# **kwargs)
)
s_hiddens = student_outputs.hidden_states[-1]
t_hiddens = teacher_outputs.hidden_states[-1].detach()
# print(s_hiddens.shape)
# print(t_hiddens.shape)
s_logits = student_outputs.logits
t_logits = teacher_outputs.logits.detach()
hidden_loss = self.compute_hidden_loss(s_hiddens, t_hiddens, attention_mask)
output_loss = self.compute_output_loss(s_logits, t_logits, labels)
if self.teacher != None:
total_loss = (1.0 * student_outputs.loss) + (3.0 * hidden_loss) + (5.0 * output_loss)
else:
total_loss = student_outputs.loss
# return DistillationOutputWithPasts(
# loss=total_loss,
# logits=student_outputs.logits,
# hidden_states=student_outputs.hidden_states,
# attentions=student_outputs.attentions,
# )
return student_outputs
def compute_output_loss(self, s_logits, t_logits, labels):
mask = (labels > -1).unsqueeze(-1).expand_as(s_logits).bool()
s_logits_slct = torch.masked_select(s_logits, mask)
s_logits_slct = s_logits_slct.view(-1, s_logits.size(-1))
t_logits_slct = torch.masked_select(t_logits, mask)
t_logits_slct = t_logits_slct.view(-1, s_logits.size(-1))
assert t_logits_slct.size() == s_logits_slct.size()
output_loss = (
self.output_loss(
nn.functional.log_softmax(s_logits_slct / self.temperature, dim=-1),
nn.functional.softmax(t_logits_slct / self.temperature, dim=-1),
)
* (self.temperature) ** 2
)
return output_loss
def compute_hidden_loss(self, s_hiddens, t_hiddens, attention_mask, lambdas=None):
s_hidden_states = s_hiddens
t_hidden_states = t_hiddens
# mask = attention_mask.unsqueeze(-1).expand_as(s_hidden_states).bool() # (bs, seq_length, dim)
assert s_hidden_states.size() == t_hidden_states.size()
dim = s_hidden_states.size(-1)
# s_hidden_states_slct = torch.masked_select(s_hidden_states, mask) # (bs * seq_length * dim)
# s_hidden_states_slct = s_hidden_states_slct.view(-1, dim) # (bs * seq_length, dim)
# t_hidden_states_slct = torch.masked_select(t_hidden_states, mask) # (bs * seq_length * dim)
# t_hidden_states_slct = t_hidden_states_slct.view(-1, dim) # (bs * seq_length, dim)
s_hidden_states_slct = s_hidden_states
t_hidden_states_slct = t_hidden_states
target = s_hidden_states_slct.new(s_hidden_states_slct.size(0)).fill_(1) # (bs * seq_length,)
hidden_loss = self.hidden_loss(s_hidden_states_slct, t_hidden_states_slct, target)
return hidden_loss