recoilme commited on
Commit
24c1527
·
1 Parent(s): a9765ed
samples/unet_192x384_0.jpg CHANGED

Git LFS Details

  • SHA256: 19ecd5d8bcc9fc82302229e8def5628673614c40131944d167a4d225beacd004
  • Pointer size: 130 Bytes
  • Size of remote file: 23.3 kB

Git LFS Details

  • SHA256: ac2e69de87e7f3cf767b13f43d2860c8fac677e429d9f70ef7e8e590bd332647
  • Pointer size: 130 Bytes
  • Size of remote file: 50 kB
samples/unet_256x384_0.jpg CHANGED

Git LFS Details

  • SHA256: 6ffafc696afa16f4b674ce8ae96957ef03fcdbcd951490eb392292ca935c6bb0
  • Pointer size: 130 Bytes
  • Size of remote file: 60.9 kB

Git LFS Details

  • SHA256: 3ae11dbe1e3ff86a70ca966d353ea67dfffb482284e3413e7f4ba2f15e778722
  • Pointer size: 130 Bytes
  • Size of remote file: 54.1 kB
samples/unet_320x384_0.jpg CHANGED

Git LFS Details

  • SHA256: dd7527d279fd00015f2cad9d8e4a7e245abd54caffe9bb3da45fa7e7359036fa
  • Pointer size: 130 Bytes
  • Size of remote file: 62 kB

Git LFS Details

  • SHA256: 9b33d8af5a779ffce31818e3dc0092c300250734ca0653bd8bec06eaa1b37f2d
  • Pointer size: 130 Bytes
  • Size of remote file: 43 kB
samples/unet_384x192_0.jpg CHANGED

Git LFS Details

  • SHA256: af38c88e6663334f9d67d857a0a5e6766520b09066eee676160d64ec94f35e00
  • Pointer size: 130 Bytes
  • Size of remote file: 44.3 kB

Git LFS Details

  • SHA256: d0e83f396229099798bc7dd38348b753008d674a813addedf235ea781f607d68
  • Pointer size: 130 Bytes
  • Size of remote file: 32.8 kB
samples/unet_384x256_0.jpg CHANGED

Git LFS Details

  • SHA256: 7eb455baf1cbd8cbca8b903d17b7814cd95d0e78b678be73d9c6ca81be56d791
  • Pointer size: 130 Bytes
  • Size of remote file: 48 kB

Git LFS Details

  • SHA256: 440f7dabbfc3eb4747d6c1ac541d5ef559f13a170ac22f4f6aa2d1b4c5e8efc5
  • Pointer size: 130 Bytes
  • Size of remote file: 39.7 kB
samples/unet_384x320_0.jpg CHANGED

Git LFS Details

  • SHA256: 252f58a3bfde894c162a061b79080d0f830fbcd67c2c4099936897166259e936
  • Pointer size: 130 Bytes
  • Size of remote file: 61.1 kB

Git LFS Details

  • SHA256: a05ac64e8139133dc627ee89014bb8ef0b6663d68a52140f1bd5520c0d9704d0
  • Pointer size: 130 Bytes
  • Size of remote file: 45.9 kB
samples/unet_384x384_0.jpg CHANGED

Git LFS Details

  • SHA256: b467fcecee636eec12c101aa0fdfbf8c3d5fd7449dc21d57ac0f901a2de39a6e
  • Pointer size: 130 Bytes
  • Size of remote file: 86 kB

Git LFS Details

  • SHA256: c88ac075bcc53c63679d296d43737b01b5d4ef66c518192ac7f1910a80934970
  • Pointer size: 130 Bytes
  • Size of remote file: 62.7 kB
