recoilme commited on
Commit
0c837de
·
1 Parent(s): 6451318
result_grid.jpg CHANGED

Git LFS Details

  • SHA256: f52bf87a4ac6698eb77c01ff1af063274f13bd3db2691bfcaa92e25b3fd17d29
  • Pointer size: 132 Bytes
  • Size of remote file: 8.21 MB

Git LFS Details

  • SHA256: b64f037fd81a29ae8a2afeb6a5ad1935a28e7ebee82eab88b2eda83560c94f2f
  • Pointer size: 132 Bytes
  • Size of remote file: 8.14 MB
src/cherrypick.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
src/sample.ipynb CHANGED
The diff for this file is too large to render. See raw diff
 
train-Copy1.py DELETED
@@ -1,756 +0,0 @@
1
- import os
2
- import math
3
- import torch
4
- import numpy as np
5
- import matplotlib.pyplot as plt
6
- from torch.utils.data import DataLoader, Sampler
7
- from torch.utils.data.distributed import DistributedSampler
8
- from torch.optim.lr_scheduler import LambdaLR
9
- from collections import defaultdict
10
- from torch.optim.lr_scheduler import LambdaLR
11
- from diffusers import UNet2DConditionModel, AutoencoderKLWan,AutoencoderKL
12
- from accelerate import Accelerator
13
- from datasets import load_from_disk
14
- from tqdm import tqdm
15
- from PIL import Image,ImageOps
16
- import wandb
17
- import random
18
- import gc
19
- from accelerate.state import DistributedType
20
- from torch.distributed import broadcast_object_list
21
- from torch.utils.checkpoint import checkpoint
22
- from diffusers.models.attention_processor import AttnProcessor2_0
23
- from datetime import datetime
24
- import bitsandbytes as bnb
25
- import torch.nn.functional as F
26
- from collections import deque
27
-
28
- # --------------------------- Параметры ---------------------------
29
- ds_path = "/workspace/sdxs3d/datasets/mjnj"
30
- project = "unet"
31
- batch_size = 128
32
- base_learning_rate = 8e-5
33
- min_learning_rate = 2e-5
34
- num_epochs = 50
35
- # samples/save per epoch
36
- sample_interval_share = 5
37
- use_wandb = True
38
- use_comet_ml = False
39
- save_model = True
40
- use_decay = True
41
- fbp = False # fused backward pass
42
- optimizer_type = "adam8bit"
43
- torch_compile = False
44
- unet_gradient = True
45
- clip_sample = False #Scheduler
46
- fixed_seed = True
47
- shuffle = True
48
- comet_ml_api_key = "Agctp26mbqnoYrrlvQuKSTk6r" # Добавлен API ключ для Comet ML
49
- comet_ml_workspace = "recoilme" # Добавлен workspace для Comet ML
50
- torch.backends.cuda.matmul.allow_tf32 = True
51
- torch.backends.cudnn.allow_tf32 = True
52
- torch.backends.cuda.enable_mem_efficient_sdp(False)
53
- dtype = torch.float32
54
- save_barrier = 1.01
55
- warmup_percent = 0.01
56
- percentile_clipping = 95 # 8bit optim
57
- betta2 = 0.97
58
- eps = 1e-6
59
- clip_grad_norm = 1.0
60
- steps_offset = 0 # Scheduler
61
- limit = 0
62
- checkpoints_folder = ""
63
- mixed_precision = "no" #"fp16"
64
- gradient_accumulation_steps = 1
65
- accelerator = Accelerator(
66
- mixed_precision=mixed_precision,
67
- gradient_accumulation_steps=gradient_accumulation_steps
68
- )
69
- device = accelerator.device
70
-
71
- # Параметры для диффузии
72
- n_diffusion_steps = 50
73
- samples_to_generate = 12
74
- guidance_scale = 4
75
-
76
- # Папки для сохранения результатов
77
- generated_folder = "samples"
78
- os.makedirs(generated_folder, exist_ok=True)
79
-
80
- # Настройка seed для воспроизводимости
81
- current_date = datetime.now()
82
- seed = int(current_date.strftime("%Y%m%d"))
83
- if fixed_seed:
84
- torch.manual_seed(seed)
85
- np.random.seed(seed)
86
- random.seed(seed)
87
- if torch.cuda.is_available():
88
- torch.cuda.manual_seed_all(seed)
89
-
90
- # --------------------------- Параметры LoRA ---------------------------
91
- lora_name = ""
92
- lora_rank = 32
93
- lora_alpha = 64
94
-
95
- print("init")
96
-
97
- # --------------------------- вспомогательные функции ---------------------------
98
- def sample_timesteps_bias(
99
- batch_size: int,
100
- progress: float, # [0..1]
101
- num_train_timesteps: int, # обычно 1000
102
- steps_offset: int = 0,
103
- device=None,
104
- mode: str = "beta", # "beta", "uniform"
105
- ) -> torch.Tensor:
106
- """
107
- Возвращает timesteps с разным bias:
108
- - beta : как раньше (сдвиг в начало или конец в зависимости от progress)
109
- - normal : около середины (гауссовое распределение)
110
- - uniform: равномерно по всем timestep’ам
111
- """
112
-
113
- max_idx = num_train_timesteps - 1 - steps_offset
114
-
115
- if mode == "beta":
116
- alpha = 1.0 + .5 * (1.0 - progress)
117
- beta = 1.0 + .5 * progress
118
- samples = torch.distributions.Beta(alpha, beta).sample((batch_size,))
119
-
120
- elif mode == "uniform":
121
- samples = torch.rand(batch_size)
122
-
123
- else:
124
- raise ValueError(f"Unknown mode: {mode}")
125
-
126
- timesteps = steps_offset + (samples * max_idx).long().to(device)
127
- return timesteps
128
-
129
- def logit_normal_samples(shape, mu=0.0, sigma=1.0, device=None, dtype=None):
130
- normal_samples = torch.normal(mean=mu, std=sigma, size=shape, device=device, dtype=dtype)
131
-
132
- logit_normal_samples = torch.sigmoid(normal_samples)
133
-
134
- return logit_normal_samples
135
-
136
- # --------------------------- Инициализация WandB ---------------------------
137
- if accelerator.is_main_process:
138
- if use_wandb:
139
- wandb.init(project=project+lora_name, config={
140
- "batch_size": batch_size,
141
- "base_learning_rate": base_learning_rate,
142
- "num_epochs": num_epochs,
143
- "fbp": fbp,
144
- "optimizer_type": optimizer_type,
145
- })
146
- if use_comet_ml:
147
- from comet_ml import Experiment
148
- comet_experiment = Experiment(
149
- api_key=comet_ml_api_key,
150
- project_name=project,
151
- workspace=comet_ml_workspace
152
- )
153
- # Логируем гиперп��раметры в Comet ML
154
- hyper_params = {
155
- "batch_size": batch_size,
156
- "base_learning_rate": base_learning_rate,
157
- "min_learning_rate": min_learning_rate,
158
- "num_epochs": num_epochs,
159
- "n_diffusion_steps": n_diffusion_steps,
160
- "guidance_scale": guidance_scale,
161
- "optimizer_type": optimizer_type,
162
- "mixed_precision": mixed_precision,
163
- }
164
- comet_experiment.log_parameters(hyper_params)
165
-
166
- # Включение Flash Attention 2/SDPA
167
- torch.backends.cuda.enable_flash_sdp(True)
168
- # --------------------------- Инициализация Accelerator --------------------
169
- gen = torch.Generator(device=device)
170
- gen.manual_seed(seed)
171
-
172
- # --------------------------- Загрузка моделей ---------------------------
173
- # VAE загружается на CPU для экономии GPU-памяти (как в твоём оригинальном коде)
174
- vae = AutoencoderKL.from_pretrained("AiArtLab/simplevae", subfolder="vae", torch_dtype=dtype).to("cpu").eval()
175
-
176
- shift_factor = getattr(vae.config, "shift_factor", 0.0)
177
- if shift_factor is None:
178
- shift_factor = 0.0
179
-
180
- scaling_factor = getattr(vae.config, "scaling_factor", 1.0)
181
- if scaling_factor is None:
182
- scaling_factor = 1.0
183
-
184
- latents_mean = getattr(vae.config, "latents_mean", None)
185
- latents_std = getattr(vae.config, "latents_std", None)
186
-
187
- from diffusers import FlowMatchEulerDiscreteScheduler
188
-
189
- # Подстрой под свои параметры
190
- num_train_timesteps = 1000
191
-
192
- scheduler = FlowMatchEulerDiscreteScheduler(
193
- num_train_timesteps=num_train_timesteps,
194
- #shift=3.0, # пример; подбирается при необходимости
195
- #use_dynamic_shifting=True
196
- )
197
-
198
-
199
- class DistributedResolutionBatchSampler(Sampler):
200
- def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, drop_last=True):
201
- self.dataset = dataset
202
- self.batch_size = max(1, batch_size // num_replicas)
203
- self.num_replicas = num_replicas
204
- self.rank = rank
205
- self.shuffle = shuffle
206
- self.drop_last = drop_last
207
- self.epoch = 0
208
-
209
- try:
210
- widths = np.array(dataset["width"])
211
- heights = np.array(dataset["height"])
212
- except KeyError:
213
- widths = np.zeros(len(dataset))
214
- heights = np.zeros(len(dataset))
215
-
216
- self.size_keys = np.unique(np.stack([widths, heights], axis=1), axis=0)
217
- self.size_groups = {}
218
- for w, h in self.size_keys:
219
- mask = (widths == w) & (heights == h)
220
- self.size_groups[(w, h)] = np.where(mask)[0]
221
-
222
- self.group_num_batches = {}
223
- total_batches = 0
224
- for size, indices in self.size_groups.items():
225
- num_full_batches = len(indices) // (self.batch_size * self.num_replicas)
226
- self.group_num_batches[size] = num_full_batches
227
- total_batches += num_full_batches
228
-
229
- self.num_batches = (total_batches // self.num_replicas) * self.num_replicas
230
-
231
- def __iter__(self):
232
- if torch.cuda.is_available():
233
- torch.cuda.empty_cache()
234
- all_batches = []
235
- rng = np.random.RandomState(self.epoch)
236
-
237
- for size, indices in self.size_groups.items():
238
- indices = indices.copy()
239
- if self.shuffle:
240
- rng.shuffle(indices)
241
- num_full_batches = self.group_num_batches[size]
242
- if num_full_batches == 0:
243
- continue
244
- valid_indices = indices[:num_full_batches * self.batch_size * self.num_replicas]
245
- batches = valid_indices.reshape(-1, self.batch_size * self.num_replicas)
246
- start_idx = self.rank * self.batch_size
247
- end_idx = start_idx + self.batch_size
248
- gpu_batches = batches[:, start_idx:end_idx]
249
- all_batches.extend(gpu_batches)
250
-
251
- if self.shuffle:
252
- rng.shuffle(all_batches)
253
- accelerator.wait_for_everyone()
254
- return iter(all_batches)
255
-
256
- def __len__(self):
257
- return self.num_batches
258
-
259
- def set_epoch(self, epoch):
260
- self.epoch = epoch
261
-
262
- # Функция для выборки фиксированных семплов по размерам
263
- def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
264
- size_groups = defaultdict(list)
265
- try:
266
- widths = dataset["width"]
267
- heights = dataset["height"]
268
- except KeyError:
269
- widths = [0] * len(dataset)
270
- heights = [0] * len(dataset)
271
- for i, (w, h) in enumerate(zip(widths, heights)):
272
- size = (w, h)
273
- size_groups[size].append(i)
274
-
275
- fixed_samples = {}
276
- for size, indices in size_groups.items():
277
- n_samples = min(samples_per_group, len(indices))
278
- if len(size_groups)==1:
279
- n_samples = samples_to_generate
280
- if n_samples == 0:
281
- continue
282
- sample_indices = random.sample(indices, n_samples)
283
- samples_data = [dataset[idx] for idx in sample_indices]
284
- latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device,dtype=dtype)
285
- embeddings = torch.tensor(np.array([item["embeddings"] for item in samples_data])).to(device,dtype=dtype)
286
- texts = [item["text"] for item in samples_data]
287
- fixed_samples[size] = (latents, embeddings, texts)
288
-
289
- print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
290
- return fixed_samples
291
-
292
- if limit > 0:
293
- dataset = load_from_disk(ds_path).select(range(limit))
294
- else:
295
- dataset = load_from_disk(ds_path)
296
-
297
- def collate_fn_simple(batch):
298
- latents = torch.tensor(np.array([item["vae"] for item in batch])).to(device,dtype=dtype)
299
- embeddings = torch.tensor(np.array([item["embeddings"] for item in batch])).to(device,dtype=dtype)
300
- return latents, embeddings
301
-
302
- batch_sampler = DistributedResolutionBatchSampler(
303
- dataset=dataset,
304
- batch_size=batch_size,
305
- num_replicas=accelerator.num_processes,
306
- rank=accelerator.process_index,
307
- shuffle=shuffle
308
- )
309
-
310
- dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple)
311
- print("Total samples",len(dataloader))
312
- dataloader = accelerator.prepare(dataloader)
313
-
314
- start_epoch = 0
315
- global_step = 0
316
- total_training_steps = (len(dataloader) * num_epochs)
317
- world_size = accelerator.state.num_processes
318
-
319
- # Опция загрузки модели из последнего чекпоинта (если существует)
320
- latest_checkpoint = os.path.join(checkpoints_folder, project)
321
- if os.path.isdir(latest_checkpoint):
322
- print("Загружаем UNet из чекпоинта:", latest_checkpoint)
323
- unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device,dtype=dtype)
324
- if torch_compile:
325
- print("compiling")
326
- torch.set_float32_matmul_precision('high')
327
- unet = torch.compile(unet)
328
- print("compiling - ok")
329
- if unet_gradient:
330
- unet.enable_gradient_checkpointing()
331
- unet.set_use_memory_efficient_attention_xformers(False)
332
- try:
333
- unet.set_attn_processor(AttnProcessor2_0())
334
- except Exception as e:
335
- print(f"Ошибка при включении SDPA: {e}")
336
- unet.set_use_memory_efficient_attention_xformers(True)
337
-
338
- else:
339
- # FIX: если чекпоинта нет — прекращаем с понятной ошибкой (лучше, чем неожиданные NameError дальше)
340
- raise FileNotFoundError(f"UNet checkpoint not found at {latest_checkpoint}. Положи UNet чекпоинт в {latest_checkpoint} или укажи другой путь.")
341
-
342
- if lora_name:
343
- print(f"--- Настройка LoRA через PEFT (Rank={lora_rank}, Alpha={lora_alpha}) ---")
344
- from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
345
- from peft.tuners.lora import LoraModel
346
- import os
347
- unet.requires_grad_(False)
348
- print("Параметры базового UNet заморожены.")
349
-
350
- lora_config = LoraConfig(
351
- r=lora_rank,
352
- lora_alpha=lora_alpha,
353
- target_modules=["to_q", "to_k", "to_v", "to_out.0"],
354
- )
355
- unet.add_adapter(lora_config)
356
-
357
- from peft import get_peft_model
358
- peft_unet = get_peft_model(unet, lora_config)
359
- params_to_optimize = list(p for p in peft_unet.parameters() if p.requires_grad)
360
-
361
- if accelerator.is_main_process:
362
- lora_params_count = sum(p.numel() for p in params_to_optimize)
363
- total_params_count = sum(p.numel() for p in unet.parameters())
364
- print(f"Количество обучаемых параметров (LoRA): {lora_params_count:,}")
365
- print(f"Общее количество параметров UNet: {total_params_count:,}")
366
-
367
- lora_save_path = os.path.join("lora", lora_name)
368
- os.makedirs(lora_save_path, exist_ok=True)
369
-
370
- def save_lora_checkpoint(model):
371
- if accelerator.is_main_process:
372
- print(f"Сохраняем LoRA адаптеры в {lora_save_path}")
373
- from peft.utils.save_and_load import get_peft_model_state_dict
374
- lora_state_dict = get_peft_model_state_dict(model)
375
- torch.save(lora_state_dict, os.path.join(lora_save_path, "adapter_model.bin"))
376
- model.peft_config["default"].save_pretrained(lora_save_path)
377
- from diffusers import StableDiffusionXLPipeline
378
- StableDiffusionXLPipeline.save_lora_weights(lora_save_path, lora_state_dict)
379
-
380
- # --------------------------- Оптимизатор ---------------------------
381
- if lora_name:
382
- trainable_params = [p for p in unet.parameters() if p.requires_grad]
383
- else:
384
- if fbp:
385
- trainable_params = list(unet.parameters())
386
-
387
- def create_optimizer(name, params):
388
- if name == "adam8bit":
389
- return bnb.optim.AdamW8bit(
390
- params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.01,
391
- percentile_clipping=percentile_clipping
392
- )
393
- elif name == "adam":
394
- return torch.optim.AdamW(
395
- params, lr=base_learning_rate, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01
396
- )
397
- else:
398
- raise ValueError(f"Unknown optimizer: {name}")
399
-
400
- if fbp:
401
- optimizer_dict = {p: create_optimizer(optimizer_type, [p]) for p in trainable_params}
402
- def optimizer_hook(param):
403
- optimizer_dict[param].step()
404
- optimizer_dict[param].zero_grad(set_to_none=True)
405
- for param in trainable_params:
406
- param.register_post_accumulate_grad_hook(optimizer_hook)
407
- unet, optimizer = accelerator.prepare(unet, optimizer_dict)
408
- else:
409
- optimizer = create_optimizer(optimizer_type, unet.parameters())
410
- def lr_schedule(step):
411
- x = step / (total_training_steps * world_size)
412
- warmup = warmup_percent
413
- if not use_decay:
414
- return base_learning_rate
415
- if x < warmup:
416
- return min_learning_rate + (base_learning_rate - min_learning_rate) * (x / warmup)
417
- decay_ratio = (x - warmup) / (1 - warmup)
418
- return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
419
- (1 + math.cos(math.pi * decay_ratio))
420
- lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
421
-
422
- num_params = sum(p.numel() for p in unet.parameters())
423
- print(f"[rank {accelerator.process_index}] total params: {num_params}")
424
- for name, param in unet.named_parameters():
425
- if torch.isnan(param).any() or torch.isinf(param).any():
426
- print(f"[rank {accelerator.process_index}] NaN/Inf in {name}")
427
- unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
428
-
429
- # --------------------------- Фиксированные семплы для генерации ---------------------------
430
- fixed_samples = get_fixed_samples_by_resolution(dataset)
431
-
432
- def get_negative_embedding(neg_prompt="", batch_size=1):
433
- """
434
- Возвращает эмбеддинг негативного промпта с батчем.
435
- Загружает модели, вычисляет эмбеддинг, выгружает модели на CPU.
436
- """
437
- import torch
438
- from transformers import AutoTokenizer, AutoModel
439
-
440
- # Настройки
441
- dtype = torch.float16
442
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
443
-
444
- # Загрузка моделей (если ещё не загружены)
445
- if not hasattr(get_negative_embedding, "tokenizer"):
446
- get_negative_embedding.tokenizer = AutoTokenizer.from_pretrained(
447
- "Qwen/Qwen3-Embedding-0.6B", padding_side="left"
448
- )
449
- get_negative_embedding.text_model = AutoModel.from_pretrained(
450
- "Qwen/Qwen3-Embedding-0.6B"
451
- ).to(device).eval()
452
-
453
- # Вспомогательная функция для пулинга
454
- def last_token_pool(last_hidden_states, attention_mask):
455
- left_padding = (attention_mask[:, -1].sum() == attention_mask.shape[0])
456
- if left_padding:
457
- return last_hidden_states[:, -1]
458
- else:
459
- sequence_lengths = attention_mask.sum(dim=1) - 1
460
- batch_size = last_hidden_states.shape[0]
461
- return last_hidden_states[
462
- torch.arange(batch_size, device=last_hidden_states.device),
463
- sequence_lengths
464
- ]
465
-
466
- # Вычисление эмбеддинга
467
- def encode_texts(texts, max_length=512):
468
- with torch.inference_mode():
469
- toks = get_negative_embedding.tokenizer(
470
- texts, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length
471
- ).to(device)
472
- outs = get_negative_embedding.text_model(**toks)
473
- emb = last_token_pool(outs.last_hidden_state, toks["attention_mask"])
474
- emb = emb.unsqueeze(1) # Добавляем размерность sequence_length
475
- return emb
476
-
477
- # Возвращаем эмбеддинг
478
- if not neg_prompt:
479
- hidden_dim = 1024 # Размерность эмбеддинга Qwen3-Embedding-0.6B
480
- return torch.zeros((batch_size, 1, hidden_dim), dtype=dtype, device=device)
481
-
482
- uncond_emb = encode_texts([neg_prompt]).to(dtype=dtype, device=device)
483
- uncond_emb = uncond_emb.repeat(batch_size, 1, 1) # Добавляем батч
484
-
485
- # Выгружаем модели
486
- if hasattr(get_negative_embedding, "text_model"):
487
- get_negative_embedding.text_model = get_negative_embedding.text_model.to("cpu")
488
- if hasattr(get_negative_embedding, "tokenizer"):
489
- del get_negative_embedding.tokenizer # Освобождаем память
490
- torch.cuda.empty_cache()
491
-
492
- return uncond_emb
493
-
494
- uncond_emb = get_negative_embedding("low quality")
495
-
496
- @torch.compiler.disable()
497
- @torch.no_grad()
498
- def generate_and_save_samples(fixed_samples_cpu,empty_embeddings, step):
499
- original_model = None
500
- try:
501
- original_model = accelerator.unwrap_model(unet, keep_torch_compile=True).eval()
502
- vae.to(device=device).eval() # временно подгружаем VAE на GPU для декодинга
503
-
504
-
505
- all_generated_images = []
506
- all_captions = []
507
-
508
- for size, (sample_latents, sample_text_embeddings, sample_text) in fixed_samples_cpu.items():
509
- width, height = size
510
- sample_latents = sample_latents.to(dtype=dtype, device=device)
511
- sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
512
-
513
- # начальный шум
514
- latents = torch.randn(
515
- sample_latents.shape,
516
- device=device,
517
- dtype=sample_latents.dtype,
518
- generator=torch.Generator(device=device).manual_seed(seed)
519
- )
520
-
521
- # подготовим timesteps через шедулер
522
- scheduler.set_timesteps(n_diffusion_steps, device=device)
523
-
524
- for t in scheduler.timesteps:
525
- # guidance: удваиваем батч
526
- if guidance_scale != 1:
527
- latent_model_input = torch.cat([latents, latents], dim=0)
528
- text_embeddings_batch = torch.cat([empty_embeddings, sample_text_embeddings], dim=0)
529
- else:
530
- latent_model_input = latents
531
- text_embeddings_batch = sample_text_embeddings
532
-
533
- # предсказание потока (velocity)
534
- model_out = original_model(latent_model_input, t, encoder_hidden_states=text_embeddings_batch)
535
- flow = getattr(model_out, "sample", model_out)
536
-
537
- # guidance объединение
538
- if guidance_scale != 1:
539
- flow_uncond, flow_cond = flow.chunk(2)
540
- flow = flow_uncond + guidance_scale * (flow_cond - flow_uncond)
541
-
542
- # шаг через scheduler
543
- latents = scheduler.step(flow, t, latents).prev_sample
544
-
545
- current_latents = latents
546
-
547
-
548
- # Параметры нормализации
549
- latent_for_vae = current_latents.detach() / scaling_factor + shift_factor
550
-
551
- decoded = vae.decode(latent_for_vae.to(torch.float32)).sample
552
- #decoded = decoded[:, :, 0, :, :] # [3, H, W]
553
- #print(decoded.ndim, decoded.shape)
554
-
555
- decoded_fp32 = decoded.to(torch.float32)
556
- for img_idx, img_tensor in enumerate(decoded_fp32):
557
-
558
- # Форма: [3, H, W] -> преобразуем в [H, W, 3]
559
- img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy()
560
- img = img.transpose(1, 2, 0) # Из [3, H, W] в [H, W, 3]
561
-
562
- #img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1, 2, 0)
563
- if np.isnan(img).any():
564
- print("NaNs found, saving stopped! Step:", step)
565
- pil_img = Image.fromarray((img * 255).astype("uint8"))
566
-
567
- max_w_overall = max(s[0] for s in fixed_samples_cpu.keys())
568
- max_h_overall = max(s[1] for s in fixed_samples_cpu.keys())
569
- max_w_overall = max(255, max_w_overall)
570
- max_h_overall = max(255, max_h_overall)
571
-
572
- padded_img = ImageOps.pad(pil_img, (max_w_overall, max_h_overall), color='white')
573
- all_generated_images.append(padded_img)
574
-
575
- caption_text = sample_text[img_idx][:200] if img_idx < len(sample_text) else ""
576
- all_captions.append(caption_text)
577
-
578
- sample_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg"
579
- pil_img.save(sample_path, "JPEG", quality=96)
580
-
581
- if use_wandb and accelerator.is_main_process:
582
- wandb_images = [
583
- wandb.Image(img, caption=f"{all_captions[i]}")
584
- for i, img in enumerate(all_generated_images)
585
- ]
586
- wandb.log({"generated_images": wandb_images})
587
- if use_comet_ml and accelerator.is_main_process:
588
- for i, img in enumerate(all_generated_images):
589
- comet_experiment.log_image(
590
- image_data=img,
591
- name=f"step_{step}_img_{i}",
592
- step=step,
593
- metadata={
594
- "caption": all_captions[i],
595
- "width": img.width,
596
- "height": img.height,
597
- "global_step": step
598
- }
599
- )
600
- finally:
601
- # вернуть VAE на CPU (как было в твоём коде)
602
- vae.to("cpu")
603
- for var in list(locals().keys()):
604
- if isinstance(locals()[var], torch.Tensor):
605
- del locals()[var]
606
- torch.cuda.empty_cache()
607
- gc.collect()
608
-
609
- # --------------------------- Генерация с��мплов перед обучением ---------------------------
610
- if accelerator.is_main_process:
611
- if save_model:
612
- print("Генерация сэмплов до старта обучения...")
613
- generate_and_save_samples(fixed_samples,uncond_emb,0)
614
- accelerator.wait_for_everyone()
615
-
616
- # Модифицируем функцию сохранения модели для поддержки LoRA
617
- def save_checkpoint(unet,variant=""):
618
- if accelerator.is_main_process:
619
- if lora_name:
620
- save_lora_checkpoint(unet)
621
- else:
622
- if variant!="":
623
- accelerator.unwrap_model(unet.to(dtype=torch.float16)).save_pretrained(os.path.join(checkpoints_folder, f"{project}"),variant=variant)
624
- else:
625
- accelerator.unwrap_model(unet).save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
626
- unet = unet.to(dtype=dtype)
627
-
628
- # --------------------------- Тренировочный цикл ---------------------------
629
- if accelerator.is_main_process:
630
- print(f"Total steps per GPU: {total_training_steps}")
631
-
632
- epoch_loss_points = []
633
- progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
634
-
635
- steps_per_epoch = len(dataloader)
636
- sample_interval = max(1, steps_per_epoch // sample_interval_share)
637
- min_loss = 2.
638
-
639
- for epoch in range(start_epoch, start_epoch + num_epochs):
640
- batch_losses = []
641
- batch_grads = []
642
- batch_sampler.set_epoch(epoch)
643
- accelerator.wait_for_everyone()
644
- unet.train()
645
- #print("epoch:",epoch)
646
- for step, (latents, embeddings) in enumerate(dataloader):
647
- with accelerator.accumulate(unet):
648
- if save_model == False and step == 5 :
649
- used_gb = torch.cuda.max_memory_allocated() / 1024**3
650
- print(f"Шаг {step}: {used_gb:.2f} GB")
651
-
652
- # шум
653
- noise = torch.randn_like(latents, dtype=latents.dtype)
654
-
655
- # берём t из [0, 1]
656
- t = torch.rand(latents.shape[0], device=latents.device, dtype=latents.dtype)
657
-
658
- # интерполяция между x0 и шумом
659
- noisy_latents = (1.0 - t.view(-1, 1, 1, 1)) * latents + t.view(-1, 1, 1, 1) * noise
660
-
661
- # делаем integer timesteps для UNet
662
- timesteps = (t * scheduler.config.num_train_timesteps).long()
663
-
664
- # предсказание потока (Flow)
665
- model_pred = unet(noisy_latents, timesteps, embeddings).sample
666
-
667
- # таргет — векторное поле (= разность между конечными точками)
668
- target = noise - latents # или latents - noise?
669
-
670
- # MSE лосс
671
- mse_loss = F.mse_loss(model_pred.float(), target.float())
672
-
673
- # Сохраняем для логов (мы сохраняем MSE отдельно — как показатель)
674
- batch_losses.append(mse_loss.detach().item())
675
-
676
- if (global_step % 100 == 0) or (global_step % sample_interval == 0):
677
- accelerator.wait_for_everyone()
678
-
679
- # Backward
680
- accelerator.backward(mse_loss)
681
-
682
- if (global_step % 100 == 0) or (global_step % sample_interval == 0):
683
- accelerator.wait_for_everyone()
684
-
685
- grad = 0.0
686
- if not fbp:
687
- if accelerator.sync_gradients:
688
- with torch.amp.autocast('cuda', enabled=False):
689
- grad_val = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm)
690
- grad = float(grad_val)
691
- optimizer.step()
692
- lr_scheduler.step()
693
- optimizer.zero_grad(set_to_none=True)
694
-
695
- if accelerator.sync_gradients:
696
- global_step += 1
697
- progress_bar.update(1)
698
- # Логируем метрики
699
- if accelerator.is_main_process:
700
- if fbp:
701
- current_lr = base_learning_rate
702
- else:
703
- current_lr = lr_scheduler.get_last_lr()[0]
704
- batch_grads.append(grad)
705
-
706
- log_data = {}
707
- log_data["loss"] = mse_loss.detach().item()
708
- log_data["lr"] = current_lr
709
- log_data["grad"] = grad
710
- if accelerator.sync_gradients:
711
- if use_wandb:
712
- wandb.log(log_data, step=global_step)
713
- if use_comet_ml:
714
- comet_experiment.log_metrics(log_data, step=global_step)
715
-
716
- # Генерируем сэмплы с заданным интервалом
717
- if global_step % sample_interval == 0:
718
- generate_and_save_samples(fixed_samples,uncond_emb, global_step)
719
- last_n = sample_interval
720
-
721
- if save_model:
722
- avg_sample_loss = np.mean(batch_losses[-sample_interval:]) if len(batch_losses) > 0 else 0.0
723
- print("saving:", avg_sample_loss < min_loss * save_barrier, "Avg:", avg_sample_loss)
724
- if avg_sample_loss is not None and avg_sample_loss < min_loss * save_barrier:
725
- min_loss = avg_sample_loss
726
- save_checkpoint(unet)
727
-
728
-
729
- if accelerator.is_main_process:
730
- # local averages
731
- avg_epoch_loss = np.mean(batch_losses) if len(batch_losses) > 0 else 0.0
732
- avg_epoch_grad = np.mean(batch_grads) if len(batch_grads) > 0 else 0.0
733
-
734
- print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
735
- log_data_ep = {
736
- "epoch_loss": avg_epoch_loss,
737
- "epoch_grad": avg_epoch_grad,
738
- "epoch": epoch + 1,
739
- }
740
- if use_wandb:
741
- wandb.log(log_data_ep)
742
- if use_comet_ml:
743
- comet_experiment.log_metrics(log_data_ep)
744
-
745
- # Завершение обучения - сохраняем финальную модель
746
- if accelerator.is_main_process:
747
- print("Обучение завершено! Сохраняем финальную модель...")
748
- if save_model:
749
- save_checkpoint(unet,"fp16")
750
- if use_comet_ml:
751
- comet_experiment.end()
752
- accelerator.free_memory()
753
- if torch.distributed.is_initialized():
754
- torch.distributed.destroy_process_group()
755
-
756
- print("Готово!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train.py CHANGED
@@ -28,12 +28,12 @@ from collections import deque
28
  # --------------------------- Параметры ---------------------------
29
  ds_path = "/workspace/sdxs3d/datasets/mjnj"
30
  project = "unet"
31
- batch_size = 64
32
- base_learning_rate = 4e-5
33
- min_learning_rate = 8e-6
34
- num_epochs = 50
35
  # samples/save per epoch
36
- sample_interval_share = 5
37
  use_wandb = True
38
  use_comet_ml = False
39
  save_model = True
 
28
  # --------------------------- Параметры ---------------------------
29
  ds_path = "/workspace/sdxs3d/datasets/mjnj"
30
  project = "unet"
31
+ batch_size = 32
32
+ base_learning_rate = 2e-5
33
+ min_learning_rate = 6e-6
34
+ num_epochs = 300
35
  # samples/save per epoch
36
+ sample_interval_share = 2
37
  use_wandb = True
38
  use_comet_ml = False
39
  save_model = True
unet/config.json CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:0ef8fbaff98c8d479d68b566d07ef4fb8e51ac26b9e8b5a3cb2b23f9a978f6ca
3
- size 1874
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:afc06beff07034f0ce9f671c83222e7f78eedc3b3ce93293143accdebef1b111
3
+ size 1887
unet/diffusion_pytorch_model.fp16.safetensors DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:8b32a7de45f43802f7e6248deb67b0b1dc2c94623e8a98d13b5b4e50ba84169b
3
- size 3092571208
 
 
 
 
unet/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1b8104061a32712d5c148806b79d64ebca9b2254298aa2e149b541c35496c1f8
3
- size 6184944280
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:613da2e12017b414266fb0d52161fecfe74c073be40cb098093449c481582895
3
+ size 3092571208