File size: 4,345 Bytes
1cd928a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 | '''
author: wayn391@mastertones
'''
import os
import json
import time
import yaml
import datetime
import torch
import matplotlib.pyplot as plt
from . import utils
import numpy as np
from torch.utils.tensorboard import SummaryWriter
class Saver(object):
def __init__(
self,
args,
initial_global_step=0):
# cold start
self.global_step = initial_global_step
self.init_time = time.time()
self.last_time = time.time()
self.log_dir = args.log_dir
self.sample_rate = args.sample_rate
# ckpt
os.makedirs(self.log_dir, exist_ok=True)
# writer
self.writer = SummaryWriter(self.log_dir)
def log_info(self, msg):
'''log method'''
if isinstance(msg, dict):
msg_list = []
for k, v in msg.items():
tmp_str = ''
if isinstance(v, int):
tmp_str = '{}: {:,}'.format(k, v)
else:
tmp_str = '{}: {}'.format(k, v)
msg_list.append(tmp_str)
msg_str = '\n'.join(msg_list)
else:
msg_str = msg
# dsplay
print(msg_str)
# save
with open(self.path_log_info, 'a') as fp:
fp.write(msg_str+'\n')
def log_value(self, dict):
for k, v in dict.items():
self.writer.add_scalar(k, v, self.global_step)
def log_spec(self, name, spec, vmin=-14, vmax=3.5):
# 检查 spec 是否为 Tensor,并转换为 numpy
if isinstance(spec, torch.Tensor):
spec = spec.cpu().numpy()
# 为 spec 绘制图像
fig = plt.figure(figsize=(12, 6))
# font_path = 'SimHei' # 或者字体的绝对路径
# font_prop = FontProperties(fname=font_path, size=14)
plt.imshow(spec, aspect='auto', vmin=vmin, vmax=vmax)
plt.colorbar()
# plt.title(name, fontproperties=font_prop)
plt.gca().invert_yaxis() # 反转y轴
plt.tight_layout()
# 将图像添加到 TensorBoard
self.writer.add_figure(name, fig, self.global_step)
# 关闭图形以释放资源
plt.close(fig)
def log_audio(self, dict):
for k, v in dict.items():
self.writer.add_audio(k, v, global_step=self.global_step, sample_rate=self.sample_rate)
def get_interval_time(self, update=True):
cur_time = time.time()
time_interval = cur_time - self.last_time
if update:
self.last_time = cur_time
return time_interval
def get_total_time(self, to_str=True):
total_time = time.time() - self.init_time
if to_str:
total_time = str(datetime.timedelta(
seconds=total_time))[:-5]
return total_time
def save_model(
self,
model,
optimizer,
name='model',
postfix='',
to_json=False):
# os.makedirs(os.path.join(self.expdir), exist_ok=True)
# path
if postfix:
postfix = '_' + postfix
path_pt = os.path.join(
self.log_dir , name+postfix+'.pt')
# check
print(' [*] model checkpoint saved: {}'.format(path_pt))
# save
if optimizer is not None:
torch.save({
'global_step': self.global_step,
'model': model.state_dict(),
'optimizer': optimizer.state_dict()}, path_pt)
else:
torch.save({
'global_step': self.global_step,
'model': model.state_dict()}, path_pt)
# to json
# if to_json:
# path_json = os.path.join(
# self.expdir , name+'.json')
# utils.to_json(path_params, path_json)
def delete_model(self, name='model', postfix=''):
# path
if postfix:
postfix = '_' + postfix
path_pt = os.path.join(
self.expdir , name+postfix+'.pt')
# delete
if os.path.exists(path_pt):
os.remove(path_pt)
print(' [*] model checkpoint deleted: {}'.format(path_pt))
def global_step_increment(self):
self.global_step += 1
|