File size: 2,923 Bytes
5e7715d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# --- msd_utils.py ---
import torch
import torch.nn as nn
from torch.nn import functional as F
from einops import rearrange, repeat

# Try to import Mamba, provide a placeholder if it fails (though it should be installed)
try:
    from mamba_ssm.modules.mamba_simple import Mamba
except ImportError:
    print("Warning (msd_utils.py): Failed to import Mamba. Using placeholder.")
    class Mamba(nn.Module):
        def __init__(self, *args, **kwargs):
            super().__init__()
            d_model = kwargs.get('d_model')
            if d_model is None: raise ValueError("d_model must be provided")
            self.dummy_layer = nn.Linear(d_model, d_model)
        def forward(self, x, **kwargs): return self.dummy_layer(x)

# Mamba Block designed to replace Self-Attention within BasicTransformerBlock
class MambaSequentialBlock(nn.Module):
    def __init__(self, d_model, d_state=16, d_conv=4, expand=2, **kwargs):
        super().__init__()
        self.d_model = d_model
        self.norm = nn.LayerNorm(d_model)
        self.mamba = Mamba(
            d_model=d_model,
            d_state=d_state,
            d_conv=d_conv,
            expand=expand,
        )

    def forward(self, x, **kwargs): # Accepts (B, L, D)
        B, L, D = x.shape
        if D != self.d_model:
            raise ValueError(f"Input dim {D} != Mamba d_model {self.d_model}")

        x_in_residual = x
        x_norm = self.norm(x_in_residual)
        y_mamba = self.mamba(x_norm)
        y_out = x_in_residual + y_mamba
        return y_out

# Function to replace SA in UNet
from diffusers import UNet2DConditionModel
from diffusers.models.attention import BasicTransformerBlock

def replace_unet_self_attention_with_mamba(unet: UNet2DConditionModel, mamba_block_kwargs: dict):
    mamba_block_count = 0
    sa_block_count = 0
    for name, module in unet.named_modules():
        if isinstance(module, BasicTransformerBlock):
            if hasattr(module, 'attn1') and module.attn1 is not None:
                if not isinstance(module.attn1, MambaSequentialBlock):
                    sa_block_count += 1
                    try:
                        dim = module.norm1.normalized_shape[0]
                        mamba_block = MambaSequentialBlock(d_model=dim, **mamba_block_kwargs)
                        module.attn1 = mamba_block
                        mamba_block_count += 1
                    except Exception as e:
                        print(f"Warning: Failed to replace attn1 in {name}: {e}")
                # Check if CA was accidentally replaced (shouldn't happen)
                if hasattr(module, 'attn2') and isinstance(module.attn2, MambaSequentialBlock):
                     print(f"ERROR: Cross-Attention attn2 in {name} was replaced by Mamba!")
    print(f"Replaced {mamba_block_count}/{sa_block_count} Self-Attention blocks with MambaSequentialBlock.")
    return unet
# --- End of msd_utils.py ---