File size: 1,212 Bytes
83d8d3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
import torch
from torch.utils.tensorboard import SummaryWriter


class Visualizer:
    """
    Tensorboard 可视化监控类
    """

    def __init__(self, opt):
        """ """
        self.opt = opt  # cache the option
        self.writer = SummaryWriter(log_dir=opt.log_dir)

    def display_current_results(self, iters, visuals_dict):
        """
        Display current images

        Parameters:
        ----------
            visuals (OrderedDict) - - dictionary of images to display
            iters (int) - - the current iteration
        """
        for label, image in visuals_dict.items():
            if image.shape[0] >= 2:
                image = image[0:2, :, :, :]
            self.writer.add_images(str(label), (image * 255.0).to(torch.uint8), global_step=iters, dataformats="NCHW")

    def plot_current_losses(self, iters, loss_dict):
        """
        Display losses on tensorboard

        Parameters:
            iters (int)           -- current iteration
            losses (OrderedDict)  -- training losses stored in the format of (name, torch.Tensor) pairs
        """
        x = iters
        for k, v in loss_dict.items():
            self.writer.add_scalar(f"Loss/{k}", v, x)