Upload modeling_esm_plusplus.py with huggingface_hub
Browse files- modeling_esm_plusplus.py +47 -36
modeling_esm_plusplus.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch._inductor.config as inductor_config
|
| 3 |
import torch._dynamo as dynamo
|
|
@@ -429,7 +431,8 @@ Contains: AttentionBackend enum, backend resolution, mask creation,
|
|
| 429 |
flex attention helpers, flash kernel detection/dispatch, and pad/unpad utilities.
|
| 430 |
"""
|
| 431 |
from enum import Enum
|
| 432 |
-
from
|
|
|
|
| 433 |
|
| 434 |
import torch
|
| 435 |
import torch.nn as nn
|
|
@@ -447,7 +450,12 @@ _compiled_flex_attention = None
|
|
| 447 |
|
| 448 |
|
| 449 |
def _get_flex_attention_fn():
|
| 450 |
-
"""Return flex_attention callable: compiled (fused kernel) by default, or eager when debug flag is set.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
global _compiled_flex_attention
|
| 452 |
if flex_attention is None:
|
| 453 |
return None
|
|
@@ -455,12 +463,15 @@ def _get_flex_attention_fn():
|
|
| 455 |
if getattr(flex_mod, "_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG", False):
|
| 456 |
return flex_attention
|
| 457 |
if _compiled_flex_attention is None:
|
| 458 |
-
_compiled_flex_attention = torch.compile(
|
|
|
|
|
|
|
|
|
|
| 459 |
return _compiled_flex_attention
|
| 460 |
|
| 461 |
|
| 462 |
### Kernels Flash Attention Detection
|
| 463 |
-
def _infer_kernels_flash_variant(kernel) -> str
|
| 464 |
if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"):
|
| 465 |
return "flash_attn2"
|
| 466 |
if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"):
|
|
@@ -576,7 +587,7 @@ class IndexFirstAxis(torch.autograd.Function):
|
|
| 576 |
).reshape(-1, *other_shape)
|
| 577 |
|
| 578 |
@staticmethod
|
| 579 |
-
def backward(ctx, grad_output) ->
|
| 580 |
(indices,) = ctx.saved_tensors
|
| 581 |
assert grad_output.ndim >= 2
|
| 582 |
other_shape = grad_output.shape[1:]
|
|
@@ -599,7 +610,7 @@ class IndexPutFirstAxis(torch.autograd.Function):
|
|
| 599 |
return output
|
| 600 |
|
| 601 |
@staticmethod
|
| 602 |
-
def backward(ctx, grad_output) ->
|
| 603 |
(indices,) = ctx.saved_tensors
|
| 604 |
return grad_output[indices], None, None
|
| 605 |
|
|
@@ -618,7 +629,7 @@ def _unpad_input(
|
|
| 618 |
key_layer: torch.Tensor,
|
| 619 |
value_layer: torch.Tensor,
|
| 620 |
attention_mask_2d: torch.Tensor,
|
| 621 |
-
) ->
|
| 622 |
batch_size, seq_len, num_heads, head_dim = query_layer.shape
|
| 623 |
seqlens = attention_mask_2d.sum(dim=1).int()
|
| 624 |
cu_seqlens = F.pad(seqlens.cumsum(0, dtype=torch.int32), (1, 0))
|
|
@@ -634,7 +645,7 @@ def kernels_flash_attention_func(
|
|
| 634 |
query_states: torch.Tensor,
|
| 635 |
key_states: torch.Tensor,
|
| 636 |
value_states: torch.Tensor,
|
| 637 |
-
attention_mask_2d: torch.Tensor
|
| 638 |
causal: bool = False,
|
| 639 |
) -> torch.Tensor:
|
| 640 |
assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
|
|
@@ -707,7 +718,7 @@ def get_attention_mask(
|
|
| 707 |
seq_len: int,
|
| 708 |
device: torch.device,
|
| 709 |
attention_mask: Optional[torch.Tensor] = None,
|
| 710 |
-
) ->
|
| 711 |
"""Build padding masks once for all encoder layers.
|
| 712 |
|
| 713 |
Returns (attention_mask_2d, attention_mask_4d, flex_block_mask).
|
|
@@ -783,7 +794,7 @@ class ESMplusplusConfig(PretrainedConfig):
|
|
| 783 |
num_attention_heads: int = 15,
|
| 784 |
num_hidden_layers: int = 30,
|
| 785 |
num_labels: int = 2,
|
| 786 |
-
problem_type: str
|
| 787 |
dropout: float = 0.0,
|
| 788 |
initializer_range: float = 0.02,
|
| 789 |
attn_backend: str = "sdpa",
|
|
@@ -1057,12 +1068,12 @@ class MultiHeadAttention(nn.Module):
|
|
| 1057 |
def forward(
|
| 1058 |
self,
|
| 1059 |
x: torch.Tensor,
|
| 1060 |
-
attention_mask_2d: torch.Tensor
|
| 1061 |
-
attention_mask_4d: torch.Tensor
|
| 1062 |
-
flex_block_mask:
|
| 1063 |
output_attentions: bool = False,
|
| 1064 |
output_s_max: bool = False,
|
| 1065 |
-
) ->
|
| 1066 |
qkv_BLD3 = self.layernorm_qkv(x)
|
| 1067 |
query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1)
|
| 1068 |
query_BLD, key_BLD = (
|
|
@@ -1089,12 +1100,12 @@ class MultiHeadAttention(nn.Module):
|
|
| 1089 |
query_BHLD: torch.Tensor,
|
| 1090 |
key_BHLD: torch.Tensor,
|
| 1091 |
value_BHLD: torch.Tensor,
|
| 1092 |
-
attention_mask_2d: torch.Tensor
|
| 1093 |
-
attention_mask_4d: torch.Tensor
|
| 1094 |
-
flex_block_mask:
|
| 1095 |
output_attentions: bool = False,
|
| 1096 |
output_s_max: bool = False,
|
| 1097 |
-
) ->
|
| 1098 |
if output_attentions:
|
| 1099 |
return self._manual_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d, output_s_max)
|
| 1100 |
|
|
@@ -1111,7 +1122,7 @@ class MultiHeadAttention(nn.Module):
|
|
| 1111 |
return attn_output, attn_weights, s_max
|
| 1112 |
|
| 1113 |
@torch.no_grad()
|
| 1114 |
-
def _compute_s_max(self, query_BHLD: torch.Tensor, key_BHLD: torch.Tensor) ->
|
| 1115 |
q_norm = torch.linalg.vector_norm(query_BHLD, dim=-1)
|
| 1116 |
k_norm = torch.linalg.vector_norm(key_BHLD, dim=-1)
|
| 1117 |
s_max_bound = (q_norm.max(dim=-1).values * k_norm.max(dim=-1).values).max(dim=0).values * self.scale
|
|
@@ -1122,9 +1133,9 @@ class MultiHeadAttention(nn.Module):
|
|
| 1122 |
query_BHLD: torch.Tensor,
|
| 1123 |
key_BHLD: torch.Tensor,
|
| 1124 |
value_BHLD: torch.Tensor,
|
| 1125 |
-
attention_mask_4d: torch.Tensor
|
| 1126 |
output_s_max: bool = False,
|
| 1127 |
-
) ->
|
| 1128 |
attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * self.scale
|
| 1129 |
if attention_mask_4d is not None:
|
| 1130 |
attn_weights = attn_weights.masked_fill(attention_mask_4d.logical_not(), float("-inf"))
|
|
@@ -1139,8 +1150,8 @@ class MultiHeadAttention(nn.Module):
|
|
| 1139 |
query_BHLD: torch.Tensor,
|
| 1140 |
key_BHLD: torch.Tensor,
|
| 1141 |
value_BHLD: torch.Tensor,
|
| 1142 |
-
attention_mask_2d: torch.Tensor
|
| 1143 |
-
) ->
|
| 1144 |
query_BLHD = query_BHLD.transpose(1, 2).contiguous()
|
| 1145 |
key_BLHD = key_BHLD.transpose(1, 2).contiguous()
|
| 1146 |
value_BLHD = value_BHLD.transpose(1, 2).contiguous()
|
|
@@ -1155,8 +1166,8 @@ class MultiHeadAttention(nn.Module):
|
|
| 1155 |
query_BHLD: torch.Tensor,
|
| 1156 |
key_BHLD: torch.Tensor,
|
| 1157 |
value_BHLD: torch.Tensor,
|
| 1158 |
-
flex_block_mask:
|
| 1159 |
-
) ->
|
| 1160 |
assert flex_attention is not None, "Flex attention is not available in this environment."
|
| 1161 |
fn = _get_flex_attention_fn()
|
| 1162 |
context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=self.scale)
|
|
@@ -1167,8 +1178,8 @@ class MultiHeadAttention(nn.Module):
|
|
| 1167 |
query_BHLD: torch.Tensor,
|
| 1168 |
key_BHLD: torch.Tensor,
|
| 1169 |
value_BHLD: torch.Tensor,
|
| 1170 |
-
attention_mask_4d: torch.Tensor
|
| 1171 |
-
) ->
|
| 1172 |
context_BHLD = F.scaled_dot_product_attention(
|
| 1173 |
query_BHLD, key_BHLD, value_BHLD, attn_mask=attention_mask_4d, scale=self.scale,
|
| 1174 |
)
|
|
@@ -1214,12 +1225,12 @@ class UnifiedTransformerBlock(nn.Module):
|
|
| 1214 |
def forward(
|
| 1215 |
self,
|
| 1216 |
x: torch.Tensor,
|
| 1217 |
-
attention_mask_2d: torch.Tensor
|
| 1218 |
-
attention_mask_4d: torch.Tensor
|
| 1219 |
-
flex_block_mask:
|
| 1220 |
output_attentions: bool = False,
|
| 1221 |
output_s_max: bool = False,
|
| 1222 |
-
) ->
|
| 1223 |
attn_output, attn_weights, s_max = self.attn(
|
| 1224 |
x,
|
| 1225 |
attention_mask_2d=attention_mask_2d,
|
|
@@ -1240,7 +1251,7 @@ class TransformerOutput(ModelOutput):
|
|
| 1240 |
last_hidden_state: Optional[torch.Tensor] = None
|
| 1241 |
hidden_states: Optional[Tuple[torch.Tensor]] = None
|
| 1242 |
attentions: Optional[Tuple[torch.Tensor]] = None
|
| 1243 |
-
s_max: Optional[Tuple[
|
| 1244 |
|
| 1245 |
|
| 1246 |
@dataclass
|
|
@@ -1251,7 +1262,7 @@ class ESMplusplusOutput(ModelOutput):
|
|
| 1251 |
last_hidden_state: Optional[torch.Tensor] = None
|
| 1252 |
hidden_states: Optional[Tuple[torch.Tensor]] = None
|
| 1253 |
attentions: Optional[Tuple[torch.Tensor]] = None
|
| 1254 |
-
s_max: Optional[Tuple[
|
| 1255 |
|
| 1256 |
|
| 1257 |
### Transformer Stack
|
|
@@ -1772,7 +1783,7 @@ def get_esmc_checkpoint_path(model: str) -> Path:
|
|
| 1772 |
def _load_esmc_checkpoint_model(
|
| 1773 |
config: ESMplusplusConfig,
|
| 1774 |
model: str,
|
| 1775 |
-
device: torch.device
|
| 1776 |
) -> ESMplusplusForMaskedLM:
|
| 1777 |
key = _resolve_esmc_checkpoint_key(model)
|
| 1778 |
spec = _ESMC_CHECKPOINT_SPECS[key]
|
|
@@ -1795,7 +1806,7 @@ def _load_esmc_checkpoint_model(
|
|
| 1795 |
return model_obj
|
| 1796 |
|
| 1797 |
|
| 1798 |
-
def ESMplusplus_300M(device: torch.device
|
| 1799 |
config = ESMplusplusConfig(
|
| 1800 |
hidden_size=960,
|
| 1801 |
num_attention_heads=15,
|
|
@@ -1804,7 +1815,7 @@ def ESMplusplus_300M(device: torch.device | str = "cpu"):
|
|
| 1804 |
return _load_esmc_checkpoint_model(config=config, model="esmc-300", device=device)
|
| 1805 |
|
| 1806 |
|
| 1807 |
-
def ESMplusplus_600M(device: torch.device
|
| 1808 |
config = ESMplusplusConfig(
|
| 1809 |
hidden_size=1152,
|
| 1810 |
num_attention_heads=18,
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
import torch
|
| 4 |
import torch._inductor.config as inductor_config
|
| 5 |
import torch._dynamo as dynamo
|
|
|
|
| 431 |
flex attention helpers, flash kernel detection/dispatch, and pad/unpad utilities.
|
| 432 |
"""
|
| 433 |
from enum import Enum
|
| 434 |
+
from functools import partial
|
| 435 |
+
from typing import Dict, List, Optional, Tuple
|
| 436 |
|
| 437 |
import torch
|
| 438 |
import torch.nn as nn
|
|
|
|
| 450 |
|
| 451 |
|
| 452 |
def _get_flex_attention_fn():
|
| 453 |
+
"""Return flex_attention callable: compiled (fused kernel) by default, or eager when debug flag is set.
|
| 454 |
+
|
| 455 |
+
Uses kernel_options={"BACKEND": "FLASH"} to prefer Flash Attention 4 (FA4)
|
| 456 |
+
on Hopper/Blackwell GPUs (PyTorch 2.11+). Automatically falls back to Triton
|
| 457 |
+
on older hardware.
|
| 458 |
+
"""
|
| 459 |
global _compiled_flex_attention
|
| 460 |
if flex_attention is None:
|
| 461 |
return None
|
|
|
|
| 463 |
if getattr(flex_mod, "_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG", False):
|
| 464 |
return flex_attention
|
| 465 |
if _compiled_flex_attention is None:
|
| 466 |
+
_compiled_flex_attention = torch.compile(
|
| 467 |
+
partial(flex_attention, kernel_options={"BACKEND": "FLASH"}),
|
| 468 |
+
dynamic=False,
|
| 469 |
+
)
|
| 470 |
return _compiled_flex_attention
|
| 471 |
|
| 472 |
|
| 473 |
### Kernels Flash Attention Detection
|
| 474 |
+
def _infer_kernels_flash_variant(kernel) -> Optional[str]:
|
| 475 |
if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"):
|
| 476 |
return "flash_attn2"
|
| 477 |
if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"):
|
|
|
|
| 587 |
).reshape(-1, *other_shape)
|
| 588 |
|
| 589 |
@staticmethod
|
| 590 |
+
def backward(ctx, grad_output) -> Tuple[torch.Tensor, None]:
|
| 591 |
(indices,) = ctx.saved_tensors
|
| 592 |
assert grad_output.ndim >= 2
|
| 593 |
other_shape = grad_output.shape[1:]
|
|
|
|
| 610 |
return output
|
| 611 |
|
| 612 |
@staticmethod
|
| 613 |
+
def backward(ctx, grad_output) -> Tuple[torch.Tensor, None, None]:
|
| 614 |
(indices,) = ctx.saved_tensors
|
| 615 |
return grad_output[indices], None, None
|
| 616 |
|
|
|
|
| 629 |
key_layer: torch.Tensor,
|
| 630 |
value_layer: torch.Tensor,
|
| 631 |
attention_mask_2d: torch.Tensor,
|
| 632 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
|
| 633 |
batch_size, seq_len, num_heads, head_dim = query_layer.shape
|
| 634 |
seqlens = attention_mask_2d.sum(dim=1).int()
|
| 635 |
cu_seqlens = F.pad(seqlens.cumsum(0, dtype=torch.int32), (1, 0))
|
|
|
|
| 645 |
query_states: torch.Tensor,
|
| 646 |
key_states: torch.Tensor,
|
| 647 |
value_states: torch.Tensor,
|
| 648 |
+
attention_mask_2d: Optional[torch.Tensor] = None,
|
| 649 |
causal: bool = False,
|
| 650 |
) -> torch.Tensor:
|
| 651 |
assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
|
|
|
|
| 718 |
seq_len: int,
|
| 719 |
device: torch.device,
|
| 720 |
attention_mask: Optional[torch.Tensor] = None,
|
| 721 |
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[BlockMask]]:
|
| 722 |
"""Build padding masks once for all encoder layers.
|
| 723 |
|
| 724 |
Returns (attention_mask_2d, attention_mask_4d, flex_block_mask).
|
|
|
|
| 794 |
num_attention_heads: int = 15,
|
| 795 |
num_hidden_layers: int = 30,
|
| 796 |
num_labels: int = 2,
|
| 797 |
+
problem_type: Optional[str] = None,
|
| 798 |
dropout: float = 0.0,
|
| 799 |
initializer_range: float = 0.02,
|
| 800 |
attn_backend: str = "sdpa",
|
|
|
|
| 1068 |
def forward(
|
| 1069 |
self,
|
| 1070 |
x: torch.Tensor,
|
| 1071 |
+
attention_mask_2d: Optional[torch.Tensor] = None,
|
| 1072 |
+
attention_mask_4d: Optional[torch.Tensor] = None,
|
| 1073 |
+
flex_block_mask: Optional[BlockMask] = None,
|
| 1074 |
output_attentions: bool = False,
|
| 1075 |
output_s_max: bool = False,
|
| 1076 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]:
|
| 1077 |
qkv_BLD3 = self.layernorm_qkv(x)
|
| 1078 |
query_BLD, key_BLD, value_BLD = torch.chunk(qkv_BLD3, 3, dim=-1)
|
| 1079 |
query_BLD, key_BLD = (
|
|
|
|
| 1100 |
query_BHLD: torch.Tensor,
|
| 1101 |
key_BHLD: torch.Tensor,
|
| 1102 |
value_BHLD: torch.Tensor,
|
| 1103 |
+
attention_mask_2d: Optional[torch.Tensor] = None,
|
| 1104 |
+
attention_mask_4d: Optional[torch.Tensor] = None,
|
| 1105 |
+
flex_block_mask: Optional[BlockMask] = None,
|
| 1106 |
output_attentions: bool = False,
|
| 1107 |
output_s_max: bool = False,
|
| 1108 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]:
|
| 1109 |
if output_attentions:
|
| 1110 |
return self._manual_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d, output_s_max)
|
| 1111 |
|
|
|
|
| 1122 |
return attn_output, attn_weights, s_max
|
| 1123 |
|
| 1124 |
@torch.no_grad()
|
| 1125 |
+
def _compute_s_max(self, query_BHLD: torch.Tensor, key_BHLD: torch.Tensor) -> List[torch.Tensor]:
|
| 1126 |
q_norm = torch.linalg.vector_norm(query_BHLD, dim=-1)
|
| 1127 |
k_norm = torch.linalg.vector_norm(key_BHLD, dim=-1)
|
| 1128 |
s_max_bound = (q_norm.max(dim=-1).values * k_norm.max(dim=-1).values).max(dim=0).values * self.scale
|
|
|
|
| 1133 |
query_BHLD: torch.Tensor,
|
| 1134 |
key_BHLD: torch.Tensor,
|
| 1135 |
value_BHLD: torch.Tensor,
|
| 1136 |
+
attention_mask_4d: Optional[torch.Tensor] = None,
|
| 1137 |
output_s_max: bool = False,
|
| 1138 |
+
) -> Tuple[torch.Tensor, torch.Tensor, Optional[List[torch.Tensor]]]:
|
| 1139 |
attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-2, -1)) * self.scale
|
| 1140 |
if attention_mask_4d is not None:
|
| 1141 |
attn_weights = attn_weights.masked_fill(attention_mask_4d.logical_not(), float("-inf"))
|
|
|
|
| 1150 |
query_BHLD: torch.Tensor,
|
| 1151 |
key_BHLD: torch.Tensor,
|
| 1152 |
value_BHLD: torch.Tensor,
|
| 1153 |
+
attention_mask_2d: Optional[torch.Tensor] = None,
|
| 1154 |
+
) -> Tuple[torch.Tensor, None]:
|
| 1155 |
query_BLHD = query_BHLD.transpose(1, 2).contiguous()
|
| 1156 |
key_BLHD = key_BHLD.transpose(1, 2).contiguous()
|
| 1157 |
value_BLHD = value_BHLD.transpose(1, 2).contiguous()
|
|
|
|
| 1166 |
query_BHLD: torch.Tensor,
|
| 1167 |
key_BHLD: torch.Tensor,
|
| 1168 |
value_BHLD: torch.Tensor,
|
| 1169 |
+
flex_block_mask: Optional[BlockMask] = None,
|
| 1170 |
+
) -> Tuple[torch.Tensor, None]:
|
| 1171 |
assert flex_attention is not None, "Flex attention is not available in this environment."
|
| 1172 |
fn = _get_flex_attention_fn()
|
| 1173 |
context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=self.scale)
|
|
|
|
| 1178 |
query_BHLD: torch.Tensor,
|
| 1179 |
key_BHLD: torch.Tensor,
|
| 1180 |
value_BHLD: torch.Tensor,
|
| 1181 |
+
attention_mask_4d: Optional[torch.Tensor] = None,
|
| 1182 |
+
) -> Tuple[torch.Tensor, None]:
|
| 1183 |
context_BHLD = F.scaled_dot_product_attention(
|
| 1184 |
query_BHLD, key_BHLD, value_BHLD, attn_mask=attention_mask_4d, scale=self.scale,
|
| 1185 |
)
|
|
|
|
| 1225 |
def forward(
|
| 1226 |
self,
|
| 1227 |
x: torch.Tensor,
|
| 1228 |
+
attention_mask_2d: Optional[torch.Tensor] = None,
|
| 1229 |
+
attention_mask_4d: Optional[torch.Tensor] = None,
|
| 1230 |
+
flex_block_mask: Optional[BlockMask] = None,
|
| 1231 |
output_attentions: bool = False,
|
| 1232 |
output_s_max: bool = False,
|
| 1233 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]:
|
| 1234 |
attn_output, attn_weights, s_max = self.attn(
|
| 1235 |
x,
|
| 1236 |
attention_mask_2d=attention_mask_2d,
|
|
|
|
| 1251 |
last_hidden_state: Optional[torch.Tensor] = None
|
| 1252 |
hidden_states: Optional[Tuple[torch.Tensor]] = None
|
| 1253 |
attentions: Optional[Tuple[torch.Tensor]] = None
|
| 1254 |
+
s_max: Optional[Tuple[List[torch.Tensor], ...]] = None
|
| 1255 |
|
| 1256 |
|
| 1257 |
@dataclass
|
|
|
|
| 1262 |
last_hidden_state: Optional[torch.Tensor] = None
|
| 1263 |
hidden_states: Optional[Tuple[torch.Tensor]] = None
|
| 1264 |
attentions: Optional[Tuple[torch.Tensor]] = None
|
| 1265 |
+
s_max: Optional[Tuple[List[torch.Tensor], ...]] = None
|
| 1266 |
|
| 1267 |
|
| 1268 |
### Transformer Stack
|
|
|
|
| 1783 |
def _load_esmc_checkpoint_model(
|
| 1784 |
config: ESMplusplusConfig,
|
| 1785 |
model: str,
|
| 1786 |
+
device: Union[torch.device, str] = "cpu",
|
| 1787 |
) -> ESMplusplusForMaskedLM:
|
| 1788 |
key = _resolve_esmc_checkpoint_key(model)
|
| 1789 |
spec = _ESMC_CHECKPOINT_SPECS[key]
|
|
|
|
| 1806 |
return model_obj
|
| 1807 |
|
| 1808 |
|
| 1809 |
+
def ESMplusplus_300M(device: Union[torch.device, str] = "cpu"):
|
| 1810 |
config = ESMplusplusConfig(
|
| 1811 |
hidden_size=960,
|
| 1812 |
num_attention_heads=15,
|
|
|
|
| 1815 |
return _load_esmc_checkpoint_model(config=config, model="esmc-300", device=device)
|
| 1816 |
|
| 1817 |
|
| 1818 |
+
def ESMplusplus_600M(device: Union[torch.device, str] = "cpu"):
|
| 1819 |
config = ESMplusplusConfig(
|
| 1820 |
hidden_size=1152,
|
| 1821 |
num_attention_heads=18,
|