|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from .activation import Mish |
|
|
from .attention import (Attention, AttentionMaskType, AttentionParams, |
|
|
BertAttention, BlockSparseAttnParams, CogVLMAttention, |
|
|
KeyValueCacheParams, PositionEmbeddingType, |
|
|
SpecDecodingParams) |
|
|
from .cast import Cast |
|
|
from .conv import Conv1d, Conv2d, ConvTranspose2d |
|
|
from .embedding import Embedding, PromptTuningEmbedding |
|
|
from .linear import (ColumnLinear, Linear, ParallelLMHead, QKVColumnLinear, |
|
|
RowLinear) |
|
|
from .lora import Lora, LoraParams, LoraRuntimeParams |
|
|
from .mlp import MLP, FusedGatedMLP, GatedMLP |
|
|
from .moe import MOE, MoeConfig |
|
|
from .normalization import GroupNorm, LayerNorm, RmsNorm |
|
|
from .pooling import AvgPool2d |
|
|
from .recurrent import FusedRgLru, GroupedLinear, Recurrent, RgLru |
|
|
from .ssm import Mamba, Mamba2 |
|
|
|
|
|
__all__ = [ |
|
|
'LayerNorm', |
|
|
'RmsNorm', |
|
|
'ColumnLinear', |
|
|
'Linear', |
|
|
'RowLinear', |
|
|
'QKVColumnLinear', |
|
|
'ParallelLMHead', |
|
|
'AttentionMaskType', |
|
|
'PositionEmbeddingType', |
|
|
'Attention', |
|
|
'BertAttention', |
|
|
'CogVLMAttention', |
|
|
'GroupNorm', |
|
|
'Embedding', |
|
|
'PromptTuningEmbedding', |
|
|
'Conv2d', |
|
|
'ConvTranspose2d', |
|
|
'Conv1d', |
|
|
'AvgPool2d', |
|
|
'Mish', |
|
|
'MLP', |
|
|
'GatedMLP', |
|
|
'FusedGatedMLP', |
|
|
'Cast', |
|
|
'AttentionParams', |
|
|
'SpecDecodingParams', |
|
|
'KeyValueCacheParams', |
|
|
'BlockSparseAttnParams', |
|
|
'Lora', |
|
|
'LoraParams', |
|
|
'LoraRuntimeParams', |
|
|
'MOE', |
|
|
'MoeConfig', |
|
|
'Mamba', |
|
|
'Mamba2', |
|
|
'Recurrent', |
|
|
'GroupedLinear', |
|
|
'RgLru', |
|
|
'FusedRgLru', |
|
|
] |
|
|
|