Spaces:
Runtime error
Runtime error
| import datetime | |
| import logging | |
| import time | |
| from os.path import join | |
| import pandas as pd | |
| import torch | |
| import torch.backends.cudnn as cudnn | |
| import torch.distributed as dist | |
| import wandb | |
| from dataset.serialize import local_broadcast_process_authkey | |
| from dataset import MetaLoader, MetaLoader_rs2, create_dataset, create_loader, create_sampler, create_stateful_sampler | |
| from models import * | |
| from tasks.retrieval_utils import evaluation_wrapper | |
| from tasks.shared_utils import get_media_types, setup_model | |
| from utils.basic_utils import (MetricLogger, SmoothedValue, | |
| remove_files_if_exist, setup_seed) | |
| from utils.config_utils import setup_main | |
| from utils.distributed import get_rank, get_world_size, is_main_process | |
| from utils.logger import log_dict_to_wandb, setup_wandb | |
| try: | |
| from petrel_client.client import Client | |
| except: | |
| Client = None | |
| import io | |
| import os | |
| import shutil | |
| logger = logging.getLogger(__name__) | |
| ceph_ckpt_bucket = "shdd:s3://avp_ckpt" | |
| def train( | |
| model, | |
| train_loaders, | |
| optimizer, | |
| tokenizer, | |
| epoch, | |
| global_step, | |
| device, | |
| scheduler, | |
| scaler, | |
| config, | |
| skip_num=0 | |
| ): | |
| try: | |
| ceph_ckpt_path = f"{ceph_ckpt_bucket}/{config.output_dir.split('/')[-3]}/{config.output_dir.split('/')[-2]}/{config.output_dir.split('/')[-1]}" | |
| client_ckpt = Client(conf_path='~/petreloss.conf') | |
| except Exception as e: | |
| print(e) | |
| logger.info("Ceph is not working!!!") | |
| if config.use_half_precision: | |
| if config.get('use_bf16', False): | |
| cast_dtype = torch.bfloat16 | |
| else: | |
| cast_dtype = torch.float16 | |
| else: | |
| cast_dtype = None | |
| model.train() | |
| metric_logger = MetricLogger(delimiter=" ") | |
| metric_logger.add_meter("lr", SmoothedValue(window=100, fmt="{value:.6f}")) | |
| metric_logger.add_meter("temperature", SmoothedValue(window=100, fmt="{value:.4f}")) | |
| loss_names = ["loss_" + k for k, v in config.criterion.loss_weight.items() if v != 0] | |
| if config.get("use_raw_text", False): # for cosa | |
| loss_names = loss_names + ["c_loss_" + k for k, v in config.criterion.loss_weight.items() if v != 0] | |
| uta_all = config.criterion.get('uta_all', False) | |
| media_types = get_media_types(train_loaders) | |
| for name in loss_names: | |
| for m in media_types: | |
| metric_logger.add_meter( | |
| f"{m}-{name}", SmoothedValue(window=100, fmt="{value:.4f}") | |
| ) | |
| header = f"Train Epoch: [{epoch}]" | |
| log_freq = config.log_freq | |
| if config.distributed: | |
| for d in train_loaders: | |
| d.sampler.set_epoch(epoch) | |
| if config.get('use_iter_train', False): | |
| train_loader = MetaLoader_rs2(name2loader=dict(list(zip(media_types, train_loaders))), skip_num=skip_num) | |
| else: | |
| train_loader = MetaLoader(name2loader=dict(list(zip(media_types, train_loaders)))) | |
| model_without_ddp = model.module if config.distributed else model | |
| iterator = metric_logger.log_every(train_loader, log_freq, header) | |
| begin_step = global_step % len(train_loader) | |
| logger.info(f"Epoch={epoch}, begin_step={begin_step} save_ckpt_iter={config.get('save_ckpt_iter', None)}") | |
| for local_step, (media_type, (media, text, idx)) in enumerate(iterator): | |
| if local_step < begin_step: | |
| logger.warn(f"Jump local_step: {local_step} (begin_step={begin_step})!!!") | |
| metric_logger.update(lr=optimizer.param_groups[0]["lr"]) | |
| continue | |
| if config.get("save_ckpt_iter", None) is not None: # and not is_iter_resume: | |
| if local_step != 0 and local_step % config.get("save_ckpt_iter") == 0: | |
| if hasattr(config, "deepspeed") and config.deepspeed.enable: | |
| tag = f"ckpt_e{epoch:02d}_local{local_step}_global{global_step}" | |
| client_state = {'epoch': epoch, 'global_step': global_step, 'local_step': local_step} | |
| model.save_checkpoint(config.output_dir, tag=tag, save_latest=False, client_state=client_state) | |
| logger.info(f"save ckpt file to local ({config.output_dir}/{tag})!!!") | |
| elif is_main_process(): | |
| state_dict = model_without_ddp.state_dict() | |
| for k in config.get("no_save_params_prefix", []): | |
| kk = [x for x in state_dict.keys() if x.startswith(k)] | |
| logger.info(f"Not saving {len(kk)} params with prefix {k}") | |
| for kkk in kk: | |
| state_dict.pop(kkk) | |
| save_obj = { | |
| "model": state_dict, | |
| "optimizer": optimizer.state_dict(), | |
| "scheduler": scheduler.state_dict(), | |
| "scaler": scaler.state_dict(), | |
| "config": config, | |
| "epoch": epoch, | |
| "local_step": local_step, | |
| "global_step": global_step, | |
| } | |
| try: | |
| with io.BytesIO() as buffer: | |
| torch.save(save_obj, buffer) | |
| client_ckpt.put(f"{ceph_ckpt_path}/ckpt_{epoch:02d}_local{local_step}_global{global_step}.pth", buffer.getvalue()) | |
| logger.info(f"Save to ceph ({ceph_ckpt_path}/ckpt_{epoch:02d}_local{local_step}_global{global_step}.pth)!!!") | |
| except Exception as e: | |
| print(e) | |
| torch.save(save_obj, join(config.output_dir, f"ckpt_{epoch:02d}_local{local_step}_global{global_step}.pth")) | |
| logger.warn(f"Ceph is not working, save to local ({join(config.output_dir, f'ckpt_{epoch:02d}_local{local_step}_global{global_step}.pth')})!!!") | |
| if media_type == 'audio_video': | |
| if type(media[0]) is list: | |
| assert len(media[0]) == 2 | |
| audio = [media[0][0].to(device, dtype=cast_dtype, non_blocking=True), media[0][1].to(device, non_blocking=True)] | |
| else: | |
| audio = media[0].to(device, dtype=cast_dtype, non_blocking=True) | |
| video = media[1].to(device, dtype=cast_dtype, non_blocking=True) | |
| media = [audio, video] | |
| else: | |
| media = media.to(device, dtype=cast_dtype, non_blocking=True) | |
| idx = idx.to(device, non_blocking=True) | |
| if config.get("use_raw_text", False) or config.get("use_cosa", False): | |
| max_length = config.inputs.max_txt_l[media_type] | |
| else: | |
| if type(text) is dict: | |
| text_input = {} | |
| for k in text.keys(): | |
| text_input[k] = tokenizer( | |
| text[k], | |
| padding="max_length", | |
| truncation=True, | |
| max_length=config.inputs.max_txt_l[media_type], | |
| return_tensors="pt", | |
| ).to( | |
| device) # change from "longest" to "max_length" | |
| else: | |
| text_input = tokenizer( | |
| text, | |
| padding="max_length", | |
| truncation=True, | |
| max_length=config.inputs.max_txt_l[media_type], | |
| return_tensors="pt", | |
| ).to( | |
| device) # change from "longest" to "max_length" | |
| if hasattr(config, "deepspeed") and config.deepspeed.enable: | |
| loss_dict = model(media, text_input, idx=idx, media_type=media_type) | |
| loss = sum(loss_dict.values()) | |
| model.backward(loss) | |
| model.step() | |
| else: # NOTE We shouldn't use scaler if we only involve bf16, check this! | |
| with torch.cuda.amp.autocast(enabled=config.use_half_precision, dtype=cast_dtype): | |
| loss_dict = model(media, text_input, idx=idx, media_type=media_type) | |
| loss = sum(loss_dict.values()) | |
| if not config.use_half_precision or config.get('use_bf16', False): | |
| optimizer.zero_grad() | |
| loss.backward() | |
| if config.optimizer.max_grad_norm > 0: | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), config.optimizer.max_grad_norm) | |
| optimizer.step() | |
| scheduler.step() | |
| else: | |
| optimizer.zero_grad() | |
| scaler.scale(loss).backward() | |
| if config.optimizer.max_grad_norm > 0: | |
| scaler.unscale_(optimizer) | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), config.optimizer.max_grad_norm) | |
| scaler.step(optimizer) | |
| scaler.update() | |
| scheduler.step() | |
| # logging | |
| for name in loss_names: | |
| if name in loss_dict.keys(): | |
| value = loss_dict[name] | |
| value = value if isinstance(value, float) else value.item() | |
| metric_logger.update(**{f"{media_type}-{name}": value}) | |
| metric_logger.update(lr=optimizer.param_groups[0]["lr"]) | |
| metric_logger.update(temperature=model_without_ddp.temp.item()) | |
| if is_main_process() and config.wandb.enable and global_step % log_freq == 0: | |
| try: | |
| logs = metric_logger.get_global_avg_dict() | |
| log_dict_to_wandb(logs, step=global_step, prefix="train/") | |
| except Exception as e: | |
| logger.warn("Wandb is not working!!!") | |
| print(e) | |
| global_step += 1 | |
| if config.debug and global_step % 20 == 0: | |
| logger.info("debug mode, break training loop") | |
| break | |
| if config.debug and global_step % (2 * log_freq + 3) == 0: | |
| logger.info("debug mode, break training loop") | |
| break | |
| # gather the stats from all processes | |
| metric_logger.synchronize_between_processes() | |
| logger.info(f"Averaged stats: {metric_logger.global_avg()}") | |
| return global_step | |
| def setup_dataloaders(config, mode="pt"): | |
| # train datasets, create a list of data loaders | |
| logger.info(f"Creating dataset for {mode} use_iter_train={config.get('use_iter_train', False)}") | |
| train_datasets = create_dataset(f"{mode}_train", config) | |
| media_types = get_media_types(train_datasets) | |
| if config.get('use_iter_train', False): | |
| if config.distributed: | |
| batch_size = [config.inputs.batch_size[k] for k in media_types] # batch_size for each GPU | |
| samplers = create_stateful_sampler(train_datasets, batch_size) | |
| else: | |
| raise NotImplementedError | |
| else: | |
| if config.distributed: | |
| num_tasks = get_world_size() | |
| global_rank = get_rank() | |
| samplers = create_sampler( | |
| train_datasets, [True] * len(media_types), num_tasks, global_rank | |
| ) | |
| else: | |
| samplers = [None] * len(media_types) | |
| train_loaders = create_loader( | |
| train_datasets, | |
| samplers, | |
| batch_size=[config.inputs.batch_size[k] for k in media_types], | |
| num_workers=[config.num_workers] * len(media_types), | |
| is_trains=[True] * len(media_types), | |
| collate_fns=[None] * len(media_types), | |
| ) # [0] | |
| # test datasets, a mapping from dataset name to data loader | |
| test_datasets, test_dataset_names = create_dataset(f"{mode}_eval", config) | |
| test_loaders = create_loader( | |
| test_datasets, | |
| [None] * len(test_datasets), | |
| batch_size=[config.inputs.batch_size_test[d.media_type] for d in test_datasets], | |
| num_workers=[config.num_workers] * len(test_datasets), | |
| is_trains=[False] * len(test_datasets), | |
| collate_fns=[None] * len(test_datasets), | |
| ) | |
| test_name2loaders = {k: v for k, v in zip(test_dataset_names, test_loaders)} | |
| return train_loaders, test_name2loaders, media_types | |
| def main(config): | |
| if config.get('use_flash_sdp', False): | |
| torch.backends.cuda.enable_flash_sdp(enabled=True) | |
| elif config.get('use_mem_efficient_sdp', False): | |
| torch.backends.cuda.enable_mem_efficient_sdp(enabled=True) | |
| try: | |
| ceph_ckpt_path = f"{ceph_ckpt_bucket}/{config.output_dir.split('/')[-3]}/{config.output_dir.split('/')[-2]}/{config.output_dir.split('/')[-1]}" | |
| client_ckpt = Client(conf_path='~/petreloss.conf') | |
| except Exception as e: | |
| print(e) | |
| logger.info("Ceph is not working!!!") | |
| if is_main_process() and config.wandb.enable: | |
| try: | |
| run = setup_wandb(config) | |
| logger.info("Wandb is working!!!") | |
| except Exception as e: | |
| logger.warn("Wandb is not working!!!") | |
| print(e) | |
| is_pretrain = config.mode == "pt" | |
| logger.info(f"train_file: {config.train_file}") | |
| setup_seed(config.seed + get_rank()) | |
| device = torch.device(config.device) | |
| train_loaders, test_name2loaders, train_media_types = setup_dataloaders( | |
| config, mode=config.mode | |
| ) | |
| num_steps_per_epoch = sum(len(d) for d in train_loaders) | |
| config.scheduler.num_training_steps = num_steps_per_epoch * config.scheduler.epochs | |
| config.scheduler.num_warmup_steps = num_steps_per_epoch * config.scheduler.warmup_epochs | |
| # set cudnn.benchmark=True only when input size is fixed | |
| # https://discuss.pytorch.org/t/what-does-torch-backends-cudnn-benchmark-do/5936/3 | |
| cudnn.benchmark = len(train_media_types) == 1 | |
| print(f"\033[31m CURRENT NODE NAME: {os.environ['SLURMD_NODENAME']} dataloader is OK {datetime.datetime.now().strftime('%Y-%m-%d-%H_%M_%S')}!!! \033[0m") | |
| find_unused_parameters = config.model.get('find_unused_parameters', False) | |
| logger.info(f"find_unused_parameters={find_unused_parameters}") | |
| model_cls = eval(config.model.get('model_cls')) | |
| ( | |
| model, | |
| model_without_ddp, | |
| optimizer, | |
| scheduler, | |
| scaler, | |
| tokenizer, | |
| start_epoch, | |
| global_step, | |
| ) = setup_model( | |
| config, | |
| model_cls=model_cls, | |
| add_decoder=False, | |
| pretrain=is_pretrain, | |
| find_unused_parameters=find_unused_parameters, | |
| ) | |
| if is_main_process() and config.wandb.enable: | |
| try: | |
| wandb.watch(model) | |
| except Exception as e: | |
| logger.warn("Wandb is not working!!!") | |
| print(e) | |
| best = 0 | |
| best_epoch = 0 | |
| if type(config.best_key) is str: | |
| best_key = [config.best_key, "t2v_r1"] | |
| elif type(config.best_key) is list and len(config.best_key) == 2: | |
| best_key = config.best_key | |
| else: | |
| raise NotImplementedError(config.best_key) | |
| best_ckpt_id = datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S") | |
| logger.info(f"Start training, start_epoch={start_epoch}") | |
| start_time = time.time() | |
| start_step = start_epoch * num_steps_per_epoch | |
| for epoch in range(start_epoch, config.scheduler.epochs): | |
| if not config.evaluate: | |
| global_step = train( | |
| model, | |
| train_loaders, | |
| optimizer, | |
| tokenizer, | |
| epoch, | |
| global_step, | |
| device, | |
| scheduler, | |
| scaler, | |
| config, | |
| skip_num = global_step - start_step | |
| ) | |
| if hasattr(config, "deepspeed") and config.deepspeed.enable: | |
| if config.get("save_latest", False): | |
| tag = "ckpt_latest" | |
| else: | |
| tag = f"ckpt_{epoch:02d}" | |
| client_state = {'epoch': epoch, 'global_step': global_step} | |
| model.save_checkpoint(config.output_dir, tag=tag, save_latest=False, client_state=client_state) | |
| logger.info(f"save ckpt file to local ({config.output_dir}/{tag})!!!") | |
| if is_main_process() and config.get("delete_ds_optim_states", False): | |
| if config.get("save_latest", False): | |
| if epoch == (config.scheduler.epochs - 1): # last epoch | |
| last_tag = "ckpt_latest" | |
| last_ckpt_path = f"{config.output_dir}/{last_tag}" | |
| if os.path.exists(last_ckpt_path): | |
| logger.info(f"remove optim states in ({config.output_dir}/{last_tag})!!!") | |
| for file in os.listdir(last_ckpt_path): | |
| if file.endswith('optim_states.pt'): | |
| os.remove(os.path.join(last_ckpt_path, file)) | |
| else: | |
| last_tag = f"ckpt_{epoch-1:02d}" | |
| last_ckpt_path = f"{config.output_dir}/{last_tag}" | |
| if os.path.exists(last_ckpt_path): | |
| logger.info(f"remove optim states in ({config.output_dir}/{last_tag})!!!") | |
| for file in os.listdir(last_ckpt_path): | |
| if file.endswith('optim_states.pt'): | |
| os.remove(os.path.join(last_ckpt_path, file)) | |
| if epoch == (config.scheduler.epochs - 1): # last epoch | |
| last_tag = f"ckpt_{epoch:02d}" | |
| last_ckpt_path = f"{config.output_dir}/{last_tag}" | |
| if os.path.exists(last_ckpt_path): | |
| logger.info(f"remove optim states in ({config.output_dir}/{last_tag})!!!") | |
| for file in os.listdir(last_ckpt_path): | |
| if file.endswith('optim_states.pt'): | |
| os.remove(os.path.join(last_ckpt_path, file)) | |
| if is_main_process(): | |
| if not (hasattr(config, "deepspeed") and config.deepspeed.enable): | |
| state_dict = model_without_ddp.state_dict() | |
| for k in config.get("no_save_params_prefix", []): | |
| kk = [x for x in state_dict.keys() if x.startswith(k)] | |
| logger.info(f"Not saving {len(kk)} params with prefix {k}") | |
| for kkk in kk: | |
| state_dict.pop(kkk) | |
| save_obj = { | |
| "model": state_dict, | |
| "optimizer": optimizer.state_dict(), | |
| "scheduler": scheduler.state_dict(), | |
| "scaler": scaler.state_dict(), | |
| "config": config, | |
| "epoch": epoch, | |
| "global_step": global_step, | |
| } | |
| try: | |
| with io.BytesIO() as buffer: | |
| torch.save(save_obj, buffer) | |
| if config.get("save_latest", False): | |
| client_ckpt.put(f"{ceph_ckpt_path}/ckpt_latest.pth", buffer.getvalue()) | |
| logger.info(f"Save to ceph ({ceph_ckpt_path}/ckpt_latest.pth)!!!") | |
| else: | |
| client_ckpt.put(f"{ceph_ckpt_path}/ckpt_{epoch:02d}.pth", buffer.getvalue()) | |
| logger.info(f"Save to ceph ({ceph_ckpt_path}/ckpt_{epoch:02d}.pth)!!!") | |
| except Exception as e: | |
| print(e) | |
| if config.get("save_latest", False): | |
| torch.save(save_obj, join(config.output_dir, "ckpt_latest.pth")) | |
| logger.warn(f"Ceph is not working, save to local ({join(config.output_dir, 'ckpt_latest.pth')})!!!") | |
| else: | |
| torch.save(save_obj, join(config.output_dir, f"ckpt_{epoch:02d}.pth")) | |
| logger.warn(f"Ceph is not working, save to local ({join(config.output_dir, f'ckpt_{epoch:02d}.pth')})!!!") | |
| if config.get("jump_evaluate", False) and not config.evaluate: | |
| logger.warn(f"Jump the evaluation'))!!!") | |
| else: | |
| try: | |
| eval_res = {} | |
| for test_name, test_loader in test_name2loaders.items(): | |
| if test_name not in config.test_types: | |
| logger.info( | |
| f"Skip eval {test_name} split. All test_types {config.test_types}" | |
| ) | |
| continue | |
| res = evaluation_wrapper( | |
| model_without_ddp, test_loader, tokenizer, device, config, prefix=test_name | |
| ) | |
| eval_res.update(res) | |
| if is_main_process(): | |
| # log to wandb | |
| if config.wandb.enable: | |
| try: | |
| for p, v in eval_res.items(): | |
| log_dict_to_wandb(v, step=global_step, prefix=p) | |
| except Exception as e: | |
| logger.warn("Wandb is not working!!!") | |
| print(e) | |
| try: | |
| cur_recall = eval_res[best_key[0]][best_key[1]] | |
| except Exception as e: | |
| logger.warn(e) | |
| print(e) | |
| # print(eval_res) | |
| cur_recall = best - 1 | |
| eval_res = pd.DataFrame(eval_res) | |
| logger.info(f"Epoch {epoch}") | |
| logger.info(f"\n{eval_res.transpose().to_string(max_cols=30)}") | |
| eval_res.to_json(join(config.output_dir, "eval_res_latest.json")) | |
| state_dict = model_without_ddp.state_dict() | |
| for k in config.get("no_save_params_prefix", []): | |
| kk = [x for x in state_dict.keys() if x.startswith(k)] | |
| logger.info(f"Not saving {len(kk)} params with prefix {k}") | |
| for kkk in kk: | |
| state_dict.pop(kkk) | |
| if not config.evaluate and cur_recall > best: | |
| if not (hasattr(config, "deepspeed") and config.deepspeed.enable): | |
| try: | |
| with io.BytesIO() as buffer: | |
| torch.save(save_obj, buffer) | |
| client_ckpt.put(f"{ceph_ckpt_path}/ckpt_best_{best_ckpt_id}.pth", buffer.getvalue()) | |
| logger.info(f"Save to ceph ({f'{ceph_ckpt_path}/ckpt_best_{best_ckpt_id}.pth'})!!!") | |
| except Exception as e: | |
| print(e) | |
| torch.save(save_obj, join(config.output_dir, f"ckpt_best_{best_ckpt_id}.pth")) | |
| logger.warn(f"Ceph is not working, save to local ({join(config.output_dir, f'ckpt_best_{best_ckpt_id}.pth')})!!!") | |
| else: | |
| now_ckpt_path = f"{config.output_dir}/{tag}/mp_rank_00_model_states.pt" | |
| best_ckpt_path = f"{config.output_dir}/best_mp_rank_00_model_states.pt" | |
| if os.path.exists(now_ckpt_path): | |
| shutil.copy(now_ckpt_path, best_ckpt_path) | |
| logger.info(f"Copy {now_ckpt_path} to {best_ckpt_path}!!!") | |
| else: | |
| logger.warn(f"Can't find {now_ckpt_path}, there's some wrong!!!") | |
| eval_file = "eval_res_best.json" | |
| eval_res.to_json(join(config.output_dir, eval_file)) | |
| best = cur_recall | |
| best_epoch = epoch | |
| except Exception as e: | |
| logger.warn("Something wrong when eval or save!!!") | |
| print(e) | |
| if config.evaluate: | |
| raise e | |
| if config.evaluate: | |
| break | |
| start_step = global_step | |
| dist.barrier() | |
| total_time = time.time() - start_time | |
| total_time_str = str(datetime.timedelta(seconds=int(total_time))) | |
| logger.info(f"Training time {total_time_str}") | |
| logger.info(f"best epoch {best_epoch} [best_key {best_key}]") | |
| logger.info(f"Checkpoints and Logs saved at {config.output_dir}") | |
| if is_main_process() and config.wandb.enable: | |
| try: | |
| run.finish() | |
| except Exception as e: | |
| logger.warn("Wandb is not working!!!") | |
| print(e) | |
| if __name__ == "__main__": | |
| print(f"\033[31m NODE LIST: {os.environ['SLURM_NODELIST']} \033[0m") | |
| logger.info(f"NODE LIST: {os.environ['SLURM_NODELIST']}") | |
| cfg = setup_main() | |
| local_broadcast_process_authkey() | |
| main(cfg) | |