| | |
| |
|
| | from fla.modules.convolution import ImplicitLongConvolution, LongConvolution, ShortConvolution |
| | from fla.modules.fused_bitlinear import BitLinear, FusedBitLinear |
| | from fla.modules.fused_cross_entropy import FusedCrossEntropyLoss |
| | from fla.modules.fused_kl_div import FusedKLDivLoss |
| | from fla.modules.fused_linear_cross_entropy import FusedLinearCrossEntropyLoss |
| | from fla.modules.fused_linear_listnet_loss import FusedLinearListNetLoss |
| | from fla.modules.fused_norm_gate import ( |
| | FusedLayerNormGated, |
| | FusedLayerNormSwishGate, |
| | FusedLayerNormSwishGateLinear, |
| | FusedRMSNormGated, |
| | FusedRMSNormSwishGate, |
| | FusedRMSNormSwishGateLinear |
| | ) |
| | from fla.modules.layernorm import GroupNorm, GroupNormLinear, LayerNorm, LayerNormLinear, RMSNorm, RMSNormLinear |
| | from fla.modules.mlp import GatedMLP |
| | from fla.modules.rotary import RotaryEmbedding |
| |
|
| | __all__ = [ |
| | 'ImplicitLongConvolution', 'LongConvolution', 'ShortConvolution', |
| | 'BitLinear', 'FusedBitLinear', |
| | 'FusedCrossEntropyLoss', 'FusedLinearCrossEntropyLoss', 'FusedKLDivLoss', |
| | 'GroupNorm', 'GroupNormLinear', 'LayerNorm', 'LayerNormLinear', 'RMSNorm', 'RMSNormLinear', |
| | 'FusedLayerNormGated', 'FusedLayerNormSwishGate', 'FusedLayerNormSwishGateLinear', |
| | 'FusedRMSNormGated', 'FusedRMSNormSwishGate', 'FusedRMSNormSwishGateLinear', |
| | 'GatedMLP', |
| | 'RotaryEmbedding' |
| | ] |
| |
|