File size: 26,843 Bytes
930ea3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
# src/discovery.py
from __future__ import annotations

import json
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Dict, List, Optional, Tuple

import numpy as np
import pandas as pd
from rdkit import Chem, DataStructs
from rdkit.Chem import AllChem
from . import sascorer

# Reuse your canonicalizer if you want; otherwise keep local
def canonicalize_smiles(smiles: str) -> Optional[str]:
    s = (smiles or "").strip()
    if not s:
        return None
    m = Chem.MolFromSmiles(s)
    if m is None:
        return None
    return Chem.MolToSmiles(m, canonical=True)


# -------------------------
# Spec schema (minimal v0)
# -------------------------
@dataclass
class DiscoverySpec:
    dataset: List[str]  # ["PI1M_PROPERTY.parquet", "POLYINFO_PROPERTY.parquet"]
    polyinfo: str  # "POLYINFO_PROPERTY.parquet"
    polyinfo_csv: str  # "POLYINFO.csv"

    hard_constraints: Dict[str, Dict[str, float]]  # { "tg": {"min": 400}, "tc": {"max": 0.3} }
    objectives: List[Dict[str, str]]  # [{"property":"cp","goal":"maximize"}, ...]

    max_pool: int = 200000         # legacy (kept for compatibility; aligned to pareto_max)
    pareto_max: int = 50000        # cap points used for Pareto + diversity fingerprinting
    max_candidates: int = 30       # final output size
    max_pareto_fronts: int = 5     # how many Pareto layers to keep for candidate pool
    min_distance: float = 0.30     # diversity threshold in Tanimoto distance
    fingerprint: str = "morgan"    # morgan only for now
    random_seed: int = 7
    use_canonical_smiles: bool = True
    use_full_data: bool = False
    trust_weights: Dict[str, float] | None = None
    selection_weights: Dict[str, float] | None = None


# -------------------------
# Column mapping
# -------------------------
def mean_col(prop_key: str) -> str:
    return f"mean_{prop_key.lower()}"

def std_col(prop_key: str) -> str:
    return f"std_{prop_key.lower()}"


def normalize_weights(weights: Dict[str, float], defaults: Dict[str, float]) -> Dict[str, float]:
    out: Dict[str, float] = {}
    for k, v in defaults.items():
        try:
            vv = float(weights.get(k, v))
        except Exception:
            vv = float(v)
        out[k] = max(0.0, vv)
    s = float(sum(out.values()))
    if s <= 0.0:
        return defaults.copy()
    return {k: float(v / s) for k, v in out.items()}

def spec_from_dict(obj: dict, dataset_path: List[str], polyinfo_path: str, polyinfo_csv_path: str) -> DiscoverySpec:
    pareto_max = int(obj.get("pareto_max", 50000))
    return DiscoverySpec(
        dataset=list(dataset_path),
        polyinfo=polyinfo_path,
        polyinfo_csv=polyinfo_csv_path,
        hard_constraints=obj.get("hard_constraints", {}),
        objectives=obj.get("objectives", []),
        # Legacy field kept for compatibility; effectively collapsed to pareto_max.
        max_pool=pareto_max,
        pareto_max=pareto_max,
        max_candidates=int(obj.get("max_candidates", 30)),
        max_pareto_fronts=int(obj.get("max_pareto_fronts", 5)),
        min_distance=float(obj.get("min_distance", 0.30)),
        fingerprint=str(obj.get("fingerprint", "morgan")),
        random_seed=int(obj.get("random_seed", 7)),
        use_canonical_smiles=not bool(obj.get("skip_smiles_canonicalization", True)),
        use_full_data=bool(obj.get("use_full_data", False)),
        trust_weights=obj.get("trust_weights"),
        selection_weights=obj.get("selection_weights"),
    )

