Spaces:
Running
on
Zero
Running
on
Zero
Commit
·
012e1d0
1
Parent(s):
680053b
add inference
Browse files- NoiseTransformer.py +26 -0
- SVDNoiseUnet.py +430 -0
- app.py +330 -0
- dpm_solver_v3.py +904 -0
- free_lunch_utils.py +303 -0
- requirements.txt +11 -0
- sampler.py +315 -0
- uni_pc.py +757 -0
NoiseTransformer.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
from timm import create_model
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
__all__ = ['NoiseTransformer']
|
| 8 |
+
|
| 9 |
+
class NoiseTransformer(nn.Module):
|
| 10 |
+
def __init__(self, resolution=(128,96)):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.upsample = lambda x: F.interpolate(x, [224,224])
|
| 13 |
+
self.downsample = lambda x: F.interpolate(x, [resolution[0],resolution[1]])
|
| 14 |
+
self.upconv = nn.Conv2d(7,4,(1,1),(1,1),(0,0))
|
| 15 |
+
self.downconv = nn.Conv2d(4,3,(1,1),(1,1),(0,0))
|
| 16 |
+
# self.upconv = nn.Conv2d(7,4,(1,1),(1,1),(0,0))
|
| 17 |
+
self.swin = create_model("swin_tiny_patch4_window7_224",pretrained=True)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def forward(self, x, residual=False):
|
| 21 |
+
if residual:
|
| 22 |
+
x = self.upconv(self.downsample(self.swin.forward_features(self.downconv(self.upsample(x))))) + x
|
| 23 |
+
else:
|
| 24 |
+
x = self.upconv(self.downsample(self.swin.forward_features(self.downconv(self.upsample(x)))))
|
| 25 |
+
|
| 26 |
+
return x
|
SVDNoiseUnet.py
ADDED
|
@@ -0,0 +1,430 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import einops
|
| 4 |
+
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
from torch.jit import Final
|
| 7 |
+
from timm.layers import use_fused_attn
|
| 8 |
+
from timm.models.layers import PatchEmbed, Mlp, DropPath, trunc_normal_, lecun_normal_, get_act_layer
|
| 9 |
+
from abc import abstractmethod
|
| 10 |
+
from NoiseTransformer import NoiseTransformer
|
| 11 |
+
from einops import rearrange
|
| 12 |
+
__all__ = ['SVDNoiseUnet', 'SVDNoiseUnet_Concise']
|
| 13 |
+
|
| 14 |
+
class Attention(nn.Module):
|
| 15 |
+
fused_attn: Final[bool]
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
dim: int,
|
| 20 |
+
num_heads: int = 8,
|
| 21 |
+
qkv_bias: bool = False,
|
| 22 |
+
qk_norm: bool = False,
|
| 23 |
+
attn_drop: float = 0.,
|
| 24 |
+
proj_drop: float = 0.,
|
| 25 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
| 26 |
+
) -> None:
|
| 27 |
+
super().__init__()
|
| 28 |
+
assert dim % num_heads == 0, 'dim should be divisible by num_heads'
|
| 29 |
+
self.num_heads = num_heads
|
| 30 |
+
self.head_dim = dim // num_heads
|
| 31 |
+
self.scale = self.head_dim ** -0.5
|
| 32 |
+
self.fused_attn = use_fused_attn()
|
| 33 |
+
|
| 34 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
| 35 |
+
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 36 |
+
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
| 37 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
| 38 |
+
self.proj = nn.Linear(dim, dim)
|
| 39 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
| 40 |
+
|
| 41 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 42 |
+
B, N, C = x.shape
|
| 43 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
|
| 44 |
+
q, k, v = qkv.unbind(0)
|
| 45 |
+
q, k = self.q_norm(q), self.k_norm(k)
|
| 46 |
+
|
| 47 |
+
if self.fused_attn:
|
| 48 |
+
x = F.scaled_dot_product_attention(
|
| 49 |
+
q, k, v,
|
| 50 |
+
dropout_p=self.attn_drop.p if self.training else 0.,
|
| 51 |
+
)
|
| 52 |
+
else:
|
| 53 |
+
q = q * self.scale
|
| 54 |
+
attn = q @ k.transpose(-2, -1)
|
| 55 |
+
attn = attn.softmax(dim=-1)
|
| 56 |
+
attn = self.attn_drop(attn)
|
| 57 |
+
x = attn @ v
|
| 58 |
+
|
| 59 |
+
x = x.transpose(1, 2).reshape(B, N, C)
|
| 60 |
+
x = self.proj(x)
|
| 61 |
+
x = self.proj_drop(x)
|
| 62 |
+
return x
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class SVDNoiseUnet(nn.Module):
|
| 66 |
+
def __init__(self, in_channels=4, out_channels=4, resolution=(128,96)): # resolution = size // 8
|
| 67 |
+
super(SVDNoiseUnet, self).__init__()
|
| 68 |
+
|
| 69 |
+
_in_1 = int(resolution[0] * in_channels // 2)
|
| 70 |
+
_out_1 = int(resolution[0] * out_channels // 2)
|
| 71 |
+
|
| 72 |
+
_in_2 = int(resolution[1] * in_channels // 2)
|
| 73 |
+
_out_2 = int(resolution[1] * out_channels // 2)
|
| 74 |
+
self.mlp1 = nn.Sequential(
|
| 75 |
+
nn.Linear(_in_1, 64),
|
| 76 |
+
nn.ReLU(inplace=True),
|
| 77 |
+
nn.Linear(64, _out_1),
|
| 78 |
+
)
|
| 79 |
+
self.mlp2 = nn.Sequential(
|
| 80 |
+
nn.Linear(_in_2, 64),
|
| 81 |
+
nn.ReLU(inplace=True),
|
| 82 |
+
nn.Linear(64, _out_2),
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
self.mlp3 = nn.Sequential(
|
| 86 |
+
nn.Linear(_in_2, _out_2),
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
self.attention = Attention(_out_2)
|
| 90 |
+
|
| 91 |
+
self.bn = nn.BatchNorm1d(256)
|
| 92 |
+
self.bn2 = nn.BatchNorm1d(192)
|
| 93 |
+
|
| 94 |
+
self.mlp4 = nn.Sequential(
|
| 95 |
+
nn.Linear(_out_2, 1024),
|
| 96 |
+
nn.ReLU(inplace=True),
|
| 97 |
+
nn.Linear(1024, _out_2),
|
| 98 |
+
)
|
| 99 |
+
self.ffn = nn.Sequential(
|
| 100 |
+
nn.Linear(256, 384), # Expand
|
| 101 |
+
nn.ReLU(inplace=True),
|
| 102 |
+
nn.Linear(384, 192) # Reduce to target size
|
| 103 |
+
)
|
| 104 |
+
self.ffn2 = nn.Sequential(
|
| 105 |
+
nn.Linear(256, 384), # Expand
|
| 106 |
+
nn.ReLU(inplace=True),
|
| 107 |
+
nn.Linear(384, 192) # Reduce to target size
|
| 108 |
+
)
|
| 109 |
+
# self.adaptive_pool = nn.AdaptiveAvgPool2d((256, 192))
|
| 110 |
+
|
| 111 |
+
def forward(self, x, residual=False):
|
| 112 |
+
b, c, h, w = x.shape
|
| 113 |
+
x = einops.rearrange(x, "b (a c)h w ->b (a h)(c w)", a=2,c=2) # x -> [1, 256, 256]
|
| 114 |
+
U, s, V = torch.linalg.svd(x) # U->[b 256 256], s-> [b 256], V->[b 256 256]
|
| 115 |
+
U_T = U.permute(0, 2, 1)
|
| 116 |
+
U_out = self.ffn(self.mlp1(U_T))
|
| 117 |
+
U_out = self.bn(U_out)
|
| 118 |
+
U_out = U_out.transpose(1, 2)
|
| 119 |
+
U_out = self.ffn2(U_out) # [b, 256, 256] -> [b, 256, 192]
|
| 120 |
+
U_out = self.bn2(U_out)
|
| 121 |
+
U_out = U_out.transpose(1, 2)
|
| 122 |
+
# U_out = self.bn(U_out)
|
| 123 |
+
V_out = self.mlp2(V)
|
| 124 |
+
s_out = self.mlp3(s).unsqueeze(1) # s -> [b, 1, 256] => [b, 256, 256]
|
| 125 |
+
out = U_out + V_out + s_out
|
| 126 |
+
# print(out.size())
|
| 127 |
+
out = out.squeeze(1)
|
| 128 |
+
out = self.attention(out).mean(1)
|
| 129 |
+
out = self.mlp4(out) + s
|
| 130 |
+
diagonal_out = torch.diag_embed(out)
|
| 131 |
+
padded_diag = F.pad(diagonal_out, (0, 0, 0, 64), mode='constant', value=0) # Shape: [b, 1, 256, 192]
|
| 132 |
+
pred = U @ padded_diag @ V
|
| 133 |
+
return einops.rearrange(pred, "b (a h)(c w) -> b (a c) h w", a=2,c=2)
|
| 134 |
+
|
| 135 |
+
class SVDNoiseUnet64(nn.Module):
|
| 136 |
+
def __init__(self, in_channels=4, out_channels=4, resolution=64): # resolution = size // 8
|
| 137 |
+
super(SVDNoiseUnet64, self).__init__()
|
| 138 |
+
|
| 139 |
+
_in = int(resolution * in_channels // 2)
|
| 140 |
+
_out = int(resolution * out_channels // 2)
|
| 141 |
+
self.mlp1 = nn.Sequential(
|
| 142 |
+
nn.Linear(_in, 64),
|
| 143 |
+
nn.ReLU(inplace=True),
|
| 144 |
+
nn.Linear(64, _out),
|
| 145 |
+
)
|
| 146 |
+
self.mlp2 = nn.Sequential(
|
| 147 |
+
nn.Linear(_in, 64),
|
| 148 |
+
nn.ReLU(inplace=True),
|
| 149 |
+
nn.Linear(64, _out),
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
self.mlp3 = nn.Sequential(
|
| 153 |
+
nn.Linear(_in, _out),
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
self.attention = Attention(_out)
|
| 157 |
+
|
| 158 |
+
self.bn = nn.BatchNorm2d(_out)
|
| 159 |
+
|
| 160 |
+
self.mlp4 = nn.Sequential(
|
| 161 |
+
nn.Linear(_out, 1024),
|
| 162 |
+
nn.ReLU(inplace=True),
|
| 163 |
+
nn.Linear(1024, _out),
|
| 164 |
+
)
|
| 165 |
+
|
| 166 |
+
def forward(self, x, residual=False):
|
| 167 |
+
b, c, h, w = x.shape
|
| 168 |
+
x = einops.rearrange(x, "b (a c)h w ->b (a h)(c w)", a=2,c=2) # x -> [1, 256, 256]
|
| 169 |
+
U, s, V = torch.linalg.svd(x) # U->[b 256 256], s-> [b 256], V->[b 256 256]
|
| 170 |
+
U_T = U.permute(0, 2, 1)
|
| 171 |
+
out = self.mlp1(U_T) + self.mlp2(V) + self.mlp3(s).unsqueeze(1) # s -> [b, 1, 256] => [b, 256, 256]
|
| 172 |
+
out = self.attention(out).mean(1)
|
| 173 |
+
out = self.mlp4(out) + s
|
| 174 |
+
pred = U @ torch.diag_embed(out) @ V
|
| 175 |
+
return einops.rearrange(pred, "b (a h)(c w) -> b (a c) h w", a=2,c=2)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class SVDNoiseUnet128(nn.Module):
|
| 180 |
+
def __init__(self, in_channels=4, out_channels=4, resolution=128): # resolution = size // 8
|
| 181 |
+
super(SVDNoiseUnet128, self).__init__()
|
| 182 |
+
|
| 183 |
+
_in = int(resolution * in_channels // 2)
|
| 184 |
+
_out = int(resolution * out_channels // 2)
|
| 185 |
+
self.mlp1 = nn.Sequential(
|
| 186 |
+
nn.Linear(_in, 64),
|
| 187 |
+
nn.ReLU(inplace=True),
|
| 188 |
+
nn.Linear(64, _out),
|
| 189 |
+
)
|
| 190 |
+
self.mlp2 = nn.Sequential(
|
| 191 |
+
nn.Linear(_in, 64),
|
| 192 |
+
nn.ReLU(inplace=True),
|
| 193 |
+
nn.Linear(64, _out),
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
self.mlp3 = nn.Sequential(
|
| 197 |
+
nn.Linear(_in, _out),
|
| 198 |
+
)
|
| 199 |
+
|
| 200 |
+
self.attention = Attention(_out)
|
| 201 |
+
|
| 202 |
+
self.bn = nn.BatchNorm2d(_out)
|
| 203 |
+
|
| 204 |
+
self.mlp4 = nn.Sequential(
|
| 205 |
+
nn.Linear(_out, 1024),
|
| 206 |
+
nn.ReLU(inplace=True),
|
| 207 |
+
nn.Linear(1024, _out),
|
| 208 |
+
)
|
| 209 |
+
|
| 210 |
+
def forward(self, x, residual=False):
|
| 211 |
+
b, c, h, w = x.shape
|
| 212 |
+
x = einops.rearrange(x, "b (a c)h w ->b (a h)(c w)", a=2,c=2) # x -> [1, 256, 256]
|
| 213 |
+
U, s, V = torch.linalg.svd(x) # U->[b 256 256], s-> [b 256], V->[b 256 256]
|
| 214 |
+
U_T = U.permute(0, 2, 1)
|
| 215 |
+
out = self.mlp1(U_T) + self.mlp2(V) + self.mlp3(s).unsqueeze(1) # s -> [b, 1, 256] => [b, 256, 256]
|
| 216 |
+
out = self.attention(out).mean(1)
|
| 217 |
+
out = self.mlp4(out) + s
|
| 218 |
+
pred = U @ torch.diag_embed(out) @ V
|
| 219 |
+
return einops.rearrange(pred, "b (a h)(c w) -> b (a c) h w", a=2,c=2)
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
class SVDNoiseUnet_Concise(nn.Module):
|
| 224 |
+
def __init__(self, in_channels=4, out_channels=4, resolution=64):
|
| 225 |
+
super(SVDNoiseUnet_Concise, self).__init__()
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
from diffusers.models.normalization import AdaGroupNorm
|
| 229 |
+
|
| 230 |
+
class NPNet(nn.Module):
|
| 231 |
+
def __init__(self, model_id, pretrained_path=' ', device='cuda') -> None:
|
| 232 |
+
super(NPNet, self).__init__()
|
| 233 |
+
|
| 234 |
+
assert model_id in ['SD1.5', 'DreamShaper', 'DiT']
|
| 235 |
+
|
| 236 |
+
self.model_id = model_id
|
| 237 |
+
self.device = device
|
| 238 |
+
self.pretrained_path = pretrained_path
|
| 239 |
+
|
| 240 |
+
(
|
| 241 |
+
self.unet_svd,
|
| 242 |
+
self.unet_embedding,
|
| 243 |
+
self.text_embedding,
|
| 244 |
+
self._alpha,
|
| 245 |
+
self._beta
|
| 246 |
+
) = self.get_model()
|
| 247 |
+
def save_model(self, save_path: str):
|
| 248 |
+
"""
|
| 249 |
+
Save this NPNet so that get_model() can later reload it.
|
| 250 |
+
"""
|
| 251 |
+
torch.save({
|
| 252 |
+
"unet_svd": self.unet_svd.state_dict(),
|
| 253 |
+
"unet_embedding": self.unet_embedding.state_dict(),
|
| 254 |
+
"embeeding": self.text_embedding.state_dict(), # matches get_model’s key
|
| 255 |
+
"alpha": self._alpha,
|
| 256 |
+
"beta": self._beta,
|
| 257 |
+
}, save_path)
|
| 258 |
+
print(f"NPNet saved to {save_path}")
|
| 259 |
+
def get_model(self):
|
| 260 |
+
|
| 261 |
+
unet_embedding = NoiseTransformer(resolution=(128,96)).to(self.device).to(torch.float32)
|
| 262 |
+
unet_svd = SVDNoiseUnet(resolution=(128,96)).to(self.device).to(torch.float32)
|
| 263 |
+
|
| 264 |
+
if self.model_id == 'DiT':
|
| 265 |
+
text_embedding = AdaGroupNorm(768 * 77, 4, 1, eps=1e-6).to(self.device).to(torch.float32)
|
| 266 |
+
else:
|
| 267 |
+
text_embedding = AdaGroupNorm(768 * 77, 4, 1, eps=1e-6).to(self.device).to(torch.float32)
|
| 268 |
+
|
| 269 |
+
# initialize random _alpha and _beta when no checkpoint is provided
|
| 270 |
+
_alpha = torch.randn(1, device=self.device)
|
| 271 |
+
_beta = torch.randn(1, device=self.device)
|
| 272 |
+
|
| 273 |
+
if '.pth' in self.pretrained_path:
|
| 274 |
+
gloden_unet = torch.load(self.pretrained_path)
|
| 275 |
+
unet_svd.load_state_dict(gloden_unet["unet_svd"],strict=True)
|
| 276 |
+
unet_embedding.load_state_dict(gloden_unet["unet_embedding"],strict=True)
|
| 277 |
+
text_embedding.load_state_dict(gloden_unet["embeeding"],strict=True)
|
| 278 |
+
_alpha = gloden_unet["alpha"]
|
| 279 |
+
_beta = gloden_unet["beta"]
|
| 280 |
+
|
| 281 |
+
print("Load Successfully!")
|
| 282 |
+
|
| 283 |
+
return unet_svd, unet_embedding, text_embedding, _alpha, _beta
|
| 284 |
+
|
| 285 |
+
else:
|
| 286 |
+
return unet_svd, unet_embedding, text_embedding, _alpha, _beta
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def forward(self, initial_noise, prompt_embeds):
|
| 290 |
+
|
| 291 |
+
prompt_embeds = prompt_embeds.float().view(prompt_embeds.shape[0], -1)
|
| 292 |
+
text_emb = self.text_embedding(initial_noise.float(), prompt_embeds)
|
| 293 |
+
|
| 294 |
+
encoder_hidden_states_svd = initial_noise
|
| 295 |
+
encoder_hidden_states_embedding = initial_noise + text_emb
|
| 296 |
+
|
| 297 |
+
golden_embedding = self.unet_embedding(encoder_hidden_states_embedding.float())
|
| 298 |
+
|
| 299 |
+
golden_noise = self.unet_svd(encoder_hidden_states_svd.float()) + (
|
| 300 |
+
2 * torch.sigmoid(self._alpha) - 1) * text_emb + self._beta * golden_embedding
|
| 301 |
+
|
| 302 |
+
return golden_noise
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
class NPNet64(nn.Module):
|
| 306 |
+
def __init__(self, model_id, pretrained_path=' ', device='cuda') -> None:
|
| 307 |
+
super(NPNet64, self).__init__()
|
| 308 |
+
self.model_id = model_id
|
| 309 |
+
self.device = device
|
| 310 |
+
self.pretrained_path = pretrained_path
|
| 311 |
+
|
| 312 |
+
(
|
| 313 |
+
self.unet_svd,
|
| 314 |
+
self.unet_embedding,
|
| 315 |
+
self.text_embedding,
|
| 316 |
+
self._alpha,
|
| 317 |
+
self._beta
|
| 318 |
+
) = self.get_model()
|
| 319 |
+
|
| 320 |
+
def save_model(self, save_path: str):
|
| 321 |
+
"""
|
| 322 |
+
Save this NPNet so that get_model() can later reload it.
|
| 323 |
+
"""
|
| 324 |
+
torch.save({
|
| 325 |
+
"unet_svd": self.unet_svd.state_dict(),
|
| 326 |
+
"unet_embedding": self.unet_embedding.state_dict(),
|
| 327 |
+
"embeeding": self.text_embedding.state_dict(), # matches get_model’s key
|
| 328 |
+
"alpha": self._alpha,
|
| 329 |
+
"beta": self._beta,
|
| 330 |
+
}, save_path)
|
| 331 |
+
print(f"NPNet saved to {save_path}")
|
| 332 |
+
|
| 333 |
+
def get_model(self):
|
| 334 |
+
|
| 335 |
+
unet_embedding = NoiseTransformer(resolution=(64,64)).to(self.device).to(torch.float32)
|
| 336 |
+
unet_svd = SVDNoiseUnet64(resolution=64).to(self.device).to(torch.float32)
|
| 337 |
+
_alpha = torch.randn(1, device=self.device)
|
| 338 |
+
_beta = torch.randn(1, device=self.device)
|
| 339 |
+
|
| 340 |
+
text_embedding = AdaGroupNorm(768 * 77, 4, 1, eps=1e-6).to(self.device).to(torch.float32)
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
if '.pth' in self.pretrained_path:
|
| 344 |
+
gloden_unet = torch.load(self.pretrained_path)
|
| 345 |
+
unet_svd.load_state_dict(gloden_unet["unet_svd"])
|
| 346 |
+
unet_embedding.load_state_dict(gloden_unet["unet_embedding"])
|
| 347 |
+
text_embedding.load_state_dict(gloden_unet["embeeding"])
|
| 348 |
+
_alpha = gloden_unet["alpha"]
|
| 349 |
+
_beta = gloden_unet["beta"]
|
| 350 |
+
|
| 351 |
+
print("Load Successfully!")
|
| 352 |
+
|
| 353 |
+
return unet_svd, unet_embedding, text_embedding, _alpha, _beta
|
| 354 |
+
|
| 355 |
+
|
| 356 |
+
def forward(self, initial_noise, prompt_embeds):
|
| 357 |
+
|
| 358 |
+
prompt_embeds = prompt_embeds.float().view(prompt_embeds.shape[0], -1)
|
| 359 |
+
text_emb = self.text_embedding(initial_noise.float(), prompt_embeds)
|
| 360 |
+
|
| 361 |
+
encoder_hidden_states_svd = initial_noise
|
| 362 |
+
encoder_hidden_states_embedding = initial_noise + text_emb
|
| 363 |
+
|
| 364 |
+
golden_embedding = self.unet_embedding(encoder_hidden_states_embedding.float())
|
| 365 |
+
|
| 366 |
+
golden_noise = self.unet_svd(encoder_hidden_states_svd.float()) + (
|
| 367 |
+
2 * torch.sigmoid(self._alpha) - 1) * text_emb + self._beta * golden_embedding
|
| 368 |
+
|
| 369 |
+
return golden_noise
|
| 370 |
+
|
| 371 |
+
class NPNet128(nn.Module):
|
| 372 |
+
def __init__(self, model_id, pretrained_path=True, device='cuda') -> None:
|
| 373 |
+
super(NPNet128, self).__init__()
|
| 374 |
+
|
| 375 |
+
assert model_id in ['SDXL', 'DreamShaper', 'DiT']
|
| 376 |
+
|
| 377 |
+
self.model_id = model_id
|
| 378 |
+
self.device = device
|
| 379 |
+
self.pretrained_path = pretrained_path
|
| 380 |
+
|
| 381 |
+
(
|
| 382 |
+
self.unet_svd,
|
| 383 |
+
self.unet_embedding,
|
| 384 |
+
self.text_embedding,
|
| 385 |
+
self._alpha,
|
| 386 |
+
self._beta
|
| 387 |
+
) = self.get_model()
|
| 388 |
+
|
| 389 |
+
def get_model(self):
|
| 390 |
+
|
| 391 |
+
unet_embedding = NoiseTransformer(resolution=(128,128)).to(self.device).to(torch.float32)
|
| 392 |
+
unet_svd = SVDNoiseUnet128(resolution=128).to(self.device).to(torch.float32)
|
| 393 |
+
|
| 394 |
+
if self.model_id == 'DiT':
|
| 395 |
+
text_embedding = AdaGroupNorm(1024 * 77, 4, 1, eps=1e-6).to(self.device).to(torch.float32)
|
| 396 |
+
else:
|
| 397 |
+
text_embedding = AdaGroupNorm(2048 * 77, 4, 1, eps=1e-6).to(self.device).to(torch.float32)
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
if '.pth' in self.pretrained_path:
|
| 401 |
+
gloden_unet = torch.load(self.pretrained_path)
|
| 402 |
+
unet_svd.load_state_dict(gloden_unet["unet_svd"])
|
| 403 |
+
unet_embedding.load_state_dict(gloden_unet["unet_embedding"])
|
| 404 |
+
text_embedding.load_state_dict(gloden_unet["embeeding"])
|
| 405 |
+
_alpha = gloden_unet["alpha"]
|
| 406 |
+
_beta = gloden_unet["beta"]
|
| 407 |
+
|
| 408 |
+
print("Load Successfully!")
|
| 409 |
+
|
| 410 |
+
return unet_svd, unet_embedding, text_embedding, _alpha, _beta
|
| 411 |
+
|
| 412 |
+
else:
|
| 413 |
+
assert ("No Pretrained Weights Found!")
|
| 414 |
+
|
| 415 |
+
|
| 416 |
+
def forward(self, initial_noise, prompt_embeds):
|
| 417 |
+
|
| 418 |
+
prompt_embeds = prompt_embeds.float().view(prompt_embeds.shape[0], -1)
|
| 419 |
+
text_emb = self.text_embedding(initial_noise.float(), prompt_embeds)
|
| 420 |
+
|
| 421 |
+
encoder_hidden_states_svd = initial_noise
|
| 422 |
+
encoder_hidden_states_embedding = initial_noise + text_emb
|
| 423 |
+
|
| 424 |
+
golden_embedding = self.unet_embedding(encoder_hidden_states_embedding.float())
|
| 425 |
+
|
| 426 |
+
golden_noise = self.unet_svd(encoder_hidden_states_svd.float()) + (
|
| 427 |
+
2 * torch.sigmoid(self._alpha) - 1) * text_emb + self._beta * golden_embedding
|
| 428 |
+
|
| 429 |
+
return golden_noise
|
| 430 |
+
|
app.py
CHANGED
|
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
import json
|
| 5 |
+
import spaces #[uncomment to use ZeroGPU]
|
| 6 |
+
from diffusers import (
|
| 7 |
+
AutoencoderKL,
|
| 8 |
+
StableDiffusionXLPipeline,
|
| 9 |
+
)
|
| 10 |
+
from huggingface_hub import login, hf_hub_download
|
| 11 |
+
from PIL import Image
|
| 12 |
+
# from huggingface_hub import login
|
| 13 |
+
from SVDNoiseUnet import NPNet64
|
| 14 |
+
import functools
|
| 15 |
+
import random
|
| 16 |
+
from free_lunch_utils import register_free_upblock2d, register_free_crossattn_upblock2d
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
from einops import rearrange
|
| 20 |
+
from torchvision.utils import make_grid
|
| 21 |
+
import time
|
| 22 |
+
from pytorch_lightning import seed_everything
|
| 23 |
+
from torch import autocast
|
| 24 |
+
from contextlib import contextmanager, nullcontext
|
| 25 |
+
import accelerate
|
| 26 |
+
import torchsde
|
| 27 |
+
from SVDNoiseUnet import NPNet128
|
| 28 |
+
from tqdm import tqdm, trange
|
| 29 |
+
from itertools import islice
|
| 30 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 31 |
+
model_repo_id = "Lykon/dreamshaper-xl-1-0" # Replace to the model you would like to use
|
| 32 |
+
from sampler import UniPCSampler
|
| 33 |
+
|
| 34 |
+
precision_scope = autocast
|
| 35 |
+
|
| 36 |
+
def chunk(it, size):
|
| 37 |
+
it = iter(it)
|
| 38 |
+
return iter(lambda: tuple(islice(it, size)), ())
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def numpy_to_pil(images):
|
| 42 |
+
"""
|
| 43 |
+
Convert a numpy image or a batch of images to a PIL image.
|
| 44 |
+
"""
|
| 45 |
+
if images.ndim == 3:
|
| 46 |
+
images = images[None, ...]
|
| 47 |
+
images = (images * 255).round().astype("uint8")
|
| 48 |
+
pil_images = [Image.fromarray(image) for image in images]
|
| 49 |
+
|
| 50 |
+
return pil_images
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def load_replacement(x):
|
| 54 |
+
try:
|
| 55 |
+
hwc = x.shape
|
| 56 |
+
y = Image.open("assets/rick.jpeg").convert("RGB").resize((hwc[1], hwc[0]))
|
| 57 |
+
y = (np.array(y) / 255.0).astype(x.dtype)
|
| 58 |
+
assert y.shape == x.shape
|
| 59 |
+
return y
|
| 60 |
+
except Exception:
|
| 61 |
+
return x
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
# Adapted from pipelines.StableDiffusionPipeline.encode_prompt
|
| 65 |
+
def encode_prompt(prompt_batch, text_encoder, tokenizer, proportion_empty_prompts, is_train=True):
|
| 66 |
+
captions = []
|
| 67 |
+
for caption in prompt_batch:
|
| 68 |
+
if random.random() < proportion_empty_prompts:
|
| 69 |
+
captions.append("")
|
| 70 |
+
elif isinstance(caption, str):
|
| 71 |
+
captions.append(caption)
|
| 72 |
+
elif isinstance(caption, (list, np.ndarray)):
|
| 73 |
+
# take a random caption if there are multiple
|
| 74 |
+
captions.append(random.choice(caption) if is_train else caption[0])
|
| 75 |
+
|
| 76 |
+
with torch.no_grad():
|
| 77 |
+
text_inputs = tokenizer(
|
| 78 |
+
captions,
|
| 79 |
+
padding="max_length",
|
| 80 |
+
max_length=tokenizer.model_max_length,
|
| 81 |
+
truncation=True,
|
| 82 |
+
return_tensors="pt",
|
| 83 |
+
)
|
| 84 |
+
text_input_ids = text_inputs.input_ids
|
| 85 |
+
prompt_embeds = text_encoder(text_input_ids.to(text_encoder.device))[0]
|
| 86 |
+
|
| 87 |
+
return prompt_embeds
|
| 88 |
+
|
| 89 |
+
def chunk(it, size):
|
| 90 |
+
it = iter(it)
|
| 91 |
+
return iter(lambda: tuple(islice(it, size)), ())
|
| 92 |
+
|
| 93 |
+
def convert_caption_json_to_str(json):
|
| 94 |
+
caption = json["caption"]
|
| 95 |
+
return caption
|
| 96 |
+
|
| 97 |
+
def prepare_sdxl_pipeline_step_parameter(pipe, prompts, need_cfg, device, negative_prompts, W = 1024, H = 1024):
|
| 98 |
+
(
|
| 99 |
+
prompt_embeds,
|
| 100 |
+
negative_prompt_embeds,
|
| 101 |
+
pooled_prompt_embeds,
|
| 102 |
+
negative_pooled_prompt_embeds,
|
| 103 |
+
) = pipe.encode_prompt(
|
| 104 |
+
prompt=prompts,
|
| 105 |
+
negative_prompt=negative_prompts,
|
| 106 |
+
device=device,
|
| 107 |
+
do_classifier_free_guidance=need_cfg,
|
| 108 |
+
)
|
| 109 |
+
# timesteps = pipe.scheduler.timesteps
|
| 110 |
+
|
| 111 |
+
prompt_embeds = prompt_embeds.to(device)
|
| 112 |
+
add_text_embeds = pooled_prompt_embeds.to(device)
|
| 113 |
+
original_size = (W, H)
|
| 114 |
+
crops_coords_top_left = (0, 0)
|
| 115 |
+
target_size = (W, H)
|
| 116 |
+
text_encoder_projection_dim = None
|
| 117 |
+
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
| 118 |
+
if pipe.text_encoder_2 is None:
|
| 119 |
+
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1])
|
| 120 |
+
else:
|
| 121 |
+
text_encoder_projection_dim = pipe.text_encoder_2.config.projection_dim
|
| 122 |
+
passed_add_embed_dim = (
|
| 123 |
+
pipe.unet.config.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim
|
| 124 |
+
)
|
| 125 |
+
expected_add_embed_dim = pipe.unet.add_embedding.linear_1.in_features
|
| 126 |
+
if expected_add_embed_dim != passed_add_embed_dim:
|
| 127 |
+
raise ValueError(
|
| 128 |
+
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
| 129 |
+
)
|
| 130 |
+
add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype)
|
| 131 |
+
add_time_ids = add_time_ids.to(device)
|
| 132 |
+
negative_add_time_ids = add_time_ids
|
| 133 |
+
|
| 134 |
+
if need_cfg:
|
| 135 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
|
| 136 |
+
add_text_embeds = torch.cat([negative_pooled_prompt_embeds, add_text_embeds], dim=0)
|
| 137 |
+
add_time_ids = torch.cat([negative_add_time_ids, add_time_ids], dim=0)
|
| 138 |
+
ret_dict = {
|
| 139 |
+
"text_embeds": add_text_embeds,
|
| 140 |
+
"time_ids": add_time_ids
|
| 141 |
+
}
|
| 142 |
+
return prompt_embeds, ret_dict
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def model_closure(pipe):
|
| 146 |
+
def model_fn(x, t, c):
|
| 147 |
+
prompt = c[0]
|
| 148 |
+
cond_kwargs = c[1] if len(c) > 1 else None
|
| 149 |
+
# prompt_embeds, cond_kwargs = prepare_sdxl_pipeline_step_parameter(pipe=pipe,prompts = prompt, need_cfg=True, device=pipe.device,negative_prompts=negative_prompt)
|
| 150 |
+
# prompt_embeds, cond_kwargs = c
|
| 151 |
+
return pipe.unet(x
|
| 152 |
+
, t
|
| 153 |
+
, encoder_hidden_states=prompt.to(device=x.device, dtype=x.dtype)
|
| 154 |
+
, added_cond_kwargs=cond_kwargs).sample
|
| 155 |
+
|
| 156 |
+
return model_fn
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
torch_dtype = torch.float16
|
| 160 |
+
repo_id = "madebyollin/sdxl-vae-fp16-fix" # e.g., "distilbert/distilgpt2"
|
| 161 |
+
vae = AutoencoderKL.from_pretrained("madebyollin/sdxl-vae-fp16-fix",torch_dtype=torch_dtype) #from_single_file(downloaded_path, torch_dtype=torch_dtype)
|
| 162 |
+
vae.to('cuda')
|
| 163 |
+
|
| 164 |
+
pipe = StableDiffusionXLPipeline.from_pretrained("John6666/illustrij-evo-lvl3-sdxl",torch_dtype=torch_dtype,vae=vae)
|
| 165 |
+
# pipe = StableDiffusionXLPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0",torch_dtype=torch.float16,vae=vae)
|
| 166 |
+
|
| 167 |
+
pipe.to('cuda')
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
MAX_SEED = np.iinfo(np.int32).max
|
| 172 |
+
MAX_IMAGE_SIZE = 1024
|
| 173 |
+
|
| 174 |
+
accelerator = accelerate.Accelerator()
|
| 175 |
+
|
| 176 |
+
def generate_image_with_steps(prompt, negative_prompt, seed, width, height, guidance_scale, num_inference_steps):
|
| 177 |
+
"""Helper function to generate image with specific number of steps"""
|
| 178 |
+
prompts = [prompt]
|
| 179 |
+
sampler = UniPCSampler(pipe,model_closure=model_closure, steps=num_inference_steps, guidance_scale=guidance_scale)
|
| 180 |
+
c = prompts
|
| 181 |
+
uc = [negative_prompt] * len(c) if guidance_scale != 1.0 else None
|
| 182 |
+
shape = [4, width // 8, height // 8]
|
| 183 |
+
# if opt.method == "dpm_solver_v3":
|
| 184 |
+
# batch_size, shape, conditioning, x_T, unconditional_conditioning
|
| 185 |
+
samples, _ = sampler.sample(
|
| 186 |
+
conditioning=c,
|
| 187 |
+
batch_size=1,
|
| 188 |
+
shape=shape,
|
| 189 |
+
unconditional_conditioning=uc,
|
| 190 |
+
x_T=None,
|
| 191 |
+
start_free_u_step=6 if num_inference_steps == 8 else 4,
|
| 192 |
+
xl_preprocess_closure = prepare_sdxl_pipeline_step_parameter,
|
| 193 |
+
# npnet = npn_net,
|
| 194 |
+
use_corrector=True,
|
| 195 |
+
)
|
| 196 |
+
|
| 197 |
+
x_samples = pipe.vae.decode(samples / pipe.vae.config.scaling_factor).sample
|
| 198 |
+
x_samples = torch.clamp((x_samples + 1.0) / 2.0, min=0.0, max=1.0)
|
| 199 |
+
x_samples = x_samples.cpu().permute(0, 2, 3, 1).numpy()
|
| 200 |
+
|
| 201 |
+
x_image_torch = torch.from_numpy(x_samples).permute(0, 3, 1, 2) # need to pay attention
|
| 202 |
+
|
| 203 |
+
for x_sample in x_image_torch:
|
| 204 |
+
x_sample = 255.0 * rearrange(x_sample.cpu().numpy(), "c h w -> h w c")
|
| 205 |
+
img = Image.fromarray(x_sample.astype(np.uint8))
|
| 206 |
+
return img
|
| 207 |
+
|
| 208 |
+
@spaces.GPU #[uncomment to use ZeroGPU]
|
| 209 |
+
def infer(
|
| 210 |
+
prompt,
|
| 211 |
+
negative_prompt,
|
| 212 |
+
seed,
|
| 213 |
+
randomize_seed,
|
| 214 |
+
resolution,
|
| 215 |
+
guidance_scale,
|
| 216 |
+
num_inference_steps,
|
| 217 |
+
progress=gr.Progress(track_tqdm=True),
|
| 218 |
+
):
|
| 219 |
+
if randomize_seed:
|
| 220 |
+
seed = random.randint(0, MAX_SEED)
|
| 221 |
+
|
| 222 |
+
# Parse resolution string into width and height
|
| 223 |
+
width, height = map(int, resolution.split('x'))
|
| 224 |
+
|
| 225 |
+
# Generate image with selected steps
|
| 226 |
+
image_quick = generate_image_with_steps(prompt, negative_prompt, seed, width, height, guidance_scale, num_inference_steps)
|
| 227 |
+
|
| 228 |
+
# Generate image with 50 steps for high quality
|
| 229 |
+
image_50_steps = generate_image_with_steps(prompt, negative_prompt, seed, width, height, guidance_scale, 50)
|
| 230 |
+
|
| 231 |
+
return image_quick, image_50_steps, seed
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
examples = [
|
| 235 |
+
"Astronaut in a jungle, cold color, muted colors, detailed, 8k",
|
| 236 |
+
"a painting of a virus monster playing guitar",
|
| 237 |
+
"a painting of a squirrel eating a burger",
|
| 238 |
+
]
|
| 239 |
+
|
| 240 |
+
css = """
|
| 241 |
+
#col-container {
|
| 242 |
+
margin: 0 auto;
|
| 243 |
+
max-width: 640px;
|
| 244 |
+
}
|
| 245 |
+
"""
|
| 246 |
+
|
| 247 |
+
with gr.Blocks(css=css) as demo:
|
| 248 |
+
with gr.Column(elem_id="col-container"):
|
| 249 |
+
gr.Markdown(" # Hyperparameters are all you need")
|
| 250 |
+
|
| 251 |
+
with gr.Row():
|
| 252 |
+
prompt = gr.Text(
|
| 253 |
+
label="Prompt",
|
| 254 |
+
show_label=False,
|
| 255 |
+
max_lines=1,
|
| 256 |
+
placeholder="Enter your prompt",
|
| 257 |
+
container=False,
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
run_button = gr.Button("Run", scale=0, variant="primary")
|
| 261 |
+
|
| 262 |
+
with gr.Row():
|
| 263 |
+
with gr.Column():
|
| 264 |
+
gr.Markdown("### Our fast inference Result")
|
| 265 |
+
result = gr.Image(label="Quick Result", show_label=False)
|
| 266 |
+
with gr.Column():
|
| 267 |
+
gr.Markdown("### Original 50 steps Result")
|
| 268 |
+
result_50_steps = gr.Image(label="50 Steps Result", show_label=False)
|
| 269 |
+
|
| 270 |
+
with gr.Accordion("Advanced Settings", open=False):
|
| 271 |
+
negative_prompt = gr.Text(
|
| 272 |
+
label="Negative prompt",
|
| 273 |
+
max_lines=1,
|
| 274 |
+
placeholder="Enter a negative prompt",
|
| 275 |
+
visible=False,
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
seed = gr.Slider(
|
| 279 |
+
label="Seed",
|
| 280 |
+
minimum=0,
|
| 281 |
+
maximum=MAX_SEED,
|
| 282 |
+
step=1,
|
| 283 |
+
value=0,
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
|
| 287 |
+
|
| 288 |
+
resolution = gr.Dropdown(
|
| 289 |
+
choices=[
|
| 290 |
+
"1024x1024",
|
| 291 |
+
"1216x832",
|
| 292 |
+
"832x1216"
|
| 293 |
+
],
|
| 294 |
+
value="1024x1024",
|
| 295 |
+
label="Resolution",
|
| 296 |
+
)
|
| 297 |
+
|
| 298 |
+
with gr.Row():
|
| 299 |
+
guidance_scale = gr.Slider(
|
| 300 |
+
label="Guidance scale",
|
| 301 |
+
minimum=0.0,
|
| 302 |
+
maximum=10.0,
|
| 303 |
+
step=0.1,
|
| 304 |
+
value=7.5, # Replace with defaults that work for your model
|
| 305 |
+
)
|
| 306 |
+
|
| 307 |
+
num_inference_steps = gr.Dropdown(
|
| 308 |
+
choices=[6, 8],
|
| 309 |
+
value=8,
|
| 310 |
+
label="Number of inference steps",
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
gr.Examples(examples=examples, inputs=[prompt])
|
| 314 |
+
gr.on(
|
| 315 |
+
triggers=[run_button.click, prompt.submit],
|
| 316 |
+
fn=infer,
|
| 317 |
+
inputs=[
|
| 318 |
+
prompt,
|
| 319 |
+
negative_prompt,
|
| 320 |
+
seed,
|
| 321 |
+
randomize_seed,
|
| 322 |
+
resolution,
|
| 323 |
+
guidance_scale,
|
| 324 |
+
num_inference_steps,
|
| 325 |
+
],
|
| 326 |
+
outputs=[result, result_50_steps, seed],
|
| 327 |
+
)
|
| 328 |
+
|
| 329 |
+
if __name__ == "__main__":
|
| 330 |
+
demo.launch()
|
dpm_solver_v3.py
ADDED
|
@@ -0,0 +1,904 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import math
|
| 4 |
+
import numpy as np
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class NoiseScheduleVP:
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
schedule="discrete",
|
| 12 |
+
betas=None,
|
| 13 |
+
alphas_cumprod=None,
|
| 14 |
+
continuous_beta_0=0.1,
|
| 15 |
+
continuous_beta_1=20.0,
|
| 16 |
+
):
|
| 17 |
+
"""Create a wrapper class for the forward SDE (VP type).
|
| 18 |
+
|
| 19 |
+
***
|
| 20 |
+
Update: We support discrete-time diffusion models by implementing a picewise linear interpolation for log_alpha_t.
|
| 21 |
+
We recommend to use schedule='discrete' for the discrete-time diffusion models, especially for high-resolution images.
|
| 22 |
+
***
|
| 23 |
+
|
| 24 |
+
The forward SDE ensures that the condition distribution q_{t|0}(x_t | x_0) = N ( alpha_t * x_0, sigma_t^2 * I ).
|
| 25 |
+
We further define lambda_t = log(alpha_t) - log(sigma_t), which is the half-logSNR (described in the DPM-Solver paper).
|
| 26 |
+
Therefore, we implement the functions for computing alpha_t, sigma_t and lambda_t. For t in [0, T], we have:
|
| 27 |
+
|
| 28 |
+
log_alpha_t = self.marginal_log_mean_coeff(t)
|
| 29 |
+
sigma_t = self.marginal_std(t)
|
| 30 |
+
lambda_t = self.marginal_lambda(t)
|
| 31 |
+
|
| 32 |
+
Moreover, as lambda(t) is an invertible function, we also support its inverse function:
|
| 33 |
+
|
| 34 |
+
t = self.inverse_lambda(lambda_t)
|
| 35 |
+
|
| 36 |
+
===============================================================
|
| 37 |
+
|
| 38 |
+
We support both discrete-time DPMs (trained on n = 0, 1, ..., N-1) and continuous-time DPMs (trained on t in [t_0, T]).
|
| 39 |
+
|
| 40 |
+
1. For discrete-time DPMs:
|
| 41 |
+
|
| 42 |
+
For discrete-time DPMs trained on n = 0, 1, ..., N-1, we convert the discrete steps to continuous time steps by:
|
| 43 |
+
t_i = (i + 1) / N
|
| 44 |
+
e.g. for N = 1000, we have t_0 = 1e-3 and T = t_{N-1} = 1.
|
| 45 |
+
We solve the corresponding diffusion ODE from time T = 1 to time t_0 = 1e-3.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
betas: A `torch.Tensor`. The beta array for the discrete-time DPM. (See the original DDPM paper for details)
|
| 49 |
+
alphas_cumprod: A `torch.Tensor`. The cumprod alphas for the discrete-time DPM. (See the original DDPM paper for details)
|
| 50 |
+
|
| 51 |
+
Note that we always have alphas_cumprod = cumprod(betas). Therefore, we only need to set one of `betas` and `alphas_cumprod`.
|
| 52 |
+
|
| 53 |
+
**Important**: Please pay special attention for the args for `alphas_cumprod`:
|
| 54 |
+
The `alphas_cumprod` is the \hat{alpha_n} arrays in the notations of DDPM. Specifically, DDPMs assume that
|
| 55 |
+
q_{t_n | 0}(x_{t_n} | x_0) = N ( \sqrt{\hat{alpha_n}} * x_0, (1 - \hat{alpha_n}) * I ).
|
| 56 |
+
Therefore, the notation \hat{alpha_n} is different from the notation alpha_t in DPM-Solver. In fact, we have
|
| 57 |
+
alpha_{t_n} = \sqrt{\hat{alpha_n}},
|
| 58 |
+
and
|
| 59 |
+
log(alpha_{t_n}) = 0.5 * log(\hat{alpha_n}).
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
2. For continuous-time DPMs:
|
| 63 |
+
|
| 64 |
+
We support two types of VPSDEs: linear (DDPM) and cosine (improved-DDPM). The hyperparameters for the noise
|
| 65 |
+
schedule are the default settings in DDPM and improved-DDPM:
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
beta_min: A `float` number. The smallest beta for the linear schedule.
|
| 69 |
+
beta_max: A `float` number. The largest beta for the linear schedule.
|
| 70 |
+
cosine_s: A `float` number. The hyperparameter in the cosine schedule.
|
| 71 |
+
cosine_beta_max: A `float` number. The hyperparameter in the cosine schedule.
|
| 72 |
+
T: A `float` number. The ending time of the forward process.
|
| 73 |
+
|
| 74 |
+
===============================================================
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
schedule: A `str`. The noise schedule of the forward SDE. 'discrete' for discrete-time DPMs,
|
| 78 |
+
'linear' or 'cosine' for continuous-time DPMs.
|
| 79 |
+
Returns:
|
| 80 |
+
A wrapper object of the forward SDE (VP type).
|
| 81 |
+
|
| 82 |
+
===============================================================
|
| 83 |
+
|
| 84 |
+
Example:
|
| 85 |
+
|
| 86 |
+
# For discrete-time DPMs, given betas (the beta array for n = 0, 1, ..., N - 1):
|
| 87 |
+
>>> ns = NoiseScheduleVP('discrete', betas=betas)
|
| 88 |
+
|
| 89 |
+
# For discrete-time DPMs, given alphas_cumprod (the \hat{alpha_n} array for n = 0, 1, ..., N - 1):
|
| 90 |
+
>>> ns = NoiseScheduleVP('discrete', alphas_cumprod=alphas_cumprod)
|
| 91 |
+
|
| 92 |
+
# For continuous-time DPMs (VPSDE), linear schedule:
|
| 93 |
+
>>> ns = NoiseScheduleVP('linear', continuous_beta_0=0.1, continuous_beta_1=20.)
|
| 94 |
+
|
| 95 |
+
"""
|
| 96 |
+
|
| 97 |
+
if schedule not in ["discrete", "linear", "cosine"]:
|
| 98 |
+
raise ValueError(
|
| 99 |
+
"Unsupported noise schedule {}. The schedule needs to be 'discrete' or 'linear' or 'cosine'".format(
|
| 100 |
+
schedule
|
| 101 |
+
)
|
| 102 |
+
)
|
| 103 |
+
self.alphas_cumprod = alphas_cumprod
|
| 104 |
+
self.sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
|
| 105 |
+
self.log_sigmas = self.sigmas.log()
|
| 106 |
+
self.schedule = schedule
|
| 107 |
+
if schedule == "discrete":
|
| 108 |
+
if betas is not None:
|
| 109 |
+
log_alphas = 0.5 * torch.log(1 - betas).cumsum(dim=0)
|
| 110 |
+
else:
|
| 111 |
+
assert alphas_cumprod is not None
|
| 112 |
+
log_alphas = 0.5 * torch.log(alphas_cumprod)
|
| 113 |
+
self.total_N = len(log_alphas)
|
| 114 |
+
self.T = 1.0
|
| 115 |
+
self.t_array = torch.linspace(0.0, 1.0, self.total_N + 1)[1:].reshape((1, -1))
|
| 116 |
+
self.log_alpha_array = log_alphas.reshape(
|
| 117 |
+
(
|
| 118 |
+
1,
|
| 119 |
+
-1,
|
| 120 |
+
)
|
| 121 |
+
)
|
| 122 |
+
else:
|
| 123 |
+
self.total_N = 1000
|
| 124 |
+
self.beta_0 = continuous_beta_0
|
| 125 |
+
self.beta_1 = continuous_beta_1
|
| 126 |
+
self.cosine_s = 0.008
|
| 127 |
+
self.cosine_beta_max = 999.0
|
| 128 |
+
self.cosine_t_max = (
|
| 129 |
+
math.atan(self.cosine_beta_max * (1.0 + self.cosine_s) / math.pi)
|
| 130 |
+
* 2.0
|
| 131 |
+
* (1.0 + self.cosine_s)
|
| 132 |
+
/ math.pi
|
| 133 |
+
- self.cosine_s
|
| 134 |
+
)
|
| 135 |
+
self.cosine_log_alpha_0 = math.log(math.cos(self.cosine_s / (1.0 + self.cosine_s) * math.pi / 2.0))
|
| 136 |
+
self.schedule = schedule
|
| 137 |
+
if schedule == "cosine":
|
| 138 |
+
# For the cosine schedule, T = 1 will have numerical issues. So we manually set the ending time T.
|
| 139 |
+
# Note that T = 0.9946 may be not the optimal setting. However, we find it works well.
|
| 140 |
+
self.T = 0.9946
|
| 141 |
+
else:
|
| 142 |
+
self.T = 1.0
|
| 143 |
+
|
| 144 |
+
def marginal_log_mean_coeff(self, t):
|
| 145 |
+
"""
|
| 146 |
+
Compute log(alpha_t) of a given continuous-time label t in [0, T].
|
| 147 |
+
"""
|
| 148 |
+
if self.schedule == "discrete":
|
| 149 |
+
return interpolate_fn(
|
| 150 |
+
t.reshape((-1, 1)), self.t_array.to(t.device), self.log_alpha_array.to(t.device)
|
| 151 |
+
).reshape((-1))
|
| 152 |
+
elif self.schedule == "linear":
|
| 153 |
+
return -0.25 * t**2 * (self.beta_1 - self.beta_0) - 0.5 * t * self.beta_0
|
| 154 |
+
elif self.schedule == "cosine":
|
| 155 |
+
log_alpha_fn = lambda s: torch.log(torch.cos((s + self.cosine_s) / (1.0 + self.cosine_s) * math.pi / 2.0))
|
| 156 |
+
log_alpha_t = log_alpha_fn(t) - self.cosine_log_alpha_0
|
| 157 |
+
return log_alpha_t
|
| 158 |
+
|
| 159 |
+
def sigma_to_t(self, sigma, quantize=None):
|
| 160 |
+
quantize = None
|
| 161 |
+
log_sigma = sigma.log()
|
| 162 |
+
dists = log_sigma - self.log_sigmas[:, None]
|
| 163 |
+
if quantize:
|
| 164 |
+
return dists.abs().argmin(dim=0).view(sigma.shape)
|
| 165 |
+
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.log_sigmas.shape[0] - 2)
|
| 166 |
+
high_idx = low_idx + 1
|
| 167 |
+
low, high = self.log_sigmas[low_idx], self.log_sigmas[high_idx]
|
| 168 |
+
w = (low - log_sigma) / (low - high)
|
| 169 |
+
w = w.clamp(0, 1)
|
| 170 |
+
t = (1 - w) * low_idx + w * high_idx
|
| 171 |
+
return t.view(sigma.shape)
|
| 172 |
+
|
| 173 |
+
def get_special_sigmas_with_timesteps(self,timesteps):
|
| 174 |
+
low_idx, high_idx, w = np.minimum(np.floor(timesteps),999), np.minimum(np.ceil(timesteps),999), torch.from_numpy( timesteps - np.floor(timesteps))
|
| 175 |
+
self.alphas_cumprod = self.alphas_cumprod.to('cpu')
|
| 176 |
+
alphas = (1 - w) * self.alphas_cumprod[low_idx] + w * self.alphas_cumprod[high_idx]
|
| 177 |
+
return ((1 - alphas) / alphas) ** 0.5
|
| 178 |
+
|
| 179 |
+
def marginal_alpha(self, t):
|
| 180 |
+
"""
|
| 181 |
+
Compute alpha_t of a given continuous-time label t in [0, T].
|
| 182 |
+
"""
|
| 183 |
+
return torch.exp(self.marginal_log_mean_coeff(t))
|
| 184 |
+
|
| 185 |
+
def marginal_std(self, t):
|
| 186 |
+
"""
|
| 187 |
+
Compute sigma_t of a given continuous-time label t in [0, T].
|
| 188 |
+
"""
|
| 189 |
+
return torch.sqrt(1.0 - torch.exp(2.0 * self.marginal_log_mean_coeff(t)))
|
| 190 |
+
|
| 191 |
+
def marginal_lambda(self, t):
|
| 192 |
+
"""
|
| 193 |
+
Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
|
| 194 |
+
"""
|
| 195 |
+
log_mean_coeff = self.marginal_log_mean_coeff(t)
|
| 196 |
+
log_std = 0.5 * torch.log(1.0 - torch.exp(2.0 * log_mean_coeff))
|
| 197 |
+
return log_mean_coeff - log_std
|
| 198 |
+
|
| 199 |
+
def inverse_lambda(self, lamb):
|
| 200 |
+
"""
|
| 201 |
+
Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
|
| 202 |
+
"""
|
| 203 |
+
if self.schedule == "linear":
|
| 204 |
+
tmp = 2.0 * (self.beta_1 - self.beta_0) * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
|
| 205 |
+
Delta = self.beta_0**2 + tmp
|
| 206 |
+
return tmp / (torch.sqrt(Delta) + self.beta_0) / (self.beta_1 - self.beta_0)
|
| 207 |
+
elif self.schedule == "discrete":
|
| 208 |
+
log_alpha = -0.5 * torch.logaddexp(torch.zeros((1,)).to(lamb.device), -2.0 * lamb)
|
| 209 |
+
t = interpolate_fn(
|
| 210 |
+
log_alpha.reshape((-1, 1)),
|
| 211 |
+
torch.flip(self.log_alpha_array.to(lamb.device), [1]),
|
| 212 |
+
torch.flip(self.t_array.to(lamb.device), [1]),
|
| 213 |
+
)
|
| 214 |
+
return t.reshape((-1,))
|
| 215 |
+
else:
|
| 216 |
+
log_alpha = -0.5 * torch.logaddexp(-2.0 * lamb, torch.zeros((1,)).to(lamb))
|
| 217 |
+
t_fn = (
|
| 218 |
+
lambda log_alpha_t: torch.arccos(torch.exp(log_alpha_t + self.cosine_log_alpha_0))
|
| 219 |
+
* 2.0
|
| 220 |
+
* (1.0 + self.cosine_s)
|
| 221 |
+
/ math.pi
|
| 222 |
+
- self.cosine_s
|
| 223 |
+
)
|
| 224 |
+
t = t_fn(log_alpha)
|
| 225 |
+
return t
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
def model_wrapper(
|
| 229 |
+
model,
|
| 230 |
+
noise_schedule,
|
| 231 |
+
model_type="noise",
|
| 232 |
+
model_kwargs={},
|
| 233 |
+
guidance_type="uncond",
|
| 234 |
+
condition=None,
|
| 235 |
+
unconditional_condition=None,
|
| 236 |
+
guidance_scale=1.0,
|
| 237 |
+
classifier_fn=None,
|
| 238 |
+
classifier_kwargs={},
|
| 239 |
+
):
|
| 240 |
+
"""Create a wrapper function for the noise prediction model.
|
| 241 |
+
|
| 242 |
+
DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
|
| 243 |
+
firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
|
| 244 |
+
|
| 245 |
+
We support four types of the diffusion model by setting `model_type`:
|
| 246 |
+
|
| 247 |
+
1. "noise": noise prediction model. (Trained by predicting noise).
|
| 248 |
+
|
| 249 |
+
2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
|
| 250 |
+
|
| 251 |
+
3. "v": velocity prediction model. (Trained by predicting the velocity).
|
| 252 |
+
The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
|
| 253 |
+
|
| 254 |
+
[1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
|
| 255 |
+
arXiv preprint arXiv:2202.00512 (2022).
|
| 256 |
+
[2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
|
| 257 |
+
arXiv preprint arXiv:2210.02303 (2022).
|
| 258 |
+
|
| 259 |
+
4. "score": marginal score function. (Trained by denoising score matching).
|
| 260 |
+
Note that the score function and the noise prediction model follows a simple relationship:
|
| 261 |
+
```
|
| 262 |
+
noise(x_t, t) = -sigma_t * score(x_t, t)
|
| 263 |
+
```
|
| 264 |
+
|
| 265 |
+
We support three types of guided sampling by DPMs by setting `guidance_type`:
|
| 266 |
+
1. "uncond": unconditional sampling by DPMs.
|
| 267 |
+
The input `model` has the following format:
|
| 268 |
+
``
|
| 269 |
+
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
| 270 |
+
``
|
| 271 |
+
|
| 272 |
+
2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
|
| 273 |
+
The input `model` has the following format:
|
| 274 |
+
``
|
| 275 |
+
model(x, t_input, **model_kwargs) -> noise | x_start | v | score
|
| 276 |
+
``
|
| 277 |
+
|
| 278 |
+
The input `classifier_fn` has the following format:
|
| 279 |
+
``
|
| 280 |
+
classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
|
| 281 |
+
``
|
| 282 |
+
|
| 283 |
+
[3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
|
| 284 |
+
in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
|
| 285 |
+
|
| 286 |
+
3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
|
| 287 |
+
The input `model` has the following format:
|
| 288 |
+
``
|
| 289 |
+
model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
|
| 290 |
+
``
|
| 291 |
+
And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
|
| 292 |
+
|
| 293 |
+
[4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
|
| 294 |
+
arXiv preprint arXiv:2207.12598 (2022).
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
|
| 298 |
+
or continuous-time labels (i.e. epsilon to T).
|
| 299 |
+
|
| 300 |
+
We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
|
| 301 |
+
``
|
| 302 |
+
def model_fn(x, t_continuous) -> noise:
|
| 303 |
+
t_input = get_model_input_time(t_continuous)
|
| 304 |
+
return noise_pred(model, x, t_input, **model_kwargs)
|
| 305 |
+
``
|
| 306 |
+
where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
|
| 307 |
+
|
| 308 |
+
===============================================================
|
| 309 |
+
|
| 310 |
+
Args:
|
| 311 |
+
model: A diffusion model with the corresponding format described above.
|
| 312 |
+
noise_schedule: A noise schedule object, such as NoiseScheduleVP.
|
| 313 |
+
model_type: A `str`. The parameterization type of the diffusion model.
|
| 314 |
+
"noise" or "x_start" or "v" or "score".
|
| 315 |
+
model_kwargs: A `dict`. A dict for the other inputs of the model function.
|
| 316 |
+
guidance_type: A `str`. The type of the guidance for sampling.
|
| 317 |
+
"uncond" or "classifier" or "classifier-free".
|
| 318 |
+
condition: A pytorch tensor. The condition for the guided sampling.
|
| 319 |
+
Only used for "classifier" or "classifier-free" guidance type.
|
| 320 |
+
unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
|
| 321 |
+
Only used for "classifier-free" guidance type.
|
| 322 |
+
guidance_scale: A `float`. The scale for the guided sampling.
|
| 323 |
+
classifier_fn: A classifier function. Only used for the classifier guidance.
|
| 324 |
+
classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
|
| 325 |
+
Returns:
|
| 326 |
+
A noise prediction model that accepts the noised data and the continuous time as the inputs.
|
| 327 |
+
"""
|
| 328 |
+
|
| 329 |
+
def get_model_input_time(t_continuous):
|
| 330 |
+
"""
|
| 331 |
+
Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
|
| 332 |
+
For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
|
| 333 |
+
For continuous-time DPMs, we just use `t_continuous`.
|
| 334 |
+
"""
|
| 335 |
+
if noise_schedule.schedule == "discrete":
|
| 336 |
+
return (t_continuous - 1.0 / noise_schedule.total_N) * 1000.0
|
| 337 |
+
else:
|
| 338 |
+
return t_continuous
|
| 339 |
+
|
| 340 |
+
def noise_pred_fn(x, t_continuous, cond=None):
|
| 341 |
+
if t_continuous.reshape((-1,)).shape[0] == 1:
|
| 342 |
+
t_continuous = t_continuous.expand((x.shape[0]))
|
| 343 |
+
t_input = get_model_input_time(t_continuous)
|
| 344 |
+
if cond is None:
|
| 345 |
+
output = model(x, t_input, None, **model_kwargs)
|
| 346 |
+
else:
|
| 347 |
+
output = model(x, t_input, cond, **model_kwargs)
|
| 348 |
+
if model_type == "noise":
|
| 349 |
+
return output
|
| 350 |
+
elif model_type == "x_start":
|
| 351 |
+
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
| 352 |
+
dims = x.dim()
|
| 353 |
+
return (x - expand_dims(alpha_t, dims) * output) / expand_dims(sigma_t, dims)
|
| 354 |
+
elif model_type == "v":
|
| 355 |
+
alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
|
| 356 |
+
dims = x.dim()
|
| 357 |
+
return expand_dims(alpha_t, dims) * output + expand_dims(sigma_t, dims) * x
|
| 358 |
+
elif model_type == "score":
|
| 359 |
+
sigma_t = noise_schedule.marginal_std(t_continuous)
|
| 360 |
+
dims = x.dim()
|
| 361 |
+
return -expand_dims(sigma_t, dims) * output
|
| 362 |
+
|
| 363 |
+
def cond_grad_fn(x, t_input):
|
| 364 |
+
"""
|
| 365 |
+
Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
|
| 366 |
+
"""
|
| 367 |
+
with torch.enable_grad():
|
| 368 |
+
x_in = x.detach().requires_grad_(True)
|
| 369 |
+
log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
|
| 370 |
+
return torch.autograd.grad(log_prob.sum(), x_in)[0]
|
| 371 |
+
|
| 372 |
+
def model_fn(x, t_continuous):
|
| 373 |
+
"""
|
| 374 |
+
The noise predicition model function that is used for DPM-Solver.
|
| 375 |
+
"""
|
| 376 |
+
if t_continuous.reshape((-1,)).shape[0] == 1:
|
| 377 |
+
t_continuous = t_continuous.expand((x.shape[0]))
|
| 378 |
+
if guidance_type == "uncond":
|
| 379 |
+
return noise_pred_fn(x, t_continuous)
|
| 380 |
+
elif guidance_type == "classifier":
|
| 381 |
+
assert classifier_fn is not None
|
| 382 |
+
t_input = get_model_input_time(t_continuous)
|
| 383 |
+
cond_grad = cond_grad_fn(x, t_input)
|
| 384 |
+
sigma_t = noise_schedule.marginal_std(t_continuous)
|
| 385 |
+
noise = noise_pred_fn(x, t_continuous)
|
| 386 |
+
return noise - guidance_scale * expand_dims(sigma_t, dims=cond_grad.dim()) * cond_grad
|
| 387 |
+
elif guidance_type == "classifier-free":
|
| 388 |
+
if guidance_scale == 1.0 or unconditional_condition is None:
|
| 389 |
+
return noise_pred_fn(x, t_continuous, cond=condition)
|
| 390 |
+
else:
|
| 391 |
+
x_in = torch.cat([x] * 2)
|
| 392 |
+
t_in = torch.cat([t_continuous] * 2)
|
| 393 |
+
if isinstance(condition, torch.Tensor) and ( isinstance(unconditional_condition, torch.Tensor) or unconditional_condition is None ):
|
| 394 |
+
c_in = torch.cat([unconditional_condition, condition])
|
| 395 |
+
else:
|
| 396 |
+
c_in = [condition, unconditional_condition]
|
| 397 |
+
# c_in = torch.cat([unconditional_condition, condition])
|
| 398 |
+
noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
|
| 399 |
+
return noise_uncond + guidance_scale * (noise - noise_uncond)
|
| 400 |
+
|
| 401 |
+
assert model_type in ["noise", "x_start", "v"]
|
| 402 |
+
assert guidance_type in ["uncond", "classifier", "classifier-free"]
|
| 403 |
+
return model_fn
|
| 404 |
+
|
| 405 |
+
|
| 406 |
+
def weighted_cumsumexp_trapezoid(a, x, b, cumsum=True):
|
| 407 |
+
# ∫ b*e^a dx
|
| 408 |
+
# Input: a,x,b: shape (N+1,...)
|
| 409 |
+
# Output: y: shape (N+1,...)
|
| 410 |
+
# y_0 = 0
|
| 411 |
+
# y_n = sum_{i=1}^{n} 0.5*(x_{i}-x_{i-1})*(b_{i}*e^{a_{i}}+b_{i-1}*e^{a_{i-1}}) (n from 1 to N)
|
| 412 |
+
|
| 413 |
+
assert x.shape[0] == a.shape[0] and x.ndim == a.ndim
|
| 414 |
+
if b is not None:
|
| 415 |
+
assert a.shape[0] == b.shape[0] and a.ndim == b.ndim
|
| 416 |
+
|
| 417 |
+
a_max = np.amax(a, axis=0, keepdims=True)
|
| 418 |
+
|
| 419 |
+
if b is not None:
|
| 420 |
+
b = np.asarray(b)
|
| 421 |
+
tmp = b * np.exp(a - a_max)
|
| 422 |
+
else:
|
| 423 |
+
tmp = np.exp(a - a_max)
|
| 424 |
+
|
| 425 |
+
out = 0.5 * (x[1:] - x[:-1]) * (tmp[1:] + tmp[:-1])
|
| 426 |
+
if not cumsum:
|
| 427 |
+
return np.sum(out, axis=0) * np.exp(a_max)
|
| 428 |
+
out = np.cumsum(out, axis=0)
|
| 429 |
+
out *= np.exp(a_max)
|
| 430 |
+
return np.concatenate([np.zeros_like(out[[0]]), out], axis=0)
|
| 431 |
+
|
| 432 |
+
|
| 433 |
+
def weighted_cumsumexp_trapezoid_torch(a, x, b, cumsum=True):
|
| 434 |
+
assert x.shape[0] == a.shape[0] and x.ndim == a.ndim
|
| 435 |
+
if b is not None:
|
| 436 |
+
assert a.shape[0] == b.shape[0] and a.ndim == b.ndim
|
| 437 |
+
|
| 438 |
+
a_max = torch.amax(a, dim=0, keepdims=True)
|
| 439 |
+
|
| 440 |
+
if b is not None:
|
| 441 |
+
tmp = b * torch.exp(a - a_max)
|
| 442 |
+
else:
|
| 443 |
+
tmp = torch.exp(a - a_max)
|
| 444 |
+
|
| 445 |
+
out = 0.5 * (x[1:] - x[:-1]) * (tmp[1:] + tmp[:-1])
|
| 446 |
+
if not cumsum:
|
| 447 |
+
return torch.sum(out, dim=0) * torch.exp(a_max)
|
| 448 |
+
out = torch.cumsum(out, dim=0)
|
| 449 |
+
out *= torch.exp(a_max)
|
| 450 |
+
return torch.concat([torch.zeros_like(out[[0]]), out], dim=0)
|
| 451 |
+
|
| 452 |
+
|
| 453 |
+
def index_list(lst, index):
|
| 454 |
+
new_lst = []
|
| 455 |
+
for i in index:
|
| 456 |
+
new_lst.append(lst[i])
|
| 457 |
+
return new_lst
|
| 458 |
+
|
| 459 |
+
|
| 460 |
+
class DPM_Solver_v3:
|
| 461 |
+
def __init__(
|
| 462 |
+
self,
|
| 463 |
+
statistics_dir,
|
| 464 |
+
noise_schedule,
|
| 465 |
+
steps=10,
|
| 466 |
+
t_start=None,
|
| 467 |
+
t_end=None,
|
| 468 |
+
skip_type="time_uniform",
|
| 469 |
+
degenerated=False,
|
| 470 |
+
device="cuda",
|
| 471 |
+
):
|
| 472 |
+
self.device = device
|
| 473 |
+
self.model = None
|
| 474 |
+
self.noise_schedule = noise_schedule
|
| 475 |
+
self.steps = steps
|
| 476 |
+
t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end
|
| 477 |
+
t_T = self.noise_schedule.T if t_start is None else t_start
|
| 478 |
+
assert (
|
| 479 |
+
t_0 > 0 and t_T > 0
|
| 480 |
+
), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
|
| 481 |
+
|
| 482 |
+
l = np.load(os.path.join(statistics_dir, "l.npz"))["l"]
|
| 483 |
+
sb = np.load(os.path.join(statistics_dir, "sb.npz"))
|
| 484 |
+
s, b = sb["s"], sb["b"]
|
| 485 |
+
if degenerated:
|
| 486 |
+
l = np.ones_like(l)
|
| 487 |
+
s = np.zeros_like(s)
|
| 488 |
+
b = np.zeros_like(b)
|
| 489 |
+
self.statistics_steps = l.shape[0] - 1
|
| 490 |
+
ts = noise_schedule.marginal_lambda(
|
| 491 |
+
self.get_time_steps("logSNR", t_T, t_0, self.statistics_steps, "cpu")
|
| 492 |
+
).numpy()[:, None, None, None]
|
| 493 |
+
self.ts = torch.from_numpy(ts).cuda()
|
| 494 |
+
self.lambda_T = self.ts[0].cpu().item()
|
| 495 |
+
self.lambda_0 = self.ts[-1].cpu().item()
|
| 496 |
+
z = np.zeros_like(l)
|
| 497 |
+
o = np.ones_like(l)
|
| 498 |
+
L = weighted_cumsumexp_trapezoid(z, ts, l)
|
| 499 |
+
S = weighted_cumsumexp_trapezoid(z, ts, s)
|
| 500 |
+
|
| 501 |
+
I = weighted_cumsumexp_trapezoid(L + S, ts, o)
|
| 502 |
+
B = weighted_cumsumexp_trapezoid(-S, ts, b)
|
| 503 |
+
C = weighted_cumsumexp_trapezoid(L + S, ts, B)
|
| 504 |
+
self.l = torch.from_numpy(l).cuda()
|
| 505 |
+
self.s = torch.from_numpy(s).cuda()
|
| 506 |
+
self.b = torch.from_numpy(b).cuda()
|
| 507 |
+
self.L = torch.from_numpy(L).cuda()
|
| 508 |
+
self.S = torch.from_numpy(S).cuda()
|
| 509 |
+
self.I = torch.from_numpy(I).cuda()
|
| 510 |
+
self.B = torch.from_numpy(B).cuda()
|
| 511 |
+
self.C = torch.from_numpy(C).cuda()
|
| 512 |
+
|
| 513 |
+
# precompute timesteps
|
| 514 |
+
if skip_type == "logSNR" or skip_type == "time_uniform" or skip_type == "time_quadratic" or skip_type == "customed_time_karras":
|
| 515 |
+
self.timesteps = self.get_time_steps(skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
|
| 516 |
+
self.indexes = self.convert_to_indexes(self.timesteps)
|
| 517 |
+
self.timesteps = self.convert_to_timesteps(self.indexes, device)
|
| 518 |
+
elif skip_type == "edm":
|
| 519 |
+
self.indexes, self.timesteps = self.get_timesteps_edm(N=steps, device=device)
|
| 520 |
+
self.timesteps = self.convert_to_timesteps(self.indexes, device)
|
| 521 |
+
else:
|
| 522 |
+
raise ValueError(f"Unsupported timestep strategy {skip_type}")
|
| 523 |
+
|
| 524 |
+
print("Indexes", self.indexes)
|
| 525 |
+
print("Time steps", self.timesteps)
|
| 526 |
+
print("LogSNR steps", self.noise_schedule.marginal_lambda(self.timesteps))
|
| 527 |
+
|
| 528 |
+
# store high-order exponential coefficients (lazy)
|
| 529 |
+
self.exp_coeffs = {}
|
| 530 |
+
|
| 531 |
+
def noise_prediction_fn(self, x, t):
|
| 532 |
+
"""
|
| 533 |
+
Return the noise prediction model.
|
| 534 |
+
"""
|
| 535 |
+
return self.model(x, t)
|
| 536 |
+
|
| 537 |
+
def convert_to_indexes(self, timesteps):
|
| 538 |
+
logSNR_steps = self.noise_schedule.marginal_lambda(timesteps)
|
| 539 |
+
indexes = list(
|
| 540 |
+
(self.statistics_steps * (logSNR_steps - self.lambda_T) / (self.lambda_0 - self.lambda_T))
|
| 541 |
+
.round()
|
| 542 |
+
.cpu()
|
| 543 |
+
.numpy()
|
| 544 |
+
.astype(np.int64)
|
| 545 |
+
)
|
| 546 |
+
return indexes
|
| 547 |
+
|
| 548 |
+
def convert_to_timesteps(self, indexes, device):
|
| 549 |
+
logSNR_steps = (
|
| 550 |
+
self.lambda_T + (self.lambda_0 - self.lambda_T) * torch.Tensor(indexes).to(device) / self.statistics_steps
|
| 551 |
+
)
|
| 552 |
+
return self.noise_schedule.inverse_lambda(logSNR_steps)
|
| 553 |
+
|
| 554 |
+
def append_zero(self, x):
|
| 555 |
+
return torch.cat([x, x.new_zeros([1])])
|
| 556 |
+
|
| 557 |
+
def get_sigmas_karras(self, n, sigma_min, sigma_max, rho=7., device='cpu', need_append_zero=True):
|
| 558 |
+
"""Constructs the noise schedule of Karras et al. (2022)."""
|
| 559 |
+
ramp = torch.linspace(0, 1, n)
|
| 560 |
+
min_inv_rho = sigma_min ** (1 / rho)
|
| 561 |
+
max_inv_rho = sigma_max ** (1 / rho)
|
| 562 |
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
| 563 |
+
return self.append_zero(sigmas).to(device) if need_append_zero else sigmas.to(device)
|
| 564 |
+
|
| 565 |
+
def sigma_to_t(self, sigma, quantize=None):
|
| 566 |
+
quantize = False
|
| 567 |
+
log_sigma = sigma.log()
|
| 568 |
+
dists = log_sigma - self.noise_schedule.log_sigmas[:, None]
|
| 569 |
+
if quantize:
|
| 570 |
+
return dists.abs().argmin(dim=0).view(sigma.shape)
|
| 571 |
+
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.noise_schedule.log_sigmas.shape[0] - 2)
|
| 572 |
+
high_idx = low_idx + 1
|
| 573 |
+
low, high = self.noise_schedule.log_sigmas[low_idx], self.noise_schedule.log_sigmas[high_idx]
|
| 574 |
+
w = (low - log_sigma) / (low - high)
|
| 575 |
+
w = w.clamp(0, 1)
|
| 576 |
+
t = (1 - w) * low_idx + w * high_idx
|
| 577 |
+
return t.view(sigma.shape)
|
| 578 |
+
|
| 579 |
+
def get_time_steps(self, skip_type, t_T, t_0, N, device):
|
| 580 |
+
"""Compute the intermediate time steps for sampling.
|
| 581 |
+
|
| 582 |
+
Args:
|
| 583 |
+
skip_type: A `str`. The type for the spacing of the time steps. We support three types:
|
| 584 |
+
- 'logSNR': uniform logSNR for the time steps.
|
| 585 |
+
- 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
|
| 586 |
+
- 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
|
| 587 |
+
t_T: A `float`. The starting time of the sampling (default is T).
|
| 588 |
+
t_0: A `float`. The ending time of the sampling (default is epsilon).
|
| 589 |
+
N: A `int`. The total number of the spacing of the time steps.
|
| 590 |
+
device: A torch device.
|
| 591 |
+
Returns:
|
| 592 |
+
A pytorch tensor of the time steps, with the shape (N + 1,).
|
| 593 |
+
"""
|
| 594 |
+
if skip_type == "logSNR":
|
| 595 |
+
lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
|
| 596 |
+
lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
|
| 597 |
+
logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
|
| 598 |
+
return self.noise_schedule.inverse_lambda(logSNR_steps)
|
| 599 |
+
elif skip_type == "time_uniform":
|
| 600 |
+
return torch.linspace(t_T, t_0, N + 1).to(device)
|
| 601 |
+
elif skip_type == "time_quadratic":
|
| 602 |
+
t_order = 2
|
| 603 |
+
t = torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device)
|
| 604 |
+
return t
|
| 605 |
+
elif skip_type == "customed_time_karras":
|
| 606 |
+
sigma_T = self.noise_schedule.sigmas[-1].cpu().item()
|
| 607 |
+
sigma_0 = self.noise_schedule.sigmas[0].cpu().item()
|
| 608 |
+
if N == 8:
|
| 609 |
+
sigmas = self.get_sigmas_karras(12, sigma_0, sigma_T, rho=7.0, device=device)
|
| 610 |
+
ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[9])
|
| 611 |
+
ct = self.get_sigmas_karras(9, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
|
| 612 |
+
sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
|
| 613 |
+
real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
|
| 614 |
+
elif N == 5:
|
| 615 |
+
sigmas = self.get_sigmas_karras(8, sigma_0, sigma_T, rho=5.0, device=device)
|
| 616 |
+
ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[6])
|
| 617 |
+
ct = self.get_sigmas_karras(6, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
|
| 618 |
+
sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
|
| 619 |
+
real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
|
| 620 |
+
elif N == 6:
|
| 621 |
+
sigmas = self.sigmas = self.get_sigmas_karras(8, sigma_0, sigma_T, rho=5.0, device=device)
|
| 622 |
+
ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[6])
|
| 623 |
+
ct = self.get_sigmas_karras(7, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
|
| 624 |
+
sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
|
| 625 |
+
real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
|
| 626 |
+
none_k_ct = torch.from_numpy(np.array(real_ct)).to(device)
|
| 627 |
+
return none_k_ct#real_ct
|
| 628 |
+
else:
|
| 629 |
+
raise ValueError(
|
| 630 |
+
"Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)
|
| 631 |
+
)
|
| 632 |
+
|
| 633 |
+
def get_timesteps_edm(self, N, device):
|
| 634 |
+
"""Constructs the noise schedule of Karras et al. (2022)."""
|
| 635 |
+
|
| 636 |
+
rho = 7.0 # 7.0 is the value used in the paper
|
| 637 |
+
|
| 638 |
+
sigma_min: float = np.exp(-self.lambda_0)
|
| 639 |
+
sigma_max: float = np.exp(-self.lambda_T)
|
| 640 |
+
ramp = np.linspace(0, 1, N + 1)
|
| 641 |
+
min_inv_rho = sigma_min ** (1 / rho)
|
| 642 |
+
max_inv_rho = sigma_max ** (1 / rho)
|
| 643 |
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
| 644 |
+
lambdas = torch.Tensor(-np.log(sigmas)).to(device)
|
| 645 |
+
timesteps = self.noise_schedule.inverse_lambda(lambdas)
|
| 646 |
+
|
| 647 |
+
indexes = list(
|
| 648 |
+
(self.statistics_steps * (lambdas - self.lambda_T) / (self.lambda_0 - self.lambda_T))
|
| 649 |
+
.round()
|
| 650 |
+
.cpu()
|
| 651 |
+
.numpy()
|
| 652 |
+
.astype(np.int64)
|
| 653 |
+
)
|
| 654 |
+
return indexes, timesteps
|
| 655 |
+
|
| 656 |
+
def get_g(self, f_t, i_s, i_t):
|
| 657 |
+
return torch.exp(self.S[i_s] - self.S[i_t]) * f_t - torch.exp(self.S[i_s]) * (self.B[i_t] - self.B[i_s])
|
| 658 |
+
|
| 659 |
+
def compute_exponential_coefficients_high_order(self, i_s, i_t, order=2):
|
| 660 |
+
key = (i_s, i_t, order)
|
| 661 |
+
if key in self.exp_coeffs.keys():
|
| 662 |
+
coeffs = self.exp_coeffs[key]
|
| 663 |
+
else:
|
| 664 |
+
n = order - 1
|
| 665 |
+
a = self.L[i_s : i_t + 1] + self.S[i_s : i_t + 1] - self.L[i_s] - self.S[i_s]
|
| 666 |
+
x = self.ts[i_s : i_t + 1]
|
| 667 |
+
b = (self.ts[i_s : i_t + 1] - self.ts[i_s]) ** n / math.factorial(n)
|
| 668 |
+
coeffs = weighted_cumsumexp_trapezoid_torch(a, x, b, cumsum=False)
|
| 669 |
+
self.exp_coeffs[key] = coeffs
|
| 670 |
+
return coeffs
|
| 671 |
+
|
| 672 |
+
def compute_high_order_derivatives(self, n, lambda_0n, g_0n, pseudo=False):
|
| 673 |
+
# return g^(1), ..., g^(n)
|
| 674 |
+
if pseudo:
|
| 675 |
+
D = [[] for _ in range(n + 1)]
|
| 676 |
+
D[0] = g_0n
|
| 677 |
+
for i in range(1, n + 1):
|
| 678 |
+
for j in range(n - i + 1):
|
| 679 |
+
D[i].append((D[i - 1][j] - D[i - 1][j + 1]) / (lambda_0n[j] - lambda_0n[i + j]))
|
| 680 |
+
|
| 681 |
+
return [D[i][0] * math.factorial(i) for i in range(1, n + 1)]
|
| 682 |
+
else:
|
| 683 |
+
R = []
|
| 684 |
+
for i in range(1, n + 1):
|
| 685 |
+
R.append(torch.pow(lambda_0n[1:] - lambda_0n[0], i))
|
| 686 |
+
R = torch.stack(R).t()
|
| 687 |
+
B = (torch.stack(g_0n[1:]) - g_0n[0]).reshape(n, -1)
|
| 688 |
+
shape = g_0n[0].shape
|
| 689 |
+
solution = torch.linalg.inv(R) @ B
|
| 690 |
+
solution = solution.reshape([n] + list(shape))
|
| 691 |
+
return [solution[i - 1] * math.factorial(i) for i in range(1, n + 1)]
|
| 692 |
+
|
| 693 |
+
def multistep_predictor_update(self, x_lst, eps_lst, time_lst, index_lst, t, i_t, order=1, pseudo=False):
|
| 694 |
+
# x_lst: [..., x_s]
|
| 695 |
+
# eps_lst: [..., eps_s]
|
| 696 |
+
# time_lst: [..., time_s]
|
| 697 |
+
ns = self.noise_schedule
|
| 698 |
+
n = order - 1
|
| 699 |
+
indexes = [-i - 1 for i in range(n + 1)]
|
| 700 |
+
x_0n = index_list(x_lst, indexes)
|
| 701 |
+
eps_0n = index_list(eps_lst, indexes)
|
| 702 |
+
time_0n = torch.FloatTensor(index_list(time_lst, indexes)).cuda()
|
| 703 |
+
index_0n = index_list(index_lst, indexes)
|
| 704 |
+
lambda_0n = ns.marginal_lambda(time_0n)
|
| 705 |
+
alpha_0n = ns.marginal_alpha(time_0n)
|
| 706 |
+
sigma_0n = ns.marginal_std(time_0n)
|
| 707 |
+
|
| 708 |
+
alpha_s, alpha_t = alpha_0n[0], ns.marginal_alpha(t)
|
| 709 |
+
i_s = index_0n[0]
|
| 710 |
+
x_s = x_0n[0]
|
| 711 |
+
g_0n = []
|
| 712 |
+
for i in range(n + 1):
|
| 713 |
+
f_i = (sigma_0n[i] * eps_0n[i] - self.l[index_0n[i]] * x_0n[i]) / alpha_0n[i]
|
| 714 |
+
g_i = self.get_g(f_i, index_0n[0], index_0n[i])
|
| 715 |
+
g_0n.append(g_i)
|
| 716 |
+
g_0 = g_0n[0]
|
| 717 |
+
x_t = (
|
| 718 |
+
alpha_t / alpha_s * torch.exp(self.L[i_s] - self.L[i_t]) * x_s
|
| 719 |
+
- alpha_t * torch.exp(-self.L[i_t] - self.S[i_s]) * (self.I[i_t] - self.I[i_s]) * g_0
|
| 720 |
+
- alpha_t
|
| 721 |
+
* torch.exp(-self.L[i_t])
|
| 722 |
+
* (self.C[i_t] - self.C[i_s] - self.B[i_s] * (self.I[i_t] - self.I[i_s]))
|
| 723 |
+
)
|
| 724 |
+
if order > 1:
|
| 725 |
+
g_d = self.compute_high_order_derivatives(n, lambda_0n, g_0n, pseudo=pseudo)
|
| 726 |
+
for i in range(order - 1):
|
| 727 |
+
x_t = (
|
| 728 |
+
x_t
|
| 729 |
+
- alpha_t
|
| 730 |
+
* torch.exp(self.L[i_s] - self.L[i_t])
|
| 731 |
+
* self.compute_exponential_coefficients_high_order(i_s, i_t, order=i + 2)
|
| 732 |
+
* g_d[i]
|
| 733 |
+
)
|
| 734 |
+
return x_t
|
| 735 |
+
|
| 736 |
+
def multistep_corrector_update(self, x_lst, eps_lst, time_lst, index_lst, order=1, pseudo=False):
|
| 737 |
+
# x_lst: [..., x_s, x_t]
|
| 738 |
+
# eps_lst: [..., eps_s, eps_t]
|
| 739 |
+
# lambda_lst: [..., lambda_s, lambda_t]
|
| 740 |
+
ns = self.noise_schedule
|
| 741 |
+
n = order - 1
|
| 742 |
+
indexes = [-i - 1 for i in range(n + 1)]
|
| 743 |
+
indexes[0] = -2
|
| 744 |
+
indexes[1] = -1
|
| 745 |
+
x_0n = index_list(x_lst, indexes)
|
| 746 |
+
eps_0n = index_list(eps_lst, indexes)
|
| 747 |
+
time_0n = torch.FloatTensor(index_list(time_lst, indexes)).cuda()
|
| 748 |
+
index_0n = index_list(index_lst, indexes)
|
| 749 |
+
lambda_0n = ns.marginal_lambda(time_0n)
|
| 750 |
+
alpha_0n = ns.marginal_alpha(time_0n)
|
| 751 |
+
sigma_0n = ns.marginal_std(time_0n)
|
| 752 |
+
|
| 753 |
+
alpha_s, alpha_t = alpha_0n[0], alpha_0n[1]
|
| 754 |
+
i_s, i_t = index_0n[0], index_0n[1]
|
| 755 |
+
x_s = x_0n[0]
|
| 756 |
+
g_0n = []
|
| 757 |
+
for i in range(n + 1):
|
| 758 |
+
f_i = (sigma_0n[i] * eps_0n[i] - self.l[index_0n[i]] * x_0n[i]) / alpha_0n[i]
|
| 759 |
+
g_i = self.get_g(f_i, index_0n[0], index_0n[i])
|
| 760 |
+
g_0n.append(g_i)
|
| 761 |
+
g_0 = g_0n[0]
|
| 762 |
+
x_t_new = (
|
| 763 |
+
alpha_t / alpha_s * torch.exp(self.L[i_s] - self.L[i_t]) * x_s
|
| 764 |
+
- alpha_t * torch.exp(-self.L[i_t] - self.S[i_s]) * (self.I[i_t] - self.I[i_s]) * g_0
|
| 765 |
+
- alpha_t
|
| 766 |
+
* torch.exp(-self.L[i_t])
|
| 767 |
+
* (self.C[i_t] - self.C[i_s] - self.B[i_s] * (self.I[i_t] - self.I[i_s]))
|
| 768 |
+
)
|
| 769 |
+
if order > 1:
|
| 770 |
+
g_d = self.compute_high_order_derivatives(n, lambda_0n, g_0n, pseudo=pseudo)
|
| 771 |
+
for i in range(order - 1):
|
| 772 |
+
x_t_new = (
|
| 773 |
+
x_t_new
|
| 774 |
+
- alpha_t
|
| 775 |
+
* torch.exp(self.L[i_s] - self.L[i_t])
|
| 776 |
+
* self.compute_exponential_coefficients_high_order(i_s, i_t, order=i + 2)
|
| 777 |
+
* g_d[i]
|
| 778 |
+
)
|
| 779 |
+
return x_t_new
|
| 780 |
+
|
| 781 |
+
def sample(
|
| 782 |
+
self,
|
| 783 |
+
x,
|
| 784 |
+
model_fn,
|
| 785 |
+
order,
|
| 786 |
+
p_pseudo,
|
| 787 |
+
use_corrector,
|
| 788 |
+
c_pseudo,
|
| 789 |
+
lower_order_final,
|
| 790 |
+
start_free_u_step=None,
|
| 791 |
+
free_u_apply_callback=None,
|
| 792 |
+
free_u_stop_callback=None,
|
| 793 |
+
half=False,
|
| 794 |
+
return_intermediate=False,
|
| 795 |
+
):
|
| 796 |
+
self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])))
|
| 797 |
+
steps = self.steps
|
| 798 |
+
cached_x = []
|
| 799 |
+
cached_model_output = []
|
| 800 |
+
cached_time = []
|
| 801 |
+
cached_index = []
|
| 802 |
+
indexes, timesteps = self.indexes, self.timesteps
|
| 803 |
+
step_p_order = 0
|
| 804 |
+
if free_u_stop_callback is not None:
|
| 805 |
+
free_u_stop_callback()
|
| 806 |
+
for step in range(1, steps + 1):
|
| 807 |
+
if start_free_u_step is not None and step == start_free_u_step and free_u_apply_callback is not None:
|
| 808 |
+
free_u_apply_callback()
|
| 809 |
+
cached_x.append(x)
|
| 810 |
+
cached_model_output.append(self.noise_prediction_fn(x, timesteps[step - 1]))
|
| 811 |
+
cached_time.append(timesteps[step - 1])
|
| 812 |
+
cached_index.append(indexes[step - 1])
|
| 813 |
+
if use_corrector and (timesteps[step - 1] > 0.5 or not half):
|
| 814 |
+
step_c_order = step_p_order + c_pseudo
|
| 815 |
+
if step_c_order > 1:
|
| 816 |
+
x_new = self.multistep_corrector_update(
|
| 817 |
+
cached_x, cached_model_output, cached_time, cached_index, order=step_c_order, pseudo=c_pseudo
|
| 818 |
+
)
|
| 819 |
+
sigma_t = self.noise_schedule.marginal_std(cached_time[-1])
|
| 820 |
+
l_t = self.l[cached_index[-1]]
|
| 821 |
+
N_old = sigma_t * cached_model_output[-1] - l_t * cached_x[-1]
|
| 822 |
+
cached_x[-1] = x_new
|
| 823 |
+
cached_model_output[-1] = (N_old + l_t * cached_x[-1]) / sigma_t
|
| 824 |
+
if step < order:
|
| 825 |
+
step_p_order = step
|
| 826 |
+
else:
|
| 827 |
+
step_p_order = order
|
| 828 |
+
if lower_order_final:
|
| 829 |
+
step_p_order = min(step_p_order, steps + 1 - step)
|
| 830 |
+
t = timesteps[step]
|
| 831 |
+
i_t = indexes[step]
|
| 832 |
+
|
| 833 |
+
x = self.multistep_predictor_update(
|
| 834 |
+
cached_x, cached_model_output, cached_time, cached_index, t, i_t, order=step_p_order, pseudo=p_pseudo
|
| 835 |
+
)
|
| 836 |
+
|
| 837 |
+
if return_intermediate:
|
| 838 |
+
return x, cached_x
|
| 839 |
+
else:
|
| 840 |
+
return x
|
| 841 |
+
|
| 842 |
+
|
| 843 |
+
#############################################################
|
| 844 |
+
# other utility functions
|
| 845 |
+
#############################################################
|
| 846 |
+
|
| 847 |
+
|
| 848 |
+
def interpolate_fn(x, xp, yp):
|
| 849 |
+
"""
|
| 850 |
+
A piecewise linear function y = f(x), using xp and yp as keypoints.
|
| 851 |
+
We implement f(x) in a differentiable way (i.e. applicable for autograd).
|
| 852 |
+
The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
|
| 853 |
+
|
| 854 |
+
Args:
|
| 855 |
+
x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
|
| 856 |
+
xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
|
| 857 |
+
yp: PyTorch tensor with shape [C, K].
|
| 858 |
+
Returns:
|
| 859 |
+
The function values f(x), with shape [N, C].
|
| 860 |
+
"""
|
| 861 |
+
N, K = x.shape[0], xp.shape[1]
|
| 862 |
+
all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
|
| 863 |
+
sorted_all_x, x_indices = torch.sort(all_x, dim=2)
|
| 864 |
+
x_idx = torch.argmin(x_indices, dim=2)
|
| 865 |
+
cand_start_idx = x_idx - 1
|
| 866 |
+
start_idx = torch.where(
|
| 867 |
+
torch.eq(x_idx, 0),
|
| 868 |
+
torch.tensor(1, device=x.device),
|
| 869 |
+
torch.where(
|
| 870 |
+
torch.eq(x_idx, K),
|
| 871 |
+
torch.tensor(K - 2, device=x.device),
|
| 872 |
+
cand_start_idx,
|
| 873 |
+
),
|
| 874 |
+
)
|
| 875 |
+
end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
|
| 876 |
+
start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
|
| 877 |
+
end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
|
| 878 |
+
start_idx2 = torch.where(
|
| 879 |
+
torch.eq(x_idx, 0),
|
| 880 |
+
torch.tensor(0, device=x.device),
|
| 881 |
+
torch.where(
|
| 882 |
+
torch.eq(x_idx, K),
|
| 883 |
+
torch.tensor(K - 2, device=x.device),
|
| 884 |
+
cand_start_idx,
|
| 885 |
+
),
|
| 886 |
+
)
|
| 887 |
+
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
|
| 888 |
+
start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
|
| 889 |
+
end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
|
| 890 |
+
cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
|
| 891 |
+
return cand
|
| 892 |
+
|
| 893 |
+
|
| 894 |
+
def expand_dims(v, dims):
|
| 895 |
+
"""
|
| 896 |
+
Expand the tensor `v` to the dim `dims`.
|
| 897 |
+
|
| 898 |
+
Args:
|
| 899 |
+
`v`: a PyTorch tensor with shape [N].
|
| 900 |
+
`dim`: a `int`.
|
| 901 |
+
Returns:
|
| 902 |
+
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
|
| 903 |
+
"""
|
| 904 |
+
return v[(...,) + (None,) * (dims - 1)]
|
free_lunch_utils.py
ADDED
|
@@ -0,0 +1,303 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.fft as fft
|
| 3 |
+
from diffusers.utils import is_torch_version
|
| 4 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def isinstance_str(x: object, cls_name: str):
|
| 8 |
+
"""
|
| 9 |
+
Checks whether x has any class *named* cls_name in its ancestry.
|
| 10 |
+
Doesn't require access to the class's implementation.
|
| 11 |
+
|
| 12 |
+
Useful for patching!
|
| 13 |
+
"""
|
| 14 |
+
|
| 15 |
+
for _cls in x.__class__.__mro__:
|
| 16 |
+
if _cls.__name__ == cls_name:
|
| 17 |
+
return True
|
| 18 |
+
|
| 19 |
+
return False
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def Fourier_filter(x, threshold, scale):
|
| 23 |
+
dtype = x.dtype
|
| 24 |
+
x = x.type(torch.float32)
|
| 25 |
+
# FFT
|
| 26 |
+
x_freq = fft.fftn(x, dim=(-2, -1))
|
| 27 |
+
x_freq = fft.fftshift(x_freq, dim=(-2, -1))
|
| 28 |
+
|
| 29 |
+
B, C, H, W = x_freq.shape
|
| 30 |
+
mask = torch.ones((B, C, H, W)).cuda()
|
| 31 |
+
|
| 32 |
+
crow, ccol = H // 2, W //2
|
| 33 |
+
mask[..., crow - threshold:crow + threshold, ccol - threshold:ccol + threshold] = scale
|
| 34 |
+
x_freq = x_freq * mask
|
| 35 |
+
|
| 36 |
+
# IFFT
|
| 37 |
+
x_freq = fft.ifftshift(x_freq, dim=(-2, -1))
|
| 38 |
+
x_filtered = fft.ifftn(x_freq, dim=(-2, -1)).real
|
| 39 |
+
|
| 40 |
+
x_filtered = x_filtered.type(dtype)
|
| 41 |
+
return x_filtered
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def register_upblock2d(model):
|
| 45 |
+
def up_forward(self):
|
| 46 |
+
def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
|
| 47 |
+
for resnet in self.resnets:
|
| 48 |
+
# pop res hidden states
|
| 49 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
| 50 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
| 51 |
+
#print(f"in upblock2d, hidden states shape: {hidden_states.shape}")
|
| 52 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
| 53 |
+
|
| 54 |
+
if self.training and self.gradient_checkpointing:
|
| 55 |
+
|
| 56 |
+
def create_custom_forward(module):
|
| 57 |
+
def custom_forward(*inputs):
|
| 58 |
+
return module(*inputs)
|
| 59 |
+
|
| 60 |
+
return custom_forward
|
| 61 |
+
|
| 62 |
+
if is_torch_version(">=", "1.11.0"):
|
| 63 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 64 |
+
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
| 65 |
+
)
|
| 66 |
+
else:
|
| 67 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 68 |
+
create_custom_forward(resnet), hidden_states, temb
|
| 69 |
+
)
|
| 70 |
+
else:
|
| 71 |
+
hidden_states = resnet(hidden_states, temb)
|
| 72 |
+
|
| 73 |
+
if self.upsamplers is not None:
|
| 74 |
+
for upsampler in self.upsamplers:
|
| 75 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
| 76 |
+
|
| 77 |
+
return hidden_states
|
| 78 |
+
|
| 79 |
+
return forward
|
| 80 |
+
|
| 81 |
+
for i, upsample_block in enumerate(model.unet.up_blocks):
|
| 82 |
+
if isinstance_str(upsample_block, "UpBlock2D"):
|
| 83 |
+
upsample_block.forward = up_forward(upsample_block)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def register_free_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
|
| 87 |
+
def up_forward(self):
|
| 88 |
+
def forward(hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
|
| 89 |
+
for resnet in self.resnets:
|
| 90 |
+
# pop res hidden states
|
| 91 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
| 92 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
| 93 |
+
#print(f"in free upblock2d, hidden states shape: {hidden_states.shape}")
|
| 94 |
+
|
| 95 |
+
# --------------- FreeU code -----------------------
|
| 96 |
+
# Only operate on the first two stages
|
| 97 |
+
if hidden_states.shape[1] == 1280:
|
| 98 |
+
hidden_states[:,:640] = hidden_states[:,:640] * self.b1
|
| 99 |
+
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
|
| 100 |
+
if hidden_states.shape[1] == 640:
|
| 101 |
+
hidden_states[:,:320] = hidden_states[:,:320] * self.b2
|
| 102 |
+
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
|
| 103 |
+
# ---------------------------------------------------------
|
| 104 |
+
|
| 105 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
| 106 |
+
|
| 107 |
+
if self.training and self.gradient_checkpointing:
|
| 108 |
+
|
| 109 |
+
def create_custom_forward(module):
|
| 110 |
+
def custom_forward(*inputs):
|
| 111 |
+
return module(*inputs)
|
| 112 |
+
|
| 113 |
+
return custom_forward
|
| 114 |
+
|
| 115 |
+
if is_torch_version(">=", "1.11.0"):
|
| 116 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 117 |
+
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
|
| 118 |
+
)
|
| 119 |
+
else:
|
| 120 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 121 |
+
create_custom_forward(resnet), hidden_states, temb
|
| 122 |
+
)
|
| 123 |
+
else:
|
| 124 |
+
hidden_states = resnet(hidden_states, temb)
|
| 125 |
+
|
| 126 |
+
if self.upsamplers is not None:
|
| 127 |
+
for upsampler in self.upsamplers:
|
| 128 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
| 129 |
+
|
| 130 |
+
return hidden_states
|
| 131 |
+
|
| 132 |
+
return forward
|
| 133 |
+
|
| 134 |
+
for i, upsample_block in enumerate(model.unet.up_blocks):
|
| 135 |
+
if isinstance_str(upsample_block, "UpBlock2D"):
|
| 136 |
+
upsample_block.forward = up_forward(upsample_block)
|
| 137 |
+
setattr(upsample_block, 'b1', b1)
|
| 138 |
+
setattr(upsample_block, 'b2', b2)
|
| 139 |
+
setattr(upsample_block, 's1', s1)
|
| 140 |
+
setattr(upsample_block, 's2', s2)
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def register_crossattn_upblock2d(model):
|
| 144 |
+
def up_forward(self):
|
| 145 |
+
def forward(
|
| 146 |
+
hidden_states: torch.FloatTensor,
|
| 147 |
+
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
| 148 |
+
temb: Optional[torch.FloatTensor] = None,
|
| 149 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 150 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 151 |
+
upsample_size: Optional[int] = None,
|
| 152 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 153 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 154 |
+
):
|
| 155 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
| 156 |
+
# pop res hidden states
|
| 157 |
+
#print(f"in crossatten upblock2d, hidden states shape: {hidden_states.shape}")
|
| 158 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
| 159 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
| 160 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
| 161 |
+
|
| 162 |
+
if self.training and self.gradient_checkpointing:
|
| 163 |
+
|
| 164 |
+
def create_custom_forward(module, return_dict=None):
|
| 165 |
+
def custom_forward(*inputs):
|
| 166 |
+
if return_dict is not None:
|
| 167 |
+
return module(*inputs, return_dict=return_dict)
|
| 168 |
+
else:
|
| 169 |
+
return module(*inputs)
|
| 170 |
+
|
| 171 |
+
return custom_forward
|
| 172 |
+
|
| 173 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 174 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 175 |
+
create_custom_forward(resnet),
|
| 176 |
+
hidden_states,
|
| 177 |
+
temb,
|
| 178 |
+
**ckpt_kwargs,
|
| 179 |
+
)
|
| 180 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 181 |
+
create_custom_forward(attn, return_dict=False),
|
| 182 |
+
hidden_states,
|
| 183 |
+
encoder_hidden_states,
|
| 184 |
+
None, # timestep
|
| 185 |
+
None, # class_labels
|
| 186 |
+
cross_attention_kwargs,
|
| 187 |
+
attention_mask,
|
| 188 |
+
encoder_attention_mask,
|
| 189 |
+
**ckpt_kwargs,
|
| 190 |
+
)[0]
|
| 191 |
+
else:
|
| 192 |
+
hidden_states = resnet(hidden_states, temb)
|
| 193 |
+
hidden_states = attn(
|
| 194 |
+
hidden_states,
|
| 195 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 196 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 197 |
+
attention_mask=attention_mask,
|
| 198 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 199 |
+
return_dict=False,
|
| 200 |
+
)[0]
|
| 201 |
+
|
| 202 |
+
if self.upsamplers is not None:
|
| 203 |
+
for upsampler in self.upsamplers:
|
| 204 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
| 205 |
+
|
| 206 |
+
return hidden_states
|
| 207 |
+
|
| 208 |
+
return forward
|
| 209 |
+
|
| 210 |
+
for i, upsample_block in enumerate(model.unet.up_blocks):
|
| 211 |
+
if isinstance_str(upsample_block, "CrossAttnUpBlock2D"):
|
| 212 |
+
upsample_block.forward = up_forward(upsample_block)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def register_free_crossattn_upblock2d(model, b1=1.2, b2=1.4, s1=0.9, s2=0.2):
|
| 216 |
+
def up_forward(self):
|
| 217 |
+
def forward(
|
| 218 |
+
hidden_states: torch.FloatTensor,
|
| 219 |
+
res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
|
| 220 |
+
temb: Optional[torch.FloatTensor] = None,
|
| 221 |
+
encoder_hidden_states: Optional[torch.FloatTensor] = None,
|
| 222 |
+
cross_attention_kwargs: Optional[Dict[str, Any]] = None,
|
| 223 |
+
upsample_size: Optional[int] = None,
|
| 224 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
| 225 |
+
encoder_attention_mask: Optional[torch.FloatTensor] = None,
|
| 226 |
+
):
|
| 227 |
+
for resnet, attn in zip(self.resnets, self.attentions):
|
| 228 |
+
# pop res hidden states
|
| 229 |
+
#print(f"in free crossatten upblock2d, hidden states shape: {hidden_states.shape}")
|
| 230 |
+
res_hidden_states = res_hidden_states_tuple[-1]
|
| 231 |
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
| 232 |
+
|
| 233 |
+
# --------------- FreeU code -----------------------
|
| 234 |
+
# Only operate on the first two stages
|
| 235 |
+
if hidden_states.shape[1] == 1280:
|
| 236 |
+
hidden_states[:,:640] = hidden_states[:,:640] * self.b1
|
| 237 |
+
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s1)
|
| 238 |
+
if hidden_states.shape[1] == 640:
|
| 239 |
+
hidden_states[:,:320] = hidden_states[:,:320] * self.b2
|
| 240 |
+
res_hidden_states = Fourier_filter(res_hidden_states, threshold=1, scale=self.s2)
|
| 241 |
+
# ---------------------------------------------------------
|
| 242 |
+
|
| 243 |
+
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
| 244 |
+
|
| 245 |
+
if self.training and self.gradient_checkpointing:
|
| 246 |
+
|
| 247 |
+
def create_custom_forward(module, return_dict=None):
|
| 248 |
+
def custom_forward(*inputs):
|
| 249 |
+
if return_dict is not None:
|
| 250 |
+
return module(*inputs, return_dict=return_dict)
|
| 251 |
+
else:
|
| 252 |
+
return module(*inputs)
|
| 253 |
+
|
| 254 |
+
return custom_forward
|
| 255 |
+
|
| 256 |
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
| 257 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 258 |
+
create_custom_forward(resnet),
|
| 259 |
+
hidden_states,
|
| 260 |
+
temb,
|
| 261 |
+
**ckpt_kwargs,
|
| 262 |
+
)
|
| 263 |
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
| 264 |
+
create_custom_forward(attn, return_dict=False),
|
| 265 |
+
hidden_states,
|
| 266 |
+
encoder_hidden_states,
|
| 267 |
+
None, # timestep
|
| 268 |
+
None, # class_labels
|
| 269 |
+
cross_attention_kwargs,
|
| 270 |
+
attention_mask,
|
| 271 |
+
encoder_attention_mask,
|
| 272 |
+
**ckpt_kwargs,
|
| 273 |
+
)[0]
|
| 274 |
+
else:
|
| 275 |
+
hidden_states = resnet(hidden_states, temb)
|
| 276 |
+
# hidden_states = attn(
|
| 277 |
+
# hidden_states,
|
| 278 |
+
# encoder_hidden_states=encoder_hidden_states,
|
| 279 |
+
# cross_attention_kwargs=cross_attention_kwargs,
|
| 280 |
+
# encoder_attention_mask=encoder_attention_mask,
|
| 281 |
+
# return_dict=False,
|
| 282 |
+
# )[0]
|
| 283 |
+
hidden_states = attn(
|
| 284 |
+
hidden_states,
|
| 285 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 286 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 287 |
+
)[0]
|
| 288 |
+
|
| 289 |
+
if self.upsamplers is not None:
|
| 290 |
+
for upsampler in self.upsamplers:
|
| 291 |
+
hidden_states = upsampler(hidden_states, upsample_size)
|
| 292 |
+
|
| 293 |
+
return hidden_states
|
| 294 |
+
|
| 295 |
+
return forward
|
| 296 |
+
|
| 297 |
+
for i, upsample_block in enumerate(model.unet.up_blocks):
|
| 298 |
+
if isinstance_str(upsample_block, "CrossAttnUpBlock2D"):
|
| 299 |
+
upsample_block.forward = up_forward(upsample_block)
|
| 300 |
+
setattr(upsample_block, 'b1', b1)
|
| 301 |
+
setattr(upsample_block, 'b2', b2)
|
| 302 |
+
setattr(upsample_block, 's1', s1)
|
| 303 |
+
setattr(upsample_block, 's2', s2)
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
tqdm
|
| 2 |
+
einops
|
| 3 |
+
pytorch_lightning
|
| 4 |
+
accelerate
|
| 5 |
+
torchsde
|
| 6 |
+
pycocotools
|
| 7 |
+
diffusers
|
| 8 |
+
timm
|
| 9 |
+
transformers
|
| 10 |
+
opencv-python
|
| 11 |
+
omegaconf
|
sampler.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""SAMPLING ONLY."""
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from dpm_solver_v3 import NoiseScheduleVP, model_wrapper, DPM_Solver_v3
|
| 6 |
+
from uni_pc import UniPC
|
| 7 |
+
from free_lunch_utils import register_free_upblock2d, register_free_crossattn_upblock2d
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class DPMSolverv3Sampler:
|
| 11 |
+
def __init__(self, stats_dir, pipe, steps, guidance_scale, **kwargs):
|
| 12 |
+
super().__init__()
|
| 13 |
+
self.model = pipe
|
| 14 |
+
to_torch = lambda x: x.clone().detach().to(torch.float32).to(pipe.device)
|
| 15 |
+
DTYPE = torch.float32 # torch.float16 works as well, but pictures seem to be a bit worse
|
| 16 |
+
device = "cuda"
|
| 17 |
+
noise_scheduler = pipe.scheduler
|
| 18 |
+
alpha_schedule = noise_scheduler.alphas_cumprod.to(device=device, dtype=DTYPE)
|
| 19 |
+
self.alphas_cumprod = alpha_schedule #to_torch(model.alphas_cumprod)
|
| 20 |
+
self.device = device
|
| 21 |
+
self.guidance_scale = guidance_scale
|
| 22 |
+
|
| 23 |
+
self.ns = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod)
|
| 24 |
+
|
| 25 |
+
assert stats_dir is not None, f"No statistics file found in {stats_dir}."
|
| 26 |
+
print("Use statistics", stats_dir)
|
| 27 |
+
self.dpm_solver_v3 = DPM_Solver_v3(
|
| 28 |
+
statistics_dir=stats_dir,
|
| 29 |
+
noise_schedule=self.ns,
|
| 30 |
+
steps=steps,
|
| 31 |
+
t_start=None,
|
| 32 |
+
t_end=None,
|
| 33 |
+
skip_type="customed_time_karras",
|
| 34 |
+
degenerated=False,
|
| 35 |
+
device=self.device,
|
| 36 |
+
)
|
| 37 |
+
self.steps = steps
|
| 38 |
+
|
| 39 |
+
@torch.no_grad()
|
| 40 |
+
def apply_free_unet(self):
|
| 41 |
+
register_free_upblock2d(self.model, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
|
| 42 |
+
register_free_crossattn_upblock2d(self.model, b1=1.1, b2=1.1, s1=0.9, s2=0.2)
|
| 43 |
+
|
| 44 |
+
@torch.no_grad()
|
| 45 |
+
def stop_free_unet(self):
|
| 46 |
+
register_free_upblock2d(self.model, b1=1.0, b2=1.0, s1=1.0, s2=1.0)
|
| 47 |
+
register_free_crossattn_upblock2d(self.model, b1=1.0, b2=1.0, s1=1.0, s2=1.0)
|
| 48 |
+
|
| 49 |
+
@torch.no_grad()
|
| 50 |
+
def sample(
|
| 51 |
+
self,
|
| 52 |
+
batch_size,
|
| 53 |
+
shape,
|
| 54 |
+
conditioning=None,
|
| 55 |
+
x_T=None,
|
| 56 |
+
unconditional_conditioning=None,
|
| 57 |
+
use_corrector=False,
|
| 58 |
+
half=False,
|
| 59 |
+
start_free_u_step=None,
|
| 60 |
+
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
| 61 |
+
**kwargs,
|
| 62 |
+
):
|
| 63 |
+
if conditioning is not None:
|
| 64 |
+
cond_in = torch.cat([unconditional_conditioning, conditioning])
|
| 65 |
+
# extra_args = {'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': self.guidance_scale}
|
| 66 |
+
if isinstance(conditioning, dict):
|
| 67 |
+
cbs = conditioning[list(conditioning.keys())[0]].shape[0]
|
| 68 |
+
if cbs != batch_size:
|
| 69 |
+
print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")
|
| 70 |
+
else:
|
| 71 |
+
if conditioning.shape[0] != batch_size:
|
| 72 |
+
print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
|
| 73 |
+
|
| 74 |
+
# sampling
|
| 75 |
+
C, H, W = shape
|
| 76 |
+
size = (batch_size, C, H, W)
|
| 77 |
+
|
| 78 |
+
if x_T is None:
|
| 79 |
+
img = torch.randn(size, device=self.device)
|
| 80 |
+
else:
|
| 81 |
+
img = x_T
|
| 82 |
+
|
| 83 |
+
if conditioning is None:
|
| 84 |
+
model_fn = model_wrapper(
|
| 85 |
+
lambda x, t, c: self.model.unet(x, t, encoder_hidden_states=c).sample,
|
| 86 |
+
self.ns,
|
| 87 |
+
model_type="noise",
|
| 88 |
+
guidance_type="uncond",
|
| 89 |
+
)
|
| 90 |
+
ORDER = 3
|
| 91 |
+
else:
|
| 92 |
+
model_fn = model_wrapper(
|
| 93 |
+
lambda x, t, c: self.model.unet(x, t, encoder_hidden_states=c).sample,
|
| 94 |
+
self.ns,
|
| 95 |
+
model_type="noise",
|
| 96 |
+
guidance_type="classifier-free",
|
| 97 |
+
condition=conditioning,
|
| 98 |
+
unconditional_condition=unconditional_conditioning,
|
| 99 |
+
guidance_scale=self.guidance_scale,
|
| 100 |
+
)
|
| 101 |
+
if self.steps == 8:
|
| 102 |
+
ORDER = 2
|
| 103 |
+
else:
|
| 104 |
+
ORDER = 1
|
| 105 |
+
|
| 106 |
+
x = self.dpm_solver_v3.sample(
|
| 107 |
+
img,
|
| 108 |
+
model_fn,
|
| 109 |
+
order=ORDER,
|
| 110 |
+
p_pseudo=False,
|
| 111 |
+
c_pseudo=True,
|
| 112 |
+
lower_order_final=True,
|
| 113 |
+
use_corrector=use_corrector,
|
| 114 |
+
start_free_u_step=start_free_u_step,
|
| 115 |
+
free_u_apply_callback=self.apply_free_unet if start_free_u_step is not None else None,
|
| 116 |
+
free_u_stop_callback=self.stop_free_unet if start_free_u_step is not None else None,
|
| 117 |
+
half=half,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
return x.to(self.device), None
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class UniPCSampler:
|
| 124 |
+
def __init__(self
|
| 125 |
+
, pipe
|
| 126 |
+
, model_closure
|
| 127 |
+
, steps
|
| 128 |
+
, guidance_scale,denoise_to_zero=False
|
| 129 |
+
, need_fp16_discrete_method = False
|
| 130 |
+
, ultilize_vae_in_fp16 = False
|
| 131 |
+
, is_high_resoulution = True
|
| 132 |
+
, skip_type="customed_time_karras"
|
| 133 |
+
, force_not_use_afs=False
|
| 134 |
+
, **kwargs):
|
| 135 |
+
super().__init__()
|
| 136 |
+
# self.model = pipe
|
| 137 |
+
self.model = model_closure(pipe)
|
| 138 |
+
self.pipe = pipe
|
| 139 |
+
self.need_fp16_discrete_method = need_fp16_discrete_method
|
| 140 |
+
# to_torch = lambda x: x.clone().detach().to(torch.float32).to(pipe.device)
|
| 141 |
+
DTYPE = self.pipe.unet.dtype # torch.float16 works as well, but pictures seem to be a bit worse
|
| 142 |
+
device = self.pipe.device
|
| 143 |
+
noise_scheduler = pipe.scheduler
|
| 144 |
+
alpha_schedule = noise_scheduler.alphas_cumprod.to(device=device, dtype=DTYPE)
|
| 145 |
+
self.alphas_cumprod = alpha_schedule #to_torch(model.alphas_cumprod)
|
| 146 |
+
self.device = device
|
| 147 |
+
self.guidance_scale = guidance_scale
|
| 148 |
+
self.use_afs = steps <= 8 and is_high_resoulution and not force_not_use_afs
|
| 149 |
+
|
| 150 |
+
self.ns = NoiseScheduleVP("discrete", alphas_cumprod=self.alphas_cumprod)
|
| 151 |
+
|
| 152 |
+
self.unipc_solver = UniPC(
|
| 153 |
+
noise_schedule=self.ns,
|
| 154 |
+
steps=steps,
|
| 155 |
+
t_start=None,
|
| 156 |
+
t_end=None,
|
| 157 |
+
skip_type=skip_type,
|
| 158 |
+
degenerated=False,
|
| 159 |
+
use_afs=self.use_afs,
|
| 160 |
+
device=self.device,
|
| 161 |
+
denoise_to_zero=denoise_to_zero,
|
| 162 |
+
need_fp16_discrete_method = self.need_fp16_discrete_method,
|
| 163 |
+
ultilize_vae_in_fp16 = ultilize_vae_in_fp16,
|
| 164 |
+
is_high_resoulution=is_high_resoulution,
|
| 165 |
+
)
|
| 166 |
+
self.steps = steps
|
| 167 |
+
|
| 168 |
+
@torch.no_grad()
|
| 169 |
+
def apply_free_unet(self):
|
| 170 |
+
register_free_upblock2d(self.pipe, b1=1.2, b2=1.2, s1=0.9, s2=0.2)
|
| 171 |
+
register_free_crossattn_upblock2d(self.pipe, b1=1.2, b2=1.2, s1=0.9, s2=0.2)
|
| 172 |
+
|
| 173 |
+
@torch.no_grad()
|
| 174 |
+
def stop_free_unet(self):
|
| 175 |
+
register_free_upblock2d(self.pipe, b1=1.0, b2=1.0, s1=1.0, s2=1.0)
|
| 176 |
+
register_free_crossattn_upblock2d(self.pipe, b1=1.0, b2=1.0, s1=1.0, s2=1.0)
|
| 177 |
+
|
| 178 |
+
@torch.no_grad()
|
| 179 |
+
def sample(
|
| 180 |
+
self,
|
| 181 |
+
batch_size,
|
| 182 |
+
shape,
|
| 183 |
+
conditioning=None,
|
| 184 |
+
x_T=None,
|
| 185 |
+
unconditional_conditioning=None,
|
| 186 |
+
use_corrector=False,
|
| 187 |
+
half=False,
|
| 188 |
+
start_free_u_step=None,
|
| 189 |
+
xl_preprocess_closure=None,
|
| 190 |
+
npnet=None,
|
| 191 |
+
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
| 192 |
+
**kwargs,
|
| 193 |
+
):
|
| 194 |
+
|
| 195 |
+
# sampling
|
| 196 |
+
C, H, W = shape
|
| 197 |
+
size = (batch_size, C, H, W)
|
| 198 |
+
new_img = None
|
| 199 |
+
if xl_preprocess_closure is not None:
|
| 200 |
+
prompt_embeds, cond_kwargs = xl_preprocess_closure(pipe=self.pipe,prompts = conditioning, need_cfg=True, device=self.device,negative_prompts=unconditional_conditioning)
|
| 201 |
+
if x_T is None:
|
| 202 |
+
img = torch.randn(size, device=self.device)
|
| 203 |
+
else:
|
| 204 |
+
img = x_T
|
| 205 |
+
if xl_preprocess_closure is not None and npnet is not None:
|
| 206 |
+
c, _ = prompt_embeds
|
| 207 |
+
c = c.unsqueeze(0) # add dummy dimension for npnet
|
| 208 |
+
new_img = npnet(img, c)
|
| 209 |
+
|
| 210 |
+
if conditioning is None:
|
| 211 |
+
model_fn = model_wrapper(
|
| 212 |
+
lambda x, t, c: self.model(x, t, c),
|
| 213 |
+
self.ns,
|
| 214 |
+
model_type="noise",
|
| 215 |
+
guidance_type="uncond",
|
| 216 |
+
)
|
| 217 |
+
ORDER = 3
|
| 218 |
+
else:
|
| 219 |
+
model_fn = model_wrapper(
|
| 220 |
+
lambda x, t, c: self.model(x, t, c),
|
| 221 |
+
self.ns,
|
| 222 |
+
model_type="noise",
|
| 223 |
+
guidance_type="classifier-free",
|
| 224 |
+
condition=conditioning if xl_preprocess_closure is None else prompt_embeds,
|
| 225 |
+
unconditional_condition=unconditional_conditioning if xl_preprocess_closure is None else cond_kwargs,
|
| 226 |
+
guidance_scale=self.guidance_scale,
|
| 227 |
+
)
|
| 228 |
+
if self.steps >= 7:
|
| 229 |
+
ORDER = 2
|
| 230 |
+
else:
|
| 231 |
+
ORDER = 1
|
| 232 |
+
|
| 233 |
+
x, full_cache = self.unipc_solver.sample(
|
| 234 |
+
x=img,
|
| 235 |
+
model_fn=model_fn,
|
| 236 |
+
order=ORDER,
|
| 237 |
+
use_corrector=use_corrector,
|
| 238 |
+
lower_order_final=True,
|
| 239 |
+
start_free_u_step=start_free_u_step,
|
| 240 |
+
free_u_apply_callback=self.apply_free_unet if start_free_u_step is not None else None,
|
| 241 |
+
free_u_stop_callback=self.stop_free_unet if start_free_u_step is not None else None,
|
| 242 |
+
npnet_x=new_img if new_img is not None else None,
|
| 243 |
+
npnet_scale=self.guidance_scale if new_img is not None else None,
|
| 244 |
+
half=half,
|
| 245 |
+
)
|
| 246 |
+
|
| 247 |
+
return x.to(self.device), full_cache
|
| 248 |
+
|
| 249 |
+
@torch.no_grad()
|
| 250 |
+
def sample_mix(
|
| 251 |
+
self,
|
| 252 |
+
batch_size,
|
| 253 |
+
shape,
|
| 254 |
+
conditioning=None,
|
| 255 |
+
x_T=None,
|
| 256 |
+
unconditional_conditioning=None,
|
| 257 |
+
use_corrector=False,
|
| 258 |
+
half=False,
|
| 259 |
+
start_free_u_step=None,
|
| 260 |
+
xl_preprocess_closure=None,
|
| 261 |
+
npnet=None,
|
| 262 |
+
# this has to come in the same format as the conditioning, # e.g. as encoded tokens, ...
|
| 263 |
+
**kwargs,
|
| 264 |
+
):
|
| 265 |
+
|
| 266 |
+
# sampling
|
| 267 |
+
C, H, W = shape
|
| 268 |
+
size = (batch_size, C, H, W)
|
| 269 |
+
if xl_preprocess_closure is not None:
|
| 270 |
+
prompt_embeds, cond_kwargs = xl_preprocess_closure(pipe=self.pipe,prompts = conditioning, need_cfg=True, device=self.device,negative_prompts=unconditional_conditioning)
|
| 271 |
+
if x_T is None:
|
| 272 |
+
img = torch.randn(size, device=self.device)
|
| 273 |
+
else:
|
| 274 |
+
img = x_T
|
| 275 |
+
if xl_preprocess_closure is not None and npnet is not None:
|
| 276 |
+
c, _ = prompt_embeds
|
| 277 |
+
c = c.unsqueeze(0) # add dummy dimension for npnet
|
| 278 |
+
img = npnet(img, c)
|
| 279 |
+
|
| 280 |
+
if conditioning is None:
|
| 281 |
+
model_fn = model_wrapper(
|
| 282 |
+
lambda x, t, c: self.model(x, t, c),
|
| 283 |
+
self.ns,
|
| 284 |
+
model_type="noise",
|
| 285 |
+
guidance_type="uncond",
|
| 286 |
+
)
|
| 287 |
+
ORDER = 3
|
| 288 |
+
else:
|
| 289 |
+
model_fn = model_wrapper(
|
| 290 |
+
lambda x, t, c: self.model(x, t, c),
|
| 291 |
+
self.ns,
|
| 292 |
+
model_type="noise",
|
| 293 |
+
guidance_type="classifier-free",
|
| 294 |
+
condition=conditioning if xl_preprocess_closure is None else prompt_embeds,
|
| 295 |
+
unconditional_condition=unconditional_conditioning if xl_preprocess_closure is None else cond_kwargs,
|
| 296 |
+
guidance_scale=self.guidance_scale,
|
| 297 |
+
)
|
| 298 |
+
if self.steps >= 8 and not self.need_fp16_discrete_method:
|
| 299 |
+
ORDER = 2
|
| 300 |
+
else:
|
| 301 |
+
ORDER = 1
|
| 302 |
+
|
| 303 |
+
x, full_cache = self.unipc_solver.sample_mix(
|
| 304 |
+
x=img,
|
| 305 |
+
model_fn=model_fn,
|
| 306 |
+
order=ORDER,
|
| 307 |
+
use_corrector=use_corrector,
|
| 308 |
+
lower_order_final=True,
|
| 309 |
+
start_free_u_step=start_free_u_step,
|
| 310 |
+
free_u_apply_callback=self.apply_free_unet if start_free_u_step is not None else None,
|
| 311 |
+
free_u_stop_callback=self.stop_free_unet if start_free_u_step is not None else None,
|
| 312 |
+
half=half,
|
| 313 |
+
)
|
| 314 |
+
|
| 315 |
+
return x.to(self.device), full_cache
|
uni_pc.py
ADDED
|
@@ -0,0 +1,757 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dpm_solver_v3 import NoiseScheduleVP, model_wrapper
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import math
|
| 5 |
+
import numpy as np
|
| 6 |
+
import os
|
| 7 |
+
|
| 8 |
+
class UniPC:
|
| 9 |
+
def __init__(
|
| 10 |
+
self,
|
| 11 |
+
noise_schedule,
|
| 12 |
+
steps=10,
|
| 13 |
+
t_start=None,
|
| 14 |
+
t_end=None,
|
| 15 |
+
skip_type="customed_time_karras",
|
| 16 |
+
degenerated=False,
|
| 17 |
+
use_afs = False,
|
| 18 |
+
denoise_to_zero=False,
|
| 19 |
+
need_fp16_discrete_method = False,
|
| 20 |
+
ultilize_vae_in_fp16 = False,
|
| 21 |
+
is_high_resoulution = True,
|
| 22 |
+
device="cuda",
|
| 23 |
+
):
|
| 24 |
+
self.device = device
|
| 25 |
+
self.model = None
|
| 26 |
+
self.noise_schedule = noise_schedule
|
| 27 |
+
self.steps = steps if not use_afs else steps + 1
|
| 28 |
+
self.use_afs = use_afs
|
| 29 |
+
self.ultilize_vae_in_fp16 = ultilize_vae_in_fp16
|
| 30 |
+
self.need_fp16_discrete_method = need_fp16_discrete_method
|
| 31 |
+
t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end
|
| 32 |
+
t_T = self.noise_schedule.T if t_start is None else t_start
|
| 33 |
+
self.is_high_resolution = is_high_resoulution
|
| 34 |
+
assert (
|
| 35 |
+
t_0 > 0 and t_T > 0
|
| 36 |
+
), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
# precompute timesteps
|
| 40 |
+
if skip_type == "logSNR" or skip_type == "time_uniform" or skip_type == "time_quadratic" or skip_type == "customed_time_karras":
|
| 41 |
+
self.timesteps = self.get_time_steps(skip_type
|
| 42 |
+
, t_T=t_T
|
| 43 |
+
, t_0=t_0
|
| 44 |
+
, N=steps
|
| 45 |
+
, device=device,denoise_to_zero=denoise_to_zero
|
| 46 |
+
, is_high_resolution=self.is_high_resolution)
|
| 47 |
+
else:
|
| 48 |
+
raise ValueError(f"Unsupported timestep strategy {skip_type}")
|
| 49 |
+
self.lambda_T = self.timesteps[0].cpu().item()
|
| 50 |
+
self.lambda_0 = self.timesteps[-1].cpu().item()
|
| 51 |
+
|
| 52 |
+
# print("Time steps", self.timesteps)
|
| 53 |
+
# print("LogSNR steps", self.noise_schedule.marginal_lambda(self.timesteps))
|
| 54 |
+
|
| 55 |
+
# store high-order exponential coefficients (lazy)
|
| 56 |
+
self.exp_coeffs = {}
|
| 57 |
+
|
| 58 |
+
def noise_prediction_fn(self, x, t):
|
| 59 |
+
"""
|
| 60 |
+
Return the noise prediction model.
|
| 61 |
+
"""
|
| 62 |
+
return self.model(x, t)
|
| 63 |
+
|
| 64 |
+
def append_zero(self, x):
|
| 65 |
+
return torch.cat([x, x.new_zeros([1])])
|
| 66 |
+
|
| 67 |
+
def get_sigmas_karras(self, n, sigma_min, sigma_max, rho=7., device='cpu', need_append_zero=True):
|
| 68 |
+
"""Constructs the noise schedule of Karras et al. (2022)."""
|
| 69 |
+
ramp = torch.linspace(0, 1, n)
|
| 70 |
+
min_inv_rho = sigma_min ** (1 / rho)
|
| 71 |
+
max_inv_rho = sigma_max ** (1 / rho)
|
| 72 |
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
| 73 |
+
return self.append_zero(sigmas).to(device) if need_append_zero else sigmas.to(device)
|
| 74 |
+
|
| 75 |
+
def sigma_to_t(self, sigma, quantize=None):
|
| 76 |
+
quantize = False
|
| 77 |
+
log_sigma = sigma.log()
|
| 78 |
+
dists = log_sigma - self.noise_schedule.log_sigmas[:, None]
|
| 79 |
+
if quantize:
|
| 80 |
+
return dists.abs().argmin(dim=0).view(sigma.shape)
|
| 81 |
+
low_idx = dists.ge(0).cumsum(dim=0).argmax(dim=0).clamp(max=self.noise_schedule.log_sigmas.shape[0] - 2)
|
| 82 |
+
high_idx = low_idx + 1
|
| 83 |
+
low, high = self.noise_schedule.log_sigmas[low_idx], self.noise_schedule.log_sigmas[high_idx]
|
| 84 |
+
w = (low - log_sigma) / (low - high)
|
| 85 |
+
w = w.clamp(0, 1)
|
| 86 |
+
t = (1 - w) * low_idx + w * high_idx
|
| 87 |
+
return t.view(sigma.shape)
|
| 88 |
+
|
| 89 |
+
def get_time_steps(self, skip_type, t_T, t_0, N, device, denoise_to_zero=False, is_high_resolution=True):
|
| 90 |
+
"""Compute the intermediate time steps for sampling.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
skip_type: A `str`. The type for the spacing of the time steps. We support three types:
|
| 94 |
+
- 'logSNR': uniform logSNR for the time steps.
|
| 95 |
+
- 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
|
| 96 |
+
- 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
|
| 97 |
+
t_T: A `float`. The starting time of the sampling (default is T).
|
| 98 |
+
t_0: A `float`. The ending time of the sampling (default is epsilon).
|
| 99 |
+
N: A `int`. The total number of the spacing of the time steps.
|
| 100 |
+
device: A torch device.
|
| 101 |
+
Returns:
|
| 102 |
+
A pytorch tensor of the time steps, with the shape (N + 1,).
|
| 103 |
+
"""
|
| 104 |
+
if skip_type == "logSNR":
|
| 105 |
+
lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
|
| 106 |
+
lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
|
| 107 |
+
logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
|
| 108 |
+
return self.noise_schedule.inverse_lambda(logSNR_steps)
|
| 109 |
+
elif skip_type == "time_uniform":
|
| 110 |
+
return torch.linspace(t_T, t_0, N + 1).to(device)
|
| 111 |
+
elif skip_type == "time_quadratic":
|
| 112 |
+
t_order = 2
|
| 113 |
+
t = torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device)
|
| 114 |
+
return t
|
| 115 |
+
elif skip_type == "customed_time_karras" and is_high_resolution:
|
| 116 |
+
sigma_T = self.noise_schedule.sigmas[-1].cpu().item()
|
| 117 |
+
sigma_0 = self.noise_schedule.sigmas[0].cpu().item()
|
| 118 |
+
if N == 8:
|
| 119 |
+
sigmas = self.get_sigmas_karras(12, sigma_0, sigma_T,rho=12.0, device=device)
|
| 120 |
+
if not self.need_fp16_discrete_method:
|
| 121 |
+
ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[10])
|
| 122 |
+
ct = self.get_sigmas_karras(9, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
|
| 123 |
+
sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
|
| 124 |
+
real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
|
| 125 |
+
else:
|
| 126 |
+
sigmas = self.get_sigmas_karras(8, sigma_0, sigma_T, rho=5.0, device=device)
|
| 127 |
+
ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[6])
|
| 128 |
+
ct = self.get_sigmas_karras(8, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
|
| 129 |
+
sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
|
| 130 |
+
tmp_t = [self.noise_schedule.sigma_to_t(sigma).to('cpu') for sigma in sigmas_ct]
|
| 131 |
+
real_ct = [ t / 999 for t in tmp_t]
|
| 132 |
+
elif N == 5:
|
| 133 |
+
sigmas = self.get_sigmas_karras(8, sigma_0, sigma_T, rho=5.0, device=device)
|
| 134 |
+
if not self.need_fp16_discrete_method:
|
| 135 |
+
sigmas = self.get_sigmas_karras(12, sigma_0, sigma_T,rho=12.0, device=device)
|
| 136 |
+
ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[9])
|
| 137 |
+
ct = self.get_sigmas_karras(6, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
|
| 138 |
+
sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
|
| 139 |
+
real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
|
| 140 |
+
else:
|
| 141 |
+
ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[6])
|
| 142 |
+
ct = self.get_sigmas_karras(5, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
|
| 143 |
+
sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
|
| 144 |
+
real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
|
| 145 |
+
elif N == 6:
|
| 146 |
+
sigmas = self.get_sigmas_karras(8, sigma_0, sigma_T, rho=5.0, device=device)
|
| 147 |
+
if not self.need_fp16_discrete_method:
|
| 148 |
+
sigmas = self.get_sigmas_karras(12, sigma_0, sigma_T,rho=12.0, device=device)
|
| 149 |
+
ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[10])
|
| 150 |
+
ct = self.get_sigmas_karras(7, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
|
| 151 |
+
sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
|
| 152 |
+
real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
|
| 153 |
+
else:
|
| 154 |
+
if denoise_to_zero:
|
| 155 |
+
ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[6])
|
| 156 |
+
ct = self.get_sigmas_karras(6, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
|
| 157 |
+
sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
|
| 158 |
+
real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
|
| 159 |
+
real_ct.append(torch.tensor(t_0).to(dtype=real_ct[-1].dtype,device='cpu'))
|
| 160 |
+
else:
|
| 161 |
+
sigmas = self.get_sigmas_karras(12, sigma_0, sigma_T, rho=7.0, device=device)
|
| 162 |
+
ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[7])
|
| 163 |
+
ct = self.get_sigmas_karras(7, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
|
| 164 |
+
sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
|
| 165 |
+
real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
|
| 166 |
+
elif N == 7:
|
| 167 |
+
sigmas = self.get_sigmas_karras(8, sigma_0, sigma_T, rho=5.0, device=device)
|
| 168 |
+
if not self.need_fp16_discrete_method:
|
| 169 |
+
ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[6])
|
| 170 |
+
ct = self.get_sigmas_karras(8, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
|
| 171 |
+
sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
|
| 172 |
+
real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
|
| 173 |
+
else:
|
| 174 |
+
if denoise_to_zero:
|
| 175 |
+
ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[6])
|
| 176 |
+
ct = self.get_sigmas_karras(7, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
|
| 177 |
+
sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
|
| 178 |
+
real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
|
| 179 |
+
real_ct.append(torch.tensor(t_0).to(dtype=real_ct[-1].dtype,device='cpu'))
|
| 180 |
+
# if denoise_to_zero:
|
| 181 |
+
# real_ct.append(torch.tensor(t_0).to(dtype=real_ct[-1].dtype,device='cpu'))
|
| 182 |
+
|
| 183 |
+
if self.use_afs:
|
| 184 |
+
tmp_t = (real_ct[0] + real_ct[1]) / 2
|
| 185 |
+
real_ct.insert(1, tmp_t)
|
| 186 |
+
none_k_ct = torch.from_numpy(np.array(real_ct)).to(device)
|
| 187 |
+
return none_k_ct#real_ct
|
| 188 |
+
elif skip_type == "customed_time_karras" and not is_high_resolution:
|
| 189 |
+
sigma_T = self.noise_schedule.sigmas[-1].cpu().item()
|
| 190 |
+
sigma_0 = self.noise_schedule.sigmas[0].cpu().item()
|
| 191 |
+
if N == 8:
|
| 192 |
+
sigmas = self.get_sigmas_karras(12, sigma_0, sigma_T, rho=7.0, device=device)
|
| 193 |
+
ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[9])
|
| 194 |
+
ct = self.get_sigmas_karras(9, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
|
| 195 |
+
sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
|
| 196 |
+
real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
|
| 197 |
+
elif N == 5:
|
| 198 |
+
sigmas = self.get_sigmas_karras(8, sigma_0, sigma_T, rho=5.0, device=device)
|
| 199 |
+
ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[6])
|
| 200 |
+
ct = self.get_sigmas_karras(6, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
|
| 201 |
+
sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
|
| 202 |
+
real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
|
| 203 |
+
elif N == 6:
|
| 204 |
+
sigmas = self.sigmas = self.get_sigmas_karras(8, sigma_0, sigma_T, rho=5.0, device=device)
|
| 205 |
+
ct_start, ct_end = self.noise_schedule.sigma_to_t(sigmas[0]), self.sigma_to_t(sigmas[6])
|
| 206 |
+
ct = self.get_sigmas_karras(7, ct_end.item(), ct_start.item(),rho=1.2, device='cpu',need_append_zero=False).numpy()
|
| 207 |
+
sigmas_ct = self.noise_schedule.get_special_sigmas_with_timesteps(ct).to(device=device)
|
| 208 |
+
real_ct = [self.noise_schedule.sigma_to_t(sigma).to('cpu') / 999 for sigma in sigmas_ct]
|
| 209 |
+
none_k_ct = torch.from_numpy(np.array(real_ct)).to(device)
|
| 210 |
+
return none_k_ct#real_ct
|
| 211 |
+
else:
|
| 212 |
+
raise ValueError(
|
| 213 |
+
"Unsupported skip_type {}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'".format(skip_type)
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
def multistep_uni_pc_update(self, x, model_prev_list:list, t_prev_list: list, t, order, **kwargs):
|
| 218 |
+
if len(model_prev_list) == 0 or len(t_prev_list) == 0:
|
| 219 |
+
return None, None
|
| 220 |
+
if len(t.shape) == 0:
|
| 221 |
+
t = t.view(-1)
|
| 222 |
+
if True:#'bh' in self.variant:
|
| 223 |
+
return self.multistep_uni_pc_bh_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
|
| 224 |
+
else:
|
| 225 |
+
# assert self.variant == 'vary_coeff'
|
| 226 |
+
return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
|
| 227 |
+
|
| 228 |
+
def multistep_uni_pc_sde_update(self, x, model_prev_list:list, t_prev_list: list, t, order, level = 1.0, **kwargs):
|
| 229 |
+
if len(model_prev_list) == 0 or len(t_prev_list) == 0:
|
| 230 |
+
return None, None
|
| 231 |
+
if len(t.shape) == 0:
|
| 232 |
+
t = t.view(-1)
|
| 233 |
+
if True:#'bh' in self.variant:
|
| 234 |
+
return self.multistep_uni_pc_bh_sde_update(x, model_prev_list, t_prev_list, t, level=level, order= order, **kwargs)
|
| 235 |
+
else:
|
| 236 |
+
# assert self.variant == 'vary_coeff'
|
| 237 |
+
return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)
|
| 238 |
+
|
| 239 |
+
def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True):
|
| 240 |
+
# print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
|
| 241 |
+
ns = self.noise_schedule
|
| 242 |
+
assert order <= len(model_prev_list)
|
| 243 |
+
dims = x.dim()
|
| 244 |
+
|
| 245 |
+
# first compute rks
|
| 246 |
+
t_prev_0 = t_prev_list[-1]
|
| 247 |
+
lambda_prev_0 = ns.marginal_lambda(t_prev_0)
|
| 248 |
+
lambda_t = ns.marginal_lambda(t)
|
| 249 |
+
model_prev_0 = model_prev_list[-1]
|
| 250 |
+
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
| 251 |
+
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
| 252 |
+
alpha_t = torch.exp(log_alpha_t)
|
| 253 |
+
|
| 254 |
+
h = lambda_t - lambda_prev_0
|
| 255 |
+
|
| 256 |
+
rks = []
|
| 257 |
+
D1s = []
|
| 258 |
+
for i in range(1, order):
|
| 259 |
+
t_prev_i = t_prev_list[-(i + 1)]
|
| 260 |
+
model_prev_i = model_prev_list[-(i + 1)]
|
| 261 |
+
lambda_prev_i = ns.marginal_lambda(t_prev_i)
|
| 262 |
+
rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
|
| 263 |
+
rks.append(rk)
|
| 264 |
+
D1s.append((model_prev_i - model_prev_0) / rk)
|
| 265 |
+
|
| 266 |
+
rks.append(1.)
|
| 267 |
+
rks = torch.tensor(rks, device=x.device)
|
| 268 |
+
|
| 269 |
+
R = []
|
| 270 |
+
b = []
|
| 271 |
+
|
| 272 |
+
hh = h[0]
|
| 273 |
+
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
|
| 274 |
+
h_phi_k = h_phi_1 / hh - 1
|
| 275 |
+
|
| 276 |
+
factorial_i = 1
|
| 277 |
+
|
| 278 |
+
if True:
|
| 279 |
+
B_h = hh
|
| 280 |
+
else:
|
| 281 |
+
B_h = torch.expm1(hh)
|
| 282 |
+
|
| 283 |
+
for i in range(1, order + 1):
|
| 284 |
+
R.append(torch.pow(rks, i - 1))
|
| 285 |
+
b.append(h_phi_k * factorial_i / B_h)
|
| 286 |
+
factorial_i *= (i + 1)
|
| 287 |
+
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
| 288 |
+
|
| 289 |
+
R = torch.stack(R)
|
| 290 |
+
b = torch.tensor(b, device=x.device)
|
| 291 |
+
|
| 292 |
+
# now predictor
|
| 293 |
+
use_predictor = len(D1s) > 0 and x_t is None
|
| 294 |
+
if len(D1s) > 0:
|
| 295 |
+
D1s = torch.stack(D1s, dim=1) # (B, K)
|
| 296 |
+
if x_t is None:
|
| 297 |
+
# for order 2, we use a simplified version
|
| 298 |
+
if order == 2:
|
| 299 |
+
rhos_p = torch.tensor([0.5], device=b.device)
|
| 300 |
+
else:
|
| 301 |
+
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
|
| 302 |
+
else:
|
| 303 |
+
D1s = None
|
| 304 |
+
|
| 305 |
+
if use_corrector:
|
| 306 |
+
# print('using corrector')
|
| 307 |
+
# for order 1, we use a simplified version
|
| 308 |
+
if order == 1:
|
| 309 |
+
rhos_c = torch.tensor([0.5], device=b.device)
|
| 310 |
+
else:
|
| 311 |
+
rhos_c = torch.linalg.solve(R, b)
|
| 312 |
+
|
| 313 |
+
model_t = None
|
| 314 |
+
|
| 315 |
+
x_t_ = (
|
| 316 |
+
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
|
| 317 |
+
- expand_dims(sigma_t * h_phi_1, dims) * model_prev_0
|
| 318 |
+
)
|
| 319 |
+
if x_t is None:
|
| 320 |
+
if use_predictor:
|
| 321 |
+
pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
|
| 322 |
+
else:
|
| 323 |
+
pred_res = 0
|
| 324 |
+
x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * pred_res
|
| 325 |
+
|
| 326 |
+
if use_corrector:
|
| 327 |
+
model_t = self.noise_prediction_fn(x_t, t)
|
| 328 |
+
if D1s is not None:
|
| 329 |
+
corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
|
| 330 |
+
else:
|
| 331 |
+
corr_res = 0
|
| 332 |
+
D1_t = (model_t - model_prev_0)
|
| 333 |
+
x_t = x_t_ - expand_dims(sigma_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
|
| 334 |
+
|
| 335 |
+
return x_t, model_t
|
| 336 |
+
|
| 337 |
+
def multistep_uni_pc_bh_sde_update(self, x, model_prev_list, t_prev_list, t, order, level = 0, x_t=None, use_corrector=True):
|
| 338 |
+
# print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
|
| 339 |
+
ns = self.noise_schedule
|
| 340 |
+
assert order <= len(model_prev_list)
|
| 341 |
+
dims = x.dim()
|
| 342 |
+
|
| 343 |
+
# first compute rks
|
| 344 |
+
t_prev_0 = t_prev_list[-1]
|
| 345 |
+
lambda_prev_0 = ns.marginal_lambda(t_prev_0)
|
| 346 |
+
lambda_t = ns.marginal_lambda(t)
|
| 347 |
+
model_prev_0 = model_prev_list[-1]
|
| 348 |
+
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
| 349 |
+
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
| 350 |
+
alpha_t = torch.exp(log_alpha_t)
|
| 351 |
+
|
| 352 |
+
h = lambda_t - lambda_prev_0
|
| 353 |
+
z = torch.randn(x.shape, device=self.device)
|
| 354 |
+
z = sigma_t * torch.sqrt(torch.expm1(2.0 * h[0])) * z
|
| 355 |
+
|
| 356 |
+
rks = []
|
| 357 |
+
D1s = []
|
| 358 |
+
for i in range(1, order):
|
| 359 |
+
t_prev_i = t_prev_list[-(i + 1)]
|
| 360 |
+
model_prev_i = model_prev_list[-(i + 1)]
|
| 361 |
+
lambda_prev_i = ns.marginal_lambda(t_prev_i)
|
| 362 |
+
rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
|
| 363 |
+
rks.append(rk)
|
| 364 |
+
D1s.append((model_prev_i - model_prev_0) / rk)
|
| 365 |
+
|
| 366 |
+
rks.append(1.)
|
| 367 |
+
rks = torch.tensor(rks, device=x.device)
|
| 368 |
+
|
| 369 |
+
R = []
|
| 370 |
+
b = []
|
| 371 |
+
|
| 372 |
+
hh = h[0]
|
| 373 |
+
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
|
| 374 |
+
h_phi_k = h_phi_1 / hh - 1
|
| 375 |
+
|
| 376 |
+
factorial_i = 1
|
| 377 |
+
|
| 378 |
+
if True:
|
| 379 |
+
B_h = hh
|
| 380 |
+
else:
|
| 381 |
+
B_h = torch.expm1(hh)
|
| 382 |
+
|
| 383 |
+
for i in range(1, order + 1):
|
| 384 |
+
R.append(torch.pow(rks, i - 1))
|
| 385 |
+
b.append(h_phi_k * factorial_i / B_h)
|
| 386 |
+
factorial_i *= (i + 1)
|
| 387 |
+
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
| 388 |
+
|
| 389 |
+
R = torch.stack(R)
|
| 390 |
+
b = torch.tensor(b, device=x.device)
|
| 391 |
+
|
| 392 |
+
# now predictor
|
| 393 |
+
use_predictor = len(D1s) > 0 and x_t is None
|
| 394 |
+
if len(D1s) > 0:
|
| 395 |
+
D1s = torch.stack(D1s, dim=1) # (B, K)
|
| 396 |
+
if x_t is None:
|
| 397 |
+
# for order 2, we use a simplified version
|
| 398 |
+
if order == 2:
|
| 399 |
+
rhos_p = torch.tensor([0.5], device=b.device)
|
| 400 |
+
else:
|
| 401 |
+
rhos_p = torch.linalg.solve(R[:-1, :-1], b[:-1])
|
| 402 |
+
else:
|
| 403 |
+
D1s = None
|
| 404 |
+
|
| 405 |
+
if use_corrector:
|
| 406 |
+
# print('using corrector')
|
| 407 |
+
# for order 1, we use a simplified version
|
| 408 |
+
if order == 1:
|
| 409 |
+
rhos_c = torch.tensor([0.5], device=b.device)
|
| 410 |
+
else:
|
| 411 |
+
rhos_c = torch.linalg.solve(R, b)
|
| 412 |
+
|
| 413 |
+
model_t = None
|
| 414 |
+
|
| 415 |
+
x_t_ = (
|
| 416 |
+
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
|
| 417 |
+
- expand_dims(sigma_t * h_phi_1, dims) * (1 + level) * model_prev_0
|
| 418 |
+
)
|
| 419 |
+
if x_t is None:
|
| 420 |
+
if use_predictor:
|
| 421 |
+
pred_res = torch.einsum('k,bkchw->bchw', rhos_p, D1s)
|
| 422 |
+
else:
|
| 423 |
+
pred_res = 0
|
| 424 |
+
|
| 425 |
+
x_t_p = (
|
| 426 |
+
expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
|
| 427 |
+
- expand_dims(sigma_t * h_phi_1, dims) * model_prev_0
|
| 428 |
+
)
|
| 429 |
+
x_t = x_t_p - expand_dims(sigma_t * B_h, dims) * pred_res
|
| 430 |
+
|
| 431 |
+
if use_corrector:
|
| 432 |
+
model_t = self.noise_prediction_fn(x_t, t)
|
| 433 |
+
if D1s is not None:
|
| 434 |
+
corr_res = torch.einsum('k,bkchw->bchw', rhos_c[:-1], D1s)
|
| 435 |
+
else:
|
| 436 |
+
corr_res = 0
|
| 437 |
+
D1_t = (model_t - model_prev_0)
|
| 438 |
+
x_t = x_t_ - (1 + level) * expand_dims(sigma_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t) + z * level
|
| 439 |
+
|
| 440 |
+
return x_t, model_t
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
|
| 444 |
+
# print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
|
| 445 |
+
ns = self.noise_schedule
|
| 446 |
+
assert order <= len(model_prev_list)
|
| 447 |
+
dims = x.dim()
|
| 448 |
+
# first compute rks
|
| 449 |
+
t_prev_0 = t_prev_list[-1]
|
| 450 |
+
lambda_prev_0 = ns.marginal_lambda(t_prev_0)
|
| 451 |
+
lambda_t = ns.marginal_lambda(t)
|
| 452 |
+
model_prev_0 = model_prev_list[-1]
|
| 453 |
+
sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
|
| 454 |
+
log_alpha_t = ns.marginal_log_mean_coeff(t)
|
| 455 |
+
alpha_t = torch.exp(log_alpha_t)
|
| 456 |
+
|
| 457 |
+
h = lambda_t - lambda_prev_0
|
| 458 |
+
|
| 459 |
+
rks = []
|
| 460 |
+
D1s = []
|
| 461 |
+
for i in range(1, order):
|
| 462 |
+
t_prev_i = t_prev_list[-(i + 1)]
|
| 463 |
+
model_prev_i = model_prev_list[-(i + 1)]
|
| 464 |
+
lambda_prev_i = ns.marginal_lambda(t_prev_i)
|
| 465 |
+
rk = ((lambda_prev_i - lambda_prev_0) / h)[0]
|
| 466 |
+
rks.append(rk)
|
| 467 |
+
D1s.append((model_prev_i - model_prev_0) / rk)
|
| 468 |
+
|
| 469 |
+
rks.append(1.)
|
| 470 |
+
rks = torch.tensor(rks, device=x.device)
|
| 471 |
+
|
| 472 |
+
K = len(rks)
|
| 473 |
+
# build C matrix
|
| 474 |
+
C = []
|
| 475 |
+
|
| 476 |
+
col = torch.ones_like(rks)
|
| 477 |
+
for k in range(1, K + 1):
|
| 478 |
+
C.append(col)
|
| 479 |
+
col = col * rks / (k + 1)
|
| 480 |
+
C = torch.stack(C, dim=1)
|
| 481 |
+
|
| 482 |
+
if len(D1s) > 0:
|
| 483 |
+
D1s = torch.stack(D1s, dim=1) # (B, K)
|
| 484 |
+
C_inv_p = torch.linalg.inv(C[:-1, :-1])
|
| 485 |
+
A_p = C_inv_p
|
| 486 |
+
|
| 487 |
+
if use_corrector:
|
| 488 |
+
# print('using corrector')
|
| 489 |
+
C_inv = torch.linalg.inv(C)
|
| 490 |
+
A_c = C_inv
|
| 491 |
+
|
| 492 |
+
hh = h
|
| 493 |
+
h_phi_1 = torch.expm1(hh)
|
| 494 |
+
h_phi_ks = []
|
| 495 |
+
factorial_k = 1
|
| 496 |
+
h_phi_k = h_phi_1
|
| 497 |
+
for k in range(1, K + 2):
|
| 498 |
+
h_phi_ks.append(h_phi_k)
|
| 499 |
+
h_phi_k = h_phi_k / hh - 1 / factorial_k
|
| 500 |
+
factorial_k *= (k + 1)
|
| 501 |
+
|
| 502 |
+
model_t = None
|
| 503 |
+
if True:
|
| 504 |
+
log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
|
| 505 |
+
x_t_ = (
|
| 506 |
+
expand_dims((torch.exp(log_alpha_t - log_alpha_prev_0)),dims) * x
|
| 507 |
+
- expand_dims((sigma_t * h_phi_1),dims) * model_prev_0
|
| 508 |
+
)
|
| 509 |
+
# now predictor
|
| 510 |
+
x_t = x_t_
|
| 511 |
+
if len(D1s) > 0:
|
| 512 |
+
# compute the residuals for predictor
|
| 513 |
+
for k in range(K - 1):
|
| 514 |
+
x_t = x_t - expand_dims(sigma_t * h_phi_ks[k + 1],dims) * torch.einsum('bkchw,k->bchw', D1s, A_p[k])
|
| 515 |
+
# now corrector
|
| 516 |
+
if use_corrector:
|
| 517 |
+
model_t = self.noise_prediction_fn(x_t, t)
|
| 518 |
+
D1_t = (model_t - model_prev_0)
|
| 519 |
+
x_t = x_t_
|
| 520 |
+
k = 0
|
| 521 |
+
for k in range(K - 1):
|
| 522 |
+
x_t = x_t - expand_dims(sigma_t * h_phi_ks[k + 1],dims) * torch.einsum('bkchw,k->bchw', D1s, A_c[k][:-1])
|
| 523 |
+
x_t = x_t - expand_dims(sigma_t * h_phi_ks[K],dims) * (D1_t * A_c[k][-1])
|
| 524 |
+
return x_t, model_t
|
| 525 |
+
|
| 526 |
+
def sample(
|
| 527 |
+
self,
|
| 528 |
+
x,
|
| 529 |
+
model_fn,
|
| 530 |
+
order,
|
| 531 |
+
use_corrector,
|
| 532 |
+
lower_order_final,
|
| 533 |
+
start_free_u_step=None,
|
| 534 |
+
free_u_apply_callback=None,
|
| 535 |
+
free_u_stop_callback=None,
|
| 536 |
+
npnet_x = None,
|
| 537 |
+
npnet_scale = None,
|
| 538 |
+
half=False,
|
| 539 |
+
return_intermediate=False,
|
| 540 |
+
):
|
| 541 |
+
self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])))
|
| 542 |
+
steps = self.steps
|
| 543 |
+
vec_t = self.timesteps[0].expand((x.shape[0]))
|
| 544 |
+
if free_u_stop_callback is not None:
|
| 545 |
+
free_u_stop_callback()
|
| 546 |
+
if start_free_u_step is not None and 0 == start_free_u_step and free_u_apply_callback is not None:
|
| 547 |
+
free_u_apply_callback()
|
| 548 |
+
has_called_free_u = True
|
| 549 |
+
if not self.use_afs:
|
| 550 |
+
fir_output = self.noise_prediction_fn(x, vec_t)
|
| 551 |
+
else:
|
| 552 |
+
fir_output = x * 0.97 # ultilize npnet there in the future
|
| 553 |
+
if npnet_x is not None and npnet_scale is not None:
|
| 554 |
+
fir_output = npnet_x
|
| 555 |
+
# fir_output = fir_output - npnet_scale * (npnet_out - fir_output) #guidance_scale * (noise - noise_uncond)
|
| 556 |
+
x = fir_output.clone().detach().to(fir_output.device)
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
model_prev_list = [fir_output]
|
| 560 |
+
full_cache = [fir_output]
|
| 561 |
+
t_prev_list = [vec_t]
|
| 562 |
+
has_called_free_u = False
|
| 563 |
+
for init_order in range(1, order):
|
| 564 |
+
if start_free_u_step is not None and init_order == start_free_u_step and free_u_apply_callback is not None and (not has_called_free_u):
|
| 565 |
+
free_u_apply_callback()
|
| 566 |
+
has_called_free_u = True
|
| 567 |
+
vec_t = self.timesteps[init_order].expand(x.shape[0])
|
| 568 |
+
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
|
| 569 |
+
if model_x is None:
|
| 570 |
+
model_x = self.noise_prediction_fn(x, vec_t)
|
| 571 |
+
x = model_x.clone().detach().to(torch.float32).to(model_x.device)
|
| 572 |
+
full_cache.append(x)
|
| 573 |
+
model_prev_list.append(model_x)
|
| 574 |
+
t_prev_list.append(vec_t)
|
| 575 |
+
|
| 576 |
+
for step in range(order, steps + 1):
|
| 577 |
+
if start_free_u_step is not None and step == start_free_u_step and free_u_apply_callback is not None and (not has_called_free_u):
|
| 578 |
+
free_u_apply_callback()
|
| 579 |
+
vec_t = self.timesteps[step].expand(x.shape[0])
|
| 580 |
+
if lower_order_final:
|
| 581 |
+
step_order = min(order, steps + 1 - step)
|
| 582 |
+
else:
|
| 583 |
+
step_order = order
|
| 584 |
+
# print('this step order:', step_order)
|
| 585 |
+
if step == steps:
|
| 586 |
+
# print('do not run corrector at the last step')
|
| 587 |
+
use_corrector = False
|
| 588 |
+
else:
|
| 589 |
+
use_corrector = True
|
| 590 |
+
x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
|
| 591 |
+
for i in range(order - 1):
|
| 592 |
+
t_prev_list[i] = t_prev_list[i + 1]
|
| 593 |
+
model_prev_list[i] = model_prev_list[i + 1]
|
| 594 |
+
t_prev_list[-1] = vec_t
|
| 595 |
+
# We do not need to evaluate the final model value.
|
| 596 |
+
full_cache.append(x)
|
| 597 |
+
if step < steps:
|
| 598 |
+
if model_x is None:
|
| 599 |
+
model_x = self.noise_prediction_fn(x, vec_t)
|
| 600 |
+
model_prev_list[-1] = model_x
|
| 601 |
+
return x, full_cache
|
| 602 |
+
def sample_mix(
|
| 603 |
+
self,
|
| 604 |
+
x,
|
| 605 |
+
model_fn,
|
| 606 |
+
order,
|
| 607 |
+
use_corrector,
|
| 608 |
+
lower_order_final,
|
| 609 |
+
start_free_u_step=None,
|
| 610 |
+
free_u_apply_callback=None,
|
| 611 |
+
free_u_stop_callback=None,
|
| 612 |
+
noise_level = 0.1,
|
| 613 |
+
half=False,
|
| 614 |
+
return_intermediate=False,
|
| 615 |
+
):
|
| 616 |
+
self.model = lambda x, t: model_fn(x, t.expand((x.shape[0])))
|
| 617 |
+
steps = self.steps
|
| 618 |
+
vec_t = self.timesteps[0].expand((x.shape[0]))
|
| 619 |
+
fir_output = self.noise_prediction_fn(x, vec_t)
|
| 620 |
+
model_prev_list = [fir_output]
|
| 621 |
+
full_cache = [fir_output]
|
| 622 |
+
t_prev_list = [vec_t]
|
| 623 |
+
has_called_free_u = False
|
| 624 |
+
if free_u_stop_callback is not None:
|
| 625 |
+
free_u_stop_callback()
|
| 626 |
+
for init_order in range(1, order):
|
| 627 |
+
if start_free_u_step is not None and init_order == start_free_u_step and free_u_apply_callback is not None:
|
| 628 |
+
free_u_apply_callback()
|
| 629 |
+
has_called_free_u = True
|
| 630 |
+
vec_t = self.timesteps[init_order].expand(x.shape[0])
|
| 631 |
+
if start_free_u_step is not None and init_order >= start_free_u_step and free_u_apply_callback is not None:
|
| 632 |
+
x, model_x = self.multistep_uni_pc_sde_update(x
|
| 633 |
+
, model_prev_list
|
| 634 |
+
, t_prev_list
|
| 635 |
+
, vec_t
|
| 636 |
+
, init_order
|
| 637 |
+
, use_corrector=True
|
| 638 |
+
,level=noise_level)
|
| 639 |
+
else:
|
| 640 |
+
x, model_x = self.multistep_uni_pc_sde_update(x
|
| 641 |
+
, model_prev_list
|
| 642 |
+
, t_prev_list
|
| 643 |
+
, vec_t
|
| 644 |
+
, init_order
|
| 645 |
+
, use_corrector=True
|
| 646 |
+
,level=0.0)
|
| 647 |
+
if model_x is None:
|
| 648 |
+
model_x = self.noise_prediction_fn(x, vec_t)
|
| 649 |
+
x = model_x.clone().detach().to(torch.float32).to(model_x.device)
|
| 650 |
+
full_cache.append(x)
|
| 651 |
+
model_prev_list.append(model_x)
|
| 652 |
+
t_prev_list.append(vec_t)
|
| 653 |
+
|
| 654 |
+
if free_u_stop_callback is not None:
|
| 655 |
+
free_u_stop_callback()
|
| 656 |
+
for step in range(order, steps + 1):
|
| 657 |
+
if start_free_u_step is not None and step == start_free_u_step and free_u_apply_callback is not None and (not has_called_free_u):
|
| 658 |
+
free_u_apply_callback()
|
| 659 |
+
vec_t = self.timesteps[step].expand(x.shape[0])
|
| 660 |
+
if lower_order_final:
|
| 661 |
+
step_order = min(order, steps + 1 - step)
|
| 662 |
+
else:
|
| 663 |
+
step_order = order
|
| 664 |
+
# print('this step order:', step_order)
|
| 665 |
+
if step == steps:
|
| 666 |
+
# print('do not run corrector at the last step')
|
| 667 |
+
use_corrector = False
|
| 668 |
+
else:
|
| 669 |
+
use_corrector = True
|
| 670 |
+
if start_free_u_step is not None and step >= start_free_u_step and free_u_apply_callback is not None:
|
| 671 |
+
x, model_x = self.multistep_uni_pc_sde_update(x
|
| 672 |
+
, model_prev_list
|
| 673 |
+
, t_prev_list
|
| 674 |
+
, vec_t
|
| 675 |
+
, step_order
|
| 676 |
+
, use_corrector=use_corrector
|
| 677 |
+
, level=noise_level)
|
| 678 |
+
else:
|
| 679 |
+
x, model_x = self.multistep_uni_pc_sde_update(x
|
| 680 |
+
, model_prev_list
|
| 681 |
+
, t_prev_list
|
| 682 |
+
, vec_t
|
| 683 |
+
, step_order
|
| 684 |
+
, use_corrector=use_corrector
|
| 685 |
+
, level=0.0)
|
| 686 |
+
for i in range(order - 1):
|
| 687 |
+
t_prev_list[i] = t_prev_list[i + 1]
|
| 688 |
+
model_prev_list[i] = model_prev_list[i + 1]
|
| 689 |
+
t_prev_list[-1] = vec_t
|
| 690 |
+
# We do not need to evaluate the final model value.
|
| 691 |
+
full_cache.append(x)
|
| 692 |
+
if step < steps:
|
| 693 |
+
if model_x is None:
|
| 694 |
+
model_x = self.noise_prediction_fn(x, vec_t)
|
| 695 |
+
model_prev_list[-1] = model_x
|
| 696 |
+
return x, full_cache
|
| 697 |
+
|
| 698 |
+
|
| 699 |
+
|
| 700 |
+
|
| 701 |
+
#############################################################
|
| 702 |
+
# other utility functions
|
| 703 |
+
#############################################################
|
| 704 |
+
|
| 705 |
+
def interpolate_fn(x, xp, yp):
|
| 706 |
+
"""
|
| 707 |
+
A piecewise linear function y = f(x), using xp and yp as keypoints.
|
| 708 |
+
We implement f(x) in a differentiable way (i.e. applicable for autograd).
|
| 709 |
+
The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
|
| 710 |
+
|
| 711 |
+
Args:
|
| 712 |
+
x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
|
| 713 |
+
xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
|
| 714 |
+
yp: PyTorch tensor with shape [C, K].
|
| 715 |
+
Returns:
|
| 716 |
+
The function values f(x), with shape [N, C].
|
| 717 |
+
"""
|
| 718 |
+
N, K = x.shape[0], xp.shape[1]
|
| 719 |
+
all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
|
| 720 |
+
sorted_all_x, x_indices = torch.sort(all_x, dim=2)
|
| 721 |
+
x_idx = torch.argmin(x_indices, dim=2)
|
| 722 |
+
cand_start_idx = x_idx - 1
|
| 723 |
+
start_idx = torch.where(
|
| 724 |
+
torch.eq(x_idx, 0),
|
| 725 |
+
torch.tensor(1, device=x.device),
|
| 726 |
+
torch.where(
|
| 727 |
+
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
|
| 728 |
+
),
|
| 729 |
+
)
|
| 730 |
+
end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
|
| 731 |
+
start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
|
| 732 |
+
end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
|
| 733 |
+
start_idx2 = torch.where(
|
| 734 |
+
torch.eq(x_idx, 0),
|
| 735 |
+
torch.tensor(0, device=x.device),
|
| 736 |
+
torch.where(
|
| 737 |
+
torch.eq(x_idx, K), torch.tensor(K - 2, device=x.device), cand_start_idx,
|
| 738 |
+
),
|
| 739 |
+
)
|
| 740 |
+
y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
|
| 741 |
+
start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
|
| 742 |
+
end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
|
| 743 |
+
cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
|
| 744 |
+
return cand
|
| 745 |
+
|
| 746 |
+
|
| 747 |
+
def expand_dims(v, dims):
|
| 748 |
+
"""
|
| 749 |
+
Expand the tensor `v` to the dim `dims`.
|
| 750 |
+
|
| 751 |
+
Args:
|
| 752 |
+
`v`: a PyTorch tensor with shape [N].
|
| 753 |
+
`dim`: a `int`.
|
| 754 |
+
Returns:
|
| 755 |
+
a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
|
| 756 |
+
"""
|
| 757 |
+
return v[(...,) + (None,)*(dims - 1)]
|