Instructions to use Synthyra/DPLM2-650M with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use Synthyra/DPLM2-650M with Transformers:
# Load model directly from transformers import EsmForDPLM2 model = EsmForDPLM2.from_pretrained("Synthyra/DPLM2-650M", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
Upload modeling_dplm2.py with huggingface_hub
Browse files- modeling_dplm2.py +47 -36
modeling_dplm2.py
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
import torch._inductor.config as inductor_config
|
| 3 |
import torch._dynamo as dynamo
|
|
@@ -429,7 +431,8 @@ Contains: AttentionBackend enum, backend resolution, mask creation,
|
|
| 429 |
flex attention helpers, flash kernel detection/dispatch, and pad/unpad utilities.
|
| 430 |
"""
|
| 431 |
from enum import Enum
|
| 432 |
-
from
|
|
|
|
| 433 |
|
| 434 |
import torch
|
| 435 |
import torch.nn as nn
|
|
@@ -447,7 +450,12 @@ _compiled_flex_attention = None
|
|
| 447 |
|
| 448 |
|
| 449 |
def _get_flex_attention_fn():
|
| 450 |
-
"""Return flex_attention callable: compiled (fused kernel) by default, or eager when debug flag is set.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 451 |
global _compiled_flex_attention
|
| 452 |
if flex_attention is None:
|
| 453 |
return None
|
|
@@ -455,12 +463,15 @@ def _get_flex_attention_fn():
|
|
| 455 |
if getattr(flex_mod, "_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG", False):
|
| 456 |
return flex_attention
|
| 457 |
if _compiled_flex_attention is None:
|
| 458 |
-
_compiled_flex_attention = torch.compile(
|
|
|
|
|
|
|
|
|
|
| 459 |
return _compiled_flex_attention
|
| 460 |
|
| 461 |
|
| 462 |
### Kernels Flash Attention Detection
|
| 463 |
-
def _infer_kernels_flash_variant(kernel) -> str
|
| 464 |
if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"):
|
| 465 |
return "flash_attn2"
|
| 466 |
if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"):
|
|
@@ -576,7 +587,7 @@ class IndexFirstAxis(torch.autograd.Function):
|
|
| 576 |
).reshape(-1, *other_shape)
|
| 577 |
|
| 578 |
@staticmethod
|
| 579 |
-
def backward(ctx, grad_output) ->
|
| 580 |
(indices,) = ctx.saved_tensors
|
| 581 |
assert grad_output.ndim >= 2
|
| 582 |
other_shape = grad_output.shape[1:]
|
|
@@ -599,7 +610,7 @@ class IndexPutFirstAxis(torch.autograd.Function):
|
|
| 599 |
return output
|
| 600 |
|
| 601 |
@staticmethod
|
| 602 |
-
def backward(ctx, grad_output) ->
|
| 603 |
(indices,) = ctx.saved_tensors
|
| 604 |
return grad_output[indices], None, None
|
| 605 |
|
|
@@ -618,7 +629,7 @@ def _unpad_input(
|
|
| 618 |
key_layer: torch.Tensor,
|
| 619 |
value_layer: torch.Tensor,
|
| 620 |
attention_mask_2d: torch.Tensor,
|
| 621 |
-
) ->
|
| 622 |
batch_size, seq_len, num_heads, head_dim = query_layer.shape
|
| 623 |
seqlens = attention_mask_2d.sum(dim=1).int()
|
| 624 |
cu_seqlens = F.pad(seqlens.cumsum(0, dtype=torch.int32), (1, 0))
|
|
@@ -634,7 +645,7 @@ def kernels_flash_attention_func(
|
|
| 634 |
query_states: torch.Tensor,
|
| 635 |
key_states: torch.Tensor,
|
| 636 |
value_states: torch.Tensor,
|
| 637 |
-
attention_mask_2d: torch.Tensor
|
| 638 |
causal: bool = False,
|
| 639 |
) -> torch.Tensor:
|
| 640 |
assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
|
|
@@ -707,7 +718,7 @@ def get_attention_mask(
|
|
| 707 |
seq_len: int,
|
| 708 |
device: torch.device,
|
| 709 |
attention_mask: Optional[torch.Tensor] = None,
|
| 710 |
-
) ->
|
| 711 |
"""Build padding masks once for all encoder layers.
|
| 712 |
|
| 713 |
Returns (attention_mask_2d, attention_mask_4d, flex_block_mask).
|
|
@@ -838,7 +849,7 @@ class DPLM2MaskedLMOutput(ModelOutput):
|
|
| 838 |
last_hidden_state: Optional[torch.Tensor] = None
|
| 839 |
hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
|
| 840 |
attentions: Optional[Tuple[torch.Tensor, ...]] = None
|
| 841 |
-
s_max: Optional[Tuple[
|
| 842 |
|
| 843 |
|
| 844 |
@dataclass
|
|
@@ -846,7 +857,7 @@ class DPLM2EncoderOutput(ModelOutput):
|
|
| 846 |
last_hidden_state: Optional[torch.Tensor] = None
|
| 847 |
hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
|
| 848 |
attentions: Optional[Tuple[torch.Tensor, ...]] = None
|
| 849 |
-
s_max: Optional[Tuple[
|
| 850 |
|
| 851 |
|
| 852 |
class DPLM2Config(EsmConfig):
|
|
@@ -986,13 +997,13 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
|
|
| 986 |
def forward(
|
| 987 |
self,
|
| 988 |
hidden_states: torch.Tensor,
|
| 989 |
-
attention_mask_2d: torch.Tensor
|
| 990 |
-
attention_mask_4d: torch.Tensor
|
| 991 |
-
flex_block_mask:
|
| 992 |
output_attentions: bool = False,
|
| 993 |
output_s_max: bool = False,
|
| 994 |
type_ids: Optional[torch.Tensor] = None,
|
| 995 |
-
) ->
|
| 996 |
batch_size, seq_length = hidden_states.shape[:-1]
|
| 997 |
hidden_shape = (batch_size, seq_length, -1, self.attention_head_size)
|
| 998 |
query_BHLD = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
@@ -1019,12 +1030,12 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
|
|
| 1019 |
query_BHLD: torch.Tensor,
|
| 1020 |
key_BHLD: torch.Tensor,
|
| 1021 |
value_BHLD: torch.Tensor,
|
| 1022 |
-
attention_mask_2d: torch.Tensor
|
| 1023 |
-
attention_mask_4d: torch.Tensor
|
| 1024 |
-
flex_block_mask:
|
| 1025 |
output_attentions: bool = False,
|
| 1026 |
output_s_max: bool = False,
|
| 1027 |
-
) ->
|
| 1028 |
if output_attentions:
|
| 1029 |
return self._manual_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d, output_s_max)
|
| 1030 |
|
|
@@ -1041,7 +1052,7 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
|
|
| 1041 |
return attn_output, attn_weights, s_max
|
| 1042 |
|
| 1043 |
@torch.no_grad()
|
| 1044 |
-
def _compute_s_max(self, query_BHLD: torch.Tensor, key_BHLD: torch.Tensor) ->
|
| 1045 |
q_norm = torch.linalg.vector_norm(query_BHLD, dim=-1)
|
| 1046 |
k_norm = torch.linalg.vector_norm(key_BHLD, dim=-1)
|
| 1047 |
s_max_bound = (q_norm.max(dim=-1).values * k_norm.max(dim=-1).values).max(dim=0).values
|
|
@@ -1052,9 +1063,9 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
|
|
| 1052 |
query_BHLD: torch.Tensor,
|
| 1053 |
key_BHLD: torch.Tensor,
|
| 1054 |
value_BHLD: torch.Tensor,
|
| 1055 |
-
attention_mask_4d: torch.Tensor
|
| 1056 |
output_s_max: bool = False,
|
| 1057 |
-
) ->
|
| 1058 |
attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-1, -2))
|
| 1059 |
if attention_mask_4d is not None:
|
| 1060 |
attn_weights = attn_weights.masked_fill(attention_mask_4d.logical_not(), float("-inf"))
|
|
@@ -1071,8 +1082,8 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
|
|
| 1071 |
query_BHLD: torch.Tensor,
|
| 1072 |
key_BHLD: torch.Tensor,
|
| 1073 |
value_BHLD: torch.Tensor,
|
| 1074 |
-
attention_mask_2d: torch.Tensor
|
| 1075 |
-
) ->
|
| 1076 |
query_BLHD = query_BHLD.transpose(1, 2).contiguous()
|
| 1077 |
key_BLHD = key_BHLD.transpose(1, 2).contiguous()
|
| 1078 |
value_BLHD = value_BHLD.transpose(1, 2).contiguous()
|
|
@@ -1087,8 +1098,8 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
|
|
| 1087 |
query_BHLD: torch.Tensor,
|
| 1088 |
key_BHLD: torch.Tensor,
|
| 1089 |
value_BHLD: torch.Tensor,
|
| 1090 |
-
flex_block_mask:
|
| 1091 |
-
) ->
|
| 1092 |
assert flex_attention is not None, "Flex attention is not available in this environment."
|
| 1093 |
fn = _get_flex_attention_fn()
|
| 1094 |
context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=1.0)
|
|
@@ -1099,8 +1110,8 @@ class ModifiedEsmSelfAttention(EsmSelfAttention):
|
|
| 1099 |
query_BHLD: torch.Tensor,
|
| 1100 |
key_BHLD: torch.Tensor,
|
| 1101 |
value_BHLD: torch.Tensor,
|
| 1102 |
-
attention_mask_4d: torch.Tensor
|
| 1103 |
-
) ->
|
| 1104 |
context_BHLD = F.scaled_dot_product_attention(
|
| 1105 |
query_BHLD, key_BHLD, value_BHLD,
|
| 1106 |
attn_mask=attention_mask_4d,
|
|
@@ -1120,13 +1131,13 @@ class ModifiedEsmAttention(EsmAttention):
|
|
| 1120 |
def forward(
|
| 1121 |
self,
|
| 1122 |
hidden_states: torch.Tensor,
|
| 1123 |
-
attention_mask_2d: torch.Tensor
|
| 1124 |
-
attention_mask_4d: torch.Tensor
|
| 1125 |
-
flex_block_mask:
|
| 1126 |
output_attentions: bool = False,
|
| 1127 |
output_s_max: bool = False,
|
| 1128 |
type_ids: Optional[torch.Tensor] = None,
|
| 1129 |
-
) ->
|
| 1130 |
hidden_states_ln = self.LayerNorm(hidden_states)
|
| 1131 |
attn_output, attn_weights, s_max = self.self(
|
| 1132 |
hidden_states_ln,
|
|
@@ -1154,13 +1165,13 @@ class ModifiedEsmLayer(EsmLayer):
|
|
| 1154 |
def forward(
|
| 1155 |
self,
|
| 1156 |
hidden_states: torch.Tensor,
|
| 1157 |
-
attention_mask_2d: torch.Tensor
|
| 1158 |
-
attention_mask_4d: torch.Tensor
|
| 1159 |
-
flex_block_mask:
|
| 1160 |
output_attentions: bool = False,
|
| 1161 |
output_s_max: bool = False,
|
| 1162 |
type_ids: Optional[torch.Tensor] = None,
|
| 1163 |
-
) ->
|
| 1164 |
attention_output, attn_weights, s_max = self.attention(
|
| 1165 |
hidden_states,
|
| 1166 |
attention_mask_2d=attention_mask_2d,
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
import torch
|
| 4 |
import torch._inductor.config as inductor_config
|
| 5 |
import torch._dynamo as dynamo
|
|
|
|
| 431 |
flex attention helpers, flash kernel detection/dispatch, and pad/unpad utilities.
|
| 432 |
"""
|
| 433 |
from enum import Enum
|
| 434 |
+
from functools import partial
|
| 435 |
+
from typing import Dict, List, Optional, Tuple
|
| 436 |
|
| 437 |
import torch
|
| 438 |
import torch.nn as nn
|
|
|
|
| 450 |
|
| 451 |
|
| 452 |
def _get_flex_attention_fn():
|
| 453 |
+
"""Return flex_attention callable: compiled (fused kernel) by default, or eager when debug flag is set.
|
| 454 |
+
|
| 455 |
+
Uses kernel_options={"BACKEND": "FLASH"} to prefer Flash Attention 4 (FA4)
|
| 456 |
+
on Hopper/Blackwell GPUs (PyTorch 2.11+). Automatically falls back to Triton
|
| 457 |
+
on older hardware.
|
| 458 |
+
"""
|
| 459 |
global _compiled_flex_attention
|
| 460 |
if flex_attention is None:
|
| 461 |
return None
|
|
|
|
| 463 |
if getattr(flex_mod, "_FLEX_ATTENTION_DISABLE_COMPILE_DEBUG", False):
|
| 464 |
return flex_attention
|
| 465 |
if _compiled_flex_attention is None:
|
| 466 |
+
_compiled_flex_attention = torch.compile(
|
| 467 |
+
partial(flex_attention, kernel_options={"BACKEND": "FLASH"}),
|
| 468 |
+
dynamic=False,
|
| 469 |
+
)
|
| 470 |
return _compiled_flex_attention
|
| 471 |
|
| 472 |
|
| 473 |
### Kernels Flash Attention Detection
|
| 474 |
+
def _infer_kernels_flash_variant(kernel) -> Optional[str]:
|
| 475 |
if hasattr(kernel, "fwd") and hasattr(kernel, "varlen_fwd"):
|
| 476 |
return "flash_attn2"
|
| 477 |
if hasattr(kernel, "flash_attn_func") and hasattr(kernel, "flash_attn_varlen_func"):
|
|
|
|
| 587 |
).reshape(-1, *other_shape)
|
| 588 |
|
| 589 |
@staticmethod
|
| 590 |
+
def backward(ctx, grad_output) -> Tuple[torch.Tensor, None]:
|
| 591 |
(indices,) = ctx.saved_tensors
|
| 592 |
assert grad_output.ndim >= 2
|
| 593 |
other_shape = grad_output.shape[1:]
|
|
|
|
| 610 |
return output
|
| 611 |
|
| 612 |
@staticmethod
|
| 613 |
+
def backward(ctx, grad_output) -> Tuple[torch.Tensor, None, None]:
|
| 614 |
(indices,) = ctx.saved_tensors
|
| 615 |
return grad_output[indices], None, None
|
| 616 |
|
|
|
|
| 629 |
key_layer: torch.Tensor,
|
| 630 |
value_layer: torch.Tensor,
|
| 631 |
attention_mask_2d: torch.Tensor,
|
| 632 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, Tuple[torch.Tensor, torch.Tensor], Tuple[int, int]]:
|
| 633 |
batch_size, seq_len, num_heads, head_dim = query_layer.shape
|
| 634 |
seqlens = attention_mask_2d.sum(dim=1).int()
|
| 635 |
cu_seqlens = F.pad(seqlens.cumsum(0, dtype=torch.int32), (1, 0))
|
|
|
|
| 645 |
query_states: torch.Tensor,
|
| 646 |
key_states: torch.Tensor,
|
| 647 |
value_states: torch.Tensor,
|
| 648 |
+
attention_mask_2d: Optional[torch.Tensor] = None,
|
| 649 |
causal: bool = False,
|
| 650 |
) -> torch.Tensor:
|
| 651 |
assert FLASH_KERNEL is not None, "Kernel Flash Attention is not available in this environment."
|
|
|
|
| 718 |
seq_len: int,
|
| 719 |
device: torch.device,
|
| 720 |
attention_mask: Optional[torch.Tensor] = None,
|
| 721 |
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor], Optional[BlockMask]]:
|
| 722 |
"""Build padding masks once for all encoder layers.
|
| 723 |
|
| 724 |
Returns (attention_mask_2d, attention_mask_4d, flex_block_mask).
|
|
|
|
| 849 |
last_hidden_state: Optional[torch.Tensor] = None
|
| 850 |
hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
|
| 851 |
attentions: Optional[Tuple[torch.Tensor, ...]] = None
|
| 852 |
+
s_max: Optional[Tuple[List[torch.Tensor], ...]] = None
|
| 853 |
|
| 854 |
|
| 855 |
@dataclass
|
|
|
|
| 857 |
last_hidden_state: Optional[torch.Tensor] = None
|
| 858 |
hidden_states: Optional[Tuple[torch.Tensor, ...]] = None
|
| 859 |
attentions: Optional[Tuple[torch.Tensor, ...]] = None
|
| 860 |
+
s_max: Optional[Tuple[List[torch.Tensor], ...]] = None
|
| 861 |
|
| 862 |
|
| 863 |
class DPLM2Config(EsmConfig):
|
|
|
|
| 997 |
def forward(
|
| 998 |
self,
|
| 999 |
hidden_states: torch.Tensor,
|
| 1000 |
+
attention_mask_2d: Optional[torch.Tensor] = None,
|
| 1001 |
+
attention_mask_4d: Optional[torch.Tensor] = None,
|
| 1002 |
+
flex_block_mask: Optional[BlockMask] = None,
|
| 1003 |
output_attentions: bool = False,
|
| 1004 |
output_s_max: bool = False,
|
| 1005 |
type_ids: Optional[torch.Tensor] = None,
|
| 1006 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]:
|
| 1007 |
batch_size, seq_length = hidden_states.shape[:-1]
|
| 1008 |
hidden_shape = (batch_size, seq_length, -1, self.attention_head_size)
|
| 1009 |
query_BHLD = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
|
|
|
|
| 1030 |
query_BHLD: torch.Tensor,
|
| 1031 |
key_BHLD: torch.Tensor,
|
| 1032 |
value_BHLD: torch.Tensor,
|
| 1033 |
+
attention_mask_2d: Optional[torch.Tensor] = None,
|
| 1034 |
+
attention_mask_4d: Optional[torch.Tensor] = None,
|
| 1035 |
+
flex_block_mask: Optional[BlockMask] = None,
|
| 1036 |
output_attentions: bool = False,
|
| 1037 |
output_s_max: bool = False,
|
| 1038 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]:
|
| 1039 |
if output_attentions:
|
| 1040 |
return self._manual_attn(query_BHLD, key_BHLD, value_BHLD, attention_mask_4d, output_s_max)
|
| 1041 |
|
|
|
|
| 1052 |
return attn_output, attn_weights, s_max
|
| 1053 |
|
| 1054 |
@torch.no_grad()
|
| 1055 |
+
def _compute_s_max(self, query_BHLD: torch.Tensor, key_BHLD: torch.Tensor) -> List[torch.Tensor]:
|
| 1056 |
q_norm = torch.linalg.vector_norm(query_BHLD, dim=-1)
|
| 1057 |
k_norm = torch.linalg.vector_norm(key_BHLD, dim=-1)
|
| 1058 |
s_max_bound = (q_norm.max(dim=-1).values * k_norm.max(dim=-1).values).max(dim=0).values
|
|
|
|
| 1063 |
query_BHLD: torch.Tensor,
|
| 1064 |
key_BHLD: torch.Tensor,
|
| 1065 |
value_BHLD: torch.Tensor,
|
| 1066 |
+
attention_mask_4d: Optional[torch.Tensor] = None,
|
| 1067 |
output_s_max: bool = False,
|
| 1068 |
+
) -> Tuple[torch.Tensor, torch.Tensor, Optional[List[torch.Tensor]]]:
|
| 1069 |
attn_weights = torch.matmul(query_BHLD, key_BHLD.transpose(-1, -2))
|
| 1070 |
if attention_mask_4d is not None:
|
| 1071 |
attn_weights = attn_weights.masked_fill(attention_mask_4d.logical_not(), float("-inf"))
|
|
|
|
| 1082 |
query_BHLD: torch.Tensor,
|
| 1083 |
key_BHLD: torch.Tensor,
|
| 1084 |
value_BHLD: torch.Tensor,
|
| 1085 |
+
attention_mask_2d: Optional[torch.Tensor] = None,
|
| 1086 |
+
) -> Tuple[torch.Tensor, None]:
|
| 1087 |
query_BLHD = query_BHLD.transpose(1, 2).contiguous()
|
| 1088 |
key_BLHD = key_BHLD.transpose(1, 2).contiguous()
|
| 1089 |
value_BLHD = value_BHLD.transpose(1, 2).contiguous()
|
|
|
|
| 1098 |
query_BHLD: torch.Tensor,
|
| 1099 |
key_BHLD: torch.Tensor,
|
| 1100 |
value_BHLD: torch.Tensor,
|
| 1101 |
+
flex_block_mask: Optional[BlockMask] = None,
|
| 1102 |
+
) -> Tuple[torch.Tensor, None]:
|
| 1103 |
assert flex_attention is not None, "Flex attention is not available in this environment."
|
| 1104 |
fn = _get_flex_attention_fn()
|
| 1105 |
context_BHLD = fn(query_BHLD, key_BHLD, value_BHLD, block_mask=flex_block_mask, scale=1.0)
|
|
|
|
| 1110 |
query_BHLD: torch.Tensor,
|
| 1111 |
key_BHLD: torch.Tensor,
|
| 1112 |
value_BHLD: torch.Tensor,
|
| 1113 |
+
attention_mask_4d: Optional[torch.Tensor] = None,
|
| 1114 |
+
) -> Tuple[torch.Tensor, None]:
|
| 1115 |
context_BHLD = F.scaled_dot_product_attention(
|
| 1116 |
query_BHLD, key_BHLD, value_BHLD,
|
| 1117 |
attn_mask=attention_mask_4d,
|
|
|
|
| 1131 |
def forward(
|
| 1132 |
self,
|
| 1133 |
hidden_states: torch.Tensor,
|
| 1134 |
+
attention_mask_2d: Optional[torch.Tensor] = None,
|
| 1135 |
+
attention_mask_4d: Optional[torch.Tensor] = None,
|
| 1136 |
+
flex_block_mask: Optional[BlockMask] = None,
|
| 1137 |
output_attentions: bool = False,
|
| 1138 |
output_s_max: bool = False,
|
| 1139 |
type_ids: Optional[torch.Tensor] = None,
|
| 1140 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]:
|
| 1141 |
hidden_states_ln = self.LayerNorm(hidden_states)
|
| 1142 |
attn_output, attn_weights, s_max = self.self(
|
| 1143 |
hidden_states_ln,
|
|
|
|
| 1165 |
def forward(
|
| 1166 |
self,
|
| 1167 |
hidden_states: torch.Tensor,
|
| 1168 |
+
attention_mask_2d: Optional[torch.Tensor] = None,
|
| 1169 |
+
attention_mask_4d: Optional[torch.Tensor] = None,
|
| 1170 |
+
flex_block_mask: Optional[BlockMask] = None,
|
| 1171 |
output_attentions: bool = False,
|
| 1172 |
output_s_max: bool = False,
|
| 1173 |
type_ids: Optional[torch.Tensor] = None,
|
| 1174 |
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[List[torch.Tensor]]]:
|
| 1175 |
attention_output, attn_weights, s_max = self.attention(
|
| 1176 |
hidden_states,
|
| 1177 |
attention_mask_2d=attention_mask_2d,
|