|
|
from typing import TypedDict |
|
|
from torch import nn |
|
|
|
|
|
|
|
|
class TransformerLayerCFG(TypedDict): |
|
|
d_model : int |
|
|
nhead : int |
|
|
batch_first : bool |
|
|
norm_first : bool |
|
|
bias : bool |
|
|
dim_feedforward : int |
|
|
dropout : float |
|
|
layer_norm_eps : float |
|
|
|
|
|
@classmethod |
|
|
def create(cls, |
|
|
d_model : int = 768, |
|
|
nhead : int = 12, |
|
|
batch_first : bool = True, |
|
|
norm_first : bool = False, |
|
|
bias : bool = True, |
|
|
mlp_ratio : float = 4.0, |
|
|
dropout : float = 0.0, |
|
|
layer_norm_eps : float = 1e-6) -> 'TransformerLayerCFG': |
|
|
return TransformerLayerCFG(d_model = d_model, |
|
|
nhead = nhead, |
|
|
batch_first = batch_first, |
|
|
norm_first = norm_first, |
|
|
bias = bias, |
|
|
dim_feedforward = int(d_model * mlp_ratio), |
|
|
dropout = dropout, |
|
|
layer_norm_eps = layer_norm_eps) |
|
|
|
|
|
|
|
|
|
|
|
class TransformerEncoderCFG(TypedDict): |
|
|
num_layers : int |
|
|
enable_nested_tensor: bool |
|
|
mask_check: bool |
|
|
|
|
|
@classmethod |
|
|
def create(cls, |
|
|
num_layers : int = 12, |
|
|
enable_nested_tensor: bool = False, |
|
|
mask_check: bool = True) -> 'TransformerEncoderCFG': |
|
|
return TransformerEncoderCFG(num_layers=num_layers, |
|
|
enable_nested_tensor = enable_nested_tensor, |
|
|
mask_check = mask_check) |