File size: 4,837 Bytes
c47a352 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | import torch
import torch.nn as nn
from safetensors.torch import load_file
from pydantic import BaseModel, model_validator, field_validator
class ModelConfig(BaseModel):
vocab_size: int
max_seq_len: int
d_model: int
n_head: int
n_layers: int
d_ffn: int
dropout: float
num_labels: int
id2label: dict[int, str]
label2id: dict[str, int]
base_encoder_path: str
@field_validator("id2label", mode="before")
@classmethod
def coerce_keys_to_int(cls, v):
return {int(k): val for k, val in v.items()}
@model_validator(mode='after')
def check_consistency(self):
if len(self.id2label) != self.num_labels:
raise ValueError("num_labels does not match id2label dictionary len")
return self
class EmCoderCore(nn.Module):
"""The core encoder architecture of EmCoder, without the classification head."""
def __init__(self, config: ModelConfig):
super().__init__()
self.token_embedding = nn.Embedding(
config.vocab_size,
config.d_model
)
self.pos_embedding = nn.Embedding(
config.max_seq_len,
config.d_model
)
self.embed_norm = nn.LayerNorm(config.d_model)
encoder_layer = nn.TransformerEncoderLayer(
d_model=config.d_model,
nhead=config.n_head,
dim_feedforward=config.d_ffn,
dropout=config.dropout,
activation="gelu",
norm_first=True,
batch_first=True
)
self.encoder = nn.TransformerEncoder(
encoder_layer=encoder_layer,
num_layers=config.n_layers
)
self.final_norm = nn.LayerNorm(config.d_model)
self.dropout = nn.Dropout(config.dropout)
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""Standard forward pass through the encoder."""
seq_len = x.size(1)
pos_ids = torch.arange(seq_len, device=x.device).unsqueeze(0)
x = self.token_embedding(x) + self.pos_embedding(pos_ids)
x = self.embed_norm(x)
x = self.dropout(x)
padding_mask = (mask == 0)
encoded = self.encoder(x, src_key_padding_mask=padding_mask)
return self.final_norm(encoded)
class EmCoder(nn.Module):
"""The full EmCoder model, including the classification head."""
def __init__(self, encoder: EmCoderCore, config: ModelConfig):
super().__init__()
self.encoder = encoder
self.config = config
self.classifier = nn.Sequential(
nn.Linear(config.d_model, config.d_model),
nn.GELU(),
nn.Dropout(config.dropout),
nn.Linear(config.d_model, config.num_labels)
)
def _set_mc_dropout(self, active: bool = True):
for m in self.modules():
if isinstance(m, nn.Dropout):
m.train(active)
@classmethod
def from_pretrained(cls, emcoder_path: str):
"""Loads the EmCoder model from the specified directory."""
# Use model_config.json to initialize same parameterers as in training
with open(f"{emcoder_path}/model_config.json", "r") as f:
model_config = ModelConfig.model_validate_json(f.read())
encoder = EmCoderCore(model_config)
model = cls(encoder, model_config)
state_dict = load_file(f"{emcoder_path}/model.safetensors")
model.load_state_dict(state_dict, strict=True)
return model
@staticmethod
def _masked_mean_pooling(features: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
mask = mask.unsqueeze(-1) # (B, S, 1)
masked_features = features * mask # (B, S, D)
sum_masked_features = masked_features.sum(dim=1) # (B, D)
count_tokens = torch.clamp(mask.sum(dim=1), min=1e-9) # (B, 1)
return sum_masked_features / count_tokens # (B, D)
def mc_forward(self, x: torch.Tensor, mask: torch.Tensor, n_samples: int) -> torch.Tensor:
"""Performs Monte Carlo Dropout inference to quantify epistemic uncertainty."""
self._set_mc_dropout(active=True)
B, S = x.shape
x_stacked = x.repeat(n_samples, 1) # (n_samples * B, S)
mask_stacked = mask.repeat(n_samples, 1)
features = self.encoder(x_stacked, mask_stacked)
pooled = self._masked_mean_pooling(features, mask_stacked)
logits = self.classifier(pooled) # (n_samples * B, num_labels)
return logits.view(n_samples, B, -1)
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
"""Standard forward pass without MC Dropout."""
features = self.encoder(x, mask)
pooled = self._masked_mean_pooling(features, mask)
return self.classifier(pooled) |