Instructions to use Synthyra/DPLM-650M with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Synthyra/DPLM-650M with Transformers:
# Load model directly from transformers import EsmForDPLM model = EsmForDPLM.from_pretrained("Synthyra/DPLM-650M", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
Upload modeling_dplm.py with huggingface_hub
Browse files- modeling_dplm.py +38 -27
modeling_dplm.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch._inductor.config as inductor_config
|
| 3 |
import torch._dynamo as dynamo
|
|
@@ -429,7 +431,8 @@ Contains: AttentionBackend enum, backend resolution, mask creation,
|
|
| 429 |
flex attention helpers, flash kernel detection/dispatch, and pad/unpad utilities.
|
| 430 |
"""
|
| 431 |
from enum import Enum
|
| 432 |
-
from
|
|
|
|
| 433 |
|
| 434 |
import torch
|
| 435 |
import torch.nn as nn
|
|
@@ -447,7 +450,12 @@ _compiled_flex_attention = None
|
|
| 447 |
|
| 448 |
|
| 449 |
def _get_flex_attention_fn():
|
| 450 |
-
"""Return flex_attention callable: compiled (fused kernel) by default, or eager when debug flag is set.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
global _compiled_flex_attention
|
| 452 |
if flex_attention is None:
|
| 453 |
return None
|
|
@@ -455,12 +463,15 @@ def _get_flex_attention_fn():
|
|
| 455 |
if getattr(flex_mod, "_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG", False):
|
| 456 |
return flex_attention
|
| 457 |
if _compiled_flex_attention is None:
|
| 458 |
-
_compiled_flex_attention = torch.compile(
|
|
|
|
|
|
|
|
|
|
| 459 |
return _compiled_flex_attention
|
| 460 |
|
| 461 |
|
| 462 |
### Kernels Flash Attention Detection
|
| 463 |
-
def _infer_kernels_flash_variant(kernel) -> str
|
| 464 |
if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"):
|
| 465 |
return "flash_attn2"
|
| 466 |
if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"):
|
|
@@ -576,7 +587,7 @@ class IndexFirstAxis(torch.autograd.Function):
|
|
| 576 |
).reshape(-1, *other_shape)
|
| 577 |
|
| 578 |
@staticmethod
|
| 579 |
-
def backward(ctx, grad_output) ->
|
| 580 |
(indices,) = ctx.saved_tensors
|
| 581 |
assert grad_output.ndim >= 2
|
| 582 |
other_shape = grad_output.shape[1:]
|
|
@@ -599,7 +610,7 @@ class IndexPutFirstAxis(torch.autograd.Function):
|
|
| 599 |
return output
|
| 600 |
|
| 601 |
@staticmethod
|
| 602 |
-
def backward(ctx, grad_output) ->
|
| 603 |
(indices,) = ctx.saved_tensors
|
| 604 |
return grad_output[indices], None, None
|
| 605 |
|
|
@@ -618,7 +629,7 @@ def _unpad_input(
|
|
| 618 |
key_layer: torch.Tensor,
|
| 619 |
value_layer: torch.Tensor,
|
| 620 |
attention_mask_2d: torch.Tensor,
|
| 621 |
-
) ->
|
| 622 |
batch_size, seq_len, num_heads, head_dim = query_layer.shape
|
| 623 |
seqlens = attention_mask_2d.sum(dim=1).int()
|
| 624 |
cu_seqlens = F.pad(seqlens.cumsum(0, dtype=torch.int32), (1, 0))
|
|
@@ -634,7 +645,7 @@ def kernels_flash_attention_func(
|
|
| 634 |
query_states: torch.Tensor,
|
| 635 |
key_states: torch.Tensor,
|
| 636 |
value_states: torch.Tensor,
|
| 637 |
-
attention_mask_2d: torch.Tensor
|
| 638 |
causal: bool = False,
|
| 639 |
) -> torch.Tensor:
|
| 640 |
assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
|
|
@@ -707,7 +718,7 @@ def get_attention_mask(
|
|
| 707 |
seq_len: int,
|
| 708 |
device: torch.device,
|
| 709 |
attention_mask: Optional[torch.Tensor] = None,
|
| 710 |
-
) ->
|
| 711 |
"""Build padding masks once for all encoder layers.
|
| 712 |
|
| 713 |
Returns (attention_mask_2d, attention_mask_4d, flex_block_mask).
|
|
@@ -782,7 +793,7 @@ class DPLMMaskedLMOutput(ModelOutput):
|
|
| 782 |
last_hidden_state: Optional[torch.Tensor] = None
|
| 783 |
hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
|
| 784 |
attentions: Optional[Tuple[torch.Tensor, ...]] = None
|
| 785 |
-
s_max: Optional[Tuple[
|
| 786 |
|
| 787 |
|
| 788 |
@dataclass
|
|
@@ -790,7 +801,7 @@ class DPLMEncoderOutput(ModelOutput):
|
|
| 790 |
last_hidden_state: Optional[torch.Tensor] = None
|
| 791 |
hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
|
| 792 |
attentions: Optional[Tuple[torch.Tensor, ...]] = None
|
| 793 |
-
s_max: Optional[Tuple[
|
| 794 |
|
| 795 |
|
| 796 |
class DPLMConfig(EsmConfig):
|
|
@@ -859,7 +870,7 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
|
|
| 859 |
output_attentions: Optional[bool] = False,
|
| 860 |
output_s_max: Optional[bool] = False,
|
| 861 |
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 862 |
-
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[
|
| 863 |
if past_key_values is not None:
|
| 864 |
past_key_value = past_key_values
|
| 865 |
|
|
@@ -930,12 +941,12 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
|
|
| 930 |
query_BHLD: torch.Tensor,
|
| 931 |
key_BHLD: torch.Tensor,
|
| 932 |
value_BHLD: torch.Tensor,
|
| 933 |
-
attention_mask_2d: torch.Tensor
|
| 934 |
-
attention_mask_4d: torch.Tensor
|
| 935 |
-
flex_block_mask:
|
| 936 |
output_attentions: bool = False,
|
| 937 |
output_s_max: bool = False,
|
| 938 |
-
) ->
|
| 939 |
if output_attentions:
|
| 940 |
return self._manual_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d, output_s_max)
|
| 941 |
|
|
@@ -952,7 +963,7 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
|
|
| 952 |
return attn_output, attn_weights, s_max
|
| 953 |
|
| 954 |
@torch.no_grad()
|
| 955 |
-
def _compute_s_max(self, query_BHLD: torch.Tensor, key_BHLD: torch.Tensor) ->
|
| 956 |
q_norm = torch.linalg.vector_norm(query_BHLD, dim=-1)
|
| 957 |
k_norm = torch.linalg.vector_norm(key_BHLD, dim=-1)
|
| 958 |
s_max_bound = (q_norm.max(dim=-1).values * k_norm.max(dim=-1).values).max(dim=0).values
|
|
@@ -963,9 +974,9 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
|
|
| 963 |
query_BHLD: torch.Tensor,
|
| 964 |
key_BHLD: torch.Tensor,
|
| 965 |
value_BHLD: torch.Tensor,
|
| 966 |
-
attention_mask_4d: torch.Tensor
|
| 967 |
output_s_max: bool = False,
|
| 968 |
-
) ->
|
| 969 |
attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-1, -2))
|
| 970 |
if attention_mask_4d is not None:
|
| 971 |
attn_weights = attn_weights.masked_fill(attention_mask_4d.logical_not(), float("-inf"))
|
|
@@ -980,8 +991,8 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
|
|
| 980 |
query_BHLD: torch.Tensor,
|
| 981 |
key_BHLD: torch.Tensor,
|
| 982 |
value_BHLD: torch.Tensor,
|
| 983 |
-
attention_mask_2d: torch.Tensor
|
| 984 |
-
) ->
|
| 985 |
query_BLHD = query_BHLD.transpose(1, 2).contiguous()
|
| 986 |
key_BLHD = key_BHLD.transpose(1, 2).contiguous()
|
| 987 |
value_BLHD = value_BHLD.transpose(1, 2).contiguous()
|
|
@@ -996,8 +1007,8 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
|
|
| 996 |
query_BHLD: torch.Tensor,
|
| 997 |
key_BHLD: torch.Tensor,
|
| 998 |
value_BHLD: torch.Tensor,
|
| 999 |
-
flex_block_mask:
|
| 1000 |
-
) ->
|
| 1001 |
assert flex_attention is not None, "Flex attention is not available in this environment."
|
| 1002 |
fn = _get_flex_attention_fn()
|
| 1003 |
context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=1.0)
|
|
@@ -1008,8 +1019,8 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
|
|
| 1008 |
query_BHLD: torch.Tensor,
|
| 1009 |
key_BHLD: torch.Tensor,
|
| 1010 |
value_BHLD: torch.Tensor,
|
| 1011 |
-
attention_mask_4d: torch.Tensor
|
| 1012 |
-
) ->
|
| 1013 |
context_BHLD = F.scaled_dot_product_attention(
|
| 1014 |
query_BHLD, key_BHLD, value_BHLD,
|
| 1015 |
attn_mask=attention_mask_4d,
|
|
@@ -1038,7 +1049,7 @@ class ModifiedEsmAttention(EsmAttention):
|
|
| 1038 |
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 1039 |
output_attentions: bool = False,
|
| 1040 |
output_s_max: bool = False,
|
| 1041 |
-
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[
|
| 1042 |
hidden_states_ln = self.LayerNorm(hidden_states)
|
| 1043 |
attn_output, attn_weights, s_max = self.self(
|
| 1044 |
hidden_states_ln,
|
|
@@ -1084,7 +1095,7 @@ class ModifiedEsmLayer(EsmLayer):
|
|
| 1084 |
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 1085 |
output_attentions: bool = False,
|
| 1086 |
output_s_max: bool = False,
|
| 1087 |
-
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[
|
| 1088 |
attention_output, attn_weights, s_max = self.attention(
|
| 1089 |
hidden_states,
|
| 1090 |
attention_mask_2d=attention_mask_2d,
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
import torch
|
| 4 |
import torch._inductor.config as inductor_config
|
| 5 |
import torch._dynamo as dynamo
|
|
|
|
| 431 |
flex attention helpers, flash kernel detection/dispatch, and pad/unpad utilities.
|
| 432 |
"""
|
| 433 |
from enum import Enum
|
| 434 |
+
from functools import partial
|
| 435 |
+
from typing import Dict, List, Optional, Tuple
|
| 436 |
|
| 437 |
import torch
|
| 438 |
import torch.nn as nn
|
|
|
|
| 450 |
|
| 451 |
|
| 452 |
def _get_flex_attention_fn():
|
| 453 |
+
"""Return flex_attention callable: compiled (fused kernel) by default, or eager when debug flag is set.
|
| 454 |
+
|
| 455 |
+
Uses kernel_options={"BACKEND": "FLASH"} to prefer Flash Attention 4 (FA4)
|
| 456 |
+
on Hopper/Blackwell GPUs (PyTorch 2.11+). Automatically falls back to Triton
|
| 457 |
+
on older hardware.
|
| 458 |
+
"""
|
| 459 |
global _compiled_flex_attention
|
| 460 |
if flex_attention is None:
|
| 461 |
return None
|
|
|
|
| 463 |
if getattr(flex_mod, "_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG", False):
|
| 464 |
return flex_attention
|
| 465 |
if _compiled_flex_attention is None:
|
| 466 |
+
_compiled_flex_attention = torch.compile(
|
| 467 |
+
partial(flex_attention, kernel_options={"BACKEND": "FLASH"}),
|
| 468 |
+
dynamic=False,
|
| 469 |
+
)
|
| 470 |
return _compiled_flex_attention
|
| 471 |
|
| 472 |
|
| 473 |
### Kernels Flash Attention Detection
|
| 474 |
+
def _infer_kernels_flash_variant(kernel) -> Optional[str]:
|
| 475 |
if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"):
|
| 476 |
return "flash_attn2"
|
| 477 |
if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"):
|
|
|
|
| 587 |
).reshape(-1, *other_shape)
|
| 588 |
|
| 589 |
@staticmethod
|
| 590 |
+
def backward(ctx, grad_output) -> Tuple[torch.Tensor, None]:
|
| 591 |
(indices,) = ctx.saved_tensors
|
| 592 |
assert grad_output.ndim >= 2
|
| 593 |
other_shape = grad_output.shape[1:]
|
|
|
|
| 610 |
return output
|
| 611 |
|
| 612 |
@staticmethod
|
| 613 |
+
def backward(ctx, grad_output) -> Tuple[torch.Tensor, None, None]:
|
| 614 |
(indices,) = ctx.saved_tensors
|
| 615 |
return grad_output[indices], None, None
|
| 616 |
|
|
|
|
| 629 |
key_layer: torch.Tensor,
|
| 630 |
value_layer: torch.Tensor,
|
| 631 |
attention_mask_2d: torch.Tensor,
|
| 632 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
|
| 633 |
batch_size, seq_len, num_heads, head_dim = query_layer.shape
|
| 634 |
seqlens = attention_mask_2d.sum(dim=1).int()
|
| 635 |
cu_seqlens = F.pad(seqlens.cumsum(0, dtype=torch.int32), (1, 0))
|
|
|
|
| 645 |
query_states: torch.Tensor,
|
| 646 |
key_states: torch.Tensor,
|
| 647 |
value_states: torch.Tensor,
|
| 648 |
+
attention_mask_2d: Optional[torch.Tensor] = None,
|
| 649 |
causal: bool = False,
|
| 650 |
) -> torch.Tensor:
|
| 651 |
assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
|
|
|
|
| 718 |
seq_len: int,
|
| 719 |
device: torch.device,
|
| 720 |
attention_mask: Optional[torch.Tensor] = None,
|
| 721 |
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[BlockMask]]:
|
| 722 |
"""Build padding masks once for all encoder layers.
|
| 723 |
|
| 724 |
Returns (attention_mask_2d, attention_mask_4d, flex_block_mask).
|
|
|
|
| 793 |
last_hidden_state: Optional[torch.Tensor] = None
|
| 794 |
hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
|
| 795 |
attentions: Optional[Tuple[torch.Tensor, ...]] = None
|
| 796 |
+
s_max: Optional[Tuple[List[torch.Tensor], ...]] = None
|
| 797 |
|
| 798 |
|
| 799 |
@dataclass
|
|
|
|
| 801 |
last_hidden_state: Optional[torch.Tensor] = None
|
| 802 |
hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
|
| 803 |
attentions: Optional[Tuple[torch.Tensor, ...]] = None
|
| 804 |
+
s_max: Optional[Tuple[List[torch.Tensor], ...]] = None
|
| 805 |
|
| 806 |
|
| 807 |
class DPLMConfig(EsmConfig):
|
|
|
|
| 870 |
output_attentions: Optional[bool] = False,
|
| 871 |
output_s_max: Optional[bool] = False,
|
| 872 |
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 873 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]:
|
| 874 |
if past_key_values is not None:
|
| 875 |
past_key_value = past_key_values
|
| 876 |
|
|
|
|
| 941 |
query_BHLD: torch.Tensor,
|
| 942 |
key_BHLD: torch.Tensor,
|
| 943 |
value_BHLD: torch.Tensor,
|
| 944 |
+
attention_mask_2d: Optional[torch.Tensor] = None,
|
| 945 |
+
attention_mask_4d: Optional[torch.Tensor] = None,
|
| 946 |
+
flex_block_mask: Optional[BlockMask] = None,
|
| 947 |
output_attentions: bool = False,
|
| 948 |
output_s_max: bool = False,
|
| 949 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]:
|
| 950 |
if output_attentions:
|
| 951 |
return self._manual_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d, output_s_max)
|
| 952 |
|
|
|
|
| 963 |
return attn_output, attn_weights, s_max
|
| 964 |
|
| 965 |
@torch.no_grad()
|
| 966 |
+
def _compute_s_max(self, query_BHLD: torch.Tensor, key_BHLD: torch.Tensor) -> List[torch.Tensor]:
|
| 967 |
q_norm = torch.linalg.vector_norm(query_BHLD, dim=-1)
|
| 968 |
k_norm = torch.linalg.vector_norm(key_BHLD, dim=-1)
|
| 969 |
s_max_bound = (q_norm.max(dim=-1).values * k_norm.max(dim=-1).values).max(dim=0).values
|
|
|
|
| 974 |
query_BHLD: torch.Tensor,
|
| 975 |
key_BHLD: torch.Tensor,
|
| 976 |
value_BHLD: torch.Tensor,
|
| 977 |
+
attention_mask_4d: Optional[torch.Tensor] = None,
|
| 978 |
output_s_max: bool = False,
|
| 979 |
+
) -> Tuple[torch.Tensor, torch.Tensor, Optional[List[torch.Tensor]]]:
|
| 980 |
attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-1, -2))
|
| 981 |
if attention_mask_4d is not None:
|
| 982 |
attn_weights = attn_weights.masked_fill(attention_mask_4d.logical_not(), float("-inf"))
|
|
|
|
| 991 |
query_BHLD: torch.Tensor,
|
| 992 |
key_BHLD: torch.Tensor,
|
| 993 |
value_BHLD: torch.Tensor,
|
| 994 |
+
attention_mask_2d: Optional[torch.Tensor] = None,
|
| 995 |
+
) -> Tuple[torch.Tensor, None]:
|
| 996 |
query_BLHD = query_BHLD.transpose(1, 2).contiguous()
|
| 997 |
key_BLHD = key_BHLD.transpose(1, 2).contiguous()
|
| 998 |
value_BLHD = value_BHLD.transpose(1, 2).contiguous()
|
|
|
|
| 1007 |
query_BHLD: torch.Tensor,
|
| 1008 |
key_BHLD: torch.Tensor,
|
| 1009 |
value_BHLD: torch.Tensor,
|
| 1010 |
+
flex_block_mask: Optional[BlockMask] = None,
|
| 1011 |
+
) -> Tuple[torch.Tensor, None]:
|
| 1012 |
assert flex_attention is not None, "Flex attention is not available in this environment."
|
| 1013 |
fn = _get_flex_attention_fn()
|
| 1014 |
context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=1.0)
|
|
|
|
| 1019 |
query_BHLD: torch.Tensor,
|
| 1020 |
key_BHLD: torch.Tensor,
|
| 1021 |
value_BHLD: torch.Tensor,
|
| 1022 |
+
attention_mask_4d: Optional[torch.Tensor] = None,
|
| 1023 |
+
) -> Tuple[torch.Tensor, None]:
|
| 1024 |
context_BHLD = F.scaled_dot_product_attention(
|
| 1025 |
query_BHLD, key_BHLD, value_BHLD,
|
| 1026 |
attn_mask=attention_mask_4d,
|
|
|
|
| 1049 |
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 1050 |
output_attentions: bool = False,
|
| 1051 |
output_s_max: bool = False,
|
| 1052 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]:
|
| 1053 |
hidden_states_ln = self.LayerNorm(hidden_states)
|
| 1054 |
attn_output, attn_weights, s_max = self.self(
|
| 1055 |
hidden_states_ln,
|
|
|
|
| 1095 |
past_key_value: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
|
| 1096 |
output_attentions: bool = False,
|
| 1097 |
output_s_max: bool = False,
|
| 1098 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]:
|
| 1099 |
attention_output, attn_weights, s_max = self.attention(
|
| 1100 |
hidden_states,
|
| 1101 |
attention_mask_2d=attention_mask_2d,
|