File size: 2,461 Bytes
d5233a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import torch, json, math, os

d = {
    'debug': True,
    'dataset_path': 'data/path_to_your_dataset.json',
    'fptype': 'morgan',
    'valid_ratio': 0.1, 
    'batch_size': 128,
    'lr': 1e-3,
    'weight_decay': 1e-3,
    'patience': 2,
    'factor': 0.5,
    'add_nl': True,
    'binary_intn': False,
    'max_mz': 2000,
    'min_mz': 20,
    'energy': 'Energy1',
    'epochs': 50,
    'bin_size': 0.05, 
    'ms_embedding_dim': 300,
    'projection_dim': 256,
    'ms_projection_layers': 1, 
    'mol_embedding_dim': 2048,
    'mol_projection_layers': 1, 
    'tsfm_in_ms': True,
    'tsfm_in_mol': False,
    'tsfm_layers': 6,
    'tsfm_heads': 8,
    'lstm_layers': 2,
    'lstm_in_ms': False,
    'lstm_in_mol': False,
    'dropout': 0.1,
    'nmodels': 1,
    'mol_encoder': 'fp', # fp, gnn or gnn+fp
    'molgnn_n_filters_list': [256, 256, 256],
    'molgnn_nhead': 4,
    'molgnn_readout_layers': 2,
    'seed': 1234,
    'dev_name': 'cuda',
    'keep_best_models_num': 3
}

class ConfigDict(dict):
    '''
    Makes a  dictionary behave like an object,with attribute-style access.
    '''
    def __getattr__(self, name):
        try:
            return self[name]
        except:
            raise AttributeError(name)

    def __setattr__(self, name, value):
        self[name] = value

    def save(self, fn, onlyprint=False):
        if onlyprint:
            print(self)
        else:
            json.dump(self, open(fn, 'w'), indent=2)

    def load_dict(self, dic):
        for k, v in dic.items():
            self[k] = v
        self.calc_ms_embedding_dim()

    def load(self, fn):
        try:
            if type(fn) is dict:
                d = fn
            elif type(fn) is str:
                if os.path.exists(fn):
                    d = json.load(open(fn, 'r'))
                else:
                    d = json.loads(fn)
            self.load_dict(d)
        except Exception as e:
            print(e)

    def calc_ms_embedding_dim(self):
        if 'bin_size' in self:
            self['ms_embedding_dim'] = math.ceil((self['max_mz'] - self['min_mz']) / self['bin_size'])
        if 'ms_embedding_dim' in self and 'add_nl' in self and self['add_nl']:
            self['ms_embedding_dim'] += math.ceil((200) / self['bin_size'])

    @property
    def device(self):
        try:
            return torch.device(self['dev_name'])
        except:
            return torch.device('cpu')


CFG = ConfigDict()
CFG.load_dict(d)