|
|
''' |
|
|
author: wayn391@mastertones |
|
|
''' |
|
|
|
|
|
import os |
|
|
import json |
|
|
import time |
|
|
import yaml |
|
|
import datetime |
|
|
import torch |
|
|
import matplotlib.pyplot as plt |
|
|
from . import utils |
|
|
import numpy as np |
|
|
from torch.utils.tensorboard import SummaryWriter |
|
|
class Saver(object): |
|
|
def __init__( |
|
|
self, |
|
|
args, |
|
|
initial_global_step=0): |
|
|
|
|
|
|
|
|
self.global_step = initial_global_step |
|
|
self.init_time = time.time() |
|
|
self.last_time = time.time() |
|
|
self.log_dir = args.log_dir |
|
|
self.sample_rate = args.sample_rate |
|
|
|
|
|
|
|
|
os.makedirs(self.log_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
self.writer = SummaryWriter(self.log_dir) |
|
|
|
|
|
|
|
|
def log_info(self, msg): |
|
|
'''log method''' |
|
|
if isinstance(msg, dict): |
|
|
msg_list = [] |
|
|
for k, v in msg.items(): |
|
|
tmp_str = '' |
|
|
if isinstance(v, int): |
|
|
tmp_str = '{}: {:,}'.format(k, v) |
|
|
else: |
|
|
tmp_str = '{}: {}'.format(k, v) |
|
|
|
|
|
msg_list.append(tmp_str) |
|
|
msg_str = '\n'.join(msg_list) |
|
|
else: |
|
|
msg_str = msg |
|
|
|
|
|
|
|
|
print(msg_str) |
|
|
|
|
|
|
|
|
with open(self.path_log_info, 'a') as fp: |
|
|
fp.write(msg_str+'\n') |
|
|
|
|
|
def log_value(self, dict): |
|
|
for k, v in dict.items(): |
|
|
self.writer.add_scalar(k, v, self.global_step) |
|
|
|
|
|
def log_spec(self, name, spec, vmin=-14, vmax=3.5): |
|
|
|
|
|
if isinstance(spec, torch.Tensor): |
|
|
spec = spec.cpu().numpy() |
|
|
|
|
|
|
|
|
fig = plt.figure(figsize=(12, 6)) |
|
|
|
|
|
|
|
|
plt.imshow(spec, aspect='auto', vmin=vmin, vmax=vmax) |
|
|
plt.colorbar() |
|
|
|
|
|
plt.gca().invert_yaxis() |
|
|
plt.tight_layout() |
|
|
|
|
|
|
|
|
self.writer.add_figure(name, fig, self.global_step) |
|
|
|
|
|
|
|
|
plt.close(fig) |
|
|
|
|
|
def log_audio(self, dict): |
|
|
for k, v in dict.items(): |
|
|
self.writer.add_audio(k, v, global_step=self.global_step, sample_rate=self.sample_rate) |
|
|
|
|
|
def get_interval_time(self, update=True): |
|
|
cur_time = time.time() |
|
|
time_interval = cur_time - self.last_time |
|
|
if update: |
|
|
self.last_time = cur_time |
|
|
return time_interval |
|
|
|
|
|
def get_total_time(self, to_str=True): |
|
|
total_time = time.time() - self.init_time |
|
|
if to_str: |
|
|
total_time = str(datetime.timedelta( |
|
|
seconds=total_time))[:-5] |
|
|
return total_time |
|
|
|
|
|
def save_model( |
|
|
self, |
|
|
model, |
|
|
optimizer, |
|
|
name='model', |
|
|
postfix='', |
|
|
to_json=False): |
|
|
|
|
|
|
|
|
if postfix: |
|
|
postfix = '_' + postfix |
|
|
path_pt = os.path.join( |
|
|
self.log_dir , name+postfix+'.pt') |
|
|
|
|
|
|
|
|
print(' [*] model checkpoint saved: {}'.format(path_pt)) |
|
|
|
|
|
|
|
|
if optimizer is not None: |
|
|
torch.save({ |
|
|
'global_step': self.global_step, |
|
|
'model': model.state_dict(), |
|
|
'optimizer': optimizer.state_dict()}, path_pt) |
|
|
else: |
|
|
torch.save({ |
|
|
'global_step': self.global_step, |
|
|
'model': model.state_dict()}, path_pt) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def delete_model(self, name='model', postfix=''): |
|
|
|
|
|
if postfix: |
|
|
postfix = '_' + postfix |
|
|
path_pt = os.path.join( |
|
|
self.expdir , name+postfix+'.pt') |
|
|
|
|
|
|
|
|
if os.path.exists(path_pt): |
|
|
os.remove(path_pt) |
|
|
print(' [*] model checkpoint deleted: {}'.format(path_pt)) |
|
|
|
|
|
def global_step_increment(self): |
|
|
self.global_step += 1 |
|
|
|
|
|
|
|
|
|