|
|
import sys |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
|
|
|
from .utils.utils import kabsch |
|
|
from .utils.rbf import grad_log_wrt_positions |
|
|
|
|
|
class BiasForceTransformer(nn.Module): |
|
|
def __init__(self, |
|
|
mds, |
|
|
args, |
|
|
d_model = 256, |
|
|
nhead = 8, |
|
|
num_layers = 4, |
|
|
dim_feedforward = 512, |
|
|
dropout = 0.1, |
|
|
): |
|
|
super().__init__() |
|
|
self.device = args.device |
|
|
self.heavy_atoms = mds.heavy_atoms |
|
|
self.N = mds.num_particles |
|
|
|
|
|
self.use_delta_to_target = args.use_delta_to_target |
|
|
self.rbf = args.rbf |
|
|
|
|
|
self.sigma = args.sigma |
|
|
|
|
|
feat_dim = 3 + 3 + (3 if self.use_delta_to_target else 0) + 1 |
|
|
|
|
|
self.input_proj = nn.Linear(feat_dim, d_model) |
|
|
enc_layer = nn.TransformerEncoderLayer( |
|
|
d_model=d_model, nhead=nhead, |
|
|
dim_feedforward=dim_feedforward, |
|
|
dropout=dropout, activation="gelu", |
|
|
batch_first=True, norm_first=True |
|
|
) |
|
|
self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers) |
|
|
|
|
|
self.scale_head = nn.Sequential( |
|
|
nn.Linear(d_model, d_model // 2), |
|
|
nn.GELU(), |
|
|
nn.Linear(d_model // 2, 1), |
|
|
) |
|
|
self.vec_head_aligned = nn.Sequential( |
|
|
nn.Linear(d_model, d_model // 2), |
|
|
nn.GELU(), |
|
|
nn.Linear(d_model // 2, 3), |
|
|
) |
|
|
|
|
|
self.bias = args.bias |
|
|
|
|
|
self.log_z = nn.Parameter(torch.tensor(0.0)) |
|
|
self.to(self.device) |
|
|
|
|
|
@staticmethod |
|
|
def _softplus_unit(x, beta=1.0, threshold=20.0, eps=1e-8): |
|
|
return F.softplus(x, beta=beta, threshold=threshold) + eps |
|
|
|
|
|
def forward(self, pos, vel, target): |
|
|
""" |
|
|
pos, vel, target: (B,N,3) |
|
|
Returns: force (B,N,3), scale (B,N), vector (B,N,3) |
|
|
""" |
|
|
B, N, _ = pos.shape |
|
|
assert N == self.N, f"Expected N={self.N}, got {N}" |
|
|
heavy = self.heavy_atoms.to(pos.device) |
|
|
|
|
|
pos_h, tgt_h = pos[:, heavy], target[:, heavy] |
|
|
R, t = kabsch(pos_h, tgt_h) |
|
|
|
|
|
pos_al = pos @ R.transpose(-2, -1) + t |
|
|
vel_al = vel @ R.transpose(-2, -1) |
|
|
|
|
|
delta_al = target - pos_al |
|
|
dist_al = torch.norm(delta_al, dim=-1, keepdim=True) |
|
|
feats = torch.cat([pos_al, vel_al, delta_al, dist_al], dim=-1) \ |
|
|
if self.use_delta_to_target else torch.cat([pos_al, vel_al, dist_al], dim=-1) |
|
|
|
|
|
x = self.input_proj(feats) |
|
|
x = self.encoder(x) |
|
|
|
|
|
scale = self._softplus_unit(self.scale_head(x)).squeeze(-1) |
|
|
vec_aligned = self.vec_head_aligned(x) |
|
|
|
|
|
vector = vec_aligned @ R |
|
|
|
|
|
target_posframe = (target - t) @ R |
|
|
|
|
|
if self.rbf: |
|
|
d = grad_log_wrt_positions(pos, target_posframe, self.sigma).detach() |
|
|
else: |
|
|
d = (target_posframe - pos) |
|
|
|
|
|
scale = scale.unsqueeze(-1).expand(-1, -1, 3) |
|
|
scaled = scale * d |
|
|
|
|
|
eps = torch.finfo(pos.dtype).eps |
|
|
denom = d.pow(2).sum(dim=-1, keepdim=True).clamp_min(eps) |
|
|
vec_parallel = ((vector * d).sum(dim=-1, keepdim=True) / denom) * d |
|
|
vec_perp = vector - vec_parallel |
|
|
|
|
|
return vec_perp + scaled |
|
|
|
|
|
class BiasForceTransformerNoVel(nn.Module): |
|
|
def __init__(self, |
|
|
mds, |
|
|
args, |
|
|
d_model = 256, |
|
|
nhead = 8, |
|
|
num_layers = 4, |
|
|
dim_feedforward = 512, |
|
|
dropout = 0.1, |
|
|
): |
|
|
super().__init__() |
|
|
self.device = args.device |
|
|
self.heavy_atoms = mds.heavy_atoms |
|
|
self.N = mds.num_particles |
|
|
|
|
|
self.use_delta_to_target = args.use_delta_to_target |
|
|
self.rbf = args.rbf |
|
|
|
|
|
self.sigma = args.sigma |
|
|
|
|
|
feat_dim = 3 + (3 if self.use_delta_to_target else 0) + 1 |
|
|
|
|
|
self.input_proj = nn.Linear(feat_dim, d_model) |
|
|
enc_layer = nn.TransformerEncoderLayer( |
|
|
d_model=d_model, nhead=nhead, |
|
|
dim_feedforward=dim_feedforward, |
|
|
dropout=dropout, activation="gelu", |
|
|
batch_first=True, norm_first=True |
|
|
) |
|
|
self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers) |
|
|
|
|
|
|
|
|
self.scale_head = nn.Sequential( |
|
|
nn.Linear(d_model, d_model // 2), |
|
|
nn.GELU(), |
|
|
nn.Linear(d_model // 2, 1), |
|
|
) |
|
|
self.vec_head_aligned = nn.Sequential( |
|
|
nn.Linear(d_model, d_model // 2), |
|
|
nn.GELU(), |
|
|
nn.Linear(d_model // 2, 3), |
|
|
) |
|
|
|
|
|
self.log_z = nn.Parameter(torch.tensor(0.0)) |
|
|
self.to(self.device) |
|
|
|
|
|
@staticmethod |
|
|
def _softplus_unit(x, beta=1.0, threshold=20.0, eps=1e-8): |
|
|
return F.softplus(x, beta=beta, threshold=threshold) + eps |
|
|
|
|
|
def forward(self, pos, target): |
|
|
""" |
|
|
pos, target: (B,N,D) |
|
|
Returns: force (B,N,D), scale (B,N), vector (B,N,D) |
|
|
|
|
|
N: number of atoms |
|
|
D: dimension (3) |
|
|
""" |
|
|
B, N, _ = pos.shape |
|
|
assert N == self.N, f"Expected N={self.N}, got {N}" |
|
|
heavy = self.heavy_atoms.to(pos.device) |
|
|
|
|
|
pos_h, tgt_h = pos[:, heavy], target[:, heavy] |
|
|
R, t = kabsch(pos_h, tgt_h) |
|
|
|
|
|
pos_al = pos @ R.transpose(-2, -1) + t |
|
|
|
|
|
delta_al = target - pos_al |
|
|
dist_al = torch.norm(delta_al, dim=-1, keepdim=True) |
|
|
feats = torch.cat([pos_al, delta_al, dist_al], dim=-1) \ |
|
|
if self.use_delta_to_target else torch.cat([pos_al, dist_al], dim=-1) |
|
|
|
|
|
x = self.input_proj(feats) |
|
|
x = self.encoder(x) |
|
|
|
|
|
|
|
|
scale = self._softplus_unit(self.scale_head(x)).squeeze(-1) |
|
|
vec_aligned = self.vec_head_aligned(x) |
|
|
|
|
|
vector = vec_aligned @ R |
|
|
|
|
|
target_posframe = (target - t) @ R |
|
|
|
|
|
if self.rbf: |
|
|
d = grad_log_wrt_positions(pos, target_posframe, self.sigma).detach() |
|
|
else: |
|
|
d = (target_posframe - pos) |
|
|
|
|
|
scale = scale.unsqueeze(-1).expand(-1, -1, 3) |
|
|
scaled = scale * d |
|
|
|
|
|
eps = torch.finfo(pos.dtype).eps |
|
|
denom = d.pow(2).sum(dim=-1, keepdim=True).clamp_min(eps) |
|
|
vec_parallel = ((vector * d).sum(dim=-1, keepdim=True) / denom) * d |
|
|
vec_perp = vector - vec_parallel |
|
|
|
|
|
return vec_perp + scaled |