|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.nn import functional as F |
|
|
from einops import rearrange, repeat |
|
|
|
|
|
|
|
|
try: |
|
|
from mamba_ssm.modules.mamba_simple import Mamba |
|
|
except ImportError: |
|
|
print("Warning (msd_utils.py): Failed to import Mamba. Using placeholder.") |
|
|
class Mamba(nn.Module): |
|
|
def __init__(self, *args, **kwargs): |
|
|
super().__init__() |
|
|
d_model = kwargs.get('d_model') |
|
|
if d_model is None: raise ValueError("d_model must be provided") |
|
|
self.dummy_layer = nn.Linear(d_model, d_model) |
|
|
def forward(self, x, **kwargs): return self.dummy_layer(x) |
|
|
|
|
|
|
|
|
class MambaSequentialBlock(nn.Module): |
|
|
def __init__(self, d_model, d_state=16, d_conv=4, expand=2, **kwargs): |
|
|
super().__init__() |
|
|
self.d_model = d_model |
|
|
self.norm = nn.LayerNorm(d_model) |
|
|
self.mamba = Mamba( |
|
|
d_model=d_model, |
|
|
d_state=d_state, |
|
|
d_conv=d_conv, |
|
|
expand=expand, |
|
|
) |
|
|
|
|
|
def forward(self, x, **kwargs): |
|
|
B, L, D = x.shape |
|
|
if D != self.d_model: |
|
|
raise ValueError(f"Input dim {D} != Mamba d_model {self.d_model}") |
|
|
|
|
|
x_in_residual = x |
|
|
x_norm = self.norm(x_in_residual) |
|
|
y_mamba = self.mamba(x_norm) |
|
|
y_out = x_in_residual + y_mamba |
|
|
return y_out |
|
|
|
|
|
|
|
|
from diffusers import UNet2DConditionModel |
|
|
from diffusers.models.attention import BasicTransformerBlock |
|
|
|
|
|
def replace_unet_self_attention_with_mamba(unet: UNet2DConditionModel, mamba_block_kwargs: dict): |
|
|
mamba_block_count = 0 |
|
|
sa_block_count = 0 |
|
|
for name, module in unet.named_modules(): |
|
|
if isinstance(module, BasicTransformerBlock): |
|
|
if hasattr(module, 'attn1') and module.attn1 is not None: |
|
|
if not isinstance(module.attn1, MambaSequentialBlock): |
|
|
sa_block_count += 1 |
|
|
try: |
|
|
dim = module.norm1.normalized_shape[0] |
|
|
mamba_block = MambaSequentialBlock(d_model=dim, **mamba_block_kwargs) |
|
|
module.attn1 = mamba_block |
|
|
mamba_block_count += 1 |
|
|
except Exception as e: |
|
|
print(f"Warning: Failed to replace attn1 in {name}: {e}") |
|
|
|
|
|
if hasattr(module, 'attn2') and isinstance(module.attn2, MambaSequentialBlock): |
|
|
print(f"ERROR: Cross-Attention attn2 in {name} was replaced by Mamba!") |
|
|
print(f"Replaced {mamba_block_count}/{sa_block_count} Self-Attention blocks with MambaSequentialBlock.") |
|
|
return unet |
|
|
|