Spaces:
Runtime error
Runtime error
| # Copyright 2024 NVIDIA CORPORATION & AFFILIATES | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # | |
| # SPDX-License-Identifier: Apache-2.0 | |
| import datetime | |
| import getpass | |
| import hashlib | |
| import json | |
| import os | |
| import os.path as osp | |
| import random | |
| import time | |
| import types | |
| import warnings | |
| from dataclasses import asdict | |
| from pathlib import Path | |
| import numpy as np | |
| import pyrallis | |
| import torch | |
| from accelerate import Accelerator, InitProcessGroupKwargs | |
| from accelerate.utils import DistributedType | |
| from PIL import Image | |
| from termcolor import colored | |
| warnings.filterwarnings("ignore") # ignore warning | |
| from diffusion import DPMS, FlowEuler, Scheduler | |
| from diffusion.data.builder import build_dataloader, build_dataset | |
| from diffusion.data.wids import DistributedRangedSampler | |
| from diffusion.model.builder import build_model, get_tokenizer_and_text_encoder, get_vae, vae_decode, vae_encode | |
| from diffusion.model.respace import compute_density_for_timestep_sampling | |
| from diffusion.utils.checkpoint import load_checkpoint, save_checkpoint | |
| from diffusion.utils.config import SanaConfig | |
| from diffusion.utils.data_sampler import AspectRatioBatchSampler | |
| from diffusion.utils.dist_utils import clip_grad_norm_, flush, get_world_size | |
| from diffusion.utils.logger import LogBuffer, get_root_logger | |
| from diffusion.utils.lr_scheduler import build_lr_scheduler | |
| from diffusion.utils.misc import DebugUnderflowOverflow, init_random_seed, read_config, set_random_seed | |
| from diffusion.utils.optimizer import auto_scale_lr, build_optimizer | |
| os.environ["TOKENIZERS_PARALLELISM"] = "false" | |
| def set_fsdp_env(): | |
| os.environ["ACCELERATE_USE_FSDP"] = "true" | |
| os.environ["FSDP_AUTO_WRAP_POLICY"] = "TRANSFORMER_BASED_WRAP" | |
| os.environ["FSDP_BACKWARD_PREFETCH"] = "BACKWARD_PRE" | |
| os.environ["FSDP_TRANSFORMER_CLS_TO_WRAP"] = "SanaBlock" | |
| def log_validation(accelerator, config, model, logger, step, device, vae=None, init_noise=None): | |
| torch.cuda.empty_cache() | |
| vis_sampler = config.scheduler.vis_sampler | |
| model = accelerator.unwrap_model(model).eval() | |
| hw = torch.tensor([[image_size, image_size]], dtype=torch.float, device=device).repeat(1, 1) | |
| ar = torch.tensor([[1.0]], device=device).repeat(1, 1) | |
| null_y = torch.load(null_embed_path, map_location="cpu") | |
| null_y = null_y["uncond_prompt_embeds"].to(device) | |
| # Create sampling noise: | |
| logger.info("Running validation... ") | |
| image_logs = [] | |
| def run_sampling(init_z=None, label_suffix="", vae=None, sampler="dpm-solver"): | |
| latents = [] | |
| current_image_logs = [] | |
| for prompt in validation_prompts: | |
| z = ( | |
| torch.randn(1, config.vae.vae_latent_dim, latent_size, latent_size, device=device) | |
| if init_z is None | |
| else init_z | |
| ) | |
| embed = torch.load( | |
| osp.join(config.train.valid_prompt_embed_root, f"{prompt[:50]}_{valid_prompt_embed_suffix}"), | |
| map_location="cpu", | |
| ) | |
| caption_embs, emb_masks = embed["caption_embeds"].to(device), embed["emb_mask"].to(device) | |
| # caption_embs = caption_embs[:, None] | |
| # emb_masks = emb_masks[:, None] | |
| model_kwargs = dict(data_info={"img_hw": hw, "aspect_ratio": ar}, mask=emb_masks) | |
| if sampler == "dpm-solver": | |
| dpm_solver = DPMS( | |
| model.forward_with_dpmsolver, | |
| condition=caption_embs, | |
| uncondition=null_y, | |
| cfg_scale=4.5, | |
| model_kwargs=model_kwargs, | |
| ) | |
| denoised = dpm_solver.sample( | |
| z, | |
| steps=14, | |
| order=2, | |
| skip_type="time_uniform", | |
| method="multistep", | |
| ) | |
| elif sampler == "flow_euler": | |
| flow_solver = FlowEuler( | |
| model, condition=caption_embs, uncondition=null_y, cfg_scale=4.5, model_kwargs=model_kwargs | |
| ) | |
| denoised = flow_solver.sample(z, steps=28) | |
| elif sampler == "flow_dpm-solver": | |
| dpm_solver = DPMS( | |
| model.forward_with_dpmsolver, | |
| condition=caption_embs, | |
| uncondition=null_y, | |
| cfg_scale=4.5, | |
| model_type="flow", | |
| model_kwargs=model_kwargs, | |
| schedule="FLOW", | |
| ) | |
| denoised = dpm_solver.sample( | |
| z, | |
| steps=20, | |
| order=2, | |
| skip_type="time_uniform_flow", | |
| method="multistep", | |
| flow_shift=config.scheduler.flow_shift, | |
| ) | |
| else: | |
| raise ValueError(f"{sampler} not implemented") | |
| latents.append(denoised) | |
| torch.cuda.empty_cache() | |
| if vae is None: | |
| vae = get_vae(config.vae.vae_type, config.vae.vae_pretrained, accelerator.device).to(torch.float16) | |
| for prompt, latent in zip(validation_prompts, latents): | |
| latent = latent.to(torch.float16) | |
| samples = vae_decode(config.vae.vae_type, vae, latent) | |
| samples = ( | |
| torch.clamp(127.5 * samples + 128.0, 0, 255).permute(0, 2, 3, 1).to("cpu", dtype=torch.uint8).numpy()[0] | |
| ) | |
| image = Image.fromarray(samples) | |
| current_image_logs.append({"validation_prompt": prompt + label_suffix, "images": [image]}) | |
| return current_image_logs | |
| # First run with original noise | |
| image_logs += run_sampling(init_z=None, label_suffix="", vae=vae, sampler=vis_sampler) | |
| # Second run with init_noise if provided | |
| if init_noise is not None: | |
| init_noise = torch.clone(init_noise).to(device) | |
| image_logs += run_sampling(init_z=init_noise, label_suffix=" w/ init noise", vae=vae, sampler=vis_sampler) | |
| formatted_images = [] | |
| for log in image_logs: | |
| images = log["images"] | |
| validation_prompt = log["validation_prompt"] | |
| for image in images: | |
| formatted_images.append((validation_prompt, np.asarray(image))) | |
| for tracker in accelerator.trackers: | |
| if tracker.name == "tensorboard": | |
| for validation_prompt, image in formatted_images: | |
| tracker.writer.add_images(validation_prompt, image[None, ...], step, dataformats="NHWC") | |
| elif tracker.name == "wandb": | |
| import wandb | |
| wandb_images = [] | |
| for validation_prompt, image in formatted_images: | |
| wandb_images.append(wandb.Image(image, caption=validation_prompt, file_type="jpg")) | |
| tracker.log({"validation": wandb_images}) | |
| else: | |
| logger.warn(f"image logging not implemented for {tracker.name}") | |
| def concatenate_images(image_caption, images_per_row=5, image_format="webp"): | |
| import io | |
| images = [log["images"][0] for log in image_caption] | |
| if images[0].size[0] > 1024: | |
| images = [image.resize((1024, 1024)) for image in images] | |
| widths, heights = zip(*(img.size for img in images)) | |
| max_width = max(widths) | |
| total_height = sum(heights[i : i + images_per_row][0] for i in range(0, len(images), images_per_row)) | |
| new_im = Image.new("RGB", (max_width * images_per_row, total_height)) | |
| y_offset = 0 | |
| for i in range(0, len(images), images_per_row): | |
| row_images = images[i : i + images_per_row] | |
| x_offset = 0 | |
| for img in row_images: | |
| new_im.paste(img, (x_offset, y_offset)) | |
| x_offset += max_width | |
| y_offset += heights[i] | |
| webp_image_bytes = io.BytesIO() | |
| new_im.save(webp_image_bytes, format=image_format) | |
| webp_image_bytes.seek(0) | |
| new_im = Image.open(webp_image_bytes) | |
| return new_im | |
| if config.train.local_save_vis: | |
| file_format = "webp" | |
| local_vis_save_path = osp.join(config.work_dir, "log_vis") | |
| os.umask(0o000) | |
| os.makedirs(local_vis_save_path, exist_ok=True) | |
| concatenated_image = concatenate_images(image_logs, images_per_row=5, image_format=file_format) | |
| save_path = ( | |
| osp.join(local_vis_save_path, f"vis_{step}.{file_format}") | |
| if init_noise is None | |
| else osp.join(local_vis_save_path, f"vis_{step}_w_init.{file_format}") | |
| ) | |
| concatenated_image.save(save_path) | |
| del vae | |
| flush() | |
| return image_logs | |
| def train(config, args, accelerator, model, optimizer, lr_scheduler, train_dataloader, train_diffusion, logger): | |
| if getattr(config.train, "debug_nan", False): | |
| DebugUnderflowOverflow(model) | |
| logger.info("NaN debugger registered. Start to detect overflow during training.") | |
| log_buffer = LogBuffer() | |
| global_step = start_step + 1 | |
| skip_step = max(config.train.skip_step, global_step) % train_dataloader_len | |
| skip_step = skip_step if skip_step < (train_dataloader_len - 20) else 0 | |
| loss_nan_timer = 0 | |
| # Cache Dataset for BatchSampler | |
| if args.caching and config.model.multi_scale: | |
| caching_start = time.time() | |
| logger.info( | |
| f"Start caching your dataset for batch_sampler at {cache_file}. \n" | |
| f"This may take a lot of time...No training will launch" | |
| ) | |
| train_dataloader.batch_sampler.sampler.set_start(max(train_dataloader.batch_sampler.exist_ids, 0)) | |
| accelerator.wait_for_everyone() | |
| for index, _ in enumerate(train_dataloader): | |
| accelerator.wait_for_everyone() | |
| if index % 2000 == 0: | |
| logger.info( | |
| f"rank: {rank}, Cached file len: {len(train_dataloader.batch_sampler.cached_idx)} / {len(train_dataloader)}" | |
| ) | |
| print( | |
| f"rank: {rank}, Cached file len: {len(train_dataloader.batch_sampler.cached_idx)} / {len(train_dataloader)}" | |
| ) | |
| if (time.time() - caching_start) / 3600 > 3.7: | |
| json.dump(train_dataloader.batch_sampler.cached_idx, open(cache_file, "w"), indent=4) | |
| accelerator.wait_for_everyone() | |
| break | |
| if len(train_dataloader.batch_sampler.cached_idx) == len(train_dataloader) - 1000: | |
| logger.info( | |
| f"Saving rank: {rank}, Cached file len: {len(train_dataloader.batch_sampler.cached_idx)} / {len(train_dataloader)}" | |
| ) | |
| json.dump(train_dataloader.batch_sampler.cached_idx, open(cache_file, "w"), indent=4) | |
| accelerator.wait_for_everyone() | |
| continue | |
| accelerator.wait_for_everyone() | |
| print(f"Saving rank-{rank} Cached file len: {len(train_dataloader.batch_sampler.cached_idx)}") | |
| json.dump(train_dataloader.batch_sampler.cached_idx, open(cache_file, "w"), indent=4) | |
| return | |
| # Now you train the model | |
| for epoch in range(start_epoch + 1, config.train.num_epochs + 1): | |
| time_start, last_tic = time.time(), time.time() | |
| sampler = ( | |
| train_dataloader.batch_sampler.sampler | |
| if (num_replicas > 1 or config.model.multi_scale) | |
| else train_dataloader.sampler | |
| ) | |
| sampler.set_epoch(epoch) | |
| sampler.set_start(max((skip_step - 1) * config.train.train_batch_size, 0)) | |
| if skip_step > 1 and accelerator.is_main_process: | |
| logger.info(f"Skipped Steps: {skip_step}") | |
| skip_step = 1 | |
| data_time_start = time.time() | |
| data_time_all = 0 | |
| lm_time_all = 0 | |
| vae_time_all = 0 | |
| model_time_all = 0 | |
| for step, batch in enumerate(train_dataloader): | |
| # image, json_info, key = batch | |
| accelerator.wait_for_everyone() | |
| data_time_all += time.time() - data_time_start | |
| vae_time_start = time.time() | |
| if load_vae_feat: | |
| z = batch[0].to(accelerator.device) | |
| else: | |
| with torch.no_grad(): | |
| with torch.amp.autocast( | |
| "cuda", | |
| enabled=(config.model.mixed_precision == "fp16" or config.model.mixed_precision == "bf16"), | |
| ): | |
| z = vae_encode( | |
| config.vae.vae_type, vae, batch[0], config.vae.sample_posterior, accelerator.device | |
| ) | |
| accelerator.wait_for_everyone() | |
| vae_time_all += time.time() - vae_time_start | |
| clean_images = z | |
| data_info = batch[3] | |
| lm_time_start = time.time() | |
| if load_text_feat: | |
| y = batch[1] # bs, 1, N, C | |
| y_mask = batch[2] # bs, 1, 1, N | |
| else: | |
| if "T5" in config.text_encoder.text_encoder_name: | |
| with torch.no_grad(): | |
| txt_tokens = tokenizer( | |
| batch[1], max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" | |
| ).to(accelerator.device) | |
| y = text_encoder(txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask)[0][:, None] | |
| y_mask = txt_tokens.attention_mask[:, None, None] | |
| elif ( | |
| "gemma" in config.text_encoder.text_encoder_name or "Qwen" in config.text_encoder.text_encoder_name | |
| ): | |
| with torch.no_grad(): | |
| if not config.text_encoder.chi_prompt: | |
| max_length_all = config.text_encoder.model_max_length | |
| prompt = batch[1] | |
| else: | |
| chi_prompt = "\n".join(config.text_encoder.chi_prompt) | |
| prompt = [chi_prompt + i for i in batch[1]] | |
| num_chi_prompt_tokens = len(tokenizer.encode(chi_prompt)) | |
| max_length_all = ( | |
| num_chi_prompt_tokens + config.text_encoder.model_max_length - 2 | |
| ) # magic number 2: [bos], [_] | |
| txt_tokens = tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=max_length_all, | |
| truncation=True, | |
| return_tensors="pt", | |
| ).to(accelerator.device) | |
| select_index = [0] + list( | |
| range(-config.text_encoder.model_max_length + 1, 0) | |
| ) # first bos and end N-1 | |
| y = text_encoder(txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask)[0][:, None][ | |
| :, :, select_index | |
| ] | |
| y_mask = txt_tokens.attention_mask[:, None, None][:, :, :, select_index] | |
| else: | |
| print("error") | |
| exit() | |
| # Sample a random timestep for each image | |
| bs = clean_images.shape[0] | |
| timesteps = torch.randint( | |
| 0, config.scheduler.train_sampling_steps, (bs,), device=clean_images.device | |
| ).long() | |
| if config.scheduler.weighting_scheme in ["logit_normal"]: | |
| # adapting from diffusers.training_utils | |
| u = compute_density_for_timestep_sampling( | |
| weighting_scheme=config.scheduler.weighting_scheme, | |
| batch_size=bs, | |
| logit_mean=config.scheduler.logit_mean, | |
| logit_std=config.scheduler.logit_std, | |
| mode_scale=None, # not used | |
| ) | |
| timesteps = (u * config.scheduler.train_sampling_steps).long().to(clean_images.device) | |
| grad_norm = None | |
| accelerator.wait_for_everyone() | |
| lm_time_all += time.time() - lm_time_start | |
| model_time_start = time.time() | |
| with accelerator.accumulate(model): | |
| # Predict the noise residual | |
| optimizer.zero_grad() | |
| loss_term = train_diffusion.training_losses( | |
| model, clean_images, timesteps, model_kwargs=dict(y=y, mask=y_mask, data_info=data_info) | |
| ) | |
| loss = loss_term["loss"].mean() | |
| accelerator.backward(loss) | |
| if accelerator.sync_gradients: | |
| grad_norm = accelerator.clip_grad_norm_(model.parameters(), config.train.gradient_clip) | |
| optimizer.step() | |
| lr_scheduler.step() | |
| accelerator.wait_for_everyone() | |
| model_time_all += time.time() - model_time_start | |
| if torch.any(torch.isnan(loss)): | |
| loss_nan_timer += 1 | |
| lr = lr_scheduler.get_last_lr()[0] | |
| logs = {args.loss_report_name: accelerator.gather(loss).mean().item()} | |
| if grad_norm is not None: | |
| logs.update(grad_norm=accelerator.gather(grad_norm).mean().item()) | |
| log_buffer.update(logs) | |
| if (step + 1) % config.train.log_interval == 0 or (step + 1) == 1: | |
| accelerator.wait_for_everyone() | |
| t = (time.time() - last_tic) / config.train.log_interval | |
| t_d = data_time_all / config.train.log_interval | |
| t_m = model_time_all / config.train.log_interval | |
| t_lm = lm_time_all / config.train.log_interval | |
| t_vae = vae_time_all / config.train.log_interval | |
| avg_time = (time.time() - time_start) / (step + 1) | |
| eta = str(datetime.timedelta(seconds=int(avg_time * (total_steps - global_step - 1)))) | |
| eta_epoch = str( | |
| datetime.timedelta( | |
| seconds=int( | |
| avg_time | |
| * (train_dataloader_len - sampler.step_start // config.train.train_batch_size - step - 1) | |
| ) | |
| ) | |
| ) | |
| log_buffer.average() | |
| current_step = ( | |
| global_step - sampler.step_start // config.train.train_batch_size | |
| ) % train_dataloader_len | |
| current_step = train_dataloader_len if current_step == 0 else current_step | |
| info = ( | |
| f"Epoch: {epoch} | Global Step: {global_step} | Local Step: {current_step} // {train_dataloader_len}, " | |
| f"total_eta: {eta}, epoch_eta:{eta_epoch}, time: all:{t:.3f}, model:{t_m:.3f}, data:{t_d:.3f}, " | |
| f"lm:{t_lm:.3f}, vae:{t_vae:.3f}, lr:{lr:.3e}, Cap: {batch[5][0]}, " | |
| ) | |
| info += ( | |
| f"s:({model.module.h}, {model.module.w}), " | |
| if hasattr(model, "module") | |
| else f"s:({model.h}, {model.w}), " | |
| ) | |
| info += ", ".join([f"{k}:{v:.4f}" for k, v in log_buffer.output.items()]) | |
| last_tic = time.time() | |
| log_buffer.clear() | |
| data_time_all = 0 | |
| model_time_all = 0 | |
| lm_time_all = 0 | |
| vae_time_all = 0 | |
| if accelerator.is_main_process: | |
| logger.info(info) | |
| logs.update(lr=lr) | |
| if accelerator.is_main_process: | |
| accelerator.log(logs, step=global_step) | |
| global_step += 1 | |
| if loss_nan_timer > 20: | |
| raise ValueError("Loss is NaN too much times. Break here.") | |
| if ( | |
| global_step % config.train.save_model_steps == 0 | |
| or (time.time() - training_start_time) / 3600 > config.train.training_hours | |
| ): | |
| accelerator.wait_for_everyone() | |
| if accelerator.is_main_process: | |
| os.umask(0o000) | |
| ckpt_saved_path = save_checkpoint( | |
| osp.join(config.work_dir, "checkpoints"), | |
| epoch=epoch, | |
| step=global_step, | |
| model=accelerator.unwrap_model(model), | |
| optimizer=optimizer, | |
| lr_scheduler=lr_scheduler, | |
| generator=generator, | |
| add_symlink=True, | |
| ) | |
| if config.train.online_metric and global_step % config.train.eval_metric_step == 0 and step > 1: | |
| online_metric_monitor_dir = osp.join(config.work_dir, config.train.online_metric_dir) | |
| os.makedirs(online_metric_monitor_dir, exist_ok=True) | |
| with open(f"{online_metric_monitor_dir}/{ckpt_saved_path.split('/')[-1]}.txt", "w") as f: | |
| f.write(osp.join(config.work_dir, "config.py") + "\n") | |
| f.write(ckpt_saved_path) | |
| if (time.time() - training_start_time) / 3600 > config.train.training_hours: | |
| logger.info(f"Stopping training at epoch {epoch}, step {global_step} due to time limit.") | |
| return | |
| if config.train.visualize and (global_step % config.train.eval_sampling_steps == 0 or (step + 1) == 1): | |
| accelerator.wait_for_everyone() | |
| if accelerator.is_main_process: | |
| if validation_noise is not None: | |
| log_validation( | |
| accelerator=accelerator, | |
| config=config, | |
| model=model, | |
| logger=logger, | |
| step=global_step, | |
| device=accelerator.device, | |
| vae=vae, | |
| init_noise=validation_noise, | |
| ) | |
| else: | |
| log_validation( | |
| accelerator=accelerator, | |
| config=config, | |
| model=model, | |
| logger=logger, | |
| step=global_step, | |
| device=accelerator.device, | |
| vae=vae, | |
| ) | |
| # avoid dead-lock of multiscale data batch sampler | |
| # for internal, refactor dataloader logic to remove the ad-hoc implementation | |
| if ( | |
| config.model.multi_scale | |
| and (train_dataloader_len - sampler.step_start // config.train.train_batch_size - step) < 30 | |
| ): | |
| global_step = epoch * train_dataloader_len | |
| logger.info("Early stop current iteration") | |
| break | |
| data_time_start = time.time() | |
| if epoch % config.train.save_model_epochs == 0 or epoch == config.train.num_epochs and not config.debug: | |
| accelerator.wait_for_everyone() | |
| if accelerator.is_main_process: | |
| # os.umask(0o000) | |
| ckpt_saved_path = save_checkpoint( | |
| osp.join(config.work_dir, "checkpoints"), | |
| epoch=epoch, | |
| step=global_step, | |
| model=accelerator.unwrap_model(model), | |
| optimizer=optimizer, | |
| lr_scheduler=lr_scheduler, | |
| generator=generator, | |
| add_symlink=True, | |
| ) | |
| online_metric_monitor_dir = osp.join(config.work_dir, config.train.online_metric_dir) | |
| os.makedirs(online_metric_monitor_dir, exist_ok=True) | |
| with open(f"{online_metric_monitor_dir}/{ckpt_saved_path.split('/')[-1]}.txt", "w") as f: | |
| f.write(osp.join(config.work_dir, "config.py") + "\n") | |
| f.write(ckpt_saved_path) | |
| accelerator.wait_for_everyone() | |
| def main(cfg: SanaConfig) -> None: | |
| global train_dataloader_len, start_epoch, start_step, vae, generator, num_replicas, rank, training_start_time | |
| global load_vae_feat, load_text_feat, validation_noise, text_encoder, tokenizer | |
| global max_length, validation_prompts, latent_size, valid_prompt_embed_suffix, null_embed_path | |
| global image_size, cache_file, total_steps | |
| config = cfg | |
| args = cfg | |
| # config = read_config(args.config) | |
| training_start_time = time.time() | |
| load_from = True | |
| if args.resume_from or config.model.resume_from: | |
| load_from = False | |
| config.model.resume_from = dict( | |
| checkpoint=args.resume_from or config.model.resume_from, | |
| load_ema=False, | |
| resume_optimizer=True, | |
| resume_lr_scheduler=True, | |
| ) | |
| if args.debug: | |
| config.train.log_interval = 1 | |
| config.train.train_batch_size = min(64, config.train.train_batch_size) | |
| args.report_to = "tensorboard" | |
| os.umask(0o000) | |
| os.makedirs(config.work_dir, exist_ok=True) | |
| init_handler = InitProcessGroupKwargs() | |
| init_handler.timeout = datetime.timedelta(seconds=5400) # change timeout to avoid a strange NCCL bug | |
| # Initialize accelerator and tensorboard logging | |
| if config.train.use_fsdp: | |
| init_train = "FSDP" | |
| from accelerate import FullyShardedDataParallelPlugin | |
| from torch.distributed.fsdp.fully_sharded_data_parallel import FullStateDictConfig | |
| set_fsdp_env() | |
| fsdp_plugin = FullyShardedDataParallelPlugin( | |
| state_dict_config=FullStateDictConfig(offload_to_cpu=False, rank0_only=False), | |
| ) | |
| else: | |
| init_train = "DDP" | |
| fsdp_plugin = None | |
| accelerator = Accelerator( | |
| mixed_precision=config.model.mixed_precision, | |
| gradient_accumulation_steps=config.train.gradient_accumulation_steps, | |
| log_with=args.report_to, | |
| project_dir=osp.join(config.work_dir, "logs"), | |
| fsdp_plugin=fsdp_plugin, | |
| kwargs_handlers=[init_handler], | |
| ) | |
| log_name = "train_log.log" | |
| logger = get_root_logger(osp.join(config.work_dir, log_name)) | |
| logger.info(accelerator.state) | |
| config.train.seed = init_random_seed(getattr(config.train, "seed", None)) | |
| set_random_seed(config.train.seed + int(os.environ["LOCAL_RANK"])) | |
| generator = torch.Generator(device="cpu").manual_seed(config.train.seed) | |
| if accelerator.is_main_process: | |
| pyrallis.dump(config, open(osp.join(config.work_dir, "config.yaml"), "w"), sort_keys=False, indent=4) | |
| if args.report_to == "wandb": | |
| import wandb | |
| wandb.init(project=args.tracker_project_name, name=args.name, resume="allow", id=args.name) | |
| logger.info(f"Config: \n{config}") | |
| logger.info(f"World_size: {get_world_size()}, seed: {config.train.seed}") | |
| logger.info(f"Initializing: {init_train} for training") | |
| image_size = config.model.image_size | |
| latent_size = int(image_size) // config.vae.vae_downsample_rate | |
| pred_sigma = getattr(config.scheduler, "pred_sigma", True) | |
| learn_sigma = getattr(config.scheduler, "learn_sigma", True) and pred_sigma | |
| max_length = config.text_encoder.model_max_length | |
| vae = None | |
| validation_noise = ( | |
| torch.randn(1, config.vae.vae_latent_dim, latent_size, latent_size, device="cpu", generator=generator) | |
| if getattr(config.train, "deterministic_validation", False) | |
| else None | |
| ) | |
| if not config.data.load_vae_feat: | |
| vae = get_vae(config.vae.vae_type, config.vae.vae_pretrained, accelerator.device).to(torch.float16) | |
| tokenizer = text_encoder = None | |
| if not config.data.load_text_feat: | |
| tokenizer, text_encoder = get_tokenizer_and_text_encoder( | |
| name=config.text_encoder.text_encoder_name, device=accelerator.device | |
| ) | |
| text_embed_dim = text_encoder.config.hidden_size | |
| else: | |
| text_embed_dim = config.text_encoder.caption_channels | |
| logger.info(f"vae type: {config.vae.vae_type}") | |
| if config.text_encoder.chi_prompt: | |
| chi_prompt = "\n".join(config.text_encoder.chi_prompt) | |
| logger.info(f"Complex Human Instruct: {chi_prompt}") | |
| os.makedirs(config.train.null_embed_root, exist_ok=True) | |
| null_embed_path = osp.join( | |
| config.train.null_embed_root, | |
| f"null_embed_diffusers_{config.text_encoder.text_encoder_name}_{max_length}token_{text_embed_dim}.pth", | |
| ) | |
| if config.train.visualize and len(config.train.validation_prompts): | |
| # preparing embeddings for visualization. We put it here for saving GPU memory | |
| valid_prompt_embed_suffix = f"{max_length}token_{config.text_encoder.text_encoder_name}_{text_embed_dim}.pth" | |
| validation_prompts = config.train.validation_prompts | |
| skip = True | |
| if config.text_encoder.chi_prompt: | |
| uuid_chi_prompt = hashlib.sha256(chi_prompt.encode()).hexdigest() | |
| else: | |
| uuid_chi_prompt = hashlib.sha256(b"").hexdigest() | |
| config.train.valid_prompt_embed_root = osp.join(config.train.valid_prompt_embed_root, uuid_chi_prompt) | |
| Path(config.train.valid_prompt_embed_root).mkdir(parents=True, exist_ok=True) | |
| if config.text_encoder.chi_prompt: | |
| # Save complex human instruct to a file | |
| chi_prompt_file = osp.join(config.train.valid_prompt_embed_root, "chi_prompt.txt") | |
| with open(chi_prompt_file, "w", encoding="utf-8") as f: | |
| f.write(chi_prompt) | |
| for prompt in validation_prompts: | |
| prompt_embed_path = osp.join( | |
| config.train.valid_prompt_embed_root, f"{prompt[:50]}_{valid_prompt_embed_suffix}" | |
| ) | |
| if not (osp.exists(prompt_embed_path) and osp.exists(null_embed_path)): | |
| skip = False | |
| logger.info("Preparing Visualization prompt embeddings...") | |
| break | |
| if accelerator.is_main_process and not skip: | |
| if config.data.load_text_feat and (tokenizer is None or text_encoder is None): | |
| logger.info(f"Loading text encoder and tokenizer from {config.text_encoder.text_encoder_name} ...") | |
| tokenizer, text_encoder = get_tokenizer_and_text_encoder(name=config.text_encoder.text_encoder_name) | |
| for prompt in validation_prompts: | |
| prompt_embed_path = osp.join( | |
| config.train.valid_prompt_embed_root, f"{prompt[:50]}_{valid_prompt_embed_suffix}" | |
| ) | |
| if "T5" in config.text_encoder.text_encoder_name: | |
| txt_tokens = tokenizer( | |
| prompt, max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" | |
| ).to(accelerator.device) | |
| caption_emb = text_encoder(txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask)[0] | |
| caption_emb_mask = txt_tokens.attention_mask | |
| elif ( | |
| "gemma" in config.text_encoder.text_encoder_name or "Qwen" in config.text_encoder.text_encoder_name | |
| ): | |
| if not config.text_encoder.chi_prompt: | |
| max_length_all = config.text_encoder.model_max_length | |
| else: | |
| chi_prompt = "\n".join(config.text_encoder.chi_prompt) | |
| prompt = chi_prompt + prompt | |
| num_chi_prompt_tokens = len(tokenizer.encode(chi_prompt)) | |
| max_length_all = ( | |
| num_chi_prompt_tokens + config.text_encoder.model_max_length - 2 | |
| ) # magic number 2: [bos], [_] | |
| txt_tokens = tokenizer( | |
| prompt, | |
| max_length=max_length_all, | |
| padding="max_length", | |
| truncation=True, | |
| return_tensors="pt", | |
| ).to(accelerator.device) | |
| select_index = [0] + list(range(-config.text_encoder.model_max_length + 1, 0)) | |
| caption_emb = text_encoder(txt_tokens.input_ids, attention_mask=txt_tokens.attention_mask)[0][ | |
| :, select_index | |
| ] | |
| caption_emb_mask = txt_tokens.attention_mask[:, select_index] | |
| else: | |
| raise ValueError(f"{config.text_encoder.text_encoder_name} is not supported!!") | |
| torch.save({"caption_embeds": caption_emb, "emb_mask": caption_emb_mask}, prompt_embed_path) | |
| null_tokens = tokenizer( | |
| "", max_length=max_length, padding="max_length", truncation=True, return_tensors="pt" | |
| ).to(accelerator.device) | |
| if "T5" in config.text_encoder.text_encoder_name: | |
| null_token_emb = text_encoder(null_tokens.input_ids, attention_mask=null_tokens.attention_mask)[0] | |
| elif "gemma" in config.text_encoder.text_encoder_name or "Qwen" in config.text_encoder.text_encoder_name: | |
| null_token_emb = text_encoder(null_tokens.input_ids, attention_mask=null_tokens.attention_mask)[0] | |
| else: | |
| raise ValueError(f"{config.text_encoder.text_encoder_name} is not supported!!") | |
| torch.save( | |
| {"uncond_prompt_embeds": null_token_emb, "uncond_prompt_embeds_mask": null_tokens.attention_mask}, | |
| null_embed_path, | |
| ) | |
| if config.data.load_text_feat: | |
| del tokenizer | |
| del text_encoder | |
| del null_token_emb | |
| del null_tokens | |
| flush() | |
| os.environ["AUTOCAST_LINEAR_ATTN"] = "true" if config.model.autocast_linear_attn else "false" | |
| # 1. build scheduler | |
| train_diffusion = Scheduler( | |
| str(config.scheduler.train_sampling_steps), | |
| noise_schedule=config.scheduler.noise_schedule, | |
| predict_v=config.scheduler.predict_v, | |
| learn_sigma=learn_sigma, | |
| pred_sigma=pred_sigma, | |
| snr=config.train.snr_loss, | |
| flow_shift=config.scheduler.flow_shift, | |
| ) | |
| predict_info = f"v-prediction: {config.scheduler.predict_v}, noise schedule: {config.scheduler.noise_schedule}" | |
| if "flow" in config.scheduler.noise_schedule: | |
| predict_info += f", flow shift: {config.scheduler.flow_shift}" | |
| if config.scheduler.weighting_scheme in ["logit_normal", "mode"]: | |
| predict_info += ( | |
| f", flow weighting: {config.scheduler.weighting_scheme}, " | |
| f"logit-mean: {config.scheduler.logit_mean}, logit-std: {config.scheduler.logit_std}" | |
| ) | |
| logger.info(predict_info) | |
| # 2. build models | |
| model_kwargs = { | |
| "pe_interpolation": config.model.pe_interpolation, | |
| "config": config, | |
| "model_max_length": max_length, | |
| "qk_norm": config.model.qk_norm, | |
| "micro_condition": config.model.micro_condition, | |
| "caption_channels": text_embed_dim, | |
| "y_norm": config.text_encoder.y_norm, | |
| "attn_type": config.model.attn_type, | |
| "ffn_type": config.model.ffn_type, | |
| "mlp_ratio": config.model.mlp_ratio, | |
| "mlp_acts": list(config.model.mlp_acts), | |
| "in_channels": config.vae.vae_latent_dim, | |
| "y_norm_scale_factor": config.text_encoder.y_norm_scale_factor, | |
| "use_pe": config.model.use_pe, | |
| "linear_head_dim": config.model.linear_head_dim, | |
| "pred_sigma": pred_sigma, | |
| "learn_sigma": learn_sigma, | |
| } | |
| model = build_model( | |
| config.model.model, | |
| config.train.grad_checkpointing, | |
| getattr(config.model, "fp32_attention", False), | |
| input_size=latent_size, | |
| **model_kwargs, | |
| ).train() | |
| logger.info( | |
| colored( | |
| f"{model.__class__.__name__}:{config.model.model}, " | |
| f"Model Parameters: {sum(p.numel() for p in model.parameters()) / 1e6:.2f}M", | |
| "green", | |
| attrs=["bold"], | |
| ) | |
| ) | |
| # 2-1. load model | |
| if args.load_from is not None: | |
| config.model.load_from = args.load_from | |
| if config.model.load_from is not None and load_from: | |
| _, missing, unexpected, _ = load_checkpoint( | |
| config.model.load_from, | |
| model, | |
| load_ema=config.model.resume_from.get("load_ema", False), | |
| null_embed_path=null_embed_path, | |
| ) | |
| logger.warning(f"Missing keys: {missing}") | |
| logger.warning(f"Unexpected keys: {unexpected}") | |
| # prepare for FSDP clip grad norm calculation | |
| if accelerator.distributed_type == DistributedType.FSDP: | |
| for m in accelerator._models: | |
| m.clip_grad_norm_ = types.MethodType(clip_grad_norm_, m) | |
| # 3. build dataloader | |
| config.data.data_dir = config.data.data_dir if isinstance(config.data.data_dir, list) else [config.data.data_dir] | |
| config.data.data_dir = [ | |
| data if data.startswith(("https://", "http://", "gs://", "/", "~")) else osp.abspath(osp.expanduser(data)) | |
| for data in config.data.data_dir | |
| ] | |
| num_replicas = int(os.environ["WORLD_SIZE"]) | |
| rank = int(os.environ["RANK"]) | |
| dataset = build_dataset( | |
| asdict(config.data), | |
| resolution=image_size, | |
| aspect_ratio_type=config.model.aspect_ratio_type, | |
| real_prompt_ratio=config.train.real_prompt_ratio, | |
| max_length=max_length, | |
| config=config, | |
| caption_proportion=config.data.caption_proportion, | |
| sort_dataset=config.data.sort_dataset, | |
| vae_downsample_rate=config.vae.vae_downsample_rate, | |
| ) | |
| accelerator.wait_for_everyone() | |
| if config.model.multi_scale: | |
| drop_last = True | |
| uuid = hashlib.sha256("-".join(config.data.data_dir).encode()).hexdigest()[:8] | |
| cache_dir = osp.expanduser(f"~/.cache/_wids_batchsampler_cache") | |
| os.makedirs(cache_dir, exist_ok=True) | |
| base_pattern = ( | |
| f"{cache_dir}/{getpass.getuser()}-{uuid}-sort_dataset{config.data.sort_dataset}" | |
| f"-hq_only{config.data.hq_only}-valid_num{config.data.valid_num}" | |
| f"-aspect_ratio{len(dataset.aspect_ratio)}-droplast{drop_last}" | |
| f"dataset_len{len(dataset)}" | |
| ) | |
| cache_file = f"{base_pattern}-num_replicas{num_replicas}-rank{rank}" | |
| for i in config.data.data_dir: | |
| cache_file += f"-{i}" | |
| cache_file += ".json" | |
| sampler = DistributedRangedSampler(dataset, num_replicas=num_replicas, rank=rank) | |
| batch_sampler = AspectRatioBatchSampler( | |
| sampler=sampler, | |
| dataset=dataset, | |
| batch_size=config.train.train_batch_size, | |
| aspect_ratios=dataset.aspect_ratio, | |
| drop_last=drop_last, | |
| ratio_nums=dataset.ratio_nums, | |
| config=config, | |
| valid_num=config.data.valid_num, | |
| hq_only=config.data.hq_only, | |
| cache_file=cache_file, | |
| caching=args.caching, | |
| ) | |
| train_dataloader = build_dataloader(dataset, batch_sampler=batch_sampler, num_workers=config.train.num_workers) | |
| train_dataloader_len = len(train_dataloader) | |
| logger.info(f"rank-{rank} Cached file len: {len(train_dataloader.batch_sampler.cached_idx)}") | |
| else: | |
| sampler = DistributedRangedSampler(dataset, num_replicas=num_replicas, rank=rank) | |
| train_dataloader = build_dataloader( | |
| dataset, | |
| num_workers=config.train.num_workers, | |
| batch_size=config.train.train_batch_size, | |
| shuffle=False, | |
| sampler=sampler, | |
| ) | |
| train_dataloader_len = len(train_dataloader) | |
| load_vae_feat = getattr(train_dataloader.dataset, "load_vae_feat", False) | |
| load_text_feat = getattr(train_dataloader.dataset, "load_text_feat", False) | |
| # 4. build optimizer and lr scheduler | |
| lr_scale_ratio = 1 | |
| if getattr(config.train, "auto_lr", None): | |
| lr_scale_ratio = auto_scale_lr( | |
| config.train.train_batch_size * get_world_size() * config.train.gradient_accumulation_steps, | |
| config.train.optimizer, | |
| **config.train.auto_lr, | |
| ) | |
| optimizer = build_optimizer(model, config.train.optimizer) | |
| if config.train.lr_schedule_args and config.train.lr_schedule_args.get("num_warmup_steps", None): | |
| config.train.lr_schedule_args["num_warmup_steps"] = ( | |
| config.train.lr_schedule_args["num_warmup_steps"] * num_replicas | |
| ) | |
| lr_scheduler = build_lr_scheduler(config.train, optimizer, train_dataloader, lr_scale_ratio) | |
| logger.warning( | |
| f"{colored(f'Basic Setting: ', 'green', attrs=['bold'])}" | |
| f"lr: {config.train.optimizer['lr']:.5f}, bs: {config.train.train_batch_size}, gc: {config.train.grad_checkpointing}, " | |
| f"gc_accum_step: {config.train.gradient_accumulation_steps}, qk norm: {config.model.qk_norm}, " | |
| f"fp32 attn: {config.model.fp32_attention}, attn type: {config.model.attn_type}, ffn type: {config.model.ffn_type}, " | |
| f"text encoder: {config.text_encoder.text_encoder_name}, captions: {config.data.caption_proportion}, precision: {config.model.mixed_precision}" | |
| ) | |
| timestamp = time.strftime("%Y-%m-%d_%H:%M:%S", time.localtime()) | |
| if accelerator.is_main_process: | |
| tracker_config = dict(vars(config)) | |
| try: | |
| accelerator.init_trackers(args.tracker_project_name, tracker_config) | |
| except: | |
| accelerator.init_trackers(f"tb_{timestamp}") | |
| start_epoch = 0 | |
| start_step = 0 | |
| total_steps = train_dataloader_len * config.train.num_epochs | |
| # Resume training | |
| if config.model.resume_from is not None and config.model.resume_from["checkpoint"] is not None: | |
| rng_state = None | |
| ckpt_path = osp.join(config.work_dir, "checkpoints") | |
| check_flag = osp.exists(ckpt_path) and len(os.listdir(ckpt_path)) != 0 | |
| if config.model.resume_from["checkpoint"] == "latest": | |
| if check_flag: | |
| checkpoints = os.listdir(ckpt_path) | |
| if "latest.pth" in checkpoints and osp.exists(osp.join(ckpt_path, "latest.pth")): | |
| config.model.resume_from["checkpoint"] = osp.realpath(osp.join(ckpt_path, "latest.pth")) | |
| else: | |
| checkpoints = [i for i in checkpoints if i.startswith("epoch_")] | |
| checkpoints = sorted(checkpoints, key=lambda x: int(x.replace(".pth", "").split("_")[3])) | |
| config.model.resume_from["checkpoint"] = osp.join(ckpt_path, checkpoints[-1]) | |
| else: | |
| config.model.resume_from["checkpoint"] = config.model.load_from | |
| if config.model.resume_from["checkpoint"] is not None: | |
| _, missing, unexpected, rng_state = load_checkpoint( | |
| **config.model.resume_from, | |
| model=model, | |
| optimizer=optimizer if check_flag else None, | |
| lr_scheduler=lr_scheduler if check_flag else None, | |
| null_embed_path=null_embed_path, | |
| ) | |
| logger.warning(f"Missing keys: {missing}") | |
| logger.warning(f"Unexpected keys: {unexpected}") | |
| path = osp.basename(config.model.resume_from["checkpoint"]) | |
| try: | |
| start_epoch = int(path.replace(".pth", "").split("_")[1]) - 1 | |
| start_step = int(path.replace(".pth", "").split("_")[3]) | |
| except: | |
| pass | |
| # resume randomise | |
| if rng_state: | |
| logger.info("resuming randomise") | |
| torch.set_rng_state(rng_state["torch"]) | |
| np.random.set_state(rng_state["numpy"]) | |
| random.setstate(rng_state["python"]) | |
| generator.set_state(rng_state["generator"]) # resume generator status | |
| try: | |
| torch.cuda.set_rng_state_all(rng_state["torch_cuda"]) | |
| except: | |
| logger.warning("Failed to resume torch_cuda rng state") | |
| # Prepare everything | |
| # There is no specific order to remember, you just need to unpack the | |
| # objects in the same order you gave them to the prepare method. | |
| model = accelerator.prepare(model) | |
| optimizer, lr_scheduler = accelerator.prepare(optimizer, lr_scheduler) | |
| # Start Training | |
| train( | |
| config=config, | |
| args=args, | |
| accelerator=accelerator, | |
| model=model, | |
| optimizer=optimizer, | |
| lr_scheduler=lr_scheduler, | |
| train_dataloader=train_dataloader, | |
| train_diffusion=train_diffusion, | |
| logger=logger, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |