Spaces:
Sleeping
Sleeping
| import ml_collections | |
| import torch | |
| from torch import multiprocessing as mp | |
| from data.data_factory import OnlineFeatures | |
| import utils | |
| import accelerate | |
| from tqdm.auto import tqdm | |
| import tempfile | |
| from absl import logging | |
| import builtins | |
| import os | |
| import wandb | |
| from PIL import Image | |
| if 'NCCL_ASYNC_ERROR_HANDLING' in os.environ: | |
| os.environ['TORCH_NCCL_ASYNC_ERROR_HANDLING'] = os.environ['NCCL_ASYNC_ERROR_HANDLING'] | |
| import libs.autoencoder | |
| from diffusion.flow_matching import FlowMatching, ODEEulerFlowMatchingSolver | |
| from tools.fid_score import calculate_fid_given_paths | |
| from libs.janus.models import MultiModalityCausalLM, VLChatProcessor | |
| from transformers import AutoModelForCausalLM | |
| def train(config): | |
| mp.set_start_method('spawn') | |
| accelerator = accelerate.Accelerator() | |
| device = accelerator.device | |
| accelerate.utils.set_seed(config.seed, device_specific=True) | |
| logging.info(f'Process {accelerator.process_index} using device: {device}') | |
| config.mixed_precision = accelerator.mixed_precision | |
| config = ml_collections.FrozenConfigDict(config) | |
| assert config.train.batch_size % accelerator.num_processes == 0 | |
| mini_batch_size = config.train.batch_size // accelerator.num_processes | |
| if accelerator.is_main_process: | |
| os.makedirs(config.ckpt_root, exist_ok=True) | |
| os.makedirs(config.sample_dir, exist_ok=True) | |
| accelerator.wait_for_everyone() | |
| if accelerator.is_main_process: | |
| wandb_mode = ( | |
| getattr(config, 'wandb_mode', None) or | |
| os.environ.get('WANDB_MODE', None) or | |
| 'online' | |
| ) | |
| wandb_project = ( | |
| getattr(config, 'wandb_project', None) or | |
| os.environ.get('WANDB_PROJECT', None) or | |
| f'{config.config_name}_{config.dataset.name}' | |
| ) | |
| wandb.init(dir=os.path.abspath(config.workdir), project=wandb_project, config=config.to_dict(), | |
| name=config.hparams, job_type='train', mode=wandb_mode) | |
| utils.set_logger(log_level='info', fname=os.path.join(config.workdir, 'output.log')) | |
| logging.info(config) | |
| logging.info(f'Optimizer config: {config.optimizer}') | |
| else: | |
| utils.set_logger(log_level='error') | |
| builtins.print = lambda *args: None | |
| logging.info(f'Run on {accelerator.num_processes} devices') | |
| model_path = "deepseek-ai/Janus-Pro-1B" | |
| vl_chat_processor: VLChatProcessor = VLChatProcessor.from_pretrained(model_path) | |
| # pre-encode tokenizer | |
| training_question = "" | |
| sft_format = vl_chat_processor.apply_sft_template_for_multi_turn_prompts( | |
| conversations=[ | |
| {"role": "<|User|>", "content": f"<image_placeholder>\n{training_question}"}, | |
| {"role": "<|Assistant|>", "content": ""}, | |
| ], | |
| sft_format=vl_chat_processor.sft_format, | |
| system_prompt=vl_chat_processor.system_prompt, | |
| ) | |
| cached_training_input_ids = vl_chat_processor.tokenizer.encode(sft_format) | |
| logging.info(f'Pre-encoded tokenizer completed, input_ids length: {len(cached_training_input_ids)}') | |
| dataset = OnlineFeatures( | |
| train_tar_pattern=config.dataset.train_tar_pattern, | |
| test_tar_pattern=config.dataset.test_tar_pattern, | |
| vis_image_root=config.dataset.vis_image_root, | |
| task=config.dataset.task, | |
| resolution=config.dataset.resolution, | |
| shuffle_buffer=config.dataset.shuffle_buffer, | |
| resampled=config.dataset.resampled, | |
| split_data_by_node=config.dataset.split_data_by_node, | |
| estimated_samples_per_shard=config.dataset.estimated_samples_per_shard, | |
| cfg=config.dataset.cfg, | |
| fid_stat_path=getattr(config, 'fid_stat_path', None), | |
| num_workers=getattr(config, 'num_workers', 8), | |
| batch_size=mini_batch_size, | |
| test_batch_size=config.sample.mini_batch_size, | |
| test_num_workers=3, | |
| vl_chat_processor=vl_chat_processor, | |
| device=device, | |
| sampling_weights=getattr(config.dataset, 'sampling_weights', None), | |
| ) | |
| test_dataset = dataset.test | |
| train_dataset_loader, test_dataset_loader = dataset.train_dataloader, dataset.test_dataloader | |
| train_state = utils.initialize_train_state(config, device) | |
| nnet, nnet_ema, optimizer = accelerator.prepare( | |
| train_state.nnet, train_state.nnet_ema, train_state.optimizer | |
| ) | |
| lr_scheduler = train_state.lr_scheduler | |
| resume_path = config.resume_ckpt_root | |
| if resume_path and resume_path.endswith('.ckpt') and os.path.isdir(resume_path): | |
| logging.info(f'Load from checkpoint directory: {resume_path}') | |
| train_state.load(resume_path) | |
| else: | |
| train_state.resume(resume_path) | |
| autoencoder = libs.autoencoder.get_model(**config.autoencoder) | |
| autoencoder.to(device) | |
| vl_gpt: MultiModalityCausalLM = AutoModelForCausalLM.from_pretrained( | |
| model_path, trust_remote_code=True, use_safetensors=True | |
| ) | |
| vl_gpt = vl_gpt.half().eval().to(device) | |
| def decode(_batch): | |
| return autoencoder.decode(_batch) | |
| def get_data_generator(): | |
| while True: | |
| remaining_steps = config.train.n_steps - train_state.step | |
| for data in tqdm( | |
| train_dataset_loader, | |
| disable=not accelerator.is_main_process, | |
| desc=f'step {train_state.step}/{config.train.n_steps}', | |
| unit=' its', | |
| ncols=120, | |
| dynamic_ncols=True, | |
| total=remaining_steps, | |
| bar_format="{l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]" | |
| ): | |
| yield data | |
| data_generator = get_data_generator() | |
| def get_context_generator(): | |
| while True: | |
| for data in tqdm( | |
| test_dataset_loader, | |
| disable=not accelerator.is_main_process, | |
| desc='step', | |
| unit=' its' | |
| ): | |
| yield data | |
| context_generator = get_context_generator() | |
| world_size = accelerator.num_processes | |
| rank = accelerator.process_index | |
| _flow_mathcing_model = FlowMatching(world_size=world_size, rank=rank) | |
| if accelerator.is_main_process: | |
| logging.info(f"FlowMatching initialized with world_size={world_size}, rank={rank}") | |
| logging.info(f"ClipLoss will use multi-GPU feature gathering: {world_size > 1}") | |
| def train_step(_batch): | |
| _metrics = dict() | |
| optimizer.zero_grad() | |
| assert len(_batch) == 4 | |
| assert not config.dataset.cfg | |
| _batch_input_img, _batch_output_img, _batch_input_img_tensor, _batch_type = _batch | |
| if isinstance(_batch_output_img, torch.Tensor): | |
| _batch_output_img = _batch_output_img.to(device, non_blocking=True) | |
| if _batch_output_img.dim() == 3: | |
| _batch_output_img = _batch_output_img.unsqueeze(0) | |
| if isinstance(_batch_input_img_tensor, torch.Tensor): | |
| _batch_input_img_tensor = _batch_input_img_tensor.to(device, non_blocking=True) | |
| if _batch_input_img_tensor.dim() == 3: | |
| _batch_input_img_tensor = _batch_input_img_tensor.unsqueeze(0) | |
| batch_size = _batch_output_img.shape[0] | |
| use_cross_atten_mask = utils.build_cross_atten_mask_from_batch_type(_batch_type, batch_size, device) | |
| moments_256 = autoencoder(_batch_output_img, fn='encode_moments').detach() | |
| _z = autoencoder.sample(moments_256) | |
| input_moments_256 = autoencoder(_batch_input_img_tensor, fn='encode_moments').detach() | |
| _input_image_latent = autoencoder.sample(input_moments_256) | |
| _batch_con, _batch_mask = utils.get_input_image_embeddings_and_masks( | |
| batch_input_images=_batch_input_img, | |
| vl_chat_processor=vl_chat_processor, | |
| vl_gpt=vl_gpt, | |
| device=device, | |
| question="", | |
| num_image_tokens=576, | |
| output_tokens=576, | |
| accelerator=accelerator, | |
| cached_input_ids=cached_training_input_ids, | |
| ) | |
| loss, loss_dict = _flow_mathcing_model( | |
| _z, nnet, | |
| loss_coeffs=config.loss_coeffs, | |
| cond=_batch_con, con_mask=_batch_mask, | |
| batch_img_clip=_batch_output_img, | |
| nnet_style=config.nnet.name, | |
| text_token=None, | |
| model_config=config.nnet.model_args, | |
| all_config=config, | |
| training_step=train_state.step, | |
| image_latent=_input_image_latent, | |
| use_cross_atten_mask=use_cross_atten_mask, | |
| ) | |
| _metrics['loss'] = accelerator.gather(loss.detach()).mean() | |
| for key in loss_dict.keys(): | |
| _metrics[key] = accelerator.gather(loss_dict[key].detach()).mean() | |
| accelerator.backward(loss.mean()) | |
| optimizer.step() | |
| lr_scheduler.step() | |
| train_state.ema_update(config.get('ema_rate', 0.9999)) | |
| train_state.step += 1 | |
| return dict(lr=train_state.optimizer.param_groups[0]['lr'], **_metrics) | |
| def ode_fm_solver_sample(nnet_ema, _n_samples, _sample_steps, context=None, token_mask=None, image_latent=None, use_cross_atten_mask=None): | |
| with torch.no_grad(): | |
| _z_gaussian = torch.randn(_n_samples, *config.z_shape, device=device) | |
| _z_x0, _mu, _log_var = nnet_ema(context, text_encoder = True, shape = _z_gaussian.shape, mask=token_mask) | |
| _z_init = _z_x0.reshape(_z_gaussian.shape) | |
| assert config.sample.scale > 1 | |
| _cfg = config.sample.scale | |
| has_null_indicator = hasattr(config.nnet.model_args, "cfg_indicator") | |
| ode_solver = ODEEulerFlowMatchingSolver(nnet_ema, step_size_type="step_in_dsigma", guidance_scale=_cfg) | |
| _z, _ = ode_solver.sample( | |
| x_T=_z_init, batch_size=_n_samples, sample_steps=_sample_steps, | |
| unconditional_guidance_scale=_cfg, has_null_indicator=has_null_indicator, | |
| image_latent=image_latent, use_cross_atten_mask=use_cross_atten_mask, | |
| ) | |
| image_unprocessed = decode(_z) | |
| return image_unprocessed | |
| def eval_step(n_samples, sample_steps): | |
| # ensure n_samples is not greater than the size of the test dataset | |
| if hasattr(test_dataset, 'num_samples'): | |
| test_dataset_size = test_dataset.num_samples | |
| if n_samples > test_dataset_size: | |
| logging.warning(f"n_samples ({n_samples}) is greater than the size of the test dataset ({test_dataset_size}), using the test dataset size") | |
| n_samples = test_dataset_size | |
| else: | |
| logging.info(f"Skip dataset size check, using n_samples={n_samples}") | |
| logging.info(f'eval_step: n_samples={n_samples}, sample_steps={sample_steps}, algorithm=ODE_Euler_Flow_Matching_Solver, ' | |
| f'mini_batch_size={config.sample.mini_batch_size}') | |
| def sample_fn(_n_samples, return_caption=False, config=None): | |
| assert not return_caption | |
| _batch_data = next(context_generator) | |
| if len(_batch_data) == 4: | |
| _input_img, _output_img, _input_img_tensor, _batch_type = _batch_data | |
| else: | |
| _input_img, _output_img, _input_img_tensor = _batch_data | |
| _batch_type = None | |
| # keep a copy for visualisation before moving to device | |
| _input_img_for_vis = _input_img_tensor.clone() if isinstance(_input_img_tensor, torch.Tensor) else _input_img_tensor | |
| if isinstance(_input_img_tensor, torch.Tensor): | |
| _input_img_tensor = _input_img_tensor.to(device, non_blocking=True) | |
| if _input_img_tensor.dim() == 3: | |
| _input_img_tensor = _input_img_tensor.unsqueeze(0) | |
| input_moments_256 = autoencoder(_input_img_tensor, fn='encode_moments').detach() | |
| _input_image_latent = autoencoder.sample(input_moments_256) | |
| _context, _token_mask = utils.get_input_image_embeddings_and_masks( | |
| batch_input_images=_input_img, | |
| vl_chat_processor=vl_chat_processor, | |
| vl_gpt=vl_gpt, | |
| device=device, | |
| question="", | |
| num_image_tokens=576, | |
| output_tokens=576, | |
| accelerator=accelerator, | |
| cached_input_ids=cached_training_input_ids, | |
| ) | |
| assert _context.size(0) == _n_samples | |
| use_cross_atten_mask = utils.build_cross_atten_mask_from_batch_type(_batch_type, _n_samples, device) | |
| generated_samples = ode_fm_solver_sample( | |
| nnet_ema, _n_samples, sample_steps, | |
| context=_context, token_mask=_token_mask, | |
| image_latent=_input_image_latent, use_cross_atten_mask=use_cross_atten_mask, | |
| ) | |
| if isinstance(_output_img, torch.Tensor): | |
| _output_img = _output_img.to(device, non_blocking=True) | |
| if _output_img.dim() == 3: | |
| _output_img = _output_img.unsqueeze(0) | |
| if isinstance(_input_img_for_vis, torch.Tensor): | |
| _input_img_for_vis = _input_img_for_vis.to(device, non_blocking=True) | |
| if _input_img_for_vis.dim() == 3: | |
| _input_img_for_vis = _input_img_for_vis.unsqueeze(0) | |
| if _input_img_for_vis.size(0) != _n_samples: | |
| if _input_img_for_vis.size(0) < _n_samples: | |
| pad = _input_img_for_vis[-1:].expand(_n_samples - _input_img_for_vis.size(0), -1, -1, -1) | |
| _input_img_for_vis = torch.cat([_input_img_for_vis, pad], dim=0) | |
| else: | |
| _input_img_for_vis = _input_img_for_vis[:_n_samples] | |
| return generated_samples, _output_img, _input_img_for_vis | |
| with tempfile.TemporaryDirectory() as temp_path: | |
| path = config.sample.path or temp_path | |
| if accelerator.is_main_process: | |
| os.makedirs(path, exist_ok=True) | |
| utils.sample2dir_with_gt(accelerator, path, n_samples, config.sample.mini_batch_size, sample_fn, dataset.unpreprocess, config=config) | |
| _fid = 0 | |
| if accelerator.is_main_process: | |
| inception_ckpt_path = getattr(config, 'inception_ckpt_path', None) | |
| _fid = calculate_fid_given_paths((dataset.fid_stat, path), inception_ckpt_path=inception_ckpt_path) | |
| logging.info(f'step={train_state.step} fid{n_samples}={_fid}') | |
| with open(os.path.join(config.workdir, 'eval.log'), 'a') as f: | |
| print(f'step={train_state.step} fid{n_samples}={_fid}', file=f) | |
| wandb.log({f'fid{n_samples}': _fid}, step=train_state.step) | |
| eval_images = [] | |
| for i in range(n_samples): | |
| img_path = os.path.join(path, f"{i}.png") | |
| if os.path.exists(img_path): | |
| img_pil = Image.open(img_path).convert("RGB") | |
| eval_images.append(wandb.Image(img_pil, caption=f"eval_sample_{i}_step_{train_state.step}")) | |
| if eval_images: | |
| wandb.log({f'eval_samples_{n_samples}_step_{train_state.step}': eval_images}, step=train_state.step) | |
| logging.info(f'Uploaded {len(eval_images)} evaluation samples to wandb at step {train_state.step}') | |
| _fid = torch.tensor(_fid, device=device) | |
| _fid = accelerator.reduce(_fid, reduction='sum') | |
| return _fid.item() | |
| logging.info(f'Start fitting, step={train_state.step}, mixed_precision={config.mixed_precision}') | |
| step_fid = [] | |
| while train_state.step < config.train.n_steps: | |
| nnet.train() | |
| batch = next(data_generator) | |
| if len(batch) == 3: | |
| batch = (batch[0], batch[1], batch[2], None) | |
| elif len(batch) != 4: | |
| raise ValueError(f"Unexpected batch length: {len(batch)}, expected 3 or 4") | |
| if isinstance(batch[1], torch.Tensor) and batch[1].device != device: | |
| batch = (batch[0], batch[1].to(device, non_blocking=True), | |
| batch[2].to(device, non_blocking=True) if isinstance(batch[2], torch.Tensor) else batch[2], | |
| batch[3]) | |
| elif isinstance(batch[2], torch.Tensor) and batch[2].device != device: | |
| batch = (batch[0], batch[1], batch[2].to(device, non_blocking=True), batch[3]) | |
| metrics = train_step(batch) | |
| nnet.eval() | |
| if accelerator.is_main_process and train_state.step % config.train.log_interval == 0: | |
| logging.info(utils.dct2str(dict(step=train_state.step, **metrics))) | |
| logging.info(config.workdir) | |
| wandb.log(metrics, step=train_state.step) | |
| # save rigid image | |
| if train_state.step % config.train.eval_interval == 0: | |
| torch.cuda.empty_cache() | |
| logging.info('Save a grid of images...') | |
| if not hasattr(dataset, "vis_image_paths"): | |
| raise NotImplementedError() | |
| vis_image_paths = dataset.vis_image_paths[:config.train.n_samples_eval] | |
| use_cross_atten_mask = utils.build_cross_atten_mask_from_paths(vis_image_paths, device) | |
| vis_input_image_latent = utils.load_images_as_latents( | |
| vis_image_paths, config.dataset.resolution, autoencoder, device | |
| ) | |
| contexts, token_mask = utils.get_input_image_embeddings_and_masks( | |
| batch_input_images=vis_image_paths, | |
| vl_chat_processor=vl_chat_processor, | |
| vl_gpt=vl_gpt, | |
| device=device, | |
| question="", | |
| num_image_tokens=576, | |
| output_tokens=576, | |
| accelerator=accelerator, | |
| cached_input_ids=cached_training_input_ids, | |
| ) | |
| samples = ode_fm_solver_sample( | |
| nnet_ema, _n_samples=config.train.n_samples_eval, _sample_steps=50, | |
| context=contexts, token_mask=token_mask, | |
| image_latent=vis_input_image_latent, use_cross_atten_mask=use_cross_atten_mask, | |
| ) | |
| samples_unpreprocessed = dataset.unpreprocess(samples) | |
| if accelerator.is_main_process: | |
| input_images_pil = dataset.get_vis_images_as_pil(max_images=config.train.n_samples_eval) | |
| gt_images_pil = ( | |
| dataset.get_vis_output_images_as_pil(max_images=config.train.n_samples_eval) | |
| if hasattr(dataset, "get_vis_output_images_as_pil") else [] | |
| ) | |
| target_device = samples_unpreprocessed[0].device if len(samples_unpreprocessed) > 0 else accelerator.device | |
| utils.save_vis_grid_and_log( | |
| samples_unpreprocessed, input_images_pil, gt_images_pil, | |
| config.sample_dir, train_state.step, wandb, target_device, | |
| ) | |
| accelerator.wait_for_everyone() | |
| torch.cuda.empty_cache() | |
| ############ save checkpoint and evaluate results | |
| if train_state.step % config.train.save_interval == 0 or train_state.step == config.train.n_steps: | |
| torch.cuda.empty_cache() | |
| logging.info(f'Save and eval checkpoint {train_state.step}...') | |
| accelerator.wait_for_everyone() | |
| if accelerator.is_main_process: | |
| utils.clean_stale_ckpt_files(config.ckpt_root) | |
| accelerator.wait_for_everyone() | |
| if accelerator.is_main_process: | |
| ckpt_path = os.path.join(config.ckpt_root, f'{train_state.step}.ckpt') | |
| try: | |
| train_state.save(ckpt_path) | |
| except Exception as e: | |
| logging.error(f'Failed to save checkpoint at step {train_state.step}: {e}') | |
| logging.warning('Continuing training despite checkpoint save failure') | |
| accelerator.wait_for_everyone() | |
| fid = eval_step(n_samples=30000, sample_steps=50) # calculate fid of the saved checkpoint | |
| step_fid.append((train_state.step, fid)) | |
| torch.cuda.empty_cache() | |
| accelerator.wait_for_everyone() | |
| logging.info(f'Finish fitting, step={train_state.step}') | |
| logging.info(f'step_fid: {step_fid}') | |
| step_best = sorted(step_fid, key=lambda x: x[1])[0][0] | |
| logging.info(f'step_best: {step_best}') | |
| train_state.load(os.path.join(config.ckpt_root, f'{step_best}.ckpt')) | |
| del metrics | |
| accelerator.wait_for_everyone() | |
| eval_step(n_samples=config.sample.n_samples, sample_steps=config.sample.sample_steps) |