|
|
| import torch |
| import torch.nn.functional as F |
| import torch.nn as nn |
| from torch import Tensor, FloatTensor |
| from typing import Optional, Callable, Tuple, Dict, List, Any, Union |
|
|
| import einops |
| from einops import rearrange |
| import copy |
| import comfy |
|
|
|
|
| from .latents import gaussian_blur_2d, median_blur_2d |
|
|
| |
| class StyleTransfer: |
| def __init__(self, |
| style_method = "WCT", |
| embedder_method = None, |
| patch_size = 1, |
| pinv_dtype = torch.float64, |
| dtype = torch.float64, |
| ): |
| self.style_method = style_method |
| |
| self.embedder_method = None |
| self.unembedder_method = None |
|
|
| if embedder_method is not None: |
| self.set_embedder_method(embedder_method) |
| |
| self.patch_size = patch_size |
| |
| |
| |
| self.pinv_dtype = pinv_dtype |
| self.dtype = dtype |
| |
| self.patchify = None |
| self.unpatchify = None |
| |
| self.orig_shape = None |
| self.grid_sizes = None |
| |
| |
| |
| |
|
|
| def set_patchify_method(self, patchify_method=None): |
| self.patchify_method = patchify_method |
|
|
| def set_unpatchify_method(self, unpatchify_method=None): |
| self.unpatchify_method = unpatchify_method |
| |
| def set_embedder_method(self, embedder_method): |
| self.embedder_method = copy.deepcopy(embedder_method).to(self.pinv_dtype) |
| self.W = self.embedder_method.weight |
| self.B = self.embedder_method.bias |
| |
| if isinstance(embedder_method, nn.Linear): |
| self.unembedder_method = self.invert_linear |
| |
| elif isinstance(embedder_method, nn.Conv2d): |
| self.unembedder_method = self.invert_conv2d |
| |
| elif isinstance(embedder_method, nn.Conv3d): |
| self.unembedder_method = self.invert_conv3d |
| |
| def set_patch_size(self, patch_size): |
| self.patch_size = patch_size |
|
|
| def unpatchify(self, x: Tensor) -> List[Tensor]: |
| x_arr = [] |
| for i, img_size in enumerate(self.img_sizes): |
| pH, pW = img_size |
| x_arr.append( |
| einops.rearrange(x[i, :pH*pW].reshape(1, pH, pW, -1), 'B H W (p1 p2 C) -> B C (H p1) (W p2)', |
| p1=self.patch_size, p2=self.patch_size) |
| ) |
| x = torch.cat(x_arr, dim=0) |
| return x |
|
|
| def patchify(self, x: Tensor): |
| x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size)) |
| |
| pH, pW = x.shape[-2] // self.patch_size, x.shape[-1] // self.patch_size |
| self.img_sizes = [[pH, pW]] * x.shape[0] |
| x = einops.rearrange(x, 'B C (H p1) (W p2) -> B (H W) (p1 p2 C)', p1=self.patch_size, p2=self.patch_size) |
| return x |
| |
| |
| def embedder(self, x): |
| if isinstance(self.embedder_method, nn.Linear): |
| x = self.patchify(x) |
| |
| self.orig_shape = x.shape |
| x = self.embedder_method(x) |
| self.grid_sizes = x.shape[2:] |
| |
| |
| |
| |
| |
| return x |
| |
| def unembedder(self, x): |
| |
| |
| |
| x = self.unembedder_method(x) |
| return x |
| |
| |
| def invert_linear(self, x : torch.Tensor,) -> torch.Tensor: |
| x = x.to(self.pinv_dtype) |
| |
| x = (x - self.B) @ torch.linalg.pinv(self.W).T |
| |
| return x.to(self.dtype) |
|
|
| |
| |
| def invert_conv2d(self, z: torch.Tensor,) -> torch.Tensor: |
| z = z.to(self.pinv_dtype) |
| conv = self.embedder_method |
| |
| B, C_in, H, W = self.orig_shape |
| C_out, _, kH, kW = conv.weight.shape |
| stride_h, stride_w = conv.stride |
| pad_h, pad_w = conv.padding |
|
|
| b = conv.bias.view(1, C_out, 1, 1).to(z) |
| z_nobias = z - b |
|
|
| W_flat = conv.weight.view(C_out, -1).to(z) |
| W_pinv = torch.linalg.pinv(W_flat) |
|
|
| Bz, Co, Hp, Wp = z_nobias.shape |
| z_flat = z_nobias.reshape(Bz, Co, -1) |
|
|
| x_patches = W_pinv @ z_flat |
|
|
| x_sum = F.fold( |
| x_patches, |
| output_size=(H + 2*pad_h, W + 2*pad_w), |
| kernel_size=(kH, kW), |
| stride=(stride_h, stride_w), |
| ) |
| ones = torch.ones_like(x_patches) |
| count = F.fold( |
| ones, |
| output_size=(H + 2*pad_h, W + 2*pad_w), |
| kernel_size=(kH, kW), |
| stride=(stride_h, stride_w), |
| ) |
|
|
| x_recon = x_sum / count.clamp(min=1e-6) |
| if pad_h > 0 or pad_w > 0: |
| x_recon = x_recon[..., pad_h:pad_h+H, pad_w:pad_w+W] |
|
|
| return x_recon.to(self.dtype) |
|
|
|
|
|
|
| def invert_conv3d(self, z: torch.Tensor, ) -> torch.Tensor: |
| z = z.to(self.pinv_dtype) |
| conv = self.embedder_method |
| grid_sizes = self.grid_sizes |
|
|
| B, C_in, D, H, W = self.orig_shape |
| pD, pH, pW = self.patch_size |
| sD, sH, sW = pD, pH, pW |
|
|
| if z.ndim == 3: |
| |
| S = z.shape[1] |
| if grid_sizes is None: |
| Dp = D // pD |
| Hp = H // pH |
| Wp = W // pW |
| else: |
| Dp, Hp, Wp = grid_sizes |
| C_out = z.shape[2] |
| z = z.transpose(1, 2).reshape(B, C_out, Dp, Hp, Wp) |
| else: |
| B2, C_out, Dp, Hp, Wp = z.shape |
| assert B2 == B, "Batch size mismatch... ya sharked it." |
|
|
| b = conv.bias.view(1, C_out, 1, 1, 1) |
| z_nobias = z - b |
|
|
| |
| w3 = conv.weight |
| w2 = w3.squeeze(2) |
| out_ch, in_ch, kH, kW = w2.shape |
| W_flat = w2.view(out_ch, -1) |
| W_pinv = torch.linalg.pinv(W_flat) |
|
|
| |
| z2 = z_nobias.permute(0,2,1,3,4).reshape(B*Dp, C_out, Hp, Wp) |
|
|
| |
| z_flat = z2.reshape(B*Dp, C_out, -1) |
| x_patches = W_pinv @ z_flat |
|
|
| |
| x2 = F.fold( |
| x_patches, |
| output_size=(H, W), |
| kernel_size=(pH, pW), |
| stride=(sH, sW) |
| ) |
|
|
| |
| x2 = x2.reshape(B, Dp, in_ch, H, W) |
| x_recon = x2.permute(0,2,1,3,4).contiguous() |
| return x_recon.to(self.dtype) |
|
|
|
|
|
|
| def adain_seq_inplace(self, content: torch.Tensor, style: torch.Tensor, eps: float = 1e-7) -> torch.Tensor: |
| mean_c = content.mean(1, keepdim=True) |
| std_c = content.std (1, keepdim=True).add_(eps) |
| mean_s = style.mean (1, keepdim=True) |
| std_s = style.std (1, keepdim=True).add_(eps) |
|
|
| content.sub_(mean_c).div_(std_c).mul_(std_s).add_(mean_s) |
| return content |
|
|
|
|
|
|
|
|
|
|
|
|
| class StyleWCT: |
| def __init__(self, dtype=torch.float64, use_svd=False,): |
| self.dtype = dtype |
| self.use_svd = use_svd |
| self.y0_adain_embed = None |
| self.mu_s = None |
| self.y0_color = None |
| self.spatial_shape = None |
| |
| def whiten(self, f_s_centered: torch.Tensor, set=False): |
| cov = (f_s_centered.T.double() @ f_s_centered.double()) / (f_s_centered.size(0) - 1) |
|
|
| if self.use_svd: |
| U_svd, S_svd, Vh_svd = torch.linalg.svd(cov + 1e-5 * torch.eye(cov.size(0), dtype=cov.dtype, device=cov.device)) |
| S_eig = S_svd |
| U_eig = U_svd |
| else: |
| S_eig, U_eig = torch.linalg.eigh(cov + 1e-5 * torch.eye(cov.size(0), dtype=cov.dtype, device=cov.device)) |
| |
| if set: |
| S_eig_root = S_eig.clamp(min=0).sqrt() |
| else: |
| S_eig_root = S_eig.clamp(min=0).rsqrt() |
| |
| whiten = U_eig @ torch.diag(S_eig_root) @ U_eig.T |
| return whiten.to(f_s_centered) |
|
|
| def set(self, y0_adain_embed: torch.Tensor, spatial_shape=None): |
| if self.y0_adain_embed is None or self.y0_adain_embed.shape != y0_adain_embed.shape or torch.norm(self.y0_adain_embed - y0_adain_embed) > 0: |
| self.y0_adain_embed = y0_adain_embed.clone() |
| if spatial_shape is not None: |
| self.spatial_shape = spatial_shape |
| |
| f_s = y0_adain_embed[0] |
| self.mu_s = f_s.mean(dim=0, keepdim=True) |
| f_s_centered = f_s - self.mu_s |
| |
| self.y0_color = self.whiten(f_s_centered, set=True) |
| |
| def get(self, denoised_embed: torch.Tensor): |
| for wct_i in range(denoised_embed.shape[0]): |
| f_c = denoised_embed[wct_i] |
| mu_c = f_c.mean(dim=0, keepdim=True) |
| f_c_centered = f_c - mu_c |
|
|
| whiten = self.whiten(f_c_centered) |
|
|
| f_c_whitened = f_c_centered @ whiten.T |
| f_cs = f_c_whitened @ self.y0_color.T + self.mu_s |
| |
| denoised_embed[wct_i] = f_cs |
| |
| return denoised_embed |
|
|
|
|
|
|
|
|
| class WaveletStyleWCT(StyleWCT): |
| def set(self, y0_adain_embed: torch.Tensor, h_len, w_len): |
| if self.y0_adain_embed is None or self.y0_adain_embed.shape != y0_adain_embed.shape or torch.norm(self.y0_adain_embed - y0_adain_embed) > 0: |
| self.y0_adain_embed = y0_adain_embed.clone() |
| |
| B, HW, C = y0_adain_embed.shape |
| LL, _, _, _ = haar_wavelet_decompose(y0_adain_embed.contiguous().view(B, C, h_len, w_len)) |
|
|
| B_LL, C_LL, H_LL, W_LL = LL.shape |
| |
| flat = LL.contiguous().view(B_LL, H_LL * W_LL, C_LL) |
|
|
| f_s = flat[0] |
| self.mu_s = f_s.mean(dim=0, keepdim=True) |
| f_s_centered = f_s - self.mu_s |
| self.y0_color = self.whiten(f_s_centered, set=True) |
| |
| |
| def get(self, denoised_embed: torch.Tensor, h_len, w_len, stylize_highfreq=False): |
|
|
| B, HW, C = denoised_embed.shape |
| |
| denoised_embed = denoised_embed.contiguous().view(B, C, h_len, w_len) |
| |
| for i in range(B): |
| x = denoised_embed[i:i+1] |
| LL, LH, HL, HH = haar_wavelet_decompose(x) |
|
|
| def process_band(band): |
| Bc, Cc, Hc, Wc = band.shape |
| flat = band.contiguous().view(Bc, Hc * Wc, Cc) |
| |
| styled = super(WaveletStyleWCT, self).get(flat) |
| return styled.contiguous().view(Bc, Cc, Hc, Wc) |
|
|
| LL_styled = process_band(LL) |
|
|
| if stylize_highfreq: |
| LH_styled = process_band(LH) |
| HL_styled = process_band(HL) |
| HH_styled = process_band(HH) |
| else: |
| LH_styled, HL_styled, HH_styled = LH, HL, HH |
|
|
| recon = haar_wavelet_reconstruct(LL_styled, LH_styled, HL_styled, HH_styled) |
| denoised_embed[i] = recon.squeeze(0) |
|
|
| return denoised_embed.view(B, HW, C) |
|
|
|
|
|
|
| def haar_wavelet_decompose(x): |
| """ |
| Orthonormal Haar decomposition. |
| Input: [B, C, H, W] |
| Output: LL, LH, HL, HH with shape [B, C, H//2, W//2] |
| """ |
| if x.dtype != torch.float32: |
| x = x.float() |
| |
| B, C, H, W = x.shape |
| assert H % 2 == 0 and W % 2 == 0, "Input must have even H, W" |
|
|
| |
| norm = 1 / 2**0.5 |
|
|
| x00 = x[:, :, 0::2, 0::2] |
| x01 = x[:, :, 0::2, 1::2] |
| x10 = x[:, :, 1::2, 0::2] |
| x11 = x[:, :, 1::2, 1::2] |
|
|
| LL = (x00 + x01 + x10 + x11) * norm * 0.5 |
| LH = (x00 - x01 + x10 - x11) * norm * 0.5 |
| HL = (x00 + x01 - x10 - x11) * norm * 0.5 |
| HH = (x00 - x01 - x10 + x11) * norm * 0.5 |
|
|
| return LL, LH, HL, HH |
|
|
| def haar_wavelet_reconstruct(LL, LH, HL, HH): |
| """ |
| Orthonormal inverse Haar reconstruction. |
| Input: LL, LH, HL, HH [B, C, H, W] |
| Output: Reconstructed [B, C, H*2, W*2] |
| """ |
| norm = 1 / 2**0.5 |
| B, C, H, W = LL.shape |
|
|
| x00 = (LL + LH + HL + HH) * norm |
| x01 = (LL - LH + HL - HH) * norm |
| x10 = (LL + LH - HL - HH) * norm |
| x11 = (LL - LH - HL + HH) * norm |
|
|
| out = torch.zeros(B, C, H * 2, W * 2, device=LL.device, dtype=LL.dtype) |
| out[:, :, 0::2, 0::2] = x00 |
| out[:, :, 0::2, 1::2] = x01 |
| out[:, :, 1::2, 0::2] = x10 |
| out[:, :, 1::2, 1::2] = x11 |
|
|
| return out |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """ |
| |
| class StyleFeatures: |
| def __init__(self, dtype=torch.float64,): |
| self.dtype = dtype |
| |
| def set(self, y0_adain_embed: torch.Tensor): |
| |
| def get(self, denoised_embed: torch.Tensor): |
| |
| return "Norpity McNerp" |
| |
| """ |
|
|
|
|
|
|
|
|
| class Retrojector: |
| def __init__(self, proj=None, patch_size=2, pinv_dtype=torch.float64, dtype=torch.float64, ENDO=False): |
| self.proj = proj |
| self.patch_size = patch_size |
| self.pinv_dtype = pinv_dtype |
| self.dtype = dtype |
| |
| self.LINEAR = isinstance(proj, nn.Linear) |
| self.CONV2D = isinstance(proj, nn.Conv2d) |
| self.CONV3D = isinstance(proj, nn.Conv3d) |
| self.ENDO = ENDO |
| self.W = proj.weight.data.to(dtype=pinv_dtype).cuda() |
| |
| if self.LINEAR: |
| self.W_inv = torch.linalg.pinv(self.W.cuda()) |
| elif self.CONV2D: |
| C_out, _, kH, kW = proj.weight.shape |
| W_flat = proj.weight.view(C_out, -1).to(dtype=pinv_dtype) |
| self.W_inv = torch.linalg.pinv(W_flat.cuda()) |
| |
| if proj.bias is None: |
| if self.LINEAR: |
| bias_size = proj.out_features |
| else: |
| bias_size = proj.out_channels |
| self.b = torch.zeros(bias_size, dtype=pinv_dtype, device=self.W_inv.device) |
| else: |
| self.b = proj.bias.data.to(dtype=pinv_dtype).to(self.W_inv.device) |
| |
| def embed(self, img: torch.Tensor): |
| self.h = img.shape[-2] // self.patch_size |
| self.w = img.shape[-1] // self.patch_size |
| |
| img = comfy.ldm.common_dit.pad_to_patch_size(img, (self.patch_size, self.patch_size)) |
| |
| if self.CONV2D: |
| self.orig_shape = img.shape |
| img_embed = F.conv2d( |
| img.to(self.W), |
| weight=self.W, |
| bias=self.b, |
| stride=self.proj.stride, |
| padding=self.proj.padding |
| ) |
| |
| img_embed = rearrange(img_embed, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=1, pw=1) |
| |
| elif self.LINEAR: |
| if img.ndim == 4: |
| img = rearrange(img, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=self.patch_size, pw=self.patch_size) |
| if self.ENDO: |
| img_embed = F.linear(img.to(self.b) - self.b, self.W_inv) |
| else: |
| img_embed = F.linear(img.to(self.W), self.W, self.b) |
| |
| return img_embed.to(img) |
| |
| def unembed(self, img_embed: torch.Tensor): |
| if self.CONV2D: |
| |
| img_embed = rearrange(img_embed, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=self.h, w=self.w, ph=1, pw=1) |
| img = self.invert_conv2d(img_embed) |
| |
| elif self.LINEAR: |
| if self.ENDO: |
| img = F.linear(img_embed.to(self.W), self.W, self.b) |
| else: |
| img = F.linear(img_embed.to(self.b) - self.b, self.W_inv) |
| if img.ndim == 3: |
| img = rearrange(img, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=self.h, w=self.w, ph=self.patch_size, pw=self.patch_size) |
| |
| return img.to(img_embed) |
| |
| def invert_conv2d(self, z: torch.Tensor,) -> torch.Tensor: |
| z_dtype = z.dtype |
| z = z.to(self.pinv_dtype) |
| conv = self.proj |
| |
| B, C_in, H, W = self.orig_shape |
| C_out, _, kH, kW = conv.weight.shape |
| stride_h, stride_w = conv.stride |
| pad_h, pad_w = conv.padding |
|
|
| b = conv.bias.view(1, C_out, 1, 1).to(z) |
| z_nobias = z - b |
|
|
| |
| |
|
|
| Bz, Co, Hp, Wp = z_nobias.shape |
| z_flat = z_nobias.reshape(Bz, Co, -1) |
|
|
| x_patches = self.W_inv @ z_flat |
|
|
| x_sum = F.fold( |
| x_patches, |
| output_size=(H + 2*pad_h, W+ 2*pad_w), |
| kernel_size=(kH, kW), |
| stride=(stride_h, stride_w), |
| ) |
| ones = torch.ones_like(x_patches) |
| count = F.fold( |
| ones, |
| output_size=(H + 2*pad_h, W + 2*pad_w), |
| kernel_size=(kH, kW), |
| stride=(stride_h, stride_w), |
| ) |
|
|
| x_recon = x_sum / count.clamp(min=1e-6) |
| if pad_h > 0 or pad_w > 0: |
| x_recon = x_recon[..., pad_h:pad_h+H, pad_w:pad_w+W] |
|
|
| return x_recon.to(z_dtype) |
| |
| def invert_patch_embedding(self, z: torch.Tensor, original_shape: torch.Size, grid_sizes: Optional[Tuple[int,int,int]] = None) -> torch.Tensor: |
|
|
| B, C_in, D, H, W = original_shape |
| pD, pH, pW = self.patch_size |
| sD, sH, sW = pD, pH, pW |
|
|
| if z.ndim == 3: |
| |
| S = z.shape[1] |
| if grid_sizes is None: |
| Dp = D // pD |
| Hp = H // pH |
| Wp = W // pW |
| else: |
| Dp, Hp, Wp = grid_sizes |
| C_out = z.shape[2] |
| z = z.transpose(1, 2).reshape(B, C_out, Dp, Hp, Wp) |
| else: |
| B2, C_out, Dp, Hp, Wp = z.shape |
| assert B2 == B, "Batch size mismatch... ya sharked it." |
|
|
| |
| b = self.patch_embedding.bias.view(1, C_out, 1, 1, 1) |
| z_nobias = z - b |
|
|
| |
| w3 = self.patch_embedding.weight |
| w2 = w3.squeeze(2) |
| out_ch, in_ch, kH, kW = w2.shape |
| W_flat = w2.view(out_ch, -1) |
| W_pinv = torch.linalg.pinv(W_flat) |
|
|
| |
| z2 = z_nobias.permute(0,2,1,3,4).reshape(B*Dp, C_out, Hp, Wp) |
|
|
| |
| z_flat = z2.reshape(B*Dp, C_out, -1) |
| x_patches = W_pinv @ z_flat |
|
|
| |
| x2 = F.fold( |
| x_patches, |
| output_size=(H, W), |
| kernel_size=(pH, pW), |
| stride=(sH, sW) |
| ) |
|
|
| |
| x2 = x2.reshape(B, Dp, in_ch, H, W) |
| x_recon = x2.permute(0,2,1,3,4).contiguous() |
| return x_recon |
|
|
|
|
|
|
|
|
|
|
|
|
| def invert_conv2d( |
| conv: torch.nn.Conv2d, |
| z: torch.Tensor, |
| original_shape: torch.Size, |
| ) -> torch.Tensor: |
| import torch.nn.functional as F |
|
|
| B, C_in, H, W = original_shape |
| C_out, _, kH, kW = conv.weight.shape |
| stride_h, stride_w = conv.stride |
| pad_h, pad_w = conv.padding |
|
|
| if conv.bias is not None: |
| b = conv.bias.view(1, C_out, 1, 1).to(z) |
| z_nobias = z - b |
| else: |
| z_nobias = z |
|
|
| W_flat = conv.weight.view(C_out, -1).to(z) |
| W_pinv = torch.linalg.pinv(W_flat) |
|
|
| Bz, Co, Hp, Wp = z_nobias.shape |
| z_flat = z_nobias.reshape(Bz, Co, -1) |
|
|
| x_patches = W_pinv @ z_flat |
|
|
| x_sum = F.fold( |
| x_patches, |
| output_size=(H + 2*pad_h, W + 2*pad_w), |
| kernel_size=(kH, kW), |
| stride=(stride_h, stride_w), |
| ) |
| ones = torch.ones_like(x_patches) |
| count = F.fold( |
| ones, |
| output_size=(H + 2*pad_h, W + 2*pad_w), |
| kernel_size=(kH, kW), |
| stride=(stride_h, stride_w), |
| ) |
|
|
| x_recon = x_sum / count.clamp(min=1e-6) |
| if pad_h > 0 or pad_w > 0: |
| x_recon = x_recon[..., pad_h:pad_h+H, pad_w:pad_w+W] |
|
|
| return x_recon |
|
|
|
|
|
|
| def adain_seq_inplace(content: torch.Tensor, style: torch.Tensor, dim=1, eps: float = 1e-7) -> torch.Tensor: |
| mean_c = content.mean(dim, keepdim=True) |
| std_c = content.std (dim, keepdim=True).add_(eps) |
| mean_s = style.mean (dim, keepdim=True) |
| std_s = style.std (dim, keepdim=True).add_(eps) |
|
|
| content.sub_(mean_c).div_(std_c).mul_(std_s).add_(mean_s) |
| return content |
|
|
| def adain_seq(content: torch.Tensor, style: torch.Tensor, eps: float = 1e-7) -> torch.Tensor: |
| return ((content - content.mean(1, keepdim=True)) / (content.std(1, keepdim=True) + eps)) * (style.std(1, keepdim=True) + eps) + style.mean(1, keepdim=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def apply_scattersort_tiled( |
| denoised_spatial : torch.Tensor, |
| y0_adain_spatial : torch.Tensor, |
| tile_h : int, |
| tile_w : int, |
| pad : int, |
| ): |
| """ |
| Apply spatial scattersort between denoised_spatial and y0_adain_spatial |
| using local tile-wise sorted value matching. |
| |
| Args: |
| denoised_spatial (Tensor): (B, C, H, W) tensor. |
| y0_adain_spatial (Tensor): (B, C, H, W) reference tensor. |
| tile_h (int): tile height. |
| tile_w (int): tile width. |
| pad (int): padding size to apply around tiles. |
| |
| Returns: |
| denoised_embed (Tensor): (B, H*W, C) tensor after sortmatch. |
| """ |
| denoised_padded = F.pad(denoised_spatial, (pad, pad, pad, pad), mode='reflect') |
| y0_padded = F.pad(y0_adain_spatial, (pad, pad, pad, pad), mode='reflect') |
|
|
| denoised_padded_out = denoised_padded.clone() |
| _, _, h_len, w_len = denoised_spatial.shape |
|
|
| for ix in range(pad, h_len, tile_h): |
| for jx in range(pad, w_len, tile_w): |
| tile = denoised_padded[:, :, ix - pad:ix + tile_h + pad, jx - pad:jx + tile_w + pad] |
| y0_tile = y0_padded[:, :, ix - pad:ix + tile_h + pad, jx - pad:jx + tile_w + pad] |
|
|
| tile = rearrange(tile, "b c h w -> b c (h w)", h=tile_h + pad * 2, w=tile_w + pad * 2) |
| y0_tile = rearrange(y0_tile, "b c h w -> b c (h w)", h=tile_h + pad * 2, w=tile_w + pad * 2) |
|
|
| src_sorted, src_idx = tile.sort(dim=-1) |
| ref_sorted, ref_idx = y0_tile.sort(dim=-1) |
|
|
| new_tile = tile.scatter(dim=-1, index=src_idx, src=ref_sorted.expand(src_sorted.shape)) |
| new_tile = rearrange(new_tile, "b c (h w) -> b c h w", h=tile_h + pad * 2, w=tile_w + pad * 2) |
|
|
| denoised_padded_out[:, :, ix:ix + tile_h, jx:jx + tile_w] = ( |
| new_tile if pad == 0 else new_tile[:, :, pad:-pad, pad:-pad] |
| ) |
|
|
| denoised_padded_out = denoised_padded_out if pad == 0 else denoised_padded_out[:, :, pad:-pad, pad:-pad] |
| return denoised_padded_out |
|
|
|
|
|
|
| def apply_scattersort_masked( |
| denoised_embed : torch.Tensor, |
| y0_adain_embed : torch.Tensor, |
| y0_style_pos_mask : torch.Tensor | None, |
| y0_style_pos_mask_edge : torch.Tensor | None, |
| h_len : int, |
| w_len : int |
| ): |
| if y0_style_pos_mask is None: |
| flatmask = torch.ones((1,1,h_len,w_len)).bool().flatten().bool() |
| else: |
| flatmask = F.interpolate(y0_style_pos_mask, size=(h_len, w_len)).bool().flatten().cpu() |
| flatunmask = ~flatmask |
|
|
| if y0_style_pos_mask_edge is not None: |
| edgemask = F.interpolate( |
| y0_style_pos_mask_edge.unsqueeze(0), size=(h_len, w_len) |
| ).bool().flatten() |
| flatmask = flatmask & (~edgemask) |
| flatunmask = flatunmask & (~edgemask) |
|
|
| denoised_masked = denoised_embed[:, flatmask, :].clone() |
| y0_adain_masked = y0_adain_embed[:, flatmask, :].clone() |
|
|
| src_sorted, src_idx = denoised_masked.sort(dim=-2) |
| ref_sorted, ref_idx = y0_adain_masked.sort(dim=-2) |
|
|
| denoised_embed[:, flatmask, :] = src_sorted.scatter(dim=-2, index=src_idx, src=ref_sorted.expand(src_sorted.shape)) |
|
|
| if (flatunmask == True).any(): |
| denoised_unmasked = denoised_embed[:, flatunmask, :].clone() |
| y0_adain_unmasked = y0_adain_embed[:, flatunmask, :].clone() |
|
|
| src_sorted, src_idx = denoised_unmasked.sort(dim=-2) |
| ref_sorted, ref_idx = y0_adain_unmasked.sort(dim=-2) |
|
|
| denoised_embed[:, flatunmask, :] = src_sorted.scatter(dim=-2, index=src_idx, src=ref_sorted.expand(src_sorted.shape)) |
|
|
| if y0_style_pos_mask_edge is not None: |
| denoised_edgemasked = denoised_embed[:, edgemask, :].clone() |
| y0_adain_edgemasked = y0_adain_embed[:, edgemask, :].clone() |
|
|
| src_sorted, src_idx = denoised_edgemasked.sort(dim=-2) |
| ref_sorted, ref_idx = y0_adain_edgemasked.sort(dim=-2) |
|
|
| denoised_embed[:, edgemask, :] = src_sorted.scatter(dim=-2, index=src_idx, src=ref_sorted.expand(src_sorted.shape)) |
|
|
| return denoised_embed |
|
|
|
|
|
|
|
|
| def apply_scattersort( |
| denoised_embed : torch.Tensor, |
| y0_adain_embed : torch.Tensor, |
| ): |
| |
| src_idx = denoised_embed.argsort(dim=-2) |
| ref_sorted = y0_adain_embed.sort(dim=-2)[0] |
|
|
| denoised_embed.scatter_(dim=-2, index=src_idx, src=ref_sorted.expand(ref_sorted.shape)) |
|
|
| return denoised_embed |
|
|
| def apply_scattersort_spatial( |
| denoised_spatial : torch.Tensor, |
| y0_adain_spatial : torch.Tensor, |
| ): |
| denoised_embed = rearrange(denoised_spatial, "b c h w -> b (h w) c") |
| y0_adain_embed = rearrange(y0_adain_spatial, "b c h w -> b (h w) c") |
| src_sorted, src_idx = denoised_embed.sort(dim=-2) |
| ref_sorted, ref_idx = y0_adain_embed.sort(dim=-2) |
|
|
| denoised_embed = src_sorted.scatter(dim=-2, index=src_idx, src=ref_sorted.expand(src_sorted.shape)) |
| |
| return rearrange(denoised_embed, "b (h w) c -> b c h w", h=denoised_spatial.shape[-2], w=denoised_spatial.shape[-1]) |
|
|
|
|
|
|
|
|
|
|
| def apply_scattersort_spatial( |
| x_spatial : torch.Tensor, |
| y_spatial : torch.Tensor, |
| ): |
| x_emb = rearrange(x_spatial, "b c h w -> b (h w) c") |
| y_emb = rearrange(y_spatial, "b c h w -> b (h w) c") |
| |
| x_sorted, x_idx = x_emb.sort(dim=-2) |
| y_sorted, y_idx = y_emb.sort(dim=-2) |
|
|
| x_emb = x_sorted.scatter(dim=-2, index=x_idx, src=y_sorted.expand(x_sorted.shape)) |
| |
| return rearrange(x_emb, "b (h w) c -> b c h w", h=x_spatial.shape[-2], w=x_spatial.shape[-1]) |
|
|
|
|
|
|
|
|
| def apply_adain_spatial( |
| x_spatial : torch.Tensor, |
| y_spatial : torch.Tensor, |
| ): |
| x_emb = rearrange(x_spatial, "b c h w -> b (h w) c") |
| y_emb = rearrange(y_spatial, "b c h w -> b (h w) c") |
| |
| x_mean = x_emb.mean(-2, keepdim=True) |
| x_std = x_emb.std (-2, keepdim=True) |
| y_mean = y_emb.mean(-2, keepdim=True) |
| y_std = y_emb.std (-2, keepdim=True) |
|
|
| assert (x_std == 0).any() == 0, "Target tensor has no variance!" |
| assert (y_std == 0).any() == 0, "Reference tensor has no variance!" |
| |
| x_emb_adain = (x_emb - x_mean) / x_std |
| x_emb_adain = (x_emb_adain * y_std) + y_mean |
| |
| return x_emb_adain.reshape_as(x_spatial) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def adain_patchwise(content: torch.Tensor, style: torch.Tensor, sigma: float = 1.0, kernel_size: int = None, eps: float = 1e-5) -> torch.Tensor: |
| |
| B, C, H, W = content.shape |
| device = content.device |
| dtype = content.dtype |
|
|
| if kernel_size is None: |
| kernel_size = int(2 * math.ceil(3 * sigma) + 1) |
| if kernel_size % 2 == 0: |
| kernel_size += 1 |
|
|
| pad = kernel_size // 2 |
| coords = torch.arange(kernel_size, dtype=torch.float64, device=device) - pad |
| gauss = torch.exp(-0.5 * (coords / sigma) ** 2) |
| gauss /= gauss.sum() |
| kernel_2d = (gauss[:, None] * gauss[None, :]).to(dtype=dtype) |
|
|
| weight = kernel_2d.view(1, 1, kernel_size, kernel_size) |
|
|
| content_padded = F.pad(content, (pad, pad, pad, pad), mode='reflect') |
| style_padded = F.pad(style, (pad, pad, pad, pad), mode='reflect') |
| result = torch.zeros_like(content) |
|
|
| for i in range(H): |
| for j in range(W): |
| c_patch = content_padded[:, :, i:i + kernel_size, j:j + kernel_size] |
| s_patch = style_padded[:, :, i:i + kernel_size, j:j + kernel_size] |
| w = weight.expand_as(c_patch) |
|
|
| c_mean = (c_patch * w).sum(dim=(-1, -2), keepdim=True) |
| c_std = ((c_patch - c_mean)**2 * w).sum(dim=(-1, -2), keepdim=True).sqrt() + eps |
| s_mean = (s_patch * w).sum(dim=(-1, -2), keepdim=True) |
| s_std = ((s_patch - s_mean)**2 * w).sum(dim=(-1, -2), keepdim=True).sqrt() + eps |
|
|
| normed = (c_patch[:, :, pad:pad+1, pad:pad+1] - c_mean) / c_std |
| stylized = normed * s_std + s_mean |
| result[:, :, i, j] = stylized.squeeze(-1).squeeze(-1) |
|
|
| return result |
|
|
|
|
| def adain_patchwise_row_batch(content: torch.Tensor, style: torch.Tensor, sigma: float = 1.0, kernel_size: int = None, eps: float = 1e-5) -> torch.Tensor: |
|
|
| B, C, H, W = content.shape |
| device, dtype = content.device, content.dtype |
|
|
| if kernel_size is None: |
| kernel_size = int(2 * math.ceil(3 * sigma) + 1) |
| if kernel_size % 2 == 0: |
| kernel_size += 1 |
|
|
| pad = kernel_size // 2 |
| coords = torch.arange(kernel_size, dtype=torch.float64, device=device) - pad |
| gauss = torch.exp(-0.5 * (coords / sigma) ** 2) |
| gauss = (gauss / gauss.sum()).to(dtype) |
| kernel_2d = (gauss[:, None] * gauss[None, :]) |
|
|
| weight = kernel_2d.view(1, 1, kernel_size, kernel_size) |
|
|
| content_padded = F.pad(content, (pad, pad, pad, pad), mode='reflect') |
| style_padded = F.pad(style, (pad, pad, pad, pad), mode='reflect') |
| result = torch.zeros_like(content) |
|
|
| for i in range(H): |
| c_row_patches = torch.stack([ |
| content_padded[:, :, i:i+kernel_size, j:j+kernel_size] |
| for j in range(W) |
| ], dim=0) |
|
|
| s_row_patches = torch.stack([ |
| style_padded[:, :, i:i+kernel_size, j:j+kernel_size] |
| for j in range(W) |
| ], dim=0) |
|
|
| w = weight.expand_as(c_row_patches[0]) |
|
|
| c_mean = (c_row_patches * w).sum(dim=(-1, -2), keepdim=True) |
| c_std = ((c_row_patches - c_mean) ** 2 * w).sum(dim=(-1, -2), keepdim=True).sqrt() + eps |
| s_mean = (s_row_patches * w).sum(dim=(-1, -2), keepdim=True) |
| s_std = ((s_row_patches - s_mean) ** 2 * w).sum(dim=(-1, -2), keepdim=True).sqrt() + eps |
|
|
| center = kernel_size // 2 |
| central = c_row_patches[:, :, :, center:center+1, center:center+1] |
| normed = (central - c_mean) / c_std |
| stylized = normed * s_std + s_mean |
|
|
| result[:, :, i, :] = stylized.squeeze(-1).squeeze(-1).permute(1, 2, 0) |
|
|
| return result |
|
|
|
|
|
|
| def adain_patchwise_row_batch_med(content: torch.Tensor, style: torch.Tensor, sigma: float = 1.0, kernel_size: int = None, eps: float = 1e-5, mask: torch.Tensor = None, use_median_blur: bool = False, lowpass_weight=1.0, highpass_weight=1.0) -> torch.Tensor: |
| B, C, H, W = content.shape |
| device, dtype = content.device, content.dtype |
|
|
| if kernel_size is None: |
| kernel_size = int(2 * math.ceil(3 * abs(sigma)) + 1) |
| if kernel_size % 2 == 0: |
| kernel_size += 1 |
|
|
| pad = kernel_size // 2 |
|
|
| content_padded = F.pad(content, (pad, pad, pad, pad), mode='reflect') |
| style_padded = F.pad(style, (pad, pad, pad, pad), mode='reflect') |
| result = torch.zeros_like(content) |
|
|
| scaling = torch.ones((B, 1, H, W), device=device, dtype=dtype) |
| sigma_scale = torch.ones((H, W), device=device, dtype=torch.float32) |
| if mask is not None: |
| with torch.no_grad(): |
| padded_mask = F.pad(mask.float(), (pad, pad, pad, pad), mode="reflect") |
| blurred_mask = F.avg_pool2d(padded_mask, kernel_size=kernel_size, stride=1, padding=pad) |
| blurred_mask = blurred_mask[..., pad:-pad, pad:-pad] |
| edge_proximity = blurred_mask * (1.0 - blurred_mask) |
| scaling = 1.0 - (edge_proximity / 0.25).clamp(0.0, 1.0) |
| sigma_scale = scaling[0, 0] |
|
|
| if not use_median_blur: |
| coords = torch.arange(kernel_size, dtype=torch.float64, device=device) - pad |
| base_gauss = torch.exp(-0.5 * (coords / sigma) ** 2) |
| base_gauss = (base_gauss / base_gauss.sum()).to(dtype) |
| gaussian_table = {} |
| for s in sigma_scale.unique(): |
| sig = float((sigma * s + eps).clamp(min=1e-3)) |
| gauss_local = torch.exp(-0.5 * (coords / sig) ** 2) |
| gauss_local = (gauss_local / gauss_local.sum()).to(dtype) |
| kernel_2d = gauss_local[:, None] * gauss_local[None, :] |
| gaussian_table[s.item()] = kernel_2d |
|
|
| for i in range(H): |
| row_result = torch.zeros(B, C, W, dtype=dtype, device=device) |
| for j in range(W): |
| c_patch = content_padded[:, :, i:i+kernel_size, j:j+kernel_size] |
| s_patch = style_padded[:, :, i:i+kernel_size, j:j+kernel_size] |
|
|
| if use_median_blur: |
| |
| unfolded_c = c_patch.reshape(B, C, -1) |
| unfolded_s = s_patch.reshape(B, C, -1) |
|
|
| c_median = unfolded_c.median(dim=-1, keepdim=True).values |
| s_median = unfolded_s.median(dim=-1, keepdim=True).values |
|
|
| center = kernel_size // 2 |
| central = c_patch[:, :, center, center].view(B, C, 1) |
| residual = central - c_median |
| stylized = lowpass_weight * s_median + residual * highpass_weight |
| else: |
| k = gaussian_table[float(sigma_scale[i, j].item())] |
| local_weight = k.view(1, 1, kernel_size, kernel_size).expand(B, C, kernel_size, kernel_size) |
|
|
| c_mean = (c_patch * local_weight).sum(dim=(-1, -2), keepdim=True) |
| c_std = ((c_patch - c_mean) ** 2 * local_weight).sum(dim=(-1, -2), keepdim=True).sqrt() + eps |
| s_mean = (s_patch * local_weight).sum(dim=(-1, -2), keepdim=True) |
| s_std = ((s_patch - s_mean) ** 2 * local_weight).sum(dim=(-1, -2), keepdim=True).sqrt() + eps |
|
|
| center = kernel_size // 2 |
| central = c_patch[:, :, center:center+1, center:center+1] |
| normed = (central - c_mean) / c_std |
| stylized = normed * s_std + s_mean |
|
|
| local_scaling = scaling[:, :, i, j].view(B, 1, 1) |
| stylized = central * (1 - local_scaling) + stylized * local_scaling |
|
|
| row_result[:, :, j] = stylized.squeeze(-1) |
| result[:, :, i, :] = row_result |
|
|
| return result |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| def weighted_mix_n(tensor_list, weight_list, dim=-1, offset=0): |
| assert all(t.shape == tensor_list[0].shape for t in tensor_list) |
| assert len(tensor_list) == len(weight_list) |
|
|
| total_weight = sum(weight_list) |
| ratios = [w / total_weight for w in weight_list] |
|
|
| length = tensor_list[0].shape[dim] |
| idx = torch.arange(length) |
|
|
| |
| float_bins = (idx + offset) * len(ratios) / length |
| bin_idx = torch.floor(float_bins).long() % len(ratios) |
|
|
| |
| counters = [0.0 for _ in ratios] |
| slots = torch.empty_like(idx) |
|
|
| for i in range(length): |
| |
| expected = [r * (i + 1) for r in ratios] |
| errors = [expected[j] - counters[j] for j in range(len(ratios))] |
| k = max(range(len(errors)), key=lambda j: errors[j]) |
| slots[i] = k |
| counters[k] += 1 |
|
|
| |
| out = tensor_list[0].clone() |
| for i, tensor in enumerate(tensor_list): |
| mask = slots == i |
| while mask.dim() < tensor.dim(): |
| mask = mask.unsqueeze(0) |
| mask = mask.expand_as(tensor) |
| out = torch.where(mask, tensor, out) |
| |
| return out |
|
|
|
|
|
|
|
|
|
|
|
|
| from torch import vmap |
|
|
| BLOCK_NAMES = {"double_blocks", "single_blocks", "up_blocks", "middle_blocks", "down_blocks", "input_blocks", "output_blocks"} |
|
|
| DEFAULT_BLOCK_WEIGHTS_MMDIT = { |
| "attn_norm" : 0.0, |
| "attn_norm_mod": 0.0, |
| "attn" : 1.0, |
| "attn_gated" : 0.0, |
| "attn_res" : 1.0, |
| "ff_norm" : 0.0, |
| "ff_norm_mod" : 0.0, |
| "ff" : 1.0, |
| "ff_gated" : 0.0, |
| "ff_res" : 1.0, |
| |
| "h_tile" : 8, |
| "w_tile" : 8, |
| } |
|
|
| DEFAULT_ATTN_WEIGHTS_MMDIT = { |
| "q_proj": 0.0, |
| "k_proj": 0.0, |
| "v_proj": 1.0, |
| "q_norm": 0.0, |
| "k_norm": 0.0, |
| "out" : 1.0, |
| |
| "h_tile": 8, |
| "w_tile": 8, |
| } |
|
|
| DEFAULT_BASE_WEIGHTS_MMDIT = { |
| "proj_in" : 1.0, |
| "proj_out": 1.0, |
| |
| "h_tile" : 8, |
| "w_tile" : 8, |
| } |
|
|
| class Stylizer: |
| buffer = {} |
| |
| CLS_WCT = StyleWCT() |
| |
| CLS_WCT2 = WaveletStyleWCT() |
| |
| def __init__(self, dtype=torch.float64, device=torch.device("cuda")): |
| self.dtype = dtype |
| self.device = device |
| self.mask = [None] |
| self.apply_to = [""] |
| self.method = ["passthrough"] |
| self.h_tile = [-1] |
| self.w_tile = [-1] |
| |
| self.w_len = 0 |
| self.h_len = 0 |
| self.img_len = 0 |
| |
| self.IMG_1ST = True |
| self.HEADS = 0 |
| self.KONTEXT = 0 |
| def set_mode(self, mode): |
| self.method = [mode] |
| |
| def set_weights(self, **kwargs): |
| for k, v in kwargs.items(): |
| if hasattr(self, k): |
| setattr(self, k, [v]) |
| |
| def set_weights_recursive(self, **kwargs): |
| for name, val in kwargs.items(): |
| if hasattr(self, name): |
| setattr(self, name, [val]) |
|
|
| for attr_name, attr_val in vars(self).items(): |
| if isinstance(attr_val, Stylizer): |
| attr_val.set_weights_recursive(**kwargs) |
|
|
| for list_name in BLOCK_NAMES: |
| lst = getattr(self, list_name, None) |
| if isinstance(lst, list): |
| for element in lst: |
| if isinstance(element, Stylizer): |
| element.set_weights_recursive(**kwargs) |
| |
| def merge_weights(self, other): |
| def recursive_merge(a, b, path): |
| if isinstance(a, list) and isinstance(b, list): |
| if path in BLOCK_NAMES: |
| out = [] |
| for i in range(max(len(a), len(b))): |
| if i < len(a) and i < len(b): |
| out.append(recursive_merge(a[i], b[i], path=None)) |
| elif i < len(a): |
| out.append(a[i]) |
| else: |
| out.append(b[i]) |
| return out |
| return a + b |
|
|
| if isinstance(a, dict) and isinstance(b, dict): |
| merged = dict(a) |
| for k, v_b in b.items(): |
| if k in merged: |
| merged[k] = recursive_merge(merged[k], v_b, path=None) |
| else: |
| merged[k] = v_b |
| return merged |
|
|
| if hasattr(a, "__dict__") and hasattr(b, "__dict__"): |
| for attr, val_b in vars(b).items(): |
| val_a = getattr(a, attr, None) |
| if val_a is not None: |
| setattr(a, attr, recursive_merge(val_a, val_b, path=attr)) |
| else: |
| setattr(a, attr, val_b) |
| return a |
| return b |
|
|
| for attr in vars(self): |
| if attr in BLOCK_NAMES: |
| merged = recursive_merge(getattr(self, attr), getattr(other, attr, []), path=attr) |
| elif hasattr(other, attr): |
| merged = recursive_merge(getattr(self, attr), getattr(other, attr), path=attr) |
| else: |
| continue |
| setattr(self, attr, merged) |
| |
| def set_len(self, h_len, w_len, img_slice, txt_slice, HEADS): |
| self.h_len = h_len |
| self.w_len = w_len |
| self.img_slice = img_slice |
| self.txt_slice = txt_slice |
| self.img_len = h_len * w_len |
| self.HEADS = HEADS |
|
|
| @staticmethod |
| def middle_slice(length, weight): |
| """ |
| Returns a slice object that selects the middle `weight` fraction of a dimension. |
| Example: weight=1.0 → full slice; weight=0.5 → middle 50% |
| """ |
| if weight >= 1.0: |
| return slice(None) |
| wr = int((length * (1 - weight)) // 2) |
| return slice(wr, -wr if wr > 0 else None) |
|
|
| @staticmethod |
| def get_outer_slice(x, weight): |
| if weight >= 0.0: |
| return x |
| length = x.shape[-2] |
| wr = int((length * (1 - (-weight))) // 2) |
| |
| return torch.cat([x[...,:wr,:], x[...,-wr:,:]], dim=-2) |
|
|
| @staticmethod |
| def restore_outer_slice(x, x_outer, weight): |
| if weight >= 0.0: |
| return x |
| length = x.shape[-2] |
| wr = int((length * (1 - (-weight))) // 2) |
| |
| x[...,:wr,:] = x_outer[...,:wr,:] |
| x[...,-wr:,:] = x_outer[...,-wr:,:] |
| return x |
|
|
| def __call__(self, x, attr): |
| if x.shape[0] == 1 and not self.KONTEXT: |
| return x |
| |
| weight_list = getattr(self, attr) |
| weights_all_zero = all(weight == 0.0 for weight in weight_list) |
| if weights_all_zero: |
| return x |
| |
| |
| |
| |
| |
| |
| |
| |
| HEAD_DIM = x.shape[1] |
| if HEAD_DIM == self.HEADS: |
| B, HEAD_DIM, HW, C = x.shape |
| x = x.reshape(B, HW, C*HEAD_DIM) |
| |
| if hasattr(self, "KONTEXT") and self.KONTEXT == 1: |
| x = x.reshape(2, x.shape[1] // 2, x.shape[2]) |
| |
| txt_slice, img_slice, ktx_slice = self.txt_slice, self.img_slice, None |
| if hasattr(self, "KONTEXT") and self.KONTEXT == 2: |
| ktx_slice = self.img_slice |
| img_slice = slice(2 * self.img_slice.start, self.img_slice.start) |
| txt_slice = slice(None, 2 * self.txt_slice.stop) |
| |
| weights_all_one = all(weight == 1.0 for weight in weight_list) |
| methods_all_scattersort = all(name == "scattersort" for name in self.method) |
| masks_all_none = all(mask is None for mask in self.mask) |
| |
| if weights_all_one and methods_all_scattersort and len(weight_list) > 1 and masks_all_none: |
| buf = Stylizer.buffer |
| buf['src_idx'] = x[0:1].argsort(dim=-2) |
| buf['ref_sorted'], buf['ref_idx'] = x[1:].reshape(1, -1, x.shape[-1]).sort(dim=-2) |
| buf['src'] = buf['ref_sorted'][:,::len(weight_list)].expand_as(buf['src_idx']) |
| |
| x[0:1] = x[0:1].scatter_(dim=-2, index=buf['src_idx'], src=buf['src'],) |
| |
| else: |
| for i, (weight, mask) in enumerate(zip(weight_list, self.mask)): |
| if mask is not None: |
| x01 = x[0:1].clone() |
| slc = Stylizer.middle_slice(x.shape[-2], weight) |
| |
| |
| txt_method_name = self.method[i].removeprefix("tiled_") |
| txt_method = getattr(self, txt_method_name) |
| |
| method_name = self.method[i].removeprefix("tiled_") if self.img_len > x.shape[-2] or self.h_len < 0 else self.method[i] |
| method = getattr(self, method_name) |
| apply_to = self.apply_to[i] |
| if weight == 0.0: |
| continue |
| else: |
| if weight > 0 and weight < 1: |
| x_clone = x.clone() |
| if self.img_len == x.shape[-2] or apply_to == "img+txt" or self.h_len < 0: |
| x = method(x, idx=i+1, slc=slc) |
| elif self.img_len < x.shape[-2]: |
| if "img" in apply_to: |
| x[...,img_slice,:] = method(x[...,img_slice,:], idx=i+1, slc=slc) |
| |
| |
| |
| if "txt" in apply_to: |
| x[...,txt_slice,:] = txt_method(x[...,txt_slice,:], idx=i+1, slc=slc) |
| |
| if not "img" in apply_to and not "txt" in apply_to: |
| pass |
| else: |
| x = method(x, idx=i+1, slc=slc) |
| if weight > 0 and weight < 1 and txt_method_name != "scattersort": |
| x = torch.lerp(x_clone, x, weight) |
| |
| |
| |
| if mask is not None: |
| x[0:1,...,img_slice,:] = torch.lerp(x01[...,img_slice,:], x[0:1,...,img_slice,:], mask.view(1, -1, 1)) |
| if ktx_slice is not None: |
| x[0:1,...,ktx_slice,:] = torch.lerp(x01[...,ktx_slice,:], x[0:1,...,ktx_slice,:], mask.view(1, -1, 1)) |
| |
| |
| |
| |
| if hasattr(self, "KONTEXT") and self.KONTEXT == 1: |
| x = x.reshape(1, x.shape[1] * 2, x.shape[2]) |
| |
| if HEAD_DIM == self.HEADS: |
| return x.reshape(B, HEAD_DIM, HW, C) |
| else: |
| return x |
|
|
|
|
|
|
| def WCT(self, x, idx=1): |
| Stylizer.CLS_WCT.set(x[idx:idx+1]) |
| x[0:1] = Stylizer.CLS_WCT.get(x[0:1]) |
| return x |
| |
| def WCT2(self, x, idx=1): |
| Stylizer.CLS_WCT2.set(x[idx:idx+1], self.h_len, self.w_len) |
| x[0:1] = Stylizer.CLS_WCT2.get(x[0:1], self.h_len, self.w_len) |
| return x |
|
|
| @staticmethod |
| def AdaIN_(x, y, eps: float = 1e-7) -> torch.Tensor: |
| mean_c = x.mean(-2, keepdim=True) |
| std_c = x.std (-2, keepdim=True).add_(eps) |
| mean_s = y.mean (-2, keepdim=True) |
| std_s = y.std (-2, keepdim=True).add_(eps) |
| x.sub_(mean_c).div_(std_c).mul_(std_s).add_(mean_s) |
| return x |
|
|
| def AdaIN(self, x, idx=1, eps: float = 1e-7) -> torch.Tensor: |
| mean_c = x[0:1].mean(-2, keepdim=True) |
| std_c = x[0:1].std (-2, keepdim=True).add_(eps) |
| mean_s = x[idx:idx+1].mean (-2, keepdim=True) |
| std_s = x[idx:idx+1].std (-2, keepdim=True).add_(eps) |
| x[0:1].sub_(mean_c).div_(std_c).mul_(std_s).add_(mean_s) |
| return x |
|
|
| def injection(self, x:torch.Tensor, idx=1) -> torch.Tensor: |
| x[0:1] = x[idx:idx+1] |
| return x |
| |
| @staticmethod |
| def injection_(x:torch.Tensor, y:torch.Tensor) -> torch.Tensor: |
| return y |
| |
| @staticmethod |
| def passthrough(x:torch.Tensor, idx=1) -> torch.Tensor: |
| return x |
| |
| @staticmethod |
| def decompose_magnitude_direction(x, dim=-1, eps=1e-8): |
| magnitude = x.norm(p=2, dim=dim, keepdim=True) |
| direction = x / (magnitude + eps) |
| return magnitude, direction |
|
|
| @staticmethod |
| def scattersort_dir_(x, y, dim=-2): |
| |
| |
| |
| |
| |
| mag, _ = Stylizer.decompose_magnitude_direction(x.to(torch.float64), dim) |
| |
| buf = Stylizer.buffer |
| buf['src_idx'] = x.argsort(dim=-2) |
| buf['ref_sorted'], buf['ref_idx'] = y .sort(dim=-2) |
| x.scatter_(dim=-2, index=buf['src_idx'], src=buf['ref_sorted'].expand_as(buf['src_idx'])) |
| |
| |
| _, dir = Stylizer.decompose_magnitude_direction(x.to(torch.float64), dim) |
| |
| return (mag * dir).to(x) |
|
|
|
|
| @staticmethod |
| def scattersort_dir2_(x, y, dim=-2): |
| |
| |
| |
| |
| |
| |
| |
| buf = Stylizer.buffer |
| buf['src_sorted'], buf['src_idx'] = x.sort(dim=dim) |
| buf['ref_sorted'], buf['ref_idx'] = y.sort(dim=dim) |
| |
|
|
|
|
|
|
| buf['x_sub'], buf['x_sub_idx'] = buf['src_sorted'].sort(dim=-1) |
| buf['y_sub'], buf['y_sub_idx'] = buf['ref_sorted'].sort(dim=-1) |
| |
| mag, _ = Stylizer.decompose_magnitude_direction(buf['x_sub'].to(torch.float64), -1) |
| _, dir = Stylizer.decompose_magnitude_direction(buf['y_sub'].to(torch.float64), -1) |
| |
| buf['y_sub'] = (mag * dir).to(x) |
| |
| buf['ref_sorted'].scatter_(dim=-1, index=buf['y_sub_idx'], src=buf['y_sub'].expand_as(buf['y_sub_idx'])) |
|
|
|
|
|
|
| mag, _ = Stylizer.decompose_magnitude_direction(buf['src_sorted'].to(torch.float64), dim) |
| _, dir = Stylizer.decompose_magnitude_direction(buf['ref_sorted'].to(torch.float64), dim) |
| |
| buf['ref_sorted'] = (mag * dir).to(x) |
| |
| x.scatter_(dim=dim, index=buf['src_idx'], src=buf['ref_sorted'].expand_as(buf['src_idx'])) |
|
|
|
|
| return x |
|
|
|
|
| @staticmethod |
| def scattersort_dir(x, idx=1): |
| x[0:1] = Stylizer.scattersort_dir_(x[0:1], x[idx:idx+1]) |
| return x |
| |
|
|
| @staticmethod |
| def scattersort_dir2(x, idx=1): |
| x[0:1] = Stylizer.scattersort_dir2_(x[0:1], x[idx:idx+1]) |
| return x |
|
|
| @staticmethod |
| def scattersort_(x, y, slc=slice(None)): |
| buf = Stylizer.buffer |
| buf['src_idx'] = x.argsort(dim=-2) |
| buf['ref_sorted'], buf['ref_idx'] = y .sort(dim=-2) |
|
|
| return x.scatter_(dim=-2, index=buf['src_idx'][...,slc,:], src=buf['ref_sorted'][...,slc,:].expand_as(buf['src_idx'][...,slc,:])) |
| |
|
|
| @staticmethod |
| def scattersort_double(x, y): |
| buf = Stylizer.buffer |
| buf['src_sorted'], buf['src_idx'] = x.sort(dim=-2) |
| buf['ref_sorted'], buf['ref_idx'] = y.sort(dim=-2) |
| |
| buf['x_sub_idx'] = buf['src_sorted'].argsort(dim=-1) |
| buf['y_sub'], buf['y_sub_idx'] = buf['ref_sorted'].sort(dim=-1) |
| |
| x.scatter_(dim=-1, index=buf['x_sub_idx'], src=buf['y_sub'].expand_as(buf['x_sub_idx'])) |
|
|
| return x.scatter_(dim=-2, index=buf['src_idx'], src=buf['ref_sorted'].expand_as(buf['src_idx'])) |
| |
| |
| def scattersort_aoeu(self, x, idx=1, slc=slice(None)): |
| x[0:1] = Stylizer.scattersort_(x[0:1], x[idx:idx+1], slc) |
| return x |
| |
| def scattersort(self, x, idx=1, slc=slice(None)): |
| if x.shape[0] != 2: |
| x[0:1] = Stylizer.scattersort_(x[0:1], x[idx:idx+1], slc) |
| return x |
| |
| buf = Stylizer.buffer |
| buf['sorted'], buf['idx'] = x.sort(dim=-2) |
|
|
| return x.scatter_(dim=-2, index=buf['idx'][0:1][...,slc,:], src=buf['sorted'][1:2][...,slc,:].expand_as(buf['idx'][0:1][...,slc,:])) |
| |
|
|
| |
| |
| def tiled_scattersort(self, x, idx=1): |
| |
| |
| |
| |
| |
| |
| |
| C = x.shape[-1] |
| den = x[0:1] [:,self.img_slice,:].reshape(-1, C, self.h_len, self.w_len) |
| style = x[idx:idx+1][:,self.img_slice,:].reshape(-1, C, self.h_len, self.w_len) |
| |
| tiles = Stylizer.get_tiles_as_strided(den, self.h_tile[idx-1], self.w_tile[idx-1]) |
| ref_tile = Stylizer.get_tiles_as_strided(style, self.h_tile[idx-1], self.w_tile[idx-1]) |
|
|
| |
| tiles_v = tiles .permute(2, 3, 0, 1, 4, 5) |
| ref_tile_v = ref_tile.permute(2, 3, 0, 1, 4, 5) |
|
|
| |
| vmap2 = torch.vmap(torch.vmap(Stylizer.apply_scattersort_per_tile, in_dims=0), in_dims=0) |
| result = vmap2(tiles_v, ref_tile_v) |
|
|
| |
| result = result.permute(2, 3, 0, 1, 4, 5) |
|
|
| |
| tiles.copy_(result) |
|
|
| return x |
| |
| |
| def tiled_AdaIN(self, x, idx=1): |
| |
| |
| |
| |
| C = x.shape[-1] |
| den = x[0:1] [:,self.img_slice,:].reshape(-1, C, self.h_len, self.w_len) |
| style = x[idx:idx+1][:,self.img_slice,:].reshape(-1, C, self.h_len, self.w_len) |
| |
| tiles = Stylizer.get_tiles_as_strided(den, self.h_tile[idx-1], self.w_tile[idx-1]) |
| ref_tile = Stylizer.get_tiles_as_strided(style, self.h_tile[idx-1], self.w_tile[idx-1]) |
|
|
| |
| tiles_v = tiles .permute(2, 3, 0, 1, 4, 5) |
| ref_tile_v = ref_tile.permute(2, 3, 0, 1, 4, 5) |
|
|
| |
| vmap2 = torch.vmap(torch.vmap(Stylizer.apply_AdaIN_per_tile, in_dims=0), in_dims=0) |
| result = vmap2(tiles_v, ref_tile_v) |
|
|
| |
| result = result.permute(2, 3, 0, 1, 4, 5) |
|
|
| |
| tiles.copy_(result) |
|
|
| return x |
| |
| |
| @staticmethod |
| def get_tiles_as_strided(x, tile_h, tile_w): |
| B, C, H, W = x.shape |
| stride = x.stride() |
| nH = H // tile_h |
| nW = W // tile_w |
|
|
| tiles = x.as_strided( |
| size=(B, C, nH, nW, tile_h, tile_w), |
| stride=(stride[0], stride[1], stride[2] * tile_h, stride[3] * tile_w, stride[2], stride[3]) |
| ) |
| return tiles |
|
|
| @staticmethod |
| def apply_scattersort_per_tile(tile, ref_tile): |
| flat = tile .flatten(-2, -1) |
| ref_flat = ref_tile.flatten(-2, -1) |
|
|
| sorted_ref, _ = ref_flat .sort(dim=-1) |
| src_sorted, src_idx = flat.sort(dim=-1) |
| |
| out = flat.scatter(dim=-1, index=src_idx, src=sorted_ref) |
| return out.view_as(tile) |
|
|
| @staticmethod |
| def apply_AdaIN_per_tile(tile, ref_tile, eps: float = 1e-7): |
| mean_c = tile.mean(-2, keepdim=True) |
| std_c = tile.std (-2, keepdim=True).add_(eps) |
| mean_s = ref_tile.mean (-2, keepdim=True) |
| std_s = ref_tile.std (-2, keepdim=True).add_(eps) |
| tile.sub_(mean_c).div_(std_c).mul_(std_s).add_(mean_s) |
| return tile |
|
|
| class StyleMMDiT_Attn(Stylizer): |
| def __init__(self, mode): |
| super().__init__() |
| |
| self.q_proj = [0.0] |
| self.k_proj = [0.0] |
| self.v_proj = [0.0] |
|
|
| self.q_norm = [0.0] |
| self.k_norm = [0.0] |
| |
| self.out = [0.0] |
|
|
| class StyleMMDiT_FF(Stylizer): |
| def __init__(self, mode): |
| super().__init__() |
| |
| self.ff_1 = [0.0] |
| self.ff_1_silu = [0.0] |
| self.ff_3 = [0.0] |
| self.ff_13 = [0.0] |
| self.ff_2 = [0.0] |
| |
| class StyleMMDiT_MoE(Stylizer): |
| def __init__(self, mode): |
| super().__init__() |
| |
| self.FF_SHARED = StyleMMDiT_FF(mode) |
| self.FF_SEPARATE = StyleMMDiT_FF(mode) |
| |
| self.shared = [0.0] |
| self.gate = [False] |
| self.topk_weight = [0.0] |
|
|
| self.separate = [0.0] |
| self.sum = [0.0] |
| self.out = [0.0] |
|
|
|
|
|
|
|
|
|
|
| class StyleMMDiT_SubBlock(Stylizer): |
| def __init__(self, mode): |
| super().__init__() |
| |
| self.ATTN = StyleMMDiT_Attn(mode) |
|
|
| self.attn_norm = [0.0] |
| self.attn_norm_mod = [0.0] |
| self.attn = [0.0] |
| self.attn_gated = [0.0] |
| self.attn_res = [0.0] |
| |
| self.ff_norm = [0.0] |
| self.ff_norm_mod = [0.0] |
| self.ff = [0.0] |
| self.ff_gated = [0.0] |
| self.ff_res = [0.0] |
| |
| self.mask = [None] |
| |
| def set_len(self, h_len, w_len, img_slice, txt_slice, HEADS): |
| super().set_len(h_len, w_len, img_slice, txt_slice, HEADS) |
| self.ATTN.set_len(h_len, w_len, img_slice, txt_slice, HEADS) |
|
|
| class StyleMMDiT_IMG_Block(StyleMMDiT_SubBlock): |
| def __init__(self, mode): |
| super().__init__(mode) |
| self.FF = StyleMMDiT_MoE(mode) |
| |
| def set_len(self, h_len, w_len, img_slice, txt_slice, HEADS): |
| super().set_len(h_len, w_len, img_slice, txt_slice, HEADS) |
| self.FF.set_len(h_len, w_len, img_slice, txt_slice, HEADS) |
| |
| class StyleMMDiT_TXT_Block(StyleMMDiT_SubBlock): |
| def __init__(self, mode): |
| super().__init__(mode) |
| self.FF = StyleMMDiT_FF(mode) |
| |
| def set_len(self, h_len, w_len, img_slice, txt_slice, HEADS): |
| super().set_len(h_len, w_len, img_slice, txt_slice, HEADS) |
| self.FF.set_len(h_len, w_len, img_slice, txt_slice, HEADS) |
|
|
|
|
|
|
|
|
|
|
| class StyleMMDiT_BaseBlock: |
| def __init__(self, mode="passthrough"): |
|
|
| self.img = StyleMMDiT_IMG_Block(mode) |
| self.txt = StyleMMDiT_TXT_Block(mode) |
| |
| self.mask = [None] |
| self.attn_mask = [None] |
| |
| def set_len(self, h_len, w_len, img_slice, txt_slice, HEADS): |
| self.h_len = h_len |
| self.w_len = w_len |
| self.img_len = h_len * w_len |
| |
| self.img_slice = img_slice |
| self.txt_slice = txt_slice |
| self.HEADS = HEADS |
| |
| self.img.set_len(h_len, w_len, img_slice, txt_slice, HEADS) |
| self.txt.set_len(-1, -1, img_slice, txt_slice, HEADS) |
| |
| for i, mask in enumerate(self.mask): |
| if mask is not None and mask.ndim > 1: |
| self.mask[i] = F.interpolate(mask.unsqueeze(0), size=(h_len, w_len)).flatten().to(torch.bfloat16).cuda() |
| self.img.mask = self.mask |
| for i, mask in enumerate(self.attn_mask): |
| if mask is not None and mask.ndim > 1: |
| self.attn_mask[i] = F.interpolate(mask.unsqueeze(0), size=(h_len, w_len)).flatten().to(torch.bfloat16).cuda() |
| self.img.ATTN.mask = self.attn_mask |
|
|
| class StyleMMDiT_DoubleBlock(StyleMMDiT_BaseBlock): |
| def __init__(self, mode="passthrough"): |
| super().__init__(mode) |
| self.txt = StyleMMDiT_TXT_Block(mode) |
| |
| def set_len(self, h_len, w_len, img_slice, txt_slice, HEADS): |
| super().set_len(h_len, w_len, img_slice, txt_slice, HEADS) |
| self.txt.set_len(-1, -1, img_slice, txt_slice, HEADS) |
|
|
| class StyleMMDiT_SingleBlock(StyleMMDiT_BaseBlock): |
| def __init__(self, mode="passthrough"): |
| super().__init__(mode) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| class StyleUNet_Resample(Stylizer): |
| def __init__(self, mode): |
| super().__init__() |
| self.conv = [0.0] |
|
|
| class StyleUNet_Attn(Stylizer): |
| def __init__(self, mode): |
| super().__init__() |
| self.q_proj = [0.0] |
| self.k_proj = [0.0] |
| self.v_proj = [0.0] |
| self.out = [0.0] |
|
|
| class StyleUNet_FF(Stylizer): |
| def __init__(self, mode): |
| super().__init__() |
| self.proj = [0.0] |
| self.geglu = [0.0] |
| self.linear = [0.0] |
| |
| class StyleUNet_TransformerBlock(Stylizer): |
| def __init__(self, mode): |
| super().__init__() |
| |
| self.ATTN1 = StyleUNet_Attn(mode) |
| self.FF = StyleUNet_FF (mode) |
| self.ATTN2 = StyleUNet_Attn(mode) |
|
|
| self.self_attn = [0.0] |
| self.ff = [0.0] |
| self.cross_attn = [0.0] |
| |
| self.self_attn_res = [0.0] |
| self.cross_attn_res = [0.0] |
| self.ff_res = [0.0] |
| |
| self.norm1 = [0.0] |
| self.norm2 = [0.0] |
| self.norm3 = [0.0] |
| |
| def set_len(self, h_len, w_len, img_slice, txt_slice, HEADS): |
| super().set_len(h_len, w_len, img_slice, txt_slice, HEADS) |
| self.ATTN1.set_len(h_len, w_len, img_slice, txt_slice, HEADS) |
| self.ATTN2.set_len(h_len, w_len, img_slice, txt_slice, HEADS) |
|
|
| class StyleUNet_SpatialTransformer(Stylizer): |
| def __init__(self, mode): |
| super().__init__() |
| |
| self.TFMR = StyleUNet_TransformerBlock(mode) |
|
|
| self.spatial_norm_in = [0.0] |
| self.spatial_proj_in = [0.0] |
| self.spatial_transformer_block = [0.0] |
| self.spatial_transformer = [0.0] |
| self.spatial_proj_out = [0.0] |
| self.spatial_res = [0.0] |
| |
| def set_len(self, h_len, w_len, img_slice, txt_slice, HEADS): |
| super().set_len(h_len, w_len, img_slice, txt_slice, HEADS) |
| self.TFMR.set_len(h_len, w_len, img_slice, txt_slice, HEADS) |
|
|
| class StyleUNet_ResBlock(Stylizer): |
| def __init__(self, mode): |
| super().__init__() |
|
|
| self.in_norm = [0.0] |
| self.in_silu = [0.0] |
| self.in_conv = [0.0] |
|
|
| self.emb_silu = [0.0] |
| self.emb_linear = [0.0] |
| self.emb_res = [0.0] |
|
|
| self.out_norm = [0.0] |
| self.out_silu = [0.0] |
| self.out_conv = [0.0] |
| |
| self.residual = [0.0] |
|
|
|
|
| class StyleUNet_BaseBlock(Stylizer): |
| def __init__(self, mode="passthrough"): |
|
|
| self.resample_block = StyleUNet_Resample(mode) |
| self.res_block = StyleUNet_ResBlock(mode) |
| self.spatial_block = StyleUNet_SpatialTransformer(mode) |
| |
| self.resample = [0.0] |
| self.res = [0.0] |
| self.spatial = [0.0] |
| |
| self.mask = [None] |
| self.attn_mask = [None] |
| |
| self.KONTEXT = 0 |
|
|
| |
| def set_len(self, h_len, w_len, img_slice, txt_slice, HEADS): |
| self.h_len = h_len |
| self.w_len = w_len |
| self.img_len = h_len * w_len |
| |
| self.img_slice = img_slice |
| self.txt_slice = txt_slice |
| self.HEADS = HEADS |
| |
| self.resample_block.set_len(h_len, w_len, img_slice, txt_slice, HEADS) |
| self.res_block .set_len(h_len, w_len, img_slice, txt_slice, HEADS) |
| self.spatial_block .set_len(h_len, w_len, img_slice, txt_slice, HEADS) |
| |
| for i, mask in enumerate(self.mask): |
| if mask is not None and mask.ndim > 1: |
| self.mask[i] = F.interpolate(mask.unsqueeze(0), size=(h_len, w_len)).flatten().to(torch.bfloat16).cuda() |
| self.resample_block.mask = self.mask |
| self.res_block.mask = self.mask |
| self.spatial_block.mask = self.mask |
| self.spatial_block.TFMR.mask = self.mask |
| |
| for i, mask in enumerate(self.attn_mask): |
| if mask is not None and mask.ndim > 1: |
| self.attn_mask[i] = F.interpolate(mask.unsqueeze(0), size=(h_len, w_len)).flatten().to(torch.bfloat16).cuda() |
| self.spatial_block.TFMR.ATTN1.mask = self.attn_mask |
| |
| def __call__(self, x, attr): |
| B, C, H, W = x.shape |
| x = super().__call__(x.reshape(B, H*W, C), attr) |
| return x.reshape(B,C,H,W) |
| |
|
|
| class StyleUNet_InputBlock(StyleUNet_BaseBlock): |
| def __init__(self, mode="passthrough"): |
| super().__init__(mode) |
|
|
| class StyleUNet_MiddleBlock(StyleUNet_BaseBlock): |
| def __init__(self, mode="passthrough"): |
| super().__init__(mode) |
|
|
| class StyleUNet_OutputBlock(StyleUNet_BaseBlock): |
| def __init__(self, mode="passthrough"): |
| super().__init__(mode) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| class Style_Model(Stylizer): |
|
|
| def __init__(self, dtype=torch.float64, device=torch.device("cuda")): |
| super().__init__(dtype, device) |
| self.guides = [] |
| self.GUIDES_INITIALIZED = False |
| |
| |
| |
| |
| self.h_len = -1 |
| self.w_len = -1 |
| self.img_len = -1 |
| self.h_tile = [-1] |
| self.w_tile = [-1] |
| |
| self.proj_in = [0.0] |
| self.proj_out = [0.0] |
| |
| self.cond_pos = [None] |
| self.cond_neg = [None] |
| |
| self.noise_mode = "update" |
| self.recon_lure = "none" |
| self.data_shock = "none" |
| |
| self.data_shock_start_step = 0 |
| self.data_shock_end_step = 0 |
| |
| self.Retrojector = None |
| self.Endojector = None |
| |
| self.IMG_1ST = True |
| self.HEADS = 0 |
| self.KONTEXT = 0 |
| def __call__(self, x, attr): |
| if x.shape[0] == 1 and not self.KONTEXT: |
| return x |
| |
| weight_list = getattr(self, attr) |
| weights_all_zero = all(weight == 0.0 for weight in weight_list) |
| if weights_all_zero: |
| return x |
| |
| """x_ndim = x.ndim |
| if x_ndim == 4: |
| B, HEAD, HW, C = x.shape |
| |
| if x_ndim == 3: |
| B, HW, C = x.shape |
| if x.shape[-2] != self.HEADS and self.HEADS != 0: |
| x = x.reshape(B,self.HEADS,HW,-1)""" |
| |
| HEAD_DIM = x.shape[1] |
| if HEAD_DIM == self.HEADS: |
| B, HEAD_DIM, HW, C = x.shape |
| x = x.reshape(B, HW, C*HEAD_DIM) |
| |
| if self.KONTEXT == 1: |
| x = x.reshape(2, x.shape[1] // 2, x.shape[2]) |
| |
| weights_all_one = all(weight == 1.0 for weight in weight_list) |
| methods_all_scattersort = all(name == "scattersort" for name in self.method) |
| masks_all_none = all(mask is None for mask in self.mask) |
| |
| if weights_all_one and methods_all_scattersort and len(weight_list) > 1 and masks_all_none: |
| buf = Stylizer.buffer |
| buf['src_idx'] = x[0:1].argsort(dim=-2) |
| buf['ref_sorted'], buf['ref_idx'] = x[1:].reshape(1, -1, x.shape[-1]).sort(dim=-2) |
| buf['src'] = buf['ref_sorted'][:,::len(weight_list)].expand_as(buf['src_idx']) |
| |
| x[0:1] = x[0:1].scatter_(dim=-2, index=buf['src_idx'], src=buf['src'],) |
| else: |
| for i, (weight, mask) in enumerate(zip(weight_list, self.mask)): |
| if weight > 0 and weight < 1: |
| x_clone = x.clone() |
| if mask is not None: |
| x01 = x[0:1].clone() |
| slc = Stylizer.middle_slice(x.shape[-2], weight) |
| |
| method = getattr(self, self.method[i]) |
| if weight == 0.0: |
| continue |
| elif weight == 1.0: |
| x = method(x, idx=i+1) |
| else: |
| x = method(x, idx=i+1, slc=slc) |
| if weight > 0 and weight < 1 and self.method[i] != "scattersort": |
| x = torch.lerp(x_clone, x, weight) |
| |
| |
| |
| |
| if mask is not None: |
| x[0:1] = torch.lerp(x01, x[0:1], mask.view(1, -1, 1)) |
| |
| |
| |
| if self.KONTEXT == 1: |
| x = x.reshape(1, x.shape[1] * 2, x.shape[2]) |
| |
| if HEAD_DIM == self.HEADS: |
| return x.reshape(B, HEAD_DIM, HW, C) |
| else: |
| return x |
|
|
| def set_len(self, h_len, w_len, img_slice, txt_slice, HEADS): |
| self.h_len = h_len |
| self.w_len = w_len |
| self.img_len = h_len * w_len |
| |
| self.img_slice = img_slice |
| self.txt_slice = txt_slice |
| self.HEADS = HEADS |
| |
| |
| |
| |
| |
| |
| for i, mask in enumerate(self.mask): |
| if mask is not None and mask.ndim > 1: |
| self.mask[i] = F.interpolate(mask.unsqueeze(0), size=(h_len, w_len)).flatten().to(torch.bfloat16).cuda() |
|
|
| def init_guides(self, model): |
| if not self.GUIDES_INITIALIZED: |
| if self.guides == []: |
| self.guides = None |
| elif self.guides is not None: |
| for i, latent in enumerate(self.guides): |
| if type(latent) is dict: |
| latent = model.inner_model.inner_model.process_latent_in(latent['samples']).to(dtype=self.dtype, device=self.device) |
| elif type(latent) is torch.Tensor: |
| latent = latent.to(dtype=self.dtype, device=self.device) |
| else: |
| latent = None |
| |
|
|
| |
| |
|
|
| self.guides[i] = latent |
| if any(g is None for g in self.guides): |
| self.guides = None |
| print("Style guide nonetype set for Kontext.") |
| else: |
| self.guides = torch.cat(self.guides, dim=0) |
| self.GUIDES_INITIALIZED = True |
| |
| def set_conditioning(self, positive, negative): |
| self.cond_pos = [positive] |
| self.cond_neg = [negative] |
|
|
| def apply_style_conditioning(self, UNCOND, base_context, base_y=None, base_llama3=None): |
|
|
| def get_max_token_lengths(style_conditioning, base_context, base_y=None, base_llama3=None): |
| context_max_len = base_context.shape[-2] |
| llama3_max_len = base_llama3.shape[-2] if base_llama3 is not None else -1 |
| y_max_len = base_y.shape[-1] if base_y is not None else -1 |
|
|
| for style_cond in style_conditioning: |
| if style_cond is None: |
| continue |
| context_max_len = max(context_max_len, style_cond[0][0].shape[-2]) |
| if base_llama3 is not None: |
| llama3_max_len = max(llama3_max_len, style_cond[0][1]['conditioning_llama3'].shape[-2]) |
| if base_y is not None: |
| y_max_len = max(y_max_len, style_cond[0][1]['pooled_output'].shape[-1]) |
|
|
| return context_max_len, llama3_max_len, y_max_len |
|
|
| def pad_to_len(x, target_len, pad_value=0.0, dim=1): |
| if target_len < 0: |
| return x |
| cur_len = x.shape[dim] |
| if cur_len == target_len: |
| return x |
| return F.pad(x, (0, 0, 0, target_len - cur_len), value=pad_value) |
|
|
| style_conditioning = self.cond_pos if not UNCOND else self.cond_neg |
| |
| context_max_len, llama3_max_len, y_max_len = get_max_token_lengths( |
| style_conditioning = style_conditioning, |
| base_context = base_context, |
| base_y = base_y, |
| base_llama3 = base_llama3, |
| ) |
| |
| bsz_style = len(style_conditioning) |
| |
| context = base_context.repeat(bsz_style + 1, 1, 1) |
| y = base_y.repeat(bsz_style + 1, 1) if base_y is not None else None |
| llama3 = base_llama3.repeat(bsz_style + 1, 1, 1, 1) if base_llama3 is not None else None |
|
|
| context = pad_to_len(context, context_max_len, dim=-2) |
| llama3 = pad_to_len(llama3, llama3_max_len, dim=-2) if base_llama3 is not None else None |
| y = pad_to_len(y, y_max_len, dim=-1) if base_y is not None else None |
| |
| for ci, style_cond in enumerate(style_conditioning): |
| if style_cond is None: |
| continue |
| context[ci+1:ci+2] = pad_to_len(style_cond[0][0], context_max_len, dim=-2).to(context) |
| if llama3 is not None: |
| llama3 [ci+1:ci+2] = pad_to_len(style_cond[0][1]['conditioning_llama3'], llama3_max_len, dim=-2).to(llama3) |
| if y is not None: |
| y [ci+1:ci+2] = pad_to_len(style_cond[0][1]['pooled_output'], y_max_len, dim=-1).to(y) |
| |
| return context, y, llama3 |
| |
| def WCT_data(self, denoised_embed, y0_style_embed): |
| Stylizer.CLS_WCT.set(y0_style_embed.to(denoised_embed)) |
| return Stylizer.CLS_WCT.get(denoised_embed) |
|
|
| def WCT2_data(self, denoised_embed, y0_style_embed): |
| Stylizer.CLS_WCT2.set(y0_style_embed.to(denoised_embed)) |
| return Stylizer.CLS_WCT2.get(denoised_embed) |
|
|
| def apply_to_data(self, denoised, y0_style=None, mode="none"): |
| if mode == "none": |
| return denoised |
| y0_style = self.guides if y0_style is None else y0_style |
| |
| y0_style_embed = self.Retrojector.embed(y0_style) |
| denoised_embed = self.Retrojector.embed(denoised) |
| B,HW,C = y0_style_embed.shape |
| embed = torch.cat([denoised_embed, y0_style_embed.view(1,B*HW,C)[:,::B,:]], dim=0) |
| method = getattr(self, mode) |
| if mode == "scattersort": |
| slc = Stylizer.middle_slice(embed.shape[-2], self.data_shock_weight) |
| embed = method(embed, slc=slc) |
| else: |
| embed = method(embed) |
| return self.Retrojector.unembed(embed[0:1]) |
|
|
| def apply_recon_lure(self, denoised, y0_style): |
| if self.recon_lure == "none": |
| return denoised |
| for i in range(denoised.shape[0]): |
| denoised[i:i+1] = self.apply_to_data(denoised[i:i+1], y0_style, self.recon_lure) |
| return denoised |
|
|
| def apply_data_shock(self, denoised): |
| if self.data_shock == "none": |
| return denoised |
| datashock_ref = getattr(self, "datashock_ref", None) |
| if self.data_shock == "scattersort": |
| return self.apply_to_data(denoised, datashock_ref, self.data_shock) |
| else: |
| return torch.lerp(denoised, self.apply_to_data(denoised, datashock_ref, self.data_shock), torch.Tensor([self.data_shock_weight]).double().cuda()) |
|
|
|
|
|
|
|
|
| class StyleMMDiT_Model(Style_Model): |
|
|
| def __init__(self, dtype=torch.float64, device=torch.device("cuda")): |
| super().__init__(dtype, device) |
| self.double_blocks = [StyleMMDiT_DoubleBlock() for _ in range(100)] |
| self.single_blocks = [StyleMMDiT_SingleBlock() for _ in range(100)] |
|
|
| def set_len(self, h_len, w_len, img_slice, txt_slice, HEADS): |
| super().set_len(h_len, w_len, img_slice, txt_slice, HEADS) |
| for block in self.double_blocks: |
| block.set_len(h_len, w_len, img_slice, txt_slice, HEADS) |
| for block in self.single_blocks: |
| block.set_len(h_len, w_len, img_slice, txt_slice, HEADS) |
|
|
|
|
| class StyleUNet_Model(Style_Model): |
|
|
| def __init__(self, dtype=torch.float64, device=torch.device("cuda")): |
| super().__init__(dtype, device) |
| self.input_blocks = [StyleUNet_InputBlock() for _ in range(100)] |
| self.middle_blocks = [StyleUNet_MiddleBlock() for _ in range(100)] |
| self.output_blocks = [StyleUNet_OutputBlock() for _ in range(100)] |
|
|
| def set_len(self, h_len, w_len, img_slice, txt_slice, HEADS): |
| super().set_len(h_len, w_len, img_slice, txt_slice, HEADS) |
| for block in self.input_blocks: |
| block.set_len(h_len, w_len, img_slice, txt_slice, HEADS) |
| for block in self.middle_blocks: |
| block.set_len(h_len, w_len, img_slice, txt_slice, HEADS) |
| for block in self.output_blocks: |
| block.set_len(h_len, w_len, img_slice, txt_slice, HEADS) |
|
|
| def __call__(self, x, attr): |
| B, C, H, W = x.shape |
| x = super().__call__(x.reshape(B, H*W, C), attr) |
| return x.reshape(B,C,H,W) |
| |
|
|