Spaces:
Running
Running
| import os | |
| import re | |
| import sys | |
| import csv | |
| import json | |
| import math | |
| import time | |
| import copy | |
| import random | |
| import shutil | |
| import warnings | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import List, Dict, Optional, Tuple, Any | |
| import numpy as np | |
| import pandas as pd | |
| from sklearn.model_selection import train_test_split, KFold | |
| from sklearn.preprocessing import StandardScaler | |
| from sklearn.decomposition import PCA | |
| from sklearn.gaussian_process import GaussianProcessRegressor | |
| from sklearn.gaussian_process.kernels import RBF, ConstantKernel as C, WhiteKernel | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import Dataset, DataLoader | |
| # Increase CSV field size limit safely (helps when JSON blobs are stored in CSV cells) | |
| try: | |
| csv.field_size_limit(sys.maxsize) | |
| except OverflowError: | |
| csv.field_size_limit(2**31 - 1) | |
| # HF Transformers (SELFIES-TED) | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| from transformers.modeling_outputs import BaseModelOutput | |
| # Shared encoders/helpers from PolyFusion | |
| from PolyFusion.GINE import GineEncoder | |
| from PolyFusion.SchNet import NodeSchNetWrapper | |
| from PolyFusion.Transformer import PooledFingerprintEncoder as FingerprintEncoder | |
| from PolyFusion.DeBERTav2 import PSMILESDebertaEncoder, build_psmiles_tokenizer | |
| # Optional chemistry dependencies (recommended) | |
| RDKit_AVAILABLE = False | |
| SELFIES_AVAILABLE = False | |
| try: | |
| from rdkit import Chem | |
| from rdkit.Chem import AllChem, DataStructs | |
| RDKit_AVAILABLE = True | |
| except Exception: | |
| RDKit_AVAILABLE = False | |
| try: | |
| import selfies as sf | |
| SELFIES_AVAILABLE = True | |
| except Exception: | |
| SELFIES_AVAILABLE = False | |
| # ============================================================================= | |
| # Configuration (paths are placeholders; replace with your actual filesystem paths) | |
| # ============================================================================= | |
| class Config: | |
| # ------------------------------------------------------------------------- | |
| # Input data and pretrained artifacts (placeholders) | |
| # ------------------------------------------------------------------------- | |
| BASE_DIR: str = "/path/to/Polymer_Foundational_Model" | |
| POLYINFO_PATH: str = "/path/to/polyinfo_with_modalities.csv" | |
| # Multimodal CL checkpoint (for the fused encoder) | |
| PRETRAINED_MULTIMODAL_DIR: str = "/path/to/multimodal_output/best" | |
| # Unimodal encoder checkpoints | |
| BEST_GINE_DIR: str = "/path/to/gin_output/best" | |
| BEST_SCHNET_DIR: str = "/path/to/schnet_output/best" | |
| BEST_FP_DIR: str = "/path/to/fingerprint_mlm_output/best" | |
| BEST_PSMILES_DIR: str = "/path/to/polybert_output/best" | |
| # SentencePiece model for PSMILES tokenizer (placeholder) | |
| SPM_MODEL_PATH: str = "/path/to/spm.model" | |
| # ------------------------------------------------------------------------- | |
| # Output folders | |
| # ------------------------------------------------------------------------- | |
| OUTPUT_DIR: str = "/path/to/multimodal_inverse_design_output" | |
| def OUTPUT_RESULTS(self) -> str: | |
| return os.path.join(self.OUTPUT_DIR, "inverse_design_results.txt") | |
| def OUTPUT_MODELS_DIR(self) -> str: | |
| return os.path.join(self.OUTPUT_DIR, "best_models") | |
| def OUTPUT_GENERATIONS_DIR(self) -> str: | |
| return os.path.join(self.OUTPUT_DIR, "best_fold_generations") | |
| CFG = Config() | |
| # Properties to run | |
| REQUESTED_PROPERTIES = [ | |
| "density", | |
| "glass transition", | |
| "melting", | |
| "thermal decomposition", | |
| ] | |
| # ------------------------------------------------------------------------- | |
| # Model sizes / dims (match CL encoder + pretraining) | |
| # ------------------------------------------------------------------------- | |
| CL_EMB_DIM = 600 | |
| MAX_ATOMIC_Z = 85 | |
| MASK_ATOM_ID = MAX_ATOMIC_Z + 1 | |
| # GINE params | |
| NODE_EMB_DIM = 300 | |
| EDGE_EMB_DIM = 300 | |
| NUM_GNN_LAYERS = 5 | |
| # SchNet params | |
| SCHNET_NUM_GAUSSIANS = 50 | |
| SCHNET_NUM_INTERACTIONS = 6 | |
| SCHNET_CUTOFF = 10.0 | |
| SCHNET_MAX_NEIGHBORS = 64 | |
| SCHNET_HIDDEN = 600 | |
| # Fingerprint params | |
| FP_LENGTH = 2048 | |
| MASK_TOKEN_ID_FP = 2 | |
| VOCAB_SIZE_FP = 3 | |
| # DeBERTa params | |
| DEBERTA_HIDDEN = 600 | |
| PSMILES_MAX_LEN = 128 | |
| # SELFIES-TED generation limits | |
| GEN_MAX_LEN = 256 | |
| GEN_MIN_LEN = 10 | |
| # ------------------------------------------------------------------------- | |
| # Decoder fine-tuning schedule (single head) | |
| # ------------------------------------------------------------------------- | |
| BATCH_SIZE = 32 | |
| NUM_EPOCHS = 100 | |
| PATIENCE = 10 | |
| WEIGHT_DECAY = 0.0 | |
| LEARNING_RATE = 1e-4 | |
| COSINE_ETA_MIN = 1e-6 | |
| # Noise injection (latent space) | |
| LATENT_NOISE_STD_TRAIN = 0.10 # training-time denoising std | |
| LATENT_NOISE_STD_GEN = 0.15 # generation-time exploration std | |
| N_FOLD_NOISE_SAMPLING = 16 # sampling multiplicity around each seed embedding | |
| # Sampling config (decoder) | |
| GEN_TOP_P = 0.92 | |
| GEN_TEMPERATURE = 1.0 | |
| GEN_REPETITION_PENALTY = 1.05 | |
| # Cross-validation | |
| NUM_FOLDS = 5 | |
| # Property guidance tolerance (scaled space) | |
| PROP_TOL_SCALED = 0.5 | |
| PROP_TOL_UNSCALED_ABS = None | |
| # GPR settings (PSMILES latent) | |
| USE_PCA_BEFORE_GPR = True | |
| PCA_DIM = 64 | |
| GPR_ALPHA = 1e-6 | |
| # Verification (optional auxiliary predictor trained per fold) | |
| VERIFY_GENERATED_PROPERTIES = True | |
| PROP_PRED_EPOCHS = 20 | |
| PROP_PRED_PATIENCE = 5 | |
| PROP_PRED_BATCH_SIZE = 32 | |
| PROP_PRED_LR = 3e-4 | |
| PROP_PRED_WEIGHT_DECAY = 0.0 | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| USE_AMP = bool(torch.cuda.is_available()) | |
| AMP_DTYPE = torch.float16 | |
| NUM_WORKERS = 0 if os.name == "nt" else 1 | |
| warnings.filterwarnings("ignore", category=UserWarning) | |
| def ensure_output_dirs(cfg: Config) -> None: | |
| """Create output directories if they do not exist.""" | |
| os.makedirs(cfg.OUTPUT_DIR, exist_ok=True) | |
| os.makedirs(cfg.OUTPUT_MODELS_DIR, exist_ok=True) | |
| os.makedirs(cfg.OUTPUT_GENERATIONS_DIR, exist_ok=True) | |
| # ============================================================================= | |
| # Utilities | |
| # ============================================================================= | |
| def _safe_json_load(x): | |
| """Robust JSON parsing for CSV cells that may contain dict/list JSON (or slightly malformed strings).""" | |
| if x is None: | |
| return None | |
| if isinstance(x, (dict, list)): | |
| return x | |
| s = str(x).strip() | |
| if not s: | |
| return None | |
| try: | |
| return json.loads(s) | |
| except Exception: | |
| try: | |
| return json.loads(s.replace("'", '"')) | |
| except Exception: | |
| return None | |
| def set_seed(seed: int): | |
| """Set random seeds for reproducibility (best effort).""" | |
| random.seed(seed) | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| if torch.cuda.is_available(): | |
| torch.cuda.manual_seed_all(seed) | |
| try: | |
| # Keep cuDNN fast; reproducibility across GPUs/drivers is not guaranteed. | |
| torch.backends.cudnn.benchmark = True | |
| except Exception: | |
| pass | |
| def make_json_serializable(obj): | |
| """Convert common scientific objects into JSON-serializable Python types.""" | |
| if isinstance(obj, dict): | |
| return {make_json_serializable(k): make_json_serializable(v) for k, v in obj.items()} | |
| if isinstance(obj, (list, tuple, set)): | |
| return [make_json_serializable(x) for x in obj] | |
| if isinstance(obj, np.ndarray): | |
| return obj.tolist() | |
| if isinstance(obj, np.generic): | |
| try: | |
| return obj.item() | |
| except Exception: | |
| return float(obj) | |
| if isinstance(obj, torch.Tensor): | |
| try: | |
| return obj.detach().cpu().tolist() | |
| except Exception: | |
| return None | |
| if isinstance(obj, (pd.Timestamp, pd.Timedelta)): | |
| return str(obj) | |
| try: | |
| if isinstance(obj, (float, int, str, bool, type(None))): | |
| return obj | |
| except Exception: | |
| pass | |
| return str(obj) | |
| def find_property_columns(columns): | |
| """ | |
| Heuristically map requested property names to dataframe columns. | |
| Notes: | |
| - Uses lowercase matching; prefers token-level matches. | |
| - Special-case: exclude "cohesive energy" when searching for "density". | |
| """ | |
| lowered = {c.lower(): c for c in columns} | |
| found = {} | |
| for req in REQUESTED_PROPERTIES: | |
| req_low = req.lower().strip() | |
| exact = None | |
| # First, attempt a token match | |
| for c_low, c_orig in lowered.items(): | |
| tokens = set(c_low.replace('_', ' ').split()) | |
| if req_low in tokens or c_low == req_low: | |
| if req_low == "density" and ("cohesive" in c_low or "cohesive energy" in c_low): | |
| continue | |
| exact = c_orig | |
| break | |
| if exact is not None: | |
| found[req] = exact | |
| continue | |
| # Fallback: substring match | |
| candidates = [c_orig for c_low, c_orig in lowered.items() if req_low in c_low] | |
| if req_low == "density": | |
| candidates = [c for c in candidates if "cohesive" not in c.lower() and "cohesive energy" not in c.lower()] | |
| chosen = candidates[0] if candidates else None | |
| found[req] = chosen | |
| if chosen is None: | |
| print(f"[WARN] Could not match requested property '{req}' to any column.") | |
| else: | |
| print(f"[INFO] Mapped requested property '{req}' -> column '{chosen}'") | |
| return found | |
| # ============================================================================= | |
| # Graph / geometry / fingerprint parsing for multimodal CL encoding | |
| # ============================================================================= | |
| def _parse_graph_for_gine(graph_field): | |
| """ | |
| Convert a stored 'graph' JSON blob into the tensor inputs expected by GineEncoder. | |
| Returns None if graph is missing or malformed. | |
| """ | |
| gf = _safe_json_load(graph_field) | |
| if not isinstance(gf, dict): | |
| return None | |
| node_features = gf.get("node_features", None) | |
| if not node_features or not isinstance(node_features, list): | |
| return None | |
| atomic_nums, chirality_vals, formal_charges = [], [], [] | |
| for nf in node_features: | |
| if not isinstance(nf, dict): | |
| continue | |
| an = nf.get("atomic_num", nf.get("atomic_number", 0)) | |
| ch = nf.get("chirality", 0) | |
| fc = nf.get("formal_charge", 0) | |
| try: | |
| atomic_nums.append(int(an)) | |
| except Exception: | |
| atomic_nums.append(0) | |
| try: | |
| chirality_vals.append(float(ch)) | |
| except Exception: | |
| chirality_vals.append(0.0) | |
| try: | |
| formal_charges.append(float(fc)) | |
| except Exception: | |
| formal_charges.append(0.0) | |
| if len(atomic_nums) == 0: | |
| return None | |
| edge_indices_raw = gf.get("edge_indices", None) | |
| edge_features_raw = gf.get("edge_features", None) | |
| srcs, dsts = [], [] | |
| # Handle two common representations: | |
| # (a) edge_indices = [[u,v], [u,v], ...] | |
| # (b) edge_indices = [[srcs...], [dsts...]] | |
| if edge_indices_raw is None: | |
| adj = gf.get("adjacency_matrix", None) | |
| if isinstance(adj, list): | |
| for i_r, row_adj in enumerate(adj): | |
| if not isinstance(row_adj, list): | |
| continue | |
| for j, val in enumerate(row_adj): | |
| if val: | |
| srcs.append(i_r) | |
| dsts.append(j) | |
| else: | |
| try: | |
| if isinstance(edge_indices_raw, list) and len(edge_indices_raw) > 0: | |
| if isinstance(edge_indices_raw[0], list) and len(edge_indices_raw[0]) == 2: | |
| srcs = [int(p[0]) for p in edge_indices_raw] | |
| dsts = [int(p[1]) for p in edge_indices_raw] | |
| elif len(edge_indices_raw) == 2: | |
| srcs = [int(x) for x in edge_indices_raw[0]] | |
| dsts = [int(x) for x in edge_indices_raw[1]] | |
| except Exception: | |
| srcs, dsts = [], [] | |
| if len(srcs) == 0: | |
| edge_index = torch.empty((2, 0), dtype=torch.long) | |
| edge_attr = torch.zeros((0, 3), dtype=torch.float) | |
| return { | |
| "z": torch.tensor(atomic_nums, dtype=torch.long), | |
| "chirality": torch.tensor(chirality_vals, dtype=torch.float), | |
| "formal_charge": torch.tensor(formal_charges, dtype=torch.float), | |
| "edge_index": edge_index, | |
| "edge_attr": edge_attr, | |
| } | |
| edge_index = torch.tensor([srcs, dsts], dtype=torch.long) | |
| # Edge attributes (bond_type, stereo, is_conjugated) if present; else zeros | |
| if isinstance(edge_features_raw, list) and len(edge_features_raw) == len(srcs): | |
| bt, st, ic = [], [], [] | |
| for ef in edge_features_raw: | |
| if isinstance(ef, dict): | |
| bt.append(float(ef.get("bond_type", 0))) | |
| st.append(float(ef.get("stereo", 0))) | |
| ic.append(float(1.0 if ef.get("is_conjugated", False) else 0.0)) | |
| else: | |
| bt.append(0.0) | |
| st.append(0.0) | |
| ic.append(0.0) | |
| edge_attr = torch.tensor(list(zip(bt, st, ic)), dtype=torch.float) | |
| else: | |
| edge_attr = torch.zeros((len(srcs), 3), dtype=torch.float) | |
| return { | |
| "z": torch.tensor(atomic_nums, dtype=torch.long), | |
| "chirality": torch.tensor(chirality_vals, dtype=torch.float), | |
| "formal_charge": torch.tensor(formal_charges, dtype=torch.float), | |
| "edge_index": edge_index, | |
| "edge_attr": edge_attr, | |
| } | |
| def _parse_geometry_for_schnet(geom_field): | |
| """ | |
| Convert stored 'geometry' JSON blob into SchNet inputs: | |
| - atomic_numbers -> z | |
| - coordinates -> pos | |
| Returns None if missing/malformed. | |
| """ | |
| gf = _safe_json_load(geom_field) | |
| if not isinstance(gf, dict): | |
| return None | |
| conf = gf.get("best_conformer", None) | |
| if not isinstance(conf, dict): | |
| return None | |
| atomic = conf.get("atomic_numbers", []) | |
| coords = conf.get("coordinates", []) | |
| if not (isinstance(atomic, list) and isinstance(coords, list)): | |
| return None | |
| if len(atomic) == 0 or len(atomic) != len(coords): | |
| return None | |
| return {"z": torch.tensor(atomic, dtype=torch.long), "pos": torch.tensor(coords, dtype=torch.float)} | |
| def _parse_fingerprints(fp_field, fp_len: int = 2048): | |
| """ | |
| Parse a fingerprint field (either list or dict containing 'morgan_r3_bits') into: | |
| - input_ids: LongTensor [fp_len] with 0/1 bits | |
| - attention_mask: BoolTensor [fp_len] all True | |
| """ | |
| fp = _safe_json_load(fp_field) | |
| bits = None | |
| if isinstance(fp, dict): | |
| bits = fp.get("morgan_r3_bits", None) | |
| elif isinstance(fp, list): | |
| bits = fp | |
| elif fp is None: | |
| bits = None | |
| if bits is None: | |
| bits = [0] * fp_len | |
| else: | |
| norm = [] | |
| for b in bits[:fp_len]: | |
| if isinstance(b, str): | |
| bc = b.strip().strip('"').strip("'") | |
| norm.append(1 if bc in ("1", "True", "true") else 0) | |
| elif isinstance(b, (int, np.integer, float, np.floating)): | |
| norm.append(1 if int(b) != 0 else 0) | |
| else: | |
| norm.append(0) | |
| if len(norm) < fp_len: | |
| norm.extend([0] * (fp_len - len(norm))) | |
| bits = norm | |
| return { | |
| "input_ids": torch.tensor(bits, dtype=torch.long), | |
| "attention_mask": torch.ones(fp_len, dtype=torch.bool), | |
| } | |
| # ============================================================================= | |
| # PSELFIES utilities (polymer-safe SELFIES encoding with endpoint markers) | |
| # ============================================================================= | |
| _SELFIES_TOKEN_RE = re.compile(r"\[[^\[\]]+\]") | |
| def _split_selfies_tokens(selfies_str: str) -> List[str]: | |
| """Split a SELFIES string into tokens; prefers selfies.split_selfies if available.""" | |
| if not isinstance(selfies_str, str) or len(selfies_str) == 0: | |
| return [] | |
| if SELFIES_AVAILABLE: | |
| try: | |
| toks = list(sf.split_selfies(selfies_str.replace(" ", ""))) | |
| return [t for t in toks if isinstance(t, str) and t] | |
| except Exception: | |
| pass | |
| return _SELFIES_TOKEN_RE.findall(selfies_str) | |
| def _selfies_for_tokenizer(selfies_str: str) -> str: | |
| """Normalize SELFIES formatting so the HF tokenizer sees token boundaries.""" | |
| s = str(selfies_str).strip() | |
| if not s: | |
| return "" | |
| s = s.replace(" ", "") | |
| s = s.replace("][", "] [") | |
| return s | |
| def _selfies_compact(selfies_str: str) -> str: | |
| """Remove spaces and trim.""" | |
| return str(selfies_str).replace(" ", "").strip() | |
| def _ensure_two_at_endpoints(selfies_str: str) -> str: | |
| """ | |
| Ensure polymer endpoints exist: enforce exactly two [At] tokens (one at each end). | |
| This is used as a polymerization marker compatible with the At-based conversion. | |
| """ | |
| s = _selfies_compact(selfies_str) | |
| toks = _split_selfies_tokens(s) | |
| if not toks: | |
| return s | |
| at = "[At]" | |
| at_pos = [i for i, t in enumerate(toks) if t == at] | |
| if len(at_pos) == 0: | |
| toks = [at] + toks + [at] | |
| elif len(at_pos) == 1: | |
| toks = toks + [at] | |
| elif len(at_pos) > 2: | |
| first = at_pos[0] | |
| last = at_pos[-1] | |
| new = [] | |
| for i, t in enumerate(toks): | |
| if t == at and i not in (first, last): | |
| continue | |
| new.append(t) | |
| toks = new | |
| return "".join(toks) | |
| def psmiles_to_at_smiles(psmiles: str, root_at: bool = True) -> Optional[str]: | |
| """ | |
| Convert polymer PSMILES (two [*]) into RDKit SMILES where [*] is represented as element At (Z=85). | |
| This allows SELFIES encoding/decoding while preserving polymer endpoints. | |
| """ | |
| if not RDKit_AVAILABLE: | |
| return None | |
| try: | |
| mol = Chem.MolFromSmiles(psmiles) | |
| if mol is None: | |
| return None | |
| mol = Chem.RWMol(mol) | |
| at_indices = [] | |
| for atom in mol.GetAtoms(): | |
| if atom.GetAtomicNum() == 0: | |
| atom.SetAtomicNum(85) | |
| try: | |
| atom.SetNoImplicit(True) | |
| except Exception: | |
| pass | |
| try: | |
| atom.SetNumExplicitHs(0) | |
| except Exception: | |
| pass | |
| try: | |
| atom.SetFormalCharge(0) | |
| except Exception: | |
| pass | |
| at_indices.append(int(atom.GetIdx())) | |
| mol = mol.GetMol() | |
| try: | |
| Chem.SanitizeMol(mol, catchErrors=True) | |
| except Exception: | |
| return None | |
| if root_at and len(at_indices) > 0: | |
| try: | |
| can = Chem.MolToSmiles(mol, canonical=True, rootedAtAtom=at_indices[0]) | |
| except Exception: | |
| can = Chem.MolToSmiles(mol, canonical=True) | |
| else: | |
| can = Chem.MolToSmiles(mol, canonical=True) | |
| return can | |
| except Exception: | |
| return None | |
| def at_smiles_to_psmiles(at_smiles: str) -> Optional[str]: | |
| """Inverse of psmiles_to_at_smiles: convert At (Z=85) back to polymer [*] endpoints.""" | |
| if not RDKit_AVAILABLE: | |
| return None | |
| try: | |
| mol = Chem.MolFromSmiles(at_smiles) | |
| if mol is None: | |
| return None | |
| rw = Chem.RWMol(mol) | |
| for atom in rw.GetAtoms(): | |
| if atom.GetAtomicNum() == 85: | |
| atom.SetAtomicNum(0) | |
| try: | |
| atom.SetNoImplicit(True) | |
| except Exception: | |
| pass | |
| try: | |
| atom.SetNumExplicitHs(0) | |
| except Exception: | |
| pass | |
| try: | |
| atom.SetFormalCharge(0) | |
| except Exception: | |
| pass | |
| mol2 = rw.GetMol() | |
| try: | |
| Chem.SanitizeMol(mol2, catchErrors=True) | |
| except Exception: | |
| return None | |
| can = Chem.MolToSmiles(mol2, canonical=True) | |
| can = can.replace("[*]", "*") | |
| return can | |
| except Exception: | |
| return None | |
| def smiles_to_pselfies(smiles: str) -> Optional[str]: | |
| """Encode RDKit-canonical SMILES into SELFIES.""" | |
| if not (RDKit_AVAILABLE and SELFIES_AVAILABLE): | |
| return None | |
| try: | |
| mol = Chem.MolFromSmiles(smiles) | |
| if mol is None: | |
| return None | |
| can = Chem.MolToSmiles(mol, canonical=True) | |
| s = sf.encoder(can) | |
| if not isinstance(s, str) or len(s) == 0: | |
| return None | |
| return s | |
| except Exception: | |
| return None | |
| def psmiles_to_pselfies(psmiles: str) -> Optional[str]: | |
| """Convert polymer PSMILES -> At-SMILES -> PSELFIES, ensuring endpoint markers.""" | |
| if not (RDKit_AVAILABLE and SELFIES_AVAILABLE): | |
| return None | |
| at_smiles = psmiles_to_at_smiles(psmiles, root_at=True) | |
| if at_smiles is None: | |
| return None | |
| s = smiles_to_pselfies(at_smiles) | |
| if s is None: | |
| return None | |
| return _ensure_two_at_endpoints(s) | |
| def selfies_to_smiles(selfies_str: str) -> Optional[str]: | |
| """Decode SELFIES -> SMILES and canonicalize with RDKit.""" | |
| if not (RDKit_AVAILABLE and SELFIES_AVAILABLE): | |
| return None | |
| try: | |
| s = _selfies_compact(selfies_str) | |
| smi = sf.decoder(s) | |
| if not isinstance(smi, str) or len(smi) == 0: | |
| return None | |
| mol = Chem.MolFromSmiles(smi) | |
| if mol is None: | |
| return None | |
| try: | |
| Chem.SanitizeMol(mol, catchErrors=True) | |
| except Exception: | |
| return None | |
| can = Chem.MolToSmiles(mol, canonical=True) | |
| return can | |
| except Exception: | |
| return None | |
| def pselfies_to_psmiles(selfies_str: str) -> Optional[str]: | |
| """Decode PSELFIES -> At-SMILES -> polymer PSMILES.""" | |
| if not (RDKit_AVAILABLE and SELFIES_AVAILABLE): | |
| return None | |
| at_smiles = selfies_to_smiles(selfies_str) | |
| if at_smiles is None: | |
| return None | |
| return at_smiles_to_psmiles(at_smiles) | |
| def canonicalize_psmiles(psmiles: str) -> Optional[str]: | |
| """RDKit-canonicalize PSMILES (best effort).""" | |
| psmiles = str(psmiles).strip() | |
| if not psmiles: | |
| return None | |
| if not RDKit_AVAILABLE: | |
| return psmiles | |
| try: | |
| mol = Chem.MolFromSmiles(psmiles) | |
| if mol is None: | |
| return None | |
| try: | |
| Chem.SanitizeMol(mol, catchErrors=True) | |
| except Exception: | |
| return None | |
| can = Chem.MolToSmiles(mol, canonical=True) | |
| can = can.replace("[*]", "*") | |
| return can | |
| except Exception: | |
| return None | |
| def chem_validity_psmiles(psmiles: str) -> bool: | |
| """Basic chemical validity check via RDKit parse + sanitize.""" | |
| if not RDKit_AVAILABLE: | |
| return False | |
| try: | |
| s = str(psmiles).strip() | |
| if not s: | |
| return False | |
| mol = Chem.MolFromSmiles(s) | |
| if mol is None: | |
| return False | |
| try: | |
| Chem.SanitizeMol(mol, catchErrors=True) | |
| except Exception: | |
| return False | |
| return True | |
| except Exception: | |
| return False | |
| def polymer_validity_psmiles_strict(psmiles: str) -> bool: | |
| """ | |
| Strict polymer validity: | |
| - exactly two [*] atoms | |
| - each [*] has degree 1 (a single attachment) | |
| """ | |
| if not RDKit_AVAILABLE: | |
| return False | |
| try: | |
| s = str(psmiles).strip() | |
| if not s: | |
| return False | |
| mol = Chem.MolFromSmiles(s) | |
| if mol is None: | |
| return False | |
| try: | |
| Chem.SanitizeMol(mol, catchErrors=True) | |
| except Exception: | |
| return False | |
| stars = [a for a in mol.GetAtoms() if a.GetAtomicNum() == 0] | |
| if len(stars) != 2: | |
| return False | |
| for a in stars: | |
| if a.GetTotalDegree() != 1: | |
| return False | |
| return True | |
| except Exception: | |
| return False | |
| # ============================================================================= | |
| # CL encoder (multimodal) + fusion pooling | |
| # ============================================================================= | |
| def resolve_cl_checkpoint_path(cl_weights_dir: str) -> Optional[str]: | |
| """Resolve a checkpoint file inside a directory (or accept a file path directly).""" | |
| if cl_weights_dir is None: | |
| return None | |
| if os.path.isfile(cl_weights_dir): | |
| return cl_weights_dir | |
| if not os.path.isdir(cl_weights_dir): | |
| return None | |
| candidates = [ | |
| os.path.join(cl_weights_dir, "pytorch_model.bin"), | |
| os.path.join(cl_weights_dir, "model.pt"), | |
| os.path.join(cl_weights_dir, "best.pt"), | |
| os.path.join(cl_weights_dir, "state_dict.pt"), | |
| ] | |
| for p in candidates: | |
| if os.path.isfile(p): | |
| return p | |
| for ext in ("*.bin", "*.pt"): | |
| files = sorted(Path(cl_weights_dir).glob(ext)) | |
| if files: | |
| return str(files[0]) | |
| return None | |
| def load_state_dict_any(ckpt_path: str) -> Dict[str, torch.Tensor]: | |
| """Load a checkpoint that may wrap the model state dict under common keys.""" | |
| obj = torch.load(ckpt_path, map_location="cpu") | |
| if isinstance(obj, dict): | |
| if "state_dict" in obj and isinstance(obj["state_dict"], dict): | |
| return obj["state_dict"] | |
| if "model_state_dict" in obj and isinstance(obj["model_state_dict"], dict): | |
| return obj["model_state_dict"] | |
| if not isinstance(obj, dict): | |
| raise RuntimeError(f"Checkpoint at {ckpt_path} did not contain a state_dict-like dict.") | |
| return obj | |
| def safe_load_into_module(module: nn.Module, sd: Dict[str, torch.Tensor], strict: bool = False) -> Tuple[int, int]: | |
| """Load a (possibly partial) state dict and return counts of missing/unexpected keys.""" | |
| incompatible = module.load_state_dict(sd, strict=strict) | |
| missing = getattr(incompatible, "missing_keys", []) | |
| unexpected = getattr(incompatible, "unexpected_keys", []) | |
| return len(missing), len(unexpected) | |
| class PolyFusionModule(nn.Module): | |
| """ | |
| Tiny fusion transformer: | |
| - self-attention over modality tokens | |
| - learned query pooling (attention weights -> pooled representation) | |
| """ | |
| def __init__(self, d_model: int, nhead: int = 8, ffn_mult: int = 4, dropout: float = 0.1): | |
| super().__init__() | |
| self.ln1 = nn.LayerNorm(d_model) | |
| self.attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) | |
| self.ln2 = nn.LayerNorm(d_model) | |
| self.ffn = nn.Sequential( | |
| nn.Linear(d_model, ffn_mult * d_model), | |
| nn.GELU(), | |
| nn.Dropout(dropout), | |
| nn.Linear(ffn_mult * d_model, d_model), | |
| nn.Dropout(dropout), | |
| ) | |
| self.pool_ln = nn.LayerNorm(d_model) | |
| self.pool_q = nn.Parameter(torch.randn(d_model)) | |
| def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: | |
| # mask: True for valid tokens; MultiheadAttention uses key_padding_mask where True means "ignore" | |
| key_padding = ~mask | |
| h = self.ln1(x) | |
| attn_out, _ = self.attn(h, h, h, key_padding_mask=key_padding) | |
| x = x + attn_out | |
| x = x + self.ffn(self.ln2(x)) | |
| # query pooling | |
| x = self.pool_ln(x) | |
| q = self.pool_q.unsqueeze(0).unsqueeze(-1) # [1, d, 1] | |
| scores = torch.matmul(x, q).squeeze(-1) # [B, T] | |
| scores = scores.masked_fill(~mask, -1e9) | |
| w = torch.softmax(scores, dim=-1).unsqueeze(-1) | |
| pooled = (x * w).sum(dim=1) # [B, d] | |
| return pooled | |
| class MultiModalCLPolymerEncoder(nn.Module): | |
| """ | |
| Frozen multimodal encoder used as the conditioning interface: | |
| - encodes any subset of modalities (graph/geometry/fingerprint/psmiles) | |
| - projects each modality into a shared CL embedding space | |
| - fuses available modality tokens into a single normalized vector | |
| """ | |
| def __init__( | |
| self, | |
| psmiles_tokenizer, | |
| emb_dim: int = CL_EMB_DIM, | |
| cl_weights_dir: Optional[str] = CFG.PRETRAINED_MULTIMODAL_DIR, | |
| use_gine: bool = True, | |
| use_schnet: bool = True, | |
| use_fp: bool = True, | |
| use_psmiles: bool = True, | |
| ): | |
| super().__init__() | |
| self.psm_tok = psmiles_tokenizer | |
| self.emb_dim = int(emb_dim) | |
| self.gine = None | |
| self.schnet = None | |
| self.fp = None | |
| self.psmiles = None | |
| if use_gine: | |
| try: | |
| self.gine = GineEncoder(NODE_EMB_DIM, EDGE_EMB_DIM, NUM_GNN_LAYERS, MAX_ATOMIC_Z) | |
| except Exception as e: | |
| print(f"[CL][WARN] Disabling GINE encoder: {e}") | |
| self.gine = None | |
| if use_schnet: | |
| try: | |
| self.schnet = NodeSchNetWrapper( | |
| SCHNET_HIDDEN, SCHNET_NUM_INTERACTIONS, SCHNET_NUM_GAUSSIANS, SCHNET_CUTOFF, SCHNET_MAX_NEIGHBORS | |
| ) | |
| except Exception as e: | |
| print(f"[CL][WARN] Disabling SchNet encoder: {e}") | |
| self.schnet = None | |
| if use_fp: | |
| try: | |
| self.fp = FingerprintEncoder(VOCAB_SIZE_FP, 256, FP_LENGTH, 4, 8, 1024, 0.1) | |
| except Exception as e: | |
| print(f"[CL][WARN] Disabling fingerprint encoder: {e}") | |
| self.fp = None | |
| if use_psmiles: | |
| enc_src = CFG.BEST_PSMILES_DIR if (CFG.BEST_PSMILES_DIR and os.path.isdir(CFG.BEST_PSMILES_DIR)) else None | |
| self.psmiles = PSMILESDebertaEncoder( | |
| model_dir_or_name=enc_src, | |
| vocab_fallback=int(getattr(psmiles_tokenizer, "vocab_size", 300)), | |
| ) | |
| # Projection layers into shared CL space | |
| self.proj_gine = nn.Linear(NODE_EMB_DIM, self.emb_dim) if self.gine is not None else None | |
| self.proj_schnet = nn.Linear(SCHNET_HIDDEN, self.emb_dim) if self.schnet is not None else None | |
| self.proj_fp = nn.Linear(256, self.emb_dim) if self.fp is not None else None | |
| self.proj_psmiles = nn.Linear(DEBERTA_HIDDEN, self.emb_dim) if self.psmiles is not None else None | |
| self.dropout = nn.Dropout(0.1) | |
| self.out_dim = self.emb_dim | |
| self.fusion = PolyFusionModule(d_model=self.emb_dim, nhead=8, ffn_mult=4, dropout=0.1) | |
| # Optionally load a trained multimodal CL checkpoint | |
| self._load_multimodal_cl_checkpoint(cl_weights_dir) | |
| def _load_multimodal_cl_checkpoint(self, cl_weights_dir: Optional[str]): | |
| ckpt_path = resolve_cl_checkpoint_path(cl_weights_dir) if cl_weights_dir else None | |
| if ckpt_path is None: | |
| print(f"[CL][INFO] No multimodal CL checkpoint found at '{cl_weights_dir}'. Using initialized weights.") | |
| return | |
| sd = load_state_dict_any(ckpt_path) | |
| model_sd = self.state_dict() | |
| # Load only compatible keys (shape match) to be robust across versions | |
| filtered = {} | |
| for k, v in sd.items(): | |
| if k not in model_sd: | |
| continue | |
| if hasattr(v, "shape") and hasattr(model_sd[k], "shape") and tuple(v.shape) != tuple(model_sd[k].shape): | |
| continue | |
| filtered[k] = v | |
| missing, unexpected = safe_load_into_module(self, filtered, strict=False) | |
| print( | |
| f"[CL][INFO] Loaded multimodal CL checkpoint '{ckpt_path}'. " | |
| f"loaded_keys={len(filtered)} missing={missing} unexpected={unexpected}" | |
| ) | |
| def freeze_cl_encoders(self): | |
| """Freeze encoders and fusion module (decoder training should not update them).""" | |
| for name, enc in [("gine", self.gine), ("schnet", self.schnet), ("fp", self.fp), ("psmiles", self.psmiles)]: | |
| if enc is None: | |
| continue | |
| enc.eval() | |
| for p in enc.parameters(): | |
| p.requires_grad = False | |
| print(f"[CL][INFO] Froze {name} encoder parameters.") | |
| self.fusion.eval() | |
| for p in self.fusion.parameters(): | |
| p.requires_grad = False | |
| print("[CL][INFO] Froze fusion module parameters.") | |
| def forward_multimodal(self, batch_mods: dict) -> torch.Tensor: | |
| """Encode a batch containing any subset of modalities and return normalized CL embeddings.""" | |
| device = next(self.parameters()).device | |
| # Infer batch size from whichever modality is present | |
| if batch_mods.get("fp", None) is not None and isinstance(batch_mods["fp"].get("input_ids", None), torch.Tensor): | |
| B = int(batch_mods["fp"]["input_ids"].size(0)) | |
| elif batch_mods.get("psmiles", None) is not None and isinstance(batch_mods["psmiles"].get("input_ids", None), torch.Tensor): | |
| B = int(batch_mods["psmiles"]["input_ids"].size(0)) | |
| else: | |
| if batch_mods.get("gine", None) is not None and isinstance(batch_mods["gine"].get("batch", None), torch.Tensor): | |
| B = int(batch_mods["gine"]["batch"].max().item() + 1) if batch_mods["gine"]["batch"].numel() > 0 else 1 | |
| elif batch_mods.get("schnet", None) is not None and isinstance(batch_mods["schnet"].get("batch", None), torch.Tensor): | |
| B = int(batch_mods["schnet"]["batch"].max().item() + 1) if batch_mods["schnet"]["batch"].numel() > 0 else 1 | |
| else: | |
| B = 1 | |
| tokens: List[torch.Tensor] = [] | |
| def _append_token(z_token: torch.Tensor): | |
| tokens.append(z_token) | |
| # GINE token | |
| if self.gine is not None and batch_mods.get("gine", None) is not None: | |
| g = batch_mods["gine"] | |
| if isinstance(g.get("z", None), torch.Tensor) and g["z"].numel() > 0: | |
| emb_g = self.gine( | |
| g["z"].to(device), | |
| g.get("chirality", torch.zeros_like(g["z"], dtype=torch.float)).to(device) if isinstance(g.get("chirality", None), torch.Tensor) else None, | |
| g.get("formal_charge", torch.zeros_like(g["z"], dtype=torch.float)).to(device) if isinstance(g.get("formal_charge", None), torch.Tensor) else None, | |
| g.get("edge_index", torch.empty((2, 0), dtype=torch.long)).to(device), | |
| g.get("edge_attr", torch.zeros((0, 3), dtype=torch.float)).to(device), | |
| g.get("batch", None).to(device) if isinstance(g.get("batch", None), torch.Tensor) else None, | |
| ) | |
| zg = self.proj_gine(emb_g) | |
| zg = self.dropout(zg) | |
| _append_token(zg) | |
| # SchNet token | |
| if self.schnet is not None and batch_mods.get("schnet", None) is not None: | |
| s = batch_mods["schnet"] | |
| if isinstance(s.get("z", None), torch.Tensor) and s["z"].numel() > 0: | |
| emb_s = self.schnet( | |
| s["z"].to(device), | |
| s["pos"].to(device), | |
| s.get("batch", None).to(device) if isinstance(s.get("batch", None), torch.Tensor) else None, | |
| ) | |
| zs = self.proj_schnet(emb_s) | |
| zs = self.dropout(zs) | |
| _append_token(zs) | |
| # Fingerprint token | |
| if self.fp is not None and batch_mods.get("fp", None) is not None: | |
| f = batch_mods["fp"] | |
| if isinstance(f.get("input_ids", None), torch.Tensor) and f["input_ids"].numel() > 0: | |
| emb_f = self.fp( | |
| f["input_ids"].to(device), | |
| f.get("attention_mask", None).to(device) if isinstance(f.get("attention_mask", None), torch.Tensor) else None, | |
| ) | |
| zf = self.proj_fp(emb_f) | |
| zf = self.dropout(zf) | |
| _append_token(zf) | |
| # PSMILES token | |
| if self.psmiles is not None and batch_mods.get("psmiles", None) is not None: | |
| p = batch_mods["psmiles"] | |
| if isinstance(p.get("input_ids", None), torch.Tensor) and p["input_ids"].numel() > 0: | |
| emb_p = self.psmiles( | |
| p["input_ids"].to(device), | |
| p.get("attention_mask", None).to(device) if isinstance(p.get("attention_mask", None), torch.Tensor) else None, | |
| ) | |
| zp = self.proj_psmiles(emb_p) | |
| zp = self.dropout(zp) | |
| _append_token(zp) | |
| if not tokens: | |
| # No modalities present; return a safe zero vector | |
| z = torch.zeros((B, self.emb_dim), device=device) | |
| return F.normalize(z, dim=-1) | |
| X = torch.stack(tokens, dim=1) # [B, T, d] | |
| mask = torch.ones((B, X.size(1)), dtype=torch.bool, device=device) | |
| pooled = self.fusion(X, mask) | |
| pooled = F.normalize(pooled, dim=-1) | |
| return pooled | |
| def encode_psmiles( | |
| self, | |
| psmiles_list: List[str], | |
| max_len: int = PSMILES_MAX_LEN, | |
| batch_size: int = 64, | |
| device: str = DEVICE, | |
| ) -> np.ndarray: | |
| self.eval() | |
| if self.psm_tok is None or self.psmiles is None or self.proj_psmiles is None: | |
| raise RuntimeError("PSMILES tokenizer/encoder/projection not available.") | |
| dev = torch.device(device) | |
| self.to(dev) | |
| outs = [] | |
| for i in range(0, len(psmiles_list), batch_size): | |
| chunk = [str(x) for x in psmiles_list[i : i + batch_size]] | |
| enc = self.psm_tok(chunk, truncation=True, padding="max_length", max_length=max_len, return_tensors="pt") | |
| input_ids = enc["input_ids"].to(dev) | |
| attn = enc["attention_mask"].to(dev).bool() | |
| emb_p = self.psmiles(input_ids, attn) | |
| z = self.proj_psmiles(emb_p) | |
| z = F.normalize(z, dim=-1) | |
| outs.append(z.detach().cpu().numpy()) | |
| return np.concatenate(outs, axis=0) if outs else np.zeros((0, self.emb_dim), dtype=np.float32) | |
| def encode_multimodal(self, records: List[dict], batch_size: int = 32, device: str = DEVICE) -> np.ndarray: | |
| """Encode a list of records that may contain any subset of modalities.""" | |
| self.eval() | |
| dev = torch.device(device) | |
| self.to(dev) | |
| outs = [] | |
| for i in range(0, len(records), batch_size): | |
| chunk = records[i : i + batch_size] | |
| # PSMILES tokenization | |
| psmiles_texts = [str(r.get("psmiles", "")) for r in chunk] | |
| p_enc = None | |
| if self.psm_tok is not None: | |
| p_enc = self.psm_tok( | |
| psmiles_texts, | |
| truncation=True, | |
| padding="max_length", | |
| max_length=PSMILES_MAX_LEN, | |
| return_tensors="pt", | |
| ) | |
| # Fingerprints | |
| fp_ids, fp_attn = [], [] | |
| for r in chunk: | |
| f = _parse_fingerprints(r.get("fingerprints", None), fp_len=FP_LENGTH) | |
| fp_ids.append(f["input_ids"]) | |
| fp_attn.append(f["attention_mask"]) | |
| fp_ids = torch.stack(fp_ids, dim=0) | |
| fp_attn = torch.stack(fp_attn, dim=0) | |
| # GINE batch assembly (concat nodes; keep per-graph batch indices) | |
| gine_all = {"z": [], "chirality": [], "formal_charge": [], "edge_index": [], "edge_attr": [], "batch": []} | |
| node_offset = 0 | |
| for bi, r in enumerate(chunk): | |
| g = _parse_graph_for_gine(r.get("graph", None)) | |
| if g is None or g["z"].numel() == 0: | |
| continue | |
| n = g["z"].size(0) | |
| gine_all["z"].append(g["z"]) | |
| gine_all["chirality"].append(g["chirality"]) | |
| gine_all["formal_charge"].append(g["formal_charge"]) | |
| gine_all["batch"].append(torch.full((n,), bi, dtype=torch.long)) | |
| ei = g["edge_index"] | |
| ea = g["edge_attr"] | |
| if ei is not None and ei.numel() > 0: | |
| gine_all["edge_index"].append(ei + node_offset) | |
| gine_all["edge_attr"].append(ea) | |
| node_offset += n | |
| gine_batch = None | |
| if len(gine_all["z"]) > 0: | |
| z_b = torch.cat(gine_all["z"], dim=0) | |
| ch_b = torch.cat(gine_all["chirality"], dim=0) | |
| fc_b = torch.cat(gine_all["formal_charge"], dim=0) | |
| b_b = torch.cat(gine_all["batch"], dim=0) | |
| if len(gine_all["edge_index"]) > 0: | |
| ei_b = torch.cat(gine_all["edge_index"], dim=1) | |
| ea_b = torch.cat(gine_all["edge_attr"], dim=0) | |
| else: | |
| ei_b = torch.empty((2, 0), dtype=torch.long) | |
| ea_b = torch.zeros((0, 3), dtype=torch.float) | |
| gine_batch = { | |
| "z": z_b, | |
| "chirality": ch_b, | |
| "formal_charge": fc_b, | |
| "edge_index": ei_b, | |
| "edge_attr": ea_b, | |
| "batch": b_b, | |
| } | |
| # SchNet batch assembly (concat atoms; keep per-structure batch indices) | |
| sch_all_z, sch_all_pos, sch_all_batch = [], [], [] | |
| for bi, r in enumerate(chunk): | |
| s = _parse_geometry_for_schnet(r.get("geometry", None)) | |
| if s is None or s["z"].numel() == 0: | |
| continue | |
| n = s["z"].size(0) | |
| sch_all_z.append(s["z"]) | |
| sch_all_pos.append(s["pos"]) | |
| sch_all_batch.append(torch.full((n,), bi, dtype=torch.long)) | |
| schnet_batch = None | |
| if len(sch_all_z) > 0: | |
| schnet_batch = { | |
| "z": torch.cat(sch_all_z, dim=0), | |
| "pos": torch.cat(sch_all_pos, dim=0), | |
| "batch": torch.cat(sch_all_batch, dim=0), | |
| } | |
| batch_mods = { | |
| "gine": gine_batch, | |
| "schnet": schnet_batch, | |
| "fp": {"input_ids": fp_ids, "attention_mask": fp_attn}, | |
| "psmiles": {"input_ids": p_enc["input_ids"], "attention_mask": p_enc["attention_mask"]} if p_enc is not None else None, | |
| } | |
| z = self.forward_multimodal(batch_mods) | |
| outs.append(z.detach().cpu().numpy()) | |
| return np.concatenate(outs, axis=0) if outs else np.zeros((0, self.emb_dim), dtype=np.float32) | |
| # ============================================================================= | |
| # SELFIES-TED decoder conditioned on CL embeddings | |
| # ============================================================================= | |
| SELFIES_TED_MODEL_NAME = os.environ.get("SELFIES_TED_MODEL_NAME", "ibm-research/materials.selfies-ted") | |
| HF_TOKEN = os.environ.get("HF_TOKEN", None) | |
| def _hf_load_with_retries(load_fn, max_tries: int = 5, base_sleep: float = 2.0): | |
| """Retry wrapper for HF downloads (useful when the hub is flaky or rate-limited).""" | |
| last_err = None | |
| for t in range(max_tries): | |
| try: | |
| return load_fn() | |
| except Exception as e: | |
| last_err = e | |
| sleep_s = base_sleep * (1.6 ** t) + random.random() | |
| print(f"[HF][WARN] Load attempt {t+1}/{max_tries} failed: {e} | retrying in {sleep_s:.1f}s") | |
| time.sleep(sleep_s) | |
| raise RuntimeError(f"Failed to load model from HF after {max_tries} attempts. Last error: {last_err}") | |
| def load_selfies_ted_and_tokenizer(model_name: str = SELFIES_TED_MODEL_NAME): | |
| """Load SELFIES-TED tokenizer and model from Hugging Face.""" | |
| def _load_tok(): | |
| return AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN, use_fast=True) | |
| def _load_model(): | |
| return AutoModelForSeq2SeqLM.from_pretrained(model_name, token=HF_TOKEN) | |
| tok = _hf_load_with_retries(_load_tok, max_tries=5) | |
| model = _hf_load_with_retries(_load_model, max_tries=5) | |
| return tok, model | |
| class CLConditionedSelfiesTEDGenerator(nn.Module): | |
| """ | |
| Condition SELFIES-TED on CL embeddings by: | |
| - mapping CL vector -> d_model | |
| - expanding into a short "memory" sequence (mem_len) | |
| - passing this as encoder_outputs to the seq2seq model | |
| """ | |
| def __init__(self, tok, seq2seq_model, cl_emb_dim: int = CL_EMB_DIM, mem_len: int = 4): | |
| super().__init__() | |
| self.tok = tok | |
| self.model = seq2seq_model | |
| self.mem_len = int(mem_len) | |
| d_model = int(getattr(self.model.config, "d_model", 1024)) | |
| self.cl_to_d = nn.Sequential( | |
| nn.Linear(cl_emb_dim, d_model), | |
| nn.Tanh(), | |
| nn.Dropout(0.1), | |
| nn.Linear(d_model, d_model), | |
| ) | |
| self.mem_pos = nn.Embedding(self.mem_len, d_model) | |
| def build_encoder_outputs(self, z: torch.Tensor) -> Tuple[BaseModelOutput, torch.Tensor]: | |
| """Determine encoder outputs from a CL latent vector.""" | |
| device = z.device | |
| B = z.size(0) | |
| d = self.cl_to_d(z) # [B, d_model] | |
| d = d.unsqueeze(1).expand(B, self.mem_len, d.size(-1)).contiguous() | |
| pos = torch.arange(self.mem_len, device=device).unsqueeze(0).expand(B, -1) | |
| d = d + self.mem_pos(pos) | |
| attn = torch.ones((B, self.mem_len), dtype=torch.long, device=device) | |
| return BaseModelOutput(last_hidden_state=d), attn | |
| def forward_train(self, z: torch.Tensor, labels: torch.Tensor) -> Dict[str, torch.Tensor]: | |
| """Teacher-forced training step (labels are decoder targets).""" | |
| enc_out, attn = self.build_encoder_outputs(z) | |
| out = self.model(encoder_outputs=enc_out, attention_mask=attn, labels=labels) | |
| loss = out.loss | |
| return {"loss": loss, "ce": loss.detach()} | |
| def generate( | |
| self, | |
| z: torch.Tensor, | |
| num_return_sequences: int = 1, | |
| max_len: int = GEN_MAX_LEN, | |
| top_p: float = GEN_TOP_P, | |
| temperature: float = GEN_TEMPERATURE, | |
| repetition_penalty: float = GEN_REPETITION_PENALTY, | |
| ) -> List[str]: | |
| """Stochastic decoding from a batch of CL latents.""" | |
| self.eval() | |
| z = z.to(next(self.parameters()).device) | |
| enc_out, attn = self.build_encoder_outputs(z) | |
| gen = self.model.generate( | |
| encoder_outputs=enc_out, | |
| attention_mask=attn, | |
| do_sample=True, | |
| top_p=float(top_p), | |
| temperature=float(temperature), | |
| repetition_penalty=float(repetition_penalty), | |
| num_return_sequences=int(num_return_sequences), | |
| max_length=int(max_len), | |
| min_length=int(GEN_MIN_LEN), | |
| pad_token_id=int(self.tok.pad_token_id) if self.tok.pad_token_id is not None else None, | |
| eos_token_id=int(self.tok.eos_token_id) if self.tok.eos_token_id is not None else None, | |
| ) | |
| outs = self.tok.batch_decode(gen, skip_special_tokens=True, clean_up_tokenization_spaces=True) | |
| outs = [_ensure_two_at_endpoints(_selfies_compact(o)) for o in outs] | |
| return outs | |
| def create_optimizer_and_scheduler_decoder(model: CLConditionedSelfiesTEDGenerator): | |
| """Create AdamW + CosineAnnealingLR for decoder fine-tuning.""" | |
| for p in model.parameters(): | |
| p.requires_grad = True | |
| opt = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY) | |
| sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=NUM_EPOCHS, eta_min=COSINE_ETA_MIN) | |
| return opt, sch | |
| # ============================================================================= | |
| # Datasets for latent-to-SELFIES training | |
| # ============================================================================= | |
| class LatentToPSELFIESDataset(Dataset): | |
| """ | |
| Each sample: | |
| - z: frozen CL embedding (optionally with Gaussian noise added for denoising) | |
| - labels: tokenized PSELFIES target sequence (pad tokens masked as -100) | |
| """ | |
| def __init__( | |
| self, | |
| records: List[dict], | |
| cl_encoder: MultiModalCLPolymerEncoder, | |
| selfies_tok, | |
| max_len: int = GEN_MAX_LEN, | |
| latent_noise_std: float = 0.0, | |
| cache_embeddings: bool = True, | |
| renormalize_after_noise: bool = True, | |
| use_multimodal: bool = True, | |
| ): | |
| self.records = records | |
| self.cl_encoder = cl_encoder | |
| self.tok = selfies_tok | |
| self.max_len = int(max_len) | |
| self.latent_noise_std = float(latent_noise_std) | |
| self.renorm = bool(renormalize_after_noise) | |
| self.use_multimodal = bool(use_multimodal) | |
| self.pad_id = int(self.tok.pad_token_id) if getattr(self.tok, "pad_token_id", None) is not None else 1 | |
| self._cache = None | |
| # Optionally precompute latents (saves a lot of time during decoder training) | |
| if cache_embeddings: | |
| if self.use_multimodal: | |
| emb = self.cl_encoder.encode_multimodal(self.records, batch_size=32, device=DEVICE) | |
| else: | |
| psm = [str(r.get("psmiles", "")) for r in self.records] | |
| emb = self.cl_encoder.encode_psmiles(psm, max_len=PSMILES_MAX_LEN, batch_size=64, device=DEVICE) | |
| self._cache = emb.astype(np.float32) | |
| def __len__(self): | |
| return len(self.records) | |
| def __getitem__(self, idx): | |
| r = self.records[idx] | |
| tgt = str(r["pselfies"]).strip() | |
| tgt = _selfies_for_tokenizer(tgt) | |
| # Get latent z (cached or computed on the fly) | |
| if self._cache is not None: | |
| z = torch.tensor(self._cache[idx], dtype=torch.float32) | |
| else: | |
| if self.use_multimodal: | |
| z_np = self.cl_encoder.encode_multimodal([r], batch_size=1, device=DEVICE) | |
| z = torch.tensor(z_np[0], dtype=torch.float32) | |
| else: | |
| psm = str(r.get("psmiles", "")).strip() | |
| z_np = self.cl_encoder.encode_psmiles([psm], max_len=PSMILES_MAX_LEN, batch_size=1, device=DEVICE) | |
| z = torch.tensor(z_np[0], dtype=torch.float32) | |
| # Denoising noise | |
| if self.latent_noise_std > 0: | |
| z = z + torch.randn_like(z) * self.latent_noise_std | |
| if self.renorm: | |
| z = F.normalize(z, dim=-1) | |
| # Tokenize target SELFIES; mask padding to -100 for CE | |
| enc = self.tok(tgt, truncation=True, padding="max_length", max_length=self.max_len, return_tensors=None) | |
| labels = torch.tensor(enc["input_ids"], dtype=torch.long) | |
| labels = labels.masked_fill(labels == self.pad_id, -100) | |
| return { | |
| "z": z, | |
| "labels": labels, | |
| "psmiles": str(r.get("psmiles", "")).strip(), | |
| "pselfies_raw": _selfies_compact(r["pselfies"]), | |
| } | |
| def latent_collate(batch: List[dict]) -> dict: | |
| """Collate latents and labels into batch tensors.""" | |
| z = torch.stack([b["z"] for b in batch], dim=0) | |
| labels = torch.stack([b["labels"] for b in batch], dim=0) | |
| return { | |
| "z": z, | |
| "labels": labels, | |
| "psmiles": [b["psmiles"] for b in batch], | |
| "pselfies_raw": [b["pselfies_raw"] for b in batch], | |
| } | |
| def move_latent_batch_to_device(batch: dict, device: str): | |
| batch["z"] = batch["z"].to(device) | |
| batch["labels"] = batch["labels"].to(device) | |
| # ============================================================================= | |
| # Aux PSMILES property oracle (optional) | |
| # ============================================================================= | |
| class PSMILESPropertyDataset(Dataset): | |
| """Text regression dataset: PSMILES -> scaled property (single scalar).""" | |
| def __init__(self, samples: List[dict], psmiles_tokenizer, max_len: int = PSMILES_MAX_LEN): | |
| self.samples = samples | |
| self.tok = psmiles_tokenizer | |
| self.max_len = max_len | |
| def __len__(self): | |
| return len(self.samples) | |
| def __getitem__(self, idx): | |
| s = str(self.samples[idx].get("psmiles", "")).strip() | |
| y = float(self.samples[idx].get("target_scaled", self.samples[idx].get("target", 0.0))) | |
| enc = self.tok(s, truncation=True, padding="max_length", max_length=self.max_len) | |
| return { | |
| "input_ids": torch.tensor(enc["input_ids"], dtype=torch.long), | |
| "attention_mask": torch.tensor(enc["attention_mask"], dtype=torch.bool), | |
| "y": torch.tensor([y], dtype=torch.float32), | |
| } | |
| def psmiles_prop_collate_fn(batch: List[dict]): | |
| input_ids = torch.stack([b["input_ids"] for b in batch], dim=0) | |
| attn = torch.stack([b["attention_mask"] for b in batch], dim=0) | |
| y = torch.stack([b["y"] for b in batch], dim=0) | |
| return {"input_ids": input_ids, "attention_mask": attn, "y": y} | |
| class TextPropertyOracle(nn.Module): | |
| """ | |
| Lightweight regressor for verification: | |
| - Frozen PSMILES encoder (DeBERTa variant) | |
| - Trainable MLP head | |
| """ | |
| def __init__(self, encoder_dir: Optional[str], vocab_size: Optional[int] = None, y_dim: int = 1): | |
| super().__init__() | |
| if encoder_dir is not None and os.path.isdir(encoder_dir): | |
| enc_src = encoder_dir | |
| elif os.path.isdir(CFG.BEST_PSMILES_DIR): | |
| enc_src = CFG.BEST_PSMILES_DIR | |
| else: | |
| enc_src = "microsoft/deberta-v2-xlarge" | |
| self.encoder = PSMILESDebertaEncoder( | |
| model_dir_or_name=enc_src, | |
| vocab_fallback=int(vocab_size) if vocab_size is not None else 300, | |
| ) | |
| h = getattr(self.encoder, "out_dim", DEBERTA_HIDDEN) | |
| self.head = nn.Sequential( | |
| nn.Linear(h, 256), | |
| nn.ReLU(), | |
| nn.Dropout(0.1), | |
| nn.Linear(256, y_dim), | |
| ) | |
| def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: | |
| h = self.encoder(input_ids=input_ids, attention_mask=attention_mask) | |
| return self.head(h) | |
| def move_prop_batch_to_device(batch: dict, device: str): | |
| batch["input_ids"] = batch["input_ids"].to(device) | |
| batch["attention_mask"] = batch["attention_mask"].to(device) | |
| batch["y"] = batch["y"].to(device) | |
| def train_prop_oracle_one_epoch(model: TextPropertyOracle, dl: DataLoader, opt, scaler_amp, device: str): | |
| model.train() | |
| total = 0.0 | |
| n = 0 | |
| for batch in dl: | |
| move_prop_batch_to_device(batch, device) | |
| y = batch["y"] | |
| opt.zero_grad(set_to_none=True) | |
| with torch.cuda.amp.autocast(enabled=USE_AMP, dtype=AMP_DTYPE): | |
| y_hat = model(batch["input_ids"], batch["attention_mask"]) | |
| loss = F.smooth_l1_loss(y_hat, y, beta=1.0) | |
| if USE_AMP: | |
| scaler_amp.scale(loss).backward() | |
| scaler_amp.unscale_(opt) | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| scaler_amp.step(opt) | |
| scaler_amp.update() | |
| else: | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| opt.step() | |
| bs = y.size(0) | |
| total += float(loss.item()) * bs | |
| n += bs | |
| return total / max(1, n) | |
| def eval_prop_oracle(model: TextPropertyOracle, dl: DataLoader, device: str): | |
| model.eval() | |
| total = 0.0 | |
| n = 0 | |
| for batch in dl: | |
| move_prop_batch_to_device(batch, device) | |
| y = batch["y"] | |
| with torch.cuda.amp.autocast(enabled=USE_AMP, dtype=AMP_DTYPE): | |
| y_hat = model(batch["input_ids"], batch["attention_mask"]) | |
| loss = F.smooth_l1_loss(y_hat, y, beta=1.0) | |
| bs = y.size(0) | |
| total += float(loss.item()) * bs | |
| n += bs | |
| return total / max(1, n) | |
| def train_property_oracle_per_fold( | |
| train_samples: List[dict], | |
| val_samples: List[dict], | |
| psmiles_tokenizer, | |
| device: str, | |
| max_len: int = PSMILES_MAX_LEN, | |
| ) -> Optional[TextPropertyOracle]: | |
| """Train a per-fold auxiliary oracle for scaled property prediction (verification only).""" | |
| if psmiles_tokenizer is None: | |
| return None | |
| try: | |
| model = TextPropertyOracle( | |
| encoder_dir=CFG.BEST_PSMILES_DIR if os.path.isdir(CFG.BEST_PSMILES_DIR) else None, | |
| vocab_size=getattr(psmiles_tokenizer, "vocab_size", None), | |
| y_dim=1, | |
| ).to(device) | |
| except Exception as e: | |
| print(f"[ORACLE][WARN] Could not initialize auxiliary property predictor: {e}") | |
| return None | |
| # Freeze encoder; train only head (fast + stable) | |
| for p in model.encoder.parameters(): | |
| p.requires_grad = False | |
| for p in model.head.parameters(): | |
| p.requires_grad = True | |
| ds_tr = PSMILESPropertyDataset(train_samples, psmiles_tokenizer, max_len=max_len) | |
| ds_va = PSMILESPropertyDataset(val_samples, psmiles_tokenizer, max_len=max_len) | |
| dl_tr = DataLoader(ds_tr, batch_size=PROP_PRED_BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, collate_fn=psmiles_prop_collate_fn) | |
| dl_va = DataLoader(ds_va, batch_size=PROP_PRED_BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True, collate_fn=psmiles_prop_collate_fn) | |
| opt = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=PROP_PRED_LR, weight_decay=PROP_PRED_WEIGHT_DECAY) | |
| scaler_amp = torch.cuda.amp.GradScaler(enabled=USE_AMP) | |
| best_val = float("inf") | |
| best_state = None | |
| no_imp = 0 | |
| for epoch in range(1, PROP_PRED_EPOCHS + 1): | |
| tr = train_prop_oracle_one_epoch(model, dl_tr, opt, scaler_amp, device) | |
| va = eval_prop_oracle(model, dl_va, device) | |
| if va < best_val - 1e-8: | |
| best_val = va | |
| no_imp = 0 | |
| best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} | |
| else: | |
| no_imp += 1 | |
| if no_imp >= PROP_PRED_PATIENCE: | |
| break | |
| if best_state is not None: | |
| model.load_state_dict({k: v.to(device) for k, v in best_state.items()}, strict=False) | |
| try: | |
| model.aux_val_loss = float(best_val) | |
| except Exception: | |
| pass | |
| return model | |
| def oracle_predict_scaled( | |
| oracle: Optional[TextPropertyOracle], | |
| psmiles_tokenizer, | |
| psmiles_list: List[str], | |
| device: str, | |
| max_len: int = PSMILES_MAX_LEN, | |
| ) -> Optional[np.ndarray]: | |
| """Batch predict scaled properties with the auxiliary oracle.""" | |
| if oracle is None or psmiles_tokenizer is None: | |
| return None | |
| if not psmiles_list: | |
| return np.array([], dtype=np.float32) | |
| oracle.eval() | |
| ys = [] | |
| bs = 32 | |
| for i in range(0, len(psmiles_list), bs): | |
| chunk = psmiles_list[i : i + bs] | |
| enc = psmiles_tokenizer(chunk, truncation=True, padding="max_length", max_length=max_len, return_tensors="pt") | |
| input_ids = enc["input_ids"].to(device) | |
| attn = enc["attention_mask"].to(device).bool() | |
| with torch.cuda.amp.autocast(enabled=USE_AMP, dtype=AMP_DTYPE): | |
| y_hat = oracle(input_ids, attn) | |
| ys.append(y_hat.detach().cpu().numpy().reshape(-1)) | |
| return np.concatenate(ys, axis=0) if ys else np.array([], dtype=np.float32) | |
| # ============================================================================= | |
| # Latent property model (per property) | |
| # ============================================================================= | |
| class LatentPropertyModel: | |
| y_scaler: StandardScaler | |
| pca: Optional[PCA] | |
| gpr: GaussianProcessRegressor | |
| def fit_latent_property_model(z_train: np.ndarray, y_train: np.ndarray, y_scaler: StandardScaler) -> LatentPropertyModel: | |
| """ | |
| Fit a GPR mapping (PSMILES latent) -> (scaled property). | |
| Uses optional PCA for stability when latent dim is large. | |
| """ | |
| y_train = np.array(y_train, dtype=np.float32).reshape(-1, 1) | |
| y_s = y_scaler.transform(y_train).reshape(-1).astype(np.float32) | |
| z_use = z_train.astype(np.float32) | |
| pca = None | |
| if USE_PCA_BEFORE_GPR: | |
| ncomp = int(min(PCA_DIM, z_use.shape[0] - 1, z_use.shape[1])) | |
| ncomp = max(2, ncomp) | |
| pca = PCA(n_components=ncomp, random_state=0) | |
| z_use = pca.fit_transform(z_use) | |
| kernel = ( | |
| C(1.0, (1e-3, 1e3)) | |
| * RBF(length_scale=1.0, length_scale_bounds=(1e-2, 1e2)) | |
| + WhiteKernel(noise_level=1e-3, noise_level_bounds=(1e-6, 1e-1)) | |
| ) | |
| gpr = GaussianProcessRegressor(kernel=kernel, alpha=GPR_ALPHA, normalize_y=True, random_state=0, n_restarts_optimizer=2) | |
| gpr.fit(z_use, y_s) | |
| return LatentPropertyModel(y_scaler=y_scaler, pca=pca, gpr=gpr) | |
| def predict_latent_property(model: LatentPropertyModel, z: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: | |
| """Predict scaled and unscaled properties for candidate latents.""" | |
| z_use = z.astype(np.float32) | |
| if model.pca is not None: | |
| z_use = model.pca.transform(z_use) | |
| y_s = model.gpr.predict(z_use, return_std=False) | |
| y_s = np.array(y_s, dtype=np.float32).reshape(-1) | |
| y_u = model.y_scaler.inverse_transform(y_s.reshape(-1, 1)).reshape(-1) | |
| return y_s, y_u | |
| # ============================================================================= | |
| # Train / eval loops (decoder) | |
| # ============================================================================= | |
| def train_one_epoch_decoder(model: CLConditionedSelfiesTEDGenerator, dl: DataLoader, optimizer, scaler_amp, device: str): | |
| """One epoch of teacher-forced decoder fine-tuning.""" | |
| model.train() | |
| total = 0.0 | |
| n = 0 | |
| ce_sum = 0.0 | |
| for batch in dl: | |
| move_latent_batch_to_device(batch, device) | |
| z = batch["z"] | |
| labels = batch["labels"] | |
| optimizer.zero_grad(set_to_none=True) | |
| with torch.cuda.amp.autocast(enabled=USE_AMP, dtype=AMP_DTYPE): | |
| out = model.forward_train(z, labels) | |
| loss = out["loss"] | |
| if USE_AMP: | |
| scaler_amp.scale(loss).backward() | |
| scaler_amp.unscale_(optimizer) | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| scaler_amp.step(optimizer) | |
| scaler_amp.update() | |
| else: | |
| loss.backward() | |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) | |
| optimizer.step() | |
| bs = z.size(0) | |
| total += float(loss.item()) * bs | |
| ce_sum += float(out["ce"].item()) * bs | |
| n += bs | |
| return {"loss": total / max(1, n), "ce": ce_sum / max(1, n)} | |
| def evaluate_decoder(model: CLConditionedSelfiesTEDGenerator, dl: DataLoader, device: str): | |
| """Validation loss for early stopping.""" | |
| model.eval() | |
| total = 0.0 | |
| n = 0 | |
| ce_sum = 0.0 | |
| for batch in dl: | |
| move_latent_batch_to_device(batch, device) | |
| z = batch["z"] | |
| labels = batch["labels"] | |
| with torch.cuda.amp.autocast(enabled=USE_AMP, dtype=AMP_DTYPE): | |
| out = model.forward_train(z, labels) | |
| loss = out["loss"] | |
| bs = z.size(0) | |
| total += float(loss.item()) * bs | |
| ce_sum += float(out["ce"].item()) * bs | |
| n += bs | |
| return {"loss": total / max(1, n), "ce": ce_sum / max(1, n)} | |
| # ============================================================================= | |
| # Generation / filtering (per target value, per property) | |
| # ============================================================================= | |
| def compute_diversity_morgan(smiles_list: List[str], radius: int = 2, nbits: int = 2048, p: float = 1.0) -> Optional[float]: | |
| """ | |
| Diversity = 1 - mean(Tanimoto), computed on Morgan fingerprints of unique valid SMILES. | |
| Returns None if RDKit unavailable or insufficient valid molecules. | |
| """ | |
| if not RDKit_AVAILABLE: | |
| return None | |
| try: | |
| p = float(p) | |
| if not np.isfinite(p) or p <= 0: | |
| p = 1.0 | |
| except Exception: | |
| p = 1.0 | |
| uniq = [] | |
| seen = set() | |
| for smi in smiles_list: | |
| smi = str(smi).strip() | |
| if not smi or smi in seen: | |
| continue | |
| seen.add(smi) | |
| uniq.append(smi) | |
| fps = [] | |
| for smi in uniq: | |
| try: | |
| mol = Chem.MolFromSmiles(smi) | |
| if mol is None: | |
| continue | |
| try: | |
| Chem.SanitizeMol(mol, catchErrors=True) | |
| except Exception: | |
| continue | |
| fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nbits) | |
| fps.append(fp) | |
| except Exception: | |
| continue | |
| if len(fps) < 2: | |
| return 0.0 if len(fps) == 1 else None | |
| sims_p = [] | |
| for i in range(len(fps)): | |
| for j in range(i + 1, len(fps)): | |
| try: | |
| s = float(DataStructs.TanimotoSimilarity(fps[i], fps[j])) | |
| sims_p.append(s ** p) | |
| except Exception: | |
| continue | |
| if not sims_p: | |
| return None | |
| mean_sim_p = float(np.mean(sims_p)) | |
| try: | |
| mean_sim = mean_sim_p ** (1.0 / p) | |
| except Exception: | |
| mean_sim = float( | |
| np.mean([float(DataStructs.TanimotoSimilarity(fps[i], fps[j])) for i in range(len(fps)) for j in range(i + 1, len(fps))]) | |
| ) | |
| return float(1.0 - mean_sim) | |
| def decode_from_latents(generator: CLConditionedSelfiesTEDGenerator, z: torch.Tensor, n_samples: int = 1) -> List[str]: | |
| """Decode PSELFIES from a batch of CL latents.""" | |
| return generator.generate( | |
| z=z, | |
| num_return_sequences=int(n_samples), | |
| max_len=GEN_MAX_LEN, | |
| top_p=GEN_TOP_P, | |
| temperature=GEN_TEMPERATURE, | |
| repetition_penalty=GEN_REPETITION_PENALTY, | |
| ) | |
| def generate_for_target( | |
| target_y_scaled: float, | |
| prop_model: LatentPropertyModel, | |
| cl_encoder: MultiModalCLPolymerEncoder, | |
| generator: CLConditionedSelfiesTEDGenerator, | |
| train_seed_pool: List[dict], | |
| train_targets_set: set, | |
| n_seeds: int = 8, | |
| n_noise: int = N_FOLD_NOISE_SAMPLING, | |
| noise_std: float = LATENT_NOISE_STD_GEN, | |
| prop_tol_scaled: float = PROP_TOL_SCALED, | |
| oracle: Optional[TextPropertyOracle] = None, | |
| psmiles_tokenizer=None, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Core generation routine for a single target property value (scaled): | |
| 1) Pick seed polymers close to target (in scaled property space). | |
| 2) Encode seeds (multimodal) -> latent vectors. | |
| 3) Add Gaussian noise to latents (exploration), renormalize. | |
| 4) Decode to PSELFIES -> convert to polymer PSMILES. | |
| 5) Filter by polymer/chem validity and property closeness (via GPR on PSMILES latents). | |
| 6) Compute novelty/uniqueness/diversity metrics; optionally score with aux oracle. | |
| """ | |
| def _l2_normalize_np(x: np.ndarray, eps: float = 1e-12) -> np.ndarray: | |
| n = np.linalg.norm(x, axis=-1, keepdims=True) | |
| return x / np.clip(n, eps, None) | |
| # Choose nearest seeds by property distance (scaled) | |
| ys = np.array([float(d["y_scaled"]) for d in train_seed_pool], dtype=np.float32) | |
| diffs = np.abs(ys - float(target_y_scaled)) | |
| order = np.argsort(diffs) | |
| chosen = [train_seed_pool[i] for i in order[: max(1, int(n_seeds))]] | |
| # Encode chosen seeds using multimodal encoder | |
| z_seed = cl_encoder.encode_multimodal(chosen, batch_size=32, device=DEVICE) | |
| if z_seed.shape[0] == 0: | |
| return {"generated": [], "metrics": {}} | |
| # Sample noise around each seed latent | |
| z_list = [] | |
| for i in range(z_seed.shape[0]): | |
| z0 = z_seed[i].astype(np.float32) | |
| for _ in range(int(n_noise)): | |
| z = z0 + np.random.randn(z0.shape[0]).astype(np.float32) * float(noise_std) | |
| z = _l2_normalize_np(z.reshape(1, -1)).reshape(-1) | |
| z_list.append(z) | |
| z_all = np.stack(z_list, axis=0).astype(np.float32) | |
| z_t = torch.tensor(z_all, dtype=torch.float32, device=DEVICE) | |
| # Decode to PSELFIES | |
| pselfies = decode_from_latents(generator, z_t, n_samples=1) | |
| # Convert to polymer PSMILES; record validity flags | |
| valid_psmiles = [] | |
| valid_flags, poly_flags = [], [] | |
| for s in pselfies: | |
| s = _ensure_two_at_endpoints(_selfies_compact(s)) | |
| psm = pselfies_to_psmiles(s) if (RDKit_AVAILABLE and SELFIES_AVAILABLE) else None | |
| if psm is None: | |
| valid_flags.append(False) | |
| poly_flags.append(False) | |
| continue | |
| psm_can = canonicalize_psmiles(psm) | |
| ok = chem_validity_psmiles(psm_can) if psm_can else False | |
| poly_ok = polymer_validity_psmiles_strict(psm_can) if psm_can else False | |
| valid_flags.append(bool(ok)) | |
| poly_flags.append(bool(poly_ok)) | |
| if ok and poly_ok and psm_can: | |
| valid_psmiles.append(psm_can) | |
| uniq_valid = sorted(set(valid_psmiles)) | |
| novelty_valid = [1.0 if s not in train_targets_set else 0.0 for s in uniq_valid] if uniq_valid else [] | |
| n_valid_poly = int(len(valid_psmiles)) | |
| uniqueness_valid_unique = float(len(uniq_valid)) / float(max(1, n_valid_poly)) if n_valid_poly > 0 else 0.0 | |
| # Property prediction via GPR on PSMILES latents (for filtering) | |
| if uniq_valid: | |
| z_cand = cl_encoder.encode_psmiles(uniq_valid, max_len=PSMILES_MAX_LEN, batch_size=64, device=DEVICE) | |
| else: | |
| z_cand = np.zeros((0, cl_encoder.out_dim), dtype=np.float32) | |
| yhat_s, yhat_u = (np.array([], dtype=np.float32), np.array([], dtype=np.float32)) | |
| if z_cand.shape[0] > 0: | |
| yhat_s, yhat_u = predict_latent_property(prop_model, z_cand) | |
| keep, keep_pred_scaled, keep_pred_unscaled = [], [], [] | |
| for i, psm in enumerate(uniq_valid): | |
| if abs(float(yhat_s[i]) - float(target_y_scaled)) <= float(prop_tol_scaled): | |
| keep.append(psm) | |
| keep_pred_scaled.append(float(yhat_s[i])) | |
| keep_pred_unscaled.append(float(yhat_u[i])) | |
| novelty_keep = [1.0 if s not in train_targets_set else 0.0 for s in keep] if keep else [] | |
| # Optional aux oracle prediction for additional sanity checking | |
| aux_pred_scaled = None | |
| if VERIFY_GENERATED_PROPERTIES and oracle is not None and psmiles_tokenizer is not None and keep: | |
| aux = oracle_predict_scaled(oracle, psmiles_tokenizer, keep, DEVICE, PSMILES_MAX_LEN) | |
| aux_pred_scaled = aux.tolist() if aux is not None else None | |
| # Diversity computed on At-SMILES (to avoid polymer "*" parsing issues) | |
| at_smiles = [] | |
| if RDKit_AVAILABLE and keep: | |
| for psm in keep: | |
| at_smi = psmiles_to_at_smiles(psm, root_at=False) | |
| if at_smi is not None: | |
| at_smiles.append(at_smi) | |
| div = compute_diversity_morgan(at_smiles) if at_smiles else None | |
| metrics = { | |
| "n_total": int(len(pselfies)), | |
| "validity": float(np.mean(valid_flags)) if valid_flags else 0.0, | |
| "polymer_validity": float(np.mean(poly_flags)) if poly_flags else 0.0, | |
| "n_valid_unique": int(len(uniq_valid)), | |
| "novelty_valid_unique": float(np.mean(novelty_valid)) if novelty_valid else 0.0, | |
| "uniqueness_valid_unique": float(uniqueness_valid_unique), | |
| "n_kept_property_filtered": int(len(keep)), | |
| "novelty_kept": float(np.mean(novelty_keep)) if novelty_keep else 0.0, | |
| "diversity": float(div) if div is not None else 0.0, | |
| } | |
| return { | |
| "generated": keep, | |
| "pred_scaled_kept": keep_pred_scaled, | |
| "pred_unscaled_kept": keep_pred_unscaled, | |
| "aux_pred_scaled": aux_pred_scaled, | |
| "metrics": metrics, | |
| } | |
| # ============================================================================= | |
| # Data assembly (per property) | |
| # ============================================================================= | |
| def build_polymer_records(df: pd.DataFrame, prop_col: str) -> List[dict]: | |
| """ | |
| Build records for a single property: | |
| - require chemically valid + strictly polymer-valid PSMILES | |
| - require finite property value | |
| - generate PSELFIES for decoder targets | |
| - preserve optional modalities for multimodal seed encoding | |
| """ | |
| if not (RDKit_AVAILABLE and SELFIES_AVAILABLE): | |
| raise RuntimeError("RDKit + selfies are required for this pipeline.") | |
| recs = [] | |
| for _, row in df.iterrows(): | |
| psmiles_raw = str(row.get("psmiles", "")).strip() | |
| if not psmiles_raw: | |
| continue | |
| psm_can = canonicalize_psmiles(psmiles_raw) | |
| if not psm_can: | |
| continue | |
| if not chem_validity_psmiles(psm_can): | |
| continue | |
| if not polymer_validity_psmiles_strict(psm_can): | |
| continue | |
| val = row.get(prop_col, None) | |
| if val is None: | |
| continue | |
| try: | |
| y = float(val) | |
| if not np.isfinite(y): | |
| continue | |
| except Exception: | |
| continue | |
| pself = psmiles_to_pselfies(psm_can) | |
| if pself is None: | |
| continue | |
| recs.append( | |
| { | |
| "psmiles": psm_can, | |
| "pselfies": pself, | |
| "y": y, | |
| "graph": row.get("graph", None), | |
| "geometry": row.get("geometry", None), | |
| "fingerprints": row.get("fingerprints", None), | |
| } | |
| ) | |
| return recs | |
| # ============================================================================= | |
| # Best-fold artifact saving (per property) | |
| # ============================================================================= | |
| def save_best_fold_artifacts_for_property( | |
| property_name: str, | |
| fold_idx: int, | |
| decoder_state: Dict[str, torch.Tensor], | |
| prop_model: Optional[LatentPropertyModel], | |
| scaler: Optional[StandardScaler], | |
| best_val_loss: float, | |
| generations_payload: List[dict], | |
| ): | |
| """ | |
| Persist the best fold for a property: | |
| - decoder state_dict | |
| - scaler + GPR (joblib, if available) | |
| - meta.json describing hyperparams | |
| - jsonl generations payload for traceability | |
| """ | |
| safe_prop = property_name.replace(" ", "_") | |
| prop_dir = os.path.join(CFG.OUTPUT_MODELS_DIR, safe_prop) | |
| os.makedirs(prop_dir, exist_ok=True) | |
| decoder_path = os.path.join(prop_dir, f"decoder_best_fold{fold_idx+1}.pt") | |
| torch.save(decoder_state, decoder_path) | |
| try: | |
| import joblib | |
| except Exception: | |
| joblib = None | |
| if joblib is not None: | |
| if scaler is not None: | |
| joblib.dump(scaler, os.path.join(prop_dir, f"standardscaler_{safe_prop}.joblib")) | |
| if prop_model is not None: | |
| joblib.dump(prop_model, os.path.join(prop_dir, f"gpr_psmiles_{safe_prop}.joblib")) | |
| meta = { | |
| "property": property_name, | |
| "best_fold": int(fold_idx + 1), | |
| "best_val_loss": float(best_val_loss), | |
| "selfies_ted_model": str(SELFIES_TED_MODEL_NAME), | |
| "cl_emb_dim": int(CL_EMB_DIM), | |
| "mem_len": 4, | |
| "tol_scaled": float(PROP_TOL_SCALED), | |
| "tol_unscaled_abs": float(PROP_TOL_UNSCALED_ABS) if PROP_TOL_UNSCALED_ABS is not None else None, | |
| "optimizer": "AdamW", | |
| "lr": float(LEARNING_RATE), | |
| "weight_decay": float(WEIGHT_DECAY), | |
| "lr_scheduler": "CosineAnnealingLR", | |
| "epochs": int(NUM_EPOCHS), | |
| "batch_size": int(BATCH_SIZE), | |
| "patience": int(PATIENCE), | |
| } | |
| try: | |
| with open(os.path.join(prop_dir, "meta.json"), "w", encoding="utf-8") as f: | |
| json.dump(meta, f, indent=2) | |
| except Exception: | |
| pass | |
| out_path = os.path.join(CFG.OUTPUT_GENERATIONS_DIR, f"{safe_prop}_best_fold{fold_idx+1}_generated_psmiles.jsonl") | |
| try: | |
| with open(out_path, "w", encoding="utf-8") as fh: | |
| for r in generations_payload: | |
| fh.write(json.dumps(make_json_serializable({"property": property_name, "best_fold": fold_idx + 1, **r})) + "\n") | |
| except Exception as e: | |
| print(f"[SAVE][WARN] Could not write generations for '{property_name}': {e}") | |
| # ============================================================================= | |
| # Main per-property CV loop (single-task) | |
| # ============================================================================= | |
| def run_inverse_design_single_property( | |
| df: pd.DataFrame, | |
| property_name: str, | |
| prop_col: str, | |
| cl_encoder: MultiModalCLPolymerEncoder, | |
| selfies_tok, | |
| selfies_model, | |
| ) -> Dict[str, Any]: | |
| """ | |
| Run fivefold CV for a single property and log fold-level metrics. | |
| Best fold is tracked by decoder validation loss and saved to disk. | |
| """ | |
| polymers = build_polymer_records(df, prop_col) | |
| if len(polymers) < 200: | |
| print(f"[{property_name}][WARN] Only {len(polymers)} usable samples; results may be noisy.") | |
| if len(polymers) < 50: | |
| print(f"[{property_name}][WARN] Skipping due to insufficient usable samples (<50).") | |
| return {"property": property_name, "runs": [], "agg": None, "n_samples": len(polymers)} | |
| indices = np.arange(len(polymers)) | |
| kf = KFold(n_splits=NUM_FOLDS, shuffle=True, random_state=42) | |
| runs = [] | |
| best_overall_val = float("inf") | |
| best_bundle = None # kept for completeness; artifacts saved immediately when best improves | |
| for fold_idx, (trainval_idx, test_idx) in enumerate(kf.split(indices)): | |
| seed = 42 + fold_idx | |
| set_seed(seed) | |
| print(f"\n[{property_name}] Fold {fold_idx+1}/{NUM_FOLDS} | seed={seed}") | |
| trainval_polys = [polymers[i] for i in trainval_idx] | |
| test_polys = [polymers[i] for i in test_idx] | |
| # Train/val split within trainval | |
| tr_idx, va_idx = train_test_split(np.arange(len(trainval_polys)), test_size=0.10, random_state=seed, shuffle=True) | |
| train_polys = [copy.deepcopy(trainval_polys[i]) for i in tr_idx] | |
| val_polys = [copy.deepcopy(trainval_polys[i]) for i in va_idx] | |
| # Scale property targets using TRAIN only | |
| sc = StandardScaler() | |
| sc.fit(np.array([p["y"] for p in train_polys], dtype=np.float32).reshape(-1, 1)) | |
| # Helper to format records for latent dataset | |
| def _to_rec(p): | |
| return { | |
| "psmiles": p["psmiles"], | |
| "pselfies": p["pselfies"], | |
| "graph": p.get("graph", None), | |
| "geometry": p.get("geometry", None), | |
| "fingerprints": p.get("fingerprints", None), | |
| } | |
| # Decoder training datasets (cache CL embeddings for speed) | |
| ds_train = LatentToPSELFIESDataset( | |
| [_to_rec(p) for p in train_polys], | |
| cl_encoder, | |
| selfies_tok, | |
| max_len=GEN_MAX_LEN, | |
| latent_noise_std=LATENT_NOISE_STD_TRAIN, | |
| cache_embeddings=True, | |
| use_multimodal=True, | |
| ) | |
| ds_val = LatentToPSELFIESDataset( | |
| [_to_rec(p) for p in val_polys], | |
| cl_encoder, | |
| selfies_tok, | |
| max_len=GEN_MAX_LEN, | |
| latent_noise_std=0.0, | |
| cache_embeddings=True, | |
| use_multimodal=True, | |
| ) | |
| dl_train = DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, collate_fn=latent_collate) | |
| dl_val = DataLoader(ds_val, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True, collate_fn=latent_collate) | |
| # Fit GPR on PSMILES latents for this property (train only) | |
| y_tr = [float(p["y"]) for p in train_polys] | |
| psm_tr = [p["psmiles"] for p in train_polys] | |
| z_tr = cl_encoder.encode_psmiles(psm_tr, max_len=PSMILES_MAX_LEN, batch_size=64, device=DEVICE) | |
| prop_model = fit_latent_property_model(z_tr, np.array(y_tr, dtype=np.float32), y_scaler=sc) | |
| print(f"[{property_name}] Fit PSMILES-latent GPR (n_train={len(y_tr)})") | |
| # Optional aux oracle (scaled) | |
| oracle = None | |
| if VERIFY_GENERATED_PROPERTIES and len(train_polys) >= 200 and len(val_polys) >= 50: | |
| tr_s, va_s = [], [] | |
| for p in train_polys: | |
| y_s = float(sc.transform(np.array([[p["y"]]], dtype=np.float32))[0, 0]) | |
| tr_s.append({"psmiles": p["psmiles"], "target": p["y"], "target_scaled": y_s}) | |
| for p in val_polys: | |
| y_s = float(sc.transform(np.array([[p["y"]]], dtype=np.float32))[0, 0]) | |
| va_s.append({"psmiles": p["psmiles"], "target": p["y"], "target_scaled": y_s}) | |
| try: | |
| oracle = train_property_oracle_per_fold(tr_s, va_s, cl_encoder.psm_tok, DEVICE, PSMILES_MAX_LEN) | |
| print(f"[{property_name}] Trained aux oracle (val_loss={getattr(oracle, 'aux_val_loss', None)})") | |
| except Exception as e: | |
| print(f"[{property_name}][WARN] Oracle training failed: {e}") | |
| oracle = None | |
| # Fresh decoder per fold + optimizer | |
| selfies_tok_f, selfies_model_f = load_selfies_ted_and_tokenizer(SELFIES_TED_MODEL_NAME) | |
| decoder = CLConditionedSelfiesTEDGenerator(selfies_tok_f, selfies_model_f, cl_emb_dim=CL_EMB_DIM, mem_len=4).to(DEVICE) | |
| optimizer, scheduler = create_optimizer_and_scheduler_decoder(decoder) | |
| scaler_amp = torch.cuda.amp.GradScaler(enabled=USE_AMP) | |
| best_val = float("inf") | |
| best_state = None | |
| no_improve = 0 | |
| for epoch in range(1, NUM_EPOCHS + 1): | |
| tr = train_one_epoch_decoder(decoder, dl_train, optimizer, scaler_amp, DEVICE) | |
| va = evaluate_decoder(decoder, dl_val, DEVICE) | |
| try: | |
| scheduler.step() | |
| except Exception: | |
| pass | |
| try: | |
| lr = float(optimizer.param_groups[0]["lr"]) | |
| print( | |
| f"[{property_name}] fold {fold_idx+1}/{NUM_FOLDS} | epoch {epoch:03d} | " | |
| f"lr={lr:.2e} | train_loss={tr['loss']:.6f} | val_loss={va['loss']:.6f}" | |
| ) | |
| except Exception: | |
| print( | |
| f"[{property_name}] fold {fold_idx+1}/{NUM_FOLDS} | epoch {epoch:03d} | " | |
| f"train_loss={tr['loss']:.6f} | val_loss={va['loss']:.6f}" | |
| ) | |
| if va["loss"] < best_val - 1e-8: | |
| best_val = va["loss"] | |
| no_improve = 0 | |
| best_state = {k: v.detach().cpu().clone() for k, v in decoder.state_dict().items()} | |
| else: | |
| no_improve += 1 | |
| if no_improve >= PATIENCE: | |
| print(f"[{property_name}] Early stopping (no val improvement for {PATIENCE} epochs).") | |
| break | |
| if best_state is None: | |
| print(f"[{property_name}][WARN] No best state captured; skipping this fold.") | |
| continue | |
| decoder.load_state_dict({k: v.to(DEVICE) for k, v in best_state.items()}, strict=False) | |
| # Seed pool for generation (scaled property values, plus modalities for encoding) | |
| seed_pool = [] | |
| for p in train_polys: | |
| y_s = float(sc.transform(np.array([[p["y"]]], dtype=np.float32))[0, 0]) | |
| seed_pool.append( | |
| { | |
| "psmiles": p["psmiles"], | |
| "y_scaled": y_s, | |
| "graph": p.get("graph", None), | |
| "geometry": p.get("geometry", None), | |
| "fingerprints": p.get("fingerprints", None), | |
| } | |
| ) | |
| train_targets_set = set(ps["psmiles"] for ps in train_polys) | |
| # Compose test targets (scaled); subsample for runtime control | |
| ys_test_scaled = [] | |
| for p in test_polys: | |
| ys_test_scaled.append(float(sc.transform(np.array([[p["y"]]], dtype=np.float32))[0, 0])) | |
| ys_test_scaled = np.array(ys_test_scaled, dtype=np.float32) | |
| if len(ys_test_scaled) > 64: | |
| ys_test_scaled = np.random.choice(ys_test_scaled, size=64, replace=False) | |
| # Generate per target | |
| all_valid, all_poly, all_kept, success_scaled, mae_best, diversity_vals = [], [], [], [], [], [] | |
| novelty_vals, uniqueness_vals = [], [] | |
| per_target_records = [] | |
| for y_t in ys_test_scaled: | |
| out = generate_for_target( | |
| target_y_scaled=float(y_t), | |
| prop_model=prop_model, | |
| cl_encoder=cl_encoder, | |
| generator=decoder, | |
| train_seed_pool=seed_pool, | |
| train_targets_set=train_targets_set, | |
| n_seeds=8, | |
| n_noise=min(N_FOLD_NOISE_SAMPLING, 16), | |
| noise_std=LATENT_NOISE_STD_GEN, | |
| prop_tol_scaled=PROP_TOL_SCALED, | |
| oracle=oracle, | |
| psmiles_tokenizer=cl_encoder.psm_tok, | |
| ) | |
| m = out["metrics"] | |
| all_valid.append(float(m.get("validity", 0.0))) | |
| all_poly.append(float(m.get("polymer_validity", 0.0))) | |
| all_kept.append(int(m.get("n_kept_property_filtered", 0))) | |
| diversity_vals.append(float(m.get("diversity", 0.0))) | |
| success_scaled.append(1.0 if int(m.get("n_kept_property_filtered", 0)) > 0 else 0.0) | |
| novelty_vals.append(float(m.get("novelty_kept", 0.0))) | |
| uniqueness_vals.append(float(m.get("uniqueness_valid_unique", 0.0))) | |
| # Best error among kept candidates | |
| if out["generated"]: | |
| z_keep = cl_encoder.encode_psmiles(out["generated"], max_len=PSMILES_MAX_LEN, batch_size=64, device=DEVICE) | |
| y_pred_s, _ = predict_latent_property(prop_model, z_keep) | |
| if len(y_pred_s): | |
| err = np.abs(y_pred_s - float(y_t)) | |
| mae_best.append(float(np.min(err))) | |
| else: | |
| mae_best.append(float("inf")) | |
| else: | |
| mae_best.append(float("inf")) | |
| target_y_unscaled = float(sc.inverse_transform(np.array([[float(y_t)]], dtype=np.float32))[0, 0]) | |
| aux_list = out.get("aux_pred_scaled", None) | |
| if aux_list is not None and not isinstance(aux_list, list): | |
| aux_list = None | |
| candidates = [] | |
| gen_list = out.get("generated", []) or [] | |
| pred_s_list = out.get("pred_scaled_kept", []) or [] | |
| pred_u_list = out.get("pred_unscaled_kept", []) or [] | |
| for i_c, psm in enumerate(gen_list): | |
| cand = { | |
| "psmiles": str(psm), | |
| "pred_scaled": float(pred_s_list[i_c]) if i_c < len(pred_s_list) else None, | |
| "pred_unscaled": float(pred_u_list[i_c]) if i_c < len(pred_u_list) else None, | |
| "aux_pred_scaled": float(aux_list[i_c]) if (aux_list is not None and i_c < len(aux_list)) else None, | |
| } | |
| candidates.append(cand) | |
| scaler_meta = { | |
| "scaler_type": "StandardScaler", | |
| "mean_": getattr(sc, "mean_", None), | |
| "scale_": getattr(sc, "scale_", None), | |
| "with_mean": bool(getattr(sc, "with_mean", True)), | |
| "with_std": bool(getattr(sc, "with_std", True)), | |
| } | |
| per_target_records.append( | |
| { | |
| "target_y_scaled": float(y_t), | |
| "target_y_unscaled": float(target_y_unscaled), | |
| "tol_scaled": float(PROP_TOL_SCALED), | |
| "tol_unscaled_abs": float(PROP_TOL_UNSCALED_ABS) if PROP_TOL_UNSCALED_ABS is not None else None, | |
| "scaler_meta": scaler_meta, | |
| "candidates": candidates, | |
| "metrics": m, | |
| } | |
| ) | |
| def _finite(xs): | |
| return [x for x in xs if np.isfinite(x)] | |
| metrics_fold = { | |
| "validity_mean": float(np.mean(all_valid)) if all_valid else 0.0, | |
| "polymer_validity_mean": float(np.mean(all_poly)) if all_poly else 0.0, | |
| "avg_n_kept": float(np.mean(all_kept)) if all_kept else 0.0, | |
| "success_at_k_scaled": float(np.mean(success_scaled)) if success_scaled else 0.0, | |
| "mae_best_scaled": float(np.mean(_finite(mae_best))) if _finite(mae_best) else 0.0, | |
| "diversity_mean": float(np.mean(diversity_vals)) if diversity_vals else 0.0, | |
| "novelty_mean": float(np.mean(novelty_vals)) if novelty_vals else 0.0, | |
| "uniqueness_mean": float(np.mean(uniqueness_vals)) if uniqueness_vals else 0.0, | |
| "tol_scaled": float(PROP_TOL_SCALED), | |
| "tol_unscaled_abs": float(PROP_TOL_UNSCALED_ABS) if PROP_TOL_UNSCALED_ABS is not None else None, | |
| } | |
| run_record = { | |
| "property": property_name, | |
| "fold": int(fold_idx + 1), | |
| "seed": int(seed), | |
| "n_train": int(len(train_polys)), | |
| "n_val": int(len(val_polys)), | |
| "n_test": int(len(test_polys)), | |
| "best_val_loss": float(best_val), | |
| "gen_metrics": metrics_fold, | |
| } | |
| runs.append(run_record) | |
| with open(CFG.OUTPUT_RESULTS, "a", encoding="utf-8") as fh: | |
| fh.write(json.dumps(make_json_serializable(run_record)) + "\n") | |
| # Save best fold artifacts by lowest validation loss | |
| if best_val < best_overall_val - 1e-8: | |
| best_overall_val = best_val | |
| best_bundle = { | |
| "fold": int(fold_idx + 1), | |
| "decoder_state": best_state, | |
| "prop_model": prop_model, | |
| "scaler": sc, | |
| "best_val_loss": float(best_val), | |
| "generations": per_target_records, | |
| } | |
| save_best_fold_artifacts_for_property( | |
| property_name=property_name, | |
| fold_idx=fold_idx, | |
| decoder_state=best_state, | |
| prop_model=prop_model, | |
| scaler=sc, | |
| best_val_loss=best_val, | |
| generations_payload=per_target_records, | |
| ) | |
| print(f"[{property_name}] Saved best-fold artifacts (fold {fold_idx+1}, val_loss={best_val:.6f}).") | |
| # Aggregate across folds | |
| if not runs: | |
| return {"property": property_name, "runs": [], "agg": None, "n_samples": len(polymers)} | |
| def _collect(key): | |
| xs = [float(r["gen_metrics"].get(key, 0.0)) for r in runs if r.get("gen_metrics", None) is not None] | |
| return (float(np.mean(xs)) if xs else 0.0, float(np.std(xs)) if xs else 0.0) | |
| agg = {} | |
| for k in [ | |
| "validity_mean", | |
| "polymer_validity_mean", | |
| "avg_n_kept", | |
| "success_at_k_scaled", | |
| "mae_best_scaled", | |
| "diversity_mean", | |
| "novelty_mean", | |
| "uniqueness_mean", | |
| ]: | |
| m, s = _collect(k) | |
| agg[k] = {"mean": m, "std": s} | |
| agg["tol_scaled"] = float(PROP_TOL_SCALED) | |
| agg["tol_unscaled_abs"] = float(PROP_TOL_UNSCALED_ABS) if PROP_TOL_UNSCALED_ABS is not None else None | |
| with open(CFG.OUTPUT_RESULTS, "a", encoding="utf-8") as fh: | |
| fh.write("AGG_PROPERTY: " + json.dumps(make_json_serializable({property_name: agg})) + "\n") | |
| return {"property": property_name, "runs": runs, "agg": agg, "n_samples": len(polymers)} | |
| # ============================================================================= | |
| # Entrypoint (single-task per property) | |
| # ============================================================================= | |
| def main(): | |
| ensure_output_dirs(CFG) | |
| if not (RDKit_AVAILABLE and SELFIES_AVAILABLE): | |
| raise RuntimeError("This script requires RDKit and selfies. Install them before running.") | |
| # Reset results file | |
| if os.path.exists(CFG.OUTPUT_RESULTS): | |
| backup = CFG.OUTPUT_RESULTS + ".bak" | |
| shutil.copy(CFG.OUTPUT_RESULTS, backup) | |
| print(f"[IO][INFO] Backed up existing results file to: {backup}") | |
| open(CFG.OUTPUT_RESULTS, "w", encoding="utf-8").close() | |
| # Load dataset | |
| if not os.path.isfile(CFG.POLYINFO_PATH): | |
| raise FileNotFoundError(f"PolyInfo CSV not found: {CFG.POLYINFO_PATH}") | |
| df = pd.read_csv(CFG.POLYINFO_PATH, engine="python") | |
| found = find_property_columns(df.columns) | |
| prop_map = {req: found.get(req) for req in REQUESTED_PROPERTIES} | |
| print("\n" + "=" * 80) | |
| print("[RUN] Inverse design (single-task per property)") | |
| print("=" * 80) | |
| print(f"[ENV] RDKit_AVAILABLE={RDKit_AVAILABLE} | SELFIES_AVAILABLE={SELFIES_AVAILABLE}") | |
| print(f"[ENV] DEVICE={DEVICE} | USE_AMP={USE_AMP} | NUM_WORKERS={NUM_WORKERS}") | |
| print(f"[DATA] POLYINFO_PATH={CFG.POLYINFO_PATH}") | |
| print(f"[DATA] Property map: {prop_map}") | |
| print(f"[CL] CL checkpoint dir: {CFG.PRETRAINED_MULTIMODAL_DIR}") | |
| print(f"[DEC] SELFIES_TED_MODEL_NAME={SELFIES_TED_MODEL_NAME}") | |
| print( | |
| f"[DEC] FT params: batch={BATCH_SIZE}, epochs={NUM_EPOCHS}, patience={PATIENCE}, " | |
| f"lr={LEARNING_RATE}, wd={WEIGHT_DECAY}, sched=CosineAnnealingLR(eta_min={COSINE_ETA_MIN})" | |
| ) | |
| print(f"[GEN] Latent noise: train_std={LATENT_NOISE_STD_TRAIN}, gen_std={LATENT_NOISE_STD_GEN}, n_noise={N_FOLD_NOISE_SAMPLING}") | |
| print(f"[GEN] Filter tol: scaled={PROP_TOL_SCALED}, abs={PROP_TOL_UNSCALED_ABS}") | |
| print(f"[AUX] VERIFY_GENERATED_PROPERTIES={VERIFY_GENERATED_PROPERTIES}") | |
| print("=" * 80 + "\n") | |
| # Build PSMILES tokenizer for CL text encoder | |
| psmiles_tok = build_psmiles_tokenizer(spm_path=CFG.SPM_MODEL_PATH, max_len=PSMILES_MAX_LEN) | |
| if psmiles_tok is None: | |
| raise RuntimeError("Failed to build PSMILES tokenizer (check SPM_MODEL_PATH).") | |
| # Multimodal CL encoder (frozen; used as conditioning interface) | |
| cl_encoder = MultiModalCLPolymerEncoder( | |
| psmiles_tokenizer=psmiles_tok, | |
| emb_dim=CL_EMB_DIM, | |
| cl_weights_dir=CFG.PRETRAINED_MULTIMODAL_DIR, | |
| use_gine=True, | |
| use_schnet=True, | |
| use_fp=True, | |
| use_psmiles=True, | |
| ).to(DEVICE) | |
| cl_encoder.freeze_cl_encoders() | |
| # Load SELFIES-TED backbone | |
| selfies_tok, selfies_model = load_selfies_ted_and_tokenizer(SELFIES_TED_MODEL_NAME) | |
| print(f"[HF][INFO] Loaded SELFIES-TED backbone: {SELFIES_TED_MODEL_NAME}") | |
| overall = {"per_property": {}} | |
| # Single-task loop per property | |
| for pname in REQUESTED_PROPERTIES: | |
| pcol = prop_map.get(pname, None) | |
| if pcol is None: | |
| print(f"[{pname}][WARN] No column match found; skipping.") | |
| continue | |
| print(f"\n>>> Property: '{pname}' | column='{pcol}'") | |
| res = run_inverse_design_single_property(df, pname, pcol, cl_encoder, selfies_tok, selfies_model) | |
| overall["per_property"][pname] = res | |
| # Final summary (aggregated per property) | |
| final_agg = {} | |
| for pname, info in overall["per_property"].items(): | |
| final_agg[pname] = info.get("agg", None) | |
| with open(CFG.OUTPUT_RESULTS, "a", encoding="utf-8") as fh: | |
| fh.write("\nFINAL_SUMMARY\n") | |
| fh.write(json.dumps(make_json_serializable(final_agg), indent=2)) | |
| fh.write("\n") | |
| print("\n" + "=" * 80) | |
| print("Finished inverse design runs.") | |
| print(f"Results file: {CFG.OUTPUT_RESULTS}") | |
| print(f"Best models dir: {CFG.OUTPUT_MODELS_DIR}") | |
| print(f"Best-fold generations dir: {CFG.OUTPUT_GENERATIONS_DIR}") | |
| print("=" * 80) | |
| if __name__ == "__main__": | |
| main() | |