Instructions to use babkasotona/1b with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use babkasotona/1b with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("babkasotona/1b", dtype=torch.bfloat16, device_map="cuda") prompt = "sdxs-1b" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
- Local Apps Settings
- Draw Things
- DiffusionBee
| import os | |
| import math | |
| import torch | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import wandb,comet_ml | |
| import random,time | |
| import gc | |
| import bitsandbytes as bnb | |
| import torch.nn.functional as F | |
| import argparse | |
| from datetime import datetime | |
| from diffusers import UNet2DConditionModel, AutoencoderKLQwenImage, FlowMatchEulerDiscreteScheduler | |
| from transformers import Qwen3_5Tokenizer, Qwen3_5ForConditionalGeneration | |
| from torch.utils.data import DataLoader, Sampler | |
| from torch.optim.lr_scheduler import LambdaLR | |
| from collections import defaultdict | |
| from accelerate import Accelerator | |
| from datasets import load_from_disk | |
| from tqdm import tqdm | |
| from PIL import Image, ImageOps | |
| from torch.utils.checkpoint import checkpoint | |
| from diffusers.models.attention_processor import AttnProcessor2_0 | |
| from contextlib import nullcontext | |
| from transformers.optimization import Adafactor | |
| from torch.nn.attention import sdpa_kernel, SDPBackend | |
| # Muon not tested! pip install git+https://github.com/recoilme/muon_adamw8bit.git | |
| from muon_adamw8bit import MuonAdamW8bit | |
| os.environ["NCCL_P2P_DISABLE"] = "1" | |
| os.environ["NCCL_IB_DISABLE"] = "1" # comment this on H100! | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True" | |
| # --------------------------- Параметры --------------------------- | |
| #ds_path = "datasets/ds1234_noanime_704_vae8x16x" #alchemist_704_vae8x16x_imgpool" | |
| ds_path = "/root/sdxs-2b/datasets/ds12345_640_vae_qwen" | |
| project = "unet" | |
| gpu_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9 | |
| local_bs = max(1, int((gpu_mem_gb / 32) * 7)) | |
| num_gpus = torch.cuda.device_count() | |
| batch_size = local_bs * num_gpus | |
| base_learning_rate = 4e-5 | |
| min_learning_rate = 4e-6 | |
| # 0.5 - pretrain (base forms) | |
| # 1 - base train (composition) | |
| # 3 - finetuning (anatomy) | |
| # 5 - small details (faces) | |
| learning_rate_scale = 3 | |
| base_learning_rate = base_learning_rate / learning_rate_scale | |
| min_learning_rate = min_learning_rate / learning_rate_scale | |
| print(f"Calculated params max-lr:{base_learning_rate} min-lr:{min_learning_rate} GPUs: {num_gpus}, Global BS: {batch_size}") | |
| num_epochs = num_gpus | |
| sink_interval_share = 20 | |
| sample_interval_min = 60 | |
| cfg_dropout = -0.10 | |
| # Время t, bias = -0.5 (Фокус на Деталях ~300) bias = 0.5 (Фокус на структуре) bias = 0 (колокол/ равномерно) | |
| sigmoid_bias = -0.1 | |
| max_length = 250 | |
| max_snr_gamma = 5.0 | |
| use_precomputed_embeddings = False | |
| use_wandb = False | |
| use_comet_ml = True | |
| save_model = True | |
| use_decay = True | |
| fbp = False | |
| torch_compile = False | |
| unet_gradient = True | |
| loss_normalize = False | |
| fixed_seed = False | |
| shuffle = True | |
| optimizer_type = "adafactor" #"adam8bit" | |
| if optimizer_type == "muon_adam8bit": | |
| batch_size = num_gpus * max(1, int((gpu_mem_gb / 32) * 3)) | |
| muon_lr_scale = 500 | |
| comet_ml_api_key = "Agctp26mbqnoYrrlvQuKSTk6r" # hardcoded for blind run, i don't care about key | |
| comet_ml_workspace = "recoilme" | |
| torch.backends.cuda.matmul.allow_tf32 = True | |
| torch.backends.cudnn.allow_tf32 = True | |
| # MAX_JOBS=4 pip install flash-attn --no-build-isolation | |
| #torch.backends.cuda.enable_flash_sdp(True) | |
| #torch.backends.cuda.enable_mem_efficient_sdp(True) | |
| #torch.backends.cuda.enable_math_sdp(False) # Отключаем медленный вариант | |
| save_barrier = 1.25 | |
| warmup_percent = 0.0025 | |
| betta2 = 0.997 | |
| eps = 1e-6 | |
| clip_grad_norm = 1.0 | |
| limit = 0 | |
| checkpoints_folder = "" | |
| gradient_accumulation_steps = 1 | |
| dtype = torch.float32 | |
| mixed_precision = "bf16" | |
| # Параметры для диффузии | |
| n_diffusion_steps = 40 | |
| samples_to_generate = 12 | |
| guidance_scale = 4 | |
| # Папки для сохранения результатов | |
| generated_folder = "samples" | |
| os.makedirs(generated_folder, exist_ok=True) | |
| # Настройка seed | |
| current_date = datetime.now() | |
| seed = int(current_date.strftime("%Y%m%d")) + 42 | |
| if fixed_seed: | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| random.seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| accelerator = Accelerator( | |
| mixed_precision=mixed_precision, | |
| gradient_accumulation_steps=gradient_accumulation_steps | |
| ) | |
| device = accelerator.device | |
| print("init") | |
| # Создаём объект ArgumentParser с рассчитанными значениями по умолчанию | |
| parser = argparse.ArgumentParser(description='Train a model on a dataset.') | |
| parser.add_argument('--ds-path', type=str, default=ds_path, help='Path to the dataset') | |
| parser.add_argument('--ep', type=int, default=num_epochs, help='Number of epochs to train the model') | |
| parser.add_argument('--batch', type=int, default=batch_size, help='Total batch size') | |
| parser.add_argument('--min-lr', type=float, default=min_learning_rate, help='Minimum learning rate') | |
| parser.add_argument('--max-lr', type=float, default=base_learning_rate, help='Maximum learning rate') | |
| parser.add_argument('--dry-run', action='store_true',default=False, help='Dry run train without saving/sampling') | |
| parser.add_argument('--lvl', type=float, default=0.0, help='Train level, from 0.5 to 5') | |
| # Парсим аргументы командной строки | |
| args = parser.parse_args() | |
| # Используем значения из аргументов | |
| batch_size = args.batch | |
| ds_path = args.ds_path | |
| base_learning_rate = args.max_lr | |
| min_learning_rate = args.min_lr | |
| num_epochs = args.ep | |
| lvl = args.lvl | |
| if args.dry_run: | |
| save_model = False | |
| if lvl >= 0.1: | |
| base_learning_rate = base_learning_rate / lvl | |
| min_learning_rate = min_learning_rate / lvl | |
| print(f"max-lr:{base_learning_rate} min-lr:{min_learning_rate}") | |
| # --------------------------- Инициализация WandB --------------------------- | |
| if accelerator.is_main_process: | |
| if use_wandb: | |
| wandb.init(project=project, config={ | |
| "batch_size": batch_size, | |
| "base_learning_rate": base_learning_rate, | |
| "num_epochs": num_epochs, | |
| "optimizer_type": optimizer_type, | |
| }) | |
| if use_comet_ml: | |
| from comet_ml import Experiment | |
| comet_experiment = Experiment( | |
| api_key=comet_ml_api_key, | |
| project_name=project, | |
| workspace=comet_ml_workspace | |
| ) | |
| hyper_params = { | |
| "batch_size": batch_size, | |
| "base_learning_rate": base_learning_rate, | |
| "num_epochs": num_epochs, | |
| } | |
| comet_experiment.log_parameters(hyper_params) | |
| # --------------------------- Загрузка моделей --------------------------- | |
| vae = AutoencoderKLQwenImage.from_pretrained("vae", torch_dtype=dtype).to(device).eval() | |
| scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("scheduler") | |
| tokenizer = None | |
| text_encoder = None | |
| def load_text_encoder(): | |
| global tokenizer, text_encoder | |
| if tokenizer is None: | |
| tokenizer = Qwen3_5Tokenizer.from_pretrained("tokenizer") | |
| if text_encoder is None: | |
| text_encoder = Qwen3_5ForConditionalGeneration.from_pretrained( | |
| "text_encoder", | |
| torch_dtype=torch.float16 | |
| ).to(device).eval() | |
| load_text_encoder() | |
| def encode_texts(text, max_length=max_length): | |
| if text is None: | |
| text = "" | |
| if isinstance(text, str): | |
| text = [text] | |
| formatted_prompts = [] | |
| for t in text: | |
| messages = [{"role": "user", "content": [{"type": "text", "text": t}]}] | |
| formatted_prompts.append( | |
| tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=False # 🔥 фикс | |
| ) | |
| ) | |
| toks = tokenizer( | |
| formatted_prompts, | |
| padding="max_length", | |
| max_length=max_length, | |
| truncation=True, | |
| return_tensors="pt" | |
| ).to(device) | |
| #toks = tokenizer( | |
| # formatted_prompts, | |
| # padding=True, # 🔥 динамический padding | |
| # truncation=True, | |
| # return_tensors="pt" | |
| #).to(device) | |
| outputs = text_encoder( | |
| input_ids=toks.input_ids, | |
| attention_mask=toks.attention_mask, | |
| output_hidden_states=True | |
| ) | |
| hidden = outputs.hidden_states[-2] | |
| return hidden.to(dtype=dtype), toks.attention_mask.to(dtype=torch.bool) | |
| def encode_texts_fast(text, max_length=max_length): | |
| if text is None: text = "" | |
| if isinstance(text, str): text = [text] | |
| formatted_prompts = [] | |
| for t in text: | |
| messages = [{"role": "user", "content": [{"type": "text", "text": t}]}] | |
| formatted_prompts.append(tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False)) | |
| toks = tokenizer(formatted_prompts, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt").to(device) | |
| outputs = text_encoder(input_ids=toks.input_ids, attention_mask=toks.attention_mask, output_hidden_states=True) | |
| last_hidden = outputs.hidden_states[-2] | |
| return last_hidden.to(dtype=dtype), toks.attention_mask.to(dtype=torch.bool) | |
| shift_factor = getattr(vae.config, "shift_factor", 0.0) | |
| if shift_factor is None: | |
| shift_factor = 0.0 | |
| scaling_factor = getattr(vae.config, "scaling_factor", 1.0) | |
| if scaling_factor is None: | |
| scaling_factor = 1.0 | |
| mean = getattr(vae.config, "latents_mean", None) | |
| std = getattr(vae.config, "latents_std", None) | |
| if mean is not None and std is not None: | |
| latents_std = torch.tensor(std, device=device, dtype=dtype).view(1, len(std), 1, 1) | |
| latents_mean = torch.tensor(mean, device=device, dtype=dtype).view(1, len(mean), 1, 1) | |
| import numpy as np | |
| from torch.utils.data import Sampler | |
| class DistributedResolutionBatchSampler(Sampler): | |
| def __init__(self, dataset, batch_size, num_replicas, rank, drop_last=True, shuffle=True): | |
| self.dataset = dataset | |
| self.num_replicas = num_replicas | |
| self.rank = rank | |
| self.shuffle = shuffle | |
| self.drop_last = drop_last | |
| self.epoch = 0 | |
| # batch на одну GPU | |
| self.batch_size = max(1, batch_size // num_replicas) | |
| self.global_batch = self.batch_size * num_replicas | |
| try: | |
| widths = np.asarray(dataset["width"]) | |
| heights = np.asarray(dataset["height"]) | |
| except KeyError: | |
| widths = np.zeros(len(dataset)) | |
| heights = np.zeros(len(dataset)) | |
| # --- группировка индексов --- | |
| groups = {} | |
| for i, (w, h) in enumerate(zip(widths, heights)): | |
| groups.setdefault((w, h), []).append(i) | |
| # --- создаём список всех глобальных батчей --- | |
| all_batches = [] | |
| for indices in groups.values(): | |
| idx = np.asarray(indices, dtype=np.int64) | |
| num_batches = len(idx) // self.global_batch | |
| if num_batches == 0: | |
| continue | |
| idx = idx[: num_batches * self.global_batch] | |
| batches = idx.reshape(num_batches, self.global_batch) | |
| all_batches.append(batches) | |
| if len(all_batches) > 0: | |
| self.global_batches = np.concatenate(all_batches, axis=0) | |
| else: | |
| self.global_batches = np.empty((0, self.global_batch), dtype=np.int64) | |
| self.num_batches = len(self.global_batches) | |
| def __iter__(self): | |
| rng = np.random.RandomState(self.epoch) | |
| order = np.arange(self.num_batches) | |
| if self.shuffle: | |
| rng.shuffle(order) | |
| start = self.rank * self.batch_size | |
| end = start + self.batch_size | |
| for i in order: | |
| yield self.global_batches[i][start:end] | |
| def __len__(self): | |
| return self.num_batches | |
| def set_epoch(self, epoch): | |
| self.epoch = epoch | |
| def get_fixed_samples_by_resolution(dataset, samples_per_group=1): | |
| size_groups = defaultdict(list) | |
| try: | |
| widths = dataset["width"] | |
| heights = dataset["height"] | |
| except KeyError: | |
| widths = [0] * len(dataset) | |
| heights = [0] * len(dataset) | |
| for i, (w, h) in enumerate(zip(widths, heights)): | |
| size = (w, h) | |
| size_groups[size].append(i) | |
| fixed_samples = {} | |
| for size, indices in size_groups.items(): | |
| n_samples = min(samples_per_group, len(indices)) | |
| if len(size_groups)==1: | |
| n_samples = samples_to_generate | |
| if n_samples == 0: | |
| continue | |
| sample_indices = random.sample(indices, n_samples) | |
| samples_data = [dataset[idx] for idx in sample_indices] | |
| latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device, dtype=dtype) | |
| if latents.ndim == 5: | |
| latents = latents.squeeze(2) | |
| #elif latents.ndim == 6: | |
| # latents = latents.squeeze(2) | |
| texts = [item["text"] for item in samples_data] | |
| if use_precomputed_embeddings: | |
| embeddings = torch.tensor( | |
| np.array([item["embeddings"] for item in samples_data]), | |
| device=device, | |
| dtype=dtype | |
| ) | |
| masks = torch.tensor( | |
| np.array([item["attention_mask"] for item in samples_data]), | |
| device=device, | |
| dtype=torch.int64 | |
| ) | |
| else: | |
| embeddings, masks = encode_texts(texts,max_length) | |
| fixed_samples[size] = (latents, embeddings, masks, texts) | |
| print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям") | |
| return fixed_samples | |
| if limit > 0: | |
| dataset = load_from_disk(ds_path).select(range(limit)) | |
| else: | |
| dataset = load_from_disk(ds_path) | |
| print(f"images: {len(dataset)}") | |
| def collate_fn_simple(batch): | |
| latents = torch.from_numpy( | |
| np.array([item["vae"] for item in batch], dtype=np.float16) | |
| ).to(device, dtype=dtype) | |
| if latents.ndim == 5: | |
| latents = latents.squeeze(2) | |
| #elif latents.ndim == 6: | |
| # latents = latents.squeeze(2) | |
| if use_precomputed_embeddings: | |
| embeddings = torch.from_numpy( | |
| np.array([item["embeddings"] for item in batch], dtype=np.float16) | |
| ).to(device, dtype=dtype) | |
| attention_mask = torch.from_numpy( | |
| np.array([item["attention_mask"] for item in batch], dtype=np.int64) | |
| ).to(device) | |
| return latents, embeddings, attention_mask | |
| raw_texts = [item["text"] for item in batch] | |
| texts = [ | |
| "" if t.lower().startswith("zero") | |
| else "" if random.random() < cfg_dropout | |
| else t[1:].lstrip() if t.startswith(".") | |
| else t.replace("The image shows ", "").replace("The image is ", "").replace("This image captures ","").strip() | |
| for t in raw_texts | |
| ] | |
| embeddings, attention_mask = encode_texts(texts,max_length) | |
| attention_mask = attention_mask.to(dtype=torch.bool) | |
| return latents, embeddings, attention_mask | |
| batch_sampler = DistributedResolutionBatchSampler( | |
| dataset=dataset, | |
| batch_size=batch_size, | |
| num_replicas=accelerator.num_processes, | |
| rank=accelerator.process_index, | |
| shuffle = shuffle | |
| ) | |
| dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple) | |
| if accelerator.is_main_process: | |
| print("Total samples", len(dataloader)) | |
| dataloader = accelerator.prepare(dataloader) | |
| start_epoch = 0 | |
| global_step = 0 | |
| total_training_steps = (len(dataloader) * num_epochs) | |
| world_size = accelerator.state.num_processes | |
| # Загрузка UNet | |
| latest_checkpoint = os.path.join(checkpoints_folder, project) | |
| if os.path.isdir(latest_checkpoint): | |
| print("Загружаем UNet из чекпоинта:", latest_checkpoint) | |
| unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device, dtype=dtype) | |
| if unet_gradient: | |
| unet.enable_gradient_checkpointing() | |
| #unet.set_use_memory_efficient_attention_xformers(False) | |
| print(dir(SDPBackend)) | |
| try: | |
| unet.set_attn_processor(AttnProcessor2_0()) | |
| except Exception as e: | |
| print(f"Ошибка при включении SDPA: {e}") | |
| unet.set_use_memory_efficient_attention_xformers(True) | |
| else: | |
| raise FileNotFoundError(f"UNet checkpoint not found at {latest_checkpoint}") | |
| def create_optimizer(name, params): | |
| if name == "adam8bit": | |
| return bnb.optim.AdamW8bit( | |
| params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.001 | |
| ) | |
| elif name == "adam": | |
| return torch.optim.AdamW( | |
| params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.001 | |
| ) | |
| elif name == "adafactor": | |
| return Adafactor( | |
| params, | |
| lr=base_learning_rate, | |
| eps=(1e-30, 1e-3), | |
| clip_threshold=1.0, | |
| decay_rate=-0.8, | |
| beta1=None, | |
| weight_decay=0.001, | |
| relative_step=False, | |
| scale_parameter=False, | |
| warmup_init=False | |
| ) | |
| elif name == "muon_adam8bit": | |
| return MuonAdamW8bit( | |
| params, | |
| lr=base_learning_rate, | |
| betas=(0.9, betta2), | |
| eps=eps, | |
| weight_decay=0.01, | |
| muon_lr_mult=muon_lr_scale, | |
| ) | |
| else: | |
| raise ValueError(f"Unknown optimizer: {name}") | |
| if fbp: | |
| trainable_params = list(unet.parameters()) | |
| optimizer_dict = {p: create_optimizer(optimizer_type, [p]) for p in trainable_params} | |
| def optimizer_hook(param): | |
| optimizer_dict[param].step() | |
| optimizer_dict[param].zero_grad(set_to_none=True) | |
| for param in trainable_params: | |
| param.register_post_accumulate_grad_hook(optimizer_hook) | |
| unet, optimizer = accelerator.prepare(unet, optimizer_dict) | |
| else: | |
| unet.requires_grad_(True) | |
| optimizer = create_optimizer(optimizer_type, unet.parameters()) | |
| # 1. Сначала замораживаем ВСЕ параметры UNet | |
| #unet.requires_grad_(False) | |
| # 2. Размораживаем только нужные | |
| #trainable_params_names = ["conv_in.weight", "conv_in.bias", "conv_out.weight", "conv_out.bias"] | |
| #train_params = [] | |
| #for name, param in unet.named_parameters(): | |
| # if any(target in name for target in trainable_params_names): | |
| # param.requires_grad = True | |
| # train_params.append(param) | |
| # print(f"Обучаемый слой: {name}") | |
| def lr_schedule(step): | |
| x = step / (total_training_steps * world_size) | |
| warmup = warmup_percent | |
| if not use_decay: | |
| return base_learning_rate | |
| if x < warmup: | |
| return min_learning_rate + (base_learning_rate - min_learning_rate) * (x / warmup) | |
| decay_ratio = (x - warmup) / (1 - warmup) | |
| return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \ | |
| (1 + math.cos(math.pi * decay_ratio)) | |
| lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate) | |
| if torch_compile: | |
| print("Compiling UNet... Это займет несколько минут, не прерывайте!") | |
| unet = torch.compile(unet) | |
| print("Compiling - ok") | |
| if not fbp: | |
| unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler) | |
| # Фиксированные семплы | |
| fixed_samples = get_fixed_samples_by_resolution(dataset) | |
| # --- [UPDATED] Функция для негативного эмбеддинга (возвращает 3 элемента) --- | |
| def get_negative_embedding(neg_prompt="", batch_size=1): | |
| if not neg_prompt: | |
| hidden_dim = 1024 | |
| seq_len = max_length | |
| empty_emb = torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device) | |
| empty_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device) | |
| return empty_emb, empty_mask | |
| uncond_emb, uncond_mask = encode_texts([neg_prompt],max_length) | |
| uncond_emb = uncond_emb.to(dtype=dtype, device=device).repeat(batch_size, 1, 1) | |
| uncond_mask = uncond_mask.to(device=device).repeat(batch_size, 1) | |
| return uncond_emb, uncond_mask | |
| # Получаем негативные (пустые) условия для валидации | |
| if use_precomputed_embeddings: | |
| # 1. грузим encoder ВРЕМЕННО | |
| load_text_encoder() | |
| # 2. считаем negative | |
| uncond_emb, uncond_mask = get_negative_embedding("low quality") | |
| # 3. уносим на CPU (очень важно) | |
| uncond_emb = uncond_emb.to("cpu") | |
| uncond_mask = uncond_mask.to("cpu") | |
| # 4. выгружаем encoder с GPU | |
| del text_encoder | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| text_encoder = None | |
| else: | |
| uncond_emb, uncond_mask = get_negative_embedding("low quality") | |
| def pad_to_match(a, b, pad_value=0): | |
| # a, b: [B, T, D] | |
| Ta, Tb = a.shape[1], b.shape[1] | |
| if Ta == Tb: | |
| return a, b | |
| T = max(Ta, Tb) | |
| def pad(x, T_target): | |
| pad_len = T_target - x.shape[1] | |
| if pad_len <= 0: | |
| return x | |
| return torch.nn.functional.pad(x, (0, 0, 0, pad_len), value=pad_value) | |
| return pad(a, T), pad(b, T) | |
| def pad_mask(a, b): | |
| Ta, Tb = a.shape[1], b.shape[1] | |
| T = max(Ta, Tb) | |
| def pad(x): | |
| pad_len = T - x.shape[1] | |
| if pad_len <= 0: | |
| return x | |
| return torch.nn.functional.pad(x, (0, pad_len), value=0) | |
| return pad(a), pad(b) | |
| # --- Функция генерации семплов --- | |
| def generate_and_save_samples(fixed_samples_cpu, uncond_data, step): | |
| uncond_emb, uncond_mask = uncond_data | |
| uncond_emb = uncond_emb.to(device) | |
| uncond_mask = uncond_mask.to(device) | |
| original_model = None | |
| try: | |
| if not torch_compile: | |
| original_model = accelerator.unwrap_model(unet, keep_torch_compile=True).eval() | |
| else: | |
| original_model = unet.eval() | |
| vae.to(device=device).eval() | |
| all_generated_images = [] | |
| all_captions = [] | |
| # Распаковываем 5 элементов (добавились mask) | |
| for size, (sample_latents, sample_text_embeddings, sample_mask, sample_text) in fixed_samples_cpu.items(): | |
| width, height = size | |
| sample_latents = sample_latents.to(dtype=dtype, device=device) | |
| sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device) | |
| sample_mask = sample_mask.to(device=device) | |
| latents = torch.randn( | |
| sample_latents.shape, | |
| device=device, | |
| dtype=sample_latents.dtype, | |
| generator=torch.Generator(device=device).manual_seed(seed) | |
| ) | |
| scheduler.set_timesteps(n_diffusion_steps, device=device) | |
| for t in scheduler.timesteps: | |
| if guidance_scale != 1: | |
| latent_model_input = torch.cat([latents, latents], dim=0) | |
| curr_batch_size = sample_text_embeddings.shape[0] | |
| seq_len = sample_text_embeddings.shape[1] | |
| hidden_dim = sample_text_embeddings.shape[2] | |
| neg_emb_batch = uncond_emb[0:1].expand(curr_batch_size, -1, -1) | |
| neg_emb_batch, sample_text_embeddings = pad_to_match(neg_emb_batch, sample_text_embeddings) | |
| neg_mask_batch = uncond_mask[0:1].expand(curr_batch_size, -1) | |
| neg_mask_batch, sample_mask = pad_mask(neg_mask_batch, sample_mask) | |
| text_embeddings_batch = torch.cat([neg_emb_batch, sample_text_embeddings], dim=0) | |
| attention_mask_batch = torch.cat([neg_mask_batch, sample_mask], dim=0) | |
| #neg_emb_batch = uncond_emb[0:1].expand(curr_batch_size, -1, -1) | |
| #text_embeddings_batch = torch.cat([neg_emb_batch, sample_text_embeddings], dim=0) | |
| #neg_mask_batch = uncond_mask[0:1].expand(curr_batch_size, -1) | |
| #attention_mask_batch = torch.cat([neg_mask_batch, sample_mask], dim=0) | |
| else: | |
| latent_model_input = latents | |
| text_embeddings_batch = sample_text_embeddings | |
| attention_mask_batch = sample_mask | |
| # Теперь всё имеет одинаковый batch size | |
| model_out = original_model( | |
| latent_model_input, | |
| t, | |
| encoder_hidden_states=text_embeddings_batch, | |
| encoder_attention_mask=attention_mask_batch, | |
| ) | |
| flow = getattr(model_out, "sample", model_out) | |
| if guidance_scale != 1: | |
| flow_uncond, flow_cond = flow.chunk(2) | |
| flow = flow_uncond + guidance_scale * (flow_cond - flow_uncond) | |
| latents = scheduler.step(flow, t, latents).prev_sample | |
| current_latents = latents | |
| if step==0: | |
| current_latents = sample_latents | |
| # VAE Qwen ожидает 5D, добавляем 1 кадр времени | |
| vae_input = current_latents.unsqueeze(2).to(torch.float32) | |
| if latents_mean is not None and latents_std is not None: | |
| vae_input = vae_input * latents_std.unsqueeze(2) + latents_mean.unsqueeze(2) | |
| decoded = vae.decode(vae_input).sample | |
| # После декодирования у Qwen на выходе [B, C, 1, H, W], убираем 1 | |
| if decoded.ndim == 5: | |
| decoded = decoded.squeeze(2) | |
| decoded_fp32 = decoded.to(torch.float32) | |
| for img_idx, img_tensor in enumerate(decoded_fp32): | |
| img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy() | |
| img = img.transpose(1, 2, 0) | |
| if np.isnan(img).any(): | |
| print("NaNs found, saving stopped! Step:", step) | |
| pil_img = Image.fromarray((img * 255).astype("uint8")) | |
| max_w_overall = max(s[0] for s in fixed_samples_cpu.keys()) | |
| max_h_overall = max(s[1] for s in fixed_samples_cpu.keys()) | |
| max_w_overall = max(255, max_w_overall) | |
| max_h_overall = max(255, max_h_overall) | |
| padded_img = ImageOps.pad(pil_img, (max_w_overall, max_h_overall), color='white') | |
| all_generated_images.append(padded_img) | |
| caption_text = sample_text[img_idx][:300] if img_idx < len(sample_text) else "" | |
| all_captions.append(caption_text) | |
| sample_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg" | |
| pil_img.save(sample_path, "JPEG", quality=95) | |
| if use_wandb and accelerator.is_main_process: | |
| wandb_images = [ | |
| wandb.Image(img, caption=f"{all_captions[i]}") | |
| for i, img in enumerate(all_generated_images) | |
| ] | |
| wandb.log({"generated_images": wandb_images}) | |
| if use_comet_ml and accelerator.is_main_process: | |
| for i, img in enumerate(all_generated_images): | |
| comet_experiment.log_image( | |
| image_data=img, | |
| name=f"step_{step}_img_{i}", | |
| step=step, | |
| metadata={"caption": all_captions[i]} | |
| ) | |
| finally: | |
| vae.to("cpu") | |
| uncond_emb = uncond_emb.to("cpu") | |
| uncond_mask = uncond_mask.to("cpu") | |
| try: | |
| all_generated_images.clear() | |
| all_captions.clear() | |
| del all_generated_images, all_captions | |
| del latents, current_latents, latent_model_input, flow | |
| del decoded, decoded_fp32 | |
| del sample_latents, sample_text_embeddings, sample_mask # Копии на GPU | |
| del model_out | |
| except UnboundLocalError: | |
| pass | |
| # 3. Синхронизируем CUDA перед очисткой | |
| torch.cuda.synchronize() | |
| # 4. Теперь чистим кэш аллокатора и вызываем GC | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| # --------------------------- Генерация сэмплов перед обучением --------------------------- | |
| if accelerator.is_main_process: | |
| if save_model: | |
| print("Генерация сэмплов до старта обучения...") | |
| generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), 0) | |
| accelerator.wait_for_everyone() | |
| def save_checkpoint(unet, variant=""): | |
| if accelerator.is_main_process: | |
| model_to_save = None | |
| if not torch_compile: | |
| model_to_save = accelerator.unwrap_model(unet) | |
| else: | |
| model_to_save = unet | |
| if variant != "": | |
| model_to_save.to(dtype=torch.float16).save_pretrained( | |
| os.path.join(checkpoints_folder, f"{project}"), variant=variant | |
| ) | |
| else: | |
| model_to_save.save_pretrained(os.path.join(checkpoints_folder, f"{project}")) | |
| torch.cuda.synchronize() | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| # --------------------------- Тренировочный цикл --------------------------- | |
| if accelerator.is_main_process: | |
| print(f"Total steps per GPU: {total_training_steps}") | |
| epoch_loss_points = [] | |
| progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step") | |
| steps_per_epoch = len(dataloader) | |
| sink_interval = max(1, steps_per_epoch // sink_interval_share) | |
| min_loss = 4. | |
| last_sample_time = time.time() | |
| sample_interval_seconds = sample_interval_min * 60 # 60 минут | |
| for epoch in range(start_epoch, start_epoch + num_epochs): | |
| batch_losses = [] | |
| batch_grads = [] | |
| batch_sampler.set_epoch(epoch) | |
| accelerator.wait_for_everyone() | |
| unet.train() | |
| for step, (latents, embeddings, attention_mask) in enumerate(dataloader): | |
| if save_model == False and epoch == 0 and step == 5 : | |
| used_gb = torch.cuda.max_memory_allocated() / 1024**3 | |
| print(f"Шаг {step}: {used_gb:.2f} GB") | |
| amp_context = accelerator.autocast() if torch_compile else nullcontext() | |
| with accelerator.accumulate(unet): | |
| with amp_context: | |
| # шум | |
| noise = torch.randn_like(latents, dtype=latents.dtype) | |
| t = torch.sigmoid(torch.randn(latents.shape[0], device=latents.device, dtype=latents.dtype) + sigmoid_bias) | |
| # интерполяция между x0 и шумом | |
| noisy_latents = (1.0 - t.view(-1, 1, 1, 1)) * latents + t.view(-1, 1, 1, 1) * noise | |
| # делаем integer timesteps для UNet | |
| timesteps = t.to(torch.float32).mul(999.0) | |
| # --- Вызов UNet с маской --- | |
| #with sdpa_kernel([SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION], SDPBackend.CUDNN_ATTENTION): | |
| with sdpa_kernel([ | |
| SDPBackend.FLASH_ATTENTION, | |
| SDPBackend.EFFICIENT_ATTENTION, | |
| SDPBackend.CUDNN_ATTENTION, | |
| SDPBackend.MATH | |
| ]): | |
| model_pred = unet( | |
| noisy_latents, | |
| timesteps, | |
| encoder_hidden_states=embeddings, | |
| encoder_attention_mask=attention_mask, | |
| ).sample | |
| target = noise - latents | |
| if max_snr_gamma > 0: | |
| # 1. Считаем сырой лосс попиксельно (без усреднения по всему батчу) | |
| raw_loss = F.mse_loss(model_pred.float(), target.float(), reduction='none') | |
| # Усредняем ошибку внутри каждой картинки, чтобы получить вектор лоссов [Batch_Size] | |
| loss_per_sample = raw_loss.mean(dim=[1, 2, 3]) | |
| # 2. Считаем SNR (Signal-to-Noise Ratio) для текущих таймстепов батча | |
| # Сигнал — это (1 - t), шум — это t. | |
| snr = ((1.0 - t) / (t + 1e-5)) ** 2 | |
| # Для твоего формата предсказания (v-prediction) формула веса: min(snr, gamma) / (snr + 1) | |
| min_snr_weights = torch.clamp(snr, max=max_snr_gamma) / (snr + 1.0) | |
| # 4. Применяем веса к лоссам картинок и получаем итоговый скаляр для backward() | |
| # (view(-1) гарантирует, что размерности совпадут) | |
| mse_loss = (loss_per_sample * min_snr_weights.view(-1)).mean() | |
| else: | |
| mse_loss = F.mse_loss(model_pred.float(), target.float()) | |
| batch_losses.append(mse_loss.detach().item()) | |
| if (global_step % 100 == 0) or (global_step % sink_interval == 0): | |
| accelerator.wait_for_everyone() | |
| losses_dict = {} | |
| losses_dict["mse"] = mse_loss | |
| if (global_step % 100 == 0) or (global_step % sink_interval == 0): | |
| accelerator.wait_for_everyone() | |
| accelerator.backward(mse_loss) | |
| if (global_step % 100 == 0) or (global_step % sink_interval == 0): | |
| accelerator.wait_for_everyone() | |
| grad = 0.0 | |
| if not fbp: | |
| if accelerator.sync_gradients: | |
| grad_val = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm) | |
| grad = grad_val.float().item() if torch.is_tensor(grad_val) else float(grad_val) | |
| optimizer.step() | |
| lr_scheduler.step() | |
| optimizer.zero_grad(set_to_none=True) | |
| if accelerator.sync_gradients: | |
| global_step += 1 | |
| progress_bar.update(1) | |
| if accelerator.is_main_process: | |
| if fbp: | |
| current_lr = base_learning_rate | |
| else: | |
| current_lr = lr_scheduler.get_last_lr()[0] | |
| batch_grads.append(grad) | |
| log_data = {} | |
| log_data["loss_mse"] = mse_loss.detach().item() | |
| log_data["lr"] = current_lr | |
| log_data["grad"] = grad | |
| if accelerator.sync_gradients: | |
| if use_wandb: | |
| wandb.log(log_data, step=global_step) | |
| if use_comet_ml: | |
| comet_experiment.log_metrics(log_data, step=global_step) | |
| current_time = time.time() | |
| is_time_to_sample = (current_time - last_sample_time) >= sample_interval_seconds | |
| if is_time_to_sample or global_step == 50: | |
| # Передаем tuple (emb, mask) для негатива | |
| if save_model: | |
| generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step) | |
| elif epoch % 10 == 0: | |
| generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step) | |
| last_n = sink_interval | |
| if save_model: | |
| has_losses = len(batch_losses) > 0 | |
| avg_sample_loss = np.mean(batch_losses[-sink_interval:]) if has_losses else 0.0 | |
| last_loss = batch_losses[-1] if has_losses else 0.0 | |
| max_loss = max(avg_sample_loss, last_loss) | |
| should_save = max_loss < min_loss * save_barrier | |
| print( | |
| f"Saving: {should_save} | Max: {max_loss:.4f} | " | |
| f"Last: {last_loss:.4f} | Avg: {avg_sample_loss:.4f}" | |
| ) | |
| # 6. Сохранение и обновление | |
| if should_save: | |
| min_loss = max_loss | |
| save_checkpoint(unet) | |
| last_sample_time = current_time | |
| unet.train() | |
| if accelerator.is_main_process: | |
| avg_epoch_loss = np.mean(batch_losses) if len(batch_losses) > 0 else 0.0 | |
| avg_epoch_grad = np.mean(batch_grads) if len(batch_grads) > 0 else 0.0 | |
| print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}") | |
| log_data_ep = { | |
| "epoch_loss": avg_epoch_loss, | |
| "epoch_grad": avg_epoch_grad, | |
| "epoch": epoch + 1, | |
| } | |
| if use_wandb: | |
| wandb.log(log_data_ep) | |
| if use_comet_ml: | |
| comet_experiment.log_metrics(log_data_ep) | |
| if accelerator.is_main_process: | |
| print("Обучение завершено! Сохраняем финальную модель...") | |
| #if save_model: | |
| save_checkpoint(unet,"fp16") | |
| if use_comet_ml: | |
| comet_experiment.end() | |
| accelerator.free_memory() | |
| if torch.distributed.is_initialized(): | |
| torch.distributed.destroy_process_group() | |
| print("Готово!") | |