train-Copy1.py ADDED
@@ -0,0 +1,679 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import math
3
+ import torch
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
6
+ from torch.utils.data import DataLoader, Sampler
7
+ from torch.utils.data.distributed import DistributedSampler
8
+ from torch.optim.lr_scheduler import LambdaLR
9
+ from collections import defaultdict
10
+ from torch.optim.lr_scheduler import LambdaLR
11
+ from diffusers import UNet2DConditionModel, AutoencoderKLWan,AutoencoderKL
12
+ from accelerate import Accelerator
13
+ from datasets import load_from_disk
14
+ from tqdm import tqdm
15
+ from PIL import Image,ImageOps
16
+ import wandb
17
+ import random
18
+ import gc
19
+ from accelerate.state import DistributedType
20
+ from torch.distributed import broadcast_object_list
21
+ from torch.utils.checkpoint import checkpoint
22
+ from diffusers.models.attention_processor import AttnProcessor2_0
23
+ from datetime import datetime
24
+ import bitsandbytes as bnb
25
+ import torch.nn.functional as F
26
+ from collections import deque
27
+
28
+ # --------------------------- Параметры ---------------------------
29
+ ds_path = "/workspace/sdxs3d/datasets/mjnj"
30
+ project = "unet"
31
+ batch_size = 128
32
+ base_learning_rate = 9e-5
33
+ min_learning_rate = 1e-5
34
+ num_epochs = 84
35
+ # samples/save per epoch
36
+ sample_interval_share = 5
37
+ use_wandb = True
38
+ use_comet_ml = False
39
+ save_model = True
40
+ use_decay = True
41
+ fbp = False # fused backward pass
42
+ optimizer_type = "adam8bit"
43
+ torch_compile = False
44
+ unet_gradient = True
45
+ clip_sample = False #Scheduler
46
+ fixed_seed = True
47
+ shuffle = True
48
+ comet_ml_api_key = "Agctp26mbqnoYrrlvQuKSTk6r" # Добавлен API ключ для Comet ML
49
+ comet_ml_workspace = "recoilme" # Добавлен workspace для Comet ML
50
+ torch.backends.cuda.matmul.allow_tf32 = True
51
+ torch.backends.cudnn.allow_tf32 = True
52
+ torch.backends.cuda.enable_mem_efficient_sdp(False)
53
+ dtype = torch.float32
54
+ save_barrier = 1.01
55
+ warmup_percent = 0.01
56
+ percentile_clipping = 99 # 8bit optim
57
+ betta2 = 0.995
58
+ eps = 1e-8
59
+ clip_grad_norm = 1.0
60
+ steps_offset = 0 # Scheduler
61
+ limit = 0
62
+ checkpoints_folder = ""
63
+ mixed_precision = "no" #"fp16"
64
+ gradient_accumulation_steps = 1
65
+ accelerator = Accelerator(
66
+ mixed_precision=mixed_precision,
67
+ gradient_accumulation_steps=gradient_accumulation_steps
68
+ )
69
+ device = accelerator.device
70
+
71
+ # Параметры для диффузии
72
+ n_diffusion_steps = 50
73
+ samples_to_generate = 12
74
+ guidance_scale = 1
75
+
76
+ # Папки для сохранения результатов
77
+ generated_folder = "samples"
78
+ os.makedirs(generated_folder, exist_ok=True)
79
+
80
+ # Настройка seed для воспроизводимости
81
+ current_date = datetime.now()
82
+ seed = int(current_date.strftime("%Y%m%d"))
83
+ if fixed_seed:
84
+ torch.manual_seed(seed)
85
+ np.random.seed(seed)
86
+ random.seed(seed)
87
+ if torch.cuda.is_available():
88
+ torch.cuda.manual_seed_all(seed)
89
+
90
+ # --------------------------- Параметры LoRA ---------------------------
91
+ lora_name = ""
92
+ lora_rank = 32
93
+ lora_alpha = 64
94
+
95
+ print("init")
96
+
97
+ # --------------------------- вспомогательные функции ---------------------------
98
+ def sample_timesteps_bias(
99
+ batch_size: int,
100
+ progress: float, # [0..1]
101
+ num_train_timesteps: int, # обычно 1000
102
+ steps_offset: int = 0,
103
+ device=None,
104
+ mode: str = "beta", # "beta", "uniform"
105
+ ) -> torch.Tensor:
106
+ """
107
+ Возвращает timesteps с разным bias:
108
+ - beta : как раньше (сдвиг в начало или конец в зависимости от progress)
109
+ - normal : около середины (гауссовое распределение)
110
+ - uniform: равномерно по всем timestep’ам
111
+ """
112
+
113
+ max_idx = num_train_timesteps - 1 - steps_offset
114
+
115
+ if mode == "beta":
116
+ alpha = 1.0 + .5 * (1.0 - progress)
117
+ beta = 1.0 + .5 * progress
118
+ samples = torch.distributions.Beta(alpha, beta).sample((batch_size,))
119
+
120
+ elif mode == "uniform":
121
+ samples = torch.rand(batch_size)
122
+
123
+ else:
124
+ raise ValueError(f"Unknown mode: {mode}")
125
+
126
+ timesteps = steps_offset + (samples * max_idx).long().to(device)
127
+ return timesteps
128
+
129
+ def logit_normal_samples(shape, mu=0.0, sigma=1.0, device=None, dtype=None):
130
+ normal_samples = torch.normal(mean=mu, std=sigma, size=shape, device=device, dtype=dtype)
131
+
132
+ logit_normal_samples = torch.sigmoid(normal_samples)
133
+
134
+ return logit_normal_samples
135
+
136
+ # --------------------------- Инициализация WandB ---------------------------
137
+ if accelerator.is_main_process:
138
+ if use_wandb:
139
+ wandb.init(project=project+lora_name, config={
140
+ "batch_size": batch_size,
141
+ "base_learning_rate": base_learning_rate,
142
+ "num_epochs": num_epochs,
143
+ "fbp": fbp,
144
+ "optimizer_type": optimizer_type,
145
+ })
146
+ if use_comet_ml:
147
+ from comet_ml import Experiment
148
+ comet_experiment = Experiment(
149
+ api_key=comet_ml_api_key,
150
+ project_name=project,
151
+ workspace=comet_ml_workspace
152
+ )
153
+ # Логируем гиперпараметры в Comet ML
154
+ hyper_params = {
155
+ "batch_size": batch_size,
156
+ "base_learning_rate": base_learning_rate,
157
+ "min_learning_rate": min_learning_rate,
158
+ "num_epochs": num_epochs,
159
+ "n_diffusion_steps": n_diffusion_steps,
160
+ "guidance_scale": guidance_scale,
161
+ "optimizer_type": optimizer_type,
162
+ "mixed_precision": mixed_precision,
163
+ }
164
+ comet_experiment.log_parameters(hyper_params)
165
+
166
+ # Включение Flash Attention 2/SDPA
167
+ torch.backends.cuda.enable_flash_sdp(True)
168
+ # --------------------------- Инициализация Accelerator --------------------
169
+ gen = torch.Generator(device=device)
170
+ gen.manual_seed(seed)
171
+
172
+ # --------------------------- Загрузка моделей ---------------------------
173
+ # VAE загружается на CPU для экономии GPU-памяти (как в твоём оригинальном коде)
174
+ vae = AutoencoderKL.from_pretrained("AiArtLab/simplevae",subfolder="simple_vae_nightly",torch_dtype=dtype).to("cpu").eval()
175
+
176
+ shift_factor = getattr(vae.config, "shift_factor", 0.0)
177
+ if shift_factor is None:
178
+ shift_factor = 0.0
179
+
180
+ scaling_factor = getattr(vae.config, "scaling_factor", 1.0)
181
+ if scaling_factor is None:
182
+ scaling_factor = 1.0
183
+
184
+ latents_mean = getattr(vae.config, "latents_mean", None)
185
+ latents_std = getattr(vae.config, "latents_std", None)
186
+
187
+
188
+
189
+
190
+ class DistributedResolutionBatchSampler(Sampler):
191
+ def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, drop_last=True):
192
+ self.dataset = dataset
193
+ self.batch_size = max(1, batch_size // num_replicas)
194
+ self.num_replicas = num_replicas
195
+ self.rank = rank
196
+ self.shuffle = shuffle
197
+ self.drop_last = drop_last
198
+ self.epoch = 0
199
+
200
+ try:
201
+ widths = np.array(dataset["width"])
202
+ heights = np.array(dataset["height"])
203
+ except KeyError:
204
+ widths = np.zeros(len(dataset))
205
+ heights = np.zeros(len(dataset))
206
+
207
+ self.size_keys = np.unique(np.stack([widths, heights], axis=1), axis=0)
208
+ self.size_groups = {}
209
+ for w, h in self.size_keys:
210
+ mask = (widths == w) & (heights == h)
211
+ self.size_groups[(w, h)] = np.where(mask)[0]
212
+
213
+ self.group_num_batches = {}
214
+ total_batches = 0
215
+ for size, indices in self.size_groups.items():
216
+ num_full_batches = len(indices) // (self.batch_size * self.num_replicas)
217
+ self.group_num_batches[size] = num_full_batches
218
+ total_batches += num_full_batches
219
+
220
+ self.num_batches = (total_batches // self.num_replicas) * self.num_replicas
221
+
222
+ def __iter__(self):
223
+ if torch.cuda.is_available():
224
+ torch.cuda.empty_cache()
225
+ all_batches = []
226
+ rng = np.random.RandomState(self.epoch)
227
+
228
+ for size, indices in self.size_groups.items():
229
+ indices = indices.copy()
230
+ if self.shuffle:
231
+ rng.shuffle(indices)
232
+ num_full_batches = self.group_num_batches[size]
233
+ if num_full_batches == 0:
234
+ continue
235
+ valid_indices = indices[:num_full_batches * self.batch_size * self.num_replicas]
236
+ batches = valid_indices.reshape(-1, self.batch_size * self.num_replicas)
237
+ start_idx = self.rank * self.batch_size
238
+ end_idx = start_idx + self.batch_size
239
+ gpu_batches = batches[:, start_idx:end_idx]
240
+ all_batches.extend(gpu_batches)
241
+
242
+ if self.shuffle:
243
+ rng.shuffle(all_batches)
244
+ accelerator.wait_for_everyone()
245
+ return iter(all_batches)
246
+
247
+ def __len__(self):
248
+ return self.num_batches
249
+
250
+ def set_epoch(self, epoch):
251
+ self.epoch = epoch
252
+
253
+ # Функция для выборки фиксированных семплов по размерам
254
+ def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
255
+ size_groups = defaultdict(list)
256
+ try:
257
+ widths = dataset["width"]
258
+ heights = dataset["height"]
259
+ except KeyError:
260
+ widths = [0] * len(dataset)
261
+ heights = [0] * len(dataset)
262
+ for i, (w, h) in enumerate(zip(widths, heights)):
263
+ size = (w, h)
264
+ size_groups[size].append(i)
265
+
266
+ fixed_samples = {}
267
+ for size, indices in size_groups.items():
268
+ n_samples = min(samples_per_group, len(indices))
269
+ if len(size_groups)==1:
270
+ n_samples = samples_to_generate
271
+ if n_samples == 0:
272
+ continue
273
+ sample_indices = random.sample(indices, n_samples)
274
+ samples_data = [dataset[idx] for idx in sample_indices]
275
+ latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device,dtype=dtype)
276
+ embeddings = torch.tensor(np.array([item["embeddings"] for item in samples_data])).to(device,dtype=dtype)
277
+ texts = [item["text"] for item in samples_data]
278
+ fixed_samples[size] = (latents, embeddings, texts)
279
+
280
+ print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
281
+ return fixed_samples
282
+
283
+ if limit > 0:
284
+ dataset = load_from_disk(ds_path).select(range(limit))
285
+ else:
286
+ dataset = load_from_disk(ds_path)
287
+
288
+ def collate_fn_simple(batch):
289
+ latents = torch.tensor(np.array([item["vae"] for item in batch])).to(device,dtype=dtype)
290
+ embeddings = torch.tensor(np.array([item["embeddings"] for item in batch])).to(device,dtype=dtype)
291
+ return latents, embeddings
292
+
293
+ batch_sampler = DistributedResolutionBatchSampler(
294
+ dataset=dataset,
295
+ batch_size=batch_size,
296
+ num_replicas=accelerator.num_processes,
297
+ rank=accelerator.process_index,
298
+ shuffle=shuffle
299
+ )
300
+
301
+ dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple)
302
+ print("Total samples",len(dataloader))
303
+ dataloader = accelerator.prepare(dataloader)
304
+
305
+ start_epoch = 0
306
+ global_step = 0
307
+ total_training_steps = (len(dataloader) * num_epochs)
308
+ world_size = accelerator.state.num_processes
309
+
310
+ # Опция загрузки модели из последнего чекпоинта (если существует)
311
+ latest_checkpoint = os.path.join(checkpoints_folder, project)
312
+ if os.path.isdir(latest_checkpoint):
313
+ print("Загружаем UNet из чекпоинта:", latest_checkpoint)
314
+ unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device,dtype=dtype)
315
+ if torch_compile:
316
+ print("compiling")
317
+ torch.set_float32_matmul_precision('high')
318
+ unet = torch.compile(unet)
319
+ print("compiling - ok")
320
+ if unet_gradient:
321
+ unet.enable_gradient_checkpointing()
322
+ unet.set_use_memory_efficient_attention_xformers(False)
323
+ try:
324
+ unet.set_attn_processor(AttnProcessor2_0())
325
+ except Exception as e:
326
+ print(f"Ошибка при включении SDPA: {e}")
327
+ unet.set_use_memory_efficient_attention_xformers(True)
328
+
329
+ else:
330
+ # FIX: если чекпоинта нет — прекращаем с понятной ошибкой (лучше, чем неожиданные NameError дальше)
331
+ raise FileNotFoundError(f"UNet checkpoint not found at {latest_checkpoint}. Положи UNet чекпоинт в {latest_checkpoint} или укажи другой путь.")
332
+
333
+ if lora_name:
334
+ print(f"--- Настройка LoRA через PEFT (Rank={lora_rank}, Alpha={lora_alpha}) ---")
335
+ from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
336
+ from peft.tuners.lora import LoraModel
337
+ import os
338
+ unet.requires_grad_(False)
339
+ print("Параметры базового UNet заморожены.")
340
+
341
+ lora_config = LoraConfig(
342
+ r=lora_rank,
343
+ lora_alpha=lora_alpha,
344
+ target_modules=["to_q", "to_k", "to_v", "to_out.0"],
345
+ )
346
+ unet.add_adapter(lora_config)
347
+
348
+ from peft import get_peft_model
349
+ peft_unet = get_peft_model(unet, lora_config)
350
+ params_to_optimize = list(p for p in peft_unet.parameters() if p.requires_grad)
351
+
352
+ if accelerator.is_main_process:
353
+ lora_params_count = sum(p.numel() for p in params_to_optimize)
354
+ total_params_count = sum(p.numel() for p in unet.parameters())
355
+ print(f"Количество обучаемых параметров (LoRA): {lora_params_count:,}")
356
+ print(f"Общее количество параметров UNet: {total_params_count:,}")
357
+
358
+ lora_save_path = os.path.join("lora", lora_name)
359
+ os.makedirs(lora_save_path, exist_ok=True)
360
+
361
+ def save_lora_checkpoint(model):
362
+ if accelerator.is_main_process:
363
+ print(f"Сохраняем LoRA адаптеры в {lora_save_path}")
364
+ from peft.utils.save_and_load import get_peft_model_state_dict
365
+ lora_state_dict = get_peft_model_state_dict(model)
366
+ torch.save(lora_state_dict, os.path.join(lora_save_path, "adapter_model.bin"))
367
+ model.peft_config["default"].save_pretrained(lora_save_path)
368
+ from diffusers import StableDiffusionXLPipeline
369
+ StableDiffusionXLPipeline.save_lora_weights(lora_save_path, lora_state_dict)
370
+
371
+ # --------------------------- Оптимизатор ---------------------------
372
+ if lora_name:
373
+ trainable_params = [p for p in unet.parameters() if p.requires_grad]
374
+ else:
375
+ if fbp:
376
+ trainable_params = list(unet.parameters())
377
+
378
+ def create_optimizer(name, params):
379
+ if name == "adam8bit":
380
+ return bnb.optim.AdamW8bit(
381
+ params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.01,
382
+ percentile_clipping=percentile_clipping
383
+ )
384
+ elif name == "adam":
385
+ return torch.optim.AdamW(
386
+ params, lr=base_learning_rate, betas=(0.9, 0.999), eps=1e-8, weight_decay=0.01
387
+ )
388
+ elif name == "lion8bit":
389
+ return bnb.optim.Lion8bit(
390
+ params, lr=base_learning_rate, betas=(0.9, 0.97), weight_decay=0.01,
391
+ percentile_clipping=percentile_clipping
392
+ )
393
+ elif name == "adafactor":
394
+ from transformers import Adafactor
395
+ return Adafactor(
396
+ params, lr=base_learning_rate, scale_parameter=True, relative_step=False,
397
+ warmup_init=False, eps=(1e-30, 1e-3), clip_threshold=1.0,
398
+ beta1=0.9, weight_decay=0.01
399
+ )
400
+ else:
401
+ raise ValueError(f"Unknown optimizer: {name}")
402
+
403
+ if fbp:
404
+ optimizer_dict = {p: create_optimizer(optimizer_type, [p]) for p in trainable_params}
405
+ def optimizer_hook(param):
406
+ optimizer_dict[param].step()
407
+ optimizer_dict[param].zero_grad(set_to_none=True)
408
+ for param in trainable_params:
409
+ param.register_post_accumulate_grad_hook(optimizer_hook)
410
+ unet, optimizer = accelerator.prepare(unet, optimizer_dict)
411
+ else:
412
+ optimizer = create_optimizer(optimizer_type, unet.parameters())
413
+ def lr_schedule(step):
414
+ x = step / (total_training_steps * world_size)
415
+ warmup = warmup_percent
416
+ if not use_decay:
417
+ return base_learning_rate
418
+ if x < warmup:
419
+ return min_learning_rate + (base_learning_rate - min_learning_rate) * (x / warmup)
420
+ decay_ratio = (x - warmup) / (1 - warmup)
421
+ return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
422
+ (1 + math.cos(math.pi * decay_ratio))
423
+ lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
424
+
425
+ num_params = sum(p.numel() for p in unet.parameters())
426
+ print(f"[rank {accelerator.process_index}] total params: {num_params}")
427
+ for name, param in unet.named_parameters():
428
+ if torch.isnan(param).any() or torch.isinf(param).any():
429
+ print(f"[rank {accelerator.process_index}] NaN/Inf in {name}")
430
+ unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
431
+
432
+ # --------------------------- Фиксированные семплы для генерации ---------------------------
433
+ fixed_samples = get_fixed_samples_by_resolution(dataset)
434
+
435
+ @torch.compiler.disable()
436
+ @torch.no_grad()
437
+ def generate_and_save_samples(fixed_samples_cpu, step):
438
+ original_model = None
439
+ try:
440
+ original_model = accelerator.unwrap_model(unet, keep_torch_compile=True).eval()
441
+ vae.to(device=device).eval() # временно подгружаем VAE на GPU для декодинга
442
+
443
+
444
+ all_generated_images = []
445
+ all_captions = []
446
+
447
+ for size, (sample_latents, sample_text_embeddings, sample_text) in fixed_samples_cpu.items():
448
+ width, height = size
449
+ sample_latents = sample_latents.to(dtype=dtype, device=device)
450
+ sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
451
+
452
+ noise = torch.randn(
453
+ sample_latents.shape,
454
+ generator=gen,
455
+ device=device,
456
+ dtype=sample_latents.dtype
457
+ )
458
+ current_latents = noise.clone()
459
+
460
+ if guidance_scale != 1:
461
+ empty_embeddings = torch.zeros_like(sample_text_embeddings, dtype=sample_text_embeddings.dtype, device=device)
462
+ text_embeddings_batch = torch.cat([empty_embeddings, sample_text_embeddings], dim=0)
463
+ else:
464
+ text_embeddings_batch = sample_text_embeddings
465
+
466
+ timesteps = torch.linspace(0, 1, n_diffusion_steps+1, device=device, dtype=sample_latents.dtype)
467
+ for i in range(0, n_diffusion_steps):
468
+ t_cur = timesteps[i].unsqueeze(0)
469
+ t_next = timesteps[i+1]
470
+ dt = t_next - t_cur
471
+ if guidance_scale != 1:
472
+ latent_model_input = torch.cat((current_latents, current_latents))
473
+ else:
474
+ latent_model_input = current_latents
475
+ t_batch = t_cur.repeat(latent_model_input.shape[0]).to(device)
476
+ t_batch = (t_batch * 1000).long().view(-1)
477
+ flow = original_model(latent_model_input, t_batch, text_embeddings_batch).sample
478
+
479
+ if guidance_scale != 1:
480
+ flow_uncond, flow_cond = flow.chunk(2)
481
+ flow = flow_uncond + guidance_scale * (flow_cond - flow_uncond)
482
+
483
+ current_latents = current_latents + flow * dt.to(device)
484
+
485
+ # Параметры нормализации
486
+ latent_for_vae = current_latents.detach() / scaling_factor + shift_factor
487
+
488
+ decoded = vae.decode(latent_for_vae.to(torch.float32)).sample
489
+ #decoded = decoded[:, :, 0, :, :] # [3, H, W]
490
+ #print(decoded.ndim, decoded.shape)
491
+
492
+ decoded_fp32 = decoded.to(torch.float32)
493
+ for img_idx, img_tensor in enumerate(decoded_fp32):
494
+
495
+ # Форма: [3, H, W] -> преобразуем в [H, W, 3]
496
+ img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy()
497
+ img = img.transpose(1, 2, 0) # Из [3, H, W] в [H, W, 3]
498
+
499
+ #img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1, 2, 0)
500
+ if np.isnan(img).any():
501
+ print("NaNs found, saving stopped! Step:", step)
502
+ pil_img = Image.fromarray((img * 255).astype("uint8"))
503
+
504
+ max_w_overall = max(s[0] for s in fixed_samples_cpu.keys())
505
+ max_h_overall = max(s[1] for s in fixed_samples_cpu.keys())
506
+ max_w_overall = max(255, max_w_overall)
507
+ max_h_overall = max(255, max_h_overall)
508
+
509
+ padded_img = ImageOps.pad(pil_img, (max_w_overall, max_h_overall), color='white')
510
+ all_generated_images.append(padded_img)
511
+
512
+ caption_text = sample_text[img_idx][:200] if img_idx < len(sample_text) else ""
513
+ all_captions.append(caption_text)
514
+
515
+ sample_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg"
516
+ pil_img.save(sample_path, "JPEG", quality=96)
517
+
518
+ if use_wandb and accelerator.is_main_process:
519
+ wandb_images = [
520
+ wandb.Image(img, caption=f"{all_captions[i]}")
521
+ for i, img in enumerate(all_generated_images)
522
+ ]
523
+ wandb.log({"generated_images": wandb_images})
524
+ if use_comet_ml and accelerator.is_main_process:
525
+ for i, img in enumerate(all_generated_images):
526
+ comet_experiment.log_image(
527
+ image_data=img,
528
+ name=f"step_{step}_img_{i}",
529
+ step=step,
530
+ metadata={
531
+ "caption": all_captions[i],
532
+ "width": img.width,
533
+ "height": img.height,
534
+ "global_step": step
535
+ }
536
+ )
537
+ finally:
538
+ # вернуть VAE на CPU (как было в твоём коде)
539
+ vae.to("cpu")
540
+ for var in list(locals().keys()):
541
+ if isinstance(locals()[var], torch.Tensor):
542
+ del locals()[var]
543
+ torch.cuda.empty_cache()
544
+ gc.collect()
545
+
546
+ # --------------------------- Генерация сэмплов перед обучением ---------------------------
547
+ if accelerator.is_main_process:
548
+ if save_model:
549
+ print("Генерация сэмплов до старта обучения...")
550
+ generate_and_save_samples(fixed_samples,0)
551
+ accelerator.wait_for_everyone()
552
+
553
+ # Модифицируем функцию сохранения модели для поддержки LoRA
554
+ def save_checkpoint(unet,variant=""):
555
+ if accelerator.is_main_process:
556
+ if lora_name:
557
+ save_lora_checkpoint(unet)
558
+ else:
559
+ if variant!="":
560
+ accelerator.unwrap_model(unet.to(dtype=torch.float16)).save_pretrained(os.path.join(checkpoints_folder, f"{project}"),variant=variant)
561
+ else:
562
+ accelerator.unwrap_model(unet).save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
563
+ unet = unet.to(dtype=dtype)
564
+
565
+ # --------------------------- Тренировочный цикл ---------------------------
566
+ if accelerator.is_main_process:
567
+ print(f"Total steps per GPU: {total_training_steps}")
568
+
569
+ epoch_loss_points = []
570
+ progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
571
+
572
+ steps_per_epoch = len(dataloader)
573
+ sample_interval = max(1, steps_per_epoch // sample_interval_share)
574
+ min_loss = 2.
575
+
576
+ for epoch in range(start_epoch, start_epoch + num_epochs):
577
+ batch_losses = []
578
+ batch_grads = []
579
+ batch_sampler.set_epoch(epoch)
580
+ accelerator.wait_for_everyone()
581
+ unet.train()
582
+ #print("epoch:",epoch)
583
+ for step, (latents, embeddings) in enumerate(dataloader):
584
+ with accelerator.accumulate(unet):
585
+ if save_model == False and step == 5 :
586
+ used_gb = torch.cuda.max_memory_allocated() / 1024**3
587
+ print(f"Шаг {step}: {used_gb:.2f} GB")
588
+
589
+ noise = torch.randn_like(latents, dtype=latents.dtype)
590
+ t = logit_normal_samples((latents.shape[0], 1, 1, 1), mu=0.0, sigma=1.0, device=latents.device, dtype=latents.dtype)
591
+ noisy_latents = (1 - t) * noise + t * latents
592
+
593
+ t_for_unet = (t * 1000).long().view(-1)
594
+ model_pred = unet(noisy_latents, t_for_unet, embeddings).sample
595
+ target_pred = latents - noise
596
+
597
+ mse_loss = F.mse_loss(model_pred.float(), target_pred.float())
598
+
599
+ # Сохраняем для логов (мы сохраняем MSE отдельно — как показатель)
600
+ batch_losses.append(mse_loss.detach().item())
601
+
602
+ if (global_step % 100 == 0) or (global_step % sample_interval == 0):
603
+ accelerator.wait_for_everyone()
604
+
605
+ # Backward
606
+ accelerator.backward(mse_loss)
607
+
608
+ if (global_step % 100 == 0) or (global_step % sample_interval == 0):
609
+ accelerator.wait_for_everyone()
610
+
611
+ grad = 0.0
612
+ if not fbp:
613
+ if accelerator.sync_gradients:
614
+ with torch.amp.autocast('cuda', enabled=False):
615
+ grad_val = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm)
616
+ grad = float(grad_val)
617
+ optimizer.step()
618
+ lr_scheduler.step()
619
+ optimizer.zero_grad(set_to_none=True)
620
+
621
+ global_step += 1
622
+ progress_bar.update(1)
623
+
624
+ # Логируем метрики
625
+ if accelerator.is_main_process:
626
+ if fbp:
627
+ current_lr = base_learning_rate
628
+ else:
629
+ current_lr = lr_scheduler.get_last_lr()[0]
630
+ batch_grads.append(grad)
631
+
632
+ log_data = {}
633
+ log_data["loss"] = mse_loss.detach().item()
634
+ log_data["lr"] = current_lr
635
+ log_data["grad"] = grad
636
+ if accelerator.sync_gradients:
637
+ if use_wandb:
638
+ wandb.log(log_data, step=global_step)
639
+ if use_comet_ml:
640
+ comet_experiment.log_metrics(log_data, step=global_step)
641
+
642
+ # Генерируем сэмплы с заданным интервалом
643
+ if global_step % sample_interval == 0:
644
+ generate_and_save_samples(fixed_samples,global_step)
645
+ last_n = sample_interval
646
+
647
+ if save_model:
648
+ avg_sample_loss = np.mean(batch_losses[-sample_interval:]) if len(batch_losses)>0 else 0.0
649
+ print("saving:",avg_sample_loss < min_loss*save_barrier,"Current:",log_data["loss"],"Avg:",avg_sample_loss)
650
+ if log_data["loss"] < min_loss*save_barrier:
651
+ min_loss = avg_sample_loss
652
+ save_checkpoint(unet)
653
+ if use_wandb:
654
+ wandb.log(log_data, step=global_step)
655
+ if use_comet_ml:
656
+ comet_experiment.log_metrics(log_data, step=global_step)
657
+
658
+
659
+ if accelerator.is_main_process:
660
+ avg_epoch_loss = np.mean(batch_losses[-steps_per_epoch:]) if len(batch_losses)>0 else 0.0
661
+ avg_epoch_grad = np.mean(batch_grads[-steps_per_epoch:]) if len(batch_grads)>0 else 0.0
662
+ print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
663
+ if use_wandb:
664
+ wandb.log({"epoch_loss": avg_epoch_loss, "epoch_grad": avg_epoch_grad, "epoch": epoch+1})
665
+ #if use_comet_ml:
666
+ # comet_experiment.log_metrics(epoch_data)
667
+
668
+ # Завершение обучения - сохраняем финальную модель
669
+ if accelerator.is_main_process:
670
+ print("Обучение завершено! Сохраняем финальную модель...")
671
+ if save_model:
672
+ save_checkpoint(unet,"fp16")
673
+ if use_comet_ml:
674
+ comet_experiment.end()
675
+ accelerator.free_memory()
676
+ if torch.distributed.is_initialized():
677
+ torch.distributed.destroy_process_group()
678
+
679
+ print("Готово!")
train.py CHANGED
@@ -7,12 +7,11 @@ from torch.utils.data import DataLoader, Sampler
7
  from torch.utils.data.distributed import DistributedSampler
