import torch import torch.nn as nn from transformers import PreTrainedModel, AutoConfig, AutoModel from .configuration_emcoder import EmCoderConfig class EmCoderEncoder(nn.Module): """The core encoder architecture of EmCoder Transformer.""" def __init__(self, config: EmCoderConfig): super().__init__() self.token_embedding = nn.Embedding(config.vocab_size, config.d_model) self.pos_embedding = nn.Embedding(config.max_seq_len, config.d_model) self.embed_norm = nn.LayerNorm(config.d_model) encoder_layer = nn.TransformerEncoderLayer( d_model=config.d_model, nhead=config.n_head, dim_feedforward=config.d_ffn, dropout=config.dropout, activation="gelu", norm_first=True, batch_first=True, ) self.encoder = nn.TransformerEncoder( encoder_layer=encoder_layer, num_layers=config.n_layers, enable_nested_tensor=False ) self.final_norm = nn.LayerNorm(config.d_model) self.dropout = nn.Dropout(config.dropout) def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: """Standard forward pass through the encoder.""" seq_len = x.size(1) pos_ids = torch.arange(seq_len, device=x.device).unsqueeze(0) x = self.token_embedding(x) + self.pos_embedding(pos_ids) x = self.embed_norm(x) x = self.dropout(x) padding_mask = mask == 0 encoded = self.encoder(x, src_key_padding_mask=padding_mask) return self.final_norm(encoded) class EmCoder(PreTrainedModel): """The full EmCoder model, including the classification head.""" config_class = EmCoderConfig def __init__(self, config: EmCoderConfig): super().__init__(config) self.encoder = EmCoderEncoder(config) self.classifier = nn.Sequential( nn.Linear(config.d_model, config.d_model), nn.GELU(), nn.Dropout(config.dropout), nn.Linear(config.d_model, config.num_labels), ) self.post_init() def _init_weights(self, module: nn.Module) -> None: if isinstance(module, nn.Linear): nn.init.trunc_normal_(module.weight, std=0.02) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Embedding): nn.init.trunc_normal_(module.weight, std=0.02) if hasattr(module, "padding_idx") and module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() elif isinstance(module, nn.LayerNorm): nn.init.ones_(module.weight) nn.init.zeros_(module.bias) def _set_mc_dropout(self, active: bool = True): for m in self.modules(): if isinstance(m, nn.Dropout) or isinstance(m, nn.MultiheadAttention): m.train(active) @staticmethod def _masked_mean_pooling( features: torch.Tensor, mask: torch.Tensor ) -> torch.Tensor: mask = mask.unsqueeze(-1) # (B, S, 1) masked_features = features * mask # (B, S, D) sum_masked_features = masked_features.sum(dim=1) # (B, D) count_tokens = torch.clamp(mask.sum(dim=1), min=1e-9) # (B, 1) return sum_masked_features / count_tokens # (B, D) def mc_forward( self, input_ids: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, n_samples: int = 10, max_batch_size: int | None = None, return_dict: bool | None = None, **kwargs, ) -> torch.Tensor: """ Performs Monte Carlo Dropout inference to quantify epistemic uncertainty. Args: x: Input token IDs of shape (B, S). mask: Attention mask of shape (B, S). n_samples: Total number of Monte Carlo samples. max_batch_size: Maximum number of samples in one forward pass. Returns: Logits of shape (n_samples, B, num_labels). """ x = input_ids if input_ids is not None else kwargs.get("x") mask = attention_mask if attention_mask is not None else kwargs.get("mask") if x is None or mask is None: raise ValueError("input_ids (x) and attention_mask (mask) must be provided") if max_batch_size is None: max_batch_size = n_samples B, S = x.shape num_labels = self.classifier[-1].out_features all_logits = torch.empty((n_samples, B, num_labels), device=x.device) is_training = self.training self._set_mc_dropout(active=True) try: for i in range(0, n_samples, max_batch_size): batch_samples = min(max_batch_size, n_samples - i) x_stacked = x.repeat(batch_samples, 1) # (batch_samples * B, S) mask_stacked = mask.repeat(batch_samples, 1) # (batch_samples * B, S) features = self.encoder( x_stacked, mask_stacked ) # (batch_samples * B, S, D) pooled = self._masked_mean_pooling(features, mask_stacked) logits = self.classifier(pooled) # (n_samples * B, num_labels) all_logits[i : i + batch_samples] = logits.view(batch_samples, B, -1) finally: self._set_mc_dropout(active=is_training) return all_logits def forward( self, input_ids: torch.Tensor | None = None, attention_mask: torch.Tensor | None = None, return_dict: bool | None = None, **kwargs, ) -> torch.Tensor: """Standard forward pass without MC Dropout.""" x = input_ids if input_ids is not None else kwargs.get("x") mask = attention_mask if attention_mask is not None else kwargs.get("mask") if x is None or mask is None: raise ValueError("input_ids (x) and attention_mask (mask) must be provided") features = self.encoder(x, mask) pooled = self._masked_mean_pooling(features, mask) return self.classifier(pooled) try: AutoConfig.register("emcoder", EmCoderConfig) AutoModel.register(EmCoderConfig, EmCoder) except ValueError: pass