|
|
import torch |
|
|
import logging |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
@torch.no_grad() |
|
|
def ema_update_model(model_to_update, model_to_merge, momentum, device, update_all=False): |
|
|
if momentum < 1.0: |
|
|
for param_to_update, param_to_merge in zip(model_to_update.parameters(), model_to_merge.parameters()): |
|
|
if param_to_update.requires_grad or update_all: |
|
|
param_to_update.data = momentum * param_to_update.data + (1 - momentum) * param_to_merge.data.to(device) |
|
|
return model_to_update |
|
|
|
|
|
|
|
|
def print_memory_info(): |
|
|
logger.info('-' * 40) |
|
|
mem_dict = {} |
|
|
for metric in ['memory_allocated', 'max_memory_allocated', 'memory_reserved', 'max_memory_reserved']: |
|
|
mem_dict[metric] = eval(f'torch.cuda.{metric}()') |
|
|
logger.info(f"{metric:>20s}: {mem_dict[metric] / 1e6:10.2f}MB") |
|
|
logger.info('-' * 40) |
|
|
return mem_dict |
|
|
|