import numpy import torch.distributed as dist import torch import clip import os def reduce_tensor(tensor, n=None): if n is None: n = dist.get_world_size() rt = tensor.clone() dist.all_reduce(rt, op=dist.ReduceOp.SUM) rt = rt / n return rt class AverageMeter: """Computes and stores the average and current value""" def __init__(self): self.reset() def reset(self): self.val = 0 self.avg = 0 self.sum = 0 self.count = 0 def update(self, val, n=1): self.val = val self.sum += val * n self.count += n self.avg = self.sum / self.count def sync(self): rank = dist.get_rank() world_size = dist.get_world_size() val = torch.tensor(self.val).cuda() sum_v = torch.tensor(self.sum).cuda() count = torch.tensor(self.count).cuda() self.val = reduce_tensor(val, world_size).item() self.sum = reduce_tensor(sum_v, 1).item() self.count = reduce_tensor(count, 1).item() self.avg = self.sum / self.count def epoch_saving(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger, working_dir, is_best): save_state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'lr_scheduler': lr_scheduler.state_dict(), 'max_accuracy': max_accuracy, 'epoch': epoch, 'config': config} save_path = os.path.join(working_dir, f'ckpt_epoch_{epoch}.pth') logger.info(f"{save_path} saving......") torch.save(save_state, save_path) logger.info(f"{save_path} saved !!!") if is_best: best_path = os.path.join(working_dir, f'best.pth') torch.save(save_state, best_path) logger.info(f"{best_path} saved !!!") def load_checkpoint(config, model, optimizer, lr_scheduler, logger): if os.path.isfile(config.MODEL.RESUME): logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................") checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu') load_state_dict = checkpoint['model'] # now remove the unwanted keys: if "module.prompt_learner.token_prefix" in load_state_dict: del load_state_dict["module.prompt_learner.token_prefix"] if "module.prompt_learner.token_suffix" in load_state_dict: del load_state_dict["module.prompt_learner.token_suffix"] if "module.prompt_learner.complete_text_embeddings" in load_state_dict: del load_state_dict["module.prompt_learner.complete_text_embeddings"] msg = model.load_state_dict(load_state_dict, strict=False) logger.info(f"resume model: {msg}") try: optimizer.load_state_dict(checkpoint['optimizer']) lr_scheduler.load_state_dict(checkpoint['lr_scheduler']) start_epoch = checkpoint['epoch'] + 1 max_accuracy = checkpoint['max_accuracy'] logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})") del checkpoint torch.cuda.empty_cache() return start_epoch, max_accuracy except: del checkpoint torch.cuda.empty_cache() return 0, 0. else: logger.info(("=> no checkpoint found at '{}'".format(config.MODEL.RESUME))) return 0, 0 def auto_resume_helper(output_dir): checkpoints = os.listdir(output_dir) checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')] print(f"All checkpoints founded in {output_dir}: {checkpoints}") if len(checkpoints) > 0: latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime) print(f"The latest checkpoint founded: {latest_checkpoint}") resume_file = latest_checkpoint else: resume_file = None return resume_file def generate_text(data): text_aug = f"{{}}" classes = torch.cat([clip.tokenize(text_aug.format(c), context_length=77) for i, c in data.classes]) return classes