|
|
import sys |
|
|
import torch |
|
|
import numpy as np |
|
|
from tqdm import tqdm |
|
|
|
|
|
import math |
|
|
from .utils.utils import kabsch |
|
|
from .bias import BiasForceTransformer |
|
|
|
|
|
|
|
|
class EntangledSBM: |
|
|
def __init__(self, args, mds): |
|
|
self.bias_net = BiasForceTransformer(mds, args) |
|
|
|
|
|
self.target_measure = PathObjective(args, mds) |
|
|
|
|
|
if args.training: |
|
|
self.replay = ReplayBuffer(args, mds) |
|
|
|
|
|
self.rollout_idx = 0 |
|
|
|
|
|
def increment_rollout(self): |
|
|
self.rollout_idx += 1 |
|
|
|
|
|
|
|
|
def sample(self, args, mds, temperature): |
|
|
|
|
|
positions = torch.zeros( |
|
|
(args.num_samples, args.num_steps + 1, mds.num_particles, 3), |
|
|
device=args.device, |
|
|
) |
|
|
|
|
|
forces = torch.zeros( |
|
|
(args.num_samples, args.num_steps + 1, mds.num_particles, 3), |
|
|
device=args.device, |
|
|
) |
|
|
|
|
|
position, force = mds.report() |
|
|
positions[:, 0] = position.detach().clone() |
|
|
forces[:, 0] = force.detach().clone() |
|
|
mds.reset() |
|
|
|
|
|
mds.set_temperature(temperature) |
|
|
prev_position = position.detach().clone() |
|
|
|
|
|
for step in tqdm(range(1, args.num_steps + 1), desc="Sampling"): |
|
|
if step == 1: |
|
|
velocity = torch.zeros_like(position) |
|
|
else: |
|
|
velocity = (position - prev_position) / args.timestep |
|
|
|
|
|
bias_force = self.bias_net(position.detach().clone(), |
|
|
velocity.detach().clone(), |
|
|
mds.target_position).detach() |
|
|
|
|
|
|
|
|
mds.step(bias_force) |
|
|
|
|
|
position, force = mds.report() |
|
|
|
|
|
if not _is_finite(position, force): |
|
|
print("MD produced non-finite: pos nan/inf", torch.isnan(position).sum().item(), torch.isinf(position).sum().item(), |
|
|
"force nan/inf", torch.isnan(force).sum().item(), torch.isinf(force).sum().item()) |
|
|
|
|
|
positions[:, step] = prev_position |
|
|
forces[:, step] = force |
|
|
break |
|
|
|
|
|
prev_position = position.detach().clone() |
|
|
|
|
|
positions[:, step] = position |
|
|
forces[:, step] = force - 1e-6 * bias_force |
|
|
|
|
|
mds.reset() |
|
|
log_tpm, final_idx, log_ri = self.target_measure(positions, forces) |
|
|
|
|
|
if args.training: |
|
|
self.replay.add_ranked((positions, |
|
|
forces, |
|
|
log_tpm), score=log_ri) |
|
|
|
|
|
for i in range(args.num_samples): |
|
|
np.save( |
|
|
f"{args.save_dir}/positions/{i}.npy", |
|
|
positions[i][: final_idx[i] + 1].cpu().numpy(), |
|
|
) |
|
|
|
|
|
def train(self, args, mds): |
|
|
|
|
|
exclude = {id(self.bias_net.log_z)} |
|
|
params_except = [p for p in self.bias_net.parameters() if id(p) not in exclude] |
|
|
optimizer = torch.optim.Adam( |
|
|
[ |
|
|
{"params": [self.bias_net.log_z], "lr": args.log_z_lr}, |
|
|
{"params": params_except, "lr": args.policy_lr}, |
|
|
] |
|
|
) |
|
|
loss_sum = 0 |
|
|
|
|
|
for _ in tqdm(range(args.trains_per_rollout), desc="Training"): |
|
|
|
|
|
positions, forces, log_tpm, log_ri = self.replay.sample() |
|
|
|
|
|
assert positions.shape == forces.shape, f"{positions.shape=} != {forces.shape=}" |
|
|
velocities = (positions[:, 1:] - positions[:, :-1]) / args.timestep |
|
|
|
|
|
|
|
|
biases = 1e-6 * self.bias_net( |
|
|
positions[:, :-1].reshape(-1, positions.size(-2), positions.size(-1)), |
|
|
velocities.view(-1, velocities.size(-2), velocities.size(-1)), |
|
|
mds.target_position, |
|
|
) |
|
|
|
|
|
biases = biases.view(*velocities.shape) |
|
|
|
|
|
means = ( |
|
|
1 - args.friction * args.timestep |
|
|
) * velocities + args.timestep / mds.m * (forces[:, :-1] + biases) |
|
|
|
|
|
resid = _sanitize(velocities[:, 1:] - means[:, :-1]) |
|
|
log_bpm = mds.log_prob(resid).mean((1, 2, 3)) |
|
|
|
|
|
if args.control_variate == "global": |
|
|
log_z = self.bias_net.log_z |
|
|
elif args.control_variate == "local": |
|
|
log_z = (log_tpm - log_bpm).mean().detach() |
|
|
elif args.control_variate == "zero": |
|
|
log_z = 0 |
|
|
|
|
|
|
|
|
if args.objective == "ce": |
|
|
|
|
|
log_rnd = (log_tpm - log_bpm).detach() |
|
|
|
|
|
weights = torch.softmax(log_rnd, dim=0) |
|
|
|
|
|
if args.control_cost: |
|
|
control_cost = 0.5 * args.timestep * (biases[:, :-1].square().sum((-1, -2, -3))).mean() |
|
|
loss = -(weights * log_bpm).sum() + control_cost |
|
|
else: |
|
|
loss = -(weights * log_bpm).sum() |
|
|
|
|
|
elif args.objective == "lv": |
|
|
loss = (log_z + log_bpm - log_tpm).square().mean() |
|
|
|
|
|
loss.backward() |
|
|
|
|
|
for group in optimizer.param_groups: |
|
|
torch.nn.utils.clip_grad_norm_(group["params"], args.max_grad_norm) |
|
|
|
|
|
optimizer.step() |
|
|
optimizer.zero_grad() |
|
|
loss_sum += loss.item() |
|
|
|
|
|
loss = loss_sum / args.trains_per_rollout |
|
|
return loss, positions |
|
|
|
|
|
|
|
|
class ReplayBuffer: |
|
|
def __init__(self, args, mds): |
|
|
|
|
|
self.positions = torch.zeros( |
|
|
(args.buffer_size, args.num_steps + 1, mds.num_particles, 3), |
|
|
device=args.device, |
|
|
) |
|
|
self.forces = torch.zeros( |
|
|
(args.buffer_size, args.num_steps + 1, mds.num_particles, 3), |
|
|
device=args.device, |
|
|
) |
|
|
self.log_tpm = torch.zeros(args.buffer_size, device=args.device) |
|
|
self.idx = 0 |
|
|
|
|
|
self.device = args.device |
|
|
self.batch_size = args.batch_size |
|
|
self.num_samples = args.num_samples |
|
|
self.buffer_size = args.buffer_size |
|
|
self.args = args |
|
|
|
|
|
|
|
|
self.scores = torch.zeros(args.buffer_size, device=args.device) |
|
|
self.count = 0 |
|
|
|
|
|
def add(self, data): |
|
|
pos_batch, force_batch, tpm_batch = data |
|
|
|
|
|
newN = pos_batch.size(0) |
|
|
|
|
|
indices = (torch.arange(self.idx, self.idx + newN, device=self.device) % self.buffer_size) |
|
|
|
|
|
self.idx = (self.idx + newN) % self.buffer_size |
|
|
|
|
|
self.positions[indices] = pos_batch.detach().to(self.device).clone() |
|
|
self.forces[indices] = force_batch.detach().to(self.device).clone() |
|
|
self.log_tpm[indices] = tpm_batch.detach().to(self.device).clone() |
|
|
|
|
|
self.count = min(self.count + newN, self.buffer_size) |
|
|
|
|
|
@torch.no_grad() |
|
|
def add_ranked(self, data, score=None): |
|
|
positions, forces, log_tpm = data |
|
|
if score is None: |
|
|
score = log_tpm |
|
|
|
|
|
|
|
|
positions, forces, log_tpm, score = ( |
|
|
positions.clone().detach(), |
|
|
forces.clone().detach(), |
|
|
log_tpm.clone().detach(), |
|
|
score.clone().detach() |
|
|
) |
|
|
|
|
|
valid = torch.isfinite(positions).all((1,2,3)) & torch.isfinite(forces).all((1,2,3)) & torch.isfinite(log_tpm) |
|
|
|
|
|
if valid.any(): |
|
|
positions = positions[valid] |
|
|
forces = forces[valid] |
|
|
log_tpm = log_tpm[valid] |
|
|
score = score[valid] |
|
|
|
|
|
curr = self.count |
|
|
newN = positions.size(0) |
|
|
keepN = min(self.buffer_size, curr + newN) |
|
|
|
|
|
if curr > 0: |
|
|
pos_cat = torch.cat([self.positions[:curr], positions], dim=0) |
|
|
force_cat = torch.cat([self.forces[:curr], forces], dim=0) |
|
|
tpm_cat = torch.cat([self.log_tpm[:curr], log_tpm], dim=0) |
|
|
sco_cat = torch.cat([self.scores[:curr], score], dim=0) |
|
|
else: |
|
|
pos_cat, force_cat, tpm_cat, sco_cat = positions, forces, log_tpm, score |
|
|
|
|
|
top_vals, top_idx = torch.topk(sco_cat, k=keepN, largest=True, sorted=False) |
|
|
|
|
|
self.positions[:keepN] = pos_cat.index_select(0, top_idx) |
|
|
self.forces[:keepN] = force_cat.index_select(0, top_idx) |
|
|
self.log_tpm[:keepN] = tpm_cat.index_select(0, top_idx) |
|
|
self.scores[:keepN] = top_vals |
|
|
self.count = keepN |
|
|
|
|
|
def sample(self): |
|
|
assert self.count > 0, "buffer is empty" |
|
|
if self.args.importance_sample: |
|
|
idx = torch.multinomial(torch.softmax(self.scores[:self.count], 0), |
|
|
num_samples=self.batch_size, replacement=True) |
|
|
else: |
|
|
idx = torch.randint(0, self.count, (self.batch_size,), device=self.device) |
|
|
|
|
|
|
|
|
return ( |
|
|
self.positions[idx].clone().detach(), |
|
|
self.forces[idx].clone().detach(), |
|
|
self.log_tpm[idx].clone().detach(), |
|
|
self.scores[idx].clone().detach(), |
|
|
) |
|
|
|
|
|
|
|
|
class PathObjective: |
|
|
def __init__(self, args, mds): |
|
|
self.sigma = args.sigma |
|
|
self.timestep = args.timestep |
|
|
self.friction = args.friction |
|
|
self.heavy_atoms = mds.heavy_atoms |
|
|
self.target_position = mds.target_position |
|
|
self.m = mds.m |
|
|
self.log_prob = mds.log_prob |
|
|
|
|
|
def __call__(self, positions, forces): |
|
|
log_upm = self.unbiased_path_measure(positions, forces) |
|
|
log_ri, final_idx = self.relaxed_indicator(positions, self.target_position) |
|
|
log_tpm = log_upm + log_ri |
|
|
return log_tpm, final_idx, log_ri |
|
|
|
|
|
def unbiased_path_measure(self, positions, forces): |
|
|
velocities = (positions[:, 1:] - positions[:, :-1]) / self.timestep |
|
|
|
|
|
means = ( |
|
|
1 - self.friction * self.timestep |
|
|
) * velocities + self.timestep / self.m * forces[:, :-1] |
|
|
|
|
|
resid = _sanitize(velocities[:, 1:] - means[:, :-1]) |
|
|
|
|
|
lp = self.log_prob(resid) |
|
|
|
|
|
log_upm = lp.mean((1, 2, 3)) |
|
|
return log_upm |
|
|
|
|
|
def relaxed_indicator(self, positions, target_position): |
|
|
positions = positions[:, :, self.heavy_atoms] |
|
|
target_position = target_position[:, self.heavy_atoms] |
|
|
|
|
|
log_ri = torch.zeros(positions.size(0), device=positions.device) |
|
|
final_idx = torch.zeros( |
|
|
positions.size(0), device=positions.device, dtype=torch.long |
|
|
) |
|
|
for i in range(positions.size(0)): |
|
|
log_ri[i], final_idx[i] = self.rbf( |
|
|
positions[i], |
|
|
target_position, |
|
|
).max(0) |
|
|
return log_ri, final_idx |
|
|
|
|
|
def rbf(self, positions, target_position): |
|
|
|
|
|
R, t = kabsch(positions, target_position) |
|
|
positions = torch.matmul(positions, R.transpose(-2, -1)) + t |
|
|
log_ri = ( |
|
|
-0.5 / self.sigma**2 * (positions - target_position).square().mean((-2, -1)) |
|
|
) |
|
|
return log_ri |
|
|
|
|
|
def _is_finite(*tensors): |
|
|
return all(torch.isfinite(t).all().item() for t in tensors) |
|
|
|
|
|
def _sanitize(t, max_abs=1e6): |
|
|
t = torch.nan_to_num(t, nan=0.0, posinf=max_abs, neginf=-max_abs) |
|
|
return torch.clamp(t, min=-max_abs, max=max_abs) |