| import sys | |
| import torch | |
| from tqdm import tqdm | |
| from torch.distributions import Normal | |
| from .dynamics import dynamics | |
| from .utils.utils import kabsch | |
| class MDs: | |
| def __init__(self, args): | |
| self.device = args.device | |
| self.molecule = args.molecule | |
| self.end_state = args.end_state | |
| self.num_samples = args.num_samples | |
| self.start_state = args.start_state | |
| self.get_md_info(args) | |
| self.mds = self._init_mds(args) | |
| self.log_prob = Normal(0, self.std).log_prob | |
| self.target_position = self.target_position - self.target_position[:, self.heavy_atoms].mean(-2, keepdim=True) | |
| R, t = kabsch( | |
| self.start_position[:, self.heavy_atoms], | |
| self.target_position[:, self.heavy_atoms], | |
| ) | |
| self.start_position = torch.matmul(self.start_position, R.transpose(-2, -1)) + t | |
| def get_md_info(self, args): | |
| md = getattr(dynamics, self.molecule.title())(args, self.end_state) | |
| self.num_particles = md.num_particles | |
| self.heavy_atoms = torch.from_numpy(md.heavy_atoms).to(self.device) | |
| self.energy_function = md.energy_function | |
| self.target_position = torch.tensor( | |
| md.position, dtype=torch.float, device=self.device | |
| ).unsqueeze(0) | |
| self.std = torch.tensor( | |
| md.std, | |
| dtype=torch.float, | |
| device=args.device, | |
| ) | |
| self.m = torch.tensor( | |
| md.m, | |
| dtype=torch.float, | |
| device=args.device, | |
| ).unsqueeze(-1) | |
| def _init_mds(self, args): | |
| mds = [] | |
| for _ in tqdm(range(self.num_samples)): | |
| md = getattr(dynamics, self.molecule.title())(args, self.start_state) | |
| mds.append(md) | |
| self.start_position = torch.tensor( | |
| md.position, dtype=torch.float, device=self.device | |
| ).unsqueeze(0) | |
| return mds | |
| def step(self, force): | |
| force = force.detach().cpu().numpy() | |
| for i in range(self.num_samples): | |
| self.mds[i].step(force[i]) | |
| def report(self): | |
| positions, forces = [], [] | |
| for i in range(self.num_samples): | |
| position, force = self.mds[i].report() | |
| positions.append(position) | |
| forces.append(force) | |
| positions = torch.tensor(positions, dtype=torch.float, device=self.device) | |
| forces = torch.tensor(forces, dtype=torch.float, device=self.device) | |
| return positions, forces | |
| def reset(self): | |
| for i in range(self.num_samples): | |
| self.mds[i].reset() | |
| def set_temperature(self, temperature): | |
| for i in range(self.num_samples): | |
| self.mds[i].set_temperature(temperature) | |