EAR_challenge / utils /tools.py
srijandas07's picture
Upload 52 files
1c990f3 verified
import numpy
import torch.distributed as dist
import torch
import clip
import os
def reduce_tensor(tensor, n=None):
if n is None:
n = dist.get_world_size()
rt = tensor.clone()
dist.all_reduce(rt, op=dist.ReduceOp.SUM)
rt = rt / n
return rt
class AverageMeter:
"""Computes and stores the average and current value"""
def __init__(self):
self.reset()
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
def sync(self):
rank = dist.get_rank()
world_size = dist.get_world_size()
val = torch.tensor(self.val).cuda()
sum_v = torch.tensor(self.sum).cuda()
count = torch.tensor(self.count).cuda()
self.val = reduce_tensor(val, world_size).item()
self.sum = reduce_tensor(sum_v, 1).item()
self.count = reduce_tensor(count, 1).item()
self.avg = self.sum / self.count
def epoch_saving(config, epoch, model, max_accuracy, optimizer, lr_scheduler, logger, working_dir, is_best):
save_state = {'model': model.state_dict(),
'optimizer': optimizer.state_dict(),
'lr_scheduler': lr_scheduler.state_dict(),
'max_accuracy': max_accuracy,
'epoch': epoch,
'config': config}
save_path = os.path.join(working_dir, f'ckpt_epoch_{epoch}.pth')
logger.info(f"{save_path} saving......")
torch.save(save_state, save_path)
logger.info(f"{save_path} saved !!!")
if is_best:
best_path = os.path.join(working_dir, f'best.pth')
torch.save(save_state, best_path)
logger.info(f"{best_path} saved !!!")
def load_checkpoint(config, model, optimizer, lr_scheduler, logger):
if os.path.isfile(config.MODEL.RESUME):
logger.info(f"==============> Resuming form {config.MODEL.RESUME}....................")
checkpoint = torch.load(config.MODEL.RESUME, map_location='cpu')
load_state_dict = checkpoint['model']
# now remove the unwanted keys:
if "module.prompt_learner.token_prefix" in load_state_dict:
del load_state_dict["module.prompt_learner.token_prefix"]
if "module.prompt_learner.token_suffix" in load_state_dict:
del load_state_dict["module.prompt_learner.token_suffix"]
if "module.prompt_learner.complete_text_embeddings" in load_state_dict:
del load_state_dict["module.prompt_learner.complete_text_embeddings"]
msg = model.load_state_dict(load_state_dict, strict=False)
logger.info(f"resume model: {msg}")
try:
optimizer.load_state_dict(checkpoint['optimizer'])
lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
start_epoch = checkpoint['epoch'] + 1
max_accuracy = checkpoint['max_accuracy']
logger.info(f"=> loaded successfully '{config.MODEL.RESUME}' (epoch {checkpoint['epoch']})")
del checkpoint
torch.cuda.empty_cache()
return start_epoch, max_accuracy
except:
del checkpoint
torch.cuda.empty_cache()
return 0, 0.
else:
logger.info(("=> no checkpoint found at '{}'".format(config.MODEL.RESUME)))
return 0, 0
def auto_resume_helper(output_dir):
checkpoints = os.listdir(output_dir)
checkpoints = [ckpt for ckpt in checkpoints if ckpt.endswith('pth')]
print(f"All checkpoints founded in {output_dir}: {checkpoints}")
if len(checkpoints) > 0:
latest_checkpoint = max([os.path.join(output_dir, d) for d in checkpoints], key=os.path.getmtime)
print(f"The latest checkpoint founded: {latest_checkpoint}")
resume_file = latest_checkpoint
else:
resume_file = None
return resume_file
def generate_text(data):
text_aug = f"{{}}"
classes = torch.cat([clip.tokenize(text_aug.format(c), context_length=77) for i, c in data.classes])
return classes