from sklearn.metrics import r2_score import numpy as np import torch import torch.nn.functional as F from tqdm import tqdm import argparse import pickle from utils_data import read_data, MolDataset, collate_mol, get_train_test_data from utils_model import ModelDimeNet def get_model(): model = ModelDimeNet() return model def get_optimizer(model, e_start): optimizer = torch.optim.RMSprop(model.parameters(), lr=10 ** -e_start) return optimizer def get_loss(mode): if mode[0] == 'mae': return lambda pred, en: (pred - en).abs().mean() if mode[0] == 'adaptive': return lambda pred, en: ((pred - en).abs() / (en.abs() + 1e-5) ** mode[1]).mean() def train_epoch(model, optimizer, dl_train, loss_fn, device): model.train() for atoms, coords, energy, batch in dl_train: optimizer.zero_grad() atoms = atoms.to(device) coords = coords.to(device) energy = energy.to(device) batch = batch.to(device) en = energy.squeeze() pred = model(atoms, coords, batch).squeeze() loss = loss_fn(pred, en) loss.backward() optimizer.step() def test_epoch(model, optimizer, dl_test, device): all_loss = 0 all_mols = 0 all_preds = [] all_trues = [] model.eval() for atoms, coords, energy, batch in dl_test: atoms = atoms.to(device) coords = coords.to(device) energy = energy.to(device) batch = batch.to(device) en = energy.squeeze() with torch.no_grad(): pred = model(atoms, coords, batch).squeeze() all_preds.append(pred.cpu().numpy()) all_trues.append(en.cpu().numpy()) all_loss += F.l1_loss(pred.squeeze(), en).item() * len(pred) all_mols += len(pred) all_trues = np.concatenate(all_trues) all_preds = np.concatenate(all_preds) return { 'r2_score': r2_score(np.array(all_trues), np.array(all_preds)), 'mae': all_loss / all_mols, } def refresh_lr(optimizer, i, n, e_start, downscale=2.0): for g in optimizer.param_groups: g['lr'] = 10 ** -(e_start + i / n * downscale) return 10 ** -(e_start + i / n * downscale) def train(n_epoch, model, optimizer, loss_fn, e_start, dl_train, dl_test, device, checkpoint_prefix): all_metrics = [] new_lr = e_start for i in tqdm(range(n_epoch)): train_epoch(model, optimizer, dl_train, loss_fn, device) metrics = test_epoch(model, optimizer, dl_test, device) cur_lr = new_lr new_lr = refresh_lr(optimizer, i, n_epoch, e_start) all_metrics.append(( metrics['r2_score'], metrics['mae'], cur_lr, )) torch.save(model.state_dict(), checkpoint_prefix + '.ckpt') return all_metrics def main(loss_mode, normalize, pretrain, checkpoint_prefix, data_filename): all_numbers, all_coords, energies, groups = read_data(data_filename) ds_all = MolDataset(all_numbers, all_coords, energies, normalize=normalize) loss_fn = get_loss(loss_mode) model = get_model() device = 'cuda' if torch.cuda.is_available() else 'cpu' model = model.to(device) # Pretraining if pretrain: e_start = 4 ds_train, ds_test = get_train_test_data(ds_all, groups, 'pretrain') dl_train = torch.utils.data.DataLoader(ds_train, batch_size=32, shuffle=True, collate_fn=collate_mol) dl_test = torch.utils.data.DataLoader(ds_test, batch_size=32, shuffle=False, collate_fn=collate_mol) optimizer = get_optimizer(model, e_start=e_start) all_metrics = train(100, model, optimizer, loss_fn, e_start, dl_train, dl_test, device, checkpoint_prefix + '_pretrain_model') with open(checkpoint_prefix + '_pretrain_metrics.pkl', 'wb') as f: pickle.dump(all_metrics, f) # Fine-tuting e_start = 5 ds_train, ds_test = get_train_test_data(ds_all, groups, 'finetune') dl_train = torch.utils.data.DataLoader(ds_train, batch_size=32, shuffle=True, collate_fn=collate_mol) dl_test = torch.utils.data.DataLoader(ds_test, batch_size=32, shuffle=False, collate_fn=collate_mol) optimizer = get_optimizer(model, e_start=e_start) all_metrics = train(100, model, optimizer, loss_fn, e_start, dl_train, dl_test, device, checkpoint_prefix + '_finetune_model') with open(checkpoint_prefix + '_finetune_metrics.pkl', 'wb') as f: pickle.dump(all_metrics, f) if __name__ == "__main__": parser = argparse.ArgumentParser(description='Параметры для обучения модели') # Обязательные аргументы parser.add_argument('loss_mode', choices=['mae', 'adaptive'], help="Режим потерь: 'mae' или 'adaptive'") parser.add_argument('checkpoint_prefix', type=str, help="Префикс для чекпоинтов") parser.add_argument('data_filename', type=str, help="Путь к файлу с датасетом") # Флаги (булевые параметры) parser.add_argument('--normalize', action='store_true', help="Применить нормализацию (только для loss_mode='mae')") parser.add_argument('--pretrain', action='store_true', help="Использовать предобучение") # Параметр только для adaptive режима parser.add_argument('--loss_k', type=float, default=None, help="Коэффициент k для adaptive loss (требуется при loss_mode='adaptive')") args = parser.parse_args() # Проверка совместимости параметров if args.loss_mode == 'adaptive': if args.normalize: raise ValueError("Параметр --normalize несовместим с loss_mode='adaptive'") if args.loss_k is None: raise ValueError("Для adaptive loss требуется параметр --loss_k") # Формируем кортеж для adaptive режима loss_mode_arg = ('adaptive', args.loss_k) else: # loss_mode == 'mae' if args.loss_k is not None: raise ValueError("Параметр --loss_k можно использовать только с loss_mode='adaptive'") loss_mode_arg = ('mae', ) # Вызов основной функции main( loss_mode=loss_mode_arg, normalize=args.normalize, pretrain=args.pretrain, checkpoint_prefix=args.checkpoint_prefix, data_filename=args.data_filename, )