Spaces:
Runtime error
Runtime error
| """ | |
| Copyright (c) Meta Platforms, Inc. and affiliates. | |
| All rights reserved. | |
| This source code is licensed under the license found in the | |
| LICENSE file in the root directory of this source tree. | |
| """ | |
| import os | |
| class TrainPlatform: | |
| def __init__(self, save_dir): | |
| pass | |
| def report_scalar(self, name, value, iteration, group_name=None): | |
| pass | |
| def report_args(self, args, name): | |
| pass | |
| def close(self): | |
| pass | |
| class ClearmlPlatform(TrainPlatform): | |
| def __init__(self, save_dir): | |
| from clearml import Task | |
| path, name = os.path.split(save_dir) | |
| self.task = Task.init(project_name='motion_diffusion', | |
| task_name=name, | |
| output_uri=path) | |
| self.logger = self.task.get_logger() | |
| def report_scalar(self, name, value, iteration, group_name): | |
| self.logger.report_scalar(title=group_name, series=name, iteration=iteration, value=value) | |
| def report_args(self, args, name): | |
| self.task.connect(args, name=name) | |
| def close(self): | |
| self.task.close() | |
| class TensorboardPlatform(TrainPlatform): | |
| def __init__(self, save_dir): | |
| from torch.utils.tensorboard import SummaryWriter | |
| self.writer = SummaryWriter(log_dir=save_dir) | |
| def report_scalar(self, name, value, iteration, group_name=None): | |
| self.writer.add_scalar(f'{group_name}/{name}', value, iteration) | |
| def close(self): | |
| self.writer.close() | |
| class NoPlatform(TrainPlatform): | |
| def __init__(self, save_dir): | |
| pass | |