Spaces:
Runtime error
Runtime error
| # flake8: noqa | |
| import hydra | |
| import pyrootutils | |
| import os | |
| import torch | |
| from accelerate import Accelerator | |
| from accelerate.logging import get_logger | |
| from accelerate.utils import ProjectConfiguration | |
| from tqdm.auto import tqdm | |
| from omegaconf import OmegaConf | |
| from omegaconf.dictconfig import DictConfig | |
| from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler, \ | |
| Transformer2DModel | |
| from transformers import CLIPTextModel, CLIPTokenizer | |
| import argparse | |
| from flask import Flask, request | |
| from typing import List, Union | |
| import json | |
| from typing import Optional | |
| import transformers | |
| from dataclasses import dataclass, field, asdict, is_dataclass | |
| from torchdata.dataloader2 import DataLoader2, MultiProcessingReadingService, DistributedReadingService, \ | |
| SequentialReadingService | |
| import logging | |
| pyrootutils.setup_root(__file__, indicator='.project-root', pythonpath=True) | |
| from src.train.schedular import get_scheduler | |
| from src.train.dist_utils import all_gather | |
| # logger = get_logger(__name__, log_level='info') | |
| log_format = '%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| logging.basicConfig(level=logging.INFO, format=log_format) | |
| logger = logging.getLogger(__name__) | |
| # os.environ["WANDB_MODE"] = "offline" | |
| class ConfigPathArguments: | |
| image_transform: Optional[str] = field(default=None, metadata={"help": "config path of image transform"}) | |
| sd_image_transform: Optional[str] = field(default=None, | |
| metadata={"help": "config path of stable diffusion image transform"}) | |
| # tokenizer: Optional[str] = field(default=None, metadata={"help": "config path of tokenizer used to initialize tokenizer"}) | |
| visual_encoder: Optional[str] = field(default=None, metadata={"help": "config path of visual encoder"}) | |
| # text_encoder: Optional[str] = field(default=None, metadata={"help": "config path of visual encoder"}) | |
| discrete_model: Optional[str] = field(default=None, metadata={"help": "config path of discrete model"}) | |
| # noise_scheduler: Optional[str] = field(default=None, metadata={"help": "config path of noise scheduler"}) | |
| # vae: Optional[str] = field(default=None, metadata={"help": "config path of vae"}) | |
| adapter: Optional[str] = field(default=None, metadata={"help": "config path of adapter"}) | |
| train_dataset: Optional[str] = field(default=None, metadata={"help": "config path of training dataset"}) | |
| fsdp_plugin: Optional[str] = field(default=None, metadata={"help": "config path of fsdp plugin"}) | |
| deepspeed_plugin: Optional[str] = field(default=None, metadata={"help": "config path of deepspeed plugin"}) | |
| tokenizer: Optional[str] = field(default=None, | |
| metadata={"help": "config path of tokenizer used to initialize tokenizer"}) | |
| llm_model: Optional[str] = field(default=None, metadata={"help": "config path of llm"}) | |
| agent_model: Optional[str] = field(default=None, metadata={"help": "config path of agent"}) | |
| class TrainingArguments: | |
| output_dir: str = field( | |
| metadata={"help": "The output directory where the model predictions and checkpoints will be written."}, ) | |
| diffusion_model_path: Optional[str] = field(default=None, metadata={"help": "config path of training dataset"}) | |
| resume_from_checkpoint: Optional[str] = field( | |
| default=None, metadata={"help": "The path to a folder with a valid checkpoint for your model."}) | |
| resume_steps: Optional[int] = field(default=None, metadata={"help": "The training sterps of saved checkpoint"}) | |
| learning_rate: float = field(default=5e-5, metadata={"help": "The initial learning rate for AdamW."}) | |
| weight_decay: float = field(default=0.0, metadata={"help": "Weight decay for AdamW if we apply some."}) | |
| # adam_beta1: float = field(default=0.9, metadata={"help": "Beta1 for AdamW optimizer"}) | |
| # adam_beta2: float = field(default=0.999, metadata={"help": "Beta2 for AdamW optimizer"}) | |
| # adam_epsilon: float = field(default=1e-8, metadata={"help": "Epsilon for AdamW optimizer."}) | |
| max_grad_norm: float = field(default=1.0, metadata={"help": "Max gradient norm."}) | |
| gradient_accumulation_steps: int = field( | |
| default=1, metadata={"help": "Number of updates steps to accumulate before performing a backward/update pass."}) | |
| mixed_precision: Optional[str] = field( | |
| default='no', | |
| metadata={ | |
| "help": | |
| "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=1.10.and an Nvidia Ampere GPU." | |
| }) | |
| num_train_epochs: int = field(default=3, metadata={"help": "Total number of training epochs to perform."}) | |
| max_steps: int = field(default=-1, metadata={"help": "Total number of training steps to perform. "}) | |
| save_steps: int = field(default=10000, metadata={"help": "Number of updates steps before two checkpoint saves."}) | |
| lr_scheduler_type: str = field(default="cosine", metadata={"help": "The scheduler type to use."}) | |
| warmup_steps: int = field(default=0, metadata={"help": "Linear warmup over warmup_steps."}) | |
| min_lr_ratio: float = field(default=0.01, metadata={"help": "Minimal learning rate ratio."}) | |
| dataloader_num_workers: int = field(default=8, metadata={"help": "The number of workers to use for data loading."}) | |
| project_name: str = field(default="IPAdapter", metadata={"help": "The name of experiment"}) | |
| expr_name: str = field(default="", metadata={"help": "The name of experiment"}) | |
| def build_dataloader(dataset_cfg, image_transform, sd_image_transform, tokenizer, dataloader_num_workers=4): | |
| dataset = hydra.utils.instantiate(dataset_cfg, | |
| image_transform=image_transform, | |
| sd_image_transform=sd_image_transform, | |
| tokenizer=tokenizer) | |
| mp_service = MultiProcessingReadingService(num_workers=dataloader_num_workers) | |
| dist_service = DistributedReadingService() | |
| reading_service = SequentialReadingService(dist_service, mp_service) | |
| dataloader = DataLoader2(dataset, reading_service=reading_service) | |
| return dataloader | |
| def get_metric(output): | |
| metric = {} | |
| for key, value in output.items(): | |
| if 'loss' in key: | |
| metric[key] = value.item() | |
| return metric | |
| def merge_config(**kwargs): | |
| config = {} | |
| for key, value in kwargs.items(): | |
| if isinstance(value, argparse.Namespace): | |
| config[key] = vars(value) | |
| elif isinstance(value, DictConfig): | |
| config[key] = OmegaConf.to_object(value) | |
| elif is_dataclass(value): | |
| config[key] = asdict(value) | |
| elif isinstance(value, dict): | |
| config[key] = value | |
| else: | |
| logger.error(f'key: {key}, value: {value} will not be merged.') | |
| return config | |
| def trainable_params(model): | |
| count = 0 | |
| for name, param in model.named_parameters(): | |
| if param.requires_grad: | |
| count += param.numel() | |
| return count | |
| def train(): | |
| parser = transformers.HfArgumentParser((ConfigPathArguments, TrainingArguments)) | |
| cfg_path, args = parser.parse_args_into_dataclasses() | |
| project_config = ProjectConfiguration(project_dir=args.output_dir, | |
| logging_dir=os.path.join(args.output_dir, 'logs')) | |
| assert int(cfg_path.fsdp_plugin is not None) + int(cfg_path.deepspeed_plugin is not None) <= 1 | |
| if cfg_path.fsdp_plugin is not None: | |
| fsdp_plugin_cfg = OmegaConf.load(cfg_path.fsdp_plugin) | |
| fsdp_plugin = hydra.utils.instantiate(fsdp_plugin_cfg) | |
| logger.info('Use FSDP plugin') | |
| else: | |
| fsdp_plugin = None | |
| if cfg_path.deepspeed_plugin is not None: | |
| deepspeed_plugin_cfg = OmegaConf.load(cfg_path.deepspeed_plugin) | |
| deepspeed_plugin = hydra.utils.instantiate(deepspeed_plugin_cfg) | |
| logger.info('Use deepspeed plugin') | |
| else: | |
| deepspeed_plugin = None | |
| accelerator = Accelerator( | |
| mixed_precision=args.mixed_precision, | |
| log_with=['tensorboard', 'wandb'], | |
| project_config=project_config, | |
| gradient_accumulation_steps=args.gradient_accumulation_steps, | |
| step_scheduler_with_optimizer=False, | |
| fsdp_plugin=fsdp_plugin, | |
| deepspeed_plugin=deepspeed_plugin, | |
| ) | |
| logger.info('Init accelerator done.') | |
| if cfg_path.deepspeed_plugin is not None: | |
| accelerator.state.deepspeed_plugin.deepspeed_config['train_micro_batch_size_per_gpu'] = 100 | |
| os.makedirs(args.output_dir, exist_ok=True) | |
| image_transform_cfg = OmegaConf.load(cfg_path.image_transform) | |
| image_transform = hydra.utils.instantiate(image_transform_cfg) | |
| sd_image_transform_cfg = OmegaConf.load(cfg_path.sd_image_transform) | |
| sd_image_transform = hydra.utils.instantiate(sd_image_transform_cfg) | |
| tokenizer_cfg = OmegaConf.load(cfg_path.tokenizer) | |
| tokenizer = hydra.utils.instantiate(tokenizer_cfg) | |
| visual_encoder_cfg = OmegaConf.load(cfg_path.visual_encoder) | |
| visual_encoder = hydra.utils.instantiate(visual_encoder_cfg) | |
| logger.info('Load visual encoder done.') | |
| discrete_model_cfg = OmegaConf.load(cfg_path.discrete_model) | |
| discrete_model = hydra.utils.instantiate(discrete_model_cfg) | |
| logger.info('Load discrete model done.') | |
| # noise_scheduler_cfg = OmegaConf.load(cfg_path.noise_scheduler) | |
| # noise_scheduler = hydra.utils.instantiate(noise_scheduler_cfg) | |
| # if cfg_path.tokenizer is not None: | |
| # tokenizer_cfg = OmegaConf.load(cfg_path.tokenizer) | |
| # tokenizer = hydra.utils.instantiate(tokenizer_cfg) | |
| # else: | |
| # tokenizer_cfg = None | |
| # tokenizer = None | |
| # if cfg_path.text_encoder is not None: | |
| # text_encoder_cfg = OmegaConf.load(cfg_path.text_encoder) | |
| # text_encoder = hydra.utils.instantiate(text_encoder_cfg) | |
| # logger.info('Load text encoder done.') | |
| # else: | |
| # text_encoder_cfg = None | |
| # text_encoder = None | |
| # vae_cfg = OmegaConf.load(cfg_path.vae) | |
| # vae = hydra.utils.instantiate(vae_cfg) | |
| # logger.info('Load vae done.') | |
| # noise_scheduler = DDPMScheduler.from_pretrained(args.diffusion_model_path, subfolder="scheduler") | |
| # tokenizer = CLIPTokenizer.from_pretrained(args.diffusion_model_path, subfolder="tokenizer") | |
| # text_encoder = CLIPTextModel.from_pretrained(args.diffusion_model_path, subfolder="text_encoder") | |
| # vae = AutoencoderKL.from_pretrained(args.diffusion_model_path, subfolder="vae") | |
| # unet = UNet2DConditionModel.from_pretrained(args.diffusion_model_path, subfolder="unet") | |
| # print('load diffusion model done') | |
| # noise_scheduler = DPMSolverMultistepScheduler.from_pretrained(args.diffusion_model_path, subfolder="scheduler") | |
| noise_scheduler = DDPMScheduler.from_pretrained(args.diffusion_model_path, subfolder="scheduler") | |
| text_encoder = None | |
| vae = AutoencoderKL.from_pretrained(args.diffusion_model_path, subfolder="vae") | |
| unet = UNet2DConditionModel.from_pretrained(args.diffusion_model_path, subfolder="unet") | |
| unet.enable_xformers_memory_efficient_attention() | |
| unet.enable_gradient_checkpointing() | |
| vae.requires_grad_(False) | |
| visual_encoder.requires_grad_(False) | |
| discrete_model.requires_grad_(False) | |
| adapter_cfg = OmegaConf.load(cfg_path.adapter) | |
| adapter = hydra.utils.instantiate(adapter_cfg, unet=unet) | |
| logger.info('Load adapter done.') | |
| weight_dtype = torch.float32 | |
| if accelerator.mixed_precision == "fp16": | |
| weight_dtype = torch.float16 | |
| elif accelerator.mixed_precision == "bf16": | |
| weight_dtype = torch.bfloat16 | |
| vae.to(accelerator.device, dtype=weight_dtype) | |
| visual_encoder.to(accelerator.device, dtype=weight_dtype) | |
| discrete_model.to(accelerator.device, dtype=weight_dtype) | |
| if text_encoder is not None: | |
| text_encoder.to(accelerator.device, dtype=weight_dtype) | |
| train_dataset_cfg = OmegaConf.load(cfg_path.train_dataset) | |
| train_dataloader = build_dataloader(dataset_cfg=train_dataset_cfg, | |
| image_transform=image_transform, | |
| sd_image_transform=sd_image_transform, | |
| tokenizer=tokenizer, | |
| dataloader_num_workers=args.dataloader_num_workers) | |
| llm_model_cfg = OmegaConf.load(cfg_path.llm_model) | |
| llm_model = hydra.utils.instantiate(llm_model_cfg) | |
| llm_model.gradient_checkpointing_enable() | |
| llm_model.config.use_cache = False | |
| logger.info('Load llm model done.') | |
| agent_model_cfg = OmegaConf.load(cfg_path.agent_model) | |
| agent_model = hydra.utils.instantiate(agent_model_cfg, llm=llm_model).to(accelerator.device, dtype=weight_dtype) | |
| agent_model.requires_grad_(False) | |
| agent_model.llm.base_model.model.use_kv_cache_head = False | |
| logger.info('Load agent model done.') | |
| if cfg_path.fsdp_plugin is not None: | |
| adapter = accelerator.prepare(adapter) | |
| optimizer = torch.optim.AdamW(adapter.params_to_opt(), lr=args.learning_rate, weight_decay=args.weight_decay) | |
| logger.info('Init optimizer done.') | |
| scheduler = get_scheduler(name=args.lr_scheduler_type, | |
| optimizer=optimizer, | |
| num_warmup_steps=args.warmup_steps, | |
| num_training_steps=args.max_steps, | |
| min_lr_ratio=args.min_lr_ratio) | |
| # accelerator.register_for_checkpointing(scheduler) | |
| # adapter.adapter, adapter.resampler, optimizer, scheduler = accelerator.prepare( | |
| # adapter.adapter, | |
| # adapter.resampler, | |
| # optimizer, | |
| # scheduler, | |
| # ) | |
| # adapter, optimizer, scheduler = accelerator.prepare( | |
| # adapter, | |
| # optimizer, | |
| # scheduler, | |
| # ) | |
| if cfg_path.fsdp_plugin is not None: | |
| optimizer, scheduler = accelerator.prepare(optimizer, scheduler) | |
| else: | |
| adapter, optimizer, scheduler = accelerator.prepare(adapter, optimizer, scheduler) | |
| logger.info('Prepare accelerator done.') | |
| # config_record = merge_config(discrete_model=discrete_model_cfg, | |
| # visual_encoder=visual_encoder_cfg, | |
| # text_encoder=text_encoder_cfg, | |
| # image_transform=image_transform_cfg, | |
| # sd_image_transform=sd_image_transform_cfg, | |
| # tokenizer=tokenizer_cfg, | |
| # train_dataset=train_dataset_cfg, | |
| # vae=vae_cfg, | |
| # adapter=adapter_cfg, | |
| # train_args=args) | |
| config_record = merge_config(discrete_model=discrete_model_cfg, | |
| visual_encoder=visual_encoder_cfg, | |
| image_transform=image_transform_cfg, | |
| sd_image_transform=sd_image_transform_cfg, | |
| train_dataset=train_dataset_cfg, | |
| adapter=adapter_cfg, | |
| train_args=args, | |
| agent_model=agent_model_cfg, | |
| llm_model=llm_model, | |
| tokenizer=tokenizer_cfg) | |
| accelerator.init_trackers(project_name=args.project_name, | |
| init_kwargs={"wandb": { | |
| "config": config_record, | |
| "name": args.expr_name, | |
| "dir": args.output_dir | |
| }}) | |
| if args.resume_from_checkpoint is not None: | |
| logger.info(f'Load checkpoint from {args.resume_from_checkpoint}') | |
| accelerator.load_state(args.resume_from_checkpoint) | |
| num_params = trainable_params(adapter) | |
| logger.info("***** Running training *****") | |
| logger.info(f" Total optimization steps = {args.max_steps}") | |
| logger.info(f" Total trainable params = {num_params}") | |
| for name, param in adapter.named_parameters(): | |
| if param.requires_grad: | |
| print(name) | |
| # print(f'adapter: {trainable_params(adapter.adapter)}') | |
| # print(f'resampler: {trainable_params(adapter.resampler)}') | |
| # Only show the progress bar once on each machine. | |
| progress_bar = tqdm(range(args.max_steps), disable=not accelerator.is_main_process) | |
| progress_bar.set_description("Steps") | |
| global_step = 0 | |
| if args.resume_steps is not None: | |
| global_step = args.resume_steps | |
| progress_bar.update(args.resume_steps) | |
| for epoch in range(args.num_train_epochs): | |
| logger.info('Start new epoch') | |
| for step, batch in enumerate(train_dataloader): | |
| with accelerator.accumulate(adapter): | |
| with torch.no_grad(): | |
| image_embeds = visual_encoder(batch['images'].to(accelerator.device, dtype=weight_dtype)) | |
| image_embeds = discrete_model.encode_image_embeds(image_embeds) | |
| if text_encoder is not None: | |
| text_embeds = text_encoder(batch['text_input_ids'].to(accelerator.device))[0] | |
| else: | |
| text_embeds = None | |
| latents = vae.encode( | |
| batch["sd_images"].to(accelerator.device, dtype=weight_dtype)).latent_dist.sample() | |
| latents = latents * vae.config.scaling_factor | |
| llm_output = agent_model(input_ids=batch['input_ids'].to(accelerator.device), | |
| attention_mask=batch['attention_mask'].to(accelerator.device), | |
| labels=batch['labels'].to(accelerator.device), | |
| image_embeds=image_embeds, | |
| embeds_gen_mask=batch['embeds_gen_mask'].to(accelerator.device) | |
| if batch['embeds_gen_mask'] is not None else None, | |
| embeds_cmp_mask=batch['embeds_cmp_mask'].to(accelerator.device) | |
| if batch['embeds_cmp_mask'] is not None else None, | |
| ids_gen_mask=batch['ids_gen_mask'].to(accelerator.device), | |
| ids_cmp_mask=batch['ids_cmp_mask'].to(accelerator.device), | |
| return_recon_image_embeds=True) | |
| time_ids = batch['time_ids'].to(accelerator.device) | |
| # Sample noise that we'll add to the latents | |
| noise = torch.randn_like(latents) | |
| bsz = latents.shape[0] | |
| # Sample a random timestep for each image | |
| timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) | |
| timesteps = timesteps.long() | |
| # Add noise to the latents according to the noise magnitude at each timestep | |
| # (this is the forward diffusion process) | |
| noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | |
| output = adapter(noisy_latents=noisy_latents, | |
| timesteps=timesteps, | |
| image_embeds=llm_output['recon_image_embeds'], | |
| text_embeds=None, | |
| noise=noise, | |
| time_ids=time_ids) | |
| loss = output['total_loss'] | |
| accelerator.backward(loss) | |
| if accelerator.sync_gradients: | |
| accelerator.clip_grad_norm_(adapter.parameters(), max_norm=args.max_grad_norm) | |
| optimizer.step() | |
| scheduler.step() | |
| optimizer.zero_grad() | |
| if accelerator.sync_gradients: | |
| progress_bar.update(1) | |
| global_step += 1 | |
| if global_step % args.save_steps == 0: | |
| save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") | |
| accelerator.save_state(save_path) | |
| metric = get_metric(output) | |
| metric['lr'] = optimizer.param_groups[0]['lr'] | |
| accelerator.log(metric, step=global_step) | |
| metric = {key: (format(value, ".6f") if isinstance(value, float) else value) for key, value in | |
| metric.items()} | |
| # if accelerator.is_local_main_process: | |
| if accelerator.is_main_process: | |
| tqdm.write(str(metric)) | |
| # print(metric) | |
| if global_step >= args.max_steps: | |
| break | |
| accelerator.end_training() | |
| if __name__ == '__main__': | |
| train() | |