# -------------------------
# Parquet loading (safe)
# -------------------------
def load_parquet_columns(path: str | List[str], columns: List[str]) -> pd.DataFrame:
    """
    Load only requested columns from Parquet (critical for 1M rows).
    Accepts a single path or a list of paths and concatenates rows.
    """
    def _load_one(fp: str, req_cols: List[str]) -> pd.DataFrame:
        available: list[str]
        try:
            import pyarrow.parquet as pq

            pf = pq.ParquetFile(fp)
            available = [str(c) for c in pf.schema.names]
        except Exception:
            # If schema probing fails, fall back to direct read with requested columns.
            return pd.read_parquet(fp, columns=req_cols)

        available_set = set(available)
        lower_to_actual = {c.lower(): c for c in available}

        # Resolve requested names against actual parquet schema.
        resolved: dict[str, str] = {}
        for req in req_cols:
            if req in available_set:
                resolved[req] = req
                continue
            alt = lower_to_actual.get(str(req).lower())
            if alt is not None:
                resolved[req] = alt

        use_cols = sorted(set(resolved.values()))
        if not use_cols:
            return pd.DataFrame(columns=req_cols)

        out = pd.read_parquet(fp, columns=use_cols)
        for req in req_cols:
            src = resolved.get(req)
            if src is None:
                out[req] = np.nan
            elif src != req:
                out[req] = out[src]
        return out[req_cols]

    if isinstance(path, (list, tuple)):
        frames = [_load_one(p, columns) for p in path]
        if not frames:
            return pd.DataFrame(columns=columns)
        return pd.concat(frames, ignore_index=True)
    return _load_one(path, columns)


def normalize_smiles(smiles: str, use_canonical_smiles: bool) -> Optional[str]:
    s = (smiles or "").strip()
    if not s:
        return None
    if not use_canonical_smiles:
        # Skip RDKit parsing entirely in fast mode.
        return s
    m = Chem.MolFromSmiles(s)
    if m is None:
        return None
    if use_canonical_smiles:
        return Chem.MolToSmiles(m, canonical=True)
    return s


def load_polyinfo_index(polyinfo_csv_path: str, use_canonical_smiles: bool = True) -> pd.DataFrame:
    """
    Expected CSV columns: SMILES, Polymer_Class, polymer_name (or common variants).
    Returns dataframe with index on smiles_key and columns polymer_name/polymer_class.
    """
    df = pd.read_csv(polyinfo_csv_path)

    # normalize column names
    cols = {c: c for c in df.columns}
    # map typical names
    if "SMILES" in cols:
        df = df.rename(columns={"SMILES": "smiles"})
    elif "smiles" not in df.columns:
        raise ValueError(f"{polyinfo_csv_path} missing SMILES/smiles column")

    if "Polymer_Name" in df.columns:
        df = df.rename(columns={"Polymer_Name": "polymer_name"})
    if "polymer_Name" in df.columns:
        df = df.rename(columns={"polymer_Name": "polymer_name"})
    if "Polymer_Class" in df.columns:
        df = df.rename(columns={"Polymer_Class": "polymer_class"})

    if "polymer_name" not in df.columns:
        df["polymer_name"] = pd.NA
    if "polymer_class" not in df.columns:
        df["polymer_class"] = pd.NA

    df["smiles_key"] = df["smiles"].astype(str).map(lambda s: normalize_smiles(s, use_canonical_smiles))
    df = df.dropna(subset=["smiles_key"]).drop_duplicates("smiles_key")
    df = df.set_index("smiles_key", drop=True)
    return df[["polymer_name", "polymer_class"]]


# -------------------------
# Pareto (2–3 objectives)
# -------------------------
def pareto_front_mask(X: np.ndarray) -> np.ndarray:
    """
    Returns mask for nondominated points.
    X: (N, M), all objectives assumed to be minimized.
    For maximize objectives, we invert before calling this.
    """
    N = X.shape[0]
    is_efficient = np.ones(N, dtype=bool)
    for i in range(N):
        if not is_efficient[i]:
            continue
        # any point that is <= in all dims and < in at least one dominates
        dominates = np.all(X <= X[i], axis=1) & np.any(X < X[i], axis=1)
        # if a point dominates i, mark i inefficient
        if np.any(dominates):
            is_efficient[i] = False
            continue
        # otherwise, i may dominate others
        dominated_by_i = np.all(X[i] <= X, axis=1) & np.any(X[i] < X, axis=1)
        is_efficient[dominated_by_i] = False
        is_efficient[i] = True
    return is_efficient


