| import math |
| from dataclasses import dataclass |
| from typing import Any, Dict |
|
|
| import torch |
| from einops import rearrange |
| from einops.layers.torch import Rearrange |
| from torch import Tensor, nn |
| from torch.nn.attention import SDPBackend, sdpa_kernel |
|
|
|
|
| def get_image_ids(bs: int, h: int, w: int, patch_size: int, device: torch.device) -> Tensor: |
| img_ids = torch.zeros(h // patch_size, w // patch_size, 2, device=device) |
| img_ids[..., 0] = torch.arange(h // patch_size, device=device)[:, None] |
| img_ids[..., 1] = torch.arange(w // patch_size, device=device)[None, :] |
| return img_ids.reshape((h // patch_size) * (w // patch_size), 2).unsqueeze(0).repeat(bs, 1, 1) |
|
|
|
|
| def apply_rope(xq: Tensor, freqs_cis: Tensor) -> Tensor: |
| xq_ = xq.float().reshape(*xq.shape[:-1], -1, 1, 2) |
| xq_out = freqs_cis[..., 0] * xq_[..., 0] + freqs_cis[..., 1] * xq_[..., 1] |
| return xq_out.reshape(*xq.shape).type_as(xq) |
|
|
|
|
| def _sdpa(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_mask: torch.Tensor | None = None) -> torch.Tensor: |
| |
| if q.is_cuda and q.dtype in (torch.float16, torch.bfloat16): |
| try: |
| with sdpa_kernel(SDPBackend.CUDNN_ATTENTION): |
| return torch.nn.functional.scaled_dot_product_attention( |
| q.contiguous(), k.contiguous(), v.contiguous(), attn_mask=attn_mask |
| ) |
| except RuntimeError: |
| pass |
| return torch.nn.functional.scaled_dot_product_attention( |
| q.contiguous(), k.contiguous(), v.contiguous(), attn_mask=attn_mask |
| ) |
|
|
|
|
| class EmbedND(nn.Module): |
| def __init__(self, dim: int, theta: int, axes_dim: list[int]): |
| super().__init__() |
| self.dim = dim |
| self.theta = theta |
| self.axes_dim = axes_dim |
| self.rope_rearrange = Rearrange("b n d (i j) -> b n d i j", i=2, j=2) |
|
|
| def rope(self, pos: Tensor, dim: int, theta: int) -> Tensor: |
| assert dim % 2 == 0 |
| scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim |
| omega = 1.0 / (theta**scale) |
| out = pos.unsqueeze(-1) * omega.unsqueeze(0) |
| out = torch.stack([torch.cos(out), -torch.sin(out), torch.sin(out), torch.cos(out)], dim=-1) |
| out = self.rope_rearrange(out) |
| return out.float() |
|
|
| def forward(self, ids: Tensor) -> Tensor: |
| n_axes = ids.shape[-1] |
| emb = torch.cat( |
| [self.rope(ids[:, :, i], self.axes_dim[i], self.theta) for i in range(n_axes)], |
| dim=-3, |
| ) |
|
|
| return emb.unsqueeze(1) |
|
|
|
|
| def timestep_embedding(t: Tensor, dim: int, max_period: int = 10000, time_factor: float = 1000.0) -> Tensor: |
| """ |
| Create sinusoidal timestep embeddings. |
| :param t: a 1-D Tensor of N indices, one per batch element. |
| These may be fractional. |
| :param dim: the dimension of the output. |
| :param max_period: controls the minimum frequency of the embeddings. |
| :return: an (N, D) Tensor of positional embeddings. |
| """ |
| t = time_factor * t |
| half = dim // 2 |
| freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device) |
|
|
| args = t[:, None].float() * freqs[None] |
| embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) |
| if dim % 2: |
| embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) |
| return embedding |
|
|
|
|
| class MLPEmbedder(nn.Module): |
| def __init__(self, in_dim: int, hidden_dim: int): |
| super().__init__() |
| self.in_layer = nn.Linear(in_dim, hidden_dim, bias=True) |
| self.silu = nn.SiLU() |
| self.out_layer = nn.Linear(hidden_dim, hidden_dim, bias=True) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| return self.out_layer(self.silu(self.in_layer(x))) |
|
|
|
|
| class RMSNorm(torch.nn.Module): |
| def __init__(self, dim: int): |
| super().__init__() |
| self.scale = nn.Parameter(torch.ones(dim)) |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| x_dtype = x.dtype |
| x = x.float() |
| rrms = torch.rsqrt(torch.mean(x**2, dim=-1, keepdim=True) + 1e-6) |
| return (x * rrms * self.scale).to(dtype=x_dtype) |
|
|
|
|
| class QKNorm(torch.nn.Module): |
| def __init__(self, dim: int): |
| super().__init__() |
| self.query_norm = RMSNorm(dim) |
| self.key_norm = RMSNorm(dim) |
|
|
| def forward(self, q: Tensor, k: Tensor, v: Tensor) -> tuple[Tensor, Tensor]: |
| q = self.query_norm(q) |
| k = self.key_norm(k) |
| return q.to(v), k.to(v) |
|
|
|
|
| @dataclass |
| class ModulationOut: |
| shift: Tensor |
| scale: Tensor |
| gate: Tensor |
|
|
|
|
| class Modulation(nn.Module): |
| def __init__(self, dim: int): |
| super().__init__() |
| self.lin = nn.Linear(dim, 6 * dim, bias=True) |
|
|
| nn.init.constant_(self.lin.weight, 0) |
| nn.init.constant_(self.lin.bias, 0) |
|
|
| def forward(self, vec: Tensor) -> tuple[ModulationOut, ModulationOut]: |
| out = self.lin(nn.functional.silu(vec))[:, None, :].chunk(6, dim=-1) |
| return ModulationOut(*out[:3]), ModulationOut(*out[3:]) |
|
|
| class PRXBlock(nn.Module): |
| """ |
| A PRX block |
| """ |
|
|
| def __init__( |
| self, |
| hidden_size: int, |
| num_heads: int, |
| mlp_ratio: float = 4.0, |
| qk_scale: float | None = None, |
| use_image_guidance: bool = False, |
| image_guidance_hidden_size: int = 1280, |
| ): |
| super().__init__() |
|
|
| self._fsdp_wrap = True |
| self._activation_checkpointing = True |
|
|
| self.hidden_dim = hidden_size |
| self.num_heads = num_heads |
| self.head_dim = hidden_size // num_heads |
| self.scale = qk_scale or self.head_dim**-0.5 |
|
|
| self.use_image_guidance = use_image_guidance |
| self.mlp_hidden_dim = int(hidden_size * mlp_ratio) |
| self.hidden_size = hidden_size |
|
|
| |
| self.img_pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
| self.img_qkv_proj = nn.Linear(hidden_size, hidden_size * 3, bias=False) |
| self.attn_out = nn.Linear(hidden_size, hidden_size, bias=False) |
| self.qk_norm = QKNorm(self.head_dim) |
|
|
| |
| self.txt_kv_proj = nn.Linear(hidden_size, hidden_size * 2, bias=False) |
| self.k_norm = RMSNorm(self.head_dim) |
|
|
| |
| if self.use_image_guidance: |
| self.guiding_img_kv_proj = nn.Linear(image_guidance_hidden_size, hidden_size * 2, bias=False) |
| self.guiding_img_norm = RMSNorm(self.head_dim) |
| self.attn_img_out = nn.Linear(hidden_size, hidden_size, bias=False) |
| nn.init.zeros_(self.attn_img_out.weight) |
|
|
| |
| self.post_attention_layernorm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
| self.gate_proj = nn.Linear(hidden_size, self.mlp_hidden_dim, bias=False) |
| self.up_proj = nn.Linear(hidden_size, self.mlp_hidden_dim, bias=False) |
| self.down_proj = nn.Linear(self.mlp_hidden_dim, hidden_size, bias=False) |
| self.mlp_act = nn.GELU(approximate="tanh") |
|
|
| self.modulation = Modulation(hidden_size) |
|
|
| self.spatial_cond_kv_proj: None | nn.Linear = None |
|
|
| def attn_forward( |
| self, |
| img: Tensor, |
| txt: Tensor, |
| pe: Tensor, |
| modulation: ModulationOut, |
| spatial_conditioning: None | Tensor = None, |
| image_conditioning: None | Tensor = None, |
| image_guidance_scale: None | Tensor = None, |
| attention_mask: None | Tensor = None, |
| ) -> Tensor: |
| |
| img_mod = (1 + modulation.scale) * self.img_pre_norm(img) + modulation.shift |
|
|
| img_qkv = self.img_qkv_proj(img_mod) |
| img_q, img_k, img_v = rearrange(img_qkv, "B L (K H D) -> K B H L D", K=3, H=self.num_heads) |
| img_q, img_k = self.qk_norm(img_q, img_k, img_v) |
|
|
| |
| txt_kv = self.txt_kv_proj(txt) |
| txt_k, txt_v = rearrange(txt_kv, "B L (K H D) -> K B H L D", K=2, H=self.num_heads) |
| txt_k = self.k_norm(txt_k) |
|
|
| |
| img_q, img_k = apply_rope(img_q, pe), apply_rope(img_k, pe) |
| k = torch.cat((txt_k, img_k), dim=2) |
| v = torch.cat((txt_v, img_v), dim=2) |
|
|
| |
| cond_len = 0 |
| if self.spatial_cond_kv_proj is not None: |
| assert spatial_conditioning is not None |
| cond_kv = self.spatial_cond_kv_proj(spatial_conditioning) |
| cond_k, cond_v = rearrange(cond_kv, "B L (K H D) -> K B H L D", K=2, H=self.num_heads) |
| cond_k = apply_rope(cond_k, pe) |
| cond_len = cond_k.shape[2] |
|
|
| k = torch.cat((cond_k, k), dim=2) |
| v = torch.cat((cond_v, v), dim=2) |
|
|
| |
| if attention_mask is not None: |
| bs, _, l_img, _ = img_q.shape |
| l_txt = txt_k.shape[2] |
|
|
| assert attention_mask.dim() == 2, f"Unsupported attention_mask shape: {attention_mask.shape}" |
| assert ( |
| attention_mask.shape[-1] == l_txt |
| ), f"attention_mask last dim {attention_mask.shape[-1]} must equal text length {l_txt}" |
|
|
| device = img_q.device |
|
|
| ones_img = torch.ones((bs, l_img), dtype=torch.bool, device=device) |
| cond_mask = torch.ones((bs, cond_len), dtype=torch.bool, device=device) |
|
|
| mask_parts = [ |
| cond_mask, |
| attention_mask.to(torch.bool), |
| ones_img, |
| ] |
| joint_mask = torch.cat(mask_parts, dim=-1) |
|
|
| |
| attn_mask = joint_mask[:, None, None, :].expand(-1, self.num_heads, l_img, -1) |
|
|
| attn = _sdpa(img_q, k, v, attn_mask=attn_mask) |
| attn = rearrange(attn, "B H L D -> B L (H D)") |
| attn = self.attn_out(attn) |
|
|
| if image_conditioning is not None: |
| assert self.use_image_guidance |
| assert image_guidance_scale is not None |
| guiding_img_kv = self.guiding_img_kv_proj(image_conditioning) |
| guiding_img_k, guiding_img_v = rearrange(guiding_img_kv, "B L (K H D) -> K B H L D", K=2, H=self.num_heads) |
| guiding_img_k = self.guiding_img_norm(guiding_img_k) |
| img_attn = torch.nn.functional.scaled_dot_product_attention(img_q, guiding_img_k, guiding_img_v) |
| img_attn = rearrange(img_attn, "B H L D -> B L (H D)") |
| img_attn = self.attn_img_out(img_attn) |
| attn = attn + img_attn * image_guidance_scale[..., None, None].to(guiding_img_k.dtype) |
|
|
| return attn |
|
|
| def ffn_forward(self, x: Tensor, modulation: ModulationOut) -> Tensor: |
| x = (1 + modulation.scale) * self.post_attention_layernorm(x) + modulation.shift |
| return self.down_proj(self.mlp_act(self.gate_proj(x)) * self.up_proj(x)) |
|
|
| def forward( |
| self, |
| img: Tensor, |
| txt: Tensor, |
| vec: Tensor, |
| pe: Tensor, |
| spatial_conditioning: Tensor | None = None, |
| image_conditioning: Tensor | None = None, |
| image_guidance_scale: Tensor | None = None, |
| attention_mask: Tensor | None = None, |
| **_: dict[str, Any], |
| ) -> Tensor: |
| mod_attn, mod_mlp = self.modulation(vec) |
|
|
| img = img + mod_attn.gate * self.attn_forward( |
| img, |
| txt, |
| pe, |
| mod_attn, |
| image_conditioning=image_conditioning, |
| image_guidance_scale=image_guidance_scale, |
| spatial_conditioning=spatial_conditioning, |
| attention_mask=attention_mask, |
| ) |
| img = img + mod_mlp.gate * self.ffn_forward(img, mod_mlp) |
| return img |
|
|
| class LastLayer(nn.Module): |
| def __init__(self, hidden_size: int, patch_size: int, out_channels: int): |
| super().__init__() |
| self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6) |
| self.linear = nn.Linear(hidden_size, patch_size * patch_size * out_channels, bias=True) |
| self.adaLN_modulation = nn.Sequential(nn.SiLU(), nn.Linear(hidden_size, 2 * hidden_size, bias=True)) |
|
|
| nn.init.constant_(self.adaLN_modulation[1].weight, 0) |
| nn.init.constant_(self.adaLN_modulation[1].bias, 0) |
| nn.init.constant_(self.linear.weight, 0) |
| nn.init.constant_(self.linear.bias, 0) |
|
|
| def forward(self, x: Tensor, vec: Tensor) -> Tensor: |
| shift, scale = self.adaLN_modulation(vec).chunk(2, dim=1) |
| x = (1 + scale[:, None, :]) * self.norm_final(x) + shift[:, None, :] |
| x = self.linear(x) |
| return x |
|
|