Spaces:
Configuration error
Configuration error
| import importlib | |
| from utils import Timer | |
| class MLFlow: | |
| def __init__(self, log_dir, logger, enabled): | |
| self.mlflow = None | |
| if enabled: | |
| log_dir = str(log_dir) | |
| # Retrieve visualization writer. | |
| try: | |
| self.mlflow = importlib.import_module("mlflow") | |
| succeeded = True | |
| except ImportError: | |
| succeeded = False | |
| if not succeeded: | |
| message = "Warning: visualization (mlflow) is configured to use, but currently not installed on " \ | |
| "this machine. Please install mlflow with 'pip install mlflow or turn off the option in " \ | |
| "the 'config.json' file." | |
| logger.warning(message) | |
| self.step = 0 | |
| self.mode = '' | |
| self.mlflow_ftns_with_tag_and_value = { | |
| 'log_param', 'log_metric' | |
| } | |
| self.mlflow_ftns = { | |
| 'start_run' | |
| } | |
| # self.tag_mode_exceptions = {'add_histogram', 'add_embedding'} | |
| # self.timer = Timer() | |
| # def set_step(self, step, mode='train'): | |
| # self.mode = mode | |
| # self.step = step | |
| # if step == 0: | |
| # self.timer.reset() | |
| # else: | |
| # duration = self.timer.check() | |
| # self.add_scalar('steps_per_sec', 1 / duration) | |
| def __getattr__(self, name): | |
| """ | |
| If visualization is configured to use: | |
| return add_data() methods of tensorboard with additional information (step, tag) added. | |
| Otherwise: | |
| return a blank function handle that does nothing | |
| """ | |
| if name in self.mlflow_ftns_with_tag_and_value: | |
| add_data = getattr(self.mlflow, name, None) | |
| def wrapper(tag, data, *args, **kwargs): | |
| if add_data is not None: | |
| # add mode(train/valid) tag | |
| if name not in self.tag_mode_exceptions: | |
| tag = '{}/{}'.format(tag, self.mode) | |
| add_data(tag, data, *args, **kwargs) | |
| return wrapper | |
| elif name in self.mlflow_ftns: | |
| add_data = getattr(self.mlflow, name, None) | |
| def wrapper(*args, **kwargs): | |
| if add_data is not None: | |
| # add mode(train/valid) tag | |
| # if name not in self.tag_mode_exceptions: | |
| # tag = '{}/{}'.format(tag, self.mode) | |
| add_data(*args, **kwargs) | |
| return wrapper | |
| else: | |
| # default action for returning methods defined in this class, set_step() for instance. | |
| try: | |
| attr = object.__getattr__(name) | |
| except AttributeError: | |
| raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name)) | |
| return attr | |
| class TensorboardWriter: | |
| def __init__(self, log_dir, logger, enabled): | |
| self.writer = None | |
| self.selected_module = "" | |
| if enabled: | |
| log_dir = str(log_dir) | |
| # Retrieve vizualization writer. | |
| succeeded = False | |
| for module in ["torch.utils.tensorboard", "tensorboardX"]: | |
| try: | |
| self.writer = importlib.import_module(module).SummaryWriter(log_dir) | |
| succeeded = True | |
| break | |
| except ImportError: | |
| succeeded = False | |
| self.selected_module = module | |
| if not succeeded: | |
| message = "Warning: visualization (Tensorboard) is configured to use, but currently not installed on " \ | |
| "this machine. Please install either TensorboardX with 'pip install tensorboardx', upgrade " \ | |
| "PyTorch to version >= 1.1 for using 'torch.utils.tensorboard' or turn off the option in " \ | |
| "the 'config.json' file." | |
| logger.warning(message) | |
| self.step = 0 | |
| self.mode = '' | |
| self.tb_writer_ftns = { | |
| 'add_scalar', 'add_scalars', 'add_image', 'add_images', 'add_audio', | |
| 'add_text', 'add_histogram', 'add_pr_curve', 'add_embedding' | |
| } | |
| self.tag_mode_exceptions = {'add_histogram', 'add_embedding'} | |
| self.timer = Timer() | |
| def set_step(self, step, mode='train'): | |
| self.mode = mode | |
| self.step = step | |
| if step == 0: | |
| self.timer.reset() | |
| else: | |
| duration = self.timer.check() | |
| self.add_scalar('steps_per_sec', 1 / duration) | |
| def __getattr__(self, name): | |
| """ | |
| If visualization is configured to use: | |
| return add_data() methods of tensorboard with additional information (step, tag) added. | |
| Otherwise: | |
| return a blank function handle that does nothing | |
| """ | |
| if name in self.tb_writer_ftns: | |
| add_data = getattr(self.writer, name, None) | |
| def wrapper(tag, data, *args, **kwargs): | |
| if add_data is not None: | |
| # add mode(train/valid) tag | |
| if name not in self.tag_mode_exceptions: | |
| tag = '{}/{}'.format(tag, self.mode) | |
| add_data(tag, data, self.step, *args, **kwargs) | |
| return wrapper | |
| else: | |
| # default action for returning methods defined in this class, set_step() for instance. | |
| try: | |
| attr = object.__getattr__(name) | |
| except AttributeError: | |
| raise AttributeError("type object '{}' has no attribute '{}'".format(self.selected_module, name)) | |
| return attr | |