manpreet88 commited on
Commit
4700e18
·
1 Parent(s): 70ec142

Create Polymer_Generation.py

Browse files
Files changed (1) hide show
  1. Downstream Tasks/Polymer_Generation.py +2105 -0
Downstream Tasks/Polymer_Generation.py ADDED
@@ -0,0 +1,2105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # G2.py — PolyBART-style inverse design (single-task, fivefold per property, G1-style I/O)
2
+
3
+ import os
4
+ import re
5
+ import sys
6
+ import csv
7
+ import json
8
+ import math
9
+ import time
10
+ import copy
11
+ import random
12
+ import shutil
13
+ import warnings
14
+ from dataclasses import dataclass
15
+ from pathlib import Path
16
+ from typing import List, Dict, Optional, Tuple, Any
17
+
18
+ import numpy as np
19
+ import pandas as pd
20
+
21
+ from sklearn.model_selection import train_test_split, KFold
22
+ from sklearn.preprocessing import StandardScaler
23
+ from sklearn.decomposition import PCA
24
+ from sklearn.gaussian_process import GaussianProcessRegressor
25
+ from sklearn.gaussian_process.kernels import RBF, ConstantKernel as C, WhiteKernel
26
+
27
+ import torch
28
+ import torch.nn as nn
29
+ import torch.nn.functional as F
30
+ from torch.utils.data import Dataset, DataLoader
31
+
32
+ # Increase csv field size limit safely
33
+ try:
34
+ csv.field_size_limit(sys.maxsize)
35
+ except OverflowError:
36
+ csv.field_size_limit(2**31 - 1)
37
+
38
+ # HF Transformers
39
+ from transformers import DebertaV2ForMaskedLM, DebertaV2Tokenizer
40
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
41
+ from transformers.modeling_outputs import BaseModelOutput
42
+
43
+ # Optional: RDKit + selfies (required for this pipeline)
44
+ RDKit_AVAILABLE = False
45
+ SELFIES_AVAILABLE = False
46
+ try:
47
+ from rdkit import Chem
48
+ from rdkit.Chem import AllChem, DataStructs
49
+ RDKit_AVAILABLE = True
50
+ except Exception:
51
+ RDKit_AVAILABLE = False
52
+
53
+ try:
54
+ import selfies as sf
55
+ SELFIES_AVAILABLE = True
56
+ except Exception:
57
+ SELFIES_AVAILABLE = False
58
+
59
+ # PyG (optional)
60
+ try:
61
+ from torch_geometric.nn import GINEConv
62
+ from torch_geometric.nn.models import SchNet as PyGSchNet
63
+ from torch_geometric.nn import radius_graph
64
+ except Exception:
65
+ GINEConv = None
66
+ PyGSchNet = None
67
+ radius_graph = None
68
+
69
+ # =============================================================================
70
+ # Configuration
71
+ # =============================================================================
72
+
73
+ BASE_DIR = "Polymer_Foundational_Model"
74
+ POLYINFO_PATH = os.path.join(BASE_DIR, "polyinfo_with_modalities.csv")
75
+
76
+ # Pretrained model directories (your paths)
77
+ PRETRAINED_MULTIMODAL_DIR = "multimodal_output_5M/best"
78
+ BEST_GINE_DIR = "gin_output_5M/best"
79
+ BEST_SCHNET_DIR = "schnet_output_5M/best"
80
+ BEST_FP_DIR = "fingerprint_mlm_output_5M/best"
81
+ BEST_PSMILES_DIR = "polybert_output_5M/best"
82
+
83
+ # Output
84
+ OUTPUT_DIR = "multimodal_inverse_design_output_5M_polybart_style"
85
+ OUTPUT_RESULTS = os.path.join(OUTPUT_DIR, "inverse_design_results.txt")
86
+ OUTPUT_MODELS_DIR = os.path.join(OUTPUT_DIR, "best_models")
87
+ OUTPUT_GENERATIONS_DIR = os.path.join(OUTPUT_DIR, "best_fold_generations")
88
+
89
+ # Properties (order preserved)
90
+ REQUESTED_PROPERTIES = [
91
+ "density",
92
+ "glass transition",
93
+ "melting",
94
+ "specific volume",
95
+ "thermal decomposition"
96
+ ]
97
+
98
+ # -------------------------------------------------------------------------
99
+ # Model sizes / dims (match your CL encoder)
100
+ # -------------------------------------------------------------------------
101
+
102
+ CL_EMB_DIM = 600
103
+
104
+ # Model hyperparameters (matching your multimodal training)
105
+ MAX_ATOMIC_Z = 85
106
+ MASK_ATOM_ID = MAX_ATOMIC_Z + 1
107
+
108
+ # GINE params
109
+ NODE_EMB_DIM = 300
110
+ EDGE_EMB_DIM = 300
111
+ NUM_GNN_LAYERS = 5
112
+
113
+ # SchNet params
114
+ SCHNET_NUM_GAUSSIANS = 50
115
+ SCHNET_NUM_INTERACTIONS = 6
116
+ SCHNET_CUTOFF = 10.0
117
+ SCHNET_MAX_NEIGHBORS = 64
118
+ SCHNET_HIDDEN = 600
119
+
120
+ # Fingerprint params
121
+ FP_LENGTH = 2048
122
+ MASK_TOKEN_ID_FP = 2
123
+ VOCAB_SIZE_FP = 3
124
+
125
+ # DeBERTa params
126
+ DEBERTA_HIDDEN = 600
127
+ PSMILES_MAX_LEN = 128
128
+
129
+ # SELFIES-TED generation
130
+ GEN_MAX_LEN = 256
131
+ GEN_MIN_LEN = 10
132
+
133
+ # -------------------------------------------------------------------------
134
+ # Decoder fine-tuning params (single head; match G1-style simple schedule)
135
+ # -------------------------------------------------------------------------
136
+ BATCH_SIZE = 32
137
+ NUM_EPOCHS = 100
138
+ PATIENCE = 10
139
+ WEIGHT_DECAY = 0.0
140
+ LEARNING_RATE = 1e-4
141
+ COSINE_ETA_MIN = 1e-6
142
+
143
+ # PolyBART-style noise injection (latent space)
144
+ LATENT_NOISE_STD_TRAIN = 0.10 # training-time denoising std
145
+ LATENT_NOISE_STD_GEN = 0.15 # generation-time exploration std
146
+ N_FOLD_NOISE_SAMPLING = 16 # n-fold sampling around each seed embedding
147
+
148
+ # Sampling config (decoder)
149
+ GEN_TOP_P = 0.92
150
+ GEN_TEMPERATURE = 1.0
151
+ GEN_REPETITION_PENALTY = 1.05
152
+
153
+ # CV
154
+ NUM_FOLDS = 5
155
+
156
+ # Property guidance tolerance (scaled and optionally absolute)
157
+ PROP_TOL_SCALED = 0.5
158
+ PROP_TOL_UNSCALED_ABS = None
159
+
160
+ # GPR settings (PSMILES latent)
161
+ USE_PCA_BEFORE_GPR = True
162
+ PCA_DIM = 64
163
+ GPR_ALPHA = 1e-6
164
+
165
+ # Verification (optional auxiliary predictor)
166
+ VERIFY_GENERATED_PROPERTIES = True
167
+ PROP_PRED_EPOCHS = 20
168
+ PROP_PRED_PATIENCE = 5
169
+ PROP_PRED_BATCH_SIZE = 32
170
+ PROP_PRED_LR = 3e-4
171
+ PROP_PRED_WEIGHT_DECAY = 0.0
172
+
173
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
174
+ USE_AMP = bool(torch.cuda.is_available())
175
+ AMP_DTYPE = torch.float16
176
+
177
+ NUM_WORKERS = 0 if os.name == "nt" else 1
178
+
179
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
180
+ os.makedirs(OUTPUT_MODELS_DIR, exist_ok=True)
181
+ os.makedirs(OUTPUT_GENERATIONS_DIR, exist_ok=True)
182
+
183
+ warnings.filterwarnings("ignore", category=UserWarning)
184
+
185
+ # =============================================================================
186
+ # Utilities
187
+ # =============================================================================
188
+ def _safe_json_load(x):
189
+ if x is None:
190
+ return None
191
+ if isinstance(x, (dict, list)):
192
+ return x
193
+ s = str(x).strip()
194
+ if not s:
195
+ return None
196
+ try:
197
+ return json.loads(s)
198
+ except Exception:
199
+ try:
200
+ return json.loads(s.replace("'", '"'))
201
+ except Exception:
202
+ return None
203
+
204
+ def set_seed(seed: int):
205
+ random.seed(seed)
206
+ np.random.seed(seed)
207
+ torch.manual_seed(seed)
208
+ if torch.cuda.is_available():
209
+ torch.cuda.manual_seed_all(seed)
210
+ try:
211
+ torch.backends.cudnn.benchmark = True
212
+ except Exception:
213
+ pass
214
+
215
+ def make_json_serializable(obj):
216
+ if isinstance(obj, dict):
217
+ return {make_json_serializable(k): make_json_serializable(v) for k, v in obj.items()}
218
+ if isinstance(obj, (list, tuple, set)):
219
+ return [make_json_serializable(x) for x in obj]
220
+ if isinstance(obj, np.ndarray):
221
+ return obj.tolist()
222
+ if isinstance(obj, np.generic):
223
+ try:
224
+ return obj.item()
225
+ except Exception:
226
+ return float(obj)
227
+ if isinstance(obj, torch.Tensor):
228
+ try:
229
+ return obj.detach().cpu().tolist()
230
+ except Exception:
231
+ return None
232
+ if isinstance(obj, (pd.Timestamp, pd.Timedelta)):
233
+ return str(obj)
234
+ try:
235
+ if isinstance(obj, (float, int, str, bool, type(None))):
236
+ return obj
237
+ except Exception:
238
+ pass
239
+ return str(obj)
240
+
241
+ def safe_get(d: dict, key: str, default=None):
242
+ return d[key] if (isinstance(d, dict) and key in d) else default
243
+
244
+ def find_property_columns(columns):
245
+ lowered = {c.lower(): c for c in columns}
246
+ found = {}
247
+ for req in REQUESTED_PROPERTIES:
248
+ req_low = req.lower().strip()
249
+ exact = None
250
+ for c_low, c_orig in lowered.items():
251
+ tokens = set(c_low.replace('_', ' ').split())
252
+ if req_low in tokens or c_low == req_low:
253
+ if req_low == "density" and ("cohesive" in c_low or "cohesive energy" in c_low):
254
+ continue
255
+ exact = c_orig
256
+ break
257
+ if exact is not None:
258
+ found[req] = exact
259
+ continue
260
+ candidates = [c_orig for c_low, c_orig in lowered.items() if req_low in c_low]
261
+ if req_low == "density":
262
+ candidates = [c for c in candidates if "cohesive" not in c.lower() and "cohesive energy" not in c.lower()]
263
+ chosen = candidates[0] if candidates else None
264
+ found[req] = chosen
265
+ if chosen is None:
266
+ print(f"[WARN] No candidates found for '{req}'.")
267
+ else:
268
+ print(f"[INFO] Requested property '{req}' -> chosen column: {chosen}")
269
+ return found
270
+
271
+ # =============================================================================
272
+ # Graph / geometry / FP parsing for multimodal CL
273
+ # =============================================================================
274
+ def _parse_graph_for_gine(graph_field):
275
+ gf = _safe_json_load(graph_field)
276
+ if not isinstance(gf, dict):
277
+ return None
278
+ node_features = gf.get("node_features", None)
279
+ if not node_features or not isinstance(node_features, list):
280
+ return None
281
+ atomic_nums, chirality_vals, formal_charges = [], [], []
282
+ for nf in node_features:
283
+ if not isinstance(nf, dict):
284
+ continue
285
+ an = nf.get("atomic_num", nf.get("atomic_number", 0))
286
+ ch = nf.get("chirality", 0)
287
+ fc = nf.get("formal_charge", 0)
288
+ try: atomic_nums.append(int(an))
289
+ except Exception: atomic_nums.append(0)
290
+ try: chirality_vals.append(float(ch))
291
+ except Exception: chirality_vals.append(0.0)
292
+ try: formal_charges.append(float(fc))
293
+ except Exception: formal_charges.append(0.0)
294
+ if len(atomic_nums) == 0:
295
+ return None
296
+ edge_indices_raw = gf.get("edge_indices", None)
297
+ edge_features_raw = gf.get("edge_features", None)
298
+ srcs, dsts = [], []
299
+ if edge_indices_raw is None:
300
+ adj = gf.get("adjacency_matrix", None)
301
+ if isinstance(adj, list):
302
+ for i_r, row_adj in enumerate(adj):
303
+ if not isinstance(row_adj, list):
304
+ continue
305
+ for j, val in enumerate(row_adj):
306
+ if val:
307
+ srcs.append(i_r); dsts.append(j)
308
+ else:
309
+ try:
310
+ if isinstance(edge_indices_raw, list) and len(edge_indices_raw) > 0:
311
+ if isinstance(edge_indices_raw[0], list) and len(edge_indices_raw[0]) == 2:
312
+ srcs = [int(p[0]) for p in edge_indices_raw]
313
+ dsts = [int(p[1]) for p in edge_indices_raw]
314
+ elif len(edge_indices_raw) == 2:
315
+ srcs = [int(x) for x in edge_indices_raw[0]]
316
+ dsts = [int(x) for x in edge_indices_raw[1]]
317
+ except Exception:
318
+ srcs, dsts = [], []
319
+ if len(srcs) == 0:
320
+ edge_index = torch.empty((2, 0), dtype=torch.long)
321
+ edge_attr = torch.zeros((0, 3), dtype=torch.float)
322
+ return {
323
+ "z": torch.tensor(atomic_nums, dtype=torch.long),
324
+ "chirality": torch.tensor(chirality_vals, dtype=torch.float),
325
+ "formal_charge": torch.tensor(formal_charges, dtype=torch.float),
326
+ "edge_index": edge_index,
327
+ "edge_attr": edge_attr,
328
+ }
329
+ edge_index = torch.tensor([srcs, dsts], dtype=torch.long)
330
+ edge_attr = None
331
+ if isinstance(edge_features_raw, list) and len(edge_features_raw) == len(srcs):
332
+ bt, st, ic = [], [], []
333
+ for ef in edge_features_raw:
334
+ if isinstance(ef, dict):
335
+ bt.append(float(ef.get("bond_type", 0)))
336
+ st.append(float(ef.get("stereo", 0)))
337
+ ic.append(float(1.0 if ef.get("is_conjugated", False) else 0.0))
338
+ else:
339
+ bt.append(0.0); st.append(0.0); ic.append(0.0)
340
+ edge_attr = torch.tensor(list(zip(bt, st, ic)), dtype=torch.float)
341
+ else:
342
+ edge_attr = torch.zeros((len(srcs), 3), dtype=torch.float)
343
+ return {
344
+ "z": torch.tensor(atomic_nums, dtype=torch.long),
345
+ "chirality": torch.tensor(chirality_vals, dtype=torch.float),
346
+ "formal_charge": torch.tensor(formal_charges, dtype=torch.float),
347
+ "edge_index": edge_index,
348
+ "edge_attr": edge_attr,
349
+ }
350
+
351
+ def _parse_geometry_for_schnet(geom_field):
352
+ gf = _safe_json_load(geom_field)
353
+ if not isinstance(gf, dict):
354
+ return None
355
+ conf = gf.get("best_conformer", None)
356
+ if not isinstance(conf, dict):
357
+ return None
358
+ atomic = conf.get("atomic_numbers", [])
359
+ coords = conf.get("coordinates", [])
360
+ if not (isinstance(atomic, list) and isinstance(coords, list)):
361
+ return None
362
+ if len(atomic) == 0 or len(atomic) != len(coords):
363
+ return None
364
+ return {"z": torch.tensor(atomic, dtype=torch.long), "pos": torch.tensor(coords, dtype=torch.float)}
365
+
366
+ def _parse_fingerprints(fp_field, fp_len: int = 2048):
367
+ fp = _safe_json_load(fp_field)
368
+ bits = None
369
+ if isinstance(fp, dict):
370
+ bits = fp.get("morgan_r3_bits", None)
371
+ elif isinstance(fp, list):
372
+ bits = fp
373
+ elif fp is None:
374
+ bits = None
375
+ if bits is None:
376
+ bits = [0] * fp_len
377
+ else:
378
+ norm = []
379
+ for b in bits[:fp_len]:
380
+ if isinstance(b, str):
381
+ bc = b.strip().strip('"').strip("'")
382
+ norm.append(1 if bc in ("1", "True", "true") else 0)
383
+ elif isinstance(b, (int, np.integer, float, np.floating)):
384
+ norm.append(1 if int(b) != 0 else 0)
385
+ else:
386
+ norm.append(0)
387
+ if len(norm) < fp_len:
388
+ norm.extend([0] * (fp_len - len(norm)))
389
+ bits = norm
390
+ return {
391
+ "input_ids": torch.tensor(bits, dtype=torch.long),
392
+ "attention_mask": torch.ones(fp_len, dtype=torch.bool),
393
+ }
394
+
395
+ # =============================================================================
396
+ # PSELFIES utilities
397
+ # =============================================================================
398
+
399
+ _SELFIES_TOKEN_RE = re.compile(r"\[[^\[\]]+\]")
400
+
401
+ def _split_selfies_tokens(selfies_str: str) -> List[str]:
402
+ if not isinstance(selfies_str, str) or len(selfies_str) == 0:
403
+ return []
404
+ if SELFIES_AVAILABLE:
405
+ try:
406
+ toks = list(sf.split_selfies(selfies_str.replace(" ", "")))
407
+ return [t for t in toks if isinstance(t, str) and t]
408
+ except Exception:
409
+ pass
410
+ return _SELFIES_TOKEN_RE.findall(selfies_str)
411
+
412
+ def _selfies_for_tokenizer(selfies_str: str) -> str:
413
+ s = str(selfies_str).strip()
414
+ if not s:
415
+ return ""
416
+ s = s.replace(" ", "")
417
+ s = s.replace("][", "] [")
418
+ return s
419
+
420
+ def _selfies_compact(selfies_str: str) -> str:
421
+ return str(selfies_str).replace(" ", "").strip()
422
+
423
+ def _ensure_two_at_endpoints(selfies_str: str) -> str:
424
+ s = _selfies_compact(selfies_str)
425
+ toks = _split_selfies_tokens(s)
426
+ if not toks:
427
+ return s
428
+ at = "[At]"
429
+ at_pos = [i for i, t in enumerate(toks) if t == at]
430
+ if len(at_pos) == 0:
431
+ toks = [at] + toks + [at]
432
+ elif len(at_pos) == 1:
433
+ toks = toks + [at]
434
+ elif len(at_pos) > 2:
435
+ first = at_pos[0]; last = at_pos[-1]
436
+ new = []
437
+ for i, t in enumerate(toks):
438
+ if t == at and i not in (first, last):
439
+ continue
440
+ new.append(t)
441
+ toks = new
442
+ return "".join(toks)
443
+
444
+ def psmiles_to_at_smiles(psmiles: str, root_at: bool = True) -> Optional[str]:
445
+ if not RDKit_AVAILABLE:
446
+ return None
447
+ try:
448
+ mol = Chem.MolFromSmiles(psmiles)
449
+ if mol is None:
450
+ return None
451
+ mol = Chem.RWMol(mol)
452
+ at_indices = []
453
+ for atom in mol.GetAtoms():
454
+ if atom.GetAtomicNum() == 0:
455
+ atom.SetAtomicNum(85)
456
+ try: atom.SetNoImplicit(True)
457
+ except Exception: pass
458
+ try: atom.SetNumExplicitHs(0)
459
+ except Exception: pass
460
+ try: atom.SetFormalCharge(0)
461
+ except Exception: pass
462
+ at_indices.append(int(atom.GetIdx()))
463
+ mol = mol.GetMol()
464
+ try:
465
+ Chem.SanitizeMol(mol, catchErrors=True)
466
+ except Exception:
467
+ return None
468
+ if root_at and len(at_indices) > 0:
469
+ try:
470
+ can = Chem.MolToSmiles(mol, canonical=True, rootedAtAtom=at_indices[0])
471
+ except Exception:
472
+ can = Chem.MolToSmiles(mol, canonical=True)
473
+ else:
474
+ can = Chem.MolToSmiles(mol, canonical=True)
475
+ return can
476
+ except Exception:
477
+ return None
478
+
479
+ def at_smiles_to_psmiles(at_smiles: str) -> Optional[str]:
480
+ if not RDKit_AVAILABLE:
481
+ return None
482
+ try:
483
+ mol = Chem.MolFromSmiles(at_smiles)
484
+ if mol is None:
485
+ return None
486
+ rw = Chem.RWMol(mol)
487
+ for atom in rw.GetAtoms():
488
+ if atom.GetAtomicNum() == 85:
489
+ atom.SetAtomicNum(0)
490
+ try: atom.SetNoImplicit(True)
491
+ except Exception: pass
492
+ try: atom.SetNumExplicitHs(0)
493
+ except Exception: pass
494
+ try: atom.SetFormalCharge(0)
495
+ except Exception: pass
496
+ mol2 = rw.GetMol()
497
+ try:
498
+ Chem.SanitizeMol(mol2, catchErrors=True)
499
+ except Exception:
500
+ return None
501
+ can = Chem.MolToSmiles(mol2, canonical=True)
502
+ can = can.replace("[*]", "*")
503
+ return can
504
+ except Exception:
505
+ return None
506
+
507
+ def smiles_to_pselfies(smiles: str) -> Optional[str]:
508
+ if not (RDKit_AVAILABLE and SELFIES_AVAILABLE):
509
+ return None
510
+ try:
511
+ mol = Chem.MolFromSmiles(smiles)
512
+ if mol is None:
513
+ return None
514
+ can = Chem.MolToSmiles(mol, canonical=True)
515
+ s = sf.encoder(can)
516
+ if not isinstance(s, str) or len(s) == 0:
517
+ return None
518
+ return s
519
+ except Exception:
520
+ return None
521
+
522
+ def psmiles_to_pselfies(psmiles: str) -> Optional[str]:
523
+ if not (RDKit_AVAILABLE and SELFIES_AVAILABLE):
524
+ return None
525
+ at_smiles = psmiles_to_at_smiles(psmiles, root_at=True)
526
+ if at_smiles is None:
527
+ return None
528
+ s = smiles_to_pselfies(at_smiles)
529
+ if s is None:
530
+ return None
531
+ return _ensure_two_at_endpoints(s)
532
+
533
+ def selfies_to_smiles(selfies_str: str) -> Optional[str]:
534
+ if not (RDKit_AVAILABLE and SELFIES_AVAILABLE):
535
+ return None
536
+ try:
537
+ s = _selfies_compact(selfies_str)
538
+ smi = sf.decoder(s)
539
+ if not isinstance(smi, str) or len(smi) == 0:
540
+ return None
541
+ mol = Chem.MolFromSmiles(smi)
542
+ if mol is None:
543
+ return None
544
+ try:
545
+ Chem.SanitizeMol(mol, catchErrors=True)
546
+ except Exception:
547
+ return None
548
+ can = Chem.MolToSmiles(mol, canonical=True)
549
+ return can
550
+ except Exception:
551
+ return None
552
+
553
+ def pselfies_to_psmiles(selfies_str: str) -> Optional[str]:
554
+ if not (RDKit_AVAILABLE and SELFIES_AVAILABLE):
555
+ return None
556
+ at_smiles = selfies_to_smiles(selfies_str)
557
+ if at_smiles is None:
558
+ return None
559
+ return at_smiles_to_psmiles(at_smiles)
560
+
561
+ def canonicalize_psmiles(psmiles: str) -> Optional[str]:
562
+ psmiles = str(psmiles).strip()
563
+ if not psmiles:
564
+ return None
565
+ if not RDKit_AVAILABLE:
566
+ return psmiles
567
+ try:
568
+ mol = Chem.MolFromSmiles(psmiles)
569
+ if mol is None:
570
+ return None
571
+ try:
572
+ Chem.SanitizeMol(mol, catchErrors=True)
573
+ except Exception:
574
+ return None
575
+ can = Chem.MolToSmiles(mol, canonical=True)
576
+ can = can.replace("[*]", "*")
577
+ return can
578
+ except Exception:
579
+ return None
580
+
581
+ def chem_validity_psmiles(psmiles: str) -> bool:
582
+ if not RDKit_AVAILABLE:
583
+ return False
584
+ try:
585
+ s = str(psmiles).strip()
586
+ if not s:
587
+ return False
588
+ mol = Chem.MolFromSmiles(s)
589
+ if mol is None:
590
+ return False
591
+ try:
592
+ Chem.SanitizeMol(mol, catchErrors=True)
593
+ except Exception:
594
+ return False
595
+ return True
596
+ except Exception:
597
+ return False
598
+
599
+ def polymer_validity_psmiles_strict(psmiles: str) -> bool:
600
+ if not RDKit_AVAILABLE:
601
+ return False
602
+ try:
603
+ s = str(psmiles).strip()
604
+ if not s:
605
+ return False
606
+ mol = Chem.MolFromSmiles(s)
607
+ if mol is None:
608
+ return False
609
+ try:
610
+ Chem.SanitizeMol(mol, catchErrors=True)
611
+ except Exception:
612
+ return False
613
+ stars = [a for a in mol.GetAtoms() if a.GetAtomicNum() == 0]
614
+ if len(stars) != 2:
615
+ return False
616
+ for a in stars:
617
+ if a.GetTotalDegree() != 1:
618
+ return False
619
+ return True
620
+ except Exception:
621
+ return False
622
+
623
+ # =============================================================================
624
+ # CL encoder (multimodal) + fusion pooling (same heads/dims as CL pretraining)
625
+ # =============================================================================
626
+
627
+ def resolve_cl_checkpoint_path(cl_weights_dir: str) -> Optional[str]:
628
+ if cl_weights_dir is None:
629
+ return None
630
+ if os.path.isfile(cl_weights_dir):
631
+ return cl_weights_dir
632
+ if not os.path.isdir(cl_weights_dir):
633
+ return None
634
+ candidates = [
635
+ os.path.join(cl_weights_dir, "pytorch_model.bin"),
636
+ os.path.join(cl_weights_dir, "model.pt"),
637
+ os.path.join(cl_weights_dir, "best.pt"),
638
+ os.path.join(cl_weights_dir, "state_dict.pt"),
639
+ ]
640
+ for p in candidates:
641
+ if os.path.isfile(p):
642
+ return p
643
+ for ext in ("*.bin", "*.pt"):
644
+ files = sorted(Path(cl_weights_dir).glob(ext))
645
+ if files:
646
+ return str(files[0])
647
+ return None
648
+
649
+ def load_state_dict_any(ckpt_path: str) -> Dict[str, torch.Tensor]:
650
+ obj = torch.load(ckpt_path, map_location="cpu")
651
+ if isinstance(obj, dict):
652
+ if "state_dict" in obj and isinstance(obj["state_dict"], dict):
653
+ return obj["state_dict"]
654
+ if "model_state_dict" in obj and isinstance(obj["model_state_dict"], dict):
655
+ return obj["model_state_dict"]
656
+ if not isinstance(obj, dict):
657
+ raise RuntimeError(f"Checkpoint at {ckpt_path} did not contain a state_dict-like dict.")
658
+ return obj
659
+
660
+ def safe_load_into_module(module: nn.Module, sd: Dict[str, torch.Tensor], strict: bool = False) -> Tuple[int, int]:
661
+ incompatible = module.load_state_dict(sd, strict=strict)
662
+ missing = getattr(incompatible, "missing_keys", [])
663
+ unexpected = getattr(incompatible, "unexpected_keys", [])
664
+ return len(missing), len(unexpected)
665
+
666
+ class GineBlock(nn.Module):
667
+ def __init__(self, node_dim):
668
+ super().__init__()
669
+ self.mlp = nn.Sequential(nn.Linear(node_dim, node_dim), nn.ReLU(), nn.Linear(node_dim, node_dim))
670
+ if GINEConv is None:
671
+ raise RuntimeError("GINEConv is not available. Install torch_geometric with compatible versions.")
672
+ self.conv = GINEConv(self.mlp)
673
+ self.bn = nn.BatchNorm1d(node_dim)
674
+ self.act = nn.ReLU()
675
+ def forward(self, x, edge_index, edge_attr):
676
+ x = self.conv(x, edge_index, edge_attr)
677
+ x = self.bn(x)
678
+ x = self.act(x)
679
+ return x
680
+
681
+ class GineEncoder(nn.Module):
682
+ def __init__(self, node_emb_dim=NODE_EMB_DIM, edge_emb_dim=EDGE_EMB_DIM, num_layers=NUM_GNN_LAYERS, max_atomic_z=MAX_ATOMIC_Z):
683
+ super().__init__()
684
+ self.atom_emb = nn.Embedding(num_embeddings=MASK_ATOM_ID+1, embedding_dim=node_emb_dim, padding_idx=None)
685
+ self.node_attr_proj = nn.Sequential(nn.Linear(2, node_emb_dim), nn.ReLU(), nn.Linear(node_emb_dim, node_emb_dim))
686
+ self.edge_encoder = nn.Sequential(nn.Linear(3, edge_emb_dim), nn.ReLU(), nn.Linear(edge_emb_dim, edge_emb_dim))
687
+ self._edge_to_node_proj = nn.Linear(edge_emb_dim, node_emb_dim) if edge_emb_dim != node_emb_dim else None
688
+ self.gnn_layers = nn.ModuleList([GineBlock(node_emb_dim) for _ in range(num_layers)])
689
+ self.pool_proj = nn.Linear(node_emb_dim, node_emb_dim)
690
+ self.node_classifier = nn.Linear(node_emb_dim, MASK_ATOM_ID+1)
691
+ def _compute_node_reps(self, z, chirality, formal_charge, edge_index, edge_attr):
692
+ device = next(self.parameters()).device
693
+ atom_embedding = self.atom_emb(z.to(device))
694
+ if chirality is None or formal_charge is None:
695
+ node_attr = torch.zeros((z.size(0), 2), device=device)
696
+ else:
697
+ node_attr = torch.stack([chirality, formal_charge], dim=1).to(atom_embedding.device)
698
+ node_attr_emb = self.node_attr_proj(node_attr)
699
+ x = atom_embedding + node_attr_emb
700
+ if edge_attr is None or edge_attr.numel() == 0:
701
+ edge_emb = torch.zeros((0, EDGE_EMB_DIM), dtype=torch.float, device=x.device)
702
+ else:
703
+ edge_emb = self.edge_encoder(edge_attr.to(x.device))
704
+ edge_for_conv = self._edge_to_node_proj(edge_emb) if (self._edge_to_node_proj is not None and edge_emb.numel() > 0) else edge_emb
705
+ h = x
706
+ for layer in self.gnn_layers:
707
+ h = layer(h, edge_index.to(h.device), edge_for_conv)
708
+ return h
709
+ def forward(self, z, chirality, formal_charge, edge_index, edge_attr, batch=None):
710
+ h = self._compute_node_reps(z, chirality, formal_charge, edge_index, edge_attr)
711
+ if batch is None:
712
+ pooled = torch.mean(h, dim=0, keepdim=True)
713
+ else:
714
+ bsize = int(batch.max().item() + 1) if batch.numel() > 0 else 1
715
+ pooled = torch.zeros((bsize, h.size(1)), device=h.device)
716
+ for i in range(bsize):
717
+ mask = batch == i
718
+ if mask.sum() == 0:
719
+ continue
720
+ pooled[i] = h[mask].mean(dim=0)
721
+ return self.pool_proj(pooled)
722
+
723
+ class NodeSchNetWrapper(nn.Module):
724
+ def __init__(self, hidden_channels=SCHNET_HIDDEN, num_interactions=SCHNET_NUM_INTERACTIONS,
725
+ num_gaussians=SCHNET_NUM_GAUSSIANS, cutoff=SCHNET_CUTOFF, max_num_neighbors=SCHNET_MAX_NEIGHBORS):
726
+ super().__init__()
727
+ if PyGSchNet is None:
728
+ raise RuntimeError("PyG SchNet is not available. Install torch_geometric with compatible extras.")
729
+ self.schnet = PyGSchNet(hidden_channels=hidden_channels, num_filters=hidden_channels,
730
+ num_interactions=num_interactions, num_gaussians=SCHNET_NUM_GAUSSIANS,
731
+ cutoff=cutoff, max_num_neighbors=max_num_neighbors)
732
+ self.pool_proj = nn.Linear(hidden_channels, hidden_channels)
733
+ self.cutoff = cutoff
734
+ self.max_num_neighbors = max_num_neighbors
735
+ self.node_classifier = nn.Linear(hidden_channels, MASK_ATOM_ID+1)
736
+ def forward(self, z, pos, batch=None):
737
+ device = next(self.parameters()).device
738
+ z = z.to(device); pos = pos.to(device)
739
+ if batch is None:
740
+ batch = torch.zeros(z.size(0), dtype=torch.long, device=z.device)
741
+ try:
742
+ edge_index = radius_graph(pos, r=self.cutoff, batch=batch, max_num_neighbors=self.max_num_neighbors)
743
+ except Exception:
744
+ edge_index = None
745
+ node_h = None
746
+ try:
747
+ node_h = self.schnet.embedding(z)
748
+ except Exception:
749
+ node_h = None
750
+ if node_h is not None and edge_index is not None and edge_index.numel() > 0:
751
+ row, col = edge_index
752
+ edge_weight = (pos[row] - pos[col]).norm(dim=-1)
753
+ edge_attr = None
754
+ if hasattr(self.schnet, "distance_expansion"):
755
+ try: edge_attr = self.schnet.distance_expansion(edge_weight)
756
+ except Exception: edge_attr = None
757
+ if edge_attr is None and hasattr(self.schnet, "gaussian_smearing"):
758
+ try: edge_attr = self.schnet.gaussian_smearing(edge_weight)
759
+ except Exception: edge_attr = None
760
+ if hasattr(self.schnet, "interactions") and getattr(self.schnet, "interactions") is not None:
761
+ for interaction in self.schnet.interactions:
762
+ try:
763
+ node_h = node_h + interaction(node_h, edge_index, edge_weight, edge_attr)
764
+ except TypeError:
765
+ node_h = node_h + interaction(node_h, edge_index, edge_weight)
766
+ if node_h is None:
767
+ try:
768
+ out = self.schnet(z=z, pos=pos, batch=batch)
769
+ if isinstance(out, torch.Tensor) and out.dim() == 2 and out.size(0) == z.size(0):
770
+ node_h = out
771
+ elif hasattr(out, "last_hidden_state"):
772
+ node_h = out.last_hidden_state
773
+ elif isinstance(out, (tuple, list)) and len(out) > 0 and isinstance(out[0], torch.Tensor):
774
+ cand = out[0]
775
+ if cand.dim() == 2 and cand.size(0) == z.size(0):
776
+ node_h = cand
777
+ except Exception as e:
778
+ raise RuntimeError("Failed to obtain node-level embeddings from PyG SchNet.") from e
779
+ bsize = int(batch.max().item()) + 1 if z.numel() > 0 else 1
780
+ pooled = torch.zeros((bsize, node_h.size(1)), device=node_h.device)
781
+ for i in range(bsize):
782
+ mask = batch == i
783
+ if mask.sum() == 0:
784
+ continue
785
+ pooled[i] = node_h[mask].mean(dim=0)
786
+ return self.pool_proj(pooled)
787
+
788
+ class FingerprintEncoder(nn.Module):
789
+ def __init__(self, vocab_size=VOCAB_SIZE_FP, hidden_dim=256, seq_len=FP_LENGTH,
790
+ num_layers=4, nhead=8, dim_feedforward=1024, dropout=0.1):
791
+ super().__init__()
792
+ self.token_emb = nn.Embedding(vocab_size, hidden_dim)
793
+ self.pos_emb = nn.Embedding(seq_len, hidden_dim)
794
+ encoder_layer = nn.TransformerEncoderLayer(d_model=hidden_dim, nhead=nhead,
795
+ dim_feedforward=dim_feedforward, dropout=dropout, batch_first=True)
796
+ self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
797
+ self.pool_proj = nn.Linear(hidden_dim, hidden_dim)
798
+ self.seq_len = seq_len
799
+ self.token_proj = nn.Linear(hidden_dim, vocab_size)
800
+ def forward(self, input_ids, attention_mask=None):
801
+ device = next(self.parameters()).device
802
+ input_ids = input_ids.to(device)
803
+ B, L = input_ids.shape
804
+ x = self.token_emb(input_ids)
805
+ pos_ids = torch.arange(L, device=input_ids.device).unsqueeze(0).expand(B, -1)
806
+ x = x + self.pos_emb(pos_ids)
807
+ key_padding_mask = (~attention_mask.to(input_ids.device)) if attention_mask is not None else None
808
+ out = self.transformer(x, src_key_padding_mask=key_padding_mask)
809
+ if attention_mask is None:
810
+ pooled = out.mean(dim=1)
811
+ else:
812
+ am = attention_mask.to(out.device).float().unsqueeze(-1)
813
+ pooled = (out * am).sum(dim=1) / (am.sum(dim=1).clamp(min=1.0))
814
+ return self.pool_proj(pooled)
815
+
816
+ class PSMILESDebertaEncoder(nn.Module):
817
+ def __init__(self, model_dir_or_name: Optional[str] = None, vocab_size: Optional[int] = None):
818
+ super().__init__()
819
+ try:
820
+ if model_dir_or_name is not None and os.path.isdir(model_dir_or_name):
821
+ self.model = DebertaV2ForMaskedLM.from_pretrained(model_dir_or_name)
822
+ else:
823
+ self.model = DebertaV2ForMaskedLM.from_pretrained(model_dir_or_name or "microsoft/deberta-v2-xlarge")
824
+ except Exception:
825
+ from transformers import DebertaV2Config
826
+ cfg = DebertaV2Config(
827
+ vocab_size=int(vocab_size) if vocab_size is not None else 300,
828
+ hidden_size=DEBERTA_HIDDEN,
829
+ num_attention_heads=12,
830
+ num_hidden_layers=12,
831
+ intermediate_size=4 * DEBERTA_HIDDEN,
832
+ )
833
+ self.model = DebertaV2ForMaskedLM(cfg)
834
+ self.pool_proj = nn.Linear(self.model.config.hidden_size, self.model.config.hidden_size)
835
+ @property
836
+ def out_dim(self) -> int:
837
+ return int(self.model.config.hidden_size)
838
+ def forward(self, input_ids, attention_mask=None):
839
+ device = next(self.parameters()).device
840
+ input_ids = input_ids.to(device)
841
+ if attention_mask is not None:
842
+ attention_mask = attention_mask.to(device)
843
+ outputs = self.model.base_model(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
844
+ last_hidden = outputs.last_hidden_state
845
+ if attention_mask is None:
846
+ pooled = last_hidden.mean(dim=1)
847
+ else:
848
+ am = attention_mask.unsqueeze(-1).to(last_hidden.device).float()
849
+ pooled = (last_hidden * am).sum(dim=1) / (am.sum(dim=1).clamp(min=1.0))
850
+ return self.pool_proj(pooled)
851
+
852
+ class UniPolyFusionModule(nn.Module):
853
+ def __init__(self, d_model: int, nhead: int = 8, ffn_mult: int = 4, dropout: float = 0.1):
854
+ super().__init__()
855
+ self.ln1 = nn.LayerNorm(d_model)
856
+ self.attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True)
857
+ self.ln2 = nn.LayerNorm(d_model)
858
+ self.ffn = nn.Sequential(
859
+ nn.Linear(d_model, ffn_mult * d_model), nn.GELU(), nn.Dropout(dropout),
860
+ nn.Linear(ffn_mult * d_model, d_model), nn.Dropout(dropout),
861
+ )
862
+ self.pool_ln = nn.LayerNorm(d_model)
863
+ self.pool_q = nn.Parameter(torch.randn(d_model))
864
+ def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
865
+ key_padding = ~mask
866
+ h = self.ln1(x)
867
+ attn_out, _ = self.attn(h, h, h, key_padding_mask=key_padding)
868
+ x = x + attn_out
869
+ x = x + self.ffn(self.ln2(x))
870
+ x = self.pool_ln(x)
871
+ q = self.pool_q.unsqueeze(0).unsqueeze(-1)
872
+ scores = torch.matmul(x, q).squeeze(-1)
873
+ scores = scores.masked_fill(~mask, -1e9)
874
+ w = torch.softmax(scores, dim=-1).unsqueeze(-1)
875
+ pooled = (x * w).sum(dim=1)
876
+ return pooled
877
+
878
+ class MultiModalCLPolymerEncoder(nn.Module):
879
+ def __init__(self, psmiles_tokenizer, emb_dim: int = CL_EMB_DIM,
880
+ cl_weights_dir: Optional[str] = PRETRAINED_MULTIMODAL_DIR,
881
+ use_gine: bool = True, use_schnet: bool = True, use_fp: bool = True, use_psmiles: bool = True):
882
+ super().__init__()
883
+ self.psm_tok = psmiles_tokenizer
884
+ self.emb_dim = int(emb_dim)
885
+ self.gine = None; self.schnet = None; self.fp = None; self.psmiles = None
886
+ if use_gine:
887
+ try:
888
+ self.gine = GineEncoder(NODE_EMB_DIM, EDGE_EMB_DIM, NUM_GNN_LAYERS, MAX_ATOMIC_Z)
889
+ except Exception as e:
890
+ print(f"[CL] GINE disabled: {e}"); self.gine = None
891
+ if use_schnet:
892
+ try:
893
+ self.schnet = NodeSchNetWrapper(SCHNET_HIDDEN, SCHNET_NUM_INTERACTIONS,
894
+ SCHNET_NUM_GAUSSIANS, SCHNET_CUTOFF, SCHNET_MAX_NEIGHBORS)
895
+ except Exception as e:
896
+ print(f"[CL] SchNet disabled: {e}"); self.schnet = None
897
+ if use_fp:
898
+ try:
899
+ self.fp = FingerprintEncoder(VOCAB_SIZE_FP, 256, FP_LENGTH, 4, 8, 1024, 0.1)
900
+ except Exception as e:
901
+ print(f"[CL] FP encoder disabled: {e}"); self.fp = None
902
+ if use_psmiles:
903
+ enc_src = BEST_PSMILES_DIR if (BEST_PSMILES_DIR and os.path.isdir(BEST_PSMILES_DIR)) else None
904
+ self.psmiles = PSMILESDebertaEncoder(model_dir_or_name=enc_src, vocab_size=getattr(psmiles_tokenizer, "vocab_size", None))
905
+ self.proj_gine = nn.Linear(NODE_EMB_DIM, self.emb_dim) if self.gine is not None else None
906
+ self.proj_schnet = nn.Linear(SCHNET_HIDDEN, self.emb_dim) if self.schnet is not None else None
907
+ self.proj_fp = nn.Linear(256, self.emb_dim) if self.fp is not None else None
908
+ self.proj_psmiles = nn.Linear(DEBERTA_HIDDEN, self.emb_dim) if self.psmiles is not None else None
909
+ self.dropout = nn.Dropout(0.1)
910
+ self.out_dim = self.emb_dim
911
+ self.fusion = UniPolyFusionModule(d_model=self.emb_dim, nhead=8, ffn_mult=4, dropout=0.1)
912
+ self._load_multimodal_cl_checkpoint(cl_weights_dir)
913
+ def _load_multimodal_cl_checkpoint(self, cl_weights_dir: Optional[str]):
914
+ ckpt_path = resolve_cl_checkpoint_path(cl_weights_dir) if cl_weights_dir else None
915
+ if ckpt_path is None:
916
+ print(f"[CL] No multimodal CL checkpoint found at: {cl_weights_dir}. Using initialized encoders/projections.")
917
+ return
918
+ sd = load_state_dict_any(ckpt_path); model_sd = self.state_dict()
919
+ filtered = {}
920
+ for k, v in sd.items():
921
+ if k not in model_sd:
922
+ continue
923
+ if hasattr(v, "shape") and hasattr(model_sd[k], "shape") and tuple(v.shape) != tuple(model_sd[k].shape):
924
+ continue
925
+ filtered[k] = v
926
+ missing, unexpected = safe_load_into_module(self, filtered, strict=False)
927
+ print(f"[CL] Loaded multimodal CL from {ckpt_path}. loaded={len(filtered)} missing={missing} unexpected={unexpected}")
928
+ def freeze_cl_encoders(self):
929
+ for encoder_name, encoder in [("gine", self.gine), ("schnet", self.schnet), ("fp", self.fp), ("psmiles", self.psmiles)]:
930
+ if encoder is not None:
931
+ encoder.eval()
932
+ for p in encoder.parameters(): p.requires_grad = False
933
+ print(f"[CL] Froze {encoder_name} encoder")
934
+ self.fusion.eval()
935
+ for p in self.fusion.parameters(): p.requires_grad = False
936
+ def forward_multimodal(self, batch_mods: dict) -> torch.Tensor:
937
+ device = next(self.parameters()).device
938
+ B = None
939
+ if batch_mods.get("fp", None) is not None and isinstance(batch_mods["fp"].get("input_ids", None), torch.Tensor):
940
+ B = int(batch_mods["fp"]["input_ids"].size(0))
941
+ elif batch_mods.get("psmiles", None) is not None and isinstance(batch_mods["psmiles"].get("input_ids", None), torch.Tensor):
942
+ B = int(batch_mods["psmiles"]["input_ids"].size(0))
943
+ else:
944
+ if batch_mods.get("gine", None) is not None and isinstance(batch_mods["gine"].get("batch", None), torch.Tensor):
945
+ B = int(batch_mods["gine"]["batch"].max().item() + 1) if batch_mods["gine"]["batch"].numel() > 0 else 1
946
+ elif batch_mods.get("schnet", None) is not None and isinstance(batch_mods["schnet"].get("batch", None), torch.Tensor):
947
+ B = int(batch_mods["schnet"]["batch"].max().item() + 1) if batch_mods["schnet"]["batch"].numel() > 0 else 1
948
+ else:
949
+ B = 1
950
+ tokens = []; masks = []
951
+ def _append_token(z_token: torch.Tensor):
952
+ tokens.append(z_token); masks.append(torch.ones((z_token.size(0),), dtype=torch.bool, device=device))
953
+ if self.gine is not None and batch_mods.get("gine", None) is not None:
954
+ g = batch_mods["gine"]
955
+ if isinstance(g.get("z", None), torch.Tensor) and g["z"].numel() > 0:
956
+ emb_g = self.gine(
957
+ g["z"].to(device),
958
+ g.get("chirality", torch.zeros_like(g["z"], dtype=torch.float)).to(device) if isinstance(g.get("chirality", None), torch.Tensor) else None,
959
+ g.get("formal_charge", torch.zeros_like(g["z"], dtype=torch.float)).to(device) if isinstance(g.get("formal_charge", None), torch.Tensor) else None,
960
+ g.get("edge_index", torch.empty((2,0), dtype=torch.long)).to(device),
961
+ g.get("edge_attr", torch.zeros((0,3), dtype=torch.float)).to(device),
962
+ g.get("batch", None).to(device) if isinstance(g.get("batch", None), torch.Tensor) else None
963
+ )
964
+ zg = self.proj_gine(emb_g); zg = self.dropout(zg); _append_token(zg)
965
+ if self.schnet is not None and batch_mods.get("schnet", None) is not None:
966
+ s = batch_mods["schnet"]
967
+ if isinstance(s.get("z", None), torch.Tensor) and s["z"].numel() > 0:
968
+ emb_s = self.schnet(s["z"].to(device), s["pos"].to(device), s.get("batch", None).to(device) if isinstance(s.get("batch", None), torch.Tensor) else None)
969
+ zs = self.proj_schnet(emb_s); zs = self.dropout(zs); _append_token(zs)
970
+ if self.fp is not None and batch_mods.get("fp", None) is not None:
971
+ f = batch_mods["fp"]
972
+ if isinstance(f.get("input_ids", None), torch.Tensor) and f["input_ids"].numel() > 0:
973
+ emb_f = self.fp(f["input_ids"].to(device), f.get("attention_mask", None).to(device) if isinstance(f.get("attention_mask", None), torch.Tensor) else None)
974
+ zf = self.proj_fp(emb_f); zf = self.dropout(zf); _append_token(zf)
975
+ if self.psmiles is not None and batch_mods.get("psmiles", None) is not None:
976
+ p = batch_mods["psmiles"]
977
+ if isinstance(p.get("input_ids", None), torch.Tensor) and p["input_ids"].numel() > 0:
978
+ emb_p = self.psmiles(p["input_ids"].to(device), p.get("attention_mask", None).to(device) if isinstance(p.get("attention_mask", None), torch.Tensor) else None)
979
+ zp = self.proj_psmiles(emb_p); zp = self.dropout(zp); _append_token(zp)
980
+ if not tokens:
981
+ z = torch.zeros((B, self.emb_dim), device=device)
982
+ return F.normalize(z, dim=-1)
983
+ X = torch.stack(tokens, dim=1)
984
+ mask = torch.ones((B, X.size(1)), dtype=torch.bool, device=device)
985
+ pooled = self.fusion(X, mask)
986
+ pooled = F.normalize(pooled, dim=-1)
987
+ return pooled
988
+ @torch.no_grad()
989
+ def encode_psmiles(self, psmiles_list: List[str], max_len: int = PSMILES_MAX_LEN, batch_size: int = 64, device: str = DEVICE) -> np.ndarray:
990
+ self.eval()
991
+ if self.psm_tok is None or self.psmiles is None or self.proj_psmiles is None:
992
+ raise RuntimeError("PSMILES tokenizer/encoder/projection not available.")
993
+ dev = torch.device(device)
994
+ self.to(dev)
995
+ outs = []
996
+ for i in range(0, len(psmiles_list), batch_size):
997
+ chunk = [str(x) for x in psmiles_list[i:i + batch_size]]
998
+ enc = self.psm_tok(chunk, truncation=True, padding="max_length", max_length=max_len, return_tensors="pt")
999
+ input_ids = enc["input_ids"].to(dev)
1000
+ attn = enc["attention_mask"].to(dev).bool()
1001
+ emb_p = self.psmiles(input_ids, attn)
1002
+ z = self.proj_psmiles(emb_p)
1003
+ z = F.normalize(z, dim=-1)
1004
+ outs.append(z.detach().cpu().numpy())
1005
+ return np.concatenate(outs, axis=0) if outs else np.zeros((0, self.emb_dim), dtype=np.float32)
1006
+ @torch.no_grad()
1007
+ def encode_multimodal(self, records: List[dict], batch_size: int = 32, device: str = DEVICE) -> np.ndarray:
1008
+ self.eval(); dev = torch.device(device); self.to(dev)
1009
+ outs = []
1010
+ for i in range(0, len(records), batch_size):
1011
+ chunk = records[i:i + batch_size]
1012
+ psmiles_texts = [str(r.get("psmiles", "")) for r in chunk]
1013
+ p_enc = None
1014
+ if self.psm_tok is not None:
1015
+ p_enc = self.psm_tok(psmiles_texts, truncation=True, padding="max_length",
1016
+ max_length=PSMILES_MAX_LEN, return_tensors="pt")
1017
+ fp_ids, fp_attn = [], []
1018
+ for r in chunk:
1019
+ f = _parse_fingerprints(r.get("fingerprints", None), fp_len=FP_LENGTH)
1020
+ fp_ids.append(f["input_ids"]); fp_attn.append(f["attention_mask"])
1021
+ fp_ids = torch.stack(fp_ids, dim=0); fp_attn = torch.stack(fp_attn, dim=0)
1022
+ gine_all = {"z": [], "chirality": [], "formal_charge": [], "edge_index": [], "edge_attr": [], "batch": []}
1023
+ node_offset = 0
1024
+ for bi, r in enumerate(chunk):
1025
+ g = _parse_graph_for_gine(r.get("graph", None))
1026
+ if g is None or g["z"].numel() == 0: continue
1027
+ n = g["z"].size(0)
1028
+ gine_all["z"].append(g["z"]); gine_all["chirality"].append(g["chirality"]); gine_all["formal_charge"].append(g["formal_charge"])
1029
+ gine_all["batch"].append(torch.full((n,), bi, dtype=torch.long))
1030
+ ei = g["edge_index"]; ea = g["edge_attr"]
1031
+ if ei is not None and ei.numel() > 0:
1032
+ gine_all["edge_index"].append(ei + node_offset); gine_all["edge_attr"].append(ea)
1033
+ node_offset += n
1034
+ gine_batch = None
1035
+ if len(gine_all["z"]) > 0:
1036
+ z_b = torch.cat(gine_all["z"], dim=0)
1037
+ ch_b = torch.cat(gine_all["chirality"], dim=0)
1038
+ fc_b = torch.cat(gine_all["formal_charge"], dim=0)
1039
+ b_b = torch.cat(gine_all["batch"], dim=0)
1040
+ if len(gine_all["edge_index"]) > 0:
1041
+ ei_b = torch.cat(gine_all["edge_index"], dim=1)
1042
+ ea_b = torch.cat(gine_all["edge_attr"], dim=0)
1043
+ else:
1044
+ ei_b = torch.empty((2, 0), dtype=torch.long); ea_b = torch.zeros((0, 3), dtype=torch.float)
1045
+ gine_batch = {"z": z_b, "chirality": ch_b, "formal_charge": fc_b, "edge_index": ei_b, "edge_attr": ea_b, "batch": b_b}
1046
+ sch_all_z, sch_all_pos, sch_all_batch = [], [], []
1047
+ for bi, r in enumerate(chunk):
1048
+ s = _parse_geometry_for_schnet(r.get("geometry", None))
1049
+ if s is None or s["z"].numel() == 0: continue
1050
+ n = s["z"].size(0)
1051
+ sch_all_z.append(s["z"]); sch_all_pos.append(s["pos"]); sch_all_batch.append(torch.full((n,), bi, dtype=torch.long))
1052
+ schnet_batch = None
1053
+ if len(sch_all_z) > 0:
1054
+ schnet_batch = {"z": torch.cat(sch_all_z, dim=0), "pos": torch.cat(sch_all_pos, dim=0), "batch": torch.cat(sch_all_batch, dim=0)}
1055
+ batch_mods = {
1056
+ "gine": gine_batch,
1057
+ "schnet": schnet_batch,
1058
+ "fp": {"input_ids": fp_ids, "attention_mask": fp_attn},
1059
+ "psmiles": {"input_ids": p_enc["input_ids"], "attention_mask": p_enc["attention_mask"]} if p_enc is not None else None
1060
+ }
1061
+ z = self.forward_multimodal(batch_mods)
1062
+ outs.append(z.detach().cpu().numpy())
1063
+ return np.concatenate(outs, axis=0) if outs else np.zeros((0, self.emb_dim), dtype=np.float32)
1064
+
1065
+ # =============================================================================
1066
+ # SELFIES-TED decoder conditioned on CL embeddings
1067
+ # =============================================================================
1068
+
1069
+ SELFIES_TED_MODEL_NAME = os.environ.get("SELFIES_TED_MODEL_NAME", "ibm-research/materials.selfies-ted")
1070
+ HF_TOKEN = os.environ.get("HF_TOKEN", None)
1071
+
1072
+ def _hf_load_with_retries(load_fn, max_tries: int = 5, base_sleep: float = 2.0):
1073
+ last_err = None
1074
+ for t in range(max_tries):
1075
+ try:
1076
+ return load_fn()
1077
+ except Exception as e:
1078
+ last_err = e
1079
+ sleep_s = base_sleep * (1.6 ** t) + random.random()
1080
+ print(f"[WARN] HF load attempt {t+1}/{max_tries} failed: {e}. Sleeping {sleep_s:.1f}s then retry.")
1081
+ time.sleep(sleep_s)
1082
+ raise RuntimeError(f"Failed to load model from HF. Last error: {last_err}")
1083
+
1084
+ def load_selfies_ted_and_tokenizer(model_name: str = SELFIES_TED_MODEL_NAME):
1085
+ def _load_tok():
1086
+ return AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN, use_fast=True)
1087
+ def _load_model():
1088
+ return AutoModelForSeq2SeqLM.from_pretrained(model_name, token=HF_TOKEN)
1089
+ tok = _hf_load_with_retries(_load_tok, max_tries=5)
1090
+ model = _hf_load_with_retries(_load_model, max_tries=5)
1091
+ return tok, model
1092
+
1093
+ class CLConditionedSelfiesTEDGenerator(nn.Module):
1094
+ def __init__(self, tok, seq2seq_model, cl_emb_dim: int = CL_EMB_DIM, mem_len: int = 4):
1095
+ super().__init__()
1096
+ self.tok = tok
1097
+ self.model = seq2seq_model
1098
+ self.mem_len = int(mem_len)
1099
+ d_model = int(getattr(self.model.config, "d_model", 1024))
1100
+ self.cl_to_d = nn.Sequential(nn.Linear(cl_emb_dim, d_model), nn.Tanh(), nn.Dropout(0.1), nn.Linear(d_model, d_model))
1101
+ self.mem_pos = nn.Embedding(self.mem_len, d_model)
1102
+ def build_encoder_outputs(self, z: torch.Tensor) -> Tuple[BaseModelOutput, torch.Tensor]:
1103
+ device = z.device
1104
+ B = z.size(0)
1105
+ d = self.cl_to_d(z)
1106
+ d = d.unsqueeze(1).expand(B, self.mem_len, d.size(-1)).contiguous()
1107
+ pos = torch.arange(self.mem_len, device=device).unsqueeze(0).expand(B, -1)
1108
+ d = d + self.mem_pos(pos)
1109
+ attn = torch.ones((B, self.mem_len), dtype=torch.long, device=device)
1110
+ return BaseModelOutput(last_hidden_state=d), attn
1111
+ def forward_train(self, z: torch.Tensor, labels: torch.Tensor) -> Dict[str, torch.Tensor]:
1112
+ enc_out, attn = self.build_encoder_outputs(z)
1113
+ out = self.model(encoder_outputs=enc_out, attention_mask=attn, labels=labels)
1114
+ loss = out.loss
1115
+ return {"loss": loss, "ce": loss.detach()}
1116
+ @torch.no_grad()
1117
+ def generate(self, z: torch.Tensor, num_return_sequences: int = 1, max_len: int = GEN_MAX_LEN,
1118
+ top_p: float = GEN_TOP_P, temperature: float = GEN_TEMPERATURE, repetition_penalty: float = GEN_REPETITION_PENALTY) -> List[str]:
1119
+ self.eval()
1120
+ z = z.to(next(self.parameters()).device)
1121
+ enc_out, attn = self.build_encoder_outputs(z)
1122
+ gen = self.model.generate(
1123
+ encoder_outputs=enc_out,
1124
+ attention_mask=attn,
1125
+ do_sample=True,
1126
+ top_p=float(top_p),
1127
+ temperature=float(temperature),
1128
+ repetition_penalty=float(repetition_penalty),
1129
+ num_return_sequences=int(num_return_sequences),
1130
+ max_length=int(max_len),
1131
+ min_length=int(GEN_MIN_LEN),
1132
+ pad_token_id=int(self.tok.pad_token_id) if self.tok.pad_token_id is not None else None,
1133
+ eos_token_id=int(self.tok.eos_token_id) if self.tok.eos_token_id is not None else None,
1134
+ )
1135
+ outs = self.tok.batch_decode(gen, skip_special_tokens=True, clean_up_tokenization_spaces=True)
1136
+ outs = [_ensure_two_at_endpoints(_selfies_compact(o)) for o in outs]
1137
+ return outs
1138
+
1139
+ def create_optimizer_and_scheduler_decoder(model: CLConditionedSelfiesTEDGenerator):
1140
+ for p in model.parameters():
1141
+ p.requires_grad = True
1142
+ opt = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY)
1143
+ sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=NUM_EPOCHS, eta_min=COSINE_ETA_MIN)
1144
+ return opt, sch
1145
+
1146
+ # =============================================================================
1147
+ # Datasets for latent-to-SELFIES training
1148
+ # =============================================================================
1149
+
1150
+ class LatentToPSELFIESDataset(Dataset):
1151
+ def __init__(self, records: List[dict], cl_encoder: MultiModalCLPolymerEncoder, selfies_tok,
1152
+ max_len: int = GEN_MAX_LEN, latent_noise_std: float = 0.0,
1153
+ cache_embeddings: bool = True, renormalize_after_noise: bool = True, use_multimodal: bool = True):
1154
+ self.records = records
1155
+ self.cl_encoder = cl_encoder
1156
+ self.tok = selfies_tok
1157
+ self.max_len = int(max_len)
1158
+ self.latent_noise_std = float(latent_noise_std)
1159
+ self.renorm = bool(renormalize_after_noise)
1160
+ self.use_multimodal = bool(use_multimodal)
1161
+ self.pad_id = int(self.tok.pad_token_id) if getattr(self.tok, "pad_token_id", None) is not None else 1
1162
+ self._cache = None
1163
+ if cache_embeddings:
1164
+ if self.use_multimodal:
1165
+ emb = self.cl_encoder.encode_multimodal(self.records, batch_size=32, device=DEVICE)
1166
+ else:
1167
+ psm = [str(r.get("psmiles", "")) for r in self.records]
1168
+ emb = self.cl_encoder.encode_psmiles(psm, max_len=PSMILES_MAX_LEN, batch_size=64, device=DEVICE)
1169
+ self._cache = emb.astype(np.float32)
1170
+ def __len__(self):
1171
+ return len(self.records)
1172
+ def __getitem__(self, idx):
1173
+ r = self.records[idx]
1174
+ tgt = str(r["pselfies"]).strip()
1175
+ tgt = _selfies_for_tokenizer(tgt)
1176
+ if self._cache is not None:
1177
+ z = torch.tensor(self._cache[idx], dtype=torch.float32)
1178
+ else:
1179
+ if self.use_multimodal:
1180
+ z_np = self.cl_encoder.encode_multimodal([r], batch_size=1, device=DEVICE)
1181
+ z = torch.tensor(z_np[0], dtype=torch.float32)
1182
+ else:
1183
+ psm = str(r.get("psmiles", "")).strip()
1184
+ z_np = self.cl_encoder.encode_psmiles([psm], max_len=PSMILES_MAX_LEN, batch_size=1, device=DEVICE)
1185
+ z = torch.tensor(z_np[0], dtype=torch.float32)
1186
+ if self.latent_noise_std > 0:
1187
+ z = z + torch.randn_like(z) * self.latent_noise_std
1188
+ if self.renorm:
1189
+ z = F.normalize(z, dim=-1)
1190
+ enc = self.tok(tgt, truncation=True, padding="max_length", max_length=self.max_len, return_tensors=None)
1191
+ labels = torch.tensor(enc["input_ids"], dtype=torch.long)
1192
+ labels = labels.masked_fill(labels == self.pad_id, -100)
1193
+ return {"z": z, "labels": labels, "psmiles": str(r.get("psmiles", "")).strip(), "pselfies_raw": _selfies_compact(r["pselfies"])}
1194
+
1195
+ def latent_collate(batch: List[dict]) -> dict:
1196
+ z = torch.stack([b["z"] for b in batch], dim=0)
1197
+ labels = torch.stack([b["labels"] for b in batch], dim=0)
1198
+ return {"z": z, "labels": labels, "psmiles": [b["psmiles"] for b in batch], "pselfies_raw": [b["pselfies_raw"] for b in batch]}
1199
+
1200
+ def move_latent_batch_to_device(batch: dict, device: str):
1201
+ batch["z"] = batch["z"].to(device)
1202
+ batch["labels"] = batch["labels"].to(device)
1203
+
1204
+ # =============================================================================
1205
+ # Aux PSMILES property oracle (optional)
1206
+ # =============================================================================
1207
+
1208
+ class PSMILESPropertyDataset(Dataset):
1209
+ def __init__(self, samples: List[dict], psmiles_tokenizer, max_len: int = PSMILES_MAX_LEN):
1210
+ self.samples = samples
1211
+ self.tok = psmiles_tokenizer
1212
+ self.max_len = max_len
1213
+ def __len__(self):
1214
+ return len(self.samples)
1215
+ def __getitem__(self, idx):
1216
+ s = str(self.samples[idx].get("psmiles", "")).strip()
1217
+ y = float(self.samples[idx].get("target_scaled", self.samples[idx].get("target", 0.0)))
1218
+ enc = self.tok(s, truncation=True, padding="max_length", max_length=self.max_len)
1219
+ return {"input_ids": torch.tensor(enc["input_ids"], dtype=torch.long),
1220
+ "attention_mask": torch.tensor(enc["attention_mask"], dtype=torch.bool),
1221
+ "y": torch.tensor([y], dtype=torch.float32)}
1222
+
1223
+ def psmiles_prop_collate_fn(batch: List[dict]):
1224
+ input_ids = torch.stack([b["input_ids"] for b in batch], dim=0)
1225
+ attn = torch.stack([b["attention_mask"] for b in batch], dim=0)
1226
+ y = torch.stack([b["y"] for b in batch], dim=0)
1227
+ return {"input_ids": input_ids, "attention_mask": attn, "y": y}
1228
+
1229
+ class TextPropertyOracle(nn.Module):
1230
+ def __init__(self, encoder_dir: Optional[str], vocab_size: Optional[int] = None, y_dim: int = 1):
1231
+ super().__init__()
1232
+ enc_src = None
1233
+ if encoder_dir is not None and os.path.isdir(encoder_dir):
1234
+ enc_src = encoder_dir
1235
+ elif os.path.isdir(BEST_PSMILES_DIR):
1236
+ enc_src = BEST_PSMILES_DIR
1237
+ else:
1238
+ enc_src = "microsoft/deberta-v2-xlarge"
1239
+ self.encoder = PSMILESDebertaEncoder(model_dir_or_name=enc_src, vocab_size=vocab_size)
1240
+ h = getattr(self.encoder, "out_dim", DEBERTA_HIDDEN)
1241
+ self.head = nn.Sequential(nn.Linear(h, 256), nn.ReLU(), nn.Dropout(0.1), nn.Linear(256, y_dim))
1242
+ def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
1243
+ h = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
1244
+ return self.head(h)
1245
+
1246
+ def move_prop_batch_to_device(batch: dict, device: str):
1247
+ batch["input_ids"] = batch["input_ids"].to(device)
1248
+ batch["attention_mask"] = batch["attention_mask"].to(device)
1249
+ batch["y"] = batch["y"].to(device)
1250
+
1251
+ def train_prop_oracle_one_epoch(model: TextPropertyOracle, dl: DataLoader, opt, scaler_amp, device: str):
1252
+ model.train()
1253
+ total = 0.0; n = 0
1254
+ for batch in dl:
1255
+ move_prop_batch_to_device(batch, device)
1256
+ y = batch["y"]
1257
+ opt.zero_grad(set_to_none=True)
1258
+ with torch.cuda.amp.autocast(enabled=USE_AMP, dtype=AMP_DTYPE):
1259
+ y_hat = model(batch["input_ids"], batch["attention_mask"])
1260
+ loss = F.smooth_l1_loss(y_hat, y, beta=1.0)
1261
+ if USE_AMP:
1262
+ scaler_amp.scale(loss).backward()
1263
+ scaler_amp.unscale_(opt)
1264
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
1265
+ scaler_amp.step(opt)
1266
+ scaler_amp.update()
1267
+ else:
1268
+ loss.backward()
1269
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
1270
+ opt.step()
1271
+ bs = y.size(0)
1272
+ total += float(loss.item()) * bs; n += bs
1273
+ return total / max(1, n)
1274
+
1275
+ @torch.no_grad()
1276
+ def eval_prop_oracle(model: TextPropertyOracle, dl: DataLoader, device: str):
1277
+ model.eval()
1278
+ total = 0.0; n = 0
1279
+ for batch in dl:
1280
+ move_prop_batch_to_device(batch, device)
1281
+ y = batch["y"]
1282
+ with torch.cuda.amp.autocast(enabled=USE_AMP, dtype=AMP_DTYPE):
1283
+ y_hat = model(batch["input_ids"], batch["attention_mask"])
1284
+ loss = F.smooth_l1_loss(y_hat, y, beta=1.0)
1285
+ bs = y.size(0); total += float(loss.item()) * bs; n += bs
1286
+ return total / max(1, n)
1287
+
1288
+ def train_property_oracle_per_fold(train_samples: List[dict], val_samples: List[dict], psmiles_tokenizer, device: str, max_len: int = PSMILES_MAX_LEN) -> Optional[TextPropertyOracle]:
1289
+ if psmiles_tokenizer is None:
1290
+ return None
1291
+ try:
1292
+ model = TextPropertyOracle(
1293
+ encoder_dir=BEST_PSMILES_DIR if os.path.isdir(BEST_PSMILES_DIR) else None,
1294
+ vocab_size=getattr(psmiles_tokenizer, "vocab_size", None),
1295
+ y_dim=1
1296
+ ).to(device)
1297
+ except Exception as e:
1298
+ print(f"[WARN] Could not initialize auxiliary property predictor: {e}")
1299
+ return None
1300
+ for p in model.encoder.parameters(): p.requires_grad = False
1301
+ for p in model.head.parameters(): p.requires_grad = True
1302
+ ds_tr = PSMILESPropertyDataset(train_samples, psmiles_tokenizer, max_len=max_len)
1303
+ ds_va = PSMILESPropertyDataset(val_samples, psmiles_tokenizer, max_len=max_len)
1304
+ dl_tr = DataLoader(ds_tr, batch_size=PROP_PRED_BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, collate_fn=psmiles_prop_collate_fn)
1305
+ dl_va = DataLoader(ds_va, batch_size=PROP_PRED_BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True, collate_fn=psmiles_prop_collate_fn)
1306
+ opt = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=PROP_PRED_LR, weight_decay=PROP_PRED_WEIGHT_DECAY)
1307
+ scaler_amp = torch.cuda.amp.GradScaler(enabled=USE_AMP)
1308
+ best_val = float("inf"); best_state = None; no_imp = 0
1309
+ for epoch in range(1, PROP_PRED_EPOCHS + 1):
1310
+ tr = train_prop_oracle_one_epoch(model, dl_tr, opt, scaler_amp, device)
1311
+ va = eval_prop_oracle(model, dl_va, device)
1312
+ if va < best_val - 1e-8:
1313
+ best_val = va; no_imp = 0
1314
+ best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
1315
+ else:
1316
+ no_imp += 1
1317
+ if no_imp >= PROP_PRED_PATIENCE:
1318
+ break
1319
+ if best_state is not None:
1320
+ model.load_state_dict({k: v.to(device) for k, v in best_state.items()}, strict=False)
1321
+ try: model.aux_val_loss = float(best_val)
1322
+ except Exception: pass
1323
+ return model
1324
+
1325
+ @torch.no_grad()
1326
+ def oracle_predict_scaled(oracle: Optional[TextPropertyOracle], psmiles_tokenizer, psmiles_list: List[str], device: str, max_len: int = PSMILES_MAX_LEN) -> Optional[np.ndarray]:
1327
+ if oracle is None or psmiles_tokenizer is None:
1328
+ return None
1329
+ if not psmiles_list:
1330
+ return np.array([], dtype=np.float32)
1331
+ oracle.eval()
1332
+ ys = []; bs = 32
1333
+ for i in range(0, len(psmiles_list), bs):
1334
+ chunk = psmiles_list[i:i+bs]
1335
+ enc = psmiles_tokenizer(chunk, truncation=True, padding="max_length", max_length=max_len, return_tensors="pt")
1336
+ input_ids = enc["input_ids"].to(device)
1337
+ attn = enc["attention_mask"].to(device).bool()
1338
+ with torch.cuda.amp.autocast(enabled=USE_AMP, dtype=AMP_DTYPE):
1339
+ y_hat = oracle(input_ids, attn)
1340
+ ys.append(y_hat.detach().cpu().numpy().reshape(-1))
1341
+ return np.concatenate(ys, axis=0) if ys else np.array([], dtype=np.float32)
1342
+
1343
+ # =============================================================================
1344
+ # PolyBART-style latent property model (per property)
1345
+ # =============================================================================
1346
+
1347
+ @dataclass
1348
+ class LatentPropertyModel:
1349
+ y_scaler: StandardScaler
1350
+ pca: Optional[PCA]
1351
+ gpr: GaussianProcessRegressor
1352
+
1353
+ def fit_latent_property_model(z_train: np.ndarray, y_train: np.ndarray, y_scaler: StandardScaler) -> LatentPropertyModel:
1354
+ y_train = np.array(y_train, dtype=np.float32).reshape(-1, 1)
1355
+ y_s = y_scaler.transform(y_train).reshape(-1).astype(np.float32)
1356
+ z_use = z_train.astype(np.float32)
1357
+ pca = None
1358
+ if USE_PCA_BEFORE_GPR:
1359
+ ncomp = int(min(PCA_DIM, z_use.shape[0] - 1, z_use.shape[1]))
1360
+ ncomp = max(2, ncomp)
1361
+ pca = PCA(n_components=ncomp, random_state=0)
1362
+ z_use = pca.fit_transform(z_use)
1363
+ kernel = C(1.0, (1e-3, 1e3)) * RBF(length_scale=1.0, length_scale_bounds=(1e-2, 1e2)) + WhiteKernel(noise_level=1e-3, noise_level_bounds=(1e-6, 1e-1))
1364
+ gpr = GaussianProcessRegressor(kernel=kernel, alpha=GPR_ALPHA, normalize_y=True, random_state=0, n_restarts_optimizer=2)
1365
+ gpr.fit(z_use, y_s)
1366
+ return LatentPropertyModel(y_scaler=y_scaler, pca=pca, gpr=gpr)
1367
+
1368
+ def predict_latent_property(model: LatentPropertyModel, z: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
1369
+ z_use = z.astype(np.float32)
1370
+ if model.pca is not None:
1371
+ z_use = model.pca.transform(z_use)
1372
+ y_s = model.gpr.predict(z_use, return_std=False)
1373
+ y_s = np.array(y_s, dtype=np.float32).reshape(-1)
1374
+ y_u = model.y_scaler.inverse_transform(y_s.reshape(-1, 1)).reshape(-1)
1375
+ return y_s, y_u
1376
+
1377
+ # =============================================================================
1378
+ # Train / eval loops (decoder)
1379
+ # =============================================================================
1380
+
1381
+ def train_one_epoch_decoder(model: CLConditionedSelfiesTEDGenerator, dl: DataLoader, optimizer, scaler_amp, device: str):
1382
+ model.train()
1383
+ total = 0.0; n = 0; ce_sum = 0.0
1384
+ for batch in dl:
1385
+ move_latent_batch_to_device(batch, device)
1386
+ z = batch["z"]; labels = batch["labels"]
1387
+ optimizer.zero_grad(set_to_none=True)
1388
+ with torch.cuda.amp.autocast(enabled=USE_AMP, dtype=AMP_DTYPE):
1389
+ out = model.forward_train(z, labels)
1390
+ loss = out["loss"]
1391
+ if USE_AMP:
1392
+ scaler_amp.scale(loss).backward()
1393
+ scaler_amp.unscale_(optimizer)
1394
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
1395
+ scaler_amp.step(optimizer)
1396
+ scaler_amp.update()
1397
+ else:
1398
+ loss.backward()
1399
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
1400
+ optimizer.step()
1401
+ bs = z.size(0)
1402
+ total += float(loss.item()) * bs; ce_sum += float(out["ce"].item()) * bs; n += bs
1403
+ return {"loss": total / max(1, n), "ce": ce_sum / max(1, n)}
1404
+
1405
+ @torch.no_grad()
1406
+ def evaluate_decoder(model: CLConditionedSelfiesTEDGenerator, dl: DataLoader, device: str):
1407
+ model.eval()
1408
+ total = 0.0; n = 0; ce_sum = 0.0
1409
+ for batch in dl:
1410
+ move_latent_batch_to_device(batch, device)
1411
+ z = batch["z"]; labels = batch["labels"]
1412
+ with torch.cuda.amp.autocast(enabled=USE_AMP, dtype=AMP_DTYPE):
1413
+ out = model.forward_train(z, labels)
1414
+ loss = out["loss"]
1415
+ bs = z.size(0)
1416
+ total += float(loss.item()) * bs; ce_sum += float(out["ce"].item()) * bs; n += bs
1417
+ return {"loss": total / max(1, n), "ce": ce_sum / max(1, n)}
1418
+
1419
+ # =============================================================================
1420
+ # Generation / filtering (per target value, per property)
1421
+ # =============================================================================
1422
+
1423
+ def compute_diversity_morgan(smiles_list: List[str], radius: int = 2, nbits: int = 2048, p: float = 1.0) -> Optional[float]:
1424
+ if not RDKit_AVAILABLE:
1425
+ return None
1426
+ try:
1427
+ p = float(p)
1428
+ if not np.isfinite(p) or p <= 0:
1429
+ p = 1.0
1430
+ except Exception:
1431
+ p = 1.0
1432
+ uniq = []
1433
+ seen = set()
1434
+ for smi in smiles_list:
1435
+ smi = str(smi).strip()
1436
+ if not smi or smi in seen:
1437
+ continue
1438
+ seen.add(smi)
1439
+ uniq.append(smi)
1440
+ fps = []
1441
+ for smi in uniq:
1442
+ try:
1443
+ mol = Chem.MolFromSmiles(smi)
1444
+ if mol is None:
1445
+ continue
1446
+ try:
1447
+ Chem.SanitizeMol(mol, catchErrors=True)
1448
+ except Exception:
1449
+ continue
1450
+ fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nbits)
1451
+ fps.append(fp)
1452
+ except Exception:
1453
+ continue
1454
+ if len(fps) < 2:
1455
+ return 0.0 if len(fps) == 1 else None
1456
+ sims_p = []
1457
+ for i in range(len(fps)):
1458
+ for j in range(i + 1, len(fps)):
1459
+ try:
1460
+ s = float(DataStructs.TanimotoSimilarity(fps[i], fps[j]))
1461
+ sims_p.append(s ** p)
1462
+ except Exception:
1463
+ continue
1464
+ if not sims_p:
1465
+ return None
1466
+ mean_sim_p = float(np.mean(sims_p))
1467
+ try:
1468
+ mean_sim = mean_sim_p ** (1.0 / p)
1469
+ except Exception:
1470
+ mean_sim = float(np.mean([float(DataStructs.TanimotoSimilarity(fps[i], fps[j])) for i in range(len(fps)) for j in range(i + 1, len(fps))]))
1471
+ return float(1.0 - mean_sim)
1472
+
1473
+ @torch.no_grad()
1474
+ def decode_from_latents(generator: CLConditionedSelfiesTEDGenerator, z: torch.Tensor, n_samples: int = 1) -> List[str]:
1475
+ outs = generator.generate(z=z, num_return_sequences=int(n_samples), max_len=GEN_MAX_LEN,
1476
+ top_p=GEN_TOP_P, temperature=GEN_TEMPERATURE, repetition_penalty=GEN_REPETITION_PENALTY)
1477
+ return outs
1478
+
1479
+ def polybart_style_generate_for_target(
1480
+ target_y_scaled: float,
1481
+ prop_model: LatentPropertyModel,
1482
+ cl_encoder: MultiModalCLPolymerEncoder,
1483
+ generator: CLConditionedSelfiesTEDGenerator,
1484
+ train_seed_pool: List[dict],
1485
+ train_targets_set: set,
1486
+ n_seeds: int = 8,
1487
+ n_noise: int = N_FOLD_NOISE_SAMPLING,
1488
+ noise_std: float = LATENT_NOISE_STD_GEN,
1489
+ prop_tol_scaled: float = PROP_TOL_SCALED,
1490
+ oracle: Optional[TextPropertyOracle] = None,
1491
+ psmiles_tokenizer=None,
1492
+ ) -> Dict[str, Any]:
1493
+
1494
+ def _l2_normalize_np(x: np.ndarray, eps: float = 1e-12) -> np.ndarray:
1495
+ n = np.linalg.norm(x, axis=-1, keepdims=True)
1496
+ return x / np.clip(n, eps, None)
1497
+
1498
+ ys = np.array([float(d["y_scaled"]) for d in train_seed_pool], dtype=np.float32)
1499
+ diffs = np.abs(ys - float(target_y_scaled))
1500
+ order = np.argsort(diffs)
1501
+ chosen = [train_seed_pool[i] for i in order[:max(1, int(n_seeds))]]
1502
+
1503
+ z_seed = cl_encoder.encode_multimodal(chosen, batch_size=32, device=DEVICE)
1504
+ if z_seed.shape[0] == 0:
1505
+ return {"generated": [], "metrics": {}}
1506
+
1507
+ z_list = []
1508
+ for i in range(z_seed.shape[0]):
1509
+ z0 = z_seed[i].astype(np.float32)
1510
+ for _ in range(int(n_noise)):
1511
+ z = z0 + np.random.randn(z0.shape[0]).astype(np.float32) * float(noise_std)
1512
+ z = _l2_normalize_np(z.reshape(1, -1)).reshape(-1)
1513
+ z_list.append(z)
1514
+
1515
+ z_all = np.stack(z_list, axis=0).astype(np.float32)
1516
+ z_t = torch.tensor(z_all, dtype=torch.float32, device=DEVICE)
1517
+
1518
+ pselfies = decode_from_latents(generator, z_t, n_samples=1)
1519
+
1520
+ valid_psmiles = []
1521
+ valid_flags, poly_flags = [], []
1522
+
1523
+ for s in pselfies:
1524
+ s = _ensure_two_at_endpoints(_selfies_compact(s))
1525
+ psm = pselfies_to_psmiles(s) if (RDKit_AVAILABLE and SELFIES_AVAILABLE) else None
1526
+ if psm is None:
1527
+ valid_flags.append(False); poly_flags.append(False)
1528
+ continue
1529
+ psm_can = canonicalize_psmiles(psm)
1530
+ ok = chem_validity_psmiles(psm_can) if psm_can else False
1531
+ poly_ok = polymer_validity_psmiles_strict(psm_can) if psm_can else False
1532
+ valid_flags.append(bool(ok)); poly_flags.append(bool(poly_ok))
1533
+ if ok and poly_ok and psm_can:
1534
+ valid_psmiles.append(psm_can)
1535
+
1536
+ uniq_valid = sorted(set(valid_psmiles))
1537
+ novelty_valid = [1.0 if s not in train_targets_set else 0.0 for s in uniq_valid] if uniq_valid else []
1538
+
1539
+ n_valid_poly = int(len(valid_psmiles))
1540
+ uniqueness_valid_unique = float(len(uniq_valid)) / float(max(1, n_valid_poly)) if n_valid_poly > 0 else 0.0
1541
+
1542
+ if uniq_valid:
1543
+ z_cand = cl_encoder.encode_psmiles(uniq_valid, max_len=PSMILES_MAX_LEN, batch_size=64, device=DEVICE)
1544
+ else:
1545
+ z_cand = np.zeros((0, cl_encoder.out_dim), dtype=np.float32)
1546
+
1547
+ yhat_s, yhat_u = (np.array([], dtype=np.float32), np.array([], dtype=np.float32))
1548
+ if z_cand.shape[0] > 0:
1549
+ yhat_s, yhat_u = predict_latent_property(prop_model, z_cand)
1550
+
1551
+ keep, keep_pred_scaled, keep_pred_unscaled = [], [], []
1552
+ for i, psm in enumerate(uniq_valid):
1553
+ if abs(float(yhat_s[i]) - float(target_y_scaled)) <= float(prop_tol_scaled):
1554
+ keep.append(psm)
1555
+ keep_pred_scaled.append(float(yhat_s[i]))
1556
+ keep_pred_unscaled.append(float(yhat_u[i]))
1557
+
1558
+ novelty_keep = [1.0 if s not in train_targets_set else 0.0 for s in keep] if keep else []
1559
+
1560
+ aux_pred_scaled = None
1561
+ if VERIFY_GENERATED_PROPERTIES and oracle is not None and psmiles_tokenizer is not None and keep:
1562
+ aux = oracle_predict_scaled(oracle, psmiles_tokenizer, keep, DEVICE, PSMILES_MAX_LEN)
1563
+ aux_pred_scaled = aux.tolist() if aux is not None else None
1564
+
1565
+ at_smiles = []
1566
+ if RDKit_AVAILABLE and keep:
1567
+ for psm in keep:
1568
+ at_smi = psmiles_to_at_smiles(psm, root_at=False)
1569
+ if at_smi is not None:
1570
+ at_smiles.append(at_smi)
1571
+ div = compute_diversity_morgan(at_smiles) if at_smiles else None
1572
+
1573
+ metrics = {
1574
+ "n_total": int(len(pselfies)),
1575
+ "validity": float(np.mean(valid_flags)) if valid_flags else 0.0,
1576
+ "polymer_validity": float(np.mean(poly_flags)) if poly_flags else 0.0,
1577
+ "n_valid_unique": int(len(uniq_valid)),
1578
+ "novelty_valid_unique": float(np.mean(novelty_valid)) if novelty_valid else 0.0,
1579
+ "uniqueness_valid_unique": float(uniqueness_valid_unique),
1580
+ "n_kept_property_filtered": int(len(keep)),
1581
+ "novelty_kept": float(np.mean(novelty_keep)) if novelty_keep else 0.0,
1582
+ "diversity": float(div) if div is not None else 0.0,
1583
+ }
1584
+
1585
+ return {
1586
+ "generated": keep,
1587
+ "pred_scaled_kept": keep_pred_scaled,
1588
+ "pred_unscaled_kept": keep_pred_unscaled,
1589
+ "aux_pred_scaled": aux_pred_scaled,
1590
+ "metrics": metrics,
1591
+ }
1592
+
1593
+ # =============================================================================
1594
+ # Data assembly (per property)
1595
+ # =============================================================================
1596
+
1597
+ def build_polymer_records(df: pd.DataFrame, prop_col: str) -> List[dict]:
1598
+ """
1599
+ Build records for a single property:
1600
+ - require valid polymer psmiles (chem + polymer validity)
1601
+ - carry pselfies, modalities, and the *single* target value
1602
+ """
1603
+ if not (RDKit_AVAILABLE and SELFIES_AVAILABLE):
1604
+ raise RuntimeError("RDKit + selfies are required for this PolyBART-style pipeline.")
1605
+
1606
+ recs = []
1607
+ for _, row in df.iterrows():
1608
+ psmiles_raw = str(row.get("psmiles", "")).strip()
1609
+ if not psmiles_raw:
1610
+ continue
1611
+ psm_can = canonicalize_psmiles(psmiles_raw)
1612
+ if not psm_can:
1613
+ continue
1614
+ if not chem_validity_psmiles(psm_can):
1615
+ continue
1616
+ if not polymer_validity_psmiles_strict(psm_can):
1617
+ continue
1618
+
1619
+ val = row.get(prop_col, None)
1620
+ if val is None:
1621
+ continue
1622
+ try:
1623
+ y = float(val)
1624
+ if not np.isfinite(y):
1625
+ continue
1626
+ except Exception:
1627
+ continue
1628
+
1629
+ pself = psmiles_to_pselfies(psm_can)
1630
+ if pself is None:
1631
+ continue
1632
+
1633
+ recs.append({
1634
+ "psmiles": psm_can,
1635
+ "pselfies": pself,
1636
+ "y": y,
1637
+ "graph": row.get("graph", None),
1638
+ "geometry": row.get("geometry", None),
1639
+ "fingerprints": row.get("fingerprints", None),
1640
+ })
1641
+ return recs
1642
+
1643
+ # =============================================================================
1644
+ # Best-fold artifact saving (per property, G1-style)
1645
+ # =============================================================================
1646
+
1647
+ def save_best_fold_artifacts_for_property(
1648
+ property_name: str,
1649
+ fold_idx: int,
1650
+ decoder_state: Dict[str, torch.Tensor],
1651
+ prop_model: Optional[LatentPropertyModel],
1652
+ scaler: Optional[StandardScaler],
1653
+ best_val_loss: float,
1654
+ generations_payload: List[dict],
1655
+ ):
1656
+ safe_prop = property_name.replace(" ", "_")
1657
+ prop_dir = os.path.join(OUTPUT_MODELS_DIR, safe_prop)
1658
+ os.makedirs(prop_dir, exist_ok=True)
1659
+
1660
+ # Decoder weights (state dict only; like G1 best_state)
1661
+ decoder_path = os.path.join(prop_dir, f"decoder_best_fold{fold_idx+1}.pt")
1662
+ torch.save(decoder_state, decoder_path)
1663
+
1664
+ # Scaler + GPR
1665
+ try:
1666
+ import joblib
1667
+ except Exception:
1668
+ joblib = None
1669
+ if joblib is not None:
1670
+ if scaler is not None:
1671
+ joblib.dump(scaler, os.path.join(prop_dir, f"standardscaler_{safe_prop}.joblib"))
1672
+ if prop_model is not None:
1673
+ joblib.dump(prop_model, os.path.join(prop_dir, f"gpr_psmiles_{safe_prop}.joblib"))
1674
+
1675
+ # Meta (mirror G1 spirit)
1676
+ meta = {
1677
+ "property": property_name,
1678
+ "best_fold": int(fold_idx + 1),
1679
+ "best_val_loss": float(best_val_loss),
1680
+ "selfies_ted_model": str(SELFIES_TED_MODEL_NAME),
1681
+ "cl_emb_dim": int(CL_EMB_DIM),
1682
+ "mem_len": 4,
1683
+ "tol_scaled": float(PROP_TOL_SCALED),
1684
+ "tol_unscaled_abs": float(PROP_TOL_UNSCALED_ABS) if PROP_TOL_UNSCALED_ABS is not None else None,
1685
+ "optimizer": "AdamW",
1686
+ "lr": float(LEARNING_RATE),
1687
+ "weight_decay": float(WEIGHT_DECAY),
1688
+ "lr_scheduler": "CosineAnnealingLR",
1689
+ "epochs": int(NUM_EPOCHS),
1690
+ "batch_size": int(BATCH_SIZE),
1691
+ "patience": int(PATIENCE),
1692
+ }
1693
+ try:
1694
+ with open(os.path.join(prop_dir, "meta.json"), "w", encoding="utf-8") as f:
1695
+ json.dump(meta, f, indent=2)
1696
+ except Exception:
1697
+ pass
1698
+
1699
+ # Save generations (jsonl) for this best fold
1700
+ out_path = os.path.join(OUTPUT_GENERATIONS_DIR, f"{safe_prop}_best_fold{fold_idx+1}_generated_psmiles.jsonl")
1701
+ try:
1702
+ with open(out_path, "w", encoding="utf-8") as fh:
1703
+ for r in generations_payload:
1704
+ fh.write(json.dumps(make_json_serializable({"property": property_name, "best_fold": fold_idx+1, **r})) + "\n")
1705
+ except Exception as e:
1706
+ print(f"[WARN] Could not write generations for '{property_name}': {e}")
1707
+
1708
+ # =============================================================================
1709
+ # Main per-property CV loop (single-task; mirrors G1)
1710
+ # =============================================================================
1711
+
1712
+ def run_inverse_design_single_property(
1713
+ df: pd.DataFrame,
1714
+ property_name: str,
1715
+ prop_col: str,
1716
+ cl_encoder: MultiModalCLPolymerEncoder,
1717
+ selfies_tok,
1718
+ selfies_model
1719
+ ) -> Dict[str, Any]:
1720
+
1721
+ # Build all valid polymers for this property
1722
+ polymers = build_polymer_records(df, prop_col)
1723
+ if len(polymers) < 200:
1724
+ print(f"[WARN] Too few samples for '{property_name}': {len(polymers)}. Results may be unstable.")
1725
+ if len(polymers) < 50:
1726
+ print(f"[WARN] Skipping '{property_name}' due to insufficient samples.")
1727
+ return {"property": property_name, "runs": [], "agg": None, "n_samples": len(polymers)}
1728
+
1729
+ indices = np.arange(len(polymers))
1730
+ kf = KFold(n_splits=NUM_FOLDS, shuffle=True, random_state=42)
1731
+
1732
+ runs = []
1733
+ best_overall_val = float("inf")
1734
+ best_bundle = None
1735
+
1736
+ # For novelty computation
1737
+ all_targets_set = set(p["psmiles"] for p in polymers)
1738
+
1739
+ for fold_idx, (trainval_idx, test_idx) in enumerate(kf.split(indices)):
1740
+ seed = 42 + fold_idx
1741
+ set_seed(seed)
1742
+ print(f"\n=== {property_name} | fold {fold_idx+1}/{NUM_FOLDS} ===")
1743
+
1744
+ trainval_polys = [polymers[i] for i in trainval_idx]
1745
+ test_polys = [polymers[i] for i in test_idx]
1746
+
1747
+ # Train/val split within trainval (same spirit as G1: separate val)
1748
+ tr_idx, va_idx = train_test_split(np.arange(len(trainval_polys)), test_size=0.10, random_state=seed, shuffle=True)
1749
+ train_polys = [copy.deepcopy(trainval_polys[i]) for i in tr_idx]
1750
+ val_polys = [copy.deepcopy(trainval_polys[i]) for i in va_idx]
1751
+
1752
+ # Scaler fit on TRAIN targets only (G1-style)
1753
+ sc = StandardScaler()
1754
+ sc.fit(np.array([p["y"] for p in train_polys], dtype=np.float32).reshape(-1, 1))
1755
+
1756
+ # Train datasets (latent-to-SELFIES) use multimodal encoding for seeds
1757
+ def _to_rec(p):
1758
+ return {
1759
+ "psmiles": p["psmiles"],
1760
+ "pselfies": p["pselfies"],
1761
+ "graph": p.get("graph", None),
1762
+ "geometry": p.get("geometry", None),
1763
+ "fingerprints": p.get("fingerprints", None),
1764
+ }
1765
+
1766
+ ds_train = LatentToPSELFIESDataset([_to_rec(p) for p in train_polys], cl_encoder, selfies_tok,
1767
+ max_len=GEN_MAX_LEN, latent_noise_std=LATENT_NOISE_STD_TRAIN,
1768
+ cache_embeddings=True, use_multimodal=True)
1769
+ ds_val = LatentToPSELFIESDataset([_to_rec(p) for p in val_polys], cl_encoder, selfies_tok,
1770
+ max_len=GEN_MAX_LEN, latent_noise_std=0.0,
1771
+ cache_embeddings=True, use_multimodal=True)
1772
+
1773
+ dl_train = DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, collate_fn=latent_collate)
1774
+ dl_val = DataLoader(ds_val, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True, collate_fn=latent_collate)
1775
+
1776
+ # Fit GPR on PSMILES latent for THIS property fold (train only)
1777
+ y_tr = [float(p["y"]) for p in train_polys]
1778
+ psm_tr = [p["psmiles"] for p in train_polys]
1779
+ z_tr = cl_encoder.encode_psmiles(psm_tr, max_len=PSMILES_MAX_LEN, batch_size=64, device=DEVICE)
1780
+ prop_model = fit_latent_property_model(z_tr, np.array(y_tr, dtype=np.float32), y_scaler=sc)
1781
+ print(f"[INFO] Fit PSMILES-latent GPR for '{property_name}' fold {fold_idx+1} (n={len(y_tr)}).")
1782
+
1783
+ # Optional aux oracle (scaled target)
1784
+ oracle = None
1785
+ if VERIFY_GENERATED_PROPERTIES and len(train_polys) >= 200 and len(val_polys) >= 50:
1786
+ tr_s, va_s = [], []
1787
+ for p in train_polys:
1788
+ y_s = float(sc.transform(np.array([[p["y"]]], dtype=np.float32))[0, 0])
1789
+ tr_s.append({"psmiles": p["psmiles"], "target": p["y"], "target_scaled": y_s})
1790
+ for p in val_polys:
1791
+ y_s = float(sc.transform(np.array([[p["y"]]], dtype=np.float32))[0, 0])
1792
+ va_s.append({"psmiles": p["psmiles"], "target": p["y"], "target_scaled": y_s})
1793
+ try:
1794
+ oracle = train_property_oracle_per_fold(tr_s, va_s, cl_encoder.psm_tok, DEVICE, PSMILES_MAX_LEN)
1795
+ print(f"[INFO] Trained aux verification predictor for '{property_name}' fold {fold_idx+1}.")
1796
+ except Exception as e:
1797
+ print(f"[WARN] Oracle training failed for '{property_name}': {e}")
1798
+ oracle = None
1799
+
1800
+ # Fresh decoder per fold (single-task) + optimizer (single head)
1801
+ selfies_tok_f, selfies_model_f = load_selfies_ted_and_tokenizer(SELFIES_TED_MODEL_NAME)
1802
+ decoder = CLConditionedSelfiesTEDGenerator(selfies_tok_f, selfies_model_f, cl_emb_dim=CL_EMB_DIM, mem_len=4).to(DEVICE)
1803
+ optimizer, scheduler = create_optimizer_and_scheduler_decoder(decoder)
1804
+ scaler_amp = torch.cuda.amp.GradScaler(enabled=USE_AMP)
1805
+
1806
+ best_val = float("inf")
1807
+ best_state = None
1808
+ no_improve = 0
1809
+
1810
+ for epoch in range(1, NUM_EPOCHS + 1):
1811
+ tr = train_one_epoch_decoder(decoder, dl_train, optimizer, scaler_amp, DEVICE)
1812
+ va = evaluate_decoder(decoder, dl_val, DEVICE)
1813
+ try: scheduler.step()
1814
+ except Exception: pass
1815
+ try:
1816
+ lr = float(optimizer.param_groups[0]["lr"])
1817
+ print(f"[{property_name}] fold {fold_idx+1}/{NUM_FOLDS} epoch {epoch:03d} | lr={lr:.2e} | train={tr['loss']:.6f} | val={va['loss']:.6f}")
1818
+ except Exception:
1819
+ print(f"[{property_name}] fold {fold_idx+1}/{NUM_FOLDS} epoch {epoch:03d} | train={tr['loss']:.6f} | val={va['loss']:.6f}")
1820
+
1821
+ if va["loss"] < best_val - 1e-8:
1822
+ best_val = va["loss"]; no_improve = 0
1823
+ best_state = {k: v.detach().cpu().clone() for k, v in decoder.state_dict().items()}
1824
+ else:
1825
+ no_improve += 1
1826
+ if no_improve >= PATIENCE:
1827
+ print(f"[{property_name}] Early stopping at epoch {epoch} (patience={PATIENCE}).")
1828
+ break
1829
+
1830
+ if best_state is None:
1831
+ print(f"[WARN] No best state saved for {property_name} fold {fold_idx+1}; skipping fold.")
1832
+ continue
1833
+
1834
+ decoder.load_state_dict({k: v.to(DEVICE) for k, v in best_state.items()}, strict=False)
1835
+
1836
+ # Prepare seed pool in scaled space
1837
+ seed_pool = []
1838
+ for p in train_polys:
1839
+ y_s = float(sc.transform(np.array([[p["y"]]], dtype=np.float32))[0, 0])
1840
+ seed_pool.append({
1841
+ "psmiles": p["psmiles"],
1842
+ "y_scaled": y_s,
1843
+ "graph": p.get("graph", None),
1844
+ "geometry": p.get("geometry", None),
1845
+ "fingerprints": p.get("fingerprints", None),
1846
+ })
1847
+
1848
+ # Train target set (novelty)
1849
+ train_targets_set = set(ps["psmiles"] for ps in train_polys)
1850
+
1851
+ # Compose test targets (scaled) — sample up to 64 to keep compute modest
1852
+ ys_test_scaled = []
1853
+ for p in test_polys:
1854
+ ys_test_scaled.append(float(sc.transform(np.array([[p["y"]]], dtype=np.float32))[0, 0]))
1855
+ ys_test_scaled = np.array(ys_test_scaled, dtype=np.float32)
1856
+ if len(ys_test_scaled) > 64:
1857
+ ys_test_scaled = np.random.choice(ys_test_scaled, size=64, replace=False)
1858
+
1859
+ # Generate per target; collect metrics and candidate payload
1860
+ all_valid, all_poly, all_kept, success_scaled, mae_best, diversity_vals = [], [], [], [], [], []
1861
+ novelty_vals, uniqueness_vals = [], []
1862
+ per_target_records = []
1863
+
1864
+ for y_t in ys_test_scaled:
1865
+ out = polybart_style_generate_for_target(
1866
+ target_y_scaled=float(y_t),
1867
+ prop_model=prop_model,
1868
+ cl_encoder=cl_encoder,
1869
+ generator=decoder,
1870
+ train_seed_pool=seed_pool,
1871
+ train_targets_set=train_targets_set,
1872
+ n_seeds=8,
1873
+ n_noise=min(N_FOLD_NOISE_SAMPLING, 16),
1874
+ noise_std=LATENT_NOISE_STD_GEN,
1875
+ prop_tol_scaled=PROP_TOL_SCALED,
1876
+ oracle=oracle,
1877
+ psmiles_tokenizer=cl_encoder.psm_tok,
1878
+ )
1879
+ m = out["metrics"]
1880
+ all_valid.append(float(m.get("validity", 0.0)))
1881
+ all_poly.append(float(m.get("polymer_validity", 0.0)))
1882
+ all_kept.append(int(m.get("n_kept_property_filtered", 0)))
1883
+ diversity_vals.append(float(m.get("diversity", 0.0)))
1884
+ success_scaled.append(1.0 if int(m.get("n_kept_property_filtered", 0)) > 0 else 0.0)
1885
+ novelty_vals.append(float(m.get("novelty_kept", 0.0)))
1886
+ uniqueness_vals.append(float(m.get("uniqueness_valid_unique", 0.0)))
1887
+
1888
+ if out["generated"]:
1889
+ z_keep = cl_encoder.encode_psmiles(out["generated"], max_len=PSMILES_MAX_LEN, batch_size=64, device=DEVICE)
1890
+ y_pred_s, _ = predict_latent_property(prop_model, z_keep)
1891
+ if len(y_pred_s):
1892
+ err = np.abs(y_pred_s - float(y_t))
1893
+ mae_best.append(float(np.min(err)))
1894
+ else:
1895
+ mae_best.append(float("inf"))
1896
+ else:
1897
+ mae_best.append(float("inf"))
1898
+
1899
+ target_y_unscaled = float(sc.inverse_transform(np.array([[float(y_t)]], dtype=np.float32))[0, 0])
1900
+ aux_list = out.get("aux_pred_scaled", None)
1901
+ if aux_list is not None and not isinstance(aux_list, list):
1902
+ aux_list = None
1903
+
1904
+ candidates = []
1905
+ gen_list = out.get("generated", []) or []
1906
+ pred_s_list = out.get("pred_scaled_kept", []) or []
1907
+ pred_u_list = out.get("pred_unscaled_kept", []) or []
1908
+ for i_c, psm in enumerate(gen_list):
1909
+ cand = {
1910
+ "psmiles": str(psm),
1911
+ "pred_scaled": float(pred_s_list[i_c]) if i_c < len(pred_s_list) else None,
1912
+ "pred_unscaled": float(pred_u_list[i_c]) if i_c < len(pred_u_list) else None,
1913
+ "aux_pred_scaled": float(aux_list[i_c]) if (aux_list is not None and i_c < len(aux_list)) else None,
1914
+ }
1915
+ candidates.append(cand)
1916
+
1917
+ scaler_meta = {
1918
+ "scaler_type": "StandardScaler",
1919
+ "mean_": getattr(sc, "mean_", None),
1920
+ "scale_": getattr(sc, "scale_", None),
1921
+ "with_mean": bool(getattr(sc, "with_mean", True)),
1922
+ "with_std": bool(getattr(sc, "with_std", True)),
1923
+ }
1924
+
1925
+ per_target_records.append({
1926
+ "target_y_scaled": float(y_t),
1927
+ "target_y_unscaled": float(target_y_unscaled),
1928
+ "tol_scaled": float(PROP_TOL_SCALED),
1929
+ "tol_unscaled_abs": float(PROP_TOL_UNSCALED_ABS) if PROP_TOL_UNSCALED_ABS is not None else None,
1930
+ "scaler_meta": scaler_meta,
1931
+ "candidates": candidates,
1932
+ "metrics": m
1933
+ })
1934
+
1935
+ def _finite(xs):
1936
+ return [x for x in xs if np.isfinite(x)]
1937
+
1938
+ metrics_fold = {
1939
+ "validity_mean": float(np.mean(all_valid)) if all_valid else 0.0,
1940
+ "polymer_validity_mean": float(np.mean(all_poly)) if all_poly else 0.0,
1941
+ "avg_n_kept": float(np.mean(all_kept)) if all_kept else 0.0,
1942
+ "success_at_k_scaled": float(np.mean(success_scaled)) if success_scaled else 0.0,
1943
+ "mae_best_scaled": float(np.mean(_finite(mae_best))) if _finite(mae_best) else 0.0,
1944
+ "diversity_mean": float(np.mean(diversity_vals)) if diversity_vals else 0.0,
1945
+ "novelty_mean": float(np.mean(novelty_vals)) if novelty_vals else 0.0,
1946
+ "uniqueness_mean": float(np.mean(uniqueness_vals)) if uniqueness_vals else 0.0,
1947
+ "tol_scaled": float(PROP_TOL_SCALED),
1948
+ "tol_unscaled_abs": float(PROP_TOL_UNSCALED_ABS) if PROP_TOL_UNSCALED_ABS is not None else None,
1949
+ }
1950
+
1951
+ run_record = {
1952
+ "property": property_name,
1953
+ "fold": int(fold_idx + 1),
1954
+ "seed": int(seed),
1955
+ "n_train": int(len(train_polys)),
1956
+ "n_val": int(len(val_polys)),
1957
+ "n_test": int(len(test_polys)),
1958
+ "best_val_loss": float(best_val),
1959
+ "gen_metrics": metrics_fold
1960
+ }
1961
+ runs.append(run_record)
1962
+
1963
+ with open(OUTPUT_RESULTS, "a", encoding="utf-8") as fh:
1964
+ fh.write(json.dumps(make_json_serializable(run_record)) + "\n")
1965
+
1966
+ # Track best fold (by lowest val loss) and save G1-style artifacts
1967
+ if best_val < best_overall_val - 1e-8:
1968
+ best_overall_val = best_val
1969
+ best_bundle = {
1970
+ "fold": int(fold_idx + 1),
1971
+ "decoder_state": best_state,
1972
+ "prop_model": prop_model,
1973
+ "scaler": sc,
1974
+ "best_val_loss": float(best_val),
1975
+ "generations": per_target_records
1976
+ }
1977
+ save_best_fold_artifacts_for_property(
1978
+ property_name=property_name,
1979
+ fold_idx=fold_idx,
1980
+ decoder_state=best_state,
1981
+ prop_model=prop_model,
1982
+ scaler=sc,
1983
+ best_val_loss=best_val,
1984
+ generations_payload=per_target_records
1985
+ )
1986
+ print(f"[INFO] Saved best-fold artifacts for '{property_name}' (fold {fold_idx+1})")
1987
+
1988
+ # Aggregate across folds (G1-style)
1989
+ if not runs:
1990
+ return {"property": property_name, "runs": [], "agg": None, "n_samples": len(polymers)}
1991
+
1992
+ def _collect(key):
1993
+ xs = [float(r["gen_metrics"].get(key, 0.0)) for r in runs if r.get("gen_metrics", None) is not None]
1994
+ return (float(np.mean(xs)) if xs else 0.0, float(np.std(xs)) if xs else 0.0)
1995
+
1996
+ agg = {}
1997
+ for k in ["validity_mean", "polymer_validity_mean", "avg_n_kept", "success_at_k_scaled",
1998
+ "mae_best_scaled", "diversity_mean", "novelty_mean", "uniqueness_mean"]:
1999
+ m, s = _collect(k)
2000
+ agg[k] = {"mean": m, "std": s}
2001
+ agg["tol_scaled"] = float(PROP_TOL_SCALED)
2002
+ agg["tol_unscaled_abs"] = float(PROP_TOL_UNSCALED_ABS) if PROP_TOL_UNSCALED_ABS is not None else None
2003
+
2004
+ # Write AGG line for this property (like G1)
2005
+ with open(OUTPUT_RESULTS, "a", encoding="utf-8") as fh:
2006
+ fh.write("AGG_PROPERTY: " + json.dumps(make_json_serializable({property_name: agg})) + "\n")
2007
+
2008
+ return {"property": property_name, "runs": runs, "agg": agg, "n_samples": len(polymers)}
2009
+
2010
+ # =============================================================================
2011
+ # Tokenizer for PSMILES (matching your preference)
2012
+ # =============================================================================
2013
+
2014
+ def build_psmiles_tokenizer():
2015
+ try:
2016
+ spm_path = "spm_5M.model"
2017
+ if Path(spm_path).exists():
2018
+ print(f"[Tokenizer] Using SentencePiece model: {spm_path}")
2019
+ tok = DebertaV2Tokenizer(vocab_file=spm_path, do_lower_case=False)
2020
+ if tok.pad_token is None:
2021
+ tok.add_special_tokens({"pad_token": "<pad>"})
2022
+ if tok.mask_token is None:
2023
+ tok.add_special_tokens({"mask_token": "<mask>"})
2024
+ tok.pad_token = tok.pad_token if tok.pad_token is not None else "<pad>"
2025
+ tok.mask_token = tok.mask_token if tok.mask_token is not None else "<mask>"
2026
+ return tok
2027
+ except Exception as e:
2028
+ print("Warning: Deberta tokenizer creation failed:", e)
2029
+ return None
2030
+
2031
+ # =============================================================================
2032
+ # Entrypoint (single-task per property; G1-style logging and summary)
2033
+ # =============================================================================
2034
+
2035
+ def main():
2036
+ if not (RDKit_AVAILABLE and SELFIES_AVAILABLE):
2037
+ raise RuntimeError("This script requires RDKit and selfies. Install them before running.")
2038
+
2039
+ if os.path.exists(OUTPUT_RESULTS):
2040
+ backup = OUTPUT_RESULTS + ".bak"
2041
+ shutil.copy(OUTPUT_RESULTS, backup)
2042
+ print(f"[INFO] Existing {OUTPUT_RESULTS} backed up to {backup}")
2043
+ open(OUTPUT_RESULTS, "w", encoding="utf-8").close()
2044
+
2045
+ if not os.path.isfile(POLYINFO_PATH):
2046
+ raise FileNotFoundError(f"PolyInfo file not found at {POLYINFO_PATH}")
2047
+
2048
+ df = pd.read_csv(POLYINFO_PATH, engine="python")
2049
+ found = find_property_columns(df.columns)
2050
+ prop_map = {req: found.get(req) for req in REQUESTED_PROPERTIES}
2051
+ print(f"[INFO] Property-to-column map: {prop_map}")
2052
+ print(f"[INFO] RDKit_AVAILABLE={RDKit_AVAILABLE}, SELFIES_AVAILABLE={SELFIES_AVAILABLE}")
2053
+ print(f"[INFO] VERIFY_GENERATED_PROPERTIES={VERIFY_GENERATED_PROPERTIES} (tol_scaled={PROP_TOL_SCALED}, tol_unscaled_abs={PROP_TOL_UNSCALED_ABS})")
2054
+ print(f"[INFO] USE_AMP={USE_AMP}, DEVICE={DEVICE}, NUM_WORKERS={NUM_WORKERS}")
2055
+ print(f"[INFO] SELFIES_TED_MODEL_NAME={SELFIES_TED_MODEL_NAME}")
2056
+ print(f"[INFO] CL encoder dir: {PRETRAINED_MULTIMODAL_DIR}")
2057
+ print(f"[INFO] Decoder FT params: batch_size={BATCH_SIZE}, epochs={NUM_EPOCHS}, patience={PATIENCE}, "
2058
+ f"optimizer=AdamW, lr={LEARNING_RATE}, weight_decay={WEIGHT_DECAY}, scheduler=CosineAnnealingLR, eta_min={COSINE_ETA_MIN}")
2059
+
2060
+ # Build PSMILES tokenizer for CL text encoder
2061
+ psmiles_tok = build_psmiles_tokenizer()
2062
+ if psmiles_tok is None:
2063
+ raise RuntimeError("Failed to build PSMILES tokenizer.")
2064
+
2065
+ # Multimodal CL encoder (frozen for seeds; same as your prior script)
2066
+ cl_encoder = MultiModalCLPolymerEncoder(
2067
+ psmiles_tokenizer=psmiles_tok,
2068
+ emb_dim=CL_EMB_DIM,
2069
+ cl_weights_dir=PRETRAINED_MULTIMODAL_DIR,
2070
+ use_gine=True, use_schnet=True, use_fp=True, use_psmiles=True
2071
+ ).to(DEVICE)
2072
+ cl_encoder.freeze_cl_encoders()
2073
+
2074
+ # Load SELFIES-TED backbone once (weights re-used when instantiating decoders per fold)
2075
+ selfies_tok, selfies_model = load_selfies_ted_and_tokenizer(SELFIES_TED_MODEL_NAME)
2076
+ print(f"[INFO] Loaded SELFIES-TED backbone: {SELFIES_TED_MODEL_NAME}")
2077
+
2078
+ overall = {"per_property": {}}
2079
+
2080
+ # Single-task loop per property (G1-style)
2081
+ for pname in REQUESTED_PROPERTIES:
2082
+ pcol = prop_map.get(pname, None)
2083
+ if pcol is None:
2084
+ print(f"[WARN] Could not find a column for requested property '{pname}'. Skipping.")
2085
+ continue
2086
+ print(f"\n=== Running PolyBART-style inverse design (single-task) for property: {pname} (col='{pcol}') ===")
2087
+ res = run_inverse_design_single_property(df, pname, pcol, cl_encoder, selfies_tok, selfies_model)
2088
+ overall["per_property"][pname] = res
2089
+
2090
+ # Final summary (aggregated per property)
2091
+ final_agg = {}
2092
+ for pname, info in overall["per_property"].items():
2093
+ final_agg[pname] = info.get("agg", None)
2094
+
2095
+ with open(OUTPUT_RESULTS, "a", encoding="utf-8") as fh:
2096
+ fh.write("\nFINAL_SUMMARY\n")
2097
+ fh.write(json.dumps(make_json_serializable(final_agg), indent=2))
2098
+ fh.write("\n")
2099
+
2100
+ print(f"\n[DONE] Results written to: {OUTPUT_RESULTS}")
2101
+ print(f"[DONE] Best models in: {OUTPUT_MODELS_DIR}")
2102
+ print(f"[DONE] Best-fold generations in: {OUTPUT_GENERATIONS_DIR}")
2103
+
2104
+ if __name__ == "__main__":
2105
+ main()