Spaces:
Running
on
Zero
Running
on
Zero
| from pytorch_lightning.callbacks import Callback | |
| import os | |
| import shutil | |
| from omegaconf import OmegaConf | |
| class SetupCallback(Callback): | |
| def __init__(self, now, logdir, ckptdir, cfgdir, config, argv_content=None): | |
| super().__init__() | |
| self.now = now | |
| self.logdir = logdir | |
| self.ckptdir = ckptdir | |
| self.cfgdir = cfgdir | |
| self.config = config | |
| self.argv_content = argv_content | |
| # 在pretrain例程开始时调用。 | |
| def on_fit_start(self, trainer, pl_module): | |
| # Create logdirs and save configs | |
| os.makedirs(self.logdir, exist_ok=True) | |
| os.makedirs(self.ckptdir, exist_ok=True) | |
| os.makedirs(self.cfgdir, exist_ok=True) | |
| print("Project config") | |
| print(OmegaConf.to_yaml(self.config)) | |
| OmegaConf.save(self.config, | |
| os.path.join(self.cfgdir, "{}-project.yaml".format(self.now))) | |
| with open(os.path.join(self.logdir, "argv_content.txt"), "w") as f: | |
| f.write(str(self.argv_content)) | |
| class BackupCodeCallback(Callback): | |
| def __init__(self, source_dir, backup_dir, ignore_patterns=None): | |
| super().__init__() | |
| self.source_dir = source_dir | |
| self.backup_dir = backup_dir | |
| self.ignore_patterns = ignore_patterns | |
| def on_train_start(self, trainer, pl_module): | |
| try: | |
| os.makedirs(self.backup_dir, exist_ok=True) | |
| if os.path.exists(self.backup_dir+'/code'): | |
| shutil.rmtree(self.backup_dir+'/code') | |
| shutil.copytree(self.source_dir, self.backup_dir+'/code', ignore=self.ignore_patterns) | |
| print(f"Code file backed up to {self.backup_dir}") | |
| except: | |
| print(f"Fail in copying file backed up to {self.backup_dir}") |