import os import re import sys import csv import json import math import time import copy import random import shutil import warnings from dataclasses import dataclass from pathlib import Path from typing import List, Dict, Optional, Tuple, Any import numpy as np import pandas as pd from sklearn.model_selection import train_test_split, KFold from sklearn.preprocessing import StandardScaler from sklearn.decomposition import PCA from sklearn.gaussian_process import GaussianProcessRegressor from sklearn.gaussian_process.kernels import RBF, ConstantKernel as C, WhiteKernel import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader # Increase CSV field size limit safely (helps when JSON blobs are stored in CSV cells) try: csv.field_size_limit(sys.maxsize) except OverflowError: csv.field_size_limit(2**31 - 1) # HF Transformers (SELFIES-TED) from transformers import AutoTokenizer, AutoModelForSeq2SeqLM from transformers.modeling_outputs import BaseModelOutput # Shared encoders/helpers from PolyFusion from PolyFusion.GINE import GineEncoder from PolyFusion.SchNet import NodeSchNetWrapper from PolyFusion.Transformer import PooledFingerprintEncoder as FingerprintEncoder from PolyFusion.DeBERTav2 import PSMILESDebertaEncoder, build_psmiles_tokenizer # Optional chemistry dependencies (recommended) RDKit_AVAILABLE = False SELFIES_AVAILABLE = False try: from rdkit import Chem from rdkit.Chem import AllChem, DataStructs RDKit_AVAILABLE = True except Exception: RDKit_AVAILABLE = False try: import selfies as sf SELFIES_AVAILABLE = True except Exception: SELFIES_AVAILABLE = False # ============================================================================= # Configuration (paths are placeholders; replace with your actual filesystem paths) # ============================================================================= @dataclass(frozen=True) class Config: # ------------------------------------------------------------------------- # Input data and pretrained artifacts (placeholders) # ------------------------------------------------------------------------- BASE_DIR: str = "/path/to/Polymer_Foundational_Model" POLYINFO_PATH: str = "/path/to/polyinfo_with_modalities.csv" # Multimodal CL checkpoint (for the fused encoder) PRETRAINED_MULTIMODAL_DIR: str = "/path/to/multimodal_output/best" # Unimodal encoder checkpoints BEST_GINE_DIR: str = "/path/to/gin_output/best" BEST_SCHNET_DIR: str = "/path/to/schnet_output/best" BEST_FP_DIR: str = "/path/to/fingerprint_mlm_output/best" BEST_PSMILES_DIR: str = "/path/to/polybert_output/best" # SentencePiece model for PSMILES tokenizer (placeholder) SPM_MODEL_PATH: str = "/path/to/spm.model" # ------------------------------------------------------------------------- # Output folders # ------------------------------------------------------------------------- OUTPUT_DIR: str = "/path/to/multimodal_inverse_design_output" @property def OUTPUT_RESULTS(self) -> str: return os.path.join(self.OUTPUT_DIR, "inverse_design_results.txt") @property def OUTPUT_MODELS_DIR(self) -> str: return os.path.join(self.OUTPUT_DIR, "best_models") @property def OUTPUT_GENERATIONS_DIR(self) -> str: return os.path.join(self.OUTPUT_DIR, "best_fold_generations") CFG = Config() # Properties to run REQUESTED_PROPERTIES = [ "density", "glass transition", "melting", "thermal decomposition", ] # ------------------------------------------------------------------------- # Model sizes / dims (match CL encoder + pretraining) # ------------------------------------------------------------------------- CL_EMB_DIM = 600 MAX_ATOMIC_Z = 85 MASK_ATOM_ID = MAX_ATOMIC_Z + 1 # GINE params NODE_EMB_DIM = 300 EDGE_EMB_DIM = 300 NUM_GNN_LAYERS = 5 # SchNet params SCHNET_NUM_GAUSSIANS = 50 SCHNET_NUM_INTERACTIONS = 6 SCHNET_CUTOFF = 10.0 SCHNET_MAX_NEIGHBORS = 64 SCHNET_HIDDEN = 600 # Fingerprint params FP_LENGTH = 2048 MASK_TOKEN_ID_FP = 2 VOCAB_SIZE_FP = 3 # DeBERTa params DEBERTA_HIDDEN = 600 PSMILES_MAX_LEN = 128 # SELFIES-TED generation limits GEN_MAX_LEN = 256 GEN_MIN_LEN = 10 # ------------------------------------------------------------------------- # Decoder fine-tuning schedule (single head) # ------------------------------------------------------------------------- BATCH_SIZE = 32 NUM_EPOCHS = 100 PATIENCE = 10 WEIGHT_DECAY = 0.0 LEARNING_RATE = 1e-4 COSINE_ETA_MIN = 1e-6 # Noise injection (latent space) LATENT_NOISE_STD_TRAIN = 0.10 # training-time denoising std LATENT_NOISE_STD_GEN = 0.15 # generation-time exploration std N_FOLD_NOISE_SAMPLING = 16 # sampling multiplicity around each seed embedding # Sampling config (decoder) GEN_TOP_P = 0.92 GEN_TEMPERATURE = 1.0 GEN_REPETITION_PENALTY = 1.05 # Cross-validation NUM_FOLDS = 5 # Property guidance tolerance (scaled space) PROP_TOL_SCALED = 0.5 PROP_TOL_UNSCALED_ABS = None # GPR settings (PSMILES latent) USE_PCA_BEFORE_GPR = True PCA_DIM = 64 GPR_ALPHA = 1e-6 # Verification (optional auxiliary predictor trained per fold) VERIFY_GENERATED_PROPERTIES = True PROP_PRED_EPOCHS = 20 PROP_PRED_PATIENCE = 5 PROP_PRED_BATCH_SIZE = 32 PROP_PRED_LR = 3e-4 PROP_PRED_WEIGHT_DECAY = 0.0 DEVICE = "cuda" if torch.cuda.is_available() else "cpu" USE_AMP = bool(torch.cuda.is_available()) AMP_DTYPE = torch.float16 NUM_WORKERS = 0 if os.name == "nt" else 1 warnings.filterwarnings("ignore", category=UserWarning) def ensure_output_dirs(cfg: Config) -> None: """Create output directories if they do not exist.""" os.makedirs(cfg.OUTPUT_DIR, exist_ok=True) os.makedirs(cfg.OUTPUT_MODELS_DIR, exist_ok=True) os.makedirs(cfg.OUTPUT_GENERATIONS_DIR, exist_ok=True) # ============================================================================= # Utilities # ============================================================================= def _safe_json_load(x): """Robust JSON parsing for CSV cells that may contain dict/list JSON (or slightly malformed strings).""" if x is None: return None if isinstance(x, (dict, list)): return x s = str(x).strip() if not s: return None try: return json.loads(s) except Exception: try: return json.loads(s.replace("'", '"')) except Exception: return None def set_seed(seed: int): """Set random seeds for reproducibility (best effort).""" random.seed(seed) np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) try: # Keep cuDNN fast; reproducibility across GPUs/drivers is not guaranteed. torch.backends.cudnn.benchmark = True except Exception: pass def make_json_serializable(obj): """Convert common scientific objects into JSON-serializable Python types.""" if isinstance(obj, dict): return {make_json_serializable(k): make_json_serializable(v) for k, v in obj.items()} if isinstance(obj, (list, tuple, set)): return [make_json_serializable(x) for x in obj] if isinstance(obj, np.ndarray): return obj.tolist() if isinstance(obj, np.generic): try: return obj.item() except Exception: return float(obj) if isinstance(obj, torch.Tensor): try: return obj.detach().cpu().tolist() except Exception: return None if isinstance(obj, (pd.Timestamp, pd.Timedelta)): return str(obj) try: if isinstance(obj, (float, int, str, bool, type(None))): return obj except Exception: pass return str(obj) def find_property_columns(columns): """ Heuristically map requested property names to dataframe columns. Notes: - Uses lowercase matching; prefers token-level matches. - Special-case: exclude "cohesive energy" when searching for "density". """ lowered = {c.lower(): c for c in columns} found = {} for req in REQUESTED_PROPERTIES: req_low = req.lower().strip() exact = None # First, attempt a token match for c_low, c_orig in lowered.items(): tokens = set(c_low.replace('_', ' ').split()) if req_low in tokens or c_low == req_low: if req_low == "density" and ("cohesive" in c_low or "cohesive energy" in c_low): continue exact = c_orig break if exact is not None: found[req] = exact continue # Fallback: substring match candidates = [c_orig for c_low, c_orig in lowered.items() if req_low in c_low] if req_low == "density": candidates = [c for c in candidates if "cohesive" not in c.lower() and "cohesive energy" not in c.lower()] chosen = candidates[0] if candidates else None found[req] = chosen if chosen is None: print(f"[WARN] Could not match requested property '{req}' to any column.") else: print(f"[INFO] Mapped requested property '{req}' -> column '{chosen}'") return found # ============================================================================= # Graph / geometry / fingerprint parsing for multimodal CL encoding # ============================================================================= def _parse_graph_for_gine(graph_field): """ Convert a stored 'graph' JSON blob into the tensor inputs expected by GineEncoder. Returns None if graph is missing or malformed. """ gf = _safe_json_load(graph_field) if not isinstance(gf, dict): return None node_features = gf.get("node_features", None) if not node_features or not isinstance(node_features, list): return None atomic_nums, chirality_vals, formal_charges = [], [], [] for nf in node_features: if not isinstance(nf, dict): continue an = nf.get("atomic_num", nf.get("atomic_number", 0)) ch = nf.get("chirality", 0) fc = nf.get("formal_charge", 0) try: atomic_nums.append(int(an)) except Exception: atomic_nums.append(0) try: chirality_vals.append(float(ch)) except Exception: chirality_vals.append(0.0) try: formal_charges.append(float(fc)) except Exception: formal_charges.append(0.0) if len(atomic_nums) == 0: return None edge_indices_raw = gf.get("edge_indices", None) edge_features_raw = gf.get("edge_features", None) srcs, dsts = [], [] # Handle two common representations: # (a) edge_indices = [[u,v], [u,v], ...] # (b) edge_indices = [[srcs...], [dsts...]] if edge_indices_raw is None: adj = gf.get("adjacency_matrix", None) if isinstance(adj, list): for i_r, row_adj in enumerate(adj): if not isinstance(row_adj, list): continue for j, val in enumerate(row_adj): if val: srcs.append(i_r) dsts.append(j) else: try: if isinstance(edge_indices_raw, list) and len(edge_indices_raw) > 0: if isinstance(edge_indices_raw[0], list) and len(edge_indices_raw[0]) == 2: srcs = [int(p[0]) for p in edge_indices_raw] dsts = [int(p[1]) for p in edge_indices_raw] elif len(edge_indices_raw) == 2: srcs = [int(x) for x in edge_indices_raw[0]] dsts = [int(x) for x in edge_indices_raw[1]] except Exception: srcs, dsts = [], [] if len(srcs) == 0: edge_index = torch.empty((2, 0), dtype=torch.long) edge_attr = torch.zeros((0, 3), dtype=torch.float) return { "z": torch.tensor(atomic_nums, dtype=torch.long), "chirality": torch.tensor(chirality_vals, dtype=torch.float), "formal_charge": torch.tensor(formal_charges, dtype=torch.float), "edge_index": edge_index, "edge_attr": edge_attr, } edge_index = torch.tensor([srcs, dsts], dtype=torch.long) # Edge attributes (bond_type, stereo, is_conjugated) if present; else zeros if isinstance(edge_features_raw, list) and len(edge_features_raw) == len(srcs): bt, st, ic = [], [], [] for ef in edge_features_raw: if isinstance(ef, dict): bt.append(float(ef.get("bond_type", 0))) st.append(float(ef.get("stereo", 0))) ic.append(float(1.0 if ef.get("is_conjugated", False) else 0.0)) else: bt.append(0.0) st.append(0.0) ic.append(0.0) edge_attr = torch.tensor(list(zip(bt, st, ic)), dtype=torch.float) else: edge_attr = torch.zeros((len(srcs), 3), dtype=torch.float) return { "z": torch.tensor(atomic_nums, dtype=torch.long), "chirality": torch.tensor(chirality_vals, dtype=torch.float), "formal_charge": torch.tensor(formal_charges, dtype=torch.float), "edge_index": edge_index, "edge_attr": edge_attr, } def _parse_geometry_for_schnet(geom_field): """ Convert stored 'geometry' JSON blob into SchNet inputs: - atomic_numbers -> z - coordinates -> pos Returns None if missing/malformed. """ gf = _safe_json_load(geom_field) if not isinstance(gf, dict): return None conf = gf.get("best_conformer", None) if not isinstance(conf, dict): return None atomic = conf.get("atomic_numbers", []) coords = conf.get("coordinates", []) if not (isinstance(atomic, list) and isinstance(coords, list)): return None if len(atomic) == 0 or len(atomic) != len(coords): return None return {"z": torch.tensor(atomic, dtype=torch.long), "pos": torch.tensor(coords, dtype=torch.float)} def _parse_fingerprints(fp_field, fp_len: int = 2048): """ Parse a fingerprint field (either list or dict containing 'morgan_r3_bits') into: - input_ids: LongTensor [fp_len] with 0/1 bits - attention_mask: BoolTensor [fp_len] all True """ fp = _safe_json_load(fp_field) bits = None if isinstance(fp, dict): bits = fp.get("morgan_r3_bits", None) elif isinstance(fp, list): bits = fp elif fp is None: bits = None if bits is None: bits = [0] * fp_len else: norm = [] for b in bits[:fp_len]: if isinstance(b, str): bc = b.strip().strip('"').strip("'") norm.append(1 if bc in ("1", "True", "true") else 0) elif isinstance(b, (int, np.integer, float, np.floating)): norm.append(1 if int(b) != 0 else 0) else: norm.append(0) if len(norm) < fp_len: norm.extend([0] * (fp_len - len(norm))) bits = norm return { "input_ids": torch.tensor(bits, dtype=torch.long), "attention_mask": torch.ones(fp_len, dtype=torch.bool), } # ============================================================================= # PSELFIES utilities (polymer-safe SELFIES encoding with endpoint markers) # ============================================================================= _SELFIES_TOKEN_RE = re.compile(r"\[[^\[\]]+\]") def _split_selfies_tokens(selfies_str: str) -> List[str]: """Split a SELFIES string into tokens; prefers selfies.split_selfies if available.""" if not isinstance(selfies_str, str) or len(selfies_str) == 0: return [] if SELFIES_AVAILABLE: try: toks = list(sf.split_selfies(selfies_str.replace(" ", ""))) return [t for t in toks if isinstance(t, str) and t] except Exception: pass return _SELFIES_TOKEN_RE.findall(selfies_str) def _selfies_for_tokenizer(selfies_str: str) -> str: """Normalize SELFIES formatting so the HF tokenizer sees token boundaries.""" s = str(selfies_str).strip() if not s: return "" s = s.replace(" ", "") s = s.replace("][", "] [") return s def _selfies_compact(selfies_str: str) -> str: """Remove spaces and trim.""" return str(selfies_str).replace(" ", "").strip() def _ensure_two_at_endpoints(selfies_str: str) -> str: """ Ensure polymer endpoints exist: enforce exactly two [At] tokens (one at each end). This is used as a polymerization marker compatible with the At-based conversion. """ s = _selfies_compact(selfies_str) toks = _split_selfies_tokens(s) if not toks: return s at = "[At]" at_pos = [i for i, t in enumerate(toks) if t == at] if len(at_pos) == 0: toks = [at] + toks + [at] elif len(at_pos) == 1: toks = toks + [at] elif len(at_pos) > 2: first = at_pos[0] last = at_pos[-1] new = [] for i, t in enumerate(toks): if t == at and i not in (first, last): continue new.append(t) toks = new return "".join(toks) def psmiles_to_at_smiles(psmiles: str, root_at: bool = True) -> Optional[str]: """ Convert polymer PSMILES (two [*]) into RDKit SMILES where [*] is represented as element At (Z=85). This allows SELFIES encoding/decoding while preserving polymer endpoints. """ if not RDKit_AVAILABLE: return None try: mol = Chem.MolFromSmiles(psmiles) if mol is None: return None mol = Chem.RWMol(mol) at_indices = [] for atom in mol.GetAtoms(): if atom.GetAtomicNum() == 0: atom.SetAtomicNum(85) try: atom.SetNoImplicit(True) except Exception: pass try: atom.SetNumExplicitHs(0) except Exception: pass try: atom.SetFormalCharge(0) except Exception: pass at_indices.append(int(atom.GetIdx())) mol = mol.GetMol() try: Chem.SanitizeMol(mol, catchErrors=True) except Exception: return None if root_at and len(at_indices) > 0: try: can = Chem.MolToSmiles(mol, canonical=True, rootedAtAtom=at_indices[0]) except Exception: can = Chem.MolToSmiles(mol, canonical=True) else: can = Chem.MolToSmiles(mol, canonical=True) return can except Exception: return None def at_smiles_to_psmiles(at_smiles: str) -> Optional[str]: """Inverse of psmiles_to_at_smiles: convert At (Z=85) back to polymer [*] endpoints.""" if not RDKit_AVAILABLE: return None try: mol = Chem.MolFromSmiles(at_smiles) if mol is None: return None rw = Chem.RWMol(mol) for atom in rw.GetAtoms(): if atom.GetAtomicNum() == 85: atom.SetAtomicNum(0) try: atom.SetNoImplicit(True) except Exception: pass try: atom.SetNumExplicitHs(0) except Exception: pass try: atom.SetFormalCharge(0) except Exception: pass mol2 = rw.GetMol() try: Chem.SanitizeMol(mol2, catchErrors=True) except Exception: return None can = Chem.MolToSmiles(mol2, canonical=True) can = can.replace("[*]", "*") return can except Exception: return None def smiles_to_pselfies(smiles: str) -> Optional[str]: """Encode RDKit-canonical SMILES into SELFIES.""" if not (RDKit_AVAILABLE and SELFIES_AVAILABLE): return None try: mol = Chem.MolFromSmiles(smiles) if mol is None: return None can = Chem.MolToSmiles(mol, canonical=True) s = sf.encoder(can) if not isinstance(s, str) or len(s) == 0: return None return s except Exception: return None def psmiles_to_pselfies(psmiles: str) -> Optional[str]: """Convert polymer PSMILES -> At-SMILES -> PSELFIES, ensuring endpoint markers.""" if not (RDKit_AVAILABLE and SELFIES_AVAILABLE): return None at_smiles = psmiles_to_at_smiles(psmiles, root_at=True) if at_smiles is None: return None s = smiles_to_pselfies(at_smiles) if s is None: return None return _ensure_two_at_endpoints(s) def selfies_to_smiles(selfies_str: str) -> Optional[str]: """Decode SELFIES -> SMILES and canonicalize with RDKit.""" if not (RDKit_AVAILABLE and SELFIES_AVAILABLE): return None try: s = _selfies_compact(selfies_str) smi = sf.decoder(s) if not isinstance(smi, str) or len(smi) == 0: return None mol = Chem.MolFromSmiles(smi) if mol is None: return None try: Chem.SanitizeMol(mol, catchErrors=True) except Exception: return None can = Chem.MolToSmiles(mol, canonical=True) return can except Exception: return None def pselfies_to_psmiles(selfies_str: str) -> Optional[str]: """Decode PSELFIES -> At-SMILES -> polymer PSMILES.""" if not (RDKit_AVAILABLE and SELFIES_AVAILABLE): return None at_smiles = selfies_to_smiles(selfies_str) if at_smiles is None: return None return at_smiles_to_psmiles(at_smiles) def canonicalize_psmiles(psmiles: str) -> Optional[str]: """RDKit-canonicalize PSMILES (best effort).""" psmiles = str(psmiles).strip() if not psmiles: return None if not RDKit_AVAILABLE: return psmiles try: mol = Chem.MolFromSmiles(psmiles) if mol is None: return None try: Chem.SanitizeMol(mol, catchErrors=True) except Exception: return None can = Chem.MolToSmiles(mol, canonical=True) can = can.replace("[*]", "*") return can except Exception: return None def chem_validity_psmiles(psmiles: str) -> bool: """Basic chemical validity check via RDKit parse + sanitize.""" if not RDKit_AVAILABLE: return False try: s = str(psmiles).strip() if not s: return False mol = Chem.MolFromSmiles(s) if mol is None: return False try: Chem.SanitizeMol(mol, catchErrors=True) except Exception: return False return True except Exception: return False def polymer_validity_psmiles_strict(psmiles: str) -> bool: """ Strict polymer validity: - exactly two [*] atoms - each [*] has degree 1 (a single attachment) """ if not RDKit_AVAILABLE: return False try: s = str(psmiles).strip() if not s: return False mol = Chem.MolFromSmiles(s) if mol is None: return False try: Chem.SanitizeMol(mol, catchErrors=True) except Exception: return False stars = [a for a in mol.GetAtoms() if a.GetAtomicNum() == 0] if len(stars) != 2: return False for a in stars: if a.GetTotalDegree() != 1: return False return True except Exception: return False # ============================================================================= # CL encoder (multimodal) + fusion pooling # ============================================================================= def resolve_cl_checkpoint_path(cl_weights_dir: str) -> Optional[str]: """Resolve a checkpoint file inside a directory (or accept a file path directly).""" if cl_weights_dir is None: return None if os.path.isfile(cl_weights_dir): return cl_weights_dir if not os.path.isdir(cl_weights_dir): return None candidates = [ os.path.join(cl_weights_dir, "pytorch_model.bin"), os.path.join(cl_weights_dir, "model.pt"), os.path.join(cl_weights_dir, "best.pt"), os.path.join(cl_weights_dir, "state_dict.pt"), ] for p in candidates: if os.path.isfile(p): return p for ext in ("*.bin", "*.pt"): files = sorted(Path(cl_weights_dir).glob(ext)) if files: return str(files[0]) return None def load_state_dict_any(ckpt_path: str) -> Dict[str, torch.Tensor]: """Load a checkpoint that may wrap the model state dict under common keys.""" obj = torch.load(ckpt_path, map_location="cpu") if isinstance(obj, dict): if "state_dict" in obj and isinstance(obj["state_dict"], dict): return obj["state_dict"] if "model_state_dict" in obj and isinstance(obj["model_state_dict"], dict): return obj["model_state_dict"] if not isinstance(obj, dict): raise RuntimeError(f"Checkpoint at {ckpt_path} did not contain a state_dict-like dict.") return obj def safe_load_into_module(module: nn.Module, sd: Dict[str, torch.Tensor], strict: bool = False) -> Tuple[int, int]: """Load a (possibly partial) state dict and return counts of missing/unexpected keys.""" incompatible = module.load_state_dict(sd, strict=strict) missing = getattr(incompatible, "missing_keys", []) unexpected = getattr(incompatible, "unexpected_keys", []) return len(missing), len(unexpected) class PolyFusionModule(nn.Module): """ Tiny fusion transformer: - self-attention over modality tokens - learned query pooling (attention weights -> pooled representation) """ def __init__(self, d_model: int, nhead: int = 8, ffn_mult: int = 4, dropout: float = 0.1): super().__init__() self.ln1 = nn.LayerNorm(d_model) self.attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) self.ln2 = nn.LayerNorm(d_model) self.ffn = nn.Sequential( nn.Linear(d_model, ffn_mult * d_model), nn.GELU(), nn.Dropout(dropout), nn.Linear(ffn_mult * d_model, d_model), nn.Dropout(dropout), ) self.pool_ln = nn.LayerNorm(d_model) self.pool_q = nn.Parameter(torch.randn(d_model)) def forward(self, x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor: # mask: True for valid tokens; MultiheadAttention uses key_padding_mask where True means "ignore" key_padding = ~mask h = self.ln1(x) attn_out, _ = self.attn(h, h, h, key_padding_mask=key_padding) x = x + attn_out x = x + self.ffn(self.ln2(x)) # query pooling x = self.pool_ln(x) q = self.pool_q.unsqueeze(0).unsqueeze(-1) # [1, d, 1] scores = torch.matmul(x, q).squeeze(-1) # [B, T] scores = scores.masked_fill(~mask, -1e9) w = torch.softmax(scores, dim=-1).unsqueeze(-1) pooled = (x * w).sum(dim=1) # [B, d] return pooled class MultiModalCLPolymerEncoder(nn.Module): """ Frozen multimodal encoder used as the conditioning interface: - encodes any subset of modalities (graph/geometry/fingerprint/psmiles) - projects each modality into a shared CL embedding space - fuses available modality tokens into a single normalized vector """ def __init__( self, psmiles_tokenizer, emb_dim: int = CL_EMB_DIM, cl_weights_dir: Optional[str] = CFG.PRETRAINED_MULTIMODAL_DIR, use_gine: bool = True, use_schnet: bool = True, use_fp: bool = True, use_psmiles: bool = True, ): super().__init__() self.psm_tok = psmiles_tokenizer self.emb_dim = int(emb_dim) self.gine = None self.schnet = None self.fp = None self.psmiles = None if use_gine: try: self.gine = GineEncoder(NODE_EMB_DIM, EDGE_EMB_DIM, NUM_GNN_LAYERS, MAX_ATOMIC_Z) except Exception as e: print(f"[CL][WARN] Disabling GINE encoder: {e}") self.gine = None if use_schnet: try: self.schnet = NodeSchNetWrapper( SCHNET_HIDDEN, SCHNET_NUM_INTERACTIONS, SCHNET_NUM_GAUSSIANS, SCHNET_CUTOFF, SCHNET_MAX_NEIGHBORS ) except Exception as e: print(f"[CL][WARN] Disabling SchNet encoder: {e}") self.schnet = None if use_fp: try: self.fp = FingerprintEncoder(VOCAB_SIZE_FP, 256, FP_LENGTH, 4, 8, 1024, 0.1) except Exception as e: print(f"[CL][WARN] Disabling fingerprint encoder: {e}") self.fp = None if use_psmiles: enc_src = CFG.BEST_PSMILES_DIR if (CFG.BEST_PSMILES_DIR and os.path.isdir(CFG.BEST_PSMILES_DIR)) else None self.psmiles = PSMILESDebertaEncoder( model_dir_or_name=enc_src, vocab_fallback=int(getattr(psmiles_tokenizer, "vocab_size", 300)), ) # Projection layers into shared CL space self.proj_gine = nn.Linear(NODE_EMB_DIM, self.emb_dim) if self.gine is not None else None self.proj_schnet = nn.Linear(SCHNET_HIDDEN, self.emb_dim) if self.schnet is not None else None self.proj_fp = nn.Linear(256, self.emb_dim) if self.fp is not None else None self.proj_psmiles = nn.Linear(DEBERTA_HIDDEN, self.emb_dim) if self.psmiles is not None else None self.dropout = nn.Dropout(0.1) self.out_dim = self.emb_dim self.fusion = PolyFusionModule(d_model=self.emb_dim, nhead=8, ffn_mult=4, dropout=0.1) # Optionally load a trained multimodal CL checkpoint self._load_multimodal_cl_checkpoint(cl_weights_dir) def _load_multimodal_cl_checkpoint(self, cl_weights_dir: Optional[str]): ckpt_path = resolve_cl_checkpoint_path(cl_weights_dir) if cl_weights_dir else None if ckpt_path is None: print(f"[CL][INFO] No multimodal CL checkpoint found at '{cl_weights_dir}'. Using initialized weights.") return sd = load_state_dict_any(ckpt_path) model_sd = self.state_dict() # Load only compatible keys (shape match) to be robust across versions filtered = {} for k, v in sd.items(): if k not in model_sd: continue if hasattr(v, "shape") and hasattr(model_sd[k], "shape") and tuple(v.shape) != tuple(model_sd[k].shape): continue filtered[k] = v missing, unexpected = safe_load_into_module(self, filtered, strict=False) print( f"[CL][INFO] Loaded multimodal CL checkpoint '{ckpt_path}'. " f"loaded_keys={len(filtered)} missing={missing} unexpected={unexpected}" ) def freeze_cl_encoders(self): """Freeze encoders and fusion module (decoder training should not update them).""" for name, enc in [("gine", self.gine), ("schnet", self.schnet), ("fp", self.fp), ("psmiles", self.psmiles)]: if enc is None: continue enc.eval() for p in enc.parameters(): p.requires_grad = False print(f"[CL][INFO] Froze {name} encoder parameters.") self.fusion.eval() for p in self.fusion.parameters(): p.requires_grad = False print("[CL][INFO] Froze fusion module parameters.") def forward_multimodal(self, batch_mods: dict) -> torch.Tensor: """Encode a batch containing any subset of modalities and return normalized CL embeddings.""" device = next(self.parameters()).device # Infer batch size from whichever modality is present if batch_mods.get("fp", None) is not None and isinstance(batch_mods["fp"].get("input_ids", None), torch.Tensor): B = int(batch_mods["fp"]["input_ids"].size(0)) elif batch_mods.get("psmiles", None) is not None and isinstance(batch_mods["psmiles"].get("input_ids", None), torch.Tensor): B = int(batch_mods["psmiles"]["input_ids"].size(0)) else: if batch_mods.get("gine", None) is not None and isinstance(batch_mods["gine"].get("batch", None), torch.Tensor): B = int(batch_mods["gine"]["batch"].max().item() + 1) if batch_mods["gine"]["batch"].numel() > 0 else 1 elif batch_mods.get("schnet", None) is not None and isinstance(batch_mods["schnet"].get("batch", None), torch.Tensor): B = int(batch_mods["schnet"]["batch"].max().item() + 1) if batch_mods["schnet"]["batch"].numel() > 0 else 1 else: B = 1 tokens: List[torch.Tensor] = [] def _append_token(z_token: torch.Tensor): tokens.append(z_token) # GINE token if self.gine is not None and batch_mods.get("gine", None) is not None: g = batch_mods["gine"] if isinstance(g.get("z", None), torch.Tensor) and g["z"].numel() > 0: emb_g = self.gine( g["z"].to(device), g.get("chirality", torch.zeros_like(g["z"], dtype=torch.float)).to(device) if isinstance(g.get("chirality", None), torch.Tensor) else None, g.get("formal_charge", torch.zeros_like(g["z"], dtype=torch.float)).to(device) if isinstance(g.get("formal_charge", None), torch.Tensor) else None, g.get("edge_index", torch.empty((2, 0), dtype=torch.long)).to(device), g.get("edge_attr", torch.zeros((0, 3), dtype=torch.float)).to(device), g.get("batch", None).to(device) if isinstance(g.get("batch", None), torch.Tensor) else None, ) zg = self.proj_gine(emb_g) zg = self.dropout(zg) _append_token(zg) # SchNet token if self.schnet is not None and batch_mods.get("schnet", None) is not None: s = batch_mods["schnet"] if isinstance(s.get("z", None), torch.Tensor) and s["z"].numel() > 0: emb_s = self.schnet( s["z"].to(device), s["pos"].to(device), s.get("batch", None).to(device) if isinstance(s.get("batch", None), torch.Tensor) else None, ) zs = self.proj_schnet(emb_s) zs = self.dropout(zs) _append_token(zs) # Fingerprint token if self.fp is not None and batch_mods.get("fp", None) is not None: f = batch_mods["fp"] if isinstance(f.get("input_ids", None), torch.Tensor) and f["input_ids"].numel() > 0: emb_f = self.fp( f["input_ids"].to(device), f.get("attention_mask", None).to(device) if isinstance(f.get("attention_mask", None), torch.Tensor) else None, ) zf = self.proj_fp(emb_f) zf = self.dropout(zf) _append_token(zf) # PSMILES token if self.psmiles is not None and batch_mods.get("psmiles", None) is not None: p = batch_mods["psmiles"] if isinstance(p.get("input_ids", None), torch.Tensor) and p["input_ids"].numel() > 0: emb_p = self.psmiles( p["input_ids"].to(device), p.get("attention_mask", None).to(device) if isinstance(p.get("attention_mask", None), torch.Tensor) else None, ) zp = self.proj_psmiles(emb_p) zp = self.dropout(zp) _append_token(zp) if not tokens: # No modalities present; return a safe zero vector z = torch.zeros((B, self.emb_dim), device=device) return F.normalize(z, dim=-1) X = torch.stack(tokens, dim=1) # [B, T, d] mask = torch.ones((B, X.size(1)), dtype=torch.bool, device=device) pooled = self.fusion(X, mask) pooled = F.normalize(pooled, dim=-1) return pooled @torch.no_grad() def encode_psmiles( self, psmiles_list: List[str], max_len: int = PSMILES_MAX_LEN, batch_size: int = 64, device: str = DEVICE, ) -> np.ndarray: self.eval() if self.psm_tok is None or self.psmiles is None or self.proj_psmiles is None: raise RuntimeError("PSMILES tokenizer/encoder/projection not available.") dev = torch.device(device) self.to(dev) outs = [] for i in range(0, len(psmiles_list), batch_size): chunk = [str(x) for x in psmiles_list[i : i + batch_size]] enc = self.psm_tok(chunk, truncation=True, padding="max_length", max_length=max_len, return_tensors="pt") input_ids = enc["input_ids"].to(dev) attn = enc["attention_mask"].to(dev).bool() emb_p = self.psmiles(input_ids, attn) z = self.proj_psmiles(emb_p) z = F.normalize(z, dim=-1) outs.append(z.detach().cpu().numpy()) return np.concatenate(outs, axis=0) if outs else np.zeros((0, self.emb_dim), dtype=np.float32) @torch.no_grad() def encode_multimodal(self, records: List[dict], batch_size: int = 32, device: str = DEVICE) -> np.ndarray: """Encode a list of records that may contain any subset of modalities.""" self.eval() dev = torch.device(device) self.to(dev) outs = [] for i in range(0, len(records), batch_size): chunk = records[i : i + batch_size] # PSMILES tokenization psmiles_texts = [str(r.get("psmiles", "")) for r in chunk] p_enc = None if self.psm_tok is not None: p_enc = self.psm_tok( psmiles_texts, truncation=True, padding="max_length", max_length=PSMILES_MAX_LEN, return_tensors="pt", ) # Fingerprints fp_ids, fp_attn = [], [] for r in chunk: f = _parse_fingerprints(r.get("fingerprints", None), fp_len=FP_LENGTH) fp_ids.append(f["input_ids"]) fp_attn.append(f["attention_mask"]) fp_ids = torch.stack(fp_ids, dim=0) fp_attn = torch.stack(fp_attn, dim=0) # GINE batch assembly (concat nodes; keep per-graph batch indices) gine_all = {"z": [], "chirality": [], "formal_charge": [], "edge_index": [], "edge_attr": [], "batch": []} node_offset = 0 for bi, r in enumerate(chunk): g = _parse_graph_for_gine(r.get("graph", None)) if g is None or g["z"].numel() == 0: continue n = g["z"].size(0) gine_all["z"].append(g["z"]) gine_all["chirality"].append(g["chirality"]) gine_all["formal_charge"].append(g["formal_charge"]) gine_all["batch"].append(torch.full((n,), bi, dtype=torch.long)) ei = g["edge_index"] ea = g["edge_attr"] if ei is not None and ei.numel() > 0: gine_all["edge_index"].append(ei + node_offset) gine_all["edge_attr"].append(ea) node_offset += n gine_batch = None if len(gine_all["z"]) > 0: z_b = torch.cat(gine_all["z"], dim=0) ch_b = torch.cat(gine_all["chirality"], dim=0) fc_b = torch.cat(gine_all["formal_charge"], dim=0) b_b = torch.cat(gine_all["batch"], dim=0) if len(gine_all["edge_index"]) > 0: ei_b = torch.cat(gine_all["edge_index"], dim=1) ea_b = torch.cat(gine_all["edge_attr"], dim=0) else: ei_b = torch.empty((2, 0), dtype=torch.long) ea_b = torch.zeros((0, 3), dtype=torch.float) gine_batch = { "z": z_b, "chirality": ch_b, "formal_charge": fc_b, "edge_index": ei_b, "edge_attr": ea_b, "batch": b_b, } # SchNet batch assembly (concat atoms; keep per-structure batch indices) sch_all_z, sch_all_pos, sch_all_batch = [], [], [] for bi, r in enumerate(chunk): s = _parse_geometry_for_schnet(r.get("geometry", None)) if s is None or s["z"].numel() == 0: continue n = s["z"].size(0) sch_all_z.append(s["z"]) sch_all_pos.append(s["pos"]) sch_all_batch.append(torch.full((n,), bi, dtype=torch.long)) schnet_batch = None if len(sch_all_z) > 0: schnet_batch = { "z": torch.cat(sch_all_z, dim=0), "pos": torch.cat(sch_all_pos, dim=0), "batch": torch.cat(sch_all_batch, dim=0), } batch_mods = { "gine": gine_batch, "schnet": schnet_batch, "fp": {"input_ids": fp_ids, "attention_mask": fp_attn}, "psmiles": {"input_ids": p_enc["input_ids"], "attention_mask": p_enc["attention_mask"]} if p_enc is not None else None, } z = self.forward_multimodal(batch_mods) outs.append(z.detach().cpu().numpy()) return np.concatenate(outs, axis=0) if outs else np.zeros((0, self.emb_dim), dtype=np.float32) # ============================================================================= # SELFIES-TED decoder conditioned on CL embeddings # ============================================================================= SELFIES_TED_MODEL_NAME = os.environ.get("SELFIES_TED_MODEL_NAME", "ibm-research/materials.selfies-ted") HF_TOKEN = os.environ.get("HF_TOKEN", None) def _hf_load_with_retries(load_fn, max_tries: int = 5, base_sleep: float = 2.0): """Retry wrapper for HF downloads (useful when the hub is flaky or rate-limited).""" last_err = None for t in range(max_tries): try: return load_fn() except Exception as e: last_err = e sleep_s = base_sleep * (1.6 ** t) + random.random() print(f"[HF][WARN] Load attempt {t+1}/{max_tries} failed: {e} | retrying in {sleep_s:.1f}s") time.sleep(sleep_s) raise RuntimeError(f"Failed to load model from HF after {max_tries} attempts. Last error: {last_err}") def load_selfies_ted_and_tokenizer(model_name: str = SELFIES_TED_MODEL_NAME): """Load SELFIES-TED tokenizer and model from Hugging Face.""" def _load_tok(): return AutoTokenizer.from_pretrained(model_name, token=HF_TOKEN, use_fast=True) def _load_model(): return AutoModelForSeq2SeqLM.from_pretrained(model_name, token=HF_TOKEN) tok = _hf_load_with_retries(_load_tok, max_tries=5) model = _hf_load_with_retries(_load_model, max_tries=5) return tok, model class CLConditionedSelfiesTEDGenerator(nn.Module): """ Condition SELFIES-TED on CL embeddings by: - mapping CL vector -> d_model - expanding into a short "memory" sequence (mem_len) - passing this as encoder_outputs to the seq2seq model """ def __init__(self, tok, seq2seq_model, cl_emb_dim: int = CL_EMB_DIM, mem_len: int = 4): super().__init__() self.tok = tok self.model = seq2seq_model self.mem_len = int(mem_len) d_model = int(getattr(self.model.config, "d_model", 1024)) self.cl_to_d = nn.Sequential( nn.Linear(cl_emb_dim, d_model), nn.Tanh(), nn.Dropout(0.1), nn.Linear(d_model, d_model), ) self.mem_pos = nn.Embedding(self.mem_len, d_model) def build_encoder_outputs(self, z: torch.Tensor) -> Tuple[BaseModelOutput, torch.Tensor]: """Determine encoder outputs from a CL latent vector.""" device = z.device B = z.size(0) d = self.cl_to_d(z) # [B, d_model] d = d.unsqueeze(1).expand(B, self.mem_len, d.size(-1)).contiguous() pos = torch.arange(self.mem_len, device=device).unsqueeze(0).expand(B, -1) d = d + self.mem_pos(pos) attn = torch.ones((B, self.mem_len), dtype=torch.long, device=device) return BaseModelOutput(last_hidden_state=d), attn def forward_train(self, z: torch.Tensor, labels: torch.Tensor) -> Dict[str, torch.Tensor]: """Teacher-forced training step (labels are decoder targets).""" enc_out, attn = self.build_encoder_outputs(z) out = self.model(encoder_outputs=enc_out, attention_mask=attn, labels=labels) loss = out.loss return {"loss": loss, "ce": loss.detach()} @torch.no_grad() def generate( self, z: torch.Tensor, num_return_sequences: int = 1, max_len: int = GEN_MAX_LEN, top_p: float = GEN_TOP_P, temperature: float = GEN_TEMPERATURE, repetition_penalty: float = GEN_REPETITION_PENALTY, ) -> List[str]: """Stochastic decoding from a batch of CL latents.""" self.eval() z = z.to(next(self.parameters()).device) enc_out, attn = self.build_encoder_outputs(z) gen = self.model.generate( encoder_outputs=enc_out, attention_mask=attn, do_sample=True, top_p=float(top_p), temperature=float(temperature), repetition_penalty=float(repetition_penalty), num_return_sequences=int(num_return_sequences), max_length=int(max_len), min_length=int(GEN_MIN_LEN), pad_token_id=int(self.tok.pad_token_id) if self.tok.pad_token_id is not None else None, eos_token_id=int(self.tok.eos_token_id) if self.tok.eos_token_id is not None else None, ) outs = self.tok.batch_decode(gen, skip_special_tokens=True, clean_up_tokenization_spaces=True) outs = [_ensure_two_at_endpoints(_selfies_compact(o)) for o in outs] return outs def create_optimizer_and_scheduler_decoder(model: CLConditionedSelfiesTEDGenerator): """Create AdamW + CosineAnnealingLR for decoder fine-tuning.""" for p in model.parameters(): p.requires_grad = True opt = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY) sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=NUM_EPOCHS, eta_min=COSINE_ETA_MIN) return opt, sch # ============================================================================= # Datasets for latent-to-SELFIES training # ============================================================================= class LatentToPSELFIESDataset(Dataset): """ Each sample: - z: frozen CL embedding (optionally with Gaussian noise added for denoising) - labels: tokenized PSELFIES target sequence (pad tokens masked as -100) """ def __init__( self, records: List[dict], cl_encoder: MultiModalCLPolymerEncoder, selfies_tok, max_len: int = GEN_MAX_LEN, latent_noise_std: float = 0.0, cache_embeddings: bool = True, renormalize_after_noise: bool = True, use_multimodal: bool = True, ): self.records = records self.cl_encoder = cl_encoder self.tok = selfies_tok self.max_len = int(max_len) self.latent_noise_std = float(latent_noise_std) self.renorm = bool(renormalize_after_noise) self.use_multimodal = bool(use_multimodal) self.pad_id = int(self.tok.pad_token_id) if getattr(self.tok, "pad_token_id", None) is not None else 1 self._cache = None # Optionally precompute latents (saves a lot of time during decoder training) if cache_embeddings: if self.use_multimodal: emb = self.cl_encoder.encode_multimodal(self.records, batch_size=32, device=DEVICE) else: psm = [str(r.get("psmiles", "")) for r in self.records] emb = self.cl_encoder.encode_psmiles(psm, max_len=PSMILES_MAX_LEN, batch_size=64, device=DEVICE) self._cache = emb.astype(np.float32) def __len__(self): return len(self.records) def __getitem__(self, idx): r = self.records[idx] tgt = str(r["pselfies"]).strip() tgt = _selfies_for_tokenizer(tgt) # Get latent z (cached or computed on the fly) if self._cache is not None: z = torch.tensor(self._cache[idx], dtype=torch.float32) else: if self.use_multimodal: z_np = self.cl_encoder.encode_multimodal([r], batch_size=1, device=DEVICE) z = torch.tensor(z_np[0], dtype=torch.float32) else: psm = str(r.get("psmiles", "")).strip() z_np = self.cl_encoder.encode_psmiles([psm], max_len=PSMILES_MAX_LEN, batch_size=1, device=DEVICE) z = torch.tensor(z_np[0], dtype=torch.float32) # Denoising noise if self.latent_noise_std > 0: z = z + torch.randn_like(z) * self.latent_noise_std if self.renorm: z = F.normalize(z, dim=-1) # Tokenize target SELFIES; mask padding to -100 for CE enc = self.tok(tgt, truncation=True, padding="max_length", max_length=self.max_len, return_tensors=None) labels = torch.tensor(enc["input_ids"], dtype=torch.long) labels = labels.masked_fill(labels == self.pad_id, -100) return { "z": z, "labels": labels, "psmiles": str(r.get("psmiles", "")).strip(), "pselfies_raw": _selfies_compact(r["pselfies"]), } def latent_collate(batch: List[dict]) -> dict: """Collate latents and labels into batch tensors.""" z = torch.stack([b["z"] for b in batch], dim=0) labels = torch.stack([b["labels"] for b in batch], dim=0) return { "z": z, "labels": labels, "psmiles": [b["psmiles"] for b in batch], "pselfies_raw": [b["pselfies_raw"] for b in batch], } def move_latent_batch_to_device(batch: dict, device: str): batch["z"] = batch["z"].to(device) batch["labels"] = batch["labels"].to(device) # ============================================================================= # Aux PSMILES property oracle (optional) # ============================================================================= class PSMILESPropertyDataset(Dataset): """Text regression dataset: PSMILES -> scaled property (single scalar).""" def __init__(self, samples: List[dict], psmiles_tokenizer, max_len: int = PSMILES_MAX_LEN): self.samples = samples self.tok = psmiles_tokenizer self.max_len = max_len def __len__(self): return len(self.samples) def __getitem__(self, idx): s = str(self.samples[idx].get("psmiles", "")).strip() y = float(self.samples[idx].get("target_scaled", self.samples[idx].get("target", 0.0))) enc = self.tok(s, truncation=True, padding="max_length", max_length=self.max_len) return { "input_ids": torch.tensor(enc["input_ids"], dtype=torch.long), "attention_mask": torch.tensor(enc["attention_mask"], dtype=torch.bool), "y": torch.tensor([y], dtype=torch.float32), } def psmiles_prop_collate_fn(batch: List[dict]): input_ids = torch.stack([b["input_ids"] for b in batch], dim=0) attn = torch.stack([b["attention_mask"] for b in batch], dim=0) y = torch.stack([b["y"] for b in batch], dim=0) return {"input_ids": input_ids, "attention_mask": attn, "y": y} class TextPropertyOracle(nn.Module): """ Lightweight regressor for verification: - Frozen PSMILES encoder (DeBERTa variant) - Trainable MLP head """ def __init__(self, encoder_dir: Optional[str], vocab_size: Optional[int] = None, y_dim: int = 1): super().__init__() if encoder_dir is not None and os.path.isdir(encoder_dir): enc_src = encoder_dir elif os.path.isdir(CFG.BEST_PSMILES_DIR): enc_src = CFG.BEST_PSMILES_DIR else: enc_src = "microsoft/deberta-v2-xlarge" self.encoder = PSMILESDebertaEncoder( model_dir_or_name=enc_src, vocab_fallback=int(vocab_size) if vocab_size is not None else 300, ) h = getattr(self.encoder, "out_dim", DEBERTA_HIDDEN) self.head = nn.Sequential( nn.Linear(h, 256), nn.ReLU(), nn.Dropout(0.1), nn.Linear(256, y_dim), ) def forward(self, input_ids: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: h = self.encoder(input_ids=input_ids, attention_mask=attention_mask) return self.head(h) def move_prop_batch_to_device(batch: dict, device: str): batch["input_ids"] = batch["input_ids"].to(device) batch["attention_mask"] = batch["attention_mask"].to(device) batch["y"] = batch["y"].to(device) def train_prop_oracle_one_epoch(model: TextPropertyOracle, dl: DataLoader, opt, scaler_amp, device: str): model.train() total = 0.0 n = 0 for batch in dl: move_prop_batch_to_device(batch, device) y = batch["y"] opt.zero_grad(set_to_none=True) with torch.cuda.amp.autocast(enabled=USE_AMP, dtype=AMP_DTYPE): y_hat = model(batch["input_ids"], batch["attention_mask"]) loss = F.smooth_l1_loss(y_hat, y, beta=1.0) if USE_AMP: scaler_amp.scale(loss).backward() scaler_amp.unscale_(opt) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) scaler_amp.step(opt) scaler_amp.update() else: loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step() bs = y.size(0) total += float(loss.item()) * bs n += bs return total / max(1, n) @torch.no_grad() def eval_prop_oracle(model: TextPropertyOracle, dl: DataLoader, device: str): model.eval() total = 0.0 n = 0 for batch in dl: move_prop_batch_to_device(batch, device) y = batch["y"] with torch.cuda.amp.autocast(enabled=USE_AMP, dtype=AMP_DTYPE): y_hat = model(batch["input_ids"], batch["attention_mask"]) loss = F.smooth_l1_loss(y_hat, y, beta=1.0) bs = y.size(0) total += float(loss.item()) * bs n += bs return total / max(1, n) def train_property_oracle_per_fold( train_samples: List[dict], val_samples: List[dict], psmiles_tokenizer, device: str, max_len: int = PSMILES_MAX_LEN, ) -> Optional[TextPropertyOracle]: """Train a per-fold auxiliary oracle for scaled property prediction (verification only).""" if psmiles_tokenizer is None: return None try: model = TextPropertyOracle( encoder_dir=CFG.BEST_PSMILES_DIR if os.path.isdir(CFG.BEST_PSMILES_DIR) else None, vocab_size=getattr(psmiles_tokenizer, "vocab_size", None), y_dim=1, ).to(device) except Exception as e: print(f"[ORACLE][WARN] Could not initialize auxiliary property predictor: {e}") return None # Freeze encoder; train only head (fast + stable) for p in model.encoder.parameters(): p.requires_grad = False for p in model.head.parameters(): p.requires_grad = True ds_tr = PSMILESPropertyDataset(train_samples, psmiles_tokenizer, max_len=max_len) ds_va = PSMILESPropertyDataset(val_samples, psmiles_tokenizer, max_len=max_len) dl_tr = DataLoader(ds_tr, batch_size=PROP_PRED_BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, collate_fn=psmiles_prop_collate_fn) dl_va = DataLoader(ds_va, batch_size=PROP_PRED_BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True, collate_fn=psmiles_prop_collate_fn) opt = torch.optim.AdamW([p for p in model.parameters() if p.requires_grad], lr=PROP_PRED_LR, weight_decay=PROP_PRED_WEIGHT_DECAY) scaler_amp = torch.cuda.amp.GradScaler(enabled=USE_AMP) best_val = float("inf") best_state = None no_imp = 0 for epoch in range(1, PROP_PRED_EPOCHS + 1): tr = train_prop_oracle_one_epoch(model, dl_tr, opt, scaler_amp, device) va = eval_prop_oracle(model, dl_va, device) if va < best_val - 1e-8: best_val = va no_imp = 0 best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()} else: no_imp += 1 if no_imp >= PROP_PRED_PATIENCE: break if best_state is not None: model.load_state_dict({k: v.to(device) for k, v in best_state.items()}, strict=False) try: model.aux_val_loss = float(best_val) except Exception: pass return model @torch.no_grad() def oracle_predict_scaled( oracle: Optional[TextPropertyOracle], psmiles_tokenizer, psmiles_list: List[str], device: str, max_len: int = PSMILES_MAX_LEN, ) -> Optional[np.ndarray]: """Batch predict scaled properties with the auxiliary oracle.""" if oracle is None or psmiles_tokenizer is None: return None if not psmiles_list: return np.array([], dtype=np.float32) oracle.eval() ys = [] bs = 32 for i in range(0, len(psmiles_list), bs): chunk = psmiles_list[i : i + bs] enc = psmiles_tokenizer(chunk, truncation=True, padding="max_length", max_length=max_len, return_tensors="pt") input_ids = enc["input_ids"].to(device) attn = enc["attention_mask"].to(device).bool() with torch.cuda.amp.autocast(enabled=USE_AMP, dtype=AMP_DTYPE): y_hat = oracle(input_ids, attn) ys.append(y_hat.detach().cpu().numpy().reshape(-1)) return np.concatenate(ys, axis=0) if ys else np.array([], dtype=np.float32) # ============================================================================= # Latent property model (per property) # ============================================================================= @dataclass class LatentPropertyModel: y_scaler: StandardScaler pca: Optional[PCA] gpr: GaussianProcessRegressor def fit_latent_property_model(z_train: np.ndarray, y_train: np.ndarray, y_scaler: StandardScaler) -> LatentPropertyModel: """ Fit a GPR mapping (PSMILES latent) -> (scaled property). Uses optional PCA for stability when latent dim is large. """ y_train = np.array(y_train, dtype=np.float32).reshape(-1, 1) y_s = y_scaler.transform(y_train).reshape(-1).astype(np.float32) z_use = z_train.astype(np.float32) pca = None if USE_PCA_BEFORE_GPR: ncomp = int(min(PCA_DIM, z_use.shape[0] - 1, z_use.shape[1])) ncomp = max(2, ncomp) pca = PCA(n_components=ncomp, random_state=0) z_use = pca.fit_transform(z_use) kernel = ( C(1.0, (1e-3, 1e3)) * RBF(length_scale=1.0, length_scale_bounds=(1e-2, 1e2)) + WhiteKernel(noise_level=1e-3, noise_level_bounds=(1e-6, 1e-1)) ) gpr = GaussianProcessRegressor(kernel=kernel, alpha=GPR_ALPHA, normalize_y=True, random_state=0, n_restarts_optimizer=2) gpr.fit(z_use, y_s) return LatentPropertyModel(y_scaler=y_scaler, pca=pca, gpr=gpr) def predict_latent_property(model: LatentPropertyModel, z: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: """Predict scaled and unscaled properties for candidate latents.""" z_use = z.astype(np.float32) if model.pca is not None: z_use = model.pca.transform(z_use) y_s = model.gpr.predict(z_use, return_std=False) y_s = np.array(y_s, dtype=np.float32).reshape(-1) y_u = model.y_scaler.inverse_transform(y_s.reshape(-1, 1)).reshape(-1) return y_s, y_u # ============================================================================= # Train / eval loops (decoder) # ============================================================================= def train_one_epoch_decoder(model: CLConditionedSelfiesTEDGenerator, dl: DataLoader, optimizer, scaler_amp, device: str): """One epoch of teacher-forced decoder fine-tuning.""" model.train() total = 0.0 n = 0 ce_sum = 0.0 for batch in dl: move_latent_batch_to_device(batch, device) z = batch["z"] labels = batch["labels"] optimizer.zero_grad(set_to_none=True) with torch.cuda.amp.autocast(enabled=USE_AMP, dtype=AMP_DTYPE): out = model.forward_train(z, labels) loss = out["loss"] if USE_AMP: scaler_amp.scale(loss).backward() scaler_amp.unscale_(optimizer) torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) scaler_amp.step(optimizer) scaler_amp.update() else: loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) optimizer.step() bs = z.size(0) total += float(loss.item()) * bs ce_sum += float(out["ce"].item()) * bs n += bs return {"loss": total / max(1, n), "ce": ce_sum / max(1, n)} @torch.no_grad() def evaluate_decoder(model: CLConditionedSelfiesTEDGenerator, dl: DataLoader, device: str): """Validation loss for early stopping.""" model.eval() total = 0.0 n = 0 ce_sum = 0.0 for batch in dl: move_latent_batch_to_device(batch, device) z = batch["z"] labels = batch["labels"] with torch.cuda.amp.autocast(enabled=USE_AMP, dtype=AMP_DTYPE): out = model.forward_train(z, labels) loss = out["loss"] bs = z.size(0) total += float(loss.item()) * bs ce_sum += float(out["ce"].item()) * bs n += bs return {"loss": total / max(1, n), "ce": ce_sum / max(1, n)} # ============================================================================= # Generation / filtering (per target value, per property) # ============================================================================= def compute_diversity_morgan(smiles_list: List[str], radius: int = 2, nbits: int = 2048, p: float = 1.0) -> Optional[float]: """ Diversity = 1 - mean(Tanimoto), computed on Morgan fingerprints of unique valid SMILES. Returns None if RDKit unavailable or insufficient valid molecules. """ if not RDKit_AVAILABLE: return None try: p = float(p) if not np.isfinite(p) or p <= 0: p = 1.0 except Exception: p = 1.0 uniq = [] seen = set() for smi in smiles_list: smi = str(smi).strip() if not smi or smi in seen: continue seen.add(smi) uniq.append(smi) fps = [] for smi in uniq: try: mol = Chem.MolFromSmiles(smi) if mol is None: continue try: Chem.SanitizeMol(mol, catchErrors=True) except Exception: continue fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nbits) fps.append(fp) except Exception: continue if len(fps) < 2: return 0.0 if len(fps) == 1 else None sims_p = [] for i in range(len(fps)): for j in range(i + 1, len(fps)): try: s = float(DataStructs.TanimotoSimilarity(fps[i], fps[j])) sims_p.append(s ** p) except Exception: continue if not sims_p: return None mean_sim_p = float(np.mean(sims_p)) try: mean_sim = mean_sim_p ** (1.0 / p) except Exception: mean_sim = float( np.mean([float(DataStructs.TanimotoSimilarity(fps[i], fps[j])) for i in range(len(fps)) for j in range(i + 1, len(fps))]) ) return float(1.0 - mean_sim) @torch.no_grad() def decode_from_latents(generator: CLConditionedSelfiesTEDGenerator, z: torch.Tensor, n_samples: int = 1) -> List[str]: """Decode PSELFIES from a batch of CL latents.""" return generator.generate( z=z, num_return_sequences=int(n_samples), max_len=GEN_MAX_LEN, top_p=GEN_TOP_P, temperature=GEN_TEMPERATURE, repetition_penalty=GEN_REPETITION_PENALTY, ) def generate_for_target( target_y_scaled: float, prop_model: LatentPropertyModel, cl_encoder: MultiModalCLPolymerEncoder, generator: CLConditionedSelfiesTEDGenerator, train_seed_pool: List[dict], train_targets_set: set, n_seeds: int = 8, n_noise: int = N_FOLD_NOISE_SAMPLING, noise_std: float = LATENT_NOISE_STD_GEN, prop_tol_scaled: float = PROP_TOL_SCALED, oracle: Optional[TextPropertyOracle] = None, psmiles_tokenizer=None, ) -> Dict[str, Any]: """ Core generation routine for a single target property value (scaled): 1) Pick seed polymers close to target (in scaled property space). 2) Encode seeds (multimodal) -> latent vectors. 3) Add Gaussian noise to latents (exploration), renormalize. 4) Decode to PSELFIES -> convert to polymer PSMILES. 5) Filter by polymer/chem validity and property closeness (via GPR on PSMILES latents). 6) Compute novelty/uniqueness/diversity metrics; optionally score with aux oracle. """ def _l2_normalize_np(x: np.ndarray, eps: float = 1e-12) -> np.ndarray: n = np.linalg.norm(x, axis=-1, keepdims=True) return x / np.clip(n, eps, None) # Choose nearest seeds by property distance (scaled) ys = np.array([float(d["y_scaled"]) for d in train_seed_pool], dtype=np.float32) diffs = np.abs(ys - float(target_y_scaled)) order = np.argsort(diffs) chosen = [train_seed_pool[i] for i in order[: max(1, int(n_seeds))]] # Encode chosen seeds using multimodal encoder z_seed = cl_encoder.encode_multimodal(chosen, batch_size=32, device=DEVICE) if z_seed.shape[0] == 0: return {"generated": [], "metrics": {}} # Sample noise around each seed latent z_list = [] for i in range(z_seed.shape[0]): z0 = z_seed[i].astype(np.float32) for _ in range(int(n_noise)): z = z0 + np.random.randn(z0.shape[0]).astype(np.float32) * float(noise_std) z = _l2_normalize_np(z.reshape(1, -1)).reshape(-1) z_list.append(z) z_all = np.stack(z_list, axis=0).astype(np.float32) z_t = torch.tensor(z_all, dtype=torch.float32, device=DEVICE) # Decode to PSELFIES pselfies = decode_from_latents(generator, z_t, n_samples=1) # Convert to polymer PSMILES; record validity flags valid_psmiles = [] valid_flags, poly_flags = [], [] for s in pselfies: s = _ensure_two_at_endpoints(_selfies_compact(s)) psm = pselfies_to_psmiles(s) if (RDKit_AVAILABLE and SELFIES_AVAILABLE) else None if psm is None: valid_flags.append(False) poly_flags.append(False) continue psm_can = canonicalize_psmiles(psm) ok = chem_validity_psmiles(psm_can) if psm_can else False poly_ok = polymer_validity_psmiles_strict(psm_can) if psm_can else False valid_flags.append(bool(ok)) poly_flags.append(bool(poly_ok)) if ok and poly_ok and psm_can: valid_psmiles.append(psm_can) uniq_valid = sorted(set(valid_psmiles)) novelty_valid = [1.0 if s not in train_targets_set else 0.0 for s in uniq_valid] if uniq_valid else [] n_valid_poly = int(len(valid_psmiles)) uniqueness_valid_unique = float(len(uniq_valid)) / float(max(1, n_valid_poly)) if n_valid_poly > 0 else 0.0 # Property prediction via GPR on PSMILES latents (for filtering) if uniq_valid: z_cand = cl_encoder.encode_psmiles(uniq_valid, max_len=PSMILES_MAX_LEN, batch_size=64, device=DEVICE) else: z_cand = np.zeros((0, cl_encoder.out_dim), dtype=np.float32) yhat_s, yhat_u = (np.array([], dtype=np.float32), np.array([], dtype=np.float32)) if z_cand.shape[0] > 0: yhat_s, yhat_u = predict_latent_property(prop_model, z_cand) keep, keep_pred_scaled, keep_pred_unscaled = [], [], [] for i, psm in enumerate(uniq_valid): if abs(float(yhat_s[i]) - float(target_y_scaled)) <= float(prop_tol_scaled): keep.append(psm) keep_pred_scaled.append(float(yhat_s[i])) keep_pred_unscaled.append(float(yhat_u[i])) novelty_keep = [1.0 if s not in train_targets_set else 0.0 for s in keep] if keep else [] # Optional aux oracle prediction for additional sanity checking aux_pred_scaled = None if VERIFY_GENERATED_PROPERTIES and oracle is not None and psmiles_tokenizer is not None and keep: aux = oracle_predict_scaled(oracle, psmiles_tokenizer, keep, DEVICE, PSMILES_MAX_LEN) aux_pred_scaled = aux.tolist() if aux is not None else None # Diversity computed on At-SMILES (to avoid polymer "*" parsing issues) at_smiles = [] if RDKit_AVAILABLE and keep: for psm in keep: at_smi = psmiles_to_at_smiles(psm, root_at=False) if at_smi is not None: at_smiles.append(at_smi) div = compute_diversity_morgan(at_smiles) if at_smiles else None metrics = { "n_total": int(len(pselfies)), "validity": float(np.mean(valid_flags)) if valid_flags else 0.0, "polymer_validity": float(np.mean(poly_flags)) if poly_flags else 0.0, "n_valid_unique": int(len(uniq_valid)), "novelty_valid_unique": float(np.mean(novelty_valid)) if novelty_valid else 0.0, "uniqueness_valid_unique": float(uniqueness_valid_unique), "n_kept_property_filtered": int(len(keep)), "novelty_kept": float(np.mean(novelty_keep)) if novelty_keep else 0.0, "diversity": float(div) if div is not None else 0.0, } return { "generated": keep, "pred_scaled_kept": keep_pred_scaled, "pred_unscaled_kept": keep_pred_unscaled, "aux_pred_scaled": aux_pred_scaled, "metrics": metrics, } # ============================================================================= # Data assembly (per property) # ============================================================================= def build_polymer_records(df: pd.DataFrame, prop_col: str) -> List[dict]: """ Build records for a single property: - require chemically valid + strictly polymer-valid PSMILES - require finite property value - generate PSELFIES for decoder targets - preserve optional modalities for multimodal seed encoding """ if not (RDKit_AVAILABLE and SELFIES_AVAILABLE): raise RuntimeError("RDKit + selfies are required for this pipeline.") recs = [] for _, row in df.iterrows(): psmiles_raw = str(row.get("psmiles", "")).strip() if not psmiles_raw: continue psm_can = canonicalize_psmiles(psmiles_raw) if not psm_can: continue if not chem_validity_psmiles(psm_can): continue if not polymer_validity_psmiles_strict(psm_can): continue val = row.get(prop_col, None) if val is None: continue try: y = float(val) if not np.isfinite(y): continue except Exception: continue pself = psmiles_to_pselfies(psm_can) if pself is None: continue recs.append( { "psmiles": psm_can, "pselfies": pself, "y": y, "graph": row.get("graph", None), "geometry": row.get("geometry", None), "fingerprints": row.get("fingerprints", None), } ) return recs # ============================================================================= # Best-fold artifact saving (per property) # ============================================================================= def save_best_fold_artifacts_for_property( property_name: str, fold_idx: int, decoder_state: Dict[str, torch.Tensor], prop_model: Optional[LatentPropertyModel], scaler: Optional[StandardScaler], best_val_loss: float, generations_payload: List[dict], ): """ Persist the best fold for a property: - decoder state_dict - scaler + GPR (joblib, if available) - meta.json describing hyperparams - jsonl generations payload for traceability """ safe_prop = property_name.replace(" ", "_") prop_dir = os.path.join(CFG.OUTPUT_MODELS_DIR, safe_prop) os.makedirs(prop_dir, exist_ok=True) decoder_path = os.path.join(prop_dir, f"decoder_best_fold{fold_idx+1}.pt") torch.save(decoder_state, decoder_path) try: import joblib except Exception: joblib = None if joblib is not None: if scaler is not None: joblib.dump(scaler, os.path.join(prop_dir, f"standardscaler_{safe_prop}.joblib")) if prop_model is not None: joblib.dump(prop_model, os.path.join(prop_dir, f"gpr_psmiles_{safe_prop}.joblib")) meta = { "property": property_name, "best_fold": int(fold_idx + 1), "best_val_loss": float(best_val_loss), "selfies_ted_model": str(SELFIES_TED_MODEL_NAME), "cl_emb_dim": int(CL_EMB_DIM), "mem_len": 4, "tol_scaled": float(PROP_TOL_SCALED), "tol_unscaled_abs": float(PROP_TOL_UNSCALED_ABS) if PROP_TOL_UNSCALED_ABS is not None else None, "optimizer": "AdamW", "lr": float(LEARNING_RATE), "weight_decay": float(WEIGHT_DECAY), "lr_scheduler": "CosineAnnealingLR", "epochs": int(NUM_EPOCHS), "batch_size": int(BATCH_SIZE), "patience": int(PATIENCE), } try: with open(os.path.join(prop_dir, "meta.json"), "w", encoding="utf-8") as f: json.dump(meta, f, indent=2) except Exception: pass out_path = os.path.join(CFG.OUTPUT_GENERATIONS_DIR, f"{safe_prop}_best_fold{fold_idx+1}_generated_psmiles.jsonl") try: with open(out_path, "w", encoding="utf-8") as fh: for r in generations_payload: fh.write(json.dumps(make_json_serializable({"property": property_name, "best_fold": fold_idx + 1, **r})) + "\n") except Exception as e: print(f"[SAVE][WARN] Could not write generations for '{property_name}': {e}") # ============================================================================= # Main per-property CV loop (single-task) # ============================================================================= def run_inverse_design_single_property( df: pd.DataFrame, property_name: str, prop_col: str, cl_encoder: MultiModalCLPolymerEncoder, selfies_tok, selfies_model, ) -> Dict[str, Any]: """ Run fivefold CV for a single property and log fold-level metrics. Best fold is tracked by decoder validation loss and saved to disk. """ polymers = build_polymer_records(df, prop_col) if len(polymers) < 200: print(f"[{property_name}][WARN] Only {len(polymers)} usable samples; results may be noisy.") if len(polymers) < 50: print(f"[{property_name}][WARN] Skipping due to insufficient usable samples (<50).") return {"property": property_name, "runs": [], "agg": None, "n_samples": len(polymers)} indices = np.arange(len(polymers)) kf = KFold(n_splits=NUM_FOLDS, shuffle=True, random_state=42) runs = [] best_overall_val = float("inf") best_bundle = None # kept for completeness; artifacts saved immediately when best improves for fold_idx, (trainval_idx, test_idx) in enumerate(kf.split(indices)): seed = 42 + fold_idx set_seed(seed) print(f"\n[{property_name}] Fold {fold_idx+1}/{NUM_FOLDS} | seed={seed}") trainval_polys = [polymers[i] for i in trainval_idx] test_polys = [polymers[i] for i in test_idx] # Train/val split within trainval tr_idx, va_idx = train_test_split(np.arange(len(trainval_polys)), test_size=0.10, random_state=seed, shuffle=True) train_polys = [copy.deepcopy(trainval_polys[i]) for i in tr_idx] val_polys = [copy.deepcopy(trainval_polys[i]) for i in va_idx] # Scale property targets using TRAIN only sc = StandardScaler() sc.fit(np.array([p["y"] for p in train_polys], dtype=np.float32).reshape(-1, 1)) # Helper to format records for latent dataset def _to_rec(p): return { "psmiles": p["psmiles"], "pselfies": p["pselfies"], "graph": p.get("graph", None), "geometry": p.get("geometry", None), "fingerprints": p.get("fingerprints", None), } # Decoder training datasets (cache CL embeddings for speed) ds_train = LatentToPSELFIESDataset( [_to_rec(p) for p in train_polys], cl_encoder, selfies_tok, max_len=GEN_MAX_LEN, latent_noise_std=LATENT_NOISE_STD_TRAIN, cache_embeddings=True, use_multimodal=True, ) ds_val = LatentToPSELFIESDataset( [_to_rec(p) for p in val_polys], cl_encoder, selfies_tok, max_len=GEN_MAX_LEN, latent_noise_std=0.0, cache_embeddings=True, use_multimodal=True, ) dl_train = DataLoader(ds_train, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True, collate_fn=latent_collate) dl_val = DataLoader(ds_val, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True, collate_fn=latent_collate) # Fit GPR on PSMILES latents for this property (train only) y_tr = [float(p["y"]) for p in train_polys] psm_tr = [p["psmiles"] for p in train_polys] z_tr = cl_encoder.encode_psmiles(psm_tr, max_len=PSMILES_MAX_LEN, batch_size=64, device=DEVICE) prop_model = fit_latent_property_model(z_tr, np.array(y_tr, dtype=np.float32), y_scaler=sc) print(f"[{property_name}] Fit PSMILES-latent GPR (n_train={len(y_tr)})") # Optional aux oracle (scaled) oracle = None if VERIFY_GENERATED_PROPERTIES and len(train_polys) >= 200 and len(val_polys) >= 50: tr_s, va_s = [], [] for p in train_polys: y_s = float(sc.transform(np.array([[p["y"]]], dtype=np.float32))[0, 0]) tr_s.append({"psmiles": p["psmiles"], "target": p["y"], "target_scaled": y_s}) for p in val_polys: y_s = float(sc.transform(np.array([[p["y"]]], dtype=np.float32))[0, 0]) va_s.append({"psmiles": p["psmiles"], "target": p["y"], "target_scaled": y_s}) try: oracle = train_property_oracle_per_fold(tr_s, va_s, cl_encoder.psm_tok, DEVICE, PSMILES_MAX_LEN) print(f"[{property_name}] Trained aux oracle (val_loss={getattr(oracle, 'aux_val_loss', None)})") except Exception as e: print(f"[{property_name}][WARN] Oracle training failed: {e}") oracle = None # Fresh decoder per fold + optimizer selfies_tok_f, selfies_model_f = load_selfies_ted_and_tokenizer(SELFIES_TED_MODEL_NAME) decoder = CLConditionedSelfiesTEDGenerator(selfies_tok_f, selfies_model_f, cl_emb_dim=CL_EMB_DIM, mem_len=4).to(DEVICE) optimizer, scheduler = create_optimizer_and_scheduler_decoder(decoder) scaler_amp = torch.cuda.amp.GradScaler(enabled=USE_AMP) best_val = float("inf") best_state = None no_improve = 0 for epoch in range(1, NUM_EPOCHS + 1): tr = train_one_epoch_decoder(decoder, dl_train, optimizer, scaler_amp, DEVICE) va = evaluate_decoder(decoder, dl_val, DEVICE) try: scheduler.step() except Exception: pass try: lr = float(optimizer.param_groups[0]["lr"]) print( f"[{property_name}] fold {fold_idx+1}/{NUM_FOLDS} | epoch {epoch:03d} | " f"lr={lr:.2e} | train_loss={tr['loss']:.6f} | val_loss={va['loss']:.6f}" ) except Exception: print( f"[{property_name}] fold {fold_idx+1}/{NUM_FOLDS} | epoch {epoch:03d} | " f"train_loss={tr['loss']:.6f} | val_loss={va['loss']:.6f}" ) if va["loss"] < best_val - 1e-8: best_val = va["loss"] no_improve = 0 best_state = {k: v.detach().cpu().clone() for k, v in decoder.state_dict().items()} else: no_improve += 1 if no_improve >= PATIENCE: print(f"[{property_name}] Early stopping (no val improvement for {PATIENCE} epochs).") break if best_state is None: print(f"[{property_name}][WARN] No best state captured; skipping this fold.") continue decoder.load_state_dict({k: v.to(DEVICE) for k, v in best_state.items()}, strict=False) # Seed pool for generation (scaled property values, plus modalities for encoding) seed_pool = [] for p in train_polys: y_s = float(sc.transform(np.array([[p["y"]]], dtype=np.float32))[0, 0]) seed_pool.append( { "psmiles": p["psmiles"], "y_scaled": y_s, "graph": p.get("graph", None), "geometry": p.get("geometry", None), "fingerprints": p.get("fingerprints", None), } ) train_targets_set = set(ps["psmiles"] for ps in train_polys) # Compose test targets (scaled); subsample for runtime control ys_test_scaled = [] for p in test_polys: ys_test_scaled.append(float(sc.transform(np.array([[p["y"]]], dtype=np.float32))[0, 0])) ys_test_scaled = np.array(ys_test_scaled, dtype=np.float32) if len(ys_test_scaled) > 64: ys_test_scaled = np.random.choice(ys_test_scaled, size=64, replace=False) # Generate per target all_valid, all_poly, all_kept, success_scaled, mae_best, diversity_vals = [], [], [], [], [], [] novelty_vals, uniqueness_vals = [], [] per_target_records = [] for y_t in ys_test_scaled: out = generate_for_target( target_y_scaled=float(y_t), prop_model=prop_model, cl_encoder=cl_encoder, generator=decoder, train_seed_pool=seed_pool, train_targets_set=train_targets_set, n_seeds=8, n_noise=min(N_FOLD_NOISE_SAMPLING, 16), noise_std=LATENT_NOISE_STD_GEN, prop_tol_scaled=PROP_TOL_SCALED, oracle=oracle, psmiles_tokenizer=cl_encoder.psm_tok, ) m = out["metrics"] all_valid.append(float(m.get("validity", 0.0))) all_poly.append(float(m.get("polymer_validity", 0.0))) all_kept.append(int(m.get("n_kept_property_filtered", 0))) diversity_vals.append(float(m.get("diversity", 0.0))) success_scaled.append(1.0 if int(m.get("n_kept_property_filtered", 0)) > 0 else 0.0) novelty_vals.append(float(m.get("novelty_kept", 0.0))) uniqueness_vals.append(float(m.get("uniqueness_valid_unique", 0.0))) # Best error among kept candidates if out["generated"]: z_keep = cl_encoder.encode_psmiles(out["generated"], max_len=PSMILES_MAX_LEN, batch_size=64, device=DEVICE) y_pred_s, _ = predict_latent_property(prop_model, z_keep) if len(y_pred_s): err = np.abs(y_pred_s - float(y_t)) mae_best.append(float(np.min(err))) else: mae_best.append(float("inf")) else: mae_best.append(float("inf")) target_y_unscaled = float(sc.inverse_transform(np.array([[float(y_t)]], dtype=np.float32))[0, 0]) aux_list = out.get("aux_pred_scaled", None) if aux_list is not None and not isinstance(aux_list, list): aux_list = None candidates = [] gen_list = out.get("generated", []) or [] pred_s_list = out.get("pred_scaled_kept", []) or [] pred_u_list = out.get("pred_unscaled_kept", []) or [] for i_c, psm in enumerate(gen_list): cand = { "psmiles": str(psm), "pred_scaled": float(pred_s_list[i_c]) if i_c < len(pred_s_list) else None, "pred_unscaled": float(pred_u_list[i_c]) if i_c < len(pred_u_list) else None, "aux_pred_scaled": float(aux_list[i_c]) if (aux_list is not None and i_c < len(aux_list)) else None, } candidates.append(cand) scaler_meta = { "scaler_type": "StandardScaler", "mean_": getattr(sc, "mean_", None), "scale_": getattr(sc, "scale_", None), "with_mean": bool(getattr(sc, "with_mean", True)), "with_std": bool(getattr(sc, "with_std", True)), } per_target_records.append( { "target_y_scaled": float(y_t), "target_y_unscaled": float(target_y_unscaled), "tol_scaled": float(PROP_TOL_SCALED), "tol_unscaled_abs": float(PROP_TOL_UNSCALED_ABS) if PROP_TOL_UNSCALED_ABS is not None else None, "scaler_meta": scaler_meta, "candidates": candidates, "metrics": m, } ) def _finite(xs): return [x for x in xs if np.isfinite(x)] metrics_fold = { "validity_mean": float(np.mean(all_valid)) if all_valid else 0.0, "polymer_validity_mean": float(np.mean(all_poly)) if all_poly else 0.0, "avg_n_kept": float(np.mean(all_kept)) if all_kept else 0.0, "success_at_k_scaled": float(np.mean(success_scaled)) if success_scaled else 0.0, "mae_best_scaled": float(np.mean(_finite(mae_best))) if _finite(mae_best) else 0.0, "diversity_mean": float(np.mean(diversity_vals)) if diversity_vals else 0.0, "novelty_mean": float(np.mean(novelty_vals)) if novelty_vals else 0.0, "uniqueness_mean": float(np.mean(uniqueness_vals)) if uniqueness_vals else 0.0, "tol_scaled": float(PROP_TOL_SCALED), "tol_unscaled_abs": float(PROP_TOL_UNSCALED_ABS) if PROP_TOL_UNSCALED_ABS is not None else None, } run_record = { "property": property_name, "fold": int(fold_idx + 1), "seed": int(seed), "n_train": int(len(train_polys)), "n_val": int(len(val_polys)), "n_test": int(len(test_polys)), "best_val_loss": float(best_val), "gen_metrics": metrics_fold, } runs.append(run_record) with open(CFG.OUTPUT_RESULTS, "a", encoding="utf-8") as fh: fh.write(json.dumps(make_json_serializable(run_record)) + "\n") # Save best fold artifacts by lowest validation loss if best_val < best_overall_val - 1e-8: best_overall_val = best_val best_bundle = { "fold": int(fold_idx + 1), "decoder_state": best_state, "prop_model": prop_model, "scaler": sc, "best_val_loss": float(best_val), "generations": per_target_records, } save_best_fold_artifacts_for_property( property_name=property_name, fold_idx=fold_idx, decoder_state=best_state, prop_model=prop_model, scaler=sc, best_val_loss=best_val, generations_payload=per_target_records, ) print(f"[{property_name}] Saved best-fold artifacts (fold {fold_idx+1}, val_loss={best_val:.6f}).") # Aggregate across folds if not runs: return {"property": property_name, "runs": [], "agg": None, "n_samples": len(polymers)} def _collect(key): xs = [float(r["gen_metrics"].get(key, 0.0)) for r in runs if r.get("gen_metrics", None) is not None] return (float(np.mean(xs)) if xs else 0.0, float(np.std(xs)) if xs else 0.0) agg = {} for k in [ "validity_mean", "polymer_validity_mean", "avg_n_kept", "success_at_k_scaled", "mae_best_scaled", "diversity_mean", "novelty_mean", "uniqueness_mean", ]: m, s = _collect(k) agg[k] = {"mean": m, "std": s} agg["tol_scaled"] = float(PROP_TOL_SCALED) agg["tol_unscaled_abs"] = float(PROP_TOL_UNSCALED_ABS) if PROP_TOL_UNSCALED_ABS is not None else None with open(CFG.OUTPUT_RESULTS, "a", encoding="utf-8") as fh: fh.write("AGG_PROPERTY: " + json.dumps(make_json_serializable({property_name: agg})) + "\n") return {"property": property_name, "runs": runs, "agg": agg, "n_samples": len(polymers)} # ============================================================================= # Entrypoint (single-task per property) # ============================================================================= def main(): ensure_output_dirs(CFG) if not (RDKit_AVAILABLE and SELFIES_AVAILABLE): raise RuntimeError("This script requires RDKit and selfies. Install them before running.") # Reset results file if os.path.exists(CFG.OUTPUT_RESULTS): backup = CFG.OUTPUT_RESULTS + ".bak" shutil.copy(CFG.OUTPUT_RESULTS, backup) print(f"[IO][INFO] Backed up existing results file to: {backup}") open(CFG.OUTPUT_RESULTS, "w", encoding="utf-8").close() # Load dataset if not os.path.isfile(CFG.POLYINFO_PATH): raise FileNotFoundError(f"PolyInfo CSV not found: {CFG.POLYINFO_PATH}") df = pd.read_csv(CFG.POLYINFO_PATH, engine="python") found = find_property_columns(df.columns) prop_map = {req: found.get(req) for req in REQUESTED_PROPERTIES} print("\n" + "=" * 80) print("[RUN] Inverse design (single-task per property)") print("=" * 80) print(f"[ENV] RDKit_AVAILABLE={RDKit_AVAILABLE} | SELFIES_AVAILABLE={SELFIES_AVAILABLE}") print(f"[ENV] DEVICE={DEVICE} | USE_AMP={USE_AMP} | NUM_WORKERS={NUM_WORKERS}") print(f"[DATA] POLYINFO_PATH={CFG.POLYINFO_PATH}") print(f"[DATA] Property map: {prop_map}") print(f"[CL] CL checkpoint dir: {CFG.PRETRAINED_MULTIMODAL_DIR}") print(f"[DEC] SELFIES_TED_MODEL_NAME={SELFIES_TED_MODEL_NAME}") print( f"[DEC] FT params: batch={BATCH_SIZE}, epochs={NUM_EPOCHS}, patience={PATIENCE}, " f"lr={LEARNING_RATE}, wd={WEIGHT_DECAY}, sched=CosineAnnealingLR(eta_min={COSINE_ETA_MIN})" ) print(f"[GEN] Latent noise: train_std={LATENT_NOISE_STD_TRAIN}, gen_std={LATENT_NOISE_STD_GEN}, n_noise={N_FOLD_NOISE_SAMPLING}") print(f"[GEN] Filter tol: scaled={PROP_TOL_SCALED}, abs={PROP_TOL_UNSCALED_ABS}") print(f"[AUX] VERIFY_GENERATED_PROPERTIES={VERIFY_GENERATED_PROPERTIES}") print("=" * 80 + "\n") # Build PSMILES tokenizer for CL text encoder psmiles_tok = build_psmiles_tokenizer(spm_path=CFG.SPM_MODEL_PATH, max_len=PSMILES_MAX_LEN) if psmiles_tok is None: raise RuntimeError("Failed to build PSMILES tokenizer (check SPM_MODEL_PATH).") # Multimodal CL encoder (frozen; used as conditioning interface) cl_encoder = MultiModalCLPolymerEncoder( psmiles_tokenizer=psmiles_tok, emb_dim=CL_EMB_DIM, cl_weights_dir=CFG.PRETRAINED_MULTIMODAL_DIR, use_gine=True, use_schnet=True, use_fp=True, use_psmiles=True, ).to(DEVICE) cl_encoder.freeze_cl_encoders() # Load SELFIES-TED backbone selfies_tok, selfies_model = load_selfies_ted_and_tokenizer(SELFIES_TED_MODEL_NAME) print(f"[HF][INFO] Loaded SELFIES-TED backbone: {SELFIES_TED_MODEL_NAME}") overall = {"per_property": {}} # Single-task loop per property for pname in REQUESTED_PROPERTIES: pcol = prop_map.get(pname, None) if pcol is None: print(f"[{pname}][WARN] No column match found; skipping.") continue print(f"\n>>> Property: '{pname}' | column='{pcol}'") res = run_inverse_design_single_property(df, pname, pcol, cl_encoder, selfies_tok, selfies_model) overall["per_property"][pname] = res # Final summary (aggregated per property) final_agg = {} for pname, info in overall["per_property"].items(): final_agg[pname] = info.get("agg", None) with open(CFG.OUTPUT_RESULTS, "a", encoding="utf-8") as fh: fh.write("\nFINAL_SUMMARY\n") fh.write(json.dumps(make_json_serializable(final_agg), indent=2)) fh.write("\n") print("\n" + "=" * 80) print("Finished inverse design runs.") print(f"Results file: {CFG.OUTPUT_RESULTS}") print(f"Best models dir: {CFG.OUTPUT_MODELS_DIR}") print(f"Best-fold generations dir: {CFG.OUTPUT_GENERATIONS_DIR}") print("=" * 80) if __name__ == "__main__": main()