Spaces:
Runtime error
Runtime error
| from torch.utils.tensorboard import SummaryWriter | |
| import os, sys, glob | |
| class logger(): | |
| def __init__(self, args, mode, suffix = None): | |
| self.outdir = f'{args.outdir}/{args.session_name}' | |
| if args.test == False: | |
| logdir = f'{self.outdir}/log' | |
| for file in glob.glob(f'{logdir}/*'): | |
| if os.path.isfile(file): | |
| os.remove(file) | |
| self.writer = SummaryWriter(log_dir=f'{logdir}', flush_secs=1) | |
| def add(self, tag, item, step, itemtype): | |
| if itemtype == 'Image': | |
| if len(item.shape) == 3: | |
| self.writer.add_image(tag, item, global_step=step, dataformats='CHW') | |
| elif len(item.shape) == 4: | |
| self.writer.add_images(tag, item, global_step=step, dataformats='NCHW') | |
| else: | |
| raise Exception("item.shape must be 3 or 4 (%d)" % len(item.shape)) | |
| elif itemtype == 'Scalar': | |
| self.writer.add_scalar(tag, item, global_step=step) | |
| else: | |
| print('itemtype is not in "Image, Scalar"', file = sys.stderr) | |