File size: 1,508 Bytes
fefd7ae |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
from typing import TypedDict
from torch import nn
class TransformerLayerCFG(TypedDict):
d_model : int
nhead : int
batch_first : bool
norm_first : bool
bias : bool
dim_feedforward : int
dropout : float
layer_norm_eps : float
@classmethod
def create(cls,
d_model : int = 768,
nhead : int = 12,
batch_first : bool = True,
norm_first : bool = False,
bias : bool = True,
mlp_ratio : float = 4.0,
dropout : float = 0.0,
layer_norm_eps : float = 1e-6) -> 'TransformerLayerCFG':
return TransformerLayerCFG(d_model = d_model,
nhead = nhead,
batch_first = batch_first,
norm_first = norm_first,
bias = bias,
dim_feedforward = int(d_model * mlp_ratio),
dropout = dropout,
layer_norm_eps = layer_norm_eps)
# Norm needs to be defined by the user!
class TransformerEncoderCFG(TypedDict):
num_layers : int
enable_nested_tensor: bool
mask_check: bool
@classmethod
def create(cls,
num_layers : int = 12,
enable_nested_tensor: bool = False,
mask_check: bool = True) -> 'TransformerEncoderCFG':
return TransformerEncoderCFG(num_layers=num_layers,
enable_nested_tensor = enable_nested_tensor,
mask_check = mask_check) |