Spaces:
Paused
Paused
| import os | |
| import glob | |
| class TrainPlatform: | |
| def __init__(self, save_dir, *args, **kwargs): | |
| self.path, file = os.path.split(save_dir) | |
| self.name = kwargs.get('name', file) | |
| def report_scalar(self, name, value, iteration, group_name=None): | |
| pass | |
| def report_media(self, title, series, iteration, local_path): | |
| pass | |
| def report_args(self, args, name): | |
| pass | |
| def close(self): | |
| pass | |
| # Deprecated | |
| class ClearmlPlatform(TrainPlatform): | |
| def __init__(self, save_dir): | |
| from clearml import Task | |
| path, name = os.path.split(save_dir) | |
| self.task = Task.init(project_name='motion_diffusion', | |
| task_name=name, | |
| output_uri=path) | |
| self.logger = self.task.get_logger() | |
| def report_scalar(self, name, value, iteration, group_name): | |
| self.logger.report_scalar(title=group_name, series=name, iteration=iteration, value=value) | |
| def report_media(self, title, series, iteration, local_path): | |
| self.logger.report_media(title=title, series=series, iteration=iteration, local_path=local_path) | |
| def report_args(self, args, name): | |
| self.task.connect(args, name=name) | |
| def close(self): | |
| self.task.close() | |
| class TensorboardPlatform(TrainPlatform): | |
| def __init__(self, save_dir): | |
| from torch.utils.tensorboard import SummaryWriter | |
| self.writer = SummaryWriter(log_dir=save_dir) | |
| def report_scalar(self, name, value, iteration, group_name=None): | |
| self.writer.add_scalar(f'{group_name}/{name}', value, iteration) | |
| def close(self): | |
| self.writer.close() | |
| class NoPlatform(TrainPlatform): | |
| def __init__(self, save_dir, *args, **kwargs): | |
| pass | |
| class WandBPlatform(TrainPlatform): | |
| import wandb | |
| def __init__(self, save_dir, config=None, *args, **kwargs): | |
| super().__init__(save_dir, *args, **kwargs) | |
| self.wandb.login(host=os.getenv("WANDB_BASE_URL"), key=os.getenv("WANDB_API_KEY")) | |
| self.wandb.init( | |
| project='motion_diffusion', | |
| name=self.name, | |
| id=self.name, # in order to send continued runs to the same record | |
| resume='allow', # in order to send continued runs to the same record | |
| entity='tau-motion', # will use your default entity if not set | |
| save_code=True, | |
| config=config) # config can also be sent via report_args() | |
| def report_scalar(self, name, value, iteration, group_name=None): | |
| self.wandb.log({name: value}, step=iteration) | |
| def report_media(self, title, series, iteration, local_path): | |
| files = glob.glob(f'{local_path}/*.mp4') | |
| self.wandb.log({series: [self.wandb.Video(file, format='mp4', fps=20) for file in files]}, step=iteration) | |
| def report_args(self, args, name): | |
| self.wandb.config.update(args) #, allow_val_change=True) # use allow_val_change ONLY if you want to change existing args (e.g., overwrite) | |
| def watch_model(self, *args, **kwargs): | |
| self.wandb.watch(args, kwargs) | |
| def close(self): | |
| self.wandb.finish() | |