| """This module implements the writer class for logging to tensorboard or wandb.""" |
|
|
| import logging |
| import os |
| from typing import Any, Dict, Optional |
|
|
| from omegaconf import DictConfig |
| from torch import nn |
| from torch.utils.tensorboard import SummaryWriter as TFSummaryWriter |
|
|
| from siclib import __module_name__ |
|
|
| logger = logging.getLogger(__name__) |
|
|
| try: |
| import wandb |
| except ImportError: |
| logger.debug("Could not import wandb.") |
| wandb = None |
|
|
| |
|
|
|
|
| def dot_conf(conf: DictConfig) -> Dict[str, Any]: |
| """Recursively convert a DictConfig to a flat dict with keys joined by dots.""" |
| d = {} |
| for k, v in conf.items(): |
| if isinstance(v, DictConfig): |
| d |= {f"{k}.{k2}": v2 for k2, v2 in dot_conf(v).items()} |
| else: |
| d[k] = v |
| return d |
|
|
|
|
| class SummaryWriter: |
| """Writer class for logging to tensorboard or wandb.""" |
|
|
| def __init__(self, conf: DictConfig, args: DictConfig, log_dir: str): |
| """Initialize the writer.""" |
| self.conf = conf |
|
|
| if not conf.train.writer: |
| self.use_wandb = False |
| self.use_tensorboard = False |
| return |
|
|
| self.use_wandb = "wandb" in conf.train.writer |
| self.use_tensorboard = "tensorboard" in conf.train.writer |
|
|
| if self.use_wandb and not wandb: |
| raise ImportError("wandb not installed.") |
|
|
| if self.use_tensorboard: |
| self.writer = TFSummaryWriter(log_dir=log_dir) |
|
|
| if self.use_wandb: |
| os.environ["WANDB__SERVICE_WAIT"] = "300" |
| wandb.init(project=__module_name__, name=args.experiment, config=dot_conf(conf)) |
|
|
| if conf.train.writer and not self.use_wandb and not self.use_tensorboard: |
| raise NotImplementedError(f"Writer {conf.train.writer} not implemented") |
|
|
| def add_scalar(self, tag: str, value: float, step: Optional[int] = None): |
| """Log a scalar value to tensorboard or wandb.""" |
| if self.use_wandb: |
| step = 1 if step == 0 else step |
| wandb.log({tag: value}, step=step) |
|
|
| if self.use_tensorboard: |
| self.writer.add_scalar(tag, value, step) |
|
|
| def add_figure(self, tag: str, figure, step: Optional[int] = None): |
| """Log a figure to tensorboard or wandb.""" |
| if self.use_wandb: |
| step = 1 if step == 0 else step |
| wandb.log({tag: figure}, step=step) |
| if self.use_tensorboard: |
| self.writer.add_figure(tag, figure, step) |
|
|
| def add_histogram(self, tag: str, values, step: Optional[int] = None): |
| """Log a histogram to tensorboard or wandb.""" |
| if self.use_tensorboard: |
| self.writer.add_histogram(tag, values, step) |
|
|
| def add_text(self, tag: str, text: str, step: Optional[int] = None): |
| """Log text to tensorboard or wandb.""" |
| if self.use_tensorboard: |
| self.writer.add_text(tag, text, step) |
|
|
| def add_pr_curve(self, tag: str, values, step: Optional[int] = None): |
| """Log a precision-recall curve to tensorboard or wandb.""" |
| if self.use_wandb: |
| step = 1 if step == 0 else step |
| |
| |
| wandb.log({tag: wandb.plots.precision_recall(values)}, step=step) |
|
|
| if self.use_tensorboard: |
| self.writer.add_pr_curve(tag, values, step) |
|
|
| def watch(self, model: nn.Module, log_freq: int = 1000): |
| """Watch a model for gradient updates.""" |
| if self.use_wandb: |
| wandb.watch( |
| model, |
| log="gradients", |
| log_freq=log_freq, |
| ) |
|
|
| def close(self): |
| """Close the writer.""" |
| if self.use_wandb: |
| wandb.finish() |
|
|
| if self.use_tensorboard: |
| self.writer.close() |
|
|