8
  from torch.optim.lr_scheduler import LambdaLR
9
  from collections import defaultdict
10
- from torch.optim.lr_scheduler import LambdaLR
11
- from diffusers import UNet2DConditionModel, AutoencoderKLWan,AutoencoderKL
12
  from accelerate import Accelerator
13
  from datasets import load_from_disk
14
  from tqdm import tqdm
15
- from PIL import Image,ImageOps
16
  import wandb
17
  import random
18
  import gc
@@ -45,8 +44,8 @@ unet_gradient = True
45
  clip_sample = False #Scheduler
46
  fixed_seed = True
47
  shuffle = True
48
- comet_ml_api_key = "Agctp26mbqnoYrrlvQuKSTk6r" # Добавлен API ключ для Comet ML
49
- comet_ml_workspace = "recoilme" # Добавлен workspace для Comet ML
50
  torch.backends.cuda.matmul.allow_tf32 = True
51
  torch.backends.cudnn.allow_tf32 = True
52
  torch.backends.cuda.enable_mem_efficient_sdp(False)
@@ -60,7 +59,7 @@ clip_grad_norm = 1.0
60
  steps_offset = 0 # Scheduler
61
  limit = 0
62
  checkpoints_folder = ""
63
- mixed_precision = "no" #"fp16"
64
  gradient_accumulation_steps = 1
