2b / train-Copy1.py
babkasotona's picture
Upload folder using huggingface_hub
58bb2b7 verified
import os
import math
import torch
import numpy as np
import matplotlib.pyplot as plt
import wandb, comet_ml
import random, time
import gc
import bitsandbytes as bnb
import torch.nn.functional as F
import argparse
from datetime import datetime
from diffusers import CosmosTransformer3DModel, AutoencoderKLQwenImage, FlowMatchEulerDiscreteScheduler
from transformers import Qwen3_5Tokenizer, Qwen3_5ForConditionalGeneration
from torch.utils.data import DataLoader, Sampler
from torch.optim.lr_scheduler import LambdaLR
from collections import defaultdict
from accelerate import Accelerator
from datasets import load_from_disk
from tqdm import tqdm
from PIL import Image, ImageOps
from torch.utils.checkpoint import checkpoint
from diffusers.models.attention_processor import AttnProcessor2_0
from contextlib import nullcontext
from transformers.optimization import Adafactor
# Muon not tested! pip install git+https://github.com/recoilme/muon_adamw8bit.git
from muon_adamw8bit import MuonAdamW8bit
os.environ["NCCL_P2P_DISABLE"] = "1"
os.environ["NCCL_IB_DISABLE"] = "1" # comment this on H100!
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
# --------------------------- Параметры ---------------------------
ds_path = "datasets/ds234_640_vae_qwen"
project = "transformer"
gpu_mem_gb = torch.cuda.get_device_properties(0).total_memory / 1e9
local_bs = max(1, int((gpu_mem_gb / 32) * 7))
num_gpus = torch.cuda.device_count()
batch_size = local_bs * num_gpus
base_learning_rate = 4e-5
min_learning_rate = 4e-6
learning_rate_scale = 3
base_learning_rate = base_learning_rate / learning_rate_scale
min_learning_rate = min_learning_rate / learning_rate_scale
print(f"Calculated params max-lr:{base_learning_rate} min-lr:{min_learning_rate} GPUs: {num_gpus}, Global BS: {batch_size}")
num_epochs = num_gpus
sink_interval_share = 10
sample_interval_min = 20
cfg_dropout = 0.10
# Время t, bias = -0.5 (Фокус на Деталях ~300) bias = 0.5 (Фокус на структуре) bias = 0 (колокол/ равномерно)
sigmoid_bias = 0.1
max_length = 250
use_precomputed_embeddings = False
use_wandb = False
use_comet_ml = False
save_model = True
use_decay = True
fbp = False
torch_compile = False
transformer_gradient = True
loss_normalize = False
fixed_seed = False
shuffle = True
optimizer_type = "adafactor"
if optimizer_type == "muon_adam8bit":
batch_size = num_gpus * max(1, int((gpu_mem_gb / 32) * 3))
muon_lr_scale = 500
comet_ml_api_key = "Agctp26mbqnoYrrlvQuKSTk6r"
comet_ml_workspace = "recoilme"
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(False)
save_barrier = 1.25
warmup_percent = 0.0025
betta2 = 0.997
eps = 1e-6
clip_grad_norm = 1.0
limit = 0
checkpoints_folder = ""
gradient_accumulation_steps = 1
dtype = torch.float32
mixed_precision = "bf16"
# Параметры для диффузии
n_diffusion_steps = 40
samples_to_generate = 12
guidance_scale = 7.0
# Папки для сохранения результатов
generated_folder = "samples"
os.makedirs(generated_folder, exist_ok=True)
# Настройка seed
current_date = datetime.now()
seed = int(current_date.strftime("%Y%m%d")) + 42
if fixed_seed:
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
accelerator = Accelerator(
mixed_precision=mixed_precision,
gradient_accumulation_steps=gradient_accumulation_steps
)
device = accelerator.device
print("init")
parser = argparse.ArgumentParser(description='Train a model on a dataset.')
parser.add_argument('--ds-path', type=str, default=ds_path, help='Path to the dataset')
parser.add_argument('--ep', type=int, default=num_epochs, help='Number of epochs to train the model')
parser.add_argument('--batch', type=int, default=batch_size, help='Total batch size')
parser.add_argument('--min-lr', type=float, default=min_learning_rate, help='Minimum learning rate')
parser.add_argument('--max-lr', type=float, default=base_learning_rate, help='Maximum learning rate')
parser.add_argument('--dry-run', action='store_true',default=False, help='Dry run train without saving/sampling')
parser.add_argument('--lvl', type=float, default=0.0, help='Train level, from 0.5 to 5')
args = parser.parse_args()
batch_size = args.batch
ds_path = args.ds_path
base_learning_rate = args.max_lr
min_learning_rate = args.min_lr
num_epochs = args.ep
lvl = args.lvl
if args.dry_run:
save_model = False
if lvl >= 0.1:
base_learning_rate = base_learning_rate / lvl
min_learning_rate = min_learning_rate / lvl
print(f"max-lr:{base_learning_rate} min-lr:{min_learning_rate}")
# --------------------------- Инициализация WandB ---------------------------
if accelerator.is_main_process:
if use_wandb:
wandb.init(project=project, config={
"batch_size": batch_size,
"base_learning_rate": base_learning_rate,
"num_epochs": num_epochs,
"optimizer_type": optimizer_type,
})
if use_comet_ml:
from comet_ml import Experiment
comet_experiment = Experiment(
api_key=comet_ml_api_key,
project_name=project,
workspace=comet_ml_workspace
)
hyper_params = {
"batch_size": batch_size,
"base_learning_rate": base_learning_rate,
"num_epochs": num_epochs,
}
comet_experiment.log_parameters(hyper_params)
# --------------------------- Загрузка моделей ---------------------------
vae = AutoencoderKLQwenImage.from_pretrained("vae", torch_dtype=dtype).to(device).to(dtype=dtype).eval()
scheduler = FlowMatchEulerDiscreteScheduler.from_pretrained("scheduler")
tokenizer = None
text_encoder = None
def load_text_encoder():
global tokenizer, text_encoder
if tokenizer is None:
tokenizer = Qwen3_5Tokenizer.from_pretrained("tokenizer")
if text_encoder is None:
text_encoder = Qwen3_5ForConditionalGeneration.from_pretrained(
"text_encoder",
torch_dtype=dtype
).to(device).eval()
load_text_encoder()
@torch.no_grad()
def encode_texts(text, max_length=max_length):
if text is None:
text = ""
if isinstance(text, str):
text = [text]
formatted_prompts = []
for t in text:
messages = [{"role": "user", "content": [{"type": "text", "text": t}]}]
formatted_prompts.append(
tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=False
)
)
toks = tokenizer(
formatted_prompts,
padding="max_length",
max_length=max_length,
truncation=True,
return_tensors="pt"
).to(device)
outputs = text_encoder(
input_ids=toks.input_ids,
attention_mask=toks.attention_mask,
output_hidden_states=True
)
hidden = outputs.hidden_states[-2].to(dtype=dtype)
lengths = toks.attention_mask.sum(dim=1)
for i, length in enumerate(lengths):
hidden[i, length:] = 0
return hidden, toks.attention_mask.to(dtype=torch.int64)
@torch.no_grad()
def encode_texts_fast(text, max_length=max_length):
if text is None: text = ""
if isinstance(text, str): text = [text]
formatted_prompts = []
for t in text:
messages = [{"role": "user", "content": [{"type": "text", "text": t}]}]
formatted_prompts.append(tokenizer.apply_chat_template(messages, add_generation_prompt=False, tokenize=False))
toks = tokenizer(formatted_prompts, padding="max_length", max_length=max_length, truncation=True, return_tensors="pt").to(device)
outputs = text_encoder(input_ids=toks.input_ids, attention_mask=toks.attention_mask, output_hidden_states=True)
last_hidden = outputs.hidden_states[-2].to(dtype=dtype)
lengths = toks.attention_mask.sum(dim=1)
for i, length in enumerate(lengths):
last_hidden[i, length:] = 0
return last_hidden, toks.attention_mask.to(dtype=torch.int64)
shift_factor = getattr(vae.config, "shift_factor", 0.0)
if shift_factor is None:
shift_factor = 0.0
scaling_factor = getattr(vae.config, "scaling_factor", 1.0)
if scaling_factor is None:
scaling_factor = 1.0
mean = getattr(vae.config, "latents_mean", None)
std = getattr(vae.config, "latents_std", None)
if mean is not None and std is not None:
latents_std = torch.tensor(std, device=device, dtype=dtype).view(1, len(std), 1, 1)
latents_mean = torch.tensor(mean, device=device, dtype=dtype).view(1, len(mean), 1, 1)
# Внимание: Cosmos использует инвертированный std для декодирования (1.0 / std)
#latents_std = 1.0 / torch.tensor(std).view(1, len(std), 1, 1, 1)
else:
latents_std = None
latents_mean = None
if scheduler is not None:
scheduler.register_to_config(
sigma_max=getattr(scheduler.config, "sigma_max", 80.0),
sigma_min=getattr(scheduler.config, "sigma_min", 0.002),
sigma_data=getattr(scheduler.config, "sigma_data", 1.0),
final_sigmas_type=getattr(scheduler.config, "final_sigmas_type", "sigma_min"),
)
import numpy as np
from torch.utils.data import Sampler
class DistributedResolutionBatchSampler(Sampler):
def __init__(self, dataset, batch_size, num_replicas, rank, drop_last=True, shuffle=True):
self.dataset = dataset
self.num_replicas = num_replicas
self.rank = rank
self.shuffle = shuffle
self.drop_last = drop_last
self.epoch = 0
self.batch_size = max(1, batch_size // num_replicas)
self.global_batch = self.batch_size * num_replicas
try:
widths = np.asarray(dataset["width"])
heights = np.asarray(dataset["height"])
except KeyError:
widths = np.zeros(len(dataset))
heights = np.zeros(len(dataset))
groups = {}
for i, (w, h) in enumerate(zip(widths, heights)):
groups.setdefault((w, h), []).append(i)
all_batches = []
for indices in groups.values():
idx = np.asarray(indices, dtype=np.int64)
num_batches = len(idx) // self.global_batch
if num_batches == 0:
continue
idx = idx[: num_batches * self.global_batch]
batches = idx.reshape(num_batches, self.global_batch)
all_batches.append(batches)
if len(all_batches) > 0:
self.global_batches = np.concatenate(all_batches, axis=0)
else:
self.global_batches = np.empty((0, self.global_batch), dtype=np.int64)
self.num_batches = len(self.global_batches)
def __iter__(self):
rng = np.random.RandomState(self.epoch)
order = np.arange(self.num_batches)
if self.shuffle:
rng.shuffle(order)
start = self.rank * self.batch_size
end = start + self.batch_size
for i in order:
yield self.global_batches[i][start:end]
def __len__(self):
return self.num_batches
def set_epoch(self, epoch):
self.epoch = epoch
def get_fixed_samples_by_resolution(dataset, samples_per_group=1):
size_groups = defaultdict(list)
try:
widths = dataset["width"]
heights = dataset["height"]
except KeyError:
widths = [0] * len(dataset)
heights = [0] * len(dataset)
for i, (w, h) in enumerate(zip(widths, heights)):
size = (w, h)
size_groups[size].append(i)
fixed_samples = {}
for size, indices in size_groups.items():
n_samples = min(samples_per_group, len(indices))
if len(size_groups)==1:
n_samples = samples_to_generate
if n_samples == 0:
continue
sample_indices = random.sample(indices, n_samples)
samples_data = [dataset[idx] for idx in sample_indices]
latents = torch.tensor(np.array([item["vae"] for item in samples_data])).to(device=device, dtype=dtype)
if latents.ndim == 4:
latents = latents.unsqueeze(2)
elif latents.ndim == 6:
latents = latents.squeeze(2)
texts = [item["text"] for item in samples_data]
if use_precomputed_embeddings:
embeddings = torch.tensor(
np.array([item["embeddings"] for item in samples_data]),
device=device,
dtype=dtype
)
masks = torch.tensor(
np.array([item["attention_mask"] for item in samples_data]),
device=device,
dtype=torch.int64
)
else:
embeddings, masks = encode_texts(texts,max_length)
fixed_samples[size] = (latents, embeddings, masks, texts)
print(f"Создано {len(fixed_samples)} групп фиксированных семплов по разрешениям")
return fixed_samples
if limit > 0:
dataset = load_from_disk(ds_path).select(range(limit))
else:
dataset = load_from_disk(ds_path)
print(f"images: {len(dataset)}")
def collate_fn_simple(batch):
latents = torch.from_numpy(
np.array([item["vae"] for item in batch], dtype=np.float16)
).to(device, dtype=dtype)
if latents.ndim == 4:
latents = latents.unsqueeze(2)
elif latents.ndim == 6:
latents = latents.squeeze(2)
if use_precomputed_embeddings:
embeddings = torch.from_numpy(
np.array([item["embeddings"] for item in batch], dtype=np.float16)
).to(device, dtype=dtype)
attention_mask = torch.from_numpy(
np.array([item["attention_mask"] for item in batch], dtype=np.int64)
).to(device)
return latents, embeddings, attention_mask
raw_texts = [item["text"] for item in batch]
texts = [
"" if t.lower().startswith("zero")
else "" if random.random() < cfg_dropout
else t[1:].lstrip() if t.startswith(".")
else t.replace("The image shows ", "").replace("The image is ", "").replace("This image captures ","").strip()
for t in raw_texts
]
embeddings, attention_mask = encode_texts(texts,max_length)
attention_mask = attention_mask.to(dtype=torch.int64)
return latents, embeddings, attention_mask
batch_sampler = DistributedResolutionBatchSampler(
dataset=dataset,
batch_size=batch_size,
num_replicas=accelerator.num_processes,
rank=accelerator.process_index,
shuffle = shuffle
)
dataloader = DataLoader(dataset, batch_sampler=batch_sampler, collate_fn=collate_fn_simple)
if accelerator.is_main_process:
print("Total samples", len(dataloader))
dataloader = accelerator.prepare(dataloader)
start_epoch = 0
global_step = 0
total_training_steps = (len(dataloader) * num_epochs)
world_size = accelerator.state.num_processes
latest_checkpoint = os.path.join(checkpoints_folder, project)
if os.path.isdir(latest_checkpoint):
print("Загружаем Transformer из чекпоинта:", latest_checkpoint)
transformer = CosmosTransformer3DModel.from_pretrained(latest_checkpoint).to(device=device, dtype=dtype)
if transformer_gradient:
transformer.enable_gradient_checkpointing()
else:
raise FileNotFoundError(f"Transformer checkpoint not found at {latest_checkpoint}")
def create_optimizer(name, params):
if name == "adam8bit":
return bnb.optim.AdamW8bit(
params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.001
)
elif name == "adam":
return torch.optim.AdamW(
params, lr=base_learning_rate, betas=(0.9, betta2), eps=eps, weight_decay=0.001
)
elif name == "adafactor":
return Adafactor(
params,
lr=base_learning_rate,
eps=(1e-30, 1e-3),
clip_threshold=1.0,
decay_rate=-0.8,
beta1=None,
weight_decay=0.001,
relative_step=False,
scale_parameter=False,
warmup_init=False
)
elif name == "muon_adam8bit":
return MuonAdamW8bit(
params,
lr=base_learning_rate,
betas=(0.9, betta2),
eps=eps,
weight_decay=0.01,
muon_lr_mult=muon_lr_scale,
)
else:
raise ValueError(f"Unknown optimizer: {name}")
if fbp:
trainable_params = list(transformer.parameters())
optimizer_dict = {p: create_optimizer(optimizer_type, [p]) for p in trainable_params}
def optimizer_hook(param):
optimizer_dict[param].step()
optimizer_dict[param].zero_grad(set_to_none=True)
for param in trainable_params:
param.register_post_accumulate_grad_hook(optimizer_hook)
transformer, optimizer = accelerator.prepare(transformer, optimizer_dict)
else:
#transformer.requires_grad_(True)
# 1. Сначала замораживаем ВООБЩЕ ВСЕ параметры
transformer.requires_grad_(False)
# 2. Определяем ключевое слово для слоев, которые нужно учить (Cross-Attention)
trainable_params_names = ["attn2"]
trainable_params = []
print("--- РАЗМОРОЖЕННЫЕ СЛОИ ---")
for name, param in transformer.named_parameters():
if any(target in name for target in trainable_params_names):
param.requires_grad_(True) # Размораживаем
trainable_params.append(param)
print(f"Обучаемый слой: {name}")
print("--------------------------")
# Защита от дурака
if len(trainable_params) == 0:
raise ValueError("Ошибка: ни один слой не был разморожен! Проверь ключи.")
optimizer = create_optimizer(optimizer_type, transformer.parameters())
def lr_schedule(step):
x = step / (total_training_steps * world_size)
warmup = warmup_percent
if not use_decay:
return base_learning_rate
if x < warmup:
return min_learning_rate + (base_learning_rate - min_learning_rate) * (x / warmup)
decay_ratio = (x - warmup) / (1 - warmup)
return min_learning_rate + 0.5 * (base_learning_rate - min_learning_rate) * \
(1 + math.cos(math.pi * decay_ratio))
lr_scheduler = LambdaLR(optimizer, lambda step: lr_schedule(step) / base_learning_rate)
if torch_compile:
print("Compiling Transformer... Это займет несколько минут, не прерывайте!")
transformer = torch.compile(transformer)
print("Compiling - ok")
if not fbp:
transformer, optimizer, lr_scheduler = accelerator.prepare(transformer, optimizer, lr_scheduler)
# Фиксированные семплы
fixed_samples = get_fixed_samples_by_resolution(dataset)
def get_negative_embedding(neg_prompt="", batch_size=1):
if not neg_prompt:
hidden_dim = 2048
seq_len = max_length
empty_emb = torch.zeros((batch_size, seq_len, hidden_dim), dtype=dtype, device=device)
empty_mask = torch.ones((batch_size, seq_len), dtype=torch.int64, device=device)
return empty_emb, empty_mask
uncond_emb, uncond_mask = encode_texts([neg_prompt],max_length)
uncond_emb = uncond_emb.to(dtype=dtype, device=device).repeat(batch_size, 1, 1)
uncond_mask = uncond_mask.to(device=device).repeat(batch_size, 1)
return uncond_emb, uncond_mask
if use_precomputed_embeddings:
load_text_encoder()
uncond_emb, uncond_mask = get_negative_embedding("low quality")
uncond_emb = uncond_emb.to("cpu")
uncond_mask = uncond_mask.to("cpu")
del text_encoder
torch.cuda.empty_cache()
gc.collect()
text_encoder = None
else:
uncond_emb, uncond_mask = get_negative_embedding("low quality")
def pad_to_match(a, b, pad_value=0):
Ta, Tb = a.shape[1], b.shape[1]
if Ta == Tb:
return a, b
T = max(Ta, Tb)
def pad(x, T_target):
pad_len = T_target - x.shape[1]
if pad_len <= 0:
return x
return torch.nn.functional.pad(x, (0, 0, 0, pad_len), value=pad_value)
return pad(a, T), pad(b, T)
@torch.compiler.disable()
@torch.no_grad()
def generate_and_save_samples(fixed_samples_cpu, uncond_data, step):
uncond_emb, uncond_mask = uncond_data
uncond_emb = uncond_emb.to(device)
uncond_mask = uncond_mask.to(device)
original_model = None
try:
if not torch_compile:
original_model = accelerator.unwrap_model(transformer, keep_torch_compile=True).eval()
else:
original_model = transformer.eval()
vae.to(device=device).eval()
all_generated_images = []
all_captions = []
for size, (sample_latents, sample_text_embeddings, sample_mask, sample_text) in fixed_samples_cpu.items():
width, height = size
curr_batch_size = sample_latents.shape[0]
in_channels = original_model.config.in_channels
sample_text_embeddings = sample_text_embeddings.to(dtype=dtype, device=device)
sigmas_dtype = torch.float32
sigmas = torch.linspace(0, 1, n_diffusion_steps, dtype=sigmas_dtype)
scheduler.set_timesteps(sigmas=sigmas, device=device)
if scheduler.config.get("final_sigmas_type", "zero") == "sigma_min":
scheduler.sigmas[-1] = scheduler.sigmas[-2]
if scheduler.sigmas[-1] == 0.0:
scheduler.sigmas[-1] = 1e-4
sigma_max = getattr(scheduler.config, "sigma_max", 80.0)
latents = torch.randn(
(curr_batch_size, in_channels, 1, sample_latents.shape[3], sample_latents.shape[4]),
device=device,
dtype=dtype,
generator=torch.Generator(device=device).manual_seed(seed)
) * sigma_max
padding_mask = torch.zeros((1, 1, sample_latents.shape[3], sample_latents.shape[4]), device=device, dtype=dtype)
if guidance_scale != 1:
neg_emb_batch = uncond_emb[0:1].expand(curr_batch_size, -1, -1)
neg_emb_batch, sample_text_embeddings = pad_to_match(neg_emb_batch, sample_text_embeddings)
for i, t in enumerate(scheduler.timesteps):
current_sigma = scheduler.sigmas[i]
if current_sigma == 0.0:
current_sigma = torch.tensor(1e-4, dtype=current_sigma.dtype, device=device)
current_t = current_sigma / (current_sigma + 1.0)
c_in = 1.0 - current_t
c_skip = 1.0 - current_t
c_out = -current_t
latent_model_input = (latents * c_in).to(dtype)
t_val = float(current_t.item()) if torch.is_tensor(current_t) else float(current_t)
timestep_tensor = torch.tensor([t_val], device=device, dtype=dtype).expand(curr_batch_size)
noise_pred = original_model(
hidden_states=latent_model_input,
timestep=timestep_tensor,
encoder_hidden_states=sample_text_embeddings,
padding_mask=padding_mask,
return_dict=False
)[0]
noise_pred = (c_skip * latents + c_out * noise_pred.float()).to(dtype)
if guidance_scale != 1:
noise_pred_uncond = original_model(
hidden_states=latent_model_input,
timestep=timestep_tensor,
encoder_hidden_states=neg_emb_batch,
padding_mask=padding_mask,
return_dict=False
)[0]
noise_pred_uncond = (c_skip * latents + c_out * noise_pred_uncond.float()).to(dtype)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred - noise_pred_uncond)
noise_pred = (latents - noise_pred) / current_sigma
latents = scheduler.step(noise_pred, t, latents, return_dict=False)[0]
current_latents = latents
if step == 0:
current_latents = sample_latents
if latents_mean is not None and latents_std is not None:
sigma_data = getattr(scheduler.config, "sigma_data", 1.0)
# Переводим векторы нормализации в float32
l_mean = torch.tensor(vae.config.latents_mean).view(1, -1, 1, 1, 1).to(device, torch.float32)
l_std = torch.tensor(vae.config.latents_std).view(1, -1, 1, 1, 1).to(device, torch.float32)
# Кастуем латенты в float32 перед умножением, чтобы сохранить точность
latents_for_decode = (current_latents.to(torch.float32) * l_std) / sigma_data + l_mean
else:
latents_for_decode = current_latents.to(torch.float32)
# 2. Декодируем, ПРИНУДИТЕЛЬНО ВКЛЮЧИВ MATH_SDP только для этого шага!
with torch.backends.cuda.sdp_kernel(enable_math=True, enable_flash=False, enable_mem_efficient=False):
decoded = vae.decode(latents_for_decode).sample
# 3. Отсекаем лишнее видео-измерение
if decoded.ndim == 5:
decoded = decoded[:, :, 0, :, :]
# 4. Он уже во float32, можно сразу пускать в цикл
decoded_fp32 = decoded
for img_idx, img_tensor in enumerate(decoded_fp32):
img = (img_tensor / 2 + 0.5).clamp(0, 1).cpu().numpy()
img = img.transpose(1, 2, 0)
if np.isnan(img).any():
print("NaNs found, saving stopped! Step:", step)
img = np.nan_to_num(img, nan=0.0)
pil_img = Image.fromarray((img * 255).astype("uint8"))
max_w_overall = max(s[0] for s in fixed_samples_cpu.keys())
max_h_overall = max(s[1] for s in fixed_samples_cpu.keys())
max_w_overall = max(255, max_w_overall)
max_h_overall = max(255, max_h_overall)
padded_img = ImageOps.pad(pil_img, (max_w_overall, max_h_overall), color='white')
all_generated_images.append(padded_img)
caption_text = sample_text[img_idx][:300] if img_idx < len(sample_text) else ""
all_captions.append(caption_text)
sample_path = f"{generated_folder}/{project}_{width}x{height}_{img_idx}.jpg"
pil_img.save(sample_path, "JPEG", quality=95)
if use_wandb and accelerator.is_main_process:
wandb_images = [
wandb.Image(img, caption=f"{all_captions[i]}")
for i, img in enumerate(all_generated_images)
]
wandb.log({"generated_images": wandb_images})
if use_comet_ml and accelerator.is_main_process:
for i, img in enumerate(all_generated_images):
comet_experiment.log_image(
image_data=img,
name=f"step_{step}_img_{i}",
step=step,
metadata={"caption": all_captions[i]}
)
finally:
vae.to("cpu")
uncond_emb = uncond_emb.to("cpu")
uncond_mask = uncond_mask.to("cpu")
try:
all_generated_images.clear()
all_captions.clear()
del all_generated_images, all_captions
del latents, current_latents, latent_model_input
del decoded, decoded_fp32
del sample_latents, sample_text_embeddings, sample_mask
del noise_pred, noise_pred_uncond
except UnboundLocalError:
pass
torch.cuda.synchronize()
torch.cuda.empty_cache()
gc.collect()
if accelerator.is_main_process:
if save_model:
print("Генерация сэмплов до старта обучения...")
generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), 0)
accelerator.wait_for_everyone()
def save_checkpoint(model_net, variant=""):
if accelerator.is_main_process:
model_to_save = None
if not torch_compile:
model_to_save = accelerator.unwrap_model(model_net)
else:
model_to_save = model_net
if variant != "":
model_to_save.to(dtype=torch.bfloat16).save_pretrained(
os.path.join(checkpoints_folder, f"{project}"), variant=variant
)
else:
model_to_save.save_pretrained(os.path.join(checkpoints_folder, f"{project}"))
torch.cuda.synchronize()
torch.cuda.empty_cache()
gc.collect()
if accelerator.is_main_process:
print(f"Total steps per GPU: {total_training_steps}")
epoch_loss_points = []
progress_bar = tqdm(total=total_training_steps, disable=not accelerator.is_local_main_process, desc="Training", unit="step")
steps_per_epoch = len(dataloader)
sink_interval = max(1, steps_per_epoch // sink_interval_share)
min_loss = 4.
last_sample_time = time.time()
sample_interval_seconds = sample_interval_min * 60
for epoch in range(start_epoch, start_epoch + num_epochs):
batch_losses = []
batch_grads = []
batch_sampler.set_epoch(epoch)
accelerator.wait_for_everyone()
transformer.train()
for step, (latents, embeddings, attention_mask) in enumerate(dataloader):
if save_model == False and epoch == 0 and step == 5 :
used_gb = torch.cuda.max_memory_allocated() / 1024**3
print(f"Шаг {step}: {used_gb:.2f} GB")
amp_context = accelerator.autocast() if torch_compile else nullcontext()
with accelerator.accumulate(transformer):
with amp_context:
noise = torch.randn_like(latents, dtype=latents.dtype)
t = torch.sigmoid(torch.randn(latents.shape[0], device=latents.device, dtype=latents.dtype) + sigmoid_bias)
noisy_latents_5d = (1.0 - t.view(-1, 1, 1, 1, 1)) * latents + t.view(-1, 1, 1, 1, 1) * noise
target_5d = noise - latents
padding_mask = torch.zeros((1, 1, latents.shape[3], latents.shape[4]), device=device, dtype=dtype)
timestep_tensor = t.flatten().to(dtype)
model_pred = transformer(
hidden_states=noisy_latents_5d,
timestep=timestep_tensor,
encoder_hidden_states=embeddings,
padding_mask=padding_mask,
return_dict=False
)[0]
mse_loss = F.mse_loss(model_pred.float(), target_5d.float())
batch_losses.append(mse_loss.detach().item())
if (global_step % 100 == 0) or (global_step % sink_interval == 0):
accelerator.wait_for_everyone()
losses_dict = {}
losses_dict["mse"] = mse_loss
if (global_step % 100 == 0) or (global_step % sink_interval == 0):
accelerator.wait_for_everyone()
accelerator.backward(mse_loss)
if (global_step % 100 == 0) or (global_step % sink_interval == 0):
accelerator.wait_for_everyone()
grad = 0.0
if not fbp:
if accelerator.sync_gradients:
grad_val = accelerator.clip_grad_norm_(transformer.parameters(), clip_grad_norm)
grad = grad_val.float().item() if torch.is_tensor(grad_val) else float(grad_val)
optimizer.step()
lr_scheduler.step()
optimizer.zero_grad(set_to_none=True)
if accelerator.sync_gradients:
global_step += 1
progress_bar.update(1)
if accelerator.is_main_process:
if fbp:
current_lr = base_learning_rate
else:
current_lr = lr_scheduler.get_last_lr()[0]
batch_grads.append(grad)
log_data = {}
log_data["loss_mse"] = mse_loss.detach().item()
log_data["lr"] = current_lr
log_data["grad"] = grad
if accelerator.sync_gradients:
if use_wandb:
wandb.log(log_data, step=global_step)
if use_comet_ml:
comet_experiment.log_metrics(log_data, step=global_step)
current_time = time.time()
is_time_to_sample = (current_time - last_sample_time) >= sample_interval_seconds
if is_time_to_sample or global_step == 50:
if save_model:
generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
elif epoch % 10 == 0:
generate_and_save_samples(fixed_samples, (uncond_emb, uncond_mask), global_step)
last_n = sink_interval
if save_model:
has_losses = len(batch_losses) > 0
avg_sample_loss = np.mean(batch_losses[-sink_interval:]) if has_losses else 0.0
last_loss = batch_losses[-1] if has_losses else 0.0
max_loss = max(avg_sample_loss, last_loss)
should_save = max_loss < min_loss * save_barrier
print(
f"Saving: {should_save} | Max: {max_loss:.4f} | "
f"Last: {last_loss:.4f} | Avg: {avg_sample_loss:.4f}"
)
if should_save:
min_loss = max_loss
save_checkpoint(transformer)
last_sample_time = current_time
transformer.train()
if accelerator.is_main_process:
avg_epoch_loss = np.mean(batch_losses) if len(batch_losses) > 0 else 0.0
avg_epoch_grad = np.mean(batch_grads) if len(batch_grads) > 0 else 0.0
print(f"\nЭпоха {epoch} завершена. Средний лосс: {avg_epoch_loss:.6f}")
log_data_ep = {
"epoch_loss": avg_epoch_loss,
"epoch_grad": avg_epoch_grad,
"epoch": epoch + 1,
}
if use_wandb:
wandb.log(log_data_ep)
if use_comet_ml:
comet_experiment.log_metrics(log_data_ep)
if accelerator.is_main_process:
print("Обучение завершено! Сохраняем финальную модель...")
save_checkpoint(transformer,"bf16")
if use_comet_ml:
comet_experiment.end()
accelerator.free_memory()
if torch.distributed.is_initialized():
torch.distributed.destroy_process_group()
print("Готово!")