def pareto_layers(X: np.ndarray, max_layers: int = 10) -> np.ndarray:
    """
    Returns layer index per point: 1 = Pareto front, 2 = second layer, ...
    Unassigned points beyond max_layers get 0.
    """
    N = X.shape[0]
    layers = np.zeros(N, dtype=int)
    remaining = np.arange(N)

    layer = 1
    while remaining.size > 0 and layer <= max_layers:
        mask = pareto_front_mask(X[remaining])
        front_idx = remaining[mask]
        layers[front_idx] = layer
        remaining = remaining[~mask]
        layer += 1
    return layers


def pareto_front_mask_chunked(
    X: np.ndarray,
    chunk_size: int = 100000,
    progress_callback: Optional[Callable[[int, int], None]] = None,
) -> np.ndarray:
    """
    Exact global Pareto front mask via chunk-local front reduction + global reconcile.
    This is exact for front-1:
      1) compute exact local front within each chunk
      2) union local fronts
      3) compute exact front on the union
    """
    N = X.shape[0]
    if N <= chunk_size:
        if progress_callback is not None:
            progress_callback(1, 1)
        return pareto_front_mask(X)

    local_front_idx = []
    total_chunks = (N + chunk_size - 1) // chunk_size
    done_chunks = 0
    for start in range(0, N, chunk_size):
        end = min(start + chunk_size, N)
        idx = np.arange(start, end)
        mask_local = pareto_front_mask(X[idx])
        local_front_idx.append(idx[mask_local])
        done_chunks += 1
        if progress_callback is not None:
            progress_callback(done_chunks, total_chunks)

    if not local_front_idx:
        return np.zeros(N, dtype=bool)

    reduced_idx = np.concatenate(local_front_idx)
    reduced_mask = pareto_front_mask(X[reduced_idx])
    front_idx = reduced_idx[reduced_mask]

    out = np.zeros(N, dtype=bool)
    out[front_idx] = True
    return out


def pareto_layers_chunked(
    X: np.ndarray,
    max_layers: int = 10,
    chunk_size: int = 100000,
    progress_callback: Optional[Callable[[int, int, int], None]] = None,
) -> np.ndarray:
    """
    Exact Pareto layers using repeated exact chunked front extraction.
    """
    N = X.shape[0]
    layers = np.zeros(N, dtype=int)
    remaining = np.arange(N)
    layer = 1

    while remaining.size > 0 and layer <= max_layers:
        def on_chunk(done: int, total: int) -> None:
            if progress_callback is not None:
                progress_callback(layer, done, total)

        mask = pareto_front_mask_chunked(X[remaining], chunk_size=chunk_size, progress_callback=on_chunk)
        front_idx = remaining[mask]
        layers[front_idx] = layer
        remaining = remaining[~mask]
        layer += 1

    return layers


# -------------------------
# Fingerprints & diversity
# -------------------------
def morgan_fp(smiles: str, radius: int = 2, nbits: int = 2048):
    m = Chem.MolFromSmiles(smiles)
    if m is None:
        return None
    return AllChem.GetMorganFingerprintAsBitVect(m, radius, nBits=nbits)

def tanimoto_distance(fp1, fp2) -> float:
    return 1.0 - DataStructs.TanimotoSimilarity(fp1, fp2)

def greedy_diverse_select(
    smiles_list: List[str],
    scores: np.ndarray,
    max_k: int,
    min_dist: float,
) -> List[int]:
    """
    Greedy selection by descending score, enforcing min Tanimoto distance.
    Returns indices into smiles_list.
    """
    fps = []
    valid_idx = []
    for i, s in enumerate(smiles_list):
        fp = morgan_fp(s)
        if fp is not None:
            fps.append(fp)
            valid_idx.append(i)

    if not valid_idx:
        return []

    # rank candidates (higher score first)
    order = np.argsort(-scores[valid_idx])
    selected_global = []
    selected_fps = []

    for oi in order:
        i = valid_idx[oi]
        fp_i = fps[oi]  # aligned with valid_idx
        ok = True
        for fp_j in selected_fps:
            if tanimoto_distance(fp_i, fp_j) < min_dist:
                ok = False
                break
        if ok:
            selected_global.append(i)
            selected_fps.append(fp_i)
        if len(selected_global) >= max_k:
            break

    return selected_global


