import os import math import torch import numpy as np import matplotlib.pyplot as plt import wandb, comet_ml import random, time import gc import bitsandbytes as bnb import torch.nn.functional as F import argparse from datetime import datetime from diffusers import CosmosTransformer3DModel, AutoencoderKLQwenImage, FlowMatchEulerDiscreteScheduler from transformers import Qwen3_5Tokenizer, Qwen3_5ForConditionalGeneration from torch.utils.data import DataLoader, Sampler from torch.optim.lr_scheduler import LambdaLR from collections import defaultdict from accelerate import Accelerator from datasets import load_from_disk from tqdm import tqdm from PIL import Image, ImageOps from torch.utils.checkpoint import checkpoint from diffusers.models.attention_processor import AttnProcessor2_0 from contextlib import nullcontext from transformers.optimization import Adafactor # Muon not tested! pip install git+https://github.com/recoilme/muon_adamw8bit.git from muon_adamw8bit import MuonAdamW8bit os.environ["NCCL_P2P_DISABLE"] = "1" os.environ["NCCL_IB_DISABLE"] = "1" # comment this on H100! os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" # --------------------------- Параметры --------------------------- ds_path = "datasets/ds234_640_vae_qwen" project = "transformer" gpu_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 local_bs = max(1, int((gpu_mem_gb / 32) * 7)) num_gpus = torch.cuda.device_count() batch_size = local_bs * num_gpus base_learning_rate = 4e-5 min_learning_rate = 4e-6 learning_rate_scale = 3 base_learning_rate = base_learning_rate / learning_rate_scale min_learning_rate = min_learning_rate / learning_rate_scale print(f"Calculated params max-lr:{base_learning_rate} min-lr:{min_learning_rate} GPUs: {num_gpus}, Global BS: {batch_size}") num_epochs = num_gpus sink_interval_share = 10 sample_interval_min = 20 cfg_dropout = 0.10 # Время t, bias = -0.5 (Фокус на Деталях ~300) bias = 0.5 (Фокус на структуре) bias = 0 (колокол/ равномерно) sigmoid_bias = 0.1 max_length = 250 use_precomputed_embeddings = False use_wandb = False use_comet_ml = False save_model = True use_decay = True fbp = False torch_compile = False transformer_gradient = True loss_normalize = False fixed_seed = False shuffle = True optimizer_type = "adafactor" if optimizer_type == "muon_adam8bit": batch_size = num_gpus * max(1, int((gpu_mem_gb / 32) * 3)) muon_lr_scale = 500 comet_ml_api_key = "Agctp26mbqnoYrrlvQuKSTk6r" comet_ml_workspace = "recoilme" torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True torch.backends.cuda.enable_flash_sdp(True) torch.backends.cuda.enable_mem_efficient_sdp(True) torch.backends.cuda.enable_math_sdp(False) save_barrier = 1.25 warmup_percent = 0.0025 betta2 = 0.997 eps = 1e-6 clip_grad_norm = 1.0 limit = 0 checkpoints_folder = "" gradient_accumulation_steps = 1 dtype = torch.float32 mixed_precision = "bf16" # Параметры для диффузии n_diffusion_steps = 40 samples_to_generate = 12 guidance_scale = 7.0 # Папки для сохранения результатов generated_folder = "samples" os.makedirs(generated_folder, exist_ok=True) # Настройка seed current_date = datetime.now() seed = int(current_date.strftime("%Y%m%d")) + 42 if fixed_seed: torch.manual_seed(seed) np.random.seed(seed) random.seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) accelerator = Accelerator( mixed_precision=mixed_precision, gradient_accumulation_steps=gradient_accumulation_steps ) device = accelerator.device print("init") parser = argparse.ArgumentParser(description='Train a model on a dataset.') parser.add_argument('--ds-path', type=str, default=ds_path, help='Path to the dataset') parser.add_argument('--ep', type=int, default=num_epochs, help='Number of epochs to train the model') parser.add_argument('--batch', type=int, default=batch_size, help='Total batch size') parser.add_argument('--min-lr', type=float, default=min_learning_rate, help='Minimum learning rate') parser.add_argument('--max-lr', type=float, default=base_learning_rate, help='Maximum learning rate') parser.add_argument('--dry-run', action='store_true',default=False, help='Dry run train without saving/sampling') parser.add_argument('--lvl', type=float, default=0.0, help='Train level, from 0.5 to 5') args = parser.parse_args() batch_size = args.batch ds_path = args.ds_path base_learning_rate = args.max_lr min_learning_rate = args.min_lr num_epochs = args.ep lvl = args.lvl if args.dry_run: save_model = False if lvl >= 0.1: base_learning_rate = base_learning_rate / lvl min_learning_rate = min_learning_rate / lvl print(f"max-lr:{base_learning_rate} min-lr:{min_learning_rate}") # --------------------------- Инициализация WandB --------------------------- if accelerator.is_main_process: if use_wandb: wandb.init(project=project, config={ "batch_size": batch_size, "base_learning_rate": base_learning_rate, "num_epochs": num_epochs, "optimizer_type": optimizer_type, }) if use_comet_ml: from comet_ml import Experiment comet_experiment = Experiment( api_key=comet_ml_api_key, project_name=project, workspace=comet_ml_workspace ) hyper_params = { "batch_size": batch_size, "base_learning_rate": base_learning_rate, "num_epochs": num_epochs, } comet_experiment.log_parameters(hyper_params) # --------------------------- Загрузка моделей --------------------------- vae = AutoencoderKLQwenImage.from_pretrained("vae", torch_dtype=dtype).to(device).to(dtype=dtype).eval() scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("scheduler") tokenizer = None text_encoder = None def load_text_encoder(): global tokenizer, text_encoder if tokenizer is None: tokenizer = Qwen3_5Tokenizer.from_pretrained("tokenizer") if text_encoder is None: text_encoder = Qwen3_5ForConditionalGeneration.from_pretrained( "text_encoder", torch_dtype=dtype ).to(device).eval() load_text_encoder() @torch.no_grad() def encode_texts(text, max_length=max_length): if text is None: text = "" if isinstance(text, str): text = [text] formatted_prompts = [] for t in text: messages = [{"role": "user", "content": [{"type": "text", "text": t}]}] formatted_prompts.append( tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=False ) ) toks = tokenizer( formatted_prompts, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt" ).to(device) outputs = text_encoder( input_ids=toks.input_ids, attention_mask=toks.attention_mask, output_hidden_states=True ) hidden = outputs.hidden_states[-2].to(dtype=dtype) lengths = toks.attention_mask.sum(dim=1) for i, length in enumerate(lengths): hidden[i, length:] = 0 return hidden, toks.attention_mask.to(dtype=torch.int64) @torch.no_grad() def encode_texts_fast(text, max_length=max_length): if text is None: text = "" if isinstance(text, str): text = [text] formatted_prompts = [] for t in text: messages = [{"role": "user", "content": [{"type": "text", "text": t}]}] formatted_prompts.append(tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)) toks = tokenizer(formatted_prompts, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt").to(device) outputs = text_encoder(input_ids=toks.input_ids, attention_mask=toks.attention_mask, output_hidden_states=True) last_hidden = outputs.hidden_states[-2].to(dtype=dtype) lengths = toks.attention_mask.sum(dim=1) for i, length in enumerate(lengths): last_hidden[i, length:] = 0 return last_hidden, toks.attention_mask.to(dtype=torch.int64) shift_factor = getattr(vae.config, "shift_factor", 0.0) if shift_factor is None: shift_factor = 0.0 scaling_factor = getattr(vae.config, "scaling_factor", 1.0) if scaling_factor is None: scaling_factor = 1.0 mean = getattr(vae.config, "latents_mean", None) std = getattr(vae.config, "latents_std", None) if mean is not None and std is not None: latents_std = torch.tensor(std, device=device, dtype=dtype).view(1, len(std), 1, 1) latents_mean = torch.tensor(mean, device=device, dtype=dtype).view(1, len(mean), 1, 1) # Внимание: Cosmos использует инвертированный std для декодирования (1.0 / std) #latents_std = 1.0 / torch.tensor(std).view(1, len(std), 1, 1, 1) else: latents_std = None latents_mean = None if scheduler is not None: scheduler.register_to_config( sigma_max=getattr(scheduler.config, "sigma_max", 80.0), sigma_min=getattr(scheduler.config, "sigma_min", 0.002), sigma_data=getattr(scheduler.config, "sigma_data", 1.0), final_sigmas_type=getattr(scheduler.config, "final_sigmas_type", "sigma_min"), ) import numpy as np from torch.utils.data import Sampler class DistributedResolutionBatchSampler(Sampler): def __init__(self, dataset, batch_size, num_replicas, rank, drop_last=True, shuffle=True): self.dataset = dataset self.num_replicas = num_replicas self.rank = rank self.shuffle = shuffle self.drop_last = drop_last self.epoch = 0 self.batch_size = max(1, batch_size // num_replicas) self.global_batch = self.batch_size * num_replicas try: widths = np.asarray(dataset["width"]) heights = np.asarray(dataset["height"]) except KeyError: widths = np.zeros(len(dataset)) heights = np.zeros(len(dataset)) groups = {} for i, (w, h) in enumerate(zip(widths, heights)): groups.setdefault((w, h), []).append(i) all_batches = [] for indices in groups.values(): idx = np.asarray(indices, dtype=np.int64) num_batches = len(idx) // self.global_batch if num_batches == 0: continue idx = idx[: num_batches * self.global_batch] batches = idx.reshape(num_batches, self.global_batch) all_batches.append(batches) if len(all_batches) > 0: self.global_batches = np.concatenate(all_batches, axis=0) else: self.global_batches = np.empty((0, self.global_batch), dtype=np.int64) self.num_batches = len(self.global_batches) def __iter__(self): rng = np.random.RandomState(self.epoch) order = np.arange(self.num_batches) if self.shuffle: rng.shuffle(order) start = self.rank * self.batch_size end = start + self.batch_size for i in order: yield self.global_batches[i][start:end] def __len__(self): return self.num_batches def set_epoch(self, epoch): self.epoch = epoch def get_fixed_samples_by_resolution(dataset, samples_per_group=1): size_groups = defaultdict(list) try: widths = dataset["width"] heights = dataset["height"] except KeyError: widths = [0] * len(dataset) heights = [0] * len(dataset) for i, (w, h) in enumerate(zip(widths, heights)): size = (w, h) size_groups[size].append(i) fixed_samples = {} for size, indices in size_groups.items(): n_samples = min(samples_per_group, len(indices)) if len(size_groups)==1: n_samples = samples_to_generate if n_samples == 0: continue sample_indices = random.sample(indices, n_samples) samples_data = [dataset[idx] for idx in sample_indices] latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device, dtype=dtype) if latents.ndim == 4: latents = latents.unsqueeze(2) elif latents.ndim == 6: latents = latents.squeeze(2) texts = [item["text"] for item in samples_data] if use_precomputed_embeddings: embeddings = torch.tensor( np.array([item["embeddings"] for item in samples_data]), device=device, dtype=dtype ) masks = torch.tensor( np.array([item["attention_mask"] for item in samples_data]), device=device, dtype=torch.int64 ) else: embeddings, masks = encode_texts(texts,max_length) fixed_samples[size] = (latents, embeddings, masks, texts) print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям") return fixed_samples if limit > 0: dataset = load_from_disk(ds_path).select(range(limit)) else: dataset = load_from_disk(ds_path) print(f"images: {len(dataset)}") def collate_fn_simple(batch): latents = torch.from_numpy( np.array([item["vae"] for item in batch], dtype=np.float16) ).to(device, dtype=dtype) if latents.ndim == 4: latents = latents.unsqueeze(2) elif latents.ndim == 6: latents = latents.squeeze(2) if use_precomputed_embeddings: embeddings = torch.from_numpy( np.array([item["embeddings"] for item in batch], dtype=np.float16) ).to(device, dtype=dtype) attention_mask = torch.from_numpy( np.array([item["attention_mask"] for item in batch], dtype=np.int64) ).to(device) return latents, embeddings, attention_mask raw_texts = [item["text"] for item in batch] texts = [ "" if t.lower().startswith("zero") else "" if random.random() < cfg_dropout else t[1:].lstrip() if t.startswith(".") else t.replace("The image shows ", "").replace("The image is ", "").replace("This image captures ","").strip() for t in raw_texts ] embeddings, attention_mask = encode_texts(texts,max_length) attention_mask = attention_mask.to(dtype=torch.int64) return latents, embeddings, attention_mask batch_sampler = DistributedResolutionBatchSampler( dataset=dataset, batch_size=batch_size, num_replicas=accelerator.num_processes, rank=accelerator.process_index, shuffle = shuffle ) dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple) if accelerator.is_main_process: print("Total samples", len(dataloader)) dataloader = accelerator.prepare(dataloader) start_epoch = 0 global_step = 0 total_training_steps = (len(dataloader) * num_epochs) world_size = accelerator.state.num_processes latest_checkpoint = os.path.join(checkpoints_folder, project) if os.path.isdir(latest_checkpoint): print("Загружаем Transformer из чекпоинта:", latest_checkpoint) transformer = CosmosTransformer3DModel.from_pretrained(latest_checkpoint).to(device=device, dtype=dtype) if transformer_gradient: transformer.enable_gradient_checkpointing() else: raise FileNotFoundError(f"Transformer checkpoint not found at {latest_checkpoint}") def create_optimizer(name, params): if name == "adam8bit": return bnb.optim.AdamW8bit( params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.001 ) elif name == "adam": return torch.optim.AdamW( params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.001 ) elif name == "adafactor": return Adafactor( params, lr=base_learning_rate, eps=(1e-30, 1e-3), clip_threshold=1.0, decay_rate=-0.8, beta1=None, weight_decay=0.001, relative_step=False, scale_parameter=False, warmup_init=False ) elif name == "muon_adam8bit": return MuonAdamW8bit( params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.01, muon_lr_mult=muon_lr_scale, ) else: raise ValueError(f"Unknown optimizer: {name}") if fbp: trainable_params = list(transformer.parameters()) optimizer_dict = {p: create_optimizer(optimizer_type, [p]) for p in trainable_params} def optimizer_hook(param): optimizer_dict[param].step() optimizer_dict[param].zero_grad(set_to_none=True) for param in trainable_params: param.register_post_accumulate_grad_hook(optimizer_hook) transformer, optimizer = accelerator.prepare(transformer, optimizer_dict) else: #transformer.requires_grad_(True) # 1. Сначала замораживаем ВООБЩЕ ВСЕ параметры transformer.requires_grad_(False) # 2. Определяем ключевое слово для слоев, которые нужно учить (Cross-Attention) trainable_params_names = ["attn2"] trainable_params = [] print("--- РАЗМОРОЖЕННЫЕ СЛОИ ---") for name, param in transformer.named_parameters(): if any(target in name for target in trainable_params_names): param.requires_grad_(True) # Размораживаем trainable_params.append(param) print(f"Обучаемый слой: {name}") print("--------------------------") # Защита от дурака if len(trainable_params) == 0: raise ValueError("Ошибка: ни один слой не был разморожен! Проверь ключи.") optimizer = create_optimizer(optimizer_type, transformer.parameters()) def lr_schedule(step): x = step / (total_training_steps * world_size) warmup = warmup_percent if not use_decay: return base_learning_rate if x < warmup: return min_learning_rate + (base_learning_rate - min_learning_rate) * (x / warmup) decay_ratio = (x - warmup) / (1 - warmup) return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \ (1 + math.cos(math.pi * decay_ratio)) lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate) if torch_compile: print("Compiling Transformer... Это займет несколько минут, не прерывайте!") transformer = torch.compile(transformer) print("Compiling - ok") if not fbp: transformer, optimizer, lr_scheduler = accelerator.prepare(transformer, optimizer, lr_scheduler) # Фиксированные семплы fixed_samples = get_fixed_samples_by_resolution(dataset) def get_negative_embedding(neg_prompt="", batch_size=1): if not neg_prompt: hidden_dim = 2048 seq_len = max_length empty_emb = torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device) empty_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device) return empty_emb, empty_mask uncond_emb, uncond_mask = encode_texts([neg_prompt],max_length) uncond_emb = uncond_emb.to(dtype=dtype, device=device).repeat(batch_size, 1, 1) uncond_mask = uncond_mask.to(device=device).repeat(batch_size, 1) return uncond_emb, uncond_mask if use_precomputed_embeddings: load_text_encoder() uncond_emb, uncond_mask = get_negative_embedding("low quality") uncond_emb = uncond_emb.to("cpu") uncond_mask = uncond_mask.to("cpu") del text_encoder torch.cuda.empty_cache() gc.collect() text_encoder = None else: uncond_emb, uncond_mask = get_negative_embedding("low quality") def pad_to_match(a, b, pad_value=0): Ta, Tb = a.shape[1], b.shape[1] if Ta == Tb: return a, b T = max(Ta, Tb) def pad(x, T_target): pad_len = T_target - x.shape[1] if pad_len <= 0: return x return torch.nn.functional.pad(x, (0, 0, 0, pad_len), value=pad_value) return pad(a, T), pad(b, T) @torch.compiler.disable() @torch.no_grad() def generate_and_save_samples(fixed_samples_cpu, uncond_data, step): uncond_emb, uncond_mask = uncond_data uncond_emb = uncond_emb.to(device) uncond_mask = uncond_mask.to(device) original_model = None try: if not torch_compile: original_model = accelerator.unwrap_model(transformer, keep_torch_compile=True).eval() else: original_model = transformer.eval() vae.to(device=device).eval() all_generated_images = [] all_captions = [] for size, (sample_latents, sample_text_embeddings, sample_mask, sample_text) in fixed_samples_cpu.items(): width, height = size curr_batch_size = sample_latents.shape[0] in_channels = original_model.config.in_channels sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device) sigmas_dtype = torch.float32 sigmas = torch.linspace(0, 1, n_diffusion_steps, dtype=sigmas_dtype) scheduler.set_timesteps(sigmas=sigmas, device=device) if scheduler.config.get("final_sigmas_type", "zero") == "sigma_min": scheduler.sigmas[-1] = scheduler.sigmas[-2] if scheduler.sigmas[-1] == 0.0: scheduler.sigmas[-1] = 1e-4 sigma_max = getattr(scheduler.config, "sigma_max", 80.0) latents = torch.randn( (curr_batch_size, in_channels, 1, sample_latents.shape[3], sample_latents.shape[4]), device=device, dtype=dtype, generator=torch.Generator(device=device).manual_seed(seed) ) * sigma_max padding_mask = torch.zeros((1, 1, sample_latents.shape[3], sample_latents.shape[4]), device=device, dtype=dtype) if guidance_scale != 1: neg_emb_batch = uncond_emb[0:1].expand(curr_batch_size, -1, -1) neg_emb_batch, sample_text_embeddings = pad_to_match(neg_emb_batch, sample_text_embeddings) for i, t in enumerate(scheduler.timesteps): current_sigma = scheduler.sigmas[i] if current_sigma == 0.0: current_sigma = torch.tensor(1e-4, dtype=current_sigma.dtype, device=device) current_t = current_sigma / (current_sigma + 1.0) c_in = 1.0 - current_t c_skip = 1.0 - current_t c_out = -current_t latent_model_input = (latents * c_in).to(dtype) t_val = float(current_t.item()) if torch.is_tensor(current_t) else float(current_t) timestep_tensor = torch.tensor([t_val], device=device, dtype=dtype).expand(curr_batch_size) noise_pred = original_model( hidden_states=latent_model_input, timestep=timestep_tensor, encoder_hidden_states=sample_text_embeddings, padding_mask=padding_mask, return_dict=False )[0] noise_pred = (c_skip * latents + c_out * noise_pred.float()).to(dtype) if guidance_scale != 1: noise_pred_uncond = original_model( hidden_states=latent_model_input, timestep=timestep_tensor, encoder_hidden_states=neg_emb_batch, padding_mask=padding_mask, return_dict=False )[0] noise_pred_uncond = (c_skip * latents + c_out * noise_pred_uncond.float()).to(dtype) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond) noise_pred = (latents - noise_pred) / current_sigma latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0] current_latents = latents if step == 0: current_latents = sample_latents if latents_mean is not None and latents_std is not None: sigma_data = getattr(scheduler.config, "sigma_data", 1.0) # Переводим векторы нормализации в float32 l_mean = torch.tensor(vae.config.latents_mean).view(1, -1, 1, 1, 1).to(device, torch.float32) l_std = torch.tensor(vae.config.latents_std).view(1, -1, 1, 1, 1).to(device, torch.float32) # Кастуем латенты в float32 перед умножением, чтобы сохранить точность latents_for_decode = (current_latents.to(torch.float32) * l_std) / sigma_data + l_mean else: latents_for_decode = current_latents.to(torch.float32) # 2. Декодируем, ПРИНУДИТЕЛЬНО ВКЛЮЧИВ MATH_SDP только для этого шага! with torch.backends.cuda.sdp_kernel(enable_math=True, enable_flash=False, enable_mem_efficient=False): decoded = vae.decode(latents_for_decode).sample # 3. Отсекаем лишнее видео-измерение if decoded.ndim == 5: decoded = decoded[:, :, 0, :, :] # 4. Он уже во float32, можно сразу пускать в цикл decoded_fp32 = decoded for img_idx, img_tensor in enumerate(decoded_fp32): img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy() img = img.transpose(1, 2, 0) if np.isnan(img).any(): print("NaNs found, saving stopped! Step:", step) img = np.nan_to_num(img, nan=0.0) pil_img = Image.fromarray((img * 255).astype("uint8")) max_w_overall = max(s[0] for s in fixed_samples_cpu.keys()) max_h_overall = max(s[1] for s in fixed_samples_cpu.keys()) max_w_overall = max(255, max_w_overall) max_h_overall = max(255, max_h_overall) padded_img = ImageOps.pad(pil_img, (max_w_overall, max_h_overall), color='white') all_generated_images.append(padded_img) caption_text = sample_text[img_idx][:300] if img_idx < len(sample_text) else "" all_captions.append(caption_text) sample_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg" pil_img.save(sample_path, "JPEG", quality=95) if use_wandb and accelerator.is_main_process: wandb_images = [ wandb.Image(img, caption=f"{all_captions[i]}") for i, img in enumerate(all_generated_images) ] wandb.log({"generated_images": wandb_images}) if use_comet_ml and accelerator.is_main_process: for i, img in enumerate(all_generated_images): comet_experiment.log_image( image_data=img, name=f"step_{step}_img_{i}", step=step, metadata={"caption": all_captions[i]} ) finally: vae.to("cpu") uncond_emb = uncond_emb.to("cpu") uncond_mask = uncond_mask.to("cpu") try: all_generated_images.clear() all_captions.clear() del all_generated_images, all_captions del latents, current_latents, latent_model_input del decoded, decoded_fp32 del sample_latents, sample_text_embeddings, sample_mask del noise_pred, noise_pred_uncond except UnboundLocalError: pass torch.cuda.synchronize() torch.cuda.empty_cache() gc.collect() if accelerator.is_main_process: if save_model: print("Генерация сэмплов до старта обучения...") generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), 0) accelerator.wait_for_everyone() def save_checkpoint(model_net, variant=""): if accelerator.is_main_process: model_to_save = None if not torch_compile: model_to_save = accelerator.unwrap_model(model_net) else: model_to_save = model_net if variant != "": model_to_save.to(dtype=torch.bfloat16).save_pretrained( os.path.join(checkpoints_folder, f"{project}"), variant=variant ) else: model_to_save.save_pretrained(os.path.join(checkpoints_folder, f"{project}")) torch.cuda.synchronize() torch.cuda.empty_cache() gc.collect() if accelerator.is_main_process: print(f"Total steps per GPU: {total_training_steps}") epoch_loss_points = [] progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step") steps_per_epoch = len(dataloader) sink_interval = max(1, steps_per_epoch // sink_interval_share) min_loss = 4. last_sample_time = time.time() sample_interval_seconds = sample_interval_min * 60 for epoch in range(start_epoch, start_epoch + num_epochs): batch_losses = [] batch_grads = [] batch_sampler.set_epoch(epoch) accelerator.wait_for_everyone() transformer.train() for step, (latents, embeddings, attention_mask) in enumerate(dataloader): if save_model == False and epoch == 0 and step == 5 : used_gb = torch.cuda.max_memory_allocated() / 1024**3 print(f"Шаг {step}: {used_gb:.2f} GB") amp_context = accelerator.autocast() if torch_compile else nullcontext() with accelerator.accumulate(transformer): with amp_context: noise = torch.randn_like(latents, dtype=latents.dtype) t = torch.sigmoid(torch.randn(latents.shape[0], device=latents.device, dtype=latents.dtype) + sigmoid_bias) noisy_latents_5d = (1.0 - t.view(-1, 1, 1, 1, 1)) * latents + t.view(-1, 1, 1, 1, 1) * noise target_5d = noise - latents padding_mask = torch.zeros((1, 1, latents.shape[3], latents.shape[4]), device=device, dtype=dtype) timestep_tensor = t.flatten().to(dtype) model_pred = transformer( hidden_states=noisy_latents_5d, timestep=timestep_tensor, encoder_hidden_states=embeddings, padding_mask=padding_mask, return_dict=False )[0] mse_loss = F.mse_loss(model_pred.float(), target_5d.float()) batch_losses.append(mse_loss.detach().item()) if (global_step % 100 == 0) or (global_step % sink_interval == 0): accelerator.wait_for_everyone() losses_dict = {} losses_dict["mse"] = mse_loss if (global_step % 100 == 0) or (global_step % sink_interval == 0): accelerator.wait_for_everyone() accelerator.backward(mse_loss) if (global_step % 100 == 0) or (global_step % sink_interval == 0): accelerator.wait_for_everyone() grad = 0.0 if not fbp: if accelerator.sync_gradients: grad_val = accelerator.clip_grad_norm_(transformer.parameters(), clip_grad_norm) grad = grad_val.float().item() if torch.is_tensor(grad_val) else float(grad_val) optimizer.step() lr_scheduler.step() optimizer.zero_grad(set_to_none=True) if accelerator.sync_gradients: global_step += 1 progress_bar.update(1) if accelerator.is_main_process: if fbp: current_lr = base_learning_rate else: current_lr = lr_scheduler.get_last_lr()[0] batch_grads.append(grad) log_data = {} log_data["loss_mse"] = mse_loss.detach().item() log_data["lr"] = current_lr log_data["grad"] = grad if accelerator.sync_gradients: if use_wandb: wandb.log(log_data, step=global_step) if use_comet_ml: comet_experiment.log_metrics(log_data, step=global_step) current_time = time.time() is_time_to_sample = (current_time - last_sample_time) >= sample_interval_seconds if is_time_to_sample or global_step == 50: if save_model: generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step) elif epoch % 10 == 0: generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step) last_n = sink_interval if save_model: has_losses = len(batch_losses) > 0 avg_sample_loss = np.mean(batch_losses[-sink_interval:]) if has_losses else 0.0 last_loss = batch_losses[-1] if has_losses else 0.0 max_loss = max(avg_sample_loss, last_loss) should_save = max_loss < min_loss * save_barrier print( f"Saving: {should_save} | Max: {max_loss:.4f} | " f"Last: {last_loss:.4f} | Avg: {avg_sample_loss:.4f}" ) if should_save: min_loss = max_loss save_checkpoint(transformer) last_sample_time = current_time transformer.train() if accelerator.is_main_process: avg_epoch_loss = np.mean(batch_losses) if len(batch_losses) > 0 else 0.0 avg_epoch_grad = np.mean(batch_grads) if len(batch_grads) > 0 else 0.0 print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}") log_data_ep = { "epoch_loss": avg_epoch_loss, "epoch_grad": avg_epoch_grad, "epoch": epoch + 1, } if use_wandb: wandb.log(log_data_ep) if use_comet_ml: comet_experiment.log_metrics(log_data_ep) if accelerator.is_main_process: print("Обучение завершено! Сохраняем финальную модель...") save_checkpoint(transformer,"bf16") if use_comet_ml: comet_experiment.end() accelerator.free_memory() if torch.distributed.is_initialized(): torch.distributed.destroy_process_group() print("Готово!")