| import torch |
| from torch import nn |
| import torch.nn.functional as F |
| from torch_geometric.nn.models import SchNet, DimeNetPlusPlus |
|
|
| import ase |
| import ase.io |
| import re |
| from sklearn.metrics import r2_score |
| import numpy as np |
|
|
| import sys |
|
|
| from tqdm import tqdm |
|
|
| import argparse |
|
|
| from utils_model import ModellDimeNet |
|
|
| def get_model_and_optimizer(model_type): |
| if model_type == 'SchNet': |
| model = SchNet() |
| elif model_type == 'DimeNet': |
| model = ModelDimeNet() |
|
|
| optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4) |
|
|
| return model, optimizer |
|
|
|
|
| def train_epoch(model, optimizer, geoms, energies_n, mean_grad=32): |
| j = 0 |
| model.train(); |
|
|
| for geom, energy in zip(geoms, energies_n.clone().detach()): |
| if j == 0: |
| optimizer.zero_grad() |
| coords = torch.tensor(geom.get_positions(), dtype=torch.float32) |
| atoms = torch.tensor(geom.get_atomic_numbers()) |
| batch = torch.zeros_like(atoms) |
| |
| en = energy.clone().detach() |
| pred = model(atoms, coords, batch) |
| loss = F.huber_loss(pred.squeeze(), en) |
| (loss / mean_grad).backward(); j += 1 |
| if j == mean_grad - 1: |
| optimizer.step() |
| j = 0 |
|
|
| def test_epoch(model, optimizer, geoms, energies_n): |
| all_loss = 0 |
| all_mols = 0 |
| all_preds = [] |
| all_trues = [] |
| model.eval(); |
| |
| for geom, energy in zip(geoms, energies_n.clone().detach()): |
| coords = torch.tensor(geom.get_positions(), dtype=torch.float32) |
| atoms = torch.tensor(geom.get_atomic_numbers()) |
| batch = torch.zeros_like(atoms) |
| |
| en = energy.clone().detach() |
| with torch.no_grad(): |
| pred = model(atoms, coords, batch) |
| all_preds.append(pred.item()) |
| all_trues.append(en.item()) |
| all_loss += F.l1_loss(pred.squeeze(), en).item() |
| all_mols += 1 |
| return { |
| 'r2_score': r2_score(np.array(all_trues), np.array(all_preds)), |
| 'mae': all_loss / all_mols, |
| } |
|
|
| def train(model, optimizer, geoms, energies_n, n_epochs=100): |
| best_r2score = -1e100 |
| best_mae = 1e100 |
| |
| for i in tqdm(range(n_epochs)): |
| train_epoch(model, optimizer, geoms, energies_n) |
|
|
| metrics = test_epoch(model, optimizer, geoms, energies_n) |
|
|
| if best_r2score < metrics['r2_score']: |
| best_r2score = metrics['r2_score'] |
| if best_mae > metrics['mae']: |
| best_mae = metrics['mae'] |
|
|
| return best_r2score, best_mae |
|
|
| def main(trajectory_file, model_type): |
| geoms = ase.io.read(trajectory_file, format='xyz', index=':') |
| |
| with open(trajectory_file) as f: |
| cont = f.read() |
| |
| energies = [] |
| lines = cont.split('\n'); i = 0 |
| while i < len(lines): |
| try: |
| n = int(lines[i].strip()) |
| except ValueError: |
| break |
| comment = lines[i+1] |
| energy = float(re.findall('energy\\:\\s+(-?\\d*\.\\d*)', comment)[0]) |
| energies.append(energy) |
| i += n + 2 |
| |
| energies = torch.tensor(energies) |
| energies_n = (energies - energies.min()) * 627.5 |
|
|
| model, optimizer = get_model_and_optimizer(model_type) |
|
|
| best_r2score, best_mae = train(model, optimizer, geoms, energies_n) |
|
|
| print(f'R2_score: {best_r2score:.4f}') |
| print(f'MAE: {best_mae:.3f} kcal/mol') |
|
|
| avaliable_models = ['SchNet', 'DimeNet'] |
|
|
| if __name__ == "__main__": |
| parser = argparse.ArgumentParser(description="Обработчик файлов с различными моделями") |
| parser.add_argument("filename", help="Путь к обрабатываемому файлу") |
| parser.add_argument("model", |
| choices=avaliable_models, |
| help=f"Выбор модели из доступных: {', '.join(avaliable_models)}") |
| |
| args = parser.parse_args() |
| |
| main(args.filename, args.model) |
|
|
|
|