Text-to-Image
Diffusers
Safetensors
1b / train.py
babkasotona's picture
Upload folder using huggingface_hub
7e72b1d verified
import os
import math
import torch
import numpy as np
import matplotlib.pyplot as plt
import wandb,comet_ml
import random,time
import gc
import bitsandbytes as bnb
import torch.nn.functional as F
import argparse
from datetime import datetime
from diffusers import UNet2DConditionModel, AutoencoderKLQwenImage, FlowMatchEulerDiscreteScheduler
from transformers import Qwen3_5Tokenizer, Qwen3_5ForConditionalGeneration
from torch.utils.data import DataLoader, Sampler
from torch.optim.lr_scheduler import LambdaLR
from collections import defaultdict
from accelerate import Accelerator
from datasets import load_from_disk
from tqdm import tqdm
from PIL import Image, ImageOps
from torch.utils.checkpoint import checkpoint
from diffusers.models.attention_processor import AttnProcessor2_0
from contextlib import nullcontext
from transformers.optimization import Adafactor
from torch.nn.attention import sdpa_kernel, SDPBackend
# Muon not tested! pip install git+https://github.com/recoilme/muon_adamw8bit.git
from muon_adamw8bit import MuonAdamW8bit
os.environ["NCCL_P2P_DISABLE"] = "1"
os.environ["NCCL_IB_DISABLE"] = "1" # comment this on H100!
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# --------------------------- Параметры ---------------------------
#ds_path = "datasets/ds1234_noanime_704_vae8x16x" #alchemist_704_vae8x16x_imgpool"
ds_path = "/root/sdxs-2b/datasets/ds12345_640_vae_qwen"
project = "unet"
gpu_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
local_bs = max(1, int((gpu_mem_gb / 32) * 7))
num_gpus = torch.cuda.device_count()
batch_size = local_bs * num_gpus
base_learning_rate = 4e-5
min_learning_rate = 4e-6
# 0.5 - pretrain (base forms)
# 1 - base train (composition)
# 3 - finetuning (anatomy)
# 5 - small details (faces)
learning_rate_scale = 3
base_learning_rate = base_learning_rate / learning_rate_scale
min_learning_rate = min_learning_rate / learning_rate_scale
print(f"Calculated params max-lr:{base_learning_rate} min-lr:{min_learning_rate} GPUs: {num_gpus}, Global BS: {batch_size}")
num_epochs = num_gpus
sink_interval_share = 20
sample_interval_min = 60
cfg_dropout = -0.10
# Время t, bias = -0.5 (Фокус на Деталях ~300) bias = 0.5 (Фокус на структуре) bias = 0 (колокол/ равномерно)
sigmoid_bias = -0.1
max_length = 250
max_snr_gamma = 5.0
use_precomputed_embeddings = False
use_wandb = False
use_comet_ml = True
save_model = True
use_decay = True
fbp = False
torch_compile = False
unet_gradient = True
loss_normalize = False
fixed_seed = False
shuffle = True
optimizer_type = "adafactor" #"adam8bit"
if optimizer_type == "muon_adam8bit":
batch_size = num_gpus * max(1, int((gpu_mem_gb / 32) * 3))
muon_lr_scale = 500
comet_ml_api_key = "Agctp26mbqnoYrrlvQuKSTk6r" # hardcoded for blind run, i don't care about key
comet_ml_workspace = "recoilme"
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
# MAX_JOBS=4 pip install flash-attn --no-build-isolation
#torch.backends.cuda.enable_flash_sdp(True)
#torch.backends.cuda.enable_mem_efficient_sdp(True)
#torch.backends.cuda.enable_math_sdp(False) # Отключаем медленный вариант
save_barrier = 1.25
warmup_percent = 0.0025
betta2 = 0.997
eps = 1e-6
clip_grad_norm = 1.0
limit = 0
checkpoints_folder = ""
gradient_accumulation_steps = 1
dtype = torch.float32
mixed_precision = "bf16"
# Параметры для диффузии
n_diffusion_steps = 40
samples_to_generate = 12
guidance_scale = 4
# Папки для сохранения результатов
generated_folder = "samples"
os.makedirs(generated_folder, exist_ok=True)
# Настройка seed
current_date = datetime.now()
seed = int(current_date.strftime("%Y%m%d")) + 42
if fixed_seed:
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
accelerator = Accelerator(
mixed_precision=mixed_precision,
gradient_accumulation_steps=gradient_accumulation_steps
)
device = accelerator.device
print("init")
# Создаём объект ArgumentParser с рассчитанными значениями по умолчанию
parser = argparse.ArgumentParser(description='Train a model on a dataset.')
parser.add_argument('--ds-path', type=str, default=ds_path, help='Path to the dataset')
parser.add_argument('--ep', type=int, default=num_epochs, help='Number of epochs to train the model')
parser.add_argument('--batch', type=int, default=batch_size, help='Total batch size')
parser.add_argument('--min-lr', type=float, default=min_learning_rate, help='Minimum learning rate')
parser.add_argument('--max-lr', type=float, default=base_learning_rate, help='Maximum learning rate')
parser.add_argument('--dry-run', action='store_true',default=False, help='Dry run train without saving/sampling')
parser.add_argument('--lvl', type=float, default=0.0, help='Train level, from 0.5 to 5')
# Парсим аргументы командной строки
args = parser.parse_args()
# Используем значения из аргументов
batch_size = args.batch
ds_path = args.ds_path
base_learning_rate = args.max_lr
min_learning_rate = args.min_lr
num_epochs = args.ep
lvl = args.lvl
if args.dry_run:
save_model = False
if lvl >= 0.1:
base_learning_rate = base_learning_rate / lvl
min_learning_rate = min_learning_rate / lvl
print(f"max-lr:{base_learning_rate} min-lr:{min_learning_rate}")
# --------------------------- Инициализация WandB ---------------------------
if accelerator.is_main_process:
if use_wandb:
wandb.init(project=project, config={
"batch_size": batch_size,
"base_learning_rate": base_learning_rate,
"num_epochs": num_epochs,
"optimizer_type": optimizer_type,
})
if use_comet_ml:
from comet_ml import Experiment
comet_experiment = Experiment(
api_key=comet_ml_api_key,
project_name=project,
workspace=comet_ml_workspace
)
hyper_params = {
"batch_size": batch_size,
"base_learning_rate": base_learning_rate,
"num_epochs": num_epochs,
}
comet_experiment.log_parameters(hyper_params)
# --------------------------- Загрузка моделей ---------------------------
vae = AutoencoderKLQwenImage.from_pretrained("vae", torch_dtype=dtype).to(device).eval()
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("scheduler")
tokenizer = None
text_encoder = None
def load_text_encoder():
global tokenizer, text_encoder
if tokenizer is None:
tokenizer = Qwen3_5Tokenizer.from_pretrained("tokenizer")
if text_encoder is None:
text_encoder = Qwen3_5ForConditionalGeneration.from_pretrained(
"text_encoder",
torch_dtype=torch.float16
).to(device).eval()
load_text_encoder()
@torch.no_grad()
def encode_texts(text, max_length=max_length):
if text is None:
text = ""
if isinstance(text, str):
text = [text]
formatted_prompts = []
for t in text:
messages = [{"role": "user", "content": [{"type": "text", "text": t}]}]
formatted_prompts.append(
tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False # 🔥 фикс
)
)
toks = tokenizer(
formatted_prompts,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt"
).to(device)
#toks = tokenizer(
# formatted_prompts,
# padding=True, # 🔥 динамический padding
# truncation=True,
# return_tensors="pt"
#).to(device)
outputs = text_encoder(
input_ids=toks.input_ids,
attention_mask=toks.attention_mask,
output_hidden_states=True
)
hidden = outputs.hidden_states[-2]
return hidden.to(dtype=dtype), toks.attention_mask.to(dtype=torch.bool)
@torch.no_grad()
def encode_texts_fast(text, max_length=max_length):
if text is None: text = ""
if isinstance(text, str): text = [text]
formatted_prompts = []
for t in text:
messages = [{"role": "user", "content": [{"type": "text", "text": t}]}]
formatted_prompts.append(tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False))
toks = tokenizer(formatted_prompts, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt").to(device)
outputs = text_encoder(input_ids=toks.input_ids, attention_mask=toks.attention_mask, output_hidden_states=True)
last_hidden = outputs.hidden_states[-2]
return last_hidden.to(dtype=dtype), toks.attention_mask.to(dtype=torch.bool)
shift_factor = getattr(vae.config, "shift_factor", 0.0)
if shift_factor is None:
shift_factor = 0.0
scaling_factor = getattr(vae.config, "scaling_factor", 1.0)
if scaling_factor is None:
scaling_factor = 1.0
mean = getattr(vae.config, "latents_mean", None)
std = getattr(vae.config, "latents_std", None)
if mean is not None and std is not None:
latents_std = torch.tensor(std, device=device, dtype=dtype).view(1, len(std), 1, 1)
latents_mean = torch.tensor(mean, device=device, dtype=dtype).view(1, len(mean), 1, 1)
import numpy as np
from torch.utils.data import Sampler
class DistributedResolutionBatchSampler(Sampler):
def __init__(self, dataset, batch_size, num_replicas, rank, drop_last=True, shuffle=True):
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.shuffle = shuffle
self.drop_last = drop_last
self.epoch = 0
# batch на одну GPU
self.batch_size = max(1, batch_size // num_replicas)
self.global_batch = self.batch_size * num_replicas
try:
widths = np.asarray(dataset["width"])
heights = np.asarray(dataset["height"])
except KeyError:
widths = np.zeros(len(dataset))
heights = np.zeros(len(dataset))
# --- группировка индексов ---
groups = {}
for i, (w, h) in enumerate(zip(widths, heights)):
groups.setdefault((w, h), []).append(i)
# --- создаём список всех глобальных батчей ---
all_batches = []
for indices in groups.values():
idx = np.asarray(indices, dtype=np.int64)
num_batches = len(idx) // self.global_batch
if num_batches == 0:
continue
idx = idx[: num_batches * self.global_batch]
batches = idx.reshape(num_batches, self.global_batch)
all_batches.append(batches)
if len(all_batches) > 0:
self.global_batches = np.concatenate(all_batches, axis=0)
else:
self.global_batches = np.empty((0, self.global_batch), dtype=np.int64)
self.num_batches = len(self.global_batches)
def __iter__(self):
rng = np.random.RandomState(self.epoch)
order = np.arange(self.num_batches)
if self.shuffle:
rng.shuffle(order)
start = self.rank * self.batch_size
end = start + self.batch_size
for i in order:
yield self.global_batches[i][start:end]
def __len__(self):
return self.num_batches
def set_epoch(self, epoch):
self.epoch = epoch
def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
size_groups = defaultdict(list)
try:
widths = dataset["width"]
heights = dataset["height"]
except KeyError:
widths = [0] * len(dataset)
heights = [0] * len(dataset)
for i, (w, h) in enumerate(zip(widths, heights)):
size = (w, h)
size_groups[size].append(i)
fixed_samples = {}
for size, indices in size_groups.items():
n_samples = min(samples_per_group, len(indices))
if len(size_groups)==1:
n_samples = samples_to_generate
if n_samples == 0:
continue
sample_indices = random.sample(indices, n_samples)
samples_data = [dataset[idx] for idx in sample_indices]
latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device, dtype=dtype)
if latents.ndim == 5:
latents = latents.squeeze(2)
#elif latents.ndim == 6:
# latents = latents.squeeze(2)
texts = [item["text"] for item in samples_data]
if use_precomputed_embeddings:
embeddings = torch.tensor(
np.array([item["embeddings"] for item in samples_data]),
device=device,
dtype=dtype
)
masks = torch.tensor(
np.array([item["attention_mask"] for item in samples_data]),
device=device,
dtype=torch.int64
)
else:
embeddings, masks = encode_texts(texts,max_length)
fixed_samples[size] = (latents, embeddings, masks, texts)
print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
return fixed_samples
if limit > 0:
dataset = load_from_disk(ds_path).select(range(limit))
else:
dataset = load_from_disk(ds_path)
print(f"images: {len(dataset)}")
def collate_fn_simple(batch):
latents = torch.from_numpy(
np.array([item["vae"] for item in batch], dtype=np.float16)
).to(device, dtype=dtype)
if latents.ndim == 5:
latents = latents.squeeze(2)
#elif latents.ndim == 6:
# latents = latents.squeeze(2)
if use_precomputed_embeddings:
embeddings = torch.from_numpy(
np.array([item["embeddings"] for item in batch], dtype=np.float16)
).to(device, dtype=dtype)
attention_mask = torch.from_numpy(
np.array([item["attention_mask"] for item in batch], dtype=np.int64)
).to(device)
return latents, embeddings, attention_mask
raw_texts = [item["text"] for item in batch]
texts = [
"" if t.lower().startswith("zero")
else "" if random.random() < cfg_dropout
else t[1:].lstrip() if t.startswith(".")
else t.replace("The image shows ", "").replace("The image is ", "").replace("This image captures ","").strip()
for t in raw_texts
]
embeddings, attention_mask = encode_texts(texts,max_length)
attention_mask = attention_mask.to(dtype=torch.bool)
return latents, embeddings, attention_mask
batch_sampler = DistributedResolutionBatchSampler(
dataset=dataset,
batch_size=batch_size,
num_replicas=accelerator.num_processes,
rank=accelerator.process_index,
shuffle = shuffle
)
dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple)
if accelerator.is_main_process:
print("Total samples", len(dataloader))
dataloader = accelerator.prepare(dataloader)
start_epoch = 0
global_step = 0
total_training_steps = (len(dataloader) * num_epochs)
world_size = accelerator.state.num_processes
# Загрузка UNet
latest_checkpoint = os.path.join(checkpoints_folder, project)
if os.path.isdir(latest_checkpoint):
print("Загружаем UNet из чекпоинта:", latest_checkpoint)
unet = UNet2DConditionModel.from_pretrained(latest_checkpoint).to(device=device, dtype=dtype)
if unet_gradient:
unet.enable_gradient_checkpointing()
#unet.set_use_memory_efficient_attention_xformers(False)
print(dir(SDPBackend))
try:
unet.set_attn_processor(AttnProcessor2_0())
except Exception as e:
print(f"Ошибка при включении SDPA: {e}")
unet.set_use_memory_efficient_attention_xformers(True)
else:
raise FileNotFoundError(f"UNet checkpoint not found at {latest_checkpoint}")
def create_optimizer(name, params):
if name == "adam8bit":
return bnb.optim.AdamW8bit(
params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.001
)
elif name == "adam":
return torch.optim.AdamW(
params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.001
)
elif name == "adafactor":
return Adafactor(
params,
lr=base_learning_rate,
eps=(1e-30, 1e-3),
clip_threshold=1.0,
decay_rate=-0.8,
beta1=None,
weight_decay=0.001,
relative_step=False,
scale_parameter=False,
warmup_init=False
)
elif name == "muon_adam8bit":
return MuonAdamW8bit(
params,
lr=base_learning_rate,
betas=(0.9, betta2),
eps=eps,
weight_decay=0.01,
muon_lr_mult=muon_lr_scale,
)
else:
raise ValueError(f"Unknown optimizer: {name}")
if fbp:
trainable_params = list(unet.parameters())
optimizer_dict = {p: create_optimizer(optimizer_type, [p]) for p in trainable_params}
def optimizer_hook(param):
optimizer_dict[param].step()
optimizer_dict[param].zero_grad(set_to_none=True)
for param in trainable_params:
param.register_post_accumulate_grad_hook(optimizer_hook)
unet, optimizer = accelerator.prepare(unet, optimizer_dict)
else:
unet.requires_grad_(True)
optimizer = create_optimizer(optimizer_type, unet.parameters())
# 1. Сначала замораживаем ВСЕ параметры UNet
#unet.requires_grad_(False)
# 2. Размораживаем только нужные
#trainable_params_names = ["conv_in.weight", "conv_in.bias", "conv_out.weight", "conv_out.bias"]
#train_params = []
#for name, param in unet.named_parameters():
# if any(target in name for target in trainable_params_names):
# param.requires_grad = True
# train_params.append(param)
# print(f"Обучаемый слой: {name}")
def lr_schedule(step):
x = step / (total_training_steps * world_size)
warmup = warmup_percent
if not use_decay:
return base_learning_rate
if x < warmup:
return min_learning_rate + (base_learning_rate - min_learning_rate) * (x / warmup)
decay_ratio = (x - warmup) / (1 - warmup)
return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
(1 + math.cos(math.pi * decay_ratio))
lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
if torch_compile:
print("Compiling UNet... Это займет несколько минут, не прерывайте!")
unet = torch.compile(unet)
print("Compiling - ok")
if not fbp:
unet, optimizer, lr_scheduler = accelerator.prepare(unet, optimizer, lr_scheduler)
# Фиксированные семплы
fixed_samples = get_fixed_samples_by_resolution(dataset)
# --- [UPDATED] Функция для негативного эмбеддинга (возвращает 3 элемента) ---
def get_negative_embedding(neg_prompt="", batch_size=1):
if not neg_prompt:
hidden_dim = 1024
seq_len = max_length
empty_emb = torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device)
empty_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device)
return empty_emb, empty_mask
uncond_emb, uncond_mask = encode_texts([neg_prompt],max_length)
uncond_emb = uncond_emb.to(dtype=dtype, device=device).repeat(batch_size, 1, 1)
uncond_mask = uncond_mask.to(device=device).repeat(batch_size, 1)
return uncond_emb, uncond_mask
# Получаем негативные (пустые) условия для валидации
if use_precomputed_embeddings:
# 1. грузим encoder ВРЕМЕННО
load_text_encoder()
# 2. считаем negative
uncond_emb, uncond_mask = get_negative_embedding("low quality")
# 3. уносим на CPU (очень важно)
uncond_emb = uncond_emb.to("cpu")
uncond_mask = uncond_mask.to("cpu")
# 4. выгружаем encoder с GPU
del text_encoder
torch.cuda.empty_cache()
gc.collect()
text_encoder = None
else:
uncond_emb, uncond_mask = get_negative_embedding("low quality")
def pad_to_match(a, b, pad_value=0):
# a, b: [B, T, D]
Ta, Tb = a.shape[1], b.shape[1]
if Ta == Tb:
return a, b
T = max(Ta, Tb)
def pad(x, T_target):
pad_len = T_target - x.shape[1]
if pad_len <= 0:
return x
return torch.nn.functional.pad(x, (0, 0, 0, pad_len), value=pad_value)
return pad(a, T), pad(b, T)
def pad_mask(a, b):
Ta, Tb = a.shape[1], b.shape[1]
T = max(Ta, Tb)
def pad(x):
pad_len = T - x.shape[1]
if pad_len <= 0:
return x
return torch.nn.functional.pad(x, (0, pad_len), value=0)
return pad(a), pad(b)
# --- Функция генерации семплов ---
@torch.compiler.disable()
@torch.no_grad()
def generate_and_save_samples(fixed_samples_cpu, uncond_data, step):
uncond_emb, uncond_mask = uncond_data
uncond_emb = uncond_emb.to(device)
uncond_mask = uncond_mask.to(device)
original_model = None
try:
if not torch_compile:
original_model = accelerator.unwrap_model(unet, keep_torch_compile=True).eval()
else:
original_model = unet.eval()
vae.to(device=device).eval()
all_generated_images = []
all_captions = []
# Распаковываем 5 элементов (добавились mask)
for size, (sample_latents, sample_text_embeddings, sample_mask, sample_text) in fixed_samples_cpu.items():
width, height = size
sample_latents = sample_latents.to(dtype=dtype, device=device)
sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
sample_mask = sample_mask.to(device=device)
latents = torch.randn(
sample_latents.shape,
device=device,
dtype=sample_latents.dtype,
generator=torch.Generator(device=device).manual_seed(seed)
)
scheduler.set_timesteps(n_diffusion_steps, device=device)
for t in scheduler.timesteps:
if guidance_scale != 1:
latent_model_input = torch.cat([latents, latents], dim=0)
curr_batch_size = sample_text_embeddings.shape[0]
seq_len = sample_text_embeddings.shape[1]
hidden_dim = sample_text_embeddings.shape[2]
neg_emb_batch = uncond_emb[0:1].expand(curr_batch_size, -1, -1)
neg_emb_batch, sample_text_embeddings = pad_to_match(neg_emb_batch, sample_text_embeddings)
neg_mask_batch = uncond_mask[0:1].expand(curr_batch_size, -1)
neg_mask_batch, sample_mask = pad_mask(neg_mask_batch, sample_mask)
text_embeddings_batch = torch.cat([neg_emb_batch, sample_text_embeddings], dim=0)
attention_mask_batch = torch.cat([neg_mask_batch, sample_mask], dim=0)
#neg_emb_batch = uncond_emb[0:1].expand(curr_batch_size, -1, -1)
#text_embeddings_batch = torch.cat([neg_emb_batch, sample_text_embeddings], dim=0)
#neg_mask_batch = uncond_mask[0:1].expand(curr_batch_size, -1)
#attention_mask_batch = torch.cat([neg_mask_batch, sample_mask], dim=0)
else:
latent_model_input = latents
text_embeddings_batch = sample_text_embeddings
attention_mask_batch = sample_mask
# Теперь всё имеет одинаковый batch size
model_out = original_model(
latent_model_input,
t,
encoder_hidden_states=text_embeddings_batch,
encoder_attention_mask=attention_mask_batch,
)
flow = getattr(model_out, "sample", model_out)
if guidance_scale != 1:
flow_uncond, flow_cond = flow.chunk(2)
flow = flow_uncond + guidance_scale * (flow_cond - flow_uncond)
latents = scheduler.step(flow, t, latents).prev_sample
current_latents = latents
if step==0:
current_latents = sample_latents
# VAE Qwen ожидает 5D, добавляем 1 кадр времени
vae_input = current_latents.unsqueeze(2).to(torch.float32)
if latents_mean is not None and latents_std is not None:
vae_input = vae_input * latents_std.unsqueeze(2) + latents_mean.unsqueeze(2)
decoded = vae.decode(vae_input).sample
# После декодирования у Qwen на выходе [B, C, 1, H, W], убираем 1
if decoded.ndim == 5:
decoded = decoded.squeeze(2)
decoded_fp32 = decoded.to(torch.float32)
for img_idx, img_tensor in enumerate(decoded_fp32):
img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy()
img = img.transpose(1, 2, 0)
if np.isnan(img).any():
print("NaNs found, saving stopped! Step:", step)
pil_img = Image.fromarray((img * 255).astype("uint8"))
max_w_overall = max(s[0] for s in fixed_samples_cpu.keys())
max_h_overall = max(s[1] for s in fixed_samples_cpu.keys())
max_w_overall = max(255, max_w_overall)
max_h_overall = max(255, max_h_overall)
padded_img = ImageOps.pad(pil_img, (max_w_overall, max_h_overall), color='white')
all_generated_images.append(padded_img)
caption_text = sample_text[img_idx][:300] if img_idx < len(sample_text) else ""
all_captions.append(caption_text)
sample_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg"
pil_img.save(sample_path, "JPEG", quality=95)
if use_wandb and accelerator.is_main_process:
wandb_images = [
wandb.Image(img, caption=f"{all_captions[i]}")
for i, img in enumerate(all_generated_images)
]
wandb.log({"generated_images": wandb_images})
if use_comet_ml and accelerator.is_main_process:
for i, img in enumerate(all_generated_images):
comet_experiment.log_image(
image_data=img,
name=f"step_{step}_img_{i}",
step=step,
metadata={"caption": all_captions[i]}
)
finally:
vae.to("cpu")
uncond_emb = uncond_emb.to("cpu")
uncond_mask = uncond_mask.to("cpu")
try:
all_generated_images.clear()
all_captions.clear()
del all_generated_images, all_captions
del latents, current_latents, latent_model_input, flow
del decoded, decoded_fp32
del sample_latents, sample_text_embeddings, sample_mask # Копии на GPU
del model_out
except UnboundLocalError:
pass
# 3. Синхронизируем CUDA перед очисткой
torch.cuda.synchronize()
# 4. Теперь чистим кэш аллокатора и вызываем GC
torch.cuda.empty_cache()
gc.collect()
# --------------------------- Генерация сэмплов перед обучением ---------------------------
if accelerator.is_main_process:
if save_model:
print("Генерация сэмплов до старта обучения...")
generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), 0)
accelerator.wait_for_everyone()
def save_checkpoint(unet, variant=""):
if accelerator.is_main_process:
model_to_save = None
if not torch_compile:
model_to_save = accelerator.unwrap_model(unet)
else:
model_to_save = unet
if variant != "":
model_to_save.to(dtype=torch.float16).save_pretrained(
os.path.join(checkpoints_folder, f"{project}"), variant=variant
)
else:
model_to_save.save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
torch.cuda.synchronize()
torch.cuda.empty_cache()
gc.collect()
# --------------------------- Тренировочный цикл ---------------------------
if accelerator.is_main_process:
print(f"Total steps per GPU: {total_training_steps}")
epoch_loss_points = []
progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
steps_per_epoch = len(dataloader)
sink_interval = max(1, steps_per_epoch // sink_interval_share)
min_loss = 4.
last_sample_time = time.time()
sample_interval_seconds = sample_interval_min * 60 # 60 минут
for epoch in range(start_epoch, start_epoch + num_epochs):
batch_losses = []
batch_grads = []
batch_sampler.set_epoch(epoch)
accelerator.wait_for_everyone()
unet.train()
for step, (latents, embeddings, attention_mask) in enumerate(dataloader):
if save_model == False and epoch == 0 and step == 5 :
used_gb = torch.cuda.max_memory_allocated() / 1024**3
print(f"Шаг {step}: {used_gb:.2f} GB")
amp_context = accelerator.autocast() if torch_compile else nullcontext()
with accelerator.accumulate(unet):
with amp_context:
# шум
noise = torch.randn_like(latents, dtype=latents.dtype)
t = torch.sigmoid(torch.randn(latents.shape[0], device=latents.device, dtype=latents.dtype) + sigmoid_bias)
# интерполяция между x0 и шумом
noisy_latents = (1.0 - t.view(-1, 1, 1, 1)) * latents + t.view(-1, 1, 1, 1) * noise
# делаем integer timesteps для UNet
timesteps = t.to(torch.float32).mul(999.0)
# --- Вызов UNet с маской ---
#with sdpa_kernel([SDPBackend.FLASH_ATTENTION, SDPBackend.EFFICIENT_ATTENTION], SDPBackend.CUDNN_ATTENTION):
with sdpa_kernel([
SDPBackend.FLASH_ATTENTION,
SDPBackend.EFFICIENT_ATTENTION,
SDPBackend.CUDNN_ATTENTION,
SDPBackend.MATH
]):
model_pred = unet(
noisy_latents,
timesteps,
encoder_hidden_states=embeddings,
encoder_attention_mask=attention_mask,
).sample
target = noise - latents
if max_snr_gamma > 0:
# 1. Считаем сырой лосс попиксельно (без усреднения по всему батчу)
raw_loss = F.mse_loss(model_pred.float(), target.float(), reduction='none')
# Усредняем ошибку внутри каждой картинки, чтобы получить вектор лоссов [Batch_Size]
loss_per_sample = raw_loss.mean(dim=[1, 2, 3])
# 2. Считаем SNR (Signal-to-Noise Ratio) для текущих таймстепов батча
# Сигнал — это (1 - t), шум — это t.
snr = ((1.0 - t) / (t + 1e-5)) ** 2
# Для твоего формата предсказания (v-prediction) формула веса: min(snr, gamma) / (snr + 1)
min_snr_weights = torch.clamp(snr, max=max_snr_gamma) / (snr + 1.0)
# 4. Применяем веса к лоссам картинок и получаем итоговый скаляр для backward()
# (view(-1) гарантирует, что размерности совпадут)
mse_loss = (loss_per_sample * min_snr_weights.view(-1)).mean()
else:
mse_loss = F.mse_loss(model_pred.float(), target.float())
batch_losses.append(mse_loss.detach().item())
if (global_step % 100 == 0) or (global_step % sink_interval == 0):
accelerator.wait_for_everyone()
losses_dict = {}
losses_dict["mse"] = mse_loss
if (global_step % 100 == 0) or (global_step % sink_interval == 0):
accelerator.wait_for_everyone()
accelerator.backward(mse_loss)
if (global_step % 100 == 0) or (global_step % sink_interval == 0):
accelerator.wait_for_everyone()
grad = 0.0
if not fbp:
if accelerator.sync_gradients:
grad_val = accelerator.clip_grad_norm_(unet.parameters(), clip_grad_norm)
grad = grad_val.float().item() if torch.is_tensor(grad_val) else float(grad_val)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
if accelerator.sync_gradients:
global_step += 1
progress_bar.update(1)
if accelerator.is_main_process:
if fbp:
current_lr = base_learning_rate
else:
current_lr = lr_scheduler.get_last_lr()[0]
batch_grads.append(grad)
log_data = {}
log_data["loss_mse"] = mse_loss.detach().item()
log_data["lr"] = current_lr
log_data["grad"] = grad
if accelerator.sync_gradients:
if use_wandb:
wandb.log(log_data, step=global_step)
if use_comet_ml:
comet_experiment.log_metrics(log_data, step=global_step)
current_time = time.time()
is_time_to_sample = (current_time - last_sample_time) >= sample_interval_seconds
if is_time_to_sample or global_step == 50:
# Передаем tuple (emb, mask) для негатива
if save_model:
generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
elif epoch % 10 == 0:
generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
last_n = sink_interval
if save_model:
has_losses = len(batch_losses) > 0
avg_sample_loss = np.mean(batch_losses[-sink_interval:]) if has_losses else 0.0
last_loss = batch_losses[-1] if has_losses else 0.0
max_loss = max(avg_sample_loss, last_loss)
should_save = max_loss < min_loss * save_barrier
print(
f"Saving: {should_save} | Max: {max_loss:.4f} | "
f"Last: {last_loss:.4f} | Avg: {avg_sample_loss:.4f}"
)
# 6. Сохранение и обновление
if should_save:
min_loss = max_loss
save_checkpoint(unet)
last_sample_time = current_time
unet.train()
if accelerator.is_main_process:
avg_epoch_loss = np.mean(batch_losses) if len(batch_losses) > 0 else 0.0
avg_epoch_grad = np.mean(batch_grads) if len(batch_grads) > 0 else 0.0
print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
log_data_ep = {
"epoch_loss": avg_epoch_loss,
"epoch_grad": avg_epoch_grad,
"epoch": epoch + 1,
}
if use_wandb:
wandb.log(log_data_ep)
if use_comet_ml:
comet_experiment.log_metrics(log_data_ep)
if accelerator.is_main_process:
print("Обучение завершено! Сохраняем финальную модель...")
#if save_model:
save_checkpoint(unet,"fp16")
if use_comet_ml:
comet_experiment.end()
accelerator.free_memory()
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
print("Готово!")