CMSSP / code /config.py
OliXio's picture
Upload 13 files
d5233a9 verified
import torch, json, math, os
d = {
'debug': True,
'dataset_path': 'data/path_to_your_dataset.json',
'fptype': 'morgan',
'valid_ratio': 0.1,
'batch_size': 128,
'lr': 1e-3,
'weight_decay': 1e-3,
'patience': 2,
'factor': 0.5,
'add_nl': True,
'binary_intn': False,
'max_mz': 2000,
'min_mz': 20,
'energy': 'Energy1',
'epochs': 50,
'bin_size': 0.05,
'ms_embedding_dim': 300,
'projection_dim': 256,
'ms_projection_layers': 1,
'mol_embedding_dim': 2048,
'mol_projection_layers': 1,
'tsfm_in_ms': True,
'tsfm_in_mol': False,
'tsfm_layers': 6,
'tsfm_heads': 8,
'lstm_layers': 2,
'lstm_in_ms': False,
'lstm_in_mol': False,
'dropout': 0.1,
'nmodels': 1,
'mol_encoder': 'fp', # fp, gnn or gnn+fp
'molgnn_n_filters_list': [256, 256, 256],
'molgnn_nhead': 4,
'molgnn_readout_layers': 2,
'seed': 1234,
'dev_name': 'cuda',
'keep_best_models_num': 3
}
class ConfigDict(dict):
'''
Makes a dictionary behave like an object,with attribute-style access.
'''
def __getattr__(self, name):
try:
return self[name]
except:
raise AttributeError(name)
def __setattr__(self, name, value):
self[name] = value
def save(self, fn, onlyprint=False):
if onlyprint:
print(self)
else:
json.dump(self, open(fn, 'w'), indent=2)
def load_dict(self, dic):
for k, v in dic.items():
self[k] = v
self.calc_ms_embedding_dim()
def load(self, fn):
try:
if type(fn) is dict:
d = fn
elif type(fn) is str:
if os.path.exists(fn):
d = json.load(open(fn, 'r'))
else:
d = json.loads(fn)
self.load_dict(d)
except Exception as e:
print(e)
def calc_ms_embedding_dim(self):
if 'bin_size' in self:
self['ms_embedding_dim'] = math.ceil((self['max_mz'] - self['min_mz']) / self['bin_size'])
if 'ms_embedding_dim' in self and 'add_nl' in self and self['add_nl']:
self['ms_embedding_dim'] += math.ceil((200) / self['bin_size'])
@property
def device(self):
try:
return torch.device(self['dev_name'])
except:
return torch.device('cpu')
CFG = ConfigDict()
CFG.load_dict(d)