File size: 1,827 Bytes
7968cb0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
from pytorch_lightning.callbacks import Callback
import os
import shutil
from omegaconf import OmegaConf

class SetupCallback(Callback):
    def __init__(self,  now, logdir, ckptdir, cfgdir, config, argv_content=None):
        super().__init__()
        self.now = now
        self.logdir = logdir
        self.ckptdir = ckptdir
        self.cfgdir = cfgdir
        self.config = config
    
        self.argv_content = argv_content

    # 在pretrain例程开始时调用。
    def on_fit_start(self, trainer, pl_module):
        # Create logdirs and save configs
        os.makedirs(self.logdir, exist_ok=True)
        os.makedirs(self.ckptdir, exist_ok=True)
        os.makedirs(self.cfgdir, exist_ok=True)

        print("Project config")
        print(OmegaConf.to_yaml(self.config))
        OmegaConf.save(self.config,
                        os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)))
        
        with open(os.path.join(self.logdir, "argv_content.txt"), "w") as f:
            f.write(str(self.argv_content))

class BackupCodeCallback(Callback):
    def __init__(self, source_dir, backup_dir, ignore_patterns=None):
        super().__init__()
        self.source_dir = source_dir
        self.backup_dir = backup_dir
        self.ignore_patterns = ignore_patterns

    def on_train_start(self, trainer, pl_module):
        try:
            os.makedirs(self.backup_dir, exist_ok=True)
            if os.path.exists(self.backup_dir+'/code'):
                shutil.rmtree(self.backup_dir+'/code')
            shutil.copytree(self.source_dir, self.backup_dir+'/code', ignore=self.ignore_patterns)

            print(f"Code file backed up to {self.backup_dir}")
        except:
            print(f"Fail in copying file backed up to {self.backup_dir}")