yairschiff commited on
Commit
2cf6420
·
verified ·
1 Parent(s): 4431cac

Enable mambav2 compat

Browse files
Files changed (1) hide show
  1. modeling_bimamba.py +29 -10
modeling_bimamba.py CHANGED
@@ -2,21 +2,29 @@
2
 
3
  """
4
 
 
5
  import math
6
  from functools import partial
7
  from typing import Optional, Tuple, Union
8
 
9
  import torch
10
- from mamba_ssm.modules.mamba_simple import Mamba, Block
 
 
 
 
11
  from torch import nn
12
  from torch.nn import functional as F
13
  from transformers import PreTrainedModel
14
  from transformers.modeling_outputs import BaseModelOutputWithNoAttention, MaskedLMOutput
15
 
16
  try:
17
- from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn
18
  except ImportError:
19
- RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
 
 
 
20
 
21
  from .configuration_bimamba import BiMambaConfig
22
 
@@ -52,13 +60,24 @@ def create_block(
52
  nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
53
  )
54
  block_cls = Block
55
- block = block_cls(
56
- d_model,
57
- mixer_cls,
58
- norm_cls=norm_cls,
59
- fused_add_norm=fused_add_norm,
60
- residual_in_fp32=residual_in_fp32,
61
- )
 
 
 
 
 
 
 
 
 
 
 
62
  block.layer_idx = layer_idx
63
  return block
64
 
 
2
 
3
  """
4
 
5
+ import inspect
6
  import math
7
  from functools import partial
8
  from typing import Optional, Tuple, Union
9
 
10
  import torch
11
+ from mamba_ssm.modules.mamba_simple import Mamba
12
+ try:
13
+ from mamba_ssm.modules.mamba_simple import Block # Legacy mambav1 file structure
14
+ except ImportError:
15
+ from mamba_ssm.modules.block import Block # mambav2 file structure
16
  from torch import nn
17
  from torch.nn import functional as F
18
  from transformers import PreTrainedModel
19
  from transformers.modeling_outputs import BaseModelOutputWithNoAttention, MaskedLMOutput
20
 
21
  try:
22
+ from mamba_ssm.ops.triton.layernorm import RMSNorm, layer_norm_fn, rms_norm_fn # Legacy mambav1 file structure
23
  except ImportError:
24
+ try:
25
+ from mamba_ssm.ops.triton.layer_norm import RMSNorm, layer_norm_fn, rms_norm_fn # mambav2 file structure
26
+ except ImportError:
27
+ RMSNorm, layer_norm_fn, rms_norm_fn = None, None, None
28
 
29
  from .configuration_bimamba import BiMambaConfig
30
 
 
60
  nn.LayerNorm if not rms_norm else RMSNorm, eps=norm_epsilon, **factory_kwargs
61
  )
62
  block_cls = Block
63
+ # mambav2 compatibility
64
+ if "mlp_cls" in inspect.signature(block_cls.__init__).parameters:
65
+ block = block_cls(
66
+ d_model,
67
+ mixer_cls,
68
+ mlp_cls=nn.Identity,
69
+ norm_cls=norm_cls,
70
+ fused_add_norm=fused_add_norm,
71
+ residual_in_fp32=residual_in_fp32,
72
+ )
73
+ else:
74
+ block = block_cls(
75
+ d_model,
76
+ mixer_cls,
77
+ norm_cls=norm_cls,
78
+ fused_add_norm=fused_add_norm,
79
+ residual_in_fp32=residual_in_fp32,
80
+ )
81
  block.layer_idx = layer_idx
82
  return block
83