Upload modeling_fastesm.py with huggingface_hub
Browse files- modeling_fastesm.py +18 -2
modeling_fastesm.py
CHANGED
|
@@ -53,6 +53,7 @@ class FastEsmConfig(PretrainedConfig):
|
|
| 53 |
layer_norm_eps: float = 1e-12,
|
| 54 |
position_embedding_type: str = "absolute",
|
| 55 |
emb_layer_norm_before: bool = None,
|
|
|
|
| 56 |
**kwargs,
|
| 57 |
):
|
| 58 |
super().__init__(
|
|
@@ -74,6 +75,7 @@ class FastEsmConfig(PretrainedConfig):
|
|
| 74 |
self.position_embedding_type = position_embedding_type
|
| 75 |
self.emb_layer_norm_before = emb_layer_norm_before
|
| 76 |
self.tie_word_embeddings = False
|
|
|
|
| 77 |
|
| 78 |
def to_dict(self) -> Dict[str, Any]:
|
| 79 |
"""
|
|
@@ -209,6 +211,8 @@ class EsmEmbeddings(nn.Module):
|
|
| 209 |
self.register_buffer(
|
| 210 |
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
| 211 |
)
|
|
|
|
|
|
|
| 212 |
|
| 213 |
def forward(
|
| 214 |
self,
|
|
@@ -223,6 +227,18 @@ class EsmEmbeddings(nn.Module):
|
|
| 223 |
|
| 224 |
embeddings = inputs_embeds
|
| 225 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 226 |
if self.layer_norm is not None:
|
| 227 |
embeddings = self.layer_norm(embeddings)
|
| 228 |
if attention_mask is not None:
|
|
@@ -300,8 +316,8 @@ class EsmSelfAttention(nn.Module):
|
|
| 300 |
query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
|
| 301 |
|
| 302 |
if output_attentions:
|
| 303 |
-
# Manual attention computation
|
| 304 |
-
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
|
| 305 |
if attention_mask is not None:
|
| 306 |
attention_scores = attention_scores + attention_mask
|
| 307 |
attention_probs = F.softmax(attention_scores, dim=-1)
|
|
|
|
| 53 |
layer_norm_eps: float = 1e-12,
|
| 54 |
position_embedding_type: str = "absolute",
|
| 55 |
emb_layer_norm_before: bool = None,
|
| 56 |
+
token_dropout: bool = True,
|
| 57 |
**kwargs,
|
| 58 |
):
|
| 59 |
super().__init__(
|
|
|
|
| 75 |
self.position_embedding_type = position_embedding_type
|
| 76 |
self.emb_layer_norm_before = emb_layer_norm_before
|
| 77 |
self.tie_word_embeddings = False
|
| 78 |
+
self.token_dropout = token_dropout
|
| 79 |
|
| 80 |
def to_dict(self) -> Dict[str, Any]:
|
| 81 |
"""
|
|
|
|
| 211 |
self.register_buffer(
|
| 212 |
"position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
|
| 213 |
)
|
| 214 |
+
self.token_dropout = config.token_dropout
|
| 215 |
+
self.mask_token_id = config.mask_token_id
|
| 216 |
|
| 217 |
def forward(
|
| 218 |
self,
|
|
|
|
| 227 |
|
| 228 |
embeddings = inputs_embeds
|
| 229 |
|
| 230 |
+
if attention_mask is None:
|
| 231 |
+
attention_mask = torch.ones_like(input_ids)
|
| 232 |
+
|
| 233 |
+
if self.token_dropout:
|
| 234 |
+
embeddings = embeddings.masked_fill((input_ids == self.mask_token_id).unsqueeze(-1), 0)
|
| 235 |
+
mask_ratio_train = 0.15 * 0.8
|
| 236 |
+
src_lengths = attention_mask.sum(-1)
|
| 237 |
+
mask_ratio_observed = (input_ids == self.mask_token_id).sum(-1).float() / src_lengths
|
| 238 |
+
embeddings = (embeddings * (1 - mask_ratio_train) / (1 - mask_ratio_observed)[:, None, None]).to(
|
| 239 |
+
embeddings.dtype
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
if self.layer_norm is not None:
|
| 243 |
embeddings = self.layer_norm(embeddings)
|
| 244 |
if attention_mask is not None:
|
|
|
|
| 316 |
query_layer, key_layer = self.rotary_embeddings(query_layer, key_layer)
|
| 317 |
|
| 318 |
if output_attentions:
|
| 319 |
+
# Manual attention computation - apply scaling here
|
| 320 |
+
attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2)) * self.scale
|
| 321 |
if attention_mask is not None:
|
| 322 |
attention_scores = attention_scores + attention_mask
|
| 323 |
attention_probs = F.softmax(attention_scores, dim=-1)
|