sobinalosious92 commited on
Commit
a22718f
·
verified ·
1 Parent(s): 83601ad

Delete src

Browse files
src/conv.py DELETED
@@ -1,258 +0,0 @@
1
- # conv.py
2
- # Clean, dependency-light graph encoder blocks for molecular GNNs.
3
- # - Single source of truth for convolution choices: "gine", "gin", "gcn"
4
- # - Edge attributes are supported for "gine" (recommended for chemistry)
5
- # - No duplication with PyG built-ins; everything wraps torch_geometric.nn
6
- # - Consistent encoder API: GNNEncoder(...).forward(x, edge_index, edge_attr, batch) -> graph embedding [B, emb_dim]
7
-
8
- from __future__ import annotations
9
- from typing import Literal, Optional
10
-
11
- import torch
12
- import torch.nn as nn
13
- import torch.nn.functional as F
14
- from torch_geometric.nn import (
15
- GINEConv,
16
- GINConv,
17
- GCNConv,
18
- global_mean_pool,
19
- global_add_pool,
20
- global_max_pool,
21
- )
22
-
23
-
24
- def get_activation(name: str) -> nn.Module:
25
- name = name.lower()
26
- if name == "relu":
27
- return nn.ReLU()
28
- if name == "gelu":
29
- return nn.GELU()
30
- if name == "silu":
31
- return nn.SiLU()
32
- if name in ("leaky_relu", "lrelu"):
33
- return nn.LeakyReLU(0.1)
34
- raise ValueError(f"Unknown activation: {name}")
35
-
36
-
37
- class MLP(nn.Module):
38
- """Small MLP used inside GNN layers and projections."""
39
- def __init__(
40
- self,
41
- in_dim: int,
42
- hidden_dim: int,
43
- out_dim: int,
44
- num_layers: int = 2,
45
- act: str = "relu",
46
- dropout: float = 0.0,
47
- bias: bool = True,
48
- ):
49
- super().__init__()
50
- assert num_layers >= 1
51
- layers: list[nn.Module] = []
52
- dims = [in_dim] + [hidden_dim] * (num_layers - 1) + [out_dim]
53
- for i in range(len(dims) - 1):
54
- layers.append(nn.Linear(dims[i], dims[i + 1], bias=bias))
55
- if i < len(dims) - 2:
56
- layers.append(get_activation(act))
57
- if dropout > 0:
58
- layers.append(nn.Dropout(dropout))
59
- self.net = nn.Sequential(*layers)
60
-
61
- def forward(self, x: torch.Tensor) -> torch.Tensor:
62
- return self.net(x)
63
-
64
-
65
- class NodeProjector(nn.Module):
66
- """Projects raw node features to model embedding size."""
67
- def __init__(self, in_dim_node: int, emb_dim: int, act: str = "relu"):
68
- super().__init__()
69
- if in_dim_node == emb_dim:
70
- self.proj = nn.Identity()
71
- else:
72
- self.proj = nn.Sequential(
73
- nn.Linear(in_dim_node, emb_dim),
74
- get_activation(act),
75
- )
76
-
77
- def forward(self, x: torch.Tensor) -> torch.Tensor:
78
- return self.proj(x)
79
-
80
-
81
- class EdgeProjector(nn.Module):
82
- """Projects raw edge attributes to model embedding size for GINE."""
83
- def __init__(self, in_dim_edge: int, emb_dim: int, act: str = "relu"):
84
- super().__init__()
85
- if in_dim_edge <= 0:
86
- raise ValueError("in_dim_edge must be > 0 when using edge attributes")
87
- self.proj = nn.Sequential(
88
- nn.Linear(in_dim_edge, emb_dim),
89
- get_activation(act),
90
- )
91
-
92
- def forward(self, e: torch.Tensor) -> torch.Tensor:
93
- return self.proj(e)
94
-
95
-
96
- class GNNEncoder(nn.Module):
97
- """
98
- Backbone GNN with selectable conv type.
99
-
100
- gnn_type:
101
- - "gine": chemistry-ready, uses edge_attr (recommended)
102
- - "gin" : ignores edge_attr, strong node MPNN
103
- - "gcn" : ignores edge_attr, fast spectral conv
104
- norm: "batch" | "layer" | "none"
105
- readout: "mean" | "sum" | "max"
106
- """
107
-
108
- def __init__(
109
- self,
110
- in_dim_node: int,
111
- emb_dim: int,
112
- num_layers: int = 5,
113
- gnn_type: Literal["gine", "gin", "gcn"] = "gine",
114
- in_dim_edge: int = 0,
115
- act: str = "relu",
116
- dropout: float = 0.0,
117
- residual: bool = True,
118
- norm: Literal["batch", "layer", "none"] = "batch",
119
- readout: Literal["mean", "sum", "max"] = "mean",
120
- ):
121
- super().__init__()
122
- assert num_layers >= 1
123
-
124
- self.gnn_type = gnn_type.lower()
125
- self.emb_dim = emb_dim
126
- self.num_layers = num_layers
127
- self.residual = residual
128
- self.dropout_p = float(dropout)
129
- self.readout = readout.lower()
130
-
131
- self.node_proj = NodeProjector(in_dim_node, emb_dim, act=act)
132
- self.edge_proj: Optional[EdgeProjector] = None
133
-
134
- if self.gnn_type == "gine":
135
- if in_dim_edge <= 0:
136
- raise ValueError(
137
- "gine selected but in_dim_edge <= 0. Provide edge attributes or switch gnn_type."
138
- )
139
- self.edge_proj = EdgeProjector(in_dim_edge, emb_dim, act=act)
140
-
141
- # Build conv stack
142
- self.convs = nn.ModuleList()
143
- self.norms = nn.ModuleList()
144
-
145
- for _ in range(num_layers):
146
- if self.gnn_type == "gine":
147
- # edge_attr must be projected to emb_dim
148
- nn_mlp = MLP(emb_dim, emb_dim, emb_dim, num_layers=2, act=act, dropout=0.0)
149
- conv = GINEConv(nn_mlp)
150
- elif self.gnn_type == "gin":
151
- nn_mlp = MLP(emb_dim, emb_dim, emb_dim, num_layers=2, act=act, dropout=0.0)
152
- conv = GINConv(nn_mlp)
153
- elif self.gnn_type == "gcn":
154
- conv = GCNConv(emb_dim, emb_dim, add_self_loops=True, normalize=True)
155
- else:
156
- raise ValueError(f"Unknown gnn_type: {gnn_type}")
157
- self.convs.append(conv)
158
-
159
- if norm == "batch":
160
- self.norms.append(nn.BatchNorm1d(emb_dim))
161
- elif norm == "layer":
162
- self.norms.append(nn.LayerNorm(emb_dim))
163
- elif norm == "none":
164
- self.norms.append(nn.Identity())
165
- else:
166
- raise ValueError(f"Unknown norm: {norm}")
167
-
168
- self.act = get_activation(act)
169
-
170
- def _readout(self, x: torch.Tensor, batch: torch.Tensor) -> torch.Tensor:
171
- if self.readout == "mean":
172
- return global_mean_pool(x, batch)
173
- if self.readout == "sum":
174
- return global_add_pool(x, batch)
175
- if self.readout == "max":
176
- return global_max_pool(x, batch)
177
- raise ValueError(f"Unknown readout: {self.readout}")
178
-
179
- def forward(
180
- self,
181
- x: torch.Tensor,
182
- edge_index: torch.Tensor,
183
- edge_attr: Optional[torch.Tensor],
184
- batch: Optional[torch.Tensor],
185
- ) -> torch.Tensor:
186
- """
187
- Returns a graph-level embedding of shape [B, emb_dim].
188
- If batch is None, assumes a single graph and creates a zero batch vector.
189
- """
190
- if batch is None:
191
- batch = x.new_zeros(x.size(0), dtype=torch.long)
192
-
193
- # Project features (ensure float dtype)
194
- x = x.float()
195
- x = self.node_proj(x)
196
-
197
- e = None
198
- if self.gnn_type == "gine":
199
- if edge_attr is None:
200
- raise ValueError("GINE requires edge_attr, but got None.")
201
- e = self.edge_proj(edge_attr.float())
202
-
203
- # Message passing
204
- h = x
205
- for conv, norm in zip(self.convs, self.norms):
206
- if self.gnn_type == "gcn":
207
- h_next = conv(h, edge_index) # GCNConv ignores edge_attr
208
- elif self.gnn_type == "gin":
209
- h_next = conv(h, edge_index) # GINConv ignores edge_attr
210
- else: # gine
211
- h_next = conv(h, edge_index, e)
212
-
213
- h_next = norm(h_next)
214
- h_next = self.act(h_next)
215
-
216
- if self.residual and h_next.shape == h.shape:
217
- h = h + h_next
218
- else:
219
- h = h_next
220
-
221
- if self.dropout_p > 0:
222
- h = F.dropout(h, p=self.dropout_p, training=self.training)
223
-
224
- g = self._readout(h, batch)
225
- return g # [B, emb_dim]
226
-
227
-
228
- def build_gnn_encoder(
229
- in_dim_node: int,
230
- emb_dim: int,
231
- num_layers: int = 5,
232
- gnn_type: Literal["gine", "gin", "gcn"] = "gine",
233
- in_dim_edge: int = 0,
234
- act: str = "relu",
235
- dropout: float = 0.0,
236
- residual: bool = True,
237
- norm: Literal["batch", "layer", "none"] = "batch",
238
- readout: Literal["mean", "sum", "max"] = "mean",
239
- ) -> GNNEncoder:
240
- """
241
- Factory to create a GNNEncoder with a consistent, minimal API.
242
- Prefer calling this from model.py so encoder construction is centralized.
243
- """
244
- return GNNEncoder(
245
- in_dim_node=in_dim_node,
246
- emb_dim=emb_dim,
247
- num_layers=num_layers,
248
- gnn_type=gnn_type,
249
- in_dim_edge=in_dim_edge,
250
- act=act,
251
- dropout=dropout,
252
- residual=residual,
253
- norm=norm,
254
- readout=readout,
255
- )
256
-
257
-
258
- __all__ = ["GNNEncoder", "build_gnn_encoder"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/data_builder.py DELETED
@@ -1,818 +0,0 @@
1
- # data_builder.py
2
- from __future__ import annotations
3
-
4
- from pathlib import Path
5
- from typing import Dict, List, Optional, Tuple, Sequence
6
- import json
7
- import warnings
8
-
9
- import numpy as np
10
- import pandas as pd
11
- import torch
12
- from torch.utils.data import Dataset
13
- from torch_geometric.data import Data
14
-
15
- # RDKit is required
16
- from rdkit import Chem
17
- from rdkit.Chem.rdchem import HybridizationType, BondType, BondStereo
18
-
19
- # ---------------------------------------------------------
20
- # Fidelity handling
21
- # ---------------------------------------------------------
22
-
23
- FID_PRIORITY = ["exp", "dft", "md", "gc"] # internal lower-case canonical order
24
-
25
-
26
- def _norm_fid(fid: str) -> str:
27
- return fid.strip().lower()
28
-
29
-
30
- def _ensure_targets_order(requested: Sequence[str]) -> List[str]:
31
- seen = set()
32
- ordered = []
33
- for t in requested:
34
- key = t.strip()
35
- if key in seen:
36
- continue
37
- seen.add(key)
38
- ordered.append(key)
39
- return ordered
40
-
41
-
42
- # ---------------------------------------------------------
43
- # RDKit featurization
44
- # ---------------------------------------------------------
45
-
46
- _ATOMS = ["H", "C", "N", "O", "F", "P", "S", "Cl", "Br", "I"]
47
- _ATOM2IDX = {s: i for i, s in enumerate(_ATOMS)}
48
- _HYBS = [HybridizationType.SP, HybridizationType.SP2, HybridizationType.SP3, HybridizationType.SP3D, HybridizationType.SP3D2]
49
- _HYB2IDX = {h: i for i, h in enumerate(_HYBS)}
50
- _BOND_STEREOS = [
51
- BondStereo.STEREONONE,
52
- BondStereo.STEREOANY,
53
- BondStereo.STEREOZ,
54
- BondStereo.STEREOE,
55
- BondStereo.STEREOCIS,
56
- BondStereo.STEREOTRANS,
57
- ]
58
- _STEREO2IDX = {s: i for i, s in enumerate(_BOND_STEREOS)}
59
-
60
-
61
- def _one_hot(index: int, size: int) -> List[float]:
62
- v = [0.0] * size
63
- if 0 <= index < size:
64
- v[index] = 1.0
65
- return v
66
-
67
-
68
- def atom_features(atom: Chem.Atom) -> List[float]:
69
- # Element one-hot with "other"
70
- elem_idx = _ATOM2IDX.get(atom.GetSymbol(), None)
71
- elem_oh = _one_hot(elem_idx if elem_idx is not None else len(_ATOMS), len(_ATOMS) + 1)
72
-
73
- # Degree one-hot up to 5 (bucket 5+)
74
- deg = min(int(atom.GetDegree()), 5)
75
- deg_oh = _one_hot(deg, 6)
76
-
77
- # Formal charge one-hot in [-2,-1,0,+1,+2]
78
- fc = max(-2, min(2, int(atom.GetFormalCharge())))
79
- fc_oh = _one_hot(fc + 2, 5)
80
-
81
- # Aromatic, in ring flags
82
- aromatic = [1.0 if atom.GetIsAromatic() else 0.0]
83
- in_ring = [1.0 if atom.IsInRing() else 0.0]
84
-
85
- # Hybridization one-hot with "other"
86
- hyb_idx = _HYB2IDX.get(atom.GetHybridization(), None)
87
- hyb_oh = _one_hot(hyb_idx if hyb_idx is not None else len(_HYBS), len(_HYBS) + 1)
88
-
89
- # Implicit H count capped at 4
90
- imp_h = min(int(atom.GetTotalNumHs(includeNeighbors=True)), 4)
91
- imp_h_oh = _one_hot(imp_h, 5)
92
-
93
- # length: 11+6+5+1+1+6+5 = 35 (element has 11 buckets incl. "other")
94
- feats = elem_oh + deg_oh + fc_oh + aromatic + in_ring + hyb_oh + imp_h_oh
95
- return feats
96
-
97
-
98
- def bond_features(bond: Chem.Bond) -> List[float]:
99
- bt = bond.GetBondType()
100
- single = 1.0 if bt == BondType.SINGLE else 0.0
101
- double = 1.0 if bt == BondType.DOUBLE else 0.0
102
- triple = 1.0 if bt == BondType.TRIPLE else 0.0
103
- aromatic = 1.0 if bt == BondType.AROMATIC else 0.0
104
- conj = 1.0 if bond.GetIsConjugated() else 0.0
105
- in_ring = 1.0 if bond.IsInRing() else 0.0
106
- stereo_oh = _one_hot(_STEREO2IDX.get(bond.GetStereo(), 0), len(_BOND_STEREOS))
107
- # length: 4 + 1 + 1 + 6 = 12
108
- return [single, double, triple, aromatic, conj, in_ring] + stereo_oh
109
-
110
-
111
- def featurize_smiles(smiles: str) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
112
- mol = Chem.MolFromSmiles(smiles)
113
- if mol is None:
114
- raise ValueError(f"RDKit failed to parse SMILES: {smiles}")
115
-
116
- # Nodes
117
- x = torch.tensor([atom_features(a) for a in mol.GetAtoms()], dtype=torch.float32)
118
-
119
- # Edges (bidirectional)
120
- rows, cols, eattr = [], [], []
121
- for b in mol.GetBonds():
122
- i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
123
- bf = bond_features(b)
124
- rows.extend([i, j])
125
- cols.extend([j, i])
126
- eattr.extend([bf, bf])
127
-
128
- if not rows:
129
- # single-atom molecules, add a dummy self-loop edge
130
- rows, cols = [0], [0]
131
- eattr = [[0.0] * 12]
132
-
133
- edge_index = torch.tensor([rows, cols], dtype=torch.long)
134
- edge_attr = torch.tensor(eattr, dtype=torch.float32)
135
- return x, edge_index, edge_attr
136
-
137
-
138
- # ---------------------------------------------------------
139
- # CSV discovery and reading
140
- # ---------------------------------------------------------
141
-
142
- def discover_target_fid_csvs(
143
- root: Path,
144
- targets: Sequence[str],
145
- fidelities: Sequence[str],
146
- ) -> Dict[tuple[str, str], Path]:
147
- """
148
- Discover CSV files for (target, fidelity) pairs.
149
-
150
- Supported layouts (case-insensitive):
151
-
152
- 1) {root}/{fid}/{target}.csv
153
- e.g. datafull/MD/SHEAR.csv, datafull/exp/cp.csv
154
-
155
- 2) {root}/{target}_{fid}.csv
156
- e.g. datafull/SHEAR_MD.csv, datafull/cp_exp.csv
157
-
158
- Matching is STRICT:
159
- - target and fid must appear as full '_' tokens in the stem
160
- - no substring matching, so 'he' will NOT match 'shear_md.csv'
161
- """
162
- root = Path(root)
163
- targets = _ensure_targets_order(targets)
164
- fids_lc = [_norm_fid(f) for f in fidelities]
165
-
166
- # Collect all CSVs under root
167
- all_paths = list(root.rglob("*.csv"))
168
-
169
- # Pre-index: (parent_name_lower, stem_lower, tokens_lower)
170
- indexed = []
171
- for p in all_paths:
172
- parent = p.parent.name.lower()
173
- stem = p.stem.lower() # filename without extension
174
- tokens = stem.split("_")
175
- tokens_l = [t.lower() for t in tokens]
176
- indexed.append((p, parent, stem, tokens_l))
177
-
178
- mapping: Dict[tuple[str, str], Path] = {}
179
-
180
- for fid in fids_lc:
181
- fid_l = fid.strip().lower()
182
-
183
- for tgt in targets:
184
- tgt_l = tgt.strip().lower()
185
-
186
- # ---- 1) Prefer explicit folder layout: {root}/{fid}/{target}.csv ----
187
- # parent == fid AND stem == target (case-insensitive)
188
- folder_matches = [
189
- p for (p, parent, stem, tokens_l) in indexed
190
- if parent == fid_l and stem == tgt_l
191
- ]
192
- if folder_matches:
193
- # If you ever get more than one, it’s a config problem
194
- if len(folder_matches) > 1:
195
- warnings.warn(
196
- f"[discover_target_fid_csvs] Multiple matches for "
197
- f"target='{tgt}' fid='{fid}' under folder layout: "
198
- + ", ".join(str(p) for p in folder_matches)
199
- )
200
- mapping[(tgt, fid)] = folder_matches[0]
201
- continue
202
-
203
- # ---- 2) Fallback: {target}_{fid}.csv anywhere under root ----
204
- # require BOTH tgt and fid as full '_' tokens
205
- token_matches = [
206
- p for (p, parent, stem, tokens_l) in indexed
207
- if (tgt_l in tokens_l) and (fid_l in tokens_l)
208
- ]
209
-
210
- if token_matches:
211
- if len(token_matches) > 1:
212
- warnings.warn(
213
- f"[discover_target_fid_csvs] Multiple token matches for "
214
- f"target='{tgt}' fid='{fid}': "
215
- + ", ".join(str(p) for p in token_matches)
216
- )
217
- mapping[(tgt, fid)] = token_matches[0]
218
- continue
219
-
220
- # If neither layout exists, we simply do not add (tgt, fid) to mapping.
221
- # build_long_table will just skip that combination.
222
- # You can enable a warning if you want:
223
- # warnings.warn(f"[discover_target_fid_csvs] No CSV for target='{tgt}', fid='{fid}'")
224
-
225
- return mapping
226
-
227
-
228
- def read_target_csv(path: Path, target: str) -> pd.DataFrame:
229
- """
230
- Accepts:
231
- - 'smiles' column (case-insensitive)
232
- - value column named '{target}' or one of ['value','y' or lower-case target]
233
- Deduplicates by SMILES with mean.
234
- """
235
- df = pd.read_csv(path)
236
-
237
- # smiles column
238
- smiles_col = next((c for c in df.columns if c.lower() == "smiles"), None)
239
- if smiles_col is None:
240
- raise ValueError(f"{path} must contain a 'smiles' column.")
241
- df = df.rename(columns={smiles_col: "smiles"})
242
-
243
- # value column
244
- val_col = None
245
- if target in df.columns:
246
- val_col = target
247
- else:
248
- for c in df.columns:
249
- if c.lower() in ("value", "y", target.lower()):
250
- val_col = c
251
- break
252
- if val_col is None:
253
- raise ValueError(f"{path} must contain a '{target}' column or one of ['value','y'].")
254
-
255
- df = df[["smiles", val_col]].copy()
256
- df = df.dropna(subset=[val_col])
257
- df[val_col] = pd.to_numeric(df[val_col], errors="coerce")
258
- df = df.dropna(subset=[val_col])
259
-
260
- # Deduplicate SMILES by mean
261
- if df.duplicated(subset=["smiles"]).any():
262
- warnings.warn(f"[data_builder] Duplicates by SMILES in {path}. Averaging duplicates.")
263
- df = df.groupby("smiles", as_index=False)[val_col].mean()
264
-
265
- return df.rename(columns={val_col: target})
266
-
267
-
268
- def build_long_table(root: Path, targets: Sequence[str], fidelities: Sequence[str]) -> pd.DataFrame:
269
- """
270
- Returns long-form table with columns: [smiles, fid, fid_idx, target, value]
271
- """
272
- targets = _ensure_targets_order(targets)
273
- fids_lc = [_norm_fid(f) for f in fidelities]
274
-
275
- mapping = discover_target_fid_csvs(root, targets, fidelities)
276
- if not mapping:
277
- raise FileNotFoundError(f"No CSVs found under {root} for the given targets and fidelities.")
278
-
279
- long_rows = []
280
- for (tgt, fid), path in mapping.items():
281
- df = read_target_csv(path, tgt)
282
- df["fid"] = _norm_fid(fid)
283
- df["target"] = tgt
284
- df = df.rename(columns={tgt: "value"})
285
- long_rows.append(df[["smiles", "fid", "target", "value"]])
286
-
287
- long = pd.concat(long_rows, axis=0, ignore_index=True)
288
-
289
- # attach fid index by priority
290
- fid2idx = {f: i for i, f in enumerate(FID_PRIORITY)}
291
- long["fid"] = long["fid"].str.lower()
292
- unknown = sorted(set(long["fid"]) - set(fid2idx.keys()))
293
- if unknown:
294
- warnings.warn(f"[data_builder] Unknown fidelities found: {unknown}. Appending after known ones.")
295
- start = len(fid2idx)
296
- for i, f in enumerate(unknown):
297
- fid2idx[f] = start + i
298
-
299
- long["fid_idx"] = long["fid"].map(fid2idx)
300
- return long
301
-
302
-
303
- def pivot_to_rows_by_smiles_fid(long: pd.DataFrame, targets: Sequence[str]) -> pd.DataFrame:
304
- """
305
- Input: long table [smiles, fid, fid_idx, target, value]
306
- Output: row-per-(smiles,fid) with wide columns for each target
307
- """
308
- targets = _ensure_targets_order(targets)
309
- wide = long.pivot_table(index=["smiles", "fid", "fid_idx"], columns="target", values="value", aggfunc="mean")
310
- wide = wide.reset_index()
311
-
312
- for t in targets:
313
- if t not in wide.columns:
314
- wide[t] = np.nan
315
-
316
- cols = ["smiles", "fid", "fid_idx"] + list(targets)
317
- return wide[cols]
318
-
319
-
320
- # ---------------------------------------------------------
321
- # Grouped split by SMILES and transforms/normalization
322
- # ---------------------------------------------------------
323
-
324
- def grouped_split_by_smiles(
325
- df_rows: pd.DataFrame,
326
- val_ratio: float = 0.1,
327
- test_ratio: float = 0.1,
328
- seed: int = 42,
329
- ) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
330
- uniq = df_rows["smiles"].drop_duplicates().values
331
- rng = np.random.default_rng(seed)
332
- uniq = rng.permutation(uniq)
333
-
334
- n = len(uniq)
335
- n_test = int(round(n * test_ratio))
336
- n_val = int(round(n * val_ratio))
337
-
338
- test_smiles = set(uniq[:n_test])
339
- val_smiles = set(uniq[n_test:n_test + n_val])
340
- train_smiles = set(uniq[n_test + n_val:])
341
-
342
- train_idx = df_rows.index[df_rows["smiles"].isin(train_smiles)].to_numpy()
343
- val_idx = df_rows.index[df_rows["smiles"].isin(val_smiles)].to_numpy()
344
- test_idx = df_rows.index[df_rows["smiles"].isin(test_smiles)].to_numpy()
345
- return train_idx, val_idx, test_idx
346
-
347
-
348
- # ---------------- Enhanced TargetScaler with per-task transforms ----------------
349
-
350
- class TargetScaler:
351
- """
352
- Per-task transform + standardization fitted on the training split only.
353
-
354
- - transforms[t] in {"identity","log10"}
355
- - eps[t] is added before log for numerical safety (only used if transforms[t]=="log10")
356
- - mean/std are computed in the *transformed* domain
357
- """
358
- def __init__(self, transforms: Optional[Sequence[str]] = None, eps: Optional[Sequence[float] | torch.Tensor] = None):
359
- self.mean: Optional[torch.Tensor] = None # [T] (transformed domain)
360
- self.std: Optional[torch.Tensor] = None # [T] (transformed domain)
361
- self.transforms: List[str] = [str(t).lower() for t in transforms] if transforms is not None else []
362
- if eps is None:
363
- self.eps: Optional[torch.Tensor] = None
364
- else:
365
- self.eps = torch.as_tensor(eps, dtype=torch.float32)
366
- self._tiny = 1e-12
367
-
368
- def _ensure_cfg(self, T: int):
369
- if not self.transforms or len(self.transforms) != T:
370
- self.transforms = ["identity"] * T
371
- if self.eps is None or self.eps.numel() != T:
372
- self.eps = torch.zeros(T, dtype=torch.float32)
373
-
374
- def _forward_transform_only(self, y: torch.Tensor) -> torch.Tensor:
375
- """
376
- Apply per-task transforms *before* standardization.
377
- y: [N, T] in original units. Returns transformed y_tf in same shape.
378
- """
379
- out = y.clone()
380
- T = out.size(1)
381
- self._ensure_cfg(T)
382
- for t in range(T):
383
- if self.transforms[t] == "log10":
384
- out[:, t] = torch.log10(torch.clamp(out[:, t] + self.eps[t], min=self._tiny))
385
- return out
386
-
387
- def _inverse_transform_only(self, y_tf: torch.Tensor) -> torch.Tensor:
388
- """
389
- Inverse the per-task transform (no standardization here).
390
- y_tf: [N, T] in transformed units.
391
- """
392
- out = y_tf.clone()
393
- T = out.size(1)
394
- self._ensure_cfg(T)
395
- for t in range(T):
396
- if self.transforms[t] == "log10":
397
- out[:, t] = (10.0 ** out[:, t]) - self.eps[t]
398
- return out
399
-
400
- def fit(self, y: torch.Tensor, mask: torch.Tensor):
401
- """
402
- y: [N, T] original units; mask: [N, T] bool
403
- Chooses eps automatically if not provided; mean/std computed in transformed space.
404
- """
405
- T = y.size(1)
406
- self._ensure_cfg(T)
407
-
408
- if self.eps is None or self.eps.numel() != T:
409
- # Auto epsilon: 0.1 * min positive per task (robust)
410
- eps_vals: List[float] = []
411
- y_np = y.detach().cpu().numpy()
412
- m_np = mask.detach().cpu().numpy().astype(bool)
413
- for t in range(T):
414
- if self.transforms[t] != "log10":
415
- eps_vals.append(0.0)
416
- continue
417
- vals = y_np[m_np[:, t], t]
418
- pos = vals[vals > 0]
419
- if pos.size == 0:
420
- eps_vals.append(1e-8)
421
- else:
422
- eps_vals.append(0.1 * float(max(np.min(pos), 1e-8)))
423
- self.eps = torch.tensor(eps_vals, dtype=torch.float32)
424
-
425
- y_tf = self._forward_transform_only(y)
426
- eps = 1e-8
427
- y_masked = torch.where(mask, y_tf, torch.zeros_like(y_tf))
428
- counts = mask.sum(dim=0).clamp_min(1)
429
- mean = y_masked.sum(dim=0) / counts
430
- var = ((torch.where(mask, y_tf - mean, torch.zeros_like(y_tf))) ** 2).sum(dim=0) / counts
431
- std = torch.sqrt(var + eps)
432
- self.mean, self.std = mean, std
433
-
434
- def transform(self, y: torch.Tensor) -> torch.Tensor:
435
- y_tf = self._forward_transform_only(y)
436
- return (y_tf - self.mean) / self.std
437
-
438
- def inverse(self, y_std: torch.Tensor) -> torch.Tensor:
439
- """
440
- Inverse standardization + inverse transform → original units.
441
- y_std: [N, T] in standardized-transformed space
442
- """
443
- y_tf = y_std * self.std + self.mean
444
- return self._inverse_transform_only(y_tf)
445
-
446
- def state_dict(self) -> Dict[str, torch.Tensor | List[str]]:
447
- return {
448
- "mean": self.mean,
449
- "std": self.std,
450
- "transforms": self.transforms,
451
- "eps": self.eps,
452
- }
453
-
454
- def load_state_dict(self, state: Dict[str, torch.Tensor | List[str]]):
455
- self.mean = state["mean"]
456
- self.std = state["std"]
457
- self.transforms = [str(t) for t in state.get("transforms", [])]
458
- eps = state.get("eps", None)
459
- self.eps = torch.as_tensor(eps, dtype=torch.float32) if eps is not None else None
460
-
461
-
462
- def auto_select_task_transforms(
463
- y_train: torch.Tensor, # [N, T] original units (train split only)
464
- mask_train: torch.Tensor, # [N, T] bool
465
- task_names: Sequence[str],
466
- *,
467
- min_pos_frac: float = 0.95, # ≥95% of labels positive
468
- orders_threshold: float = 2.0, # ≥2 orders of magnitude between p95 and p5
469
- tiny: float = 1e-12,
470
- ) -> tuple[List[str], torch.Tensor]:
471
- """
472
- Decide per-task transform: "log10" if (mostly-positive AND large dynamic range), else "identity".
473
- Returns (transforms, eps_vector) where eps is only used for log tasks.
474
- """
475
- Y = y_train.detach().cpu().numpy()
476
- M = mask_train.detach().cpu().numpy().astype(bool)
477
-
478
- transforms: List[str] = []
479
- eps_vals: List[float] = []
480
-
481
- for t in range(Y.shape[1]):
482
- yt = Y[M[:, t], t]
483
- if yt.size == 0:
484
- transforms.append("identity")
485
- eps_vals.append(0.0)
486
- continue
487
-
488
- pos_frac = (yt > 0).mean()
489
- p5 = float(np.percentile(yt, 5))
490
- p95 = float(np.percentile(yt, 95))
491
- denom = max(p5, tiny)
492
- dyn_orders = float(np.log10(max(p95 / denom, 1.0)))
493
- use_log = (pos_frac >= min_pos_frac) and (dyn_orders >= orders_threshold)
494
-
495
- if use_log:
496
- pos_vals = yt[yt > 0]
497
- if pos_vals.size == 0:
498
- eps_vals.append(1e-8)
499
- else:
500
- eps_vals.append(0.1 * float(max(np.min(pos_vals), 1e-8)))
501
- transforms.append("log10")
502
- else:
503
- transforms.append("identity")
504
- eps_vals.append(0.0)
505
-
506
- return transforms, torch.tensor(eps_vals, dtype=torch.float32)
507
-
508
-
509
- # ---------------------------------------------------------
510
- # Dataset
511
- # ---------------------------------------------------------
512
-
513
- class MultiFidelityMoleculeDataset(Dataset):
514
- """
515
- Each item is a PyG Data with:
516
- - x: [N_nodes, F_node]
517
- - edge_index: [2, N_edges]
518
- - edge_attr: [N_edges, F_edge]
519
- - y: [T] normalized targets (zeros where missing)
520
- - y_mask: [T] bool mask of present targets
521
- - fid_idx: [1] long
522
- - .smiles and .fid_str added for debugging
523
-
524
- Targets are kept in the exact order provided by the user.
525
- """
526
- def __init__(
527
- self,
528
- rows: pd.DataFrame,
529
- targets: Sequence[str],
530
- scaler: Optional[TargetScaler],
531
- smiles_graph_cache: Dict[str, tuple[torch.Tensor, torch.Tensor, torch.Tensor]],
532
- ):
533
- super().__init__()
534
- self.rows = rows.reset_index(drop=True).copy()
535
- self.targets = _ensure_targets_order(targets)
536
- self.scaler = scaler
537
- self.smiles_graph_cache = smiles_graph_cache
538
-
539
- # Build y and mask tensors
540
- ys, masks = [], []
541
- for _, r in self.rows.iterrows():
542
- yv, mv = [], []
543
- for t in self.targets:
544
- v = r[t]
545
- if pd.isna(v):
546
- yv.append(np.nan)
547
- mv.append(False)
548
- else:
549
- yv.append(float(v))
550
- mv.append(True)
551
- ys.append(yv)
552
- masks.append(mv)
553
-
554
- y = torch.tensor(np.array(ys, dtype=np.float32)) # [N, T]
555
- mask = torch.tensor(np.array(masks, dtype=np.bool_))
556
-
557
- if scaler is not None and scaler.mean is not None:
558
- y_norm = torch.where(mask, scaler.transform(y), torch.zeros_like(y))
559
- else:
560
- y_norm = y
561
-
562
- self.y = y_norm
563
- self.mask = mask
564
-
565
- # Input dims
566
- any_smiles = self.rows.iloc[0]["smiles"]
567
- x0, _, e0 = smiles_graph_cache[any_smiles]
568
- self.in_dim_node = x0.shape[1]
569
- self.in_dim_edge = e0.shape[1]
570
-
571
- # Fidelity metadata for reference (local indexing in this dataset)
572
- self.fids = sorted(
573
- self.rows["fid"].str.lower().unique().tolist(),
574
- key=lambda f: (FID_PRIORITY + [f]).index(f) if f in FID_PRIORITY else len(FID_PRIORITY),
575
- )
576
- self.fid2idx = {f: i for i, f in enumerate(self.fids)}
577
- self.rows["fid_idx_local"] = self.rows["fid"].str.lower().map(self.fid2idx)
578
-
579
- def __len__(self) -> int:
580
- return len(self.rows)
581
-
582
- def __getitem__(self, idx: int) -> Data:
583
- idx = int(idx)
584
- r = self.rows.iloc[idx]
585
- smi = r["smiles"]
586
-
587
- x, edge_index, edge_attr = self.smiles_graph_cache[smi]
588
- # Ensure [1, T] so batches become [B, T]
589
- y_i = self.y[idx].clone().unsqueeze(0) # [1, T]
590
- m_i = self.mask[idx].clone().unsqueeze(0) # [1, T]
591
- fid_idx = int(r["fid_idx_local"])
592
-
593
- d = Data(
594
- x=x.clone(),
595
- edge_index=edge_index.clone(),
596
- edge_attr=edge_attr.clone(),
597
- y=y_i,
598
- y_mask=m_i,
599
- fid_idx=torch.tensor([fid_idx], dtype=torch.long),
600
- )
601
- d.smiles = smi
602
- d.fid_str = r["fid"]
603
- return d
604
-
605
-
606
- def subsample_train_indices(
607
- rows: pd.DataFrame,
608
- train_idx: np.ndarray,
609
- *,
610
- target: Optional[str],
611
- fidelity: Optional[str],
612
- pct: float = 1.0,
613
- seed: int = 137,
614
- ) -> np.ndarray:
615
- """
616
- Return a filtered train_idx that keeps only a 'pct' fraction (0<pct<=1)
617
- of TRAIN rows for the specified (target, fidelity) block. Selection is
618
- deterministic by unique SMILES. Rows outside the block are untouched.
619
-
620
- rows: wide table with columns ["smiles","fid","fid_idx", <targets...>]
621
- """
622
- if target is None or fidelity is None or pct >= 0.999:
623
- return train_idx
624
-
625
- if target not in rows.columns:
626
- return train_idx
627
-
628
- fid_lc = fidelity.strip().lower()
629
-
630
- # Identify TRAIN rows in the specified block: matching fid and having a label for 'target'
631
- train_rows = rows.iloc[train_idx]
632
- block_mask = (train_rows["fid"].str.lower() == fid_lc) & (~train_rows[target].isna())
633
- if not bool(block_mask.any()):
634
- return train_idx # nothing to subsample
635
-
636
- # Sample by unique SMILES (stable & grouped)
637
- smiles_all = pd.Index(train_rows.loc[block_mask, "smiles"].unique())
638
- n_all = len(smiles_all)
639
- if n_all == 0:
640
- return train_idx
641
-
642
- if pct <= 0.0:
643
- pct = 0.0001
644
- n_keep = max(1, int(round(pct * n_all)))
645
-
646
- rng = np.random.RandomState(int(seed))
647
- smiles_sorted = np.array(sorted(smiles_all.tolist()))
648
- keep_smiles = set(rng.choice(smiles_sorted, size=n_keep, replace=False).tolist())
649
-
650
- # Keep all non-block rows; within block keep selected SMILES
651
- keep_mask_local = (~block_mask) | (train_rows["smiles"].isin(keep_smiles))
652
- kept_train_idx = train_rows.index[keep_mask_local].to_numpy()
653
- return kept_train_idx
654
-
655
-
656
- # ---------------------------------------------------------
657
- # High-level builder
658
- # ---------------------------------------------------------
659
-
660
- def build_dataset_from_dir(
661
- root_dir: str | Path,
662
- targets: Sequence[str],
663
- fidelities: Sequence[str] = ("exp", "dft", "md", "gc"),
664
- val_ratio: float = 0.1,
665
- test_ratio: float = 0.1,
666
- seed: int = 42,
667
- save_splits_path: Optional[str | Path] = None,
668
- # Optional subsampling of a (target, fidelity) block in TRAIN
669
- subsample_target: Optional[str] = None,
670
- subsample_fidelity: Optional[str] = None,
671
- subsample_pct: float = 1.0,
672
- subsample_seed: int = 137,
673
- # -------- NEW: auto/explicit log transforms --------
674
- auto_log: bool = True,
675
- log_orders_threshold: float = 2.0,
676
- log_min_pos_frac: float = 0.95,
677
- explicit_log_targets: Optional[Sequence[str]] = None, # e.g. ["permeability"]
678
- ) -> tuple[MultiFidelityMoleculeDataset, MultiFidelityMoleculeDataset, MultiFidelityMoleculeDataset, TargetScaler]:
679
- """
680
- Returns train_ds, val_ds, test_ds, scaler.
681
-
682
- - Discovers CSVs for requested targets and fidelities
683
- - Builds a row-per-(smiles,fid) table with columns for each target
684
- - Splits by unique SMILES to avoid leakage across fidelity or targets
685
- - Fits transform+normalization on the training split only, applies to val/test
686
- - Builds RDKit graphs once per unique SMILES and reuses them
687
-
688
- NEW:
689
- - Auto per-task transform selection ("log10" vs "identity") by criteria
690
- - Optional explicit override via explicit_log_targets
691
- """
692
- root = Path(root_dir)
693
- targets = _ensure_targets_order(targets)
694
- fids_lc = [_norm_fid(f) for f in fidelities]
695
-
696
- # Build long and pivot to rows
697
- long = build_long_table(root, targets, fids_lc)
698
- rows = pivot_to_rows_by_smiles_fid(long, targets)
699
-
700
- # Deterministic grouped split by SMILES
701
- if save_splits_path is not None and Path(save_splits_path).exists():
702
- with open(save_splits_path, "r") as f:
703
- split_obj = json.load(f)
704
- train_smiles = set(split_obj["train_smiles"])
705
- val_smiles = set(split_obj["val_smiles"])
706
- test_smiles = set(split_obj["test_smiles"])
707
- train_idx = rows.index[rows["smiles"].isin(train_smiles)].to_numpy()
708
- val_idx = rows.index[rows["smiles"].isin(val_smiles)].to_numpy()
709
- test_idx = rows.index[rows["smiles"].isin(test_smiles)].to_numpy()
710
- else:
711
- train_idx, val_idx, test_idx = grouped_split_by_smiles(rows, val_ratio=val_ratio, test_ratio=test_ratio, seed=seed)
712
- if save_splits_path is not None:
713
- split_obj = {
714
- "train_smiles": rows.iloc[train_idx]["smiles"].drop_duplicates().tolist(),
715
- "val_smiles": rows.iloc[val_idx]["smiles"].drop_duplicates().tolist(),
716
- "test_smiles": rows.iloc[test_idx]["smiles"].drop_duplicates().tolist(),
717
- "seed": seed,
718
- "val_ratio": val_ratio,
719
- "test_ratio": test_ratio,
720
- }
721
- Path(save_splits_path).parent.mkdir(parents=True, exist_ok=True)
722
- with open(save_splits_path, "w") as f:
723
- json.dump(split_obj, f, indent=2)
724
-
725
- # Build RDKit graphs once per unique SMILES
726
- uniq_smiles = rows["smiles"].drop_duplicates().tolist()
727
- smiles_graph_cache: Dict[str, tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = {}
728
- for smi in uniq_smiles:
729
- try:
730
- x, edge_index, edge_attr = featurize_smiles(smi)
731
- smiles_graph_cache[smi] = (x, edge_index, edge_attr)
732
- except Exception as e:
733
- warnings.warn(f"[data_builder] Dropping SMILES due to RDKit parse error: {smi} ({e})")
734
-
735
- # Filter rows to those that featurized successfully
736
- rows = rows[rows["smiles"].isin(smiles_graph_cache.keys())].reset_index(drop=True)
737
-
738
- # Re-map indices after filtering using smiles membership
739
- train_idx = rows.index[rows["smiles"].isin(set(rows.iloc[train_idx]["smiles"]))].to_numpy()
740
- val_idx = rows.index[rows["smiles"].isin(set(rows.iloc[val_idx]["smiles"]))].to_numpy()
741
- test_idx = rows.index[rows["smiles"].isin(set(rows.iloc[test_idx]["smiles"]))].to_numpy()
742
-
743
- # Optional subsampling (train only) for a specific (target, fidelity) block
744
- train_idx = subsample_train_indices(
745
- rows,
746
- train_idx,
747
- target=subsample_target,
748
- fidelity=subsample_fidelity,
749
- pct=float(subsample_pct),
750
- seed=int(subsample_seed),
751
- )
752
-
753
- # Fit scaler on training split only
754
- def build_y_mask(df_slice: pd.DataFrame) -> tuple[torch.Tensor, torch.Tensor]:
755
- ys, ms = [], []
756
- for _, r in df_slice.iterrows():
757
- yv, mv = [], []
758
- for t in targets:
759
- v = r[t]
760
- if pd.isna(v):
761
- yv.append(np.nan)
762
- mv.append(False)
763
- else:
764
- yv.append(float(v))
765
- mv.append(True)
766
- ys.append(yv)
767
- ms.append(mv)
768
- y = torch.tensor(np.array(ys, dtype=np.float32))
769
- mask = torch.tensor(np.array(ms, dtype=np.bool_))
770
- return y, mask
771
-
772
- y_train, mask_train = build_y_mask(rows.iloc[train_idx])
773
-
774
- # Decide transforms per task
775
- if explicit_log_targets:
776
- explicit_set = set(explicit_log_targets)
777
- transforms = [("log10" if t in explicit_set else "identity") for t in targets]
778
- eps_vec = None # will be auto-chosen in scaler.fit if not provided
779
- elif auto_log:
780
- transforms, eps_vec = auto_select_task_transforms(
781
- y_train,
782
- mask_train,
783
- targets,
784
- min_pos_frac=float(log_min_pos_frac),
785
- orders_threshold=float(log_orders_threshold),
786
- )
787
- else:
788
- transforms, eps_vec = (["identity"] * len(targets), None)
789
-
790
- scaler = TargetScaler(transforms=transforms, eps=eps_vec)
791
- scaler.fit(y_train, mask_train)
792
-
793
- # Build datasets
794
- train_rows = rows.iloc[train_idx].reset_index(drop=True)
795
- val_rows = rows.iloc[val_idx].reset_index(drop=True)
796
- test_rows = rows.iloc[test_idx].reset_index(drop=True)
797
-
798
- train_ds = MultiFidelityMoleculeDataset(train_rows, targets, scaler, smiles_graph_cache)
799
- val_ds = MultiFidelityMoleculeDataset(val_rows, targets, scaler, smiles_graph_cache)
800
- test_ds = MultiFidelityMoleculeDataset(test_rows, targets, scaler, smiles_graph_cache)
801
-
802
- return train_ds, val_ds, test_ds, scaler
803
-
804
-
805
- __all__ = [
806
- "build_dataset_from_dir",
807
- "discover_target_fid_csvs",
808
- "read_target_csv",
809
- "build_long_table",
810
- "pivot_to_rows_by_smiles_fid",
811
- "grouped_split_by_smiles",
812
- "TargetScaler",
813
- "MultiFidelityMoleculeDataset",
814
- "atom_features",
815
- "bond_features",
816
- "featurize_smiles",
817
- "auto_select_task_transforms",
818
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/discover_llm.py DELETED
@@ -1,829 +0,0 @@
1
- # src/discovery.py
2
- from __future__ import annotations
3
-
4
- import json
5
- from dataclasses import dataclass
6
- from pathlib import Path
7
- from typing import Callable, Dict, List, Optional, Tuple
8
-
9
- import numpy as np
10
- import pandas as pd
11
- from rdkit import Chem, DataStructs
12
- from rdkit.Chem import AllChem
13
- from . import sascorer
14
-
15
- # Reuse your canonicalizer if you want; otherwise keep local
16
- def canonicalize_smiles(smiles: str) -> Optional[str]:
17
- s = (smiles or "").strip()
18
- if not s:
19
- return None
20
- m = Chem.MolFromSmiles(s)
21
- if m is None:
22
- return None
23
- return Chem.MolToSmiles(m, canonical=True)
24
-
25
-
26
- # -------------------------
27
- # Spec schema (minimal v0)
28
- # -------------------------
29
- @dataclass
30
- class DiscoverySpec:
31
- dataset: List[str] # ["PI1M_PROPERTY.parquet", "POLYINFO_PROPERTY.parquet"]
32
- polyinfo: str # "POLYINFO_PROPERTY.parquet"
33
- polyinfo_csv: str # "POLYINFO.csv"
34
-
35
- hard_constraints: Dict[str, Dict[str, float]] # { "tg": {"min": 400}, "tc": {"max": 0.3} }
36
- objectives: List[Dict[str, str]] # [{"property":"cp","goal":"maximize"}, ...]
37
-
38
- max_pool: int = 200000 # legacy (kept for compatibility; aligned to pareto_max)
39
- pareto_max: int = 50000 # cap points used for Pareto + diversity fingerprinting
40
- max_candidates: int = 30 # final output size
41
- max_pareto_fronts: int = 5 # how many Pareto layers to keep for candidate pool
42
- min_distance: float = 0.30 # diversity threshold in Tanimoto distance
43
- fingerprint: str = "morgan" # morgan only for now
44
- random_seed: int = 7
45
- use_canonical_smiles: bool = True
46
- use_full_data: bool = False
47
- trust_weights: Dict[str, float] | None = None
48
- selection_weights: Dict[str, float] | None = None
49
-
50
-
51
- # -------------------------
52
- # Property metadata (local to discovery_llm)
53
- # -------------------------
54
- PROPERTY_META: Dict[str, Dict[str, str]] = {
55
- # Thermal
56
- "tm": {"name": "Melting temperature", "unit": "K"},
57
- "tg": {"name": "Glass transition temperature", "unit": "K"},
58
- "td": {"name": "Thermal diffusivity", "unit": "m^2/s"},
59
- "tc": {"name": "Thermal conductivity", "unit": "W/m-K"},
60
- "cp": {"name": "Specific heat capacity", "unit": "J/kg-K"},
61
- # Mechanical
62
- "young": {"name": "Young's modulus", "unit": "GPa"},
63
- "shear": {"name": "Shear modulus", "unit": "GPa"},
64
- "bulk": {"name": "Bulk modulus", "unit": "GPa"},
65
- "poisson": {"name": "Poisson ratio", "unit": "-"},
66
- # Transport
67
- "visc": {"name": "Viscosity", "unit": "Pa-s"},
68
- "dif": {"name": "Diffusivity", "unit": "cm^2/s"},
69
- # Gas permeability
70
- "phe": {"name": "He permeability", "unit": "Barrer"},
71
- "ph2": {"name": "H2 permeability", "unit": "Barrer"},
72
- "pco2": {"name": "CO2 permeability", "unit": "Barrer"},
73
- "pn2": {"name": "N2 permeability", "unit": "Barrer"},
74
- "po2": {"name": "O2 permeability", "unit": "Barrer"},
75
- "pch4": {"name": "CH4 permeability", "unit": "Barrer"},
76
- # Electronic / Optical
77
- "alpha": {"name": "Polarizability", "unit": "a.u."},
78
- "homo": {"name": "HOMO energy", "unit": "eV"},
79
- "lumo": {"name": "LUMO energy", "unit": "eV"},
80
- "bandgap": {"name": "Band gap", "unit": "eV"},
81
- "mu": {"name": "Dipole moment", "unit": "Debye"},
82
- "etotal": {"name": "Total electronic energy", "unit": "eV"},
83
- "ri": {"name": "Refractive index", "unit": "-"},
84
- "dc": {"name": "Dielectric constant", "unit": "-"},
85
- "pe": {"name": "Permittivity", "unit": "-"},
86
- # Structural / Physical
87
- "rg": {"name": "Radius of gyration", "unit": "A"},
88
- "rho": {"name": "Density", "unit": "g/cm^3"},
89
- }
90
-
91
-
92
- # -------------------------
93
- # Column mapping
94
- # -------------------------
95
- def mean_col(prop_key: str) -> str:
96
- return f"mean_{prop_key.lower()}"
97
-
98
- def std_col(prop_key: str) -> str:
99
- return f"std_{prop_key.lower()}"
100
-
101
-
102
- def normalize_weights(weights: Dict[str, float], defaults: Dict[str, float]) -> Dict[str, float]:
103
- out: Dict[str, float] = {}
104
- for k, v in defaults.items():
105
- try:
106
- vv = float(weights.get(k, v))
107
- except Exception:
108
- vv = float(v)
109
- out[k] = max(0.0, vv)
110
- s = float(sum(out.values()))
111
- if s <= 0.0:
112
- return defaults.copy()
113
- return {k: float(v / s) for k, v in out.items()}
114
-
115
- def spec_from_dict(obj: dict, dataset_path: List[str], polyinfo_path: str, polyinfo_csv_path: str) -> DiscoverySpec:
116
- pareto_max = int(obj.get("pareto_max", 50000))
117
- return DiscoverySpec(
118
- dataset=list(dataset_path),
119
- polyinfo=polyinfo_path,
120
- polyinfo_csv=polyinfo_csv_path,
121
- hard_constraints=obj.get("hard_constraints", {}),
122
- objectives=obj.get("objectives", []),
123
- # Legacy field kept for compatibility; effectively collapsed to pareto_max.
124
- max_pool=pareto_max,
125
- pareto_max=pareto_max,
126
- max_candidates=int(obj.get("max_candidates", 30)),
127
- max_pareto_fronts=int(obj.get("max_pareto_fronts", 5)),
128
- min_distance=float(obj.get("min_distance", 0.30)),
129
- fingerprint=str(obj.get("fingerprint", "morgan")),
130
- random_seed=int(obj.get("random_seed", 7)),
131
- use_canonical_smiles=not bool(obj.get("skip_smiles_canonicalization", True)),
132
- use_full_data=bool(obj.get("use_full_data", False)),
133
- trust_weights=obj.get("trust_weights"),
134
- selection_weights=obj.get("selection_weights"),
135
- )
136
-
137
- # -------------------------
138
- # Parquet loading (safe)
139
- # -------------------------
140
- def load_parquet_columns(path: str | List[str], columns: List[str]) -> pd.DataFrame:
141
- """
142
- Load only requested columns from Parquet (critical for 1M rows).
143
- Accepts a single path or a list of paths and concatenates rows.
144
- """
145
- def _load_one(fp: str, req_cols: List[str]) -> pd.DataFrame:
146
- available: list[str]
147
- try:
148
- import pyarrow.parquet as pq
149
-
150
- pf = pq.ParquetFile(fp)
151
- available = [str(c) for c in pf.schema.names]
152
- except Exception:
153
- # If schema probing fails, fall back to direct read with requested columns.
154
- return pd.read_parquet(fp, columns=req_cols)
155
-
156
- available_set = set(available)
157
- lower_to_actual = {c.lower(): c for c in available}
158
-
159
- # Resolve requested names against actual parquet schema.
160
- resolved: dict[str, str] = {}
161
- for req in req_cols:
162
- if req in available_set:
163
- resolved[req] = req
164
- continue
165
- alt = lower_to_actual.get(str(req).lower())
166
- if alt is not None:
167
- resolved[req] = alt
168
-
169
- use_cols = sorted(set(resolved.values()))
170
- if not use_cols:
171
- return pd.DataFrame(columns=req_cols)
172
-
173
- out = pd.read_parquet(fp, columns=use_cols)
174
- for req in req_cols:
175
- src = resolved.get(req)
176
- if src is None:
177
- out[req] = np.nan
178
- elif src != req:
179
- out[req] = out[src]
180
- return out[req_cols]
181
-
182
- if isinstance(path, (list, tuple)):
183
- frames = [_load_one(p, columns) for p in path]
184
- if not frames:
185
- return pd.DataFrame(columns=columns)
186
- return pd.concat(frames, ignore_index=True)
187
- return _load_one(path, columns)
188
-
189
-
190
- def normalize_smiles(smiles: str, use_canonical_smiles: bool) -> Optional[str]:
191
- s = (smiles or "").strip()
192
- if not s:
193
- return None
194
- if not use_canonical_smiles:
195
- # Skip RDKit parsing entirely in fast mode.
196
- return s
197
- m = Chem.MolFromSmiles(s)
198
- if m is None:
199
- return None
200
- if use_canonical_smiles:
201
- return Chem.MolToSmiles(m, canonical=True)
202
- return s
203
-
204
-
205
- def load_polyinfo_index(polyinfo_csv_path: str, use_canonical_smiles: bool = True) -> pd.DataFrame:
206
- """
207
- Expected CSV columns: SMILES, Polymer_Class, polymer_name (or common variants).
208
- Returns dataframe with index on smiles_key and columns polymer_name/polymer_class.
209
- """
210
- df = pd.read_csv(polyinfo_csv_path)
211
-
212
- # normalize column names
213
- cols = {c: c for c in df.columns}
214
- # map typical names
215
- if "SMILES" in cols:
216
- df = df.rename(columns={"SMILES": "smiles"})
217
- elif "smiles" not in df.columns:
218
- raise ValueError(f"{polyinfo_csv_path} missing SMILES/smiles column")
219
-
220
- if "Polymer_Name" in df.columns:
221
- df = df.rename(columns={"Polymer_Name": "polymer_name"})
222
- if "polymer_Name" in df.columns:
223
- df = df.rename(columns={"polymer_Name": "polymer_name"})
224
- if "Polymer_Class" in df.columns:
225
- df = df.rename(columns={"Polymer_Class": "polymer_class"})
226
-
227
- if "polymer_name" not in df.columns:
228
- df["polymer_name"] = pd.NA
229
- if "polymer_class" not in df.columns:
230
- df["polymer_class"] = pd.NA
231
-
232
- df["smiles_key"] = df["smiles"].astype(str).map(lambda s: normalize_smiles(s, use_canonical_smiles))
233
- df = df.dropna(subset=["smiles_key"]).drop_duplicates("smiles_key")
234
- df = df.set_index("smiles_key", drop=True)
235
- return df[["polymer_name", "polymer_class"]]
236
-
237
-
238
- # -------------------------
239
- # Pareto (2–3 objectives)
240
- # -------------------------
241
- def pareto_front_mask(X: np.ndarray) -> np.ndarray:
242
- """
243
- Returns mask for nondominated points.
244
- X: (N, M), all objectives assumed to be minimized.
245
- For maximize objectives, we invert before calling this.
246
- """
247
- N = X.shape[0]
248
- is_efficient = np.ones(N, dtype=bool)
249
- for i in range(N):
250
- if not is_efficient[i]:
251
- continue
252
- # any point that is <= in all dims and < in at least one dominates
253
- dominates = np.all(X <= X[i], axis=1) & np.any(X < X[i], axis=1)
254
- # if a point dominates i, mark i inefficient
255
- if np.any(dominates):
256
- is_efficient[i] = False
257
- continue
258
- # otherwise, i may dominate others
259
- dominated_by_i = np.all(X[i] <= X, axis=1) & np.any(X[i] < X, axis=1)
260
- is_efficient[dominated_by_i] = False
261
- is_efficient[i] = True
262
- return is_efficient
263
-
264
-
265
- def pareto_layers(X: np.ndarray, max_layers: int = 10) -> np.ndarray:
266
- """
267
- Returns layer index per point: 1 = Pareto front, 2 = second layer, ...
268
- Unassigned points beyond max_layers get 0.
269
- """
270
- N = X.shape[0]
271
- layers = np.zeros(N, dtype=int)
272
- remaining = np.arange(N)
273
-
274
- layer = 1
275
- while remaining.size > 0 and layer <= max_layers:
276
- mask = pareto_front_mask(X[remaining])
277
- front_idx = remaining[mask]
278
- layers[front_idx] = layer
279
- remaining = remaining[~mask]
280
- layer += 1
281
- return layers
282
-
283
-
284
- def pareto_front_mask_chunked(
285
- X: np.ndarray,
286
- chunk_size: int = 100000,
287
- progress_callback: Optional[Callable[[int, int], None]] = None,
288
- ) -> np.ndarray:
289
- """
290
- Exact global Pareto front mask via chunk-local front reduction + global reconcile.
291
- This is exact for front-1:
292
- 1) compute exact local front within each chunk
293
- 2) union local fronts
294
- 3) compute exact front on the union
295
- """
296
- N = X.shape[0]
297
- if N <= chunk_size:
298
- if progress_callback is not None:
299
- progress_callback(1, 1)
300
- return pareto_front_mask(X)
301
-
302
- local_front_idx = []
303
- total_chunks = (N + chunk_size - 1) // chunk_size
304
- done_chunks = 0
305
- for start in range(0, N, chunk_size):
306
- end = min(start + chunk_size, N)
307
- idx = np.arange(start, end)
308
- mask_local = pareto_front_mask(X[idx])
309
- local_front_idx.append(idx[mask_local])
310
- done_chunks += 1
311
- if progress_callback is not None:
312
- progress_callback(done_chunks, total_chunks)
313
-
314
- if not local_front_idx:
315
- return np.zeros(N, dtype=bool)
316
-
317
- reduced_idx = np.concatenate(local_front_idx)
318
- reduced_mask = pareto_front_mask(X[reduced_idx])
319
- front_idx = reduced_idx[reduced_mask]
320
-
321
- out = np.zeros(N, dtype=bool)
322
- out[front_idx] = True
323
- return out
324
-
325
-
326
- def pareto_layers_chunked(
327
- X: np.ndarray,
328
- max_layers: int = 10,
329
- chunk_size: int = 100000,
330
- progress_callback: Optional[Callable[[int, int, int], None]] = None,
331
- ) -> np.ndarray:
332
- """
333
- Exact Pareto layers using repeated exact chunked front extraction.
334
- """
335
- N = X.shape[0]
336
- layers = np.zeros(N, dtype=int)
337
- remaining = np.arange(N)
338
- layer = 1
339
-
340
- while remaining.size > 0 and layer <= max_layers:
341
- def on_chunk(done: int, total: int) -> None:
342
- if progress_callback is not None:
343
- progress_callback(layer, done, total)
344
-
345
- mask = pareto_front_mask_chunked(X[remaining], chunk_size=chunk_size, progress_callback=on_chunk)
346
- front_idx = remaining[mask]
347
- layers[front_idx] = layer
348
- remaining = remaining[~mask]
349
- layer += 1
350
-
351
- return layers
352
-
353
-
354
- # -------------------------
355
- # Fingerprints & diversity
356
- # -------------------------
357
- def morgan_fp(smiles: str, radius: int = 2, nbits: int = 2048):
358
- m = Chem.MolFromSmiles(smiles)
359
- if m is None:
360
- return None
361
- return AllChem.GetMorganFingerprintAsBitVect(m, radius, nBits=nbits)
362
-
363
- def tanimoto_distance(fp1, fp2) -> float:
364
- return 1.0 - DataStructs.TanimotoSimilarity(fp1, fp2)
365
-
366
- def greedy_diverse_select(
367
- smiles_list: List[str],
368
- scores: np.ndarray,
369
- max_k: int,
370
- min_dist: float,
371
- ) -> List[int]:
372
- """
373
- Greedy selection by descending score, enforcing min Tanimoto distance.
374
- Returns indices into smiles_list.
375
- """
376
- fps = []
377
- valid_idx = []
378
- for i, s in enumerate(smiles_list):
379
- fp = morgan_fp(s)
380
- if fp is not None:
381
- fps.append(fp)
382
- valid_idx.append(i)
383
-
384
- if not valid_idx:
385
- return []
386
-
387
- # rank candidates (higher score first)
388
- order = np.argsort(-scores[valid_idx])
389
- selected_global = []
390
- selected_fps = []
391
-
392
- for oi in order:
393
- i = valid_idx[oi]
394
- fp_i = fps[oi] # aligned with valid_idx
395
- ok = True
396
- for fp_j in selected_fps:
397
- if tanimoto_distance(fp_i, fp_j) < min_dist:
398
- ok = False
399
- break
400
- if ok:
401
- selected_global.append(i)
402
- selected_fps.append(fp_i)
403
- if len(selected_global) >= max_k:
404
- break
405
-
406
- return selected_global
407
-
408
-
409
- # -------------------------
410
- # Trust score (lightweight, robust)
411
- # -------------------------
412
- def internal_consistency_penalty(row: pd.Series) -> float:
413
- """
414
- Very simple physics/validity checks. Penalty in [0,1].
415
- Adjust/add rules later.
416
- """
417
- viol = 0
418
- total = 0
419
-
420
- def chk(cond: bool):
421
- nonlocal viol, total
422
- total += 1
423
- if not cond:
424
- viol += 1
425
-
426
- # positivity checks if present
427
- for p in ["cp", "tc", "rho", "dif", "visc", "tg", "tm", "bandgap"]:
428
- c = mean_col(p)
429
- if c in row.index and pd.notna(row[c]):
430
- if p in ["bandgap", "tg", "tm"]:
431
- chk(float(row[c]) >= 0.0)
432
- else:
433
- chk(float(row[c]) > 0.0)
434
-
435
- # Poisson ratio bounds if present
436
- if mean_col("poisson") in row.index and pd.notna(row[mean_col("poisson")]):
437
- v = float(row[mean_col("poisson")])
438
- chk(0.0 <= v <= 0.5)
439
-
440
- # Tg <= Tm if both present
441
- if mean_col("tg") in row.index and mean_col("tm") in row.index:
442
- if pd.notna(row[mean_col("tg")]) and pd.notna(row[mean_col("tm")]):
443
- chk(float(row[mean_col("tg")]) <= float(row[mean_col("tm")]))
444
-
445
- if total == 0:
446
- return 0.0
447
- return viol / total
448
-
449
-
450
- def synthesizability_score(smiles: str) -> float:
451
- """
452
- RDKit SA-score based synthesizability proxy in [0,1].
453
- SA-score is ~[1 (easy), 10 (hard)].
454
- We map: 1 -> 1.0, 10 -> 0.0
455
- """
456
- m = Chem.MolFromSmiles(smiles)
457
- if m is None:
458
- return 0.0
459
-
460
- # Guard against unexpected scorer failures / None for edge-case molecules.
461
- try:
462
- sa_raw = sascorer.calculateScore(m)
463
- except Exception:
464
- return 0.0
465
- if sa_raw is None:
466
- return 0.0
467
-
468
- sa = float(sa_raw) # ~ 1..10
469
- s_syn = 1.0 - (sa - 1.0) / 9.0 # linear map to [0,1]
470
- return float(np.clip(s_syn, 0.0, 1.0))
471
-
472
-
473
- def compute_trust_scores(
474
- df: pd.DataFrame,
475
- real_fps: List,
476
- real_smiles: List[str],
477
- trust_weights: Dict[str, float] | None = None,
478
- ) -> np.ndarray:
479
- """
480
- Trust score in [0,1] (higher = more trustworthy / lower risk).
481
- Components:
482
- - distance to nearest real polymer (fingerprint distance)
483
- - internal consistency penalty
484
- - uncertainty penalty (if std columns exist)
485
- - synthesizability
486
- """
487
- N = len(df)
488
- trust = np.zeros(N, dtype=float)
489
- tw_defaults = {"real": 0.45, "consistency": 0.25, "uncertainty": 0.10, "synth": 0.20}
490
- tw = normalize_weights(trust_weights or {}, tw_defaults)
491
-
492
- # nearest-real distance (expensive if done naively)
493
- # We do it only for the (small) post-filter set, which is safe.
494
- smiles_col = "smiles_key" if "smiles_key" in df.columns else "smiles_canon"
495
- for i in range(N):
496
- s = df.iloc[i][smiles_col]
497
- fp = morgan_fp(s)
498
- if fp is None or not real_fps:
499
- d_real = 1.0
500
- else:
501
- sims = DataStructs.BulkTanimotoSimilarity(fp, real_fps)
502
- d_real = 1.0 - float(max(sims)) # distance to nearest
503
-
504
- # internal consistency
505
- pen_cons = internal_consistency_penalty(df.iloc[i])
506
-
507
- # uncertainty: average normalized std for any std_* columns present
508
- std_cols = [c for c in df.columns if c.startswith("std_")]
509
- if std_cols:
510
- std_vals = df.iloc[i][std_cols].astype(float)
511
- std_vals = std_vals.replace([np.inf, -np.inf], np.nan).dropna()
512
- pen_unc = float(np.clip(std_vals.mean() / (std_vals.mean() + 1.0), 0.0, 1.0)) if len(std_vals) else 0.0
513
- else:
514
- pen_unc = 0.0
515
-
516
- # synthesizability heuristic
517
- s_syn = synthesizability_score(s)
518
-
519
- # Combine (tunable weights)
520
- # lower distance to real is better -> convert to score
521
- s_real = 1.0 - np.clip(d_real, 0.0, 1.0)
522
-
523
- trust[i] = (
524
- tw["real"] * s_real +
525
- tw["consistency"] * (1.0 - pen_cons) +
526
- tw["uncertainty"] * (1.0 - pen_unc) +
527
- tw["synth"] * s_syn
528
- )
529
-
530
- trust = np.clip(trust, 0.0, 1.0)
531
- return trust
532
-
533
-
534
- # -------------------------
535
- # Main pipeline
536
- # -------------------------
537
- def run_discovery(
538
- spec: DiscoverySpec,
539
- progress_callback: Optional[Callable[[str, float], None]] = None,
540
- ) -> Tuple[pd.DataFrame, Dict[str, float], pd.DataFrame]:
541
- def report(step: str, pct: float) -> None:
542
- if progress_callback is not None:
543
- progress_callback(step, pct)
544
-
545
- rng = np.random.default_rng(spec.random_seed)
546
-
547
- # 1) Determine required columns
548
- report("Preparing columns…", 0.02)
549
- obj_props = [o["property"].lower() for o in spec.objectives]
550
- cons_props = [p.lower() for p in spec.hard_constraints.keys()]
551
-
552
- needed_props = sorted(set(obj_props + cons_props))
553
- cols = ["SMILES"] + [mean_col(p) for p in needed_props]
554
-
555
- # include std columns if available (not required, but used for trust)
556
- std_cols = [std_col(p) for p in needed_props]
557
- cols += std_cols
558
-
559
- # 2) Load only needed columns
560
- report("Loading data from parquet…", 0.05)
561
- df = load_parquet_columns(spec.dataset, columns=[c for c in cols if c != "SMILES"] + ["SMILES"])
562
- # normalize
563
- if "SMILES" not in df.columns and "smiles" in df.columns:
564
- df = df.rename(columns={"smiles": "SMILES"})
565
- normalize_step = "Canonicalizing SMILES…" if spec.use_canonical_smiles else "Skipping SMILES normalization…"
566
- report(normalize_step, 0.10)
567
- df["smiles_key"] = df["SMILES"].astype(str).map(lambda s: normalize_smiles(s, spec.use_canonical_smiles))
568
- df = df.dropna(subset=["smiles_key"]).reset_index(drop=True)
569
-
570
- # 3) Hard constraints
571
- report("Applying constraints…", 0.22)
572
- for p, rule in spec.hard_constraints.items():
573
- p = p.lower()
574
- c = mean_col(p)
575
- if c not in df.columns:
576
- # if missing, nothing can satisfy
577
- df = df.iloc[0:0]
578
- break
579
- if "min" in rule:
580
- df = df[df[c] >= float(rule["min"])]
581
- if "max" in rule:
582
- df = df[df[c] <= float(rule["max"])]
583
-
584
- n_after = len(df)
585
- if n_after == 0:
586
- empty_stats = {"n_total": 0, "n_after_constraints": 0, "n_pool": 0, "n_pareto_pool": 0, "n_selected": 0}
587
- return df, empty_stats, pd.DataFrame()
588
-
589
- n_pool = len(df)
590
-
591
- # 5) Prepare objective matrix for Pareto
592
- report("Building objective matrix…", 0.30)
593
- # convert to minimization: maximize => negate
594
- X = []
595
- resolved_objectives = []
596
- for o in spec.objectives:
597
- prop = o["property"].lower()
598
- goal = o["goal"].lower()
599
- c = mean_col(prop)
600
- if c not in df.columns:
601
- continue
602
- v = df[c].to_numpy(dtype=float)
603
- if goal == "maximize":
604
- v = -v
605
- X.append(v)
606
- resolved_objectives.append({"property": prop, "goal": goal})
607
- if not X:
608
- # Fallback to first available mean_* column to keep pipeline runnable.
609
- fallback_col = next((c for c in df.columns if str(c).startswith("mean_")), None)
610
- if fallback_col is None:
611
- empty_stats = {"n_total": 0, "n_after_constraints": 0, "n_pool": 0, "n_pareto_pool": 0, "n_selected": 0}
612
- return df.iloc[0:0], empty_stats, pd.DataFrame()
613
- X = [df[fallback_col].to_numpy(dtype=float) * -1.0]
614
- resolved_objectives = [{"property": fallback_col.replace("mean_", ""), "goal": "maximize"}]
615
- X = np.stack(X, axis=1) # (N, M)
616
- obj_props = [o["property"] for o in resolved_objectives]
617
-
618
- # Pareto cap before computing layers (optional safety)
619
- if spec.use_full_data:
620
- report("Using full dataset (no Pareto cap)…", 0.35)
621
- elif len(df) > spec.pareto_max:
622
- idx = rng.choice(len(df), size=spec.pareto_max, replace=False)
623
- df = df.iloc[idx].reset_index(drop=True)
624
- X = X[idx]
625
-
626
- # 6) Pareto layers (only 5 layers needed for candidate pool)
627
- report("Computing Pareto layers…", 0.40)
628
- pareto_start = 0.40
629
- pareto_end = 0.54
630
- max_layers_for_pool = max(1, int(spec.max_pareto_fronts))
631
- pareto_chunk_ref = {"chunks_per_layer": None}
632
-
633
- def on_pareto_chunk(layer_i: int, done_chunks: int, total_chunks: int) -> None:
634
- if pareto_chunk_ref["chunks_per_layer"] is None:
635
- pareto_chunk_ref["chunks_per_layer"] = max(1, int(total_chunks))
636
- ref_chunks = pareto_chunk_ref["chunks_per_layer"]
637
- total_units = max_layers_for_pool * ref_chunks
638
- done_units = min(total_units, ((layer_i - 1) * ref_chunks) + done_chunks)
639
- pareto_pct = int(round(100.0 * done_units / max(1, total_units)))
640
-
641
- layer_progress = done_chunks / max(1, total_chunks)
642
- overall = ((layer_i - 1) + layer_progress) / max_layers_for_pool
643
- pct = pareto_start + (pareto_end - pareto_start) * min(1.0, max(0.0, overall))
644
- report(
645
- f"Computing Pareto layers… {pareto_pct}% (Layer {layer_i}/{max_layers_for_pool}, chunk {done_chunks}/{total_chunks})",
646
- pct,
647
- )
648
-
649
- layers = pareto_layers_chunked(
650
- X,
651
- max_layers=max_layers_for_pool,
652
- chunk_size=100000,
653
- progress_callback=on_pareto_chunk,
654
- )
655
- report("Computing Pareto layers…", pareto_end)
656
- df["pareto_layer"] = layers
657
- plot_df = df[["smiles_key"] + [mean_col(p) for p in obj_props] + ["pareto_layer"]].copy()
658
- plot_df = plot_df.rename(columns={"smiles_key": "SMILES"})
659
-
660
- # Keep first few layers as candidate pool (avoid huge set)
661
- cand = df[df["pareto_layer"].between(1, max_layers_for_pool)].copy()
662
- if cand.empty:
663
- cand = df[df["pareto_layer"] == 1].copy()
664
- cand = cand.reset_index(drop=True)
665
- n_pareto = len(cand)
666
-
667
- # 7) Load real polymer metadata and fingerprints (from POLYINFO.csv)
668
- report("Loading POLYINFO index…", 0.55)
669
- polyinfo = load_polyinfo_index(spec.polyinfo_csv, use_canonical_smiles=spec.use_canonical_smiles)
670
- real_smiles = polyinfo.index.to_list()
671
-
672
- report("Building real-polymer fingerprints…", 0.60)
673
- real_fps = []
674
- for s in real_smiles:
675
- fp = morgan_fp(s)
676
- if fp is not None:
677
- real_fps.append(fp)
678
-
679
- # 8) Trust score on candidate pool (safe size)
680
- report("Computing trust scores…", 0.70)
681
- trust = compute_trust_scores(
682
- cand,
683
- real_fps=real_fps,
684
- real_smiles=real_smiles,
685
- trust_weights=spec.trust_weights,
686
- )
687
- cand["trust_score"] = trust
688
-
689
- # 9) Diversity selection on candidate pool
690
- report("Diversity selection…", 0.88)
691
- # score for selection: prioritize Pareto layer 1 then trust
692
- # higher is better
693
- sw_defaults = {"pareto": 0.60, "trust": 0.40}
694
- sw = normalize_weights(spec.selection_weights or {}, sw_defaults)
695
- pareto_bonus = (
696
- (max_layers_for_pool + 1) - np.clip(cand["pareto_layer"].to_numpy(dtype=int), 1, max_layers_for_pool)
697
- ) / float(max_layers_for_pool)
698
- sel_score = sw["pareto"] * pareto_bonus + sw["trust"] * cand["trust_score"].to_numpy(dtype=float)
699
-
700
- chosen_idx = greedy_diverse_select(
701
- smiles_list=cand["smiles_key"].tolist(),
702
- scores=sel_score,
703
- max_k=spec.max_candidates,
704
- min_dist=spec.min_distance,
705
- )
706
- out = cand.iloc[chosen_idx].copy().reset_index(drop=True)
707
-
708
- # 10) Attach Polymer_Name/Class if available (only for matches)
709
- report("Finalizing results…", 0.96)
710
- out = out.set_index("smiles_key", drop=False)
711
- out = out.join(polyinfo, how="left")
712
- out = out.reset_index(drop=True)
713
-
714
- # 11) Make a clean output bundle with requested columns
715
- # Keep SMILES (canonical), name/class, pareto layer, trust score, properties used
716
- keep = ["smiles_key", "polymer_name", "polymer_class", "pareto_layer", "trust_score"]
717
- for p in needed_props:
718
- mc = mean_col(p)
719
- sc = std_col(p)
720
- if mc in out.columns:
721
- keep.append(mc)
722
- if sc in out.columns:
723
- keep.append(sc)
724
-
725
- out = out[keep].rename(columns={"smiles_key": "SMILES"})
726
-
727
- stats = {
728
- "n_total": float(len(df)),
729
- "n_after_constraints": float(n_after),
730
- "n_pool": float(n_pool),
731
- "n_pareto_pool": float(n_pareto),
732
- "n_selected": float(len(out)),
733
- }
734
- report("Done.", 1.0)
735
- return out, stats, plot_df
736
-
737
-
738
- def build_pareto_plot_df(spec: DiscoverySpec, max_plot_points: int = 30000) -> pd.DataFrame:
739
- """
740
- Returns a small dataframe for plotting (sampled), with objective columns and pareto_layer.
741
- Does NOT compute trust/diversity. Safe for live plotting.
742
- """
743
- rng = np.random.default_rng(spec.random_seed)
744
-
745
- obj_props = [o["property"].lower() for o in spec.objectives]
746
- cons_props = [p.lower() for p in spec.hard_constraints.keys()]
747
- needed_props = sorted(set(obj_props + cons_props))
748
-
749
- cols = ["SMILES"] + [mean_col(p) for p in needed_props]
750
- df = load_parquet_columns(spec.dataset, columns=cols)
751
-
752
- if "SMILES" not in df.columns and "smiles" in df.columns:
753
- df = df.rename(columns={"smiles": "SMILES"})
754
-
755
- df["smiles_key"] = df["SMILES"].astype(str).map(lambda s: normalize_smiles(s, spec.use_canonical_smiles))
756
- df = df.dropna(subset=["smiles_key"]).reset_index(drop=True)
757
-
758
- # Hard constraints
759
- for p, rule in spec.hard_constraints.items():
760
- p = p.lower()
761
- c = mean_col(p)
762
- if c not in df.columns:
763
- return df.iloc[0:0]
764
- if "min" in rule:
765
- df = df[df[c] >= float(rule["min"])]
766
- if "max" in rule:
767
- df = df[df[c] <= float(rule["max"])]
768
-
769
- if len(df) == 0:
770
- return df
771
-
772
- # Pareto cap for plotting
773
- plot_cap = min(int(max_plot_points), int(spec.pareto_max))
774
- if len(df) > plot_cap:
775
- idx = rng.choice(len(df), size=plot_cap, replace=False)
776
- df = df.iloc[idx].reset_index(drop=True)
777
-
778
- # Build objective matrix (minimization)
779
- X = []
780
- resolved_obj_props = []
781
- for o in spec.objectives:
782
- prop = o["property"].lower()
783
- goal = o["goal"].lower()
784
- c = mean_col(prop)
785
- if c not in df.columns:
786
- continue
787
- v = df[c].to_numpy(dtype=float)
788
- if goal == "maximize":
789
- v = -v
790
- X.append(v)
791
- resolved_obj_props.append(prop)
792
- if not X:
793
- fallback_col = next((c for c in df.columns if str(c).startswith("mean_")), None)
794
- if fallback_col is None:
795
- return df.iloc[0:0]
796
- X = [df[fallback_col].to_numpy(dtype=float) * -1.0]
797
- resolved_obj_props = [fallback_col.replace("mean_", "")]
798
- X = np.stack(X, axis=1)
799
-
800
- df["pareto_layer"] = pareto_layers(X, max_layers=5)
801
-
802
- # Return only what plotting needs
803
- keep = ["smiles_key", "pareto_layer"] + [mean_col(p) for p in resolved_obj_props]
804
- out = df[keep].rename(columns={"smiles_key": "SMILES"})
805
- return out
806
-
807
-
808
- def parse_spec(text: str, dataset_path: List[str], polyinfo_path: str, polyinfo_csv_path: str) -> DiscoverySpec:
809
- obj = json.loads(text)
810
- pareto_max = int(obj.get("pareto_max", 50000))
811
-
812
- return DiscoverySpec(
813
- dataset=list(dataset_path),
814
- polyinfo=polyinfo_path,
815
- polyinfo_csv=polyinfo_csv_path,
816
- hard_constraints=obj.get("hard_constraints", {}),
817
- objectives=obj.get("objectives", []),
818
- max_pool=pareto_max,
819
- pareto_max=pareto_max,
820
- max_candidates=int(obj.get("max_candidates", 30)),
821
- max_pareto_fronts=int(obj.get("max_pareto_fronts", 5)),
822
- min_distance=float(obj.get("min_distance", 0.30)),
823
- fingerprint=str(obj.get("fingerprint", "morgan")),
824
- random_seed=int(obj.get("random_seed", 7)),
825
- use_canonical_smiles=not bool(obj.get("skip_smiles_canonicalization", True)),
826
- use_full_data=bool(obj.get("use_full_data", False)),
827
- trust_weights=obj.get("trust_weights"),
828
- selection_weights=obj.get("selection_weights"),
829
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/discovery.py DELETED
@@ -1,767 +0,0 @@
1
- # src/discovery.py
2
- from __future__ import annotations
3
-
4
- import json
5
- from dataclasses import dataclass
6
- from pathlib import Path
7
- from typing import Callable, Dict, List, Optional, Tuple
8
-
9
- import numpy as np
10
- import pandas as pd
11
- from rdkit import Chem, DataStructs
12
- from rdkit.Chem import AllChem
13
- from . import sascorer
14
-
15
- # Reuse your canonicalizer if you want; otherwise keep local
16
- def canonicalize_smiles(smiles: str) -> Optional[str]:
17
- s = (smiles or "").strip()
18
- if not s:
19
- return None
20
- m = Chem.MolFromSmiles(s)
21
- if m is None:
22
- return None
23
- return Chem.MolToSmiles(m, canonical=True)
24
-
25
-
26
- # -------------------------
27
- # Spec schema (minimal v0)
28
- # -------------------------
29
- @dataclass
30
- class DiscoverySpec:
31
- dataset: List[str] # ["PI1M_PROPERTY.parquet", "POLYINFO_PROPERTY.parquet"]
32
- polyinfo: str # "POLYINFO_PROPERTY.parquet"
33
- polyinfo_csv: str # "POLYINFO.csv"
34
-
35
- hard_constraints: Dict[str, Dict[str, float]] # { "tg": {"min": 400}, "tc": {"max": 0.3} }
36
- objectives: List[Dict[str, str]] # [{"property":"cp","goal":"maximize"}, ...]
37
-
38
- max_pool: int = 200000 # legacy (kept for compatibility; aligned to pareto_max)
39
- pareto_max: int = 50000 # cap points used for Pareto + diversity fingerprinting
40
- max_candidates: int = 30 # final output size
41
- max_pareto_fronts: int = 5 # how many Pareto layers to keep for candidate pool
42
- min_distance: float = 0.30 # diversity threshold in Tanimoto distance
43
- fingerprint: str = "morgan" # morgan only for now
44
- random_seed: int = 7
45
- use_canonical_smiles: bool = True
46
- use_full_data: bool = False
47
- trust_weights: Dict[str, float] | None = None
48
- selection_weights: Dict[str, float] | None = None
49
-
50
-
51
- # -------------------------
52
- # Column mapping
53
- # -------------------------
54
- def mean_col(prop_key: str) -> str:
55
- return f"mean_{prop_key.lower()}"
56
-
57
- def std_col(prop_key: str) -> str:
58
- return f"std_{prop_key.lower()}"
59
-
60
-
61
- def normalize_weights(weights: Dict[str, float], defaults: Dict[str, float]) -> Dict[str, float]:
62
- out: Dict[str, float] = {}
63
- for k, v in defaults.items():
64
- try:
65
- vv = float(weights.get(k, v))
66
- except Exception:
67
- vv = float(v)
68
- out[k] = max(0.0, vv)
69
- s = float(sum(out.values()))
70
- if s <= 0.0:
71
- return defaults.copy()
72
- return {k: float(v / s) for k, v in out.items()}
73
-
74
- def spec_from_dict(obj: dict, dataset_path: List[str], polyinfo_path: str, polyinfo_csv_path: str) -> DiscoverySpec:
75
- pareto_max = int(obj.get("pareto_max", 50000))
76
- return DiscoverySpec(
77
- dataset=list(dataset_path),
78
- polyinfo=polyinfo_path,
79
- polyinfo_csv=polyinfo_csv_path,
80
- hard_constraints=obj.get("hard_constraints", {}),
81
- objectives=obj.get("objectives", []),
82
- # Legacy field kept for compatibility; effectively collapsed to pareto_max.
83
- max_pool=pareto_max,
84
- pareto_max=pareto_max,
85
- max_candidates=int(obj.get("max_candidates", 30)),
86
- max_pareto_fronts=int(obj.get("max_pareto_fronts", 5)),
87
- min_distance=float(obj.get("min_distance", 0.30)),
88
- fingerprint=str(obj.get("fingerprint", "morgan")),
89
- random_seed=int(obj.get("random_seed", 7)),
90
- use_canonical_smiles=not bool(obj.get("skip_smiles_canonicalization", True)),
91
- use_full_data=bool(obj.get("use_full_data", False)),
92
- trust_weights=obj.get("trust_weights"),
93
- selection_weights=obj.get("selection_weights"),
94
- )
95
-
96
- # -------------------------
97
- # Parquet loading (safe)
98
- # -------------------------
99
- def load_parquet_columns(path: str | List[str], columns: List[str]) -> pd.DataFrame:
100
- """
101
- Load only requested columns from Parquet (critical for 1M rows).
102
- Accepts a single path or a list of paths and concatenates rows.
103
- """
104
- def _load_one(fp: str, req_cols: List[str]) -> pd.DataFrame:
105
- available: list[str]
106
- try:
107
- import pyarrow.parquet as pq
108
-
109
- pf = pq.ParquetFile(fp)
110
- available = [str(c) for c in pf.schema.names]
111
- except Exception:
112
- # If schema probing fails, fall back to direct read with requested columns.
113
- return pd.read_parquet(fp, columns=req_cols)
114
-
115
- available_set = set(available)
116
- lower_to_actual = {c.lower(): c for c in available}
117
-
118
- # Resolve requested names against actual parquet schema.
119
- resolved: dict[str, str] = {}
120
- for req in req_cols:
121
- if req in available_set:
122
- resolved[req] = req
123
- continue
124
- alt = lower_to_actual.get(str(req).lower())
125
- if alt is not None:
126
- resolved[req] = alt
127
-
128
- use_cols = sorted(set(resolved.values()))
129
- if not use_cols:
130
- return pd.DataFrame(columns=req_cols)
131
-
132
- out = pd.read_parquet(fp, columns=use_cols)
133
- for req in req_cols:
134
- src = resolved.get(req)
135
- if src is None:
136
- out[req] = np.nan
137
- elif src != req:
138
- out[req] = out[src]
139
- return out[req_cols]
140
-
141
- if isinstance(path, (list, tuple)):
142
- frames = [_load_one(p, columns) for p in path]
143
- if not frames:
144
- return pd.DataFrame(columns=columns)
145
- return pd.concat(frames, ignore_index=True)
146
- return _load_one(path, columns)
147
-
148
-
149
- def normalize_smiles(smiles: str, use_canonical_smiles: bool) -> Optional[str]:
150
- s = (smiles or "").strip()
151
- if not s:
152
- return None
153
- if not use_canonical_smiles:
154
- # Skip RDKit parsing entirely in fast mode.
155
- return s
156
- m = Chem.MolFromSmiles(s)
157
- if m is None:
158
- return None
159
- if use_canonical_smiles:
160
- return Chem.MolToSmiles(m, canonical=True)
161
- return s
162
-
163
-
164
- def load_polyinfo_index(polyinfo_csv_path: str, use_canonical_smiles: bool = True) -> pd.DataFrame:
165
- """
166
- Expected CSV columns: SMILES, Polymer_Class, polymer_name (or common variants).
167
- Returns dataframe with index on smiles_key and columns polymer_name/polymer_class.
168
- """
169
- df = pd.read_csv(polyinfo_csv_path)
170
-
171
- # normalize column names
172
- cols = {c: c for c in df.columns}
173
- # map typical names
174
- if "SMILES" in cols:
175
- df = df.rename(columns={"SMILES": "smiles"})
176
- elif "smiles" not in df.columns:
177
- raise ValueError(f"{polyinfo_csv_path} missing SMILES/smiles column")
178
-
179
- if "Polymer_Name" in df.columns:
180
- df = df.rename(columns={"Polymer_Name": "polymer_name"})
181
- if "polymer_Name" in df.columns:
182
- df = df.rename(columns={"polymer_Name": "polymer_name"})
183
- if "Polymer_Class" in df.columns:
184
- df = df.rename(columns={"Polymer_Class": "polymer_class"})
185
-
186
- if "polymer_name" not in df.columns:
187
- df["polymer_name"] = pd.NA
188
- if "polymer_class" not in df.columns:
189
- df["polymer_class"] = pd.NA
190
-
191
- df["smiles_key"] = df["smiles"].astype(str).map(lambda s: normalize_smiles(s, use_canonical_smiles))
192
- df = df.dropna(subset=["smiles_key"]).drop_duplicates("smiles_key")
193
- df = df.set_index("smiles_key", drop=True)
194
- return df[["polymer_name", "polymer_class"]]
195
-
196
-
197
- # -------------------------
198
- # Pareto (2–3 objectives)
199
- # -------------------------
200
- def pareto_front_mask(X: np.ndarray) -> np.ndarray:
201
- """
202
- Returns mask for nondominated points.
203
- X: (N, M), all objectives assumed to be minimized.
204
- For maximize objectives, we invert before calling this.
205
- """
206
- N = X.shape[0]
207
- is_efficient = np.ones(N, dtype=bool)
208
- for i in range(N):
209
- if not is_efficient[i]:
210
- continue
211
- # any point that is <= in all dims and < in at least one dominates
212
- dominates = np.all(X <= X[i], axis=1) & np.any(X < X[i], axis=1)
213
- # if a point dominates i, mark i inefficient
214
- if np.any(dominates):
215
- is_efficient[i] = False
216
- continue
217
- # otherwise, i may dominate others
218
- dominated_by_i = np.all(X[i] <= X, axis=1) & np.any(X[i] < X, axis=1)
219
- is_efficient[dominated_by_i] = False
220
- is_efficient[i] = True
221
- return is_efficient
222
-
223
-
224
- def pareto_layers(X: np.ndarray, max_layers: int = 10) -> np.ndarray:
225
- """
226
- Returns layer index per point: 1 = Pareto front, 2 = second layer, ...
227
- Unassigned points beyond max_layers get 0.
228
- """
229
- N = X.shape[0]
230
- layers = np.zeros(N, dtype=int)
231
- remaining = np.arange(N)
232
-
233
- layer = 1
234
- while remaining.size > 0 and layer <= max_layers:
235
- mask = pareto_front_mask(X[remaining])
236
- front_idx = remaining[mask]
237
- layers[front_idx] = layer
238
- remaining = remaining[~mask]
239
- layer += 1
240
- return layers
241
-
242
-
243
- def pareto_front_mask_chunked(
244
- X: np.ndarray,
245
- chunk_size: int = 100000,
246
- progress_callback: Optional[Callable[[int, int], None]] = None,
247
- ) -> np.ndarray:
248
- """
249
- Exact global Pareto front mask via chunk-local front reduction + global reconcile.
250
- This is exact for front-1:
251
- 1) compute exact local front within each chunk
252
- 2) union local fronts
253
- 3) compute exact front on the union
254
- """
255
- N = X.shape[0]
256
- if N <= chunk_size:
257
- if progress_callback is not None:
258
- progress_callback(1, 1)
259
- return pareto_front_mask(X)
260
-
261
- local_front_idx = []
262
- total_chunks = (N + chunk_size - 1) // chunk_size
263
- done_chunks = 0
264
- for start in range(0, N, chunk_size):
265
- end = min(start + chunk_size, N)
266
- idx = np.arange(start, end)
267
- mask_local = pareto_front_mask(X[idx])
268
- local_front_idx.append(idx[mask_local])
269
- done_chunks += 1
270
- if progress_callback is not None:
271
- progress_callback(done_chunks, total_chunks)
272
-
273
- if not local_front_idx:
274
- return np.zeros(N, dtype=bool)
275
-
276
- reduced_idx = np.concatenate(local_front_idx)
277
- reduced_mask = pareto_front_mask(X[reduced_idx])
278
- front_idx = reduced_idx[reduced_mask]
279
-
280
- out = np.zeros(N, dtype=bool)
281
- out[front_idx] = True
282
- return out
283
-
284
-
285
- def pareto_layers_chunked(
286
- X: np.ndarray,
287
- max_layers: int = 10,
288
- chunk_size: int = 100000,
289
- progress_callback: Optional[Callable[[int, int, int], None]] = None,
290
- ) -> np.ndarray:
291
- """
292
- Exact Pareto layers using repeated exact chunked front extraction.
293
- """
294
- N = X.shape[0]
295
- layers = np.zeros(N, dtype=int)
296
- remaining = np.arange(N)
297
- layer = 1
298
-
299
- while remaining.size > 0 and layer <= max_layers:
300
- def on_chunk(done: int, total: int) -> None:
301
- if progress_callback is not None:
302
- progress_callback(layer, done, total)
303
-
304
- mask = pareto_front_mask_chunked(X[remaining], chunk_size=chunk_size, progress_callback=on_chunk)
305
- front_idx = remaining[mask]
306
- layers[front_idx] = layer
307
- remaining = remaining[~mask]
308
- layer += 1
309
-
310
- return layers
311
-
312
-
313
- # -------------------------
314
- # Fingerprints & diversity
315
- # -------------------------
316
- def morgan_fp(smiles: str, radius: int = 2, nbits: int = 2048):
317
- m = Chem.MolFromSmiles(smiles)
318
- if m is None:
319
- return None
320
- return AllChem.GetMorganFingerprintAsBitVect(m, radius, nBits=nbits)
321
-
322
- def tanimoto_distance(fp1, fp2) -> float:
323
- return 1.0 - DataStructs.TanimotoSimilarity(fp1, fp2)
324
-
325
- def greedy_diverse_select(
326
- smiles_list: List[str],
327
- scores: np.ndarray,
328
- max_k: int,
329
- min_dist: float,
330
- ) -> List[int]:
331
- """
332
- Greedy selection by descending score, enforcing min Tanimoto distance.
333
- Returns indices into smiles_list.
334
- """
335
- fps = []
336
- valid_idx = []
337
- for i, s in enumerate(smiles_list):
338
- fp = morgan_fp(s)
339
- if fp is not None:
340
- fps.append(fp)
341
- valid_idx.append(i)
342
-
343
- if not valid_idx:
344
- return []
345
-
346
- # rank candidates (higher score first)
347
- order = np.argsort(-scores[valid_idx])
348
- selected_global = []
349
- selected_fps = []
350
-
351
- for oi in order:
352
- i = valid_idx[oi]
353
- fp_i = fps[oi] # aligned with valid_idx
354
- ok = True
355
- for fp_j in selected_fps:
356
- if tanimoto_distance(fp_i, fp_j) < min_dist:
357
- ok = False
358
- break
359
- if ok:
360
- selected_global.append(i)
361
- selected_fps.append(fp_i)
362
- if len(selected_global) >= max_k:
363
- break
364
-
365
- return selected_global
366
-
367
-
368
- # -------------------------
369
- # Trust score (lightweight, robust)
370
- # -------------------------
371
- def internal_consistency_penalty(row: pd.Series) -> float:
372
- """
373
- Very simple physics/validity checks. Penalty in [0,1].
374
- Adjust/add rules later.
375
- """
376
- viol = 0
377
- total = 0
378
-
379
- def chk(cond: bool):
380
- nonlocal viol, total
381
- total += 1
382
- if not cond:
383
- viol += 1
384
-
385
- # positivity checks if present
386
- for p in ["cp", "tc", "rho", "dif", "visc", "tg", "tm", "bandgap"]:
387
- c = mean_col(p)
388
- if c in row.index and pd.notna(row[c]):
389
- if p in ["bandgap", "tg", "tm"]:
390
- chk(float(row[c]) >= 0.0)
391
- else:
392
- chk(float(row[c]) > 0.0)
393
-
394
- # Poisson ratio bounds if present
395
- if mean_col("poisson") in row.index and pd.notna(row[mean_col("poisson")]):
396
- v = float(row[mean_col("poisson")])
397
- chk(0.0 <= v <= 0.5)
398
-
399
- # Tg <= Tm if both present
400
- if mean_col("tg") in row.index and mean_col("tm") in row.index:
401
- if pd.notna(row[mean_col("tg")]) and pd.notna(row[mean_col("tm")]):
402
- chk(float(row[mean_col("tg")]) <= float(row[mean_col("tm")]))
403
-
404
- if total == 0:
405
- return 0.0
406
- return viol / total
407
-
408
-
409
- def synthesizability_score(smiles: str) -> float:
410
- """
411
- RDKit SA-score based synthesizability proxy in [0,1].
412
- SA-score is ~[1 (easy), 10 (hard)].
413
- We map: 1 -> 1.0, 10 -> 0.0
414
- """
415
- m = Chem.MolFromSmiles(smiles)
416
- if m is None:
417
- return 0.0
418
-
419
- # Guard against unexpected scorer failures / None for edge-case molecules.
420
- try:
421
- sa_raw = sascorer.calculateScore(m)
422
- except Exception:
423
- return 0.0
424
- if sa_raw is None:
425
- return 0.0
426
-
427
- sa = float(sa_raw) # ~ 1..10
428
- s_syn = 1.0 - (sa - 1.0) / 9.0 # linear map to [0,1]
429
- return float(np.clip(s_syn, 0.0, 1.0))
430
-
431
-
432
- def compute_trust_scores(
433
- df: pd.DataFrame,
434
- real_fps: List,
435
- real_smiles: List[str],
436
- trust_weights: Dict[str, float] | None = None,
437
- ) -> np.ndarray:
438
- """
439
- Trust score in [0,1] (higher = more trustworthy / lower risk).
440
- Components:
441
- - distance to nearest real polymer (fingerprint distance)
442
- - internal consistency penalty
443
- - uncertainty penalty (if std columns exist)
444
- - synthesizability
445
- """
446
- N = len(df)
447
- trust = np.zeros(N, dtype=float)
448
- tw_defaults = {"real": 0.45, "consistency": 0.25, "uncertainty": 0.10, "synth": 0.20}
449
- tw = normalize_weights(trust_weights or {}, tw_defaults)
450
-
451
- # nearest-real distance (expensive if done naively)
452
- # We do it only for the (small) post-filter set, which is safe.
453
- smiles_col = "smiles_key" if "smiles_key" in df.columns else "smiles_canon"
454
- for i in range(N):
455
- s = df.iloc[i][smiles_col]
456
- fp = morgan_fp(s)
457
- if fp is None or not real_fps:
458
- d_real = 1.0
459
- else:
460
- sims = DataStructs.BulkTanimotoSimilarity(fp, real_fps)
461
- d_real = 1.0 - float(max(sims)) # distance to nearest
462
-
463
- # internal consistency
464
- pen_cons = internal_consistency_penalty(df.iloc[i])
465
-
466
- # uncertainty: average normalized std for any std_* columns present
467
- std_cols = [c for c in df.columns if c.startswith("std_")]
468
- if std_cols:
469
- std_vals = df.iloc[i][std_cols].astype(float)
470
- std_vals = std_vals.replace([np.inf, -np.inf], np.nan).dropna()
471
- pen_unc = float(np.clip(std_vals.mean() / (std_vals.mean() + 1.0), 0.0, 1.0)) if len(std_vals) else 0.0
472
- else:
473
- pen_unc = 0.0
474
-
475
- # synthesizability heuristic
476
- s_syn = synthesizability_score(s)
477
-
478
- # Combine (tunable weights)
479
- # lower distance to real is better -> convert to score
480
- s_real = 1.0 - np.clip(d_real, 0.0, 1.0)
481
-
482
- trust[i] = (
483
- tw["real"] * s_real +
484
- tw["consistency"] * (1.0 - pen_cons) +
485
- tw["uncertainty"] * (1.0 - pen_unc) +
486
- tw["synth"] * s_syn
487
- )
488
-
489
- trust = np.clip(trust, 0.0, 1.0)
490
- return trust
491
-
492
-
493
- # -------------------------
494
- # Main pipeline
495
- # -------------------------
496
- def run_discovery(
497
- spec: DiscoverySpec,
498
- progress_callback: Optional[Callable[[str, float], None]] = None,
499
- ) -> Tuple[pd.DataFrame, Dict[str, float], pd.DataFrame]:
500
- def report(step: str, pct: float) -> None:
501
- if progress_callback is not None:
502
- progress_callback(step, pct)
503
-
504
- rng = np.random.default_rng(spec.random_seed)
505
-
506
- # 1) Determine required columns
507
- report("Preparing columns…", 0.02)
508
- obj_props = [o["property"].lower() for o in spec.objectives]
509
- cons_props = [p.lower() for p in spec.hard_constraints.keys()]
510
-
511
- needed_props = sorted(set(obj_props + cons_props))
512
- cols = ["SMILES"] + [mean_col(p) for p in needed_props]
513
-
514
- # include std columns if available (not required, but used for trust)
515
- std_cols = [std_col(p) for p in needed_props]
516
- cols += std_cols
517
-
518
- # 2) Load only needed columns
519
- report("Loading data from parquet…", 0.05)
520
- df = load_parquet_columns(spec.dataset, columns=[c for c in cols if c != "SMILES"] + ["SMILES"])
521
- # normalize
522
- if "SMILES" not in df.columns and "smiles" in df.columns:
523
- df = df.rename(columns={"smiles": "SMILES"})
524
- normalize_step = "Canonicalizing SMILES…" if spec.use_canonical_smiles else "Skipping SMILES normalization…"
525
- report(normalize_step, 0.10)
526
- df["smiles_key"] = df["SMILES"].astype(str).map(lambda s: normalize_smiles(s, spec.use_canonical_smiles))
527
- df = df.dropna(subset=["smiles_key"]).reset_index(drop=True)
528
-
529
- # 3) Hard constraints
530
- report("Applying constraints…", 0.22)
531
- for p, rule in spec.hard_constraints.items():
532
- p = p.lower()
533
- c = mean_col(p)
534
- if c not in df.columns:
535
- # if missing, nothing can satisfy
536
- df = df.iloc[0:0]
537
- break
538
- if "min" in rule:
539
- df = df[df[c] >= float(rule["min"])]
540
- if "max" in rule:
541
- df = df[df[c] <= float(rule["max"])]
542
-
543
- n_after = len(df)
544
- if n_after == 0:
545
- empty_stats = {"n_total": 0, "n_after_constraints": 0, "n_pool": 0, "n_pareto_pool": 0, "n_selected": 0}
546
- return df, empty_stats, pd.DataFrame()
547
-
548
- n_pool = len(df)
549
-
550
- # 5) Prepare objective matrix for Pareto
551
- report("Building objective matrix…", 0.30)
552
- # convert to minimization: maximize => negate
553
- X = []
554
- for o in spec.objectives:
555
- prop = o["property"].lower()
556
- goal = o["goal"].lower()
557
- c = mean_col(prop)
558
- if c not in df.columns:
559
- raise ValueError(f"Objective column missing: {c}")
560
- v = df[c].to_numpy(dtype=float)
561
- if goal == "maximize":
562
- v = -v
563
- X.append(v)
564
- X = np.stack(X, axis=1) # (N, M)
565
-
566
- # Pareto cap before computing layers (optional safety)
567
- if spec.use_full_data:
568
- report("Using full dataset (no Pareto cap)…", 0.35)
569
- elif len(df) > spec.pareto_max:
570
- idx = rng.choice(len(df), size=spec.pareto_max, replace=False)
571
- df = df.iloc[idx].reset_index(drop=True)
572
- X = X[idx]
573
-
574
- # 6) Pareto layers (only 5 layers needed for candidate pool)
575
- report("Computing Pareto layers…", 0.40)
576
- pareto_start = 0.40
577
- pareto_end = 0.54
578
- max_layers_for_pool = max(1, int(spec.max_pareto_fronts))
579
- pareto_chunk_ref = {"chunks_per_layer": None}
580
-
581
- def on_pareto_chunk(layer_i: int, done_chunks: int, total_chunks: int) -> None:
582
- if pareto_chunk_ref["chunks_per_layer"] is None:
583
- pareto_chunk_ref["chunks_per_layer"] = max(1, int(total_chunks))
584
- ref_chunks = pareto_chunk_ref["chunks_per_layer"]
585
- total_units = max_layers_for_pool * ref_chunks
586
- done_units = min(total_units, ((layer_i - 1) * ref_chunks) + done_chunks)
587
- pareto_pct = int(round(100.0 * done_units / max(1, total_units)))
588
-
589
- layer_progress = done_chunks / max(1, total_chunks)
590
- overall = ((layer_i - 1) + layer_progress) / max_layers_for_pool
591
- pct = pareto_start + (pareto_end - pareto_start) * min(1.0, max(0.0, overall))
592
- report(
593
- f"Computing Pareto layers… {pareto_pct}% (Layer {layer_i}/{max_layers_for_pool}, chunk {done_chunks}/{total_chunks})",
594
- pct,
595
- )
596
-
597
- layers = pareto_layers_chunked(
598
- X,
599
- max_layers=max_layers_for_pool,
600
- chunk_size=100000,
601
- progress_callback=on_pareto_chunk,
602
- )
603
- report("Computing Pareto layers…", pareto_end)
604
- df["pareto_layer"] = layers
605
- plot_df = df[["smiles_key"] + [mean_col(p) for p in obj_props] + ["pareto_layer"]].copy()
606
- plot_df = plot_df.rename(columns={"smiles_key": "SMILES"})
607
-
608
- # Keep first few layers as candidate pool (avoid huge set)
609
- cand = df[df["pareto_layer"].between(1, max_layers_for_pool)].copy()
610
- if cand.empty:
611
- cand = df[df["pareto_layer"] == 1].copy()
612
- cand = cand.reset_index(drop=True)
613
- n_pareto = len(cand)
614
-
615
- # 7) Load real polymer metadata and fingerprints (from POLYINFO.csv)
616
- report("Loading POLYINFO index…", 0.55)
617
- polyinfo = load_polyinfo_index(spec.polyinfo_csv, use_canonical_smiles=spec.use_canonical_smiles)
618
- real_smiles = polyinfo.index.to_list()
619
-
620
- report("Building real-polymer fingerprints…", 0.60)
621
- real_fps = []
622
- for s in real_smiles:
623
- fp = morgan_fp(s)
624
- if fp is not None:
625
- real_fps.append(fp)
626
-
627
- # 8) Trust score on candidate pool (safe size)
628
- report("Computing trust scores…", 0.70)
629
- trust = compute_trust_scores(
630
- cand,
631
- real_fps=real_fps,
632
- real_smiles=real_smiles,
633
- trust_weights=spec.trust_weights,
634
- )
635
- cand["trust_score"] = trust
636
-
637
- # 9) Diversity selection on candidate pool
638
- report("Diversity selection…", 0.88)
639
- # score for selection: prioritize Pareto layer 1 then trust
640
- # higher is better
641
- sw_defaults = {"pareto": 0.60, "trust": 0.40}
642
- sw = normalize_weights(spec.selection_weights or {}, sw_defaults)
643
- pareto_bonus = (
644
- (max_layers_for_pool + 1) - np.clip(cand["pareto_layer"].to_numpy(dtype=int), 1, max_layers_for_pool)
645
- ) / float(max_layers_for_pool)
646
- sel_score = sw["pareto"] * pareto_bonus + sw["trust"] * cand["trust_score"].to_numpy(dtype=float)
647
-
648
- chosen_idx = greedy_diverse_select(
649
- smiles_list=cand["smiles_key"].tolist(),
650
- scores=sel_score,
651
- max_k=spec.max_candidates,
652
- min_dist=spec.min_distance,
653
- )
654
- out = cand.iloc[chosen_idx].copy().reset_index(drop=True)
655
-
656
- # 10) Attach Polymer_Name/Class if available (only for matches)
657
- report("Finalizing results…", 0.96)
658
- out = out.set_index("smiles_key", drop=False)
659
- out = out.join(polyinfo, how="left")
660
- out = out.reset_index(drop=True)
661
-
662
- # 11) Make a clean output bundle with requested columns
663
- # Keep SMILES (canonical), name/class, pareto layer, trust score, properties used
664
- keep = ["smiles_key", "polymer_name", "polymer_class", "pareto_layer", "trust_score"]
665
- for p in needed_props:
666
- mc = mean_col(p)
667
- sc = std_col(p)
668
- if mc in out.columns:
669
- keep.append(mc)
670
- if sc in out.columns:
671
- keep.append(sc)
672
-
673
- out = out[keep].rename(columns={"smiles_key": "SMILES"})
674
-
675
- stats = {
676
- "n_total": float(len(df)),
677
- "n_after_constraints": float(n_after),
678
- "n_pool": float(n_pool),
679
- "n_pareto_pool": float(n_pareto),
680
- "n_selected": float(len(out)),
681
- }
682
- report("Done.", 1.0)
683
- return out, stats, plot_df
684
-
685
-
686
- def build_pareto_plot_df(spec: DiscoverySpec, max_plot_points: int = 30000) -> pd.DataFrame:
687
- """
688
- Returns a small dataframe for plotting (sampled), with objective columns and pareto_layer.
689
- Does NOT compute trust/diversity. Safe for live plotting.
690
- """
691
- rng = np.random.default_rng(spec.random_seed)
692
-
693
- obj_props = [o["property"].lower() for o in spec.objectives]
694
- cons_props = [p.lower() for p in spec.hard_constraints.keys()]
695
- needed_props = sorted(set(obj_props + cons_props))
696
-
697
- cols = ["SMILES"] + [mean_col(p) for p in needed_props]
698
- df = load_parquet_columns(spec.dataset, columns=cols)
699
-
700
- if "SMILES" not in df.columns and "smiles" in df.columns:
701
- df = df.rename(columns={"smiles": "SMILES"})
702
-
703
- df["smiles_key"] = df["SMILES"].astype(str).map(lambda s: normalize_smiles(s, spec.use_canonical_smiles))
704
- df = df.dropna(subset=["smiles_key"]).reset_index(drop=True)
705
-
706
- # Hard constraints
707
- for p, rule in spec.hard_constraints.items():
708
- p = p.lower()
709
- c = mean_col(p)
710
- if c not in df.columns:
711
- return df.iloc[0:0]
712
- if "min" in rule:
713
- df = df[df[c] >= float(rule["min"])]
714
- if "max" in rule:
715
- df = df[df[c] <= float(rule["max"])]
716
-
717
- if len(df) == 0:
718
- return df
719
-
720
- # Pareto cap for plotting
721
- plot_cap = min(int(max_plot_points), int(spec.pareto_max))
722
- if len(df) > plot_cap:
723
- idx = rng.choice(len(df), size=plot_cap, replace=False)
724
- df = df.iloc[idx].reset_index(drop=True)
725
-
726
- # Build objective matrix (minimization)
727
- X = []
728
- for o in spec.objectives:
729
- prop = o["property"].lower()
730
- goal = o["goal"].lower()
731
- c = mean_col(prop)
732
- v = df[c].to_numpy(dtype=float)
733
- if goal == "maximize":
734
- v = -v
735
- X.append(v)
736
- X = np.stack(X, axis=1)
737
-
738
- df["pareto_layer"] = pareto_layers(X, max_layers=5)
739
-
740
- # Return only what plotting needs
741
- keep = ["smiles_key", "pareto_layer"] + [mean_col(p) for p in obj_props]
742
- out = df[keep].rename(columns={"smiles_key": "SMILES"})
743
- return out
744
-
745
-
746
- def parse_spec(text: str, dataset_path: List[str], polyinfo_path: str, polyinfo_csv_path: str) -> DiscoverySpec:
747
- obj = json.loads(text)
748
- pareto_max = int(obj.get("pareto_max", 50000))
749
-
750
- return DiscoverySpec(
751
- dataset=list(dataset_path),
752
- polyinfo=polyinfo_path,
753
- polyinfo_csv=polyinfo_csv_path,
754
- hard_constraints=obj.get("hard_constraints", {}),
755
- objectives=obj.get("objectives", []),
756
- max_pool=pareto_max,
757
- pareto_max=pareto_max,
758
- max_candidates=int(obj.get("max_candidates", 30)),
759
- max_pareto_fronts=int(obj.get("max_pareto_fronts", 5)),
760
- min_distance=float(obj.get("min_distance", 0.30)),
761
- fingerprint=str(obj.get("fingerprint", "morgan")),
762
- random_seed=int(obj.get("random_seed", 7)),
763
- use_canonical_smiles=not bool(obj.get("skip_smiles_canonicalization", True)),
764
- use_full_data=bool(obj.get("use_full_data", False)),
765
- trust_weights=obj.get("trust_weights"),
766
- selection_weights=obj.get("selection_weights"),
767
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/fpscores.pkl.gz DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:9abb6f4c322d27fa05c8a1115a463bcef312d2bed0b447c347de33bfefa83316
3
- size 132
 
 
 
 
src/lookup.py DELETED
@@ -1,222 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import pandas as pd
4
- import streamlit as st
5
- from rdkit import Chem
6
- from rdkit import RDLogger
7
-
8
- RDLogger.DisableLog("rdApp.*")
9
-
10
- # ----------------------------
11
- # Sources (property value files)
12
- # ----------------------------
13
- SOURCES = ["EXP", "MD", "DFT", "GC"]
14
-
15
- SOURCE_LABELS = {
16
- "EXP": "Experimental",
17
- "MD": "Molecular Dynamics",
18
- "DFT": "Density Functional Theory",
19
- "GC": "Group Contribution",
20
- }
21
-
22
- # ----------------------------
23
- # PolyInfo metadata file (name/class)
24
- # ----------------------------
25
- POLYINFO_FILE = "data/POLYINFO.csv" # contains: SMILES, Polymer_Class, Polymer_Name
26
-
27
-
28
- def canonicalize_smiles(smiles: str) -> str | None:
29
- smiles = (smiles or "").strip()
30
- if not smiles:
31
- return None
32
- mol = Chem.MolFromSmiles(smiles)
33
- if mol is None:
34
- return None
35
- return Chem.MolToSmiles(mol, canonical=True)
36
-
37
-
38
- # --- Property meta (full name + unit) ---
39
- PROPERTY_META = {
40
- # Thermal
41
- "tm": {"name": "Melting temperature", "unit": "K"},
42
- "tg": {"name": "Glass transition temperature", "unit": "K"},
43
- "td": {"name": "Thermal diffusivity", "unit": "m^2/s"},
44
- "tc": {"name": "Thermal conductivity", "unit": "W/m·K"},
45
- "cp": {"name": "Specific heat capacity", "unit": "J/kg·K"},
46
- # Mechanical
47
- "young": {"name": "Young's modulus", "unit": "GPa"},
48
- "shear": {"name": "Shear modulus", "unit": "GPa"},
49
- "bulk": {"name": "Bulk modulus", "unit": "GPa"},
50
- "poisson": {"name": "Poisson ratio", "unit": "-"},
51
- # Transport
52
- "visc": {"name": "Viscosity", "unit": "Pa·s"},
53
- "dif": {"name": "Diffusivity", "unit": "cm^2/s"},
54
- # Gas permeability
55
- "phe": {"name": "He permeability", "unit": "Barrer"},
56
- "ph2": {"name": "H2 permeability", "unit": "Barrer"},
57
- "pco2": {"name": "CO2 permeability", "unit": "Barrer"},
58
- "pn2": {"name": "N2 permeability", "unit": "Barrer"},
59
- "po2": {"name": "O2 permeability", "unit": "Barrer"},
60
- "pch4": {"name": "CH4 permeability", "unit": "Barrer"},
61
- # Electronic / Optical
62
- "alpha": {"name": "Polarizability", "unit": "a.u."},
63
- "homo": {"name": "HOMO energy", "unit": "eV"},
64
- "lumo": {"name": "LUMO energy", "unit": "eV"},
65
- "bandgap": {"name": "Band gap", "unit": "eV"},
66
- "mu": {"name": "Dipole moment", "unit": "Debye"},
67
- "etotal": {"name": "Total electronic energy", "unit": "eV"},
68
- "ri": {"name": "Refractive index", "unit": "-"},
69
- "dc": {"name": "Dielectric constant", "unit": "-"},
70
- "pe": {"name": "Permittivity", "unit": "-"},
71
- # Structural / Physical
72
- "rg": {"name": "Radius of gyration", "unit": "Å"},
73
- "rho": {"name": "Density", "unit": "g/cm^3"},
74
- }
75
-
76
-
77
- @st.cache_data
78
- def load_source_csv(source: str) -> pd.DataFrame:
79
- """
80
- Loads data/{SOURCE}.csv, normalizes:
81
- - SMILES column -> 'smiles'
82
- - property columns -> lowercase
83
- - adds 'smiles_canon'
84
- """
85
- path = f"data/{source}.csv"
86
- df = pd.read_csv(path)
87
-
88
- # Normalize SMILES column name
89
- if "SMILES" in df.columns:
90
- df = df.rename(columns={"SMILES": "smiles"})
91
- elif "smiles" not in df.columns:
92
- raise ValueError(f"{path} missing SMILES column")
93
-
94
- # Normalize property column names to lowercase
95
- rename_map = {c: c.lower() for c in df.columns if c != "smiles"}
96
- df = df.rename(columns=rename_map)
97
-
98
- # Canonicalize SMILES
99
- df["smiles_canon"] = df["smiles"].astype(str).apply(canonicalize_smiles)
100
- df = df.dropna(subset=["smiles_canon"]).reset_index(drop=True)
101
-
102
- return df
103
-
104
-
105
- @st.cache_data
106
- def build_index(df: pd.DataFrame) -> dict[str, int]:
107
- """canonical smiles -> row index (first occurrence)"""
108
- idx: dict[str, int] = {}
109
- for i, s in enumerate(df["smiles_canon"].tolist()):
110
- if s and s not in idx:
111
- idx[s] = i
112
- return idx
113
-
114
-
115
- @st.cache_data
116
- def load_polyinfo_csv() -> pd.DataFrame:
117
- """
118
- Loads data/POLYINFO.csv with columns:
119
- SMILES, Polymer_Class, Polymer_Name
120
- Adds canonical smiles column 'smiles_canon'.
121
- Returns empty df if file missing.
122
- """
123
- try:
124
- df = pd.read_csv(POLYINFO_FILE)
125
- except Exception:
126
- return pd.DataFrame(columns=["smiles", "polymer_class", "polymer_name", "smiles_canon"])
127
-
128
- # Normalize columns
129
- if "SMILES" in df.columns:
130
- df = df.rename(columns={"SMILES": "smiles"})
131
- elif "smiles" not in df.columns:
132
- # If the file doesn't have a SMILES column as expected, return empty gracefully
133
- return pd.DataFrame(columns=["smiles", "polymer_class", "polymer_name", "smiles_canon"])
134
-
135
- # Normalize expected meta columns
136
- ren = {}
137
- if "Polymer_Class" in df.columns:
138
- ren["Polymer_Class"] = "polymer_class"
139
- if "Polymer_Name" in df.columns:
140
- ren["Polymer_Name"] = "polymer_name"
141
- df = df.rename(columns=ren)
142
-
143
- # Ensure the columns exist (even if missing in the file)
144
- if "polymer_class" not in df.columns:
145
- df["polymer_class"] = pd.NA
146
- if "polymer_name" not in df.columns:
147
- df["polymer_name"] = pd.NA
148
-
149
- # Canonicalize smiles
150
- df["smiles_canon"] = df["smiles"].astype(str).apply(canonicalize_smiles)
151
- df = df.dropna(subset=["smiles_canon"]).reset_index(drop=True)
152
-
153
- return df
154
-
155
-
156
- @st.cache_data
157
- def load_all_sources():
158
- """
159
- Returns dict:
160
- db["EXP"/"MD"/"DFT"/"GC"] = {"df": df, "idx": idx}
161
- db["POLYINFO"] = {"df": df, "idx": idx}
162
- """
163
- db = {}
164
- for src in SOURCES:
165
- df = load_source_csv(src)
166
- idx = build_index(df)
167
- db[src] = {"df": df, "idx": idx}
168
-
169
- # PolyInfo metadata
170
- pi_df = load_polyinfo_csv()
171
- pi_idx = build_index(pi_df) if not pi_df.empty else {}
172
- db["POLYINFO"] = {"df": pi_df, "idx": pi_idx}
173
-
174
- return db
175
-
176
-
177
- def get_value(db, source: str, smiles_canon: str, prop_key: str):
178
- pack = db[source]
179
- df, idx = pack["df"], pack["idx"]
180
- row_i = idx.get(smiles_canon, None)
181
- if row_i is None:
182
- return None
183
- if prop_key not in df.columns:
184
- return None
185
- val = df.iloc[row_i][prop_key]
186
- if pd.isna(val):
187
- return None
188
- return float(val)
189
-
190
-
191
- def get_polyinfo(db, smiles_canon: str) -> tuple[str | None, str | None]:
192
- """
193
- Returns (polymer_name, polymer_class) if available, else (None, None).
194
- No 'not available' text here.
195
- """
196
- pack = db.get("POLYINFO", None)
197
- if pack is None:
198
- return None, None
199
-
200
- df, idx = pack["df"], pack["idx"]
201
- if df is None or df.empty:
202
- return None, None
203
-
204
- row_i = idx.get(smiles_canon, None)
205
- if row_i is None:
206
- return None, None
207
-
208
- name = df.iloc[row_i].get("polymer_name", None)
209
- cls = df.iloc[row_i].get("polymer_class", None)
210
-
211
- # Clean up NA / empty
212
- if pd.isna(name) or str(name).strip() == "":
213
- name = None
214
- else:
215
- name = str(name).strip()
216
-
217
- if pd.isna(cls) or str(cls).strip() == "":
218
- cls = None
219
- else:
220
- cls = str(cls).strip()
221
-
222
- return name, cls
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/model.py DELETED
@@ -1,312 +0,0 @@
1
- # model.py
2
- from __future__ import annotations
3
-
4
- from typing import List, Optional, Literal
5
-
6
- import torch
7
- import torch.nn as nn
8
- import torch.nn.functional as F
9
- from torch_geometric.data import Batch
10
-
11
- from src.conv import build_gnn_encoder, GNNEncoder
12
-
13
-
14
- def get_activation(name: str) -> nn.Module:
15
- name = name.lower()
16
- if name == "relu":
17
- return nn.ReLU()
18
- if name == "gelu":
19
- return nn.GELU()
20
- if name == "silu":
21
- return nn.SiLU()
22
- if name in ("leaky_relu", "lrelu"):
23
- return nn.LeakyReLU(0.1)
24
- raise ValueError(f"Unknown activation: {name}")
25
-
26
-
27
- class FiLM(nn.Module):
28
- """
29
- Simple FiLM: gamma, beta from condition vector; apply to features as (1+gamma)*h + beta
30
- """
31
- def __init__(self, feat_dim: int, cond_dim: int):
32
- super().__init__()
33
- self.gamma = nn.Linear(cond_dim, feat_dim)
34
- self.beta = nn.Linear(cond_dim, feat_dim)
35
-
36
- def forward(self, h: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
37
- g = self.gamma(cond)
38
- b = self.beta(cond)
39
- return (1.0 + g) * h + b
40
-
41
-
42
- class TaskHead(nn.Module):
43
- """
44
- Per-task MLP head. Input is concatenation of [graph_embed, optional task_embed].
45
- Outputs either a mean only (scalar) or mean+logvar (heteroscedastic).
46
- """
47
- def __init__(
48
- self,
49
- in_dim: int,
50
- hidden_dim: int = 512,
51
- depth: int = 2,
52
- act: str = "relu",
53
- dropout: float = 0.0,
54
- heteroscedastic: bool = False,
55
- ):
56
- super().__init__()
57
- layers: List[nn.Module] = []
58
- d = in_dim
59
- for _ in range(depth):
60
- layers.append(nn.Linear(d, hidden_dim))
61
- layers.append(get_activation(act))
62
- if dropout > 0:
63
- layers.append(nn.Dropout(dropout))
64
- d = hidden_dim
65
- out_dim = 2 if heteroscedastic else 1
66
- layers.append(nn.Linear(d, out_dim))
67
- self.net = nn.Sequential(*layers)
68
- self.hetero = heteroscedastic
69
-
70
- def forward(self, z: torch.Tensor) -> torch.Tensor:
71
- # returns [B, 1] or [B, 2] where [...,0] is mean and [...,1] is logvar if heteroscedastic
72
- return self.net(z)
73
-
74
-
75
- class MultiTaskMultiFidelityModel(nn.Module):
76
- """
77
- General multi-task, multi-fidelity GNN.
78
-
79
- - Any number of tasks (properties) via T = len(task_names)
80
- - Any number of fidelities via num_fids
81
- - Fidelity conditioning with an embedding and FiLM on the graph embedding
82
- - Optional task embeddings concatenated into each task head input
83
- - Single forward returning predictions [B, T] (means); if heteroscedastic, also returns log-variances
84
-
85
- Expected input Batch fields (PyG):
86
- - x : [N_nodes, F_node]
87
- - edge_index : [2, N_edges]
88
- - edge_attr : [N_edges, F_edge] (required if gnn_type="gine")
89
- - batch : [N_nodes]
90
- - fid_idx : [B] or [B, 1] long; integer fidelity per graph
91
-
92
- Notes:
93
- - Targets should already be normalized outside the model; apply inverse transform for plots.
94
- - Loss weighting/equal-importance and curriculum happen in the trainer, not here.
95
- """
96
-
97
- def __init__(
98
- self,
99
- in_dim_node: int,
100
- in_dim_edge: int,
101
- task_names: List[str],
102
- num_fids: int,
103
- gnn_type: Literal["gine", "gin", "gcn"] = "gine",
104
- gnn_emb_dim: int = 256,
105
- gnn_layers: int = 5,
106
- gnn_norm: Literal["batch", "layer", "none"] = "batch",
107
- gnn_readout: Literal["mean", "sum", "max"] = "mean",
108
- gnn_act: str = "relu",
109
- gnn_dropout: float = 0.0,
110
- gnn_residual: bool = True,
111
- # Fidelity conditioning
112
- fid_emb_dim: int = 64,
113
- use_film: bool = True,
114
- # Task conditioning
115
- use_task_embed: bool = True,
116
- task_emb_dim: int = 32,
117
- # Heads
118
- head_hidden: int = 512,
119
- head_depth: int = 2,
120
- head_act: str = "relu",
121
- head_dropout: float = 0.0,
122
- heteroscedastic: bool = False,
123
- # Optional homoscedastic task uncertainty (used in loss, kept here for checkpoint parity)
124
- use_task_uncertainty: bool = False,
125
- # Embedding regularization (used via regularization_loss)
126
- fid_emb_l2: float = 0.0,
127
- task_emb_l2: float = 0.0,
128
- ):
129
- super().__init__()
130
- self.task_names = list(task_names)
131
- self.num_tasks = len(task_names)
132
- self.num_fids = int(num_fids)
133
- self.hetero = heteroscedastic
134
- self.fid_emb_l2 = float(fid_emb_l2)
135
- self.task_emb_l2 = float(task_emb_l2)
136
- self.use_film = use_film
137
- self.use_task_embed = use_task_embed
138
-
139
- # Optional learned homoscedastic uncertainty per task (trainer may use it)
140
- self.use_task_uncertainty = bool(use_task_uncertainty)
141
- if self.use_task_uncertainty:
142
- self.task_log_sigma2 = nn.Parameter(torch.zeros(self.num_tasks))
143
- else:
144
- self.task_log_sigma2 = None
145
-
146
- # Encoder
147
- self.encoder: GNNEncoder = build_gnn_encoder(
148
- in_dim_node=in_dim_node,
149
- emb_dim=gnn_emb_dim,
150
- num_layers=gnn_layers,
151
- gnn_type=gnn_type,
152
- in_dim_edge=in_dim_edge,
153
- act=gnn_act,
154
- dropout=gnn_dropout,
155
- residual=gnn_residual,
156
- norm=gnn_norm,
157
- readout=gnn_readout,
158
- )
159
-
160
- # Fidelity embedding + FiLM
161
- self.fid_embed = nn.Embedding(self.num_fids, fid_emb_dim) if fid_emb_dim > 0 else None
162
- self.film = FiLM(gnn_emb_dim, fid_emb_dim) if (use_film and fid_emb_dim > 0) else None
163
-
164
- # --- Compute the true feature dim sent to heads ---
165
- # If FiLM is ON: g stays [B, gnn_emb_dim]
166
- # If FiLM is OFF but fid_embed exists: we CONCAT c → g becomes [B, gnn_emb_dim + fid_emb_dim]
167
- self.gnn_out_dim = gnn_emb_dim + (fid_emb_dim if (self.fid_embed is not None and self.film is None) else 0)
168
-
169
- # Task embeddings
170
- self.task_embed = nn.Embedding(self.num_tasks, task_emb_dim) if (use_task_embed and task_emb_dim > 0) else None
171
-
172
- # Per-task heads
173
- head_in_dim = self.gnn_out_dim + (task_emb_dim if self.task_embed is not None else 0)
174
- self.heads = nn.ModuleList([
175
- TaskHead(
176
- in_dim=head_in_dim,
177
- hidden_dim=head_hidden,
178
- depth=head_depth,
179
- act=head_act,
180
- dropout=head_dropout,
181
- heteroscedastic=heteroscedastic,
182
- ) for _ in range(self.num_tasks)
183
- ])
184
-
185
-
186
- def reset_parameters(self):
187
- if self.fid_embed is not None:
188
- nn.init.normal_(self.fid_embed.weight, mean=0.0, std=0.02)
189
- if self.task_embed is not None:
190
- nn.init.normal_(self.task_embed.weight, mean=0.0, std=0.02)
191
- # Encoder/heads rely on their internal initializations.
192
-
193
- def forward(self, data: Batch) -> dict:
194
- """
195
- Returns:
196
- {
197
- "pred": [B, T] means,
198
- "logvar": [B, T] optional if heteroscedastic,
199
- "h": [B, D] graph embedding after FiLM (useful for diagnostics).
200
- }
201
- """
202
- x, edge_index = data.x, data.edge_index
203
- edge_attr = getattr(data, "edge_attr", None)
204
- batch = data.batch
205
- if edge_attr is None and hasattr(self.encoder, "gnn_type") and self.encoder.gnn_type == "gine":
206
- raise ValueError("GINE encoder requires edge_attr, but Batch.edge_attr is None.")
207
-
208
- # Graph embedding
209
- g = self.encoder(x, edge_index, edge_attr, batch) # [B, D]
210
-
211
- # Fidelity conditioning
212
- fid_idx = data.fid_idx.view(-1).long() # [B]
213
- if self.fid_embed is not None:
214
- c = self.fid_embed(fid_idx) # [B, C]
215
- if self.film is not None:
216
- g = self.film(g, c) # [B, D]
217
- else:
218
- g = torch.cat([g, c], dim=-1)
219
-
220
- # Per-task heads
221
- preds: List[torch.Tensor] = []
222
- logvars: Optional[List[torch.Tensor]] = [] if self.hetero else None
223
- for t_idx, head in enumerate(self.heads):
224
- if self.task_embed is not None:
225
- tvec = self.task_embed.weight[t_idx].unsqueeze(0).expand(g.size(0), -1)
226
- z = torch.cat([g, tvec], dim=-1)
227
- else:
228
- z = g
229
- out = head(z) # [B, 1] or [B, 2]
230
- if self.hetero:
231
- mu = out[..., 0:1]
232
- lv = out[..., 1:2]
233
- preds.append(mu)
234
- logvars.append(lv) # type: ignore[arg-type]
235
- else:
236
- preds.append(out)
237
-
238
- pred = torch.cat(preds, dim=-1) # [B, T]
239
- result = {"pred": pred, "h": g}
240
- if self.hetero and logvars is not None:
241
- result["logvar"] = torch.cat(logvars, dim=-1) # [B, T]
242
- return result
243
-
244
- def regularization_loss(self) -> torch.Tensor:
245
- """
246
- Optional small L2 on embeddings to keep them bounded.
247
- """
248
- device = next(self.parameters()).device
249
- reg = torch.zeros([], device=device)
250
- if self.fid_embed is not None and self.fid_emb_l2 > 0:
251
- reg = reg + self.fid_emb_l2 * (self.fid_embed.weight.pow(2).mean())
252
- if self.task_embed is not None and self.task_emb_l2 > 0:
253
- reg = reg + self.task_emb_l2 * (self.task_embed.weight.pow(2).mean())
254
- return reg
255
-
256
-
257
- def build_model(
258
- *,
259
- in_dim_node: int,
260
- in_dim_edge: int,
261
- task_names: List[str],
262
- num_fids: int,
263
- gnn_type: Literal["gine", "gin", "gcn"] = "gine",
264
- gnn_emb_dim: int = 256,
265
- gnn_layers: int = 5,
266
- gnn_norm: Literal["batch", "layer", "none"] = "batch",
267
- gnn_readout: Literal["mean", "sum", "max"] = "mean",
268
- gnn_act: str = "relu",
269
- gnn_dropout: float = 0.0,
270
- gnn_residual: bool = True,
271
- fid_emb_dim: int = 64,
272
- use_film: bool = True,
273
- use_task_embed: bool = True,
274
- task_emb_dim: int = 32,
275
- head_hidden: int = 512,
276
- use_task_uncertainty: bool = False,
277
- head_depth: int = 2,
278
- head_act: str = "relu",
279
- head_dropout: float = 0.0,
280
- heteroscedastic: bool = False,
281
- fid_emb_l2: float = 0.0,
282
- task_emb_l2: float = 0.0,
283
- ) -> MultiTaskMultiFidelityModel:
284
- """
285
- Factory to construct the multi-task, multi-fidelity model with a consistent API.
286
- """
287
- return MultiTaskMultiFidelityModel(
288
- in_dim_node=in_dim_node,
289
- in_dim_edge=in_dim_edge,
290
- task_names=task_names,
291
- num_fids=num_fids,
292
- gnn_type=gnn_type,
293
- gnn_emb_dim=gnn_emb_dim,
294
- gnn_layers=gnn_layers,
295
- gnn_norm=gnn_norm,
296
- gnn_readout=gnn_readout,
297
- gnn_act=gnn_act,
298
- gnn_dropout=gnn_dropout,
299
- gnn_residual=gnn_residual,
300
- fid_emb_dim=fid_emb_dim,
301
- use_film=use_film,
302
- use_task_embed=use_task_embed,
303
- task_emb_dim=task_emb_dim,
304
- head_hidden=head_hidden,
305
- head_depth=head_depth,
306
- head_act=head_act,
307
- head_dropout=head_dropout,
308
- heteroscedastic=heteroscedastic,
309
- fid_emb_l2=fid_emb_l2,
310
- task_emb_l2=task_emb_l2,
311
- use_task_uncertainty=use_task_uncertainty,
312
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/predictor.py DELETED
@@ -1,193 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import re
4
- from pathlib import Path
5
- from typing import Dict, List, Optional, Tuple
6
-
7
- import numpy as np
8
- import torch
9
- from torch_geometric.data import Data
10
-
11
- from src.data_builder import featurize_smiles, TargetScaler
12
- from src.model import build_model
13
- from src.utils import to_device, apply_inverse_transform
14
-
15
-
16
- # -------------------------
17
- # Unit correction (ML only)
18
- # -------------------------
19
- POST_SCALE = {
20
- "td": 1e-7,
21
- "dif": 1e-5,
22
- "visc": 1e-3,
23
- }
24
-
25
-
26
- def _load_scaler_compat(path: Path) -> TargetScaler:
27
- blob = torch.load(path, map_location="cpu")
28
- if "mean" not in blob or "std" not in blob:
29
- raise RuntimeError(f"Unrecognized target_scaler format: {path}")
30
-
31
- ts = TargetScaler(
32
- transforms=blob.get("transforms", None),
33
- eps=blob.get("eps", None),
34
- )
35
- ts.load_state_dict({
36
- "mean": blob["mean"].float(),
37
- "std": blob["std"].float(),
38
- "transforms": blob.get("transforms", ts.transforms),
39
- "eps": blob.get("eps", ts.eps),
40
- })
41
- ts.targets = [str(t).lower() for t in blob.get("targets", [])]
42
- return ts
43
-
44
-
45
- def _infer_seed_from_name(path: Path) -> Optional[int]:
46
- m = re.search(r"_([0-9]+)\.pt$", path.name)
47
- return int(m.group(1)) if m else None
48
-
49
-
50
- def _make_one_graph(smiles: str) -> Data:
51
- x, edge_index, edge_attr = featurize_smiles(smiles)
52
- d = Data(
53
- x=x,
54
- edge_index=edge_index,
55
- edge_attr=edge_attr,
56
- y=torch.zeros(1, 1),
57
- y_mask=torch.zeros(1, 1, dtype=torch.bool),
58
- fid_idx=torch.tensor([0], dtype=torch.long),
59
- )
60
- d.smiles = smiles
61
- return d
62
-
63
-
64
- class SingleTaskEnsemblePredictor:
65
- """
66
- Single-task ensemble:
67
- models/single_models/{prop}_single_model_{seed}.pt
68
- models/single_models/{prop}_single_scalar_{seed}.pt
69
- """
70
-
71
- def __init__(self, models_dir: str = "models/single_models", device: str = "cpu"):
72
- self.models_dir = Path(models_dir)
73
- self.device = torch.device(device if device == "cuda" and torch.cuda.is_available() else "cpu")
74
- self._cache: Dict[Tuple[str, int], Tuple[Optional[torch.nn.Module], TargetScaler, dict]] = {}
75
-
76
- def available_seeds(self, prop: str) -> List[int]:
77
- prop = prop.lower()
78
- seeds = []
79
- for p in self.models_dir.glob(f"{prop}_single_model_*.pt"):
80
- s = _infer_seed_from_name(p)
81
- if s is not None:
82
- seeds.append(s)
83
- return sorted(set(seeds))
84
-
85
- def _load_one(self, prop: str, seed: int):
86
- prop = prop.lower()
87
- key = (prop, seed)
88
- if key in self._cache:
89
- return self._cache[key]
90
-
91
- ckpt_path = self.models_dir / f"{prop}_single_model_{seed}.pt"
92
- scaler_path = self.models_dir / f"{prop}_single_scalar_{seed}.pt"
93
- if not ckpt_path.exists() or not scaler_path.exists():
94
- raise FileNotFoundError(f"Missing model/scaler for {prop} seed {seed}")
95
-
96
- ckpt = torch.load(ckpt_path, map_location=self.device)
97
- train_args = ckpt.get("args", {})
98
-
99
- scaler = _load_scaler_compat(scaler_path)
100
- task_names = list(getattr(scaler, "targets", [])) or [prop]
101
-
102
- meta = {"train_args": train_args, "task_names": task_names}
103
- self._cache[key] = (None, scaler, meta)
104
- return self._cache[key]
105
-
106
- def _build_model_if_needed(self, prop: str, seed: int, in_dim_node: int, in_dim_edge: int):
107
- prop = prop.lower()
108
- key = (prop, seed)
109
- model, scaler, meta = self._cache[key]
110
- if model is not None:
111
- return model, scaler, meta
112
-
113
- train_args = meta["train_args"]
114
- task_names = meta["task_names"]
115
-
116
- ckpt_path = self.models_dir / f"{prop}_single_model_{seed}.pt"
117
- ckpt = torch.load(ckpt_path, map_location=self.device)
118
- state_dict = ckpt["model"]
119
-
120
- # infer num_fids from checkpoint
121
- if "fid_embed.weight" in state_dict:
122
- num_fids = state_dict["fid_embed.weight"].shape[0]
123
- else:
124
- num_fids = 1
125
-
126
- model = build_model(
127
- in_dim_node=in_dim_node,
128
- in_dim_edge=in_dim_edge,
129
- task_names=task_names,
130
- num_fids=num_fids,
131
- gnn_type=train_args.get("gnn_type", "gine"),
132
- gnn_emb_dim=train_args.get("gnn_emb_dim", 256),
133
- gnn_layers=train_args.get("gnn_layers", 5),
134
- gnn_norm=train_args.get("gnn_norm", "batch"),
135
- gnn_readout=train_args.get("gnn_readout", "mean"),
136
- gnn_act=train_args.get("gnn_act", "relu"),
137
- gnn_dropout=train_args.get("gnn_dropout", 0.0),
138
- gnn_residual=train_args.get("gnn_residual", True),
139
- fid_emb_dim=train_args.get("fid_emb_dim", 64),
140
- use_film=train_args.get("use_film", True),
141
- use_task_embed=train_args.get("use_task_embed", True),
142
- task_emb_dim=train_args.get("task_emb_dim", 32),
143
- head_hidden=train_args.get("head_hidden", 512),
144
- head_depth=train_args.get("head_depth", 2),
145
- head_act=train_args.get("head_act", "relu"),
146
- head_dropout=train_args.get("head_dropout", 0.0),
147
- heteroscedastic=train_args.get("heteroscedastic", False),
148
- fid_emb_l2=0.0,
149
- task_emb_l2=0.0,
150
- use_task_uncertainty=train_args.get("task_uncertainty", False),
151
- ).to(self.device)
152
-
153
- model.load_state_dict(state_dict, strict=True)
154
- model.eval()
155
-
156
- self._cache[key] = (model, scaler, meta)
157
- return model, scaler, meta
158
-
159
- def predict_mean_std(self, smiles: str, prop: str) -> Tuple[Optional[float], Optional[float], Dict[int, float]]:
160
- prop = prop.lower()
161
- seeds = self.available_seeds(prop)
162
- if not seeds:
163
- return None, None, {}
164
-
165
- try:
166
- g = _make_one_graph(smiles)
167
- except Exception:
168
- return None, None, {}
169
-
170
- in_dim_node = g.x.shape[1]
171
- in_dim_edge = g.edge_attr.shape[1]
172
-
173
- per_seed: Dict[int, float] = {}
174
- with torch.no_grad():
175
- for seed in seeds:
176
- self._load_one(prop, seed)
177
- model, scaler, meta = self._build_model_if_needed(prop, seed, in_dim_node, in_dim_edge)
178
-
179
- batch = to_device(g, self.device)
180
- out = model(batch)
181
- pred_n = out["pred"] # [1, 1]
182
- pred = apply_inverse_transform(pred_n, scaler).cpu().numpy().reshape(-1)
183
- val = float(pred[0])
184
-
185
- # unit correction
186
- val *= POST_SCALE.get(prop, 1.0)
187
-
188
- per_seed[seed] = val
189
-
190
- vals = np.array(list(per_seed.values()), dtype=float)
191
- mean = float(vals.mean())
192
- std = float(vals.std(ddof=1)) if len(vals) > 1 else 0.0
193
- return mean, std, per_seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/predictor_multitask.py DELETED
@@ -1,209 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import re
4
- from pathlib import Path
5
- from typing import Dict, List, Optional, Tuple
6
-
7
- import numpy as np
8
- import torch
9
- from torch_geometric.data import Data
10
-
11
- from src.data_builder import featurize_smiles, TargetScaler
12
- from src.model import build_model
13
- from src.utils import to_device, apply_inverse_transform
14
-
15
-
16
- # -------------------------
17
- # Unit correction (ML only)
18
- # -------------------------
19
- POST_SCALE = {
20
- "td": 1e-7,
21
- "dif": 1e-5,
22
- "visc": 1e-3,
23
- }
24
-
25
-
26
- def _load_scaler_compat(path: Path) -> TargetScaler:
27
- blob = torch.load(path, map_location="cpu")
28
- if "mean" not in blob or "std" not in blob:
29
- raise RuntimeError(f"Unrecognized target_scaler format: {path}")
30
-
31
- ts = TargetScaler(
32
- transforms=blob.get("transforms", None),
33
- eps=blob.get("eps", None),
34
- )
35
- ts.load_state_dict({
36
- "mean": blob["mean"].float(),
37
- "std": blob["std"].float(),
38
- "transforms": blob.get("transforms", ts.transforms),
39
- "eps": blob.get("eps", ts.eps),
40
- })
41
- ts.targets = [str(t).lower() for t in blob.get("targets", [])]
42
- return ts
43
-
44
-
45
- def _infer_seed(path: Path) -> Optional[int]:
46
- m = re.search(r"_([0-9]+)\.pt$", path.name)
47
- return int(m.group(1)) if m else None
48
-
49
-
50
- def _make_one_graph(smiles: str, T: int, fid_idx: int = 0) -> Data:
51
- x, edge_index, edge_attr = featurize_smiles(smiles)
52
- d = Data(
53
- x=x,
54
- edge_index=edge_index,
55
- edge_attr=edge_attr,
56
- y=torch.zeros(1, T),
57
- y_mask=torch.zeros(1, T, dtype=torch.bool),
58
- fid_idx=torch.tensor([fid_idx], dtype=torch.long),
59
- )
60
- d.smiles = smiles
61
- return d
62
-
63
-
64
- class MultiTaskEnsemblePredictor:
65
- """
66
- Multi-task ensemble:
67
- models/multitask_models/{task}_model_{seed}.pt
68
- models/multitask_models/{task}_scalar_{seed}.pt
69
- """
70
-
71
- def __init__(self, models_dir: str = "models/multitask_models", device: str = "cpu"):
72
- self.models_dir = Path(models_dir)
73
- self.device = torch.device(device if device == "cuda" and torch.cuda.is_available() else "cpu")
74
- self._cache: Dict[Tuple[str, int], Tuple[Optional[torch.nn.Module], TargetScaler, dict]] = {}
75
-
76
- def available_seeds(self, task: str) -> List[int]:
77
- task = task.strip().lower()
78
- seeds = []
79
- for p in self.models_dir.glob(f"{task}_model_*.pt"):
80
- s = _infer_seed(p)
81
- if s is not None:
82
- seeds.append(s)
83
- return sorted(set(seeds))
84
-
85
- def _load_one_meta(self, task: str, seed: int):
86
- task = task.strip().lower()
87
- key = (task, seed)
88
- if key in self._cache:
89
- return self._cache[key]
90
-
91
- ckpt_path = self.models_dir / f"{task}_model_{seed}.pt"
92
- scaler_path = self.models_dir / f"{task}_scalar_{seed}.pt"
93
- if not ckpt_path.exists() or not scaler_path.exists():
94
- raise FileNotFoundError(f"Missing model/scaler for task={task} seed={seed}")
95
-
96
- ckpt = torch.load(ckpt_path, map_location=self.device)
97
- state_dict = ckpt["model"]
98
- train_args = ckpt.get("args", {})
99
-
100
- scaler = _load_scaler_compat(scaler_path)
101
- task_names = list(getattr(scaler, "targets", []))
102
- if not task_names:
103
- raise RuntimeError(f"No targets found in scaler: {scaler_path}")
104
-
105
- if "fid_embed.weight" in state_dict:
106
- num_fids = state_dict["fid_embed.weight"].shape[0]
107
- else:
108
- num_fids = 1
109
-
110
- meta = {
111
- "train_args": train_args,
112
- "task_names": task_names,
113
- "num_fids": num_fids,
114
- }
115
- self._cache[key] = (None, scaler, meta)
116
- return self._cache[key]
117
-
118
- def _build_if_needed(self, task: str, seed: int, in_dim_node: int, in_dim_edge: int):
119
- task = task.strip().lower()
120
- key = (task, seed)
121
- model, scaler, meta = self._cache[key]
122
- if model is not None:
123
- return model, scaler, meta
124
-
125
- train_args = meta["train_args"]
126
- task_names = meta["task_names"]
127
- num_fids = meta["num_fids"]
128
-
129
- model = build_model(
130
- in_dim_node=in_dim_node,
131
- in_dim_edge=in_dim_edge,
132
- task_names=task_names,
133
- num_fids=num_fids,
134
- gnn_type=train_args.get("gnn_type", "gine"),
135
- gnn_emb_dim=train_args.get("gnn_emb_dim", 256),
136
- gnn_layers=train_args.get("gnn_layers", 5),
137
- gnn_norm=train_args.get("gnn_norm", "batch"),
138
- gnn_readout=train_args.get("gnn_readout", "mean"),
139
- gnn_act=train_args.get("gnn_act", "relu"),
140
- gnn_dropout=train_args.get("gnn_dropout", 0.0),
141
- gnn_residual=train_args.get("gnn_residual", True),
142
- fid_emb_dim=train_args.get("fid_emb_dim", 64),
143
- use_film=train_args.get("use_film", True),
144
- use_task_embed=train_args.get("use_task_embed", True),
145
- task_emb_dim=train_args.get("task_emb_dim", 32),
146
- head_hidden=train_args.get("head_hidden", 512),
147
- head_depth=train_args.get("head_depth", 2),
148
- head_act=train_args.get("head_act", "relu"),
149
- head_dropout=train_args.get("head_dropout", 0.0),
150
- heteroscedastic=train_args.get("heteroscedastic", False),
151
- fid_emb_l2=0.0,
152
- task_emb_l2=0.0,
153
- use_task_uncertainty=train_args.get("task_uncertainty", False),
154
- ).to(self.device)
155
-
156
- ckpt_path = self.models_dir / f"{task}_model_{seed}.pt"
157
- ckpt = torch.load(ckpt_path, map_location=self.device)
158
- model.load_state_dict(ckpt["model"], strict=True)
159
- model.eval()
160
-
161
- self._cache[key] = (model, scaler, meta)
162
- return model, scaler, meta
163
-
164
- def predict_mean_std(self, smiles: str, prop_key: str, task: str) -> Tuple[Optional[float], Optional[float], Dict[int, float]]:
165
- task = task.strip().lower()
166
- prop_key = prop_key.lower()
167
-
168
- seeds = self.available_seeds(task)
169
- if not seeds:
170
- return None, None, {}
171
-
172
- self._load_one_meta(task, seeds[0])
173
- _, scaler0, meta0 = self._cache[(task, seeds[0])]
174
- targets = list(meta0["task_names"]) # already lower()
175
- if prop_key not in targets:
176
- return None, None, {}
177
-
178
- t_idx = targets.index(prop_key)
179
- T = len(targets)
180
-
181
- try:
182
- g = _make_one_graph(smiles, T=T, fid_idx=0)
183
- except Exception:
184
- return None, None, {}
185
-
186
- in_dim_node = g.x.shape[1]
187
- in_dim_edge = g.edge_attr.shape[1]
188
-
189
- per_seed: Dict[int, float] = {}
190
- with torch.no_grad():
191
- for seed in seeds:
192
- self._load_one_meta(task, seed)
193
- model, scaler, meta = self._build_if_needed(task, seed, in_dim_node, in_dim_edge)
194
-
195
- batch = to_device(g, self.device)
196
- out = model(batch)
197
- pred_n = out["pred"] # [1, T]
198
- pred = apply_inverse_transform(pred_n, scaler).cpu().numpy().reshape(-1)
199
- val = float(pred[t_idx])
200
-
201
- # unit correction
202
- val *= POST_SCALE.get(prop_key, 1.0)
203
-
204
- per_seed[seed] = val
205
-
206
- vals = np.array(list(per_seed.values()), dtype=float)
207
- mean = float(vals.mean())
208
- std = float(vals.std(ddof=1)) if len(vals) > 1 else 0.0
209
- return mean, std, per_seed
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/predictor_router.py DELETED
@@ -1,45 +0,0 @@
1
- from __future__ import annotations
2
-
3
- import json
4
- from pathlib import Path
5
- from typing import Dict, Optional, Tuple
6
-
7
- from src.predictor import SingleTaskEnsemblePredictor
8
- from src.predictor_multitask import MultiTaskEnsemblePredictor
9
-
10
-
11
- class RouterPredictor:
12
- """
13
- Routes each property to either:
14
- - single-task ensemble (models/single_models)
15
- - multitask ensemble (models/multitask_models/{task}_*)
16
- based on models/best_model_map.json
17
- """
18
-
19
- def __init__(
20
- self,
21
- map_path: str = "models/best_model_map.json",
22
- single_dir: str = "models/single_models",
23
- multitask_dir: str = "models/multitask_models",
24
- device: str = "cpu",
25
- ):
26
- self.map_path = Path(map_path)
27
- self.map: Dict[str, dict] = json.load(open(self.map_path))
28
- self.single = SingleTaskEnsemblePredictor(models_dir=single_dir, device=device)
29
- self.multi = MultiTaskEnsemblePredictor(models_dir=multitask_dir, device=device)
30
-
31
- def predict_mean_std(self, smiles: str, prop: str) -> Tuple[Optional[float], Optional[float], dict, str]:
32
- prop = prop.lower()
33
- cfg = self.map.get(prop, {"family": "single"})
34
-
35
- fam = cfg.get("family", "single").lower()
36
- if fam == "multitask":
37
- task = str(cfg.get("task", "all")).lower()
38
- mean, std, per_seed = self.multi.predict_mean_std(smiles, prop_key=prop, task=task)
39
- label = f"multitask:{task}"
40
- return mean, std, per_seed, label
41
-
42
- # default: single
43
- mean, std, per_seed = self.single.predict_mean_std(smiles, prop)
44
- label = "single"
45
- return mean, std, per_seed, label
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/rnn_smiles/__init__.py DELETED
@@ -1,22 +0,0 @@
1
- """RNN-based SMILES generation helpers for Streamlit pages."""
2
-
3
- from .generator import (
4
- canonicalize_smiles,
5
- filter_novel_smiles,
6
- generate_smiles,
7
- load_existing_smiles_set,
8
- load_rnn_model,
9
- )
10
- from .rnn import MultiGRU, RNN
11
- from .vocabulary import Vocabulary
12
-
13
- __all__ = [
14
- "canonicalize_smiles",
15
- "filter_novel_smiles",
16
- "generate_smiles",
17
- "load_existing_smiles_set",
18
- "load_rnn_model",
19
- "MultiGRU",
20
- "RNN",
21
- "Vocabulary",
22
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/rnn_smiles/generator.py DELETED
@@ -1,175 +0,0 @@
1
- """Streamlit integration helpers for RNN SMILES generation."""
2
-
3
- from __future__ import annotations
4
-
5
- from pathlib import Path
6
- from typing import Iterable, Sequence
7
-
8
- import pandas as pd
9
- import streamlit as st
10
- import torch
11
- from rdkit import Chem, RDLogger
12
-
13
- from .rnn import RNN
14
- from .vocabulary import Vocabulary
15
-
16
- RDLogger.DisableLog("rdApp.*")
17
-
18
-
19
- def canonicalize_smiles(smiles: str) -> str | None:
20
- s = (smiles or "").strip()
21
- if not s:
22
- return None
23
- mol = Chem.MolFromSmiles(s)
24
- if mol is None:
25
- return None
26
- return Chem.MolToSmiles(mol, canonical=True)
27
-
28
-
29
- def _find_smiles_column(path: Path) -> str:
30
- header = pd.read_csv(path, nrows=0)
31
- for col in header.columns:
32
- if str(col).strip().lower() == "smiles":
33
- return col
34
- raise ValueError(f"No SMILES column found in {path}")
35
-
36
-
37
- def _load_checkpoint(path: Path, device: torch.device) -> dict:
38
- # Support both new/old torch signatures while preferring secure load mode.
39
- try:
40
- state = torch.load(path, map_location=device, weights_only=True)
41
- except TypeError:
42
- state = torch.load(path, map_location=device)
43
- if isinstance(state, dict) and isinstance(state.get("state_dict"), dict):
44
- state = state["state_dict"]
45
- if not isinstance(state, dict):
46
- raise RuntimeError(f"Checkpoint does not contain a state dict: {path}")
47
- return state
48
-
49
-
50
- @st.cache_resource(show_spinner=False)
51
- def load_rnn_model(ckpt_path: str | Path, voc_path: str | Path) -> tuple[RNN, Vocabulary]:
52
- ckpt_path = Path(ckpt_path).expanduser().resolve()
53
- voc_path = Path(voc_path).expanduser().resolve()
54
-
55
- if not ckpt_path.exists():
56
- raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
57
- if not voc_path.exists():
58
- raise FileNotFoundError(f"Vocabulary not found: {voc_path}")
59
-
60
- voc = Vocabulary(init_from_file=str(voc_path))
61
- model = RNN(voc)
62
- model_device = next(model.rnn.parameters()).device
63
- state = _load_checkpoint(ckpt_path, model_device)
64
-
65
- ckpt_vocab_size = None
66
- if "embedding.weight" in state:
67
- ckpt_vocab_size = int(state["embedding.weight"].shape[0])
68
- if ckpt_vocab_size is not None and ckpt_vocab_size != voc.vocab_size:
69
- raise RuntimeError(
70
- f"Vocabulary size mismatch: voc has {voc.vocab_size} tokens, "
71
- f"checkpoint expects {ckpt_vocab_size}. "
72
- "Use the matching vocab file for this checkpoint."
73
- )
74
-
75
- model.rnn.load_state_dict(state)
76
- model.rnn.eval()
77
- return model, voc
78
-
79
-
80
- def _sample_with_temperature(
81
- model: RNN, voc: Vocabulary, batch_size: int, max_length: int, temperature: float
82
- ) -> torch.Tensor:
83
- temp = max(float(temperature), 1e-6)
84
- device = next(model.rnn.parameters()).device
85
- start_token = torch.full((batch_size,), voc.vocab["GO"], dtype=torch.long, device=device)
86
- h = model.rnn.init_h(batch_size)
87
- x = start_token
88
-
89
- sequences: list[torch.Tensor] = []
90
- finished = torch.zeros(batch_size, dtype=torch.bool, device=device)
91
-
92
- for _ in range(max_length):
93
- logits, h = model.rnn(x, h)
94
- logits = logits / temp
95
- prob = torch.softmax(logits, dim=1)
96
- x = torch.multinomial(prob, 1).view(-1)
97
- sequences.append(x.view(-1, 1))
98
- finished = finished | (x == voc.vocab["EOS"])
99
- if torch.all(finished):
100
- break
101
-
102
- if not sequences:
103
- return torch.empty((batch_size, 0), dtype=torch.long, device=device)
104
- return torch.cat(sequences, dim=1)
105
-
106
-
107
- def generate_smiles(
108
- model: RNN,
109
- voc: Vocabulary,
110
- n: int,
111
- max_length: int,
112
- temperature: float = 1.0,
113
- ) -> list[str]:
114
- if n <= 0:
115
- return []
116
- max_length = max(int(max_length), 1)
117
-
118
- with torch.no_grad():
119
- if abs(float(temperature) - 1.0) < 1e-8:
120
- seqs, _, _ = model.sample(int(n), max_length=max_length)
121
- else:
122
- seqs = _sample_with_temperature(
123
- model,
124
- voc,
125
- int(n),
126
- max_length,
127
- float(temperature),
128
- )
129
- arr = seqs.detach().cpu().numpy()
130
-
131
- output: list[str] = []
132
- for seq in arr:
133
- output.append(voc.decode(seq))
134
- return output
135
-
136
-
137
- def filter_novel_smiles(smiles: Iterable[str], existing: set[str]) -> list[str]:
138
- novel: list[str] = []
139
- seen: set[str] = set()
140
- for smi in smiles:
141
- canonical = canonicalize_smiles(smi)
142
- if canonical is None:
143
- continue
144
- if canonical in seen:
145
- continue
146
- seen.add(canonical)
147
- if canonical in existing:
148
- continue
149
- novel.append(canonical)
150
- return novel
151
-
152
-
153
- @st.cache_resource(show_spinner=False)
154
- def load_existing_smiles_set(csv_paths: Sequence[str | Path], chunksize: int = 200_000) -> set[str]:
155
- existing: set[str] = set()
156
- for p in csv_paths:
157
- path = Path(p)
158
- if not path.exists():
159
- continue
160
- col = _find_smiles_column(path)
161
- for chunk in pd.read_csv(path, usecols=[col], chunksize=int(chunksize)):
162
- for smiles in chunk[col].astype(str):
163
- canonical = canonicalize_smiles(smiles)
164
- if canonical:
165
- existing.add(canonical)
166
- return existing
167
-
168
-
169
- __all__ = [
170
- "canonicalize_smiles",
171
- "load_rnn_model",
172
- "generate_smiles",
173
- "filter_novel_smiles",
174
- "load_existing_smiles_set",
175
- ]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/rnn_smiles/rnn.py DELETED
@@ -1,89 +0,0 @@
1
- """Core GRU model used for polymer SMILES generation."""
2
-
3
- from __future__ import annotations
4
-
5
- import torch
6
- import torch.nn as nn
7
- import torch.nn.functional as F
8
-
9
-
10
- class MultiGRU(nn.Module):
11
- def __init__(self, vocab_size: int):
12
- super().__init__()
13
- self.embedding = nn.Embedding(vocab_size, 128)
14
- self.gru_1 = nn.GRUCell(128, 512)
15
- self.gru_2 = nn.GRUCell(512, 512)
16
- self.gru_3 = nn.GRUCell(512, 512)
17
- self.linear = nn.Linear(512, vocab_size)
18
-
19
- def forward(self, x: torch.Tensor, h: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
20
- x = self.embedding(x)
21
- h_out = torch.zeros_like(h)
22
- x = h_out[0] = self.gru_1(x, h[0])
23
- x = h_out[1] = self.gru_2(x, h[1])
24
- x = h_out[2] = self.gru_3(x, h[2])
25
- x = self.linear(x)
26
- return x, h_out
27
-
28
- def init_h(self, batch_size: int) -> torch.Tensor:
29
- device = next(self.parameters()).device
30
- return torch.zeros(3, batch_size, 512, device=device)
31
-
32
-
33
- def nll_loss(log_probs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
34
- # Gather selected token log-probability for each sample in batch.
35
- return log_probs.gather(1, targets.contiguous().view(-1, 1)).squeeze(1)
36
-
37
-
38
- class RNN:
39
- def __init__(self, voc):
40
- self.rnn = MultiGRU(voc.vocab_size)
41
- if torch.cuda.is_available():
42
- self.rnn.cuda()
43
- self.voc = voc
44
-
45
- def likelihood(self, target: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
46
- batch_size, seq_length = target.size()
47
- device = target.device
48
- start_token = torch.full((batch_size, 1), self.voc.vocab["GO"], dtype=torch.long, device=device)
49
- x = torch.cat((start_token, target[:, :-1]), 1)
50
- h = self.rnn.init_h(batch_size)
51
-
52
- log_probs = torch.zeros(batch_size, device=device)
53
- entropy = torch.zeros(batch_size, device=device)
54
- for step in range(seq_length):
55
- logits, h = self.rnn(x[:, step], h)
56
- log_prob = F.log_softmax(logits, dim=1)
57
- prob = F.softmax(logits, dim=1)
58
- log_probs += nll_loss(log_prob, target[:, step])
59
- entropy += -torch.sum((log_prob * prob), 1)
60
- return log_probs, entropy
61
-
62
- def sample(self, batch_size: int, max_length: int = 140) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
63
- device = next(self.rnn.parameters()).device
64
- start_token = torch.full((batch_size,), self.voc.vocab["GO"], dtype=torch.long, device=device)
65
- h = self.rnn.init_h(batch_size)
66
- x = start_token
67
-
68
- sequences: list[torch.Tensor] = []
69
- log_probs = torch.zeros(batch_size, device=device)
70
- finished = torch.zeros(batch_size, dtype=torch.bool, device=device)
71
- entropy = torch.zeros(batch_size, device=device)
72
-
73
- for _ in range(max_length):
74
- logits, h = self.rnn(x, h)
75
- prob = F.softmax(logits, dim=1)
76
- log_prob = F.log_softmax(logits, dim=1)
77
- x = torch.multinomial(prob, 1).view(-1)
78
- sequences.append(x.view(-1, 1))
79
- log_probs += nll_loss(log_prob, x)
80
- entropy += -torch.sum((log_prob * prob), 1)
81
- finished = finished | (x == self.voc.vocab["EOS"])
82
- if torch.all(finished):
83
- break
84
-
85
- if sequences:
86
- stacked = torch.cat(sequences, 1)
87
- else:
88
- stacked = torch.empty((batch_size, 0), dtype=torch.long, device=device)
89
- return stacked, log_probs, entropy
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/rnn_smiles/utils.py DELETED
@@ -1,15 +0,0 @@
1
- """Utility helpers used by the legacy-style RNN generator."""
2
-
3
- from __future__ import annotations
4
-
5
- import numpy as np
6
- import torch
7
-
8
-
9
- def variable(tensor: torch.Tensor | np.ndarray) -> torch.Tensor:
10
- """Return a tensor on GPU when available."""
11
- if isinstance(tensor, np.ndarray):
12
- tensor = torch.from_numpy(tensor)
13
- if torch.cuda.is_available():
14
- return tensor.cuda()
15
- return tensor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/rnn_smiles/vocabulary.py DELETED
@@ -1,69 +0,0 @@
1
- """Token vocabulary used by the SMILES RNN."""
2
-
3
- from __future__ import annotations
4
-
5
- import re
6
-
7
- import numpy as np
8
-
9
-
10
- class Vocabulary:
11
- def __init__(self, init_from_file: str | None = None, max_length: int | None = None):
12
- self.special_tokens = ["EOS", "GO"]
13
- self.additional_chars: set[str] = set()
14
- self.chars = self.special_tokens
15
- self.vocab_size = len(self.chars)
16
- self.vocab = dict(zip(self.chars, range(len(self.chars))))
17
- self.reversed_vocab = {v: k for k, v in self.vocab.items()}
18
- self.max_length = max_length
19
- if init_from_file:
20
- self.init_from_file(init_from_file)
21
-
22
- def encode(self, char_list: list[str]) -> np.ndarray:
23
- smiles_matrix = np.zeros(len(char_list), dtype=np.float32)
24
- for i, char in enumerate(char_list):
25
- smiles_matrix[i] = self.vocab[char]
26
- return smiles_matrix
27
-
28
- def decode(self, matrix: np.ndarray) -> str:
29
- chars: list[str] = []
30
- eos_id = self.vocab["EOS"]
31
- for i in matrix:
32
- if int(i) == eos_id:
33
- break
34
- chars.append(self.reversed_vocab[int(i)])
35
- return "".join(chars)
36
-
37
- def tokenize(self, smiles: str) -> list[str]:
38
- regex = r"(\[[^\[\]]{1,6}\])"
39
- char_list = re.split(regex, smiles)
40
- tokenized: list[str] = []
41
- for char in char_list:
42
- if not char:
43
- continue
44
- if char.startswith("["):
45
- tokenized.append(char)
46
- else:
47
- tokenized.extend(list(char))
48
- tokenized.append("EOS")
49
- return tokenized
50
-
51
- def add_characters(self, chars: list[str]) -> None:
52
- for char in chars:
53
- self.additional_chars.add(char)
54
- char_list = sorted(list(self.additional_chars))
55
- self.chars = char_list + self.special_tokens
56
- self.vocab_size = len(self.chars)
57
- self.vocab = dict(zip(self.chars, range(len(self.chars))))
58
- self.reversed_vocab = {v: k for k, v in self.vocab.items()}
59
-
60
- def init_from_file(self, file_path: str) -> None:
61
- with open(file_path, "r", encoding="utf-8") as f:
62
- chars = f.read().split()
63
- self.add_characters(chars)
64
-
65
- def __len__(self) -> int:
66
- return len(self.chars)
67
-
68
- def __str__(self) -> str:
69
- return f"Vocabulary containing {len(self)} tokens: {self.chars}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/sascorer.py DELETED
@@ -1,192 +0,0 @@
1
- #
2
- # calculation of synthetic accessibility score as described in:
3
- #
4
- # Estimation of Synthetic Accessibility Score of Drug-like Molecules based on Molecular Complexity and Fragment Contributions
5
- # Peter Ertl and Ansgar Schuffenhauer
6
- # Journal of Cheminformatics 1:8 (2009)
7
- # http://www.jcheminf.com/content/1/1/8
8
- #
9
- # several small modifications to the original paper are included
10
- # particularly slightly different formula for marocyclic penalty
11
- # and taking into account also molecule symmetry (fingerprint density)
12
- #
13
- # for a set of 10k diverse molecules the agreement between the original method
14
- # as implemented in PipelinePilot and this implementation is r2 = 0.97
15
- #
16
- # peter ertl & greg landrum, september 2013
17
- #
18
-
19
- from rdkit import Chem
20
- from rdkit.Chem import rdFingerprintGenerator, rdMolDescriptors
21
-
22
- import math
23
- import pickle
24
-
25
- import os.path as op
26
-
27
- _fscores = None
28
- mfpgen = rdFingerprintGenerator.GetMorganGenerator(radius=2)
29
-
30
-
31
- def readFragmentScores(name="fpscores.pkl.gz"):
32
- import gzip
33
- global _fscores
34
- # generate the full path filename:
35
- if name == "fpscores.pkl.gz":
36
- name = op.join(op.dirname(__file__), name)
37
- data = pickle.load(gzip.open(name))
38
- outDict = {}
39
- for i in data:
40
- for j in range(1, len(i)):
41
- outDict[i[j]] = float(i[0])
42
- _fscores = outDict
43
-
44
-
45
- def numBridgeheadsAndSpiro(mol, ri=None):
46
- nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol)
47
- nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
48
- return nBridgehead, nSpiro
49
-
50
-
51
- def calculateScore(m):
52
-
53
- if not m.GetNumAtoms():
54
- return None
55
-
56
- if _fscores is None:
57
- readFragmentScores()
58
-
59
- # fragment score
60
- sfp = mfpgen.GetSparseCountFingerprint(m)
61
-
62
- score1 = 0.
63
- nf = 0
64
- nze = sfp.GetNonzeroElements()
65
- for id, count in nze.items():
66
- nf += count
67
- score1 += _fscores.get(id, -4) * count
68
-
69
- score1 /= nf
70
-
71
- # features score
72
- nAtoms = m.GetNumAtoms()
73
- nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True))
74
- ri = m.GetRingInfo()
75
- nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri)
76
- nMacrocycles = 0
77
- for x in ri.AtomRings():
78
- if len(x) > 8:
79
- nMacrocycles += 1
80
-
81
- sizePenalty = nAtoms**1.005 - nAtoms
82
- stereoPenalty = math.log10(nChiralCenters + 1)
83
- spiroPenalty = math.log10(nSpiro + 1)
84
- bridgePenalty = math.log10(nBridgeheads + 1)
85
- macrocyclePenalty = 0.
86
- # ---------------------------------------
87
- # This differs from the paper, which defines:
88
- # macrocyclePenalty = math.log10(nMacrocycles+1)
89
- # This form generates better results when 2 or more macrocycles are present
90
- if nMacrocycles > 0:
91
- macrocyclePenalty = math.log10(2)
92
-
93
- score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty
94
-
95
- # correction for the fingerprint density
96
- # not in the original publication, added in version 1.1
97
- # to make highly symmetrical molecules easier to synthetise
98
- score3 = 0.
99
- numBits = len(nze)
100
- if nAtoms > numBits:
101
- score3 = math.log(float(nAtoms) / numBits) * .5
102
-
103
- sascore = score1 + score2 + score3
104
-
105
- # need to transform "raw" value into scale between 1 and 10
106
- min = -4.0
107
- max = 2.5
108
- sascore = 11. - (sascore - min + 1) / (max - min) * 9.
109
-
110
- # smooth the 10-end
111
- if sascore > 8.:
112
- sascore = 8. + math.log(sascore + 1. - 9.)
113
- if sascore > 10.:
114
- sascore = 10.0
115
- elif sascore < 1.:
116
- sascore = 1.0
117
-
118
- return sascore
119
-
120
-
121
- def processMols(mols):
122
- print('smiles\tName\tsa_score')
123
- for i, m in enumerate(mols):
124
- if m is None:
125
- continue
126
-
127
- s = calculateScore(m)
128
-
129
- smiles = Chem.MolToSmiles(m)
130
- if s is None:
131
- print(f"{smiles}\t{m.GetProp('_Name')}\t{s}")
132
- else:
133
- print(f"{smiles}\t{m.GetProp('_Name')}\t{s:3f}")
134
-
135
-
136
- if __name__ == '__main__':
137
- import sys
138
- import time
139
-
140
- t1 = time.time()
141
- if len(sys.argv) == 2:
142
- readFragmentScores()
143
- else:
144
- readFragmentScores(sys.argv[2])
145
- t2 = time.time()
146
-
147
- molFile = sys.argv[1]
148
- if molFile.endswith("smi"):
149
- suppl = Chem.SmilesMolSupplier(molFile)
150
- elif molFile.endswith("sdf"):
151
- suppl = Chem.SDMolSupplier(molFile)
152
- else:
153
- print(f"Unrecognized file extension for {molFile}")
154
- sys.exit()
155
-
156
- t3 = time.time()
157
- processMols(suppl)
158
- t4 = time.time()
159
-
160
- print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)),
161
- file=sys.stderr)
162
-
163
- #
164
- # Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc.
165
- # All rights reserved.
166
- #
167
- # Redistribution and use in source and binary forms, with or without
168
- # modification, are permitted provided that the following conditions are
169
- # met:
170
- #
171
- # * Redistributions of source code must retain the above copyright
172
- # notice, this list of conditions and the following disclaimer.
173
- # * Redistributions in binary form must reproduce the above
174
- # copyright notice, this list of conditions and the following
175
- # disclaimer in the documentation and/or other materials provided
176
- # with the distribution.
177
- # * Neither the name of Novartis Institutes for BioMedical Research Inc.
178
- # nor the names of its contributors may be used to endorse or promote
179
- # products derived from this software without specific prior written permission.
180
- #
181
- # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
182
- # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
183
- # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
184
- # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
185
- # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
186
- # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
187
- # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
188
- # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
189
- # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
190
- # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
191
- # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
192
- #
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/streamlit_app.py DELETED
@@ -1,40 +0,0 @@
1
- import altair as alt
2
- import numpy as np
3
- import pandas as pd
4
- import streamlit as st
5
-
6
- """
7
- # Welcome to Streamlit!
8
-
9
- Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
10
- If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
11
- forums](https://discuss.streamlit.io).
12
-
13
- In the meantime, below is an example of what you can do with just a few lines of code:
14
- """
15
-
16
- num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
17
- num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
18
-
19
- indices = np.linspace(0, 1, num_points)
20
- theta = 2 * np.pi * num_turns * indices
21
- radius = indices
22
-
23
- x = radius * np.cos(theta)
24
- y = radius * np.sin(theta)
25
-
26
- df = pd.DataFrame({
27
- "x": x,
28
- "y": y,
29
- "idx": indices,
30
- "rand": np.random.randn(num_points),
31
- })
32
-
33
- st.altair_chart(alt.Chart(df, height=700, width=700)
34
- .mark_point(filled=True)
35
- .encode(
36
- x=alt.X("x", axis=None),
37
- y=alt.Y("y", axis=None),
38
- color=alt.Color("idx", legend=None, scale=alt.Scale()),
39
- size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
40
- ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/ui_style.py DELETED
@@ -1,1003 +0,0 @@
1
- import base64
2
- import html
3
- import os
4
- from pathlib import Path
5
- from urllib import request
6
-
7
- import streamlit as st
8
-
9
-
10
- def _icon_data_uri(filename: str) -> str:
11
- icon_path = Path(__file__).resolve().parent.parent / "icons" / filename
12
- if not icon_path.exists():
13
- return ""
14
- try:
15
- encoded = base64.b64encode(icon_path.read_bytes()).decode("ascii")
16
- except Exception:
17
- return ""
18
- return f"data:image/png;base64,{encoded}"
19
-
20
-
21
- def _config_value(name: str, default: str = "") -> str:
22
- try:
23
- if name in st.secrets:
24
- return str(st.secrets[name]).strip()
25
- except Exception:
26
- pass
27
- return str(os.getenv(name, default)).strip()
28
-
29
-
30
- def _build_sidebar_icon_css() -> str:
31
- fallback = {
32
- 1: "🏠",
33
- 2: "🔎",
34
- 3: "📦",
35
- 4: "🧬",
36
- 5: "⚙️",
37
- 6: "🧠",
38
- 7: "✨",
39
- 8: "💬",
40
- 9: "📚",
41
- }
42
- icon_name = {
43
- 1: "home1.png",
44
- 2: "probe1.png",
45
- 3: "batch1.png",
46
- 4: "molecule1.png",
47
- 5: "manual1.png",
48
- 6: "ai1.png",
49
- 7: "rnn1.png",
50
- 8: "literature.png",
51
- 9: "feedback.png",
52
- }
53
- rules = [
54
- '[data-testid="stSidebarNav"] ul li a { position: relative; padding-left: 3.25rem !important; }',
55
- '[data-testid="stSidebarNav"] ul li a::before { content: ""; position: absolute; left: 12px; top: 50%; transform: translateY(-50%); width: 32px; height: 32px; background-size: contain; background-repeat: no-repeat; background-position: center; }',
56
- ]
57
- for idx in range(1, 10):
58
- uri = _icon_data_uri(icon_name[idx])
59
- if uri:
60
- rules.append(
61
- '[data-testid="stSidebarNav"] ul li:nth-of-type(%d) a::before { content: ""; background-image: url("%s"); }'
62
- % (idx, uri)
63
- )
64
- else:
65
- emoji = fallback[idx]
66
- rules.append(
67
- '[data-testid="stSidebarNav"] ul li:nth-of-type(%d) a::before { content: "%s"; background-image: none; width: auto; height: auto; font-size: 1.4rem; }'
68
- % (idx, emoji)
69
- )
70
- return "\n".join(rules)
71
-
72
-
73
- def _log_visit_once_per_session() -> None:
74
- if st.session_state.get("_visit_logged"):
75
- return
76
- webhook_url = _config_value("FEEDBACK_WEBHOOK_URL", "")
77
- webhook_token = _config_value("FEEDBACK_WEBHOOK_TOKEN", "")
78
- if not webhook_url:
79
- return
80
- endpoint = webhook_url
81
- sep = "&" if "?" in webhook_url else "?"
82
- endpoint = f"{webhook_url}{sep}event=visit"
83
- if webhook_token:
84
- endpoint = f"{endpoint}&token={webhook_token}"
85
- try:
86
- with request.urlopen(endpoint, timeout=3):
87
- pass
88
- except Exception:
89
- pass
90
- st.session_state["_visit_logged"] = True
91
-
92
-
93
- def render_page_header(title: str, subtitle: str = "", badge: str = "") -> None:
94
- title_html = html.escape(title)
95
- subtitle_html = html.escape(subtitle) if subtitle else ""
96
- badge_html = html.escape(badge) if badge else ""
97
-
98
- st.markdown(
99
- f"""
100
- <section class="pp-page-header">
101
- {"<span class='pp-badge'>" + badge_html + "</span>" if badge_html else ""}
102
- <h1 class="pp-page-title">{title_html}</h1>
103
- {"<p class='pp-page-subtitle'>" + subtitle_html + "</p>" if subtitle_html else ""}
104
- </section>
105
- """,
106
- unsafe_allow_html=True,
107
- )
108
-
109
-
110
- def apply_global_style() -> None:
111
- _log_visit_once_per_session()
112
- icon_css = _build_sidebar_icon_css()
113
- css = """
114
- <style>
115
- @import url('https://fonts.googleapis.com/css2?family=Manrope:wght@400;500;600;700;800&display=swap');
116
-
117
- :root {
118
- --pp-text: #133a2a;
119
- --pp-muted: #4e6f5d;
120
- --pp-primary: #1f8f52;
121
- --pp-primary-2: #2cab67;
122
- --pp-border: rgba(255, 255, 255, 0.96);
123
- --pp-surface: rgba(251, 247, 255, 0.90);
124
- --pp-shadow: 0 10px 22px rgba(70, 66, 110, 0.12);
125
- --pp-panel-shadow:
126
- 0 14px 28px rgba(83, 73, 124, 0.16),
127
- 0 1px 0 rgba(255, 255, 255, 0.58) inset;
128
- }
129
-
130
- html, body, [class*="css"], [data-testid="stMarkdownContainer"] * {
131
- font-family: "Manrope", "Avenir Next", "Segoe UI", sans-serif;
132
- }
133
-
134
- [data-testid="stAppViewContainer"] {
135
- min-height: 100vh;
136
- overflow: visible !important;
137
- background:
138
- radial-gradient(1300px 700px at -10% 105%, rgba(188, 113, 202, 0.62), transparent 62%),
139
- radial-gradient(1200px 700px at 108% 104%, rgba(130, 170, 235, 0.50), transparent 62%),
140
- linear-gradient(120deg, #d5c4e3 0%, #cdd4e9 45%, #c4e1ea 100%) !important;
141
- }
142
-
143
- .stApp,
144
- [data-testid="stAppViewContainer"] > .main,
145
- section[data-testid="stMain"] {
146
- color: var(--pp-text);
147
- background: transparent !important;
148
- }
149
-
150
- [data-testid="stAppViewContainer"] > .main {
151
- margin: 0 !important;
152
- border: none !important;
153
- border-radius: 0 !important;
154
- box-shadow: none !important;
155
- background: transparent !important;
156
- overflow: visible !important;
157
- }
158
-
159
- section[data-testid="stMain"] {
160
- position: relative;
161
- margin: 12px 14px 12px 18px;
162
- min-height: calc(100vh - 24px) !important;
163
- border-radius: 30px;
164
- border: 1px solid rgba(255, 255, 255, 0.98);
165
- border-right: 2px solid rgba(233, 223, 245, 0.98);
166
- background: rgba(244, 239, 250, 0.96) !important;
167
- box-shadow:
168
- var(--pp-panel-shadow),
169
- inset -1px 0 0 rgba(233, 223, 245, 0.95);
170
- overflow: visible;
171
- isolation: isolate;
172
- }
173
-
174
- section[data-testid="stMain"] {
175
- overflow-y: visible !important;
176
- overflow-x: hidden !important;
177
- }
178
-
179
- section[data-testid="stMain"]::before {
180
- content: "";
181
- position: absolute;
182
- inset: 0;
183
- border-radius: inherit;
184
- pointer-events: none;
185
- box-shadow: inset 0 1px 0 rgba(255, 255, 255, 0.58);
186
- z-index: 0;
187
- }
188
-
189
- section[data-testid="stMain"]::after {
190
- display: none;
191
- }
192
-
193
- section[data-testid="stMain"] > div,
194
- [data-testid="stMainBlockContainer"] {
195
- position: relative;
196
- z-index: 1;
197
- background: transparent !important;
198
- }
199
-
200
- [data-testid="stHeader"] {
201
- background: transparent !important;
202
- }
203
-
204
- .block-container {
205
- max-width: 1220px;
206
- padding-top: 1.3rem;
207
- padding-bottom: 2.4rem;
208
- }
209
-
210
- h1, h2, h3, h4, h5 {
211
- letter-spacing: -0.02em;
212
- }
213
-
214
- p, li, label, [data-testid="stCaptionContainer"] {
215
- color: var(--pp-muted);
216
- }
217
-
218
- a, a:visited {
219
- color: #1f8f52 !important;
220
- }
221
-
222
- section[data-testid="stSidebar"],
223
- [data-testid="stSidebar"] {
224
- background: transparent !important;
225
- border-right: none !important;
226
- --pp-sidebar-width: 300px;
227
- height: calc(100vh - 24px) !important;
228
- min-width: var(--pp-sidebar-width) !important;
229
- width: var(--pp-sidebar-width) !important;
230
- max-width: var(--pp-sidebar-width) !important;
231
- flex: 0 0 var(--pp-sidebar-width) !important;
232
- flex-basis: var(--pp-sidebar-width) !important;
233
- }
234
-
235
- section[data-testid="stSidebar"][aria-expanded="true"] {
236
- min-width: var(--pp-sidebar-width) !important;
237
- width: var(--pp-sidebar-width) !important;
238
- max-width: var(--pp-sidebar-width) !important;
239
- flex: 0 0 var(--pp-sidebar-width) !important;
240
- flex-basis: var(--pp-sidebar-width) !important;
241
- }
242
-
243
- [data-testid="stSidebar"] > div:first-child {
244
- margin: 12px 0 12px 12px;
245
- height: calc(100vh - 24px);
246
- border-radius: 30px;
247
- border: 1px solid rgba(255, 255, 255, 0.98);
248
- background: rgba(241, 226, 248, 0.96) !important;
249
- box-shadow: var(--pp-panel-shadow);
250
- backdrop-filter: blur(6px);
251
- overflow-y: auto;
252
- overflow-x: hidden;
253
- }
254
-
255
- [data-testid="stSidebarUserContent"] {
256
- padding-top: 0.25rem;
257
- }
258
-
259
- [data-testid="stSidebarNav"] ul li,
260
- [data-testid="stSidebarNav"] ul li a,
261
- [data-testid="stSidebarNav"] ul li button {
262
- margin: 0 !important;
263
- padding: 0 !important;
264
- }
265
-
266
- [data-testid="stSidebarNav"] ul li + li {
267
- margin-top: 0.44rem !important;
268
- }
269
-
270
- [data-testid="stSidebarNav"] ul li a,
271
- [data-testid="stSidebarNav"] ul li button {
272
- font-size: 1.02rem !important;
273
- font-family: "Inter", "Manrope", "Avenir Next", "Segoe UI", sans-serif !important;
274
- font-weight: 600 !important;
275
- color: #1f6b4a !important;
276
- border-radius: 12px !important;
277
- display: flex !important;
278
- align-items: center !important;
279
- justify-content: flex-start !important;
280
- height: 44px !important;
281
- min-height: 44px !important;
282
- max-height: 44px !important;
283
- line-height: 1 !important;
284
- padding: 0 0.78rem 0 3.1rem !important;
285
- border: 1px solid rgba(255, 255, 255, 0.98);
286
- background: rgba(255, 255, 255, 0.78) !important;
287
- box-shadow:
288
- 0 1px 0 rgba(255, 255, 255, 0.42) inset,
289
- 0 2px 7px rgba(80, 86, 131, 0.07);
290
- transition: all 140ms ease;
291
- box-sizing: border-box !important;
292
- overflow: hidden !important;
293
- }
294
-
295
- [data-testid="stSidebarNav"] ul li a > div,
296
- [data-testid="stSidebarNav"] ul li button > div {
297
- margin: 0 !important;
298
- padding: 0 !important;
299
- display: flex !important;
300
- align-items: center !important;
301
- min-height: 0 !important;
302
- line-height: 1 !important;
303
- }
304
-
305
- [data-testid="stSidebarNav"] ul li a span,
306
- [data-testid="stSidebarNav"] ul li button span {
307
- font-size: 0.95rem !important;
308
- font-family: "Inter", "Manrope", "Avenir Next", "Segoe UI", sans-serif !important;
309
- font-weight: 600 !important;
310
- color: #1f6b4a !important;
311
- line-height: 1.05 !important;
312
- white-space: nowrap !important;
313
- margin: 0 !important;
314
- padding: 0 !important;
315
- }
316
-
317
- [data-testid="stSidebarNav"] ul li a:hover,
318
- [data-testid="stSidebarNav"] ul li button:hover {
319
- transform: translateY(-1px);
320
- background: rgba(255, 255, 255, 0.90) !important;
321
- box-shadow:
322
- 0 1px 0 rgba(255, 255, 255, 0.44) inset,
323
- 0 4px 10px rgba(85, 94, 142, 0.10);
324
- }
325
-
326
- [data-testid="stSidebarNav"] ul li a[aria-current="page"],
327
- [data-testid="stSidebarNav"] ul li button[aria-current="page"] {
328
- border: 1px solid rgba(34, 163, 93, 0.34) !important;
329
- background: linear-gradient(100deg, #21a35e, #34c67a) !important;
330
- box-shadow:
331
- 0 1px 0 rgba(255, 255, 255, 0.22) inset,
332
- 0 8px 16px rgba(34, 163, 93, 0.28);
333
- }
334
-
335
- [data-testid="stSidebarNav"] ul li a[aria-current="page"],
336
- [data-testid="stSidebarNav"] ul li a[aria-current="page"] *,
337
- [data-testid="stSidebarNav"] ul li button[aria-current="page"],
338
- [data-testid="stSidebarNav"] ul li button[aria-current="page"] * {
339
- color: #ffffff !important;
340
- fill: #ffffff !important;
341
- font-weight: 700 !important;
342
- }
343
-
344
- [data-testid="stSidebarNav"] ul li a[aria-current="page"]::before {
345
- filter: brightness(0) invert(1);
346
- opacity: 0.96;
347
- }
348
-
349
- __ICON_CSS__
350
-
351
- .stTextInput > div > div > input,
352
- .stTextArea textarea,
353
- .stSelectbox [data-baseweb="select"] > div,
354
- .stMultiSelect [data-baseweb="select"] > div,
355
- .stNumberInput input {
356
- border-radius: 12px !important;
357
- border: 1px solid #d7deeb !important;
358
- background: rgba(255, 255, 255, 0.87) !important;
359
- box-shadow: none !important;
360
- color: #173b2b !important;
361
- }
362
-
363
- .stTextInput > div > div > input:focus,
364
- .stTextArea textarea:focus,
365
- .stSelectbox [data-baseweb="select"] > div:focus-within,
366
- .stMultiSelect [data-baseweb="select"] > div:focus-within,
367
- .stNumberInput input:focus {
368
- border-color: #21a35e !important;
369
- box-shadow: 0 0 0 3px rgba(34, 163, 93, 0.18) !important;
370
- }
371
-
372
- /* Force dropdown/expanded menus to green-white accents */
373
- [data-baseweb="popover"] [role="listbox"],
374
- [data-baseweb="popover"] [data-baseweb="menu"],
375
- div[role="listbox"] {
376
- background: rgba(248, 252, 249, 0.98) !important;
377
- border: 1px solid rgba(185, 214, 198, 0.95) !important;
378
- border-radius: 12px !important;
379
- box-shadow: 0 10px 24px rgba(44, 95, 67, 0.14) !important;
380
- }
381
-
382
- [data-baseweb="popover"] [role="option"],
383
- [data-baseweb="popover"] li,
384
- [data-baseweb="popover"] [data-baseweb="menu"] > div,
385
- div[role="listbox"] [role="option"] {
386
- background: transparent !important;
387
- color: #173b2b !important;
388
- }
389
-
390
- [data-baseweb="popover"] [role="option"]:hover,
391
- [data-baseweb="popover"] li:hover,
392
- [data-baseweb="popover"] [data-highlighted="true"],
393
- div[role="listbox"] [role="option"]:hover {
394
- background: rgba(34, 163, 93, 0.10) !important;
395
- }
396
-
397
- [data-baseweb="popover"] [role="option"][aria-selected="true"],
398
- [data-baseweb="popover"] li[aria-selected="true"],
399
- [data-baseweb="popover"] [aria-selected="true"],
400
- div[role="listbox"] [role="option"][aria-selected="true"] {
401
- background: rgba(34, 163, 93, 0.18) !important;
402
- color: #173b2b !important;
403
- }
404
-
405
- [data-baseweb="tag"] {
406
- background: #2f9d62 !important;
407
- border: 1px solid #288653 !important;
408
- color: #ffffff !important;
409
- }
410
-
411
- [data-baseweb="tag"] *,
412
- [data-baseweb="tag"] svg {
413
- color: #ffffff !important;
414
- fill: #ffffff !important;
415
- }
416
-
417
- /* Keep sliders/toggles green while page background stays blue */
418
- .stSlider [data-baseweb="slider"] > div > div > div:first-child,
419
- [data-baseweb="slider"] > div > div > div:first-child {
420
- background-color: #1f8f52 !important;
421
- }
422
-
423
- .stSlider [data-baseweb="slider"] > div > div > div:last-child,
424
- [data-baseweb="slider"] > div > div > div:last-child {
425
- background-color: rgba(34, 163, 93, 0.30) !important;
426
- }
427
-
428
- [data-baseweb="slider"] [style*="rgb(79, 70, 229)"],
429
- [data-baseweb="slider"] [style*="rgb(91, 80, 255)"],
430
- [data-baseweb="slider"] [style*="rgb(67, 56, 202)"],
431
- [data-baseweb="slider"] [style*="rgb("] {
432
- background-color: #1f8f52 !important;
433
- border-color: #1f8f52 !important;
434
- }
435
-
436
- [data-baseweb="slider"] [role="slider"] {
437
- background-color: #1f8f52 !important;
438
- border: 2px solid #ffffff !important;
439
- box-shadow: 0 0 0 1px rgba(34, 163, 93, 0.35), 0 2px 6px rgba(34, 163, 93, 0.28) !important;
440
- }
441
-
442
- [data-baseweb="checkbox"] [aria-checked="true"] {
443
- color: #1f8f52 !important;
444
- }
445
-
446
- [data-baseweb="checkbox"] [aria-checked="true"] > div {
447
- background-color: #1f8f52 !important;
448
- border-color: #1f8f52 !important;
449
- }
450
-
451
- input[type="checkbox"],
452
- input[type="radio"] {
453
- accent-color: #1f8f52 !important;
454
- }
455
-
456
- [data-baseweb="radio"] [aria-checked="true"] {
457
- color: #1f8f52 !important;
458
- }
459
-
460
- .stButton > button,
461
- .stDownloadButton > button,
462
- [data-testid="baseButton-secondary"] {
463
- border-radius: 999px !important;
464
- border: 1px solid #d7dff2 !important;
465
- font-weight: 500 !important;
466
- min-height: 2.65rem;
467
- padding: 0.3rem 1.08rem !important;
468
- background: rgba(255, 255, 255, 0.94) !important;
469
- transition: all 140ms ease;
470
- }
471
-
472
- .stButton > button[kind="primary"],
473
- [data-testid="baseButton-primary"] {
474
- background: linear-gradient(100deg, var(--pp-primary), var(--pp-primary-2)) !important;
475
- color: #fff !important;
476
- border: none !important;
477
- box-shadow: 0 10px 22px rgba(31, 157, 85, 0.34);
478
- }
479
-
480
- .stButton > button[kind="primary"] *,
481
- [data-testid="baseButton-primary"] * {
482
- color: #fff !important;
483
- fill: #fff !important;
484
- }
485
-
486
- [data-testid="stFormSubmitButton"] button,
487
- [data-testid="stFormSubmitButton"] button * {
488
- color: #fff !important;
489
- fill: #fff !important;
490
- }
491
-
492
- .stButton > button:hover,
493
- .stDownloadButton > button:hover {
494
- transform: translateY(-1px);
495
- box-shadow: 0 10px 20px rgba(44, 95, 67, 0.2);
496
- }
497
-
498
- div[data-testid="stVerticalBlockBorderWrapper"] {
499
- border-radius: 18px !important;
500
- border: 1px solid var(--pp-border) !important;
501
- background: var(--pp-surface) !important;
502
- box-shadow: var(--pp-shadow);
503
- }
504
-
505
- div[data-testid="stMetric"] {
506
- background: rgba(255, 255, 255, 0.72);
507
- border-radius: 14px;
508
- border: 1px solid rgba(255, 255, 255, 0.84);
509
- padding: 0.45rem 0.7rem;
510
- }
511
-
512
- div[data-testid="stMetric"] label {
513
- color: #4d705d !important;
514
- font-weight: 600 !important;
515
- letter-spacing: 0.01em;
516
- }
517
-
518
- div[data-testid="stMetricValue"] {
519
- color: #1f8f52 !important;
520
- font-weight: 800 !important;
521
- }
522
-
523
- [data-testid="stDataFrame"],
524
- [data-testid="stTable"] {
525
- background: rgba(255, 255, 255, 0.78);
526
- border-radius: 14px;
527
- border: 1px solid rgba(255, 255, 255, 0.88);
528
- overflow: hidden;
529
- }
530
-
531
- [data-testid="stDataFrame"] [role="grid"],
532
- [data-testid="stDataFrame"] [role="rowgroup"],
533
- [data-testid="stDataFrame"] [role="row"],
534
- [data-testid="stDataFrame"] [role="gridcell"],
535
- [data-testid="stDataFrame"] [role="columnheader"] {
536
- background-color: rgba(247, 252, 248, 0.90) !important;
537
- border-color: rgba(188, 213, 196, 0.55) !important;
538
- }
539
-
540
- [data-testid="stTable"] table,
541
- [data-testid="stTable"] th,
542
- [data-testid="stTable"] td {
543
- background-color: rgba(248, 252, 249, 0.94) !important;
544
- border-color: rgba(188, 213, 196, 0.62) !important;
545
- }
546
-
547
- .pp-page-header {
548
- margin: 0.1rem 0 1.0rem 0;
549
- }
550
-
551
- .pp-page-title {
552
- margin: 0.2rem 0 0.45rem 0;
553
- font-size: clamp(1.95rem, 2.65vw, 3.0rem);
554
- line-height: 1.1;
555
- font-weight: 800;
556
- color: #123726;
557
- }
558
-
559
- .pp-page-subtitle {
560
- margin: 0;
561
- max-width: 880px;
562
- color: #4a6e5b;
563
- font-size: 1.02rem;
564
- line-height: 1.62;
565
- }
566
-
567
- .pp-badge {
568
- display: inline-flex;
569
- align-items: center;
570
- gap: 0.38rem;
571
- padding: 0.3rem 0.72rem;
572
- border-radius: 999px;
573
- border: 1px solid rgba(255, 255, 255, 0.92);
574
- background: rgba(255, 255, 255, 0.72);
575
- color: #3f6856;
576
- font-size: 0.76rem;
577
- font-weight: 700;
578
- text-transform: uppercase;
579
- letter-spacing: 0.04em;
580
- }
581
-
582
- .pp-hero {
583
- border-radius: 22px;
584
- border: 1px solid rgba(255, 255, 255, 0.84);
585
- background:
586
- linear-gradient(95deg, rgba(255, 255, 255, 0.74), rgba(255, 255, 255, 0.60)),
587
- radial-gradient(800px 300px at 100% 0%, rgba(34, 163, 93, 0.16), transparent 72%);
588
- box-shadow: 0 16px 32px rgba(56, 70, 121, 0.11);
589
- padding: 1.45rem 1.5rem;
590
- margin: 0.3rem 0 1.2rem;
591
- }
592
-
593
- .pp-hero-grid {
594
- display: grid;
595
- grid-template-columns: minmax(0, 4fr) minmax(130px, 1fr);
596
- gap: 1.4rem;
597
- align-items: center;
598
- }
599
-
600
- .pp-hero-title {
601
- margin: 0.58rem 0 0.5rem;
602
- color: #123726;
603
- font-size: clamp(1.55rem, 2.12vw, 2.35rem);
604
- font-weight: 800;
605
- letter-spacing: -0.014em;
606
- line-height: 1.2;
607
- }
608
-
609
- .pp-hero-copy {
610
- margin: 0;
611
- color: #4c6f5d;
612
- max-width: 780px;
613
- line-height: 1.66;
614
- }
615
-
616
- .pp-hero-logo {
617
- display: flex;
618
- justify-content: center;
619
- }
620
-
621
- .pp-hero-logo img {
622
- width: 112px;
623
- height: 112px;
624
- border-radius: 18px;
625
- border: 1px solid rgba(255, 255, 255, 0.88);
626
- box-shadow: 0 12px 26px rgba(30, 49, 103, 0.2);
627
- object-fit: contain;
628
- background: rgba(255, 255, 255, 0.96);
629
- padding: 8px;
630
- }
631
-
632
- .pp-stat-card {
633
- border-radius: 16px;
634
- border: 1px solid rgba(255, 255, 255, 0.9);
635
- background: rgba(255, 255, 255, 0.78);
636
- padding: 0.85rem 0.92rem;
637
- box-shadow: 0 8px 22px rgba(56, 78, 145, 0.12);
638
- }
639
-
640
- .pp-stat-value {
641
- margin: 0;
642
- color: #2f9d62;
643
- font-size: clamp(1.2rem, 1.8vw, 1.95rem);
644
- font-weight: 800;
645
- letter-spacing: -0.018em;
646
- }
647
-
648
- .pp-stat-label {
649
- margin: 0.2rem 0 0;
650
- color: #678a76;
651
- font-size: 0.8rem;
652
- text-transform: uppercase;
653
- letter-spacing: 0.045em;
654
- font-weight: 700;
655
- }
656
-
657
- .pp-kpi-strip {
658
- border-radius: 28px;
659
- border: 1px solid rgba(255, 255, 255, 0.9);
660
- background: rgba(247, 251, 248, 0.86);
661
- box-shadow: 0 12px 26px rgba(56, 78, 145, 0.08);
662
- display: grid;
663
- grid-template-columns: repeat(3, minmax(0, 1fr));
664
- gap: 0;
665
- padding: 1.45rem 1.15rem 1.25rem;
666
- margin: 0.62rem 0 0.9rem 0;
667
- }
668
-
669
- .pp-kpi-item {
670
- padding: 0.1rem 0.45rem 0.2rem;
671
- text-align: center;
672
- }
673
-
674
- .pp-kpi-strip .pp-kpi-value {
675
- margin: 0;
676
- color: #2f9d62 !important;
677
- font-size: 3.5rem !important;
678
- font-weight: 800 !important;
679
- line-height: 0.98 !important;
680
- letter-spacing: -0.03em !important;
681
- }
682
-
683
- .pp-kpi-strip .pp-kpi-label {
684
- margin: 0.75rem 0 0;
685
- color: #6f8d7c !important;
686
- font-size: 0.99rem !important;
687
- font-weight: 700 !important;
688
- letter-spacing: 0.12em !important;
689
- text-transform: uppercase !important;
690
- }
691
-
692
- .pp-step-card {
693
- border-radius: 16px;
694
- border: 1px solid rgba(255, 255, 255, 0.9);
695
- background: rgba(255, 255, 255, 0.77);
696
- padding: 0.8rem 0.92rem;
697
- min-height: 120px;
698
- }
699
-
700
- .pp-step-title {
701
- margin: 0 0 0.26rem;
702
- color: #1a4b36;
703
- font-weight: 700;
704
- font-size: 0.98rem;
705
- }
706
-
707
- .pp-step-copy {
708
- margin: 0;
709
- color: #557562;
710
- font-size: 0.89rem;
711
- line-height: 1.45;
712
- }
713
-
714
- .pp-module-card {
715
- border-radius: 16px;
716
- border: 1px solid rgba(255, 255, 255, 0.9);
717
- background: rgba(255, 255, 255, 0.77);
718
- box-shadow: 0 8px 22px rgba(56, 78, 145, 0.10);
719
- padding: 1.02rem 1.05rem;
720
- margin-bottom: 0.74rem;
721
- min-height: 136px;
722
- }
723
-
724
- .pp-module-title {
725
- margin: 0 0 0.32rem;
726
- color: #1a4b36;
727
- font-size: 1.02rem;
728
- font-weight: 800;
729
- line-height: 1.32;
730
- }
731
-
732
- .pp-module-copy {
733
- margin: 0;
734
- color: #557461;
735
- font-size: 0.9rem;
736
- line-height: 1.5;
737
- }
738
-
739
- .pp-api-card {
740
- border-radius: 18px;
741
- border: 1px solid rgba(255, 255, 255, 0.9);
742
- background:
743
- linear-gradient(120deg, rgba(255, 255, 255, 0.84), rgba(246, 251, 248, 0.76));
744
- box-shadow: 0 10px 24px rgba(56, 78, 145, 0.10);
745
- padding: 1rem 1.05rem 0.95rem;
746
- margin: 0 0 0.8rem 0;
747
- }
748
-
749
- .pp-api-kicker {
750
- display: inline-flex;
751
- align-items: center;
752
- padding: 0.28rem 0.72rem;
753
- border-radius: 999px;
754
- border: 1px solid rgba(188, 213, 196, 0.95);
755
- background: rgba(255, 255, 255, 0.84);
756
- color: #4a6e5b;
757
- text-transform: uppercase;
758
- letter-spacing: 0.08em;
759
- font-size: 0.72rem;
760
- font-weight: 700;
761
- }
762
-
763
- .pp-api-title {
764
- margin: 0.72rem 0 0.26rem;
765
- color: #123726;
766
- font-size: clamp(1.2rem, 1.6vw, 1.55rem);
767
- font-weight: 800;
768
- letter-spacing: -0.02em;
769
- }
770
-
771
- .pp-api-copy {
772
- margin: 0;
773
- color: #557461;
774
- font-size: 0.94rem;
775
- line-height: 1.58;
776
- max-width: 900px;
777
- }
778
-
779
- .pp-api-meta {
780
- display: flex;
781
- flex-wrap: wrap;
782
- gap: 0.5rem;
783
- margin: 0.25rem 0 0.15rem;
784
- }
785
-
786
- .pp-api-pill {
787
- display: inline-flex;
788
- align-items: center;
789
- gap: 0.32rem;
790
- padding: 0.34rem 0.72rem;
791
- border-radius: 999px;
792
- border: 1px solid rgba(188, 213, 196, 0.95);
793
- background: rgba(248, 252, 249, 0.95);
794
- color: #3f6856;
795
- font-size: 0.8rem;
796
- font-weight: 600;
797
- line-height: 1.2;
798
- }
799
-
800
- .pp-api-pill strong {
801
- color: #1f8f52;
802
- font-weight: 800;
803
- }
804
-
805
- .pp-api-inline-head {
806
- margin: 0.05rem 0 0.7rem 0;
807
- }
808
-
809
- .pp-main-card {
810
- border-radius: 16px;
811
- border: 1px solid rgba(255, 255, 255, 0.9);
812
- background: rgba(255, 255, 255, 0.77);
813
- box-shadow: 0 8px 22px rgba(56, 78, 145, 0.10);
814
- padding: 1.1rem 1.08rem;
815
- margin: 3.25rem 0 0.9rem 0;
816
- }
817
-
818
- .pp-main-grid {
819
- display: grid;
820
- grid-template-columns: minmax(0, 4fr) minmax(120px, 1fr);
821
- gap: 1.1rem;
822
- align-items: center;
823
- }
824
-
825
- .pp-main-card .pp-main-title {
826
- margin: 0;
827
- color: #000000 !important;
828
- font-size: clamp(2.8rem, 4.2vw, 3.5rem) !important;
829
- font-weight: 800 !important;
830
- line-height: 0.98 !important;
831
- letter-spacing: -0.03em !important;
832
- }
833
-
834
- .pp-main-copy {
835
- margin: 0.5rem 0 0;
836
- color: #557461;
837
- font-size: 1.0rem;
838
- line-height: 1.56;
839
- max-width: 900px;
840
- }
841
-
842
- .pp-main-logo {
843
- display: flex;
844
- justify-content: center;
845
- }
846
-
847
- .pp-main-logo img {
848
- width: 118px;
849
- height: 118px;
850
- border-radius: 16px;
851
- border: 1px solid rgba(255, 255, 255, 0.62);
852
- box-shadow: 0 10px 22px rgba(32, 92, 63, 0.20);
853
- object-fit: contain;
854
- background: linear-gradient(145deg, #2f8059, #1f9d55);
855
- padding: 8px;
856
- filter: drop-shadow(0 0 5px rgba(255, 255, 255, 0.25));
857
- }
858
-
859
- .pp-lab-card {
860
- margin: 0.35rem 0 0.5rem 0;
861
- border-radius: 20px;
862
- border: 1px solid rgba(255, 255, 255, 0.92);
863
- background:
864
- linear-gradient(125deg, rgba(255, 255, 255, 0.82), rgba(242, 250, 245, 0.75));
865
- box-shadow: 0 12px 26px rgba(44, 95, 67, 0.12);
866
- padding: 1.3rem 1.35rem 1.25rem;
867
- }
868
-
869
- .pp-lab-kicker {
870
- display: inline-flex;
871
- align-items: center;
872
- padding: 0.28rem 0.72rem;
873
- border-radius: 999px;
874
- background: rgba(255, 255, 255, 0.82);
875
- border: 1px solid rgba(188, 213, 196, 0.95);
876
- color: #4a6e5b;
877
- text-transform: uppercase;
878
- letter-spacing: 0.08em;
879
- font-size: 0.74rem;
880
- font-weight: 700;
881
- }
882
-
883
- .pp-lab-title {
884
- margin: 0.72rem 0 0.26rem;
885
- font-size: clamp(1.52rem, 2.1vw, 2.05rem);
886
- font-weight: 800;
887
- letter-spacing: -0.02em;
888
- color: #123726;
889
- }
890
-
891
- .pp-lab-subtitle {
892
- margin: 0;
893
- color: #678a76;
894
- font-size: 1.0rem;
895
- font-weight: 600;
896
- line-height: 1.45;
897
- }
898
-
899
- .pp-lab-copy {
900
- margin: 1rem 0 1.15rem;
901
- color: #4c6f5d;
902
- font-size: 1.02rem;
903
- line-height: 1.68;
904
- max-width: 1080px;
905
- }
906
-
907
- .pp-lab-link {
908
- display: inline-flex;
909
- align-items: center;
910
- justify-content: center;
911
- text-decoration: none;
912
- border-radius: 999px;
913
- border: 1px solid rgba(34, 163, 93, 0.36);
914
- background: linear-gradient(100deg, #21a35e, #34c67a);
915
- color: #ffffff !important;
916
- font-size: 0.95rem;
917
- font-weight: 700;
918
- letter-spacing: 0.01em;
919
- padding: 0.56rem 1rem;
920
- box-shadow: 0 8px 16px rgba(34, 163, 93, 0.25);
921
- transition: transform 140ms ease, box-shadow 140ms ease, filter 140ms ease;
922
- }
923
-
924
- .pp-lab-link:hover {
925
- color: #ffffff !important;
926
- transform: translateY(-1px);
927
- box-shadow: 0 12px 20px rgba(34, 163, 93, 0.3);
928
- filter: saturate(1.02);
929
- }
930
-
931
- @media (max-width: 980px) {
932
- [data-testid="stAppViewContainer"] {
933
- height: auto;
934
- overflow: visible !important;
935
- }
936
-
937
- section[data-testid="stMain"] {
938
- margin: 8px;
939
- height: auto !important;
940
- border-radius: 16px;
941
- }
942
-
943
- [data-testid="stSidebar"] > div:first-child {
944
- margin: 8px;
945
- height: auto;
946
- border-radius: 16px;
947
- }
948
-
949
- .pp-main-grid {
950
- grid-template-columns: 1fr;
951
- gap: 0.75rem;
952
- }
953
-
954
- .pp-kpi-strip {
955
- grid-template-columns: 1fr;
956
- gap: 0.42rem;
957
- padding: 0.95rem 1rem 0.85rem;
958
- }
959
-
960
- .pp-kpi-item {
961
- padding: 0.18rem 0.45rem;
962
- }
963
-
964
- .pp-kpi-strip .pp-kpi-value {
965
- font-size: 3.05rem !important;
966
- }
967
-
968
- .pp-kpi-strip .pp-kpi-label {
969
- font-size: 0.7rem !important;
970
- }
971
-
972
- .pp-main-logo {
973
- justify-content: flex-start;
974
- }
975
-
976
- .pp-main-card {
977
- margin-top: 0.7rem;
978
- }
979
-
980
- .pp-hero-grid {
981
- grid-template-columns: 1fr;
982
- gap: 0.85rem;
983
- }
984
- .pp-hero-logo {
985
- justify-content: flex-start;
986
- }
987
-
988
- .pp-lab-card {
989
- padding: 1.0rem;
990
- }
991
-
992
- .pp-lab-title {
993
- margin-top: 0.58rem;
994
- }
995
-
996
- .pp-lab-copy {
997
- margin-top: 0.78rem;
998
- font-size: 0.96rem;
999
- }
1000
- }
1001
- </style>
1002
- """
1003
- st.markdown(css.replace("__ICON_CSS__", icon_css), unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
src/utils.py DELETED
@@ -1,338 +0,0 @@
1
- # utils.py
2
- from __future__ import annotations
3
-
4
- from typing import Dict, List, Optional, Sequence, Literal
5
-
6
- import math
7
- import numpy as np
8
- import torch
9
- import torch.nn as nn
10
-
11
- # Re-exported conveniences from data_builder
12
- from src.data_builder import TargetScaler, grouped_split_by_smiles # noqa: F401
13
-
14
-
15
- # ---------------------------------------------------------
16
- # Seeding and device helpers
17
- # ---------------------------------------------------------
18
-
19
- def seed_everything(seed: int) -> None:
20
- """Deterministically seed Python, NumPy, and PyTorch (CPU/CUDA)."""
21
- import random
22
- random.seed(seed)
23
- np.random.seed(seed)
24
- torch.manual_seed(seed)
25
- torch.cuda.manual_seed_all(seed)
26
-
27
-
28
- def to_device(batch, device: torch.device):
29
- """Move a PyG Batch or simple dict of tensors to device."""
30
- if hasattr(batch, "to"):
31
- return batch.to(device)
32
- if isinstance(batch, dict):
33
- return {k: (v.to(device) if torch.is_tensor(v) else v) for k, v in batch.items()}
34
- return batch
35
-
36
-
37
- # ---------------------------------------------------------
38
- # Masked metrics (canonical)
39
- # ---------------------------------------------------------
40
-
41
- def _safe_div(num: torch.Tensor, den: torch.Tensor) -> torch.Tensor:
42
- den = torch.clamp(den, min=1e-12)
43
- return num / den
44
-
45
-
46
- def masked_mse(pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor,
47
- reduction: Literal["mean", "sum"] = "mean") -> torch.Tensor:
48
- """
49
- pred/target: [B, T]; mask: [B, T] bool
50
- """
51
- pred, target = pred.float(), target.float()
52
- mask = mask.bool()
53
- se = ((pred - target) ** 2) * mask
54
- if reduction == "sum":
55
- return se.sum()
56
- return _safe_div(se.sum(), mask.sum().float())
57
-
58
-
59
- def masked_mae(pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor,
60
- reduction: Literal["mean", "sum"] = "mean") -> torch.Tensor:
61
- ae = (pred - target).abs() * mask
62
- if reduction == "sum":
63
- return ae.sum()
64
- return _safe_div(ae.sum(), mask.sum().float())
65
-
66
-
67
- def masked_rmse(pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
68
- return torch.sqrt(masked_mse(pred, target, mask, reduction="mean"))
69
-
70
-
71
- def masked_r2(pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
72
- """
73
- Masked coefficient of determination across all elements jointly.
74
- """
75
- pred, target = pred.float(), target.float()
76
- mask = mask.bool()
77
- count = mask.sum().float().clamp(min=1.0)
78
- mean = _safe_div((target * mask).sum(), count)
79
- sst = (((target - mean) ** 2) * mask).sum()
80
- sse = (((target - pred) ** 2) * mask).sum()
81
- return 1.0 - _safe_div(sse, sst.clamp(min=1e-12))
82
-
83
-
84
- def masked_metrics_overall(pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> Dict[str, float]:
85
- return {
86
- "rmse": float(masked_rmse(pred, target, mask).detach().cpu()),
87
- "mae": float(masked_mae(pred, target, mask).detach().cpu()),
88
- "r2": float(masked_r2(pred, target, mask).detach().cpu()),
89
- }
90
-
91
-
92
- def masked_metrics_per_task(
93
- pred: torch.Tensor,
94
- target: torch.Tensor,
95
- mask: torch.Tensor,
96
- task_names: Sequence[str],
97
- ) -> Dict[str, Dict[str, float]]:
98
- """
99
- Per-task metrics using the same masked formulations.
100
- """
101
- out: Dict[str, Dict[str, float]] = {}
102
- for t, name in enumerate(task_names):
103
- m = mask[:, t]
104
- if m.any():
105
- rmse = float(masked_rmse(pred[:, t:t+1], target[:, t:t+1], m.unsqueeze(1)).detach().cpu())
106
- mae = float(masked_mae(pred[:, t:t+1], target[:, t:t+1], m.unsqueeze(1)).detach().cpu())
107
- r2 = float(masked_r2(pred[:, t:t+1], target[:, t:t+1], m.unsqueeze(1)).detach().cpu())
108
- else:
109
- rmse = mae = r2 = float("nan")
110
- out[name] = {"rmse": rmse, "mae": mae, "r2": r2}
111
- return out
112
-
113
-
114
- def masked_metrics_by_fidelity(
115
- pred: torch.Tensor,
116
- target: torch.Tensor,
117
- mask: torch.Tensor,
118
- fid_idx: torch.Tensor,
119
- fid_names: Sequence[str],
120
- task_names: Sequence[str], # kept for API parity; not used in overall-by-fid
121
- ) -> Dict[str, Dict[str, float]]:
122
- """
123
- Overall metrics per fidelity (aggregated across tasks).
124
- """
125
- out: Dict[str, Dict[str, float]] = {}
126
- fid_idx = fid_idx.view(-1).long()
127
- for i, fname in enumerate(fid_names):
128
- sel = (fid_idx == i)
129
- if sel.any():
130
- p = pred[sel]
131
- y = target[sel]
132
- m = mask[sel]
133
- out[fname] = masked_metrics_overall(p, y, m)
134
- else:
135
- out[fname] = {"rmse": float("nan"), "mae": float("nan"), "r2": float("nan")}
136
- return out
137
-
138
-
139
- # ---------------------------------------------------------
140
- # Multitask, multi-fidelity loss (canonical)
141
- # ---------------------------------------------------------
142
-
143
- def gaussian_nll(mu: torch.Tensor, logvar: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
144
- """
145
- Element-wise Gaussian NLL (no reduction).
146
- Shapes: mu, logvar, target -> [B, T] (or broadcastable).
147
- """
148
- logvar = torch.as_tensor(logvar, device=mu.device, dtype=mu.dtype)
149
- logvar = logvar.clamp(min=-20.0, max=20.0) # numerical guard
150
- var = torch.exp(logvar)
151
- err2_over_var = (target - mu) ** 2 / var
152
- nll = 0.5 * (err2_over_var + logvar + math.log(2.0 * math.pi)) # [B, T]
153
- return nll
154
-
155
-
156
- def loss_multitask_fidelity(
157
- *,
158
- pred: torch.Tensor, # [B, T] (or means if heteroscedastic)
159
- target: torch.Tensor, # [B, T]
160
- mask: torch.Tensor, # [B, T] bool
161
- fid_idx: torch.Tensor, # [B] long (per-row fidelity index)
162
- fid_loss_w: Sequence[float] | torch.Tensor | None, # [F] weights per fidelity
163
- task_weights: Optional[Sequence[float] | torch.Tensor] = None, # [T]
164
- hetero_logvar: Optional[torch.Tensor] = None, # [B, T] if heteroscedastic head
165
- reduction: Literal["mean", "sum"] = "mean",
166
- task_log_sigma2: Optional[torch.Tensor] = None, # [T] learned homoscedastic uncertainty
167
- balanced: bool = True,
168
- ) -> torch.Tensor:
169
- """
170
- Multi-task, multi-fidelity loss with *balanced per-task reduction* by default.
171
-
172
- - If `hetero_logvar` is given: uses Gaussian NLL per element.
173
- - Applies per-fidelity weights via `fid_idx`.
174
- - Balanced reduction: compute mean loss per task first, then average across tasks
175
- (optionally weight by `task_weights` or learned uncertainty `task_log_sigma2`).
176
- - If `balanced=False`, uses legacy global reduction.
177
- """
178
- B, T = pred.shape
179
- pred = pred.float()
180
- target = target.float()
181
- mask = mask.bool()
182
- fid_idx = fid_idx.view(-1).long()
183
-
184
- # Task weights (optional)
185
- if task_weights is None:
186
- tw = pred.new_ones(T) # [T]
187
- else:
188
- tw = torch.as_tensor(task_weights, dtype=pred.dtype, device=pred.device)
189
- assert tw.numel() == T, f"task_weights len {tw.numel()} != T {T}"
190
- s = tw.sum().clamp(min=1e-12)
191
- tw = tw * (T / s) # normalize to sum=T for stable scale
192
-
193
- # Fidelity weights
194
- if fid_loss_w is None:
195
- fw = pred.new_ones(int(fid_idx.max().item()) + 1)
196
- else:
197
- fw = torch.as_tensor(fid_loss_w, dtype=pred.dtype, device=pred.device)
198
- w_fid = fw[fid_idx].unsqueeze(1).expand(-1, T) # [B, T]
199
-
200
- # Elementwise loss
201
- if hetero_logvar is not None:
202
- elem_loss = gaussian_nll(pred, hetero_logvar.float(), target) # [B, T]
203
- else:
204
- elem_loss = (pred - target) ** 2 # [B, T]
205
-
206
- if not balanced:
207
- # Legacy global reduction (label-count biased)
208
- w_task = tw.view(1, T).expand(B, -1)
209
- weighted = elem_loss * mask * w_task * w_fid
210
- if reduction == "sum":
211
- return weighted.sum()
212
- denom = (mask * w_task * w_fid).sum().float().clamp(min=1e-12)
213
- return weighted.sum() / denom
214
-
215
- # -------- Balanced per-task reduction --------
216
- # First compute a per-task average (exclude tw here)
217
- num = (elem_loss * mask * w_fid).sum(dim=0) # [T]
218
- den = (mask * w_fid).sum(dim=0).float().clamp(min=1e-12) # [T]
219
- per_task_loss = num / den # [T]
220
-
221
- # Optional manual task weights AFTER per-task averaging
222
- if task_weights is not None:
223
- per_task_loss = per_task_loss * tw
224
-
225
- # Optional homoscedastic task-uncertainty weighting (Kendall & Gal)
226
- if task_log_sigma2 is not None:
227
- assert task_log_sigma2.numel() == T, f"task_log_sigma2 must be [T], got {task_log_sigma2.shape}"
228
- sigma2 = torch.exp(task_log_sigma2) # [T]
229
- per_task_loss = per_task_loss / (2.0 * sigma2) + 0.5 * torch.log(sigma2)
230
-
231
- if reduction == "sum":
232
- return per_task_loss.sum()
233
- return per_task_loss.mean()
234
-
235
-
236
- # ---------------------------------------------------------
237
- # Curriculum scheduler for EXP fidelity
238
- # ---------------------------------------------------------
239
-
240
- def exp_weight_at_epoch(
241
- epoch: int,
242
- total_epochs: int,
243
- schedule: Literal["none", "linear", "cosine"] = "none",
244
- start: float = 0.6,
245
- end: float = 1.0,
246
- ) -> float:
247
- """
248
- Returns the EXP loss weight for a given epoch under the chosen schedule.
249
- """
250
- if schedule == "none":
251
- return float(end)
252
- epoch = max(0, min(epoch, total_epochs))
253
- if total_epochs <= 0:
254
- return float(end)
255
- t = epoch / float(total_epochs)
256
- if schedule == "linear":
257
- return float(start + (end - start) * t)
258
- if schedule == "cosine":
259
- cos_t = 0.5 - 0.5 * math.cos(math.pi * t) # 0->1 smoothly
260
- return float(start + (end - start) * cos_t)
261
- raise ValueError(f"Unknown schedule: {schedule}")
262
-
263
-
264
- def make_fid_loss_weights(
265
- fids: Sequence[str],
266
- base_weights: Optional[Sequence[float]] = None,
267
- exp_weight: Optional[float] = None,
268
- ) -> List[float]:
269
- """
270
- Builds a per-fidelity weight vector aligned with dataset.fids order.
271
- If exp_weight is provided, it overrides the weight for the 'exp' fidelity.
272
- If base_weights is provided, it must match len(fids) and is used as a template.
273
- """
274
- fids_lc = [f.lower() for f in fids]
275
- F = len(fids_lc)
276
- if base_weights is None:
277
- w = [1.0] * F
278
- else:
279
- assert len(base_weights) == F, f"base_weights len {len(base_weights)} != {F}"
280
- w = [float(x) for x in base_weights]
281
- if exp_weight is not None and "exp" in fids_lc:
282
- idx = fids_lc.index("exp")
283
- w[idx] = float(exp_weight)
284
- return w
285
-
286
-
287
- # ---------------------------------------------------------
288
- # Inference utilities
289
- # ---------------------------------------------------------
290
-
291
- def apply_inverse_transform(pred: torch.Tensor, scaler):
292
- """
293
- Apply inverse target scaling safely on the same device as pred.
294
- Works for CPU/GPU and legacy scalers.
295
- """
296
- dev = pred.device
297
-
298
- # Move scaler tensors to pred device if needed
299
- if hasattr(scaler, "mean") and scaler.mean.device != dev:
300
- scaler.mean = scaler.mean.to(dev)
301
- if hasattr(scaler, "std") and scaler.std.device != dev:
302
- scaler.std = scaler.std.to(dev)
303
- if hasattr(scaler, "eps") and scaler.eps is not None and scaler.eps.device != dev:
304
- scaler.eps = scaler.eps.to(dev)
305
-
306
- return scaler.inverse(pred)
307
-
308
-
309
-
310
- def ensure_2d(x: torch.Tensor) -> torch.Tensor:
311
- """Utility to guarantee [B, T] shape for single-task or squeezed outputs."""
312
- if x.dim() == 1:
313
- return x.unsqueeze(1)
314
- return x
315
-
316
-
317
- # ---------------------------------------------------------
318
- # Simple test harness (optional)
319
- # ---------------------------------------------------------
320
-
321
- if __name__ == "__main__":
322
- # Minimal sanity checks
323
- torch.manual_seed(0)
324
- B, T = 5, 3
325
- pred = torch.randn(B, T)
326
- targ = torch.randn(B, T)
327
- mask = torch.rand(B, T) > 0.3
328
- fid_idx = torch.randint(0, 4, (B,))
329
- fid_w = [1.0, 0.8, 0.6, 0.5]
330
- task_w = [1.0, 2.0, 1.0]
331
-
332
- l1 = loss_multitask_fidelity(pred=pred, target=targ, mask=mask, fid_idx=fid_idx, fid_loss_w=fid_w, task_weights=task_w)
333
- l2 = loss_multitask_fidelity(pred=pred, target=targ, mask=mask, fid_idx=fid_idx, fid_loss_w=fid_w, task_weights=None)
334
- print("Loss with task weights:", float(l1))
335
- print("Loss without task weights:", float(l2))
336
-
337
- m_all = masked_metrics_overall(pred, targ, mask)
338
- print("Overall metrics:", m_all)