65
  accelerator = Accelerator(
66
  mixed_precision=mixed_precision,
@@ -88,7 +87,7 @@ if fixed_seed:
88
  torch.cuda.manual_seed_all(seed)
89
 
90
  # --------------------------- Параметры LoRA ---------------------------
91
- lora_name = ""
92
  lora_rank = 32
93
  lora_alpha = 64
94
 
@@ -103,13 +102,6 @@ def sample_timesteps_bias(
103
  device=None,
104
  mode: str = "beta", # "beta", "uniform"
105
  ) -> torch.Tensor:
106
- """
107
- Возвращает timesteps с разным bias:
108
- - beta : как раньше (сдвиг в начало или конец в зависимости от progress)
109
- - normal : около середины (гауссовое распределение)
110
- - uniform: равномерно по всем timestep’ам
111
- """
112
-
113
  max_idx = num_train_timesteps - 1 - steps_offset
114
 
115
  if mode == "beta":
@@ -126,17 +118,16 @@ def sample_timesteps_bias(
126
  timesteps = steps_offset + (samples * max_idx).long().to(device)
127
  return timesteps
128
 
 
129
  def logit_normal_samples(shape, mu=0.0, sigma=1.0, device=None, dtype=None):
130
  normal_samples = torch.normal(mean=mu, std=sigma, size=shape, device=device, dtype=dtype)
131
-
132
  logit_normal_samples = torch.sigmoid(normal_samples)
133
-
134
  return logit_normal_samples
135
 
136
  # --------------------------- Инициализация WandB ---------------------------
137
  if accelerator.is_main_process:
138
  if use_wandb:
139
- wandb.init(project=project+lora_name, config={
140
  "batch_size": batch_size,
141
  "base_learning_rate": base_learning_rate,
142
  "num_epochs": num_epochs,
@@ -150,7 +141,6 @@ if accelerator.is_main_process:
150
  project_name=project,
151
  workspace=comet_ml_workspace
152
  )
153
- # Логируем гиперпараметры в Comet ML
154
  hyper_params = {
155
  "batch_size": batch_size,
156
  "base_learning_rate": base_learning_rate,
@@ -170,8 +160,7 @@ gen = torch.Generator(device=device)
170
  gen.manual_seed(seed)
171
 
172
  # --------------------------- Загрузка моделей ---------------------------
173
- # VAE загружается на CPU для экономии GPU-памяти (как в твоём оригинальном коде)
174
- vae = AutoencoderKL.from_pretrained("AiArtLab/simplevae",subfolder="simple_vae_nightly",torch_dtype=dtype).to("cpu").eval()
175
 
176
  shift_factor = getattr(vae.config, "shift_factor", 0.0)
177
  if shift_factor is None:
@@ -185,8 +174,6 @@ latents_mean = getattr(vae.config, "latents_mean", None)
185
  latents_std = getattr(vae.config, "latents_std", None)
186
 
187
 
188
-
189
-
190
  class DistributedResolutionBatchSampler(Sampler):
191
  def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, drop_last=True):
192
  self.dataset = dataset
@@ -266,14 +253,16 @@ def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
266
  fixed_samples = {}
267
  for size, indices in size_groups.items():
268
  n_samples = min(samples_per_group, len(indices))
269
- if len(size_groups)==1:
270
  n_samples = samples_to_generate
271
  if n_samples == 0:
272
  continue
273
  sample_indices = random.sample(indices, n_samples)
274
  samples_data = [dataset[idx] for idx in sample_indices]
275
- latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device,dtype=dtype)
276
- embeddings = torch.tensor(np.array([item["embeddings"] for item in samples_data])).to(device,dtype=dtype)
 
 
277
  texts = [item["text"] for item in samples_data]
278
  fixed_samples[size] = (latents, embeddings, texts)
279
 
@@ -285,9 +274,10 @@ if limit > 0:
285
  else:
286
  dataset = load_from_disk(ds_path)
287
 
 
288
  def collate_fn_simple(batch):
289
- latents = torch.tensor(np.array([item["vae"] for item in batch])).to(device,dtype=dtype)
290
- embeddings = torch.tensor(np.array([item["embeddings"] for item in batch])).to(device,dtype=dtype)
291
  return latents, embeddings
292
 
293
  batch_sampler = DistributedResolutionBatchSampler(
@@ -298,20 +288,23 @@ batch_sampler = DistributedResolutionBatchSampler(
298
  shuffle=shuffle
299
  )
300
 
 
 
 
301
  dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple)
302
- print("Total samples",len(dataloader))
303
  dataloader = accelerator.prepare(dataloader)
304
 
305
- start_epoch = 0
306
- global_step = 0
307
- total_training_steps = (len(dataloader) * num_epochs)
308
- world_size = accelerator.state.num_processes
309
 
310
- # Опция загрузки модели из последнего чекпоинта (если существует)
311
  latest_checkpoint = os.path.join(checkpoints_folder, project)
312
  if os.path.isdir(latest_checkpoint):
313
  print("Загружаем UNet из чекпоинта:", latest_checkpoint)
314
- unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device,dtype=dtype)
315
  if torch_compile:
316
  print("compiling")
317
  torch.set_float32_matmul_precision('high')
@@ -325,16 +318,14 @@ if os.path.isdir(latest_checkpoint):
325
  except Exception as e:
326
  print(f"Ошибка при включении SDPA: {e}")
327
  unet.set_use_memory_efficient_attention_xformers(True)
328
-
329
  else:
330
- # FIX: если чекпоинта нет — прекращаем с понятной ошибкой (лучше, чем неожиданные NameError дальше)
331
  raise FileNotFoundError(f"UNet checkpoint not found at {latest_checkpoint}. Положи UNet чекпоинт в {latest_checkpoint} или укажи другой путь.")
332
 
 
333
  if lora_name:
334
  print(f"--- Настройка LoRA через PEFT (Rank={lora_rank}, Alpha={lora_alpha}) ---")
335
  from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
336
  from peft.tuners.lora import LoraModel
337
- import os
338
  unet.requires_grad_(False)
339
  print("Параметры базового UNet заморожены.")
340
 
@@ -372,8 +363,8 @@ if lora_name:
372
  if lora_name:
373
  trainable_params = [p for p in unet.parameters() if p.requires_grad]
374
  else:
375
- if fbp:
376
- trainable_params = list(unet.parameters())
377
 
378
  def create_optimizer(name, params):
379
  if name == "adam8bit":
@@ -407,11 +398,15 @@ if fbp:
407
  optimizer_dict[param].zero_grad(set_to_none=True)
408
  for param in trainable_params:
409
  param.register_post_accumulate_grad_hook(optimizer_hook)
 
410
  unet, optimizer = accelerator.prepare(unet, optimizer_dict)
411
  else:
412
- optimizer = create_optimizer(optimizer_type, unet.parameters())
 
 
413
  def lr_schedule(step):
414
- x = step / (total_training_steps * world_size)
 
415
  warmup = warmup_percent
416
  if not use_decay:
417
  return base_learning_rate
@@ -420,6 +415,8 @@ else:
420
  decay_ratio = (x - warmup) / (1 - warmup)
421
  return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
422
  (1 + math.cos(math.pi * decay_ratio))
 
 
423
  lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
424
 
425
  num_params = sum(p.numel() for p in unet.parameters())
@@ -427,12 +424,13 @@ else:
427
  for name, param in unet.named_parameters():
428
  if torch.isnan(param).any() or torch.isinf(param).any():
429
  print(f"[rank {accelerator.process_index}] NaN/Inf in {name}")
 
 
430
  unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
431
 
432
  # --------------------------- Фиксированные семплы для генерации ---------------------------
433
  fixed_samples = get_fixed_samples_by_resolution(dataset)
434
 
435
- @torch.compiler.disable()
436
  @torch.no_grad()
437
  def generate_and_save_samples(fixed_samples_cpu, step):
438
  original_model = None
@@ -440,12 +438,12 @@ def generate_and_save_samples(fixed_samples_cpu, step):
440
  original_model = accelerator.unwrap_model(unet, keep_torch_compile=True).eval()
441
  vae.to(device=device).eval() # временно подгружаем VAE на GPU для декодинга
442
 
443
-
444
  all_generated_images = []
445
  all_captions = []
446
 
447
  for size, (sample_latents, sample_text_embeddings, sample_text) in fixed_samples_cpu.items():
448
  width, height = size
 
449
  sample_latents = sample_latents.to(dtype=dtype, device=device)
450
  sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
451
 
@@ -484,19 +482,11 @@ def generate_and_save_samples(fixed_samples_cpu, step):
484
 
485
  # Параметры нормализации
486
  latent_for_vae = current_latents.detach() / scaling_factor + shift_factor
487
-
488
  decoded = vae.decode(latent_for_vae.to(torch.float32)).sample
489
- #decoded = decoded[:, :, 0, :, :] # [3, H, W]
490
- #print(decoded.ndim, decoded.shape)
491
-
492
  decoded_fp32 = decoded.to(torch.float32)
493
  for img_idx, img_tensor in enumerate(decoded_fp32):
494
-
495
- # Форма: [3, H, W] -> преобразуем в [H, W, 3]
496
  img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy()
497
- img = img.transpose(1, 2, 0) # Из [3, H, W] в [H, W, 3]
498
-
499
- #img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy().transpose(1, 2, 0)
500
  if np.isnan(img).any():
501
  print("NaNs found, saving stopped! Step:", step)
502
  pil_img = Image.fromarray((img * 255).astype("uint8"))
@@ -520,7 +510,7 @@ def generate_and_save_samples(fixed_samples_cpu, step):
520
  wandb.Image(img, caption=f"{all_captions[i]}")
521
  for i, img in enumerate(all_generated_images)
522
  ]
