Spaces:
Running
on
Zero
Running
on
Zero
File size: 1,827 Bytes
7968cb0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 |
from pytorch_lightning.callbacks import Callback
import os
import shutil
from omegaconf import OmegaConf
class SetupCallback(Callback):
def __init__(self, now, logdir, ckptdir, cfgdir, config, argv_content=None):
super().__init__()
self.now = now
self.logdir = logdir
self.ckptdir = ckptdir
self.cfgdir = cfgdir
self.config = config
self.argv_content = argv_content
# 在pretrain例程开始时调用。
def on_fit_start(self, trainer, pl_module):
# Create logdirs and save configs
os.makedirs(self.logdir, exist_ok=True)
os.makedirs(self.ckptdir, exist_ok=True)
os.makedirs(self.cfgdir, exist_ok=True)
print("Project config")
print(OmegaConf.to_yaml(self.config))
OmegaConf.save(self.config,
os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
with open(os.path.join(self.logdir, "argv_content.txt"), "w") as f:
f.write(str(self.argv_content))
class BackupCodeCallback(Callback):
def __init__(self, source_dir, backup_dir, ignore_patterns=None):
super().__init__()
self.source_dir = source_dir
self.backup_dir = backup_dir
self.ignore_patterns = ignore_patterns
def on_train_start(self, trainer, pl_module):
try:
os.makedirs(self.backup_dir, exist_ok=True)
if os.path.exists(self.backup_dir+'/code'):
shutil.rmtree(self.backup_dir+'/code')
shutil.copytree(self.source_dir, self.backup_dir+'/code', ignore=self.ignore_patterns)
print(f"Code file backed up to {self.backup_dir}")
except:
print(f"Fail in copying file backed up to {self.backup_dir}") |