Spaces:
Build error
Build error
| """Tensorboard logger with add image interface.""" | |
| # Copyright (C) 2020 Intel Corporation | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, | |
| # software distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions | |
| # and limitations under the License. | |
| from typing import Any, Optional, Union | |
| import numpy as np | |
| from matplotlib.figure import Figure | |
| from pytorch_lightning.loggers.tensorboard import TensorBoardLogger | |
| from pytorch_lightning.utilities import rank_zero_only | |
| from .base import ImageLoggerBase | |
| class AnomalibTensorBoardLogger(ImageLoggerBase, TensorBoardLogger): | |
| """Logger for tensorboard. | |
| Adds interface for `add_image` in the logger rather than calling the experiment object. | |
| Note: | |
| Same as the Tensorboard Logger provided by PyTorch Lightning and the doc string is reproduced below. | |
| Logs are saved to | |
| ``os.path.join(save_dir, name, version)``. This is the default logger in Lightning, it comes | |
| preinstalled. | |
| Example: | |
| >>> from pytorch_lightning import Trainer | |
| >>> from anomalib.utils.loggers import AnomalibTensorBoardLogger | |
| >>> logger = AnomalibTensorBoardLogger("tb_logs", name="my_model") | |
| >>> trainer = Trainer(logger=logger) | |
| Args: | |
| save_dir (str): Save directory | |
| name (Optional, str): Experiment name. Defaults to ``'default'``. If it is the empty string then no | |
| per-experiment subdirectory is used. | |
| version (Optional, int, str): Experiment version. If version is not specified the logger inspects the save | |
| directory for existing versions, then automatically assigns the next available version. | |
| If it is a string then it is used as the run-specific subdirectory name, | |
| otherwise ``'version_${version}'`` is used. | |
| log_graph (bool): Adds the computational graph to tensorboard. This requires that | |
| the user has defined the `self.example_input_array` attribute in their | |
| model. | |
| default_hp_metric (bool): Enables a placeholder metric with key `hp_metric` when `log_hyperparams` is | |
| called without a metric (otherwise calls to log_hyperparams without a metric are ignored). | |
| prefix (str): A string to put at the beginning of metric keys. | |
| **kwargs: Additional arguments like `comment`, `filename_suffix`, etc. used by | |
| :class:`SummaryWriter` can be passed as keyword arguments in this logger. | |
| """ | |
| def __init__( | |
| self, | |
| save_dir: str, | |
| name: Optional[str] = "default", | |
| version: Optional[Union[int, str]] = None, | |
| log_graph: bool = False, | |
| default_hp_metric: bool = True, | |
| prefix: str = "", | |
| **kwargs | |
| ): | |
| super().__init__( | |
| save_dir, | |
| name=name, | |
| version=version, | |
| log_graph=log_graph, | |
| default_hp_metric=default_hp_metric, | |
| prefix=prefix, | |
| **kwargs | |
| ) | |
| def add_image(self, image: Union[np.ndarray, Figure], name: Optional[str] = None, **kwargs: Any): | |
| """Interface to add image to tensorboard logger. | |
| Args: | |
| image (Union[np.ndarray, Figure]): Image to log | |
| name (Optional[str]): The tag of the image | |
| kwargs: Accepts only `global_step` (int). The step at which to log the image. | |
| """ | |
| if "global_step" not in kwargs: | |
| raise ValueError("`global_step` is required for tensorboard logger") | |
| # Matplotlib Figure is not supported by tensorboard | |
| if isinstance(image, Figure): | |
| axis = image.gca() | |
| axis.axis("off") | |
| axis.margins(0) | |
| image.canvas.draw() # cache the renderer | |
| buffer = np.frombuffer(image.canvas.tostring_rgb(), dtype=np.uint8) | |
| image = buffer.reshape(image.canvas.get_width_height()[::-1] + (3,)) | |
| kwargs["dataformats"] = "HWC" | |
| self.experiment.add_image(img_tensor=image, tag=name, **kwargs) | |