523
- wandb.log({"generated_images": wandb_images})
524
  if use_comet_ml and accelerator.is_main_process:
525
  for i, img in enumerate(all_generated_images):
526
  comet_experiment.log_image(
@@ -535,11 +525,7 @@ def generate_and_save_samples(fixed_samples_cpu, step):
535
  }
536
  )
537
  finally:
538
- # вернуть VAE на CPU (как было в твоём коде)
539
  vae.to("cpu")
540
- for var in list(locals().keys()):
541
- if isinstance(locals()[var], torch.Tensor):
542
- del locals()[var]
543
  torch.cuda.empty_cache()
544
  gc.collect()
545
 
@@ -547,20 +533,23 @@ def generate_and_save_samples(fixed_samples_cpu, step):
547
  if accelerator.is_main_process:
548
  if save_model:
549
  print("Генерация сэмплов до старта обучения...")
550
- generate_and_save_samples(fixed_samples,0)
551
  accelerator.wait_for_everyone()
552
 
553
  # Модифицируем функцию сохранения модели для поддержки LoRA
554
- def save_checkpoint(unet,variant=""):
555
  if accelerator.is_main_process:
556
  if lora_name:
557
  save_lora_checkpoint(unet)
558
  else:
559
- if variant!="":
560
- accelerator.unwrap_model(unet.to(dtype=torch.float16)).save_pretrained(os.path.join(checkpoints_folder, f"{project}"),variant=variant)
 
 
 