# -------------------------
# Trust score (lightweight, robust)
# -------------------------
def internal_consistency_penalty(row: pd.Series) -> float:
    """
    Very simple physics/validity checks. Penalty in [0,1].
    Adjust/add rules later.
    """
    viol = 0
    total = 0

    def chk(cond: bool):
        nonlocal viol, total
        total += 1
        if not cond:
            viol += 1

    # positivity checks if present
    for p in ["cp", "tc", "rho", "dif", "visc", "tg", "tm", "bandgap"]:
        c = mean_col(p)
        if c in row.index and pd.notna(row[c]):
            if p in ["bandgap", "tg", "tm"]:
                chk(float(row[c]) >= 0.0)
            else:
                chk(float(row[c]) > 0.0)

    # Poisson ratio bounds if present
    if mean_col("poisson") in row.index and pd.notna(row[mean_col("poisson")]):
        v = float(row[mean_col("poisson")])
        chk(0.0 <= v <= 0.5)

    # Tg <= Tm if both present
    if mean_col("tg") in row.index and mean_col("tm") in row.index:
        if pd.notna(row[mean_col("tg")]) and pd.notna(row[mean_col("tm")]):
            chk(float(row[mean_col("tg")]) <= float(row[mean_col("tm")]))

    if total == 0:
        return 0.0
    return viol / total


def synthesizability_score(smiles: str) -> float:
    """
    RDKit SA-score based synthesizability proxy in [0,1].
    SA-score is ~[1 (easy), 10 (hard)].
    We map: 1 -> 1.0, 10 -> 0.0
    """
    m = Chem.MolFromSmiles(smiles)
    if m is None:
        return 0.0

    # Guard against unexpected scorer failures / None for edge-case molecules.
    try:
        sa_raw = sascorer.calculateScore(m)
    except Exception:
        return 0.0
    if sa_raw is None:
        return 0.0

    sa = float(sa_raw)  # ~ 1..10
    s_syn = 1.0 - (sa - 1.0) / 9.0          # linear map to [0,1]
    return float(np.clip(s_syn, 0.0, 1.0))


def compute_trust_scores(
    df: pd.DataFrame,
    real_fps: List,
    real_smiles: List[str],
    trust_weights: Dict[str, float] | None = None,
) -> np.ndarray:
    """
    Trust score in [0,1] (higher = more trustworthy / lower risk).
    Components:
      - distance to nearest real polymer (fingerprint distance)
      - internal consistency penalty
      - uncertainty penalty (if std columns exist)
      - synthesizability
    """
    N = len(df)
    trust = np.zeros(N, dtype=float)
    tw_defaults = {"real": 0.45, "consistency": 0.25, "uncertainty": 0.10, "synth": 0.20}
    tw = normalize_weights(trust_weights or {}, tw_defaults)

    # nearest-real distance (expensive if done naively)
    # We do it only for the (small) post-filter set, which is safe.
    smiles_col = "smiles_key" if "smiles_key" in df.columns else "smiles_canon"
    for i in range(N):
        s = df.iloc[i][smiles_col]
        fp = morgan_fp(s)
        if fp is None or not real_fps:
            d_real = 1.0
        else:
            sims = DataStructs.BulkTanimotoSimilarity(fp, real_fps)
            d_real = 1.0 - float(max(sims))  # distance to nearest

        # internal consistency
        pen_cons = internal_consistency_penalty(df.iloc[i])

        # uncertainty: average normalized std for any std_* columns present
        std_cols = [c for c in df.columns if c.startswith("std_")]
        if std_cols:
            std_vals = df.iloc[i][std_cols].astype(float)
            std_vals = std_vals.replace([np.inf, -np.inf], np.nan).dropna()
            pen_unc = float(np.clip(std_vals.mean() / (std_vals.mean() + 1.0), 0.0, 1.0)) if len(std_vals) else 0.0
        else:
            pen_unc = 0.0

        # synthesizability heuristic
        s_syn = synthesizability_score(s)

        # Combine (tunable weights)
        # lower distance to real is better -> convert to score
        s_real = 1.0 - np.clip(d_real, 0.0, 1.0)

        trust[i] = (
            tw["real"] * s_real +
            tw["consistency"] * (1.0 - pen_cons) +
            tw["uncertainty"] * (1.0 - pen_unc) +
            tw["synth"] * s_syn
        )

    trust = np.clip(trust, 0.0, 1.0)
    return trust


