Spaces:
Build error
Build error
Commit ·
005e4d3
1
Parent(s): 777c843
Add source code (clean)
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitignore +9 -0
- MolecularDiffusion/__init__.py +26 -0
- MolecularDiffusion/_version.py +21 -0
- MolecularDiffusion/callbacks/__init__.py +13 -0
- MolecularDiffusion/callbacks/train_helper.py +259 -0
- MolecularDiffusion/cli/__init__.py +6 -0
- MolecularDiffusion/cli/_hydra.py +129 -0
- MolecularDiffusion/cli/analyze.py +380 -0
- MolecularDiffusion/cli/eval_predict.py +259 -0
- MolecularDiffusion/cli/generate.py +282 -0
- MolecularDiffusion/cli/main.py +197 -0
- MolecularDiffusion/cli/predict.py +395 -0
- MolecularDiffusion/cli/train.py +453 -0
- MolecularDiffusion/configs/data/filter_molecules_by_property.py +0 -0
- MolecularDiffusion/configs/data/formed_data.yaml +20 -0
- MolecularDiffusion/configs/data/mol_dataset.yaml +25 -0
- MolecularDiffusion/configs/data/mol_dataset_extraf.yaml +23 -0
- MolecularDiffusion/configs/engine/lightning.yaml +33 -0
- MolecularDiffusion/configs/engine/original.yaml +4 -0
- MolecularDiffusion/configs/hydra/default.yaml +19 -0
- MolecularDiffusion/configs/interference/gen_cfg.yaml +15 -0
- MolecularDiffusion/configs/interference/gen_cfggg.yaml +29 -0
- MolecularDiffusion/configs/interference/gen_conditional.yaml +12 -0
- MolecularDiffusion/configs/interference/gen_gg.yaml +29 -0
- MolecularDiffusion/configs/interference/gen_hybrid.yaml +28 -0
- MolecularDiffusion/configs/interference/gen_inpaint.yaml +69 -0
- MolecularDiffusion/configs/interference/gen_outpaint.yaml +31 -0
- MolecularDiffusion/configs/interference/gen_outpaintft.yaml +18 -0
- MolecularDiffusion/configs/interference/gen_unconditional.yaml +11 -0
- MolecularDiffusion/configs/interference/prediction.yaml +2 -0
- MolecularDiffusion/configs/logger/default.yaml +9 -0
- MolecularDiffusion/configs/logger/wandb.yaml +9 -0
- MolecularDiffusion/configs/models/tabasco_transformer.yaml +72 -0
- MolecularDiffusion/configs/tasks/diffusion.yaml +48 -0
- MolecularDiffusion/configs/tasks/diffusion_egt.yaml +54 -0
- MolecularDiffusion/configs/tasks/diffusion_extraf.yaml +47 -0
- MolecularDiffusion/configs/tasks/diffusion_hybrid.yaml +95 -0
- MolecularDiffusion/configs/tasks/diffusion_hybrid_egcl.yaml +53 -0
- MolecularDiffusion/configs/tasks/diffusion_integer.yaml +62 -0
- MolecularDiffusion/configs/tasks/diffusion_pretrained.yaml +47 -0
- MolecularDiffusion/configs/tasks/diffusion_pyg.yaml +82 -0
- MolecularDiffusion/configs/tasks/diffusion_pyg_egcl.yaml +55 -0
- MolecularDiffusion/configs/tasks/diffusion_pyg_egt.yaml +56 -0
- MolecularDiffusion/configs/tasks/diffusion_tabasco.yaml +66 -0
- MolecularDiffusion/configs/tasks/guidance.yaml +40 -0
- MolecularDiffusion/configs/tasks/guidance_esen.yaml +43 -0
- MolecularDiffusion/configs/tasks/guidance_pc.yaml +43 -0
- MolecularDiffusion/configs/tasks/ldm_dit.yaml +24 -0
- MolecularDiffusion/configs/tasks/regression.yaml +30 -0
- MolecularDiffusion/configs/tasks/regression_esen.yaml +34 -0
.gitignore
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
__pycache__/
|
| 2 |
+
*.pyc
|
| 3 |
+
*.pyo
|
| 4 |
+
*.pyd
|
| 5 |
+
.DS_Store
|
| 6 |
+
.env
|
| 7 |
+
.venv
|
| 8 |
+
env/
|
| 9 |
+
venv/
|
MolecularDiffusion/__init__.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MolecularDiffusion - A molecular diffusion framework.
|
| 3 |
+
|
| 4 |
+
This package provides tools and models for molecular diffusion processes.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
__version__ = "0.1.0"
|
| 8 |
+
__author__ = "Thanapat Worakul"
|
| 9 |
+
__email__ = "thanapat.worakul@epfl.ch"
|
| 10 |
+
|
| 11 |
+
# Import main modules to make them available at package level
|
| 12 |
+
from . import core
|
| 13 |
+
from . import data
|
| 14 |
+
from . import modules
|
| 15 |
+
from . import utils
|
| 16 |
+
from . import callbacks
|
| 17 |
+
from . import runmodes
|
| 18 |
+
|
| 19 |
+
__all__ = [
|
| 20 |
+
"core",
|
| 21 |
+
"data",
|
| 22 |
+
"modules",
|
| 23 |
+
"utils",
|
| 24 |
+
"callbacks",
|
| 25 |
+
"runmodes"
|
| 26 |
+
]
|
MolecularDiffusion/_version.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# file generated by setuptools-scm
|
| 2 |
+
# don't change, don't track in version control
|
| 3 |
+
|
| 4 |
+
__all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
|
| 5 |
+
|
| 6 |
+
TYPE_CHECKING = False
|
| 7 |
+
if TYPE_CHECKING:
|
| 8 |
+
from typing import Tuple
|
| 9 |
+
from typing import Union
|
| 10 |
+
|
| 11 |
+
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
| 12 |
+
else:
|
| 13 |
+
VERSION_TUPLE = object
|
| 14 |
+
|
| 15 |
+
version: str
|
| 16 |
+
__version__: str
|
| 17 |
+
__version_tuple__: VERSION_TUPLE
|
| 18 |
+
version_tuple: VERSION_TUPLE
|
| 19 |
+
|
| 20 |
+
__version__ = version = '0.1.dev26+gff3c644.d20250809'
|
| 21 |
+
__version_tuple__ = version_tuple = (0, 1, 'dev26', 'gff3c644.d20250809')
|
MolecularDiffusion/callbacks/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .train_helper import (
|
| 2 |
+
Queue,
|
| 3 |
+
gradient_clipping,
|
| 4 |
+
EMA,
|
| 5 |
+
SP_regularizer
|
| 6 |
+
)
|
| 7 |
+
|
| 8 |
+
__all__ = [
|
| 9 |
+
"Queue",
|
| 10 |
+
"gradient_clipping",
|
| 11 |
+
"EMA",
|
| 12 |
+
"SP_regularizer"
|
| 13 |
+
]
|
MolecularDiffusion/callbacks/train_helper.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import logging
|
| 4 |
+
|
| 5 |
+
logger = logging.getLogger(__name__)
|
| 6 |
+
logger.setLevel(logging.CRITICAL)
|
| 7 |
+
class Queue:
|
| 8 |
+
def __init__(self, max_len=50):
|
| 9 |
+
self.items = []
|
| 10 |
+
self.max_len = max_len
|
| 11 |
+
|
| 12 |
+
def __len__(self):
|
| 13 |
+
return len(self.items)
|
| 14 |
+
|
| 15 |
+
def add(self, item):
|
| 16 |
+
self.items.insert(0, item)
|
| 17 |
+
if len(self) > self.max_len:
|
| 18 |
+
self.items.pop()
|
| 19 |
+
|
| 20 |
+
def mean(self):
|
| 21 |
+
return np.mean(self.items)
|
| 22 |
+
|
| 23 |
+
def std(self):
|
| 24 |
+
return np.std(self.items)
|
| 25 |
+
|
| 26 |
+
class gradient_clipping:
|
| 27 |
+
def __init__(self, m=1, max_len=200):
|
| 28 |
+
self.max_grad_norm = None
|
| 29 |
+
self.max_grad_norms = []
|
| 30 |
+
self.max_len = max_len
|
| 31 |
+
self.m = m
|
| 32 |
+
self.FACTOR = 100
|
| 33 |
+
|
| 34 |
+
def __call__(self, model, gradnorm_queue):
|
| 35 |
+
self.max_grad_norm = 1.5 * gradnorm_queue.mean() + 2 * gradnorm_queue.std()
|
| 36 |
+
if len(self.max_grad_norms) == 0:
|
| 37 |
+
self.max_grad_norms.append(self.max_grad_norm)
|
| 38 |
+
else:
|
| 39 |
+
#max_grad_norm_mean = torch.mean(torch.tensor(self.max_grad_norms))
|
| 40 |
+
previous_max_grad_norm = self.max_grad_norms[-1]
|
| 41 |
+
# if the current max_grad_norm is greater than the mean of the previous max_grad_norms
|
| 42 |
+
if self.max_grad_norm > previous_max_grad_norm:
|
| 43 |
+
self.max_grad_norm = previous_max_grad_norm * self.m
|
| 44 |
+
if self.max_grad_norm > previous_max_grad_norm * 1e5:
|
| 45 |
+
self.max_grad_norm = previous_max_grad_norm * self.m / self.FACTOR
|
| 46 |
+
|
| 47 |
+
self.max_grad_norms.append(self.max_grad_norm)
|
| 48 |
+
|
| 49 |
+
if len(self.max_grad_norms) > self.max_len:
|
| 50 |
+
self.max_grad_norms.pop(0)
|
| 51 |
+
# Clips gradient and returns the norm
|
| 52 |
+
|
| 53 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(
|
| 54 |
+
model.parameters(), max_norm=self.max_grad_norm, norm_type=2.0
|
| 55 |
+
)
|
| 56 |
+
if float(grad_norm) > self.max_grad_norm:
|
| 57 |
+
gradnorm_queue.add(float(self.max_grad_norm))
|
| 58 |
+
else:
|
| 59 |
+
gradnorm_queue.add(float(grad_norm))
|
| 60 |
+
|
| 61 |
+
if float(grad_norm) > self.max_grad_norm:
|
| 62 |
+
logger.info(
|
| 63 |
+
f"Clipped gradient with value {grad_norm:.1f} "
|
| 64 |
+
f"while allowed {self.max_grad_norm:.1f}"
|
| 65 |
+
)
|
| 66 |
+
return grad_norm
|
| 67 |
+
|
| 68 |
+
class gradient_clipping_0:
|
| 69 |
+
def __init__(self, m=1, max_len=200):
|
| 70 |
+
self.max_grad_norm = None
|
| 71 |
+
self.max_grad_norms = []
|
| 72 |
+
self.max_len = max_len
|
| 73 |
+
self.m = m
|
| 74 |
+
|
| 75 |
+
def __call__(self, model, gradnorm_queue):
|
| 76 |
+
self.max_grad_norm = 1.5 * gradnorm_queue.mean() + 2 * gradnorm_queue.std()
|
| 77 |
+
if len(self.max_grad_norms) == 0:
|
| 78 |
+
self.max_grad_norms.append(self.max_grad_norm)
|
| 79 |
+
else:
|
| 80 |
+
max_grad_norm_mean = torch.mean(torch.tensor(self.max_grad_norms))
|
| 81 |
+
if self.max_grad_norm > max_grad_norm_mean:
|
| 82 |
+
self.max_grad_norm = max_grad_norm_mean * self.m
|
| 83 |
+
if self.max_grad_norm > max_grad_norm_mean * 1e5:
|
| 84 |
+
self.max_grad_norm = max_grad_norm_mean * self.m / 10
|
| 85 |
+
self.max_grad_norms.append(self.max_grad_norm)
|
| 86 |
+
|
| 87 |
+
if len(self.max_grad_norms) > self.max_len:
|
| 88 |
+
self.max_grad_norms.pop(0)
|
| 89 |
+
# Clips gradient and returns the norm
|
| 90 |
+
|
| 91 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(
|
| 92 |
+
model.parameters(), max_norm=self.max_grad_norm, norm_type=2.0
|
| 93 |
+
)
|
| 94 |
+
if float(grad_norm) > self.max_grad_norm:
|
| 95 |
+
gradnorm_queue.add(float(self.max_grad_norm))
|
| 96 |
+
else:
|
| 97 |
+
gradnorm_queue.add(float(grad_norm))
|
| 98 |
+
|
| 99 |
+
if float(grad_norm) > self.max_grad_norm:
|
| 100 |
+
print(
|
| 101 |
+
f"Clipped gradient with value {grad_norm:.1f} "
|
| 102 |
+
f"while allowed {self.max_grad_norm:.1f}"
|
| 103 |
+
)
|
| 104 |
+
return grad_norm
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class EMA:
|
| 108 |
+
def __init__(self, beta):
|
| 109 |
+
super().__init__()
|
| 110 |
+
self.beta = beta
|
| 111 |
+
|
| 112 |
+
def update_model_average(self, ma_model, current_model):
|
| 113 |
+
for current_params, ma_params in zip(
|
| 114 |
+
current_model.parameters(), ma_model.parameters()
|
| 115 |
+
):
|
| 116 |
+
old_weight, up_weight = ma_params.data, current_params.data
|
| 117 |
+
ma_params.data = self.update_average(old_weight, up_weight)
|
| 118 |
+
|
| 119 |
+
def update_average(self, old, new):
|
| 120 |
+
if old is None:
|
| 121 |
+
return new
|
| 122 |
+
return old * self.beta + (1 - self.beta) * new
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
class SP_regularizer:
|
| 126 |
+
def __init__(
|
| 127 |
+
self,
|
| 128 |
+
regularizer: str,
|
| 129 |
+
lambda_: float = 10,
|
| 130 |
+
lambda_2: float = 100,
|
| 131 |
+
lambda_update_value: float = 50,
|
| 132 |
+
lambda_update_step: int = 2500,
|
| 133 |
+
polynomial_p: float = 1.5,
|
| 134 |
+
warm_up_steps: int = 100,
|
| 135 |
+
):
|
| 136 |
+
"""
|
| 137 |
+
Self-paced regularizer for curriculum learning
|
| 138 |
+
Args:
|
| 139 |
+
regularizer (str): Regularizer to use. Options are:
|
| 140 |
+
- hard
|
| 141 |
+
- linear
|
| 142 |
+
- logaritmic
|
| 143 |
+
- logistic
|
| 144 |
+
lambda_ (float): Initial lambda value
|
| 145 |
+
lambda_2 (float): Initial lambda value for the second regularizer
|
| 146 |
+
lambda_update_value (float): Value to update lambda
|
| 147 |
+
lambda_update_step (int): Number of steps to update lambda
|
| 148 |
+
polynomial_p (float): Value of p for polynomial regularizer
|
| 149 |
+
warm_up_steps (int): Number of steps to use the regularizer
|
| 150 |
+
"""
|
| 151 |
+
|
| 152 |
+
self.regularizer = regularizer
|
| 153 |
+
self.lambda_ = lambda_
|
| 154 |
+
self.lambda_2 = lambda_2
|
| 155 |
+
self.n_calls = 1
|
| 156 |
+
self.lambda_update_value = lambda_update_value
|
| 157 |
+
self.lambda_update_step = lambda_update_step
|
| 158 |
+
self.p = polynomial_p
|
| 159 |
+
self.warm_up_steps = warm_up_steps
|
| 160 |
+
|
| 161 |
+
def __call__(self, losses: torch.Tensor):
|
| 162 |
+
|
| 163 |
+
# TODO during warm up steps, keep the losses infomation, to be used to determine lambda
|
| 164 |
+
if self.n_calls < self.warm_up_steps:
|
| 165 |
+
self.n_calls += 1
|
| 166 |
+
return losses
|
| 167 |
+
else:
|
| 168 |
+
if self.regularizer == "hard":
|
| 169 |
+
weighted_loss = self.hard(losses)
|
| 170 |
+
elif self.regularizer == "linear":
|
| 171 |
+
weighted_loss = self.linear(losses)
|
| 172 |
+
elif self.regularizer == "logaritmic":
|
| 173 |
+
weighted_loss = self.logaritmic(losses)
|
| 174 |
+
elif self.regularizer == "logistic":
|
| 175 |
+
weighted_loss = self.logistic(losses)
|
| 176 |
+
elif self.regularizer == "polynomial":
|
| 177 |
+
weighted_loss = self.polynomial(losses)
|
| 178 |
+
elif self.regularizer == "hard_relax":
|
| 179 |
+
weighted_loss = self.hard_relax(losses)
|
| 180 |
+
else:
|
| 181 |
+
raise ValueError("Regularizer not implemented")
|
| 182 |
+
self.n_calls += 1
|
| 183 |
+
self.update_lambda()
|
| 184 |
+
return weighted_loss
|
| 185 |
+
|
| 186 |
+
def update_lambda(self):
|
| 187 |
+
if self.n_calls % self.lambda_update_step == 0:
|
| 188 |
+
self.lambda_ += self.lambda_update_value
|
| 189 |
+
self.lambda_2 += self.lambda_update_value
|
| 190 |
+
elif self.n_calls == 0:
|
| 191 |
+
self.lambda_ = self.lambda_
|
| 192 |
+
self.lambda_2 = self.lambda_2
|
| 193 |
+
|
| 194 |
+
def hard(self, losses: torch.Tensor):
|
| 195 |
+
|
| 196 |
+
weights = (losses <= self.lambda_).float()
|
| 197 |
+
sp_loss = losses * weights
|
| 198 |
+
|
| 199 |
+
return sp_loss
|
| 200 |
+
|
| 201 |
+
def hard_relax(self, losses: torch.Tensor):
|
| 202 |
+
weights = torch.where(
|
| 203 |
+
losses < self.lambda_,
|
| 204 |
+
torch.ones_like(losses),
|
| 205 |
+
(1 - losses / self.lambda_2) ** (1 / (self.p - 1)),
|
| 206 |
+
)
|
| 207 |
+
idces_zero = torch.where(losses > self.lambda_2)
|
| 208 |
+
weights[idces_zero] = 0
|
| 209 |
+
weights = torch.clamp(weights, 0, 1)
|
| 210 |
+
sp_loss = losses * weights
|
| 211 |
+
|
| 212 |
+
return sp_loss
|
| 213 |
+
|
| 214 |
+
def linear(self, losses: torch.Tensor):
|
| 215 |
+
weights = torch.where(
|
| 216 |
+
losses > self.lambda_, torch.zeros_like(losses), 1 - losses / self.lambda_
|
| 217 |
+
)
|
| 218 |
+
weights = torch.clamp(weights, 0, 1)
|
| 219 |
+
sp_loss = losses * weights
|
| 220 |
+
|
| 221 |
+
return sp_loss
|
| 222 |
+
|
| 223 |
+
def logaritmic(self, losses: torch.Tensor):
|
| 224 |
+
|
| 225 |
+
weights = torch.where(
|
| 226 |
+
losses > self.lambda_,
|
| 227 |
+
torch.zeros_like(losses),
|
| 228 |
+
torch.log(2 - losses / self.lambda_),
|
| 229 |
+
)
|
| 230 |
+
weights = torch.clamp(weights, 0, 1)
|
| 231 |
+
sp_loss = losses * weights
|
| 232 |
+
|
| 233 |
+
return sp_loss
|
| 234 |
+
|
| 235 |
+
def logistic(self, losses: torch.Tensor):
|
| 236 |
+
|
| 237 |
+
weights = torch.where(
|
| 238 |
+
losses > self.lambda_,
|
| 239 |
+
torch.zeros_like(losses),
|
| 240 |
+
(1 - torch.exp(torch.tensor(self.lambda_)))
|
| 241 |
+
/ (1 - torch.exp(losses - self.lambda_)),
|
| 242 |
+
)
|
| 243 |
+
weights = torch.clamp(weights, 0, 1)
|
| 244 |
+
sp_loss = losses * weights
|
| 245 |
+
|
| 246 |
+
return sp_loss
|
| 247 |
+
|
| 248 |
+
def polynomial(self, losses: torch.Tensor):
|
| 249 |
+
|
| 250 |
+
weights = torch.where(
|
| 251 |
+
losses > self.lambda_,
|
| 252 |
+
torch.zeros_like(losses),
|
| 253 |
+
(1 - losses / self.lambda_) ** (1 / (self.p - 1)),
|
| 254 |
+
)
|
| 255 |
+
weights = torch.clamp(weights, 0, 1)
|
| 256 |
+
sp_loss = losses * weights
|
| 257 |
+
|
| 258 |
+
return sp_loss
|
| 259 |
+
|
MolecularDiffusion/cli/__init__.py
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# CLI module for MolecularDiffusion.
|
| 2 |
+
"""Unified command-line interface for MolecularDiffusion package."""
|
| 3 |
+
|
| 4 |
+
from MolecularDiffusion.cli.main import cli
|
| 5 |
+
|
| 6 |
+
__all__ = ["cli"]
|
MolecularDiffusion/cli/_hydra.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Hydra configuration utilities for CLI.
|
| 2 |
+
|
| 3 |
+
Provides utilities for discovering and loading bundled configs
|
| 4 |
+
while allowing user configs to reference them.
|
| 5 |
+
"""
|
| 6 |
+
|
| 7 |
+
import os
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Optional, List
|
| 10 |
+
from importlib import resources
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_package_config_path() -> Path:
|
| 14 |
+
"""Get the absolute path to bundled config directory.
|
| 15 |
+
|
| 16 |
+
Returns:
|
| 17 |
+
Path to the configs directory within the installed package.
|
| 18 |
+
"""
|
| 19 |
+
# Use importlib.resources for Python 3.9+
|
| 20 |
+
try:
|
| 21 |
+
# For Python 3.9+
|
| 22 |
+
pkg_files = resources.files("MolecularDiffusion")
|
| 23 |
+
config_path = pkg_files / "configs"
|
| 24 |
+
# Convert to real path (handles both installed and editable installs)
|
| 25 |
+
if hasattr(config_path, '_path'):
|
| 26 |
+
# Traversable from importlib.resources
|
| 27 |
+
real_path = Path(str(config_path))
|
| 28 |
+
else:
|
| 29 |
+
real_path = Path(config_path)
|
| 30 |
+
if real_path.is_dir():
|
| 31 |
+
return real_path
|
| 32 |
+
except (TypeError, AttributeError, Exception):
|
| 33 |
+
pass
|
| 34 |
+
|
| 35 |
+
# Fallback: relative to this module
|
| 36 |
+
module_dir = Path(__file__).parent.parent
|
| 37 |
+
config_path = module_dir / "configs"
|
| 38 |
+
if config_path.is_dir():
|
| 39 |
+
return config_path
|
| 40 |
+
|
| 41 |
+
raise FileNotFoundError(
|
| 42 |
+
"Could not find bundled configs. Ensure package is installed correctly."
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def setup_hydra_config(
|
| 47 |
+
config_name: str,
|
| 48 |
+
config_dir: Optional[str] = None,
|
| 49 |
+
overrides: Optional[List[str]] = None,
|
| 50 |
+
):
|
| 51 |
+
"""Setup Hydra configuration with proper search paths.
|
| 52 |
+
|
| 53 |
+
Configures Hydra to search:
|
| 54 |
+
1. User's config_dir (if provided) or current directory
|
| 55 |
+
2. Package bundled configs (via searchpath)
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
config_name: Name of the config file (without .yaml extension)
|
| 59 |
+
config_dir: Optional user config directory
|
| 60 |
+
overrides: Optional list of Hydra override strings
|
| 61 |
+
|
| 62 |
+
Returns:
|
| 63 |
+
DictConfig from Hydra
|
| 64 |
+
"""
|
| 65 |
+
from hydra import compose, initialize_config_dir
|
| 66 |
+
from hydra.core.global_hydra import GlobalHydra
|
| 67 |
+
|
| 68 |
+
# Get package config path for defaults
|
| 69 |
+
pkg_config_path = get_package_config_path()
|
| 70 |
+
|
| 71 |
+
# Determine primary config directory
|
| 72 |
+
# If config_name contains a path (e.g., "configs/train.yaml"), extract the directory
|
| 73 |
+
config_name_path = Path(config_name)
|
| 74 |
+
if config_name_path.parent != Path("."):
|
| 75 |
+
# Config name includes directory, use that as config_dir
|
| 76 |
+
if config_dir is None:
|
| 77 |
+
config_dir = str(config_name_path.parent)
|
| 78 |
+
config_name = config_name_path.name
|
| 79 |
+
|
| 80 |
+
if config_dir:
|
| 81 |
+
primary_config_dir = os.path.abspath(config_dir)
|
| 82 |
+
else:
|
| 83 |
+
primary_config_dir = os.getcwd()
|
| 84 |
+
|
| 85 |
+
# Clear any existing Hydra state
|
| 86 |
+
GlobalHydra.instance().clear()
|
| 87 |
+
|
| 88 |
+
# Initialize with the primary config directory
|
| 89 |
+
initialize_config_dir(
|
| 90 |
+
config_dir=primary_config_dir,
|
| 91 |
+
version_base="1.3",
|
| 92 |
+
)
|
| 93 |
+
|
| 94 |
+
# Build overrides to include searchpath for bundled configs
|
| 95 |
+
all_overrides = overrides or []
|
| 96 |
+
|
| 97 |
+
# Add package config path to searchpath using file:// protocol
|
| 98 |
+
# This allows Hydra to find bundled defaults like data/mol_dataset.yaml
|
| 99 |
+
searchpath_override = f"hydra.searchpath=[file://{pkg_config_path}]"
|
| 100 |
+
all_overrides = [searchpath_override] + all_overrides
|
| 101 |
+
|
| 102 |
+
# Handle config name (strip .yaml if present)
|
| 103 |
+
if config_name.endswith(".yaml"):
|
| 104 |
+
config_name = config_name[:-5]
|
| 105 |
+
|
| 106 |
+
# Compose the configuration
|
| 107 |
+
cfg = compose(config_name=config_name, overrides=all_overrides)
|
| 108 |
+
|
| 109 |
+
return cfg
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def run_hydra_app(
|
| 113 |
+
config_name: str,
|
| 114 |
+
task_function,
|
| 115 |
+
config_dir: Optional[str] = None,
|
| 116 |
+
overrides: Optional[List[str]] = None,
|
| 117 |
+
):
|
| 118 |
+
"""Run a Hydra-based task function with proper config setup.
|
| 119 |
+
|
| 120 |
+
This is the main entry point for CLI commands that use Hydra configs.
|
| 121 |
+
|
| 122 |
+
Args:
|
| 123 |
+
config_name: Name of the config file
|
| 124 |
+
task_function: Function to call with the composed config
|
| 125 |
+
config_dir: Optional user config directory
|
| 126 |
+
overrides: Optional Hydra overrides
|
| 127 |
+
"""
|
| 128 |
+
cfg = setup_hydra_config(config_name, config_dir, overrides)
|
| 129 |
+
return task_function(cfg)
|
MolecularDiffusion/cli/analyze.py
ADDED
|
@@ -0,0 +1,380 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Analyze CLI subcommands for 3D molecule analysis.
|
| 2 |
+
|
| 3 |
+
Provides subcommands for:
|
| 4 |
+
- optimize: XTB geometry optimization
|
| 5 |
+
- metrics: Validity/connectivity metrics
|
| 6 |
+
- compare: RMSD, energy, and optional bond analysis
|
| 7 |
+
- xyz2mol: XYZ to SMILES conversion + fingerprints
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
import click
|
| 13 |
+
|
| 14 |
+
# Enable -h as alias for --help
|
| 15 |
+
CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
@click.group(context_settings=CONTEXT_SETTINGS)
|
| 19 |
+
def analyze():
|
| 20 |
+
"""Analyze 3D molecular structures.
|
| 21 |
+
|
| 22 |
+
\b
|
| 23 |
+
Subcommands:
|
| 24 |
+
optimize XTB geometry optimization
|
| 25 |
+
metrics Validity/connectivity metrics
|
| 26 |
+
compare RMSD, energy, and bond analysis
|
| 27 |
+
xyz2mol Convert XYZ to SMILES + fingerprints
|
| 28 |
+
"""
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# ============================================================================
|
| 33 |
+
# OPTIMIZE: XTB geometry optimization
|
| 34 |
+
# ============================================================================
|
| 35 |
+
|
| 36 |
+
@analyze.command("optimize", context_settings=CONTEXT_SETTINGS)
|
| 37 |
+
@click.argument("input_dir", type=click.Path(exists=True))
|
| 38 |
+
@click.option("--output-dir", "-o", "--o", default=None, type=click.Path(),
|
| 39 |
+
help="Output directory for optimized files (default: input_dir/optimized_xyz)")
|
| 40 |
+
@click.option("--charge", "-c", "--c", default=0, type=int,
|
| 41 |
+
help="Molecular charge for xTB (default: 0)")
|
| 42 |
+
@click.option("--level", "-l", "--l", default="gfn1", type=click.Choice(["gfn1", "gfn2", "gfn-ff", "mmff94"]),
|
| 43 |
+
help="Optimization level (default: gfn1)")
|
| 44 |
+
@click.option("--timeout", "-t", "--t", default=240, type=int,
|
| 45 |
+
help="Timeout per molecule in seconds (default: 240)")
|
| 46 |
+
@click.option("--scale-factor", "-s", "--s", default=1.3, type=float,
|
| 47 |
+
help="Scale factor for covalent radii (default: 1.3)")
|
| 48 |
+
@click.option("--csv", "csv_path", default=None, type=click.Path(),
|
| 49 |
+
help="CSV file to filter which files to optimize")
|
| 50 |
+
@click.option("--filter-column", default=None, type=str,
|
| 51 |
+
help="Column name in CSV to filter by (values must be 1)")
|
| 52 |
+
def optimize(input_dir, output_dir, charge, level, timeout, scale_factor, csv_path, filter_column):
|
| 53 |
+
"""Optimize XYZ geometries using xTB.
|
| 54 |
+
|
| 55 |
+
\b
|
| 56 |
+
Examples:
|
| 57 |
+
MolCraftDiff analyze optimize gen_xyz/
|
| 58 |
+
MolCraftDiff analyze optimize gen_xyz/ --o optimized/ --level gfn2
|
| 59 |
+
"""
|
| 60 |
+
from MolecularDiffusion.runmodes.analyze.xtb_optimization import get_xtb_optimized_xyz
|
| 61 |
+
|
| 62 |
+
output_dir = output_dir or os.path.join(input_dir, "optimized_xyz")
|
| 63 |
+
|
| 64 |
+
click.echo(f"Optimizing XYZ files from: {input_dir}")
|
| 65 |
+
click.echo(f"Output directory: {output_dir}")
|
| 66 |
+
click.echo(f"xTB level: {level}, charge: {charge}")
|
| 67 |
+
|
| 68 |
+
optimized_files = get_xtb_optimized_xyz(
|
| 69 |
+
input_directory=input_dir,
|
| 70 |
+
output_directory=output_dir,
|
| 71 |
+
charge=charge,
|
| 72 |
+
level=level,
|
| 73 |
+
timeout=timeout,
|
| 74 |
+
scale_factor=scale_factor,
|
| 75 |
+
csv_path=csv_path,
|
| 76 |
+
filter_column=filter_column,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
click.echo(f"\nSuccessfully optimized {len(optimized_files)} files.")
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# ============================================================================
|
| 83 |
+
# METRICS: Validity/connectivity metrics
|
| 84 |
+
# ============================================================================
|
| 85 |
+
|
| 86 |
+
@analyze.command("metrics", context_settings=CONTEXT_SETTINGS)
|
| 87 |
+
@click.argument("input_dir", type=click.Path(exists=True))
|
| 88 |
+
@click.option("--output", "-o", "--o", "--output-csv", default=None, type=click.Path(),
|
| 89 |
+
help="Output CSV file for results")
|
| 90 |
+
@click.option("--metrics", "-m", "--m", "metrics_type", default="all",
|
| 91 |
+
type=click.Choice(["all", "core", "posebuster", "geom_revised"]),
|
| 92 |
+
help="Which metrics to compute (default: all)")
|
| 93 |
+
@click.option("--recheck-topo", is_flag=True, default=False,
|
| 94 |
+
help="Recheck topology using RDKit")
|
| 95 |
+
@click.option("--check-strain", is_flag=True, default=False,
|
| 96 |
+
help="Check strain via XTB optimization")
|
| 97 |
+
@click.option("--portion", "-p", "--p", default=1.0, type=float,
|
| 98 |
+
help="Portion of XYZ files to process (default: 1.0 = all)")
|
| 99 |
+
@click.option("--mol-converter", default="cell2mol",
|
| 100 |
+
type=click.Choice(["cell2mol", "openbabel"]),
|
| 101 |
+
help="XYZ to mol converter (default: cell2mol)")
|
| 102 |
+
@click.option("--skip-atoms", multiple=True, type=int,
|
| 103 |
+
help="Atom indices to skip in validation")
|
| 104 |
+
@click.option("--n-subsets", "-n", "--n", default=5, type=int,
|
| 105 |
+
help="Number of subsets for std calculation (default: 5)")
|
| 106 |
+
@click.option("--timeout", "-t", "--t", default=10, type=int,
|
| 107 |
+
help="Timeout per xyz2mol conversion in seconds (default: 10)")
|
| 108 |
+
def metrics(input_dir, output, metrics_type, recheck_topo, check_strain, portion, mol_converter, skip_atoms, n_subsets, timeout):
|
| 109 |
+
"""Compute validity and connectivity metrics for XYZ files.
|
| 110 |
+
|
| 111 |
+
\b
|
| 112 |
+
Metrics types:
|
| 113 |
+
all Run all metrics (core + posebuster + geom_revised)
|
| 114 |
+
core Basic validity checks (connectivity, atom stability)
|
| 115 |
+
posebuster PoseBusters checks (bond lengths, angles, clashes)
|
| 116 |
+
geom_revised Aromatic-aware stability metrics
|
| 117 |
+
|
| 118 |
+
\b
|
| 119 |
+
Examples:
|
| 120 |
+
MolCraftDiff analyze metrics gen_xyz/
|
| 121 |
+
MolCraftDiff analyze metrics gen_xyz/ --metrics posebuster
|
| 122 |
+
MolCraftDiff analyze metrics gen_xyz/ --metrics geom_revised --mol-converter openbabel
|
| 123 |
+
"""
|
| 124 |
+
import argparse
|
| 125 |
+
from MolecularDiffusion.runmodes.analyze.compute_metrics import runner
|
| 126 |
+
|
| 127 |
+
args = argparse.Namespace(
|
| 128 |
+
input=input_dir,
|
| 129 |
+
output=output,
|
| 130 |
+
metrics=metrics_type,
|
| 131 |
+
recheck_topo=recheck_topo,
|
| 132 |
+
check_strain=check_strain,
|
| 133 |
+
portion=portion,
|
| 134 |
+
mol_converter=mol_converter,
|
| 135 |
+
skip_atoms=list(skip_atoms) if skip_atoms else None,
|
| 136 |
+
n_subsets=n_subsets,
|
| 137 |
+
timeout=timeout,
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
click.echo(f"Computing {metrics_type} metrics for: {input_dir}")
|
| 141 |
+
runner(args)
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
# ============================================================================
|
| 145 |
+
# COMPARE: Unified RMSD, energy, and bond analysis
|
| 146 |
+
# ============================================================================
|
| 147 |
+
|
| 148 |
+
@analyze.command("compare", context_settings=CONTEXT_SETTINGS)
|
| 149 |
+
@click.argument("directory", type=click.Path(exists=True))
|
| 150 |
+
@click.option("--mol-converter", default="openbabel", type=click.Choice(["openbabel", "cell2mol"]),
|
| 151 |
+
help="Converter for bond perception (default: openbabel)")
|
| 152 |
+
@click.option("--n-subsets", "-n", "--n", default=5, type=int,
|
| 153 |
+
help="Number of subsets for std calculation (default: 5)")
|
| 154 |
+
@click.option("--output", "-o", "--o", "--csv", "csv_path", default=None, type=click.Path(),
|
| 155 |
+
help="Output CSV filename for results")
|
| 156 |
+
@click.option("--charge", "-c", "--c", default=0, type=int,
|
| 157 |
+
help="Molecular charge for xTB energy (default: 0)")
|
| 158 |
+
@click.option("--level", "-l", "--l", default="gfn2", type=click.Choice(["gfn1", "gfn2", "gfn-ff", "mmff94"]),
|
| 159 |
+
help="xTB level for energy calculation (default: gfn2)")
|
| 160 |
+
@click.option("--timeout", "-t", "--t", default=120, type=int,
|
| 161 |
+
help="Timeout per xTB calculation in seconds (default: 120)")
|
| 162 |
+
def compare(directory, mol_converter, n_subsets, csv_path, charge, level, timeout):
|
| 163 |
+
"""Compare XYZ files with their optimized counterparts.
|
| 164 |
+
|
| 165 |
+
Computes RMSD, xTB Energy Difference, and Bond Geometry Metrics.
|
| 166 |
+
Enforces strict connectivity checks.
|
| 167 |
+
|
| 168 |
+
Requires 'optimized_xyz' subdirectory with *_opt.xyz files.
|
| 169 |
+
"""
|
| 170 |
+
import argparse
|
| 171 |
+
from MolecularDiffusion.runmodes.analyze.compare_to_optimized import run_compare_analysis
|
| 172 |
+
|
| 173 |
+
# Construct args namespace to pass to run_compare_analysis
|
| 174 |
+
args = argparse.Namespace(
|
| 175 |
+
directory=directory,
|
| 176 |
+
mol_converter=mol_converter,
|
| 177 |
+
n_subsets=n_subsets,
|
| 178 |
+
csv_path=csv_path,
|
| 179 |
+
charge=charge,
|
| 180 |
+
level=level,
|
| 181 |
+
timeout=timeout
|
| 182 |
+
)
|
| 183 |
+
|
| 184 |
+
run_compare_analysis(args)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
# ============================================================================
|
| 188 |
+
# XYZ2MOL: Convert XYZ to SMILES + fingerprints
|
| 189 |
+
# ============================================================================
|
| 190 |
+
|
| 191 |
+
@analyze.command("xyz2mol", context_settings=CONTEXT_SETTINGS)
|
| 192 |
+
@click.argument("xyz_dir", type=click.Path(exists=True))
|
| 193 |
+
@click.option("--input-csv", "-i", "--i", default=None, type=click.Path(),
|
| 194 |
+
help="Optional input CSV with xyz file list")
|
| 195 |
+
@click.option("--label", "-l", "--l", default=None, type=str,
|
| 196 |
+
help="Label for processed files")
|
| 197 |
+
@click.option("--timeout", "-t", "--t", default=30, type=int,
|
| 198 |
+
help="Timeout per conversion in seconds (default: 30)")
|
| 199 |
+
@click.option("--bits", "-b", "--b", default=2048, type=int,
|
| 200 |
+
help="Number of bits for Morgan fingerprint (default: 2048)")
|
| 201 |
+
@click.option("--verbose", "-v", "--v", is_flag=True,
|
| 202 |
+
help="Enable verbose output")
|
| 203 |
+
def xyz2mol(xyz_dir, input_csv, label, timeout, bits, verbose):
|
| 204 |
+
"""Convert XYZ files to SMILES and extract fingerprints/scaffolds.
|
| 205 |
+
|
| 206 |
+
Outputs are saved to xyz_dir/2d_reprs/:
|
| 207 |
+
- smiles_processed.csv
|
| 208 |
+
- fingerprints.npy
|
| 209 |
+
- scaffolds.txt
|
| 210 |
+
- substructures.json
|
| 211 |
+
|
| 212 |
+
\b
|
| 213 |
+
Examples:
|
| 214 |
+
MolCraftDiff analyze xyz2mol gen_xyz/
|
| 215 |
+
MolCraftDiff analyze xyz2mol gen_xyz/ --bits 1024 -v
|
| 216 |
+
"""
|
| 217 |
+
from pathlib import Path
|
| 218 |
+
import pandas as pd
|
| 219 |
+
import numpy as np
|
| 220 |
+
import json
|
| 221 |
+
import logging
|
| 222 |
+
|
| 223 |
+
from MolecularDiffusion.runmodes.analyze.xyz2mol import (
|
| 224 |
+
load_file_list_from_dir, run_processing, extract_scaffold_and_fingerprints
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
if verbose:
|
| 228 |
+
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
| 229 |
+
|
| 230 |
+
xyz_dir = Path(xyz_dir)
|
| 231 |
+
two_d_reprs_dir = xyz_dir / "2d_reprs"
|
| 232 |
+
two_d_reprs_dir.mkdir(parents=True, exist_ok=True)
|
| 233 |
+
|
| 234 |
+
smiles_csv_output = two_d_reprs_dir / "smiles_processed.csv"
|
| 235 |
+
|
| 236 |
+
click.echo(f"Processing XYZ files from: {xyz_dir}")
|
| 237 |
+
click.echo(f"Output directory: {two_d_reprs_dir}")
|
| 238 |
+
|
| 239 |
+
# Load file list
|
| 240 |
+
if input_csv:
|
| 241 |
+
df = pd.read_csv(input_csv)
|
| 242 |
+
else:
|
| 243 |
+
df = load_file_list_from_dir(str(xyz_dir))
|
| 244 |
+
|
| 245 |
+
# Generate SMILES
|
| 246 |
+
df_smiles = run_processing(df, str(xyz_dir), label, smiles_csv_output, timeout=timeout, verbose=verbose)
|
| 247 |
+
|
| 248 |
+
if df_smiles is None or 'smiles' not in df_smiles.columns or df_smiles['smiles'].isnull().all():
|
| 249 |
+
click.echo("No valid SMILES generated.", err=True)
|
| 250 |
+
return
|
| 251 |
+
|
| 252 |
+
# Extract fingerprints and scaffolds
|
| 253 |
+
click.echo("\nExtracting fingerprints and scaffolds...")
|
| 254 |
+
fps, scaffolds, clean_smiles, n_fail, substruct_counts = \
|
| 255 |
+
extract_scaffold_and_fingerprints(df_smiles["smiles"].dropna().values, fp_bits=bits)
|
| 256 |
+
|
| 257 |
+
np.save(two_d_reprs_dir / "fingerprints.npy", fps)
|
| 258 |
+
with open(two_d_reprs_dir / "scaffolds.txt", "w") as f:
|
| 259 |
+
f.write("\n".join(scaffolds))
|
| 260 |
+
with open(two_d_reprs_dir / "smiles_cleaned.txt", "w") as f:
|
| 261 |
+
f.write("\n".join(clean_smiles))
|
| 262 |
+
with open(two_d_reprs_dir / "substructures.json", "w") as f:
|
| 263 |
+
json.dump(substruct_counts, f, indent=2)
|
| 264 |
+
|
| 265 |
+
total = len(df_smiles["smiles"].dropna())
|
| 266 |
+
click.echo(f"\n--- Summary ---")
|
| 267 |
+
click.echo(f"Total SMILES: {total}")
|
| 268 |
+
click.echo(f"Failed FP extraction: {n_fail}")
|
| 269 |
+
click.echo(f"Unique substructures: {len(substruct_counts)}")
|
| 270 |
+
click.echo(f"Outputs saved to: {two_d_reprs_dir}")
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
# ============================================================================
|
| 274 |
+
# XTB-ELECTRONIC: Compute XTB electronic properties
|
| 275 |
+
# ============================================================================
|
| 276 |
+
|
| 277 |
+
@analyze.command("xtb-electronic", context_settings=CONTEXT_SETTINGS)
|
| 278 |
+
@click.argument("input_dir", type=click.Path(exists=True))
|
| 279 |
+
@click.option("--output", "--o", "-o", default=None, type=click.Path(),
|
| 280 |
+
help="Output file path (without extension for 'all' format)")
|
| 281 |
+
@click.option("--method", "--m", "-m", default="2", type=click.Choice(["1", "2", "ptb"]),
|
| 282 |
+
help="XTB method: 1=GFN1, 2=GFN2, ptb=PTB (default: 2)")
|
| 283 |
+
@click.option("--charge", "--c", "-c", default=0, type=int,
|
| 284 |
+
help="Molecular charge (default: 0)")
|
| 285 |
+
@click.option("--n-unpaired", "--unpaired", default=0, type=int,
|
| 286 |
+
help="Number of unpaired electrons (default: 0)")
|
| 287 |
+
@click.option("--solvent", "--s", "-s", default=None, type=str,
|
| 288 |
+
help="Solvent for solvation calculations (e.g., 'water', 'thf', 'chcl3')")
|
| 289 |
+
@click.option("--properties", "--prop", "-p", multiple=True,
|
| 290 |
+
type=click.Choice(["energy", "dipole", "reactivity", "global",
|
| 291 |
+
"charges", "fukui", "bond_orders", "all"]),
|
| 292 |
+
help="Property groups to compute (default: energy)")
|
| 293 |
+
@click.option("--corrected/--no-corrected", default=True,
|
| 294 |
+
help="Apply empirical IP/EA correction (default: True)")
|
| 295 |
+
@click.option("--timeout", "--t", "-t", default=120, type=int,
|
| 296 |
+
help="Timeout per molecule in seconds (default: 120)")
|
| 297 |
+
@click.option("--n-jobs", "--jobs", "-j", default=1, type=int,
|
| 298 |
+
help="Number of parallel jobs (default: 1)")
|
| 299 |
+
@click.option("--format", "--fmt", "-f", "output_format", default="csv",
|
| 300 |
+
type=click.Choice(["csv", "json", "ase", "all"]),
|
| 301 |
+
help="Output format: csv, json, ase (.db), or all (default: csv)")
|
| 302 |
+
def xtb_electronic(input_dir, output, method, charge, n_unpaired,
|
| 303 |
+
solvent, properties, corrected, timeout, n_jobs, output_format):
|
| 304 |
+
"""Compute XTB electronic properties for XYZ files.
|
| 305 |
+
|
| 306 |
+
Uses morfeus to calculate quantum-chemical descriptors at the GFN-xTB level.
|
| 307 |
+
|
| 308 |
+
\b
|
| 309 |
+
Property groups (molecular-level):
|
| 310 |
+
energy Total energy, HOMO, LUMO, gap, Fermi level
|
| 311 |
+
dipole Dipole moment and vector
|
| 312 |
+
reactivity IP, EA, electronegativity, hardness, softness
|
| 313 |
+
global Electrophilicity, nucleophilicity, fugalities
|
| 314 |
+
solvation Solvation energy, H-bond correction (requires --solvent)
|
| 315 |
+
|
| 316 |
+
\b
|
| 317 |
+
Property groups (atomic-level):
|
| 318 |
+
charges Atomic charges (Mulliken)
|
| 319 |
+
fukui Fukui indices (f+, f-, f, dual)
|
| 320 |
+
bond_orders Bond orders between atom pairs
|
| 321 |
+
|
| 322 |
+
\b
|
| 323 |
+
Output formats:
|
| 324 |
+
csv Molecular-level properties only (one row per molecule)
|
| 325 |
+
json Full data including atomic-level properties
|
| 326 |
+
ase ASE database with properties in atoms.info/arrays
|
| 327 |
+
all Generate all three formats
|
| 328 |
+
|
| 329 |
+
\b
|
| 330 |
+
Examples:
|
| 331 |
+
MolCraftDiff analyze xtb-electronic gen_xyz/
|
| 332 |
+
MolCraftDiff analyze xtb-electronic gen_xyz/ -p energy -p reactivity
|
| 333 |
+
MolCraftDiff analyze xtb-electronic gen_xyz/ -s water -p solvation
|
| 334 |
+
MolCraftDiff analyze xtb-electronic gen_xyz/ -p all -f ase -o results.db
|
| 335 |
+
"""
|
| 336 |
+
from MolecularDiffusion.runmodes.analyze.xtb_electronic import batch_xtb_electronic
|
| 337 |
+
|
| 338 |
+
# Parse method
|
| 339 |
+
if method in ["1", "2"]:
|
| 340 |
+
method = int(method)
|
| 341 |
+
|
| 342 |
+
# Default properties
|
| 343 |
+
if not properties:
|
| 344 |
+
properties = ["energy"]
|
| 345 |
+
|
| 346 |
+
# Default output path
|
| 347 |
+
if output is None:
|
| 348 |
+
output = os.path.join(input_dir, "xtb_electronic")
|
| 349 |
+
|
| 350 |
+
click.echo(f"Computing XTB electronic properties for: {input_dir}")
|
| 351 |
+
click.echo(f"Method: GFN{method}-xTB" if method != "ptb" else "Method: PTB")
|
| 352 |
+
click.echo(f"Charge: {charge}, Unpaired: {n_unpaired}")
|
| 353 |
+
if solvent:
|
| 354 |
+
click.echo(f"Solvent: {solvent}")
|
| 355 |
+
click.echo(f"Properties: {', '.join(properties)}")
|
| 356 |
+
click.echo(f"Output format: {output_format}")
|
| 357 |
+
|
| 358 |
+
df = batch_xtb_electronic(
|
| 359 |
+
input_dir=input_dir,
|
| 360 |
+
output_path=output,
|
| 361 |
+
output_format=output_format,
|
| 362 |
+
method=method,
|
| 363 |
+
charge=charge,
|
| 364 |
+
n_unpaired=n_unpaired,
|
| 365 |
+
solvent=solvent,
|
| 366 |
+
properties=list(properties),
|
| 367 |
+
corrected=corrected,
|
| 368 |
+
timeout=timeout,
|
| 369 |
+
n_jobs=n_jobs,
|
| 370 |
+
)
|
| 371 |
+
|
| 372 |
+
n_success = df["success"].sum() if "success" in df.columns else len(df)
|
| 373 |
+
n_total = len(df)
|
| 374 |
+
|
| 375 |
+
click.echo(f"\n--- Summary ---")
|
| 376 |
+
click.echo(f"Processed: {n_total} molecules")
|
| 377 |
+
click.echo(f"Successful: {n_success}")
|
| 378 |
+
click.echo(f"Failed: {n_total - n_success}")
|
| 379 |
+
click.echo(f"Output saved to: {output}")
|
| 380 |
+
|
MolecularDiffusion/cli/eval_predict.py
ADDED
|
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Eval-Predict command for MolCraft CLI.
|
| 2 |
+
|
| 3 |
+
Adapted from scripts/eval_predict.py for package-level execution.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from typing import Any, Dict, Tuple
|
| 8 |
+
|
| 9 |
+
import hydra
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pandas as pd
|
| 12 |
+
import torch
|
| 13 |
+
from omegaconf import DictConfig, OmegaConf
|
| 14 |
+
from torch.utils.data import ConcatDataset
|
| 15 |
+
|
| 16 |
+
from MolecularDiffusion.core import Engine
|
| 17 |
+
from MolecularDiffusion.runmodes.train import DataModule, ModelTaskFactory_EGCL, OptimSchedulerFactory
|
| 18 |
+
from MolecularDiffusion.utils import RankedLogger, seed_everything
|
| 19 |
+
from MolecularDiffusion.utils.plot_function import (
|
| 20 |
+
plot_kde_distribution,
|
| 21 |
+
plot_histogram_distribution,
|
| 22 |
+
plot_kde_distribution_multiple,
|
| 23 |
+
plot_correlation_with_histograms,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
log = RankedLogger(__name__, rank_zero_only=True)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def is_rank_zero():
|
| 30 |
+
"""Check if current process is rank zero."""
|
| 31 |
+
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
| 32 |
+
return torch.distributed.get_rank() == 0
|
| 33 |
+
return True
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def load_checkpoint_weights(task, chkpt_path):
|
| 37 |
+
"""Load weights from checkpoint with support for Engine and Lightning formats."""
|
| 38 |
+
log.info(f"Loading weights from: {chkpt_path}")
|
| 39 |
+
|
| 40 |
+
checkpoint = torch.load(chkpt_path, map_location="cpu", weights_only=False)
|
| 41 |
+
|
| 42 |
+
# Check if it's a Lightning checkpoint
|
| 43 |
+
if "state_dict" in checkpoint:
|
| 44 |
+
log.info("Detected Lightning checkpoint.")
|
| 45 |
+
state_dict = checkpoint["state_dict"]
|
| 46 |
+
cleaned_state_dict = {}
|
| 47 |
+
for k, v in state_dict.items():
|
| 48 |
+
if k.startswith("task."):
|
| 49 |
+
cleaned_state_dict[k[5:]] = v
|
| 50 |
+
else:
|
| 51 |
+
cleaned_state_dict[k] = v
|
| 52 |
+
|
| 53 |
+
load_result = task.load_state_dict(cleaned_state_dict, strict=False)
|
| 54 |
+
log.info(f"Loaded {len(cleaned_state_dict)} parameters from state_dict")
|
| 55 |
+
if load_result.missing_keys:
|
| 56 |
+
log.warning(f"Missing keys: {load_result.missing_keys}")
|
| 57 |
+
|
| 58 |
+
# Recover statistics
|
| 59 |
+
for key in ["mean", "std", "weight"]:
|
| 60 |
+
val = None
|
| 61 |
+
if key in checkpoint:
|
| 62 |
+
val = checkpoint[key]
|
| 63 |
+
elif f"task.{key}" in state_dict:
|
| 64 |
+
val = state_dict[f"task.{key}"]
|
| 65 |
+
elif key in state_dict:
|
| 66 |
+
val = state_dict[key]
|
| 67 |
+
|
| 68 |
+
if val is not None:
|
| 69 |
+
if not isinstance(val, torch.Tensor):
|
| 70 |
+
val = torch.as_tensor(val, dtype=torch.float32)
|
| 71 |
+
|
| 72 |
+
# Register as buffer to ensure it moves with the model to the correct device
|
| 73 |
+
if key in task._buffers:
|
| 74 |
+
task._buffers[key].copy_(val)
|
| 75 |
+
else:
|
| 76 |
+
task.register_buffer(key, val)
|
| 77 |
+
elif "model" in checkpoint:
|
| 78 |
+
log.info("Detected original Engine checkpoint.")
|
| 79 |
+
task.load_state_dict(checkpoint["model"], strict=False)
|
| 80 |
+
# Recover statistics
|
| 81 |
+
for key in ["mean", "std", "weight"]:
|
| 82 |
+
if key in checkpoint["model"]:
|
| 83 |
+
val = checkpoint["model"][key]
|
| 84 |
+
if not isinstance(val, torch.Tensor):
|
| 85 |
+
val = torch.as_tensor(val, dtype=torch.float32)
|
| 86 |
+
|
| 87 |
+
# Register as buffer to ensure it moves with the model to the correct device
|
| 88 |
+
if key in task._buffers:
|
| 89 |
+
task._buffers[key].copy_(val)
|
| 90 |
+
else:
|
| 91 |
+
task.register_buffer(key, val)
|
| 92 |
+
else:
|
| 93 |
+
# Fallback for unexpected formats
|
| 94 |
+
log.warning("Unknown checkpoint format. Attempting direct load.")
|
| 95 |
+
task.load_state_dict(checkpoint, strict=False)
|
| 96 |
+
|
| 97 |
+
# Ensure task has a device attribute for initial loading,
|
| 98 |
+
# but don't hardcode it if it's about to be moved by Engine
|
| 99 |
+
if not hasattr(task, 'device'):
|
| 100 |
+
task.device = next(task.parameters()).device if list(task.parameters()) else torch.device('cpu')
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def engine_wrapper(task_module, data_module, trainer_module):
|
| 104 |
+
"""Run evaluation with Engine."""
|
| 105 |
+
trainer_module.get_optimizer()
|
| 106 |
+
trainer_module.get_scheduler()
|
| 107 |
+
|
| 108 |
+
pred_dataset = ConcatDataset([data_module.valid_set, data_module.test_set])
|
| 109 |
+
solver = Engine(
|
| 110 |
+
task_module.task,
|
| 111 |
+
None,
|
| 112 |
+
None,
|
| 113 |
+
pred_dataset,
|
| 114 |
+
batch_size=data_module.batch_size,
|
| 115 |
+
collate_fn=data_module.collate_fn,
|
| 116 |
+
logger="logging",
|
| 117 |
+
)
|
| 118 |
+
# Ensure task.device is updated to the actual device solver is using
|
| 119 |
+
task_module.task.device = solver.device
|
| 120 |
+
|
| 121 |
+
_, preds_test, targets_test = solver.evaluate("test")
|
| 122 |
+
y_preds = torch.cat(preds_test, dim=0)
|
| 123 |
+
y_trues = torch.cat(targets_test, dim=0)
|
| 124 |
+
return y_preds, y_trues
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def predict(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
| 128 |
+
"""Evaluate predictions on validation/test sets."""
|
| 129 |
+
if cfg.get("seed"):
|
| 130 |
+
seed_everything(cfg.seed, workers=True)
|
| 131 |
+
|
| 132 |
+
log.info(f"Instantiating datamodule <{cfg.data._target_}>")
|
| 133 |
+
data_module: DataModule = hydra.utils.instantiate(
|
| 134 |
+
cfg.data, task_type=cfg.tasks.task_type, train_ratio=0
|
| 135 |
+
)
|
| 136 |
+
data_module.load()
|
| 137 |
+
|
| 138 |
+
log.info(f"Instantiating task <{cfg.tasks._target_}>")
|
| 139 |
+
act_fn = hydra.utils.instantiate(cfg.tasks.act_fn)
|
| 140 |
+
|
| 141 |
+
# Store checkpoint path and temporarily disable it for task_module.build()
|
| 142 |
+
# to avoid the factory's internal (legacy) loading.
|
| 143 |
+
chkpt_path = cfg.tasks.get("chkpt_path")
|
| 144 |
+
|
| 145 |
+
# Create a copy of the config to modify safely
|
| 146 |
+
tasks_cfg = OmegaConf.to_container(cfg.tasks, resolve=True)
|
| 147 |
+
tasks_cfg['chkpt_path'] = None
|
| 148 |
+
tasks_cfg = OmegaConf.create(tasks_cfg)
|
| 149 |
+
|
| 150 |
+
task_module: ModelTaskFactory_EGCL = hydra.utils.instantiate(tasks_cfg, act_fn=act_fn)
|
| 151 |
+
task_module.build()
|
| 152 |
+
|
| 153 |
+
# Manually load weights using our robust loader
|
| 154 |
+
if chkpt_path:
|
| 155 |
+
load_checkpoint_weights(task_module.task, chkpt_path)
|
| 156 |
+
|
| 157 |
+
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
|
| 158 |
+
trainer_module: OptimSchedulerFactory = hydra.utils.instantiate(
|
| 159 |
+
cfg.trainer, parameters=task_module.task.parameters()
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
object_dict = {
|
| 163 |
+
"cfg": cfg,
|
| 164 |
+
"datamodule": data_module,
|
| 165 |
+
"task": task_module,
|
| 166 |
+
"trainer": trainer_module,
|
| 167 |
+
}
|
| 168 |
+
|
| 169 |
+
log.info("Logging hyperparameters!")
|
| 170 |
+
log_hyperparameters(object_dict)
|
| 171 |
+
|
| 172 |
+
y_preds, y_trues = engine_wrapper(task_module, data_module, trainer_module)
|
| 173 |
+
|
| 174 |
+
df = pd.read_csv(cfg.data.filename)
|
| 175 |
+
task_matrix = df[cfg.tasks.task_learn].to_numpy()
|
| 176 |
+
filenames = df["filename"].to_numpy()
|
| 177 |
+
filenames_aligned = []
|
| 178 |
+
|
| 179 |
+
for row in y_trues.cpu().numpy():
|
| 180 |
+
mask = np.all(np.isclose(task_matrix, row, atol=1e-4), axis=1)
|
| 181 |
+
idx = np.flatnonzero(mask)
|
| 182 |
+
|
| 183 |
+
if idx.size == 0:
|
| 184 |
+
raise ValueError(f"No match for row {row}")
|
| 185 |
+
if idx.size > 1:
|
| 186 |
+
raise ValueError(f"Multiple matches for row {row}: {filenames[idx].tolist()}")
|
| 187 |
+
|
| 188 |
+
filenames_aligned.append(filenames[idx[0]])
|
| 189 |
+
|
| 190 |
+
df_compiled = pd.DataFrame({
|
| 191 |
+
"filename": filenames_aligned,
|
| 192 |
+
"y_true": y_trues.cpu().numpy().tolist(),
|
| 193 |
+
"y_pred": y_preds.cpu().numpy().tolist(),
|
| 194 |
+
})
|
| 195 |
+
|
| 196 |
+
os.makedirs(cfg.output_directory, exist_ok=True)
|
| 197 |
+
df_compiled.to_csv(f"{cfg.output_directory}/predictions.csv", index=False)
|
| 198 |
+
|
| 199 |
+
log.info("Prediction statistics:")
|
| 200 |
+
for task_name in cfg.tasks.task_learn:
|
| 201 |
+
log.info(f"--- {task_name} ---")
|
| 202 |
+
log.info(f"Mean: {df[task_name].mean():.4f}")
|
| 203 |
+
log.info(f"Std: {df[task_name].std():.4f}")
|
| 204 |
+
log.info(f"Min: {df[task_name].min():.4f}")
|
| 205 |
+
log.info(f"Max: {df[task_name].max():.4f}")
|
| 206 |
+
|
| 207 |
+
log.info("Plotting distributions...")
|
| 208 |
+
props = []
|
| 209 |
+
for i, prop in enumerate(cfg.tasks.task_learn):
|
| 210 |
+
plot_kde_distribution(df[prop], prop, f"{cfg.output_directory}/{prop}_kde.png")
|
| 211 |
+
plot_histogram_distribution(df[prop], prop, f"{cfg.output_directory}/{prop}_hist.png")
|
| 212 |
+
plot_correlation_with_histograms(
|
| 213 |
+
y_trues[:, i].cpu().numpy(),
|
| 214 |
+
y_preds[:, i].cpu().numpy(),
|
| 215 |
+
prop,
|
| 216 |
+
"",
|
| 217 |
+
f"{cfg.output_directory}/{prop}_correlation.png",
|
| 218 |
+
)
|
| 219 |
+
props.append(df[prop].values)
|
| 220 |
+
|
| 221 |
+
props = np.array(props).T
|
| 222 |
+
plot_kde_distribution_multiple(props, cfg.tasks.task_learn, f"{cfg.output_directory}/kde_all.png")
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
def log_hyperparameters(object_dict: dict):
|
| 226 |
+
"""Log hyperparameters for debugging."""
|
| 227 |
+
if not is_rank_zero():
|
| 228 |
+
return
|
| 229 |
+
|
| 230 |
+
log.info("\n========== Logging Hyperparameters ==========\n")
|
| 231 |
+
for name, obj in object_dict.items():
|
| 232 |
+
log.info(f"{'=' * 20} {name.upper()} {'=' * 20}")
|
| 233 |
+
if name == "cfg":
|
| 234 |
+
if isinstance(obj, dict):
|
| 235 |
+
log.info("\n" + OmegaConf.to_yaml(OmegaConf.create(obj)))
|
| 236 |
+
else:
|
| 237 |
+
log.info("\n" + OmegaConf.to_yaml(obj))
|
| 238 |
+
else:
|
| 239 |
+
if hasattr(obj, '__dict__'):
|
| 240 |
+
for k, v in vars(obj).items():
|
| 241 |
+
if not k.startswith("_"):
|
| 242 |
+
log.info(f"{k}: {v}")
|
| 243 |
+
log.info(f"{'=' * (44 + len(name))}\n")
|
| 244 |
+
|
| 245 |
+
if "task" in object_dict and hasattr(object_dict["task"], "task"):
|
| 246 |
+
model = object_dict["task"].task
|
| 247 |
+
total = sum(p.numel() for p in model.parameters())
|
| 248 |
+
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 249 |
+
log.info(f"{'=' * 20} MODEL PARAMS {'=' * 20}")
|
| 250 |
+
log.info(f"model/params/total: {total}")
|
| 251 |
+
log.info(f"model/params/trainable: {trainable}")
|
| 252 |
+
log.info("=" * 54 + "\n")
|
| 253 |
+
|
| 254 |
+
log.info("========== End of Hyperparameters ==========\n")
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
def eval_predict_main(cfg: DictConfig):
|
| 258 |
+
"""Entry point for CLI eval-predict command."""
|
| 259 |
+
predict(cfg)
|
MolecularDiffusion/cli/generate.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Generation command for MolCraft CLI.
|
| 2 |
+
|
| 3 |
+
Adapted from scripts/generate.py for package-level execution.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import glob
|
| 7 |
+
import os
|
| 8 |
+
import re
|
| 9 |
+
import time
|
| 10 |
+
import copy
|
| 11 |
+
import pickle
|
| 12 |
+
from typing import Any, Dict, Optional, Tuple
|
| 13 |
+
|
| 14 |
+
import hydra
|
| 15 |
+
import torch
|
| 16 |
+
from omegaconf import DictConfig, OmegaConf
|
| 17 |
+
|
| 18 |
+
from MolecularDiffusion.core import Engine
|
| 19 |
+
from MolecularDiffusion.runmodes.generate.tasks_generate import GenerativeFactory
|
| 20 |
+
from MolecularDiffusion.utils import (
|
| 21 |
+
RankedLogger,
|
| 22 |
+
seed_everything,
|
| 23 |
+
recursive_module_to_device,
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
log = RankedLogger(__name__, rank_zero_only=True)
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def is_rank_zero():
|
| 30 |
+
"""Check if current process is rank zero."""
|
| 31 |
+
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
| 32 |
+
return torch.distributed.get_rank() == 0
|
| 33 |
+
return True
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def load_lightning_model(chkpt_path, task_config, atom_vocab=None, total_step=0):
|
| 37 |
+
"""Load model from Lightning checkpoint (.ckpt)."""
|
| 38 |
+
log.info(f"Loading Lightning checkpoint from: {chkpt_path}")
|
| 39 |
+
|
| 40 |
+
try:
|
| 41 |
+
from MolecularDiffusion.core.engine_lightning import EngineLightning
|
| 42 |
+
wrapper = EngineLightning.load_from_checkpoint(chkpt_path, map_location="cpu")
|
| 43 |
+
log.info("Successfully loaded model using EngineLightning.load_from_checkpoint")
|
| 44 |
+
|
| 45 |
+
if atom_vocab and hasattr(wrapper.task, 'atom_vocab') and wrapper.task.atom_vocab is None:
|
| 46 |
+
wrapper.task.atom_vocab = atom_vocab
|
| 47 |
+
|
| 48 |
+
# Apply diffusion_steps override from config
|
| 49 |
+
if total_step > 0:
|
| 50 |
+
if hasattr(wrapper.task, 'model') and hasattr(wrapper.task.model, 'T'):
|
| 51 |
+
log.info(f"Overriding diffusion steps: {wrapper.task.model.T} -> {total_step}")
|
| 52 |
+
wrapper.task.model.T = total_step
|
| 53 |
+
elif hasattr(wrapper.task, 'T'):
|
| 54 |
+
log.info(f"Overriding diffusion steps: {wrapper.task.T} -> {total_step}")
|
| 55 |
+
wrapper.task.T = total_step
|
| 56 |
+
|
| 57 |
+
wrapper.task.eval()
|
| 58 |
+
return wrapper.task
|
| 59 |
+
|
| 60 |
+
except Exception as e:
|
| 61 |
+
log.warning(f"EngineLightning.load_from_checkpoint failed ({type(e).__name__}: {e}). Falling back to manual config reconstruction.")
|
| 62 |
+
|
| 63 |
+
# Fallback: Load checkpoint manually
|
| 64 |
+
checkpoint = torch.load(chkpt_path, map_location="cpu", weights_only=False)
|
| 65 |
+
|
| 66 |
+
hparams = checkpoint.get("hyper_parameters", {})
|
| 67 |
+
if "model_config" in hparams and hparams["model_config"] is not None:
|
| 68 |
+
task_config = OmegaConf.create(hparams["model_config"])
|
| 69 |
+
log.info("Loaded task configuration from checkpoint hyperparameters")
|
| 70 |
+
elif task_config is None:
|
| 71 |
+
raise ValueError("task_config not provided and 'model_config' not found in checkpoint.")
|
| 72 |
+
|
| 73 |
+
task_config = copy.deepcopy(task_config)
|
| 74 |
+
OmegaConf.set_readonly(task_config, False)
|
| 75 |
+
OmegaConf.set_struct(task_config, False)
|
| 76 |
+
|
| 77 |
+
n_types = len(atom_vocab) if atom_vocab else 0
|
| 78 |
+
|
| 79 |
+
if OmegaConf.is_missing(task_config, "num_atom_types") or task_config.get("num_atom_types") == "???":
|
| 80 |
+
task_config.num_atom_types = n_types if n_types > 0 else 100
|
| 81 |
+
|
| 82 |
+
if hasattr(task_config, "transformer_config"):
|
| 83 |
+
if OmegaConf.is_missing(task_config.transformer_config, "atom_dim"):
|
| 84 |
+
task_config.transformer_config.atom_dim = task_config.num_atom_types
|
| 85 |
+
|
| 86 |
+
if hasattr(task_config, "dataset_stats"):
|
| 87 |
+
if OmegaConf.is_missing(task_config.dataset_stats, "max_atoms"):
|
| 88 |
+
task_config.dataset_stats.max_atoms = 150
|
| 89 |
+
|
| 90 |
+
log.info(f"Building task from config: {task_config._target_}")
|
| 91 |
+
task_factory = hydra.utils.instantiate(task_config, atom_vocab=atom_vocab)
|
| 92 |
+
task = task_factory.build()
|
| 93 |
+
|
| 94 |
+
state_dict = checkpoint.get('state_dict', {})
|
| 95 |
+
cleaned_state_dict = {}
|
| 96 |
+
for key, value in state_dict.items():
|
| 97 |
+
if key.startswith('task.'):
|
| 98 |
+
cleaned_state_dict[key[5:]] = value
|
| 99 |
+
else:
|
| 100 |
+
cleaned_state_dict[key] = value
|
| 101 |
+
|
| 102 |
+
task.load_state_dict(cleaned_state_dict, strict=False)
|
| 103 |
+
log.info(f"Loaded {len(cleaned_state_dict)} parameters from checkpoint")
|
| 104 |
+
|
| 105 |
+
if 'data_stats' in checkpoint:
|
| 106 |
+
task.tabasco_model.set_data_stats(checkpoint['data_stats'])
|
| 107 |
+
if 'node_dist_model' in checkpoint:
|
| 108 |
+
task._node_dist_model = checkpoint['node_dist_model']
|
| 109 |
+
if 'prop_dist_model' in checkpoint:
|
| 110 |
+
task.prop_dist_model = checkpoint['prop_dist_model']
|
| 111 |
+
|
| 112 |
+
if total_step > 0:
|
| 113 |
+
if hasattr(task, 'model') and hasattr(task.model, 'T'):
|
| 114 |
+
task.model.T = total_step
|
| 115 |
+
elif hasattr(task, 'T'):
|
| 116 |
+
task.T = total_step
|
| 117 |
+
|
| 118 |
+
task.eval()
|
| 119 |
+
return task
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def load_model(chkpt_directory, task_config=None, atom_vocab=None, total_step=0):
|
| 123 |
+
"""Load model from checkpoint directory with auto-detection."""
|
| 124 |
+
ckpt_files = glob.glob(os.path.join(chkpt_directory, '*.ckpt'))
|
| 125 |
+
|
| 126 |
+
if ckpt_files:
|
| 127 |
+
best_metric = -1.0
|
| 128 |
+
best_checkpoint = None
|
| 129 |
+
|
| 130 |
+
for ckpt_file in ckpt_files:
|
| 131 |
+
match = re.search(r"(?:metric|val)[_=](\d+\.?\d*)", os.path.basename(ckpt_file))
|
| 132 |
+
if match:
|
| 133 |
+
metric = float(match.group(1))
|
| 134 |
+
if metric > best_metric:
|
| 135 |
+
best_metric = metric
|
| 136 |
+
best_checkpoint = ckpt_file
|
| 137 |
+
|
| 138 |
+
if best_checkpoint is None:
|
| 139 |
+
last_ckpt = os.path.join(chkpt_directory, 'last.ckpt')
|
| 140 |
+
best_checkpoint = last_ckpt if os.path.exists(last_ckpt) else ckpt_files[0]
|
| 141 |
+
|
| 142 |
+
task = load_lightning_model(best_checkpoint, task_config, atom_vocab, total_step)
|
| 143 |
+
|
| 144 |
+
try:
|
| 145 |
+
with open(os.path.join(chkpt_directory, "edm_stat.pkl"), "rb") as file:
|
| 146 |
+
edm_stats = pickle.load(file)
|
| 147 |
+
task.node_dist_model = edm_stats.get("node")
|
| 148 |
+
if "prop" in edm_stats:
|
| 149 |
+
task.prop_dist_model = edm_stats["prop"]
|
| 150 |
+
except (ImportError, FileNotFoundError):
|
| 151 |
+
log.warning("edm_stat.pkl not found")
|
| 152 |
+
|
| 153 |
+
return task
|
| 154 |
+
|
| 155 |
+
# Original engine (.pkl files)
|
| 156 |
+
model_path = os.path.join(chkpt_directory, "edm_chem.pkl")
|
| 157 |
+
|
| 158 |
+
if not os.path.exists(model_path):
|
| 159 |
+
checkpoint_files = glob.glob(os.path.join(chkpt_directory, '*.pkl'))
|
| 160 |
+
checkpoint_files = [f for f in checkpoint_files if 'edm_stat.pkl' not in os.path.basename(f)]
|
| 161 |
+
|
| 162 |
+
if not checkpoint_files:
|
| 163 |
+
raise FileNotFoundError(f"No checkpoints found in {chkpt_directory}")
|
| 164 |
+
|
| 165 |
+
best_metric = -1.0
|
| 166 |
+
best_checkpoint = None
|
| 167 |
+
|
| 168 |
+
for ckpt_file in checkpoint_files:
|
| 169 |
+
match = re.search(r"metric=([\d.]+)\.pkl", os.path.basename(ckpt_file))
|
| 170 |
+
if match:
|
| 171 |
+
metric = float(match.group(1))
|
| 172 |
+
if metric > best_metric:
|
| 173 |
+
best_metric = metric
|
| 174 |
+
best_checkpoint = ckpt_file
|
| 175 |
+
|
| 176 |
+
model_path = best_checkpoint or checkpoint_files[0]
|
| 177 |
+
|
| 178 |
+
log.info(f"Loading original engine checkpoint from: {model_path}")
|
| 179 |
+
|
| 180 |
+
edm_stats = {"node": None, "prop": None}
|
| 181 |
+
stat_path = os.path.join(chkpt_directory, "edm_stat.pkl")
|
| 182 |
+
if os.path.exists(stat_path):
|
| 183 |
+
try:
|
| 184 |
+
with open(stat_path, "rb") as file:
|
| 185 |
+
loaded_stats = pickle.load(file)
|
| 186 |
+
if "node" in loaded_stats:
|
| 187 |
+
edm_stats["node"] = loaded_stats["node"]
|
| 188 |
+
elif "node_dist_model" in loaded_stats:
|
| 189 |
+
edm_stats["node"] = loaded_stats["node_dist_model"]
|
| 190 |
+
if "prop" in loaded_stats:
|
| 191 |
+
edm_stats["prop"] = loaded_stats["prop"]
|
| 192 |
+
elif "prop_dist_model" in loaded_stats:
|
| 193 |
+
edm_stats["prop"] = loaded_stats["prop_dist_model"]
|
| 194 |
+
except Exception as e:
|
| 195 |
+
log.warning(f"Failed to load edm_stat.pkl: {e}")
|
| 196 |
+
|
| 197 |
+
engine = Engine(None, None, None, None, None)
|
| 198 |
+
engine = engine.load_from_checkpoint(model_path, interference_mode=True)
|
| 199 |
+
task = engine.model
|
| 200 |
+
|
| 201 |
+
if edm_stats["node"] is not None:
|
| 202 |
+
task.node_dist_model = edm_stats["node"]
|
| 203 |
+
if edm_stats["prop"] is not None:
|
| 204 |
+
task.prop_dist_model = edm_stats["prop"]
|
| 205 |
+
|
| 206 |
+
if total_step > 0:
|
| 207 |
+
if hasattr(task, 'model') and hasattr(task.model, 'T'):
|
| 208 |
+
task.model.T = total_step
|
| 209 |
+
elif hasattr(task, 'T'):
|
| 210 |
+
task.T = total_step
|
| 211 |
+
|
| 212 |
+
task.eval()
|
| 213 |
+
return task
|
| 214 |
+
|
| 215 |
+
|
| 216 |
+
def generate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
| 217 |
+
"""Main generation function."""
|
| 218 |
+
if cfg.get("seed"):
|
| 219 |
+
seed_everything(cfg.seed, workers=True)
|
| 220 |
+
|
| 221 |
+
log.info(f"Instantiating diffusion task and loading the model <{cfg.tasks._target_}>")
|
| 222 |
+
task = load_model(
|
| 223 |
+
cfg.chkpt_directory,
|
| 224 |
+
task_config=cfg.tasks,
|
| 225 |
+
atom_vocab=cfg.atom_vocab,
|
| 226 |
+
total_step=cfg.diffusion_steps,
|
| 227 |
+
)
|
| 228 |
+
|
| 229 |
+
if not hasattr(task, 'atom_vocab') or task.atom_vocab is None:
|
| 230 |
+
task.atom_vocab = cfg.atom_vocab
|
| 231 |
+
|
| 232 |
+
if not hasattr(task, 'device'):
|
| 233 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 234 |
+
recursive_module_to_device(task, device)
|
| 235 |
+
|
| 236 |
+
log.info(f"Instantiating generator... <{cfg.interference._target_}>")
|
| 237 |
+
generator: GenerativeFactory = hydra.utils.instantiate(cfg.interference, task=task)
|
| 238 |
+
|
| 239 |
+
object_dict = {"cfg": cfg, "task": task, "generator": generator}
|
| 240 |
+
|
| 241 |
+
log.info("Logging hyperparameters!")
|
| 242 |
+
log_hyperparameters(object_dict)
|
| 243 |
+
|
| 244 |
+
os.makedirs(cfg.interference.output_path, exist_ok=True)
|
| 245 |
+
|
| 246 |
+
if is_rank_zero():
|
| 247 |
+
config_path = os.path.join(cfg.interference.output_path, "config.yaml")
|
| 248 |
+
with open(config_path, "w") as f:
|
| 249 |
+
OmegaConf.save(config=cfg, f=f)
|
| 250 |
+
log.info(f"Configuration saved to {config_path}")
|
| 251 |
+
|
| 252 |
+
generator.run()
|
| 253 |
+
|
| 254 |
+
|
| 255 |
+
def log_hyperparameters(object_dict: dict):
|
| 256 |
+
"""Log hyperparameters for debugging."""
|
| 257 |
+
if not is_rank_zero():
|
| 258 |
+
return
|
| 259 |
+
|
| 260 |
+
log.info("\n========== Logging Hyperparameters ==========\n")
|
| 261 |
+
for name, obj in object_dict.items():
|
| 262 |
+
log.info(f"{'=' * 20} {name.upper()} {'=' * 20}")
|
| 263 |
+
if name == "cfg":
|
| 264 |
+
if isinstance(obj, dict):
|
| 265 |
+
log.info("\n" + OmegaConf.to_yaml(OmegaConf.create(obj)))
|
| 266 |
+
else:
|
| 267 |
+
log.info("\n" + OmegaConf.to_yaml(obj))
|
| 268 |
+
else:
|
| 269 |
+
if hasattr(obj, '__dict__'):
|
| 270 |
+
for k, v in vars(obj).items():
|
| 271 |
+
if not k.startswith("_"):
|
| 272 |
+
log.info(f"{k}: {v}")
|
| 273 |
+
log.info(f"{'=' * (44 + len(name))}\n")
|
| 274 |
+
log.info("========== End of Hyperparameters ==========\n")
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
def generate_main(cfg: DictConfig):
|
| 278 |
+
"""Entry point for CLI generate command."""
|
| 279 |
+
start_time = time.time()
|
| 280 |
+
generate(cfg)
|
| 281 |
+
total_time = time.time() - start_time
|
| 282 |
+
log.warning(f"Total time of execution: {total_time:.2f} seconds")
|
MolecularDiffusion/cli/main.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""MolCraft CLI - Unified command-line interface for MolecularDiffusion.
|
| 2 |
+
|
| 3 |
+
Usage:
|
| 4 |
+
molcraft train config.yaml [overrides...]
|
| 5 |
+
molcraft generate config.yaml [overrides...]
|
| 6 |
+
molcraft predict config.yaml [overrides...]
|
| 7 |
+
"""
|
| 8 |
+
|
| 9 |
+
import os
|
| 10 |
+
import logging
|
| 11 |
+
import platform
|
| 12 |
+
|
| 13 |
+
import click
|
| 14 |
+
|
| 15 |
+
# Setup logging
|
| 16 |
+
logging.basicConfig(
|
| 17 |
+
level=logging.INFO,
|
| 18 |
+
format='%(asctime)s - %(levelname)s - %(message)s'
|
| 19 |
+
)
|
| 20 |
+
logger = logging.getLogger(__name__)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def log_system_info():
|
| 24 |
+
"""Log basic system information."""
|
| 25 |
+
import psutil
|
| 26 |
+
|
| 27 |
+
logger.info("=" * 60)
|
| 28 |
+
logger.info(f"OS: {platform.system()} {platform.release()}")
|
| 29 |
+
logger.info(f"CPU: {platform.processor()}, Cores: {os.cpu_count()}")
|
| 30 |
+
|
| 31 |
+
ram = psutil.virtual_memory()
|
| 32 |
+
logger.info(f"RAM: Total {ram.total / (1024**3):.2f} GB, Available {ram.available / (1024**3):.2f} GB")
|
| 33 |
+
logger.info(f"Python: {platform.python_version()}")
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
import torch
|
| 37 |
+
logger.info(f"PyTorch: {torch.__version__}")
|
| 38 |
+
if torch.cuda.is_available():
|
| 39 |
+
logger.info(f"CUDA: {torch.version.cuda}, GPUs: {torch.cuda.device_count()}")
|
| 40 |
+
except ImportError:
|
| 41 |
+
pass
|
| 42 |
+
|
| 43 |
+
logger.info("=" * 60)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# Enable -h as alias for --help
|
| 47 |
+
CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
@click.group(context_settings=CONTEXT_SETTINGS)
|
| 51 |
+
@click.version_option(package_name="MolecularDiffusion")
|
| 52 |
+
def cli():
|
| 53 |
+
"""MolCraft - Molecular Diffusion CLI.
|
| 54 |
+
|
| 55 |
+
A unified command-line interface for training, generation, and prediction
|
| 56 |
+
with molecular diffusion models.
|
| 57 |
+
|
| 58 |
+
\b
|
| 59 |
+
Examples:
|
| 60 |
+
molcraft train configs/my_train_config.yaml
|
| 61 |
+
molcraft generate configs/my_gen_config.yaml
|
| 62 |
+
molcraft predict configs/my_pred_config.yaml
|
| 63 |
+
"""
|
| 64 |
+
pass
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@cli.command(context_settings=CONTEXT_SETTINGS)
|
| 68 |
+
@click.argument("config", type=str)
|
| 69 |
+
@click.argument("overrides", nargs=-1)
|
| 70 |
+
def train(config: str, overrides: tuple):
|
| 71 |
+
"""Train a molecular diffusion model.
|
| 72 |
+
|
| 73 |
+
\b
|
| 74 |
+
Arguments:
|
| 75 |
+
CONFIG Config file path (e.g., configs/train.yaml)
|
| 76 |
+
OVERRIDES Hydra-style config overrides (e.g., trainer.num_epochs=100)
|
| 77 |
+
|
| 78 |
+
\b
|
| 79 |
+
Examples:
|
| 80 |
+
molcraft train configs/train_tabasco_geom.yaml
|
| 81 |
+
molcraft train configs/my_config.yaml trainer.num_epochs=50 seed=42
|
| 82 |
+
"""
|
| 83 |
+
log_system_info()
|
| 84 |
+
logger.info(f"Starting training with config: {config}")
|
| 85 |
+
|
| 86 |
+
from MolecularDiffusion.cli._hydra import run_hydra_app
|
| 87 |
+
from MolecularDiffusion.cli.train import train_main
|
| 88 |
+
|
| 89 |
+
run_hydra_app(
|
| 90 |
+
config_name=config,
|
| 91 |
+
task_function=train_main,
|
| 92 |
+
config_dir=None,
|
| 93 |
+
overrides=list(overrides),
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
@cli.command(context_settings=CONTEXT_SETTINGS)
|
| 98 |
+
@click.argument("config", type=str)
|
| 99 |
+
@click.argument("overrides", nargs=-1)
|
| 100 |
+
def generate(config: str, overrides: tuple):
|
| 101 |
+
"""Generate molecules using a trained model.
|
| 102 |
+
|
| 103 |
+
\b
|
| 104 |
+
Arguments:
|
| 105 |
+
CONFIG Config file path (e.g., configs/generate.yaml)
|
| 106 |
+
OVERRIDES Hydra-style config overrides
|
| 107 |
+
|
| 108 |
+
\b
|
| 109 |
+
Examples:
|
| 110 |
+
molcraft generate configs/gen_config.yaml
|
| 111 |
+
molcraft generate configs/gen_config.yaml interference.n_samples=1000
|
| 112 |
+
"""
|
| 113 |
+
log_system_info()
|
| 114 |
+
logger.info(f"Starting generation with config: {config}")
|
| 115 |
+
|
| 116 |
+
from MolecularDiffusion.cli._hydra import run_hydra_app
|
| 117 |
+
from MolecularDiffusion.cli.generate import generate_main
|
| 118 |
+
|
| 119 |
+
run_hydra_app(
|
| 120 |
+
config_name=config,
|
| 121 |
+
task_function=generate_main,
|
| 122 |
+
config_dir=None,
|
| 123 |
+
overrides=list(overrides),
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
@cli.command(context_settings=CONTEXT_SETTINGS)
|
| 128 |
+
@click.argument("config", type=str)
|
| 129 |
+
@click.argument("overrides", nargs=-1)
|
| 130 |
+
def predict(config: str, overrides: tuple):
|
| 131 |
+
"""Run property prediction on molecules.
|
| 132 |
+
|
| 133 |
+
\b
|
| 134 |
+
Arguments:
|
| 135 |
+
CONFIG Config file path (e.g., configs/predict.yaml)
|
| 136 |
+
OVERRIDES Hydra-style config overrides
|
| 137 |
+
|
| 138 |
+
\b
|
| 139 |
+
Examples:
|
| 140 |
+
molcraft predict configs/predict.yaml
|
| 141 |
+
molcraft predict configs/my_pred.yaml xyz_directory=/path/to/xyz
|
| 142 |
+
"""
|
| 143 |
+
log_system_info()
|
| 144 |
+
logger.info(f"Starting prediction with config: {config}")
|
| 145 |
+
|
| 146 |
+
from MolecularDiffusion.cli._hydra import run_hydra_app
|
| 147 |
+
from MolecularDiffusion.cli.predict import predict_main
|
| 148 |
+
|
| 149 |
+
run_hydra_app(
|
| 150 |
+
config_name=config,
|
| 151 |
+
task_function=predict_main,
|
| 152 |
+
config_dir=None,
|
| 153 |
+
overrides=list(overrides),
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
@cli.command("eval-predict", context_settings=CONTEXT_SETTINGS)
|
| 158 |
+
@click.argument("config", type=str)
|
| 159 |
+
@click.argument("overrides", nargs=-1)
|
| 160 |
+
def eval_predict(config: str, overrides: tuple):
|
| 161 |
+
"""Evaluate model predictions on validation/test sets.
|
| 162 |
+
|
| 163 |
+
\b
|
| 164 |
+
Arguments:
|
| 165 |
+
CONFIG Config file path (e.g., configs/eval_predict.yaml)
|
| 166 |
+
OVERRIDES Hydra-style config overrides
|
| 167 |
+
|
| 168 |
+
\b
|
| 169 |
+
Examples:
|
| 170 |
+
molcraft eval-predict configs/eval_predict.yaml
|
| 171 |
+
"""
|
| 172 |
+
log_system_info()
|
| 173 |
+
logger.info(f"Starting eval-predict with config: {config}")
|
| 174 |
+
|
| 175 |
+
from MolecularDiffusion.cli._hydra import run_hydra_app
|
| 176 |
+
from MolecularDiffusion.cli.eval_predict import eval_predict_main
|
| 177 |
+
|
| 178 |
+
run_hydra_app(
|
| 179 |
+
config_name=config,
|
| 180 |
+
task_function=eval_predict_main,
|
| 181 |
+
config_dir=None,
|
| 182 |
+
overrides=list(overrides),
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
|
| 186 |
+
# Register analyze subcommand group
|
| 187 |
+
from MolecularDiffusion.cli.analyze import analyze
|
| 188 |
+
cli.add_command(analyze)
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def main():
|
| 192 |
+
"""Entry point."""
|
| 193 |
+
cli()
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
if __name__ == "__main__":
|
| 197 |
+
main()
|
MolecularDiffusion/cli/predict.py
ADDED
|
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Prediction command for MolCraft CLI.
|
| 2 |
+
|
| 3 |
+
Adapted from scripts/predict.py for package-level execution.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from glob import glob
|
| 8 |
+
from typing import Any, Dict, Tuple
|
| 9 |
+
|
| 10 |
+
import hydra
|
| 11 |
+
import numpy as np
|
| 12 |
+
import pandas as pd
|
| 13 |
+
import torch
|
| 14 |
+
from ase.data import atomic_numbers
|
| 15 |
+
from omegaconf import DictConfig, OmegaConf
|
| 16 |
+
from torch_geometric.data import Data
|
| 17 |
+
from torch_geometric.nn import knn_graph, radius_graph
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
|
| 20 |
+
from MolecularDiffusion.core import Engine
|
| 21 |
+
from MolecularDiffusion.data.component.pointcloud import PointCloud_Mol
|
| 22 |
+
from MolecularDiffusion.data.component.feature import (
|
| 23 |
+
onehot,
|
| 24 |
+
atom_topological,
|
| 25 |
+
atom_geom,
|
| 26 |
+
atom_geom_compact,
|
| 27 |
+
atom_geom_opt,
|
| 28 |
+
atom_geom_v2,
|
| 29 |
+
atom_geom_v2_trun,
|
| 30 |
+
)
|
| 31 |
+
from MolecularDiffusion.utils import RankedLogger, seed_everything
|
| 32 |
+
from MolecularDiffusion.utils.plot_function import (
|
| 33 |
+
plot_kde_distribution,
|
| 34 |
+
plot_histogram_distribution,
|
| 35 |
+
plot_kde_distribution_multiple,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
log = RankedLogger(__name__, rank_zero_only=True)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def is_rank_zero():
|
| 42 |
+
"""Check if current process is rank zero."""
|
| 43 |
+
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
| 44 |
+
return torch.distributed.get_rank() == 0
|
| 45 |
+
return True
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def load_model(chkpt_path, task_config=None, atom_vocab=None):
|
| 49 |
+
"""Load a pre-trained model from checkpoint with auto-detection."""
|
| 50 |
+
log.info(f"Loading checkpoint from: {chkpt_path}")
|
| 51 |
+
|
| 52 |
+
# Try loading as Lightning checkpoint first if it has .ckpt extension
|
| 53 |
+
if chkpt_path.endswith('.ckpt'):
|
| 54 |
+
try:
|
| 55 |
+
from MolecularDiffusion.core.engine_lightning import EngineLightning
|
| 56 |
+
wrapper = EngineLightning.load_from_checkpoint(chkpt_path, map_location="cpu")
|
| 57 |
+
log.info("Successfully loaded model using EngineLightning.load_from_checkpoint")
|
| 58 |
+
|
| 59 |
+
# Need to return something that has a .model attribute for backward compatibility
|
| 60 |
+
class SolverWrapper:
|
| 61 |
+
def __init__(self, task):
|
| 62 |
+
self.model = task
|
| 63 |
+
|
| 64 |
+
solver = SolverWrapper(wrapper.task)
|
| 65 |
+
solver.model.eval()
|
| 66 |
+
return solver
|
| 67 |
+
except Exception as e:
|
| 68 |
+
log.warning(f"EngineLightning.load_from_checkpoint failed ({type(e).__name__}: {e}). Trying manual fallback.")
|
| 69 |
+
|
| 70 |
+
# Manual fallback or original engine (.pkl/no extension)
|
| 71 |
+
checkpoint = torch.load(chkpt_path, map_location="cpu", weights_only=False)
|
| 72 |
+
|
| 73 |
+
# Check if it's a Lightning checkpoint dictionary
|
| 74 |
+
if "hyper_parameters" in checkpoint:
|
| 75 |
+
log.info("Detected Lightning checkpoint dictionary.")
|
| 76 |
+
hparams = checkpoint.get("hyper_parameters", {})
|
| 77 |
+
|
| 78 |
+
# Try to get model_config from checkpoint
|
| 79 |
+
model_config = hparams.get("model_config", task_config)
|
| 80 |
+
if model_config is None:
|
| 81 |
+
raise ValueError("Lightning checkpoint lacks 'model_config' and no 'task_config' provided.")
|
| 82 |
+
|
| 83 |
+
# Instantiate task
|
| 84 |
+
if isinstance(model_config, dict):
|
| 85 |
+
model_config = OmegaConf.create(model_config)
|
| 86 |
+
|
| 87 |
+
# Ensure we have atom_vocab if needed
|
| 88 |
+
if atom_vocab is not None and ('atom_vocab' not in model_config or model_config.atom_vocab is None):
|
| 89 |
+
OmegaConf.set_struct(model_config, False)
|
| 90 |
+
model_config.atom_vocab = atom_vocab
|
| 91 |
+
|
| 92 |
+
task_factory = hydra.utils.instantiate(model_config)
|
| 93 |
+
task = task_factory.build()
|
| 94 |
+
|
| 95 |
+
# Load weights
|
| 96 |
+
state_dict = checkpoint.get("state_dict", {})
|
| 97 |
+
cleaned_state_dict = {}
|
| 98 |
+
for k, v in state_dict.items():
|
| 99 |
+
if k.startswith("task."):
|
| 100 |
+
cleaned_state_dict[k[5:]] = v
|
| 101 |
+
else:
|
| 102 |
+
cleaned_state_dict[k] = v
|
| 103 |
+
|
| 104 |
+
task.load_state_dict(cleaned_state_dict, strict=False)
|
| 105 |
+
log.info(f"Loaded {len(cleaned_state_dict)} parameters from state_dict")
|
| 106 |
+
|
| 107 |
+
# Try to recover mean/std if they are in the checkpoint root or state_dict but not as buffers
|
| 108 |
+
for key in ["mean", "std", "weight"]:
|
| 109 |
+
val = None
|
| 110 |
+
if key in checkpoint:
|
| 111 |
+
val = checkpoint[key]
|
| 112 |
+
elif f"task.{key}" in state_dict:
|
| 113 |
+
val = state_dict[f"task.{key}"]
|
| 114 |
+
elif key in state_dict:
|
| 115 |
+
val = state_dict[key]
|
| 116 |
+
|
| 117 |
+
if val is not None:
|
| 118 |
+
if not isinstance(val, torch.Tensor):
|
| 119 |
+
val = torch.as_tensor(val, dtype=torch.float32)
|
| 120 |
+
|
| 121 |
+
# Register as buffer to ensure it moves with the model to the correct device
|
| 122 |
+
if key in task._buffers:
|
| 123 |
+
task._buffers[key].copy_(val)
|
| 124 |
+
else:
|
| 125 |
+
task.register_buffer(key, val)
|
| 126 |
+
|
| 127 |
+
# Ensure task has a device attribute
|
| 128 |
+
if not hasattr(task, 'device'):
|
| 129 |
+
task.device = next(task.parameters()).device if list(task.parameters()) else torch.device('cpu')
|
| 130 |
+
|
| 131 |
+
class SolverWrapper:
|
| 132 |
+
def __init__(self, task):
|
| 133 |
+
self.model = task
|
| 134 |
+
|
| 135 |
+
solver = SolverWrapper(task)
|
| 136 |
+
solver.model.eval()
|
| 137 |
+
# Ensure task.device is updated to the actual device solver is using
|
| 138 |
+
if hasattr(solver.model, 'device') and solver.model.device != next(solver.model.parameters()).device:
|
| 139 |
+
solver.model.device = next(solver.model.parameters()).device if list(solver.model.parameters()) else torch.device('cpu')
|
| 140 |
+
elif not hasattr(solver.model, 'device'):
|
| 141 |
+
solver.model.device = next(solver.model.parameters()).device if list(solver.model.parameters()) else torch.device('cpu')
|
| 142 |
+
return solver
|
| 143 |
+
else:
|
| 144 |
+
# Original Engine checkpoint
|
| 145 |
+
engine = Engine(None, None, None, None, None)
|
| 146 |
+
solver = engine.load_from_checkpoint(chkpt_path, interference_mode=True)
|
| 147 |
+
solver.model.eval()
|
| 148 |
+
# Ensure task.device is updated to the actual device solver is using
|
| 149 |
+
solver.model.device = solver.device
|
| 150 |
+
return solver
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def xyz2mol(xyz_file, atom_vocab, node_feature, edge_type="fully_connected",
|
| 154 |
+
radius=4.0, n_neigh=5, device="cpu"):
|
| 155 |
+
"""Convert an XYZ file into a PyTorch Geometric Data object."""
|
| 156 |
+
mol_obj = {}
|
| 157 |
+
mol_xyz = PointCloud_Mol.from_xyz(xyz_file, with_hydrogen=True, forbidden_atoms=[])
|
| 158 |
+
coords = mol_xyz.get_coord()
|
| 159 |
+
n_nodes = len(mol_xyz.atoms)
|
| 160 |
+
|
| 161 |
+
node_features = []
|
| 162 |
+
for atom in mol_xyz.atoms:
|
| 163 |
+
node_features.append(onehot(atom.element, atom_vocab, allow_unknown=False))
|
| 164 |
+
|
| 165 |
+
charges = [
|
| 166 |
+
atomic_numbers[atom.element]
|
| 167 |
+
for atom in mol_xyz.atoms
|
| 168 |
+
if atom.element in atomic_numbers
|
| 169 |
+
]
|
| 170 |
+
|
| 171 |
+
if node_feature:
|
| 172 |
+
if node_feature in [
|
| 173 |
+
"atom_topological", "atom_geom", "atom_geom_v2",
|
| 174 |
+
"atom_geom_v2_trun", "atom_geom_opt", "atom_geom_compact"
|
| 175 |
+
]:
|
| 176 |
+
feature_mapping = {
|
| 177 |
+
"atom_topological": atom_topological,
|
| 178 |
+
"atom_geom": atom_geom,
|
| 179 |
+
"atom_geom_v2": atom_geom_v2,
|
| 180 |
+
"atom_geom_v2_trun": atom_geom_v2_trun,
|
| 181 |
+
"atom_geom_opt": atom_geom_opt,
|
| 182 |
+
"atom_geom_compact": atom_geom_compact,
|
| 183 |
+
}
|
| 184 |
+
feature_function = feature_mapping.get(node_feature)
|
| 185 |
+
if feature_function is not None:
|
| 186 |
+
node_features_extra = feature_function(charges, coords)
|
| 187 |
+
node_features = torch.cat(
|
| 188 |
+
[torch.tensor(node_features), node_features_extra], dim=1
|
| 189 |
+
)
|
| 190 |
+
else:
|
| 191 |
+
raise ValueError("Unknown node feature type")
|
| 192 |
+
else:
|
| 193 |
+
node_features = torch.tensor(node_features, dtype=torch.float32)
|
| 194 |
+
|
| 195 |
+
node_features = torch.tensor(node_features, dtype=torch.float32)
|
| 196 |
+
charges = torch.as_tensor(charges, dtype=torch.long)
|
| 197 |
+
node_mask = torch.ones(n_nodes, dtype=torch.int8)
|
| 198 |
+
|
| 199 |
+
edge_mask = node_mask.unsqueeze(0) * node_mask.unsqueeze(1)
|
| 200 |
+
diag_mask = ~torch.eye(n_nodes, dtype=torch.bool)
|
| 201 |
+
edge_mask *= diag_mask
|
| 202 |
+
edge_mask = edge_mask.view(1 * n_nodes * n_nodes, 1)
|
| 203 |
+
h = node_features.view(1 * n_nodes, -1).clone()
|
| 204 |
+
|
| 205 |
+
if edge_type == "distance":
|
| 206 |
+
edge_index = radius_graph(coords, r=radius)
|
| 207 |
+
elif edge_type == "neighbor":
|
| 208 |
+
edge_index = knn_graph(coords, k=n_neigh)
|
| 209 |
+
elif edge_type == "fully_connected":
|
| 210 |
+
num_nodes = coords.size(0)
|
| 211 |
+
row = torch.arange(num_nodes).repeat_interleave(num_nodes)
|
| 212 |
+
col = torch.arange(num_nodes).repeat(num_nodes)
|
| 213 |
+
edge_index = torch.stack([row, col], dim=0)
|
| 214 |
+
edge_index = edge_index[:, row != col]
|
| 215 |
+
else:
|
| 216 |
+
raise ValueError(f"Unknown edge type {edge_type}")
|
| 217 |
+
|
| 218 |
+
graph_data = Data(
|
| 219 |
+
x=h,
|
| 220 |
+
pos=coords,
|
| 221 |
+
atomic_numbers=charges,
|
| 222 |
+
natoms=torch.tensor([n_nodes]),
|
| 223 |
+
edge_index=edge_index,
|
| 224 |
+
times=torch.tensor([0]),
|
| 225 |
+
batch=torch.zeros(n_nodes, dtype=torch.long),
|
| 226 |
+
).to(device)
|
| 227 |
+
|
| 228 |
+
mol_obj["graph"] = graph_data
|
| 229 |
+
return mol_obj
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def count_atoms_from_xyz(path: str) -> int:
|
| 233 |
+
"""Fast atom counter for XYZ files."""
|
| 234 |
+
try:
|
| 235 |
+
with open(path, "r") as f:
|
| 236 |
+
first = f.readline().strip()
|
| 237 |
+
return int(first)
|
| 238 |
+
except Exception:
|
| 239 |
+
return 0
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
def _runner(solver, xyz_paths: list, max_atoms: int = 100) -> torch.Tensor:
|
| 243 |
+
"""Runs predictions on a list of XYZ files."""
|
| 244 |
+
device = getattr(solver.model, 'device', next(solver.model.parameters()).device if list(solver.model.parameters()) else torch.device('cpu'))
|
| 245 |
+
task_names = list(solver.model.task.keys())
|
| 246 |
+
num_molecules = len(xyz_paths)
|
| 247 |
+
|
| 248 |
+
progress_bar = tqdm(
|
| 249 |
+
enumerate(xyz_paths),
|
| 250 |
+
desc="Predicting molecules",
|
| 251 |
+
leave=True,
|
| 252 |
+
dynamic_ncols=True,
|
| 253 |
+
total=num_molecules,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
predictions = []
|
| 257 |
+
xyz_paths_clear = []
|
| 258 |
+
skipped = 0
|
| 259 |
+
|
| 260 |
+
for i, xyz_path in progress_bar:
|
| 261 |
+
n_atoms = count_atoms_from_xyz(xyz_path)
|
| 262 |
+
if n_atoms > max_atoms:
|
| 263 |
+
skipped += 1
|
| 264 |
+
progress_bar.set_postfix({"batch": i + 1, "skipped": skipped})
|
| 265 |
+
log.info(f"Skipping {xyz_path} (atoms={n_atoms} > max_atoms={max_atoms})")
|
| 266 |
+
continue
|
| 267 |
+
|
| 268 |
+
mol_obj = xyz2mol(
|
| 269 |
+
xyz_file=xyz_path,
|
| 270 |
+
atom_vocab=solver.model.atom_vocab,
|
| 271 |
+
node_feature=solver.model.node_feature,
|
| 272 |
+
device=device,
|
| 273 |
+
)
|
| 274 |
+
prediction = solver.model.predict(mol_obj, evaluate=True)[0]
|
| 275 |
+
predictions.append(prediction.detach().cpu().numpy())
|
| 276 |
+
current_preds_dict = {prop_name: prediction[j].item() for j, prop_name in enumerate(task_names)}
|
| 277 |
+
progress_bar.set_postfix({"batch": i + 1, "skipped": skipped, **current_preds_dict})
|
| 278 |
+
xyz_paths_clear.append(xyz_path)
|
| 279 |
+
|
| 280 |
+
predictions = np.array(predictions)
|
| 281 |
+
if predictions.ndim > 1 and predictions.shape[-1] == 1:
|
| 282 |
+
predictions = predictions.squeeze(-1)
|
| 283 |
+
|
| 284 |
+
return predictions, xyz_paths_clear
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def runner(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
| 288 |
+
"""Property prediction run."""
|
| 289 |
+
if cfg.get("seed"):
|
| 290 |
+
seed_everything(cfg.seed, workers=True)
|
| 291 |
+
|
| 292 |
+
log.info(f"Instantiating diffusion task and loading the model <{cfg.tasks._target_}>")
|
| 293 |
+
solver = load_model(cfg.chkpt_directory, task_config=cfg.tasks, atom_vocab=cfg.atom_vocab)
|
| 294 |
+
|
| 295 |
+
task_names = list(solver.model.task.keys())
|
| 296 |
+
|
| 297 |
+
if not hasattr(solver.model, 'std') or solver.model.std is None:
|
| 298 |
+
chkpt = torch.load(cfg.chkpt_directory, weights_only=False)
|
| 299 |
+
if "model" in chkpt:
|
| 300 |
+
solver.model.std = chkpt["model"].get("std", torch.ones(1)).to(solver.model.device)
|
| 301 |
+
solver.model.weight = chkpt["model"].get("weight", torch.ones(1)).to(solver.model.device)
|
| 302 |
+
solver.model.mean = chkpt["model"].get("mean", torch.zeros(1)).to(solver.model.device)
|
| 303 |
+
elif "state_dict" in chkpt:
|
| 304 |
+
# Fallback for Lightning if not already loaded by load_model
|
| 305 |
+
sd = chkpt["state_dict"]
|
| 306 |
+
solver.model.std = sd.get("task.std", sd.get("std", torch.ones(1))).to(solver.model.device)
|
| 307 |
+
solver.model.weight = sd.get("task.weight", sd.get("weight", torch.ones(1))).to(solver.model.device)
|
| 308 |
+
solver.model.mean = sd.get("task.mean", sd.get("mean", torch.zeros(1))).to(solver.model.device)
|
| 309 |
+
|
| 310 |
+
if not hasattr(solver.model, 'atom_vocab'):
|
| 311 |
+
solver.model.atom_vocab = cfg.atom_vocab
|
| 312 |
+
if not hasattr(solver.model, 'node_feature'):
|
| 313 |
+
solver.model.node_feature = cfg.node_feature
|
| 314 |
+
|
| 315 |
+
object_dict = {"cfg": cfg, "solver": solver}
|
| 316 |
+
|
| 317 |
+
log.info("Logging hyperparameters!")
|
| 318 |
+
log_hyperparameters(object_dict)
|
| 319 |
+
|
| 320 |
+
os.makedirs(cfg.output_directory, exist_ok=True)
|
| 321 |
+
|
| 322 |
+
if is_rank_zero():
|
| 323 |
+
config_path = os.path.join(cfg.output_directory, "config.yaml")
|
| 324 |
+
with open(config_path, "w") as f:
|
| 325 |
+
OmegaConf.save(config=cfg, f=f)
|
| 326 |
+
log.info(f"Configuration saved to {config_path}")
|
| 327 |
+
|
| 328 |
+
log.info("Running the predictions...")
|
| 329 |
+
xyz_paths = glob(f"{cfg.xyz_directory}/*.xyz")
|
| 330 |
+
xyz_paths = [str(xyz_path) for xyz_path in xyz_paths]
|
| 331 |
+
predictions, xyz_paths_clear = _runner(solver, xyz_paths, max_atoms=cfg.get("max_atoms", 100))
|
| 332 |
+
|
| 333 |
+
df_dicts = {}
|
| 334 |
+
for task_name, prediction in zip(task_names, predictions.T):
|
| 335 |
+
df_dicts[task_name] = prediction
|
| 336 |
+
df_dicts["xyz_path"] = xyz_paths_clear
|
| 337 |
+
|
| 338 |
+
df = pd.DataFrame(df_dicts)
|
| 339 |
+
df = df.sort_values(by="xyz_path")
|
| 340 |
+
df.to_csv(f"{cfg.output_directory}/predictions.csv", index=False)
|
| 341 |
+
|
| 342 |
+
log.info("Prediction statistics:")
|
| 343 |
+
for task_name in task_names:
|
| 344 |
+
log.info(f"--- {task_name} ---")
|
| 345 |
+
log.info(f"Mean: {df[task_name].mean():.4f}")
|
| 346 |
+
log.info(f"Std: {df[task_name].std():.4f}")
|
| 347 |
+
log.info(f"Min: {df[task_name].min():.4f}")
|
| 348 |
+
log.info(f"Max: {df[task_name].max():.4f}")
|
| 349 |
+
|
| 350 |
+
log.info("Plotting distributions...")
|
| 351 |
+
props = []
|
| 352 |
+
for prop in task_names:
|
| 353 |
+
plot_kde_distribution(df[prop], prop, f"{cfg.output_directory}/{prop}_kde.png")
|
| 354 |
+
plot_histogram_distribution(df[prop], prop, f"{cfg.output_directory}/{prop}_hist.png")
|
| 355 |
+
props.append(df[prop].values)
|
| 356 |
+
|
| 357 |
+
props = np.array(props).T
|
| 358 |
+
plot_kde_distribution_multiple(props, task_names, f"{cfg.output_directory}/kde_all.png")
|
| 359 |
+
|
| 360 |
+
|
| 361 |
+
def log_hyperparameters(object_dict: dict):
|
| 362 |
+
"""Log hyperparameters for debugging."""
|
| 363 |
+
if not is_rank_zero():
|
| 364 |
+
return
|
| 365 |
+
|
| 366 |
+
log.info("\n========== Logging Hyperparameters ==========\n")
|
| 367 |
+
for name, obj in object_dict.items():
|
| 368 |
+
log.info(f"{'=' * 20} {name.upper()} {'=' * 20}")
|
| 369 |
+
if name == "cfg":
|
| 370 |
+
if isinstance(obj, dict):
|
| 371 |
+
log.info("\n" + OmegaConf.to_yaml(OmegaConf.create(obj)))
|
| 372 |
+
else:
|
| 373 |
+
log.info("\n" + OmegaConf.to_yaml(obj))
|
| 374 |
+
else:
|
| 375 |
+
if hasattr(obj, '__dict__'):
|
| 376 |
+
for k, v in vars(obj).items():
|
| 377 |
+
if not k.startswith("_"):
|
| 378 |
+
log.info(f"{k}: {v}")
|
| 379 |
+
log.info(f"{'=' * (44 + len(name))}\n")
|
| 380 |
+
|
| 381 |
+
if "task" in object_dict and hasattr(object_dict["task"], "task"):
|
| 382 |
+
model = object_dict["task"].task
|
| 383 |
+
total = sum(p.numel() for p in model.parameters())
|
| 384 |
+
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 385 |
+
log.info(f"{'=' * 20} MODEL PARAMS {'=' * 20}")
|
| 386 |
+
log.info(f"model/params/total: {total}")
|
| 387 |
+
log.info(f"model/params/trainable: {trainable}")
|
| 388 |
+
log.info("=" * 54 + "\n")
|
| 389 |
+
|
| 390 |
+
log.info("========== End of Hyperparameters ==========\n")
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def predict_main(cfg: DictConfig):
|
| 394 |
+
"""Entry point for CLI predict command."""
|
| 395 |
+
runner(cfg)
|
MolecularDiffusion/cli/train.py
ADDED
|
@@ -0,0 +1,453 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Training command for MolCraft CLI.
|
| 2 |
+
|
| 3 |
+
Adapted from scripts/train.py for package-level execution.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from typing import Any, Dict, Optional, Tuple
|
| 7 |
+
import os
|
| 8 |
+
import pickle
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
import hydra
|
| 12 |
+
import torch
|
| 13 |
+
from omegaconf import DictConfig, OmegaConf
|
| 14 |
+
|
| 15 |
+
from MolecularDiffusion.core import Engine
|
| 16 |
+
from MolecularDiffusion.runmodes.train import (
|
| 17 |
+
evaluate,
|
| 18 |
+
DataModule,
|
| 19 |
+
Logger,
|
| 20 |
+
OptimSchedulerFactory,
|
| 21 |
+
get_versioned_output_path,
|
| 22 |
+
)
|
| 23 |
+
from MolecularDiffusion.utils import (
|
| 24 |
+
RankedLogger,
|
| 25 |
+
task_wrapper,
|
| 26 |
+
seed_everything,
|
| 27 |
+
)
|
| 28 |
+
|
| 29 |
+
log = RankedLogger(__name__, rank_zero_only=True)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def is_rank_zero():
|
| 34 |
+
"""Check if current process is rank zero."""
|
| 35 |
+
if torch.distributed.is_available() and torch.distributed.is_initialized():
|
| 36 |
+
return torch.distributed.get_rank() == 0
|
| 37 |
+
return True
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def load_weights(task, ckpt_path):
|
| 41 |
+
"""Load model weights from a checkpoint file (weights only).
|
| 42 |
+
|
| 43 |
+
This loads the state_dict from the checkpoint into the task model,
|
| 44 |
+
ignoring optimizer/scheduler states and other metadata.
|
| 45 |
+
Useful for fine-tuning or starting from a pre-trained model.
|
| 46 |
+
"""
|
| 47 |
+
if not os.path.exists(ckpt_path):
|
| 48 |
+
raise FileNotFoundError(f"Checkpoint not found at: {ckpt_path}")
|
| 49 |
+
|
| 50 |
+
log.info(f"Loading weights from: {ckpt_path}")
|
| 51 |
+
|
| 52 |
+
# Load checkpoint
|
| 53 |
+
checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
|
| 54 |
+
state_dict = checkpoint.get("state_dict", checkpoint)
|
| 55 |
+
|
| 56 |
+
# Prepare state dict for loading
|
| 57 |
+
cleaned_state_dict = {}
|
| 58 |
+
for key, value in state_dict.items():
|
| 59 |
+
# Strip 'task.' prefix if present (common in Lightning checkpoints)
|
| 60 |
+
if key.startswith("task."):
|
| 61 |
+
cleaned_state_dict[key[5:]] = value
|
| 62 |
+
else:
|
| 63 |
+
cleaned_state_dict[key] = value
|
| 64 |
+
|
| 65 |
+
# Load into task
|
| 66 |
+
missing, unexpected = task.load_state_dict(cleaned_state_dict, strict=False)
|
| 67 |
+
|
| 68 |
+
if len(missing) > 0:
|
| 69 |
+
log.warning(f"Missing keys when loading weights: {missing[:5]}{'...' if len(missing)>5 else ''}")
|
| 70 |
+
if len(unexpected) > 0:
|
| 71 |
+
log.warning(f"Unexpected keys in checkpoint: {unexpected[:5]}{'...' if len(unexpected)>5 else ''}")
|
| 72 |
+
|
| 73 |
+
log.info(f"Successfully loaded {len(cleaned_state_dict)} parameters into task.")
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# Lightning imports (optional)
|
| 77 |
+
try:
|
| 78 |
+
import pytorch_lightning as pl
|
| 79 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
| 80 |
+
from pytorch_lightning.callbacks import LearningRateMonitor
|
| 81 |
+
from MolecularDiffusion.core.engine_lightning import EngineLightning
|
| 82 |
+
from MolecularDiffusion.data.lightning_data_module import MolecularDiffusionDataModule
|
| 83 |
+
from MolecularDiffusion.core.lightning_callbacks import GenerativeEvalCallback
|
| 84 |
+
LIGHTNING_AVAILABLE = True
|
| 85 |
+
except ImportError as e:
|
| 86 |
+
LIGHTNING_AVAILABLE = False
|
| 87 |
+
log.warning(f"PyTorch Lightning not found: {e}. Only original Engine available.")
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def engine_wrapper(task_module, data_module, trainer_module, logger_module,
|
| 91 |
+
resume_from_checkpoint=None, **kwargs):
|
| 92 |
+
"""Training loop using original Engine."""
|
| 93 |
+
trainer_module.get_optimizer()
|
| 94 |
+
trainer_module.get_scheduler()
|
| 95 |
+
|
| 96 |
+
solver = Engine(
|
| 97 |
+
task_module.task,
|
| 98 |
+
data_module.train_set,
|
| 99 |
+
data_module.valid_set,
|
| 100 |
+
data_module.test_set,
|
| 101 |
+
batch_size=data_module.batch_size,
|
| 102 |
+
collate_fn=data_module.collate_fn,
|
| 103 |
+
optimizer=trainer_module.optimizer,
|
| 104 |
+
ema_decay=trainer_module.ema_decay,
|
| 105 |
+
scheduler=trainer_module.scheduler,
|
| 106 |
+
clipping_gradient=trainer_module.gradient_clip_mode,
|
| 107 |
+
clip_value=trainer_module.gradnorm_queue,
|
| 108 |
+
logger=logger_module.logger,
|
| 109 |
+
log_interval=logger_module.log_interval,
|
| 110 |
+
name_wandb=logger_module.name_wandb,
|
| 111 |
+
project_wandb=logger_module.project_wandb,
|
| 112 |
+
dir_wandb=trainer_module.output_path,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
# Resume from checkpoint if provided
|
| 116 |
+
start_epoch = 0
|
| 117 |
+
if resume_from_checkpoint:
|
| 118 |
+
start_epoch = solver.resume(resume_from_checkpoint, strict=False)
|
| 119 |
+
log.info(f"Resumed from epoch {start_epoch}")
|
| 120 |
+
|
| 121 |
+
use_amp = trainer_module.precision in ["bf16", 16]
|
| 122 |
+
|
| 123 |
+
best_checkpoints = []
|
| 124 |
+
best_checkpoints = []
|
| 125 |
+
if hasattr(task_module.task, "sample") and kwargs.get("generative_analysis"):
|
| 126 |
+
best_metrics = -torch.inf
|
| 127 |
+
models_to_save = {"node": task_module.task.node_dist_model}
|
| 128 |
+
if len(task_module.condition_names) > 0:
|
| 129 |
+
models_to_save["prop"] = task_module.task.prop_dist_model
|
| 130 |
+
if is_rank_zero():
|
| 131 |
+
with open(os.path.join(trainer_module.output_path, "edm_stat.pkl"), "wb") as f:
|
| 132 |
+
pickle.dump(models_to_save, f)
|
| 133 |
+
else:
|
| 134 |
+
best_metrics = torch.inf
|
| 135 |
+
|
| 136 |
+
# Create versioned checkpoint folder (like Lightning's version_X folders)
|
| 137 |
+
versioned_ckpt_path = get_versioned_output_path(trainer_module.output_path)
|
| 138 |
+
|
| 139 |
+
# Adjust loop to continue from start_epoch
|
| 140 |
+
for i in range(start_epoch, trainer_module.num_epochs):
|
| 141 |
+
solver.train(num_epoch=1, use_amp=use_amp, precision=trainer_module.precision)
|
| 142 |
+
if i % trainer_module.validation_interval == 0 or i == trainer_module.num_epochs - 1:
|
| 143 |
+
if hasattr(task_module.task, "sample"):
|
| 144 |
+
output_generated_dir = os.path.join(versioned_ckpt_path, "generated_molecules")
|
| 145 |
+
os.makedirs(output_generated_dir, exist_ok=True)
|
| 146 |
+
best_metrics, best_checkpoints = evaluate(
|
| 147 |
+
task_module.task_type, solver, i, best_metrics, best_checkpoints,
|
| 148 |
+
logger_module.logger, output_generated_dir=output_generated_dir,
|
| 149 |
+
generative_analysis=kwargs.get("generative_analysis", False),
|
| 150 |
+
n_samples=kwargs.get("n_samples", 100),
|
| 151 |
+
metric=kwargs.get("metric", "Validity Relax and connected"),
|
| 152 |
+
output_path=versioned_ckpt_path,
|
| 153 |
+
use_amp=use_amp, precision=trainer_module.precision,
|
| 154 |
+
use_posebuster=kwargs.get("use_posebuster", False),
|
| 155 |
+
batch_size=kwargs.get("batch_size", 1),
|
| 156 |
+
save_top_k=getattr(trainer_module, "save_top_k", 3),
|
| 157 |
+
save_every_val_epoch=getattr(trainer_module, "save_every_val_epoch", False),
|
| 158 |
+
)
|
| 159 |
+
else:
|
| 160 |
+
best_metrics, best_checkpoints = evaluate(
|
| 161 |
+
task_module.task_type, solver, i, best_metrics, best_checkpoints,
|
| 162 |
+
logger_module.logger, output_path=versioned_ckpt_path,
|
| 163 |
+
save_top_k=getattr(trainer_module, "save_top_k", 3),
|
| 164 |
+
save_every_val_epoch=getattr(trainer_module, "save_every_val_epoch", False),
|
| 165 |
+
)
|
| 166 |
+
return best_metrics, solver
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
def lightning_wrapper(task_module, data_module, trainer_module, logger_module, engine_cfg,
|
| 170 |
+
ckpt_path=None, monitor_metric=None, monitor_mode=None, model_config=None, **kwargs):
|
| 171 |
+
"""Training using PyTorch Lightning Trainer."""
|
| 172 |
+
if not LIGHTNING_AVAILABLE:
|
| 173 |
+
raise ImportError("PyTorch Lightning required. Install with: pip install pytorch-lightning")
|
| 174 |
+
|
| 175 |
+
if hasattr(task_module.task, "preprocess"):
|
| 176 |
+
log.info("Calling task.preprocess() for Lightning engine")
|
| 177 |
+
result = task_module.task.preprocess(data_module.train_set)
|
| 178 |
+
if result is not None:
|
| 179 |
+
data_module.train_set, data_module.valid_set, data_module.test_set = result
|
| 180 |
+
|
| 181 |
+
pl_data_module = MolecularDiffusionDataModule(
|
| 182 |
+
data_module=data_module,
|
| 183 |
+
batch_size=data_module.batch_size,
|
| 184 |
+
num_workers=getattr(trainer_module, "num_worker", 0),
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
pl_module = EngineLightning(
|
| 188 |
+
task=task_module.task,
|
| 189 |
+
optimizer_config={
|
| 190 |
+
"optimizer_choice": trainer_module.optimizer_choice,
|
| 191 |
+
"lr": trainer_module.lr,
|
| 192 |
+
"weight_decay": trainer_module.weight_decay,
|
| 193 |
+
"betas": trainer_module.betas,
|
| 194 |
+
"eps": trainer_module.eps,
|
| 195 |
+
},
|
| 196 |
+
scheduler_config={
|
| 197 |
+
"scheduler": trainer_module.scheduler_choice,
|
| 198 |
+
"scheduler_kwargs": trainer_module.scheduler_choice_kwargs,
|
| 199 |
+
},
|
| 200 |
+
model_config=model_config,
|
| 201 |
+
monitor_metric=monitor_metric,
|
| 202 |
+
ema_decay=trainer_module.ema_decay,
|
| 203 |
+
gradnorm_queue=trainer_module.gradnorm_queue,
|
| 204 |
+
gradient_clip_algorithm=getattr(trainer_module, 'gradient_clip_algorithm', 'adaptive'),
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
callbacks = []
|
| 208 |
+
|
| 209 |
+
if hasattr(task_module.task, "sample") and kwargs.get("generative_analysis"):
|
| 210 |
+
callbacks.append(GenerativeEvalCallback(
|
| 211 |
+
n_samples=kwargs.get("n_samples", 100),
|
| 212 |
+
batch_size=kwargs.get("batch_size", 100),
|
| 213 |
+
metric=kwargs.get("metric", "Validity Relax and connected"),
|
| 214 |
+
output_dir=os.path.join(trainer_module.output_path, "generated_molecules"),
|
| 215 |
+
use_posebuster=kwargs.get("use_posebuster", False),
|
| 216 |
+
monitor_metric=monitor_metric,
|
| 217 |
+
))
|
| 218 |
+
|
| 219 |
+
# Checkpoint callback
|
| 220 |
+
# Handle OmegaConf ListConfig properly
|
| 221 |
+
if monitor_metric is not None:
|
| 222 |
+
# Convert OmegaConf types to Python types
|
| 223 |
+
if OmegaConf.is_list(monitor_metric):
|
| 224 |
+
monitor_metric_key = str(monitor_metric[0])
|
| 225 |
+
elif isinstance(monitor_metric, (list, tuple)):
|
| 226 |
+
monitor_metric_key = str(monitor_metric[0])
|
| 227 |
+
else:
|
| 228 |
+
monitor_metric_key = str(monitor_metric)
|
| 229 |
+
mode = monitor_mode or ("min" if "loss" in monitor_metric_key else "max")
|
| 230 |
+
elif hasattr(task_module.task, "sample"):
|
| 231 |
+
monitor_metric_key = f"gen/{kwargs.get('metric', 'Validity Relax and connected')}"
|
| 232 |
+
mode = "max"
|
| 233 |
+
else:
|
| 234 |
+
monitor_metric_key = "val/loss"
|
| 235 |
+
mode = "min"
|
| 236 |
+
|
| 237 |
+
# Handle save_every_val_epoch
|
| 238 |
+
save_top_k = trainer_module.save_top_k
|
| 239 |
+
if getattr(trainer_module, "save_every_val_epoch", False) or kwargs.get("save_every_val_epoch", False):
|
| 240 |
+
log.info("save_every_val_epoch=True: Overriding save_top_k to -1 (save all checkpoints)")
|
| 241 |
+
save_top_k = -1
|
| 242 |
+
|
| 243 |
+
callbacks.append(ModelCheckpoint(
|
| 244 |
+
monitor=monitor_metric_key,
|
| 245 |
+
mode=mode,
|
| 246 |
+
save_top_k=save_top_k,
|
| 247 |
+
filename=f"epoch={{epoch}}-{monitor_metric_key.replace('/', '_').replace(' ', '_')}={{{monitor_metric_key}:.3f}}",
|
| 248 |
+
save_last=True,
|
| 249 |
+
))
|
| 250 |
+
|
| 251 |
+
# Learning rate monitor for wandb logging
|
| 252 |
+
callbacks.append(LearningRateMonitor(logging_interval='step'))
|
| 253 |
+
|
| 254 |
+
trainer_config = OmegaConf.to_container(engine_cfg.trainer_config, resolve=True)
|
| 255 |
+
precision_map = {32: 32, 16: "16-mixed", "16": "16-mixed", "bf16": "bf16-mixed"}
|
| 256 |
+
trainer_config["precision"] = precision_map.get(trainer_config.get("precision", 32), 32)
|
| 257 |
+
|
| 258 |
+
if logger_module.logger == "wandb":
|
| 259 |
+
pl_logger = pl.loggers.WandbLogger(
|
| 260 |
+
project=logger_module.project_wandb,
|
| 261 |
+
name=logger_module.name_wandb,
|
| 262 |
+
save_dir=trainer_module.output_path,
|
| 263 |
+
)
|
| 264 |
+
else:
|
| 265 |
+
pl_logger = True
|
| 266 |
+
|
| 267 |
+
trainer = hydra.utils.instantiate(trainer_config, callbacks=callbacks, logger=pl_logger)
|
| 268 |
+
|
| 269 |
+
if ckpt_path:
|
| 270 |
+
trainer.fit(pl_module, datamodule=pl_data_module, ckpt_path=ckpt_path)
|
| 271 |
+
else:
|
| 272 |
+
trainer.fit(pl_module, datamodule=pl_data_module)
|
| 273 |
+
|
| 274 |
+
return trainer.callback_metrics, trainer
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
@task_wrapper
|
| 278 |
+
def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
| 279 |
+
"""Main training function."""
|
| 280 |
+
output_path = cfg.trainer.output_path
|
| 281 |
+
os.makedirs(output_path, exist_ok=True)
|
| 282 |
+
|
| 283 |
+
if is_rank_zero():
|
| 284 |
+
config_path = os.path.join(output_path, "config.yaml")
|
| 285 |
+
with open(config_path, "w") as f:
|
| 286 |
+
OmegaConf.save(config=cfg, f=f)
|
| 287 |
+
log.info(f"Configuration saved to {config_path}")
|
| 288 |
+
|
| 289 |
+
if cfg.get("seed"):
|
| 290 |
+
seed_everything(cfg.seed, workers=True)
|
| 291 |
+
|
| 292 |
+
log.info(f"Instantiating datamodule <{cfg.data._target_}>")
|
| 293 |
+
data_module: DataModule = hydra.utils.instantiate(cfg.data, task_type=cfg.tasks.task_type)
|
| 294 |
+
data_module.load()
|
| 295 |
+
|
| 296 |
+
log.info(f"Instantiating task <{cfg.tasks._target_}>")
|
| 297 |
+
data_point_chk = data_module.train_set[0]
|
| 298 |
+
node_feature_0 = getattr(data_point_chk, "node_feature", None)
|
| 299 |
+
if node_feature_0 is not None:
|
| 300 |
+
n_dim = node_feature_0.shape[1]
|
| 301 |
+
else:
|
| 302 |
+
try:
|
| 303 |
+
node_feature_0 = getattr(data_point_chk, "x", None)
|
| 304 |
+
n_dim = node_feature_0.shape[1]
|
| 305 |
+
except:
|
| 306 |
+
n_dim = 0
|
| 307 |
+
|
| 308 |
+
factory_cfg = cfg.tasks
|
| 309 |
+
overrides = {}
|
| 310 |
+
|
| 311 |
+
if "tasks_egt" in factory_cfg._target_ or "tasks_esen" in factory_cfg._target_ or "diffusion_tabasco" in factory_cfg._target_:
|
| 312 |
+
overrides["train_set"] = data_module.train_set
|
| 313 |
+
if "condition_names" in factory_cfg:
|
| 314 |
+
overrides["task_names"] = factory_cfg.condition_names
|
| 315 |
+
|
| 316 |
+
if "atom_vocab" in cfg.data:
|
| 317 |
+
overrides["atom_vocab"] = list(cfg.data.atom_vocab)
|
| 318 |
+
|
| 319 |
+
if cfg.data.get("allow_unknown", False):
|
| 320 |
+
overrides["atom_vocab"].append("Suisei")
|
| 321 |
+
|
| 322 |
+
if cfg.tasks.get("metrics", None) == "valid_posebuster":
|
| 323 |
+
overrides["use_posebuster"] = True
|
| 324 |
+
try:
|
| 325 |
+
import posebusters
|
| 326 |
+
except ImportError:
|
| 327 |
+
log.warning("PoseBuster not installed. Falling back to 'Validity Relax and connected'.")
|
| 328 |
+
overrides["use_posebuster"] = False
|
| 329 |
+
overrides["metrics"] = ["Validity Relax and connected"]
|
| 330 |
+
|
| 331 |
+
task_module = hydra.utils.instantiate(factory_cfg, **overrides)
|
| 332 |
+
task_module.build()
|
| 333 |
+
|
| 334 |
+
# Optional: Load weights from checkpoint (without resuming full state)
|
| 335 |
+
if cfg.trainer.get("load_weights_from"):
|
| 336 |
+
load_weights(task_module.task, cfg.trainer.load_weights_from)
|
| 337 |
+
|
| 338 |
+
log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
|
| 339 |
+
trainer_module: OptimSchedulerFactory = hydra.utils.instantiate(
|
| 340 |
+
cfg.trainer, parameters=task_module.task.parameters()
|
| 341 |
+
)
|
| 342 |
+
|
| 343 |
+
name_wandb = trainer_module.output_path.split('/')[-1] if "/" in trainer_module.output_path else trainer_module.output_path
|
| 344 |
+
log.info(f"Instantiating loggers... <{cfg.logger._target_}>")
|
| 345 |
+
logger_module: Logger = hydra.utils.instantiate(cfg.logger, name_wandb=name_wandb)
|
| 346 |
+
|
| 347 |
+
object_dict = {
|
| 348 |
+
"cfg": cfg,
|
| 349 |
+
"datamodule": data_module,
|
| 350 |
+
"task": task_module,
|
| 351 |
+
"trainer": trainer_module,
|
| 352 |
+
"logger": logger_module,
|
| 353 |
+
}
|
| 354 |
+
|
| 355 |
+
log.info("Logging hyperparameters!")
|
| 356 |
+
log_hyperparameters(object_dict)
|
| 357 |
+
|
| 358 |
+
engine_type = cfg.get("engine", {}).get("engine_type", "original")
|
| 359 |
+
log.info(f"Using engine: {engine_type}")
|
| 360 |
+
|
| 361 |
+
if engine_type == "lightning":
|
| 362 |
+
gen_analysis = cfg.get("generative_analysis", cfg.tasks.get("generative_analysis", False))
|
| 363 |
+
n_samples = cfg.get("n_samples", cfg.tasks.get("n_samples", 100))
|
| 364 |
+
metric = cfg.get("metrics", cfg.get("metric", cfg.tasks.get("metrics", "Validity Relax and connected")))
|
| 365 |
+
use_posebuster = cfg.get("use_posebuster", cfg.tasks.get("use_posebuster", False))
|
| 366 |
+
gen_batch_size = cfg.get("batch_size", cfg.tasks.get("batch_size", 100))
|
| 367 |
+
|
| 368 |
+
# Always save model_config for checkpoint reconstruction (VAE, LDM, etc.)
|
| 369 |
+
model_config = OmegaConf.to_container(factory_cfg, resolve=True)
|
| 370 |
+
for k, v in overrides.items():
|
| 371 |
+
if k != "train_set":
|
| 372 |
+
model_config[k] = v
|
| 373 |
+
|
| 374 |
+
if hasattr(task_module.task, "sample"):
|
| 375 |
+
metrics = lightning_wrapper(
|
| 376 |
+
task_module, data_module, trainer_module, logger_module,
|
| 377 |
+
engine_cfg=cfg.engine,
|
| 378 |
+
generative_analysis=gen_analysis, n_samples=n_samples,
|
| 379 |
+
metric=metric, use_posebuster=use_posebuster, batch_size=gen_batch_size,
|
| 380 |
+
ckpt_path=cfg.trainer.get("resume_from_checkpoint", None),
|
| 381 |
+
monitor_metric=cfg.trainer.get("monitor_metric", None),
|
| 382 |
+
monitor_mode=cfg.trainer.get("monitor_mode", None),
|
| 383 |
+
model_config=model_config,
|
| 384 |
+
)
|
| 385 |
+
else:
|
| 386 |
+
metrics = lightning_wrapper(
|
| 387 |
+
task_module, data_module, trainer_module, logger_module,
|
| 388 |
+
engine_cfg=cfg.engine,
|
| 389 |
+
ckpt_path=cfg.trainer.get("resume_from_checkpoint", None),
|
| 390 |
+
monitor_metric=cfg.trainer.get("monitor_metric", None),
|
| 391 |
+
monitor_mode=cfg.trainer.get("monitor_mode", None),
|
| 392 |
+
model_config=model_config,
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
elif engine_type == "original":
|
| 396 |
+
resume_ckpt = cfg.trainer.get("resume_from_checkpoint", None)
|
| 397 |
+
if hasattr(task_module.task, "sample"):
|
| 398 |
+
metrics = engine_wrapper(
|
| 399 |
+
task_module, data_module, trainer_module, logger_module,
|
| 400 |
+
resume_from_checkpoint=resume_ckpt,
|
| 401 |
+
generative_analysis=cfg.tasks.generative_analysis,
|
| 402 |
+
n_samples=cfg.tasks.n_samples,
|
| 403 |
+
metric=cfg.tasks.metrics,
|
| 404 |
+
use_posebuster=cfg.tasks.use_posebuster,
|
| 405 |
+
batch_size=cfg.tasks.batch_size,
|
| 406 |
+
)
|
| 407 |
+
else:
|
| 408 |
+
metrics = engine_wrapper(
|
| 409 |
+
task_module, data_module, trainer_module, logger_module,
|
| 410 |
+
resume_from_checkpoint=resume_ckpt,
|
| 411 |
+
)
|
| 412 |
+
else:
|
| 413 |
+
raise ValueError(f"Unknown engine_type: {engine_type}")
|
| 414 |
+
|
| 415 |
+
return metrics, object_dict
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
def log_hyperparameters(object_dict: dict):
|
| 419 |
+
"""Log hyperparameters for debugging."""
|
| 420 |
+
if not is_rank_zero():
|
| 421 |
+
return
|
| 422 |
+
|
| 423 |
+
log.info("\n========== Logging Hyperparameters ==========\n")
|
| 424 |
+
for name, obj in object_dict.items():
|
| 425 |
+
log.info(f"{'=' * 20} {name.upper()} {'=' * 20}")
|
| 426 |
+
if name == "cfg":
|
| 427 |
+
if isinstance(obj, dict):
|
| 428 |
+
log.info("\n" + OmegaConf.to_yaml(OmegaConf.create(obj)))
|
| 429 |
+
else:
|
| 430 |
+
log.info("\n" + OmegaConf.to_yaml(obj))
|
| 431 |
+
else:
|
| 432 |
+
if hasattr(obj, '__dict__'):
|
| 433 |
+
for k, v in vars(obj).items():
|
| 434 |
+
if not k.startswith("_"):
|
| 435 |
+
log.info(f"{k}: {v}")
|
| 436 |
+
log.info(f"{'=' * (44 + len(name))}\n")
|
| 437 |
+
|
| 438 |
+
if "task" in object_dict and hasattr(object_dict["task"], "task"):
|
| 439 |
+
model = object_dict["task"].task
|
| 440 |
+
total = sum(p.numel() for p in model.parameters())
|
| 441 |
+
trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 442 |
+
log.info(f"{'=' * 20} MODEL PARAMS {'=' * 20}")
|
| 443 |
+
log.info(f"model/params/total: {total}")
|
| 444 |
+
log.info(f"model/params/trainable: {trainable}")
|
| 445 |
+
log.info("=" * 54 + "\n")
|
| 446 |
+
|
| 447 |
+
log.info("========== End of Hyperparameters ==========\n")
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
def train_main(cfg: DictConfig):
|
| 451 |
+
"""Entry point for CLI train command."""
|
| 452 |
+
metric, _ = train(cfg)
|
| 453 |
+
return metric
|
MolecularDiffusion/configs/data/filter_molecules_by_property.py
ADDED
|
File without changes
|
MolecularDiffusion/configs/data/formed_data.yaml
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: MolecularDiffusion.runmodes.train.DataModule
|
| 2 |
+
root: /home/pregabalin/RF/blue_edm/data/formed
|
| 3 |
+
filename: /home/pregabalin/RF/blue_edm/data/formed/Data_FORMED_scored.csv # 4k or ready
|
| 4 |
+
atom_vocab: [H,B,C,N,O,F,Al,Si,P,S,Cl,As,Se,Br,I,Hg,Bi]
|
| 5 |
+
dataset_name: formed
|
| 6 |
+
with_hydrogen: True
|
| 7 |
+
node_feature: null # atom_topological, atom_geom, atom_geom_compact, atom_geom_opt
|
| 8 |
+
max_atom: 120
|
| 9 |
+
xyz_dir: /home/pregabalin/RF/blue_edm/data/formed/XYZ_FORMED/
|
| 10 |
+
coord_file: null
|
| 11 |
+
natoms_file: null
|
| 12 |
+
forbidden_atom: []
|
| 13 |
+
data_efficient_collator: True
|
| 14 |
+
train_ratio: 0.8
|
| 15 |
+
load_pkl: null
|
| 16 |
+
save_pkl: data/test.pkl
|
| 17 |
+
data_type: pyg # pyg or pointcloud
|
| 18 |
+
batch_size: 32
|
| 19 |
+
num_workers: 0
|
| 20 |
+
allow_unknown: False # additional atom type for the unknown in OHE
|
MolecularDiffusion/configs/data/mol_dataset.yaml
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: MolecularDiffusion.runmodes.train.DataModule
|
| 2 |
+
root: data/
|
| 3 |
+
filename: path_to_csv.csv # 4k or ready
|
| 4 |
+
atom_vocab: [H,B,C,N,O,F,Al,Si,P,S,Cl,As,Se,Br,I,Hg,Bi] #Ge,Sn,Te,Sb
|
| 5 |
+
dataset_name: qm9
|
| 6 |
+
with_hydrogen: True
|
| 7 |
+
use_ohe_feature: True
|
| 8 |
+
allow_unknown: False # True to add +1 "unknown" column to OHE for rare/unseen atoms
|
| 9 |
+
node_feature_choice: null # atom_topological, atom_geom, atom_geom_compact, atom_geom_opt
|
| 10 |
+
max_atom: 29
|
| 11 |
+
xyz_dir: path_to_xyz
|
| 12 |
+
coord_file: null
|
| 13 |
+
natoms_file: null
|
| 14 |
+
forbidden_atom: []
|
| 15 |
+
data_efficient_collator: True
|
| 16 |
+
train_ratio: 0.8
|
| 17 |
+
load_pkl: null
|
| 18 |
+
save_pkl: data/test.pkl #TODO this is not really used anymore
|
| 19 |
+
data_type: pointcloud # pyg or pointcloud
|
| 20 |
+
batch_size: 48
|
| 21 |
+
num_workers: 0
|
| 22 |
+
edge_type: fully_connected
|
| 23 |
+
radius: 4.0
|
| 24 |
+
n_neigh: 5
|
| 25 |
+
# consider_global_attributes: False #depricated
|
MolecularDiffusion/configs/data/mol_dataset_extraf.yaml
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: MolecularDiffusion.runmodes.train.DataModule
|
| 2 |
+
root: /home/pregabalin/RF/blue_edm/data/qm9
|
| 3 |
+
filename: /home/pregabalin/RF/blue_edm/data/qm9/dsgdb9nsd_4k.csv # 4k or ready
|
| 4 |
+
atom_vocab: [H,B,C,N,O,F,Al,Si,P,S,Cl,As,Se,Br,I,Hg,Bi]
|
| 5 |
+
dataset_name: qm9
|
| 6 |
+
with_hydrogen: True
|
| 7 |
+
node_feature: atom_geom_compact # atom_topological, atom_geom, atom_geom_compact, atom_geom_opt
|
| 8 |
+
max_atom: 29
|
| 9 |
+
xyz_dir: /home/pregabalin/RF/blue_edm/data/qm9/dsgdb9nsd/
|
| 10 |
+
coord_file: null
|
| 11 |
+
natoms_file: null
|
| 12 |
+
forbidden_atom: []
|
| 13 |
+
data_efficient_collator: True
|
| 14 |
+
train_ratio: 0.8
|
| 15 |
+
load_pkl: null
|
| 16 |
+
save_pkl: data/test.pkl
|
| 17 |
+
data_type: pointcloud # pyg or pointcloud
|
| 18 |
+
batch_size: 48
|
| 19 |
+
num_workers: 0
|
| 20 |
+
allow_unknown: False # additional atom type for the unknown in OHE
|
| 21 |
+
edge_type: fully_connected
|
| 22 |
+
radius: 4.0
|
| 23 |
+
n_neigh: 5
|
MolecularDiffusion/configs/engine/lightning.yaml
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use PyTorch Lightning Trainer
|
| 2 |
+
engine_type: lightning
|
| 3 |
+
|
| 4 |
+
# Lightning-specific trainer configuration
|
| 5 |
+
trainer_config:
|
| 6 |
+
_target_: pytorch_lightning.Trainer
|
| 7 |
+
|
| 8 |
+
# Training
|
| 9 |
+
max_epochs: ${trainer.num_epochs}
|
| 10 |
+
accelerator: auto
|
| 11 |
+
devices: auto
|
| 12 |
+
strategy: auto # Lightning auto-selects ddp/ddp_spawn based on devices
|
| 13 |
+
|
| 14 |
+
# Precision - will be converted to Lightning format in Python
|
| 15 |
+
precision: ${trainer.precision}
|
| 16 |
+
|
| 17 |
+
# Optimization
|
| 18 |
+
accumulate_grad_batches: 1
|
| 19 |
+
gradient_clip_val: ${trainer.grad_clip_value}
|
| 20 |
+
gradient_clip_algorithm: ${trainer.gradient_clip_mode}
|
| 21 |
+
|
| 22 |
+
# Logging & Validation
|
| 23 |
+
log_every_n_steps: ${logger.log_interval}
|
| 24 |
+
check_val_every_n_epoch: ${trainer.validation_interval}
|
| 25 |
+
|
| 26 |
+
# Checkpointing
|
| 27 |
+
enable_checkpointing: true
|
| 28 |
+
default_root_dir: ${trainer.output_path}
|
| 29 |
+
|
| 30 |
+
# Other
|
| 31 |
+
num_sanity_val_steps: 0 # Skip sanity validation
|
| 32 |
+
enable_progress_bar: true
|
| 33 |
+
enable_model_summary: true
|
MolecularDiffusion/configs/engine/original.yaml
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Use the original custom Engine class
|
| 2 |
+
engine_type: original
|
| 3 |
+
|
| 4 |
+
# Engine is instantiated inline in train.py using engine_wrapper()
|
MolecularDiffusion/configs/hydra/default.yaml
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# https://hydra.cc/docs/configure_hydra/intro/
|
| 2 |
+
|
| 3 |
+
# enable color logging
|
| 4 |
+
# install hydra-colorlog==1.2.0
|
| 5 |
+
defaults:
|
| 6 |
+
- override hydra_logging: colorlog
|
| 7 |
+
- override job_logging: colorlog
|
| 8 |
+
|
| 9 |
+
# output directory, generated dynamically on each run
|
| 10 |
+
run:
|
| 11 |
+
dir: ${trainer.output_path}
|
| 12 |
+
# dir: ${trainer.output_path}/${tasks.task_type}/runs/${name}_${now:%Y-%m-%d}_${now:%H-%M-%S}
|
| 13 |
+
|
| 14 |
+
job_logging:
|
| 15 |
+
handlers:
|
| 16 |
+
file:
|
| 17 |
+
# Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242
|
| 18 |
+
filename: ${hydra.runtime.output_dir}/${name}_${now:%Y-%m-%d}_${now:%H-%M-%S}.log
|
| 19 |
+
|
MolecularDiffusion/configs/interference/gen_cfg.yaml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: MolecularDiffusion.runmodes.generate.GenerativeFactory
|
| 2 |
+
task_type: cfg
|
| 3 |
+
sampling_mode: "ddpm"
|
| 4 |
+
num_generate: 100
|
| 5 |
+
mol_size: [0,0]
|
| 6 |
+
max_mol_size: 0
|
| 7 |
+
target_values: [3,1.5]
|
| 8 |
+
property_names: ["S1_exc", "T1_exc"]
|
| 9 |
+
batch_size: 1
|
| 10 |
+
seed: 86
|
| 11 |
+
n_frames: 0
|
| 12 |
+
output_path: generated_mol
|
| 13 |
+
condition_configs:
|
| 14 |
+
cfg_scale: 1
|
| 15 |
+
cfg_scale_schedule: null
|
MolecularDiffusion/configs/interference/gen_cfggg.yaml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: MolecularDiffusion.runmodes.generate.GenerativeFactory
|
| 2 |
+
task_type: gradient_guidance # cfggg
|
| 3 |
+
sampling_mode: "ddpm"
|
| 4 |
+
num_generate: 100
|
| 5 |
+
mol_size: [0,0]
|
| 6 |
+
max_mol_size: 0
|
| 7 |
+
target_values: [3,1.5]
|
| 8 |
+
property_names: ["S1_exc", "T1_exc"]
|
| 9 |
+
batch_size: 1
|
| 10 |
+
seed: 86
|
| 11 |
+
n_frames: 0
|
| 12 |
+
output_path: generated_mol
|
| 13 |
+
condition_configs:
|
| 14 |
+
cfg_scale: 1
|
| 15 |
+
target_function:
|
| 16 |
+
_target_: scripts.gradient_guidance.sf_energy_score.SFEnergyScore
|
| 17 |
+
_partial_: true
|
| 18 |
+
chkpt_directory: trained_models/egcl_guidance_s1t1.ckpt
|
| 19 |
+
gg_scale: 1e-3
|
| 20 |
+
max_norm: 1e-3
|
| 21 |
+
scheduler:
|
| 22 |
+
_target_: scripts.gradient_guidance.scheduler.CosineAnnealing
|
| 23 |
+
_partial_: true
|
| 24 |
+
T_max: 1000
|
| 25 |
+
eta_min: 0
|
| 26 |
+
guidance_ver: 2
|
| 27 |
+
guidance_at: 1
|
| 28 |
+
guidance_stop: 0
|
| 29 |
+
n_backwards: 3
|
MolecularDiffusion/configs/interference/gen_conditional.yaml
ADDED
|
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: MolecularDiffusion.runmodes.generate.GenerativeFactory
|
| 2 |
+
task_type: conditional
|
| 3 |
+
sampling_mode: "ddpm"
|
| 4 |
+
num_generate: 100
|
| 5 |
+
mol_size: [0,0]
|
| 6 |
+
max_mol_size: 0
|
| 7 |
+
target_values: [3,1.5]
|
| 8 |
+
property_names: ["S1_exc", "T1_exc"]
|
| 9 |
+
batch_size: 1
|
| 10 |
+
seed: 86
|
| 11 |
+
n_frames: 0
|
| 12 |
+
output_path: generated_mol
|
MolecularDiffusion/configs/interference/gen_gg.yaml
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: MolecularDiffusion.runmodes.generate.GenerativeFactory
|
| 2 |
+
task_type: gradient_guidance # gg
|
| 3 |
+
sampling_mode: "ddpm"
|
| 4 |
+
num_generate: 100
|
| 5 |
+
mol_size: [0,0]
|
| 6 |
+
max_mol_size: 0
|
| 7 |
+
target_values: []
|
| 8 |
+
property_names: []
|
| 9 |
+
batch_size: 1
|
| 10 |
+
seed: 86
|
| 11 |
+
n_frames: 0
|
| 12 |
+
output_path: generated_mol
|
| 13 |
+
condition_configs:
|
| 14 |
+
cfg_scale: 0
|
| 15 |
+
target_function:
|
| 16 |
+
_target_: scripts.gradient_guidance.sf_energy_score.SFEnergyScore
|
| 17 |
+
_partial_: true
|
| 18 |
+
chkpt_directory: trained_models/egcl_guidance_s1t1.ckpt
|
| 19 |
+
gg_scale: 1e-3
|
| 20 |
+
max_norm: 1e-3
|
| 21 |
+
scheduler:
|
| 22 |
+
_target_: scripts.gradient_guidance.scheduler.CosineAnnealing
|
| 23 |
+
_partial_: true
|
| 24 |
+
T_max: 1000
|
| 25 |
+
eta_min: 0
|
| 26 |
+
guidance_ver: 2
|
| 27 |
+
guidance_at: 1
|
| 28 |
+
guidance_stop: 0
|
| 29 |
+
n_backwards: 0
|
MolecularDiffusion/configs/interference/gen_hybrid.yaml
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: MolecularDiffusion.runmodes.generate.GenerativeFactory
|
| 2 |
+
task_type: inpaint_cfg
|
| 3 |
+
sampling_mode: "ddpm"
|
| 4 |
+
num_generate: 100
|
| 5 |
+
mol_size: [0,0]
|
| 6 |
+
max_mol_size: 0
|
| 7 |
+
target_values: [3,1.5]
|
| 8 |
+
property_names: ["S1_exc", "T1_exc"]
|
| 9 |
+
batch_size: 1
|
| 10 |
+
seed: 86
|
| 11 |
+
n_frames: 0
|
| 12 |
+
output_path: generated_mol
|
| 13 |
+
condition_configs:
|
| 14 |
+
cfg_scale: 1
|
| 15 |
+
reference_structure_path: "data/template_structures/INT2_0.xyz"
|
| 16 |
+
inpaint_cfgs:
|
| 17 |
+
t_start: 0.8
|
| 18 |
+
t_critical: 0.05
|
| 19 |
+
|
| 20 |
+
# inpaint
|
| 21 |
+
# denoising_strength: 0.7
|
| 22 |
+
# noise_initial_mask: False
|
| 23 |
+
# mask_node_index:
|
| 24 |
+
# - 5
|
| 25 |
+
# - 30
|
| 26 |
+
# - 31
|
| 27 |
+
|
| 28 |
+
|
MolecularDiffusion/configs/interference/gen_inpaint.yaml
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: MolecularDiffusion.runmodes.generate.GenerativeFactory
|
| 2 |
+
task_type: inpaint
|
| 3 |
+
sampling_mode: "ddpm"
|
| 4 |
+
num_generate: 100
|
| 5 |
+
mol_size: [0,0]
|
| 6 |
+
max_mol_size: 0
|
| 7 |
+
target_values: []
|
| 8 |
+
property_names: []
|
| 9 |
+
batch_size: 1
|
| 10 |
+
seed: 86
|
| 11 |
+
n_frames: 0
|
| 12 |
+
output_path: generated_mol
|
| 13 |
+
condition_configs:
|
| 14 |
+
reference_structure_path: "data/template_structures/BINOLCpHHH.xyz"
|
| 15 |
+
condition_component: xh
|
| 16 |
+
inpaint_cfgs:
|
| 17 |
+
mask_node_index:
|
| 18 |
+
- 5
|
| 19 |
+
- 30
|
| 20 |
+
- 31
|
| 21 |
+
- 6
|
| 22 |
+
- 7
|
| 23 |
+
- 45
|
| 24 |
+
- 8
|
| 25 |
+
- 32
|
| 26 |
+
- 9
|
| 27 |
+
- 10
|
| 28 |
+
- 33
|
| 29 |
+
- 11
|
| 30 |
+
- 34
|
| 31 |
+
- 12
|
| 32 |
+
- 35
|
| 33 |
+
- 13
|
| 34 |
+
- 36
|
| 35 |
+
- 14
|
| 36 |
+
- 15
|
| 37 |
+
- 16
|
| 38 |
+
- 17
|
| 39 |
+
- 18
|
| 40 |
+
- 37
|
| 41 |
+
- 19
|
| 42 |
+
- 38
|
| 43 |
+
- 20
|
| 44 |
+
- 39
|
| 45 |
+
- 21
|
| 46 |
+
- 40
|
| 47 |
+
- 22
|
| 48 |
+
- 23
|
| 49 |
+
- 41
|
| 50 |
+
- 24
|
| 51 |
+
- 44
|
| 52 |
+
- 25
|
| 53 |
+
- 26
|
| 54 |
+
- 43
|
| 55 |
+
- 42
|
| 56 |
+
denoising_strength: 0.75
|
| 57 |
+
t_start: 0.8
|
| 58 |
+
t_critical_1: 0.8
|
| 59 |
+
t_critical_2: 1
|
| 60 |
+
d_threshold_f: 1.5
|
| 61 |
+
w_b: 10
|
| 62 |
+
all_frozen: True
|
| 63 |
+
use_covalent_radii: True
|
| 64 |
+
scale_factor: 1.2
|
| 65 |
+
noise_initial_mask: True
|
| 66 |
+
n_frames: 0
|
| 67 |
+
n_retrys: 0
|
| 68 |
+
t_retry: 180
|
| 69 |
+
|
MolecularDiffusion/configs/interference/gen_outpaint.yaml
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: MolecularDiffusion.runmodes.generate.GenerativeFactory
|
| 2 |
+
task_type: outpaint
|
| 3 |
+
sampling_mode: ddpm
|
| 4 |
+
num_generate: 100
|
| 5 |
+
mol_size: [0, 0]
|
| 6 |
+
max_mol_size: 0
|
| 7 |
+
target_values: []
|
| 8 |
+
property_names: []
|
| 9 |
+
batch_size: 1
|
| 10 |
+
seed: 86
|
| 11 |
+
n_frames: 0
|
| 12 |
+
output_path: generated_mol
|
| 13 |
+
|
| 14 |
+
condition_configs:
|
| 15 |
+
reference_structure_path: data/template_structures/BINOLCp.xyz
|
| 16 |
+
condition_component: xh
|
| 17 |
+
|
| 18 |
+
outpaint_cfgs:
|
| 19 |
+
t_start: 0.8
|
| 20 |
+
t_critical_1: 0.7
|
| 21 |
+
t_critical_2: 0.4
|
| 22 |
+
d_threshold_f: 2
|
| 23 |
+
w_b: 0.1
|
| 24 |
+
all_frozen: false
|
| 25 |
+
use_covalent_radii: true
|
| 26 |
+
scale_factor: 1.1
|
| 27 |
+
noise_initial_mask: false
|
| 28 |
+
connector_dicts: {} # fill if needed, e.g. {0: [3]}
|
| 29 |
+
|
| 30 |
+
n_retrys: 3
|
| 31 |
+
t_retry: 180
|
MolecularDiffusion/configs/interference/gen_outpaintft.yaml
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: MolecularDiffusion.runmodes.generate.GenerativeFactory
|
| 2 |
+
task_type: outpaintft
|
| 3 |
+
sampling_mode: "ddpm"
|
| 4 |
+
num_generate: 100
|
| 5 |
+
mol_size: [76,76]
|
| 6 |
+
target_values: []
|
| 7 |
+
property_names: []
|
| 8 |
+
batch_size: 1
|
| 9 |
+
seed: 86
|
| 10 |
+
n_frames: 0
|
| 11 |
+
output_path: generated_mol
|
| 12 |
+
condition_configs:
|
| 13 |
+
reference_structure_path: "data/template_structures/INT2_0.xyz"
|
| 14 |
+
outpaint_cfgs:
|
| 15 |
+
t_start: 1
|
| 16 |
+
n_retrys: 0
|
| 17 |
+
t_retry: 180
|
| 18 |
+
|
MolecularDiffusion/configs/interference/gen_unconditional.yaml
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: MolecularDiffusion.runmodes.generate.GenerativeFactory
|
| 2 |
+
task_type: unconditional
|
| 3 |
+
sampling_mode: "ddpm"
|
| 4 |
+
num_generate: 100
|
| 5 |
+
mol_size: [16]
|
| 6 |
+
target_values: []
|
| 7 |
+
property_names: []
|
| 8 |
+
batch_size: 1
|
| 9 |
+
seed: 86
|
| 10 |
+
n_frames: 0
|
| 11 |
+
output_path: generated_mol
|
MolecularDiffusion/configs/interference/prediction.yaml
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
prop_names: ["S1_exc", "T1_exc"]
|
| 2 |
+
hit_criteria: null
|
MolecularDiffusion/configs/logger/default.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: MolecularDiffusion.runmodes.train.Logger
|
| 2 |
+
logger: logging # wandb, logging
|
| 3 |
+
log_interval: 2
|
| 4 |
+
name_wandb: MolecularDiffusion
|
| 5 |
+
project_wandb: MolecularDiffusion
|
| 6 |
+
dir_wandb: ${trainer.output_path}
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
MolecularDiffusion/configs/logger/wandb.yaml
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: MolecularDiffusion.runmodes.train.Logger
|
| 2 |
+
logger: wandb # wandb, logging
|
| 3 |
+
log_interval: 2
|
| 4 |
+
name_wandb: MolecularDiffusion
|
| 5 |
+
project_wandb: MolecularDiffusion
|
| 6 |
+
dir_wandb: ${trainer.output_path}
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
MolecularDiffusion/configs/models/tabasco_transformer.yaml
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TABASCO Transformer model configuration
|
| 2 |
+
# State-of-the-art non-equivariant flow matching model for molecules
|
| 3 |
+
_target_: MolecularDiffusion.modules.tasks.diffusion_tabasco.TabascoDiffusionTask
|
| 4 |
+
|
| 5 |
+
# Number of atom types from dataset vocabulary
|
| 6 |
+
num_atom_types: 19 # Will be overridden by ${data.num_atom_types} at runtime
|
| 7 |
+
|
| 8 |
+
# Transformer backbone configuration
|
| 9 |
+
transformer_config:
|
| 10 |
+
_target_: MolecularDiffusion.modules.layers.tabasco.transformer_module.TransformerModule
|
| 11 |
+
spatial_dim: 3
|
| 12 |
+
atom_dim: 19 # Will be overridden by ${data.num_atom_types}
|
| 13 |
+
hidden_dim: 256
|
| 14 |
+
num_layers: 16
|
| 15 |
+
num_heads: 8
|
| 16 |
+
activation: SiLU
|
| 17 |
+
implementation: pytorch # or 'reimplemented'
|
| 18 |
+
cross_attention: true
|
| 19 |
+
add_sinusoid_posenc: true
|
| 20 |
+
concat_combine_input: false
|
| 21 |
+
custom_weight_init: null # or 'xavier', 'kaiming', etc.
|
| 22 |
+
|
| 23 |
+
# Continuous coordinate interpolant configuration
|
| 24 |
+
coords_interpolant_config:
|
| 25 |
+
_target_: MolecularDiffusion.modules.models.tabasco.flow.interpolate.SDEMetricInterpolant
|
| 26 |
+
key: coords
|
| 27 |
+
loss_weight: 1.0
|
| 28 |
+
centered: true
|
| 29 |
+
scale_noise_by_log_num_atoms: false
|
| 30 |
+
noise_scale: 1.0
|
| 31 |
+
# Langevin sampling schedule for SDE integration
|
| 32 |
+
langevin_sampling_schedule:
|
| 33 |
+
_target_: MolecularDiffusion.modules.models.tabasco.sample.noise_schedule.SampleNoiseSchedule
|
| 34 |
+
cutoff: 0.9
|
| 35 |
+
white_noise_sampling_scale: 0.01
|
| 36 |
+
# Time-dependent loss weighting
|
| 37 |
+
time_factor:
|
| 38 |
+
_target_: MolecularDiffusion.modules.models.tabasco.flow.time_factor.InverseTimeFactor
|
| 39 |
+
max_value: 100.0
|
| 40 |
+
min_value: 0.05
|
| 41 |
+
zero_before: 0.0
|
| 42 |
+
eps: 1.0e-6
|
| 43 |
+
|
| 44 |
+
# Discrete atom type interpolant configuration
|
| 45 |
+
atomics_interpolant_config:
|
| 46 |
+
_target_: MolecularDiffusion.modules.models.tabasco.flow.interpolate.DiscreteInterpolant
|
| 47 |
+
key: atomics
|
| 48 |
+
loss_weight: 0.1
|
| 49 |
+
# Time-dependent loss weighting
|
| 50 |
+
time_factor:
|
| 51 |
+
_target_: MolecularDiffusion.modules.models.tabasco.flow.time_factor.InverseTimeFactor
|
| 52 |
+
max_value: 100.0
|
| 53 |
+
min_value: 0.05
|
| 54 |
+
zero_before: 0.0
|
| 55 |
+
eps: 1.0e-6
|
| 56 |
+
|
| 57 |
+
# Flow matching training configuration
|
| 58 |
+
flow_matching_config:
|
| 59 |
+
_target_: MolecularDiffusion.modules.models.tabasco.flow_model.FlowMatchingModel
|
| 60 |
+
time_distribution:
|
| 61 |
+
_target_: MolecularDiffusion.modules.models.tabasco.flow.utils.HistogramTimeDistribution
|
| 62 |
+
time_alpha_factor: 1.8
|
| 63 |
+
num_random_augmentations: 7 # +1 original = 8 total
|
| 64 |
+
sample_schedule: log # or 'linear', 'power'
|
| 65 |
+
compile: false
|
| 66 |
+
interdist_loss: null
|
| 67 |
+
|
| 68 |
+
# Dataset statistics (populated at runtime)
|
| 69 |
+
dataset_stats:
|
| 70 |
+
max_atoms: 29 # Will be set from data config
|
| 71 |
+
atom_count_histogram: {} # Computed from dataset
|
| 72 |
+
all_smiles: [] # Collected from dataset
|
MolecularDiffusion/configs/tasks/diffusion.yaml
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: MolecularDiffusion.runmodes.train.ModelTaskFactory_EGCL
|
| 2 |
+
task_type: diffusion
|
| 3 |
+
atom_vocab: ${data.atom_vocab}
|
| 4 |
+
condition_names: []
|
| 5 |
+
hidden_size: 192
|
| 6 |
+
act_fn:
|
| 7 |
+
_target_: torch.nn.SiLU
|
| 8 |
+
num_layers: 9
|
| 9 |
+
attention: True
|
| 10 |
+
tanh: True
|
| 11 |
+
num_sublayers: 1
|
| 12 |
+
sin_embedding: False
|
| 13 |
+
aggregation_method: "sum"
|
| 14 |
+
dropout: 0.0
|
| 15 |
+
normalization: False
|
| 16 |
+
include_cosine: True
|
| 17 |
+
norm_constant: 1.0
|
| 18 |
+
normalization_factor: 1.0
|
| 19 |
+
chkpt_path: null
|
| 20 |
+
|
| 21 |
+
# specific to diffusion
|
| 22 |
+
diffusion_steps : 900
|
| 23 |
+
diffusion_noise_schedule : polynomial_2 # learned, cosine_x, polynomial_x, issnr_x, smld_x
|
| 24 |
+
diffusion_noise_precision: 1e-5
|
| 25 |
+
diffusion_loss_type: vlb
|
| 26 |
+
normalize_factors: [1,4,10]
|
| 27 |
+
extra_norm_values: []
|
| 28 |
+
augment_noise: False
|
| 29 |
+
data_augmentation: False
|
| 30 |
+
context_mask_rate: 0.2
|
| 31 |
+
mask_value: 5
|
| 32 |
+
normalize_condition: value_10 # [None, "maxmin", "mad"]
|
| 33 |
+
sp_regularizer_deploy: False
|
| 34 |
+
sp_regularizer_regularizer: hard
|
| 35 |
+
sp_regularizer_lambda_: 0
|
| 36 |
+
sp_regularizer_lambda_2: 1000
|
| 37 |
+
sp_regularizer_lambda_update_value: 1
|
| 38 |
+
sp_regularizer_lambda_update_step: 100
|
| 39 |
+
sp_regularizer_polynomial_p: 1.1
|
| 40 |
+
sp_regularizer_warm_up_steps: 100
|
| 41 |
+
use_unknown_fallback: False
|
| 42 |
+
reference_indices: null # indices of core atoms for the outpainting objective
|
| 43 |
+
# evaluator parameters
|
| 44 |
+
use_posebuster: True
|
| 45 |
+
metrics: valid_posebuster # use_posebuster must be true
|
| 46 |
+
n_samples: 48
|
| 47 |
+
batch_size: 4
|
| 48 |
+
generative_analysis: True
|
MolecularDiffusion/configs/tasks/diffusion_egt.yaml
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: MolecularDiffusion.runmodes.train.tasks_egt.ModelTaskFactory
|
| 2 |
+
task_type: diffusion
|
| 3 |
+
model_class: GraphTransformer
|
| 4 |
+
atom_vocab: ${data.atom_vocab}
|
| 5 |
+
condition_names: []
|
| 6 |
+
hidden_dims:
|
| 7 |
+
dx: 256
|
| 8 |
+
de: 64
|
| 9 |
+
dy: 4
|
| 10 |
+
n_head: 4
|
| 11 |
+
dim_ffX: 256
|
| 12 |
+
dim_ffE: 64
|
| 13 |
+
dim_ffy: 1
|
| 14 |
+
hidden_mlp_dims:
|
| 15 |
+
X: 256
|
| 16 |
+
E: 64
|
| 17 |
+
y: 256
|
| 18 |
+
pos: 512
|
| 19 |
+
act_fn_in:
|
| 20 |
+
_target_: torch.nn.SiLU
|
| 21 |
+
act_fn_out:
|
| 22 |
+
_target_: torch.nn.SiLU
|
| 23 |
+
num_layers: 6
|
| 24 |
+
dropout: 0.1
|
| 25 |
+
chkpt_path: null
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
# specific to diffusion
|
| 29 |
+
diffusion_steps : 400
|
| 30 |
+
diffusion_noise_schedule : polynomial_2 # learned, cosine_x, polynomial_x, issnr_x, smld_x
|
| 31 |
+
diffusion_noise_precision: 1e-5
|
| 32 |
+
diffusion_loss_type: vlb
|
| 33 |
+
normalize_factors: [1,4,10]
|
| 34 |
+
extra_norm_values: []
|
| 35 |
+
augment_noise: False
|
| 36 |
+
data_augmentation: False
|
| 37 |
+
context_mask_rate: 0.2
|
| 38 |
+
mask_value: 5
|
| 39 |
+
normalize_condition: value_10 # [None, "maxmin", "mad"]
|
| 40 |
+
sp_regularizer_deploy: False
|
| 41 |
+
sp_regularizer_regularizer: hard
|
| 42 |
+
sp_regularizer_lambda_: 0
|
| 43 |
+
sp_regularizer_lambda_2: 1000
|
| 44 |
+
sp_regularizer_lambda_update_value: 1
|
| 45 |
+
sp_regularizer_lambda_update_step: 100
|
| 46 |
+
sp_regularizer_polynomial_p: 1.1
|
| 47 |
+
sp_regularizer_warm_up_steps: 100
|
| 48 |
+
reference_indices: null # indices of core atoms for the outpainting objective
|
| 49 |
+
# evaluator parameters
|
| 50 |
+
use_posebuster: True
|
| 51 |
+
metrics: valid_posebuster # use_posebuster must be true
|
| 52 |
+
n_samples: 24
|
| 53 |
+
generative_analysis: True
|
| 54 |
+
batch_size: 4
|
MolecularDiffusion/configs/tasks/diffusion_extraf.yaml
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: MolecularDiffusion.runmodes.train.ModelTaskFactory_EGCL
|
| 2 |
+
task_type: diffusion
|
| 3 |
+
atom_vocab: ${data.atom_vocab}
|
| 4 |
+
condition_names: []
|
| 5 |
+
hidden_size: 192
|
| 6 |
+
act_fn:
|
| 7 |
+
_target_: torch.nn.SiLU
|
| 8 |
+
num_layers: 1
|
| 9 |
+
attention: True
|
| 10 |
+
tanh: True
|
| 11 |
+
num_sublayers: 12
|
| 12 |
+
sin_embedding: False
|
| 13 |
+
aggregation_method: "sum"
|
| 14 |
+
dropout: 0.0
|
| 15 |
+
normalization: False
|
| 16 |
+
include_cosine: True
|
| 17 |
+
norm_constant: 1.0
|
| 18 |
+
normalization_factor: 1.0
|
| 19 |
+
chkpt_path: null
|
| 20 |
+
|
| 21 |
+
# specific to diffusion
|
| 22 |
+
diffusion_steps : 400
|
| 23 |
+
diffusion_noise_schedule : polynomial_2 # learned, cosine_x, polynomial_x, issnr_x, smld_x
|
| 24 |
+
diffusion_noise_precision: 1e-5
|
| 25 |
+
diffusion_loss_type: vlb
|
| 26 |
+
normalize_factors: [1,4,10]
|
| 27 |
+
extra_norm_values: [10,10]
|
| 28 |
+
augment_noise: False
|
| 29 |
+
data_augmentation: False
|
| 30 |
+
context_mask_rate: 0.2
|
| 31 |
+
mask_value: 5
|
| 32 |
+
normalize_condition: value_10 # [None, "maxmin", "mad"]
|
| 33 |
+
sp_regularizer_deploy: False
|
| 34 |
+
sp_regularizer_regularizer: hard
|
| 35 |
+
sp_regularizer_lambda_: 0
|
| 36 |
+
sp_regularizer_lambda_2: 1000
|
| 37 |
+
sp_regularizer_lambda_update_value: 1
|
| 38 |
+
sp_regularizer_lambda_update_step: 100
|
| 39 |
+
sp_regularizer_polynomial_p: 1.1
|
| 40 |
+
sp_regularizer_warm_up_steps: 100
|
| 41 |
+
reference_indices: null # indices of core atoms for the outpainting objective
|
| 42 |
+
# evaluator parameters
|
| 43 |
+
use_posebuster: True
|
| 44 |
+
metrics: valid_posebuster # use_posebuster must be true
|
| 45 |
+
n_samples: 24
|
| 46 |
+
generative_analysis: True
|
| 47 |
+
batch_size: 4
|
MolecularDiffusion/configs/tasks/diffusion_hybrid.yaml
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: MolecularDiffusion.runmodes.train.tasks_esen.ModelTaskFactory
|
| 2 |
+
task_type: diffusion_hybrid
|
| 3 |
+
|
| 4 |
+
# === Atom Vocabulary ===
|
| 5 |
+
# Specify either atom_vocab directly OR use the one from data config
|
| 6 |
+
# Available base vocabularies: H, C, N, O, F, P, S, Cl, Br, I (common organic)
|
| 7 |
+
# The number of classes is automatically determined from vocab length
|
| 8 |
+
atom_vocab: ${data.atom_vocab}
|
| 9 |
+
# atom_vocab: ["C", "N", "O", "H", "F", "S", "Cl", "Br", "P", "I"] # Example custom
|
| 10 |
+
|
| 11 |
+
condition_names: []
|
| 12 |
+
|
| 13 |
+
# eSEN specific parameters
|
| 14 |
+
hidden_size: 64
|
| 15 |
+
hidden_channels: 64
|
| 16 |
+
num_layers: 9
|
| 17 |
+
lmax: 2
|
| 18 |
+
mmax: 2
|
| 19 |
+
grid_resolution: null
|
| 20 |
+
cutoff: 30
|
| 21 |
+
edge_channels: 128
|
| 22 |
+
distance_function: "gaussian"
|
| 23 |
+
num_distance_basis: 512
|
| 24 |
+
norm_type: "rms_norm_sh"
|
| 25 |
+
act_type: "s2"
|
| 26 |
+
mlp_type: "grid"
|
| 27 |
+
otf_graph: True
|
| 28 |
+
use_envelope: False
|
| 29 |
+
activation_checkpointing: False
|
| 30 |
+
global_attributes: False
|
| 31 |
+
sphere_embedding_type: "mixed" # DO NOT CHANGE
|
| 32 |
+
aggregation_method: "sum"
|
| 33 |
+
|
| 34 |
+
chkpt_path: null
|
| 35 |
+
|
| 36 |
+
# === Continuous Diffusion Parameters ===
|
| 37 |
+
diffusion_steps: 450
|
| 38 |
+
diffusion_noise_schedule: polynomial_2 # Options: cosine, polynomial_2, polynomial_3, learned
|
| 39 |
+
diffusion_noise_precision: 1e-5
|
| 40 |
+
diffusion_loss_type: l2 # Options: vlb, l2
|
| 41 |
+
normalize_factors: [1, 1]
|
| 42 |
+
extra_norm_values: []
|
| 43 |
+
augment_noise: False
|
| 44 |
+
data_augmentation: False
|
| 45 |
+
context_mask_rate: 0.2
|
| 46 |
+
mask_value: 5
|
| 47 |
+
normalize_condition: value_10
|
| 48 |
+
sp_regularizer_deploy: False
|
| 49 |
+
sp_regularizer_regularizer: hard
|
| 50 |
+
sp_regularizer_lambda_: 0
|
| 51 |
+
sp_regularizer_lambda_2: 1000
|
| 52 |
+
sp_regularizer_lambda_update_value: 1
|
| 53 |
+
sp_regularizer_lambda_update_step: 100
|
| 54 |
+
sp_regularizer_polynomial_p: 1.1
|
| 55 |
+
sp_regularizer_warm_up_steps: 100
|
| 56 |
+
reference_indices: null
|
| 57 |
+
|
| 58 |
+
# === Discrete Diffusion Parameters (Atom Types) ===
|
| 59 |
+
# Number of atom classes (automatically set from atom_vocab length if not specified)
|
| 60 |
+
num_atom_classes: 19
|
| 61 |
+
|
| 62 |
+
# Weight for discrete loss in combined loss: L_total = L_continuous + λ * L_discrete
|
| 63 |
+
discrete_loss_weight: 0.2
|
| 64 |
+
|
| 65 |
+
# Discrete masking schedule for absorbing-state diffusion
|
| 66 |
+
# Each schedule controls how quickly tokens get masked during forward diffusion
|
| 67 |
+
#
|
| 68 |
+
# Available schedules:
|
| 69 |
+
# - "cosine" : Smooth cosine decay (default, from improved DDPM)
|
| 70 |
+
# - "linear" : Linear increase in masking probability
|
| 71 |
+
# - "sqrt" : Square root schedule (faster initial masking)
|
| 72 |
+
# - "quadratic" : Quadratic schedule (slower initial, faster later)
|
| 73 |
+
# - "cubic" : Cubic schedule (even slower start than quadratic)
|
| 74 |
+
# - "sigmoid" : S-curve transition (smooth start and end)
|
| 75 |
+
# - "exponential" : Exponential decay of survival probability
|
| 76 |
+
# - "log" : Logarithmic schedule (fast early, slow late)
|
| 77 |
+
# - "uniform" : Constant masking rate each step
|
| 78 |
+
#
|
| 79 |
+
discrete_schedule: "cosine"
|
| 80 |
+
|
| 81 |
+
# MLP layers for atom classification head
|
| 82 |
+
atom_head_mlp_layers: 2
|
| 83 |
+
|
| 84 |
+
# === eSEN Dynamics specific ===
|
| 85 |
+
use_adapter_module: False
|
| 86 |
+
tanh: True
|
| 87 |
+
coords_range: 10
|
| 88 |
+
normalization_factor: 1.0
|
| 89 |
+
|
| 90 |
+
# === Evaluator Parameters ===
|
| 91 |
+
use_posebuster: True
|
| 92 |
+
metrics: valid_posebuster
|
| 93 |
+
n_samples: 96
|
| 94 |
+
batch_size: 8
|
| 95 |
+
generative_analysis: True
|
MolecularDiffusion/configs/tasks/diffusion_hybrid_egcl.yaml
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: MolecularDiffusion.runmodes.train.tasks_egcl.ModelTaskFactory
|
| 2 |
+
task_type: diffusion_hybrid
|
| 3 |
+
|
| 4 |
+
# === Atom Vocabulary ===
|
| 5 |
+
atom_vocab: ${data.atom_vocab}
|
| 6 |
+
|
| 7 |
+
condition_names: []
|
| 8 |
+
|
| 9 |
+
# === EGNN Parameters ===
|
| 10 |
+
hidden_size: 192
|
| 11 |
+
num_layers: 9
|
| 12 |
+
attention: True
|
| 13 |
+
norm_diff: True
|
| 14 |
+
tanh: True
|
| 15 |
+
coords_range: 15
|
| 16 |
+
num_sublayers: 1
|
| 17 |
+
sin_embedding: True
|
| 18 |
+
include_cosine: False
|
| 19 |
+
normalization_factor: 1.0
|
| 20 |
+
aggregation_method: "sum"
|
| 21 |
+
dropout: 0.0
|
| 22 |
+
normalization: False
|
| 23 |
+
|
| 24 |
+
chkpt_path: null
|
| 25 |
+
|
| 26 |
+
# === Continuous Diffusion Parameters ===
|
| 27 |
+
diffusion_steps: 900
|
| 28 |
+
diffusion_noise_schedule: polynomial_2
|
| 29 |
+
diffusion_noise_precision: 1e-5
|
| 30 |
+
diffusion_loss_type: vlb
|
| 31 |
+
normalize_factors: [1, 4]
|
| 32 |
+
extra_norm_values: []
|
| 33 |
+
augment_noise: False
|
| 34 |
+
data_augmentation: False
|
| 35 |
+
context_mask_rate: 0.0
|
| 36 |
+
mask_value: 5.0
|
| 37 |
+
normalize_condition: value_10
|
| 38 |
+
sp_regularizer_deploy: False
|
| 39 |
+
|
| 40 |
+
# === Discrete Diffusion Parameters (Atom Types) ===
|
| 41 |
+
num_atom_classes: 19
|
| 42 |
+
discrete_loss_weight: 0.2
|
| 43 |
+
discrete_schedule: "cosine"
|
| 44 |
+
|
| 45 |
+
# MLP layers for atom classification head
|
| 46 |
+
atom_head_mlp_layers: 2
|
| 47 |
+
|
| 48 |
+
# === Evaluator Parameters ===
|
| 49 |
+
use_posebuster: True
|
| 50 |
+
metrics: valid_posebuster
|
| 51 |
+
n_samples: 48
|
| 52 |
+
batch_size: 8
|
| 53 |
+
generative_analysis: True
|
MolecularDiffusion/configs/tasks/diffusion_integer.yaml
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: MolecularDiffusion.runmodes.train.tasks_esen.ModelTaskFactory
|
| 2 |
+
task_type: diffusion
|
| 3 |
+
atom_vocab: ${data.atom_vocab}
|
| 4 |
+
condition_names: []
|
| 5 |
+
|
| 6 |
+
# eSEN specific parameters
|
| 7 |
+
hidden_size: 256
|
| 8 |
+
hidden_channels: 256
|
| 9 |
+
num_layers: 4
|
| 10 |
+
lmax: 2
|
| 11 |
+
mmax: 2
|
| 12 |
+
grid_resolution: null
|
| 13 |
+
cutoff: 5.0
|
| 14 |
+
edge_channels: 128
|
| 15 |
+
distance_function: "gaussian"
|
| 16 |
+
num_distance_basis: 512
|
| 17 |
+
norm_type: "rms_norm_sh"
|
| 18 |
+
act_type: "s2"
|
| 19 |
+
mlp_type: "grid"
|
| 20 |
+
otf_graph: True #!!
|
| 21 |
+
use_envelope: False
|
| 22 |
+
activation_checkpointing: False
|
| 23 |
+
global_attributes: False
|
| 24 |
+
sphere_embedding_type: "gaussian" #!!
|
| 25 |
+
aggregation_method: "sum"
|
| 26 |
+
|
| 27 |
+
chkpt_path: null
|
| 28 |
+
|
| 29 |
+
# Diffusion kwargs
|
| 30 |
+
diffusion_steps: 450
|
| 31 |
+
diffusion_noise_schedule: polynomial_2
|
| 32 |
+
diffusion_noise_precision: 1e-5
|
| 33 |
+
diffusion_loss_type: vlb
|
| 34 |
+
normalize_factors: [1, 1]
|
| 35 |
+
extra_norm_values: []
|
| 36 |
+
augment_noise: False
|
| 37 |
+
data_augmentation: False
|
| 38 |
+
context_mask_rate: 0.2
|
| 39 |
+
mask_value: 5
|
| 40 |
+
normalize_condition: value_10
|
| 41 |
+
sp_regularizer_deploy: False
|
| 42 |
+
sp_regularizer_regularizer: hard
|
| 43 |
+
sp_regularizer_lambda_: 0
|
| 44 |
+
sp_regularizer_lambda_2: 1000
|
| 45 |
+
sp_regularizer_lambda_update_value: 1
|
| 46 |
+
sp_regularizer_lambda_update_step: 100
|
| 47 |
+
sp_regularizer_polynomial_p: 1.1
|
| 48 |
+
sp_regularizer_warm_up_steps: 100
|
| 49 |
+
reference_indices: null
|
| 50 |
+
|
| 51 |
+
# eSEN_dynamics specific kwargs
|
| 52 |
+
use_adapter_module: False
|
| 53 |
+
tanh: True
|
| 54 |
+
coords_range: 10
|
| 55 |
+
normalization_factor: 1.0
|
| 56 |
+
|
| 57 |
+
# Evaluator parameters
|
| 58 |
+
use_posebuster: True
|
| 59 |
+
metrics: valid_posebuster
|
| 60 |
+
n_samples: 96
|
| 61 |
+
batch_size: 8
|
| 62 |
+
generative_analysis: True
|
MolecularDiffusion/configs/tasks/diffusion_pretrained.yaml
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: MolecularDiffusion.runmodes.train.ModelTaskFactory_EGCL
|
| 2 |
+
task_type: diffusion
|
| 3 |
+
atom_vocab: ${data.atom_vocab}
|
| 4 |
+
condition_names: []
|
| 5 |
+
hidden_size: 192
|
| 6 |
+
act_fn:
|
| 7 |
+
_target_: torch.nn.SiLU
|
| 8 |
+
num_layers: 9
|
| 9 |
+
attention: True
|
| 10 |
+
tanh: True
|
| 11 |
+
num_sublayers: 1
|
| 12 |
+
sin_embedding: False
|
| 13 |
+
aggregation_method: "sum"
|
| 14 |
+
dropout: 0.0
|
| 15 |
+
normalization: False
|
| 16 |
+
include_cosine: True
|
| 17 |
+
norm_constant: 1.0
|
| 18 |
+
normalization_factor: 1.0
|
| 19 |
+
chkpt_path: null
|
| 20 |
+
|
| 21 |
+
# specific to diffusion
|
| 22 |
+
diffusion_steps : 900
|
| 23 |
+
diffusion_noise_schedule : polynomial_2 # learned, cosine_x, polynomial_x, issnr_x, smld_x
|
| 24 |
+
diffusion_noise_precision: 1e-5
|
| 25 |
+
diffusion_loss_type: vlb
|
| 26 |
+
normalize_factors: [1,4,10]
|
| 27 |
+
extra_norm_values: []
|
| 28 |
+
augment_noise: False
|
| 29 |
+
data_augmentation: False
|
| 30 |
+
context_mask_rate: 0.2
|
| 31 |
+
mask_value: 5
|
| 32 |
+
normalize_condition: value_10 # [None, "maxmin", "mad"]
|
| 33 |
+
sp_regularizer_deploy: False
|
| 34 |
+
sp_regularizer_regularizer: hard
|
| 35 |
+
sp_regularizer_lambda_: 0
|
| 36 |
+
sp_regularizer_lambda_2: 1000
|
| 37 |
+
sp_regularizer_lambda_update_value: 1
|
| 38 |
+
sp_regularizer_lambda_update_step: 100
|
| 39 |
+
sp_regularizer_polynomial_p: 1.1
|
| 40 |
+
sp_regularizer_warm_up_steps: 100
|
| 41 |
+
reference_indices: null # indices of core atoms for the outpainting objective
|
| 42 |
+
# evaluator parameters
|
| 43 |
+
use_posebuster: True
|
| 44 |
+
metrics: valid_posebuster # use_posebuster must be true
|
| 45 |
+
n_samples: 24
|
| 46 |
+
generative_analysis: True
|
| 47 |
+
batch_size: 4
|
MolecularDiffusion/configs/tasks/diffusion_pyg.yaml
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: MolecularDiffusion.runmodes.train.tasks_esen.ModelTaskFactory
|
| 2 |
+
task_type: diffusion_pyg
|
| 3 |
+
|
| 4 |
+
# === Atom Vocabulary ===
|
| 5 |
+
atom_vocab: ${data.atom_vocab}
|
| 6 |
+
|
| 7 |
+
condition_names: []
|
| 8 |
+
|
| 9 |
+
# === eSEN Model Parameters ===
|
| 10 |
+
hidden_size: 256
|
| 11 |
+
hidden_channels: 32
|
| 12 |
+
num_layers: 9
|
| 13 |
+
lmax: 2
|
| 14 |
+
mmax: 2
|
| 15 |
+
grid_resolution: null
|
| 16 |
+
cutoff: 15
|
| 17 |
+
edge_channels: 128
|
| 18 |
+
distance_function: "gaussian"
|
| 19 |
+
num_distance_basis: 10
|
| 20 |
+
norm_type: "rms_norm_sh"
|
| 21 |
+
act_type: "s2"
|
| 22 |
+
mlp_type: "grid"
|
| 23 |
+
otf_graph: True
|
| 24 |
+
use_envelope: False
|
| 25 |
+
activation_checkpointing: False
|
| 26 |
+
global_attributes: False
|
| 27 |
+
|
| 28 |
+
# IMPORTANT: Use "gaussian" for float features during diffusion!
|
| 29 |
+
# "gaussian" uses Gaussian smearing + MLP, fully float-compatible
|
| 30 |
+
# Other options ("embedding", "mixed") require integer atomic_numbers
|
| 31 |
+
sphere_embedding_type: "gaussian"
|
| 32 |
+
|
| 33 |
+
aggregation_method: "sum"
|
| 34 |
+
|
| 35 |
+
chkpt_path: null
|
| 36 |
+
|
| 37 |
+
# === Continuous Diffusion Parameters ===
|
| 38 |
+
# All features (positions, one-hot, integer) use continuous Gaussian diffusion
|
| 39 |
+
diffusion_steps: 900
|
| 40 |
+
diffusion_noise_schedule: polynomial_2 # Options: cosine, polynomial_2, polynomial_3, learned
|
| 41 |
+
diffusion_noise_precision: 1e-5
|
| 42 |
+
diffusion_loss_type: vlb # Options: vlb, l2
|
| 43 |
+
|
| 44 |
+
# Normalization factors: [positions, categorical (one-hot), integer (atomic_numbers)]
|
| 45 |
+
normalize_factors: [1.0, 4.0, 10.0]
|
| 46 |
+
extra_norm_values: []
|
| 47 |
+
|
| 48 |
+
# Data augmentation
|
| 49 |
+
augment_noise: False
|
| 50 |
+
data_augmentation: False
|
| 51 |
+
|
| 52 |
+
# Context masking for classifier-free guidance
|
| 53 |
+
context_mask_rate: 0.0
|
| 54 |
+
mask_value: 0.0
|
| 55 |
+
normalize_condition: null
|
| 56 |
+
|
| 57 |
+
# Self-paced learning regularizer
|
| 58 |
+
sp_regularizer_deploy: False
|
| 59 |
+
sp_regularizer_regularizer: hard
|
| 60 |
+
sp_regularizer_lambda_: 0
|
| 61 |
+
sp_regularizer_lambda_2: 1000
|
| 62 |
+
sp_regularizer_lambda_update_value: 1
|
| 63 |
+
sp_regularizer_lambda_update_step: 100
|
| 64 |
+
sp_regularizer_polynomial_p: 1.1
|
| 65 |
+
sp_regularizer_warm_up_steps: 100
|
| 66 |
+
|
| 67 |
+
# Outpainting/inpainting
|
| 68 |
+
reference_indices: null
|
| 69 |
+
use_unknown_fallback: False # Set to True when data.allow_unknown is True
|
| 70 |
+
|
| 71 |
+
# === eSEN Dynamics Specific ===
|
| 72 |
+
use_adapter_module: False
|
| 73 |
+
tanh: True
|
| 74 |
+
coords_range: 10
|
| 75 |
+
normalization_factor: 1.0
|
| 76 |
+
|
| 77 |
+
# === Evaluation Parameters ===
|
| 78 |
+
use_posebuster: True
|
| 79 |
+
metrics: valid_posebuster
|
| 80 |
+
n_samples: 48
|
| 81 |
+
batch_size: 8
|
| 82 |
+
generative_analysis: True
|
MolecularDiffusion/configs/tasks/diffusion_pyg_egcl.yaml
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: MolecularDiffusion.runmodes.train.tasks_egcl.ModelTaskFactory
|
| 2 |
+
task_type: diffusion_pyg
|
| 3 |
+
|
| 4 |
+
# === Atom Vocabulary ===
|
| 5 |
+
atom_vocab: ${data.atom_vocab}
|
| 6 |
+
|
| 7 |
+
condition_names: []
|
| 8 |
+
|
| 9 |
+
# === EGNN Parameters ===
|
| 10 |
+
hidden_size: 256
|
| 11 |
+
num_layers: 9
|
| 12 |
+
attention: True
|
| 13 |
+
norm_diff: True
|
| 14 |
+
tanh: True
|
| 15 |
+
coords_range: 10
|
| 16 |
+
num_sublayers: 1
|
| 17 |
+
sin_embedding: False
|
| 18 |
+
include_cosine: True
|
| 19 |
+
normalization_factor: 1.0
|
| 20 |
+
aggregation_method: "sum"
|
| 21 |
+
dropout: 0.0
|
| 22 |
+
normalization: False
|
| 23 |
+
|
| 24 |
+
chkpt_path: null
|
| 25 |
+
|
| 26 |
+
# === Continuous Diffusion Parameters ===
|
| 27 |
+
# All features use continuous Gaussian diffusion (same as EnVariationalDiffusion)
|
| 28 |
+
diffusion_steps: 900
|
| 29 |
+
diffusion_noise_schedule: polynomial_2
|
| 30 |
+
diffusion_noise_precision: 1e-5
|
| 31 |
+
diffusion_loss_type: vlb
|
| 32 |
+
|
| 33 |
+
# Normalization factors: [positions, categorical (one-hot), integer (atomic_numbers)]
|
| 34 |
+
normalize_factors: [1.0, 4.0, 10.0]
|
| 35 |
+
extra_norm_values: []
|
| 36 |
+
|
| 37 |
+
# Data augmentation
|
| 38 |
+
augment_noise: False
|
| 39 |
+
data_augmentation: False
|
| 40 |
+
|
| 41 |
+
# Context masking for classifier-free guidance
|
| 42 |
+
context_mask_rate: 0.0
|
| 43 |
+
mask_value: 0.0
|
| 44 |
+
normalize_condition: null
|
| 45 |
+
|
| 46 |
+
# Self-paced learning regularizer
|
| 47 |
+
sp_regularizer_deploy: False
|
| 48 |
+
use_unknown_fallback: False # Set to True when data.allow_unknown is True
|
| 49 |
+
|
| 50 |
+
# === Evaluation Parameters ===
|
| 51 |
+
use_posebuster: True
|
| 52 |
+
metrics: valid_posebuster
|
| 53 |
+
n_samples: 48
|
| 54 |
+
batch_size: 8
|
| 55 |
+
generative_analysis: True
|
MolecularDiffusion/configs/tasks/diffusion_pyg_egt.yaml
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: MolecularDiffusion.runmodes.train.tasks_egt.ModelTaskFactory
|
| 2 |
+
task_type: diffusion_pyg
|
| 3 |
+
|
| 4 |
+
# === Atom Vocabulary ===
|
| 5 |
+
atom_vocab: ${data.atom_vocab}
|
| 6 |
+
condition_names: []
|
| 7 |
+
|
| 8 |
+
# === Graph Transformer Parameters ===
|
| 9 |
+
model_class: GraphTransformerPyG
|
| 10 |
+
hidden_dims:
|
| 11 |
+
dx: 256
|
| 12 |
+
de: 1
|
| 13 |
+
dy: 32
|
| 14 |
+
n_head: 4
|
| 15 |
+
dim_ffX: 256
|
| 16 |
+
dim_ffE: 1
|
| 17 |
+
dim_ffy: 32
|
| 18 |
+
hidden_mlp_dims:
|
| 19 |
+
X: 256
|
| 20 |
+
E: 1
|
| 21 |
+
y: 32
|
| 22 |
+
pos: 512
|
| 23 |
+
act_fn_in:
|
| 24 |
+
_target_: torch.nn.SiLU
|
| 25 |
+
act_fn_out:
|
| 26 |
+
_target_: torch.nn.SiLU
|
| 27 |
+
num_layers: 6
|
| 28 |
+
dropout: 0.1
|
| 29 |
+
chkpt_path: null
|
| 30 |
+
|
| 31 |
+
# === Diffusion Parameters ===
|
| 32 |
+
diffusion_steps: 900
|
| 33 |
+
diffusion_noise_schedule: polynomial_2
|
| 34 |
+
diffusion_noise_precision: 1e-5
|
| 35 |
+
diffusion_loss_type: vlb
|
| 36 |
+
normalize_factors: [1.0, 4.0, 10.0]
|
| 37 |
+
extra_norm_values: []
|
| 38 |
+
|
| 39 |
+
# Data augmentation
|
| 40 |
+
augment_noise: False
|
| 41 |
+
data_augmentation: False
|
| 42 |
+
|
| 43 |
+
# Context masking for CFG
|
| 44 |
+
context_mask_rate: 0.0
|
| 45 |
+
mask_value: 0.0
|
| 46 |
+
normalize_condition: null
|
| 47 |
+
|
| 48 |
+
# Self-paced regularizer
|
| 49 |
+
sp_regularizer_deploy: False
|
| 50 |
+
|
| 51 |
+
# === Evaluation ===
|
| 52 |
+
use_posebuster: True
|
| 53 |
+
metrics: valid_posebuster
|
| 54 |
+
n_samples: 48
|
| 55 |
+
batch_size: 8
|
| 56 |
+
generative_analysis: True
|
MolecularDiffusion/configs/tasks/diffusion_tabasco.yaml
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TABASCO diffusion task configuration
|
| 2 |
+
# This config is referenced by: defaults: - override tasks: diffusion_tabasco
|
| 3 |
+
|
| 4 |
+
_target_: MolecularDiffusion.modules.tasks.diffusion_tabasco.ModelTaskFactory
|
| 5 |
+
task_type: diffusion_tabasco
|
| 6 |
+
|
| 7 |
+
# Automatically populated from dataset
|
| 8 |
+
num_atom_types: ???
|
| 9 |
+
|
| 10 |
+
# Transformer backbone configuration
|
| 11 |
+
transformer_config:
|
| 12 |
+
spatial_dim: 3
|
| 13 |
+
atom_dim: ???
|
| 14 |
+
hidden_dim: 256
|
| 15 |
+
num_layers: 16
|
| 16 |
+
num_heads: 8
|
| 17 |
+
activation: SiLU
|
| 18 |
+
implementation: pytorch
|
| 19 |
+
cross_attention: true
|
| 20 |
+
add_sinusoid_posenc: true
|
| 21 |
+
concat_combine_input: false
|
| 22 |
+
custom_weight_init: null
|
| 23 |
+
|
| 24 |
+
# Continuous coordinate interpolant
|
| 25 |
+
coords_interpolant_config:
|
| 26 |
+
key: coords
|
| 27 |
+
loss_weight: 1.0
|
| 28 |
+
centered: true
|
| 29 |
+
scale_noise_by_log_num_atoms: false
|
| 30 |
+
noise_scale: 1.0
|
| 31 |
+
langevin_sampling_schedule:
|
| 32 |
+
_target_: MolecularDiffusion.modules.models.tabasco.sample.noise_schedule.SampleNoiseSchedule
|
| 33 |
+
cutoff: 0.9
|
| 34 |
+
white_noise_sampling_scale: 0.01
|
| 35 |
+
time_factor:
|
| 36 |
+
_target_: MolecularDiffusion.modules.models.tabasco.flow.time_factor.InverseTimeFactor
|
| 37 |
+
max_value: 100.0
|
| 38 |
+
min_value: 0.05
|
| 39 |
+
zero_before: 0.0
|
| 40 |
+
eps: 1.0e-6
|
| 41 |
+
|
| 42 |
+
# Discrete atom type interpolant
|
| 43 |
+
atomics_interpolant_config:
|
| 44 |
+
key: atomics
|
| 45 |
+
loss_weight: 0.1
|
| 46 |
+
time_factor:
|
| 47 |
+
_target_: MolecularDiffusion.modules.models.tabasco.flow.time_factor.InverseTimeFactor
|
| 48 |
+
max_value: 100.0
|
| 49 |
+
min_value: 0.05
|
| 50 |
+
zero_before: 0.0
|
| 51 |
+
eps: 1.0e-6
|
| 52 |
+
|
| 53 |
+
# Flow matching configuration
|
| 54 |
+
flow_matching_config:
|
| 55 |
+
time_distribution: beta
|
| 56 |
+
time_alpha_factor: 1.8
|
| 57 |
+
num_random_augmentations: 7
|
| 58 |
+
sample_schedule: log
|
| 59 |
+
compile: false
|
| 60 |
+
interdist_loss: null
|
| 61 |
+
|
| 62 |
+
# Dataset statistics (populated at runtime)
|
| 63 |
+
dataset_stats:
|
| 64 |
+
max_atoms: ???
|
| 65 |
+
atom_count_histogram: {}
|
| 66 |
+
all_smiles: []
|
MolecularDiffusion/configs/tasks/guidance.yaml
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: MolecularDiffusion.runmodes.train.ModelTaskFactory_EGCL
|
| 2 |
+
task_type: guidance
|
| 3 |
+
atom_vocab: ${data.atom_vocab}
|
| 4 |
+
condition_names: []
|
| 5 |
+
hidden_size: 512
|
| 6 |
+
act_fn:
|
| 7 |
+
_target_: torch.nn.ReLU
|
| 8 |
+
num_layers: 1
|
| 9 |
+
attention: True
|
| 10 |
+
tanh: True
|
| 11 |
+
num_sublayers: 5
|
| 12 |
+
sin_embedding: False
|
| 13 |
+
aggregation_method: "sum"
|
| 14 |
+
dropout: 0.0
|
| 15 |
+
normalization: False
|
| 16 |
+
include_cosine: True
|
| 17 |
+
norm_constant: 1.0
|
| 18 |
+
normalization_factor: 1.0
|
| 19 |
+
chkpt_path: null
|
| 20 |
+
|
| 21 |
+
# specific to diffusion
|
| 22 |
+
task_learn: [S1_exc,T1_exc]
|
| 23 |
+
criterion: mse
|
| 24 |
+
metric: [mae]
|
| 25 |
+
num_mlp_layer: 3
|
| 26 |
+
mlp_dropout: 0.2
|
| 27 |
+
mlp_batch_norm: True # True/False for legacy mode, null/'layernorm'/'batchnorm' for new mode
|
| 28 |
+
prediction_mlp_type: legacy # 'legacy' (backward compat), 'pernode', or 'padded'
|
| 29 |
+
prediction_activation: relu # 'relu' or 'silu'
|
| 30 |
+
diffusion_steps: 900
|
| 31 |
+
diffusion_noise_precision: 1e-5
|
| 32 |
+
nu_arr: [2,2,2]
|
| 33 |
+
mapping: ["pos", "categorical", "integer"]
|
| 34 |
+
weight_classes: null
|
| 35 |
+
norm_values: [1,4,10]
|
| 36 |
+
t_max: 0.7
|
| 37 |
+
loss_weighting: linear
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
|
MolecularDiffusion/configs/tasks/guidance_esen.yaml
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Uses the existing ModelTaskFactory from tasks_esen.py with task_type: guidance
|
| 2 |
+
_target_: MolecularDiffusion.runmodes.train.tasks_esen.ModelTaskFactory
|
| 3 |
+
task_type: guidance
|
| 4 |
+
atom_vocab: ${data.atom_vocab}
|
| 5 |
+
condition_names: []
|
| 6 |
+
|
| 7 |
+
# eSEN Backbone parameters
|
| 8 |
+
sphere_channels: 128
|
| 9 |
+
hidden_channels: 128
|
| 10 |
+
lmax: 2
|
| 11 |
+
mmax: 2
|
| 12 |
+
num_layers: 4
|
| 13 |
+
edge_channels: 128
|
| 14 |
+
distance_function: "gaussian"
|
| 15 |
+
num_distance_basis: 512
|
| 16 |
+
cutoff: 5.0
|
| 17 |
+
max_neighbors: 300
|
| 18 |
+
norm_type: "rms_norm_sh"
|
| 19 |
+
act_type: "s2"
|
| 20 |
+
mlp_type: "grid"
|
| 21 |
+
|
| 22 |
+
# CRITICAL: Use "mlp" or "gaussian" for differentiable gradients
|
| 23 |
+
sphere_embedding_type: "mlp"
|
| 24 |
+
# in_node_channels is computed by factory: len(atom_vocab) + n_extra + 1 (charge) + 1 (time)
|
| 25 |
+
|
| 26 |
+
# Guidance-specific parameters
|
| 27 |
+
task_learn: [S1_exc, T1_exc]
|
| 28 |
+
criterion: mse
|
| 29 |
+
metric: [mae]
|
| 30 |
+
num_mlp_layer: 3
|
| 31 |
+
mlp_dropout: 0.2
|
| 32 |
+
mlp_batch_norm: True # True/False for legacy mode, null/'layernorm'/'batchnorm' for new mode
|
| 33 |
+
prediction_mlp_type: legacy # 'legacy' (backward compat), 'pernode', or 'padded'
|
| 34 |
+
prediction_activation: relu # 'relu' or 'silu'
|
| 35 |
+
diffusion_steps: 600
|
| 36 |
+
diffusion_noise_precision: 1e-5
|
| 37 |
+
nu_arr: [2, 2, 2]
|
| 38 |
+
mapping: ["pos", "categorical", "integer"]
|
| 39 |
+
weight_classes: null
|
| 40 |
+
norm_values: [1, 4, 10]
|
| 41 |
+
t_max: 0.7
|
| 42 |
+
loss_weighting: linear
|
| 43 |
+
normalization: False
|
MolecularDiffusion/configs/tasks/guidance_pc.yaml
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Configuration for PointCloud-optimized EGCL Guidance Model
|
| 2 |
+
# Uses GuidanceModelPredictionPointCloud with dense_mode=True
|
| 3 |
+
|
| 4 |
+
_target_: MolecularDiffusion.runmodes.train.ModelTaskFactory_EGCL
|
| 5 |
+
task_type: guidance
|
| 6 |
+
atom_vocab: ${data.atom_vocab}
|
| 7 |
+
condition_names: []
|
| 8 |
+
hidden_size: 512
|
| 9 |
+
act_fn:
|
| 10 |
+
_target_: torch.nn.ReLU
|
| 11 |
+
num_layers: 1
|
| 12 |
+
attention: True
|
| 13 |
+
tanh: True
|
| 14 |
+
num_sublayers: 5
|
| 15 |
+
sin_embedding: False
|
| 16 |
+
aggregation_method: "sum"
|
| 17 |
+
dropout: 0.0
|
| 18 |
+
normalization: False
|
| 19 |
+
include_cosine: True
|
| 20 |
+
norm_constant: 1.0
|
| 21 |
+
normalization_factor: 1.0
|
| 22 |
+
chkpt_path: null
|
| 23 |
+
|
| 24 |
+
# Enable dense mode for PointCloud inference
|
| 25 |
+
dense_mode: True
|
| 26 |
+
|
| 27 |
+
# Guidance-specific parameters
|
| 28 |
+
task_learn: [S1_exc, T1_exc]
|
| 29 |
+
criterion: mse
|
| 30 |
+
metric: [mae]
|
| 31 |
+
num_mlp_layer: 3
|
| 32 |
+
mlp_dropout: 0.2
|
| 33 |
+
mlp_batch_norm: True # True/False for legacy mode, null/'layernorm'/'batchnorm' for new mode
|
| 34 |
+
prediction_mlp_type: legacy # 'legacy' (backward compat), 'pernode', or 'padded'
|
| 35 |
+
prediction_activation: relu # 'relu' or 'silu'
|
| 36 |
+
diffusion_steps: 900
|
| 37 |
+
diffusion_noise_precision: 1e-5
|
| 38 |
+
nu_arr: [2, 2, 2]
|
| 39 |
+
mapping: ["pos", "categorical", "integer"]
|
| 40 |
+
weight_classes: null
|
| 41 |
+
norm_values: [1, 4, 10]
|
| 42 |
+
t_max: 0.7
|
| 43 |
+
loss_weighting: linear
|
MolecularDiffusion/configs/tasks/ldm_dit.yaml
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Latent Diffusion with DiT denoiser
|
| 2 |
+
_target_: MolecularDiffusion.modules.tasks.diffusion_ldm.LDMTaskFactory
|
| 3 |
+
task_type: ldm_dit
|
| 4 |
+
_recursive: False
|
| 5 |
+
autoencoder_ckpt: ??? # Required: path to pre-trained VAE
|
| 6 |
+
|
| 7 |
+
denoiser:
|
| 8 |
+
_target_: MolecularDiffusion.modules.models.ldm.denoisers.dit.DiT
|
| 9 |
+
# d_x is auto-inferred from VAE latent_dim
|
| 10 |
+
d_model: 384
|
| 11 |
+
num_layers: 12
|
| 12 |
+
nhead: 6
|
| 13 |
+
class_dropout_prob: 0.1
|
| 14 |
+
|
| 15 |
+
interpolant:
|
| 16 |
+
type: flow_matching
|
| 17 |
+
min_t: 0.01
|
| 18 |
+
corrupt: true
|
| 19 |
+
num_timesteps: 100
|
| 20 |
+
self_condition: false
|
| 21 |
+
self_condition_prob: 0.5
|
| 22 |
+
|
| 23 |
+
# Data augmentation
|
| 24 |
+
augment_rotation: true
|
MolecularDiffusion/configs/tasks/regression.yaml
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: MolecularDiffusion.runmodes.train.ModelTaskFactory_EGCL
|
| 2 |
+
task_type: regression
|
| 3 |
+
atom_vocab: ${data.atom_vocab}
|
| 4 |
+
condition_names: []
|
| 5 |
+
hidden_size: 512
|
| 6 |
+
act_fn:
|
| 7 |
+
_target_: torch.nn.ReLU
|
| 8 |
+
num_layers: 1
|
| 9 |
+
attention: True
|
| 10 |
+
tanh: True
|
| 11 |
+
num_sublayers: 5
|
| 12 |
+
sin_embedding: False
|
| 13 |
+
aggregation_method: "sum"
|
| 14 |
+
dropout: 0.0
|
| 15 |
+
normalization: False # For EGNN backbone layer norm
|
| 16 |
+
include_cosine: True
|
| 17 |
+
norm_constant: 1.0
|
| 18 |
+
normalization_factor: 1.0
|
| 19 |
+
chkpt_path: null
|
| 20 |
+
|
| 21 |
+
# specific to regression
|
| 22 |
+
task_learn: [S1_exc,T1_exc]
|
| 23 |
+
criterion: mse
|
| 24 |
+
metric: [mae]
|
| 25 |
+
num_mlp_layer: 3
|
| 26 |
+
mlp_batch_norm: batchnorm # Options: null, layernorm, batchnorm
|
| 27 |
+
target_normalization: True # Normalize targets by mean/std in loss
|
| 28 |
+
mlp_dropout: 0.2
|
| 29 |
+
prediction_mlp_type: "pernode"
|
| 30 |
+
prediction_activation: "relu"
|
MolecularDiffusion/configs/tasks/regression_esen.yaml
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
_target_: MolecularDiffusion.runmodes.train.ModelTaskFactory_ESEN
|
| 2 |
+
task_type: regression
|
| 3 |
+
atom_vocab: ${data.atom_vocab}
|
| 4 |
+
# condition_names: []
|
| 5 |
+
hidden_size: 256
|
| 6 |
+
hidden_channels: 256
|
| 7 |
+
num_layers: 4
|
| 8 |
+
lmax: 2
|
| 9 |
+
mmax: 2
|
| 10 |
+
grid_resolution: null
|
| 11 |
+
cutoff: 5.0
|
| 12 |
+
edge_channels: 128
|
| 13 |
+
distance_function: "gaussian"
|
| 14 |
+
num_distance_basis: 512
|
| 15 |
+
norm_type: "rms_norm_sh"
|
| 16 |
+
act_type: "s2"
|
| 17 |
+
mlp_type: "grid"
|
| 18 |
+
use_envelope: False
|
| 19 |
+
activation_checkpointing: False
|
| 20 |
+
global_attributes: False
|
| 21 |
+
sphere_embedding_type: "mixed"
|
| 22 |
+
aggregation_method: mean
|
| 23 |
+
chkpt_path: null
|
| 24 |
+
|
| 25 |
+
# specific to regression
|
| 26 |
+
task_learn: [S1_exc,T1_exc]
|
| 27 |
+
criterion: mse
|
| 28 |
+
metric: [mae]
|
| 29 |
+
num_mlp_layer: 3
|
| 30 |
+
mlp_dropout: 0.2
|
| 31 |
+
mlp_batch_norm: batchnorm # Options: null, layernorm, batchnorm
|
| 32 |
+
target_normalization: True # Normalize targets by mean/std in loss
|
| 33 |
+
prediction_mlp_type: "pernode"
|
| 34 |
+
prediction_activation: "relu"
|