LigandMPNN / data_utils.py
gabboud's picture
initial commit from source repo
d95502a
from __future__ import print_function
import numpy as np
import torch
import torch.utils
from prody import *
confProDy(verbosity="none")
restype_1to3 = {
"A": "ALA",
"R": "ARG",
"N": "ASN",
"D": "ASP",
"C": "CYS",
"Q": "GLN",
"E": "GLU",
"G": "GLY",
"H": "HIS",
"I": "ILE",
"L": "LEU",
"K": "LYS",
"M": "MET",
"F": "PHE",
"P": "PRO",
"S": "SER",
"T": "THR",
"W": "TRP",
"Y": "TYR",
"V": "VAL",
"X": "UNK",
}
restype_str_to_int = {
"A": 0,
"C": 1,
"D": 2,
"E": 3,
"F": 4,
"G": 5,
"H": 6,
"I": 7,
"K": 8,
"L": 9,
"M": 10,
"N": 11,
"P": 12,
"Q": 13,
"R": 14,
"S": 15,
"T": 16,
"V": 17,
"W": 18,
"Y": 19,
"X": 20,
}
restype_int_to_str = {
0: "A",
1: "C",
2: "D",
3: "E",
4: "F",
5: "G",
6: "H",
7: "I",
8: "K",
9: "L",
10: "M",
11: "N",
12: "P",
13: "Q",
14: "R",
15: "S",
16: "T",
17: "V",
18: "W",
19: "Y",
20: "X",
}
alphabet = list(restype_str_to_int)
element_list = [
"H",
"He",
"Li",
"Be",
"B",
"C",
"N",
"O",
"F",
"Ne",
"Na",
"Mg",
"Al",
"Si",
"P",
"S",
"Cl",
"Ar",
"K",
"Ca",
"Sc",
"Ti",
"V",
"Cr",
"Mn",
"Fe",
"Co",
"Ni",
"Cu",
"Zn",
"Ga",
"Ge",
"As",
"Se",
"Br",
"Kr",
"Rb",
"Sr",
"Y",
"Zr",
"Nb",
"Mb",
"Tc",
"Ru",
"Rh",
"Pd",
"Ag",
"Cd",
"In",
"Sn",
"Sb",
"Te",
"I",
"Xe",
"Cs",
"Ba",
"La",
"Ce",
"Pr",
"Nd",
"Pm",
"Sm",
"Eu",
"Gd",
"Tb",
"Dy",
"Ho",
"Er",
"Tm",
"Yb",
"Lu",
"Hf",
"Ta",
"W",
"Re",
"Os",
"Ir",
"Pt",
"Au",
"Hg",
"Tl",
"Pb",
"Bi",
"Po",
"At",
"Rn",
"Fr",
"Ra",
"Ac",
"Th",
"Pa",
"U",
"Np",
"Pu",
"Am",
"Cm",
"Bk",
"Cf",
"Es",
"Fm",
"Md",
"No",
"Lr",
"Rf",
"Db",
"Sg",
"Bh",
"Hs",
"Mt",
"Ds",
"Rg",
"Cn",
"Uut",
"Fl",
"Uup",
"Lv",
"Uus",
"Uuo",
]
element_list = [item.upper() for item in element_list]
# element_dict = dict(zip(element_list, range(1,len(element_list))))
element_dict_rev = dict(zip(range(1, len(element_list)), element_list))
def get_seq_rec(S: torch.Tensor, S_pred: torch.Tensor, mask: torch.Tensor):
"""
S : true sequence shape=[batch, length]
S_pred : predicted sequence shape=[batch, length]
mask : mask to compute average over the region shape=[batch, length]
average : averaged sequence recovery shape=[batch]
"""
match = S == S_pred
average = torch.sum(match * mask, dim=-1) / torch.sum(mask, dim=-1)
return average
def get_score(S: torch.Tensor, log_probs: torch.Tensor, mask: torch.Tensor):
"""
S : true sequence shape=[batch, length]
log_probs : predicted sequence shape=[batch, length]
mask : mask to compute average over the region shape=[batch, length]
average_loss : averaged categorical cross entropy (CCE) [batch]
loss_per_resdue : per position CCE [batch, length]
"""
S_one_hot = torch.nn.functional.one_hot(S, 21)
loss_per_residue = -(S_one_hot * log_probs).sum(-1) # [B, L]
average_loss = torch.sum(loss_per_residue * mask, dim=-1) / (
torch.sum(mask, dim=-1) + 1e-8
)
return average_loss, loss_per_residue
def write_full_PDB(
save_path: str,
X: np.ndarray,
X_m: np.ndarray,
b_factors: np.ndarray,
R_idx: np.ndarray,
chain_letters: np.ndarray,
S: np.ndarray,
other_atoms=None,
icodes=None,
force_hetatm=False,
):
"""
save_path : path where the PDB will be written to
X : protein atom xyz coordinates shape=[length, 14, 3]
X_m : protein atom mask shape=[length, 14]
b_factors: shape=[length, 14]
R_idx: protein residue indices shape=[length]
chain_letters: protein chain letters shape=[length]
S : protein amino acid sequence shape=[length]
other_atoms: other atoms parsed by prody
icodes: a list of insertion codes for the PDB; e.g. antibody loops
"""
restype_1to3 = {
"A": "ALA",
"R": "ARG",
"N": "ASN",
"D": "ASP",
"C": "CYS",
"Q": "GLN",
"E": "GLU",
"G": "GLY",
"H": "HIS",
"I": "ILE",
"L": "LEU",
"K": "LYS",
"M": "MET",
"F": "PHE",
"P": "PRO",
"S": "SER",
"T": "THR",
"W": "TRP",
"Y": "TYR",
"V": "VAL",
"X": "UNK",
}
restype_INTtoSTR = {
0: "A",
1: "C",
2: "D",
3: "E",
4: "F",
5: "G",
6: "H",
7: "I",
8: "K",
9: "L",
10: "M",
11: "N",
12: "P",
13: "Q",
14: "R",
15: "S",
16: "T",
17: "V",
18: "W",
19: "Y",
20: "X",
}
restype_name_to_atom14_names = {
"ALA": ["N", "CA", "C", "O", "CB", "", "", "", "", "", "", "", "", ""],
"ARG": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD",
"NE",
"CZ",
"NH1",
"NH2",
"",
"",
"",
],
"ASN": ["N", "CA", "C", "O", "CB", "CG", "OD1", "ND2", "", "", "", "", "", ""],
"ASP": ["N", "CA", "C", "O", "CB", "CG", "OD1", "OD2", "", "", "", "", "", ""],
"CYS": ["N", "CA", "C", "O", "CB", "SG", "", "", "", "", "", "", "", ""],
"GLN": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD",
"OE1",
"NE2",
"",
"",
"",
"",
"",
],
"GLU": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD",
"OE1",
"OE2",
"",
"",
"",
"",
"",
],
"GLY": ["N", "CA", "C", "O", "", "", "", "", "", "", "", "", "", ""],
"HIS": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"ND1",
"CD2",
"CE1",
"NE2",
"",
"",
"",
"",
],
"ILE": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "CD1", "", "", "", "", "", ""],
"LEU": ["N", "CA", "C", "O", "CB", "CG", "CD1", "CD2", "", "", "", "", "", ""],
"LYS": ["N", "CA", "C", "O", "CB", "CG", "CD", "CE", "NZ", "", "", "", "", ""],
"MET": ["N", "CA", "C", "O", "CB", "CG", "SD", "CE", "", "", "", "", "", ""],
"PHE": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD1",
"CD2",
"CE1",
"CE2",
"CZ",
"",
"",
"",
],
"PRO": ["N", "CA", "C", "O", "CB", "CG", "CD", "", "", "", "", "", "", ""],
"SER": ["N", "CA", "C", "O", "CB", "OG", "", "", "", "", "", "", "", ""],
"THR": ["N", "CA", "C", "O", "CB", "OG1", "CG2", "", "", "", "", "", "", ""],
"TRP": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD1",
"CD2",
"CE2",
"CE3",
"NE1",
"CZ2",
"CZ3",
"CH2",
],
"TYR": [
"N",
"CA",
"C",
"O",
"CB",
"CG",
"CD1",
"CD2",
"CE1",
"CE2",
"CZ",
"OH",
"",
"",
],
"VAL": ["N", "CA", "C", "O", "CB", "CG1", "CG2", "", "", "", "", "", "", ""],
"UNK": ["", "", "", "", "", "", "", "", "", "", "", "", "", ""],
}
S_str = [restype_1to3[AA] for AA in [restype_INTtoSTR[AA] for AA in S]]
X_list = []
b_factor_list = []
atom_name_list = []
element_name_list = []
residue_name_list = []
residue_number_list = []
chain_id_list = []
icodes_list = []
for i, AA in enumerate(S_str):
sel = X_m[i].astype(np.int32) == 1
total = np.sum(sel)
tmp = np.array(restype_name_to_atom14_names[AA])[sel]
X_list.append(X[i][sel])
b_factor_list.append(b_factors[i][sel])
atom_name_list.append(tmp)
element_name_list += [AA[:1] for AA in list(tmp)]
residue_name_list += total * [AA]
residue_number_list += total * [R_idx[i]]
chain_id_list += total * [chain_letters[i]]
icodes_list += total * [icodes[i]]
X_stack = np.concatenate(X_list, 0)
b_factor_stack = np.concatenate(b_factor_list, 0)
atom_name_stack = np.concatenate(atom_name_list, 0)
protein = prody.AtomGroup()
protein.setCoords(X_stack)
protein.setBetas(b_factor_stack)
protein.setNames(atom_name_stack)
protein.setResnames(residue_name_list)
protein.setElements(element_name_list)
protein.setOccupancies(np.ones([X_stack.shape[0]]))
protein.setResnums(residue_number_list)
protein.setChids(chain_id_list)
protein.setIcodes(icodes_list)
if other_atoms:
other_atoms_g = prody.AtomGroup()
other_atoms_g.setCoords(other_atoms.getCoords())
other_atoms_g.setNames(other_atoms.getNames())
other_atoms_g.setResnames(other_atoms.getResnames())
other_atoms_g.setElements(other_atoms.getElements())
other_atoms_g.setOccupancies(other_atoms.getOccupancies())
other_atoms_g.setResnums(other_atoms.getResnums())
other_atoms_g.setChids(other_atoms.getChids())
if force_hetatm:
other_atoms_g.setFlags("hetatm", other_atoms.getFlags("hetatm"))
writePDB(save_path, protein + other_atoms_g)
else:
writePDB(save_path, protein)
def get_aligned_coordinates(protein_atoms, CA_dict: dict, atom_name: str):
"""
protein_atoms: prody atom group
CA_dict: mapping between chain_residue_idx_icodes and integers
atom_name: atom to be parsed; e.g. CA
"""
atom_atoms = protein_atoms.select(f"name {atom_name}")
if atom_atoms != None:
atom_coords = atom_atoms.getCoords()
atom_resnums = atom_atoms.getResnums()
atom_chain_ids = atom_atoms.getChids()
atom_icodes = atom_atoms.getIcodes()
atom_coords_ = np.zeros([len(CA_dict), 3], np.float32)
atom_coords_m = np.zeros([len(CA_dict)], np.int32)
if atom_atoms != None:
for i in range(len(atom_resnums)):
code = atom_chain_ids[i] + "_" + str(atom_resnums[i]) + "_" + atom_icodes[i]
if code in list(CA_dict):
atom_coords_[CA_dict[code], :] = atom_coords[i]
atom_coords_m[CA_dict[code]] = 1
return atom_coords_, atom_coords_m
def parse_PDB(
input_path: str,
device: str = "cpu",
chains: list = [],
parse_all_atoms: bool = False,
parse_atoms_with_zero_occupancy: bool = False
):
"""
input_path : path for the input PDB
device: device for the torch.Tensor
chains: a list specifying which chains need to be parsed; e.g. ["A", "B"]
parse_all_atoms: if False parse only N,CA,C,O otherwise all 37 atoms
parse_atoms_with_zero_occupancy: if True atoms with zero occupancy will be parsed
"""
element_list = [
"H",
"He",
"Li",
"Be",
"B",
"C",
"N",
"O",
"F",
"Ne",
"Na",
"Mg",
"Al",
"Si",
"P",
"S",
"Cl",
"Ar",
"K",
"Ca",
"Sc",
"Ti",
"V",
"Cr",
"Mn",
"Fe",
"Co",
"Ni",
"Cu",
"Zn",
"Ga",
"Ge",
"As",
"Se",
"Br",
"Kr",
"Rb",
"Sr",
"Y",
"Zr",
"Nb",
"Mb",
"Tc",
"Ru",
"Rh",
"Pd",
"Ag",
"Cd",
"In",
"Sn",
"Sb",
"Te",
"I",
"Xe",
"Cs",
"Ba",
"La",
"Ce",
"Pr",
"Nd",
"Pm",
"Sm",
"Eu",
"Gd",
"Tb",
"Dy",
"Ho",
"Er",
"Tm",
"Yb",
"Lu",
"Hf",
"Ta",
"W",
"Re",
"Os",
"Ir",
"Pt",
"Au",
"Hg",
"Tl",
"Pb",
"Bi",
"Po",
"At",
"Rn",
"Fr",
"Ra",
"Ac",
"Th",
"Pa",
"U",
"Np",
"Pu",
"Am",
"Cm",
"Bk",
"Cf",
"Es",
"Fm",
"Md",
"No",
"Lr",
"Rf",
"Db",
"Sg",
"Bh",
"Hs",
"Mt",
"Ds",
"Rg",
"Cn",
"Uut",
"Fl",
"Uup",
"Lv",
"Uus",
"Uuo",
]
element_list = [item.upper() for item in element_list]
element_dict = dict(zip(element_list, range(1, len(element_list))))
restype_3to1 = {
"ALA": "A",
"ARG": "R",
"ASN": "N",
"ASP": "D",
"CYS": "C",
"GLN": "Q",
"GLU": "E",
"GLY": "G",
"HIS": "H",
"ILE": "I",
"LEU": "L",
"LYS": "K",
"MET": "M",
"PHE": "F",
"PRO": "P",
"SER": "S",
"THR": "T",
"TRP": "W",
"TYR": "Y",
"VAL": "V",
}
restype_STRtoINT = {
"A": 0,
"C": 1,
"D": 2,
"E": 3,
"F": 4,
"G": 5,
"H": 6,
"I": 7,
"K": 8,
"L": 9,
"M": 10,
"N": 11,
"P": 12,
"Q": 13,
"R": 14,
"S": 15,
"T": 16,
"V": 17,
"W": 18,
"Y": 19,
"X": 20,
}
atom_order = {
"N": 0,
"CA": 1,
"C": 2,
"CB": 3,
"O": 4,
"CG": 5,
"CG1": 6,
"CG2": 7,
"OG": 8,
"OG1": 9,
"SG": 10,
"CD": 11,
"CD1": 12,
"CD2": 13,
"ND1": 14,
"ND2": 15,
"OD1": 16,
"OD2": 17,
"SD": 18,
"CE": 19,
"CE1": 20,
"CE2": 21,
"CE3": 22,
"NE": 23,
"NE1": 24,
"NE2": 25,
"OE1": 26,
"OE2": 27,
"CH2": 28,
"NH1": 29,
"NH2": 30,
"OH": 31,
"CZ": 32,
"CZ2": 33,
"CZ3": 34,
"NZ": 35,
"OXT": 36,
}
if not parse_all_atoms:
atom_types = ["N", "CA", "C", "O"]
else:
atom_types = [
"N",
"CA",
"C",
"CB",
"O",
"CG",
"CG1",
"CG2",
"OG",
"OG1",
"SG",
"CD",
"CD1",
"CD2",
"ND1",
"ND2",
"OD1",
"OD2",
"SD",
"CE",
"CE1",
"CE2",
"CE3",
"NE",
"NE1",
"NE2",
"OE1",
"OE2",
"CH2",
"NH1",
"NH2",
"OH",
"CZ",
"CZ2",
"CZ3",
"NZ",
]
atoms = parsePDB(input_path)
if not parse_atoms_with_zero_occupancy:
atoms = atoms.select("occupancy > 0")
if chains:
str_out = ""
for item in chains:
str_out += " chain " + item + " or"
atoms = atoms.select(str_out[1:-3])
protein_atoms = atoms.select("protein")
backbone = protein_atoms.select("backbone")
other_atoms = atoms.select("not protein and not water")
water_atoms = atoms.select("water")
CA_atoms = protein_atoms.select("name CA")
CA_resnums = CA_atoms.getResnums()
CA_chain_ids = CA_atoms.getChids()
CA_icodes = CA_atoms.getIcodes()
CA_dict = {}
for i in range(len(CA_resnums)):
code = CA_chain_ids[i] + "_" + str(CA_resnums[i]) + "_" + CA_icodes[i]
CA_dict[code] = i
xyz_37 = np.zeros([len(CA_dict), 37, 3], np.float32)
xyz_37_m = np.zeros([len(CA_dict), 37], np.int32)
for atom_name in atom_types:
xyz, xyz_m = get_aligned_coordinates(protein_atoms, CA_dict, atom_name)
xyz_37[:, atom_order[atom_name], :] = xyz
xyz_37_m[:, atom_order[atom_name]] = xyz_m
N = xyz_37[:, atom_order["N"], :]
CA = xyz_37[:, atom_order["CA"], :]
C = xyz_37[:, atom_order["C"], :]
O = xyz_37[:, atom_order["O"], :]
N_m = xyz_37_m[:, atom_order["N"]]
CA_m = xyz_37_m[:, atom_order["CA"]]
C_m = xyz_37_m[:, atom_order["C"]]
O_m = xyz_37_m[:, atom_order["O"]]
mask = N_m * CA_m * C_m * O_m # must all 4 atoms exist
b = CA - N
c = C - CA
a = np.cross(b, c, axis=-1)
CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA
chain_labels = np.array(CA_atoms.getChindices(), dtype=np.int32)
R_idx = np.array(CA_resnums, dtype=np.int32)
S = CA_atoms.getResnames()
S = [restype_3to1[AA] if AA in list(restype_3to1) else "X" for AA in list(S)]
S = np.array([restype_STRtoINT[AA] for AA in list(S)], np.int32)
X = np.concatenate([N[:, None], CA[:, None], C[:, None], O[:, None]], 1)
try:
Y = np.array(other_atoms.getCoords(), dtype=np.float32)
Y_t = list(other_atoms.getElements())
Y_t = np.array(
[
element_dict[y_t.upper()] if y_t.upper() in element_list else 0
for y_t in Y_t
],
dtype=np.int32,
)
Y_m = (Y_t != 1) * (Y_t != 0)
Y = Y[Y_m, :]
Y_t = Y_t[Y_m]
Y_m = Y_m[Y_m]
except:
Y = np.zeros([1, 3], np.float32)
Y_t = np.zeros([1], np.int32)
Y_m = np.zeros([1], np.int32)
output_dict = {}
output_dict["X"] = torch.tensor(X, device=device, dtype=torch.float32)
output_dict["mask"] = torch.tensor(mask, device=device, dtype=torch.int32)
output_dict["Y"] = torch.tensor(Y, device=device, dtype=torch.float32)
output_dict["Y_t"] = torch.tensor(Y_t, device=device, dtype=torch.int32)
output_dict["Y_m"] = torch.tensor(Y_m, device=device, dtype=torch.int32)
output_dict["R_idx"] = torch.tensor(R_idx, device=device, dtype=torch.int32)
output_dict["chain_labels"] = torch.tensor(
chain_labels, device=device, dtype=torch.int32
)
output_dict["chain_letters"] = CA_chain_ids
mask_c = []
chain_list = list(set(output_dict["chain_letters"]))
chain_list.sort()
for chain in chain_list:
mask_c.append(
torch.tensor(
[chain == item for item in output_dict["chain_letters"]],
device=device,
dtype=bool,
)
)
output_dict["mask_c"] = mask_c
output_dict["chain_list"] = chain_list
output_dict["S"] = torch.tensor(S, device=device, dtype=torch.int32)
output_dict["xyz_37"] = torch.tensor(xyz_37, device=device, dtype=torch.float32)
output_dict["xyz_37_m"] = torch.tensor(xyz_37_m, device=device, dtype=torch.int32)
return output_dict, backbone, other_atoms, CA_icodes, water_atoms
def get_nearest_neighbours(CB, mask, Y, Y_t, Y_m, number_of_ligand_atoms):
device = CB.device
mask_CBY = mask[:, None] * Y_m[None, :] # [A,B]
L2_AB = torch.sum((CB[:, None, :] - Y[None, :, :]) ** 2, -1)
L2_AB = L2_AB * mask_CBY + (1 - mask_CBY) * 1000.0
nn_idx = torch.argsort(L2_AB, -1)[:, :number_of_ligand_atoms]
L2_AB_nn = torch.gather(L2_AB, 1, nn_idx)
D_AB_closest = torch.sqrt(L2_AB_nn[:, 0])
Y_r = Y[None, :, :].repeat(CB.shape[0], 1, 1)
Y_t_r = Y_t[None, :].repeat(CB.shape[0], 1)
Y_m_r = Y_m[None, :].repeat(CB.shape[0], 1)
Y_tmp = torch.gather(Y_r, 1, nn_idx[:, :, None].repeat(1, 1, 3))
Y_t_tmp = torch.gather(Y_t_r, 1, nn_idx)
Y_m_tmp = torch.gather(Y_m_r, 1, nn_idx)
Y = torch.zeros(
[CB.shape[0], number_of_ligand_atoms, 3], dtype=torch.float32, device=device
)
Y_t = torch.zeros(
[CB.shape[0], number_of_ligand_atoms], dtype=torch.int32, device=device
)
Y_m = torch.zeros(
[CB.shape[0], number_of_ligand_atoms], dtype=torch.int32, device=device
)
num_nn_update = Y_tmp.shape[1]
Y[:, :num_nn_update] = Y_tmp
Y_t[:, :num_nn_update] = Y_t_tmp
Y_m[:, :num_nn_update] = Y_m_tmp
return Y, Y_t, Y_m, D_AB_closest
def featurize(
input_dict,
cutoff_for_score=8.0,
use_atom_context=True,
number_of_ligand_atoms=16,
model_type="protein_mpnn",
):
output_dict = {}
if model_type == "ligand_mpnn":
mask = input_dict["mask"]
Y = input_dict["Y"]
Y_t = input_dict["Y_t"]
Y_m = input_dict["Y_m"]
N = input_dict["X"][:, 0, :]
CA = input_dict["X"][:, 1, :]
C = input_dict["X"][:, 2, :]
b = CA - N
c = C - CA
a = torch.cross(b, c, axis=-1)
CB = -0.58273431 * a + 0.56802827 * b - 0.54067466 * c + CA
Y, Y_t, Y_m, D_XY = get_nearest_neighbours(
CB, mask, Y, Y_t, Y_m, number_of_ligand_atoms
)
mask_XY = (D_XY < cutoff_for_score) * mask * Y_m[:, 0]
output_dict["mask_XY"] = mask_XY[None,]
if "side_chain_mask" in list(input_dict):
output_dict["side_chain_mask"] = input_dict["side_chain_mask"][None,]
output_dict["Y"] = Y[None,]
output_dict["Y_t"] = Y_t[None,]
output_dict["Y_m"] = Y_m[None,]
if not use_atom_context:
output_dict["Y_m"] = 0.0 * output_dict["Y_m"]
elif (
model_type == "per_residue_label_membrane_mpnn"
or model_type == "global_label_membrane_mpnn"
):
output_dict["membrane_per_residue_labels"] = input_dict[
"membrane_per_residue_labels"
][None,]
R_idx_list = []
count = 0
R_idx_prev = -100000
for R_idx in list(input_dict["R_idx"]):
if R_idx_prev == R_idx:
count += 1
R_idx_list.append(R_idx + count)
R_idx_prev = R_idx
R_idx_renumbered = torch.tensor(R_idx_list, device=R_idx.device)
output_dict["R_idx"] = R_idx_renumbered[None,]
output_dict["R_idx_original"] = input_dict["R_idx"][None,]
output_dict["chain_labels"] = input_dict["chain_labels"][None,]
output_dict["S"] = input_dict["S"][None,]
output_dict["chain_mask"] = input_dict["chain_mask"][None,]
output_dict["mask"] = input_dict["mask"][None,]
output_dict["X"] = input_dict["X"][None,]
if "xyz_37" in list(input_dict):
output_dict["xyz_37"] = input_dict["xyz_37"][None,]
output_dict["xyz_37_m"] = input_dict["xyz_37_m"][None,]
return output_dict