# data_builder.py from __future__ import annotations from pathlib import Path from typing import Dict, List, Optional, Tuple, Sequence import json import warnings import numpy as np import pandas as pd import torch from torch.utils.data import Dataset from torch_geometric.data import Data # RDKit is required from rdkit import Chem from rdkit.Chem.rdchem import HybridizationType, BondType, BondStereo # --------------------------------------------------------- # Fidelity handling # --------------------------------------------------------- FID_PRIORITY = ["exp", "dft", "md", "gc"] # internal lower-case canonical order def _norm_fid(fid: str) -> str: return fid.strip().lower() def _ensure_targets_order(requested: Sequence[str]) -> List[str]: seen = set() ordered = [] for t in requested: key = t.strip() if key in seen: continue seen.add(key) ordered.append(key) return ordered # --------------------------------------------------------- # RDKit featurization # --------------------------------------------------------- _ATOMS = ["H", "C", "N", "O", "F", "P", "S", "Cl", "Br", "I"] _ATOM2IDX = {s: i for i, s in enumerate(_ATOMS)} _HYBS = [HybridizationType.SP, HybridizationType.SP2, HybridizationType.SP3, HybridizationType.SP3D, HybridizationType.SP3D2] _HYB2IDX = {h: i for i, h in enumerate(_HYBS)} _BOND_STEREOS = [ BondStereo.STEREONONE, BondStereo.STEREOANY, BondStereo.STEREOZ, BondStereo.STEREOE, BondStereo.STEREOCIS, BondStereo.STEREOTRANS, ] _STEREO2IDX = {s: i for i, s in enumerate(_BOND_STEREOS)} def _one_hot(index: int, size: int) -> List[float]: v = [0.0] * size if 0 <= index < size: v[index] = 1.0 return v def atom_features(atom: Chem.Atom) -> List[float]: # Element one-hot with "other" elem_idx = _ATOM2IDX.get(atom.GetSymbol(), None) elem_oh = _one_hot(elem_idx if elem_idx is not None else len(_ATOMS), len(_ATOMS) + 1) # Degree one-hot up to 5 (bucket 5+) deg = min(int(atom.GetDegree()), 5) deg_oh = _one_hot(deg, 6) # Formal charge one-hot in [-2,-1,0,+1,+2] fc = max(-2, min(2, int(atom.GetFormalCharge()))) fc_oh = _one_hot(fc + 2, 5) # Aromatic, in ring flags aromatic = [1.0 if atom.GetIsAromatic() else 0.0] in_ring = [1.0 if atom.IsInRing() else 0.0] # Hybridization one-hot with "other" hyb_idx = _HYB2IDX.get(atom.GetHybridization(), None) hyb_oh = _one_hot(hyb_idx if hyb_idx is not None else len(_HYBS), len(_HYBS) + 1) # Implicit H count capped at 4 imp_h = min(int(atom.GetTotalNumHs(includeNeighbors=True)), 4) imp_h_oh = _one_hot(imp_h, 5) # length: 11+6+5+1+1+6+5 = 35 (element has 11 buckets incl. "other") feats = elem_oh + deg_oh + fc_oh + aromatic + in_ring + hyb_oh + imp_h_oh return feats def bond_features(bond: Chem.Bond) -> List[float]: bt = bond.GetBondType() single = 1.0 if bt == BondType.SINGLE else 0.0 double = 1.0 if bt == BondType.DOUBLE else 0.0 triple = 1.0 if bt == BondType.TRIPLE else 0.0 aromatic = 1.0 if bt == BondType.AROMATIC else 0.0 conj = 1.0 if bond.GetIsConjugated() else 0.0 in_ring = 1.0 if bond.IsInRing() else 0.0 stereo_oh = _one_hot(_STEREO2IDX.get(bond.GetStereo(), 0), len(_BOND_STEREOS)) # length: 4 + 1 + 1 + 6 = 12 return [single, double, triple, aromatic, conj, in_ring] + stereo_oh def featurize_smiles(smiles: str) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: mol = Chem.MolFromSmiles(smiles) if mol is None: raise ValueError(f"RDKit failed to parse SMILES: {smiles}") # Nodes x = torch.tensor([atom_features(a) for a in mol.GetAtoms()], dtype=torch.float32) # Edges (bidirectional) rows, cols, eattr = [], [], [] for b in mol.GetBonds(): i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx() bf = bond_features(b) rows.extend([i, j]) cols.extend([j, i]) eattr.extend([bf, bf]) if not rows: # single-atom molecules, add a dummy self-loop edge rows, cols = [0], [0] eattr = [[0.0] * 12] edge_index = torch.tensor([rows, cols], dtype=torch.long) edge_attr = torch.tensor(eattr, dtype=torch.float32) return x, edge_index, edge_attr # --------------------------------------------------------- # CSV discovery and reading # --------------------------------------------------------- def discover_target_fid_csvs( root: Path, targets: Sequence[str], fidelities: Sequence[str], ) -> Dict[tuple[str, str], Path]: """ Discover CSV files for (target, fidelity) pairs. Supported layouts (case-insensitive): 1) {root}/{fid}/{target}.csv e.g. datafull/MD/SHEAR.csv, datafull/exp/cp.csv 2) {root}/{target}_{fid}.csv e.g. datafull/SHEAR_MD.csv, datafull/cp_exp.csv Matching is STRICT: - target and fid must appear as full '_' tokens in the stem - no substring matching, so 'he' will NOT match 'shear_md.csv' """ root = Path(root) targets = _ensure_targets_order(targets) fids_lc = [_norm_fid(f) for f in fidelities] # Collect all CSVs under root all_paths = list(root.rglob("*.csv")) # Pre-index: (parent_name_lower, stem_lower, tokens_lower) indexed = [] for p in all_paths: parent = p.parent.name.lower() stem = p.stem.lower() # filename without extension tokens = stem.split("_") tokens_l = [t.lower() for t in tokens] indexed.append((p, parent, stem, tokens_l)) mapping: Dict[tuple[str, str], Path] = {} for fid in fids_lc: fid_l = fid.strip().lower() for tgt in targets: tgt_l = tgt.strip().lower() # ---- 1) Prefer explicit folder layout: {root}/{fid}/{target}.csv ---- # parent == fid AND stem == target (case-insensitive) folder_matches = [ p for (p, parent, stem, tokens_l) in indexed if parent == fid_l and stem == tgt_l ] if folder_matches: # If you ever get more than one, it’s a config problem if len(folder_matches) > 1: warnings.warn( f"[discover_target_fid_csvs] Multiple matches for " f"target='{tgt}' fid='{fid}' under folder layout: " + ", ".join(str(p) for p in folder_matches) ) mapping[(tgt, fid)] = folder_matches[0] continue # ---- 2) Fallback: {target}_{fid}.csv anywhere under root ---- # require BOTH tgt and fid as full '_' tokens token_matches = [ p for (p, parent, stem, tokens_l) in indexed if (tgt_l in tokens_l) and (fid_l in tokens_l) ] if token_matches: if len(token_matches) > 1: warnings.warn( f"[discover_target_fid_csvs] Multiple token matches for " f"target='{tgt}' fid='{fid}': " + ", ".join(str(p) for p in token_matches) ) mapping[(tgt, fid)] = token_matches[0] continue # If neither layout exists, we simply do not add (tgt, fid) to mapping. # build_long_table will just skip that combination. # You can enable a warning if you want: # warnings.warn(f"[discover_target_fid_csvs] No CSV for target='{tgt}', fid='{fid}'") return mapping def read_target_csv(path: Path, target: str) -> pd.DataFrame: """ Accepts: - 'smiles' column (case-insensitive) - value column named '{target}' or one of ['value','y' or lower-case target] Deduplicates by SMILES with mean. """ df = pd.read_csv(path) # smiles column smiles_col = next((c for c in df.columns if c.lower() == "smiles"), None) if smiles_col is None: raise ValueError(f"{path} must contain a 'smiles' column.") df = df.rename(columns={smiles_col: "smiles"}) # value column val_col = None if target in df.columns: val_col = target else: for c in df.columns: if c.lower() in ("value", "y", target.lower()): val_col = c break if val_col is None: raise ValueError(f"{path} must contain a '{target}' column or one of ['value','y'].") df = df[["smiles", val_col]].copy() df = df.dropna(subset=[val_col]) df[val_col] = pd.to_numeric(df[val_col], errors="coerce") df = df.dropna(subset=[val_col]) # Deduplicate SMILES by mean if df.duplicated(subset=["smiles"]).any(): warnings.warn(f"[data_builder] Duplicates by SMILES in {path}. Averaging duplicates.") df = df.groupby("smiles", as_index=False)[val_col].mean() return df.rename(columns={val_col: target}) def build_long_table(root: Path, targets: Sequence[str], fidelities: Sequence[str]) -> pd.DataFrame: """ Returns long-form table with columns: [smiles, fid, fid_idx, target, value] """ targets = _ensure_targets_order(targets) fids_lc = [_norm_fid(f) for f in fidelities] mapping = discover_target_fid_csvs(root, targets, fidelities) if not mapping: raise FileNotFoundError(f"No CSVs found under {root} for the given targets and fidelities.") long_rows = [] for (tgt, fid), path in mapping.items(): df = read_target_csv(path, tgt) df["fid"] = _norm_fid(fid) df["target"] = tgt df = df.rename(columns={tgt: "value"}) long_rows.append(df[["smiles", "fid", "target", "value"]]) long = pd.concat(long_rows, axis=0, ignore_index=True) # attach fid index by priority fid2idx = {f: i for i, f in enumerate(FID_PRIORITY)} long["fid"] = long["fid"].str.lower() unknown = sorted(set(long["fid"]) - set(fid2idx.keys())) if unknown: warnings.warn(f"[data_builder] Unknown fidelities found: {unknown}. Appending after known ones.") start = len(fid2idx) for i, f in enumerate(unknown): fid2idx[f] = start + i long["fid_idx"] = long["fid"].map(fid2idx) return long def pivot_to_rows_by_smiles_fid(long: pd.DataFrame, targets: Sequence[str]) -> pd.DataFrame: """ Input: long table [smiles, fid, fid_idx, target, value] Output: row-per-(smiles,fid) with wide columns for each target """ targets = _ensure_targets_order(targets) wide = long.pivot_table(index=["smiles", "fid", "fid_idx"], columns="target", values="value", aggfunc="mean") wide = wide.reset_index() for t in targets: if t not in wide.columns: wide[t] = np.nan cols = ["smiles", "fid", "fid_idx"] + list(targets) return wide[cols] # --------------------------------------------------------- # Grouped split by SMILES and transforms/normalization # --------------------------------------------------------- def grouped_split_by_smiles( df_rows: pd.DataFrame, val_ratio: float = 0.1, test_ratio: float = 0.1, seed: int = 42, ) -> tuple[np.ndarray, np.ndarray, np.ndarray]: uniq = df_rows["smiles"].drop_duplicates().values rng = np.random.default_rng(seed) uniq = rng.permutation(uniq) n = len(uniq) n_test = int(round(n * test_ratio)) n_val = int(round(n * val_ratio)) test_smiles = set(uniq[:n_test]) val_smiles = set(uniq[n_test:n_test + n_val]) train_smiles = set(uniq[n_test + n_val:]) train_idx = df_rows.index[df_rows["smiles"].isin(train_smiles)].to_numpy() val_idx = df_rows.index[df_rows["smiles"].isin(val_smiles)].to_numpy() test_idx = df_rows.index[df_rows["smiles"].isin(test_smiles)].to_numpy() return train_idx, val_idx, test_idx # ---------------- Enhanced TargetScaler with per-task transforms ---------------- class TargetScaler: """ Per-task transform + standardization fitted on the training split only. - transforms[t] in {"identity","log10"} - eps[t] is added before log for numerical safety (only used if transforms[t]=="log10") - mean/std are computed in the *transformed* domain """ def __init__(self, transforms: Optional[Sequence[str]] = None, eps: Optional[Sequence[float] | torch.Tensor] = None): self.mean: Optional[torch.Tensor] = None # [T] (transformed domain) self.std: Optional[torch.Tensor] = None # [T] (transformed domain) self.transforms: List[str] = [str(t).lower() for t in transforms] if transforms is not None else [] if eps is None: self.eps: Optional[torch.Tensor] = None else: self.eps = torch.as_tensor(eps, dtype=torch.float32) self._tiny = 1e-12 def _ensure_cfg(self, T: int): if not self.transforms or len(self.transforms) != T: self.transforms = ["identity"] * T if self.eps is None or self.eps.numel() != T: self.eps = torch.zeros(T, dtype=torch.float32) def _forward_transform_only(self, y: torch.Tensor) -> torch.Tensor: """ Apply per-task transforms *before* standardization. y: [N, T] in original units. Returns transformed y_tf in same shape. """ out = y.clone() T = out.size(1) self._ensure_cfg(T) for t in range(T): if self.transforms[t] == "log10": out[:, t] = torch.log10(torch.clamp(out[:, t] + self.eps[t], min=self._tiny)) return out def _inverse_transform_only(self, y_tf: torch.Tensor) -> torch.Tensor: """ Inverse the per-task transform (no standardization here). y_tf: [N, T] in transformed units. """ out = y_tf.clone() T = out.size(1) self._ensure_cfg(T) for t in range(T): if self.transforms[t] == "log10": out[:, t] = (10.0 ** out[:, t]) - self.eps[t] return out def fit(self, y: torch.Tensor, mask: torch.Tensor): """ y: [N, T] original units; mask: [N, T] bool Chooses eps automatically if not provided; mean/std computed in transformed space. """ T = y.size(1) self._ensure_cfg(T) if self.eps is None or self.eps.numel() != T: # Auto epsilon: 0.1 * min positive per task (robust) eps_vals: List[float] = [] y_np = y.detach().cpu().numpy() m_np = mask.detach().cpu().numpy().astype(bool) for t in range(T): if self.transforms[t] != "log10": eps_vals.append(0.0) continue vals = y_np[m_np[:, t], t] pos = vals[vals > 0] if pos.size == 0: eps_vals.append(1e-8) else: eps_vals.append(0.1 * float(max(np.min(pos), 1e-8))) self.eps = torch.tensor(eps_vals, dtype=torch.float32) y_tf = self._forward_transform_only(y) eps = 1e-8 y_masked = torch.where(mask, y_tf, torch.zeros_like(y_tf)) counts = mask.sum(dim=0).clamp_min(1) mean = y_masked.sum(dim=0) / counts var = ((torch.where(mask, y_tf - mean, torch.zeros_like(y_tf))) ** 2).sum(dim=0) / counts std = torch.sqrt(var + eps) self.mean, self.std = mean, std def transform(self, y: torch.Tensor) -> torch.Tensor: y_tf = self._forward_transform_only(y) return (y_tf - self.mean) / self.std def inverse(self, y_std: torch.Tensor) -> torch.Tensor: """ Inverse standardization + inverse transform → original units. y_std: [N, T] in standardized-transformed space """ y_tf = y_std * self.std + self.mean return self._inverse_transform_only(y_tf) def state_dict(self) -> Dict[str, torch.Tensor | List[str]]: return { "mean": self.mean, "std": self.std, "transforms": self.transforms, "eps": self.eps, } def load_state_dict(self, state: Dict[str, torch.Tensor | List[str]]): self.mean = state["mean"] self.std = state["std"] self.transforms = [str(t) for t in state.get("transforms", [])] eps = state.get("eps", None) self.eps = torch.as_tensor(eps, dtype=torch.float32) if eps is not None else None def auto_select_task_transforms( y_train: torch.Tensor, # [N, T] original units (train split only) mask_train: torch.Tensor, # [N, T] bool task_names: Sequence[str], *, min_pos_frac: float = 0.95, # ≥95% of labels positive orders_threshold: float = 2.0, # ≥2 orders of magnitude between p95 and p5 tiny: float = 1e-12, ) -> tuple[List[str], torch.Tensor]: """ Decide per-task transform: "log10" if (mostly-positive AND large dynamic range), else "identity". Returns (transforms, eps_vector) where eps is only used for log tasks. """ Y = y_train.detach().cpu().numpy() M = mask_train.detach().cpu().numpy().astype(bool) transforms: List[str] = [] eps_vals: List[float] = [] for t in range(Y.shape[1]): yt = Y[M[:, t], t] if yt.size == 0: transforms.append("identity") eps_vals.append(0.0) continue pos_frac = (yt > 0).mean() p5 = float(np.percentile(yt, 5)) p95 = float(np.percentile(yt, 95)) denom = max(p5, tiny) dyn_orders = float(np.log10(max(p95 / denom, 1.0))) use_log = (pos_frac >= min_pos_frac) and (dyn_orders >= orders_threshold) if use_log: pos_vals = yt[yt > 0] if pos_vals.size == 0: eps_vals.append(1e-8) else: eps_vals.append(0.1 * float(max(np.min(pos_vals), 1e-8))) transforms.append("log10") else: transforms.append("identity") eps_vals.append(0.0) return transforms, torch.tensor(eps_vals, dtype=torch.float32) # --------------------------------------------------------- # Dataset # --------------------------------------------------------- class MultiFidelityMoleculeDataset(Dataset): """ Each item is a PyG Data with: - x: [N_nodes, F_node] - edge_index: [2, N_edges] - edge_attr: [N_edges, F_edge] - y: [T] normalized targets (zeros where missing) - y_mask: [T] bool mask of present targets - fid_idx: [1] long - .smiles and .fid_str added for debugging Targets are kept in the exact order provided by the user. """ def __init__( self, rows: pd.DataFrame, targets: Sequence[str], scaler: Optional[TargetScaler], smiles_graph_cache: Dict[str, tuple[torch.Tensor, torch.Tensor, torch.Tensor]], ): super().__init__() self.rows = rows.reset_index(drop=True).copy() self.targets = _ensure_targets_order(targets) self.scaler = scaler self.smiles_graph_cache = smiles_graph_cache # Build y and mask tensors ys, masks = [], [] for _, r in self.rows.iterrows(): yv, mv = [], [] for t in self.targets: v = r[t] if pd.isna(v): yv.append(np.nan) mv.append(False) else: yv.append(float(v)) mv.append(True) ys.append(yv) masks.append(mv) y = torch.tensor(np.array(ys, dtype=np.float32)) # [N, T] mask = torch.tensor(np.array(masks, dtype=np.bool_)) if scaler is not None and scaler.mean is not None: y_norm = torch.where(mask, scaler.transform(y), torch.zeros_like(y)) else: y_norm = y self.y = y_norm self.mask = mask # Input dims any_smiles = self.rows.iloc[0]["smiles"] x0, _, e0 = smiles_graph_cache[any_smiles] self.in_dim_node = x0.shape[1] self.in_dim_edge = e0.shape[1] # Fidelity metadata for reference (local indexing in this dataset) self.fids = sorted( self.rows["fid"].str.lower().unique().tolist(), key=lambda f: (FID_PRIORITY + [f]).index(f) if f in FID_PRIORITY else len(FID_PRIORITY), ) self.fid2idx = {f: i for i, f in enumerate(self.fids)} self.rows["fid_idx_local"] = self.rows["fid"].str.lower().map(self.fid2idx) def __len__(self) -> int: return len(self.rows) def __getitem__(self, idx: int) -> Data: idx = int(idx) r = self.rows.iloc[idx] smi = r["smiles"] x, edge_index, edge_attr = self.smiles_graph_cache[smi] # Ensure [1, T] so batches become [B, T] y_i = self.y[idx].clone().unsqueeze(0) # [1, T] m_i = self.mask[idx].clone().unsqueeze(0) # [1, T] fid_idx = int(r["fid_idx_local"]) d = Data( x=x.clone(), edge_index=edge_index.clone(), edge_attr=edge_attr.clone(), y=y_i, y_mask=m_i, fid_idx=torch.tensor([fid_idx], dtype=torch.long), ) d.smiles = smi d.fid_str = r["fid"] return d def subsample_train_indices( rows: pd.DataFrame, train_idx: np.ndarray, *, target: Optional[str], fidelity: Optional[str], pct: float = 1.0, seed: int = 137, ) -> np.ndarray: """ Return a filtered train_idx that keeps only a 'pct' fraction (0] """ if target is None or fidelity is None or pct >= 0.999: return train_idx if target not in rows.columns: return train_idx fid_lc = fidelity.strip().lower() # Identify TRAIN rows in the specified block: matching fid and having a label for 'target' train_rows = rows.iloc[train_idx] block_mask = (train_rows["fid"].str.lower() == fid_lc) & (~train_rows[target].isna()) if not bool(block_mask.any()): return train_idx # nothing to subsample # Sample by unique SMILES (stable & grouped) smiles_all = pd.Index(train_rows.loc[block_mask, "smiles"].unique()) n_all = len(smiles_all) if n_all == 0: return train_idx if pct <= 0.0: pct = 0.0001 n_keep = max(1, int(round(pct * n_all))) rng = np.random.RandomState(int(seed)) smiles_sorted = np.array(sorted(smiles_all.tolist())) keep_smiles = set(rng.choice(smiles_sorted, size=n_keep, replace=False).tolist()) # Keep all non-block rows; within block keep selected SMILES keep_mask_local = (~block_mask) | (train_rows["smiles"].isin(keep_smiles)) kept_train_idx = train_rows.index[keep_mask_local].to_numpy() return kept_train_idx # --------------------------------------------------------- # High-level builder # --------------------------------------------------------- def build_dataset_from_dir( root_dir: str | Path, targets: Sequence[str], fidelities: Sequence[str] = ("exp", "dft", "md", "gc"), val_ratio: float = 0.1, test_ratio: float = 0.1, seed: int = 42, save_splits_path: Optional[str | Path] = None, # Optional subsampling of a (target, fidelity) block in TRAIN subsample_target: Optional[str] = None, subsample_fidelity: Optional[str] = None, subsample_pct: float = 1.0, subsample_seed: int = 137, # -------- NEW: auto/explicit log transforms -------- auto_log: bool = True, log_orders_threshold: float = 2.0, log_min_pos_frac: float = 0.95, explicit_log_targets: Optional[Sequence[str]] = None, # e.g. ["permeability"] ) -> tuple[MultiFidelityMoleculeDataset, MultiFidelityMoleculeDataset, MultiFidelityMoleculeDataset, TargetScaler]: """ Returns train_ds, val_ds, test_ds, scaler. - Discovers CSVs for requested targets and fidelities - Builds a row-per-(smiles,fid) table with columns for each target - Splits by unique SMILES to avoid leakage across fidelity or targets - Fits transform+normalization on the training split only, applies to val/test - Builds RDKit graphs once per unique SMILES and reuses them NEW: - Auto per-task transform selection ("log10" vs "identity") by criteria - Optional explicit override via explicit_log_targets """ root = Path(root_dir) targets = _ensure_targets_order(targets) fids_lc = [_norm_fid(f) for f in fidelities] # Build long and pivot to rows long = build_long_table(root, targets, fids_lc) rows = pivot_to_rows_by_smiles_fid(long, targets) # Deterministic grouped split by SMILES if save_splits_path is not None and Path(save_splits_path).exists(): with open(save_splits_path, "r") as f: split_obj = json.load(f) train_smiles = set(split_obj["train_smiles"]) val_smiles = set(split_obj["val_smiles"]) test_smiles = set(split_obj["test_smiles"]) train_idx = rows.index[rows["smiles"].isin(train_smiles)].to_numpy() val_idx = rows.index[rows["smiles"].isin(val_smiles)].to_numpy() test_idx = rows.index[rows["smiles"].isin(test_smiles)].to_numpy() else: train_idx, val_idx, test_idx = grouped_split_by_smiles(rows, val_ratio=val_ratio, test_ratio=test_ratio, seed=seed) if save_splits_path is not None: split_obj = { "train_smiles": rows.iloc[train_idx]["smiles"].drop_duplicates().tolist(), "val_smiles": rows.iloc[val_idx]["smiles"].drop_duplicates().tolist(), "test_smiles": rows.iloc[test_idx]["smiles"].drop_duplicates().tolist(), "seed": seed, "val_ratio": val_ratio, "test_ratio": test_ratio, } Path(save_splits_path).parent.mkdir(parents=True, exist_ok=True) with open(save_splits_path, "w") as f: json.dump(split_obj, f, indent=2) # Build RDKit graphs once per unique SMILES uniq_smiles = rows["smiles"].drop_duplicates().tolist() smiles_graph_cache: Dict[str, tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = {} for smi in uniq_smiles: try: x, edge_index, edge_attr = featurize_smiles(smi) smiles_graph_cache[smi] = (x, edge_index, edge_attr) except Exception as e: warnings.warn(f"[data_builder] Dropping SMILES due to RDKit parse error: {smi} ({e})") # Filter rows to those that featurized successfully rows = rows[rows["smiles"].isin(smiles_graph_cache.keys())].reset_index(drop=True) # Re-map indices after filtering using smiles membership train_idx = rows.index[rows["smiles"].isin(set(rows.iloc[train_idx]["smiles"]))].to_numpy() val_idx = rows.index[rows["smiles"].isin(set(rows.iloc[val_idx]["smiles"]))].to_numpy() test_idx = rows.index[rows["smiles"].isin(set(rows.iloc[test_idx]["smiles"]))].to_numpy() # Optional subsampling (train only) for a specific (target, fidelity) block train_idx = subsample_train_indices( rows, train_idx, target=subsample_target, fidelity=subsample_fidelity, pct=float(subsample_pct), seed=int(subsample_seed), ) # Fit scaler on training split only def build_y_mask(df_slice: pd.DataFrame) -> tuple[torch.Tensor, torch.Tensor]: ys, ms = [], [] for _, r in df_slice.iterrows(): yv, mv = [], [] for t in targets: v = r[t] if pd.isna(v): yv.append(np.nan) mv.append(False) else: yv.append(float(v)) mv.append(True) ys.append(yv) ms.append(mv) y = torch.tensor(np.array(ys, dtype=np.float32)) mask = torch.tensor(np.array(ms, dtype=np.bool_)) return y, mask y_train, mask_train = build_y_mask(rows.iloc[train_idx]) # Decide transforms per task if explicit_log_targets: explicit_set = set(explicit_log_targets) transforms = [("log10" if t in explicit_set else "identity") for t in targets] eps_vec = None # will be auto-chosen in scaler.fit if not provided elif auto_log: transforms, eps_vec = auto_select_task_transforms( y_train, mask_train, targets, min_pos_frac=float(log_min_pos_frac), orders_threshold=float(log_orders_threshold), ) else: transforms, eps_vec = (["identity"] * len(targets), None) scaler = TargetScaler(transforms=transforms, eps=eps_vec) scaler.fit(y_train, mask_train) # Build datasets train_rows = rows.iloc[train_idx].reset_index(drop=True) val_rows = rows.iloc[val_idx].reset_index(drop=True) test_rows = rows.iloc[test_idx].reset_index(drop=True) train_ds = MultiFidelityMoleculeDataset(train_rows, targets, scaler, smiles_graph_cache) val_ds = MultiFidelityMoleculeDataset(val_rows, targets, scaler, smiles_graph_cache) test_ds = MultiFidelityMoleculeDataset(test_rows, targets, scaler, smiles_graph_cache) return train_ds, val_ds, test_ds, scaler __all__ = [ "build_dataset_from_dir", "discover_target_fid_csvs", "read_target_csv", "build_long_table", "pivot_to_rows_by_smiles_fid", "grouped_split_by_smiles", "TargetScaler", "MultiFidelityMoleculeDataset", "atom_features", "bond_features", "featurize_smiles", "auto_select_task_transforms", ]