from utils import * from modules import * import os, sys import numpy as np from tqdm import tqdm import random import torch from torch import nn from config import CFG from dataset import * import torch.utils.data import copy, json, pickle import itertools as it import loss loss_func = loss.infoNCE_loss2 def make_next_record_dir(basedir, prefix=''): path = '%s/%%s001/' %basedir n = 2 while os.path.exists(path %prefix): path = '%s/%%s%.3d/' %(basedir, n) n += 1 pth = path %prefix os.makedirs(pth) return pth def setup_seed(seed): torch.manual_seed(seed) torch.cuda.manual_seed(seed) np.random.seed(seed) random.seed(seed) torch.backends.cudnn.deterministic = True def my_collate(batch): batch = list(filter(lambda x:(x is not None), batch)) msbinl, molfpl, molfml, vl, al, msl = [], [], [], [], [], [] bat = {} for b in batch: if 'ms_bins' in b: msbinl.append(b['ms_bins']) if 'mol_fps' in b: molfpl.append(b['mol_fps']) if 'mol_fmvec' in b: molfml.append(b['mol_fmvec']) if 'V' in b: vl.append(b['V']) if 'A' in b: al.append(b['A']) if 'mol_size' in b: msl.append(b['mol_size']) if msbinl: bat['ms_bins'] = torch.stack(msbinl) if molfpl: bat['mol_fps'] = torch.stack(molfpl) if molfml: bat['mol_fmvec'] = torch.stack(molfml) if vl and al and msl: max_n = max(map(lambda x:x.shape[0], vl)) vl1, al1 = [], [] for v in vl: vl1.append(pad_V(v, max_n)) for a in al: al1.append(pad_A(a, max_n)) bat['V'] = torch.stack(vl1) bat['A'] = torch.stack(al1) bat['mol_size'] = torch.cat(msl, dim=0) #return torch.utils.data.dataloader.default_collate(batch) return bat def make_train_valid(data, valid_ratio, seed=1234): idxs = np.arange(len(data)) np.random.seed(seed) np.random.shuffle(idxs) lenval = int(valid_ratio*len(data)) valid_set = [ data[i] for i in idxs[:lenval] ] train_set = [ data[i] for i in idxs[lenval:] ] return train_set, valid_set def build_loaders(inp, mode, cfg, num_workers): if type(inp[0]) is dict: dataset = Dataset(inp, cfg) else: dataset = PathDataset(inp, cfg) dataloader = torch.utils.data.DataLoader( dataset, batch_size=cfg.batch_size, num_workers=num_workers, shuffle=True if mode == "train" else False, collate_fn=my_collate ) return dataloader def train_epoch(model, train_loader, optimizer, lr_scheduler, step): model.train() loss_meter = AvgMeter() tqdm_object = tqdm(train_loader, total=len(train_loader)) total_cos_sim = 0 for batch in tqdm_object: for k, v in batch.items(): batch[k] = v.to(CFG.device) optimizer.zero_grad() mol_features, ms_features = model(batch) loss = loss_func(mol_features, ms_features) loss.backward() optimizer.step() with torch.no_grad(): cos_sim = F.cosine_similarity( mol_features.detach(), ms_features.detach() ).mean().item() total_cos_sim += cos_sim if step == "batch": lr_scheduler.step() count = batch["ms_bins"].size(0) loss_meter.update(loss.item(), count) tqdm_object.set_postfix(train_loss=loss_meter.avg, train_cos_sim=round(cos_sim, 4), lr=get_lr(optimizer)) del mol_features, ms_features, loss, cos_sim for k in list(batch.keys()): del batch[k] del batch return loss_meter, total_cos_sim / len(train_loader) def valid_epoch(model, valid_loader): model.eval() loss_meter = AvgMeter() total_cos_sim = 0 with torch.no_grad(): tqdm_object = tqdm(valid_loader, total=len(valid_loader)) for batch in tqdm_object: for k, v in batch.items(): batch[k] = v.to(CFG.device) mol_features, ms_features = model(batch) loss = loss_func(mol_features, ms_features) count = batch["ms_bins"].size(0) loss_meter.update(loss.item(), count) cos_sim = F.cosine_similarity(mol_features.detach(), ms_features.detach()).mean().item() total_cos_sim += cos_sim tqdm_object.set_postfix(valid_loss=loss_meter.avg, valid_cos_sim=round(cos_sim, 4)) del mol_features, ms_features, loss, cos_sim for k in list(batch.keys()): del batch[k] del batch return loss_meter, total_cos_sim / len(valid_loader) def main(data, cfg=CFG, savedir='data/train', model_path=None, ratio=1): setup_seed(cfg.seed) train_set, valid_set = make_train_valid(data, valid_ratio=cfg.valid_ratio, seed=cfg.seed) log_file = f'{savedir}/trainlog.txt' n = len(train_set) if ratio < 1: train_set = random.sample(train_set, int(n*ratio)) print(f'Ratio {ratio}, lenall {n}, newtrainset {len(train_set)}') train_loader = build_loaders(train_set, "train", cfg, 1) valid_loader = build_loaders(valid_set, "valid", cfg, 1) step = "epoch" best_loss = float('inf') best_model_fn = '' best_model_fns = [] model = FragSimiModel(cfg).to(cfg.device) print(model) optimizer = torch.optim.AdamW( model.parameters(), lr=cfg.lr, weight_decay=cfg.weight_decay ) lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau( optimizer, mode="min", patience=cfg.patience, factor=cfg.factor ) # Load pre-trained model if path is provided if model_path and os.path.exists(model_path): print(f"Loading model from {model_path}") checkpoint = torch.load(model_path, map_location=cfg.device) model.load_state_dict(checkpoint['state_dict']) '''if 'optimizer' in checkpoint: optimizer.load_state_dict(checkpoint['optimizer']) print("Loaded optimizer state")''' print(f"Resuming training") del checkpoint # write training log with open(log_file, 'a', encoding='utf8') as f: f.write(f'Start training:\n') f.write(f'Data path: {cfg.dataset_path}, valid ratio: {cfg.valid_ratio}\n') if model_path: f.write(f'Resuming from: {model_path}\n') print(model, file=f) f.write(f'\n') for epoch in range(cfg.epochs): print(f"Epoch: {epoch + 1}/{cfg.epochs}") train_loss, t_cos_sim = train_epoch(model, train_loader, optimizer, lr_scheduler, step) valid_loss, v_cos_sim = valid_epoch(model, valid_loader) txt = f"Train Loss: {train_loss.avg:.4f} | Val Loss: {valid_loss.avg:.4f} | Train cos sim: {t_cos_sim:.4f} | Val cos sim: {v_cos_sim:.4f}" print(txt) open(log_file, 'a').write(f"Epoch {epoch + 1}/{cfg.epochs}: {txt}\n") if True: #valid_loss.avg < best_loss: best_loss = valid_loss.avg best_model_fn = f"{savedir}/model-tloss{round(train_loss.avg, 3)}-vloss{round(valid_loss.avg, 3)}-tcos{round(t_cos_sim, 3)}-vcos{round(v_cos_sim, 3)}-epoch{epoch}.pth" best_model_fn_base = best_model_fn.replace('.pth', '') n = 1 while os.path.exists(best_model_fn): best_model_fn = best_model_fn_base + f'-{n}.pth' n += 1 best_model_fns.append(best_model_fn) torch.save({'epoch': epoch, 'state_dict': model.state_dict(), 'optimizer': optimizer.state_dict(), 'config': dict(CFG), 'train_loss': train_loss.avg, 'valid_loss': valid_loss.avg, 'train_cos_sim': t_cos_sim, 'val_cos_sim': v_cos_sim }, best_model_fn) print("Saved new best model!") best_model_fnl = [] for fn in best_model_fns: if os.path.exists(fn): best_model_fnl.append(fn) for fn in best_model_fnl[:-cfg.keep_best_models_num]: os.remove(fn) best_model_fnl = best_model_fnl[-cfg.keep_best_models_num:] print(best_model_fnl, best_loss) return best_model_fnl, best_loss if __name__ == "__main__": import pickle from tqdm import tqdm try: conffn = sys.argv[1] if conffn.endswith('.json'): CFG.load(sys.argv[1]) elif conffn.endswith('.pth'): dpath = CFG.dataset_path d = torch.load(conffn) CFG.load(d['config']) CFG.dataset_path = dpath print('Use config from', conffn) except: pass try: savedir = sys.argv[2] except: savedir = 'data/' os.system('mkdir -p %s' %savedir) try: prev_model_pth = sys.argv[3] except: prev_model_pth = None print(CFG) if os.path.isdir(CFG.dataset_path): data = [os.path.join(CFG.dataset_path, i) for i in os.listdir(CFG.dataset_path) if i.endswith('mgf')] elif os.path.isfile(CFG.dataset_path): if CFG.dataset_path.endswith('.pkl'): print(f'loading data from {CFG.dataset_path} ...') data = pickle.load(open(CFG.dataset_path, 'rb')) else: data = json.load(open(CFG.dataset_path)) pklfn = CFG.dataset_path.replace('.json', '.pkl') if not os.path.exists(pklfn): pickle.dump(data, open(pklfn, 'wb')) subdir = make_next_record_dir(savedir, f'train-neg-') os.system(f'cp -a *py {subdir}; cp -a GNN {subdir}') CFG.save(f'{subdir}/config.json') modelfnl, _ = main(data, CFG, subdir, prev_model_pth)