recoilme commited on
Commit
f08d8ce
·
1 Parent(s): 5170729
pipeline_sdxs-Copy1.py DELETED
@@ -1,190 +0,0 @@
1
- from diffusers import DiffusionPipeline
2
- import torch
3
- from diffusers.utils import BaseOutput
4
- from dataclasses import dataclass
5
- from typing import List, Union, Optional
6
- from PIL import Image
7
- import numpy as np
8
- from tqdm import tqdm
9
-
10
- @dataclass
11
- class SdxsPipelineOutput(BaseOutput):
12
- images: Union[List[Image.Image], np.ndarray]
13
-
14
- class SdxsPipeline(DiffusionPipeline):
15
- def __init__(self, vae, text_encoder, tokenizer, unet, scheduler, max_length: int = 150):
16
- super().__init__()
17
- self.register_modules(
18
- vae=vae, text_encoder=text_encoder, tokenizer=tokenizer,
19
- unet=unet, scheduler=scheduler
20
- )
21
- self.vae_scale_factor = 16
22
- self.max_length = max_length
23
-
24
- def encode_prompt(self, prompt=None, negative_prompt=None, device=None, dtype=None):
25
- device = device or self.device
26
- dtype = dtype or next(self.unet.parameters()).dtype
27
-
28
- # Преобразуем в списки
29
- if isinstance(prompt, str):
30
- prompt = [prompt]
31
- if isinstance(negative_prompt, str):
32
- negative_prompt = [negative_prompt]
33
-
34
- # Если промпты не заданы, используем пустые эмбеддинги
35
- if prompt is None and negative_prompt is None:
36
- hidden_dim = 1024 # Размерность эмбеддинга Qwen3-0.6B
37
- seq_len = 150
38
- batch_size = 1
39
- return torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device)
40
-
41
- # Токенизация с фиксированным max_length=150 и padding="max_length"
42
- def encode_texts(texts, max_length=150):
43
- with torch.no_grad():
44
- toks = self.tokenizer(
45
- texts,
46
- return_tensors="pt",
47
- padding="max_length",
48
- truncation=True,
49
- max_length=max_length
50
- ).to(device)
51
- outs = self.text_encoder(**toks, output_hidden_states=True, return_dict=True)
52
- hidden = outs.hidden_states[-1]
53
- mask = toks["attention_mask"].unsqueeze(-1) # (B, L, 1)
54
-
55
- # 3. Zero-pad embeddings for pad tokens
56
- hidden = hidden * mask
57
-
58
- return hidden
59
-
60
- # Кодируем позитивные и негативные промпты
61
- pos_embeddings = encode_texts(prompt) if prompt is not None else None
62
- neg_embeddings = encode_texts(negative_prompt) if negative_prompt is not None else None
63
-
64
- # Выравниваем размеры batch_size
65
- batch_size = max(
66
- pos_embeddings.shape[0] if pos_embeddings is not None else 0,
67
- neg_embeddings.shape[0] if neg_embeddings is not None else 0
68
- )
69
-
70
- # Повторяем эмбеддинги по batch_size
71
- if pos_embeddings is not None and pos_embeddings.shape[0] < batch_size:
72
- pos_embeddings = pos_embeddings.repeat(batch_size, 1, 1)
73
- if neg_embeddings is not None and neg_embeddings.shape[0] < batch_size:
74
- neg_embeddings = neg_embeddings.repeat(batch_size, 1, 1)
75
-
76
- # Конкатенируем для guidance
77
- if pos_embeddings is not None and neg_embeddings is not None:
78
- text_embeddings = torch.cat([neg_embeddings, pos_embeddings], dim=0)
79
- elif pos_embeddings is not None:
80
- text_embeddings = pos_embeddings
81
- else:
82
- text_embeddings = neg_embeddings
83
-
84
- return text_embeddings.to(device=device, dtype=dtype)
85
-
86
-
87
- @torch.no_grad()
88
- def generate_latents(
89
- self,
90
- text_embeddings,
91
- height: int = 1280,
92
- width: int = 1024,
93
- num_inference_steps: int = 40,
94
- guidance_scale: float = 4.0,
95
- latent_channels: int = 16,
96
- batch_size: int = 1,
97
- generator=None,
98
- ):
99
- device = self.device
100
- dtype = next(self.unet.parameters()).dtype
101
-
102
- self.scheduler.set_timesteps(num_inference_steps, device=device)
103
-
104
- # Разделяем эмбеддинги на условные и безусловные
105
- if guidance_scale > 1:
106
- neg_embeds, pos_embeds = text_embeddings.chunk(2)
107
- # Повторяем, если batch_size больше
108
- if batch_size > pos_embeds.shape[0]:
109
- pos_embeds = pos_embeds.repeat(batch_size, 1, 1)
110
- neg_embeds = neg_embeds.repeat(batch_size, 1, 1)
111
- text_embeddings = torch.cat([neg_embeds, pos_embeds], dim=0)
112
- else:
113
- text_embeddings = text_embeddings.repeat(batch_size, 1, 1)
114
-
115
- # Инициализация латентов
116
- latent_shape = (
117
- batch_size,
118
- latent_channels,
119
- height // self.vae_scale_factor,
120
- width // self.vae_scale_factor
121
- )
122
- latents = torch.randn(latent_shape, device=device, dtype=dtype, generator=generator)
123
-
124
- # Процесс диффузии
125
- for t in tqdm(self.scheduler.timesteps, desc="Генерация"):
126
- latent_input = torch.cat([latents, latents], dim=0) if guidance_scale > 1 else latents
127
- noise_pred = self.unet(latent_input, t, text_embeddings).sample
128
-
129
- if guidance_scale > 1:
130
- noise_uncond, noise_text = noise_pred.chunk(2)
131
- noise_pred = noise_uncond + guidance_scale * (noise_text - noise_uncond)
132
-
133
- latents = self.scheduler.step(noise_pred, t, latents).prev_sample
134
-
135
- return latents
136
-
137
-
138
- def decode_latents(self, latents, output_type="pil"):
139
- """Декодирование латентов в изображения."""
140
- latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
141
- with torch.no_grad():
142
- images = self.vae.decode(latents).sample
143
- images = (images / 2 + 0.5).clamp(0, 1)
144
-
145
- if output_type == "pil":
146
- images = images.cpu().permute(0, 2, 3, 1).float().numpy()
147
- images = (images * 255).round().astype("uint8")
148
- return [Image.fromarray(image) for image in images]
149
- return images.cpu().permute(0, 2, 3, 1).float().numpy()
150
-
151
- @torch.no_grad()
152
- def __call__(
153
- self,
154
- prompt: Optional[Union[str, List[str]]] = None,
155
- height: int = 1280,
156
- width: int = 1024,
157
- num_inference_steps: int = 40,
158
- guidance_scale: float = 4.0,
159
- latent_channels: int = 16,
160
- output_type: str = "pil",
161
- return_dict: bool = True,
162
- batch_size: int = 1,
163
- seed: Optional[int] = None,
164
- negative_prompt: Optional[Union[str, List[str]]] = None,
165
- text_embeddings: Optional[torch.FloatTensor] = None,
166
- ):
167
- device = self.device
168
- generator = torch.Generator(device=device).manual_seed(seed) if seed is not None else None
169
-
170
- if text_embeddings is None:
171
- if prompt is None and negative_prompt is None:
172
- raise ValueError("Необходимо указать prompt, negative_prompt или text_embeddings")
173
- text_embeddings = self.encode_prompt(prompt, negative_prompt, device=device, dtype=next(self.unet.parameters()).dtype)
174
-
175
- # text_embeddings уже имеет структуру [B_uncond + B_cond, seq_len, hid], dtype и device совместимы
176
- latents = self.generate_latents(
177
- text_embeddings=text_embeddings,
178
- height=height,
179
- width=width,
180
- num_inference_steps=num_inference_steps,
181
- guidance_scale=guidance_scale,
182
- latent_channels=latent_channels,
183
- batch_size=batch_size,
184
- generator=generator
185
- )
186
-
187
- images = self.decode_latents(latents, output_type=output_type)
188
- if not return_dict:
189
- return images
190
- return SdxsPipelineOutput(images=images)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dataset_new.py → src/dataset_new.py RENAMED
File without changes
dataset_old.py → src/dataset_old.py RENAMED
File without changes
pipeline_sdxs_no_pooling.py → src/pipeline_sdxs_no_pooling.py RENAMED
File without changes
pooling.py → src/pooling.py RENAMED
File without changes
tokenize.ipynb → src/tokenize.ipynb RENAMED
File without changes
train_no_pooling.py → src/train_no_pooling.py RENAMED
File without changes
test.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:907963d0ec5739bde474ea18ebd1f843e42419ead2ed5c57a64ddf363058efd3
3
  size 6169057
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a742a3da3ff7d38f7f98b059ce105cedaf65730d1d16a5801c0aa196b74f4069
3
  size 6169057
