import sys import torch from tqdm import tqdm from torch.distributions import Normal from .dynamics import dynamics from .utils.utils import kabsch class MDs: def __init__(self, args): self.device = args.device self.molecule = args.molecule self.end_state = args.end_state self.num_samples = args.num_samples self.start_state = args.start_state self.get_md_info(args) self.mds = self._init_mds(args) self.log_prob = Normal(0, self.std).log_prob self.target_position = self.target_position - self.target_position[:, self.heavy_atoms].mean(-2, keepdim=True) R, t = kabsch( self.start_position[:, self.heavy_atoms], self.target_position[:, self.heavy_atoms], ) self.start_position = torch.matmul(self.start_position, R.transpose(-2, -1)) + t def get_md_info(self, args): md = getattr(dynamics, self.molecule.title())(args, self.end_state) self.num_particles = md.num_particles self.heavy_atoms = torch.from_numpy(md.heavy_atoms).to(self.device) self.energy_function = md.energy_function self.target_position = torch.tensor( md.position, dtype=torch.float, device=self.device ).unsqueeze(0) self.std = torch.tensor( md.std, dtype=torch.float, device=args.device, ) self.m = torch.tensor( md.m, dtype=torch.float, device=args.device, ).unsqueeze(-1) def _init_mds(self, args): mds = [] for _ in tqdm(range(self.num_samples)): md = getattr(dynamics, self.molecule.title())(args, self.start_state) mds.append(md) self.start_position = torch.tensor( md.position, dtype=torch.float, device=self.device ).unsqueeze(0) return mds def step(self, force): force = force.detach().cpu().numpy() for i in range(self.num_samples): self.mds[i].step(force[i]) def report(self): positions, forces = [], [] for i in range(self.num_samples): position, force = self.mds[i].report() positions.append(position) forces.append(force) positions = torch.tensor(positions, dtype=torch.float, device=self.device) forces = torch.tensor(forces, dtype=torch.float, device=self.device) return positions, forces def reset(self): for i in range(self.num_samples): self.mds[i].reset() def set_temperature(self, temperature): for i in range(self.num_samples): self.mds[i].set_temperature(temperature)