# -------------------------
# Main pipeline
# -------------------------
def run_discovery(
    spec: DiscoverySpec,
    progress_callback: Optional[Callable[[str, float], None]] = None,
) -> Tuple[pd.DataFrame, Dict[str, float], pd.DataFrame]:
    def report(step: str, pct: float) -> None:
        if progress_callback is not None:
            progress_callback(step, pct)

    rng = np.random.default_rng(spec.random_seed)

    # 1) Determine required columns
    report("Preparing columns…", 0.02)
    obj_props = [o["property"].lower() for o in spec.objectives]
    cons_props = [p.lower() for p in spec.hard_constraints.keys()]

    needed_props = sorted(set(obj_props + cons_props))
    cols = ["SMILES"] + [mean_col(p) for p in needed_props]

    # include std columns if available (not required, but used for trust)
    std_cols = [std_col(p) for p in needed_props]
    cols += std_cols

    # 2) Load only needed columns
    report("Loading data from parquet…", 0.05)
    df = load_parquet_columns(spec.dataset, columns=[c for c in cols if c != "SMILES"] + ["SMILES"])
    # normalize
    if "SMILES" not in df.columns and "smiles" in df.columns:
        df = df.rename(columns={"smiles": "SMILES"})
    normalize_step = "Canonicalizing SMILES…" if spec.use_canonical_smiles else "Skipping SMILES normalization…"
    report(normalize_step, 0.10)
    df["smiles_key"] = df["SMILES"].astype(str).map(lambda s: normalize_smiles(s, spec.use_canonical_smiles))
    df = df.dropna(subset=["smiles_key"]).reset_index(drop=True)

    # 3) Hard constraints
    report("Applying constraints…", 0.22)
    for p, rule in spec.hard_constraints.items():
        p = p.lower()
        c = mean_col(p)
        if c not in df.columns:
            # if missing, nothing can satisfy
            df = df.iloc[0:0]
            break
        if "min" in rule:
            df = df[df[c] >= float(rule["min"])]
        if "max" in rule:
            df = df[df[c] <= float(rule["max"])]

    n_after = len(df)
    if n_after == 0:
        empty_stats = {"n_total": 0, "n_after_constraints": 0, "n_pool": 0, "n_pareto_pool": 0, "n_selected": 0}
        return df, empty_stats, pd.DataFrame()

    n_pool = len(df)

    # 5) Prepare objective matrix for Pareto
    report("Building objective matrix…", 0.30)
    # convert to minimization: maximize => negate
    X = []
    for o in spec.objectives:
        prop = o["property"].lower()
        goal = o["goal"].lower()
        c = mean_col(prop)
        if c not in df.columns:
            raise ValueError(f"Objective column missing: {c}")
        v = df[c].to_numpy(dtype=float)
        if goal == "maximize":
            v = -v
        X.append(v)
    X = np.stack(X, axis=1)  # (N, M)

    # Pareto cap before computing layers (optional safety)
    if spec.use_full_data:
        report("Using full dataset (no Pareto cap)…", 0.35)
    elif len(df) > spec.pareto_max:
        idx = rng.choice(len(df), size=spec.pareto_max, replace=False)
        df = df.iloc[idx].reset_index(drop=True)
        X = X[idx]

    # 6) Pareto layers (only 5 layers needed for candidate pool)
    report("Computing Pareto layers…", 0.40)
    pareto_start = 0.40
    pareto_end = 0.54
    max_layers_for_pool = max(1, int(spec.max_pareto_fronts))
    pareto_chunk_ref = {"chunks_per_layer": None}

    def on_pareto_chunk(layer_i: int, done_chunks: int, total_chunks: int) -> None:
        if pareto_chunk_ref["chunks_per_layer"] is None:
            pareto_chunk_ref["chunks_per_layer"] = max(1, int(total_chunks))
        ref_chunks = pareto_chunk_ref["chunks_per_layer"]
        total_units = max_layers_for_pool * ref_chunks
        done_units = min(total_units, ((layer_i - 1) * ref_chunks) + done_chunks)
        pareto_pct = int(round(100.0 * done_units / max(1, total_units)))

        layer_progress = done_chunks / max(1, total_chunks)
        overall = ((layer_i - 1) + layer_progress) / max_layers_for_pool
        pct = pareto_start + (pareto_end - pareto_start) * min(1.0, max(0.0, overall))
        report(
            f"Computing Pareto layers… {pareto_pct}% (Layer {layer_i}/{max_layers_for_pool}, chunk {done_chunks}/{total_chunks})",
            pct,
        )

    layers = pareto_layers_chunked(
        X,
        max_layers=max_layers_for_pool,
        chunk_size=100000,
        progress_callback=on_pareto_chunk,
    )
    report("Computing Pareto layers…", pareto_end)
    df["pareto_layer"] = layers
    plot_df = df[["smiles_key"] + [mean_col(p) for p in obj_props] + ["pareto_layer"]].copy()
    plot_df = plot_df.rename(columns={"smiles_key": "SMILES"})

    # Keep first few layers as candidate pool (avoid huge set)
    cand = df[df["pareto_layer"].between(1, max_layers_for_pool)].copy()
    if cand.empty:
        cand = df[df["pareto_layer"] == 1].copy()
    cand = cand.reset_index(drop=True)
    n_pareto = len(cand)

    # 7) Load real polymer metadata and fingerprints (from POLYINFO.csv)
    report("Loading POLYINFO index…", 0.55)
    polyinfo = load_polyinfo_index(spec.polyinfo_csv, use_canonical_smiles=spec.use_canonical_smiles)
    real_smiles = polyinfo.index.to_list()

    report("Building real-polymer fingerprints…", 0.60)
    real_fps = []
    for s in real_smiles:
        fp = morgan_fp(s)
        if fp is not None:
            real_fps.append(fp)

    # 8) Trust score on candidate pool (safe size)
    report("Computing trust scores…", 0.70)
    trust = compute_trust_scores(
        cand,
        real_fps=real_fps,
        real_smiles=real_smiles,
        trust_weights=spec.trust_weights,
    )
    cand["trust_score"] = trust

    # 9) Diversity selection on candidate pool
    report("Diversity selection…", 0.88)
    # score for selection: prioritize Pareto layer 1 then trust
    # higher is better
    sw_defaults = {"pareto": 0.60, "trust": 0.40}
    sw = normalize_weights(spec.selection_weights or {}, sw_defaults)
    pareto_bonus = (
        (max_layers_for_pool + 1) - np.clip(cand["pareto_layer"].to_numpy(dtype=int), 1, max_layers_for_pool)
    ) / float(max_layers_for_pool)
    sel_score = sw["pareto"] * pareto_bonus + sw["trust"] * cand["trust_score"].to_numpy(dtype=float)

    chosen_idx = greedy_diverse_select(
        smiles_list=cand["smiles_key"].tolist(),
        scores=sel_score,
        max_k=spec.max_candidates,
        min_dist=spec.min_distance,
    )
    out = cand.iloc[chosen_idx].copy().reset_index(drop=True)

    # 10) Attach Polymer_Name/Class if available (only for matches)
    report("Finalizing results…", 0.96)
    out = out.set_index("smiles_key", drop=False)
    out = out.join(polyinfo, how="left")
    out = out.reset_index(drop=True)

    # 11) Make a clean output bundle with requested columns
    # Keep SMILES (canonical), name/class, pareto layer, trust score, properties used
    keep = ["smiles_key", "polymer_name", "polymer_class", "pareto_layer", "trust_score"]
    for p in needed_props:
        mc = mean_col(p)
        sc = std_col(p)
        if mc in out.columns:
            keep.append(mc)
        if sc in out.columns:
            keep.append(sc)

    out = out[keep].rename(columns={"smiles_key": "SMILES"})

    stats = {
        "n_total": float(len(df)),
        "n_after_constraints": float(n_after),
        "n_pool": float(n_pool),
        "n_pareto_pool": float(n_pareto),
        "n_selected": float(len(out)),
    }
    report("Done.", 1.0)
    return out, stats, plot_df