561
  else:
562
- accelerator.unwrap_model(unet).save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
563
- unet = unet.to(dtype=dtype)
564
 
565
  # --------------------------- Тренировочный цикл ---------------------------
566
  if accelerator.is_main_process:
@@ -569,20 +558,25 @@ if accelerator.is_main_process:
569
  epoch_loss_points = []
570
  progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
571
 
572
- steps_per_epoch = len(dataloader)
573
  sample_interval = max(1, steps_per_epoch // sample_interval_share)
574
  min_loss = 2.
575
 
576
- for epoch in range(start_epoch, start_epoch + num_epochs):
577
- batch_losses = []
578
- batch_grads = []
579
- batch_sampler.set_epoch(epoch)
 
 
 
580
  accelerator.wait_for_everyone()
581
  unet.train()
582
- #print("epoch:",epoch)
 
 
 
583
  for step, (latents, embeddings) in enumerate(dataloader):
584
  with accelerator.accumulate(unet):
585
- if save_model == False and step == 5 :
586
  used_gb = torch.cuda.max_memory_allocated() / 1024**3
587
  print(f"Шаг {step}: {used_gb:.2f} GB")
588
 
@@ -596,80 +590,94 @@ for epoch in range(start_epoch, start_epoch + num_epochs):
596
 
597
  mse_loss = F.mse_loss(model_pred.float(), target_pred.float())
598
 
599
- # Сохраняем для логов (мы сохраняем MSE отдельно — как показатель)
600
- batch_losses.append(mse_loss.detach().item())
601
 
602
- if (global_step % 100 == 0) or (global_step % sample_interval == 0):
603
- accelerator.wait_for_everyone()
604
-
605
  # Backward
606
  accelerator.backward(mse_loss)
607
 
608
- if (global_step % 100 == 0) or (global_step % sample_interval == 0):
609
- accelerator.wait_for_everyone()
610
-
611
- grad = 0.0
612
  if not fbp:
613
  if accelerator.sync_gradients:
614
- with torch.amp.autocast('cuda', enabled=False):
615
- grad_val = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm)
616
- grad = float(grad_val)
617
  optimizer.step()
618
  lr_scheduler.step()
619
  optimizer.zero_grad(set_to_none=True)
620
 
621
- global_step += 1
622
- progress_bar.update(1)
623
-
624
- # Логируем метрики
625
- if accelerator.is_main_process:
626
- if fbp:
627
- current_lr = base_learning_rate
 
 
 
 
 
 
 
 
 
628
  else:
629
- current_lr = lr_scheduler.get_last_lr()[0]
630
- batch_grads.append(grad)
631
-
632
- log_data = {}
633
- log_data["loss"] = mse_loss.detach().item()
634
- log_data["lr"] = current_lr
635
- log_data["grad"] = grad
636
- if accelerator.sync_gradients:
637
- if use_wandb:
638
- wandb.log(log_data, step=global_step)
639
- if use_comet_ml:
640
- comet_experiment.log_metrics(log_data, step=global_step)
 
 
 
 
 
 
641
 
642
- # Генерируем сэмплы с заданным интервалом
643
- if global_step % sample_interval == 0:
644
- generate_and_save_samples(fixed_samples,global_step)
645
- last_n = sample_interval
646
-
647
- if save_model:
648
- avg_sample_loss = np.mean(batch_losses[-sample_interval:]) if len(batch_losses)>0 else 0.0
649
- print("saving:",avg_sample_loss < min_loss*save_barrier,"Current:",log_data["loss"],"Avg:",avg_sample_loss)
650
- if log_data["loss"] < min_loss*save_barrier:
651
- min_loss = avg_sample_loss
652
- save_checkpoint(unet)
653
  if use_wandb:
654
  wandb.log(log_data, step=global_step)
655
  if use_comet_ml:
656
  comet_experiment.log_metrics(log_data, step=global_step)
657
 
658
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
659
  if accelerator.is_main_process:
660
- avg_epoch_loss = np.mean(batch_losses[-steps_per_epoch:]) if len(batch_losses)>0 else 0.0
661
- avg_epoch_grad = np.mean(batch_grads[-steps_per_epoch:]) if len(batch_grads)>0 else 0.0
662
- print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
 
 
 
 
663
  if use_wandb:
664
- wandb.log({"epoch_loss": avg_epoch_loss, "epoch_grad": avg_epoch_grad, "epoch": epoch+1})
665
- #if use_comet_ml:
666
- # comet_experiment.log_metrics(epoch_data)
667
 
668
  # Завершение обучения - сохраняем финальную модель
669
  if accelerator.is_main_process:
670
  print("Обучение завершено! Сохраняем финальную модель...")
671
  if save_model:
672
- save_checkpoint(unet,"fp16")
673
  if use_comet_ml:
674
  comet_experiment.end()
675
  accelerator.free_memory()
 
7
  from torch.utils.data.distributed import DistributedSampler
8
  from torch.optim.lr_scheduler import LambdaLR
9
  from collections import defaultdict
