text-to-image-model / logger.py
JBlitzar
ahahahaha it works
9f5a022
raw
history blame contribute delete
745 Bytes
import torch
from torch.utils.tensorboard import SummaryWriter
import matplotlib.pyplot as plt
writer = None
def log_data(data, i):
for key in data.keys():
writer.add_scalar(key, data[key], i)
def log_img(img, name):
writer.add_image(name, img)
def save_grid_with_label(img_grid, label, out_file):
img_grid = img_grid.permute(1, 2, 0).numpy()
fig, ax = plt.subplots(figsize=(8, 8))
ax.imshow(img_grid)
ax.set_title(label, fontsize=20)
ax.axis('off')
plt.subplots_adjust(top=0.85)
plt.savefig(out_file, bbox_inches='tight', pad_inches=0.1)
plt.close(fig)
plt.close("all")
def init_logger(dir="runs"):
global writer
if not writer:
writer = SummaryWriter(dir)