def build_pareto_plot_df(spec: DiscoverySpec, max_plot_points: int = 30000) -> pd.DataFrame:
    """
    Returns a small dataframe for plotting (sampled), with objective columns and pareto_layer.
    Does NOT compute trust/diversity. Safe for live plotting.
    """
    rng = np.random.default_rng(spec.random_seed)

    obj_props = [o["property"].lower() for o in spec.objectives]
    cons_props = [p.lower() for p in spec.hard_constraints.keys()]
    needed_props = sorted(set(obj_props + cons_props))

    cols = ["SMILES"] + [mean_col(p) for p in needed_props]
    df = load_parquet_columns(spec.dataset, columns=cols)

    if "SMILES" not in df.columns and "smiles" in df.columns:
        df = df.rename(columns={"smiles": "SMILES"})

    df["smiles_key"] = df["SMILES"].astype(str).map(lambda s: normalize_smiles(s, spec.use_canonical_smiles))
    df = df.dropna(subset=["smiles_key"]).reset_index(drop=True)

    # Hard constraints
    for p, rule in spec.hard_constraints.items():
        p = p.lower()
        c = mean_col(p)
        if c not in df.columns:
            return df.iloc[0:0]
        if "min" in rule:
            df = df[df[c] >= float(rule["min"])]
        if "max" in rule:
            df = df[df[c] <= float(rule["max"])]

    if len(df) == 0:
        return df

    # Pareto cap for plotting
    plot_cap = min(int(max_plot_points), int(spec.pareto_max))
    if len(df) > plot_cap:
        idx = rng.choice(len(df), size=plot_cap, replace=False)
        df = df.iloc[idx].reset_index(drop=True)

    # Build objective matrix (minimization)
    X = []
    for o in spec.objectives:
        prop = o["property"].lower()
        goal = o["goal"].lower()
        c = mean_col(prop)
        v = df[c].to_numpy(dtype=float)
        if goal == "maximize":
            v = -v
        X.append(v)
    X = np.stack(X, axis=1)

    df["pareto_layer"] = pareto_layers(X, max_layers=5)

    # Return only what plotting needs
    keep = ["smiles_key", "pareto_layer"] + [mean_col(p) for p in obj_props]
    out = df[keep].rename(columns={"smiles_key": "SMILES"})
    return out


def parse_spec(text: str, dataset_path: List[str], polyinfo_path: str, polyinfo_csv_path: str) -> DiscoverySpec:
    obj = json.loads(text)
    pareto_max = int(obj.get("pareto_max", 50000))

    return DiscoverySpec(
        dataset=list(dataset_path),
        polyinfo=polyinfo_path,
        polyinfo_csv=polyinfo_csv_path,
        hard_constraints=obj.get("hard_constraints", {}),
        objectives=obj.get("objectives", []),
        max_pool=pareto_max,
        pareto_max=pareto_max,
        max_candidates=int(obj.get("max_candidates", 30)),
        max_pareto_fronts=int(obj.get("max_pareto_fronts", 5)),
        min_distance=float(obj.get("min_distance", 0.30)),
        fingerprint=str(obj.get("fingerprint", "morgan")),
        random_seed=int(obj.get("random_seed", 7)),
        use_canonical_smiles=not bool(obj.get("skip_smiles_canonicalization", True)),
        use_full_data=bool(obj.get("use_full_data", False)),
        trust_weights=obj.get("trust_weights"),
        selection_weights=obj.get("selection_weights"),
    )