Sophia Tang
Initial commit with LFS
7efee70
import sys
import torch
from tqdm import tqdm
from torch.distributions import Normal
from .dynamics import dynamics
from .utils.utils import kabsch
class MDs:
def __init__(self, args):
self.device = args.device
self.molecule = args.molecule
self.end_state = args.end_state
self.num_samples = args.num_samples
self.start_state = args.start_state
self.get_md_info(args)
self.mds = self._init_mds(args)
self.log_prob = Normal(0, self.std).log_prob
self.target_position = self.target_position - self.target_position[:, self.heavy_atoms].mean(-2, keepdim=True)
R, t = kabsch(
self.start_position[:, self.heavy_atoms],
self.target_position[:, self.heavy_atoms],
)
self.start_position = torch.matmul(self.start_position, R.transpose(-2, -1)) + t
def get_md_info(self, args):
md = getattr(dynamics, self.molecule.title())(args, self.end_state)
self.num_particles = md.num_particles
self.heavy_atoms = torch.from_numpy(md.heavy_atoms).to(self.device)
self.energy_function = md.energy_function
self.target_position = torch.tensor(
md.position, dtype=torch.float, device=self.device
).unsqueeze(0)
self.std = torch.tensor(
md.std,
dtype=torch.float,
device=args.device,
)
self.m = torch.tensor(
md.m,
dtype=torch.float,
device=args.device,
).unsqueeze(-1)
def _init_mds(self, args):
mds = []
for _ in tqdm(range(self.num_samples)):
md = getattr(dynamics, self.molecule.title())(args, self.start_state)
mds.append(md)
self.start_position = torch.tensor(
md.position, dtype=torch.float, device=self.device
).unsqueeze(0)
return mds
def step(self, force):
force = force.detach().cpu().numpy()
for i in range(self.num_samples):
self.mds[i].step(force[i])
def report(self):
positions, forces = [], []
for i in range(self.num_samples):
position, force = self.mds[i].report()
positions.append(position)
forces.append(force)
positions = torch.tensor(positions, dtype=torch.float, device=self.device)
forces = torch.tensor(forces, dtype=torch.float, device=self.device)
return positions, forces
def reset(self):
for i in range(self.num_samples):
self.mds[i].reset()
def set_temperature(self, temperature):
for i in range(self.num_samples):
self.mds[i].set_temperature(temperature)