import torch #from .latent_resizer import LatentResizer from comfy import model_management import os import torch.nn as nn import torch.nn.functional as F from einops import rearrange def normalization(channels): return nn.GroupNorm(32, channels) def zero_module(module): for p in module.parameters(): p.detach().zero_() return module class AttnBlock(nn.Module): def __init__(self, in_channels): super().__init__() self.in_channels = in_channels self.norm = normalization(in_channels) self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0) self.proj_out = nn.Conv2d( in_channels, in_channels, kernel_size=1, stride=1, padding=0 ) def attention(self, h_: torch.Tensor) -> torch.Tensor: h_ = self.norm(h_) q = self.q(h_) k = self.k(h_) v = self.v(h_) b, c, h, w = q.shape q, k, v = map( lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v) ) h_ = nn.functional.scaled_dot_product_attention( q, k, v ) # scale is dim ** -0.5 per default return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b) def forward(self, x, **kwargs): h_ = x h_ = self.attention(h_) h_ = self.proj_out(h_) return x + h_ def make_attn(in_channels, attn_kwargs=None): return AttnBlock(in_channels) class ResBlockEmb(nn.Module): def __init__( self, channels, emb_channels, dropout=0, out_channels=None, use_conv=False, use_scale_shift_norm=False, kernel_size=3, exchange_temb_dims=False, skip_t_emb=False, ): super().__init__() self.channels = channels self.emb_channels = emb_channels self.dropout = dropout self.out_channels = out_channels or channels self.use_conv = use_conv self.use_scale_shift_norm = use_scale_shift_norm self.exchange_temb_dims = exchange_temb_dims padding = kernel_size // 2 self.in_layers = nn.Sequential( normalization(channels), nn.SiLU(), nn.Conv2d(channels, self.out_channels, kernel_size, padding=padding), ) self.skip_t_emb = skip_t_emb self.emb_out_channels = ( 2 * self.out_channels if use_scale_shift_norm else self.out_channels ) if self.skip_t_emb: print(f"Skipping timestep embedding in {self.__class__.__name__}") assert not self.use_scale_shift_norm self.emb_layers = None self.exchange_temb_dims = False else: self.emb_layers = nn.Sequential( nn.SiLU(), nn.Linear( emb_channels, self.emb_out_channels, ), ) self.out_layers = nn.Sequential( normalization(self.out_channels), nn.SiLU(), nn.Dropout(p=dropout), zero_module( nn.Conv2d( self.out_channels, self.out_channels, kernel_size, padding=padding, ) ), ) if self.out_channels == channels: self.skip_connection = nn.Identity() elif use_conv: self.skip_connection = nn.Conv2d( channels, self.out_channels, kernel_size, padding=padding ) else: self.skip_connection = nn.Conv2d(channels, self.out_channels, 1) def forward(self, x, emb): h = self.in_layers(x) if self.skip_t_emb: emb_out = torch.zeros_like(h) else: emb_out = self.emb_layers(emb).type(h.dtype) while len(emb_out.shape) < len(h.shape): emb_out = emb_out[..., None] if self.use_scale_shift_norm: out_norm, out_rest = self.out_layers[0], self.out_layers[1:] scale, shift = torch.chunk(emb_out, 2, dim=1) h = out_norm(h) * (1 + scale) + shift h = out_rest(h) else: if self.exchange_temb_dims: emb_out = rearrange(emb_out, "b t c ... -> b c t ...") h = h + emb_out h = self.out_layers(h) return self.skip_connection(x) + h class LatentResizer(nn.Module): def __init__(self, in_blocks=10, out_blocks=10, channels=128, dropout=0, attn=True): super().__init__() self.conv_in = nn.Conv2d(4, channels, 3, padding=1) self.channels = channels embed_dim = 32 self.embed = nn.Sequential( nn.Linear(1, embed_dim), nn.SiLU(), nn.Linear(embed_dim, embed_dim), ) self.in_blocks = nn.ModuleList([]) for b in range(in_blocks): if (b == 1 or b == in_blocks - 1) and attn: self.in_blocks.append(make_attn(channels)) self.in_blocks.append(ResBlockEmb(channels, embed_dim, dropout)) self.out_blocks = nn.ModuleList([]) for b in range(out_blocks): if (b == 1 or b == out_blocks - 1) and attn: self.out_blocks.append(make_attn(channels)) self.out_blocks.append(ResBlockEmb(channels, embed_dim, dropout)) self.norm_out = normalization(channels) self.conv_out = nn.Conv2d(channels, 4, 3, padding=1) @classmethod def load_model(cls, filename, device="cpu", dtype=torch.float32, dropout=0): if not 'weights_only' in torch.load.__code__.co_varnames: weights = torch.load(filename, map_location=torch.device("cpu")) else: weights = torch.load(filename, map_location=torch.device("cpu"), weights_only=True) in_blocks = 0 out_blocks = 0 in_tfs = 0 out_tfs = 0 channels = weights["conv_in.bias"].shape[0] for k in weights.keys(): k = k.split(".") if k[0] == "in_blocks": in_blocks = max(in_blocks, int(k[1])) if k[2] == "q" and k[3] == "weight": in_tfs += 1 if k[0] == "out_blocks": out_blocks = max(out_blocks, int(k[1])) if k[2] == "q" and k[3] == "weight": out_tfs += 1 in_blocks = in_blocks + 1 - in_tfs out_blocks = out_blocks + 1 - out_tfs resizer = cls( in_blocks=in_blocks, out_blocks=out_blocks, channels=channels, dropout=dropout, attn=(out_tfs != 0), ) resizer.load_state_dict(weights) resizer.eval() resizer.to(device, dtype=dtype) return resizer def forward(self, x, scale=None, size=None): if scale is None and size is None: raise ValueError("Either scale or size needs to be not None") if scale is not None and size is not None: raise ValueError("Both scale or size can't be not None") if scale is not None: size = (x.shape[-2] * scale, x.shape[-1] * scale) size = tuple([int(round(i)) for i in size]) else: scale = size[-1] / x.shape[-1] # Output is the same size as input if size == x.shape[-2:]: return x scale = torch.tensor([scale - 1], dtype=x.dtype).to(x.device).unsqueeze(0) emb = self.embed(scale) x = self.conv_in(x) for b in self.in_blocks: if isinstance(b, ResBlockEmb): x = b(x, emb) else: x = b(x) x = F.interpolate(x, size=size, mode="bilinear") for b in self.out_blocks: if isinstance(b, ResBlockEmb): x = b(x, emb) else: x = b(x) x = self.norm_out(x) x = F.silu(x) x = self.conv_out(x) return x ######################################################## class NNLatentUpscale: """ Upscales SDXL latent using neural network """ def __init__(self): self.local_dir = os.path.dirname(os.path.realpath(__file__)) self.scale_factor = 0.13025 self.dtype = torch.float32 if model_management.should_use_fp16(): self.dtype = torch.float16 self.weight_path = { "SDXL": os.path.join(self.local_dir, "sdxl_resizer.pt"), "SD 1.x": os.path.join(self.local_dir, "sd15_resizer.pt"), } self.version = "none" @classmethod def INPUT_TYPES(s): return { "required": { "latent": ("LATENT",), "version": (["SDXL", "SD 1.x"],), "upscale": ( "FLOAT", { "default": 1.5, "min": 1.0, "max": 2.0, "step": 0.01, "display": "number", }, ), }, } RETURN_TYPES = ("LATENT",) FUNCTION = "upscale" CATEGORY = "latent" def upscale(self, latent, version, upscale): device = model_management.get_torch_device() samples = latent["samples"].to(device=device, dtype=self.dtype) if version != self.version: self.model = LatentResizer.load_model(self.weight_path[version], device, self.dtype) self.version = version self.model.to(device=device) latent_out = (self.model(self.scale_factor * samples, scale=upscale) / self.scale_factor) if self.dtype != torch.float32: latent_out = latent_out.to(dtype=torch.float32) latent_out = latent_out.to(device="cpu") self.model.to(device=model_management.vae_offload_device()) return ({"samples": latent_out},) NODE_CLASS_MAPPINGS = { "NNLatentUpscale": NNLatentUpscale } NODE_DISPLAY_NAME_MAPPINGS = { "NNlLatentUpscale": "EFF Latent Upscale" }