b3h-young123 commited on
Commit
66d99ea
·
verified ·
1 Parent(s): d864b06

Add files using upload-large-folder tool

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/__init__.py +0 -0
  2. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/DAT.py +1182 -0
  3. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/HAT.py +1277 -0
  4. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/LICENSE-DAT +201 -0
  5. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/LICENSE-ESRGAN +201 -0
  6. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/LICENSE-HAT +21 -0
  7. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/LICENSE-RealESRGAN +29 -0
  8. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/LICENSE-SCUNet +201 -0
  9. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/LICENSE-SPSR +201 -0
  10. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/LICENSE-SwiftSRGAN +121 -0
  11. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/LICENSE-Swin2SR +201 -0
  12. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/LICENSE-SwinIR +201 -0
  13. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/LICENSE-lama +201 -0
  14. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/LaMa.py +694 -0
  15. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/OmniSR/ChannelAttention.py +110 -0
  16. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/OmniSR/LICENSE +201 -0
  17. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/OmniSR/OSA.py +577 -0
  18. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/OmniSR/OSAG.py +60 -0
  19. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py +143 -0
  20. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/OmniSR/esa.py +294 -0
  21. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/OmniSR/layernorm.py +70 -0
  22. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/OmniSR/pixelshuffle.py +31 -0
  23. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/RRDB.py +296 -0
  24. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/SCUNet.py +455 -0
  25. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/SPSR.py +383 -0
  26. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/SRVGG.py +114 -0
  27. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/SwiftSRGAN.py +161 -0
  28. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/Swin2SR.py +1377 -0
  29. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/SwinIR.py +1224 -0
  30. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/__init__.py +0 -0
  31. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/block.py +546 -0
  32. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/face/LICENSE-GFPGAN +351 -0
  33. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/face/LICENSE-RestoreFormer +351 -0
  34. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/face/LICENSE-codeformer +35 -0
  35. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/face/arcface_arch.py +265 -0
  36. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/face/codeformer.py +790 -0
  37. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/face/fused_act.py +81 -0
  38. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/face/gfpgan_bilinear_arch.py +389 -0
  39. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/face/gfpganv1_arch.py +566 -0
  40. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/face/gfpganv1_clean_arch.py +370 -0
  41. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/face/restoreformer_arch.py +776 -0
  42. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/face/stylegan2_arch.py +865 -0
  43. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/face/stylegan2_bilinear_arch.py +709 -0
  44. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/face/stylegan2_clean_arch.py +453 -0
  45. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/face/upfirdn2d.py +194 -0
  46. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/timm/LICENSE +201 -0
  47. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/timm/drop.py +223 -0
  48. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/timm/helpers.py +31 -0
  49. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/timm/weight_init.py +128 -0
  50. LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/model_loading.py +99 -0
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/__init__.py ADDED
File without changes
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/DAT.py ADDED
@@ -0,0 +1,1182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ import math
3
+ import re
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.utils.checkpoint as checkpoint
9
+ from einops import rearrange
10
+ from einops.layers.torch import Rearrange
11
+ from torch import Tensor
12
+ from torch.nn import functional as F
13
+
14
+ from .timm.drop import DropPath
15
+ from .timm.weight_init import trunc_normal_
16
+
17
+
18
+ def img2windows(img, H_sp, W_sp):
19
+ """
20
+ Input: Image (B, C, H, W)
21
+ Output: Window Partition (B', N, C)
22
+ """
23
+ B, C, H, W = img.shape
24
+ img_reshape = img.view(B, C, H // H_sp, H_sp, W // W_sp, W_sp)
25
+ img_perm = (
26
+ img_reshape.permute(0, 2, 4, 3, 5, 1).contiguous().reshape(-1, H_sp * W_sp, C)
27
+ )
28
+ return img_perm
29
+
30
+
31
+ def windows2img(img_splits_hw, H_sp, W_sp, H, W):
32
+ """
33
+ Input: Window Partition (B', N, C)
34
+ Output: Image (B, H, W, C)
35
+ """
36
+ B = int(img_splits_hw.shape[0] / (H * W / H_sp / W_sp))
37
+
38
+ img = img_splits_hw.view(B, H // H_sp, W // W_sp, H_sp, W_sp, -1)
39
+ img = img.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
40
+ return img
41
+
42
+
43
+ class SpatialGate(nn.Module):
44
+ """Spatial-Gate.
45
+ Args:
46
+ dim (int): Half of input channels.
47
+ """
48
+
49
+ def __init__(self, dim):
50
+ super().__init__()
51
+ self.norm = nn.LayerNorm(dim)
52
+ self.conv = nn.Conv2d(
53
+ dim, dim, kernel_size=3, stride=1, padding=1, groups=dim
54
+ ) # DW Conv
55
+
56
+ def forward(self, x, H, W):
57
+ # Split
58
+ x1, x2 = x.chunk(2, dim=-1)
59
+ B, N, C = x.shape
60
+ x2 = (
61
+ self.conv(self.norm(x2).transpose(1, 2).contiguous().view(B, C // 2, H, W))
62
+ .flatten(2)
63
+ .transpose(-1, -2)
64
+ .contiguous()
65
+ )
66
+
67
+ return x1 * x2
68
+
69
+
70
+ class SGFN(nn.Module):
71
+ """Spatial-Gate Feed-Forward Network.
72
+ Args:
73
+ in_features (int): Number of input channels.
74
+ hidden_features (int | None): Number of hidden channels. Default: None
75
+ out_features (int | None): Number of output channels. Default: None
76
+ act_layer (nn.Module): Activation layer. Default: nn.GELU
77
+ drop (float): Dropout rate. Default: 0.0
78
+ """
79
+
80
+ def __init__(
81
+ self,
82
+ in_features,
83
+ hidden_features=None,
84
+ out_features=None,
85
+ act_layer=nn.GELU,
86
+ drop=0.0,
87
+ ):
88
+ super().__init__()
89
+ out_features = out_features or in_features
90
+ hidden_features = hidden_features or in_features
91
+ self.fc1 = nn.Linear(in_features, hidden_features)
92
+ self.act = act_layer()
93
+ self.sg = SpatialGate(hidden_features // 2)
94
+ self.fc2 = nn.Linear(hidden_features // 2, out_features)
95
+ self.drop = nn.Dropout(drop)
96
+
97
+ def forward(self, x, H, W):
98
+ """
99
+ Input: x: (B, H*W, C), H, W
100
+ Output: x: (B, H*W, C)
101
+ """
102
+ x = self.fc1(x)
103
+ x = self.act(x)
104
+ x = self.drop(x)
105
+
106
+ x = self.sg(x, H, W)
107
+ x = self.drop(x)
108
+
109
+ x = self.fc2(x)
110
+ x = self.drop(x)
111
+ return x
112
+
113
+
114
+ class DynamicPosBias(nn.Module):
115
+ # The implementation builds on Crossformer code https://github.com/cheerss/CrossFormer/blob/main/models/crossformer.py
116
+ """Dynamic Relative Position Bias.
117
+ Args:
118
+ dim (int): Number of input channels.
119
+ num_heads (int): Number of attention heads.
120
+ residual (bool): If True, use residual strage to connect conv.
121
+ """
122
+
123
+ def __init__(self, dim, num_heads, residual):
124
+ super().__init__()
125
+ self.residual = residual
126
+ self.num_heads = num_heads
127
+ self.pos_dim = dim // 4
128
+ self.pos_proj = nn.Linear(2, self.pos_dim)
129
+ self.pos1 = nn.Sequential(
130
+ nn.LayerNorm(self.pos_dim),
131
+ nn.ReLU(inplace=True),
132
+ nn.Linear(self.pos_dim, self.pos_dim),
133
+ )
134
+ self.pos2 = nn.Sequential(
135
+ nn.LayerNorm(self.pos_dim),
136
+ nn.ReLU(inplace=True),
137
+ nn.Linear(self.pos_dim, self.pos_dim),
138
+ )
139
+ self.pos3 = nn.Sequential(
140
+ nn.LayerNorm(self.pos_dim),
141
+ nn.ReLU(inplace=True),
142
+ nn.Linear(self.pos_dim, self.num_heads),
143
+ )
144
+
145
+ def forward(self, biases):
146
+ if self.residual:
147
+ pos = self.pos_proj(biases) # 2Gh-1 * 2Gw-1, heads
148
+ pos = pos + self.pos1(pos)
149
+ pos = pos + self.pos2(pos)
150
+ pos = self.pos3(pos)
151
+ else:
152
+ pos = self.pos3(self.pos2(self.pos1(self.pos_proj(biases))))
153
+ return pos
154
+
155
+
156
+ class Spatial_Attention(nn.Module):
157
+ """Spatial Window Self-Attention.
158
+ It supports rectangle window (containing square window).
159
+ Args:
160
+ dim (int): Number of input channels.
161
+ idx (int): The indentix of window. (0/1)
162
+ split_size (tuple(int)): Height and Width of spatial window.
163
+ dim_out (int | None): The dimension of the attention output. Default: None
164
+ num_heads (int): Number of attention heads. Default: 6
165
+ attn_drop (float): Dropout ratio of attention weight. Default: 0.0
166
+ proj_drop (float): Dropout ratio of output. Default: 0.0
167
+ qk_scale (float | None): Override default qk scale of head_dim ** -0.5 if set
168
+ position_bias (bool): The dynamic relative position bias. Default: True
169
+ """
170
+
171
+ def __init__(
172
+ self,
173
+ dim,
174
+ idx,
175
+ split_size=[8, 8],
176
+ dim_out=None,
177
+ num_heads=6,
178
+ attn_drop=0.0,
179
+ proj_drop=0.0,
180
+ qk_scale=None,
181
+ position_bias=True,
182
+ ):
183
+ super().__init__()
184
+ self.dim = dim
185
+ self.dim_out = dim_out or dim
186
+ self.split_size = split_size
187
+ self.num_heads = num_heads
188
+ self.idx = idx
189
+ self.position_bias = position_bias
190
+
191
+ head_dim = dim // num_heads
192
+ self.scale = qk_scale or head_dim**-0.5
193
+
194
+ if idx == 0:
195
+ H_sp, W_sp = self.split_size[0], self.split_size[1]
196
+ elif idx == 1:
197
+ W_sp, H_sp = self.split_size[0], self.split_size[1]
198
+ else:
199
+ print("ERROR MODE", idx)
200
+ exit(0)
201
+ self.H_sp = H_sp
202
+ self.W_sp = W_sp
203
+
204
+ if self.position_bias:
205
+ self.pos = DynamicPosBias(self.dim // 4, self.num_heads, residual=False)
206
+ # generate mother-set
207
+ position_bias_h = torch.arange(1 - self.H_sp, self.H_sp)
208
+ position_bias_w = torch.arange(1 - self.W_sp, self.W_sp)
209
+ biases = torch.stack(torch.meshgrid([position_bias_h, position_bias_w]))
210
+ biases = biases.flatten(1).transpose(0, 1).contiguous().float()
211
+ self.register_buffer("rpe_biases", biases)
212
+
213
+ # get pair-wise relative position index for each token inside the window
214
+ coords_h = torch.arange(self.H_sp)
215
+ coords_w = torch.arange(self.W_sp)
216
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w]))
217
+ coords_flatten = torch.flatten(coords, 1)
218
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
219
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
220
+ relative_coords[:, :, 0] += self.H_sp - 1
221
+ relative_coords[:, :, 1] += self.W_sp - 1
222
+ relative_coords[:, :, 0] *= 2 * self.W_sp - 1
223
+ relative_position_index = relative_coords.sum(-1)
224
+ self.register_buffer("relative_position_index", relative_position_index)
225
+
226
+ self.attn_drop = nn.Dropout(attn_drop)
227
+
228
+ def im2win(self, x, H, W):
229
+ B, N, C = x.shape
230
+ x = x.transpose(-2, -1).contiguous().view(B, C, H, W)
231
+ x = img2windows(x, self.H_sp, self.W_sp)
232
+ x = (
233
+ x.reshape(-1, self.H_sp * self.W_sp, self.num_heads, C // self.num_heads)
234
+ .permute(0, 2, 1, 3)
235
+ .contiguous()
236
+ )
237
+ return x
238
+
239
+ def forward(self, qkv, H, W, mask=None):
240
+ """
241
+ Input: qkv: (B, 3*L, C), H, W, mask: (B, N, N), N is the window size
242
+ Output: x (B, H, W, C)
243
+ """
244
+ q, k, v = qkv[0], qkv[1], qkv[2]
245
+
246
+ B, L, C = q.shape
247
+ assert L == H * W, "flatten img_tokens has wrong size"
248
+
249
+ # partition the q,k,v, image to window
250
+ q = self.im2win(q, H, W)
251
+ k = self.im2win(k, H, W)
252
+ v = self.im2win(v, H, W)
253
+
254
+ q = q * self.scale
255
+ attn = q @ k.transpose(-2, -1) # B head N C @ B head C N --> B head N N
256
+
257
+ # calculate drpe
258
+ if self.position_bias:
259
+ pos = self.pos(self.rpe_biases)
260
+ # select position bias
261
+ relative_position_bias = pos[self.relative_position_index.view(-1)].view(
262
+ self.H_sp * self.W_sp, self.H_sp * self.W_sp, -1
263
+ )
264
+ relative_position_bias = relative_position_bias.permute(
265
+ 2, 0, 1
266
+ ).contiguous()
267
+ attn = attn + relative_position_bias.unsqueeze(0)
268
+
269
+ N = attn.shape[3]
270
+
271
+ # use mask for shift window
272
+ if mask is not None:
273
+ nW = mask.shape[0]
274
+ attn = attn.view(B, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(
275
+ 0
276
+ )
277
+ attn = attn.view(-1, self.num_heads, N, N)
278
+
279
+ attn = nn.functional.softmax(attn, dim=-1, dtype=attn.dtype)
280
+ attn = self.attn_drop(attn)
281
+
282
+ x = attn @ v
283
+ x = x.transpose(1, 2).reshape(
284
+ -1, self.H_sp * self.W_sp, C
285
+ ) # B head N N @ B head N C
286
+
287
+ # merge the window, window to image
288
+ x = windows2img(x, self.H_sp, self.W_sp, H, W) # B H' W' C
289
+
290
+ return x
291
+
292
+
293
+ class Adaptive_Spatial_Attention(nn.Module):
294
+ # The implementation builds on CAT code https://github.com/Zhengchen1999/CAT
295
+ """Adaptive Spatial Self-Attention
296
+ Args:
297
+ dim (int): Number of input channels.
298
+ num_heads (int): Number of attention heads. Default: 6
299
+ split_size (tuple(int)): Height and Width of spatial window.
300
+ shift_size (tuple(int)): Shift size for spatial window.
301
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
302
+ qk_scale (float | None): Override default qk scale of head_dim ** -0.5 if set.
303
+ drop (float): Dropout rate. Default: 0.0
304
+ attn_drop (float): Attention dropout rate. Default: 0.0
305
+ rg_idx (int): The indentix of Residual Group (RG)
306
+ b_idx (int): The indentix of Block in each RG
307
+ """
308
+
309
+ def __init__(
310
+ self,
311
+ dim,
312
+ num_heads,
313
+ reso=64,
314
+ split_size=[8, 8],
315
+ shift_size=[1, 2],
316
+ qkv_bias=False,
317
+ qk_scale=None,
318
+ drop=0.0,
319
+ attn_drop=0.0,
320
+ rg_idx=0,
321
+ b_idx=0,
322
+ ):
323
+ super().__init__()
324
+ self.dim = dim
325
+ self.num_heads = num_heads
326
+ self.split_size = split_size
327
+ self.shift_size = shift_size
328
+ self.b_idx = b_idx
329
+ self.rg_idx = rg_idx
330
+ self.patches_resolution = reso
331
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
332
+
333
+ assert (
334
+ 0 <= self.shift_size[0] < self.split_size[0]
335
+ ), "shift_size must in 0-split_size0"
336
+ assert (
337
+ 0 <= self.shift_size[1] < self.split_size[1]
338
+ ), "shift_size must in 0-split_size1"
339
+
340
+ self.branch_num = 2
341
+
342
+ self.proj = nn.Linear(dim, dim)
343
+ self.proj_drop = nn.Dropout(drop)
344
+
345
+ self.attns = nn.ModuleList(
346
+ [
347
+ Spatial_Attention(
348
+ dim // 2,
349
+ idx=i,
350
+ split_size=split_size,
351
+ num_heads=num_heads // 2,
352
+ dim_out=dim // 2,
353
+ qk_scale=qk_scale,
354
+ attn_drop=attn_drop,
355
+ proj_drop=drop,
356
+ position_bias=True,
357
+ )
358
+ for i in range(self.branch_num)
359
+ ]
360
+ )
361
+
362
+ if (self.rg_idx % 2 == 0 and self.b_idx > 0 and (self.b_idx - 2) % 4 == 0) or (
363
+ self.rg_idx % 2 != 0 and self.b_idx % 4 == 0
364
+ ):
365
+ attn_mask = self.calculate_mask(
366
+ self.patches_resolution, self.patches_resolution
367
+ )
368
+ self.register_buffer("attn_mask_0", attn_mask[0])
369
+ self.register_buffer("attn_mask_1", attn_mask[1])
370
+ else:
371
+ attn_mask = None
372
+ self.register_buffer("attn_mask_0", None)
373
+ self.register_buffer("attn_mask_1", None)
374
+
375
+ self.dwconv = nn.Sequential(
376
+ nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim),
377
+ nn.BatchNorm2d(dim),
378
+ nn.GELU(),
379
+ )
380
+ self.channel_interaction = nn.Sequential(
381
+ nn.AdaptiveAvgPool2d(1),
382
+ nn.Conv2d(dim, dim // 8, kernel_size=1),
383
+ nn.BatchNorm2d(dim // 8),
384
+ nn.GELU(),
385
+ nn.Conv2d(dim // 8, dim, kernel_size=1),
386
+ )
387
+ self.spatial_interaction = nn.Sequential(
388
+ nn.Conv2d(dim, dim // 16, kernel_size=1),
389
+ nn.BatchNorm2d(dim // 16),
390
+ nn.GELU(),
391
+ nn.Conv2d(dim // 16, 1, kernel_size=1),
392
+ )
393
+
394
+ def calculate_mask(self, H, W):
395
+ # The implementation builds on Swin Transformer code https://github.com/microsoft/Swin-Transformer/blob/main/models/swin_transformer.py
396
+ # calculate attention mask for shift window
397
+ img_mask_0 = torch.zeros((1, H, W, 1)) # 1 H W 1 idx=0
398
+ img_mask_1 = torch.zeros((1, H, W, 1)) # 1 H W 1 idx=1
399
+ h_slices_0 = (
400
+ slice(0, -self.split_size[0]),
401
+ slice(-self.split_size[0], -self.shift_size[0]),
402
+ slice(-self.shift_size[0], None),
403
+ )
404
+ w_slices_0 = (
405
+ slice(0, -self.split_size[1]),
406
+ slice(-self.split_size[1], -self.shift_size[1]),
407
+ slice(-self.shift_size[1], None),
408
+ )
409
+
410
+ h_slices_1 = (
411
+ slice(0, -self.split_size[1]),
412
+ slice(-self.split_size[1], -self.shift_size[1]),
413
+ slice(-self.shift_size[1], None),
414
+ )
415
+ w_slices_1 = (
416
+ slice(0, -self.split_size[0]),
417
+ slice(-self.split_size[0], -self.shift_size[0]),
418
+ slice(-self.shift_size[0], None),
419
+ )
420
+ cnt = 0
421
+ for h in h_slices_0:
422
+ for w in w_slices_0:
423
+ img_mask_0[:, h, w, :] = cnt
424
+ cnt += 1
425
+ cnt = 0
426
+ for h in h_slices_1:
427
+ for w in w_slices_1:
428
+ img_mask_1[:, h, w, :] = cnt
429
+ cnt += 1
430
+
431
+ # calculate mask for window-0
432
+ img_mask_0 = img_mask_0.view(
433
+ 1,
434
+ H // self.split_size[0],
435
+ self.split_size[0],
436
+ W // self.split_size[1],
437
+ self.split_size[1],
438
+ 1,
439
+ )
440
+ img_mask_0 = (
441
+ img_mask_0.permute(0, 1, 3, 2, 4, 5)
442
+ .contiguous()
443
+ .view(-1, self.split_size[0], self.split_size[1], 1)
444
+ ) # nW, sw[0], sw[1], 1
445
+ mask_windows_0 = img_mask_0.view(-1, self.split_size[0] * self.split_size[1])
446
+ attn_mask_0 = mask_windows_0.unsqueeze(1) - mask_windows_0.unsqueeze(2)
447
+ attn_mask_0 = attn_mask_0.masked_fill(
448
+ attn_mask_0 != 0, float(-100.0)
449
+ ).masked_fill(attn_mask_0 == 0, float(0.0))
450
+
451
+ # calculate mask for window-1
452
+ img_mask_1 = img_mask_1.view(
453
+ 1,
454
+ H // self.split_size[1],
455
+ self.split_size[1],
456
+ W // self.split_size[0],
457
+ self.split_size[0],
458
+ 1,
459
+ )
460
+ img_mask_1 = (
461
+ img_mask_1.permute(0, 1, 3, 2, 4, 5)
462
+ .contiguous()
463
+ .view(-1, self.split_size[1], self.split_size[0], 1)
464
+ ) # nW, sw[1], sw[0], 1
465
+ mask_windows_1 = img_mask_1.view(-1, self.split_size[1] * self.split_size[0])
466
+ attn_mask_1 = mask_windows_1.unsqueeze(1) - mask_windows_1.unsqueeze(2)
467
+ attn_mask_1 = attn_mask_1.masked_fill(
468
+ attn_mask_1 != 0, float(-100.0)
469
+ ).masked_fill(attn_mask_1 == 0, float(0.0))
470
+
471
+ return attn_mask_0, attn_mask_1
472
+
473
+ def forward(self, x, H, W):
474
+ """
475
+ Input: x: (B, H*W, C), H, W
476
+ Output: x: (B, H*W, C)
477
+ """
478
+ B, L, C = x.shape
479
+ assert L == H * W, "flatten img_tokens has wrong size"
480
+
481
+ qkv = self.qkv(x).reshape(B, -1, 3, C).permute(2, 0, 1, 3) # 3, B, HW, C
482
+ # V without partition
483
+ v = qkv[2].transpose(-2, -1).contiguous().view(B, C, H, W)
484
+
485
+ # image padding
486
+ max_split_size = max(self.split_size[0], self.split_size[1])
487
+ pad_l = pad_t = 0
488
+ pad_r = (max_split_size - W % max_split_size) % max_split_size
489
+ pad_b = (max_split_size - H % max_split_size) % max_split_size
490
+
491
+ qkv = qkv.reshape(3 * B, H, W, C).permute(0, 3, 1, 2) # 3B C H W
492
+ qkv = (
493
+ F.pad(qkv, (pad_l, pad_r, pad_t, pad_b))
494
+ .reshape(3, B, C, -1)
495
+ .transpose(-2, -1)
496
+ ) # l r t b
497
+ _H = pad_b + H
498
+ _W = pad_r + W
499
+ _L = _H * _W
500
+
501
+ # window-0 and window-1 on split channels [C/2, C/2]; for square windows (e.g., 8x8), window-0 and window-1 can be merged
502
+ # shift in block: (0, 4, 8, ...), (2, 6, 10, ...), (0, 4, 8, ...), (2, 6, 10, ...), ...
503
+ if (self.rg_idx % 2 == 0 and self.b_idx > 0 and (self.b_idx - 2) % 4 == 0) or (
504
+ self.rg_idx % 2 != 0 and self.b_idx % 4 == 0
505
+ ):
506
+ qkv = qkv.view(3, B, _H, _W, C)
507
+ qkv_0 = torch.roll(
508
+ qkv[:, :, :, :, : C // 2],
509
+ shifts=(-self.shift_size[0], -self.shift_size[1]),
510
+ dims=(2, 3),
511
+ )
512
+ qkv_0 = qkv_0.view(3, B, _L, C // 2)
513
+ qkv_1 = torch.roll(
514
+ qkv[:, :, :, :, C // 2 :],
515
+ shifts=(-self.shift_size[1], -self.shift_size[0]),
516
+ dims=(2, 3),
517
+ )
518
+ qkv_1 = qkv_1.view(3, B, _L, C // 2)
519
+
520
+ if self.patches_resolution != _H or self.patches_resolution != _W:
521
+ mask_tmp = self.calculate_mask(_H, _W)
522
+ x1_shift = self.attns[0](qkv_0, _H, _W, mask=mask_tmp[0].to(x.device))
523
+ x2_shift = self.attns[1](qkv_1, _H, _W, mask=mask_tmp[1].to(x.device))
524
+ else:
525
+ x1_shift = self.attns[0](qkv_0, _H, _W, mask=self.attn_mask_0)
526
+ x2_shift = self.attns[1](qkv_1, _H, _W, mask=self.attn_mask_1)
527
+
528
+ x1 = torch.roll(
529
+ x1_shift, shifts=(self.shift_size[0], self.shift_size[1]), dims=(1, 2)
530
+ )
531
+ x2 = torch.roll(
532
+ x2_shift, shifts=(self.shift_size[1], self.shift_size[0]), dims=(1, 2)
533
+ )
534
+ x1 = x1[:, :H, :W, :].reshape(B, L, C // 2)
535
+ x2 = x2[:, :H, :W, :].reshape(B, L, C // 2)
536
+ # attention output
537
+ attened_x = torch.cat([x1, x2], dim=2)
538
+
539
+ else:
540
+ x1 = self.attns[0](qkv[:, :, :, : C // 2], _H, _W)[:, :H, :W, :].reshape(
541
+ B, L, C // 2
542
+ )
543
+ x2 = self.attns[1](qkv[:, :, :, C // 2 :], _H, _W)[:, :H, :W, :].reshape(
544
+ B, L, C // 2
545
+ )
546
+ # attention output
547
+ attened_x = torch.cat([x1, x2], dim=2)
548
+
549
+ # convolution output
550
+ conv_x = self.dwconv(v)
551
+
552
+ # Adaptive Interaction Module (AIM)
553
+ # C-Map (before sigmoid)
554
+ channel_map = (
555
+ self.channel_interaction(conv_x)
556
+ .permute(0, 2, 3, 1)
557
+ .contiguous()
558
+ .view(B, 1, C)
559
+ )
560
+ # S-Map (before sigmoid)
561
+ attention_reshape = attened_x.transpose(-2, -1).contiguous().view(B, C, H, W)
562
+ spatial_map = self.spatial_interaction(attention_reshape)
563
+
564
+ # C-I
565
+ attened_x = attened_x * torch.sigmoid(channel_map)
566
+ # S-I
567
+ conv_x = torch.sigmoid(spatial_map) * conv_x
568
+ conv_x = conv_x.permute(0, 2, 3, 1).contiguous().view(B, L, C)
569
+
570
+ x = attened_x + conv_x
571
+
572
+ x = self.proj(x)
573
+ x = self.proj_drop(x)
574
+
575
+ return x
576
+
577
+
578
+ class Adaptive_Channel_Attention(nn.Module):
579
+ # The implementation builds on XCiT code https://github.com/facebookresearch/xcit
580
+ """Adaptive Channel Self-Attention
581
+ Args:
582
+ dim (int): Number of input channels.
583
+ num_heads (int): Number of attention heads. Default: 6
584
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
585
+ qk_scale (float | None): Override default qk scale of head_dim ** -0.5 if set.
586
+ attn_drop (float): Attention dropout rate. Default: 0.0
587
+ drop_path (float): Stochastic depth rate. Default: 0.0
588
+ """
589
+
590
+ def __init__(
591
+ self,
592
+ dim,
593
+ num_heads=8,
594
+ qkv_bias=False,
595
+ qk_scale=None,
596
+ attn_drop=0.0,
597
+ proj_drop=0.0,
598
+ ):
599
+ super().__init__()
600
+ self.num_heads = num_heads
601
+ self.temperature = nn.Parameter(torch.ones(num_heads, 1, 1))
602
+
603
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
604
+ self.attn_drop = nn.Dropout(attn_drop)
605
+ self.proj = nn.Linear(dim, dim)
606
+ self.proj_drop = nn.Dropout(proj_drop)
607
+
608
+ self.dwconv = nn.Sequential(
609
+ nn.Conv2d(dim, dim, kernel_size=3, stride=1, padding=1, groups=dim),
610
+ nn.BatchNorm2d(dim),
611
+ nn.GELU(),
612
+ )
613
+ self.channel_interaction = nn.Sequential(
614
+ nn.AdaptiveAvgPool2d(1),
615
+ nn.Conv2d(dim, dim // 8, kernel_size=1),
616
+ nn.BatchNorm2d(dim // 8),
617
+ nn.GELU(),
618
+ nn.Conv2d(dim // 8, dim, kernel_size=1),
619
+ )
620
+ self.spatial_interaction = nn.Sequential(
621
+ nn.Conv2d(dim, dim // 16, kernel_size=1),
622
+ nn.BatchNorm2d(dim // 16),
623
+ nn.GELU(),
624
+ nn.Conv2d(dim // 16, 1, kernel_size=1),
625
+ )
626
+
627
+ def forward(self, x, H, W):
628
+ """
629
+ Input: x: (B, H*W, C), H, W
630
+ Output: x: (B, H*W, C)
631
+ """
632
+ B, N, C = x.shape
633
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
634
+ qkv = qkv.permute(2, 0, 3, 1, 4)
635
+ q, k, v = qkv[0], qkv[1], qkv[2]
636
+
637
+ q = q.transpose(-2, -1)
638
+ k = k.transpose(-2, -1)
639
+ v = v.transpose(-2, -1)
640
+
641
+ v_ = v.reshape(B, C, N).contiguous().view(B, C, H, W)
642
+
643
+ q = torch.nn.functional.normalize(q, dim=-1)
644
+ k = torch.nn.functional.normalize(k, dim=-1)
645
+
646
+ attn = (q @ k.transpose(-2, -1)) * self.temperature
647
+ attn = attn.softmax(dim=-1)
648
+ attn = self.attn_drop(attn)
649
+
650
+ # attention output
651
+ attened_x = (attn @ v).permute(0, 3, 1, 2).reshape(B, N, C)
652
+
653
+ # convolution output
654
+ conv_x = self.dwconv(v_)
655
+
656
+ # Adaptive Interaction Module (AIM)
657
+ # C-Map (before sigmoid)
658
+ attention_reshape = attened_x.transpose(-2, -1).contiguous().view(B, C, H, W)
659
+ channel_map = self.channel_interaction(attention_reshape)
660
+ # S-Map (before sigmoid)
661
+ spatial_map = (
662
+ self.spatial_interaction(conv_x)
663
+ .permute(0, 2, 3, 1)
664
+ .contiguous()
665
+ .view(B, N, 1)
666
+ )
667
+
668
+ # S-I
669
+ attened_x = attened_x * torch.sigmoid(spatial_map)
670
+ # C-I
671
+ conv_x = conv_x * torch.sigmoid(channel_map)
672
+ conv_x = conv_x.permute(0, 2, 3, 1).contiguous().view(B, N, C)
673
+
674
+ x = attened_x + conv_x
675
+
676
+ x = self.proj(x)
677
+ x = self.proj_drop(x)
678
+
679
+ return x
680
+
681
+
682
+ class DATB(nn.Module):
683
+ def __init__(
684
+ self,
685
+ dim,
686
+ num_heads,
687
+ reso=64,
688
+ split_size=[2, 4],
689
+ shift_size=[1, 2],
690
+ expansion_factor=4.0,
691
+ qkv_bias=False,
692
+ qk_scale=None,
693
+ drop=0.0,
694
+ attn_drop=0.0,
695
+ drop_path=0.0,
696
+ act_layer=nn.GELU,
697
+ norm_layer=nn.LayerNorm,
698
+ rg_idx=0,
699
+ b_idx=0,
700
+ ):
701
+ super().__init__()
702
+
703
+ self.norm1 = norm_layer(dim)
704
+
705
+ if b_idx % 2 == 0:
706
+ # DSTB
707
+ self.attn = Adaptive_Spatial_Attention(
708
+ dim,
709
+ num_heads=num_heads,
710
+ reso=reso,
711
+ split_size=split_size,
712
+ shift_size=shift_size,
713
+ qkv_bias=qkv_bias,
714
+ qk_scale=qk_scale,
715
+ drop=drop,
716
+ attn_drop=attn_drop,
717
+ rg_idx=rg_idx,
718
+ b_idx=b_idx,
719
+ )
720
+ else:
721
+ # DCTB
722
+ self.attn = Adaptive_Channel_Attention(
723
+ dim,
724
+ num_heads=num_heads,
725
+ qkv_bias=qkv_bias,
726
+ qk_scale=qk_scale,
727
+ attn_drop=attn_drop,
728
+ proj_drop=drop,
729
+ )
730
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
731
+
732
+ ffn_hidden_dim = int(dim * expansion_factor)
733
+ self.ffn = SGFN(
734
+ in_features=dim,
735
+ hidden_features=ffn_hidden_dim,
736
+ out_features=dim,
737
+ act_layer=act_layer,
738
+ )
739
+ self.norm2 = norm_layer(dim)
740
+
741
+ def forward(self, x, x_size):
742
+ """
743
+ Input: x: (B, H*W, C), x_size: (H, W)
744
+ Output: x: (B, H*W, C)
745
+ """
746
+ H, W = x_size
747
+ x = x + self.drop_path(self.attn(self.norm1(x), H, W))
748
+ x = x + self.drop_path(self.ffn(self.norm2(x), H, W))
749
+
750
+ return x
751
+
752
+
753
+ class ResidualGroup(nn.Module):
754
+ """ResidualGroup
755
+ Args:
756
+ dim (int): Number of input channels.
757
+ reso (int): Input resolution.
758
+ num_heads (int): Number of attention heads.
759
+ split_size (tuple(int)): Height and Width of spatial window.
760
+ expansion_factor (float): Ratio of ffn hidden dim to embedding dim.
761
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
762
+ qk_scale (float | None): Override default qk scale of head_dim ** -0.5 if set. Default: None
763
+ drop (float): Dropout rate. Default: 0
764
+ attn_drop(float): Attention dropout rate. Default: 0
765
+ drop_paths (float | None): Stochastic depth rate.
766
+ act_layer (nn.Module): Activation layer. Default: nn.GELU
767
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm
768
+ depth (int): Number of dual aggregation Transformer blocks in residual group.
769
+ use_chk (bool): Whether to use checkpointing to save memory.
770
+ resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
771
+ """
772
+
773
+ def __init__(
774
+ self,
775
+ dim,
776
+ reso,
777
+ num_heads,
778
+ split_size=[2, 4],
779
+ expansion_factor=4.0,
780
+ qkv_bias=False,
781
+ qk_scale=None,
782
+ drop=0.0,
783
+ attn_drop=0.0,
784
+ drop_paths=None,
785
+ act_layer=nn.GELU,
786
+ norm_layer=nn.LayerNorm,
787
+ depth=2,
788
+ use_chk=False,
789
+ resi_connection="1conv",
790
+ rg_idx=0,
791
+ ):
792
+ super().__init__()
793
+ self.use_chk = use_chk
794
+ self.reso = reso
795
+
796
+ self.blocks = nn.ModuleList(
797
+ [
798
+ DATB(
799
+ dim=dim,
800
+ num_heads=num_heads,
801
+ reso=reso,
802
+ split_size=split_size,
803
+ shift_size=[split_size[0] // 2, split_size[1] // 2],
804
+ expansion_factor=expansion_factor,
805
+ qkv_bias=qkv_bias,
806
+ qk_scale=qk_scale,
807
+ drop=drop,
808
+ attn_drop=attn_drop,
809
+ drop_path=drop_paths[i],
810
+ act_layer=act_layer,
811
+ norm_layer=norm_layer,
812
+ rg_idx=rg_idx,
813
+ b_idx=i,
814
+ )
815
+ for i in range(depth)
816
+ ]
817
+ )
818
+
819
+ if resi_connection == "1conv":
820
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
821
+ elif resi_connection == "3conv":
822
+ self.conv = nn.Sequential(
823
+ nn.Conv2d(dim, dim // 4, 3, 1, 1),
824
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
825
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
826
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
827
+ nn.Conv2d(dim // 4, dim, 3, 1, 1),
828
+ )
829
+
830
+ def forward(self, x, x_size):
831
+ """
832
+ Input: x: (B, H*W, C), x_size: (H, W)
833
+ Output: x: (B, H*W, C)
834
+ """
835
+ H, W = x_size
836
+ res = x
837
+ for blk in self.blocks:
838
+ if self.use_chk:
839
+ x = checkpoint.checkpoint(blk, x, x_size)
840
+ else:
841
+ x = blk(x, x_size)
842
+ x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W)
843
+ x = self.conv(x)
844
+ x = rearrange(x, "b c h w -> b (h w) c")
845
+ x = res + x
846
+
847
+ return x
848
+
849
+
850
+ class Upsample(nn.Sequential):
851
+ """Upsample module.
852
+ Args:
853
+ scale (int): Scale factor. Supported scales: 2^n and 3.
854
+ num_feat (int): Channel number of intermediate features.
855
+ """
856
+
857
+ def __init__(self, scale, num_feat):
858
+ m = []
859
+ if (scale & (scale - 1)) == 0: # scale = 2^n
860
+ for _ in range(int(math.log(scale, 2))):
861
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
862
+ m.append(nn.PixelShuffle(2))
863
+ elif scale == 3:
864
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
865
+ m.append(nn.PixelShuffle(3))
866
+ else:
867
+ raise ValueError(
868
+ f"scale {scale} is not supported. " "Supported scales: 2^n and 3."
869
+ )
870
+ super(Upsample, self).__init__(*m)
871
+
872
+
873
+ class UpsampleOneStep(nn.Sequential):
874
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
875
+ Used in lightweight SR to save parameters.
876
+
877
+ Args:
878
+ scale (int): Scale factor. Supported scales: 2^n and 3.
879
+ num_feat (int): Channel number of intermediate features.
880
+
881
+ """
882
+
883
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
884
+ self.num_feat = num_feat
885
+ self.input_resolution = input_resolution
886
+ m = []
887
+ m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1))
888
+ m.append(nn.PixelShuffle(scale))
889
+ super(UpsampleOneStep, self).__init__(*m)
890
+
891
+ def flops(self):
892
+ h, w = self.input_resolution
893
+ flops = h * w * self.num_feat * 3 * 9
894
+ return flops
895
+
896
+
897
+ class DAT(nn.Module):
898
+ """Dual Aggregation Transformer
899
+ Args:
900
+ img_size (int): Input image size. Default: 64
901
+ in_chans (int): Number of input image channels. Default: 3
902
+ embed_dim (int): Patch embedding dimension. Default: 180
903
+ depths (tuple(int)): Depth of each residual group (number of DATB in each RG).
904
+ split_size (tuple(int)): Height and Width of spatial window.
905
+ num_heads (tuple(int)): Number of attention heads in different residual groups.
906
+ expansion_factor (float): Ratio of ffn hidden dim to embedding dim. Default: 4
907
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
908
+ qk_scale (float | None): Override default qk scale of head_dim ** -0.5 if set. Default: None
909
+ drop_rate (float): Dropout rate. Default: 0
910
+ attn_drop_rate (float): Attention dropout rate. Default: 0
911
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
912
+ act_layer (nn.Module): Activation layer. Default: nn.GELU
913
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm
914
+ use_chk (bool): Whether to use checkpointing to save memory.
915
+ upscale: Upscale factor. 2/3/4 for image SR
916
+ img_range: Image range. 1. or 255.
917
+ resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
918
+ """
919
+
920
+ def __init__(self, state_dict):
921
+ super().__init__()
922
+
923
+ # defaults
924
+ img_size = 64
925
+ in_chans = 3
926
+ embed_dim = 180
927
+ split_size = [2, 4]
928
+ depth = [2, 2, 2, 2]
929
+ num_heads = [2, 2, 2, 2]
930
+ expansion_factor = 4.0
931
+ qkv_bias = True
932
+ qk_scale = None
933
+ drop_rate = 0.0
934
+ attn_drop_rate = 0.0
935
+ drop_path_rate = 0.1
936
+ act_layer = nn.GELU
937
+ norm_layer = nn.LayerNorm
938
+ use_chk = False
939
+ upscale = 2
940
+ img_range = 1.0
941
+ resi_connection = "1conv"
942
+ upsampler = "pixelshuffle"
943
+
944
+ self.model_arch = "DAT"
945
+ self.sub_type = "SR"
946
+ self.state = state_dict
947
+
948
+ state_keys = state_dict.keys()
949
+ if "conv_before_upsample.0.weight" in state_keys:
950
+ if "conv_up1.weight" in state_keys:
951
+ upsampler = "nearest+conv"
952
+ else:
953
+ upsampler = "pixelshuffle"
954
+ supports_fp16 = False
955
+ elif "upsample.0.weight" in state_keys:
956
+ upsampler = "pixelshuffledirect"
957
+ else:
958
+ upsampler = ""
959
+
960
+ num_feat = (
961
+ state_dict.get("conv_before_upsample.0.weight", None).shape[1]
962
+ if state_dict.get("conv_before_upsample.weight", None)
963
+ else 64
964
+ )
965
+
966
+ num_in_ch = state_dict["conv_first.weight"].shape[1]
967
+ in_chans = num_in_ch
968
+ if "conv_last.weight" in state_keys:
969
+ num_out_ch = state_dict["conv_last.weight"].shape[0]
970
+ else:
971
+ num_out_ch = num_in_ch
972
+
973
+ upscale = 1
974
+ if upsampler == "nearest+conv":
975
+ upsample_keys = [
976
+ x for x in state_keys if "conv_up" in x and "bias" not in x
977
+ ]
978
+
979
+ for upsample_key in upsample_keys:
980
+ upscale *= 2
981
+ elif upsampler == "pixelshuffle":
982
+ upsample_keys = [
983
+ x
984
+ for x in state_keys
985
+ if "upsample" in x and "conv" not in x and "bias" not in x
986
+ ]
987
+ for upsample_key in upsample_keys:
988
+ shape = state_dict[upsample_key].shape[0]
989
+ upscale *= math.sqrt(shape // num_feat)
990
+ upscale = int(upscale)
991
+ elif upsampler == "pixelshuffledirect":
992
+ upscale = int(
993
+ math.sqrt(state_dict["upsample.0.bias"].shape[0] // num_out_ch)
994
+ )
995
+
996
+ max_layer_num = 0
997
+ max_block_num = 0
998
+ for key in state_keys:
999
+ result = re.match(r"layers.(\d*).blocks.(\d*).norm1.weight", key)
1000
+ if result:
1001
+ layer_num, block_num = result.groups()
1002
+ max_layer_num = max(max_layer_num, int(layer_num))
1003
+ max_block_num = max(max_block_num, int(block_num))
1004
+
1005
+ depth = [max_block_num + 1 for _ in range(max_layer_num + 1)]
1006
+
1007
+ if "layers.0.blocks.1.attn.temperature" in state_keys:
1008
+ num_heads_num = state_dict["layers.0.blocks.1.attn.temperature"].shape[0]
1009
+ num_heads = [num_heads_num for _ in range(max_layer_num + 1)]
1010
+ else:
1011
+ num_heads = depth
1012
+
1013
+ embed_dim = state_dict["conv_first.weight"].shape[0]
1014
+ expansion_factor = float(
1015
+ state_dict["layers.0.blocks.0.ffn.fc1.weight"].shape[0] / embed_dim
1016
+ )
1017
+
1018
+ # TODO: could actually count the layers, but this should do
1019
+ if "layers.0.conv.4.weight" in state_keys:
1020
+ resi_connection = "3conv"
1021
+ else:
1022
+ resi_connection = "1conv"
1023
+
1024
+ if "layers.0.blocks.2.attn.attn_mask_0" in state_keys:
1025
+ attn_mask_0_x, attn_mask_0_y, attn_mask_0_z = state_dict[
1026
+ "layers.0.blocks.2.attn.attn_mask_0"
1027
+ ].shape
1028
+
1029
+ img_size = int(math.sqrt(attn_mask_0_x * attn_mask_0_y))
1030
+
1031
+ if "layers.0.blocks.0.attn.attns.0.rpe_biases" in state_keys:
1032
+ split_sizes = (
1033
+ state_dict["layers.0.blocks.0.attn.attns.0.rpe_biases"][-1] + 1
1034
+ )
1035
+ split_size = [int(x) for x in split_sizes]
1036
+
1037
+ self.in_nc = num_in_ch
1038
+ self.out_nc = num_out_ch
1039
+ self.num_feat = num_feat
1040
+ self.embed_dim = embed_dim
1041
+ self.num_heads = num_heads
1042
+ self.depth = depth
1043
+ self.scale = upscale
1044
+ self.upsampler = upsampler
1045
+ self.img_size = img_size
1046
+ self.img_range = img_range
1047
+ self.expansion_factor = expansion_factor
1048
+ self.resi_connection = resi_connection
1049
+ self.split_size = split_size
1050
+
1051
+ self.supports_fp16 = False # Too much weirdness to support this at the moment
1052
+ self.supports_bfp16 = True
1053
+ self.min_size_restriction = 16
1054
+
1055
+ num_in_ch = in_chans
1056
+ num_out_ch = in_chans
1057
+ num_feat = 64
1058
+ self.img_range = img_range
1059
+ if in_chans == 3:
1060
+ rgb_mean = (0.4488, 0.4371, 0.4040)
1061
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
1062
+ else:
1063
+ self.mean = torch.zeros(1, 1, 1, 1)
1064
+ self.upscale = upscale
1065
+ self.upsampler = upsampler
1066
+
1067
+ # ------------------------- 1, Shallow Feature Extraction ------------------------- #
1068
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
1069
+
1070
+ # ------------------------- 2, Deep Feature Extraction ------------------------- #
1071
+ self.num_layers = len(depth)
1072
+ self.use_chk = use_chk
1073
+ self.num_features = (
1074
+ self.embed_dim
1075
+ ) = embed_dim # num_features for consistency with other models
1076
+ heads = num_heads
1077
+
1078
+ self.before_RG = nn.Sequential(
1079
+ Rearrange("b c h w -> b (h w) c"), nn.LayerNorm(embed_dim)
1080
+ )
1081
+
1082
+ curr_dim = embed_dim
1083
+ dpr = [
1084
+ x.item() for x in torch.linspace(0, drop_path_rate, np.sum(depth))
1085
+ ] # stochastic depth decay rule
1086
+
1087
+ self.layers = nn.ModuleList()
1088
+ for i in range(self.num_layers):
1089
+ layer = ResidualGroup(
1090
+ dim=embed_dim,
1091
+ num_heads=heads[i],
1092
+ reso=img_size,
1093
+ split_size=split_size,
1094
+ expansion_factor=expansion_factor,
1095
+ qkv_bias=qkv_bias,
1096
+ qk_scale=qk_scale,
1097
+ drop=drop_rate,
1098
+ attn_drop=attn_drop_rate,
1099
+ drop_paths=dpr[sum(depth[:i]) : sum(depth[: i + 1])],
1100
+ act_layer=act_layer,
1101
+ norm_layer=norm_layer,
1102
+ depth=depth[i],
1103
+ use_chk=use_chk,
1104
+ resi_connection=resi_connection,
1105
+ rg_idx=i,
1106
+ )
1107
+ self.layers.append(layer)
1108
+
1109
+ self.norm = norm_layer(curr_dim)
1110
+ # build the last conv layer in deep feature extraction
1111
+ if resi_connection == "1conv":
1112
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
1113
+ elif resi_connection == "3conv":
1114
+ # to save parameters and memory
1115
+ self.conv_after_body = nn.Sequential(
1116
+ nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
1117
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
1118
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
1119
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
1120
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1),
1121
+ )
1122
+
1123
+ # ------------------------- 3, Reconstruction ------------------------- #
1124
+ if self.upsampler == "pixelshuffle":
1125
+ # for classical SR
1126
+ self.conv_before_upsample = nn.Sequential(
1127
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
1128
+ )
1129
+ self.upsample = Upsample(upscale, num_feat)
1130
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
1131
+ elif self.upsampler == "pixelshuffledirect":
1132
+ # for lightweight SR (to save parameters)
1133
+ self.upsample = UpsampleOneStep(
1134
+ upscale, embed_dim, num_out_ch, (img_size, img_size)
1135
+ )
1136
+
1137
+ self.apply(self._init_weights)
1138
+ self.load_state_dict(state_dict, strict=True)
1139
+
1140
+ def _init_weights(self, m):
1141
+ if isinstance(m, nn.Linear):
1142
+ trunc_normal_(m.weight, std=0.02)
1143
+ if isinstance(m, nn.Linear) and m.bias is not None:
1144
+ nn.init.constant_(m.bias, 0)
1145
+ elif isinstance(
1146
+ m, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm, nn.InstanceNorm2d)
1147
+ ):
1148
+ nn.init.constant_(m.bias, 0)
1149
+ nn.init.constant_(m.weight, 1.0)
1150
+
1151
+ def forward_features(self, x):
1152
+ _, _, H, W = x.shape
1153
+ x_size = [H, W]
1154
+ x = self.before_RG(x)
1155
+ for layer in self.layers:
1156
+ x = layer(x, x_size)
1157
+ x = self.norm(x)
1158
+ x = rearrange(x, "b (h w) c -> b c h w", h=H, w=W)
1159
+
1160
+ return x
1161
+
1162
+ def forward(self, x):
1163
+ """
1164
+ Input: x: (B, C, H, W)
1165
+ """
1166
+ self.mean = self.mean.type_as(x)
1167
+ x = (x - self.mean) * self.img_range
1168
+
1169
+ if self.upsampler == "pixelshuffle":
1170
+ # for image SR
1171
+ x = self.conv_first(x)
1172
+ x = self.conv_after_body(self.forward_features(x)) + x
1173
+ x = self.conv_before_upsample(x)
1174
+ x = self.conv_last(self.upsample(x))
1175
+ elif self.upsampler == "pixelshuffledirect":
1176
+ # for lightweight SR
1177
+ x = self.conv_first(x)
1178
+ x = self.conv_after_body(self.forward_features(x)) + x
1179
+ x = self.upsample(x)
1180
+
1181
+ x = x / self.img_range + self.mean
1182
+ return x
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/HAT.py ADDED
@@ -0,0 +1,1277 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ # HAT from https://github.com/XPixelGroup/HAT/blob/main/hat/archs/hat_arch.py
3
+ import math
4
+ import re
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ from einops import rearrange
10
+
11
+ from .timm.helpers import to_2tuple
12
+ from .timm.weight_init import trunc_normal_
13
+
14
+
15
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
16
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
17
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
18
+ """
19
+ if drop_prob == 0.0 or not training:
20
+ return x
21
+ keep_prob = 1 - drop_prob
22
+ shape = (x.shape[0],) + (1,) * (
23
+ x.ndim - 1
24
+ ) # work with diff dim tensors, not just 2D ConvNets
25
+ random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
26
+ random_tensor.floor_() # binarize
27
+ output = x.div(keep_prob) * random_tensor
28
+ return output
29
+
30
+
31
+ class DropPath(nn.Module):
32
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
33
+ From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
34
+ """
35
+
36
+ def __init__(self, drop_prob=None):
37
+ super(DropPath, self).__init__()
38
+ self.drop_prob = drop_prob
39
+
40
+ def forward(self, x):
41
+ return drop_path(x, self.drop_prob, self.training) # type: ignore
42
+
43
+
44
+ class ChannelAttention(nn.Module):
45
+ """Channel attention used in RCAN.
46
+ Args:
47
+ num_feat (int): Channel number of intermediate features.
48
+ squeeze_factor (int): Channel squeeze factor. Default: 16.
49
+ """
50
+
51
+ def __init__(self, num_feat, squeeze_factor=16):
52
+ super(ChannelAttention, self).__init__()
53
+ self.attention = nn.Sequential(
54
+ nn.AdaptiveAvgPool2d(1),
55
+ nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0),
56
+ nn.ReLU(inplace=True),
57
+ nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0),
58
+ nn.Sigmoid(),
59
+ )
60
+
61
+ def forward(self, x):
62
+ y = self.attention(x)
63
+ return x * y
64
+
65
+
66
+ class CAB(nn.Module):
67
+ def __init__(self, num_feat, compress_ratio=3, squeeze_factor=30):
68
+ super(CAB, self).__init__()
69
+
70
+ self.cab = nn.Sequential(
71
+ nn.Conv2d(num_feat, num_feat // compress_ratio, 3, 1, 1),
72
+ nn.GELU(),
73
+ nn.Conv2d(num_feat // compress_ratio, num_feat, 3, 1, 1),
74
+ ChannelAttention(num_feat, squeeze_factor),
75
+ )
76
+
77
+ def forward(self, x):
78
+ return self.cab(x)
79
+
80
+
81
+ class Mlp(nn.Module):
82
+ def __init__(
83
+ self,
84
+ in_features,
85
+ hidden_features=None,
86
+ out_features=None,
87
+ act_layer=nn.GELU,
88
+ drop=0.0,
89
+ ):
90
+ super().__init__()
91
+ out_features = out_features or in_features
92
+ hidden_features = hidden_features or in_features
93
+ self.fc1 = nn.Linear(in_features, hidden_features)
94
+ self.act = act_layer()
95
+ self.fc2 = nn.Linear(hidden_features, out_features)
96
+ self.drop = nn.Dropout(drop)
97
+
98
+ def forward(self, x):
99
+ x = self.fc1(x)
100
+ x = self.act(x)
101
+ x = self.drop(x)
102
+ x = self.fc2(x)
103
+ x = self.drop(x)
104
+ return x
105
+
106
+
107
+ def window_partition(x, window_size):
108
+ """
109
+ Args:
110
+ x: (b, h, w, c)
111
+ window_size (int): window size
112
+ Returns:
113
+ windows: (num_windows*b, window_size, window_size, c)
114
+ """
115
+ b, h, w, c = x.shape
116
+ x = x.view(b, h // window_size, window_size, w // window_size, window_size, c)
117
+ windows = (
118
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, c)
119
+ )
120
+ return windows
121
+
122
+
123
+ def window_reverse(windows, window_size, h, w):
124
+ """
125
+ Args:
126
+ windows: (num_windows*b, window_size, window_size, c)
127
+ window_size (int): Window size
128
+ h (int): Height of image
129
+ w (int): Width of image
130
+ Returns:
131
+ x: (b, h, w, c)
132
+ """
133
+ b = int(windows.shape[0] / (h * w / window_size / window_size))
134
+ x = windows.view(
135
+ b, h // window_size, w // window_size, window_size, window_size, -1
136
+ )
137
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(b, h, w, -1)
138
+ return x
139
+
140
+
141
+ class WindowAttention(nn.Module):
142
+ r"""Window based multi-head self attention (W-MSA) module with relative position bias.
143
+ It supports both of shifted and non-shifted window.
144
+ Args:
145
+ dim (int): Number of input channels.
146
+ window_size (tuple[int]): The height and width of the window.
147
+ num_heads (int): Number of attention heads.
148
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
149
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
150
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
151
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
152
+ """
153
+
154
+ def __init__(
155
+ self,
156
+ dim,
157
+ window_size,
158
+ num_heads,
159
+ qkv_bias=True,
160
+ qk_scale=None,
161
+ attn_drop=0.0,
162
+ proj_drop=0.0,
163
+ ):
164
+ super().__init__()
165
+ self.dim = dim
166
+ self.window_size = window_size # Wh, Ww
167
+ self.num_heads = num_heads
168
+ head_dim = dim // num_heads
169
+ self.scale = qk_scale or head_dim**-0.5
170
+
171
+ # define a parameter table of relative position bias
172
+ self.relative_position_bias_table = nn.Parameter( # type: ignore
173
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
174
+ ) # 2*Wh-1 * 2*Ww-1, nH
175
+
176
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
177
+ self.attn_drop = nn.Dropout(attn_drop)
178
+ self.proj = nn.Linear(dim, dim)
179
+
180
+ self.proj_drop = nn.Dropout(proj_drop)
181
+
182
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
183
+ self.softmax = nn.Softmax(dim=-1)
184
+
185
+ def forward(self, x, rpi, mask=None):
186
+ """
187
+ Args:
188
+ x: input features with shape of (num_windows*b, n, c)
189
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
190
+ """
191
+ b_, n, c = x.shape
192
+ qkv = (
193
+ self.qkv(x)
194
+ .reshape(b_, n, 3, self.num_heads, c // self.num_heads)
195
+ .permute(2, 0, 3, 1, 4)
196
+ )
197
+ q, k, v = (
198
+ qkv[0],
199
+ qkv[1],
200
+ qkv[2],
201
+ ) # make torchscript happy (cannot use tensor as tuple)
202
+
203
+ q = q * self.scale
204
+ attn = q @ k.transpose(-2, -1)
205
+
206
+ relative_position_bias = self.relative_position_bias_table[rpi.view(-1)].view(
207
+ self.window_size[0] * self.window_size[1],
208
+ self.window_size[0] * self.window_size[1],
209
+ -1,
210
+ ) # Wh*Ww,Wh*Ww,nH
211
+ relative_position_bias = relative_position_bias.permute(
212
+ 2, 0, 1
213
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
214
+ attn = attn + relative_position_bias.unsqueeze(0)
215
+
216
+ if mask is not None:
217
+ nw = mask.shape[0]
218
+ attn = attn.view(b_ // nw, nw, self.num_heads, n, n) + mask.unsqueeze(
219
+ 1
220
+ ).unsqueeze(0)
221
+ attn = attn.view(-1, self.num_heads, n, n)
222
+ attn = self.softmax(attn)
223
+ else:
224
+ attn = self.softmax(attn)
225
+
226
+ attn = self.attn_drop(attn)
227
+
228
+ x = (attn @ v).transpose(1, 2).reshape(b_, n, c)
229
+ x = self.proj(x)
230
+ x = self.proj_drop(x)
231
+ return x
232
+
233
+
234
+ class HAB(nn.Module):
235
+ r"""Hybrid Attention Block.
236
+ Args:
237
+ dim (int): Number of input channels.
238
+ input_resolution (tuple[int]): Input resolution.
239
+ num_heads (int): Number of attention heads.
240
+ window_size (int): Window size.
241
+ shift_size (int): Shift size for SW-MSA.
242
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
243
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
244
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
245
+ drop (float, optional): Dropout rate. Default: 0.0
246
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
247
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
248
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
249
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
250
+ """
251
+
252
+ def __init__(
253
+ self,
254
+ dim,
255
+ input_resolution,
256
+ num_heads,
257
+ window_size=7,
258
+ shift_size=0,
259
+ compress_ratio=3,
260
+ squeeze_factor=30,
261
+ conv_scale=0.01,
262
+ mlp_ratio=4.0,
263
+ qkv_bias=True,
264
+ qk_scale=None,
265
+ drop=0.0,
266
+ attn_drop=0.0,
267
+ drop_path=0.0,
268
+ act_layer=nn.GELU,
269
+ norm_layer=nn.LayerNorm,
270
+ ):
271
+ super().__init__()
272
+ self.dim = dim
273
+ self.input_resolution = input_resolution
274
+ self.num_heads = num_heads
275
+ self.window_size = window_size
276
+ self.shift_size = shift_size
277
+ self.mlp_ratio = mlp_ratio
278
+ if min(self.input_resolution) <= self.window_size:
279
+ # if window size is larger than input resolution, we don't partition windows
280
+ self.shift_size = 0
281
+ self.window_size = min(self.input_resolution)
282
+ assert (
283
+ 0 <= self.shift_size < self.window_size
284
+ ), "shift_size must in 0-window_size"
285
+
286
+ self.norm1 = norm_layer(dim)
287
+ self.attn = WindowAttention(
288
+ dim,
289
+ window_size=to_2tuple(self.window_size),
290
+ num_heads=num_heads,
291
+ qkv_bias=qkv_bias,
292
+ qk_scale=qk_scale,
293
+ attn_drop=attn_drop,
294
+ proj_drop=drop,
295
+ )
296
+
297
+ self.conv_scale = conv_scale
298
+ self.conv_block = CAB(
299
+ num_feat=dim, compress_ratio=compress_ratio, squeeze_factor=squeeze_factor
300
+ )
301
+
302
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
303
+ self.norm2 = norm_layer(dim)
304
+ mlp_hidden_dim = int(dim * mlp_ratio)
305
+ self.mlp = Mlp(
306
+ in_features=dim,
307
+ hidden_features=mlp_hidden_dim,
308
+ act_layer=act_layer,
309
+ drop=drop,
310
+ )
311
+
312
+ def forward(self, x, x_size, rpi_sa, attn_mask):
313
+ h, w = x_size
314
+ b, _, c = x.shape
315
+ # assert seq_len == h * w, "input feature has wrong size"
316
+
317
+ shortcut = x
318
+ x = self.norm1(x)
319
+ x = x.view(b, h, w, c)
320
+
321
+ # Conv_X
322
+ conv_x = self.conv_block(x.permute(0, 3, 1, 2))
323
+ conv_x = conv_x.permute(0, 2, 3, 1).contiguous().view(b, h * w, c)
324
+
325
+ # cyclic shift
326
+ if self.shift_size > 0:
327
+ shifted_x = torch.roll(
328
+ x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
329
+ )
330
+ attn_mask = attn_mask
331
+ else:
332
+ shifted_x = x
333
+ attn_mask = None
334
+
335
+ # partition windows
336
+ x_windows = window_partition(
337
+ shifted_x, self.window_size
338
+ ) # nw*b, window_size, window_size, c
339
+ x_windows = x_windows.view(
340
+ -1, self.window_size * self.window_size, c
341
+ ) # nw*b, window_size*window_size, c
342
+
343
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
344
+ attn_windows = self.attn(x_windows, rpi=rpi_sa, mask=attn_mask)
345
+
346
+ # merge windows
347
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, c)
348
+ shifted_x = window_reverse(attn_windows, self.window_size, h, w) # b h' w' c
349
+
350
+ # reverse cyclic shift
351
+ if self.shift_size > 0:
352
+ attn_x = torch.roll(
353
+ shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
354
+ )
355
+ else:
356
+ attn_x = shifted_x
357
+ attn_x = attn_x.view(b, h * w, c)
358
+
359
+ # FFN
360
+ x = shortcut + self.drop_path(attn_x) + conv_x * self.conv_scale
361
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
362
+
363
+ return x
364
+
365
+
366
+ class PatchMerging(nn.Module):
367
+ r"""Patch Merging Layer.
368
+ Args:
369
+ input_resolution (tuple[int]): Resolution of input feature.
370
+ dim (int): Number of input channels.
371
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
372
+ """
373
+
374
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
375
+ super().__init__()
376
+ self.input_resolution = input_resolution
377
+ self.dim = dim
378
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
379
+ self.norm = norm_layer(4 * dim)
380
+
381
+ def forward(self, x):
382
+ """
383
+ x: b, h*w, c
384
+ """
385
+ h, w = self.input_resolution
386
+ b, seq_len, c = x.shape
387
+ assert seq_len == h * w, "input feature has wrong size"
388
+ assert h % 2 == 0 and w % 2 == 0, f"x size ({h}*{w}) are not even."
389
+
390
+ x = x.view(b, h, w, c)
391
+
392
+ x0 = x[:, 0::2, 0::2, :] # b h/2 w/2 c
393
+ x1 = x[:, 1::2, 0::2, :] # b h/2 w/2 c
394
+ x2 = x[:, 0::2, 1::2, :] # b h/2 w/2 c
395
+ x3 = x[:, 1::2, 1::2, :] # b h/2 w/2 c
396
+ x = torch.cat([x0, x1, x2, x3], -1) # b h/2 w/2 4*c
397
+ x = x.view(b, -1, 4 * c) # b h/2*w/2 4*c
398
+
399
+ x = self.norm(x)
400
+ x = self.reduction(x)
401
+
402
+ return x
403
+
404
+
405
+ class OCAB(nn.Module):
406
+ # overlapping cross-attention block
407
+
408
+ def __init__(
409
+ self,
410
+ dim,
411
+ input_resolution,
412
+ window_size,
413
+ overlap_ratio,
414
+ num_heads,
415
+ qkv_bias=True,
416
+ qk_scale=None,
417
+ mlp_ratio=2,
418
+ norm_layer=nn.LayerNorm,
419
+ ):
420
+ super().__init__()
421
+ self.dim = dim
422
+ self.input_resolution = input_resolution
423
+ self.window_size = window_size
424
+ self.num_heads = num_heads
425
+ head_dim = dim // num_heads
426
+ self.scale = qk_scale or head_dim**-0.5
427
+ self.overlap_win_size = int(window_size * overlap_ratio) + window_size
428
+
429
+ self.norm1 = norm_layer(dim)
430
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
431
+ self.unfold = nn.Unfold(
432
+ kernel_size=(self.overlap_win_size, self.overlap_win_size),
433
+ stride=window_size,
434
+ padding=(self.overlap_win_size - window_size) // 2,
435
+ )
436
+
437
+ # define a parameter table of relative position bias
438
+ self.relative_position_bias_table = nn.Parameter( # type: ignore
439
+ torch.zeros(
440
+ (window_size + self.overlap_win_size - 1)
441
+ * (window_size + self.overlap_win_size - 1),
442
+ num_heads,
443
+ )
444
+ ) # 2*Wh-1 * 2*Ww-1, nH
445
+
446
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
447
+ self.softmax = nn.Softmax(dim=-1)
448
+
449
+ self.proj = nn.Linear(dim, dim)
450
+
451
+ self.norm2 = norm_layer(dim)
452
+ mlp_hidden_dim = int(dim * mlp_ratio)
453
+ self.mlp = Mlp(
454
+ in_features=dim, hidden_features=mlp_hidden_dim, act_layer=nn.GELU
455
+ )
456
+
457
+ def forward(self, x, x_size, rpi):
458
+ h, w = x_size
459
+ b, _, c = x.shape
460
+
461
+ shortcut = x
462
+ x = self.norm1(x)
463
+ x = x.view(b, h, w, c)
464
+
465
+ qkv = self.qkv(x).reshape(b, h, w, 3, c).permute(3, 0, 4, 1, 2) # 3, b, c, h, w
466
+ q = qkv[0].permute(0, 2, 3, 1) # b, h, w, c
467
+ kv = torch.cat((qkv[1], qkv[2]), dim=1) # b, 2*c, h, w
468
+
469
+ # partition windows
470
+ q_windows = window_partition(
471
+ q, self.window_size
472
+ ) # nw*b, window_size, window_size, c
473
+ q_windows = q_windows.view(
474
+ -1, self.window_size * self.window_size, c
475
+ ) # nw*b, window_size*window_size, c
476
+
477
+ kv_windows = self.unfold(kv) # b, c*w*w, nw
478
+ kv_windows = rearrange(
479
+ kv_windows,
480
+ "b (nc ch owh oww) nw -> nc (b nw) (owh oww) ch",
481
+ nc=2,
482
+ ch=c,
483
+ owh=self.overlap_win_size,
484
+ oww=self.overlap_win_size,
485
+ ).contiguous() # 2, nw*b, ow*ow, c
486
+ # Do the above rearrangement without the rearrange function
487
+ # kv_windows = kv_windows.view(
488
+ # 2, b, self.overlap_win_size, self.overlap_win_size, c, -1
489
+ # )
490
+ # kv_windows = kv_windows.permute(0, 5, 1, 2, 3, 4).contiguous()
491
+ # kv_windows = kv_windows.view(
492
+ # 2, -1, self.overlap_win_size * self.overlap_win_size, c
493
+ # )
494
+
495
+ k_windows, v_windows = kv_windows[0], kv_windows[1] # nw*b, ow*ow, c
496
+
497
+ b_, nq, _ = q_windows.shape
498
+ _, n, _ = k_windows.shape
499
+ d = self.dim // self.num_heads
500
+ q = q_windows.reshape(b_, nq, self.num_heads, d).permute(
501
+ 0, 2, 1, 3
502
+ ) # nw*b, nH, nq, d
503
+ k = k_windows.reshape(b_, n, self.num_heads, d).permute(
504
+ 0, 2, 1, 3
505
+ ) # nw*b, nH, n, d
506
+ v = v_windows.reshape(b_, n, self.num_heads, d).permute(
507
+ 0, 2, 1, 3
508
+ ) # nw*b, nH, n, d
509
+
510
+ q = q * self.scale
511
+ attn = q @ k.transpose(-2, -1)
512
+
513
+ relative_position_bias = self.relative_position_bias_table[rpi.view(-1)].view(
514
+ self.window_size * self.window_size,
515
+ self.overlap_win_size * self.overlap_win_size,
516
+ -1,
517
+ ) # ws*ws, wse*wse, nH
518
+ relative_position_bias = relative_position_bias.permute(
519
+ 2, 0, 1
520
+ ).contiguous() # nH, ws*ws, wse*wse
521
+ attn = attn + relative_position_bias.unsqueeze(0)
522
+
523
+ attn = self.softmax(attn)
524
+ attn_windows = (attn @ v).transpose(1, 2).reshape(b_, nq, self.dim)
525
+
526
+ # merge windows
527
+ attn_windows = attn_windows.view(
528
+ -1, self.window_size, self.window_size, self.dim
529
+ )
530
+ x = window_reverse(attn_windows, self.window_size, h, w) # b h w c
531
+ x = x.view(b, h * w, self.dim)
532
+
533
+ x = self.proj(x) + shortcut
534
+
535
+ x = x + self.mlp(self.norm2(x))
536
+ return x
537
+
538
+
539
+ class AttenBlocks(nn.Module):
540
+ """A series of attention blocks for one RHAG.
541
+ Args:
542
+ dim (int): Number of input channels.
543
+ input_resolution (tuple[int]): Input resolution.
544
+ depth (int): Number of blocks.
545
+ num_heads (int): Number of attention heads.
546
+ window_size (int): Local window size.
547
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
548
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
549
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
550
+ drop (float, optional): Dropout rate. Default: 0.0
551
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
552
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
553
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
554
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
555
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
556
+ """
557
+
558
+ def __init__(
559
+ self,
560
+ dim,
561
+ input_resolution,
562
+ depth,
563
+ num_heads,
564
+ window_size,
565
+ compress_ratio,
566
+ squeeze_factor,
567
+ conv_scale,
568
+ overlap_ratio,
569
+ mlp_ratio=4.0,
570
+ qkv_bias=True,
571
+ qk_scale=None,
572
+ drop=0.0,
573
+ attn_drop=0.0,
574
+ drop_path=0.0,
575
+ norm_layer=nn.LayerNorm,
576
+ downsample=None,
577
+ use_checkpoint=False,
578
+ ):
579
+ super().__init__()
580
+ self.dim = dim
581
+ self.input_resolution = input_resolution
582
+ self.depth = depth
583
+ self.use_checkpoint = use_checkpoint
584
+
585
+ # build blocks
586
+ self.blocks = nn.ModuleList(
587
+ [
588
+ HAB(
589
+ dim=dim,
590
+ input_resolution=input_resolution,
591
+ num_heads=num_heads,
592
+ window_size=window_size,
593
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
594
+ compress_ratio=compress_ratio,
595
+ squeeze_factor=squeeze_factor,
596
+ conv_scale=conv_scale,
597
+ mlp_ratio=mlp_ratio,
598
+ qkv_bias=qkv_bias,
599
+ qk_scale=qk_scale,
600
+ drop=drop,
601
+ attn_drop=attn_drop,
602
+ drop_path=drop_path[i]
603
+ if isinstance(drop_path, list)
604
+ else drop_path,
605
+ norm_layer=norm_layer,
606
+ )
607
+ for i in range(depth)
608
+ ]
609
+ )
610
+
611
+ # OCAB
612
+ self.overlap_attn = OCAB(
613
+ dim=dim,
614
+ input_resolution=input_resolution,
615
+ window_size=window_size,
616
+ overlap_ratio=overlap_ratio,
617
+ num_heads=num_heads,
618
+ qkv_bias=qkv_bias,
619
+ qk_scale=qk_scale,
620
+ mlp_ratio=mlp_ratio, # type: ignore
621
+ norm_layer=norm_layer,
622
+ )
623
+
624
+ # patch merging layer
625
+ if downsample is not None:
626
+ self.downsample = downsample(
627
+ input_resolution, dim=dim, norm_layer=norm_layer
628
+ )
629
+ else:
630
+ self.downsample = None
631
+
632
+ def forward(self, x, x_size, params):
633
+ for blk in self.blocks:
634
+ x = blk(x, x_size, params["rpi_sa"], params["attn_mask"])
635
+
636
+ x = self.overlap_attn(x, x_size, params["rpi_oca"])
637
+
638
+ if self.downsample is not None:
639
+ x = self.downsample(x)
640
+ return x
641
+
642
+
643
+ class RHAG(nn.Module):
644
+ """Residual Hybrid Attention Group (RHAG).
645
+ Args:
646
+ dim (int): Number of input channels.
647
+ input_resolution (tuple[int]): Input resolution.
648
+ depth (int): Number of blocks.
649
+ num_heads (int): Number of attention heads.
650
+ window_size (int): Local window size.
651
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
652
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
653
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
654
+ drop (float, optional): Dropout rate. Default: 0.0
655
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
656
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
657
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
658
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
659
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
660
+ img_size: Input image size.
661
+ patch_size: Patch size.
662
+ resi_connection: The convolutional block before residual connection.
663
+ """
664
+
665
+ def __init__(
666
+ self,
667
+ dim,
668
+ input_resolution,
669
+ depth,
670
+ num_heads,
671
+ window_size,
672
+ compress_ratio,
673
+ squeeze_factor,
674
+ conv_scale,
675
+ overlap_ratio,
676
+ mlp_ratio=4.0,
677
+ qkv_bias=True,
678
+ qk_scale=None,
679
+ drop=0.0,
680
+ attn_drop=0.0,
681
+ drop_path=0.0,
682
+ norm_layer=nn.LayerNorm,
683
+ downsample=None,
684
+ use_checkpoint=False,
685
+ img_size=224,
686
+ patch_size=4,
687
+ resi_connection="1conv",
688
+ ):
689
+ super(RHAG, self).__init__()
690
+
691
+ self.dim = dim
692
+ self.input_resolution = input_resolution
693
+
694
+ self.residual_group = AttenBlocks(
695
+ dim=dim,
696
+ input_resolution=input_resolution,
697
+ depth=depth,
698
+ num_heads=num_heads,
699
+ window_size=window_size,
700
+ compress_ratio=compress_ratio,
701
+ squeeze_factor=squeeze_factor,
702
+ conv_scale=conv_scale,
703
+ overlap_ratio=overlap_ratio,
704
+ mlp_ratio=mlp_ratio,
705
+ qkv_bias=qkv_bias,
706
+ qk_scale=qk_scale,
707
+ drop=drop,
708
+ attn_drop=attn_drop,
709
+ drop_path=drop_path,
710
+ norm_layer=norm_layer,
711
+ downsample=downsample,
712
+ use_checkpoint=use_checkpoint,
713
+ )
714
+
715
+ if resi_connection == "1conv":
716
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
717
+ elif resi_connection == "identity":
718
+ self.conv = nn.Identity()
719
+
720
+ self.patch_embed = PatchEmbed(
721
+ img_size=img_size,
722
+ patch_size=patch_size,
723
+ in_chans=0,
724
+ embed_dim=dim,
725
+ norm_layer=None,
726
+ )
727
+
728
+ self.patch_unembed = PatchUnEmbed(
729
+ img_size=img_size,
730
+ patch_size=patch_size,
731
+ in_chans=0,
732
+ embed_dim=dim,
733
+ norm_layer=None,
734
+ )
735
+
736
+ def forward(self, x, x_size, params):
737
+ return (
738
+ self.patch_embed(
739
+ self.conv(
740
+ self.patch_unembed(self.residual_group(x, x_size, params), x_size)
741
+ )
742
+ )
743
+ + x
744
+ )
745
+
746
+
747
+ class PatchEmbed(nn.Module):
748
+ r"""Image to Patch Embedding
749
+ Args:
750
+ img_size (int): Image size. Default: 224.
751
+ patch_size (int): Patch token size. Default: 4.
752
+ in_chans (int): Number of input image channels. Default: 3.
753
+ embed_dim (int): Number of linear projection output channels. Default: 96.
754
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
755
+ """
756
+
757
+ def __init__(
758
+ self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None
759
+ ):
760
+ super().__init__()
761
+ img_size = to_2tuple(img_size)
762
+ patch_size = to_2tuple(patch_size)
763
+ patches_resolution = [
764
+ img_size[0] // patch_size[0], # type: ignore
765
+ img_size[1] // patch_size[1], # type: ignore
766
+ ]
767
+ self.img_size = img_size
768
+ self.patch_size = patch_size
769
+ self.patches_resolution = patches_resolution
770
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
771
+
772
+ self.in_chans = in_chans
773
+ self.embed_dim = embed_dim
774
+
775
+ if norm_layer is not None:
776
+ self.norm = norm_layer(embed_dim)
777
+ else:
778
+ self.norm = None
779
+
780
+ def forward(self, x):
781
+ x = x.flatten(2).transpose(1, 2) # b Ph*Pw c
782
+ if self.norm is not None:
783
+ x = self.norm(x)
784
+ return x
785
+
786
+
787
+ class PatchUnEmbed(nn.Module):
788
+ r"""Image to Patch Unembedding
789
+ Args:
790
+ img_size (int): Image size. Default: 224.
791
+ patch_size (int): Patch token size. Default: 4.
792
+ in_chans (int): Number of input image channels. Default: 3.
793
+ embed_dim (int): Number of linear projection output channels. Default: 96.
794
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
795
+ """
796
+
797
+ def __init__(
798
+ self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None
799
+ ):
800
+ super().__init__()
801
+ img_size = to_2tuple(img_size)
802
+ patch_size = to_2tuple(patch_size)
803
+ patches_resolution = [
804
+ img_size[0] // patch_size[0], # type: ignore
805
+ img_size[1] // patch_size[1], # type: ignore
806
+ ]
807
+ self.img_size = img_size
808
+ self.patch_size = patch_size
809
+ self.patches_resolution = patches_resolution
810
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
811
+
812
+ self.in_chans = in_chans
813
+ self.embed_dim = embed_dim
814
+
815
+ def forward(self, x, x_size):
816
+ x = (
817
+ x.transpose(1, 2)
818
+ .contiguous()
819
+ .view(x.shape[0], self.embed_dim, x_size[0], x_size[1])
820
+ ) # b Ph*Pw c
821
+ return x
822
+
823
+
824
+ class Upsample(nn.Sequential):
825
+ """Upsample module.
826
+ Args:
827
+ scale (int): Scale factor. Supported scales: 2^n and 3.
828
+ num_feat (int): Channel number of intermediate features.
829
+ """
830
+
831
+ def __init__(self, scale, num_feat):
832
+ m = []
833
+ if (scale & (scale - 1)) == 0: # scale = 2^n
834
+ for _ in range(int(math.log(scale, 2))):
835
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
836
+ m.append(nn.PixelShuffle(2))
837
+ elif scale == 3:
838
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
839
+ m.append(nn.PixelShuffle(3))
840
+ else:
841
+ raise ValueError(
842
+ f"scale {scale} is not supported. " "Supported scales: 2^n and 3."
843
+ )
844
+ super(Upsample, self).__init__(*m)
845
+
846
+
847
+ class HAT(nn.Module):
848
+ r"""Hybrid Attention Transformer
849
+ A PyTorch implementation of : `Activating More Pixels in Image Super-Resolution Transformer`.
850
+ Some codes are based on SwinIR.
851
+ Args:
852
+ img_size (int | tuple(int)): Input image size. Default 64
853
+ patch_size (int | tuple(int)): Patch size. Default: 1
854
+ in_chans (int): Number of input image channels. Default: 3
855
+ embed_dim (int): Patch embedding dimension. Default: 96
856
+ depths (tuple(int)): Depth of each Swin Transformer layer.
857
+ num_heads (tuple(int)): Number of attention heads in different layers.
858
+ window_size (int): Window size. Default: 7
859
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
860
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
861
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
862
+ drop_rate (float): Dropout rate. Default: 0
863
+ attn_drop_rate (float): Attention dropout rate. Default: 0
864
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
865
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
866
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
867
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
868
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
869
+ upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
870
+ img_range: Image range. 1. or 255.
871
+ upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
872
+ resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
873
+ """
874
+
875
+ def __init__(
876
+ self,
877
+ state_dict,
878
+ **kwargs,
879
+ ):
880
+ super(HAT, self).__init__()
881
+
882
+ # Defaults
883
+ img_size = 64
884
+ patch_size = 1
885
+ in_chans = 3
886
+ embed_dim = 96
887
+ depths = (6, 6, 6, 6)
888
+ num_heads = (6, 6, 6, 6)
889
+ window_size = 7
890
+ compress_ratio = 3
891
+ squeeze_factor = 30
892
+ conv_scale = 0.01
893
+ overlap_ratio = 0.5
894
+ mlp_ratio = 4.0
895
+ qkv_bias = True
896
+ qk_scale = None
897
+ drop_rate = 0.0
898
+ attn_drop_rate = 0.0
899
+ drop_path_rate = 0.1
900
+ norm_layer = nn.LayerNorm
901
+ ape = False
902
+ patch_norm = True
903
+ use_checkpoint = False
904
+ upscale = 2
905
+ img_range = 1.0
906
+ upsampler = ""
907
+ resi_connection = "1conv"
908
+
909
+ self.state = state_dict
910
+ self.model_arch = "HAT"
911
+ self.sub_type = "SR"
912
+ self.supports_fp16 = False
913
+ self.support_bf16 = True
914
+ self.min_size_restriction = 16
915
+
916
+ state_keys = list(state_dict.keys())
917
+
918
+ num_feat = state_dict["conv_last.weight"].shape[1]
919
+ in_chans = state_dict["conv_first.weight"].shape[1]
920
+ num_out_ch = state_dict["conv_last.weight"].shape[0]
921
+ embed_dim = state_dict["conv_first.weight"].shape[0]
922
+
923
+ if "conv_before_upsample.0.weight" in state_keys:
924
+ if "conv_up1.weight" in state_keys:
925
+ upsampler = "nearest+conv"
926
+ else:
927
+ upsampler = "pixelshuffle"
928
+ supports_fp16 = False
929
+ elif "upsample.0.weight" in state_keys:
930
+ upsampler = "pixelshuffledirect"
931
+ else:
932
+ upsampler = ""
933
+ upscale = 1
934
+ if upsampler == "nearest+conv":
935
+ upsample_keys = [
936
+ x for x in state_keys if "conv_up" in x and "bias" not in x
937
+ ]
938
+
939
+ for upsample_key in upsample_keys:
940
+ upscale *= 2
941
+ elif upsampler == "pixelshuffle":
942
+ upsample_keys = [
943
+ x
944
+ for x in state_keys
945
+ if "upsample" in x and "conv" not in x and "bias" not in x
946
+ ]
947
+ for upsample_key in upsample_keys:
948
+ shape = self.state[upsample_key].shape[0]
949
+ upscale *= math.sqrt(shape // num_feat)
950
+ upscale = int(upscale)
951
+ elif upsampler == "pixelshuffledirect":
952
+ upscale = int(
953
+ math.sqrt(self.state["upsample.0.bias"].shape[0] // num_out_ch)
954
+ )
955
+
956
+ max_layer_num = 0
957
+ max_block_num = 0
958
+ for key in state_keys:
959
+ result = re.match(
960
+ r"layers.(\d*).residual_group.blocks.(\d*).conv_block.cab.0.weight", key
961
+ )
962
+ if result:
963
+ layer_num, block_num = result.groups()
964
+ max_layer_num = max(max_layer_num, int(layer_num))
965
+ max_block_num = max(max_block_num, int(block_num))
966
+
967
+ depths = [max_block_num + 1 for _ in range(max_layer_num + 1)]
968
+
969
+ if (
970
+ "layers.0.residual_group.blocks.0.attn.relative_position_bias_table"
971
+ in state_keys
972
+ ):
973
+ num_heads_num = self.state[
974
+ "layers.0.residual_group.blocks.0.attn.relative_position_bias_table"
975
+ ].shape[-1]
976
+ num_heads = [num_heads_num for _ in range(max_layer_num + 1)]
977
+ else:
978
+ num_heads = depths
979
+
980
+ mlp_ratio = float(
981
+ self.state["layers.0.residual_group.blocks.0.mlp.fc1.bias"].shape[0]
982
+ / embed_dim
983
+ )
984
+
985
+ # TODO: could actually count the layers, but this should do
986
+ if "layers.0.conv.4.weight" in state_keys:
987
+ resi_connection = "3conv"
988
+ else:
989
+ resi_connection = "1conv"
990
+
991
+ window_size = int(math.sqrt(self.state["relative_position_index_SA"].shape[0]))
992
+
993
+ # Not sure if this is needed or used at all anywhere in HAT's config
994
+ if "layers.0.residual_group.blocks.1.attn_mask" in state_keys:
995
+ img_size = int(
996
+ math.sqrt(
997
+ self.state["layers.0.residual_group.blocks.1.attn_mask"].shape[0]
998
+ )
999
+ * window_size
1000
+ )
1001
+
1002
+ self.window_size = window_size
1003
+ self.shift_size = window_size // 2
1004
+ self.overlap_ratio = overlap_ratio
1005
+
1006
+ self.in_nc = in_chans
1007
+ self.out_nc = num_out_ch
1008
+ self.num_feat = num_feat
1009
+ self.embed_dim = embed_dim
1010
+ self.num_heads = num_heads
1011
+ self.depths = depths
1012
+ self.window_size = window_size
1013
+ self.mlp_ratio = mlp_ratio
1014
+ self.scale = upscale
1015
+ self.upsampler = upsampler
1016
+ self.img_size = img_size
1017
+ self.img_range = img_range
1018
+ self.resi_connection = resi_connection
1019
+
1020
+ num_in_ch = in_chans
1021
+ # num_out_ch = in_chans
1022
+ # num_feat = 64
1023
+ self.img_range = img_range
1024
+ if in_chans == 3:
1025
+ rgb_mean = (0.4488, 0.4371, 0.4040)
1026
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
1027
+ else:
1028
+ self.mean = torch.zeros(1, 1, 1, 1)
1029
+ self.upscale = upscale
1030
+ self.upsampler = upsampler
1031
+
1032
+ # relative position index
1033
+ relative_position_index_SA = self.calculate_rpi_sa()
1034
+ relative_position_index_OCA = self.calculate_rpi_oca()
1035
+ self.register_buffer("relative_position_index_SA", relative_position_index_SA)
1036
+ self.register_buffer("relative_position_index_OCA", relative_position_index_OCA)
1037
+
1038
+ # ------------------------- 1, shallow feature extraction ------------------------- #
1039
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
1040
+
1041
+ # ------------------------- 2, deep feature extraction ------------------------- #
1042
+ self.num_layers = len(depths)
1043
+ self.embed_dim = embed_dim
1044
+ self.ape = ape
1045
+ self.patch_norm = patch_norm
1046
+ self.num_features = embed_dim
1047
+ self.mlp_ratio = mlp_ratio
1048
+
1049
+ # split image into non-overlapping patches
1050
+ self.patch_embed = PatchEmbed(
1051
+ img_size=img_size,
1052
+ patch_size=patch_size,
1053
+ in_chans=embed_dim,
1054
+ embed_dim=embed_dim,
1055
+ norm_layer=norm_layer if self.patch_norm else None,
1056
+ )
1057
+ num_patches = self.patch_embed.num_patches
1058
+ patches_resolution = self.patch_embed.patches_resolution
1059
+ self.patches_resolution = patches_resolution
1060
+
1061
+ # merge non-overlapping patches into image
1062
+ self.patch_unembed = PatchUnEmbed(
1063
+ img_size=img_size,
1064
+ patch_size=patch_size,
1065
+ in_chans=embed_dim,
1066
+ embed_dim=embed_dim,
1067
+ norm_layer=norm_layer if self.patch_norm else None,
1068
+ )
1069
+
1070
+ # absolute position embedding
1071
+ if self.ape:
1072
+ self.absolute_pos_embed = nn.Parameter( # type: ignore[arg-type]
1073
+ torch.zeros(1, num_patches, embed_dim)
1074
+ )
1075
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
1076
+
1077
+ self.pos_drop = nn.Dropout(p=drop_rate)
1078
+
1079
+ # stochastic depth
1080
+ dpr = [
1081
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
1082
+ ] # stochastic depth decay rule
1083
+
1084
+ # build Residual Hybrid Attention Groups (RHAG)
1085
+ self.layers = nn.ModuleList()
1086
+ for i_layer in range(self.num_layers):
1087
+ layer = RHAG(
1088
+ dim=embed_dim,
1089
+ input_resolution=(patches_resolution[0], patches_resolution[1]),
1090
+ depth=depths[i_layer],
1091
+ num_heads=num_heads[i_layer],
1092
+ window_size=window_size,
1093
+ compress_ratio=compress_ratio,
1094
+ squeeze_factor=squeeze_factor,
1095
+ conv_scale=conv_scale,
1096
+ overlap_ratio=overlap_ratio,
1097
+ mlp_ratio=self.mlp_ratio,
1098
+ qkv_bias=qkv_bias,
1099
+ qk_scale=qk_scale,
1100
+ drop=drop_rate,
1101
+ attn_drop=attn_drop_rate,
1102
+ drop_path=dpr[
1103
+ sum(depths[:i_layer]) : sum(depths[: i_layer + 1]) # type: ignore
1104
+ ], # no impact on SR results
1105
+ norm_layer=norm_layer,
1106
+ downsample=None,
1107
+ use_checkpoint=use_checkpoint,
1108
+ img_size=img_size,
1109
+ patch_size=patch_size,
1110
+ resi_connection=resi_connection,
1111
+ )
1112
+ self.layers.append(layer)
1113
+ self.norm = norm_layer(self.num_features)
1114
+
1115
+ # build the last conv layer in deep feature extraction
1116
+ if resi_connection == "1conv":
1117
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
1118
+ elif resi_connection == "identity":
1119
+ self.conv_after_body = nn.Identity()
1120
+
1121
+ # ------------------------- 3, high quality image reconstruction ------------------------- #
1122
+ if self.upsampler == "pixelshuffle":
1123
+ # for classical SR
1124
+ self.conv_before_upsample = nn.Sequential(
1125
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
1126
+ )
1127
+ self.upsample = Upsample(upscale, num_feat)
1128
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
1129
+
1130
+ self.apply(self._init_weights)
1131
+ self.load_state_dict(self.state, strict=False)
1132
+
1133
+ def _init_weights(self, m):
1134
+ if isinstance(m, nn.Linear):
1135
+ trunc_normal_(m.weight, std=0.02)
1136
+ if isinstance(m, nn.Linear) and m.bias is not None:
1137
+ nn.init.constant_(m.bias, 0)
1138
+ elif isinstance(m, nn.LayerNorm):
1139
+ nn.init.constant_(m.bias, 0)
1140
+ nn.init.constant_(m.weight, 1.0)
1141
+
1142
+ def calculate_rpi_sa(self):
1143
+ # calculate relative position index for SA
1144
+ coords_h = torch.arange(self.window_size)
1145
+ coords_w = torch.arange(self.window_size)
1146
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
1147
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
1148
+ relative_coords = (
1149
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
1150
+ ) # 2, Wh*Ww, Wh*Ww
1151
+ relative_coords = relative_coords.permute(
1152
+ 1, 2, 0
1153
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
1154
+ relative_coords[:, :, 0] += self.window_size - 1 # shift to start from 0
1155
+ relative_coords[:, :, 1] += self.window_size - 1
1156
+ relative_coords[:, :, 0] *= 2 * self.window_size - 1
1157
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
1158
+ return relative_position_index
1159
+
1160
+ def calculate_rpi_oca(self):
1161
+ # calculate relative position index for OCA
1162
+ window_size_ori = self.window_size
1163
+ window_size_ext = self.window_size + int(self.overlap_ratio * self.window_size)
1164
+
1165
+ coords_h = torch.arange(window_size_ori)
1166
+ coords_w = torch.arange(window_size_ori)
1167
+ coords_ori = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, ws, ws
1168
+ coords_ori_flatten = torch.flatten(coords_ori, 1) # 2, ws*ws
1169
+
1170
+ coords_h = torch.arange(window_size_ext)
1171
+ coords_w = torch.arange(window_size_ext)
1172
+ coords_ext = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, wse, wse
1173
+ coords_ext_flatten = torch.flatten(coords_ext, 1) # 2, wse*wse
1174
+
1175
+ relative_coords = (
1176
+ coords_ext_flatten[:, None, :] - coords_ori_flatten[:, :, None]
1177
+ ) # 2, ws*ws, wse*wse
1178
+
1179
+ relative_coords = relative_coords.permute(
1180
+ 1, 2, 0
1181
+ ).contiguous() # ws*ws, wse*wse, 2
1182
+ relative_coords[:, :, 0] += (
1183
+ window_size_ori - window_size_ext + 1
1184
+ ) # shift to start from 0
1185
+ relative_coords[:, :, 1] += window_size_ori - window_size_ext + 1
1186
+
1187
+ relative_coords[:, :, 0] *= window_size_ori + window_size_ext - 1
1188
+ relative_position_index = relative_coords.sum(-1)
1189
+ return relative_position_index
1190
+
1191
+ def calculate_mask(self, x_size):
1192
+ # calculate attention mask for SW-MSA
1193
+ h, w = x_size
1194
+ img_mask = torch.zeros((1, h, w, 1)) # 1 h w 1
1195
+ h_slices = (
1196
+ slice(0, -self.window_size),
1197
+ slice(-self.window_size, -self.shift_size),
1198
+ slice(-self.shift_size, None),
1199
+ )
1200
+ w_slices = (
1201
+ slice(0, -self.window_size),
1202
+ slice(-self.window_size, -self.shift_size),
1203
+ slice(-self.shift_size, None),
1204
+ )
1205
+ cnt = 0
1206
+ for h in h_slices:
1207
+ for w in w_slices:
1208
+ img_mask[:, h, w, :] = cnt
1209
+ cnt += 1
1210
+
1211
+ mask_windows = window_partition(
1212
+ img_mask, self.window_size
1213
+ ) # nw, window_size, window_size, 1
1214
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
1215
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
1216
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
1217
+ attn_mask == 0, float(0.0)
1218
+ )
1219
+
1220
+ return attn_mask
1221
+
1222
+ @torch.jit.ignore # type: ignore
1223
+ def no_weight_decay(self):
1224
+ return {"absolute_pos_embed"}
1225
+
1226
+ @torch.jit.ignore # type: ignore
1227
+ def no_weight_decay_keywords(self):
1228
+ return {"relative_position_bias_table"}
1229
+
1230
+ def check_image_size(self, x):
1231
+ _, _, h, w = x.size()
1232
+ mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
1233
+ mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
1234
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
1235
+ return x
1236
+
1237
+ def forward_features(self, x):
1238
+ x_size = (x.shape[2], x.shape[3])
1239
+
1240
+ # Calculate attention mask and relative position index in advance to speed up inference.
1241
+ # The original code is very time-cosuming for large window size.
1242
+ attn_mask = self.calculate_mask(x_size).to(x.device)
1243
+ params = {
1244
+ "attn_mask": attn_mask,
1245
+ "rpi_sa": self.relative_position_index_SA,
1246
+ "rpi_oca": self.relative_position_index_OCA,
1247
+ }
1248
+
1249
+ x = self.patch_embed(x)
1250
+ if self.ape:
1251
+ x = x + self.absolute_pos_embed
1252
+ x = self.pos_drop(x)
1253
+
1254
+ for layer in self.layers:
1255
+ x = layer(x, x_size, params)
1256
+
1257
+ x = self.norm(x) # b seq_len c
1258
+ x = self.patch_unembed(x, x_size)
1259
+
1260
+ return x
1261
+
1262
+ def forward(self, x):
1263
+ H, W = x.shape[2:]
1264
+ self.mean = self.mean.type_as(x)
1265
+ x = (x - self.mean) * self.img_range
1266
+ x = self.check_image_size(x)
1267
+
1268
+ if self.upsampler == "pixelshuffle":
1269
+ # for classical SR
1270
+ x = self.conv_first(x)
1271
+ x = self.conv_after_body(self.forward_features(x)) + x
1272
+ x = self.conv_before_upsample(x)
1273
+ x = self.conv_last(self.upsample(x))
1274
+
1275
+ x = x / self.img_range + self.mean
1276
+
1277
+ return x[:, :, : H * self.upscale, : W * self.upscale]
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/LICENSE-DAT ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/LICENSE-ESRGAN ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/LICENSE-HAT ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2022 Xiangyu Chen
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/LICENSE-RealESRGAN ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ BSD 3-Clause License
2
+
3
+ Copyright (c) 2021, Xintao Wang
4
+ All rights reserved.
5
+
6
+ Redistribution and use in source and binary forms, with or without
7
+ modification, are permitted provided that the following conditions are met:
8
+
9
+ 1. Redistributions of source code must retain the above copyright notice, this
10
+ list of conditions and the following disclaimer.
11
+
12
+ 2. Redistributions in binary form must reproduce the above copyright notice,
13
+ this list of conditions and the following disclaimer in the documentation
14
+ and/or other materials provided with the distribution.
15
+
16
+ 3. Neither the name of the copyright holder nor the names of its
17
+ contributors may be used to endorse or promote products derived from
18
+ this software without specific prior written permission.
19
+
20
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
21
+ AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
22
+ IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
23
+ DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
24
+ FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
25
+ DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
26
+ SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
27
+ CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
28
+ OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
29
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/LICENSE-SCUNet ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2022 Kai Zhang (cskaizhang@gmail.com, https://cszn.github.io/). All rights reserved.
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/LICENSE-SPSR ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2018-2022 BasicSR Authors
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/LICENSE-SwiftSRGAN ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Creative Commons Legal Code
2
+
3
+ CC0 1.0 Universal
4
+
5
+ CREATIVE COMMONS CORPORATION IS NOT A LAW FIRM AND DOES NOT PROVIDE
6
+ LEGAL SERVICES. DISTRIBUTION OF THIS DOCUMENT DOES NOT CREATE AN
7
+ ATTORNEY-CLIENT RELATIONSHIP. CREATIVE COMMONS PROVIDES THIS
8
+ INFORMATION ON AN "AS-IS" BASIS. CREATIVE COMMONS MAKES NO WARRANTIES
9
+ REGARDING THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS
10
+ PROVIDED HEREUNDER, AND DISCLAIMS LIABILITY FOR DAMAGES RESULTING FROM
11
+ THE USE OF THIS DOCUMENT OR THE INFORMATION OR WORKS PROVIDED
12
+ HEREUNDER.
13
+
14
+ Statement of Purpose
15
+
16
+ The laws of most jurisdictions throughout the world automatically confer
17
+ exclusive Copyright and Related Rights (defined below) upon the creator
18
+ and subsequent owner(s) (each and all, an "owner") of an original work of
19
+ authorship and/or a database (each, a "Work").
20
+
21
+ Certain owners wish to permanently relinquish those rights to a Work for
22
+ the purpose of contributing to a commons of creative, cultural and
23
+ scientific works ("Commons") that the public can reliably and without fear
24
+ of later claims of infringement build upon, modify, incorporate in other
25
+ works, reuse and redistribute as freely as possible in any form whatsoever
26
+ and for any purposes, including without limitation commercial purposes.
27
+ These owners may contribute to the Commons to promote the ideal of a free
28
+ culture and the further production of creative, cultural and scientific
29
+ works, or to gain reputation or greater distribution for their Work in
30
+ part through the use and efforts of others.
31
+
32
+ For these and/or other purposes and motivations, and without any
33
+ expectation of additional consideration or compensation, the person
34
+ associating CC0 with a Work (the "Affirmer"), to the extent that he or she
35
+ is an owner of Copyright and Related Rights in the Work, voluntarily
36
+ elects to apply CC0 to the Work and publicly distribute the Work under its
37
+ terms, with knowledge of his or her Copyright and Related Rights in the
38
+ Work and the meaning and intended legal effect of CC0 on those rights.
39
+
40
+ 1. Copyright and Related Rights. A Work made available under CC0 may be
41
+ protected by copyright and related or neighboring rights ("Copyright and
42
+ Related Rights"). Copyright and Related Rights include, but are not
43
+ limited to, the following:
44
+
45
+ i. the right to reproduce, adapt, distribute, perform, display,
46
+ communicate, and translate a Work;
47
+ ii. moral rights retained by the original author(s) and/or performer(s);
48
+ iii. publicity and privacy rights pertaining to a person's image or
49
+ likeness depicted in a Work;
50
+ iv. rights protecting against unfair competition in regards to a Work,
51
+ subject to the limitations in paragraph 4(a), below;
52
+ v. rights protecting the extraction, dissemination, use and reuse of data
53
+ in a Work;
54
+ vi. database rights (such as those arising under Directive 96/9/EC of the
55
+ European Parliament and of the Council of 11 March 1996 on the legal
56
+ protection of databases, and under any national implementation
57
+ thereof, including any amended or successor version of such
58
+ directive); and
59
+ vii. other similar, equivalent or corresponding rights throughout the
60
+ world based on applicable law or treaty, and any national
61
+ implementations thereof.
62
+
63
+ 2. Waiver. To the greatest extent permitted by, but not in contravention
64
+ of, applicable law, Affirmer hereby overtly, fully, permanently,
65
+ irrevocably and unconditionally waives, abandons, and surrenders all of
66
+ Affirmer's Copyright and Related Rights and associated claims and causes
67
+ of action, whether now known or unknown (including existing as well as
68
+ future claims and causes of action), in the Work (i) in all territories
69
+ worldwide, (ii) for the maximum duration provided by applicable law or
70
+ treaty (including future time extensions), (iii) in any current or future
71
+ medium and for any number of copies, and (iv) for any purpose whatsoever,
72
+ including without limitation commercial, advertising or promotional
73
+ purposes (the "Waiver"). Affirmer makes the Waiver for the benefit of each
74
+ member of the public at large and to the detriment of Affirmer's heirs and
75
+ successors, fully intending that such Waiver shall not be subject to
76
+ revocation, rescission, cancellation, termination, or any other legal or
77
+ equitable action to disrupt the quiet enjoyment of the Work by the public
78
+ as contemplated by Affirmer's express Statement of Purpose.
79
+
80
+ 3. Public License Fallback. Should any part of the Waiver for any reason
81
+ be judged legally invalid or ineffective under applicable law, then the
82
+ Waiver shall be preserved to the maximum extent permitted taking into
83
+ account Affirmer's express Statement of Purpose. In addition, to the
84
+ extent the Waiver is so judged Affirmer hereby grants to each affected
85
+ person a royalty-free, non transferable, non sublicensable, non exclusive,
86
+ irrevocable and unconditional license to exercise Affirmer's Copyright and
87
+ Related Rights in the Work (i) in all territories worldwide, (ii) for the
88
+ maximum duration provided by applicable law or treaty (including future
89
+ time extensions), (iii) in any current or future medium and for any number
90
+ of copies, and (iv) for any purpose whatsoever, including without
91
+ limitation commercial, advertising or promotional purposes (the
92
+ "License"). The License shall be deemed effective as of the date CC0 was
93
+ applied by Affirmer to the Work. Should any part of the License for any
94
+ reason be judged legally invalid or ineffective under applicable law, such
95
+ partial invalidity or ineffectiveness shall not invalidate the remainder
96
+ of the License, and in such case Affirmer hereby affirms that he or she
97
+ will not (i) exercise any of his or her remaining Copyright and Related
98
+ Rights in the Work or (ii) assert any associated claims and causes of
99
+ action with respect to the Work, in either case contrary to Affirmer's
100
+ express Statement of Purpose.
101
+
102
+ 4. Limitations and Disclaimers.
103
+
104
+ a. No trademark or patent rights held by Affirmer are waived, abandoned,
105
+ surrendered, licensed or otherwise affected by this document.
106
+ b. Affirmer offers the Work as-is and makes no representations or
107
+ warranties of any kind concerning the Work, express, implied,
108
+ statutory or otherwise, including without limitation warranties of
109
+ title, merchantability, fitness for a particular purpose, non
110
+ infringement, or the absence of latent or other defects, accuracy, or
111
+ the present or absence of errors, whether or not discoverable, all to
112
+ the greatest extent permissible under applicable law.
113
+ c. Affirmer disclaims responsibility for clearing rights of other persons
114
+ that may apply to the Work or any use thereof, including without
115
+ limitation any person's Copyright and Related Rights in the Work.
116
+ Further, Affirmer disclaims responsibility for obtaining any necessary
117
+ consents, permissions or other rights required for any use of the
118
+ Work.
119
+ d. Affirmer understands and acknowledges that Creative Commons is not a
120
+ party to this document and has no duty or obligation with respect to
121
+ this CC0 or use of the Work.
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/LICENSE-Swin2SR ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [2021] [SwinIR Authors]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/LICENSE-SwinIR ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [2021] [SwinIR Authors]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/LICENSE-lama ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [2021] Samsung Research
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/LaMa.py ADDED
@@ -0,0 +1,694 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ """
3
+ Model adapted from advimman's lama project: https://github.com/advimman/lama
4
+ """
5
+
6
+ # Fast Fourier Convolution NeurIPS 2020
7
+ # original implementation https://github.com/pkumivision/FFC/blob/main/model_zoo/ffc.py
8
+ # paper https://proceedings.neurips.cc/paper/2020/file/2fd5d41ec6cfab47e32164d5624269b1-Paper.pdf
9
+
10
+ from typing import List
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from torchvision.transforms.functional import InterpolationMode, rotate
16
+
17
+
18
+ class LearnableSpatialTransformWrapper(nn.Module):
19
+ def __init__(self, impl, pad_coef=0.5, angle_init_range=80, train_angle=True):
20
+ super().__init__()
21
+ self.impl = impl
22
+ self.angle = torch.rand(1) * angle_init_range
23
+ if train_angle:
24
+ self.angle = nn.Parameter(self.angle, requires_grad=True)
25
+ self.pad_coef = pad_coef
26
+
27
+ def forward(self, x):
28
+ if torch.is_tensor(x):
29
+ return self.inverse_transform(self.impl(self.transform(x)), x)
30
+ elif isinstance(x, tuple):
31
+ x_trans = tuple(self.transform(elem) for elem in x)
32
+ y_trans = self.impl(x_trans)
33
+ return tuple(
34
+ self.inverse_transform(elem, orig_x) for elem, orig_x in zip(y_trans, x)
35
+ )
36
+ else:
37
+ raise ValueError(f"Unexpected input type {type(x)}")
38
+
39
+ def transform(self, x):
40
+ height, width = x.shape[2:]
41
+ pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef)
42
+ x_padded = F.pad(x, [pad_w, pad_w, pad_h, pad_h], mode="reflect")
43
+ x_padded_rotated = rotate(
44
+ x_padded, self.angle.to(x_padded), InterpolationMode.BILINEAR, fill=0
45
+ )
46
+
47
+ return x_padded_rotated
48
+
49
+ def inverse_transform(self, y_padded_rotated, orig_x):
50
+ height, width = orig_x.shape[2:]
51
+ pad_h, pad_w = int(height * self.pad_coef), int(width * self.pad_coef)
52
+
53
+ y_padded = rotate(
54
+ y_padded_rotated,
55
+ -self.angle.to(y_padded_rotated),
56
+ InterpolationMode.BILINEAR,
57
+ fill=0,
58
+ )
59
+ y_height, y_width = y_padded.shape[2:]
60
+ y = y_padded[:, :, pad_h : y_height - pad_h, pad_w : y_width - pad_w]
61
+ return y
62
+
63
+
64
+ class SELayer(nn.Module):
65
+ def __init__(self, channel, reduction=16):
66
+ super(SELayer, self).__init__()
67
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
68
+ self.fc = nn.Sequential(
69
+ nn.Linear(channel, channel // reduction, bias=False),
70
+ nn.ReLU(inplace=True),
71
+ nn.Linear(channel // reduction, channel, bias=False),
72
+ nn.Sigmoid(),
73
+ )
74
+
75
+ def forward(self, x):
76
+ b, c, _, _ = x.size()
77
+ y = self.avg_pool(x).view(b, c)
78
+ y = self.fc(y).view(b, c, 1, 1)
79
+ res = x * y.expand_as(x)
80
+ return res
81
+
82
+
83
+ class FourierUnit(nn.Module):
84
+ def __init__(
85
+ self,
86
+ in_channels,
87
+ out_channels,
88
+ groups=1,
89
+ spatial_scale_factor=None,
90
+ spatial_scale_mode="bilinear",
91
+ spectral_pos_encoding=False,
92
+ use_se=False,
93
+ se_kwargs=None,
94
+ ffc3d=False,
95
+ fft_norm="ortho",
96
+ ):
97
+ # bn_layer not used
98
+ super(FourierUnit, self).__init__()
99
+ self.groups = groups
100
+
101
+ self.conv_layer = torch.nn.Conv2d(
102
+ in_channels=in_channels * 2 + (2 if spectral_pos_encoding else 0),
103
+ out_channels=out_channels * 2,
104
+ kernel_size=1,
105
+ stride=1,
106
+ padding=0,
107
+ groups=self.groups,
108
+ bias=False,
109
+ )
110
+ self.bn = torch.nn.BatchNorm2d(out_channels * 2)
111
+ self.relu = torch.nn.ReLU(inplace=True)
112
+
113
+ # squeeze and excitation block
114
+ self.use_se = use_se
115
+ if use_se:
116
+ if se_kwargs is None:
117
+ se_kwargs = {}
118
+ self.se = SELayer(self.conv_layer.in_channels, **se_kwargs)
119
+
120
+ self.spatial_scale_factor = spatial_scale_factor
121
+ self.spatial_scale_mode = spatial_scale_mode
122
+ self.spectral_pos_encoding = spectral_pos_encoding
123
+ self.ffc3d = ffc3d
124
+ self.fft_norm = fft_norm
125
+
126
+ def forward(self, x):
127
+ half_check = False
128
+ if x.type() == "torch.cuda.HalfTensor":
129
+ # half only works on gpu anyway
130
+ half_check = True
131
+
132
+ batch = x.shape[0]
133
+
134
+ if self.spatial_scale_factor is not None:
135
+ orig_size = x.shape[-2:]
136
+ x = F.interpolate(
137
+ x,
138
+ scale_factor=self.spatial_scale_factor,
139
+ mode=self.spatial_scale_mode,
140
+ align_corners=False,
141
+ )
142
+
143
+ # (batch, c, h, w/2+1, 2)
144
+ fft_dim = (-3, -2, -1) if self.ffc3d else (-2, -1)
145
+ if half_check == True:
146
+ ffted = torch.fft.rfftn(
147
+ x.float(), dim=fft_dim, norm=self.fft_norm
148
+ ) # .type(torch.cuda.HalfTensor)
149
+ else:
150
+ ffted = torch.fft.rfftn(x, dim=fft_dim, norm=self.fft_norm)
151
+
152
+ ffted = torch.stack((ffted.real, ffted.imag), dim=-1)
153
+ ffted = ffted.permute(0, 1, 4, 2, 3).contiguous() # (batch, c, 2, h, w/2+1)
154
+ ffted = ffted.view(
155
+ (
156
+ batch,
157
+ -1,
158
+ )
159
+ + ffted.size()[3:]
160
+ )
161
+
162
+ if self.spectral_pos_encoding:
163
+ height, width = ffted.shape[-2:]
164
+ coords_vert = (
165
+ torch.linspace(0, 1, height)[None, None, :, None]
166
+ .expand(batch, 1, height, width)
167
+ .to(ffted)
168
+ )
169
+ coords_hor = (
170
+ torch.linspace(0, 1, width)[None, None, None, :]
171
+ .expand(batch, 1, height, width)
172
+ .to(ffted)
173
+ )
174
+ ffted = torch.cat((coords_vert, coords_hor, ffted), dim=1)
175
+
176
+ if self.use_se:
177
+ ffted = self.se(ffted)
178
+
179
+ if half_check == True:
180
+ ffted = self.conv_layer(ffted.half()) # (batch, c*2, h, w/2+1)
181
+ else:
182
+ ffted = self.conv_layer(
183
+ ffted
184
+ ) # .type(torch.cuda.FloatTensor) # (batch, c*2, h, w/2+1)
185
+
186
+ ffted = self.relu(self.bn(ffted))
187
+ # forcing to be always float
188
+ ffted = ffted.float()
189
+
190
+ ffted = (
191
+ ffted.view(
192
+ (
193
+ batch,
194
+ -1,
195
+ 2,
196
+ )
197
+ + ffted.size()[2:]
198
+ )
199
+ .permute(0, 1, 3, 4, 2)
200
+ .contiguous()
201
+ ) # (batch,c, t, h, w/2+1, 2)
202
+
203
+ ffted = torch.complex(ffted[..., 0], ffted[..., 1])
204
+
205
+ ifft_shape_slice = x.shape[-3:] if self.ffc3d else x.shape[-2:]
206
+ output = torch.fft.irfftn(
207
+ ffted, s=ifft_shape_slice, dim=fft_dim, norm=self.fft_norm
208
+ )
209
+
210
+ if half_check == True:
211
+ output = output.half()
212
+
213
+ if self.spatial_scale_factor is not None:
214
+ output = F.interpolate(
215
+ output,
216
+ size=orig_size,
217
+ mode=self.spatial_scale_mode,
218
+ align_corners=False,
219
+ )
220
+
221
+ return output
222
+
223
+
224
+ class SpectralTransform(nn.Module):
225
+ def __init__(
226
+ self,
227
+ in_channels,
228
+ out_channels,
229
+ stride=1,
230
+ groups=1,
231
+ enable_lfu=True,
232
+ separable_fu=False,
233
+ **fu_kwargs,
234
+ ):
235
+ # bn_layer not used
236
+ super(SpectralTransform, self).__init__()
237
+ self.enable_lfu = enable_lfu
238
+ if stride == 2:
239
+ self.downsample = nn.AvgPool2d(kernel_size=(2, 2), stride=2)
240
+ else:
241
+ self.downsample = nn.Identity()
242
+
243
+ self.stride = stride
244
+ self.conv1 = nn.Sequential(
245
+ nn.Conv2d(
246
+ in_channels, out_channels // 2, kernel_size=1, groups=groups, bias=False
247
+ ),
248
+ nn.BatchNorm2d(out_channels // 2),
249
+ nn.ReLU(inplace=True),
250
+ )
251
+ fu_class = FourierUnit
252
+ self.fu = fu_class(out_channels // 2, out_channels // 2, groups, **fu_kwargs)
253
+ if self.enable_lfu:
254
+ self.lfu = fu_class(out_channels // 2, out_channels // 2, groups)
255
+ self.conv2 = torch.nn.Conv2d(
256
+ out_channels // 2, out_channels, kernel_size=1, groups=groups, bias=False
257
+ )
258
+
259
+ def forward(self, x):
260
+ x = self.downsample(x)
261
+ x = self.conv1(x)
262
+ output = self.fu(x)
263
+
264
+ if self.enable_lfu:
265
+ _, c, h, _ = x.shape
266
+ split_no = 2
267
+ split_s = h // split_no
268
+ xs = torch.cat(
269
+ torch.split(x[:, : c // 4], split_s, dim=-2), dim=1
270
+ ).contiguous()
271
+ xs = torch.cat(torch.split(xs, split_s, dim=-1), dim=1).contiguous()
272
+ xs = self.lfu(xs)
273
+ xs = xs.repeat(1, 1, split_no, split_no).contiguous()
274
+ else:
275
+ xs = 0
276
+
277
+ output = self.conv2(x + output + xs)
278
+
279
+ return output
280
+
281
+
282
+ class FFC(nn.Module):
283
+ def __init__(
284
+ self,
285
+ in_channels,
286
+ out_channels,
287
+ kernel_size,
288
+ ratio_gin,
289
+ ratio_gout,
290
+ stride=1,
291
+ padding=0,
292
+ dilation=1,
293
+ groups=1,
294
+ bias=False,
295
+ enable_lfu=True,
296
+ padding_type="reflect",
297
+ gated=False,
298
+ **spectral_kwargs,
299
+ ):
300
+ super(FFC, self).__init__()
301
+
302
+ assert stride == 1 or stride == 2, "Stride should be 1 or 2."
303
+ self.stride = stride
304
+
305
+ in_cg = int(in_channels * ratio_gin)
306
+ in_cl = in_channels - in_cg
307
+ out_cg = int(out_channels * ratio_gout)
308
+ out_cl = out_channels - out_cg
309
+ # groups_g = 1 if groups == 1 else int(groups * ratio_gout)
310
+ # groups_l = 1 if groups == 1 else groups - groups_g
311
+
312
+ self.ratio_gin = ratio_gin
313
+ self.ratio_gout = ratio_gout
314
+ self.global_in_num = in_cg
315
+
316
+ module = nn.Identity if in_cl == 0 or out_cl == 0 else nn.Conv2d
317
+ self.convl2l = module(
318
+ in_cl,
319
+ out_cl,
320
+ kernel_size,
321
+ stride,
322
+ padding,
323
+ dilation,
324
+ groups,
325
+ bias,
326
+ padding_mode=padding_type,
327
+ )
328
+ module = nn.Identity if in_cl == 0 or out_cg == 0 else nn.Conv2d
329
+ self.convl2g = module(
330
+ in_cl,
331
+ out_cg,
332
+ kernel_size,
333
+ stride,
334
+ padding,
335
+ dilation,
336
+ groups,
337
+ bias,
338
+ padding_mode=padding_type,
339
+ )
340
+ module = nn.Identity if in_cg == 0 or out_cl == 0 else nn.Conv2d
341
+ self.convg2l = module(
342
+ in_cg,
343
+ out_cl,
344
+ kernel_size,
345
+ stride,
346
+ padding,
347
+ dilation,
348
+ groups,
349
+ bias,
350
+ padding_mode=padding_type,
351
+ )
352
+ module = nn.Identity if in_cg == 0 or out_cg == 0 else SpectralTransform
353
+ self.convg2g = module(
354
+ in_cg,
355
+ out_cg,
356
+ stride,
357
+ 1 if groups == 1 else groups // 2,
358
+ enable_lfu,
359
+ **spectral_kwargs,
360
+ )
361
+
362
+ self.gated = gated
363
+ module = (
364
+ nn.Identity if in_cg == 0 or out_cl == 0 or not self.gated else nn.Conv2d
365
+ )
366
+ self.gate = module(in_channels, 2, 1)
367
+
368
+ def forward(self, x):
369
+ x_l, x_g = x if type(x) is tuple else (x, 0)
370
+ out_xl, out_xg = 0, 0
371
+
372
+ if self.gated:
373
+ total_input_parts = [x_l]
374
+ if torch.is_tensor(x_g):
375
+ total_input_parts.append(x_g)
376
+ total_input = torch.cat(total_input_parts, dim=1)
377
+
378
+ gates = torch.sigmoid(self.gate(total_input))
379
+ g2l_gate, l2g_gate = gates.chunk(2, dim=1)
380
+ else:
381
+ g2l_gate, l2g_gate = 1, 1
382
+
383
+ if self.ratio_gout != 1:
384
+ out_xl = self.convl2l(x_l) + self.convg2l(x_g) * g2l_gate
385
+ if self.ratio_gout != 0:
386
+ out_xg = self.convl2g(x_l) * l2g_gate + self.convg2g(x_g)
387
+
388
+ return out_xl, out_xg
389
+
390
+
391
+ class FFC_BN_ACT(nn.Module):
392
+ def __init__(
393
+ self,
394
+ in_channels,
395
+ out_channels,
396
+ kernel_size,
397
+ ratio_gin,
398
+ ratio_gout,
399
+ stride=1,
400
+ padding=0,
401
+ dilation=1,
402
+ groups=1,
403
+ bias=False,
404
+ norm_layer=nn.BatchNorm2d,
405
+ activation_layer=nn.Identity,
406
+ padding_type="reflect",
407
+ enable_lfu=True,
408
+ **kwargs,
409
+ ):
410
+ super(FFC_BN_ACT, self).__init__()
411
+ self.ffc = FFC(
412
+ in_channels,
413
+ out_channels,
414
+ kernel_size,
415
+ ratio_gin,
416
+ ratio_gout,
417
+ stride,
418
+ padding,
419
+ dilation,
420
+ groups,
421
+ bias,
422
+ enable_lfu,
423
+ padding_type=padding_type,
424
+ **kwargs,
425
+ )
426
+ lnorm = nn.Identity if ratio_gout == 1 else norm_layer
427
+ gnorm = nn.Identity if ratio_gout == 0 else norm_layer
428
+ global_channels = int(out_channels * ratio_gout)
429
+ self.bn_l = lnorm(out_channels - global_channels)
430
+ self.bn_g = gnorm(global_channels)
431
+
432
+ lact = nn.Identity if ratio_gout == 1 else activation_layer
433
+ gact = nn.Identity if ratio_gout == 0 else activation_layer
434
+ self.act_l = lact(inplace=True)
435
+ self.act_g = gact(inplace=True)
436
+
437
+ def forward(self, x):
438
+ x_l, x_g = self.ffc(x)
439
+ x_l = self.act_l(self.bn_l(x_l))
440
+ x_g = self.act_g(self.bn_g(x_g))
441
+ return x_l, x_g
442
+
443
+
444
+ class FFCResnetBlock(nn.Module):
445
+ def __init__(
446
+ self,
447
+ dim,
448
+ padding_type,
449
+ norm_layer,
450
+ activation_layer=nn.ReLU,
451
+ dilation=1,
452
+ spatial_transform_kwargs=None,
453
+ inline=False,
454
+ **conv_kwargs,
455
+ ):
456
+ super().__init__()
457
+ self.conv1 = FFC_BN_ACT(
458
+ dim,
459
+ dim,
460
+ kernel_size=3,
461
+ padding=dilation,
462
+ dilation=dilation,
463
+ norm_layer=norm_layer,
464
+ activation_layer=activation_layer,
465
+ padding_type=padding_type,
466
+ **conv_kwargs,
467
+ )
468
+ self.conv2 = FFC_BN_ACT(
469
+ dim,
470
+ dim,
471
+ kernel_size=3,
472
+ padding=dilation,
473
+ dilation=dilation,
474
+ norm_layer=norm_layer,
475
+ activation_layer=activation_layer,
476
+ padding_type=padding_type,
477
+ **conv_kwargs,
478
+ )
479
+ if spatial_transform_kwargs is not None:
480
+ self.conv1 = LearnableSpatialTransformWrapper(
481
+ self.conv1, **spatial_transform_kwargs
482
+ )
483
+ self.conv2 = LearnableSpatialTransformWrapper(
484
+ self.conv2, **spatial_transform_kwargs
485
+ )
486
+ self.inline = inline
487
+
488
+ def forward(self, x):
489
+ if self.inline:
490
+ x_l, x_g = (
491
+ x[:, : -self.conv1.ffc.global_in_num],
492
+ x[:, -self.conv1.ffc.global_in_num :],
493
+ )
494
+ else:
495
+ x_l, x_g = x if type(x) is tuple else (x, 0)
496
+
497
+ id_l, id_g = x_l, x_g
498
+
499
+ x_l, x_g = self.conv1((x_l, x_g))
500
+ x_l, x_g = self.conv2((x_l, x_g))
501
+
502
+ x_l, x_g = id_l + x_l, id_g + x_g
503
+ out = x_l, x_g
504
+ if self.inline:
505
+ out = torch.cat(out, dim=1)
506
+ return out
507
+
508
+
509
+ class ConcatTupleLayer(nn.Module):
510
+ def forward(self, x):
511
+ assert isinstance(x, tuple)
512
+ x_l, x_g = x
513
+ assert torch.is_tensor(x_l) or torch.is_tensor(x_g)
514
+ if not torch.is_tensor(x_g):
515
+ return x_l
516
+ return torch.cat(x, dim=1)
517
+
518
+
519
+ class FFCResNetGenerator(nn.Module):
520
+ def __init__(
521
+ self,
522
+ input_nc,
523
+ output_nc,
524
+ ngf=64,
525
+ n_downsampling=3,
526
+ n_blocks=18,
527
+ norm_layer=nn.BatchNorm2d,
528
+ padding_type="reflect",
529
+ activation_layer=nn.ReLU,
530
+ up_norm_layer=nn.BatchNorm2d,
531
+ up_activation=nn.ReLU(True),
532
+ init_conv_kwargs={},
533
+ downsample_conv_kwargs={},
534
+ resnet_conv_kwargs={},
535
+ spatial_transform_layers=None,
536
+ spatial_transform_kwargs={},
537
+ max_features=1024,
538
+ out_ffc=False,
539
+ out_ffc_kwargs={},
540
+ ):
541
+ assert n_blocks >= 0
542
+ super().__init__()
543
+ """
544
+ init_conv_kwargs = {'ratio_gin': 0, 'ratio_gout': 0, 'enable_lfu': False}
545
+ downsample_conv_kwargs = {'ratio_gin': '${generator.init_conv_kwargs.ratio_gout}', 'ratio_gout': '${generator.downsample_conv_kwargs.ratio_gin}', 'enable_lfu': False}
546
+ resnet_conv_kwargs = {'ratio_gin': 0.75, 'ratio_gout': '${generator.resnet_conv_kwargs.ratio_gin}', 'enable_lfu': False}
547
+ spatial_transform_kwargs = {}
548
+ out_ffc_kwargs = {}
549
+ """
550
+ """
551
+ print(input_nc, output_nc, ngf, n_downsampling, n_blocks, norm_layer,
552
+ padding_type, activation_layer,
553
+ up_norm_layer, up_activation,
554
+ spatial_transform_layers,
555
+ add_out_act, max_features, out_ffc, file=sys.stderr)
556
+
557
+ 4 3 64 3 18 <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
558
+ reflect <class 'torch.nn.modules.activation.ReLU'>
559
+ <class 'torch.nn.modules.batchnorm.BatchNorm2d'>
560
+ ReLU(inplace=True)
561
+ None sigmoid 1024 False
562
+ """
563
+ init_conv_kwargs = {"ratio_gin": 0, "ratio_gout": 0, "enable_lfu": False}
564
+ downsample_conv_kwargs = {"ratio_gin": 0, "ratio_gout": 0, "enable_lfu": False}
565
+ resnet_conv_kwargs = {
566
+ "ratio_gin": 0.75,
567
+ "ratio_gout": 0.75,
568
+ "enable_lfu": False,
569
+ }
570
+ spatial_transform_kwargs = {}
571
+ out_ffc_kwargs = {}
572
+
573
+ model = [
574
+ nn.ReflectionPad2d(3),
575
+ FFC_BN_ACT(
576
+ input_nc,
577
+ ngf,
578
+ kernel_size=7,
579
+ padding=0,
580
+ norm_layer=norm_layer,
581
+ activation_layer=activation_layer,
582
+ **init_conv_kwargs,
583
+ ),
584
+ ]
585
+
586
+ ### downsample
587
+ for i in range(n_downsampling):
588
+ mult = 2**i
589
+ if i == n_downsampling - 1:
590
+ cur_conv_kwargs = dict(downsample_conv_kwargs)
591
+ cur_conv_kwargs["ratio_gout"] = resnet_conv_kwargs.get("ratio_gin", 0)
592
+ else:
593
+ cur_conv_kwargs = downsample_conv_kwargs
594
+ model += [
595
+ FFC_BN_ACT(
596
+ min(max_features, ngf * mult),
597
+ min(max_features, ngf * mult * 2),
598
+ kernel_size=3,
599
+ stride=2,
600
+ padding=1,
601
+ norm_layer=norm_layer,
602
+ activation_layer=activation_layer,
603
+ **cur_conv_kwargs,
604
+ )
605
+ ]
606
+
607
+ mult = 2**n_downsampling
608
+ feats_num_bottleneck = min(max_features, ngf * mult)
609
+
610
+ ### resnet blocks
611
+ for i in range(n_blocks):
612
+ cur_resblock = FFCResnetBlock(
613
+ feats_num_bottleneck,
614
+ padding_type=padding_type,
615
+ activation_layer=activation_layer,
616
+ norm_layer=norm_layer,
617
+ **resnet_conv_kwargs,
618
+ )
619
+ if spatial_transform_layers is not None and i in spatial_transform_layers:
620
+ cur_resblock = LearnableSpatialTransformWrapper(
621
+ cur_resblock, **spatial_transform_kwargs
622
+ )
623
+ model += [cur_resblock]
624
+
625
+ model += [ConcatTupleLayer()]
626
+
627
+ ### upsample
628
+ for i in range(n_downsampling):
629
+ mult = 2 ** (n_downsampling - i)
630
+ model += [
631
+ nn.ConvTranspose2d(
632
+ min(max_features, ngf * mult),
633
+ min(max_features, int(ngf * mult / 2)),
634
+ kernel_size=3,
635
+ stride=2,
636
+ padding=1,
637
+ output_padding=1,
638
+ ),
639
+ up_norm_layer(min(max_features, int(ngf * mult / 2))),
640
+ up_activation,
641
+ ]
642
+
643
+ if out_ffc:
644
+ model += [
645
+ FFCResnetBlock(
646
+ ngf,
647
+ padding_type=padding_type,
648
+ activation_layer=activation_layer,
649
+ norm_layer=norm_layer,
650
+ inline=True,
651
+ **out_ffc_kwargs,
652
+ )
653
+ ]
654
+
655
+ model += [
656
+ nn.ReflectionPad2d(3),
657
+ nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),
658
+ ]
659
+ model.append(nn.Sigmoid())
660
+ self.model = nn.Sequential(*model)
661
+
662
+ def forward(self, image, mask):
663
+ return self.model(torch.cat([image, mask], dim=1))
664
+
665
+
666
+ class LaMa(nn.Module):
667
+ def __init__(self, state_dict) -> None:
668
+ super(LaMa, self).__init__()
669
+ self.model_arch = "LaMa"
670
+ self.sub_type = "Inpaint"
671
+ self.in_nc = 4
672
+ self.out_nc = 3
673
+ self.scale = 1
674
+
675
+ self.min_size = None
676
+ self.pad_mod = 8
677
+ self.pad_to_square = False
678
+
679
+ self.model = FFCResNetGenerator(self.in_nc, self.out_nc)
680
+ self.state = {
681
+ k.replace("generator.model", "model.model"): v
682
+ for k, v in state_dict.items()
683
+ }
684
+
685
+ self.supports_fp16 = False
686
+ self.support_bf16 = True
687
+
688
+ self.load_state_dict(self.state, strict=False)
689
+
690
+ def forward(self, img, mask):
691
+ masked_img = img * (1 - mask)
692
+ inpainted_mask = mask * self.model.forward(masked_img, mask)
693
+ result = inpainted_mask + (1 - mask) * img
694
+ return result
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/OmniSR/ChannelAttention.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch.nn as nn
4
+
5
+
6
+ class CA_layer(nn.Module):
7
+ def __init__(self, channel, reduction=16):
8
+ super(CA_layer, self).__init__()
9
+ # global average pooling
10
+ self.gap = nn.AdaptiveAvgPool2d(1)
11
+ self.fc = nn.Sequential(
12
+ nn.Conv2d(channel, channel // reduction, kernel_size=(1, 1), bias=False),
13
+ nn.GELU(),
14
+ nn.Conv2d(channel // reduction, channel, kernel_size=(1, 1), bias=False),
15
+ # nn.Sigmoid()
16
+ )
17
+
18
+ def forward(self, x):
19
+ y = self.fc(self.gap(x))
20
+ return x * y.expand_as(x)
21
+
22
+
23
+ class Simple_CA_layer(nn.Module):
24
+ def __init__(self, channel):
25
+ super(Simple_CA_layer, self).__init__()
26
+ self.gap = nn.AdaptiveAvgPool2d(1)
27
+ self.fc = nn.Conv2d(
28
+ in_channels=channel,
29
+ out_channels=channel,
30
+ kernel_size=1,
31
+ padding=0,
32
+ stride=1,
33
+ groups=1,
34
+ bias=True,
35
+ )
36
+
37
+ def forward(self, x):
38
+ return x * self.fc(self.gap(x))
39
+
40
+
41
+ class ECA_layer(nn.Module):
42
+ """Constructs a ECA module.
43
+ Args:
44
+ channel: Number of channels of the input feature map
45
+ k_size: Adaptive selection of kernel size
46
+ """
47
+
48
+ def __init__(self, channel):
49
+ super(ECA_layer, self).__init__()
50
+
51
+ b = 1
52
+ gamma = 2
53
+ k_size = int(abs(math.log(channel, 2) + b) / gamma)
54
+ k_size = k_size if k_size % 2 else k_size + 1
55
+ self.avg_pool = nn.AdaptiveAvgPool2d(1)
56
+ self.conv = nn.Conv1d(
57
+ 1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False
58
+ )
59
+ # self.sigmoid = nn.Sigmoid()
60
+
61
+ def forward(self, x):
62
+ # x: input features with shape [b, c, h, w]
63
+ # b, c, h, w = x.size()
64
+
65
+ # feature descriptor on the global spatial information
66
+ y = self.avg_pool(x)
67
+
68
+ # Two different branches of ECA module
69
+ y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
70
+
71
+ # Multi-scale information fusion
72
+ # y = self.sigmoid(y)
73
+
74
+ return x * y.expand_as(x)
75
+
76
+
77
+ class ECA_MaxPool_layer(nn.Module):
78
+ """Constructs a ECA module.
79
+ Args:
80
+ channel: Number of channels of the input feature map
81
+ k_size: Adaptive selection of kernel size
82
+ """
83
+
84
+ def __init__(self, channel):
85
+ super(ECA_MaxPool_layer, self).__init__()
86
+
87
+ b = 1
88
+ gamma = 2
89
+ k_size = int(abs(math.log(channel, 2) + b) / gamma)
90
+ k_size = k_size if k_size % 2 else k_size + 1
91
+ self.max_pool = nn.AdaptiveMaxPool2d(1)
92
+ self.conv = nn.Conv1d(
93
+ 1, 1, kernel_size=k_size, padding=(k_size - 1) // 2, bias=False
94
+ )
95
+ # self.sigmoid = nn.Sigmoid()
96
+
97
+ def forward(self, x):
98
+ # x: input features with shape [b, c, h, w]
99
+ # b, c, h, w = x.size()
100
+
101
+ # feature descriptor on the global spatial information
102
+ y = self.max_pool(x)
103
+
104
+ # Two different branches of ECA module
105
+ y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
106
+
107
+ # Multi-scale information fusion
108
+ # y = self.sigmoid(y)
109
+
110
+ return x * y.expand_as(x)
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/OmniSR/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/OmniSR/OSA.py ADDED
@@ -0,0 +1,577 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+ #############################################################
4
+ # File: OSA.py
5
+ # Created Date: Tuesday April 28th 2022
6
+ # Author: Chen Xuanhong
7
+ # Email: chenxuanhongzju@outlook.com
8
+ # Last Modified: Sunday, 23rd April 2023 3:07:42 pm
9
+ # Modified By: Chen Xuanhong
10
+ # Copyright (c) 2020 Shanghai Jiao Tong University
11
+ #############################################################
12
+
13
+ import torch
14
+ import torch.nn.functional as F
15
+ from einops import rearrange, repeat
16
+ from einops.layers.torch import Rearrange, Reduce
17
+ from torch import einsum, nn
18
+
19
+ from .layernorm import LayerNorm2d
20
+
21
+ # helpers
22
+
23
+
24
+ def exists(val):
25
+ return val is not None
26
+
27
+
28
+ def default(val, d):
29
+ return val if exists(val) else d
30
+
31
+
32
+ def cast_tuple(val, length=1):
33
+ return val if isinstance(val, tuple) else ((val,) * length)
34
+
35
+
36
+ # helper classes
37
+
38
+
39
+ class PreNormResidual(nn.Module):
40
+ def __init__(self, dim, fn):
41
+ super().__init__()
42
+ self.norm = nn.LayerNorm(dim)
43
+ self.fn = fn
44
+
45
+ def forward(self, x):
46
+ return self.fn(self.norm(x)) + x
47
+
48
+
49
+ class Conv_PreNormResidual(nn.Module):
50
+ def __init__(self, dim, fn):
51
+ super().__init__()
52
+ self.norm = LayerNorm2d(dim)
53
+ self.fn = fn
54
+
55
+ def forward(self, x):
56
+ return self.fn(self.norm(x)) + x
57
+
58
+
59
+ class FeedForward(nn.Module):
60
+ def __init__(self, dim, mult=2, dropout=0.0):
61
+ super().__init__()
62
+ inner_dim = int(dim * mult)
63
+ self.net = nn.Sequential(
64
+ nn.Linear(dim, inner_dim),
65
+ nn.GELU(),
66
+ nn.Dropout(dropout),
67
+ nn.Linear(inner_dim, dim),
68
+ nn.Dropout(dropout),
69
+ )
70
+
71
+ def forward(self, x):
72
+ return self.net(x)
73
+
74
+
75
+ class Conv_FeedForward(nn.Module):
76
+ def __init__(self, dim, mult=2, dropout=0.0):
77
+ super().__init__()
78
+ inner_dim = int(dim * mult)
79
+ self.net = nn.Sequential(
80
+ nn.Conv2d(dim, inner_dim, 1, 1, 0),
81
+ nn.GELU(),
82
+ nn.Dropout(dropout),
83
+ nn.Conv2d(inner_dim, dim, 1, 1, 0),
84
+ nn.Dropout(dropout),
85
+ )
86
+
87
+ def forward(self, x):
88
+ return self.net(x)
89
+
90
+
91
+ class Gated_Conv_FeedForward(nn.Module):
92
+ def __init__(self, dim, mult=1, bias=False, dropout=0.0):
93
+ super().__init__()
94
+
95
+ hidden_features = int(dim * mult)
96
+
97
+ self.project_in = nn.Conv2d(dim, hidden_features * 2, kernel_size=1, bias=bias)
98
+
99
+ self.dwconv = nn.Conv2d(
100
+ hidden_features * 2,
101
+ hidden_features * 2,
102
+ kernel_size=3,
103
+ stride=1,
104
+ padding=1,
105
+ groups=hidden_features * 2,
106
+ bias=bias,
107
+ )
108
+
109
+ self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
110
+
111
+ def forward(self, x):
112
+ x = self.project_in(x)
113
+ x1, x2 = self.dwconv(x).chunk(2, dim=1)
114
+ x = F.gelu(x1) * x2
115
+ x = self.project_out(x)
116
+ return x
117
+
118
+
119
+ # MBConv
120
+
121
+
122
+ class SqueezeExcitation(nn.Module):
123
+ def __init__(self, dim, shrinkage_rate=0.25):
124
+ super().__init__()
125
+ hidden_dim = int(dim * shrinkage_rate)
126
+
127
+ self.gate = nn.Sequential(
128
+ Reduce("b c h w -> b c", "mean"),
129
+ nn.Linear(dim, hidden_dim, bias=False),
130
+ nn.SiLU(),
131
+ nn.Linear(hidden_dim, dim, bias=False),
132
+ nn.Sigmoid(),
133
+ Rearrange("b c -> b c 1 1"),
134
+ )
135
+
136
+ def forward(self, x):
137
+ return x * self.gate(x)
138
+
139
+
140
+ class MBConvResidual(nn.Module):
141
+ def __init__(self, fn, dropout=0.0):
142
+ super().__init__()
143
+ self.fn = fn
144
+ self.dropsample = Dropsample(dropout)
145
+
146
+ def forward(self, x):
147
+ out = self.fn(x)
148
+ out = self.dropsample(out)
149
+ return out + x
150
+
151
+
152
+ class Dropsample(nn.Module):
153
+ def __init__(self, prob=0):
154
+ super().__init__()
155
+ self.prob = prob
156
+
157
+ def forward(self, x):
158
+ device = x.device
159
+
160
+ if self.prob == 0.0 or (not self.training):
161
+ return x
162
+
163
+ keep_mask = (
164
+ torch.FloatTensor((x.shape[0], 1, 1, 1), device=device).uniform_()
165
+ > self.prob
166
+ )
167
+ return x * keep_mask / (1 - self.prob)
168
+
169
+
170
+ def MBConv(
171
+ dim_in, dim_out, *, downsample, expansion_rate=4, shrinkage_rate=0.25, dropout=0.0
172
+ ):
173
+ hidden_dim = int(expansion_rate * dim_out)
174
+ stride = 2 if downsample else 1
175
+
176
+ net = nn.Sequential(
177
+ nn.Conv2d(dim_in, hidden_dim, 1),
178
+ # nn.BatchNorm2d(hidden_dim),
179
+ nn.GELU(),
180
+ nn.Conv2d(
181
+ hidden_dim, hidden_dim, 3, stride=stride, padding=1, groups=hidden_dim
182
+ ),
183
+ # nn.BatchNorm2d(hidden_dim),
184
+ nn.GELU(),
185
+ SqueezeExcitation(hidden_dim, shrinkage_rate=shrinkage_rate),
186
+ nn.Conv2d(hidden_dim, dim_out, 1),
187
+ # nn.BatchNorm2d(dim_out)
188
+ )
189
+
190
+ if dim_in == dim_out and not downsample:
191
+ net = MBConvResidual(net, dropout=dropout)
192
+
193
+ return net
194
+
195
+
196
+ # attention related classes
197
+ class Attention(nn.Module):
198
+ def __init__(
199
+ self,
200
+ dim,
201
+ dim_head=32,
202
+ dropout=0.0,
203
+ window_size=7,
204
+ with_pe=True,
205
+ ):
206
+ super().__init__()
207
+ assert (
208
+ dim % dim_head
209
+ ) == 0, "dimension should be divisible by dimension per head"
210
+
211
+ self.heads = dim // dim_head
212
+ self.scale = dim_head**-0.5
213
+ self.with_pe = with_pe
214
+
215
+ self.to_qkv = nn.Linear(dim, dim * 3, bias=False)
216
+
217
+ self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout))
218
+
219
+ self.to_out = nn.Sequential(
220
+ nn.Linear(dim, dim, bias=False), nn.Dropout(dropout)
221
+ )
222
+
223
+ # relative positional bias
224
+ if self.with_pe:
225
+ self.rel_pos_bias = nn.Embedding((2 * window_size - 1) ** 2, self.heads)
226
+
227
+ pos = torch.arange(window_size)
228
+ grid = torch.stack(torch.meshgrid(pos, pos))
229
+ grid = rearrange(grid, "c i j -> (i j) c")
230
+ rel_pos = rearrange(grid, "i ... -> i 1 ...") - rearrange(
231
+ grid, "j ... -> 1 j ..."
232
+ )
233
+ rel_pos += window_size - 1
234
+ rel_pos_indices = (rel_pos * torch.tensor([2 * window_size - 1, 1])).sum(
235
+ dim=-1
236
+ )
237
+
238
+ self.register_buffer("rel_pos_indices", rel_pos_indices, persistent=False)
239
+
240
+ def forward(self, x):
241
+ batch, height, width, window_height, window_width, _, device, h = (
242
+ *x.shape,
243
+ x.device,
244
+ self.heads,
245
+ )
246
+
247
+ # flatten
248
+
249
+ x = rearrange(x, "b x y w1 w2 d -> (b x y) (w1 w2) d")
250
+
251
+ # project for queries, keys, values
252
+
253
+ q, k, v = self.to_qkv(x).chunk(3, dim=-1)
254
+
255
+ # split heads
256
+
257
+ q, k, v = map(lambda t: rearrange(t, "b n (h d ) -> b h n d", h=h), (q, k, v))
258
+
259
+ # scale
260
+
261
+ q = q * self.scale
262
+
263
+ # sim
264
+
265
+ sim = einsum("b h i d, b h j d -> b h i j", q, k)
266
+
267
+ # add positional bias
268
+ if self.with_pe:
269
+ bias = self.rel_pos_bias(self.rel_pos_indices)
270
+ sim = sim + rearrange(bias, "i j h -> h i j")
271
+
272
+ # attention
273
+
274
+ attn = self.attend(sim)
275
+
276
+ # aggregate
277
+
278
+ out = einsum("b h i j, b h j d -> b h i d", attn, v)
279
+
280
+ # merge heads
281
+
282
+ out = rearrange(
283
+ out, "b h (w1 w2) d -> b w1 w2 (h d)", w1=window_height, w2=window_width
284
+ )
285
+
286
+ # combine heads out
287
+
288
+ out = self.to_out(out)
289
+ return rearrange(out, "(b x y) ... -> b x y ...", x=height, y=width)
290
+
291
+
292
+ class Block_Attention(nn.Module):
293
+ def __init__(
294
+ self,
295
+ dim,
296
+ dim_head=32,
297
+ bias=False,
298
+ dropout=0.0,
299
+ window_size=7,
300
+ with_pe=True,
301
+ ):
302
+ super().__init__()
303
+ assert (
304
+ dim % dim_head
305
+ ) == 0, "dimension should be divisible by dimension per head"
306
+
307
+ self.heads = dim // dim_head
308
+ self.ps = window_size
309
+ self.scale = dim_head**-0.5
310
+ self.with_pe = with_pe
311
+
312
+ self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
313
+ self.qkv_dwconv = nn.Conv2d(
314
+ dim * 3,
315
+ dim * 3,
316
+ kernel_size=3,
317
+ stride=1,
318
+ padding=1,
319
+ groups=dim * 3,
320
+ bias=bias,
321
+ )
322
+
323
+ self.attend = nn.Sequential(nn.Softmax(dim=-1), nn.Dropout(dropout))
324
+
325
+ self.to_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
326
+
327
+ def forward(self, x):
328
+ # project for queries, keys, values
329
+ b, c, h, w = x.shape
330
+
331
+ qkv = self.qkv_dwconv(self.qkv(x))
332
+ q, k, v = qkv.chunk(3, dim=1)
333
+
334
+ # split heads
335
+
336
+ q, k, v = map(
337
+ lambda t: rearrange(
338
+ t,
339
+ "b (h d) (x w1) (y w2) -> (b x y) h (w1 w2) d",
340
+ h=self.heads,
341
+ w1=self.ps,
342
+ w2=self.ps,
343
+ ),
344
+ (q, k, v),
345
+ )
346
+
347
+ # scale
348
+
349
+ q = q * self.scale
350
+
351
+ # sim
352
+
353
+ sim = einsum("b h i d, b h j d -> b h i j", q, k)
354
+
355
+ # attention
356
+ attn = self.attend(sim)
357
+
358
+ # aggregate
359
+
360
+ out = einsum("b h i j, b h j d -> b h i d", attn, v)
361
+
362
+ # merge heads
363
+ out = rearrange(
364
+ out,
365
+ "(b x y) head (w1 w2) d -> b (head d) (x w1) (y w2)",
366
+ x=h // self.ps,
367
+ y=w // self.ps,
368
+ head=self.heads,
369
+ w1=self.ps,
370
+ w2=self.ps,
371
+ )
372
+
373
+ out = self.to_out(out)
374
+ return out
375
+
376
+
377
+ class Channel_Attention(nn.Module):
378
+ def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7):
379
+ super(Channel_Attention, self).__init__()
380
+ self.heads = heads
381
+
382
+ self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
383
+
384
+ self.ps = window_size
385
+
386
+ self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
387
+ self.qkv_dwconv = nn.Conv2d(
388
+ dim * 3,
389
+ dim * 3,
390
+ kernel_size=3,
391
+ stride=1,
392
+ padding=1,
393
+ groups=dim * 3,
394
+ bias=bias,
395
+ )
396
+ self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
397
+
398
+ def forward(self, x):
399
+ b, c, h, w = x.shape
400
+
401
+ qkv = self.qkv_dwconv(self.qkv(x))
402
+ qkv = qkv.chunk(3, dim=1)
403
+
404
+ q, k, v = map(
405
+ lambda t: rearrange(
406
+ t,
407
+ "b (head d) (h ph) (w pw) -> b (h w) head d (ph pw)",
408
+ ph=self.ps,
409
+ pw=self.ps,
410
+ head=self.heads,
411
+ ),
412
+ qkv,
413
+ )
414
+
415
+ q = F.normalize(q, dim=-1)
416
+ k = F.normalize(k, dim=-1)
417
+
418
+ attn = (q @ k.transpose(-2, -1)) * self.temperature
419
+ attn = attn.softmax(dim=-1)
420
+ out = attn @ v
421
+
422
+ out = rearrange(
423
+ out,
424
+ "b (h w) head d (ph pw) -> b (head d) (h ph) (w pw)",
425
+ h=h // self.ps,
426
+ w=w // self.ps,
427
+ ph=self.ps,
428
+ pw=self.ps,
429
+ head=self.heads,
430
+ )
431
+
432
+ out = self.project_out(out)
433
+
434
+ return out
435
+
436
+
437
+ class Channel_Attention_grid(nn.Module):
438
+ def __init__(self, dim, heads, bias=False, dropout=0.0, window_size=7):
439
+ super(Channel_Attention_grid, self).__init__()
440
+ self.heads = heads
441
+
442
+ self.temperature = nn.Parameter(torch.ones(heads, 1, 1))
443
+
444
+ self.ps = window_size
445
+
446
+ self.qkv = nn.Conv2d(dim, dim * 3, kernel_size=1, bias=bias)
447
+ self.qkv_dwconv = nn.Conv2d(
448
+ dim * 3,
449
+ dim * 3,
450
+ kernel_size=3,
451
+ stride=1,
452
+ padding=1,
453
+ groups=dim * 3,
454
+ bias=bias,
455
+ )
456
+ self.project_out = nn.Conv2d(dim, dim, kernel_size=1, bias=bias)
457
+
458
+ def forward(self, x):
459
+ b, c, h, w = x.shape
460
+
461
+ qkv = self.qkv_dwconv(self.qkv(x))
462
+ qkv = qkv.chunk(3, dim=1)
463
+
464
+ q, k, v = map(
465
+ lambda t: rearrange(
466
+ t,
467
+ "b (head d) (h ph) (w pw) -> b (ph pw) head d (h w)",
468
+ ph=self.ps,
469
+ pw=self.ps,
470
+ head=self.heads,
471
+ ),
472
+ qkv,
473
+ )
474
+
475
+ q = F.normalize(q, dim=-1)
476
+ k = F.normalize(k, dim=-1)
477
+
478
+ attn = (q @ k.transpose(-2, -1)) * self.temperature
479
+ attn = attn.softmax(dim=-1)
480
+ out = attn @ v
481
+
482
+ out = rearrange(
483
+ out,
484
+ "b (ph pw) head d (h w) -> b (head d) (h ph) (w pw)",
485
+ h=h // self.ps,
486
+ w=w // self.ps,
487
+ ph=self.ps,
488
+ pw=self.ps,
489
+ head=self.heads,
490
+ )
491
+
492
+ out = self.project_out(out)
493
+
494
+ return out
495
+
496
+
497
+ class OSA_Block(nn.Module):
498
+ def __init__(
499
+ self,
500
+ channel_num=64,
501
+ bias=True,
502
+ ffn_bias=True,
503
+ window_size=8,
504
+ with_pe=False,
505
+ dropout=0.0,
506
+ ):
507
+ super(OSA_Block, self).__init__()
508
+
509
+ w = window_size
510
+
511
+ self.layer = nn.Sequential(
512
+ MBConv(
513
+ channel_num,
514
+ channel_num,
515
+ downsample=False,
516
+ expansion_rate=1,
517
+ shrinkage_rate=0.25,
518
+ ),
519
+ Rearrange(
520
+ "b d (x w1) (y w2) -> b x y w1 w2 d", w1=w, w2=w
521
+ ), # block-like attention
522
+ PreNormResidual(
523
+ channel_num,
524
+ Attention(
525
+ dim=channel_num,
526
+ dim_head=channel_num // 4,
527
+ dropout=dropout,
528
+ window_size=window_size,
529
+ with_pe=with_pe,
530
+ ),
531
+ ),
532
+ Rearrange("b x y w1 w2 d -> b d (x w1) (y w2)"),
533
+ Conv_PreNormResidual(
534
+ channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
535
+ ),
536
+ # channel-like attention
537
+ Conv_PreNormResidual(
538
+ channel_num,
539
+ Channel_Attention(
540
+ dim=channel_num, heads=4, dropout=dropout, window_size=window_size
541
+ ),
542
+ ),
543
+ Conv_PreNormResidual(
544
+ channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
545
+ ),
546
+ Rearrange(
547
+ "b d (w1 x) (w2 y) -> b x y w1 w2 d", w1=w, w2=w
548
+ ), # grid-like attention
549
+ PreNormResidual(
550
+ channel_num,
551
+ Attention(
552
+ dim=channel_num,
553
+ dim_head=channel_num // 4,
554
+ dropout=dropout,
555
+ window_size=window_size,
556
+ with_pe=with_pe,
557
+ ),
558
+ ),
559
+ Rearrange("b x y w1 w2 d -> b d (w1 x) (w2 y)"),
560
+ Conv_PreNormResidual(
561
+ channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
562
+ ),
563
+ # channel-like attention
564
+ Conv_PreNormResidual(
565
+ channel_num,
566
+ Channel_Attention_grid(
567
+ dim=channel_num, heads=4, dropout=dropout, window_size=window_size
568
+ ),
569
+ ),
570
+ Conv_PreNormResidual(
571
+ channel_num, Gated_Conv_FeedForward(dim=channel_num, dropout=dropout)
572
+ ),
573
+ )
574
+
575
+ def forward(self, x):
576
+ out = self.layer(x)
577
+ return out
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/OmniSR/OSAG.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+ #############################################################
4
+ # File: OSAG.py
5
+ # Created Date: Tuesday April 28th 2022
6
+ # Author: Chen Xuanhong
7
+ # Email: chenxuanhongzju@outlook.com
8
+ # Last Modified: Sunday, 23rd April 2023 3:08:49 pm
9
+ # Modified By: Chen Xuanhong
10
+ # Copyright (c) 2020 Shanghai Jiao Tong University
11
+ #############################################################
12
+
13
+
14
+ import torch.nn as nn
15
+
16
+ from .esa import ESA
17
+ from .OSA import OSA_Block
18
+
19
+
20
+ class OSAG(nn.Module):
21
+ def __init__(
22
+ self,
23
+ channel_num=64,
24
+ bias=True,
25
+ block_num=4,
26
+ ffn_bias=False,
27
+ window_size=0,
28
+ pe=False,
29
+ ):
30
+ super(OSAG, self).__init__()
31
+
32
+ # print("window_size: %d" % (window_size))
33
+ # print("with_pe", pe)
34
+ # print("ffn_bias: %d" % (ffn_bias))
35
+
36
+ # block_script_name = kwargs.get("block_script_name", "OSA")
37
+ # block_class_name = kwargs.get("block_class_name", "OSA_Block")
38
+
39
+ # script_name = "." + block_script_name
40
+ # package = __import__(script_name, fromlist=True)
41
+ block_class = OSA_Block # getattr(package, block_class_name)
42
+ group_list = []
43
+ for _ in range(block_num):
44
+ temp_res = block_class(
45
+ channel_num,
46
+ bias,
47
+ ffn_bias=ffn_bias,
48
+ window_size=window_size,
49
+ with_pe=pe,
50
+ )
51
+ group_list.append(temp_res)
52
+ group_list.append(nn.Conv2d(channel_num, channel_num, 1, 1, 0, bias=bias))
53
+ self.residual_layer = nn.Sequential(*group_list)
54
+ esa_channel = max(channel_num // 4, 16)
55
+ self.esa = ESA(esa_channel, channel_num)
56
+
57
+ def forward(self, x):
58
+ out = self.residual_layer(x)
59
+ out = out + x
60
+ return self.esa(out)
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/OmniSR/OmniSR.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+ #############################################################
4
+ # File: OmniSR.py
5
+ # Created Date: Tuesday April 28th 2022
6
+ # Author: Chen Xuanhong
7
+ # Email: chenxuanhongzju@outlook.com
8
+ # Last Modified: Sunday, 23rd April 2023 3:06:36 pm
9
+ # Modified By: Chen Xuanhong
10
+ # Copyright (c) 2020 Shanghai Jiao Tong University
11
+ #############################################################
12
+
13
+ import math
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+
19
+ from .OSAG import OSAG
20
+ from .pixelshuffle import pixelshuffle_block
21
+
22
+
23
+ class OmniSR(nn.Module):
24
+ def __init__(
25
+ self,
26
+ state_dict,
27
+ **kwargs,
28
+ ):
29
+ super(OmniSR, self).__init__()
30
+ self.state = state_dict
31
+
32
+ bias = True # Fine to assume this for now
33
+ block_num = 1 # Fine to assume this for now
34
+ ffn_bias = True
35
+ pe = True
36
+
37
+ num_feat = state_dict["input.weight"].shape[0] or 64
38
+ num_in_ch = state_dict["input.weight"].shape[1] or 3
39
+ num_out_ch = num_in_ch # we can just assume this for now. pixelshuffle smh
40
+
41
+ pixelshuffle_shape = state_dict["up.0.weight"].shape[0]
42
+ up_scale = math.sqrt(pixelshuffle_shape / num_out_ch)
43
+ if up_scale - int(up_scale) > 0:
44
+ print(
45
+ "out_nc is probably different than in_nc, scale calculation might be wrong"
46
+ )
47
+ up_scale = int(up_scale)
48
+ res_num = 0
49
+ for key in state_dict.keys():
50
+ if "residual_layer" in key:
51
+ temp_res_num = int(key.split(".")[1])
52
+ if temp_res_num > res_num:
53
+ res_num = temp_res_num
54
+ res_num = res_num + 1 # zero-indexed
55
+
56
+ residual_layer = []
57
+ self.res_num = res_num
58
+
59
+ if (
60
+ "residual_layer.0.residual_layer.0.layer.2.fn.rel_pos_bias.weight"
61
+ in state_dict.keys()
62
+ ):
63
+ rel_pos_bias_weight = state_dict[
64
+ "residual_layer.0.residual_layer.0.layer.2.fn.rel_pos_bias.weight"
65
+ ].shape[0]
66
+ self.window_size = int((math.sqrt(rel_pos_bias_weight) + 1) / 2)
67
+ else:
68
+ self.window_size = 8
69
+
70
+ self.up_scale = up_scale
71
+
72
+ for _ in range(res_num):
73
+ temp_res = OSAG(
74
+ channel_num=num_feat,
75
+ bias=bias,
76
+ block_num=block_num,
77
+ ffn_bias=ffn_bias,
78
+ window_size=self.window_size,
79
+ pe=pe,
80
+ )
81
+ residual_layer.append(temp_res)
82
+ self.residual_layer = nn.Sequential(*residual_layer)
83
+ self.input = nn.Conv2d(
84
+ in_channels=num_in_ch,
85
+ out_channels=num_feat,
86
+ kernel_size=3,
87
+ stride=1,
88
+ padding=1,
89
+ bias=bias,
90
+ )
91
+ self.output = nn.Conv2d(
92
+ in_channels=num_feat,
93
+ out_channels=num_feat,
94
+ kernel_size=3,
95
+ stride=1,
96
+ padding=1,
97
+ bias=bias,
98
+ )
99
+ self.up = pixelshuffle_block(num_feat, num_out_ch, up_scale, bias=bias)
100
+
101
+ # self.tail = pixelshuffle_block(num_feat,num_out_ch,up_scale,bias=bias)
102
+
103
+ # for m in self.modules():
104
+ # if isinstance(m, nn.Conv2d):
105
+ # n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
106
+ # m.weight.data.normal_(0, sqrt(2. / n))
107
+
108
+ # chaiNNer specific stuff
109
+ self.model_arch = "OmniSR"
110
+ self.sub_type = "SR"
111
+ self.in_nc = num_in_ch
112
+ self.out_nc = num_out_ch
113
+ self.num_feat = num_feat
114
+ self.scale = up_scale
115
+
116
+ self.supports_fp16 = True # TODO: Test this
117
+ self.supports_bfp16 = True
118
+ self.min_size_restriction = 16
119
+
120
+ self.load_state_dict(state_dict, strict=False)
121
+
122
+ def check_image_size(self, x):
123
+ _, _, h, w = x.size()
124
+ # import pdb; pdb.set_trace()
125
+ mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
126
+ mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
127
+ # x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), 'reflect')
128
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "constant", 0)
129
+ return x
130
+
131
+ def forward(self, x):
132
+ H, W = x.shape[2:]
133
+ x = self.check_image_size(x)
134
+
135
+ residual = self.input(x)
136
+ out = self.residual_layer(residual)
137
+
138
+ # origin
139
+ out = torch.add(self.output(out), residual)
140
+ out = self.up(out)
141
+
142
+ out = out[:, :, : H * self.up_scale, : W * self.up_scale]
143
+ return out
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/OmniSR/esa.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+ #############################################################
4
+ # File: esa.py
5
+ # Created Date: Tuesday April 28th 2022
6
+ # Author: Chen Xuanhong
7
+ # Email: chenxuanhongzju@outlook.com
8
+ # Last Modified: Thursday, 20th April 2023 9:28:06 am
9
+ # Modified By: Chen Xuanhong
10
+ # Copyright (c) 2020 Shanghai Jiao Tong University
11
+ #############################################################
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+ import torch.nn.functional as F
16
+
17
+ from .layernorm import LayerNorm2d
18
+
19
+
20
+ def moment(x, dim=(2, 3), k=2):
21
+ assert len(x.size()) == 4
22
+ mean = torch.mean(x, dim=dim).unsqueeze(-1).unsqueeze(-1)
23
+ mk = (1 / (x.size(2) * x.size(3))) * torch.sum(torch.pow(x - mean, k), dim=dim)
24
+ return mk
25
+
26
+
27
+ class ESA(nn.Module):
28
+ """
29
+ Modification of Enhanced Spatial Attention (ESA), which is proposed by
30
+ `Residual Feature Aggregation Network for Image Super-Resolution`
31
+ Note: `conv_max` and `conv3_` are NOT used here, so the corresponding codes
32
+ are deleted.
33
+ """
34
+
35
+ def __init__(self, esa_channels, n_feats, conv=nn.Conv2d):
36
+ super(ESA, self).__init__()
37
+ f = esa_channels
38
+ self.conv1 = conv(n_feats, f, kernel_size=1)
39
+ self.conv_f = conv(f, f, kernel_size=1)
40
+ self.conv2 = conv(f, f, kernel_size=3, stride=2, padding=0)
41
+ self.conv3 = conv(f, f, kernel_size=3, padding=1)
42
+ self.conv4 = conv(f, n_feats, kernel_size=1)
43
+ self.sigmoid = nn.Sigmoid()
44
+ self.relu = nn.ReLU(inplace=True)
45
+
46
+ def forward(self, x):
47
+ c1_ = self.conv1(x)
48
+ c1 = self.conv2(c1_)
49
+ v_max = F.max_pool2d(c1, kernel_size=7, stride=3)
50
+ c3 = self.conv3(v_max)
51
+ c3 = F.interpolate(
52
+ c3, (x.size(2), x.size(3)), mode="bilinear", align_corners=False
53
+ )
54
+ cf = self.conv_f(c1_)
55
+ c4 = self.conv4(c3 + cf)
56
+ m = self.sigmoid(c4)
57
+ return x * m
58
+
59
+
60
+ class LK_ESA(nn.Module):
61
+ def __init__(
62
+ self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
63
+ ):
64
+ super(LK_ESA, self).__init__()
65
+ f = esa_channels
66
+ self.conv1 = conv(n_feats, f, kernel_size=1)
67
+ self.conv_f = conv(f, f, kernel_size=1)
68
+
69
+ kernel_size = 17
70
+ kernel_expand = kernel_expand
71
+ padding = kernel_size // 2
72
+
73
+ self.vec_conv = nn.Conv2d(
74
+ in_channels=f * kernel_expand,
75
+ out_channels=f * kernel_expand,
76
+ kernel_size=(1, kernel_size),
77
+ padding=(0, padding),
78
+ groups=2,
79
+ bias=bias,
80
+ )
81
+ self.vec_conv3x1 = nn.Conv2d(
82
+ in_channels=f * kernel_expand,
83
+ out_channels=f * kernel_expand,
84
+ kernel_size=(1, 3),
85
+ padding=(0, 1),
86
+ groups=2,
87
+ bias=bias,
88
+ )
89
+
90
+ self.hor_conv = nn.Conv2d(
91
+ in_channels=f * kernel_expand,
92
+ out_channels=f * kernel_expand,
93
+ kernel_size=(kernel_size, 1),
94
+ padding=(padding, 0),
95
+ groups=2,
96
+ bias=bias,
97
+ )
98
+ self.hor_conv1x3 = nn.Conv2d(
99
+ in_channels=f * kernel_expand,
100
+ out_channels=f * kernel_expand,
101
+ kernel_size=(3, 1),
102
+ padding=(1, 0),
103
+ groups=2,
104
+ bias=bias,
105
+ )
106
+
107
+ self.conv4 = conv(f, n_feats, kernel_size=1)
108
+ self.sigmoid = nn.Sigmoid()
109
+ self.relu = nn.ReLU(inplace=True)
110
+
111
+ def forward(self, x):
112
+ c1_ = self.conv1(x)
113
+
114
+ res = self.vec_conv(c1_) + self.vec_conv3x1(c1_)
115
+ res = self.hor_conv(res) + self.hor_conv1x3(res)
116
+
117
+ cf = self.conv_f(c1_)
118
+ c4 = self.conv4(res + cf)
119
+ m = self.sigmoid(c4)
120
+ return x * m
121
+
122
+
123
+ class LK_ESA_LN(nn.Module):
124
+ def __init__(
125
+ self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
126
+ ):
127
+ super(LK_ESA_LN, self).__init__()
128
+ f = esa_channels
129
+ self.conv1 = conv(n_feats, f, kernel_size=1)
130
+ self.conv_f = conv(f, f, kernel_size=1)
131
+
132
+ kernel_size = 17
133
+ kernel_expand = kernel_expand
134
+ padding = kernel_size // 2
135
+
136
+ self.norm = LayerNorm2d(n_feats)
137
+
138
+ self.vec_conv = nn.Conv2d(
139
+ in_channels=f * kernel_expand,
140
+ out_channels=f * kernel_expand,
141
+ kernel_size=(1, kernel_size),
142
+ padding=(0, padding),
143
+ groups=2,
144
+ bias=bias,
145
+ )
146
+ self.vec_conv3x1 = nn.Conv2d(
147
+ in_channels=f * kernel_expand,
148
+ out_channels=f * kernel_expand,
149
+ kernel_size=(1, 3),
150
+ padding=(0, 1),
151
+ groups=2,
152
+ bias=bias,
153
+ )
154
+
155
+ self.hor_conv = nn.Conv2d(
156
+ in_channels=f * kernel_expand,
157
+ out_channels=f * kernel_expand,
158
+ kernel_size=(kernel_size, 1),
159
+ padding=(padding, 0),
160
+ groups=2,
161
+ bias=bias,
162
+ )
163
+ self.hor_conv1x3 = nn.Conv2d(
164
+ in_channels=f * kernel_expand,
165
+ out_channels=f * kernel_expand,
166
+ kernel_size=(3, 1),
167
+ padding=(1, 0),
168
+ groups=2,
169
+ bias=bias,
170
+ )
171
+
172
+ self.conv4 = conv(f, n_feats, kernel_size=1)
173
+ self.sigmoid = nn.Sigmoid()
174
+ self.relu = nn.ReLU(inplace=True)
175
+
176
+ def forward(self, x):
177
+ c1_ = self.norm(x)
178
+ c1_ = self.conv1(c1_)
179
+
180
+ res = self.vec_conv(c1_) + self.vec_conv3x1(c1_)
181
+ res = self.hor_conv(res) + self.hor_conv1x3(res)
182
+
183
+ cf = self.conv_f(c1_)
184
+ c4 = self.conv4(res + cf)
185
+ m = self.sigmoid(c4)
186
+ return x * m
187
+
188
+
189
+ class AdaGuidedFilter(nn.Module):
190
+ def __init__(
191
+ self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
192
+ ):
193
+ super(AdaGuidedFilter, self).__init__()
194
+
195
+ self.gap = nn.AdaptiveAvgPool2d(1)
196
+ self.fc = nn.Conv2d(
197
+ in_channels=n_feats,
198
+ out_channels=1,
199
+ kernel_size=1,
200
+ padding=0,
201
+ stride=1,
202
+ groups=1,
203
+ bias=True,
204
+ )
205
+
206
+ self.r = 5
207
+
208
+ def box_filter(self, x, r):
209
+ channel = x.shape[1]
210
+ kernel_size = 2 * r + 1
211
+ weight = 1.0 / (kernel_size**2)
212
+ box_kernel = weight * torch.ones(
213
+ (channel, 1, kernel_size, kernel_size), dtype=torch.float32, device=x.device
214
+ )
215
+ output = F.conv2d(x, weight=box_kernel, stride=1, padding=r, groups=channel)
216
+ return output
217
+
218
+ def forward(self, x):
219
+ _, _, H, W = x.shape
220
+ N = self.box_filter(
221
+ torch.ones((1, 1, H, W), dtype=x.dtype, device=x.device), self.r
222
+ )
223
+
224
+ # epsilon = self.fc(self.gap(x))
225
+ # epsilon = torch.pow(epsilon, 2)
226
+ epsilon = 1e-2
227
+
228
+ mean_x = self.box_filter(x, self.r) / N
229
+ var_x = self.box_filter(x * x, self.r) / N - mean_x * mean_x
230
+
231
+ A = var_x / (var_x + epsilon)
232
+ b = (1 - A) * mean_x
233
+ m = A * x + b
234
+
235
+ # mean_A = self.box_filter(A, self.r) / N
236
+ # mean_b = self.box_filter(b, self.r) / N
237
+ # m = mean_A * x + mean_b
238
+ return x * m
239
+
240
+
241
+ class AdaConvGuidedFilter(nn.Module):
242
+ def __init__(
243
+ self, esa_channels, n_feats, conv=nn.Conv2d, kernel_expand=1, bias=True
244
+ ):
245
+ super(AdaConvGuidedFilter, self).__init__()
246
+ f = esa_channels
247
+
248
+ self.conv_f = conv(f, f, kernel_size=1)
249
+
250
+ kernel_size = 17
251
+ kernel_expand = kernel_expand
252
+ padding = kernel_size // 2
253
+
254
+ self.vec_conv = nn.Conv2d(
255
+ in_channels=f,
256
+ out_channels=f,
257
+ kernel_size=(1, kernel_size),
258
+ padding=(0, padding),
259
+ groups=f,
260
+ bias=bias,
261
+ )
262
+
263
+ self.hor_conv = nn.Conv2d(
264
+ in_channels=f,
265
+ out_channels=f,
266
+ kernel_size=(kernel_size, 1),
267
+ padding=(padding, 0),
268
+ groups=f,
269
+ bias=bias,
270
+ )
271
+
272
+ self.gap = nn.AdaptiveAvgPool2d(1)
273
+ self.fc = nn.Conv2d(
274
+ in_channels=f,
275
+ out_channels=f,
276
+ kernel_size=1,
277
+ padding=0,
278
+ stride=1,
279
+ groups=1,
280
+ bias=True,
281
+ )
282
+
283
+ def forward(self, x):
284
+ y = self.vec_conv(x)
285
+ y = self.hor_conv(y)
286
+
287
+ sigma = torch.pow(y, 2)
288
+ epsilon = self.fc(self.gap(y))
289
+
290
+ weight = sigma / (sigma + epsilon)
291
+
292
+ m = weight * x + (1 - weight)
293
+
294
+ return x * m
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/OmniSR/layernorm.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+ #############################################################
4
+ # File: layernorm.py
5
+ # Created Date: Tuesday April 28th 2022
6
+ # Author: Chen Xuanhong
7
+ # Email: chenxuanhongzju@outlook.com
8
+ # Last Modified: Thursday, 20th April 2023 9:28:20 am
9
+ # Modified By: Chen Xuanhong
10
+ # Copyright (c) 2020 Shanghai Jiao Tong University
11
+ #############################################################
12
+
13
+ import torch
14
+ import torch.nn as nn
15
+
16
+
17
+ class LayerNormFunction(torch.autograd.Function):
18
+ @staticmethod
19
+ def forward(ctx, x, weight, bias, eps):
20
+ ctx.eps = eps
21
+ N, C, H, W = x.size()
22
+ mu = x.mean(1, keepdim=True)
23
+ var = (x - mu).pow(2).mean(1, keepdim=True)
24
+ y = (x - mu) / (var + eps).sqrt()
25
+ ctx.save_for_backward(y, var, weight)
26
+ y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
27
+ return y
28
+
29
+ @staticmethod
30
+ def backward(ctx, grad_output):
31
+ eps = ctx.eps
32
+
33
+ N, C, H, W = grad_output.size()
34
+ y, var, weight = ctx.saved_variables
35
+ g = grad_output * weight.view(1, C, 1, 1)
36
+ mean_g = g.mean(dim=1, keepdim=True)
37
+
38
+ mean_gy = (g * y).mean(dim=1, keepdim=True)
39
+ gx = 1.0 / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
40
+ return (
41
+ gx,
42
+ (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0),
43
+ grad_output.sum(dim=3).sum(dim=2).sum(dim=0),
44
+ None,
45
+ )
46
+
47
+
48
+ class LayerNorm2d(nn.Module):
49
+ def __init__(self, channels, eps=1e-6):
50
+ super(LayerNorm2d, self).__init__()
51
+ self.register_parameter("weight", nn.Parameter(torch.ones(channels)))
52
+ self.register_parameter("bias", nn.Parameter(torch.zeros(channels)))
53
+ self.eps = eps
54
+
55
+ def forward(self, x):
56
+ return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
57
+
58
+
59
+ class GRN(nn.Module):
60
+ """GRN (Global Response Normalization) layer"""
61
+
62
+ def __init__(self, dim):
63
+ super().__init__()
64
+ self.gamma = nn.Parameter(torch.zeros(1, dim, 1, 1))
65
+ self.beta = nn.Parameter(torch.zeros(1, dim, 1, 1))
66
+
67
+ def forward(self, x):
68
+ Gx = torch.norm(x, p=2, dim=(2, 3), keepdim=True)
69
+ Nx = Gx / (Gx.mean(dim=1, keepdim=True) + 1e-6)
70
+ return self.gamma * (x * Nx) + self.beta + x
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/OmniSR/pixelshuffle.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding:utf-8 -*-
3
+ #############################################################
4
+ # File: pixelshuffle.py
5
+ # Created Date: Friday July 1st 2022
6
+ # Author: Chen Xuanhong
7
+ # Email: chenxuanhongzju@outlook.com
8
+ # Last Modified: Friday, 1st July 2022 10:18:39 am
9
+ # Modified By: Chen Xuanhong
10
+ # Copyright (c) 2022 Shanghai Jiao Tong University
11
+ #############################################################
12
+
13
+ import torch.nn as nn
14
+
15
+
16
+ def pixelshuffle_block(
17
+ in_channels, out_channels, upscale_factor=2, kernel_size=3, bias=False
18
+ ):
19
+ """
20
+ Upsample features according to `upscale_factor`.
21
+ """
22
+ padding = kernel_size // 2
23
+ conv = nn.Conv2d(
24
+ in_channels,
25
+ out_channels * (upscale_factor**2),
26
+ kernel_size,
27
+ padding=1,
28
+ bias=bias,
29
+ )
30
+ pixel_shuffle = nn.PixelShuffle(upscale_factor)
31
+ return nn.Sequential(*[conv, pixel_shuffle])
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/RRDB.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import functools
5
+ import math
6
+ import re
7
+ from collections import OrderedDict
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+ from . import block as B
14
+
15
+
16
+ # Borrowed from https://github.com/rlaphoenix/VSGAN/blob/master/vsgan/archs/ESRGAN.py
17
+ # Which enhanced stuff that was already here
18
+ class RRDBNet(nn.Module):
19
+ def __init__(
20
+ self,
21
+ state_dict,
22
+ norm=None,
23
+ act: str = "leakyrelu",
24
+ upsampler: str = "upconv",
25
+ mode: B.ConvMode = "CNA",
26
+ ) -> None:
27
+ """
28
+ ESRGAN - Enhanced Super-Resolution Generative Adversarial Networks.
29
+ By Xintao Wang, Ke Yu, Shixiang Wu, Jinjin Gu, Yihao Liu, Chao Dong, Yu Qiao,
30
+ and Chen Change Loy.
31
+ This is old-arch Residual in Residual Dense Block Network and is not
32
+ the newest revision that's available at github.com/xinntao/ESRGAN.
33
+ This is on purpose, the newest Network has severely limited the
34
+ potential use of the Network with no benefits.
35
+ This network supports model files from both new and old-arch.
36
+ Args:
37
+ norm: Normalization layer
38
+ act: Activation layer
39
+ upsampler: Upsample layer. upconv, pixel_shuffle
40
+ mode: Convolution mode
41
+ """
42
+ super(RRDBNet, self).__init__()
43
+ self.model_arch = "ESRGAN"
44
+ self.sub_type = "SR"
45
+
46
+ self.state = state_dict
47
+ self.norm = norm
48
+ self.act = act
49
+ self.upsampler = upsampler
50
+ self.mode = mode
51
+
52
+ self.state_map = {
53
+ # currently supports old, new, and newer RRDBNet arch models
54
+ # ESRGAN, BSRGAN/RealSR, Real-ESRGAN
55
+ "model.0.weight": ("conv_first.weight",),
56
+ "model.0.bias": ("conv_first.bias",),
57
+ "model.1.sub./NB/.weight": ("trunk_conv.weight", "conv_body.weight"),
58
+ "model.1.sub./NB/.bias": ("trunk_conv.bias", "conv_body.bias"),
59
+ r"model.1.sub.\1.RDB\2.conv\3.0.\4": (
60
+ r"RRDB_trunk\.(\d+)\.RDB(\d)\.conv(\d+)\.(weight|bias)",
61
+ r"body\.(\d+)\.rdb(\d)\.conv(\d+)\.(weight|bias)",
62
+ ),
63
+ }
64
+ if "params_ema" in self.state:
65
+ self.state = self.state["params_ema"]
66
+ # self.model_arch = "RealESRGAN"
67
+ self.num_blocks = self.get_num_blocks()
68
+ self.plus = any("conv1x1" in k for k in self.state.keys())
69
+ if self.plus:
70
+ self.model_arch = "ESRGAN+"
71
+
72
+ self.state = self.new_to_old_arch(self.state)
73
+
74
+ self.key_arr = list(self.state.keys())
75
+
76
+ self.in_nc: int = self.state[self.key_arr[0]].shape[1]
77
+ self.out_nc: int = self.state[self.key_arr[-1]].shape[0]
78
+
79
+ self.scale: int = self.get_scale()
80
+ self.num_filters: int = self.state[self.key_arr[0]].shape[0]
81
+
82
+ c2x2 = False
83
+ if self.state["model.0.weight"].shape[-2] == 2:
84
+ c2x2 = True
85
+ self.scale = round(math.sqrt(self.scale / 4))
86
+ self.model_arch = "ESRGAN-2c2"
87
+
88
+ self.supports_fp16 = True
89
+ self.supports_bfp16 = True
90
+ self.min_size_restriction = None
91
+
92
+ # Detect if pixelunshuffle was used (Real-ESRGAN)
93
+ if self.in_nc in (self.out_nc * 4, self.out_nc * 16) and self.out_nc in (
94
+ self.in_nc / 4,
95
+ self.in_nc / 16,
96
+ ):
97
+ self.shuffle_factor = int(math.sqrt(self.in_nc / self.out_nc))
98
+ else:
99
+ self.shuffle_factor = None
100
+
101
+ upsample_block = {
102
+ "upconv": B.upconv_block,
103
+ "pixel_shuffle": B.pixelshuffle_block,
104
+ }.get(self.upsampler)
105
+ if upsample_block is None:
106
+ raise NotImplementedError(f"Upsample mode [{self.upsampler}] is not found")
107
+
108
+ if self.scale == 3:
109
+ upsample_blocks = upsample_block(
110
+ in_nc=self.num_filters,
111
+ out_nc=self.num_filters,
112
+ upscale_factor=3,
113
+ act_type=self.act,
114
+ c2x2=c2x2,
115
+ )
116
+ else:
117
+ upsample_blocks = [
118
+ upsample_block(
119
+ in_nc=self.num_filters,
120
+ out_nc=self.num_filters,
121
+ act_type=self.act,
122
+ c2x2=c2x2,
123
+ )
124
+ for _ in range(int(math.log(self.scale, 2)))
125
+ ]
126
+
127
+ self.model = B.sequential(
128
+ # fea conv
129
+ B.conv_block(
130
+ in_nc=self.in_nc,
131
+ out_nc=self.num_filters,
132
+ kernel_size=3,
133
+ norm_type=None,
134
+ act_type=None,
135
+ c2x2=c2x2,
136
+ ),
137
+ B.ShortcutBlock(
138
+ B.sequential(
139
+ # rrdb blocks
140
+ *[
141
+ B.RRDB(
142
+ nf=self.num_filters,
143
+ kernel_size=3,
144
+ gc=32,
145
+ stride=1,
146
+ bias=True,
147
+ pad_type="zero",
148
+ norm_type=self.norm,
149
+ act_type=self.act,
150
+ mode="CNA",
151
+ plus=self.plus,
152
+ c2x2=c2x2,
153
+ )
154
+ for _ in range(self.num_blocks)
155
+ ],
156
+ # lr conv
157
+ B.conv_block(
158
+ in_nc=self.num_filters,
159
+ out_nc=self.num_filters,
160
+ kernel_size=3,
161
+ norm_type=self.norm,
162
+ act_type=None,
163
+ mode=self.mode,
164
+ c2x2=c2x2,
165
+ ),
166
+ )
167
+ ),
168
+ *upsample_blocks,
169
+ # hr_conv0
170
+ B.conv_block(
171
+ in_nc=self.num_filters,
172
+ out_nc=self.num_filters,
173
+ kernel_size=3,
174
+ norm_type=None,
175
+ act_type=self.act,
176
+ c2x2=c2x2,
177
+ ),
178
+ # hr_conv1
179
+ B.conv_block(
180
+ in_nc=self.num_filters,
181
+ out_nc=self.out_nc,
182
+ kernel_size=3,
183
+ norm_type=None,
184
+ act_type=None,
185
+ c2x2=c2x2,
186
+ ),
187
+ )
188
+
189
+ # Adjust these properties for calculations outside of the model
190
+ if self.shuffle_factor:
191
+ self.in_nc //= self.shuffle_factor**2
192
+ self.scale //= self.shuffle_factor
193
+
194
+ self.load_state_dict(self.state, strict=False)
195
+
196
+ def new_to_old_arch(self, state):
197
+ """Convert a new-arch model state dictionary to an old-arch dictionary."""
198
+ if "params_ema" in state:
199
+ state = state["params_ema"]
200
+
201
+ if "conv_first.weight" not in state:
202
+ # model is already old arch, this is a loose check, but should be sufficient
203
+ return state
204
+
205
+ # add nb to state keys
206
+ for kind in ("weight", "bias"):
207
+ self.state_map[f"model.1.sub.{self.num_blocks}.{kind}"] = self.state_map[
208
+ f"model.1.sub./NB/.{kind}"
209
+ ]
210
+ del self.state_map[f"model.1.sub./NB/.{kind}"]
211
+
212
+ old_state = OrderedDict()
213
+ for old_key, new_keys in self.state_map.items():
214
+ for new_key in new_keys:
215
+ if r"\1" in old_key:
216
+ for k, v in state.items():
217
+ sub = re.sub(new_key, old_key, k)
218
+ if sub != k:
219
+ old_state[sub] = v
220
+ else:
221
+ if new_key in state:
222
+ old_state[old_key] = state[new_key]
223
+
224
+ # upconv layers
225
+ max_upconv = 0
226
+ for key in state.keys():
227
+ match = re.match(r"(upconv|conv_up)(\d)\.(weight|bias)", key)
228
+ if match is not None:
229
+ _, key_num, key_type = match.groups()
230
+ old_state[f"model.{int(key_num) * 3}.{key_type}"] = state[key]
231
+ max_upconv = max(max_upconv, int(key_num) * 3)
232
+
233
+ # final layers
234
+ for key in state.keys():
235
+ if key in ("HRconv.weight", "conv_hr.weight"):
236
+ old_state[f"model.{max_upconv + 2}.weight"] = state[key]
237
+ elif key in ("HRconv.bias", "conv_hr.bias"):
238
+ old_state[f"model.{max_upconv + 2}.bias"] = state[key]
239
+ elif key in ("conv_last.weight",):
240
+ old_state[f"model.{max_upconv + 4}.weight"] = state[key]
241
+ elif key in ("conv_last.bias",):
242
+ old_state[f"model.{max_upconv + 4}.bias"] = state[key]
243
+
244
+ # Sort by first numeric value of each layer
245
+ def compare(item1, item2):
246
+ parts1 = item1.split(".")
247
+ parts2 = item2.split(".")
248
+ int1 = int(parts1[1])
249
+ int2 = int(parts2[1])
250
+ return int1 - int2
251
+
252
+ sorted_keys = sorted(old_state.keys(), key=functools.cmp_to_key(compare))
253
+
254
+ # Rebuild the output dict in the right order
255
+ out_dict = OrderedDict((k, old_state[k]) for k in sorted_keys)
256
+
257
+ return out_dict
258
+
259
+ def get_scale(self, min_part: int = 6) -> int:
260
+ n = 0
261
+ for part in list(self.state):
262
+ parts = part.split(".")[1:]
263
+ if len(parts) == 2:
264
+ part_num = int(parts[0])
265
+ if part_num > min_part and parts[1] == "weight":
266
+ n += 1
267
+ return 2**n
268
+
269
+ def get_num_blocks(self) -> int:
270
+ nbs = []
271
+ state_keys = self.state_map[r"model.1.sub.\1.RDB\2.conv\3.0.\4"] + (
272
+ r"model\.\d+\.sub\.(\d+)\.RDB(\d+)\.conv(\d+)\.0\.(weight|bias)",
273
+ )
274
+ for state_key in state_keys:
275
+ for k in self.state:
276
+ m = re.search(state_key, k)
277
+ if m:
278
+ nbs.append(int(m.group(1)))
279
+ if nbs:
280
+ break
281
+ return max(*nbs) + 1
282
+
283
+ def forward(self, x):
284
+ if self.shuffle_factor:
285
+ _, _, h, w = x.size()
286
+ mod_pad_h = (
287
+ self.shuffle_factor - h % self.shuffle_factor
288
+ ) % self.shuffle_factor
289
+ mod_pad_w = (
290
+ self.shuffle_factor - w % self.shuffle_factor
291
+ ) % self.shuffle_factor
292
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
293
+ x = torch.pixel_unshuffle(x, downscale_factor=self.shuffle_factor)
294
+ x = self.model(x)
295
+ return x[:, :, : h * self.scale, : w * self.scale]
296
+ return self.model(x)
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/SCUNet.py ADDED
@@ -0,0 +1,455 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ # -----------------------------------------------------------------------------------
3
+ # SCUNet: Practical Blind Denoising via Swin-Conv-UNet and Data Synthesis, https://arxiv.org/abs/2203.13278
4
+ # Zhang, Kai and Li, Yawei and Liang, Jingyun and Cao, Jiezhang and Zhang, Yulun and Tang, Hao and Timofte, Radu and Van Gool, Luc
5
+ # -----------------------------------------------------------------------------------
6
+
7
+ import numpy as np
8
+ import torch
9
+ import torch.nn as nn
10
+ import torch.nn.functional as F
11
+ from einops import rearrange
12
+ from einops.layers.torch import Rearrange
13
+
14
+ from .timm.drop import DropPath
15
+ from .timm.weight_init import trunc_normal_
16
+
17
+
18
+ # Borrowed from https://github.com/cszn/SCUNet/blob/main/models/network_scunet.py
19
+ class WMSA(nn.Module):
20
+ """Self-attention module in Swin Transformer"""
21
+
22
+ def __init__(self, input_dim, output_dim, head_dim, window_size, type):
23
+ super(WMSA, self).__init__()
24
+ self.input_dim = input_dim
25
+ self.output_dim = output_dim
26
+ self.head_dim = head_dim
27
+ self.scale = self.head_dim**-0.5
28
+ self.n_heads = input_dim // head_dim
29
+ self.window_size = window_size
30
+ self.type = type
31
+ self.embedding_layer = nn.Linear(self.input_dim, 3 * self.input_dim, bias=True)
32
+
33
+ self.relative_position_params = nn.Parameter(
34
+ torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads)
35
+ )
36
+ # TODO recover
37
+ # self.relative_position_params = nn.Parameter(torch.zeros(self.n_heads, 2 * window_size - 1, 2 * window_size -1))
38
+ self.relative_position_params = nn.Parameter(
39
+ torch.zeros((2 * window_size - 1) * (2 * window_size - 1), self.n_heads)
40
+ )
41
+
42
+ self.linear = nn.Linear(self.input_dim, self.output_dim)
43
+
44
+ trunc_normal_(self.relative_position_params, std=0.02)
45
+ self.relative_position_params = torch.nn.Parameter(
46
+ self.relative_position_params.view(
47
+ 2 * window_size - 1, 2 * window_size - 1, self.n_heads
48
+ )
49
+ .transpose(1, 2)
50
+ .transpose(0, 1)
51
+ )
52
+
53
+ def generate_mask(self, h, w, p, shift):
54
+ """generating the mask of SW-MSA
55
+ Args:
56
+ shift: shift parameters in CyclicShift.
57
+ Returns:
58
+ attn_mask: should be (1 1 w p p),
59
+ """
60
+ # supporting square.
61
+ attn_mask = torch.zeros(
62
+ h,
63
+ w,
64
+ p,
65
+ p,
66
+ p,
67
+ p,
68
+ dtype=torch.bool,
69
+ device=self.relative_position_params.device,
70
+ )
71
+ if self.type == "W":
72
+ return attn_mask
73
+
74
+ s = p - shift
75
+ attn_mask[-1, :, :s, :, s:, :] = True
76
+ attn_mask[-1, :, s:, :, :s, :] = True
77
+ attn_mask[:, -1, :, :s, :, s:] = True
78
+ attn_mask[:, -1, :, s:, :, :s] = True
79
+ attn_mask = rearrange(
80
+ attn_mask, "w1 w2 p1 p2 p3 p4 -> 1 1 (w1 w2) (p1 p2) (p3 p4)"
81
+ )
82
+ return attn_mask
83
+
84
+ def forward(self, x):
85
+ """Forward pass of Window Multi-head Self-attention module.
86
+ Args:
87
+ x: input tensor with shape of [b h w c];
88
+ attn_mask: attention mask, fill -inf where the value is True;
89
+ Returns:
90
+ output: tensor shape [b h w c]
91
+ """
92
+ if self.type != "W":
93
+ x = torch.roll(
94
+ x,
95
+ shifts=(-(self.window_size // 2), -(self.window_size // 2)),
96
+ dims=(1, 2),
97
+ )
98
+
99
+ x = rearrange(
100
+ x,
101
+ "b (w1 p1) (w2 p2) c -> b w1 w2 p1 p2 c",
102
+ p1=self.window_size,
103
+ p2=self.window_size,
104
+ )
105
+ h_windows = x.size(1)
106
+ w_windows = x.size(2)
107
+ # square validation
108
+ # assert h_windows == w_windows
109
+
110
+ x = rearrange(
111
+ x,
112
+ "b w1 w2 p1 p2 c -> b (w1 w2) (p1 p2) c",
113
+ p1=self.window_size,
114
+ p2=self.window_size,
115
+ )
116
+ qkv = self.embedding_layer(x)
117
+ q, k, v = rearrange(
118
+ qkv, "b nw np (threeh c) -> threeh b nw np c", c=self.head_dim
119
+ ).chunk(3, dim=0)
120
+ sim = torch.einsum("hbwpc,hbwqc->hbwpq", q, k) * self.scale
121
+ # Adding learnable relative embedding
122
+ sim = sim + rearrange(self.relative_embedding(), "h p q -> h 1 1 p q")
123
+ # Using Attn Mask to distinguish different subwindows.
124
+ if self.type != "W":
125
+ attn_mask = self.generate_mask(
126
+ h_windows, w_windows, self.window_size, shift=self.window_size // 2
127
+ )
128
+ sim = sim.masked_fill_(attn_mask, float("-inf"))
129
+
130
+ probs = nn.functional.softmax(sim, dim=-1)
131
+ output = torch.einsum("hbwij,hbwjc->hbwic", probs, v)
132
+ output = rearrange(output, "h b w p c -> b w p (h c)")
133
+ output = self.linear(output)
134
+ output = rearrange(
135
+ output,
136
+ "b (w1 w2) (p1 p2) c -> b (w1 p1) (w2 p2) c",
137
+ w1=h_windows,
138
+ p1=self.window_size,
139
+ )
140
+
141
+ if self.type != "W":
142
+ output = torch.roll(
143
+ output,
144
+ shifts=(self.window_size // 2, self.window_size // 2),
145
+ dims=(1, 2),
146
+ )
147
+
148
+ return output
149
+
150
+ def relative_embedding(self):
151
+ cord = torch.tensor(
152
+ np.array(
153
+ [
154
+ [i, j]
155
+ for i in range(self.window_size)
156
+ for j in range(self.window_size)
157
+ ]
158
+ )
159
+ )
160
+ relation = cord[:, None, :] - cord[None, :, :] + self.window_size - 1
161
+ # negative is allowed
162
+ return self.relative_position_params[
163
+ :, relation[:, :, 0].long(), relation[:, :, 1].long()
164
+ ]
165
+
166
+
167
+ class Block(nn.Module):
168
+ def __init__(
169
+ self,
170
+ input_dim,
171
+ output_dim,
172
+ head_dim,
173
+ window_size,
174
+ drop_path,
175
+ type="W",
176
+ input_resolution=None,
177
+ ):
178
+ """SwinTransformer Block"""
179
+ super(Block, self).__init__()
180
+ self.input_dim = input_dim
181
+ self.output_dim = output_dim
182
+ assert type in ["W", "SW"]
183
+ self.type = type
184
+ if input_resolution <= window_size:
185
+ self.type = "W"
186
+
187
+ self.ln1 = nn.LayerNorm(input_dim)
188
+ self.msa = WMSA(input_dim, input_dim, head_dim, window_size, self.type)
189
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
190
+ self.ln2 = nn.LayerNorm(input_dim)
191
+ self.mlp = nn.Sequential(
192
+ nn.Linear(input_dim, 4 * input_dim),
193
+ nn.GELU(),
194
+ nn.Linear(4 * input_dim, output_dim),
195
+ )
196
+
197
+ def forward(self, x):
198
+ x = x + self.drop_path(self.msa(self.ln1(x)))
199
+ x = x + self.drop_path(self.mlp(self.ln2(x)))
200
+ return x
201
+
202
+
203
+ class ConvTransBlock(nn.Module):
204
+ def __init__(
205
+ self,
206
+ conv_dim,
207
+ trans_dim,
208
+ head_dim,
209
+ window_size,
210
+ drop_path,
211
+ type="W",
212
+ input_resolution=None,
213
+ ):
214
+ """SwinTransformer and Conv Block"""
215
+ super(ConvTransBlock, self).__init__()
216
+ self.conv_dim = conv_dim
217
+ self.trans_dim = trans_dim
218
+ self.head_dim = head_dim
219
+ self.window_size = window_size
220
+ self.drop_path = drop_path
221
+ self.type = type
222
+ self.input_resolution = input_resolution
223
+
224
+ assert self.type in ["W", "SW"]
225
+ if self.input_resolution <= self.window_size:
226
+ self.type = "W"
227
+
228
+ self.trans_block = Block(
229
+ self.trans_dim,
230
+ self.trans_dim,
231
+ self.head_dim,
232
+ self.window_size,
233
+ self.drop_path,
234
+ self.type,
235
+ self.input_resolution,
236
+ )
237
+ self.conv1_1 = nn.Conv2d(
238
+ self.conv_dim + self.trans_dim,
239
+ self.conv_dim + self.trans_dim,
240
+ 1,
241
+ 1,
242
+ 0,
243
+ bias=True,
244
+ )
245
+ self.conv1_2 = nn.Conv2d(
246
+ self.conv_dim + self.trans_dim,
247
+ self.conv_dim + self.trans_dim,
248
+ 1,
249
+ 1,
250
+ 0,
251
+ bias=True,
252
+ )
253
+
254
+ self.conv_block = nn.Sequential(
255
+ nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
256
+ nn.ReLU(True),
257
+ nn.Conv2d(self.conv_dim, self.conv_dim, 3, 1, 1, bias=False),
258
+ )
259
+
260
+ def forward(self, x):
261
+ conv_x, trans_x = torch.split(
262
+ self.conv1_1(x), (self.conv_dim, self.trans_dim), dim=1
263
+ )
264
+ conv_x = self.conv_block(conv_x) + conv_x
265
+ trans_x = Rearrange("b c h w -> b h w c")(trans_x)
266
+ trans_x = self.trans_block(trans_x)
267
+ trans_x = Rearrange("b h w c -> b c h w")(trans_x)
268
+ res = self.conv1_2(torch.cat((conv_x, trans_x), dim=1))
269
+ x = x + res
270
+
271
+ return x
272
+
273
+
274
+ class SCUNet(nn.Module):
275
+ def __init__(
276
+ self,
277
+ state_dict,
278
+ in_nc=3,
279
+ config=[4, 4, 4, 4, 4, 4, 4],
280
+ dim=64,
281
+ drop_path_rate=0.0,
282
+ input_resolution=256,
283
+ ):
284
+ super(SCUNet, self).__init__()
285
+ self.model_arch = "SCUNet"
286
+ self.sub_type = "SR"
287
+
288
+ self.num_filters: int = 0
289
+
290
+ self.state = state_dict
291
+ self.config = config
292
+ self.dim = dim
293
+ self.head_dim = 32
294
+ self.window_size = 8
295
+
296
+ self.in_nc = in_nc
297
+ self.out_nc = self.in_nc
298
+ self.scale = 1
299
+ self.supports_fp16 = True
300
+
301
+ # drop path rate for each layer
302
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(config))]
303
+
304
+ self.m_head = [nn.Conv2d(in_nc, dim, 3, 1, 1, bias=False)]
305
+
306
+ begin = 0
307
+ self.m_down1 = [
308
+ ConvTransBlock(
309
+ dim // 2,
310
+ dim // 2,
311
+ self.head_dim,
312
+ self.window_size,
313
+ dpr[i + begin],
314
+ "W" if not i % 2 else "SW",
315
+ input_resolution,
316
+ )
317
+ for i in range(config[0])
318
+ ] + [nn.Conv2d(dim, 2 * dim, 2, 2, 0, bias=False)]
319
+
320
+ begin += config[0]
321
+ self.m_down2 = [
322
+ ConvTransBlock(
323
+ dim,
324
+ dim,
325
+ self.head_dim,
326
+ self.window_size,
327
+ dpr[i + begin],
328
+ "W" if not i % 2 else "SW",
329
+ input_resolution // 2,
330
+ )
331
+ for i in range(config[1])
332
+ ] + [nn.Conv2d(2 * dim, 4 * dim, 2, 2, 0, bias=False)]
333
+
334
+ begin += config[1]
335
+ self.m_down3 = [
336
+ ConvTransBlock(
337
+ 2 * dim,
338
+ 2 * dim,
339
+ self.head_dim,
340
+ self.window_size,
341
+ dpr[i + begin],
342
+ "W" if not i % 2 else "SW",
343
+ input_resolution // 4,
344
+ )
345
+ for i in range(config[2])
346
+ ] + [nn.Conv2d(4 * dim, 8 * dim, 2, 2, 0, bias=False)]
347
+
348
+ begin += config[2]
349
+ self.m_body = [
350
+ ConvTransBlock(
351
+ 4 * dim,
352
+ 4 * dim,
353
+ self.head_dim,
354
+ self.window_size,
355
+ dpr[i + begin],
356
+ "W" if not i % 2 else "SW",
357
+ input_resolution // 8,
358
+ )
359
+ for i in range(config[3])
360
+ ]
361
+
362
+ begin += config[3]
363
+ self.m_up3 = [
364
+ nn.ConvTranspose2d(8 * dim, 4 * dim, 2, 2, 0, bias=False),
365
+ ] + [
366
+ ConvTransBlock(
367
+ 2 * dim,
368
+ 2 * dim,
369
+ self.head_dim,
370
+ self.window_size,
371
+ dpr[i + begin],
372
+ "W" if not i % 2 else "SW",
373
+ input_resolution // 4,
374
+ )
375
+ for i in range(config[4])
376
+ ]
377
+
378
+ begin += config[4]
379
+ self.m_up2 = [
380
+ nn.ConvTranspose2d(4 * dim, 2 * dim, 2, 2, 0, bias=False),
381
+ ] + [
382
+ ConvTransBlock(
383
+ dim,
384
+ dim,
385
+ self.head_dim,
386
+ self.window_size,
387
+ dpr[i + begin],
388
+ "W" if not i % 2 else "SW",
389
+ input_resolution // 2,
390
+ )
391
+ for i in range(config[5])
392
+ ]
393
+
394
+ begin += config[5]
395
+ self.m_up1 = [
396
+ nn.ConvTranspose2d(2 * dim, dim, 2, 2, 0, bias=False),
397
+ ] + [
398
+ ConvTransBlock(
399
+ dim // 2,
400
+ dim // 2,
401
+ self.head_dim,
402
+ self.window_size,
403
+ dpr[i + begin],
404
+ "W" if not i % 2 else "SW",
405
+ input_resolution,
406
+ )
407
+ for i in range(config[6])
408
+ ]
409
+
410
+ self.m_tail = [nn.Conv2d(dim, in_nc, 3, 1, 1, bias=False)]
411
+
412
+ self.m_head = nn.Sequential(*self.m_head)
413
+ self.m_down1 = nn.Sequential(*self.m_down1)
414
+ self.m_down2 = nn.Sequential(*self.m_down2)
415
+ self.m_down3 = nn.Sequential(*self.m_down3)
416
+ self.m_body = nn.Sequential(*self.m_body)
417
+ self.m_up3 = nn.Sequential(*self.m_up3)
418
+ self.m_up2 = nn.Sequential(*self.m_up2)
419
+ self.m_up1 = nn.Sequential(*self.m_up1)
420
+ self.m_tail = nn.Sequential(*self.m_tail)
421
+ # self.apply(self._init_weights)
422
+ self.load_state_dict(state_dict, strict=True)
423
+
424
+ def check_image_size(self, x):
425
+ _, _, h, w = x.size()
426
+ mod_pad_h = (64 - h % 64) % 64
427
+ mod_pad_w = (64 - w % 64) % 64
428
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
429
+ return x
430
+
431
+ def forward(self, x0):
432
+ h, w = x0.size()[-2:]
433
+ x0 = self.check_image_size(x0)
434
+
435
+ x1 = self.m_head(x0)
436
+ x2 = self.m_down1(x1)
437
+ x3 = self.m_down2(x2)
438
+ x4 = self.m_down3(x3)
439
+ x = self.m_body(x4)
440
+ x = self.m_up3(x + x4)
441
+ x = self.m_up2(x + x3)
442
+ x = self.m_up1(x + x2)
443
+ x = self.m_tail(x + x1)
444
+
445
+ x = x[:, :, :h, :w]
446
+ return x
447
+
448
+ def _init_weights(self, m):
449
+ if isinstance(m, nn.Linear):
450
+ trunc_normal_(m.weight, std=0.02)
451
+ if m.bias is not None:
452
+ nn.init.constant_(m.bias, 0)
453
+ elif isinstance(m, nn.LayerNorm):
454
+ nn.init.constant_(m.bias, 0)
455
+ nn.init.constant_(m.weight, 1.0)
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/SPSR.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+ from . import block as B
11
+
12
+
13
+ class Get_gradient_nopadding(nn.Module):
14
+ def __init__(self):
15
+ super(Get_gradient_nopadding, self).__init__()
16
+ kernel_v = [[0, -1, 0], [0, 0, 0], [0, 1, 0]]
17
+ kernel_h = [[0, 0, 0], [-1, 0, 1], [0, 0, 0]]
18
+ kernel_h = torch.FloatTensor(kernel_h).unsqueeze(0).unsqueeze(0)
19
+ kernel_v = torch.FloatTensor(kernel_v).unsqueeze(0).unsqueeze(0)
20
+ self.weight_h = nn.Parameter(data=kernel_h, requires_grad=False) # type: ignore
21
+
22
+ self.weight_v = nn.Parameter(data=kernel_v, requires_grad=False) # type: ignore
23
+
24
+ def forward(self, x):
25
+ x_list = []
26
+ for i in range(x.shape[1]):
27
+ x_i = x[:, i]
28
+ x_i_v = F.conv2d(x_i.unsqueeze(1), self.weight_v, padding=1)
29
+ x_i_h = F.conv2d(x_i.unsqueeze(1), self.weight_h, padding=1)
30
+ x_i = torch.sqrt(torch.pow(x_i_v, 2) + torch.pow(x_i_h, 2) + 1e-6)
31
+ x_list.append(x_i)
32
+
33
+ x = torch.cat(x_list, dim=1)
34
+
35
+ return x
36
+
37
+
38
+ class SPSRNet(nn.Module):
39
+ def __init__(
40
+ self,
41
+ state_dict,
42
+ norm=None,
43
+ act: str = "leakyrelu",
44
+ upsampler: str = "upconv",
45
+ mode: B.ConvMode = "CNA",
46
+ ):
47
+ super(SPSRNet, self).__init__()
48
+ self.model_arch = "SPSR"
49
+ self.sub_type = "SR"
50
+
51
+ self.state = state_dict
52
+ self.norm = norm
53
+ self.act = act
54
+ self.upsampler = upsampler
55
+ self.mode = mode
56
+
57
+ self.num_blocks = self.get_num_blocks()
58
+
59
+ self.in_nc: int = self.state["model.0.weight"].shape[1]
60
+ self.out_nc: int = self.state["f_HR_conv1.0.bias"].shape[0]
61
+
62
+ self.scale = self.get_scale(4)
63
+ self.num_filters: int = self.state["model.0.weight"].shape[0]
64
+
65
+ self.supports_fp16 = True
66
+ self.supports_bfp16 = True
67
+ self.min_size_restriction = None
68
+
69
+ n_upscale = int(math.log(self.scale, 2))
70
+ if self.scale == 3:
71
+ n_upscale = 1
72
+
73
+ fea_conv = B.conv_block(
74
+ self.in_nc, self.num_filters, kernel_size=3, norm_type=None, act_type=None
75
+ )
76
+ rb_blocks = [
77
+ B.RRDB(
78
+ self.num_filters,
79
+ kernel_size=3,
80
+ gc=32,
81
+ stride=1,
82
+ bias=True,
83
+ pad_type="zero",
84
+ norm_type=norm,
85
+ act_type=act,
86
+ mode="CNA",
87
+ )
88
+ for _ in range(self.num_blocks)
89
+ ]
90
+ LR_conv = B.conv_block(
91
+ self.num_filters,
92
+ self.num_filters,
93
+ kernel_size=3,
94
+ norm_type=norm,
95
+ act_type=None,
96
+ mode=mode,
97
+ )
98
+
99
+ if upsampler == "upconv":
100
+ upsample_block = B.upconv_block
101
+ elif upsampler == "pixelshuffle":
102
+ upsample_block = B.pixelshuffle_block
103
+ else:
104
+ raise NotImplementedError(f"upsample mode [{upsampler}] is not found")
105
+ if self.scale == 3:
106
+ a_upsampler = upsample_block(
107
+ self.num_filters, self.num_filters, 3, act_type=act
108
+ )
109
+ else:
110
+ a_upsampler = [
111
+ upsample_block(self.num_filters, self.num_filters, act_type=act)
112
+ for _ in range(n_upscale)
113
+ ]
114
+ self.HR_conv0_new = B.conv_block(
115
+ self.num_filters,
116
+ self.num_filters,
117
+ kernel_size=3,
118
+ norm_type=None,
119
+ act_type=act,
120
+ )
121
+ self.HR_conv1_new = B.conv_block(
122
+ self.num_filters,
123
+ self.num_filters,
124
+ kernel_size=3,
125
+ norm_type=None,
126
+ act_type=None,
127
+ )
128
+
129
+ self.model = B.sequential(
130
+ fea_conv,
131
+ B.ShortcutBlockSPSR(B.sequential(*rb_blocks, LR_conv)),
132
+ *a_upsampler,
133
+ self.HR_conv0_new,
134
+ )
135
+
136
+ self.get_g_nopadding = Get_gradient_nopadding()
137
+
138
+ self.b_fea_conv = B.conv_block(
139
+ self.in_nc, self.num_filters, kernel_size=3, norm_type=None, act_type=None
140
+ )
141
+
142
+ self.b_concat_1 = B.conv_block(
143
+ 2 * self.num_filters,
144
+ self.num_filters,
145
+ kernel_size=3,
146
+ norm_type=None,
147
+ act_type=None,
148
+ )
149
+ self.b_block_1 = B.RRDB(
150
+ self.num_filters * 2,
151
+ kernel_size=3,
152
+ gc=32,
153
+ stride=1,
154
+ bias=True,
155
+ pad_type="zero",
156
+ norm_type=norm,
157
+ act_type=act,
158
+ mode="CNA",
159
+ )
160
+
161
+ self.b_concat_2 = B.conv_block(
162
+ 2 * self.num_filters,
163
+ self.num_filters,
164
+ kernel_size=3,
165
+ norm_type=None,
166
+ act_type=None,
167
+ )
168
+ self.b_block_2 = B.RRDB(
169
+ self.num_filters * 2,
170
+ kernel_size=3,
171
+ gc=32,
172
+ stride=1,
173
+ bias=True,
174
+ pad_type="zero",
175
+ norm_type=norm,
176
+ act_type=act,
177
+ mode="CNA",
178
+ )
179
+
180
+ self.b_concat_3 = B.conv_block(
181
+ 2 * self.num_filters,
182
+ self.num_filters,
183
+ kernel_size=3,
184
+ norm_type=None,
185
+ act_type=None,
186
+ )
187
+ self.b_block_3 = B.RRDB(
188
+ self.num_filters * 2,
189
+ kernel_size=3,
190
+ gc=32,
191
+ stride=1,
192
+ bias=True,
193
+ pad_type="zero",
194
+ norm_type=norm,
195
+ act_type=act,
196
+ mode="CNA",
197
+ )
198
+
199
+ self.b_concat_4 = B.conv_block(
200
+ 2 * self.num_filters,
201
+ self.num_filters,
202
+ kernel_size=3,
203
+ norm_type=None,
204
+ act_type=None,
205
+ )
206
+ self.b_block_4 = B.RRDB(
207
+ self.num_filters * 2,
208
+ kernel_size=3,
209
+ gc=32,
210
+ stride=1,
211
+ bias=True,
212
+ pad_type="zero",
213
+ norm_type=norm,
214
+ act_type=act,
215
+ mode="CNA",
216
+ )
217
+
218
+ self.b_LR_conv = B.conv_block(
219
+ self.num_filters,
220
+ self.num_filters,
221
+ kernel_size=3,
222
+ norm_type=norm,
223
+ act_type=None,
224
+ mode=mode,
225
+ )
226
+
227
+ if upsampler == "upconv":
228
+ upsample_block = B.upconv_block
229
+ elif upsampler == "pixelshuffle":
230
+ upsample_block = B.pixelshuffle_block
231
+ else:
232
+ raise NotImplementedError(f"upsample mode [{upsampler}] is not found")
233
+ if self.scale == 3:
234
+ b_upsampler = upsample_block(
235
+ self.num_filters, self.num_filters, 3, act_type=act
236
+ )
237
+ else:
238
+ b_upsampler = [
239
+ upsample_block(self.num_filters, self.num_filters, act_type=act)
240
+ for _ in range(n_upscale)
241
+ ]
242
+
243
+ b_HR_conv0 = B.conv_block(
244
+ self.num_filters,
245
+ self.num_filters,
246
+ kernel_size=3,
247
+ norm_type=None,
248
+ act_type=act,
249
+ )
250
+ b_HR_conv1 = B.conv_block(
251
+ self.num_filters,
252
+ self.num_filters,
253
+ kernel_size=3,
254
+ norm_type=None,
255
+ act_type=None,
256
+ )
257
+
258
+ self.b_module = B.sequential(*b_upsampler, b_HR_conv0, b_HR_conv1)
259
+
260
+ self.conv_w = B.conv_block(
261
+ self.num_filters, self.out_nc, kernel_size=1, norm_type=None, act_type=None
262
+ )
263
+
264
+ self.f_concat = B.conv_block(
265
+ self.num_filters * 2,
266
+ self.num_filters,
267
+ kernel_size=3,
268
+ norm_type=None,
269
+ act_type=None,
270
+ )
271
+
272
+ self.f_block = B.RRDB(
273
+ self.num_filters * 2,
274
+ kernel_size=3,
275
+ gc=32,
276
+ stride=1,
277
+ bias=True,
278
+ pad_type="zero",
279
+ norm_type=norm,
280
+ act_type=act,
281
+ mode="CNA",
282
+ )
283
+
284
+ self.f_HR_conv0 = B.conv_block(
285
+ self.num_filters,
286
+ self.num_filters,
287
+ kernel_size=3,
288
+ norm_type=None,
289
+ act_type=act,
290
+ )
291
+ self.f_HR_conv1 = B.conv_block(
292
+ self.num_filters, self.out_nc, kernel_size=3, norm_type=None, act_type=None
293
+ )
294
+
295
+ self.load_state_dict(self.state, strict=False)
296
+
297
+ def get_scale(self, min_part: int = 4) -> int:
298
+ n = 0
299
+ for part in list(self.state):
300
+ parts = part.split(".")
301
+ if len(parts) == 3:
302
+ part_num = int(parts[1])
303
+ if part_num > min_part and parts[0] == "model" and parts[2] == "weight":
304
+ n += 1
305
+ return 2**n
306
+
307
+ def get_num_blocks(self) -> int:
308
+ nb = 0
309
+ for part in list(self.state):
310
+ parts = part.split(".")
311
+ n_parts = len(parts)
312
+ if n_parts == 5 and parts[2] == "sub":
313
+ nb = int(parts[3])
314
+ return nb
315
+
316
+ def forward(self, x):
317
+ x_grad = self.get_g_nopadding(x)
318
+ x = self.model[0](x)
319
+
320
+ x, block_list = self.model[1](x)
321
+
322
+ x_ori = x
323
+ for i in range(5):
324
+ x = block_list[i](x)
325
+ x_fea1 = x
326
+
327
+ for i in range(5):
328
+ x = block_list[i + 5](x)
329
+ x_fea2 = x
330
+
331
+ for i in range(5):
332
+ x = block_list[i + 10](x)
333
+ x_fea3 = x
334
+
335
+ for i in range(5):
336
+ x = block_list[i + 15](x)
337
+ x_fea4 = x
338
+
339
+ x = block_list[20:](x)
340
+ # short cut
341
+ x = x_ori + x
342
+ x = self.model[2:](x)
343
+ x = self.HR_conv1_new(x)
344
+
345
+ x_b_fea = self.b_fea_conv(x_grad)
346
+ x_cat_1 = torch.cat([x_b_fea, x_fea1], dim=1)
347
+
348
+ x_cat_1 = self.b_block_1(x_cat_1)
349
+ x_cat_1 = self.b_concat_1(x_cat_1)
350
+
351
+ x_cat_2 = torch.cat([x_cat_1, x_fea2], dim=1)
352
+
353
+ x_cat_2 = self.b_block_2(x_cat_2)
354
+ x_cat_2 = self.b_concat_2(x_cat_2)
355
+
356
+ x_cat_3 = torch.cat([x_cat_2, x_fea3], dim=1)
357
+
358
+ x_cat_3 = self.b_block_3(x_cat_3)
359
+ x_cat_3 = self.b_concat_3(x_cat_3)
360
+
361
+ x_cat_4 = torch.cat([x_cat_3, x_fea4], dim=1)
362
+
363
+ x_cat_4 = self.b_block_4(x_cat_4)
364
+ x_cat_4 = self.b_concat_4(x_cat_4)
365
+
366
+ x_cat_4 = self.b_LR_conv(x_cat_4)
367
+
368
+ # short cut
369
+ x_cat_4 = x_cat_4 + x_b_fea
370
+ x_branch = self.b_module(x_cat_4)
371
+
372
+ # x_out_branch = self.conv_w(x_branch)
373
+ ########
374
+ x_branch_d = x_branch
375
+ x_f_cat = torch.cat([x_branch_d, x], dim=1)
376
+ x_f_cat = self.f_block(x_f_cat)
377
+ x_out = self.f_concat(x_f_cat)
378
+ x_out = self.f_HR_conv0(x_out)
379
+ x_out = self.f_HR_conv1(x_out)
380
+
381
+ #########
382
+ # return x_out_branch, x_out, x_grad
383
+ return x_out
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/SRVGG.py ADDED
@@ -0,0 +1,114 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import math
5
+
6
+ import torch.nn as nn
7
+ import torch.nn.functional as F
8
+
9
+
10
+ class SRVGGNetCompact(nn.Module):
11
+ """A compact VGG-style network structure for super-resolution.
12
+ It is a compact network structure, which performs upsampling in the last layer and no convolution is
13
+ conducted on the HR feature space.
14
+ Args:
15
+ num_in_ch (int): Channel number of inputs. Default: 3.
16
+ num_out_ch (int): Channel number of outputs. Default: 3.
17
+ num_feat (int): Channel number of intermediate features. Default: 64.
18
+ num_conv (int): Number of convolution layers in the body network. Default: 16.
19
+ upscale (int): Upsampling factor. Default: 4.
20
+ act_type (str): Activation type, options: 'relu', 'prelu', 'leakyrelu'. Default: prelu.
21
+ """
22
+
23
+ def __init__(
24
+ self,
25
+ state_dict,
26
+ act_type: str = "prelu",
27
+ ):
28
+ super(SRVGGNetCompact, self).__init__()
29
+ self.model_arch = "SRVGG (RealESRGAN)"
30
+ self.sub_type = "SR"
31
+
32
+ self.act_type = act_type
33
+
34
+ self.state = state_dict
35
+
36
+ if "params" in self.state:
37
+ self.state = self.state["params"]
38
+
39
+ self.key_arr = list(self.state.keys())
40
+
41
+ self.in_nc = self.get_in_nc()
42
+ self.num_feat = self.get_num_feats()
43
+ self.num_conv = self.get_num_conv()
44
+ self.out_nc = self.in_nc # :(
45
+ self.pixelshuffle_shape = None # Defined in get_scale()
46
+ self.scale = self.get_scale()
47
+
48
+ self.supports_fp16 = True
49
+ self.supports_bfp16 = True
50
+ self.min_size_restriction = None
51
+
52
+ self.body = nn.ModuleList()
53
+ # the first conv
54
+ self.body.append(nn.Conv2d(self.in_nc, self.num_feat, 3, 1, 1))
55
+ # the first activation
56
+ if act_type == "relu":
57
+ activation = nn.ReLU(inplace=True)
58
+ elif act_type == "prelu":
59
+ activation = nn.PReLU(num_parameters=self.num_feat)
60
+ elif act_type == "leakyrelu":
61
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
62
+ self.body.append(activation) # type: ignore
63
+
64
+ # the body structure
65
+ for _ in range(self.num_conv):
66
+ self.body.append(nn.Conv2d(self.num_feat, self.num_feat, 3, 1, 1))
67
+ # activation
68
+ if act_type == "relu":
69
+ activation = nn.ReLU(inplace=True)
70
+ elif act_type == "prelu":
71
+ activation = nn.PReLU(num_parameters=self.num_feat)
72
+ elif act_type == "leakyrelu":
73
+ activation = nn.LeakyReLU(negative_slope=0.1, inplace=True)
74
+ self.body.append(activation) # type: ignore
75
+
76
+ # the last conv
77
+ self.body.append(nn.Conv2d(self.num_feat, self.pixelshuffle_shape, 3, 1, 1)) # type: ignore
78
+ # upsample
79
+ self.upsampler = nn.PixelShuffle(self.scale)
80
+
81
+ self.load_state_dict(self.state, strict=False)
82
+
83
+ def get_num_conv(self) -> int:
84
+ return (int(self.key_arr[-1].split(".")[1]) - 2) // 2
85
+
86
+ def get_num_feats(self) -> int:
87
+ return self.state[self.key_arr[0]].shape[0]
88
+
89
+ def get_in_nc(self) -> int:
90
+ return self.state[self.key_arr[0]].shape[1]
91
+
92
+ def get_scale(self) -> int:
93
+ self.pixelshuffle_shape = self.state[self.key_arr[-1]].shape[0]
94
+ # Assume out_nc is the same as in_nc
95
+ # I cant think of a better way to do that
96
+ self.out_nc = self.in_nc
97
+ scale = math.sqrt(self.pixelshuffle_shape / self.out_nc)
98
+ if scale - int(scale) > 0:
99
+ print(
100
+ "out_nc is probably different than in_nc, scale calculation might be wrong"
101
+ )
102
+ scale = int(scale)
103
+ return scale
104
+
105
+ def forward(self, x):
106
+ out = x
107
+ for i in range(0, len(self.body)):
108
+ out = self.body[i](out)
109
+
110
+ out = self.upsampler(out)
111
+ # add the nearest upsampled image, so that the network learns the residual
112
+ base = F.interpolate(x, scale_factor=self.scale, mode="nearest")
113
+ out += base
114
+ return out
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/SwiftSRGAN.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # From https://github.com/Koushik0901/Swift-SRGAN/blob/master/swift-srgan/models.py
2
+
3
+ import torch
4
+ from torch import nn
5
+
6
+
7
+ class SeperableConv2d(nn.Module):
8
+ def __init__(
9
+ self, in_channels, out_channels, kernel_size, stride=1, padding=1, bias=True
10
+ ):
11
+ super(SeperableConv2d, self).__init__()
12
+ self.depthwise = nn.Conv2d(
13
+ in_channels,
14
+ in_channels,
15
+ kernel_size=kernel_size,
16
+ stride=stride,
17
+ groups=in_channels,
18
+ bias=bias,
19
+ padding=padding,
20
+ )
21
+ self.pointwise = nn.Conv2d(in_channels, out_channels, kernel_size=1, bias=bias)
22
+
23
+ def forward(self, x):
24
+ return self.pointwise(self.depthwise(x))
25
+
26
+
27
+ class ConvBlock(nn.Module):
28
+ def __init__(
29
+ self,
30
+ in_channels,
31
+ out_channels,
32
+ use_act=True,
33
+ use_bn=True,
34
+ discriminator=False,
35
+ **kwargs,
36
+ ):
37
+ super(ConvBlock, self).__init__()
38
+
39
+ self.use_act = use_act
40
+ self.cnn = SeperableConv2d(in_channels, out_channels, **kwargs, bias=not use_bn)
41
+ self.bn = nn.BatchNorm2d(out_channels) if use_bn else nn.Identity()
42
+ self.act = (
43
+ nn.LeakyReLU(0.2, inplace=True)
44
+ if discriminator
45
+ else nn.PReLU(num_parameters=out_channels)
46
+ )
47
+
48
+ def forward(self, x):
49
+ return self.act(self.bn(self.cnn(x))) if self.use_act else self.bn(self.cnn(x))
50
+
51
+
52
+ class UpsampleBlock(nn.Module):
53
+ def __init__(self, in_channels, scale_factor):
54
+ super(UpsampleBlock, self).__init__()
55
+
56
+ self.conv = SeperableConv2d(
57
+ in_channels,
58
+ in_channels * scale_factor**2,
59
+ kernel_size=3,
60
+ stride=1,
61
+ padding=1,
62
+ )
63
+ self.ps = nn.PixelShuffle(
64
+ scale_factor
65
+ ) # (in_channels * 4, H, W) -> (in_channels, H*2, W*2)
66
+ self.act = nn.PReLU(num_parameters=in_channels)
67
+
68
+ def forward(self, x):
69
+ return self.act(self.ps(self.conv(x)))
70
+
71
+
72
+ class ResidualBlock(nn.Module):
73
+ def __init__(self, in_channels):
74
+ super(ResidualBlock, self).__init__()
75
+
76
+ self.block1 = ConvBlock(
77
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
78
+ )
79
+ self.block2 = ConvBlock(
80
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1, use_act=False
81
+ )
82
+
83
+ def forward(self, x):
84
+ out = self.block1(x)
85
+ out = self.block2(out)
86
+ return out + x
87
+
88
+
89
+ class Generator(nn.Module):
90
+ """Swift-SRGAN Generator
91
+ Args:
92
+ in_channels (int): number of input image channels.
93
+ num_channels (int): number of hidden channels.
94
+ num_blocks (int): number of residual blocks.
95
+ upscale_factor (int): factor to upscale the image [2x, 4x, 8x].
96
+ Returns:
97
+ torch.Tensor: super resolution image
98
+ """
99
+
100
+ def __init__(
101
+ self,
102
+ state_dict,
103
+ ):
104
+ super(Generator, self).__init__()
105
+ self.model_arch = "Swift-SRGAN"
106
+ self.sub_type = "SR"
107
+ self.state = state_dict
108
+ if "model" in self.state:
109
+ self.state = self.state["model"]
110
+
111
+ self.in_nc: int = self.state["initial.cnn.depthwise.weight"].shape[0]
112
+ self.out_nc: int = self.state["final_conv.pointwise.weight"].shape[0]
113
+ self.num_filters: int = self.state["initial.cnn.pointwise.weight"].shape[0]
114
+ self.num_blocks = len(
115
+ set([x.split(".")[1] for x in self.state.keys() if "residual" in x])
116
+ )
117
+ self.scale: int = 2 ** len(
118
+ set([x.split(".")[1] for x in self.state.keys() if "upsampler" in x])
119
+ )
120
+
121
+ in_channels = self.in_nc
122
+ num_channels = self.num_filters
123
+ num_blocks = self.num_blocks
124
+ upscale_factor = self.scale
125
+
126
+ self.supports_fp16 = True
127
+ self.supports_bfp16 = True
128
+ self.min_size_restriction = None
129
+
130
+ self.initial = ConvBlock(
131
+ in_channels, num_channels, kernel_size=9, stride=1, padding=4, use_bn=False
132
+ )
133
+ self.residual = nn.Sequential(
134
+ *[ResidualBlock(num_channels) for _ in range(num_blocks)]
135
+ )
136
+ self.convblock = ConvBlock(
137
+ num_channels,
138
+ num_channels,
139
+ kernel_size=3,
140
+ stride=1,
141
+ padding=1,
142
+ use_act=False,
143
+ )
144
+ self.upsampler = nn.Sequential(
145
+ *[
146
+ UpsampleBlock(num_channels, scale_factor=2)
147
+ for _ in range(upscale_factor // 2)
148
+ ]
149
+ )
150
+ self.final_conv = SeperableConv2d(
151
+ num_channels, in_channels, kernel_size=9, stride=1, padding=4
152
+ )
153
+
154
+ self.load_state_dict(self.state, strict=False)
155
+
156
+ def forward(self, x):
157
+ initial = self.initial(x)
158
+ x = self.residual(initial)
159
+ x = self.convblock(x) + initial
160
+ x = self.upsampler(x)
161
+ return (torch.tanh(self.final_conv(x)) + 1) / 2
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/Swin2SR.py ADDED
@@ -0,0 +1,1377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ # -----------------------------------------------------------------------------------
3
+ # Swin2SR: Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration, https://arxiv.org/abs/2209.11345
4
+ # Written by Conde and Choi et al.
5
+ # From: https://raw.githubusercontent.com/mv-lab/swin2sr/main/models/network_swin2sr.py
6
+ # -----------------------------------------------------------------------------------
7
+
8
+ import math
9
+ import re
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import torch.utils.checkpoint as checkpoint
16
+
17
+ # Originally from the timm package
18
+ from .timm.drop import DropPath
19
+ from .timm.helpers import to_2tuple
20
+ from .timm.weight_init import trunc_normal_
21
+
22
+
23
+ class Mlp(nn.Module):
24
+ def __init__(
25
+ self,
26
+ in_features,
27
+ hidden_features=None,
28
+ out_features=None,
29
+ act_layer=nn.GELU,
30
+ drop=0.0,
31
+ ):
32
+ super().__init__()
33
+ out_features = out_features or in_features
34
+ hidden_features = hidden_features or in_features
35
+ self.fc1 = nn.Linear(in_features, hidden_features)
36
+ self.act = act_layer()
37
+ self.fc2 = nn.Linear(hidden_features, out_features)
38
+ self.drop = nn.Dropout(drop)
39
+
40
+ def forward(self, x):
41
+ x = self.fc1(x)
42
+ x = self.act(x)
43
+ x = self.drop(x)
44
+ x = self.fc2(x)
45
+ x = self.drop(x)
46
+ return x
47
+
48
+
49
+ def window_partition(x, window_size):
50
+ """
51
+ Args:
52
+ x: (B, H, W, C)
53
+ window_size (int): window size
54
+ Returns:
55
+ windows: (num_windows*B, window_size, window_size, C)
56
+ """
57
+ B, H, W, C = x.shape
58
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
59
+ windows = (
60
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
61
+ )
62
+ return windows
63
+
64
+
65
+ def window_reverse(windows, window_size, H, W):
66
+ """
67
+ Args:
68
+ windows: (num_windows*B, window_size, window_size, C)
69
+ window_size (int): Window size
70
+ H (int): Height of image
71
+ W (int): Width of image
72
+ Returns:
73
+ x: (B, H, W, C)
74
+ """
75
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
76
+ x = windows.view(
77
+ B, H // window_size, W // window_size, window_size, window_size, -1
78
+ )
79
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
80
+ return x
81
+
82
+
83
+ class WindowAttention(nn.Module):
84
+ r"""Window based multi-head self attention (W-MSA) module with relative position bias.
85
+ It supports both of shifted and non-shifted window.
86
+ Args:
87
+ dim (int): Number of input channels.
88
+ window_size (tuple[int]): The height and width of the window.
89
+ num_heads (int): Number of attention heads.
90
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
91
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
92
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
93
+ pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
94
+ """
95
+
96
+ def __init__(
97
+ self,
98
+ dim,
99
+ window_size,
100
+ num_heads,
101
+ qkv_bias=True,
102
+ attn_drop=0.0,
103
+ proj_drop=0.0,
104
+ pretrained_window_size=[0, 0],
105
+ ):
106
+ super().__init__()
107
+ self.dim = dim
108
+ self.window_size = window_size # Wh, Ww
109
+ self.pretrained_window_size = pretrained_window_size
110
+ self.num_heads = num_heads
111
+
112
+ self.logit_scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True) # type: ignore
113
+
114
+ # mlp to generate continuous relative position bias
115
+ self.cpb_mlp = nn.Sequential(
116
+ nn.Linear(2, 512, bias=True),
117
+ nn.ReLU(inplace=True),
118
+ nn.Linear(512, num_heads, bias=False),
119
+ )
120
+
121
+ # get relative_coords_table
122
+ relative_coords_h = torch.arange(
123
+ -(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32
124
+ )
125
+ relative_coords_w = torch.arange(
126
+ -(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32
127
+ )
128
+ relative_coords_table = (
129
+ torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
130
+ .permute(1, 2, 0)
131
+ .contiguous()
132
+ .unsqueeze(0)
133
+ ) # 1, 2*Wh-1, 2*Ww-1, 2
134
+ if pretrained_window_size[0] > 0:
135
+ relative_coords_table[:, :, :, 0] /= pretrained_window_size[0] - 1
136
+ relative_coords_table[:, :, :, 1] /= pretrained_window_size[1] - 1
137
+ else:
138
+ relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1
139
+ relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1
140
+ relative_coords_table *= 8 # normalize to -8, 8
141
+ relative_coords_table = (
142
+ torch.sign(relative_coords_table)
143
+ * torch.log2(torch.abs(relative_coords_table) + 1.0)
144
+ / np.log2(8)
145
+ )
146
+
147
+ self.register_buffer("relative_coords_table", relative_coords_table)
148
+
149
+ # get pair-wise relative position index for each token inside the window
150
+ coords_h = torch.arange(self.window_size[0])
151
+ coords_w = torch.arange(self.window_size[1])
152
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
153
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
154
+ relative_coords = (
155
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
156
+ ) # 2, Wh*Ww, Wh*Ww
157
+ relative_coords = relative_coords.permute(
158
+ 1, 2, 0
159
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
160
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
161
+ relative_coords[:, :, 1] += self.window_size[1] - 1
162
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
163
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
164
+ self.register_buffer("relative_position_index", relative_position_index)
165
+
166
+ self.qkv = nn.Linear(dim, dim * 3, bias=False)
167
+ if qkv_bias:
168
+ self.q_bias = nn.Parameter(torch.zeros(dim)) # type: ignore
169
+ self.v_bias = nn.Parameter(torch.zeros(dim)) # type: ignore
170
+ else:
171
+ self.q_bias = None
172
+ self.v_bias = None
173
+ self.attn_drop = nn.Dropout(attn_drop)
174
+ self.proj = nn.Linear(dim, dim)
175
+ self.proj_drop = nn.Dropout(proj_drop)
176
+ self.softmax = nn.Softmax(dim=-1)
177
+
178
+ def forward(self, x, mask=None):
179
+ """
180
+ Args:
181
+ x: input features with shape of (num_windows*B, N, C)
182
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
183
+ """
184
+ B_, N, C = x.shape
185
+ qkv_bias = None
186
+ if self.q_bias is not None:
187
+ qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias)) # type: ignore
188
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
189
+ qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
190
+ q, k, v = (
191
+ qkv[0],
192
+ qkv[1],
193
+ qkv[2],
194
+ ) # make torchscript happy (cannot use tensor as tuple)
195
+
196
+ # cosine attention
197
+ attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
198
+ logit_scale = torch.clamp(
199
+ self.logit_scale,
200
+ max=torch.log(torch.tensor(1.0 / 0.01)).to(self.logit_scale.device),
201
+ ).exp()
202
+ attn = attn * logit_scale
203
+
204
+ relative_position_bias_table = self.cpb_mlp(self.relative_coords_table).view(
205
+ -1, self.num_heads
206
+ )
207
+ relative_position_bias = relative_position_bias_table[self.relative_position_index.view(-1)].view( # type: ignore
208
+ self.window_size[0] * self.window_size[1],
209
+ self.window_size[0] * self.window_size[1],
210
+ -1,
211
+ ) # Wh*Ww,Wh*Ww,nH
212
+ relative_position_bias = relative_position_bias.permute(
213
+ 2, 0, 1
214
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
215
+ relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
216
+ attn = attn + relative_position_bias.unsqueeze(0)
217
+
218
+ if mask is not None:
219
+ nW = mask.shape[0]
220
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
221
+ 1
222
+ ).unsqueeze(0)
223
+ attn = attn.view(-1, self.num_heads, N, N)
224
+ attn = self.softmax(attn)
225
+ else:
226
+ attn = self.softmax(attn)
227
+
228
+ attn = self.attn_drop(attn)
229
+
230
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
231
+ x = self.proj(x)
232
+ x = self.proj_drop(x)
233
+ return x
234
+
235
+ def extra_repr(self) -> str:
236
+ return (
237
+ f"dim={self.dim}, window_size={self.window_size}, "
238
+ f"pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}"
239
+ )
240
+
241
+ def flops(self, N):
242
+ # calculate flops for 1 window with token length of N
243
+ flops = 0
244
+ # qkv = self.qkv(x)
245
+ flops += N * self.dim * 3 * self.dim
246
+ # attn = (q @ k.transpose(-2, -1))
247
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
248
+ # x = (attn @ v)
249
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
250
+ # x = self.proj(x)
251
+ flops += N * self.dim * self.dim
252
+ return flops
253
+
254
+
255
+ class SwinTransformerBlock(nn.Module):
256
+ r"""Swin Transformer Block.
257
+ Args:
258
+ dim (int): Number of input channels.
259
+ input_resolution (tuple[int]): Input resulotion.
260
+ num_heads (int): Number of attention heads.
261
+ window_size (int): Window size.
262
+ shift_size (int): Shift size for SW-MSA.
263
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
264
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
265
+ drop (float, optional): Dropout rate. Default: 0.0
266
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
267
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
268
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
269
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
270
+ pretrained_window_size (int): Window size in pre-training.
271
+ """
272
+
273
+ def __init__(
274
+ self,
275
+ dim,
276
+ input_resolution,
277
+ num_heads,
278
+ window_size=7,
279
+ shift_size=0,
280
+ mlp_ratio=4.0,
281
+ qkv_bias=True,
282
+ drop=0.0,
283
+ attn_drop=0.0,
284
+ drop_path=0.0,
285
+ act_layer=nn.GELU,
286
+ norm_layer=nn.LayerNorm,
287
+ pretrained_window_size=0,
288
+ ):
289
+ super().__init__()
290
+ self.dim = dim
291
+ self.input_resolution = input_resolution
292
+ self.num_heads = num_heads
293
+ self.window_size = window_size
294
+ self.shift_size = shift_size
295
+ self.mlp_ratio = mlp_ratio
296
+ if min(self.input_resolution) <= self.window_size:
297
+ # if window size is larger than input resolution, we don't partition windows
298
+ self.shift_size = 0
299
+ self.window_size = min(self.input_resolution)
300
+ assert (
301
+ 0 <= self.shift_size < self.window_size
302
+ ), "shift_size must in 0-window_size"
303
+
304
+ self.norm1 = norm_layer(dim)
305
+ self.attn = WindowAttention(
306
+ dim,
307
+ window_size=to_2tuple(self.window_size),
308
+ num_heads=num_heads,
309
+ qkv_bias=qkv_bias,
310
+ attn_drop=attn_drop,
311
+ proj_drop=drop,
312
+ pretrained_window_size=to_2tuple(pretrained_window_size),
313
+ )
314
+
315
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
316
+ self.norm2 = norm_layer(dim)
317
+ mlp_hidden_dim = int(dim * mlp_ratio)
318
+ self.mlp = Mlp(
319
+ in_features=dim,
320
+ hidden_features=mlp_hidden_dim,
321
+ act_layer=act_layer,
322
+ drop=drop,
323
+ )
324
+
325
+ if self.shift_size > 0:
326
+ attn_mask = self.calculate_mask(self.input_resolution)
327
+ else:
328
+ attn_mask = None
329
+
330
+ self.register_buffer("attn_mask", attn_mask)
331
+
332
+ def calculate_mask(self, x_size):
333
+ # calculate attention mask for SW-MSA
334
+ H, W = x_size
335
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
336
+ h_slices = (
337
+ slice(0, -self.window_size),
338
+ slice(-self.window_size, -self.shift_size),
339
+ slice(-self.shift_size, None),
340
+ )
341
+ w_slices = (
342
+ slice(0, -self.window_size),
343
+ slice(-self.window_size, -self.shift_size),
344
+ slice(-self.shift_size, None),
345
+ )
346
+ cnt = 0
347
+ for h in h_slices:
348
+ for w in w_slices:
349
+ img_mask[:, h, w, :] = cnt
350
+ cnt += 1
351
+
352
+ mask_windows = window_partition(
353
+ img_mask, self.window_size
354
+ ) # nW, window_size, window_size, 1
355
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
356
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
357
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
358
+ attn_mask == 0, float(0.0)
359
+ )
360
+
361
+ return attn_mask
362
+
363
+ def forward(self, x, x_size):
364
+ H, W = x_size
365
+ B, L, C = x.shape
366
+ # assert L == H * W, "input feature has wrong size"
367
+
368
+ shortcut = x
369
+ x = x.view(B, H, W, C)
370
+
371
+ # cyclic shift
372
+ if self.shift_size > 0:
373
+ shifted_x = torch.roll(
374
+ x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
375
+ )
376
+ else:
377
+ shifted_x = x
378
+
379
+ # partition windows
380
+ x_windows = window_partition(
381
+ shifted_x, self.window_size
382
+ ) # nW*B, window_size, window_size, C
383
+ x_windows = x_windows.view(
384
+ -1, self.window_size * self.window_size, C
385
+ ) # nW*B, window_size*window_size, C
386
+
387
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
388
+ if self.input_resolution == x_size:
389
+ attn_windows = self.attn(
390
+ x_windows, mask=self.attn_mask
391
+ ) # nW*B, window_size*window_size, C
392
+ else:
393
+ attn_windows = self.attn(
394
+ x_windows, mask=self.calculate_mask(x_size).to(x.device)
395
+ )
396
+
397
+ # merge windows
398
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
399
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
400
+
401
+ # reverse cyclic shift
402
+ if self.shift_size > 0:
403
+ x = torch.roll(
404
+ shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
405
+ )
406
+ else:
407
+ x = shifted_x
408
+ x = x.view(B, H * W, C)
409
+ x = shortcut + self.drop_path(self.norm1(x))
410
+
411
+ # FFN
412
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
413
+
414
+ return x
415
+
416
+ def extra_repr(self) -> str:
417
+ return (
418
+ f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
419
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
420
+ )
421
+
422
+ def flops(self):
423
+ flops = 0
424
+ H, W = self.input_resolution
425
+ # norm1
426
+ flops += self.dim * H * W
427
+ # W-MSA/SW-MSA
428
+ nW = H * W / self.window_size / self.window_size
429
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
430
+ # mlp
431
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
432
+ # norm2
433
+ flops += self.dim * H * W
434
+ return flops
435
+
436
+
437
+ class PatchMerging(nn.Module):
438
+ r"""Patch Merging Layer.
439
+ Args:
440
+ input_resolution (tuple[int]): Resolution of input feature.
441
+ dim (int): Number of input channels.
442
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
443
+ """
444
+
445
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
446
+ super().__init__()
447
+ self.input_resolution = input_resolution
448
+ self.dim = dim
449
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
450
+ self.norm = norm_layer(2 * dim)
451
+
452
+ def forward(self, x):
453
+ """
454
+ x: B, H*W, C
455
+ """
456
+ H, W = self.input_resolution
457
+ B, L, C = x.shape
458
+ assert L == H * W, "input feature has wrong size"
459
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
460
+
461
+ x = x.view(B, H, W, C)
462
+
463
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
464
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
465
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
466
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
467
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
468
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
469
+
470
+ x = self.reduction(x)
471
+ x = self.norm(x)
472
+
473
+ return x
474
+
475
+ def extra_repr(self) -> str:
476
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
477
+
478
+ def flops(self):
479
+ H, W = self.input_resolution
480
+ flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
481
+ flops += H * W * self.dim // 2
482
+ return flops
483
+
484
+
485
+ class BasicLayer(nn.Module):
486
+ """A basic Swin Transformer layer for one stage.
487
+ Args:
488
+ dim (int): Number of input channels.
489
+ input_resolution (tuple[int]): Input resolution.
490
+ depth (int): Number of blocks.
491
+ num_heads (int): Number of attention heads.
492
+ window_size (int): Local window size.
493
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
494
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
495
+ drop (float, optional): Dropout rate. Default: 0.0
496
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
497
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
498
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
499
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
500
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
501
+ pretrained_window_size (int): Local window size in pre-training.
502
+ """
503
+
504
+ def __init__(
505
+ self,
506
+ dim,
507
+ input_resolution,
508
+ depth,
509
+ num_heads,
510
+ window_size,
511
+ mlp_ratio=4.0,
512
+ qkv_bias=True,
513
+ drop=0.0,
514
+ attn_drop=0.0,
515
+ drop_path=0.0,
516
+ norm_layer=nn.LayerNorm,
517
+ downsample=None,
518
+ use_checkpoint=False,
519
+ pretrained_window_size=0,
520
+ ):
521
+ super().__init__()
522
+ self.dim = dim
523
+ self.input_resolution = input_resolution
524
+ self.depth = depth
525
+ self.use_checkpoint = use_checkpoint
526
+
527
+ # build blocks
528
+ self.blocks = nn.ModuleList(
529
+ [
530
+ SwinTransformerBlock(
531
+ dim=dim,
532
+ input_resolution=input_resolution,
533
+ num_heads=num_heads,
534
+ window_size=window_size,
535
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
536
+ mlp_ratio=mlp_ratio,
537
+ qkv_bias=qkv_bias,
538
+ drop=drop,
539
+ attn_drop=attn_drop,
540
+ drop_path=drop_path[i]
541
+ if isinstance(drop_path, list)
542
+ else drop_path,
543
+ norm_layer=norm_layer,
544
+ pretrained_window_size=pretrained_window_size,
545
+ )
546
+ for i in range(depth)
547
+ ]
548
+ )
549
+
550
+ # patch merging layer
551
+ if downsample is not None:
552
+ self.downsample = downsample(
553
+ input_resolution, dim=dim, norm_layer=norm_layer
554
+ )
555
+ else:
556
+ self.downsample = None
557
+
558
+ def forward(self, x, x_size):
559
+ for blk in self.blocks:
560
+ if self.use_checkpoint:
561
+ x = checkpoint.checkpoint(blk, x, x_size)
562
+ else:
563
+ x = blk(x, x_size)
564
+ if self.downsample is not None:
565
+ x = self.downsample(x)
566
+ return x
567
+
568
+ def extra_repr(self) -> str:
569
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
570
+
571
+ def flops(self):
572
+ flops = 0
573
+ for blk in self.blocks:
574
+ flops += blk.flops() # type: ignore
575
+ if self.downsample is not None:
576
+ flops += self.downsample.flops()
577
+ return flops
578
+
579
+ def _init_respostnorm(self):
580
+ for blk in self.blocks:
581
+ nn.init.constant_(blk.norm1.bias, 0) # type: ignore
582
+ nn.init.constant_(blk.norm1.weight, 0) # type: ignore
583
+ nn.init.constant_(blk.norm2.bias, 0) # type: ignore
584
+ nn.init.constant_(blk.norm2.weight, 0) # type: ignore
585
+
586
+
587
+ class PatchEmbed(nn.Module):
588
+ r"""Image to Patch Embedding
589
+ Args:
590
+ img_size (int): Image size. Default: 224.
591
+ patch_size (int): Patch token size. Default: 4.
592
+ in_chans (int): Number of input image channels. Default: 3.
593
+ embed_dim (int): Number of linear projection output channels. Default: 96.
594
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
595
+ """
596
+
597
+ def __init__(
598
+ self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None
599
+ ):
600
+ super().__init__()
601
+ img_size = to_2tuple(img_size)
602
+ patch_size = to_2tuple(patch_size)
603
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] # type: ignore
604
+ self.img_size = img_size
605
+ self.patch_size = patch_size
606
+ self.patches_resolution = patches_resolution
607
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
608
+
609
+ self.in_chans = in_chans
610
+ self.embed_dim = embed_dim
611
+
612
+ self.proj = nn.Conv2d(
613
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size # type: ignore
614
+ )
615
+ if norm_layer is not None:
616
+ self.norm = norm_layer(embed_dim)
617
+ else:
618
+ self.norm = None
619
+
620
+ def forward(self, x):
621
+ B, C, H, W = x.shape
622
+ # FIXME look at relaxing size constraints
623
+ # assert H == self.img_size[0] and W == self.img_size[1],
624
+ # f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
625
+ x = self.proj(x).flatten(2).transpose(1, 2) # B Ph*Pw C
626
+ if self.norm is not None:
627
+ x = self.norm(x)
628
+ return x
629
+
630
+ def flops(self):
631
+ Ho, Wo = self.patches_resolution
632
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) # type: ignore
633
+ if self.norm is not None:
634
+ flops += Ho * Wo * self.embed_dim
635
+ return flops
636
+
637
+
638
+ class RSTB(nn.Module):
639
+ """Residual Swin Transformer Block (RSTB).
640
+
641
+ Args:
642
+ dim (int): Number of input channels.
643
+ input_resolution (tuple[int]): Input resolution.
644
+ depth (int): Number of blocks.
645
+ num_heads (int): Number of attention heads.
646
+ window_size (int): Local window size.
647
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
648
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
649
+ drop (float, optional): Dropout rate. Default: 0.0
650
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
651
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
652
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
653
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
654
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
655
+ img_size: Input image size.
656
+ patch_size: Patch size.
657
+ resi_connection: The convolutional block before residual connection.
658
+ """
659
+
660
+ def __init__(
661
+ self,
662
+ dim,
663
+ input_resolution,
664
+ depth,
665
+ num_heads,
666
+ window_size,
667
+ mlp_ratio=4.0,
668
+ qkv_bias=True,
669
+ drop=0.0,
670
+ attn_drop=0.0,
671
+ drop_path=0.0,
672
+ norm_layer=nn.LayerNorm,
673
+ downsample=None,
674
+ use_checkpoint=False,
675
+ img_size=224,
676
+ patch_size=4,
677
+ resi_connection="1conv",
678
+ ):
679
+ super(RSTB, self).__init__()
680
+
681
+ self.dim = dim
682
+ self.input_resolution = input_resolution
683
+
684
+ self.residual_group = BasicLayer(
685
+ dim=dim,
686
+ input_resolution=input_resolution,
687
+ depth=depth,
688
+ num_heads=num_heads,
689
+ window_size=window_size,
690
+ mlp_ratio=mlp_ratio,
691
+ qkv_bias=qkv_bias,
692
+ drop=drop,
693
+ attn_drop=attn_drop,
694
+ drop_path=drop_path,
695
+ norm_layer=norm_layer,
696
+ downsample=downsample,
697
+ use_checkpoint=use_checkpoint,
698
+ )
699
+
700
+ if resi_connection == "1conv":
701
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
702
+ elif resi_connection == "3conv":
703
+ # to save parameters and memory
704
+ self.conv = nn.Sequential(
705
+ nn.Conv2d(dim, dim // 4, 3, 1, 1),
706
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
707
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
708
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
709
+ nn.Conv2d(dim // 4, dim, 3, 1, 1),
710
+ )
711
+
712
+ self.patch_embed = PatchEmbed(
713
+ img_size=img_size,
714
+ patch_size=patch_size,
715
+ in_chans=dim,
716
+ embed_dim=dim,
717
+ norm_layer=None,
718
+ )
719
+
720
+ self.patch_unembed = PatchUnEmbed(
721
+ img_size=img_size,
722
+ patch_size=patch_size,
723
+ in_chans=dim,
724
+ embed_dim=dim,
725
+ norm_layer=None,
726
+ )
727
+
728
+ def forward(self, x, x_size):
729
+ return (
730
+ self.patch_embed(
731
+ self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))
732
+ )
733
+ + x
734
+ )
735
+
736
+ def flops(self):
737
+ flops = 0
738
+ flops += self.residual_group.flops()
739
+ H, W = self.input_resolution
740
+ flops += H * W * self.dim * self.dim * 9
741
+ flops += self.patch_embed.flops()
742
+ flops += self.patch_unembed.flops()
743
+
744
+ return flops
745
+
746
+
747
+ class PatchUnEmbed(nn.Module):
748
+ r"""Image to Patch Unembedding
749
+
750
+ Args:
751
+ img_size (int): Image size. Default: 224.
752
+ patch_size (int): Patch token size. Default: 4.
753
+ in_chans (int): Number of input image channels. Default: 3.
754
+ embed_dim (int): Number of linear projection output channels. Default: 96.
755
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
756
+ """
757
+
758
+ def __init__(
759
+ self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None
760
+ ):
761
+ super().__init__()
762
+ img_size = to_2tuple(img_size)
763
+ patch_size = to_2tuple(patch_size)
764
+ patches_resolution = [img_size[0] // patch_size[0], img_size[1] // patch_size[1]] # type: ignore
765
+ self.img_size = img_size
766
+ self.patch_size = patch_size
767
+ self.patches_resolution = patches_resolution
768
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
769
+
770
+ self.in_chans = in_chans
771
+ self.embed_dim = embed_dim
772
+
773
+ def forward(self, x, x_size):
774
+ B, HW, C = x.shape
775
+ x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
776
+ return x
777
+
778
+ def flops(self):
779
+ flops = 0
780
+ return flops
781
+
782
+
783
+ class Upsample(nn.Sequential):
784
+ """Upsample module.
785
+
786
+ Args:
787
+ scale (int): Scale factor. Supported scales: 2^n and 3.
788
+ num_feat (int): Channel number of intermediate features.
789
+ """
790
+
791
+ def __init__(self, scale, num_feat):
792
+ m = []
793
+ if (scale & (scale - 1)) == 0: # scale = 2^n
794
+ for _ in range(int(math.log(scale, 2))):
795
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
796
+ m.append(nn.PixelShuffle(2))
797
+ elif scale == 3:
798
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
799
+ m.append(nn.PixelShuffle(3))
800
+ else:
801
+ raise ValueError(
802
+ f"scale {scale} is not supported. " "Supported scales: 2^n and 3."
803
+ )
804
+ super(Upsample, self).__init__(*m)
805
+
806
+
807
+ class Upsample_hf(nn.Sequential):
808
+ """Upsample module.
809
+
810
+ Args:
811
+ scale (int): Scale factor. Supported scales: 2^n and 3.
812
+ num_feat (int): Channel number of intermediate features.
813
+ """
814
+
815
+ def __init__(self, scale, num_feat):
816
+ m = []
817
+ if (scale & (scale - 1)) == 0: # scale = 2^n
818
+ for _ in range(int(math.log(scale, 2))):
819
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
820
+ m.append(nn.PixelShuffle(2))
821
+ elif scale == 3:
822
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
823
+ m.append(nn.PixelShuffle(3))
824
+ else:
825
+ raise ValueError(
826
+ f"scale {scale} is not supported. " "Supported scales: 2^n and 3."
827
+ )
828
+ super(Upsample_hf, self).__init__(*m)
829
+
830
+
831
+ class UpsampleOneStep(nn.Sequential):
832
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
833
+ Used in lightweight SR to save parameters.
834
+
835
+ Args:
836
+ scale (int): Scale factor. Supported scales: 2^n and 3.
837
+ num_feat (int): Channel number of intermediate features.
838
+
839
+ """
840
+
841
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
842
+ self.num_feat = num_feat
843
+ self.input_resolution = input_resolution
844
+ m = []
845
+ m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1))
846
+ m.append(nn.PixelShuffle(scale))
847
+ super(UpsampleOneStep, self).__init__(*m)
848
+
849
+ def flops(self):
850
+ H, W = self.input_resolution # type: ignore
851
+ flops = H * W * self.num_feat * 3 * 9
852
+ return flops
853
+
854
+
855
+ class Swin2SR(nn.Module):
856
+ r"""Swin2SR
857
+ A PyTorch impl of : `Swin2SR: SwinV2 Transformer for Compressed Image Super-Resolution and Restoration`.
858
+
859
+ Args:
860
+ img_size (int | tuple(int)): Input image size. Default 64
861
+ patch_size (int | tuple(int)): Patch size. Default: 1
862
+ in_chans (int): Number of input image channels. Default: 3
863
+ embed_dim (int): Patch embedding dimension. Default: 96
864
+ depths (tuple(int)): Depth of each Swin Transformer layer.
865
+ num_heads (tuple(int)): Number of attention heads in different layers.
866
+ window_size (int): Window size. Default: 7
867
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
868
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
869
+ drop_rate (float): Dropout rate. Default: 0
870
+ attn_drop_rate (float): Attention dropout rate. Default: 0
871
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
872
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
873
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
874
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
875
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
876
+ upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
877
+ img_range: Image range. 1. or 255.
878
+ upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
879
+ resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
880
+ """
881
+
882
+ def __init__(
883
+ self,
884
+ state_dict,
885
+ **kwargs,
886
+ ):
887
+ super(Swin2SR, self).__init__()
888
+
889
+ # Defaults
890
+ img_size = 128
891
+ patch_size = 1
892
+ in_chans = 3
893
+ embed_dim = 96
894
+ depths = [6, 6, 6, 6]
895
+ num_heads = [6, 6, 6, 6]
896
+ window_size = 7
897
+ mlp_ratio = 4.0
898
+ qkv_bias = True
899
+ drop_rate = 0.0
900
+ attn_drop_rate = 0.0
901
+ drop_path_rate = 0.1
902
+ norm_layer = nn.LayerNorm
903
+ ape = False
904
+ patch_norm = True
905
+ use_checkpoint = False
906
+ upscale = 2
907
+ img_range = 1.0
908
+ upsampler = ""
909
+ resi_connection = "1conv"
910
+ num_in_ch = in_chans
911
+ num_out_ch = in_chans
912
+ num_feat = 64
913
+
914
+ self.model_arch = "Swin2SR"
915
+ self.sub_type = "SR"
916
+ self.state = state_dict
917
+ if "params_ema" in self.state:
918
+ self.state = self.state["params_ema"]
919
+ elif "params" in self.state:
920
+ self.state = self.state["params"]
921
+
922
+ state_keys = self.state.keys()
923
+
924
+ if "conv_before_upsample.0.weight" in state_keys:
925
+ if "conv_aux.weight" in state_keys:
926
+ upsampler = "pixelshuffle_aux"
927
+ elif "conv_up1.weight" in state_keys:
928
+ upsampler = "nearest+conv"
929
+ else:
930
+ upsampler = "pixelshuffle"
931
+ supports_fp16 = False
932
+ elif "upsample.0.weight" in state_keys:
933
+ upsampler = "pixelshuffledirect"
934
+ else:
935
+ upsampler = ""
936
+
937
+ num_feat = (
938
+ self.state.get("conv_before_upsample.0.weight", None).shape[1]
939
+ if self.state.get("conv_before_upsample.weight", None)
940
+ else 64
941
+ )
942
+
943
+ num_in_ch = self.state["conv_first.weight"].shape[1]
944
+ in_chans = num_in_ch
945
+ if "conv_last.weight" in state_keys:
946
+ num_out_ch = self.state["conv_last.weight"].shape[0]
947
+ else:
948
+ num_out_ch = num_in_ch
949
+
950
+ upscale = 1
951
+ if upsampler == "nearest+conv":
952
+ upsample_keys = [
953
+ x for x in state_keys if "conv_up" in x and "bias" not in x
954
+ ]
955
+
956
+ for upsample_key in upsample_keys:
957
+ upscale *= 2
958
+ elif upsampler == "pixelshuffle" or upsampler == "pixelshuffle_aux":
959
+ upsample_keys = [
960
+ x
961
+ for x in state_keys
962
+ if "upsample" in x and "conv" not in x and "bias" not in x
963
+ ]
964
+ for upsample_key in upsample_keys:
965
+ shape = self.state[upsample_key].shape[0]
966
+ upscale *= math.sqrt(shape // num_feat)
967
+ upscale = int(upscale)
968
+ elif upsampler == "pixelshuffledirect":
969
+ upscale = int(
970
+ math.sqrt(self.state["upsample.0.bias"].shape[0] // num_out_ch)
971
+ )
972
+
973
+ max_layer_num = 0
974
+ max_block_num = 0
975
+ for key in state_keys:
976
+ result = re.match(
977
+ r"layers.(\d*).residual_group.blocks.(\d*).norm1.weight", key
978
+ )
979
+ if result:
980
+ layer_num, block_num = result.groups()
981
+ max_layer_num = max(max_layer_num, int(layer_num))
982
+ max_block_num = max(max_block_num, int(block_num))
983
+
984
+ depths = [max_block_num + 1 for _ in range(max_layer_num + 1)]
985
+
986
+ if (
987
+ "layers.0.residual_group.blocks.0.attn.relative_position_bias_table"
988
+ in state_keys
989
+ ):
990
+ num_heads_num = self.state[
991
+ "layers.0.residual_group.blocks.0.attn.relative_position_bias_table"
992
+ ].shape[-1]
993
+ num_heads = [num_heads_num for _ in range(max_layer_num + 1)]
994
+ else:
995
+ num_heads = depths
996
+
997
+ embed_dim = self.state["conv_first.weight"].shape[0]
998
+
999
+ mlp_ratio = float(
1000
+ self.state["layers.0.residual_group.blocks.0.mlp.fc1.bias"].shape[0]
1001
+ / embed_dim
1002
+ )
1003
+
1004
+ # TODO: could actually count the layers, but this should do
1005
+ if "layers.0.conv.4.weight" in state_keys:
1006
+ resi_connection = "3conv"
1007
+ else:
1008
+ resi_connection = "1conv"
1009
+
1010
+ window_size = int(
1011
+ math.sqrt(
1012
+ self.state[
1013
+ "layers.0.residual_group.blocks.0.attn.relative_position_index"
1014
+ ].shape[0]
1015
+ )
1016
+ )
1017
+
1018
+ if "layers.0.residual_group.blocks.1.attn_mask" in state_keys:
1019
+ img_size = int(
1020
+ math.sqrt(
1021
+ self.state["layers.0.residual_group.blocks.1.attn_mask"].shape[0]
1022
+ )
1023
+ * window_size
1024
+ )
1025
+
1026
+ # The JPEG models are the only ones with window-size 7, and they also use this range
1027
+ img_range = 255.0 if window_size == 7 else 1.0
1028
+
1029
+ self.in_nc = num_in_ch
1030
+ self.out_nc = num_out_ch
1031
+ self.num_feat = num_feat
1032
+ self.embed_dim = embed_dim
1033
+ self.num_heads = num_heads
1034
+ self.depths = depths
1035
+ self.window_size = window_size
1036
+ self.mlp_ratio = mlp_ratio
1037
+ self.scale = upscale
1038
+ self.upsampler = upsampler
1039
+ self.img_size = img_size
1040
+ self.img_range = img_range
1041
+ self.resi_connection = resi_connection
1042
+
1043
+ self.supports_fp16 = False # Too much weirdness to support this at the moment
1044
+ self.supports_bfp16 = True
1045
+ self.min_size_restriction = 16
1046
+
1047
+ ## END AUTO DETECTION
1048
+
1049
+ if in_chans == 3:
1050
+ rgb_mean = (0.4488, 0.4371, 0.4040)
1051
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
1052
+ else:
1053
+ self.mean = torch.zeros(1, 1, 1, 1)
1054
+ self.upscale = upscale
1055
+ self.upsampler = upsampler
1056
+ self.window_size = window_size
1057
+
1058
+ #####################################################################################################
1059
+ ################################### 1, shallow feature extraction ###################################
1060
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
1061
+
1062
+ #####################################################################################################
1063
+ ################################### 2, deep feature extraction ######################################
1064
+ self.num_layers = len(depths)
1065
+ self.embed_dim = embed_dim
1066
+ self.ape = ape
1067
+ self.patch_norm = patch_norm
1068
+ self.num_features = embed_dim
1069
+ self.mlp_ratio = mlp_ratio
1070
+
1071
+ # split image into non-overlapping patches
1072
+ self.patch_embed = PatchEmbed(
1073
+ img_size=img_size,
1074
+ patch_size=patch_size,
1075
+ in_chans=embed_dim,
1076
+ embed_dim=embed_dim,
1077
+ norm_layer=norm_layer if self.patch_norm else None,
1078
+ )
1079
+ num_patches = self.patch_embed.num_patches
1080
+ patches_resolution = self.patch_embed.patches_resolution
1081
+ self.patches_resolution = patches_resolution
1082
+
1083
+ # merge non-overlapping patches into image
1084
+ self.patch_unembed = PatchUnEmbed(
1085
+ img_size=img_size,
1086
+ patch_size=patch_size,
1087
+ in_chans=embed_dim,
1088
+ embed_dim=embed_dim,
1089
+ norm_layer=norm_layer if self.patch_norm else None,
1090
+ )
1091
+
1092
+ # absolute position embedding
1093
+ if self.ape:
1094
+ self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim)) # type: ignore
1095
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
1096
+
1097
+ self.pos_drop = nn.Dropout(p=drop_rate)
1098
+
1099
+ # stochastic depth
1100
+ dpr = [
1101
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
1102
+ ] # stochastic depth decay rule
1103
+
1104
+ # build Residual Swin Transformer blocks (RSTB)
1105
+ self.layers = nn.ModuleList()
1106
+ for i_layer in range(self.num_layers):
1107
+ layer = RSTB(
1108
+ dim=embed_dim,
1109
+ input_resolution=(patches_resolution[0], patches_resolution[1]),
1110
+ depth=depths[i_layer],
1111
+ num_heads=num_heads[i_layer],
1112
+ window_size=window_size,
1113
+ mlp_ratio=self.mlp_ratio,
1114
+ qkv_bias=qkv_bias,
1115
+ drop=drop_rate,
1116
+ attn_drop=attn_drop_rate,
1117
+ drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], # type: ignore # no impact on SR results
1118
+ norm_layer=norm_layer,
1119
+ downsample=None,
1120
+ use_checkpoint=use_checkpoint,
1121
+ img_size=img_size,
1122
+ patch_size=patch_size,
1123
+ resi_connection=resi_connection,
1124
+ )
1125
+ self.layers.append(layer)
1126
+
1127
+ if self.upsampler == "pixelshuffle_hf":
1128
+ self.layers_hf = nn.ModuleList()
1129
+ for i_layer in range(self.num_layers):
1130
+ layer = RSTB(
1131
+ dim=embed_dim,
1132
+ input_resolution=(patches_resolution[0], patches_resolution[1]),
1133
+ depth=depths[i_layer],
1134
+ num_heads=num_heads[i_layer],
1135
+ window_size=window_size,
1136
+ mlp_ratio=self.mlp_ratio,
1137
+ qkv_bias=qkv_bias,
1138
+ drop=drop_rate,
1139
+ attn_drop=attn_drop_rate,
1140
+ drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], # type: ignore # no impact on SR results # type: ignore
1141
+ norm_layer=norm_layer,
1142
+ downsample=None,
1143
+ use_checkpoint=use_checkpoint,
1144
+ img_size=img_size,
1145
+ patch_size=patch_size,
1146
+ resi_connection=resi_connection,
1147
+ )
1148
+ self.layers_hf.append(layer)
1149
+
1150
+ self.norm = norm_layer(self.num_features)
1151
+
1152
+ # build the last conv layer in deep feature extraction
1153
+ if resi_connection == "1conv":
1154
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
1155
+ elif resi_connection == "3conv":
1156
+ # to save parameters and memory
1157
+ self.conv_after_body = nn.Sequential(
1158
+ nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
1159
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
1160
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
1161
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
1162
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1),
1163
+ )
1164
+
1165
+ #####################################################################################################
1166
+ ################################ 3, high quality image reconstruction ################################
1167
+ if self.upsampler == "pixelshuffle":
1168
+ # for classical SR
1169
+ self.conv_before_upsample = nn.Sequential(
1170
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
1171
+ )
1172
+ self.upsample = Upsample(upscale, num_feat)
1173
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
1174
+ elif self.upsampler == "pixelshuffle_aux":
1175
+ self.conv_bicubic = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1)
1176
+ self.conv_before_upsample = nn.Sequential(
1177
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
1178
+ )
1179
+ self.conv_aux = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
1180
+ self.conv_after_aux = nn.Sequential(
1181
+ nn.Conv2d(3, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
1182
+ )
1183
+ self.upsample = Upsample(upscale, num_feat)
1184
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
1185
+
1186
+ elif self.upsampler == "pixelshuffle_hf":
1187
+ self.conv_before_upsample = nn.Sequential(
1188
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
1189
+ )
1190
+ self.upsample = Upsample(upscale, num_feat)
1191
+ self.upsample_hf = Upsample_hf(upscale, num_feat)
1192
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
1193
+ self.conv_first_hf = nn.Sequential(
1194
+ nn.Conv2d(num_feat, embed_dim, 3, 1, 1), nn.LeakyReLU(inplace=True)
1195
+ )
1196
+ self.conv_after_body_hf = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
1197
+ self.conv_before_upsample_hf = nn.Sequential(
1198
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
1199
+ )
1200
+ self.conv_last_hf = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
1201
+
1202
+ elif self.upsampler == "pixelshuffledirect":
1203
+ # for lightweight SR (to save parameters)
1204
+ self.upsample = UpsampleOneStep(
1205
+ upscale,
1206
+ embed_dim,
1207
+ num_out_ch,
1208
+ (patches_resolution[0], patches_resolution[1]),
1209
+ )
1210
+ elif self.upsampler == "nearest+conv":
1211
+ # for real-world SR (less artifacts)
1212
+ assert self.upscale == 4, "only support x4 now."
1213
+ self.conv_before_upsample = nn.Sequential(
1214
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
1215
+ )
1216
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
1217
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
1218
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
1219
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
1220
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
1221
+ else:
1222
+ # for image denoising and JPEG compression artifact reduction
1223
+ self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
1224
+
1225
+ self.apply(self._init_weights)
1226
+
1227
+ self.load_state_dict(state_dict)
1228
+
1229
+ def _init_weights(self, m):
1230
+ if isinstance(m, nn.Linear):
1231
+ trunc_normal_(m.weight, std=0.02)
1232
+ if isinstance(m, nn.Linear) and m.bias is not None:
1233
+ nn.init.constant_(m.bias, 0)
1234
+ elif isinstance(m, nn.LayerNorm):
1235
+ nn.init.constant_(m.bias, 0)
1236
+ nn.init.constant_(m.weight, 1.0)
1237
+
1238
+ @torch.jit.ignore # type: ignore
1239
+ def no_weight_decay(self):
1240
+ return {"absolute_pos_embed"}
1241
+
1242
+ @torch.jit.ignore # type: ignore
1243
+ def no_weight_decay_keywords(self):
1244
+ return {"relative_position_bias_table"}
1245
+
1246
+ def check_image_size(self, x):
1247
+ _, _, h, w = x.size()
1248
+ mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
1249
+ mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
1250
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
1251
+ return x
1252
+
1253
+ def forward_features(self, x):
1254
+ x_size = (x.shape[2], x.shape[3])
1255
+ x = self.patch_embed(x)
1256
+ if self.ape:
1257
+ x = x + self.absolute_pos_embed
1258
+ x = self.pos_drop(x)
1259
+
1260
+ for layer in self.layers:
1261
+ x = layer(x, x_size)
1262
+
1263
+ x = self.norm(x) # B L C
1264
+ x = self.patch_unembed(x, x_size)
1265
+
1266
+ return x
1267
+
1268
+ def forward_features_hf(self, x):
1269
+ x_size = (x.shape[2], x.shape[3])
1270
+ x = self.patch_embed(x)
1271
+ if self.ape:
1272
+ x = x + self.absolute_pos_embed
1273
+ x = self.pos_drop(x)
1274
+
1275
+ for layer in self.layers_hf:
1276
+ x = layer(x, x_size)
1277
+
1278
+ x = self.norm(x) # B L C
1279
+ x = self.patch_unembed(x, x_size)
1280
+
1281
+ return x
1282
+
1283
+ def forward(self, x):
1284
+ H, W = x.shape[2:]
1285
+ x = self.check_image_size(x)
1286
+
1287
+ self.mean = self.mean.type_as(x)
1288
+ x = (x - self.mean) * self.img_range
1289
+
1290
+ if self.upsampler == "pixelshuffle":
1291
+ # for classical SR
1292
+ x = self.conv_first(x)
1293
+ x = self.conv_after_body(self.forward_features(x)) + x
1294
+ x = self.conv_before_upsample(x)
1295
+ x = self.conv_last(self.upsample(x))
1296
+ elif self.upsampler == "pixelshuffle_aux":
1297
+ bicubic = F.interpolate(
1298
+ x,
1299
+ size=(H * self.upscale, W * self.upscale),
1300
+ mode="bicubic",
1301
+ align_corners=False,
1302
+ )
1303
+ bicubic = self.conv_bicubic(bicubic)
1304
+ x = self.conv_first(x)
1305
+ x = self.conv_after_body(self.forward_features(x)) + x
1306
+ x = self.conv_before_upsample(x)
1307
+ aux = self.conv_aux(x) # b, 3, LR_H, LR_W
1308
+ x = self.conv_after_aux(aux)
1309
+ x = (
1310
+ self.upsample(x)[:, :, : H * self.upscale, : W * self.upscale]
1311
+ + bicubic[:, :, : H * self.upscale, : W * self.upscale]
1312
+ )
1313
+ x = self.conv_last(x)
1314
+ aux = aux / self.img_range + self.mean
1315
+ elif self.upsampler == "pixelshuffle_hf":
1316
+ # for classical SR with HF
1317
+ x = self.conv_first(x)
1318
+ x = self.conv_after_body(self.forward_features(x)) + x
1319
+ x_before = self.conv_before_upsample(x)
1320
+ x_out = self.conv_last(self.upsample(x_before))
1321
+
1322
+ x_hf = self.conv_first_hf(x_before)
1323
+ x_hf = self.conv_after_body_hf(self.forward_features_hf(x_hf)) + x_hf
1324
+ x_hf = self.conv_before_upsample_hf(x_hf)
1325
+ x_hf = self.conv_last_hf(self.upsample_hf(x_hf))
1326
+ x = x_out + x_hf
1327
+ x_hf = x_hf / self.img_range + self.mean
1328
+
1329
+ elif self.upsampler == "pixelshuffledirect":
1330
+ # for lightweight SR
1331
+ x = self.conv_first(x)
1332
+ x = self.conv_after_body(self.forward_features(x)) + x
1333
+ x = self.upsample(x)
1334
+ elif self.upsampler == "nearest+conv":
1335
+ # for real-world SR
1336
+ x = self.conv_first(x)
1337
+ x = self.conv_after_body(self.forward_features(x)) + x
1338
+ x = self.conv_before_upsample(x)
1339
+ x = self.lrelu(
1340
+ self.conv_up1(
1341
+ torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest")
1342
+ )
1343
+ )
1344
+ x = self.lrelu(
1345
+ self.conv_up2(
1346
+ torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest")
1347
+ )
1348
+ )
1349
+ x = self.conv_last(self.lrelu(self.conv_hr(x)))
1350
+ else:
1351
+ # for image denoising and JPEG compression artifact reduction
1352
+ x_first = self.conv_first(x)
1353
+ res = self.conv_after_body(self.forward_features(x_first)) + x_first
1354
+ x = x + self.conv_last(res)
1355
+
1356
+ x = x / self.img_range + self.mean
1357
+ if self.upsampler == "pixelshuffle_aux":
1358
+ # NOTE: I removed an "aux" output here. not sure what that was for
1359
+ return x[:, :, : H * self.upscale, : W * self.upscale] # type: ignore
1360
+
1361
+ elif self.upsampler == "pixelshuffle_hf":
1362
+ x_out = x_out / self.img_range + self.mean # type: ignore
1363
+ return x_out[:, :, : H * self.upscale, : W * self.upscale], x[:, :, : H * self.upscale, : W * self.upscale], x_hf[:, :, : H * self.upscale, : W * self.upscale] # type: ignore
1364
+
1365
+ else:
1366
+ return x[:, :, : H * self.upscale, : W * self.upscale]
1367
+
1368
+ def flops(self):
1369
+ flops = 0
1370
+ H, W = self.patches_resolution
1371
+ flops += H * W * 3 * self.embed_dim * 9
1372
+ flops += self.patch_embed.flops()
1373
+ for i, layer in enumerate(self.layers):
1374
+ flops += layer.flops() # type: ignore
1375
+ flops += H * W * 3 * self.embed_dim * self.embed_dim
1376
+ flops += self.upsample.flops() # type: ignore
1377
+ return flops
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/SwinIR.py ADDED
@@ -0,0 +1,1224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ # -----------------------------------------------------------------------------------
3
+ # SwinIR: Image Restoration Using Swin Transformer, https://arxiv.org/abs/2108.10257
4
+ # Originally Written by Ze Liu, Modified by Jingyun Liang.
5
+ # -----------------------------------------------------------------------------------
6
+
7
+ import math
8
+ import re
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import torch.utils.checkpoint as checkpoint
14
+
15
+ # Originally from the timm package
16
+ from .timm.drop import DropPath
17
+ from .timm.helpers import to_2tuple
18
+ from .timm.weight_init import trunc_normal_
19
+
20
+
21
+ class Mlp(nn.Module):
22
+ def __init__(
23
+ self,
24
+ in_features,
25
+ hidden_features=None,
26
+ out_features=None,
27
+ act_layer=nn.GELU,
28
+ drop=0.0,
29
+ ):
30
+ super().__init__()
31
+ out_features = out_features or in_features
32
+ hidden_features = hidden_features or in_features
33
+ self.fc1 = nn.Linear(in_features, hidden_features)
34
+ self.act = act_layer()
35
+ self.fc2 = nn.Linear(hidden_features, out_features)
36
+ self.drop = nn.Dropout(drop)
37
+
38
+ def forward(self, x):
39
+ x = self.fc1(x)
40
+ x = self.act(x)
41
+ x = self.drop(x)
42
+ x = self.fc2(x)
43
+ x = self.drop(x)
44
+ return x
45
+
46
+
47
+ def window_partition(x, window_size):
48
+ """
49
+ Args:
50
+ x: (B, H, W, C)
51
+ window_size (int): window size
52
+
53
+ Returns:
54
+ windows: (num_windows*B, window_size, window_size, C)
55
+ """
56
+ B, H, W, C = x.shape
57
+ x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
58
+ windows = (
59
+ x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
60
+ )
61
+ return windows
62
+
63
+
64
+ def window_reverse(windows, window_size, H, W):
65
+ """
66
+ Args:
67
+ windows: (num_windows*B, window_size, window_size, C)
68
+ window_size (int): Window size
69
+ H (int): Height of image
70
+ W (int): Width of image
71
+
72
+ Returns:
73
+ x: (B, H, W, C)
74
+ """
75
+ B = int(windows.shape[0] / (H * W / window_size / window_size))
76
+ x = windows.view(
77
+ B, H // window_size, W // window_size, window_size, window_size, -1
78
+ )
79
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
80
+ return x
81
+
82
+
83
+ class WindowAttention(nn.Module):
84
+ r"""Window based multi-head self attention (W-MSA) module with relative position bias.
85
+ It supports both of shifted and non-shifted window.
86
+
87
+ Args:
88
+ dim (int): Number of input channels.
89
+ window_size (tuple[int]): The height and width of the window.
90
+ num_heads (int): Number of attention heads.
91
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
92
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
93
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
94
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
95
+ """
96
+
97
+ def __init__(
98
+ self,
99
+ dim,
100
+ window_size,
101
+ num_heads,
102
+ qkv_bias=True,
103
+ qk_scale=None,
104
+ attn_drop=0.0,
105
+ proj_drop=0.0,
106
+ ):
107
+ super().__init__()
108
+ self.dim = dim
109
+ self.window_size = window_size # Wh, Ww
110
+ self.num_heads = num_heads
111
+ head_dim = dim // num_heads
112
+ self.scale = qk_scale or head_dim**-0.5
113
+
114
+ # define a parameter table of relative position bias
115
+ self.relative_position_bias_table = nn.Parameter( # type: ignore
116
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
117
+ ) # 2*Wh-1 * 2*Ww-1, nH
118
+
119
+ # get pair-wise relative position index for each token inside the window
120
+ coords_h = torch.arange(self.window_size[0])
121
+ coords_w = torch.arange(self.window_size[1])
122
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
123
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
124
+ relative_coords = (
125
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
126
+ ) # 2, Wh*Ww, Wh*Ww
127
+ relative_coords = relative_coords.permute(
128
+ 1, 2, 0
129
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
130
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
131
+ relative_coords[:, :, 1] += self.window_size[1] - 1
132
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
133
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
134
+ self.register_buffer("relative_position_index", relative_position_index)
135
+
136
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
137
+ self.attn_drop = nn.Dropout(attn_drop)
138
+ self.proj = nn.Linear(dim, dim)
139
+
140
+ self.proj_drop = nn.Dropout(proj_drop)
141
+
142
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
143
+ self.softmax = nn.Softmax(dim=-1)
144
+
145
+ def forward(self, x, mask=None):
146
+ """
147
+ Args:
148
+ x: input features with shape of (num_windows*B, N, C)
149
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
150
+ """
151
+ B_, N, C = x.shape
152
+ qkv = (
153
+ self.qkv(x)
154
+ .reshape(B_, N, 3, self.num_heads, C // self.num_heads)
155
+ .permute(2, 0, 3, 1, 4)
156
+ )
157
+ q, k, v = (
158
+ qkv[0],
159
+ qkv[1],
160
+ qkv[2],
161
+ ) # make torchscript happy (cannot use tensor as tuple)
162
+
163
+ q = q * self.scale
164
+ attn = q @ k.transpose(-2, -1)
165
+
166
+ relative_position_bias = self.relative_position_bias_table[
167
+ self.relative_position_index.view(-1) # type: ignore
168
+ ].view(
169
+ self.window_size[0] * self.window_size[1],
170
+ self.window_size[0] * self.window_size[1],
171
+ -1,
172
+ ) # Wh*Ww,Wh*Ww,nH
173
+ relative_position_bias = relative_position_bias.permute(
174
+ 2, 0, 1
175
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
176
+ attn = attn + relative_position_bias.unsqueeze(0)
177
+
178
+ if mask is not None:
179
+ nW = mask.shape[0]
180
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
181
+ 1
182
+ ).unsqueeze(0)
183
+ attn = attn.view(-1, self.num_heads, N, N)
184
+ attn = self.softmax(attn)
185
+ else:
186
+ attn = self.softmax(attn)
187
+
188
+ attn = self.attn_drop(attn)
189
+
190
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
191
+ x = self.proj(x)
192
+ x = self.proj_drop(x)
193
+ return x
194
+
195
+ def extra_repr(self) -> str:
196
+ return f"dim={self.dim}, window_size={self.window_size}, num_heads={self.num_heads}"
197
+
198
+ def flops(self, N):
199
+ # calculate flops for 1 window with token length of N
200
+ flops = 0
201
+ # qkv = self.qkv(x)
202
+ flops += N * self.dim * 3 * self.dim
203
+ # attn = (q @ k.transpose(-2, -1))
204
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
205
+ # x = (attn @ v)
206
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
207
+ # x = self.proj(x)
208
+ flops += N * self.dim * self.dim
209
+ return flops
210
+
211
+
212
+ class SwinTransformerBlock(nn.Module):
213
+ r"""Swin Transformer Block.
214
+
215
+ Args:
216
+ dim (int): Number of input channels.
217
+ input_resolution (tuple[int]): Input resulotion.
218
+ num_heads (int): Number of attention heads.
219
+ window_size (int): Window size.
220
+ shift_size (int): Shift size for SW-MSA.
221
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
222
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
223
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
224
+ drop (float, optional): Dropout rate. Default: 0.0
225
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
226
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
227
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
228
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
229
+ """
230
+
231
+ def __init__(
232
+ self,
233
+ dim,
234
+ input_resolution,
235
+ num_heads,
236
+ window_size=7,
237
+ shift_size=0,
238
+ mlp_ratio=4.0,
239
+ qkv_bias=True,
240
+ qk_scale=None,
241
+ drop=0.0,
242
+ attn_drop=0.0,
243
+ drop_path=0.0,
244
+ act_layer=nn.GELU,
245
+ norm_layer=nn.LayerNorm,
246
+ ):
247
+ super().__init__()
248
+ self.dim = dim
249
+ self.input_resolution = input_resolution
250
+ self.num_heads = num_heads
251
+ self.window_size = window_size
252
+ self.shift_size = shift_size
253
+ self.mlp_ratio = mlp_ratio
254
+ if min(self.input_resolution) <= self.window_size:
255
+ # if window size is larger than input resolution, we don't partition windows
256
+ self.shift_size = 0
257
+ self.window_size = min(self.input_resolution)
258
+ assert (
259
+ 0 <= self.shift_size < self.window_size
260
+ ), "shift_size must in 0-window_size"
261
+
262
+ self.norm1 = norm_layer(dim)
263
+ self.attn = WindowAttention(
264
+ dim,
265
+ window_size=to_2tuple(self.window_size),
266
+ num_heads=num_heads,
267
+ qkv_bias=qkv_bias,
268
+ qk_scale=qk_scale,
269
+ attn_drop=attn_drop,
270
+ proj_drop=drop,
271
+ )
272
+
273
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
274
+ self.norm2 = norm_layer(dim)
275
+ mlp_hidden_dim = int(dim * mlp_ratio)
276
+ self.mlp = Mlp(
277
+ in_features=dim,
278
+ hidden_features=mlp_hidden_dim,
279
+ act_layer=act_layer,
280
+ drop=drop,
281
+ )
282
+
283
+ if self.shift_size > 0:
284
+ attn_mask = self.calculate_mask(self.input_resolution)
285
+ else:
286
+ attn_mask = None
287
+
288
+ self.register_buffer("attn_mask", attn_mask)
289
+
290
+ def calculate_mask(self, x_size):
291
+ # calculate attention mask for SW-MSA
292
+ H, W = x_size
293
+ img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
294
+ h_slices = (
295
+ slice(0, -self.window_size),
296
+ slice(-self.window_size, -self.shift_size),
297
+ slice(-self.shift_size, None),
298
+ )
299
+ w_slices = (
300
+ slice(0, -self.window_size),
301
+ slice(-self.window_size, -self.shift_size),
302
+ slice(-self.shift_size, None),
303
+ )
304
+ cnt = 0
305
+ for h in h_slices:
306
+ for w in w_slices:
307
+ img_mask[:, h, w, :] = cnt
308
+ cnt += 1
309
+
310
+ mask_windows = window_partition(
311
+ img_mask, self.window_size
312
+ ) # nW, window_size, window_size, 1
313
+ mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
314
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
315
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
316
+ attn_mask == 0, float(0.0)
317
+ )
318
+
319
+ return attn_mask
320
+
321
+ def forward(self, x, x_size):
322
+ H, W = x_size
323
+ B, L, C = x.shape
324
+ # assert L == H * W, "input feature has wrong size"
325
+
326
+ shortcut = x
327
+ x = self.norm1(x)
328
+ x = x.view(B, H, W, C)
329
+
330
+ # cyclic shift
331
+ if self.shift_size > 0:
332
+ shifted_x = torch.roll(
333
+ x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2)
334
+ )
335
+ else:
336
+ shifted_x = x
337
+
338
+ # partition windows
339
+ x_windows = window_partition(
340
+ shifted_x, self.window_size
341
+ ) # nW*B, window_size, window_size, C
342
+ x_windows = x_windows.view(
343
+ -1, self.window_size * self.window_size, C
344
+ ) # nW*B, window_size*window_size, C
345
+
346
+ # W-MSA/SW-MSA (to be compatible for testing on images whose shapes are the multiple of window size
347
+ if self.input_resolution == x_size:
348
+ attn_windows = self.attn(
349
+ x_windows, mask=self.attn_mask
350
+ ) # nW*B, window_size*window_size, C
351
+ else:
352
+ attn_windows = self.attn(
353
+ x_windows, mask=self.calculate_mask(x_size).to(x.device)
354
+ )
355
+
356
+ # merge windows
357
+ attn_windows = attn_windows.view(-1, self.window_size, self.window_size, C)
358
+ shifted_x = window_reverse(attn_windows, self.window_size, H, W) # B H' W' C
359
+
360
+ # reverse cyclic shift
361
+ if self.shift_size > 0:
362
+ x = torch.roll(
363
+ shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2)
364
+ )
365
+ else:
366
+ x = shifted_x
367
+ x = x.view(B, H * W, C)
368
+
369
+ # FFN
370
+ x = shortcut + self.drop_path(x)
371
+ x = x + self.drop_path(self.mlp(self.norm2(x)))
372
+
373
+ return x
374
+
375
+ def extra_repr(self) -> str:
376
+ return (
377
+ f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
378
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
379
+ )
380
+
381
+ def flops(self):
382
+ flops = 0
383
+ H, W = self.input_resolution
384
+ # norm1
385
+ flops += self.dim * H * W
386
+ # W-MSA/SW-MSA
387
+ nW = H * W / self.window_size / self.window_size
388
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
389
+ # mlp
390
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
391
+ # norm2
392
+ flops += self.dim * H * W
393
+ return flops
394
+
395
+
396
+ class PatchMerging(nn.Module):
397
+ r"""Patch Merging Layer.
398
+
399
+ Args:
400
+ input_resolution (tuple[int]): Resolution of input feature.
401
+ dim (int): Number of input channels.
402
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
403
+ """
404
+
405
+ def __init__(self, input_resolution, dim, norm_layer=nn.LayerNorm):
406
+ super().__init__()
407
+ self.input_resolution = input_resolution
408
+ self.dim = dim
409
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
410
+ self.norm = norm_layer(4 * dim)
411
+
412
+ def forward(self, x):
413
+ """
414
+ x: B, H*W, C
415
+ """
416
+ H, W = self.input_resolution
417
+ B, L, C = x.shape
418
+ assert L == H * W, "input feature has wrong size"
419
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
420
+
421
+ x = x.view(B, H, W, C)
422
+
423
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
424
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
425
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
426
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
427
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
428
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
429
+
430
+ x = self.norm(x)
431
+ x = self.reduction(x)
432
+
433
+ return x
434
+
435
+ def extra_repr(self) -> str:
436
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
437
+
438
+ def flops(self):
439
+ H, W = self.input_resolution
440
+ flops = H * W * self.dim
441
+ flops += (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
442
+ return flops
443
+
444
+
445
+ class BasicLayer(nn.Module):
446
+ """A basic Swin Transformer layer for one stage.
447
+
448
+ Args:
449
+ dim (int): Number of input channels.
450
+ input_resolution (tuple[int]): Input resolution.
451
+ depth (int): Number of blocks.
452
+ num_heads (int): Number of attention heads.
453
+ window_size (int): Local window size.
454
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
455
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
456
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
457
+ drop (float, optional): Dropout rate. Default: 0.0
458
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
459
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
460
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
461
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
462
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
463
+ """
464
+
465
+ def __init__(
466
+ self,
467
+ dim,
468
+ input_resolution,
469
+ depth,
470
+ num_heads,
471
+ window_size,
472
+ mlp_ratio=4.0,
473
+ qkv_bias=True,
474
+ qk_scale=None,
475
+ drop=0.0,
476
+ attn_drop=0.0,
477
+ drop_path=0.0,
478
+ norm_layer=nn.LayerNorm,
479
+ downsample=None,
480
+ use_checkpoint=False,
481
+ ):
482
+ super().__init__()
483
+ self.dim = dim
484
+ self.input_resolution = input_resolution
485
+ self.depth = depth
486
+ self.use_checkpoint = use_checkpoint
487
+
488
+ # build blocks
489
+ self.blocks = nn.ModuleList(
490
+ [
491
+ SwinTransformerBlock(
492
+ dim=dim,
493
+ input_resolution=input_resolution,
494
+ num_heads=num_heads,
495
+ window_size=window_size,
496
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
497
+ mlp_ratio=mlp_ratio,
498
+ qkv_bias=qkv_bias,
499
+ qk_scale=qk_scale,
500
+ drop=drop,
501
+ attn_drop=attn_drop,
502
+ drop_path=drop_path[i]
503
+ if isinstance(drop_path, list)
504
+ else drop_path,
505
+ norm_layer=norm_layer,
506
+ )
507
+ for i in range(depth)
508
+ ]
509
+ )
510
+
511
+ # patch merging layer
512
+ if downsample is not None:
513
+ self.downsample = downsample(
514
+ input_resolution, dim=dim, norm_layer=norm_layer
515
+ )
516
+ else:
517
+ self.downsample = None
518
+
519
+ def forward(self, x, x_size):
520
+ for blk in self.blocks:
521
+ if self.use_checkpoint:
522
+ x = checkpoint.checkpoint(blk, x, x_size)
523
+ else:
524
+ x = blk(x, x_size)
525
+ if self.downsample is not None:
526
+ x = self.downsample(x)
527
+ return x
528
+
529
+ def extra_repr(self) -> str:
530
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
531
+
532
+ def flops(self):
533
+ flops = 0
534
+ for blk in self.blocks:
535
+ flops += blk.flops() # type: ignore
536
+ if self.downsample is not None:
537
+ flops += self.downsample.flops()
538
+ return flops
539
+
540
+
541
+ class RSTB(nn.Module):
542
+ """Residual Swin Transformer Block (RSTB).
543
+
544
+ Args:
545
+ dim (int): Number of input channels.
546
+ input_resolution (tuple[int]): Input resolution.
547
+ depth (int): Number of blocks.
548
+ num_heads (int): Number of attention heads.
549
+ window_size (int): Local window size.
550
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
551
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
552
+ qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
553
+ drop (float, optional): Dropout rate. Default: 0.0
554
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
555
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
556
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
557
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
558
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
559
+ img_size: Input image size.
560
+ patch_size: Patch size.
561
+ resi_connection: The convolutional block before residual connection.
562
+ """
563
+
564
+ def __init__(
565
+ self,
566
+ dim,
567
+ input_resolution,
568
+ depth,
569
+ num_heads,
570
+ window_size,
571
+ mlp_ratio=4.0,
572
+ qkv_bias=True,
573
+ qk_scale=None,
574
+ drop=0.0,
575
+ attn_drop=0.0,
576
+ drop_path=0.0,
577
+ norm_layer=nn.LayerNorm,
578
+ downsample=None,
579
+ use_checkpoint=False,
580
+ img_size=224,
581
+ patch_size=4,
582
+ resi_connection="1conv",
583
+ ):
584
+ super(RSTB, self).__init__()
585
+
586
+ self.dim = dim
587
+ self.input_resolution = input_resolution
588
+
589
+ self.residual_group = BasicLayer(
590
+ dim=dim,
591
+ input_resolution=input_resolution,
592
+ depth=depth,
593
+ num_heads=num_heads,
594
+ window_size=window_size,
595
+ mlp_ratio=mlp_ratio,
596
+ qkv_bias=qkv_bias,
597
+ qk_scale=qk_scale,
598
+ drop=drop,
599
+ attn_drop=attn_drop,
600
+ drop_path=drop_path,
601
+ norm_layer=norm_layer,
602
+ downsample=downsample,
603
+ use_checkpoint=use_checkpoint,
604
+ )
605
+
606
+ if resi_connection == "1conv":
607
+ self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
608
+ elif resi_connection == "3conv":
609
+ # to save parameters and memory
610
+ self.conv = nn.Sequential(
611
+ nn.Conv2d(dim, dim // 4, 3, 1, 1),
612
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
613
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
614
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
615
+ nn.Conv2d(dim // 4, dim, 3, 1, 1),
616
+ )
617
+
618
+ self.patch_embed = PatchEmbed(
619
+ img_size=img_size,
620
+ patch_size=patch_size,
621
+ in_chans=0,
622
+ embed_dim=dim,
623
+ norm_layer=None,
624
+ )
625
+
626
+ self.patch_unembed = PatchUnEmbed(
627
+ img_size=img_size,
628
+ patch_size=patch_size,
629
+ in_chans=0,
630
+ embed_dim=dim,
631
+ norm_layer=None,
632
+ )
633
+
634
+ def forward(self, x, x_size):
635
+ return (
636
+ self.patch_embed(
637
+ self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))
638
+ )
639
+ + x
640
+ )
641
+
642
+ def flops(self):
643
+ flops = 0
644
+ flops += self.residual_group.flops()
645
+ H, W = self.input_resolution
646
+ flops += H * W * self.dim * self.dim * 9
647
+ flops += self.patch_embed.flops()
648
+ flops += self.patch_unembed.flops()
649
+
650
+ return flops
651
+
652
+
653
+ class PatchEmbed(nn.Module):
654
+ r"""Image to Patch Embedding
655
+
656
+ Args:
657
+ img_size (int): Image size. Default: 224.
658
+ patch_size (int): Patch token size. Default: 4.
659
+ in_chans (int): Number of input image channels. Default: 3.
660
+ embed_dim (int): Number of linear projection output channels. Default: 96.
661
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
662
+ """
663
+
664
+ def __init__(
665
+ self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None
666
+ ):
667
+ super().__init__()
668
+ img_size = to_2tuple(img_size)
669
+ patch_size = to_2tuple(patch_size)
670
+ patches_resolution = [
671
+ img_size[0] // patch_size[0], # type: ignore
672
+ img_size[1] // patch_size[1], # type: ignore
673
+ ]
674
+ self.img_size = img_size
675
+ self.patch_size = patch_size
676
+ self.patches_resolution = patches_resolution
677
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
678
+
679
+ self.in_chans = in_chans
680
+ self.embed_dim = embed_dim
681
+
682
+ if norm_layer is not None:
683
+ self.norm = norm_layer(embed_dim)
684
+ else:
685
+ self.norm = None
686
+
687
+ def forward(self, x):
688
+ x = x.flatten(2).transpose(1, 2) # B Ph*Pw C
689
+ if self.norm is not None:
690
+ x = self.norm(x)
691
+ return x
692
+
693
+ def flops(self):
694
+ flops = 0
695
+ H, W = self.img_size
696
+ if self.norm is not None:
697
+ flops += H * W * self.embed_dim # type: ignore
698
+ return flops
699
+
700
+
701
+ class PatchUnEmbed(nn.Module):
702
+ r"""Image to Patch Unembedding
703
+
704
+ Args:
705
+ img_size (int): Image size. Default: 224.
706
+ patch_size (int): Patch token size. Default: 4.
707
+ in_chans (int): Number of input image channels. Default: 3.
708
+ embed_dim (int): Number of linear projection output channels. Default: 96.
709
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
710
+ """
711
+
712
+ def __init__(
713
+ self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None
714
+ ):
715
+ super().__init__()
716
+ img_size = to_2tuple(img_size)
717
+ patch_size = to_2tuple(patch_size)
718
+ patches_resolution = [
719
+ img_size[0] // patch_size[0], # type: ignore
720
+ img_size[1] // patch_size[1], # type: ignore
721
+ ]
722
+ self.img_size = img_size
723
+ self.patch_size = patch_size
724
+ self.patches_resolution = patches_resolution
725
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
726
+
727
+ self.in_chans = in_chans
728
+ self.embed_dim = embed_dim
729
+
730
+ def forward(self, x, x_size):
731
+ B, HW, C = x.shape
732
+ x = x.transpose(1, 2).view(B, self.embed_dim, x_size[0], x_size[1]) # B Ph*Pw C
733
+ return x
734
+
735
+ def flops(self):
736
+ flops = 0
737
+ return flops
738
+
739
+
740
+ class Upsample(nn.Sequential):
741
+ """Upsample module.
742
+
743
+ Args:
744
+ scale (int): Scale factor. Supported scales: 2^n and 3.
745
+ num_feat (int): Channel number of intermediate features.
746
+ """
747
+
748
+ def __init__(self, scale, num_feat):
749
+ m = []
750
+ if (scale & (scale - 1)) == 0: # scale = 2^n
751
+ for _ in range(int(math.log(scale, 2))):
752
+ m.append(nn.Conv2d(num_feat, 4 * num_feat, 3, 1, 1))
753
+ m.append(nn.PixelShuffle(2))
754
+ elif scale == 3:
755
+ m.append(nn.Conv2d(num_feat, 9 * num_feat, 3, 1, 1))
756
+ m.append(nn.PixelShuffle(3))
757
+ else:
758
+ raise ValueError(
759
+ f"scale {scale} is not supported. " "Supported scales: 2^n and 3."
760
+ )
761
+ super(Upsample, self).__init__(*m)
762
+
763
+
764
+ class UpsampleOneStep(nn.Sequential):
765
+ """UpsampleOneStep module (the difference with Upsample is that it always only has 1conv + 1pixelshuffle)
766
+ Used in lightweight SR to save parameters.
767
+
768
+ Args:
769
+ scale (int): Scale factor. Supported scales: 2^n and 3.
770
+ num_feat (int): Channel number of intermediate features.
771
+
772
+ """
773
+
774
+ def __init__(self, scale, num_feat, num_out_ch, input_resolution=None):
775
+ self.num_feat = num_feat
776
+ self.input_resolution = input_resolution
777
+ m = []
778
+ m.append(nn.Conv2d(num_feat, (scale**2) * num_out_ch, 3, 1, 1))
779
+ m.append(nn.PixelShuffle(scale))
780
+ super(UpsampleOneStep, self).__init__(*m)
781
+
782
+ def flops(self):
783
+ H, W = self.input_resolution # type: ignore
784
+ flops = H * W * self.num_feat * 3 * 9
785
+ return flops
786
+
787
+
788
+ class SwinIR(nn.Module):
789
+ r"""SwinIR
790
+ A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
791
+
792
+ Args:
793
+ img_size (int | tuple(int)): Input image size. Default 64
794
+ patch_size (int | tuple(int)): Patch size. Default: 1
795
+ in_chans (int): Number of input image channels. Default: 3
796
+ embed_dim (int): Patch embedding dimension. Default: 96
797
+ depths (tuple(int)): Depth of each Swin Transformer layer.
798
+ num_heads (tuple(int)): Number of attention heads in different layers.
799
+ window_size (int): Window size. Default: 7
800
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
801
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
802
+ qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
803
+ drop_rate (float): Dropout rate. Default: 0
804
+ attn_drop_rate (float): Attention dropout rate. Default: 0
805
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
806
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
807
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
808
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
809
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
810
+ upscale: Upscale factor. 2/3/4/8 for image SR, 1 for denoising and compress artifact reduction
811
+ img_range: Image range. 1. or 255.
812
+ upsampler: The reconstruction reconstruction module. 'pixelshuffle'/'pixelshuffledirect'/'nearest+conv'/None
813
+ resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
814
+ """
815
+
816
+ def __init__(
817
+ self,
818
+ state_dict,
819
+ **kwargs,
820
+ ):
821
+ super(SwinIR, self).__init__()
822
+
823
+ # Defaults
824
+ img_size = 64
825
+ patch_size = 1
826
+ in_chans = 3
827
+ embed_dim = 96
828
+ depths = [6, 6, 6, 6]
829
+ num_heads = [6, 6, 6, 6]
830
+ window_size = 7
831
+ mlp_ratio = 4.0
832
+ qkv_bias = True
833
+ qk_scale = None
834
+ drop_rate = 0.0
835
+ attn_drop_rate = 0.0
836
+ drop_path_rate = 0.1
837
+ norm_layer = nn.LayerNorm
838
+ ape = False
839
+ patch_norm = True
840
+ use_checkpoint = False
841
+ upscale = 2
842
+ img_range = 1.0
843
+ upsampler = ""
844
+ resi_connection = "1conv"
845
+ num_feat = 64
846
+ num_in_ch = in_chans
847
+ num_out_ch = in_chans
848
+ supports_fp16 = True
849
+ self.start_unshuffle = 1
850
+
851
+ self.model_arch = "SwinIR"
852
+ self.sub_type = "SR"
853
+ self.state = state_dict
854
+ if "params_ema" in self.state:
855
+ self.state = self.state["params_ema"]
856
+ elif "params" in self.state:
857
+ self.state = self.state["params"]
858
+
859
+ state_keys = self.state.keys()
860
+
861
+ if "conv_before_upsample.0.weight" in state_keys:
862
+ if "conv_up1.weight" in state_keys:
863
+ upsampler = "nearest+conv"
864
+ else:
865
+ upsampler = "pixelshuffle"
866
+ supports_fp16 = False
867
+ elif "upsample.0.weight" in state_keys:
868
+ upsampler = "pixelshuffledirect"
869
+ else:
870
+ upsampler = ""
871
+
872
+ num_feat = (
873
+ self.state.get("conv_before_upsample.0.weight", None).shape[1]
874
+ if self.state.get("conv_before_upsample.weight", None)
875
+ else 64
876
+ )
877
+
878
+ if "conv_first.1.weight" in self.state:
879
+ self.state["conv_first.weight"] = self.state.pop("conv_first.1.weight")
880
+ self.state["conv_first.bias"] = self.state.pop("conv_first.1.bias")
881
+ self.start_unshuffle = round(math.sqrt(self.state["conv_first.weight"].shape[1] // 3))
882
+
883
+ num_in_ch = self.state["conv_first.weight"].shape[1]
884
+ in_chans = num_in_ch
885
+ if "conv_last.weight" in state_keys:
886
+ num_out_ch = self.state["conv_last.weight"].shape[0]
887
+ else:
888
+ num_out_ch = num_in_ch
889
+
890
+ upscale = 1
891
+ if upsampler == "nearest+conv":
892
+ upsample_keys = [
893
+ x for x in state_keys if "conv_up" in x and "bias" not in x
894
+ ]
895
+
896
+ for upsample_key in upsample_keys:
897
+ upscale *= 2
898
+ elif upsampler == "pixelshuffle":
899
+ upsample_keys = [
900
+ x
901
+ for x in state_keys
902
+ if "upsample" in x and "conv" not in x and "bias" not in x
903
+ ]
904
+ for upsample_key in upsample_keys:
905
+ shape = self.state[upsample_key].shape[0]
906
+ upscale *= math.sqrt(shape // num_feat)
907
+ upscale = int(upscale)
908
+ elif upsampler == "pixelshuffledirect":
909
+ upscale = int(
910
+ math.sqrt(self.state["upsample.0.bias"].shape[0] // num_out_ch)
911
+ )
912
+
913
+ max_layer_num = 0
914
+ max_block_num = 0
915
+ for key in state_keys:
916
+ result = re.match(
917
+ r"layers.(\d*).residual_group.blocks.(\d*).norm1.weight", key
918
+ )
919
+ if result:
920
+ layer_num, block_num = result.groups()
921
+ max_layer_num = max(max_layer_num, int(layer_num))
922
+ max_block_num = max(max_block_num, int(block_num))
923
+
924
+ depths = [max_block_num + 1 for _ in range(max_layer_num + 1)]
925
+
926
+ if (
927
+ "layers.0.residual_group.blocks.0.attn.relative_position_bias_table"
928
+ in state_keys
929
+ ):
930
+ num_heads_num = self.state[
931
+ "layers.0.residual_group.blocks.0.attn.relative_position_bias_table"
932
+ ].shape[-1]
933
+ num_heads = [num_heads_num for _ in range(max_layer_num + 1)]
934
+ else:
935
+ num_heads = depths
936
+
937
+ embed_dim = self.state["conv_first.weight"].shape[0]
938
+
939
+ mlp_ratio = float(
940
+ self.state["layers.0.residual_group.blocks.0.mlp.fc1.bias"].shape[0]
941
+ / embed_dim
942
+ )
943
+
944
+ # TODO: could actually count the layers, but this should do
945
+ if "layers.0.conv.4.weight" in state_keys:
946
+ resi_connection = "3conv"
947
+ else:
948
+ resi_connection = "1conv"
949
+
950
+ window_size = int(
951
+ math.sqrt(
952
+ self.state[
953
+ "layers.0.residual_group.blocks.0.attn.relative_position_index"
954
+ ].shape[0]
955
+ )
956
+ )
957
+
958
+ if "layers.0.residual_group.blocks.1.attn_mask" in state_keys:
959
+ img_size = int(
960
+ math.sqrt(
961
+ self.state["layers.0.residual_group.blocks.1.attn_mask"].shape[0]
962
+ )
963
+ * window_size
964
+ )
965
+
966
+ # The JPEG models are the only ones with window-size 7, and they also use this range
967
+ img_range = 255.0 if window_size == 7 else 1.0
968
+
969
+ self.in_nc = num_in_ch
970
+ self.out_nc = num_out_ch
971
+ self.num_feat = num_feat
972
+ self.embed_dim = embed_dim
973
+ self.num_heads = num_heads
974
+ self.depths = depths
975
+ self.window_size = window_size
976
+ self.mlp_ratio = mlp_ratio
977
+ self.scale = upscale / self.start_unshuffle
978
+ self.upsampler = upsampler
979
+ self.img_size = img_size
980
+ self.img_range = img_range
981
+ self.resi_connection = resi_connection
982
+
983
+ self.supports_fp16 = False # Too much weirdness to support this at the moment
984
+ self.supports_bfp16 = True
985
+ self.min_size_restriction = 16
986
+
987
+ self.img_range = img_range
988
+ if in_chans == 3:
989
+ rgb_mean = (0.4488, 0.4371, 0.4040)
990
+ self.mean = torch.Tensor(rgb_mean).view(1, 3, 1, 1)
991
+ else:
992
+ self.mean = torch.zeros(1, 1, 1, 1)
993
+ self.upscale = upscale
994
+ self.upsampler = upsampler
995
+ self.window_size = window_size
996
+
997
+ #####################################################################################################
998
+ ################################### 1, shallow feature extraction ###################################
999
+ self.conv_first = nn.Conv2d(num_in_ch, embed_dim, 3, 1, 1)
1000
+
1001
+ #####################################################################################################
1002
+ ################################### 2, deep feature extraction ######################################
1003
+ self.num_layers = len(depths)
1004
+ self.embed_dim = embed_dim
1005
+ self.ape = ape
1006
+ self.patch_norm = patch_norm
1007
+ self.num_features = embed_dim
1008
+ self.mlp_ratio = mlp_ratio
1009
+
1010
+ # split image into non-overlapping patches
1011
+ self.patch_embed = PatchEmbed(
1012
+ img_size=img_size,
1013
+ patch_size=patch_size,
1014
+ in_chans=embed_dim,
1015
+ embed_dim=embed_dim,
1016
+ norm_layer=norm_layer if self.patch_norm else None,
1017
+ )
1018
+ num_patches = self.patch_embed.num_patches
1019
+ patches_resolution = self.patch_embed.patches_resolution
1020
+ self.patches_resolution = patches_resolution
1021
+
1022
+ # merge non-overlapping patches into image
1023
+ self.patch_unembed = PatchUnEmbed(
1024
+ img_size=img_size,
1025
+ patch_size=patch_size,
1026
+ in_chans=embed_dim,
1027
+ embed_dim=embed_dim,
1028
+ norm_layer=norm_layer if self.patch_norm else None,
1029
+ )
1030
+
1031
+ # absolute position embedding
1032
+ if self.ape:
1033
+ self.absolute_pos_embed = nn.Parameter( # type: ignore
1034
+ torch.zeros(1, num_patches, embed_dim)
1035
+ )
1036
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
1037
+
1038
+ self.pos_drop = nn.Dropout(p=drop_rate)
1039
+
1040
+ # stochastic depth
1041
+ dpr = [
1042
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
1043
+ ] # stochastic depth decay rule
1044
+
1045
+ # build Residual Swin Transformer blocks (RSTB)
1046
+ self.layers = nn.ModuleList()
1047
+ for i_layer in range(self.num_layers):
1048
+ layer = RSTB(
1049
+ dim=embed_dim,
1050
+ input_resolution=(patches_resolution[0], patches_resolution[1]),
1051
+ depth=depths[i_layer],
1052
+ num_heads=num_heads[i_layer],
1053
+ window_size=window_size,
1054
+ mlp_ratio=self.mlp_ratio,
1055
+ qkv_bias=qkv_bias,
1056
+ qk_scale=qk_scale,
1057
+ drop=drop_rate,
1058
+ attn_drop=attn_drop_rate,
1059
+ drop_path=dpr[
1060
+ sum(depths[:i_layer]) : sum(depths[: i_layer + 1]) # type: ignore
1061
+ ], # no impact on SR results
1062
+ norm_layer=norm_layer,
1063
+ downsample=None,
1064
+ use_checkpoint=use_checkpoint,
1065
+ img_size=img_size,
1066
+ patch_size=patch_size,
1067
+ resi_connection=resi_connection,
1068
+ )
1069
+ self.layers.append(layer)
1070
+ self.norm = norm_layer(self.num_features)
1071
+
1072
+ # build the last conv layer in deep feature extraction
1073
+ if resi_connection == "1conv":
1074
+ self.conv_after_body = nn.Conv2d(embed_dim, embed_dim, 3, 1, 1)
1075
+ elif resi_connection == "3conv":
1076
+ # to save parameters and memory
1077
+ self.conv_after_body = nn.Sequential(
1078
+ nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
1079
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
1080
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
1081
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
1082
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1),
1083
+ )
1084
+
1085
+ #####################################################################################################
1086
+ ################################ 3, high quality image reconstruction ################################
1087
+ if self.upsampler == "pixelshuffle":
1088
+ # for classical SR
1089
+ self.conv_before_upsample = nn.Sequential(
1090
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
1091
+ )
1092
+ self.upsample = Upsample(upscale, num_feat)
1093
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
1094
+ elif self.upsampler == "pixelshuffledirect":
1095
+ # for lightweight SR (to save parameters)
1096
+ self.upsample = UpsampleOneStep(
1097
+ upscale,
1098
+ embed_dim,
1099
+ num_out_ch,
1100
+ (patches_resolution[0], patches_resolution[1]),
1101
+ )
1102
+ elif self.upsampler == "nearest+conv":
1103
+ # for real-world SR (less artifacts)
1104
+ self.conv_before_upsample = nn.Sequential(
1105
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
1106
+ )
1107
+ self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
1108
+ if self.upscale == 4:
1109
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
1110
+ elif self.upscale == 8:
1111
+ self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
1112
+ self.conv_up3 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
1113
+ self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
1114
+ self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
1115
+ self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
1116
+ else:
1117
+ # for image denoising and JPEG compression artifact reduction
1118
+ self.conv_last = nn.Conv2d(embed_dim, num_out_ch, 3, 1, 1)
1119
+
1120
+ self.apply(self._init_weights)
1121
+ self.load_state_dict(self.state, strict=False)
1122
+
1123
+ def _init_weights(self, m):
1124
+ if isinstance(m, nn.Linear):
1125
+ trunc_normal_(m.weight, std=0.02)
1126
+ if isinstance(m, nn.Linear) and m.bias is not None:
1127
+ nn.init.constant_(m.bias, 0)
1128
+ elif isinstance(m, nn.LayerNorm):
1129
+ nn.init.constant_(m.bias, 0)
1130
+ nn.init.constant_(m.weight, 1.0)
1131
+
1132
+ @torch.jit.ignore # type: ignore
1133
+ def no_weight_decay(self):
1134
+ return {"absolute_pos_embed"}
1135
+
1136
+ @torch.jit.ignore # type: ignore
1137
+ def no_weight_decay_keywords(self):
1138
+ return {"relative_position_bias_table"}
1139
+
1140
+ def check_image_size(self, x):
1141
+ _, _, h, w = x.size()
1142
+ mod_pad_h = (self.window_size - h % self.window_size) % self.window_size
1143
+ mod_pad_w = (self.window_size - w % self.window_size) % self.window_size
1144
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h), "reflect")
1145
+ return x
1146
+
1147
+ def forward_features(self, x):
1148
+ x_size = (x.shape[2], x.shape[3])
1149
+ x = self.patch_embed(x)
1150
+ if self.ape:
1151
+ x = x + self.absolute_pos_embed
1152
+ x = self.pos_drop(x)
1153
+
1154
+ for layer in self.layers:
1155
+ x = layer(x, x_size)
1156
+
1157
+ x = self.norm(x) # B L C
1158
+ x = self.patch_unembed(x, x_size)
1159
+
1160
+ return x
1161
+
1162
+ def forward(self, x):
1163
+ H, W = x.shape[2:]
1164
+ x = self.check_image_size(x)
1165
+
1166
+ self.mean = self.mean.type_as(x)
1167
+ x = (x - self.mean) * self.img_range
1168
+
1169
+ if self.start_unshuffle > 1:
1170
+ x = torch.nn.functional.pixel_unshuffle(x, self.start_unshuffle)
1171
+
1172
+ if self.upsampler == "pixelshuffle":
1173
+ # for classical SR
1174
+ x = self.conv_first(x)
1175
+ x = self.conv_after_body(self.forward_features(x)) + x
1176
+ x = self.conv_before_upsample(x)
1177
+ x = self.conv_last(self.upsample(x))
1178
+ elif self.upsampler == "pixelshuffledirect":
1179
+ # for lightweight SR
1180
+ x = self.conv_first(x)
1181
+ x = self.conv_after_body(self.forward_features(x)) + x
1182
+ x = self.upsample(x)
1183
+ elif self.upsampler == "nearest+conv":
1184
+ # for real-world SR
1185
+ x = self.conv_first(x)
1186
+ x = self.conv_after_body(self.forward_features(x)) + x
1187
+ x = self.conv_before_upsample(x)
1188
+ x = self.lrelu(
1189
+ self.conv_up1(
1190
+ torch.nn.functional.interpolate(x, scale_factor=2, mode="nearest") # type: ignore
1191
+ )
1192
+ )
1193
+ if self.upscale == 4:
1194
+ x = self.lrelu(
1195
+ self.conv_up2(
1196
+ torch.nn.functional.interpolate( # type: ignore
1197
+ x, scale_factor=2, mode="nearest"
1198
+ )
1199
+ )
1200
+ )
1201
+ elif self.upscale == 8:
1202
+ x = self.lrelu(self.conv_up2(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
1203
+ x = self.lrelu(self.conv_up3(torch.nn.functional.interpolate(x, scale_factor=2, mode='nearest')))
1204
+ x = self.conv_last(self.lrelu(self.conv_hr(x)))
1205
+ else:
1206
+ # for image denoising and JPEG compression artifact reduction
1207
+ x_first = self.conv_first(x)
1208
+ res = self.conv_after_body(self.forward_features(x_first)) + x_first
1209
+ x = x + self.conv_last(res)
1210
+
1211
+ x = x / self.img_range + self.mean
1212
+
1213
+ return x[:, :, : H * self.upscale, : W * self.upscale]
1214
+
1215
+ def flops(self):
1216
+ flops = 0
1217
+ H, W = self.patches_resolution
1218
+ flops += H * W * 3 * self.embed_dim * 9
1219
+ flops += self.patch_embed.flops()
1220
+ for i, layer in enumerate(self.layers):
1221
+ flops += layer.flops() # type: ignore
1222
+ flops += H * W * 3 * self.embed_dim * self.embed_dim
1223
+ flops += self.upsample.flops() # type: ignore
1224
+ return flops
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/__init__.py ADDED
File without changes
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/block.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ from __future__ import annotations
5
+
6
+ from collections import OrderedDict
7
+ try:
8
+ from typing import Literal
9
+ except ImportError:
10
+ from typing_extensions import Literal
11
+
12
+ import torch
13
+ import torch.nn as nn
14
+
15
+ ####################
16
+ # Basic blocks
17
+ ####################
18
+
19
+
20
+ def act(act_type: str, inplace=True, neg_slope=0.2, n_prelu=1):
21
+ # helper selecting activation
22
+ # neg_slope: for leakyrelu and init of prelu
23
+ # n_prelu: for p_relu num_parameters
24
+ act_type = act_type.lower()
25
+ if act_type == "relu":
26
+ layer = nn.ReLU(inplace)
27
+ elif act_type == "leakyrelu":
28
+ layer = nn.LeakyReLU(neg_slope, inplace)
29
+ elif act_type == "prelu":
30
+ layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
31
+ else:
32
+ raise NotImplementedError(
33
+ "activation layer [{:s}] is not found".format(act_type)
34
+ )
35
+ return layer
36
+
37
+
38
+ def norm(norm_type: str, nc: int):
39
+ # helper selecting normalization layer
40
+ norm_type = norm_type.lower()
41
+ if norm_type == "batch":
42
+ layer = nn.BatchNorm2d(nc, affine=True)
43
+ elif norm_type == "instance":
44
+ layer = nn.InstanceNorm2d(nc, affine=False)
45
+ else:
46
+ raise NotImplementedError(
47
+ "normalization layer [{:s}] is not found".format(norm_type)
48
+ )
49
+ return layer
50
+
51
+
52
+ def pad(pad_type: str, padding):
53
+ # helper selecting padding layer
54
+ # if padding is 'zero', do by conv layers
55
+ pad_type = pad_type.lower()
56
+ if padding == 0:
57
+ return None
58
+ if pad_type == "reflect":
59
+ layer = nn.ReflectionPad2d(padding)
60
+ elif pad_type == "replicate":
61
+ layer = nn.ReplicationPad2d(padding)
62
+ else:
63
+ raise NotImplementedError(
64
+ "padding layer [{:s}] is not implemented".format(pad_type)
65
+ )
66
+ return layer
67
+
68
+
69
+ def get_valid_padding(kernel_size, dilation):
70
+ kernel_size = kernel_size + (kernel_size - 1) * (dilation - 1)
71
+ padding = (kernel_size - 1) // 2
72
+ return padding
73
+
74
+
75
+ class ConcatBlock(nn.Module):
76
+ # Concat the output of a submodule to its input
77
+ def __init__(self, submodule):
78
+ super(ConcatBlock, self).__init__()
79
+ self.sub = submodule
80
+
81
+ def forward(self, x):
82
+ output = torch.cat((x, self.sub(x)), dim=1)
83
+ return output
84
+
85
+ def __repr__(self):
86
+ tmpstr = "Identity .. \n|"
87
+ modstr = self.sub.__repr__().replace("\n", "\n|")
88
+ tmpstr = tmpstr + modstr
89
+ return tmpstr
90
+
91
+
92
+ class ShortcutBlock(nn.Module):
93
+ # Elementwise sum the output of a submodule to its input
94
+ def __init__(self, submodule):
95
+ super(ShortcutBlock, self).__init__()
96
+ self.sub = submodule
97
+
98
+ def forward(self, x):
99
+ output = x + self.sub(x)
100
+ return output
101
+
102
+ def __repr__(self):
103
+ tmpstr = "Identity + \n|"
104
+ modstr = self.sub.__repr__().replace("\n", "\n|")
105
+ tmpstr = tmpstr + modstr
106
+ return tmpstr
107
+
108
+
109
+ class ShortcutBlockSPSR(nn.Module):
110
+ # Elementwise sum the output of a submodule to its input
111
+ def __init__(self, submodule):
112
+ super(ShortcutBlockSPSR, self).__init__()
113
+ self.sub = submodule
114
+
115
+ def forward(self, x):
116
+ return x, self.sub
117
+
118
+ def __repr__(self):
119
+ tmpstr = "Identity + \n|"
120
+ modstr = self.sub.__repr__().replace("\n", "\n|")
121
+ tmpstr = tmpstr + modstr
122
+ return tmpstr
123
+
124
+
125
+ def sequential(*args):
126
+ # Flatten Sequential. It unwraps nn.Sequential.
127
+ if len(args) == 1:
128
+ if isinstance(args[0], OrderedDict):
129
+ raise NotImplementedError("sequential does not support OrderedDict input.")
130
+ return args[0] # No sequential is needed.
131
+ modules = []
132
+ for module in args:
133
+ if isinstance(module, nn.Sequential):
134
+ for submodule in module.children():
135
+ modules.append(submodule)
136
+ elif isinstance(module, nn.Module):
137
+ modules.append(module)
138
+ return nn.Sequential(*modules)
139
+
140
+
141
+ ConvMode = Literal["CNA", "NAC", "CNAC"]
142
+
143
+
144
+ # 2x2x2 Conv Block
145
+ def conv_block_2c2(
146
+ in_nc,
147
+ out_nc,
148
+ act_type="relu",
149
+ ):
150
+ return sequential(
151
+ nn.Conv2d(in_nc, out_nc, kernel_size=2, padding=1),
152
+ nn.Conv2d(out_nc, out_nc, kernel_size=2, padding=0),
153
+ act(act_type) if act_type else None,
154
+ )
155
+
156
+
157
+ def conv_block(
158
+ in_nc: int,
159
+ out_nc: int,
160
+ kernel_size,
161
+ stride=1,
162
+ dilation=1,
163
+ groups=1,
164
+ bias=True,
165
+ pad_type="zero",
166
+ norm_type: str | None = None,
167
+ act_type: str | None = "relu",
168
+ mode: ConvMode = "CNA",
169
+ c2x2=False,
170
+ ):
171
+ """
172
+ Conv layer with padding, normalization, activation
173
+ mode: CNA --> Conv -> Norm -> Act
174
+ NAC --> Norm -> Act --> Conv (Identity Mappings in Deep Residual Networks, ECCV16)
175
+ """
176
+
177
+ if c2x2:
178
+ return conv_block_2c2(in_nc, out_nc, act_type=act_type)
179
+
180
+ assert mode in ("CNA", "NAC", "CNAC"), "Wrong conv mode [{:s}]".format(mode)
181
+ padding = get_valid_padding(kernel_size, dilation)
182
+ p = pad(pad_type, padding) if pad_type and pad_type != "zero" else None
183
+ padding = padding if pad_type == "zero" else 0
184
+
185
+ c = nn.Conv2d(
186
+ in_nc,
187
+ out_nc,
188
+ kernel_size=kernel_size,
189
+ stride=stride,
190
+ padding=padding,
191
+ dilation=dilation,
192
+ bias=bias,
193
+ groups=groups,
194
+ )
195
+ a = act(act_type) if act_type else None
196
+ if mode in ("CNA", "CNAC"):
197
+ n = norm(norm_type, out_nc) if norm_type else None
198
+ return sequential(p, c, n, a)
199
+ elif mode == "NAC":
200
+ if norm_type is None and act_type is not None:
201
+ a = act(act_type, inplace=False)
202
+ # Important!
203
+ # input----ReLU(inplace)----Conv--+----output
204
+ # |________________________|
205
+ # inplace ReLU will modify the input, therefore wrong output
206
+ n = norm(norm_type, in_nc) if norm_type else None
207
+ return sequential(n, a, p, c)
208
+ else:
209
+ assert False, f"Invalid conv mode {mode}"
210
+
211
+
212
+ ####################
213
+ # Useful blocks
214
+ ####################
215
+
216
+
217
+ class ResNetBlock(nn.Module):
218
+ """
219
+ ResNet Block, 3-3 style
220
+ with extra residual scaling used in EDSR
221
+ (Enhanced Deep Residual Networks for Single Image Super-Resolution, CVPRW 17)
222
+ """
223
+
224
+ def __init__(
225
+ self,
226
+ in_nc,
227
+ mid_nc,
228
+ out_nc,
229
+ kernel_size=3,
230
+ stride=1,
231
+ dilation=1,
232
+ groups=1,
233
+ bias=True,
234
+ pad_type="zero",
235
+ norm_type=None,
236
+ act_type="relu",
237
+ mode: ConvMode = "CNA",
238
+ res_scale=1,
239
+ ):
240
+ super(ResNetBlock, self).__init__()
241
+ conv0 = conv_block(
242
+ in_nc,
243
+ mid_nc,
244
+ kernel_size,
245
+ stride,
246
+ dilation,
247
+ groups,
248
+ bias,
249
+ pad_type,
250
+ norm_type,
251
+ act_type,
252
+ mode,
253
+ )
254
+ if mode == "CNA":
255
+ act_type = None
256
+ if mode == "CNAC": # Residual path: |-CNAC-|
257
+ act_type = None
258
+ norm_type = None
259
+ conv1 = conv_block(
260
+ mid_nc,
261
+ out_nc,
262
+ kernel_size,
263
+ stride,
264
+ dilation,
265
+ groups,
266
+ bias,
267
+ pad_type,
268
+ norm_type,
269
+ act_type,
270
+ mode,
271
+ )
272
+ # if in_nc != out_nc:
273
+ # self.project = conv_block(in_nc, out_nc, 1, stride, dilation, 1, bias, pad_type, \
274
+ # None, None)
275
+ # print('Need a projecter in ResNetBlock.')
276
+ # else:
277
+ # self.project = lambda x:x
278
+ self.res = sequential(conv0, conv1)
279
+ self.res_scale = res_scale
280
+
281
+ def forward(self, x):
282
+ res = self.res(x).mul(self.res_scale)
283
+ return x + res
284
+
285
+
286
+ class RRDB(nn.Module):
287
+ """
288
+ Residual in Residual Dense Block
289
+ (ESRGAN: Enhanced Super-Resolution Generative Adversarial Networks)
290
+ """
291
+
292
+ def __init__(
293
+ self,
294
+ nf,
295
+ kernel_size=3,
296
+ gc=32,
297
+ stride=1,
298
+ bias: bool = True,
299
+ pad_type="zero",
300
+ norm_type=None,
301
+ act_type="leakyrelu",
302
+ mode: ConvMode = "CNA",
303
+ _convtype="Conv2D",
304
+ _spectral_norm=False,
305
+ plus=False,
306
+ c2x2=False,
307
+ ):
308
+ super(RRDB, self).__init__()
309
+ self.RDB1 = ResidualDenseBlock_5C(
310
+ nf,
311
+ kernel_size,
312
+ gc,
313
+ stride,
314
+ bias,
315
+ pad_type,
316
+ norm_type,
317
+ act_type,
318
+ mode,
319
+ plus=plus,
320
+ c2x2=c2x2,
321
+ )
322
+ self.RDB2 = ResidualDenseBlock_5C(
323
+ nf,
324
+ kernel_size,
325
+ gc,
326
+ stride,
327
+ bias,
328
+ pad_type,
329
+ norm_type,
330
+ act_type,
331
+ mode,
332
+ plus=plus,
333
+ c2x2=c2x2,
334
+ )
335
+ self.RDB3 = ResidualDenseBlock_5C(
336
+ nf,
337
+ kernel_size,
338
+ gc,
339
+ stride,
340
+ bias,
341
+ pad_type,
342
+ norm_type,
343
+ act_type,
344
+ mode,
345
+ plus=plus,
346
+ c2x2=c2x2,
347
+ )
348
+
349
+ def forward(self, x):
350
+ out = self.RDB1(x)
351
+ out = self.RDB2(out)
352
+ out = self.RDB3(out)
353
+ return out * 0.2 + x
354
+
355
+
356
+ class ResidualDenseBlock_5C(nn.Module):
357
+ """
358
+ Residual Dense Block
359
+ style: 5 convs
360
+ The core module of paper: (Residual Dense Network for Image Super-Resolution, CVPR 18)
361
+ Modified options that can be used:
362
+ - "Partial Convolution based Padding" arXiv:1811.11718
363
+ - "Spectral normalization" arXiv:1802.05957
364
+ - "ICASSP 2020 - ESRGAN+ : Further Improving ESRGAN" N. C.
365
+ {Rakotonirina} and A. {Rasoanaivo}
366
+
367
+ Args:
368
+ nf (int): Channel number of intermediate features (num_feat).
369
+ gc (int): Channels for each growth (num_grow_ch: growth channel,
370
+ i.e. intermediate channels).
371
+ convtype (str): the type of convolution to use. Default: 'Conv2D'
372
+ gaussian_noise (bool): enable the ESRGAN+ gaussian noise (no new
373
+ trainable parameters)
374
+ plus (bool): enable the additional residual paths from ESRGAN+
375
+ (adds trainable parameters)
376
+ """
377
+
378
+ def __init__(
379
+ self,
380
+ nf=64,
381
+ kernel_size=3,
382
+ gc=32,
383
+ stride=1,
384
+ bias: bool = True,
385
+ pad_type="zero",
386
+ norm_type=None,
387
+ act_type="leakyrelu",
388
+ mode: ConvMode = "CNA",
389
+ plus=False,
390
+ c2x2=False,
391
+ ):
392
+ super(ResidualDenseBlock_5C, self).__init__()
393
+
394
+ ## +
395
+ self.conv1x1 = conv1x1(nf, gc) if plus else None
396
+ ## +
397
+
398
+ self.conv1 = conv_block(
399
+ nf,
400
+ gc,
401
+ kernel_size,
402
+ stride,
403
+ bias=bias,
404
+ pad_type=pad_type,
405
+ norm_type=norm_type,
406
+ act_type=act_type,
407
+ mode=mode,
408
+ c2x2=c2x2,
409
+ )
410
+ self.conv2 = conv_block(
411
+ nf + gc,
412
+ gc,
413
+ kernel_size,
414
+ stride,
415
+ bias=bias,
416
+ pad_type=pad_type,
417
+ norm_type=norm_type,
418
+ act_type=act_type,
419
+ mode=mode,
420
+ c2x2=c2x2,
421
+ )
422
+ self.conv3 = conv_block(
423
+ nf + 2 * gc,
424
+ gc,
425
+ kernel_size,
426
+ stride,
427
+ bias=bias,
428
+ pad_type=pad_type,
429
+ norm_type=norm_type,
430
+ act_type=act_type,
431
+ mode=mode,
432
+ c2x2=c2x2,
433
+ )
434
+ self.conv4 = conv_block(
435
+ nf + 3 * gc,
436
+ gc,
437
+ kernel_size,
438
+ stride,
439
+ bias=bias,
440
+ pad_type=pad_type,
441
+ norm_type=norm_type,
442
+ act_type=act_type,
443
+ mode=mode,
444
+ c2x2=c2x2,
445
+ )
446
+ if mode == "CNA":
447
+ last_act = None
448
+ else:
449
+ last_act = act_type
450
+ self.conv5 = conv_block(
451
+ nf + 4 * gc,
452
+ nf,
453
+ 3,
454
+ stride,
455
+ bias=bias,
456
+ pad_type=pad_type,
457
+ norm_type=norm_type,
458
+ act_type=last_act,
459
+ mode=mode,
460
+ c2x2=c2x2,
461
+ )
462
+
463
+ def forward(self, x):
464
+ x1 = self.conv1(x)
465
+ x2 = self.conv2(torch.cat((x, x1), 1))
466
+ if self.conv1x1:
467
+ # pylint: disable=not-callable
468
+ x2 = x2 + self.conv1x1(x) # +
469
+ x3 = self.conv3(torch.cat((x, x1, x2), 1))
470
+ x4 = self.conv4(torch.cat((x, x1, x2, x3), 1))
471
+ if self.conv1x1:
472
+ x4 = x4 + x2 # +
473
+ x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
474
+ return x5 * 0.2 + x
475
+
476
+
477
+ def conv1x1(in_planes, out_planes, stride=1):
478
+ return nn.Conv2d(in_planes, out_planes, kernel_size=1, stride=stride, bias=False)
479
+
480
+
481
+ ####################
482
+ # Upsampler
483
+ ####################
484
+
485
+
486
+ def pixelshuffle_block(
487
+ in_nc: int,
488
+ out_nc: int,
489
+ upscale_factor=2,
490
+ kernel_size=3,
491
+ stride=1,
492
+ bias=True,
493
+ pad_type="zero",
494
+ norm_type: str | None = None,
495
+ act_type="relu",
496
+ ):
497
+ """
498
+ Pixel shuffle layer
499
+ (Real-Time Single Image and Video Super-Resolution Using an Efficient Sub-Pixel Convolutional
500
+ Neural Network, CVPR17)
501
+ """
502
+ conv = conv_block(
503
+ in_nc,
504
+ out_nc * (upscale_factor**2),
505
+ kernel_size,
506
+ stride,
507
+ bias=bias,
508
+ pad_type=pad_type,
509
+ norm_type=None,
510
+ act_type=None,
511
+ )
512
+ pixel_shuffle = nn.PixelShuffle(upscale_factor)
513
+
514
+ n = norm(norm_type, out_nc) if norm_type else None
515
+ a = act(act_type) if act_type else None
516
+ return sequential(conv, pixel_shuffle, n, a)
517
+
518
+
519
+ def upconv_block(
520
+ in_nc: int,
521
+ out_nc: int,
522
+ upscale_factor=2,
523
+ kernel_size=3,
524
+ stride=1,
525
+ bias=True,
526
+ pad_type="zero",
527
+ norm_type: str | None = None,
528
+ act_type="relu",
529
+ mode="nearest",
530
+ c2x2=False,
531
+ ):
532
+ # Up conv
533
+ # described in https://distill.pub/2016/deconv-checkerboard/
534
+ upsample = nn.Upsample(scale_factor=upscale_factor, mode=mode)
535
+ conv = conv_block(
536
+ in_nc,
537
+ out_nc,
538
+ kernel_size,
539
+ stride,
540
+ bias=bias,
541
+ pad_type=pad_type,
542
+ norm_type=norm_type,
543
+ act_type=act_type,
544
+ c2x2=c2x2,
545
+ )
546
+ return sequential(upsample, conv)
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/face/LICENSE-GFPGAN ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Tencent is pleased to support the open source community by making GFPGAN available.
2
+
3
+ Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4
+
5
+ GFPGAN is licensed under the Apache License Version 2.0 except for the third-party components listed below.
6
+
7
+
8
+ Terms of the Apache License Version 2.0:
9
+ ---------------------------------------------
10
+ Apache License
11
+
12
+ Version 2.0, January 2004
13
+
14
+ http://www.apache.org/licenses/
15
+
16
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
17
+ 1. Definitions.
18
+
19
+ “License” shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
20
+
21
+ “Licensor” shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
22
+
23
+ “Legal Entity” shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, “control” means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
24
+
25
+ “You” (or “Your”) shall mean an individual or Legal Entity exercising permissions granted by this License.
26
+
27
+ “Source” form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
28
+
29
+ “Object” form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
30
+
31
+ “Work” shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
32
+
33
+ “Derivative Works” shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
34
+
35
+ “Contribution” shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, “submitted” means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as “Not a Contribution.”
36
+
37
+ “Contributor” shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
38
+
39
+ 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
40
+
41
+ 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
42
+
43
+ 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
44
+
45
+ You must give any other recipients of the Work or Derivative Works a copy of this License; and
46
+
47
+ You must cause any modified files to carry prominent notices stating that You changed the files; and
48
+
49
+ You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
50
+
51
+ If the Work includes a “NOTICE” text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
52
+
53
+ You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
54
+
55
+ 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
56
+
57
+ 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
58
+
59
+ 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
60
+
61
+ 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
62
+
63
+ 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
64
+
65
+ END OF TERMS AND CONDITIONS
66
+
67
+
68
+
69
+ Other dependencies and licenses:
70
+
71
+
72
+ Open Source Software licensed under the Apache 2.0 license and Other Licenses of the Third-Party Components therein:
73
+ ---------------------------------------------
74
+ 1. basicsr
75
+ Copyright 2018-2020 BasicSR Authors
76
+
77
+
78
+ This BasicSR project is released under the Apache 2.0 license.
79
+
80
+ A copy of Apache 2.0 is included in this file.
81
+
82
+ StyleGAN2
83
+ The codes are modified from the repository stylegan2-pytorch. Many thanks to the author - Kim Seonghyeon 😊 for translating from the official TensorFlow codes to PyTorch ones. Here is the license of stylegan2-pytorch.
84
+ The official repository is https://github.com/NVlabs/stylegan2, and here is the NVIDIA license.
85
+ DFDNet
86
+ The codes are largely modified from the repository DFDNet. Their license is Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.
87
+
88
+ Terms of the Nvidia License:
89
+ ---------------------------------------------
90
+
91
+ 1. Definitions
92
+
93
+ "Licensor" means any person or entity that distributes its Work.
94
+
95
+ "Software" means the original work of authorship made available under
96
+ this License.
97
+
98
+ "Work" means the Software and any additions to or derivative works of
99
+ the Software that are made available under this License.
100
+
101
+ "Nvidia Processors" means any central processing unit (CPU), graphics
102
+ processing unit (GPU), field-programmable gate array (FPGA),
103
+ application-specific integrated circuit (ASIC) or any combination
104
+ thereof designed, made, sold, or provided by Nvidia or its affiliates.
105
+
106
+ The terms "reproduce," "reproduction," "derivative works," and
107
+ "distribution" have the meaning as provided under U.S. copyright law;
108
+ provided, however, that for the purposes of this License, derivative
109
+ works shall not include works that remain separable from, or merely
110
+ link (or bind by name) to the interfaces of, the Work.
111
+
112
+ Works, including the Software, are "made available" under this License
113
+ by including in or with the Work either (a) a copyright notice
114
+ referencing the applicability of this License to the Work, or (b) a
115
+ copy of this License.
116
+
117
+ 2. License Grants
118
+
119
+ 2.1 Copyright Grant. Subject to the terms and conditions of this
120
+ License, each Licensor grants to you a perpetual, worldwide,
121
+ non-exclusive, royalty-free, copyright license to reproduce,
122
+ prepare derivative works of, publicly display, publicly perform,
123
+ sublicense and distribute its Work and any resulting derivative
124
+ works in any form.
125
+
126
+ 3. Limitations
127
+
128
+ 3.1 Redistribution. You may reproduce or distribute the Work only
129
+ if (a) you do so under this License, (b) you include a complete
130
+ copy of this License with your distribution, and (c) you retain
131
+ without modification any copyright, patent, trademark, or
132
+ attribution notices that are present in the Work.
133
+
134
+ 3.2 Derivative Works. You may specify that additional or different
135
+ terms apply to the use, reproduction, and distribution of your
136
+ derivative works of the Work ("Your Terms") only if (a) Your Terms
137
+ provide that the use limitation in Section 3.3 applies to your
138
+ derivative works, and (b) you identify the specific derivative
139
+ works that are subject to Your Terms. Notwithstanding Your Terms,
140
+ this License (including the redistribution requirements in Section
141
+ 3.1) will continue to apply to the Work itself.
142
+
143
+ 3.3 Use Limitation. The Work and any derivative works thereof only
144
+ may be used or intended for use non-commercially. The Work or
145
+ derivative works thereof may be used or intended for use by Nvidia
146
+ or its affiliates commercially or non-commercially. As used herein,
147
+ "non-commercially" means for research or evaluation purposes only.
148
+
149
+ 3.4 Patent Claims. If you bring or threaten to bring a patent claim
150
+ against any Licensor (including any claim, cross-claim or
151
+ counterclaim in a lawsuit) to enforce any patents that you allege
152
+ are infringed by any Work, then your rights under this License from
153
+ such Licensor (including the grants in Sections 2.1 and 2.2) will
154
+ terminate immediately.
155
+
156
+ 3.5 Trademarks. This License does not grant any rights to use any
157
+ Licensor's or its affiliates' names, logos, or trademarks, except
158
+ as necessary to reproduce the notices described in this License.
159
+
160
+ 3.6 Termination. If you violate any term of this License, then your
161
+ rights under this License (including the grants in Sections 2.1 and
162
+ 2.2) will terminate immediately.
163
+
164
+ 4. Disclaimer of Warranty.
165
+
166
+ THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
167
+ KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
168
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
169
+ NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
170
+ THIS LICENSE.
171
+
172
+ 5. Limitation of Liability.
173
+
174
+ EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
175
+ THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
176
+ SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
177
+ INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
178
+ OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
179
+ (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
180
+ LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
181
+ COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
182
+ THE POSSIBILITY OF SUCH DAMAGES.
183
+
184
+ MIT License
185
+
186
+ Copyright (c) 2019 Kim Seonghyeon
187
+
188
+ Permission is hereby granted, free of charge, to any person obtaining a copy
189
+ of this software and associated documentation files (the "Software"), to deal
190
+ in the Software without restriction, including without limitation the rights
191
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
192
+ copies of the Software, and to permit persons to whom the Software is
193
+ furnished to do so, subject to the following conditions:
194
+
195
+ The above copyright notice and this permission notice shall be included in all
196
+ copies or substantial portions of the Software.
197
+
198
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
199
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
200
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
201
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
202
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
203
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
204
+ SOFTWARE.
205
+
206
+
207
+
208
+ Open Source Software licensed under the BSD 3-Clause license:
209
+ ---------------------------------------------
210
+ 1. torchvision
211
+ Copyright (c) Soumith Chintala 2016,
212
+ All rights reserved.
213
+
214
+ 2. torch
215
+ Copyright (c) 2016- Facebook, Inc (Adam Paszke)
216
+ Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
217
+ Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
218
+ Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
219
+ Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
220
+ Copyright (c) 2011-2013 NYU (Clement Farabet)
221
+ Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
222
+ Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
223
+ Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
224
+
225
+
226
+ Terms of the BSD 3-Clause License:
227
+ ---------------------------------------------
228
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
229
+
230
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
231
+
232
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
233
+
234
+ 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
235
+
236
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
237
+
238
+
239
+
240
+ Open Source Software licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
241
+ ---------------------------------------------
242
+ 1. numpy
243
+ Copyright (c) 2005-2020, NumPy Developers.
244
+ All rights reserved.
245
+
246
+ A copy of BSD 3-Clause License is included in this file.
247
+
248
+ The NumPy repository and source distributions bundle several libraries that are
249
+ compatibly licensed. We list these here.
250
+
251
+ Name: Numpydoc
252
+ Files: doc/sphinxext/numpydoc/*
253
+ License: BSD-2-Clause
254
+ For details, see doc/sphinxext/LICENSE.txt
255
+
256
+ Name: scipy-sphinx-theme
257
+ Files: doc/scipy-sphinx-theme/*
258
+ License: BSD-3-Clause AND PSF-2.0 AND Apache-2.0
259
+ For details, see doc/scipy-sphinx-theme/LICENSE.txt
260
+
261
+ Name: lapack-lite
262
+ Files: numpy/linalg/lapack_lite/*
263
+ License: BSD-3-Clause
264
+ For details, see numpy/linalg/lapack_lite/LICENSE.txt
265
+
266
+ Name: tempita
267
+ Files: tools/npy_tempita/*
268
+ License: MIT
269
+ For details, see tools/npy_tempita/license.txt
270
+
271
+ Name: dragon4
272
+ Files: numpy/core/src/multiarray/dragon4.c
273
+ License: MIT
274
+ For license text, see numpy/core/src/multiarray/dragon4.c
275
+
276
+
277
+
278
+ Open Source Software licensed under the MIT license:
279
+ ---------------------------------------------
280
+ 1. facexlib
281
+ Copyright (c) 2020 Xintao Wang
282
+
283
+ 2. opencv-python
284
+ Copyright (c) Olli-Pekka Heinisuo
285
+ Please note that only files in cv2 package are used.
286
+
287
+
288
+ Terms of the MIT License:
289
+ ---------------------------------------------
290
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
291
+
292
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
293
+
294
+ THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
295
+
296
+
297
+
298
+ Open Source Software licensed under the MIT license and Other Licenses of the Third-Party Components therein:
299
+ ---------------------------------------------
300
+ 1. tqdm
301
+ Copyright (c) 2013 noamraph
302
+
303
+ `tqdm` is a product of collaborative work.
304
+ Unless otherwise stated, all authors (see commit logs) retain copyright
305
+ for their respective work, and release the work under the MIT licence
306
+ (text below).
307
+
308
+ Exceptions or notable authors are listed below
309
+ in reverse chronological order:
310
+
311
+ * files: *
312
+ MPLv2.0 2015-2020 (c) Casper da Costa-Luis
313
+ [casperdcl](https://github.com/casperdcl).
314
+ * files: tqdm/_tqdm.py
315
+ MIT 2016 (c) [PR #96] on behalf of Google Inc.
316
+ * files: tqdm/_tqdm.py setup.py README.rst MANIFEST.in .gitignore
317
+ MIT 2013 (c) Noam Yorav-Raphael, original author.
318
+
319
+ [PR #96]: https://github.com/tqdm/tqdm/pull/96
320
+
321
+
322
+ Mozilla Public Licence (MPL) v. 2.0 - Exhibit A
323
+ -----------------------------------------------
324
+
325
+ This Source Code Form is subject to the terms of the
326
+ Mozilla Public License, v. 2.0.
327
+ If a copy of the MPL was not distributed with this file,
328
+ You can obtain one at https://mozilla.org/MPL/2.0/.
329
+
330
+
331
+ MIT License (MIT)
332
+ -----------------
333
+
334
+ Copyright (c) 2013 noamraph
335
+
336
+ Permission is hereby granted, free of charge, to any person obtaining a copy of
337
+ this software and associated documentation files (the "Software"), to deal in
338
+ the Software without restriction, including without limitation the rights to
339
+ use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
340
+ the Software, and to permit persons to whom the Software is furnished to do so,
341
+ subject to the following conditions:
342
+
343
+ The above copyright notice and this permission notice shall be included in all
344
+ copies or substantial portions of the Software.
345
+
346
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
347
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
348
+ FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
349
+ COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
350
+ IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
351
+ CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/face/LICENSE-RestoreFormer ADDED
@@ -0,0 +1,351 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Tencent is pleased to support the open source community by making GFPGAN available.
2
+
3
+ Copyright (C) 2021 THL A29 Limited, a Tencent company. All rights reserved.
4
+
5
+ GFPGAN is licensed under the Apache License Version 2.0 except for the third-party components listed below.
6
+
7
+
8
+ Terms of the Apache License Version 2.0:
9
+ ---------------------------------------------
10
+ Apache License
11
+
12
+ Version 2.0, January 2004
13
+
14
+ http://www.apache.org/licenses/
15
+
16
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
17
+ 1. Definitions.
18
+
19
+ “License” shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
20
+
21
+ “Licensor” shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
22
+
23
+ “Legal Entity” shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, “control” means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
24
+
25
+ “You” (or “Your”) shall mean an individual or Legal Entity exercising permissions granted by this License.
26
+
27
+ “Source” form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
28
+
29
+ “Object” form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
30
+
31
+ “Work” shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
32
+
33
+ “Derivative Works” shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
34
+
35
+ “Contribution” shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, “submitted” means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as “Not a Contribution.”
36
+
37
+ “Contributor” shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
38
+
39
+ 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
40
+
41
+ 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
42
+
43
+ 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
44
+
45
+ You must give any other recipients of the Work or Derivative Works a copy of this License; and
46
+
47
+ You must cause any modified files to carry prominent notices stating that You changed the files; and
48
+
49
+ You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
50
+
51
+ If the Work includes a “NOTICE” text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
52
+
53
+ You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
54
+
55
+ 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
56
+
57
+ 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
58
+
59
+ 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an “AS IS” BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
60
+
61
+ 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
62
+
63
+ 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
64
+
65
+ END OF TERMS AND CONDITIONS
66
+
67
+
68
+
69
+ Other dependencies and licenses:
70
+
71
+
72
+ Open Source Software licensed under the Apache 2.0 license and Other Licenses of the Third-Party Components therein:
73
+ ---------------------------------------------
74
+ 1. basicsr
75
+ Copyright 2018-2020 BasicSR Authors
76
+
77
+
78
+ This BasicSR project is released under the Apache 2.0 license.
79
+
80
+ A copy of Apache 2.0 is included in this file.
81
+
82
+ StyleGAN2
83
+ The codes are modified from the repository stylegan2-pytorch. Many thanks to the author - Kim Seonghyeon 😊 for translating from the official TensorFlow codes to PyTorch ones. Here is the license of stylegan2-pytorch.
84
+ The official repository is https://github.com/NVlabs/stylegan2, and here is the NVIDIA license.
85
+ DFDNet
86
+ The codes are largely modified from the repository DFDNet. Their license is Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License.
87
+
88
+ Terms of the Nvidia License:
89
+ ---------------------------------------------
90
+
91
+ 1. Definitions
92
+
93
+ "Licensor" means any person or entity that distributes its Work.
94
+
95
+ "Software" means the original work of authorship made available under
96
+ this License.
97
+
98
+ "Work" means the Software and any additions to or derivative works of
99
+ the Software that are made available under this License.
100
+
101
+ "Nvidia Processors" means any central processing unit (CPU), graphics
102
+ processing unit (GPU), field-programmable gate array (FPGA),
103
+ application-specific integrated circuit (ASIC) or any combination
104
+ thereof designed, made, sold, or provided by Nvidia or its affiliates.
105
+
106
+ The terms "reproduce," "reproduction," "derivative works," and
107
+ "distribution" have the meaning as provided under U.S. copyright law;
108
+ provided, however, that for the purposes of this License, derivative
109
+ works shall not include works that remain separable from, or merely
110
+ link (or bind by name) to the interfaces of, the Work.
111
+
112
+ Works, including the Software, are "made available" under this License
113
+ by including in or with the Work either (a) a copyright notice
114
+ referencing the applicability of this License to the Work, or (b) a
115
+ copy of this License.
116
+
117
+ 2. License Grants
118
+
119
+ 2.1 Copyright Grant. Subject to the terms and conditions of this
120
+ License, each Licensor grants to you a perpetual, worldwide,
121
+ non-exclusive, royalty-free, copyright license to reproduce,
122
+ prepare derivative works of, publicly display, publicly perform,
123
+ sublicense and distribute its Work and any resulting derivative
124
+ works in any form.
125
+
126
+ 3. Limitations
127
+
128
+ 3.1 Redistribution. You may reproduce or distribute the Work only
129
+ if (a) you do so under this License, (b) you include a complete
130
+ copy of this License with your distribution, and (c) you retain
131
+ without modification any copyright, patent, trademark, or
132
+ attribution notices that are present in the Work.
133
+
134
+ 3.2 Derivative Works. You may specify that additional or different
135
+ terms apply to the use, reproduction, and distribution of your
136
+ derivative works of the Work ("Your Terms") only if (a) Your Terms
137
+ provide that the use limitation in Section 3.3 applies to your
138
+ derivative works, and (b) you identify the specific derivative
139
+ works that are subject to Your Terms. Notwithstanding Your Terms,
140
+ this License (including the redistribution requirements in Section
141
+ 3.1) will continue to apply to the Work itself.
142
+
143
+ 3.3 Use Limitation. The Work and any derivative works thereof only
144
+ may be used or intended for use non-commercially. The Work or
145
+ derivative works thereof may be used or intended for use by Nvidia
146
+ or its affiliates commercially or non-commercially. As used herein,
147
+ "non-commercially" means for research or evaluation purposes only.
148
+
149
+ 3.4 Patent Claims. If you bring or threaten to bring a patent claim
150
+ against any Licensor (including any claim, cross-claim or
151
+ counterclaim in a lawsuit) to enforce any patents that you allege
152
+ are infringed by any Work, then your rights under this License from
153
+ such Licensor (including the grants in Sections 2.1 and 2.2) will
154
+ terminate immediately.
155
+
156
+ 3.5 Trademarks. This License does not grant any rights to use any
157
+ Licensor's or its affiliates' names, logos, or trademarks, except
158
+ as necessary to reproduce the notices described in this License.
159
+
160
+ 3.6 Termination. If you violate any term of this License, then your
161
+ rights under this License (including the grants in Sections 2.1 and
162
+ 2.2) will terminate immediately.
163
+
164
+ 4. Disclaimer of Warranty.
165
+
166
+ THE WORK IS PROVIDED "AS IS" WITHOUT WARRANTIES OR CONDITIONS OF ANY
167
+ KIND, EITHER EXPRESS OR IMPLIED, INCLUDING WARRANTIES OR CONDITIONS OF
168
+ MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE, TITLE OR
169
+ NON-INFRINGEMENT. YOU BEAR THE RISK OF UNDERTAKING ANY ACTIVITIES UNDER
170
+ THIS LICENSE.
171
+
172
+ 5. Limitation of Liability.
173
+
174
+ EXCEPT AS PROHIBITED BY APPLICABLE LAW, IN NO EVENT AND UNDER NO LEGAL
175
+ THEORY, WHETHER IN TORT (INCLUDING NEGLIGENCE), CONTRACT, OR OTHERWISE
176
+ SHALL ANY LICENSOR BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY DIRECT,
177
+ INDIRECT, SPECIAL, INCIDENTAL, OR CONSEQUENTIAL DAMAGES ARISING OUT OF
178
+ OR RELATED TO THIS LICENSE, THE USE OR INABILITY TO USE THE WORK
179
+ (INCLUDING BUT NOT LIMITED TO LOSS OF GOODWILL, BUSINESS INTERRUPTION,
180
+ LOST PROFITS OR DATA, COMPUTER FAILURE OR MALFUNCTION, OR ANY OTHER
181
+ COMMERCIAL DAMAGES OR LOSSES), EVEN IF THE LICENSOR HAS BEEN ADVISED OF
182
+ THE POSSIBILITY OF SUCH DAMAGES.
183
+
184
+ MIT License
185
+
186
+ Copyright (c) 2019 Kim Seonghyeon
187
+
188
+ Permission is hereby granted, free of charge, to any person obtaining a copy
189
+ of this software and associated documentation files (the "Software"), to deal
190
+ in the Software without restriction, including without limitation the rights
191
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
192
+ copies of the Software, and to permit persons to whom the Software is
193
+ furnished to do so, subject to the following conditions:
194
+
195
+ The above copyright notice and this permission notice shall be included in all
196
+ copies or substantial portions of the Software.
197
+
198
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
199
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
200
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
201
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
202
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
203
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
204
+ SOFTWARE.
205
+
206
+
207
+
208
+ Open Source Software licensed under the BSD 3-Clause license:
209
+ ---------------------------------------------
210
+ 1. torchvision
211
+ Copyright (c) Soumith Chintala 2016,
212
+ All rights reserved.
213
+
214
+ 2. torch
215
+ Copyright (c) 2016- Facebook, Inc (Adam Paszke)
216
+ Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
217
+ Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
218
+ Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
219
+ Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
220
+ Copyright (c) 2011-2013 NYU (Clement Farabet)
221
+ Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
222
+ Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
223
+ Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
224
+
225
+
226
+ Terms of the BSD 3-Clause License:
227
+ ---------------------------------------------
228
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
229
+
230
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
231
+
232
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
233
+
234
+ 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
235
+
236
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
237
+
238
+
239
+
240
+ Open Source Software licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
241
+ ---------------------------------------------
242
+ 1. numpy
243
+ Copyright (c) 2005-2020, NumPy Developers.
244
+ All rights reserved.
245
+
246
+ A copy of BSD 3-Clause License is included in this file.
247
+
248
+ The NumPy repository and source distributions bundle several libraries that are
249
+ compatibly licensed. We list these here.
250
+
251
+ Name: Numpydoc
252
+ Files: doc/sphinxext/numpydoc/*
253
+ License: BSD-2-Clause
254
+ For details, see doc/sphinxext/LICENSE.txt
255
+
256
+ Name: scipy-sphinx-theme
257
+ Files: doc/scipy-sphinx-theme/*
258
+ License: BSD-3-Clause AND PSF-2.0 AND Apache-2.0
259
+ For details, see doc/scipy-sphinx-theme/LICENSE.txt
260
+
261
+ Name: lapack-lite
262
+ Files: numpy/linalg/lapack_lite/*
263
+ License: BSD-3-Clause
264
+ For details, see numpy/linalg/lapack_lite/LICENSE.txt
265
+
266
+ Name: tempita
267
+ Files: tools/npy_tempita/*
268
+ License: MIT
269
+ For details, see tools/npy_tempita/license.txt
270
+
271
+ Name: dragon4
272
+ Files: numpy/core/src/multiarray/dragon4.c
273
+ License: MIT
274
+ For license text, see numpy/core/src/multiarray/dragon4.c
275
+
276
+
277
+
278
+ Open Source Software licensed under the MIT license:
279
+ ---------------------------------------------
280
+ 1. facexlib
281
+ Copyright (c) 2020 Xintao Wang
282
+
283
+ 2. opencv-python
284
+ Copyright (c) Olli-Pekka Heinisuo
285
+ Please note that only files in cv2 package are used.
286
+
287
+
288
+ Terms of the MIT License:
289
+ ---------------------------------------------
290
+ Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the “Software”), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
291
+
292
+ The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
293
+
294
+ THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
295
+
296
+
297
+
298
+ Open Source Software licensed under the MIT license and Other Licenses of the Third-Party Components therein:
299
+ ---------------------------------------------
300
+ 1. tqdm
301
+ Copyright (c) 2013 noamraph
302
+
303
+ `tqdm` is a product of collaborative work.
304
+ Unless otherwise stated, all authors (see commit logs) retain copyright
305
+ for their respective work, and release the work under the MIT licence
306
+ (text below).
307
+
308
+ Exceptions or notable authors are listed below
309
+ in reverse chronological order:
310
+
311
+ * files: *
312
+ MPLv2.0 2015-2020 (c) Casper da Costa-Luis
313
+ [casperdcl](https://github.com/casperdcl).
314
+ * files: tqdm/_tqdm.py
315
+ MIT 2016 (c) [PR #96] on behalf of Google Inc.
316
+ * files: tqdm/_tqdm.py setup.py README.rst MANIFEST.in .gitignore
317
+ MIT 2013 (c) Noam Yorav-Raphael, original author.
318
+
319
+ [PR #96]: https://github.com/tqdm/tqdm/pull/96
320
+
321
+
322
+ Mozilla Public Licence (MPL) v. 2.0 - Exhibit A
323
+ -----------------------------------------------
324
+
325
+ This Source Code Form is subject to the terms of the
326
+ Mozilla Public License, v. 2.0.
327
+ If a copy of the MPL was not distributed with this file,
328
+ You can obtain one at https://mozilla.org/MPL/2.0/.
329
+
330
+
331
+ MIT License (MIT)
332
+ -----------------
333
+
334
+ Copyright (c) 2013 noamraph
335
+
336
+ Permission is hereby granted, free of charge, to any person obtaining a copy of
337
+ this software and associated documentation files (the "Software"), to deal in
338
+ the Software without restriction, including without limitation the rights to
339
+ use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
340
+ the Software, and to permit persons to whom the Software is furnished to do so,
341
+ subject to the following conditions:
342
+
343
+ The above copyright notice and this permission notice shall be included in all
344
+ copies or substantial portions of the Software.
345
+
346
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
347
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
348
+ FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
349
+ COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
350
+ IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
351
+ CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/face/LICENSE-codeformer ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ S-Lab License 1.0
2
+
3
+ Copyright 2022 S-Lab
4
+
5
+ Redistribution and use for non-commercial purpose in source and
6
+ binary forms, with or without modification, are permitted provided
7
+ that the following conditions are met:
8
+
9
+ 1. Redistributions of source code must retain the above copyright
10
+ notice, this list of conditions and the following disclaimer.
11
+
12
+ 2. Redistributions in binary form must reproduce the above copyright
13
+ notice, this list of conditions and the following disclaimer in
14
+ the documentation and/or other materials provided with the
15
+ distribution.
16
+
17
+ 3. Neither the name of the copyright holder nor the names of its
18
+ contributors may be used to endorse or promote products derived
19
+ from this software without specific prior written permission.
20
+
21
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22
+ "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23
+ LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24
+ A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25
+ HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26
+ SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27
+ LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28
+ DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29
+ THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30
+ (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31
+ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
32
+
33
+ In the event that redistribution and/or use for commercial purpose in
34
+ source or binary forms, with or without modification is required,
35
+ please contact the contributor(s) of the work.
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/face/arcface_arch.py ADDED
@@ -0,0 +1,265 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+
4
+ def conv3x3(inplanes, outplanes, stride=1):
5
+ """A simple wrapper for 3x3 convolution with padding.
6
+
7
+ Args:
8
+ inplanes (int): Channel number of inputs.
9
+ outplanes (int): Channel number of outputs.
10
+ stride (int): Stride in convolution. Default: 1.
11
+ """
12
+ return nn.Conv2d(
13
+ inplanes, outplanes, kernel_size=3, stride=stride, padding=1, bias=False
14
+ )
15
+
16
+
17
+ class BasicBlock(nn.Module):
18
+ """Basic residual block used in the ResNetArcFace architecture.
19
+
20
+ Args:
21
+ inplanes (int): Channel number of inputs.
22
+ planes (int): Channel number of outputs.
23
+ stride (int): Stride in convolution. Default: 1.
24
+ downsample (nn.Module): The downsample module. Default: None.
25
+ """
26
+
27
+ expansion = 1 # output channel expansion ratio
28
+
29
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
30
+ super(BasicBlock, self).__init__()
31
+ self.conv1 = conv3x3(inplanes, planes, stride)
32
+ self.bn1 = nn.BatchNorm2d(planes)
33
+ self.relu = nn.ReLU(inplace=True)
34
+ self.conv2 = conv3x3(planes, planes)
35
+ self.bn2 = nn.BatchNorm2d(planes)
36
+ self.downsample = downsample
37
+ self.stride = stride
38
+
39
+ def forward(self, x):
40
+ residual = x
41
+
42
+ out = self.conv1(x)
43
+ out = self.bn1(out)
44
+ out = self.relu(out)
45
+
46
+ out = self.conv2(out)
47
+ out = self.bn2(out)
48
+
49
+ if self.downsample is not None:
50
+ residual = self.downsample(x)
51
+
52
+ out += residual
53
+ out = self.relu(out)
54
+
55
+ return out
56
+
57
+
58
+ class IRBlock(nn.Module):
59
+ """Improved residual block (IR Block) used in the ResNetArcFace architecture.
60
+
61
+ Args:
62
+ inplanes (int): Channel number of inputs.
63
+ planes (int): Channel number of outputs.
64
+ stride (int): Stride in convolution. Default: 1.
65
+ downsample (nn.Module): The downsample module. Default: None.
66
+ use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
67
+ """
68
+
69
+ expansion = 1 # output channel expansion ratio
70
+
71
+ def __init__(self, inplanes, planes, stride=1, downsample=None, use_se=True):
72
+ super(IRBlock, self).__init__()
73
+ self.bn0 = nn.BatchNorm2d(inplanes)
74
+ self.conv1 = conv3x3(inplanes, inplanes)
75
+ self.bn1 = nn.BatchNorm2d(inplanes)
76
+ self.prelu = nn.PReLU()
77
+ self.conv2 = conv3x3(inplanes, planes, stride)
78
+ self.bn2 = nn.BatchNorm2d(planes)
79
+ self.downsample = downsample
80
+ self.stride = stride
81
+ self.use_se = use_se
82
+ if self.use_se:
83
+ self.se = SEBlock(planes)
84
+
85
+ def forward(self, x):
86
+ residual = x
87
+ out = self.bn0(x)
88
+ out = self.conv1(out)
89
+ out = self.bn1(out)
90
+ out = self.prelu(out)
91
+
92
+ out = self.conv2(out)
93
+ out = self.bn2(out)
94
+ if self.use_se:
95
+ out = self.se(out)
96
+
97
+ if self.downsample is not None:
98
+ residual = self.downsample(x)
99
+
100
+ out += residual
101
+ out = self.prelu(out)
102
+
103
+ return out
104
+
105
+
106
+ class Bottleneck(nn.Module):
107
+ """Bottleneck block used in the ResNetArcFace architecture.
108
+
109
+ Args:
110
+ inplanes (int): Channel number of inputs.
111
+ planes (int): Channel number of outputs.
112
+ stride (int): Stride in convolution. Default: 1.
113
+ downsample (nn.Module): The downsample module. Default: None.
114
+ """
115
+
116
+ expansion = 4 # output channel expansion ratio
117
+
118
+ def __init__(self, inplanes, planes, stride=1, downsample=None):
119
+ super(Bottleneck, self).__init__()
120
+ self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
121
+ self.bn1 = nn.BatchNorm2d(planes)
122
+ self.conv2 = nn.Conv2d(
123
+ planes, planes, kernel_size=3, stride=stride, padding=1, bias=False
124
+ )
125
+ self.bn2 = nn.BatchNorm2d(planes)
126
+ self.conv3 = nn.Conv2d(
127
+ planes, planes * self.expansion, kernel_size=1, bias=False
128
+ )
129
+ self.bn3 = nn.BatchNorm2d(planes * self.expansion)
130
+ self.relu = nn.ReLU(inplace=True)
131
+ self.downsample = downsample
132
+ self.stride = stride
133
+
134
+ def forward(self, x):
135
+ residual = x
136
+
137
+ out = self.conv1(x)
138
+ out = self.bn1(out)
139
+ out = self.relu(out)
140
+
141
+ out = self.conv2(out)
142
+ out = self.bn2(out)
143
+ out = self.relu(out)
144
+
145
+ out = self.conv3(out)
146
+ out = self.bn3(out)
147
+
148
+ if self.downsample is not None:
149
+ residual = self.downsample(x)
150
+
151
+ out += residual
152
+ out = self.relu(out)
153
+
154
+ return out
155
+
156
+
157
+ class SEBlock(nn.Module):
158
+ """The squeeze-and-excitation block (SEBlock) used in the IRBlock.
159
+
160
+ Args:
161
+ channel (int): Channel number of inputs.
162
+ reduction (int): Channel reduction ration. Default: 16.
163
+ """
164
+
165
+ def __init__(self, channel, reduction=16):
166
+ super(SEBlock, self).__init__()
167
+ self.avg_pool = nn.AdaptiveAvgPool2d(
168
+ 1
169
+ ) # pool to 1x1 without spatial information
170
+ self.fc = nn.Sequential(
171
+ nn.Linear(channel, channel // reduction),
172
+ nn.PReLU(),
173
+ nn.Linear(channel // reduction, channel),
174
+ nn.Sigmoid(),
175
+ )
176
+
177
+ def forward(self, x):
178
+ b, c, _, _ = x.size()
179
+ y = self.avg_pool(x).view(b, c)
180
+ y = self.fc(y).view(b, c, 1, 1)
181
+ return x * y
182
+
183
+
184
+ class ResNetArcFace(nn.Module):
185
+ """ArcFace with ResNet architectures.
186
+
187
+ Ref: ArcFace: Additive Angular Margin Loss for Deep Face Recognition.
188
+
189
+ Args:
190
+ block (str): Block used in the ArcFace architecture.
191
+ layers (tuple(int)): Block numbers in each layer.
192
+ use_se (bool): Whether use the SEBlock (squeeze and excitation block). Default: True.
193
+ """
194
+
195
+ def __init__(self, block, layers, use_se=True):
196
+ if block == "IRBlock":
197
+ block = IRBlock
198
+ self.inplanes = 64
199
+ self.use_se = use_se
200
+ super(ResNetArcFace, self).__init__()
201
+
202
+ self.conv1 = nn.Conv2d(1, 64, kernel_size=3, padding=1, bias=False)
203
+ self.bn1 = nn.BatchNorm2d(64)
204
+ self.prelu = nn.PReLU()
205
+ self.maxpool = nn.MaxPool2d(kernel_size=2, stride=2)
206
+ self.layer1 = self._make_layer(block, 64, layers[0])
207
+ self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
208
+ self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
209
+ self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
210
+ self.bn4 = nn.BatchNorm2d(512)
211
+ self.dropout = nn.Dropout()
212
+ self.fc5 = nn.Linear(512 * 8 * 8, 512)
213
+ self.bn5 = nn.BatchNorm1d(512)
214
+
215
+ # initialization
216
+ for m in self.modules():
217
+ if isinstance(m, nn.Conv2d):
218
+ nn.init.xavier_normal_(m.weight)
219
+ elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.BatchNorm1d):
220
+ nn.init.constant_(m.weight, 1)
221
+ nn.init.constant_(m.bias, 0)
222
+ elif isinstance(m, nn.Linear):
223
+ nn.init.xavier_normal_(m.weight)
224
+ nn.init.constant_(m.bias, 0)
225
+
226
+ def _make_layer(self, block, planes, num_blocks, stride=1):
227
+ downsample = None
228
+ if stride != 1 or self.inplanes != planes * block.expansion:
229
+ downsample = nn.Sequential(
230
+ nn.Conv2d(
231
+ self.inplanes,
232
+ planes * block.expansion,
233
+ kernel_size=1,
234
+ stride=stride,
235
+ bias=False,
236
+ ),
237
+ nn.BatchNorm2d(planes * block.expansion),
238
+ )
239
+ layers = []
240
+ layers.append(
241
+ block(self.inplanes, planes, stride, downsample, use_se=self.use_se)
242
+ )
243
+ self.inplanes = planes
244
+ for _ in range(1, num_blocks):
245
+ layers.append(block(self.inplanes, planes, use_se=self.use_se))
246
+
247
+ return nn.Sequential(*layers)
248
+
249
+ def forward(self, x):
250
+ x = self.conv1(x)
251
+ x = self.bn1(x)
252
+ x = self.prelu(x)
253
+ x = self.maxpool(x)
254
+
255
+ x = self.layer1(x)
256
+ x = self.layer2(x)
257
+ x = self.layer3(x)
258
+ x = self.layer4(x)
259
+ x = self.bn4(x)
260
+ x = self.dropout(x)
261
+ x = x.view(x.size(0), -1)
262
+ x = self.fc5(x)
263
+ x = self.bn5(x)
264
+
265
+ return x
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/face/codeformer.py ADDED
@@ -0,0 +1,790 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Modified from https://github.com/sczhou/CodeFormer
3
+ VQGAN code, adapted from the original created by the Unleashing Transformers authors:
4
+ https://github.com/samb-t/unleashing-transformers/blob/master/models/vqgan.py
5
+ This verison of the arch specifically was gathered from an old version of GFPGAN. If this is a problem, please contact me.
6
+ """
7
+ import math
8
+ from typing import Optional
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.nn.functional as F
13
+ import logging as logger
14
+ from torch import Tensor
15
+
16
+
17
+ class VectorQuantizer(nn.Module):
18
+ def __init__(self, codebook_size, emb_dim, beta):
19
+ super(VectorQuantizer, self).__init__()
20
+ self.codebook_size = codebook_size # number of embeddings
21
+ self.emb_dim = emb_dim # dimension of embedding
22
+ self.beta = beta # commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
23
+ self.embedding = nn.Embedding(self.codebook_size, self.emb_dim)
24
+ self.embedding.weight.data.uniform_(
25
+ -1.0 / self.codebook_size, 1.0 / self.codebook_size
26
+ )
27
+
28
+ def forward(self, z):
29
+ # reshape z -> (batch, height, width, channel) and flatten
30
+ z = z.permute(0, 2, 3, 1).contiguous()
31
+ z_flattened = z.view(-1, self.emb_dim)
32
+
33
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
34
+ d = (
35
+ (z_flattened**2).sum(dim=1, keepdim=True)
36
+ + (self.embedding.weight**2).sum(1)
37
+ - 2 * torch.matmul(z_flattened, self.embedding.weight.t())
38
+ )
39
+
40
+ mean_distance = torch.mean(d)
41
+ # find closest encodings
42
+ # min_encoding_indices = torch.argmin(d, dim=1).unsqueeze(1)
43
+ min_encoding_scores, min_encoding_indices = torch.topk(
44
+ d, 1, dim=1, largest=False
45
+ )
46
+ # [0-1], higher score, higher confidence
47
+ min_encoding_scores = torch.exp(-min_encoding_scores / 10)
48
+
49
+ min_encodings = torch.zeros(
50
+ min_encoding_indices.shape[0], self.codebook_size
51
+ ).to(z)
52
+ min_encodings.scatter_(1, min_encoding_indices, 1)
53
+
54
+ # get quantized latent vectors
55
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
56
+ # compute loss for embedding
57
+ loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
58
+ (z_q - z.detach()) ** 2
59
+ )
60
+ # preserve gradients
61
+ z_q = z + (z_q - z).detach()
62
+
63
+ # perplexity
64
+ e_mean = torch.mean(min_encodings, dim=0)
65
+ perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
66
+ # reshape back to match original input shape
67
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
68
+
69
+ return (
70
+ z_q,
71
+ loss,
72
+ {
73
+ "perplexity": perplexity,
74
+ "min_encodings": min_encodings,
75
+ "min_encoding_indices": min_encoding_indices,
76
+ "min_encoding_scores": min_encoding_scores,
77
+ "mean_distance": mean_distance,
78
+ },
79
+ )
80
+
81
+ def get_codebook_feat(self, indices, shape):
82
+ # input indices: batch*token_num -> (batch*token_num)*1
83
+ # shape: batch, height, width, channel
84
+ indices = indices.view(-1, 1)
85
+ min_encodings = torch.zeros(indices.shape[0], self.codebook_size).to(indices)
86
+ min_encodings.scatter_(1, indices, 1)
87
+ # get quantized latent vectors
88
+ z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
89
+
90
+ if shape is not None: # reshape back to match original input shape
91
+ z_q = z_q.view(shape).permute(0, 3, 1, 2).contiguous()
92
+
93
+ return z_q
94
+
95
+
96
+ class GumbelQuantizer(nn.Module):
97
+ def __init__(
98
+ self,
99
+ codebook_size,
100
+ emb_dim,
101
+ num_hiddens,
102
+ straight_through=False,
103
+ kl_weight=5e-4,
104
+ temp_init=1.0,
105
+ ):
106
+ super().__init__()
107
+ self.codebook_size = codebook_size # number of embeddings
108
+ self.emb_dim = emb_dim # dimension of embedding
109
+ self.straight_through = straight_through
110
+ self.temperature = temp_init
111
+ self.kl_weight = kl_weight
112
+ self.proj = nn.Conv2d(
113
+ num_hiddens, codebook_size, 1
114
+ ) # projects last encoder layer to quantized logits
115
+ self.embed = nn.Embedding(codebook_size, emb_dim)
116
+
117
+ def forward(self, z):
118
+ hard = self.straight_through if self.training else True
119
+
120
+ logits = self.proj(z)
121
+
122
+ soft_one_hot = F.gumbel_softmax(logits, tau=self.temperature, dim=1, hard=hard)
123
+
124
+ z_q = torch.einsum("b n h w, n d -> b d h w", soft_one_hot, self.embed.weight)
125
+
126
+ # + kl divergence to the prior loss
127
+ qy = F.softmax(logits, dim=1)
128
+ diff = (
129
+ self.kl_weight
130
+ * torch.sum(qy * torch.log(qy * self.codebook_size + 1e-10), dim=1).mean()
131
+ )
132
+ min_encoding_indices = soft_one_hot.argmax(dim=1)
133
+
134
+ return z_q, diff, {"min_encoding_indices": min_encoding_indices}
135
+
136
+
137
+ class Downsample(nn.Module):
138
+ def __init__(self, in_channels):
139
+ super().__init__()
140
+ self.conv = torch.nn.Conv2d(
141
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
142
+ )
143
+
144
+ def forward(self, x):
145
+ pad = (0, 1, 0, 1)
146
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
147
+ x = self.conv(x)
148
+ return x
149
+
150
+
151
+ class Upsample(nn.Module):
152
+ def __init__(self, in_channels):
153
+ super().__init__()
154
+ self.conv = nn.Conv2d(
155
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
156
+ )
157
+
158
+ def forward(self, x):
159
+ x = F.interpolate(x, scale_factor=2.0, mode="nearest")
160
+ x = self.conv(x)
161
+
162
+ return x
163
+
164
+
165
+ class AttnBlock(nn.Module):
166
+ def __init__(self, in_channels):
167
+ super().__init__()
168
+ self.in_channels = in_channels
169
+
170
+ self.norm = normalize(in_channels)
171
+ self.q = torch.nn.Conv2d(
172
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
173
+ )
174
+ self.k = torch.nn.Conv2d(
175
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
176
+ )
177
+ self.v = torch.nn.Conv2d(
178
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
179
+ )
180
+ self.proj_out = torch.nn.Conv2d(
181
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
182
+ )
183
+
184
+ def forward(self, x):
185
+ h_ = x
186
+ h_ = self.norm(h_)
187
+ q = self.q(h_)
188
+ k = self.k(h_)
189
+ v = self.v(h_)
190
+
191
+ # compute attention
192
+ b, c, h, w = q.shape
193
+ q = q.reshape(b, c, h * w)
194
+ q = q.permute(0, 2, 1)
195
+ k = k.reshape(b, c, h * w)
196
+ w_ = torch.bmm(q, k)
197
+ w_ = w_ * (int(c) ** (-0.5))
198
+ w_ = F.softmax(w_, dim=2)
199
+
200
+ # attend to values
201
+ v = v.reshape(b, c, h * w)
202
+ w_ = w_.permute(0, 2, 1)
203
+ h_ = torch.bmm(v, w_)
204
+ h_ = h_.reshape(b, c, h, w)
205
+
206
+ h_ = self.proj_out(h_)
207
+
208
+ return x + h_
209
+
210
+
211
+ class Encoder(nn.Module):
212
+ def __init__(
213
+ self,
214
+ in_channels,
215
+ nf,
216
+ out_channels,
217
+ ch_mult,
218
+ num_res_blocks,
219
+ resolution,
220
+ attn_resolutions,
221
+ ):
222
+ super().__init__()
223
+ self.nf = nf
224
+ self.num_resolutions = len(ch_mult)
225
+ self.num_res_blocks = num_res_blocks
226
+ self.resolution = resolution
227
+ self.attn_resolutions = attn_resolutions
228
+
229
+ curr_res = self.resolution
230
+ in_ch_mult = (1,) + tuple(ch_mult)
231
+
232
+ blocks = []
233
+ # initial convultion
234
+ blocks.append(nn.Conv2d(in_channels, nf, kernel_size=3, stride=1, padding=1))
235
+
236
+ # residual and downsampling blocks, with attention on smaller res (16x16)
237
+ for i in range(self.num_resolutions):
238
+ block_in_ch = nf * in_ch_mult[i]
239
+ block_out_ch = nf * ch_mult[i]
240
+ for _ in range(self.num_res_blocks):
241
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
242
+ block_in_ch = block_out_ch
243
+ if curr_res in attn_resolutions:
244
+ blocks.append(AttnBlock(block_in_ch))
245
+
246
+ if i != self.num_resolutions - 1:
247
+ blocks.append(Downsample(block_in_ch))
248
+ curr_res = curr_res // 2
249
+
250
+ # non-local attention block
251
+ blocks.append(ResBlock(block_in_ch, block_in_ch)) # type: ignore
252
+ blocks.append(AttnBlock(block_in_ch)) # type: ignore
253
+ blocks.append(ResBlock(block_in_ch, block_in_ch)) # type: ignore
254
+
255
+ # normalise and convert to latent size
256
+ blocks.append(normalize(block_in_ch)) # type: ignore
257
+ blocks.append(
258
+ nn.Conv2d(block_in_ch, out_channels, kernel_size=3, stride=1, padding=1) # type: ignore
259
+ )
260
+ self.blocks = nn.ModuleList(blocks)
261
+
262
+ def forward(self, x):
263
+ for block in self.blocks:
264
+ x = block(x)
265
+
266
+ return x
267
+
268
+
269
+ class Generator(nn.Module):
270
+ def __init__(self, nf, ch_mult, res_blocks, img_size, attn_resolutions, emb_dim):
271
+ super().__init__()
272
+ self.nf = nf
273
+ self.ch_mult = ch_mult
274
+ self.num_resolutions = len(self.ch_mult)
275
+ self.num_res_blocks = res_blocks
276
+ self.resolution = img_size
277
+ self.attn_resolutions = attn_resolutions
278
+ self.in_channels = emb_dim
279
+ self.out_channels = 3
280
+ block_in_ch = self.nf * self.ch_mult[-1]
281
+ curr_res = self.resolution // 2 ** (self.num_resolutions - 1)
282
+
283
+ blocks = []
284
+ # initial conv
285
+ blocks.append(
286
+ nn.Conv2d(self.in_channels, block_in_ch, kernel_size=3, stride=1, padding=1)
287
+ )
288
+
289
+ # non-local attention block
290
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
291
+ blocks.append(AttnBlock(block_in_ch))
292
+ blocks.append(ResBlock(block_in_ch, block_in_ch))
293
+
294
+ for i in reversed(range(self.num_resolutions)):
295
+ block_out_ch = self.nf * self.ch_mult[i]
296
+
297
+ for _ in range(self.num_res_blocks):
298
+ blocks.append(ResBlock(block_in_ch, block_out_ch))
299
+ block_in_ch = block_out_ch
300
+
301
+ if curr_res in self.attn_resolutions:
302
+ blocks.append(AttnBlock(block_in_ch))
303
+
304
+ if i != 0:
305
+ blocks.append(Upsample(block_in_ch))
306
+ curr_res = curr_res * 2
307
+
308
+ blocks.append(normalize(block_in_ch))
309
+ blocks.append(
310
+ nn.Conv2d(
311
+ block_in_ch, self.out_channels, kernel_size=3, stride=1, padding=1
312
+ )
313
+ )
314
+
315
+ self.blocks = nn.ModuleList(blocks)
316
+
317
+ def forward(self, x):
318
+ for block in self.blocks:
319
+ x = block(x)
320
+
321
+ return x
322
+
323
+
324
+ class VQAutoEncoder(nn.Module):
325
+ def __init__(
326
+ self,
327
+ img_size,
328
+ nf,
329
+ ch_mult,
330
+ quantizer="nearest",
331
+ res_blocks=2,
332
+ attn_resolutions=[16],
333
+ codebook_size=1024,
334
+ emb_dim=256,
335
+ beta=0.25,
336
+ gumbel_straight_through=False,
337
+ gumbel_kl_weight=1e-8,
338
+ model_path=None,
339
+ ):
340
+ super().__init__()
341
+ self.in_channels = 3
342
+ self.nf = nf
343
+ self.n_blocks = res_blocks
344
+ self.codebook_size = codebook_size
345
+ self.embed_dim = emb_dim
346
+ self.ch_mult = ch_mult
347
+ self.resolution = img_size
348
+ self.attn_resolutions = attn_resolutions
349
+ self.quantizer_type = quantizer
350
+ self.encoder = Encoder(
351
+ self.in_channels,
352
+ self.nf,
353
+ self.embed_dim,
354
+ self.ch_mult,
355
+ self.n_blocks,
356
+ self.resolution,
357
+ self.attn_resolutions,
358
+ )
359
+ if self.quantizer_type == "nearest":
360
+ self.beta = beta # 0.25
361
+ self.quantize = VectorQuantizer(
362
+ self.codebook_size, self.embed_dim, self.beta
363
+ )
364
+ elif self.quantizer_type == "gumbel":
365
+ self.gumbel_num_hiddens = emb_dim
366
+ self.straight_through = gumbel_straight_through
367
+ self.kl_weight = gumbel_kl_weight
368
+ self.quantize = GumbelQuantizer(
369
+ self.codebook_size,
370
+ self.embed_dim,
371
+ self.gumbel_num_hiddens,
372
+ self.straight_through,
373
+ self.kl_weight,
374
+ )
375
+ self.generator = Generator(
376
+ nf, ch_mult, res_blocks, img_size, attn_resolutions, emb_dim
377
+ )
378
+
379
+ if model_path is not None:
380
+ chkpt = torch.load(model_path, map_location="cpu")
381
+ if "params_ema" in chkpt:
382
+ self.load_state_dict(
383
+ torch.load(model_path, map_location="cpu")["params_ema"]
384
+ )
385
+ logger.info(f"vqgan is loaded from: {model_path} [params_ema]")
386
+ elif "params" in chkpt:
387
+ self.load_state_dict(
388
+ torch.load(model_path, map_location="cpu")["params"]
389
+ )
390
+ logger.info(f"vqgan is loaded from: {model_path} [params]")
391
+ else:
392
+ raise ValueError("Wrong params!")
393
+
394
+ def forward(self, x):
395
+ x = self.encoder(x)
396
+ quant, codebook_loss, quant_stats = self.quantize(x)
397
+ x = self.generator(quant)
398
+ return x, codebook_loss, quant_stats
399
+
400
+
401
+ def calc_mean_std(feat, eps=1e-5):
402
+ """Calculate mean and std for adaptive_instance_normalization.
403
+ Args:
404
+ feat (Tensor): 4D tensor.
405
+ eps (float): A small value added to the variance to avoid
406
+ divide-by-zero. Default: 1e-5.
407
+ """
408
+ size = feat.size()
409
+ assert len(size) == 4, "The input feature should be 4D tensor."
410
+ b, c = size[:2]
411
+ feat_var = feat.view(b, c, -1).var(dim=2) + eps
412
+ feat_std = feat_var.sqrt().view(b, c, 1, 1)
413
+ feat_mean = feat.view(b, c, -1).mean(dim=2).view(b, c, 1, 1)
414
+ return feat_mean, feat_std
415
+
416
+
417
+ def adaptive_instance_normalization(content_feat, style_feat):
418
+ """Adaptive instance normalization.
419
+ Adjust the reference features to have the similar color and illuminations
420
+ as those in the degradate features.
421
+ Args:
422
+ content_feat (Tensor): The reference feature.
423
+ style_feat (Tensor): The degradate features.
424
+ """
425
+ size = content_feat.size()
426
+ style_mean, style_std = calc_mean_std(style_feat)
427
+ content_mean, content_std = calc_mean_std(content_feat)
428
+ normalized_feat = (content_feat - content_mean.expand(size)) / content_std.expand(
429
+ size
430
+ )
431
+ return normalized_feat * style_std.expand(size) + style_mean.expand(size)
432
+
433
+
434
+ class PositionEmbeddingSine(nn.Module):
435
+ """
436
+ This is a more standard version of the position embedding, very similar to the one
437
+ used by the Attention is all you need paper, generalized to work on images.
438
+ """
439
+
440
+ def __init__(
441
+ self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
442
+ ):
443
+ super().__init__()
444
+ self.num_pos_feats = num_pos_feats
445
+ self.temperature = temperature
446
+ self.normalize = normalize
447
+ if scale is not None and normalize is False:
448
+ raise ValueError("normalize should be True if scale is passed")
449
+ if scale is None:
450
+ scale = 2 * math.pi
451
+ self.scale = scale
452
+
453
+ def forward(self, x, mask=None):
454
+ if mask is None:
455
+ mask = torch.zeros(
456
+ (x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool
457
+ )
458
+ not_mask = ~mask # pylint: disable=invalid-unary-operand-type
459
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
460
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
461
+ if self.normalize:
462
+ eps = 1e-6
463
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
464
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
465
+
466
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
467
+ dim_t = self.temperature ** (2 * (dim_t // 2) / self.num_pos_feats)
468
+
469
+ pos_x = x_embed[:, :, :, None] / dim_t
470
+ pos_y = y_embed[:, :, :, None] / dim_t
471
+ pos_x = torch.stack(
472
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
473
+ ).flatten(3)
474
+ pos_y = torch.stack(
475
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
476
+ ).flatten(3)
477
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
478
+ return pos
479
+
480
+
481
+ def _get_activation_fn(activation):
482
+ """Return an activation function given a string"""
483
+ if activation == "relu":
484
+ return F.relu
485
+ if activation == "gelu":
486
+ return F.gelu
487
+ if activation == "glu":
488
+ return F.glu
489
+ raise RuntimeError(f"activation should be relu/gelu, not {activation}.")
490
+
491
+
492
+ class TransformerSALayer(nn.Module):
493
+ def __init__(
494
+ self, embed_dim, nhead=8, dim_mlp=2048, dropout=0.0, activation="gelu"
495
+ ):
496
+ super().__init__()
497
+ self.self_attn = nn.MultiheadAttention(embed_dim, nhead, dropout=dropout)
498
+ # Implementation of Feedforward model - MLP
499
+ self.linear1 = nn.Linear(embed_dim, dim_mlp)
500
+ self.dropout = nn.Dropout(dropout)
501
+ self.linear2 = nn.Linear(dim_mlp, embed_dim)
502
+
503
+ self.norm1 = nn.LayerNorm(embed_dim)
504
+ self.norm2 = nn.LayerNorm(embed_dim)
505
+ self.dropout1 = nn.Dropout(dropout)
506
+ self.dropout2 = nn.Dropout(dropout)
507
+
508
+ self.activation = _get_activation_fn(activation)
509
+
510
+ def with_pos_embed(self, tensor, pos: Optional[Tensor]):
511
+ return tensor if pos is None else tensor + pos
512
+
513
+ def forward(
514
+ self,
515
+ tgt,
516
+ tgt_mask: Optional[Tensor] = None,
517
+ tgt_key_padding_mask: Optional[Tensor] = None,
518
+ query_pos: Optional[Tensor] = None,
519
+ ):
520
+ # self attention
521
+ tgt2 = self.norm1(tgt)
522
+ q = k = self.with_pos_embed(tgt2, query_pos)
523
+ tgt2 = self.self_attn(
524
+ q, k, value=tgt2, attn_mask=tgt_mask, key_padding_mask=tgt_key_padding_mask
525
+ )[0]
526
+ tgt = tgt + self.dropout1(tgt2)
527
+
528
+ # ffn
529
+ tgt2 = self.norm2(tgt)
530
+ tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt2))))
531
+ tgt = tgt + self.dropout2(tgt2)
532
+ return tgt
533
+
534
+
535
+ def normalize(in_channels):
536
+ return torch.nn.GroupNorm(
537
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
538
+ )
539
+
540
+
541
+ @torch.jit.script # type: ignore
542
+ def swish(x):
543
+ return x * torch.sigmoid(x)
544
+
545
+
546
+ class ResBlock(nn.Module):
547
+ def __init__(self, in_channels, out_channels=None):
548
+ super(ResBlock, self).__init__()
549
+ self.in_channels = in_channels
550
+ self.out_channels = in_channels if out_channels is None else out_channels
551
+ self.norm1 = normalize(in_channels)
552
+ self.conv1 = nn.Conv2d(
553
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1 # type: ignore
554
+ )
555
+ self.norm2 = normalize(out_channels)
556
+ self.conv2 = nn.Conv2d(
557
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1 # type: ignore
558
+ )
559
+ if self.in_channels != self.out_channels:
560
+ self.conv_out = nn.Conv2d(
561
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0 # type: ignore
562
+ )
563
+
564
+ def forward(self, x_in):
565
+ x = x_in
566
+ x = self.norm1(x)
567
+ x = swish(x)
568
+ x = self.conv1(x)
569
+ x = self.norm2(x)
570
+ x = swish(x)
571
+ x = self.conv2(x)
572
+ if self.in_channels != self.out_channels:
573
+ x_in = self.conv_out(x_in)
574
+
575
+ return x + x_in
576
+
577
+
578
+ class Fuse_sft_block(nn.Module):
579
+ def __init__(self, in_ch, out_ch):
580
+ super().__init__()
581
+ self.encode_enc = ResBlock(2 * in_ch, out_ch)
582
+
583
+ self.scale = nn.Sequential(
584
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
585
+ nn.LeakyReLU(0.2, True),
586
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
587
+ )
588
+
589
+ self.shift = nn.Sequential(
590
+ nn.Conv2d(in_ch, out_ch, kernel_size=3, padding=1),
591
+ nn.LeakyReLU(0.2, True),
592
+ nn.Conv2d(out_ch, out_ch, kernel_size=3, padding=1),
593
+ )
594
+
595
+ def forward(self, enc_feat, dec_feat, w=1):
596
+ enc_feat = self.encode_enc(torch.cat([enc_feat, dec_feat], dim=1))
597
+ scale = self.scale(enc_feat)
598
+ shift = self.shift(enc_feat)
599
+ residual = w * (dec_feat * scale + shift)
600
+ out = dec_feat + residual
601
+ return out
602
+
603
+
604
+ class CodeFormer(VQAutoEncoder):
605
+ def __init__(self, state_dict):
606
+ dim_embd = 512
607
+ n_head = 8
608
+ n_layers = 9
609
+ codebook_size = 1024
610
+ latent_size = 256
611
+ connect_list = ["32", "64", "128", "256"]
612
+ fix_modules = ["quantize", "generator"]
613
+
614
+ # This is just a guess as I only have one model to look at
615
+ position_emb = state_dict["position_emb"]
616
+ dim_embd = position_emb.shape[1]
617
+ latent_size = position_emb.shape[0]
618
+
619
+ try:
620
+ n_layers = len(
621
+ set([x.split(".")[1] for x in state_dict.keys() if "ft_layers" in x])
622
+ )
623
+ except:
624
+ pass
625
+
626
+ codebook_size = state_dict["quantize.embedding.weight"].shape[0]
627
+
628
+ # This is also just another guess
629
+ n_head_exp = (
630
+ state_dict["ft_layers.0.self_attn.in_proj_weight"].shape[0] // dim_embd
631
+ )
632
+ n_head = 2**n_head_exp
633
+
634
+ in_nc = state_dict["encoder.blocks.0.weight"].shape[1]
635
+
636
+ self.model_arch = "CodeFormer"
637
+ self.sub_type = "Face SR"
638
+ self.scale = 8
639
+ self.in_nc = in_nc
640
+ self.out_nc = in_nc
641
+
642
+ self.state = state_dict
643
+
644
+ self.supports_fp16 = False
645
+ self.supports_bf16 = True
646
+ self.min_size_restriction = 16
647
+
648
+ super(CodeFormer, self).__init__(
649
+ 512, 64, [1, 2, 2, 4, 4, 8], "nearest", 2, [16], codebook_size
650
+ )
651
+
652
+ if fix_modules is not None:
653
+ for module in fix_modules:
654
+ for param in getattr(self, module).parameters():
655
+ param.requires_grad = False
656
+
657
+ self.connect_list = connect_list
658
+ self.n_layers = n_layers
659
+ self.dim_embd = dim_embd
660
+ self.dim_mlp = dim_embd * 2
661
+
662
+ self.position_emb = nn.Parameter(torch.zeros(latent_size, self.dim_embd)) # type: ignore
663
+ self.feat_emb = nn.Linear(256, self.dim_embd)
664
+
665
+ # transformer
666
+ self.ft_layers = nn.Sequential(
667
+ *[
668
+ TransformerSALayer(
669
+ embed_dim=dim_embd, nhead=n_head, dim_mlp=self.dim_mlp, dropout=0.0
670
+ )
671
+ for _ in range(self.n_layers)
672
+ ]
673
+ )
674
+
675
+ # logits_predict head
676
+ self.idx_pred_layer = nn.Sequential(
677
+ nn.LayerNorm(dim_embd), nn.Linear(dim_embd, codebook_size, bias=False)
678
+ )
679
+
680
+ self.channels = {
681
+ "16": 512,
682
+ "32": 256,
683
+ "64": 256,
684
+ "128": 128,
685
+ "256": 128,
686
+ "512": 64,
687
+ }
688
+
689
+ # after second residual block for > 16, before attn layer for ==16
690
+ self.fuse_encoder_block = {
691
+ "512": 2,
692
+ "256": 5,
693
+ "128": 8,
694
+ "64": 11,
695
+ "32": 14,
696
+ "16": 18,
697
+ }
698
+ # after first residual block for > 16, before attn layer for ==16
699
+ self.fuse_generator_block = {
700
+ "16": 6,
701
+ "32": 9,
702
+ "64": 12,
703
+ "128": 15,
704
+ "256": 18,
705
+ "512": 21,
706
+ }
707
+
708
+ # fuse_convs_dict
709
+ self.fuse_convs_dict = nn.ModuleDict()
710
+ for f_size in self.connect_list:
711
+ in_ch = self.channels[f_size]
712
+ self.fuse_convs_dict[f_size] = Fuse_sft_block(in_ch, in_ch)
713
+
714
+ self.load_state_dict(state_dict)
715
+
716
+ def _init_weights(self, module):
717
+ if isinstance(module, (nn.Linear, nn.Embedding)):
718
+ module.weight.data.normal_(mean=0.0, std=0.02)
719
+ if isinstance(module, nn.Linear) and module.bias is not None:
720
+ module.bias.data.zero_()
721
+ elif isinstance(module, nn.LayerNorm):
722
+ module.bias.data.zero_()
723
+ module.weight.data.fill_(1.0)
724
+
725
+ def forward(self, x, weight=0.5, **kwargs):
726
+ detach_16 = True
727
+ code_only = False
728
+ adain = True
729
+ # ################### Encoder #####################
730
+ enc_feat_dict = {}
731
+ out_list = [self.fuse_encoder_block[f_size] for f_size in self.connect_list]
732
+ for i, block in enumerate(self.encoder.blocks):
733
+ x = block(x)
734
+ if i in out_list:
735
+ enc_feat_dict[str(x.shape[-1])] = x.clone()
736
+
737
+ lq_feat = x
738
+ # ################# Transformer ###################
739
+ # quant_feat, codebook_loss, quant_stats = self.quantize(lq_feat)
740
+ pos_emb = self.position_emb.unsqueeze(1).repeat(1, x.shape[0], 1)
741
+ # BCHW -> BC(HW) -> (HW)BC
742
+ feat_emb = self.feat_emb(lq_feat.flatten(2).permute(2, 0, 1))
743
+ query_emb = feat_emb
744
+ # Transformer encoder
745
+ for layer in self.ft_layers:
746
+ query_emb = layer(query_emb, query_pos=pos_emb)
747
+
748
+ # output logits
749
+ logits = self.idx_pred_layer(query_emb) # (hw)bn
750
+ logits = logits.permute(1, 0, 2) # (hw)bn -> b(hw)n
751
+
752
+ if code_only: # for training stage II
753
+ # logits doesn't need softmax before cross_entropy loss
754
+ return logits, lq_feat
755
+
756
+ # ################# Quantization ###################
757
+ # if self.training:
758
+ # quant_feat = torch.einsum('btn,nc->btc', [soft_one_hot, self.quantize.embedding.weight])
759
+ # # b(hw)c -> bc(hw) -> bchw
760
+ # quant_feat = quant_feat.permute(0,2,1).view(lq_feat.shape)
761
+ # ------------
762
+ soft_one_hot = F.softmax(logits, dim=2)
763
+ _, top_idx = torch.topk(soft_one_hot, 1, dim=2)
764
+ quant_feat = self.quantize.get_codebook_feat(
765
+ top_idx, shape=[x.shape[0], 16, 16, 256] # type: ignore
766
+ )
767
+ # preserve gradients
768
+ # quant_feat = lq_feat + (quant_feat - lq_feat).detach()
769
+
770
+ if detach_16:
771
+ quant_feat = quant_feat.detach() # for training stage III
772
+ if adain:
773
+ quant_feat = adaptive_instance_normalization(quant_feat, lq_feat)
774
+
775
+ # ################## Generator ####################
776
+ x = quant_feat
777
+ fuse_list = [self.fuse_generator_block[f_size] for f_size in self.connect_list]
778
+
779
+ for i, block in enumerate(self.generator.blocks):
780
+ x = block(x)
781
+ if i in fuse_list: # fuse after i-th block
782
+ f_size = str(x.shape[-1])
783
+ if weight > 0:
784
+ x = self.fuse_convs_dict[f_size](
785
+ enc_feat_dict[f_size].detach(), x, weight
786
+ )
787
+ out = x
788
+ # logits doesn't need softmax before cross_entropy loss
789
+ # return out, logits, lq_feat
790
+ return out, logits
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/face/fused_act.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ # type: ignore
3
+ # modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501
4
+
5
+ import torch
6
+ from torch import nn
7
+ from torch.autograd import Function
8
+
9
+ fused_act_ext = None
10
+
11
+
12
+ class FusedLeakyReLUFunctionBackward(Function):
13
+ @staticmethod
14
+ def forward(ctx, grad_output, out, negative_slope, scale):
15
+ ctx.save_for_backward(out)
16
+ ctx.negative_slope = negative_slope
17
+ ctx.scale = scale
18
+
19
+ empty = grad_output.new_empty(0)
20
+
21
+ grad_input = fused_act_ext.fused_bias_act(
22
+ grad_output, empty, out, 3, 1, negative_slope, scale
23
+ )
24
+
25
+ dim = [0]
26
+
27
+ if grad_input.ndim > 2:
28
+ dim += list(range(2, grad_input.ndim))
29
+
30
+ grad_bias = grad_input.sum(dim).detach()
31
+
32
+ return grad_input, grad_bias
33
+
34
+ @staticmethod
35
+ def backward(ctx, gradgrad_input, gradgrad_bias):
36
+ (out,) = ctx.saved_tensors
37
+ gradgrad_out = fused_act_ext.fused_bias_act(
38
+ gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
39
+ )
40
+
41
+ return gradgrad_out, None, None, None
42
+
43
+
44
+ class FusedLeakyReLUFunction(Function):
45
+ @staticmethod
46
+ def forward(ctx, input, bias, negative_slope, scale):
47
+ empty = input.new_empty(0)
48
+ out = fused_act_ext.fused_bias_act(
49
+ input, bias, empty, 3, 0, negative_slope, scale
50
+ )
51
+ ctx.save_for_backward(out)
52
+ ctx.negative_slope = negative_slope
53
+ ctx.scale = scale
54
+
55
+ return out
56
+
57
+ @staticmethod
58
+ def backward(ctx, grad_output):
59
+ (out,) = ctx.saved_tensors
60
+
61
+ grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(
62
+ grad_output, out, ctx.negative_slope, ctx.scale
63
+ )
64
+
65
+ return grad_input, grad_bias, None, None
66
+
67
+
68
+ class FusedLeakyReLU(nn.Module):
69
+ def __init__(self, channel, negative_slope=0.2, scale=2**0.5):
70
+ super().__init__()
71
+
72
+ self.bias = nn.Parameter(torch.zeros(channel))
73
+ self.negative_slope = negative_slope
74
+ self.scale = scale
75
+
76
+ def forward(self, input):
77
+ return fused_leaky_relu(input, self.bias, self.negative_slope, self.scale)
78
+
79
+
80
+ def fused_leaky_relu(input, bias, negative_slope=0.2, scale=2**0.5):
81
+ return FusedLeakyReLUFunction.apply(input, bias, negative_slope, scale)
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/face/gfpgan_bilinear_arch.py ADDED
@@ -0,0 +1,389 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ # type: ignore
3
+ import math
4
+ import random
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+ from .gfpganv1_arch import ResUpBlock
10
+ from .stylegan2_bilinear_arch import (
11
+ ConvLayer,
12
+ EqualConv2d,
13
+ EqualLinear,
14
+ ResBlock,
15
+ ScaledLeakyReLU,
16
+ StyleGAN2GeneratorBilinear,
17
+ )
18
+
19
+
20
+ class StyleGAN2GeneratorBilinearSFT(StyleGAN2GeneratorBilinear):
21
+ """StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
22
+ It is the bilinear version. It does not use the complicated UpFirDnSmooth function that is not friendly for
23
+ deployment. It can be easily converted to the clean version: StyleGAN2GeneratorCSFT.
24
+ Args:
25
+ out_size (int): The spatial size of outputs.
26
+ num_style_feat (int): Channel number of style features. Default: 512.
27
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
28
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
29
+ lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
30
+ narrow (float): The narrow ratio for channels. Default: 1.
31
+ sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ out_size,
37
+ num_style_feat=512,
38
+ num_mlp=8,
39
+ channel_multiplier=2,
40
+ lr_mlp=0.01,
41
+ narrow=1,
42
+ sft_half=False,
43
+ ):
44
+ super(StyleGAN2GeneratorBilinearSFT, self).__init__(
45
+ out_size,
46
+ num_style_feat=num_style_feat,
47
+ num_mlp=num_mlp,
48
+ channel_multiplier=channel_multiplier,
49
+ lr_mlp=lr_mlp,
50
+ narrow=narrow,
51
+ )
52
+ self.sft_half = sft_half
53
+
54
+ def forward(
55
+ self,
56
+ styles,
57
+ conditions,
58
+ input_is_latent=False,
59
+ noise=None,
60
+ randomize_noise=True,
61
+ truncation=1,
62
+ truncation_latent=None,
63
+ inject_index=None,
64
+ return_latents=False,
65
+ ):
66
+ """Forward function for StyleGAN2GeneratorBilinearSFT.
67
+ Args:
68
+ styles (list[Tensor]): Sample codes of styles.
69
+ conditions (list[Tensor]): SFT conditions to generators.
70
+ input_is_latent (bool): Whether input is latent style. Default: False.
71
+ noise (Tensor | None): Input noise or None. Default: None.
72
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
73
+ truncation (float): The truncation ratio. Default: 1.
74
+ truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
75
+ inject_index (int | None): The injection index for mixing noise. Default: None.
76
+ return_latents (bool): Whether to return style latents. Default: False.
77
+ """
78
+ # style codes -> latents with Style MLP layer
79
+ if not input_is_latent:
80
+ styles = [self.style_mlp(s) for s in styles]
81
+ # noises
82
+ if noise is None:
83
+ if randomize_noise:
84
+ noise = [None] * self.num_layers # for each style conv layer
85
+ else: # use the stored noise
86
+ noise = [
87
+ getattr(self.noises, f"noise{i}") for i in range(self.num_layers)
88
+ ]
89
+ # style truncation
90
+ if truncation < 1:
91
+ style_truncation = []
92
+ for style in styles:
93
+ style_truncation.append(
94
+ truncation_latent + truncation * (style - truncation_latent)
95
+ )
96
+ styles = style_truncation
97
+ # get style latents with injection
98
+ if len(styles) == 1:
99
+ inject_index = self.num_latent
100
+
101
+ if styles[0].ndim < 3:
102
+ # repeat latent code for all the layers
103
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
104
+ else: # used for encoder with different latent code for each layer
105
+ latent = styles[0]
106
+ elif len(styles) == 2: # mixing noises
107
+ if inject_index is None:
108
+ inject_index = random.randint(1, self.num_latent - 1)
109
+ latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
110
+ latent2 = (
111
+ styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
112
+ )
113
+ latent = torch.cat([latent1, latent2], 1)
114
+
115
+ # main generation
116
+ out = self.constant_input(latent.shape[0])
117
+ out = self.style_conv1(out, latent[:, 0], noise=noise[0])
118
+ skip = self.to_rgb1(out, latent[:, 1])
119
+
120
+ i = 1
121
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
122
+ self.style_convs[::2],
123
+ self.style_convs[1::2],
124
+ noise[1::2],
125
+ noise[2::2],
126
+ self.to_rgbs,
127
+ ):
128
+ out = conv1(out, latent[:, i], noise=noise1)
129
+
130
+ # the conditions may have fewer levels
131
+ if i < len(conditions):
132
+ # SFT part to combine the conditions
133
+ if self.sft_half: # only apply SFT to half of the channels
134
+ out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
135
+ out_sft = out_sft * conditions[i - 1] + conditions[i]
136
+ out = torch.cat([out_same, out_sft], dim=1)
137
+ else: # apply SFT to all the channels
138
+ out = out * conditions[i - 1] + conditions[i]
139
+
140
+ out = conv2(out, latent[:, i + 1], noise=noise2)
141
+ skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
142
+ i += 2
143
+
144
+ image = skip
145
+
146
+ if return_latents:
147
+ return image, latent
148
+ else:
149
+ return image, None
150
+
151
+
152
+ class GFPGANBilinear(nn.Module):
153
+ """The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
154
+ It is the bilinear version and it does not use the complicated UpFirDnSmooth function that is not friendly for
155
+ deployment. It can be easily converted to the clean version: GFPGANv1Clean.
156
+ Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
157
+ Args:
158
+ out_size (int): The spatial size of outputs.
159
+ num_style_feat (int): Channel number of style features. Default: 512.
160
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
161
+ decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
162
+ fix_decoder (bool): Whether to fix the decoder. Default: True.
163
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
164
+ lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
165
+ input_is_latent (bool): Whether input is latent style. Default: False.
166
+ different_w (bool): Whether to use different latent w for different layers. Default: False.
167
+ narrow (float): The narrow ratio for channels. Default: 1.
168
+ sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
169
+ """
170
+
171
+ def __init__(
172
+ self,
173
+ out_size,
174
+ num_style_feat=512,
175
+ channel_multiplier=1,
176
+ decoder_load_path=None,
177
+ fix_decoder=True,
178
+ # for stylegan decoder
179
+ num_mlp=8,
180
+ lr_mlp=0.01,
181
+ input_is_latent=False,
182
+ different_w=False,
183
+ narrow=1,
184
+ sft_half=False,
185
+ ):
186
+ super(GFPGANBilinear, self).__init__()
187
+ self.input_is_latent = input_is_latent
188
+ self.different_w = different_w
189
+ self.num_style_feat = num_style_feat
190
+ self.min_size_restriction = 512
191
+
192
+ unet_narrow = narrow * 0.5 # by default, use a half of input channels
193
+ channels = {
194
+ "4": int(512 * unet_narrow),
195
+ "8": int(512 * unet_narrow),
196
+ "16": int(512 * unet_narrow),
197
+ "32": int(512 * unet_narrow),
198
+ "64": int(256 * channel_multiplier * unet_narrow),
199
+ "128": int(128 * channel_multiplier * unet_narrow),
200
+ "256": int(64 * channel_multiplier * unet_narrow),
201
+ "512": int(32 * channel_multiplier * unet_narrow),
202
+ "1024": int(16 * channel_multiplier * unet_narrow),
203
+ }
204
+
205
+ self.log_size = int(math.log(out_size, 2))
206
+ first_out_size = 2 ** (int(math.log(out_size, 2)))
207
+
208
+ self.conv_body_first = ConvLayer(
209
+ 3, channels[f"{first_out_size}"], 1, bias=True, activate=True
210
+ )
211
+
212
+ # downsample
213
+ in_channels = channels[f"{first_out_size}"]
214
+ self.conv_body_down = nn.ModuleList()
215
+ for i in range(self.log_size, 2, -1):
216
+ out_channels = channels[f"{2**(i - 1)}"]
217
+ self.conv_body_down.append(ResBlock(in_channels, out_channels))
218
+ in_channels = out_channels
219
+
220
+ self.final_conv = ConvLayer(
221
+ in_channels, channels["4"], 3, bias=True, activate=True
222
+ )
223
+
224
+ # upsample
225
+ in_channels = channels["4"]
226
+ self.conv_body_up = nn.ModuleList()
227
+ for i in range(3, self.log_size + 1):
228
+ out_channels = channels[f"{2**i}"]
229
+ self.conv_body_up.append(ResUpBlock(in_channels, out_channels))
230
+ in_channels = out_channels
231
+
232
+ # to RGB
233
+ self.toRGB = nn.ModuleList()
234
+ for i in range(3, self.log_size + 1):
235
+ self.toRGB.append(
236
+ EqualConv2d(
237
+ channels[f"{2**i}"],
238
+ 3,
239
+ 1,
240
+ stride=1,
241
+ padding=0,
242
+ bias=True,
243
+ bias_init_val=0,
244
+ )
245
+ )
246
+
247
+ if different_w:
248
+ linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
249
+ else:
250
+ linear_out_channel = num_style_feat
251
+
252
+ self.final_linear = EqualLinear(
253
+ channels["4"] * 4 * 4,
254
+ linear_out_channel,
255
+ bias=True,
256
+ bias_init_val=0,
257
+ lr_mul=1,
258
+ activation=None,
259
+ )
260
+
261
+ # the decoder: stylegan2 generator with SFT modulations
262
+ self.stylegan_decoder = StyleGAN2GeneratorBilinearSFT(
263
+ out_size=out_size,
264
+ num_style_feat=num_style_feat,
265
+ num_mlp=num_mlp,
266
+ channel_multiplier=channel_multiplier,
267
+ lr_mlp=lr_mlp,
268
+ narrow=narrow,
269
+ sft_half=sft_half,
270
+ )
271
+
272
+ # load pre-trained stylegan2 model if necessary
273
+ if decoder_load_path:
274
+ self.stylegan_decoder.load_state_dict(
275
+ torch.load(
276
+ decoder_load_path, map_location=lambda storage, loc: storage
277
+ )["params_ema"]
278
+ )
279
+ # fix decoder without updating params
280
+ if fix_decoder:
281
+ for _, param in self.stylegan_decoder.named_parameters():
282
+ param.requires_grad = False
283
+
284
+ # for SFT modulations (scale and shift)
285
+ self.condition_scale = nn.ModuleList()
286
+ self.condition_shift = nn.ModuleList()
287
+ for i in range(3, self.log_size + 1):
288
+ out_channels = channels[f"{2**i}"]
289
+ if sft_half:
290
+ sft_out_channels = out_channels
291
+ else:
292
+ sft_out_channels = out_channels * 2
293
+ self.condition_scale.append(
294
+ nn.Sequential(
295
+ EqualConv2d(
296
+ out_channels,
297
+ out_channels,
298
+ 3,
299
+ stride=1,
300
+ padding=1,
301
+ bias=True,
302
+ bias_init_val=0,
303
+ ),
304
+ ScaledLeakyReLU(0.2),
305
+ EqualConv2d(
306
+ out_channels,
307
+ sft_out_channels,
308
+ 3,
309
+ stride=1,
310
+ padding=1,
311
+ bias=True,
312
+ bias_init_val=1,
313
+ ),
314
+ )
315
+ )
316
+ self.condition_shift.append(
317
+ nn.Sequential(
318
+ EqualConv2d(
319
+ out_channels,
320
+ out_channels,
321
+ 3,
322
+ stride=1,
323
+ padding=1,
324
+ bias=True,
325
+ bias_init_val=0,
326
+ ),
327
+ ScaledLeakyReLU(0.2),
328
+ EqualConv2d(
329
+ out_channels,
330
+ sft_out_channels,
331
+ 3,
332
+ stride=1,
333
+ padding=1,
334
+ bias=True,
335
+ bias_init_val=0,
336
+ ),
337
+ )
338
+ )
339
+
340
+ def forward(self, x, return_latents=False, return_rgb=True, randomize_noise=True):
341
+ """Forward function for GFPGANBilinear.
342
+ Args:
343
+ x (Tensor): Input images.
344
+ return_latents (bool): Whether to return style latents. Default: False.
345
+ return_rgb (bool): Whether return intermediate rgb images. Default: True.
346
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
347
+ """
348
+ conditions = []
349
+ unet_skips = []
350
+ out_rgbs = []
351
+
352
+ # encoder
353
+ feat = self.conv_body_first(x)
354
+ for i in range(self.log_size - 2):
355
+ feat = self.conv_body_down[i](feat)
356
+ unet_skips.insert(0, feat)
357
+
358
+ feat = self.final_conv(feat)
359
+
360
+ # style code
361
+ style_code = self.final_linear(feat.view(feat.size(0), -1))
362
+ if self.different_w:
363
+ style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
364
+
365
+ # decode
366
+ for i in range(self.log_size - 2):
367
+ # add unet skip
368
+ feat = feat + unet_skips[i]
369
+ # ResUpLayer
370
+ feat = self.conv_body_up[i](feat)
371
+ # generate scale and shift for SFT layers
372
+ scale = self.condition_scale[i](feat)
373
+ conditions.append(scale.clone())
374
+ shift = self.condition_shift[i](feat)
375
+ conditions.append(shift.clone())
376
+ # generate rgb images
377
+ if return_rgb:
378
+ out_rgbs.append(self.toRGB[i](feat))
379
+
380
+ # decoder
381
+ image, _ = self.stylegan_decoder(
382
+ [style_code],
383
+ conditions,
384
+ return_latents=return_latents,
385
+ input_is_latent=self.input_is_latent,
386
+ randomize_noise=randomize_noise,
387
+ )
388
+
389
+ return image, out_rgbs
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/face/gfpganv1_arch.py ADDED
@@ -0,0 +1,566 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ # type: ignore
3
+ import math
4
+ import random
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+ from .fused_act import FusedLeakyReLU
11
+ from .stylegan2_arch import (
12
+ ConvLayer,
13
+ EqualConv2d,
14
+ EqualLinear,
15
+ ResBlock,
16
+ ScaledLeakyReLU,
17
+ StyleGAN2Generator,
18
+ )
19
+
20
+
21
+ class StyleGAN2GeneratorSFT(StyleGAN2Generator):
22
+ """StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
23
+ Args:
24
+ out_size (int): The spatial size of outputs.
25
+ num_style_feat (int): Channel number of style features. Default: 512.
26
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
27
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
28
+ resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be
29
+ applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1).
30
+ lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
31
+ narrow (float): The narrow ratio for channels. Default: 1.
32
+ sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
33
+ """
34
+
35
+ def __init__(
36
+ self,
37
+ out_size,
38
+ num_style_feat=512,
39
+ num_mlp=8,
40
+ channel_multiplier=2,
41
+ resample_kernel=(1, 3, 3, 1),
42
+ lr_mlp=0.01,
43
+ narrow=1,
44
+ sft_half=False,
45
+ ):
46
+ super(StyleGAN2GeneratorSFT, self).__init__(
47
+ out_size,
48
+ num_style_feat=num_style_feat,
49
+ num_mlp=num_mlp,
50
+ channel_multiplier=channel_multiplier,
51
+ resample_kernel=resample_kernel,
52
+ lr_mlp=lr_mlp,
53
+ narrow=narrow,
54
+ )
55
+ self.sft_half = sft_half
56
+
57
+ def forward(
58
+ self,
59
+ styles,
60
+ conditions,
61
+ input_is_latent=False,
62
+ noise=None,
63
+ randomize_noise=True,
64
+ truncation=1,
65
+ truncation_latent=None,
66
+ inject_index=None,
67
+ return_latents=False,
68
+ ):
69
+ """Forward function for StyleGAN2GeneratorSFT.
70
+ Args:
71
+ styles (list[Tensor]): Sample codes of styles.
72
+ conditions (list[Tensor]): SFT conditions to generators.
73
+ input_is_latent (bool): Whether input is latent style. Default: False.
74
+ noise (Tensor | None): Input noise or None. Default: None.
75
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
76
+ truncation (float): The truncation ratio. Default: 1.
77
+ truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
78
+ inject_index (int | None): The injection index for mixing noise. Default: None.
79
+ return_latents (bool): Whether to return style latents. Default: False.
80
+ """
81
+ # style codes -> latents with Style MLP layer
82
+ if not input_is_latent:
83
+ styles = [self.style_mlp(s) for s in styles]
84
+ # noises
85
+ if noise is None:
86
+ if randomize_noise:
87
+ noise = [None] * self.num_layers # for each style conv layer
88
+ else: # use the stored noise
89
+ noise = [
90
+ getattr(self.noises, f"noise{i}") for i in range(self.num_layers)
91
+ ]
92
+ # style truncation
93
+ if truncation < 1:
94
+ style_truncation = []
95
+ for style in styles:
96
+ style_truncation.append(
97
+ truncation_latent + truncation * (style - truncation_latent)
98
+ )
99
+ styles = style_truncation
100
+ # get style latents with injection
101
+ if len(styles) == 1:
102
+ inject_index = self.num_latent
103
+
104
+ if styles[0].ndim < 3:
105
+ # repeat latent code for all the layers
106
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
107
+ else: # used for encoder with different latent code for each layer
108
+ latent = styles[0]
109
+ elif len(styles) == 2: # mixing noises
110
+ if inject_index is None:
111
+ inject_index = random.randint(1, self.num_latent - 1)
112
+ latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
113
+ latent2 = (
114
+ styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
115
+ )
116
+ latent = torch.cat([latent1, latent2], 1)
117
+
118
+ # main generation
119
+ out = self.constant_input(latent.shape[0])
120
+ out = self.style_conv1(out, latent[:, 0], noise=noise[0])
121
+ skip = self.to_rgb1(out, latent[:, 1])
122
+
123
+ i = 1
124
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
125
+ self.style_convs[::2],
126
+ self.style_convs[1::2],
127
+ noise[1::2],
128
+ noise[2::2],
129
+ self.to_rgbs,
130
+ ):
131
+ out = conv1(out, latent[:, i], noise=noise1)
132
+
133
+ # the conditions may have fewer levels
134
+ if i < len(conditions):
135
+ # SFT part to combine the conditions
136
+ if self.sft_half: # only apply SFT to half of the channels
137
+ out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
138
+ out_sft = out_sft * conditions[i - 1] + conditions[i]
139
+ out = torch.cat([out_same, out_sft], dim=1)
140
+ else: # apply SFT to all the channels
141
+ out = out * conditions[i - 1] + conditions[i]
142
+
143
+ out = conv2(out, latent[:, i + 1], noise=noise2)
144
+ skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
145
+ i += 2
146
+
147
+ image = skip
148
+
149
+ if return_latents:
150
+ return image, latent
151
+ else:
152
+ return image, None
153
+
154
+
155
+ class ConvUpLayer(nn.Module):
156
+ """Convolutional upsampling layer. It uses bilinear upsampler + Conv.
157
+ Args:
158
+ in_channels (int): Channel number of the input.
159
+ out_channels (int): Channel number of the output.
160
+ kernel_size (int): Size of the convolving kernel.
161
+ stride (int): Stride of the convolution. Default: 1
162
+ padding (int): Zero-padding added to both sides of the input. Default: 0.
163
+ bias (bool): If ``True``, adds a learnable bias to the output. Default: ``True``.
164
+ bias_init_val (float): Bias initialized value. Default: 0.
165
+ activate (bool): Whether use activateion. Default: True.
166
+ """
167
+
168
+ def __init__(
169
+ self,
170
+ in_channels,
171
+ out_channels,
172
+ kernel_size,
173
+ stride=1,
174
+ padding=0,
175
+ bias=True,
176
+ bias_init_val=0,
177
+ activate=True,
178
+ ):
179
+ super(ConvUpLayer, self).__init__()
180
+ self.in_channels = in_channels
181
+ self.out_channels = out_channels
182
+ self.kernel_size = kernel_size
183
+ self.stride = stride
184
+ self.padding = padding
185
+ # self.scale is used to scale the convolution weights, which is related to the common initializations.
186
+ self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
187
+
188
+ self.weight = nn.Parameter(
189
+ torch.randn(out_channels, in_channels, kernel_size, kernel_size)
190
+ )
191
+
192
+ if bias and not activate:
193
+ self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
194
+ else:
195
+ self.register_parameter("bias", None)
196
+
197
+ # activation
198
+ if activate:
199
+ if bias:
200
+ self.activation = FusedLeakyReLU(out_channels)
201
+ else:
202
+ self.activation = ScaledLeakyReLU(0.2)
203
+ else:
204
+ self.activation = None
205
+
206
+ def forward(self, x):
207
+ # bilinear upsample
208
+ out = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False)
209
+ # conv
210
+ out = F.conv2d(
211
+ out,
212
+ self.weight * self.scale,
213
+ bias=self.bias,
214
+ stride=self.stride,
215
+ padding=self.padding,
216
+ )
217
+ # activation
218
+ if self.activation is not None:
219
+ out = self.activation(out)
220
+ return out
221
+
222
+
223
+ class ResUpBlock(nn.Module):
224
+ """Residual block with upsampling.
225
+ Args:
226
+ in_channels (int): Channel number of the input.
227
+ out_channels (int): Channel number of the output.
228
+ """
229
+
230
+ def __init__(self, in_channels, out_channels):
231
+ super(ResUpBlock, self).__init__()
232
+
233
+ self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
234
+ self.conv2 = ConvUpLayer(
235
+ in_channels, out_channels, 3, stride=1, padding=1, bias=True, activate=True
236
+ )
237
+ self.skip = ConvUpLayer(
238
+ in_channels, out_channels, 1, bias=False, activate=False
239
+ )
240
+
241
+ def forward(self, x):
242
+ out = self.conv1(x)
243
+ out = self.conv2(out)
244
+ skip = self.skip(x)
245
+ out = (out + skip) / math.sqrt(2)
246
+ return out
247
+
248
+
249
+ class GFPGANv1(nn.Module):
250
+ """The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
251
+ Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
252
+ Args:
253
+ out_size (int): The spatial size of outputs.
254
+ num_style_feat (int): Channel number of style features. Default: 512.
255
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
256
+ resample_kernel (list[int]): A list indicating the 1D resample kernel magnitude. A cross production will be
257
+ applied to extent 1D resample kernel to 2D resample kernel. Default: (1, 3, 3, 1).
258
+ decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
259
+ fix_decoder (bool): Whether to fix the decoder. Default: True.
260
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
261
+ lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
262
+ input_is_latent (bool): Whether input is latent style. Default: False.
263
+ different_w (bool): Whether to use different latent w for different layers. Default: False.
264
+ narrow (float): The narrow ratio for channels. Default: 1.
265
+ sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
266
+ """
267
+
268
+ def __init__(
269
+ self,
270
+ out_size,
271
+ num_style_feat=512,
272
+ channel_multiplier=1,
273
+ resample_kernel=(1, 3, 3, 1),
274
+ decoder_load_path=None,
275
+ fix_decoder=True,
276
+ # for stylegan decoder
277
+ num_mlp=8,
278
+ lr_mlp=0.01,
279
+ input_is_latent=False,
280
+ different_w=False,
281
+ narrow=1,
282
+ sft_half=False,
283
+ ):
284
+ super(GFPGANv1, self).__init__()
285
+ self.input_is_latent = input_is_latent
286
+ self.different_w = different_w
287
+ self.num_style_feat = num_style_feat
288
+
289
+ unet_narrow = narrow * 0.5 # by default, use a half of input channels
290
+ channels = {
291
+ "4": int(512 * unet_narrow),
292
+ "8": int(512 * unet_narrow),
293
+ "16": int(512 * unet_narrow),
294
+ "32": int(512 * unet_narrow),
295
+ "64": int(256 * channel_multiplier * unet_narrow),
296
+ "128": int(128 * channel_multiplier * unet_narrow),
297
+ "256": int(64 * channel_multiplier * unet_narrow),
298
+ "512": int(32 * channel_multiplier * unet_narrow),
299
+ "1024": int(16 * channel_multiplier * unet_narrow),
300
+ }
301
+
302
+ self.log_size = int(math.log(out_size, 2))
303
+ first_out_size = 2 ** (int(math.log(out_size, 2)))
304
+
305
+ self.conv_body_first = ConvLayer(
306
+ 3, channels[f"{first_out_size}"], 1, bias=True, activate=True
307
+ )
308
+
309
+ # downsample
310
+ in_channels = channels[f"{first_out_size}"]
311
+ self.conv_body_down = nn.ModuleList()
312
+ for i in range(self.log_size, 2, -1):
313
+ out_channels = channels[f"{2**(i - 1)}"]
314
+ self.conv_body_down.append(
315
+ ResBlock(in_channels, out_channels, resample_kernel)
316
+ )
317
+ in_channels = out_channels
318
+
319
+ self.final_conv = ConvLayer(
320
+ in_channels, channels["4"], 3, bias=True, activate=True
321
+ )
322
+
323
+ # upsample
324
+ in_channels = channels["4"]
325
+ self.conv_body_up = nn.ModuleList()
326
+ for i in range(3, self.log_size + 1):
327
+ out_channels = channels[f"{2**i}"]
328
+ self.conv_body_up.append(ResUpBlock(in_channels, out_channels))
329
+ in_channels = out_channels
330
+
331
+ # to RGB
332
+ self.toRGB = nn.ModuleList()
333
+ for i in range(3, self.log_size + 1):
334
+ self.toRGB.append(
335
+ EqualConv2d(
336
+ channels[f"{2**i}"],
337
+ 3,
338
+ 1,
339
+ stride=1,
340
+ padding=0,
341
+ bias=True,
342
+ bias_init_val=0,
343
+ )
344
+ )
345
+
346
+ if different_w:
347
+ linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
348
+ else:
349
+ linear_out_channel = num_style_feat
350
+
351
+ self.final_linear = EqualLinear(
352
+ channels["4"] * 4 * 4,
353
+ linear_out_channel,
354
+ bias=True,
355
+ bias_init_val=0,
356
+ lr_mul=1,
357
+ activation=None,
358
+ )
359
+
360
+ # the decoder: stylegan2 generator with SFT modulations
361
+ self.stylegan_decoder = StyleGAN2GeneratorSFT(
362
+ out_size=out_size,
363
+ num_style_feat=num_style_feat,
364
+ num_mlp=num_mlp,
365
+ channel_multiplier=channel_multiplier,
366
+ resample_kernel=resample_kernel,
367
+ lr_mlp=lr_mlp,
368
+ narrow=narrow,
369
+ sft_half=sft_half,
370
+ )
371
+
372
+ # load pre-trained stylegan2 model if necessary
373
+ if decoder_load_path:
374
+ self.stylegan_decoder.load_state_dict(
375
+ torch.load(
376
+ decoder_load_path, map_location=lambda storage, loc: storage
377
+ )["params_ema"]
378
+ )
379
+ # fix decoder without updating params
380
+ if fix_decoder:
381
+ for _, param in self.stylegan_decoder.named_parameters():
382
+ param.requires_grad = False
383
+
384
+ # for SFT modulations (scale and shift)
385
+ self.condition_scale = nn.ModuleList()
386
+ self.condition_shift = nn.ModuleList()
387
+ for i in range(3, self.log_size + 1):
388
+ out_channels = channels[f"{2**i}"]
389
+ if sft_half:
390
+ sft_out_channels = out_channels
391
+ else:
392
+ sft_out_channels = out_channels * 2
393
+ self.condition_scale.append(
394
+ nn.Sequential(
395
+ EqualConv2d(
396
+ out_channels,
397
+ out_channels,
398
+ 3,
399
+ stride=1,
400
+ padding=1,
401
+ bias=True,
402
+ bias_init_val=0,
403
+ ),
404
+ ScaledLeakyReLU(0.2),
405
+ EqualConv2d(
406
+ out_channels,
407
+ sft_out_channels,
408
+ 3,
409
+ stride=1,
410
+ padding=1,
411
+ bias=True,
412
+ bias_init_val=1,
413
+ ),
414
+ )
415
+ )
416
+ self.condition_shift.append(
417
+ nn.Sequential(
418
+ EqualConv2d(
419
+ out_channels,
420
+ out_channels,
421
+ 3,
422
+ stride=1,
423
+ padding=1,
424
+ bias=True,
425
+ bias_init_val=0,
426
+ ),
427
+ ScaledLeakyReLU(0.2),
428
+ EqualConv2d(
429
+ out_channels,
430
+ sft_out_channels,
431
+ 3,
432
+ stride=1,
433
+ padding=1,
434
+ bias=True,
435
+ bias_init_val=0,
436
+ ),
437
+ )
438
+ )
439
+
440
+ def forward(
441
+ self, x, return_latents=False, return_rgb=True, randomize_noise=True, **kwargs
442
+ ):
443
+ """Forward function for GFPGANv1.
444
+ Args:
445
+ x (Tensor): Input images.
446
+ return_latents (bool): Whether to return style latents. Default: False.
447
+ return_rgb (bool): Whether return intermediate rgb images. Default: True.
448
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
449
+ """
450
+ conditions = []
451
+ unet_skips = []
452
+ out_rgbs = []
453
+
454
+ # encoder
455
+ feat = self.conv_body_first(x)
456
+ for i in range(self.log_size - 2):
457
+ feat = self.conv_body_down[i](feat)
458
+ unet_skips.insert(0, feat)
459
+
460
+ feat = self.final_conv(feat)
461
+
462
+ # style code
463
+ style_code = self.final_linear(feat.view(feat.size(0), -1))
464
+ if self.different_w:
465
+ style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
466
+
467
+ # decode
468
+ for i in range(self.log_size - 2):
469
+ # add unet skip
470
+ feat = feat + unet_skips[i]
471
+ # ResUpLayer
472
+ feat = self.conv_body_up[i](feat)
473
+ # generate scale and shift for SFT layers
474
+ scale = self.condition_scale[i](feat)
475
+ conditions.append(scale.clone())
476
+ shift = self.condition_shift[i](feat)
477
+ conditions.append(shift.clone())
478
+ # generate rgb images
479
+ if return_rgb:
480
+ out_rgbs.append(self.toRGB[i](feat))
481
+
482
+ # decoder
483
+ image, _ = self.stylegan_decoder(
484
+ [style_code],
485
+ conditions,
486
+ return_latents=return_latents,
487
+ input_is_latent=self.input_is_latent,
488
+ randomize_noise=randomize_noise,
489
+ )
490
+
491
+ return image, out_rgbs
492
+
493
+
494
+ class FacialComponentDiscriminator(nn.Module):
495
+ """Facial component (eyes, mouth, noise) discriminator used in GFPGAN."""
496
+
497
+ def __init__(self):
498
+ super(FacialComponentDiscriminator, self).__init__()
499
+ # It now uses a VGG-style architectrue with fixed model size
500
+ self.conv1 = ConvLayer(
501
+ 3,
502
+ 64,
503
+ 3,
504
+ downsample=False,
505
+ resample_kernel=(1, 3, 3, 1),
506
+ bias=True,
507
+ activate=True,
508
+ )
509
+ self.conv2 = ConvLayer(
510
+ 64,
511
+ 128,
512
+ 3,
513
+ downsample=True,
514
+ resample_kernel=(1, 3, 3, 1),
515
+ bias=True,
516
+ activate=True,
517
+ )
518
+ self.conv3 = ConvLayer(
519
+ 128,
520
+ 128,
521
+ 3,
522
+ downsample=False,
523
+ resample_kernel=(1, 3, 3, 1),
524
+ bias=True,
525
+ activate=True,
526
+ )
527
+ self.conv4 = ConvLayer(
528
+ 128,
529
+ 256,
530
+ 3,
531
+ downsample=True,
532
+ resample_kernel=(1, 3, 3, 1),
533
+ bias=True,
534
+ activate=True,
535
+ )
536
+ self.conv5 = ConvLayer(
537
+ 256,
538
+ 256,
539
+ 3,
540
+ downsample=False,
541
+ resample_kernel=(1, 3, 3, 1),
542
+ bias=True,
543
+ activate=True,
544
+ )
545
+ self.final_conv = ConvLayer(256, 1, 3, bias=True, activate=False)
546
+
547
+ def forward(self, x, return_feats=False, **kwargs):
548
+ """Forward function for FacialComponentDiscriminator.
549
+ Args:
550
+ x (Tensor): Input images.
551
+ return_feats (bool): Whether to return intermediate features. Default: False.
552
+ """
553
+ feat = self.conv1(x)
554
+ feat = self.conv3(self.conv2(feat))
555
+ rlt_feats = []
556
+ if return_feats:
557
+ rlt_feats.append(feat.clone())
558
+ feat = self.conv5(self.conv4(feat))
559
+ if return_feats:
560
+ rlt_feats.append(feat.clone())
561
+ out = self.final_conv(feat)
562
+
563
+ if return_feats:
564
+ return out, rlt_feats
565
+ else:
566
+ return out, None
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/face/gfpganv1_clean_arch.py ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ # type: ignore
3
+ import math
4
+ import random
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+ from .stylegan2_clean_arch import StyleGAN2GeneratorClean
11
+
12
+
13
+ class StyleGAN2GeneratorCSFT(StyleGAN2GeneratorClean):
14
+ """StyleGAN2 Generator with SFT modulation (Spatial Feature Transform).
15
+ It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
16
+ Args:
17
+ out_size (int): The spatial size of outputs.
18
+ num_style_feat (int): Channel number of style features. Default: 512.
19
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
20
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
21
+ narrow (float): The narrow ratio for channels. Default: 1.
22
+ sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ out_size,
28
+ num_style_feat=512,
29
+ num_mlp=8,
30
+ channel_multiplier=2,
31
+ narrow=1,
32
+ sft_half=False,
33
+ ):
34
+ super(StyleGAN2GeneratorCSFT, self).__init__(
35
+ out_size,
36
+ num_style_feat=num_style_feat,
37
+ num_mlp=num_mlp,
38
+ channel_multiplier=channel_multiplier,
39
+ narrow=narrow,
40
+ )
41
+ self.sft_half = sft_half
42
+
43
+ def forward(
44
+ self,
45
+ styles,
46
+ conditions,
47
+ input_is_latent=False,
48
+ noise=None,
49
+ randomize_noise=True,
50
+ truncation=1,
51
+ truncation_latent=None,
52
+ inject_index=None,
53
+ return_latents=False,
54
+ ):
55
+ """Forward function for StyleGAN2GeneratorCSFT.
56
+ Args:
57
+ styles (list[Tensor]): Sample codes of styles.
58
+ conditions (list[Tensor]): SFT conditions to generators.
59
+ input_is_latent (bool): Whether input is latent style. Default: False.
60
+ noise (Tensor | None): Input noise or None. Default: None.
61
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
62
+ truncation (float): The truncation ratio. Default: 1.
63
+ truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
64
+ inject_index (int | None): The injection index for mixing noise. Default: None.
65
+ return_latents (bool): Whether to return style latents. Default: False.
66
+ """
67
+ # style codes -> latents with Style MLP layer
68
+ if not input_is_latent:
69
+ styles = [self.style_mlp(s) for s in styles]
70
+ # noises
71
+ if noise is None:
72
+ if randomize_noise:
73
+ noise = [None] * self.num_layers # for each style conv layer
74
+ else: # use the stored noise
75
+ noise = [
76
+ getattr(self.noises, f"noise{i}") for i in range(self.num_layers)
77
+ ]
78
+ # style truncation
79
+ if truncation < 1:
80
+ style_truncation = []
81
+ for style in styles:
82
+ style_truncation.append(
83
+ truncation_latent + truncation * (style - truncation_latent)
84
+ )
85
+ styles = style_truncation
86
+ # get style latents with injection
87
+ if len(styles) == 1:
88
+ inject_index = self.num_latent
89
+
90
+ if styles[0].ndim < 3:
91
+ # repeat latent code for all the layers
92
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
93
+ else: # used for encoder with different latent code for each layer
94
+ latent = styles[0]
95
+ elif len(styles) == 2: # mixing noises
96
+ if inject_index is None:
97
+ inject_index = random.randint(1, self.num_latent - 1)
98
+ latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
99
+ latent2 = (
100
+ styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
101
+ )
102
+ latent = torch.cat([latent1, latent2], 1)
103
+
104
+ # main generation
105
+ out = self.constant_input(latent.shape[0])
106
+ out = self.style_conv1(out, latent[:, 0], noise=noise[0])
107
+ skip = self.to_rgb1(out, latent[:, 1])
108
+
109
+ i = 1
110
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
111
+ self.style_convs[::2],
112
+ self.style_convs[1::2],
113
+ noise[1::2],
114
+ noise[2::2],
115
+ self.to_rgbs,
116
+ ):
117
+ out = conv1(out, latent[:, i], noise=noise1)
118
+
119
+ # the conditions may have fewer levels
120
+ if i < len(conditions):
121
+ # SFT part to combine the conditions
122
+ if self.sft_half: # only apply SFT to half of the channels
123
+ out_same, out_sft = torch.split(out, int(out.size(1) // 2), dim=1)
124
+ out_sft = out_sft * conditions[i - 1] + conditions[i]
125
+ out = torch.cat([out_same, out_sft], dim=1)
126
+ else: # apply SFT to all the channels
127
+ out = out * conditions[i - 1] + conditions[i]
128
+
129
+ out = conv2(out, latent[:, i + 1], noise=noise2)
130
+ skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
131
+ i += 2
132
+
133
+ image = skip
134
+
135
+ if return_latents:
136
+ return image, latent
137
+ else:
138
+ return image, None
139
+
140
+
141
+ class ResBlock(nn.Module):
142
+ """Residual block with bilinear upsampling/downsampling.
143
+ Args:
144
+ in_channels (int): Channel number of the input.
145
+ out_channels (int): Channel number of the output.
146
+ mode (str): Upsampling/downsampling mode. Options: down | up. Default: down.
147
+ """
148
+
149
+ def __init__(self, in_channels, out_channels, mode="down"):
150
+ super(ResBlock, self).__init__()
151
+
152
+ self.conv1 = nn.Conv2d(in_channels, in_channels, 3, 1, 1)
153
+ self.conv2 = nn.Conv2d(in_channels, out_channels, 3, 1, 1)
154
+ self.skip = nn.Conv2d(in_channels, out_channels, 1, bias=False)
155
+ if mode == "down":
156
+ self.scale_factor = 0.5
157
+ elif mode == "up":
158
+ self.scale_factor = 2
159
+
160
+ def forward(self, x):
161
+ out = F.leaky_relu_(self.conv1(x), negative_slope=0.2)
162
+ # upsample/downsample
163
+ out = F.interpolate(
164
+ out, scale_factor=self.scale_factor, mode="bilinear", align_corners=False
165
+ )
166
+ out = F.leaky_relu_(self.conv2(out), negative_slope=0.2)
167
+ # skip
168
+ x = F.interpolate(
169
+ x, scale_factor=self.scale_factor, mode="bilinear", align_corners=False
170
+ )
171
+ skip = self.skip(x)
172
+ out = out + skip
173
+ return out
174
+
175
+
176
+ class GFPGANv1Clean(nn.Module):
177
+ """The GFPGAN architecture: Unet + StyleGAN2 decoder with SFT.
178
+ It is the clean version without custom compiled CUDA extensions used in StyleGAN2.
179
+ Ref: GFP-GAN: Towards Real-World Blind Face Restoration with Generative Facial Prior.
180
+ Args:
181
+ out_size (int): The spatial size of outputs.
182
+ num_style_feat (int): Channel number of style features. Default: 512.
183
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
184
+ decoder_load_path (str): The path to the pre-trained decoder model (usually, the StyleGAN2). Default: None.
185
+ fix_decoder (bool): Whether to fix the decoder. Default: True.
186
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
187
+ input_is_latent (bool): Whether input is latent style. Default: False.
188
+ different_w (bool): Whether to use different latent w for different layers. Default: False.
189
+ narrow (float): The narrow ratio for channels. Default: 1.
190
+ sft_half (bool): Whether to apply SFT on half of the input channels. Default: False.
191
+ """
192
+
193
+ def __init__(
194
+ self,
195
+ state_dict,
196
+ ):
197
+ super(GFPGANv1Clean, self).__init__()
198
+
199
+ out_size = 512
200
+ num_style_feat = 512
201
+ channel_multiplier = 2
202
+ decoder_load_path = None
203
+ fix_decoder = False
204
+ num_mlp = 8
205
+ input_is_latent = True
206
+ different_w = True
207
+ narrow = 1
208
+ sft_half = True
209
+
210
+ self.model_arch = "GFPGAN"
211
+ self.sub_type = "Face SR"
212
+ self.scale = 8
213
+ self.in_nc = 3
214
+ self.out_nc = 3
215
+ self.state = state_dict
216
+
217
+ self.supports_fp16 = False
218
+ self.supports_bf16 = True
219
+ self.min_size_restriction = 512
220
+
221
+ self.input_is_latent = input_is_latent
222
+ self.different_w = different_w
223
+ self.num_style_feat = num_style_feat
224
+
225
+ unet_narrow = narrow * 0.5 # by default, use a half of input channels
226
+ channels = {
227
+ "4": int(512 * unet_narrow),
228
+ "8": int(512 * unet_narrow),
229
+ "16": int(512 * unet_narrow),
230
+ "32": int(512 * unet_narrow),
231
+ "64": int(256 * channel_multiplier * unet_narrow),
232
+ "128": int(128 * channel_multiplier * unet_narrow),
233
+ "256": int(64 * channel_multiplier * unet_narrow),
234
+ "512": int(32 * channel_multiplier * unet_narrow),
235
+ "1024": int(16 * channel_multiplier * unet_narrow),
236
+ }
237
+
238
+ self.log_size = int(math.log(out_size, 2))
239
+ first_out_size = 2 ** (int(math.log(out_size, 2)))
240
+
241
+ self.conv_body_first = nn.Conv2d(3, channels[f"{first_out_size}"], 1)
242
+
243
+ # downsample
244
+ in_channels = channels[f"{first_out_size}"]
245
+ self.conv_body_down = nn.ModuleList()
246
+ for i in range(self.log_size, 2, -1):
247
+ out_channels = channels[f"{2**(i - 1)}"]
248
+ self.conv_body_down.append(ResBlock(in_channels, out_channels, mode="down"))
249
+ in_channels = out_channels
250
+
251
+ self.final_conv = nn.Conv2d(in_channels, channels["4"], 3, 1, 1)
252
+
253
+ # upsample
254
+ in_channels = channels["4"]
255
+ self.conv_body_up = nn.ModuleList()
256
+ for i in range(3, self.log_size + 1):
257
+ out_channels = channels[f"{2**i}"]
258
+ self.conv_body_up.append(ResBlock(in_channels, out_channels, mode="up"))
259
+ in_channels = out_channels
260
+
261
+ # to RGB
262
+ self.toRGB = nn.ModuleList()
263
+ for i in range(3, self.log_size + 1):
264
+ self.toRGB.append(nn.Conv2d(channels[f"{2**i}"], 3, 1))
265
+
266
+ if different_w:
267
+ linear_out_channel = (int(math.log(out_size, 2)) * 2 - 2) * num_style_feat
268
+ else:
269
+ linear_out_channel = num_style_feat
270
+
271
+ self.final_linear = nn.Linear(channels["4"] * 4 * 4, linear_out_channel)
272
+
273
+ # the decoder: stylegan2 generator with SFT modulations
274
+ self.stylegan_decoder = StyleGAN2GeneratorCSFT(
275
+ out_size=out_size,
276
+ num_style_feat=num_style_feat,
277
+ num_mlp=num_mlp,
278
+ channel_multiplier=channel_multiplier,
279
+ narrow=narrow,
280
+ sft_half=sft_half,
281
+ )
282
+
283
+ # load pre-trained stylegan2 model if necessary
284
+ if decoder_load_path:
285
+ self.stylegan_decoder.load_state_dict(
286
+ torch.load(
287
+ decoder_load_path, map_location=lambda storage, loc: storage
288
+ )["params_ema"]
289
+ )
290
+ # fix decoder without updating params
291
+ if fix_decoder:
292
+ for _, param in self.stylegan_decoder.named_parameters():
293
+ param.requires_grad = False
294
+
295
+ # for SFT modulations (scale and shift)
296
+ self.condition_scale = nn.ModuleList()
297
+ self.condition_shift = nn.ModuleList()
298
+ for i in range(3, self.log_size + 1):
299
+ out_channels = channels[f"{2**i}"]
300
+ if sft_half:
301
+ sft_out_channels = out_channels
302
+ else:
303
+ sft_out_channels = out_channels * 2
304
+ self.condition_scale.append(
305
+ nn.Sequential(
306
+ nn.Conv2d(out_channels, out_channels, 3, 1, 1),
307
+ nn.LeakyReLU(0.2, True),
308
+ nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1),
309
+ )
310
+ )
311
+ self.condition_shift.append(
312
+ nn.Sequential(
313
+ nn.Conv2d(out_channels, out_channels, 3, 1, 1),
314
+ nn.LeakyReLU(0.2, True),
315
+ nn.Conv2d(out_channels, sft_out_channels, 3, 1, 1),
316
+ )
317
+ )
318
+ self.load_state_dict(state_dict)
319
+
320
+ def forward(
321
+ self, x, return_latents=False, return_rgb=True, randomize_noise=True, **kwargs
322
+ ):
323
+ """Forward function for GFPGANv1Clean.
324
+ Args:
325
+ x (Tensor): Input images.
326
+ return_latents (bool): Whether to return style latents. Default: False.
327
+ return_rgb (bool): Whether return intermediate rgb images. Default: True.
328
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
329
+ """
330
+ conditions = []
331
+ unet_skips = []
332
+ out_rgbs = []
333
+
334
+ # encoder
335
+ feat = F.leaky_relu_(self.conv_body_first(x), negative_slope=0.2)
336
+ for i in range(self.log_size - 2):
337
+ feat = self.conv_body_down[i](feat)
338
+ unet_skips.insert(0, feat)
339
+ feat = F.leaky_relu_(self.final_conv(feat), negative_slope=0.2)
340
+
341
+ # style code
342
+ style_code = self.final_linear(feat.view(feat.size(0), -1))
343
+ if self.different_w:
344
+ style_code = style_code.view(style_code.size(0), -1, self.num_style_feat)
345
+
346
+ # decode
347
+ for i in range(self.log_size - 2):
348
+ # add unet skip
349
+ feat = feat + unet_skips[i]
350
+ # ResUpLayer
351
+ feat = self.conv_body_up[i](feat)
352
+ # generate scale and shift for SFT layers
353
+ scale = self.condition_scale[i](feat)
354
+ conditions.append(scale.clone())
355
+ shift = self.condition_shift[i](feat)
356
+ conditions.append(shift.clone())
357
+ # generate rgb images
358
+ if return_rgb:
359
+ out_rgbs.append(self.toRGB[i](feat))
360
+
361
+ # decoder
362
+ image, _ = self.stylegan_decoder(
363
+ [style_code],
364
+ conditions,
365
+ return_latents=return_latents,
366
+ input_is_latent=self.input_is_latent,
367
+ randomize_noise=randomize_noise,
368
+ )
369
+
370
+ return image, out_rgbs
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/face/restoreformer_arch.py ADDED
@@ -0,0 +1,776 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ # type: ignore
3
+ """Modified from https://github.com/wzhouxiff/RestoreFormer
4
+ """
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+
10
+
11
+ class VectorQuantizer(nn.Module):
12
+ """
13
+ see https://github.com/MishaLaskin/vqvae/blob/d761a999e2267766400dc646d82d3ac3657771d4/models/quantizer.py
14
+ ____________________________________________
15
+ Discretization bottleneck part of the VQ-VAE.
16
+ Inputs:
17
+ - n_e : number of embeddings
18
+ - e_dim : dimension of embedding
19
+ - beta : commitment cost used in loss term, beta * ||z_e(x)-sg[e]||^2
20
+ _____________________________________________
21
+ """
22
+
23
+ def __init__(self, n_e, e_dim, beta):
24
+ super(VectorQuantizer, self).__init__()
25
+ self.n_e = n_e
26
+ self.e_dim = e_dim
27
+ self.beta = beta
28
+
29
+ self.embedding = nn.Embedding(self.n_e, self.e_dim)
30
+ self.embedding.weight.data.uniform_(-1.0 / self.n_e, 1.0 / self.n_e)
31
+
32
+ def forward(self, z):
33
+ """
34
+ Inputs the output of the encoder network z and maps it to a discrete
35
+ one-hot vector that is the index of the closest embedding vector e_j
36
+ z (continuous) -> z_q (discrete)
37
+ z.shape = (batch, channel, height, width)
38
+ quantization pipeline:
39
+ 1. get encoder input (B,C,H,W)
40
+ 2. flatten input to (B*H*W,C)
41
+ """
42
+ # reshape z -> (batch, height, width, channel) and flatten
43
+ z = z.permute(0, 2, 3, 1).contiguous()
44
+ z_flattened = z.view(-1, self.e_dim)
45
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
46
+
47
+ d = (
48
+ torch.sum(z_flattened**2, dim=1, keepdim=True)
49
+ + torch.sum(self.embedding.weight**2, dim=1)
50
+ - 2 * torch.matmul(z_flattened, self.embedding.weight.t())
51
+ )
52
+
53
+ # could possible replace this here
54
+ # #\start...
55
+ # find closest encodings
56
+
57
+ min_value, min_encoding_indices = torch.min(d, dim=1)
58
+
59
+ min_encoding_indices = min_encoding_indices.unsqueeze(1)
60
+
61
+ min_encodings = torch.zeros(min_encoding_indices.shape[0], self.n_e).to(z)
62
+ min_encodings.scatter_(1, min_encoding_indices, 1)
63
+
64
+ # dtype min encodings: torch.float32
65
+ # min_encodings shape: torch.Size([2048, 512])
66
+ # min_encoding_indices.shape: torch.Size([2048, 1])
67
+
68
+ # get quantized latent vectors
69
+ z_q = torch.matmul(min_encodings, self.embedding.weight).view(z.shape)
70
+ # .........\end
71
+
72
+ # with:
73
+ # .........\start
74
+ # min_encoding_indices = torch.argmin(d, dim=1)
75
+ # z_q = self.embedding(min_encoding_indices)
76
+ # ......\end......... (TODO)
77
+
78
+ # compute loss for embedding
79
+ loss = torch.mean((z_q.detach() - z) ** 2) + self.beta * torch.mean(
80
+ (z_q - z.detach()) ** 2
81
+ )
82
+
83
+ # preserve gradients
84
+ z_q = z + (z_q - z).detach()
85
+
86
+ # perplexity
87
+
88
+ e_mean = torch.mean(min_encodings, dim=0)
89
+ perplexity = torch.exp(-torch.sum(e_mean * torch.log(e_mean + 1e-10)))
90
+
91
+ # reshape back to match original input shape
92
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
93
+
94
+ return z_q, loss, (perplexity, min_encodings, min_encoding_indices, d)
95
+
96
+ def get_codebook_entry(self, indices, shape):
97
+ # shape specifying (batch, height, width, channel)
98
+ # TODO: check for more easy handling with nn.Embedding
99
+ min_encodings = torch.zeros(indices.shape[0], self.n_e).to(indices)
100
+ min_encodings.scatter_(1, indices[:, None], 1)
101
+
102
+ # get quantized latent vectors
103
+ z_q = torch.matmul(min_encodings.float(), self.embedding.weight)
104
+
105
+ if shape is not None:
106
+ z_q = z_q.view(shape)
107
+
108
+ # reshape back to match original input shape
109
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
110
+
111
+ return z_q
112
+
113
+
114
+ # pytorch_diffusion + derived encoder decoder
115
+ def nonlinearity(x):
116
+ # swish
117
+ return x * torch.sigmoid(x)
118
+
119
+
120
+ def Normalize(in_channels):
121
+ return torch.nn.GroupNorm(
122
+ num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
123
+ )
124
+
125
+
126
+ class Upsample(nn.Module):
127
+ def __init__(self, in_channels, with_conv):
128
+ super().__init__()
129
+ self.with_conv = with_conv
130
+ if self.with_conv:
131
+ self.conv = torch.nn.Conv2d(
132
+ in_channels, in_channels, kernel_size=3, stride=1, padding=1
133
+ )
134
+
135
+ def forward(self, x):
136
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
137
+ if self.with_conv:
138
+ x = self.conv(x)
139
+ return x
140
+
141
+
142
+ class Downsample(nn.Module):
143
+ def __init__(self, in_channels, with_conv):
144
+ super().__init__()
145
+ self.with_conv = with_conv
146
+ if self.with_conv:
147
+ # no asymmetric padding in torch conv, must do it ourselves
148
+ self.conv = torch.nn.Conv2d(
149
+ in_channels, in_channels, kernel_size=3, stride=2, padding=0
150
+ )
151
+
152
+ def forward(self, x):
153
+ if self.with_conv:
154
+ pad = (0, 1, 0, 1)
155
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
156
+ x = self.conv(x)
157
+ else:
158
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
159
+ return x
160
+
161
+
162
+ class ResnetBlock(nn.Module):
163
+ def __init__(
164
+ self,
165
+ *,
166
+ in_channels,
167
+ out_channels=None,
168
+ conv_shortcut=False,
169
+ dropout,
170
+ temb_channels=512
171
+ ):
172
+ super().__init__()
173
+ self.in_channels = in_channels
174
+ out_channels = in_channels if out_channels is None else out_channels
175
+ self.out_channels = out_channels
176
+ self.use_conv_shortcut = conv_shortcut
177
+
178
+ self.norm1 = Normalize(in_channels)
179
+ self.conv1 = torch.nn.Conv2d(
180
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
181
+ )
182
+ if temb_channels > 0:
183
+ self.temb_proj = torch.nn.Linear(temb_channels, out_channels)
184
+ self.norm2 = Normalize(out_channels)
185
+ self.dropout = torch.nn.Dropout(dropout)
186
+ self.conv2 = torch.nn.Conv2d(
187
+ out_channels, out_channels, kernel_size=3, stride=1, padding=1
188
+ )
189
+ if self.in_channels != self.out_channels:
190
+ if self.use_conv_shortcut:
191
+ self.conv_shortcut = torch.nn.Conv2d(
192
+ in_channels, out_channels, kernel_size=3, stride=1, padding=1
193
+ )
194
+ else:
195
+ self.nin_shortcut = torch.nn.Conv2d(
196
+ in_channels, out_channels, kernel_size=1, stride=1, padding=0
197
+ )
198
+
199
+ def forward(self, x, temb):
200
+ h = x
201
+ h = self.norm1(h)
202
+ h = nonlinearity(h)
203
+ h = self.conv1(h)
204
+
205
+ if temb is not None:
206
+ h = h + self.temb_proj(nonlinearity(temb))[:, :, None, None]
207
+
208
+ h = self.norm2(h)
209
+ h = nonlinearity(h)
210
+ h = self.dropout(h)
211
+ h = self.conv2(h)
212
+
213
+ if self.in_channels != self.out_channels:
214
+ if self.use_conv_shortcut:
215
+ x = self.conv_shortcut(x)
216
+ else:
217
+ x = self.nin_shortcut(x)
218
+
219
+ return x + h
220
+
221
+
222
+ class MultiHeadAttnBlock(nn.Module):
223
+ def __init__(self, in_channels, head_size=1):
224
+ super().__init__()
225
+ self.in_channels = in_channels
226
+ self.head_size = head_size
227
+ self.att_size = in_channels // head_size
228
+ assert (
229
+ in_channels % head_size == 0
230
+ ), "The size of head should be divided by the number of channels."
231
+
232
+ self.norm1 = Normalize(in_channels)
233
+ self.norm2 = Normalize(in_channels)
234
+
235
+ self.q = torch.nn.Conv2d(
236
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
237
+ )
238
+ self.k = torch.nn.Conv2d(
239
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
240
+ )
241
+ self.v = torch.nn.Conv2d(
242
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
243
+ )
244
+ self.proj_out = torch.nn.Conv2d(
245
+ in_channels, in_channels, kernel_size=1, stride=1, padding=0
246
+ )
247
+ self.num = 0
248
+
249
+ def forward(self, x, y=None):
250
+ h_ = x
251
+ h_ = self.norm1(h_)
252
+ if y is None:
253
+ y = h_
254
+ else:
255
+ y = self.norm2(y)
256
+
257
+ q = self.q(y)
258
+ k = self.k(h_)
259
+ v = self.v(h_)
260
+
261
+ # compute attention
262
+ b, c, h, w = q.shape
263
+ q = q.reshape(b, self.head_size, self.att_size, h * w)
264
+ q = q.permute(0, 3, 1, 2) # b, hw, head, att
265
+
266
+ k = k.reshape(b, self.head_size, self.att_size, h * w)
267
+ k = k.permute(0, 3, 1, 2)
268
+
269
+ v = v.reshape(b, self.head_size, self.att_size, h * w)
270
+ v = v.permute(0, 3, 1, 2)
271
+
272
+ q = q.transpose(1, 2)
273
+ v = v.transpose(1, 2)
274
+ k = k.transpose(1, 2).transpose(2, 3)
275
+
276
+ scale = int(self.att_size) ** (-0.5)
277
+ q.mul_(scale)
278
+ w_ = torch.matmul(q, k)
279
+ w_ = F.softmax(w_, dim=3)
280
+
281
+ w_ = w_.matmul(v)
282
+
283
+ w_ = w_.transpose(1, 2).contiguous() # [b, h*w, head, att]
284
+ w_ = w_.view(b, h, w, -1)
285
+ w_ = w_.permute(0, 3, 1, 2)
286
+
287
+ w_ = self.proj_out(w_)
288
+
289
+ return x + w_
290
+
291
+
292
+ class MultiHeadEncoder(nn.Module):
293
+ def __init__(
294
+ self,
295
+ ch,
296
+ out_ch,
297
+ ch_mult=(1, 2, 4, 8),
298
+ num_res_blocks=2,
299
+ attn_resolutions=(16,),
300
+ dropout=0.0,
301
+ resamp_with_conv=True,
302
+ in_channels=3,
303
+ resolution=512,
304
+ z_channels=256,
305
+ double_z=True,
306
+ enable_mid=True,
307
+ head_size=1,
308
+ **ignore_kwargs
309
+ ):
310
+ super().__init__()
311
+ self.ch = ch
312
+ self.temb_ch = 0
313
+ self.num_resolutions = len(ch_mult)
314
+ self.num_res_blocks = num_res_blocks
315
+ self.resolution = resolution
316
+ self.in_channels = in_channels
317
+ self.enable_mid = enable_mid
318
+
319
+ # downsampling
320
+ self.conv_in = torch.nn.Conv2d(
321
+ in_channels, self.ch, kernel_size=3, stride=1, padding=1
322
+ )
323
+
324
+ curr_res = resolution
325
+ in_ch_mult = (1,) + tuple(ch_mult)
326
+ self.down = nn.ModuleList()
327
+ for i_level in range(self.num_resolutions):
328
+ block = nn.ModuleList()
329
+ attn = nn.ModuleList()
330
+ block_in = ch * in_ch_mult[i_level]
331
+ block_out = ch * ch_mult[i_level]
332
+ for i_block in range(self.num_res_blocks):
333
+ block.append(
334
+ ResnetBlock(
335
+ in_channels=block_in,
336
+ out_channels=block_out,
337
+ temb_channels=self.temb_ch,
338
+ dropout=dropout,
339
+ )
340
+ )
341
+ block_in = block_out
342
+ if curr_res in attn_resolutions:
343
+ attn.append(MultiHeadAttnBlock(block_in, head_size))
344
+ down = nn.Module()
345
+ down.block = block
346
+ down.attn = attn
347
+ if i_level != self.num_resolutions - 1:
348
+ down.downsample = Downsample(block_in, resamp_with_conv)
349
+ curr_res = curr_res // 2
350
+ self.down.append(down)
351
+
352
+ # middle
353
+ if self.enable_mid:
354
+ self.mid = nn.Module()
355
+ self.mid.block_1 = ResnetBlock(
356
+ in_channels=block_in,
357
+ out_channels=block_in,
358
+ temb_channels=self.temb_ch,
359
+ dropout=dropout,
360
+ )
361
+ self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
362
+ self.mid.block_2 = ResnetBlock(
363
+ in_channels=block_in,
364
+ out_channels=block_in,
365
+ temb_channels=self.temb_ch,
366
+ dropout=dropout,
367
+ )
368
+
369
+ # end
370
+ self.norm_out = Normalize(block_in)
371
+ self.conv_out = torch.nn.Conv2d(
372
+ block_in,
373
+ 2 * z_channels if double_z else z_channels,
374
+ kernel_size=3,
375
+ stride=1,
376
+ padding=1,
377
+ )
378
+
379
+ def forward(self, x):
380
+ hs = {}
381
+ # timestep embedding
382
+ temb = None
383
+
384
+ # downsampling
385
+ h = self.conv_in(x)
386
+ hs["in"] = h
387
+ for i_level in range(self.num_resolutions):
388
+ for i_block in range(self.num_res_blocks):
389
+ h = self.down[i_level].block[i_block](h, temb)
390
+ if len(self.down[i_level].attn) > 0:
391
+ h = self.down[i_level].attn[i_block](h)
392
+
393
+ if i_level != self.num_resolutions - 1:
394
+ # hs.append(h)
395
+ hs["block_" + str(i_level)] = h
396
+ h = self.down[i_level].downsample(h)
397
+
398
+ # middle
399
+ # h = hs[-1]
400
+ if self.enable_mid:
401
+ h = self.mid.block_1(h, temb)
402
+ hs["block_" + str(i_level) + "_atten"] = h
403
+ h = self.mid.attn_1(h)
404
+ h = self.mid.block_2(h, temb)
405
+ hs["mid_atten"] = h
406
+
407
+ # end
408
+ h = self.norm_out(h)
409
+ h = nonlinearity(h)
410
+ h = self.conv_out(h)
411
+ # hs.append(h)
412
+ hs["out"] = h
413
+
414
+ return hs
415
+
416
+
417
+ class MultiHeadDecoder(nn.Module):
418
+ def __init__(
419
+ self,
420
+ ch,
421
+ out_ch,
422
+ ch_mult=(1, 2, 4, 8),
423
+ num_res_blocks=2,
424
+ attn_resolutions=(16,),
425
+ dropout=0.0,
426
+ resamp_with_conv=True,
427
+ in_channels=3,
428
+ resolution=512,
429
+ z_channels=256,
430
+ give_pre_end=False,
431
+ enable_mid=True,
432
+ head_size=1,
433
+ **ignorekwargs
434
+ ):
435
+ super().__init__()
436
+ self.ch = ch
437
+ self.temb_ch = 0
438
+ self.num_resolutions = len(ch_mult)
439
+ self.num_res_blocks = num_res_blocks
440
+ self.resolution = resolution
441
+ self.in_channels = in_channels
442
+ self.give_pre_end = give_pre_end
443
+ self.enable_mid = enable_mid
444
+
445
+ # compute in_ch_mult, block_in and curr_res at lowest res
446
+ block_in = ch * ch_mult[self.num_resolutions - 1]
447
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
448
+ self.z_shape = (1, z_channels, curr_res, curr_res)
449
+ print(
450
+ "Working with z of shape {} = {} dimensions.".format(
451
+ self.z_shape, np.prod(self.z_shape)
452
+ )
453
+ )
454
+
455
+ # z to block_in
456
+ self.conv_in = torch.nn.Conv2d(
457
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
458
+ )
459
+
460
+ # middle
461
+ if self.enable_mid:
462
+ self.mid = nn.Module()
463
+ self.mid.block_1 = ResnetBlock(
464
+ in_channels=block_in,
465
+ out_channels=block_in,
466
+ temb_channels=self.temb_ch,
467
+ dropout=dropout,
468
+ )
469
+ self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
470
+ self.mid.block_2 = ResnetBlock(
471
+ in_channels=block_in,
472
+ out_channels=block_in,
473
+ temb_channels=self.temb_ch,
474
+ dropout=dropout,
475
+ )
476
+
477
+ # upsampling
478
+ self.up = nn.ModuleList()
479
+ for i_level in reversed(range(self.num_resolutions)):
480
+ block = nn.ModuleList()
481
+ attn = nn.ModuleList()
482
+ block_out = ch * ch_mult[i_level]
483
+ for i_block in range(self.num_res_blocks + 1):
484
+ block.append(
485
+ ResnetBlock(
486
+ in_channels=block_in,
487
+ out_channels=block_out,
488
+ temb_channels=self.temb_ch,
489
+ dropout=dropout,
490
+ )
491
+ )
492
+ block_in = block_out
493
+ if curr_res in attn_resolutions:
494
+ attn.append(MultiHeadAttnBlock(block_in, head_size))
495
+ up = nn.Module()
496
+ up.block = block
497
+ up.attn = attn
498
+ if i_level != 0:
499
+ up.upsample = Upsample(block_in, resamp_with_conv)
500
+ curr_res = curr_res * 2
501
+ self.up.insert(0, up) # prepend to get consistent order
502
+
503
+ # end
504
+ self.norm_out = Normalize(block_in)
505
+ self.conv_out = torch.nn.Conv2d(
506
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
507
+ )
508
+
509
+ def forward(self, z):
510
+ # assert z.shape[1:] == self.z_shape[1:]
511
+ self.last_z_shape = z.shape
512
+
513
+ # timestep embedding
514
+ temb = None
515
+
516
+ # z to block_in
517
+ h = self.conv_in(z)
518
+
519
+ # middle
520
+ if self.enable_mid:
521
+ h = self.mid.block_1(h, temb)
522
+ h = self.mid.attn_1(h)
523
+ h = self.mid.block_2(h, temb)
524
+
525
+ # upsampling
526
+ for i_level in reversed(range(self.num_resolutions)):
527
+ for i_block in range(self.num_res_blocks + 1):
528
+ h = self.up[i_level].block[i_block](h, temb)
529
+ if len(self.up[i_level].attn) > 0:
530
+ h = self.up[i_level].attn[i_block](h)
531
+ if i_level != 0:
532
+ h = self.up[i_level].upsample(h)
533
+
534
+ # end
535
+ if self.give_pre_end:
536
+ return h
537
+
538
+ h = self.norm_out(h)
539
+ h = nonlinearity(h)
540
+ h = self.conv_out(h)
541
+ return h
542
+
543
+
544
+ class MultiHeadDecoderTransformer(nn.Module):
545
+ def __init__(
546
+ self,
547
+ ch,
548
+ out_ch,
549
+ ch_mult=(1, 2, 4, 8),
550
+ num_res_blocks=2,
551
+ attn_resolutions=(16,),
552
+ dropout=0.0,
553
+ resamp_with_conv=True,
554
+ in_channels=3,
555
+ resolution=512,
556
+ z_channels=256,
557
+ give_pre_end=False,
558
+ enable_mid=True,
559
+ head_size=1,
560
+ **ignorekwargs
561
+ ):
562
+ super().__init__()
563
+ self.ch = ch
564
+ self.temb_ch = 0
565
+ self.num_resolutions = len(ch_mult)
566
+ self.num_res_blocks = num_res_blocks
567
+ self.resolution = resolution
568
+ self.in_channels = in_channels
569
+ self.give_pre_end = give_pre_end
570
+ self.enable_mid = enable_mid
571
+
572
+ # compute in_ch_mult, block_in and curr_res at lowest res
573
+ block_in = ch * ch_mult[self.num_resolutions - 1]
574
+ curr_res = resolution // 2 ** (self.num_resolutions - 1)
575
+ self.z_shape = (1, z_channels, curr_res, curr_res)
576
+ print(
577
+ "Working with z of shape {} = {} dimensions.".format(
578
+ self.z_shape, np.prod(self.z_shape)
579
+ )
580
+ )
581
+
582
+ # z to block_in
583
+ self.conv_in = torch.nn.Conv2d(
584
+ z_channels, block_in, kernel_size=3, stride=1, padding=1
585
+ )
586
+
587
+ # middle
588
+ if self.enable_mid:
589
+ self.mid = nn.Module()
590
+ self.mid.block_1 = ResnetBlock(
591
+ in_channels=block_in,
592
+ out_channels=block_in,
593
+ temb_channels=self.temb_ch,
594
+ dropout=dropout,
595
+ )
596
+ self.mid.attn_1 = MultiHeadAttnBlock(block_in, head_size)
597
+ self.mid.block_2 = ResnetBlock(
598
+ in_channels=block_in,
599
+ out_channels=block_in,
600
+ temb_channels=self.temb_ch,
601
+ dropout=dropout,
602
+ )
603
+
604
+ # upsampling
605
+ self.up = nn.ModuleList()
606
+ for i_level in reversed(range(self.num_resolutions)):
607
+ block = nn.ModuleList()
608
+ attn = nn.ModuleList()
609
+ block_out = ch * ch_mult[i_level]
610
+ for i_block in range(self.num_res_blocks + 1):
611
+ block.append(
612
+ ResnetBlock(
613
+ in_channels=block_in,
614
+ out_channels=block_out,
615
+ temb_channels=self.temb_ch,
616
+ dropout=dropout,
617
+ )
618
+ )
619
+ block_in = block_out
620
+ if curr_res in attn_resolutions:
621
+ attn.append(MultiHeadAttnBlock(block_in, head_size))
622
+ up = nn.Module()
623
+ up.block = block
624
+ up.attn = attn
625
+ if i_level != 0:
626
+ up.upsample = Upsample(block_in, resamp_with_conv)
627
+ curr_res = curr_res * 2
628
+ self.up.insert(0, up) # prepend to get consistent order
629
+
630
+ # end
631
+ self.norm_out = Normalize(block_in)
632
+ self.conv_out = torch.nn.Conv2d(
633
+ block_in, out_ch, kernel_size=3, stride=1, padding=1
634
+ )
635
+
636
+ def forward(self, z, hs):
637
+ # assert z.shape[1:] == self.z_shape[1:]
638
+ # self.last_z_shape = z.shape
639
+
640
+ # timestep embedding
641
+ temb = None
642
+
643
+ # z to block_in
644
+ h = self.conv_in(z)
645
+
646
+ # middle
647
+ if self.enable_mid:
648
+ h = self.mid.block_1(h, temb)
649
+ h = self.mid.attn_1(h, hs["mid_atten"])
650
+ h = self.mid.block_2(h, temb)
651
+
652
+ # upsampling
653
+ for i_level in reversed(range(self.num_resolutions)):
654
+ for i_block in range(self.num_res_blocks + 1):
655
+ h = self.up[i_level].block[i_block](h, temb)
656
+ if len(self.up[i_level].attn) > 0:
657
+ h = self.up[i_level].attn[i_block](
658
+ h, hs["block_" + str(i_level) + "_atten"]
659
+ )
660
+ # hfeature = h.clone()
661
+ if i_level != 0:
662
+ h = self.up[i_level].upsample(h)
663
+
664
+ # end
665
+ if self.give_pre_end:
666
+ return h
667
+
668
+ h = self.norm_out(h)
669
+ h = nonlinearity(h)
670
+ h = self.conv_out(h)
671
+ return h
672
+
673
+
674
+ class RestoreFormer(nn.Module):
675
+ def __init__(
676
+ self,
677
+ state_dict,
678
+ ):
679
+ super(RestoreFormer, self).__init__()
680
+
681
+ n_embed = 1024
682
+ embed_dim = 256
683
+ ch = 64
684
+ out_ch = 3
685
+ ch_mult = (1, 2, 2, 4, 4, 8)
686
+ num_res_blocks = 2
687
+ attn_resolutions = (16,)
688
+ dropout = 0.0
689
+ in_channels = 3
690
+ resolution = 512
691
+ z_channels = 256
692
+ double_z = False
693
+ enable_mid = True
694
+ fix_decoder = False
695
+ fix_codebook = True
696
+ fix_encoder = False
697
+ head_size = 8
698
+
699
+ self.model_arch = "RestoreFormer"
700
+ self.sub_type = "Face SR"
701
+ self.scale = 8
702
+ self.in_nc = 3
703
+ self.out_nc = out_ch
704
+ self.state = state_dict
705
+
706
+ self.supports_fp16 = False
707
+ self.supports_bf16 = True
708
+ self.min_size_restriction = 16
709
+
710
+ self.encoder = MultiHeadEncoder(
711
+ ch=ch,
712
+ out_ch=out_ch,
713
+ ch_mult=ch_mult,
714
+ num_res_blocks=num_res_blocks,
715
+ attn_resolutions=attn_resolutions,
716
+ dropout=dropout,
717
+ in_channels=in_channels,
718
+ resolution=resolution,
719
+ z_channels=z_channels,
720
+ double_z=double_z,
721
+ enable_mid=enable_mid,
722
+ head_size=head_size,
723
+ )
724
+ self.decoder = MultiHeadDecoderTransformer(
725
+ ch=ch,
726
+ out_ch=out_ch,
727
+ ch_mult=ch_mult,
728
+ num_res_blocks=num_res_blocks,
729
+ attn_resolutions=attn_resolutions,
730
+ dropout=dropout,
731
+ in_channels=in_channels,
732
+ resolution=resolution,
733
+ z_channels=z_channels,
734
+ enable_mid=enable_mid,
735
+ head_size=head_size,
736
+ )
737
+
738
+ self.quantize = VectorQuantizer(n_embed, embed_dim, beta=0.25)
739
+
740
+ self.quant_conv = torch.nn.Conv2d(z_channels, embed_dim, 1)
741
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, z_channels, 1)
742
+
743
+ if fix_decoder:
744
+ for _, param in self.decoder.named_parameters():
745
+ param.requires_grad = False
746
+ for _, param in self.post_quant_conv.named_parameters():
747
+ param.requires_grad = False
748
+ for _, param in self.quantize.named_parameters():
749
+ param.requires_grad = False
750
+ elif fix_codebook:
751
+ for _, param in self.quantize.named_parameters():
752
+ param.requires_grad = False
753
+
754
+ if fix_encoder:
755
+ for _, param in self.encoder.named_parameters():
756
+ param.requires_grad = False
757
+
758
+ self.load_state_dict(state_dict)
759
+
760
+ def encode(self, x):
761
+ hs = self.encoder(x)
762
+ h = self.quant_conv(hs["out"])
763
+ quant, emb_loss, info = self.quantize(h)
764
+ return quant, emb_loss, info, hs
765
+
766
+ def decode(self, quant, hs):
767
+ quant = self.post_quant_conv(quant)
768
+ dec = self.decoder(quant, hs)
769
+
770
+ return dec
771
+
772
+ def forward(self, input, **kwargs):
773
+ quant, diff, info, hs = self.encode(input)
774
+ dec = self.decode(quant, hs)
775
+
776
+ return dec, None
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/face/stylegan2_arch.py ADDED
@@ -0,0 +1,865 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ # type: ignore
3
+ import math
4
+ import random
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+ from .fused_act import FusedLeakyReLU, fused_leaky_relu
11
+ from .upfirdn2d import upfirdn2d
12
+
13
+
14
+ class NormStyleCode(nn.Module):
15
+ def forward(self, x):
16
+ """Normalize the style codes.
17
+
18
+ Args:
19
+ x (Tensor): Style codes with shape (b, c).
20
+
21
+ Returns:
22
+ Tensor: Normalized tensor.
23
+ """
24
+ return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
25
+
26
+
27
+ def make_resample_kernel(k):
28
+ """Make resampling kernel for UpFirDn.
29
+
30
+ Args:
31
+ k (list[int]): A list indicating the 1D resample kernel magnitude.
32
+
33
+ Returns:
34
+ Tensor: 2D resampled kernel.
35
+ """
36
+ k = torch.tensor(k, dtype=torch.float32)
37
+ if k.ndim == 1:
38
+ k = k[None, :] * k[:, None] # to 2D kernel, outer product
39
+ # normalize
40
+ k /= k.sum()
41
+ return k
42
+
43
+
44
+ class UpFirDnUpsample(nn.Module):
45
+ """Upsample, FIR filter, and downsample (upsampole version).
46
+
47
+ References:
48
+ 1. https://docs.scipy.org/doc/scipy/reference/generated/scipy.signal.upfirdn.html # noqa: E501
49
+ 2. http://www.ece.northwestern.edu/local-apps/matlabhelp/toolbox/signal/upfirdn.html # noqa: E501
50
+
51
+ Args:
52
+ resample_kernel (list[int]): A list indicating the 1D resample kernel
53
+ magnitude.
54
+ factor (int): Upsampling scale factor. Default: 2.
55
+ """
56
+
57
+ def __init__(self, resample_kernel, factor=2):
58
+ super(UpFirDnUpsample, self).__init__()
59
+ self.kernel = make_resample_kernel(resample_kernel) * (factor**2)
60
+ self.factor = factor
61
+
62
+ pad = self.kernel.shape[0] - factor
63
+ self.pad = ((pad + 1) // 2 + factor - 1, pad // 2)
64
+
65
+ def forward(self, x):
66
+ out = upfirdn2d(x, self.kernel.type_as(x), up=self.factor, down=1, pad=self.pad)
67
+ return out
68
+
69
+ def __repr__(self):
70
+ return f"{self.__class__.__name__}(factor={self.factor})"
71
+
72
+
73
+ class UpFirDnDownsample(nn.Module):
74
+ """Upsample, FIR filter, and downsample (downsampole version).
75
+
76
+ Args:
77
+ resample_kernel (list[int]): A list indicating the 1D resample kernel
78
+ magnitude.
79
+ factor (int): Downsampling scale factor. Default: 2.
80
+ """
81
+
82
+ def __init__(self, resample_kernel, factor=2):
83
+ super(UpFirDnDownsample, self).__init__()
84
+ self.kernel = make_resample_kernel(resample_kernel)
85
+ self.factor = factor
86
+
87
+ pad = self.kernel.shape[0] - factor
88
+ self.pad = ((pad + 1) // 2, pad // 2)
89
+
90
+ def forward(self, x):
91
+ out = upfirdn2d(x, self.kernel.type_as(x), up=1, down=self.factor, pad=self.pad)
92
+ return out
93
+
94
+ def __repr__(self):
95
+ return f"{self.__class__.__name__}(factor={self.factor})"
96
+
97
+
98
+ class UpFirDnSmooth(nn.Module):
99
+ """Upsample, FIR filter, and downsample (smooth version).
100
+
101
+ Args:
102
+ resample_kernel (list[int]): A list indicating the 1D resample kernel
103
+ magnitude.
104
+ upsample_factor (int): Upsampling scale factor. Default: 1.
105
+ downsample_factor (int): Downsampling scale factor. Default: 1.
106
+ kernel_size (int): Kernel size: Default: 1.
107
+ """
108
+
109
+ def __init__(
110
+ self, resample_kernel, upsample_factor=1, downsample_factor=1, kernel_size=1
111
+ ):
112
+ super(UpFirDnSmooth, self).__init__()
113
+ self.upsample_factor = upsample_factor
114
+ self.downsample_factor = downsample_factor
115
+ self.kernel = make_resample_kernel(resample_kernel)
116
+ if upsample_factor > 1:
117
+ self.kernel = self.kernel * (upsample_factor**2)
118
+
119
+ if upsample_factor > 1:
120
+ pad = (self.kernel.shape[0] - upsample_factor) - (kernel_size - 1)
121
+ self.pad = ((pad + 1) // 2 + upsample_factor - 1, pad // 2 + 1)
122
+ elif downsample_factor > 1:
123
+ pad = (self.kernel.shape[0] - downsample_factor) + (kernel_size - 1)
124
+ self.pad = ((pad + 1) // 2, pad // 2)
125
+ else:
126
+ raise NotImplementedError
127
+
128
+ def forward(self, x):
129
+ out = upfirdn2d(x, self.kernel.type_as(x), up=1, down=1, pad=self.pad)
130
+ return out
131
+
132
+ def __repr__(self):
133
+ return (
134
+ f"{self.__class__.__name__}(upsample_factor={self.upsample_factor}"
135
+ f", downsample_factor={self.downsample_factor})"
136
+ )
137
+
138
+
139
+ class EqualLinear(nn.Module):
140
+ """Equalized Linear as StyleGAN2.
141
+
142
+ Args:
143
+ in_channels (int): Size of each sample.
144
+ out_channels (int): Size of each output sample.
145
+ bias (bool): If set to ``False``, the layer will not learn an additive
146
+ bias. Default: ``True``.
147
+ bias_init_val (float): Bias initialized value. Default: 0.
148
+ lr_mul (float): Learning rate multiplier. Default: 1.
149
+ activation (None | str): The activation after ``linear`` operation.
150
+ Supported: 'fused_lrelu', None. Default: None.
151
+ """
152
+
153
+ def __init__(
154
+ self,
155
+ in_channels,
156
+ out_channels,
157
+ bias=True,
158
+ bias_init_val=0,
159
+ lr_mul=1,
160
+ activation=None,
161
+ ):
162
+ super(EqualLinear, self).__init__()
163
+ self.in_channels = in_channels
164
+ self.out_channels = out_channels
165
+ self.lr_mul = lr_mul
166
+ self.activation = activation
167
+ if self.activation not in ["fused_lrelu", None]:
168
+ raise ValueError(
169
+ f"Wrong activation value in EqualLinear: {activation}"
170
+ "Supported ones are: ['fused_lrelu', None]."
171
+ )
172
+ self.scale = (1 / math.sqrt(in_channels)) * lr_mul
173
+
174
+ self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul))
175
+ if bias:
176
+ self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
177
+ else:
178
+ self.register_parameter("bias", None)
179
+
180
+ def forward(self, x):
181
+ if self.bias is None:
182
+ bias = None
183
+ else:
184
+ bias = self.bias * self.lr_mul
185
+ if self.activation == "fused_lrelu":
186
+ out = F.linear(x, self.weight * self.scale)
187
+ out = fused_leaky_relu(out, bias)
188
+ else:
189
+ out = F.linear(x, self.weight * self.scale, bias=bias)
190
+ return out
191
+
192
+ def __repr__(self):
193
+ return (
194
+ f"{self.__class__.__name__}(in_channels={self.in_channels}, "
195
+ f"out_channels={self.out_channels}, bias={self.bias is not None})"
196
+ )
197
+
198
+
199
+ class ModulatedConv2d(nn.Module):
200
+ """Modulated Conv2d used in StyleGAN2.
201
+
202
+ There is no bias in ModulatedConv2d.
203
+
204
+ Args:
205
+ in_channels (int): Channel number of the input.
206
+ out_channels (int): Channel number of the output.
207
+ kernel_size (int): Size of the convolving kernel.
208
+ num_style_feat (int): Channel number of style features.
209
+ demodulate (bool): Whether to demodulate in the conv layer.
210
+ Default: True.
211
+ sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
212
+ Default: None.
213
+ resample_kernel (list[int]): A list indicating the 1D resample kernel
214
+ magnitude. Default: (1, 3, 3, 1).
215
+ eps (float): A value added to the denominator for numerical stability.
216
+ Default: 1e-8.
217
+ """
218
+
219
+ def __init__(
220
+ self,
221
+ in_channels,
222
+ out_channels,
223
+ kernel_size,
224
+ num_style_feat,
225
+ demodulate=True,
226
+ sample_mode=None,
227
+ resample_kernel=(1, 3, 3, 1),
228
+ eps=1e-8,
229
+ ):
230
+ super(ModulatedConv2d, self).__init__()
231
+ self.in_channels = in_channels
232
+ self.out_channels = out_channels
233
+ self.kernel_size = kernel_size
234
+ self.demodulate = demodulate
235
+ self.sample_mode = sample_mode
236
+ self.eps = eps
237
+
238
+ if self.sample_mode == "upsample":
239
+ self.smooth = UpFirDnSmooth(
240
+ resample_kernel,
241
+ upsample_factor=2,
242
+ downsample_factor=1,
243
+ kernel_size=kernel_size,
244
+ )
245
+ elif self.sample_mode == "downsample":
246
+ self.smooth = UpFirDnSmooth(
247
+ resample_kernel,
248
+ upsample_factor=1,
249
+ downsample_factor=2,
250
+ kernel_size=kernel_size,
251
+ )
252
+ elif self.sample_mode is None:
253
+ pass
254
+ else:
255
+ raise ValueError(
256
+ f"Wrong sample mode {self.sample_mode}, "
257
+ "supported ones are ['upsample', 'downsample', None]."
258
+ )
259
+
260
+ self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
261
+ # modulation inside each modulated conv
262
+ self.modulation = EqualLinear(
263
+ num_style_feat,
264
+ in_channels,
265
+ bias=True,
266
+ bias_init_val=1,
267
+ lr_mul=1,
268
+ activation=None,
269
+ )
270
+
271
+ self.weight = nn.Parameter(
272
+ torch.randn(1, out_channels, in_channels, kernel_size, kernel_size)
273
+ )
274
+ self.padding = kernel_size // 2
275
+
276
+ def forward(self, x, style):
277
+ """Forward function.
278
+
279
+ Args:
280
+ x (Tensor): Tensor with shape (b, c, h, w).
281
+ style (Tensor): Tensor with shape (b, num_style_feat).
282
+
283
+ Returns:
284
+ Tensor: Modulated tensor after convolution.
285
+ """
286
+ b, c, h, w = x.shape # c = c_in
287
+ # weight modulation
288
+ style = self.modulation(style).view(b, 1, c, 1, 1)
289
+ # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
290
+ weight = self.scale * self.weight * style # (b, c_out, c_in, k, k)
291
+
292
+ if self.demodulate:
293
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
294
+ weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
295
+
296
+ weight = weight.view(
297
+ b * self.out_channels, c, self.kernel_size, self.kernel_size
298
+ )
299
+
300
+ if self.sample_mode == "upsample":
301
+ x = x.view(1, b * c, h, w)
302
+ weight = weight.view(
303
+ b, self.out_channels, c, self.kernel_size, self.kernel_size
304
+ )
305
+ weight = weight.transpose(1, 2).reshape(
306
+ b * c, self.out_channels, self.kernel_size, self.kernel_size
307
+ )
308
+ out = F.conv_transpose2d(x, weight, padding=0, stride=2, groups=b)
309
+ out = out.view(b, self.out_channels, *out.shape[2:4])
310
+ out = self.smooth(out)
311
+ elif self.sample_mode == "downsample":
312
+ x = self.smooth(x)
313
+ x = x.view(1, b * c, *x.shape[2:4])
314
+ out = F.conv2d(x, weight, padding=0, stride=2, groups=b)
315
+ out = out.view(b, self.out_channels, *out.shape[2:4])
316
+ else:
317
+ x = x.view(1, b * c, h, w)
318
+ # weight: (b*c_out, c_in, k, k), groups=b
319
+ out = F.conv2d(x, weight, padding=self.padding, groups=b)
320
+ out = out.view(b, self.out_channels, *out.shape[2:4])
321
+
322
+ return out
323
+
324
+ def __repr__(self):
325
+ return (
326
+ f"{self.__class__.__name__}(in_channels={self.in_channels}, "
327
+ f"out_channels={self.out_channels}, "
328
+ f"kernel_size={self.kernel_size}, "
329
+ f"demodulate={self.demodulate}, sample_mode={self.sample_mode})"
330
+ )
331
+
332
+
333
+ class StyleConv(nn.Module):
334
+ """Style conv.
335
+
336
+ Args:
337
+ in_channels (int): Channel number of the input.
338
+ out_channels (int): Channel number of the output.
339
+ kernel_size (int): Size of the convolving kernel.
340
+ num_style_feat (int): Channel number of style features.
341
+ demodulate (bool): Whether demodulate in the conv layer. Default: True.
342
+ sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
343
+ Default: None.
344
+ resample_kernel (list[int]): A list indicating the 1D resample kernel
345
+ magnitude. Default: (1, 3, 3, 1).
346
+ """
347
+
348
+ def __init__(
349
+ self,
350
+ in_channels,
351
+ out_channels,
352
+ kernel_size,
353
+ num_style_feat,
354
+ demodulate=True,
355
+ sample_mode=None,
356
+ resample_kernel=(1, 3, 3, 1),
357
+ ):
358
+ super(StyleConv, self).__init__()
359
+ self.modulated_conv = ModulatedConv2d(
360
+ in_channels,
361
+ out_channels,
362
+ kernel_size,
363
+ num_style_feat,
364
+ demodulate=demodulate,
365
+ sample_mode=sample_mode,
366
+ resample_kernel=resample_kernel,
367
+ )
368
+ self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
369
+ self.activate = FusedLeakyReLU(out_channels)
370
+
371
+ def forward(self, x, style, noise=None):
372
+ # modulate
373
+ out = self.modulated_conv(x, style)
374
+ # noise injection
375
+ if noise is None:
376
+ b, _, h, w = out.shape
377
+ noise = out.new_empty(b, 1, h, w).normal_()
378
+ out = out + self.weight * noise
379
+ # activation (with bias)
380
+ out = self.activate(out)
381
+ return out
382
+
383
+
384
+ class ToRGB(nn.Module):
385
+ """To RGB from features.
386
+
387
+ Args:
388
+ in_channels (int): Channel number of input.
389
+ num_style_feat (int): Channel number of style features.
390
+ upsample (bool): Whether to upsample. Default: True.
391
+ resample_kernel (list[int]): A list indicating the 1D resample kernel
392
+ magnitude. Default: (1, 3, 3, 1).
393
+ """
394
+
395
+ def __init__(
396
+ self, in_channels, num_style_feat, upsample=True, resample_kernel=(1, 3, 3, 1)
397
+ ):
398
+ super(ToRGB, self).__init__()
399
+ if upsample:
400
+ self.upsample = UpFirDnUpsample(resample_kernel, factor=2)
401
+ else:
402
+ self.upsample = None
403
+ self.modulated_conv = ModulatedConv2d(
404
+ in_channels,
405
+ 3,
406
+ kernel_size=1,
407
+ num_style_feat=num_style_feat,
408
+ demodulate=False,
409
+ sample_mode=None,
410
+ )
411
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
412
+
413
+ def forward(self, x, style, skip=None):
414
+ """Forward function.
415
+
416
+ Args:
417
+ x (Tensor): Feature tensor with shape (b, c, h, w).
418
+ style (Tensor): Tensor with shape (b, num_style_feat).
419
+ skip (Tensor): Base/skip tensor. Default: None.
420
+
421
+ Returns:
422
+ Tensor: RGB images.
423
+ """
424
+ out = self.modulated_conv(x, style)
425
+ out = out + self.bias
426
+ if skip is not None:
427
+ if self.upsample:
428
+ skip = self.upsample(skip)
429
+ out = out + skip
430
+ return out
431
+
432
+
433
+ class ConstantInput(nn.Module):
434
+ """Constant input.
435
+
436
+ Args:
437
+ num_channel (int): Channel number of constant input.
438
+ size (int): Spatial size of constant input.
439
+ """
440
+
441
+ def __init__(self, num_channel, size):
442
+ super(ConstantInput, self).__init__()
443
+ self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
444
+
445
+ def forward(self, batch):
446
+ out = self.weight.repeat(batch, 1, 1, 1)
447
+ return out
448
+
449
+
450
+ class StyleGAN2Generator(nn.Module):
451
+ """StyleGAN2 Generator.
452
+
453
+ Args:
454
+ out_size (int): The spatial size of outputs.
455
+ num_style_feat (int): Channel number of style features. Default: 512.
456
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
457
+ channel_multiplier (int): Channel multiplier for large networks of
458
+ StyleGAN2. Default: 2.
459
+ resample_kernel (list[int]): A list indicating the 1D resample kernel
460
+ magnitude. A cross production will be applied to extent 1D resample
461
+ kernel to 2D resample kernel. Default: (1, 3, 3, 1).
462
+ lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
463
+ narrow (float): Narrow ratio for channels. Default: 1.0.
464
+ """
465
+
466
+ def __init__(
467
+ self,
468
+ out_size,
469
+ num_style_feat=512,
470
+ num_mlp=8,
471
+ channel_multiplier=2,
472
+ resample_kernel=(1, 3, 3, 1),
473
+ lr_mlp=0.01,
474
+ narrow=1,
475
+ ):
476
+ super(StyleGAN2Generator, self).__init__()
477
+ # Style MLP layers
478
+ self.num_style_feat = num_style_feat
479
+ style_mlp_layers = [NormStyleCode()]
480
+ for i in range(num_mlp):
481
+ style_mlp_layers.append(
482
+ EqualLinear(
483
+ num_style_feat,
484
+ num_style_feat,
485
+ bias=True,
486
+ bias_init_val=0,
487
+ lr_mul=lr_mlp,
488
+ activation="fused_lrelu",
489
+ )
490
+ )
491
+ self.style_mlp = nn.Sequential(*style_mlp_layers)
492
+
493
+ channels = {
494
+ "4": int(512 * narrow),
495
+ "8": int(512 * narrow),
496
+ "16": int(512 * narrow),
497
+ "32": int(512 * narrow),
498
+ "64": int(256 * channel_multiplier * narrow),
499
+ "128": int(128 * channel_multiplier * narrow),
500
+ "256": int(64 * channel_multiplier * narrow),
501
+ "512": int(32 * channel_multiplier * narrow),
502
+ "1024": int(16 * channel_multiplier * narrow),
503
+ }
504
+ self.channels = channels
505
+
506
+ self.constant_input = ConstantInput(channels["4"], size=4)
507
+ self.style_conv1 = StyleConv(
508
+ channels["4"],
509
+ channels["4"],
510
+ kernel_size=3,
511
+ num_style_feat=num_style_feat,
512
+ demodulate=True,
513
+ sample_mode=None,
514
+ resample_kernel=resample_kernel,
515
+ )
516
+ self.to_rgb1 = ToRGB(
517
+ channels["4"],
518
+ num_style_feat,
519
+ upsample=False,
520
+ resample_kernel=resample_kernel,
521
+ )
522
+
523
+ self.log_size = int(math.log(out_size, 2))
524
+ self.num_layers = (self.log_size - 2) * 2 + 1
525
+ self.num_latent = self.log_size * 2 - 2
526
+
527
+ self.style_convs = nn.ModuleList()
528
+ self.to_rgbs = nn.ModuleList()
529
+ self.noises = nn.Module()
530
+
531
+ in_channels = channels["4"]
532
+ # noise
533
+ for layer_idx in range(self.num_layers):
534
+ resolution = 2 ** ((layer_idx + 5) // 2)
535
+ shape = [1, 1, resolution, resolution]
536
+ self.noises.register_buffer(f"noise{layer_idx}", torch.randn(*shape))
537
+ # style convs and to_rgbs
538
+ for i in range(3, self.log_size + 1):
539
+ out_channels = channels[f"{2**i}"]
540
+ self.style_convs.append(
541
+ StyleConv(
542
+ in_channels,
543
+ out_channels,
544
+ kernel_size=3,
545
+ num_style_feat=num_style_feat,
546
+ demodulate=True,
547
+ sample_mode="upsample",
548
+ resample_kernel=resample_kernel,
549
+ )
550
+ )
551
+ self.style_convs.append(
552
+ StyleConv(
553
+ out_channels,
554
+ out_channels,
555
+ kernel_size=3,
556
+ num_style_feat=num_style_feat,
557
+ demodulate=True,
558
+ sample_mode=None,
559
+ resample_kernel=resample_kernel,
560
+ )
561
+ )
562
+ self.to_rgbs.append(
563
+ ToRGB(
564
+ out_channels,
565
+ num_style_feat,
566
+ upsample=True,
567
+ resample_kernel=resample_kernel,
568
+ )
569
+ )
570
+ in_channels = out_channels
571
+
572
+ def make_noise(self):
573
+ """Make noise for noise injection."""
574
+ device = self.constant_input.weight.device
575
+ noises = [torch.randn(1, 1, 4, 4, device=device)]
576
+
577
+ for i in range(3, self.log_size + 1):
578
+ for _ in range(2):
579
+ noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
580
+
581
+ return noises
582
+
583
+ def get_latent(self, x):
584
+ return self.style_mlp(x)
585
+
586
+ def mean_latent(self, num_latent):
587
+ latent_in = torch.randn(
588
+ num_latent, self.num_style_feat, device=self.constant_input.weight.device
589
+ )
590
+ latent = self.style_mlp(latent_in).mean(0, keepdim=True)
591
+ return latent
592
+
593
+ def forward(
594
+ self,
595
+ styles,
596
+ input_is_latent=False,
597
+ noise=None,
598
+ randomize_noise=True,
599
+ truncation=1,
600
+ truncation_latent=None,
601
+ inject_index=None,
602
+ return_latents=False,
603
+ ):
604
+ """Forward function for StyleGAN2Generator.
605
+
606
+ Args:
607
+ styles (list[Tensor]): Sample codes of styles.
608
+ input_is_latent (bool): Whether input is latent style.
609
+ Default: False.
610
+ noise (Tensor | None): Input noise or None. Default: None.
611
+ randomize_noise (bool): Randomize noise, used when 'noise' is
612
+ False. Default: True.
613
+ truncation (float): TODO. Default: 1.
614
+ truncation_latent (Tensor | None): TODO. Default: None.
615
+ inject_index (int | None): The injection index for mixing noise.
616
+ Default: None.
617
+ return_latents (bool): Whether to return style latents.
618
+ Default: False.
619
+ """
620
+ # style codes -> latents with Style MLP layer
621
+ if not input_is_latent:
622
+ styles = [self.style_mlp(s) for s in styles]
623
+ # noises
624
+ if noise is None:
625
+ if randomize_noise:
626
+ noise = [None] * self.num_layers # for each style conv layer
627
+ else: # use the stored noise
628
+ noise = [
629
+ getattr(self.noises, f"noise{i}") for i in range(self.num_layers)
630
+ ]
631
+ # style truncation
632
+ if truncation < 1:
633
+ style_truncation = []
634
+ for style in styles:
635
+ style_truncation.append(
636
+ truncation_latent + truncation * (style - truncation_latent)
637
+ )
638
+ styles = style_truncation
639
+ # get style latent with injection
640
+ if len(styles) == 1:
641
+ inject_index = self.num_latent
642
+
643
+ if styles[0].ndim < 3:
644
+ # repeat latent code for all the layers
645
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
646
+ else: # used for encoder with different latent code for each layer
647
+ latent = styles[0]
648
+ elif len(styles) == 2: # mixing noises
649
+ if inject_index is None:
650
+ inject_index = random.randint(1, self.num_latent - 1)
651
+ latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
652
+ latent2 = (
653
+ styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
654
+ )
655
+ latent = torch.cat([latent1, latent2], 1)
656
+
657
+ # main generation
658
+ out = self.constant_input(latent.shape[0])
659
+ out = self.style_conv1(out, latent[:, 0], noise=noise[0])
660
+ skip = self.to_rgb1(out, latent[:, 1])
661
+
662
+ i = 1
663
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
664
+ self.style_convs[::2],
665
+ self.style_convs[1::2],
666
+ noise[1::2],
667
+ noise[2::2],
668
+ self.to_rgbs,
669
+ ):
670
+ out = conv1(out, latent[:, i], noise=noise1)
671
+ out = conv2(out, latent[:, i + 1], noise=noise2)
672
+ skip = to_rgb(out, latent[:, i + 2], skip)
673
+ i += 2
674
+
675
+ image = skip
676
+
677
+ if return_latents:
678
+ return image, latent
679
+ else:
680
+ return image, None
681
+
682
+
683
+ class ScaledLeakyReLU(nn.Module):
684
+ """Scaled LeakyReLU.
685
+
686
+ Args:
687
+ negative_slope (float): Negative slope. Default: 0.2.
688
+ """
689
+
690
+ def __init__(self, negative_slope=0.2):
691
+ super(ScaledLeakyReLU, self).__init__()
692
+ self.negative_slope = negative_slope
693
+
694
+ def forward(self, x):
695
+ out = F.leaky_relu(x, negative_slope=self.negative_slope)
696
+ return out * math.sqrt(2)
697
+
698
+
699
+ class EqualConv2d(nn.Module):
700
+ """Equalized Linear as StyleGAN2.
701
+
702
+ Args:
703
+ in_channels (int): Channel number of the input.
704
+ out_channels (int): Channel number of the output.
705
+ kernel_size (int): Size of the convolving kernel.
706
+ stride (int): Stride of the convolution. Default: 1
707
+ padding (int): Zero-padding added to both sides of the input.
708
+ Default: 0.
709
+ bias (bool): If ``True``, adds a learnable bias to the output.
710
+ Default: ``True``.
711
+ bias_init_val (float): Bias initialized value. Default: 0.
712
+ """
713
+
714
+ def __init__(
715
+ self,
716
+ in_channels,
717
+ out_channels,
718
+ kernel_size,
719
+ stride=1,
720
+ padding=0,
721
+ bias=True,
722
+ bias_init_val=0,
723
+ ):
724
+ super(EqualConv2d, self).__init__()
725
+ self.in_channels = in_channels
726
+ self.out_channels = out_channels
727
+ self.kernel_size = kernel_size
728
+ self.stride = stride
729
+ self.padding = padding
730
+ self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
731
+
732
+ self.weight = nn.Parameter(
733
+ torch.randn(out_channels, in_channels, kernel_size, kernel_size)
734
+ )
735
+ if bias:
736
+ self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
737
+ else:
738
+ self.register_parameter("bias", None)
739
+
740
+ def forward(self, x):
741
+ out = F.conv2d(
742
+ x,
743
+ self.weight * self.scale,
744
+ bias=self.bias,
745
+ stride=self.stride,
746
+ padding=self.padding,
747
+ )
748
+
749
+ return out
750
+
751
+ def __repr__(self):
752
+ return (
753
+ f"{self.__class__.__name__}(in_channels={self.in_channels}, "
754
+ f"out_channels={self.out_channels}, "
755
+ f"kernel_size={self.kernel_size},"
756
+ f" stride={self.stride}, padding={self.padding}, "
757
+ f"bias={self.bias is not None})"
758
+ )
759
+
760
+
761
+ class ConvLayer(nn.Sequential):
762
+ """Conv Layer used in StyleGAN2 Discriminator.
763
+
764
+ Args:
765
+ in_channels (int): Channel number of the input.
766
+ out_channels (int): Channel number of the output.
767
+ kernel_size (int): Kernel size.
768
+ downsample (bool): Whether downsample by a factor of 2.
769
+ Default: False.
770
+ resample_kernel (list[int]): A list indicating the 1D resample
771
+ kernel magnitude. A cross production will be applied to
772
+ extent 1D resample kernel to 2D resample kernel.
773
+ Default: (1, 3, 3, 1).
774
+ bias (bool): Whether with bias. Default: True.
775
+ activate (bool): Whether use activateion. Default: True.
776
+ """
777
+
778
+ def __init__(
779
+ self,
780
+ in_channels,
781
+ out_channels,
782
+ kernel_size,
783
+ downsample=False,
784
+ resample_kernel=(1, 3, 3, 1),
785
+ bias=True,
786
+ activate=True,
787
+ ):
788
+ layers = []
789
+ # downsample
790
+ if downsample:
791
+ layers.append(
792
+ UpFirDnSmooth(
793
+ resample_kernel,
794
+ upsample_factor=1,
795
+ downsample_factor=2,
796
+ kernel_size=kernel_size,
797
+ )
798
+ )
799
+ stride = 2
800
+ self.padding = 0
801
+ else:
802
+ stride = 1
803
+ self.padding = kernel_size // 2
804
+ # conv
805
+ layers.append(
806
+ EqualConv2d(
807
+ in_channels,
808
+ out_channels,
809
+ kernel_size,
810
+ stride=stride,
811
+ padding=self.padding,
812
+ bias=bias and not activate,
813
+ )
814
+ )
815
+ # activation
816
+ if activate:
817
+ if bias:
818
+ layers.append(FusedLeakyReLU(out_channels))
819
+ else:
820
+ layers.append(ScaledLeakyReLU(0.2))
821
+
822
+ super(ConvLayer, self).__init__(*layers)
823
+
824
+
825
+ class ResBlock(nn.Module):
826
+ """Residual block used in StyleGAN2 Discriminator.
827
+
828
+ Args:
829
+ in_channels (int): Channel number of the input.
830
+ out_channels (int): Channel number of the output.
831
+ resample_kernel (list[int]): A list indicating the 1D resample
832
+ kernel magnitude. A cross production will be applied to
833
+ extent 1D resample kernel to 2D resample kernel.
834
+ Default: (1, 3, 3, 1).
835
+ """
836
+
837
+ def __init__(self, in_channels, out_channels, resample_kernel=(1, 3, 3, 1)):
838
+ super(ResBlock, self).__init__()
839
+
840
+ self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
841
+ self.conv2 = ConvLayer(
842
+ in_channels,
843
+ out_channels,
844
+ 3,
845
+ downsample=True,
846
+ resample_kernel=resample_kernel,
847
+ bias=True,
848
+ activate=True,
849
+ )
850
+ self.skip = ConvLayer(
851
+ in_channels,
852
+ out_channels,
853
+ 1,
854
+ downsample=True,
855
+ resample_kernel=resample_kernel,
856
+ bias=False,
857
+ activate=False,
858
+ )
859
+
860
+ def forward(self, x):
861
+ out = self.conv1(x)
862
+ out = self.conv2(out)
863
+ skip = self.skip(x)
864
+ out = (out + skip) / math.sqrt(2)
865
+ return out
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/face/stylegan2_bilinear_arch.py ADDED
@@ -0,0 +1,709 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ # type: ignore
3
+ import math
4
+ import random
5
+
6
+ import torch
7
+ from torch import nn
8
+ from torch.nn import functional as F
9
+
10
+ from .fused_act import FusedLeakyReLU, fused_leaky_relu
11
+
12
+
13
+ class NormStyleCode(nn.Module):
14
+ def forward(self, x):
15
+ """Normalize the style codes.
16
+ Args:
17
+ x (Tensor): Style codes with shape (b, c).
18
+ Returns:
19
+ Tensor: Normalized tensor.
20
+ """
21
+ return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
22
+
23
+
24
+ class EqualLinear(nn.Module):
25
+ """Equalized Linear as StyleGAN2.
26
+ Args:
27
+ in_channels (int): Size of each sample.
28
+ out_channels (int): Size of each output sample.
29
+ bias (bool): If set to ``False``, the layer will not learn an additive
30
+ bias. Default: ``True``.
31
+ bias_init_val (float): Bias initialized value. Default: 0.
32
+ lr_mul (float): Learning rate multiplier. Default: 1.
33
+ activation (None | str): The activation after ``linear`` operation.
34
+ Supported: 'fused_lrelu', None. Default: None.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ in_channels,
40
+ out_channels,
41
+ bias=True,
42
+ bias_init_val=0,
43
+ lr_mul=1,
44
+ activation=None,
45
+ ):
46
+ super(EqualLinear, self).__init__()
47
+ self.in_channels = in_channels
48
+ self.out_channels = out_channels
49
+ self.lr_mul = lr_mul
50
+ self.activation = activation
51
+ if self.activation not in ["fused_lrelu", None]:
52
+ raise ValueError(
53
+ f"Wrong activation value in EqualLinear: {activation}"
54
+ "Supported ones are: ['fused_lrelu', None]."
55
+ )
56
+ self.scale = (1 / math.sqrt(in_channels)) * lr_mul
57
+
58
+ self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul))
59
+ if bias:
60
+ self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
61
+ else:
62
+ self.register_parameter("bias", None)
63
+
64
+ def forward(self, x):
65
+ if self.bias is None:
66
+ bias = None
67
+ else:
68
+ bias = self.bias * self.lr_mul
69
+ if self.activation == "fused_lrelu":
70
+ out = F.linear(x, self.weight * self.scale)
71
+ out = fused_leaky_relu(out, bias)
72
+ else:
73
+ out = F.linear(x, self.weight * self.scale, bias=bias)
74
+ return out
75
+
76
+ def __repr__(self):
77
+ return (
78
+ f"{self.__class__.__name__}(in_channels={self.in_channels}, "
79
+ f"out_channels={self.out_channels}, bias={self.bias is not None})"
80
+ )
81
+
82
+
83
+ class ModulatedConv2d(nn.Module):
84
+ """Modulated Conv2d used in StyleGAN2.
85
+ There is no bias in ModulatedConv2d.
86
+ Args:
87
+ in_channels (int): Channel number of the input.
88
+ out_channels (int): Channel number of the output.
89
+ kernel_size (int): Size of the convolving kernel.
90
+ num_style_feat (int): Channel number of style features.
91
+ demodulate (bool): Whether to demodulate in the conv layer.
92
+ Default: True.
93
+ sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
94
+ Default: None.
95
+ eps (float): A value added to the denominator for numerical stability.
96
+ Default: 1e-8.
97
+ """
98
+
99
+ def __init__(
100
+ self,
101
+ in_channels,
102
+ out_channels,
103
+ kernel_size,
104
+ num_style_feat,
105
+ demodulate=True,
106
+ sample_mode=None,
107
+ eps=1e-8,
108
+ interpolation_mode="bilinear",
109
+ ):
110
+ super(ModulatedConv2d, self).__init__()
111
+ self.in_channels = in_channels
112
+ self.out_channels = out_channels
113
+ self.kernel_size = kernel_size
114
+ self.demodulate = demodulate
115
+ self.sample_mode = sample_mode
116
+ self.eps = eps
117
+ self.interpolation_mode = interpolation_mode
118
+ if self.interpolation_mode == "nearest":
119
+ self.align_corners = None
120
+ else:
121
+ self.align_corners = False
122
+
123
+ self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
124
+ # modulation inside each modulated conv
125
+ self.modulation = EqualLinear(
126
+ num_style_feat,
127
+ in_channels,
128
+ bias=True,
129
+ bias_init_val=1,
130
+ lr_mul=1,
131
+ activation=None,
132
+ )
133
+
134
+ self.weight = nn.Parameter(
135
+ torch.randn(1, out_channels, in_channels, kernel_size, kernel_size)
136
+ )
137
+ self.padding = kernel_size // 2
138
+
139
+ def forward(self, x, style):
140
+ """Forward function.
141
+ Args:
142
+ x (Tensor): Tensor with shape (b, c, h, w).
143
+ style (Tensor): Tensor with shape (b, num_style_feat).
144
+ Returns:
145
+ Tensor: Modulated tensor after convolution.
146
+ """
147
+ b, c, h, w = x.shape # c = c_in
148
+ # weight modulation
149
+ style = self.modulation(style).view(b, 1, c, 1, 1)
150
+ # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
151
+ weight = self.scale * self.weight * style # (b, c_out, c_in, k, k)
152
+
153
+ if self.demodulate:
154
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
155
+ weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
156
+
157
+ weight = weight.view(
158
+ b * self.out_channels, c, self.kernel_size, self.kernel_size
159
+ )
160
+
161
+ if self.sample_mode == "upsample":
162
+ x = F.interpolate(
163
+ x,
164
+ scale_factor=2,
165
+ mode=self.interpolation_mode,
166
+ align_corners=self.align_corners,
167
+ )
168
+ elif self.sample_mode == "downsample":
169
+ x = F.interpolate(
170
+ x,
171
+ scale_factor=0.5,
172
+ mode=self.interpolation_mode,
173
+ align_corners=self.align_corners,
174
+ )
175
+
176
+ b, c, h, w = x.shape
177
+ x = x.view(1, b * c, h, w)
178
+ # weight: (b*c_out, c_in, k, k), groups=b
179
+ out = F.conv2d(x, weight, padding=self.padding, groups=b)
180
+ out = out.view(b, self.out_channels, *out.shape[2:4])
181
+
182
+ return out
183
+
184
+ def __repr__(self):
185
+ return (
186
+ f"{self.__class__.__name__}(in_channels={self.in_channels}, "
187
+ f"out_channels={self.out_channels}, "
188
+ f"kernel_size={self.kernel_size}, "
189
+ f"demodulate={self.demodulate}, sample_mode={self.sample_mode})"
190
+ )
191
+
192
+
193
+ class StyleConv(nn.Module):
194
+ """Style conv.
195
+ Args:
196
+ in_channels (int): Channel number of the input.
197
+ out_channels (int): Channel number of the output.
198
+ kernel_size (int): Size of the convolving kernel.
199
+ num_style_feat (int): Channel number of style features.
200
+ demodulate (bool): Whether demodulate in the conv layer. Default: True.
201
+ sample_mode (str | None): Indicating 'upsample', 'downsample' or None.
202
+ Default: None.
203
+ """
204
+
205
+ def __init__(
206
+ self,
207
+ in_channels,
208
+ out_channels,
209
+ kernel_size,
210
+ num_style_feat,
211
+ demodulate=True,
212
+ sample_mode=None,
213
+ interpolation_mode="bilinear",
214
+ ):
215
+ super(StyleConv, self).__init__()
216
+ self.modulated_conv = ModulatedConv2d(
217
+ in_channels,
218
+ out_channels,
219
+ kernel_size,
220
+ num_style_feat,
221
+ demodulate=demodulate,
222
+ sample_mode=sample_mode,
223
+ interpolation_mode=interpolation_mode,
224
+ )
225
+ self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
226
+ self.activate = FusedLeakyReLU(out_channels)
227
+
228
+ def forward(self, x, style, noise=None):
229
+ # modulate
230
+ out = self.modulated_conv(x, style)
231
+ # noise injection
232
+ if noise is None:
233
+ b, _, h, w = out.shape
234
+ noise = out.new_empty(b, 1, h, w).normal_()
235
+ out = out + self.weight * noise
236
+ # activation (with bias)
237
+ out = self.activate(out)
238
+ return out
239
+
240
+
241
+ class ToRGB(nn.Module):
242
+ """To RGB from features.
243
+ Args:
244
+ in_channels (int): Channel number of input.
245
+ num_style_feat (int): Channel number of style features.
246
+ upsample (bool): Whether to upsample. Default: True.
247
+ """
248
+
249
+ def __init__(
250
+ self, in_channels, num_style_feat, upsample=True, interpolation_mode="bilinear"
251
+ ):
252
+ super(ToRGB, self).__init__()
253
+ self.upsample = upsample
254
+ self.interpolation_mode = interpolation_mode
255
+ if self.interpolation_mode == "nearest":
256
+ self.align_corners = None
257
+ else:
258
+ self.align_corners = False
259
+ self.modulated_conv = ModulatedConv2d(
260
+ in_channels,
261
+ 3,
262
+ kernel_size=1,
263
+ num_style_feat=num_style_feat,
264
+ demodulate=False,
265
+ sample_mode=None,
266
+ interpolation_mode=interpolation_mode,
267
+ )
268
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
269
+
270
+ def forward(self, x, style, skip=None):
271
+ """Forward function.
272
+ Args:
273
+ x (Tensor): Feature tensor with shape (b, c, h, w).
274
+ style (Tensor): Tensor with shape (b, num_style_feat).
275
+ skip (Tensor): Base/skip tensor. Default: None.
276
+ Returns:
277
+ Tensor: RGB images.
278
+ """
279
+ out = self.modulated_conv(x, style)
280
+ out = out + self.bias
281
+ if skip is not None:
282
+ if self.upsample:
283
+ skip = F.interpolate(
284
+ skip,
285
+ scale_factor=2,
286
+ mode=self.interpolation_mode,
287
+ align_corners=self.align_corners,
288
+ )
289
+ out = out + skip
290
+ return out
291
+
292
+
293
+ class ConstantInput(nn.Module):
294
+ """Constant input.
295
+ Args:
296
+ num_channel (int): Channel number of constant input.
297
+ size (int): Spatial size of constant input.
298
+ """
299
+
300
+ def __init__(self, num_channel, size):
301
+ super(ConstantInput, self).__init__()
302
+ self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
303
+
304
+ def forward(self, batch):
305
+ out = self.weight.repeat(batch, 1, 1, 1)
306
+ return out
307
+
308
+
309
+ class StyleGAN2GeneratorBilinear(nn.Module):
310
+ """StyleGAN2 Generator.
311
+ Args:
312
+ out_size (int): The spatial size of outputs.
313
+ num_style_feat (int): Channel number of style features. Default: 512.
314
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
315
+ channel_multiplier (int): Channel multiplier for large networks of
316
+ StyleGAN2. Default: 2.
317
+ lr_mlp (float): Learning rate multiplier for mlp layers. Default: 0.01.
318
+ narrow (float): Narrow ratio for channels. Default: 1.0.
319
+ """
320
+
321
+ def __init__(
322
+ self,
323
+ out_size,
324
+ num_style_feat=512,
325
+ num_mlp=8,
326
+ channel_multiplier=2,
327
+ lr_mlp=0.01,
328
+ narrow=1,
329
+ interpolation_mode="bilinear",
330
+ ):
331
+ super(StyleGAN2GeneratorBilinear, self).__init__()
332
+ # Style MLP layers
333
+ self.num_style_feat = num_style_feat
334
+ style_mlp_layers = [NormStyleCode()]
335
+ for i in range(num_mlp):
336
+ style_mlp_layers.append(
337
+ EqualLinear(
338
+ num_style_feat,
339
+ num_style_feat,
340
+ bias=True,
341
+ bias_init_val=0,
342
+ lr_mul=lr_mlp,
343
+ activation="fused_lrelu",
344
+ )
345
+ )
346
+ self.style_mlp = nn.Sequential(*style_mlp_layers)
347
+
348
+ channels = {
349
+ "4": int(512 * narrow),
350
+ "8": int(512 * narrow),
351
+ "16": int(512 * narrow),
352
+ "32": int(512 * narrow),
353
+ "64": int(256 * channel_multiplier * narrow),
354
+ "128": int(128 * channel_multiplier * narrow),
355
+ "256": int(64 * channel_multiplier * narrow),
356
+ "512": int(32 * channel_multiplier * narrow),
357
+ "1024": int(16 * channel_multiplier * narrow),
358
+ }
359
+ self.channels = channels
360
+
361
+ self.constant_input = ConstantInput(channels["4"], size=4)
362
+ self.style_conv1 = StyleConv(
363
+ channels["4"],
364
+ channels["4"],
365
+ kernel_size=3,
366
+ num_style_feat=num_style_feat,
367
+ demodulate=True,
368
+ sample_mode=None,
369
+ interpolation_mode=interpolation_mode,
370
+ )
371
+ self.to_rgb1 = ToRGB(
372
+ channels["4"],
373
+ num_style_feat,
374
+ upsample=False,
375
+ interpolation_mode=interpolation_mode,
376
+ )
377
+
378
+ self.log_size = int(math.log(out_size, 2))
379
+ self.num_layers = (self.log_size - 2) * 2 + 1
380
+ self.num_latent = self.log_size * 2 - 2
381
+
382
+ self.style_convs = nn.ModuleList()
383
+ self.to_rgbs = nn.ModuleList()
384
+ self.noises = nn.Module()
385
+
386
+ in_channels = channels["4"]
387
+ # noise
388
+ for layer_idx in range(self.num_layers):
389
+ resolution = 2 ** ((layer_idx + 5) // 2)
390
+ shape = [1, 1, resolution, resolution]
391
+ self.noises.register_buffer(f"noise{layer_idx}", torch.randn(*shape))
392
+ # style convs and to_rgbs
393
+ for i in range(3, self.log_size + 1):
394
+ out_channels = channels[f"{2**i}"]
395
+ self.style_convs.append(
396
+ StyleConv(
397
+ in_channels,
398
+ out_channels,
399
+ kernel_size=3,
400
+ num_style_feat=num_style_feat,
401
+ demodulate=True,
402
+ sample_mode="upsample",
403
+ interpolation_mode=interpolation_mode,
404
+ )
405
+ )
406
+ self.style_convs.append(
407
+ StyleConv(
408
+ out_channels,
409
+ out_channels,
410
+ kernel_size=3,
411
+ num_style_feat=num_style_feat,
412
+ demodulate=True,
413
+ sample_mode=None,
414
+ interpolation_mode=interpolation_mode,
415
+ )
416
+ )
417
+ self.to_rgbs.append(
418
+ ToRGB(
419
+ out_channels,
420
+ num_style_feat,
421
+ upsample=True,
422
+ interpolation_mode=interpolation_mode,
423
+ )
424
+ )
425
+ in_channels = out_channels
426
+
427
+ def make_noise(self):
428
+ """Make noise for noise injection."""
429
+ device = self.constant_input.weight.device
430
+ noises = [torch.randn(1, 1, 4, 4, device=device)]
431
+
432
+ for i in range(3, self.log_size + 1):
433
+ for _ in range(2):
434
+ noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
435
+
436
+ return noises
437
+
438
+ def get_latent(self, x):
439
+ return self.style_mlp(x)
440
+
441
+ def mean_latent(self, num_latent):
442
+ latent_in = torch.randn(
443
+ num_latent, self.num_style_feat, device=self.constant_input.weight.device
444
+ )
445
+ latent = self.style_mlp(latent_in).mean(0, keepdim=True)
446
+ return latent
447
+
448
+ def forward(
449
+ self,
450
+ styles,
451
+ input_is_latent=False,
452
+ noise=None,
453
+ randomize_noise=True,
454
+ truncation=1,
455
+ truncation_latent=None,
456
+ inject_index=None,
457
+ return_latents=False,
458
+ ):
459
+ """Forward function for StyleGAN2Generator.
460
+ Args:
461
+ styles (list[Tensor]): Sample codes of styles.
462
+ input_is_latent (bool): Whether input is latent style.
463
+ Default: False.
464
+ noise (Tensor | None): Input noise or None. Default: None.
465
+ randomize_noise (bool): Randomize noise, used when 'noise' is
466
+ False. Default: True.
467
+ truncation (float): TODO. Default: 1.
468
+ truncation_latent (Tensor | None): TODO. Default: None.
469
+ inject_index (int | None): The injection index for mixing noise.
470
+ Default: None.
471
+ return_latents (bool): Whether to return style latents.
472
+ Default: False.
473
+ """
474
+ # style codes -> latents with Style MLP layer
475
+ if not input_is_latent:
476
+ styles = [self.style_mlp(s) for s in styles]
477
+ # noises
478
+ if noise is None:
479
+ if randomize_noise:
480
+ noise = [None] * self.num_layers # for each style conv layer
481
+ else: # use the stored noise
482
+ noise = [
483
+ getattr(self.noises, f"noise{i}") for i in range(self.num_layers)
484
+ ]
485
+ # style truncation
486
+ if truncation < 1:
487
+ style_truncation = []
488
+ for style in styles:
489
+ style_truncation.append(
490
+ truncation_latent + truncation * (style - truncation_latent)
491
+ )
492
+ styles = style_truncation
493
+ # get style latent with injection
494
+ if len(styles) == 1:
495
+ inject_index = self.num_latent
496
+
497
+ if styles[0].ndim < 3:
498
+ # repeat latent code for all the layers
499
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
500
+ else: # used for encoder with different latent code for each layer
501
+ latent = styles[0]
502
+ elif len(styles) == 2: # mixing noises
503
+ if inject_index is None:
504
+ inject_index = random.randint(1, self.num_latent - 1)
505
+ latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
506
+ latent2 = (
507
+ styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
508
+ )
509
+ latent = torch.cat([latent1, latent2], 1)
510
+
511
+ # main generation
512
+ out = self.constant_input(latent.shape[0])
513
+ out = self.style_conv1(out, latent[:, 0], noise=noise[0])
514
+ skip = self.to_rgb1(out, latent[:, 1])
515
+
516
+ i = 1
517
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
518
+ self.style_convs[::2],
519
+ self.style_convs[1::2],
520
+ noise[1::2],
521
+ noise[2::2],
522
+ self.to_rgbs,
523
+ ):
524
+ out = conv1(out, latent[:, i], noise=noise1)
525
+ out = conv2(out, latent[:, i + 1], noise=noise2)
526
+ skip = to_rgb(out, latent[:, i + 2], skip)
527
+ i += 2
528
+
529
+ image = skip
530
+
531
+ if return_latents:
532
+ return image, latent
533
+ else:
534
+ return image, None
535
+
536
+
537
+ class ScaledLeakyReLU(nn.Module):
538
+ """Scaled LeakyReLU.
539
+ Args:
540
+ negative_slope (float): Negative slope. Default: 0.2.
541
+ """
542
+
543
+ def __init__(self, negative_slope=0.2):
544
+ super(ScaledLeakyReLU, self).__init__()
545
+ self.negative_slope = negative_slope
546
+
547
+ def forward(self, x):
548
+ out = F.leaky_relu(x, negative_slope=self.negative_slope)
549
+ return out * math.sqrt(2)
550
+
551
+
552
+ class EqualConv2d(nn.Module):
553
+ """Equalized Linear as StyleGAN2.
554
+ Args:
555
+ in_channels (int): Channel number of the input.
556
+ out_channels (int): Channel number of the output.
557
+ kernel_size (int): Size of the convolving kernel.
558
+ stride (int): Stride of the convolution. Default: 1
559
+ padding (int): Zero-padding added to both sides of the input.
560
+ Default: 0.
561
+ bias (bool): If ``True``, adds a learnable bias to the output.
562
+ Default: ``True``.
563
+ bias_init_val (float): Bias initialized value. Default: 0.
564
+ """
565
+
566
+ def __init__(
567
+ self,
568
+ in_channels,
569
+ out_channels,
570
+ kernel_size,
571
+ stride=1,
572
+ padding=0,
573
+ bias=True,
574
+ bias_init_val=0,
575
+ ):
576
+ super(EqualConv2d, self).__init__()
577
+ self.in_channels = in_channels
578
+ self.out_channels = out_channels
579
+ self.kernel_size = kernel_size
580
+ self.stride = stride
581
+ self.padding = padding
582
+ self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
583
+
584
+ self.weight = nn.Parameter(
585
+ torch.randn(out_channels, in_channels, kernel_size, kernel_size)
586
+ )
587
+ if bias:
588
+ self.bias = nn.Parameter(torch.zeros(out_channels).fill_(bias_init_val))
589
+ else:
590
+ self.register_parameter("bias", None)
591
+
592
+ def forward(self, x):
593
+ out = F.conv2d(
594
+ x,
595
+ self.weight * self.scale,
596
+ bias=self.bias,
597
+ stride=self.stride,
598
+ padding=self.padding,
599
+ )
600
+
601
+ return out
602
+
603
+ def __repr__(self):
604
+ return (
605
+ f"{self.__class__.__name__}(in_channels={self.in_channels}, "
606
+ f"out_channels={self.out_channels}, "
607
+ f"kernel_size={self.kernel_size},"
608
+ f" stride={self.stride}, padding={self.padding}, "
609
+ f"bias={self.bias is not None})"
610
+ )
611
+
612
+
613
+ class ConvLayer(nn.Sequential):
614
+ """Conv Layer used in StyleGAN2 Discriminator.
615
+ Args:
616
+ in_channels (int): Channel number of the input.
617
+ out_channels (int): Channel number of the output.
618
+ kernel_size (int): Kernel size.
619
+ downsample (bool): Whether downsample by a factor of 2.
620
+ Default: False.
621
+ bias (bool): Whether with bias. Default: True.
622
+ activate (bool): Whether use activateion. Default: True.
623
+ """
624
+
625
+ def __init__(
626
+ self,
627
+ in_channels,
628
+ out_channels,
629
+ kernel_size,
630
+ downsample=False,
631
+ bias=True,
632
+ activate=True,
633
+ interpolation_mode="bilinear",
634
+ ):
635
+ layers = []
636
+ self.interpolation_mode = interpolation_mode
637
+ # downsample
638
+ if downsample:
639
+ if self.interpolation_mode == "nearest":
640
+ self.align_corners = None
641
+ else:
642
+ self.align_corners = False
643
+
644
+ layers.append(
645
+ torch.nn.Upsample(
646
+ scale_factor=0.5,
647
+ mode=interpolation_mode,
648
+ align_corners=self.align_corners,
649
+ )
650
+ )
651
+ stride = 1
652
+ self.padding = kernel_size // 2
653
+ # conv
654
+ layers.append(
655
+ EqualConv2d(
656
+ in_channels,
657
+ out_channels,
658
+ kernel_size,
659
+ stride=stride,
660
+ padding=self.padding,
661
+ bias=bias and not activate,
662
+ )
663
+ )
664
+ # activation
665
+ if activate:
666
+ if bias:
667
+ layers.append(FusedLeakyReLU(out_channels))
668
+ else:
669
+ layers.append(ScaledLeakyReLU(0.2))
670
+
671
+ super(ConvLayer, self).__init__(*layers)
672
+
673
+
674
+ class ResBlock(nn.Module):
675
+ """Residual block used in StyleGAN2 Discriminator.
676
+ Args:
677
+ in_channels (int): Channel number of the input.
678
+ out_channels (int): Channel number of the output.
679
+ """
680
+
681
+ def __init__(self, in_channels, out_channels, interpolation_mode="bilinear"):
682
+ super(ResBlock, self).__init__()
683
+
684
+ self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
685
+ self.conv2 = ConvLayer(
686
+ in_channels,
687
+ out_channels,
688
+ 3,
689
+ downsample=True,
690
+ interpolation_mode=interpolation_mode,
691
+ bias=True,
692
+ activate=True,
693
+ )
694
+ self.skip = ConvLayer(
695
+ in_channels,
696
+ out_channels,
697
+ 1,
698
+ downsample=True,
699
+ interpolation_mode=interpolation_mode,
700
+ bias=False,
701
+ activate=False,
702
+ )
703
+
704
+ def forward(self, x):
705
+ out = self.conv1(x)
706
+ out = self.conv2(out)
707
+ skip = self.skip(x)
708
+ out = (out + skip) / math.sqrt(2)
709
+ return out
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/face/stylegan2_clean_arch.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ # type: ignore
3
+ import math
4
+
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+ from torch.nn import init
9
+ from torch.nn.modules.batchnorm import _BatchNorm
10
+
11
+
12
+ @torch.no_grad()
13
+ def default_init_weights(module_list, scale=1, bias_fill=0, **kwargs):
14
+ """Initialize network weights.
15
+ Args:
16
+ module_list (list[nn.Module] | nn.Module): Modules to be initialized.
17
+ scale (float): Scale initialized weights, especially for residual
18
+ blocks. Default: 1.
19
+ bias_fill (float): The value to fill bias. Default: 0
20
+ kwargs (dict): Other arguments for initialization function.
21
+ """
22
+ if not isinstance(module_list, list):
23
+ module_list = [module_list]
24
+ for module in module_list:
25
+ for m in module.modules():
26
+ if isinstance(m, nn.Conv2d):
27
+ init.kaiming_normal_(m.weight, **kwargs)
28
+ m.weight.data *= scale
29
+ if m.bias is not None:
30
+ m.bias.data.fill_(bias_fill)
31
+ elif isinstance(m, nn.Linear):
32
+ init.kaiming_normal_(m.weight, **kwargs)
33
+ m.weight.data *= scale
34
+ if m.bias is not None:
35
+ m.bias.data.fill_(bias_fill)
36
+ elif isinstance(m, _BatchNorm):
37
+ init.constant_(m.weight, 1)
38
+ if m.bias is not None:
39
+ m.bias.data.fill_(bias_fill)
40
+
41
+
42
+ class NormStyleCode(nn.Module):
43
+ def forward(self, x):
44
+ """Normalize the style codes.
45
+ Args:
46
+ x (Tensor): Style codes with shape (b, c).
47
+ Returns:
48
+ Tensor: Normalized tensor.
49
+ """
50
+ return x * torch.rsqrt(torch.mean(x**2, dim=1, keepdim=True) + 1e-8)
51
+
52
+
53
+ class ModulatedConv2d(nn.Module):
54
+ """Modulated Conv2d used in StyleGAN2.
55
+ There is no bias in ModulatedConv2d.
56
+ Args:
57
+ in_channels (int): Channel number of the input.
58
+ out_channels (int): Channel number of the output.
59
+ kernel_size (int): Size of the convolving kernel.
60
+ num_style_feat (int): Channel number of style features.
61
+ demodulate (bool): Whether to demodulate in the conv layer. Default: True.
62
+ sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
63
+ eps (float): A value added to the denominator for numerical stability. Default: 1e-8.
64
+ """
65
+
66
+ def __init__(
67
+ self,
68
+ in_channels,
69
+ out_channels,
70
+ kernel_size,
71
+ num_style_feat,
72
+ demodulate=True,
73
+ sample_mode=None,
74
+ eps=1e-8,
75
+ ):
76
+ super(ModulatedConv2d, self).__init__()
77
+ self.in_channels = in_channels
78
+ self.out_channels = out_channels
79
+ self.kernel_size = kernel_size
80
+ self.demodulate = demodulate
81
+ self.sample_mode = sample_mode
82
+ self.eps = eps
83
+
84
+ # modulation inside each modulated conv
85
+ self.modulation = nn.Linear(num_style_feat, in_channels, bias=True)
86
+ # initialization
87
+ default_init_weights(
88
+ self.modulation,
89
+ scale=1,
90
+ bias_fill=1,
91
+ a=0,
92
+ mode="fan_in",
93
+ nonlinearity="linear",
94
+ )
95
+
96
+ self.weight = nn.Parameter(
97
+ torch.randn(1, out_channels, in_channels, kernel_size, kernel_size)
98
+ / math.sqrt(in_channels * kernel_size**2)
99
+ )
100
+ self.padding = kernel_size // 2
101
+
102
+ def forward(self, x, style):
103
+ """Forward function.
104
+ Args:
105
+ x (Tensor): Tensor with shape (b, c, h, w).
106
+ style (Tensor): Tensor with shape (b, num_style_feat).
107
+ Returns:
108
+ Tensor: Modulated tensor after convolution.
109
+ """
110
+ b, c, h, w = x.shape # c = c_in
111
+ # weight modulation
112
+ style = self.modulation(style).view(b, 1, c, 1, 1)
113
+ # self.weight: (1, c_out, c_in, k, k); style: (b, 1, c, 1, 1)
114
+ weight = self.weight * style # (b, c_out, c_in, k, k)
115
+
116
+ if self.demodulate:
117
+ demod = torch.rsqrt(weight.pow(2).sum([2, 3, 4]) + self.eps)
118
+ weight = weight * demod.view(b, self.out_channels, 1, 1, 1)
119
+
120
+ weight = weight.view(
121
+ b * self.out_channels, c, self.kernel_size, self.kernel_size
122
+ )
123
+
124
+ # upsample or downsample if necessary
125
+ if self.sample_mode == "upsample":
126
+ x = F.interpolate(x, scale_factor=2, mode="bilinear", align_corners=False)
127
+ elif self.sample_mode == "downsample":
128
+ x = F.interpolate(x, scale_factor=0.5, mode="bilinear", align_corners=False)
129
+
130
+ b, c, h, w = x.shape
131
+ x = x.view(1, b * c, h, w)
132
+ # weight: (b*c_out, c_in, k, k), groups=b
133
+ out = F.conv2d(x, weight, padding=self.padding, groups=b)
134
+ out = out.view(b, self.out_channels, *out.shape[2:4])
135
+
136
+ return out
137
+
138
+ def __repr__(self):
139
+ return (
140
+ f"{self.__class__.__name__}(in_channels={self.in_channels}, out_channels={self.out_channels}, "
141
+ f"kernel_size={self.kernel_size}, demodulate={self.demodulate}, sample_mode={self.sample_mode})"
142
+ )
143
+
144
+
145
+ class StyleConv(nn.Module):
146
+ """Style conv used in StyleGAN2.
147
+ Args:
148
+ in_channels (int): Channel number of the input.
149
+ out_channels (int): Channel number of the output.
150
+ kernel_size (int): Size of the convolving kernel.
151
+ num_style_feat (int): Channel number of style features.
152
+ demodulate (bool): Whether demodulate in the conv layer. Default: True.
153
+ sample_mode (str | None): Indicating 'upsample', 'downsample' or None. Default: None.
154
+ """
155
+
156
+ def __init__(
157
+ self,
158
+ in_channels,
159
+ out_channels,
160
+ kernel_size,
161
+ num_style_feat,
162
+ demodulate=True,
163
+ sample_mode=None,
164
+ ):
165
+ super(StyleConv, self).__init__()
166
+ self.modulated_conv = ModulatedConv2d(
167
+ in_channels,
168
+ out_channels,
169
+ kernel_size,
170
+ num_style_feat,
171
+ demodulate=demodulate,
172
+ sample_mode=sample_mode,
173
+ )
174
+ self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
175
+ self.bias = nn.Parameter(torch.zeros(1, out_channels, 1, 1))
176
+ self.activate = nn.LeakyReLU(negative_slope=0.2, inplace=True)
177
+
178
+ def forward(self, x, style, noise=None):
179
+ # modulate
180
+ out = self.modulated_conv(x, style) * 2**0.5 # for conversion
181
+ # noise injection
182
+ if noise is None:
183
+ b, _, h, w = out.shape
184
+ noise = out.new_empty(b, 1, h, w).normal_()
185
+ out = out + self.weight * noise
186
+ # add bias
187
+ out = out + self.bias
188
+ # activation
189
+ out = self.activate(out)
190
+ return out
191
+
192
+
193
+ class ToRGB(nn.Module):
194
+ """To RGB (image space) from features.
195
+ Args:
196
+ in_channels (int): Channel number of input.
197
+ num_style_feat (int): Channel number of style features.
198
+ upsample (bool): Whether to upsample. Default: True.
199
+ """
200
+
201
+ def __init__(self, in_channels, num_style_feat, upsample=True):
202
+ super(ToRGB, self).__init__()
203
+ self.upsample = upsample
204
+ self.modulated_conv = ModulatedConv2d(
205
+ in_channels,
206
+ 3,
207
+ kernel_size=1,
208
+ num_style_feat=num_style_feat,
209
+ demodulate=False,
210
+ sample_mode=None,
211
+ )
212
+ self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
213
+
214
+ def forward(self, x, style, skip=None):
215
+ """Forward function.
216
+ Args:
217
+ x (Tensor): Feature tensor with shape (b, c, h, w).
218
+ style (Tensor): Tensor with shape (b, num_style_feat).
219
+ skip (Tensor): Base/skip tensor. Default: None.
220
+ Returns:
221
+ Tensor: RGB images.
222
+ """
223
+ out = self.modulated_conv(x, style)
224
+ out = out + self.bias
225
+ if skip is not None:
226
+ if self.upsample:
227
+ skip = F.interpolate(
228
+ skip, scale_factor=2, mode="bilinear", align_corners=False
229
+ )
230
+ out = out + skip
231
+ return out
232
+
233
+
234
+ class ConstantInput(nn.Module):
235
+ """Constant input.
236
+ Args:
237
+ num_channel (int): Channel number of constant input.
238
+ size (int): Spatial size of constant input.
239
+ """
240
+
241
+ def __init__(self, num_channel, size):
242
+ super(ConstantInput, self).__init__()
243
+ self.weight = nn.Parameter(torch.randn(1, num_channel, size, size))
244
+
245
+ def forward(self, batch):
246
+ out = self.weight.repeat(batch, 1, 1, 1)
247
+ return out
248
+
249
+
250
+ class StyleGAN2GeneratorClean(nn.Module):
251
+ """Clean version of StyleGAN2 Generator.
252
+ Args:
253
+ out_size (int): The spatial size of outputs.
254
+ num_style_feat (int): Channel number of style features. Default: 512.
255
+ num_mlp (int): Layer number of MLP style layers. Default: 8.
256
+ channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
257
+ narrow (float): Narrow ratio for channels. Default: 1.0.
258
+ """
259
+
260
+ def __init__(
261
+ self, out_size, num_style_feat=512, num_mlp=8, channel_multiplier=2, narrow=1
262
+ ):
263
+ super(StyleGAN2GeneratorClean, self).__init__()
264
+ # Style MLP layers
265
+ self.num_style_feat = num_style_feat
266
+ style_mlp_layers = [NormStyleCode()]
267
+ for i in range(num_mlp):
268
+ style_mlp_layers.extend(
269
+ [
270
+ nn.Linear(num_style_feat, num_style_feat, bias=True),
271
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
272
+ ]
273
+ )
274
+ self.style_mlp = nn.Sequential(*style_mlp_layers)
275
+ # initialization
276
+ default_init_weights(
277
+ self.style_mlp,
278
+ scale=1,
279
+ bias_fill=0,
280
+ a=0.2,
281
+ mode="fan_in",
282
+ nonlinearity="leaky_relu",
283
+ )
284
+
285
+ # channel list
286
+ channels = {
287
+ "4": int(512 * narrow),
288
+ "8": int(512 * narrow),
289
+ "16": int(512 * narrow),
290
+ "32": int(512 * narrow),
291
+ "64": int(256 * channel_multiplier * narrow),
292
+ "128": int(128 * channel_multiplier * narrow),
293
+ "256": int(64 * channel_multiplier * narrow),
294
+ "512": int(32 * channel_multiplier * narrow),
295
+ "1024": int(16 * channel_multiplier * narrow),
296
+ }
297
+ self.channels = channels
298
+
299
+ self.constant_input = ConstantInput(channels["4"], size=4)
300
+ self.style_conv1 = StyleConv(
301
+ channels["4"],
302
+ channels["4"],
303
+ kernel_size=3,
304
+ num_style_feat=num_style_feat,
305
+ demodulate=True,
306
+ sample_mode=None,
307
+ )
308
+ self.to_rgb1 = ToRGB(channels["4"], num_style_feat, upsample=False)
309
+
310
+ self.log_size = int(math.log(out_size, 2))
311
+ self.num_layers = (self.log_size - 2) * 2 + 1
312
+ self.num_latent = self.log_size * 2 - 2
313
+
314
+ self.style_convs = nn.ModuleList()
315
+ self.to_rgbs = nn.ModuleList()
316
+ self.noises = nn.Module()
317
+
318
+ in_channels = channels["4"]
319
+ # noise
320
+ for layer_idx in range(self.num_layers):
321
+ resolution = 2 ** ((layer_idx + 5) // 2)
322
+ shape = [1, 1, resolution, resolution]
323
+ self.noises.register_buffer(f"noise{layer_idx}", torch.randn(*shape))
324
+ # style convs and to_rgbs
325
+ for i in range(3, self.log_size + 1):
326
+ out_channels = channels[f"{2**i}"]
327
+ self.style_convs.append(
328
+ StyleConv(
329
+ in_channels,
330
+ out_channels,
331
+ kernel_size=3,
332
+ num_style_feat=num_style_feat,
333
+ demodulate=True,
334
+ sample_mode="upsample",
335
+ )
336
+ )
337
+ self.style_convs.append(
338
+ StyleConv(
339
+ out_channels,
340
+ out_channels,
341
+ kernel_size=3,
342
+ num_style_feat=num_style_feat,
343
+ demodulate=True,
344
+ sample_mode=None,
345
+ )
346
+ )
347
+ self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True))
348
+ in_channels = out_channels
349
+
350
+ def make_noise(self):
351
+ """Make noise for noise injection."""
352
+ device = self.constant_input.weight.device
353
+ noises = [torch.randn(1, 1, 4, 4, device=device)]
354
+
355
+ for i in range(3, self.log_size + 1):
356
+ for _ in range(2):
357
+ noises.append(torch.randn(1, 1, 2**i, 2**i, device=device))
358
+
359
+ return noises
360
+
361
+ def get_latent(self, x):
362
+ return self.style_mlp(x)
363
+
364
+ def mean_latent(self, num_latent):
365
+ latent_in = torch.randn(
366
+ num_latent, self.num_style_feat, device=self.constant_input.weight.device
367
+ )
368
+ latent = self.style_mlp(latent_in).mean(0, keepdim=True)
369
+ return latent
370
+
371
+ def forward(
372
+ self,
373
+ styles,
374
+ input_is_latent=False,
375
+ noise=None,
376
+ randomize_noise=True,
377
+ truncation=1,
378
+ truncation_latent=None,
379
+ inject_index=None,
380
+ return_latents=False,
381
+ ):
382
+ """Forward function for StyleGAN2GeneratorClean.
383
+ Args:
384
+ styles (list[Tensor]): Sample codes of styles.
385
+ input_is_latent (bool): Whether input is latent style. Default: False.
386
+ noise (Tensor | None): Input noise or None. Default: None.
387
+ randomize_noise (bool): Randomize noise, used when 'noise' is False. Default: True.
388
+ truncation (float): The truncation ratio. Default: 1.
389
+ truncation_latent (Tensor | None): The truncation latent tensor. Default: None.
390
+ inject_index (int | None): The injection index for mixing noise. Default: None.
391
+ return_latents (bool): Whether to return style latents. Default: False.
392
+ """
393
+ # style codes -> latents with Style MLP layer
394
+ if not input_is_latent:
395
+ styles = [self.style_mlp(s) for s in styles]
396
+ # noises
397
+ if noise is None:
398
+ if randomize_noise:
399
+ noise = [None] * self.num_layers # for each style conv layer
400
+ else: # use the stored noise
401
+ noise = [
402
+ getattr(self.noises, f"noise{i}") for i in range(self.num_layers)
403
+ ]
404
+ # style truncation
405
+ if truncation < 1:
406
+ style_truncation = []
407
+ for style in styles:
408
+ style_truncation.append(
409
+ truncation_latent + truncation * (style - truncation_latent)
410
+ )
411
+ styles = style_truncation
412
+ # get style latents with injection
413
+ if len(styles) == 1:
414
+ inject_index = self.num_latent
415
+
416
+ if styles[0].ndim < 3:
417
+ # repeat latent code for all the layers
418
+ latent = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
419
+ else: # used for encoder with different latent code for each layer
420
+ latent = styles[0]
421
+ elif len(styles) == 2: # mixing noises
422
+ if inject_index is None:
423
+ inject_index = random.randint(1, self.num_latent - 1)
424
+ latent1 = styles[0].unsqueeze(1).repeat(1, inject_index, 1)
425
+ latent2 = (
426
+ styles[1].unsqueeze(1).repeat(1, self.num_latent - inject_index, 1)
427
+ )
428
+ latent = torch.cat([latent1, latent2], 1)
429
+
430
+ # main generation
431
+ out = self.constant_input(latent.shape[0])
432
+ out = self.style_conv1(out, latent[:, 0], noise=noise[0])
433
+ skip = self.to_rgb1(out, latent[:, 1])
434
+
435
+ i = 1
436
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
437
+ self.style_convs[::2],
438
+ self.style_convs[1::2],
439
+ noise[1::2],
440
+ noise[2::2],
441
+ self.to_rgbs,
442
+ ):
443
+ out = conv1(out, latent[:, i], noise=noise1)
444
+ out = conv2(out, latent[:, i + 1], noise=noise2)
445
+ skip = to_rgb(out, latent[:, i + 2], skip) # feature back to the rgb space
446
+ i += 2
447
+
448
+ image = skip
449
+
450
+ if return_latents:
451
+ return image, latent
452
+ else:
453
+ return image, None
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/face/upfirdn2d.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # pylint: skip-file
2
+ # type: ignore
3
+ # modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501
4
+
5
+ import os
6
+
7
+ import torch
8
+ from torch.autograd import Function
9
+ from torch.nn import functional as F
10
+
11
+ upfirdn2d_ext = None
12
+
13
+
14
+ class UpFirDn2dBackward(Function):
15
+ @staticmethod
16
+ def forward(
17
+ ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size
18
+ ):
19
+ up_x, up_y = up
20
+ down_x, down_y = down
21
+ g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1 = g_pad
22
+
23
+ grad_output = grad_output.reshape(-1, out_size[0], out_size[1], 1)
24
+
25
+ grad_input = upfirdn2d_ext.upfirdn2d(
26
+ grad_output,
27
+ grad_kernel,
28
+ down_x,
29
+ down_y,
30
+ up_x,
31
+ up_y,
32
+ g_pad_x0,
33
+ g_pad_x1,
34
+ g_pad_y0,
35
+ g_pad_y1,
36
+ )
37
+ grad_input = grad_input.view(in_size[0], in_size[1], in_size[2], in_size[3])
38
+
39
+ ctx.save_for_backward(kernel)
40
+
41
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
42
+
43
+ ctx.up_x = up_x
44
+ ctx.up_y = up_y
45
+ ctx.down_x = down_x
46
+ ctx.down_y = down_y
47
+ ctx.pad_x0 = pad_x0
48
+ ctx.pad_x1 = pad_x1
49
+ ctx.pad_y0 = pad_y0
50
+ ctx.pad_y1 = pad_y1
51
+ ctx.in_size = in_size
52
+ ctx.out_size = out_size
53
+
54
+ return grad_input
55
+
56
+ @staticmethod
57
+ def backward(ctx, gradgrad_input):
58
+ (kernel,) = ctx.saved_tensors
59
+
60
+ gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
61
+
62
+ gradgrad_out = upfirdn2d_ext.upfirdn2d(
63
+ gradgrad_input,
64
+ kernel,
65
+ ctx.up_x,
66
+ ctx.up_y,
67
+ ctx.down_x,
68
+ ctx.down_y,
69
+ ctx.pad_x0,
70
+ ctx.pad_x1,
71
+ ctx.pad_y0,
72
+ ctx.pad_y1,
73
+ )
74
+ # gradgrad_out = gradgrad_out.view(ctx.in_size[0], ctx.out_size[0],
75
+ # ctx.out_size[1], ctx.in_size[3])
76
+ gradgrad_out = gradgrad_out.view(
77
+ ctx.in_size[0], ctx.in_size[1], ctx.out_size[0], ctx.out_size[1]
78
+ )
79
+
80
+ return gradgrad_out, None, None, None, None, None, None, None, None
81
+
82
+
83
+ class UpFirDn2d(Function):
84
+ @staticmethod
85
+ def forward(ctx, input, kernel, up, down, pad):
86
+ up_x, up_y = up
87
+ down_x, down_y = down
88
+ pad_x0, pad_x1, pad_y0, pad_y1 = pad
89
+
90
+ kernel_h, kernel_w = kernel.shape
91
+ _, channel, in_h, in_w = input.shape
92
+ ctx.in_size = input.shape
93
+
94
+ input = input.reshape(-1, in_h, in_w, 1)
95
+
96
+ ctx.save_for_backward(kernel, torch.flip(kernel, [0, 1]))
97
+
98
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
99
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
100
+ ctx.out_size = (out_h, out_w)
101
+
102
+ ctx.up = (up_x, up_y)
103
+ ctx.down = (down_x, down_y)
104
+ ctx.pad = (pad_x0, pad_x1, pad_y0, pad_y1)
105
+
106
+ g_pad_x0 = kernel_w - pad_x0 - 1
107
+ g_pad_y0 = kernel_h - pad_y0 - 1
108
+ g_pad_x1 = in_w * up_x - out_w * down_x + pad_x0 - up_x + 1
109
+ g_pad_y1 = in_h * up_y - out_h * down_y + pad_y0 - up_y + 1
110
+
111
+ ctx.g_pad = (g_pad_x0, g_pad_x1, g_pad_y0, g_pad_y1)
112
+
113
+ out = upfirdn2d_ext.upfirdn2d(
114
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
115
+ )
116
+ # out = out.view(major, out_h, out_w, minor)
117
+ out = out.view(-1, channel, out_h, out_w)
118
+
119
+ return out
120
+
121
+ @staticmethod
122
+ def backward(ctx, grad_output):
123
+ kernel, grad_kernel = ctx.saved_tensors
124
+
125
+ grad_input = UpFirDn2dBackward.apply(
126
+ grad_output,
127
+ kernel,
128
+ grad_kernel,
129
+ ctx.up,
130
+ ctx.down,
131
+ ctx.pad,
132
+ ctx.g_pad,
133
+ ctx.in_size,
134
+ ctx.out_size,
135
+ )
136
+
137
+ return grad_input, None, None, None, None
138
+
139
+
140
+ def upfirdn2d(input, kernel, up=1, down=1, pad=(0, 0)):
141
+ if input.device.type == "cpu":
142
+ out = upfirdn2d_native(
143
+ input, kernel, up, up, down, down, pad[0], pad[1], pad[0], pad[1]
144
+ )
145
+ else:
146
+ out = UpFirDn2d.apply(
147
+ input, kernel, (up, up), (down, down), (pad[0], pad[1], pad[0], pad[1])
148
+ )
149
+
150
+ return out
151
+
152
+
153
+ def upfirdn2d_native(
154
+ input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1, pad_y0, pad_y1
155
+ ):
156
+ _, channel, in_h, in_w = input.shape
157
+ input = input.reshape(-1, in_h, in_w, 1)
158
+
159
+ _, in_h, in_w, minor = input.shape
160
+ kernel_h, kernel_w = kernel.shape
161
+
162
+ out = input.view(-1, in_h, 1, in_w, 1, minor)
163
+ out = F.pad(out, [0, 0, 0, up_x - 1, 0, 0, 0, up_y - 1])
164
+ out = out.view(-1, in_h * up_y, in_w * up_x, minor)
165
+
166
+ out = F.pad(
167
+ out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)]
168
+ )
169
+ out = out[
170
+ :,
171
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
172
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
173
+ :,
174
+ ]
175
+
176
+ out = out.permute(0, 3, 1, 2)
177
+ out = out.reshape(
178
+ [-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1]
179
+ )
180
+ w = torch.flip(kernel, [0, 1]).view(1, 1, kernel_h, kernel_w)
181
+ out = F.conv2d(out, w)
182
+ out = out.reshape(
183
+ -1,
184
+ minor,
185
+ in_h * up_y + pad_y0 + pad_y1 - kernel_h + 1,
186
+ in_w * up_x + pad_x0 + pad_x1 - kernel_w + 1,
187
+ )
188
+ out = out.permute(0, 2, 3, 1)
189
+ out = out[:, ::down_y, ::down_x, :]
190
+
191
+ out_h = (in_h * up_y + pad_y0 + pad_y1 - kernel_h) // down_y + 1
192
+ out_w = (in_w * up_x + pad_x0 + pad_x1 - kernel_w) // down_x + 1
193
+
194
+ return out.view(-1, channel, out_h, out_w)
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/timm/LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "{}"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2019 Ross Wightman
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/timm/drop.py ADDED
@@ -0,0 +1,223 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ DropBlock, DropPath
2
+
3
+ PyTorch implementations of DropBlock and DropPath (Stochastic Depth) regularization layers.
4
+
5
+ Papers:
6
+ DropBlock: A regularization method for convolutional networks (https://arxiv.org/abs/1810.12890)
7
+
8
+ Deep Networks with Stochastic Depth (https://arxiv.org/abs/1603.09382)
9
+
10
+ Code:
11
+ DropBlock impl inspired by two Tensorflow impl that I liked:
12
+ - https://github.com/tensorflow/tpu/blob/master/models/official/resnet/resnet_model.py#L74
13
+ - https://github.com/clovaai/assembled-cnn/blob/master/nets/blocks.py
14
+
15
+ Hacked together by / Copyright 2020 Ross Wightman
16
+ """
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.nn.functional as F
20
+
21
+
22
+ def drop_block_2d(
23
+ x,
24
+ drop_prob: float = 0.1,
25
+ block_size: int = 7,
26
+ gamma_scale: float = 1.0,
27
+ with_noise: bool = False,
28
+ inplace: bool = False,
29
+ batchwise: bool = False,
30
+ ):
31
+ """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
32
+
33
+ DropBlock with an experimental gaussian noise option. This layer has been tested on a few training
34
+ runs with success, but needs further validation and possibly optimization for lower runtime impact.
35
+ """
36
+ _, C, H, W = x.shape
37
+ total_size = W * H
38
+ clipped_block_size = min(block_size, min(W, H))
39
+ # seed_drop_rate, the gamma parameter
40
+ gamma = (
41
+ gamma_scale
42
+ * drop_prob
43
+ * total_size
44
+ / clipped_block_size**2
45
+ / ((W - block_size + 1) * (H - block_size + 1))
46
+ )
47
+
48
+ # Forces the block to be inside the feature map.
49
+ w_i, h_i = torch.meshgrid(
50
+ torch.arange(W).to(x.device), torch.arange(H).to(x.device)
51
+ )
52
+ valid_block = (
53
+ (w_i >= clipped_block_size // 2) & (w_i < W - (clipped_block_size - 1) // 2)
54
+ ) & ((h_i >= clipped_block_size // 2) & (h_i < H - (clipped_block_size - 1) // 2))
55
+ valid_block = torch.reshape(valid_block, (1, 1, H, W)).to(dtype=x.dtype)
56
+
57
+ if batchwise:
58
+ # one mask for whole batch, quite a bit faster
59
+ uniform_noise = torch.rand((1, C, H, W), dtype=x.dtype, device=x.device)
60
+ else:
61
+ uniform_noise = torch.rand_like(x)
62
+ block_mask = ((2 - gamma - valid_block + uniform_noise) >= 1).to(dtype=x.dtype)
63
+ block_mask = -F.max_pool2d(
64
+ -block_mask,
65
+ kernel_size=clipped_block_size, # block_size,
66
+ stride=1,
67
+ padding=clipped_block_size // 2,
68
+ )
69
+
70
+ if with_noise:
71
+ normal_noise = (
72
+ torch.randn((1, C, H, W), dtype=x.dtype, device=x.device)
73
+ if batchwise
74
+ else torch.randn_like(x)
75
+ )
76
+ if inplace:
77
+ x.mul_(block_mask).add_(normal_noise * (1 - block_mask))
78
+ else:
79
+ x = x * block_mask + normal_noise * (1 - block_mask)
80
+ else:
81
+ normalize_scale = (
82
+ block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-7)
83
+ ).to(x.dtype)
84
+ if inplace:
85
+ x.mul_(block_mask * normalize_scale)
86
+ else:
87
+ x = x * block_mask * normalize_scale
88
+ return x
89
+
90
+
91
+ def drop_block_fast_2d(
92
+ x: torch.Tensor,
93
+ drop_prob: float = 0.1,
94
+ block_size: int = 7,
95
+ gamma_scale: float = 1.0,
96
+ with_noise: bool = False,
97
+ inplace: bool = False,
98
+ ):
99
+ """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf
100
+
101
+ DropBlock with an experimental gaussian noise option. Simplied from above without concern for valid
102
+ block mask at edges.
103
+ """
104
+ _, _, H, W = x.shape
105
+ total_size = W * H
106
+ clipped_block_size = min(block_size, min(W, H))
107
+ gamma = (
108
+ gamma_scale
109
+ * drop_prob
110
+ * total_size
111
+ / clipped_block_size**2
112
+ / ((W - block_size + 1) * (H - block_size + 1))
113
+ )
114
+
115
+ block_mask = torch.empty_like(x).bernoulli_(gamma)
116
+ block_mask = F.max_pool2d(
117
+ block_mask.to(x.dtype),
118
+ kernel_size=clipped_block_size,
119
+ stride=1,
120
+ padding=clipped_block_size // 2,
121
+ )
122
+
123
+ if with_noise:
124
+ normal_noise = torch.empty_like(x).normal_()
125
+ if inplace:
126
+ x.mul_(1.0 - block_mask).add_(normal_noise * block_mask)
127
+ else:
128
+ x = x * (1.0 - block_mask) + normal_noise * block_mask
129
+ else:
130
+ block_mask = 1 - block_mask
131
+ normalize_scale = (
132
+ block_mask.numel() / block_mask.to(dtype=torch.float32).sum().add(1e-6)
133
+ ).to(dtype=x.dtype)
134
+ if inplace:
135
+ x.mul_(block_mask * normalize_scale)
136
+ else:
137
+ x = x * block_mask * normalize_scale
138
+ return x
139
+
140
+
141
+ class DropBlock2d(nn.Module):
142
+ """DropBlock. See https://arxiv.org/pdf/1810.12890.pdf"""
143
+
144
+ def __init__(
145
+ self,
146
+ drop_prob: float = 0.1,
147
+ block_size: int = 7,
148
+ gamma_scale: float = 1.0,
149
+ with_noise: bool = False,
150
+ inplace: bool = False,
151
+ batchwise: bool = False,
152
+ fast: bool = True,
153
+ ):
154
+ super(DropBlock2d, self).__init__()
155
+ self.drop_prob = drop_prob
156
+ self.gamma_scale = gamma_scale
157
+ self.block_size = block_size
158
+ self.with_noise = with_noise
159
+ self.inplace = inplace
160
+ self.batchwise = batchwise
161
+ self.fast = fast # FIXME finish comparisons of fast vs not
162
+
163
+ def forward(self, x):
164
+ if not self.training or not self.drop_prob:
165
+ return x
166
+ if self.fast:
167
+ return drop_block_fast_2d(
168
+ x,
169
+ self.drop_prob,
170
+ self.block_size,
171
+ self.gamma_scale,
172
+ self.with_noise,
173
+ self.inplace,
174
+ )
175
+ else:
176
+ return drop_block_2d(
177
+ x,
178
+ self.drop_prob,
179
+ self.block_size,
180
+ self.gamma_scale,
181
+ self.with_noise,
182
+ self.inplace,
183
+ self.batchwise,
184
+ )
185
+
186
+
187
+ def drop_path(
188
+ x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
189
+ ):
190
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
191
+
192
+ This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
193
+ the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
194
+ See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
195
+ changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
196
+ 'survival rate' as the argument.
197
+
198
+ """
199
+ if drop_prob == 0.0 or not training:
200
+ return x
201
+ keep_prob = 1 - drop_prob
202
+ shape = (x.shape[0],) + (1,) * (
203
+ x.ndim - 1
204
+ ) # work with diff dim tensors, not just 2D ConvNets
205
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
206
+ if keep_prob > 0.0 and scale_by_keep:
207
+ random_tensor.div_(keep_prob)
208
+ return x * random_tensor
209
+
210
+
211
+ class DropPath(nn.Module):
212
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
213
+
214
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
215
+ super(DropPath, self).__init__()
216
+ self.drop_prob = drop_prob
217
+ self.scale_by_keep = scale_by_keep
218
+
219
+ def forward(self, x):
220
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
221
+
222
+ def extra_repr(self):
223
+ return f"drop_prob={round(self.drop_prob,3):0.3f}"
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/timm/helpers.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Layer/Module Helpers
2
+ Hacked together by / Copyright 2020 Ross Wightman
3
+ """
4
+ import collections.abc
5
+ from itertools import repeat
6
+
7
+
8
+ # From PyTorch internals
9
+ def _ntuple(n):
10
+ def parse(x):
11
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
12
+ return x
13
+ return tuple(repeat(x, n))
14
+
15
+ return parse
16
+
17
+
18
+ to_1tuple = _ntuple(1)
19
+ to_2tuple = _ntuple(2)
20
+ to_3tuple = _ntuple(3)
21
+ to_4tuple = _ntuple(4)
22
+ to_ntuple = _ntuple
23
+
24
+
25
+ def make_divisible(v, divisor=8, min_value=None, round_limit=0.9):
26
+ min_value = min_value or divisor
27
+ new_v = max(min_value, int(v + divisor / 2) // divisor * divisor)
28
+ # Make sure that round down does not go down by more than 10%.
29
+ if new_v < round_limit * v:
30
+ new_v += divisor
31
+ return new_v
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/architecture/timm/weight_init.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import warnings
3
+
4
+ import torch
5
+ from torch.nn.init import _calculate_fan_in_and_fan_out
6
+
7
+
8
+ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
9
+ # Cut & paste from PyTorch official master until it's in a few official releases - RW
10
+ # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
11
+ def norm_cdf(x):
12
+ # Computes standard normal cumulative distribution function
13
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
14
+
15
+ if (mean < a - 2 * std) or (mean > b + 2 * std):
16
+ warnings.warn(
17
+ "mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
18
+ "The distribution of values may be incorrect.",
19
+ stacklevel=2,
20
+ )
21
+
22
+ with torch.no_grad():
23
+ # Values are generated by using a truncated uniform distribution and
24
+ # then using the inverse CDF for the normal distribution.
25
+ # Get upper and lower cdf values
26
+ l = norm_cdf((a - mean) / std)
27
+ u = norm_cdf((b - mean) / std)
28
+
29
+ # Uniformly fill tensor with values from [l, u], then translate to
30
+ # [2l-1, 2u-1].
31
+ tensor.uniform_(2 * l - 1, 2 * u - 1)
32
+
33
+ # Use inverse cdf transform for normal distribution to get truncated
34
+ # standard normal
35
+ tensor.erfinv_()
36
+
37
+ # Transform to proper mean, std
38
+ tensor.mul_(std * math.sqrt(2.0))
39
+ tensor.add_(mean)
40
+
41
+ # Clamp to ensure it's in the proper range
42
+ tensor.clamp_(min=a, max=b)
43
+ return tensor
44
+
45
+
46
+ def trunc_normal_(
47
+ tensor: torch.Tensor, mean=0.0, std=1.0, a=-2.0, b=2.0
48
+ ) -> torch.Tensor:
49
+ r"""Fills the input Tensor with values drawn from a truncated
50
+ normal distribution. The values are effectively drawn from the
51
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
52
+ with values outside :math:`[a, b]` redrawn until they are within
53
+ the bounds. The method used for generating the random values works
54
+ best when :math:`a \leq \text{mean} \leq b`.
55
+
56
+ NOTE: this impl is similar to the PyTorch trunc_normal_, the bounds [a, b] are
57
+ applied while sampling the normal with mean/std applied, therefore a, b args
58
+ should be adjusted to match the range of mean, std args.
59
+
60
+ Args:
61
+ tensor: an n-dimensional `torch.Tensor`
62
+ mean: the mean of the normal distribution
63
+ std: the standard deviation of the normal distribution
64
+ a: the minimum cutoff value
65
+ b: the maximum cutoff value
66
+ Examples:
67
+ >>> w = torch.empty(3, 5)
68
+ >>> nn.init.trunc_normal_(w)
69
+ """
70
+ return _no_grad_trunc_normal_(tensor, mean, std, a, b)
71
+
72
+
73
+ def trunc_normal_tf_(
74
+ tensor: torch.Tensor, mean=0.0, std=1.0, a=-2.0, b=2.0
75
+ ) -> torch.Tensor:
76
+ r"""Fills the input Tensor with values drawn from a truncated
77
+ normal distribution. The values are effectively drawn from the
78
+ normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
79
+ with values outside :math:`[a, b]` redrawn until they are within
80
+ the bounds. The method used for generating the random values works
81
+ best when :math:`a \leq \text{mean} \leq b`.
82
+
83
+ NOTE: this 'tf' variant behaves closer to Tensorflow / JAX impl where the
84
+ bounds [a, b] are applied when sampling the normal distribution with mean=0, std=1.0
85
+ and the result is subsquently scaled and shifted by the mean and std args.
86
+
87
+ Args:
88
+ tensor: an n-dimensional `torch.Tensor`
89
+ mean: the mean of the normal distribution
90
+ std: the standard deviation of the normal distribution
91
+ a: the minimum cutoff value
92
+ b: the maximum cutoff value
93
+ Examples:
94
+ >>> w = torch.empty(3, 5)
95
+ >>> nn.init.trunc_normal_(w)
96
+ """
97
+ _no_grad_trunc_normal_(tensor, 0, 1.0, a, b)
98
+ with torch.no_grad():
99
+ tensor.mul_(std).add_(mean)
100
+ return tensor
101
+
102
+
103
+ def variance_scaling_(tensor, scale=1.0, mode="fan_in", distribution="normal"):
104
+ fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
105
+ if mode == "fan_in":
106
+ denom = fan_in
107
+ elif mode == "fan_out":
108
+ denom = fan_out
109
+ elif mode == "fan_avg":
110
+ denom = (fan_in + fan_out) / 2
111
+
112
+ variance = scale / denom # type: ignore
113
+
114
+ if distribution == "truncated_normal":
115
+ # constant is stddev of standard normal truncated to (-2, 2)
116
+ trunc_normal_tf_(tensor, std=math.sqrt(variance) / 0.87962566103423978)
117
+ elif distribution == "normal":
118
+ tensor.normal_(std=math.sqrt(variance))
119
+ elif distribution == "uniform":
120
+ bound = math.sqrt(3 * variance)
121
+ # pylint: disable=invalid-unary-operand-type
122
+ tensor.uniform_(-bound, bound)
123
+ else:
124
+ raise ValueError(f"invalid distribution {distribution}")
125
+
126
+
127
+ def lecun_normal_(tensor):
128
+ variance_scaling_(tensor, mode="fan_in", distribution="truncated_normal")
LayerDiffuse-gradio-unofficial/ComfyUI/comfy_extras/chainner_models/model_loading.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging as logger
2
+
3
+ from .architecture.DAT import DAT
4
+ from .architecture.face.codeformer import CodeFormer
5
+ from .architecture.face.gfpganv1_clean_arch import GFPGANv1Clean
6
+ from .architecture.face.restoreformer_arch import RestoreFormer
7
+ from .architecture.HAT import HAT
8
+ from .architecture.LaMa import LaMa
9
+ from .architecture.OmniSR.OmniSR import OmniSR
10
+ from .architecture.RRDB import RRDBNet as ESRGAN
11
+ from .architecture.SCUNet import SCUNet
12
+ from .architecture.SPSR import SPSRNet as SPSR
13
+ from .architecture.SRVGG import SRVGGNetCompact as RealESRGANv2
14
+ from .architecture.SwiftSRGAN import Generator as SwiftSRGAN
15
+ from .architecture.Swin2SR import Swin2SR
16
+ from .architecture.SwinIR import SwinIR
17
+ from .types import PyTorchModel
18
+
19
+
20
+ class UnsupportedModel(Exception):
21
+ pass
22
+
23
+
24
+ def load_state_dict(state_dict) -> PyTorchModel:
25
+ logger.debug(f"Loading state dict into pytorch model arch")
26
+
27
+ state_dict_keys = list(state_dict.keys())
28
+
29
+ if "params_ema" in state_dict_keys:
30
+ state_dict = state_dict["params_ema"]
31
+ elif "params-ema" in state_dict_keys:
32
+ state_dict = state_dict["params-ema"]
33
+ elif "params" in state_dict_keys:
34
+ state_dict = state_dict["params"]
35
+
36
+ state_dict_keys = list(state_dict.keys())
37
+ # SRVGGNet Real-ESRGAN (v2)
38
+ if "body.0.weight" in state_dict_keys and "body.1.weight" in state_dict_keys:
39
+ model = RealESRGANv2(state_dict)
40
+ # SPSR (ESRGAN with lots of extra layers)
41
+ elif "f_HR_conv1.0.weight" in state_dict:
42
+ model = SPSR(state_dict)
43
+ # Swift-SRGAN
44
+ elif (
45
+ "model" in state_dict_keys
46
+ and "initial.cnn.depthwise.weight" in state_dict["model"].keys()
47
+ ):
48
+ model = SwiftSRGAN(state_dict)
49
+ # SwinIR, Swin2SR, HAT
50
+ elif "layers.0.residual_group.blocks.0.norm1.weight" in state_dict_keys:
51
+ if (
52
+ "layers.0.residual_group.blocks.0.conv_block.cab.0.weight"
53
+ in state_dict_keys
54
+ ):
55
+ model = HAT(state_dict)
56
+ elif "patch_embed.proj.weight" in state_dict_keys:
57
+ model = Swin2SR(state_dict)
58
+ else:
59
+ model = SwinIR(state_dict)
60
+ # GFPGAN
61
+ elif (
62
+ "toRGB.0.weight" in state_dict_keys
63
+ and "stylegan_decoder.style_mlp.1.weight" in state_dict_keys
64
+ ):
65
+ model = GFPGANv1Clean(state_dict)
66
+ # RestoreFormer
67
+ elif (
68
+ "encoder.conv_in.weight" in state_dict_keys
69
+ and "encoder.down.0.block.0.norm1.weight" in state_dict_keys
70
+ ):
71
+ model = RestoreFormer(state_dict)
72
+ elif (
73
+ "encoder.blocks.0.weight" in state_dict_keys
74
+ and "quantize.embedding.weight" in state_dict_keys
75
+ ):
76
+ model = CodeFormer(state_dict)
77
+ # LaMa
78
+ elif (
79
+ "model.model.1.bn_l.running_mean" in state_dict_keys
80
+ or "generator.model.1.bn_l.running_mean" in state_dict_keys
81
+ ):
82
+ model = LaMa(state_dict)
83
+ # Omni-SR
84
+ elif "residual_layer.0.residual_layer.0.layer.0.fn.0.weight" in state_dict_keys:
85
+ model = OmniSR(state_dict)
86
+ # SCUNet
87
+ elif "m_head.0.weight" in state_dict_keys and "m_tail.0.weight" in state_dict_keys:
88
+ model = SCUNet(state_dict)
89
+ # DAT
90
+ elif "layers.0.blocks.2.attn.attn_mask_0" in state_dict_keys:
91
+ model = DAT(state_dict)
92
+ # Regular ESRGAN, "new-arch" ESRGAN, Real-ESRGAN v1
93
+ else:
94
+ try:
95
+ model = ESRGAN(state_dict)
96
+ except:
97
+ # pylint: disable=raise-missing-from
98
+ raise UnsupportedModel
99
+ return model