iflp1908sl commited on
Commit
005e4d3
·
1 Parent(s): 777c843

Add source code (clean)

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +9 -0
  2. MolecularDiffusion/__init__.py +26 -0
  3. MolecularDiffusion/_version.py +21 -0
  4. MolecularDiffusion/callbacks/__init__.py +13 -0
  5. MolecularDiffusion/callbacks/train_helper.py +259 -0
  6. MolecularDiffusion/cli/__init__.py +6 -0
  7. MolecularDiffusion/cli/_hydra.py +129 -0
  8. MolecularDiffusion/cli/analyze.py +380 -0
  9. MolecularDiffusion/cli/eval_predict.py +259 -0
  10. MolecularDiffusion/cli/generate.py +282 -0
  11. MolecularDiffusion/cli/main.py +197 -0
  12. MolecularDiffusion/cli/predict.py +395 -0
  13. MolecularDiffusion/cli/train.py +453 -0
  14. MolecularDiffusion/configs/data/filter_molecules_by_property.py +0 -0
  15. MolecularDiffusion/configs/data/formed_data.yaml +20 -0
  16. MolecularDiffusion/configs/data/mol_dataset.yaml +25 -0
  17. MolecularDiffusion/configs/data/mol_dataset_extraf.yaml +23 -0
  18. MolecularDiffusion/configs/engine/lightning.yaml +33 -0
  19. MolecularDiffusion/configs/engine/original.yaml +4 -0
  20. MolecularDiffusion/configs/hydra/default.yaml +19 -0
  21. MolecularDiffusion/configs/interference/gen_cfg.yaml +15 -0
  22. MolecularDiffusion/configs/interference/gen_cfggg.yaml +29 -0
  23. MolecularDiffusion/configs/interference/gen_conditional.yaml +12 -0
  24. MolecularDiffusion/configs/interference/gen_gg.yaml +29 -0
  25. MolecularDiffusion/configs/interference/gen_hybrid.yaml +28 -0
  26. MolecularDiffusion/configs/interference/gen_inpaint.yaml +69 -0
  27. MolecularDiffusion/configs/interference/gen_outpaint.yaml +31 -0
  28. MolecularDiffusion/configs/interference/gen_outpaintft.yaml +18 -0
  29. MolecularDiffusion/configs/interference/gen_unconditional.yaml +11 -0
  30. MolecularDiffusion/configs/interference/prediction.yaml +2 -0
  31. MolecularDiffusion/configs/logger/default.yaml +9 -0
  32. MolecularDiffusion/configs/logger/wandb.yaml +9 -0
  33. MolecularDiffusion/configs/models/tabasco_transformer.yaml +72 -0
  34. MolecularDiffusion/configs/tasks/diffusion.yaml +48 -0
  35. MolecularDiffusion/configs/tasks/diffusion_egt.yaml +54 -0
  36. MolecularDiffusion/configs/tasks/diffusion_extraf.yaml +47 -0
  37. MolecularDiffusion/configs/tasks/diffusion_hybrid.yaml +95 -0
  38. MolecularDiffusion/configs/tasks/diffusion_hybrid_egcl.yaml +53 -0
  39. MolecularDiffusion/configs/tasks/diffusion_integer.yaml +62 -0
  40. MolecularDiffusion/configs/tasks/diffusion_pretrained.yaml +47 -0
  41. MolecularDiffusion/configs/tasks/diffusion_pyg.yaml +82 -0
  42. MolecularDiffusion/configs/tasks/diffusion_pyg_egcl.yaml +55 -0
  43. MolecularDiffusion/configs/tasks/diffusion_pyg_egt.yaml +56 -0
  44. MolecularDiffusion/configs/tasks/diffusion_tabasco.yaml +66 -0
  45. MolecularDiffusion/configs/tasks/guidance.yaml +40 -0
  46. MolecularDiffusion/configs/tasks/guidance_esen.yaml +43 -0
  47. MolecularDiffusion/configs/tasks/guidance_pc.yaml +43 -0
  48. MolecularDiffusion/configs/tasks/ldm_dit.yaml +24 -0
  49. MolecularDiffusion/configs/tasks/regression.yaml +30 -0
  50. MolecularDiffusion/configs/tasks/regression_esen.yaml +34 -0
