recoilme commited on
Commit
e50147a
·
verified ·
1 Parent(s): ab501b3

Upload folder using huggingface_hub

Browse files
.ipynb_checkpoints/stae-checkpoint.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class RMSNorm2d(nn.Module):
6
+ def __init__(self, channels, eps=1e-8, affine=True):
7
+ super().__init__()
8
+ self.eps = eps
9
+ self.affine = affine
10
+ if affine:
11
+ self.weight = nn.Parameter(torch.ones(channels))
12
+ else:
13
+ self.register_parameter("weight", None)
14
+
15
+ def forward(self, x):
16
+ norm = x.pow(2).mean(dim=1, keepdim=True).add(self.eps).rsqrt()
17
+ x = x * norm
18
+ if self.affine:
19
+ x = x * self.weight[:, None, None]
20
+ return x
21
+
22
+ class ConvMlp(nn.Module):
23
+ def __init__(self, in_features, hidden_features=None, out_features=None):
24
+ super().__init__()
25
+ self.model = nn.Sequential(
26
+ nn.Conv2d(in_channels=in_features, out_channels=hidden_features, kernel_size=1),
27
+ nn.GELU(),
28
+ nn.Conv2d(in_channels=hidden_features, out_channels=out_features, kernel_size=1),
29
+ )
30
+
31
+ def forward(self, x):
32
+ return self.model(x)
33
+
34
+ class GegluMlp(nn.Module):
35
+ def __init__(self, hidden_dim):
36
+ super().__init__()
37
+ self.conv_up = nn.Conv2d(hidden_dim, hidden_dim * 4, kernel_size=1)
38
+ self.conv_down = nn.Conv2d(hidden_dim * 2, hidden_dim, kernel_size=1)
39
+ self.activation = nn.GELU(approximate="tanh")
40
+
41
+ def forward(self, x):
42
+ x = self.conv_up(x)
43
+ x_gate, x_act = torch.chunk(x, 2, dim=1)
44
+ x = self.activation(x_act) * x_gate
45
+ x = self.conv_down(x)
46
+
47
+ return x
48
+
49
+ class EncoderBlock(nn.Module):
50
+ def __init__(self, channels):
51
+ super().__init__()
52
+ self.norm = RMSNorm2d(channels)
53
+ hidden_dim = channels
54
+
55
+ self.mlp = GegluMlp(hidden_dim)
56
+
57
+ def forward(self, x):
58
+ norm = self.norm(x)
59
+ mlp_out = self.mlp(norm)
60
+ x = x + mlp_out
61
+
62
+ return x
63
+
64
+ class DecoderBlock(nn.Module):
65
+ def __init__(self, channels):
66
+ super().__init__()
67
+ self.norm = RMSNorm2d(channels)
68
+
69
+ self.mlp = nn.Sequential(
70
+ nn.Conv2d(channels, channels, kernel_size=1),
71
+ nn.GELU(approximate="tanh"),
72
+ nn.Conv2d(channels, channels, kernel_size=3, padding=1),
73
+ )
74
+
75
+ def forward(self, x):
76
+ norm = self.norm(x)
77
+ mlp_out = self.mlp(norm)
78
+ x = x + mlp_out
79
+
80
+ return x
81
+
82
+ class StupidEncoder(nn.Module):
83
+ def __init__(self,
84
+ hidden_dim,
85
+ in_channels,
86
+ out_channels,
87
+ patch_size,
88
+ num_blocks):
89
+ super().__init__()
90
+
91
+ self.initial = nn.Sequential(
92
+ nn.Conv2d(in_channels, hidden_dim, patch_size, padding=0, stride=patch_size),
93
+ )
94
+
95
+ self.blocks = nn.ModuleList(EncoderBlock(hidden_dim) for _ in range(num_blocks))
96
+ self.out = ConvMlp(hidden_dim, hidden_dim, out_channels)
97
+
98
+ def forward(self, x):
99
+ x = self.initial(x)
100
+
101
+ for block in self.blocks:
102
+ x = block(x)
103
+
104
+ x = self.out(x)
105
+ return x
106
+
107
+ class NerfHead(nn.Module):
108
+ def __init__(self, patch_dim, mlp_dim):
109
+ super().__init__()
110
+ self.mlp_dim = mlp_dim
111
+ self.param_gen = nn.Linear(patch_dim, self.mlp_dim*self.mlp_dim*2)
112
+ self.norm = nn.RMSNorm(self.mlp_dim)
113
+
114
+ def forward(self, pixels, patches):
115
+ bs = pixels.shape[0]
116
+ params = self.param_gen(patches)
117
+ layer1, layer2 = params.chunk(2, dim=-1)
118
+ layer1 = layer1.view(bs, self.mlp_dim, self.mlp_dim)
119
+ layer2 = layer2.view(bs, self.mlp_dim, self.mlp_dim)
120
+
121
+ layer1 = torch.nn.functional.normalize(layer1, dim=-2)
122
+
123
+ res_x = pixels
124
+ pixels = self.norm(pixels)
125
+ pixels = torch.bmm(pixels, layer1)
126
+ pixels = torch.nn.functional.silu(pixels)
127
+ pixels = torch.bmm(pixels, layer2)
128
+ pixels = pixels + res_x
129
+ return pixels
130
+
131
+ class StupidDecoder(nn.Module):
132
+ def __init__(self,
133
+ hidden_dim,
134
+ in_channels,
135
+ out_channels,
136
+ patch_size,
137
+ num_blocks,
138
+ nerf_blocks,
139
+ mlp_dim):
140
+ super().__init__()
141
+
142
+ self.out_channels = out_channels
143
+
144
+ self.patch_size = patch_size
145
+ self.conv_in = ConvMlp(in_channels, hidden_dim, hidden_dim)
146
+ self.blocks = []
147
+ for _ in range(num_blocks):
148
+ self.blocks.append(DecoderBlock(hidden_dim))
149
+ self.blocks.append(EncoderBlock(hidden_dim))
150
+ self.blocks = nn.ModuleList(self.blocks)
151
+
152
+ self.nerf = nn.ModuleList(NerfHead(hidden_dim, mlp_dim) for _ in range(nerf_blocks))
153
+ self.positions = nn.Parameter(torch.randn(1, self.patch_size**2, mlp_dim))
154
+ self.last = nn.Linear(mlp_dim, self.out_channels)
155
+
156
+ def forward(self, x):
157
+ B, C, H, W = x.shape
158
+ x = self.conv_in(x)
159
+ for block in self.blocks:
160
+ x = block(x)
161
+
162
+ patches = x.flatten(2).transpose(1,2) # B C H W -> B (HW) C
163
+ patch_count = H*W
164
+ total_len = x.shape[0] * patch_count
165
+ patches = patches.reshape(total_len, -1)
166
+ x = self.positions.repeat(total_len, 1, 1)
167
+
168
+ for block in self.nerf:
169
+ x = block(x, patches) # B * patch_count, ps*ps, C
170
+ x = self.last(x)
171
+ x = x.transpose(1,2) # [B * patch_count, ps*ps, C] -> [B*patch_count, C, ps*ps]
172
+ x = x.reshape(B, patch_count, -1) # [B*patch_count, C, ps*ps] -> [B, patch_count, ps*ps*3]
173
+ x = x.transpose(1,2) # [B, patch_count, ps*ps*3] -> [B, ps*ps*3, patch_count]
174
+ x = torch.nn.functional.fold(x.contiguous(),
175
+ (H*self.patch_size, W*self.patch_size),
176
+ kernel_size=self.patch_size,
177
+ stride=self.patch_size)
178
+
179
+ return x
180
+
181
+ class SimpleStupidDecoder(nn.Module):
182
+ def __init__(self,
183
+ hidden_dim,
184
+ in_channels,
185
+ out_channels,
186
+ patch_size,
187
+ num_blocks):
188
+ super().__init__()
189
+
190
+ self.out_channels = out_channels
191
+ self.patch_size = patch_size
192
+
193
+ self.conv_in = ConvMlp(in_channels, hidden_dim, hidden_dim)
194
+ self.blocks = nn.ModuleList(DecoderBlock(hidden_dim) for _ in range(num_blocks))
195
+
196
+ self.last = nn.Sequential(
197
+ ConvMlp(hidden_dim, hidden_dim, out_channels * patch_size * patch_size),
198
+ nn.PixelShuffle(patch_size)
199
+ )
200
+
201
+ def forward(self, x):
202
+ x = self.conv_in(x)
203
+ for block in self.blocks:
204
+ x = block(x)
205
+
206
+ return self.last(x)
207
+
208
+ class StupidAE(nn.Module):
209
+ def __init__(self):
210
+ super().__init__()
211
+
212
+ self.encoder = nn.Sequential(
213
+ StupidEncoder(in_channels=3, out_channels=4, hidden_dim=1024, patch_size=8, num_blocks=2),
214
+ )
215
+ self.decoder = nn.Sequential(
216
+ StupidDecoder(in_channels=4, out_channels=3, hidden_dim=1024, patch_size=8, num_blocks=2, nerf_blocks=1, mlp_dim=32)
217
+ )
218
+
219
+ def encode(self, x):
220
+ return self.encoder(x)
221
+
222
+ def decode(self, x):
223
+ return self.decoder(x)
224
+
225
+ def forward(self, x):
226
+ x = self.encode(x)
227
+ x = self.decode(x)
228
+ return x
stae.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class RMSNorm2d(nn.Module):
6
+ def __init__(self, channels, eps=1e-8, affine=True):
7
+ super().__init__()
8
+ self.eps = eps
9
+ self.affine = affine
10
+ if affine:
11
+ self.weight = nn.Parameter(torch.ones(channels))
12
+ else:
13
+ self.register_parameter("weight", None)
14
+
15
+ def forward(self, x):
16
+ norm = x.pow(2).mean(dim=1, keepdim=True).add(self.eps).rsqrt()
17
+ x = x * norm
18
+ if self.affine:
19
+ x = x * self.weight[:, None, None]
20
+ return x
21
+
22
+ class ConvMlp(nn.Module):
23
+ def __init__(self, in_features, hidden_features=None, out_features=None):
24
+ super().__init__()
25
+ self.model = nn.Sequential(
26
+ nn.Conv2d(in_channels=in_features, out_channels=hidden_features, kernel_size=1),
27
+ nn.GELU(),
28
+ nn.Conv2d(in_channels=hidden_features, out_channels=out_features, kernel_size=1),
29
+ )
30
+
31
+ def forward(self, x):
32
+ return self.model(x)
33
+
34
+ class GegluMlp(nn.Module):
35
+ def __init__(self, hidden_dim):
36
+ super().__init__()
37
+ self.conv_up = nn.Conv2d(hidden_dim, hidden_dim * 4, kernel_size=1)
38
+ self.conv_down = nn.Conv2d(hidden_dim * 2, hidden_dim, kernel_size=1)
39
+ self.activation = nn.GELU(approximate="tanh")
40
+
41
+ def forward(self, x):
42
+ x = self.conv_up(x)
43
+ x_gate, x_act = torch.chunk(x, 2, dim=1)
44
+ x = self.activation(x_act) * x_gate
45
+ x = self.conv_down(x)
46
+
47
+ return x
48
+
49
+ class EncoderBlock(nn.Module):
50
+ def __init__(self, channels):
51
+ super().__init__()
52
+ self.norm = RMSNorm2d(channels)
53
+ hidden_dim = channels
54
+
55
+ self.mlp = GegluMlp(hidden_dim)
56
+
57
+ def forward(self, x):
58
+ norm = self.norm(x)
59
+ mlp_out = self.mlp(norm)
60
+ x = x + mlp_out
61
+
62
+ return x
63
+
64
+ class DecoderBlock(nn.Module):
65
+ def __init__(self, channels):
66
+ super().__init__()
67
+ self.norm = RMSNorm2d(channels)
68
+
69
+ self.mlp = nn.Sequential(
70
+ nn.Conv2d(channels, channels, kernel_size=1),
71
+ nn.GELU(approximate="tanh"),
72
+ nn.Conv2d(channels, channels, kernel_size=3, padding=1),
73
+ )
74
+
75
+ def forward(self, x):
76
+ norm = self.norm(x)
77
+ mlp_out = self.mlp(norm)
78
+ x = x + mlp_out
79
+
80
+ return x
81
+
82
+ class StupidEncoder(nn.Module):
83
+ def __init__(self,
84
+ hidden_dim,
85
+ in_channels,
86
+ out_channels,
87
+ patch_size,
88
+ num_blocks):
89
+ super().__init__()
90
+
91
+ self.initial = nn.Sequential(
92
+ nn.Conv2d(in_channels, hidden_dim, patch_size, padding=0, stride=patch_size),
93
+ )
94
+
95
+ self.blocks = nn.ModuleList(EncoderBlock(hidden_dim) for _ in range(num_blocks))
96
+ self.out = ConvMlp(hidden_dim, hidden_dim, out_channels)
97
+
98
+ def forward(self, x):
99
+ x = self.initial(x)
100
+
101
+ for block in self.blocks:
102
+ x = block(x)
103
+
104
+ x = self.out(x)
105
+ return x
106
+
107
+ class NerfHead(nn.Module):
108
+ def __init__(self, patch_dim, mlp_dim):
109
+ super().__init__()
110
+ self.mlp_dim = mlp_dim
111
+ self.param_gen = nn.Linear(patch_dim, self.mlp_dim*self.mlp_dim*2)
112
+ self.norm = nn.RMSNorm(self.mlp_dim)
113
+
114
+ def forward(self, pixels, patches):
115
+ bs = pixels.shape[0]
116
+ params = self.param_gen(patches)
117
+ layer1, layer2 = params.chunk(2, dim=-1)
118
+ layer1 = layer1.view(bs, self.mlp_dim, self.mlp_dim)
119
+ layer2 = layer2.view(bs, self.mlp_dim, self.mlp_dim)
120
+
121
+ layer1 = torch.nn.functional.normalize(layer1, dim=-2)
122
+
123
+ res_x = pixels
124
+ pixels = self.norm(pixels)
125
+ pixels = torch.bmm(pixels, layer1)
126
+ pixels = torch.nn.functional.silu(pixels)
127
+ pixels = torch.bmm(pixels, layer2)
128
+ pixels = pixels + res_x
129
+ return pixels
130
+
131
+ class StupidDecoder(nn.Module):
132
+ def __init__(self,
133
+ hidden_dim,
134
+ in_channels,
135
+ out_channels,
136
+ patch_size,
137
+ num_blocks,
138
+ nerf_blocks,
139
+ mlp_dim):
140
+ super().__init__()
141
+
142
+ self.out_channels = out_channels
143
+
144
+ self.patch_size = patch_size
145
+ self.conv_in = ConvMlp(in_channels, hidden_dim, hidden_dim)
146
+ self.blocks = []
147
+ for _ in range(num_blocks):
148
+ self.blocks.append(DecoderBlock(hidden_dim))
149
+ self.blocks.append(EncoderBlock(hidden_dim))
150
+ self.blocks = nn.ModuleList(self.blocks)
151
+
152
+ self.nerf = nn.ModuleList(NerfHead(hidden_dim, mlp_dim) for _ in range(nerf_blocks))
153
+ self.positions = nn.Parameter(torch.randn(1, self.patch_size**2, mlp_dim))
154
+ self.last = nn.Linear(mlp_dim, self.out_channels)
155
+
156
+ def forward(self, x):
157
+ B, C, H, W = x.shape
158
+ x = self.conv_in(x)
159
+ for block in self.blocks:
160
+ x = block(x)
161
+
162
+ patches = x.flatten(2).transpose(1,2) # B C H W -> B (HW) C
163
+ patch_count = H*W
164
+ total_len = x.shape[0] * patch_count
165
+ patches = patches.reshape(total_len, -1)
166
+ x = self.positions.repeat(total_len, 1, 1)
167
+
168
+ for block in self.nerf:
169
+ x = block(x, patches) # B * patch_count, ps*ps, C
170
+ x = self.last(x)
171
+ x = x.transpose(1,2) # [B * patch_count, ps*ps, C] -> [B*patch_count, C, ps*ps]
172
+ x = x.reshape(B, patch_count, -1) # [B*patch_count, C, ps*ps] -> [B, patch_count, ps*ps*3]
173
+ x = x.transpose(1,2) # [B, patch_count, ps*ps*3] -> [B, ps*ps*3, patch_count]
174
+ x = torch.nn.functional.fold(x.contiguous(),
175
+ (H*self.patch_size, W*self.patch_size),
176
+ kernel_size=self.patch_size,
177
+ stride=self.patch_size)
178
+
179
+ return x
180
+
181
+ class SimpleStupidDecoder(nn.Module):
182
+ def __init__(self,
183
+ hidden_dim,
184
+ in_channels,
185
+ out_channels,
186
+ patch_size,
187
+ num_blocks):
188
+ super().__init__()
189
+
190
+ self.out_channels = out_channels
191
+ self.patch_size = patch_size
192
+
193
+ self.conv_in = ConvMlp(in_channels, hidden_dim, hidden_dim)
194
+ self.blocks = nn.ModuleList(DecoderBlock(hidden_dim) for _ in range(num_blocks))
195
+
196
+ self.last = nn.Sequential(
197
+ ConvMlp(hidden_dim, hidden_dim, out_channels * patch_size * patch_size),
198
+ nn.PixelShuffle(patch_size)
199
+ )
200
+
201
+ def forward(self, x):
202
+ x = self.conv_in(x)
203
+ for block in self.blocks:
204
+ x = block(x)
205
+
206
+ return self.last(x)
207
+
208
+ class StupidAE(nn.Module):
209
+ def __init__(self):
210
+ super().__init__()
211
+
212
+ self.encoder = nn.Sequential(
213
+ StupidEncoder(in_channels=3, out_channels=4, hidden_dim=1024, patch_size=8, num_blocks=2),
214
+ )
215
+ self.decoder = nn.Sequential(
216
+ StupidDecoder(in_channels=4, out_channels=3, hidden_dim=1024, patch_size=8, num_blocks=2, nerf_blocks=1, mlp_dim=32)
217
+ )
218
+
219
+ def encode(self, x):
220
+ return self.encoder(x)
221
+
222
+ def decode(self, x):
223
+ return self.decoder(x)
224
+
225
+ def forward(self, x):
226
+ x = self.encode(x)
227
+ x = self.decode(x)
228
+ return x
train_stae_fdl.py ADDED
@@ -0,0 +1,646 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+ import math
4
+ import re
5
+ import torch
6
+ import numpy as np
7
+ import random
8
+ import gc
9
+ from datetime import datetime
10
+ from pathlib import Path
11
+
12
+ import torchvision.transforms as transforms
13
+ import torch.nn.functional as F
14
+ from torch.utils.data import DataLoader, Dataset
15
+ from torch.optim.lr_scheduler import LambdaLR
16
+ from diffusers import AutoencoderKL, AsymmetricAutoencoderKL
17
+ # QWEN: импорт класса
18
+ from diffusers import AutoencoderKLQwenImage
19
+ from diffusers import AutoencoderKLWan
20
+
21
+ from accelerate import Accelerator
22
+ from PIL import Image, UnidentifiedImageError
23
+ from tqdm import tqdm
24
+ import bitsandbytes as bnb
25
+ import wandb
26
+ import lpips # pip install lpips
27
+ from FDL_pytorch import FDL_loss # pip install fdl-pytorch
28
+ from collections import deque
29
+ from stae import StupidAE
30
+ from huggingface_hub import hf_hub_download
31
+ from safetensors.torch import save_file
32
+
33
+ # --------------------------- Параметры ---------------------------
34
+ ds_path = "/workspace/d23"
35
+ project = "stae3"
36
+ batch_size = 16
37
+ base_learning_rate = 5e-5
38
+ min_learning_rate = 1e-5
39
+ num_epochs = 10
40
+ sample_interval_share = 10
41
+ use_wandb = True
42
+ save_model = True
43
+ use_decay = True
44
+ optimizer_type = "adam8bit"
45
+ dtype = torch.float32
46
+
47
+ model_resolution = 256
48
+ high_resolution = 256
49
+ limit = 0
50
+ save_barrier = 1.3
51
+ warmup_percent = 0.005
52
+ percentile_clipping = 99
53
+ beta2 = 0.997
54
+ eps = 1e-8
55
+ clip_grad_norm = 1.0
56
+ mixed_precision = "no"
57
+ gradient_accumulation_steps = 1
58
+ generated_folder = "samples"
59
+ save_as = "stae3"
60
+ num_workers = 0
61
+ device = None
62
+
63
+ # --- Режимы обучения ---
64
+ # QWEN: учим только декодер
65
+ train_decoder_only = False
66
+ train_up_only = False
67
+ full_training = True # если True — учим весь VAE и добавляем KL (ниже)
68
+ kl_ratio = 0.00
69
+
70
+ # Доли лоссов
71
+ loss_ratios = {
72
+ "lpips": 0.70,#0.50,
73
+ "fdl" : 0.10,#0.25,
74
+ "edge": 0.05,
75
+ "mse": 0.10,
76
+ "mae": 0.05,
77
+ "kl": 0.0, # активируем при full_training=True
78
+ }
79
+ median_coeff_steps = 256
80
+
81
+ resize_long_side = 1280 # ресайз длинной стороны исходных картинок
82
+
83
+ # QWEN: конфиг загрузки модели
84
+ vae_kind = "kl" # "qwen" или "kl" (обычный)
85
+
86
+ Path(generated_folder).mkdir(parents=True, exist_ok=True)
87
+
88
+ accelerator = Accelerator(
89
+ mixed_precision=mixed_precision,
90
+ gradient_accumulation_steps=gradient_accumulation_steps
91
+ )
92
+ device = accelerator.device
93
+
94
+ # reproducibility
95
+ seed = int(datetime.now().strftime("%Y%m%d"))
96
+ torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)
97
+ torch.backends.cudnn.benchmark = False
98
+
99
+ # --------------------------- WandB ---------------------------
100
+ if use_wandb and accelerator.is_main_process:
101
+ wandb.init(project=project, config={
102
+ "batch_size": batch_size,
103
+ "base_learning_rate": base_learning_rate,
104
+ "num_epochs": num_epochs,
105
+ "optimizer_type": optimizer_type,
106
+ "model_resolution": model_resolution,
107
+ "high_resolution": high_resolution,
108
+ "gradient_accumulation_steps": gradient_accumulation_steps,
109
+ "train_decoder_only": train_decoder_only,
110
+ "full_training": full_training,
111
+ "kl_ratio": kl_ratio,
112
+ "vae_kind": vae_kind,
113
+ })
114
+
115
+ # --------------------------- VAE ---------------------------
116
+ def get_core_model(model):
117
+ m = model
118
+ # если модель уже обёрнута torch.compile
119
+ if hasattr(m, "_orig_mod"):
120
+ m = m._orig_mod
121
+ return m
122
+
123
+ def is_video_vae(model) -> bool:
124
+ # WAN/Qwen — это видео-VAEs
125
+ if vae_kind in ("wan", "qwen"):
126
+ return True
127
+ # fallback по структуре (если понадобится)
128
+ try:
129
+ core = get_core_model(model)
130
+ enc = getattr(core, "encoder", None)
131
+ conv_in = getattr(enc, "conv_in", None)
132
+ w = getattr(conv_in, "weight", None)
133
+ if isinstance(w, torch.nn.Parameter):
134
+ return w.ndim == 5
135
+ except Exception:
136
+ pass
137
+ return False
138
+
139
+ # загрузка
140
+ if vae_kind == "qwen":
141
+ vae = AutoencoderKLQwenImage.from_pretrained("Qwen/Qwen-Image", subfolder="vae")
142
+ else:
143
+ if vae_kind == "wan":
144
+ vae = AutoencoderKLWan.from_pretrained(project)
145
+ else:
146
+ # старое поведение (пример)
147
+ if model_resolution==high_resolution:
148
+ #vae = AutoencoderKL.from_pretrained(project)
149
+ vae = StupidAE().cuda().half()
150
+
151
+ # 2. Определяем путь к файлу (тот же, что был при сохранении)
152
+ load_path = os.path.join(project, "vae.safetensors")
153
+
154
+ # 3. Загружаем веса из safetensors
155
+ if os.path.exists(load_path):
156
+ state_dict = load_file(load_path, device="cuda") # Сначала грузим в CPU
157
+
158
+ # 4. Заливаем веса в модель
159
+ # strict=True гарантирует, что все ключи совпали
160
+ vae.load_state_dict(state_dict, strict=True)
161
+ #vae = vae.cuda().half()
162
+
163
+ print(f"VAE успешно загружен из {load_path}")
164
+ else:
165
+ print(f"Файл {load_path} не найден!")
166
+ else:
167
+ vae = AsymmetricAutoencoderKL.from_pretrained(project)
168
+
169
+ vae = vae.to(dtype)
170
+
171
+ # torch.compile (опционально)
172
+ if hasattr(torch, "compile"):
173
+ try:
174
+ vae = torch.compile(vae)
175
+ except Exception as e:
176
+ print(f"[WARN] torch.compile failed: {e}")
177
+
178
+ # --------------------------- Freeze/Unfreeze ---------------------------
179
+ core = get_core_model(vae)
180
+
181
+ for p in core.parameters():
182
+ p.requires_grad = False
183
+
184
+ unfrozen_param_names = []
185
+
186
+ if full_training and not train_decoder_only:
187
+ for name, p in core.named_parameters():
188
+ p.requires_grad = True
189
+ unfrozen_param_names.append(name)
190
+ loss_ratios["kl"] = float(kl_ratio)
191
+ trainable_module = core
192
+ else:
193
+ # учим только 0-й блок декодера + post_quant_conv
194
+ if hasattr(core, "decoder"):
195
+ if train_up_only:#hasattr(core.decoder, "up_blocks") and len(core.decoder.up_blocks) > 0:
196
+ # --- только 0-й up_block ---
197
+ for name, p in core.decoder.up_blocks[0].named_parameters():
198
+ p.requires_grad = True
199
+ unfrozen_param_names.append(f"{name}")
200
+ else:
201
+ print("Decoder — fallback to full decoder")
202
+ for name, p in core.decoder.named_parameters():
203
+ p.requires_grad = True
204
+ unfrozen_param_names.append(f"decoder.{name}")
205
+ if hasattr(core, "post_quant_conv"):
206
+ for name, p in core.post_quant_conv.named_parameters():
207
+ p.requires_grad = True
208
+ unfrozen_param_names.append(f"post_quant_conv.{name}")
209
+ trainable_module = core.decoder if hasattr(core, "decoder") else core
210
+
211
+
212
+ print(f"[INFO] Разморожено параметров: {len(unfrozen_param_names)}. Первые 200 имён:")
213
+ for nm in unfrozen_param_names[:200]:
214
+ print(" ", nm)
215
+
216
+ # --------------------------- Датасет ---------------------------
217
+ class PngFolderDataset(Dataset):
218
+ def __init__(self, root_dir, min_exts=('.png',), resolution=1024, limit=0):
219
+ self.root_dir = root_dir
220
+ self.resolution = resolution
221
+ self.paths = []
222
+ for root, _, files in os.walk(root_dir):
223
+ for fname in files:
224
+ if fname.lower().endswith(tuple(ext.lower() for ext in min_exts)):
225
+ self.paths.append(os.path.join(root, fname))
226
+ if limit:
227
+ self.paths = self.paths[:limit]
228
+ valid = []
229
+ for p in self.paths:
230
+ try:
231
+ with Image.open(p) as im:
232
+ im.verify()
233
+ valid.append(p)
234
+ except (OSError, UnidentifiedImageError):
235
+ continue
236
+ self.paths = valid
237
+ if len(self.paths) == 0:
238
+ raise RuntimeError(f"No valid PNG images found under {root_dir}")
239
+ random.shuffle(self.paths)
240
+
241
+ def __len__(self):
242
+ return len(self.paths)
243
+
244
+ def __getitem__(self, idx):
245
+ p = self.paths[idx % len(self.paths)]
246
+ with Image.open(p) as img:
247
+ img = img.convert("RGB")
248
+ if not resize_long_side or resize_long_side <= 0:
249
+ return img
250
+ w, h = img.size
251
+ long = max(w, h)
252
+ if long <= resize_long_side:
253
+ return img
254
+ scale = resize_long_side / float(long)
255
+ new_w = int(round(w * scale))
256
+ new_h = int(round(h * scale))
257
+ return img.resize((new_w, new_h), Image.BICUBIC)
258
+
259
+ def random_crop(img, sz):
260
+ w, h = img.size
261
+ if w < sz or h < sz:
262
+ img = img.resize((max(sz, w), max(sz, h)), Image.BICUBIC)
263
+ x = random.randint(0, max(1, img.width - sz))
264
+ y = random.randint(0, max(1, img.height - sz))
265
+ return img.crop((x, y, x + sz, y + sz))
266
+
267
+ tfm = transforms.Compose([
268
+ transforms.ToTensor(),
269
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
270
+ ])
271
+
272
+ dataset = PngFolderDataset(ds_path, min_exts=('.png',), resolution=high_resolution, limit=limit)
273
+ if len(dataset) < batch_size:
274
+ raise RuntimeError(f"Not enough valid images ({len(dataset)}) to form a batch of size {batch_size}")
275
+
276
+ def collate_fn(batch):
277
+ imgs = []
278
+ for img in batch:
279
+ img = random_crop(img, high_resolution)
280
+ imgs.append(tfm(img))
281
+ return torch.stack(imgs)
282
+
283
+ dataloader = DataLoader(
284
+ dataset,
285
+ batch_size=batch_size,
286
+ shuffle=True,
287
+ collate_fn=collate_fn,
288
+ num_workers=num_workers,
289
+ pin_memory=True,
290
+ drop_last=True
291
+ )
292
+
293
+ # --------------------------- Оптимизатор ---------------------------
294
+ def get_param_groups(module, weight_decay=0.001):
295
+ no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight", "ln_1.weight", "ln_f.weight"]
296
+ decay_params, no_decay_params = [], []
297
+ for n, p in vae.named_parameters(): # глобально по vae, с фильтром requires_grad
298
+ if not p.requires_grad:
299
+ continue
300
+ if any(nd in n for nd in no_decay):
301
+ no_decay_params.append(p)
302
+ else:
303
+ decay_params.append(p)
304
+ return [
305
+ {"params": decay_params, "weight_decay": weight_decay},
306
+ {"params": no_decay_params, "weight_decay": 0.0},
307
+ ]
308
+
309
+ def get_param_groups(module, weight_decay=0.001):
310
+ no_decay_tokens = ("bias", "norm", "rms", "layernorm")
311
+ decay_params, no_decay_params = [], []
312
+ for n, p in module.named_parameters():
313
+ if not p.requires_grad:
314
+ continue
315
+ n_l = n.lower()
316
+ if any(t in n_l for t in no_decay_tokens):
317
+ no_decay_params.append(p)
318
+ else:
319
+ decay_params.append(p)
320
+ return [
321
+ {"params": decay_params, "weight_decay": weight_decay},
322
+ {"params": no_decay_params, "weight_decay": 0.0},
323
+ ]
324
+
325
+ def create_optimizer(name, param_groups):
326
+ if name == "adam8bit":
327
+ return bnb.optim.AdamW8bit(param_groups, lr=base_learning_rate, betas=(0.9, beta2), eps=eps)
328
+ raise ValueError(name)
329
+
330
+ param_groups = get_param_groups(get_core_model(vae), weight_decay=0.001)
331
+ optimizer = create_optimizer(optimizer_type, param_groups)
332
+
333
+ # --------------------------- LR schedule ---------------------------
334
+ batches_per_epoch = len(dataloader)
335
+ steps_per_epoch = int(math.ceil(batches_per_epoch / float(gradient_accumulation_steps)))
336
+ total_steps = steps_per_epoch * num_epochs
337
+
338
+ def lr_lambda(step):
339
+ if not use_decay:
340
+ return 1.0
341
+ x = float(step) / float(max(1, total_steps))
342
+ warmup = float(warmup_percent)
343
+ min_ratio = float(min_learning_rate) / float(base_learning_rate)
344
+ if x < warmup:
345
+ return min_ratio + (1.0 - min_ratio) * (x / warmup)
346
+ decay_ratio = (x - warmup) / (1.0 - warmup)
347
+ return min_ratio + 0.5 * (1.0 - min_ratio) * (1.0 + math.cos(math.pi * decay_ratio))
348
+
349
+ scheduler = LambdaLR(optimizer, lr_lambda)
350
+
351
+ # Подготовка
352
+ dataloader, vae, optimizer, scheduler = accelerator.prepare(dataloader, vae, optimizer, scheduler)
353
+ trainable_params = [p for p in vae.parameters() if p.requires_grad]
354
+
355
+ # fdl
356
+ fdl_loss = FDL_loss()
357
+ fdl_loss = fdl_loss.to(accelerator.device)
358
+
359
+ # --------------------------- LPIPS и вспомогательные ---------------------------
360
+ _lpips_net = None
361
+ def _get_lpips():
362
+ global _lpips_net
363
+ if _lpips_net is None:
364
+ _lpips_net = lpips.LPIPS(net='vgg', verbose=False).eval().to(accelerator.device).eval()
365
+ return _lpips_net
366
+
367
+ _sobel_kx = torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]], dtype=torch.float32)
368
+ _sobel_ky = torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]], dtype=torch.float32)
369
+ def sobel_edges(x: torch.Tensor) -> torch.Tensor:
370
+ C = x.shape[1]
371
+ kx = _sobel_kx.to(x.device, x.dtype).repeat(C, 1, 1, 1)
372
+ ky = _sobel_ky.to(x.device, x.dtype).repeat(C, 1, 1, 1)
373
+ gx = F.conv2d(x, kx, padding=1, groups=C)
374
+ gy = F.conv2d(x, ky, padding=1, groups=C)
375
+ return torch.sqrt(gx * gx + gy * gy + 1e-12)
376
+
377
+ class MedianLossNormalizer:
378
+ def __init__(self, desired_ratios: dict, window_steps: int):
379
+ s = sum(desired_ratios.values())
380
+ self.ratios = {k: (v / s) if s > 0 else 0.0 for k, v in desired_ratios.items()}
381
+ self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()}
382
+ self.window = window_steps
383
+
384
+ def update_and_total(self, abs_losses: dict):
385
+ for k, v in abs_losses.items():
386
+ if k in self.buffers:
387
+ self.buffers[k].append(float(v.detach().abs().cpu()))
388
+ meds = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers}
389
+ coeffs = {k: (self.ratios[k] / max(meds[k], 1e-12)) for k in self.ratios}
390
+ total = sum(coeffs[k] * abs_losses[k] for k in abs_losses if k in coeffs)
391
+ return total, coeffs, meds
392
+
393
+ if full_training and not train_decoder_only:
394
+ loss_ratios["kl"] = float(kl_ratio)
395
+ normalizer = MedianLossNormalizer(loss_ratios, median_coeff_steps)
396
+
397
+ # --------------------------- Сэмплы ---------------------------
398
+ @torch.no_grad()
399
+ def get_fixed_samples(n=3):
400
+ idx = random.sample(range(len(dataset)), min(n, len(dataset)))
401
+ pil_imgs = [dataset[i] for i in idx]
402
+ tensors = []
403
+ for img in pil_imgs:
404
+ img = random_crop(img, high_resolution)
405
+ tensors.append(tfm(img))
406
+ return torch.stack(tensors).to(accelerator.device, dtype)
407
+
408
+ fixed_samples = get_fixed_samples()
409
+
410
+ @torch.no_grad()
411
+ def _to_pil_uint8(img_tensor: torch.Tensor) -> Image.Image:
412
+ arr = ((img_tensor.float().clamp(-1, 1) + 1.0) * 127.5).clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0)
413
+ return Image.fromarray(arr)
414
+
415
+
416
+ @torch.no_grad()
417
+ def generate_and_save_samples(step=None):
418
+ try:
419
+ temp_vae = accelerator.unwrap_model(vae).eval()
420
+ lpips_net = _get_lpips()
421
+ with torch.no_grad():
422
+ orig_high = fixed_samples
423
+ orig_low = F.interpolate(
424
+ orig_high,
425
+ size=(model_resolution, model_resolution),
426
+ mode="bilinear",
427
+ align_corners=False
428
+ )
429
+ model_dtype = next(temp_vae.parameters()).dtype
430
+ orig_low = orig_low.to(dtype=model_dtype)
431
+
432
+ # Encode/decode с учётом видео-режима
433
+ if is_video_vae(temp_vae):
434
+ x_in = orig_low.unsqueeze(2) # [B,3,1,H,W]
435
+ enc = temp_vae.encode(x_in)
436
+ latents_mean = enc.latent_dist.mean
437
+ dec = temp_vae.decode(latents_mean).sample # [B,3,1,H,W]
438
+ rec = dec.squeeze(2) # [B,3,H,W]
439
+ else:
440
+ latents_mean = temp_vae.encode(orig_low)
441
+ #latents_mean = enc.latent_dist.mean
442
+ rec = temp_vae.decode(latents_mean)#.sample
443
+
444
+ # Подгон размеров, если надо
445
+ #if rec.shape[-2:] != orig_high.shape[-2:]:
446
+ # rec = F.interpolate(rec, size=orig_high.shape[-2:], mode="bilinear", align_corners=False)
447
+
448
+ # Сохраняем все real/decoded
449
+ for i in range(rec.shape[0]):
450
+ real_img = _to_pil_uint8(orig_high[i])
451
+ dec_img = _to_pil_uint8(rec[i])
452
+ real_img.save(f"{generated_folder}/sample_real_{i}.jpg", quality=95)
453
+ dec_img.save(f"{generated_folder}/sample_decoded_{i}.jpg", quality=95)
454
+
455
+ # LPIPS
456
+ lpips_scores = []
457
+ for i in range(rec.shape[0]):
458
+ orig_full = orig_high[i:i+1].to(torch.float32)
459
+ rec_full = rec[i:i+1].to(torch.float32)
460
+ #if rec_full.shape[-2:] != orig_full.shape[-2:]:
461
+ # rec_full = F.interpolate(rec_full, size=orig_full.shape[-2:], mode="bilinear", align_corners=False)
462
+ lpips_val = lpips_net(orig_full, rec_full).item()
463
+ lpips_scores.append(lpips_val)
464
+ avg_lpips = float(np.mean(lpips_scores))
465
+
466
+ # W&B логирование
467
+ if use_wandb and accelerator.is_main_process:
468
+ log_data = {"lpips_mean": avg_lpips}
469
+ for i in range(rec.shape[0]):
470
+ log_data[f"sample/real_{i}"] = wandb.Image(f"{generated_folder}/sample_real_{i}.jpg", caption=f"real_{i}")
471
+ log_data[f"sample/decoded_{i}"] = wandb.Image(f"{generated_folder}/sample_decoded_{i}.jpg", caption=f"decoded_{i}")
472
+ wandb.log(log_data, step=step)
473
+
474
+ finally:
475
+ gc.collect()
476
+ torch.cuda.empty_cache()
477
+
478
+
479
+ if accelerator.is_main_process and save_model:
480
+ print("Генерация сэмплов до старта обучения...")
481
+ generate_and_save_samples(0)
482
+
483
+ accelerator.wait_for_everyone()
484
+
485
+ # --------------------------- Тренировка ---------------------------
486
+ progress = tqdm(total=total_steps, disable=not accelerator.is_local_main_process)
487
+ global_step = 0
488
+ min_loss = float("inf")
489
+ sample_interval = max(1, total_steps // max(1, sample_interval_share * num_epochs))
490
+
491
+ for epoch in range(num_epochs):
492
+ vae.train()
493
+ batch_losses, batch_grads = [], []
494
+ track_losses = {k: [] for k in loss_ratios.keys()}
495
+
496
+ for imgs in dataloader:
497
+ with accelerator.accumulate(vae):
498
+ imgs = imgs.to(accelerator.device)
499
+
500
+ if high_resolution != model_resolution:
501
+ imgs_low = F.interpolate(imgs, size=(model_resolution, model_resolution), mode="bilinear", align_corners=False)
502
+ else:
503
+ imgs_low = imgs
504
+
505
+ model_dtype = next(vae.parameters()).dtype
506
+ imgs_low_model = imgs_low.to(dtype=model_dtype) if imgs_low.dtype != model_dtype else imgs_low
507
+
508
+ # QWEN: encode/decode с T=1
509
+ if is_video_vae(vae):
510
+ x_in = imgs_low_model.unsqueeze(2) # [B,3,1,H,W]
511
+ enc = vae.encode(x_in)
512
+ latents = enc.latent_dist.mean if train_decoder_only else enc.latent_dist.sample()
513
+ dec = vae.decode(latents).sample # [B,3,1,H,W]
514
+ rec = dec.squeeze(2) # [B,3,H,W]
515
+ else:
516
+ enc = vae.encode(imgs_low_model)
517
+ #latents = enc.latent_dist.mean if train_decoder_only else enc.latent_dist.sample()
518
+ rec = vae.decode(enc)#.sample
519
+
520
+ #if rec.shape[-2:] != imgs.shape[-2:]:
521
+ # rec = F.interpolate(rec, size=imgs.shape[-2:], mode="bilinear", align_corners=False)
522
+
523
+ rec_f32 = rec.to(torch.float32)
524
+ imgs_f32 = imgs.to(torch.float32)
525
+
526
+ abs_losses = {
527
+ "mae": F.l1_loss(rec_f32, imgs_f32),
528
+ "mse": F.mse_loss(rec_f32, imgs_f32),
529
+ "lpips": _get_lpips()(rec_f32, imgs_f32).mean(),
530
+ "fdl": fdl_loss(rec_f32, imgs_f32),
531
+ "edge": F.l1_loss(sobel_edges(rec_f32), sobel_edges(imgs_f32)),
532
+ }
533
+
534
+ if full_training and not train_decoder_only:
535
+ #mean = enc.mean
536
+ #logvar = enc.logvar
537
+ #mean = enc.latent_dist.mean
538
+ #logvar = enc.latent_dist.logvar
539
+ #kl = -0.5 * torch.mean(1 + logvar - mean.pow(2) - logvar.exp())
540
+ #abs_losses["kl"] = kl
541
+ abs_losses["kl"] = torch.tensor(0.0, device=accelerator.device, dtype=torch.float32)
542
+ else:
543
+ abs_losses["kl"] = torch.tensor(0.0, device=accelerator.device, dtype=torch.float32)
544
+
545
+ total_loss, coeffs, meds = normalizer.update_and_total(abs_losses)
546
+
547
+ if torch.isnan(total_loss) or torch.isinf(total_loss):
548
+ raise RuntimeError("NaN/Inf loss")
549
+
550
+ accelerator.backward(total_loss)
551
+
552
+ grad_norm = torch.tensor(0.0, device=accelerator.device)
553
+ if accelerator.sync_gradients:
554
+ grad_norm = accelerator.clip_grad_norm_(trainable_params, clip_grad_norm)
555
+ optimizer.step()
556
+ scheduler.step()
557
+ optimizer.zero_grad(set_to_none=True)
558
+ global_step += 1
559
+ progress.update(1)
560
+
561
+ if accelerator.is_main_process:
562
+ try:
563
+ current_lr = optimizer.param_groups[0]["lr"]
564
+ except Exception:
565
+ current_lr = scheduler.get_last_lr()[0]
566
+
567
+ batch_losses.append(total_loss.detach().item())
568
+ batch_grads.append(float(grad_norm.detach().cpu().item()) if isinstance(grad_norm, torch.Tensor) else float(grad_norm))
569
+ for k, v in abs_losses.items():
570
+ track_losses[k].append(float(v.detach().item()))
571
+
572
+ if use_wandb and accelerator.sync_gradients:
573
+ log_dict = {
574
+ "total_loss": float(total_loss.detach().item()),
575
+ "learning_rate": current_lr,
576
+ "epoch": epoch,
577
+ "grad_norm": batch_grads[-1],
578
+ }
579
+ for k, v in abs_losses.items():
580
+ log_dict[f"loss_{k}"] = float(v.detach().item())
581
+ for k in coeffs:
582
+ log_dict[f"coeff_{k}"] = float(coeffs[k])
583
+ log_dict[f"median_{k}"] = float(meds[k])
584
+ wandb.log(log_dict, step=global_step)
585
+
586
+ if global_step > 0 and global_step % sample_interval == 0:
587
+ if accelerator.is_main_process:
588
+ generate_and_save_samples(global_step)
589
+ accelerator.wait_for_everyone()
590
+
591
+ n_micro = sample_interval * gradient_accumulation_steps
592
+ avg_loss = float(np.mean(batch_losses[-n_micro:])) if len(batch_losses) >= n_micro else float(np.mean(batch_losses)) if batch_losses else float("nan")
593
+ avg_grad = float(np.mean(batch_grads[-n_micro:])) if len(batch_grads) >= 1 else float(np.mean(batch_grads)) if batch_grads else 0.0
594
+
595
+ if accelerator.is_main_process:
596
+ print(f"Epoch {epoch} step {global_step} loss: {avg_loss:.6f}, grad_norm: {avg_grad:.6f}, lr: {current_lr:.9f}")
597
+ if save_model and avg_loss < min_loss * save_barrier:
598
+ min_loss = avg_loss
599
+ #accelerator.unwrap_model(vae).save_pretrained(save_as)
600
+
601
+ # 1. Разворачиваем модель из акселератора
602
+ unwrapped_model = accelerator.unwrap_model(vae)
603
+
604
+ # 2. Получаем state_dict (словарь весов)
605
+ # Используем accelerator.get_state_dict, чтобы он корректно собрал веса в DDP
606
+ state_dict = accelerator.get_state_dict(vae)
607
+
608
+ # 3. Создаем папку, если её нет
609
+ os.makedirs(save_as, exist_ok=True)
610
+
611
+ # 4. Сохраняем в формате safetensors
612
+ save_path = os.path.join(save_as, "vae.safetensors")
613
+ save_file(state_dict, save_path)
614
+
615
+ print(f"Модель успешно сохранена в {save_path}")
616
+ if use_wandb:
617
+ wandb.log({"interm_loss": avg_loss, "interm_grad": avg_grad}, step=global_step)
618
+
619
+ if accelerator.is_main_process:
620
+ epoch_avg = float(np.mean(batch_losses)) if batch_losses else float("nan")
621
+ print(f"Epoch {epoch} done, avg loss {epoch_avg:.6f}")
622
+ if use_wandb:
623
+ wandb.log({"epoch_loss": epoch_avg, "epoch": epoch + 1}, step=global_step)
624
+
625
+ # --------------------------- Финальное сохранение ---------------------------
626
+ if accelerator.is_main_process:
627
+ print("Training finished – saving final model")
628
+ if save_model:
629
+ #accelerator.unwrap_model(vae).save_pretrained(save_as)
630
+ # 1. Разворачиваем модель из акселератора
631
+ unwrapped_model = accelerator.unwrap_model(vae)
632
+
633
+ # 2. Получаем state_dict (словарь весов)
634
+ # Используем accelerator.get_state_dict, чтобы он корректно собрал веса в DDP
635
+ state_dict = accelerator.get_state_dict(vae)
636
+
637
+ # 4. Сохраняем в формате safetensors
638
+ save_path = os.path.join(save_as, "vae_model.safetensors")
639
+ save_file(state_dict, save_path)
640
+
641
+ print(f"Модель успешно сохранена в {save_path}")
642
+
643
+ accelerator.free_memory()
644
+ if torch.distributed.is_initialized():
645
+ torch.distributed.destroy_process_group()
646
+ print("Готово!")
vae.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f271b4a9e81187b9c486faaff9d0c2ef7fb0bea06b310ce9e1ec247388962c2
3
+ size 202307260