Enable mambav2 compat
Browse files- modeling_bimamba.py +29 -10
modeling_bimamba.py
CHANGED
|
@@ -2,21 +2,29 @@
|
|
| 2 |
|
| 3 |
"""
|
| 4 |
|
|
|
|
| 5 |
import math
|
| 6 |
from functools import partial
|
| 7 |
from typing import Optional, Tuple, Union
|
| 8 |
|
| 9 |
import torch
|
| 10 |
-
from mamba_ssm.modules.mamba_simple import Mamba
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
from torch import nn
|
| 12 |
from torch.nn import functional as F
|
| 13 |
from transformers import PreTrainedModel
|
| 14 |
from transformers.modeling_outputs import BaseModelOutputWithNoAttention, MaskedLMOutput
|
| 15 |
|
| 16 |
try:
|
| 17 |
-
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
|
| 18 |
except ImportError:
|
| 19 |
-
|
|
|
|
|
|
|
|
|
|
| 20 |
|
| 21 |
from .configuration_bimamba import BiMambaConfig
|
| 22 |
|
|
@@ -52,13 +60,24 @@ def create_block(
|
|
| 52 |
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
|
| 53 |
)
|
| 54 |
block_cls = Block
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
block.layer_idx = layer_idx
|
| 63 |
return block
|
| 64 |
|
|
|
|
| 2 |
|
| 3 |
"""
|
| 4 |
|
| 5 |
+
import inspect
|
| 6 |
import math
|
| 7 |
from functools import partial
|
| 8 |
from typing import Optional, Tuple, Union
|
| 9 |
|
| 10 |
import torch
|
| 11 |
+
from mamba_ssm.modules.mamba_simple import Mamba
|
| 12 |
+
try:
|
| 13 |
+
from mamba_ssm.modules.mamba_simple import Block # Legacy mambav1 file structure
|
| 14 |
+
except ImportError:
|
| 15 |
+
from mamba_ssm.modules.block import Block # mambav2 file structure
|
| 16 |
from torch import nn
|
| 17 |
from torch.nn import functional as F
|
| 18 |
from transformers import PreTrainedModel
|
| 19 |
from transformers.modeling_outputs import BaseModelOutputWithNoAttention, MaskedLMOutput
|
| 20 |
|
| 21 |
try:
|
| 22 |
+
from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn # Legacy mambav1 file structure
|
| 23 |
except ImportError:
|
| 24 |
+
try:
|
| 25 |
+
from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn # mambav2 file structure
|
| 26 |
+
except ImportError:
|
| 27 |
+
RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
|
| 28 |
|
| 29 |
from .configuration_bimamba import BiMambaConfig
|
| 30 |
|
|
|
|
| 60 |
nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
|
| 61 |
)
|
| 62 |
block_cls = Block
|
| 63 |
+
# mambav2 compatibility
|
| 64 |
+
if "mlp_cls" in inspect.signature(block_cls.__init__).parameters:
|
| 65 |
+
block = block_cls(
|
| 66 |
+
d_model,
|
| 67 |
+
mixer_cls,
|
| 68 |
+
mlp_cls=nn.Identity,
|
| 69 |
+
norm_cls=norm_cls,
|
| 70 |
+
fused_add_norm=fused_add_norm,
|
| 71 |
+
residual_in_fp32=residual_in_fp32,
|
| 72 |
+
)
|
| 73 |
+
else:
|
| 74 |
+
block = block_cls(
|
| 75 |
+
d_model,
|
| 76 |
+
mixer_cls,
|
| 77 |
+
norm_cls=norm_cls,
|
| 78 |
+
fused_add_norm=fused_add_norm,
|
| 79 |
+
residual_in_fp32=residual_in_fp32,
|
| 80 |
+
)
|
| 81 |
block.layer_idx = layer_idx
|
| 82 |
return block
|
| 83 |
|