oimoyu's picture
init
9ab8b5f verified
raw
history blame
10.6 kB
import torch
#from .latent_resizer import LatentResizer
from comfy import model_management
import os
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
def normalization(channels):
return nn.GroupNorm(32, channels)
def zero_module(module):
for p in module.parameters():
p.detach().zero_()
return module
class AttnBlock(nn.Module):
def __init__(self, in_channels):
super().__init__()
self.in_channels = in_channels
self.norm = normalization(in_channels)
self.q = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.k = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.v = nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
self.proj_out = nn.Conv2d(
in_channels, in_channels, kernel_size=1, stride=1, padding=0
)
def attention(self, h_: torch.Tensor) -> torch.Tensor:
h_ = self.norm(h_)
q = self.q(h_)
k = self.k(h_)
v = self.v(h_)
b, c, h, w = q.shape
q, k, v = map(
lambda x: rearrange(x, "b c h w -> b 1 (h w) c").contiguous(), (q, k, v)
)
h_ = nn.functional.scaled_dot_product_attention(
q, k, v
) # scale is dim ** -0.5 per default
return rearrange(h_, "b 1 (h w) c -> b c h w", h=h, w=w, c=c, b=b)
def forward(self, x, **kwargs):
h_ = x
h_ = self.attention(h_)
h_ = self.proj_out(h_)
return x + h_
def make_attn(in_channels, attn_kwargs=None):
return AttnBlock(in_channels)
class ResBlockEmb(nn.Module):
def __init__(
self,
channels,
emb_channels,
dropout=0,
out_channels=None,
use_conv=False,
use_scale_shift_norm=False,
kernel_size=3,
exchange_temb_dims=False,
skip_t_emb=False,
):
super().__init__()
self.channels = channels
self.emb_channels = emb_channels
self.dropout = dropout
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_scale_shift_norm = use_scale_shift_norm
self.exchange_temb_dims = exchange_temb_dims
padding = kernel_size // 2
self.in_layers = nn.Sequential(
normalization(channels),
nn.SiLU(),
nn.Conv2d(channels, self.out_channels, kernel_size, padding=padding),
)
self.skip_t_emb = skip_t_emb
self.emb_out_channels = (
2 * self.out_channels if use_scale_shift_norm else self.out_channels
)
if self.skip_t_emb:
print(f"Skipping timestep embedding in {self.__class__.__name__}")
assert not self.use_scale_shift_norm
self.emb_layers = None
self.exchange_temb_dims = False
else:
self.emb_layers = nn.Sequential(
nn.SiLU(),
nn.Linear(
emb_channels,
self.emb_out_channels,
),
)
self.out_layers = nn.Sequential(
normalization(self.out_channels),
nn.SiLU(),
nn.Dropout(p=dropout),
zero_module(
nn.Conv2d(
self.out_channels,
self.out_channels,
kernel_size,
padding=padding,
)
),
)
if self.out_channels == channels:
self.skip_connection = nn.Identity()
elif use_conv:
self.skip_connection = nn.Conv2d(
channels, self.out_channels, kernel_size, padding=padding
)
else:
self.skip_connection = nn.Conv2d(channels, self.out_channels, 1)
def forward(self, x, emb):
h = self.in_layers(x)
if self.skip_t_emb:
emb_out = torch.zeros_like(h)
else:
emb_out = self.emb_layers(emb).type(h.dtype)
while len(emb_out.shape) < len(h.shape):
emb_out = emb_out[..., None]
if self.use_scale_shift_norm:
out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
scale, shift = torch.chunk(emb_out, 2, dim=1)
h = out_norm(h) * (1 + scale) + shift
h = out_rest(h)
else:
if self.exchange_temb_dims:
emb_out = rearrange(emb_out, "b t c ... -> b c t ...")
h = h + emb_out
h = self.out_layers(h)
return self.skip_connection(x) + h
class LatentResizer(nn.Module):
def __init__(self, in_blocks=10, out_blocks=10, channels=128, dropout=0, attn=True):
super().__init__()
self.conv_in = nn.Conv2d(4, channels, 3, padding=1)
self.channels = channels
embed_dim = 32
self.embed = nn.Sequential(
nn.Linear(1, embed_dim),
nn.SiLU(),
nn.Linear(embed_dim, embed_dim),
)
self.in_blocks = nn.ModuleList([])
for b in range(in_blocks):
if (b == 1 or b == in_blocks - 1) and attn:
self.in_blocks.append(make_attn(channels))
self.in_blocks.append(ResBlockEmb(channels, embed_dim, dropout))
self.out_blocks = nn.ModuleList([])
for b in range(out_blocks):
if (b == 1 or b == out_blocks - 1) and attn:
self.out_blocks.append(make_attn(channels))
self.out_blocks.append(ResBlockEmb(channels, embed_dim, dropout))
self.norm_out = normalization(channels)
self.conv_out = nn.Conv2d(channels, 4, 3, padding=1)
@classmethod
def load_model(cls, filename, device="cpu", dtype=torch.float32, dropout=0):
if not 'weights_only' in torch.load.__code__.co_varnames:
weights = torch.load(filename, map_location=torch.device("cpu"))
else:
weights = torch.load(filename, map_location=torch.device("cpu"), weights_only=True)
in_blocks = 0
out_blocks = 0
in_tfs = 0
out_tfs = 0
channels = weights["conv_in.bias"].shape[0]
for k in weights.keys():
k = k.split(".")
if k[0] == "in_blocks":
in_blocks = max(in_blocks, int(k[1]))
if k[2] == "q" and k[3] == "weight":
in_tfs += 1
if k[0] == "out_blocks":
out_blocks = max(out_blocks, int(k[1]))
if k[2] == "q" and k[3] == "weight":
out_tfs += 1
in_blocks = in_blocks + 1 - in_tfs
out_blocks = out_blocks + 1 - out_tfs
resizer = cls(
in_blocks=in_blocks,
out_blocks=out_blocks,
channels=channels,
dropout=dropout,
attn=(out_tfs != 0),
)
resizer.load_state_dict(weights)
resizer.eval()
resizer.to(device, dtype=dtype)
return resizer
def forward(self, x, scale=None, size=None):
if scale is None and size is None:
raise ValueError("Either scale or size needs to be not None")
if scale is not None and size is not None:
raise ValueError("Both scale or size can't be not None")
if scale is not None:
size = (x.shape[-2] * scale, x.shape[-1] * scale)
size = tuple([int(round(i)) for i in size])
else:
scale = size[-1] / x.shape[-1]
# Output is the same size as input
if size == x.shape[-2:]:
return x
scale = torch.tensor([scale - 1], dtype=x.dtype).to(x.device).unsqueeze(0)
emb = self.embed(scale)
x = self.conv_in(x)
for b in self.in_blocks:
if isinstance(b, ResBlockEmb):
x = b(x, emb)
else:
x = b(x)
x = F.interpolate(x, size=size, mode="bilinear")
for b in self.out_blocks:
if isinstance(b, ResBlockEmb):
x = b(x, emb)
else:
x = b(x)
x = self.norm_out(x)
x = F.silu(x)
x = self.conv_out(x)
return x
########################################################
class NNLatentUpscale:
"""
Upscales SDXL latent using neural network
"""
def __init__(self):
self.local_dir = os.path.dirname(os.path.realpath(__file__))
self.scale_factor = 0.13025
self.dtype = torch.float32
if model_management.should_use_fp16():
self.dtype = torch.float16
self.weight_path = {
"SDXL": os.path.join(self.local_dir, "sdxl_resizer.pt"),
"SD 1.x": os.path.join(self.local_dir, "sd15_resizer.pt"),
}
self.version = "none"
@classmethod
def INPUT_TYPES(s):
return {
"required": {
"latent": ("LATENT",),
"version": (["SDXL", "SD 1.x"],),
"upscale": (
"FLOAT",
{
"default": 1.5,
"min": 1.0,
"max": 2.0,
"step": 0.01,
"display": "number",
},
),
},
}
RETURN_TYPES = ("LATENT",)
FUNCTION = "upscale"
CATEGORY = "latent"
def upscale(self, latent, version, upscale):
device = model_management.get_torch_device()
samples = latent["samples"].to(device=device, dtype=self.dtype)
if version != self.version:
self.model = LatentResizer.load_model(self.weight_path[version], device, self.dtype)
self.version = version
self.model.to(device=device)
latent_out = (self.model(self.scale_factor * samples, scale=upscale) / self.scale_factor)
if self.dtype != torch.float32:
latent_out = latent_out.to(dtype=torch.float32)
latent_out = latent_out.to(device="cpu")
self.model.to(device=model_management.vae_offload_device())
return ({"samples": latent_out},)
NODE_CLASS_MAPPINGS = {
"NNLatentUpscale": NNLatentUpscale
}
NODE_DISPLAY_NAME_MAPPINGS = {
"NNlLatentUpscale": "EFF Latent Upscale"
}