import torch from torch import nn import torch.nn.functional as F from torch_geometric.nn.models import SchNet, DimeNetPlusPlus import ase import ase.io import re from sklearn.metrics import r2_score import numpy as np import sys from tqdm import tqdm import argparse from utils_model import ModellDimeNet def get_model_and_optimizer(model_type): if model_type == 'SchNet': model = SchNet() elif model_type == 'DimeNet': model = ModelDimeNet() optimizer = torch.optim.AdamW(model.parameters(), lr=3e-4) return model, optimizer def train_epoch(model, optimizer, geoms, energies_n, mean_grad=32): j = 0 model.train(); for geom, energy in zip(geoms, energies_n.clone().detach()): if j == 0: optimizer.zero_grad() coords = torch.tensor(geom.get_positions(), dtype=torch.float32) atoms = torch.tensor(geom.get_atomic_numbers()) batch = torch.zeros_like(atoms) en = energy.clone().detach() pred = model(atoms, coords, batch) loss = F.huber_loss(pred.squeeze(), en) (loss / mean_grad).backward(); j += 1 if j == mean_grad - 1: optimizer.step() j = 0 def test_epoch(model, optimizer, geoms, energies_n): all_loss = 0 all_mols = 0 all_preds = [] all_trues = [] model.eval(); for geom, energy in zip(geoms, energies_n.clone().detach()): coords = torch.tensor(geom.get_positions(), dtype=torch.float32) atoms = torch.tensor(geom.get_atomic_numbers()) batch = torch.zeros_like(atoms) en = energy.clone().detach() with torch.no_grad(): pred = model(atoms, coords, batch) all_preds.append(pred.item()) all_trues.append(en.item()) all_loss += F.l1_loss(pred.squeeze(), en).item() all_mols += 1 return { 'r2_score': r2_score(np.array(all_trues), np.array(all_preds)), 'mae': all_loss / all_mols, } def train(model, optimizer, geoms, energies_n, n_epochs=100): best_r2score = -1e100 best_mae = 1e100 for i in tqdm(range(n_epochs)): train_epoch(model, optimizer, geoms, energies_n) metrics = test_epoch(model, optimizer, geoms, energies_n) if best_r2score < metrics['r2_score']: best_r2score = metrics['r2_score'] if best_mae > metrics['mae']: best_mae = metrics['mae'] return best_r2score, best_mae def main(trajectory_file, model_type): geoms = ase.io.read(trajectory_file, format='xyz', index=':') with open(trajectory_file) as f: cont = f.read() energies = [] lines = cont.split('\n'); i = 0 while i < len(lines): try: n = int(lines[i].strip()) except ValueError: break comment = lines[i+1] energy = float(re.findall('energy\\:\\s+(-?\\d*\.\\d*)', comment)[0]) energies.append(energy) i += n + 2 energies = torch.tensor(energies) energies_n = (energies - energies.min()) * 627.5 model, optimizer = get_model_and_optimizer(model_type) best_r2score, best_mae = train(model, optimizer, geoms, energies_n) print(f'R2_score: {best_r2score:.4f}') print(f'MAE: {best_mae:.3f} kcal/mol') avaliable_models = ['SchNet', 'DimeNet'] if __name__ == "__main__": parser = argparse.ArgumentParser(description="Обработчик файлов с различными моделями") parser.add_argument("filename", help="Путь к обрабатываемому файлу") parser.add_argument("model", choices=avaliable_models, help=f"Выбор модели из доступных: {', '.join(avaliable_models)}") args = parser.parse_args() main(args.filename, args.model)