babkasotona commited on
Commit
bfec1c4
·
verified ·
1 Parent(s): dfd2311

Upload folder using huggingface_hub

Browse files
.ipynb_checkpoints/config-checkpoint.json ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AsymmetricAutoencoderKL",
3
+ "_diffusers_version": "0.37.1",
4
+ "_name_or_path": "vae",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 128,
9
+ 256,
10
+ 512,
11
+ 512
12
+ ],
13
+ "down_block_out_channels": [
14
+ 128,
15
+ 256,
16
+ 512,
17
+ 512
18
+ ],
19
+ "down_block_types": [
20
+ "DownEncoderBlock2D",
21
+ "DownEncoderBlock2D",
22
+ "DownEncoderBlock2D",
23
+ "DownEncoderBlock2D"
24
+ ],
25
+ "force_upcast": false,
26
+ "in_channels": 3,
27
+ "latent_channels": 32,
28
+ "latents_mean": [
29
+ -0.03542253375053406,
30
+ 0.20086465775966644,
31
+ -0.016413161531090736,
32
+ -0.0956302210688591,
33
+ -0.2672063112258911,
34
+ 0.2609933018684387,
35
+ -0.07806991040706635,
36
+ -0.48407721519470215,
37
+ 0.21844269335269928,
38
+ -0.1122383326292038,
39
+ 0.27197545766830444,
40
+ -0.18958772718906403,
41
+ 0.18776826560497284,
42
+ 0.0987580344080925,
43
+ 0.2837068736553192,
44
+ -0.4486690163612366,
45
+ 0.4816776514053345,
46
+ 0.02947971224784851,
47
+ -0.1337375044822693,
48
+ -0.39750921726226807,
49
+ -0.08513020724058151,
50
+ -0.054023586213588715,
51
+ -0.3943594992160797,
52
+ 0.23918119072914124,
53
+ -0.12466679513454437,
54
+ 0.09935147315263748,
55
+ 0.31858691573143005,
56
+ 0.48585832118988037,
57
+ -0.6416525840759277,
58
+ -0.15164820849895477,
59
+ -0.4693508744239807,
60
+ -0.13071806728839874
61
+ ],
62
+ "latents_std": [
63
+ 1.5792087316513062,
64
+ 1.5769503116607666,
65
+ 1.5864241123199463,
66
+ 1.6454921960830688,
67
+ 1.5336694717407227,
68
+ 1.5587652921676636,
69
+ 1.5838669538497925,
70
+ 1.5659377574920654,
71
+ 1.6860467195510864,
72
+ 1.5192310810089111,
73
+ 1.573639988899231,
74
+ 1.5953549146652222,
75
+ 1.5271092653274536,
76
+ 1.6246271133422852,
77
+ 1.7054023742675781,
78
+ 1.607722282409668,
79
+ 1.558642864227295,
80
+ 1.5824549198150635,
81
+ 1.6202995777130127,
82
+ 1.6206320524215698,
83
+ 1.6379750967025757,
84
+ 1.6527063846588135,
85
+ 1.498811960220337,
86
+ 1.5706247091293335,
87
+ 1.5854856967926025,
88
+ 1.4828169345855713,
89
+ 1.5693111419677734,
90
+ 1.692481517791748,
91
+ 1.6409776210784912,
92
+ 1.6216280460357666,
93
+ 1.6087706089019775,
94
+ 1.5776633024215698
95
+ ],
96
+ "layers_per_down_block": 2,
97
+ "layers_per_up_block": 2,
98
+ "norm_num_groups": 32,
99
+ "out_channels": 3,
100
+ "sample_size": 32,
101
+ "scaling_factor": 1.0,
102
+ "up_block_out_channels": [
103
+ 128,
104
+ 128,
105
+ 256,
106
+ 512,
107
+ 512
108
+ ],
109
+ "up_block_types": [
110
+ "UpDecoderBlock2D",
111
+ "UpDecoderBlock2D",
112
+ "UpDecoderBlock2D",
113
+ "UpDecoderBlock2D",
114
+ "UpDecoderBlock2D"
115
+ ]
116
+ }
config.json ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "AsymmetricAutoencoderKL",
3
+ "_diffusers_version": "0.37.1",
4
+ "_name_or_path": "vae",
5
+ "act_fn": "silu",
6
+ "block_out_channels": [
7
+ 128,
8
+ 128,
9
+ 256,
10
+ 512,
11
+ 512
12
+ ],
13
+ "down_block_out_channels": [
14
+ 128,
15
+ 256,
16
+ 512,
17
+ 512
18
+ ],
19
+ "down_block_types": [
20
+ "DownEncoderBlock2D",
21
+ "DownEncoderBlock2D",
22
+ "DownEncoderBlock2D",
23
+ "DownEncoderBlock2D"
24
+ ],
25
+ "force_upcast": false,
26
+ "in_channels": 3,
27
+ "latent_channels": 32,
28
+ "latents_mean": [
29
+ -0.03542253375053406,
30
+ 0.20086465775966644,
31
+ -0.016413161531090736,
32
+ -0.0956302210688591,
33
+ -0.2672063112258911,
34
+ 0.2609933018684387,
35
+ -0.07806991040706635,
36
+ -0.48407721519470215,
37
+ 0.21844269335269928,
38
+ -0.1122383326292038,
39
+ 0.27197545766830444,
40
+ -0.18958772718906403,
41
+ 0.18776826560497284,
42
+ 0.0987580344080925,
43
+ 0.2837068736553192,
44
+ -0.4486690163612366,
45
+ 0.4816776514053345,
46
+ 0.02947971224784851,
47
+ -0.1337375044822693,
48
+ -0.39750921726226807,
49
+ -0.08513020724058151,
50
+ -0.054023586213588715,
51
+ -0.3943594992160797,
52
+ 0.23918119072914124,
53
+ -0.12466679513454437,
54
+ 0.09935147315263748,
55
+ 0.31858691573143005,
56
+ 0.48585832118988037,
57
+ -0.6416525840759277,
58
+ -0.15164820849895477,
59
+ -0.4693508744239807,
60
+ -0.13071806728839874
61
+ ],
62
+ "latents_std": [
63
+ 1.5792087316513062,
64
+ 1.5769503116607666,
65
+ 1.5864241123199463,
66
+ 1.6454921960830688,
67
+ 1.5336694717407227,
68
+ 1.5587652921676636,
69
+ 1.5838669538497925,
70
+ 1.5659377574920654,
71
+ 1.6860467195510864,
72
+ 1.5192310810089111,
73
+ 1.573639988899231,
74
+ 1.5953549146652222,
75
+ 1.5271092653274536,
76
+ 1.6246271133422852,
77
+ 1.7054023742675781,
78
+ 1.607722282409668,
79
+ 1.558642864227295,
80
+ 1.5824549198150635,
81
+ 1.6202995777130127,
82
+ 1.6206320524215698,
83
+ 1.6379750967025757,
84
+ 1.6527063846588135,
85
+ 1.498811960220337,
86
+ 1.5706247091293335,
87
+ 1.5854856967926025,
88
+ 1.4828169345855713,
89
+ 1.5693111419677734,
90
+ 1.692481517791748,
91
+ 1.6409776210784912,
92
+ 1.6216280460357666,
93
+ 1.6087706089019775,
94
+ 1.5776633024215698
95
+ ],
96
+ "layers_per_down_block": 2,
97
+ "layers_per_up_block": 2,
98
+ "norm_num_groups": 32,
99
+ "out_channels": 3,
100
+ "sample_size": 32,
101
+ "scaling_factor": 1.0,
102
+ "up_block_out_channels": [
103
+ 128,
104
+ 128,
105
+ 256,
106
+ 512,
107
+ 512
108
+ ],
109
+ "up_block_types": [
110
+ "UpDecoderBlock2D",
111
+ "UpDecoderBlock2D",
112
+ "UpDecoderBlock2D",
113
+ "UpDecoderBlock2D",
114
+ "UpDecoderBlock2D"
115
+ ]
116
+ }
diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2898bf83f498827cae5de6feeff1e8d376c784def9b22c1de77135f69f237b0d
3
+ size 383499124
train_sdxs_vae.py ADDED
@@ -0,0 +1,623 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ import os
3
+ import math
4
+ import re
5
+ import torch
6
+ import numpy as np
7
+ import random
8
+ import gc
9
+ from datetime import datetime
10
+ from pathlib import Path
11
+
12
+ import torchvision.transforms as transforms
13
+ import torch.nn.functional as F
14
+ from torch.utils.data import DataLoader, Dataset
15
+ from torch.optim.lr_scheduler import LambdaLR
16
+ from diffusers import AutoencoderKL, AsymmetricAutoencoderKL
17
+ # QWEN: импорт класса
18
+ from diffusers import AutoencoderKLQwenImage
19
+ from diffusers import AutoencoderKLWan
20
+
21
+ from accelerate import Accelerator
22
+ from PIL import Image, UnidentifiedImageError
23
+ from tqdm import tqdm
24
+ import bitsandbytes as bnb
25
+ import wandb
26
+ import lpips # pip install lpips
27
+ from FDL_pytorch import FDL_loss # pip install fdl-pytorch
28
+ from collections import deque
29
+
30
+ # --------------------------- Параметры ---------------------------
31
+ ds_path = "/workspace/d23"
32
+ project = "vae"
33
+ batch_size = 1
34
+ base_learning_rate = 6e-6
35
+ min_learning_rate = 7e-7
36
+ num_epochs = 1
37
+ sample_interval_share = 30
38
+ use_wandb = False
39
+ save_model = True
40
+ use_decay = True
41
+ optimizer_type = "adam8bit"
42
+ dtype = torch.float32
43
+
44
+ model_resolution = 512
45
+ high_resolution = 1024
46
+ limit = 0
47
+ save_barrier = 1.3
48
+ warmup_percent = 0.005
49
+ beta2 = 0.997
50
+ eps = 1e-8
51
+ clip_grad_norm = 1.0
52
+ mixed_precision = "no"
53
+ gradient_accumulation_steps = 1
54
+ generated_folder = "samples"
55
+ save_as = "vae2"
56
+ num_workers = 0
57
+ device = None
58
+ torch.backends.cuda.matmul.allow_tf32 = True
59
+ torch.backends.cudnn.allow_tf32 = True
60
+ # Включение Flash Attention 2/SDPA #MAX_JOBS=4 pip install flash-attn --no-build-isolation
61
+ torch.backends.cuda.enable_flash_sdp(True)
62
+ torch.backends.cuda.enable_mem_efficient_sdp(True)
63
+ torch.backends.cuda.enable_math_sdp(False)
64
+
65
+ # --- Режимы обучения ---
66
+ # QWEN: учим только декодер
67
+ train_decoder_only = True
68
+ train_up_only = False
69
+ full_training = False # если True — учим весь VAE и добавляем KL (ниже)
70
+ kl_ratio = 0.0
71
+
72
+ # Доли лоссов
73
+ loss_ratios = {
74
+ "lpips": 0.70,#0.50,
75
+ "fdl" : 0.10,#0.25,
76
+ "edge": 0.05,
77
+ "mse": 0.10,
78
+ "mae": 0.05,
79
+ "kl": 0.00,
80
+ }
81
+ median_coeff_steps = 250
82
+
83
+ resize_long_side = 1280 # ресайз длинной стороны исходных картинок
84
+
85
+ # QWEN: конфиг загрузки модели
86
+ vae_kind = "kl" # "qwen" или "kl" (обычный)
87
+
88
+ Path(generated_folder).mkdir(parents=True, exist_ok=True)
89
+
90
+ accelerator = Accelerator(
91
+ mixed_precision=mixed_precision,
92
+ gradient_accumulation_steps=gradient_accumulation_steps
93
+ )
94
+ device = accelerator.device
95
+
96
+ # reproducibility
97
+ seed = int(datetime.now().strftime("%Y%m%d")) + 42
98
+ torch.manual_seed(seed); np.random.seed(seed); random.seed(seed)
99
+ torch.backends.cudnn.benchmark = False
100
+
101
+ # --------------------------- WandB ---------------------------
102
+ if use_wandb and accelerator.is_main_process:
103
+ wandb.init(project=project, config={
104
+ "batch_size": batch_size,
105
+ "base_learning_rate": base_learning_rate,
106
+ "num_epochs": num_epochs,
107
+ "optimizer_type": optimizer_type,
108
+ "model_resolution": model_resolution,
109
+ "high_resolution": high_resolution,
110
+ "gradient_accumulation_steps": gradient_accumulation_steps,
111
+ "train_decoder_only": train_decoder_only,
112
+ "full_training": full_training,
113
+ "kl_ratio": kl_ratio,
114
+ "vae_kind": vae_kind,
115
+ })
116
+
117
+ # --------------------------- VAE ---------------------------
118
+ def get_core_model(model):
119
+ m = model
120
+ # если модель уже обёрнута torch.compile
121
+ if hasattr(m, "_orig_mod"):
122
+ m = m._orig_mod
123
+ return m
124
+
125
+ def is_video_vae(model) -> bool:
126
+ # WAN/Qwen — это видео-VAEs
127
+ if vae_kind in ("wan", "qwen"):
128
+ return True
129
+ # fallback по структуре (если понадобится)
130
+ try:
131
+ core = get_core_model(model)
132
+ enc = getattr(core, "encoder", None)
133
+ conv_in = getattr(enc, "conv_in", None)
134
+ w = getattr(conv_in, "weight", None)
135
+ if isinstance(w, torch.nn.Parameter):
136
+ return w.ndim == 5
137
+ except Exception:
138
+ pass
139
+ return False
140
+
141
+ # загрузка
142
+ if vae_kind == "qwen":
143
+ vae = AutoencoderKLQwenImage.from_pretrained("Qwen/Qwen-Image", subfolder="vae")
144
+ else:
145
+ if vae_kind == "wan":
146
+ vae = AutoencoderKLWan.from_pretrained(project)
147
+ else:
148
+ # старое поведение (пример)
149
+ if model_resolution==high_resolution:
150
+ vae = AutoencoderKL.from_pretrained(project)
151
+ else:
152
+ vae = AsymmetricAutoencoderKL.from_pretrained(project)
153
+
154
+ vae = vae.to(dtype)
155
+
156
+ # torch.compile (опционально)
157
+ if hasattr(torch, "compile"):
158
+ try:
159
+ vae = torch.compile(vae)
160
+ except Exception as e:
161
+ print(f"[WARN] torch.compile failed: {e}")
162
+
163
+ # --------------------------- Freeze/Unfreeze ---------------------------
164
+ core = get_core_model(vae)
165
+
166
+ for p in core.parameters():
167
+ p.requires_grad = False
168
+
169
+ unfrozen_param_names = []
170
+
171
+ if full_training and not train_decoder_only:
172
+ for name, p in core.named_parameters():
173
+ p.requires_grad = True
174
+ unfrozen_param_names.append(name)
175
+ loss_ratios["kl"] = float(kl_ratio)
176
+ trainable_module = core
177
+ else:
178
+ # учим только 0-й блок декодера + post_quant_conv
179
+ if hasattr(core, "decoder"):
180
+ if train_up_only:#hasattr(core.decoder, "up_blocks") and len(core.decoder.up_blocks) > 0:
181
+ # --- только 0-й up_block ---
182
+ for name, p in core.decoder.up_blocks[0].named_parameters():
183
+ p.requires_grad = True
184
+ unfrozen_param_names.append(f"{name}")
185
+ else:
186
+ print("Decoder — fallback to full decoder")
187
+ for name, p in core.decoder.named_parameters():
188
+ p.requires_grad = True
189
+ unfrozen_param_names.append(f"decoder.{name}")
190
+ if hasattr(core, "post_quant_conv"):
191
+ for name, p in core.post_quant_conv.named_parameters():
192
+ p.requires_grad = True
193
+ unfrozen_param_names.append(f"post_quant_conv.{name}")
194
+ trainable_module = core.decoder if hasattr(core, "decoder") else core
195
+
196
+
197
+ print(f"[INFO] Разморожено параметров: {len(unfrozen_param_names)}. Первые 200 имён:")
198
+ for nm in unfrozen_param_names[:100]:
199
+ print(" ", nm)
200
+
201
+ # --------------------------- Датасет ---------------------------
202
+ class PngFolderDataset(Dataset):
203
+ def __init__(self, root_dir, min_exts=('.png',), resolution=1024, limit=0):
204
+ self.root_dir = root_dir
205
+ self.resolution = resolution
206
+ self.paths = []
207
+ for root, _, files in os.walk(root_dir):
208
+ for fname in files:
209
+ if fname.lower().endswith(tuple(ext.lower() for ext in min_exts)):
210
+ self.paths.append(os.path.join(root, fname))
211
+ if limit:
212
+ self.paths = self.paths[:limit]
213
+ valid = []
214
+ for p in self.paths:
215
+ try:
216
+ with Image.open(p) as im:
217
+ im.verify()
218
+ valid.append(p)
219
+ except (OSError, UnidentifiedImageError):
220
+ continue
221
+ self.paths = valid
222
+ if len(self.paths) == 0:
223
+ raise RuntimeError(f"No valid PNG images found under {root_dir}")
224
+ random.shuffle(self.paths)
225
+
226
+ def __len__(self):
227
+ return len(self.paths)
228
+
229
+ def __getitem__(self, idx):
230
+ p = self.paths[idx % len(self.paths)]
231
+ with Image.open(p) as img:
232
+ img = img.convert("RGB")
233
+ if not resize_long_side or resize_long_side <= 0:
234
+ return img
235
+ w, h = img.size
236
+ long = max(w, h)
237
+ if long <= resize_long_side:
238
+ return img
239
+ scale = resize_long_side / float(long)
240
+ new_w = int(round(w * scale))
241
+ new_h = int(round(h * scale))
242
+ return img.resize((new_w, new_h), Image.BICUBIC)
243
+
244
+ def random_crop(img, sz):
245
+ w, h = img.size
246
+ if w < sz or h < sz:
247
+ img = img.resize((max(sz, w), max(sz, h)), Image.BICUBIC)
248
+ x = random.randint(0, max(1, img.width - sz))
249
+ y = random.randint(0, max(1, img.height - sz))
250
+ return img.crop((x, y, x + sz, y + sz))
251
+
252
+ tfm = transforms.Compose([
253
+ transforms.ToTensor(),
254
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])
255
+ ])
256
+
257
+ dataset = PngFolderDataset(ds_path, min_exts=('.png',), resolution=high_resolution, limit=limit)
258
+ print("len(dataset)",len(dataset))
259
+ if len(dataset) < batch_size:
260
+ raise RuntimeError(f"Not enough valid images ({len(dataset)}) to form a batch of size {batch_size}")
261
+
262
+ def collate_fn(batch):
263
+ imgs = []
264
+ for img in batch:
265
+ img = random_crop(img, high_resolution)
266
+ imgs.append(tfm(img))
267
+ return torch.stack(imgs)
268
+
269
+ dataloader = DataLoader(
270
+ dataset,
271
+ batch_size=batch_size,
272
+ shuffle=True,
273
+ collate_fn=collate_fn,
274
+ num_workers=num_workers,
275
+ pin_memory=True,
276
+ drop_last=True
277
+ )
278
+
279
+ # --------------------------- Оптимизатор ---------------------------
280
+ def get_param_groups(module, weight_decay=0.001):
281
+ no_decay = ["bias", "LayerNorm.weight", "layer_norm.weight", "ln_1.weight", "ln_f.weight"]
282
+ decay_params, no_decay_params = [], []
283
+ for n, p in vae.named_parameters(): # глобально по vae, с фильтром requires_grad
284
+ if not p.requires_grad:
285
+ continue
286
+ if any(nd in n for nd in no_decay):
287
+ no_decay_params.append(p)
288
+ else:
289
+ decay_params.append(p)
290
+ return [
291
+ {"params": decay_params, "weight_decay": weight_decay},
292
+ {"params": no_decay_params, "weight_decay": 0.0},
293
+ ]
294
+
295
+ def get_param_groups(module, weight_decay=0.001):
296
+ no_decay_tokens = ("bias", "norm", "rms", "layernorm")
297
+ decay_params, no_decay_params = [], []
298
+ for n, p in module.named_parameters():
299
+ if not p.requires_grad:
300
+ continue
301
+ n_l = n.lower()
302
+ if any(t in n_l for t in no_decay_tokens):
303
+ no_decay_params.append(p)
304
+ else:
305
+ decay_params.append(p)
306
+ return [
307
+ {"params": decay_params, "weight_decay": weight_decay},
308
+ {"params": no_decay_params, "weight_decay": 0.0},
309
+ ]
310
+
311
+ def create_optimizer(name, param_groups):
312
+ if name == "adam8bit":
313
+ return bnb.optim.AdamW8bit(param_groups, lr=base_learning_rate, betas=(0.9, beta2), eps=eps)
314
+ raise ValueError(name)
315
+
316
+ param_groups = get_param_groups(get_core_model(vae), weight_decay=0.001)
317
+ optimizer = create_optimizer(optimizer_type, param_groups)
318
+
319
+ # --------------------------- LR schedule ---------------------------
320
+ batches_per_epoch = len(dataloader)
321
+ steps_per_epoch = int(math.ceil(batches_per_epoch / float(gradient_accumulation_steps)))
322
+ total_steps = steps_per_epoch * num_epochs
323
+
324
+ def lr_lambda(step):
325
+ if not use_decay:
326
+ return 1.0
327
+ x = float(step) / float(max(1, total_steps))
328
+ warmup = float(warmup_percent)
329
+ min_ratio = float(min_learning_rate) / float(base_learning_rate)
330
+ if x < warmup:
331
+ return min_ratio + (1.0 - min_ratio) * (x / warmup)
332
+ decay_ratio = (x - warmup) / (1.0 - warmup)
333
+ return min_ratio + 0.5 * (1.0 - min_ratio) * (1.0 + math.cos(math.pi * decay_ratio))
334
+
335
+ scheduler = LambdaLR(optimizer, lr_lambda)
336
+
337
+ # Подготовка
338
+ dataloader, vae, optimizer, scheduler = accelerator.prepare(dataloader, vae, optimizer, scheduler)
339
+ trainable_params = [p for p in vae.parameters() if p.requires_grad]
340
+
341
+ # fdl
342
+ fdl_loss = FDL_loss()
343
+ fdl_loss = fdl_loss.to(accelerator.device)
344
+
345
+ # --------------------------- LPIPS и вспомогательные ---------------------------
346
+ _lpips_net = None
347
+ def _get_lpips():
348
+ global _lpips_net
349
+ if _lpips_net is None:
350
+ _lpips_net = lpips.LPIPS(net='vgg', verbose=False).eval().to(accelerator.device).eval()
351
+ return _lpips_net
352
+
353
+ _sobel_kx = torch.tensor([[[[-1,0,1],[-2,0,2],[-1,0,1]]]], dtype=torch.float32)
354
+ _sobel_ky = torch.tensor([[[[-1,-2,-1],[0,0,0],[1,2,1]]]], dtype=torch.float32)
355
+ def sobel_edges(x: torch.Tensor) -> torch.Tensor:
356
+ C = x.shape[1]
357
+ kx = _sobel_kx.to(x.device, x.dtype).repeat(C, 1, 1, 1)
358
+ ky = _sobel_ky.to(x.device, x.dtype).repeat(C, 1, 1, 1)
359
+ gx = F.conv2d(x, kx, padding=1, groups=C)
360
+ gy = F.conv2d(x, ky, padding=1, groups=C)
361
+ return torch.sqrt(gx * gx + gy * gy + 1e-12)
362
+
363
+ class MedianLossNormalizer:
364
+ def __init__(self, desired_ratios: dict, window_steps: int):
365
+ s = sum(desired_ratios.values())
366
+ self.ratios = {k: (v / s) if s > 0 else 0.0 for k, v in desired_ratios.items()}
367
+ self.buffers = {k: deque(maxlen=window_steps) for k in self.ratios.keys()}
368
+ self.window = window_steps
369
+
370
+ def update_and_total(self, abs_losses: dict):
371
+ for k, v in abs_losses.items():
372
+ if k in self.buffers:
373
+ self.buffers[k].append(float(v.detach().abs().cpu()))
374
+ meds = {k: (np.median(self.buffers[k]) if len(self.buffers[k]) > 0 else 1.0) for k in self.buffers}
375
+ coeffs = {k: (self.ratios[k] / max(meds[k], 1e-12)) for k in self.ratios}
376
+ total = sum(coeffs[k] * abs_losses[k] for k in abs_losses if k in coeffs)
377
+ return total, coeffs, meds
378
+
379
+ if full_training and not train_decoder_only:
380
+ loss_ratios["kl"] = float(kl_ratio)
381
+ normalizer = MedianLossNormalizer(loss_ratios, median_coeff_steps)
382
+
383
+ # --------------------------- Сэмплы ---------------------------
384
+ @torch.no_grad()
385
+ def get_fixed_samples(n=3):
386
+ idx = random.sample(range(len(dataset)), min(n, len(dataset)))
387
+ pil_imgs = [dataset[i] for i in idx]
388
+ tensors = []
389
+ for img in pil_imgs:
390
+ img = random_crop(img, high_resolution)
391
+ tensors.append(tfm(img))
392
+ return torch.stack(tensors).to(accelerator.device, dtype)
393
+
394
+ fixed_samples = get_fixed_samples()
395
+
396
+ @torch.no_grad()
397
+ def _to_pil_uint8(img_tensor: torch.Tensor) -> Image.Image:
398
+ arr = ((img_tensor.float().clamp(-1, 1) + 1.0) * 127.5).clamp(0, 255).byte().cpu().numpy().transpose(1, 2, 0)
399
+ return Image.fromarray(arr)
400
+
401
+
402
+ @torch.no_grad()
403
+ def generate_and_save_samples(step=None):
404
+ try:
405
+ #temp_vae = accelerator.unwrap_model(vae).eval()
406
+ if hasattr(vae, "module"):
407
+ # Если это DDP или DistributedDataParallel
408
+ unwrapped_vae = vae.module
409
+ else:
410
+ unwrapped_vae = vae
411
+
412
+ # Если использовался torch.compile, достаем оригинал
413
+ if hasattr(unwrapped_vae, "_orig_mod"):
414
+ temp_vae = unwrapped_vae._orig_mod
415
+ else:
416
+ temp_vae = unwrapped_vae
417
+
418
+ temp_vae = temp_vae.eval()
419
+ lpips_net = _get_lpips()
420
+ with torch.no_grad():
421
+ orig_high = fixed_samples
422
+ orig_low = F.interpolate(
423
+ orig_high,
424
+ size=(model_resolution, model_resolution),
425
+ mode="bilinear",
426
+ align_corners=False
427
+ )
428
+ model_dtype = next(temp_vae.parameters()).dtype
429
+ orig_low = orig_low.to(dtype=model_dtype)
430
+
431
+ # Encode/decode с учётом видео-режима
432
+ if is_video_vae(temp_vae):
433
+ x_in = orig_low.unsqueeze(2) # [B,3,1,H,W]
434
+ enc = temp_vae.encode(x_in)
435
+ latents_mean = enc.latent_dist.mean
436
+ dec = temp_vae.decode(latents_mean).sample # [B,3,1,H,W]
437
+ rec = dec.squeeze(2) # [B,3,H,W]
438
+ else:
439
+ enc = temp_vae.encode(orig_low)
440
+ latents_mean = enc.latent_dist.mean
441
+ rec = temp_vae.decode(latents_mean).sample
442
+
443
+ # Подгон размеров, если надо
444
+ #if rec.shape[-2:] != orig_high.shape[-2:]:
445
+ # rec = F.interpolate(rec, size=orig_high.shape[-2:], mode="bilinear", align_corners=False)
446
+
447
+ # Сохраняем все real/decoded
448
+ for i in range(rec.shape[0]):
449
+ real_img = _to_pil_uint8(orig_high[i])
450
+ dec_img = _to_pil_uint8(rec[i])
451
+ real_img.save(f"{generated_folder}/sample_real_{i}.png")
452
+ dec_img.save(f"{generated_folder}/sample_decoded_{i}.png")
453
+
454
+ # LPIPS
455
+ lpips_scores = []
456
+ for i in range(rec.shape[0]):
457
+ orig_full = orig_high[i:i+1].to(torch.float32)
458
+ rec_full = rec[i:i+1].to(torch.float32)
459
+ #if rec_full.shape[-2:] != orig_full.shape[-2:]:
460
+ # rec_full = F.interpolate(rec_full, size=orig_full.shape[-2:], mode="bilinear", align_corners=False)
461
+ lpips_val = lpips_net(orig_full, rec_full).item()
462
+ lpips_scores.append(lpips_val)
463
+ avg_lpips = float(np.mean(lpips_scores))
464
+
465
+ # W&B логирование
466
+ if use_wandb and accelerator.is_main_process:
467
+ log_data = {"lpips_mean": avg_lpips}
468
+ for i in range(rec.shape[0]):
469
+ log_data[f"sample/real_{i}"] = wandb.Image(f"{generated_folder}/sample_real_{i}.png", caption=f"real_{i}")
470
+ log_data[f"sample/decoded_{i}"] = wandb.Image(f"{generated_folder}/sample_decoded_{i}.png", caption=f"decoded_{i}")
471
+ wandb.log(log_data, step=step)
472
+
473
+ finally:
474
+ gc.collect()
475
+ torch.cuda.empty_cache()
476
+
477
+
478
+ if accelerator.is_main_process and save_model:
479
+ print("Генерация сэмплов до старта обучения...")
480
+ generate_and_save_samples(0)
481
+
482
+ accelerator.wait_for_everyone()
483
+
484
+ # --------------------------- Тренировка ---------------------------
485
+ progress = tqdm(total=total_steps, disable=not accelerator.is_local_main_process)
486
+ global_step = 0
487
+ min_loss = float("inf")
488
+ sample_interval = max(1, total_steps // max(1, sample_interval_share * num_epochs))
489
+
490
+ for epoch in range(num_epochs):
491
+ vae.train()
492
+ batch_losses, batch_grads = [], []
493
+ track_losses = {k: [] for k in loss_ratios.keys()}
494
+
495
+ for imgs in dataloader:
496
+ with accelerator.accumulate(vae):
497
+ imgs = imgs.to(accelerator.device)
498
+
499
+ if high_resolution != model_resolution:
500
+ imgs_low = F.interpolate(imgs, size=(model_resolution, model_resolution),mode="area") # mode="bilinear", align_corners=False)
501
+ else:
502
+ imgs_low = imgs
503
+
504
+ model_dtype = next(vae.parameters()).dtype
505
+ imgs_low_model = imgs_low.to(dtype=model_dtype) if imgs_low.dtype != model_dtype else imgs_low
506
+
507
+ # Вместо: current_vae = accelerator.unwrap_model(vae)
508
+ unwrapped = vae.module if hasattr(vae, "module") else vae
509
+ current_vae = getattr(unwrapped, "_orig_mod", unwrapped)
510
+
511
+
512
+ # QWEN: encode/decode с T=1
513
+ if is_video_vae(current_vae):
514
+ x_in = imgs_low_model.unsqueeze(2) # [B,3,1,H,W]
515
+ enc = current_vae.encode(x_in)
516
+ latents = enc.latent_dist.mean if train_decoder_only else enc.latent_dist.sample()
517
+ dec = current_vae.decode(latents).sample # [B,3,1,H,W]
518
+ rec = dec.squeeze(2) # [B,3,H,W]
519
+ else:
520
+ enc = current_vae.encode(imgs_low_model)
521
+ latents = enc.latent_dist.mean if train_decoder_only else enc.latent_dist.sample()
522
+ rec = current_vae.decode(latents).sample
523
+
524
+ #if rec.shape[-2:] != imgs.shape[-2:]:
525
+ # rec = F.interpolate(rec, size=imgs.shape[-2:], mode="bilinear", align_corners=False)
526
+
527
+ rec_f32 = rec.to(torch.float32)
528
+ imgs_f32 = imgs.to(torch.float32)
529
+
530
+ abs_losses = {
531
+ "mae": F.l1_loss(rec_f32, imgs_f32),
532
+ "mse": F.mse_loss(rec_f32, imgs_f32),
533
+ "lpips": _get_lpips()(rec_f32, imgs_f32).mean(),
534
+ "fdl": fdl_loss(rec_f32, imgs_f32),
535
+ "edge": F.l1_loss(sobel_edges(rec_f32), sobel_edges(imgs_f32)),
536
+ }
537
+
538
+ if full_training and not train_decoder_only:
539
+ mean = enc.latent_dist.mean
540
+ logvar = enc.latent_dist.logvar
541
+ kl = -0.5 * torch.mean(1 + logvar - mean.pow(2) - logvar.exp())
542
+ abs_losses["kl"] = kl
543
+ else:
544
+ abs_losses["kl"] = torch.tensor(0.0, device=accelerator.device, dtype=torch.float32)
545
+
546
+ total_loss, coeffs, meds = normalizer.update_and_total(abs_losses)
547
+
548
+ if torch.isnan(total_loss) or torch.isinf(total_loss):
549
+ raise RuntimeError("NaN/Inf loss")
550
+
551
+ accelerator.backward(total_loss)
552
+
553
+ grad_norm = torch.tensor(0.0, device=accelerator.device)
554
+ if accelerator.sync_gradients:
555
+ grad_norm = accelerator.clip_grad_norm_(trainable_params, clip_grad_norm)
556
+ optimizer.step()
557
+ scheduler.step()
558
+ optimizer.zero_grad(set_to_none=True)
559
+ global_step += 1
560
+ progress.update(1)
561
+
562
+ if accelerator.is_main_process:
563
+ try:
564
+ current_lr = optimizer.param_groups[0]["lr"]
565
+ except Exception:
566
+ current_lr = scheduler.get_last_lr()[0]
567
+
568
+ batch_losses.append(total_loss.detach().item())
569
+ batch_grads.append(float(grad_norm.detach().cpu().item()) if isinstance(grad_norm, torch.Tensor) else float(grad_norm))
570
+ for k, v in abs_losses.items():
571
+ track_losses[k].append(float(v.detach().item()))
572
+
573
+ if use_wandb and accelerator.sync_gradients:
574
+ log_dict = {
575
+ "total_loss": float(total_loss.detach().item()),
576
+ "learning_rate": current_lr,
577
+ "epoch": epoch,
578
+ "grad_norm": batch_grads[-1],
579
+ }
580
+ for k, v in abs_losses.items():
581
+ log_dict[f"loss_{k}"] = float(v.detach().item())
582
+ for k in coeffs:
583
+ log_dict[f"coeff_{k}"] = float(coeffs[k])
584
+ log_dict[f"median_{k}"] = float(meds[k])
585
+ wandb.log(log_dict, step=global_step)
586
+
587
+ if global_step > 0 and global_step % sample_interval == 0:
588
+ if accelerator.is_main_process:
589
+ generate_and_save_samples(global_step)
590
+ accelerator.wait_for_everyone()
591
+
592
+ n_micro = sample_interval * gradient_accumulation_steps
593
+ avg_loss = float(np.mean(batch_losses[-n_micro:])) if len(batch_losses) >= n_micro else float(np.mean(batch_losses)) if batch_losses else float("nan")
594
+ avg_grad = float(np.mean(batch_grads[-n_micro:])) if len(batch_grads) >= 1 else float(np.mean(batch_grads)) if batch_grads else 0.0
595
+
596
+ if accelerator.is_main_process:
597
+ print(f"Epoch {epoch} step {global_step} loss: {avg_loss:.6f}, grad_norm: {avg_grad:.6f}, lr: {current_lr:.9f}")
598
+ if save_model and avg_loss < min_loss * save_barrier:
599
+ min_loss = avg_loss
600
+ unwrapped = vae.module if hasattr(vae, "module") else vae
601
+ current_vae = getattr(unwrapped, "_orig_mod", unwrapped)
602
+ current_vae.save_pretrained(save_as)
603
+ if use_wandb:
604
+ wandb.log({"interm_loss": avg_loss, "interm_grad": avg_grad}, step=global_step)
605
+
606
+ if accelerator.is_main_process:
607
+ epoch_avg = float(np.mean(batch_losses)) if batch_losses else float("nan")
608
+ print(f"Epoch {epoch} done, avg loss {epoch_avg:.6f}")
609
+ if use_wandb:
610
+ wandb.log({"epoch_loss": epoch_avg, "epoch": epoch + 1}, step=global_step)
611
+
612
+ # --------------------------- Финальное сохранение ---------------------------
613
+ if accelerator.is_main_process:
614
+ print("Training finished – saving final model")
615
+ if save_model:
616
+ unwrapped = vae.module if hasattr(vae, "module") else vae
617
+ current_vae = getattr(unwrapped, "_orig_mod", unwrapped)
618
+ current_vae.save_pretrained(save_as)
619
+
620
+ accelerator.free_memory()
621
+ if torch.distributed.is_initialized():
622
+ torch.distributed.destroy_process_group()
623
+ print("Готово!")