.gitignore ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.pyc
3
+ *.pyo
4
+ *.pyd
5
+ .DS_Store
6
+ .env
7
+ .venv
8
+ env/
9
+ venv/
MolecularDiffusion/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MolecularDiffusion - A molecular diffusion framework.
3
+
4
+ This package provides tools and models for molecular diffusion processes.
5
+ """
6
+
7
+ __version__ = "0.1.0"
8
+ __author__ = "Thanapat Worakul"
9
+ __email__ = "thanapat.worakul@epfl.ch"
10
+
11
+ # Import main modules to make them available at package level
12
+ from . import core
13
+ from . import data
14
+ from . import modules
15
+ from . import utils
16
+ from . import callbacks
17
+ from . import runmodes
18
+
19
+ __all__ = [
20
+ "core",
21
+ "data",
22
+ "modules",
23
+ "utils",
24
+ "callbacks",
25
+ "runmodes"
26
+ ]
MolecularDiffusion/_version.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # file generated by setuptools-scm
2
+ # don't change, don't track in version control
3
+
4
+ __all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
5
+
6
+ TYPE_CHECKING = False
7
+ if TYPE_CHECKING:
8
+ from typing import Tuple
9
+ from typing import Union
10
+
11
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
12
+ else:
13
+ VERSION_TUPLE = object
14
+
15
+ version: str
16
+ __version__: str
17
+ __version_tuple__: VERSION_TUPLE
18
+ version_tuple: VERSION_TUPLE
19
+
20
+ __version__ = version = '0.1.dev26+gff3c644.d20250809'
21
+ __version_tuple__ = version_tuple = (0, 1, 'dev26', 'gff3c644.d20250809')
MolecularDiffusion/callbacks/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .train_helper import (
2
+ Queue,
3
+ gradient_clipping,
4
+ EMA,
5
+ SP_regularizer
6
+ )
7
+
8
+ __all__ = [
9
+ "Queue",
10
+ "gradient_clipping",
11
+ "EMA",
12
+ "SP_regularizer"
13
+ ]
MolecularDiffusion/callbacks/train_helper.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import logging
4
+
5
+ logger = logging.getLogger(__name__)
6
+ logger.setLevel(logging.CRITICAL)
7
+ class Queue:
8
+ def __init__(self, max_len=50):
9
+ self.items = []
10
+ self.max_len = max_len
11
+
12
+ def __len__(self):
13
+ return len(self.items)
14
+
15
+ def add(self, item):
16
+ self.items.insert(0, item)
17
+ if len(self) > self.max_len:
18
+ self.items.pop()
19
+
20
+ def mean(self):
21
+ return np.mean(self.items)
22
+
23
+ def std(self):
24
+ return np.std(self.items)
25
+
26
+ class gradient_clipping:
27
+ def __init__(self, m=1, max_len=200):
28
+ self.max_grad_norm = None
29
+ self.max_grad_norms = []
30
+ self.max_len = max_len
31
+ self.m = m
32
+ self.FACTOR = 100
33
+
34
+ def __call__(self, model, gradnorm_queue):
35
+ self.max_grad_norm = 1.5 * gradnorm_queue.mean() + 2 * gradnorm_queue.std()
36
+ if len(self.max_grad_norms) == 0:
37
+ self.max_grad_norms.append(self.max_grad_norm)
38
+ else:
39
+ #max_grad_norm_mean = torch.mean(torch.tensor(self.max_grad_norms))
40
+ previous_max_grad_norm = self.max_grad_norms[-1]
41
+ # if the current max_grad_norm is greater than the mean of the previous max_grad_norms
42
+ if self.max_grad_norm > previous_max_grad_norm:
43
+ self.max_grad_norm = previous_max_grad_norm * self.m
44
+ if self.max_grad_norm > previous_max_grad_norm * 1e5:
45
+ self.max_grad_norm = previous_max_grad_norm * self.m / self.FACTOR
46
+
47
+ self.max_grad_norms.append(self.max_grad_norm)
48
+
49
+ if len(self.max_grad_norms) > self.max_len:
50
+ self.max_grad_norms.pop(0)
51
+ # Clips gradient and returns the norm
52
+
53
+ grad_norm = torch.nn.utils.clip_grad_norm_(
54
+ model.parameters(), max_norm=self.max_grad_norm, norm_type=2.0
55
+ )
56
+ if float(grad_norm) > self.max_grad_norm:
57
+ gradnorm_queue.add(float(self.max_grad_norm))
58
+ else:
59
+ gradnorm_queue.add(float(grad_norm))
60
+
61
+ if float(grad_norm) > self.max_grad_norm:
62
+ logger.info(
63
+ f"Clipped gradient with value {grad_norm:.1f} "
64
+ f"while allowed {self.max_grad_norm:.1f}"
65
+ )
66
+ return grad_norm
67
+
68
+ class gradient_clipping_0:
69
+ def __init__(self, m=1, max_len=200):
70
+ self.max_grad_norm = None
71
+ self.max_grad_norms = []
72
+ self.max_len = max_len
73
+ self.m = m
74
+
75
+ def __call__(self, model, gradnorm_queue):
76
+ self.max_grad_norm = 1.5 * gradnorm_queue.mean() + 2 * gradnorm_queue.std()
77
+ if len(self.max_grad_norms) == 0:
78
+ self.max_grad_norms.append(self.max_grad_norm)
79
+ else:
80
+ max_grad_norm_mean = torch.mean(torch.tensor(self.max_grad_norms))
81
+ if self.max_grad_norm > max_grad_norm_mean:
82
+ self.max_grad_norm = max_grad_norm_mean * self.m
83
+ if self.max_grad_norm > max_grad_norm_mean * 1e5:
84
+ self.max_grad_norm = max_grad_norm_mean * self.m / 10
85
+ self.max_grad_norms.append(self.max_grad_norm)
86
+
87
+ if len(self.max_grad_norms) > self.max_len:
88
+ self.max_grad_norms.pop(0)
89
+ # Clips gradient and returns the norm
90
+
91
+ grad_norm = torch.nn.utils.clip_grad_norm_(
92
+ model.parameters(), max_norm=self.max_grad_norm, norm_type=2.0
93
+ )
94
+ if float(grad_norm) > self.max_grad_norm:
95
+ gradnorm_queue.add(float(self.max_grad_norm))
96
+ else:
97
+ gradnorm_queue.add(float(grad_norm))
98
+
99
+ if float(grad_norm) > self.max_grad_norm:
100
+ print(
101
+ f"Clipped gradient with value {grad_norm:.1f} "
102
+ f"while allowed {self.max_grad_norm:.1f}"
103
+ )
104
+ return grad_norm
105
+
106
+
107
+ class EMA:
108
+ def __init__(self, beta):
109
+ super().__init__()
110
+ self.beta = beta
111
+
112
+ def update_model_average(self, ma_model, current_model):
113
+ for current_params, ma_params in zip(
114
+ current_model.parameters(), ma_model.parameters()
115
+ ):
116
+ old_weight, up_weight = ma_params.data, current_params.data
117
+ ma_params.data = self.update_average(old_weight, up_weight)
118
+
119
+ def update_average(self, old, new):
120
+ if old is None:
121
+ return new
122
+ return old * self.beta + (1 - self.beta) * new
123
+
124
+
125
+ class SP_regularizer:
126
+ def __init__(
127
+ self,
128
+ regularizer: str,
129
+ lambda_: float = 10,
130
+ lambda_2: float = 100,
131
+ lambda_update_value: float = 50,
132
+ lambda_update_step: int = 2500,
133
+ polynomial_p: float = 1.5,
134
+ warm_up_steps: int = 100,
135
+ ):
136
+ """
137
+ Self-paced regularizer for curriculum learning
138
+ Args:
139
+ regularizer (str): Regularizer to use. Options are:
140
+ - hard
141
+ - linear
142
+ - logaritmic
143
+ - logistic
144
+ lambda_ (float): Initial lambda value
145
+ lambda_2 (float): Initial lambda value for the second regularizer
146
+ lambda_update_value (float): Value to update lambda
147
+ lambda_update_step (int): Number of steps to update lambda
148
+ polynomial_p (float): Value of p for polynomial regularizer
149
+ warm_up_steps (int): Number of steps to use the regularizer
150
+ """
151
+
152
+ self.regularizer = regularizer
153
+ self.lambda_ = lambda_
154
+ self.lambda_2 = lambda_2
155
+ self.n_calls = 1
156
+ self.lambda_update_value = lambda_update_value
157
+ self.lambda_update_step = lambda_update_step
158
+ self.p = polynomial_p
159
+ self.warm_up_steps = warm_up_steps
160
+
161
+ def __call__(self, losses: torch.Tensor):
162
+
163
+ # TODO during warm up steps, keep the losses infomation, to be used to determine lambda
164
+ if self.n_calls < self.warm_up_steps:
165
+ self.n_calls += 1
166
+ return losses
167
+ else:
168
+ if self.regularizer == "hard":
169
+ weighted_loss = self.hard(losses)
170
+ elif self.regularizer == "linear":
171
+ weighted_loss = self.linear(losses)
172
+ elif self.regularizer == "logaritmic":
173
+ weighted_loss = self.logaritmic(losses)
174
+ elif self.regularizer == "logistic":
175
+ weighted_loss = self.logistic(losses)
176
+ elif self.regularizer == "polynomial":
177
+ weighted_loss = self.polynomial(losses)
178
+ elif self.regularizer == "hard_relax":
179
+ weighted_loss = self.hard_relax(losses)
180
+ else:
181
+ raise ValueError("Regularizer not implemented")
182
+ self.n_calls += 1
183
+ self.update_lambda()
184
+ return weighted_loss
185
+
186
+ def update_lambda(self):
187
+ if self.n_calls % self.lambda_update_step == 0:
188
+ self.lambda_ += self.lambda_update_value
189
+ self.lambda_2 += self.lambda_update_value
190
+ elif self.n_calls == 0:
191
+ self.lambda_ = self.lambda_
192
+ self.lambda_2 = self.lambda_2
193
+
194
+ def hard(self, losses: torch.Tensor):
195
+
196
+ weights = (losses <= self.lambda_).float()
197
+ sp_loss = losses * weights
198
+
199
+ return sp_loss
200
+
201
+ def hard_relax(self, losses: torch.Tensor):
202
+ weights = torch.where(
203
+ losses < self.lambda_,
204
+ torch.ones_like(losses),
205
+ (1 - losses / self.lambda_2) ** (1 / (self.p - 1)),
206
+ )
207
+ idces_zero = torch.where(losses > self.lambda_2)
208
+ weights[idces_zero] = 0
209
+ weights = torch.clamp(weights, 0, 1)
210
+ sp_loss = losses * weights
211
+
212
+ return sp_loss
213
+
214
+ def linear(self, losses: torch.Tensor):
215
+ weights = torch.where(
216
+ losses > self.lambda_, torch.zeros_like(losses), 1 - losses / self.lambda_
217
+ )
218
+ weights = torch.clamp(weights, 0, 1)
219
+ sp_loss = losses * weights
220
+
221
+ return sp_loss
222
+
223
+ def logaritmic(self, losses: torch.Tensor):
224
+
225
+ weights = torch.where(
226
+ losses > self.lambda_,
227
+ torch.zeros_like(losses),
228
+ torch.log(2 - losses / self.lambda_),
229
+ )
230
+ weights = torch.clamp(weights, 0, 1)
231
+ sp_loss = losses * weights
232
+
233
+ return sp_loss
234
+
235
+ def logistic(self, losses: torch.Tensor):
236
+
237
+ weights = torch.where(
238
+ losses > self.lambda_,
239
+ torch.zeros_like(losses),
240
+ (1 - torch.exp(torch.tensor(self.lambda_)))
241
+ / (1 - torch.exp(losses - self.lambda_)),
242
+ )
243
+ weights = torch.clamp(weights, 0, 1)
244
+ sp_loss = losses * weights
245
+
246
+ return sp_loss
247
+
248
+ def polynomial(self, losses: torch.Tensor):
249
+
250
+ weights = torch.where(
251
+ losses > self.lambda_,
252
+ torch.zeros_like(losses),
253
+ (1 - losses / self.lambda_) ** (1 / (self.p - 1)),
254
+ )
255
+ weights = torch.clamp(weights, 0, 1)
256
+ sp_loss = losses * weights
257
+
258
+ return sp_loss
259
+
MolecularDiffusion/cli/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # CLI module for MolecularDiffusion.
2
+ """Unified command-line interface for MolecularDiffusion package."""
3
+
4
+ from MolecularDiffusion.cli.main import cli
5
+
6
+ __all__ = ["cli"]
MolecularDiffusion/cli/_hydra.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Hydra configuration utilities for CLI.
2
+
3
+ Provides utilities for discovering and loading bundled configs
4
+ while allowing user configs to reference them.
5
+ """
6
+
7
+ import os
8
+ from pathlib import Path
9
+ from typing import Optional, List
10
+ from importlib import resources
11
+
12
+
13
+ def get_package_config_path() -> Path:
14
+ """Get the absolute path to bundled config directory.
15
+
16
+ Returns:
17
+ Path to the configs directory within the installed package.
18
+ """
19
+ # Use importlib.resources for Python 3.9+
20
+ try:
21
+ # For Python 3.9+
22
+ pkg_files = resources.files("MolecularDiffusion")
23
+ config_path = pkg_files / "configs"
24
+ # Convert to real path (handles both installed and editable installs)
25
+ if hasattr(config_path, '_path'):
26
+ # Traversable from importlib.resources
27
+ real_path = Path(str(config_path))
28
+ else:
29
+ real_path = Path(config_path)
30
+ if real_path.is_dir():
31
+ return real_path
32
+ except (TypeError, AttributeError, Exception):
33
+ pass
34
+
35
+ # Fallback: relative to this module
36
+ module_dir = Path(__file__).parent.parent
37
+ config_path = module_dir / "configs"
38
+ if config_path.is_dir():
39
+ return config_path
40
+
41
+ raise FileNotFoundError(
42
+ "Could not find bundled configs. Ensure package is installed correctly."
43
+ )
44
+
45
+
46
+ def setup_hydra_config(
47
+ config_name: str,
48
+ config_dir: Optional[str] = None,
49
+ overrides: Optional[List[str]] = None,
50
+ ):
51
+ """Setup Hydra configuration with proper search paths.
52
+
53
+ Configures Hydra to search:
54
+ 1. User's config_dir (if provided) or current directory
55
+ 2. Package bundled configs (via searchpath)
56
+
57
+ Args:
58
+ config_name: Name of the config file (without .yaml extension)
59
+ config_dir: Optional user config directory
60
+ overrides: Optional list of Hydra override strings
61
+
62
+ Returns:
63
+ DictConfig from Hydra
64
+ """
65
+ from hydra import compose, initialize_config_dir
66
+ from hydra.core.global_hydra import GlobalHydra
67
+
68
+ # Get package config path for defaults
69
+ pkg_config_path = get_package_config_path()
70
+
71
+ # Determine primary config directory
72
+ # If config_name contains a path (e.g., "configs/train.yaml"), extract the directory
73
+ config_name_path = Path(config_name)
74
+ if config_name_path.parent != Path("."):
75
+ # Config name includes directory, use that as config_dir
76
+ if config_dir is None:
77
+ config_dir = str(config_name_path.parent)
78
+ config_name = config_name_path.name
79
+
80
+ if config_dir:
81
+ primary_config_dir = os.path.abspath(config_dir)
82
+ else:
83
+ primary_config_dir = os.getcwd()
84
+
85
+ # Clear any existing Hydra state
86
+ GlobalHydra.instance().clear()
87
+
88
+ # Initialize with the primary config directory
89
+ initialize_config_dir(
90
+ config_dir=primary_config_dir,
91
+ version_base="1.3",
92
+ )
93
+
94
+ # Build overrides to include searchpath for bundled configs
95
+ all_overrides = overrides or []
96
+
97
+ # Add package config path to searchpath using file:// protocol
98
+ # This allows Hydra to find bundled defaults like data/mol_dataset.yaml
99
+ searchpath_override = f"hydra.searchpath=[file://{pkg_config_path}]"
100
+ all_overrides = [searchpath_override] + all_overrides
101
+
102
+ # Handle config name (strip .yaml if present)
103
+ if config_name.endswith(".yaml"):
104
+ config_name = config_name[:-5]
105
+
106
+ # Compose the configuration
107
+ cfg = compose(config_name=config_name, overrides=all_overrides)
108
+
109
+ return cfg
110
+
111
+
112
+ def run_hydra_app(
113
+ config_name: str,
114
+ task_function,
115
+ config_dir: Optional[str] = None,
116
+ overrides: Optional[List[str]] = None,
117
+ ):
118
+ """Run a Hydra-based task function with proper config setup.
119
+
120
+ This is the main entry point for CLI commands that use Hydra configs.
121
+
122
+ Args:
123
+ config_name: Name of the config file
124
+ task_function: Function to call with the composed config
125
+ config_dir: Optional user config directory
126
+ overrides: Optional Hydra overrides
127
+ """
128
+ cfg = setup_hydra_config(config_name, config_dir, overrides)
129
+ return task_function(cfg)
MolecularDiffusion/cli/analyze.py ADDED
@@ -0,0 +1,380 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Analyze CLI subcommands for 3D molecule analysis.
2
+
3
+ Provides subcommands for:
4
+ - optimize: XTB geometry optimization
5
+ - metrics: Validity/connectivity metrics
6
+ - compare: RMSD, energy, and optional bond analysis
7
+ - xyz2mol: XYZ to SMILES conversion + fingerprints
8
+ """
9
+
10
+ import os
11
+
12
+ import click
13
+
14
+ # Enable -h as alias for --help
15
+ CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
16
+
17
+
18
+ @click.group(context_settings=CONTEXT_SETTINGS)
19
+ def analyze():
20
+ """Analyze 3D molecular structures.
21
+
22
+ \b
23
+ Subcommands:
24
+ optimize XTB geometry optimization
25
+ metrics Validity/connectivity metrics
26
+ compare RMSD, energy, and bond analysis
27
+ xyz2mol Convert XYZ to SMILES + fingerprints
28
+ """
29
+ pass
30
+
31
+
32
+ # ============================================================================
33
+ # OPTIMIZE: XTB geometry optimization
34
+ # ============================================================================
35
+
36
+ @analyze.command("optimize", context_settings=CONTEXT_SETTINGS)
37
+ @click.argument("input_dir", type=click.Path(exists=True))
38
+ @click.option("--output-dir", "-o", "--o", default=None, type=click.Path(),
39
+ help="Output directory for optimized files (default: input_dir/optimized_xyz)")
40
+ @click.option("--charge", "-c", "--c", default=0, type=int,
41
+ help="Molecular charge for xTB (default: 0)")
42
+ @click.option("--level", "-l", "--l", default="gfn1", type=click.Choice(["gfn1", "gfn2", "gfn-ff", "mmff94"]),
43
+ help="Optimization level (default: gfn1)")
44
+ @click.option("--timeout", "-t", "--t", default=240, type=int,
45
+ help="Timeout per molecule in seconds (default: 240)")
46
+ @click.option("--scale-factor", "-s", "--s", default=1.3, type=float,
47
+ help="Scale factor for covalent radii (default: 1.3)")
48
+ @click.option("--csv", "csv_path", default=None, type=click.Path(),
49
+ help="CSV file to filter which files to optimize")
50
+ @click.option("--filter-column", default=None, type=str,
51
+ help="Column name in CSV to filter by (values must be 1)")
52
+ def optimize(input_dir, output_dir, charge, level, timeout, scale_factor, csv_path, filter_column):
53
+ """Optimize XYZ geometries using xTB.
54
+
55
+ \b
56
+ Examples:
57
+ MolCraftDiff analyze optimize gen_xyz/
58
+ MolCraftDiff analyze optimize gen_xyz/ --o optimized/ --level gfn2
59
+ """
60
+ from MolecularDiffusion.runmodes.analyze.xtb_optimization import get_xtb_optimized_xyz
61
+
62
+ output_dir = output_dir or os.path.join(input_dir, "optimized_xyz")
63
+
64
+ click.echo(f"Optimizing XYZ files from: {input_dir}")
65
+ click.echo(f"Output directory: {output_dir}")
66
+ click.echo(f"xTB level: {level}, charge: {charge}")
67
+
68
+ optimized_files = get_xtb_optimized_xyz(
69
+ input_directory=input_dir,
70
+ output_directory=output_dir,
71
+ charge=charge,
72
+ level=level,
73
+ timeout=timeout,
74
+ scale_factor=scale_factor,
75
+ csv_path=csv_path,
76
+ filter_column=filter_column,
77
+ )
78
+
79
+ click.echo(f"\nSuccessfully optimized {len(optimized_files)} files.")
80
+
81
+
82
+ # ============================================================================
83
+ # METRICS: Validity/connectivity metrics
84
+ # ============================================================================
85
+
86
+ @analyze.command("metrics", context_settings=CONTEXT_SETTINGS)
87
+ @click.argument("input_dir", type=click.Path(exists=True))
88
+ @click.option("--output", "-o", "--o", "--output-csv", default=None, type=click.Path(),
89
+ help="Output CSV file for results")
90
+ @click.option("--metrics", "-m", "--m", "metrics_type", default="all",
91
+ type=click.Choice(["all", "core", "posebuster", "geom_revised"]),
92
+ help="Which metrics to compute (default: all)")
93
+ @click.option("--recheck-topo", is_flag=True, default=False,
94
+ help="Recheck topology using RDKit")
95
+ @click.option("--check-strain", is_flag=True, default=False,
96
+ help="Check strain via XTB optimization")
97
+ @click.option("--portion", "-p", "--p", default=1.0, type=float,
98
+ help="Portion of XYZ files to process (default: 1.0 = all)")
99
+ @click.option("--mol-converter", default="cell2mol",
100
+ type=click.Choice(["cell2mol", "openbabel"]),
101
+ help="XYZ to mol converter (default: cell2mol)")
102
+ @click.option("--skip-atoms", multiple=True, type=int,
103
+ help="Atom indices to skip in validation")
104
+ @click.option("--n-subsets", "-n", "--n", default=5, type=int,
105
+ help="Number of subsets for std calculation (default: 5)")
106
+ @click.option("--timeout", "-t", "--t", default=10, type=int,
107
+ help="Timeout per xyz2mol conversion in seconds (default: 10)")
108
+ def metrics(input_dir, output, metrics_type, recheck_topo, check_strain, portion, mol_converter, skip_atoms, n_subsets, timeout):
109
+ """Compute validity and connectivity metrics for XYZ files.
110
+
111
+ \b
112
+ Metrics types:
113
+ all Run all metrics (core + posebuster + geom_revised)
114
+ core Basic validity checks (connectivity, atom stability)
115
+ posebuster PoseBusters checks (bond lengths, angles, clashes)
116
+ geom_revised Aromatic-aware stability metrics
117
+
118
+ \b
119
+ Examples:
120
+ MolCraftDiff analyze metrics gen_xyz/
121
+ MolCraftDiff analyze metrics gen_xyz/ --metrics posebuster
122
+ MolCraftDiff analyze metrics gen_xyz/ --metrics geom_revised --mol-converter openbabel
123
+ """
124
+ import argparse
125
+ from MolecularDiffusion.runmodes.analyze.compute_metrics import runner
126
+
127
+ args = argparse.Namespace(
128
+ input=input_dir,
129
+ output=output,
130
+ metrics=metrics_type,
131
+ recheck_topo=recheck_topo,
132
+ check_strain=check_strain,
133
+ portion=portion,
134
+ mol_converter=mol_converter,
135
+ skip_atoms=list(skip_atoms) if skip_atoms else None,
136
+ n_subsets=n_subsets,
137
+ timeout=timeout,
138
+ )
139
+
140
+ click.echo(f"Computing {metrics_type} metrics for: {input_dir}")
141
+ runner(args)
142
+
143
+
144
+ # ============================================================================
145
+ # COMPARE: Unified RMSD, energy, and bond analysis
146
+ # ============================================================================
147
+
148
+ @analyze.command("compare", context_settings=CONTEXT_SETTINGS)
149
+ @click.argument("directory", type=click.Path(exists=True))
150
+ @click.option("--mol-converter", default="openbabel", type=click.Choice(["openbabel", "cell2mol"]),
151
+ help="Converter for bond perception (default: openbabel)")
152
+ @click.option("--n-subsets", "-n", "--n", default=5, type=int,
153
+ help="Number of subsets for std calculation (default: 5)")
154
+ @click.option("--output", "-o", "--o", "--csv", "csv_path", default=None, type=click.Path(),
155
+ help="Output CSV filename for results")
156
+ @click.option("--charge", "-c", "--c", default=0, type=int,
157
+ help="Molecular charge for xTB energy (default: 0)")
158
+ @click.option("--level", "-l", "--l", default="gfn2", type=click.Choice(["gfn1", "gfn2", "gfn-ff", "mmff94"]),
159
+ help="xTB level for energy calculation (default: gfn2)")
160
+ @click.option("--timeout", "-t", "--t", default=120, type=int,
161
+ help="Timeout per xTB calculation in seconds (default: 120)")
162
+ def compare(directory, mol_converter, n_subsets, csv_path, charge, level, timeout):
163
+ """Compare XYZ files with their optimized counterparts.
164
+
165
+ Computes RMSD, xTB Energy Difference, and Bond Geometry Metrics.
166
+ Enforces strict connectivity checks.
167
+
168
+ Requires 'optimized_xyz' subdirectory with *_opt.xyz files.
169
+ """
170
+ import argparse
171
+ from MolecularDiffusion.runmodes.analyze.compare_to_optimized import run_compare_analysis
172
+
173
+ # Construct args namespace to pass to run_compare_analysis
174
+ args = argparse.Namespace(
175
+ directory=directory,
176
+ mol_converter=mol_converter,
177
+ n_subsets=n_subsets,
178
+ csv_path=csv_path,
179
+ charge=charge,
180
+ level=level,
181
+ timeout=timeout
182
+ )
183
+
184
+ run_compare_analysis(args)
185
+
186
+
187
+ # ============================================================================
188
+ # XYZ2MOL: Convert XYZ to SMILES + fingerprints
189
+ # ============================================================================
190
+
191
+ @analyze.command("xyz2mol", context_settings=CONTEXT_SETTINGS)
192
+ @click.argument("xyz_dir", type=click.Path(exists=True))
193
+ @click.option("--input-csv", "-i", "--i", default=None, type=click.Path(),
194
+ help="Optional input CSV with xyz file list")
195
+ @click.option("--label", "-l", "--l", default=None, type=str,
196
+ help="Label for processed files")
197
+ @click.option("--timeout", "-t", "--t", default=30, type=int,
198
+ help="Timeout per conversion in seconds (default: 30)")
199
+ @click.option("--bits", "-b", "--b", default=2048, type=int,
200
+ help="Number of bits for Morgan fingerprint (default: 2048)")
201
+ @click.option("--verbose", "-v", "--v", is_flag=True,
202
+ help="Enable verbose output")
203
+ def xyz2mol(xyz_dir, input_csv, label, timeout, bits, verbose):
204
+ """Convert XYZ files to SMILES and extract fingerprints/scaffolds.
205
+
206
+ Outputs are saved to xyz_dir/2d_reprs/:
207
+ - smiles_processed.csv
208
+ - fingerprints.npy
209
+ - scaffolds.txt
210
+ - substructures.json
211
+
212
+ \b
213
+ Examples:
214
+ MolCraftDiff analyze xyz2mol gen_xyz/
215
+ MolCraftDiff analyze xyz2mol gen_xyz/ --bits 1024 -v
216
+ """
217
+ from pathlib import Path
218
+ import pandas as pd
219
+ import numpy as np
220
+ import json
221
+ import logging
222
+
223
+ from MolecularDiffusion.runmodes.analyze.xyz2mol import (
224
+ load_file_list_from_dir, run_processing, extract_scaffold_and_fingerprints
225
+ )
226
+
227
+ if verbose:
228
+ logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
229
+
230
+ xyz_dir = Path(xyz_dir)
231
+ two_d_reprs_dir = xyz_dir / "2d_reprs"
232
+ two_d_reprs_dir.mkdir(parents=True, exist_ok=True)
233
+
234
+ smiles_csv_output = two_d_reprs_dir / "smiles_processed.csv"
235
+
236
+ click.echo(f"Processing XYZ files from: {xyz_dir}")
237
+ click.echo(f"Output directory: {two_d_reprs_dir}")
238
+
239
+ # Load file list
240
+ if input_csv:
241
+ df = pd.read_csv(input_csv)
242
+ else:
243
+ df = load_file_list_from_dir(str(xyz_dir))
244
+
245
+ # Generate SMILES
246
+ df_smiles = run_processing(df, str(xyz_dir), label, smiles_csv_output, timeout=timeout, verbose=verbose)
247
+
248
+ if df_smiles is None or 'smiles' not in df_smiles.columns or df_smiles['smiles'].isnull().all():
249
+ click.echo("No valid SMILES generated.", err=True)
250
+ return
251
+
252
+ # Extract fingerprints and scaffolds
253
+ click.echo("\nExtracting fingerprints and scaffolds...")
254
+ fps, scaffolds, clean_smiles, n_fail, substruct_counts = \
255
+ extract_scaffold_and_fingerprints(df_smiles["smiles"].dropna().values, fp_bits=bits)
256
+
257
+ np.save(two_d_reprs_dir / "fingerprints.npy", fps)
258
+ with open(two_d_reprs_dir / "scaffolds.txt", "w") as f:
259
+ f.write("\n".join(scaffolds))
260
+ with open(two_d_reprs_dir / "smiles_cleaned.txt", "w") as f:
261
+ f.write("\n".join(clean_smiles))
262
+ with open(two_d_reprs_dir / "substructures.json", "w") as f:
263
+ json.dump(substruct_counts, f, indent=2)
264
+
265
+ total = len(df_smiles["smiles"].dropna())
266
+ click.echo(f"\n--- Summary ---")
267
+ click.echo(f"Total SMILES: {total}")
268
+ click.echo(f"Failed FP extraction: {n_fail}")
269
+ click.echo(f"Unique substructures: {len(substruct_counts)}")
270
+ click.echo(f"Outputs saved to: {two_d_reprs_dir}")
271
+
272
+
273
+ # ============================================================================
274
+ # XTB-ELECTRONIC: Compute XTB electronic properties
275
+ # ============================================================================
276
+
277
+ @analyze.command("xtb-electronic", context_settings=CONTEXT_SETTINGS)
278
+ @click.argument("input_dir", type=click.Path(exists=True))
279
+ @click.option("--output", "--o", "-o", default=None, type=click.Path(),
280
+ help="Output file path (without extension for 'all' format)")
281
+ @click.option("--method", "--m", "-m", default="2", type=click.Choice(["1", "2", "ptb"]),
282
+ help="XTB method: 1=GFN1, 2=GFN2, ptb=PTB (default: 2)")
283
+ @click.option("--charge", "--c", "-c", default=0, type=int,
284
+ help="Molecular charge (default: 0)")
285
+ @click.option("--n-unpaired", "--unpaired", default=0, type=int,
286
+ help="Number of unpaired electrons (default: 0)")
287
+ @click.option("--solvent", "--s", "-s", default=None, type=str,
288
+ help="Solvent for solvation calculations (e.g., 'water', 'thf', 'chcl3')")
289
+ @click.option("--properties", "--prop", "-p", multiple=True,
290
+ type=click.Choice(["energy", "dipole", "reactivity", "global",
291
+ "charges", "fukui", "bond_orders", "all"]),
292
+ help="Property groups to compute (default: energy)")
293
+ @click.option("--corrected/--no-corrected", default=True,
294
+ help="Apply empirical IP/EA correction (default: True)")
295
+ @click.option("--timeout", "--t", "-t", default=120, type=int,
296
+ help="Timeout per molecule in seconds (default: 120)")
297
+ @click.option("--n-jobs", "--jobs", "-j", default=1, type=int,
298
+ help="Number of parallel jobs (default: 1)")
299
+ @click.option("--format", "--fmt", "-f", "output_format", default="csv",
300
+ type=click.Choice(["csv", "json", "ase", "all"]),
301
+ help="Output format: csv, json, ase (.db), or all (default: csv)")
302
+ def xtb_electronic(input_dir, output, method, charge, n_unpaired,
303
+ solvent, properties, corrected, timeout, n_jobs, output_format):
304
+ """Compute XTB electronic properties for XYZ files.
305
+
306
+ Uses morfeus to calculate quantum-chemical descriptors at the GFN-xTB level.
307
+
308
+ \b
309
+ Property groups (molecular-level):
310
+ energy Total energy, HOMO, LUMO, gap, Fermi level
311
+ dipole Dipole moment and vector
312
+ reactivity IP, EA, electronegativity, hardness, softness
313
+ global Electrophilicity, nucleophilicity, fugalities
314
+ solvation Solvation energy, H-bond correction (requires --solvent)
315
+
316
+ \b
317
+ Property groups (atomic-level):
318
+ charges Atomic charges (Mulliken)
319
+ fukui Fukui indices (f+, f-, f, dual)
320
+ bond_orders Bond orders between atom pairs
321
+
322
+ \b
323
+ Output formats:
324
+ csv Molecular-level properties only (one row per molecule)
325
+ json Full data including atomic-level properties
326
+ ase ASE database with properties in atoms.info/arrays
327
+ all Generate all three formats
328
+
329
+ \b
330
+ Examples:
331
+ MolCraftDiff analyze xtb-electronic gen_xyz/
332
+ MolCraftDiff analyze xtb-electronic gen_xyz/ -p energy -p reactivity
333
+ MolCraftDiff analyze xtb-electronic gen_xyz/ -s water -p solvation
334
+ MolCraftDiff analyze xtb-electronic gen_xyz/ -p all -f ase -o results.db
335
+ """
336
+ from MolecularDiffusion.runmodes.analyze.xtb_electronic import batch_xtb_electronic
337
+
338
+ # Parse method
339
+ if method in ["1", "2"]:
340
+ method = int(method)
341
+
342
+ # Default properties
343
+ if not properties:
344
+ properties = ["energy"]
345
+
346
+ # Default output path
347
+ if output is None:
348
+ output = os.path.join(input_dir, "xtb_electronic")
349
+
350
+ click.echo(f"Computing XTB electronic properties for: {input_dir}")
351
+ click.echo(f"Method: GFN{method}-xTB" if method != "ptb" else "Method: PTB")
352
+ click.echo(f"Charge: {charge}, Unpaired: {n_unpaired}")
353
+ if solvent:
354
+ click.echo(f"Solvent: {solvent}")
355
+ click.echo(f"Properties: {', '.join(properties)}")
356
+ click.echo(f"Output format: {output_format}")
357
+
358
+ df = batch_xtb_electronic(
359
+ input_dir=input_dir,
360
+ output_path=output,
361
+ output_format=output_format,
362
+ method=method,
363
+ charge=charge,
364
+ n_unpaired=n_unpaired,
365
+ solvent=solvent,
366
+ properties=list(properties),
367
+ corrected=corrected,
368
+ timeout=timeout,
369
+ n_jobs=n_jobs,
370
+ )
371
+
372
+ n_success = df["success"].sum() if "success" in df.columns else len(df)
373
+ n_total = len(df)
374
+
375
+ click.echo(f"\n--- Summary ---")
376
+ click.echo(f"Processed: {n_total} molecules")
377
+ click.echo(f"Successful: {n_success}")
378
+ click.echo(f"Failed: {n_total - n_success}")
379
+ click.echo(f"Output saved to: {output}")
380
+
MolecularDiffusion/cli/eval_predict.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Eval-Predict command for MolCraft CLI.
2
+
3
+ Adapted from scripts/eval_predict.py for package-level execution.
4
+ """
5
+
6
+ import os
7
+ from typing import Any, Dict, Tuple
8
+
9
+ import hydra
10
+ import numpy as np
11
+ import pandas as pd
12
+ import torch
13
+ from omegaconf import DictConfig, OmegaConf
14
+ from torch.utils.data import ConcatDataset
15
+
16
+ from MolecularDiffusion.core import Engine
17
+ from MolecularDiffusion.runmodes.train import DataModule, ModelTaskFactory_EGCL, OptimSchedulerFactory
18
+ from MolecularDiffusion.utils import RankedLogger, seed_everything
19
+ from MolecularDiffusion.utils.plot_function import (
20
+ plot_kde_distribution,
21
+ plot_histogram_distribution,
22
+ plot_kde_distribution_multiple,
23
+ plot_correlation_with_histograms,
24
+ )
25
+
26
+ log = RankedLogger(__name__, rank_zero_only=True)
27
+
28
+
29
+ def is_rank_zero():
30
+ """Check if current process is rank zero."""
31
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
32
+ return torch.distributed.get_rank() == 0
33
+ return True
34
+
35
+
36
+ def load_checkpoint_weights(task, chkpt_path):
37
+ """Load weights from checkpoint with support for Engine and Lightning formats."""
38
+ log.info(f"Loading weights from: {chkpt_path}")
39
+
40
+ checkpoint = torch.load(chkpt_path, map_location="cpu", weights_only=False)
41
+
42
+ # Check if it's a Lightning checkpoint
43
+ if "state_dict" in checkpoint:
44
+ log.info("Detected Lightning checkpoint.")
45
+ state_dict = checkpoint["state_dict"]
46
+ cleaned_state_dict = {}
47
+ for k, v in state_dict.items():
48
+ if k.startswith("task."):
49
+ cleaned_state_dict[k[5:]] = v
50
+ else:
51
+ cleaned_state_dict[k] = v
52
+
53
+ load_result = task.load_state_dict(cleaned_state_dict, strict=False)
54
+ log.info(f"Loaded {len(cleaned_state_dict)} parameters from state_dict")
55
+ if load_result.missing_keys:
56
+ log.warning(f"Missing keys: {load_result.missing_keys}")
57
+
58
+ # Recover statistics
59
+ for key in ["mean", "std", "weight"]:
60
+ val = None
61
+ if key in checkpoint:
62
+ val = checkpoint[key]
63
+ elif f"task.{key}" in state_dict:
64
+ val = state_dict[f"task.{key}"]
65
+ elif key in state_dict:
66
+ val = state_dict[key]
67
+
68
+ if val is not None:
69
+ if not isinstance(val, torch.Tensor):
70
+ val = torch.as_tensor(val, dtype=torch.float32)
71
+
72
+ # Register as buffer to ensure it moves with the model to the correct device
73
+ if key in task._buffers:
74
+ task._buffers[key].copy_(val)
75
+ else:
76
+ task.register_buffer(key, val)
77
+ elif "model" in checkpoint:
78
+ log.info("Detected original Engine checkpoint.")
79
+ task.load_state_dict(checkpoint["model"], strict=False)
80
+ # Recover statistics
81
+ for key in ["mean", "std", "weight"]:
82
+ if key in checkpoint["model"]:
83
+ val = checkpoint["model"][key]
84
+ if not isinstance(val, torch.Tensor):
85
+ val = torch.as_tensor(val, dtype=torch.float32)
86
+
87
+ # Register as buffer to ensure it moves with the model to the correct device
88
+ if key in task._buffers:
89
+ task._buffers[key].copy_(val)
90
+ else:
91
+ task.register_buffer(key, val)
92
+ else:
93
+ # Fallback for unexpected formats
94
+ log.warning("Unknown checkpoint format. Attempting direct load.")
95
+ task.load_state_dict(checkpoint, strict=False)
96
+
97
+ # Ensure task has a device attribute for initial loading,
98
+ # but don't hardcode it if it's about to be moved by Engine
99
+ if not hasattr(task, 'device'):
100
+ task.device = next(task.parameters()).device if list(task.parameters()) else torch.device('cpu')
101
+
102
+
103
+ def engine_wrapper(task_module, data_module, trainer_module):
104
+ """Run evaluation with Engine."""
105
+ trainer_module.get_optimizer()
106
+ trainer_module.get_scheduler()
107
+
108
+ pred_dataset = ConcatDataset([data_module.valid_set, data_module.test_set])
109
+ solver = Engine(
110
+ task_module.task,
111
+ None,
112
+ None,
113
+ pred_dataset,
114
+ batch_size=data_module.batch_size,
115
+ collate_fn=data_module.collate_fn,
116
+ logger="logging",
117
+ )
118
+ # Ensure task.device is updated to the actual device solver is using
119
+ task_module.task.device = solver.device
120
+
121
+ _, preds_test, targets_test = solver.evaluate("test")
122
+ y_preds = torch.cat(preds_test, dim=0)
123
+ y_trues = torch.cat(targets_test, dim=0)
124
+ return y_preds, y_trues
125
+
126
+
127
+ def predict(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
128
+ """Evaluate predictions on validation/test sets."""
129
+ if cfg.get("seed"):
130
+ seed_everything(cfg.seed, workers=True)
131
+
132
+ log.info(f"Instantiating datamodule <{cfg.data._target_}>")
133
+ data_module: DataModule = hydra.utils.instantiate(
134
+ cfg.data, task_type=cfg.tasks.task_type, train_ratio=0
135
+ )
136
+ data_module.load()
137
+
138
+ log.info(f"Instantiating task <{cfg.tasks._target_}>")
139
+ act_fn = hydra.utils.instantiate(cfg.tasks.act_fn)
140
+
141
+ # Store checkpoint path and temporarily disable it for task_module.build()
142
+ # to avoid the factory's internal (legacy) loading.
143
+ chkpt_path = cfg.tasks.get("chkpt_path")
144
+
145
+ # Create a copy of the config to modify safely
146
+ tasks_cfg = OmegaConf.to_container(cfg.tasks, resolve=True)
147
+ tasks_cfg['chkpt_path'] = None
148
+ tasks_cfg = OmegaConf.create(tasks_cfg)
149
+
150
+ task_module: ModelTaskFactory_EGCL = hydra.utils.instantiate(tasks_cfg, act_fn=act_fn)
151
+ task_module.build()
152
+
153
+ # Manually load weights using our robust loader
154
+ if chkpt_path:
155
+ load_checkpoint_weights(task_module.task, chkpt_path)
156
+
157
+ log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
158
+ trainer_module: OptimSchedulerFactory = hydra.utils.instantiate(
159
+ cfg.trainer, parameters=task_module.task.parameters()
160
+ )
161
+
162
+ object_dict = {
163
+ "cfg": cfg,
164
+ "datamodule": data_module,
165
+ "task": task_module,
166
+ "trainer": trainer_module,
167
+ }
168
+
169
+ log.info("Logging hyperparameters!")
170
+ log_hyperparameters(object_dict)
171
+
172
+ y_preds, y_trues = engine_wrapper(task_module, data_module, trainer_module)
173
+
174
+ df = pd.read_csv(cfg.data.filename)
175
+ task_matrix = df[cfg.tasks.task_learn].to_numpy()
176
+ filenames = df["filename"].to_numpy()
177
+ filenames_aligned = []
178
+
179
+ for row in y_trues.cpu().numpy():
180
+ mask = np.all(np.isclose(task_matrix, row, atol=1e-4), axis=1)
181
+ idx = np.flatnonzero(mask)
182
+
183
+ if idx.size == 0:
184
+ raise ValueError(f"No match for row {row}")
185
+ if idx.size > 1:
186
+ raise ValueError(f"Multiple matches for row {row}: {filenames[idx].tolist()}")
187
+
188
+ filenames_aligned.append(filenames[idx[0]])
189
+
190
+ df_compiled = pd.DataFrame({
191
+ "filename": filenames_aligned,
192
+ "y_true": y_trues.cpu().numpy().tolist(),
193
+ "y_pred": y_preds.cpu().numpy().tolist(),
194
+ })
195
+
196
+ os.makedirs(cfg.output_directory, exist_ok=True)
197
+ df_compiled.to_csv(f"{cfg.output_directory}/predictions.csv", index=False)
198
+
199
+ log.info("Prediction statistics:")
200
+ for task_name in cfg.tasks.task_learn:
201
+ log.info(f"--- {task_name} ---")
202
+ log.info(f"Mean: {df[task_name].mean():.4f}")
203
+ log.info(f"Std: {df[task_name].std():.4f}")
204
+ log.info(f"Min: {df[task_name].min():.4f}")
205
+ log.info(f"Max: {df[task_name].max():.4f}")
206
+
207
+ log.info("Plotting distributions...")
208
+ props = []
209
+ for i, prop in enumerate(cfg.tasks.task_learn):
210
+ plot_kde_distribution(df[prop], prop, f"{cfg.output_directory}/{prop}_kde.png")
211
+ plot_histogram_distribution(df[prop], prop, f"{cfg.output_directory}/{prop}_hist.png")
212
+ plot_correlation_with_histograms(
213
+ y_trues[:, i].cpu().numpy(),
214
+ y_preds[:, i].cpu().numpy(),
215
+ prop,
216
+ "",
217
+ f"{cfg.output_directory}/{prop}_correlation.png",
218
+ )
219
+ props.append(df[prop].values)
220
+
221
+ props = np.array(props).T
222
+ plot_kde_distribution_multiple(props, cfg.tasks.task_learn, f"{cfg.output_directory}/kde_all.png")
223
+
224
+
225
+ def log_hyperparameters(object_dict: dict):
226
+ """Log hyperparameters for debugging."""
227
+ if not is_rank_zero():
228
+ return
229
+
230
+ log.info("\n========== Logging Hyperparameters ==========\n")
231
+ for name, obj in object_dict.items():
232
+ log.info(f"{'=' * 20} {name.upper()} {'=' * 20}")
233
+ if name == "cfg":
234
+ if isinstance(obj, dict):
235
+ log.info("\n" + OmegaConf.to_yaml(OmegaConf.create(obj)))
236
+ else:
237
+ log.info("\n" + OmegaConf.to_yaml(obj))
238
+ else:
239
+ if hasattr(obj, '__dict__'):
240
+ for k, v in vars(obj).items():
241
+ if not k.startswith("_"):
242
+ log.info(f"{k}: {v}")
243
+ log.info(f"{'=' * (44 + len(name))}\n")
244
+
245
+ if "task" in object_dict and hasattr(object_dict["task"], "task"):
246
+ model = object_dict["task"].task
247
+ total = sum(p.numel() for p in model.parameters())
248
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
249
+ log.info(f"{'=' * 20} MODEL PARAMS {'=' * 20}")
250
+ log.info(f"model/params/total: {total}")
251
+ log.info(f"model/params/trainable: {trainable}")
252
+ log.info("=" * 54 + "\n")
253
+
254
+ log.info("========== End of Hyperparameters ==========\n")
255
+
256
+
257
+ def eval_predict_main(cfg: DictConfig):
258
+ """Entry point for CLI eval-predict command."""
259
+ predict(cfg)
MolecularDiffusion/cli/generate.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Generation command for MolCraft CLI.
2
+
3
+ Adapted from scripts/generate.py for package-level execution.
4
+ """
5
+
6
+ import glob
7
+ import os
8
+ import re
9
+ import time
10
+ import copy
11
+ import pickle
12
+ from typing import Any, Dict, Optional, Tuple
13
+
14
+ import hydra
15
+ import torch
16
+ from omegaconf import DictConfig, OmegaConf
17
+
18
+ from MolecularDiffusion.core import Engine
19
+ from MolecularDiffusion.runmodes.generate.tasks_generate import GenerativeFactory
20
+ from MolecularDiffusion.utils import (
21
+ RankedLogger,
22
+ seed_everything,
23
+ recursive_module_to_device,
24
+ )
25
+
26
+ log = RankedLogger(__name__, rank_zero_only=True)
27
+
28
+
29
+ def is_rank_zero():
30
+ """Check if current process is rank zero."""
31
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
32
+ return torch.distributed.get_rank() == 0
33
+ return True
34
+
35
+
36
+ def load_lightning_model(chkpt_path, task_config, atom_vocab=None, total_step=0):
37
+ """Load model from Lightning checkpoint (.ckpt)."""
38
+ log.info(f"Loading Lightning checkpoint from: {chkpt_path}")
39
+
40
+ try:
41
+ from MolecularDiffusion.core.engine_lightning import EngineLightning
42
+ wrapper = EngineLightning.load_from_checkpoint(chkpt_path, map_location="cpu")
43
+ log.info("Successfully loaded model using EngineLightning.load_from_checkpoint")
44
+
45
+ if atom_vocab and hasattr(wrapper.task, 'atom_vocab') and wrapper.task.atom_vocab is None:
46
+ wrapper.task.atom_vocab = atom_vocab
47
+
48
+ # Apply diffusion_steps override from config
49
+ if total_step > 0:
50
+ if hasattr(wrapper.task, 'model') and hasattr(wrapper.task.model, 'T'):
51
+ log.info(f"Overriding diffusion steps: {wrapper.task.model.T} -> {total_step}")
52
+ wrapper.task.model.T = total_step
53
+ elif hasattr(wrapper.task, 'T'):
54
+ log.info(f"Overriding diffusion steps: {wrapper.task.T} -> {total_step}")
55
+ wrapper.task.T = total_step
56
+
57
+ wrapper.task.eval()
58
+ return wrapper.task
59
+
60
+ except Exception as e:
61
+ log.warning(f"EngineLightning.load_from_checkpoint failed ({type(e).__name__}: {e}). Falling back to manual config reconstruction.")
62
+
63
+ # Fallback: Load checkpoint manually
64
+ checkpoint = torch.load(chkpt_path, map_location="cpu", weights_only=False)
65
+
66
+ hparams = checkpoint.get("hyper_parameters", {})
67
+ if "model_config" in hparams and hparams["model_config"] is not None:
68
+ task_config = OmegaConf.create(hparams["model_config"])
69
+ log.info("Loaded task configuration from checkpoint hyperparameters")
70
+ elif task_config is None:
71
+ raise ValueError("task_config not provided and 'model_config' not found in checkpoint.")
72
+
73
+ task_config = copy.deepcopy(task_config)
74
+ OmegaConf.set_readonly(task_config, False)
75
+ OmegaConf.set_struct(task_config, False)
76
+
77
+ n_types = len(atom_vocab) if atom_vocab else 0
78
+
79
+ if OmegaConf.is_missing(task_config, "num_atom_types") or task_config.get("num_atom_types") == "???":
80
+ task_config.num_atom_types = n_types if n_types > 0 else 100
81
+
82
+ if hasattr(task_config, "transformer_config"):
83
+ if OmegaConf.is_missing(task_config.transformer_config, "atom_dim"):
84
+ task_config.transformer_config.atom_dim = task_config.num_atom_types
85
+
86
+ if hasattr(task_config, "dataset_stats"):
87
+ if OmegaConf.is_missing(task_config.dataset_stats, "max_atoms"):
88
+ task_config.dataset_stats.max_atoms = 150
89
+
90
+ log.info(f"Building task from config: {task_config._target_}")
91
+ task_factory = hydra.utils.instantiate(task_config, atom_vocab=atom_vocab)
92
+ task = task_factory.build()
93
+
94
+ state_dict = checkpoint.get('state_dict', {})
95
+ cleaned_state_dict = {}
96
+ for key, value in state_dict.items():
97
+ if key.startswith('task.'):
98
+ cleaned_state_dict[key[5:]] = value
99
+ else:
100
+ cleaned_state_dict[key] = value
101
+
102
+ task.load_state_dict(cleaned_state_dict, strict=False)
103
+ log.info(f"Loaded {len(cleaned_state_dict)} parameters from checkpoint")
104
+
105
+ if 'data_stats' in checkpoint:
106
+ task.tabasco_model.set_data_stats(checkpoint['data_stats'])
107
+ if 'node_dist_model' in checkpoint:
108
+ task._node_dist_model = checkpoint['node_dist_model']
109
+ if 'prop_dist_model' in checkpoint:
110
+ task.prop_dist_model = checkpoint['prop_dist_model']
111
+
112
+ if total_step > 0:
113
+ if hasattr(task, 'model') and hasattr(task.model, 'T'):
114
+ task.model.T = total_step
115
+ elif hasattr(task, 'T'):
116
+ task.T = total_step
117
+
118
+ task.eval()
119
+ return task
120
+
121
+
122
+ def load_model(chkpt_directory, task_config=None, atom_vocab=None, total_step=0):
123
+ """Load model from checkpoint directory with auto-detection."""
124
+ ckpt_files = glob.glob(os.path.join(chkpt_directory, '*.ckpt'))
125
+
126
+ if ckpt_files:
127
+ best_metric = -1.0
128
+ best_checkpoint = None
129
+
130
+ for ckpt_file in ckpt_files:
131
+ match = re.search(r"(?:metric|val)[_=](\d+\.?\d*)", os.path.basename(ckpt_file))
132
+ if match:
133
+ metric = float(match.group(1))
134
+ if metric > best_metric:
135
+ best_metric = metric
136
+ best_checkpoint = ckpt_file
137
+
138
+ if best_checkpoint is None:
139
+ last_ckpt = os.path.join(chkpt_directory, 'last.ckpt')
140
+ best_checkpoint = last_ckpt if os.path.exists(last_ckpt) else ckpt_files[0]
141
+
142
+ task = load_lightning_model(best_checkpoint, task_config, atom_vocab, total_step)
143
+
144
+ try:
145
+ with open(os.path.join(chkpt_directory, "edm_stat.pkl"), "rb") as file:
146
+ edm_stats = pickle.load(file)
147
+ task.node_dist_model = edm_stats.get("node")
148
+ if "prop" in edm_stats:
149
+ task.prop_dist_model = edm_stats["prop"]
150
+ except (ImportError, FileNotFoundError):
151
+ log.warning("edm_stat.pkl not found")
152
+
153
+ return task
154
+
155
+ # Original engine (.pkl files)
156
+ model_path = os.path.join(chkpt_directory, "edm_chem.pkl")
157
+
158
+ if not os.path.exists(model_path):
159
+ checkpoint_files = glob.glob(os.path.join(chkpt_directory, '*.pkl'))
160
+ checkpoint_files = [f for f in checkpoint_files if 'edm_stat.pkl' not in os.path.basename(f)]
161
+
162
+ if not checkpoint_files:
163
+ raise FileNotFoundError(f"No checkpoints found in {chkpt_directory}")
164
+
165
+ best_metric = -1.0
166
+ best_checkpoint = None
167
+
168
+ for ckpt_file in checkpoint_files:
169
+ match = re.search(r"metric=([\d.]+)\.pkl", os.path.basename(ckpt_file))
170
+ if match:
171
+ metric = float(match.group(1))
172
+ if metric > best_metric:
173
+ best_metric = metric
174
+ best_checkpoint = ckpt_file
175
+
176
+ model_path = best_checkpoint or checkpoint_files[0]
177
+
178
+ log.info(f"Loading original engine checkpoint from: {model_path}")
179
+
180
+ edm_stats = {"node": None, "prop": None}
181
+ stat_path = os.path.join(chkpt_directory, "edm_stat.pkl")
182
+ if os.path.exists(stat_path):
183
+ try:
184
+ with open(stat_path, "rb") as file:
185
+ loaded_stats = pickle.load(file)
186
+ if "node" in loaded_stats:
187
+ edm_stats["node"] = loaded_stats["node"]
188
+ elif "node_dist_model" in loaded_stats:
189
+ edm_stats["node"] = loaded_stats["node_dist_model"]
190
+ if "prop" in loaded_stats:
191
+ edm_stats["prop"] = loaded_stats["prop"]
192
+ elif "prop_dist_model" in loaded_stats:
193
+ edm_stats["prop"] = loaded_stats["prop_dist_model"]
194
+ except Exception as e:
195
+ log.warning(f"Failed to load edm_stat.pkl: {e}")
196
+
197
+ engine = Engine(None, None, None, None, None)
198
+ engine = engine.load_from_checkpoint(model_path, interference_mode=True)
199
+ task = engine.model
200
+
201
+ if edm_stats["node"] is not None:
202
+ task.node_dist_model = edm_stats["node"]
203
+ if edm_stats["prop"] is not None:
204
+ task.prop_dist_model = edm_stats["prop"]
205
+
206
+ if total_step > 0:
207
+ if hasattr(task, 'model') and hasattr(task.model, 'T'):
208
+ task.model.T = total_step
209
+ elif hasattr(task, 'T'):
210
+ task.T = total_step
211
+
212
+ task.eval()
213
+ return task
214
+
215
+
216
+ def generate(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
217
+ """Main generation function."""
218
+ if cfg.get("seed"):
219
+ seed_everything(cfg.seed, workers=True)
220
+
221
+ log.info(f"Instantiating diffusion task and loading the model <{cfg.tasks._target_}>")
222
+ task = load_model(
223
+ cfg.chkpt_directory,
224
+ task_config=cfg.tasks,
225
+ atom_vocab=cfg.atom_vocab,
226
+ total_step=cfg.diffusion_steps,
227
+ )
228
+
229
+ if not hasattr(task, 'atom_vocab') or task.atom_vocab is None:
230
+ task.atom_vocab = cfg.atom_vocab
231
+
232
+ if not hasattr(task, 'device'):
233
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
234
+ recursive_module_to_device(task, device)
235
+
236
+ log.info(f"Instantiating generator... <{cfg.interference._target_}>")
237
+ generator: GenerativeFactory = hydra.utils.instantiate(cfg.interference, task=task)
238
+
239
+ object_dict = {"cfg": cfg, "task": task, "generator": generator}
240
+
241
+ log.info("Logging hyperparameters!")
242
+ log_hyperparameters(object_dict)
243
+
244
+ os.makedirs(cfg.interference.output_path, exist_ok=True)
245
+
246
+ if is_rank_zero():
247
+ config_path = os.path.join(cfg.interference.output_path, "config.yaml")
248
+ with open(config_path, "w") as f:
249
+ OmegaConf.save(config=cfg, f=f)
250
+ log.info(f"Configuration saved to {config_path}")
251
+
252
+ generator.run()
253
+
254
+
255
+ def log_hyperparameters(object_dict: dict):
256
+ """Log hyperparameters for debugging."""
257
+ if not is_rank_zero():
258
+ return
259
+
260
+ log.info("\n========== Logging Hyperparameters ==========\n")
261
+ for name, obj in object_dict.items():
262
+ log.info(f"{'=' * 20} {name.upper()} {'=' * 20}")
263
+ if name == "cfg":
264
+ if isinstance(obj, dict):
265
+ log.info("\n" + OmegaConf.to_yaml(OmegaConf.create(obj)))
266
+ else:
267
+ log.info("\n" + OmegaConf.to_yaml(obj))
268
+ else:
269
+ if hasattr(obj, '__dict__'):
270
+ for k, v in vars(obj).items():
271
+ if not k.startswith("_"):
272
+ log.info(f"{k}: {v}")
273
+ log.info(f"{'=' * (44 + len(name))}\n")
274
+ log.info("========== End of Hyperparameters ==========\n")
275
+
276
+
277
+ def generate_main(cfg: DictConfig):
278
+ """Entry point for CLI generate command."""
279
+ start_time = time.time()
280
+ generate(cfg)
281
+ total_time = time.time() - start_time
282
+ log.warning(f"Total time of execution: {total_time:.2f} seconds")
MolecularDiffusion/cli/main.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """MolCraft CLI - Unified command-line interface for MolecularDiffusion.
2
+
3
+ Usage:
4
+ molcraft train config.yaml [overrides...]
5
+ molcraft generate config.yaml [overrides...]
6
+ molcraft predict config.yaml [overrides...]
7
+ """
8
+
9
+ import os
10
+ import logging
11
+ import platform
12
+
13
+ import click
14
+
15
+ # Setup logging
16
+ logging.basicConfig(
17
+ level=logging.INFO,
18
+ format='%(asctime)s - %(levelname)s - %(message)s'
19
+ )
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ def log_system_info():
24
+ """Log basic system information."""
25
+ import psutil
26
+
27
+ logger.info("=" * 60)
28
+ logger.info(f"OS: {platform.system()} {platform.release()}")
29
+ logger.info(f"CPU: {platform.processor()}, Cores: {os.cpu_count()}")
30
+
31
+ ram = psutil.virtual_memory()
32
+ logger.info(f"RAM: Total {ram.total / (1024**3):.2f} GB, Available {ram.available / (1024**3):.2f} GB")
33
+ logger.info(f"Python: {platform.python_version()}")
34
+
35
+ try:
36
+ import torch
37
+ logger.info(f"PyTorch: {torch.__version__}")
38
+ if torch.cuda.is_available():
39
+ logger.info(f"CUDA: {torch.version.cuda}, GPUs: {torch.cuda.device_count()}")
40
+ except ImportError:
41
+ pass
42
+
43
+ logger.info("=" * 60)
44
+
45
+
46
+ # Enable -h as alias for --help
47
+ CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
48
+
49
+
50
+ @click.group(context_settings=CONTEXT_SETTINGS)
51
+ @click.version_option(package_name="MolecularDiffusion")
52
+ def cli():
53
+ """MolCraft - Molecular Diffusion CLI.
54
+
55
+ A unified command-line interface for training, generation, and prediction
56
+ with molecular diffusion models.
57
+
58
+ \b
59
+ Examples:
60
+ molcraft train configs/my_train_config.yaml
61
+ molcraft generate configs/my_gen_config.yaml
62
+ molcraft predict configs/my_pred_config.yaml
63
+ """
64
+ pass
65
+
66
+
67
+ @cli.command(context_settings=CONTEXT_SETTINGS)
68
+ @click.argument("config", type=str)
69
+ @click.argument("overrides", nargs=-1)
70
+ def train(config: str, overrides: tuple):
71
+ """Train a molecular diffusion model.
72
+
73
+ \b
74
+ Arguments:
75
+ CONFIG Config file path (e.g., configs/train.yaml)
76
+ OVERRIDES Hydra-style config overrides (e.g., trainer.num_epochs=100)
77
+
78
+ \b
79
+ Examples:
80
+ molcraft train configs/train_tabasco_geom.yaml
81
+ molcraft train configs/my_config.yaml trainer.num_epochs=50 seed=42
82
+ """
83
+ log_system_info()
84
+ logger.info(f"Starting training with config: {config}")
85
+
86
+ from MolecularDiffusion.cli._hydra import run_hydra_app
87
+ from MolecularDiffusion.cli.train import train_main
88
+
89
+ run_hydra_app(
90
+ config_name=config,
91
+ task_function=train_main,
92
+ config_dir=None,
93
+ overrides=list(overrides),
94
+ )
95
+
96
+
97
+ @cli.command(context_settings=CONTEXT_SETTINGS)
98
+ @click.argument("config", type=str)
99
+ @click.argument("overrides", nargs=-1)
100
+ def generate(config: str, overrides: tuple):
101
+ """Generate molecules using a trained model.
102
+
103
+ \b
104
+ Arguments:
105
+ CONFIG Config file path (e.g., configs/generate.yaml)
106
+ OVERRIDES Hydra-style config overrides
107
+
108
+ \b
109
+ Examples:
110
+ molcraft generate configs/gen_config.yaml
111
+ molcraft generate configs/gen_config.yaml interference.n_samples=1000
112
+ """
113
+ log_system_info()
114
+ logger.info(f"Starting generation with config: {config}")
115
+
116
+ from MolecularDiffusion.cli._hydra import run_hydra_app
117
+ from MolecularDiffusion.cli.generate import generate_main
118
+
119
+ run_hydra_app(
120
+ config_name=config,
121
+ task_function=generate_main,
122
+ config_dir=None,
123
+ overrides=list(overrides),
124
+ )
125
+
126
+
127
+ @cli.command(context_settings=CONTEXT_SETTINGS)
128
+ @click.argument("config", type=str)
129
+ @click.argument("overrides", nargs=-1)
130
+ def predict(config: str, overrides: tuple):
131
+ """Run property prediction on molecules.
132
+
133
+ \b
134
+ Arguments:
135
+ CONFIG Config file path (e.g., configs/predict.yaml)
136
+ OVERRIDES Hydra-style config overrides
137
+
138
+ \b
139
+ Examples:
140
+ molcraft predict configs/predict.yaml
141
+ molcraft predict configs/my_pred.yaml xyz_directory=/path/to/xyz
142
+ """
143
+ log_system_info()
144
+ logger.info(f"Starting prediction with config: {config}")
145
+
146
+ from MolecularDiffusion.cli._hydra import run_hydra_app
147
+ from MolecularDiffusion.cli.predict import predict_main
148
+
149
+ run_hydra_app(
150
+ config_name=config,
151
+ task_function=predict_main,
152
+ config_dir=None,
153
+ overrides=list(overrides),
154
+ )
155
+
156
+
157
+ @cli.command("eval-predict", context_settings=CONTEXT_SETTINGS)
158
+ @click.argument("config", type=str)
159
+ @click.argument("overrides", nargs=-1)
160
+ def eval_predict(config: str, overrides: tuple):
161
+ """Evaluate model predictions on validation/test sets.
162
+
163
+ \b
164
+ Arguments:
165
+ CONFIG Config file path (e.g., configs/eval_predict.yaml)
166
+ OVERRIDES Hydra-style config overrides
167
+
168
+ \b
169
+ Examples:
170
+ molcraft eval-predict configs/eval_predict.yaml
171
+ """
172
+ log_system_info()
173
+ logger.info(f"Starting eval-predict with config: {config}")
174
+
175
+ from MolecularDiffusion.cli._hydra import run_hydra_app
176
+ from MolecularDiffusion.cli.eval_predict import eval_predict_main
177
+
178
+ run_hydra_app(
179
+ config_name=config,
180
+ task_function=eval_predict_main,
181
+ config_dir=None,
182
+ overrides=list(overrides),
183
+ )
184
+
185
+
186
+ # Register analyze subcommand group
187
+ from MolecularDiffusion.cli.analyze import analyze
188
+ cli.add_command(analyze)
189
+
190
+
191
+ def main():
192
+ """Entry point."""
193
+ cli()
194
+
195
+
196
+ if __name__ == "__main__":
197
+ main()
MolecularDiffusion/cli/predict.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Prediction command for MolCraft CLI.
2
+
3
+ Adapted from scripts/predict.py for package-level execution.
4
+ """
5
+
6
+ import os
7
+ from glob import glob
8
+ from typing import Any, Dict, Tuple
9
+
10
+ import hydra
11
+ import numpy as np
12
+ import pandas as pd
13
+ import torch
14
+ from ase.data import atomic_numbers
15
+ from omegaconf import DictConfig, OmegaConf
16
+ from torch_geometric.data import Data
17
+ from torch_geometric.nn import knn_graph, radius_graph
18
+ from tqdm import tqdm
19
+
20
+ from MolecularDiffusion.core import Engine
21
+ from MolecularDiffusion.data.component.pointcloud import PointCloud_Mol
22
+ from MolecularDiffusion.data.component.feature import (
23
+ onehot,
24
+ atom_topological,
25
+ atom_geom,
26
+ atom_geom_compact,
27
+ atom_geom_opt,
28
+ atom_geom_v2,
29
+ atom_geom_v2_trun,
30
+ )
31
+ from MolecularDiffusion.utils import RankedLogger, seed_everything
32
+ from MolecularDiffusion.utils.plot_function import (
33
+ plot_kde_distribution,
34
+ plot_histogram_distribution,
35
+ plot_kde_distribution_multiple,
36
+ )
37
+
38
+ log = RankedLogger(__name__, rank_zero_only=True)
39
+
40
+
41
+ def is_rank_zero():
42
+ """Check if current process is rank zero."""
43
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
44
+ return torch.distributed.get_rank() == 0
45
+ return True
46
+
47
+
48
+ def load_model(chkpt_path, task_config=None, atom_vocab=None):
49
+ """Load a pre-trained model from checkpoint with auto-detection."""
50
+ log.info(f"Loading checkpoint from: {chkpt_path}")
51
+
52
+ # Try loading as Lightning checkpoint first if it has .ckpt extension
53
+ if chkpt_path.endswith('.ckpt'):
54
+ try:
55
+ from MolecularDiffusion.core.engine_lightning import EngineLightning
56
+ wrapper = EngineLightning.load_from_checkpoint(chkpt_path, map_location="cpu")
57
+ log.info("Successfully loaded model using EngineLightning.load_from_checkpoint")
58
+
59
+ # Need to return something that has a .model attribute for backward compatibility
60
+ class SolverWrapper:
61
+ def __init__(self, task):
62
+ self.model = task
63
+
64
+ solver = SolverWrapper(wrapper.task)
65
+ solver.model.eval()
66
+ return solver
67
+ except Exception as e:
68
+ log.warning(f"EngineLightning.load_from_checkpoint failed ({type(e).__name__}: {e}). Trying manual fallback.")
69
+
70
+ # Manual fallback or original engine (.pkl/no extension)
71
+ checkpoint = torch.load(chkpt_path, map_location="cpu", weights_only=False)
72
+
73
+ # Check if it's a Lightning checkpoint dictionary
74
+ if "hyper_parameters" in checkpoint:
75
+ log.info("Detected Lightning checkpoint dictionary.")
76
+ hparams = checkpoint.get("hyper_parameters", {})
77
+
78
+ # Try to get model_config from checkpoint
79
+ model_config = hparams.get("model_config", task_config)
80
+ if model_config is None:
81
+ raise ValueError("Lightning checkpoint lacks 'model_config' and no 'task_config' provided.")
82
+
83
+ # Instantiate task
84
+ if isinstance(model_config, dict):
85
+ model_config = OmegaConf.create(model_config)
86
+
87
+ # Ensure we have atom_vocab if needed
88
+ if atom_vocab is not None and ('atom_vocab' not in model_config or model_config.atom_vocab is None):
89
+ OmegaConf.set_struct(model_config, False)
90
+ model_config.atom_vocab = atom_vocab
91
+
92
+ task_factory = hydra.utils.instantiate(model_config)
93
+ task = task_factory.build()
94
+
95
+ # Load weights
96
+ state_dict = checkpoint.get("state_dict", {})
97
+ cleaned_state_dict = {}
98
+ for k, v in state_dict.items():
99
+ if k.startswith("task."):
100
+ cleaned_state_dict[k[5:]] = v
101
+ else:
102
+ cleaned_state_dict[k] = v
103
+
104
+ task.load_state_dict(cleaned_state_dict, strict=False)
105
+ log.info(f"Loaded {len(cleaned_state_dict)} parameters from state_dict")
106
+
107
+ # Try to recover mean/std if they are in the checkpoint root or state_dict but not as buffers
108
+ for key in ["mean", "std", "weight"]:
109
+ val = None
110
+ if key in checkpoint:
111
+ val = checkpoint[key]
112
+ elif f"task.{key}" in state_dict:
113
+ val = state_dict[f"task.{key}"]
114
+ elif key in state_dict:
115
+ val = state_dict[key]
116
+
117
+ if val is not None:
118
+ if not isinstance(val, torch.Tensor):
119
+ val = torch.as_tensor(val, dtype=torch.float32)
120
+
121
+ # Register as buffer to ensure it moves with the model to the correct device
122
+ if key in task._buffers:
123
+ task._buffers[key].copy_(val)
124
+ else:
125
+ task.register_buffer(key, val)
126
+
127
+ # Ensure task has a device attribute
128
+ if not hasattr(task, 'device'):
129
+ task.device = next(task.parameters()).device if list(task.parameters()) else torch.device('cpu')
130
+
131
+ class SolverWrapper:
132
+ def __init__(self, task):
133
+ self.model = task
134
+
135
+ solver = SolverWrapper(task)
136
+ solver.model.eval()
137
+ # Ensure task.device is updated to the actual device solver is using
138
+ if hasattr(solver.model, 'device') and solver.model.device != next(solver.model.parameters()).device:
139
+ solver.model.device = next(solver.model.parameters()).device if list(solver.model.parameters()) else torch.device('cpu')
140
+ elif not hasattr(solver.model, 'device'):
141
+ solver.model.device = next(solver.model.parameters()).device if list(solver.model.parameters()) else torch.device('cpu')
142
+ return solver
143
+ else:
144
+ # Original Engine checkpoint
145
+ engine = Engine(None, None, None, None, None)
146
+ solver = engine.load_from_checkpoint(chkpt_path, interference_mode=True)
147
+ solver.model.eval()
148
+ # Ensure task.device is updated to the actual device solver is using
149
+ solver.model.device = solver.device
150
+ return solver
151
+
152
+
153
+ def xyz2mol(xyz_file, atom_vocab, node_feature, edge_type="fully_connected",
154
+ radius=4.0, n_neigh=5, device="cpu"):
155
+ """Convert an XYZ file into a PyTorch Geometric Data object."""
156
+ mol_obj = {}
157
+ mol_xyz = PointCloud_Mol.from_xyz(xyz_file, with_hydrogen=True, forbidden_atoms=[])
158
+ coords = mol_xyz.get_coord()
159
+ n_nodes = len(mol_xyz.atoms)
160
+
161
+ node_features = []
162
+ for atom in mol_xyz.atoms:
163
+ node_features.append(onehot(atom.element, atom_vocab, allow_unknown=False))
164
+
165
+ charges = [
166
+ atomic_numbers[atom.element]
167
+ for atom in mol_xyz.atoms
168
+ if atom.element in atomic_numbers
169
+ ]
170
+
171
+ if node_feature:
172
+ if node_feature in [
173
+ "atom_topological", "atom_geom", "atom_geom_v2",
174
+ "atom_geom_v2_trun", "atom_geom_opt", "atom_geom_compact"
175
+ ]:
176
+ feature_mapping = {
177
+ "atom_topological": atom_topological,
178
+ "atom_geom": atom_geom,
179
+ "atom_geom_v2": atom_geom_v2,
180
+ "atom_geom_v2_trun": atom_geom_v2_trun,
181
+ "atom_geom_opt": atom_geom_opt,
182
+ "atom_geom_compact": atom_geom_compact,
183
+ }
184
+ feature_function = feature_mapping.get(node_feature)
185
+ if feature_function is not None:
186
+ node_features_extra = feature_function(charges, coords)
187
+ node_features = torch.cat(
188
+ [torch.tensor(node_features), node_features_extra], dim=1
189
+ )
190
+ else:
191
+ raise ValueError("Unknown node feature type")
192
+ else:
193
+ node_features = torch.tensor(node_features, dtype=torch.float32)
194
+
195
+ node_features = torch.tensor(node_features, dtype=torch.float32)
196
+ charges = torch.as_tensor(charges, dtype=torch.long)
197
+ node_mask = torch.ones(n_nodes, dtype=torch.int8)
198
+
199
+ edge_mask = node_mask.unsqueeze(0) * node_mask.unsqueeze(1)
200
+ diag_mask = ~torch.eye(n_nodes, dtype=torch.bool)
201
+ edge_mask *= diag_mask
202
+ edge_mask = edge_mask.view(1 * n_nodes * n_nodes, 1)
203
+ h = node_features.view(1 * n_nodes, -1).clone()
204
+
205
+ if edge_type == "distance":
206
+ edge_index = radius_graph(coords, r=radius)
207
+ elif edge_type == "neighbor":
208
+ edge_index = knn_graph(coords, k=n_neigh)
209
+ elif edge_type == "fully_connected":
210
+ num_nodes = coords.size(0)
211
+ row = torch.arange(num_nodes).repeat_interleave(num_nodes)
212
+ col = torch.arange(num_nodes).repeat(num_nodes)
213
+ edge_index = torch.stack([row, col], dim=0)
214
+ edge_index = edge_index[:, row != col]
215
+ else:
216
+ raise ValueError(f"Unknown edge type {edge_type}")
217
+
218
+ graph_data = Data(
219
+ x=h,
220
+ pos=coords,
221
+ atomic_numbers=charges,
222
+ natoms=torch.tensor([n_nodes]),
223
+ edge_index=edge_index,
224
+ times=torch.tensor([0]),
225
+ batch=torch.zeros(n_nodes, dtype=torch.long),
226
+ ).to(device)
227
+
228
+ mol_obj["graph"] = graph_data
229
+ return mol_obj
230
+
231
+
232
+ def count_atoms_from_xyz(path: str) -> int:
233
+ """Fast atom counter for XYZ files."""
234
+ try:
235
+ with open(path, "r") as f:
236
+ first = f.readline().strip()
237
+ return int(first)
238
+ except Exception:
239
+ return 0
240
+
241
+
242
+ def _runner(solver, xyz_paths: list, max_atoms: int = 100) -> torch.Tensor:
243
+ """Runs predictions on a list of XYZ files."""
244
+ device = getattr(solver.model, 'device', next(solver.model.parameters()).device if list(solver.model.parameters()) else torch.device('cpu'))
245
+ task_names = list(solver.model.task.keys())
246
+ num_molecules = len(xyz_paths)
247
+
248
+ progress_bar = tqdm(
249
+ enumerate(xyz_paths),
250
+ desc="Predicting molecules",
251
+ leave=True,
252
+ dynamic_ncols=True,
253
+ total=num_molecules,
254
+ )
255
+
256
+ predictions = []
257
+ xyz_paths_clear = []
258
+ skipped = 0
259
+
260
+ for i, xyz_path in progress_bar:
261
+ n_atoms = count_atoms_from_xyz(xyz_path)
262
+ if n_atoms > max_atoms:
263
+ skipped += 1
264
+ progress_bar.set_postfix({"batch": i + 1, "skipped": skipped})
265
+ log.info(f"Skipping {xyz_path} (atoms={n_atoms} > max_atoms={max_atoms})")
266
+ continue
267
+
268
+ mol_obj = xyz2mol(
269
+ xyz_file=xyz_path,
270
+ atom_vocab=solver.model.atom_vocab,
271
+ node_feature=solver.model.node_feature,
272
+ device=device,
273
+ )
274
+ prediction = solver.model.predict(mol_obj, evaluate=True)[0]
275
+ predictions.append(prediction.detach().cpu().numpy())
276
+ current_preds_dict = {prop_name: prediction[j].item() for j, prop_name in enumerate(task_names)}
277
+ progress_bar.set_postfix({"batch": i + 1, "skipped": skipped, **current_preds_dict})
278
+ xyz_paths_clear.append(xyz_path)
279
+
280
+ predictions = np.array(predictions)
281
+ if predictions.ndim > 1 and predictions.shape[-1] == 1:
282
+ predictions = predictions.squeeze(-1)
283
+
284
+ return predictions, xyz_paths_clear
285
+
286
+
287
+ def runner(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
288
+ """Property prediction run."""
289
+ if cfg.get("seed"):
290
+ seed_everything(cfg.seed, workers=True)
291
+
292
+ log.info(f"Instantiating diffusion task and loading the model <{cfg.tasks._target_}>")
293
+ solver = load_model(cfg.chkpt_directory, task_config=cfg.tasks, atom_vocab=cfg.atom_vocab)
294
+
295
+ task_names = list(solver.model.task.keys())
296
+
297
+ if not hasattr(solver.model, 'std') or solver.model.std is None:
298
+ chkpt = torch.load(cfg.chkpt_directory, weights_only=False)
299
+ if "model" in chkpt:
300
+ solver.model.std = chkpt["model"].get("std", torch.ones(1)).to(solver.model.device)
301
+ solver.model.weight = chkpt["model"].get("weight", torch.ones(1)).to(solver.model.device)
302
+ solver.model.mean = chkpt["model"].get("mean", torch.zeros(1)).to(solver.model.device)
303
+ elif "state_dict" in chkpt:
304
+ # Fallback for Lightning if not already loaded by load_model
305
+ sd = chkpt["state_dict"]
306
+ solver.model.std = sd.get("task.std", sd.get("std", torch.ones(1))).to(solver.model.device)
307
+ solver.model.weight = sd.get("task.weight", sd.get("weight", torch.ones(1))).to(solver.model.device)
308
+ solver.model.mean = sd.get("task.mean", sd.get("mean", torch.zeros(1))).to(solver.model.device)
309
+
310
+ if not hasattr(solver.model, 'atom_vocab'):
311
+ solver.model.atom_vocab = cfg.atom_vocab
312
+ if not hasattr(solver.model, 'node_feature'):
313
+ solver.model.node_feature = cfg.node_feature
314
+
315
+ object_dict = {"cfg": cfg, "solver": solver}
316
+
317
+ log.info("Logging hyperparameters!")
318
+ log_hyperparameters(object_dict)
319
+
320
+ os.makedirs(cfg.output_directory, exist_ok=True)
321
+
322
+ if is_rank_zero():
323
+ config_path = os.path.join(cfg.output_directory, "config.yaml")
324
+ with open(config_path, "w") as f:
325
+ OmegaConf.save(config=cfg, f=f)
326
+ log.info(f"Configuration saved to {config_path}")
327
+
328
+ log.info("Running the predictions...")
329
+ xyz_paths = glob(f"{cfg.xyz_directory}/*.xyz")
330
+ xyz_paths = [str(xyz_path) for xyz_path in xyz_paths]
331
+ predictions, xyz_paths_clear = _runner(solver, xyz_paths, max_atoms=cfg.get("max_atoms", 100))
332
+
333
+ df_dicts = {}
334
+ for task_name, prediction in zip(task_names, predictions.T):
335
+ df_dicts[task_name] = prediction
336
+ df_dicts["xyz_path"] = xyz_paths_clear
337
+
338
+ df = pd.DataFrame(df_dicts)
339
+ df = df.sort_values(by="xyz_path")
340
+ df.to_csv(f"{cfg.output_directory}/predictions.csv", index=False)
341
+
342
+ log.info("Prediction statistics:")
343
+ for task_name in task_names:
344
+ log.info(f"--- {task_name} ---")
345
+ log.info(f"Mean: {df[task_name].mean():.4f}")
346
+ log.info(f"Std: {df[task_name].std():.4f}")
347
+ log.info(f"Min: {df[task_name].min():.4f}")
348
+ log.info(f"Max: {df[task_name].max():.4f}")
349
+
350
+ log.info("Plotting distributions...")
351
+ props = []
352
+ for prop in task_names:
353
+ plot_kde_distribution(df[prop], prop, f"{cfg.output_directory}/{prop}_kde.png")
354
+ plot_histogram_distribution(df[prop], prop, f"{cfg.output_directory}/{prop}_hist.png")
355
+ props.append(df[prop].values)
356
+
357
+ props = np.array(props).T
358
+ plot_kde_distribution_multiple(props, task_names, f"{cfg.output_directory}/kde_all.png")
359
+
360
+
361
+ def log_hyperparameters(object_dict: dict):
362
+ """Log hyperparameters for debugging."""
363
+ if not is_rank_zero():
364
+ return
365
+
366
+ log.info("\n========== Logging Hyperparameters ==========\n")
367
+ for name, obj in object_dict.items():
368
+ log.info(f"{'=' * 20} {name.upper()} {'=' * 20}")
369
+ if name == "cfg":
370
+ if isinstance(obj, dict):
371
+ log.info("\n" + OmegaConf.to_yaml(OmegaConf.create(obj)))
372
+ else:
373
+ log.info("\n" + OmegaConf.to_yaml(obj))
374
+ else:
375
+ if hasattr(obj, '__dict__'):
376
+ for k, v in vars(obj).items():
377
+ if not k.startswith("_"):
378
+ log.info(f"{k}: {v}")
379
+ log.info(f"{'=' * (44 + len(name))}\n")
380
+
381
+ if "task" in object_dict and hasattr(object_dict["task"], "task"):
382
+ model = object_dict["task"].task
383
+ total = sum(p.numel() for p in model.parameters())
384
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
385
+ log.info(f"{'=' * 20} MODEL PARAMS {'=' * 20}")
386
+ log.info(f"model/params/total: {total}")
387
+ log.info(f"model/params/trainable: {trainable}")
388
+ log.info("=" * 54 + "\n")
389
+
390
+ log.info("========== End of Hyperparameters ==========\n")
391
+
392
+
393
+ def predict_main(cfg: DictConfig):
394
+ """Entry point for CLI predict command."""
395
+ runner(cfg)
MolecularDiffusion/cli/train.py ADDED
@@ -0,0 +1,453 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Training command for MolCraft CLI.
2
+
3
+ Adapted from scripts/train.py for package-level execution.
4
+ """
5
+
6
+ from typing import Any, Dict, Optional, Tuple
7
+ import os
8
+ import pickle
9
+ import logging
10
+
11
+ import hydra
12
+ import torch
13
+ from omegaconf import DictConfig, OmegaConf
14
+
15
+ from MolecularDiffusion.core import Engine
16
+ from MolecularDiffusion.runmodes.train import (
17
+ evaluate,
18
+ DataModule,
19
+ Logger,
20
+ OptimSchedulerFactory,
21
+ get_versioned_output_path,
22
+ )
23
+ from MolecularDiffusion.utils import (
24
+ RankedLogger,
25
+ task_wrapper,
26
+ seed_everything,
27
+ )
28
+
29
+ log = RankedLogger(__name__, rank_zero_only=True)
30
+
31
+
32
+
33
+ def is_rank_zero():
34
+ """Check if current process is rank zero."""
35
+ if torch.distributed.is_available() and torch.distributed.is_initialized():
36
+ return torch.distributed.get_rank() == 0
37
+ return True
38
+
39
+
40
+ def load_weights(task, ckpt_path):
41
+ """Load model weights from a checkpoint file (weights only).
42
+
43
+ This loads the state_dict from the checkpoint into the task model,
44
+ ignoring optimizer/scheduler states and other metadata.
45
+ Useful for fine-tuning or starting from a pre-trained model.
46
+ """
47
+ if not os.path.exists(ckpt_path):
48
+ raise FileNotFoundError(f"Checkpoint not found at: {ckpt_path}")
49
+
50
+ log.info(f"Loading weights from: {ckpt_path}")
51
+
52
+ # Load checkpoint
53
+ checkpoint = torch.load(ckpt_path, map_location="cpu", weights_only=False)
54
+ state_dict = checkpoint.get("state_dict", checkpoint)
55
+
56
+ # Prepare state dict for loading
57
+ cleaned_state_dict = {}
58
+ for key, value in state_dict.items():
59
+ # Strip 'task.' prefix if present (common in Lightning checkpoints)
60
+ if key.startswith("task."):
61
+ cleaned_state_dict[key[5:]] = value
62
+ else:
63
+ cleaned_state_dict[key] = value
64
+
65
+ # Load into task
66
+ missing, unexpected = task.load_state_dict(cleaned_state_dict, strict=False)
67
+
68
+ if len(missing) > 0:
69
+ log.warning(f"Missing keys when loading weights: {missing[:5]}{'...' if len(missing)>5 else ''}")
70
+ if len(unexpected) > 0:
71
+ log.warning(f"Unexpected keys in checkpoint: {unexpected[:5]}{'...' if len(unexpected)>5 else ''}")
72
+
73
+ log.info(f"Successfully loaded {len(cleaned_state_dict)} parameters into task.")
74
+
75
+
76
+ # Lightning imports (optional)
77
+ try:
78
+ import pytorch_lightning as pl
79
+ from pytorch_lightning.callbacks import ModelCheckpoint
80
+ from pytorch_lightning.callbacks import LearningRateMonitor
81
+ from MolecularDiffusion.core.engine_lightning import EngineLightning
82
+ from MolecularDiffusion.data.lightning_data_module import MolecularDiffusionDataModule
83
+ from MolecularDiffusion.core.lightning_callbacks import GenerativeEvalCallback
84
+ LIGHTNING_AVAILABLE = True
85
+ except ImportError as e:
86
+ LIGHTNING_AVAILABLE = False
87
+ log.warning(f"PyTorch Lightning not found: {e}. Only original Engine available.")
88
+
89
+
90
+ def engine_wrapper(task_module, data_module, trainer_module, logger_module,
91
+ resume_from_checkpoint=None, **kwargs):
92
+ """Training loop using original Engine."""
93
+ trainer_module.get_optimizer()
94
+ trainer_module.get_scheduler()
95
+
96
+ solver = Engine(
97
+ task_module.task,
98
+ data_module.train_set,
99
+ data_module.valid_set,
100
+ data_module.test_set,
101
+ batch_size=data_module.batch_size,
102
+ collate_fn=data_module.collate_fn,
103
+ optimizer=trainer_module.optimizer,
104
+ ema_decay=trainer_module.ema_decay,
105
+ scheduler=trainer_module.scheduler,
106
+ clipping_gradient=trainer_module.gradient_clip_mode,
107
+ clip_value=trainer_module.gradnorm_queue,
108
+ logger=logger_module.logger,
109
+ log_interval=logger_module.log_interval,
110
+ name_wandb=logger_module.name_wandb,
111
+ project_wandb=logger_module.project_wandb,
112
+ dir_wandb=trainer_module.output_path,
113
+ )
114
+
115
+ # Resume from checkpoint if provided
116
+ start_epoch = 0
117
+ if resume_from_checkpoint:
118
+ start_epoch = solver.resume(resume_from_checkpoint, strict=False)
119
+ log.info(f"Resumed from epoch {start_epoch}")
120
+
121
+ use_amp = trainer_module.precision in ["bf16", 16]
122
+
123
+ best_checkpoints = []
124
+ best_checkpoints = []
125
+ if hasattr(task_module.task, "sample") and kwargs.get("generative_analysis"):
126
+ best_metrics = -torch.inf
127
+ models_to_save = {"node": task_module.task.node_dist_model}
128
+ if len(task_module.condition_names) > 0:
129
+ models_to_save["prop"] = task_module.task.prop_dist_model
130
+ if is_rank_zero():
131
+ with open(os.path.join(trainer_module.output_path, "edm_stat.pkl"), "wb") as f:
132
+ pickle.dump(models_to_save, f)
133
+ else:
134
+ best_metrics = torch.inf
135
+
136
+ # Create versioned checkpoint folder (like Lightning's version_X folders)
137
+ versioned_ckpt_path = get_versioned_output_path(trainer_module.output_path)
138
+
139
+ # Adjust loop to continue from start_epoch
140
+ for i in range(start_epoch, trainer_module.num_epochs):
141
+ solver.train(num_epoch=1, use_amp=use_amp, precision=trainer_module.precision)
142
+ if i % trainer_module.validation_interval == 0 or i == trainer_module.num_epochs - 1:
143
+ if hasattr(task_module.task, "sample"):
144
+ output_generated_dir = os.path.join(versioned_ckpt_path, "generated_molecules")
145
+ os.makedirs(output_generated_dir, exist_ok=True)
146
+ best_metrics, best_checkpoints = evaluate(
147
+ task_module.task_type, solver, i, best_metrics, best_checkpoints,
148
+ logger_module.logger, output_generated_dir=output_generated_dir,
149
+ generative_analysis=kwargs.get("generative_analysis", False),
150
+ n_samples=kwargs.get("n_samples", 100),
151
+ metric=kwargs.get("metric", "Validity Relax and connected"),
152
+ output_path=versioned_ckpt_path,
153
+ use_amp=use_amp, precision=trainer_module.precision,
154
+ use_posebuster=kwargs.get("use_posebuster", False),
155
+ batch_size=kwargs.get("batch_size", 1),
156
+ save_top_k=getattr(trainer_module, "save_top_k", 3),
157
+ save_every_val_epoch=getattr(trainer_module, "save_every_val_epoch", False),
158
+ )
159
+ else:
160
+ best_metrics, best_checkpoints = evaluate(
161
+ task_module.task_type, solver, i, best_metrics, best_checkpoints,
162
+ logger_module.logger, output_path=versioned_ckpt_path,
163
+ save_top_k=getattr(trainer_module, "save_top_k", 3),
164
+ save_every_val_epoch=getattr(trainer_module, "save_every_val_epoch", False),
165
+ )
166
+ return best_metrics, solver
167
+
168
+
169
+ def lightning_wrapper(task_module, data_module, trainer_module, logger_module, engine_cfg,
170
+ ckpt_path=None, monitor_metric=None, monitor_mode=None, model_config=None, **kwargs):
171
+ """Training using PyTorch Lightning Trainer."""
172
+ if not LIGHTNING_AVAILABLE:
173
+ raise ImportError("PyTorch Lightning required. Install with: pip install pytorch-lightning")
174
+
175
+ if hasattr(task_module.task, "preprocess"):
176
+ log.info("Calling task.preprocess() for Lightning engine")
177
+ result = task_module.task.preprocess(data_module.train_set)
178
+ if result is not None:
179
+ data_module.train_set, data_module.valid_set, data_module.test_set = result
180
+
181
+ pl_data_module = MolecularDiffusionDataModule(
182
+ data_module=data_module,
183
+ batch_size=data_module.batch_size,
184
+ num_workers=getattr(trainer_module, "num_worker", 0),
185
+ )
186
+
187
+ pl_module = EngineLightning(
188
+ task=task_module.task,
189
+ optimizer_config={
190
+ "optimizer_choice": trainer_module.optimizer_choice,
191
+ "lr": trainer_module.lr,
192
+ "weight_decay": trainer_module.weight_decay,
193
+ "betas": trainer_module.betas,
194
+ "eps": trainer_module.eps,
195
+ },
196
+ scheduler_config={
197
+ "scheduler": trainer_module.scheduler_choice,
198
+ "scheduler_kwargs": trainer_module.scheduler_choice_kwargs,
199
+ },
200
+ model_config=model_config,
201
+ monitor_metric=monitor_metric,
202
+ ema_decay=trainer_module.ema_decay,
203
+ gradnorm_queue=trainer_module.gradnorm_queue,
204
+ gradient_clip_algorithm=getattr(trainer_module, 'gradient_clip_algorithm', 'adaptive'),
205
+ )
206
+
207
+ callbacks = []
208
+
209
+ if hasattr(task_module.task, "sample") and kwargs.get("generative_analysis"):
210
+ callbacks.append(GenerativeEvalCallback(
211
+ n_samples=kwargs.get("n_samples", 100),
212
+ batch_size=kwargs.get("batch_size", 100),
213
+ metric=kwargs.get("metric", "Validity Relax and connected"),
214
+ output_dir=os.path.join(trainer_module.output_path, "generated_molecules"),
215
+ use_posebuster=kwargs.get("use_posebuster", False),
216
+ monitor_metric=monitor_metric,
217
+ ))
218
+
219
+ # Checkpoint callback
220
+ # Handle OmegaConf ListConfig properly
221
+ if monitor_metric is not None:
222
+ # Convert OmegaConf types to Python types
223
+ if OmegaConf.is_list(monitor_metric):
224
+ monitor_metric_key = str(monitor_metric[0])
225
+ elif isinstance(monitor_metric, (list, tuple)):
226
+ monitor_metric_key = str(monitor_metric[0])
227
+ else:
228
+ monitor_metric_key = str(monitor_metric)
229
+ mode = monitor_mode or ("min" if "loss" in monitor_metric_key else "max")
230
+ elif hasattr(task_module.task, "sample"):
231
+ monitor_metric_key = f"gen/{kwargs.get('metric', 'Validity Relax and connected')}"
232
+ mode = "max"
233
+ else:
234
+ monitor_metric_key = "val/loss"
235
+ mode = "min"
236
+
237
+ # Handle save_every_val_epoch
238
+ save_top_k = trainer_module.save_top_k
239
+ if getattr(trainer_module, "save_every_val_epoch", False) or kwargs.get("save_every_val_epoch", False):
240
+ log.info("save_every_val_epoch=True: Overriding save_top_k to -1 (save all checkpoints)")
241
+ save_top_k = -1
242
+
243
+ callbacks.append(ModelCheckpoint(
244
+ monitor=monitor_metric_key,
245
+ mode=mode,
246
+ save_top_k=save_top_k,
247
+ filename=f"epoch={{epoch}}-{monitor_metric_key.replace('/', '_').replace(' ', '_')}={{{monitor_metric_key}:.3f}}",
248
+ save_last=True,
249
+ ))
250
+
251
+ # Learning rate monitor for wandb logging
252
+ callbacks.append(LearningRateMonitor(logging_interval='step'))
253
+
254
+ trainer_config = OmegaConf.to_container(engine_cfg.trainer_config, resolve=True)
255
+ precision_map = {32: 32, 16: "16-mixed", "16": "16-mixed", "bf16": "bf16-mixed"}
256
+ trainer_config["precision"] = precision_map.get(trainer_config.get("precision", 32), 32)
257
+
258
+ if logger_module.logger == "wandb":
259
+ pl_logger = pl.loggers.WandbLogger(
260
+ project=logger_module.project_wandb,
261
+ name=logger_module.name_wandb,
262
+ save_dir=trainer_module.output_path,
263
+ )
264
+ else:
265
+ pl_logger = True
266
+
267
+ trainer = hydra.utils.instantiate(trainer_config, callbacks=callbacks, logger=pl_logger)
268
+
269
+ if ckpt_path:
270
+ trainer.fit(pl_module, datamodule=pl_data_module, ckpt_path=ckpt_path)
271
+ else:
272
+ trainer.fit(pl_module, datamodule=pl_data_module)
273
+
274
+ return trainer.callback_metrics, trainer
275
+
276
+
277
+ @task_wrapper
278
+ def train(cfg: DictConfig) -> Tuple[Dict[str, Any], Dict[str, Any]]:
279
+ """Main training function."""
280
+ output_path = cfg.trainer.output_path
281
+ os.makedirs(output_path, exist_ok=True)
282
+
283
+ if is_rank_zero():
284
+ config_path = os.path.join(output_path, "config.yaml")
285
+ with open(config_path, "w") as f:
286
+ OmegaConf.save(config=cfg, f=f)
287
+ log.info(f"Configuration saved to {config_path}")
288
+
289
+ if cfg.get("seed"):
290
+ seed_everything(cfg.seed, workers=True)
291
+
292
+ log.info(f"Instantiating datamodule <{cfg.data._target_}>")
293
+ data_module: DataModule = hydra.utils.instantiate(cfg.data, task_type=cfg.tasks.task_type)
294
+ data_module.load()
295
+
296
+ log.info(f"Instantiating task <{cfg.tasks._target_}>")
297
+ data_point_chk = data_module.train_set[0]
298
+ node_feature_0 = getattr(data_point_chk, "node_feature", None)
299
+ if node_feature_0 is not None:
300
+ n_dim = node_feature_0.shape[1]
301
+ else:
302
+ try:
303
+ node_feature_0 = getattr(data_point_chk, "x", None)
304
+ n_dim = node_feature_0.shape[1]
305
+ except:
306
+ n_dim = 0
307
+
308
+ factory_cfg = cfg.tasks
309
+ overrides = {}
310
+
311
+ if "tasks_egt" in factory_cfg._target_ or "tasks_esen" in factory_cfg._target_ or "diffusion_tabasco" in factory_cfg._target_:
312
+ overrides["train_set"] = data_module.train_set
313
+ if "condition_names" in factory_cfg:
314
+ overrides["task_names"] = factory_cfg.condition_names
315
+
316
+ if "atom_vocab" in cfg.data:
317
+ overrides["atom_vocab"] = list(cfg.data.atom_vocab)
318
+
319
+ if cfg.data.get("allow_unknown", False):
320
+ overrides["atom_vocab"].append("Suisei")
321
+
322
+ if cfg.tasks.get("metrics", None) == "valid_posebuster":
323
+ overrides["use_posebuster"] = True
324
+ try:
325
+ import posebusters
326
+ except ImportError:
327
+ log.warning("PoseBuster not installed. Falling back to 'Validity Relax and connected'.")
328
+ overrides["use_posebuster"] = False
329
+ overrides["metrics"] = ["Validity Relax and connected"]
330
+
331
+ task_module = hydra.utils.instantiate(factory_cfg, **overrides)
332
+ task_module.build()
333
+
334
+ # Optional: Load weights from checkpoint (without resuming full state)
335
+ if cfg.trainer.get("load_weights_from"):
336
+ load_weights(task_module.task, cfg.trainer.load_weights_from)
337
+
338
+ log.info(f"Instantiating trainer <{cfg.trainer._target_}>")
339
+ trainer_module: OptimSchedulerFactory = hydra.utils.instantiate(
340
+ cfg.trainer, parameters=task_module.task.parameters()
341
+ )
342
+
343
+ name_wandb = trainer_module.output_path.split('/')[-1] if "/" in trainer_module.output_path else trainer_module.output_path
344
+ log.info(f"Instantiating loggers... <{cfg.logger._target_}>")
345
+ logger_module: Logger = hydra.utils.instantiate(cfg.logger, name_wandb=name_wandb)
346
+
347
+ object_dict = {
348
+ "cfg": cfg,
349
+ "datamodule": data_module,
350
+ "task": task_module,
351
+ "trainer": trainer_module,
352
+ "logger": logger_module,
353
+ }
354
+
355
+ log.info("Logging hyperparameters!")
356
+ log_hyperparameters(object_dict)
357
+
358
+ engine_type = cfg.get("engine", {}).get("engine_type", "original")
359
+ log.info(f"Using engine: {engine_type}")
360
+
361
+ if engine_type == "lightning":
362
+ gen_analysis = cfg.get("generative_analysis", cfg.tasks.get("generative_analysis", False))
363
+ n_samples = cfg.get("n_samples", cfg.tasks.get("n_samples", 100))
364
+ metric = cfg.get("metrics", cfg.get("metric", cfg.tasks.get("metrics", "Validity Relax and connected")))
365
+ use_posebuster = cfg.get("use_posebuster", cfg.tasks.get("use_posebuster", False))
366
+ gen_batch_size = cfg.get("batch_size", cfg.tasks.get("batch_size", 100))
367
+
368
+ # Always save model_config for checkpoint reconstruction (VAE, LDM, etc.)
369
+ model_config = OmegaConf.to_container(factory_cfg, resolve=True)
370
+ for k, v in overrides.items():
371
+ if k != "train_set":
372
+ model_config[k] = v
373
+
374
+ if hasattr(task_module.task, "sample"):
375
+ metrics = lightning_wrapper(
376
+ task_module, data_module, trainer_module, logger_module,
377
+ engine_cfg=cfg.engine,
378
+ generative_analysis=gen_analysis, n_samples=n_samples,
379
+ metric=metric, use_posebuster=use_posebuster, batch_size=gen_batch_size,
380
+ ckpt_path=cfg.trainer.get("resume_from_checkpoint", None),
381
+ monitor_metric=cfg.trainer.get("monitor_metric", None),
382
+ monitor_mode=cfg.trainer.get("monitor_mode", None),
383
+ model_config=model_config,
384
+ )
385
+ else:
386
+ metrics = lightning_wrapper(
387
+ task_module, data_module, trainer_module, logger_module,
388
+ engine_cfg=cfg.engine,
389
+ ckpt_path=cfg.trainer.get("resume_from_checkpoint", None),
390
+ monitor_metric=cfg.trainer.get("monitor_metric", None),
391
+ monitor_mode=cfg.trainer.get("monitor_mode", None),
392
+ model_config=model_config,
393
+ )
394
+
395
+ elif engine_type == "original":
396
+ resume_ckpt = cfg.trainer.get("resume_from_checkpoint", None)
397
+ if hasattr(task_module.task, "sample"):
398
+ metrics = engine_wrapper(
399
+ task_module, data_module, trainer_module, logger_module,
400
+ resume_from_checkpoint=resume_ckpt,
401
+ generative_analysis=cfg.tasks.generative_analysis,
402
+ n_samples=cfg.tasks.n_samples,
403
+ metric=cfg.tasks.metrics,
404
+ use_posebuster=cfg.tasks.use_posebuster,
405
+ batch_size=cfg.tasks.batch_size,
406
+ )
407
+ else:
408
+ metrics = engine_wrapper(
409
+ task_module, data_module, trainer_module, logger_module,
410
+ resume_from_checkpoint=resume_ckpt,
411
+ )
412
+ else:
413
+ raise ValueError(f"Unknown engine_type: {engine_type}")
414
+
415
+ return metrics, object_dict
416
+
417
+
418
+ def log_hyperparameters(object_dict: dict):
419
+ """Log hyperparameters for debugging."""
420
+ if not is_rank_zero():
421
+ return
422
+
423
+ log.info("\n========== Logging Hyperparameters ==========\n")
424
+ for name, obj in object_dict.items():
425
+ log.info(f"{'=' * 20} {name.upper()} {'=' * 20}")
426
+ if name == "cfg":
427
+ if isinstance(obj, dict):
428
+ log.info("\n" + OmegaConf.to_yaml(OmegaConf.create(obj)))
429
+ else:
430
+ log.info("\n" + OmegaConf.to_yaml(obj))
431
+ else:
432
+ if hasattr(obj, '__dict__'):
433
+ for k, v in vars(obj).items():
434
+ if not k.startswith("_"):
435
+ log.info(f"{k}: {v}")
436
+ log.info(f"{'=' * (44 + len(name))}\n")
437
+
438
+ if "task" in object_dict and hasattr(object_dict["task"], "task"):
439
+ model = object_dict["task"].task
440
+ total = sum(p.numel() for p in model.parameters())
441
+ trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
442
+ log.info(f"{'=' * 20} MODEL PARAMS {'=' * 20}")
443
+ log.info(f"model/params/total: {total}")
444
+ log.info(f"model/params/trainable: {trainable}")
445
+ log.info("=" * 54 + "\n")
446
+
447
+ log.info("========== End of Hyperparameters ==========\n")
448
+
449
+
450
+ def train_main(cfg: DictConfig):
451
+ """Entry point for CLI train command."""
452
+ metric, _ = train(cfg)
453
+ return metric
MolecularDiffusion/configs/data/filter_molecules_by_property.py ADDED
File without changes
MolecularDiffusion/configs/data/formed_data.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: MolecularDiffusion.runmodes.train.DataModule
2
+ root: /home/pregabalin/RF/blue_edm/data/formed
3
+ filename: /home/pregabalin/RF/blue_edm/data/formed/Data_FORMED_scored.csv # 4k or ready
4
+ atom_vocab: [H,B,C,N,O,F,Al,Si,P,S,Cl,As,Se,Br,I,Hg,Bi]
5
+ dataset_name: formed
6
+ with_hydrogen: True
7
+ node_feature: null # atom_topological, atom_geom, atom_geom_compact, atom_geom_opt
8
+ max_atom: 120
9
+ xyz_dir: /home/pregabalin/RF/blue_edm/data/formed/XYZ_FORMED/
10
+ coord_file: null
11
+ natoms_file: null
12
+ forbidden_atom: []
13
+ data_efficient_collator: True
14
+ train_ratio: 0.8
15
+ load_pkl: null
16
+ save_pkl: data/test.pkl
17
+ data_type: pyg # pyg or pointcloud
18
+ batch_size: 32
19
+ num_workers: 0
20
+ allow_unknown: False # additional atom type for the unknown in OHE
MolecularDiffusion/configs/data/mol_dataset.yaml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: MolecularDiffusion.runmodes.train.DataModule
2
+ root: data/
3
+ filename: path_to_csv.csv # 4k or ready
4
+ atom_vocab: [H,B,C,N,O,F,Al,Si,P,S,Cl,As,Se,Br,I,Hg,Bi] #Ge,Sn,Te,Sb
5
+ dataset_name: qm9
6
+ with_hydrogen: True
7
+ use_ohe_feature: True
8
+ allow_unknown: False # True to add +1 "unknown" column to OHE for rare/unseen atoms
9
+ node_feature_choice: null # atom_topological, atom_geom, atom_geom_compact, atom_geom_opt
10
+ max_atom: 29
11
+ xyz_dir: path_to_xyz
12
+ coord_file: null
13
+ natoms_file: null
14
+ forbidden_atom: []
15
+ data_efficient_collator: True
16
+ train_ratio: 0.8
17
+ load_pkl: null
18
+ save_pkl: data/test.pkl #TODO this is not really used anymore
19
+ data_type: pointcloud # pyg or pointcloud
20
+ batch_size: 48
21
+ num_workers: 0
22
+ edge_type: fully_connected
23
+ radius: 4.0
24
+ n_neigh: 5
25
+ # consider_global_attributes: False #depricated
MolecularDiffusion/configs/data/mol_dataset_extraf.yaml ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: MolecularDiffusion.runmodes.train.DataModule
2
+ root: /home/pregabalin/RF/blue_edm/data/qm9
3
+ filename: /home/pregabalin/RF/blue_edm/data/qm9/dsgdb9nsd_4k.csv # 4k or ready
4
+ atom_vocab: [H,B,C,N,O,F,Al,Si,P,S,Cl,As,Se,Br,I,Hg,Bi]
5
+ dataset_name: qm9
6
+ with_hydrogen: True
7
+ node_feature: atom_geom_compact # atom_topological, atom_geom, atom_geom_compact, atom_geom_opt
8
+ max_atom: 29
9
+ xyz_dir: /home/pregabalin/RF/blue_edm/data/qm9/dsgdb9nsd/
10
+ coord_file: null
11
+ natoms_file: null
12
+ forbidden_atom: []
13
+ data_efficient_collator: True
14
+ train_ratio: 0.8
15
+ load_pkl: null
16
+ save_pkl: data/test.pkl
17
+ data_type: pointcloud # pyg or pointcloud
18
+ batch_size: 48
19
+ num_workers: 0
20
+ allow_unknown: False # additional atom type for the unknown in OHE
21
+ edge_type: fully_connected
22
+ radius: 4.0
23
+ n_neigh: 5
MolecularDiffusion/configs/engine/lightning.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Use PyTorch Lightning Trainer
2
+ engine_type: lightning
3
+
4
+ # Lightning-specific trainer configuration
5
+ trainer_config:
6
+ _target_: pytorch_lightning.Trainer
7
+
8
+ # Training
9
+ max_epochs: ${trainer.num_epochs}
10
+ accelerator: auto
11
+ devices: auto
12
+ strategy: auto # Lightning auto-selects ddp/ddp_spawn based on devices
13
+
14
+ # Precision - will be converted to Lightning format in Python
15
+ precision: ${trainer.precision}
16
+
17
+ # Optimization
18
+ accumulate_grad_batches: 1
19
+ gradient_clip_val: ${trainer.grad_clip_value}
20
+ gradient_clip_algorithm: ${trainer.gradient_clip_mode}
21
+
22
+ # Logging & Validation
23
+ log_every_n_steps: ${logger.log_interval}
24
+ check_val_every_n_epoch: ${trainer.validation_interval}
25
+
26
+ # Checkpointing
27
+ enable_checkpointing: true
28
+ default_root_dir: ${trainer.output_path}
29
+
30
+ # Other
31
+ num_sanity_val_steps: 0 # Skip sanity validation
32
+ enable_progress_bar: true
33
+ enable_model_summary: true
MolecularDiffusion/configs/engine/original.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Use the original custom Engine class
2
+ engine_type: original
3
+
4
+ # Engine is instantiated inline in train.py using engine_wrapper()
MolecularDiffusion/configs/hydra/default.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://hydra.cc/docs/configure_hydra/intro/
2
+
3
+ # enable color logging
4
+ # install hydra-colorlog==1.2.0
5
+ defaults:
6
+ - override hydra_logging: colorlog
7
+ - override job_logging: colorlog
8
+
9
+ # output directory, generated dynamically on each run
10
+ run:
11
+ dir: ${trainer.output_path}
12
+ # dir: ${trainer.output_path}/${tasks.task_type}/runs/${name}_${now:%Y-%m-%d}_${now:%H-%M-%S}
13
+
14
+ job_logging:
15
+ handlers:
16
+ file:
17
+ # Incorporates fix from https://github.com/facebookresearch/hydra/pull/2242
18
+ filename: ${hydra.runtime.output_dir}/${name}_${now:%Y-%m-%d}_${now:%H-%M-%S}.log
19
+
MolecularDiffusion/configs/interference/gen_cfg.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: MolecularDiffusion.runmodes.generate.GenerativeFactory
2
+ task_type: cfg
3
+ sampling_mode: "ddpm"
4
+ num_generate: 100
5
+ mol_size: [0,0]
6
+ max_mol_size: 0
7
+ target_values: [3,1.5]
8
+ property_names: ["S1_exc", "T1_exc"]
9
+ batch_size: 1
10
+ seed: 86
11
+ n_frames: 0
12
+ output_path: generated_mol
13
+ condition_configs:
14
+ cfg_scale: 1
15
+ cfg_scale_schedule: null
MolecularDiffusion/configs/interference/gen_cfggg.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: MolecularDiffusion.runmodes.generate.GenerativeFactory
2
+ task_type: gradient_guidance # cfggg
3
+ sampling_mode: "ddpm"
4
+ num_generate: 100
5
+ mol_size: [0,0]
6
+ max_mol_size: 0
7
+ target_values: [3,1.5]
8
+ property_names: ["S1_exc", "T1_exc"]
9
+ batch_size: 1
10
+ seed: 86
11
+ n_frames: 0
12
+ output_path: generated_mol
13
+ condition_configs:
14
+ cfg_scale: 1
15
+ target_function:
16
+ _target_: scripts.gradient_guidance.sf_energy_score.SFEnergyScore
17
+ _partial_: true
18
+ chkpt_directory: trained_models/egcl_guidance_s1t1.ckpt
19
+ gg_scale: 1e-3
20
+ max_norm: 1e-3
21
+ scheduler:
22
+ _target_: scripts.gradient_guidance.scheduler.CosineAnnealing
23
+ _partial_: true
24
+ T_max: 1000
25
+ eta_min: 0
26
+ guidance_ver: 2
27
+ guidance_at: 1
28
+ guidance_stop: 0
29
+ n_backwards: 3
MolecularDiffusion/configs/interference/gen_conditional.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: MolecularDiffusion.runmodes.generate.GenerativeFactory
2
+ task_type: conditional
3
+ sampling_mode: "ddpm"
4
+ num_generate: 100
5
+ mol_size: [0,0]
6
+ max_mol_size: 0
7
+ target_values: [3,1.5]
8
+ property_names: ["S1_exc", "T1_exc"]
9
+ batch_size: 1
10
+ seed: 86
11
+ n_frames: 0
12
+ output_path: generated_mol
MolecularDiffusion/configs/interference/gen_gg.yaml ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: MolecularDiffusion.runmodes.generate.GenerativeFactory
2
+ task_type: gradient_guidance # gg
3
+ sampling_mode: "ddpm"
4
+ num_generate: 100
5
+ mol_size: [0,0]
6
+ max_mol_size: 0
7
+ target_values: []
8
+ property_names: []
9
+ batch_size: 1
10
+ seed: 86
11
+ n_frames: 0
12
+ output_path: generated_mol
13
+ condition_configs:
14
+ cfg_scale: 0
15
+ target_function:
16
+ _target_: scripts.gradient_guidance.sf_energy_score.SFEnergyScore
17
+ _partial_: true
18
+ chkpt_directory: trained_models/egcl_guidance_s1t1.ckpt
19
+ gg_scale: 1e-3
20
+ max_norm: 1e-3
21
+ scheduler:
22
+ _target_: scripts.gradient_guidance.scheduler.CosineAnnealing
23
+ _partial_: true
24
+ T_max: 1000
25
+ eta_min: 0
26
+ guidance_ver: 2
27
+ guidance_at: 1
28
+ guidance_stop: 0
29
+ n_backwards: 0
MolecularDiffusion/configs/interference/gen_hybrid.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: MolecularDiffusion.runmodes.generate.GenerativeFactory
2
+ task_type: inpaint_cfg
3
+ sampling_mode: "ddpm"
4
+ num_generate: 100
5
+ mol_size: [0,0]
6
+ max_mol_size: 0
7
+ target_values: [3,1.5]
8
+ property_names: ["S1_exc", "T1_exc"]
9
+ batch_size: 1
10
+ seed: 86
11
+ n_frames: 0
12
+ output_path: generated_mol
13
+ condition_configs:
14
+ cfg_scale: 1
15
+ reference_structure_path: "data/template_structures/INT2_0.xyz"
16
+ inpaint_cfgs:
17
+ t_start: 0.8
18
+ t_critical: 0.05
19
+
20
+ # inpaint
21
+ # denoising_strength: 0.7
22
+ # noise_initial_mask: False
23
+ # mask_node_index:
24
+ # - 5
25
+ # - 30
26
+ # - 31
27
+
28
+
MolecularDiffusion/configs/interference/gen_inpaint.yaml ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: MolecularDiffusion.runmodes.generate.GenerativeFactory
2
+ task_type: inpaint
3
+ sampling_mode: "ddpm"
4
+ num_generate: 100
5
+ mol_size: [0,0]
6
+ max_mol_size: 0
7
+ target_values: []
8
+ property_names: []
9
+ batch_size: 1
10
+ seed: 86
11
+ n_frames: 0
12
+ output_path: generated_mol
13
+ condition_configs:
14
+ reference_structure_path: "data/template_structures/BINOLCpHHH.xyz"
15
+ condition_component: xh
16
+ inpaint_cfgs:
17
+ mask_node_index:
18
+ - 5
19
+ - 30
20
+ - 31
21
+ - 6
22
+ - 7
23
+ - 45
24
+ - 8
25
+ - 32
26
+ - 9
27
+ - 10
28
+ - 33
29
+ - 11
30
+ - 34
31
+ - 12
32
+ - 35
33
+ - 13
34
+ - 36
35
+ - 14
36
+ - 15
37
+ - 16
38
+ - 17
39
+ - 18
40
+ - 37
41
+ - 19
42
+ - 38
43
+ - 20
44
+ - 39
45
+ - 21
46
+ - 40
47
+ - 22
48
+ - 23
49
+ - 41
50
+ - 24
51
+ - 44
52
+ - 25
53
+ - 26
54
+ - 43
55
+ - 42
56
+ denoising_strength: 0.75
57
+ t_start: 0.8
58
+ t_critical_1: 0.8
59
+ t_critical_2: 1
60
+ d_threshold_f: 1.5
61
+ w_b: 10
62
+ all_frozen: True
63
+ use_covalent_radii: True
64
+ scale_factor: 1.2
65
+ noise_initial_mask: True
66
+ n_frames: 0
67
+ n_retrys: 0
68
+ t_retry: 180
69
+
MolecularDiffusion/configs/interference/gen_outpaint.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: MolecularDiffusion.runmodes.generate.GenerativeFactory
2
+ task_type: outpaint
3
+ sampling_mode: ddpm
4
+ num_generate: 100
5
+ mol_size: [0, 0]
6
+ max_mol_size: 0
7
+ target_values: []
8
+ property_names: []
9
+ batch_size: 1
10
+ seed: 86
11
+ n_frames: 0
12
+ output_path: generated_mol
13
+
14
+ condition_configs:
15
+ reference_structure_path: data/template_structures/BINOLCp.xyz
16
+ condition_component: xh
17
+
18
+ outpaint_cfgs:
19
+ t_start: 0.8
20
+ t_critical_1: 0.7
21
+ t_critical_2: 0.4
22
+ d_threshold_f: 2
23
+ w_b: 0.1
24
+ all_frozen: false
25
+ use_covalent_radii: true
26
+ scale_factor: 1.1
27
+ noise_initial_mask: false
28
+ connector_dicts: {} # fill if needed, e.g. {0: [3]}
29
+
30
+ n_retrys: 3
31
+ t_retry: 180
MolecularDiffusion/configs/interference/gen_outpaintft.yaml ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: MolecularDiffusion.runmodes.generate.GenerativeFactory
2
+ task_type: outpaintft
3
+ sampling_mode: "ddpm"
4
+ num_generate: 100
5
+ mol_size: [76,76]
6
+ target_values: []
7
+ property_names: []
8
+ batch_size: 1
9
+ seed: 86
10
+ n_frames: 0
11
+ output_path: generated_mol
12
+ condition_configs:
13
+ reference_structure_path: "data/template_structures/INT2_0.xyz"
14
+ outpaint_cfgs:
15
+ t_start: 1
16
+ n_retrys: 0
17
+ t_retry: 180
18
+
MolecularDiffusion/configs/interference/gen_unconditional.yaml ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: MolecularDiffusion.runmodes.generate.GenerativeFactory
2
+ task_type: unconditional
3
+ sampling_mode: "ddpm"
4
+ num_generate: 100
5
+ mol_size: [16]
6
+ target_values: []
7
+ property_names: []
8
+ batch_size: 1
9
+ seed: 86
10
+ n_frames: 0
11
+ output_path: generated_mol
MolecularDiffusion/configs/interference/prediction.yaml ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ prop_names: ["S1_exc", "T1_exc"]
2
+ hit_criteria: null
MolecularDiffusion/configs/logger/default.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ _target_: MolecularDiffusion.runmodes.train.Logger
2
+ logger: logging # wandb, logging
3
+ log_interval: 2
4
+ name_wandb: MolecularDiffusion
5
+ project_wandb: MolecularDiffusion
6
+ dir_wandb: ${trainer.output_path}
7
+
8
+
9
+
MolecularDiffusion/configs/logger/wandb.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ _target_: MolecularDiffusion.runmodes.train.Logger
2
+ logger: wandb # wandb, logging
3
+ log_interval: 2
4
+ name_wandb: MolecularDiffusion
5
+ project_wandb: MolecularDiffusion
6
+ dir_wandb: ${trainer.output_path}
7
+
8
+
9
+
MolecularDiffusion/configs/models/tabasco_transformer.yaml ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TABASCO Transformer model configuration
2
+ # State-of-the-art non-equivariant flow matching model for molecules
3
+ _target_: MolecularDiffusion.modules.tasks.diffusion_tabasco.TabascoDiffusionTask
4
+
5
+ # Number of atom types from dataset vocabulary
6
+ num_atom_types: 19 # Will be overridden by ${data.num_atom_types} at runtime
7
+
8
+ # Transformer backbone configuration
9
+ transformer_config:
10
+ _target_: MolecularDiffusion.modules.layers.tabasco.transformer_module.TransformerModule
11
+ spatial_dim: 3
12
+ atom_dim: 19 # Will be overridden by ${data.num_atom_types}
13
+ hidden_dim: 256
14
+ num_layers: 16
15
+ num_heads: 8
16
+ activation: SiLU
17
+ implementation: pytorch # or 'reimplemented'
18
+ cross_attention: true
19
+ add_sinusoid_posenc: true
20
+ concat_combine_input: false
21
+ custom_weight_init: null # or 'xavier', 'kaiming', etc.
22
+
23
+ # Continuous coordinate interpolant configuration
24
+ coords_interpolant_config:
25
+ _target_: MolecularDiffusion.modules.models.tabasco.flow.interpolate.SDEMetricInterpolant
26
+ key: coords
27
+ loss_weight: 1.0
28
+ centered: true
29
+ scale_noise_by_log_num_atoms: false
30
+ noise_scale: 1.0
31
+ # Langevin sampling schedule for SDE integration
32
+ langevin_sampling_schedule:
33
+ _target_: MolecularDiffusion.modules.models.tabasco.sample.noise_schedule.SampleNoiseSchedule
34
+ cutoff: 0.9
35
+ white_noise_sampling_scale: 0.01
36
+ # Time-dependent loss weighting
37
+ time_factor:
38
+ _target_: MolecularDiffusion.modules.models.tabasco.flow.time_factor.InverseTimeFactor
39
+ max_value: 100.0
40
+ min_value: 0.05
41
+ zero_before: 0.0
42
+ eps: 1.0e-6
43
+
44
+ # Discrete atom type interpolant configuration
45
+ atomics_interpolant_config:
46
+ _target_: MolecularDiffusion.modules.models.tabasco.flow.interpolate.DiscreteInterpolant
47
+ key: atomics
48
+ loss_weight: 0.1
49
+ # Time-dependent loss weighting
50
+ time_factor:
51
+ _target_: MolecularDiffusion.modules.models.tabasco.flow.time_factor.InverseTimeFactor
52
+ max_value: 100.0
53
+ min_value: 0.05
54
+ zero_before: 0.0
55
+ eps: 1.0e-6
56
+
57
+ # Flow matching training configuration
58
+ flow_matching_config:
59
+ _target_: MolecularDiffusion.modules.models.tabasco.flow_model.FlowMatchingModel
60
+ time_distribution:
61
+ _target_: MolecularDiffusion.modules.models.tabasco.flow.utils.HistogramTimeDistribution
62
+ time_alpha_factor: 1.8
63
+ num_random_augmentations: 7 # +1 original = 8 total
64
+ sample_schedule: log # or 'linear', 'power'
65
+ compile: false
66
+ interdist_loss: null
67
+
68
+ # Dataset statistics (populated at runtime)
69
+ dataset_stats:
70
+ max_atoms: 29 # Will be set from data config
71
+ atom_count_histogram: {} # Computed from dataset
72
+ all_smiles: [] # Collected from dataset
MolecularDiffusion/configs/tasks/diffusion.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: MolecularDiffusion.runmodes.train.ModelTaskFactory_EGCL
2
+ task_type: diffusion
3
+ atom_vocab: ${data.atom_vocab}
4
+ condition_names: []
5
+ hidden_size: 192
6
+ act_fn:
7
+ _target_: torch.nn.SiLU
8
+ num_layers: 9
9
+ attention: True
10
+ tanh: True
11
+ num_sublayers: 1
12
+ sin_embedding: False
13
+ aggregation_method: "sum"
14
+ dropout: 0.0
15
+ normalization: False
16
+ include_cosine: True
17
+ norm_constant: 1.0
18
+ normalization_factor: 1.0
19
+ chkpt_path: null
20
+
21
+ # specific to diffusion
22
+ diffusion_steps : 900
23
+ diffusion_noise_schedule : polynomial_2 # learned, cosine_x, polynomial_x, issnr_x, smld_x
24
+ diffusion_noise_precision: 1e-5
25
+ diffusion_loss_type: vlb
26
+ normalize_factors: [1,4,10]
27
+ extra_norm_values: []
28
+ augment_noise: False
29
+ data_augmentation: False
30
+ context_mask_rate: 0.2
31
+ mask_value: 5
32
+ normalize_condition: value_10 # [None, "maxmin", "mad"]
33
+ sp_regularizer_deploy: False
34
+ sp_regularizer_regularizer: hard
35
+ sp_regularizer_lambda_: 0
36
+ sp_regularizer_lambda_2: 1000
37
+ sp_regularizer_lambda_update_value: 1
38
+ sp_regularizer_lambda_update_step: 100
39
+ sp_regularizer_polynomial_p: 1.1
40
+ sp_regularizer_warm_up_steps: 100
41
+ use_unknown_fallback: False
42
+ reference_indices: null # indices of core atoms for the outpainting objective
43
+ # evaluator parameters
44
+ use_posebuster: True
45
+ metrics: valid_posebuster # use_posebuster must be true
46
+ n_samples: 48
47
+ batch_size: 4
48
+ generative_analysis: True
MolecularDiffusion/configs/tasks/diffusion_egt.yaml ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: MolecularDiffusion.runmodes.train.tasks_egt.ModelTaskFactory
2
+ task_type: diffusion
3
+ model_class: GraphTransformer
4
+ atom_vocab: ${data.atom_vocab}
5
+ condition_names: []
6
+ hidden_dims:
7
+ dx: 256
8
+ de: 64
9
+ dy: 4
10
+ n_head: 4
11
+ dim_ffX: 256
12
+ dim_ffE: 64
13
+ dim_ffy: 1
14
+ hidden_mlp_dims:
15
+ X: 256
16
+ E: 64
17
+ y: 256
18
+ pos: 512
19
+ act_fn_in:
20
+ _target_: torch.nn.SiLU
21
+ act_fn_out:
22
+ _target_: torch.nn.SiLU
23
+ num_layers: 6
24
+ dropout: 0.1
25
+ chkpt_path: null
26
+
27
+
28
+ # specific to diffusion
29
+ diffusion_steps : 400
30
+ diffusion_noise_schedule : polynomial_2 # learned, cosine_x, polynomial_x, issnr_x, smld_x
31
+ diffusion_noise_precision: 1e-5
32
+ diffusion_loss_type: vlb
33
+ normalize_factors: [1,4,10]
34
+ extra_norm_values: []
35
+ augment_noise: False
36
+ data_augmentation: False
37
+ context_mask_rate: 0.2
38
+ mask_value: 5
39
+ normalize_condition: value_10 # [None, "maxmin", "mad"]
40
+ sp_regularizer_deploy: False
41
+ sp_regularizer_regularizer: hard
42
+ sp_regularizer_lambda_: 0
43
+ sp_regularizer_lambda_2: 1000
44
+ sp_regularizer_lambda_update_value: 1
45
+ sp_regularizer_lambda_update_step: 100
46
+ sp_regularizer_polynomial_p: 1.1
47
+ sp_regularizer_warm_up_steps: 100
48
+ reference_indices: null # indices of core atoms for the outpainting objective
49
+ # evaluator parameters
50
+ use_posebuster: True
51
+ metrics: valid_posebuster # use_posebuster must be true
52
+ n_samples: 24
53
+ generative_analysis: True
54
+ batch_size: 4
MolecularDiffusion/configs/tasks/diffusion_extraf.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: MolecularDiffusion.runmodes.train.ModelTaskFactory_EGCL
2
+ task_type: diffusion
3
+ atom_vocab: ${data.atom_vocab}
4
+ condition_names: []
5
+ hidden_size: 192
6
+ act_fn:
7
+ _target_: torch.nn.SiLU
8
+ num_layers: 1
9
+ attention: True
10
+ tanh: True
11
+ num_sublayers: 12
12
+ sin_embedding: False
13
+ aggregation_method: "sum"
14
+ dropout: 0.0
15
+ normalization: False
16
+ include_cosine: True
17
+ norm_constant: 1.0
18
+ normalization_factor: 1.0
19
+ chkpt_path: null
20
+
21
+ # specific to diffusion
22
+ diffusion_steps : 400
23
+ diffusion_noise_schedule : polynomial_2 # learned, cosine_x, polynomial_x, issnr_x, smld_x
24
+ diffusion_noise_precision: 1e-5
25
+ diffusion_loss_type: vlb
26
+ normalize_factors: [1,4,10]
27
+ extra_norm_values: [10,10]
28
+ augment_noise: False
29
+ data_augmentation: False
30
+ context_mask_rate: 0.2
31
+ mask_value: 5
32
+ normalize_condition: value_10 # [None, "maxmin", "mad"]
33
+ sp_regularizer_deploy: False
34
+ sp_regularizer_regularizer: hard
35
+ sp_regularizer_lambda_: 0
36
+ sp_regularizer_lambda_2: 1000
37
+ sp_regularizer_lambda_update_value: 1
38
+ sp_regularizer_lambda_update_step: 100
39
+ sp_regularizer_polynomial_p: 1.1
40
+ sp_regularizer_warm_up_steps: 100
41
+ reference_indices: null # indices of core atoms for the outpainting objective
42
+ # evaluator parameters
43
+ use_posebuster: True
44
+ metrics: valid_posebuster # use_posebuster must be true
45
+ n_samples: 24
46
+ generative_analysis: True
47
+ batch_size: 4
MolecularDiffusion/configs/tasks/diffusion_hybrid.yaml ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: MolecularDiffusion.runmodes.train.tasks_esen.ModelTaskFactory
2
+ task_type: diffusion_hybrid
3
+
4
+ # === Atom Vocabulary ===
5
+ # Specify either atom_vocab directly OR use the one from data config
6
+ # Available base vocabularies: H, C, N, O, F, P, S, Cl, Br, I (common organic)
7
+ # The number of classes is automatically determined from vocab length
8
+ atom_vocab: ${data.atom_vocab}
9
+ # atom_vocab: ["C", "N", "O", "H", "F", "S", "Cl", "Br", "P", "I"] # Example custom
10
+
11
+ condition_names: []
12
+
13
+ # eSEN specific parameters
14
+ hidden_size: 64
15
+ hidden_channels: 64
16
+ num_layers: 9
17
+ lmax: 2
18
+ mmax: 2
19
+ grid_resolution: null
20
+ cutoff: 30
21
+ edge_channels: 128
22
+ distance_function: "gaussian"
23
+ num_distance_basis: 512
24
+ norm_type: "rms_norm_sh"
25
+ act_type: "s2"
26
+ mlp_type: "grid"
27
+ otf_graph: True
28
+ use_envelope: False
29
+ activation_checkpointing: False
30
+ global_attributes: False
31
+ sphere_embedding_type: "mixed" # DO NOT CHANGE
32
+ aggregation_method: "sum"
33
+
34
+ chkpt_path: null
35
+
36
+ # === Continuous Diffusion Parameters ===
37
+ diffusion_steps: 450
38
+ diffusion_noise_schedule: polynomial_2 # Options: cosine, polynomial_2, polynomial_3, learned
39
+ diffusion_noise_precision: 1e-5
40
+ diffusion_loss_type: l2 # Options: vlb, l2
41
+ normalize_factors: [1, 1]
42
+ extra_norm_values: []
43
+ augment_noise: False
44
+ data_augmentation: False
45
+ context_mask_rate: 0.2
46
+ mask_value: 5
47
+ normalize_condition: value_10
48
+ sp_regularizer_deploy: False
49
+ sp_regularizer_regularizer: hard
50
+ sp_regularizer_lambda_: 0
51
+ sp_regularizer_lambda_2: 1000
52
+ sp_regularizer_lambda_update_value: 1
53
+ sp_regularizer_lambda_update_step: 100
54
+ sp_regularizer_polynomial_p: 1.1
55
+ sp_regularizer_warm_up_steps: 100
56
+ reference_indices: null
57
+
58
+ # === Discrete Diffusion Parameters (Atom Types) ===
59
+ # Number of atom classes (automatically set from atom_vocab length if not specified)
60
+ num_atom_classes: 19
61
+
62
+ # Weight for discrete loss in combined loss: L_total = L_continuous + λ * L_discrete
63
+ discrete_loss_weight: 0.2
64
+
65
+ # Discrete masking schedule for absorbing-state diffusion
66
+ # Each schedule controls how quickly tokens get masked during forward diffusion
67
+ #
68
+ # Available schedules:
69
+ # - "cosine" : Smooth cosine decay (default, from improved DDPM)
70
+ # - "linear" : Linear increase in masking probability
71
+ # - "sqrt" : Square root schedule (faster initial masking)
72
+ # - "quadratic" : Quadratic schedule (slower initial, faster later)
73
+ # - "cubic" : Cubic schedule (even slower start than quadratic)
74
+ # - "sigmoid" : S-curve transition (smooth start and end)
75
+ # - "exponential" : Exponential decay of survival probability
76
+ # - "log" : Logarithmic schedule (fast early, slow late)
77
+ # - "uniform" : Constant masking rate each step
78
+ #
79
+ discrete_schedule: "cosine"
80
+
81
+ # MLP layers for atom classification head
82
+ atom_head_mlp_layers: 2
83
+
84
+ # === eSEN Dynamics specific ===
85
+ use_adapter_module: False
86
+ tanh: True
87
+ coords_range: 10
88
+ normalization_factor: 1.0
89
+
90
+ # === Evaluator Parameters ===
91
+ use_posebuster: True
92
+ metrics: valid_posebuster
93
+ n_samples: 96
94
+ batch_size: 8
95
+ generative_analysis: True
MolecularDiffusion/configs/tasks/diffusion_hybrid_egcl.yaml ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: MolecularDiffusion.runmodes.train.tasks_egcl.ModelTaskFactory
2
+ task_type: diffusion_hybrid
3
+
4
+ # === Atom Vocabulary ===
5
+ atom_vocab: ${data.atom_vocab}
6
+
7
+ condition_names: []
8
+
9
+ # === EGNN Parameters ===
10
+ hidden_size: 192
11
+ num_layers: 9
12
+ attention: True
13
+ norm_diff: True
14
+ tanh: True
15
+ coords_range: 15
16
+ num_sublayers: 1
17
+ sin_embedding: True
18
+ include_cosine: False
19
+ normalization_factor: 1.0
20
+ aggregation_method: "sum"
21
+ dropout: 0.0
22
+ normalization: False
23
+
24
+ chkpt_path: null
25
+
26
+ # === Continuous Diffusion Parameters ===
27
+ diffusion_steps: 900
28
+ diffusion_noise_schedule: polynomial_2
29
+ diffusion_noise_precision: 1e-5
30
+ diffusion_loss_type: vlb
31
+ normalize_factors: [1, 4]
32
+ extra_norm_values: []
33
+ augment_noise: False
34
+ data_augmentation: False
35
+ context_mask_rate: 0.0
36
+ mask_value: 5.0
37
+ normalize_condition: value_10
38
+ sp_regularizer_deploy: False
39
+
40
+ # === Discrete Diffusion Parameters (Atom Types) ===
41
+ num_atom_classes: 19
42
+ discrete_loss_weight: 0.2
43
+ discrete_schedule: "cosine"
44
+
45
+ # MLP layers for atom classification head
46
+ atom_head_mlp_layers: 2
47
+
48
+ # === Evaluator Parameters ===
49
+ use_posebuster: True
50
+ metrics: valid_posebuster
51
+ n_samples: 48
52
+ batch_size: 8
53
+ generative_analysis: True
MolecularDiffusion/configs/tasks/diffusion_integer.yaml ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: MolecularDiffusion.runmodes.train.tasks_esen.ModelTaskFactory
2
+ task_type: diffusion
3
+ atom_vocab: ${data.atom_vocab}
4
+ condition_names: []
5
+
6
+ # eSEN specific parameters
7
+ hidden_size: 256
8
+ hidden_channels: 256
9
+ num_layers: 4
10
+ lmax: 2
11
+ mmax: 2
12
+ grid_resolution: null
13
+ cutoff: 5.0
14
+ edge_channels: 128
15
+ distance_function: "gaussian"
16
+ num_distance_basis: 512
17
+ norm_type: "rms_norm_sh"
18
+ act_type: "s2"
19
+ mlp_type: "grid"
20
+ otf_graph: True #!!
21
+ use_envelope: False
22
+ activation_checkpointing: False
23
+ global_attributes: False
24
+ sphere_embedding_type: "gaussian" #!!
25
+ aggregation_method: "sum"
26
+
27
+ chkpt_path: null
28
+
29
+ # Diffusion kwargs
30
+ diffusion_steps: 450
31
+ diffusion_noise_schedule: polynomial_2
32
+ diffusion_noise_precision: 1e-5
33
+ diffusion_loss_type: vlb
34
+ normalize_factors: [1, 1]
35
+ extra_norm_values: []
36
+ augment_noise: False
37
+ data_augmentation: False
38
+ context_mask_rate: 0.2
39
+ mask_value: 5
40
+ normalize_condition: value_10
41
+ sp_regularizer_deploy: False
42
+ sp_regularizer_regularizer: hard
43
+ sp_regularizer_lambda_: 0
44
+ sp_regularizer_lambda_2: 1000
45
+ sp_regularizer_lambda_update_value: 1
46
+ sp_regularizer_lambda_update_step: 100
47
+ sp_regularizer_polynomial_p: 1.1
48
+ sp_regularizer_warm_up_steps: 100
49
+ reference_indices: null
50
+
51
+ # eSEN_dynamics specific kwargs
52
+ use_adapter_module: False
53
+ tanh: True
54
+ coords_range: 10
55
+ normalization_factor: 1.0
56
+
57
+ # Evaluator parameters
58
+ use_posebuster: True
59
+ metrics: valid_posebuster
60
+ n_samples: 96
61
+ batch_size: 8
62
+ generative_analysis: True
MolecularDiffusion/configs/tasks/diffusion_pretrained.yaml ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: MolecularDiffusion.runmodes.train.ModelTaskFactory_EGCL
2
+ task_type: diffusion
3
+ atom_vocab: ${data.atom_vocab}
4
+ condition_names: []
5
+ hidden_size: 192
6
+ act_fn:
7
+ _target_: torch.nn.SiLU
8
+ num_layers: 9
9
+ attention: True
10
+ tanh: True
11
+ num_sublayers: 1
12
+ sin_embedding: False
13
+ aggregation_method: "sum"
14
+ dropout: 0.0
15
+ normalization: False
16
+ include_cosine: True
17
+ norm_constant: 1.0
18
+ normalization_factor: 1.0
19
+ chkpt_path: null
20
+
21
+ # specific to diffusion
22
+ diffusion_steps : 900
23
+ diffusion_noise_schedule : polynomial_2 # learned, cosine_x, polynomial_x, issnr_x, smld_x
24
+ diffusion_noise_precision: 1e-5
25
+ diffusion_loss_type: vlb
26
+ normalize_factors: [1,4,10]
27
+ extra_norm_values: []
28
+ augment_noise: False
29
+ data_augmentation: False
30
+ context_mask_rate: 0.2
31
+ mask_value: 5
32
+ normalize_condition: value_10 # [None, "maxmin", "mad"]
33
+ sp_regularizer_deploy: False
34
+ sp_regularizer_regularizer: hard
35
+ sp_regularizer_lambda_: 0
36
+ sp_regularizer_lambda_2: 1000
37
+ sp_regularizer_lambda_update_value: 1
38
+ sp_regularizer_lambda_update_step: 100
39
+ sp_regularizer_polynomial_p: 1.1
40
+ sp_regularizer_warm_up_steps: 100
41
+ reference_indices: null # indices of core atoms for the outpainting objective
42
+ # evaluator parameters
43
+ use_posebuster: True
44
+ metrics: valid_posebuster # use_posebuster must be true
45
+ n_samples: 24
46
+ generative_analysis: True
47
+ batch_size: 4
MolecularDiffusion/configs/tasks/diffusion_pyg.yaml ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: MolecularDiffusion.runmodes.train.tasks_esen.ModelTaskFactory
2
+ task_type: diffusion_pyg
3
+
4
+ # === Atom Vocabulary ===
5
+ atom_vocab: ${data.atom_vocab}
6
+
7
+ condition_names: []
8
+
9
+ # === eSEN Model Parameters ===
10
+ hidden_size: 256
11
+ hidden_channels: 32
12
+ num_layers: 9
13
+ lmax: 2
14
+ mmax: 2
15
+ grid_resolution: null
16
+ cutoff: 15
17
+ edge_channels: 128
18
+ distance_function: "gaussian"
19
+ num_distance_basis: 10
20
+ norm_type: "rms_norm_sh"
21
+ act_type: "s2"
22
+ mlp_type: "grid"
23
+ otf_graph: True
24
+ use_envelope: False
25
+ activation_checkpointing: False
26
+ global_attributes: False
27
+
28
+ # IMPORTANT: Use "gaussian" for float features during diffusion!
29
+ # "gaussian" uses Gaussian smearing + MLP, fully float-compatible
30
+ # Other options ("embedding", "mixed") require integer atomic_numbers
31
+ sphere_embedding_type: "gaussian"
32
+
33
+ aggregation_method: "sum"
34
+
35
+ chkpt_path: null
36
+
37
+ # === Continuous Diffusion Parameters ===
38
+ # All features (positions, one-hot, integer) use continuous Gaussian diffusion
39
+ diffusion_steps: 900
40
+ diffusion_noise_schedule: polynomial_2 # Options: cosine, polynomial_2, polynomial_3, learned
41
+ diffusion_noise_precision: 1e-5
42
+ diffusion_loss_type: vlb # Options: vlb, l2
43
+
44
+ # Normalization factors: [positions, categorical (one-hot), integer (atomic_numbers)]
45
+ normalize_factors: [1.0, 4.0, 10.0]
46
+ extra_norm_values: []
47
+
48
+ # Data augmentation
49
+ augment_noise: False
50
+ data_augmentation: False
51
+
52
+ # Context masking for classifier-free guidance
53
+ context_mask_rate: 0.0
54
+ mask_value: 0.0
55
+ normalize_condition: null
56
+
57
+ # Self-paced learning regularizer
58
+ sp_regularizer_deploy: False
59
+ sp_regularizer_regularizer: hard
60
+ sp_regularizer_lambda_: 0
61
+ sp_regularizer_lambda_2: 1000
62
+ sp_regularizer_lambda_update_value: 1
63
+ sp_regularizer_lambda_update_step: 100
64
+ sp_regularizer_polynomial_p: 1.1
65
+ sp_regularizer_warm_up_steps: 100
66
+
67
+ # Outpainting/inpainting
68
+ reference_indices: null
69
+ use_unknown_fallback: False # Set to True when data.allow_unknown is True
70
+
71
+ # === eSEN Dynamics Specific ===
72
+ use_adapter_module: False
73
+ tanh: True
74
+ coords_range: 10
75
+ normalization_factor: 1.0
76
+
77
+ # === Evaluation Parameters ===
78
+ use_posebuster: True
79
+ metrics: valid_posebuster
80
+ n_samples: 48
81
+ batch_size: 8
82
+ generative_analysis: True
MolecularDiffusion/configs/tasks/diffusion_pyg_egcl.yaml ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: MolecularDiffusion.runmodes.train.tasks_egcl.ModelTaskFactory
2
+ task_type: diffusion_pyg
3
+
4
+ # === Atom Vocabulary ===
5
+ atom_vocab: ${data.atom_vocab}
6
+
7
+ condition_names: []
8
+
9
+ # === EGNN Parameters ===
10
+ hidden_size: 256
11
+ num_layers: 9
12
+ attention: True
13
+ norm_diff: True
14
+ tanh: True
15
+ coords_range: 10
16
+ num_sublayers: 1
17
+ sin_embedding: False
18
+ include_cosine: True
19
+ normalization_factor: 1.0
20
+ aggregation_method: "sum"
21
+ dropout: 0.0
22
+ normalization: False
23
+
24
+ chkpt_path: null
25
+
26
+ # === Continuous Diffusion Parameters ===
27
+ # All features use continuous Gaussian diffusion (same as EnVariationalDiffusion)
28
+ diffusion_steps: 900
29
+ diffusion_noise_schedule: polynomial_2
30
+ diffusion_noise_precision: 1e-5
31
+ diffusion_loss_type: vlb
32
+
33
+ # Normalization factors: [positions, categorical (one-hot), integer (atomic_numbers)]
34
+ normalize_factors: [1.0, 4.0, 10.0]
35
+ extra_norm_values: []
36
+
37
+ # Data augmentation
38
+ augment_noise: False
39
+ data_augmentation: False
40
+
41
+ # Context masking for classifier-free guidance
42
+ context_mask_rate: 0.0
43
+ mask_value: 0.0
44
+ normalize_condition: null
45
+
46
+ # Self-paced learning regularizer
47
+ sp_regularizer_deploy: False
48
+ use_unknown_fallback: False # Set to True when data.allow_unknown is True
49
+
50
+ # === Evaluation Parameters ===
51
+ use_posebuster: True
52
+ metrics: valid_posebuster
53
+ n_samples: 48
54
+ batch_size: 8
55
+ generative_analysis: True
MolecularDiffusion/configs/tasks/diffusion_pyg_egt.yaml ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: MolecularDiffusion.runmodes.train.tasks_egt.ModelTaskFactory
2
+ task_type: diffusion_pyg
3
+
4
+ # === Atom Vocabulary ===
5
+ atom_vocab: ${data.atom_vocab}
6
+ condition_names: []
7
+
8
+ # === Graph Transformer Parameters ===
9
+ model_class: GraphTransformerPyG
10
+ hidden_dims:
11
+ dx: 256
12
+ de: 1
13
+ dy: 32
14
+ n_head: 4
15
+ dim_ffX: 256
16
+ dim_ffE: 1
17
+ dim_ffy: 32
18
+ hidden_mlp_dims:
19
+ X: 256
20
+ E: 1
21
+ y: 32
22
+ pos: 512
23
+ act_fn_in:
24
+ _target_: torch.nn.SiLU
25
+ act_fn_out:
26
+ _target_: torch.nn.SiLU
27
+ num_layers: 6
28
+ dropout: 0.1
29
+ chkpt_path: null
30
+
31
+ # === Diffusion Parameters ===
32
+ diffusion_steps: 900
33
+ diffusion_noise_schedule: polynomial_2
34
+ diffusion_noise_precision: 1e-5
35
+ diffusion_loss_type: vlb
36
+ normalize_factors: [1.0, 4.0, 10.0]
37
+ extra_norm_values: []
38
+
39
+ # Data augmentation
40
+ augment_noise: False
41
+ data_augmentation: False
42
+
43
+ # Context masking for CFG
44
+ context_mask_rate: 0.0
45
+ mask_value: 0.0
46
+ normalize_condition: null
47
+
48
+ # Self-paced regularizer
49
+ sp_regularizer_deploy: False
50
+
51
+ # === Evaluation ===
52
+ use_posebuster: True
53
+ metrics: valid_posebuster
54
+ n_samples: 48
55
+ batch_size: 8
56
+ generative_analysis: True
MolecularDiffusion/configs/tasks/diffusion_tabasco.yaml ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TABASCO diffusion task configuration
2
+ # This config is referenced by: defaults: - override tasks: diffusion_tabasco
3
+
4
+ _target_: MolecularDiffusion.modules.tasks.diffusion_tabasco.ModelTaskFactory
5
+ task_type: diffusion_tabasco
6
+
7
+ # Automatically populated from dataset
8
+ num_atom_types: ???
9
+
10
+ # Transformer backbone configuration
11
+ transformer_config:
12
+ spatial_dim: 3
13
+ atom_dim: ???
14
+ hidden_dim: 256
15
+ num_layers: 16
16
+ num_heads: 8
17
+ activation: SiLU
18
+ implementation: pytorch
19
+ cross_attention: true
20
+ add_sinusoid_posenc: true
21
+ concat_combine_input: false
22
+ custom_weight_init: null
23
+
24
+ # Continuous coordinate interpolant
25
+ coords_interpolant_config:
26
+ key: coords
27
+ loss_weight: 1.0
28
+ centered: true
29
+ scale_noise_by_log_num_atoms: false
30
+ noise_scale: 1.0
31
+ langevin_sampling_schedule:
32
+ _target_: MolecularDiffusion.modules.models.tabasco.sample.noise_schedule.SampleNoiseSchedule
33
+ cutoff: 0.9
34
+ white_noise_sampling_scale: 0.01
35
+ time_factor:
36
+ _target_: MolecularDiffusion.modules.models.tabasco.flow.time_factor.InverseTimeFactor
37
+ max_value: 100.0
38
+ min_value: 0.05
39
+ zero_before: 0.0
40
+ eps: 1.0e-6
41
+
42
+ # Discrete atom type interpolant
43
+ atomics_interpolant_config:
44
+ key: atomics
45
+ loss_weight: 0.1
46
+ time_factor:
47
+ _target_: MolecularDiffusion.modules.models.tabasco.flow.time_factor.InverseTimeFactor
48
+ max_value: 100.0
49
+ min_value: 0.05
50
+ zero_before: 0.0
51
+ eps: 1.0e-6
52
+
53
+ # Flow matching configuration
54
+ flow_matching_config:
55
+ time_distribution: beta
56
+ time_alpha_factor: 1.8
57
+ num_random_augmentations: 7
58
+ sample_schedule: log
59
+ compile: false
60
+ interdist_loss: null
61
+
62
+ # Dataset statistics (populated at runtime)
63
+ dataset_stats:
64
+ max_atoms: ???
65
+ atom_count_histogram: {}
66
+ all_smiles: []
MolecularDiffusion/configs/tasks/guidance.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: MolecularDiffusion.runmodes.train.ModelTaskFactory_EGCL
2
+ task_type: guidance
3
+ atom_vocab: ${data.atom_vocab}
4
+ condition_names: []
5
+ hidden_size: 512
6
+ act_fn:
7
+ _target_: torch.nn.ReLU
8
+ num_layers: 1
9
+ attention: True
10
+ tanh: True
11
+ num_sublayers: 5
12
+ sin_embedding: False
13
+ aggregation_method: "sum"
14
+ dropout: 0.0
15
+ normalization: False
16
+ include_cosine: True
17
+ norm_constant: 1.0
18
+ normalization_factor: 1.0
19
+ chkpt_path: null
20
+
21
+ # specific to diffusion
22
+ task_learn: [S1_exc,T1_exc]
23
+ criterion: mse
24
+ metric: [mae]
25
+ num_mlp_layer: 3
26
+ mlp_dropout: 0.2
27
+ mlp_batch_norm: True # True/False for legacy mode, null/'layernorm'/'batchnorm' for new mode
28
+ prediction_mlp_type: legacy # 'legacy' (backward compat), 'pernode', or 'padded'
29
+ prediction_activation: relu # 'relu' or 'silu'
30
+ diffusion_steps: 900
31
+ diffusion_noise_precision: 1e-5
32
+ nu_arr: [2,2,2]
33
+ mapping: ["pos", "categorical", "integer"]
34
+ weight_classes: null
35
+ norm_values: [1,4,10]
36
+ t_max: 0.7
37
+ loss_weighting: linear
38
+
39
+
40
+
MolecularDiffusion/configs/tasks/guidance_esen.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Uses the existing ModelTaskFactory from tasks_esen.py with task_type: guidance
2
+ _target_: MolecularDiffusion.runmodes.train.tasks_esen.ModelTaskFactory
3
+ task_type: guidance
4
+ atom_vocab: ${data.atom_vocab}
5
+ condition_names: []
6
+
7
+ # eSEN Backbone parameters
8
+ sphere_channels: 128
9
+ hidden_channels: 128
10
+ lmax: 2
11
+ mmax: 2
12
+ num_layers: 4
13
+ edge_channels: 128
14
+ distance_function: "gaussian"
15
+ num_distance_basis: 512
16
+ cutoff: 5.0
17
+ max_neighbors: 300
18
+ norm_type: "rms_norm_sh"
19
+ act_type: "s2"
20
+ mlp_type: "grid"
21
+
22
+ # CRITICAL: Use "mlp" or "gaussian" for differentiable gradients
23
+ sphere_embedding_type: "mlp"
24
+ # in_node_channels is computed by factory: len(atom_vocab) + n_extra + 1 (charge) + 1 (time)
25
+
26
+ # Guidance-specific parameters
27
+ task_learn: [S1_exc, T1_exc]
28
+ criterion: mse
29
+ metric: [mae]
30
+ num_mlp_layer: 3
31
+ mlp_dropout: 0.2
32
+ mlp_batch_norm: True # True/False for legacy mode, null/'layernorm'/'batchnorm' for new mode
33
+ prediction_mlp_type: legacy # 'legacy' (backward compat), 'pernode', or 'padded'
34
+ prediction_activation: relu # 'relu' or 'silu'
35
+ diffusion_steps: 600
36
+ diffusion_noise_precision: 1e-5
37
+ nu_arr: [2, 2, 2]
38
+ mapping: ["pos", "categorical", "integer"]
39
+ weight_classes: null
40
+ norm_values: [1, 4, 10]
41
+ t_max: 0.7
42
+ loss_weighting: linear
43
+ normalization: False
MolecularDiffusion/configs/tasks/guidance_pc.yaml ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration for PointCloud-optimized EGCL Guidance Model
2
+ # Uses GuidanceModelPredictionPointCloud with dense_mode=True
3
+
4
+ _target_: MolecularDiffusion.runmodes.train.ModelTaskFactory_EGCL
5
+ task_type: guidance
6
+ atom_vocab: ${data.atom_vocab}
7
+ condition_names: []
8
+ hidden_size: 512
9
+ act_fn:
10
+ _target_: torch.nn.ReLU
11
+ num_layers: 1
12
+ attention: True
13
+ tanh: True
14
+ num_sublayers: 5
15
+ sin_embedding: False
16
+ aggregation_method: "sum"
17
+ dropout: 0.0
18
+ normalization: False
19
+ include_cosine: True
20
+ norm_constant: 1.0
21
+ normalization_factor: 1.0
22
+ chkpt_path: null
23
+
24
+ # Enable dense mode for PointCloud inference
25
+ dense_mode: True
26
+
27
+ # Guidance-specific parameters
28
+ task_learn: [S1_exc, T1_exc]
29
+ criterion: mse
30
+ metric: [mae]
31
+ num_mlp_layer: 3
32
+ mlp_dropout: 0.2
33
+ mlp_batch_norm: True # True/False for legacy mode, null/'layernorm'/'batchnorm' for new mode
34
+ prediction_mlp_type: legacy # 'legacy' (backward compat), 'pernode', or 'padded'
35
+ prediction_activation: relu # 'relu' or 'silu'
36
+ diffusion_steps: 900
37
+ diffusion_noise_precision: 1e-5
38
+ nu_arr: [2, 2, 2]
39
+ mapping: ["pos", "categorical", "integer"]
40
+ weight_classes: null
41
+ norm_values: [1, 4, 10]
42
+ t_max: 0.7
43
+ loss_weighting: linear
MolecularDiffusion/configs/tasks/ldm_dit.yaml ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Latent Diffusion with DiT denoiser
2
+ _target_: MolecularDiffusion.modules.tasks.diffusion_ldm.LDMTaskFactory
3
+ task_type: ldm_dit
4
+ _recursive: False
5
+ autoencoder_ckpt: ??? # Required: path to pre-trained VAE
6
+
7
+ denoiser:
8
+ _target_: MolecularDiffusion.modules.models.ldm.denoisers.dit.DiT
9
+ # d_x is auto-inferred from VAE latent_dim
10
+ d_model: 384
11
+ num_layers: 12
12
+ nhead: 6
13
+ class_dropout_prob: 0.1
14
+
15
+ interpolant:
16
+ type: flow_matching
17
+ min_t: 0.01
18
+ corrupt: true
19
+ num_timesteps: 100
20
+ self_condition: false
21
+ self_condition_prob: 0.5
22
+
23
+ # Data augmentation
24
+ augment_rotation: true
MolecularDiffusion/configs/tasks/regression.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: MolecularDiffusion.runmodes.train.ModelTaskFactory_EGCL
2
+ task_type: regression
3
+ atom_vocab: ${data.atom_vocab}
4
+ condition_names: []
5
+ hidden_size: 512
6
+ act_fn:
7
+ _target_: torch.nn.ReLU
8
+ num_layers: 1
9
+ attention: True
10
+ tanh: True
11
+ num_sublayers: 5
12
+ sin_embedding: False
13
+ aggregation_method: "sum"
14
+ dropout: 0.0
15
+ normalization: False # For EGNN backbone layer norm
16
+ include_cosine: True
17
+ norm_constant: 1.0
18
+ normalization_factor: 1.0
19
+ chkpt_path: null
20
+
21
+ # specific to regression
22
+ task_learn: [S1_exc,T1_exc]
23
+ criterion: mse
24
+ metric: [mae]
25
+ num_mlp_layer: 3
26
+ mlp_batch_norm: batchnorm # Options: null, layernorm, batchnorm
27
+ target_normalization: True # Normalize targets by mean/std in loss
28
+ mlp_dropout: 0.2
29
+ prediction_mlp_type: "pernode"
30
+ prediction_activation: "relu"
MolecularDiffusion/configs/tasks/regression_esen.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ _target_: MolecularDiffusion.runmodes.train.ModelTaskFactory_ESEN
2
+ task_type: regression
3
+ atom_vocab: ${data.atom_vocab}
4
+ # condition_names: []
5
+ hidden_size: 256
6
+ hidden_channels: 256
7
+ num_layers: 4
8
+ lmax: 2
9
+ mmax: 2
10
+ grid_resolution: null
11
+ cutoff: 5.0
12
+ edge_channels: 128
13
+ distance_function: "gaussian"
14
+ num_distance_basis: 512
15
+ norm_type: "rms_norm_sh"
16
+ act_type: "s2"
17
+ mlp_type: "grid"
18
+ use_envelope: False
19
+ activation_checkpointing: False
20
+ global_attributes: False
21
+ sphere_embedding_type: "mixed"
22
+ aggregation_method: mean
23
+ chkpt_path: null
24
+
25
+ # specific to regression
26
+ task_learn: [S1_exc,T1_exc]
27
+ criterion: mse
28
+ metric: [mae]
29
+ num_mlp_layer: 3
30
+ mlp_dropout: 0.2
31
+ mlp_batch_norm: batchnorm # Options: null, layernorm, batchnorm
32
+ target_normalization: True # Normalize targets by mean/std in loss
33
+ prediction_mlp_type: "pernode"
34
+ prediction_activation: "relu"