|
|
"""
|
|
|
Integrate numerical values for some iterations
|
|
|
Typically used for loss computation / logging to tensorboard
|
|
|
Call finalize and create a new Integrator when you want to display/log
|
|
|
"""
|
|
|
|
|
|
import torch
|
|
|
|
|
|
|
|
|
class Integrator:
|
|
|
def __init__(self, logger, distributed=True, local_rank=0, world_size=1):
|
|
|
self.values = {}
|
|
|
self.counts = {}
|
|
|
self.hooks = []
|
|
|
|
|
|
self.logger = logger
|
|
|
|
|
|
self.distributed = distributed
|
|
|
self.local_rank = local_rank
|
|
|
self.world_size = world_size
|
|
|
|
|
|
def add_tensor(self, key, tensor):
|
|
|
if key not in self.values:
|
|
|
self.counts[key] = 1
|
|
|
if type(tensor) == float or type(tensor) == int:
|
|
|
self.values[key] = tensor
|
|
|
else:
|
|
|
self.values[key] = tensor.mean().item()
|
|
|
else:
|
|
|
self.counts[key] += 1
|
|
|
if type(tensor) == float or type(tensor) == int:
|
|
|
self.values[key] += tensor
|
|
|
else:
|
|
|
self.values[key] += tensor.mean().item()
|
|
|
|
|
|
def add_dict(self, tensor_dict):
|
|
|
for k, v in tensor_dict.items():
|
|
|
self.add_tensor(k, v)
|
|
|
|
|
|
def add_hook(self, hook):
|
|
|
"""
|
|
|
Adds a custom hook, i.e. compute new metrics using values in the dict
|
|
|
The hook takes the dict as argument, and returns a (k, v) tuple
|
|
|
e.g. for computing IoU
|
|
|
"""
|
|
|
if type(hook) == list:
|
|
|
self.hooks.extend(hook)
|
|
|
else:
|
|
|
self.hooks.append(hook)
|
|
|
|
|
|
def reset_except_hooks(self):
|
|
|
self.values = {}
|
|
|
self.counts = {}
|
|
|
|
|
|
|
|
|
def finalize(self, prefix, it, f=None):
|
|
|
|
|
|
for hook in self.hooks:
|
|
|
k, v = hook(self.values)
|
|
|
self.add_tensor(k, v)
|
|
|
|
|
|
for k, v in self.values.items():
|
|
|
|
|
|
if k[:4] == 'hide':
|
|
|
continue
|
|
|
|
|
|
avg = v / self.counts[k]
|
|
|
|
|
|
if self.distributed:
|
|
|
|
|
|
avg = torch.tensor(avg).cuda()
|
|
|
torch.distributed.reduce(avg, dst=0)
|
|
|
|
|
|
if self.local_rank == 0:
|
|
|
avg = (avg/self.world_size).cpu().item()
|
|
|
self.logger.log_metrics(prefix, k, avg, it, f)
|
|
|
else:
|
|
|
|
|
|
self.logger.log_metrics(prefix, k, avg, it, f)
|
|
|
|
|
|
|