yezdata commited on
Commit
12731fa
·
verified ·
1 Parent(s): 29c0a9c

Delete emcoder.py

Browse files
Files changed (1) hide show
  1. emcoder.py +0 -138
emcoder.py DELETED
@@ -1,138 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from safetensors.torch import load_file
4
- from pydantic import BaseModel, model_validator, field_validator
5
-
6
-
7
- class ModelConfig(BaseModel):
8
- vocab_size: int
9
- max_seq_len: int
10
-
11
- d_model: int
12
- n_head: int
13
- n_layers: int
14
- d_ffn: int
15
-
16
- dropout: float
17
-
18
- num_labels: int
19
- id2label: dict[int, str]
20
- label2id: dict[str, int]
21
-
22
- base_encoder_path: str
23
-
24
- @field_validator("id2label", mode="before")
25
- @classmethod
26
- def coerce_keys_to_int(cls, v):
27
- return {int(k): val for k, val in v.items()}
28
-
29
- @model_validator(mode='after')
30
- def check_consistency(self):
31
- if len(self.id2label) != self.num_labels:
32
- raise ValueError("num_labels does not match id2label dictionary len")
33
- return self
34
-
35
-
36
-
37
-
38
- class EmCoderCore(nn.Module):
39
- """The core encoder architecture of EmCoder, without the classification head."""
40
- def __init__(self, config: ModelConfig):
41
- super().__init__()
42
-
43
- self.token_embedding = nn.Embedding(
44
- config.vocab_size,
45
- config.d_model
46
- )
47
- self.pos_embedding = nn.Embedding(
48
- config.max_seq_len,
49
- config.d_model
50
- )
51
-
52
- self.embed_norm = nn.LayerNorm(config.d_model)
53
-
54
- encoder_layer = nn.TransformerEncoderLayer(
55
- d_model=config.d_model,
56
- nhead=config.n_head,
57
- dim_feedforward=config.d_ffn,
58
- dropout=config.dropout,
59
- activation="gelu",
60
- norm_first=True,
61
- batch_first=True
62
- )
63
- self.encoder = nn.TransformerEncoder(
64
- encoder_layer=encoder_layer,
65
- num_layers=config.n_layers
66
- )
67
-
68
- self.final_norm = nn.LayerNorm(config.d_model)
69
- self.dropout = nn.Dropout(config.dropout)
70
-
71
-
72
-
73
- class EmCoder(nn.Module):
74
- """The full EmCoder model, including the classification head."""
75
- def __init__(self, encoder: EmCoderCore, config: ModelConfig):
76
- super().__init__()
77
-
78
- self.encoder = encoder
79
-
80
- self.classifier = nn.Sequential(
81
- nn.Linear(config.d_model, config.d_model),
82
- nn.GELU(),
83
- nn.Dropout(config.dropout),
84
- nn.Linear(config.d_model, config.num_labels)
85
- )
86
-
87
-
88
- def _set_mc_dropout(self, active: bool = True):
89
- for m in self.modules():
90
- if isinstance(m, nn.Dropout):
91
- m.train(active)
92
-
93
-
94
- @classmethod
95
- def from_pretrained(cls, emcoder_path: str):
96
- """Loads the EmCoder model from the specified directory."""
97
- # Use model_config.json to initialize same parameterers as in training
98
- with open(f"{emcoder_path}/model_config.json", "r") as f:
99
- model_config = ModelConfig.model_validate_json(f.read())
100
-
101
- encoder = EmCoderCore(model_config)
102
- model = cls(encoder, model_config)
103
-
104
- state_dict = load_file(f"{emcoder_path}/model.safetensors")
105
- model.load_state_dict(state_dict, strict=True)
106
- return model
107
-
108
-
109
- @staticmethod
110
- def _masked_mean_pooling(features: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
111
- mask = mask.unsqueeze(-1) # (B, S, 1)
112
- masked_features = features * mask # (B, S, D)
113
- sum_masked_features = masked_features.sum(dim=1) # (B, D)
114
- count_tokens = torch.clamp(mask.sum(dim=1), min=1e-9) # (B, 1)
115
- return sum_masked_features / count_tokens # (B, D)
116
-
117
-
118
- def mc_forward(self, x: torch.Tensor, mask: torch.Tensor, n_samples: int) -> torch.Tensor:
119
- """Performs Monte Carlo Dropout inference to quantify epistemic uncertainty."""
120
- self._set_mc_dropout(active=True)
121
-
122
- B, S = x.shape
123
- x_stacked = x.repeat(n_samples, 1) # (n_samples * B, S)
124
- mask_stacked = mask.repeat(n_samples, 1)
125
-
126
- features = self.encoder(x_stacked, mask_stacked)
127
- pooled = self._masked_mean_pooling(features, mask_stacked)
128
- logits = self.classifier(pooled) # (n_samples * B, num_labels)
129
-
130
- return logits.view(n_samples, B, -1)
131
-
132
-
133
- def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
134
- """Standard forward pass without MC Dropout."""
135
- features = self.encoder(x, mask)
136
-
137
- pooled = self._masked_mean_pooling(features, mask)
138
- return self.classifier(pooled)