10
+ from diffusers import UNet2DConditionModel, AutoencoderKL
 
11
  from accelerate import Accelerator
12
  from datasets import load_from_disk
13
  from tqdm import tqdm
14
+ from PIL import Image, ImageOps
15
  import wandb
16
  import random
17
  import gc
 
44
  clip_sample = False #Scheduler
45
  fixed_seed = True
46
  shuffle = True
47
+ comet_ml_api_key = "Agctp26mbqnoYrrlvQuKSTk6r"
48
+ comet_ml_workspace = "recoilme"
49
  torch.backends.cuda.matmul.allow_tf32 = True
50
  torch.backends.cudnn.allow_tf32 = True
51
  torch.backends.cuda.enable_mem_efficient_sdp(False)
 
59
  steps_offset = 0 # Scheduler
60
  limit = 0
61
  checkpoints_folder = ""
62
+ mixed_precision = "no" # "fp16"
63
  gradient_accumulation_steps = 1
64
  accelerator = Accelerator(
65
  mixed_precision=mixed_precision,
 
87
  torch.cuda.manual_seed_all(seed)
88
 
89
  # --------------------------- Параметры LoRA ---------------------------
90
+ lora_name = ""
91
  lora_rank = 32
92
  lora_alpha = 64
93
 
 
102
  device=None,
103
  mode: str = "beta", # "beta", "uniform"
104
  ) -> torch.Tensor:
 
 
 
 
 
 
 
105
  max_idx = num_train_timesteps - 1 - steps_offset
106
 
107
  if mode == "beta":
 
118
  timesteps = steps_offset + (samples * max_idx).long().to(device)
119
  return timesteps
120
 
121
+
122
  def logit_normal_samples(shape, mu=0.0, sigma=1.0, device=None, dtype=None):
123
  normal_samples = torch.normal(mean=mu, std=sigma, size=shape, device=device, dtype=dtype)
 
124
  logit_normal_samples = torch.sigmoid(normal_samples)
 
125
  return logit_normal_samples
126
 
127
  # --------------------------- Инициализация WandB ---------------------------
128
  if accelerator.is_main_process:
129
  if use_wandb:
130
+ wandb.init(project=project + lora_name, config={
131
  "batch_size": batch_size,
132
  "base_learning_rate": base_learning_rate,
133
  "num_epochs": num_epochs,
 
141
  project_name=project,
142
  workspace=comet_ml_workspace
143
  )
 
144
  hyper_params = {
145
  "batch_size": batch_size,
146
  "base_learning_rate": base_learning_rate,
 
160
  gen.manual_seed(seed)
161
 
162
  # --------------------------- Загрузка моделей ---------------------------
163
+ vae = AutoencoderKL.from_pretrained("AiArtLab/simplevae", subfolder="vae", torch_dtype=dtype).to("cpu").eval()
 
164
 
165
  shift_factor = getattr(vae.config, "shift_factor", 0.0)
166
  if shift_factor is None:
 
174
  latents_std = getattr(vae.config, "latents_std", None)
175
 
176
 
 
 
177
  class DistributedResolutionBatchSampler(Sampler):
178
  def __init__(self, dataset, batch_size, num_replicas, rank, shuffle=True, drop_last=True):
179
  self.dataset = dataset
 
253
  fixed_samples = {}
254
  for size, indices in size_groups.items():
255
  n_samples = min(samples_per_group, len(indices))
256
+ if len(size_groups) == 1:
257
  n_samples = samples_to_generate
258
  if n_samples == 0:
259
  continue
260
  sample_indices = random.sample(indices, n_samples)
261
  samples_data = [dataset[idx] for idx in sample_indices]
262
+
263
+ # FIXED: keep fixed samples on CPU to avoid device/device-transfer issues when creating dataset
264
+ latents = torch.tensor(np.array([item["vae"] for item in samples_data]), dtype=dtype).cpu()
265
+ embeddings = torch.tensor(np.array([item["embeddings"] for item in samples_data]), dtype=dtype).cpu()
266
  texts = [item["text"] for item in samples_data]
267
  fixed_samples[size] = (latents, embeddings, texts)
268
 
 
274
  else:
275
  dataset = load_from_disk(ds_path)
276
 
277
+
278
  def collate_fn_simple(batch):
279
+ latents = torch.tensor(np.array([item["vae"] for item in batch]), dtype=dtype).to(device)
280
+ embeddings = torch.tensor(np.array([item["embeddings"] for item in batch]), dtype=dtype).to(device)
281
  return latents, embeddings
282
 
283
  batch_sampler = DistributedResolutionBatchSampler(
 
288
  shuffle=shuffle
289
  )
290
 
291
+ # NOTE: we create dataloader first, then prepare it with accelerator. We'll create optimizer/lr_scheduler after
292
+ # we know len(dataloader) (which is per-process after prepare) so that scheduling is consistent.
293
+
294
  dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple)
295
+ print("Total batches (pre-prepare):", len(dataloader))
296
  dataloader = accelerator.prepare(dataloader)
297
 
298
+ # --------------------------- Теперь безопасно --- вычисляем шаги и создаём оптимизатор/шедулер ---------------------------
299
+ steps_per_epoch = len(dataloader) # this is per-process (after prepare)
300
+ total_training_steps = steps_per_epoch * num_epochs
301
+ print(f"[rank {accelerator.process_index}] steps_per_epoch={steps_per_epoch}, total_training_steps={total_training_steps}")
302
 
303
+ # --------------------------- Создание/загрузка UNet ---------------------------
304
  latest_checkpoint = os.path.join(checkpoints_folder, project)
305
  if os.path.isdir(latest_checkpoint):
306
  print("Загружаем UNet из чекпоинта:", latest_checkpoint)
307
+ unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device, dtype=dtype)
308
  if torch_compile:
309
  print("compiling")
310
  torch.set_float32_matmul_precision('high')
 
318
  except Exception as e:
319
  print(f"Ошибка при включении SDPA: {e}")
320
  unet.set_use_memory_efficient_attention_xformers(True)
 
321
  else:
 
322
  raise FileNotFoundError(f"UNet checkpoint not found at {latest_checkpoint}. Положи UNet чекпоинт в {latest_checkpoint} или укажи другой путь.")
323
 
324
+ # --------------------------- LoRA (если нужно) ---------------------------
325
  if lora_name:
326
  print(f"--- Настройка LoRA через PEFT (Rank={lora_rank}, Alpha={lora_alpha}) ---")
327
  from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training
328
  from peft.tuners.lora import LoraModel
 
329
  unet.requires_grad_(False)
330
  print("Параметры базового UNet заморожены.")
331
 
 
363
  if lora_name:
364
  trainable_params = [p for p in unet.parameters() if p.requires_grad]
365
  else:
366
+ trainable_params = list(unet.parameters()) if fbp else [p for p in unet.parameters() if p.requires_grad]
367
+
368
 
369
  def create_optimizer(name, params):
370
  if name == "adam8bit":
 
398
  optimizer_dict[param].zero_grad(set_to_none=True)
399
  for param in trainable_params:
400
  param.register_post_accumulate_grad_hook(optimizer_hook)
401
+ # FIXED: prepare fbp variant (keeps original logic)
402
  unet, optimizer = accelerator.prepare(unet, optimizer_dict)
403
  else:
404
+ optimizer = create_optimizer(optimizer_type, trainable_params)
405
+
406
+ # FIXED: LR schedule should be based on total_training_steps (per-process steps * epochs)
407
  def lr_schedule(step):
408
+ # step is current scheduler step (0..total_training_steps)
409
+ x = step / max(1, total_training_steps)
410
  warmup = warmup_percent
411
  if not use_decay:
412
  return base_learning_rate
 
415
  decay_ratio = (x - warmup) / (1 - warmup)
416
  return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
417
  (1 + math.cos(math.pi * decay_ratio))
418
+
419
+ # LambdaLR expects a multiplier, so divide by base_learning_rate
420
  lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
421
 
422
  num_params = sum(p.numel() for p in unet.parameters())
 
424
  for name, param in unet.named_parameters():
425
  if torch.isnan(param).any() or torch.isinf(param).any():
426
  print(f"[rank {accelerator.process_index}] NaN/Inf in {name}")
427
+
428
+ # FIXED: prepare model, optimizer, scheduler AFTER creating them and after dataloader.prepare
429
  unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
430
 
431
  # --------------------------- Фиксированные семплы для генерации ---------------------------
432
  fixed_samples = get_fixed_samples_by_resolution(dataset)
433
 
 
434
  @torch.no_grad()
435
  def generate_and_save_samples(fixed_samples_cpu, step):
