CMSSP / code /train.py
OliXio's picture
Upload 5 files
5946936 verified
from utils import *
from modules import *
import os, sys
import numpy as np
from tqdm import tqdm
import random
import torch
from torch import nn
from config import CFG
from dataset import *
import torch.utils.data
import copy, json, pickle
import itertools as it
import loss
loss_func = loss.infoNCE_loss2
def make_next_record_dir(basedir, prefix=''):
path = '%s/%%s001/' %basedir
n = 2
while os.path.exists(path %prefix):
path = '%s/%%s%.3d/' %(basedir, n)
n += 1
pth = path %prefix
os.makedirs(pth)
return pth
def setup_seed(seed):
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
def my_collate(batch):
batch = list(filter(lambda x:(x is not None), batch))
msbinl, molfpl, molfml, vl, al, msl = [], [], [], [], [], []
bat = {}
for b in batch:
if 'ms_bins' in b:
msbinl.append(b['ms_bins'])
if 'mol_fps' in b:
molfpl.append(b['mol_fps'])
if 'mol_fmvec' in b:
molfml.append(b['mol_fmvec'])
if 'V' in b:
vl.append(b['V'])
if 'A' in b:
al.append(b['A'])
if 'mol_size' in b:
msl.append(b['mol_size'])
if msbinl:
bat['ms_bins'] = torch.stack(msbinl)
if molfpl:
bat['mol_fps'] = torch.stack(molfpl)
if molfml:
bat['mol_fmvec'] = torch.stack(molfml)
if vl and al and msl:
max_n = max(map(lambda x:x.shape[0], vl))
vl1, al1 = [], []
for v in vl:
vl1.append(pad_V(v, max_n))
for a in al:
al1.append(pad_A(a, max_n))
bat['V'] = torch.stack(vl1)
bat['A'] = torch.stack(al1)
bat['mol_size'] = torch.cat(msl, dim=0)
#return torch.utils.data.dataloader.default_collate(batch)
return bat
def make_train_valid(data, valid_ratio, seed=1234):
idxs = np.arange(len(data))
np.random.seed(seed)
np.random.shuffle(idxs)
lenval = int(valid_ratio*len(data))
valid_set = [ data[i] for i in idxs[:lenval] ]
train_set = [ data[i] for i in idxs[lenval:] ]
return train_set, valid_set
def build_loaders(inp, mode, cfg, num_workers):
if type(inp[0]) is dict:
dataset = Dataset(inp, cfg)
else:
dataset = PathDataset(inp, cfg)
dataloader = torch.utils.data.DataLoader(
dataset,
batch_size=cfg.batch_size,
num_workers=num_workers,
shuffle=True if mode == "train" else False,
collate_fn=my_collate
)
return dataloader
def train_epoch(model, train_loader, optimizer, lr_scheduler, step):
model.train()
loss_meter = AvgMeter()
tqdm_object = tqdm(train_loader, total=len(train_loader))
total_cos_sim = 0
for batch in tqdm_object:
for k, v in batch.items():
batch[k] = v.to(CFG.device)
optimizer.zero_grad()
mol_features, ms_features = model(batch)
loss = loss_func(mol_features, ms_features)
loss.backward()
optimizer.step()
with torch.no_grad():
cos_sim = F.cosine_similarity(
mol_features.detach(),
ms_features.detach()
).mean().item()
total_cos_sim += cos_sim
if step == "batch":
lr_scheduler.step()
count = batch["ms_bins"].size(0)
loss_meter.update(loss.item(), count)
tqdm_object.set_postfix(train_loss=loss_meter.avg, train_cos_sim=round(cos_sim, 4), lr=get_lr(optimizer))
del mol_features, ms_features, loss, cos_sim
for k in list(batch.keys()):
del batch[k]
del batch
return loss_meter, total_cos_sim / len(train_loader)
def valid_epoch(model, valid_loader):
model.eval()
loss_meter = AvgMeter()
total_cos_sim = 0
with torch.no_grad():
tqdm_object = tqdm(valid_loader, total=len(valid_loader))
for batch in tqdm_object:
for k, v in batch.items():
batch[k] = v.to(CFG.device)
mol_features, ms_features = model(batch)
loss = loss_func(mol_features, ms_features)
count = batch["ms_bins"].size(0)
loss_meter.update(loss.item(), count)
cos_sim = F.cosine_similarity(mol_features.detach(), ms_features.detach()).mean().item()
total_cos_sim += cos_sim
tqdm_object.set_postfix(valid_loss=loss_meter.avg, valid_cos_sim=round(cos_sim, 4))
del mol_features, ms_features, loss, cos_sim
for k in list(batch.keys()):
del batch[k]
del batch
return loss_meter, total_cos_sim / len(valid_loader)
def main(data, cfg=CFG, savedir='data/train', model_path=None, ratio=1):
setup_seed(cfg.seed)
train_set, valid_set = make_train_valid(data, valid_ratio=cfg.valid_ratio, seed=cfg.seed)
log_file = f'{savedir}/trainlog.txt'
n = len(train_set)
if ratio < 1:
train_set = random.sample(train_set, int(n*ratio))
print(f'Ratio {ratio}, lenall {n}, newtrainset {len(train_set)}')
train_loader = build_loaders(train_set, "train", cfg, 1)
valid_loader = build_loaders(valid_set, "valid", cfg, 1)
step = "epoch"
best_loss = float('inf')
best_model_fn = ''
best_model_fns = []
model = FragSimiModel(cfg).to(cfg.device)
print(model)
optimizer = torch.optim.AdamW(
model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay
)
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
optimizer, mode="min", patience=cfg.patience, factor=cfg.factor
)
# Load pre-trained model if path is provided
if model_path and os.path.exists(model_path):
print(f"Loading model from {model_path}")
checkpoint = torch.load(model_path, map_location=cfg.device)
model.load_state_dict(checkpoint['state_dict'])
'''if 'optimizer' in checkpoint:
optimizer.load_state_dict(checkpoint['optimizer'])
print("Loaded optimizer state")'''
print(f"Resuming training")
del checkpoint
# write training log
with open(log_file, 'a', encoding='utf8') as f:
f.write(f'Start training:\n')
f.write(f'Data path: {cfg.dataset_path}, valid ratio: {cfg.valid_ratio}\n')
if model_path:
f.write(f'Resuming from: {model_path}\n')
print(model, file=f)
f.write(f'\n')
for epoch in range(cfg.epochs):
print(f"Epoch: {epoch + 1}/{cfg.epochs}")
train_loss, t_cos_sim = train_epoch(model, train_loader, optimizer, lr_scheduler, step)
valid_loss, v_cos_sim = valid_epoch(model, valid_loader)
txt = f"Train Loss: {train_loss.avg:.4f} | Val Loss: {valid_loss.avg:.4f} | Train cos sim: {t_cos_sim:.4f} | Val cos sim: {v_cos_sim:.4f}"
print(txt)
open(log_file, 'a').write(f"Epoch {epoch + 1}/{cfg.epochs}: {txt}\n")
if True: #valid_loss.avg < best_loss:
best_loss = valid_loss.avg
best_model_fn = f"{savedir}/model-tloss{round(train_loss.avg, 3)}-vloss{round(valid_loss.avg, 3)}-tcos{round(t_cos_sim, 3)}-vcos{round(v_cos_sim, 3)}-epoch{epoch}.pth"
best_model_fn_base = best_model_fn.replace('.pth', '')
n = 1
while os.path.exists(best_model_fn):
best_model_fn = best_model_fn_base + f'-{n}.pth'
n += 1
best_model_fns.append(best_model_fn)
torch.save({'epoch': epoch,
'state_dict': model.state_dict(),
'optimizer': optimizer.state_dict(),
'config': dict(CFG),
'train_loss': train_loss.avg,
'valid_loss': valid_loss.avg,
'train_cos_sim': t_cos_sim,
'val_cos_sim': v_cos_sim
}, best_model_fn)
print("Saved new best model!")
best_model_fnl = []
for fn in best_model_fns:
if os.path.exists(fn):
best_model_fnl.append(fn)
for fn in best_model_fnl[:-cfg.keep_best_models_num]:
os.remove(fn)
best_model_fnl = best_model_fnl[-cfg.keep_best_models_num:]
print(best_model_fnl, best_loss)
return best_model_fnl, best_loss
if __name__ == "__main__":
import pickle
from tqdm import tqdm
try:
conffn = sys.argv[1]
if conffn.endswith('.json'):
CFG.load(sys.argv[1])
elif conffn.endswith('.pth'):
dpath = CFG.dataset_path
d = torch.load(conffn)
CFG.load(d['config'])
CFG.dataset_path = dpath
print('Use config from', conffn)
except:
pass
try:
savedir = sys.argv[2]
except:
savedir = 'data/'
os.system('mkdir -p %s' %savedir)
try:
prev_model_pth = sys.argv[3]
except:
prev_model_pth = None
print(CFG)
if os.path.isdir(CFG.dataset_path):
data = [os.path.join(CFG.dataset_path, i) for i in os.listdir(CFG.dataset_path) if i.endswith('mgf')]
elif os.path.isfile(CFG.dataset_path):
if CFG.dataset_path.endswith('.pkl'):
print(f'loading data from {CFG.dataset_path} ...')
data = pickle.load(open(CFG.dataset_path, 'rb'))
else:
data = json.load(open(CFG.dataset_path))
pklfn = CFG.dataset_path.replace('.json', '.pkl')
if not os.path.exists(pklfn):
pickle.dump(data, open(pklfn, 'wb'))
subdir = make_next_record_dir(savedir, f'train-neg-')
os.system(f'cp -a *py {subdir}; cp -a GNN {subdir}')
CFG.save(f'{subdir}/config.json')
modelfnl, _ = main(data, CFG, subdir, prev_model_pth)