Image Classification
English
TTA
ReservoirTTA / utils /misc.py
GuillaumeVray
Uploading files
02ba886
import torch
import logging
logger = logging.getLogger(__name__)
@torch.no_grad()
def ema_update_model(model_to_update, model_to_merge, momentum, device, update_all=False):
if momentum < 1.0:
for param_to_update, param_to_merge in zip(model_to_update.parameters(), model_to_merge.parameters()):
if param_to_update.requires_grad or update_all:
param_to_update.data = momentum * param_to_update.data + (1 - momentum) * param_to_merge.data.to(device)
return model_to_update
def print_memory_info():
logger.info('-' * 40)
mem_dict = {}
for metric in ['memory_allocated', 'max_memory_allocated', 'memory_reserved', 'max_memory_reserved']:
mem_dict[metric] = eval(f'torch.cuda.{metric}()')
logger.info(f"{metric:>20s}: {mem_dict[metric] / 1e6:10.2f}MB")
logger.info('-' * 40)
return mem_dict