import torch if not hasattr(torch.library, 'wrap_triton'): def wrap_triton(fn): return fn torch.library.wrap_triton = wrap_triton # Fix graph breaks from scalar outputs import torch._dynamo torch._dynamo.config.capture_scalar_outputs = True import torch.nn as nn import torch.nn.functional as F from dataclasses import dataclass from typing import Optional, Tuple, Union from transformers import PreTrainedModel, PretrainedConfig from transformers.modeling_outputs import MaskedLMOutput, BaseModelOutputWithPast, SequenceClassifierOutput import bert_padding from attention import FlexBertUnpadRopeAttention from torch.distributed import init_process_group, destroy_process_group from torch.nn.parallel import DistributedDataParallel as DDP import torch.distributed as dist try: from liger_kernel.transformers import LigerLayerNorm LayerNormClass = LigerLayerNorm except ImportError: LayerNormClass = nn.LayerNorm # HuggingFace-compatible Configuration class CustomTransformerConfig(PretrainedConfig): """ Configuration class for CustomTransformer model. This class stores the configuration of a CustomTransformer model and is compatible with HuggingFace's transformers library. It replaces the old ModelConfig dataclass. """ model_type = "custom_transformer" # auto_map tells HF which classes to use when loading with AutoModel/AutoConfig auto_map = { "AutoConfig": "model.CustomTransformerConfig", "AutoModel": "model.CustomTransformerModel", "AutoModelForMaskedLM": "model.CustomTransformerForMaskedLM", "AutoModelForSequenceClassification": "model.CustomTransformerForSequenceClassification", } def __init__( self, vocab_size: int = 50368, num_dims: int = 768, num_heads: int = 12, num_kv_heads: int = 12, num_layers: int = 12, ffn_hidden_dims: int = 1536, layernorm_eps: float = 1e-6, attention_probs_dropout_prob: float = 0.1, attn_qkv_bias: bool = False, attn_out_bias: bool = False, attn_out_dropout_prob: float = 0.0, global_attn_every_n_layers: int = 3, sliding_window: int = 128, rotary_emb_base: int = 10000, context_len: int = 128, use_cache: bool = False, use_flash: bool = True, use_moe: bool = True, moe_num_experts: int = 15, moe_routed_experts: int = 1, moe_eps: float = 1e-6, moe_aux_loss_coef: float = 0.01, moe_shared_experts: int = 1, use_lossfreebalance: bool = True, pad_token_id: int = 0, bos_token_id: int = 1, eos_token_id: int = 2, mask_token_id: int = 3, rope_theta: float = 1e5, ffn_dim_multiplier: Optional[int] = None, rotary_emb_dim: Optional[int] = None, local_attn_rotary_emb_base: int = -1, local_attn_rotary_emb_dim: Optional[int] = None, rotary_emb_scale_base: Optional[float] = None, rotary_emb_interleaved: bool = False, use_fa2: Optional[bool] = None, deterministic_fa2: bool = False, use_sdpa_attn_mask: bool = False, num_labels: int = 2, classifier_dropout: Optional[float] = None, **kwargs ): """Initialize CustomTransformerConfig.""" super().__init__( pad_token_id=pad_token_id, bos_token_id=bos_token_id, eos_token_id=eos_token_id, **kwargs ) self.vocab_size = vocab_size self.num_dims = num_dims self.num_heads = num_heads self.num_kv_heads = num_kv_heads self.num_layers = num_layers self.ffn_hidden_dims = ffn_hidden_dims self.layernorm_eps = layernorm_eps self.attention_probs_dropout_prob = attention_probs_dropout_prob self.attn_qkv_bias = attn_qkv_bias self.attn_out_bias = attn_out_bias self.attn_out_dropout_prob = attn_out_dropout_prob self.global_attn_every_n_layers = global_attn_every_n_layers self.sliding_window = sliding_window self.rotary_emb_base = rotary_emb_base self.context_len = context_len self.use_cache = use_cache self.use_flash = use_flash self.use_moe = use_moe self.moe_num_experts = moe_num_experts self.moe_routed_experts = moe_routed_experts self.moe_eps = moe_eps self.moe_aux_loss_coef = moe_aux_loss_coef self.moe_shared_experts = moe_shared_experts self.use_lossfreebalance = use_lossfreebalance self.mask_token_id = mask_token_id self.rope_theta = rope_theta self.ffn_dim_multiplier = ffn_dim_multiplier self.rotary_emb_dim = rotary_emb_dim self.local_attn_rotary_emb_base = local_attn_rotary_emb_base self.local_attn_rotary_emb_dim = local_attn_rotary_emb_dim self.rotary_emb_scale_base = rotary_emb_scale_base self.rotary_emb_interleaved = rotary_emb_interleaved self.use_fa2 = use_fa2 self.deterministic_fa2 = deterministic_fa2 self.use_sdpa_attn_mask = use_sdpa_attn_mask self.num_labels = num_labels self.classifier_dropout = classifier_dropout # Derived attributes for compatibility with attention module self.hidden_size = num_dims self.num_attention_heads = num_heads self.embedding_size = num_dims # Mirror old ModelConfig.__post_init__ if self.use_fa2 is None: self.use_fa2 = self.use_flash # Keep ModelConfig as a thin alias for backward compatibility with existing training scripts @dataclass class ModelConfig: vocab_size: int num_dims: int num_heads: int num_kv_heads: int num_layers: int ffn_hidden_dims: int context_len: int use_cache: bool use_flash: bool use_moe: bool moe_num_experts: int moe_routed_experts: int moe_eps: float = 1e-6 moe_aux_loss_coef: float = 0.00 moe_shared_experts: int = 0 use_lossfreebalance: bool = False layernorm_eps: float = 1e-6 rope_theta: float = 1e5 attention_probs_dropout_prob: float = 0.0 attn_qkv_bias: bool = False attn_out_bias: bool = False attn_out_dropout_prob: float = 0.0 global_attn_every_n_layers: int = 0 sliding_window: int = -1 rotary_emb_dim: Optional[int] = None rotary_emb_base: Optional[float] = None local_attn_rotary_emb_base: int = -1 local_attn_rotary_emb_dim: Optional[int] = None rotary_emb_scale_base: Optional[float] = None rotary_emb_interleaved: bool = False use_fa2: Optional[bool] = None deterministic_fa2: bool = False use_sdpa_attn_mask: bool = False hidden_size: Optional[int] = None num_attention_heads: Optional[int] = None embedding_size: Optional[int] = None ffn_dim_multiplier: Optional[int] = None def __post_init__(self): if self.hidden_size is None: self.hidden_size = self.num_dims if self.num_attention_heads is None: self.num_attention_heads = self.num_heads if self.rotary_emb_base is None: self.rotary_emb_base = self.rope_theta if self.use_fa2 is None: self.use_fa2 = self.use_flash # Model Layers class FlexBertUnpadAttention(nn.Module): """Thin wrapper that preserves the state_dict key path: block.attention.attn.* In ModernBERT-style global unpadding the data is already (total_nnz, dim) so this wrapper just forwards directly to FlexBertUnpadRopeAttention without any pad/unpad work. cu_seqlens, max_seqlen, indices, and attn_mask are passed through from the Transformer level. """ def __init__(self, config, layer_id: Optional[int] = None): super().__init__() self.attn = FlexBertUnpadRopeAttention(config=config, layer_id=layer_id) def forward( self, hidden_states: torch.Tensor, cu_seqlens: torch.Tensor, max_seqlen: int, indices: torch.Tensor, attn_mask: torch.Tensor, ) -> torch.Tensor: """Forward on already-unpadded data. Args: hidden_states: (total_nnz, dim) cu_seqlens: (batch + 1,) max_seqlen: int indices: (total_nnz,) attn_mask: (batch, seq_len) Returns: (total_nnz, dim) """ return self.attn( hidden_states=hidden_states, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, indices=indices, attn_mask=attn_mask, ) class FeedForward(nn.Module): """Default Feed Forward Layer. Works on both 2D (total_nnz, dim) and 3D inputs.""" def __init__(self, config): super().__init__() self.hidden_dim = config.ffn_hidden_dims self.w1 = nn.Linear(config.num_dims, self.hidden_dim, bias=False) self.w2 = nn.Linear(self.hidden_dim, config.num_dims, bias=False) self.w3 = nn.Linear(config.num_dims, self.hidden_dim, bias=False) self.act = nn.GELU() def forward(self, x: torch.Tensor): return self.w2(self.act(self.w1(x)) * self.w3(x)), None class FFNwMoE(nn.Module): """ Feed Forward with MoE with optional shared experts. Works on 2D (total_nnz, dim) unpadded inputs. Uses batched_mm (torch.bmm) for expert dispatch. Expert weights are stored as stacked nn.Parameters: (num_experts, out_dim, in_dim). Old checkpoints with per-expert nn.Linear weights are automatically converted at load time via _load_from_state_dict. Returns after forward: output: Combined outputs from experts aux_loss: Auxiliary loss tensor or routing metadata """ def __init__(self, config): super().__init__() self.hidden_dim = config.ffn_hidden_dims self.num_dims = config.num_dims self.moe_routed_experts = config.moe_routed_experts self.moe_aux_loss_coef = config.moe_aux_loss_coef self.moe_eps = config.moe_eps self.moe_shared_experts = config.moe_shared_experts self.num_experts = config.moe_num_experts self.use_lossfreebalance = config.use_lossfreebalance self.router = nn.Linear(config.num_dims, self.num_experts, bias=False) # Stacked expert weights — the actual trainable parameters # w1: projects dim -> hidden (gate) # w2: projects hidden -> dim (down) # w3: projects dim -> hidden (up) self.w1_stacked = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, config.num_dims)) self.w2_stacked = nn.Parameter(torch.empty(self.num_experts, config.num_dims, self.hidden_dim)) self.w3_stacked = nn.Parameter(torch.empty(self.num_experts, self.hidden_dim, config.num_dims)) # Initialize for i in range(self.num_experts): nn.init.kaiming_uniform_(self.w1_stacked.data[i]) nn.init.kaiming_uniform_(self.w2_stacked.data[i]) nn.init.kaiming_uniform_(self.w3_stacked.data[i]) # shared experts (for DeepSeekMoE) self.shared_experts = nn.ModuleList() for _ in range(self.moe_shared_experts): self.shared_experts.append( nn.ModuleList([ nn.Linear(config.num_dims, self.hidden_dim, bias=False), nn.Linear(self.hidden_dim, config.num_dims, bias=False), nn.Linear(config.num_dims, self.hidden_dim, bias=False) ])) # Auxiliary-loss-free load balancing strategy for MoE (DeepSeek) if self.use_lossfreebalance: self.expert_biases = nn.Parameter(torch.zeros(self.num_experts)) def forward(self, x: torch.Tensor): # x can be (total_nnz, dim) or (batch, seq_len, dim) input_shape = x.shape if x.ndim == 3: c_batch_size, c_context_len, c_dim = input_shape x_flat = x.view(-1, c_dim) else: x_flat = x c_dim = x.shape[-1] router_out = self.router(x_flat) router_probs = F.softmax(router_out, dim=-1) _, topk_indices = router_out.topk(self.moe_routed_experts, dim=-1) self.last_topk_indices = topk_indices.detach() aux_loss, topk_probs = self._compute_aux_loss(router_out, router_probs, topk_indices) output = self._compute_expert_outputs(x_flat, topk_indices, topk_probs, router_probs) if x.ndim == 3: output = output.view(c_batch_size, c_context_len, c_dim) return output, aux_loss def _compute_aux_loss(self, router_out, router_probs, topk_indices): """Computes the auxiliary loss based on whether loss-free balancing is used or not.""" if not self.use_lossfreebalance: topk_probs, _ = router_probs.topk(self.moe_routed_experts, dim=-1) expert_mask = F.one_hot(topk_indices[:, 0], self.num_experts).float() density = expert_mask.mean(dim=0) router_prob_mean = router_probs.mean(dim=0) aux_loss = self.moe_aux_loss_coef * torch.sum(density * router_prob_mean) * self.num_experts else: router_out = router_out + self.expert_biases router_probs = torch.sigmoid(router_out) topk_probs = router_probs.gather(-1, topk_indices) topk_probs = topk_probs / topk_probs.sum(dim=-1, keepdim=True) aux_loss = (router_probs, topk_indices) return aux_loss, topk_probs def _compute_expert_outputs(self, x_flat, topk_indices, topk_probs, router_probs): """Compute expert outputs using sort-based dispatch with stacked weights. Sort tokens by expert, slice contiguous chunks, run each expert via matmul on the stacked weight tensors. No weight duplication, minimal memory overhead. """ num_tokens, dim = x_flat.shape # Flatten top-k: (num_tokens * top_k,) flat_expert_ids = topk_indices.view(-1) flat_probs = topk_probs.view(-1) flat_token_ids = torch.arange(num_tokens, device=x_flat.device).unsqueeze(1).expand(-1, self.moe_routed_experts).reshape(-1) # Sort by expert id for contiguous batching sorted_expert_ids, sort_indices = flat_expert_ids.sort(stable=True) sorted_token_ids = flat_token_ids[sort_indices] sorted_probs = flat_probs[sort_indices] # Gather sorted input tokens sorted_x = x_flat[sorted_token_ids] # (num_tokens * top_k, dim) # Find expert boundaries expert_counts = torch.bincount(sorted_expert_ids, minlength=self.num_experts) expert_offsets = torch.zeros(self.num_experts + 1, dtype=torch.long, device=x_flat.device) torch.cumsum(expert_counts, dim=0, out=expert_offsets[1:]) # Run each expert on its contiguous slice using stacked weights sorted_output = torch.zeros_like(sorted_x) for expert_id in range(self.num_experts): start = expert_offsets[expert_id].item() end = expert_offsets[expert_id + 1].item() if start == end: continue expert_input = sorted_x[start:end] # (n_tokens, dim) # Use stacked weights directly: w1[expert_id] is (hidden, dim) h1 = F.linear(expert_input, self.w1_stacked[expert_id]) # (n, hidden) h3 = F.linear(expert_input, self.w3_stacked[expert_id]) # (n, hidden) h = F.gelu(h1) * h3 sorted_output[start:end] = F.linear(h, self.w2_stacked[expert_id]) # (n, dim) # Weight by router probabilities sorted_output = sorted_output * sorted_probs.unsqueeze(-1) # Scatter back to original token positions output = torch.zeros_like(x_flat) output.scatter_add_(0, sorted_token_ids.unsqueeze(-1).expand_as(sorted_output), sorted_output) # Shared experts (for DeepSeekMoE) — unchanged for shared_expert_id in range(self.moe_shared_experts): w1, w2, w3 = self.shared_experts[shared_expert_id] expert_output = w2(F.gelu(w1(x_flat)) * w3(x_flat)) output = output + expert_output return output class Block(nn.Module): """Transformer block operating on unpadded (total_nnz, dim) tensors. Receives unpadding metadata (cu_seqlens, max_seqlen, indices, attn_mask) from the Transformer level and passes them to attention. Norms and FFN operate directly on the 2D unpadded tensor, avoiding wasted compute on padding tokens. """ def __init__(self, config, layer_id: Optional[int] = None): super().__init__() self.is_first_block = (layer_id == 0) self.attention = FlexBertUnpadAttention(config, layer_id=layer_id) if config.use_moe: self.ffn = FFNwMoE(config) else: self.ffn = FeedForward(config) self.norm_attention = LayerNormClass(config.num_dims, eps=config.layernorm_eps) self.norm_ffn = LayerNormClass(config.num_dims, eps=config.layernorm_eps) def forward(self, x, cu_seqlens, max_seqlen, indices, attn_mask): """ Args: x: (total_nnz, dim) - unpadded hidden states cu_seqlens: (batch + 1,) max_seqlen: int indices: (total_nnz,) attn_mask: (batch, seq_len) Returns: x: (total_nnz, dim) aux_loss: auxiliary loss from MoE or None """ if self.is_first_block: attn_in = x else: attn_in = self.norm_attention(x) x = x + self.attention( attn_in, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, indices=indices, attn_mask=attn_mask, ) ffn_out, aux_loss = self.ffn( self.norm_ffn(x) ) x = x + ffn_out return x, aux_loss # Core Transformer (nn.Module backbone used inside HF wrappers) class Transformer(nn.Module): """ModernBERT-style Transformer: unpad once before embeddings, repad once at the end. All blocks, norms, and FFNs operate on (total_nnz, dim) tensors, avoiding wasted compute on padding tokens. """ def __init__(self, config): super().__init__() self.vocab_size = config.vocab_size self.num_dims = config.num_dims self.num_heads = config.num_heads self.context_len = config.context_len self.use_moe = config.use_moe self.use_lossfreebalance = config.use_lossfreebalance and self.use_moe self.num_layers = config.num_layers hidden_dim = 4 * config.num_dims self.tokens_embedding = nn.Embedding(config.vocab_size, config.num_dims) self.norm_embeddings = LayerNormClass(config.num_dims, eps=config.layernorm_eps) self.blocks = nn.ModuleList() for layer_id in range(self.num_layers): self.blocks.append(Block(config, layer_id=layer_id)) self.norm = LayerNormClass(config.num_dims, eps=config.layernorm_eps) self.ll_head = nn.Linear(config.num_dims, config.vocab_size, bias=False) self.tokens_embedding.weight = self.ll_head.weight def _unpad(self, input_ids, attention_mask): """Compute unpadding metadata and unpad input_ids before embedding. Unpads input_ids (cheap 1D integer indexing) so that embedding and all subsequent layers only process real tokens. Args: input_ids: (batch, seq_len) attention_mask: (batch, seq_len) or None Returns: input_ids_unpadded: (total_nnz,) indices: (total_nnz,) cu_seqlens: (batch + 1,) max_seqlen: int attn_mask: (batch, seq_len) batch_size: int seq_len: int """ batch_size, seq_len = input_ids.shape if attention_mask is None: attn_mask = torch.ones((batch_size, seq_len), device=input_ids.device, dtype=torch.int32) else: attn_mask = attention_mask.to(dtype=torch.int32) # Unpad input_ids using the same bert_padding logic but on (batch, seq_len, 1) # so we can reuse unpad_input which expects 3D input_ids_3d = input_ids.unsqueeze(-1).float() # (batch, seq_len, 1) input_ids_unpadded, indices, cu_seqlens, max_seqlen = bert_padding.unpad_input(input_ids_3d, attn_mask) input_ids_unpadded = input_ids_unpadded.squeeze(-1).long() # (total_nnz,) return input_ids_unpadded, indices, cu_seqlens, max_seqlen, attn_mask, batch_size, seq_len def forward( self, x: torch.Tensor, targets: Optional[torch.Tensor] = None, start_pos: int = 0, attention_mask: Optional[torch.Tensor] = None, ): batch_size, seq_len = x.shape # Unpad input_ids before embedding — only embed real tokens x_unpadded, indices, cu_seqlens, max_seqlen, attn_mask, batch_size, seq_len = self._unpad(x, attention_mask) # Embed only real tokens (total_nnz, dim) x = self.tokens_embedding(x_unpadded) x = self.norm_embeddings(x) total_aux_loss = 0 for block in self.blocks: x, aux_loss = block( x, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, indices=indices, attn_mask=attn_mask, ) if self.use_moe and not self.use_lossfreebalance: total_aux_loss += aux_loss x = self.norm(x) # Repad once — back to (batch, seq_len, dim) for the LM head / loss x = bert_padding.pad_input(x, indices, batch_size, seq_len) logits = self.ll_head(x) if targets is None: loss = None ce_loss = None else: c_batch_size, c_context_len, c_dim = logits.shape logits = logits.view(c_batch_size * c_context_len, c_dim) targets = targets.view(c_batch_size * c_context_len) ce_loss = F.cross_entropy(logits, targets) if self.use_moe and not self.use_lossfreebalance: loss = ce_loss + total_aux_loss else: loss = ce_loss ce_loss = aux_loss return logits, loss, ce_loss @torch.no_grad() def generate(self, x: torch.Tensor, max_tokens: int, temperature: float = 1.0, top_k: int = 50, use_cache: bool = False): """Generate text from x up to max_tokens.""" for c_tkn_pos in range(max_tokens): if use_cache: if c_tkn_pos == 0: logits, _, ce_loss = self.forward(x, start_pos=c_tkn_pos) else: logits, _, ce_loss = self.forward(x[:, -1:], start_pos=c_tkn_pos) else: logits, _, ce_loss = self.forward(x) logits = logits[:, -1, :] / temperature if top_k is not None: tkl, idx = torch.topk(logits, min(top_k, logits.size(-1))) logits[logits < tkl[:, [-1]]] = -float('Inf') probs = F.softmax(logits, dim=-1) next_token = torch.multinomial(probs, num_samples=1) x = torch.cat((x, next_token), dim=1) return x # HuggingFace PreTrainedModel Wrappers class CustomTransformerPreTrainedModel(PreTrainedModel): """Base class for CustomTransformer models.""" config_class = CustomTransformerConfig base_model_prefix = "transformer" supports_gradient_checkpointing = False _no_split_modules = ["Block"] def _init_weights(self, module): """Initialize weights - handled by model itself.""" pass class CustomTransformerModel(CustomTransformerPreTrainedModel): """The bare CustomTransformer Model outputting raw hidden-states.""" def __init__(self, config: CustomTransformerConfig): super().__init__(config) self.config = config self.transformer = Transformer(config) self.post_init() def get_input_embeddings(self): return self.transformer.tokens_embedding def set_input_embeddings(self, value): self.transformer.tokens_embedding = value def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, BaseModelOutputWithPast]: """Forward pass returning raw hidden states.""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Unpad input_ids before embedding x_unpadded, indices, cu_seqlens, max_seqlen, attn_mask, batch_size, seq_len = self.transformer._unpad(input_ids, attention_mask) # Embed only real tokens x = self.transformer.tokens_embedding(x_unpadded) x = self.transformer.norm_embeddings(x) for block in self.transformer.blocks: x, _ = block(x, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, indices=indices, attn_mask=attn_mask) x = self.transformer.norm(x) # Repad once hidden_states = bert_padding.pad_input(x, indices, batch_size, seq_len) if not return_dict: return (hidden_states,) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=None, hidden_states=None, attentions=None, ) class CustomTransformerForMaskedLM(CustomTransformerPreTrainedModel): """CustomTransformer Model with a masked language modeling head on top.""" _tied_weights_keys = ["transformer.ll_head.weight", "transformer.tokens_embedding.weight"] def __init__(self, config: CustomTransformerConfig): super().__init__(config) self.config = config self.transformer = Transformer(config) self.post_init() def get_input_embeddings(self): return self.transformer.tokens_embedding def set_input_embeddings(self, value): self.transformer.tokens_embedding = value def get_output_embeddings(self): return self.transformer.ll_head def set_output_embeddings(self, new_embeddings): self.transformer.ll_head = new_embeddings def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, MaskedLMOutput]: """Forward pass for masked language modeling.""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict logits, model_loss, ce_loss = self.transformer( input_ids, targets=labels, start_pos=0, attention_mask=attention_mask ) masked_lm_loss = None if labels is not None: masked_lm_loss = model_loss if not return_dict: output = (logits,) return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output return MaskedLMOutput( loss=masked_lm_loss, logits=logits, hidden_states=None, attentions=None, ) class CustomTransformerForSequenceClassification(CustomTransformerPreTrainedModel): """CustomTransformer Model with a sequence classification head on top.""" def __init__(self, config: CustomTransformerConfig): super().__init__(config) self.num_labels = config.num_labels self.config = config self.transformer = Transformer(config) # Classification head classifier_dropout = ( config.classifier_dropout if config.classifier_dropout is not None else config.attention_probs_dropout_prob ) self.dropout = nn.Dropout(classifier_dropout) self.classifier = nn.Linear(config.num_dims, config.num_labels) self._init_classifier_weights() self.post_init() def _init_classifier_weights(self): std = 0.02 if isinstance(self.classifier, nn.Linear): self.classifier.weight.data.normal_(mean=0.0, std=std) if self.classifier.bias is not None: self.classifier.bias.data.zero_() def forward( self, input_ids: Optional[torch.LongTensor] = None, attention_mask: Optional[torch.FloatTensor] = None, position_ids: Optional[torch.LongTensor] = None, head_mask: Optional[torch.FloatTensor] = None, inputs_embeds: Optional[torch.FloatTensor] = None, labels: Optional[torch.LongTensor] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, ) -> Union[Tuple, SequenceClassifierOutput]: """Forward pass for sequence classification.""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states # Unpad input_ids before embedding x_unpadded, indices, cu_seqlens, max_seqlen, attn_mask, batch_size, seq_len = self.transformer._unpad(input_ids, attention_mask) # Embed only real tokens x = self.transformer.tokens_embedding(x_unpadded) x = self.transformer.norm_embeddings(x) # Collect hidden states if requested (repad each for the output tuple) all_hidden_states = () if output_hidden_states else None if output_hidden_states: all_hidden_states = all_hidden_states + (bert_padding.pad_input(x, indices, batch_size, seq_len),) for block in self.transformer.blocks: x, _ = block(x, cu_seqlens=cu_seqlens, max_seqlen=max_seqlen, indices=indices, attn_mask=attn_mask) if output_hidden_states: all_hidden_states = all_hidden_states + (bert_padding.pad_input(x, indices, batch_size, seq_len),) x = self.transformer.norm(x) # Repad once hidden_states = bert_padding.pad_input(x, indices, batch_size, seq_len) # Use [CLS] token representation (first token) for classification pooled_output = hidden_states[:, 0, :] pooled_output = self.dropout(pooled_output) logits = self.classifier(pooled_output) loss = None if labels is not None: if self.config.problem_type is None: if self.num_labels == 1: self.config.problem_type = "regression" elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int): self.config.problem_type = "single_label_classification" else: self.config.problem_type = "multi_label_classification" if self.config.problem_type == "regression": loss_fct = nn.MSELoss() if self.num_labels == 1: loss = loss_fct(logits.squeeze(), labels.squeeze()) else: loss = loss_fct(logits, labels) elif self.config.problem_type == "single_label_classification": loss_fct = nn.CrossEntropyLoss() loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1)) elif self.config.problem_type == "multi_label_classification": loss_fct = nn.BCEWithLogitsLoss() loss = loss_fct(logits, labels) if not return_dict: output = (logits,) + (all_hidden_states,) + (None,) return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=all_hidden_states, attentions=None, ) # Auto-registration try: from transformers import AutoConfig, AutoModel, AutoModelForMaskedLM, AutoModelForSequenceClassification AutoConfig.register("custom_transformer", CustomTransformerConfig) AutoModel.register(CustomTransformerConfig, CustomTransformerModel) AutoModelForMaskedLM.register(CustomTransformerConfig, CustomTransformerForMaskedLM) AutoModelForSequenceClassification.register(CustomTransformerConfig, CustomTransformerForSequenceClassification) except Exception: pass def main(): pass if __name__ == "__main__": main()