Spaces:
Sleeping
Sleeping
| from typing import TYPE_CHECKING | |
| from tensorboardX import SummaryWriter | |
| if TYPE_CHECKING: | |
| # TYPE_CHECKING is always False at runtime, but mypy will evaluate the contents of this block. | |
| # So if you import this module within TYPE_CHECKING, you will get code hints and other benefits. | |
| # Here is a good answer on stackoverflow: | |
| # https://stackoverflow.com/questions/39740632/python-type-hinting-without-cyclic-imports | |
| from ding.framework import Parallel | |
| class DistributedWriter(SummaryWriter): | |
| """ | |
| Overview: | |
| A simple subclass of SummaryWriter that supports writing to one process in multi-process mode. | |
| The best way is to use it in conjunction with the ``router`` to take advantage of the message \ | |
| and event components of the router (see ``writer.plugin``). | |
| Interfaces: | |
| ``get_instance``, ``plugin``, ``initialize``, ``__del__`` | |
| """ | |
| root = None | |
| def __init__(self, *args, **kwargs): | |
| """ | |
| Overview: | |
| Initialize the DistributedWriter object. | |
| Arguments: | |
| - args (:obj:`Tuple`): The arguments passed to the ``__init__`` function of the parent class, \ | |
| SummaryWriter. | |
| - kwargs (:obj:`Dict`): The keyword arguments passed to the ``__init__`` function of the parent class, \ | |
| SummaryWriter. | |
| """ | |
| self._default_writer_to_disk = kwargs.get("write_to_disk") if "write_to_disk" in kwargs else True | |
| # We need to write data to files lazily, so we should not use file writer in __init__, | |
| # On the contrary, we will initialize the file writer when the user calls the | |
| # add_* function for the first time | |
| kwargs["write_to_disk"] = False | |
| super().__init__(*args, **kwargs) | |
| self._in_parallel = False | |
| self._router = None | |
| self._is_writer = False | |
| self._lazy_initialized = False | |
| def get_instance(cls, *args, **kwargs) -> "DistributedWriter": | |
| """ | |
| Overview: | |
| Get instance and set the root level instance on the first called. If args and kwargs is none, | |
| this method will return root instance. | |
| Arguments: | |
| - args (:obj:`Tuple`): The arguments passed to the ``__init__`` function of the parent class, \ | |
| SummaryWriter. | |
| - kwargs (:obj:`Dict`): The keyword arguments passed to the ``__init__`` function of the parent class, \ | |
| SummaryWriter. | |
| """ | |
| if args or kwargs: | |
| ins = cls(*args, **kwargs) | |
| if cls.root is None: | |
| cls.root = ins | |
| return ins | |
| else: | |
| return cls.root | |
| def plugin(self, router: "Parallel", is_writer: bool = False) -> "DistributedWriter": | |
| """ | |
| Overview: | |
| Plugin ``router``, so when using this writer with active router, it will automatically send requests\ | |
| to the main writer instead of writing it to the disk. So we can collect data from multiple processes\ | |
| and write them into one file. | |
| Arguments: | |
| - router (:obj:`Parallel`): The router to be plugged in. | |
| - is_writer (:obj:`bool`): Whether this writer is the main writer. | |
| Examples: | |
| >>> DistributedWriter().plugin(router, is_writer=True) | |
| """ | |
| if router.is_active: | |
| self._in_parallel = True | |
| self._router = router | |
| self._is_writer = is_writer | |
| if is_writer: | |
| self.initialize() | |
| self._lazy_initialized = True | |
| router.on("distributed_writer", self._on_distributed_writer) | |
| return self | |
| def _on_distributed_writer(self, fn_name: str, *args, **kwargs): | |
| """ | |
| Overview: | |
| This method is called when the router receives a request to write data. | |
| Arguments: | |
| - fn_name (:obj:`str`): The name of the function to be called. | |
| - args (:obj:`Tuple`): The arguments passed to the function to be called. | |
| - kwargs (:obj:`Dict`): The keyword arguments passed to the function to be called. | |
| """ | |
| if self._is_writer: | |
| getattr(self, fn_name)(*args, **kwargs) | |
| def initialize(self): | |
| """ | |
| Overview: | |
| Initialize the file writer. | |
| """ | |
| self.close() | |
| self._write_to_disk = self._default_writer_to_disk | |
| self._get_file_writer() | |
| self._lazy_initialized = True | |
| def __del__(self): | |
| """ | |
| Overview: | |
| Close the file writer. | |
| """ | |
| self.close() | |
| def enable_parallel(fn_name, fn): | |
| """ | |
| Overview: | |
| Decorator to enable parallel writing. | |
| Arguments: | |
| - fn_name (:obj:`str`): The name of the function to be called. | |
| - fn (:obj:`Callable`): The function to be called. | |
| """ | |
| def _parallel_fn(self: DistributedWriter, *args, **kwargs): | |
| if not self._lazy_initialized: | |
| self.initialize() | |
| if self._in_parallel and not self._is_writer: | |
| self._router.emit("distributed_writer", fn_name, *args, **kwargs) | |
| else: | |
| fn(self, *args, **kwargs) | |
| return _parallel_fn | |
| ready_to_parallel_fns = [ | |
| 'add_audio', | |
| 'add_custom_scalars', | |
| 'add_custom_scalars_marginchart', | |
| 'add_custom_scalars_multilinechart', | |
| 'add_embedding', | |
| 'add_figure', | |
| 'add_graph', | |
| 'add_graph_deprecated', | |
| 'add_histogram', | |
| 'add_histogram_raw', | |
| 'add_hparams', | |
| 'add_image', | |
| 'add_image_with_boxes', | |
| 'add_images', | |
| 'add_mesh', | |
| 'add_onnx_graph', | |
| 'add_openvino_graph', | |
| 'add_pr_curve', | |
| 'add_pr_curve_raw', | |
| 'add_scalar', | |
| 'add_scalars', | |
| 'add_text', | |
| 'add_video', | |
| ] | |
| for fn_name in ready_to_parallel_fns: | |
| if hasattr(DistributedWriter, fn_name): | |
| setattr(DistributedWriter, fn_name, enable_parallel(fn_name, getattr(DistributedWriter, fn_name))) | |
| # Examples: | |
| # In main, `distributed_writer.plugin(task.router, is_writer=True)`, | |
| # In middleware, `distributed_writer.record()` | |
| distributed_writer = DistributedWriter() | |