# --- msd_utils.py --- import torch import torch.nn as nn from torch.nn import functional as F from einops import rearrange, repeat # Try to import Mamba, provide a placeholder if it fails (though it should be installed) try: from mamba_ssm.modules.mamba_simple import Mamba except ImportError: print("Warning (msd_utils.py): Failed to import Mamba. Using placeholder.") class Mamba(nn.Module): def __init__(self, *args, **kwargs): super().__init__() d_model = kwargs.get('d_model') if d_model is None: raise ValueError("d_model must be provided") self.dummy_layer = nn.Linear(d_model, d_model) def forward(self, x, **kwargs): return self.dummy_layer(x) # Mamba Block designed to replace Self-Attention within BasicTransformerBlock class MambaSequentialBlock(nn.Module): def __init__(self, d_model, d_state=16, d_conv=4, expand=2, **kwargs): super().__init__() self.d_model = d_model self.norm = nn.LayerNorm(d_model) self.mamba = Mamba( d_model=d_model, d_state=d_state, d_conv=d_conv, expand=expand, ) def forward(self, x, **kwargs): # Accepts (B, L, D) B, L, D = x.shape if D != self.d_model: raise ValueError(f"Input dim {D} != Mamba d_model {self.d_model}") x_in_residual = x x_norm = self.norm(x_in_residual) y_mamba = self.mamba(x_norm) y_out = x_in_residual + y_mamba return y_out # Function to replace SA in UNet from diffusers import UNet2DConditionModel from diffusers.models.attention import BasicTransformerBlock def replace_unet_self_attention_with_mamba(unet: UNet2DConditionModel, mamba_block_kwargs: dict): mamba_block_count = 0 sa_block_count = 0 for name, module in unet.named_modules(): if isinstance(module, BasicTransformerBlock): if hasattr(module, 'attn1') and module.attn1 is not None: if not isinstance(module.attn1, MambaSequentialBlock): sa_block_count += 1 try: dim = module.norm1.normalized_shape[0] mamba_block = MambaSequentialBlock(d_model=dim, **mamba_block_kwargs) module.attn1 = mamba_block mamba_block_count += 1 except Exception as e: print(f"Warning: Failed to replace attn1 in {name}: {e}") # Check if CA was accidentally replaced (shouldn't happen) if hasattr(module, 'attn2') and isinstance(module.attn2, MambaSequentialBlock): print(f"ERROR: Cross-Attention attn2 in {name} was replaced by Mamba!") print(f"Replaced {mamba_block_count}/{sa_block_count} Self-Attention blocks with MambaSequentialBlock.") return unet # --- End of msd_utils.py ---