Instructions to use Synthyra/DPLM-650M with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Synthyra/DPLM-650M with Transformers:
# Load model directly from transformers import EsmForDPLM model = EsmForDPLM.from_pretrained("Synthyra/DPLM-650M", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
Upload modeling_dplm.py with huggingface_hub
Browse files- modeling_dplm.py +23 -16
modeling_dplm.py
CHANGED
|
@@ -420,9 +420,9 @@ def get_attention_mask(
|
|
| 420 |
attention_mask: Optional[torch.Tensor] = None,
|
| 421 |
) -> Tuple[Optional[torch.Tensor], Optional[object]]:
|
| 422 |
if attention_mask is None:
|
| 423 |
-
|
| 424 |
else:
|
| 425 |
-
|
| 426 |
|
| 427 |
if attn_backend == "flex":
|
| 428 |
assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable."
|
|
@@ -430,8 +430,10 @@ def get_attention_mask(
|
|
| 430 |
if attention_mask is None:
|
| 431 |
flex_block_mask = None
|
| 432 |
else:
|
|
|
|
|
|
|
| 433 |
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
| 434 |
-
return (
|
| 435 |
|
| 436 |
flex_block_mask = create_block_mask(
|
| 437 |
mask_mod,
|
|
@@ -441,12 +443,12 @@ def get_attention_mask(
|
|
| 441 |
seq_len,
|
| 442 |
device=device,
|
| 443 |
)
|
| 444 |
-
|
| 445 |
else:
|
| 446 |
flex_block_mask = None
|
| 447 |
-
|
| 448 |
|
| 449 |
-
return
|
| 450 |
|
| 451 |
|
| 452 |
@dataclass
|
|
@@ -478,6 +480,11 @@ class DPLMPreTrainedModel(EsmPreTrainedModel):
|
|
| 478 |
tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
|
| 479 |
all_tied_weights_keys = {}
|
| 480 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 481 |
@property
|
| 482 |
def attn_backend(self) -> str:
|
| 483 |
return self.config.attn_backend
|
|
@@ -899,12 +906,12 @@ class FAST_DPLM_ENCODER(DPLMPreTrainedModel, EmbeddingMixin):
|
|
| 899 |
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
| 900 |
|
| 901 |
if attention_mask is None:
|
| 902 |
-
|
| 903 |
elif attention_mask.dim() == 2:
|
| 904 |
-
|
| 905 |
elif attention_mask.dim() == 4:
|
| 906 |
assert input_ids is not None, "4D attention_mask requires input_ids to infer token-level mask."
|
| 907 |
-
|
| 908 |
else:
|
| 909 |
raise ValueError(f"Unsupported attention_mask shape: {attention_mask.shape}")
|
| 910 |
|
|
@@ -919,19 +926,19 @@ class FAST_DPLM_ENCODER(DPLMPreTrainedModel, EmbeddingMixin):
|
|
| 919 |
|
| 920 |
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 921 |
|
| 922 |
-
embedding_attention_mask =
|
| 923 |
if embedding_attention_mask is None and input_ids is not None:
|
| 924 |
embedding_attention_mask = input_ids.ne(self.config.pad_token_id)
|
| 925 |
|
| 926 |
if self.config.attn_backend == "flex" and output_attentions:
|
| 927 |
raise AssertionError("output_attentions=True is not supported with attn_backend='flex'.")
|
| 928 |
|
| 929 |
-
|
| 930 |
attn_backend=self.config.attn_backend,
|
| 931 |
batch_size=batch_size,
|
| 932 |
seq_len=seq_length,
|
| 933 |
device=device,
|
| 934 |
-
attention_mask=
|
| 935 |
)
|
| 936 |
|
| 937 |
embedding_output = self.embeddings(
|
|
@@ -942,7 +949,7 @@ class FAST_DPLM_ENCODER(DPLMPreTrainedModel, EmbeddingMixin):
|
|
| 942 |
)
|
| 943 |
encoder_outputs = self.encoder(
|
| 944 |
embedding_output,
|
| 945 |
-
attention_mask=
|
| 946 |
head_mask=head_mask,
|
| 947 |
encoder_hidden_states=encoder_hidden_states,
|
| 948 |
encoder_attention_mask=encoder_extended_attention_mask,
|
|
@@ -1041,7 +1048,7 @@ class DPLMForMaskedLM(DPLMPreTrainedModel, EmbeddingMixin):
|
|
| 1041 |
def __init__(self, config, dropout: float = 0.1):
|
| 1042 |
config.hidden_dropout_prob = dropout
|
| 1043 |
DPLMPreTrainedModel.__init__(self, config)
|
| 1044 |
-
self.esm =
|
| 1045 |
self.lm_head = EsmLMHead(config)
|
| 1046 |
self.loss_fct = nn.CrossEntropyLoss()
|
| 1047 |
self.post_init()
|
|
@@ -1136,7 +1143,7 @@ class DPLMForSequenceClassification(DPLMPreTrainedModel, EmbeddingMixin):
|
|
| 1136 |
def __init__(self, config):
|
| 1137 |
DPLMPreTrainedModel.__init__(self, config)
|
| 1138 |
self.num_labels = config.num_labels
|
| 1139 |
-
self.esm =
|
| 1140 |
self.classifier = EsmClassificationHead(config)
|
| 1141 |
self.mse = nn.MSELoss()
|
| 1142 |
self.ce = nn.CrossEntropyLoss()
|
|
@@ -1206,7 +1213,7 @@ class DPLMForTokenClassification(DPLMPreTrainedModel, EmbeddingMixin):
|
|
| 1206 |
def __init__(self, config):
|
| 1207 |
DPLMPreTrainedModel.__init__(self, config)
|
| 1208 |
self.num_labels = config.num_labels
|
| 1209 |
-
self.esm =
|
| 1210 |
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 1211 |
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 1212 |
self.loss_fct = nn.CrossEntropyLoss()
|
|
|
|
| 420 |
attention_mask: Optional[torch.Tensor] = None,
|
| 421 |
) -> Tuple[Optional[torch.Tensor], Optional[object]]:
|
| 422 |
if attention_mask is None:
|
| 423 |
+
attention_mask_2d = torch.ones((batch_size, seq_len), device=device).bool()
|
| 424 |
else:
|
| 425 |
+
attention_mask_2d = attention_mask.bool()
|
| 426 |
|
| 427 |
if attn_backend == "flex":
|
| 428 |
assert create_block_mask is not None, "Flex attention backend requested but torch.create_block_mask is unavailable."
|
|
|
|
| 430 |
if attention_mask is None:
|
| 431 |
flex_block_mask = None
|
| 432 |
else:
|
| 433 |
+
valid_lens = attention_mask_2d.sum(dim=-1)
|
| 434 |
+
|
| 435 |
def mask_mod(batch_idx, head_idx, q_idx, kv_idx):
|
| 436 |
+
return (q_idx < valid_lens[batch_idx]) & (kv_idx < valid_lens[batch_idx])
|
| 437 |
|
| 438 |
flex_block_mask = create_block_mask(
|
| 439 |
mask_mod,
|
|
|
|
| 443 |
seq_len,
|
| 444 |
device=device,
|
| 445 |
)
|
| 446 |
+
attention_mask_4d = None
|
| 447 |
else:
|
| 448 |
flex_block_mask = None
|
| 449 |
+
attention_mask_4d = attention_mask_2d[:, None, :, None] & attention_mask_2d[:, None, None, :]
|
| 450 |
|
| 451 |
+
return attention_mask_4d, flex_block_mask
|
| 452 |
|
| 453 |
|
| 454 |
@dataclass
|
|
|
|
| 480 |
tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t6_8M_UR50D")
|
| 481 |
all_tied_weights_keys = {}
|
| 482 |
|
| 483 |
+
@classmethod
|
| 484 |
+
def is_remote_code(cls) -> bool:
|
| 485 |
+
# Prevent post-load reinitialization of tensors already loaded from checkpoints.
|
| 486 |
+
return True
|
| 487 |
+
|
| 488 |
@property
|
| 489 |
def attn_backend(self) -> str:
|
| 490 |
return self.config.attn_backend
|
|
|
|
| 906 |
past_key_values_length = past_key_values[0][0].shape[2] if past_key_values is not None else 0
|
| 907 |
|
| 908 |
if attention_mask is None:
|
| 909 |
+
attention_mask_2d = torch.ones((batch_size, seq_length + past_key_values_length), device=device).bool()
|
| 910 |
elif attention_mask.dim() == 2:
|
| 911 |
+
attention_mask_2d = attention_mask.bool()
|
| 912 |
elif attention_mask.dim() == 4:
|
| 913 |
assert input_ids is not None, "4D attention_mask requires input_ids to infer token-level mask."
|
| 914 |
+
attention_mask_2d = input_ids.ne(self.config.pad_token_id)
|
| 915 |
else:
|
| 916 |
raise ValueError(f"Unsupported attention_mask shape: {attention_mask.shape}")
|
| 917 |
|
|
|
|
| 926 |
|
| 927 |
head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
|
| 928 |
|
| 929 |
+
embedding_attention_mask = attention_mask_2d
|
| 930 |
if embedding_attention_mask is None and input_ids is not None:
|
| 931 |
embedding_attention_mask = input_ids.ne(self.config.pad_token_id)
|
| 932 |
|
| 933 |
if self.config.attn_backend == "flex" and output_attentions:
|
| 934 |
raise AssertionError("output_attentions=True is not supported with attn_backend='flex'.")
|
| 935 |
|
| 936 |
+
attention_mask_4d, flex_block_mask = get_attention_mask(
|
| 937 |
attn_backend=self.config.attn_backend,
|
| 938 |
batch_size=batch_size,
|
| 939 |
seq_len=seq_length,
|
| 940 |
device=device,
|
| 941 |
+
attention_mask=attention_mask_2d,
|
| 942 |
)
|
| 943 |
|
| 944 |
embedding_output = self.embeddings(
|
|
|
|
| 949 |
)
|
| 950 |
encoder_outputs = self.encoder(
|
| 951 |
embedding_output,
|
| 952 |
+
attention_mask=attention_mask_4d,
|
| 953 |
head_mask=head_mask,
|
| 954 |
encoder_hidden_states=encoder_hidden_states,
|
| 955 |
encoder_attention_mask=encoder_extended_attention_mask,
|
|
|
|
| 1048 |
def __init__(self, config, dropout: float = 0.1):
|
| 1049 |
config.hidden_dropout_prob = dropout
|
| 1050 |
DPLMPreTrainedModel.__init__(self, config)
|
| 1051 |
+
self.esm = FAST_DPLM_ENCODER(config)
|
| 1052 |
self.lm_head = EsmLMHead(config)
|
| 1053 |
self.loss_fct = nn.CrossEntropyLoss()
|
| 1054 |
self.post_init()
|
|
|
|
| 1143 |
def __init__(self, config):
|
| 1144 |
DPLMPreTrainedModel.__init__(self, config)
|
| 1145 |
self.num_labels = config.num_labels
|
| 1146 |
+
self.esm = FAST_DPLM_ENCODER(config)
|
| 1147 |
self.classifier = EsmClassificationHead(config)
|
| 1148 |
self.mse = nn.MSELoss()
|
| 1149 |
self.ce = nn.CrossEntropyLoss()
|
|
|
|
| 1213 |
def __init__(self, config):
|
| 1214 |
DPLMPreTrainedModel.__init__(self, config)
|
| 1215 |
self.num_labels = config.num_labels
|
| 1216 |
+
self.esm = FAST_DPLM_ENCODER(config)
|
| 1217 |
self.dropout = nn.Dropout(config.hidden_dropout_prob)
|
| 1218 |
self.classifier = nn.Linear(config.hidden_size, config.num_labels)
|
| 1219 |
self.loss_fct = nn.CrossEntropyLoss()
|