Delete src
Browse files- src/conv.py +0 -258
- src/data_builder.py +0 -818
- src/discover_llm.py +0 -829
- src/discovery.py +0 -767
- src/fpscores.pkl.gz +0 -3
- src/lookup.py +0 -222
- src/model.py +0 -312
- src/predictor.py +0 -193
- src/predictor_multitask.py +0 -209
- src/predictor_router.py +0 -45
- src/rnn_smiles/__init__.py +0 -22
- src/rnn_smiles/generator.py +0 -175
- src/rnn_smiles/rnn.py +0 -89
- src/rnn_smiles/utils.py +0 -15
- src/rnn_smiles/vocabulary.py +0 -69
- src/sascorer.py +0 -192
- src/streamlit_app.py +0 -40
- src/ui_style.py +0 -1003
- src/utils.py +0 -338
src/conv.py
DELETED
|
@@ -1,258 +0,0 @@
|
|
| 1 |
-
# conv.py
|
| 2 |
-
# Clean, dependency-light graph encoder blocks for molecular GNNs.
|
| 3 |
-
# - Single source of truth for convolution choices: "gine", "gin", "gcn"
|
| 4 |
-
# - Edge attributes are supported for "gine" (recommended for chemistry)
|
| 5 |
-
# - No duplication with PyG built-ins; everything wraps torch_geometric.nn
|
| 6 |
-
# - Consistent encoder API: GNNEncoder(...).forward(x, edge_index, edge_attr, batch) -> graph embedding [B, emb_dim]
|
| 7 |
-
|
| 8 |
-
from __future__ import annotations
|
| 9 |
-
from typing import Literal, Optional
|
| 10 |
-
|
| 11 |
-
import torch
|
| 12 |
-
import torch.nn as nn
|
| 13 |
-
import torch.nn.functional as F
|
| 14 |
-
from torch_geometric.nn import (
|
| 15 |
-
GINEConv,
|
| 16 |
-
GINConv,
|
| 17 |
-
GCNConv,
|
| 18 |
-
global_mean_pool,
|
| 19 |
-
global_add_pool,
|
| 20 |
-
global_max_pool,
|
| 21 |
-
)
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
def get_activation(name: str) -> nn.Module:
|
| 25 |
-
name = name.lower()
|
| 26 |
-
if name == "relu":
|
| 27 |
-
return nn.ReLU()
|
| 28 |
-
if name == "gelu":
|
| 29 |
-
return nn.GELU()
|
| 30 |
-
if name == "silu":
|
| 31 |
-
return nn.SiLU()
|
| 32 |
-
if name in ("leaky_relu", "lrelu"):
|
| 33 |
-
return nn.LeakyReLU(0.1)
|
| 34 |
-
raise ValueError(f"Unknown activation: {name}")
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
class MLP(nn.Module):
|
| 38 |
-
"""Small MLP used inside GNN layers and projections."""
|
| 39 |
-
def __init__(
|
| 40 |
-
self,
|
| 41 |
-
in_dim: int,
|
| 42 |
-
hidden_dim: int,
|
| 43 |
-
out_dim: int,
|
| 44 |
-
num_layers: int = 2,
|
| 45 |
-
act: str = "relu",
|
| 46 |
-
dropout: float = 0.0,
|
| 47 |
-
bias: bool = True,
|
| 48 |
-
):
|
| 49 |
-
super().__init__()
|
| 50 |
-
assert num_layers >= 1
|
| 51 |
-
layers: list[nn.Module] = []
|
| 52 |
-
dims = [in_dim] + [hidden_dim] * (num_layers - 1) + [out_dim]
|
| 53 |
-
for i in range(len(dims) - 1):
|
| 54 |
-
layers.append(nn.Linear(dims[i], dims[i + 1], bias=bias))
|
| 55 |
-
if i < len(dims) - 2:
|
| 56 |
-
layers.append(get_activation(act))
|
| 57 |
-
if dropout > 0:
|
| 58 |
-
layers.append(nn.Dropout(dropout))
|
| 59 |
-
self.net = nn.Sequential(*layers)
|
| 60 |
-
|
| 61 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 62 |
-
return self.net(x)
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
class NodeProjector(nn.Module):
|
| 66 |
-
"""Projects raw node features to model embedding size."""
|
| 67 |
-
def __init__(self, in_dim_node: int, emb_dim: int, act: str = "relu"):
|
| 68 |
-
super().__init__()
|
| 69 |
-
if in_dim_node == emb_dim:
|
| 70 |
-
self.proj = nn.Identity()
|
| 71 |
-
else:
|
| 72 |
-
self.proj = nn.Sequential(
|
| 73 |
-
nn.Linear(in_dim_node, emb_dim),
|
| 74 |
-
get_activation(act),
|
| 75 |
-
)
|
| 76 |
-
|
| 77 |
-
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 78 |
-
return self.proj(x)
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
class EdgeProjector(nn.Module):
|
| 82 |
-
"""Projects raw edge attributes to model embedding size for GINE."""
|
| 83 |
-
def __init__(self, in_dim_edge: int, emb_dim: int, act: str = "relu"):
|
| 84 |
-
super().__init__()
|
| 85 |
-
if in_dim_edge <= 0:
|
| 86 |
-
raise ValueError("in_dim_edge must be > 0 when using edge attributes")
|
| 87 |
-
self.proj = nn.Sequential(
|
| 88 |
-
nn.Linear(in_dim_edge, emb_dim),
|
| 89 |
-
get_activation(act),
|
| 90 |
-
)
|
| 91 |
-
|
| 92 |
-
def forward(self, e: torch.Tensor) -> torch.Tensor:
|
| 93 |
-
return self.proj(e)
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
class GNNEncoder(nn.Module):
|
| 97 |
-
"""
|
| 98 |
-
Backbone GNN with selectable conv type.
|
| 99 |
-
|
| 100 |
-
gnn_type:
|
| 101 |
-
- "gine": chemistry-ready, uses edge_attr (recommended)
|
| 102 |
-
- "gin" : ignores edge_attr, strong node MPNN
|
| 103 |
-
- "gcn" : ignores edge_attr, fast spectral conv
|
| 104 |
-
norm: "batch" | "layer" | "none"
|
| 105 |
-
readout: "mean" | "sum" | "max"
|
| 106 |
-
"""
|
| 107 |
-
|
| 108 |
-
def __init__(
|
| 109 |
-
self,
|
| 110 |
-
in_dim_node: int,
|
| 111 |
-
emb_dim: int,
|
| 112 |
-
num_layers: int = 5,
|
| 113 |
-
gnn_type: Literal["gine", "gin", "gcn"] = "gine",
|
| 114 |
-
in_dim_edge: int = 0,
|
| 115 |
-
act: str = "relu",
|
| 116 |
-
dropout: float = 0.0,
|
| 117 |
-
residual: bool = True,
|
| 118 |
-
norm: Literal["batch", "layer", "none"] = "batch",
|
| 119 |
-
readout: Literal["mean", "sum", "max"] = "mean",
|
| 120 |
-
):
|
| 121 |
-
super().__init__()
|
| 122 |
-
assert num_layers >= 1
|
| 123 |
-
|
| 124 |
-
self.gnn_type = gnn_type.lower()
|
| 125 |
-
self.emb_dim = emb_dim
|
| 126 |
-
self.num_layers = num_layers
|
| 127 |
-
self.residual = residual
|
| 128 |
-
self.dropout_p = float(dropout)
|
| 129 |
-
self.readout = readout.lower()
|
| 130 |
-
|
| 131 |
-
self.node_proj = NodeProjector(in_dim_node, emb_dim, act=act)
|
| 132 |
-
self.edge_proj: Optional[EdgeProjector] = None
|
| 133 |
-
|
| 134 |
-
if self.gnn_type == "gine":
|
| 135 |
-
if in_dim_edge <= 0:
|
| 136 |
-
raise ValueError(
|
| 137 |
-
"gine selected but in_dim_edge <= 0. Provide edge attributes or switch gnn_type."
|
| 138 |
-
)
|
| 139 |
-
self.edge_proj = EdgeProjector(in_dim_edge, emb_dim, act=act)
|
| 140 |
-
|
| 141 |
-
# Build conv stack
|
| 142 |
-
self.convs = nn.ModuleList()
|
| 143 |
-
self.norms = nn.ModuleList()
|
| 144 |
-
|
| 145 |
-
for _ in range(num_layers):
|
| 146 |
-
if self.gnn_type == "gine":
|
| 147 |
-
# edge_attr must be projected to emb_dim
|
| 148 |
-
nn_mlp = MLP(emb_dim, emb_dim, emb_dim, num_layers=2, act=act, dropout=0.0)
|
| 149 |
-
conv = GINEConv(nn_mlp)
|
| 150 |
-
elif self.gnn_type == "gin":
|
| 151 |
-
nn_mlp = MLP(emb_dim, emb_dim, emb_dim, num_layers=2, act=act, dropout=0.0)
|
| 152 |
-
conv = GINConv(nn_mlp)
|
| 153 |
-
elif self.gnn_type == "gcn":
|
| 154 |
-
conv = GCNConv(emb_dim, emb_dim, add_self_loops=True, normalize=True)
|
| 155 |
-
else:
|
| 156 |
-
raise ValueError(f"Unknown gnn_type: {gnn_type}")
|
| 157 |
-
self.convs.append(conv)
|
| 158 |
-
|
| 159 |
-
if norm == "batch":
|
| 160 |
-
self.norms.append(nn.BatchNorm1d(emb_dim))
|
| 161 |
-
elif norm == "layer":
|
| 162 |
-
self.norms.append(nn.LayerNorm(emb_dim))
|
| 163 |
-
elif norm == "none":
|
| 164 |
-
self.norms.append(nn.Identity())
|
| 165 |
-
else:
|
| 166 |
-
raise ValueError(f"Unknown norm: {norm}")
|
| 167 |
-
|
| 168 |
-
self.act = get_activation(act)
|
| 169 |
-
|
| 170 |
-
def _readout(self, x: torch.Tensor, batch: torch.Tensor) -> torch.Tensor:
|
| 171 |
-
if self.readout == "mean":
|
| 172 |
-
return global_mean_pool(x, batch)
|
| 173 |
-
if self.readout == "sum":
|
| 174 |
-
return global_add_pool(x, batch)
|
| 175 |
-
if self.readout == "max":
|
| 176 |
-
return global_max_pool(x, batch)
|
| 177 |
-
raise ValueError(f"Unknown readout: {self.readout}")
|
| 178 |
-
|
| 179 |
-
def forward(
|
| 180 |
-
self,
|
| 181 |
-
x: torch.Tensor,
|
| 182 |
-
edge_index: torch.Tensor,
|
| 183 |
-
edge_attr: Optional[torch.Tensor],
|
| 184 |
-
batch: Optional[torch.Tensor],
|
| 185 |
-
) -> torch.Tensor:
|
| 186 |
-
"""
|
| 187 |
-
Returns a graph-level embedding of shape [B, emb_dim].
|
| 188 |
-
If batch is None, assumes a single graph and creates a zero batch vector.
|
| 189 |
-
"""
|
| 190 |
-
if batch is None:
|
| 191 |
-
batch = x.new_zeros(x.size(0), dtype=torch.long)
|
| 192 |
-
|
| 193 |
-
# Project features (ensure float dtype)
|
| 194 |
-
x = x.float()
|
| 195 |
-
x = self.node_proj(x)
|
| 196 |
-
|
| 197 |
-
e = None
|
| 198 |
-
if self.gnn_type == "gine":
|
| 199 |
-
if edge_attr is None:
|
| 200 |
-
raise ValueError("GINE requires edge_attr, but got None.")
|
| 201 |
-
e = self.edge_proj(edge_attr.float())
|
| 202 |
-
|
| 203 |
-
# Message passing
|
| 204 |
-
h = x
|
| 205 |
-
for conv, norm in zip(self.convs, self.norms):
|
| 206 |
-
if self.gnn_type == "gcn":
|
| 207 |
-
h_next = conv(h, edge_index) # GCNConv ignores edge_attr
|
| 208 |
-
elif self.gnn_type == "gin":
|
| 209 |
-
h_next = conv(h, edge_index) # GINConv ignores edge_attr
|
| 210 |
-
else: # gine
|
| 211 |
-
h_next = conv(h, edge_index, e)
|
| 212 |
-
|
| 213 |
-
h_next = norm(h_next)
|
| 214 |
-
h_next = self.act(h_next)
|
| 215 |
-
|
| 216 |
-
if self.residual and h_next.shape == h.shape:
|
| 217 |
-
h = h + h_next
|
| 218 |
-
else:
|
| 219 |
-
h = h_next
|
| 220 |
-
|
| 221 |
-
if self.dropout_p > 0:
|
| 222 |
-
h = F.dropout(h, p=self.dropout_p, training=self.training)
|
| 223 |
-
|
| 224 |
-
g = self._readout(h, batch)
|
| 225 |
-
return g # [B, emb_dim]
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
def build_gnn_encoder(
|
| 229 |
-
in_dim_node: int,
|
| 230 |
-
emb_dim: int,
|
| 231 |
-
num_layers: int = 5,
|
| 232 |
-
gnn_type: Literal["gine", "gin", "gcn"] = "gine",
|
| 233 |
-
in_dim_edge: int = 0,
|
| 234 |
-
act: str = "relu",
|
| 235 |
-
dropout: float = 0.0,
|
| 236 |
-
residual: bool = True,
|
| 237 |
-
norm: Literal["batch", "layer", "none"] = "batch",
|
| 238 |
-
readout: Literal["mean", "sum", "max"] = "mean",
|
| 239 |
-
) -> GNNEncoder:
|
| 240 |
-
"""
|
| 241 |
-
Factory to create a GNNEncoder with a consistent, minimal API.
|
| 242 |
-
Prefer calling this from model.py so encoder construction is centralized.
|
| 243 |
-
"""
|
| 244 |
-
return GNNEncoder(
|
| 245 |
-
in_dim_node=in_dim_node,
|
| 246 |
-
emb_dim=emb_dim,
|
| 247 |
-
num_layers=num_layers,
|
| 248 |
-
gnn_type=gnn_type,
|
| 249 |
-
in_dim_edge=in_dim_edge,
|
| 250 |
-
act=act,
|
| 251 |
-
dropout=dropout,
|
| 252 |
-
residual=residual,
|
| 253 |
-
norm=norm,
|
| 254 |
-
readout=readout,
|
| 255 |
-
)
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
__all__ = ["GNNEncoder", "build_gnn_encoder"]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/data_builder.py
DELETED
|
@@ -1,818 +0,0 @@
|
|
| 1 |
-
# data_builder.py
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
from pathlib import Path
|
| 5 |
-
from typing import Dict, List, Optional, Tuple, Sequence
|
| 6 |
-
import json
|
| 7 |
-
import warnings
|
| 8 |
-
|
| 9 |
-
import numpy as np
|
| 10 |
-
import pandas as pd
|
| 11 |
-
import torch
|
| 12 |
-
from torch.utils.data import Dataset
|
| 13 |
-
from torch_geometric.data import Data
|
| 14 |
-
|
| 15 |
-
# RDKit is required
|
| 16 |
-
from rdkit import Chem
|
| 17 |
-
from rdkit.Chem.rdchem import HybridizationType, BondType, BondStereo
|
| 18 |
-
|
| 19 |
-
# ---------------------------------------------------------
|
| 20 |
-
# Fidelity handling
|
| 21 |
-
# ---------------------------------------------------------
|
| 22 |
-
|
| 23 |
-
FID_PRIORITY = ["exp", "dft", "md", "gc"] # internal lower-case canonical order
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
def _norm_fid(fid: str) -> str:
|
| 27 |
-
return fid.strip().lower()
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
def _ensure_targets_order(requested: Sequence[str]) -> List[str]:
|
| 31 |
-
seen = set()
|
| 32 |
-
ordered = []
|
| 33 |
-
for t in requested:
|
| 34 |
-
key = t.strip()
|
| 35 |
-
if key in seen:
|
| 36 |
-
continue
|
| 37 |
-
seen.add(key)
|
| 38 |
-
ordered.append(key)
|
| 39 |
-
return ordered
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
# ---------------------------------------------------------
|
| 43 |
-
# RDKit featurization
|
| 44 |
-
# ---------------------------------------------------------
|
| 45 |
-
|
| 46 |
-
_ATOMS = ["H", "C", "N", "O", "F", "P", "S", "Cl", "Br", "I"]
|
| 47 |
-
_ATOM2IDX = {s: i for i, s in enumerate(_ATOMS)}
|
| 48 |
-
_HYBS = [HybridizationType.SP, HybridizationType.SP2, HybridizationType.SP3, HybridizationType.SP3D, HybridizationType.SP3D2]
|
| 49 |
-
_HYB2IDX = {h: i for i, h in enumerate(_HYBS)}
|
| 50 |
-
_BOND_STEREOS = [
|
| 51 |
-
BondStereo.STEREONONE,
|
| 52 |
-
BondStereo.STEREOANY,
|
| 53 |
-
BondStereo.STEREOZ,
|
| 54 |
-
BondStereo.STEREOE,
|
| 55 |
-
BondStereo.STEREOCIS,
|
| 56 |
-
BondStereo.STEREOTRANS,
|
| 57 |
-
]
|
| 58 |
-
_STEREO2IDX = {s: i for i, s in enumerate(_BOND_STEREOS)}
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
def _one_hot(index: int, size: int) -> List[float]:
|
| 62 |
-
v = [0.0] * size
|
| 63 |
-
if 0 <= index < size:
|
| 64 |
-
v[index] = 1.0
|
| 65 |
-
return v
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
def atom_features(atom: Chem.Atom) -> List[float]:
|
| 69 |
-
# Element one-hot with "other"
|
| 70 |
-
elem_idx = _ATOM2IDX.get(atom.GetSymbol(), None)
|
| 71 |
-
elem_oh = _one_hot(elem_idx if elem_idx is not None else len(_ATOMS), len(_ATOMS) + 1)
|
| 72 |
-
|
| 73 |
-
# Degree one-hot up to 5 (bucket 5+)
|
| 74 |
-
deg = min(int(atom.GetDegree()), 5)
|
| 75 |
-
deg_oh = _one_hot(deg, 6)
|
| 76 |
-
|
| 77 |
-
# Formal charge one-hot in [-2,-1,0,+1,+2]
|
| 78 |
-
fc = max(-2, min(2, int(atom.GetFormalCharge())))
|
| 79 |
-
fc_oh = _one_hot(fc + 2, 5)
|
| 80 |
-
|
| 81 |
-
# Aromatic, in ring flags
|
| 82 |
-
aromatic = [1.0 if atom.GetIsAromatic() else 0.0]
|
| 83 |
-
in_ring = [1.0 if atom.IsInRing() else 0.0]
|
| 84 |
-
|
| 85 |
-
# Hybridization one-hot with "other"
|
| 86 |
-
hyb_idx = _HYB2IDX.get(atom.GetHybridization(), None)
|
| 87 |
-
hyb_oh = _one_hot(hyb_idx if hyb_idx is not None else len(_HYBS), len(_HYBS) + 1)
|
| 88 |
-
|
| 89 |
-
# Implicit H count capped at 4
|
| 90 |
-
imp_h = min(int(atom.GetTotalNumHs(includeNeighbors=True)), 4)
|
| 91 |
-
imp_h_oh = _one_hot(imp_h, 5)
|
| 92 |
-
|
| 93 |
-
# length: 11+6+5+1+1+6+5 = 35 (element has 11 buckets incl. "other")
|
| 94 |
-
feats = elem_oh + deg_oh + fc_oh + aromatic + in_ring + hyb_oh + imp_h_oh
|
| 95 |
-
return feats
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
def bond_features(bond: Chem.Bond) -> List[float]:
|
| 99 |
-
bt = bond.GetBondType()
|
| 100 |
-
single = 1.0 if bt == BondType.SINGLE else 0.0
|
| 101 |
-
double = 1.0 if bt == BondType.DOUBLE else 0.0
|
| 102 |
-
triple = 1.0 if bt == BondType.TRIPLE else 0.0
|
| 103 |
-
aromatic = 1.0 if bt == BondType.AROMATIC else 0.0
|
| 104 |
-
conj = 1.0 if bond.GetIsConjugated() else 0.0
|
| 105 |
-
in_ring = 1.0 if bond.IsInRing() else 0.0
|
| 106 |
-
stereo_oh = _one_hot(_STEREO2IDX.get(bond.GetStereo(), 0), len(_BOND_STEREOS))
|
| 107 |
-
# length: 4 + 1 + 1 + 6 = 12
|
| 108 |
-
return [single, double, triple, aromatic, conj, in_ring] + stereo_oh
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
def featurize_smiles(smiles: str) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 112 |
-
mol = Chem.MolFromSmiles(smiles)
|
| 113 |
-
if mol is None:
|
| 114 |
-
raise ValueError(f"RDKit failed to parse SMILES: {smiles}")
|
| 115 |
-
|
| 116 |
-
# Nodes
|
| 117 |
-
x = torch.tensor([atom_features(a) for a in mol.GetAtoms()], dtype=torch.float32)
|
| 118 |
-
|
| 119 |
-
# Edges (bidirectional)
|
| 120 |
-
rows, cols, eattr = [], [], []
|
| 121 |
-
for b in mol.GetBonds():
|
| 122 |
-
i, j = b.GetBeginAtomIdx(), b.GetEndAtomIdx()
|
| 123 |
-
bf = bond_features(b)
|
| 124 |
-
rows.extend([i, j])
|
| 125 |
-
cols.extend([j, i])
|
| 126 |
-
eattr.extend([bf, bf])
|
| 127 |
-
|
| 128 |
-
if not rows:
|
| 129 |
-
# single-atom molecules, add a dummy self-loop edge
|
| 130 |
-
rows, cols = [0], [0]
|
| 131 |
-
eattr = [[0.0] * 12]
|
| 132 |
-
|
| 133 |
-
edge_index = torch.tensor([rows, cols], dtype=torch.long)
|
| 134 |
-
edge_attr = torch.tensor(eattr, dtype=torch.float32)
|
| 135 |
-
return x, edge_index, edge_attr
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
# ---------------------------------------------------------
|
| 139 |
-
# CSV discovery and reading
|
| 140 |
-
# ---------------------------------------------------------
|
| 141 |
-
|
| 142 |
-
def discover_target_fid_csvs(
|
| 143 |
-
root: Path,
|
| 144 |
-
targets: Sequence[str],
|
| 145 |
-
fidelities: Sequence[str],
|
| 146 |
-
) -> Dict[tuple[str, str], Path]:
|
| 147 |
-
"""
|
| 148 |
-
Discover CSV files for (target, fidelity) pairs.
|
| 149 |
-
|
| 150 |
-
Supported layouts (case-insensitive):
|
| 151 |
-
|
| 152 |
-
1) {root}/{fid}/{target}.csv
|
| 153 |
-
e.g. datafull/MD/SHEAR.csv, datafull/exp/cp.csv
|
| 154 |
-
|
| 155 |
-
2) {root}/{target}_{fid}.csv
|
| 156 |
-
e.g. datafull/SHEAR_MD.csv, datafull/cp_exp.csv
|
| 157 |
-
|
| 158 |
-
Matching is STRICT:
|
| 159 |
-
- target and fid must appear as full '_' tokens in the stem
|
| 160 |
-
- no substring matching, so 'he' will NOT match 'shear_md.csv'
|
| 161 |
-
"""
|
| 162 |
-
root = Path(root)
|
| 163 |
-
targets = _ensure_targets_order(targets)
|
| 164 |
-
fids_lc = [_norm_fid(f) for f in fidelities]
|
| 165 |
-
|
| 166 |
-
# Collect all CSVs under root
|
| 167 |
-
all_paths = list(root.rglob("*.csv"))
|
| 168 |
-
|
| 169 |
-
# Pre-index: (parent_name_lower, stem_lower, tokens_lower)
|
| 170 |
-
indexed = []
|
| 171 |
-
for p in all_paths:
|
| 172 |
-
parent = p.parent.name.lower()
|
| 173 |
-
stem = p.stem.lower() # filename without extension
|
| 174 |
-
tokens = stem.split("_")
|
| 175 |
-
tokens_l = [t.lower() for t in tokens]
|
| 176 |
-
indexed.append((p, parent, stem, tokens_l))
|
| 177 |
-
|
| 178 |
-
mapping: Dict[tuple[str, str], Path] = {}
|
| 179 |
-
|
| 180 |
-
for fid in fids_lc:
|
| 181 |
-
fid_l = fid.strip().lower()
|
| 182 |
-
|
| 183 |
-
for tgt in targets:
|
| 184 |
-
tgt_l = tgt.strip().lower()
|
| 185 |
-
|
| 186 |
-
# ---- 1) Prefer explicit folder layout: {root}/{fid}/{target}.csv ----
|
| 187 |
-
# parent == fid AND stem == target (case-insensitive)
|
| 188 |
-
folder_matches = [
|
| 189 |
-
p for (p, parent, stem, tokens_l) in indexed
|
| 190 |
-
if parent == fid_l and stem == tgt_l
|
| 191 |
-
]
|
| 192 |
-
if folder_matches:
|
| 193 |
-
# If you ever get more than one, it’s a config problem
|
| 194 |
-
if len(folder_matches) > 1:
|
| 195 |
-
warnings.warn(
|
| 196 |
-
f"[discover_target_fid_csvs] Multiple matches for "
|
| 197 |
-
f"target='{tgt}' fid='{fid}' under folder layout: "
|
| 198 |
-
+ ", ".join(str(p) for p in folder_matches)
|
| 199 |
-
)
|
| 200 |
-
mapping[(tgt, fid)] = folder_matches[0]
|
| 201 |
-
continue
|
| 202 |
-
|
| 203 |
-
# ---- 2) Fallback: {target}_{fid}.csv anywhere under root ----
|
| 204 |
-
# require BOTH tgt and fid as full '_' tokens
|
| 205 |
-
token_matches = [
|
| 206 |
-
p for (p, parent, stem, tokens_l) in indexed
|
| 207 |
-
if (tgt_l in tokens_l) and (fid_l in tokens_l)
|
| 208 |
-
]
|
| 209 |
-
|
| 210 |
-
if token_matches:
|
| 211 |
-
if len(token_matches) > 1:
|
| 212 |
-
warnings.warn(
|
| 213 |
-
f"[discover_target_fid_csvs] Multiple token matches for "
|
| 214 |
-
f"target='{tgt}' fid='{fid}': "
|
| 215 |
-
+ ", ".join(str(p) for p in token_matches)
|
| 216 |
-
)
|
| 217 |
-
mapping[(tgt, fid)] = token_matches[0]
|
| 218 |
-
continue
|
| 219 |
-
|
| 220 |
-
# If neither layout exists, we simply do not add (tgt, fid) to mapping.
|
| 221 |
-
# build_long_table will just skip that combination.
|
| 222 |
-
# You can enable a warning if you want:
|
| 223 |
-
# warnings.warn(f"[discover_target_fid_csvs] No CSV for target='{tgt}', fid='{fid}'")
|
| 224 |
-
|
| 225 |
-
return mapping
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
def read_target_csv(path: Path, target: str) -> pd.DataFrame:
|
| 229 |
-
"""
|
| 230 |
-
Accepts:
|
| 231 |
-
- 'smiles' column (case-insensitive)
|
| 232 |
-
- value column named '{target}' or one of ['value','y' or lower-case target]
|
| 233 |
-
Deduplicates by SMILES with mean.
|
| 234 |
-
"""
|
| 235 |
-
df = pd.read_csv(path)
|
| 236 |
-
|
| 237 |
-
# smiles column
|
| 238 |
-
smiles_col = next((c for c in df.columns if c.lower() == "smiles"), None)
|
| 239 |
-
if smiles_col is None:
|
| 240 |
-
raise ValueError(f"{path} must contain a 'smiles' column.")
|
| 241 |
-
df = df.rename(columns={smiles_col: "smiles"})
|
| 242 |
-
|
| 243 |
-
# value column
|
| 244 |
-
val_col = None
|
| 245 |
-
if target in df.columns:
|
| 246 |
-
val_col = target
|
| 247 |
-
else:
|
| 248 |
-
for c in df.columns:
|
| 249 |
-
if c.lower() in ("value", "y", target.lower()):
|
| 250 |
-
val_col = c
|
| 251 |
-
break
|
| 252 |
-
if val_col is None:
|
| 253 |
-
raise ValueError(f"{path} must contain a '{target}' column or one of ['value','y'].")
|
| 254 |
-
|
| 255 |
-
df = df[["smiles", val_col]].copy()
|
| 256 |
-
df = df.dropna(subset=[val_col])
|
| 257 |
-
df[val_col] = pd.to_numeric(df[val_col], errors="coerce")
|
| 258 |
-
df = df.dropna(subset=[val_col])
|
| 259 |
-
|
| 260 |
-
# Deduplicate SMILES by mean
|
| 261 |
-
if df.duplicated(subset=["smiles"]).any():
|
| 262 |
-
warnings.warn(f"[data_builder] Duplicates by SMILES in {path}. Averaging duplicates.")
|
| 263 |
-
df = df.groupby("smiles", as_index=False)[val_col].mean()
|
| 264 |
-
|
| 265 |
-
return df.rename(columns={val_col: target})
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
def build_long_table(root: Path, targets: Sequence[str], fidelities: Sequence[str]) -> pd.DataFrame:
|
| 269 |
-
"""
|
| 270 |
-
Returns long-form table with columns: [smiles, fid, fid_idx, target, value]
|
| 271 |
-
"""
|
| 272 |
-
targets = _ensure_targets_order(targets)
|
| 273 |
-
fids_lc = [_norm_fid(f) for f in fidelities]
|
| 274 |
-
|
| 275 |
-
mapping = discover_target_fid_csvs(root, targets, fidelities)
|
| 276 |
-
if not mapping:
|
| 277 |
-
raise FileNotFoundError(f"No CSVs found under {root} for the given targets and fidelities.")
|
| 278 |
-
|
| 279 |
-
long_rows = []
|
| 280 |
-
for (tgt, fid), path in mapping.items():
|
| 281 |
-
df = read_target_csv(path, tgt)
|
| 282 |
-
df["fid"] = _norm_fid(fid)
|
| 283 |
-
df["target"] = tgt
|
| 284 |
-
df = df.rename(columns={tgt: "value"})
|
| 285 |
-
long_rows.append(df[["smiles", "fid", "target", "value"]])
|
| 286 |
-
|
| 287 |
-
long = pd.concat(long_rows, axis=0, ignore_index=True)
|
| 288 |
-
|
| 289 |
-
# attach fid index by priority
|
| 290 |
-
fid2idx = {f: i for i, f in enumerate(FID_PRIORITY)}
|
| 291 |
-
long["fid"] = long["fid"].str.lower()
|
| 292 |
-
unknown = sorted(set(long["fid"]) - set(fid2idx.keys()))
|
| 293 |
-
if unknown:
|
| 294 |
-
warnings.warn(f"[data_builder] Unknown fidelities found: {unknown}. Appending after known ones.")
|
| 295 |
-
start = len(fid2idx)
|
| 296 |
-
for i, f in enumerate(unknown):
|
| 297 |
-
fid2idx[f] = start + i
|
| 298 |
-
|
| 299 |
-
long["fid_idx"] = long["fid"].map(fid2idx)
|
| 300 |
-
return long
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
def pivot_to_rows_by_smiles_fid(long: pd.DataFrame, targets: Sequence[str]) -> pd.DataFrame:
|
| 304 |
-
"""
|
| 305 |
-
Input: long table [smiles, fid, fid_idx, target, value]
|
| 306 |
-
Output: row-per-(smiles,fid) with wide columns for each target
|
| 307 |
-
"""
|
| 308 |
-
targets = _ensure_targets_order(targets)
|
| 309 |
-
wide = long.pivot_table(index=["smiles", "fid", "fid_idx"], columns="target", values="value", aggfunc="mean")
|
| 310 |
-
wide = wide.reset_index()
|
| 311 |
-
|
| 312 |
-
for t in targets:
|
| 313 |
-
if t not in wide.columns:
|
| 314 |
-
wide[t] = np.nan
|
| 315 |
-
|
| 316 |
-
cols = ["smiles", "fid", "fid_idx"] + list(targets)
|
| 317 |
-
return wide[cols]
|
| 318 |
-
|
| 319 |
-
|
| 320 |
-
# ---------------------------------------------------------
|
| 321 |
-
# Grouped split by SMILES and transforms/normalization
|
| 322 |
-
# ---------------------------------------------------------
|
| 323 |
-
|
| 324 |
-
def grouped_split_by_smiles(
|
| 325 |
-
df_rows: pd.DataFrame,
|
| 326 |
-
val_ratio: float = 0.1,
|
| 327 |
-
test_ratio: float = 0.1,
|
| 328 |
-
seed: int = 42,
|
| 329 |
-
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
|
| 330 |
-
uniq = df_rows["smiles"].drop_duplicates().values
|
| 331 |
-
rng = np.random.default_rng(seed)
|
| 332 |
-
uniq = rng.permutation(uniq)
|
| 333 |
-
|
| 334 |
-
n = len(uniq)
|
| 335 |
-
n_test = int(round(n * test_ratio))
|
| 336 |
-
n_val = int(round(n * val_ratio))
|
| 337 |
-
|
| 338 |
-
test_smiles = set(uniq[:n_test])
|
| 339 |
-
val_smiles = set(uniq[n_test:n_test + n_val])
|
| 340 |
-
train_smiles = set(uniq[n_test + n_val:])
|
| 341 |
-
|
| 342 |
-
train_idx = df_rows.index[df_rows["smiles"].isin(train_smiles)].to_numpy()
|
| 343 |
-
val_idx = df_rows.index[df_rows["smiles"].isin(val_smiles)].to_numpy()
|
| 344 |
-
test_idx = df_rows.index[df_rows["smiles"].isin(test_smiles)].to_numpy()
|
| 345 |
-
return train_idx, val_idx, test_idx
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
# ---------------- Enhanced TargetScaler with per-task transforms ----------------
|
| 349 |
-
|
| 350 |
-
class TargetScaler:
|
| 351 |
-
"""
|
| 352 |
-
Per-task transform + standardization fitted on the training split only.
|
| 353 |
-
|
| 354 |
-
- transforms[t] in {"identity","log10"}
|
| 355 |
-
- eps[t] is added before log for numerical safety (only used if transforms[t]=="log10")
|
| 356 |
-
- mean/std are computed in the *transformed* domain
|
| 357 |
-
"""
|
| 358 |
-
def __init__(self, transforms: Optional[Sequence[str]] = None, eps: Optional[Sequence[float] | torch.Tensor] = None):
|
| 359 |
-
self.mean: Optional[torch.Tensor] = None # [T] (transformed domain)
|
| 360 |
-
self.std: Optional[torch.Tensor] = None # [T] (transformed domain)
|
| 361 |
-
self.transforms: List[str] = [str(t).lower() for t in transforms] if transforms is not None else []
|
| 362 |
-
if eps is None:
|
| 363 |
-
self.eps: Optional[torch.Tensor] = None
|
| 364 |
-
else:
|
| 365 |
-
self.eps = torch.as_tensor(eps, dtype=torch.float32)
|
| 366 |
-
self._tiny = 1e-12
|
| 367 |
-
|
| 368 |
-
def _ensure_cfg(self, T: int):
|
| 369 |
-
if not self.transforms or len(self.transforms) != T:
|
| 370 |
-
self.transforms = ["identity"] * T
|
| 371 |
-
if self.eps is None or self.eps.numel() != T:
|
| 372 |
-
self.eps = torch.zeros(T, dtype=torch.float32)
|
| 373 |
-
|
| 374 |
-
def _forward_transform_only(self, y: torch.Tensor) -> torch.Tensor:
|
| 375 |
-
"""
|
| 376 |
-
Apply per-task transforms *before* standardization.
|
| 377 |
-
y: [N, T] in original units. Returns transformed y_tf in same shape.
|
| 378 |
-
"""
|
| 379 |
-
out = y.clone()
|
| 380 |
-
T = out.size(1)
|
| 381 |
-
self._ensure_cfg(T)
|
| 382 |
-
for t in range(T):
|
| 383 |
-
if self.transforms[t] == "log10":
|
| 384 |
-
out[:, t] = torch.log10(torch.clamp(out[:, t] + self.eps[t], min=self._tiny))
|
| 385 |
-
return out
|
| 386 |
-
|
| 387 |
-
def _inverse_transform_only(self, y_tf: torch.Tensor) -> torch.Tensor:
|
| 388 |
-
"""
|
| 389 |
-
Inverse the per-task transform (no standardization here).
|
| 390 |
-
y_tf: [N, T] in transformed units.
|
| 391 |
-
"""
|
| 392 |
-
out = y_tf.clone()
|
| 393 |
-
T = out.size(1)
|
| 394 |
-
self._ensure_cfg(T)
|
| 395 |
-
for t in range(T):
|
| 396 |
-
if self.transforms[t] == "log10":
|
| 397 |
-
out[:, t] = (10.0 ** out[:, t]) - self.eps[t]
|
| 398 |
-
return out
|
| 399 |
-
|
| 400 |
-
def fit(self, y: torch.Tensor, mask: torch.Tensor):
|
| 401 |
-
"""
|
| 402 |
-
y: [N, T] original units; mask: [N, T] bool
|
| 403 |
-
Chooses eps automatically if not provided; mean/std computed in transformed space.
|
| 404 |
-
"""
|
| 405 |
-
T = y.size(1)
|
| 406 |
-
self._ensure_cfg(T)
|
| 407 |
-
|
| 408 |
-
if self.eps is None or self.eps.numel() != T:
|
| 409 |
-
# Auto epsilon: 0.1 * min positive per task (robust)
|
| 410 |
-
eps_vals: List[float] = []
|
| 411 |
-
y_np = y.detach().cpu().numpy()
|
| 412 |
-
m_np = mask.detach().cpu().numpy().astype(bool)
|
| 413 |
-
for t in range(T):
|
| 414 |
-
if self.transforms[t] != "log10":
|
| 415 |
-
eps_vals.append(0.0)
|
| 416 |
-
continue
|
| 417 |
-
vals = y_np[m_np[:, t], t]
|
| 418 |
-
pos = vals[vals > 0]
|
| 419 |
-
if pos.size == 0:
|
| 420 |
-
eps_vals.append(1e-8)
|
| 421 |
-
else:
|
| 422 |
-
eps_vals.append(0.1 * float(max(np.min(pos), 1e-8)))
|
| 423 |
-
self.eps = torch.tensor(eps_vals, dtype=torch.float32)
|
| 424 |
-
|
| 425 |
-
y_tf = self._forward_transform_only(y)
|
| 426 |
-
eps = 1e-8
|
| 427 |
-
y_masked = torch.where(mask, y_tf, torch.zeros_like(y_tf))
|
| 428 |
-
counts = mask.sum(dim=0).clamp_min(1)
|
| 429 |
-
mean = y_masked.sum(dim=0) / counts
|
| 430 |
-
var = ((torch.where(mask, y_tf - mean, torch.zeros_like(y_tf))) ** 2).sum(dim=0) / counts
|
| 431 |
-
std = torch.sqrt(var + eps)
|
| 432 |
-
self.mean, self.std = mean, std
|
| 433 |
-
|
| 434 |
-
def transform(self, y: torch.Tensor) -> torch.Tensor:
|
| 435 |
-
y_tf = self._forward_transform_only(y)
|
| 436 |
-
return (y_tf - self.mean) / self.std
|
| 437 |
-
|
| 438 |
-
def inverse(self, y_std: torch.Tensor) -> torch.Tensor:
|
| 439 |
-
"""
|
| 440 |
-
Inverse standardization + inverse transform → original units.
|
| 441 |
-
y_std: [N, T] in standardized-transformed space
|
| 442 |
-
"""
|
| 443 |
-
y_tf = y_std * self.std + self.mean
|
| 444 |
-
return self._inverse_transform_only(y_tf)
|
| 445 |
-
|
| 446 |
-
def state_dict(self) -> Dict[str, torch.Tensor | List[str]]:
|
| 447 |
-
return {
|
| 448 |
-
"mean": self.mean,
|
| 449 |
-
"std": self.std,
|
| 450 |
-
"transforms": self.transforms,
|
| 451 |
-
"eps": self.eps,
|
| 452 |
-
}
|
| 453 |
-
|
| 454 |
-
def load_state_dict(self, state: Dict[str, torch.Tensor | List[str]]):
|
| 455 |
-
self.mean = state["mean"]
|
| 456 |
-
self.std = state["std"]
|
| 457 |
-
self.transforms = [str(t) for t in state.get("transforms", [])]
|
| 458 |
-
eps = state.get("eps", None)
|
| 459 |
-
self.eps = torch.as_tensor(eps, dtype=torch.float32) if eps is not None else None
|
| 460 |
-
|
| 461 |
-
|
| 462 |
-
def auto_select_task_transforms(
|
| 463 |
-
y_train: torch.Tensor, # [N, T] original units (train split only)
|
| 464 |
-
mask_train: torch.Tensor, # [N, T] bool
|
| 465 |
-
task_names: Sequence[str],
|
| 466 |
-
*,
|
| 467 |
-
min_pos_frac: float = 0.95, # ≥95% of labels positive
|
| 468 |
-
orders_threshold: float = 2.0, # ≥2 orders of magnitude between p95 and p5
|
| 469 |
-
tiny: float = 1e-12,
|
| 470 |
-
) -> tuple[List[str], torch.Tensor]:
|
| 471 |
-
"""
|
| 472 |
-
Decide per-task transform: "log10" if (mostly-positive AND large dynamic range), else "identity".
|
| 473 |
-
Returns (transforms, eps_vector) where eps is only used for log tasks.
|
| 474 |
-
"""
|
| 475 |
-
Y = y_train.detach().cpu().numpy()
|
| 476 |
-
M = mask_train.detach().cpu().numpy().astype(bool)
|
| 477 |
-
|
| 478 |
-
transforms: List[str] = []
|
| 479 |
-
eps_vals: List[float] = []
|
| 480 |
-
|
| 481 |
-
for t in range(Y.shape[1]):
|
| 482 |
-
yt = Y[M[:, t], t]
|
| 483 |
-
if yt.size == 0:
|
| 484 |
-
transforms.append("identity")
|
| 485 |
-
eps_vals.append(0.0)
|
| 486 |
-
continue
|
| 487 |
-
|
| 488 |
-
pos_frac = (yt > 0).mean()
|
| 489 |
-
p5 = float(np.percentile(yt, 5))
|
| 490 |
-
p95 = float(np.percentile(yt, 95))
|
| 491 |
-
denom = max(p5, tiny)
|
| 492 |
-
dyn_orders = float(np.log10(max(p95 / denom, 1.0)))
|
| 493 |
-
use_log = (pos_frac >= min_pos_frac) and (dyn_orders >= orders_threshold)
|
| 494 |
-
|
| 495 |
-
if use_log:
|
| 496 |
-
pos_vals = yt[yt > 0]
|
| 497 |
-
if pos_vals.size == 0:
|
| 498 |
-
eps_vals.append(1e-8)
|
| 499 |
-
else:
|
| 500 |
-
eps_vals.append(0.1 * float(max(np.min(pos_vals), 1e-8)))
|
| 501 |
-
transforms.append("log10")
|
| 502 |
-
else:
|
| 503 |
-
transforms.append("identity")
|
| 504 |
-
eps_vals.append(0.0)
|
| 505 |
-
|
| 506 |
-
return transforms, torch.tensor(eps_vals, dtype=torch.float32)
|
| 507 |
-
|
| 508 |
-
|
| 509 |
-
# ---------------------------------------------------------
|
| 510 |
-
# Dataset
|
| 511 |
-
# ---------------------------------------------------------
|
| 512 |
-
|
| 513 |
-
class MultiFidelityMoleculeDataset(Dataset):
|
| 514 |
-
"""
|
| 515 |
-
Each item is a PyG Data with:
|
| 516 |
-
- x: [N_nodes, F_node]
|
| 517 |
-
- edge_index: [2, N_edges]
|
| 518 |
-
- edge_attr: [N_edges, F_edge]
|
| 519 |
-
- y: [T] normalized targets (zeros where missing)
|
| 520 |
-
- y_mask: [T] bool mask of present targets
|
| 521 |
-
- fid_idx: [1] long
|
| 522 |
-
- .smiles and .fid_str added for debugging
|
| 523 |
-
|
| 524 |
-
Targets are kept in the exact order provided by the user.
|
| 525 |
-
"""
|
| 526 |
-
def __init__(
|
| 527 |
-
self,
|
| 528 |
-
rows: pd.DataFrame,
|
| 529 |
-
targets: Sequence[str],
|
| 530 |
-
scaler: Optional[TargetScaler],
|
| 531 |
-
smiles_graph_cache: Dict[str, tuple[torch.Tensor, torch.Tensor, torch.Tensor]],
|
| 532 |
-
):
|
| 533 |
-
super().__init__()
|
| 534 |
-
self.rows = rows.reset_index(drop=True).copy()
|
| 535 |
-
self.targets = _ensure_targets_order(targets)
|
| 536 |
-
self.scaler = scaler
|
| 537 |
-
self.smiles_graph_cache = smiles_graph_cache
|
| 538 |
-
|
| 539 |
-
# Build y and mask tensors
|
| 540 |
-
ys, masks = [], []
|
| 541 |
-
for _, r in self.rows.iterrows():
|
| 542 |
-
yv, mv = [], []
|
| 543 |
-
for t in self.targets:
|
| 544 |
-
v = r[t]
|
| 545 |
-
if pd.isna(v):
|
| 546 |
-
yv.append(np.nan)
|
| 547 |
-
mv.append(False)
|
| 548 |
-
else:
|
| 549 |
-
yv.append(float(v))
|
| 550 |
-
mv.append(True)
|
| 551 |
-
ys.append(yv)
|
| 552 |
-
masks.append(mv)
|
| 553 |
-
|
| 554 |
-
y = torch.tensor(np.array(ys, dtype=np.float32)) # [N, T]
|
| 555 |
-
mask = torch.tensor(np.array(masks, dtype=np.bool_))
|
| 556 |
-
|
| 557 |
-
if scaler is not None and scaler.mean is not None:
|
| 558 |
-
y_norm = torch.where(mask, scaler.transform(y), torch.zeros_like(y))
|
| 559 |
-
else:
|
| 560 |
-
y_norm = y
|
| 561 |
-
|
| 562 |
-
self.y = y_norm
|
| 563 |
-
self.mask = mask
|
| 564 |
-
|
| 565 |
-
# Input dims
|
| 566 |
-
any_smiles = self.rows.iloc[0]["smiles"]
|
| 567 |
-
x0, _, e0 = smiles_graph_cache[any_smiles]
|
| 568 |
-
self.in_dim_node = x0.shape[1]
|
| 569 |
-
self.in_dim_edge = e0.shape[1]
|
| 570 |
-
|
| 571 |
-
# Fidelity metadata for reference (local indexing in this dataset)
|
| 572 |
-
self.fids = sorted(
|
| 573 |
-
self.rows["fid"].str.lower().unique().tolist(),
|
| 574 |
-
key=lambda f: (FID_PRIORITY + [f]).index(f) if f in FID_PRIORITY else len(FID_PRIORITY),
|
| 575 |
-
)
|
| 576 |
-
self.fid2idx = {f: i for i, f in enumerate(self.fids)}
|
| 577 |
-
self.rows["fid_idx_local"] = self.rows["fid"].str.lower().map(self.fid2idx)
|
| 578 |
-
|
| 579 |
-
def __len__(self) -> int:
|
| 580 |
-
return len(self.rows)
|
| 581 |
-
|
| 582 |
-
def __getitem__(self, idx: int) -> Data:
|
| 583 |
-
idx = int(idx)
|
| 584 |
-
r = self.rows.iloc[idx]
|
| 585 |
-
smi = r["smiles"]
|
| 586 |
-
|
| 587 |
-
x, edge_index, edge_attr = self.smiles_graph_cache[smi]
|
| 588 |
-
# Ensure [1, T] so batches become [B, T]
|
| 589 |
-
y_i = self.y[idx].clone().unsqueeze(0) # [1, T]
|
| 590 |
-
m_i = self.mask[idx].clone().unsqueeze(0) # [1, T]
|
| 591 |
-
fid_idx = int(r["fid_idx_local"])
|
| 592 |
-
|
| 593 |
-
d = Data(
|
| 594 |
-
x=x.clone(),
|
| 595 |
-
edge_index=edge_index.clone(),
|
| 596 |
-
edge_attr=edge_attr.clone(),
|
| 597 |
-
y=y_i,
|
| 598 |
-
y_mask=m_i,
|
| 599 |
-
fid_idx=torch.tensor([fid_idx], dtype=torch.long),
|
| 600 |
-
)
|
| 601 |
-
d.smiles = smi
|
| 602 |
-
d.fid_str = r["fid"]
|
| 603 |
-
return d
|
| 604 |
-
|
| 605 |
-
|
| 606 |
-
def subsample_train_indices(
|
| 607 |
-
rows: pd.DataFrame,
|
| 608 |
-
train_idx: np.ndarray,
|
| 609 |
-
*,
|
| 610 |
-
target: Optional[str],
|
| 611 |
-
fidelity: Optional[str],
|
| 612 |
-
pct: float = 1.0,
|
| 613 |
-
seed: int = 137,
|
| 614 |
-
) -> np.ndarray:
|
| 615 |
-
"""
|
| 616 |
-
Return a filtered train_idx that keeps only a 'pct' fraction (0<pct<=1)
|
| 617 |
-
of TRAIN rows for the specified (target, fidelity) block. Selection is
|
| 618 |
-
deterministic by unique SMILES. Rows outside the block are untouched.
|
| 619 |
-
|
| 620 |
-
rows: wide table with columns ["smiles","fid","fid_idx", <targets...>]
|
| 621 |
-
"""
|
| 622 |
-
if target is None or fidelity is None or pct >= 0.999:
|
| 623 |
-
return train_idx
|
| 624 |
-
|
| 625 |
-
if target not in rows.columns:
|
| 626 |
-
return train_idx
|
| 627 |
-
|
| 628 |
-
fid_lc = fidelity.strip().lower()
|
| 629 |
-
|
| 630 |
-
# Identify TRAIN rows in the specified block: matching fid and having a label for 'target'
|
| 631 |
-
train_rows = rows.iloc[train_idx]
|
| 632 |
-
block_mask = (train_rows["fid"].str.lower() == fid_lc) & (~train_rows[target].isna())
|
| 633 |
-
if not bool(block_mask.any()):
|
| 634 |
-
return train_idx # nothing to subsample
|
| 635 |
-
|
| 636 |
-
# Sample by unique SMILES (stable & grouped)
|
| 637 |
-
smiles_all = pd.Index(train_rows.loc[block_mask, "smiles"].unique())
|
| 638 |
-
n_all = len(smiles_all)
|
| 639 |
-
if n_all == 0:
|
| 640 |
-
return train_idx
|
| 641 |
-
|
| 642 |
-
if pct <= 0.0:
|
| 643 |
-
pct = 0.0001
|
| 644 |
-
n_keep = max(1, int(round(pct * n_all)))
|
| 645 |
-
|
| 646 |
-
rng = np.random.RandomState(int(seed))
|
| 647 |
-
smiles_sorted = np.array(sorted(smiles_all.tolist()))
|
| 648 |
-
keep_smiles = set(rng.choice(smiles_sorted, size=n_keep, replace=False).tolist())
|
| 649 |
-
|
| 650 |
-
# Keep all non-block rows; within block keep selected SMILES
|
| 651 |
-
keep_mask_local = (~block_mask) | (train_rows["smiles"].isin(keep_smiles))
|
| 652 |
-
kept_train_idx = train_rows.index[keep_mask_local].to_numpy()
|
| 653 |
-
return kept_train_idx
|
| 654 |
-
|
| 655 |
-
|
| 656 |
-
# ---------------------------------------------------------
|
| 657 |
-
# High-level builder
|
| 658 |
-
# ---------------------------------------------------------
|
| 659 |
-
|
| 660 |
-
def build_dataset_from_dir(
|
| 661 |
-
root_dir: str | Path,
|
| 662 |
-
targets: Sequence[str],
|
| 663 |
-
fidelities: Sequence[str] = ("exp", "dft", "md", "gc"),
|
| 664 |
-
val_ratio: float = 0.1,
|
| 665 |
-
test_ratio: float = 0.1,
|
| 666 |
-
seed: int = 42,
|
| 667 |
-
save_splits_path: Optional[str | Path] = None,
|
| 668 |
-
# Optional subsampling of a (target, fidelity) block in TRAIN
|
| 669 |
-
subsample_target: Optional[str] = None,
|
| 670 |
-
subsample_fidelity: Optional[str] = None,
|
| 671 |
-
subsample_pct: float = 1.0,
|
| 672 |
-
subsample_seed: int = 137,
|
| 673 |
-
# -------- NEW: auto/explicit log transforms --------
|
| 674 |
-
auto_log: bool = True,
|
| 675 |
-
log_orders_threshold: float = 2.0,
|
| 676 |
-
log_min_pos_frac: float = 0.95,
|
| 677 |
-
explicit_log_targets: Optional[Sequence[str]] = None, # e.g. ["permeability"]
|
| 678 |
-
) -> tuple[MultiFidelityMoleculeDataset, MultiFidelityMoleculeDataset, MultiFidelityMoleculeDataset, TargetScaler]:
|
| 679 |
-
"""
|
| 680 |
-
Returns train_ds, val_ds, test_ds, scaler.
|
| 681 |
-
|
| 682 |
-
- Discovers CSVs for requested targets and fidelities
|
| 683 |
-
- Builds a row-per-(smiles,fid) table with columns for each target
|
| 684 |
-
- Splits by unique SMILES to avoid leakage across fidelity or targets
|
| 685 |
-
- Fits transform+normalization on the training split only, applies to val/test
|
| 686 |
-
- Builds RDKit graphs once per unique SMILES and reuses them
|
| 687 |
-
|
| 688 |
-
NEW:
|
| 689 |
-
- Auto per-task transform selection ("log10" vs "identity") by criteria
|
| 690 |
-
- Optional explicit override via explicit_log_targets
|
| 691 |
-
"""
|
| 692 |
-
root = Path(root_dir)
|
| 693 |
-
targets = _ensure_targets_order(targets)
|
| 694 |
-
fids_lc = [_norm_fid(f) for f in fidelities]
|
| 695 |
-
|
| 696 |
-
# Build long and pivot to rows
|
| 697 |
-
long = build_long_table(root, targets, fids_lc)
|
| 698 |
-
rows = pivot_to_rows_by_smiles_fid(long, targets)
|
| 699 |
-
|
| 700 |
-
# Deterministic grouped split by SMILES
|
| 701 |
-
if save_splits_path is not None and Path(save_splits_path).exists():
|
| 702 |
-
with open(save_splits_path, "r") as f:
|
| 703 |
-
split_obj = json.load(f)
|
| 704 |
-
train_smiles = set(split_obj["train_smiles"])
|
| 705 |
-
val_smiles = set(split_obj["val_smiles"])
|
| 706 |
-
test_smiles = set(split_obj["test_smiles"])
|
| 707 |
-
train_idx = rows.index[rows["smiles"].isin(train_smiles)].to_numpy()
|
| 708 |
-
val_idx = rows.index[rows["smiles"].isin(val_smiles)].to_numpy()
|
| 709 |
-
test_idx = rows.index[rows["smiles"].isin(test_smiles)].to_numpy()
|
| 710 |
-
else:
|
| 711 |
-
train_idx, val_idx, test_idx = grouped_split_by_smiles(rows, val_ratio=val_ratio, test_ratio=test_ratio, seed=seed)
|
| 712 |
-
if save_splits_path is not None:
|
| 713 |
-
split_obj = {
|
| 714 |
-
"train_smiles": rows.iloc[train_idx]["smiles"].drop_duplicates().tolist(),
|
| 715 |
-
"val_smiles": rows.iloc[val_idx]["smiles"].drop_duplicates().tolist(),
|
| 716 |
-
"test_smiles": rows.iloc[test_idx]["smiles"].drop_duplicates().tolist(),
|
| 717 |
-
"seed": seed,
|
| 718 |
-
"val_ratio": val_ratio,
|
| 719 |
-
"test_ratio": test_ratio,
|
| 720 |
-
}
|
| 721 |
-
Path(save_splits_path).parent.mkdir(parents=True, exist_ok=True)
|
| 722 |
-
with open(save_splits_path, "w") as f:
|
| 723 |
-
json.dump(split_obj, f, indent=2)
|
| 724 |
-
|
| 725 |
-
# Build RDKit graphs once per unique SMILES
|
| 726 |
-
uniq_smiles = rows["smiles"].drop_duplicates().tolist()
|
| 727 |
-
smiles_graph_cache: Dict[str, tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = {}
|
| 728 |
-
for smi in uniq_smiles:
|
| 729 |
-
try:
|
| 730 |
-
x, edge_index, edge_attr = featurize_smiles(smi)
|
| 731 |
-
smiles_graph_cache[smi] = (x, edge_index, edge_attr)
|
| 732 |
-
except Exception as e:
|
| 733 |
-
warnings.warn(f"[data_builder] Dropping SMILES due to RDKit parse error: {smi} ({e})")
|
| 734 |
-
|
| 735 |
-
# Filter rows to those that featurized successfully
|
| 736 |
-
rows = rows[rows["smiles"].isin(smiles_graph_cache.keys())].reset_index(drop=True)
|
| 737 |
-
|
| 738 |
-
# Re-map indices after filtering using smiles membership
|
| 739 |
-
train_idx = rows.index[rows["smiles"].isin(set(rows.iloc[train_idx]["smiles"]))].to_numpy()
|
| 740 |
-
val_idx = rows.index[rows["smiles"].isin(set(rows.iloc[val_idx]["smiles"]))].to_numpy()
|
| 741 |
-
test_idx = rows.index[rows["smiles"].isin(set(rows.iloc[test_idx]["smiles"]))].to_numpy()
|
| 742 |
-
|
| 743 |
-
# Optional subsampling (train only) for a specific (target, fidelity) block
|
| 744 |
-
train_idx = subsample_train_indices(
|
| 745 |
-
rows,
|
| 746 |
-
train_idx,
|
| 747 |
-
target=subsample_target,
|
| 748 |
-
fidelity=subsample_fidelity,
|
| 749 |
-
pct=float(subsample_pct),
|
| 750 |
-
seed=int(subsample_seed),
|
| 751 |
-
)
|
| 752 |
-
|
| 753 |
-
# Fit scaler on training split only
|
| 754 |
-
def build_y_mask(df_slice: pd.DataFrame) -> tuple[torch.Tensor, torch.Tensor]:
|
| 755 |
-
ys, ms = [], []
|
| 756 |
-
for _, r in df_slice.iterrows():
|
| 757 |
-
yv, mv = [], []
|
| 758 |
-
for t in targets:
|
| 759 |
-
v = r[t]
|
| 760 |
-
if pd.isna(v):
|
| 761 |
-
yv.append(np.nan)
|
| 762 |
-
mv.append(False)
|
| 763 |
-
else:
|
| 764 |
-
yv.append(float(v))
|
| 765 |
-
mv.append(True)
|
| 766 |
-
ys.append(yv)
|
| 767 |
-
ms.append(mv)
|
| 768 |
-
y = torch.tensor(np.array(ys, dtype=np.float32))
|
| 769 |
-
mask = torch.tensor(np.array(ms, dtype=np.bool_))
|
| 770 |
-
return y, mask
|
| 771 |
-
|
| 772 |
-
y_train, mask_train = build_y_mask(rows.iloc[train_idx])
|
| 773 |
-
|
| 774 |
-
# Decide transforms per task
|
| 775 |
-
if explicit_log_targets:
|
| 776 |
-
explicit_set = set(explicit_log_targets)
|
| 777 |
-
transforms = [("log10" if t in explicit_set else "identity") for t in targets]
|
| 778 |
-
eps_vec = None # will be auto-chosen in scaler.fit if not provided
|
| 779 |
-
elif auto_log:
|
| 780 |
-
transforms, eps_vec = auto_select_task_transforms(
|
| 781 |
-
y_train,
|
| 782 |
-
mask_train,
|
| 783 |
-
targets,
|
| 784 |
-
min_pos_frac=float(log_min_pos_frac),
|
| 785 |
-
orders_threshold=float(log_orders_threshold),
|
| 786 |
-
)
|
| 787 |
-
else:
|
| 788 |
-
transforms, eps_vec = (["identity"] * len(targets), None)
|
| 789 |
-
|
| 790 |
-
scaler = TargetScaler(transforms=transforms, eps=eps_vec)
|
| 791 |
-
scaler.fit(y_train, mask_train)
|
| 792 |
-
|
| 793 |
-
# Build datasets
|
| 794 |
-
train_rows = rows.iloc[train_idx].reset_index(drop=True)
|
| 795 |
-
val_rows = rows.iloc[val_idx].reset_index(drop=True)
|
| 796 |
-
test_rows = rows.iloc[test_idx].reset_index(drop=True)
|
| 797 |
-
|
| 798 |
-
train_ds = MultiFidelityMoleculeDataset(train_rows, targets, scaler, smiles_graph_cache)
|
| 799 |
-
val_ds = MultiFidelityMoleculeDataset(val_rows, targets, scaler, smiles_graph_cache)
|
| 800 |
-
test_ds = MultiFidelityMoleculeDataset(test_rows, targets, scaler, smiles_graph_cache)
|
| 801 |
-
|
| 802 |
-
return train_ds, val_ds, test_ds, scaler
|
| 803 |
-
|
| 804 |
-
|
| 805 |
-
__all__ = [
|
| 806 |
-
"build_dataset_from_dir",
|
| 807 |
-
"discover_target_fid_csvs",
|
| 808 |
-
"read_target_csv",
|
| 809 |
-
"build_long_table",
|
| 810 |
-
"pivot_to_rows_by_smiles_fid",
|
| 811 |
-
"grouped_split_by_smiles",
|
| 812 |
-
"TargetScaler",
|
| 813 |
-
"MultiFidelityMoleculeDataset",
|
| 814 |
-
"atom_features",
|
| 815 |
-
"bond_features",
|
| 816 |
-
"featurize_smiles",
|
| 817 |
-
"auto_select_task_transforms",
|
| 818 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/discover_llm.py
DELETED
|
@@ -1,829 +0,0 @@
|
|
| 1 |
-
# src/discovery.py
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
import json
|
| 5 |
-
from dataclasses import dataclass
|
| 6 |
-
from pathlib import Path
|
| 7 |
-
from typing import Callable, Dict, List, Optional, Tuple
|
| 8 |
-
|
| 9 |
-
import numpy as np
|
| 10 |
-
import pandas as pd
|
| 11 |
-
from rdkit import Chem, DataStructs
|
| 12 |
-
from rdkit.Chem import AllChem
|
| 13 |
-
from . import sascorer
|
| 14 |
-
|
| 15 |
-
# Reuse your canonicalizer if you want; otherwise keep local
|
| 16 |
-
def canonicalize_smiles(smiles: str) -> Optional[str]:
|
| 17 |
-
s = (smiles or "").strip()
|
| 18 |
-
if not s:
|
| 19 |
-
return None
|
| 20 |
-
m = Chem.MolFromSmiles(s)
|
| 21 |
-
if m is None:
|
| 22 |
-
return None
|
| 23 |
-
return Chem.MolToSmiles(m, canonical=True)
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
# -------------------------
|
| 27 |
-
# Spec schema (minimal v0)
|
| 28 |
-
# -------------------------
|
| 29 |
-
@dataclass
|
| 30 |
-
class DiscoverySpec:
|
| 31 |
-
dataset: List[str] # ["PI1M_PROPERTY.parquet", "POLYINFO_PROPERTY.parquet"]
|
| 32 |
-
polyinfo: str # "POLYINFO_PROPERTY.parquet"
|
| 33 |
-
polyinfo_csv: str # "POLYINFO.csv"
|
| 34 |
-
|
| 35 |
-
hard_constraints: Dict[str, Dict[str, float]] # { "tg": {"min": 400}, "tc": {"max": 0.3} }
|
| 36 |
-
objectives: List[Dict[str, str]] # [{"property":"cp","goal":"maximize"}, ...]
|
| 37 |
-
|
| 38 |
-
max_pool: int = 200000 # legacy (kept for compatibility; aligned to pareto_max)
|
| 39 |
-
pareto_max: int = 50000 # cap points used for Pareto + diversity fingerprinting
|
| 40 |
-
max_candidates: int = 30 # final output size
|
| 41 |
-
max_pareto_fronts: int = 5 # how many Pareto layers to keep for candidate pool
|
| 42 |
-
min_distance: float = 0.30 # diversity threshold in Tanimoto distance
|
| 43 |
-
fingerprint: str = "morgan" # morgan only for now
|
| 44 |
-
random_seed: int = 7
|
| 45 |
-
use_canonical_smiles: bool = True
|
| 46 |
-
use_full_data: bool = False
|
| 47 |
-
trust_weights: Dict[str, float] | None = None
|
| 48 |
-
selection_weights: Dict[str, float] | None = None
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
# -------------------------
|
| 52 |
-
# Property metadata (local to discovery_llm)
|
| 53 |
-
# -------------------------
|
| 54 |
-
PROPERTY_META: Dict[str, Dict[str, str]] = {
|
| 55 |
-
# Thermal
|
| 56 |
-
"tm": {"name": "Melting temperature", "unit": "K"},
|
| 57 |
-
"tg": {"name": "Glass transition temperature", "unit": "K"},
|
| 58 |
-
"td": {"name": "Thermal diffusivity", "unit": "m^2/s"},
|
| 59 |
-
"tc": {"name": "Thermal conductivity", "unit": "W/m-K"},
|
| 60 |
-
"cp": {"name": "Specific heat capacity", "unit": "J/kg-K"},
|
| 61 |
-
# Mechanical
|
| 62 |
-
"young": {"name": "Young's modulus", "unit": "GPa"},
|
| 63 |
-
"shear": {"name": "Shear modulus", "unit": "GPa"},
|
| 64 |
-
"bulk": {"name": "Bulk modulus", "unit": "GPa"},
|
| 65 |
-
"poisson": {"name": "Poisson ratio", "unit": "-"},
|
| 66 |
-
# Transport
|
| 67 |
-
"visc": {"name": "Viscosity", "unit": "Pa-s"},
|
| 68 |
-
"dif": {"name": "Diffusivity", "unit": "cm^2/s"},
|
| 69 |
-
# Gas permeability
|
| 70 |
-
"phe": {"name": "He permeability", "unit": "Barrer"},
|
| 71 |
-
"ph2": {"name": "H2 permeability", "unit": "Barrer"},
|
| 72 |
-
"pco2": {"name": "CO2 permeability", "unit": "Barrer"},
|
| 73 |
-
"pn2": {"name": "N2 permeability", "unit": "Barrer"},
|
| 74 |
-
"po2": {"name": "O2 permeability", "unit": "Barrer"},
|
| 75 |
-
"pch4": {"name": "CH4 permeability", "unit": "Barrer"},
|
| 76 |
-
# Electronic / Optical
|
| 77 |
-
"alpha": {"name": "Polarizability", "unit": "a.u."},
|
| 78 |
-
"homo": {"name": "HOMO energy", "unit": "eV"},
|
| 79 |
-
"lumo": {"name": "LUMO energy", "unit": "eV"},
|
| 80 |
-
"bandgap": {"name": "Band gap", "unit": "eV"},
|
| 81 |
-
"mu": {"name": "Dipole moment", "unit": "Debye"},
|
| 82 |
-
"etotal": {"name": "Total electronic energy", "unit": "eV"},
|
| 83 |
-
"ri": {"name": "Refractive index", "unit": "-"},
|
| 84 |
-
"dc": {"name": "Dielectric constant", "unit": "-"},
|
| 85 |
-
"pe": {"name": "Permittivity", "unit": "-"},
|
| 86 |
-
# Structural / Physical
|
| 87 |
-
"rg": {"name": "Radius of gyration", "unit": "A"},
|
| 88 |
-
"rho": {"name": "Density", "unit": "g/cm^3"},
|
| 89 |
-
}
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
# -------------------------
|
| 93 |
-
# Column mapping
|
| 94 |
-
# -------------------------
|
| 95 |
-
def mean_col(prop_key: str) -> str:
|
| 96 |
-
return f"mean_{prop_key.lower()}"
|
| 97 |
-
|
| 98 |
-
def std_col(prop_key: str) -> str:
|
| 99 |
-
return f"std_{prop_key.lower()}"
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
def normalize_weights(weights: Dict[str, float], defaults: Dict[str, float]) -> Dict[str, float]:
|
| 103 |
-
out: Dict[str, float] = {}
|
| 104 |
-
for k, v in defaults.items():
|
| 105 |
-
try:
|
| 106 |
-
vv = float(weights.get(k, v))
|
| 107 |
-
except Exception:
|
| 108 |
-
vv = float(v)
|
| 109 |
-
out[k] = max(0.0, vv)
|
| 110 |
-
s = float(sum(out.values()))
|
| 111 |
-
if s <= 0.0:
|
| 112 |
-
return defaults.copy()
|
| 113 |
-
return {k: float(v / s) for k, v in out.items()}
|
| 114 |
-
|
| 115 |
-
def spec_from_dict(obj: dict, dataset_path: List[str], polyinfo_path: str, polyinfo_csv_path: str) -> DiscoverySpec:
|
| 116 |
-
pareto_max = int(obj.get("pareto_max", 50000))
|
| 117 |
-
return DiscoverySpec(
|
| 118 |
-
dataset=list(dataset_path),
|
| 119 |
-
polyinfo=polyinfo_path,
|
| 120 |
-
polyinfo_csv=polyinfo_csv_path,
|
| 121 |
-
hard_constraints=obj.get("hard_constraints", {}),
|
| 122 |
-
objectives=obj.get("objectives", []),
|
| 123 |
-
# Legacy field kept for compatibility; effectively collapsed to pareto_max.
|
| 124 |
-
max_pool=pareto_max,
|
| 125 |
-
pareto_max=pareto_max,
|
| 126 |
-
max_candidates=int(obj.get("max_candidates", 30)),
|
| 127 |
-
max_pareto_fronts=int(obj.get("max_pareto_fronts", 5)),
|
| 128 |
-
min_distance=float(obj.get("min_distance", 0.30)),
|
| 129 |
-
fingerprint=str(obj.get("fingerprint", "morgan")),
|
| 130 |
-
random_seed=int(obj.get("random_seed", 7)),
|
| 131 |
-
use_canonical_smiles=not bool(obj.get("skip_smiles_canonicalization", True)),
|
| 132 |
-
use_full_data=bool(obj.get("use_full_data", False)),
|
| 133 |
-
trust_weights=obj.get("trust_weights"),
|
| 134 |
-
selection_weights=obj.get("selection_weights"),
|
| 135 |
-
)
|
| 136 |
-
|
| 137 |
-
# -------------------------
|
| 138 |
-
# Parquet loading (safe)
|
| 139 |
-
# -------------------------
|
| 140 |
-
def load_parquet_columns(path: str | List[str], columns: List[str]) -> pd.DataFrame:
|
| 141 |
-
"""
|
| 142 |
-
Load only requested columns from Parquet (critical for 1M rows).
|
| 143 |
-
Accepts a single path or a list of paths and concatenates rows.
|
| 144 |
-
"""
|
| 145 |
-
def _load_one(fp: str, req_cols: List[str]) -> pd.DataFrame:
|
| 146 |
-
available: list[str]
|
| 147 |
-
try:
|
| 148 |
-
import pyarrow.parquet as pq
|
| 149 |
-
|
| 150 |
-
pf = pq.ParquetFile(fp)
|
| 151 |
-
available = [str(c) for c in pf.schema.names]
|
| 152 |
-
except Exception:
|
| 153 |
-
# If schema probing fails, fall back to direct read with requested columns.
|
| 154 |
-
return pd.read_parquet(fp, columns=req_cols)
|
| 155 |
-
|
| 156 |
-
available_set = set(available)
|
| 157 |
-
lower_to_actual = {c.lower(): c for c in available}
|
| 158 |
-
|
| 159 |
-
# Resolve requested names against actual parquet schema.
|
| 160 |
-
resolved: dict[str, str] = {}
|
| 161 |
-
for req in req_cols:
|
| 162 |
-
if req in available_set:
|
| 163 |
-
resolved[req] = req
|
| 164 |
-
continue
|
| 165 |
-
alt = lower_to_actual.get(str(req).lower())
|
| 166 |
-
if alt is not None:
|
| 167 |
-
resolved[req] = alt
|
| 168 |
-
|
| 169 |
-
use_cols = sorted(set(resolved.values()))
|
| 170 |
-
if not use_cols:
|
| 171 |
-
return pd.DataFrame(columns=req_cols)
|
| 172 |
-
|
| 173 |
-
out = pd.read_parquet(fp, columns=use_cols)
|
| 174 |
-
for req in req_cols:
|
| 175 |
-
src = resolved.get(req)
|
| 176 |
-
if src is None:
|
| 177 |
-
out[req] = np.nan
|
| 178 |
-
elif src != req:
|
| 179 |
-
out[req] = out[src]
|
| 180 |
-
return out[req_cols]
|
| 181 |
-
|
| 182 |
-
if isinstance(path, (list, tuple)):
|
| 183 |
-
frames = [_load_one(p, columns) for p in path]
|
| 184 |
-
if not frames:
|
| 185 |
-
return pd.DataFrame(columns=columns)
|
| 186 |
-
return pd.concat(frames, ignore_index=True)
|
| 187 |
-
return _load_one(path, columns)
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
def normalize_smiles(smiles: str, use_canonical_smiles: bool) -> Optional[str]:
|
| 191 |
-
s = (smiles or "").strip()
|
| 192 |
-
if not s:
|
| 193 |
-
return None
|
| 194 |
-
if not use_canonical_smiles:
|
| 195 |
-
# Skip RDKit parsing entirely in fast mode.
|
| 196 |
-
return s
|
| 197 |
-
m = Chem.MolFromSmiles(s)
|
| 198 |
-
if m is None:
|
| 199 |
-
return None
|
| 200 |
-
if use_canonical_smiles:
|
| 201 |
-
return Chem.MolToSmiles(m, canonical=True)
|
| 202 |
-
return s
|
| 203 |
-
|
| 204 |
-
|
| 205 |
-
def load_polyinfo_index(polyinfo_csv_path: str, use_canonical_smiles: bool = True) -> pd.DataFrame:
|
| 206 |
-
"""
|
| 207 |
-
Expected CSV columns: SMILES, Polymer_Class, polymer_name (or common variants).
|
| 208 |
-
Returns dataframe with index on smiles_key and columns polymer_name/polymer_class.
|
| 209 |
-
"""
|
| 210 |
-
df = pd.read_csv(polyinfo_csv_path)
|
| 211 |
-
|
| 212 |
-
# normalize column names
|
| 213 |
-
cols = {c: c for c in df.columns}
|
| 214 |
-
# map typical names
|
| 215 |
-
if "SMILES" in cols:
|
| 216 |
-
df = df.rename(columns={"SMILES": "smiles"})
|
| 217 |
-
elif "smiles" not in df.columns:
|
| 218 |
-
raise ValueError(f"{polyinfo_csv_path} missing SMILES/smiles column")
|
| 219 |
-
|
| 220 |
-
if "Polymer_Name" in df.columns:
|
| 221 |
-
df = df.rename(columns={"Polymer_Name": "polymer_name"})
|
| 222 |
-
if "polymer_Name" in df.columns:
|
| 223 |
-
df = df.rename(columns={"polymer_Name": "polymer_name"})
|
| 224 |
-
if "Polymer_Class" in df.columns:
|
| 225 |
-
df = df.rename(columns={"Polymer_Class": "polymer_class"})
|
| 226 |
-
|
| 227 |
-
if "polymer_name" not in df.columns:
|
| 228 |
-
df["polymer_name"] = pd.NA
|
| 229 |
-
if "polymer_class" not in df.columns:
|
| 230 |
-
df["polymer_class"] = pd.NA
|
| 231 |
-
|
| 232 |
-
df["smiles_key"] = df["smiles"].astype(str).map(lambda s: normalize_smiles(s, use_canonical_smiles))
|
| 233 |
-
df = df.dropna(subset=["smiles_key"]).drop_duplicates("smiles_key")
|
| 234 |
-
df = df.set_index("smiles_key", drop=True)
|
| 235 |
-
return df[["polymer_name", "polymer_class"]]
|
| 236 |
-
|
| 237 |
-
|
| 238 |
-
# -------------------------
|
| 239 |
-
# Pareto (2–3 objectives)
|
| 240 |
-
# -------------------------
|
| 241 |
-
def pareto_front_mask(X: np.ndarray) -> np.ndarray:
|
| 242 |
-
"""
|
| 243 |
-
Returns mask for nondominated points.
|
| 244 |
-
X: (N, M), all objectives assumed to be minimized.
|
| 245 |
-
For maximize objectives, we invert before calling this.
|
| 246 |
-
"""
|
| 247 |
-
N = X.shape[0]
|
| 248 |
-
is_efficient = np.ones(N, dtype=bool)
|
| 249 |
-
for i in range(N):
|
| 250 |
-
if not is_efficient[i]:
|
| 251 |
-
continue
|
| 252 |
-
# any point that is <= in all dims and < in at least one dominates
|
| 253 |
-
dominates = np.all(X <= X[i], axis=1) & np.any(X < X[i], axis=1)
|
| 254 |
-
# if a point dominates i, mark i inefficient
|
| 255 |
-
if np.any(dominates):
|
| 256 |
-
is_efficient[i] = False
|
| 257 |
-
continue
|
| 258 |
-
# otherwise, i may dominate others
|
| 259 |
-
dominated_by_i = np.all(X[i] <= X, axis=1) & np.any(X[i] < X, axis=1)
|
| 260 |
-
is_efficient[dominated_by_i] = False
|
| 261 |
-
is_efficient[i] = True
|
| 262 |
-
return is_efficient
|
| 263 |
-
|
| 264 |
-
|
| 265 |
-
def pareto_layers(X: np.ndarray, max_layers: int = 10) -> np.ndarray:
|
| 266 |
-
"""
|
| 267 |
-
Returns layer index per point: 1 = Pareto front, 2 = second layer, ...
|
| 268 |
-
Unassigned points beyond max_layers get 0.
|
| 269 |
-
"""
|
| 270 |
-
N = X.shape[0]
|
| 271 |
-
layers = np.zeros(N, dtype=int)
|
| 272 |
-
remaining = np.arange(N)
|
| 273 |
-
|
| 274 |
-
layer = 1
|
| 275 |
-
while remaining.size > 0 and layer <= max_layers:
|
| 276 |
-
mask = pareto_front_mask(X[remaining])
|
| 277 |
-
front_idx = remaining[mask]
|
| 278 |
-
layers[front_idx] = layer
|
| 279 |
-
remaining = remaining[~mask]
|
| 280 |
-
layer += 1
|
| 281 |
-
return layers
|
| 282 |
-
|
| 283 |
-
|
| 284 |
-
def pareto_front_mask_chunked(
|
| 285 |
-
X: np.ndarray,
|
| 286 |
-
chunk_size: int = 100000,
|
| 287 |
-
progress_callback: Optional[Callable[[int, int], None]] = None,
|
| 288 |
-
) -> np.ndarray:
|
| 289 |
-
"""
|
| 290 |
-
Exact global Pareto front mask via chunk-local front reduction + global reconcile.
|
| 291 |
-
This is exact for front-1:
|
| 292 |
-
1) compute exact local front within each chunk
|
| 293 |
-
2) union local fronts
|
| 294 |
-
3) compute exact front on the union
|
| 295 |
-
"""
|
| 296 |
-
N = X.shape[0]
|
| 297 |
-
if N <= chunk_size:
|
| 298 |
-
if progress_callback is not None:
|
| 299 |
-
progress_callback(1, 1)
|
| 300 |
-
return pareto_front_mask(X)
|
| 301 |
-
|
| 302 |
-
local_front_idx = []
|
| 303 |
-
total_chunks = (N + chunk_size - 1) // chunk_size
|
| 304 |
-
done_chunks = 0
|
| 305 |
-
for start in range(0, N, chunk_size):
|
| 306 |
-
end = min(start + chunk_size, N)
|
| 307 |
-
idx = np.arange(start, end)
|
| 308 |
-
mask_local = pareto_front_mask(X[idx])
|
| 309 |
-
local_front_idx.append(idx[mask_local])
|
| 310 |
-
done_chunks += 1
|
| 311 |
-
if progress_callback is not None:
|
| 312 |
-
progress_callback(done_chunks, total_chunks)
|
| 313 |
-
|
| 314 |
-
if not local_front_idx:
|
| 315 |
-
return np.zeros(N, dtype=bool)
|
| 316 |
-
|
| 317 |
-
reduced_idx = np.concatenate(local_front_idx)
|
| 318 |
-
reduced_mask = pareto_front_mask(X[reduced_idx])
|
| 319 |
-
front_idx = reduced_idx[reduced_mask]
|
| 320 |
-
|
| 321 |
-
out = np.zeros(N, dtype=bool)
|
| 322 |
-
out[front_idx] = True
|
| 323 |
-
return out
|
| 324 |
-
|
| 325 |
-
|
| 326 |
-
def pareto_layers_chunked(
|
| 327 |
-
X: np.ndarray,
|
| 328 |
-
max_layers: int = 10,
|
| 329 |
-
chunk_size: int = 100000,
|
| 330 |
-
progress_callback: Optional[Callable[[int, int, int], None]] = None,
|
| 331 |
-
) -> np.ndarray:
|
| 332 |
-
"""
|
| 333 |
-
Exact Pareto layers using repeated exact chunked front extraction.
|
| 334 |
-
"""
|
| 335 |
-
N = X.shape[0]
|
| 336 |
-
layers = np.zeros(N, dtype=int)
|
| 337 |
-
remaining = np.arange(N)
|
| 338 |
-
layer = 1
|
| 339 |
-
|
| 340 |
-
while remaining.size > 0 and layer <= max_layers:
|
| 341 |
-
def on_chunk(done: int, total: int) -> None:
|
| 342 |
-
if progress_callback is not None:
|
| 343 |
-
progress_callback(layer, done, total)
|
| 344 |
-
|
| 345 |
-
mask = pareto_front_mask_chunked(X[remaining], chunk_size=chunk_size, progress_callback=on_chunk)
|
| 346 |
-
front_idx = remaining[mask]
|
| 347 |
-
layers[front_idx] = layer
|
| 348 |
-
remaining = remaining[~mask]
|
| 349 |
-
layer += 1
|
| 350 |
-
|
| 351 |
-
return layers
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
# -------------------------
|
| 355 |
-
# Fingerprints & diversity
|
| 356 |
-
# -------------------------
|
| 357 |
-
def morgan_fp(smiles: str, radius: int = 2, nbits: int = 2048):
|
| 358 |
-
m = Chem.MolFromSmiles(smiles)
|
| 359 |
-
if m is None:
|
| 360 |
-
return None
|
| 361 |
-
return AllChem.GetMorganFingerprintAsBitVect(m, radius, nBits=nbits)
|
| 362 |
-
|
| 363 |
-
def tanimoto_distance(fp1, fp2) -> float:
|
| 364 |
-
return 1.0 - DataStructs.TanimotoSimilarity(fp1, fp2)
|
| 365 |
-
|
| 366 |
-
def greedy_diverse_select(
|
| 367 |
-
smiles_list: List[str],
|
| 368 |
-
scores: np.ndarray,
|
| 369 |
-
max_k: int,
|
| 370 |
-
min_dist: float,
|
| 371 |
-
) -> List[int]:
|
| 372 |
-
"""
|
| 373 |
-
Greedy selection by descending score, enforcing min Tanimoto distance.
|
| 374 |
-
Returns indices into smiles_list.
|
| 375 |
-
"""
|
| 376 |
-
fps = []
|
| 377 |
-
valid_idx = []
|
| 378 |
-
for i, s in enumerate(smiles_list):
|
| 379 |
-
fp = morgan_fp(s)
|
| 380 |
-
if fp is not None:
|
| 381 |
-
fps.append(fp)
|
| 382 |
-
valid_idx.append(i)
|
| 383 |
-
|
| 384 |
-
if not valid_idx:
|
| 385 |
-
return []
|
| 386 |
-
|
| 387 |
-
# rank candidates (higher score first)
|
| 388 |
-
order = np.argsort(-scores[valid_idx])
|
| 389 |
-
selected_global = []
|
| 390 |
-
selected_fps = []
|
| 391 |
-
|
| 392 |
-
for oi in order:
|
| 393 |
-
i = valid_idx[oi]
|
| 394 |
-
fp_i = fps[oi] # aligned with valid_idx
|
| 395 |
-
ok = True
|
| 396 |
-
for fp_j in selected_fps:
|
| 397 |
-
if tanimoto_distance(fp_i, fp_j) < min_dist:
|
| 398 |
-
ok = False
|
| 399 |
-
break
|
| 400 |
-
if ok:
|
| 401 |
-
selected_global.append(i)
|
| 402 |
-
selected_fps.append(fp_i)
|
| 403 |
-
if len(selected_global) >= max_k:
|
| 404 |
-
break
|
| 405 |
-
|
| 406 |
-
return selected_global
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
# -------------------------
|
| 410 |
-
# Trust score (lightweight, robust)
|
| 411 |
-
# -------------------------
|
| 412 |
-
def internal_consistency_penalty(row: pd.Series) -> float:
|
| 413 |
-
"""
|
| 414 |
-
Very simple physics/validity checks. Penalty in [0,1].
|
| 415 |
-
Adjust/add rules later.
|
| 416 |
-
"""
|
| 417 |
-
viol = 0
|
| 418 |
-
total = 0
|
| 419 |
-
|
| 420 |
-
def chk(cond: bool):
|
| 421 |
-
nonlocal viol, total
|
| 422 |
-
total += 1
|
| 423 |
-
if not cond:
|
| 424 |
-
viol += 1
|
| 425 |
-
|
| 426 |
-
# positivity checks if present
|
| 427 |
-
for p in ["cp", "tc", "rho", "dif", "visc", "tg", "tm", "bandgap"]:
|
| 428 |
-
c = mean_col(p)
|
| 429 |
-
if c in row.index and pd.notna(row[c]):
|
| 430 |
-
if p in ["bandgap", "tg", "tm"]:
|
| 431 |
-
chk(float(row[c]) >= 0.0)
|
| 432 |
-
else:
|
| 433 |
-
chk(float(row[c]) > 0.0)
|
| 434 |
-
|
| 435 |
-
# Poisson ratio bounds if present
|
| 436 |
-
if mean_col("poisson") in row.index and pd.notna(row[mean_col("poisson")]):
|
| 437 |
-
v = float(row[mean_col("poisson")])
|
| 438 |
-
chk(0.0 <= v <= 0.5)
|
| 439 |
-
|
| 440 |
-
# Tg <= Tm if both present
|
| 441 |
-
if mean_col("tg") in row.index and mean_col("tm") in row.index:
|
| 442 |
-
if pd.notna(row[mean_col("tg")]) and pd.notna(row[mean_col("tm")]):
|
| 443 |
-
chk(float(row[mean_col("tg")]) <= float(row[mean_col("tm")]))
|
| 444 |
-
|
| 445 |
-
if total == 0:
|
| 446 |
-
return 0.0
|
| 447 |
-
return viol / total
|
| 448 |
-
|
| 449 |
-
|
| 450 |
-
def synthesizability_score(smiles: str) -> float:
|
| 451 |
-
"""
|
| 452 |
-
RDKit SA-score based synthesizability proxy in [0,1].
|
| 453 |
-
SA-score is ~[1 (easy), 10 (hard)].
|
| 454 |
-
We map: 1 -> 1.0, 10 -> 0.0
|
| 455 |
-
"""
|
| 456 |
-
m = Chem.MolFromSmiles(smiles)
|
| 457 |
-
if m is None:
|
| 458 |
-
return 0.0
|
| 459 |
-
|
| 460 |
-
# Guard against unexpected scorer failures / None for edge-case molecules.
|
| 461 |
-
try:
|
| 462 |
-
sa_raw = sascorer.calculateScore(m)
|
| 463 |
-
except Exception:
|
| 464 |
-
return 0.0
|
| 465 |
-
if sa_raw is None:
|
| 466 |
-
return 0.0
|
| 467 |
-
|
| 468 |
-
sa = float(sa_raw) # ~ 1..10
|
| 469 |
-
s_syn = 1.0 - (sa - 1.0) / 9.0 # linear map to [0,1]
|
| 470 |
-
return float(np.clip(s_syn, 0.0, 1.0))
|
| 471 |
-
|
| 472 |
-
|
| 473 |
-
def compute_trust_scores(
|
| 474 |
-
df: pd.DataFrame,
|
| 475 |
-
real_fps: List,
|
| 476 |
-
real_smiles: List[str],
|
| 477 |
-
trust_weights: Dict[str, float] | None = None,
|
| 478 |
-
) -> np.ndarray:
|
| 479 |
-
"""
|
| 480 |
-
Trust score in [0,1] (higher = more trustworthy / lower risk).
|
| 481 |
-
Components:
|
| 482 |
-
- distance to nearest real polymer (fingerprint distance)
|
| 483 |
-
- internal consistency penalty
|
| 484 |
-
- uncertainty penalty (if std columns exist)
|
| 485 |
-
- synthesizability
|
| 486 |
-
"""
|
| 487 |
-
N = len(df)
|
| 488 |
-
trust = np.zeros(N, dtype=float)
|
| 489 |
-
tw_defaults = {"real": 0.45, "consistency": 0.25, "uncertainty": 0.10, "synth": 0.20}
|
| 490 |
-
tw = normalize_weights(trust_weights or {}, tw_defaults)
|
| 491 |
-
|
| 492 |
-
# nearest-real distance (expensive if done naively)
|
| 493 |
-
# We do it only for the (small) post-filter set, which is safe.
|
| 494 |
-
smiles_col = "smiles_key" if "smiles_key" in df.columns else "smiles_canon"
|
| 495 |
-
for i in range(N):
|
| 496 |
-
s = df.iloc[i][smiles_col]
|
| 497 |
-
fp = morgan_fp(s)
|
| 498 |
-
if fp is None or not real_fps:
|
| 499 |
-
d_real = 1.0
|
| 500 |
-
else:
|
| 501 |
-
sims = DataStructs.BulkTanimotoSimilarity(fp, real_fps)
|
| 502 |
-
d_real = 1.0 - float(max(sims)) # distance to nearest
|
| 503 |
-
|
| 504 |
-
# internal consistency
|
| 505 |
-
pen_cons = internal_consistency_penalty(df.iloc[i])
|
| 506 |
-
|
| 507 |
-
# uncertainty: average normalized std for any std_* columns present
|
| 508 |
-
std_cols = [c for c in df.columns if c.startswith("std_")]
|
| 509 |
-
if std_cols:
|
| 510 |
-
std_vals = df.iloc[i][std_cols].astype(float)
|
| 511 |
-
std_vals = std_vals.replace([np.inf, -np.inf], np.nan).dropna()
|
| 512 |
-
pen_unc = float(np.clip(std_vals.mean() / (std_vals.mean() + 1.0), 0.0, 1.0)) if len(std_vals) else 0.0
|
| 513 |
-
else:
|
| 514 |
-
pen_unc = 0.0
|
| 515 |
-
|
| 516 |
-
# synthesizability heuristic
|
| 517 |
-
s_syn = synthesizability_score(s)
|
| 518 |
-
|
| 519 |
-
# Combine (tunable weights)
|
| 520 |
-
# lower distance to real is better -> convert to score
|
| 521 |
-
s_real = 1.0 - np.clip(d_real, 0.0, 1.0)
|
| 522 |
-
|
| 523 |
-
trust[i] = (
|
| 524 |
-
tw["real"] * s_real +
|
| 525 |
-
tw["consistency"] * (1.0 - pen_cons) +
|
| 526 |
-
tw["uncertainty"] * (1.0 - pen_unc) +
|
| 527 |
-
tw["synth"] * s_syn
|
| 528 |
-
)
|
| 529 |
-
|
| 530 |
-
trust = np.clip(trust, 0.0, 1.0)
|
| 531 |
-
return trust
|
| 532 |
-
|
| 533 |
-
|
| 534 |
-
# -------------------------
|
| 535 |
-
# Main pipeline
|
| 536 |
-
# -------------------------
|
| 537 |
-
def run_discovery(
|
| 538 |
-
spec: DiscoverySpec,
|
| 539 |
-
progress_callback: Optional[Callable[[str, float], None]] = None,
|
| 540 |
-
) -> Tuple[pd.DataFrame, Dict[str, float], pd.DataFrame]:
|
| 541 |
-
def report(step: str, pct: float) -> None:
|
| 542 |
-
if progress_callback is not None:
|
| 543 |
-
progress_callback(step, pct)
|
| 544 |
-
|
| 545 |
-
rng = np.random.default_rng(spec.random_seed)
|
| 546 |
-
|
| 547 |
-
# 1) Determine required columns
|
| 548 |
-
report("Preparing columns…", 0.02)
|
| 549 |
-
obj_props = [o["property"].lower() for o in spec.objectives]
|
| 550 |
-
cons_props = [p.lower() for p in spec.hard_constraints.keys()]
|
| 551 |
-
|
| 552 |
-
needed_props = sorted(set(obj_props + cons_props))
|
| 553 |
-
cols = ["SMILES"] + [mean_col(p) for p in needed_props]
|
| 554 |
-
|
| 555 |
-
# include std columns if available (not required, but used for trust)
|
| 556 |
-
std_cols = [std_col(p) for p in needed_props]
|
| 557 |
-
cols += std_cols
|
| 558 |
-
|
| 559 |
-
# 2) Load only needed columns
|
| 560 |
-
report("Loading data from parquet…", 0.05)
|
| 561 |
-
df = load_parquet_columns(spec.dataset, columns=[c for c in cols if c != "SMILES"] + ["SMILES"])
|
| 562 |
-
# normalize
|
| 563 |
-
if "SMILES" not in df.columns and "smiles" in df.columns:
|
| 564 |
-
df = df.rename(columns={"smiles": "SMILES"})
|
| 565 |
-
normalize_step = "Canonicalizing SMILES…" if spec.use_canonical_smiles else "Skipping SMILES normalization…"
|
| 566 |
-
report(normalize_step, 0.10)
|
| 567 |
-
df["smiles_key"] = df["SMILES"].astype(str).map(lambda s: normalize_smiles(s, spec.use_canonical_smiles))
|
| 568 |
-
df = df.dropna(subset=["smiles_key"]).reset_index(drop=True)
|
| 569 |
-
|
| 570 |
-
# 3) Hard constraints
|
| 571 |
-
report("Applying constraints…", 0.22)
|
| 572 |
-
for p, rule in spec.hard_constraints.items():
|
| 573 |
-
p = p.lower()
|
| 574 |
-
c = mean_col(p)
|
| 575 |
-
if c not in df.columns:
|
| 576 |
-
# if missing, nothing can satisfy
|
| 577 |
-
df = df.iloc[0:0]
|
| 578 |
-
break
|
| 579 |
-
if "min" in rule:
|
| 580 |
-
df = df[df[c] >= float(rule["min"])]
|
| 581 |
-
if "max" in rule:
|
| 582 |
-
df = df[df[c] <= float(rule["max"])]
|
| 583 |
-
|
| 584 |
-
n_after = len(df)
|
| 585 |
-
if n_after == 0:
|
| 586 |
-
empty_stats = {"n_total": 0, "n_after_constraints": 0, "n_pool": 0, "n_pareto_pool": 0, "n_selected": 0}
|
| 587 |
-
return df, empty_stats, pd.DataFrame()
|
| 588 |
-
|
| 589 |
-
n_pool = len(df)
|
| 590 |
-
|
| 591 |
-
# 5) Prepare objective matrix for Pareto
|
| 592 |
-
report("Building objective matrix…", 0.30)
|
| 593 |
-
# convert to minimization: maximize => negate
|
| 594 |
-
X = []
|
| 595 |
-
resolved_objectives = []
|
| 596 |
-
for o in spec.objectives:
|
| 597 |
-
prop = o["property"].lower()
|
| 598 |
-
goal = o["goal"].lower()
|
| 599 |
-
c = mean_col(prop)
|
| 600 |
-
if c not in df.columns:
|
| 601 |
-
continue
|
| 602 |
-
v = df[c].to_numpy(dtype=float)
|
| 603 |
-
if goal == "maximize":
|
| 604 |
-
v = -v
|
| 605 |
-
X.append(v)
|
| 606 |
-
resolved_objectives.append({"property": prop, "goal": goal})
|
| 607 |
-
if not X:
|
| 608 |
-
# Fallback to first available mean_* column to keep pipeline runnable.
|
| 609 |
-
fallback_col = next((c for c in df.columns if str(c).startswith("mean_")), None)
|
| 610 |
-
if fallback_col is None:
|
| 611 |
-
empty_stats = {"n_total": 0, "n_after_constraints": 0, "n_pool": 0, "n_pareto_pool": 0, "n_selected": 0}
|
| 612 |
-
return df.iloc[0:0], empty_stats, pd.DataFrame()
|
| 613 |
-
X = [df[fallback_col].to_numpy(dtype=float) * -1.0]
|
| 614 |
-
resolved_objectives = [{"property": fallback_col.replace("mean_", ""), "goal": "maximize"}]
|
| 615 |
-
X = np.stack(X, axis=1) # (N, M)
|
| 616 |
-
obj_props = [o["property"] for o in resolved_objectives]
|
| 617 |
-
|
| 618 |
-
# Pareto cap before computing layers (optional safety)
|
| 619 |
-
if spec.use_full_data:
|
| 620 |
-
report("Using full dataset (no Pareto cap)…", 0.35)
|
| 621 |
-
elif len(df) > spec.pareto_max:
|
| 622 |
-
idx = rng.choice(len(df), size=spec.pareto_max, replace=False)
|
| 623 |
-
df = df.iloc[idx].reset_index(drop=True)
|
| 624 |
-
X = X[idx]
|
| 625 |
-
|
| 626 |
-
# 6) Pareto layers (only 5 layers needed for candidate pool)
|
| 627 |
-
report("Computing Pareto layers…", 0.40)
|
| 628 |
-
pareto_start = 0.40
|
| 629 |
-
pareto_end = 0.54
|
| 630 |
-
max_layers_for_pool = max(1, int(spec.max_pareto_fronts))
|
| 631 |
-
pareto_chunk_ref = {"chunks_per_layer": None}
|
| 632 |
-
|
| 633 |
-
def on_pareto_chunk(layer_i: int, done_chunks: int, total_chunks: int) -> None:
|
| 634 |
-
if pareto_chunk_ref["chunks_per_layer"] is None:
|
| 635 |
-
pareto_chunk_ref["chunks_per_layer"] = max(1, int(total_chunks))
|
| 636 |
-
ref_chunks = pareto_chunk_ref["chunks_per_layer"]
|
| 637 |
-
total_units = max_layers_for_pool * ref_chunks
|
| 638 |
-
done_units = min(total_units, ((layer_i - 1) * ref_chunks) + done_chunks)
|
| 639 |
-
pareto_pct = int(round(100.0 * done_units / max(1, total_units)))
|
| 640 |
-
|
| 641 |
-
layer_progress = done_chunks / max(1, total_chunks)
|
| 642 |
-
overall = ((layer_i - 1) + layer_progress) / max_layers_for_pool
|
| 643 |
-
pct = pareto_start + (pareto_end - pareto_start) * min(1.0, max(0.0, overall))
|
| 644 |
-
report(
|
| 645 |
-
f"Computing Pareto layers… {pareto_pct}% (Layer {layer_i}/{max_layers_for_pool}, chunk {done_chunks}/{total_chunks})",
|
| 646 |
-
pct,
|
| 647 |
-
)
|
| 648 |
-
|
| 649 |
-
layers = pareto_layers_chunked(
|
| 650 |
-
X,
|
| 651 |
-
max_layers=max_layers_for_pool,
|
| 652 |
-
chunk_size=100000,
|
| 653 |
-
progress_callback=on_pareto_chunk,
|
| 654 |
-
)
|
| 655 |
-
report("Computing Pareto layers…", pareto_end)
|
| 656 |
-
df["pareto_layer"] = layers
|
| 657 |
-
plot_df = df[["smiles_key"] + [mean_col(p) for p in obj_props] + ["pareto_layer"]].copy()
|
| 658 |
-
plot_df = plot_df.rename(columns={"smiles_key": "SMILES"})
|
| 659 |
-
|
| 660 |
-
# Keep first few layers as candidate pool (avoid huge set)
|
| 661 |
-
cand = df[df["pareto_layer"].between(1, max_layers_for_pool)].copy()
|
| 662 |
-
if cand.empty:
|
| 663 |
-
cand = df[df["pareto_layer"] == 1].copy()
|
| 664 |
-
cand = cand.reset_index(drop=True)
|
| 665 |
-
n_pareto = len(cand)
|
| 666 |
-
|
| 667 |
-
# 7) Load real polymer metadata and fingerprints (from POLYINFO.csv)
|
| 668 |
-
report("Loading POLYINFO index…", 0.55)
|
| 669 |
-
polyinfo = load_polyinfo_index(spec.polyinfo_csv, use_canonical_smiles=spec.use_canonical_smiles)
|
| 670 |
-
real_smiles = polyinfo.index.to_list()
|
| 671 |
-
|
| 672 |
-
report("Building real-polymer fingerprints…", 0.60)
|
| 673 |
-
real_fps = []
|
| 674 |
-
for s in real_smiles:
|
| 675 |
-
fp = morgan_fp(s)
|
| 676 |
-
if fp is not None:
|
| 677 |
-
real_fps.append(fp)
|
| 678 |
-
|
| 679 |
-
# 8) Trust score on candidate pool (safe size)
|
| 680 |
-
report("Computing trust scores…", 0.70)
|
| 681 |
-
trust = compute_trust_scores(
|
| 682 |
-
cand,
|
| 683 |
-
real_fps=real_fps,
|
| 684 |
-
real_smiles=real_smiles,
|
| 685 |
-
trust_weights=spec.trust_weights,
|
| 686 |
-
)
|
| 687 |
-
cand["trust_score"] = trust
|
| 688 |
-
|
| 689 |
-
# 9) Diversity selection on candidate pool
|
| 690 |
-
report("Diversity selection…", 0.88)
|
| 691 |
-
# score for selection: prioritize Pareto layer 1 then trust
|
| 692 |
-
# higher is better
|
| 693 |
-
sw_defaults = {"pareto": 0.60, "trust": 0.40}
|
| 694 |
-
sw = normalize_weights(spec.selection_weights or {}, sw_defaults)
|
| 695 |
-
pareto_bonus = (
|
| 696 |
-
(max_layers_for_pool + 1) - np.clip(cand["pareto_layer"].to_numpy(dtype=int), 1, max_layers_for_pool)
|
| 697 |
-
) / float(max_layers_for_pool)
|
| 698 |
-
sel_score = sw["pareto"] * pareto_bonus + sw["trust"] * cand["trust_score"].to_numpy(dtype=float)
|
| 699 |
-
|
| 700 |
-
chosen_idx = greedy_diverse_select(
|
| 701 |
-
smiles_list=cand["smiles_key"].tolist(),
|
| 702 |
-
scores=sel_score,
|
| 703 |
-
max_k=spec.max_candidates,
|
| 704 |
-
min_dist=spec.min_distance,
|
| 705 |
-
)
|
| 706 |
-
out = cand.iloc[chosen_idx].copy().reset_index(drop=True)
|
| 707 |
-
|
| 708 |
-
# 10) Attach Polymer_Name/Class if available (only for matches)
|
| 709 |
-
report("Finalizing results…", 0.96)
|
| 710 |
-
out = out.set_index("smiles_key", drop=False)
|
| 711 |
-
out = out.join(polyinfo, how="left")
|
| 712 |
-
out = out.reset_index(drop=True)
|
| 713 |
-
|
| 714 |
-
# 11) Make a clean output bundle with requested columns
|
| 715 |
-
# Keep SMILES (canonical), name/class, pareto layer, trust score, properties used
|
| 716 |
-
keep = ["smiles_key", "polymer_name", "polymer_class", "pareto_layer", "trust_score"]
|
| 717 |
-
for p in needed_props:
|
| 718 |
-
mc = mean_col(p)
|
| 719 |
-
sc = std_col(p)
|
| 720 |
-
if mc in out.columns:
|
| 721 |
-
keep.append(mc)
|
| 722 |
-
if sc in out.columns:
|
| 723 |
-
keep.append(sc)
|
| 724 |
-
|
| 725 |
-
out = out[keep].rename(columns={"smiles_key": "SMILES"})
|
| 726 |
-
|
| 727 |
-
stats = {
|
| 728 |
-
"n_total": float(len(df)),
|
| 729 |
-
"n_after_constraints": float(n_after),
|
| 730 |
-
"n_pool": float(n_pool),
|
| 731 |
-
"n_pareto_pool": float(n_pareto),
|
| 732 |
-
"n_selected": float(len(out)),
|
| 733 |
-
}
|
| 734 |
-
report("Done.", 1.0)
|
| 735 |
-
return out, stats, plot_df
|
| 736 |
-
|
| 737 |
-
|
| 738 |
-
def build_pareto_plot_df(spec: DiscoverySpec, max_plot_points: int = 30000) -> pd.DataFrame:
|
| 739 |
-
"""
|
| 740 |
-
Returns a small dataframe for plotting (sampled), with objective columns and pareto_layer.
|
| 741 |
-
Does NOT compute trust/diversity. Safe for live plotting.
|
| 742 |
-
"""
|
| 743 |
-
rng = np.random.default_rng(spec.random_seed)
|
| 744 |
-
|
| 745 |
-
obj_props = [o["property"].lower() for o in spec.objectives]
|
| 746 |
-
cons_props = [p.lower() for p in spec.hard_constraints.keys()]
|
| 747 |
-
needed_props = sorted(set(obj_props + cons_props))
|
| 748 |
-
|
| 749 |
-
cols = ["SMILES"] + [mean_col(p) for p in needed_props]
|
| 750 |
-
df = load_parquet_columns(spec.dataset, columns=cols)
|
| 751 |
-
|
| 752 |
-
if "SMILES" not in df.columns and "smiles" in df.columns:
|
| 753 |
-
df = df.rename(columns={"smiles": "SMILES"})
|
| 754 |
-
|
| 755 |
-
df["smiles_key"] = df["SMILES"].astype(str).map(lambda s: normalize_smiles(s, spec.use_canonical_smiles))
|
| 756 |
-
df = df.dropna(subset=["smiles_key"]).reset_index(drop=True)
|
| 757 |
-
|
| 758 |
-
# Hard constraints
|
| 759 |
-
for p, rule in spec.hard_constraints.items():
|
| 760 |
-
p = p.lower()
|
| 761 |
-
c = mean_col(p)
|
| 762 |
-
if c not in df.columns:
|
| 763 |
-
return df.iloc[0:0]
|
| 764 |
-
if "min" in rule:
|
| 765 |
-
df = df[df[c] >= float(rule["min"])]
|
| 766 |
-
if "max" in rule:
|
| 767 |
-
df = df[df[c] <= float(rule["max"])]
|
| 768 |
-
|
| 769 |
-
if len(df) == 0:
|
| 770 |
-
return df
|
| 771 |
-
|
| 772 |
-
# Pareto cap for plotting
|
| 773 |
-
plot_cap = min(int(max_plot_points), int(spec.pareto_max))
|
| 774 |
-
if len(df) > plot_cap:
|
| 775 |
-
idx = rng.choice(len(df), size=plot_cap, replace=False)
|
| 776 |
-
df = df.iloc[idx].reset_index(drop=True)
|
| 777 |
-
|
| 778 |
-
# Build objective matrix (minimization)
|
| 779 |
-
X = []
|
| 780 |
-
resolved_obj_props = []
|
| 781 |
-
for o in spec.objectives:
|
| 782 |
-
prop = o["property"].lower()
|
| 783 |
-
goal = o["goal"].lower()
|
| 784 |
-
c = mean_col(prop)
|
| 785 |
-
if c not in df.columns:
|
| 786 |
-
continue
|
| 787 |
-
v = df[c].to_numpy(dtype=float)
|
| 788 |
-
if goal == "maximize":
|
| 789 |
-
v = -v
|
| 790 |
-
X.append(v)
|
| 791 |
-
resolved_obj_props.append(prop)
|
| 792 |
-
if not X:
|
| 793 |
-
fallback_col = next((c for c in df.columns if str(c).startswith("mean_")), None)
|
| 794 |
-
if fallback_col is None:
|
| 795 |
-
return df.iloc[0:0]
|
| 796 |
-
X = [df[fallback_col].to_numpy(dtype=float) * -1.0]
|
| 797 |
-
resolved_obj_props = [fallback_col.replace("mean_", "")]
|
| 798 |
-
X = np.stack(X, axis=1)
|
| 799 |
-
|
| 800 |
-
df["pareto_layer"] = pareto_layers(X, max_layers=5)
|
| 801 |
-
|
| 802 |
-
# Return only what plotting needs
|
| 803 |
-
keep = ["smiles_key", "pareto_layer"] + [mean_col(p) for p in resolved_obj_props]
|
| 804 |
-
out = df[keep].rename(columns={"smiles_key": "SMILES"})
|
| 805 |
-
return out
|
| 806 |
-
|
| 807 |
-
|
| 808 |
-
def parse_spec(text: str, dataset_path: List[str], polyinfo_path: str, polyinfo_csv_path: str) -> DiscoverySpec:
|
| 809 |
-
obj = json.loads(text)
|
| 810 |
-
pareto_max = int(obj.get("pareto_max", 50000))
|
| 811 |
-
|
| 812 |
-
return DiscoverySpec(
|
| 813 |
-
dataset=list(dataset_path),
|
| 814 |
-
polyinfo=polyinfo_path,
|
| 815 |
-
polyinfo_csv=polyinfo_csv_path,
|
| 816 |
-
hard_constraints=obj.get("hard_constraints", {}),
|
| 817 |
-
objectives=obj.get("objectives", []),
|
| 818 |
-
max_pool=pareto_max,
|
| 819 |
-
pareto_max=pareto_max,
|
| 820 |
-
max_candidates=int(obj.get("max_candidates", 30)),
|
| 821 |
-
max_pareto_fronts=int(obj.get("max_pareto_fronts", 5)),
|
| 822 |
-
min_distance=float(obj.get("min_distance", 0.30)),
|
| 823 |
-
fingerprint=str(obj.get("fingerprint", "morgan")),
|
| 824 |
-
random_seed=int(obj.get("random_seed", 7)),
|
| 825 |
-
use_canonical_smiles=not bool(obj.get("skip_smiles_canonicalization", True)),
|
| 826 |
-
use_full_data=bool(obj.get("use_full_data", False)),
|
| 827 |
-
trust_weights=obj.get("trust_weights"),
|
| 828 |
-
selection_weights=obj.get("selection_weights"),
|
| 829 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/discovery.py
DELETED
|
@@ -1,767 +0,0 @@
|
|
| 1 |
-
# src/discovery.py
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
import json
|
| 5 |
-
from dataclasses import dataclass
|
| 6 |
-
from pathlib import Path
|
| 7 |
-
from typing import Callable, Dict, List, Optional, Tuple
|
| 8 |
-
|
| 9 |
-
import numpy as np
|
| 10 |
-
import pandas as pd
|
| 11 |
-
from rdkit import Chem, DataStructs
|
| 12 |
-
from rdkit.Chem import AllChem
|
| 13 |
-
from . import sascorer
|
| 14 |
-
|
| 15 |
-
# Reuse your canonicalizer if you want; otherwise keep local
|
| 16 |
-
def canonicalize_smiles(smiles: str) -> Optional[str]:
|
| 17 |
-
s = (smiles or "").strip()
|
| 18 |
-
if not s:
|
| 19 |
-
return None
|
| 20 |
-
m = Chem.MolFromSmiles(s)
|
| 21 |
-
if m is None:
|
| 22 |
-
return None
|
| 23 |
-
return Chem.MolToSmiles(m, canonical=True)
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
# -------------------------
|
| 27 |
-
# Spec schema (minimal v0)
|
| 28 |
-
# -------------------------
|
| 29 |
-
@dataclass
|
| 30 |
-
class DiscoverySpec:
|
| 31 |
-
dataset: List[str] # ["PI1M_PROPERTY.parquet", "POLYINFO_PROPERTY.parquet"]
|
| 32 |
-
polyinfo: str # "POLYINFO_PROPERTY.parquet"
|
| 33 |
-
polyinfo_csv: str # "POLYINFO.csv"
|
| 34 |
-
|
| 35 |
-
hard_constraints: Dict[str, Dict[str, float]] # { "tg": {"min": 400}, "tc": {"max": 0.3} }
|
| 36 |
-
objectives: List[Dict[str, str]] # [{"property":"cp","goal":"maximize"}, ...]
|
| 37 |
-
|
| 38 |
-
max_pool: int = 200000 # legacy (kept for compatibility; aligned to pareto_max)
|
| 39 |
-
pareto_max: int = 50000 # cap points used for Pareto + diversity fingerprinting
|
| 40 |
-
max_candidates: int = 30 # final output size
|
| 41 |
-
max_pareto_fronts: int = 5 # how many Pareto layers to keep for candidate pool
|
| 42 |
-
min_distance: float = 0.30 # diversity threshold in Tanimoto distance
|
| 43 |
-
fingerprint: str = "morgan" # morgan only for now
|
| 44 |
-
random_seed: int = 7
|
| 45 |
-
use_canonical_smiles: bool = True
|
| 46 |
-
use_full_data: bool = False
|
| 47 |
-
trust_weights: Dict[str, float] | None = None
|
| 48 |
-
selection_weights: Dict[str, float] | None = None
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
# -------------------------
|
| 52 |
-
# Column mapping
|
| 53 |
-
# -------------------------
|
| 54 |
-
def mean_col(prop_key: str) -> str:
|
| 55 |
-
return f"mean_{prop_key.lower()}"
|
| 56 |
-
|
| 57 |
-
def std_col(prop_key: str) -> str:
|
| 58 |
-
return f"std_{prop_key.lower()}"
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
def normalize_weights(weights: Dict[str, float], defaults: Dict[str, float]) -> Dict[str, float]:
|
| 62 |
-
out: Dict[str, float] = {}
|
| 63 |
-
for k, v in defaults.items():
|
| 64 |
-
try:
|
| 65 |
-
vv = float(weights.get(k, v))
|
| 66 |
-
except Exception:
|
| 67 |
-
vv = float(v)
|
| 68 |
-
out[k] = max(0.0, vv)
|
| 69 |
-
s = float(sum(out.values()))
|
| 70 |
-
if s <= 0.0:
|
| 71 |
-
return defaults.copy()
|
| 72 |
-
return {k: float(v / s) for k, v in out.items()}
|
| 73 |
-
|
| 74 |
-
def spec_from_dict(obj: dict, dataset_path: List[str], polyinfo_path: str, polyinfo_csv_path: str) -> DiscoverySpec:
|
| 75 |
-
pareto_max = int(obj.get("pareto_max", 50000))
|
| 76 |
-
return DiscoverySpec(
|
| 77 |
-
dataset=list(dataset_path),
|
| 78 |
-
polyinfo=polyinfo_path,
|
| 79 |
-
polyinfo_csv=polyinfo_csv_path,
|
| 80 |
-
hard_constraints=obj.get("hard_constraints", {}),
|
| 81 |
-
objectives=obj.get("objectives", []),
|
| 82 |
-
# Legacy field kept for compatibility; effectively collapsed to pareto_max.
|
| 83 |
-
max_pool=pareto_max,
|
| 84 |
-
pareto_max=pareto_max,
|
| 85 |
-
max_candidates=int(obj.get("max_candidates", 30)),
|
| 86 |
-
max_pareto_fronts=int(obj.get("max_pareto_fronts", 5)),
|
| 87 |
-
min_distance=float(obj.get("min_distance", 0.30)),
|
| 88 |
-
fingerprint=str(obj.get("fingerprint", "morgan")),
|
| 89 |
-
random_seed=int(obj.get("random_seed", 7)),
|
| 90 |
-
use_canonical_smiles=not bool(obj.get("skip_smiles_canonicalization", True)),
|
| 91 |
-
use_full_data=bool(obj.get("use_full_data", False)),
|
| 92 |
-
trust_weights=obj.get("trust_weights"),
|
| 93 |
-
selection_weights=obj.get("selection_weights"),
|
| 94 |
-
)
|
| 95 |
-
|
| 96 |
-
# -------------------------
|
| 97 |
-
# Parquet loading (safe)
|
| 98 |
-
# -------------------------
|
| 99 |
-
def load_parquet_columns(path: str | List[str], columns: List[str]) -> pd.DataFrame:
|
| 100 |
-
"""
|
| 101 |
-
Load only requested columns from Parquet (critical for 1M rows).
|
| 102 |
-
Accepts a single path or a list of paths and concatenates rows.
|
| 103 |
-
"""
|
| 104 |
-
def _load_one(fp: str, req_cols: List[str]) -> pd.DataFrame:
|
| 105 |
-
available: list[str]
|
| 106 |
-
try:
|
| 107 |
-
import pyarrow.parquet as pq
|
| 108 |
-
|
| 109 |
-
pf = pq.ParquetFile(fp)
|
| 110 |
-
available = [str(c) for c in pf.schema.names]
|
| 111 |
-
except Exception:
|
| 112 |
-
# If schema probing fails, fall back to direct read with requested columns.
|
| 113 |
-
return pd.read_parquet(fp, columns=req_cols)
|
| 114 |
-
|
| 115 |
-
available_set = set(available)
|
| 116 |
-
lower_to_actual = {c.lower(): c for c in available}
|
| 117 |
-
|
| 118 |
-
# Resolve requested names against actual parquet schema.
|
| 119 |
-
resolved: dict[str, str] = {}
|
| 120 |
-
for req in req_cols:
|
| 121 |
-
if req in available_set:
|
| 122 |
-
resolved[req] = req
|
| 123 |
-
continue
|
| 124 |
-
alt = lower_to_actual.get(str(req).lower())
|
| 125 |
-
if alt is not None:
|
| 126 |
-
resolved[req] = alt
|
| 127 |
-
|
| 128 |
-
use_cols = sorted(set(resolved.values()))
|
| 129 |
-
if not use_cols:
|
| 130 |
-
return pd.DataFrame(columns=req_cols)
|
| 131 |
-
|
| 132 |
-
out = pd.read_parquet(fp, columns=use_cols)
|
| 133 |
-
for req in req_cols:
|
| 134 |
-
src = resolved.get(req)
|
| 135 |
-
if src is None:
|
| 136 |
-
out[req] = np.nan
|
| 137 |
-
elif src != req:
|
| 138 |
-
out[req] = out[src]
|
| 139 |
-
return out[req_cols]
|
| 140 |
-
|
| 141 |
-
if isinstance(path, (list, tuple)):
|
| 142 |
-
frames = [_load_one(p, columns) for p in path]
|
| 143 |
-
if not frames:
|
| 144 |
-
return pd.DataFrame(columns=columns)
|
| 145 |
-
return pd.concat(frames, ignore_index=True)
|
| 146 |
-
return _load_one(path, columns)
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
def normalize_smiles(smiles: str, use_canonical_smiles: bool) -> Optional[str]:
|
| 150 |
-
s = (smiles or "").strip()
|
| 151 |
-
if not s:
|
| 152 |
-
return None
|
| 153 |
-
if not use_canonical_smiles:
|
| 154 |
-
# Skip RDKit parsing entirely in fast mode.
|
| 155 |
-
return s
|
| 156 |
-
m = Chem.MolFromSmiles(s)
|
| 157 |
-
if m is None:
|
| 158 |
-
return None
|
| 159 |
-
if use_canonical_smiles:
|
| 160 |
-
return Chem.MolToSmiles(m, canonical=True)
|
| 161 |
-
return s
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
def load_polyinfo_index(polyinfo_csv_path: str, use_canonical_smiles: bool = True) -> pd.DataFrame:
|
| 165 |
-
"""
|
| 166 |
-
Expected CSV columns: SMILES, Polymer_Class, polymer_name (or common variants).
|
| 167 |
-
Returns dataframe with index on smiles_key and columns polymer_name/polymer_class.
|
| 168 |
-
"""
|
| 169 |
-
df = pd.read_csv(polyinfo_csv_path)
|
| 170 |
-
|
| 171 |
-
# normalize column names
|
| 172 |
-
cols = {c: c for c in df.columns}
|
| 173 |
-
# map typical names
|
| 174 |
-
if "SMILES" in cols:
|
| 175 |
-
df = df.rename(columns={"SMILES": "smiles"})
|
| 176 |
-
elif "smiles" not in df.columns:
|
| 177 |
-
raise ValueError(f"{polyinfo_csv_path} missing SMILES/smiles column")
|
| 178 |
-
|
| 179 |
-
if "Polymer_Name" in df.columns:
|
| 180 |
-
df = df.rename(columns={"Polymer_Name": "polymer_name"})
|
| 181 |
-
if "polymer_Name" in df.columns:
|
| 182 |
-
df = df.rename(columns={"polymer_Name": "polymer_name"})
|
| 183 |
-
if "Polymer_Class" in df.columns:
|
| 184 |
-
df = df.rename(columns={"Polymer_Class": "polymer_class"})
|
| 185 |
-
|
| 186 |
-
if "polymer_name" not in df.columns:
|
| 187 |
-
df["polymer_name"] = pd.NA
|
| 188 |
-
if "polymer_class" not in df.columns:
|
| 189 |
-
df["polymer_class"] = pd.NA
|
| 190 |
-
|
| 191 |
-
df["smiles_key"] = df["smiles"].astype(str).map(lambda s: normalize_smiles(s, use_canonical_smiles))
|
| 192 |
-
df = df.dropna(subset=["smiles_key"]).drop_duplicates("smiles_key")
|
| 193 |
-
df = df.set_index("smiles_key", drop=True)
|
| 194 |
-
return df[["polymer_name", "polymer_class"]]
|
| 195 |
-
|
| 196 |
-
|
| 197 |
-
# -------------------------
|
| 198 |
-
# Pareto (2–3 objectives)
|
| 199 |
-
# -------------------------
|
| 200 |
-
def pareto_front_mask(X: np.ndarray) -> np.ndarray:
|
| 201 |
-
"""
|
| 202 |
-
Returns mask for nondominated points.
|
| 203 |
-
X: (N, M), all objectives assumed to be minimized.
|
| 204 |
-
For maximize objectives, we invert before calling this.
|
| 205 |
-
"""
|
| 206 |
-
N = X.shape[0]
|
| 207 |
-
is_efficient = np.ones(N, dtype=bool)
|
| 208 |
-
for i in range(N):
|
| 209 |
-
if not is_efficient[i]:
|
| 210 |
-
continue
|
| 211 |
-
# any point that is <= in all dims and < in at least one dominates
|
| 212 |
-
dominates = np.all(X <= X[i], axis=1) & np.any(X < X[i], axis=1)
|
| 213 |
-
# if a point dominates i, mark i inefficient
|
| 214 |
-
if np.any(dominates):
|
| 215 |
-
is_efficient[i] = False
|
| 216 |
-
continue
|
| 217 |
-
# otherwise, i may dominate others
|
| 218 |
-
dominated_by_i = np.all(X[i] <= X, axis=1) & np.any(X[i] < X, axis=1)
|
| 219 |
-
is_efficient[dominated_by_i] = False
|
| 220 |
-
is_efficient[i] = True
|
| 221 |
-
return is_efficient
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
def pareto_layers(X: np.ndarray, max_layers: int = 10) -> np.ndarray:
|
| 225 |
-
"""
|
| 226 |
-
Returns layer index per point: 1 = Pareto front, 2 = second layer, ...
|
| 227 |
-
Unassigned points beyond max_layers get 0.
|
| 228 |
-
"""
|
| 229 |
-
N = X.shape[0]
|
| 230 |
-
layers = np.zeros(N, dtype=int)
|
| 231 |
-
remaining = np.arange(N)
|
| 232 |
-
|
| 233 |
-
layer = 1
|
| 234 |
-
while remaining.size > 0 and layer <= max_layers:
|
| 235 |
-
mask = pareto_front_mask(X[remaining])
|
| 236 |
-
front_idx = remaining[mask]
|
| 237 |
-
layers[front_idx] = layer
|
| 238 |
-
remaining = remaining[~mask]
|
| 239 |
-
layer += 1
|
| 240 |
-
return layers
|
| 241 |
-
|
| 242 |
-
|
| 243 |
-
def pareto_front_mask_chunked(
|
| 244 |
-
X: np.ndarray,
|
| 245 |
-
chunk_size: int = 100000,
|
| 246 |
-
progress_callback: Optional[Callable[[int, int], None]] = None,
|
| 247 |
-
) -> np.ndarray:
|
| 248 |
-
"""
|
| 249 |
-
Exact global Pareto front mask via chunk-local front reduction + global reconcile.
|
| 250 |
-
This is exact for front-1:
|
| 251 |
-
1) compute exact local front within each chunk
|
| 252 |
-
2) union local fronts
|
| 253 |
-
3) compute exact front on the union
|
| 254 |
-
"""
|
| 255 |
-
N = X.shape[0]
|
| 256 |
-
if N <= chunk_size:
|
| 257 |
-
if progress_callback is not None:
|
| 258 |
-
progress_callback(1, 1)
|
| 259 |
-
return pareto_front_mask(X)
|
| 260 |
-
|
| 261 |
-
local_front_idx = []
|
| 262 |
-
total_chunks = (N + chunk_size - 1) // chunk_size
|
| 263 |
-
done_chunks = 0
|
| 264 |
-
for start in range(0, N, chunk_size):
|
| 265 |
-
end = min(start + chunk_size, N)
|
| 266 |
-
idx = np.arange(start, end)
|
| 267 |
-
mask_local = pareto_front_mask(X[idx])
|
| 268 |
-
local_front_idx.append(idx[mask_local])
|
| 269 |
-
done_chunks += 1
|
| 270 |
-
if progress_callback is not None:
|
| 271 |
-
progress_callback(done_chunks, total_chunks)
|
| 272 |
-
|
| 273 |
-
if not local_front_idx:
|
| 274 |
-
return np.zeros(N, dtype=bool)
|
| 275 |
-
|
| 276 |
-
reduced_idx = np.concatenate(local_front_idx)
|
| 277 |
-
reduced_mask = pareto_front_mask(X[reduced_idx])
|
| 278 |
-
front_idx = reduced_idx[reduced_mask]
|
| 279 |
-
|
| 280 |
-
out = np.zeros(N, dtype=bool)
|
| 281 |
-
out[front_idx] = True
|
| 282 |
-
return out
|
| 283 |
-
|
| 284 |
-
|
| 285 |
-
def pareto_layers_chunked(
|
| 286 |
-
X: np.ndarray,
|
| 287 |
-
max_layers: int = 10,
|
| 288 |
-
chunk_size: int = 100000,
|
| 289 |
-
progress_callback: Optional[Callable[[int, int, int], None]] = None,
|
| 290 |
-
) -> np.ndarray:
|
| 291 |
-
"""
|
| 292 |
-
Exact Pareto layers using repeated exact chunked front extraction.
|
| 293 |
-
"""
|
| 294 |
-
N = X.shape[0]
|
| 295 |
-
layers = np.zeros(N, dtype=int)
|
| 296 |
-
remaining = np.arange(N)
|
| 297 |
-
layer = 1
|
| 298 |
-
|
| 299 |
-
while remaining.size > 0 and layer <= max_layers:
|
| 300 |
-
def on_chunk(done: int, total: int) -> None:
|
| 301 |
-
if progress_callback is not None:
|
| 302 |
-
progress_callback(layer, done, total)
|
| 303 |
-
|
| 304 |
-
mask = pareto_front_mask_chunked(X[remaining], chunk_size=chunk_size, progress_callback=on_chunk)
|
| 305 |
-
front_idx = remaining[mask]
|
| 306 |
-
layers[front_idx] = layer
|
| 307 |
-
remaining = remaining[~mask]
|
| 308 |
-
layer += 1
|
| 309 |
-
|
| 310 |
-
return layers
|
| 311 |
-
|
| 312 |
-
|
| 313 |
-
# -------------------------
|
| 314 |
-
# Fingerprints & diversity
|
| 315 |
-
# -------------------------
|
| 316 |
-
def morgan_fp(smiles: str, radius: int = 2, nbits: int = 2048):
|
| 317 |
-
m = Chem.MolFromSmiles(smiles)
|
| 318 |
-
if m is None:
|
| 319 |
-
return None
|
| 320 |
-
return AllChem.GetMorganFingerprintAsBitVect(m, radius, nBits=nbits)
|
| 321 |
-
|
| 322 |
-
def tanimoto_distance(fp1, fp2) -> float:
|
| 323 |
-
return 1.0 - DataStructs.TanimotoSimilarity(fp1, fp2)
|
| 324 |
-
|
| 325 |
-
def greedy_diverse_select(
|
| 326 |
-
smiles_list: List[str],
|
| 327 |
-
scores: np.ndarray,
|
| 328 |
-
max_k: int,
|
| 329 |
-
min_dist: float,
|
| 330 |
-
) -> List[int]:
|
| 331 |
-
"""
|
| 332 |
-
Greedy selection by descending score, enforcing min Tanimoto distance.
|
| 333 |
-
Returns indices into smiles_list.
|
| 334 |
-
"""
|
| 335 |
-
fps = []
|
| 336 |
-
valid_idx = []
|
| 337 |
-
for i, s in enumerate(smiles_list):
|
| 338 |
-
fp = morgan_fp(s)
|
| 339 |
-
if fp is not None:
|
| 340 |
-
fps.append(fp)
|
| 341 |
-
valid_idx.append(i)
|
| 342 |
-
|
| 343 |
-
if not valid_idx:
|
| 344 |
-
return []
|
| 345 |
-
|
| 346 |
-
# rank candidates (higher score first)
|
| 347 |
-
order = np.argsort(-scores[valid_idx])
|
| 348 |
-
selected_global = []
|
| 349 |
-
selected_fps = []
|
| 350 |
-
|
| 351 |
-
for oi in order:
|
| 352 |
-
i = valid_idx[oi]
|
| 353 |
-
fp_i = fps[oi] # aligned with valid_idx
|
| 354 |
-
ok = True
|
| 355 |
-
for fp_j in selected_fps:
|
| 356 |
-
if tanimoto_distance(fp_i, fp_j) < min_dist:
|
| 357 |
-
ok = False
|
| 358 |
-
break
|
| 359 |
-
if ok:
|
| 360 |
-
selected_global.append(i)
|
| 361 |
-
selected_fps.append(fp_i)
|
| 362 |
-
if len(selected_global) >= max_k:
|
| 363 |
-
break
|
| 364 |
-
|
| 365 |
-
return selected_global
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
# -------------------------
|
| 369 |
-
# Trust score (lightweight, robust)
|
| 370 |
-
# -------------------------
|
| 371 |
-
def internal_consistency_penalty(row: pd.Series) -> float:
|
| 372 |
-
"""
|
| 373 |
-
Very simple physics/validity checks. Penalty in [0,1].
|
| 374 |
-
Adjust/add rules later.
|
| 375 |
-
"""
|
| 376 |
-
viol = 0
|
| 377 |
-
total = 0
|
| 378 |
-
|
| 379 |
-
def chk(cond: bool):
|
| 380 |
-
nonlocal viol, total
|
| 381 |
-
total += 1
|
| 382 |
-
if not cond:
|
| 383 |
-
viol += 1
|
| 384 |
-
|
| 385 |
-
# positivity checks if present
|
| 386 |
-
for p in ["cp", "tc", "rho", "dif", "visc", "tg", "tm", "bandgap"]:
|
| 387 |
-
c = mean_col(p)
|
| 388 |
-
if c in row.index and pd.notna(row[c]):
|
| 389 |
-
if p in ["bandgap", "tg", "tm"]:
|
| 390 |
-
chk(float(row[c]) >= 0.0)
|
| 391 |
-
else:
|
| 392 |
-
chk(float(row[c]) > 0.0)
|
| 393 |
-
|
| 394 |
-
# Poisson ratio bounds if present
|
| 395 |
-
if mean_col("poisson") in row.index and pd.notna(row[mean_col("poisson")]):
|
| 396 |
-
v = float(row[mean_col("poisson")])
|
| 397 |
-
chk(0.0 <= v <= 0.5)
|
| 398 |
-
|
| 399 |
-
# Tg <= Tm if both present
|
| 400 |
-
if mean_col("tg") in row.index and mean_col("tm") in row.index:
|
| 401 |
-
if pd.notna(row[mean_col("tg")]) and pd.notna(row[mean_col("tm")]):
|
| 402 |
-
chk(float(row[mean_col("tg")]) <= float(row[mean_col("tm")]))
|
| 403 |
-
|
| 404 |
-
if total == 0:
|
| 405 |
-
return 0.0
|
| 406 |
-
return viol / total
|
| 407 |
-
|
| 408 |
-
|
| 409 |
-
def synthesizability_score(smiles: str) -> float:
|
| 410 |
-
"""
|
| 411 |
-
RDKit SA-score based synthesizability proxy in [0,1].
|
| 412 |
-
SA-score is ~[1 (easy), 10 (hard)].
|
| 413 |
-
We map: 1 -> 1.0, 10 -> 0.0
|
| 414 |
-
"""
|
| 415 |
-
m = Chem.MolFromSmiles(smiles)
|
| 416 |
-
if m is None:
|
| 417 |
-
return 0.0
|
| 418 |
-
|
| 419 |
-
# Guard against unexpected scorer failures / None for edge-case molecules.
|
| 420 |
-
try:
|
| 421 |
-
sa_raw = sascorer.calculateScore(m)
|
| 422 |
-
except Exception:
|
| 423 |
-
return 0.0
|
| 424 |
-
if sa_raw is None:
|
| 425 |
-
return 0.0
|
| 426 |
-
|
| 427 |
-
sa = float(sa_raw) # ~ 1..10
|
| 428 |
-
s_syn = 1.0 - (sa - 1.0) / 9.0 # linear map to [0,1]
|
| 429 |
-
return float(np.clip(s_syn, 0.0, 1.0))
|
| 430 |
-
|
| 431 |
-
|
| 432 |
-
def compute_trust_scores(
|
| 433 |
-
df: pd.DataFrame,
|
| 434 |
-
real_fps: List,
|
| 435 |
-
real_smiles: List[str],
|
| 436 |
-
trust_weights: Dict[str, float] | None = None,
|
| 437 |
-
) -> np.ndarray:
|
| 438 |
-
"""
|
| 439 |
-
Trust score in [0,1] (higher = more trustworthy / lower risk).
|
| 440 |
-
Components:
|
| 441 |
-
- distance to nearest real polymer (fingerprint distance)
|
| 442 |
-
- internal consistency penalty
|
| 443 |
-
- uncertainty penalty (if std columns exist)
|
| 444 |
-
- synthesizability
|
| 445 |
-
"""
|
| 446 |
-
N = len(df)
|
| 447 |
-
trust = np.zeros(N, dtype=float)
|
| 448 |
-
tw_defaults = {"real": 0.45, "consistency": 0.25, "uncertainty": 0.10, "synth": 0.20}
|
| 449 |
-
tw = normalize_weights(trust_weights or {}, tw_defaults)
|
| 450 |
-
|
| 451 |
-
# nearest-real distance (expensive if done naively)
|
| 452 |
-
# We do it only for the (small) post-filter set, which is safe.
|
| 453 |
-
smiles_col = "smiles_key" if "smiles_key" in df.columns else "smiles_canon"
|
| 454 |
-
for i in range(N):
|
| 455 |
-
s = df.iloc[i][smiles_col]
|
| 456 |
-
fp = morgan_fp(s)
|
| 457 |
-
if fp is None or not real_fps:
|
| 458 |
-
d_real = 1.0
|
| 459 |
-
else:
|
| 460 |
-
sims = DataStructs.BulkTanimotoSimilarity(fp, real_fps)
|
| 461 |
-
d_real = 1.0 - float(max(sims)) # distance to nearest
|
| 462 |
-
|
| 463 |
-
# internal consistency
|
| 464 |
-
pen_cons = internal_consistency_penalty(df.iloc[i])
|
| 465 |
-
|
| 466 |
-
# uncertainty: average normalized std for any std_* columns present
|
| 467 |
-
std_cols = [c for c in df.columns if c.startswith("std_")]
|
| 468 |
-
if std_cols:
|
| 469 |
-
std_vals = df.iloc[i][std_cols].astype(float)
|
| 470 |
-
std_vals = std_vals.replace([np.inf, -np.inf], np.nan).dropna()
|
| 471 |
-
pen_unc = float(np.clip(std_vals.mean() / (std_vals.mean() + 1.0), 0.0, 1.0)) if len(std_vals) else 0.0
|
| 472 |
-
else:
|
| 473 |
-
pen_unc = 0.0
|
| 474 |
-
|
| 475 |
-
# synthesizability heuristic
|
| 476 |
-
s_syn = synthesizability_score(s)
|
| 477 |
-
|
| 478 |
-
# Combine (tunable weights)
|
| 479 |
-
# lower distance to real is better -> convert to score
|
| 480 |
-
s_real = 1.0 - np.clip(d_real, 0.0, 1.0)
|
| 481 |
-
|
| 482 |
-
trust[i] = (
|
| 483 |
-
tw["real"] * s_real +
|
| 484 |
-
tw["consistency"] * (1.0 - pen_cons) +
|
| 485 |
-
tw["uncertainty"] * (1.0 - pen_unc) +
|
| 486 |
-
tw["synth"] * s_syn
|
| 487 |
-
)
|
| 488 |
-
|
| 489 |
-
trust = np.clip(trust, 0.0, 1.0)
|
| 490 |
-
return trust
|
| 491 |
-
|
| 492 |
-
|
| 493 |
-
# -------------------------
|
| 494 |
-
# Main pipeline
|
| 495 |
-
# -------------------------
|
| 496 |
-
def run_discovery(
|
| 497 |
-
spec: DiscoverySpec,
|
| 498 |
-
progress_callback: Optional[Callable[[str, float], None]] = None,
|
| 499 |
-
) -> Tuple[pd.DataFrame, Dict[str, float], pd.DataFrame]:
|
| 500 |
-
def report(step: str, pct: float) -> None:
|
| 501 |
-
if progress_callback is not None:
|
| 502 |
-
progress_callback(step, pct)
|
| 503 |
-
|
| 504 |
-
rng = np.random.default_rng(spec.random_seed)
|
| 505 |
-
|
| 506 |
-
# 1) Determine required columns
|
| 507 |
-
report("Preparing columns…", 0.02)
|
| 508 |
-
obj_props = [o["property"].lower() for o in spec.objectives]
|
| 509 |
-
cons_props = [p.lower() for p in spec.hard_constraints.keys()]
|
| 510 |
-
|
| 511 |
-
needed_props = sorted(set(obj_props + cons_props))
|
| 512 |
-
cols = ["SMILES"] + [mean_col(p) for p in needed_props]
|
| 513 |
-
|
| 514 |
-
# include std columns if available (not required, but used for trust)
|
| 515 |
-
std_cols = [std_col(p) for p in needed_props]
|
| 516 |
-
cols += std_cols
|
| 517 |
-
|
| 518 |
-
# 2) Load only needed columns
|
| 519 |
-
report("Loading data from parquet…", 0.05)
|
| 520 |
-
df = load_parquet_columns(spec.dataset, columns=[c for c in cols if c != "SMILES"] + ["SMILES"])
|
| 521 |
-
# normalize
|
| 522 |
-
if "SMILES" not in df.columns and "smiles" in df.columns:
|
| 523 |
-
df = df.rename(columns={"smiles": "SMILES"})
|
| 524 |
-
normalize_step = "Canonicalizing SMILES…" if spec.use_canonical_smiles else "Skipping SMILES normalization…"
|
| 525 |
-
report(normalize_step, 0.10)
|
| 526 |
-
df["smiles_key"] = df["SMILES"].astype(str).map(lambda s: normalize_smiles(s, spec.use_canonical_smiles))
|
| 527 |
-
df = df.dropna(subset=["smiles_key"]).reset_index(drop=True)
|
| 528 |
-
|
| 529 |
-
# 3) Hard constraints
|
| 530 |
-
report("Applying constraints…", 0.22)
|
| 531 |
-
for p, rule in spec.hard_constraints.items():
|
| 532 |
-
p = p.lower()
|
| 533 |
-
c = mean_col(p)
|
| 534 |
-
if c not in df.columns:
|
| 535 |
-
# if missing, nothing can satisfy
|
| 536 |
-
df = df.iloc[0:0]
|
| 537 |
-
break
|
| 538 |
-
if "min" in rule:
|
| 539 |
-
df = df[df[c] >= float(rule["min"])]
|
| 540 |
-
if "max" in rule:
|
| 541 |
-
df = df[df[c] <= float(rule["max"])]
|
| 542 |
-
|
| 543 |
-
n_after = len(df)
|
| 544 |
-
if n_after == 0:
|
| 545 |
-
empty_stats = {"n_total": 0, "n_after_constraints": 0, "n_pool": 0, "n_pareto_pool": 0, "n_selected": 0}
|
| 546 |
-
return df, empty_stats, pd.DataFrame()
|
| 547 |
-
|
| 548 |
-
n_pool = len(df)
|
| 549 |
-
|
| 550 |
-
# 5) Prepare objective matrix for Pareto
|
| 551 |
-
report("Building objective matrix…", 0.30)
|
| 552 |
-
# convert to minimization: maximize => negate
|
| 553 |
-
X = []
|
| 554 |
-
for o in spec.objectives:
|
| 555 |
-
prop = o["property"].lower()
|
| 556 |
-
goal = o["goal"].lower()
|
| 557 |
-
c = mean_col(prop)
|
| 558 |
-
if c not in df.columns:
|
| 559 |
-
raise ValueError(f"Objective column missing: {c}")
|
| 560 |
-
v = df[c].to_numpy(dtype=float)
|
| 561 |
-
if goal == "maximize":
|
| 562 |
-
v = -v
|
| 563 |
-
X.append(v)
|
| 564 |
-
X = np.stack(X, axis=1) # (N, M)
|
| 565 |
-
|
| 566 |
-
# Pareto cap before computing layers (optional safety)
|
| 567 |
-
if spec.use_full_data:
|
| 568 |
-
report("Using full dataset (no Pareto cap)…", 0.35)
|
| 569 |
-
elif len(df) > spec.pareto_max:
|
| 570 |
-
idx = rng.choice(len(df), size=spec.pareto_max, replace=False)
|
| 571 |
-
df = df.iloc[idx].reset_index(drop=True)
|
| 572 |
-
X = X[idx]
|
| 573 |
-
|
| 574 |
-
# 6) Pareto layers (only 5 layers needed for candidate pool)
|
| 575 |
-
report("Computing Pareto layers…", 0.40)
|
| 576 |
-
pareto_start = 0.40
|
| 577 |
-
pareto_end = 0.54
|
| 578 |
-
max_layers_for_pool = max(1, int(spec.max_pareto_fronts))
|
| 579 |
-
pareto_chunk_ref = {"chunks_per_layer": None}
|
| 580 |
-
|
| 581 |
-
def on_pareto_chunk(layer_i: int, done_chunks: int, total_chunks: int) -> None:
|
| 582 |
-
if pareto_chunk_ref["chunks_per_layer"] is None:
|
| 583 |
-
pareto_chunk_ref["chunks_per_layer"] = max(1, int(total_chunks))
|
| 584 |
-
ref_chunks = pareto_chunk_ref["chunks_per_layer"]
|
| 585 |
-
total_units = max_layers_for_pool * ref_chunks
|
| 586 |
-
done_units = min(total_units, ((layer_i - 1) * ref_chunks) + done_chunks)
|
| 587 |
-
pareto_pct = int(round(100.0 * done_units / max(1, total_units)))
|
| 588 |
-
|
| 589 |
-
layer_progress = done_chunks / max(1, total_chunks)
|
| 590 |
-
overall = ((layer_i - 1) + layer_progress) / max_layers_for_pool
|
| 591 |
-
pct = pareto_start + (pareto_end - pareto_start) * min(1.0, max(0.0, overall))
|
| 592 |
-
report(
|
| 593 |
-
f"Computing Pareto layers… {pareto_pct}% (Layer {layer_i}/{max_layers_for_pool}, chunk {done_chunks}/{total_chunks})",
|
| 594 |
-
pct,
|
| 595 |
-
)
|
| 596 |
-
|
| 597 |
-
layers = pareto_layers_chunked(
|
| 598 |
-
X,
|
| 599 |
-
max_layers=max_layers_for_pool,
|
| 600 |
-
chunk_size=100000,
|
| 601 |
-
progress_callback=on_pareto_chunk,
|
| 602 |
-
)
|
| 603 |
-
report("Computing Pareto layers…", pareto_end)
|
| 604 |
-
df["pareto_layer"] = layers
|
| 605 |
-
plot_df = df[["smiles_key"] + [mean_col(p) for p in obj_props] + ["pareto_layer"]].copy()
|
| 606 |
-
plot_df = plot_df.rename(columns={"smiles_key": "SMILES"})
|
| 607 |
-
|
| 608 |
-
# Keep first few layers as candidate pool (avoid huge set)
|
| 609 |
-
cand = df[df["pareto_layer"].between(1, max_layers_for_pool)].copy()
|
| 610 |
-
if cand.empty:
|
| 611 |
-
cand = df[df["pareto_layer"] == 1].copy()
|
| 612 |
-
cand = cand.reset_index(drop=True)
|
| 613 |
-
n_pareto = len(cand)
|
| 614 |
-
|
| 615 |
-
# 7) Load real polymer metadata and fingerprints (from POLYINFO.csv)
|
| 616 |
-
report("Loading POLYINFO index…", 0.55)
|
| 617 |
-
polyinfo = load_polyinfo_index(spec.polyinfo_csv, use_canonical_smiles=spec.use_canonical_smiles)
|
| 618 |
-
real_smiles = polyinfo.index.to_list()
|
| 619 |
-
|
| 620 |
-
report("Building real-polymer fingerprints…", 0.60)
|
| 621 |
-
real_fps = []
|
| 622 |
-
for s in real_smiles:
|
| 623 |
-
fp = morgan_fp(s)
|
| 624 |
-
if fp is not None:
|
| 625 |
-
real_fps.append(fp)
|
| 626 |
-
|
| 627 |
-
# 8) Trust score on candidate pool (safe size)
|
| 628 |
-
report("Computing trust scores…", 0.70)
|
| 629 |
-
trust = compute_trust_scores(
|
| 630 |
-
cand,
|
| 631 |
-
real_fps=real_fps,
|
| 632 |
-
real_smiles=real_smiles,
|
| 633 |
-
trust_weights=spec.trust_weights,
|
| 634 |
-
)
|
| 635 |
-
cand["trust_score"] = trust
|
| 636 |
-
|
| 637 |
-
# 9) Diversity selection on candidate pool
|
| 638 |
-
report("Diversity selection…", 0.88)
|
| 639 |
-
# score for selection: prioritize Pareto layer 1 then trust
|
| 640 |
-
# higher is better
|
| 641 |
-
sw_defaults = {"pareto": 0.60, "trust": 0.40}
|
| 642 |
-
sw = normalize_weights(spec.selection_weights or {}, sw_defaults)
|
| 643 |
-
pareto_bonus = (
|
| 644 |
-
(max_layers_for_pool + 1) - np.clip(cand["pareto_layer"].to_numpy(dtype=int), 1, max_layers_for_pool)
|
| 645 |
-
) / float(max_layers_for_pool)
|
| 646 |
-
sel_score = sw["pareto"] * pareto_bonus + sw["trust"] * cand["trust_score"].to_numpy(dtype=float)
|
| 647 |
-
|
| 648 |
-
chosen_idx = greedy_diverse_select(
|
| 649 |
-
smiles_list=cand["smiles_key"].tolist(),
|
| 650 |
-
scores=sel_score,
|
| 651 |
-
max_k=spec.max_candidates,
|
| 652 |
-
min_dist=spec.min_distance,
|
| 653 |
-
)
|
| 654 |
-
out = cand.iloc[chosen_idx].copy().reset_index(drop=True)
|
| 655 |
-
|
| 656 |
-
# 10) Attach Polymer_Name/Class if available (only for matches)
|
| 657 |
-
report("Finalizing results…", 0.96)
|
| 658 |
-
out = out.set_index("smiles_key", drop=False)
|
| 659 |
-
out = out.join(polyinfo, how="left")
|
| 660 |
-
out = out.reset_index(drop=True)
|
| 661 |
-
|
| 662 |
-
# 11) Make a clean output bundle with requested columns
|
| 663 |
-
# Keep SMILES (canonical), name/class, pareto layer, trust score, properties used
|
| 664 |
-
keep = ["smiles_key", "polymer_name", "polymer_class", "pareto_layer", "trust_score"]
|
| 665 |
-
for p in needed_props:
|
| 666 |
-
mc = mean_col(p)
|
| 667 |
-
sc = std_col(p)
|
| 668 |
-
if mc in out.columns:
|
| 669 |
-
keep.append(mc)
|
| 670 |
-
if sc in out.columns:
|
| 671 |
-
keep.append(sc)
|
| 672 |
-
|
| 673 |
-
out = out[keep].rename(columns={"smiles_key": "SMILES"})
|
| 674 |
-
|
| 675 |
-
stats = {
|
| 676 |
-
"n_total": float(len(df)),
|
| 677 |
-
"n_after_constraints": float(n_after),
|
| 678 |
-
"n_pool": float(n_pool),
|
| 679 |
-
"n_pareto_pool": float(n_pareto),
|
| 680 |
-
"n_selected": float(len(out)),
|
| 681 |
-
}
|
| 682 |
-
report("Done.", 1.0)
|
| 683 |
-
return out, stats, plot_df
|
| 684 |
-
|
| 685 |
-
|
| 686 |
-
def build_pareto_plot_df(spec: DiscoverySpec, max_plot_points: int = 30000) -> pd.DataFrame:
|
| 687 |
-
"""
|
| 688 |
-
Returns a small dataframe for plotting (sampled), with objective columns and pareto_layer.
|
| 689 |
-
Does NOT compute trust/diversity. Safe for live plotting.
|
| 690 |
-
"""
|
| 691 |
-
rng = np.random.default_rng(spec.random_seed)
|
| 692 |
-
|
| 693 |
-
obj_props = [o["property"].lower() for o in spec.objectives]
|
| 694 |
-
cons_props = [p.lower() for p in spec.hard_constraints.keys()]
|
| 695 |
-
needed_props = sorted(set(obj_props + cons_props))
|
| 696 |
-
|
| 697 |
-
cols = ["SMILES"] + [mean_col(p) for p in needed_props]
|
| 698 |
-
df = load_parquet_columns(spec.dataset, columns=cols)
|
| 699 |
-
|
| 700 |
-
if "SMILES" not in df.columns and "smiles" in df.columns:
|
| 701 |
-
df = df.rename(columns={"smiles": "SMILES"})
|
| 702 |
-
|
| 703 |
-
df["smiles_key"] = df["SMILES"].astype(str).map(lambda s: normalize_smiles(s, spec.use_canonical_smiles))
|
| 704 |
-
df = df.dropna(subset=["smiles_key"]).reset_index(drop=True)
|
| 705 |
-
|
| 706 |
-
# Hard constraints
|
| 707 |
-
for p, rule in spec.hard_constraints.items():
|
| 708 |
-
p = p.lower()
|
| 709 |
-
c = mean_col(p)
|
| 710 |
-
if c not in df.columns:
|
| 711 |
-
return df.iloc[0:0]
|
| 712 |
-
if "min" in rule:
|
| 713 |
-
df = df[df[c] >= float(rule["min"])]
|
| 714 |
-
if "max" in rule:
|
| 715 |
-
df = df[df[c] <= float(rule["max"])]
|
| 716 |
-
|
| 717 |
-
if len(df) == 0:
|
| 718 |
-
return df
|
| 719 |
-
|
| 720 |
-
# Pareto cap for plotting
|
| 721 |
-
plot_cap = min(int(max_plot_points), int(spec.pareto_max))
|
| 722 |
-
if len(df) > plot_cap:
|
| 723 |
-
idx = rng.choice(len(df), size=plot_cap, replace=False)
|
| 724 |
-
df = df.iloc[idx].reset_index(drop=True)
|
| 725 |
-
|
| 726 |
-
# Build objective matrix (minimization)
|
| 727 |
-
X = []
|
| 728 |
-
for o in spec.objectives:
|
| 729 |
-
prop = o["property"].lower()
|
| 730 |
-
goal = o["goal"].lower()
|
| 731 |
-
c = mean_col(prop)
|
| 732 |
-
v = df[c].to_numpy(dtype=float)
|
| 733 |
-
if goal == "maximize":
|
| 734 |
-
v = -v
|
| 735 |
-
X.append(v)
|
| 736 |
-
X = np.stack(X, axis=1)
|
| 737 |
-
|
| 738 |
-
df["pareto_layer"] = pareto_layers(X, max_layers=5)
|
| 739 |
-
|
| 740 |
-
# Return only what plotting needs
|
| 741 |
-
keep = ["smiles_key", "pareto_layer"] + [mean_col(p) for p in obj_props]
|
| 742 |
-
out = df[keep].rename(columns={"smiles_key": "SMILES"})
|
| 743 |
-
return out
|
| 744 |
-
|
| 745 |
-
|
| 746 |
-
def parse_spec(text: str, dataset_path: List[str], polyinfo_path: str, polyinfo_csv_path: str) -> DiscoverySpec:
|
| 747 |
-
obj = json.loads(text)
|
| 748 |
-
pareto_max = int(obj.get("pareto_max", 50000))
|
| 749 |
-
|
| 750 |
-
return DiscoverySpec(
|
| 751 |
-
dataset=list(dataset_path),
|
| 752 |
-
polyinfo=polyinfo_path,
|
| 753 |
-
polyinfo_csv=polyinfo_csv_path,
|
| 754 |
-
hard_constraints=obj.get("hard_constraints", {}),
|
| 755 |
-
objectives=obj.get("objectives", []),
|
| 756 |
-
max_pool=pareto_max,
|
| 757 |
-
pareto_max=pareto_max,
|
| 758 |
-
max_candidates=int(obj.get("max_candidates", 30)),
|
| 759 |
-
max_pareto_fronts=int(obj.get("max_pareto_fronts", 5)),
|
| 760 |
-
min_distance=float(obj.get("min_distance", 0.30)),
|
| 761 |
-
fingerprint=str(obj.get("fingerprint", "morgan")),
|
| 762 |
-
random_seed=int(obj.get("random_seed", 7)),
|
| 763 |
-
use_canonical_smiles=not bool(obj.get("skip_smiles_canonicalization", True)),
|
| 764 |
-
use_full_data=bool(obj.get("use_full_data", False)),
|
| 765 |
-
trust_weights=obj.get("trust_weights"),
|
| 766 |
-
selection_weights=obj.get("selection_weights"),
|
| 767 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/fpscores.pkl.gz
DELETED
|
@@ -1,3 +0,0 @@
|
|
| 1 |
-
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:9abb6f4c322d27fa05c8a1115a463bcef312d2bed0b447c347de33bfefa83316
|
| 3 |
-
size 132
|
|
|
|
|
|
|
|
|
|
|
|
src/lookup.py
DELETED
|
@@ -1,222 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import pandas as pd
|
| 4 |
-
import streamlit as st
|
| 5 |
-
from rdkit import Chem
|
| 6 |
-
from rdkit import RDLogger
|
| 7 |
-
|
| 8 |
-
RDLogger.DisableLog("rdApp.*")
|
| 9 |
-
|
| 10 |
-
# ----------------------------
|
| 11 |
-
# Sources (property value files)
|
| 12 |
-
# ----------------------------
|
| 13 |
-
SOURCES = ["EXP", "MD", "DFT", "GC"]
|
| 14 |
-
|
| 15 |
-
SOURCE_LABELS = {
|
| 16 |
-
"EXP": "Experimental",
|
| 17 |
-
"MD": "Molecular Dynamics",
|
| 18 |
-
"DFT": "Density Functional Theory",
|
| 19 |
-
"GC": "Group Contribution",
|
| 20 |
-
}
|
| 21 |
-
|
| 22 |
-
# ----------------------------
|
| 23 |
-
# PolyInfo metadata file (name/class)
|
| 24 |
-
# ----------------------------
|
| 25 |
-
POLYINFO_FILE = "data/POLYINFO.csv" # contains: SMILES, Polymer_Class, Polymer_Name
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def canonicalize_smiles(smiles: str) -> str | None:
|
| 29 |
-
smiles = (smiles or "").strip()
|
| 30 |
-
if not smiles:
|
| 31 |
-
return None
|
| 32 |
-
mol = Chem.MolFromSmiles(smiles)
|
| 33 |
-
if mol is None:
|
| 34 |
-
return None
|
| 35 |
-
return Chem.MolToSmiles(mol, canonical=True)
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
# --- Property meta (full name + unit) ---
|
| 39 |
-
PROPERTY_META = {
|
| 40 |
-
# Thermal
|
| 41 |
-
"tm": {"name": "Melting temperature", "unit": "K"},
|
| 42 |
-
"tg": {"name": "Glass transition temperature", "unit": "K"},
|
| 43 |
-
"td": {"name": "Thermal diffusivity", "unit": "m^2/s"},
|
| 44 |
-
"tc": {"name": "Thermal conductivity", "unit": "W/m·K"},
|
| 45 |
-
"cp": {"name": "Specific heat capacity", "unit": "J/kg·K"},
|
| 46 |
-
# Mechanical
|
| 47 |
-
"young": {"name": "Young's modulus", "unit": "GPa"},
|
| 48 |
-
"shear": {"name": "Shear modulus", "unit": "GPa"},
|
| 49 |
-
"bulk": {"name": "Bulk modulus", "unit": "GPa"},
|
| 50 |
-
"poisson": {"name": "Poisson ratio", "unit": "-"},
|
| 51 |
-
# Transport
|
| 52 |
-
"visc": {"name": "Viscosity", "unit": "Pa·s"},
|
| 53 |
-
"dif": {"name": "Diffusivity", "unit": "cm^2/s"},
|
| 54 |
-
# Gas permeability
|
| 55 |
-
"phe": {"name": "He permeability", "unit": "Barrer"},
|
| 56 |
-
"ph2": {"name": "H2 permeability", "unit": "Barrer"},
|
| 57 |
-
"pco2": {"name": "CO2 permeability", "unit": "Barrer"},
|
| 58 |
-
"pn2": {"name": "N2 permeability", "unit": "Barrer"},
|
| 59 |
-
"po2": {"name": "O2 permeability", "unit": "Barrer"},
|
| 60 |
-
"pch4": {"name": "CH4 permeability", "unit": "Barrer"},
|
| 61 |
-
# Electronic / Optical
|
| 62 |
-
"alpha": {"name": "Polarizability", "unit": "a.u."},
|
| 63 |
-
"homo": {"name": "HOMO energy", "unit": "eV"},
|
| 64 |
-
"lumo": {"name": "LUMO energy", "unit": "eV"},
|
| 65 |
-
"bandgap": {"name": "Band gap", "unit": "eV"},
|
| 66 |
-
"mu": {"name": "Dipole moment", "unit": "Debye"},
|
| 67 |
-
"etotal": {"name": "Total electronic energy", "unit": "eV"},
|
| 68 |
-
"ri": {"name": "Refractive index", "unit": "-"},
|
| 69 |
-
"dc": {"name": "Dielectric constant", "unit": "-"},
|
| 70 |
-
"pe": {"name": "Permittivity", "unit": "-"},
|
| 71 |
-
# Structural / Physical
|
| 72 |
-
"rg": {"name": "Radius of gyration", "unit": "Å"},
|
| 73 |
-
"rho": {"name": "Density", "unit": "g/cm^3"},
|
| 74 |
-
}
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
@st.cache_data
|
| 78 |
-
def load_source_csv(source: str) -> pd.DataFrame:
|
| 79 |
-
"""
|
| 80 |
-
Loads data/{SOURCE}.csv, normalizes:
|
| 81 |
-
- SMILES column -> 'smiles'
|
| 82 |
-
- property columns -> lowercase
|
| 83 |
-
- adds 'smiles_canon'
|
| 84 |
-
"""
|
| 85 |
-
path = f"data/{source}.csv"
|
| 86 |
-
df = pd.read_csv(path)
|
| 87 |
-
|
| 88 |
-
# Normalize SMILES column name
|
| 89 |
-
if "SMILES" in df.columns:
|
| 90 |
-
df = df.rename(columns={"SMILES": "smiles"})
|
| 91 |
-
elif "smiles" not in df.columns:
|
| 92 |
-
raise ValueError(f"{path} missing SMILES column")
|
| 93 |
-
|
| 94 |
-
# Normalize property column names to lowercase
|
| 95 |
-
rename_map = {c: c.lower() for c in df.columns if c != "smiles"}
|
| 96 |
-
df = df.rename(columns=rename_map)
|
| 97 |
-
|
| 98 |
-
# Canonicalize SMILES
|
| 99 |
-
df["smiles_canon"] = df["smiles"].astype(str).apply(canonicalize_smiles)
|
| 100 |
-
df = df.dropna(subset=["smiles_canon"]).reset_index(drop=True)
|
| 101 |
-
|
| 102 |
-
return df
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
@st.cache_data
|
| 106 |
-
def build_index(df: pd.DataFrame) -> dict[str, int]:
|
| 107 |
-
"""canonical smiles -> row index (first occurrence)"""
|
| 108 |
-
idx: dict[str, int] = {}
|
| 109 |
-
for i, s in enumerate(df["smiles_canon"].tolist()):
|
| 110 |
-
if s and s not in idx:
|
| 111 |
-
idx[s] = i
|
| 112 |
-
return idx
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
@st.cache_data
|
| 116 |
-
def load_polyinfo_csv() -> pd.DataFrame:
|
| 117 |
-
"""
|
| 118 |
-
Loads data/POLYINFO.csv with columns:
|
| 119 |
-
SMILES, Polymer_Class, Polymer_Name
|
| 120 |
-
Adds canonical smiles column 'smiles_canon'.
|
| 121 |
-
Returns empty df if file missing.
|
| 122 |
-
"""
|
| 123 |
-
try:
|
| 124 |
-
df = pd.read_csv(POLYINFO_FILE)
|
| 125 |
-
except Exception:
|
| 126 |
-
return pd.DataFrame(columns=["smiles", "polymer_class", "polymer_name", "smiles_canon"])
|
| 127 |
-
|
| 128 |
-
# Normalize columns
|
| 129 |
-
if "SMILES" in df.columns:
|
| 130 |
-
df = df.rename(columns={"SMILES": "smiles"})
|
| 131 |
-
elif "smiles" not in df.columns:
|
| 132 |
-
# If the file doesn't have a SMILES column as expected, return empty gracefully
|
| 133 |
-
return pd.DataFrame(columns=["smiles", "polymer_class", "polymer_name", "smiles_canon"])
|
| 134 |
-
|
| 135 |
-
# Normalize expected meta columns
|
| 136 |
-
ren = {}
|
| 137 |
-
if "Polymer_Class" in df.columns:
|
| 138 |
-
ren["Polymer_Class"] = "polymer_class"
|
| 139 |
-
if "Polymer_Name" in df.columns:
|
| 140 |
-
ren["Polymer_Name"] = "polymer_name"
|
| 141 |
-
df = df.rename(columns=ren)
|
| 142 |
-
|
| 143 |
-
# Ensure the columns exist (even if missing in the file)
|
| 144 |
-
if "polymer_class" not in df.columns:
|
| 145 |
-
df["polymer_class"] = pd.NA
|
| 146 |
-
if "polymer_name" not in df.columns:
|
| 147 |
-
df["polymer_name"] = pd.NA
|
| 148 |
-
|
| 149 |
-
# Canonicalize smiles
|
| 150 |
-
df["smiles_canon"] = df["smiles"].astype(str).apply(canonicalize_smiles)
|
| 151 |
-
df = df.dropna(subset=["smiles_canon"]).reset_index(drop=True)
|
| 152 |
-
|
| 153 |
-
return df
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
@st.cache_data
|
| 157 |
-
def load_all_sources():
|
| 158 |
-
"""
|
| 159 |
-
Returns dict:
|
| 160 |
-
db["EXP"/"MD"/"DFT"/"GC"] = {"df": df, "idx": idx}
|
| 161 |
-
db["POLYINFO"] = {"df": df, "idx": idx}
|
| 162 |
-
"""
|
| 163 |
-
db = {}
|
| 164 |
-
for src in SOURCES:
|
| 165 |
-
df = load_source_csv(src)
|
| 166 |
-
idx = build_index(df)
|
| 167 |
-
db[src] = {"df": df, "idx": idx}
|
| 168 |
-
|
| 169 |
-
# PolyInfo metadata
|
| 170 |
-
pi_df = load_polyinfo_csv()
|
| 171 |
-
pi_idx = build_index(pi_df) if not pi_df.empty else {}
|
| 172 |
-
db["POLYINFO"] = {"df": pi_df, "idx": pi_idx}
|
| 173 |
-
|
| 174 |
-
return db
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
def get_value(db, source: str, smiles_canon: str, prop_key: str):
|
| 178 |
-
pack = db[source]
|
| 179 |
-
df, idx = pack["df"], pack["idx"]
|
| 180 |
-
row_i = idx.get(smiles_canon, None)
|
| 181 |
-
if row_i is None:
|
| 182 |
-
return None
|
| 183 |
-
if prop_key not in df.columns:
|
| 184 |
-
return None
|
| 185 |
-
val = df.iloc[row_i][prop_key]
|
| 186 |
-
if pd.isna(val):
|
| 187 |
-
return None
|
| 188 |
-
return float(val)
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
def get_polyinfo(db, smiles_canon: str) -> tuple[str | None, str | None]:
|
| 192 |
-
"""
|
| 193 |
-
Returns (polymer_name, polymer_class) if available, else (None, None).
|
| 194 |
-
No 'not available' text here.
|
| 195 |
-
"""
|
| 196 |
-
pack = db.get("POLYINFO", None)
|
| 197 |
-
if pack is None:
|
| 198 |
-
return None, None
|
| 199 |
-
|
| 200 |
-
df, idx = pack["df"], pack["idx"]
|
| 201 |
-
if df is None or df.empty:
|
| 202 |
-
return None, None
|
| 203 |
-
|
| 204 |
-
row_i = idx.get(smiles_canon, None)
|
| 205 |
-
if row_i is None:
|
| 206 |
-
return None, None
|
| 207 |
-
|
| 208 |
-
name = df.iloc[row_i].get("polymer_name", None)
|
| 209 |
-
cls = df.iloc[row_i].get("polymer_class", None)
|
| 210 |
-
|
| 211 |
-
# Clean up NA / empty
|
| 212 |
-
if pd.isna(name) or str(name).strip() == "":
|
| 213 |
-
name = None
|
| 214 |
-
else:
|
| 215 |
-
name = str(name).strip()
|
| 216 |
-
|
| 217 |
-
if pd.isna(cls) or str(cls).strip() == "":
|
| 218 |
-
cls = None
|
| 219 |
-
else:
|
| 220 |
-
cls = str(cls).strip()
|
| 221 |
-
|
| 222 |
-
return name, cls
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/model.py
DELETED
|
@@ -1,312 +0,0 @@
|
|
| 1 |
-
# model.py
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
from typing import List, Optional, Literal
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
import torch.nn as nn
|
| 8 |
-
import torch.nn.functional as F
|
| 9 |
-
from torch_geometric.data import Batch
|
| 10 |
-
|
| 11 |
-
from src.conv import build_gnn_encoder, GNNEncoder
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
def get_activation(name: str) -> nn.Module:
|
| 15 |
-
name = name.lower()
|
| 16 |
-
if name == "relu":
|
| 17 |
-
return nn.ReLU()
|
| 18 |
-
if name == "gelu":
|
| 19 |
-
return nn.GELU()
|
| 20 |
-
if name == "silu":
|
| 21 |
-
return nn.SiLU()
|
| 22 |
-
if name in ("leaky_relu", "lrelu"):
|
| 23 |
-
return nn.LeakyReLU(0.1)
|
| 24 |
-
raise ValueError(f"Unknown activation: {name}")
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
class FiLM(nn.Module):
|
| 28 |
-
"""
|
| 29 |
-
Simple FiLM: gamma, beta from condition vector; apply to features as (1+gamma)*h + beta
|
| 30 |
-
"""
|
| 31 |
-
def __init__(self, feat_dim: int, cond_dim: int):
|
| 32 |
-
super().__init__()
|
| 33 |
-
self.gamma = nn.Linear(cond_dim, feat_dim)
|
| 34 |
-
self.beta = nn.Linear(cond_dim, feat_dim)
|
| 35 |
-
|
| 36 |
-
def forward(self, h: torch.Tensor, cond: torch.Tensor) -> torch.Tensor:
|
| 37 |
-
g = self.gamma(cond)
|
| 38 |
-
b = self.beta(cond)
|
| 39 |
-
return (1.0 + g) * h + b
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
class TaskHead(nn.Module):
|
| 43 |
-
"""
|
| 44 |
-
Per-task MLP head. Input is concatenation of [graph_embed, optional task_embed].
|
| 45 |
-
Outputs either a mean only (scalar) or mean+logvar (heteroscedastic).
|
| 46 |
-
"""
|
| 47 |
-
def __init__(
|
| 48 |
-
self,
|
| 49 |
-
in_dim: int,
|
| 50 |
-
hidden_dim: int = 512,
|
| 51 |
-
depth: int = 2,
|
| 52 |
-
act: str = "relu",
|
| 53 |
-
dropout: float = 0.0,
|
| 54 |
-
heteroscedastic: bool = False,
|
| 55 |
-
):
|
| 56 |
-
super().__init__()
|
| 57 |
-
layers: List[nn.Module] = []
|
| 58 |
-
d = in_dim
|
| 59 |
-
for _ in range(depth):
|
| 60 |
-
layers.append(nn.Linear(d, hidden_dim))
|
| 61 |
-
layers.append(get_activation(act))
|
| 62 |
-
if dropout > 0:
|
| 63 |
-
layers.append(nn.Dropout(dropout))
|
| 64 |
-
d = hidden_dim
|
| 65 |
-
out_dim = 2 if heteroscedastic else 1
|
| 66 |
-
layers.append(nn.Linear(d, out_dim))
|
| 67 |
-
self.net = nn.Sequential(*layers)
|
| 68 |
-
self.hetero = heteroscedastic
|
| 69 |
-
|
| 70 |
-
def forward(self, z: torch.Tensor) -> torch.Tensor:
|
| 71 |
-
# returns [B, 1] or [B, 2] where [...,0] is mean and [...,1] is logvar if heteroscedastic
|
| 72 |
-
return self.net(z)
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
class MultiTaskMultiFidelityModel(nn.Module):
|
| 76 |
-
"""
|
| 77 |
-
General multi-task, multi-fidelity GNN.
|
| 78 |
-
|
| 79 |
-
- Any number of tasks (properties) via T = len(task_names)
|
| 80 |
-
- Any number of fidelities via num_fids
|
| 81 |
-
- Fidelity conditioning with an embedding and FiLM on the graph embedding
|
| 82 |
-
- Optional task embeddings concatenated into each task head input
|
| 83 |
-
- Single forward returning predictions [B, T] (means); if heteroscedastic, also returns log-variances
|
| 84 |
-
|
| 85 |
-
Expected input Batch fields (PyG):
|
| 86 |
-
- x : [N_nodes, F_node]
|
| 87 |
-
- edge_index : [2, N_edges]
|
| 88 |
-
- edge_attr : [N_edges, F_edge] (required if gnn_type="gine")
|
| 89 |
-
- batch : [N_nodes]
|
| 90 |
-
- fid_idx : [B] or [B, 1] long; integer fidelity per graph
|
| 91 |
-
|
| 92 |
-
Notes:
|
| 93 |
-
- Targets should already be normalized outside the model; apply inverse transform for plots.
|
| 94 |
-
- Loss weighting/equal-importance and curriculum happen in the trainer, not here.
|
| 95 |
-
"""
|
| 96 |
-
|
| 97 |
-
def __init__(
|
| 98 |
-
self,
|
| 99 |
-
in_dim_node: int,
|
| 100 |
-
in_dim_edge: int,
|
| 101 |
-
task_names: List[str],
|
| 102 |
-
num_fids: int,
|
| 103 |
-
gnn_type: Literal["gine", "gin", "gcn"] = "gine",
|
| 104 |
-
gnn_emb_dim: int = 256,
|
| 105 |
-
gnn_layers: int = 5,
|
| 106 |
-
gnn_norm: Literal["batch", "layer", "none"] = "batch",
|
| 107 |
-
gnn_readout: Literal["mean", "sum", "max"] = "mean",
|
| 108 |
-
gnn_act: str = "relu",
|
| 109 |
-
gnn_dropout: float = 0.0,
|
| 110 |
-
gnn_residual: bool = True,
|
| 111 |
-
# Fidelity conditioning
|
| 112 |
-
fid_emb_dim: int = 64,
|
| 113 |
-
use_film: bool = True,
|
| 114 |
-
# Task conditioning
|
| 115 |
-
use_task_embed: bool = True,
|
| 116 |
-
task_emb_dim: int = 32,
|
| 117 |
-
# Heads
|
| 118 |
-
head_hidden: int = 512,
|
| 119 |
-
head_depth: int = 2,
|
| 120 |
-
head_act: str = "relu",
|
| 121 |
-
head_dropout: float = 0.0,
|
| 122 |
-
heteroscedastic: bool = False,
|
| 123 |
-
# Optional homoscedastic task uncertainty (used in loss, kept here for checkpoint parity)
|
| 124 |
-
use_task_uncertainty: bool = False,
|
| 125 |
-
# Embedding regularization (used via regularization_loss)
|
| 126 |
-
fid_emb_l2: float = 0.0,
|
| 127 |
-
task_emb_l2: float = 0.0,
|
| 128 |
-
):
|
| 129 |
-
super().__init__()
|
| 130 |
-
self.task_names = list(task_names)
|
| 131 |
-
self.num_tasks = len(task_names)
|
| 132 |
-
self.num_fids = int(num_fids)
|
| 133 |
-
self.hetero = heteroscedastic
|
| 134 |
-
self.fid_emb_l2 = float(fid_emb_l2)
|
| 135 |
-
self.task_emb_l2 = float(task_emb_l2)
|
| 136 |
-
self.use_film = use_film
|
| 137 |
-
self.use_task_embed = use_task_embed
|
| 138 |
-
|
| 139 |
-
# Optional learned homoscedastic uncertainty per task (trainer may use it)
|
| 140 |
-
self.use_task_uncertainty = bool(use_task_uncertainty)
|
| 141 |
-
if self.use_task_uncertainty:
|
| 142 |
-
self.task_log_sigma2 = nn.Parameter(torch.zeros(self.num_tasks))
|
| 143 |
-
else:
|
| 144 |
-
self.task_log_sigma2 = None
|
| 145 |
-
|
| 146 |
-
# Encoder
|
| 147 |
-
self.encoder: GNNEncoder = build_gnn_encoder(
|
| 148 |
-
in_dim_node=in_dim_node,
|
| 149 |
-
emb_dim=gnn_emb_dim,
|
| 150 |
-
num_layers=gnn_layers,
|
| 151 |
-
gnn_type=gnn_type,
|
| 152 |
-
in_dim_edge=in_dim_edge,
|
| 153 |
-
act=gnn_act,
|
| 154 |
-
dropout=gnn_dropout,
|
| 155 |
-
residual=gnn_residual,
|
| 156 |
-
norm=gnn_norm,
|
| 157 |
-
readout=gnn_readout,
|
| 158 |
-
)
|
| 159 |
-
|
| 160 |
-
# Fidelity embedding + FiLM
|
| 161 |
-
self.fid_embed = nn.Embedding(self.num_fids, fid_emb_dim) if fid_emb_dim > 0 else None
|
| 162 |
-
self.film = FiLM(gnn_emb_dim, fid_emb_dim) if (use_film and fid_emb_dim > 0) else None
|
| 163 |
-
|
| 164 |
-
# --- Compute the true feature dim sent to heads ---
|
| 165 |
-
# If FiLM is ON: g stays [B, gnn_emb_dim]
|
| 166 |
-
# If FiLM is OFF but fid_embed exists: we CONCAT c → g becomes [B, gnn_emb_dim + fid_emb_dim]
|
| 167 |
-
self.gnn_out_dim = gnn_emb_dim + (fid_emb_dim if (self.fid_embed is not None and self.film is None) else 0)
|
| 168 |
-
|
| 169 |
-
# Task embeddings
|
| 170 |
-
self.task_embed = nn.Embedding(self.num_tasks, task_emb_dim) if (use_task_embed and task_emb_dim > 0) else None
|
| 171 |
-
|
| 172 |
-
# Per-task heads
|
| 173 |
-
head_in_dim = self.gnn_out_dim + (task_emb_dim if self.task_embed is not None else 0)
|
| 174 |
-
self.heads = nn.ModuleList([
|
| 175 |
-
TaskHead(
|
| 176 |
-
in_dim=head_in_dim,
|
| 177 |
-
hidden_dim=head_hidden,
|
| 178 |
-
depth=head_depth,
|
| 179 |
-
act=head_act,
|
| 180 |
-
dropout=head_dropout,
|
| 181 |
-
heteroscedastic=heteroscedastic,
|
| 182 |
-
) for _ in range(self.num_tasks)
|
| 183 |
-
])
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
def reset_parameters(self):
|
| 187 |
-
if self.fid_embed is not None:
|
| 188 |
-
nn.init.normal_(self.fid_embed.weight, mean=0.0, std=0.02)
|
| 189 |
-
if self.task_embed is not None:
|
| 190 |
-
nn.init.normal_(self.task_embed.weight, mean=0.0, std=0.02)
|
| 191 |
-
# Encoder/heads rely on their internal initializations.
|
| 192 |
-
|
| 193 |
-
def forward(self, data: Batch) -> dict:
|
| 194 |
-
"""
|
| 195 |
-
Returns:
|
| 196 |
-
{
|
| 197 |
-
"pred": [B, T] means,
|
| 198 |
-
"logvar": [B, T] optional if heteroscedastic,
|
| 199 |
-
"h": [B, D] graph embedding after FiLM (useful for diagnostics).
|
| 200 |
-
}
|
| 201 |
-
"""
|
| 202 |
-
x, edge_index = data.x, data.edge_index
|
| 203 |
-
edge_attr = getattr(data, "edge_attr", None)
|
| 204 |
-
batch = data.batch
|
| 205 |
-
if edge_attr is None and hasattr(self.encoder, "gnn_type") and self.encoder.gnn_type == "gine":
|
| 206 |
-
raise ValueError("GINE encoder requires edge_attr, but Batch.edge_attr is None.")
|
| 207 |
-
|
| 208 |
-
# Graph embedding
|
| 209 |
-
g = self.encoder(x, edge_index, edge_attr, batch) # [B, D]
|
| 210 |
-
|
| 211 |
-
# Fidelity conditioning
|
| 212 |
-
fid_idx = data.fid_idx.view(-1).long() # [B]
|
| 213 |
-
if self.fid_embed is not None:
|
| 214 |
-
c = self.fid_embed(fid_idx) # [B, C]
|
| 215 |
-
if self.film is not None:
|
| 216 |
-
g = self.film(g, c) # [B, D]
|
| 217 |
-
else:
|
| 218 |
-
g = torch.cat([g, c], dim=-1)
|
| 219 |
-
|
| 220 |
-
# Per-task heads
|
| 221 |
-
preds: List[torch.Tensor] = []
|
| 222 |
-
logvars: Optional[List[torch.Tensor]] = [] if self.hetero else None
|
| 223 |
-
for t_idx, head in enumerate(self.heads):
|
| 224 |
-
if self.task_embed is not None:
|
| 225 |
-
tvec = self.task_embed.weight[t_idx].unsqueeze(0).expand(g.size(0), -1)
|
| 226 |
-
z = torch.cat([g, tvec], dim=-1)
|
| 227 |
-
else:
|
| 228 |
-
z = g
|
| 229 |
-
out = head(z) # [B, 1] or [B, 2]
|
| 230 |
-
if self.hetero:
|
| 231 |
-
mu = out[..., 0:1]
|
| 232 |
-
lv = out[..., 1:2]
|
| 233 |
-
preds.append(mu)
|
| 234 |
-
logvars.append(lv) # type: ignore[arg-type]
|
| 235 |
-
else:
|
| 236 |
-
preds.append(out)
|
| 237 |
-
|
| 238 |
-
pred = torch.cat(preds, dim=-1) # [B, T]
|
| 239 |
-
result = {"pred": pred, "h": g}
|
| 240 |
-
if self.hetero and logvars is not None:
|
| 241 |
-
result["logvar"] = torch.cat(logvars, dim=-1) # [B, T]
|
| 242 |
-
return result
|
| 243 |
-
|
| 244 |
-
def regularization_loss(self) -> torch.Tensor:
|
| 245 |
-
"""
|
| 246 |
-
Optional small L2 on embeddings to keep them bounded.
|
| 247 |
-
"""
|
| 248 |
-
device = next(self.parameters()).device
|
| 249 |
-
reg = torch.zeros([], device=device)
|
| 250 |
-
if self.fid_embed is not None and self.fid_emb_l2 > 0:
|
| 251 |
-
reg = reg + self.fid_emb_l2 * (self.fid_embed.weight.pow(2).mean())
|
| 252 |
-
if self.task_embed is not None and self.task_emb_l2 > 0:
|
| 253 |
-
reg = reg + self.task_emb_l2 * (self.task_embed.weight.pow(2).mean())
|
| 254 |
-
return reg
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
def build_model(
|
| 258 |
-
*,
|
| 259 |
-
in_dim_node: int,
|
| 260 |
-
in_dim_edge: int,
|
| 261 |
-
task_names: List[str],
|
| 262 |
-
num_fids: int,
|
| 263 |
-
gnn_type: Literal["gine", "gin", "gcn"] = "gine",
|
| 264 |
-
gnn_emb_dim: int = 256,
|
| 265 |
-
gnn_layers: int = 5,
|
| 266 |
-
gnn_norm: Literal["batch", "layer", "none"] = "batch",
|
| 267 |
-
gnn_readout: Literal["mean", "sum", "max"] = "mean",
|
| 268 |
-
gnn_act: str = "relu",
|
| 269 |
-
gnn_dropout: float = 0.0,
|
| 270 |
-
gnn_residual: bool = True,
|
| 271 |
-
fid_emb_dim: int = 64,
|
| 272 |
-
use_film: bool = True,
|
| 273 |
-
use_task_embed: bool = True,
|
| 274 |
-
task_emb_dim: int = 32,
|
| 275 |
-
head_hidden: int = 512,
|
| 276 |
-
use_task_uncertainty: bool = False,
|
| 277 |
-
head_depth: int = 2,
|
| 278 |
-
head_act: str = "relu",
|
| 279 |
-
head_dropout: float = 0.0,
|
| 280 |
-
heteroscedastic: bool = False,
|
| 281 |
-
fid_emb_l2: float = 0.0,
|
| 282 |
-
task_emb_l2: float = 0.0,
|
| 283 |
-
) -> MultiTaskMultiFidelityModel:
|
| 284 |
-
"""
|
| 285 |
-
Factory to construct the multi-task, multi-fidelity model with a consistent API.
|
| 286 |
-
"""
|
| 287 |
-
return MultiTaskMultiFidelityModel(
|
| 288 |
-
in_dim_node=in_dim_node,
|
| 289 |
-
in_dim_edge=in_dim_edge,
|
| 290 |
-
task_names=task_names,
|
| 291 |
-
num_fids=num_fids,
|
| 292 |
-
gnn_type=gnn_type,
|
| 293 |
-
gnn_emb_dim=gnn_emb_dim,
|
| 294 |
-
gnn_layers=gnn_layers,
|
| 295 |
-
gnn_norm=gnn_norm,
|
| 296 |
-
gnn_readout=gnn_readout,
|
| 297 |
-
gnn_act=gnn_act,
|
| 298 |
-
gnn_dropout=gnn_dropout,
|
| 299 |
-
gnn_residual=gnn_residual,
|
| 300 |
-
fid_emb_dim=fid_emb_dim,
|
| 301 |
-
use_film=use_film,
|
| 302 |
-
use_task_embed=use_task_embed,
|
| 303 |
-
task_emb_dim=task_emb_dim,
|
| 304 |
-
head_hidden=head_hidden,
|
| 305 |
-
head_depth=head_depth,
|
| 306 |
-
head_act=head_act,
|
| 307 |
-
head_dropout=head_dropout,
|
| 308 |
-
heteroscedastic=heteroscedastic,
|
| 309 |
-
fid_emb_l2=fid_emb_l2,
|
| 310 |
-
task_emb_l2=task_emb_l2,
|
| 311 |
-
use_task_uncertainty=use_task_uncertainty,
|
| 312 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/predictor.py
DELETED
|
@@ -1,193 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import re
|
| 4 |
-
from pathlib import Path
|
| 5 |
-
from typing import Dict, List, Optional, Tuple
|
| 6 |
-
|
| 7 |
-
import numpy as np
|
| 8 |
-
import torch
|
| 9 |
-
from torch_geometric.data import Data
|
| 10 |
-
|
| 11 |
-
from src.data_builder import featurize_smiles, TargetScaler
|
| 12 |
-
from src.model import build_model
|
| 13 |
-
from src.utils import to_device, apply_inverse_transform
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
# -------------------------
|
| 17 |
-
# Unit correction (ML only)
|
| 18 |
-
# -------------------------
|
| 19 |
-
POST_SCALE = {
|
| 20 |
-
"td": 1e-7,
|
| 21 |
-
"dif": 1e-5,
|
| 22 |
-
"visc": 1e-3,
|
| 23 |
-
}
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
def _load_scaler_compat(path: Path) -> TargetScaler:
|
| 27 |
-
blob = torch.load(path, map_location="cpu")
|
| 28 |
-
if "mean" not in blob or "std" not in blob:
|
| 29 |
-
raise RuntimeError(f"Unrecognized target_scaler format: {path}")
|
| 30 |
-
|
| 31 |
-
ts = TargetScaler(
|
| 32 |
-
transforms=blob.get("transforms", None),
|
| 33 |
-
eps=blob.get("eps", None),
|
| 34 |
-
)
|
| 35 |
-
ts.load_state_dict({
|
| 36 |
-
"mean": blob["mean"].float(),
|
| 37 |
-
"std": blob["std"].float(),
|
| 38 |
-
"transforms": blob.get("transforms", ts.transforms),
|
| 39 |
-
"eps": blob.get("eps", ts.eps),
|
| 40 |
-
})
|
| 41 |
-
ts.targets = [str(t).lower() for t in blob.get("targets", [])]
|
| 42 |
-
return ts
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def _infer_seed_from_name(path: Path) -> Optional[int]:
|
| 46 |
-
m = re.search(r"_([0-9]+)\.pt$", path.name)
|
| 47 |
-
return int(m.group(1)) if m else None
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
def _make_one_graph(smiles: str) -> Data:
|
| 51 |
-
x, edge_index, edge_attr = featurize_smiles(smiles)
|
| 52 |
-
d = Data(
|
| 53 |
-
x=x,
|
| 54 |
-
edge_index=edge_index,
|
| 55 |
-
edge_attr=edge_attr,
|
| 56 |
-
y=torch.zeros(1, 1),
|
| 57 |
-
y_mask=torch.zeros(1, 1, dtype=torch.bool),
|
| 58 |
-
fid_idx=torch.tensor([0], dtype=torch.long),
|
| 59 |
-
)
|
| 60 |
-
d.smiles = smiles
|
| 61 |
-
return d
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
class SingleTaskEnsemblePredictor:
|
| 65 |
-
"""
|
| 66 |
-
Single-task ensemble:
|
| 67 |
-
models/single_models/{prop}_single_model_{seed}.pt
|
| 68 |
-
models/single_models/{prop}_single_scalar_{seed}.pt
|
| 69 |
-
"""
|
| 70 |
-
|
| 71 |
-
def __init__(self, models_dir: str = "models/single_models", device: str = "cpu"):
|
| 72 |
-
self.models_dir = Path(models_dir)
|
| 73 |
-
self.device = torch.device(device if device == "cuda" and torch.cuda.is_available() else "cpu")
|
| 74 |
-
self._cache: Dict[Tuple[str, int], Tuple[Optional[torch.nn.Module], TargetScaler, dict]] = {}
|
| 75 |
-
|
| 76 |
-
def available_seeds(self, prop: str) -> List[int]:
|
| 77 |
-
prop = prop.lower()
|
| 78 |
-
seeds = []
|
| 79 |
-
for p in self.models_dir.glob(f"{prop}_single_model_*.pt"):
|
| 80 |
-
s = _infer_seed_from_name(p)
|
| 81 |
-
if s is not None:
|
| 82 |
-
seeds.append(s)
|
| 83 |
-
return sorted(set(seeds))
|
| 84 |
-
|
| 85 |
-
def _load_one(self, prop: str, seed: int):
|
| 86 |
-
prop = prop.lower()
|
| 87 |
-
key = (prop, seed)
|
| 88 |
-
if key in self._cache:
|
| 89 |
-
return self._cache[key]
|
| 90 |
-
|
| 91 |
-
ckpt_path = self.models_dir / f"{prop}_single_model_{seed}.pt"
|
| 92 |
-
scaler_path = self.models_dir / f"{prop}_single_scalar_{seed}.pt"
|
| 93 |
-
if not ckpt_path.exists() or not scaler_path.exists():
|
| 94 |
-
raise FileNotFoundError(f"Missing model/scaler for {prop} seed {seed}")
|
| 95 |
-
|
| 96 |
-
ckpt = torch.load(ckpt_path, map_location=self.device)
|
| 97 |
-
train_args = ckpt.get("args", {})
|
| 98 |
-
|
| 99 |
-
scaler = _load_scaler_compat(scaler_path)
|
| 100 |
-
task_names = list(getattr(scaler, "targets", [])) or [prop]
|
| 101 |
-
|
| 102 |
-
meta = {"train_args": train_args, "task_names": task_names}
|
| 103 |
-
self._cache[key] = (None, scaler, meta)
|
| 104 |
-
return self._cache[key]
|
| 105 |
-
|
| 106 |
-
def _build_model_if_needed(self, prop: str, seed: int, in_dim_node: int, in_dim_edge: int):
|
| 107 |
-
prop = prop.lower()
|
| 108 |
-
key = (prop, seed)
|
| 109 |
-
model, scaler, meta = self._cache[key]
|
| 110 |
-
if model is not None:
|
| 111 |
-
return model, scaler, meta
|
| 112 |
-
|
| 113 |
-
train_args = meta["train_args"]
|
| 114 |
-
task_names = meta["task_names"]
|
| 115 |
-
|
| 116 |
-
ckpt_path = self.models_dir / f"{prop}_single_model_{seed}.pt"
|
| 117 |
-
ckpt = torch.load(ckpt_path, map_location=self.device)
|
| 118 |
-
state_dict = ckpt["model"]
|
| 119 |
-
|
| 120 |
-
# infer num_fids from checkpoint
|
| 121 |
-
if "fid_embed.weight" in state_dict:
|
| 122 |
-
num_fids = state_dict["fid_embed.weight"].shape[0]
|
| 123 |
-
else:
|
| 124 |
-
num_fids = 1
|
| 125 |
-
|
| 126 |
-
model = build_model(
|
| 127 |
-
in_dim_node=in_dim_node,
|
| 128 |
-
in_dim_edge=in_dim_edge,
|
| 129 |
-
task_names=task_names,
|
| 130 |
-
num_fids=num_fids,
|
| 131 |
-
gnn_type=train_args.get("gnn_type", "gine"),
|
| 132 |
-
gnn_emb_dim=train_args.get("gnn_emb_dim", 256),
|
| 133 |
-
gnn_layers=train_args.get("gnn_layers", 5),
|
| 134 |
-
gnn_norm=train_args.get("gnn_norm", "batch"),
|
| 135 |
-
gnn_readout=train_args.get("gnn_readout", "mean"),
|
| 136 |
-
gnn_act=train_args.get("gnn_act", "relu"),
|
| 137 |
-
gnn_dropout=train_args.get("gnn_dropout", 0.0),
|
| 138 |
-
gnn_residual=train_args.get("gnn_residual", True),
|
| 139 |
-
fid_emb_dim=train_args.get("fid_emb_dim", 64),
|
| 140 |
-
use_film=train_args.get("use_film", True),
|
| 141 |
-
use_task_embed=train_args.get("use_task_embed", True),
|
| 142 |
-
task_emb_dim=train_args.get("task_emb_dim", 32),
|
| 143 |
-
head_hidden=train_args.get("head_hidden", 512),
|
| 144 |
-
head_depth=train_args.get("head_depth", 2),
|
| 145 |
-
head_act=train_args.get("head_act", "relu"),
|
| 146 |
-
head_dropout=train_args.get("head_dropout", 0.0),
|
| 147 |
-
heteroscedastic=train_args.get("heteroscedastic", False),
|
| 148 |
-
fid_emb_l2=0.0,
|
| 149 |
-
task_emb_l2=0.0,
|
| 150 |
-
use_task_uncertainty=train_args.get("task_uncertainty", False),
|
| 151 |
-
).to(self.device)
|
| 152 |
-
|
| 153 |
-
model.load_state_dict(state_dict, strict=True)
|
| 154 |
-
model.eval()
|
| 155 |
-
|
| 156 |
-
self._cache[key] = (model, scaler, meta)
|
| 157 |
-
return model, scaler, meta
|
| 158 |
-
|
| 159 |
-
def predict_mean_std(self, smiles: str, prop: str) -> Tuple[Optional[float], Optional[float], Dict[int, float]]:
|
| 160 |
-
prop = prop.lower()
|
| 161 |
-
seeds = self.available_seeds(prop)
|
| 162 |
-
if not seeds:
|
| 163 |
-
return None, None, {}
|
| 164 |
-
|
| 165 |
-
try:
|
| 166 |
-
g = _make_one_graph(smiles)
|
| 167 |
-
except Exception:
|
| 168 |
-
return None, None, {}
|
| 169 |
-
|
| 170 |
-
in_dim_node = g.x.shape[1]
|
| 171 |
-
in_dim_edge = g.edge_attr.shape[1]
|
| 172 |
-
|
| 173 |
-
per_seed: Dict[int, float] = {}
|
| 174 |
-
with torch.no_grad():
|
| 175 |
-
for seed in seeds:
|
| 176 |
-
self._load_one(prop, seed)
|
| 177 |
-
model, scaler, meta = self._build_model_if_needed(prop, seed, in_dim_node, in_dim_edge)
|
| 178 |
-
|
| 179 |
-
batch = to_device(g, self.device)
|
| 180 |
-
out = model(batch)
|
| 181 |
-
pred_n = out["pred"] # [1, 1]
|
| 182 |
-
pred = apply_inverse_transform(pred_n, scaler).cpu().numpy().reshape(-1)
|
| 183 |
-
val = float(pred[0])
|
| 184 |
-
|
| 185 |
-
# unit correction
|
| 186 |
-
val *= POST_SCALE.get(prop, 1.0)
|
| 187 |
-
|
| 188 |
-
per_seed[seed] = val
|
| 189 |
-
|
| 190 |
-
vals = np.array(list(per_seed.values()), dtype=float)
|
| 191 |
-
mean = float(vals.mean())
|
| 192 |
-
std = float(vals.std(ddof=1)) if len(vals) > 1 else 0.0
|
| 193 |
-
return mean, std, per_seed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/predictor_multitask.py
DELETED
|
@@ -1,209 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import re
|
| 4 |
-
from pathlib import Path
|
| 5 |
-
from typing import Dict, List, Optional, Tuple
|
| 6 |
-
|
| 7 |
-
import numpy as np
|
| 8 |
-
import torch
|
| 9 |
-
from torch_geometric.data import Data
|
| 10 |
-
|
| 11 |
-
from src.data_builder import featurize_smiles, TargetScaler
|
| 12 |
-
from src.model import build_model
|
| 13 |
-
from src.utils import to_device, apply_inverse_transform
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
# -------------------------
|
| 17 |
-
# Unit correction (ML only)
|
| 18 |
-
# -------------------------
|
| 19 |
-
POST_SCALE = {
|
| 20 |
-
"td": 1e-7,
|
| 21 |
-
"dif": 1e-5,
|
| 22 |
-
"visc": 1e-3,
|
| 23 |
-
}
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
def _load_scaler_compat(path: Path) -> TargetScaler:
|
| 27 |
-
blob = torch.load(path, map_location="cpu")
|
| 28 |
-
if "mean" not in blob or "std" not in blob:
|
| 29 |
-
raise RuntimeError(f"Unrecognized target_scaler format: {path}")
|
| 30 |
-
|
| 31 |
-
ts = TargetScaler(
|
| 32 |
-
transforms=blob.get("transforms", None),
|
| 33 |
-
eps=blob.get("eps", None),
|
| 34 |
-
)
|
| 35 |
-
ts.load_state_dict({
|
| 36 |
-
"mean": blob["mean"].float(),
|
| 37 |
-
"std": blob["std"].float(),
|
| 38 |
-
"transforms": blob.get("transforms", ts.transforms),
|
| 39 |
-
"eps": blob.get("eps", ts.eps),
|
| 40 |
-
})
|
| 41 |
-
ts.targets = [str(t).lower() for t in blob.get("targets", [])]
|
| 42 |
-
return ts
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def _infer_seed(path: Path) -> Optional[int]:
|
| 46 |
-
m = re.search(r"_([0-9]+)\.pt$", path.name)
|
| 47 |
-
return int(m.group(1)) if m else None
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
def _make_one_graph(smiles: str, T: int, fid_idx: int = 0) -> Data:
|
| 51 |
-
x, edge_index, edge_attr = featurize_smiles(smiles)
|
| 52 |
-
d = Data(
|
| 53 |
-
x=x,
|
| 54 |
-
edge_index=edge_index,
|
| 55 |
-
edge_attr=edge_attr,
|
| 56 |
-
y=torch.zeros(1, T),
|
| 57 |
-
y_mask=torch.zeros(1, T, dtype=torch.bool),
|
| 58 |
-
fid_idx=torch.tensor([fid_idx], dtype=torch.long),
|
| 59 |
-
)
|
| 60 |
-
d.smiles = smiles
|
| 61 |
-
return d
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
class MultiTaskEnsemblePredictor:
|
| 65 |
-
"""
|
| 66 |
-
Multi-task ensemble:
|
| 67 |
-
models/multitask_models/{task}_model_{seed}.pt
|
| 68 |
-
models/multitask_models/{task}_scalar_{seed}.pt
|
| 69 |
-
"""
|
| 70 |
-
|
| 71 |
-
def __init__(self, models_dir: str = "models/multitask_models", device: str = "cpu"):
|
| 72 |
-
self.models_dir = Path(models_dir)
|
| 73 |
-
self.device = torch.device(device if device == "cuda" and torch.cuda.is_available() else "cpu")
|
| 74 |
-
self._cache: Dict[Tuple[str, int], Tuple[Optional[torch.nn.Module], TargetScaler, dict]] = {}
|
| 75 |
-
|
| 76 |
-
def available_seeds(self, task: str) -> List[int]:
|
| 77 |
-
task = task.strip().lower()
|
| 78 |
-
seeds = []
|
| 79 |
-
for p in self.models_dir.glob(f"{task}_model_*.pt"):
|
| 80 |
-
s = _infer_seed(p)
|
| 81 |
-
if s is not None:
|
| 82 |
-
seeds.append(s)
|
| 83 |
-
return sorted(set(seeds))
|
| 84 |
-
|
| 85 |
-
def _load_one_meta(self, task: str, seed: int):
|
| 86 |
-
task = task.strip().lower()
|
| 87 |
-
key = (task, seed)
|
| 88 |
-
if key in self._cache:
|
| 89 |
-
return self._cache[key]
|
| 90 |
-
|
| 91 |
-
ckpt_path = self.models_dir / f"{task}_model_{seed}.pt"
|
| 92 |
-
scaler_path = self.models_dir / f"{task}_scalar_{seed}.pt"
|
| 93 |
-
if not ckpt_path.exists() or not scaler_path.exists():
|
| 94 |
-
raise FileNotFoundError(f"Missing model/scaler for task={task} seed={seed}")
|
| 95 |
-
|
| 96 |
-
ckpt = torch.load(ckpt_path, map_location=self.device)
|
| 97 |
-
state_dict = ckpt["model"]
|
| 98 |
-
train_args = ckpt.get("args", {})
|
| 99 |
-
|
| 100 |
-
scaler = _load_scaler_compat(scaler_path)
|
| 101 |
-
task_names = list(getattr(scaler, "targets", []))
|
| 102 |
-
if not task_names:
|
| 103 |
-
raise RuntimeError(f"No targets found in scaler: {scaler_path}")
|
| 104 |
-
|
| 105 |
-
if "fid_embed.weight" in state_dict:
|
| 106 |
-
num_fids = state_dict["fid_embed.weight"].shape[0]
|
| 107 |
-
else:
|
| 108 |
-
num_fids = 1
|
| 109 |
-
|
| 110 |
-
meta = {
|
| 111 |
-
"train_args": train_args,
|
| 112 |
-
"task_names": task_names,
|
| 113 |
-
"num_fids": num_fids,
|
| 114 |
-
}
|
| 115 |
-
self._cache[key] = (None, scaler, meta)
|
| 116 |
-
return self._cache[key]
|
| 117 |
-
|
| 118 |
-
def _build_if_needed(self, task: str, seed: int, in_dim_node: int, in_dim_edge: int):
|
| 119 |
-
task = task.strip().lower()
|
| 120 |
-
key = (task, seed)
|
| 121 |
-
model, scaler, meta = self._cache[key]
|
| 122 |
-
if model is not None:
|
| 123 |
-
return model, scaler, meta
|
| 124 |
-
|
| 125 |
-
train_args = meta["train_args"]
|
| 126 |
-
task_names = meta["task_names"]
|
| 127 |
-
num_fids = meta["num_fids"]
|
| 128 |
-
|
| 129 |
-
model = build_model(
|
| 130 |
-
in_dim_node=in_dim_node,
|
| 131 |
-
in_dim_edge=in_dim_edge,
|
| 132 |
-
task_names=task_names,
|
| 133 |
-
num_fids=num_fids,
|
| 134 |
-
gnn_type=train_args.get("gnn_type", "gine"),
|
| 135 |
-
gnn_emb_dim=train_args.get("gnn_emb_dim", 256),
|
| 136 |
-
gnn_layers=train_args.get("gnn_layers", 5),
|
| 137 |
-
gnn_norm=train_args.get("gnn_norm", "batch"),
|
| 138 |
-
gnn_readout=train_args.get("gnn_readout", "mean"),
|
| 139 |
-
gnn_act=train_args.get("gnn_act", "relu"),
|
| 140 |
-
gnn_dropout=train_args.get("gnn_dropout", 0.0),
|
| 141 |
-
gnn_residual=train_args.get("gnn_residual", True),
|
| 142 |
-
fid_emb_dim=train_args.get("fid_emb_dim", 64),
|
| 143 |
-
use_film=train_args.get("use_film", True),
|
| 144 |
-
use_task_embed=train_args.get("use_task_embed", True),
|
| 145 |
-
task_emb_dim=train_args.get("task_emb_dim", 32),
|
| 146 |
-
head_hidden=train_args.get("head_hidden", 512),
|
| 147 |
-
head_depth=train_args.get("head_depth", 2),
|
| 148 |
-
head_act=train_args.get("head_act", "relu"),
|
| 149 |
-
head_dropout=train_args.get("head_dropout", 0.0),
|
| 150 |
-
heteroscedastic=train_args.get("heteroscedastic", False),
|
| 151 |
-
fid_emb_l2=0.0,
|
| 152 |
-
task_emb_l2=0.0,
|
| 153 |
-
use_task_uncertainty=train_args.get("task_uncertainty", False),
|
| 154 |
-
).to(self.device)
|
| 155 |
-
|
| 156 |
-
ckpt_path = self.models_dir / f"{task}_model_{seed}.pt"
|
| 157 |
-
ckpt = torch.load(ckpt_path, map_location=self.device)
|
| 158 |
-
model.load_state_dict(ckpt["model"], strict=True)
|
| 159 |
-
model.eval()
|
| 160 |
-
|
| 161 |
-
self._cache[key] = (model, scaler, meta)
|
| 162 |
-
return model, scaler, meta
|
| 163 |
-
|
| 164 |
-
def predict_mean_std(self, smiles: str, prop_key: str, task: str) -> Tuple[Optional[float], Optional[float], Dict[int, float]]:
|
| 165 |
-
task = task.strip().lower()
|
| 166 |
-
prop_key = prop_key.lower()
|
| 167 |
-
|
| 168 |
-
seeds = self.available_seeds(task)
|
| 169 |
-
if not seeds:
|
| 170 |
-
return None, None, {}
|
| 171 |
-
|
| 172 |
-
self._load_one_meta(task, seeds[0])
|
| 173 |
-
_, scaler0, meta0 = self._cache[(task, seeds[0])]
|
| 174 |
-
targets = list(meta0["task_names"]) # already lower()
|
| 175 |
-
if prop_key not in targets:
|
| 176 |
-
return None, None, {}
|
| 177 |
-
|
| 178 |
-
t_idx = targets.index(prop_key)
|
| 179 |
-
T = len(targets)
|
| 180 |
-
|
| 181 |
-
try:
|
| 182 |
-
g = _make_one_graph(smiles, T=T, fid_idx=0)
|
| 183 |
-
except Exception:
|
| 184 |
-
return None, None, {}
|
| 185 |
-
|
| 186 |
-
in_dim_node = g.x.shape[1]
|
| 187 |
-
in_dim_edge = g.edge_attr.shape[1]
|
| 188 |
-
|
| 189 |
-
per_seed: Dict[int, float] = {}
|
| 190 |
-
with torch.no_grad():
|
| 191 |
-
for seed in seeds:
|
| 192 |
-
self._load_one_meta(task, seed)
|
| 193 |
-
model, scaler, meta = self._build_if_needed(task, seed, in_dim_node, in_dim_edge)
|
| 194 |
-
|
| 195 |
-
batch = to_device(g, self.device)
|
| 196 |
-
out = model(batch)
|
| 197 |
-
pred_n = out["pred"] # [1, T]
|
| 198 |
-
pred = apply_inverse_transform(pred_n, scaler).cpu().numpy().reshape(-1)
|
| 199 |
-
val = float(pred[t_idx])
|
| 200 |
-
|
| 201 |
-
# unit correction
|
| 202 |
-
val *= POST_SCALE.get(prop_key, 1.0)
|
| 203 |
-
|
| 204 |
-
per_seed[seed] = val
|
| 205 |
-
|
| 206 |
-
vals = np.array(list(per_seed.values()), dtype=float)
|
| 207 |
-
mean = float(vals.mean())
|
| 208 |
-
std = float(vals.std(ddof=1)) if len(vals) > 1 else 0.0
|
| 209 |
-
return mean, std, per_seed
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/predictor_router.py
DELETED
|
@@ -1,45 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import json
|
| 4 |
-
from pathlib import Path
|
| 5 |
-
from typing import Dict, Optional, Tuple
|
| 6 |
-
|
| 7 |
-
from src.predictor import SingleTaskEnsemblePredictor
|
| 8 |
-
from src.predictor_multitask import MultiTaskEnsemblePredictor
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
class RouterPredictor:
|
| 12 |
-
"""
|
| 13 |
-
Routes each property to either:
|
| 14 |
-
- single-task ensemble (models/single_models)
|
| 15 |
-
- multitask ensemble (models/multitask_models/{task}_*)
|
| 16 |
-
based on models/best_model_map.json
|
| 17 |
-
"""
|
| 18 |
-
|
| 19 |
-
def __init__(
|
| 20 |
-
self,
|
| 21 |
-
map_path: str = "models/best_model_map.json",
|
| 22 |
-
single_dir: str = "models/single_models",
|
| 23 |
-
multitask_dir: str = "models/multitask_models",
|
| 24 |
-
device: str = "cpu",
|
| 25 |
-
):
|
| 26 |
-
self.map_path = Path(map_path)
|
| 27 |
-
self.map: Dict[str, dict] = json.load(open(self.map_path))
|
| 28 |
-
self.single = SingleTaskEnsemblePredictor(models_dir=single_dir, device=device)
|
| 29 |
-
self.multi = MultiTaskEnsemblePredictor(models_dir=multitask_dir, device=device)
|
| 30 |
-
|
| 31 |
-
def predict_mean_std(self, smiles: str, prop: str) -> Tuple[Optional[float], Optional[float], dict, str]:
|
| 32 |
-
prop = prop.lower()
|
| 33 |
-
cfg = self.map.get(prop, {"family": "single"})
|
| 34 |
-
|
| 35 |
-
fam = cfg.get("family", "single").lower()
|
| 36 |
-
if fam == "multitask":
|
| 37 |
-
task = str(cfg.get("task", "all")).lower()
|
| 38 |
-
mean, std, per_seed = self.multi.predict_mean_std(smiles, prop_key=prop, task=task)
|
| 39 |
-
label = f"multitask:{task}"
|
| 40 |
-
return mean, std, per_seed, label
|
| 41 |
-
|
| 42 |
-
# default: single
|
| 43 |
-
mean, std, per_seed = self.single.predict_mean_std(smiles, prop)
|
| 44 |
-
label = "single"
|
| 45 |
-
return mean, std, per_seed, label
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/rnn_smiles/__init__.py
DELETED
|
@@ -1,22 +0,0 @@
|
|
| 1 |
-
"""RNN-based SMILES generation helpers for Streamlit pages."""
|
| 2 |
-
|
| 3 |
-
from .generator import (
|
| 4 |
-
canonicalize_smiles,
|
| 5 |
-
filter_novel_smiles,
|
| 6 |
-
generate_smiles,
|
| 7 |
-
load_existing_smiles_set,
|
| 8 |
-
load_rnn_model,
|
| 9 |
-
)
|
| 10 |
-
from .rnn import MultiGRU, RNN
|
| 11 |
-
from .vocabulary import Vocabulary
|
| 12 |
-
|
| 13 |
-
__all__ = [
|
| 14 |
-
"canonicalize_smiles",
|
| 15 |
-
"filter_novel_smiles",
|
| 16 |
-
"generate_smiles",
|
| 17 |
-
"load_existing_smiles_set",
|
| 18 |
-
"load_rnn_model",
|
| 19 |
-
"MultiGRU",
|
| 20 |
-
"RNN",
|
| 21 |
-
"Vocabulary",
|
| 22 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/rnn_smiles/generator.py
DELETED
|
@@ -1,175 +0,0 @@
|
|
| 1 |
-
"""Streamlit integration helpers for RNN SMILES generation."""
|
| 2 |
-
|
| 3 |
-
from __future__ import annotations
|
| 4 |
-
|
| 5 |
-
from pathlib import Path
|
| 6 |
-
from typing import Iterable, Sequence
|
| 7 |
-
|
| 8 |
-
import pandas as pd
|
| 9 |
-
import streamlit as st
|
| 10 |
-
import torch
|
| 11 |
-
from rdkit import Chem, RDLogger
|
| 12 |
-
|
| 13 |
-
from .rnn import RNN
|
| 14 |
-
from .vocabulary import Vocabulary
|
| 15 |
-
|
| 16 |
-
RDLogger.DisableLog("rdApp.*")
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
def canonicalize_smiles(smiles: str) -> str | None:
|
| 20 |
-
s = (smiles or "").strip()
|
| 21 |
-
if not s:
|
| 22 |
-
return None
|
| 23 |
-
mol = Chem.MolFromSmiles(s)
|
| 24 |
-
if mol is None:
|
| 25 |
-
return None
|
| 26 |
-
return Chem.MolToSmiles(mol, canonical=True)
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def _find_smiles_column(path: Path) -> str:
|
| 30 |
-
header = pd.read_csv(path, nrows=0)
|
| 31 |
-
for col in header.columns:
|
| 32 |
-
if str(col).strip().lower() == "smiles":
|
| 33 |
-
return col
|
| 34 |
-
raise ValueError(f"No SMILES column found in {path}")
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def _load_checkpoint(path: Path, device: torch.device) -> dict:
|
| 38 |
-
# Support both new/old torch signatures while preferring secure load mode.
|
| 39 |
-
try:
|
| 40 |
-
state = torch.load(path, map_location=device, weights_only=True)
|
| 41 |
-
except TypeError:
|
| 42 |
-
state = torch.load(path, map_location=device)
|
| 43 |
-
if isinstance(state, dict) and isinstance(state.get("state_dict"), dict):
|
| 44 |
-
state = state["state_dict"]
|
| 45 |
-
if not isinstance(state, dict):
|
| 46 |
-
raise RuntimeError(f"Checkpoint does not contain a state dict: {path}")
|
| 47 |
-
return state
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
@st.cache_resource(show_spinner=False)
|
| 51 |
-
def load_rnn_model(ckpt_path: str | Path, voc_path: str | Path) -> tuple[RNN, Vocabulary]:
|
| 52 |
-
ckpt_path = Path(ckpt_path).expanduser().resolve()
|
| 53 |
-
voc_path = Path(voc_path).expanduser().resolve()
|
| 54 |
-
|
| 55 |
-
if not ckpt_path.exists():
|
| 56 |
-
raise FileNotFoundError(f"Checkpoint not found: {ckpt_path}")
|
| 57 |
-
if not voc_path.exists():
|
| 58 |
-
raise FileNotFoundError(f"Vocabulary not found: {voc_path}")
|
| 59 |
-
|
| 60 |
-
voc = Vocabulary(init_from_file=str(voc_path))
|
| 61 |
-
model = RNN(voc)
|
| 62 |
-
model_device = next(model.rnn.parameters()).device
|
| 63 |
-
state = _load_checkpoint(ckpt_path, model_device)
|
| 64 |
-
|
| 65 |
-
ckpt_vocab_size = None
|
| 66 |
-
if "embedding.weight" in state:
|
| 67 |
-
ckpt_vocab_size = int(state["embedding.weight"].shape[0])
|
| 68 |
-
if ckpt_vocab_size is not None and ckpt_vocab_size != voc.vocab_size:
|
| 69 |
-
raise RuntimeError(
|
| 70 |
-
f"Vocabulary size mismatch: voc has {voc.vocab_size} tokens, "
|
| 71 |
-
f"checkpoint expects {ckpt_vocab_size}. "
|
| 72 |
-
"Use the matching vocab file for this checkpoint."
|
| 73 |
-
)
|
| 74 |
-
|
| 75 |
-
model.rnn.load_state_dict(state)
|
| 76 |
-
model.rnn.eval()
|
| 77 |
-
return model, voc
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
def _sample_with_temperature(
|
| 81 |
-
model: RNN, voc: Vocabulary, batch_size: int, max_length: int, temperature: float
|
| 82 |
-
) -> torch.Tensor:
|
| 83 |
-
temp = max(float(temperature), 1e-6)
|
| 84 |
-
device = next(model.rnn.parameters()).device
|
| 85 |
-
start_token = torch.full((batch_size,), voc.vocab["GO"], dtype=torch.long, device=device)
|
| 86 |
-
h = model.rnn.init_h(batch_size)
|
| 87 |
-
x = start_token
|
| 88 |
-
|
| 89 |
-
sequences: list[torch.Tensor] = []
|
| 90 |
-
finished = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
| 91 |
-
|
| 92 |
-
for _ in range(max_length):
|
| 93 |
-
logits, h = model.rnn(x, h)
|
| 94 |
-
logits = logits / temp
|
| 95 |
-
prob = torch.softmax(logits, dim=1)
|
| 96 |
-
x = torch.multinomial(prob, 1).view(-1)
|
| 97 |
-
sequences.append(x.view(-1, 1))
|
| 98 |
-
finished = finished | (x == voc.vocab["EOS"])
|
| 99 |
-
if torch.all(finished):
|
| 100 |
-
break
|
| 101 |
-
|
| 102 |
-
if not sequences:
|
| 103 |
-
return torch.empty((batch_size, 0), dtype=torch.long, device=device)
|
| 104 |
-
return torch.cat(sequences, dim=1)
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
def generate_smiles(
|
| 108 |
-
model: RNN,
|
| 109 |
-
voc: Vocabulary,
|
| 110 |
-
n: int,
|
| 111 |
-
max_length: int,
|
| 112 |
-
temperature: float = 1.0,
|
| 113 |
-
) -> list[str]:
|
| 114 |
-
if n <= 0:
|
| 115 |
-
return []
|
| 116 |
-
max_length = max(int(max_length), 1)
|
| 117 |
-
|
| 118 |
-
with torch.no_grad():
|
| 119 |
-
if abs(float(temperature) - 1.0) < 1e-8:
|
| 120 |
-
seqs, _, _ = model.sample(int(n), max_length=max_length)
|
| 121 |
-
else:
|
| 122 |
-
seqs = _sample_with_temperature(
|
| 123 |
-
model,
|
| 124 |
-
voc,
|
| 125 |
-
int(n),
|
| 126 |
-
max_length,
|
| 127 |
-
float(temperature),
|
| 128 |
-
)
|
| 129 |
-
arr = seqs.detach().cpu().numpy()
|
| 130 |
-
|
| 131 |
-
output: list[str] = []
|
| 132 |
-
for seq in arr:
|
| 133 |
-
output.append(voc.decode(seq))
|
| 134 |
-
return output
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
def filter_novel_smiles(smiles: Iterable[str], existing: set[str]) -> list[str]:
|
| 138 |
-
novel: list[str] = []
|
| 139 |
-
seen: set[str] = set()
|
| 140 |
-
for smi in smiles:
|
| 141 |
-
canonical = canonicalize_smiles(smi)
|
| 142 |
-
if canonical is None:
|
| 143 |
-
continue
|
| 144 |
-
if canonical in seen:
|
| 145 |
-
continue
|
| 146 |
-
seen.add(canonical)
|
| 147 |
-
if canonical in existing:
|
| 148 |
-
continue
|
| 149 |
-
novel.append(canonical)
|
| 150 |
-
return novel
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
@st.cache_resource(show_spinner=False)
|
| 154 |
-
def load_existing_smiles_set(csv_paths: Sequence[str | Path], chunksize: int = 200_000) -> set[str]:
|
| 155 |
-
existing: set[str] = set()
|
| 156 |
-
for p in csv_paths:
|
| 157 |
-
path = Path(p)
|
| 158 |
-
if not path.exists():
|
| 159 |
-
continue
|
| 160 |
-
col = _find_smiles_column(path)
|
| 161 |
-
for chunk in pd.read_csv(path, usecols=[col], chunksize=int(chunksize)):
|
| 162 |
-
for smiles in chunk[col].astype(str):
|
| 163 |
-
canonical = canonicalize_smiles(smiles)
|
| 164 |
-
if canonical:
|
| 165 |
-
existing.add(canonical)
|
| 166 |
-
return existing
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
__all__ = [
|
| 170 |
-
"canonicalize_smiles",
|
| 171 |
-
"load_rnn_model",
|
| 172 |
-
"generate_smiles",
|
| 173 |
-
"filter_novel_smiles",
|
| 174 |
-
"load_existing_smiles_set",
|
| 175 |
-
]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/rnn_smiles/rnn.py
DELETED
|
@@ -1,89 +0,0 @@
|
|
| 1 |
-
"""Core GRU model used for polymer SMILES generation."""
|
| 2 |
-
|
| 3 |
-
from __future__ import annotations
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
import torch.nn as nn
|
| 7 |
-
import torch.nn.functional as F
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class MultiGRU(nn.Module):
|
| 11 |
-
def __init__(self, vocab_size: int):
|
| 12 |
-
super().__init__()
|
| 13 |
-
self.embedding = nn.Embedding(vocab_size, 128)
|
| 14 |
-
self.gru_1 = nn.GRUCell(128, 512)
|
| 15 |
-
self.gru_2 = nn.GRUCell(512, 512)
|
| 16 |
-
self.gru_3 = nn.GRUCell(512, 512)
|
| 17 |
-
self.linear = nn.Linear(512, vocab_size)
|
| 18 |
-
|
| 19 |
-
def forward(self, x: torch.Tensor, h: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 20 |
-
x = self.embedding(x)
|
| 21 |
-
h_out = torch.zeros_like(h)
|
| 22 |
-
x = h_out[0] = self.gru_1(x, h[0])
|
| 23 |
-
x = h_out[1] = self.gru_2(x, h[1])
|
| 24 |
-
x = h_out[2] = self.gru_3(x, h[2])
|
| 25 |
-
x = self.linear(x)
|
| 26 |
-
return x, h_out
|
| 27 |
-
|
| 28 |
-
def init_h(self, batch_size: int) -> torch.Tensor:
|
| 29 |
-
device = next(self.parameters()).device
|
| 30 |
-
return torch.zeros(3, batch_size, 512, device=device)
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
def nll_loss(log_probs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
| 34 |
-
# Gather selected token log-probability for each sample in batch.
|
| 35 |
-
return log_probs.gather(1, targets.contiguous().view(-1, 1)).squeeze(1)
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
class RNN:
|
| 39 |
-
def __init__(self, voc):
|
| 40 |
-
self.rnn = MultiGRU(voc.vocab_size)
|
| 41 |
-
if torch.cuda.is_available():
|
| 42 |
-
self.rnn.cuda()
|
| 43 |
-
self.voc = voc
|
| 44 |
-
|
| 45 |
-
def likelihood(self, target: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
| 46 |
-
batch_size, seq_length = target.size()
|
| 47 |
-
device = target.device
|
| 48 |
-
start_token = torch.full((batch_size, 1), self.voc.vocab["GO"], dtype=torch.long, device=device)
|
| 49 |
-
x = torch.cat((start_token, target[:, :-1]), 1)
|
| 50 |
-
h = self.rnn.init_h(batch_size)
|
| 51 |
-
|
| 52 |
-
log_probs = torch.zeros(batch_size, device=device)
|
| 53 |
-
entropy = torch.zeros(batch_size, device=device)
|
| 54 |
-
for step in range(seq_length):
|
| 55 |
-
logits, h = self.rnn(x[:, step], h)
|
| 56 |
-
log_prob = F.log_softmax(logits, dim=1)
|
| 57 |
-
prob = F.softmax(logits, dim=1)
|
| 58 |
-
log_probs += nll_loss(log_prob, target[:, step])
|
| 59 |
-
entropy += -torch.sum((log_prob * prob), 1)
|
| 60 |
-
return log_probs, entropy
|
| 61 |
-
|
| 62 |
-
def sample(self, batch_size: int, max_length: int = 140) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 63 |
-
device = next(self.rnn.parameters()).device
|
| 64 |
-
start_token = torch.full((batch_size,), self.voc.vocab["GO"], dtype=torch.long, device=device)
|
| 65 |
-
h = self.rnn.init_h(batch_size)
|
| 66 |
-
x = start_token
|
| 67 |
-
|
| 68 |
-
sequences: list[torch.Tensor] = []
|
| 69 |
-
log_probs = torch.zeros(batch_size, device=device)
|
| 70 |
-
finished = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
| 71 |
-
entropy = torch.zeros(batch_size, device=device)
|
| 72 |
-
|
| 73 |
-
for _ in range(max_length):
|
| 74 |
-
logits, h = self.rnn(x, h)
|
| 75 |
-
prob = F.softmax(logits, dim=1)
|
| 76 |
-
log_prob = F.log_softmax(logits, dim=1)
|
| 77 |
-
x = torch.multinomial(prob, 1).view(-1)
|
| 78 |
-
sequences.append(x.view(-1, 1))
|
| 79 |
-
log_probs += nll_loss(log_prob, x)
|
| 80 |
-
entropy += -torch.sum((log_prob * prob), 1)
|
| 81 |
-
finished = finished | (x == self.voc.vocab["EOS"])
|
| 82 |
-
if torch.all(finished):
|
| 83 |
-
break
|
| 84 |
-
|
| 85 |
-
if sequences:
|
| 86 |
-
stacked = torch.cat(sequences, 1)
|
| 87 |
-
else:
|
| 88 |
-
stacked = torch.empty((batch_size, 0), dtype=torch.long, device=device)
|
| 89 |
-
return stacked, log_probs, entropy
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/rnn_smiles/utils.py
DELETED
|
@@ -1,15 +0,0 @@
|
|
| 1 |
-
"""Utility helpers used by the legacy-style RNN generator."""
|
| 2 |
-
|
| 3 |
-
from __future__ import annotations
|
| 4 |
-
|
| 5 |
-
import numpy as np
|
| 6 |
-
import torch
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
def variable(tensor: torch.Tensor | np.ndarray) -> torch.Tensor:
|
| 10 |
-
"""Return a tensor on GPU when available."""
|
| 11 |
-
if isinstance(tensor, np.ndarray):
|
| 12 |
-
tensor = torch.from_numpy(tensor)
|
| 13 |
-
if torch.cuda.is_available():
|
| 14 |
-
return tensor.cuda()
|
| 15 |
-
return tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/rnn_smiles/vocabulary.py
DELETED
|
@@ -1,69 +0,0 @@
|
|
| 1 |
-
"""Token vocabulary used by the SMILES RNN."""
|
| 2 |
-
|
| 3 |
-
from __future__ import annotations
|
| 4 |
-
|
| 5 |
-
import re
|
| 6 |
-
|
| 7 |
-
import numpy as np
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
class Vocabulary:
|
| 11 |
-
def __init__(self, init_from_file: str | None = None, max_length: int | None = None):
|
| 12 |
-
self.special_tokens = ["EOS", "GO"]
|
| 13 |
-
self.additional_chars: set[str] = set()
|
| 14 |
-
self.chars = self.special_tokens
|
| 15 |
-
self.vocab_size = len(self.chars)
|
| 16 |
-
self.vocab = dict(zip(self.chars, range(len(self.chars))))
|
| 17 |
-
self.reversed_vocab = {v: k for k, v in self.vocab.items()}
|
| 18 |
-
self.max_length = max_length
|
| 19 |
-
if init_from_file:
|
| 20 |
-
self.init_from_file(init_from_file)
|
| 21 |
-
|
| 22 |
-
def encode(self, char_list: list[str]) -> np.ndarray:
|
| 23 |
-
smiles_matrix = np.zeros(len(char_list), dtype=np.float32)
|
| 24 |
-
for i, char in enumerate(char_list):
|
| 25 |
-
smiles_matrix[i] = self.vocab[char]
|
| 26 |
-
return smiles_matrix
|
| 27 |
-
|
| 28 |
-
def decode(self, matrix: np.ndarray) -> str:
|
| 29 |
-
chars: list[str] = []
|
| 30 |
-
eos_id = self.vocab["EOS"]
|
| 31 |
-
for i in matrix:
|
| 32 |
-
if int(i) == eos_id:
|
| 33 |
-
break
|
| 34 |
-
chars.append(self.reversed_vocab[int(i)])
|
| 35 |
-
return "".join(chars)
|
| 36 |
-
|
| 37 |
-
def tokenize(self, smiles: str) -> list[str]:
|
| 38 |
-
regex = r"(\[[^\[\]]{1,6}\])"
|
| 39 |
-
char_list = re.split(regex, smiles)
|
| 40 |
-
tokenized: list[str] = []
|
| 41 |
-
for char in char_list:
|
| 42 |
-
if not char:
|
| 43 |
-
continue
|
| 44 |
-
if char.startswith("["):
|
| 45 |
-
tokenized.append(char)
|
| 46 |
-
else:
|
| 47 |
-
tokenized.extend(list(char))
|
| 48 |
-
tokenized.append("EOS")
|
| 49 |
-
return tokenized
|
| 50 |
-
|
| 51 |
-
def add_characters(self, chars: list[str]) -> None:
|
| 52 |
-
for char in chars:
|
| 53 |
-
self.additional_chars.add(char)
|
| 54 |
-
char_list = sorted(list(self.additional_chars))
|
| 55 |
-
self.chars = char_list + self.special_tokens
|
| 56 |
-
self.vocab_size = len(self.chars)
|
| 57 |
-
self.vocab = dict(zip(self.chars, range(len(self.chars))))
|
| 58 |
-
self.reversed_vocab = {v: k for k, v in self.vocab.items()}
|
| 59 |
-
|
| 60 |
-
def init_from_file(self, file_path: str) -> None:
|
| 61 |
-
with open(file_path, "r", encoding="utf-8") as f:
|
| 62 |
-
chars = f.read().split()
|
| 63 |
-
self.add_characters(chars)
|
| 64 |
-
|
| 65 |
-
def __len__(self) -> int:
|
| 66 |
-
return len(self.chars)
|
| 67 |
-
|
| 68 |
-
def __str__(self) -> str:
|
| 69 |
-
return f"Vocabulary containing {len(self)} tokens: {self.chars}"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/sascorer.py
DELETED
|
@@ -1,192 +0,0 @@
|
|
| 1 |
-
#
|
| 2 |
-
# calculation of synthetic accessibility score as described in:
|
| 3 |
-
#
|
| 4 |
-
# Estimation of Synthetic Accessibility Score of Drug-like Molecules based on Molecular Complexity and Fragment Contributions
|
| 5 |
-
# Peter Ertl and Ansgar Schuffenhauer
|
| 6 |
-
# Journal of Cheminformatics 1:8 (2009)
|
| 7 |
-
# http://www.jcheminf.com/content/1/1/8
|
| 8 |
-
#
|
| 9 |
-
# several small modifications to the original paper are included
|
| 10 |
-
# particularly slightly different formula for marocyclic penalty
|
| 11 |
-
# and taking into account also molecule symmetry (fingerprint density)
|
| 12 |
-
#
|
| 13 |
-
# for a set of 10k diverse molecules the agreement between the original method
|
| 14 |
-
# as implemented in PipelinePilot and this implementation is r2 = 0.97
|
| 15 |
-
#
|
| 16 |
-
# peter ertl & greg landrum, september 2013
|
| 17 |
-
#
|
| 18 |
-
|
| 19 |
-
from rdkit import Chem
|
| 20 |
-
from rdkit.Chem import rdFingerprintGenerator, rdMolDescriptors
|
| 21 |
-
|
| 22 |
-
import math
|
| 23 |
-
import pickle
|
| 24 |
-
|
| 25 |
-
import os.path as op
|
| 26 |
-
|
| 27 |
-
_fscores = None
|
| 28 |
-
mfpgen = rdFingerprintGenerator.GetMorganGenerator(radius=2)
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
def readFragmentScores(name="fpscores.pkl.gz"):
|
| 32 |
-
import gzip
|
| 33 |
-
global _fscores
|
| 34 |
-
# generate the full path filename:
|
| 35 |
-
if name == "fpscores.pkl.gz":
|
| 36 |
-
name = op.join(op.dirname(__file__), name)
|
| 37 |
-
data = pickle.load(gzip.open(name))
|
| 38 |
-
outDict = {}
|
| 39 |
-
for i in data:
|
| 40 |
-
for j in range(1, len(i)):
|
| 41 |
-
outDict[i[j]] = float(i[0])
|
| 42 |
-
_fscores = outDict
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def numBridgeheadsAndSpiro(mol, ri=None):
|
| 46 |
-
nSpiro = rdMolDescriptors.CalcNumSpiroAtoms(mol)
|
| 47 |
-
nBridgehead = rdMolDescriptors.CalcNumBridgeheadAtoms(mol)
|
| 48 |
-
return nBridgehead, nSpiro
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
def calculateScore(m):
|
| 52 |
-
|
| 53 |
-
if not m.GetNumAtoms():
|
| 54 |
-
return None
|
| 55 |
-
|
| 56 |
-
if _fscores is None:
|
| 57 |
-
readFragmentScores()
|
| 58 |
-
|
| 59 |
-
# fragment score
|
| 60 |
-
sfp = mfpgen.GetSparseCountFingerprint(m)
|
| 61 |
-
|
| 62 |
-
score1 = 0.
|
| 63 |
-
nf = 0
|
| 64 |
-
nze = sfp.GetNonzeroElements()
|
| 65 |
-
for id, count in nze.items():
|
| 66 |
-
nf += count
|
| 67 |
-
score1 += _fscores.get(id, -4) * count
|
| 68 |
-
|
| 69 |
-
score1 /= nf
|
| 70 |
-
|
| 71 |
-
# features score
|
| 72 |
-
nAtoms = m.GetNumAtoms()
|
| 73 |
-
nChiralCenters = len(Chem.FindMolChiralCenters(m, includeUnassigned=True))
|
| 74 |
-
ri = m.GetRingInfo()
|
| 75 |
-
nBridgeheads, nSpiro = numBridgeheadsAndSpiro(m, ri)
|
| 76 |
-
nMacrocycles = 0
|
| 77 |
-
for x in ri.AtomRings():
|
| 78 |
-
if len(x) > 8:
|
| 79 |
-
nMacrocycles += 1
|
| 80 |
-
|
| 81 |
-
sizePenalty = nAtoms**1.005 - nAtoms
|
| 82 |
-
stereoPenalty = math.log10(nChiralCenters + 1)
|
| 83 |
-
spiroPenalty = math.log10(nSpiro + 1)
|
| 84 |
-
bridgePenalty = math.log10(nBridgeheads + 1)
|
| 85 |
-
macrocyclePenalty = 0.
|
| 86 |
-
# ---------------------------------------
|
| 87 |
-
# This differs from the paper, which defines:
|
| 88 |
-
# macrocyclePenalty = math.log10(nMacrocycles+1)
|
| 89 |
-
# This form generates better results when 2 or more macrocycles are present
|
| 90 |
-
if nMacrocycles > 0:
|
| 91 |
-
macrocyclePenalty = math.log10(2)
|
| 92 |
-
|
| 93 |
-
score2 = 0. - sizePenalty - stereoPenalty - spiroPenalty - bridgePenalty - macrocyclePenalty
|
| 94 |
-
|
| 95 |
-
# correction for the fingerprint density
|
| 96 |
-
# not in the original publication, added in version 1.1
|
| 97 |
-
# to make highly symmetrical molecules easier to synthetise
|
| 98 |
-
score3 = 0.
|
| 99 |
-
numBits = len(nze)
|
| 100 |
-
if nAtoms > numBits:
|
| 101 |
-
score3 = math.log(float(nAtoms) / numBits) * .5
|
| 102 |
-
|
| 103 |
-
sascore = score1 + score2 + score3
|
| 104 |
-
|
| 105 |
-
# need to transform "raw" value into scale between 1 and 10
|
| 106 |
-
min = -4.0
|
| 107 |
-
max = 2.5
|
| 108 |
-
sascore = 11. - (sascore - min + 1) / (max - min) * 9.
|
| 109 |
-
|
| 110 |
-
# smooth the 10-end
|
| 111 |
-
if sascore > 8.:
|
| 112 |
-
sascore = 8. + math.log(sascore + 1. - 9.)
|
| 113 |
-
if sascore > 10.:
|
| 114 |
-
sascore = 10.0
|
| 115 |
-
elif sascore < 1.:
|
| 116 |
-
sascore = 1.0
|
| 117 |
-
|
| 118 |
-
return sascore
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
def processMols(mols):
|
| 122 |
-
print('smiles\tName\tsa_score')
|
| 123 |
-
for i, m in enumerate(mols):
|
| 124 |
-
if m is None:
|
| 125 |
-
continue
|
| 126 |
-
|
| 127 |
-
s = calculateScore(m)
|
| 128 |
-
|
| 129 |
-
smiles = Chem.MolToSmiles(m)
|
| 130 |
-
if s is None:
|
| 131 |
-
print(f"{smiles}\t{m.GetProp('_Name')}\t{s}")
|
| 132 |
-
else:
|
| 133 |
-
print(f"{smiles}\t{m.GetProp('_Name')}\t{s:3f}")
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
if __name__ == '__main__':
|
| 137 |
-
import sys
|
| 138 |
-
import time
|
| 139 |
-
|
| 140 |
-
t1 = time.time()
|
| 141 |
-
if len(sys.argv) == 2:
|
| 142 |
-
readFragmentScores()
|
| 143 |
-
else:
|
| 144 |
-
readFragmentScores(sys.argv[2])
|
| 145 |
-
t2 = time.time()
|
| 146 |
-
|
| 147 |
-
molFile = sys.argv[1]
|
| 148 |
-
if molFile.endswith("smi"):
|
| 149 |
-
suppl = Chem.SmilesMolSupplier(molFile)
|
| 150 |
-
elif molFile.endswith("sdf"):
|
| 151 |
-
suppl = Chem.SDMolSupplier(molFile)
|
| 152 |
-
else:
|
| 153 |
-
print(f"Unrecognized file extension for {molFile}")
|
| 154 |
-
sys.exit()
|
| 155 |
-
|
| 156 |
-
t3 = time.time()
|
| 157 |
-
processMols(suppl)
|
| 158 |
-
t4 = time.time()
|
| 159 |
-
|
| 160 |
-
print('Reading took %.2f seconds. Calculating took %.2f seconds' % ((t2 - t1), (t4 - t3)),
|
| 161 |
-
file=sys.stderr)
|
| 162 |
-
|
| 163 |
-
#
|
| 164 |
-
# Copyright (c) 2013, Novartis Institutes for BioMedical Research Inc.
|
| 165 |
-
# All rights reserved.
|
| 166 |
-
#
|
| 167 |
-
# Redistribution and use in source and binary forms, with or without
|
| 168 |
-
# modification, are permitted provided that the following conditions are
|
| 169 |
-
# met:
|
| 170 |
-
#
|
| 171 |
-
# * Redistributions of source code must retain the above copyright
|
| 172 |
-
# notice, this list of conditions and the following disclaimer.
|
| 173 |
-
# * Redistributions in binary form must reproduce the above
|
| 174 |
-
# copyright notice, this list of conditions and the following
|
| 175 |
-
# disclaimer in the documentation and/or other materials provided
|
| 176 |
-
# with the distribution.
|
| 177 |
-
# * Neither the name of Novartis Institutes for BioMedical Research Inc.
|
| 178 |
-
# nor the names of its contributors may be used to endorse or promote
|
| 179 |
-
# products derived from this software without specific prior written permission.
|
| 180 |
-
#
|
| 181 |
-
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
| 182 |
-
# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
| 183 |
-
# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
| 184 |
-
# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
| 185 |
-
# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
| 186 |
-
# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
| 187 |
-
# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
| 188 |
-
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
| 189 |
-
# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
| 190 |
-
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 191 |
-
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 192 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/streamlit_app.py
DELETED
|
@@ -1,40 +0,0 @@
|
|
| 1 |
-
import altair as alt
|
| 2 |
-
import numpy as np
|
| 3 |
-
import pandas as pd
|
| 4 |
-
import streamlit as st
|
| 5 |
-
|
| 6 |
-
"""
|
| 7 |
-
# Welcome to Streamlit!
|
| 8 |
-
|
| 9 |
-
Edit `/streamlit_app.py` to customize this app to your heart's desire :heart:.
|
| 10 |
-
If you have any questions, checkout our [documentation](https://docs.streamlit.io) and [community
|
| 11 |
-
forums](https://discuss.streamlit.io).
|
| 12 |
-
|
| 13 |
-
In the meantime, below is an example of what you can do with just a few lines of code:
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
-
num_points = st.slider("Number of points in spiral", 1, 10000, 1100)
|
| 17 |
-
num_turns = st.slider("Number of turns in spiral", 1, 300, 31)
|
| 18 |
-
|
| 19 |
-
indices = np.linspace(0, 1, num_points)
|
| 20 |
-
theta = 2 * np.pi * num_turns * indices
|
| 21 |
-
radius = indices
|
| 22 |
-
|
| 23 |
-
x = radius * np.cos(theta)
|
| 24 |
-
y = radius * np.sin(theta)
|
| 25 |
-
|
| 26 |
-
df = pd.DataFrame({
|
| 27 |
-
"x": x,
|
| 28 |
-
"y": y,
|
| 29 |
-
"idx": indices,
|
| 30 |
-
"rand": np.random.randn(num_points),
|
| 31 |
-
})
|
| 32 |
-
|
| 33 |
-
st.altair_chart(alt.Chart(df, height=700, width=700)
|
| 34 |
-
.mark_point(filled=True)
|
| 35 |
-
.encode(
|
| 36 |
-
x=alt.X("x", axis=None),
|
| 37 |
-
y=alt.Y("y", axis=None),
|
| 38 |
-
color=alt.Color("idx", legend=None, scale=alt.Scale()),
|
| 39 |
-
size=alt.Size("rand", legend=None, scale=alt.Scale(range=[1, 150])),
|
| 40 |
-
))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/ui_style.py
DELETED
|
@@ -1,1003 +0,0 @@
|
|
| 1 |
-
import base64
|
| 2 |
-
import html
|
| 3 |
-
import os
|
| 4 |
-
from pathlib import Path
|
| 5 |
-
from urllib import request
|
| 6 |
-
|
| 7 |
-
import streamlit as st
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
def _icon_data_uri(filename: str) -> str:
|
| 11 |
-
icon_path = Path(__file__).resolve().parent.parent / "icons" / filename
|
| 12 |
-
if not icon_path.exists():
|
| 13 |
-
return ""
|
| 14 |
-
try:
|
| 15 |
-
encoded = base64.b64encode(icon_path.read_bytes()).decode("ascii")
|
| 16 |
-
except Exception:
|
| 17 |
-
return ""
|
| 18 |
-
return f"data:image/png;base64,{encoded}"
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
def _config_value(name: str, default: str = "") -> str:
|
| 22 |
-
try:
|
| 23 |
-
if name in st.secrets:
|
| 24 |
-
return str(st.secrets[name]).strip()
|
| 25 |
-
except Exception:
|
| 26 |
-
pass
|
| 27 |
-
return str(os.getenv(name, default)).strip()
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
def _build_sidebar_icon_css() -> str:
|
| 31 |
-
fallback = {
|
| 32 |
-
1: "🏠",
|
| 33 |
-
2: "🔎",
|
| 34 |
-
3: "📦",
|
| 35 |
-
4: "🧬",
|
| 36 |
-
5: "⚙️",
|
| 37 |
-
6: "🧠",
|
| 38 |
-
7: "✨",
|
| 39 |
-
8: "💬",
|
| 40 |
-
9: "📚",
|
| 41 |
-
}
|
| 42 |
-
icon_name = {
|
| 43 |
-
1: "home1.png",
|
| 44 |
-
2: "probe1.png",
|
| 45 |
-
3: "batch1.png",
|
| 46 |
-
4: "molecule1.png",
|
| 47 |
-
5: "manual1.png",
|
| 48 |
-
6: "ai1.png",
|
| 49 |
-
7: "rnn1.png",
|
| 50 |
-
8: "literature.png",
|
| 51 |
-
9: "feedback.png",
|
| 52 |
-
}
|
| 53 |
-
rules = [
|
| 54 |
-
'[data-testid="stSidebarNav"] ul li a { position: relative; padding-left: 3.25rem !important; }',
|
| 55 |
-
'[data-testid="stSidebarNav"] ul li a::before { content: ""; position: absolute; left: 12px; top: 50%; transform: translateY(-50%); width: 32px; height: 32px; background-size: contain; background-repeat: no-repeat; background-position: center; }',
|
| 56 |
-
]
|
| 57 |
-
for idx in range(1, 10):
|
| 58 |
-
uri = _icon_data_uri(icon_name[idx])
|
| 59 |
-
if uri:
|
| 60 |
-
rules.append(
|
| 61 |
-
'[data-testid="stSidebarNav"] ul li:nth-of-type(%d) a::before { content: ""; background-image: url("%s"); }'
|
| 62 |
-
% (idx, uri)
|
| 63 |
-
)
|
| 64 |
-
else:
|
| 65 |
-
emoji = fallback[idx]
|
| 66 |
-
rules.append(
|
| 67 |
-
'[data-testid="stSidebarNav"] ul li:nth-of-type(%d) a::before { content: "%s"; background-image: none; width: auto; height: auto; font-size: 1.4rem; }'
|
| 68 |
-
% (idx, emoji)
|
| 69 |
-
)
|
| 70 |
-
return "\n".join(rules)
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
def _log_visit_once_per_session() -> None:
|
| 74 |
-
if st.session_state.get("_visit_logged"):
|
| 75 |
-
return
|
| 76 |
-
webhook_url = _config_value("FEEDBACK_WEBHOOK_URL", "")
|
| 77 |
-
webhook_token = _config_value("FEEDBACK_WEBHOOK_TOKEN", "")
|
| 78 |
-
if not webhook_url:
|
| 79 |
-
return
|
| 80 |
-
endpoint = webhook_url
|
| 81 |
-
sep = "&" if "?" in webhook_url else "?"
|
| 82 |
-
endpoint = f"{webhook_url}{sep}event=visit"
|
| 83 |
-
if webhook_token:
|
| 84 |
-
endpoint = f"{endpoint}&token={webhook_token}"
|
| 85 |
-
try:
|
| 86 |
-
with request.urlopen(endpoint, timeout=3):
|
| 87 |
-
pass
|
| 88 |
-
except Exception:
|
| 89 |
-
pass
|
| 90 |
-
st.session_state["_visit_logged"] = True
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
def render_page_header(title: str, subtitle: str = "", badge: str = "") -> None:
|
| 94 |
-
title_html = html.escape(title)
|
| 95 |
-
subtitle_html = html.escape(subtitle) if subtitle else ""
|
| 96 |
-
badge_html = html.escape(badge) if badge else ""
|
| 97 |
-
|
| 98 |
-
st.markdown(
|
| 99 |
-
f"""
|
| 100 |
-
<section class="pp-page-header">
|
| 101 |
-
{"<span class='pp-badge'>" + badge_html + "</span>" if badge_html else ""}
|
| 102 |
-
<h1 class="pp-page-title">{title_html}</h1>
|
| 103 |
-
{"<p class='pp-page-subtitle'>" + subtitle_html + "</p>" if subtitle_html else ""}
|
| 104 |
-
</section>
|
| 105 |
-
""",
|
| 106 |
-
unsafe_allow_html=True,
|
| 107 |
-
)
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
def apply_global_style() -> None:
|
| 111 |
-
_log_visit_once_per_session()
|
| 112 |
-
icon_css = _build_sidebar_icon_css()
|
| 113 |
-
css = """
|
| 114 |
-
<style>
|
| 115 |
-
@import url('https://fonts.googleapis.com/css2?family=Manrope:wght@400;500;600;700;800&display=swap');
|
| 116 |
-
|
| 117 |
-
:root {
|
| 118 |
-
--pp-text: #133a2a;
|
| 119 |
-
--pp-muted: #4e6f5d;
|
| 120 |
-
--pp-primary: #1f8f52;
|
| 121 |
-
--pp-primary-2: #2cab67;
|
| 122 |
-
--pp-border: rgba(255, 255, 255, 0.96);
|
| 123 |
-
--pp-surface: rgba(251, 247, 255, 0.90);
|
| 124 |
-
--pp-shadow: 0 10px 22px rgba(70, 66, 110, 0.12);
|
| 125 |
-
--pp-panel-shadow:
|
| 126 |
-
0 14px 28px rgba(83, 73, 124, 0.16),
|
| 127 |
-
0 1px 0 rgba(255, 255, 255, 0.58) inset;
|
| 128 |
-
}
|
| 129 |
-
|
| 130 |
-
html, body, [class*="css"], [data-testid="stMarkdownContainer"] * {
|
| 131 |
-
font-family: "Manrope", "Avenir Next", "Segoe UI", sans-serif;
|
| 132 |
-
}
|
| 133 |
-
|
| 134 |
-
[data-testid="stAppViewContainer"] {
|
| 135 |
-
min-height: 100vh;
|
| 136 |
-
overflow: visible !important;
|
| 137 |
-
background:
|
| 138 |
-
radial-gradient(1300px 700px at -10% 105%, rgba(188, 113, 202, 0.62), transparent 62%),
|
| 139 |
-
radial-gradient(1200px 700px at 108% 104%, rgba(130, 170, 235, 0.50), transparent 62%),
|
| 140 |
-
linear-gradient(120deg, #d5c4e3 0%, #cdd4e9 45%, #c4e1ea 100%) !important;
|
| 141 |
-
}
|
| 142 |
-
|
| 143 |
-
.stApp,
|
| 144 |
-
[data-testid="stAppViewContainer"] > .main,
|
| 145 |
-
section[data-testid="stMain"] {
|
| 146 |
-
color: var(--pp-text);
|
| 147 |
-
background: transparent !important;
|
| 148 |
-
}
|
| 149 |
-
|
| 150 |
-
[data-testid="stAppViewContainer"] > .main {
|
| 151 |
-
margin: 0 !important;
|
| 152 |
-
border: none !important;
|
| 153 |
-
border-radius: 0 !important;
|
| 154 |
-
box-shadow: none !important;
|
| 155 |
-
background: transparent !important;
|
| 156 |
-
overflow: visible !important;
|
| 157 |
-
}
|
| 158 |
-
|
| 159 |
-
section[data-testid="stMain"] {
|
| 160 |
-
position: relative;
|
| 161 |
-
margin: 12px 14px 12px 18px;
|
| 162 |
-
min-height: calc(100vh - 24px) !important;
|
| 163 |
-
border-radius: 30px;
|
| 164 |
-
border: 1px solid rgba(255, 255, 255, 0.98);
|
| 165 |
-
border-right: 2px solid rgba(233, 223, 245, 0.98);
|
| 166 |
-
background: rgba(244, 239, 250, 0.96) !important;
|
| 167 |
-
box-shadow:
|
| 168 |
-
var(--pp-panel-shadow),
|
| 169 |
-
inset -1px 0 0 rgba(233, 223, 245, 0.95);
|
| 170 |
-
overflow: visible;
|
| 171 |
-
isolation: isolate;
|
| 172 |
-
}
|
| 173 |
-
|
| 174 |
-
section[data-testid="stMain"] {
|
| 175 |
-
overflow-y: visible !important;
|
| 176 |
-
overflow-x: hidden !important;
|
| 177 |
-
}
|
| 178 |
-
|
| 179 |
-
section[data-testid="stMain"]::before {
|
| 180 |
-
content: "";
|
| 181 |
-
position: absolute;
|
| 182 |
-
inset: 0;
|
| 183 |
-
border-radius: inherit;
|
| 184 |
-
pointer-events: none;
|
| 185 |
-
box-shadow: inset 0 1px 0 rgba(255, 255, 255, 0.58);
|
| 186 |
-
z-index: 0;
|
| 187 |
-
}
|
| 188 |
-
|
| 189 |
-
section[data-testid="stMain"]::after {
|
| 190 |
-
display: none;
|
| 191 |
-
}
|
| 192 |
-
|
| 193 |
-
section[data-testid="stMain"] > div,
|
| 194 |
-
[data-testid="stMainBlockContainer"] {
|
| 195 |
-
position: relative;
|
| 196 |
-
z-index: 1;
|
| 197 |
-
background: transparent !important;
|
| 198 |
-
}
|
| 199 |
-
|
| 200 |
-
[data-testid="stHeader"] {
|
| 201 |
-
background: transparent !important;
|
| 202 |
-
}
|
| 203 |
-
|
| 204 |
-
.block-container {
|
| 205 |
-
max-width: 1220px;
|
| 206 |
-
padding-top: 1.3rem;
|
| 207 |
-
padding-bottom: 2.4rem;
|
| 208 |
-
}
|
| 209 |
-
|
| 210 |
-
h1, h2, h3, h4, h5 {
|
| 211 |
-
letter-spacing: -0.02em;
|
| 212 |
-
}
|
| 213 |
-
|
| 214 |
-
p, li, label, [data-testid="stCaptionContainer"] {
|
| 215 |
-
color: var(--pp-muted);
|
| 216 |
-
}
|
| 217 |
-
|
| 218 |
-
a, a:visited {
|
| 219 |
-
color: #1f8f52 !important;
|
| 220 |
-
}
|
| 221 |
-
|
| 222 |
-
section[data-testid="stSidebar"],
|
| 223 |
-
[data-testid="stSidebar"] {
|
| 224 |
-
background: transparent !important;
|
| 225 |
-
border-right: none !important;
|
| 226 |
-
--pp-sidebar-width: 300px;
|
| 227 |
-
height: calc(100vh - 24px) !important;
|
| 228 |
-
min-width: var(--pp-sidebar-width) !important;
|
| 229 |
-
width: var(--pp-sidebar-width) !important;
|
| 230 |
-
max-width: var(--pp-sidebar-width) !important;
|
| 231 |
-
flex: 0 0 var(--pp-sidebar-width) !important;
|
| 232 |
-
flex-basis: var(--pp-sidebar-width) !important;
|
| 233 |
-
}
|
| 234 |
-
|
| 235 |
-
section[data-testid="stSidebar"][aria-expanded="true"] {
|
| 236 |
-
min-width: var(--pp-sidebar-width) !important;
|
| 237 |
-
width: var(--pp-sidebar-width) !important;
|
| 238 |
-
max-width: var(--pp-sidebar-width) !important;
|
| 239 |
-
flex: 0 0 var(--pp-sidebar-width) !important;
|
| 240 |
-
flex-basis: var(--pp-sidebar-width) !important;
|
| 241 |
-
}
|
| 242 |
-
|
| 243 |
-
[data-testid="stSidebar"] > div:first-child {
|
| 244 |
-
margin: 12px 0 12px 12px;
|
| 245 |
-
height: calc(100vh - 24px);
|
| 246 |
-
border-radius: 30px;
|
| 247 |
-
border: 1px solid rgba(255, 255, 255, 0.98);
|
| 248 |
-
background: rgba(241, 226, 248, 0.96) !important;
|
| 249 |
-
box-shadow: var(--pp-panel-shadow);
|
| 250 |
-
backdrop-filter: blur(6px);
|
| 251 |
-
overflow-y: auto;
|
| 252 |
-
overflow-x: hidden;
|
| 253 |
-
}
|
| 254 |
-
|
| 255 |
-
[data-testid="stSidebarUserContent"] {
|
| 256 |
-
padding-top: 0.25rem;
|
| 257 |
-
}
|
| 258 |
-
|
| 259 |
-
[data-testid="stSidebarNav"] ul li,
|
| 260 |
-
[data-testid="stSidebarNav"] ul li a,
|
| 261 |
-
[data-testid="stSidebarNav"] ul li button {
|
| 262 |
-
margin: 0 !important;
|
| 263 |
-
padding: 0 !important;
|
| 264 |
-
}
|
| 265 |
-
|
| 266 |
-
[data-testid="stSidebarNav"] ul li + li {
|
| 267 |
-
margin-top: 0.44rem !important;
|
| 268 |
-
}
|
| 269 |
-
|
| 270 |
-
[data-testid="stSidebarNav"] ul li a,
|
| 271 |
-
[data-testid="stSidebarNav"] ul li button {
|
| 272 |
-
font-size: 1.02rem !important;
|
| 273 |
-
font-family: "Inter", "Manrope", "Avenir Next", "Segoe UI", sans-serif !important;
|
| 274 |
-
font-weight: 600 !important;
|
| 275 |
-
color: #1f6b4a !important;
|
| 276 |
-
border-radius: 12px !important;
|
| 277 |
-
display: flex !important;
|
| 278 |
-
align-items: center !important;
|
| 279 |
-
justify-content: flex-start !important;
|
| 280 |
-
height: 44px !important;
|
| 281 |
-
min-height: 44px !important;
|
| 282 |
-
max-height: 44px !important;
|
| 283 |
-
line-height: 1 !important;
|
| 284 |
-
padding: 0 0.78rem 0 3.1rem !important;
|
| 285 |
-
border: 1px solid rgba(255, 255, 255, 0.98);
|
| 286 |
-
background: rgba(255, 255, 255, 0.78) !important;
|
| 287 |
-
box-shadow:
|
| 288 |
-
0 1px 0 rgba(255, 255, 255, 0.42) inset,
|
| 289 |
-
0 2px 7px rgba(80, 86, 131, 0.07);
|
| 290 |
-
transition: all 140ms ease;
|
| 291 |
-
box-sizing: border-box !important;
|
| 292 |
-
overflow: hidden !important;
|
| 293 |
-
}
|
| 294 |
-
|
| 295 |
-
[data-testid="stSidebarNav"] ul li a > div,
|
| 296 |
-
[data-testid="stSidebarNav"] ul li button > div {
|
| 297 |
-
margin: 0 !important;
|
| 298 |
-
padding: 0 !important;
|
| 299 |
-
display: flex !important;
|
| 300 |
-
align-items: center !important;
|
| 301 |
-
min-height: 0 !important;
|
| 302 |
-
line-height: 1 !important;
|
| 303 |
-
}
|
| 304 |
-
|
| 305 |
-
[data-testid="stSidebarNav"] ul li a span,
|
| 306 |
-
[data-testid="stSidebarNav"] ul li button span {
|
| 307 |
-
font-size: 0.95rem !important;
|
| 308 |
-
font-family: "Inter", "Manrope", "Avenir Next", "Segoe UI", sans-serif !important;
|
| 309 |
-
font-weight: 600 !important;
|
| 310 |
-
color: #1f6b4a !important;
|
| 311 |
-
line-height: 1.05 !important;
|
| 312 |
-
white-space: nowrap !important;
|
| 313 |
-
margin: 0 !important;
|
| 314 |
-
padding: 0 !important;
|
| 315 |
-
}
|
| 316 |
-
|
| 317 |
-
[data-testid="stSidebarNav"] ul li a:hover,
|
| 318 |
-
[data-testid="stSidebarNav"] ul li button:hover {
|
| 319 |
-
transform: translateY(-1px);
|
| 320 |
-
background: rgba(255, 255, 255, 0.90) !important;
|
| 321 |
-
box-shadow:
|
| 322 |
-
0 1px 0 rgba(255, 255, 255, 0.44) inset,
|
| 323 |
-
0 4px 10px rgba(85, 94, 142, 0.10);
|
| 324 |
-
}
|
| 325 |
-
|
| 326 |
-
[data-testid="stSidebarNav"] ul li a[aria-current="page"],
|
| 327 |
-
[data-testid="stSidebarNav"] ul li button[aria-current="page"] {
|
| 328 |
-
border: 1px solid rgba(34, 163, 93, 0.34) !important;
|
| 329 |
-
background: linear-gradient(100deg, #21a35e, #34c67a) !important;
|
| 330 |
-
box-shadow:
|
| 331 |
-
0 1px 0 rgba(255, 255, 255, 0.22) inset,
|
| 332 |
-
0 8px 16px rgba(34, 163, 93, 0.28);
|
| 333 |
-
}
|
| 334 |
-
|
| 335 |
-
[data-testid="stSidebarNav"] ul li a[aria-current="page"],
|
| 336 |
-
[data-testid="stSidebarNav"] ul li a[aria-current="page"] *,
|
| 337 |
-
[data-testid="stSidebarNav"] ul li button[aria-current="page"],
|
| 338 |
-
[data-testid="stSidebarNav"] ul li button[aria-current="page"] * {
|
| 339 |
-
color: #ffffff !important;
|
| 340 |
-
fill: #ffffff !important;
|
| 341 |
-
font-weight: 700 !important;
|
| 342 |
-
}
|
| 343 |
-
|
| 344 |
-
[data-testid="stSidebarNav"] ul li a[aria-current="page"]::before {
|
| 345 |
-
filter: brightness(0) invert(1);
|
| 346 |
-
opacity: 0.96;
|
| 347 |
-
}
|
| 348 |
-
|
| 349 |
-
__ICON_CSS__
|
| 350 |
-
|
| 351 |
-
.stTextInput > div > div > input,
|
| 352 |
-
.stTextArea textarea,
|
| 353 |
-
.stSelectbox [data-baseweb="select"] > div,
|
| 354 |
-
.stMultiSelect [data-baseweb="select"] > div,
|
| 355 |
-
.stNumberInput input {
|
| 356 |
-
border-radius: 12px !important;
|
| 357 |
-
border: 1px solid #d7deeb !important;
|
| 358 |
-
background: rgba(255, 255, 255, 0.87) !important;
|
| 359 |
-
box-shadow: none !important;
|
| 360 |
-
color: #173b2b !important;
|
| 361 |
-
}
|
| 362 |
-
|
| 363 |
-
.stTextInput > div > div > input:focus,
|
| 364 |
-
.stTextArea textarea:focus,
|
| 365 |
-
.stSelectbox [data-baseweb="select"] > div:focus-within,
|
| 366 |
-
.stMultiSelect [data-baseweb="select"] > div:focus-within,
|
| 367 |
-
.stNumberInput input:focus {
|
| 368 |
-
border-color: #21a35e !important;
|
| 369 |
-
box-shadow: 0 0 0 3px rgba(34, 163, 93, 0.18) !important;
|
| 370 |
-
}
|
| 371 |
-
|
| 372 |
-
/* Force dropdown/expanded menus to green-white accents */
|
| 373 |
-
[data-baseweb="popover"] [role="listbox"],
|
| 374 |
-
[data-baseweb="popover"] [data-baseweb="menu"],
|
| 375 |
-
div[role="listbox"] {
|
| 376 |
-
background: rgba(248, 252, 249, 0.98) !important;
|
| 377 |
-
border: 1px solid rgba(185, 214, 198, 0.95) !important;
|
| 378 |
-
border-radius: 12px !important;
|
| 379 |
-
box-shadow: 0 10px 24px rgba(44, 95, 67, 0.14) !important;
|
| 380 |
-
}
|
| 381 |
-
|
| 382 |
-
[data-baseweb="popover"] [role="option"],
|
| 383 |
-
[data-baseweb="popover"] li,
|
| 384 |
-
[data-baseweb="popover"] [data-baseweb="menu"] > div,
|
| 385 |
-
div[role="listbox"] [role="option"] {
|
| 386 |
-
background: transparent !important;
|
| 387 |
-
color: #173b2b !important;
|
| 388 |
-
}
|
| 389 |
-
|
| 390 |
-
[data-baseweb="popover"] [role="option"]:hover,
|
| 391 |
-
[data-baseweb="popover"] li:hover,
|
| 392 |
-
[data-baseweb="popover"] [data-highlighted="true"],
|
| 393 |
-
div[role="listbox"] [role="option"]:hover {
|
| 394 |
-
background: rgba(34, 163, 93, 0.10) !important;
|
| 395 |
-
}
|
| 396 |
-
|
| 397 |
-
[data-baseweb="popover"] [role="option"][aria-selected="true"],
|
| 398 |
-
[data-baseweb="popover"] li[aria-selected="true"],
|
| 399 |
-
[data-baseweb="popover"] [aria-selected="true"],
|
| 400 |
-
div[role="listbox"] [role="option"][aria-selected="true"] {
|
| 401 |
-
background: rgba(34, 163, 93, 0.18) !important;
|
| 402 |
-
color: #173b2b !important;
|
| 403 |
-
}
|
| 404 |
-
|
| 405 |
-
[data-baseweb="tag"] {
|
| 406 |
-
background: #2f9d62 !important;
|
| 407 |
-
border: 1px solid #288653 !important;
|
| 408 |
-
color: #ffffff !important;
|
| 409 |
-
}
|
| 410 |
-
|
| 411 |
-
[data-baseweb="tag"] *,
|
| 412 |
-
[data-baseweb="tag"] svg {
|
| 413 |
-
color: #ffffff !important;
|
| 414 |
-
fill: #ffffff !important;
|
| 415 |
-
}
|
| 416 |
-
|
| 417 |
-
/* Keep sliders/toggles green while page background stays blue */
|
| 418 |
-
.stSlider [data-baseweb="slider"] > div > div > div:first-child,
|
| 419 |
-
[data-baseweb="slider"] > div > div > div:first-child {
|
| 420 |
-
background-color: #1f8f52 !important;
|
| 421 |
-
}
|
| 422 |
-
|
| 423 |
-
.stSlider [data-baseweb="slider"] > div > div > div:last-child,
|
| 424 |
-
[data-baseweb="slider"] > div > div > div:last-child {
|
| 425 |
-
background-color: rgba(34, 163, 93, 0.30) !important;
|
| 426 |
-
}
|
| 427 |
-
|
| 428 |
-
[data-baseweb="slider"] [style*="rgb(79, 70, 229)"],
|
| 429 |
-
[data-baseweb="slider"] [style*="rgb(91, 80, 255)"],
|
| 430 |
-
[data-baseweb="slider"] [style*="rgb(67, 56, 202)"],
|
| 431 |
-
[data-baseweb="slider"] [style*="rgb("] {
|
| 432 |
-
background-color: #1f8f52 !important;
|
| 433 |
-
border-color: #1f8f52 !important;
|
| 434 |
-
}
|
| 435 |
-
|
| 436 |
-
[data-baseweb="slider"] [role="slider"] {
|
| 437 |
-
background-color: #1f8f52 !important;
|
| 438 |
-
border: 2px solid #ffffff !important;
|
| 439 |
-
box-shadow: 0 0 0 1px rgba(34, 163, 93, 0.35), 0 2px 6px rgba(34, 163, 93, 0.28) !important;
|
| 440 |
-
}
|
| 441 |
-
|
| 442 |
-
[data-baseweb="checkbox"] [aria-checked="true"] {
|
| 443 |
-
color: #1f8f52 !important;
|
| 444 |
-
}
|
| 445 |
-
|
| 446 |
-
[data-baseweb="checkbox"] [aria-checked="true"] > div {
|
| 447 |
-
background-color: #1f8f52 !important;
|
| 448 |
-
border-color: #1f8f52 !important;
|
| 449 |
-
}
|
| 450 |
-
|
| 451 |
-
input[type="checkbox"],
|
| 452 |
-
input[type="radio"] {
|
| 453 |
-
accent-color: #1f8f52 !important;
|
| 454 |
-
}
|
| 455 |
-
|
| 456 |
-
[data-baseweb="radio"] [aria-checked="true"] {
|
| 457 |
-
color: #1f8f52 !important;
|
| 458 |
-
}
|
| 459 |
-
|
| 460 |
-
.stButton > button,
|
| 461 |
-
.stDownloadButton > button,
|
| 462 |
-
[data-testid="baseButton-secondary"] {
|
| 463 |
-
border-radius: 999px !important;
|
| 464 |
-
border: 1px solid #d7dff2 !important;
|
| 465 |
-
font-weight: 500 !important;
|
| 466 |
-
min-height: 2.65rem;
|
| 467 |
-
padding: 0.3rem 1.08rem !important;
|
| 468 |
-
background: rgba(255, 255, 255, 0.94) !important;
|
| 469 |
-
transition: all 140ms ease;
|
| 470 |
-
}
|
| 471 |
-
|
| 472 |
-
.stButton > button[kind="primary"],
|
| 473 |
-
[data-testid="baseButton-primary"] {
|
| 474 |
-
background: linear-gradient(100deg, var(--pp-primary), var(--pp-primary-2)) !important;
|
| 475 |
-
color: #fff !important;
|
| 476 |
-
border: none !important;
|
| 477 |
-
box-shadow: 0 10px 22px rgba(31, 157, 85, 0.34);
|
| 478 |
-
}
|
| 479 |
-
|
| 480 |
-
.stButton > button[kind="primary"] *,
|
| 481 |
-
[data-testid="baseButton-primary"] * {
|
| 482 |
-
color: #fff !important;
|
| 483 |
-
fill: #fff !important;
|
| 484 |
-
}
|
| 485 |
-
|
| 486 |
-
[data-testid="stFormSubmitButton"] button,
|
| 487 |
-
[data-testid="stFormSubmitButton"] button * {
|
| 488 |
-
color: #fff !important;
|
| 489 |
-
fill: #fff !important;
|
| 490 |
-
}
|
| 491 |
-
|
| 492 |
-
.stButton > button:hover,
|
| 493 |
-
.stDownloadButton > button:hover {
|
| 494 |
-
transform: translateY(-1px);
|
| 495 |
-
box-shadow: 0 10px 20px rgba(44, 95, 67, 0.2);
|
| 496 |
-
}
|
| 497 |
-
|
| 498 |
-
div[data-testid="stVerticalBlockBorderWrapper"] {
|
| 499 |
-
border-radius: 18px !important;
|
| 500 |
-
border: 1px solid var(--pp-border) !important;
|
| 501 |
-
background: var(--pp-surface) !important;
|
| 502 |
-
box-shadow: var(--pp-shadow);
|
| 503 |
-
}
|
| 504 |
-
|
| 505 |
-
div[data-testid="stMetric"] {
|
| 506 |
-
background: rgba(255, 255, 255, 0.72);
|
| 507 |
-
border-radius: 14px;
|
| 508 |
-
border: 1px solid rgba(255, 255, 255, 0.84);
|
| 509 |
-
padding: 0.45rem 0.7rem;
|
| 510 |
-
}
|
| 511 |
-
|
| 512 |
-
div[data-testid="stMetric"] label {
|
| 513 |
-
color: #4d705d !important;
|
| 514 |
-
font-weight: 600 !important;
|
| 515 |
-
letter-spacing: 0.01em;
|
| 516 |
-
}
|
| 517 |
-
|
| 518 |
-
div[data-testid="stMetricValue"] {
|
| 519 |
-
color: #1f8f52 !important;
|
| 520 |
-
font-weight: 800 !important;
|
| 521 |
-
}
|
| 522 |
-
|
| 523 |
-
[data-testid="stDataFrame"],
|
| 524 |
-
[data-testid="stTable"] {
|
| 525 |
-
background: rgba(255, 255, 255, 0.78);
|
| 526 |
-
border-radius: 14px;
|
| 527 |
-
border: 1px solid rgba(255, 255, 255, 0.88);
|
| 528 |
-
overflow: hidden;
|
| 529 |
-
}
|
| 530 |
-
|
| 531 |
-
[data-testid="stDataFrame"] [role="grid"],
|
| 532 |
-
[data-testid="stDataFrame"] [role="rowgroup"],
|
| 533 |
-
[data-testid="stDataFrame"] [role="row"],
|
| 534 |
-
[data-testid="stDataFrame"] [role="gridcell"],
|
| 535 |
-
[data-testid="stDataFrame"] [role="columnheader"] {
|
| 536 |
-
background-color: rgba(247, 252, 248, 0.90) !important;
|
| 537 |
-
border-color: rgba(188, 213, 196, 0.55) !important;
|
| 538 |
-
}
|
| 539 |
-
|
| 540 |
-
[data-testid="stTable"] table,
|
| 541 |
-
[data-testid="stTable"] th,
|
| 542 |
-
[data-testid="stTable"] td {
|
| 543 |
-
background-color: rgba(248, 252, 249, 0.94) !important;
|
| 544 |
-
border-color: rgba(188, 213, 196, 0.62) !important;
|
| 545 |
-
}
|
| 546 |
-
|
| 547 |
-
.pp-page-header {
|
| 548 |
-
margin: 0.1rem 0 1.0rem 0;
|
| 549 |
-
}
|
| 550 |
-
|
| 551 |
-
.pp-page-title {
|
| 552 |
-
margin: 0.2rem 0 0.45rem 0;
|
| 553 |
-
font-size: clamp(1.95rem, 2.65vw, 3.0rem);
|
| 554 |
-
line-height: 1.1;
|
| 555 |
-
font-weight: 800;
|
| 556 |
-
color: #123726;
|
| 557 |
-
}
|
| 558 |
-
|
| 559 |
-
.pp-page-subtitle {
|
| 560 |
-
margin: 0;
|
| 561 |
-
max-width: 880px;
|
| 562 |
-
color: #4a6e5b;
|
| 563 |
-
font-size: 1.02rem;
|
| 564 |
-
line-height: 1.62;
|
| 565 |
-
}
|
| 566 |
-
|
| 567 |
-
.pp-badge {
|
| 568 |
-
display: inline-flex;
|
| 569 |
-
align-items: center;
|
| 570 |
-
gap: 0.38rem;
|
| 571 |
-
padding: 0.3rem 0.72rem;
|
| 572 |
-
border-radius: 999px;
|
| 573 |
-
border: 1px solid rgba(255, 255, 255, 0.92);
|
| 574 |
-
background: rgba(255, 255, 255, 0.72);
|
| 575 |
-
color: #3f6856;
|
| 576 |
-
font-size: 0.76rem;
|
| 577 |
-
font-weight: 700;
|
| 578 |
-
text-transform: uppercase;
|
| 579 |
-
letter-spacing: 0.04em;
|
| 580 |
-
}
|
| 581 |
-
|
| 582 |
-
.pp-hero {
|
| 583 |
-
border-radius: 22px;
|
| 584 |
-
border: 1px solid rgba(255, 255, 255, 0.84);
|
| 585 |
-
background:
|
| 586 |
-
linear-gradient(95deg, rgba(255, 255, 255, 0.74), rgba(255, 255, 255, 0.60)),
|
| 587 |
-
radial-gradient(800px 300px at 100% 0%, rgba(34, 163, 93, 0.16), transparent 72%);
|
| 588 |
-
box-shadow: 0 16px 32px rgba(56, 70, 121, 0.11);
|
| 589 |
-
padding: 1.45rem 1.5rem;
|
| 590 |
-
margin: 0.3rem 0 1.2rem;
|
| 591 |
-
}
|
| 592 |
-
|
| 593 |
-
.pp-hero-grid {
|
| 594 |
-
display: grid;
|
| 595 |
-
grid-template-columns: minmax(0, 4fr) minmax(130px, 1fr);
|
| 596 |
-
gap: 1.4rem;
|
| 597 |
-
align-items: center;
|
| 598 |
-
}
|
| 599 |
-
|
| 600 |
-
.pp-hero-title {
|
| 601 |
-
margin: 0.58rem 0 0.5rem;
|
| 602 |
-
color: #123726;
|
| 603 |
-
font-size: clamp(1.55rem, 2.12vw, 2.35rem);
|
| 604 |
-
font-weight: 800;
|
| 605 |
-
letter-spacing: -0.014em;
|
| 606 |
-
line-height: 1.2;
|
| 607 |
-
}
|
| 608 |
-
|
| 609 |
-
.pp-hero-copy {
|
| 610 |
-
margin: 0;
|
| 611 |
-
color: #4c6f5d;
|
| 612 |
-
max-width: 780px;
|
| 613 |
-
line-height: 1.66;
|
| 614 |
-
}
|
| 615 |
-
|
| 616 |
-
.pp-hero-logo {
|
| 617 |
-
display: flex;
|
| 618 |
-
justify-content: center;
|
| 619 |
-
}
|
| 620 |
-
|
| 621 |
-
.pp-hero-logo img {
|
| 622 |
-
width: 112px;
|
| 623 |
-
height: 112px;
|
| 624 |
-
border-radius: 18px;
|
| 625 |
-
border: 1px solid rgba(255, 255, 255, 0.88);
|
| 626 |
-
box-shadow: 0 12px 26px rgba(30, 49, 103, 0.2);
|
| 627 |
-
object-fit: contain;
|
| 628 |
-
background: rgba(255, 255, 255, 0.96);
|
| 629 |
-
padding: 8px;
|
| 630 |
-
}
|
| 631 |
-
|
| 632 |
-
.pp-stat-card {
|
| 633 |
-
border-radius: 16px;
|
| 634 |
-
border: 1px solid rgba(255, 255, 255, 0.9);
|
| 635 |
-
background: rgba(255, 255, 255, 0.78);
|
| 636 |
-
padding: 0.85rem 0.92rem;
|
| 637 |
-
box-shadow: 0 8px 22px rgba(56, 78, 145, 0.12);
|
| 638 |
-
}
|
| 639 |
-
|
| 640 |
-
.pp-stat-value {
|
| 641 |
-
margin: 0;
|
| 642 |
-
color: #2f9d62;
|
| 643 |
-
font-size: clamp(1.2rem, 1.8vw, 1.95rem);
|
| 644 |
-
font-weight: 800;
|
| 645 |
-
letter-spacing: -0.018em;
|
| 646 |
-
}
|
| 647 |
-
|
| 648 |
-
.pp-stat-label {
|
| 649 |
-
margin: 0.2rem 0 0;
|
| 650 |
-
color: #678a76;
|
| 651 |
-
font-size: 0.8rem;
|
| 652 |
-
text-transform: uppercase;
|
| 653 |
-
letter-spacing: 0.045em;
|
| 654 |
-
font-weight: 700;
|
| 655 |
-
}
|
| 656 |
-
|
| 657 |
-
.pp-kpi-strip {
|
| 658 |
-
border-radius: 28px;
|
| 659 |
-
border: 1px solid rgba(255, 255, 255, 0.9);
|
| 660 |
-
background: rgba(247, 251, 248, 0.86);
|
| 661 |
-
box-shadow: 0 12px 26px rgba(56, 78, 145, 0.08);
|
| 662 |
-
display: grid;
|
| 663 |
-
grid-template-columns: repeat(3, minmax(0, 1fr));
|
| 664 |
-
gap: 0;
|
| 665 |
-
padding: 1.45rem 1.15rem 1.25rem;
|
| 666 |
-
margin: 0.62rem 0 0.9rem 0;
|
| 667 |
-
}
|
| 668 |
-
|
| 669 |
-
.pp-kpi-item {
|
| 670 |
-
padding: 0.1rem 0.45rem 0.2rem;
|
| 671 |
-
text-align: center;
|
| 672 |
-
}
|
| 673 |
-
|
| 674 |
-
.pp-kpi-strip .pp-kpi-value {
|
| 675 |
-
margin: 0;
|
| 676 |
-
color: #2f9d62 !important;
|
| 677 |
-
font-size: 3.5rem !important;
|
| 678 |
-
font-weight: 800 !important;
|
| 679 |
-
line-height: 0.98 !important;
|
| 680 |
-
letter-spacing: -0.03em !important;
|
| 681 |
-
}
|
| 682 |
-
|
| 683 |
-
.pp-kpi-strip .pp-kpi-label {
|
| 684 |
-
margin: 0.75rem 0 0;
|
| 685 |
-
color: #6f8d7c !important;
|
| 686 |
-
font-size: 0.99rem !important;
|
| 687 |
-
font-weight: 700 !important;
|
| 688 |
-
letter-spacing: 0.12em !important;
|
| 689 |
-
text-transform: uppercase !important;
|
| 690 |
-
}
|
| 691 |
-
|
| 692 |
-
.pp-step-card {
|
| 693 |
-
border-radius: 16px;
|
| 694 |
-
border: 1px solid rgba(255, 255, 255, 0.9);
|
| 695 |
-
background: rgba(255, 255, 255, 0.77);
|
| 696 |
-
padding: 0.8rem 0.92rem;
|
| 697 |
-
min-height: 120px;
|
| 698 |
-
}
|
| 699 |
-
|
| 700 |
-
.pp-step-title {
|
| 701 |
-
margin: 0 0 0.26rem;
|
| 702 |
-
color: #1a4b36;
|
| 703 |
-
font-weight: 700;
|
| 704 |
-
font-size: 0.98rem;
|
| 705 |
-
}
|
| 706 |
-
|
| 707 |
-
.pp-step-copy {
|
| 708 |
-
margin: 0;
|
| 709 |
-
color: #557562;
|
| 710 |
-
font-size: 0.89rem;
|
| 711 |
-
line-height: 1.45;
|
| 712 |
-
}
|
| 713 |
-
|
| 714 |
-
.pp-module-card {
|
| 715 |
-
border-radius: 16px;
|
| 716 |
-
border: 1px solid rgba(255, 255, 255, 0.9);
|
| 717 |
-
background: rgba(255, 255, 255, 0.77);
|
| 718 |
-
box-shadow: 0 8px 22px rgba(56, 78, 145, 0.10);
|
| 719 |
-
padding: 1.02rem 1.05rem;
|
| 720 |
-
margin-bottom: 0.74rem;
|
| 721 |
-
min-height: 136px;
|
| 722 |
-
}
|
| 723 |
-
|
| 724 |
-
.pp-module-title {
|
| 725 |
-
margin: 0 0 0.32rem;
|
| 726 |
-
color: #1a4b36;
|
| 727 |
-
font-size: 1.02rem;
|
| 728 |
-
font-weight: 800;
|
| 729 |
-
line-height: 1.32;
|
| 730 |
-
}
|
| 731 |
-
|
| 732 |
-
.pp-module-copy {
|
| 733 |
-
margin: 0;
|
| 734 |
-
color: #557461;
|
| 735 |
-
font-size: 0.9rem;
|
| 736 |
-
line-height: 1.5;
|
| 737 |
-
}
|
| 738 |
-
|
| 739 |
-
.pp-api-card {
|
| 740 |
-
border-radius: 18px;
|
| 741 |
-
border: 1px solid rgba(255, 255, 255, 0.9);
|
| 742 |
-
background:
|
| 743 |
-
linear-gradient(120deg, rgba(255, 255, 255, 0.84), rgba(246, 251, 248, 0.76));
|
| 744 |
-
box-shadow: 0 10px 24px rgba(56, 78, 145, 0.10);
|
| 745 |
-
padding: 1rem 1.05rem 0.95rem;
|
| 746 |
-
margin: 0 0 0.8rem 0;
|
| 747 |
-
}
|
| 748 |
-
|
| 749 |
-
.pp-api-kicker {
|
| 750 |
-
display: inline-flex;
|
| 751 |
-
align-items: center;
|
| 752 |
-
padding: 0.28rem 0.72rem;
|
| 753 |
-
border-radius: 999px;
|
| 754 |
-
border: 1px solid rgba(188, 213, 196, 0.95);
|
| 755 |
-
background: rgba(255, 255, 255, 0.84);
|
| 756 |
-
color: #4a6e5b;
|
| 757 |
-
text-transform: uppercase;
|
| 758 |
-
letter-spacing: 0.08em;
|
| 759 |
-
font-size: 0.72rem;
|
| 760 |
-
font-weight: 700;
|
| 761 |
-
}
|
| 762 |
-
|
| 763 |
-
.pp-api-title {
|
| 764 |
-
margin: 0.72rem 0 0.26rem;
|
| 765 |
-
color: #123726;
|
| 766 |
-
font-size: clamp(1.2rem, 1.6vw, 1.55rem);
|
| 767 |
-
font-weight: 800;
|
| 768 |
-
letter-spacing: -0.02em;
|
| 769 |
-
}
|
| 770 |
-
|
| 771 |
-
.pp-api-copy {
|
| 772 |
-
margin: 0;
|
| 773 |
-
color: #557461;
|
| 774 |
-
font-size: 0.94rem;
|
| 775 |
-
line-height: 1.58;
|
| 776 |
-
max-width: 900px;
|
| 777 |
-
}
|
| 778 |
-
|
| 779 |
-
.pp-api-meta {
|
| 780 |
-
display: flex;
|
| 781 |
-
flex-wrap: wrap;
|
| 782 |
-
gap: 0.5rem;
|
| 783 |
-
margin: 0.25rem 0 0.15rem;
|
| 784 |
-
}
|
| 785 |
-
|
| 786 |
-
.pp-api-pill {
|
| 787 |
-
display: inline-flex;
|
| 788 |
-
align-items: center;
|
| 789 |
-
gap: 0.32rem;
|
| 790 |
-
padding: 0.34rem 0.72rem;
|
| 791 |
-
border-radius: 999px;
|
| 792 |
-
border: 1px solid rgba(188, 213, 196, 0.95);
|
| 793 |
-
background: rgba(248, 252, 249, 0.95);
|
| 794 |
-
color: #3f6856;
|
| 795 |
-
font-size: 0.8rem;
|
| 796 |
-
font-weight: 600;
|
| 797 |
-
line-height: 1.2;
|
| 798 |
-
}
|
| 799 |
-
|
| 800 |
-
.pp-api-pill strong {
|
| 801 |
-
color: #1f8f52;
|
| 802 |
-
font-weight: 800;
|
| 803 |
-
}
|
| 804 |
-
|
| 805 |
-
.pp-api-inline-head {
|
| 806 |
-
margin: 0.05rem 0 0.7rem 0;
|
| 807 |
-
}
|
| 808 |
-
|
| 809 |
-
.pp-main-card {
|
| 810 |
-
border-radius: 16px;
|
| 811 |
-
border: 1px solid rgba(255, 255, 255, 0.9);
|
| 812 |
-
background: rgba(255, 255, 255, 0.77);
|
| 813 |
-
box-shadow: 0 8px 22px rgba(56, 78, 145, 0.10);
|
| 814 |
-
padding: 1.1rem 1.08rem;
|
| 815 |
-
margin: 3.25rem 0 0.9rem 0;
|
| 816 |
-
}
|
| 817 |
-
|
| 818 |
-
.pp-main-grid {
|
| 819 |
-
display: grid;
|
| 820 |
-
grid-template-columns: minmax(0, 4fr) minmax(120px, 1fr);
|
| 821 |
-
gap: 1.1rem;
|
| 822 |
-
align-items: center;
|
| 823 |
-
}
|
| 824 |
-
|
| 825 |
-
.pp-main-card .pp-main-title {
|
| 826 |
-
margin: 0;
|
| 827 |
-
color: #000000 !important;
|
| 828 |
-
font-size: clamp(2.8rem, 4.2vw, 3.5rem) !important;
|
| 829 |
-
font-weight: 800 !important;
|
| 830 |
-
line-height: 0.98 !important;
|
| 831 |
-
letter-spacing: -0.03em !important;
|
| 832 |
-
}
|
| 833 |
-
|
| 834 |
-
.pp-main-copy {
|
| 835 |
-
margin: 0.5rem 0 0;
|
| 836 |
-
color: #557461;
|
| 837 |
-
font-size: 1.0rem;
|
| 838 |
-
line-height: 1.56;
|
| 839 |
-
max-width: 900px;
|
| 840 |
-
}
|
| 841 |
-
|
| 842 |
-
.pp-main-logo {
|
| 843 |
-
display: flex;
|
| 844 |
-
justify-content: center;
|
| 845 |
-
}
|
| 846 |
-
|
| 847 |
-
.pp-main-logo img {
|
| 848 |
-
width: 118px;
|
| 849 |
-
height: 118px;
|
| 850 |
-
border-radius: 16px;
|
| 851 |
-
border: 1px solid rgba(255, 255, 255, 0.62);
|
| 852 |
-
box-shadow: 0 10px 22px rgba(32, 92, 63, 0.20);
|
| 853 |
-
object-fit: contain;
|
| 854 |
-
background: linear-gradient(145deg, #2f8059, #1f9d55);
|
| 855 |
-
padding: 8px;
|
| 856 |
-
filter: drop-shadow(0 0 5px rgba(255, 255, 255, 0.25));
|
| 857 |
-
}
|
| 858 |
-
|
| 859 |
-
.pp-lab-card {
|
| 860 |
-
margin: 0.35rem 0 0.5rem 0;
|
| 861 |
-
border-radius: 20px;
|
| 862 |
-
border: 1px solid rgba(255, 255, 255, 0.92);
|
| 863 |
-
background:
|
| 864 |
-
linear-gradient(125deg, rgba(255, 255, 255, 0.82), rgba(242, 250, 245, 0.75));
|
| 865 |
-
box-shadow: 0 12px 26px rgba(44, 95, 67, 0.12);
|
| 866 |
-
padding: 1.3rem 1.35rem 1.25rem;
|
| 867 |
-
}
|
| 868 |
-
|
| 869 |
-
.pp-lab-kicker {
|
| 870 |
-
display: inline-flex;
|
| 871 |
-
align-items: center;
|
| 872 |
-
padding: 0.28rem 0.72rem;
|
| 873 |
-
border-radius: 999px;
|
| 874 |
-
background: rgba(255, 255, 255, 0.82);
|
| 875 |
-
border: 1px solid rgba(188, 213, 196, 0.95);
|
| 876 |
-
color: #4a6e5b;
|
| 877 |
-
text-transform: uppercase;
|
| 878 |
-
letter-spacing: 0.08em;
|
| 879 |
-
font-size: 0.74rem;
|
| 880 |
-
font-weight: 700;
|
| 881 |
-
}
|
| 882 |
-
|
| 883 |
-
.pp-lab-title {
|
| 884 |
-
margin: 0.72rem 0 0.26rem;
|
| 885 |
-
font-size: clamp(1.52rem, 2.1vw, 2.05rem);
|
| 886 |
-
font-weight: 800;
|
| 887 |
-
letter-spacing: -0.02em;
|
| 888 |
-
color: #123726;
|
| 889 |
-
}
|
| 890 |
-
|
| 891 |
-
.pp-lab-subtitle {
|
| 892 |
-
margin: 0;
|
| 893 |
-
color: #678a76;
|
| 894 |
-
font-size: 1.0rem;
|
| 895 |
-
font-weight: 600;
|
| 896 |
-
line-height: 1.45;
|
| 897 |
-
}
|
| 898 |
-
|
| 899 |
-
.pp-lab-copy {
|
| 900 |
-
margin: 1rem 0 1.15rem;
|
| 901 |
-
color: #4c6f5d;
|
| 902 |
-
font-size: 1.02rem;
|
| 903 |
-
line-height: 1.68;
|
| 904 |
-
max-width: 1080px;
|
| 905 |
-
}
|
| 906 |
-
|
| 907 |
-
.pp-lab-link {
|
| 908 |
-
display: inline-flex;
|
| 909 |
-
align-items: center;
|
| 910 |
-
justify-content: center;
|
| 911 |
-
text-decoration: none;
|
| 912 |
-
border-radius: 999px;
|
| 913 |
-
border: 1px solid rgba(34, 163, 93, 0.36);
|
| 914 |
-
background: linear-gradient(100deg, #21a35e, #34c67a);
|
| 915 |
-
color: #ffffff !important;
|
| 916 |
-
font-size: 0.95rem;
|
| 917 |
-
font-weight: 700;
|
| 918 |
-
letter-spacing: 0.01em;
|
| 919 |
-
padding: 0.56rem 1rem;
|
| 920 |
-
box-shadow: 0 8px 16px rgba(34, 163, 93, 0.25);
|
| 921 |
-
transition: transform 140ms ease, box-shadow 140ms ease, filter 140ms ease;
|
| 922 |
-
}
|
| 923 |
-
|
| 924 |
-
.pp-lab-link:hover {
|
| 925 |
-
color: #ffffff !important;
|
| 926 |
-
transform: translateY(-1px);
|
| 927 |
-
box-shadow: 0 12px 20px rgba(34, 163, 93, 0.3);
|
| 928 |
-
filter: saturate(1.02);
|
| 929 |
-
}
|
| 930 |
-
|
| 931 |
-
@media (max-width: 980px) {
|
| 932 |
-
[data-testid="stAppViewContainer"] {
|
| 933 |
-
height: auto;
|
| 934 |
-
overflow: visible !important;
|
| 935 |
-
}
|
| 936 |
-
|
| 937 |
-
section[data-testid="stMain"] {
|
| 938 |
-
margin: 8px;
|
| 939 |
-
height: auto !important;
|
| 940 |
-
border-radius: 16px;
|
| 941 |
-
}
|
| 942 |
-
|
| 943 |
-
[data-testid="stSidebar"] > div:first-child {
|
| 944 |
-
margin: 8px;
|
| 945 |
-
height: auto;
|
| 946 |
-
border-radius: 16px;
|
| 947 |
-
}
|
| 948 |
-
|
| 949 |
-
.pp-main-grid {
|
| 950 |
-
grid-template-columns: 1fr;
|
| 951 |
-
gap: 0.75rem;
|
| 952 |
-
}
|
| 953 |
-
|
| 954 |
-
.pp-kpi-strip {
|
| 955 |
-
grid-template-columns: 1fr;
|
| 956 |
-
gap: 0.42rem;
|
| 957 |
-
padding: 0.95rem 1rem 0.85rem;
|
| 958 |
-
}
|
| 959 |
-
|
| 960 |
-
.pp-kpi-item {
|
| 961 |
-
padding: 0.18rem 0.45rem;
|
| 962 |
-
}
|
| 963 |
-
|
| 964 |
-
.pp-kpi-strip .pp-kpi-value {
|
| 965 |
-
font-size: 3.05rem !important;
|
| 966 |
-
}
|
| 967 |
-
|
| 968 |
-
.pp-kpi-strip .pp-kpi-label {
|
| 969 |
-
font-size: 0.7rem !important;
|
| 970 |
-
}
|
| 971 |
-
|
| 972 |
-
.pp-main-logo {
|
| 973 |
-
justify-content: flex-start;
|
| 974 |
-
}
|
| 975 |
-
|
| 976 |
-
.pp-main-card {
|
| 977 |
-
margin-top: 0.7rem;
|
| 978 |
-
}
|
| 979 |
-
|
| 980 |
-
.pp-hero-grid {
|
| 981 |
-
grid-template-columns: 1fr;
|
| 982 |
-
gap: 0.85rem;
|
| 983 |
-
}
|
| 984 |
-
.pp-hero-logo {
|
| 985 |
-
justify-content: flex-start;
|
| 986 |
-
}
|
| 987 |
-
|
| 988 |
-
.pp-lab-card {
|
| 989 |
-
padding: 1.0rem;
|
| 990 |
-
}
|
| 991 |
-
|
| 992 |
-
.pp-lab-title {
|
| 993 |
-
margin-top: 0.58rem;
|
| 994 |
-
}
|
| 995 |
-
|
| 996 |
-
.pp-lab-copy {
|
| 997 |
-
margin-top: 0.78rem;
|
| 998 |
-
font-size: 0.96rem;
|
| 999 |
-
}
|
| 1000 |
-
}
|
| 1001 |
-
</style>
|
| 1002 |
-
"""
|
| 1003 |
-
st.markdown(css.replace("__ICON_CSS__", icon_css), unsafe_allow_html=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
src/utils.py
DELETED
|
@@ -1,338 +0,0 @@
|
|
| 1 |
-
# utils.py
|
| 2 |
-
from __future__ import annotations
|
| 3 |
-
|
| 4 |
-
from typing import Dict, List, Optional, Sequence, Literal
|
| 5 |
-
|
| 6 |
-
import math
|
| 7 |
-
import numpy as np
|
| 8 |
-
import torch
|
| 9 |
-
import torch.nn as nn
|
| 10 |
-
|
| 11 |
-
# Re-exported conveniences from data_builder
|
| 12 |
-
from src.data_builder import TargetScaler, grouped_split_by_smiles # noqa: F401
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
# ---------------------------------------------------------
|
| 16 |
-
# Seeding and device helpers
|
| 17 |
-
# ---------------------------------------------------------
|
| 18 |
-
|
| 19 |
-
def seed_everything(seed: int) -> None:
|
| 20 |
-
"""Deterministically seed Python, NumPy, and PyTorch (CPU/CUDA)."""
|
| 21 |
-
import random
|
| 22 |
-
random.seed(seed)
|
| 23 |
-
np.random.seed(seed)
|
| 24 |
-
torch.manual_seed(seed)
|
| 25 |
-
torch.cuda.manual_seed_all(seed)
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
def to_device(batch, device: torch.device):
|
| 29 |
-
"""Move a PyG Batch or simple dict of tensors to device."""
|
| 30 |
-
if hasattr(batch, "to"):
|
| 31 |
-
return batch.to(device)
|
| 32 |
-
if isinstance(batch, dict):
|
| 33 |
-
return {k: (v.to(device) if torch.is_tensor(v) else v) for k, v in batch.items()}
|
| 34 |
-
return batch
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
# ---------------------------------------------------------
|
| 38 |
-
# Masked metrics (canonical)
|
| 39 |
-
# ---------------------------------------------------------
|
| 40 |
-
|
| 41 |
-
def _safe_div(num: torch.Tensor, den: torch.Tensor) -> torch.Tensor:
|
| 42 |
-
den = torch.clamp(den, min=1e-12)
|
| 43 |
-
return num / den
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
def masked_mse(pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor,
|
| 47 |
-
reduction: Literal["mean", "sum"] = "mean") -> torch.Tensor:
|
| 48 |
-
"""
|
| 49 |
-
pred/target: [B, T]; mask: [B, T] bool
|
| 50 |
-
"""
|
| 51 |
-
pred, target = pred.float(), target.float()
|
| 52 |
-
mask = mask.bool()
|
| 53 |
-
se = ((pred - target) ** 2) * mask
|
| 54 |
-
if reduction == "sum":
|
| 55 |
-
return se.sum()
|
| 56 |
-
return _safe_div(se.sum(), mask.sum().float())
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
def masked_mae(pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor,
|
| 60 |
-
reduction: Literal["mean", "sum"] = "mean") -> torch.Tensor:
|
| 61 |
-
ae = (pred - target).abs() * mask
|
| 62 |
-
if reduction == "sum":
|
| 63 |
-
return ae.sum()
|
| 64 |
-
return _safe_div(ae.sum(), mask.sum().float())
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
def masked_rmse(pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
| 68 |
-
return torch.sqrt(masked_mse(pred, target, mask, reduction="mean"))
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
def masked_r2(pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
| 72 |
-
"""
|
| 73 |
-
Masked coefficient of determination across all elements jointly.
|
| 74 |
-
"""
|
| 75 |
-
pred, target = pred.float(), target.float()
|
| 76 |
-
mask = mask.bool()
|
| 77 |
-
count = mask.sum().float().clamp(min=1.0)
|
| 78 |
-
mean = _safe_div((target * mask).sum(), count)
|
| 79 |
-
sst = (((target - mean) ** 2) * mask).sum()
|
| 80 |
-
sse = (((target - pred) ** 2) * mask).sum()
|
| 81 |
-
return 1.0 - _safe_div(sse, sst.clamp(min=1e-12))
|
| 82 |
-
|
| 83 |
-
|
| 84 |
-
def masked_metrics_overall(pred: torch.Tensor, target: torch.Tensor, mask: torch.Tensor) -> Dict[str, float]:
|
| 85 |
-
return {
|
| 86 |
-
"rmse": float(masked_rmse(pred, target, mask).detach().cpu()),
|
| 87 |
-
"mae": float(masked_mae(pred, target, mask).detach().cpu()),
|
| 88 |
-
"r2": float(masked_r2(pred, target, mask).detach().cpu()),
|
| 89 |
-
}
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
def masked_metrics_per_task(
|
| 93 |
-
pred: torch.Tensor,
|
| 94 |
-
target: torch.Tensor,
|
| 95 |
-
mask: torch.Tensor,
|
| 96 |
-
task_names: Sequence[str],
|
| 97 |
-
) -> Dict[str, Dict[str, float]]:
|
| 98 |
-
"""
|
| 99 |
-
Per-task metrics using the same masked formulations.
|
| 100 |
-
"""
|
| 101 |
-
out: Dict[str, Dict[str, float]] = {}
|
| 102 |
-
for t, name in enumerate(task_names):
|
| 103 |
-
m = mask[:, t]
|
| 104 |
-
if m.any():
|
| 105 |
-
rmse = float(masked_rmse(pred[:, t:t+1], target[:, t:t+1], m.unsqueeze(1)).detach().cpu())
|
| 106 |
-
mae = float(masked_mae(pred[:, t:t+1], target[:, t:t+1], m.unsqueeze(1)).detach().cpu())
|
| 107 |
-
r2 = float(masked_r2(pred[:, t:t+1], target[:, t:t+1], m.unsqueeze(1)).detach().cpu())
|
| 108 |
-
else:
|
| 109 |
-
rmse = mae = r2 = float("nan")
|
| 110 |
-
out[name] = {"rmse": rmse, "mae": mae, "r2": r2}
|
| 111 |
-
return out
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
def masked_metrics_by_fidelity(
|
| 115 |
-
pred: torch.Tensor,
|
| 116 |
-
target: torch.Tensor,
|
| 117 |
-
mask: torch.Tensor,
|
| 118 |
-
fid_idx: torch.Tensor,
|
| 119 |
-
fid_names: Sequence[str],
|
| 120 |
-
task_names: Sequence[str], # kept for API parity; not used in overall-by-fid
|
| 121 |
-
) -> Dict[str, Dict[str, float]]:
|
| 122 |
-
"""
|
| 123 |
-
Overall metrics per fidelity (aggregated across tasks).
|
| 124 |
-
"""
|
| 125 |
-
out: Dict[str, Dict[str, float]] = {}
|
| 126 |
-
fid_idx = fid_idx.view(-1).long()
|
| 127 |
-
for i, fname in enumerate(fid_names):
|
| 128 |
-
sel = (fid_idx == i)
|
| 129 |
-
if sel.any():
|
| 130 |
-
p = pred[sel]
|
| 131 |
-
y = target[sel]
|
| 132 |
-
m = mask[sel]
|
| 133 |
-
out[fname] = masked_metrics_overall(p, y, m)
|
| 134 |
-
else:
|
| 135 |
-
out[fname] = {"rmse": float("nan"), "mae": float("nan"), "r2": float("nan")}
|
| 136 |
-
return out
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
# ---------------------------------------------------------
|
| 140 |
-
# Multitask, multi-fidelity loss (canonical)
|
| 141 |
-
# ---------------------------------------------------------
|
| 142 |
-
|
| 143 |
-
def gaussian_nll(mu: torch.Tensor, logvar: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
| 144 |
-
"""
|
| 145 |
-
Element-wise Gaussian NLL (no reduction).
|
| 146 |
-
Shapes: mu, logvar, target -> [B, T] (or broadcastable).
|
| 147 |
-
"""
|
| 148 |
-
logvar = torch.as_tensor(logvar, device=mu.device, dtype=mu.dtype)
|
| 149 |
-
logvar = logvar.clamp(min=-20.0, max=20.0) # numerical guard
|
| 150 |
-
var = torch.exp(logvar)
|
| 151 |
-
err2_over_var = (target - mu) ** 2 / var
|
| 152 |
-
nll = 0.5 * (err2_over_var + logvar + math.log(2.0 * math.pi)) # [B, T]
|
| 153 |
-
return nll
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
def loss_multitask_fidelity(
|
| 157 |
-
*,
|
| 158 |
-
pred: torch.Tensor, # [B, T] (or means if heteroscedastic)
|
| 159 |
-
target: torch.Tensor, # [B, T]
|
| 160 |
-
mask: torch.Tensor, # [B, T] bool
|
| 161 |
-
fid_idx: torch.Tensor, # [B] long (per-row fidelity index)
|
| 162 |
-
fid_loss_w: Sequence[float] | torch.Tensor | None, # [F] weights per fidelity
|
| 163 |
-
task_weights: Optional[Sequence[float] | torch.Tensor] = None, # [T]
|
| 164 |
-
hetero_logvar: Optional[torch.Tensor] = None, # [B, T] if heteroscedastic head
|
| 165 |
-
reduction: Literal["mean", "sum"] = "mean",
|
| 166 |
-
task_log_sigma2: Optional[torch.Tensor] = None, # [T] learned homoscedastic uncertainty
|
| 167 |
-
balanced: bool = True,
|
| 168 |
-
) -> torch.Tensor:
|
| 169 |
-
"""
|
| 170 |
-
Multi-task, multi-fidelity loss with *balanced per-task reduction* by default.
|
| 171 |
-
|
| 172 |
-
- If `hetero_logvar` is given: uses Gaussian NLL per element.
|
| 173 |
-
- Applies per-fidelity weights via `fid_idx`.
|
| 174 |
-
- Balanced reduction: compute mean loss per task first, then average across tasks
|
| 175 |
-
(optionally weight by `task_weights` or learned uncertainty `task_log_sigma2`).
|
| 176 |
-
- If `balanced=False`, uses legacy global reduction.
|
| 177 |
-
"""
|
| 178 |
-
B, T = pred.shape
|
| 179 |
-
pred = pred.float()
|
| 180 |
-
target = target.float()
|
| 181 |
-
mask = mask.bool()
|
| 182 |
-
fid_idx = fid_idx.view(-1).long()
|
| 183 |
-
|
| 184 |
-
# Task weights (optional)
|
| 185 |
-
if task_weights is None:
|
| 186 |
-
tw = pred.new_ones(T) # [T]
|
| 187 |
-
else:
|
| 188 |
-
tw = torch.as_tensor(task_weights, dtype=pred.dtype, device=pred.device)
|
| 189 |
-
assert tw.numel() == T, f"task_weights len {tw.numel()} != T {T}"
|
| 190 |
-
s = tw.sum().clamp(min=1e-12)
|
| 191 |
-
tw = tw * (T / s) # normalize to sum=T for stable scale
|
| 192 |
-
|
| 193 |
-
# Fidelity weights
|
| 194 |
-
if fid_loss_w is None:
|
| 195 |
-
fw = pred.new_ones(int(fid_idx.max().item()) + 1)
|
| 196 |
-
else:
|
| 197 |
-
fw = torch.as_tensor(fid_loss_w, dtype=pred.dtype, device=pred.device)
|
| 198 |
-
w_fid = fw[fid_idx].unsqueeze(1).expand(-1, T) # [B, T]
|
| 199 |
-
|
| 200 |
-
# Elementwise loss
|
| 201 |
-
if hetero_logvar is not None:
|
| 202 |
-
elem_loss = gaussian_nll(pred, hetero_logvar.float(), target) # [B, T]
|
| 203 |
-
else:
|
| 204 |
-
elem_loss = (pred - target) ** 2 # [B, T]
|
| 205 |
-
|
| 206 |
-
if not balanced:
|
| 207 |
-
# Legacy global reduction (label-count biased)
|
| 208 |
-
w_task = tw.view(1, T).expand(B, -1)
|
| 209 |
-
weighted = elem_loss * mask * w_task * w_fid
|
| 210 |
-
if reduction == "sum":
|
| 211 |
-
return weighted.sum()
|
| 212 |
-
denom = (mask * w_task * w_fid).sum().float().clamp(min=1e-12)
|
| 213 |
-
return weighted.sum() / denom
|
| 214 |
-
|
| 215 |
-
# -------- Balanced per-task reduction --------
|
| 216 |
-
# First compute a per-task average (exclude tw here)
|
| 217 |
-
num = (elem_loss * mask * w_fid).sum(dim=0) # [T]
|
| 218 |
-
den = (mask * w_fid).sum(dim=0).float().clamp(min=1e-12) # [T]
|
| 219 |
-
per_task_loss = num / den # [T]
|
| 220 |
-
|
| 221 |
-
# Optional manual task weights AFTER per-task averaging
|
| 222 |
-
if task_weights is not None:
|
| 223 |
-
per_task_loss = per_task_loss * tw
|
| 224 |
-
|
| 225 |
-
# Optional homoscedastic task-uncertainty weighting (Kendall & Gal)
|
| 226 |
-
if task_log_sigma2 is not None:
|
| 227 |
-
assert task_log_sigma2.numel() == T, f"task_log_sigma2 must be [T], got {task_log_sigma2.shape}"
|
| 228 |
-
sigma2 = torch.exp(task_log_sigma2) # [T]
|
| 229 |
-
per_task_loss = per_task_loss / (2.0 * sigma2) + 0.5 * torch.log(sigma2)
|
| 230 |
-
|
| 231 |
-
if reduction == "sum":
|
| 232 |
-
return per_task_loss.sum()
|
| 233 |
-
return per_task_loss.mean()
|
| 234 |
-
|
| 235 |
-
|
| 236 |
-
# ---------------------------------------------------------
|
| 237 |
-
# Curriculum scheduler for EXP fidelity
|
| 238 |
-
# ---------------------------------------------------------
|
| 239 |
-
|
| 240 |
-
def exp_weight_at_epoch(
|
| 241 |
-
epoch: int,
|
| 242 |
-
total_epochs: int,
|
| 243 |
-
schedule: Literal["none", "linear", "cosine"] = "none",
|
| 244 |
-
start: float = 0.6,
|
| 245 |
-
end: float = 1.0,
|
| 246 |
-
) -> float:
|
| 247 |
-
"""
|
| 248 |
-
Returns the EXP loss weight for a given epoch under the chosen schedule.
|
| 249 |
-
"""
|
| 250 |
-
if schedule == "none":
|
| 251 |
-
return float(end)
|
| 252 |
-
epoch = max(0, min(epoch, total_epochs))
|
| 253 |
-
if total_epochs <= 0:
|
| 254 |
-
return float(end)
|
| 255 |
-
t = epoch / float(total_epochs)
|
| 256 |
-
if schedule == "linear":
|
| 257 |
-
return float(start + (end - start) * t)
|
| 258 |
-
if schedule == "cosine":
|
| 259 |
-
cos_t = 0.5 - 0.5 * math.cos(math.pi * t) # 0->1 smoothly
|
| 260 |
-
return float(start + (end - start) * cos_t)
|
| 261 |
-
raise ValueError(f"Unknown schedule: {schedule}")
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
def make_fid_loss_weights(
|
| 265 |
-
fids: Sequence[str],
|
| 266 |
-
base_weights: Optional[Sequence[float]] = None,
|
| 267 |
-
exp_weight: Optional[float] = None,
|
| 268 |
-
) -> List[float]:
|
| 269 |
-
"""
|
| 270 |
-
Builds a per-fidelity weight vector aligned with dataset.fids order.
|
| 271 |
-
If exp_weight is provided, it overrides the weight for the 'exp' fidelity.
|
| 272 |
-
If base_weights is provided, it must match len(fids) and is used as a template.
|
| 273 |
-
"""
|
| 274 |
-
fids_lc = [f.lower() for f in fids]
|
| 275 |
-
F = len(fids_lc)
|
| 276 |
-
if base_weights is None:
|
| 277 |
-
w = [1.0] * F
|
| 278 |
-
else:
|
| 279 |
-
assert len(base_weights) == F, f"base_weights len {len(base_weights)} != {F}"
|
| 280 |
-
w = [float(x) for x in base_weights]
|
| 281 |
-
if exp_weight is not None and "exp" in fids_lc:
|
| 282 |
-
idx = fids_lc.index("exp")
|
| 283 |
-
w[idx] = float(exp_weight)
|
| 284 |
-
return w
|
| 285 |
-
|
| 286 |
-
|
| 287 |
-
# ---------------------------------------------------------
|
| 288 |
-
# Inference utilities
|
| 289 |
-
# ---------------------------------------------------------
|
| 290 |
-
|
| 291 |
-
def apply_inverse_transform(pred: torch.Tensor, scaler):
|
| 292 |
-
"""
|
| 293 |
-
Apply inverse target scaling safely on the same device as pred.
|
| 294 |
-
Works for CPU/GPU and legacy scalers.
|
| 295 |
-
"""
|
| 296 |
-
dev = pred.device
|
| 297 |
-
|
| 298 |
-
# Move scaler tensors to pred device if needed
|
| 299 |
-
if hasattr(scaler, "mean") and scaler.mean.device != dev:
|
| 300 |
-
scaler.mean = scaler.mean.to(dev)
|
| 301 |
-
if hasattr(scaler, "std") and scaler.std.device != dev:
|
| 302 |
-
scaler.std = scaler.std.to(dev)
|
| 303 |
-
if hasattr(scaler, "eps") and scaler.eps is not None and scaler.eps.device != dev:
|
| 304 |
-
scaler.eps = scaler.eps.to(dev)
|
| 305 |
-
|
| 306 |
-
return scaler.inverse(pred)
|
| 307 |
-
|
| 308 |
-
|
| 309 |
-
|
| 310 |
-
def ensure_2d(x: torch.Tensor) -> torch.Tensor:
|
| 311 |
-
"""Utility to guarantee [B, T] shape for single-task or squeezed outputs."""
|
| 312 |
-
if x.dim() == 1:
|
| 313 |
-
return x.unsqueeze(1)
|
| 314 |
-
return x
|
| 315 |
-
|
| 316 |
-
|
| 317 |
-
# ---------------------------------------------------------
|
| 318 |
-
# Simple test harness (optional)
|
| 319 |
-
# ---------------------------------------------------------
|
| 320 |
-
|
| 321 |
-
if __name__ == "__main__":
|
| 322 |
-
# Minimal sanity checks
|
| 323 |
-
torch.manual_seed(0)
|
| 324 |
-
B, T = 5, 3
|
| 325 |
-
pred = torch.randn(B, T)
|
| 326 |
-
targ = torch.randn(B, T)
|
| 327 |
-
mask = torch.rand(B, T) > 0.3
|
| 328 |
-
fid_idx = torch.randint(0, 4, (B,))
|
| 329 |
-
fid_w = [1.0, 0.8, 0.6, 0.5]
|
| 330 |
-
task_w = [1.0, 2.0, 1.0]
|
| 331 |
-
|
| 332 |
-
l1 = loss_multitask_fidelity(pred=pred, target=targ, mask=mask, fid_idx=fid_idx, fid_loss_w=fid_w, task_weights=task_w)
|
| 333 |
-
l2 = loss_multitask_fidelity(pred=pred, target=targ, mask=mask, fid_idx=fid_idx, fid_loss_w=fid_w, task_weights=None)
|
| 334 |
-
print("Loss with task weights:", float(l1))
|
| 335 |
-
print("Loss without task weights:", float(l2))
|
| 336 |
-
|
| 337 |
-
m_all = masked_metrics_overall(pred, targ, mask)
|
| 338 |
-
print("Overall metrics:", m_all)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|