File size: 6,694 Bytes
e42c9a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a09246c
 
e42c9a6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
"""

╔══════════════════════════════════════════════════════════════════════╗

β•‘  TRIADS V13A β€” SOTA Evaluation Script                               β•‘

β•‘  Reproduces the 91.20 MPa MAE on matbench_steels                    β•‘

β•‘                                                                      β•‘

β•‘  Usage:                                                              β•‘

β•‘    python evaluate.py                                                β•‘

β•‘    python evaluate.py --checkpoint path/to/triads_v13a_ensemble.pt   β•‘

β•‘                                                                      β•‘

β•‘  This script will:                                                   β•‘

β•‘    1. Download the checkpoint from HuggingFace (if not provided)     β•‘

β•‘    2. Load the matbench_steels dataset (312 samples)                 β•‘

β•‘    3. Compute expanded features (Magpie + Mat2Vec + Matminer)        β•‘

β•‘    4. Run official 5-fold nested CV with the 5-seed ensemble         β•‘

β•‘    5. Report per-fold and overall MAE                                β•‘

β•‘                                                                      β•‘

β•‘  Expected output: ~91.20 MPa MAE                                    β•‘

β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•

"""

import os
import argparse
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader
from sklearn.model_selection import KFold

from model_arch import DeepHybridTRM, ExpandedFeaturizer, DSData


def evaluate(checkpoint_path=None):
    # ── 1. Load checkpoint ────────────────────────────────────────────
    if checkpoint_path is None:
        from huggingface_hub import hf_hub_download
        print("Downloading checkpoint from HuggingFace...")
        checkpoint_path = hf_hub_download(
            repo_id="Rtx09/TRIADS",
            filename="triads_v13a_ensemble.pt"
        )

    print(f"Loading checkpoint: {checkpoint_path}")
    ckpt = torch.load(checkpoint_path, map_location="cpu")
    config = ckpt["config"]
    seeds = ckpt["seeds"]
    n_folds = ckpt["n_folds"]
    ensemble_weights = ckpt["ensemble_weights"]
    print(f"  {ckpt['model_name']} β€” {len(ensemble_weights)} models "
          f"({len(seeds)} seeds Γ— {n_folds} folds)")

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"  Device: {device}")

    # ── 2. Load matbench_steels ───────────────────────────────────────
    print("\nLoading matbench_steels dataset...")
    from matminer.datasets import load_dataset
    from pymatgen.core import Composition

    df = load_dataset("matbench_steels")
    comps_raw = df["composition"].tolist()
    targets_all = np.array(df["yield strength"].tolist(), np.float32)
    comps_all = [Composition(c) for c in comps_raw]
    print(f"  {len(comps_all)} samples loaded")

    # ── 3. Featurize ──────────────────────────────────────────────────
    print("\nComputing expanded features (Magpie + Mat2Vec + Matminer)...")
    feat = ExpandedFeaturizer()
    X_all = feat.featurize_all(comps_all)
    print(f"  Feature shape: {X_all.shape}")

    # ── 4. Official 5-fold CV (same split as training) ────────────────
    kfold = KFold(n_splits=5, shuffle=True, random_state=18012019)
    folds = list(kfold.split(comps_all))

    print(f"\n{'═'*60}")
    print(f"  Running 5-Fold Evaluation with {len(seeds)}-Seed Ensemble")
    print(f"{'═'*60}\n")

    fold_maes = []

    for fi, (tv_i, te_i) in enumerate(folds):
        # Fit scaler on training data for this fold
        feat.fit_scaler(X_all[tv_i])
        te_s = feat.transform(X_all[te_i])

        te_dl = DataLoader(
            DSData(te_s, targets_all[te_i]),
            batch_size=32, shuffle=False, num_workers=0
        )
        te_tgt = torch.tensor(targets_all[te_i], dtype=torch.float32)

        # Collect predictions from all seeds for this fold
        seed_preds = []
        for seed in seeds:
            key = f"seed{seed}_fold{fi+1}"
            if key not in ensemble_weights:
                print(f"  WARNING: Missing {key}, skipping")
                continue

            # Force n_extra=200 to match pool.0.weight shape [96, 264] (64+200=264)
            model = DeepHybridTRM(n_extra=200).to(device)
            model.load_state_dict(ensemble_weights[key])
            model.eval()

            preds = []
            with torch.no_grad():
                for bx, _ in te_dl:
                    preds.append(model(bx.to(device)).cpu())
            seed_preds.append(torch.cat(preds))

            del model
            if device.type == "cuda":
                torch.cuda.empty_cache()

        # Ensemble: average across seeds
        ensemble_pred = torch.stack(seed_preds).mean(dim=0)
        fold_mae = F.l1_loss(ensemble_pred, te_tgt).item()
        fold_maes.append(fold_mae)

        print(f"  Fold {fi+1}/5: MAE = {fold_mae:.2f} MPa  "
              f"({len(seed_preds)} seeds)")

    # ── 5. Report ─────────────────────────────────────────────────────
    avg_mae = np.mean(fold_maes)
    std_mae = np.std(fold_maes)

    print(f"\n{'═'*60}")
    print(f"  TRIADS V13A β€” Final Result")
    print(f"  {'─'*40}")
    print(f"  Per-fold MAE: {[f'{m:.2f}' for m in fold_maes]}")
    print(f"  Average MAE:  {avg_mae:.2f} Β± {std_mae:.2f} MPa")
    print(f"  Best fold:    {min(fold_maes):.2f} MPa")
    print(f"{'═'*60}")

    return avg_mae


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="Reproduce TRIADS V13A SOTA result (91.20 MPa)")
    parser.add_argument(
        "--checkpoint", type=str, default=None,
        help="Path to triads_v13a_ensemble.pt "
             "(downloads from HuggingFace if not provided)")
    args = parser.parse_args()
    evaluate(args.checkpoint)