| import json |
| from time import time |
| import argparse |
| import logging |
| import os |
| from pathlib import Path |
| import math |
|
|
| import numpy as np |
| from PIL import Image |
| from copy import deepcopy |
|
|
| import torch |
| import torch.distributed as dist |
| from torch.utils.data import Dataset, DataLoader |
| from torch.utils.data.distributed import DistributedSampler |
| from torchvision import transforms |
|
|
| from accelerate import Accelerator |
| from accelerate.utils import ProjectConfiguration, set_seed |
| from diffusers.optimization import get_scheduler |
| from accelerate.utils import DistributedType |
| from peft import LoraConfig, set_peft_model_state_dict, PeftModel, get_peft_model |
| from peft.utils import get_peft_model_state_dict |
| from huggingface_hub import snapshot_download |
| from safetensors.torch import save_file |
|
|
| from diffusers.models import AutoencoderKL |
|
|
| from OmniGen import OmniGen, OmniGenProcessor |
| from OmniGen.train_helper import DatasetFromJson, TrainDataCollator |
| from OmniGen.train_helper import training_losses |
| from OmniGen.utils import ( |
| create_logger, |
| update_ema, |
| requires_grad, |
| center_crop_arr, |
| crop_arr, |
| vae_encode, |
| vae_encode_list |
| ) |
|
|
| def main(args): |
| |
| from accelerate import DistributedDataParallelKwargs as DDPK |
| kwargs = DDPK(find_unused_parameters=False) |
| accelerator = Accelerator( |
| gradient_accumulation_steps=args.gradient_accumulation_steps, |
| mixed_precision=args.mixed_precision, |
| log_with=args.report_to, |
| project_dir=args.results_dir, |
| kwargs_handlers=[kwargs], |
| ) |
| device = accelerator.device |
| accelerator.init_trackers("tensorboard_log", config=args.__dict__) |
|
|
| |
| checkpoint_dir = f"{args.results_dir}/checkpoints" |
| logger = create_logger(args.results_dir) |
| if accelerator.is_main_process: |
| os.makedirs(checkpoint_dir, exist_ok=True) |
| logger.info(f"Experiment directory created at {args.results_dir}") |
| json.dump(args.__dict__, open(os.path.join(args.results_dir, 'train_args.json'), 'w')) |
|
|
|
|
| |
| if not os.path.exists(args.model_name_or_path): |
| cache_folder = os.getenv('HF_HUB_CACHE') |
| args.model_name_or_path = snapshot_download(repo_id=args.model_name_or_path, |
| cache_dir=cache_folder, |
| ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5']) |
| logger.info(f"Downloaded model to {args.model_name_or_path}") |
| model = OmniGen.from_pretrained(args.model_name_or_path) |
| model.llm.config.use_cache = False |
| model.llm.gradient_checkpointing_enable() |
| model = model.to(device) |
|
|
| if args.vae_path is None: |
| print(args.model_name_or_path) |
| vae_path = os.path.join(args.model_name_or_path, "vae") |
| if os.path.exists(vae_path): |
| vae = AutoencoderKL.from_pretrained(vae_path).to(device) |
| else: |
| logger.info("No VAE found in model, downloading stabilityai/sdxl-vae from HF") |
| logger.info("If you have VAE in local folder, please specify the path with --vae_path") |
| vae = AutoencoderKL.from_pretrained("stabilityai/sdxl-vae").to(device) |
| else: |
| vae = AutoencoderKL.from_pretrained(args.vae_path).to(device) |
|
|
| weight_dtype = torch.float32 |
| if accelerator.mixed_precision == "fp16": |
| weight_dtype = torch.float16 |
| elif accelerator.mixed_precision == "bf16": |
| weight_dtype = torch.bfloat16 |
| vae.to(dtype=torch.float32) |
| model.to(weight_dtype) |
|
|
| processor = OmniGenProcessor.from_pretrained(args.model_name_or_path) |
|
|
| requires_grad(vae, False) |
| if args.use_lora: |
| if accelerator.distributed_type == DistributedType.FSDP: |
| raise NotImplementedError("FSDP does not support LoRA") |
| requires_grad(model, False) |
| transformer_lora_config = LoraConfig( |
| r=args.lora_rank, |
| lora_alpha=args.lora_rank, |
| init_lora_weights="gaussian", |
| target_modules=["qkv_proj", "o_proj"], |
| ) |
| model.llm.enable_input_require_grads() |
| model = get_peft_model(model, transformer_lora_config) |
| model.to(weight_dtype) |
| transformer_lora_parameters = list(filter(lambda p: p.requires_grad, model.parameters())) |
| for n,p in model.named_parameters(): |
| print(n, p.requires_grad) |
| opt = torch.optim.AdamW(transformer_lora_parameters, lr=args.lr, weight_decay=args.adam_weight_decay) |
| else: |
| opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.adam_weight_decay) |
|
|
| ema = None |
| if args.use_ema: |
| ema = deepcopy(model).to(device) |
| requires_grad(ema, False) |
| |
|
|
| |
| crop_func = crop_arr |
| if not args.keep_raw_resolution: |
| crop_func = center_crop_arr |
| image_transform = transforms.Compose([ |
| transforms.Lambda(lambda pil_image: crop_func(pil_image, args.max_image_size)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) |
| ]) |
| |
| dataset = DatasetFromJson(json_file=args.json_file, |
| image_path=args.image_path, |
| processer=processor, |
| image_transform=image_transform, |
| max_input_length_limit=args.max_input_length_limit, |
| condition_dropout_prob=args.condition_dropout_prob, |
| keep_raw_resolution=args.keep_raw_resolution |
| ) |
| collate_fn = TrainDataCollator(pad_token_id=processor.text_tokenizer.eos_token_id, hidden_size=model.llm.config.hidden_size, keep_raw_resolution=args.keep_raw_resolution) |
|
|
| loader = DataLoader( |
| dataset, |
| collate_fn=collate_fn, |
| batch_size=args.batch_size_per_device, |
| shuffle=True, |
| num_workers=args.num_workers, |
| pin_memory=True, |
| drop_last=True, |
| prefetch_factor=2, |
| ) |
| |
| if accelerator.is_main_process: |
| logger.info(f"Dataset contains {len(dataset):,}") |
|
|
| num_update_steps_per_epoch = math.ceil(len(loader) / args.gradient_accumulation_steps) |
| max_train_steps = args.epochs * num_update_steps_per_epoch |
| lr_scheduler = get_scheduler( |
| args.lr_scheduler, |
| optimizer=opt, |
| num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, |
| num_training_steps=max_train_steps * args.gradient_accumulation_steps, |
| ) |
|
|
| |
| model.train() |
| |
| if ema is not None: |
| update_ema(ema, model, decay=0) |
| ema.eval() |
| |
|
|
| if ema is not None: |
| model, ema = accelerator.prepare(model, ema) |
| else: |
| model = accelerator.prepare(model) |
|
|
| opt, loader, lr_scheduler = accelerator.prepare(opt, loader, lr_scheduler) |
| |
| |
| |
| train_steps, log_steps = 0, 0 |
| running_loss = 0 |
| start_time = time() |
| |
| if accelerator.is_main_process: |
| logger.info(f"Training for {args.epochs} epochs...") |
| for epoch in range(args.epochs): |
| if accelerator.is_main_process: |
| logger.info(f"Beginning epoch {epoch}...") |
| |
| for data in loader: |
| with accelerator.accumulate(model): |
| with torch.no_grad(): |
| output_images = data['output_images'] |
| input_pixel_values = data['input_pixel_values'] |
| if isinstance(output_images, list): |
| output_images = vae_encode_list(vae, output_images, weight_dtype) |
| if input_pixel_values is not None: |
| input_pixel_values = vae_encode_list(vae, input_pixel_values, weight_dtype) |
| else: |
| output_images = vae_encode(vae, output_images, weight_dtype) |
| if input_pixel_values is not None: |
| input_pixel_values = vae_encode(vae, input_pixel_values, weight_dtype) |
| |
|
|
| model_kwargs = dict(input_ids=data['input_ids'], input_img_latents=input_pixel_values, input_image_sizes=data['input_image_sizes'], attention_mask=data['attention_mask'], position_ids=data['position_ids'], padding_latent=data['padding_images'], past_key_values=None, return_past_key_values=False) |
| |
| loss_dict = training_losses(model, output_images, model_kwargs) |
| loss = loss_dict["loss"].mean() |
|
|
| running_loss += loss.item() |
| accelerator.backward(loss) |
| if args.max_grad_norm is not None and accelerator.sync_gradients: |
| accelerator.clip_grad_norm_(model.parameters(), args.max_grad_norm) |
| opt.step() |
| lr_scheduler.step() |
| opt.zero_grad() |
|
|
| log_steps += 1 |
| train_steps += 1 |
|
|
| accelerator.log({"training_loss": loss.item()}, step=train_steps) |
| if train_steps % args.gradient_accumulation_steps == 0: |
| if accelerator.sync_gradients and ema is not None: |
| update_ema(ema, model) |
| |
| if train_steps % (args.log_every * args.gradient_accumulation_steps) == 0 and train_steps > 0: |
| torch.cuda.synchronize() |
| end_time = time() |
| steps_per_sec = log_steps / args.gradient_accumulation_steps / (end_time - start_time) |
| |
| avg_loss = torch.tensor(running_loss / log_steps, device=device) |
| dist.all_reduce(avg_loss, op=dist.ReduceOp.SUM) |
| avg_loss = avg_loss.item() / accelerator.num_processes |
| |
| if accelerator.is_main_process: |
| cur_lr = opt.param_groups[0]["lr"] |
| logger.info(f"(step={int(train_steps/args.gradient_accumulation_steps):07d}) Train Loss: {avg_loss:.4f}, Train Steps/Sec: {steps_per_sec:.2f}, Epoch: {train_steps/len(loader)}, LR: {cur_lr}") |
|
|
| |
| running_loss = 0 |
| log_steps = 0 |
| start_time = time() |
|
|
|
|
| if train_steps % (args.ckpt_every * args.gradient_accumulation_steps) == 0 and train_steps > 0: |
| if accelerator.distributed_type == DistributedType.FSDP: |
| state_dict = accelerator.get_state_dict(model) |
| ema_state_dict = accelerator.get_state_dict(ema) if ema is not None else None |
| else: |
| if not args.use_lora: |
| state_dict = model.module.state_dict() |
| ema_state_dict = accelerator.get_state_dict(ema) if ema is not None else None |
|
|
| if accelerator.is_main_process: |
| if args.use_lora: |
| checkpoint_path = f"{checkpoint_dir}/{int(train_steps/args.gradient_accumulation_steps):07d}/" |
| os.makedirs(checkpoint_path, exist_ok=True) |
|
|
| model.module.save_pretrained(checkpoint_path) |
| else: |
| checkpoint_path = f"{checkpoint_dir}/{int(train_steps/args.gradient_accumulation_steps):07d}/" |
| os.makedirs(checkpoint_path, exist_ok=True) |
| torch.save(state_dict, os.path.join(checkpoint_path, "model.pt")) |
| processor.text_tokenizer.save_pretrained(checkpoint_path) |
| model.llm.config.save_pretrained(checkpoint_path) |
| if ema_state_dict is not None: |
| checkpoint_path = f"{checkpoint_dir}/{int(train_steps/args.gradient_accumulation_steps):07d}_ema" |
| os.makedirs(checkpoint_path, exist_ok=True) |
| torch.save(state_dict, os.path.join(checkpoint_path, "model.pt")) |
| processor.text_tokenizer.save_pretrained(checkpoint_path) |
| model.llm.config.save_pretrained(checkpoint_path) |
| logger.info(f"Saved checkpoint to {checkpoint_path}") |
|
|
| dist.barrier() |
| accelerator.end_training() |
| model.eval() |
| |
| if accelerator.is_main_process: |
| logger.info("Done!") |
|
|
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--results_dir", type=str, default="results") |
| parser.add_argument("--model_name_or_path", type=str, default="OmniGen") |
| parser.add_argument("--json_file", type=str) |
| parser.add_argument("--image_path", type=str, default=None) |
| parser.add_argument("--epochs", type=int, default=1400) |
| parser.add_argument("--batch_size_per_device", type=int, default=1) |
| parser.add_argument("--vae_path", type=str, default=None) |
| parser.add_argument("--num_workers", type=int, default=4) |
| parser.add_argument("--log_every", type=int, default=100) |
| parser.add_argument("--ckpt_every", type=int, default=20000) |
| parser.add_argument("--max_grad_norm", type=float, default=1.0) |
| parser.add_argument("--lr", type=float, default=1e-4) |
| parser.add_argument("--max_input_length_limit", type=int, default=1024) |
| parser.add_argument("--condition_dropout_prob", type=float, default=0.1) |
| parser.add_argument("--adam_weight_decay", type=float, default=0.0) |
| parser.add_argument( |
| "--keep_raw_resolution", |
| action="store_true", |
| help="multiple_resolutions", |
| ) |
| parser.add_argument("--max_image_size", type=int, default=1344) |
|
|
| parser.add_argument( |
| "--use_lora", |
| action="store_true", |
| ) |
| parser.add_argument( |
| "--lora_rank", |
| type=int, |
| default=8 |
| ) |
|
|
| parser.add_argument( |
| "--use_ema", |
| action="store_true", |
| help="Whether or not to use ema.", |
| ) |
| parser.add_argument( |
| "--lr_scheduler", |
| type=str, |
| default="constant", |
| help=( |
| 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' |
| ' "constant", "constant_with_warmup"]' |
| ), |
| ) |
| parser.add_argument( |
| "--lr_warmup_steps", type=int, default=1000, help="Number of steps for the warmup in the lr scheduler." |
| ) |
| parser.add_argument( |
| "--report_to", |
| type=str, |
| default="tensorboard", |
| help=( |
| 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' |
| ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' |
| ), |
| ) |
| parser.add_argument( |
| "--mixed_precision", |
| type=str, |
| default="bf16", |
| choices=["no", "fp16", "bf16"], |
| help=( |
| "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" |
| " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" |
| " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." |
| ), |
| ) |
| parser.add_argument( |
| "--gradient_accumulation_steps", |
| type=int, |
| default=1, |
| help="Number of updates steps to accumulate before performing a backward/update pass.", |
| ) |
|
|
|
|
| args = parser.parse_args() |
| assert args.max_image_size % 16 == 0, "Image size must be divisible by 16." |
|
|
| main(args) |
|
|
|
|
|
|