File size: 2,923 Bytes
5e7715d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
# --- msd_utils.py ---
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange, repeat
# Try to import Mamba, provide a placeholder if it fails (though it should be installed)
try:
from mamba_ssm.modules.mamba_simple import Mamba
except ImportError:
print("Warning (msd_utils.py): Failed to import Mamba. Using placeholder.")
class Mamba(nn.Module):
def __init__(self, *args, **kwargs):
super().__init__()
d_model = kwargs.get('d_model')
if d_model is None: raise ValueError("d_model must be provided")
self.dummy_layer = nn.Linear(d_model, d_model)
def forward(self, x, **kwargs): return self.dummy_layer(x)
# Mamba Block designed to replace Self-Attention within BasicTransformerBlock
class MambaSequentialBlock(nn.Module):
def __init__(self, d_model, d_state=16, d_conv=4, expand=2, **kwargs):
super().__init__()
self.d_model = d_model
self.norm = nn.LayerNorm(d_model)
self.mamba = Mamba(
d_model=d_model,
d_state=d_state,
d_conv=d_conv,
expand=expand,
)
def forward(self, x, **kwargs): # Accepts (B, L, D)
B, L, D = x.shape
if D != self.d_model:
raise ValueError(f"Input dim {D} != Mamba d_model {self.d_model}")
x_in_residual = x
x_norm = self.norm(x_in_residual)
y_mamba = self.mamba(x_norm)
y_out = x_in_residual + y_mamba
return y_out
# Function to replace SA in UNet
from diffusers import UNet2DConditionModel
from diffusers.models.attention import BasicTransformerBlock
def replace_unet_self_attention_with_mamba(unet: UNet2DConditionModel, mamba_block_kwargs: dict):
mamba_block_count = 0
sa_block_count = 0
for name, module in unet.named_modules():
if isinstance(module, BasicTransformerBlock):
if hasattr(module, 'attn1') and module.attn1 is not None:
if not isinstance(module.attn1, MambaSequentialBlock):
sa_block_count += 1
try:
dim = module.norm1.normalized_shape[0]
mamba_block = MambaSequentialBlock(d_model=dim, **mamba_block_kwargs)
module.attn1 = mamba_block
mamba_block_count += 1
except Exception as e:
print(f"Warning: Failed to replace attn1 in {name}: {e}")
# Check if CA was accidentally replaced (shouldn't happen)
if hasattr(module, 'attn2') and isinstance(module.attn2, MambaSequentialBlock):
print(f"ERROR: Cross-Attention attn2 in {name} was replaced by Mamba!")
print(f"Replaced {mamba_block_count}/{sa_block_count} Self-Attention blocks with MambaSequentialBlock.")
return unet
# --- End of msd_utils.py --- |