Sophia Tang
Initial commit with LFS
7efee70
import sys
import torch
import torch.nn as nn
import torch.nn.functional as F
from .utils.utils import kabsch
from .utils.rbf import grad_log_wrt_positions
class BiasForceTransformer(nn.Module):
def __init__(self,
mds,
args,
d_model = 256,
nhead = 8,
num_layers = 4,
dim_feedforward = 512,
dropout = 0.1,
):
super().__init__()
self.device = args.device
self.heavy_atoms = mds.heavy_atoms
self.N = mds.num_particles
self.use_delta_to_target = args.use_delta_to_target
self.rbf = args.rbf
self.sigma = args.sigma
feat_dim = 3 + 3 + (3 if self.use_delta_to_target else 0) + 1
self.input_proj = nn.Linear(feat_dim, d_model)
enc_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout, activation="gelu",
batch_first=True, norm_first=True
)
self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers)
self.scale_head = nn.Sequential(
nn.Linear(d_model, d_model // 2),
nn.GELU(),
nn.Linear(d_model // 2, 1),
)
self.vec_head_aligned = nn.Sequential(
nn.Linear(d_model, d_model // 2),
nn.GELU(),
nn.Linear(d_model // 2, 3),
)
self.bias = args.bias
self.log_z = nn.Parameter(torch.tensor(0.0))
self.to(self.device)
@staticmethod
def _softplus_unit(x, beta=1.0, threshold=20.0, eps=1e-8):
return F.softplus(x, beta=beta, threshold=threshold) + eps
def forward(self, pos, vel, target):
"""
pos, vel, target: (B,N,3)
Returns: force (B,N,3), scale (B,N), vector (B,N,3)
"""
B, N, _ = pos.shape
assert N == self.N, f"Expected N={self.N}, got {N}"
heavy = self.heavy_atoms.to(pos.device)
pos_h, tgt_h = pos[:, heavy], target[:, heavy] # (B,Nh,3)
R, t = kabsch(pos_h, tgt_h)
pos_al = pos @ R.transpose(-2, -1) + t
vel_al = vel @ R.transpose(-2, -1)
delta_al = target - pos_al # (B,N,3)
dist_al = torch.norm(delta_al, dim=-1, keepdim=True) # (B,N,1)
feats = torch.cat([pos_al, vel_al, delta_al, dist_al], dim=-1) \
if self.use_delta_to_target else torch.cat([pos_al, vel_al, dist_al], dim=-1)
x = self.input_proj(feats) # (B,N,d_model)
x = self.encoder(x) # (B,N,d_model)
scale = self._softplus_unit(self.scale_head(x)).squeeze(-1) # (B,N)
vec_aligned = self.vec_head_aligned(x) # (B,N,3)
vector = vec_aligned @ R # (B,N,3)
target_posframe = (target - t) @ R # (B,N,3)
if self.rbf:
d = grad_log_wrt_positions(pos, target_posframe, self.sigma).detach()
else:
d = (target_posframe - pos)
scale = scale.unsqueeze(-1).expand(-1, -1, 3)
scaled = scale * d
eps = torch.finfo(pos.dtype).eps
denom = d.pow(2).sum(dim=-1, keepdim=True).clamp_min(eps) # (B,N,1)
vec_parallel = ((vector * d).sum(dim=-1, keepdim=True) / denom) * d
vec_perp = vector - vec_parallel
return vec_perp + scaled
class BiasForceTransformerNoVel(nn.Module):
def __init__(self,
mds,
args,
d_model = 256,
nhead = 8,
num_layers = 4,
dim_feedforward = 512,
dropout = 0.1,
):
super().__init__()
self.device = args.device
self.heavy_atoms = mds.heavy_atoms
self.N = mds.num_particles
self.use_delta_to_target = args.use_delta_to_target
self.rbf = args.rbf
self.sigma = args.sigma
feat_dim = 3 + (3 if self.use_delta_to_target else 0) + 1
self.input_proj = nn.Linear(feat_dim, d_model)
enc_layer = nn.TransformerEncoderLayer(
d_model=d_model, nhead=nhead,
dim_feedforward=dim_feedforward,
dropout=dropout, activation="gelu",
batch_first=True, norm_first=True
)
self.encoder = nn.TransformerEncoder(enc_layer, num_layers=num_layers)
# Heads
self.scale_head = nn.Sequential(
nn.Linear(d_model, d_model // 2),
nn.GELU(),
nn.Linear(d_model // 2, 1),
)
self.vec_head_aligned = nn.Sequential(
nn.Linear(d_model, d_model // 2),
nn.GELU(),
nn.Linear(d_model // 2, 3),
)
self.log_z = nn.Parameter(torch.tensor(0.0))
self.to(self.device)
@staticmethod
def _softplus_unit(x, beta=1.0, threshold=20.0, eps=1e-8):
return F.softplus(x, beta=beta, threshold=threshold) + eps
def forward(self, pos, target):
"""
pos, target: (B,N,D)
Returns: force (B,N,D), scale (B,N), vector (B,N,D)
N: number of atoms
D: dimension (3)
"""
B, N, _ = pos.shape
assert N == self.N, f"Expected N={self.N}, got {N}"
heavy = self.heavy_atoms.to(pos.device)
pos_h, tgt_h = pos[:, heavy], target[:, heavy] # (B,Nh,3)
R, t = kabsch(pos_h, tgt_h)
pos_al = pos @ R.transpose(-2, -1) + t
delta_al = target - pos_al # (B,N,3)
dist_al = torch.norm(delta_al, dim=-1, keepdim=True) # (B,N,1)
feats = torch.cat([pos_al, delta_al, dist_al], dim=-1) \
if self.use_delta_to_target else torch.cat([pos_al, dist_al], dim=-1)
x = self.input_proj(feats) # (B,N,d_model)
x = self.encoder(x) # (B,N,d_model)
# Heads
scale = self._softplus_unit(self.scale_head(x)).squeeze(-1) # (B,N)
vec_aligned = self.vec_head_aligned(x) # (B,N,3)
vector = vec_aligned @ R # (B,N,3)
target_posframe = (target - t) @ R
if self.rbf:
d = grad_log_wrt_positions(pos, target_posframe, self.sigma).detach()
else:
d = (target_posframe - pos)
scale = scale.unsqueeze(-1).expand(-1, -1, 3)
scaled = scale * d # (B,N,3)
eps = torch.finfo(pos.dtype).eps
denom = d.pow(2).sum(dim=-1, keepdim=True).clamp_min(eps) # (B,N,1)
vec_parallel = ((vector * d).sum(dim=-1, keepdim=True) / denom) * d
vec_perp = vector - vec_parallel
return vec_perp + scaled