basketball_code / logger.py
youqiwong's picture
Upload folder using huggingface_hub
0c51b93 verified
from torch.utils.tensorboard import SummaryWriter
from collections import defaultdict
import json
import os
import csv
import shutil
import torch
import numpy as np
from termcolor import colored
COMMON_TRAIN_FORMAT = [
('episode', 'E', 'int'),
('step', 'S', 'int'),
('episode_reward', 'R', 'float'),
('true_episode_reward', 'TR', 'float'),
('total_feedback', 'TF', 'int'),
('labeled_feedback', 'LR', 'int'),
('noisy_feedback', 'NR', 'int'),
('duration', 'D', 'time'),
('total_duration', 'TD', 'time'),
]
COMMON_EVAL_FORMAT = [
('episode', 'E', 'int'),
('step', 'S', 'int'),
('episode_reward', 'R', 'float'),
('true_episode_reward', 'TR', 'float'),
('true_episode_success', 'TS', 'float'),
]
AGENT_TRAIN_FORMAT = {
'sac': [
('batch_reward', 'BR', 'float'),
('actor_loss', 'ALOSS', 'float'),
('critic_loss', 'CLOSS', 'float'),
('alpha_loss', 'TLOSS', 'float'),
('alpha_value', 'TVAL', 'float'),
('actor_entropy', 'AENT', 'float'),
('bc_loss', 'BCLOSS', 'float'),
],
'ppo': [
('batch_reward', 'BR', 'float'),
],
}
class AverageMeter(object):
def __init__(self):
self._sum = 0
self._count = 0
def update(self, value, n=1):
self._sum += value
self._count += n
def value(self):
return self._sum / max(1, self._count)
class MetersGroup(object):
def __init__(self, file_name, formating):
self._csv_file_name = self._prepare_file(file_name, 'csv')
self._formating = formating
self._meters = defaultdict(AverageMeter)
self._csv_file = open(self._csv_file_name, 'w')
self._csv_writer = None
def _prepare_file(self, prefix, suffix):
file_name = f'{prefix}.{suffix}'
if os.path.exists(file_name):
os.remove(file_name)
return file_name
def log(self, key, value, n=1):
self._meters[key].update(value, n)
def _prime_meters(self):
data = dict()
for key, meter in self._meters.items():
if key.startswith('train'):
key = key[len('train') + 1:]
else:
key = key[len('eval') + 1:]
key = key.replace('/', '_')
data[key] = meter.value()
return data
def _dump_to_csv(self, data):
if self._csv_writer is None:
self._csv_writer = csv.DictWriter(self._csv_file,
fieldnames=sorted(data.keys()),
restval=0.0)
self._csv_writer.writeheader()
self._csv_writer.writerow(data)
self._csv_file.flush()
def _format(self, key, value, ty):
if ty == 'int':
value = int(value)
return f'{key}: {value}'
elif ty == 'float':
return f'{key}: {value:.04f}'
elif ty == 'time':
return f'{key}: {value:04.1f} s'
else:
raise f'invalid format type: {ty}'
def _dump_to_console(self, data, prefix):
prefix = colored(prefix, 'yellow' if prefix == 'train' else 'green')
pieces = [f'| {prefix: <14}']
for key, disp_key, ty in self._formating:
value = data.get(key, 0)
pieces.append(self._format(disp_key, value, ty))
print(' | '.join(pieces))
def dump(self, step, prefix, save=True):
if len(self._meters) == 0:
return
if save:
data = self._prime_meters()
data['step'] = step
self._dump_to_csv(data)
self._dump_to_console(data, prefix)
self._meters.clear()
class Logger(object):
def __init__(self,
log_dir,
save_tb=False,
log_frequency=10000,
agent='sac'):
self._log_dir = log_dir
self._log_frequency = log_frequency
if save_tb:
tb_dir = os.path.join(log_dir, 'tb')
if os.path.exists(tb_dir):
try:
shutil.rmtree(tb_dir)
except:
print("logger.py warning: Unable to remove tb directory")
pass
self._sw = SummaryWriter(tb_dir)
else:
self._sw = None
# each agent has specific output format for training
assert agent in AGENT_TRAIN_FORMAT
train_format = COMMON_TRAIN_FORMAT + AGENT_TRAIN_FORMAT[agent]
self._train_mg = MetersGroup(os.path.join(log_dir, 'train'),
formating=train_format)
self._eval_mg = MetersGroup(os.path.join(log_dir, 'eval'),
formating=COMMON_EVAL_FORMAT)
def _should_log(self, step, log_frequency):
cur_step = step
log_frequency = log_frequency or self._log_frequency
# import pdb; pdb.set_trace()
return step % log_frequency == 0
def _try_sw_log(self, key, value, step):
if self._sw is not None:
self._sw.add_scalar(key, value, step)
def _try_sw_log_video(self, key, frames, step):
if self._sw is not None:
frames = torch.from_numpy(np.array(frames))
frames = frames.unsqueeze(0)
self._sw.add_video(key, frames, step, fps=30)
def _try_sw_log_histogram(self, key, histogram, step):
if self._sw is not None:
self._sw.add_histogram(key, histogram, step)
def log(self, key, value, step, n=1, log_frequency=1):
cur_step = step
# import pdb; pdb.set_trace()
if not self._should_log(step, log_frequency):
return
assert key.startswith('train') or key.startswith('eval')
if type(value) == torch.Tensor:
value = value.item()
self._try_sw_log(key, value / n, step)
mg = self._train_mg if key.startswith('train') else self._eval_mg
mg.log(key, value, n)
def log_param(self, key, param, step, log_frequency=None):
if not self._should_log(step, log_frequency):
return
self.log_histogram(key + '_w', param.weight.data, step)
if hasattr(param.weight, 'grad') and param.weight.grad is not None:
self.log_histogram(key + '_w_g', param.weight.grad.data, step)
if hasattr(param, 'bias') and hasattr(param.bias, 'data'):
self.log_histogram(key + '_b', param.bias.data, step)
if hasattr(param.bias, 'grad') and param.bias.grad is not None:
self.log_histogram(key + '_b_g', param.bias.grad.data, step)
def log_video(self, key, frames, step, log_frequency=None):
if not self._should_log(step, log_frequency):
return
assert key.startswith('train') or key.startswith('eval')
self._try_sw_log_video(key, frames, step)
def log_histogram(self, key, histogram, step, log_frequency=None):
if not self._should_log(step, log_frequency):
return
assert key.startswith('train') or key.startswith('eval')
self._try_sw_log_histogram(key, histogram, step)
def dump(self, step, save=True, ty=None):
if ty is None:
self._train_mg.dump(step, 'train', save)
self._eval_mg.dump(step, 'eval', save)
elif ty == 'eval':
self._eval_mg.dump(step, 'eval', save)
elif ty == 'train':
self._train_mg.dump(step, 'train', save)
else:
raise f'invalid log type: {ty}'