Spaces:
Runtime error
Runtime error
| """k-diffusion transformer diffusion models, version 1.""" | |
| import math | |
| from einops import rearrange | |
| import torch | |
| from torch import nn | |
| import torch._dynamo | |
| from torch.nn import functional as F | |
| from . import flags | |
| from .. import layers | |
| from .axial_rope import AxialRoPE, make_axial_pos | |
| if flags.get_use_compile(): | |
| torch._dynamo.config.suppress_errors = True | |
| def zero_init(layer): | |
| nn.init.zeros_(layer.weight) | |
| if layer.bias is not None: | |
| nn.init.zeros_(layer.bias) | |
| return layer | |
| def checkpoint_helper(function, *args, **kwargs): | |
| if flags.get_checkpointing(): | |
| kwargs.setdefault("use_reentrant", True) | |
| return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) | |
| else: | |
| return function(*args, **kwargs) | |
| def tag_param(param, tag): | |
| if not hasattr(param, "_tags"): | |
| param._tags = set([tag]) | |
| else: | |
| param._tags.add(tag) | |
| return param | |
| def tag_module(module, tag): | |
| for param in module.parameters(): | |
| tag_param(param, tag) | |
| return module | |
| def apply_wd(module): | |
| for name, param in module.named_parameters(): | |
| if name.endswith("weight"): | |
| tag_param(param, "wd") | |
| return module | |
| def filter_params(function, module): | |
| for param in module.parameters(): | |
| tags = getattr(param, "_tags", set()) | |
| if function(tags): | |
| yield param | |
| def scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0): | |
| if flags.get_use_flash_attention_2() and attn_mask is None: | |
| try: | |
| from flash_attn import flash_attn_func | |
| q_ = q.transpose(-3, -2) | |
| k_ = k.transpose(-3, -2) | |
| v_ = v.transpose(-3, -2) | |
| o_ = flash_attn_func(q_, k_, v_, dropout_p=dropout_p) | |
| return o_.transpose(-3, -2) | |
| except (ImportError, RuntimeError): | |
| pass | |
| return F.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p=dropout_p) | |
| def geglu(x): | |
| a, b = x.chunk(2, dim=-1) | |
| return a * F.gelu(b) | |
| def rms_norm(x, scale, eps): | |
| dtype = torch.promote_types(x.dtype, torch.float32) | |
| mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) | |
| scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) | |
| return x * scale.to(x.dtype) | |
| class GEGLU(nn.Module): | |
| def forward(self, x): | |
| return geglu(x) | |
| class RMSNorm(nn.Module): | |
| def __init__(self, param_shape, eps=1e-6): | |
| super().__init__() | |
| self.eps = eps | |
| self.scale = nn.Parameter(torch.ones(param_shape)) | |
| def extra_repr(self): | |
| return f"shape={tuple(self.scale.shape)}, eps={self.eps}" | |
| def forward(self, x): | |
| return rms_norm(x, self.scale, self.eps) | |
| class QKNorm(nn.Module): | |
| def __init__(self, n_heads, eps=1e-6, max_scale=100.0): | |
| super().__init__() | |
| self.eps = eps | |
| self.max_scale = math.log(max_scale) | |
| self.scale = nn.Parameter(torch.full((n_heads,), math.log(10.0))) | |
| self.proj_() | |
| def extra_repr(self): | |
| return f"n_heads={self.scale.shape[0]}, eps={self.eps}" | |
| def proj_(self): | |
| """Modify the scale in-place so it doesn't get "stuck" with zero gradient if it's clamped | |
| to the max value.""" | |
| self.scale.clamp_(max=self.max_scale) | |
| def forward(self, x): | |
| self.proj_() | |
| scale = torch.exp(0.5 * self.scale - 0.25 * math.log(x.shape[-1])) | |
| return rms_norm(x, scale[:, None, None], self.eps) | |
| class AdaRMSNorm(nn.Module): | |
| def __init__(self, features, cond_features, eps=1e-6): | |
| super().__init__() | |
| self.eps = eps | |
| self.linear = apply_wd(zero_init(nn.Linear(cond_features, features, bias=False))) | |
| tag_module(self.linear, "mapping") | |
| def extra_repr(self): | |
| return f"eps={self.eps}," | |
| def forward(self, x, cond): | |
| return rms_norm(x, self.linear(cond) + 1, self.eps) | |
| class SelfAttentionBlock(nn.Module): | |
| def __init__(self, d_model, d_head, dropout=0.0): | |
| super().__init__() | |
| self.d_head = d_head | |
| self.n_heads = d_model // d_head | |
| self.norm = AdaRMSNorm(d_model, d_model) | |
| self.qkv_proj = apply_wd(nn.Linear(d_model, d_model * 3, bias=False)) | |
| self.qk_norm = QKNorm(self.n_heads) | |
| self.pos_emb = AxialRoPE(d_head, self.n_heads) | |
| self.dropout = nn.Dropout(dropout) | |
| self.out_proj = apply_wd(zero_init(nn.Linear(d_model, d_model, bias=False))) | |
| def extra_repr(self): | |
| return f"d_head={self.d_head}," | |
| def forward(self, x, pos, attn_mask, cond): | |
| skip = x | |
| x = self.norm(x, cond) | |
| q, k, v = self.qkv_proj(x).chunk(3, dim=-1) | |
| q = rearrange(q, "n l (h e) -> n h l e", e=self.d_head) | |
| k = rearrange(k, "n l (h e) -> n h l e", e=self.d_head) | |
| v = rearrange(v, "n l (h e) -> n h l e", e=self.d_head) | |
| q = self.pos_emb(self.qk_norm(q), pos) | |
| k = self.pos_emb(self.qk_norm(k), pos) | |
| x = scaled_dot_product_attention(q, k, v, attn_mask) | |
| x = rearrange(x, "n h l e -> n l (h e)") | |
| x = self.dropout(x) | |
| x = self.out_proj(x) | |
| return x + skip | |
| class FeedForwardBlock(nn.Module): | |
| def __init__(self, d_model, d_ff, dropout=0.0): | |
| super().__init__() | |
| self.norm = AdaRMSNorm(d_model, d_model) | |
| self.up_proj = apply_wd(nn.Linear(d_model, d_ff * 2, bias=False)) | |
| self.act = GEGLU() | |
| self.dropout = nn.Dropout(dropout) | |
| self.down_proj = apply_wd(zero_init(nn.Linear(d_ff, d_model, bias=False))) | |
| def forward(self, x, cond): | |
| skip = x | |
| x = self.norm(x, cond) | |
| x = self.up_proj(x) | |
| x = self.act(x) | |
| x = self.dropout(x) | |
| x = self.down_proj(x) | |
| return x + skip | |
| class TransformerBlock(nn.Module): | |
| def __init__(self, d_model, d_ff, d_head, dropout=0.0): | |
| super().__init__() | |
| self.self_attn = SelfAttentionBlock(d_model, d_head, dropout=dropout) | |
| self.ff = FeedForwardBlock(d_model, d_ff, dropout=dropout) | |
| def forward(self, x, pos, attn_mask, cond): | |
| x = checkpoint_helper(self.self_attn, x, pos, attn_mask, cond) | |
| x = checkpoint_helper(self.ff, x, cond) | |
| return x | |
| class Patching(nn.Module): | |
| def __init__(self, features, patch_size): | |
| super().__init__() | |
| self.features = features | |
| self.patch_size = patch_size | |
| self.d_out = features * patch_size[0] * patch_size[1] | |
| def extra_repr(self): | |
| return f"features={self.features}, patch_size={self.patch_size!r}" | |
| def forward(self, x, pixel_aspect_ratio=1.0): | |
| *_, h, w = x.shape | |
| h_out = h // self.patch_size[0] | |
| w_out = w // self.patch_size[1] | |
| if h % self.patch_size[0] != 0 or w % self.patch_size[1] != 0: | |
| raise ValueError(f"Image size {h}x{w} is not divisible by patch size {self.patch_size[0]}x{self.patch_size[1]}") | |
| x = rearrange(x, "... c (h i) (w j) -> ... (h w) (c i j)", i=self.patch_size[0], j=self.patch_size[1]) | |
| pixel_aspect_ratio = pixel_aspect_ratio * self.patch_size[0] / self.patch_size[1] | |
| pos = make_axial_pos(h_out, w_out, pixel_aspect_ratio, device=x.device) | |
| return x, pos | |
| class Unpatching(nn.Module): | |
| def __init__(self, features, patch_size): | |
| super().__init__() | |
| self.features = features | |
| self.patch_size = patch_size | |
| self.d_in = features * patch_size[0] * patch_size[1] | |
| def extra_repr(self): | |
| return f"features={self.features}, patch_size={self.patch_size!r}" | |
| def forward(self, x, h, w): | |
| h_in = h // self.patch_size[0] | |
| w_in = w // self.patch_size[1] | |
| x = rearrange(x, "... (h w) (c i j) -> ... c (h i) (w j)", h=h_in, w=w_in, i=self.patch_size[0], j=self.patch_size[1]) | |
| return x | |
| class MappingFeedForwardBlock(nn.Module): | |
| def __init__(self, d_model, d_ff, dropout=0.0): | |
| super().__init__() | |
| self.norm = RMSNorm(d_model) | |
| self.up_proj = apply_wd(nn.Linear(d_model, d_ff * 2, bias=False)) | |
| self.act = GEGLU() | |
| self.dropout = nn.Dropout(dropout) | |
| self.down_proj = apply_wd(zero_init(nn.Linear(d_ff, d_model, bias=False))) | |
| def forward(self, x): | |
| skip = x | |
| x = self.norm(x) | |
| x = self.up_proj(x) | |
| x = self.act(x) | |
| x = self.dropout(x) | |
| x = self.down_proj(x) | |
| return x + skip | |
| class MappingNetwork(nn.Module): | |
| def __init__(self, n_layers, d_model, d_ff, dropout=0.0): | |
| super().__init__() | |
| self.in_norm = RMSNorm(d_model) | |
| self.blocks = nn.ModuleList([MappingFeedForwardBlock(d_model, d_ff, dropout=dropout) for _ in range(n_layers)]) | |
| self.out_norm = RMSNorm(d_model) | |
| def forward(self, x): | |
| x = self.in_norm(x) | |
| for block in self.blocks: | |
| x = block(x) | |
| x = self.out_norm(x) | |
| return x | |
| class ImageTransformerDenoiserModelV1(nn.Module): | |
| def __init__(self, n_layers, d_model, d_ff, in_features, out_features, patch_size, num_classes=0, dropout=0.0, sigma_data=1.0): | |
| super().__init__() | |
| self.sigma_data = sigma_data | |
| self.num_classes = num_classes | |
| self.patch_in = Patching(in_features, patch_size) | |
| self.patch_out = Unpatching(out_features, patch_size) | |
| self.time_emb = layers.FourierFeatures(1, d_model) | |
| self.time_in_proj = nn.Linear(d_model, d_model, bias=False) | |
| self.aug_emb = layers.FourierFeatures(9, d_model) | |
| self.aug_in_proj = nn.Linear(d_model, d_model, bias=False) | |
| self.class_emb = nn.Embedding(num_classes, d_model) if num_classes else None | |
| self.mapping = tag_module(MappingNetwork(2, d_model, d_ff, dropout=dropout), "mapping") | |
| self.in_proj = nn.Linear(self.patch_in.d_out, d_model, bias=False) | |
| self.blocks = nn.ModuleList([TransformerBlock(d_model, d_ff, 64, dropout=dropout) for _ in range(n_layers)]) | |
| self.out_norm = RMSNorm(d_model) | |
| self.out_proj = zero_init(nn.Linear(d_model, self.patch_out.d_in, bias=False)) | |
| def proj_(self): | |
| for block in self.blocks: | |
| block.self_attn.qk_norm.proj_() | |
| def param_groups(self, base_lr=5e-4, mapping_lr_scale=1 / 3): | |
| wd = filter_params(lambda tags: "wd" in tags and "mapping" not in tags, self) | |
| no_wd = filter_params(lambda tags: "wd" not in tags and "mapping" not in tags, self) | |
| mapping_wd = filter_params(lambda tags: "wd" in tags and "mapping" in tags, self) | |
| mapping_no_wd = filter_params(lambda tags: "wd" not in tags and "mapping" in tags, self) | |
| groups = [ | |
| {"params": list(wd), "lr": base_lr}, | |
| {"params": list(no_wd), "lr": base_lr, "weight_decay": 0.0}, | |
| {"params": list(mapping_wd), "lr": base_lr * mapping_lr_scale}, | |
| {"params": list(mapping_no_wd), "lr": base_lr * mapping_lr_scale, "weight_decay": 0.0} | |
| ] | |
| return groups | |
| def forward(self, x, sigma, aug_cond=None, class_cond=None): | |
| # Patching | |
| *_, h, w = x.shape | |
| x, pos = self.patch_in(x) | |
| attn_mask = None | |
| x = self.in_proj(x) | |
| # Mapping network | |
| if class_cond is None and self.class_emb is not None: | |
| raise ValueError("class_cond must be specified if num_classes > 0") | |
| c_noise = torch.log(sigma) / 4 | |
| time_emb = self.time_in_proj(self.time_emb(c_noise[..., None])) | |
| aug_cond = x.new_zeros([x.shape[0], 9]) if aug_cond is None else aug_cond | |
| aug_emb = self.aug_in_proj(self.aug_emb(aug_cond)) | |
| class_emb = self.class_emb(class_cond) if self.class_emb is not None else 0 | |
| cond = self.mapping(time_emb + aug_emb + class_emb).unsqueeze(-2) | |
| # Transformer | |
| for block in self.blocks: | |
| x = block(x, pos, attn_mask, cond) | |
| # Unpatching | |
| x = self.out_norm(x) | |
| x = self.out_proj(x) | |
| x = self.patch_out(x, h, w) | |
| return x | |