| import torch | |
| from torch import nn | |
| import torch.nn.functional as F | |
| from torch_geometric.nn.models import SchNet, DimeNetPlusPlus | |
| import ase | |
| import ase.io | |
| import re | |
| from sklearn.metrics import r2_score | |
| import numpy as np | |
| import sys | |
| from tqdm import tqdm | |
| import argparse | |
| from utils_model import ModellDimeNet | |
| def get_model_and_optimizer(model_type): | |
| if model_type == 'SchNet': | |
| model = SchNet() | |
| elif model_type == 'DimeNet': | |
| model = ModelDimeNet() | |
| optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4) | |
| return model, optimizer | |
| def train_epoch(model, optimizer, geoms, energies_n, mean_grad=32): | |
| j = 0 | |
| model.train(); | |
| for geom, energy in zip(geoms, energies_n.clone().detach()): | |
| if j == 0: | |
| optimizer.zero_grad() | |
| coords = torch.tensor(geom.get_positions(), dtype=torch.float32) | |
| atoms = torch.tensor(geom.get_atomic_numbers()) | |
| batch = torch.zeros_like(atoms) | |
| en = energy.clone().detach() | |
| pred = model(atoms, coords, batch) | |
| loss = F.huber_loss(pred.squeeze(), en) | |
| (loss / mean_grad).backward(); j += 1 | |
| if j == mean_grad - 1: | |
| optimizer.step() | |
| j = 0 | |
| def test_epoch(model, optimizer, geoms, energies_n): | |
| all_loss = 0 | |
| all_mols = 0 | |
| all_preds = [] | |
| all_trues = [] | |
| model.eval(); | |
| for geom, energy in zip(geoms, energies_n.clone().detach()): | |
| coords = torch.tensor(geom.get_positions(), dtype=torch.float32) | |
| atoms = torch.tensor(geom.get_atomic_numbers()) | |
| batch = torch.zeros_like(atoms) | |
| en = energy.clone().detach() | |
| with torch.no_grad(): | |
| pred = model(atoms, coords, batch) | |
| all_preds.append(pred.item()) | |
| all_trues.append(en.item()) | |
| all_loss += F.l1_loss(pred.squeeze(), en).item() | |
| all_mols += 1 | |
| return { | |
| 'r2_score': r2_score(np.array(all_trues), np.array(all_preds)), | |
| 'mae': all_loss / all_mols, | |
| } | |
| def train(model, optimizer, geoms, energies_n, n_epochs=100): | |
| best_r2score = -1e100 | |
| best_mae = 1e100 | |
| for i in tqdm(range(n_epochs)): | |
| train_epoch(model, optimizer, geoms, energies_n) | |
| metrics = test_epoch(model, optimizer, geoms, energies_n) | |
| if best_r2score < metrics['r2_score']: | |
| best_r2score = metrics['r2_score'] | |
| if best_mae > metrics['mae']: | |
| best_mae = metrics['mae'] | |
| return best_r2score, best_mae | |
| def main(trajectory_file, model_type): | |
| geoms = ase.io.read(trajectory_file, format='xyz', index=':') | |
| with open(trajectory_file) as f: | |
| cont = f.read() | |
| energies = [] | |
| lines = cont.split('\n'); i = 0 | |
| while i < len(lines): | |
| try: | |
| n = int(lines[i].strip()) | |
| except ValueError: | |
| break | |
| comment = lines[i+1] | |
| energy = float(re.findall('energy\\:\\s+(-?\\d*\.\\d*)', comment)[0]) | |
| energies.append(energy) | |
| i += n + 2 | |
| energies = torch.tensor(energies) | |
| energies_n = (energies - energies.min()) * 627.5 | |
| model, optimizer = get_model_and_optimizer(model_type) | |
| best_r2score, best_mae = train(model, optimizer, geoms, energies_n) | |
| print(f'R2_score: {best_r2score:.4f}') | |
| print(f'MAE: {best_mae:.3f} kcal/mol') | |
| avaliable_models = ['SchNet', 'DimeNet'] | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Обработчик файлов с различными моделями") | |
| parser.add_argument("filename", help="Путь к обрабатываемому файлу") | |
| parser.add_argument("model", | |
| choices=avaliable_models, | |
| help=f"Выбор модели из доступных: {', '.join(avaliable_models)}") | |
| args = parser.parse_args() | |
| main(args.filename, args.model) | |