|
|
from sklearn.metrics import r2_score |
|
|
import numpy as np |
|
|
|
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from tqdm import tqdm |
|
|
import argparse |
|
|
|
|
|
import pickle |
|
|
|
|
|
from utils_data import read_data, MolDataset, collate_mol, get_train_test_data |
|
|
from utils_model import ModelDimeNet |
|
|
|
|
|
|
|
|
def get_model(): |
|
|
model = ModelDimeNet() |
|
|
return model |
|
|
|
|
|
|
|
|
def get_optimizer(model, e_start): |
|
|
optimizer = torch.optim.RMSprop(model.parameters(), lr=10 ** -e_start) |
|
|
|
|
|
return optimizer |
|
|
|
|
|
def get_loss(mode): |
|
|
if mode[0] == 'mae': |
|
|
return lambda pred, en: (pred - en).abs().mean() |
|
|
|
|
|
if mode[0] == 'adaptive': |
|
|
return lambda pred, en: ((pred - en).abs() / (en.abs() + 1e-5) ** mode[1]).mean() |
|
|
|
|
|
def train_epoch(model, optimizer, dl_train, loss_fn, device): |
|
|
model.train() |
|
|
|
|
|
for atoms, coords, energy, batch in dl_train: |
|
|
optimizer.zero_grad() |
|
|
|
|
|
atoms = atoms.to(device) |
|
|
coords = coords.to(device) |
|
|
energy = energy.to(device) |
|
|
batch = batch.to(device) |
|
|
|
|
|
en = energy.squeeze() |
|
|
pred = model(atoms, coords, batch).squeeze() |
|
|
loss = loss_fn(pred, en) |
|
|
loss.backward() |
|
|
optimizer.step() |
|
|
|
|
|
|
|
|
def test_epoch(model, optimizer, dl_test, device): |
|
|
all_loss = 0 |
|
|
all_mols = 0 |
|
|
all_preds = [] |
|
|
all_trues = [] |
|
|
model.eval() |
|
|
|
|
|
for atoms, coords, energy, batch in dl_test: |
|
|
atoms = atoms.to(device) |
|
|
coords = coords.to(device) |
|
|
energy = energy.to(device) |
|
|
batch = batch.to(device) |
|
|
|
|
|
en = energy.squeeze() |
|
|
with torch.no_grad(): |
|
|
pred = model(atoms, coords, batch).squeeze() |
|
|
all_preds.append(pred.cpu().numpy()) |
|
|
all_trues.append(en.cpu().numpy()) |
|
|
all_loss += F.l1_loss(pred.squeeze(), en).item() * len(pred) |
|
|
all_mols += len(pred) |
|
|
|
|
|
all_trues = np.concatenate(all_trues) |
|
|
all_preds = np.concatenate(all_preds) |
|
|
|
|
|
return { |
|
|
'r2_score': r2_score(np.array(all_trues), np.array(all_preds)), |
|
|
'mae': all_loss / all_mols, |
|
|
} |
|
|
|
|
|
def refresh_lr(optimizer, i, n, e_start, downscale=2.0): |
|
|
for g in optimizer.param_groups: |
|
|
g['lr'] = 10 ** -(e_start + i / n * downscale) |
|
|
|
|
|
return 10 ** -(e_start + i / n * downscale) |
|
|
|
|
|
|
|
|
def train(n_epoch, model, optimizer, loss_fn, e_start, dl_train, dl_test, device, checkpoint_prefix): |
|
|
all_metrics = [] |
|
|
new_lr = e_start |
|
|
|
|
|
for i in tqdm(range(n_epoch)): |
|
|
train_epoch(model, optimizer, dl_train, loss_fn, device) |
|
|
|
|
|
metrics = test_epoch(model, optimizer, dl_test, device) |
|
|
|
|
|
cur_lr = new_lr |
|
|
new_lr = refresh_lr(optimizer, i, n_epoch, e_start) |
|
|
|
|
|
|
|
|
all_metrics.append(( |
|
|
metrics['r2_score'], |
|
|
metrics['mae'], |
|
|
cur_lr, |
|
|
)) |
|
|
|
|
|
torch.save(model.state_dict(), checkpoint_prefix + '.ckpt') |
|
|
|
|
|
return all_metrics |
|
|
|
|
|
def main(loss_mode, normalize, pretrain, checkpoint_prefix, data_filename): |
|
|
all_numbers, all_coords, energies, groups = read_data(data_filename) |
|
|
ds_all = MolDataset(all_numbers, all_coords, energies, normalize=normalize) |
|
|
|
|
|
loss_fn = get_loss(loss_mode) |
|
|
|
|
|
model = get_model() |
|
|
|
|
|
device = 'cuda' if torch.cuda.is_available() else 'cpu' |
|
|
model = model.to(device) |
|
|
|
|
|
|
|
|
if pretrain: |
|
|
e_start = 4 |
|
|
|
|
|
ds_train, ds_test = get_train_test_data(ds_all, groups, 'pretrain') |
|
|
dl_train = torch.utils.data.DataLoader(ds_train, batch_size=32, shuffle=True, collate_fn=collate_mol) |
|
|
dl_test = torch.utils.data.DataLoader(ds_test, batch_size=32, shuffle=False, collate_fn=collate_mol) |
|
|
|
|
|
optimizer = get_optimizer(model, e_start=e_start) |
|
|
|
|
|
all_metrics = train(100, model, optimizer, loss_fn, e_start, dl_train, dl_test, device, checkpoint_prefix + '_pretrain_model') |
|
|
with open(checkpoint_prefix + '_pretrain_metrics.pkl', 'wb') as f: |
|
|
pickle.dump(all_metrics, f) |
|
|
|
|
|
|
|
|
e_start = 5 |
|
|
|
|
|
ds_train, ds_test = get_train_test_data(ds_all, groups, 'finetune') |
|
|
dl_train = torch.utils.data.DataLoader(ds_train, batch_size=32, shuffle=True, collate_fn=collate_mol) |
|
|
dl_test = torch.utils.data.DataLoader(ds_test, batch_size=32, shuffle=False, collate_fn=collate_mol) |
|
|
|
|
|
optimizer = get_optimizer(model, e_start=e_start) |
|
|
|
|
|
all_metrics = train(100, model, optimizer, loss_fn, e_start, dl_train, dl_test, device, checkpoint_prefix + '_finetune_model') |
|
|
with open(checkpoint_prefix + '_finetune_metrics.pkl', 'wb') as f: |
|
|
pickle.dump(all_metrics, f) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
parser = argparse.ArgumentParser(description='Параметры для обучения модели') |
|
|
|
|
|
|
|
|
parser.add_argument('loss_mode', |
|
|
choices=['mae', 'adaptive'], |
|
|
help="Режим потерь: 'mae' или 'adaptive'") |
|
|
parser.add_argument('checkpoint_prefix', |
|
|
type=str, |
|
|
help="Префикс для чекпоинтов") |
|
|
parser.add_argument('data_filename', |
|
|
type=str, |
|
|
help="Путь к файлу с датасетом") |
|
|
|
|
|
|
|
|
parser.add_argument('--normalize', |
|
|
action='store_true', |
|
|
help="Применить нормализацию (только для loss_mode='mae')") |
|
|
parser.add_argument('--pretrain', |
|
|
action='store_true', |
|
|
help="Использовать предобучение") |
|
|
|
|
|
|
|
|
parser.add_argument('--loss_k', |
|
|
type=float, |
|
|
default=None, |
|
|
help="Коэффициент k для adaptive loss (требуется при loss_mode='adaptive')") |
|
|
|
|
|
args = parser.parse_args() |
|
|
|
|
|
|
|
|
if args.loss_mode == 'adaptive': |
|
|
if args.normalize: |
|
|
raise ValueError("Параметр --normalize несовместим с loss_mode='adaptive'") |
|
|
if args.loss_k is None: |
|
|
raise ValueError("Для adaptive loss требуется параметр --loss_k") |
|
|
|
|
|
loss_mode_arg = ('adaptive', args.loss_k) |
|
|
else: |
|
|
if args.loss_k is not None: |
|
|
raise ValueError("Параметр --loss_k можно использовать только с loss_mode='adaptive'") |
|
|
loss_mode_arg = ('mae', ) |
|
|
|
|
|
|
|
|
main( |
|
|
loss_mode=loss_mode_arg, |
|
|
normalize=args.normalize, |
|
|
pretrain=args.pretrain, |
|
|
checkpoint_prefix=args.checkpoint_prefix, |
|
|
data_filename=args.data_filename, |
|
|
) |