| import torch |
| import torch.nn as nn |
| from transformers import PreTrainedModel |
| from .configuration_emcoder import EmCoderConfig |
|
|
|
|
| class EmCoderCore(nn.Module): |
| """The core encoder architecture of EmCoder, without the classification head.""" |
| def __init__(self, config: EmCoderConfig): |
| super().__init__() |
|
|
| self.token_embedding = nn.Embedding(config.vocab_size, config.d_model) |
| self.pos_embedding = nn.Embedding(config.max_seq_len, config.d_model) |
| self.embed_norm = nn.LayerNorm(config.d_model) |
|
|
| encoder_layer = nn.TransformerEncoderLayer( |
| d_model=config.d_model, |
| nhead=config.n_head, |
| dim_feedforward=config.d_ffn, |
| dropout=config.dropout, |
| activation="gelu", |
| norm_first=True, |
| batch_first=True |
| ) |
| self.encoder = nn.TransformerEncoder( |
| encoder_layer=encoder_layer, |
| num_layers=config.n_layers |
| ) |
| |
| self.final_norm = nn.LayerNorm(config.d_model) |
| self.dropout = nn.Dropout(config.dropout) |
|
|
|
|
| def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
| """Standard forward pass through the encoder.""" |
| seq_len = x.size(1) |
| pos_ids = torch.arange(seq_len, device=x.device).unsqueeze(0) |
|
|
| x = self.token_embedding(x) + self.pos_embedding(pos_ids) |
|
|
| x = self.embed_norm(x) |
| x = self.dropout(x) |
|
|
| padding_mask = (mask == 0) |
|
|
| encoded = self.encoder(x, src_key_padding_mask=padding_mask) |
| return self.final_norm(encoded) |
| |
|
|
| class EmCoder(PreTrainedModel): |
| """The full EmCoder model, including the classification head.""" |
| config_class = EmCoderConfig |
|
|
| def __init__(self, config: EmCoderConfig): |
| super().__init__(config) |
|
|
| self.encoder = EmCoderCore(config) |
| self.classifier = nn.Sequential( |
| nn.Linear(config.d_model, config.d_model), |
| nn.GELU(), |
| nn.Dropout(config.dropout), |
| nn.Linear(config.d_model, config.num_labels) |
| ) |
|
|
| self.post_init() |
|
|
| def _set_mc_dropout(self, active: bool = True): |
| for m in self.modules(): |
| if isinstance(m, nn.Dropout): |
| m.train(active) |
|
|
| @staticmethod |
| def _masked_mean_pooling(features: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
| mask = mask.unsqueeze(-1) |
| masked_features = features * mask |
| sum_masked_features = masked_features.sum(dim=1) |
| count_tokens = torch.clamp(mask.sum(dim=1), min=1e-9) |
| return sum_masked_features / count_tokens |
|
|
| def mc_forward(self, x: torch.Tensor, mask: torch.Tensor, n_samples: int) -> torch.Tensor: |
| """Performs Monte Carlo Dropout inference to quantify epistemic uncertainty.""" |
| self._set_mc_dropout(active=True) |
|
|
| B, S = x.shape |
| x_stacked = x.repeat(n_samples, 1) |
| mask_stacked = mask.repeat(n_samples, 1) |
|
|
| features = self.encoder(x_stacked, mask_stacked) |
| pooled = self._masked_mean_pooling(features, mask_stacked) |
| logits = self.classifier(pooled) |
|
|
| return logits.view(n_samples, B, -1) |
|
|
|
|
| def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: |
| """Standard forward pass without MC Dropout.""" |
| features = self.encoder(x, mask) |
|
|
| pooled = self._masked_mean_pooling(features, mask) |
| return self.classifier(pooled) |