File size: 3,553 Bytes
296800d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from .configuration_emcoder import EmCoderConfig


class EmCoderCore(nn.Module):
    """The core encoder architecture of EmCoder, without the classification head."""
    def __init__(self, config: EmCoderConfig):
        super().__init__()

        self.token_embedding = nn.Embedding(config.vocab_size, config.d_model)
        self.pos_embedding = nn.Embedding(config.max_seq_len, config.d_model)
        self.embed_norm = nn.LayerNorm(config.d_model)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=config.d_model,
            nhead=config.n_head,
            dim_feedforward=config.d_ffn,
            dropout=config.dropout,
            activation="gelu",
            norm_first=True,
            batch_first=True
        )
        self.encoder = nn.TransformerEncoder(
            encoder_layer=encoder_layer,
            num_layers=config.n_layers
        )
        
        self.final_norm = nn.LayerNorm(config.d_model)
        self.dropout = nn.Dropout(config.dropout)


    def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        """Standard forward pass through the encoder."""
        seq_len = x.size(1)
        pos_ids = torch.arange(seq_len, device=x.device).unsqueeze(0)

        x = self.token_embedding(x) + self.pos_embedding(pos_ids)

        x = self.embed_norm(x)
        x = self.dropout(x)

        padding_mask = (mask == 0)

        encoded = self.encoder(x, src_key_padding_mask=padding_mask)
        return self.final_norm(encoded)
        

class EmCoder(PreTrainedModel):
    """The full EmCoder model, including the classification head."""
    config_class = EmCoderConfig

    def __init__(self, config: EmCoderConfig):
        super().__init__(config)

        self.encoder = EmCoderCore(config)
        self.classifier = nn.Sequential(
            nn.Linear(config.d_model, config.d_model),
            nn.GELU(),
            nn.Dropout(config.dropout),
            nn.Linear(config.d_model, config.num_labels)
        )

        self.post_init()

    def _set_mc_dropout(self, active: bool = True):
        for m in self.modules():
            if isinstance(m, nn.Dropout):
                m.train(active)

    @staticmethod
    def _masked_mean_pooling(features: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        mask = mask.unsqueeze(-1)  # (B, S, 1)
        masked_features = features * mask  # (B, S, D)
        sum_masked_features = masked_features.sum(dim=1)  # (B, D)
        count_tokens = torch.clamp(mask.sum(dim=1), min=1e-9)  # (B, 1)
        return sum_masked_features / count_tokens  # (B, D)

    def mc_forward(self, x: torch.Tensor, mask: torch.Tensor, n_samples: int) -> torch.Tensor:
        """Performs Monte Carlo Dropout inference to quantify epistemic uncertainty."""
        self._set_mc_dropout(active=True)

        B, S = x.shape
        x_stacked = x.repeat(n_samples, 1) # (n_samples * B, S)
        mask_stacked = mask.repeat(n_samples, 1)

        features = self.encoder(x_stacked, mask_stacked)
        pooled = self._masked_mean_pooling(features, mask_stacked)
        logits = self.classifier(pooled) # (n_samples * B, num_labels)

        return logits.view(n_samples, B, -1)


    def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        """Standard forward pass without MC Dropout."""
        features = self.encoder(x, mask)

        pooled = self._masked_mean_pooling(features, mask)
        return self.classifier(pooled)