yezdata commited on
Commit
b1ffb95
·
verified ·
1 Parent(s): c086040

Delete emcoder.py

Browse files
Files changed (1) hide show
  1. emcoder.py +0 -155
emcoder.py DELETED
@@ -1,155 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- from safetensors.torch import load_file
4
- from pydantic import BaseModel, model_validator, field_validator
5
-
6
-
7
- class ModelConfig(BaseModel):
8
- vocab_size: int
9
- max_seq_len: int
10
-
11
- d_model: int
12
- n_head: int
13
- n_layers: int
14
- d_ffn: int
15
-
16
- dropout: float
17
-
18
- num_labels: int
19
- id2label: dict[int, str]
20
- label2id: dict[str, int]
21
-
22
- base_encoder_path: str
23
-
24
- @field_validator("id2label", mode="before")
25
- @classmethod
26
- def coerce_keys_to_int(cls, v):
27
- return {int(k): val for k, val in v.items()}
28
-
29
- @model_validator(mode='after')
30
- def check_consistency(self):
31
- if len(self.id2label) != self.num_labels:
32
- raise ValueError("num_labels does not match id2label dictionary len")
33
- return self
34
-
35
-
36
-
37
-
38
- class EmCoderCore(nn.Module):
39
- """The core encoder architecture of EmCoder, without the classification head."""
40
- def __init__(self, config: ModelConfig):
41
- super().__init__()
42
-
43
- self.token_embedding = nn.Embedding(
44
- config.vocab_size,
45
- config.d_model
46
- )
47
- self.pos_embedding = nn.Embedding(
48
- config.max_seq_len,
49
- config.d_model
50
- )
51
-
52
- self.embed_norm = nn.LayerNorm(config.d_model)
53
-
54
- encoder_layer = nn.TransformerEncoderLayer(
55
- d_model=config.d_model,
56
- nhead=config.n_head,
57
- dim_feedforward=config.d_ffn,
58
- dropout=config.dropout,
59
- activation="gelu",
60
- norm_first=True,
61
- batch_first=True
62
- )
63
- self.encoder = nn.TransformerEncoder(
64
- encoder_layer=encoder_layer,
65
- num_layers=config.n_layers
66
- )
67
-
68
- self.final_norm = nn.LayerNorm(config.d_model)
69
- self.dropout = nn.Dropout(config.dropout)
70
-
71
- def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
72
- """Standard forward pass through the encoder."""
73
- seq_len = x.size(1)
74
- pos_ids = torch.arange(seq_len, device=x.device).unsqueeze(0)
75
-
76
- x = self.token_embedding(x) + self.pos_embedding(pos_ids)
77
-
78
- x = self.embed_norm(x)
79
- x = self.dropout(x)
80
-
81
- padding_mask = (mask == 0)
82
-
83
- encoded = self.encoder(x, src_key_padding_mask=padding_mask)
84
- return self.final_norm(encoded)
85
-
86
-
87
-
88
- class EmCoder(nn.Module):
89
- """The full EmCoder model, including the classification head."""
90
- def __init__(self, encoder: EmCoderCore, config: ModelConfig):
91
- super().__init__()
92
-
93
- self.encoder = encoder
94
- self.config = config
95
-
96
- self.classifier = nn.Sequential(
97
- nn.Linear(config.d_model, config.d_model),
98
- nn.GELU(),
99
- nn.Dropout(config.dropout),
100
- nn.Linear(config.d_model, config.num_labels)
101
- )
102
-
103
-
104
- def _set_mc_dropout(self, active: bool = True):
105
- for m in self.modules():
106
- if isinstance(m, nn.Dropout):
107
- m.train(active)
108
-
109
-
110
- @classmethod
111
- def from_pretrained(cls, emcoder_path: str):
112
- """Loads the EmCoder model from the specified directory."""
113
- # Use model_config.json to initialize same parameterers as in training
114
- with open(f"{emcoder_path}/model_config.json", "r") as f:
115
- model_config = ModelConfig.model_validate_json(f.read())
116
-
117
-
118
- encoder = EmCoderCore(model_config)
119
- model = cls(encoder, model_config)
120
-
121
- state_dict = load_file(f"{emcoder_path}/model.safetensors")
122
- model.load_state_dict(state_dict, strict=True)
123
- return model
124
-
125
-
126
- @staticmethod
127
- def _masked_mean_pooling(features: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
128
- mask = mask.unsqueeze(-1) # (B, S, 1)
129
- masked_features = features * mask # (B, S, D)
130
- sum_masked_features = masked_features.sum(dim=1) # (B, D)
131
- count_tokens = torch.clamp(mask.sum(dim=1), min=1e-9) # (B, 1)
132
- return sum_masked_features / count_tokens # (B, D)
133
-
134
-
135
- def mc_forward(self, x: torch.Tensor, mask: torch.Tensor, n_samples: int) -> torch.Tensor:
136
- """Performs Monte Carlo Dropout inference to quantify epistemic uncertainty."""
137
- self._set_mc_dropout(active=True)
138
-
139
- B, S = x.shape
140
- x_stacked = x.repeat(n_samples, 1) # (n_samples * B, S)
141
- mask_stacked = mask.repeat(n_samples, 1)
142
-
143
- features = self.encoder(x_stacked, mask_stacked)
144
- pooled = self._masked_mean_pooling(features, mask_stacked)
145
- logits = self.classifier(pooled) # (n_samples * B, num_labels)
146
-
147
- return logits.view(n_samples, B, -1)
148
-
149
-
150
- def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
151
- """Standard forward pass without MC Dropout."""
152
- features = self.encoder(x, mask)
153
-
154
- pooled = self._masked_mean_pooling(features, mask)
155
- return self.classifier(pooled)