karthikeya1212's picture
Upload 115 files
cda88e0 verified
import numpy as np
import torch
from torch.utils.tensorboard import SummaryWriter
def tensorboard_plot_loss(win_name, loss, writer):
writer.add_scalar("Loss/{}".format(win_name), loss[-1], len(loss))
writer.flush()
def normalize_img(imgs):
b,c,h,w = imgs.shape
gt_batch = b//2
for i in range(gt_batch):
factor = torch.max(imgs[i])
imgs[i] = imgs[i]/factor
imgs[gt_batch + i] = imgs[gt_batch + i]/factor
imgs = torch.clamp(imgs, 0.0,1.0)
return imgs
def tensorboard_show_batch(imgs, writer, win_name=None, nrow=2, normalize=True, step=0):
if normalize:
imgs = normalize_img(imgs)
writer.add_images('{}'.format(win_name), imgs, step)
writer.flush()
def tensorboard_log(log_info, writer, win_name='logger', step=0):
writer.add_text(win_name, log_info, step)
writer.flush()