Spaces:
Runtime error
Runtime error
File size: 11,930 Bytes
c19ca42 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 |
"""k-diffusion transformer diffusion models, version 1."""
import math
from einops import rearrange
import torch
from torch import nn
import torch._dynamo
from torch.nn import functional as F
from . import flags
from .. import layers
from .axial_rope import AxialRoPE, make_axial_pos
if flags.get_use_compile():
torch._dynamo.config.suppress_errors = True
def zero_init(layer):
nn.init.zeros_(layer.weight)
if layer.bias is not None:
nn.init.zeros_(layer.bias)
return layer
def checkpoint_helper(function, *args, **kwargs):
if flags.get_checkpointing():
kwargs.setdefault("use_reentrant", True)
return torch.utils.checkpoint.checkpoint(function, *args, **kwargs)
else:
return function(*args, **kwargs)
def tag_param(param, tag):
if not hasattr(param, "_tags"):
param._tags = set([tag])
else:
param._tags.add(tag)
return param
def tag_module(module, tag):
for param in module.parameters():
tag_param(param, tag)
return module
def apply_wd(module):
for name, param in module.named_parameters():
if name.endswith("weight"):
tag_param(param, "wd")
return module
def filter_params(function, module):
for param in module.parameters():
tags = getattr(param, "_tags", set())
if function(tags):
yield param
def scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0):
if flags.get_use_flash_attention_2() and attn_mask is None:
try:
from flash_attn import flash_attn_func
q_ = q.transpose(-3, -2)
k_ = k.transpose(-3, -2)
v_ = v.transpose(-3, -2)
o_ = flash_attn_func(q_, k_, v_, dropout_p=dropout_p)
return o_.transpose(-3, -2)
except (ImportError, RuntimeError):
pass
return F.scaled_dot_product_attention(q, k, v, attn_mask, dropout_p=dropout_p)
@flags.compile_wrap
def geglu(x):
a, b = x.chunk(2, dim=-1)
return a * F.gelu(b)
@flags.compile_wrap
def rms_norm(x, scale, eps):
dtype = torch.promote_types(x.dtype, torch.float32)
mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True)
scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps)
return x * scale.to(x.dtype)
class GEGLU(nn.Module):
def forward(self, x):
return geglu(x)
class RMSNorm(nn.Module):
def __init__(self, param_shape, eps=1e-6):
super().__init__()
self.eps = eps
self.scale = nn.Parameter(torch.ones(param_shape))
def extra_repr(self):
return f"shape={tuple(self.scale.shape)}, eps={self.eps}"
def forward(self, x):
return rms_norm(x, self.scale, self.eps)
class QKNorm(nn.Module):
def __init__(self, n_heads, eps=1e-6, max_scale=100.0):
super().__init__()
self.eps = eps
self.max_scale = math.log(max_scale)
self.scale = nn.Parameter(torch.full((n_heads,), math.log(10.0)))
self.proj_()
def extra_repr(self):
return f"n_heads={self.scale.shape[0]}, eps={self.eps}"
@torch.no_grad()
def proj_(self):
"""Modify the scale in-place so it doesn't get "stuck" with zero gradient if it's clamped
to the max value."""
self.scale.clamp_(max=self.max_scale)
def forward(self, x):
self.proj_()
scale = torch.exp(0.5 * self.scale - 0.25 * math.log(x.shape[-1]))
return rms_norm(x, scale[:, None, None], self.eps)
class AdaRMSNorm(nn.Module):
def __init__(self, features, cond_features, eps=1e-6):
super().__init__()
self.eps = eps
self.linear = apply_wd(zero_init(nn.Linear(cond_features, features, bias=False)))
tag_module(self.linear, "mapping")
def extra_repr(self):
return f"eps={self.eps},"
def forward(self, x, cond):
return rms_norm(x, self.linear(cond) + 1, self.eps)
class SelfAttentionBlock(nn.Module):
def __init__(self, d_model, d_head, dropout=0.0):
super().__init__()
self.d_head = d_head
self.n_heads = d_model // d_head
self.norm = AdaRMSNorm(d_model, d_model)
self.qkv_proj = apply_wd(nn.Linear(d_model, d_model * 3, bias=False))
self.qk_norm = QKNorm(self.n_heads)
self.pos_emb = AxialRoPE(d_head, self.n_heads)
self.dropout = nn.Dropout(dropout)
self.out_proj = apply_wd(zero_init(nn.Linear(d_model, d_model, bias=False)))
def extra_repr(self):
return f"d_head={self.d_head},"
def forward(self, x, pos, attn_mask, cond):
skip = x
x = self.norm(x, cond)
q, k, v = self.qkv_proj(x).chunk(3, dim=-1)
q = rearrange(q, "n l (h e) -> n h l e", e=self.d_head)
k = rearrange(k, "n l (h e) -> n h l e", e=self.d_head)
v = rearrange(v, "n l (h e) -> n h l e", e=self.d_head)
q = self.pos_emb(self.qk_norm(q), pos)
k = self.pos_emb(self.qk_norm(k), pos)
x = scaled_dot_product_attention(q, k, v, attn_mask)
x = rearrange(x, "n h l e -> n l (h e)")
x = self.dropout(x)
x = self.out_proj(x)
return x + skip
class FeedForwardBlock(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.0):
super().__init__()
self.norm = AdaRMSNorm(d_model, d_model)
self.up_proj = apply_wd(nn.Linear(d_model, d_ff * 2, bias=False))
self.act = GEGLU()
self.dropout = nn.Dropout(dropout)
self.down_proj = apply_wd(zero_init(nn.Linear(d_ff, d_model, bias=False)))
def forward(self, x, cond):
skip = x
x = self.norm(x, cond)
x = self.up_proj(x)
x = self.act(x)
x = self.dropout(x)
x = self.down_proj(x)
return x + skip
class TransformerBlock(nn.Module):
def __init__(self, d_model, d_ff, d_head, dropout=0.0):
super().__init__()
self.self_attn = SelfAttentionBlock(d_model, d_head, dropout=dropout)
self.ff = FeedForwardBlock(d_model, d_ff, dropout=dropout)
def forward(self, x, pos, attn_mask, cond):
x = checkpoint_helper(self.self_attn, x, pos, attn_mask, cond)
x = checkpoint_helper(self.ff, x, cond)
return x
class Patching(nn.Module):
def __init__(self, features, patch_size):
super().__init__()
self.features = features
self.patch_size = patch_size
self.d_out = features * patch_size[0] * patch_size[1]
def extra_repr(self):
return f"features={self.features}, patch_size={self.patch_size!r}"
def forward(self, x, pixel_aspect_ratio=1.0):
*_, h, w = x.shape
h_out = h // self.patch_size[0]
w_out = w // self.patch_size[1]
if h % self.patch_size[0] != 0 or w % self.patch_size[1] != 0:
raise ValueError(f"Image size {h}x{w} is not divisible by patch size {self.patch_size[0]}x{self.patch_size[1]}")
x = rearrange(x, "... c (h i) (w j) -> ... (h w) (c i j)", i=self.patch_size[0], j=self.patch_size[1])
pixel_aspect_ratio = pixel_aspect_ratio * self.patch_size[0] / self.patch_size[1]
pos = make_axial_pos(h_out, w_out, pixel_aspect_ratio, device=x.device)
return x, pos
class Unpatching(nn.Module):
def __init__(self, features, patch_size):
super().__init__()
self.features = features
self.patch_size = patch_size
self.d_in = features * patch_size[0] * patch_size[1]
def extra_repr(self):
return f"features={self.features}, patch_size={self.patch_size!r}"
def forward(self, x, h, w):
h_in = h // self.patch_size[0]
w_in = w // self.patch_size[1]
x = rearrange(x, "... (h w) (c i j) -> ... c (h i) (w j)", h=h_in, w=w_in, i=self.patch_size[0], j=self.patch_size[1])
return x
class MappingFeedForwardBlock(nn.Module):
def __init__(self, d_model, d_ff, dropout=0.0):
super().__init__()
self.norm = RMSNorm(d_model)
self.up_proj = apply_wd(nn.Linear(d_model, d_ff * 2, bias=False))
self.act = GEGLU()
self.dropout = nn.Dropout(dropout)
self.down_proj = apply_wd(zero_init(nn.Linear(d_ff, d_model, bias=False)))
def forward(self, x):
skip = x
x = self.norm(x)
x = self.up_proj(x)
x = self.act(x)
x = self.dropout(x)
x = self.down_proj(x)
return x + skip
class MappingNetwork(nn.Module):
def __init__(self, n_layers, d_model, d_ff, dropout=0.0):
super().__init__()
self.in_norm = RMSNorm(d_model)
self.blocks = nn.ModuleList([MappingFeedForwardBlock(d_model, d_ff, dropout=dropout) for _ in range(n_layers)])
self.out_norm = RMSNorm(d_model)
def forward(self, x):
x = self.in_norm(x)
for block in self.blocks:
x = block(x)
x = self.out_norm(x)
return x
class ImageTransformerDenoiserModelV1(nn.Module):
def __init__(self, n_layers, d_model, d_ff, in_features, out_features, patch_size, num_classes=0, dropout=0.0, sigma_data=1.0):
super().__init__()
self.sigma_data = sigma_data
self.num_classes = num_classes
self.patch_in = Patching(in_features, patch_size)
self.patch_out = Unpatching(out_features, patch_size)
self.time_emb = layers.FourierFeatures(1, d_model)
self.time_in_proj = nn.Linear(d_model, d_model, bias=False)
self.aug_emb = layers.FourierFeatures(9, d_model)
self.aug_in_proj = nn.Linear(d_model, d_model, bias=False)
self.class_emb = nn.Embedding(num_classes, d_model) if num_classes else None
self.mapping = tag_module(MappingNetwork(2, d_model, d_ff, dropout=dropout), "mapping")
self.in_proj = nn.Linear(self.patch_in.d_out, d_model, bias=False)
self.blocks = nn.ModuleList([TransformerBlock(d_model, d_ff, 64, dropout=dropout) for _ in range(n_layers)])
self.out_norm = RMSNorm(d_model)
self.out_proj = zero_init(nn.Linear(d_model, self.patch_out.d_in, bias=False))
def proj_(self):
for block in self.blocks:
block.self_attn.qk_norm.proj_()
def param_groups(self, base_lr=5e-4, mapping_lr_scale=1 / 3):
wd = filter_params(lambda tags: "wd" in tags and "mapping" not in tags, self)
no_wd = filter_params(lambda tags: "wd" not in tags and "mapping" not in tags, self)
mapping_wd = filter_params(lambda tags: "wd" in tags and "mapping" in tags, self)
mapping_no_wd = filter_params(lambda tags: "wd" not in tags and "mapping" in tags, self)
groups = [
{"params": list(wd), "lr": base_lr},
{"params": list(no_wd), "lr": base_lr, "weight_decay": 0.0},
{"params": list(mapping_wd), "lr": base_lr * mapping_lr_scale},
{"params": list(mapping_no_wd), "lr": base_lr * mapping_lr_scale, "weight_decay": 0.0}
]
return groups
def forward(self, x, sigma, aug_cond=None, class_cond=None):
# Patching
*_, h, w = x.shape
x, pos = self.patch_in(x)
attn_mask = None
x = self.in_proj(x)
# Mapping network
if class_cond is None and self.class_emb is not None:
raise ValueError("class_cond must be specified if num_classes > 0")
c_noise = torch.log(sigma) / 4
time_emb = self.time_in_proj(self.time_emb(c_noise[..., None]))
aug_cond = x.new_zeros([x.shape[0], 9]) if aug_cond is None else aug_cond
aug_emb = self.aug_in_proj(self.aug_emb(aug_cond))
class_emb = self.class_emb(class_cond) if self.class_emb is not None else 0
cond = self.mapping(time_emb + aug_emb + class_emb).unsqueeze(-2)
# Transformer
for block in self.blocks:
x = block(x, pos, attn_mask, cond)
# Unpatching
x = self.out_norm(x)
x = self.out_proj(x)
x = self.patch_out(x, h, w)
return x
|