436
  original_model = None
 
438
  original_model = accelerator.unwrap_model(unet, keep_torch_compile=True).eval()
439
  vae.to(device=device).eval() # временно подгружаем VAE на GPU для декодинга
440
 
 
441
  all_generated_images = []
442
  all_captions = []
443
 
444
  for size, (sample_latents, sample_text_embeddings, sample_text) in fixed_samples_cpu.items():
445
  width, height = size
446
+ # move CPU tensors to device here (they were kept on CPU in get_fixed_samples_by_resolution)
447
  sample_latents = sample_latents.to(dtype=dtype, device=device)
448
  sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
449
 
 
482
 
483
  # Параметры нормализации
484
  latent_for_vae = current_latents.detach() / scaling_factor + shift_factor
 
485
  decoded = vae.decode(latent_for_vae.to(torch.float32)).sample
 
 
 
486
  decoded_fp32 = decoded.to(torch.float32)
487
  for img_idx, img_tensor in enumerate(decoded_fp32):
 
 
488
  img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy()
489
+ img = img.transpose(1, 2, 0)
 
 
490
  if np.isnan(img).any():
491
  print("NaNs found, saving stopped! Step:", step)
492
  pil_img = Image.fromarray((img * 255).astype("uint8"))
 
510
  wandb.Image(img, caption=f"{all_captions[i]}")
511
  for i, img in enumerate(all_generated_images)
512
  ]
513
+ wandb.log({"generated_images": wandb_images}, step=step)
514
  if use_comet_ml and accelerator.is_main_process:
515
  for i, img in enumerate(all_generated_images):
516
  comet_experiment.log_image(
 
525
  }
526
  )
527
  finally:
 
528
  vae.to("cpu")
 
 
 
529
  torch.cuda.empty_cache()
530
  gc.collect()
531
 
 
533
  if accelerator.is_main_process:
534
  if save_model:
535
  print("Генерация сэмплов до старта обучения...")
536
+ generate_and_save_samples(fixed_samples, 0)
537
  accelerator.wait_for_everyone()
538
 
539
  # Модифицируем функцию сохранения модели для поддержки LoRA
540
+ def save_checkpoint(unet, variant=""):
541
  if accelerator.is_main_process:
542
  if lora_name:
543
  save_lora_checkpoint(unet)
544
  else:
545
+ # FIXED: don't change dtype of model wrapped by accelerator. Unwrap and save as-is.
546
+ model_to_save = accelerator.unwrap_model(unet)
547
+ dest = os.path.join(checkpoints_folder, f"{project}")
548
+ if variant != "":
549
+ model_to_save.save_pretrained(dest, variant=variant)
550
  else:
551
+ model_to_save.save_pretrained(dest)
552
+
553
 
554
  # --------------------------- Тренировочный цикл ---------------------------
555
  if accelerator.is_main_process:
 
558
  epoch_loss_points = []
559
  progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
560
 
 
561
  sample_interval = max(1, steps_per_epoch // sample_interval_share)
562
  min_loss = 2.
563
 
564
+ for epoch in range(0, num_epochs):
565
+ # FIXED: set epoch on the dataloader's batch_sampler if available (accelerator may have wrapped it)
566
+ if hasattr(dataloader, "batch_sampler") and hasattr(dataloader.batch_sampler, "set_epoch"):
567
+ dataloader.batch_sampler.set_epoch(epoch)
568
+ elif hasattr(batch_sampler, "set_epoch"):
569
+ batch_sampler.set_epoch(epoch)
570
+
571
  accelerator.wait_for_everyone()
572
  unet.train()
573
+
574
+ batch_losses = []
575
+ batch_grads = []
576
+
577
  for step, (latents, embeddings) in enumerate(dataloader):
578
  with accelerator.accumulate(unet):
579
+ if save_model == False and step == 5:
580
  used_gb = torch.cuda.max_memory_allocated() / 1024**3
581
  print(f"Шаг {step}: {used_gb:.2f} GB")
582
 
 
590
 
591
  mse_loss = F.mse_loss(model_pred.float(), target_pred.float())
592
 
593
+ # Сохраняем для локальных логов
594
+ batch_losses.append(mse_loss.detach().cpu().item())
595
 
 
 
 
596
  # Backward
597
  accelerator.backward(mse_loss)
598
 
599
+ grad_norm_val = 0.0
 
 
 
600
  if not fbp:
601
  if accelerator.sync_gradients:
602
+ # Clip gradients and step only when gradients are synchronized (i.e. actual optimizer step)
603
+ grad_norm = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm)
604
+ grad_norm_val = float(grad_norm)
605
  optimizer.step()
606
  lr_scheduler.step()
607
  optimizer.zero_grad(set_to_none=True)
608
 
609
+ # increment global_step only when we have synchronized gradients (i.e. on optimizer step)
610
+ # FIXED: ensure global_step reflects optimizer updates, not micro-batches
611
+ if accelerator.sync_gradients:
612
+ try:
613
+ global_step += 1
614
+ except NameError:
615
+ global_step = 1
616
+ progress_bar.update(1)
617
+
618
+ # Aggregate loss across processes for correct logging
619
+ loss_tensor = mse_loss.detach().clone()
620
+ # move to device if not already
621
+ loss_tensor = loss_tensor.to(device)
622
+ gathered = accelerator.gather(loss_tensor)
623
+ if accelerator.is_main_process:
624
+ reduced_loss = gathered.mean().item()
625
  else:
626
+ reduced_loss = None
627
+
628
+ # Логируем метрики только на главном процессе
629
+ if accelerator.is_main_process:
630
+ if fbp:
631
+ current_lr = base_learning_rate
632
+ else:
633
+ current_lr = lr_scheduler.get_last_lr()[0]
634
+
635
+ batch_grads.append(grad_norm_val)
636
+
637
+ log_data = {
638
+ "loss": reduced_loss,
639
+ "lr": current_lr,
640
+ "grad": grad_norm_val,
641
+ "epoch": epoch + 1,
642
+ "global_step": global_step,
643
+ }
644
 
 
 
 
 
 
 
 
 
 
 
 
645
  if use_wandb:
646
  wandb.log(log_data, step=global_step)
647
  if use_comet_ml:
648
  comet_experiment.log_metrics(log_data, step=global_step)
649
 
650
+ # Генерируем сэмплы с заданным интервалом (только на главном процессе)
651
+ if global_step % sample_interval == 0:
652
+ generate_and_save_samples(fixed_samples, global_step)
653
+
654
+ if save_model:
655
+ # use recent local losses to decide saving (still local); you may want to use reduced_loss here
656
+ avg_sample_loss = np.mean(batch_losses[-sample_interval:]) if len(batch_losses) > 0 else 0.0
657
+ if use_wandb:
658
+ wandb.log({"sample_loss": avg_sample_loss)
659
+ print("saving:", reduced_loss is not None and reduced_loss < min_loss * save_barrier, "Current:", reduced_loss, "Avg:", avg_sample_loss)
660
+ if reduced_loss is not None and reduced_loss < min_loss * save_barrier:
661
+ min_loss = avg_sample_loss
662
+ save_checkpoint(unet)
663
+
664
+ # Эпоха окончена — агрегируем и логируем средние значения
665
  if accelerator.is_main_process:
666
+ # local averages
667
+ avg_epoch_loss = np.mean(batch_losses) if len(batch_losses) > 0 else 0.0
668
+ avg_epoch_grad = np.mean(batch_grads) if len(batch_grads) > 0 else 0.0
669
+
670
+ # FIXED: optionally reduce across processes if you want a true global epoch average
671
+ # Here we compute local values and log them (main process only). For global average use accelerator.gather
672
+ print(f"\nЭпоха {epoch} завершена. Средний лосс (local main proc): {avg_epoch_loss:.6f}")
673
  if use_wandb:
674
+ wandb.log({"epoch_loss": avg_epoch_loss, "epoch_grad": avg_epoch_grad, "epoch": epoch + 1})
 
 
675
 
676
  # Завершение обучения - сохраняем финальную модель
677
  if accelerator.is_main_process:
678
  print("Обучение завершено! Сохраняем финальную модель...")
679
  if save_model:
680
+ save_checkpoint(unet, "fp16")
681
  if use_comet_ml:
682
  comet_experiment.end()
683
  accelerator.free_memory()
unet/diffusion_pytorch_model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:77ac76865ef38f3cabffab60d7cbb5ef9ef1e18d5ad1117db6cd121545c39fd5
3
  size 6184944280
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e9471beb1dbfceb94fecef7d779b179a0d3d2ca499c2ed2f7593b3e68c3d3dc5
3
  size 6184944280