File size: 19,608 Bytes
f28049f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 |
"""Diffusion transformer modules."""
from math import log, pi
from typing import Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, reduce, repeat
from einops.layers.torch import Rearrange
from torch import Tensor, einsum
from .utils import exists, default, rand_bool
class AdaLayerNorm(nn.Module):
def __init__(self, style_dim, channels, eps=1e-5):
super().__init__()
self.channels = channels
self.eps = eps
self.fc = nn.Linear(style_dim, channels * 2)
def forward(self, x, s):
x = x.transpose(-1, -2)
x = x.transpose(1, -1)
h = self.fc(s)
h = h.view(h.size(0), h.size(1), 1)
gamma, beta = torch.chunk(h, chunks=2, dim=1)
gamma, beta = gamma.transpose(1, -1), beta.transpose(1, -1)
x = F.layer_norm(x, (self.channels,), eps=self.eps)
x = (1 + gamma) * x + beta
return x.transpose(1, -1).transpose(-1, -2)
class LearnedPositionalEmbedding(nn.Module):
def __init__(self, dim: int):
super().__init__()
assert (dim % 2) == 0
half_dim = dim // 2
self.weights = nn.Parameter(torch.randn(half_dim))
def forward(self, x: Tensor) -> Tensor:
x = rearrange(x, "b -> b 1")
freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
fouriered = torch.cat((x, fouriered), dim=-1)
return fouriered
def TimePositionalEmbedding(dim: int, out_features: int) -> nn.Module:
return nn.Sequential(
LearnedPositionalEmbedding(dim),
nn.Linear(in_features=dim + 1, out_features=out_features),
)
class FixedEmbedding(nn.Module):
def __init__(self, max_length: int, features: int):
super().__init__()
self.max_length = max_length
self.embedding = nn.Embedding(max_length, features)
def forward(self, x: Tensor) -> Tensor:
batch_size, length, device = *x.shape[0:2], x.device
assert length <= self.max_length, "Input sequence length must be <= max_length"
position = torch.arange(length, device=device)
fixed_embedding = self.embedding(position)
fixed_embedding = repeat(fixed_embedding, "n d -> b n d", b=batch_size)
return fixed_embedding
class RelativePositionBias(nn.Module):
def __init__(self, num_buckets: int, max_distance: int, num_heads: int):
super().__init__()
self.num_buckets = num_buckets
self.max_distance = max_distance
self.num_heads = num_heads
self.relative_attention_bias = nn.Embedding(num_buckets, num_heads)
@staticmethod
def _relative_position_bucket(relative_position: Tensor, num_buckets: int, max_distance: int):
num_buckets //= 2
ret = (relative_position >= 0).to(torch.long) * num_buckets
n = torch.abs(relative_position)
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = max_exact + (torch.log(n.float() / max_exact) / log(max_distance / max_exact) * (num_buckets - max_exact)).long()
val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
ret += torch.where(is_small, n, val_if_large)
return ret
def forward(self, num_queries: int, num_keys: int) -> Tensor:
i, j, device = num_queries, num_keys, self.relative_attention_bias.weight.device
q_pos = torch.arange(j - i, j, dtype=torch.long, device=device)
k_pos = torch.arange(j, dtype=torch.long, device=device)
rel_pos = rearrange(k_pos, "j -> 1 j") - rearrange(q_pos, "i -> i 1")
relative_position_bucket = self._relative_position_bucket(rel_pos, num_buckets=self.num_buckets, max_distance=self.max_distance)
bias = self.relative_attention_bias(relative_position_bucket)
bias = rearrange(bias, "m n h -> 1 h m n")
return bias
def FeedForward(features: int, multiplier: int) -> nn.Module:
mid_features = features * multiplier
return nn.Sequential(
nn.Linear(in_features=features, out_features=mid_features),
nn.GELU(),
nn.Linear(in_features=mid_features, out_features=features),
)
class AttentionBase(nn.Module):
def __init__(self, features: int, *, head_features: int, num_heads: int, use_rel_pos: bool,
out_features: Optional[int] = None, rel_pos_num_buckets: Optional[int] = None,
rel_pos_max_distance: Optional[int] = None):
super().__init__()
self.scale = head_features ** -0.5
self.num_heads = num_heads
self.use_rel_pos = use_rel_pos
mid_features = head_features * num_heads
if use_rel_pos:
assert exists(rel_pos_num_buckets) and exists(rel_pos_max_distance)
self.rel_pos = RelativePositionBias(num_buckets=rel_pos_num_buckets, max_distance=rel_pos_max_distance, num_heads=num_heads)
if out_features is None:
out_features = features
self.to_out = nn.Linear(in_features=mid_features, out_features=out_features)
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
h = self.num_heads
q = rearrange(q, "b n (h d) -> b h n d", h=h)
k = rearrange(k, "b n (h d) -> b h n d", h=h)
v = rearrange(v, "b n (h d) -> b h n d", h=h)
sim = einsum("b h n d, b h m d -> b h n m", q, k)
sim = (sim + self.rel_pos(*sim.shape[-2:])) if self.use_rel_pos else sim
sim = sim * self.scale
attn = sim.softmax(dim=-1)
out = einsum("b h n m, b h m d -> b h n d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
return self.to_out(out)
class StyleAttention(nn.Module):
def __init__(self, features: int, *, style_dim: int, head_features: int, num_heads: int,
context_features: Optional[int] = None, use_rel_pos: bool,
rel_pos_num_buckets: Optional[int] = None, rel_pos_max_distance: Optional[int] = None):
super().__init__()
self.context_features = context_features
mid_features = head_features * num_heads
context_features = default(context_features, features)
self.norm = AdaLayerNorm(style_dim, features)
self.norm_context = AdaLayerNorm(style_dim, context_features)
self.to_q = nn.Linear(in_features=features, out_features=mid_features, bias=False)
self.to_kv = nn.Linear(in_features=context_features, out_features=mid_features * 2, bias=False)
self.attention = AttentionBase(features, num_heads=num_heads, head_features=head_features,
use_rel_pos=use_rel_pos, rel_pos_num_buckets=rel_pos_num_buckets,
rel_pos_max_distance=rel_pos_max_distance)
def forward(self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
context = default(context, x)
x, context = self.norm(x, s), self.norm_context(context, s)
q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
return self.attention(q, k, v)
class Attention(nn.Module):
def __init__(self, features: int, *, head_features: int, num_heads: int, out_features: Optional[int] = None,
context_features: Optional[int] = None, use_rel_pos: bool,
rel_pos_num_buckets: Optional[int] = None, rel_pos_max_distance: Optional[int] = None):
super().__init__()
self.context_features = context_features
mid_features = head_features * num_heads
context_features = default(context_features, features)
self.norm = nn.LayerNorm(features)
self.norm_context = nn.LayerNorm(context_features)
self.to_q = nn.Linear(in_features=features, out_features=mid_features, bias=False)
self.to_kv = nn.Linear(in_features=context_features, out_features=mid_features * 2, bias=False)
self.attention = AttentionBase(features, out_features=out_features, num_heads=num_heads, head_features=head_features,
use_rel_pos=use_rel_pos, rel_pos_num_buckets=rel_pos_num_buckets,
rel_pos_max_distance=rel_pos_max_distance)
def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
context = default(context, x)
x, context = self.norm(x), self.norm_context(context)
q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
return self.attention(q, k, v)
class StyleTransformerBlock(nn.Module):
def __init__(self, features: int, num_heads: int, head_features: int, style_dim: int, multiplier: int,
use_rel_pos: bool, rel_pos_num_buckets: Optional[int] = None,
rel_pos_max_distance: Optional[int] = None, context_features: Optional[int] = None):
super().__init__()
self.use_cross_attention = exists(context_features) and context_features > 0
self.attention = StyleAttention(features=features, style_dim=style_dim, num_heads=num_heads, head_features=head_features,
use_rel_pos=use_rel_pos, rel_pos_num_buckets=rel_pos_num_buckets,
rel_pos_max_distance=rel_pos_max_distance)
if self.use_cross_attention:
self.cross_attention = StyleAttention(features=features, style_dim=style_dim, num_heads=num_heads, head_features=head_features,
context_features=context_features, use_rel_pos=use_rel_pos,
rel_pos_num_buckets=rel_pos_num_buckets, rel_pos_max_distance=rel_pos_max_distance)
self.feed_forward = FeedForward(features=features, multiplier=multiplier)
def forward(self, x: Tensor, s: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
x = self.attention(x, s) + x
if self.use_cross_attention:
x = self.cross_attention(x, s, context=context) + x
x = self.feed_forward(x) + x
return x
class TransformerBlock(nn.Module):
def __init__(self, features: int, num_heads: int, head_features: int, multiplier: int, use_rel_pos: bool,
rel_pos_num_buckets: Optional[int] = None, rel_pos_max_distance: Optional[int] = None,
context_features: Optional[int] = None):
super().__init__()
self.use_cross_attention = exists(context_features) and context_features > 0
self.attention = Attention(features=features, num_heads=num_heads, head_features=head_features,
use_rel_pos=use_rel_pos, rel_pos_num_buckets=rel_pos_num_buckets,
rel_pos_max_distance=rel_pos_max_distance)
if self.use_cross_attention:
self.cross_attention = Attention(features=features, num_heads=num_heads, head_features=head_features,
context_features=context_features, use_rel_pos=use_rel_pos,
rel_pos_num_buckets=rel_pos_num_buckets, rel_pos_max_distance=rel_pos_max_distance)
self.feed_forward = FeedForward(features=features, multiplier=multiplier)
def forward(self, x: Tensor, *, context: Optional[Tensor] = None) -> Tensor:
x = self.attention(x) + x
if self.use_cross_attention:
x = self.cross_attention(x, context=context) + x
x = self.feed_forward(x) + x
return x
class StyleTransformer1d(nn.Module):
def __init__(self, num_layers: int, channels: int, num_heads: int, head_features: int, multiplier: int,
use_context_time: bool = True, use_rel_pos: bool = False, context_features_multiplier: int = 1,
rel_pos_num_buckets: Optional[int] = None, rel_pos_max_distance: Optional[int] = None,
context_features: Optional[int] = None, context_embedding_features: Optional[int] = None,
embedding_max_length: int = 512):
super().__init__()
self.blocks = nn.ModuleList([
StyleTransformerBlock(features=channels + context_embedding_features, head_features=head_features, num_heads=num_heads,
multiplier=multiplier, style_dim=context_features, use_rel_pos=use_rel_pos,
rel_pos_num_buckets=rel_pos_num_buckets, rel_pos_max_distance=rel_pos_max_distance)
for _ in range(num_layers)
])
self.to_out = nn.Sequential(
Rearrange("b t c -> b c t"),
nn.Conv1d(in_channels=channels + context_embedding_features, out_channels=channels, kernel_size=1),
)
use_context_features = exists(context_features)
self.use_context_features = use_context_features
self.use_context_time = use_context_time
if use_context_time or use_context_features:
context_mapping_features = channels + context_embedding_features
self.to_mapping = nn.Sequential(nn.Linear(context_mapping_features, context_mapping_features), nn.GELU(),
nn.Linear(context_mapping_features, context_mapping_features), nn.GELU())
if use_context_time:
self.to_time = nn.Sequential(TimePositionalEmbedding(dim=channels, out_features=context_mapping_features), nn.GELU())
if use_context_features:
self.to_features = nn.Sequential(nn.Linear(in_features=context_features, out_features=context_mapping_features), nn.GELU())
self.fixed_embedding = FixedEmbedding(max_length=embedding_max_length, features=context_embedding_features)
def get_mapping(self, time: Optional[Tensor] = None, features: Optional[Tensor] = None) -> Optional[Tensor]:
items, mapping = [], None
if self.use_context_time:
items += [self.to_time(time)]
if self.use_context_features:
items += [self.to_features(features)]
if self.use_context_time or self.use_context_features:
mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
mapping = self.to_mapping(mapping)
return mapping
def run(self, x, time, embedding, features):
mapping = self.get_mapping(time, features)
x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1)
for block in self.blocks:
x = x + mapping
x = block(x, features)
x = x.mean(axis=1).unsqueeze(1)
x = self.to_out(x)
x = x.transpose(-1, -2)
return x
def forward(self, x: Tensor, time: Tensor, embedding_mask_proba: float = 0.0, embedding: Optional[Tensor] = None,
features: Optional[Tensor] = None, embedding_scale: float = 1.0) -> Tensor:
b, device = embedding.shape[0], embedding.device
fixed_embedding = self.fixed_embedding(embedding)
if embedding_mask_proba > 0.0:
batch_mask = rand_bool(shape=(b, 1, 1), proba=embedding_mask_proba, device=device)
embedding = torch.where(batch_mask, fixed_embedding, embedding)
if embedding_scale != 1.0:
out = self.run(x, time, embedding=embedding, features=features)
out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
return out_masked + (out - out_masked) * embedding_scale
else:
return self.run(x, time, embedding=embedding, features=features)
class Transformer1d(nn.Module):
def __init__(self, num_layers: int, channels: int, num_heads: int, head_features: int, multiplier: int,
use_context_time: bool = True, use_rel_pos: bool = False, context_features_multiplier: int = 1,
rel_pos_num_buckets: Optional[int] = None, rel_pos_max_distance: Optional[int] = None,
context_features: Optional[int] = None, context_embedding_features: Optional[int] = None,
embedding_max_length: int = 512):
super().__init__()
self.blocks = nn.ModuleList([
TransformerBlock(features=channels + context_embedding_features, head_features=head_features, num_heads=num_heads,
multiplier=multiplier, use_rel_pos=use_rel_pos, rel_pos_num_buckets=rel_pos_num_buckets,
rel_pos_max_distance=rel_pos_max_distance)
for _ in range(num_layers)
])
self.to_out = nn.Sequential(
Rearrange("b t c -> b c t"),
nn.Conv1d(in_channels=channels + context_embedding_features, out_channels=channels, kernel_size=1),
)
use_context_features = exists(context_features)
self.use_context_features = use_context_features
self.use_context_time = use_context_time
if use_context_time or use_context_features:
context_mapping_features = channels + context_embedding_features
self.to_mapping = nn.Sequential(nn.Linear(context_mapping_features, context_mapping_features), nn.GELU(),
nn.Linear(context_mapping_features, context_mapping_features), nn.GELU())
if use_context_time:
self.to_time = nn.Sequential(TimePositionalEmbedding(dim=channels, out_features=context_mapping_features), nn.GELU())
if use_context_features:
self.to_features = nn.Sequential(nn.Linear(in_features=context_features, out_features=context_mapping_features), nn.GELU())
self.fixed_embedding = FixedEmbedding(max_length=embedding_max_length, features=context_embedding_features)
def get_mapping(self, time: Optional[Tensor] = None, features: Optional[Tensor] = None) -> Optional[Tensor]:
items, mapping = [], None
if self.use_context_time:
items += [self.to_time(time)]
if self.use_context_features:
items += [self.to_features(features)]
if self.use_context_time or self.use_context_features:
mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
mapping = self.to_mapping(mapping)
return mapping
def run(self, x, time, embedding, features):
mapping = self.get_mapping(time, features)
x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
mapping = mapping.unsqueeze(1).expand(-1, embedding.size(1), -1)
for block in self.blocks:
x = x + mapping
x = block(x)
x = x.mean(axis=1).unsqueeze(1)
x = self.to_out(x)
x = x.transpose(-1, -2)
return x
def forward(self, x: Tensor, time: Tensor, embedding_mask_proba: float = 0.0, embedding: Optional[Tensor] = None,
features: Optional[Tensor] = None, embedding_scale: float = 1.0) -> Tensor:
b, device = embedding.shape[0], embedding.device
fixed_embedding = self.fixed_embedding(embedding)
if embedding_mask_proba > 0.0:
batch_mask = rand_bool(shape=(b, 1, 1), proba=embedding_mask_proba, device=device)
embedding = torch.where(batch_mask, fixed_embedding, embedding)
if embedding_scale != 1.0:
out = self.run(x, time, embedding=embedding, features=features)
out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
return out_masked + (out - out_masked) * embedding_scale
else:
return self.run(x, time, embedding=embedding, features=features)
|