Spaces:
Sleeping
Sleeping
Upload 36 files
Browse files- src/__init__.py +0 -0
- src/__pycache__/__init__.cpython-310.pyc +0 -0
- src/data/__init__.py +0 -0
- src/data/__pycache__/__init__.cpython-310.pyc +0 -0
- src/data/__pycache__/pinder_datamodule.cpython-310.pyc +0 -0
- src/data/components/__init__.py +0 -0
- src/data/components/__pycache__/__init__.cpython-310.pyc +0 -0
- src/data/components/__pycache__/pinder_dataset.cpython-310.pyc +0 -0
- src/data/components/__pycache__/prepare_data.cpython-310.pyc +0 -0
- src/data/components/pinder_dataset.py +64 -0
- src/data/components/prepare_data.py +175 -0
- src/data/pinder_datamodule.py +167 -0
- src/eval.py +99 -0
- src/models/__init__.py +0 -0
- src/models/__pycache__/__init__.cpython-310.pyc +0 -0
- src/models/__pycache__/pinder_module.cpython-310.pyc +0 -0
- src/models/components/__init__.py +0 -0
- src/models/components/__pycache__/__init__.cpython-310.pyc +0 -0
- src/models/components/__pycache__/equivariant_mpnn.cpython-310.pyc +0 -0
- src/models/components/__pycache__/utils.cpython-310.pyc +0 -0
- src/models/components/equivariant_mpnn.py +231 -0
- src/models/components/utils.py +100 -0
- src/models/pinder_module.py +297 -0
- src/train.py +133 -0
- src/utils/__init__.py +5 -0
- src/utils/__pycache__/__init__.cpython-310.pyc +0 -0
- src/utils/__pycache__/instantiators.cpython-310.pyc +0 -0
- src/utils/__pycache__/logging_utils.cpython-310.pyc +0 -0
- src/utils/__pycache__/pylogger.cpython-310.pyc +0 -0
- src/utils/__pycache__/rich_utils.cpython-310.pyc +0 -0
- src/utils/__pycache__/utils.cpython-310.pyc +0 -0
- src/utils/instantiators.py +56 -0
- src/utils/logging_utils.py +57 -0
- src/utils/pylogger.py +51 -0
- src/utils/rich_utils.py +103 -0
- src/utils/utils.py +119 -0
src/__init__.py
ADDED
|
File without changes
|
src/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (138 Bytes). View file
|
|
|
src/data/__init__.py
ADDED
|
File without changes
|
src/data/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (143 Bytes). View file
|
|
|
src/data/__pycache__/pinder_datamodule.cpython-310.pyc
ADDED
|
Binary file (6.15 kB). View file
|
|
|
src/data/components/__init__.py
ADDED
|
File without changes
|
src/data/components/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (154 Bytes). View file
|
|
|
src/data/components/__pycache__/pinder_dataset.cpython-310.pyc
ADDED
|
Binary file (2.09 kB). View file
|
|
|
src/data/components/__pycache__/prepare_data.cpython-310.pyc
ADDED
|
Binary file (5.29 kB). View file
|
|
|
src/data/components/pinder_dataset.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
import __main__
|
| 4 |
+
import rootutils
|
| 5 |
+
import torch
|
| 6 |
+
from torch_geometric.data import Dataset
|
| 7 |
+
|
| 8 |
+
# setup root dir and pythonpath
|
| 9 |
+
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
| 10 |
+
from src.data.components.prepare_data import CropPairedPDB
|
| 11 |
+
|
| 12 |
+
setattr(__main__, "CropPairedPDB", CropPairedPDB)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class PinderDataset(Dataset):
|
| 16 |
+
"""Pinder dataset.
|
| 17 |
+
|
| 18 |
+
Args:
|
| 19 |
+
Dataset: PyTorch Geometric Dataset.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, file_paths: List[str]) -> None:
|
| 23 |
+
"""Initialize the PinderDataset.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
file_paths: List of file paths.
|
| 27 |
+
"""
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.file_paths = file_paths
|
| 30 |
+
|
| 31 |
+
@property
|
| 32 |
+
def processed_file_names(self) -> List[str]:
|
| 33 |
+
"""Return the processed file names.
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
List[str]: List of processed
|
| 37 |
+
"""
|
| 38 |
+
return self.file_paths
|
| 39 |
+
|
| 40 |
+
def len(self) -> int:
|
| 41 |
+
"""Return the length of the dataset.
|
| 42 |
+
|
| 43 |
+
Returns:
|
| 44 |
+
int: Length of the dataset
|
| 45 |
+
"""
|
| 46 |
+
return len(self.processed_file_names)
|
| 47 |
+
|
| 48 |
+
def get(self, idx) -> CropPairedPDB:
|
| 49 |
+
"""Get the data at the given index.
|
| 50 |
+
|
| 51 |
+
Args:
|
| 52 |
+
idx: Index of the data.
|
| 53 |
+
|
| 54 |
+
Returns:
|
| 55 |
+
CropPairedPDB: CropPairedPDB object.
|
| 56 |
+
"""
|
| 57 |
+
data = torch.load(self.processed_file_names[idx], weights_only=False)
|
| 58 |
+
return data
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
if __name__ == "__main__":
|
| 62 |
+
file_paths = ["./data/processed/apo/test/1a19__A1_P11540--1a19__B1_P11540.pt"]
|
| 63 |
+
dataset = PinderDataset(file_paths=file_paths)
|
| 64 |
+
print(dataset[0])
|
src/data/components/prepare_data.py
ADDED
|
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import multiprocessing
|
| 2 |
+
import os
|
| 3 |
+
from argparse import ArgumentParser
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Optional
|
| 6 |
+
|
| 7 |
+
import rootutils
|
| 8 |
+
import torch
|
| 9 |
+
from loguru import logger
|
| 10 |
+
from pinder.core import PinderSystem, get_index
|
| 11 |
+
from pinder.core.loader.geodata import PairedPDB, structure2tensor
|
| 12 |
+
from pinder.core.loader.structure import Structure
|
| 13 |
+
from tqdm.auto import tqdm
|
| 14 |
+
|
| 15 |
+
# setup root dir and pythonpath
|
| 16 |
+
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
from torch_cluster import knn_graph
|
| 20 |
+
|
| 21 |
+
torch_cluster_installed = True
|
| 22 |
+
except ImportError:
|
| 23 |
+
logger.warning(
|
| 24 |
+
"torch-cluster is not installed!"
|
| 25 |
+
"Please install the appropriate library for your pytorch installation."
|
| 26 |
+
"See https://github.com/rusty1s/pytorch_cluster/issues/185 for background."
|
| 27 |
+
)
|
| 28 |
+
torch_cluster_installed = False
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def create_lr_files(system_id: str, apo_complex_path: str, save_path: str):
|
| 32 |
+
apo_r_path = os.path.join(save_path, f"apo_r_{system_id}.pdb")
|
| 33 |
+
apo_l_path = os.path.join(save_path, f"apo_l_{system_id}.pdb")
|
| 34 |
+
native_path = apo_complex_path.with_name(apo_complex_path.stem + f"{system_id}.pdb")
|
| 35 |
+
with open(native_path) as infile, open(apo_r_path, "w") as output_r, open(
|
| 36 |
+
apo_l_path, "w"
|
| 37 |
+
) as output_l:
|
| 38 |
+
|
| 39 |
+
for line in infile:
|
| 40 |
+
# Check if the line is an ATOM or HETATM line and has a chain ID at position 21
|
| 41 |
+
if line.startswith("ATOM") or line.startswith("HETATM"):
|
| 42 |
+
chain_id = line[21]
|
| 43 |
+
if chain_id == "R":
|
| 44 |
+
output_r.write(line)
|
| 45 |
+
elif chain_id == "L":
|
| 46 |
+
output_l.write(line)
|
| 47 |
+
else:
|
| 48 |
+
# Write other lines (e.g., HEADER, REMARK) to both files
|
| 49 |
+
output_r.write(line)
|
| 50 |
+
output_l.write(line)
|
| 51 |
+
return apo_r_path, apo_l_path
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class CropPairedPDB(PairedPDB):
|
| 55 |
+
@classmethod
|
| 56 |
+
def from_crop_system(
|
| 57 |
+
cls,
|
| 58 |
+
system_id: str,
|
| 59 |
+
root: str = "./data/",
|
| 60 |
+
k: int = 10,
|
| 61 |
+
add_edges: bool = True,
|
| 62 |
+
predicted_structures: bool = True,
|
| 63 |
+
split: str = "train",
|
| 64 |
+
) -> None:
|
| 65 |
+
system = PinderSystem(system_id)
|
| 66 |
+
# Create directories if they do not exist
|
| 67 |
+
for subdir in ["apo", "holo", "predicted"]:
|
| 68 |
+
os.makedirs(Path(root) / "raw" / subdir / split, exist_ok=True)
|
| 69 |
+
|
| 70 |
+
try:
|
| 71 |
+
holo_complex, apo_complex, pred_complex = system.create_masked_bound_unbound_complexes(
|
| 72 |
+
renumber_residues=True
|
| 73 |
+
)
|
| 74 |
+
for complex_type, complex_obj in zip(
|
| 75 |
+
["apo", "holo", "predicted"], [apo_complex, holo_complex, pred_complex]
|
| 76 |
+
):
|
| 77 |
+
complex_obj.to_pdb(
|
| 78 |
+
Path(root) / "raw" / complex_type / split / f"{system_id}_complex.pdb"
|
| 79 |
+
)
|
| 80 |
+
except Exception as e:
|
| 81 |
+
logger.error(f"Error in writing PDB files: {e}, {system_id}")
|
| 82 |
+
return None
|
| 83 |
+
|
| 84 |
+
if predicted_structures:
|
| 85 |
+
apo_complex = pred_complex
|
| 86 |
+
save_path = os.path.join(root, "processed", "predicted", split)
|
| 87 |
+
else:
|
| 88 |
+
save_path = os.path.join(root, "processed", "apo", split)
|
| 89 |
+
|
| 90 |
+
# create the directory if it does not exist
|
| 91 |
+
os.makedirs(save_path, exist_ok=True)
|
| 92 |
+
|
| 93 |
+
graph = cls.from_structure_pair(
|
| 94 |
+
holo_complex=holo_complex,
|
| 95 |
+
apo_complex=apo_complex,
|
| 96 |
+
add_edges=add_edges,
|
| 97 |
+
k=k,
|
| 98 |
+
)
|
| 99 |
+
torch.save(graph, os.path.join(save_path, f"{system_id}.pt"))
|
| 100 |
+
|
| 101 |
+
@classmethod
|
| 102 |
+
def from_structure_pair(
|
| 103 |
+
cls,
|
| 104 |
+
holo_complex: Structure,
|
| 105 |
+
apo_complex: Structure,
|
| 106 |
+
add_edges: bool = True,
|
| 107 |
+
k: int = 10,
|
| 108 |
+
) -> PairedPDB:
|
| 109 |
+
def get_structure_props(structure: Structure, start: int, end: Optional[int]):
|
| 110 |
+
calpha = structure.filter("atom_name", mask=["CA"])
|
| 111 |
+
return structure2tensor(
|
| 112 |
+
atom_coordinates=structure.coords[start:end],
|
| 113 |
+
atom_types=structure.atom_array.atom_name[start:end],
|
| 114 |
+
element_types=structure.atom_array.element[start:end],
|
| 115 |
+
residue_coordinates=calpha.coords[start:end],
|
| 116 |
+
residue_types=calpha.atom_array.res_name[start:end],
|
| 117 |
+
residue_ids=calpha.atom_array.res_id[start:end],
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
graph = cls()
|
| 121 |
+
r_h = (holo_complex.dataframe["chain_id"] == "R").sum()
|
| 122 |
+
r_a = (apo_complex.dataframe["chain_id"] == "R").sum()
|
| 123 |
+
|
| 124 |
+
holo_r_props = get_structure_props(holo_complex, 0, r_h)
|
| 125 |
+
holo_l_props = get_structure_props(holo_complex, r_h, None)
|
| 126 |
+
apo_r_props = get_structure_props(apo_complex, 0, r_a)
|
| 127 |
+
apo_l_props = get_structure_props(apo_complex, r_a, None)
|
| 128 |
+
|
| 129 |
+
graph["ligand"].x = apo_l_props["atom_types"]
|
| 130 |
+
graph["ligand"].pos = apo_l_props["atom_coordinates"]
|
| 131 |
+
graph["receptor"].x = apo_r_props["atom_types"]
|
| 132 |
+
graph["receptor"].pos = apo_r_props["atom_coordinates"]
|
| 133 |
+
graph["ligand"].y = holo_l_props["atom_coordinates"]
|
| 134 |
+
graph["receptor"].y = holo_r_props["atom_coordinates"]
|
| 135 |
+
|
| 136 |
+
if add_edges and torch_cluster_installed:
|
| 137 |
+
graph["ligand", "ligand"].edge_index = knn_graph(graph["ligand"].pos, k=k)
|
| 138 |
+
graph["receptor", "receptor"].edge_index = knn_graph(graph["receptor"].pos, k=k)
|
| 139 |
+
|
| 140 |
+
return graph
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
if __name__ == "__main__":
|
| 144 |
+
parser = ArgumentParser()
|
| 145 |
+
parser.add_argument("--n_jobs", type=int, default=20)
|
| 146 |
+
parser.add_argument("--k", type=int, default=10)
|
| 147 |
+
parser.add_argument("--predicted_structures", action="store_true")
|
| 148 |
+
parser.add_argument("--split", type=str, default="train")
|
| 149 |
+
args = parser.parse_args()
|
| 150 |
+
|
| 151 |
+
predicted_structures = args.predicted_structures
|
| 152 |
+
|
| 153 |
+
# get indices for train, validation, and test splits
|
| 154 |
+
indices = get_index()
|
| 155 |
+
|
| 156 |
+
if predicted_structures:
|
| 157 |
+
query = '(split == "{split}") and ((apo_R == False and apo_L == False) and (predicted_R==True and predicted_L==True))'
|
| 158 |
+
else:
|
| 159 |
+
query = '(split == "{split}") and (apo_R == True and apo_L == True)'
|
| 160 |
+
|
| 161 |
+
system_idx = indices.query(query.format(split=args.split)).reset_index(drop=True)
|
| 162 |
+
|
| 163 |
+
system_ids = system_idx.id.tolist()
|
| 164 |
+
|
| 165 |
+
def process_system_id(system_id: str):
|
| 166 |
+
graph = CropPairedPDB.from_crop_system(
|
| 167 |
+
system_id,
|
| 168 |
+
predicted_structures=predicted_structures,
|
| 169 |
+
k=args.k,
|
| 170 |
+
split=args.split,
|
| 171 |
+
)
|
| 172 |
+
return graph
|
| 173 |
+
|
| 174 |
+
with multiprocessing.Pool(args.n_jobs) as pool:
|
| 175 |
+
results = list(tqdm(pool.imap(process_system_id, system_ids), total=len(system_ids)))
|
src/data/pinder_datamodule.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Any, Dict, Optional
|
| 3 |
+
|
| 4 |
+
import pandas as pd
|
| 5 |
+
import rootutils
|
| 6 |
+
from lightning import LightningDataModule
|
| 7 |
+
from torch_geometric.data import Dataset
|
| 8 |
+
from torch_geometric.loader import DataLoader
|
| 9 |
+
|
| 10 |
+
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
| 11 |
+
|
| 12 |
+
from src.data.components.pinder_dataset import PinderDataset
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class PINDERDataModule(LightningDataModule):
|
| 16 |
+
"""`LightningDataModule` for the PINDER dataset."""
|
| 17 |
+
|
| 18 |
+
def __init__(
|
| 19 |
+
self,
|
| 20 |
+
data_dir: str = "data/processed",
|
| 21 |
+
predicted_structures: bool = False,
|
| 22 |
+
high_quality: bool = False,
|
| 23 |
+
batch_size: int = 1,
|
| 24 |
+
num_workers: int = 0,
|
| 25 |
+
pin_memory: bool = True,
|
| 26 |
+
) -> None:
|
| 27 |
+
"""Initialize the `PINDERDataModule`.
|
| 28 |
+
|
| 29 |
+
Args:
|
| 30 |
+
data_dir: Data for pinder. Defaults to "data/processed".
|
| 31 |
+
predicted_structures: Whether to use predicted structures. Defaults to True.
|
| 32 |
+
batch_size: Batch size. Defaults to 64.
|
| 33 |
+
num_workers: Number of workers for parallel processing. Defaults to 0.
|
| 34 |
+
pin_memory: Whether to pin memory. Defaults to True.
|
| 35 |
+
"""
|
| 36 |
+
super().__init__()
|
| 37 |
+
|
| 38 |
+
# this line allows to access init params with 'self.hparams' attribute
|
| 39 |
+
# also ensures init params will be stored in ckpt
|
| 40 |
+
self.save_hyperparameters(logger=False)
|
| 41 |
+
|
| 42 |
+
# get metadata
|
| 43 |
+
metadata = pd.read_csv(os.path.join(self.hparams.data_dir, "metadata.csv"))
|
| 44 |
+
|
| 45 |
+
def get_files(split: str, complex_types: list) -> list:
|
| 46 |
+
file_df = metadata[
|
| 47 |
+
(metadata["split"] == split) & (metadata["complex"].isin(complex_types))
|
| 48 |
+
]
|
| 49 |
+
file_df["file_paths"] = file_df.apply(
|
| 50 |
+
lambda row: os.path.join(
|
| 51 |
+
"./data/processed", row["complex"], row["split"], row["file_paths"]
|
| 52 |
+
),
|
| 53 |
+
axis=1,
|
| 54 |
+
)
|
| 55 |
+
return file_df["file_paths"].tolist()
|
| 56 |
+
|
| 57 |
+
complex_types = ["apo", "predicted"] if self.hparams.predicted_structures else ["apo"]
|
| 58 |
+
self.train_files = get_files("train", complex_types)
|
| 59 |
+
self.val_files = get_files("val", complex_types)
|
| 60 |
+
self.test_files = get_files("test", complex_types)
|
| 61 |
+
|
| 62 |
+
self.data_train: Optional[Dataset] = None
|
| 63 |
+
self.data_val: Optional[Dataset] = None
|
| 64 |
+
self.data_test: Optional[Dataset] = None
|
| 65 |
+
|
| 66 |
+
self.batch_size_per_device = batch_size
|
| 67 |
+
|
| 68 |
+
def setup(self, stage: Optional[str] = None) -> None:
|
| 69 |
+
"""Load data. Set variables: `self.data_train`, `self.data_val`, `self.data_test`.
|
| 70 |
+
|
| 71 |
+
This method is called by Lightning before `trainer.fit()`, `trainer.validate()`, `trainer.test()`, and
|
| 72 |
+
`trainer.predict()`, so be careful not to execute things like random split twice! Also, it is called after
|
| 73 |
+
`self.prepare_data()` and there is a barrier in between which ensures that all the processes proceed to
|
| 74 |
+
`self.setup()` once the data is prepared and available for use.
|
| 75 |
+
|
| 76 |
+
:param stage: The stage to setup. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`. Defaults to ``None``.
|
| 77 |
+
"""
|
| 78 |
+
# Divide batch size by the number of devices.
|
| 79 |
+
if self.trainer is not None:
|
| 80 |
+
if self.hparams.batch_size % self.trainer.world_size != 0:
|
| 81 |
+
raise RuntimeError(
|
| 82 |
+
f"Batch size ({self.hparams.batch_size}) is not divisible by the number of devices ({self.trainer.world_size})."
|
| 83 |
+
)
|
| 84 |
+
self.batch_size_per_device = self.hparams.batch_size // self.trainer.world_size
|
| 85 |
+
|
| 86 |
+
# load and split datasets only if not loaded already
|
| 87 |
+
if not self.data_train and not self.data_val and not self.data_test:
|
| 88 |
+
self.data_train = PinderDataset(self.train_files)
|
| 89 |
+
self.data_val = PinderDataset(self.val_files)
|
| 90 |
+
self.data_test = PinderDataset(self.test_files)
|
| 91 |
+
|
| 92 |
+
def train_dataloader(self) -> DataLoader:
|
| 93 |
+
"""Create and return the train dataloader.
|
| 94 |
+
|
| 95 |
+
:return: The train dataloader.
|
| 96 |
+
"""
|
| 97 |
+
return DataLoader(
|
| 98 |
+
dataset=self.data_train,
|
| 99 |
+
batch_size=self.batch_size_per_device,
|
| 100 |
+
num_workers=self.hparams.num_workers,
|
| 101 |
+
pin_memory=self.hparams.pin_memory,
|
| 102 |
+
shuffle=True,
|
| 103 |
+
drop_last=True,
|
| 104 |
+
)
|
| 105 |
+
|
| 106 |
+
def val_dataloader(self) -> DataLoader:
|
| 107 |
+
"""Create and return the validation dataloader.
|
| 108 |
+
|
| 109 |
+
:return: The validation dataloader.
|
| 110 |
+
"""
|
| 111 |
+
return DataLoader(
|
| 112 |
+
dataset=self.data_val,
|
| 113 |
+
batch_size=self.batch_size_per_device,
|
| 114 |
+
num_workers=self.hparams.num_workers,
|
| 115 |
+
pin_memory=self.hparams.pin_memory,
|
| 116 |
+
shuffle=False,
|
| 117 |
+
)
|
| 118 |
+
|
| 119 |
+
def test_dataloader(self) -> DataLoader:
|
| 120 |
+
"""Create and return the test dataloader.
|
| 121 |
+
|
| 122 |
+
:return: The test dataloader.
|
| 123 |
+
"""
|
| 124 |
+
return DataLoader(
|
| 125 |
+
dataset=self.data_test,
|
| 126 |
+
batch_size=self.batch_size_per_device,
|
| 127 |
+
num_workers=self.hparams.num_workers,
|
| 128 |
+
pin_memory=self.hparams.pin_memory,
|
| 129 |
+
shuffle=False,
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
def teardown(self, stage: Optional[str] = None) -> None:
|
| 133 |
+
"""Lightning hook for cleaning up after `trainer.fit()`, `trainer.validate()`,
|
| 134 |
+
`trainer.test()`, and `trainer.predict()`.
|
| 135 |
+
|
| 136 |
+
:param stage: The stage being torn down. Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
|
| 137 |
+
Defaults to ``None``.
|
| 138 |
+
"""
|
| 139 |
+
pass
|
| 140 |
+
|
| 141 |
+
def state_dict(self) -> Dict[Any, Any]:
|
| 142 |
+
"""Called when saving a checkpoint. Implement to generate and save the datamodule state.
|
| 143 |
+
|
| 144 |
+
:return: A dictionary containing the datamodule state that you want to save.
|
| 145 |
+
"""
|
| 146 |
+
return {}
|
| 147 |
+
|
| 148 |
+
def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
|
| 149 |
+
"""Called when loading a checkpoint. Implement to reload datamodule state given datamodule
|
| 150 |
+
`state_dict()`.
|
| 151 |
+
|
| 152 |
+
:param state_dict: The datamodule state returned by `self.state_dict()`.
|
| 153 |
+
"""
|
| 154 |
+
pass
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
if __name__ == "__main__":
|
| 158 |
+
datamodule = PINDERDataModule()
|
| 159 |
+
datamodule.setup()
|
| 160 |
+
# print(datamodule.train_files[64])
|
| 161 |
+
train_loader = datamodule.train_dataloader()
|
| 162 |
+
val_loader = datamodule.val_dataloader()
|
| 163 |
+
test_loader = datamodule.test_dataloader()
|
| 164 |
+
print(f"Number of training batches: {len(train_loader)}")
|
| 165 |
+
print(f"Number of validation batches: {len(val_loader)}")
|
| 166 |
+
print(f"Number of test batches: {len(test_loader)}")
|
| 167 |
+
print(next(iter(train_loader)))
|
src/eval.py
ADDED
|
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, List, Tuple
|
| 2 |
+
|
| 3 |
+
import hydra
|
| 4 |
+
import rootutils
|
| 5 |
+
from lightning import LightningDataModule, LightningModule, Trainer
|
| 6 |
+
from lightning.pytorch.loggers import Logger
|
| 7 |
+
from omegaconf import DictConfig
|
| 8 |
+
|
| 9 |
+
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
| 10 |
+
# ------------------------------------------------------------------------------------ #
|
| 11 |
+
# the setup_root above is equivalent to:
|
| 12 |
+
# - adding project root dir to PYTHONPATH
|
| 13 |
+
# (so you don't need to force user to install project as a package)
|
| 14 |
+
# (necessary before importing any local modules e.g. `from src import utils`)
|
| 15 |
+
# - setting up PROJECT_ROOT environment variable
|
| 16 |
+
# (which is used as a base for paths in "configs/paths/default.yaml")
|
| 17 |
+
# (this way all filepaths are the same no matter where you run the code)
|
| 18 |
+
# - loading environment variables from ".env" in root dir
|
| 19 |
+
#
|
| 20 |
+
# you can remove it if you:
|
| 21 |
+
# 1. either install project as a package or move entry files to project root dir
|
| 22 |
+
# 2. set `root_dir` to "." in "configs/paths/default.yaml"
|
| 23 |
+
#
|
| 24 |
+
# more info: https://github.com/ashleve/rootutils
|
| 25 |
+
# ------------------------------------------------------------------------------------ #
|
| 26 |
+
|
| 27 |
+
from src.utils import (
|
| 28 |
+
RankedLogger,
|
| 29 |
+
extras,
|
| 30 |
+
instantiate_loggers,
|
| 31 |
+
log_hyperparameters,
|
| 32 |
+
task_wrapper,
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
log = RankedLogger(__name__, rank_zero_only=True)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@task_wrapper
|
| 39 |
+
def evaluate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
| 40 |
+
"""Evaluates given checkpoint on a datamodule testset.
|
| 41 |
+
|
| 42 |
+
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
|
| 43 |
+
failure. Useful for multiruns, saving info about the crash, etc.
|
| 44 |
+
|
| 45 |
+
:param cfg: DictConfig configuration composed by Hydra.
|
| 46 |
+
:return: Tuple[dict, dict] with metrics and dict with all instantiated objects.
|
| 47 |
+
"""
|
| 48 |
+
assert cfg.ckpt_path
|
| 49 |
+
|
| 50 |
+
log.info(f"Instantiating datamodule <{cfg.data._target_}>")
|
| 51 |
+
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
|
| 52 |
+
|
| 53 |
+
log.info(f"Instantiating model <{cfg.model._target_}>")
|
| 54 |
+
model: LightningModule = hydra.utils.instantiate(cfg.model)
|
| 55 |
+
|
| 56 |
+
log.info("Instantiating loggers...")
|
| 57 |
+
logger: List[Logger] = instantiate_loggers(cfg.get("logger"))
|
| 58 |
+
|
| 59 |
+
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
|
| 60 |
+
trainer: Trainer = hydra.utils.instantiate(cfg.trainer, logger=logger)
|
| 61 |
+
|
| 62 |
+
object_dict = {
|
| 63 |
+
"cfg": cfg,
|
| 64 |
+
"datamodule": datamodule,
|
| 65 |
+
"model": model,
|
| 66 |
+
"logger": logger,
|
| 67 |
+
"trainer": trainer,
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
if logger:
|
| 71 |
+
log.info("Logging hyperparameters!")
|
| 72 |
+
log_hyperparameters(object_dict)
|
| 73 |
+
|
| 74 |
+
log.info("Starting testing!")
|
| 75 |
+
trainer.test(model=model, datamodule=datamodule, ckpt_path=cfg.ckpt_path)
|
| 76 |
+
|
| 77 |
+
# for predictions use trainer.predict(...)
|
| 78 |
+
# predictions = trainer.predict(model=model, dataloaders=dataloaders, ckpt_path=cfg.ckpt_path)
|
| 79 |
+
|
| 80 |
+
metric_dict = trainer.callback_metrics
|
| 81 |
+
|
| 82 |
+
return metric_dict, object_dict
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
@hydra.main(version_base="1.3", config_path="../configs", config_name="eval.yaml")
|
| 86 |
+
def main(cfg: DictConfig) -> None:
|
| 87 |
+
"""Main entry point for evaluation.
|
| 88 |
+
|
| 89 |
+
:param cfg: DictConfig configuration composed by Hydra.
|
| 90 |
+
"""
|
| 91 |
+
# apply extra utilities
|
| 92 |
+
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
|
| 93 |
+
extras(cfg)
|
| 94 |
+
|
| 95 |
+
evaluate(cfg)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
if __name__ == "__main__":
|
| 99 |
+
main()
|
src/models/__init__.py
ADDED
|
File without changes
|
src/models/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (145 Bytes). View file
|
|
|
src/models/__pycache__/pinder_module.cpython-310.pyc
ADDED
|
Binary file (8.44 kB). View file
|
|
|
src/models/components/__init__.py
ADDED
|
File without changes
|
src/models/components/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (156 Bytes). View file
|
|
|
src/models/components/__pycache__/equivariant_mpnn.cpython-310.pyc
ADDED
|
Binary file (6.84 kB). View file
|
|
|
src/models/components/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (2.74 kB). View file
|
|
|
src/models/components/equivariant_mpnn.py
ADDED
|
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import rootutils
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
from torch.nn import BatchNorm1d, Linear, Module, ReLU, Sequential
|
| 5 |
+
from torch_geometric.loader import DataLoader
|
| 6 |
+
from torch_geometric.nn import MessagePassing
|
| 7 |
+
from torch_scatter import scatter
|
| 8 |
+
|
| 9 |
+
# setup root dir and pythonpath
|
| 10 |
+
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
| 11 |
+
|
| 12 |
+
from src.data.components.pinder_dataset import PinderDataset
|
| 13 |
+
from src.models.components.utils import (
|
| 14 |
+
compute_euler_angles_from_rotation_matrices,
|
| 15 |
+
compute_rotation_matrix_from_ortho6d,
|
| 16 |
+
)
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class EquivariantMPNNLayer(MessagePassing):
|
| 20 |
+
def __init__(self, emb_dim=64, out_dim=128, aggr="add"):
|
| 21 |
+
r"""Message Passing Neural Network Layer
|
| 22 |
+
|
| 23 |
+
This layer is equivariant to 3D rotations and translations.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
emb_dim: (int) - hidden dimension d
|
| 27 |
+
edge_dim: (int) - edge feature dimension d_e
|
| 28 |
+
aggr: (str) - aggregation function \oplus (sum/mean/max)
|
| 29 |
+
"""
|
| 30 |
+
# Set the aggregation function
|
| 31 |
+
super().__init__(aggr=aggr)
|
| 32 |
+
|
| 33 |
+
self.emb_dim = emb_dim
|
| 34 |
+
|
| 35 |
+
#
|
| 36 |
+
self.mlp_msg = Sequential(
|
| 37 |
+
Linear(2 * emb_dim + 1, emb_dim),
|
| 38 |
+
BatchNorm1d(emb_dim),
|
| 39 |
+
ReLU(),
|
| 40 |
+
Linear(emb_dim, emb_dim),
|
| 41 |
+
BatchNorm1d(emb_dim),
|
| 42 |
+
ReLU(),
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
self.mlp_pos = Sequential(
|
| 46 |
+
Linear(emb_dim, emb_dim), BatchNorm1d(emb_dim), ReLU(), Linear(emb_dim, 1)
|
| 47 |
+
) # MLP \psi
|
| 48 |
+
self.mlp_upd = Sequential(
|
| 49 |
+
Linear(2 * emb_dim, emb_dim),
|
| 50 |
+
BatchNorm1d(emb_dim),
|
| 51 |
+
ReLU(),
|
| 52 |
+
Linear(emb_dim, emb_dim),
|
| 53 |
+
BatchNorm1d(emb_dim),
|
| 54 |
+
ReLU(),
|
| 55 |
+
) # MLP \phi
|
| 56 |
+
# ===========================================
|
| 57 |
+
|
| 58 |
+
self.lin_out = Linear(emb_dim, out_dim)
|
| 59 |
+
|
| 60 |
+
def forward(self, data):
|
| 61 |
+
"""
|
| 62 |
+
The forward pass updates node features h via one round of message passing.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
h: (n, d) - initial node features
|
| 66 |
+
pos: (n, 3) - initial node coordinates
|
| 67 |
+
edge_index: (e, 2) - pairs of edges (i, j)
|
| 68 |
+
edge_attr: (e, d_e) - edge features
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
out: [(n, d),(n,3)] - updated node features
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
#
|
| 75 |
+
h, pos, edge_index = data
|
| 76 |
+
h_out, pos_out = self.propagate(edge_index=edge_index, h=h, pos=pos)
|
| 77 |
+
h_out = self.lin_out(h_out)
|
| 78 |
+
return h_out, pos_out, edge_index
|
| 79 |
+
# ==========================================
|
| 80 |
+
|
| 81 |
+
#
|
| 82 |
+
def message(self, h_i, h_j, pos_i, pos_j):
|
| 83 |
+
# Compute distance between nodes i and j (Euclidean distance)
|
| 84 |
+
# distance_ij = torch.norm(pos_i - pos_j, dim=-1, keepdim=True) # (e, 1)
|
| 85 |
+
pos_diff = pos_i - pos_j
|
| 86 |
+
dists = torch.norm(pos_diff, dim=-1).unsqueeze(1)
|
| 87 |
+
|
| 88 |
+
# Concatenate node features, edge features, and distance
|
| 89 |
+
msg = torch.cat([h_i, h_j, dists], dim=-1)
|
| 90 |
+
msg = self.mlp_msg(msg)
|
| 91 |
+
pos_diff = pos_diff * self.mlp_pos(msg) # (e, 2d + d_e + 1)
|
| 92 |
+
|
| 93 |
+
# (e, d)
|
| 94 |
+
return msg, pos_diff
|
| 95 |
+
|
| 96 |
+
# ...
|
| 97 |
+
#
|
| 98 |
+
def aggregate(self, inputs, index):
|
| 99 |
+
"""The aggregate function aggregates the messages from neighboring nodes,
|
| 100 |
+
according to the chosen aggregation function ('sum' by default).
|
| 101 |
+
|
| 102 |
+
Args:
|
| 103 |
+
inputs: (e, d) - messages m_ij from destination to source nodes
|
| 104 |
+
index: (e, 1) - list of source nodes for each edge/message in input
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
aggr_out: (n, d) - aggregated messages m_i
|
| 108 |
+
"""
|
| 109 |
+
msgs, pos_diffs = inputs
|
| 110 |
+
|
| 111 |
+
msg_aggr = scatter(msgs, index, dim=self.node_dim, reduce=self.aggr)
|
| 112 |
+
|
| 113 |
+
pos_aggr = scatter(pos_diffs, index, dim=self.node_dim, reduce="mean")
|
| 114 |
+
|
| 115 |
+
return msg_aggr, pos_aggr
|
| 116 |
+
|
| 117 |
+
def update(self, aggr_out, h, pos):
|
| 118 |
+
msg_aggr, pos_aggr = aggr_out
|
| 119 |
+
|
| 120 |
+
upd_out = self.mlp_upd(torch.cat((h, msg_aggr), dim=-1))
|
| 121 |
+
|
| 122 |
+
upd_pos = pos + pos_aggr
|
| 123 |
+
|
| 124 |
+
return upd_out, upd_pos
|
| 125 |
+
|
| 126 |
+
def __repr__(self) -> str:
|
| 127 |
+
return f"{self.__class__.__name__}(emb_dim={self.emb_dim}, aggr={self.aggr})"
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class PinderMPNNModel(Module):
|
| 131 |
+
def __init__(self, input_dim=1, emb_dim=64, num_heads=5):
|
| 132 |
+
"""Message Passing Neural Network model for graph property prediction
|
| 133 |
+
|
| 134 |
+
This model uses both node features and coordinates as inputs, and
|
| 135 |
+
is invariant to 3D rotations and translations (the constituent MPNN layers
|
| 136 |
+
are equivariant to 3D rotations and translations).
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
emb_dim: (int) - hidden dimension d
|
| 140 |
+
input_dim: (int) - initial node feature dimension d_n
|
| 141 |
+
edge_dim: (int) - edge feature dimension d_e
|
| 142 |
+
out_dim: (int) - output dimension (fixed to 1)
|
| 143 |
+
"""
|
| 144 |
+
super().__init__()
|
| 145 |
+
|
| 146 |
+
# Linear projection for initial node features
|
| 147 |
+
self.lin_in_rec = Linear(input_dim, emb_dim)
|
| 148 |
+
self.lin_in_lig = Linear(input_dim, emb_dim)
|
| 149 |
+
|
| 150 |
+
# Stack of MPNN layers
|
| 151 |
+
self.receptor_mpnn = Sequential(
|
| 152 |
+
EquivariantMPNNLayer(emb_dim, 128, aggr="mean"),
|
| 153 |
+
EquivariantMPNNLayer(128, 256, aggr="mean"),
|
| 154 |
+
# EquivariantMPNNLayer(256, 512, aggr="mean"),
|
| 155 |
+
# EquivariantMPNNLayer(512, 512, aggr="mean"),
|
| 156 |
+
)
|
| 157 |
+
self.ligand_mpnn = Sequential(
|
| 158 |
+
EquivariantMPNNLayer(64, 128, aggr="mean"),
|
| 159 |
+
EquivariantMPNNLayer(128, 256, aggr="mean"),
|
| 160 |
+
# EquivariantMPNNLayer(256, 512, aggr="mean"),
|
| 161 |
+
# EquivariantMPNNLayer(512, 512, aggr="mean"),
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
# Cross-attention layer
|
| 165 |
+
self.rec_cross_attention = nn.MultiheadAttention(256, num_heads, batch_first=True)
|
| 166 |
+
self.lig_cross_attention = nn.MultiheadAttention(256, num_heads, batch_first=True)
|
| 167 |
+
|
| 168 |
+
# MLPs for translation prediction
|
| 169 |
+
self.fc_translation_rec = nn.Linear(256 + 3, 3)
|
| 170 |
+
self.fc_translation_lig = nn.Linear(256 + 3, 3)
|
| 171 |
+
|
| 172 |
+
def forward(self, batch):
|
| 173 |
+
"""
|
| 174 |
+
The main forward pass of the model.
|
| 175 |
+
|
| 176 |
+
Args:
|
| 177 |
+
batch: Same as in forward_rot_trans.
|
| 178 |
+
|
| 179 |
+
Returns:
|
| 180 |
+
transformed_ligands: List of tensors, each of shape (1, num_ligand_atoms, 3)
|
| 181 |
+
representing the transformed ligand coordinates after applying the predicted
|
| 182 |
+
rotation and translation.
|
| 183 |
+
"""
|
| 184 |
+
h_receptor = self.lin_in_rec(batch["receptor"].x)
|
| 185 |
+
h_ligand = self.lin_in_lig(batch["ligand"].x)
|
| 186 |
+
|
| 187 |
+
pos_receptor = batch["receptor"].pos
|
| 188 |
+
pos_ligand = batch["ligand"].pos
|
| 189 |
+
|
| 190 |
+
h_receptor, pos_receptor, _ = self.receptor_mpnn(
|
| 191 |
+
(h_receptor, pos_receptor, batch["receptor", "receptor"].edge_index)
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
h_ligand, pos_ligand, _ = self.ligand_mpnn(
|
| 195 |
+
(h_ligand, pos_ligand, batch["ligand", "ligand"].edge_index)
|
| 196 |
+
)
|
| 197 |
+
|
| 198 |
+
attn_output_rec, _ = self.rec_cross_attention(h_receptor, h_ligand, h_ligand)
|
| 199 |
+
|
| 200 |
+
attn_output_lig, _ = self.lig_cross_attention(h_ligand, h_receptor, h_receptor)
|
| 201 |
+
|
| 202 |
+
emb_features_receptor = torch.cat((attn_output_rec, pos_receptor), dim=-1)
|
| 203 |
+
emb_features_ligand = torch.cat((attn_output_lig, pos_ligand), dim=-1)
|
| 204 |
+
|
| 205 |
+
translation_vector_r = self.fc_translation_rec(emb_features_receptor)
|
| 206 |
+
translation_vector_l = self.fc_translation_lig(emb_features_ligand)
|
| 207 |
+
|
| 208 |
+
ortho_6d_rec = compute_rotation_matrix_from_ortho6d(attn_output_rec)
|
| 209 |
+
ortho_6d_lig = compute_rotation_matrix_from_ortho6d(attn_output_lig)
|
| 210 |
+
|
| 211 |
+
receptor_coords = (
|
| 212 |
+
compute_euler_angles_from_rotation_matrices(ortho_6d_rec) * 180 / torch.pi
|
| 213 |
+
)
|
| 214 |
+
ligand_coords = compute_euler_angles_from_rotation_matrices(ortho_6d_lig) * 180 / torch.pi
|
| 215 |
+
|
| 216 |
+
receptor_coords = receptor_coords + translation_vector_r
|
| 217 |
+
ligand_coords = ligand_coords + translation_vector_l
|
| 218 |
+
|
| 219 |
+
return receptor_coords, ligand_coords
|
| 220 |
+
|
| 221 |
+
|
| 222 |
+
if __name__ == "__main__":
|
| 223 |
+
file_paths = ["./data/processed/apo/test/1a19__A1_P11540--1a19__B1_P11540.pt"]
|
| 224 |
+
dataset = PinderDataset(file_paths=file_paths * 3)
|
| 225 |
+
loader = DataLoader(dataset, batch_size=3, shuffle=False)
|
| 226 |
+
batch = next(iter(loader))
|
| 227 |
+
model = PinderMPNNModel()
|
| 228 |
+
print("Number of parameters:", sum(p.numel() for p in model.parameters()))
|
| 229 |
+
receptor_coords, ligand_coords = model(batch)
|
| 230 |
+
print(receptor_coords.shape)
|
| 231 |
+
print(ligand_coords.shape)
|
src/models/components/utils.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
# batch*n
|
| 5 |
+
def normalize_vector(v):
|
| 6 |
+
batch = v.shape[0]
|
| 7 |
+
v_mag = torch.sqrt(v.pow(2).sum(1)) # batch
|
| 8 |
+
eps = torch.tensor(1e-8, device=v.device)
|
| 9 |
+
v_mag = torch.max(v_mag, eps)
|
| 10 |
+
v_mag = v_mag.view(batch, 1).expand(batch, v.shape[1])
|
| 11 |
+
v = v / v_mag
|
| 12 |
+
return v
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
# u, v batch*n
|
| 16 |
+
def cross_product(u, v):
|
| 17 |
+
batch = u.shape[0]
|
| 18 |
+
# print (u.shape)
|
| 19 |
+
# print (v.shape)
|
| 20 |
+
i = u[:, 1] * v[:, 2] - u[:, 2] * v[:, 1]
|
| 21 |
+
j = u[:, 2] * v[:, 0] - u[:, 0] * v[:, 2]
|
| 22 |
+
k = u[:, 0] * v[:, 1] - u[:, 1] * v[:, 0]
|
| 23 |
+
|
| 24 |
+
out = torch.cat((i.view(batch, 1), j.view(batch, 1), k.view(batch, 1)), 1) # batch*3
|
| 25 |
+
|
| 26 |
+
return out
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# poses batch*6
|
| 30 |
+
# poses
|
| 31 |
+
def compute_rotation_matrix_from_ortho6d(poses):
|
| 32 |
+
x_raw = poses[:, 0:3] # batch*3
|
| 33 |
+
y_raw = poses[:, 3:6] # batch*3
|
| 34 |
+
|
| 35 |
+
x = normalize_vector(x_raw) # batch*3
|
| 36 |
+
z = cross_product(x, y_raw) # batch*3
|
| 37 |
+
z = normalize_vector(z) # batch*3
|
| 38 |
+
y = cross_product(z, x) # batch*3
|
| 39 |
+
|
| 40 |
+
x = x.view(-1, 3, 1)
|
| 41 |
+
y = y.view(-1, 3, 1)
|
| 42 |
+
z = z.view(-1, 3, 1)
|
| 43 |
+
matrix = torch.cat((x, y, z), 2) # batch*3*3
|
| 44 |
+
return matrix
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
# input batch*4*4 or batch*3*3
|
| 48 |
+
# output torch batch*3 x, y, z in radiant
|
| 49 |
+
# the rotation is in the sequence of x,y,z
|
| 50 |
+
def compute_euler_angles_from_rotation_matrices(rotation_matrices):
|
| 51 |
+
batch = rotation_matrices.shape[0]
|
| 52 |
+
R = rotation_matrices
|
| 53 |
+
sy = torch.sqrt(R[:, 0, 0] * R[:, 0, 0] + R[:, 1, 0] * R[:, 1, 0])
|
| 54 |
+
singular = sy < 1e-6
|
| 55 |
+
singular = singular.float()
|
| 56 |
+
|
| 57 |
+
x = torch.atan2(R[:, 2, 1], R[:, 2, 2])
|
| 58 |
+
y = torch.atan2(-R[:, 2, 0], sy)
|
| 59 |
+
z = torch.atan2(R[:, 1, 0], R[:, 0, 0])
|
| 60 |
+
|
| 61 |
+
xs = torch.atan2(-R[:, 1, 2], R[:, 1, 1])
|
| 62 |
+
ys = torch.atan2(-R[:, 2, 0], sy)
|
| 63 |
+
zs = R[:, 1, 0] * 0
|
| 64 |
+
|
| 65 |
+
out_euler = torch.zeros(batch, 3, device=rotation_matrices.device)
|
| 66 |
+
|
| 67 |
+
out_euler[:, 0] = x * (1 - singular) + xs * singular
|
| 68 |
+
out_euler[:, 1] = y * (1 - singular) + ys * singular
|
| 69 |
+
out_euler[:, 2] = z * (1 - singular) + zs * singular
|
| 70 |
+
|
| 71 |
+
return out_euler
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def get_R(x, y, z):
|
| 75 |
+
"""Get rotation matrix from three rotation angles (radians). right-handed.
|
| 76 |
+
Args:
|
| 77 |
+
x: rotation angle around x-axis
|
| 78 |
+
y: rotation angle around y-axis
|
| 79 |
+
z: rotation angle around z-axis
|
| 80 |
+
Returns:
|
| 81 |
+
R: [3, 3]. rotation matrix.
|
| 82 |
+
"""
|
| 83 |
+
# x
|
| 84 |
+
Rx = torch.tensor(
|
| 85 |
+
[[1, 0, 0], [0, torch.cos(x), -torch.sin(x)], [0, torch.sin(x), torch.cos(x)]],
|
| 86 |
+
device=x.device,
|
| 87 |
+
)
|
| 88 |
+
# y
|
| 89 |
+
Ry = torch.tensor(
|
| 90 |
+
[[torch.cos(y), 0, torch.sin(y)], [0, 1, 0], [-torch.sin(y), 0, torch.cos(y)]],
|
| 91 |
+
device=y.device,
|
| 92 |
+
)
|
| 93 |
+
# z
|
| 94 |
+
Rz = torch.tensor(
|
| 95 |
+
[[torch.cos(z), -torch.sin(z), 0], [torch.sin(z), torch.cos(z), 0], [0, 0, 1]],
|
| 96 |
+
device=z.device,
|
| 97 |
+
)
|
| 98 |
+
|
| 99 |
+
R = torch.mm(Rz, torch.mm(Ry, Rx))
|
| 100 |
+
return R
|
src/models/pinder_module.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from lightning import LightningModule
|
| 5 |
+
from torchmetrics import MeanMetric, MinMetric
|
| 6 |
+
from torchmetrics.regression import MeanAbsoluteError, MeanSquaredError
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class PinderLitModule(LightningModule):
|
| 10 |
+
"""Example of a `LightningModule` for MNIST classification.
|
| 11 |
+
|
| 12 |
+
A `LightningModule` implements 8 key methods:
|
| 13 |
+
|
| 14 |
+
```python
|
| 15 |
+
def __init__(self):
|
| 16 |
+
# Define initialization code here.
|
| 17 |
+
|
| 18 |
+
def setup(self, stage):
|
| 19 |
+
# Things to setup before each stage, 'fit', 'validate', 'test', 'predict'.
|
| 20 |
+
# This hook is called on every process when using DDP.
|
| 21 |
+
|
| 22 |
+
def training_step(self, batch, batch_idx):
|
| 23 |
+
# The complete training step.
|
| 24 |
+
|
| 25 |
+
def validation_step(self, batch, batch_idx):
|
| 26 |
+
# The complete validation step.
|
| 27 |
+
|
| 28 |
+
def test_step(self, batch, batch_idx):
|
| 29 |
+
# The complete test step.
|
| 30 |
+
|
| 31 |
+
def predict_step(self, batch, batch_idx):
|
| 32 |
+
# The complete predict step.
|
| 33 |
+
|
| 34 |
+
def configure_optimizers(self):
|
| 35 |
+
# Define and configure optimizers and LR schedulers.
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
Docs:
|
| 39 |
+
https://lightning.ai/docs/pytorch/latest/common/lightning_module.html
|
| 40 |
+
"""
|
| 41 |
+
|
| 42 |
+
def __init__(
|
| 43 |
+
self,
|
| 44 |
+
net: torch.nn.Module,
|
| 45 |
+
optimizer: torch.optim.Optimizer,
|
| 46 |
+
scheduler: torch.optim.lr_scheduler,
|
| 47 |
+
compile: bool,
|
| 48 |
+
) -> None:
|
| 49 |
+
"""Initialize a `MNISTLitModule`.
|
| 50 |
+
|
| 51 |
+
:param net: The model to train.
|
| 52 |
+
:param optimizer: The optimizer to use for training.
|
| 53 |
+
:param scheduler: The learning rate scheduler to use for training.
|
| 54 |
+
"""
|
| 55 |
+
super().__init__()
|
| 56 |
+
|
| 57 |
+
# this line allows to access init params with 'self.hparams' attribute
|
| 58 |
+
# also ensures init params will be stored in ckpt
|
| 59 |
+
self.save_hyperparameters(logger=False)
|
| 60 |
+
|
| 61 |
+
self.net = net
|
| 62 |
+
|
| 63 |
+
# loss function
|
| 64 |
+
self.criterion = torch.nn.MSELoss()
|
| 65 |
+
|
| 66 |
+
# metric objects for calculating and averaging accuracy across batches
|
| 67 |
+
self.train_mse_ligand = MeanSquaredError()
|
| 68 |
+
self.val_mse_ligand = MeanSquaredError()
|
| 69 |
+
self.test_mse_ligand = MeanSquaredError()
|
| 70 |
+
|
| 71 |
+
self.train_mse_receptor = MeanSquaredError()
|
| 72 |
+
self.val_mse_receptor = MeanSquaredError()
|
| 73 |
+
self.test_mse_receptor = MeanSquaredError()
|
| 74 |
+
|
| 75 |
+
self.train_mae_receptor = MeanAbsoluteError()
|
| 76 |
+
self.val_mae_receptor = MeanAbsoluteError()
|
| 77 |
+
self.test_mae_receptor = MeanAbsoluteError()
|
| 78 |
+
|
| 79 |
+
self.train_mae_ligand = MeanAbsoluteError()
|
| 80 |
+
self.val_mae_ligand = MeanAbsoluteError()
|
| 81 |
+
self.test_mae_ligand = MeanAbsoluteError()
|
| 82 |
+
|
| 83 |
+
# for averaging loss across batches
|
| 84 |
+
self.train_loss = MeanMetric()
|
| 85 |
+
self.val_loss = MeanMetric()
|
| 86 |
+
self.test_loss = MeanMetric()
|
| 87 |
+
|
| 88 |
+
# for tracking best so far validation mse
|
| 89 |
+
self.val_mse_best = MinMetric()
|
| 90 |
+
|
| 91 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 92 |
+
"""Perform a forward pass through the model `self.net`.
|
| 93 |
+
|
| 94 |
+
:param x: A tensor of images.
|
| 95 |
+
:return: A tensor of logits.
|
| 96 |
+
"""
|
| 97 |
+
return self.net(x)
|
| 98 |
+
|
| 99 |
+
def on_train_start(self) -> None:
|
| 100 |
+
"""Lightning hook that is called when training begins."""
|
| 101 |
+
# by default lightning executes validation step sanity checks before training starts,
|
| 102 |
+
# so it's worth to make sure validation metrics don't store results from these checks
|
| 103 |
+
self.val_loss.reset()
|
| 104 |
+
self.val_mse_ligand.reset()
|
| 105 |
+
self.val_mse_receptor.reset()
|
| 106 |
+
self.val_mae_receptor.reset()
|
| 107 |
+
self.val_mae_ligand.reset()
|
| 108 |
+
self.val_mse_best.reset()
|
| 109 |
+
|
| 110 |
+
def model_step(
|
| 111 |
+
self, batch: Tuple[torch.Tensor, torch.Tensor]
|
| 112 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 113 |
+
"""Perform a single model step on a batch of data.
|
| 114 |
+
|
| 115 |
+
:param batch: A batch of data (a tuple) containing the input tensor of images and target labels.
|
| 116 |
+
|
| 117 |
+
:return: A tuple containing (in order):
|
| 118 |
+
- A tensor of losses.
|
| 119 |
+
- A tensor of predictions.
|
| 120 |
+
- A tensor of target labels.
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
receptor_coords, ligand_coords = self.forward(batch)
|
| 124 |
+
loss_receptor = self.criterion(receptor_coords, batch["receptor"].y)
|
| 125 |
+
loss_ligand = self.criterion(ligand_coords, batch["ligand"].y)
|
| 126 |
+
loss = loss_receptor + loss_ligand
|
| 127 |
+
return loss, receptor_coords, ligand_coords, batch["receptor"].y, batch["ligand"].y
|
| 128 |
+
|
| 129 |
+
def training_step(
|
| 130 |
+
self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int
|
| 131 |
+
) -> torch.Tensor:
|
| 132 |
+
"""Perform a single training step on a batch of data from the training set.
|
| 133 |
+
|
| 134 |
+
:param batch: A batch of data (a tuple) containing the input tensor of images and target
|
| 135 |
+
labels.
|
| 136 |
+
:param batch_idx: The index of the current batch.
|
| 137 |
+
:return: A tensor of losses between model predictions and targets.
|
| 138 |
+
"""
|
| 139 |
+
loss, receptor_coords, ligand_coords, receptor_targets, ligand_targets = self.model_step(
|
| 140 |
+
batch
|
| 141 |
+
)
|
| 142 |
+
|
| 143 |
+
# update and log metrics
|
| 144 |
+
self.train_loss(loss)
|
| 145 |
+
self.train_mse_ligand(ligand_coords, ligand_targets)
|
| 146 |
+
self.train_mse_receptor(receptor_coords, receptor_targets)
|
| 147 |
+
self.train_mae_ligand(ligand_coords, ligand_targets)
|
| 148 |
+
self.train_mae_receptor(receptor_coords, receptor_targets)
|
| 149 |
+
self.log("train/loss", self.train_loss, on_step=True, on_epoch=False, prog_bar=True)
|
| 150 |
+
self.log(
|
| 151 |
+
"train/mse_ligand", self.train_mse_ligand, on_step=True, on_epoch=False, prog_bar=True
|
| 152 |
+
)
|
| 153 |
+
self.log(
|
| 154 |
+
"train/mse_receptor",
|
| 155 |
+
self.train_mse_receptor,
|
| 156 |
+
on_step=True,
|
| 157 |
+
on_epoch=False,
|
| 158 |
+
prog_bar=True,
|
| 159 |
+
)
|
| 160 |
+
self.log(
|
| 161 |
+
"train/mae_ligand", self.train_mae_ligand, on_step=True, on_epoch=False, prog_bar=True
|
| 162 |
+
)
|
| 163 |
+
self.log(
|
| 164 |
+
"train/mae_receptor",
|
| 165 |
+
self.train_mae_receptor,
|
| 166 |
+
on_step=True,
|
| 167 |
+
on_epoch=False,
|
| 168 |
+
prog_bar=True,
|
| 169 |
+
)
|
| 170 |
+
|
| 171 |
+
# return loss or backpropagation will fail
|
| 172 |
+
return loss
|
| 173 |
+
|
| 174 |
+
def on_train_epoch_end(self) -> None:
|
| 175 |
+
"Lightning hook that is called when a training epoch ends."
|
| 176 |
+
pass
|
| 177 |
+
|
| 178 |
+
def validation_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
|
| 179 |
+
"""Perform a single validation step on a batch of data from the validation set.
|
| 180 |
+
|
| 181 |
+
:param batch: A batch of data (a tuple) containing the input tensor of images and target
|
| 182 |
+
labels.
|
| 183 |
+
:param batch_idx: The index of the current batch.
|
| 184 |
+
"""
|
| 185 |
+
loss, receptor_coords, ligand_coords, receptor_targets, ligand_targets = self.model_step(
|
| 186 |
+
batch
|
| 187 |
+
)
|
| 188 |
+
|
| 189 |
+
# update and log metrics
|
| 190 |
+
self.val_loss(loss)
|
| 191 |
+
self.val_mse_ligand(ligand_coords, ligand_targets)
|
| 192 |
+
self.val_mse_receptor(receptor_coords, receptor_targets)
|
| 193 |
+
self.val_mae_ligand(ligand_coords, ligand_targets)
|
| 194 |
+
self.val_mae_receptor(receptor_coords, receptor_targets)
|
| 195 |
+
self.log("val/loss", self.val_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 196 |
+
self.log(
|
| 197 |
+
"val/mse_ligand", self.val_mse_ligand, on_step=False, on_epoch=True, prog_bar=True
|
| 198 |
+
)
|
| 199 |
+
self.log(
|
| 200 |
+
"val/mse_receptor", self.val_mse_receptor, on_step=False, on_epoch=True, prog_bar=True
|
| 201 |
+
)
|
| 202 |
+
self.log(
|
| 203 |
+
"val/mae_ligand", self.val_mae_ligand, on_step=False, on_epoch=True, prog_bar=True
|
| 204 |
+
)
|
| 205 |
+
self.log(
|
| 206 |
+
"val/mae_receptor", self.val_mae_receptor, on_step=False, on_epoch=True, prog_bar=True
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
def on_validation_epoch_end(self) -> None:
|
| 210 |
+
"Lightning hook that is called when a validation epoch ends."
|
| 211 |
+
acc = self.val_mse_ligand.compute() # get current val acc
|
| 212 |
+
self.val_mse_best(acc) # update best so far val acc
|
| 213 |
+
# log `val_acc_best` as a value through `.compute()` method, instead of as a metric object
|
| 214 |
+
# otherwise metric would be reset by lightning after each epoch
|
| 215 |
+
self.log("val/acc_best", self.val_mse_best.compute(), sync_dist=True, prog_bar=True)
|
| 216 |
+
|
| 217 |
+
def test_step(self, batch: Tuple[torch.Tensor, torch.Tensor], batch_idx: int) -> None:
|
| 218 |
+
"""Perform a single test step on a batch of data from the test set.
|
| 219 |
+
|
| 220 |
+
:param batch: A batch of data (a tuple) containing the input tensor of images and target
|
| 221 |
+
labels.
|
| 222 |
+
:param batch_idx: The index of the current batch.
|
| 223 |
+
"""
|
| 224 |
+
loss, receptor_coords, ligand_coords, receptor_targets, ligand_targets = self.model_step(
|
| 225 |
+
batch
|
| 226 |
+
)
|
| 227 |
+
|
| 228 |
+
# update and log metrics
|
| 229 |
+
self.test_loss(loss)
|
| 230 |
+
self.test_mse_ligand(ligand_coords, ligand_targets)
|
| 231 |
+
self.test_mse_receptor(receptor_coords, receptor_targets)
|
| 232 |
+
self.test_mae_ligand(ligand_coords, ligand_targets)
|
| 233 |
+
self.test_mae_receptor(receptor_coords, receptor_targets)
|
| 234 |
+
self.log("test/loss", self.test_loss, on_step=False, on_epoch=True, prog_bar=True)
|
| 235 |
+
self.log(
|
| 236 |
+
"test/mse_ligand", self.test_mse_ligand, on_step=False, on_epoch=True, prog_bar=True
|
| 237 |
+
)
|
| 238 |
+
self.log(
|
| 239 |
+
"test/mse_receptor",
|
| 240 |
+
self.test_mse_receptor,
|
| 241 |
+
on_step=False,
|
| 242 |
+
on_epoch=True,
|
| 243 |
+
prog_bar=True,
|
| 244 |
+
)
|
| 245 |
+
self.log(
|
| 246 |
+
"test/mae_ligand", self.test_mae_ligand, on_step=False, on_epoch=True, prog_bar=True
|
| 247 |
+
)
|
| 248 |
+
self.log(
|
| 249 |
+
"test/mae_receptor",
|
| 250 |
+
self.test_mae_receptor,
|
| 251 |
+
on_step=False,
|
| 252 |
+
on_epoch=True,
|
| 253 |
+
prog_bar=True,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
def on_test_epoch_end(self) -> None:
|
| 257 |
+
"""Lightning hook that is called when a test epoch ends."""
|
| 258 |
+
pass
|
| 259 |
+
|
| 260 |
+
def setup(self, stage: str) -> None:
|
| 261 |
+
"""Lightning hook that is called at the beginning of fit (train + validate), validate,
|
| 262 |
+
test, or predict.
|
| 263 |
+
|
| 264 |
+
This is a good hook when you need to build models dynamically or adjust something about
|
| 265 |
+
them. This hook is called on every process when using DDP.
|
| 266 |
+
|
| 267 |
+
:param stage: Either `"fit"`, `"validate"`, `"test"`, or `"predict"`.
|
| 268 |
+
"""
|
| 269 |
+
if self.hparams.compile and stage == "fit":
|
| 270 |
+
self.net = torch.compile(self.net)
|
| 271 |
+
|
| 272 |
+
def configure_optimizers(self) -> Dict[str, Any]:
|
| 273 |
+
"""Choose what optimizers and learning-rate schedulers to use in your optimization.
|
| 274 |
+
Normally you'd need one. But in the case of GANs or similar you might have multiple.
|
| 275 |
+
|
| 276 |
+
Examples:
|
| 277 |
+
https://lightning.ai/docs/pytorch/latest/common/lightning_module.html#configure-optimizers
|
| 278 |
+
|
| 279 |
+
:return: A dict containing the configured optimizers and learning-rate schedulers to be used for training.
|
| 280 |
+
"""
|
| 281 |
+
optimizer = self.hparams.optimizer(params=self.trainer.model.parameters())
|
| 282 |
+
if self.hparams.scheduler is not None:
|
| 283 |
+
scheduler = self.hparams.scheduler(optimizer=optimizer)
|
| 284 |
+
return {
|
| 285 |
+
"optimizer": optimizer,
|
| 286 |
+
"lr_scheduler": {
|
| 287 |
+
"scheduler": scheduler,
|
| 288 |
+
"monitor": "val/loss",
|
| 289 |
+
"interval": "epoch",
|
| 290 |
+
"frequency": 1,
|
| 291 |
+
},
|
| 292 |
+
}
|
| 293 |
+
return {"optimizer": optimizer}
|
| 294 |
+
|
| 295 |
+
|
| 296 |
+
if __name__ == "__main__":
|
| 297 |
+
_ = PinderLitModule(None, None, None, None)
|
src/train.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, List, Optional, Tuple
|
| 2 |
+
|
| 3 |
+
import hydra
|
| 4 |
+
import lightning as L
|
| 5 |
+
import rootutils
|
| 6 |
+
import torch
|
| 7 |
+
from lightning import Callback, LightningDataModule, LightningModule, Trainer
|
| 8 |
+
from lightning.pytorch.loggers import Logger
|
| 9 |
+
from omegaconf import DictConfig
|
| 10 |
+
|
| 11 |
+
rootutils.setup_root(__file__, indicator=".project-root", pythonpath=True)
|
| 12 |
+
# ------------------------------------------------------------------------------------ #
|
| 13 |
+
# the setup_root above is equivalent to:
|
| 14 |
+
# - adding project root dir to PYTHONPATH
|
| 15 |
+
# (so you don't need to force user to install project as a package)
|
| 16 |
+
# (necessary before importing any local modules e.g. `from src import utils`)
|
| 17 |
+
# - setting up PROJECT_ROOT environment variable
|
| 18 |
+
# (which is used as a base for paths in "configs/paths/default.yaml")
|
| 19 |
+
# (this way all filepaths are the same no matter where you run the code)
|
| 20 |
+
# - loading environment variables from ".env" in root dir
|
| 21 |
+
#
|
| 22 |
+
# you can remove it if you:
|
| 23 |
+
# 1. either install project as a package or move entry files to project root dir
|
| 24 |
+
# 2. set `root_dir` to "." in "configs/paths/default.yaml"
|
| 25 |
+
#
|
| 26 |
+
# more info: https://github.com/ashleve/rootutils
|
| 27 |
+
# ------------------------------------------------------------------------------------ #
|
| 28 |
+
|
| 29 |
+
from src.utils import (
|
| 30 |
+
RankedLogger,
|
| 31 |
+
extras,
|
| 32 |
+
get_metric_value,
|
| 33 |
+
instantiate_callbacks,
|
| 34 |
+
instantiate_loggers,
|
| 35 |
+
log_hyperparameters,
|
| 36 |
+
task_wrapper,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
log = RankedLogger(__name__, rank_zero_only=True)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@task_wrapper
|
| 43 |
+
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
| 44 |
+
"""Trains the model. Can additionally evaluate on a testset, using best weights obtained during
|
| 45 |
+
training.
|
| 46 |
+
|
| 47 |
+
This method is wrapped in optional @task_wrapper decorator, that controls the behavior during
|
| 48 |
+
failure. Useful for multiruns, saving info about the crash, etc.
|
| 49 |
+
|
| 50 |
+
:param cfg: A DictConfig configuration composed by Hydra.
|
| 51 |
+
:return: A tuple with metrics and dict with all instantiated objects.
|
| 52 |
+
"""
|
| 53 |
+
# set seed for random number generators in pytorch, numpy and python.random
|
| 54 |
+
if cfg.get("seed"):
|
| 55 |
+
L.seed_everything(cfg.seed, workers=True)
|
| 56 |
+
|
| 57 |
+
log.info(f"Instantiating datamodule <{cfg.data._target_}>")
|
| 58 |
+
datamodule: LightningDataModule = hydra.utils.instantiate(cfg.data)
|
| 59 |
+
|
| 60 |
+
log.info(f"Instantiating model <{cfg.model._target_}>")
|
| 61 |
+
model: LightningModule = hydra.utils.instantiate(cfg.model)
|
| 62 |
+
|
| 63 |
+
log.info("Instantiating callbacks...")
|
| 64 |
+
callbacks: List[Callback] = instantiate_callbacks(cfg.get("callbacks"))
|
| 65 |
+
|
| 66 |
+
log.info("Instantiating loggers...")
|
| 67 |
+
logger: List[Logger] = instantiate_loggers(cfg.get("logger"))
|
| 68 |
+
|
| 69 |
+
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
|
| 70 |
+
trainer: Trainer = hydra.utils.instantiate(cfg.trainer, callbacks=callbacks, logger=logger)
|
| 71 |
+
|
| 72 |
+
object_dict = {
|
| 73 |
+
"cfg": cfg,
|
| 74 |
+
"datamodule": datamodule,
|
| 75 |
+
"model": model,
|
| 76 |
+
"callbacks": callbacks,
|
| 77 |
+
"logger": logger,
|
| 78 |
+
"trainer": trainer,
|
| 79 |
+
}
|
| 80 |
+
|
| 81 |
+
if logger:
|
| 82 |
+
log.info("Logging hyperparameters!")
|
| 83 |
+
log_hyperparameters(object_dict)
|
| 84 |
+
|
| 85 |
+
if cfg.get("train"):
|
| 86 |
+
log.info("Starting training!")
|
| 87 |
+
trainer.fit(model=model, datamodule=datamodule, ckpt_path=cfg.get("ckpt_path"))
|
| 88 |
+
|
| 89 |
+
train_metrics = trainer.callback_metrics
|
| 90 |
+
|
| 91 |
+
if cfg.get("test"):
|
| 92 |
+
log.info("Starting testing!")
|
| 93 |
+
ckpt_path = trainer.checkpoint_callback.best_model_path
|
| 94 |
+
if ckpt_path == "":
|
| 95 |
+
log.warning("Best ckpt not found! Using current weights for testing...")
|
| 96 |
+
ckpt_path = None
|
| 97 |
+
trainer.test(model=model, datamodule=datamodule, ckpt_path=ckpt_path)
|
| 98 |
+
log.info(f"Best ckpt path: {ckpt_path}")
|
| 99 |
+
|
| 100 |
+
test_metrics = trainer.callback_metrics
|
| 101 |
+
|
| 102 |
+
# merge train and test metrics
|
| 103 |
+
metric_dict = {**train_metrics, **test_metrics}
|
| 104 |
+
|
| 105 |
+
return metric_dict, object_dict
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
@hydra.main(version_base="1.3", config_path="../configs", config_name="train.yaml")
|
| 109 |
+
def main(cfg: DictConfig) -> Optional[float]:
|
| 110 |
+
"""Main entry point for training.
|
| 111 |
+
|
| 112 |
+
:param cfg: DictConfig configuration composed by Hydra.
|
| 113 |
+
:return: Optional[float] with optimized metric value.
|
| 114 |
+
"""
|
| 115 |
+
# apply extra utilities
|
| 116 |
+
# (e.g. ask for tags if none are provided in cfg, print cfg tree, etc.)
|
| 117 |
+
extras(cfg)
|
| 118 |
+
|
| 119 |
+
# train the model
|
| 120 |
+
metric_dict, _ = train(cfg)
|
| 121 |
+
|
| 122 |
+
# safely retrieve metric value for hydra-based hyperparameter optimization
|
| 123 |
+
metric_value = get_metric_value(
|
| 124 |
+
metric_dict=metric_dict, metric_name=cfg.get("optimized_metric")
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# return optimized metric
|
| 128 |
+
return metric_value
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
if __name__ == "__main__":
|
| 132 |
+
torch.set_float32_matmul_precision("high")
|
| 133 |
+
main()
|
src/utils/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from src.utils.instantiators import instantiate_callbacks, instantiate_loggers
|
| 2 |
+
from src.utils.logging_utils import log_hyperparameters
|
| 3 |
+
from src.utils.pylogger import RankedLogger
|
| 4 |
+
from src.utils.rich_utils import enforce_tags, print_config_tree
|
| 5 |
+
from src.utils.utils import extras, get_metric_value, task_wrapper
|
src/utils/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (546 Bytes). View file
|
|
|
src/utils/__pycache__/instantiators.cpython-310.pyc
ADDED
|
Binary file (1.57 kB). View file
|
|
|
src/utils/__pycache__/logging_utils.cpython-310.pyc
ADDED
|
Binary file (1.96 kB). View file
|
|
|
src/utils/__pycache__/pylogger.cpython-310.pyc
ADDED
|
Binary file (2.55 kB). View file
|
|
|
src/utils/__pycache__/rich_utils.cpython-310.pyc
ADDED
|
Binary file (3.21 kB). View file
|
|
|
src/utils/__pycache__/utils.cpython-310.pyc
ADDED
|
Binary file (3.69 kB). View file
|
|
|
src/utils/instantiators.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import List
|
| 2 |
+
|
| 3 |
+
import hydra
|
| 4 |
+
from lightning import Callback
|
| 5 |
+
from lightning.pytorch.loggers import Logger
|
| 6 |
+
from omegaconf import DictConfig
|
| 7 |
+
|
| 8 |
+
from src.utils import pylogger
|
| 9 |
+
|
| 10 |
+
log = pylogger.RankedLogger(__name__, rank_zero_only=True)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def instantiate_callbacks(callbacks_cfg: DictConfig) -> List[Callback]:
|
| 14 |
+
"""Instantiates callbacks from config.
|
| 15 |
+
|
| 16 |
+
:param callbacks_cfg: A DictConfig object containing callback configurations.
|
| 17 |
+
:return: A list of instantiated callbacks.
|
| 18 |
+
"""
|
| 19 |
+
callbacks: List[Callback] = []
|
| 20 |
+
|
| 21 |
+
if not callbacks_cfg:
|
| 22 |
+
log.warning("No callback configs found! Skipping..")
|
| 23 |
+
return callbacks
|
| 24 |
+
|
| 25 |
+
if not isinstance(callbacks_cfg, DictConfig):
|
| 26 |
+
raise TypeError("Callbacks config must be a DictConfig!")
|
| 27 |
+
|
| 28 |
+
for _, cb_conf in callbacks_cfg.items():
|
| 29 |
+
if isinstance(cb_conf, DictConfig) and "_target_" in cb_conf:
|
| 30 |
+
log.info(f"Instantiating callback <{cb_conf._target_}>")
|
| 31 |
+
callbacks.append(hydra.utils.instantiate(cb_conf))
|
| 32 |
+
|
| 33 |
+
return callbacks
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def instantiate_loggers(logger_cfg: DictConfig) -> List[Logger]:
|
| 37 |
+
"""Instantiates loggers from config.
|
| 38 |
+
|
| 39 |
+
:param logger_cfg: A DictConfig object containing logger configurations.
|
| 40 |
+
:return: A list of instantiated loggers.
|
| 41 |
+
"""
|
| 42 |
+
logger: List[Logger] = []
|
| 43 |
+
|
| 44 |
+
if not logger_cfg:
|
| 45 |
+
log.warning("No logger configs found! Skipping...")
|
| 46 |
+
return logger
|
| 47 |
+
|
| 48 |
+
if not isinstance(logger_cfg, DictConfig):
|
| 49 |
+
raise TypeError("Logger config must be a DictConfig!")
|
| 50 |
+
|
| 51 |
+
for _, lg_conf in logger_cfg.items():
|
| 52 |
+
if isinstance(lg_conf, DictConfig) and "_target_" in lg_conf:
|
| 53 |
+
log.info(f"Instantiating logger <{lg_conf._target_}>")
|
| 54 |
+
logger.append(hydra.utils.instantiate(lg_conf))
|
| 55 |
+
|
| 56 |
+
return logger
|
src/utils/logging_utils.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict
|
| 2 |
+
|
| 3 |
+
from lightning_utilities.core.rank_zero import rank_zero_only
|
| 4 |
+
from omegaconf import OmegaConf
|
| 5 |
+
|
| 6 |
+
from src.utils import pylogger
|
| 7 |
+
|
| 8 |
+
log = pylogger.RankedLogger(__name__, rank_zero_only=True)
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
@rank_zero_only
|
| 12 |
+
def log_hyperparameters(object_dict: Dict[str, Any]) -> None:
|
| 13 |
+
"""Controls which config parts are saved by Lightning loggers.
|
| 14 |
+
|
| 15 |
+
Additionally saves:
|
| 16 |
+
- Number of model parameters
|
| 17 |
+
|
| 18 |
+
:param object_dict: A dictionary containing the following objects:
|
| 19 |
+
- `"cfg"`: A DictConfig object containing the main config.
|
| 20 |
+
- `"model"`: The Lightning model.
|
| 21 |
+
- `"trainer"`: The Lightning trainer.
|
| 22 |
+
"""
|
| 23 |
+
hparams = {}
|
| 24 |
+
|
| 25 |
+
cfg = OmegaConf.to_container(object_dict["cfg"])
|
| 26 |
+
model = object_dict["model"]
|
| 27 |
+
trainer = object_dict["trainer"]
|
| 28 |
+
|
| 29 |
+
if not trainer.logger:
|
| 30 |
+
log.warning("Logger not found! Skipping hyperparameter logging...")
|
| 31 |
+
return
|
| 32 |
+
|
| 33 |
+
hparams["model"] = cfg["model"]
|
| 34 |
+
|
| 35 |
+
# save number of model parameters
|
| 36 |
+
hparams["model/params/total"] = sum(p.numel() for p in model.parameters())
|
| 37 |
+
hparams["model/params/trainable"] = sum(
|
| 38 |
+
p.numel() for p in model.parameters() if p.requires_grad
|
| 39 |
+
)
|
| 40 |
+
hparams["model/params/non_trainable"] = sum(
|
| 41 |
+
p.numel() for p in model.parameters() if not p.requires_grad
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
hparams["data"] = cfg["data"]
|
| 45 |
+
hparams["trainer"] = cfg["trainer"]
|
| 46 |
+
|
| 47 |
+
hparams["callbacks"] = cfg.get("callbacks")
|
| 48 |
+
hparams["extras"] = cfg.get("extras")
|
| 49 |
+
|
| 50 |
+
hparams["task_name"] = cfg.get("task_name")
|
| 51 |
+
hparams["tags"] = cfg.get("tags")
|
| 52 |
+
hparams["ckpt_path"] = cfg.get("ckpt_path")
|
| 53 |
+
hparams["seed"] = cfg.get("seed")
|
| 54 |
+
|
| 55 |
+
# send hparams to all loggers
|
| 56 |
+
for logger in trainer.loggers:
|
| 57 |
+
logger.log_hyperparams(hparams)
|
src/utils/pylogger.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
from typing import Mapping, Optional
|
| 3 |
+
|
| 4 |
+
from lightning_utilities.core.rank_zero import rank_prefixed_message, rank_zero_only
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class RankedLogger(logging.LoggerAdapter):
|
| 8 |
+
"""A multi-GPU-friendly python command line logger."""
|
| 9 |
+
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
name: str = __name__,
|
| 13 |
+
rank_zero_only: bool = False,
|
| 14 |
+
extra: Optional[Mapping[str, object]] = None,
|
| 15 |
+
) -> None:
|
| 16 |
+
"""Initializes a multi-GPU-friendly python command line logger that logs on all processes
|
| 17 |
+
with their rank prefixed in the log message.
|
| 18 |
+
|
| 19 |
+
:param name: The name of the logger. Default is ``__name__``.
|
| 20 |
+
:param rank_zero_only: Whether to force all logs to only occur on the rank zero process. Default is `False`.
|
| 21 |
+
:param extra: (Optional) A dict-like object which provides contextual information. See `logging.LoggerAdapter`.
|
| 22 |
+
"""
|
| 23 |
+
logger = logging.getLogger(name)
|
| 24 |
+
super().__init__(logger=logger, extra=extra)
|
| 25 |
+
self.rank_zero_only = rank_zero_only
|
| 26 |
+
|
| 27 |
+
def log(self, level: int, msg: str, rank: Optional[int] = None, *args, **kwargs) -> None:
|
| 28 |
+
"""Delegate a log call to the underlying logger, after prefixing its message with the rank
|
| 29 |
+
of the process it's being logged from. If `'rank'` is provided, then the log will only
|
| 30 |
+
occur on that rank/process.
|
| 31 |
+
|
| 32 |
+
:param level: The level to log at. Look at `logging.__init__.py` for more information.
|
| 33 |
+
:param msg: The message to log.
|
| 34 |
+
:param rank: The rank to log at.
|
| 35 |
+
:param args: Additional args to pass to the underlying logging function.
|
| 36 |
+
:param kwargs: Any additional keyword args to pass to the underlying logging function.
|
| 37 |
+
"""
|
| 38 |
+
if self.isEnabledFor(level):
|
| 39 |
+
msg, kwargs = self.process(msg, kwargs)
|
| 40 |
+
current_rank = getattr(rank_zero_only, "rank", None)
|
| 41 |
+
if current_rank is None:
|
| 42 |
+
raise RuntimeError("The `rank_zero_only.rank` needs to be set before use")
|
| 43 |
+
msg = rank_prefixed_message(msg, current_rank)
|
| 44 |
+
if self.rank_zero_only:
|
| 45 |
+
if current_rank == 0:
|
| 46 |
+
self.logger.log(level, msg, *args, **kwargs)
|
| 47 |
+
else:
|
| 48 |
+
if rank is None:
|
| 49 |
+
self.logger.log(level, msg, *args, **kwargs)
|
| 50 |
+
elif current_rank == rank:
|
| 51 |
+
self.logger.log(level, msg, *args, **kwargs)
|
src/utils/rich_utils.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from typing import Sequence
|
| 3 |
+
|
| 4 |
+
import rich
|
| 5 |
+
import rich.syntax
|
| 6 |
+
import rich.tree
|
| 7 |
+
from hydra.core.hydra_config import HydraConfig
|
| 8 |
+
from lightning_utilities.core.rank_zero import rank_zero_only
|
| 9 |
+
from omegaconf import DictConfig, OmegaConf, open_dict
|
| 10 |
+
from rich.prompt import Prompt
|
| 11 |
+
|
| 12 |
+
from src.utils import pylogger
|
| 13 |
+
|
| 14 |
+
log = pylogger.RankedLogger(__name__, rank_zero_only=True)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@rank_zero_only
|
| 18 |
+
def print_config_tree(
|
| 19 |
+
cfg: DictConfig,
|
| 20 |
+
print_order: Sequence[str] = (
|
| 21 |
+
"data",
|
| 22 |
+
"model",
|
| 23 |
+
"callbacks",
|
| 24 |
+
"logger",
|
| 25 |
+
"trainer",
|
| 26 |
+
"paths",
|
| 27 |
+
"extras",
|
| 28 |
+
),
|
| 29 |
+
resolve: bool = False,
|
| 30 |
+
save_to_file: bool = False,
|
| 31 |
+
) -> None:
|
| 32 |
+
"""Prints the contents of a DictConfig as a tree structure using the Rich library.
|
| 33 |
+
|
| 34 |
+
:param cfg: A DictConfig composed by Hydra.
|
| 35 |
+
:param print_order: Determines in what order config components are printed. Default is ``("data", "model",
|
| 36 |
+
"callbacks", "logger", "trainer", "paths", "extras")``.
|
| 37 |
+
:param resolve: Whether to resolve reference fields of DictConfig. Default is ``False``.
|
| 38 |
+
:param save_to_file: Whether to export config to the hydra output folder. Default is ``False``.
|
| 39 |
+
"""
|
| 40 |
+
style = "dim"
|
| 41 |
+
tree = rich.tree.Tree("CONFIG", style=style, guide_style=style)
|
| 42 |
+
|
| 43 |
+
queue = []
|
| 44 |
+
|
| 45 |
+
# add fields from `print_order` to queue
|
| 46 |
+
for field in print_order:
|
| 47 |
+
(
|
| 48 |
+
queue.append(field)
|
| 49 |
+
if field in cfg
|
| 50 |
+
else log.warning(
|
| 51 |
+
f"Field '{field}' not found in config. Skipping '{field}' config printing..."
|
| 52 |
+
)
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# add all the other fields to queue (not specified in `print_order`)
|
| 56 |
+
for field in cfg:
|
| 57 |
+
if field not in queue:
|
| 58 |
+
queue.append(field)
|
| 59 |
+
|
| 60 |
+
# generate config tree from queue
|
| 61 |
+
for field in queue:
|
| 62 |
+
branch = tree.add(field, style=style, guide_style=style)
|
| 63 |
+
|
| 64 |
+
config_group = cfg[field]
|
| 65 |
+
if isinstance(config_group, DictConfig):
|
| 66 |
+
branch_content = OmegaConf.to_yaml(config_group, resolve=resolve)
|
| 67 |
+
else:
|
| 68 |
+
branch_content = str(config_group)
|
| 69 |
+
|
| 70 |
+
branch.add(rich.syntax.Syntax(branch_content, "yaml"))
|
| 71 |
+
|
| 72 |
+
# print config tree
|
| 73 |
+
rich.print(tree)
|
| 74 |
+
|
| 75 |
+
# save config tree to file
|
| 76 |
+
if save_to_file:
|
| 77 |
+
with open(Path(cfg.paths.output_dir, "config_tree.log"), "w") as file:
|
| 78 |
+
rich.print(tree, file=file)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
@rank_zero_only
|
| 82 |
+
def enforce_tags(cfg: DictConfig, save_to_file: bool = False) -> None:
|
| 83 |
+
"""Prompts user to input tags from command line if no tags are provided in config.
|
| 84 |
+
|
| 85 |
+
:param cfg: A DictConfig composed by Hydra.
|
| 86 |
+
:param save_to_file: Whether to export tags to the hydra output folder. Default is ``False``.
|
| 87 |
+
"""
|
| 88 |
+
if not cfg.get("tags"):
|
| 89 |
+
if "id" in HydraConfig().cfg.hydra.job:
|
| 90 |
+
raise ValueError("Specify tags before launching a multirun!")
|
| 91 |
+
|
| 92 |
+
log.warning("No tags provided in config. Prompting user to input tags...")
|
| 93 |
+
tags = Prompt.ask("Enter a list of comma separated tags", default="dev")
|
| 94 |
+
tags = [t.strip() for t in tags.split(",") if t != ""]
|
| 95 |
+
|
| 96 |
+
with open_dict(cfg):
|
| 97 |
+
cfg.tags = tags
|
| 98 |
+
|
| 99 |
+
log.info(f"Tags: {cfg.tags}")
|
| 100 |
+
|
| 101 |
+
if save_to_file:
|
| 102 |
+
with open(Path(cfg.paths.output_dir, "tags.log"), "w") as file:
|
| 103 |
+
rich.print(cfg.tags, file=file)
|
src/utils/utils.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import warnings
|
| 2 |
+
from importlib.util import find_spec
|
| 3 |
+
from typing import Any, Callable, Dict, Optional, Tuple
|
| 4 |
+
|
| 5 |
+
from omegaconf import DictConfig
|
| 6 |
+
|
| 7 |
+
from src.utils import pylogger, rich_utils
|
| 8 |
+
|
| 9 |
+
log = pylogger.RankedLogger(__name__, rank_zero_only=True)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def extras(cfg: DictConfig) -> None:
|
| 13 |
+
"""Applies optional utilities before the task is started.
|
| 14 |
+
|
| 15 |
+
Utilities:
|
| 16 |
+
- Ignoring python warnings
|
| 17 |
+
- Setting tags from command line
|
| 18 |
+
- Rich config printing
|
| 19 |
+
|
| 20 |
+
:param cfg: A DictConfig object containing the config tree.
|
| 21 |
+
"""
|
| 22 |
+
# return if no `extras` config
|
| 23 |
+
if not cfg.get("extras"):
|
| 24 |
+
log.warning("Extras config not found! <cfg.extras=null>")
|
| 25 |
+
return
|
| 26 |
+
|
| 27 |
+
# disable python warnings
|
| 28 |
+
if cfg.extras.get("ignore_warnings"):
|
| 29 |
+
log.info("Disabling python warnings! <cfg.extras.ignore_warnings=True>")
|
| 30 |
+
warnings.filterwarnings("ignore")
|
| 31 |
+
|
| 32 |
+
# prompt user to input tags from command line if none are provided in the config
|
| 33 |
+
if cfg.extras.get("enforce_tags"):
|
| 34 |
+
log.info("Enforcing tags! <cfg.extras.enforce_tags=True>")
|
| 35 |
+
rich_utils.enforce_tags(cfg, save_to_file=True)
|
| 36 |
+
|
| 37 |
+
# pretty print config tree using Rich library
|
| 38 |
+
if cfg.extras.get("print_config"):
|
| 39 |
+
log.info("Printing config tree with Rich! <cfg.extras.print_config=True>")
|
| 40 |
+
rich_utils.print_config_tree(cfg, resolve=True, save_to_file=True)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def task_wrapper(task_func: Callable) -> Callable:
|
| 44 |
+
"""Optional decorator that controls the failure behavior when executing the task function.
|
| 45 |
+
|
| 46 |
+
This wrapper can be used to:
|
| 47 |
+
- make sure loggers are closed even if the task function raises an exception (prevents multirun failure)
|
| 48 |
+
- save the exception to a `.log` file
|
| 49 |
+
- mark the run as failed with a dedicated file in the `logs/` folder (so we can find and rerun it later)
|
| 50 |
+
- etc. (adjust depending on your needs)
|
| 51 |
+
|
| 52 |
+
Example:
|
| 53 |
+
```
|
| 54 |
+
@utils.task_wrapper
|
| 55 |
+
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
| 56 |
+
...
|
| 57 |
+
return metric_dict, object_dict
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
:param task_func: The task function to be wrapped.
|
| 61 |
+
|
| 62 |
+
:return: The wrapped task function.
|
| 63 |
+
"""
|
| 64 |
+
|
| 65 |
+
def wrap(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
| 66 |
+
# execute the task
|
| 67 |
+
try:
|
| 68 |
+
metric_dict, object_dict = task_func(cfg=cfg)
|
| 69 |
+
|
| 70 |
+
# things to do if exception occurs
|
| 71 |
+
except Exception as ex:
|
| 72 |
+
# save exception to `.log` file
|
| 73 |
+
log.exception("")
|
| 74 |
+
|
| 75 |
+
# some hyperparameter combinations might be invalid or cause out-of-memory errors
|
| 76 |
+
# so when using hparam search plugins like Optuna, you might want to disable
|
| 77 |
+
# raising the below exception to avoid multirun failure
|
| 78 |
+
raise ex
|
| 79 |
+
|
| 80 |
+
# things to always do after either success or exception
|
| 81 |
+
finally:
|
| 82 |
+
# display output dir path in terminal
|
| 83 |
+
log.info(f"Output dir: {cfg.paths.output_dir}")
|
| 84 |
+
|
| 85 |
+
# always close wandb run (even if exception occurs so multirun won't fail)
|
| 86 |
+
if find_spec("wandb"): # check if wandb is installed
|
| 87 |
+
import wandb
|
| 88 |
+
|
| 89 |
+
if wandb.run:
|
| 90 |
+
log.info("Closing wandb!")
|
| 91 |
+
wandb.finish()
|
| 92 |
+
|
| 93 |
+
return metric_dict, object_dict
|
| 94 |
+
|
| 95 |
+
return wrap
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def get_metric_value(metric_dict: Dict[str, Any], metric_name: Optional[str]) -> Optional[float]:
|
| 99 |
+
"""Safely retrieves value of the metric logged in LightningModule.
|
| 100 |
+
|
| 101 |
+
:param metric_dict: A dict containing metric values.
|
| 102 |
+
:param metric_name: If provided, the name of the metric to retrieve.
|
| 103 |
+
:return: If a metric name was provided, the value of the metric.
|
| 104 |
+
"""
|
| 105 |
+
if not metric_name:
|
| 106 |
+
log.info("Metric name is None! Skipping metric value retrieval...")
|
| 107 |
+
return None
|
| 108 |
+
|
| 109 |
+
if metric_name not in metric_dict:
|
| 110 |
+
raise Exception(
|
| 111 |
+
f"Metric value not found! <metric_name={metric_name}>\n"
|
| 112 |
+
"Make sure metric name logged in LightningModule is correct!\n"
|
| 113 |
+
"Make sure `optimized_metric` name in `hparams_search` config is correct!"
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
metric_value = metric_dict[metric_name].item()
|
| 117 |
+
log.info(f"Retrieved metric value! <{metric_name}={metric_value}>")
|
| 118 |
+
|
| 119 |
+
return metric_value
|