import torch import torch.nn as nn import torch.nn.functional as F from .rope_embeddings import RotaryEmbedding from transformers import PreTrainedModel, AutoConfig, AutoModel from transformers.modeling_outputs import SequenceClassifierOutput from .configuration_emcoder import EmCoderConfig class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x: torch.Tensor) -> torch.Tensor: variance = x.pow(2).mean(-1, keepdim=True) return x * torch.rsqrt(variance + self.eps) * self.weight class SwiGLU(nn.Module): def __init__(self, d_model: int, d_ffn: int): super().__init__() self.wi = nn.Linear(d_model, 2 * d_ffn, bias=False) self.wo = nn.Linear(d_ffn, d_model, bias=False) def forward(self, x: torch.Tensor) -> torch.Tensor: x1, x2 = self.wi(x).chunk(2, dim=-1) return self.wo(x1 * F.silu(x2)) class EmCoderEncoderLayer(nn.Module): """Custom Pre-LN Transformer Encoder Layer with RoPE and FlashAttention.""" def __init__(self, config: EmCoderConfig, rope: RotaryEmbedding): super().__init__() self.n_head = config.n_head self.d_head = config.d_model // config.n_head self.rope = rope # Attention projections self.q_proj = nn.Linear(config.d_model, config.d_model, bias=False) self.k_proj = nn.Linear(config.d_model, config.d_model, bias=False) self.v_proj = nn.Linear(config.d_model, config.d_model, bias=False) self.out_proj = nn.Linear(config.d_model, config.d_model, bias=False) self.ln1 = RMSNorm(config.d_model) self.ln2 = RMSNorm(config.d_model) self.ffn = SwiGLU(config.d_model, config.d_ffn) self.dropout = nn.Dropout(config.dropout) # mark for initialization self.out_proj._is_residual = True self.ffn.wo._is_residual = True def forward(self, x: torch.Tensor, attn_mask: torch.Tensor) -> torch.Tensor: # MULTI-HEAD ATTENTION residual = x nx = self.ln1(x) B, S, _ = nx.shape # Projections -> (B, H, S, D_head) q = self.q_proj(nx).view(B, S, self.n_head, self.d_head).transpose(1, 2) k = self.k_proj(nx).view(B, S, self.n_head, self.d_head).transpose(1, 2) v = self.v_proj(nx).view(B, S, self.n_head, self.d_head).transpose(1, 2) q = self.rope.rotate_queries_or_keys(q) k = self.rope.rotate_queries_or_keys(k) attn_out = F.scaled_dot_product_attention( q, k, v, attn_mask=attn_mask, dropout_p=self.dropout.p if self.dropout.training else 0.0, ) # Join heads -> (B, S, D_model) attn_out = attn_out.transpose(1, 2).contiguous().view(B, S, -1) x = residual + self.dropout(self.out_proj(attn_out)) x = x + self.dropout(self.ffn(self.ln2(x))) return x class EmCoderEncoder(nn.Module): """The core encoder architecture of EmCoder Transformer.""" def __init__(self, config: EmCoderConfig): super().__init__() self.token_embedding = nn.Embedding(config.vocab_size, config.d_model) self.embed_norm = RMSNorm(config.d_model) self.dropout = nn.Dropout(config.dropout) self.rope = RotaryEmbedding(dim=config.d_model // config.n_head) self.layers = nn.ModuleList( [EmCoderEncoderLayer(config, self.rope) for _ in range(config.n_layers)] ) self.final_norm = RMSNorm(config.d_model) def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: """Standard forward pass through the encoder.""" x = self.token_embedding(x) x = self.embed_norm(x) x = self.dropout(x) B, S = mask.shape attn_mask = mask.view(B, 1, 1, S).to(dtype=torch.bool) for layer in self.layers: x = layer(x, attn_mask) return self.final_norm(x) class EmCoder(PreTrainedModel): """The full EmCoder model, including the backbone encoder and the classification head.""" config_class = EmCoderConfig def __init__(self, config: EmCoderConfig): super().__init__(config) self.encoder = EmCoderEncoder(config) self.classifier = nn.Sequential( nn.Linear(config.d_model, config.d_model), nn.GELU(), nn.Dropout(config.dropout), nn.Linear(config.d_model, config.num_labels), ) self.post_init() def _init_weights(self, module: nn.Module) -> None: if isinstance(module, nn.Linear): # scale down the init for residual connections if getattr(module, "_is_residual", False): std = 0.02 / ((2 * self.config.n_layers) ** 0.5) else: std = 0.02 nn.init.trunc_normal_(module.weight, std=std) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.trunc_normal_(module.weight, std=0.02) elif isinstance(module, RMSNorm): nn.init.ones_(module.weight) def _set_mc_dropout(self, active: bool = True): for m in self.modules(): if isinstance(m, nn.Dropout): m.train(active) @staticmethod def _masked_mean_pooling( features: torch.Tensor, mask: torch.Tensor ) -> torch.Tensor: mask = mask.unsqueeze(-1) # (B, S, 1) masked_features = features * mask # (B, S, D) sum_masked_features = masked_features.sum(dim=1) # (B, D) count_tokens = torch.clamp(mask.sum(dim=1), min=1e-9) # (B, 1) return sum_masked_features / count_tokens # (B, D) def mc_forward( self, input_ids: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, labels: torch.Tensor | None = None, n_samples: int = 10, max_batch_size: int | None = None, return_dict: bool | None = None, **kwargs, ) -> tuple[torch.Tensor, ...] | SequenceClassifierOutput: """ Performs Monte Carlo Dropout inference to quantify uncertainty. Args: input_ids: Input token IDs of shape (B, S). attention_mask: Attention mask of shape (B, S). n_samples: Total number of Monte Carlo samples. max_batch_size: Maximum number of samples in one forward pass. Returns: Logits of shape (n_samples, B, num_labels). """ return_dict = return_dict if return_dict is not None else True x = input_ids if input_ids is not None else kwargs.get("x") mask = attention_mask if attention_mask is not None else kwargs.get("mask") if x is None or mask is None: raise ValueError("input_ids (x) and attention_mask (mask) must be provided") if max_batch_size is None: max_batch_size = n_samples B, S = x.shape num_labels = self.classifier[-1].out_features all_logits = torch.empty((n_samples, B, num_labels), device=x.device) is_training = self.training self._set_mc_dropout(active=True) try: with torch.no_grad(): for i in range(0, n_samples, max_batch_size): batch_samples = min(max_batch_size, n_samples - i) x_stacked = x.repeat(batch_samples, 1) # (batch_samples * B, S) mask_stacked = mask.repeat(batch_samples, 1) # (batch_samples * B, S) features = self.encoder( x_stacked, mask_stacked ) # (batch_samples * B, S, D) pooled = self._masked_mean_pooling(features, mask_stacked) logits = self.classifier(pooled) # (n_samples * B, num_labels) all_logits[i : i + batch_samples] = logits.view(batch_samples, B, -1) finally: self._set_mc_dropout(active=is_training) loss = None if labels is not None: loss_fct = nn.BCEWithLogitsLoss() loss = loss_fct(all_logits.mean(dim=0), labels.to(all_logits.dtype)) if not return_dict: output = (all_logits,) return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=all_logits, hidden_states=None, attentions=None, ) def forward( self, input_ids: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, labels: torch.Tensor | None = None, return_dict: bool | None = None, **kwargs, ) -> tuple[torch.Tensor, ...] | SequenceClassifierOutput: """Standard forward pass without MC Dropout.""" return_dict = return_dict if return_dict is not None else True x = input_ids if input_ids is not None else kwargs.get("x") mask = attention_mask if attention_mask is not None else kwargs.get("mask") if x is None or mask is None: raise ValueError("input_ids (x) and attention_mask (mask) must be provided") features = self.encoder(x, mask) pooled = self._masked_mean_pooling(features, mask) logits = self.classifier(pooled) loss = None if labels is not None: loss_fct = nn.BCEWithLogitsLoss() loss = loss_fct(logits, labels.to(logits.dtype)) if not return_dict: output = (logits,) return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=None, attentions=None, ) try: AutoConfig.register("emcoder", EmCoderConfig) AutoModel.register(EmCoderConfig, EmCoder) except ValueError: pass