vscf_mlff / train.py
timcryt's picture
Initial commit
5fae7ca verified
from sklearn.metrics import r2_score
import numpy as np
import torch
import torch.nn.functional as F
from tqdm import tqdm
import argparse
import pickle
from utils_data import read_data, MolDataset, collate_mol, get_train_test_data
from utils_model import ModelDimeNet
def get_model():
model = ModelDimeNet()
return model
def get_optimizer(model, e_start):
optimizer = torch.optim.RMSprop(model.parameters(), lr=10 ** -e_start)
return optimizer
def get_loss(mode):
if mode[0] == 'mae':
return lambda pred, en: (pred - en).abs().mean()
if mode[0] == 'adaptive':
return lambda pred, en: ((pred - en).abs() / (en.abs() + 1e-5) ** mode[1]).mean()
def train_epoch(model, optimizer, dl_train, loss_fn, device):
model.train()
for atoms, coords, energy, batch in dl_train:
optimizer.zero_grad()
atoms = atoms.to(device)
coords = coords.to(device)
energy = energy.to(device)
batch = batch.to(device)
en = energy.squeeze()
pred = model(atoms, coords, batch).squeeze()
loss = loss_fn(pred, en)
loss.backward()
optimizer.step()
def test_epoch(model, optimizer, dl_test, device):
all_loss = 0
all_mols = 0
all_preds = []
all_trues = []
model.eval()
for atoms, coords, energy, batch in dl_test:
atoms = atoms.to(device)
coords = coords.to(device)
energy = energy.to(device)
batch = batch.to(device)
en = energy.squeeze()
with torch.no_grad():
pred = model(atoms, coords, batch).squeeze()
all_preds.append(pred.cpu().numpy())
all_trues.append(en.cpu().numpy())
all_loss += F.l1_loss(pred.squeeze(), en).item() * len(pred)
all_mols += len(pred)
all_trues = np.concatenate(all_trues)
all_preds = np.concatenate(all_preds)
return {
'r2_score': r2_score(np.array(all_trues), np.array(all_preds)),
'mae': all_loss / all_mols,
}
def refresh_lr(optimizer, i, n, e_start, downscale=2.0):
for g in optimizer.param_groups:
g['lr'] = 10 ** -(e_start + i / n * downscale)
return 10 ** -(e_start + i / n * downscale)
def train(n_epoch, model, optimizer, loss_fn, e_start, dl_train, dl_test, device, checkpoint_prefix):
all_metrics = []
new_lr = e_start
for i in tqdm(range(n_epoch)):
train_epoch(model, optimizer, dl_train, loss_fn, device)
metrics = test_epoch(model, optimizer, dl_test, device)
cur_lr = new_lr
new_lr = refresh_lr(optimizer, i, n_epoch, e_start)
all_metrics.append((
metrics['r2_score'],
metrics['mae'],
cur_lr,
))
torch.save(model.state_dict(), checkpoint_prefix + '.ckpt')
return all_metrics
def main(loss_mode, normalize, pretrain, checkpoint_prefix, data_filename):
all_numbers, all_coords, energies, groups = read_data(data_filename)
ds_all = MolDataset(all_numbers, all_coords, energies, normalize=normalize)
loss_fn = get_loss(loss_mode)
model = get_model()
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = model.to(device)
# Pretraining
if pretrain:
e_start = 4
ds_train, ds_test = get_train_test_data(ds_all, groups, 'pretrain')
dl_train = torch.utils.data.DataLoader(ds_train, batch_size=32, shuffle=True, collate_fn=collate_mol)
dl_test = torch.utils.data.DataLoader(ds_test, batch_size=32, shuffle=False, collate_fn=collate_mol)
optimizer = get_optimizer(model, e_start=e_start)
all_metrics = train(100, model, optimizer, loss_fn, e_start, dl_train, dl_test, device, checkpoint_prefix + '_pretrain_model')
with open(checkpoint_prefix + '_pretrain_metrics.pkl', 'wb') as f:
pickle.dump(all_metrics, f)
# Fine-tuting
e_start = 5
ds_train, ds_test = get_train_test_data(ds_all, groups, 'finetune')
dl_train = torch.utils.data.DataLoader(ds_train, batch_size=32, shuffle=True, collate_fn=collate_mol)
dl_test = torch.utils.data.DataLoader(ds_test, batch_size=32, shuffle=False, collate_fn=collate_mol)
optimizer = get_optimizer(model, e_start=e_start)
all_metrics = train(100, model, optimizer, loss_fn, e_start, dl_train, dl_test, device, checkpoint_prefix + '_finetune_model')
with open(checkpoint_prefix + '_finetune_metrics.pkl', 'wb') as f:
pickle.dump(all_metrics, f)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Параметры для обучения модели')
# Обязательные аргументы
parser.add_argument('loss_mode',
choices=['mae', 'adaptive'],
help="Режим потерь: 'mae' или 'adaptive'")
parser.add_argument('checkpoint_prefix',
type=str,
help="Префикс для чекпоинтов")
parser.add_argument('data_filename',
type=str,
help="Путь к файлу с датасетом")
# Флаги (булевые параметры)
parser.add_argument('--normalize',
action='store_true',
help="Применить нормализацию (только для loss_mode='mae')")
parser.add_argument('--pretrain',
action='store_true',
help="Использовать предобучение")
# Параметр только для adaptive режима
parser.add_argument('--loss_k',
type=float,
default=None,
help="Коэффициент k для adaptive loss (требуется при loss_mode='adaptive')")
args = parser.parse_args()
# Проверка совместимости параметров
if args.loss_mode == 'adaptive':
if args.normalize:
raise ValueError("Параметр --normalize несовместим с loss_mode='adaptive'")
if args.loss_k is None:
raise ValueError("Для adaptive loss требуется параметр --loss_k")
# Формируем кортеж для adaptive режима
loss_mode_arg = ('adaptive', args.loss_k)
else: # loss_mode == 'mae'
if args.loss_k is not None:
raise ValueError("Параметр --loss_k можно использовать только с loss_mode='adaptive'")
loss_mode_arg = ('mae', )
# Вызов основной функции
main(
loss_mode=loss_mode_arg,
normalize=args.normalize,
pretrain=args.pretrain,
checkpoint_prefix=args.checkpoint_prefix,
data_filename=args.data_filename,
)