train-Copy1.py DELETED
@@ -1,771 +0,0 @@
1
- import os
2
- import math
3
- import torch
4
- import numpy as np
5
- import matplotlib.pyplot as plt
6
- from torch.utils.data import DataLoader, Sampler
7
- from torch.utils.data.distributed import DistributedSampler
8
- from torch.optim.lr_scheduler import LambdaLR
9
- from collections import defaultdict
10
- from torch.optim.lr_scheduler import LambdaLR
11
- from diffusers import UNet2DConditionModel, AutoencoderKLWan,AutoencoderKL
12
- from accelerate import Accelerator
13
- from datasets import load_from_disk
14
- from tqdm import tqdm
15
- from PIL import Image,ImageOps
16
- import wandb
17
- import random
18
- import gc
19
- from accelerate.state import DistributedType
20
- from torch.distributed import broadcast_object_list
21
- from torch.utils.checkpoint import checkpoint
22
- from diffusers.models.attention_processor import AttnProcessor2_0
23
- from datetime import datetime
24
- import bitsandbytes as bnb
25
- import torch.nn.functional as F
26
- from collections import deque
27
-
28
- # --------------------------- Параметры ---------------------------
29
- ds_path = "/workspace/sdxs/datasets/640"
30
- project = "unet"
31
- batch_size = 48
32
- base_learning_rate = 4e-5
33
- min_learning_rate = 2e-5
34
- num_epochs = 50
35
- # samples/save per epoch
36
- sample_interval_share = 5
37
- use_wandb = True
38
- use_comet_ml = False
39
- save_model = True
40
- use_decay = True
41
- fbp = False # fused backward pass
42
- optimizer_type = "adam8bit"
43
- torch_compile = False
44
- unet_gradient = True
45
- clip_sample = False #Scheduler
46
- fixed_seed = True
47
- shuffle = True
48
- comet_ml_api_key = "Agctp26mbqnoYrrlvQuKSTk6r" # Добавлен API ключ для Comet ML
49
- comet_ml_workspace = "recoilme" # Добавлен workspace для Comet ML
50
- torch.backends.cuda.matmul.allow_tf32 = True
51
- torch.backends.cudnn.allow_tf32 = True
52
- torch.backends.cuda.enable_mem_efficient_sdp(False)
53
- dtype = torch.float32
54
- save_barrier = 1.006
55
- warmup_percent = 0.005
56
- percentile_clipping = 99 # 8bit optim
57
- betta2 = 0.99
58
- eps = 1e-8
59
- clip_grad_norm = 1.0
60
- steps_offset = 0 # Scheduler
61
- limit = 0
62
- checkpoints_folder = ""
63
- mixed_precision = "no" #"fp16"
64
- gradient_accumulation_steps = 1
65
- accelerator = Accelerator(
66
- mixed_precision=mixed_precision,
67
- gradient_accumulation_steps=gradient_accumulation_steps
68
- )
69
- device = accelerator.device
70
-
71
- # Параметры для диффузии
72
- n_diffusion_steps = 50
73
- samples_to_generate = 12
74
- guidance_scale = 4
75
-
76
- # Папки для сохранения результатов
77
- generated_folder = "samples"
78
- os.makedirs(generated_folder, exist_ok=True)
79
-
80
- # Настройка seed для воспроизводимости
81
- current_date = datetime.now()
82
- seed = int(current_date.strftime("%Y%m%d"))
83
- if fixed_seed:
84
- torch.manual_seed(seed)
85
- np.random.seed(seed)
86
- random.seed(seed)
87
- if torch.cuda.is_available():
88
- torch.cuda.manual_seed_all(seed)
89
-
90
- # --------------------------- Параметры LoRA ---------------------------
91
- lora_name = ""
92
- lora_rank = 32
93
- lora_alpha = 64
94
-
95
- print("init")
96
-
97
- # --------------------------- вспомогательные функции ---------------------------
98
- def sample_timesteps_bias(
99
- batch_size: int,
100
- progress: float, # [0..1]
101
- num_train_timesteps: int, # обычно 1000
102
- steps_offset: int = 0,
103
- device=None,
104
- mode: str = "beta", # "beta", "uniform"
105
- ) -> torch.Tensor:
106
- """
107
- Возвращает timesteps с разным bias:
108
- - beta : как раньше (сдвиг в начало или конец в зависимости от progress)
109
- - normal : около середины (гауссовое распределение)
110
- - uniform: равномерно по всем timestep’ам
111
- """
112
-
113
- max_idx = num_train_timesteps - 1 - steps_offset
114
-
115
- if mode == "beta":
116
- alpha = 1.0 + .5 * (1.0 - progress)
117
- beta = 1.0 + .5 * progress
118
- samples = torch.distributions.Beta(alpha, beta).sample((batch_size,))
119
-
120
- elif mode == "uniform":
121
- samples = torch.rand(batch_size)
122
-
123
- else:
124
- raise ValueError(f"Unknown mode: {mode}")
125
-
126
- timesteps = steps_offset + (samples * max_idx).long().to(device)
127
- return timesteps
128
-
129
- def logit_normal_samples(shape, mu=0.0, sigma=1.0, device=None, dtype=None):
130
- normal_samples = torch.normal(mean=mu, std=sigma, size=shape, device=device, dtype=dtype)
131
-
132
- logit_normal_samples = torch.sigmoid(normal_samples)
133
-
134
- return logit_normal_samples
135
-
136
- # --------------------------- Инициализация WandB ---------------------------
137
- if accelerator.is_main_process:
138
- if use_wandb:
139
- wandb.init(project=project+lora_name, config={
140
- "batch_size": batch_size,
141
- "base_learning_rate": base_learning_rate,
142
- "num_epochs": num_epochs,
143
- "fbp": fbp,
144
- "optimizer_type": optimizer_type,
145
- })
146
- if use_comet_ml:
147
- from comet_ml import Experiment
148
- comet_experiment = Experiment(
149
- api_key=comet_ml_api_key,
150
- project_name=project,
151
- workspace=comet_ml_workspace
152
- )
153
- # Логируем гиперпа��аметры в Comet ML
154
- hyper_params = {
155
- "batch_size": batch_size,
156
- "base_learning_rate": base_learning_rate,
157
- "min_learning_rate": min_learning_rate,
158
- "num_epochs": num_epochs,
159
- "n_diffusion_steps": n_diffusion_steps,
160
- "guidance_scale": guidance_scale,
161
- "optimizer_type": optimizer_type,
162
- "mixed_precision": mixed_precision,
163
- }
164
- comet_experiment.log_parameters(hyper_params)
165
-
166
- # Включение Flash Attention 2/SDPA
167
- torch.backends.cuda.enable_flash_sdp(True)
168
- # --------------------------- Инициализация Accelerator --------------------
169
- gen = torch.Generator(device=device)
170
- gen.manual_seed(seed)
171
-
172
- # --------------------------- Загрузка моделей ---------------------------
173
- # VAE загружается на CPU для экономии GPU-памяти (как в твоём оригинальном коде)
174
- vae = AutoencoderKL.from_pretrained("AiArtLab/simplevae", subfolder="vae", torch_dtype=dtype).to("cpu").eval()
175
-
176
- shift_factor = getattr(vae.config, "shift_factor", 0.0)
177
- if shift_factor is None:
178
- shift_factor = 0.0
179
-
180
- scaling_factor = getattr(vae.config, "scaling_factor", 1.0)
181
- if scaling_factor is None:
182
- scaling_factor = 1.0
183
-
184
- latents_mean = getattr(vae.config, "latents_mean", None)
185
- latents_std = getattr(vae.config, "latents_std", None)
186
-
187
- from diffusers import FlowMatchEulerDiscreteScheduler
188
-
189
- # Подстрой под свои параметры
190
- num_train_timesteps = 1000
191
-
192
- scheduler = FlowMatchEulerDiscreteScheduler(
193
- num_train_timesteps=num_train_timesteps,
194
- #shift=3.0, # пример; подбирается при необходимости
195
- #use_dynamic_shifting=True
196
- )
197
-
198
-
199
- class DistributedResolutionBatchSampler(Sampler):
200
- def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, drop_last=True):
201
- self.dataset = dataset
202
- self.batch_size = max(1, batch_size // num_replicas)
203
- self.num_replicas = num_replicas
204
- self.rank = rank
205
- self.shuffle = shuffle
206
- self.drop_last = drop_last
207
- self.epoch = 0
208
-
209
- try:
210
- widths = np.array(dataset["width"])
211
- heights = np.array(dataset["height"])
212
- except KeyError:
213
- widths = np.zeros(len(dataset))
214
- heights = np.zeros(len(dataset))
215
-
216
- self.size_keys = np.unique(np.stack([widths, heights], axis=1), axis=0)
217
- self.size_groups = {}
218
- for w, h in self.size_keys:
219
- mask = (widths == w) & (heights == h)
220
- self.size_groups[(w, h)] = np.where(mask)[0]
221
-
222
- self.group_num_batches = {}
223
- total_batches = 0
224
- for size, indices in self.size_groups.items():
225
- num_full_batches = len(indices) // (self.batch_size * self.num_replicas)
226
- self.group_num_batches[size] = num_full_batches
227
- total_batches += num_full_batches
228
-
229
- self.num_batches = (total_batches // self.num_replicas) * self.num_replicas
230
-
231
- def __iter__(self):
232
- if torch.cuda.is_available():
233
- torch.cuda.empty_cache()
234
- all_batches = []
235
- rng = np.random.RandomState(self.epoch)
236
-
237
- for size, indices in self.size_groups.items():
238
- indices = indices.copy()
239
- if self.shuffle:
240
- rng.shuffle(indices)
241
- num_full_batches = self.group_num_batches[size]
242
- if num_full_batches == 0:
243
- continue
244
- valid_indices = indices[:num_full_batches * self.batch_size * self.num_replicas]
245
- batches = valid_indices.reshape(-1, self.batch_size * self.num_replicas)
246
- start_idx = self.rank * self.batch_size
247
- end_idx = start_idx + self.batch_size
248
- gpu_batches = batches[:, start_idx:end_idx]
249
- all_batches.extend(gpu_batches)
250
-
251
- if self.shuffle:
252
- rng.shuffle(all_batches)
253
- accelerator.wait_for_everyone()
254
- return iter(all_batches)
255
-
256
- def __len__(self):
257
- return self.num_batches
258
-
259
- def set_epoch(self, epoch):
260
- self.epoch = epoch
261
-
262
- # Функция для выборки фиксированных семплов по размерам
263
- def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
264
- size_groups = defaultdict(list)
265
- try:
266
- widths = dataset["width"]
267
- heights = dataset["height"]
268
- except KeyError:
269
- widths = [0] * len(dataset)
270
- heights = [0] * len(dataset)
271
- for i, (w, h) in enumerate(zip(widths, heights)):
272
- size = (w, h)
273
- size_groups[size].append(i)
274
-
275
- fixed_samples = {}
276
- for size, indices in size_groups.items():
277
- n_samples = min(samples_per_group, len(indices))
278
- if len(size_groups)==1:
279
- n_samples = samples_to_generate
280
- if n_samples == 0:
281
- continue
282
- sample_indices = random.sample(indices, n_samples)
283
- samples_data = [dataset[idx] for idx in sample_indices]
284
- latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device,dtype=dtype)
285
- embeddings = torch.tensor(np.array([item["embeddings"] for item in samples_data])).to(device,dtype=dtype)
286
- texts = [item["text"] for item in samples_data]
287
- fixed_samples[size] = (latents, embeddings, texts)
288
-
289
- print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
290
- return fixed_samples
291
-
292
- if limit > 0:
293
- dataset = load_from_disk(ds_path).select(range(limit))
294
- else:
295
- dataset = load_from_disk(ds_path)
296
-
297
- def collate_fn_simple(batch):
298
- latents = torch.tensor(np.array([item["vae"] for item in batch])).to(device,dtype=dtype)
299
- embeddings = torch.tensor(np.array([item["embeddings"] for item in batch])).to(device,dtype=dtype)
300
- return latents, embeddings
301
-
302
- batch_sampler = DistributedResolutionBatchSampler(
303
- dataset=dataset,
304
- batch_size=batch_size,
305
- num_replicas=accelerator.num_processes,
306
- rank=accelerator.process_index,
307
- shuffle=shuffle
308
- )
309
-
310
- dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple)
311
- print("Total samples",len(dataloader))
312
- dataloader = accelerator.prepare(dataloader)
313
-
314
- start_epoch = 0
315
- global_step = 0
316
- total_training_steps = (len(dataloader) * num_epochs)
317
- world_size = accelerator.state.num_processes
318
-
319
- # Опция загрузки модели из последнего чекпоинта (если существует)
320
- latest_checkpoint = os.path.join(checkpoints_folder, project)
321
- if os.path.isdir(latest_checkpoint):
322
- print("Загружаем UNet из чекпоинта:", latest_checkpoint)
323
- unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device,dtype=dtype)
324
- if unet_gradient:
325
- unet.enable_gradient_checkpointing()
326
- unet.set_use_memory_efficient_attention_xformers(False)
327
- try:
328
- unet.set_attn_processor(AttnProcessor2_0())
329
- except Exception as e:
330
- print(f"Ошибка при включении SDPA: {e}")
331
- unet.set_use_memory_efficient_attention_xformers(True)
332
-
333
- else:
334
- # FIX: если чекпоинта нет — прекращаем с понятной ошибкой (лучше, чем неожиданные NameError дальше)
335
- raise FileNotFoundError(f"UNet checkpoint not found at {latest_checkpoint}. Положи UNet чекпоинт в {latest_checkpoint} или укажи другой путь.")
336
-
337
- if lora_name:
338
- print(f"--- Настройка LoRA через PEFT (Rank={lora_rank}, Alpha={lora_alpha}) ---")
339
- from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
340
- from peft.tuners.lora import LoraModel
341
- import os
342
- unet.requires_grad_(False)
343
- print("Параметры базового UNet заморожены.")
344
-
345
- lora_config = LoraConfig(
346
- r=lora_rank,
347
- lora_alpha=lora_alpha,
348
- target_modules=["to_q", "to_k", "to_v", "to_out.0"],
349
- )
350
- unet.add_adapter(lora_config)
351
-
352
- from peft import get_peft_model
353
- peft_unet = get_peft_model(unet, lora_config)
354
- params_to_optimize = list(p for p in peft_unet.parameters() if p.requires_grad)
355
-
356
- if accelerator.is_main_process:
357
- lora_params_count = sum(p.numel() for p in params_to_optimize)
358
- total_params_count = sum(p.numel() for p in unet.parameters())
359
- print(f"Количество обучаемых параметров (LoRA): {lora_params_count:,}")
360
- print(f"Общее количество параметров UNet: {total_params_count:,}")
361
-
362
- lora_save_path = os.path.join("lora", lora_name)
363
- os.makedirs(lora_save_path, exist_ok=True)
364
-
365
- def save_lora_checkpoint(model):
366
- if accelerator.is_main_process:
367
- print(f"Сохраняем LoRA адаптеры в {lora_save_path}")
368
- from peft.utils.save_and_load import get_peft_model_state_dict
369
- lora_state_dict = get_peft_model_state_dict(model)
370
- torch.save(lora_state_dict, os.path.join(lora_save_path, "adapter_model.bin"))
371
- model.peft_config["default"].save_pretrained(lora_save_path)
372
- from diffusers import StableDiffusionXLPipeline
373
- StableDiffusionXLPipeline.save_lora_weights(lora_save_path, lora_state_dict)
374
-
375
- # --------------------------- Оптимизатор ---------------------------
376
- if lora_name:
377
- trainable_params = [p for p in unet.parameters() if p.requires_grad]
378
- else:
379
- if fbp:
380
- trainable_params = list(unet.parameters())
381
-
382
- def create_optimizer(name, params):
383
- if name == "adam8bit":
384
- return bnb.optim.AdamW8bit(
385
- params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.01,
386
- percentile_clipping=percentile_clipping
387
- )
388
- elif name == "adam":
389
- return torch.optim.AdamW(
390
- params, lr=base_learning_rate, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01
391
- )
392
- else:
393
- raise ValueError(f"Unknown optimizer: {name}")
394
-
395
- if fbp:
396
- optimizer_dict = {p: create_optimizer(optimizer_type, [p]) for p in trainable_params}
397
- def optimizer_hook(param):
398
- optimizer_dict[param].step()
399
- optimizer_dict[param].zero_grad(set_to_none=True)
400
- for param in trainable_params:
401
- param.register_post_accumulate_grad_hook(optimizer_hook)
402
- unet, optimizer = accelerator.prepare(unet, optimizer_dict)
403
- else:
404
- optimizer = create_optimizer(optimizer_type, unet.parameters())
405
- def lr_schedule(step):
406
- x = step / (total_training_steps * world_size)
407
- warmup = warmup_percent
408
- if not use_decay:
409
- return base_learning_rate
410
- if x < warmup:
411
- return min_learning_rate + (base_learning_rate - min_learning_rate) * (x / warmup)
412
- decay_ratio = (x - warmup) / (1 - warmup)
413
- return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
414
- (1 + math.cos(math.pi * decay_ratio))
415
- lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
416
-
417
- num_params = sum(p.numel() for p in unet.parameters())
418
- print(f"[rank {accelerator.process_index}] total params: {num_params}")
419
- for name, param in unet.named_parameters():
420
- if torch.isnan(param).any() or torch.isinf(param).any():
421
- print(f"[rank {accelerator.process_index}] NaN/Inf in {name}")
422
- unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
423
-
424
- if torch_compile:
425
- print("compiling")
426
- torch.set_float32_matmul_precision('high')
427
- torch.backends.cudnn.allow_tf32 = True
428
- torch.backends.cuda.matmul.allow_tf32 = True
429
- unet = torch.compile(unet)#, mode='max-autotune')
430
- print("compiling - ok")
431
-
432
- # --------------------------- Фиксированные семплы для генерации ---------------------------
433
- fixed_samples = get_fixed_samples_by_resolution(dataset)
434
-
435
- def get_negative_embedding(neg_prompt="", batch_size=1):
436
- """
437
- Возвращает эмбеддинг негативного промпта с батчем.
438
- Загружает модели, вычисляет эмбеддинг, выгружает модели на CPU.
439
- """
440
- import torch
441
- from transformers import AutoTokenizer, AutoModel
442
-
443
- # Настройки
444
- dtype = torch.float16
445
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
446
-
447
- # Загрузка моделей (если ещё не загружены)
448
- if not hasattr(get_negative_embedding, "tokenizer"):
449
- get_negative_embedding.tokenizer = AutoTokenizer.from_pretrained(
450
- "Qwen/Qwen3-0.6B"
451
- )
452
- get_negative_embedding.text_model = AutoModel.from_pretrained(
453
- "Qwen/Qwen3-0.6B"
454
- ).to(device).eval()
455
-
456
- # Вычисление эмбеддинга
457
- def encode_texts(texts, max_length=150):
458
- with torch.inference_mode():
459
- toks = get_negative_embedding.tokenizer(
460
- texts, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length
461
- ).to(device)
462
-
463
- outs = get_negative_embedding.text_model(**toks, output_hidden_states=True)
464
- hidden_states = outs.hidden_states[-1] # [B, L, D]
465
- return hidden_states
466
-
467
- # Возвращаем эмбеддинг
468
- if not neg_prompt:
469
- hidden_dim = 1024 # Размерность эмбеддинга Qwen3-Embedding-0.6B
470
- seq_len = 150
471
- return torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device)
472
-
473
- uncond_emb = encode_texts([neg_prompt]).to(dtype=dtype, device=device)
474
- uncond_emb = uncond_emb.repeat(batch_size, 1, 1) # Добавляем батч
475
-
476
- # Выгружаем модели
477
- if hasattr(get_negative_embedding, "text_model"):
478
- get_negative_embedding.text_model = get_negative_embedding.text_model.to("cpu")
479
- if hasattr(get_negative_embedding, "tokenizer"):
480
- del get_negative_embedding.tokenizer # Освобождаем память
481
- torch.cuda.empty_cache()
482
-
483
- return uncond_emb
484
-
485
- uncond_emb = get_negative_embedding("low quality")
486
-
487
- @torch.compiler.disable()
488
- @torch.no_grad()
489
- def generate_and_save_samples(fixed_samples_cpu,empty_embeddings, step):
490
- original_model = None
491
- try:
492
- # безопасный unwrap: если компилировано, unwrap не нужен
493
- if not torch_compile:
494
- original_model = accelerator.unwrap_model(unet, keep_torch_compile=True).eval()
495
- else:
496
- original_model = unet.eval()
497
-
498
- vae.to(device=device).eval() # временно подгружаем VAE на GPU для декодинга
499
-
500
-
501
- all_generated_images = []
502
- all_captions = []
503
-
504
- for size, (sample_latents, sample_text_embeddings, sample_text) in fixed_samples_cpu.items():
505
- width, height = size
506
- sample_latents = sample_latents.to(dtype=dtype, device=device)
507
- sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
508
-
509
- # начальный шум
510
- latents = torch.randn(
511
- sample_latents.shape,
512
- device=device,
513
- dtype=sample_latents.dtype,
514
- generator=torch.Generator(device=device).manual_seed(seed)
515
- )
516
-
517
- # подготовим timesteps через шедулер
518
- scheduler.set_timesteps(n_diffusion_steps, device=device)
519
-
520
- for t in scheduler.timesteps:
521
- # guidance: удваиваем батч
522
- if guidance_scale != 1:
523
- latent_model_input = torch.cat([latents, latents], dim=0)
524
-
525
- # empty_embeddings: [1, 1, hidden_dim] → повторяем по seq_len и batch
526
- seq_len = sample_text_embeddings.shape[1]
527
- hidden_dim = sample_text_embeddings.shape[2]
528
- empty_embeddings_exp = empty_embeddings.expand(-1, seq_len, hidden_dim) # [1, seq_len, hidden_dim]
529
- empty_embeddings_exp = empty_embeddings_exp.repeat(sample_text_embeddings.shape[0], 1, 1) # [batch, seq_len, hidden_dim]
530
-
531
- text_embeddings_batch = torch.cat([empty_embeddings_exp, sample_text_embeddings], dim=0)
532
- else:
533
- latent_model_input = latents
534
- text_embeddings_batch = sample_text_embeddings
535
-
536
-
537
-
538
- # предсказание потока (velocity)
539
- model_out = original_model(latent_model_input, t, encoder_hidden_states=text_embeddings_batch)
540
- flow = getattr(model_out, "sample", model_out)
541
-
542
- # guidance объединение
543
- if guidance_scale != 1:
544
- flow_uncond, flow_cond = flow.chunk(2)
545
- flow = flow_uncond + guidance_scale * (flow_cond - flow_uncond)
546
-
547
- # шаг через scheduler
548
- latents = scheduler.step(flow, t, latents).prev_sample
549
-
550
- current_latents = latents
551
-
552
-
553
- # Параметры нормализации
554
- latent_for_vae = current_latents.detach() / scaling_factor + shift_factor
555
-
556
- decoded = vae.decode(latent_for_vae.to(torch.float32)).sample
557
- #decoded = decoded[:, :, 0, :, :] # [3, H, W]
558
- #print(decoded.ndim, decoded.shape)
559
-
560
- decoded_fp32 = decoded.to(torch.float32)
561
- for img_idx, img_tensor in enumerate(decoded_fp32):
562
-
563
- # Форма: [3, H, W] -> преобразуем в [H, W, 3]
564
- img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy()
565
- img = img.transpose(1, 2, 0) # Из [3, H, W] в [H, W, 3]
566
-
567
- #img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1, 2, 0)
568
- if np.isnan(img).any():
569
- print("NaNs found, saving stopped! Step:", step)
570
- pil_img = Image.fromarray((img * 255).astype("uint8"))
571
-
572
- max_w_overall = max(s[0] for s in fixed_samples_cpu.keys())
573
- max_h_overall = max(s[1] for s in fixed_samples_cpu.keys())
574
- max_w_overall = max(255, max_w_overall)
575
- max_h_overall = max(255, max_h_overall)
576
-
577
- padded_img = ImageOps.pad(pil_img, (max_w_overall, max_h_overall), color='white')
578
- all_generated_images.append(padded_img)
579
-
580
- caption_text = sample_text[img_idx][:200] if img_idx < len(sample_text) else ""
581
- all_captions.append(caption_text)
582
-
583
- sample_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg"
584
- pil_img.save(sample_path, "JPEG", quality=96)
585
-
586
- if use_wandb and accelerator.is_main_process:
587
- wandb_images = [
588
- wandb.Image(img, caption=f"{all_captions[i]}")
589
- for i, img in enumerate(all_generated_images)
590
- ]
591
- wandb.log({"generated_images": wandb_images})
592
- if use_comet_ml and accelerator.is_main_process:
593
- for i, img in enumerate(all_generated_images):
594
- comet_experiment.log_image(
595
- image_data=img,
596
- name=f"step_{step}_img_{i}",
597
- step=step,
598
- metadata={
599
- "caption": all_captions[i],
600
- "width": img.width,
601
- "height": img.height,
602
- "global_step": step
603
- }
604
- )
605
- finally:
606
- # вернуть VAE на CPU (как было в твоём коде)
607
- vae.to("cpu")
608
- for var in list(locals().keys()):
609
- if isinstance(locals()[var], torch.Tensor):
610
- del locals()[var]
611
- torch.cuda.empty_cache()
612
- gc.collect()
613
-
614
- # --------------------------- Генерация сэмплов перед обучением ---------------------------
615
- if accelerator.is_main_process:
616
- if save_model:
617
- print("Генерация сэмплов до старта обучения...")
618
- generate_and_save_samples(fixed_samples,uncond_emb,0)
619
- accelerator.wait_for_everyone()
620
-
621
- # Модифицируем функцию сохранения модели для поддержки LoRA
622
- def save_checkpoint(unet, variant=""):
623
- if accelerator.is_main_process:
624
- if lora_name:
625
- save_lora_checkpoint(unet)
626
- else:
627
- # безопасный unwrap для компилированной модели
628
- model_to_save = None
629
- if not torch_compile:
630
- model_to_save = accelerator.unwrap_model(unet)
631
- else:
632
- model_to_save = unet
633
-
634
- if variant != "":
635
- model_to_save.to(dtype=torch.float16).save_pretrained(
636
- os.path.join(checkpoints_folder, f"{project}"), variant=variant
637
- )
638
- else:
639
- model_to_save.save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
640
-
641
- unet = unet.to(dtype=dtype)
642
-
643
- # --------------------------- Тренировочный цикл ---------------------------
644
- if accelerator.is_main_process:
645
- print(f"Total steps per GPU: {total_training_steps}")
646
-
647
- epoch_loss_points = []
648
- progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
649
-
650
- steps_per_epoch = len(dataloader)
651
- sample_interval = max(1, steps_per_epoch // sample_interval_share)
652
- min_loss = 2.
653
-
654
- for epoch in range(start_epoch, start_epoch + num_epochs):
655
- batch_losses = []
656
- batch_grads = []
657
- batch_sampler.set_epoch(epoch)
658
- accelerator.wait_for_everyone()
659
- unet.train()
660
- #print("epoch:",epoch)
661
- for step, (latents, embeddings) in enumerate(dataloader):
662
- with accelerator.accumulate(unet):
663
- if save_model == False and step == 5 :
664
- used_gb = torch.cuda.max_memory_allocated() / 1024**3
665
- print(f"Шаг {step}: {used_gb:.2f} GB")
666
-
667
- # шум
668
- noise = torch.randn_like(latents, dtype=latents.dtype)
669
-
670
- # берём t из [0, 1]
671
- t = torch.rand(latents.shape[0], device=latents.device, dtype=latents.dtype)
672
-
673
- # интерполяция между x0 и шумом
674
- noisy_latents = (1.0 - t.view(-1, 1, 1, 1)) * latents + t.view(-1, 1, 1, 1) * noise
675
-
676
- # делаем integer timesteps для UNet
677
- timesteps = (t * scheduler.config.num_train_timesteps).long()
678
-
679
- # предсказание потока (Flow)
680
- model_pred = unet(noisy_latents, timesteps, embeddings).sample
681
-
682
- # таргет — векторное поле (= разность между конечными точками)
683
- target = noise - latents # или latents - noise?
684
-
685
- # MSE лосс
686
- mse_loss = F.mse_loss(model_pred.float(), target.float())
687
-
688
- # Сохраняем для логов (мы сохраняем MSE отдельно — как показатель)
689
- batch_losses.append(mse_loss.detach().item())
690
-
691
- if (global_step % 100 == 0) or (global_step % sample_interval == 0):
692
- accelerator.wait_for_everyone()
693
-
694
- # Backward
695
- accelerator.backward(mse_loss)
696
-
697
- if (global_step % 100 == 0) or (global_step % sample_interval == 0):
698
- accelerator.wait_for_everyone()
699
-
700
- grad = 0.0
701
- if not fbp:
702
- if accelerator.sync_gradients:
703
- with torch.amp.autocast('cuda', enabled=False):
704
- grad_val = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm)
705
- grad = float(grad_val)
706
- optimizer.step()
707
- lr_scheduler.step()
708
- optimizer.zero_grad(set_to_none=True)
709
-
710
- if accelerator.sync_gradients:
711
- global_step += 1
712
- progress_bar.update(1)
713
- # Логируем метрики
714
- if accelerator.is_main_process:
715
- if fbp:
716
- current_lr = base_learning_rate
717
- else:
718
- current_lr = lr_scheduler.get_last_lr()[0]
719
- batch_grads.append(grad)
720
-
721
- log_data = {}
722
- log_data["loss"] = mse_loss.detach().item()
723
- log_data["lr"] = current_lr
724
- log_data["grad"] = grad
725
- if accelerator.sync_gradients:
726
- if use_wandb:
727
- wandb.log(log_data, step=global_step)
728
- if use_comet_ml:
729
- comet_experiment.log_metrics(log_data, step=global_step)
730
-
731
- # Генерируем сэмплы с заданным интервалом
732
- if global_step % sample_interval == 0:
733
- generate_and_save_samples(fixed_samples,uncond_emb, global_step)
734
- last_n = sample_interval
735
-
736
- if save_model:
737
- avg_sample_loss = np.mean(batch_losses[-sample_interval:]) if len(batch_losses) > 0 else 0.0
738
- print("saving:", avg_sample_loss < min_loss * save_barrier, "Avg:", avg_sample_loss)
739
- if avg_sample_loss is not None and avg_sample_loss < min_loss * save_barrier:
740
- min_loss = avg_sample_loss
741
- save_checkpoint(unet)
742
-
743
-
744
- if accelerator.is_main_process:
745
- # local averages
746
- avg_epoch_loss = np.mean(batch_losses) if len(batch_losses) > 0 else 0.0
747
- avg_epoch_grad = np.mean(batch_grads) if len(batch_grads) > 0 else 0.0
748
-
749
- print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
750
- log_data_ep = {
751
- "epoch_loss": avg_epoch_loss,
752
- "epoch_grad": avg_epoch_grad,
753
- "epoch": epoch + 1,
754
- }
755
- if use_wandb:
756
- wandb.log(log_data_ep)
757
- if use_comet_ml:
758
- comet_experiment.log_metrics(log_data_ep)
759
-
760
- # Завершение обучения - сохраняем финальную модель
761
- if accelerator.is_main_process:
762
- print("Обучение завершено! Сохраняем финальную модель...")
763
- if save_model:
764
- save_checkpoint(unet,"fp16")
765
- if use_comet_ml:
766
- comet_experiment.end()
767
- accelerator.free_memory()
768
- if torch.distributed.is_initialized():
769
- torch.distributed.destroy_process_group()
770
-
771
- print("Готово!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train-Copy2.py DELETED
@@ -1,747 +0,0 @@
1
- import os
2
- import math
3
- import torch
4
- import numpy as np
5
- import matplotlib.pyplot as plt
6
- from torch.utils.data import DataLoader, Sampler
7
- from torch.utils.data.distributed import DistributedSampler
8
- from torch.optim.lr_scheduler import LambdaLR
9
- from collections import defaultdict
10
- from torch.optim.lr_scheduler import LambdaLR
11
- from diffusers import UNet2DConditionModel, AutoencoderKLWan,AutoencoderKL
12
- from accelerate import Accelerator
13
- from datasets import load_from_disk
14
- from tqdm import tqdm
15
- from PIL import Image,ImageOps
16
- import wandb
17
- import random
18
- import gc
19
- from accelerate.state import DistributedType
20
- from torch.distributed import broadcast_object_list
21
- from torch.utils.checkpoint import checkpoint
22
- from diffusers.models.attention_processor import AttnProcessor2_0
23
- from datetime import datetime
24
- import bitsandbytes as bnb
25
- import torch.nn.functional as F
26
- from collections import deque
27
-
28
- # --------------------------- Параметры ---------------------------
29
- ds_path = "/workspace/sdxs/datasets/ds1234_640"
30
- project = "unet"
31
- batch_size = 64
32
- base_learning_rate = 6e-5
33
- min_learning_rate = 2.5e-5
34
- num_epochs = 80
35
- # samples/save per epoch
36
- sample_interval_share = 10
37
- use_wandb = True
38
- use_comet_ml = False
39
- save_model = True
40
- use_decay = True
41
- fbp = False # fused backward pass
42
- optimizer_type = "adam8bit"
43
- torch_compile = False
44
- unet_gradient = True
45
- clip_sample = False #Scheduler
46
- fixed_seed = False
47
- shuffle = True
48
- comet_ml_api_key = "Agctp26mbqnoYrrlvQuKSTk6r" # Добавлен API ключ для Comet ML
49
- comet_ml_workspace = "recoilme" # Добавлен workspace для Comet ML
50
- torch.backends.cuda.matmul.allow_tf32 = True
51
- torch.backends.cudnn.allow_tf32 = True
52
- torch.backends.cuda.enable_mem_efficient_sdp(False)
53
- dtype = torch.float32
54
- save_barrier = 1.006
55
- warmup_percent = 0.01
56
- percentile_clipping = 99 # 8bit optim
57
- betta2 = 0.99
58
- eps = 1e-8
59
- clip_grad_norm = 1.0
60
- steps_offset = 0 # Scheduler
61
- limit = 0
62
- checkpoints_folder = ""
63
- mixed_precision = "no" #"fp16"
64
- gradient_accumulation_steps = 1
65
- accelerator = Accelerator(
66
- mixed_precision=mixed_precision,
67
- gradient_accumulation_steps=gradient_accumulation_steps
68
- )
69
- device = accelerator.device
70
-
71
- # Параметры для диффузии
72
- n_diffusion_steps = 50
73
- samples_to_generate = 12
74
- guidance_scale = 4
75
-
76
- # Папки для сохранения результатов
77
- generated_folder = "samples"
78
- os.makedirs(generated_folder, exist_ok=True)
79
-
80
- # Настройка seed для воспроизводимости
81
- current_date = datetime.now()
82
- seed = int(current_date.strftime("%Y%m%d"))
83
- if fixed_seed:
84
- torch.manual_seed(seed)
85
- np.random.seed(seed)
86
- random.seed(seed)
87
- if torch.cuda.is_available():
88
- torch.cuda.manual_seed_all(seed)
89
-
90
- # --------------------------- Параметры LoRA ---------------------------
91
- lora_name = ""
92
- lora_rank = 32
93
- lora_alpha = 64
94
-
95
- print("init")
96
-
97
- # --------------------------- Инициализация WandB ---------------------------
98
- if accelerator.is_main_process:
99
- if use_wandb:
100
- wandb.init(project=project+lora_name, config={
101
- "batch_size": batch_size,
102
- "base_learning_rate": base_learning_rate,
103
- "num_epochs": num_epochs,
104
- "fbp": fbp,
105
- "optimizer_type": optimizer_type,
106
- })
107
- if use_comet_ml:
108
- from comet_ml import Experiment
109
- comet_experiment = Experiment(
110
- api_key=comet_ml_api_key,
111
- project_name=project,
112
- workspace=comet_ml_workspace
113
- )
114
- # Логируем гиперпараметры в Comet ML
115
- hyper_params = {
116
- "batch_size": batch_size,
117
- "base_learning_rate": base_learning_rate,
118
- "min_learning_rate": min_learning_rate,
119
- "num_epochs": num_epochs,
120
- "n_diffusion_steps": n_diffusion_steps,
121
- "guidance_scale": guidance_scale,
122
- "optimizer_type": optimizer_type,
123
- "mixed_precision": mixed_precision,
124
- }
125
- comet_experiment.log_parameters(hyper_params)
126
-
127
- # Включение Flash Attention 2/SDPA
128
- torch.backends.cuda.enable_flash_sdp(True)
129
- # --------------------------- Инициализация Accelerator --------------------
130
- gen = torch.Generator(device=device)
131
- gen.manual_seed(seed)
132
-
133
- # --------------------------- Загрузка моделей ---------------------------
134
- # VAE загружается на CPU для экономии GPU-памяти (как в твоём оригинальном коде)
135
- vae = AutoencoderKL.from_pretrained("AiArtLab/simplevae", subfolder="vae", torch_dtype=dtype).to("cpu").eval()
136
-
137
- shift_factor = getattr(vae.config, "shift_factor", 0.0)
138
- if shift_factor is None:
139
- shift_factor = 0.0
140
-
141
- scaling_factor = getattr(vae.config, "scaling_factor", 1.0)
142
- if scaling_factor is None:
143
- scaling_factor = 1.0
144
-
145
- latents_mean = getattr(vae.config, "latents_mean", None)
146
- latents_std = getattr(vae.config, "latents_std", None)
147
-
148
- from diffusers import FlowMatchEulerDiscreteScheduler
149
-
150
- # Подстрой под свои параметры
151
- num_train_timesteps = 1000
152
-
153
- scheduler = FlowMatchEulerDiscreteScheduler(
154
- num_train_timesteps=num_train_timesteps,
155
- #shift=3.0, # пример; подбирается при необходимости
156
- #use_dynamic_shifting=True
157
- )
158
-
159
-
160
- class DistributedResolutionBatchSampler(Sampler):
161
- def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, drop_last=True):
162
- self.dataset = dataset
163
- self.batch_size = max(1, batch_size // num_replicas)
164
- self.num_replicas = num_replicas
165
- self.rank = rank
166
- self.shuffle = shuffle
167
- self.drop_last = drop_last
168
- self.epoch = 0
169
-
170
- try:
171
- widths = np.array(dataset["width"])
172
- heights = np.array(dataset["height"])
173
- except KeyError:
174
- widths = np.zeros(len(dataset))
175
- heights = np.zeros(len(dataset))
176
-
177
- self.size_keys = np.unique(np.stack([widths, heights], axis=1), axis=0)
178
- self.size_groups = {}
179
- for w, h in self.size_keys:
180
- mask = (widths == w) & (heights == h)
181
- self.size_groups[(w, h)] = np.where(mask)[0]
182
-
183
- self.group_num_batches = {}
184
- total_batches = 0
185
- for size, indices in self.size_groups.items():
186
- num_full_batches = len(indices) // (self.batch_size * self.num_replicas)
187
- self.group_num_batches[size] = num_full_batches
188
- total_batches += num_full_batches
189
-
190
- self.num_batches = (total_batches // self.num_replicas) * self.num_replicas
191
-
192
- def __iter__(self):
193
- if torch.cuda.is_available():
194
- torch.cuda.empty_cache()
195
- all_batches = []
196
- rng = np.random.RandomState(self.epoch)
197
-
198
- for size, indices in self.size_groups.items():
199
- indices = indices.copy()
200
- if self.shuffle:
201
- rng.shuffle(indices)
202
- num_full_batches = self.group_num_batches[size]
203
- if num_full_batches == 0:
204
- continue
205
- valid_indices = indices[:num_full_batches * self.batch_size * self.num_replicas]
206
- batches = valid_indices.reshape(-1, self.batch_size * self.num_replicas)
207
- start_idx = self.rank * self.batch_size
208
- end_idx = start_idx + self.batch_size
209
- gpu_batches = batches[:, start_idx:end_idx]
210
- all_batches.extend(gpu_batches)
211
-
212
- if self.shuffle:
213
- rng.shuffle(all_batches)
214
- accelerator.wait_for_everyone()
215
- return iter(all_batches)
216
-
217
- def __len__(self):
218
- return self.num_batches
219
-
220
- def set_epoch(self, epoch):
221
- self.epoch = epoch
222
-
223
- # Функция для выборки фиксированных семплов по размерам
224
- def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
225
- size_groups = defaultdict(list)
226
- try:
227
- widths = dataset["width"]
228
- heights = dataset["height"]
229
- except KeyError:
230
- widths = [0] * len(dataset)
231
- heights = [0] * len(dataset)
232
- for i, (w, h) in enumerate(zip(widths, heights)):
233
- size = (w, h)
234
- size_groups[size].append(i)
235
-
236
- fixed_samples = {}
237
- for size, indices in size_groups.items():
238
- n_samples = min(samples_per_group, len(indices))
239
- if len(size_groups)==1:
240
- n_samples = samples_to_generate
241
- if n_samples == 0:
242
- continue
243
- sample_indices = random.sample(indices, n_samples)
244
- samples_data = [dataset[idx] for idx in sample_indices]
245
- latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device,dtype=dtype)
246
- embeddings = torch.tensor(np.array([item["embeddings"] for item in samples_data])).to(device,dtype=dtype)
247
- texts = [item["text"] for item in samples_data]
248
- fixed_samples[size] = (latents, embeddings, texts)
249
-
250
- print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
251
- return fixed_samples
252
-
253
- if limit > 0:
254
- dataset = load_from_disk(ds_path).select(range(limit))
255
- else:
256
- dataset = load_from_disk(ds_path)
257
-
258
- def collate_fn_simple(batch):
259
- latents = torch.tensor(np.array([item["vae"] for item in batch])).to(device,dtype=dtype)
260
- embeddings = torch.tensor(np.array([item["embeddings"] for item in batch])).to(device,dtype=dtype)
261
- attention_mask = torch.abs(embeddings).sum(dim=-1) > 1e-6
262
- attention_mask = attention_mask.to(device, dtype=torch.int64)
263
- return latents, embeddings, attention_mask
264
-
265
- batch_sampler = DistributedResolutionBatchSampler(
266
- dataset=dataset,
267
- batch_size=batch_size,
268
- num_replicas=accelerator.num_processes,
269
- rank=accelerator.process_index,
270
- shuffle=shuffle
271
- )
272
-
273
- dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple)
274
- print("Total samples",len(dataloader))
275
- dataloader = accelerator.prepare(dataloader)
276
-
277
- start_epoch = 0
278
- global_step = 0
279
- total_training_steps = (len(dataloader) * num_epochs)
280
- world_size = accelerator.state.num_processes
281
-
282
- # Опция загрузки модели из последнего чекпоинта (если существует)
283
- latest_checkpoint = os.path.join(checkpoints_folder, project)
284
- if os.path.isdir(latest_checkpoint):
285
- print("Загружаем UNet из чекпоинта:", latest_checkpoint)
286
- unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device,dtype=dtype)
287
- if unet_gradient:
288
- unet.enable_gradient_checkpointing()
289
- unet.set_use_memory_efficient_attention_xformers(False)
290
- try:
291
- unet.set_attn_processor(AttnProcessor2_0())
292
- except Exception as e:
293
- print(f"Ошибка при включении SDPA: {e}")
294
- unet.set_use_memory_efficient_attention_xformers(True)
295
-
296
- else:
297
- # FIX: если чекпоинта нет — прекращаем с понятной ошибкой (лучше, чем неожиданные NameError дальше)
298
- raise FileNotFoundError(f"UNet checkpoint not found at {latest_checkpoint}. Положи UNet чекпоинт в {latest_checkpoint} или укажи другой путь.")
299
-
300
- if lora_name:
301
- print(f"--- Настройка LoRA через PEFT (Rank={lora_rank}, Alpha={lora_alpha}) ---")
302
- from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
303
- from peft.tuners.lora import LoraModel
304
- import os
305
- unet.requires_grad_(False)
306
- print("Параметры базового UNet заморожены.")
307
-
308
- lora_config = LoraConfig(
309
- r=lora_rank,
310
- lora_alpha=lora_alpha,
311
- target_modules=["to_q", "to_k", "to_v", "to_out.0"],
312
- )
313
- unet.add_adapter(lora_config)
314
-
315
- from peft import get_peft_model
316
- peft_unet = get_peft_model(unet, lora_config)
317
- params_to_optimize = list(p for p in peft_unet.parameters() if p.requires_grad)
318
-
319
- if accelerator.is_main_process:
320
- lora_params_count = sum(p.numel() for p in params_to_optimize)
321
- total_params_count = sum(p.numel() for p in unet.parameters())
322
- print(f"Количество обучаемых параметров (LoRA): {lora_params_count:,}")
323
- print(f"Общее количество параметров UNet: {total_params_count:,}")
324
-
325
- lora_save_path = os.path.join("lora", lora_name)
326
- os.makedirs(lora_save_path, exist_ok=True)
327
-
328
- def save_lora_checkpoint(model):
329
- if accelerator.is_main_process:
330
- print(f"Сохраняем LoRA адаптеры в {lora_save_path}")
331
- from peft.utils.save_and_load import get_peft_model_state_dict
332
- lora_state_dict = get_peft_model_state_dict(model)
333
- torch.save(lora_state_dict, os.path.join(lora_save_path, "adapter_model.bin"))
334
- model.peft_config["default"].save_pretrained(lora_save_path)
335
- from diffusers import StableDiffusionXLPipeline
336
- StableDiffusionXLPipeline.save_lora_weights(lora_save_path, lora_state_dict)
337
-
338
- # --------------------------- Оптимизатор ---------------------------
339
- if lora_name:
340
- trainable_params = [p for p in unet.parameters() if p.requires_grad]
341
- else:
342
- if fbp:
343
- trainable_params = list(unet.parameters())
344
-
345
- def create_optimizer(name, params):
346
- if name == "adam8bit":
347
- return bnb.optim.AdamW8bit(
348
- params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.01,
349
- percentile_clipping=percentile_clipping
350
- )
351
- elif name == "adam":
352
- return torch.optim.AdamW(
353
- params, lr=base_learning_rate, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01
354
- )
355
- else:
356
- raise ValueError(f"Unknown optimizer: {name}")
357
-
358
- if fbp:
359
- optimizer_dict = {p: create_optimizer(optimizer_type, [p]) for p in trainable_params}
360
- def optimizer_hook(param):
361
- optimizer_dict[param].step()
362
- optimizer_dict[param].zero_grad(set_to_none=True)
363
- for param in trainable_params:
364
- param.register_post_accumulate_grad_hook(optimizer_hook)
365
- unet, optimizer = accelerator.prepare(unet, optimizer_dict)
366
- else:
367
- optimizer = create_optimizer(optimizer_type, unet.parameters())
368
- def lr_schedule(step):
369
- x = step / (total_training_steps * world_size)
370
- warmup = warmup_percent
371
- if not use_decay:
372
- return base_learning_rate
373
- if x < warmup:
374
- return min_learning_rate + (base_learning_rate - min_learning_rate) * (x / warmup)
375
- decay_ratio = (x - warmup) / (1 - warmup)
376
- return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
377
- (1 + math.cos(math.pi * decay_ratio))
378
- lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
379
-
380
- num_params = sum(p.numel() for p in unet.parameters())
381
- print(f"[rank {accelerator.process_index}] total params: {num_params}")
382
- for name, param in unet.named_parameters():
383
- if torch.isnan(param).any() or torch.isinf(param).any():
384
- print(f"[rank {accelerator.process_index}] NaN/Inf in {name}")
385
- unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
386
-
387
- if torch_compile:
388
- print("compiling")
389
- torch.set_float32_matmul_precision('high')
390
- torch.backends.cudnn.allow_tf32 = True
391
- torch.backends.cuda.matmul.allow_tf32 = True
392
- unet = torch.compile(unet)#, mode='max-autotune')
393
- print("compiling - ok")
394
-
395
- # --------------------------- Фиксированные семплы для генерации ---------------------------
396
- fixed_samples = get_fixed_samples_by_resolution(dataset)
397
-
398
- def get_negative_embedding(neg_prompt="", batch_size=1):
399
- """
400
- Возвращает эмбеддинг негативного промпта с батчем.
401
- Загружает модели, вычисляет эмбеддинг, выгружает модели на CPU.
402
- """
403
- import torch
404
- from transformers import AutoTokenizer, AutoModel
405
-
406
- # Настройки
407
- dtype = torch.float16
408
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
409
-
410
- # Загрузка моделей (если ещё не загружены)
411
- if not hasattr(get_negative_embedding, "tokenizer"):
412
- get_negative_embedding.tokenizer = AutoTokenizer.from_pretrained(
413
- "Qwen/Qwen3-0.6B"
414
- )
415
- get_negative_embedding.text_model = AutoModel.from_pretrained(
416
- "Qwen/Qwen3-0.6B"
417
- ).to(device).eval()
418
-
419
- # Вычисление эмбеддинга
420
- def encode_texts(texts, max_length=150):
421
- with torch.inference_mode():
422
- toks = get_negative_embedding.tokenizer(
423
- texts, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length
424
- ).to(device)
425
-
426
- outs = get_negative_embedding.text_model(**toks, output_hidden_states=True, return_dict=True)
427
- hidden = outs.hidden_states[-1] # [B, L, D]
428
- mask = toks["attention_mask"].unsqueeze(-1) # (B, L, 1)
429
- hidden = hidden * mask
430
-
431
- return hidden
432
-
433
- # Возвращаем эмбеддинг
434
- if not neg_prompt:
435
- hidden_dim = 1024 # Размерность эмбеддинга Qwen3-Embedding-0.6B
436
- seq_len = 150
437
- return torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device)
438
-
439
- uncond_emb = encode_texts([neg_prompt]).to(dtype=dtype, device=device)
440
- uncond_emb = uncond_emb.repeat(batch_size, 1, 1) # Добавляем батч
441
-
442
- # Выгружаем модели
443
- if 0:
444
- if hasattr(get_negative_embedding, "text_model"):
445
- get_negative_embedding.text_model = get_negative_embedding.text_model.to("cpu")
446
- if hasattr(get_negative_embedding, "tokenizer"):
447
- del get_negative_embedding.tokenizer # Освобождаем память
448
- torch.cuda.empty_cache()
449
-
450
- return uncond_emb
451
-
452
- uncond_emb = get_negative_embedding("low quality")
453
-
454
- @torch.compiler.disable()
455
- @torch.no_grad()
456
- def generate_and_save_samples(fixed_samples_cpu,empty_embeddings, step):
457
- original_model = None
458
- try:
459
- # безопасный unwrap: если компилировано, unwrap не нужен
460
- if not torch_compile:
461
- original_model = accelerator.unwrap_model(unet, keep_torch_compile=True).eval()
462
- else:
463
- original_model = unet.eval()
464
-
465
- vae.to(device=device).eval() # временно подгружаем VAE на GPU для декодинга
466
-
467
-
468
- all_generated_images = []
469
- all_captions = []
470
-
471
- for size, (sample_latents, sample_text_embeddings, sample_text) in fixed_samples_cpu.items():
472
- width, height = size
473
- sample_latents = sample_latents.to(dtype=dtype, device=device)
474
- sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
475
-
476
- # начальный шум
477
- latents = torch.randn(
478
- sample_latents.shape,
479
- device=device,
480
- dtype=sample_latents.dtype,
481
- generator=torch.Generator(device=device).manual_seed(seed)
482
- )
483
-
484
- # подготовим timesteps через шедулер
485
- scheduler.set_timesteps(n_diffusion_steps, device=device)
486
- prompt_mask = torch.abs(sample_text_embeddings).sum(dim=-1) > 1e-6 # (B, Seq)
487
- prompt_mask = prompt_mask.to(dtype=torch.int64)
488
-
489
- # Создаем маску для негатива (empty_embeddings)
490
- # empty_embeddings у вас [Batch, Seq, Dim], скорее всего там нули кроме первых токенов
491
- neg_mask = torch.abs(empty_embeddings).sum(dim=-1) > 1e-6
492
- neg_mask = neg_mask.repeat(sample_text_embeddings.shape[0], 1).to(dtype=torch.int64, device=device)
493
-
494
- for t in scheduler.timesteps:
495
- # guidance: удваиваем батч
496
- if guidance_scale != 1:
497
- latent_model_input = torch.cat([latents, latents], dim=0)
498
-
499
- # empty_embeddings: [1, 1, hidden_dim] → повторяем по seq_len и batch
500
- seq_len = sample_text_embeddings.shape[1]
501
- hidden_dim = sample_text_embeddings.shape[2]
502
- empty_embeddings_exp = empty_embeddings.expand(-1, seq_len, hidden_dim) # [1, seq_len, hidden_dim]
503
- empty_embeddings_exp = empty_embeddings_exp.repeat(sample_text_embeddings.shape[0], 1, 1) # [batch, seq_len, hidden_dim]
504
-
505
- text_embeddings_batch = torch.cat([empty_embeddings_exp, sample_text_embeddings], dim=0)
506
- attention_mask_batch = torch.cat([neg_mask, prompt_mask], dim=0)
507
- else:
508
- latent_model_input = latents
509
- text_embeddings_batch = sample_text_embeddings
510
- attention_mask_batch = prompt_mask
511
-
512
-
513
- # предсказание потока (velocity)
514
- model_out = original_model(
515
- latent_model_input,
516
- t,
517
- encoder_hidden_states=text_embeddings_batch,
518
- encoder_attention_mask=attention_mask_batch)
519
- flow = getattr(model_out, "sample", model_out)
520
-
521
- # guidance объединение
522
- if guidance_scale != 1:
523
- flow_uncond, flow_cond = flow.chunk(2)
524
- flow = flow_uncond + guidance_scale * (flow_cond - flow_uncond)
525
-
526
- # шаг через scheduler
527
- latents = scheduler.step(flow, t, latents).prev_sample
528
-
529
- current_latents = latents
530
-
531
-
532
- # Параметры нормализации
533
- latent_for_vae = current_latents.detach() / scaling_factor + shift_factor
534
- decoded = vae.decode(latent_for_vae.to(torch.float32)).sample
535
-
536
- decoded_fp32 = decoded.to(torch.float32)
537
- for img_idx, img_tensor in enumerate(decoded_fp32):
538
-
539
- # Форма: [3, H, W] -> преобразуем в [H, W, 3]
540
- img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy()
541
- img = img.transpose(1, 2, 0) # Из [3, H, W] в [H, W, 3]
542
-
543
- #img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1, 2, 0)
544
- if np.isnan(img).any():
545
- print("NaNs found, saving stopped! Step:", step)
546
- pil_img = Image.fromarray((img * 255).astype("uint8"))
547
-
548
- max_w_overall = max(s[0] for s in fixed_samples_cpu.keys())
549
- max_h_overall = max(s[1] for s in fixed_samples_cpu.keys())
550
- max_w_overall = max(255, max_w_overall)
551
- max_h_overall = max(255, max_h_overall)
552
-
553
- padded_img = ImageOps.pad(pil_img, (max_w_overall, max_h_overall), color='white')
554
- all_generated_images.append(padded_img)
555
-
556
- caption_text = sample_text[img_idx][:200] if img_idx < len(sample_text) else ""
557
- all_captions.append(caption_text)
558
-
559
- sample_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg"
560
- pil_img.save(sample_path, "JPEG", quality=96)
561
-
562
- if use_wandb and accelerator.is_main_process:
563
- wandb_images = [
564
- wandb.Image(img, caption=f"{all_captions[i]}")
565
- for i, img in enumerate(all_generated_images)
566
- ]
567
- wandb.log({"generated_images": wandb_images})
568
- if use_comet_ml and accelerator.is_main_process:
569
- for i, img in enumerate(all_generated_images):
570
- comet_experiment.log_image(
571
- image_data=img,
572
- name=f"step_{step}_img_{i}",
573
- step=step,
574
- metadata={
575
- "caption": all_captions[i],
576
- "width": img.width,
577
- "height": img.height,
578
- "global_step": step
579
- }
580
- )
581
- finally:
582
- # вернуть VAE на CPU (как было в твоём коде)
583
- vae.to("cpu")
584
- for var in list(locals().keys()):
585
- if isinstance(locals()[var], torch.Tensor):
586
- del locals()[var]
587
- torch.cuda.empty_cache()
588
- gc.collect()
589
-
590
- # --------------------------- Генерация сэмплов перед обучением ---------------------------
591
- if accelerator.is_main_process:
592
- if save_model:
593
- print("Генерация сэмплов до старта обучения...")
594
- generate_and_save_samples(fixed_samples,uncond_emb,0)
595
- accelerator.wait_for_everyone()
596
-
597
- # Модифицируем функцию сохранения модели для поддержки LoRA
598
- def save_checkpoint(unet, variant=""):
599
- if accelerator.is_main_process:
600
- if lora_name:
601
- save_lora_checkpoint(unet)
602
- else:
603
- # безопасный unwrap для компилированной модели
604
- model_to_save = None
605
- if not torch_compile:
606
- model_to_save = accelerator.unwrap_model(unet)
607
- else:
608
- model_to_save = unet
609
-
610
- if variant != "":
611
- model_to_save.to(dtype=torch.float16).save_pretrained(
612
- os.path.join(checkpoints_folder, f"{project}"), variant=variant
613
- )
614
- else:
615
- model_to_save.save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
616
-
617
- unet = unet.to(dtype=dtype)
618
-
619
- # --------------------------- Тренировочный цикл ---------------------------
620
- if accelerator.is_main_process:
621
- print(f"Total steps per GPU: {total_training_steps}")
622
-
623
- epoch_loss_points = []
624
- progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
625
-
626
- steps_per_epoch = len(dataloader)
627
- sample_interval = max(1, steps_per_epoch // sample_interval_share)
628
- min_loss = 2.
629
-
630
- for epoch in range(start_epoch, start_epoch + num_epochs):
631
- batch_losses = []
632
- batch_grads = []
633
- batch_sampler.set_epoch(epoch)
634
- accelerator.wait_for_everyone()
635
- unet.train()
636
- #print("epoch:",epoch)
637
- for step, (latents, embeddings, attention_mask) in enumerate(dataloader):
638
- with accelerator.accumulate(unet):
639
- if save_model == False and step == 5 :
640
- used_gb = torch.cuda.max_memory_allocated() / 1024**3
641
- print(f"Шаг {step}: {used_gb:.2f} GB")
642
-
643
- # шум
644
- noise = torch.randn_like(latents, dtype=latents.dtype)
645
-
646
- # берём t из [0, 1]
647
- t = torch.rand(latents.shape[0], device=latents.device, dtype=latents.dtype)
648
-
649
- # интерполяция между x0 и шумом
650
- noisy_latents = (1.0 - t.view(-1, 1, 1, 1)) * latents + t.view(-1, 1, 1, 1) * noise
651
-
652
- # делаем integer timesteps для UNet
653
- timesteps = (t * scheduler.config.num_train_timesteps).long()
654
-
655
- # предсказание потока (Flow)
656
- model_pred = unet(noisy_latents, timesteps, embeddings, encoder_attention_mask=attention_mask).sample
657
-
658
- # таргет — векторное поле (= разность между конечными точками)
659
- target = noise - latents # или latents - noise?
660
-
661
- # MSE лосс
662
- mse_loss = F.mse_loss(model_pred.float(), target.float())
663
-
664
- # Сохраняем для логов (мы сохраняем MSE отдельно — как показатель)
665
- batch_losses.append(mse_loss.detach().item())
666
-
667
- if (global_step % 100 == 0) or (global_step % sample_interval == 0):
668
- accelerator.wait_for_everyone()
669
-
670
- # Backward
671
- accelerator.backward(mse_loss)
672
-
673
- if (global_step % 100 == 0) or (global_step % sample_interval == 0):
674
- accelerator.wait_for_everyone()
675
-
676
- grad = 0.0
677
- if not fbp:
678
- if accelerator.sync_gradients:
679
- with torch.amp.autocast('cuda', enabled=False):
680
- grad_val = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm)
681
- grad = float(grad_val)
682
- optimizer.step()
683
- lr_scheduler.step()
684
- optimizer.zero_grad(set_to_none=True)
685
-
686
- if accelerator.sync_gradients:
687
- global_step += 1
688
- progress_bar.update(1)
689
- # Логируем метрики
690
- if accelerator.is_main_process:
691
- if fbp:
692
- current_lr = base_learning_rate
693
- else:
694
- current_lr = lr_scheduler.get_last_lr()[0]
695
- batch_grads.append(grad)
696
-
697
- log_data = {}
698
- log_data["loss"] = mse_loss.detach().item()
699
- log_data["lr"] = current_lr
700
- log_data["grad"] = grad
701
- if accelerator.sync_gradients:
702
- if use_wandb:
703
- wandb.log(log_data, step=global_step)
704
- if use_comet_ml:
705
- comet_experiment.log_metrics(log_data, step=global_step)
706
-
707
- # Генерируем сэмплы с заданным интервалом
708
- if global_step % sample_interval == 0:
709
- generate_and_save_samples(fixed_samples,uncond_emb, global_step)
710
- last_n = sample_interval
711
-
712
- if save_model:
713
- avg_sample_loss = np.mean(batch_losses[-sample_interval:]) if len(batch_losses) > 0 else 0.0
714
- print("saving:", avg_sample_loss < min_loss * save_barrier, "Avg:", avg_sample_loss)
715
- if avg_sample_loss is not None and avg_sample_loss < min_loss * save_barrier:
716
- min_loss = avg_sample_loss
717
- save_checkpoint(unet)
718
-
719
-
720
- if accelerator.is_main_process:
721
- # local averages
722
- avg_epoch_loss = np.mean(batch_losses) if len(batch_losses) > 0 else 0.0
723
- avg_epoch_grad = np.mean(batch_grads) if len(batch_grads) > 0 else 0.0
724
-
725
- print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
726
- log_data_ep = {
727
- "epoch_loss": avg_epoch_loss,
728
- "epoch_grad": avg_epoch_grad,
729
- "epoch": epoch + 1,
730
- }
731
- if use_wandb:
732
- wandb.log(log_data_ep)
733
- if use_comet_ml:
734
- comet_experiment.log_metrics(log_data_ep)
735
-
736
- # Завершение обучения - сохраняем финальную модель
737
- if accelerator.is_main_process:
738
- print("Обучение завершено! Сохраняем финальную модель...")
739
- if save_model:
740
- save_checkpoint(unet,"fp16")
741
- if use_comet_ml:
742
- comet_experiment.end()
743
- accelerator.free_memory()
744
- if torch.distributed.is_initialized():
745
- torch.distributed.destroy_process_group()
746
-
747
- print("Готово!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train-Copy3.py DELETED
@@ -1,747 +0,0 @@
1
- import os
2
- import math
3
- import torch
4
- import numpy as np
5
- import matplotlib.pyplot as plt
6
- from torch.utils.data import DataLoader, Sampler
7
- from torch.utils.data.distributed import DistributedSampler
8
- from torch.optim.lr_scheduler import LambdaLR
9
- from collections import defaultdict
10
- from torch.optim.lr_scheduler import LambdaLR
11
- from diffusers import UNet2DConditionModel, AutoencoderKLWan,AutoencoderKL
12
- from accelerate import Accelerator
13
- from datasets import load_from_disk
14
- from tqdm import tqdm
15
- from PIL import Image,ImageOps
16
- import wandb
17
- import random
18
- import gc
19
- from accelerate.state import DistributedType
20
- from torch.distributed import broadcast_object_list
21
- from torch.utils.checkpoint import checkpoint
22
- from diffusers.models.attention_processor import AttnProcessor2_0
23
- from datetime import datetime
24
- import bitsandbytes as bnb
25
- import torch.nn.functional as F
26
- from collections import deque
27
-
28
- # --------------------------- Параметры ---------------------------
29
- ds_path = "/workspace/sdxs/datasets/ds1234_640"
30
- project = "unet"
31
- batch_size = 64
32
- base_learning_rate = 6e-5
33
- min_learning_rate = 2.5e-5
34
- num_epochs = 80
35
- # samples/save per epoch
36
- sample_interval_share = 2
37
- use_wandb = True
38
- use_comet_ml = False
39
- save_model = True
40
- use_decay = True
41
- fbp = False # fused backward pass
42
- optimizer_type = "adam8bit"
43
- torch_compile = False
44
- unet_gradient = True
45
- clip_sample = False #Scheduler
46
- fixed_seed = False
47
- shuffle = True
48
- comet_ml_api_key = "Agctp26mbqnoYrrlvQuKSTk6r" # Добавлен API ключ для Comet ML
49
- comet_ml_workspace = "recoilme" # Добавлен workspace для Comet ML
50
- torch.backends.cuda.matmul.allow_tf32 = True
51
- torch.backends.cudnn.allow_tf32 = True
52
- torch.backends.cuda.enable_mem_efficient_sdp(False)
53
- dtype = torch.float32
54
- save_barrier = 1.006
55
- warmup_percent = 0.01
56
- percentile_clipping = 99 # 8bit optim
57
- betta2 = 0.99
58
- eps = 1e-8
59
- clip_grad_norm = 1.0
60
- steps_offset = 0 # Scheduler
61
- limit = 0
62
- checkpoints_folder = ""
63
- mixed_precision = "no" #"fp16"
64
- gradient_accumulation_steps = 1
65
- accelerator = Accelerator(
66
- mixed_precision=mixed_precision,
67
- gradient_accumulation_steps=gradient_accumulation_steps
68
- )
69
- device = accelerator.device
70
-
71
- # Параметры для диффузии
72
- n_diffusion_steps = 50
73
- samples_to_generate = 12
74
- guidance_scale = 4
75
-
76
- # Папки для сохранения результатов
77
- generated_folder = "samples"
78
- os.makedirs(generated_folder, exist_ok=True)
79
-
80
- # Настройка seed для воспроизводимости
81
- current_date = datetime.now()
82
- seed = int(current_date.strftime("%Y%m%d"))
83
- if fixed_seed:
84
- torch.manual_seed(seed)
85
- np.random.seed(seed)
86
- random.seed(seed)
87
- if torch.cuda.is_available():
88
- torch.cuda.manual_seed_all(seed)
89
-
90
- # --------------------------- Параметры LoRA ---------------------------
91
- lora_name = ""
92
- lora_rank = 32
93
- lora_alpha = 64
94
-
95
- print("init")
96
-
97
- # --------------------------- Инициализация WandB ---------------------------
98
- if accelerator.is_main_process:
99
- if use_wandb:
100
- wandb.init(project=project+lora_name, config={
101
- "batch_size": batch_size,
102
- "base_learning_rate": base_learning_rate,
103
- "num_epochs": num_epochs,
104
- "fbp": fbp,
105
- "optimizer_type": optimizer_type,
106
- })
107
- if use_comet_ml:
108
- from comet_ml import Experiment
109
- comet_experiment = Experiment(
110
- api_key=comet_ml_api_key,
111
- project_name=project,
112
- workspace=comet_ml_workspace
113
- )
114
- # Логируем гиперпараметры в Comet ML
115
- hyper_params = {
116
- "batch_size": batch_size,
117
- "base_learning_rate": base_learning_rate,
118
- "min_learning_rate": min_learning_rate,
119
- "num_epochs": num_epochs,
120
- "n_diffusion_steps": n_diffusion_steps,
121
- "guidance_scale": guidance_scale,
122
- "optimizer_type": optimizer_type,
123
- "mixed_precision": mixed_precision,
124
- }
125
- comet_experiment.log_parameters(hyper_params)
126
-
127
- # Включение Flash Attention 2/SDPA
128
- torch.backends.cuda.enable_flash_sdp(True)
129
- # --------------------------- Инициализация Accelerator --------------------
130
- gen = torch.Generator(device=device)
131
- gen.manual_seed(seed)
132
-
133
- # --------------------------- Загрузка моделей ---------------------------
134
- # VAE загружается на CPU для экономии GPU-памяти (как в твоём оригинальном коде)
135
- vae = AutoencoderKL.from_pretrained("AiArtLab/simplevae", subfolder="vae", torch_dtype=dtype).to("cpu").eval()
136
-
137
- shift_factor = getattr(vae.config, "shift_factor", 0.0)
138
- if shift_factor is None:
139
- shift_factor = 0.0
140
-
141
- scaling_factor = getattr(vae.config, "scaling_factor", 1.0)
142
- if scaling_factor is None:
143
- scaling_factor = 1.0
144
-
145
- latents_mean = getattr(vae.config, "latents_mean", None)
146
- latents_std = getattr(vae.config, "latents_std", None)
147
-
148
- from diffusers import FlowMatchEulerDiscreteScheduler
149
-
150
- # Подстрой под свои параметры
151
- num_train_timesteps = 1000
152
-
153
- scheduler = FlowMatchEulerDiscreteScheduler(
154
- num_train_timesteps=num_train_timesteps,
155
- #shift=3.0, # пример; подбирается при необходимости
156
- #use_dynamic_shifting=True
157
- )
158
-
159
-
160
- class DistributedResolutionBatchSampler(Sampler):
161
- def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, drop_last=True):
162
- self.dataset = dataset
163
- self.batch_size = max(1, batch_size // num_replicas)
164
- self.num_replicas = num_replicas
165
- self.rank = rank
166
- self.shuffle = shuffle
167
- self.drop_last = drop_last
168
- self.epoch = 0
169
-
170
- try:
171
- widths = np.array(dataset["width"])
172
- heights = np.array(dataset["height"])
173
- except KeyError:
174
- widths = np.zeros(len(dataset))
175
- heights = np.zeros(len(dataset))
176
-
177
- self.size_keys = np.unique(np.stack([widths, heights], axis=1), axis=0)
178
- self.size_groups = {}
179
- for w, h in self.size_keys:
180
- mask = (widths == w) & (heights == h)
181
- self.size_groups[(w, h)] = np.where(mask)[0]
182
-
183
- self.group_num_batches = {}
184
- total_batches = 0
185
- for size, indices in self.size_groups.items():
186
- num_full_batches = len(indices) // (self.batch_size * self.num_replicas)
187
- self.group_num_batches[size] = num_full_batches
188
- total_batches += num_full_batches
189
-
190
- self.num_batches = (total_batches // self.num_replicas) * self.num_replicas
191
-
192
- def __iter__(self):
193
- if torch.cuda.is_available():
194
- torch.cuda.empty_cache()
195
- all_batches = []
196
- rng = np.random.RandomState(self.epoch)
197
-
198
- for size, indices in self.size_groups.items():
199
- indices = indices.copy()
200
- if self.shuffle:
201
- rng.shuffle(indices)
202
- num_full_batches = self.group_num_batches[size]
203
- if num_full_batches == 0:
204
- continue
205
- valid_indices = indices[:num_full_batches * self.batch_size * self.num_replicas]
206
- batches = valid_indices.reshape(-1, self.batch_size * self.num_replicas)
207
- start_idx = self.rank * self.batch_size
208
- end_idx = start_idx + self.batch_size
209
- gpu_batches = batches[:, start_idx:end_idx]
210
- all_batches.extend(gpu_batches)
211
-
212
- if self.shuffle:
213
- rng.shuffle(all_batches)
214
- accelerator.wait_for_everyone()
215
- return iter(all_batches)
216
-
217
- def __len__(self):
218
- return self.num_batches
219
-
220
- def set_epoch(self, epoch):
221
- self.epoch = epoch
222
-
223
- # Функция для выборки фиксированных семплов по размерам
224
- def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
225
- size_groups = defaultdict(list)
226
- try:
227
- widths = dataset["width"]
228
- heights = dataset["height"]
229
- except KeyError:
230
- widths = [0] * len(dataset)
231
- heights = [0] * len(dataset)
232
- for i, (w, h) in enumerate(zip(widths, heights)):
233
- size = (w, h)
234
- size_groups[size].append(i)
235
-
236
- fixed_samples = {}
237
- for size, indices in size_groups.items():
238
- n_samples = min(samples_per_group, len(indices))
239
- if len(size_groups)==1:
240
- n_samples = samples_to_generate
241
- if n_samples == 0:
242
- continue
243
- sample_indices = random.sample(indices, n_samples)
244
- samples_data = [dataset[idx] for idx in sample_indices]
245
- latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device,dtype=dtype)
246
- embeddings = torch.tensor(np.array([item["embeddings"] for item in samples_data])).to(device,dtype=dtype)
247
- texts = [item["text"] for item in samples_data]
248
- fixed_samples[size] = (latents, embeddings, texts)
249
-
250
- print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
251
- return fixed_samples
252
-
253
- if limit > 0:
254
- dataset = load_from_disk(ds_path).select(range(limit))
255
- else:
256
- dataset = load_from_disk(ds_path)
257
-
258
- def collate_fn_simple(batch):
259
- latents = torch.tensor(np.array([item["vae"] for item in batch])).to(device,dtype=dtype)
260
- embeddings = torch.tensor(np.array([item["embeddings"] for item in batch])).to(device,dtype=dtype)
261
- attention_mask = torch.abs(embeddings).sum(dim=-1) > 1e-6
262
- attention_mask = attention_mask.to(device, dtype=torch.int64)
263
- return latents, embeddings, attention_mask
264
-
265
- batch_sampler = DistributedResolutionBatchSampler(
266
- dataset=dataset,
267
- batch_size=batch_size,
268
- num_replicas=accelerator.num_processes,
269
- rank=accelerator.process_index,
270
- shuffle=shuffle
271
- )
272
-
273
- dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple)
274
- print("Total samples",len(dataloader))
275
- dataloader = accelerator.prepare(dataloader)
276
-
277
- start_epoch = 0
278
- global_step = 0
279
- total_training_steps = (len(dataloader) * num_epochs)
280
- world_size = accelerator.state.num_processes
281
-
282
- # Опция загрузки модели из последнего чекпоинта (если существует)
283
- latest_checkpoint = os.path.join(checkpoints_folder, project)
284
- if os.path.isdir(latest_checkpoint):
285
- print("Загружаем UNet из чекпоинта:", latest_checkpoint)
286
- unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device,dtype=dtype)
287
- if unet_gradient:
288
- unet.enable_gradient_checkpointing()
289
- unet.set_use_memory_efficient_attention_xformers(False)
290
- try:
291
- unet.set_attn_processor(AttnProcessor2_0())
292
- except Exception as e:
293
- print(f"Ошибка при включении SDPA: {e}")
294
- unet.set_use_memory_efficient_attention_xformers(True)
295
-
296
- else:
297
- # FIX: если чекпоинта нет — прекращаем с понятной ошибкой (лучше, чем неожиданные NameError дальше)
298
- raise FileNotFoundError(f"UNet checkpoint not found at {latest_checkpoint}. Положи UNet чекпоинт в {latest_checkpoint} или укажи другой путь.")
299
-
300
- if lora_name:
301
- print(f"--- Настройка LoRA через PEFT (Rank={lora_rank}, Alpha={lora_alpha}) ---")
302
- from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
303
- from peft.tuners.lora import LoraModel
304
- import os
305
- unet.requires_grad_(False)
306
- print("Параметры базового UNet заморожены.")
307
-
308
- lora_config = LoraConfig(
309
- r=lora_rank,
310
- lora_alpha=lora_alpha,
311
- target_modules=["to_q", "to_k", "to_v", "to_out.0"],
312
- )
313
- unet.add_adapter(lora_config)
314
-
315
- from peft import get_peft_model
316
- peft_unet = get_peft_model(unet, lora_config)
317
- params_to_optimize = list(p for p in peft_unet.parameters() if p.requires_grad)
318
-
319
- if accelerator.is_main_process:
320
- lora_params_count = sum(p.numel() for p in params_to_optimize)
321
- total_params_count = sum(p.numel() for p in unet.parameters())
322
- print(f"Количество обучаемых параметров (LoRA): {lora_params_count:,}")
323
- print(f"Общее количество параметров UNet: {total_params_count:,}")
324
-
325
- lora_save_path = os.path.join("lora", lora_name)
326
- os.makedirs(lora_save_path, exist_ok=True)
327
-
328
- def save_lora_checkpoint(model):
329
- if accelerator.is_main_process:
330
- print(f"Сохраняем LoRA адаптеры в {lora_save_path}")
331
- from peft.utils.save_and_load import get_peft_model_state_dict
332
- lora_state_dict = get_peft_model_state_dict(model)
333
- torch.save(lora_state_dict, os.path.join(lora_save_path, "adapter_model.bin"))
334
- model.peft_config["default"].save_pretrained(lora_save_path)
335
- from diffusers import StableDiffusionXLPipeline
336
- StableDiffusionXLPipeline.save_lora_weights(lora_save_path, lora_state_dict)
337
-
338
- # --------------------------- Оптимизатор ---------------------------
339
- if lora_name:
340
- trainable_params = [p for p in unet.parameters() if p.requires_grad]
341
- else:
342
- if fbp:
343
- trainable_params = list(unet.parameters())
344
-
345
- def create_optimizer(name, params):
346
- if name == "adam8bit":
347
- return bnb.optim.AdamW8bit(
348
- params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.01,
349
- percentile_clipping=percentile_clipping
350
- )
351
- elif name == "adam":
352
- return torch.optim.AdamW(
353
- params, lr=base_learning_rate, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01
354
- )
355
- else:
356
- raise ValueError(f"Unknown optimizer: {name}")
357
-
358
- if fbp:
359
- optimizer_dict = {p: create_optimizer(optimizer_type, [p]) for p in trainable_params}
360
- def optimizer_hook(param):
361
- optimizer_dict[param].step()
362
- optimizer_dict[param].zero_grad(set_to_none=True)
363
- for param in trainable_params:
364
- param.register_post_accumulate_grad_hook(optimizer_hook)
365
- unet, optimizer = accelerator.prepare(unet, optimizer_dict)
366
- else:
367
- optimizer = create_optimizer(optimizer_type, unet.parameters())
368
- def lr_schedule(step):
369
- x = step / (total_training_steps * world_size)
370
- warmup = warmup_percent
371
- if not use_decay:
372
- return base_learning_rate
373
- if x < warmup:
374
- return min_learning_rate + (base_learning_rate - min_learning_rate) * (x / warmup)
375
- decay_ratio = (x - warmup) / (1 - warmup)
376
- return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
377
- (1 + math.cos(math.pi * decay_ratio))
378
- lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
379
-
380
- num_params = sum(p.numel() for p in unet.parameters())
381
- print(f"[rank {accelerator.process_index}] total params: {num_params}")
382
- for name, param in unet.named_parameters():
383
- if torch.isnan(param).any() or torch.isinf(param).any():
384
- print(f"[rank {accelerator.process_index}] NaN/Inf in {name}")
385
- unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
386
-
387
- if torch_compile:
388
- print("compiling")
389
- torch.set_float32_matmul_precision('high')
390
- torch.backends.cudnn.allow_tf32 = True
391
- torch.backends.cuda.matmul.allow_tf32 = True
392
- unet = torch.compile(unet)#, mode='max-autotune')
393
- print("compiling - ok")
394
-
395
- # --------------------------- Фиксированные семплы для генерации ---------------------------
396
- fixed_samples = get_fixed_samples_by_resolution(dataset)
397
-
398
- def get_negative_embedding(neg_prompt="", batch_size=1):
399
- """
400
- Возвращает эмбеддинг негативного промпта с батчем.
401
- Загружает модели, вычисляет эмбеддинг, выгружает модели на CPU.
402
- """
403
- import torch
404
- from transformers import AutoTokenizer, AutoModel
405
-
406
- # Настройки
407
- dtype = torch.float16
408
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
409
-
410
- # Загрузка моделей (если ещё не загружены)
411
- if not hasattr(get_negative_embedding, "tokenizer"):
412
- get_negative_embedding.tokenizer = AutoTokenizer.from_pretrained(
413
- "Qwen/Qwen3-0.6B"
414
- )
415
- get_negative_embedding.text_model = AutoModel.from_pretrained(
416
- "Qwen/Qwen3-0.6B"
417
- ).to(device).eval()
418
-
419
- # Вычисление эмбеддинга
420
- def encode_texts(texts, max_length=150):
421
- with torch.inference_mode():
422
- toks = get_negative_embedding.tokenizer(
423
- texts, return_tensors="pt", padding="max_length", truncation=True, max_length=max_length
424
- ).to(device)
425
-
426
- outs = get_negative_embedding.text_model(**toks, output_hidden_states=True, return_dict=True)
427
- hidden = outs.hidden_states[-1] # [B, L, D]
428
- mask = toks["attention_mask"].unsqueeze(-1) # (B, L, 1)
429
- hidden = hidden * mask
430
-
431
- return hidden
432
-
433
- # Возвращаем эмбеддинг
434
- if not neg_prompt:
435
- hidden_dim = 1024 # Размерность эмбеддинга Qwen3-Embedding-0.6B
436
- seq_len = 150
437
- return torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device)
438
-
439
- uncond_emb = encode_texts([neg_prompt]).to(dtype=dtype, device=device)
440
- uncond_emb = uncond_emb.repeat(batch_size, 1, 1) # Добавляем батч
441
-
442
- # Выгружаем модели
443
- if 1:
444
- if hasattr(get_negative_embedding, "text_model"):
445
- get_negative_embedding.text_model = get_negative_embedding.text_model.to("cpu")
446
- if hasattr(get_negative_embedding, "tokenizer"):
447
- del get_negative_embedding.tokenizer # Освобождаем память
448
- torch.cuda.empty_cache()
449
-
450
- return uncond_emb
451
-
452
- uncond_emb = get_negative_embedding("low quality")
453
-
454
- @torch.compiler.disable()
455
- @torch.no_grad()
456
- def generate_and_save_samples(fixed_samples_cpu,empty_embeddings, step):
457
- original_model = None
458
- try:
459
- # безопасный unwrap: если компилировано, unwrap не нужен
460
- if not torch_compile:
461
- original_model = accelerator.unwrap_model(unet, keep_torch_compile=True).eval()
462
- else:
463
- original_model = unet.eval()
464
-
465
- vae.to(device=device).eval() # временно подгружаем VAE на GPU для декодинга
466
-
467
-
468
- all_generated_images = []
469
- all_captions = []
470
-
471
- for size, (sample_latents, sample_text_embeddings, sample_text) in fixed_samples_cpu.items():
472
- width, height = size
473
- sample_latents = sample_latents.to(dtype=dtype, device=device)
474
- sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
475
-
476
- # начальный шум
477
- latents = torch.randn(
478
- sample_latents.shape,
479
- device=device,
480
- dtype=sample_latents.dtype,
481
- generator=torch.Generator(device=device).manual_seed(seed)
482
- )
483
-
484
- # подготовим timesteps через шедулер
485
- scheduler.set_timesteps(n_diffusion_steps, device=device)
486
- prompt_mask = torch.abs(sample_text_embeddings).sum(dim=-1) > 1e-6 # (B, Seq)
487
- prompt_mask = prompt_mask.to(dtype=torch.int64)
488
-
489
- # Создаем маску для негатива (empty_embeddings)
490
- # empty_embeddings у вас [Batch, Seq, Dim], скорее всего там нули кроме первых токенов
491
- neg_mask = torch.abs(empty_embeddings).sum(dim=-1) > 1e-6
492
- neg_mask = neg_mask.repeat(sample_text_embeddings.shape[0], 1).to(dtype=torch.int64, device=device)
493
-
494
- for t in scheduler.timesteps:
495
- # guidance: удваиваем батч
496
- if guidance_scale != 1:
497
- latent_model_input = torch.cat([latents, latents], dim=0)
498
-
499
- # empty_embeddings: [1, 1, hidden_dim] → повторяем по seq_len и batch
500
- seq_len = sample_text_embeddings.shape[1]
501
- hidden_dim = sample_text_embeddings.shape[2]
502
- empty_embeddings_exp = empty_embeddings.expand(-1, seq_len, hidden_dim) # [1, seq_len, hidden_dim]
503
- empty_embeddings_exp = empty_embeddings_exp.repeat(sample_text_embeddings.shape[0], 1, 1) # [batch, seq_len, hidden_dim]
504
-
505
- text_embeddings_batch = torch.cat([empty_embeddings_exp, sample_text_embeddings], dim=0)
506
- attention_mask_batch = torch.cat([neg_mask, prompt_mask], dim=0)
507
- else:
508
- latent_model_input = latents
509
- text_embeddings_batch = sample_text_embeddings
510
- attention_mask_batch = prompt_mask
511
-
512
-
513
- # предсказание потока (velocity)
514
- model_out = original_model(
515
- latent_model_input,
516
- t,
517
- encoder_hidden_states=text_embeddings_batch,
518
- encoder_attention_mask=attention_mask_batch)
519
- flow = getattr(model_out, "sample", model_out)
520
-
521
- # guidance объединение
522
- if guidance_scale != 1:
523
- flow_uncond, flow_cond = flow.chunk(2)
524
- flow = flow_uncond + guidance_scale * (flow_cond - flow_uncond)
525
-
526
- # шаг через scheduler
527
- latents = scheduler.step(flow, t, latents).prev_sample
528
-
529
- current_latents = latents
530
-
531
-
532
- # Параметры нормализации
533
- latent_for_vae = current_latents.detach() / scaling_factor + shift_factor
534
- decoded = vae.decode(latent_for_vae.to(torch.float32)).sample
535
-
536
- decoded_fp32 = decoded.to(torch.float32)
537
- for img_idx, img_tensor in enumerate(decoded_fp32):
538
-
539
- # Форма: [3, H, W] -> преобразуем в [H, W, 3]
540
- img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy()
541
- img = img.transpose(1, 2, 0) # Из [3, H, W] в [H, W, 3]
542
-
543
- #img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1, 2, 0)
544
- if np.isnan(img).any():
545
- print("NaNs found, saving stopped! Step:", step)
546
- pil_img = Image.fromarray((img * 255).astype("uint8"))
547
-
548
- max_w_overall = max(s[0] for s in fixed_samples_cpu.keys())
549
- max_h_overall = max(s[1] for s in fixed_samples_cpu.keys())
550
- max_w_overall = max(255, max_w_overall)
551
- max_h_overall = max(255, max_h_overall)
552
-
553
- padded_img = ImageOps.pad(pil_img, (max_w_overall, max_h_overall), color='white')
554
- all_generated_images.append(padded_img)
555
-
556
- caption_text = sample_text[img_idx][:200] if img_idx < len(sample_text) else ""
557
- all_captions.append(caption_text)
558
-
559
- sample_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg"
560
- pil_img.save(sample_path, "JPEG", quality=96)
561
-
562
- if use_wandb and accelerator.is_main_process:
563
- wandb_images = [
564
- wandb.Image(img, caption=f"{all_captions[i]}")
565
- for i, img in enumerate(all_generated_images)
566
- ]
567
- wandb.log({"generated_images": wandb_images})
568
- if use_comet_ml and accelerator.is_main_process:
569
- for i, img in enumerate(all_generated_images):
570
- comet_experiment.log_image(
571
- image_data=img,
572
- name=f"step_{step}_img_{i}",
573
- step=step,
574
- metadata={
575
- "caption": all_captions[i],
576
- "width": img.width,
577
- "height": img.height,
578
- "global_step": step
579
- }
580
- )
581
- finally:
582
- # вернуть VAE на CPU (как было в твоём коде)
583
- vae.to("cpu")
584
- for var in list(locals().keys()):
585
- if isinstance(locals()[var], torch.Tensor):
586
- del locals()[var]
587
- torch.cuda.empty_cache()
588
- gc.collect()
589
-
590
- # --------------------------- Генерация сэмплов перед обучением ---------------------------
591
- if accelerator.is_main_process:
592
- if save_model:
593
- print("Генерация сэмплов до старта обучения...")
594
- generate_and_save_samples(fixed_samples,uncond_emb,0)
595
- accelerator.wait_for_everyone()
596
-
597
- # Модифицируем функцию сохранения модели для поддержки LoRA
598
- def save_checkpoint(unet, variant=""):
599
- if accelerator.is_main_process:
600
- if lora_name:
601
- save_lora_checkpoint(unet)
602
- else:
603
- # безопасный unwrap для компилированной модели
604
- model_to_save = None
605
- if not torch_compile:
606
- model_to_save = accelerator.unwrap_model(unet)
607
- else:
608
- model_to_save = unet
609
-
610
- if variant != "":
611
- model_to_save.to(dtype=torch.float16).save_pretrained(
612
- os.path.join(checkpoints_folder, f"{project}"), variant=variant
613
- )
614
- else:
615
- model_to_save.save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
616
-
617
- unet = unet.to(dtype=dtype)
618
-
619
- # --------------------------- Тренировочный цикл ---------------------------
620
- if accelerator.is_main_process:
621
- print(f"Total steps per GPU: {total_training_steps}")
622
-
623
- epoch_loss_points = []
624
- progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
625
-
626
- steps_per_epoch = len(dataloader)
627
- sample_interval = max(1, steps_per_epoch // sample_interval_share)
628
- min_loss = 2.
629
-
630
- for epoch in range(start_epoch, start_epoch + num_epochs):
631
- batch_losses = []
632
- batch_grads = []
633
- batch_sampler.set_epoch(epoch)
634
- accelerator.wait_for_everyone()
635
- unet.train()
636
- #print("epoch:",epoch)
637
- for step, (latents, embeddings, attention_mask) in enumerate(dataloader):
638
- with accelerator.accumulate(unet):
639
- if save_model == False and step == 5 :
640
- used_gb = torch.cuda.max_memory_allocated() / 1024**3
641
- print(f"Шаг {step}: {used_gb:.2f} GB")
642
-
643
- # шум
644
- noise = torch.randn_like(latents, dtype=latents.dtype)
645
-
646
- # берём t из [0, 1]
647
- t = torch.rand(latents.shape[0], device=latents.device, dtype=latents.dtype)
648
-
649
- # интерполяция между x0 и шумом
650
- noisy_latents = (1.0 - t.view(-1, 1, 1, 1)) * latents + t.view(-1, 1, 1, 1) * noise
651
-
652
- # делаем integer timesteps для UNet
653
- timesteps = (t * scheduler.config.num_train_timesteps).long()
654
-
655
- # предсказание потока (Flow)
656
- model_pred = unet(noisy_latents, timesteps, embeddings, encoder_attention_mask=attention_mask).sample
657
-
658
- # таргет — векторное поле (= разность между конечными точками)
659
- target = noise - latents # или latents - noise?
660
-
661
- # MSE лосс
662
- mse_loss = F.mse_loss(model_pred.float(), target.float())
663
-
664
- # Сохраняем для логов (мы сохраняем MSE отдельно — как показатель)
665
- batch_losses.append(mse_loss.detach().item())
666
-
667
- if (global_step % 100 == 0) or (global_step % sample_interval == 0):
668
- accelerator.wait_for_everyone()
669
-
670
- # Backward
671
- accelerator.backward(mse_loss)
672
-
673
- if (global_step % 100 == 0) or (global_step % sample_interval == 0):
674
- accelerator.wait_for_everyone()
675
-
676
- grad = 0.0
677
- if not fbp:
678
- if accelerator.sync_gradients:
679
- with torch.amp.autocast('cuda', enabled=False):
680
- grad_val = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm)
681
- grad = float(grad_val)
682
- optimizer.step()
683
- lr_scheduler.step()
684
- optimizer.zero_grad(set_to_none=True)
685
-
686
- if accelerator.sync_gradients:
687
- global_step += 1
688
- progress_bar.update(1)
689
- # Логируем метрики
690
- if accelerator.is_main_process:
691
- if fbp:
692
- current_lr = base_learning_rate
693
- else:
694
- current_lr = lr_scheduler.get_last_lr()[0]
695
- batch_grads.append(grad)
696
-
697
- log_data = {}
698
- log_data["loss"] = mse_loss.detach().item()
699
- log_data["lr"] = current_lr
700
- log_data["grad"] = grad
701
- if accelerator.sync_gradients:
702
- if use_wandb:
703
- wandb.log(log_data, step=global_step)
704
- if use_comet_ml:
705
- comet_experiment.log_metrics(log_data, step=global_step)
706
-
707
- # Генерируем сэмплы с заданным интервалом
708
- if global_step % sample_interval == 0:
709
- generate_and_save_samples(fixed_samples,uncond_emb, global_step)
710
- last_n = sample_interval
711
-
712
- if save_model:
713
- avg_sample_loss = np.mean(batch_losses[-sample_interval:]) if len(batch_losses) > 0 else 0.0
714
- print("saving:", avg_sample_loss < min_loss * save_barrier, "Avg:", avg_sample_loss)
715
- if avg_sample_loss is not None and avg_sample_loss < min_loss * save_barrier:
716
- min_loss = avg_sample_loss
717
- save_checkpoint(unet)
718
-
719
-
720
- if accelerator.is_main_process:
721
- # local averages
722
- avg_epoch_loss = np.mean(batch_losses) if len(batch_losses) > 0 else 0.0
723
- avg_epoch_grad = np.mean(batch_grads) if len(batch_grads) > 0 else 0.0
724
-
725
- print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
726
- log_data_ep = {
727
- "epoch_loss": avg_epoch_loss,
728
- "epoch_grad": avg_epoch_grad,
729
- "epoch": epoch + 1,
730
- }
731
- if use_wandb:
732
- wandb.log(log_data_ep)
733
- if use_comet_ml:
734
- comet_experiment.log_metrics(log_data_ep)
735
-
736
- # Завершение обучения - сохраняем финальную модель
737
- if accelerator.is_main_process:
738
- print("Обучение завершено! Сохраняем финальную модель...")
739
- if save_model:
740
- save_checkpoint(unet,"fp16")
741
- if use_comet_ml:
742
- comet_experiment.end()
743
- accelerator.free_memory()
744
- if torch.distributed.is_initialized():
745
- torch.distributed.destroy_process_group()
746
-
747
- print("Готово!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
train.py CHANGED
@@ -29,7 +29,7 @@ from transformers import AutoTokenizer, AutoModel
29
  ds_path = "/workspace/sdxs/datasets/ds1234_640"
30
  project = "unet"
31
  batch_size = 56
32
- base_learning_rate = 3e-5
33
  min_learning_rate = 3e-5
34
  num_epochs = 50
35
  sample_interval_share = 2
 
29
  ds_path = "/workspace/sdxs/datasets/ds1234_640"
30
  project = "unet"
31
  batch_size = 56
32
+ base_learning_rate = 5e-5
33
  min_learning_rate = 3e-5
34
  num_epochs = 50
35
  sample_interval_share = 2
train_pooling_copy.py DELETED
@@ -1,736 +0,0 @@
1
- import os
2
- import math
3
- import torch
4
- import numpy as np
5
- import matplotlib.pyplot as plt
6
- from torch.utils.data import DataLoader, Sampler
7
- from torch.utils.data.distributed import DistributedSampler
8
- from torch.optim.lr_scheduler import LambdaLR
9
- from collections import defaultdict
10
- from diffusers import UNet2DConditionModel, AutoencoderKLWan,AutoencoderKL
11
- from accelerate import Accelerator
12
- from datasets import load_from_disk
13
- from tqdm import tqdm
14
- from PIL import Image,ImageOps
15
- import wandb
16
- import random
17
- import gc
18
- from accelerate.state import DistributedType
19
- from torch.distributed import broadcast_object_list
20
- from torch.utils.checkpoint import checkpoint
21
- from diffusers.models.attention_processor import AttnProcessor2_0
22
- from datetime import datetime
23
- import bitsandbytes as bnb
24
- import torch.nn.functional as F
25
- from collections import deque
26
- from transformers import AutoTokenizer, AutoModel
27
-
28
- # --------------------------- Параметры ---------------------------
29
- ds_path = "/workspace/sdxs/datasets/ds1234_640"
30
- project = "unet"
31
- batch_size = 64
32
- base_learning_rate = 6e-5
33
- min_learning_rate = 2.5e-5
34
- num_epochs = 80
35
- # samples/save per epoch
36
- sample_interval_share = 2
37
- use_wandb = True
38
- use_comet_ml = False
39
- save_model = True
40
- use_decay = True
41
- fbp = False # fused backward pass
42
- optimizer_type = "adam8bit"
43
- torch_compile = False
44
- unet_gradient = True
45
- clip_sample = False #Scheduler
46
- fixed_seed = False
47
- shuffle = True
48
- comet_ml_api_key = "Agctp26mbqnoYrrlvQuKSTk6r" # Добавлен API ключ для Comet ML
49
- comet_ml_workspace = "recoilme" # Добавлен workspace для Comet ML
50
- torch.backends.cuda.matmul.allow_tf32 = True
51
- torch.backends.cudnn.allow_tf32 = True
52
- torch.backends.cuda.enable_mem_efficient_sdp(False)
53
- dtype = torch.float32
54
- save_barrier = 1.006
55
- warmup_percent = 0.01
56
- percentile_clipping = 99 # 8bit optim
57
- betta2 = 0.99
58
- eps = 1e-8
59
- clip_grad_norm = 1.0
60
- steps_offset = 0 # Scheduler
61
- limit = 0
62
- checkpoints_folder = ""
63
- mixed_precision = "no" #"fp16"
64
- gradient_accumulation_steps = 1
65
- accelerator = Accelerator(
66
- mixed_precision=mixed_precision,
67
- gradient_accumulation_steps=gradient_accumulation_steps
68
- )
69
- device = accelerator.device
70
-
71
- # Параметры для диффузии
72
- n_diffusion_steps = 50
73
- samples_to_generate = 12
74
- guidance_scale = 4
75
-
76
- # Папки для сохранения результатов
77
- generated_folder = "samples"
78
- os.makedirs(generated_folder, exist_ok=True)
79
-
80
- # Настройка seed для воспроизводимости
81
- current_date = datetime.now()
82
- seed = int(current_date.strftime("%Y%m%d"))
83
- if fixed_seed:
84
- torch.manual_seed(seed)
85
- np.random.seed(seed)
86
- random.seed(seed)
87
- if torch.cuda.is_available():
88
- torch.cuda.manual_seed_all(seed)
89
-
90
- # --------------------------- Параметры LoRA ---------------------------
91
- lora_name = ""
92
- lora_rank = 32
93
- lora_alpha = 64
94
-
95
- print("init")
96
-
97
- # --------------------------- Инициализация WandB ---------------------------
98
- if accelerator.is_main_process:
99
- if use_wandb:
100
- wandb.init(project=project+lora_name, config={
101
- "batch_size": batch_size,
102
- "base_learning_rate": base_learning_rate,
103
- "num_epochs": num_epochs,
104
- "fbp": fbp,
105
- "optimizer_type": optimizer_type,
106
- })
107
- if use_comet_ml:
108
- from comet_ml import Experiment
109
- comet_experiment = Experiment(
110
- api_key=comet_ml_api_key,
111
- project_name=project,
112
- workspace=comet_ml_workspace
113
- )
114
- # Логируем гиперпараметры в Comet ML
115
- hyper_params = {
116
- "batch_size": batch_size,
117
- "base_learning_rate": base_learning_rate,
118
- "min_learning_rate": min_learning_rate,
119
- "num_epochs": num_epochs,
120
- "n_diffusion_steps": n_diffusion_steps,
121
- "guidance_scale": guidance_scale,
122
- "optimizer_type": optimizer_type,
123
- "mixed_precision": mixed_precision,
124
- }
125
- comet_experiment.log_parameters(hyper_params)
126
-
127
- # Включение Flash Attention 2/SDPA
128
- torch.backends.cuda.enable_flash_sdp(True)
129
- # --------------------------- Инициализация Accelerator --------------------
130
- gen = torch.Generator(device=device)
131
- gen.manual_seed(seed)
132
-
133
- # --------------------------- Загрузка моделей ---------------------------
134
- # VAE загружается на CPU для экономии GPU-памяти (как в твоём оригинальном коде)
135
- vae = AutoencoderKL.from_pretrained("vae1x", torch_dtype=dtype).to("cpu").eval()
136
- tokenizer = AutoTokenizer.from_pretrained("tokenizer")
137
- text_model = AutoModel.from_pretrained("text_encoder").to(device).eval()
138
-
139
- def encode_texts(texts, max_length=150):
140
- with torch.no_grad():
141
- toks = tokenizer(
142
- texts,
143
- return_tensors="pt",
144
- padding="max_length",
145
- truncation=True,
146
- max_length=max_length
147
- ).to(device)
148
- outs = self.text_encoder(**toks, output_hidden_states=True, return_dict=True)
149
- # Токен-эмбеддинги (для Cross-Attention)
150
- hidden = outs.hidden_states[-1] # Используем last hidden state
151
- # Маска внимания (для Cross-Attention)
152
- attention_mask = toks["attention_mask"]
153
-
154
- # Пулинг-эмбеддинг (для Class Conditioning). Берем эмбеддинг последнего токена без padding.
155
- sequence_lengths = attention_mask.sum(dim=1) - 1
156
- batch_size = hidden.shape[0]
157
- pooled = hidden[torch.arange(batch_size, device=hidden.device), sequence_lengths]
158
-
159
- return embeddings, attention_mask, pooled
160
-
161
- shift_factor = getattr(vae.config, "shift_factor", 0.0)
162
- if shift_factor is None:
163
- shift_factor = 0.0
164
-
165
- scaling_factor = getattr(vae.config, "scaling_factor", 1.0)
166
- if scaling_factor is None:
167
- scaling_factor = 1.0
168
-
169
- latents_mean = getattr(vae.config, "latents_mean", None)
170
- latents_std = getattr(vae.config, "latents_std", None)
171
-
172
- from diffusers import FlowMatchEulerDiscreteScheduler
173
-
174
- # Подстрой под свои параметры
175
- num_train_timesteps = 1000
176
-
177
- scheduler = FlowMatchEulerDiscreteScheduler(
178
- num_train_timesteps=num_train_timesteps,
179
- #shift=3.0, # пример; подбирается при необходимости
180
- #use_dynamic_shifting=True
181
- )
182
-
183
-
184
- class DistributedResolutionBatchSampler(Sampler):
185
- def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, drop_last=True):
186
- self.dataset = dataset
187
- self.batch_size = max(1, batch_size // num_replicas)
188
- self.num_replicas = num_replicas
189
- self.rank = rank
190
- self.shuffle = shuffle
191
- self.drop_last = drop_last
192
- self.epoch = 0
193
-
194
- try:
195
- widths = np.array(dataset["width"])
196
- heights = np.array(dataset["height"])
197
- except KeyError:
198
- widths = np.zeros(len(dataset))
199
- heights = np.zeros(len(dataset))
200
-
201
- self.size_keys = np.unique(np.stack([widths, heights], axis=1), axis=0)
202
- self.size_groups = {}
203
- for w, h in self.size_keys:
204
- mask = (widths == w) & (heights == h)
205
- self.size_groups[(w, h)] = np.where(mask)[0]
206
-
207
- self.group_num_batches = {}
208
- total_batches = 0
209
- for size, indices in self.size_groups.items():
210
- num_full_batches = len(indices) // (self.batch_size * self.num_replicas)
211
- self.group_num_batches[size] = num_full_batches
212
- total_batches += num_full_batches
213
-
214
- self.num_batches = (total_batches // self.num_replicas) * self.num_replicas
215
-
216
- def __iter__(self):
217
- if torch.cuda.is_available():
218
- torch.cuda.empty_cache()
219
- all_batches = []
220
- rng = np.random.RandomState(self.epoch)
221
-
222
- for size, indices in self.size_groups.items():
223
- indices = indices.copy()
224
- if self.shuffle:
225
- rng.shuffle(indices)
226
- num_full_batches = self.group_num_batches[size]
227
- if num_full_batches == 0:
228
- continue
229
- valid_indices = indices[:num_full_batches * self.batch_size * self.num_replicas]
230
- batches = valid_indices.reshape(-1, self.batch_size * self.num_replicas)
231
- start_idx = self.rank * self.batch_size
232
- end_idx = start_idx + self.batch_size
233
- gpu_batches = batches[:, start_idx:end_idx]
234
- all_batches.extend(gpu_batches)
235
-
236
- if self.shuffle:
237
- rng.shuffle(all_batches)
238
- accelerator.wait_for_everyone()
239
- return iter(all_batches)
240
-
241
- def __len__(self):
242
- return self.num_batches
243
-
244
- def set_epoch(self, epoch):
245
- self.epoch = epoch
246
-
247
- # Функция для выборки фиксированных семплов по размерам
248
- def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
249
- size_groups = defaultdict(list)
250
- try:
251
- widths = dataset["width"]
252
- heights = dataset["height"]
253
- except KeyError:
254
- widths = [0] * len(dataset)
255
- heights = [0] * len(dataset)
256
- for i, (w, h) in enumerate(zip(widths, heights)):
257
- size = (w, h)
258
- size_groups[size].append(i)
259
-
260
- fixed_samples = {}
261
- for size, indices in size_groups.items():
262
- n_samples = min(samples_per_group, len(indices))
263
- if len(size_groups)==1:
264
- n_samples = samples_to_generate
265
- if n_samples == 0:
266
- continue
267
- sample_indices = random.sample(indices, n_samples)
268
- samples_data = [dataset[idx] for idx in sample_indices]
269
- latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device,dtype=dtype)
270
- embeddings = torch.tensor(np.array([item["embeddings"] for item in samples_data])).to(device,dtype=dtype)
271
- texts = [item["text"] for item in samples_data]
272
- fixed_samples[size] = (latents, embeddings, texts)
273
-
274
- print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
275
- return fixed_samples
276
-
277
- if limit > 0:
278
- dataset = load_from_disk(ds_path).select(range(limit))
279
- else:
280
- dataset = load_from_disk(ds_path)
281
-
282
- def collate_fn_simple(batch):
283
- latents = torch.tensor(np.array([item["vae"] for item in batch])).to(device,dtype=dtype)
284
- embeddings = torch.tensor(np.array([item["embeddings"] for item in batch])).to(device,dtype=dtype)
285
- attention_mask = torch.abs(embeddings).sum(dim=-1) > 1e-6
286
- attention_mask = attention_mask.to(device, dtype=torch.int64)
287
- return latents, embeddings, attention_mask
288
-
289
- batch_sampler = DistributedResolutionBatchSampler(
290
- dataset=dataset,
291
- batch_size=batch_size,
292
- num_replicas=accelerator.num_processes,
293
- rank=accelerator.process_index,
294
- shuffle=shuffle
295
- )
296
-
297
- dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple)
298
- print("Total samples",len(dataloader))
299
- dataloader = accelerator.prepare(dataloader)
300
-
301
- start_epoch = 0
302
- global_step = 0
303
- total_training_steps = (len(dataloader) * num_epochs)
304
- world_size = accelerator.state.num_processes
305
-
306
- # Опция загрузки модели из последнего чекпоинта (если существует)
307
- latest_checkpoint = os.path.join(checkpoints_folder, project)
308
- if os.path.isdir(latest_checkpoint):
309
- print("Загружаем UNet из чекпоинта:", latest_checkpoint)
310
- unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device,dtype=dtype)
311
- if unet_gradient:
312
- unet.enable_gradient_checkpointing()
313
- unet.set_use_memory_efficient_attention_xformers(False)
314
- try:
315
- unet.set_attn_processor(AttnProcessor2_0())
316
- except Exception as e:
317
- print(f"Ошибка при включении SDPA: {e}")
318
- unet.set_use_memory_efficient_attention_xformers(True)
319
-
320
- else:
321
- # FIX: если чекпоинта нет — прекращаем с понятной ошибкой (лучше, чем неожиданные NameError дальше)
322
- raise FileNotFoundError(f"UNet checkpoint not found at {latest_checkpoint}. Положи UNet чекпоинт в {latest_checkpoint} или укажи другой путь.")
323
-
324
- if lora_name:
325
- print(f"--- Настройка LoRA через PEFT (Rank={lora_rank}, Alpha={lora_alpha}) ---")
326
- from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
327
- from peft.tuners.lora import LoraModel
328
- import os
329
- unet.requires_grad_(False)
330
- print("Параметры базового UNet заморожены.")
331
-
332
- lora_config = LoraConfig(
333
- r=lora_rank,
334
- lora_alpha=lora_alpha,
335
- target_modules=["to_q", "to_k", "to_v", "to_out.0"],
336
- )
337
- unet.add_adapter(lora_config)
338
-
339
- from peft import get_peft_model
340
- peft_unet = get_peft_model(unet, lora_config)
341
- params_to_optimize = list(p for p in peft_unet.parameters() if p.requires_grad)
342
-
343
- if accelerator.is_main_process:
344
- lora_params_count = sum(p.numel() for p in params_to_optimize)
345
- total_params_count = sum(p.numel() for p in unet.parameters())
346
- print(f"Количество обучаемых параметров (LoRA): {lora_params_count:,}")
347
- print(f"Общее количество параметров UNet: {total_params_count:,}")
348
-
349
- lora_save_path = os.path.join("lora", lora_name)
350
- os.makedirs(lora_save_path, exist_ok=True)
351
-
352
- def save_lora_checkpoint(model):
353
- if accelerator.is_main_process:
354
- print(f"Сохраняем LoRA адаптеры в {lora_save_path}")
355
- from peft.utils.save_and_load import get_peft_model_state_dict
356
- lora_state_dict = get_peft_model_state_dict(model)
357
- torch.save(lora_state_dict, os.path.join(lora_save_path, "adapter_model.bin"))
358
- model.peft_config["default"].save_pretrained(lora_save_path)
359
- from diffusers import StableDiffusionXLPipeline
360
- StableDiffusionXLPipeline.save_lora_weights(lora_save_path, lora_state_dict)
361
-
362
- # --------------------------- Оптимизатор ---------------------------
363
- if lora_name:
364
- trainable_params = [p for p in unet.parameters() if p.requires_grad]
365
- else:
366
- if fbp:
367
- trainable_params = list(unet.parameters())
368
-
369
- def create_optimizer(name, params):
370
- if name == "adam8bit":
371
- return bnb.optim.AdamW8bit(
372
- params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.01,
373
- percentile_clipping=percentile_clipping
374
- )
375
- elif name == "adam":
376
- return torch.optim.AdamW(
377
- params, lr=base_learning_rate, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01
378
- )
379
- else:
380
- raise ValueError(f"Unknown optimizer: {name}")
381
-
382
- if fbp:
383
- optimizer_dict = {p: create_optimizer(optimizer_type, [p]) for p in trainable_params}
384
- def optimizer_hook(param):
385
- optimizer_dict[param].step()
386
- optimizer_dict[param].zero_grad(set_to_none=True)
387
- for param in trainable_params:
388
- param.register_post_accumulate_grad_hook(optimizer_hook)
389
- unet, optimizer = accelerator.prepare(unet, optimizer_dict)
390
- else:
391
- optimizer = create_optimizer(optimizer_type, unet.parameters())
392
- def lr_schedule(step):
393
- x = step / (total_training_steps * world_size)
394
- warmup = warmup_percent
395
- if not use_decay:
396
- return base_learning_rate
397
- if x < warmup:
398
- return min_learning_rate + (base_learning_rate - min_learning_rate) * (x / warmup)
399
- decay_ratio = (x - warmup) / (1 - warmup)
400
- return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
401
- (1 + math.cos(math.pi * decay_ratio))
402
- lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
403
-
404
- num_params = sum(p.numel() for p in unet.parameters())
405
- print(f"[rank {accelerator.process_index}] total params: {num_params}")
406
- for name, param in unet.named_parameters():
407
- if torch.isnan(param).any() or torch.isinf(param).any():
408
- print(f"[rank {accelerator.process_index}] NaN/Inf in {name}")
409
- unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
410
-
411
- if torch_compile:
412
- print("compiling")
413
- torch.set_float32_matmul_precision('high')
414
- torch.backends.cudnn.allow_tf32 = True
415
- torch.backends.cuda.matmul.allow_tf32 = True
416
- unet = torch.compile(unet)#, mode='max-autotune')
417
- print("compiling - ok")
418
-
419
- # --------------------------- Фиксированные семплы для генерации ---------------------------
420
- fixed_samples = get_fixed_samples_by_resolution(dataset)
421
-
422
- def get_negative_embedding(neg_prompt="", batch_size=1):
423
- """
424
- Возвращает эмбеддинг негативного промпта с батчем.
425
- Загружает модели, вычисляет эмбеддинг, выгружает модели на CPU.
426
- """
427
-
428
- # Возвращаем эмбеддинг
429
- if not neg_prompt:
430
- hidden_dim = 1024 # Размерность эмбеддинга Qwen3-Embedding-0.6B
431
- seq_len = 150
432
- return torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device)
433
-
434
- uncond_emb, attention_mask, pooling = encode_texts([neg_prompt]).to(dtype=dtype, device=device)
435
- uncond_emb = uncond_emb.repeat(batch_size, 1, 1) # Добавляем батч
436
- uncond_mask = attention_mask.repeat(batch_size, 1, 1)
437
- uncond_pool = pooling.repeat(batch_size, 1, 1)
438
-
439
- return uncond_emb
440
-
441
- uncond_emb, uncond_mask, uncond_pool = get_negative_embedding("low quality")
442
-
443
- @torch.compiler.disable()
444
- @torch.no_grad()
445
- def generate_and_save_samples(fixed_samples_cpu,empty_embeddings, step):
446
- original_model = None
447
- try:
448
- # безопасный unwrap: если компилировано, unwrap не нужен
449
- if not torch_compile:
450
- original_model = accelerator.unwrap_model(unet, keep_torch_compile=True).eval()
451
- else:
452
- original_model = unet.eval()
453
-
454
- vae.to(device=device).eval() # временно подгружаем VAE на GPU для декодинга
455
-
456
-
457
- all_generated_images = []
458
- all_captions = []
459
-
460
- for size, (sample_latents, sample_text_embeddings, sample_text) in fixed_samples_cpu.items():
461
- width, height = size
462
- sample_latents = sample_latents.to(dtype=dtype, device=device)
463
- sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
464
-
465
- # начальный шум
466
- latents = torch.randn(
467
- sample_latents.shape,
468
- device=device,
469
- dtype=sample_latents.dtype,
470
- generator=torch.Generator(device=device).manual_seed(seed)
471
- )
472
-
473
- # подготовим timesteps через шедулер
474
- scheduler.set_timesteps(n_diffusion_steps, device=device)
475
- prompt_mask = torch.abs(sample_text_embeddings).sum(dim=-1) > 1e-6 # (B, Seq)
476
- prompt_mask = prompt_mask.to(dtype=torch.int64)
477
-
478
- # Создаем маску для негатива (empty_embeddings)
479
- # empty_embeddings у вас [Batch, Seq, Dim], скорее всего там нули кроме первых токенов
480
- neg_mask = torch.abs(empty_embeddings).sum(dim=-1) > 1e-6
481
- neg_mask = neg_mask.repeat(sample_text_embeddings.shape[0], 1).to(dtype=torch.int64, device=device)
482
-
483
- for t in scheduler.timesteps:
484
- # guidance: удваиваем батч
485
- if guidance_scale != 1:
486
- latent_model_input = torch.cat([latents, latents], dim=0)
487
-
488
- # empty_embeddings: [1, 1, hidden_dim] → повторяем по seq_len и batch
489
- seq_len = sample_text_embeddings.shape[1]
490
- hidden_dim = sample_text_embeddings.shape[2]
491
- empty_embeddings_exp = empty_embeddings.expand(-1, seq_len, hidden_dim) # [1, seq_len, hidden_dim]
492
- empty_embeddings_exp = empty_embeddings_exp.repeat(sample_text_embeddings.shape[0], 1, 1) # [batch, seq_len, hidden_dim]
493
-
494
- text_embeddings_batch = torch.cat([empty_embeddings_exp, sample_text_embeddings], dim=0)
495
- attention_mask_batch = torch.cat([neg_mask, prompt_mask], dim=0)
496
- else:
497
- latent_model_input = latents
498
- text_embeddings_batch = sample_text_embeddings
499
- attention_mask_batch = prompt_mask
500
-
501
-
502
- # предсказание потока (velocity)
503
- model_out = original_model(
504
- latent_model_input,
505
- t,
506
- encoder_hidden_states=text_embeddings_batch,
507
- encoder_attention_mask=attention_mask_batch)
508
- flow = getattr(model_out, "sample", model_out)
509
-
510
- # guidance объединение
511
- if guidance_scale != 1:
512
- flow_uncond, flow_cond = flow.chunk(2)
513
- flow = flow_uncond + guidance_scale * (flow_cond - flow_uncond)
514
-
515
- # шаг через scheduler
516
- latents = scheduler.step(flow, t, latents).prev_sample
517
-
518
- current_latents = latents
519
-
520
-
521
- # Параметры нормализации
522
- latent_for_vae = current_latents.detach() / scaling_factor + shift_factor
523
- decoded = vae.decode(latent_for_vae.to(torch.float32)).sample
524
-
525
- decoded_fp32 = decoded.to(torch.float32)
526
- for img_idx, img_tensor in enumerate(decoded_fp32):
527
-
528
- # Форма: [3, H, W] -> преобразуем в [H, W, 3]
529
- img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy()
530
- img = img.transpose(1, 2, 0) # Из [3, H, W] в [H, W, 3]
531
-
532
- #img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1, 2, 0)
533
- if np.isnan(img).any():
534
- print("NaNs found, saving stopped! Step:", step)
535
- pil_img = Image.fromarray((img * 255).astype("uint8"))
536
-
537
- max_w_overall = max(s[0] for s in fixed_samples_cpu.keys())
538
- max_h_overall = max(s[1] for s in fixed_samples_cpu.keys())
539
- max_w_overall = max(255, max_w_overall)
540
- max_h_overall = max(255, max_h_overall)
541
-
542
- padded_img = ImageOps.pad(pil_img, (max_w_overall, max_h_overall), color='white')
543
- all_generated_images.append(padded_img)
544
-
545
- caption_text = sample_text[img_idx][:200] if img_idx < len(sample_text) else ""
546
- all_captions.append(caption_text)
547
-
548
- sample_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg"
549
- pil_img.save(sample_path, "JPEG", quality=96)
550
-
551
- if use_wandb and accelerator.is_main_process:
552
- wandb_images = [
553
- wandb.Image(img, caption=f"{all_captions[i]}")
554
- for i, img in enumerate(all_generated_images)
555
- ]
556
- wandb.log({"generated_images": wandb_images})
557
- if use_comet_ml and accelerator.is_main_process:
558
- for i, img in enumerate(all_generated_images):
559
- comet_experiment.log_image(
560
- image_data=img,
561
- name=f"step_{step}_img_{i}",
562
- step=step,
563
- metadata={
564
- "caption": all_captions[i],
565
- "width": img.width,
566
- "height": img.height,
567
- "global_step": step
568
- }
569
- )
570
- finally:
571
- # вернуть VAE на CPU (как было в твоём коде)
572
- vae.to("cpu")
573
- for var in list(locals().keys()):
574
- if isinstance(locals()[var], torch.Tensor):
575
- del locals()[var]
576
- torch.cuda.empty_cache()
577
- gc.collect()
578
-
579
- # --------------------------- Генерация сэмплов перед обучением ---------------------------
580
- if accelerator.is_main_process:
581
- if save_model:
582
- print("Генерация сэмплов до старта обучения...")
583
- generate_and_save_samples(fixed_samples,uncond_emb,0)
584
- accelerator.wait_for_everyone()
585
-
586
- # Модифицируем функцию сохранения модели для поддержки LoRA
587
- def save_checkpoint(unet, variant=""):
588
- if accelerator.is_main_process:
589
- if lora_name:
590
- save_lora_checkpoint(unet)
591
- else:
592
- # безопасный unwrap для компилированной модели
593
- model_to_save = None
594
- if not torch_compile:
595
- model_to_save = accelerator.unwrap_model(unet)
596
- else:
597
- model_to_save = unet
598
-
599
- if variant != "":
600
- model_to_save.to(dtype=torch.float16).save_pretrained(
601
- os.path.join(checkpoints_folder, f"{project}"), variant=variant
602
- )
603
- else:
604
- model_to_save.save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
605
-
606
- unet = unet.to(dtype=dtype)
607
-
608
- # --------------------------- Тренировочный цикл ---------------------------
609
- if accelerator.is_main_process:
610
- print(f"Total steps per GPU: {total_training_steps}")
611
-
612
- epoch_loss_points = []
613
- progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
614
-
615
- steps_per_epoch = len(dataloader)
616
- sample_interval = max(1, steps_per_epoch // sample_interval_share)
617
- min_loss = 2.
618
-
619
- for epoch in range(start_epoch, start_epoch + num_epochs):
620
- batch_losses = []
621
- batch_grads = []
622
- batch_sampler.set_epoch(epoch)
623
- accelerator.wait_for_everyone()
624
- unet.train()
625
- #print("epoch:",epoch)
626
- for step, (latents, embeddings, attention_mask) in enumerate(dataloader):
627
- with accelerator.accumulate(unet):
628
- if save_model == False and step == 5 :
629
- used_gb = torch.cuda.max_memory_allocated() / 1024**3
630
- print(f"Шаг {step}: {used_gb:.2f} GB")
631
-
632
- # шум
633
- noise = torch.randn_like(latents, dtype=latents.dtype)
634
-
635
- # берём t из [0, 1]
636
- t = torch.rand(latents.shape[0], device=latents.device, dtype=latents.dtype)
637
-
638
- # интерполяция между x0 и шумом
639
- noisy_latents = (1.0 - t.view(-1, 1, 1, 1)) * latents + t.view(-1, 1, 1, 1) * noise
640
-
641
- # делаем integer timesteps для UNet
642
- timesteps = (t * scheduler.config.num_train_timesteps).long()
643
-
644
- # предсказание потока (Flow)
645
- model_pred = unet(noisy_latents, timesteps, embeddings, encoder_attention_mask=attention_mask).sample
646
-
647
- # таргет — векторное поле (= разность между конечными точками)
648
- target = noise - latents # или latents - noise?
649
-
650
- # MSE лосс
651
- mse_loss = F.mse_loss(model_pred.float(), target.float())
652
-
653
- # Сохраняем для логов (мы сохраняем MSE отдельно — как показатель)
654
- batch_losses.append(mse_loss.detach().item())
655
-
656
- if (global_step % 100 == 0) or (global_step % sample_interval == 0):
657
- accelerator.wait_for_everyone()
658
-
659
- # Backward
660
- accelerator.backward(mse_loss)
661
-
662
- if (global_step % 100 == 0) or (global_step % sample_interval == 0):
663
- accelerator.wait_for_everyone()
664
-
665
- grad = 0.0
666
- if not fbp:
667
- if accelerator.sync_gradients:
668
- with torch.amp.autocast('cuda', enabled=False):
669
- grad_val = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm)
670
- grad = float(grad_val)
671
- optimizer.step()
672
- lr_scheduler.step()
673
- optimizer.zero_grad(set_to_none=True)
674
-
675
- if accelerator.sync_gradients:
676
- global_step += 1
677
- progress_bar.update(1)
678
- # Логируем метрики
679
- if accelerator.is_main_process:
680
- if fbp:
681
- current_lr = base_learning_rate
682
- else:
683
- current_lr = lr_scheduler.get_last_lr()[0]
684
- batch_grads.append(grad)
685
-
686
- log_data = {}
687
- log_data["loss"] = mse_loss.detach().item()
688
- log_data["lr"] = current_lr
689
- log_data["grad"] = grad
690
- if accelerator.sync_gradients:
691
- if use_wandb:
692
- wandb.log(log_data, step=global_step)
693
- if use_comet_ml:
694
- comet_experiment.log_metrics(log_data, step=global_step)
695
-
696
- # Генерируем сэмплы с заданным интервалом
697
- if global_step % sample_interval == 0:
698
- generate_and_save_samples(fixed_samples,uncond_emb, global_step)
699
- last_n = sample_interval
700
-
701
- if save_model:
702
- avg_sample_loss = np.mean(batch_losses[-sample_interval:]) if len(batch_losses) > 0 else 0.0
703
- print("saving:", avg_sample_loss < min_loss * save_barrier, "Avg:", avg_sample_loss)
704
- if avg_sample_loss is not None and avg_sample_loss < min_loss * save_barrier:
705
- min_loss = avg_sample_loss
706
- save_checkpoint(unet)
707
-
708
-
709
- if accelerator.is_main_process:
710
- # local averages
711
- avg_epoch_loss = np.mean(batch_losses) if len(batch_losses) > 0 else 0.0
712
- avg_epoch_grad = np.mean(batch_grads) if len(batch_grads) > 0 else 0.0
713
-
714
- print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
715
- log_data_ep = {
716
- "epoch_loss": avg_epoch_loss,
717
- "epoch_grad": avg_epoch_grad,
718
- "epoch": epoch + 1,
719
- }
720
- if use_wandb:
721
- wandb.log(log_data_ep)
722
- if use_comet_ml:
723
- comet_experiment.log_metrics(log_data_ep)
724
-
725
- # Завершение обучения - сохраняем финальную модель
726
- if accelerator.is_main_process:
727
- print("Обучение завершено! Сохраняем финальную модель...")
728
- if save_model:
729
- save_checkpoint(unet,"fp16")
730
- if use_comet_ml:
731
- comet_experiment.end()
732
- accelerator.free_memory()
733
- if torch.distributed.is_initialized():
734
- torch.distributed.destroy_process_group()
735
-
736
- print("Готово!")