Rtx09 commited on
Commit
8a82d34
·
0 Parent(s):

TRIADS — 6-benchmark weights + model code + Gradio app

Browse files

Benchmarks:
- matbench_steels: 91.20 MPa (HybridTRIADS V13A, 225K, 5-fold 5-seed avg)
- matbench_expt_gap: 0.3068 eV (HybridTRIADS V3, 100K)
- matbench_expt_ismetal: 0.9655 AUC (HybridTRIADS, 44K, best comp-only)
- matbench_glass: 0.9285 AUC (HybridTRIADS, 44K, 5-seed)
- matbench_jdft2d: 35.89 meV (HybridTRIADS V4, 75K, 5-fold 5-seed avg)
- matbench_phonons: 41.91 cm-1 (GraphTRIADS V6, 247K, gate-halt)

.gitattributes ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ *.pt filter=lfs diff=lfs merge=lfs -text
2
+ weights/** filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ language: en
4
+ tags:
5
+ - materials-science
6
+ - machine-learning
7
+ - pytorch
8
+ - matbench
9
+ - small-data
10
+ - attention
11
+ - recursive
12
+ - crystal
13
+ - gradio
14
+ datasets:
15
+ - matbench
16
+ metrics:
17
+ - mae
18
+ - roc_auc
19
+ model-index:
20
+ - name: TRIADS
21
+ results:
22
+ - task:
23
+ type: regression
24
+ name: Yield Strength Prediction (MPa)
25
+ dataset:
26
+ name: matbench_steels
27
+ type: matbench
28
+ metrics:
29
+ - type: mae
30
+ value: 91.20
31
+ name: MAE (MPa)
32
+ - task:
33
+ type: regression
34
+ name: Band Gap Prediction (eV)
35
+ dataset:
36
+ name: matbench_expt_gap
37
+ type: matbench
38
+ metrics:
39
+ - type: mae
40
+ value: 0.3068
41
+ name: MAE (eV)
42
+ - task:
43
+ type: classification
44
+ name: Metallicity Classification
45
+ dataset:
46
+ name: matbench_expt_ismetal
47
+ type: matbench
48
+ metrics:
49
+ - type: roc_auc
50
+ value: 0.9655
51
+ name: ROC-AUC
52
+ - task:
53
+ type: classification
54
+ name: Glass Forming Ability
55
+ dataset:
56
+ name: matbench_glass
57
+ type: matbench
58
+ metrics:
59
+ - type: roc_auc
60
+ value: 0.9285
61
+ name: ROC-AUC
62
+ - task:
63
+ type: regression
64
+ name: Exfoliation Energy (meV/atom)
65
+ dataset:
66
+ name: matbench_jdft2d
67
+ type: matbench
68
+ metrics:
69
+ - type: mae
70
+ value: 35.89
71
+ name: MAE (meV/atom)
72
+ - task:
73
+ type: regression
74
+ name: Peak Phonon Frequency (cm⁻¹)
75
+ dataset:
76
+ name: matbench_phonons
77
+ type: matbench
78
+ metrics:
79
+ - type: mae
80
+ value: 41.91
81
+ name: MAE (cm⁻¹)
82
+ ---
83
+
84
+ # TRIADS — Materials Property Prediction Across 6 Matbench Benchmarks
85
+
86
+ **TRIADS (Tiny Recursive Information-Attention with Deep Supervision)** is a parameter-efficient recursive architecture for materials property prediction, purpose-built for the **small-data regime** (312–5,680 samples).
87
+
88
+ [![GitHub](https://img.shields.io/badge/GitHub-Code-black?logo=github)](https://github.com/Rtx09x/TRIADS)
89
+ [![Paper](https://img.shields.io/badge/Paper-PDF-red)](https://github.com/Rtx09x/TRIADS/raw/main/TRIADS_Final.pdf)
90
+
91
+ ## Live Demo
92
+
93
+ Try the interactive demo with all 6 benchmarks → **[Launch App](https://huggingface.co/spaces/Rtx09/TRIADS)**
94
+
95
+ ## Results Summary
96
+
97
+ | Task | N | TRIADS | Params | Rank |
98
+ |---|---|---|---|---|
99
+ | `matbench_steels` (yield strength) | 312 | **91.20 MPa** | 225K | #3 |
100
+ | `matbench_expt_gap` (band gap) | 4,604 | **0.3068 eV** | 100K | #2 composition-only |
101
+ | `matbench_expt_ismetal` (metal?) | 4,921 | **0.9655 ROC-AUC** | 100K | **#1** composition-only |
102
+ | `matbench_glass` (glass forming) | 5,680 | **0.9285 ROC-AUC** | 44K | #2 |
103
+ | `matbench_jdft2d` (exfol. energy) | 636 | **35.89 meV/atom** | 75K | **#1** no-pretraining |
104
+ | `matbench_phonons` (phonon freq.) | 1,265 | **41.91 cm⁻¹** | 247K | **#1** no-pretraining |
105
+
106
+ ## Two Model Variants
107
+
108
+ ### HybridTRIADS (composition tasks: steels, gap, ismetal, glass, jdft2d)
109
+ Input: Chemical formula → Magpie + Mat2Vec (composition tokens)
110
+ Core: 2-layer self-attention cell, iterated T=16-20 times with shared weights
111
+ Training: Per-cycle deep supervision (w_t ∝ t)
112
+
113
+ ### GraphTRIADS (structural task: phonons)
114
+ Input: CIF/structure → 3-order hierarchical crystal graph (atoms, bonds, triplet angles, dihedral angles)
115
+ Core: Hierarchical GNN message-passing stack inside the shared recursive cell
116
+ Halting: Gate-based adaptive halting (4–16 cycles per sample)
117
+
118
+ ## Pretrained Checkpoints
119
+
120
+ Weights are organized by benchmark. Download via `huggingface_hub`:
121
+
122
+ ```python
123
+ from huggingface_hub import hf_hub_download
124
+ import torch
125
+
126
+ # Download one benchmark's weights (contains all folds compiled)
127
+ ckpt = torch.load(
128
+ hf_hub_download("Rtx09/TRIADS", "steels/weights.pt"),
129
+ map_location="cpu"
130
+ )
131
+ # ckpt['folds'] -> list of fold dicts, each with 'model_state' and 'test_mae'
132
+ # ckpt['n_extra'] -> int (needed for model init)
133
+ # ckpt['config'] -> dict (d_attn, d_hidden, ff_dim, dropout, max_steps)
134
+ ```
135
+
136
+ ### Checkpoint Index
137
+
138
+ | Benchmark | File | Folds | Notes |
139
+ |---|---|---|---|
140
+ | matbench_steels | `steels/weights.pt` | 5 | HybridTRIADS V13A · 225K · 5-seed ensemble compiled |
141
+ | matbench_expt_gap | `expt_gap/weights.pt` | 5 | HybridTRIADS V3 · 100K |
142
+ | matbench_expt_ismetal | `is_metal/weights.pt` | 5 | HybridTRIADS · 100K |
143
+ | matbench_glass | `glass/weights.pt` | 5 | HybridTRIADS · 44K |
144
+ | matbench_jdft2d | `jdft2d/weights.pt` | 5 | HybridTRIADS V4 · 75K · 5-seed ensemble compiled |
145
+ | matbench_phonons | `phonons/weights.pt` | 5 | GraphTRIADS V6 · 247K · also needs `phonons/dataset.pt` |
146
+
147
+ ## Citation
148
+
149
+ ```bibtex
150
+ @article{tiwari2026triads,
151
+ author = {Rudra Tiwari},
152
+ title = {TRIADS: Tiny Recursive Information-Attention with Deep Supervision},
153
+ year = {2026},
154
+ url = {https://github.com/Rtx09x/TRIADS}
155
+ }
156
+ ```
157
+
158
+ ## License
159
+
160
+ MIT License — see [GitHub repository](https://github.com/Rtx09x/TRIADS/blob/main/LICENSE).
app.py ADDED
@@ -0,0 +1,658 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ TRIADS — Multi-Benchmark Materials Property Prediction
3
+ HuggingFace Gradio App
4
+
5
+ Covers all 6 Matbench benchmarks:
6
+ 1. matbench_steels — Yield Strength (MPa)
7
+ 2. matbench_expt_gap — Band Gap (eV)
8
+ 3. matbench_ismetal — Metallicity (ROC-AUC)
9
+ 4. matbench_glass — Glass Forming Ability
10
+ 5. matbench_jdft2d — Exfoliation Energy (meV/atom)
11
+ 6. matbench_phonons — Peak Phonon Frequency (cm⁻¹)
12
+ """
13
+
14
+ import os
15
+ import warnings
16
+ import urllib.request
17
+ import json
18
+
19
+ warnings.filterwarnings("ignore")
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torch.nn as nn
24
+ import gradio as gr
25
+ from huggingface_hub import hf_hub_download
26
+
27
+ # ─────────────────────────────────────────────────────────────────
28
+ # CONFIG
29
+ # ─────────────────────────────────────────────────────────────────
30
+
31
+ REPO_ID = "Rtx09/TRIADS"
32
+
33
+ BENCHMARK_INFO = {
34
+ "steels": {
35
+ "title": "🔩 Steel Yield Strength",
36
+ "description": "Predict yield strength (MPa) of steel alloys from composition.",
37
+ "unit": "MPa",
38
+ "example": "Fe0.7Cr0.15Ni0.15",
39
+ "examples": ["Fe0.7Cr0.15Ni0.15", "Fe0.8C0.02Mn0.1Si0.05Cr0.03", "Fe0.6Ni0.25Mo0.1Cr0.05"],
40
+ "task": "regression",
41
+ "result": "91.20 ± 12.23 MPa MAE (5-fold, 5-seed ensemble)",
42
+ },
43
+ "expt_gap": {
44
+ "title": "⚡ Experimental Band Gap",
45
+ "description": "Predict experimental electronic band gap (eV) from composition.",
46
+ "unit": "eV",
47
+ "example": "TiO2",
48
+ "examples": ["TiO2", "GaN", "ZnO", "Si", "CdS"],
49
+ "task": "regression",
50
+ "result": "0.3068 ± 0.0082 eV MAE (5-fold, composition-only)",
51
+ },
52
+ "ismetal": {
53
+ "title": "🔮 Metallicity Classifier",
54
+ "description": "Predict whether a material is metallic or non-metallic from composition.",
55
+ "unit": "probability (1 = metal)",
56
+ "example": "Cu",
57
+ "examples": ["Cu", "SiO2", "Fe3O4", "BaTiO3", "Al"],
58
+ "task": "classification",
59
+ "result": "0.9655 ± 0.0029 ROC-AUC (5-fold, composition-only)",
60
+ },
61
+ "glass": {
62
+ "title": "🪟 Glass Forming Ability",
63
+ "description": "Predict metallic glass forming ability from alloy composition.",
64
+ "unit": "probability (1 = glass former)",
65
+ "example": "Cu46Zr54",
66
+ "examples": ["Cu46Zr54", "Fe80B20", "Al86Ni7La6Y1", "Pd40Cu30Ni10P20"],
67
+ "task": "classification",
68
+ "result": "0.9285 ± 0.0063 ROC-AUC (5-fold, 5-seed ensemble)",
69
+ },
70
+ "jdft2d": {
71
+ "title": "📐 Exfoliation Energy",
72
+ "description": "Predict exfoliation energy (meV/atom) of 2D materials from structure+composition.",
73
+ "unit": "meV/atom",
74
+ "example": "MoS2",
75
+ "examples": ["MoS2", "WSe2", "BN", "graphene (C)", "MoTe2"],
76
+ "task": "regression",
77
+ "result": "35.89 ± 12.40 meV/atom MAE (5-fold, 5-seed ensemble)",
78
+ },
79
+ "phonons": {
80
+ "title": "🎵 Phonon Peak Frequency",
81
+ "description": "Predict peak phonon frequency (cm⁻¹) from crystal structure.",
82
+ "unit": "cm⁻¹",
83
+ "example": "Si (diamond cubic)",
84
+ "examples": ["Si", "GaAs", "MgO", "BN (wurtzite)", "TiO2 (rutile)"],
85
+ "task": "regression",
86
+ "result": "41.91 ± 4.04 cm⁻¹ MAE (5-fold, gate-halt GraphTRIADS)",
87
+ },
88
+ }
89
+
90
+
91
+ # ─────────────────────────────────────────────────────────────────
92
+ # MODEL DEFINITIONS (inlined for self-contained app)
93
+ # ─────────────────────────────────────────────────────────────────
94
+
95
+ class DeepHybridTRM(nn.Module):
96
+ """
97
+ HybridTRIADS — composition-only tasks.
98
+ Shared across: steels, expt_gap, ismetal, glass, jdft2d.
99
+ """
100
+ def __init__(self, n_props=22, stat_dim=6, n_extra=0, mat2vec_dim=200,
101
+ d_attn=64, nhead=4, d_hidden=96, ff_dim=150,
102
+ dropout=0.2, max_steps=20, **kw):
103
+ super().__init__()
104
+ self.max_steps, self.D = max_steps, d_hidden
105
+ self.n_props, self.stat_dim, self.n_extra = n_props, stat_dim, n_extra
106
+
107
+ self.tok_proj = nn.Sequential(
108
+ nn.Linear(stat_dim, d_attn), nn.LayerNorm(d_attn), nn.GELU())
109
+ self.m2v_proj = nn.Sequential(
110
+ nn.Linear(mat2vec_dim, d_attn), nn.LayerNorm(d_attn), nn.GELU())
111
+
112
+ self.sa1 = nn.MultiheadAttention(d_attn, nhead, dropout=dropout, batch_first=True)
113
+ self.sa1_n = nn.LayerNorm(d_attn)
114
+ self.sa1_ff = nn.Sequential(
115
+ nn.Linear(d_attn, d_attn*2), nn.GELU(), nn.Dropout(dropout),
116
+ nn.Linear(d_attn*2, d_attn))
117
+ self.sa1_fn = nn.LayerNorm(d_attn)
118
+
119
+ self.sa2 = nn.MultiheadAttention(d_attn, nhead, dropout=dropout, batch_first=True)
120
+ self.sa2_n = nn.LayerNorm(d_attn)
121
+ self.sa2_ff = nn.Sequential(
122
+ nn.Linear(d_attn, d_attn*2), nn.GELU(), nn.Dropout(dropout),
123
+ nn.Linear(d_attn*2, d_attn))
124
+ self.sa2_fn = nn.LayerNorm(d_attn)
125
+
126
+ self.ca = nn.MultiheadAttention(d_attn, nhead, dropout=dropout, batch_first=True)
127
+ self.ca_n = nn.LayerNorm(d_attn)
128
+
129
+ pool_in = d_attn + (n_extra if n_extra > 0 else 0)
130
+ self.pool = nn.Sequential(
131
+ nn.Linear(pool_in, d_hidden), nn.LayerNorm(d_hidden), nn.GELU())
132
+
133
+ self.z_up = nn.Sequential(
134
+ nn.Linear(d_hidden*3, ff_dim), nn.GELU(), nn.Dropout(dropout),
135
+ nn.Linear(ff_dim, d_hidden), nn.LayerNorm(d_hidden))
136
+ self.y_up = nn.Sequential(
137
+ nn.Linear(d_hidden*2, ff_dim), nn.GELU(), nn.Dropout(dropout),
138
+ nn.Linear(ff_dim, d_hidden), nn.LayerNorm(d_hidden))
139
+ self.head = nn.Linear(d_hidden, 1)
140
+ self._init()
141
+
142
+ def _init(self):
143
+ for m in self.modules():
144
+ if isinstance(m, nn.Linear):
145
+ nn.init.xavier_uniform_(m.weight)
146
+ if m.bias is not None: nn.init.zeros_(m.bias)
147
+
148
+ def _attention(self, x):
149
+ B = x.size(0)
150
+ mg_dim = self.n_props * self.stat_dim
151
+ if self.n_extra > 0:
152
+ extra = x[:, mg_dim:mg_dim + self.n_extra]
153
+ m2v = x[:, mg_dim + self.n_extra:]
154
+ else:
155
+ extra, m2v = None, x[:, mg_dim:]
156
+
157
+ tok = self.tok_proj(x[:, :mg_dim].view(B, self.n_props, self.stat_dim))
158
+ ctx = self.m2v_proj(m2v).unsqueeze(1)
159
+
160
+ tok = self.sa1_n(tok + self.sa1(tok, tok, tok)[0])
161
+ tok = self.sa1_fn(tok + self.sa1_ff(tok))
162
+ tok = self.sa2_n(tok + self.sa2(tok, tok, tok)[0])
163
+ tok = self.sa2_fn(tok + self.sa2_ff(tok))
164
+ tok = self.ca_n(tok + self.ca(tok, ctx, ctx)[0])
165
+
166
+ pooled = tok.mean(dim=1)
167
+ if extra is not None:
168
+ pooled = torch.cat([pooled, extra], dim=-1)
169
+ return self.pool(pooled)
170
+
171
+ def forward(self, x, deep_supervision=False):
172
+ B = x.size(0)
173
+ xp = self._attention(x)
174
+ z = torch.zeros(B, self.D, device=x.device)
175
+ y = torch.zeros(B, self.D, device=x.device)
176
+ step_preds = []
177
+ for _ in range(self.max_steps):
178
+ z = z + self.z_up(torch.cat([xp, y, z], -1))
179
+ y = y + self.y_up(torch.cat([y, z], -1))
180
+ step_preds.append(self.head(y).squeeze(1))
181
+ return step_preds if deep_supervision else step_preds[-1]
182
+
183
+
184
+ # ─────────────────────────────────────────────────────────────────
185
+ # FEATURIZER (composition-only, shared across HybridTRIADS tasks)
186
+ # ─────────────────────────────────────────────────────────────────
187
+
188
+ _featurizer_cache = {}
189
+ _mat2vec_cache = {}
190
+
191
+
192
+ def _get_featurizer():
193
+ """Lazy-load the ExpandedFeaturizer (downloads Mat2Vec once)."""
194
+ if "main" in _featurizer_cache:
195
+ return _featurizer_cache["main"]
196
+
197
+ try:
198
+ from matminer.featurizers.composition import (
199
+ ElementProperty, ElementFraction, Stoichiometry,
200
+ ValenceOrbital, IonProperty, BandCenter
201
+ )
202
+ from matminer.featurizers.base import MultipleFeaturizer
203
+ from gensim.models import Word2Vec
204
+ from sklearn.preprocessing import StandardScaler
205
+
206
+ GCS = "https://storage.googleapis.com/mat2vec/"
207
+ M2V_FILES = [
208
+ "pretrained_embeddings",
209
+ "pretrained_embeddings.wv.vectors.npy",
210
+ "pretrained_embeddings.trainables.syn1neg.npy",
211
+ ]
212
+ os.makedirs("mat2vec_cache", exist_ok=True)
213
+ for f in M2V_FILES:
214
+ p = os.path.join("mat2vec_cache", f)
215
+ if not os.path.exists(p):
216
+ urllib.request.urlretrieve(GCS + f, p)
217
+
218
+ ep = ElementProperty.from_preset("magpie")
219
+ m2v = Word2Vec.load("mat2vec_cache/pretrained_embeddings")
220
+ emb = {w: m2v.wv[w] for w in m2v.wv.index_to_key}
221
+ extra = MultipleFeaturizer([ElementFraction(), Stoichiometry(),
222
+ ValenceOrbital(), IonProperty(), BandCenter()])
223
+
224
+ _featurizer_cache["main"] = (ep, m2v, emb, extra)
225
+ return _featurizer_cache["main"]
226
+
227
+ except Exception as e:
228
+ return None
229
+
230
+
231
+ def featurize_composition(formula: str):
232
+ """Featurize a chemical formula into the TRIADS feature vector."""
233
+ from pymatgen.core import Composition
234
+
235
+ result = _get_featurizer()
236
+ if result is None:
237
+ return None, f"Featurizer not available: {str(e)}"
238
+
239
+ ep, m2v, emb, extra = result
240
+
241
+ try:
242
+ comp = Composition(formula)
243
+ except Exception:
244
+ return None, f"Invalid formula: '{formula}'"
245
+
246
+ try:
247
+ mg = np.array(ep.featurize(comp), np.float32)
248
+ except Exception:
249
+ mg = np.zeros(len(ep.feature_labels()), np.float32)
250
+
251
+ try:
252
+ ex = np.array(extra.featurize(comp), np.float32)
253
+ ex = np.nan_to_num(ex, nan=0.0)
254
+ except Exception:
255
+ ex = np.zeros(50, np.float32)
256
+
257
+ # Mat2Vec pooled
258
+ v, t = np.zeros(200, np.float32), 0.0
259
+ for s, f in comp.get_el_amt_dict().items():
260
+ if s in emb:
261
+ v += f * emb[s]
262
+ t += f
263
+ m2v_vec = v / max(t, 1e-8)
264
+
265
+ mg = np.nan_to_num(mg, nan=0.0)
266
+ feat = np.concatenate([mg, ex, m2v_vec])
267
+ return feat.astype(np.float32), None
268
+
269
+
270
+ # ─────────────────────────────────────────────────────────────────
271
+ # WEIGHT LOADING (lazy, cached)
272
+ # ─────────────────────────────────────────────────────────────────
273
+
274
+ # weights.pt format (one file per benchmark on HuggingFace):
275
+ # {
276
+ # 'folds': [ {'model_state': OrderedDict, 'test_mae': float}, ... ], # len == n_folds
277
+ # 'n_extra': int,
278
+ # 'config': {'d_attn': int, 'd_hidden': int, 'ff_dim': int,
279
+ # 'dropout': float, 'max_steps': int},
280
+ # 'benchmark': str,
281
+ # }
282
+
283
+ _fold_models = {} # benchmark -> list[nn.Module] (one entry per fold)
284
+
285
+ _MODEL_CONFIGS = {
286
+ # These MUST match the architecture configs baked into the saved weights.pt files.
287
+ # Values verified by inspecting ckpt['config'] from each weights.pt directly.
288
+ "steels": dict(d_attn=64, d_hidden=96, ff_dim=150, dropout=0.20, max_steps=20),
289
+ "expt_gap": dict(d_attn=64, d_hidden=96, ff_dim=150, dropout=0.20, max_steps=20), # V3 s42 (actual weights)
290
+ "ismetal": dict(d_attn=24, d_hidden=48, ff_dim=72, dropout=0.20, max_steps=16), # 100K actual
291
+ "glass": dict(d_attn=24, d_hidden=48, ff_dim=72, dropout=0.20, max_steps=16), # actual weights
292
+ "jdft2d": dict(d_attn=32, d_hidden=64, ff_dim=96, dropout=0.20, max_steps=16), # V4-75K actual
293
+ }
294
+
295
+ _HF_PATHS = {
296
+ "steels": "steels/weights.pt",
297
+ "expt_gap": "expt_gap/weights.pt",
298
+ "ismetal": "is_metal/weights.pt",
299
+ "glass": "glass/weights.pt",
300
+ "jdft2d": "jdft2d/weights.pt",
301
+ "phonons": "phonons/weights.pt",
302
+ }
303
+
304
+
305
+ def _load_benchmark_models(benchmark: str):
306
+ """
307
+ Download benchmark/weights.pt once, build one nn.Module per fold,
308
+ cache the list in _fold_models[benchmark].
309
+ Returns list[nn.Module] or None on failure.
310
+ """
311
+ if benchmark in _fold_models:
312
+ return _fold_models[benchmark]
313
+
314
+ if benchmark == "phonons":
315
+ # Phonons needs structure input — no composition-only inference
316
+ return None
317
+
318
+ try:
319
+ path = hf_hub_download(repo_id=REPO_ID, filename=_HF_PATHS[benchmark])
320
+ ckpt = torch.load(path, map_location="cpu", weights_only=False)
321
+
322
+ # Accept both old per-fold dicts and the new compiled format
323
+ fold_entries = ckpt.get("folds", [ckpt]) # fallback: single-fold legacy
324
+ n_extra = ckpt.get("n_extra", 0)
325
+ cfg = {**_MODEL_CONFIGS[benchmark], "n_extra": n_extra}
326
+
327
+ models = []
328
+ for entry in fold_entries:
329
+ m = DeepHybridTRM(**cfg)
330
+ state = entry if isinstance(entry, dict) and "weight" not in str(list(entry.keys())[:1]) \
331
+ else entry # entry is either a state_dict or {'model_state': ..., ...}
332
+ # Handle both {'model_state': sd} and raw state_dict formats
333
+ sd = entry.get("model_state", entry) if isinstance(entry, dict) else entry
334
+ m.load_state_dict(sd)
335
+ m.eval()
336
+ models.append(m)
337
+
338
+ _fold_models[benchmark] = models
339
+ return models
340
+
341
+ except Exception:
342
+ return None
343
+
344
+
345
+ def _ensemble_predict(benchmark: str, x: np.ndarray,
346
+ is_classification: bool = False):
347
+ """Run inference through all fold models, return averaged prediction."""
348
+ models = _load_benchmark_models(benchmark)
349
+ if not models:
350
+ return None, "Could not load model weights. Are they uploaded to HuggingFace?"
351
+
352
+ xt = torch.tensor(x[None], dtype=torch.float32)
353
+ preds = []
354
+ for m in models:
355
+ with torch.no_grad():
356
+ out = m(xt).item()
357
+ if is_classification:
358
+ out = torch.sigmoid(torch.tensor(out)).item()
359
+ preds.append(out)
360
+
361
+ return float(np.mean(preds)), None
362
+
363
+
364
+ # ─────────────────────────────────────────────────────────────────
365
+ # PREDICTION FUNCTIONS (one per benchmark tab)
366
+ # ─────────────────────────────────────────────────────────────────
367
+
368
+ def _status_bar(benchmark_key: str):
369
+ info = BENCHMARK_INFO[benchmark_key]
370
+ return (f"📊 **Benchmark result:** {info['result']}\n\n"
371
+ f"*Weights will be downloaded from HuggingFace on first prediction.*")
372
+
373
+
374
+ def predict_steels(formula: str):
375
+ feat, err = featurize_composition(formula)
376
+ if err:
377
+ return f"❌ Error: {err}", ""
378
+
379
+ pred, err = _ensemble_predict("steels", feat, is_classification=False)
380
+ if err:
381
+ return f"❌ {err}", ""
382
+
383
+ context = (f"**{pred:.1f} MPa** yield strength\n\n"
384
+ f"> TRIADS benchmark MAE: 91.20 MPa | "
385
+ f"CrabNet: 107.32 MPa | Darwin: 123.29 MPa")
386
+ return f"### {pred:.1f} MPa", context
387
+
388
+
389
+ def predict_expt_gap(formula: str):
390
+ feat, err = featurize_composition(formula)
391
+ if err:
392
+ return f"❌ Error: {err}", ""
393
+
394
+ pred, err = _ensemble_predict("expt_gap", feat, is_classification=False)
395
+ if err:
396
+ return f"❌ {err}", ""
397
+
398
+ metal_class = "Likely metallic (Eg ≈ 0)" if pred < 0.3 else (
399
+ "Small gap semiconductor" if pred < 1.5 else
400
+ "Wide-gap semiconductor/insulator")
401
+ context = (f"**{pred:.3f} eV** band gap\n\n"
402
+ f"Classification: {metal_class}\n\n"
403
+ f"> TRIADS benchmark MAE: 0.3068 eV | Darwin: 0.2865 eV")
404
+ return f"### {pred:.3f} eV", context
405
+
406
+
407
+ def predict_ismetal(formula: str):
408
+ feat, err = featurize_composition(formula)
409
+ if err:
410
+ return f"❌ Error: {err}", ""
411
+
412
+ pred, err = _ensemble_predict("ismetal", feat, is_classification=True)
413
+ if err:
414
+ return f"❌ {err}", ""
415
+
416
+ label = "🔩 **METALLIC**" if pred > 0.5 else "💎 **NON-METALLIC**"
417
+ pct = pred * 100 if pred > 0.5 else (1 - pred) * 100
418
+ confidence = "high" if pct > 80 else "moderate" if pct > 60 else "uncertain"
419
+ context = (f"{label} (confidence: {confidence}, p={pred:.3f})\n\n"
420
+ f"> TRIADS benchmark ROC-AUC: 0.9655 (best composition-only model)")
421
+ return f"### {pred:.3f} probability of being metallic", context
422
+
423
+
424
+ def predict_glass(formula: str):
425
+ feat, err = featurize_composition(formula)
426
+ if err:
427
+ return f"❌ Error: {err}", ""
428
+
429
+ pred, err = _ensemble_predict("glass", feat, is_classification=True)
430
+ if err:
431
+ return f"❌ {err}", ""
432
+
433
+ label = "🪟 **Likely glass-former**" if pred > 0.5 else "❌ **Unlikely glass-former**"
434
+ context = (f"{label} (p={pred:.3f})\n\n"
435
+ f"> TRIADS benchmark ROC-AUC: 0.9285 | MODNet: 0.9603")
436
+ return f"### {pred:.3f} glass-forming probability", context
437
+
438
+
439
+ def predict_jdft2d(formula: str):
440
+ feat, err = featurize_composition(formula)
441
+ if err:
442
+ return f"❌ Error: {err}", ""
443
+
444
+ pred, err = _ensemble_predict("jdft2d", feat, is_classification=False)
445
+ if err:
446
+ return f"❌ {err}", ""
447
+
448
+ ease = "Easy to exfoliate" if pred < 50 else "Moderate" if pred < 150 else "Hard to exfoliate"
449
+ context = (f"**{pred:.1f} meV/atom** exfoliation energy\n\n"
450
+ f"Exfoliatability: {ease}\n\n"
451
+ f"> TRIADS benchmark MAE: 35.89 meV/atom (best no-pretraining)")
452
+ return f"### {pred:.1f} meV/atom", context
453
+
454
+
455
+ def predict_phonons_placeholder(formula: str):
456
+ return ("### ⚠️ Phonons — Structure Required",
457
+ "GraphTRIADS for phonons requires a crystal structure (CIF file), "
458
+ "not just a formula. The pretrained weights are available at "
459
+ "`huggingface.co/Rtx09/TRIADS` under `phonons/`.\n\n"
460
+ f"> Benchmark MAE: 41.91 cm⁻¹ (gate-halt GraphTRIADS V6, 247K params)")
461
+
462
+
463
+ # ─────────────────────────────────────────────────────────────────
464
+ # GRADIO INTERFACE
465
+ # ─────────────────────────────────────────────────────────────────
466
+
467
+ CSS = """
468
+ .gr-box { border-radius: 12px !important; }
469
+ .tab-nav button { font-weight: 600; font-size: 14px; }
470
+ #result_text { font-size: 1.5rem; font-weight: 700; color: #6366f1; }
471
+ .benchmark-badge {
472
+ background: #1e293b; color: #94a3b8; border-radius: 8px;
473
+ padding: 8px 14px; font-family: monospace; font-size: 12px;
474
+ }
475
+ footer { display: none !important; }
476
+ """
477
+
478
+ def build_interface():
479
+ with gr.Blocks(css=CSS, title="TRIADS — Materials Property Prediction") as demo:
480
+
481
+ gr.Markdown("""
482
+ # ⚡ TRIADS — Materials Property Prediction
483
+ **Tiny Recursive Information-Attention with Deep Supervision**
484
+ Six Matbench benchmarks · Parameter-efficient · Small-data specialist
485
+
486
+ Select a benchmark tab below to predict a material property.
487
+ """)
488
+
489
+ with gr.Tabs():
490
+
491
+ # ── TAB 1: STEELS ───────────────────────────────────���─────────
492
+ with gr.Tab("🔩 Steel Yield"):
493
+ with gr.Row():
494
+ with gr.Column(scale=1):
495
+ gr.Markdown("### Alloy Yield Strength (MPa)")
496
+ gr.Markdown("Input an alloy composition (elemental fractions must sum to 1).")
497
+ formula_s = gr.Textbox(
498
+ label="Alloy formula",
499
+ placeholder="e.g. Fe0.7Cr0.15Ni0.15",
500
+ value="Fe0.7Cr0.15Ni0.15"
501
+ )
502
+ gr.Examples(
503
+ examples=["Fe0.7Cr0.15Ni0.15", "Fe0.8C0.02Mn0.1Si0.05Cr0.03",
504
+ "Fe0.6Ni0.25Mo0.1Cr0.05"],
505
+ inputs=formula_s
506
+ )
507
+ btn_s = gr.Button("Predict Yield Strength", variant="primary")
508
+ with gr.Column(scale=1):
509
+ out_s = gr.Markdown(elem_id="result_text")
510
+ ctx_s = gr.Markdown()
511
+ gr.Markdown(
512
+ "📊 TRIADS V13A · 225K params · 5-seed ensemble · **91.20 MPa MAE**",
513
+ elem_classes="benchmark-badge"
514
+ )
515
+ btn_s.click(predict_steels, inputs=formula_s, outputs=[out_s, ctx_s])
516
+
517
+ # ── TAB 2: BAND GAP ───────────────────────────────────────────
518
+ with gr.Tab("⚡ Band Gap"):
519
+ with gr.Row():
520
+ with gr.Column(scale=1):
521
+ gr.Markdown("### Experimental Band Gap (eV)")
522
+ gr.Markdown("Input a chemical composition formula.")
523
+ formula_g = gr.Textbox(
524
+ label="Composition",
525
+ placeholder="e.g. TiO2",
526
+ value="TiO2"
527
+ )
528
+ gr.Examples(
529
+ examples=["TiO2", "GaN", "ZnO", "Si", "CdS", "SrTiO3"],
530
+ inputs=formula_g
531
+ )
532
+ btn_g = gr.Button("Predict Band Gap", variant="primary")
533
+ with gr.Column(scale=1):
534
+ out_g = gr.Markdown(elem_id="result_text")
535
+ ctx_g = gr.Markdown()
536
+ gr.Markdown(
537
+ "📊 TRIADS V3 · 100K params · **0.3068 eV MAE** (best comp-only)",
538
+ elem_classes="benchmark-badge"
539
+ )
540
+ btn_g.click(predict_expt_gap, inputs=formula_g, outputs=[out_g, ctx_g])
541
+
542
+ # ── TAB 3: METALLICITY ────────────────────────────────────────
543
+ with gr.Tab("🔮 Metallicity"):
544
+ with gr.Row():
545
+ with gr.Column(scale=1):
546
+ gr.Markdown("### Metal vs. Non-metal Classifier")
547
+ gr.Markdown("Predicts electronic metallicity from composition.")
548
+ formula_m = gr.Textbox(
549
+ label="Composition",
550
+ placeholder="e.g. Cu",
551
+ value="Cu"
552
+ )
553
+ gr.Examples(
554
+ examples=["Cu", "SiO2", "Fe3O4", "BaTiO3", "Al", "MgO", "NiO"],
555
+ inputs=formula_m
556
+ )
557
+ btn_m = gr.Button("Classify Metallicity", variant="primary")
558
+ with gr.Column(scale=1):
559
+ out_m = gr.Markdown(elem_id="result_text")
560
+ ctx_m = gr.Markdown()
561
+ gr.Markdown(
562
+ "📊 TRIADS 100K · **0.9655 ROC-AUC** · Best composition-only (beats GPTChem 1B+)",
563
+ elem_classes="benchmark-badge"
564
+ )
565
+ btn_m.click(predict_ismetal, inputs=formula_m, outputs=[out_m, ctx_m])
566
+
567
+ # ── TAB 4: GLASS FORMING ──────────────────────────────────────
568
+ with gr.Tab("🪟 Glass Forming"):
569
+ with gr.Row():
570
+ with gr.Column(scale=1):
571
+ gr.Markdown("### Metallic Glass Forming Ability")
572
+ gr.Markdown("Predicts glass forming probability from alloy composition.")
573
+ formula_gf = gr.Textbox(
574
+ label="Alloy composition",
575
+ placeholder="e.g. Cu46Zr54",
576
+ value="Cu46Zr54"
577
+ )
578
+ gr.Examples(
579
+ examples=["Cu46Zr54", "Fe80B20", "Al86Ni7La6Y1", "Pd40Cu30Ni10P20"],
580
+ inputs=formula_gf
581
+ )
582
+ btn_gf = gr.Button("Predict Glass Forming", variant="primary")
583
+ with gr.Column(scale=1):
584
+ out_gf = gr.Markdown(elem_id="result_text")
585
+ ctx_gf = gr.Markdown()
586
+ gr.Markdown(
587
+ "📊 TRIADS 44K · 5-seed ensemble · **0.9285 ROC-AUC**",
588
+ elem_classes="benchmark-badge"
589
+ )
590
+ btn_gf.click(predict_glass, inputs=formula_gf, outputs=[out_gf, ctx_gf])
591
+
592
+ # ── TAB 5: JDFT2D ─────────────────────────────────────────────
593
+ with gr.Tab("📐 JDFT2D"):
594
+ with gr.Row():
595
+ with gr.Column(scale=1):
596
+ gr.Markdown("### 2D Material Exfoliation Energy (meV/atom)")
597
+ gr.Markdown("Predicts how easily a layered 2D material can be exfoliated.")
598
+ formula_j = gr.Textbox(
599
+ label="Composition",
600
+ placeholder="e.g. MoS2",
601
+ value="MoS2"
602
+ )
603
+ gr.Examples(
604
+ examples=["MoS2", "WSe2", "BN", "MoTe2", "WS2"],
605
+ inputs=formula_j
606
+ )
607
+ btn_j = gr.Button("Predict Exfoliation Energy", variant="primary")
608
+ with gr.Column(scale=1):
609
+ out_j = gr.Markdown(elem_id="result_text")
610
+ ctx_j = gr.Markdown()
611
+ gr.Markdown(
612
+ "📊 TRIADS V4 · 75K params · 5-seed ensemble · **35.89 meV/atom MAE**",
613
+ elem_classes="benchmark-badge"
614
+ )
615
+ btn_j.click(predict_jdft2d, inputs=formula_j, outputs=[out_j, ctx_j])
616
+
617
+ # ── TAB 6: PHONONS ────────────────────────────────────────────
618
+ with gr.Tab("🎵 Phonons"):
619
+ with gr.Row():
620
+ with gr.Column(scale=1):
621
+ gr.Markdown("### Peak Phonon Frequency (cm⁻¹)")
622
+ gr.Markdown(
623
+ "GraphTRIADS V6 predicts phonon peak frequency from crystal structure.\n\n"
624
+ "⚠️ **Structure required.** This model requires a full crystal "
625
+ "structure (CIF) rather than composition alone. Enter a composition "
626
+ "below to get a benchmark context, or see the GitHub repo for full "
627
+ "structure-based inference."
628
+ )
629
+ formula_ph = gr.Textbox(
630
+ label="Formula (for context only)",
631
+ placeholder="e.g. Si",
632
+ value="Si"
633
+ )
634
+ btn_ph = gr.Button("Show Benchmark Info", variant="primary")
635
+ with gr.Column(scale=1):
636
+ out_ph = gr.Markdown(elem_id="result_text")
637
+ ctx_ph = gr.Markdown()
638
+ gr.Markdown(
639
+ "📊 GraphTRIADS V6 · 247K params · Gate-halt · **41.91 cm⁻¹ MAE**",
640
+ elem_classes="benchmark-badge"
641
+ )
642
+ btn_ph.click(predict_phonons_placeholder, inputs=formula_ph, outputs=[out_ph, ctx_ph])
643
+
644
+ # ── FOOTER ──────────────────────────────────────────────────────
645
+ gr.Markdown("""
646
+ ---
647
+ **TRIADS** · [GitHub](https://github.com/Rtx09x/TRIADS) · MIT License · Rudra Tiwari, 2026
648
+
649
+ *All benchmarks use exact Matbench 5-fold CV protocol (random\_state=18012019).
650
+ Predictions are ensemble averages across 5 folds (fold-specific scalers approximated at inference).*
651
+ """)
652
+
653
+ return demo
654
+
655
+
656
+ if __name__ == "__main__":
657
+ demo = build_interface()
658
+ demo.launch(share=False)
model_code/__init__.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ # TRIADS model_code package
2
+ # Import the model classes for convenience
3
+
4
+ from .steels_model import DeepHybridTRM as SteelsModel
5
+ from .expt_gap_model import DeepHybridTRM as ExptGapModel
6
+ from .classification_model import DeepHybridTRM as ClassificationModel
7
+ from .jdft2d_model import DeepHybridTRM as Jdft2dModel
8
+
9
+ __all__ = ["SteelsModel", "ExptGapModel", "ClassificationModel", "Jdft2dModel"]
model_code/classification_model.py ADDED
@@ -0,0 +1,734 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ +=============================================================+
3
+ | TRIADS — Classification Benchmarks (Combined) |
4
+ | |
5
+ | 1. matbench_expt_is_metal (4,921) — Metal vs Non-metal |
6
+ | 2. matbench_glass (5,680) — Metallic Glass Forming |
7
+ | |
8
+ | 44K model | BCEWithLogitsLoss | ROCAUC | Single Seed |
9
+ | Seeds: [42, 123, 456, 789, 1024] |
10
+ | Folds: KFold(5, shuffle=True, random_state=18012019) |
11
+ | ^^^ exact matbench v0.1 fold generation ^^^ |
12
+ +=============================================================+
13
+
14
+ DEPENDENCIES (run before executing):
15
+ pip install matminer pymatgen gensim tqdm scikit-learn torch
16
+
17
+ USAGE:
18
+ python classification_benchmarks.py # runs both sequentially
19
+ """
20
+
21
+ import os, copy, json, time, logging, warnings, urllib.request, shutil
22
+ warnings.filterwarnings('ignore')
23
+
24
+ import numpy as np
25
+ import pandas as pd
26
+ from tqdm import tqdm
27
+ from sklearn.metrics import roc_auc_score
28
+
29
+ import torch
30
+ import torch.nn as nn
31
+ import torch.nn.functional as F
32
+ from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
33
+
34
+ from sklearn.model_selection import KFold
35
+ from sklearn.preprocessing import StandardScaler
36
+ from pymatgen.core import Composition
37
+ from matminer.featurizers.composition import ElementProperty
38
+ from gensim.models import Word2Vec
39
+
40
+ logging.basicConfig(level=logging.INFO, format='%(name)s | %(message)s')
41
+ log = logging.getLogger("TRIADS-CLS")
42
+
43
+ BATCH_SIZE = 64
44
+ # Single seed first — test before committing to full ensemble
45
+ SEEDS = [42]
46
+ # Uncomment below for 5-seed ensemble after single seed looks good:
47
+ # SEEDS = [42, 123, 456, 789, 1024]
48
+
49
+ # ~44K config — smaller to prevent overfitting
50
+ MODEL_CFG = dict(
51
+ d_attn=24, nhead=4, d_hidden=48, ff_dim=72,
52
+ dropout=0.20, max_steps=16,
53
+ )
54
+
55
+ # Matbench v0.1 exact fold seed — DO NOT CHANGE
56
+ MATBENCH_FOLD_SEED = 18012019
57
+
58
+
59
+ # ======================================================================
60
+ # FAST TENSOR DATALOADER
61
+ # ======================================================================
62
+
63
+ class FastTensorDataLoader:
64
+ def __init__(self, *tensors, batch_size=64, shuffle=False):
65
+ assert all(t.shape[0] == tensors[0].shape[0] for t in tensors)
66
+ self.tensors = tensors
67
+ self.dataset_len = tensors[0].shape[0]
68
+ self.batch_size = batch_size
69
+ self.shuffle = shuffle
70
+ self.n_batches = (self.dataset_len + batch_size - 1) // batch_size
71
+
72
+ def __iter__(self):
73
+ if self.shuffle:
74
+ idx = torch.randperm(self.dataset_len, device=self.tensors[0].device)
75
+ self.tensors = tuple(t[idx] for t in self.tensors)
76
+ self.i = 0
77
+ return self
78
+
79
+ def __next__(self):
80
+ if self.i >= self.dataset_len:
81
+ raise StopIteration
82
+ batch = tuple(t[self.i:self.i + self.batch_size] for t in self.tensors)
83
+ self.i += self.batch_size
84
+ return batch
85
+
86
+ def __len__(self):
87
+ return self.n_batches
88
+
89
+
90
+ # ======================================================================
91
+ # FEATURIZERS
92
+ # ======================================================================
93
+
94
+ _ORBITAL_ENERGIES = {
95
+ 'H': {'1s': -13.6}, 'He': {'1s': -24.6},
96
+ 'Li': {'2s': -5.4}, 'Be': {'2s': -9.3},
97
+ 'B': {'2s': -14.0, '2p': -8.3}, 'C': {'2s': -19.4, '2p': -11.3},
98
+ 'N': {'2s': -25.6, '2p': -14.5}, 'O': {'2s': -32.4, '2p': -13.6},
99
+ 'F': {'2s': -40.2, '2p': -17.4}, 'Ne': {'2s': -48.5, '2p': -21.6},
100
+ 'Na': {'3s': -5.1}, 'Mg': {'3s': -7.6},
101
+ 'Al': {'3s': -11.3, '3p': -6.0}, 'Si': {'3s': -15.0, '3p': -8.2},
102
+ 'P': {'3s': -18.7, '3p': -10.5}, 'S': {'3s': -22.7, '3p': -10.4},
103
+ 'Cl': {'3s': -25.3, '3p': -13.0}, 'Ar': {'3s': -29.2, '3p': -15.8},
104
+ 'K': {'4s': -4.3}, 'Ca': {'4s': -6.1},
105
+ 'Sc': {'4s': -6.6, '3d': -8.0}, 'Ti': {'4s': -6.8, '3d': -8.5},
106
+ 'V': {'4s': -6.7, '3d': -8.3}, 'Cr': {'4s': -6.8, '3d': -8.7},
107
+ 'Mn': {'4s': -7.4, '3d': -9.5}, 'Fe': {'4s': -7.9, '3d': -10.0},
108
+ 'Co': {'4s': -7.9, '3d': -10.0}, 'Ni': {'4s': -7.6, '3d': -10.0},
109
+ 'Cu': {'4s': -7.7, '3d': -11.7}, 'Zn': {'4s': -9.4, '3d': -17.3},
110
+ 'Ga': {'4s': -12.6, '4p': -6.0}, 'Ge': {'4s': -15.6, '4p': -7.9},
111
+ 'As': {'4s': -18.6, '4p': -9.8}, 'Se': {'4s': -21.1, '4p': -9.8},
112
+ 'Br': {'4s': -24.0, '4p': -11.8}, 'Kr': {'4s': -27.5, '4p': -14.0},
113
+ 'Rb': {'5s': -4.2}, 'Sr': {'5s': -5.7},
114
+ 'Y': {'5s': -6.5, '4d': -7.4}, 'Zr': {'5s': -6.8, '4d': -8.3},
115
+ 'Nb': {'5s': -6.9, '4d': -8.5}, 'Mo': {'5s': -7.1, '4d': -8.9},
116
+ 'Ru': {'5s': -7.4, '4d': -8.7}, 'Rh': {'5s': -7.5, '4d': -8.8},
117
+ 'Pd': {'4d': -8.3}, 'Ag': {'5s': -7.6, '4d': -12.3},
118
+ 'Cd': {'5s': -9.0, '4d': -16.7}, 'In': {'5s': -12.0, '5p': -5.8},
119
+ 'Sn': {'5s': -14.6, '5p': -7.3}, 'Sb': {'5s': -16.5, '5p': -8.6},
120
+ 'Te': {'5s': -19.0, '5p': -9.0}, 'I': {'5s': -21.1, '5p': -10.5},
121
+ 'Xe': {'5s': -23.4, '5p': -12.1}, 'Cs': {'6s': -3.9}, 'Ba': {'6s': -5.2},
122
+ 'La': {'6s': -5.6, '5d': -7.5},
123
+ 'Ce': {'6s': -5.5, '5d': -7.3, '4f': -7.0},
124
+ 'Hf': {'6s': -7.0, '5d': -8.1}, 'Ta': {'6s': -7.9, '5d': -9.6},
125
+ 'W': {'6s': -8.0, '5d': -9.8}, 'Re': {'6s': -7.9, '5d': -9.2},
126
+ 'Os': {'6s': -8.4, '5d': -10.0}, 'Ir': {'6s': -9.1, '5d': -10.7},
127
+ 'Pt': {'6s': -9.0, '5d': -10.5}, 'Au': {'6s': -9.2, '5d': -12.8},
128
+ 'Pb': {'6s': -15.0, '6p': -7.4}, 'Bi': {'6s': -16.7, '6p': -7.3},
129
+ }
130
+
131
+
132
+ def _compute_homo_lumo_gap(comp):
133
+ elements = comp.get_el_amt_dict()
134
+ highest_occ, all_energies = [], []
135
+ for el, frac in elements.items():
136
+ if el not in _ORBITAL_ENERGIES:
137
+ return np.array([0.0, 0.0, 0.0], dtype=np.float32)
138
+ orbs = _ORBITAL_ENERGIES[el]
139
+ highest_occ.append((max(orbs.values()), frac))
140
+ all_energies.extend(orbs.values())
141
+ if not highest_occ:
142
+ return np.array([0.0, 0.0, 0.0], dtype=np.float32)
143
+ homo = sum(e * f for e, f in highest_occ) / sum(f for _, f in highest_occ)
144
+ above = [e for e in all_energies if e > homo]
145
+ lumo = min(above) if above else homo + 1.0
146
+ return np.array([homo, lumo, lumo - homo], dtype=np.float32)
147
+
148
+
149
+ class _BaseFeaturizer:
150
+ """Shared Mat2Vec loading and Magpie featurization."""
151
+ GCS = "https://storage.googleapis.com/mat2vec/"
152
+ FILES = ["pretrained_embeddings",
153
+ "pretrained_embeddings.wv.vectors.npy",
154
+ "pretrained_embeddings.trainables.syn1neg.npy"]
155
+
156
+ def __init__(self, cache="mat2vec_cache"):
157
+ self.ep_magpie = ElementProperty.from_preset("magpie")
158
+ self.n_mg = len(self.ep_magpie.feature_labels())
159
+ self.n_extra = None
160
+ self.scaler = None
161
+
162
+ os.makedirs(cache, exist_ok=True)
163
+ for f in self.FILES:
164
+ p = os.path.join(cache, f)
165
+ if not os.path.exists(p):
166
+ log.info(f" Downloading {f}...")
167
+ urllib.request.urlretrieve(self.GCS + f, p)
168
+ self.m2v = Word2Vec.load(os.path.join(cache, "pretrained_embeddings"))
169
+ self.emb = {w: self.m2v.wv[w] for w in self.m2v.wv.index_to_key}
170
+
171
+ def _pool(self, c):
172
+ v, t = np.zeros(200, np.float32), 0.0
173
+ for s, f in c.get_el_amt_dict().items():
174
+ if s in self.emb: v += f * self.emb[s]; t += f
175
+ return v / max(t, 1e-8)
176
+
177
+ def featurize_all(self, comps):
178
+ out = []
179
+ test_ex = self._featurize_extra(comps[0])
180
+ self.n_extra = len(test_ex)
181
+ total = self.n_mg + self.n_extra + 200
182
+ log.info(f"Features: {self.n_mg} Magpie + "
183
+ f"{self.n_extra} Extra + 200 Mat2Vec = {total}d")
184
+ for c in tqdm(comps, desc=" Featurizing", leave=False):
185
+ try: mg = np.array(self.ep_magpie.featurize(c), np.float32)
186
+ except: mg = np.zeros(self.n_mg, np.float32)
187
+ ex = self._featurize_extra(c)
188
+ out.append(np.concatenate([
189
+ np.nan_to_num(mg, nan=0.0),
190
+ np.nan_to_num(ex, nan=0.0),
191
+ self._pool(c)
192
+ ]))
193
+ return np.array(out)
194
+
195
+ def fit_scaler(self, X): self.scaler = StandardScaler().fit(X)
196
+ def transform(self, X):
197
+ if not self.scaler: return X
198
+ return np.nan_to_num(self.scaler.transform(X), nan=0.0).astype(np.float32)
199
+
200
+
201
+ class MetallicityFeaturizer(_BaseFeaturizer):
202
+ """354d — keeps HOMO/LUMO gap + BandCenter (relevant to metallicity)."""
203
+ def __init__(self, cache="mat2vec_cache"):
204
+ super().__init__(cache)
205
+ from matminer.featurizers.composition import (
206
+ Stoichiometry, ValenceOrbital, IonProperty, BandCenter
207
+ )
208
+ from matminer.featurizers.composition.element import TMetalFraction
209
+ self.extra_featurizers = [
210
+ ("Stoichiometry", Stoichiometry()),
211
+ ("ValenceOrbital", ValenceOrbital()),
212
+ ("IonProperty", IonProperty()),
213
+ ("BandCenter", BandCenter()),
214
+ ("TMetalFraction", TMetalFraction()),
215
+ ]
216
+ self._extra_sizes = {}
217
+ for name, ftzr in self.extra_featurizers:
218
+ try: self._extra_sizes[name] = len(ftzr.feature_labels())
219
+ except: self._extra_sizes[name] = None
220
+
221
+ def _featurize_extra(self, comp):
222
+ parts = []
223
+ for name, ftzr in self.extra_featurizers:
224
+ try:
225
+ vals = np.array(ftzr.featurize(comp), np.float32)
226
+ parts.append(np.nan_to_num(vals, nan=0.0))
227
+ if self._extra_sizes.get(name) is None:
228
+ self._extra_sizes[name] = len(vals)
229
+ except:
230
+ sz = self._extra_sizes.get(name, 0) or 1
231
+ parts.append(np.zeros(sz, np.float32))
232
+ parts.append(_compute_homo_lumo_gap(comp))
233
+ return np.concatenate(parts)
234
+
235
+
236
+ class GlassFeaturizer(_BaseFeaturizer):
237
+ """~351d — removes BandCenter & HOMO/LUMO (irrelevant to glass forming)."""
238
+ def __init__(self, cache="mat2vec_cache"):
239
+ super().__init__(cache)
240
+ from matminer.featurizers.composition import (
241
+ Stoichiometry, ValenceOrbital, IonProperty
242
+ )
243
+ from matminer.featurizers.composition.element import TMetalFraction
244
+ self.extra_featurizers = [
245
+ ("Stoichiometry", Stoichiometry()),
246
+ ("ValenceOrbital", ValenceOrbital()),
247
+ ("IonProperty", IonProperty()),
248
+ ("TMetalFraction", TMetalFraction()),
249
+ ]
250
+ self._extra_sizes = {}
251
+ for name, ftzr in self.extra_featurizers:
252
+ try: self._extra_sizes[name] = len(ftzr.feature_labels())
253
+ except: self._extra_sizes[name] = None
254
+
255
+ def _featurize_extra(self, comp):
256
+ parts = []
257
+ for name, ftzr in self.extra_featurizers:
258
+ try:
259
+ vals = np.array(ftzr.featurize(comp), np.float32)
260
+ parts.append(np.nan_to_num(vals, nan=0.0))
261
+ if self._extra_sizes.get(name) is None:
262
+ self._extra_sizes[name] = len(vals)
263
+ except:
264
+ sz = self._extra_sizes.get(name, 0) or 1
265
+ parts.append(np.zeros(sz, np.float32))
266
+ return np.concatenate(parts)
267
+
268
+
269
+ # ======================================================================
270
+ # MODEL — DeepHybridTRM (100K params)
271
+ # ======================================================================
272
+
273
+ class DeepHybridTRM(nn.Module):
274
+ def __init__(self, n_props=22, stat_dim=6, n_extra=0, mat2vec_dim=200,
275
+ d_attn=32, nhead=4, d_hidden=64, ff_dim=96,
276
+ dropout=0.15, max_steps=16, **kw):
277
+ super().__init__()
278
+ self.max_steps, self.D = max_steps, d_hidden
279
+ self.n_props, self.stat_dim, self.n_extra = n_props, stat_dim, n_extra
280
+
281
+ self.tok_proj = nn.Sequential(
282
+ nn.Linear(stat_dim, d_attn), nn.LayerNorm(d_attn), nn.GELU())
283
+ self.m2v_proj = nn.Sequential(
284
+ nn.Linear(mat2vec_dim, d_attn), nn.LayerNorm(d_attn), nn.GELU())
285
+
286
+ self.sa1 = nn.MultiheadAttention(d_attn, nhead, dropout=dropout, batch_first=True)
287
+ self.sa1_n = nn.LayerNorm(d_attn)
288
+ self.sa1_ff = nn.Sequential(
289
+ nn.Linear(d_attn, d_attn*2), nn.GELU(), nn.Dropout(dropout),
290
+ nn.Linear(d_attn*2, d_attn))
291
+ self.sa1_fn = nn.LayerNorm(d_attn)
292
+
293
+ self.sa2 = nn.MultiheadAttention(d_attn, nhead, dropout=dropout, batch_first=True)
294
+ self.sa2_n = nn.LayerNorm(d_attn)
295
+ self.sa2_ff = nn.Sequential(
296
+ nn.Linear(d_attn, d_attn*2), nn.GELU(), nn.Dropout(dropout),
297
+ nn.Linear(d_attn*2, d_attn))
298
+ self.sa2_fn = nn.LayerNorm(d_attn)
299
+
300
+ self.ca = nn.MultiheadAttention(d_attn, nhead, dropout=dropout, batch_first=True)
301
+ self.ca_n = nn.LayerNorm(d_attn)
302
+
303
+ pool_in = d_attn + (n_extra if n_extra > 0 else 0)
304
+ self.pool = nn.Sequential(
305
+ nn.Linear(pool_in, d_hidden), nn.LayerNorm(d_hidden), nn.GELU())
306
+
307
+ self.z_up = nn.Sequential(
308
+ nn.Linear(d_hidden*3, ff_dim), nn.GELU(), nn.Dropout(dropout),
309
+ nn.Linear(ff_dim, d_hidden), nn.LayerNorm(d_hidden))
310
+ self.y_up = nn.Sequential(
311
+ nn.Linear(d_hidden*2, ff_dim), nn.GELU(), nn.Dropout(dropout),
312
+ nn.Linear(ff_dim, d_hidden), nn.LayerNorm(d_hidden))
313
+ self.head = nn.Linear(d_hidden, 1)
314
+ self._init()
315
+
316
+ def _init(self):
317
+ for m in self.modules():
318
+ if isinstance(m, nn.Linear):
319
+ nn.init.xavier_uniform_(m.weight)
320
+ if m.bias is not None: nn.init.zeros_(m.bias)
321
+
322
+ def _attention(self, x):
323
+ B = x.size(0)
324
+ mg_dim = self.n_props * self.stat_dim
325
+ if self.n_extra > 0:
326
+ extra = x[:, mg_dim:mg_dim + self.n_extra]
327
+ m2v = x[:, mg_dim + self.n_extra:]
328
+ else:
329
+ extra, m2v = None, x[:, mg_dim:]
330
+ tok = self.tok_proj(x[:, :mg_dim].view(B, self.n_props, self.stat_dim))
331
+ ctx = self.m2v_proj(m2v).unsqueeze(1)
332
+ tok = self.sa1_n(tok + self.sa1(tok, tok, tok)[0])
333
+ tok = self.sa1_fn(tok + self.sa1_ff(tok))
334
+ tok = self.sa2_n(tok + self.sa2(tok, tok, tok)[0])
335
+ tok = self.sa2_fn(tok + self.sa2_ff(tok))
336
+ tok = self.ca_n(tok + self.ca(tok, ctx, ctx)[0])
337
+ pooled = tok.mean(dim=1)
338
+ if extra is not None:
339
+ pooled = torch.cat([pooled, extra], dim=-1)
340
+ return self.pool(pooled)
341
+
342
+ def forward(self, x, deep_supervision=False):
343
+ B = x.size(0)
344
+ xp = self._attention(x)
345
+ z = torch.zeros(B, self.D, device=x.device)
346
+ y = torch.zeros(B, self.D, device=x.device)
347
+ step_preds = []
348
+ for s in range(self.max_steps):
349
+ z = z + self.z_up(torch.cat([xp, y, z], -1))
350
+ y = y + self.y_up(torch.cat([y, z], -1))
351
+ step_preds.append(self.head(y).squeeze(1))
352
+ return step_preds if deep_supervision else step_preds[-1]
353
+
354
+ def count_parameters(self):
355
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
356
+
357
+
358
+ # ======================================================================
359
+ # LOSS + UTILS
360
+ # ======================================================================
361
+
362
+ def deep_supervision_loss_bce(step_preds, targets):
363
+ preds = torch.stack(step_preds)
364
+ n = preds.shape[0]
365
+ w = torch.arange(1, n + 1, device=preds.device, dtype=preds.dtype)
366
+ w = w / w.sum()
367
+ per_step = torch.stack([
368
+ F.binary_cross_entropy_with_logits(preds[i], targets, reduction='mean')
369
+ for i in range(n)
370
+ ])
371
+ return (w * per_step).sum()
372
+
373
+
374
+ def strat_split_cls(targets, val_size=0.15, seed=42):
375
+ tr, vl = [], []
376
+ rng = np.random.RandomState(seed)
377
+ for cls in [0, 1]:
378
+ m = np.where(targets == cls)[0]
379
+ if len(m) == 0: continue
380
+ n = max(1, int(len(m) * val_size))
381
+ c = rng.choice(m, n, replace=False)
382
+ vl.extend(c.tolist()); tr.extend(np.setdiff1d(m, c).tolist())
383
+ return np.array(tr), np.array(vl)
384
+
385
+
386
+ @torch.inference_mode()
387
+ def predict_proba(model, dl):
388
+ model.eval()
389
+ preds = []
390
+ for bx, _ in dl:
391
+ preds.append(torch.sigmoid(model(bx)).cpu())
392
+ return torch.cat(preds)
393
+
394
+
395
+ # ======================================================================
396
+ # TRAINING
397
+ # ======================================================================
398
+
399
+ def train_fold(model, tr_dl, vl_dl, device,
400
+ epochs=300, swa_start=200, fold=1, seed=42, label="100K"):
401
+ opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
402
+ sch = torch.optim.lr_scheduler.CosineAnnealingLR(
403
+ opt, T_max=swa_start, eta_min=1e-4)
404
+ swa_m = AveragedModel(model)
405
+ swa_s = SWALR(opt, swa_lr=5e-4)
406
+ swa_on = False
407
+ best_v, best_w = float('-inf'), None
408
+
409
+ pbar = tqdm(range(epochs), desc=f" [{label}|s{seed}] F{fold}/5",
410
+ leave=False, ncols=120)
411
+ for ep in pbar:
412
+ model.train()
413
+ epoch_loss, n_batches = 0.0, 0
414
+ for bx, by in tr_dl:
415
+ sp = model(bx, deep_supervision=True)
416
+ loss = deep_supervision_loss_bce(sp, by)
417
+ opt.zero_grad(set_to_none=True)
418
+ loss.backward()
419
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
420
+ opt.step()
421
+ epoch_loss += loss.item()
422
+ n_batches += 1
423
+
424
+ model.eval()
425
+ vp_list, vt_list = [], []
426
+ with torch.inference_mode():
427
+ for bx, by in vl_dl:
428
+ vp_list.append(torch.sigmoid(model(bx)).cpu())
429
+ vt_list.append(by.cpu())
430
+ vp = torch.cat(vp_list).numpy()
431
+ vt = torch.cat(vt_list).numpy()
432
+ try: val_auc = roc_auc_score(vt, vp)
433
+ except: val_auc = 0.5
434
+
435
+ if ep < swa_start:
436
+ sch.step()
437
+ if val_auc > best_v:
438
+ best_v = val_auc
439
+ best_w = copy.deepcopy(model.state_dict())
440
+ else:
441
+ if not swa_on: swa_on = True
442
+ swa_m.update_parameters(model); swa_s.step()
443
+
444
+ if ep % 10 == 0 or ep == epochs - 1:
445
+ pbar.set_postfix(Best=f'{best_v:.4f}', Ph='SWA' if swa_on else 'COS',
446
+ Loss=f'{epoch_loss/max(n_batches,1):.4f}',
447
+ AUC=f'{val_auc:.4f}')
448
+
449
+ if swa_on:
450
+ update_bn(tr_dl, swa_m, device=device)
451
+ model.load_state_dict(swa_m.module.state_dict())
452
+ else:
453
+ model.load_state_dict(best_w)
454
+ return best_v, model
455
+
456
+
457
+ # ======================================================================
458
+ # GENERIC BENCHMARK RUNNER
459
+ # ======================================================================
460
+
461
+ def run_classification_benchmark(
462
+ dataset_name, target_col, featurizer_cls,
463
+ model_dir, summary_file, baseline_name, baseline_auc,
464
+ device
465
+ ):
466
+ """Run a full 5-seed ensemble classification benchmark."""
467
+ t0 = time.time()
468
+
469
+ # ── LOAD ─────────────────────────────────────────────────────────
470
+ print(f"\n Loading {dataset_name}...")
471
+ from matminer.datasets import load_dataset
472
+ df = load_dataset(dataset_name)
473
+
474
+ targets_all = np.array(df[target_col].astype(float).tolist(), np.float32)
475
+
476
+ # Handle different column names
477
+ if 'composition' in df.columns:
478
+ comps_all = [Composition(c) for c in df['composition'].tolist()]
479
+ elif 'structure' in df.columns:
480
+ comps_all = [s.composition for s in df['structure'].tolist()]
481
+ elif 'formula' in df.columns:
482
+ comps_all = [Composition(str(f)) for f in df['formula'].tolist()]
483
+ else:
484
+ raise ValueError(f"Cannot find composition column in {df.columns.tolist()}")
485
+
486
+ n_pos = int(targets_all.sum())
487
+ n_neg = len(targets_all) - n_pos
488
+ print(f" Dataset: {len(comps_all)} samples ({n_pos} positive, {n_neg} negative)")
489
+ print(f" Class balance: {n_pos/len(targets_all)*100:.1f}% positive")
490
+
491
+ # ── FEATURIZE (once) ─────────────────────────────────────────────
492
+ t_feat = time.time()
493
+ feat = featurizer_cls()
494
+ X_all = feat.featurize_all(comps_all)
495
+ n_extra = feat.n_extra
496
+ print(f" Features: {X_all.shape} (n_extra={n_extra})")
497
+ print(f" Featurization: {time.time()-t_feat:.1f}s")
498
+
499
+ # ── FOLDS — exact matbench v0.1 splits ───────────────────────────
500
+ kfold = KFold(n_splits=5, shuffle=True, random_state=MATBENCH_FOLD_SEED)
501
+ folds = list(kfold.split(comps_all))
502
+
503
+ # Verify zero leakage
504
+ all_test_indices = []
505
+ for fi, (tv, te) in enumerate(folds):
506
+ assert len(set(tv) & set(te)) == 0, f"Fold {fi}: train/test overlap!"
507
+ all_test_indices.extend(te.tolist())
508
+ assert len(set(all_test_indices)) == len(comps_all), "Not all samples covered!"
509
+ assert len(all_test_indices) == len(comps_all), "Duplicate test samples!"
510
+ print(f" 5 folds verified: zero leakage, full coverage, no duplicates ✓\n")
511
+
512
+ # ── MODEL INFO ───────────────────────────────────────────────────
513
+ model_kw = dict(n_props=22, stat_dim=6, n_extra=n_extra,
514
+ mat2vec_dim=200, **MODEL_CFG)
515
+ test_model = DeepHybridTRM(**model_kw)
516
+ n_params = test_model.count_parameters()
517
+ del test_model
518
+ print(f" Model: {n_params:,} params (100K config)")
519
+
520
+ # ── TRAIN ALL SEEDS ──────────────────────────────────────────────
521
+ os.makedirs(model_dir, exist_ok=True)
522
+ all_seed_aucs = {}
523
+ all_fold_probs = {}
524
+ all_fold_targets = {}
525
+
526
+ for seed in SEEDS:
527
+ print(f"\n {'─'*3} Seed {seed} {'─'*40}")
528
+ t_seed = time.time()
529
+ seed_aucs = {}
530
+
531
+ for fi, (tv_i, te_i) in enumerate(folds):
532
+ tri, vli = strat_split_cls(targets_all[tv_i], 0.15, seed + fi)
533
+ feat.fit_scaler(X_all[tv_i][tri])
534
+
535
+ tr_x = torch.tensor(feat.transform(X_all[tv_i][tri]), dtype=torch.float32).to(device)
536
+ tr_y = torch.tensor(targets_all[tv_i][tri], dtype=torch.float32).to(device)
537
+ vl_x = torch.tensor(feat.transform(X_all[tv_i][vli]), dtype=torch.float32).to(device)
538
+ vl_y = torch.tensor(targets_all[tv_i][vli], dtype=torch.float32).to(device)
539
+ te_x = torch.tensor(feat.transform(X_all[te_i]), dtype=torch.float32).to(device)
540
+ te_y = torch.tensor(targets_all[te_i], dtype=torch.float32).to(device)
541
+
542
+ tr_dl = FastTensorDataLoader(tr_x, tr_y, batch_size=BATCH_SIZE, shuffle=True)
543
+ vl_dl = FastTensorDataLoader(vl_x, vl_y, batch_size=BATCH_SIZE, shuffle=False)
544
+ te_dl = FastTensorDataLoader(te_x, te_y, batch_size=BATCH_SIZE, shuffle=False)
545
+
546
+ torch.manual_seed(seed + fi)
547
+ np.random.seed(seed + fi)
548
+ if device.type == 'cuda': torch.cuda.manual_seed(seed + fi)
549
+
550
+ model = DeepHybridTRM(**model_kw).to(device)
551
+ bv, model = train_fold(model, tr_dl, vl_dl, device,
552
+ epochs=300, swa_start=200,
553
+ fold=fi+1, seed=seed, label="44K")
554
+
555
+ probs = predict_proba(model, te_dl)
556
+ auc = roc_auc_score(te_y.cpu().numpy(), probs.numpy())
557
+ seed_aucs[fi] = auc
558
+
559
+ if fi not in all_fold_probs:
560
+ all_fold_probs[fi] = {}
561
+ all_fold_targets[fi] = te_y.cpu()
562
+ all_fold_probs[fi][seed] = probs
563
+
564
+ torch.save({
565
+ 'model_state': model.state_dict(),
566
+ 'test_auc': auc, 'fold': fi+1, 'seed': seed,
567
+ 'n_extra': n_extra,
568
+ }, f'{model_dir}/{dataset_name}_100K_s{seed}_f{fi+1}.pt')
569
+
570
+ del model, tr_x, tr_y, vl_x, vl_y, te_x, te_y
571
+ if device.type == 'cuda': torch.cuda.empty_cache()
572
+
573
+ avg_s = np.mean(list(seed_aucs.values()))
574
+ all_seed_aucs[seed] = seed_aucs
575
+ dt = time.time() - t_seed
576
+ print(f"\n Seed {seed}: avg={avg_s:.4f} | "
577
+ f"{[f'{seed_aucs[i]:.4f}' for i in range(5)]} ({dt:.0f}s)")
578
+
579
+ # ── ENSEMBLE ─────────────────────────────────────────────────────
580
+ ens_aucs = {}
581
+ for fi in range(5):
582
+ probs_stack = torch.stack([all_fold_probs[fi][s] for s in SEEDS])
583
+ ens_prob = probs_stack.mean(dim=0)
584
+ ens_aucs[fi] = roc_auc_score(
585
+ all_fold_targets[fi].numpy(), ens_prob.numpy())
586
+
587
+ single_avgs = [np.mean(list(all_seed_aucs[s].values())) for s in SEEDS]
588
+ single_mean = np.mean(single_avgs)
589
+ single_std = np.std(single_avgs)
590
+ ens_mean = np.mean(list(ens_aucs.values()))
591
+ ens_std = np.std(list(ens_aucs.values()))
592
+
593
+ tt = time.time() - t0
594
+
595
+ print(f"""
596
+ {'='*72}
597
+ FINAL RESULTS — TRIADS on {dataset_name} (ROCAUC)
598
+ {'='*72}
599
+
600
+ Per-seed results:""")
601
+ for seed in SEEDS:
602
+ sm = all_seed_aucs[seed]
603
+ avg_s = np.mean(list(sm.values()))
604
+ print(f" Seed {seed:>4}: {avg_s:.4f} | "
605
+ f"{[f'{sm[i]:.4f}' for i in range(5)]}")
606
+
607
+ print(f"""
608
+ Single-seed avg: {single_mean:.4f} ± {single_std:.4f}
609
+ 5-Seed Ensemble: {ens_mean:.4f} ± {ens_std:.4f}
610
+ Per-fold ens: {[f'{ens_aucs[i]:.4f}' for i in range(5)]}
611
+
612
+ {'Model':<40} {'ROCAUC':>10}
613
+ {'─'*53}
614
+ {baseline_name:<40} {baseline_auc:>10}
615
+ {'TRIADS (44K, 5-seed ens)':<40} {f'{ens_mean:.4f}':>10} ← US
616
+ {'─'*53}
617
+
618
+ Total time: {tt/60:.1f} min
619
+ Saved: {model_dir}/
620
+ """)
621
+
622
+ summary = {
623
+ 'dataset': dataset_name,
624
+ 'task': 'classification',
625
+ 'metric': 'ROCAUC',
626
+ 'samples': len(comps_all),
627
+ 'class_balance': f'{n_pos} positive / {n_neg} negative',
628
+ 'model_config': MODEL_CFG,
629
+ 'params': n_params,
630
+ 'seeds': SEEDS,
631
+ 'fold_seed': MATBENCH_FOLD_SEED,
632
+ 'per_seed': {str(s): {str(k): round(v, 4) for k, v in m.items()}
633
+ for s, m in all_seed_aucs.items()},
634
+ 'single_seed_avg': round(single_mean, 4),
635
+ 'single_seed_std': round(single_std, 4),
636
+ 'ensemble_aucs': {str(k): round(v, 4) for k, v in ens_aucs.items()},
637
+ 'ensemble_avg': round(ens_mean, 4),
638
+ 'ensemble_std': round(ens_std, 4),
639
+ 'total_time_min': round(tt/60, 1),
640
+ }
641
+ with open(summary_file, 'w') as f:
642
+ json.dump(summary, f, indent=2)
643
+ print(f" Saved: {summary_file}")
644
+
645
+ shutil.make_archive(model_dir, 'zip', '.', model_dir)
646
+ print(f" Saved: {model_dir}.zip")
647
+
648
+ return ens_mean
649
+
650
+
651
+ # ======================================================================
652
+ # MAIN — RUN BOTH SEQUENTIALLY
653
+ # ======================================================================
654
+
655
+ if __name__ == '__main__':
656
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
657
+ if device.type == 'cuda':
658
+ gm = torch.cuda.get_device_properties(0).total_memory / 1e9
659
+ print(f" GPU: {torch.cuda.get_device_name(0)} ({gm:.1f} GB)")
660
+ torch.backends.cuda.matmul.allow_tf32 = True
661
+ torch.backends.cudnn.benchmark = True
662
+
663
+ print(f"""
664
+ ╔══════════════════════════════════════════════════════════╗
665
+ ║ TRIADS Classification Benchmarks ║
666
+ ║ 44K model | 5-Seed Ensemble | BCEWithLogitsLoss ║
667
+ ║ Fold seed: {MATBENCH_FOLD_SEED} (matbench v0.1 standard) ║
668
+ ╠══════════════════════════════════════════════════════════╣
669
+ ║ 1. matbench_expt_is_metal (4,921 samples) ║
670
+ ║ 2. matbench_glass (5,680 samples) ║
671
+ ╚══════════════════════════════════════════════════════════╝
672
+ """)
673
+
674
+ t_total = time.time()
675
+ results = {}
676
+
677
+ # ── BENCHMARK 1: expt_is_metal ───────────────────────────────────
678
+ print("\n" + "█"*72)
679
+ print(" BENCHMARK 1/2: matbench_expt_is_metal")
680
+ print("█"*72)
681
+
682
+ auc1 = run_classification_benchmark(
683
+ dataset_name="matbench_expt_is_metal",
684
+ target_col="is_metal",
685
+ featurizer_cls=MetallicityFeaturizer,
686
+ model_dir="is_metal_models",
687
+ summary_file="is_metal_summary.json",
688
+ baseline_name="AMMExpress v2020",
689
+ baseline_auc="0.9209",
690
+ device=device,
691
+ )
692
+ results['is_metal'] = auc1
693
+
694
+ # ── BENCHMARK 2: glass ───────────────────────────────────────────
695
+ print("\n" + "█"*72)
696
+ print(" BENCHMARK 2/2: matbench_glass")
697
+ print("█"*72)
698
+
699
+ auc2 = run_classification_benchmark(
700
+ dataset_name="matbench_glass",
701
+ target_col="gfa",
702
+ featurizer_cls=GlassFeaturizer,
703
+ model_dir="glass_models",
704
+ summary_file="glass_summary.json",
705
+ baseline_name="MODNet v0.1.12",
706
+ baseline_auc="0.9603",
707
+ device=device,
708
+ )
709
+ results['glass'] = auc2
710
+
711
+ # ── COMBINED SUMMARY ─────────────────────────────────────────────
712
+ tt = time.time() - t_total
713
+ print(f"""
714
+
715
+ {'='*72}
716
+ COMBINED RESULTS — ALL CLASSIFICATION BENCHMARKS
717
+ {'='*72}
718
+
719
+ {'Dataset':<30} {'Baseline':>10} {'TRIADS':>10}
720
+ {'─'*53}
721
+ {'matbench_expt_is_metal':<30} {'0.9209':>10} {f'{auc1:.4f}':>10}
722
+ {'matbench_glass':<30} {'0.9603':>10} {f'{auc2:.4f}':>10}
723
+ {'─'*53}
724
+
725
+ Grand total time: {tt/60:.1f} min ({tt/3600:.1f} hrs)
726
+
727
+ ALL TRIADS BENCHMARKS:
728
+ ────��────────────────
729
+ steels: 91.20 MPa (#1-2)
730
+ expt_gap: 0.3068 eV (#2)
731
+ jdft2d: 35.89 meV/atom (#3)
732
+ is_metal: {auc1:.4f} ROCAUC
733
+ glass: {auc2:.4f} ROCAUC
734
+ """)
model_code/expt_gap_model.py ADDED
@@ -0,0 +1,579 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ +=============================================================+
3
+ | TRIADS V3 on matbench_expt_gap |
4
+ | 2x T4 GPU Parallel Training (auto-fallback to 1 GPU) |
5
+ | 4 Models: Steps(16,20) x Dropout(0.15,0.20) |
6
+ | Proven arch: d_attn=64, d_hidden=96 | batch_size=64 |
7
+ | FastTensorDataLoader | Clean output |
8
+ +=============================================================+
9
+ """
10
+
11
+ import os, copy, json, time, logging, warnings, urllib.request
12
+ warnings.filterwarnings('ignore')
13
+
14
+ import numpy as np
15
+ import pandas as pd
16
+
17
+ import matplotlib
18
+ matplotlib.use('Agg')
19
+ import matplotlib.pyplot as plt
20
+
21
+ from tqdm import tqdm
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+ import torch.nn.functional as F
26
+ from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
27
+
28
+ from sklearn.model_selection import KFold
29
+ from sklearn.preprocessing import StandardScaler
30
+ from pymatgen.core import Composition
31
+ from matminer.featurizers.composition import ElementProperty
32
+ from gensim.models import Word2Vec
33
+
34
+ logging.basicConfig(level=logging.INFO, format='%(name)s | %(message)s')
35
+ log = logging.getLogger("TRIADS-V3")
36
+
37
+ SEEDS = [42]
38
+ BATCH_SIZE = 64
39
+
40
+ BASELINES = {
41
+ 'Darwin': 0.2865,
42
+ 'Ax/SAASBO CrabNet': 0.3310,
43
+ 'MODNet v0.1.12': 0.3327,
44
+ 'AMMExpress v2020': 0.4161,
45
+ 'CrabNet': 0.4427,
46
+ 'RF-SCM/Magpie': 0.5205,
47
+ 'Dummy': 1.0280,
48
+ }
49
+ V1_BEST = {'EG-A (V1)': 0.3510, 'EG-B (V1)': 0.3616}
50
+
51
+ # Use ALL available CPU cores for PyTorch operations
52
+ torch.set_num_threads(4) # 4 vCPUs on Kaggle
53
+ torch.set_num_interop_threads(2) # 2 physical cores
54
+
55
+
56
+ # ======================================================================
57
+ # FAST TENSOR DATALOADER
58
+ # ======================================================================
59
+
60
+ class FastTensorDataLoader:
61
+ """Zero-CPU DataLoader. Entire dataset in GPU VRAM."""
62
+ def __init__(self, *tensors, batch_size=64, shuffle=False):
63
+ assert all(t.shape[0] == tensors[0].shape[0] for t in tensors)
64
+ self.tensors = tensors
65
+ self.dataset_len = tensors[0].shape[0]
66
+ self.batch_size = batch_size
67
+ self.shuffle = shuffle
68
+ self.n_batches = (self.dataset_len + batch_size - 1) // batch_size
69
+
70
+ def __iter__(self):
71
+ if self.shuffle:
72
+ idx = torch.randperm(self.dataset_len, device=self.tensors[0].device)
73
+ self.tensors = tuple(t[idx] for t in self.tensors)
74
+ self.i = 0
75
+ return self
76
+
77
+ def __next__(self):
78
+ if self.i >= self.dataset_len:
79
+ raise StopIteration
80
+ batch = tuple(t[self.i:self.i + self.batch_size] for t in self.tensors)
81
+ self.i += self.batch_size
82
+ return batch
83
+
84
+ def __len__(self):
85
+ return self.n_batches
86
+
87
+
88
+ # ======================================================================
89
+ # FEATURIZER
90
+ # ======================================================================
91
+
92
+ class ExpandedFeaturizer:
93
+ GCS = "https://storage.googleapis.com/mat2vec/"
94
+ FILES = ["pretrained_embeddings",
95
+ "pretrained_embeddings.wv.vectors.npy",
96
+ "pretrained_embeddings.trainables.syn1neg.npy"]
97
+
98
+ def __init__(self, cache="mat2vec_cache"):
99
+ from matminer.featurizers.composition import (
100
+ ElementFraction, Stoichiometry, ValenceOrbital,
101
+ IonProperty, BandCenter
102
+ )
103
+ from matminer.featurizers.base import MultipleFeaturizer
104
+ self.ep_magpie = ElementProperty.from_preset("magpie")
105
+ self.n_mg = len(self.ep_magpie.feature_labels())
106
+ self.extra_feats = MultipleFeaturizer([
107
+ ElementFraction(), Stoichiometry(), ValenceOrbital(),
108
+ IonProperty(), BandCenter(),
109
+ ])
110
+ self.n_extra = None
111
+ self.scaler = None
112
+ os.makedirs(cache, exist_ok=True)
113
+ for f in self.FILES:
114
+ p = os.path.join(cache, f)
115
+ if not os.path.exists(p):
116
+ log.info(f" Downloading {f}...")
117
+ urllib.request.urlretrieve(self.GCS + f, p)
118
+ self.m2v = Word2Vec.load(os.path.join(cache, "pretrained_embeddings"))
119
+ self.emb = {w: self.m2v.wv[w] for w in self.m2v.wv.index_to_key}
120
+
121
+ def _pool(self, c):
122
+ v, t = np.zeros(200, np.float32), 0.0
123
+ for s, f in c.get_el_amt_dict().items():
124
+ if s in self.emb: v += f * self.emb[s]; t += f
125
+ return v / max(t, 1e-8)
126
+
127
+ def featurize_all(self, comps):
128
+ out = []
129
+ for c in tqdm(comps, desc=" Featurizing", leave=False):
130
+ try: mg = np.array(self.ep_magpie.featurize(c), np.float32)
131
+ except: mg = np.zeros(self.n_mg, np.float32)
132
+ try: ex = np.array(self.extra_feats.featurize(c), np.float32)
133
+ except: ex = np.zeros(self.n_extra or 200, np.float32)
134
+ if self.n_extra is None:
135
+ self.n_extra = len(ex)
136
+ log.info(f"Features: {self.n_mg} Magpie + {self.n_extra} Extra + 200 Mat2Vec")
137
+ out.append(np.concatenate([
138
+ np.nan_to_num(mg, nan=0.0),
139
+ np.nan_to_num(ex, nan=0.0),
140
+ self._pool(c)
141
+ ]))
142
+ return np.array(out)
143
+
144
+ def fit_scaler(self, X): self.scaler = StandardScaler().fit(X)
145
+ def transform(self, X):
146
+ if not self.scaler: return X
147
+ return np.nan_to_num(self.scaler.transform(X), nan=0.0).astype(np.float32)
148
+
149
+
150
+ # ======================================================================
151
+ # MODEL — DeepHybridTRM (V13A proven architecture)
152
+ # ======================================================================
153
+
154
+ class DeepHybridTRM(nn.Module):
155
+ def __init__(self, n_props=22, stat_dim=6, n_extra=0, mat2vec_dim=200,
156
+ d_attn=64, nhead=4, d_hidden=96, ff_dim=150,
157
+ dropout=0.2, max_steps=20, **kw):
158
+ super().__init__()
159
+ self.max_steps, self.D = max_steps, d_hidden
160
+ self.n_props, self.stat_dim, self.n_extra = n_props, stat_dim, n_extra
161
+
162
+ self.tok_proj = nn.Sequential(
163
+ nn.Linear(stat_dim, d_attn), nn.LayerNorm(d_attn), nn.GELU())
164
+ self.m2v_proj = nn.Sequential(
165
+ nn.Linear(mat2vec_dim, d_attn), nn.LayerNorm(d_attn), nn.GELU())
166
+
167
+ self.sa1 = nn.MultiheadAttention(d_attn, nhead, dropout=dropout, batch_first=True)
168
+ self.sa1_n = nn.LayerNorm(d_attn)
169
+ self.sa1_ff = nn.Sequential(
170
+ nn.Linear(d_attn, d_attn*2), nn.GELU(), nn.Dropout(dropout),
171
+ nn.Linear(d_attn*2, d_attn))
172
+ self.sa1_fn = nn.LayerNorm(d_attn)
173
+
174
+ self.sa2 = nn.MultiheadAttention(d_attn, nhead, dropout=dropout, batch_first=True)
175
+ self.sa2_n = nn.LayerNorm(d_attn)
176
+ self.sa2_ff = nn.Sequential(
177
+ nn.Linear(d_attn, d_attn*2), nn.GELU(), nn.Dropout(dropout),
178
+ nn.Linear(d_attn*2, d_attn))
179
+ self.sa2_fn = nn.LayerNorm(d_attn)
180
+
181
+ self.ca = nn.MultiheadAttention(d_attn, nhead, dropout=dropout, batch_first=True)
182
+ self.ca_n = nn.LayerNorm(d_attn)
183
+
184
+ pool_in = d_attn + (n_extra if n_extra > 0 else 0)
185
+ self.pool = nn.Sequential(
186
+ nn.Linear(pool_in, d_hidden), nn.LayerNorm(d_hidden), nn.GELU())
187
+
188
+ self.z_up = nn.Sequential(
189
+ nn.Linear(d_hidden*3, ff_dim), nn.GELU(), nn.Dropout(dropout),
190
+ nn.Linear(ff_dim, d_hidden), nn.LayerNorm(d_hidden))
191
+ self.y_up = nn.Sequential(
192
+ nn.Linear(d_hidden*2, ff_dim), nn.GELU(), nn.Dropout(dropout),
193
+ nn.Linear(ff_dim, d_hidden), nn.LayerNorm(d_hidden))
194
+ self.head = nn.Linear(d_hidden, 1)
195
+ self._init()
196
+
197
+ def _init(self):
198
+ for m in self.modules():
199
+ if isinstance(m, nn.Linear):
200
+ nn.init.xavier_uniform_(m.weight)
201
+ if m.bias is not None: nn.init.zeros_(m.bias)
202
+
203
+ def _attention(self, x):
204
+ B = x.size(0)
205
+ mg_dim = self.n_props * self.stat_dim
206
+ if self.n_extra > 0:
207
+ extra = x[:, mg_dim:mg_dim + self.n_extra]
208
+ m2v = x[:, mg_dim + self.n_extra:]
209
+ else:
210
+ extra, m2v = None, x[:, mg_dim:]
211
+
212
+ tok = self.tok_proj(x[:, :mg_dim].view(B, self.n_props, self.stat_dim))
213
+ ctx = self.m2v_proj(m2v).unsqueeze(1)
214
+
215
+ tok = self.sa1_n(tok + self.sa1(tok, tok, tok)[0])
216
+ tok = self.sa1_fn(tok + self.sa1_ff(tok))
217
+ tok = self.sa2_n(tok + self.sa2(tok, tok, tok)[0])
218
+ tok = self.sa2_fn(tok + self.sa2_ff(tok))
219
+ tok = self.ca_n(tok + self.ca(tok, ctx, ctx)[0])
220
+
221
+ pooled = tok.mean(dim=1)
222
+ if extra is not None:
223
+ pooled = torch.cat([pooled, extra], dim=-1)
224
+ return self.pool(pooled)
225
+
226
+ def forward(self, x, deep_supervision=False):
227
+ B = x.size(0)
228
+ xp = self._attention(x)
229
+ z = torch.zeros(B, self.D, device=x.device)
230
+ y = torch.zeros(B, self.D, device=x.device)
231
+ step_preds = []
232
+ for s in range(self.max_steps):
233
+ z = z + self.z_up(torch.cat([xp, y, z], -1))
234
+ y = y + self.y_up(torch.cat([y, z], -1))
235
+ step_preds.append(self.head(y).squeeze(1))
236
+ return step_preds if deep_supervision else step_preds[-1]
237
+
238
+ def count_parameters(self):
239
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
240
+
241
+
242
+ # ======================================================================
243
+ # LOSS + UTILS
244
+ # ======================================================================
245
+
246
+ def deep_supervision_loss(step_preds, targets):
247
+ n = len(step_preds)
248
+ weights = [(i+1) for i in range(n)]
249
+ tw = sum(weights)
250
+ return sum((w/tw) * F.l1_loss(p, targets) for p, w in zip(step_preds, weights))
251
+
252
+
253
+ def strat_split(targets, val_size=0.15, seed=42):
254
+ bins = np.percentile(targets, [25, 50, 75])
255
+ lbl = np.digitize(targets, bins)
256
+ tr, vl = [], []
257
+ rng = np.random.RandomState(seed)
258
+ for b in range(4):
259
+ m = np.where(lbl == b)[0]
260
+ if len(m) == 0: continue
261
+ n = max(1, int(len(m) * val_size))
262
+ c = rng.choice(m, n, replace=False)
263
+ vl.extend(c.tolist()); tr.extend(np.setdiff1d(m, c).tolist())
264
+ return np.array(tr), np.array(vl)
265
+
266
+
267
+ def predict(model, dl):
268
+ model.eval(); preds = []
269
+ with torch.no_grad():
270
+ for bx, _ in dl:
271
+ preds.append(model(bx).cpu())
272
+ return torch.cat(preds)
273
+
274
+
275
+ # ======================================================================
276
+ # TRAINING — clean, simple, V1-style
277
+ # ======================================================================
278
+
279
+ def train_fold(model, tr_dl, vl_dl, device,
280
+ epochs=300, swa_start=200, fold=1, name="", gpu_tag=""):
281
+ opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
282
+ sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=swa_start, eta_min=1e-4)
283
+ swa_m = AveragedModel(model)
284
+ swa_s = SWALR(opt, swa_lr=5e-4)
285
+ swa_on = False
286
+ best_v, best_w = float('inf'), copy.deepcopy(model.state_dict())
287
+ hist = {'train': [], 'val': []}
288
+ use_amp = (device.type == 'cuda')
289
+ scaler = torch.amp.GradScaler('cuda', enabled=use_amp)
290
+
291
+ pbar = tqdm(range(epochs), desc=f" {gpu_tag}[{name}] F{fold}/5",
292
+ leave=False, ncols=120)
293
+ for ep in pbar:
294
+ model.train(); tl = 0.0
295
+ for bx, by in tr_dl:
296
+ with torch.amp.autocast('cuda', enabled=use_amp):
297
+ sp = model(bx, deep_supervision=True)
298
+ loss = deep_supervision_loss(sp, by)
299
+ opt.zero_grad(set_to_none=True)
300
+ scaler.scale(loss).backward()
301
+ scaler.unscale_(opt)
302
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
303
+ scaler.step(opt)
304
+ scaler.update()
305
+ tl += F.l1_loss(sp[-1], by).item() * len(by)
306
+ tl /= tr_dl.dataset_len
307
+
308
+ model.eval(); vl = 0.0
309
+ with torch.no_grad():
310
+ with torch.amp.autocast('cuda', enabled=use_amp):
311
+ for bx, by in vl_dl:
312
+ vl += F.l1_loss(model(bx), by).item() * len(by)
313
+ vl /= vl_dl.dataset_len
314
+ hist['train'].append(tl); hist['val'].append(vl)
315
+
316
+ if ep < swa_start:
317
+ sch.step()
318
+ if vl < best_v:
319
+ best_v = vl
320
+ best_w = copy.deepcopy(model.state_dict())
321
+ else:
322
+ if not swa_on: swa_on = True
323
+ swa_m.update_parameters(model); swa_s.step()
324
+
325
+ pbar.set_postfix(Best=f'{best_v:.4f}', Ph='SWA' if swa_on else 'COS',
326
+ Tr=f'{tl:.4f}', Val=f'{vl:.4f}')
327
+
328
+ if swa_on:
329
+ update_bn(tr_dl, swa_m, device=device)
330
+ model.load_state_dict(swa_m.module.state_dict())
331
+ else:
332
+ model.load_state_dict(best_w)
333
+ return best_v, model, hist
334
+
335
+
336
+ # ======================================================================
337
+ # GPU WORKER — trains assigned models on one GPU
338
+ # ======================================================================
339
+
340
+ def gpu_worker(gpu_id, config_list, X_all, targets_all, folds, n_extra,
341
+ result_file):
342
+ device = torch.device(f'cuda:{gpu_id}')
343
+ torch.cuda.set_device(gpu_id)
344
+ tag = f"[GPU{gpu_id}] "
345
+
346
+ print(f"\n {tag}Started on {torch.cuda.get_device_name(gpu_id)}")
347
+ print(f" {tag}Models: {[c[0] for c in config_list]}")
348
+
349
+ feat = ExpandedFeaturizer()
350
+ results = {}
351
+
352
+ for ci, (cname, model_kw) in enumerate(config_list):
353
+ print(f"\n {tag}{'='*50}")
354
+ print(f" {tag}[{ci+1}/{len(config_list)}] {cname}")
355
+ print(f" {tag}{'='*50}")
356
+
357
+ seed = SEEDS[0]
358
+ fold_maes = []
359
+
360
+ for fi, (tv_i, te_i) in enumerate(folds):
361
+ print(f"\n {tag}-- [{cname}] Fold {fi+1}/5 " + "-"*20)
362
+
363
+ tri, vli = strat_split(targets_all[tv_i], 0.15, seed + fi)
364
+ feat.fit_scaler(X_all[tv_i][tri])
365
+
366
+ tr_x = torch.tensor(feat.transform(X_all[tv_i][tri]), dtype=torch.float32).to(device)
367
+ tr_y = torch.tensor(targets_all[tv_i][tri], dtype=torch.float32).to(device)
368
+ vl_x = torch.tensor(feat.transform(X_all[tv_i][vli]), dtype=torch.float32).to(device)
369
+ vl_y = torch.tensor(targets_all[tv_i][vli], dtype=torch.float32).to(device)
370
+ te_x = torch.tensor(feat.transform(X_all[te_i]), dtype=torch.float32).to(device)
371
+ te_y = torch.tensor(targets_all[te_i], dtype=torch.float32).to(device)
372
+
373
+ tr_dl = FastTensorDataLoader(tr_x, tr_y, batch_size=BATCH_SIZE, shuffle=True)
374
+ vl_dl = FastTensorDataLoader(vl_x, vl_y, batch_size=BATCH_SIZE, shuffle=False)
375
+ te_dl = FastTensorDataLoader(te_x, te_y, batch_size=BATCH_SIZE, shuffle=False)
376
+
377
+ torch.manual_seed(seed + fi)
378
+ np.random.seed(seed + fi)
379
+ torch.cuda.manual_seed(seed + fi)
380
+
381
+ model = DeepHybridTRM(**model_kw).to(device)
382
+ if fi == 0:
383
+ print(f" {tag}Params: {model.count_parameters():,}")
384
+
385
+ bv, model, hist = train_fold(
386
+ model, tr_dl, vl_dl, device,
387
+ epochs=300, swa_start=200, fold=fi+1, name=cname, gpu_tag=tag)
388
+
389
+ pred = predict(model, te_dl)
390
+ mae = F.l1_loss(pred, te_y.cpu()).item()
391
+ print(f" {tag}Fold {fi+1} TEST: {mae:.4f} eV (val best: {bv:.4f})")
392
+
393
+ fold_maes.append(mae)
394
+ os.makedirs('expt_gap_models_v3', exist_ok=True)
395
+ torch.save({
396
+ 'model_state': model.state_dict(),
397
+ 'test_mae': mae, 'config': cname, 'seed': seed,
398
+ 'fold': fi+1, 'n_extra': n_extra,
399
+ }, f'expt_gap_models_v3/{cname}_s{seed}_f{fi+1}.pt')
400
+
401
+ del model, tr_x, tr_y, vl_x, vl_y, te_x, te_y
402
+ torch.cuda.empty_cache()
403
+
404
+ avg = float(np.mean(fold_maes))
405
+ std = float(np.std(fold_maes))
406
+ results[cname] = {'avg': avg, 'std': std, 'folds': fold_maes}
407
+
408
+ print(f"\n {tag}=== {cname} ===")
409
+ print(f" {tag} 5-Fold Avg MAE: {avg:.4f} +/- {std:.4f} eV")
410
+ print(f" {tag} Per-fold: {[f'{m:.4f}' for m in fold_maes]}")
411
+
412
+ with open(result_file, 'w') as f:
413
+ json.dump(results, f)
414
+ print(f"\n {tag}DONE. Saved to {result_file}")
415
+
416
+
417
+ # ======================================================================
418
+ # MAIN
419
+ # ======================================================================
420
+
421
+ def run_benchmark():
422
+ t0 = time.time()
423
+
424
+ print(f"""
425
+ +==========================================================+
426
+ | TRIADS V3 -- P100 | FastTensorDataLoader |
427
+ | 4 Models: Steps(16,20) x Dropout(0.15,0.20) |
428
+ | d_attn=64, d_hidden=96 (proven V1 arch) |
429
+ | batch_size={BATCH_SIZE} | All CPU cores active |
430
+ +==========================================================+
431
+ """)
432
+
433
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
434
+ if device.type == 'cuda':
435
+ try: gm = torch.cuda.get_device_properties(0).total_memory / 1e9
436
+ except: gm = 0
437
+ print(f" GPU: {torch.cuda.get_device_name(0)} ({gm:.1f} GB)")
438
+ print(f" CPU threads: {torch.get_num_threads()} | Interop: {torch.get_num_interop_threads()}")
439
+ torch.backends.cuda.matmul.allow_tf32 = True
440
+ torch.backends.cudnn.benchmark = True
441
+
442
+ # ---- LOAD + FEATURIZE ----
443
+ print("\n Loading matbench_expt_gap...")
444
+ from matminer.datasets import load_dataset
445
+ df = load_dataset("matbench_expt_gap")
446
+ targets_all = np.array(df['gap expt'].tolist(), np.float32)
447
+ comps_all = [Composition(c) for c in df['composition'].tolist()]
448
+ print(f" Dataset: {len(comps_all)} samples")
449
+
450
+ feat = ExpandedFeaturizer()
451
+ X_all = feat.featurize_all(comps_all)
452
+ n_extra = feat.n_extra
453
+ print(f" Features: {X_all.shape}")
454
+
455
+ kfold = KFold(n_splits=5, shuffle=True, random_state=18012019)
456
+ folds = list(kfold.split(comps_all))
457
+ for fi, (tv, te) in enumerate(folds):
458
+ assert len(set(tv) & set(te)) == 0
459
+ print(" 5 folds verified: zero leakage")
460
+
461
+ # ---- CONFIGS ----
462
+ base = dict(n_props=22, stat_dim=6, n_extra=n_extra, mat2vec_dim=200,
463
+ d_attn=64, nhead=4, d_hidden=96, ff_dim=150)
464
+
465
+ all_configs = [
466
+ ('V3-S16-D15', {**base, 'max_steps': 16, 'dropout': 0.15}),
467
+ ('V3-S16-D20', {**base, 'max_steps': 16, 'dropout': 0.20}),
468
+ ('V3-S20-D15', {**base, 'max_steps': 20, 'dropout': 0.15}),
469
+ ('V3-S20-D20', {**base, 'max_steps': 20, 'dropout': 0.20}),
470
+ ]
471
+
472
+ print(f"\n {'Config':<16} {'Params':>10} {'Steps':>6} {'Drop':>6}")
473
+ for cn, kw in all_configs:
474
+ m = DeepHybridTRM(**kw); print(f" {cn:<16} {m.count_parameters():>10,} {kw['max_steps']:>6} {kw['dropout']:>6.2f}"); del m
475
+
476
+ # ---- TRAIN ----
477
+ all_results = {}
478
+
479
+ for ci, (cname, model_kw) in enumerate(all_configs):
480
+ print(f"\n {'='*60}")
481
+ print(f" [{ci+1}/4] {cname}")
482
+ print(f" {'='*60}")
483
+
484
+ seed = SEEDS[0]
485
+ fold_maes = []
486
+
487
+ for fi, (tv_i, te_i) in enumerate(folds):
488
+ print(f"\n -- [{cname}] Fold {fi+1}/5 " + "-"*30)
489
+ tri, vli = strat_split(targets_all[tv_i], 0.15, seed + fi)
490
+ feat.fit_scaler(X_all[tv_i][tri])
491
+
492
+ tr_x = torch.tensor(feat.transform(X_all[tv_i][tri]), dtype=torch.float32).to(device)
493
+ tr_y = torch.tensor(targets_all[tv_i][tri], dtype=torch.float32).to(device)
494
+ vl_x = torch.tensor(feat.transform(X_all[tv_i][vli]), dtype=torch.float32).to(device)
495
+ vl_y = torch.tensor(targets_all[tv_i][vli], dtype=torch.float32).to(device)
496
+ te_x = torch.tensor(feat.transform(X_all[te_i]), dtype=torch.float32).to(device)
497
+ te_y = torch.tensor(targets_all[te_i], dtype=torch.float32).to(device)
498
+
499
+ tr_dl = FastTensorDataLoader(tr_x, tr_y, batch_size=BATCH_SIZE, shuffle=True)
500
+ vl_dl = FastTensorDataLoader(vl_x, vl_y, batch_size=BATCH_SIZE, shuffle=False)
501
+ te_dl = FastTensorDataLoader(te_x, te_y, batch_size=BATCH_SIZE, shuffle=False)
502
+
503
+ torch.manual_seed(seed + fi); np.random.seed(seed + fi)
504
+ if device.type == 'cuda': torch.cuda.manual_seed(seed + fi)
505
+
506
+ model = DeepHybridTRM(**model_kw).to(device)
507
+ if fi == 0: print(f" Params: {model.count_parameters():,}")
508
+
509
+ bv, model, hist = train_fold(model, tr_dl, vl_dl, device,
510
+ epochs=300, swa_start=200, fold=fi+1, name=cname)
511
+
512
+ pred = predict(model, te_dl)
513
+ mae = F.l1_loss(pred, te_y.cpu()).item()
514
+ print(f" Fold {fi+1} TEST: {mae:.4f} eV (val: {bv:.4f})")
515
+ fold_maes.append(mae)
516
+
517
+ os.makedirs('expt_gap_models_v3', exist_ok=True)
518
+ torch.save({
519
+ 'model_state': model.state_dict(),
520
+ 'test_mae': mae, 'config': cname, 'seed': seed,
521
+ 'fold': fi+1, 'n_extra': n_extra,
522
+ }, f'expt_gap_models_v3/{cname}_s{seed}_f{fi+1}.pt')
523
+
524
+ del model, tr_x, tr_y, vl_x, vl_y, te_x, te_y
525
+ if device.type == 'cuda': torch.cuda.empty_cache()
526
+
527
+ avg = float(np.mean(fold_maes))
528
+ std = float(np.std(fold_maes))
529
+ all_results[cname] = {'avg': avg, 'std': std, 'folds': fold_maes}
530
+ print(f"\n === {cname}: {avg:.4f} +/- {std:.4f} eV ===")
531
+
532
+ # ======== FINAL RESULTS ========
533
+ tt = time.time() - t0
534
+ print(f"\n{'='*72}")
535
+ print(f" FINAL LEADERBOARD -- TRIADS V3 (5-Fold Avg MAE, eV)")
536
+ print(f"{'='*72}")
537
+ print(f" {'Model':<20} {'MAE':>10} {'Std':>8} Notes")
538
+ print(f" {'-'*60}")
539
+
540
+ for n, r in sorted(all_results.items(), key=lambda x: x[1]['avg']):
541
+ tag = (" <-- DARWIN BEATEN!" if r['avg'] < 0.2865 else
542
+ " <-- Top 3!" if r['avg'] < 0.3327 else
543
+ " <-- Beats V1!" if r['avg'] < 0.3510 else
544
+ " <-- Beats AMMExp" if r['avg'] < 0.4161 else "")
545
+ print(f" {n:<20} {r['avg']:>10.4f} {r['std']:>8.4f}{tag}")
546
+
547
+ print(f" {'-'*60}")
548
+ for vn, vm in sorted(V1_BEST.items(), key=lambda x: x[1]):
549
+ print(f" {vn:<20} {vm:>10.4f} (V1)")
550
+ for bn, bv in sorted(BASELINES.items(), key=lambda x: x[1]):
551
+ print(f" {bn:<20} {bv:>10.4f}")
552
+
553
+ # Per-fold
554
+ names = sorted(all_results.keys())
555
+ print(f"\n PER-FOLD:")
556
+ hdr = f" {'Fold':<6}"; [hdr := hdr + f" {cn:>14}" for cn in names]
557
+ print(hdr)
558
+ for fi in range(5):
559
+ row = f" F{fi+1:<5}"; [row := row + f" {all_results[cn]['folds'][fi]:>14.4f}" for cn in names]
560
+ print(row)
561
+
562
+ print(f"\n HP GRID: {'D=0.15':>10} {'D=0.20':>10}")
563
+ for s in [16, 20]:
564
+ d15 = all_results.get(f'V3-S{s}-D15', {}).get('avg', 0)
565
+ d20 = all_results.get(f'V3-S{s}-D20', {}).get('avg', 0)
566
+ print(f" S={s:>2} {d15:>10.4f} {d20:>10.4f}")
567
+
568
+ print(f"\n Total: {tt/60:.1f} min")
569
+
570
+ s = {'version': 'EG-V3', 'batch_size': BATCH_SIZE,
571
+ 'total_min': round(tt/60, 1), 'models': all_results,
572
+ 'baselines': BASELINES, 'v1': V1_BEST}
573
+ with open('expt_gap_summary_v3.json', 'w') as f:
574
+ json.dump(s, f, indent=2)
575
+ print(" Saved: expt_gap_summary_v3.json")
576
+
577
+
578
+ if __name__ == '__main__':
579
+ run_benchmark()
model_code/jdft2d_model.py ADDED
@@ -0,0 +1,589 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ +=============================================================+
3
+ | TRIADS V4 on matbench_jdft2d — 5-Seed Ensemble |
4
+ | Exfoliation Energy (meV/atom) — 636 samples |
5
+ | |
6
+ | Structural + Composition features (~361d) |
7
+ | 75K model (d_attn=32, d_hidden=64) | dropout=0.20 |
8
+ | Seeds: [42, 123, 456, 789, 1024] |
9
+ | Target: Kaggle P100 | ~30 min |
10
+ +=============================================================+
11
+ """
12
+
13
+ import os, copy, json, time, logging, warnings, urllib.request, shutil
14
+ warnings.filterwarnings('ignore')
15
+
16
+ import numpy as np
17
+ import pandas as pd
18
+ from tqdm import tqdm
19
+
20
+ import torch
21
+ import torch.nn as nn
22
+ import torch.nn.functional as F
23
+ from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
24
+
25
+ from sklearn.model_selection import KFold
26
+ from sklearn.preprocessing import StandardScaler
27
+ from pymatgen.core import Composition
28
+ from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
29
+ from matminer.featurizers.composition import ElementProperty
30
+ from gensim.models import Word2Vec
31
+
32
+ logging.basicConfig(level=logging.INFO, format='%(name)s | %(message)s')
33
+ log = logging.getLogger("TRIADS-jdft2d")
34
+
35
+ BATCH_SIZE = 64
36
+ SEEDS = [42, 123, 456, 789, 1024]
37
+
38
+ # 75K config — best for 636 samples
39
+ MODEL_CFG = dict(
40
+ d_attn=32, nhead=4, d_hidden=64, ff_dim=96,
41
+ dropout=0.20, max_steps=16,
42
+ )
43
+
44
+ V1_BEST = {'V1 (100K, comp-only)': 45.8045}
45
+ V2_BEST = {'V2 (44K, comp-only)': 46.5889}
46
+ V3_BEST = {'V3 (75K, +struct, single)': 37.0033}
47
+
48
+
49
+ # ======================================================================
50
+ # FAST TENSOR DATALOADER
51
+ # ======================================================================
52
+
53
+ class FastTensorDataLoader:
54
+ def __init__(self, *tensors, batch_size=64, shuffle=False):
55
+ assert all(t.shape[0] == tensors[0].shape[0] for t in tensors)
56
+ self.tensors = tensors
57
+ self.dataset_len = tensors[0].shape[0]
58
+ self.batch_size = batch_size
59
+ self.shuffle = shuffle
60
+ self.n_batches = (self.dataset_len + batch_size - 1) // batch_size
61
+
62
+ def __iter__(self):
63
+ if self.shuffle:
64
+ idx = torch.randperm(self.dataset_len, device=self.tensors[0].device)
65
+ self.tensors = tuple(t[idx] for t in self.tensors)
66
+ self.i = 0
67
+ return self
68
+
69
+ def __next__(self):
70
+ if self.i >= self.dataset_len:
71
+ raise StopIteration
72
+ batch = tuple(t[self.i:self.i + self.batch_size] for t in self.tensors)
73
+ self.i += self.batch_size
74
+ return batch
75
+
76
+ def __len__(self):
77
+ return self.n_batches
78
+
79
+
80
+ # ======================================================================
81
+ # FEATURIZER — Composition + Structural (~361d)
82
+ # ======================================================================
83
+
84
+ def _extract_structural_features(structure):
85
+ feats = []
86
+ try:
87
+ lat = structure.lattice
88
+ feats.extend([lat.a, lat.b, lat.c, lat.alpha, lat.beta, lat.gamma])
89
+ feats.append(structure.volume / max(len(structure), 1))
90
+ feats.append(structure.density)
91
+ feats.append(float(len(structure)))
92
+ try:
93
+ sga = SpacegroupAnalyzer(structure, symprec=0.1)
94
+ feats.append(float(sga.get_space_group_number()))
95
+ except:
96
+ feats.append(0.0)
97
+ try:
98
+ total_vol = sum(
99
+ (4/3) * np.pi * site.specie.atomic_radius**3
100
+ for site in structure if hasattr(site.specie, 'atomic_radius')
101
+ and site.specie.atomic_radius is not None
102
+ )
103
+ feats.append(total_vol / structure.volume if structure.volume > 0 else 0.0)
104
+ except:
105
+ feats.append(0.0)
106
+ except:
107
+ feats = [0.0] * 11
108
+ return np.array(feats, dtype=np.float32)
109
+
110
+
111
+ class ExfoliationFeaturizer:
112
+ GCS = "https://storage.googleapis.com/mat2vec/"
113
+ FILES = ["pretrained_embeddings",
114
+ "pretrained_embeddings.wv.vectors.npy",
115
+ "pretrained_embeddings.trainables.syn1neg.npy"]
116
+
117
+ def __init__(self, cache="mat2vec_cache"):
118
+ from matminer.featurizers.composition import (
119
+ Stoichiometry, ValenceOrbital, IonProperty
120
+ )
121
+ from matminer.featurizers.composition.element import TMetalFraction
122
+
123
+ self.ep_magpie = ElementProperty.from_preset("magpie")
124
+ self.n_mg = len(self.ep_magpie.feature_labels())
125
+
126
+ self.extra_featurizers = [
127
+ ("Stoichiometry", Stoichiometry()),
128
+ ("ValenceOrbital", ValenceOrbital()),
129
+ ("IonProperty", IonProperty()),
130
+ ("TMetalFraction", TMetalFraction()),
131
+ ]
132
+
133
+ self._extra_sizes = {}
134
+ for name, ftzr in self.extra_featurizers:
135
+ try: self._extra_sizes[name] = len(ftzr.feature_labels())
136
+ except: self._extra_sizes[name] = None
137
+
138
+ self.n_extra = None
139
+ self.scaler = None
140
+
141
+ os.makedirs(cache, exist_ok=True)
142
+ for f in self.FILES:
143
+ p = os.path.join(cache, f)
144
+ if not os.path.exists(p):
145
+ log.info(f" Downloading {f}...")
146
+ urllib.request.urlretrieve(self.GCS + f, p)
147
+ self.m2v = Word2Vec.load(os.path.join(cache, "pretrained_embeddings"))
148
+ self.emb = {w: self.m2v.wv[w] for w in self.m2v.wv.index_to_key}
149
+
150
+ def _pool(self, c):
151
+ v, t = np.zeros(200, np.float32), 0.0
152
+ for s, f in c.get_el_amt_dict().items():
153
+ if s in self.emb: v += f * self.emb[s]; t += f
154
+ return v / max(t, 1e-8)
155
+
156
+ def _featurize_extra(self, comp, structure=None):
157
+ parts = []
158
+ for name, ftzr in self.extra_featurizers:
159
+ try:
160
+ vals = np.array(ftzr.featurize(comp), np.float32)
161
+ parts.append(np.nan_to_num(vals, nan=0.0))
162
+ if self._extra_sizes.get(name) is None:
163
+ self._extra_sizes[name] = len(vals)
164
+ except:
165
+ sz = self._extra_sizes.get(name, 0) or 1
166
+ parts.append(np.zeros(sz, np.float32))
167
+ if structure is not None:
168
+ parts.append(_extract_structural_features(structure))
169
+ else:
170
+ parts.append(np.zeros(11, np.float32))
171
+ return np.concatenate(parts)
172
+
173
+ def featurize_all(self, comps, structures=None):
174
+ out = []
175
+ test_struct = structures[0] if structures else None
176
+ test_ex = self._featurize_extra(comps[0], test_struct)
177
+ self.n_extra = len(test_ex)
178
+ total = self.n_mg + self.n_extra + 200
179
+ comp_extras = sum(self._extra_sizes.get(n, 0) or 0
180
+ for n, _ in self.extra_featurizers)
181
+ log.info(f"Features: {self.n_mg} Magpie + {comp_extras} CompExtra + "
182
+ f"11 Structural + 200 Mat2Vec = {total}d")
183
+ for i, c in enumerate(tqdm(comps, desc=" Featurizing", leave=False)):
184
+ struct = structures[i] if structures else None
185
+ try: mg = np.array(self.ep_magpie.featurize(c), np.float32)
186
+ except: mg = np.zeros(self.n_mg, np.float32)
187
+ ex = self._featurize_extra(c, struct)
188
+ out.append(np.concatenate([
189
+ np.nan_to_num(mg, nan=0.0),
190
+ np.nan_to_num(ex, nan=0.0),
191
+ self._pool(c)
192
+ ]))
193
+ return np.array(out)
194
+
195
+ def fit_scaler(self, X): self.scaler = StandardScaler().fit(X)
196
+ def transform(self, X):
197
+ if not self.scaler: return X
198
+ return np.nan_to_num(self.scaler.transform(X), nan=0.0).astype(np.float32)
199
+
200
+
201
+ # ======================================================================
202
+ # MODEL
203
+ # ======================================================================
204
+
205
+ class DeepHybridTRM(nn.Module):
206
+ def __init__(self, n_props=22, stat_dim=6, n_extra=0, mat2vec_dim=200,
207
+ d_attn=32, nhead=4, d_hidden=64, ff_dim=96,
208
+ dropout=0.15, max_steps=16, **kw):
209
+ super().__init__()
210
+ self.max_steps, self.D = max_steps, d_hidden
211
+ self.n_props, self.stat_dim, self.n_extra = n_props, stat_dim, n_extra
212
+
213
+ self.tok_proj = nn.Sequential(
214
+ nn.Linear(stat_dim, d_attn), nn.LayerNorm(d_attn), nn.GELU())
215
+ self.m2v_proj = nn.Sequential(
216
+ nn.Linear(mat2vec_dim, d_attn), nn.LayerNorm(d_attn), nn.GELU())
217
+
218
+ self.sa1 = nn.MultiheadAttention(d_attn, nhead, dropout=dropout, batch_first=True)
219
+ self.sa1_n = nn.LayerNorm(d_attn)
220
+ self.sa1_ff = nn.Sequential(
221
+ nn.Linear(d_attn, d_attn*2), nn.GELU(), nn.Dropout(dropout),
222
+ nn.Linear(d_attn*2, d_attn))
223
+ self.sa1_fn = nn.LayerNorm(d_attn)
224
+
225
+ self.sa2 = nn.MultiheadAttention(d_attn, nhead, dropout=dropout, batch_first=True)
226
+ self.sa2_n = nn.LayerNorm(d_attn)
227
+ self.sa2_ff = nn.Sequential(
228
+ nn.Linear(d_attn, d_attn*2), nn.GELU(), nn.Dropout(dropout),
229
+ nn.Linear(d_attn*2, d_attn))
230
+ self.sa2_fn = nn.LayerNorm(d_attn)
231
+
232
+ self.ca = nn.MultiheadAttention(d_attn, nhead, dropout=dropout, batch_first=True)
233
+ self.ca_n = nn.LayerNorm(d_attn)
234
+
235
+ pool_in = d_attn + (n_extra if n_extra > 0 else 0)
236
+ self.pool = nn.Sequential(
237
+ nn.Linear(pool_in, d_hidden), nn.LayerNorm(d_hidden), nn.GELU())
238
+
239
+ self.z_up = nn.Sequential(
240
+ nn.Linear(d_hidden*3, ff_dim), nn.GELU(), nn.Dropout(dropout),
241
+ nn.Linear(ff_dim, d_hidden), nn.LayerNorm(d_hidden))
242
+ self.y_up = nn.Sequential(
243
+ nn.Linear(d_hidden*2, ff_dim), nn.GELU(), nn.Dropout(dropout),
244
+ nn.Linear(ff_dim, d_hidden), nn.LayerNorm(d_hidden))
245
+ self.head = nn.Linear(d_hidden, 1)
246
+ self._init()
247
+
248
+ def _init(self):
249
+ for m in self.modules():
250
+ if isinstance(m, nn.Linear):
251
+ nn.init.xavier_uniform_(m.weight)
252
+ if m.bias is not None: nn.init.zeros_(m.bias)
253
+
254
+ def _attention(self, x):
255
+ B = x.size(0)
256
+ mg_dim = self.n_props * self.stat_dim
257
+ if self.n_extra > 0:
258
+ extra = x[:, mg_dim:mg_dim + self.n_extra]
259
+ m2v = x[:, mg_dim + self.n_extra:]
260
+ else:
261
+ extra, m2v = None, x[:, mg_dim:]
262
+
263
+ tok = self.tok_proj(x[:, :mg_dim].view(B, self.n_props, self.stat_dim))
264
+ ctx = self.m2v_proj(m2v).unsqueeze(1)
265
+
266
+ tok = self.sa1_n(tok + self.sa1(tok, tok, tok)[0])
267
+ tok = self.sa1_fn(tok + self.sa1_ff(tok))
268
+ tok = self.sa2_n(tok + self.sa2(tok, tok, tok)[0])
269
+ tok = self.sa2_fn(tok + self.sa2_ff(tok))
270
+ tok = self.ca_n(tok + self.ca(tok, ctx, ctx)[0])
271
+
272
+ pooled = tok.mean(dim=1)
273
+ if extra is not None:
274
+ pooled = torch.cat([pooled, extra], dim=-1)
275
+ return self.pool(pooled)
276
+
277
+ def forward(self, x, deep_supervision=False):
278
+ B = x.size(0)
279
+ xp = self._attention(x)
280
+ z = torch.zeros(B, self.D, device=x.device)
281
+ y = torch.zeros(B, self.D, device=x.device)
282
+ step_preds = []
283
+ for s in range(self.max_steps):
284
+ z = z + self.z_up(torch.cat([xp, y, z], -1))
285
+ y = y + self.y_up(torch.cat([y, z], -1))
286
+ step_preds.append(self.head(y).squeeze(1))
287
+ return step_preds if deep_supervision else step_preds[-1]
288
+
289
+ def count_parameters(self):
290
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
291
+
292
+
293
+ # ======================================================================
294
+ # LOSS + UTILS
295
+ # ======================================================================
296
+
297
+ def deep_supervision_loss(step_preds, targets):
298
+ preds = torch.stack(step_preds)
299
+ n = preds.shape[0]
300
+ w = torch.arange(1, n + 1, device=preds.device, dtype=preds.dtype)
301
+ w = w / w.sum()
302
+ per_step = (preds - targets.unsqueeze(0)).abs().mean(dim=1)
303
+ return (w * per_step).sum()
304
+
305
+
306
+ def strat_split(targets, val_size=0.15, seed=42):
307
+ bins = np.percentile(targets, [25, 50, 75])
308
+ lbl = np.digitize(targets, bins)
309
+ tr, vl = [], []
310
+ rng = np.random.RandomState(seed)
311
+ for b in range(4):
312
+ m = np.where(lbl == b)[0]
313
+ if len(m) == 0: continue
314
+ n = max(1, int(len(m) * val_size))
315
+ c = rng.choice(m, n, replace=False)
316
+ vl.extend(c.tolist()); tr.extend(np.setdiff1d(m, c).tolist())
317
+ return np.array(tr), np.array(vl)
318
+
319
+
320
+ @torch.inference_mode()
321
+ def predict(model, dl):
322
+ model.eval()
323
+ preds = []
324
+ for bx, _ in dl:
325
+ preds.append(model(bx).cpu())
326
+ return torch.cat(preds)
327
+
328
+
329
+ # ======================================================================
330
+ # TRAINING
331
+ # ======================================================================
332
+
333
+ def train_fold(model, tr_dl, vl_dl, device,
334
+ epochs=300, swa_start=200, fold=1, seed=42):
335
+ opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
336
+ sch = torch.optim.lr_scheduler.CosineAnnealingLR(
337
+ opt, T_max=swa_start, eta_min=1e-4)
338
+ swa_m = AveragedModel(model)
339
+ swa_s = SWALR(opt, swa_lr=5e-4)
340
+ swa_on = False
341
+ best_v, best_w = float('inf'), None
342
+
343
+ pbar = tqdm(range(epochs), desc=f" [75K|s{seed}] F{fold}/5",
344
+ leave=False, ncols=120)
345
+ for ep in pbar:
346
+ model.train()
347
+ epoch_loss = torch.tensor(0.0, device=device)
348
+ n_samples = 0
349
+
350
+ for bx, by in tr_dl:
351
+ sp = model(bx, deep_supervision=True)
352
+ loss = deep_supervision_loss(sp, by)
353
+ opt.zero_grad(set_to_none=True)
354
+ loss.backward()
355
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
356
+ opt.step()
357
+ with torch.no_grad():
358
+ epoch_loss += (sp[-1] - by).abs().sum()
359
+ n_samples += len(by)
360
+
361
+ model.eval()
362
+ val_loss = torch.tensor(0.0, device=device)
363
+ val_n = 0
364
+ with torch.inference_mode():
365
+ for bx, by in vl_dl:
366
+ val_loss += (model(bx) - by).abs().sum()
367
+ val_n += len(by)
368
+
369
+ tl = epoch_loss.item() / n_samples
370
+ vl = val_loss.item() / val_n
371
+
372
+ if ep < swa_start:
373
+ sch.step()
374
+ if vl < best_v:
375
+ best_v = vl
376
+ best_w = copy.deepcopy(model.state_dict())
377
+ else:
378
+ if not swa_on: swa_on = True
379
+ swa_m.update_parameters(model); swa_s.step()
380
+
381
+ if ep % 10 == 0 or ep == epochs - 1:
382
+ pbar.set_postfix(Best=f'{best_v:.2f}', Ph='SWA' if swa_on else 'COS',
383
+ Tr=f'{tl:.2f}', Val=f'{vl:.2f}')
384
+
385
+ if swa_on:
386
+ update_bn(tr_dl, swa_m, device=device)
387
+ model.load_state_dict(swa_m.module.state_dict())
388
+ else:
389
+ model.load_state_dict(best_w)
390
+ return best_v, model
391
+
392
+
393
+ # ======================================================================
394
+ # MAIN — 5-SEED ENSEMBLE
395
+ # ======================================================================
396
+
397
+ def run_benchmark():
398
+ t0 = time.time()
399
+
400
+ print(f"""
401
+ +==========================================================+
402
+ | TRIADS V4 — matbench_jdft2d (5-Seed Ensemble) |
403
+ | Structural + Composition features (~361d) |
404
+ | 75K model | dropout=0.20 |
405
+ | Seeds: {SEEDS} |
406
+ +==========================================================+
407
+ """)
408
+
409
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
410
+ if device.type == 'cuda':
411
+ gm = torch.cuda.get_device_properties(0).total_memory / 1e9
412
+ print(f" GPU: {torch.cuda.get_device_name(0)} ({gm:.1f} GB)")
413
+ torch.backends.cuda.matmul.allow_tf32 = True
414
+ torch.backends.cudnn.benchmark = True
415
+
416
+ # ── LOAD DATASET ──────────────────────────────────────────────────
417
+ print("\n Loading matbench_jdft2d...")
418
+ from matminer.datasets import load_dataset
419
+ df = load_dataset("matbench_jdft2d")
420
+ targets_all = np.array(df['exfoliation_en'].tolist(), np.float32)
421
+ structures_all = df['structure'].tolist()
422
+ comps_all = [s.composition for s in structures_all]
423
+ print(f" Dataset: {len(comps_all)} samples")
424
+
425
+ # ── FEATURIZE (once) ─────────────────────────────────────────────
426
+ t_feat = time.time()
427
+ feat = ExfoliationFeaturizer()
428
+ X_all = feat.featurize_all(comps_all, structures_all)
429
+ n_extra = feat.n_extra
430
+ print(f" Features: {X_all.shape} (n_extra={n_extra})")
431
+ print(f" Featurization: {time.time()-t_feat:.1f}s")
432
+
433
+ # ── FOLDS ────────────────────────────────────────────────────────
434
+ kfold = KFold(n_splits=5, shuffle=True, random_state=18012019)
435
+ folds = list(kfold.split(comps_all))
436
+ for fi, (tv, te) in enumerate(folds):
437
+ assert len(set(tv) & set(te)) == 0
438
+ print(" 5 folds verified: zero leakage\n")
439
+
440
+ # ── MODEL INFO ───────────────────────────────────────────────────
441
+ model_kw = dict(n_props=22, stat_dim=6, n_extra=n_extra,
442
+ mat2vec_dim=200, **MODEL_CFG)
443
+ test_model = DeepHybridTRM(**model_kw)
444
+ n_params = test_model.count_parameters()
445
+ del test_model
446
+ print(f" Model: {n_params:,} params")
447
+ print(f" Config: d_attn={MODEL_CFG['d_attn']}, d_hidden={MODEL_CFG['d_hidden']}, "
448
+ f"ff_dim={MODEL_CFG['ff_dim']}, dropout={MODEL_CFG['dropout']}\n")
449
+
450
+ # ── TRAIN ALL SEEDS ──────────────────────────────────────────────
451
+ model_dir = 'jdft2d_models_v4'
452
+ os.makedirs(model_dir, exist_ok=True)
453
+
454
+ # Store predictions and MAEs per seed
455
+ all_seed_maes = {} # {seed: {fold: mae}}
456
+ all_fold_preds = {} # {fold: {seed: predictions}}
457
+ all_fold_targets = {} # {fold: targets}
458
+
459
+ for seed in SEEDS:
460
+ print(f"\n {'─'*3} Seed {seed} {'─'*40}")
461
+ t_seed = time.time()
462
+ seed_maes = {}
463
+
464
+ for fi, (tv_i, te_i) in enumerate(folds):
465
+ tri, vli = strat_split(targets_all[tv_i], 0.15, seed + fi)
466
+ feat.fit_scaler(X_all[tv_i][tri])
467
+
468
+ tr_x = torch.tensor(feat.transform(X_all[tv_i][tri]), dtype=torch.float32).to(device)
469
+ tr_y = torch.tensor(targets_all[tv_i][tri], dtype=torch.float32).to(device)
470
+ vl_x = torch.tensor(feat.transform(X_all[tv_i][vli]), dtype=torch.float32).to(device)
471
+ vl_y = torch.tensor(targets_all[tv_i][vli], dtype=torch.float32).to(device)
472
+ te_x = torch.tensor(feat.transform(X_all[te_i]), dtype=torch.float32).to(device)
473
+ te_y = torch.tensor(targets_all[te_i], dtype=torch.float32).to(device)
474
+
475
+ tr_dl = FastTensorDataLoader(tr_x, tr_y, batch_size=BATCH_SIZE, shuffle=True)
476
+ vl_dl = FastTensorDataLoader(vl_x, vl_y, batch_size=BATCH_SIZE, shuffle=False)
477
+ te_dl = FastTensorDataLoader(te_x, te_y, batch_size=BATCH_SIZE, shuffle=False)
478
+
479
+ torch.manual_seed(seed + fi)
480
+ np.random.seed(seed + fi)
481
+ if device.type == 'cuda': torch.cuda.manual_seed(seed + fi)
482
+
483
+ model = DeepHybridTRM(**model_kw).to(device)
484
+ bv, model = train_fold(model, tr_dl, vl_dl, device,
485
+ epochs=300, swa_start=200,
486
+ fold=fi+1, seed=seed)
487
+
488
+ pred = predict(model, te_dl)
489
+ mae = F.l1_loss(pred, te_y.cpu()).item()
490
+ seed_maes[fi] = mae
491
+
492
+ # Store for ensemble
493
+ if fi not in all_fold_preds:
494
+ all_fold_preds[fi] = {}
495
+ all_fold_targets[fi] = te_y.cpu()
496
+ all_fold_preds[fi][seed] = pred
497
+
498
+ torch.save({
499
+ 'model_state': model.state_dict(),
500
+ 'test_mae': mae, 'fold': fi+1, 'seed': seed,
501
+ 'n_extra': n_extra,
502
+ }, f'{model_dir}/jdft2d_75K_s{seed}_f{fi+1}.pt')
503
+
504
+ del model, tr_x, tr_y, vl_x, vl_y, te_x, te_y
505
+ if device.type == 'cuda': torch.cuda.empty_cache()
506
+
507
+ avg_s = np.mean(list(seed_maes.values()))
508
+ all_seed_maes[seed] = seed_maes
509
+ dt = time.time() - t_seed
510
+ print(f"\n Seed {seed}: avg={avg_s:.4f} | "
511
+ f"{[f'{seed_maes[i]:.4f}' for i in range(5)]} ({dt:.0f}s)")
512
+
513
+ # ── ENSEMBLE ─────────────────────────────────────────────────────
514
+ ens_maes = {}
515
+ for fi in range(5):
516
+ preds_stack = torch.stack([all_fold_preds[fi][s] for s in SEEDS])
517
+ ens_pred = preds_stack.mean(dim=0)
518
+ ens_maes[fi] = F.l1_loss(ens_pred, all_fold_targets[fi]).item()
519
+
520
+ single_avgs = [np.mean(list(all_seed_maes[s].values())) for s in SEEDS]
521
+ single_mean = np.mean(single_avgs)
522
+ single_std = np.std(single_avgs)
523
+ ens_mean = np.mean(list(ens_maes.values()))
524
+ ens_std = np.std(list(ens_maes.values()))
525
+ ens_drop = (1 - ens_mean / single_mean) * 100
526
+
527
+ # ── RESULTS ──────────────────────────────────────────────────────
528
+ tt = time.time() - t0
529
+
530
+ print(f"""
531
+ {'='*72}
532
+ FINAL RESULTS — TRIADS V4 on matbench_jdft2d
533
+ {'='*72}
534
+
535
+ Per-seed results:""")
536
+
537
+ for seed in SEEDS:
538
+ sm = all_seed_maes[seed]
539
+ avg_s = np.mean(list(sm.values()))
540
+ print(f" Seed {seed:>4}: {avg_s:.4f} | "
541
+ f"{[f'{sm[i]:.4f}' for i in range(5)]}")
542
+
543
+ print(f"""
544
+ Single-seed avg: {single_mean:.4f} ± {single_std:.4f}
545
+ 5-Seed Ensemble: {ens_mean:.4f} ± {ens_std:.4f} (↓{ens_drop:.1f}% from single)
546
+ Per-fold ens: {[f'{ens_maes[i]:.4f}' for i in range(5)]}
547
+
548
+ {'Model':<40} {'MAE(meV/atom)':>15}
549
+ {'─'*58}
550
+ {'MODNet v0.1.12':<40} {'33.1918':>15}
551
+ {'TRIADS V3 (75K, +struct, single)':<40} {'37.0033':>15}
552
+ {'TRIADS V4 (75K, +struct, 5-seed ens)':<40} {f'{ens_mean:.4f}':>15} ← NEW
553
+ {'TRIADS V1 (100K, comp-only)':<40} {'45.8045':>15}
554
+ {'─'*58}
555
+
556
+ Total time: {tt/60:.1f} min
557
+ Saved: {model_dir}/
558
+ """)
559
+
560
+ # ── SAVE ─────────────────────────────────────────────────────────
561
+ summary = {
562
+ 'version': 'jdft2d-V4-ensemble',
563
+ 'dataset': 'matbench_jdft2d',
564
+ 'samples': len(comps_all),
565
+ 'target_unit': 'meV/atom',
566
+ 'model_config': MODEL_CFG,
567
+ 'params': n_params,
568
+ 'seeds': SEEDS,
569
+ 'per_seed': {str(s): {str(k): round(v, 4) for k, v in m.items()}
570
+ for s, m in all_seed_maes.items()},
571
+ 'single_seed_avg': round(single_mean, 4),
572
+ 'single_seed_std': round(single_std, 4),
573
+ 'ensemble_maes': {str(k): round(v, 4) for k, v in ens_maes.items()},
574
+ 'ensemble_avg': round(ens_mean, 4),
575
+ 'ensemble_std': round(ens_std, 4),
576
+ 'ensemble_improvement': f'{ens_drop:.1f}%',
577
+ 'total_time_min': round(tt/60, 1),
578
+ }
579
+ with open('jdft2d_summary_v4.json', 'w') as f:
580
+ json.dump(summary, f, indent=2)
581
+ print(" Saved: jdft2d_summary_v4.json")
582
+
583
+ # Zip models
584
+ shutil.make_archive(model_dir, 'zip', '.', model_dir)
585
+ print(f" Saved: {model_dir}.zip (download this!)")
586
+
587
+
588
+ if __name__ == '__main__':
589
+ run_benchmark()
model_code/phonons_dataset_builder.py ADDED
@@ -0,0 +1,749 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ +=============================================================+
3
+ | V6 Physics-Featurized Phonon Dataset Builder |
4
+ | Architecture-Agnostic | Rich Physics | 3-Order Graphs |
5
+ | |
6
+ | Features per atom: 18d (element physics + coords + local) |
7
+ | Features per bond: 8d physics + 40d RBF + 3d direction |
8
+ | Order 2 (angles): 8d angle RBF |
9
+ | Order 3 (dihedrals): 8d dihedral RBF |
10
+ | Composition: MAGPIE + mat2vec + matminer extras |
11
+ | Global physics: Debye temp, force constants, etc. |
12
+ | |
13
+ | ⚠ NO SCALING — raw features. Scale at training time only. |
14
+ +=============================================================+
15
+
16
+ DEPENDENCIES:
17
+ pip install matminer pymatgen gensim tqdm scikit-learn torch numpy
18
+
19
+ USAGE:
20
+ python build_phonons_v6_dataset.py
21
+ -> Outputs: phonons_v6_dataset.pt
22
+ """
23
+
24
+ import os, time, math, warnings, urllib.request, logging
25
+ from collections import defaultdict
26
+ warnings.filterwarnings('ignore')
27
+
28
+ import numpy as np
29
+ import torch
30
+ from tqdm import tqdm
31
+ from sklearn.model_selection import KFold
32
+
33
+ logging.basicConfig(level=logging.INFO, format='%(name)s | %(message)s')
34
+ log = logging.getLogger("V6-BUILD")
35
+
36
+ # ═══════════════════════════════════════════════════════════════
37
+ # CONFIGURATION
38
+ # ═══════════════════════════════════════════════════════════════
39
+
40
+ CUTOFF = 8.0
41
+ MAX_NEIGHBORS = 12
42
+ N_RBF_DIST = 40
43
+ N_RBF_ANGLE = 8
44
+ N_RBF_DIHEDRAL = 8
45
+ MAX_QUADS = 50000 # cap dihedrals per crystal for memory
46
+ FOLD_SEED = 18012019 # matbench v0.1 protocol
47
+ N_FOLDS = 5
48
+
49
+ N_ELEM_FEAT = 12 # from lookup table
50
+ N_ATOM_COMPUTED = 6 # frac_coords(3) + coord_num(1) + avg_nn(1) + std_nn(1)
51
+ N_ATOM_FEAT = N_ELEM_FEAT + N_ATOM_COMPUTED # 18
52
+ N_BOND_PHYSICS = 8
53
+ N_GLOBAL_PHYS = 15
54
+
55
+
56
+ # ═══════════════════════════════════════════════════════════════
57
+ # GAUSSIAN RADIAL BASIS FUNCTIONS
58
+ # ═══════════════════════════════════════════════════════════════
59
+
60
+ def gaussian_rbf(values, n_bins, vmin, vmax):
61
+ """Fixed Gaussian expansion. No learnable parameters."""
62
+ centers = torch.linspace(vmin, vmax, n_bins)
63
+ gamma = 1.0 / ((vmax - vmin) / n_bins) ** 2
64
+ return torch.exp(-gamma * (values.unsqueeze(-1) - centers.unsqueeze(0)) ** 2)
65
+
66
+
67
+ # ═══════════════════════════════════════════════════════════════
68
+ # ELEMENT PHYSICS LOOKUP TABLE
69
+ # ═══════════════════════════════════════════════════════════════
70
+
71
+ def build_element_table():
72
+ """
73
+ Build [103, 12] lookup table of per-element physical properties.
74
+ Z=0 is padding. Uses pymatgen Element data.
75
+
76
+ Columns: mass, 1/sqrt(mass), electronegativity, atomic_radius,
77
+ covalent_radius, ionization_energy, electron_affinity,
78
+ valence_electrons, group, period, block, is_metal
79
+ """
80
+ from pymatgen.core.periodic_table import Element
81
+
82
+ block_map = {'s': 0., 'p': 1., 'd': 2., 'f': 3.}
83
+ table = torch.zeros(103, N_ELEM_FEAT)
84
+
85
+ for z in range(1, 103):
86
+ try:
87
+ el = Element.from_Z(z)
88
+ mass = float(el.atomic_mass) if el.atomic_mass else 1.0
89
+ chi = float(el.X) if el.X is not None else 0.0
90
+ ar = float(el.atomic_radius) if el.atomic_radius is not None else 1.5
91
+ # Covalent radius proxy
92
+ try:
93
+ cr = float(el.average_ionic_radius) if el.average_ionic_radius and float(el.average_ionic_radius) > 0 else ar
94
+ except:
95
+ cr = ar
96
+ # First ionization energy
97
+ ie = 0.0
98
+ try:
99
+ ies = el.ionization_energies
100
+ if isinstance(ies, dict) and 1 in ies and ies[1] is not None:
101
+ ie = float(ies[1])
102
+ elif isinstance(ies, (list, tuple)) and len(ies) > 1 and ies[1] is not None:
103
+ ie = float(ies[1])
104
+ except:
105
+ pass
106
+ # Electron affinity
107
+ ea = 0.0
108
+ try:
109
+ if el.electron_affinity is not None:
110
+ ea = float(el.electron_affinity)
111
+ except:
112
+ pass
113
+ # Group, period, valence electrons
114
+ g = int(el.group) if el.group is not None else 0
115
+ p = int(el.row) if el.row is not None else 0
116
+ ve = g if g <= 2 else (g - 10 if g >= 13 else 2)
117
+ bl = block_map.get(el.block, 0.) if hasattr(el, 'block') and el.block else 0.
118
+ im = 1.0 if el.is_metal else 0.0
119
+
120
+ table[z] = torch.tensor([
121
+ mass, 1.0 / math.sqrt(max(mass, 0.01)), chi, ar, cr,
122
+ ie, ea, float(ve), float(g), float(p), bl, im
123
+ ])
124
+ except:
125
+ table[z] = torch.tensor([1., 1., 0., 1.5, 1.5, 0., 0., 0., 0., 0., 0., 0.])
126
+
127
+ return table
128
+
129
+
130
+ # ═══════════════════════════════════════════════════════════════
131
+ # CRYSTAL GRAPH BUILDER (Orders 1, 2, 3)
132
+ # ═══════════════════════════════════════════════════════════════
133
+
134
+ def _empty_graph(atom_z, atom_features, n_atoms):
135
+ """Fallback for crystals with no neighbors found."""
136
+ return {
137
+ 'atom_z': atom_z,
138
+ 'atom_features': atom_features,
139
+ 'n_atoms': n_atoms,
140
+ 'edge_index': torch.zeros(2, 1, dtype=torch.long),
141
+ 'edge_dist': torch.zeros(1),
142
+ 'edge_rbf': torch.zeros(1, N_RBF_DIST),
143
+ 'edge_vec': torch.zeros(1, 3),
144
+ 'edge_physics': torch.zeros(1, N_BOND_PHYSICS),
145
+ 'n_edges': 1,
146
+ 'triplet_index': torch.zeros(2, 0, dtype=torch.long),
147
+ 'angle_rbf': torch.zeros(0, N_RBF_ANGLE),
148
+ 'n_triplets': 0,
149
+ 'quad_index': torch.zeros(2, 0, dtype=torch.long),
150
+ 'dihedral_rbf': torch.zeros(0, N_RBF_DIHEDRAL),
151
+ 'n_quads': 0,
152
+ }
153
+
154
+
155
+ def build_crystal_graph(structure, elem_table):
156
+ """
157
+ Build a complete 3-order crystal graph for a single structure.
158
+
159
+ Returns dict with atom features, edge features + physics,
160
+ triplets (angles), and quads (dihedrals).
161
+
162
+ ✅ ZERO DATA LEAKAGE: uses ONLY this structure's geometry.
163
+ """
164
+ n_atoms = len(structure)
165
+ atom_z = torch.tensor([site.specie.Z for site in structure], dtype=torch.long)
166
+
167
+ # Element lookup features [N, 12]
168
+ atom_elem_feat = elem_table[atom_z.clamp(0, 102)]
169
+
170
+ # Fractional coordinates [N, 3]
171
+ frac_coords = torch.tensor(
172
+ [site.frac_coords for site in structure], dtype=torch.float32
173
+ )
174
+
175
+ # ── NEIGHBOR FINDING ──────────────────────────────────────
176
+ src_list, dst_list, dist_list, vec_list = [], [], [], []
177
+ nn_dists_per_atom = defaultdict(list)
178
+
179
+ try:
180
+ all_nbrs = structure.get_all_neighbors(CUTOFF)
181
+ for i, nbrs in enumerate(all_nbrs):
182
+ nbrs_sorted = sorted(nbrs, key=lambda x: x.nn_distance)[:MAX_NEIGHBORS]
183
+ for nbr in nbrs_sorted:
184
+ src_list.append(i)
185
+ dst_list.append(nbr.index)
186
+ dist_list.append(nbr.nn_distance)
187
+ vec_list.append(nbr.coords - structure[i].coords)
188
+ nn_dists_per_atom[i].append(nbr.nn_distance)
189
+ except Exception as e:
190
+ log.warning(f" Neighbor finding failed: {e}")
191
+
192
+ # Per-atom coordination stats
193
+ coord_nums = torch.zeros(n_atoms)
194
+ avg_nn_dists = torch.zeros(n_atoms)
195
+ std_nn_dists = torch.zeros(n_atoms)
196
+ for i in range(n_atoms):
197
+ ds = nn_dists_per_atom.get(i, [])
198
+ coord_nums[i] = len(ds)
199
+ if ds:
200
+ avg_nn_dists[i] = np.mean(ds)
201
+ std_nn_dists[i] = np.std(ds) if len(ds) > 1 else 0.0
202
+
203
+ # Combined atom features [N, 18]
204
+ atom_features = torch.cat([
205
+ atom_elem_feat, # [N, 12]
206
+ frac_coords, # [N, 3]
207
+ coord_nums.unsqueeze(-1), # [N, 1]
208
+ avg_nn_dists.unsqueeze(-1), # [N, 1]
209
+ std_nn_dists.unsqueeze(-1), # [N, 1]
210
+ ], dim=-1) # [N, 18]
211
+
212
+ if len(src_list) == 0:
213
+ return _empty_graph(atom_z, atom_features, n_atoms)
214
+
215
+ # ── EDGE FEATURES (Order 1) ───────────────────────────────
216
+ edge_index = torch.tensor([src_list, dst_list], dtype=torch.long)
217
+ edge_dist = torch.tensor(dist_list, dtype=torch.float32)
218
+ raw_vecs = torch.tensor(np.array(vec_list), dtype=torch.float32)
219
+ n_edges = edge_index.shape[1]
220
+
221
+ edge_rbf = gaussian_rbf(edge_dist, N_RBF_DIST, 0.0, CUTOFF)
222
+ norms = raw_vecs.norm(dim=-1, keepdim=True).clamp(min=1e-8)
223
+ edge_vec = raw_vecs / norms
224
+
225
+ # ── BOND PHYSICS FEATURES [E, 8] ─────────────────────────
226
+ z_src = atom_z[edge_index[0]] # [E]
227
+ z_dst = atom_z[edge_index[1]] # [E]
228
+
229
+ m_src = elem_table[z_src.clamp(0, 102), 0] # mass
230
+ m_dst = elem_table[z_dst.clamp(0, 102), 0]
231
+ chi_src = elem_table[z_src.clamp(0, 102), 2] # electronegativity
232
+ chi_dst = elem_table[z_dst.clamp(0, 102), 2]
233
+ r_src = elem_table[z_src.clamp(0, 102), 3] # atomic radius
234
+ r_dst = elem_table[z_dst.clamp(0, 102), 3]
235
+
236
+ d = edge_dist.clamp(min=0.01)
237
+
238
+ # Vectorized bond physics computation
239
+ chi_prod = (chi_src * chi_dst).clamp(min=0.01)
240
+ k_est = torch.sqrt(chi_prod) / (d * d) # force constant
241
+ mu = (m_src * m_dst) / (m_src + m_dst).clamp(min=0.01) # reduced mass
242
+ omega = torch.sqrt(k_est / mu.clamp(min=0.01)) # Einstein freq
243
+ delta_chi = (chi_src - chi_dst).abs() # EN difference
244
+ ionicity = delta_chi * delta_chi # bond ionicity
245
+ r_ratio = (r_src + r_dst) / d # radius sum ratio
246
+ m_ratio = torch.min(m_src, m_dst) / torch.max(m_src, m_dst).clamp(min=0.01)
247
+ inv_d = 1.0 / d # inverse distance
248
+
249
+ edge_physics = torch.stack([
250
+ k_est, mu, omega, delta_chi, ionicity, r_ratio, m_ratio, inv_d
251
+ ], dim=-1) # [E, 8]
252
+
253
+ # ── TRIPLETS / ANGLES (Order 2) ───────────────────────────
254
+ dst_np = edge_index[1].numpy()
255
+ dest_to_edges = defaultdict(list)
256
+ for e_idx in range(n_edges):
257
+ dest_to_edges[int(dst_np[e_idx])].append(e_idx)
258
+
259
+ trip_ij, trip_kj = [], []
260
+ for j, edge_list in dest_to_edges.items():
261
+ for idx_ij in edge_list:
262
+ for idx_kj in edge_list:
263
+ if idx_ij != idx_kj:
264
+ trip_ij.append(idx_ij)
265
+ trip_kj.append(idx_kj)
266
+
267
+ if trip_ij:
268
+ triplet_index = torch.tensor([trip_ij, trip_kj], dtype=torch.long)
269
+ v_ij = edge_vec[triplet_index[0]]
270
+ v_kj = edge_vec[triplet_index[1]]
271
+ cos_theta = (v_ij * v_kj).sum(-1).clamp(-1 + 1e-7, 1 - 1e-7)
272
+ angles = torch.acos(cos_theta)
273
+ angle_rbf_t = gaussian_rbf(angles, N_RBF_ANGLE, 0.0, math.pi)
274
+ n_triplets = triplet_index.shape[1]
275
+ else:
276
+ triplet_index = torch.zeros(2, 0, dtype=torch.long)
277
+ angle_rbf_t = torch.zeros(0, N_RBF_ANGLE)
278
+ n_triplets = 0
279
+
280
+ # ── QUADS / DIHEDRALS (Order 3) ───────────────────────────
281
+ quad_index, dihedral_rbf_t, n_quads = _compute_quads(
282
+ triplet_index, n_triplets, edge_vec, trip_ij, trip_kj
283
+ )
284
+
285
+ return {
286
+ 'atom_z': atom_z,
287
+ 'atom_features': atom_features,
288
+ 'n_atoms': n_atoms,
289
+ 'edge_index': edge_index,
290
+ 'edge_dist': edge_dist,
291
+ 'edge_rbf': edge_rbf,
292
+ 'edge_vec': edge_vec,
293
+ 'edge_physics': edge_physics,
294
+ 'n_edges': n_edges,
295
+ 'triplet_index': triplet_index,
296
+ 'angle_rbf': angle_rbf_t,
297
+ 'n_triplets': n_triplets,
298
+ 'quad_index': quad_index,
299
+ 'dihedral_rbf': dihedral_rbf_t,
300
+ 'n_quads': n_quads,
301
+ }
302
+
303
+
304
+ def _compute_quads(triplet_index, n_triplets, edge_vec, trip_ij, trip_kj):
305
+ """Compute Order 3: pairs of triplets sharing a bond (dihedrals)."""
306
+ if n_triplets == 0:
307
+ return (torch.zeros(2, 0, dtype=torch.long),
308
+ torch.zeros(0, N_RBF_DIHEDRAL), 0)
309
+
310
+ # For each edge, which triplets reference it?
311
+ edge_to_trips = defaultdict(list)
312
+ for t_idx in range(n_triplets):
313
+ edge_to_trips[trip_ij[t_idx]].append(t_idx)
314
+ edge_to_trips[trip_kj[t_idx]].append(t_idx)
315
+
316
+ quad_src, quad_dst = [], []
317
+ for edge_idx, tlist in edge_to_trips.items():
318
+ for i in range(len(tlist)):
319
+ for j in range(len(tlist)):
320
+ if tlist[i] != tlist[j]:
321
+ quad_src.append(tlist[i])
322
+ quad_dst.append(tlist[j])
323
+ if len(quad_src) >= MAX_QUADS:
324
+ break
325
+ if len(quad_src) >= MAX_QUADS:
326
+ break
327
+ if len(quad_src) >= MAX_QUADS:
328
+ break
329
+
330
+ if not quad_src:
331
+ return (torch.zeros(2, 0, dtype=torch.long),
332
+ torch.zeros(0, N_RBF_DIHEDRAL), 0)
333
+
334
+ quad_index = torch.tensor([quad_src, quad_dst], dtype=torch.long)
335
+
336
+ # Dihedral angle = angle between planes of the two triplets
337
+ v_a1 = edge_vec[triplet_index[0, quad_index[0]]]
338
+ v_a2 = edge_vec[triplet_index[1, quad_index[0]]]
339
+ v_b1 = edge_vec[triplet_index[0, quad_index[1]]]
340
+ v_b2 = edge_vec[triplet_index[1, quad_index[1]]]
341
+
342
+ n_a = torch.cross(v_a1, v_a2, dim=-1)
343
+ n_b = torch.cross(v_b1, v_b2, dim=-1)
344
+ n_a = n_a / n_a.norm(dim=-1, keepdim=True).clamp(min=1e-8)
345
+ n_b = n_b / n_b.norm(dim=-1, keepdim=True).clamp(min=1e-8)
346
+
347
+ cos_dih = (n_a * n_b).sum(-1).clamp(-1 + 1e-7, 1 - 1e-7)
348
+ dihedrals = torch.acos(cos_dih)
349
+ dihedral_rbf_t = gaussian_rbf(dihedrals, N_RBF_DIHEDRAL, 0.0, math.pi)
350
+
351
+ return quad_index, dihedral_rbf_t, quad_index.shape[1]
352
+
353
+
354
+ # ═══════════════════════════════════════════════════════════════
355
+ # GLOBAL PHYSICS FEATURES (per crystal)
356
+ # ═══════════════════════════════════════════════════════════════
357
+
358
+ def compute_global_physics(graph, structure, elem_table):
359
+ """
360
+ Compute 15 global physics features from a crystal graph.
361
+
362
+ Features:
363
+ 0: avg_force_constant 7: avg_coordination
364
+ 1: std_force_constant 8: density
365
+ 2: avg_reduced_mass 9: volume_per_atom
366
+ 3: mass_variance 10: packing_fraction
367
+ 4: avg_einstein_freq 11: avg_bond_length
368
+ 5: electronegativity_var 12: std_bond_length
369
+ 6: debye_temp_estimate 13: max_atomic_mass
370
+ 14: min_atomic_mass
371
+ """
372
+ ep = graph['edge_physics'] # [E, 8]
373
+ n_atoms = graph['n_atoms']
374
+ atom_z = graph['atom_z']
375
+
376
+ # From bond physics
377
+ k_vals = ep[:, 0] # force constants
378
+ mu_vals = ep[:, 1] # reduced masses
379
+ omega_vals = ep[:, 2] # Einstein frequencies
380
+ dists = graph['edge_dist']
381
+
382
+ feats = torch.zeros(N_GLOBAL_PHYS)
383
+
384
+ if graph['n_edges'] > 0 and dists.shape[0] > 0:
385
+ feats[0] = k_vals.mean()
386
+ feats[1] = k_vals.std() if k_vals.shape[0] > 1 else 0.0
387
+ feats[2] = mu_vals.mean()
388
+ feats[4] = omega_vals.mean()
389
+ feats[11] = dists.mean()
390
+ feats[12] = dists.std() if dists.shape[0] > 1 else 0.0
391
+
392
+ # Mass statistics
393
+ masses = elem_table[atom_z.clamp(0, 102), 0]
394
+ feats[3] = masses.var() if n_atoms > 1 else 0.0
395
+ feats[13] = masses.max()
396
+ feats[14] = masses.min()
397
+
398
+ # Electronegativity variance
399
+ chis = elem_table[atom_z.clamp(0, 102), 2]
400
+ feats[5] = chis.var() if n_atoms > 1 else 0.0
401
+
402
+ # Debye temperature estimate: Θ_D ∝ sqrt(k_avg / m_avg)
403
+ m_avg = masses.mean()
404
+ k_avg = feats[0]
405
+ feats[6] = math.sqrt(float(k_avg / max(m_avg, 0.01)))
406
+
407
+ # Coordination
408
+ feats[7] = graph['atom_features'][:, N_ELEM_FEAT + 3].mean() # coord_num column
409
+
410
+ # Structural
411
+ try:
412
+ feats[8] = structure.density
413
+ feats[9] = structure.volume / max(n_atoms, 1)
414
+ # Packing fraction
415
+ total_vol = sum(
416
+ (4 / 3) * math.pi * (float(site.specie.atomic_radius) ** 3)
417
+ for site in structure
418
+ if hasattr(site.specie, 'atomic_radius') and site.specie.atomic_radius is not None
419
+ )
420
+ feats[10] = total_vol / structure.volume if structure.volume > 0 else 0.0
421
+ except:
422
+ pass
423
+
424
+ return feats
425
+
426
+
427
+ # ═══════════════════════════════════════════════════════════════
428
+ # STRUCTURAL FEATURES (per crystal)
429
+ # ═══════════════════════════════════════════════════════════════
430
+
431
+ def compute_structural_features(structure):
432
+ """
433
+ Compute 11 structural features: lattice params + symmetry.
434
+ Same as previous versions for backward compatibility.
435
+ """
436
+ from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
437
+
438
+ feats = np.zeros(11, dtype=np.float32)
439
+ try:
440
+ lat = structure.lattice
441
+ feats[0:6] = [lat.a, lat.b, lat.c, lat.alpha, lat.beta, lat.gamma]
442
+ feats[6] = structure.volume / max(len(structure), 1)
443
+ feats[7] = structure.density
444
+ feats[8] = float(len(structure))
445
+ try:
446
+ sga = SpacegroupAnalyzer(structure, symprec=0.1)
447
+ feats[9] = float(sga.get_space_group_number())
448
+ except:
449
+ feats[9] = 0.0
450
+ try:
451
+ total_vol = sum(
452
+ (4 / 3) * np.pi * site.specie.atomic_radius ** 3
453
+ for site in structure
454
+ if hasattr(site.specie, 'atomic_radius') and site.specie.atomic_radius is not None
455
+ )
456
+ feats[10] = total_vol / structure.volume if structure.volume > 0 else 0.0
457
+ except:
458
+ feats[10] = 0.0
459
+ except:
460
+ pass
461
+ return feats
462
+
463
+
464
+ # ═══════════════════════════════════════════════════════════════
465
+ # COMPOSITION FEATURIZER (MAGPIE + mat2vec + matminer extras)
466
+ # ═══════════════════════════════════════════════════════════════
467
+
468
+ class CompositionFeaturizer:
469
+ """
470
+ Builds rich composition features per crystal:
471
+ - MAGPIE elemental properties (132d: 22 props × 6 stats)
472
+ - Extra matminer (Stoichiometry, ValenceOrbital, IonProperty, TMetalFraction)
473
+ - Structural features (11d)
474
+ - mat2vec embeddings (200d)
475
+
476
+ ✅ ALL features are deterministic per-sample. No cross-sample info.
477
+ """
478
+ M2V_URL = "https://storage.googleapis.com/mat2vec/"
479
+ M2V_FILES = [
480
+ "pretrained_embeddings",
481
+ "pretrained_embeddings.wv.vectors.npy",
482
+ "pretrained_embeddings.trainables.syn1neg.npy",
483
+ ]
484
+
485
+ def __init__(self, cache="mat2vec_cache"):
486
+ from matminer.featurizers.composition import (
487
+ ElementProperty, Stoichiometry, ValenceOrbital, IonProperty
488
+ )
489
+ from matminer.featurizers.composition.element import TMetalFraction
490
+ from gensim.models import Word2Vec
491
+
492
+ self.ep_magpie = ElementProperty.from_preset("magpie")
493
+ self.n_magpie = len(self.ep_magpie.feature_labels())
494
+
495
+ self.extra_ftzrs = [
496
+ ("Stoichiometry", Stoichiometry()),
497
+ ("ValenceOrbital", ValenceOrbital()),
498
+ ("IonProperty", IonProperty()),
499
+ ("TMetalFraction", TMetalFraction()),
500
+ ]
501
+ self._extra_sizes = {}
502
+ for name, ft in self.extra_ftzrs:
503
+ try:
504
+ self._extra_sizes[name] = len(ft.feature_labels())
505
+ except:
506
+ self._extra_sizes[name] = None
507
+
508
+ # Download mat2vec
509
+ os.makedirs(cache, exist_ok=True)
510
+ for f in self.M2V_FILES:
511
+ p = os.path.join(cache, f)
512
+ if not os.path.exists(p):
513
+ log.info(f" Downloading mat2vec: {f}...")
514
+ urllib.request.urlretrieve(self.M2V_URL + f, p)
515
+ m2v = Word2Vec.load(os.path.join(cache, "pretrained_embeddings"))
516
+ self.emb = {w: m2v.wv[w] for w in m2v.wv.index_to_key}
517
+
518
+ self.n_extra = None # determined on first call
519
+
520
+ def _pool_m2v(self, comp):
521
+ v, t = np.zeros(200, np.float32), 0.0
522
+ for s, f in comp.get_el_amt_dict().items():
523
+ if s in self.emb:
524
+ v += f * self.emb[s]
525
+ t += f
526
+ return v / max(t, 1e-8)
527
+
528
+ def _featurize_extras(self, comp):
529
+ parts = []
530
+ for name, ft in self.extra_ftzrs:
531
+ try:
532
+ vals = np.array(ft.featurize(comp), np.float32)
533
+ parts.append(np.nan_to_num(vals, nan=0.0))
534
+ if self._extra_sizes.get(name) is None:
535
+ self._extra_sizes[name] = len(vals)
536
+ except:
537
+ sz = self._extra_sizes.get(name, 0) or 1
538
+ parts.append(np.zeros(sz, np.float32))
539
+ return np.concatenate(parts)
540
+
541
+ def featurize_all(self, compositions, structures):
542
+ """Return [N, D_comp] array of all composition features."""
543
+ # Determine dimensions from first sample
544
+ test_extras = self._featurize_extras(compositions[0])
545
+ self.n_extra = len(test_extras)
546
+ struct_feats_dim = 11
547
+ total_dim = self.n_magpie + self.n_extra + struct_feats_dim + 200
548
+
549
+ log.info(f" Composition features: {self.n_magpie} MAGPIE + "
550
+ f"{self.n_extra} Extras + 11 Structural + 200 mat2vec = {total_dim}d")
551
+
552
+ out = []
553
+ for i, comp in enumerate(tqdm(compositions, desc=" Featurizing compositions", leave=False)):
554
+ # MAGPIE
555
+ try:
556
+ mg = np.array(self.ep_magpie.featurize(comp), np.float32)
557
+ except:
558
+ mg = np.zeros(self.n_magpie, np.float32)
559
+ mg = np.nan_to_num(mg, nan=0.0)
560
+
561
+ # Extra matminer
562
+ ex = self._featurize_extras(comp)
563
+
564
+ # Structural
565
+ sf = compute_structural_features(structures[i])
566
+
567
+ # mat2vec
568
+ m2v = self._pool_m2v(comp)
569
+
570
+ out.append(np.concatenate([mg, ex, sf, m2v]))
571
+
572
+ return np.array(out, dtype=np.float32)
573
+
574
+
575
+ # ═══════════════════════════════════════════════════════════════
576
+ # MAIN — BUILD AND SAVE
577
+ # ═══════════════════════════════════════════════════════════════
578
+
579
+ def main():
580
+ t0 = time.time()
581
+ print("""
582
+ +==========================================================+
583
+ | V6 Physics-Featurized Phonon Dataset Builder |
584
+ | 3-Order Graphs | Bond Physics | Architecture-Agnostic |
585
+ | ⚠ NO SCALING — raw features only |
586
+ +==========================================================+
587
+ """)
588
+
589
+ # ── LOAD MATBENCH DATA ────────────────────────────────────
590
+ print(" Loading matbench_phonons...")
591
+ from matminer.datasets import load_dataset
592
+ df = load_dataset("matbench_phonons")
593
+ targets = np.array(df['last phdos peak'].tolist(), np.float32)
594
+ structures = df['structure'].tolist()
595
+ compositions = [s.composition for s in structures]
596
+ N = len(structures)
597
+ print(f" Loaded: {N} samples")
598
+ print(f" Target range: {targets.min():.1f} – {targets.max():.1f} cm⁻¹")
599
+
600
+ # ── BUILD ELEMENT TABLE ───────────────────────────────────
601
+ print("\n Building element physics table...")
602
+ elem_table = build_element_table()
603
+ print(f" Element table: {elem_table.shape} (Z=0..102, {N_ELEM_FEAT} features)")
604
+
605
+ # ── BUILD CRYSTAL GRAPHS ─────────────────────────────────
606
+ print(f"\n Building 3-order crystal graphs ({MAX_NEIGHBORS}-NN, cutoff={CUTOFF}Å)...")
607
+ graphs = []
608
+ global_physics_list = []
609
+
610
+ for i, struct in enumerate(tqdm(structures, desc=" Building graphs")):
611
+ g = build_crystal_graph(struct, elem_table)
612
+ gp = compute_global_physics(g, struct, elem_table)
613
+ graphs.append(g)
614
+ global_physics_list.append(gp)
615
+
616
+ # Stats
617
+ n_atoms_list = [g['n_atoms'] for g in graphs]
618
+ n_edges_list = [g['n_edges'] for g in graphs]
619
+ n_trips_list = [g['n_triplets'] for g in graphs]
620
+ n_quads_list = [g['n_quads'] for g in graphs]
621
+ print(f" Graphs built:")
622
+ print(f" Atoms/crystal: min={min(n_atoms_list)}, max={max(n_atoms_list)}, "
623
+ f"mean={np.mean(n_atoms_list):.1f}")
624
+ print(f" Edges/crystal: min={min(n_edges_list)}, max={max(n_edges_list)}, "
625
+ f"mean={np.mean(n_edges_list):.1f}")
626
+ print(f" Triplets/crystal: min={min(n_trips_list)}, max={max(n_trips_list)}, "
627
+ f"mean={np.mean(n_trips_list):.1f}")
628
+ print(f" Quads/crystal: min={min(n_quads_list)}, max={max(n_quads_list)}, "
629
+ f"mean={np.mean(n_quads_list):.1f}")
630
+
631
+ global_physics = torch.stack(global_physics_list)
632
+ print(f" Global physics: {global_physics.shape}")
633
+
634
+ # ── COMPOSITION FEATURES ─────────────────────────────────
635
+ print("\n Computing composition features...")
636
+ feat = CompositionFeaturizer()
637
+ comp_features = feat.featurize_all(compositions, structures)
638
+ print(f" Composition features shape: {comp_features.shape}")
639
+
640
+ # ── FOLD INDICES (strict matbench protocol) ──────────────
641
+ print(f"\n Computing 5-fold split indices (seed={FOLD_SEED})...")
642
+ kf = KFold(N_FOLDS, shuffle=True, random_state=FOLD_SEED)
643
+ fold_indices = [(train_idx.tolist(), test_idx.tolist())
644
+ for train_idx, test_idx in kf.split(range(N))]
645
+
646
+ # Verify zero leakage
647
+ for fi, (tr, te) in enumerate(fold_indices):
648
+ overlap = set(tr) & set(te)
649
+ assert len(overlap) == 0, f"DATA LEAK in fold {fi}: {len(overlap)} shared indices!"
650
+ assert len(tr) + len(te) == N, f"Fold {fi}: missing samples!"
651
+ print(" ✅ All folds verified: ZERO data leakage")
652
+
653
+ # ── FEATURE DIMENSION INFO ───────────────────────────────
654
+ n_magpie = feat.n_magpie
655
+ n_extra = feat.n_extra
656
+ feature_info = {
657
+ 'atom_features_dim': N_ATOM_FEAT,
658
+ 'atom_features_layout': [
659
+ 'mass', '1/sqrt_mass', 'electronegativity', 'atomic_radius',
660
+ 'covalent_radius', 'ionization_energy', 'electron_affinity',
661
+ 'valence_electrons', 'group', 'period', 'block', 'is_metal',
662
+ 'frac_x', 'frac_y', 'frac_z',
663
+ 'coordination_num', 'avg_nn_dist', 'std_nn_dist',
664
+ ],
665
+ 'edge_physics_dim': N_BOND_PHYSICS,
666
+ 'edge_physics_layout': [
667
+ 'force_constant', 'reduced_mass', 'einstein_freq',
668
+ 'en_difference', 'ionicity', 'radius_sum_ratio',
669
+ 'mass_ratio', 'inverse_distance',
670
+ ],
671
+ 'edge_rbf_dim': N_RBF_DIST,
672
+ 'angle_rbf_dim': N_RBF_ANGLE,
673
+ 'dihedral_rbf_dim': N_RBF_DIHEDRAL,
674
+ 'global_physics_dim': N_GLOBAL_PHYS,
675
+ 'global_physics_layout': [
676
+ 'avg_force_constant', 'std_force_constant', 'avg_reduced_mass',
677
+ 'mass_variance', 'avg_einstein_freq', 'en_variance',
678
+ 'debye_temp_estimate', 'avg_coordination', 'density',
679
+ 'volume_per_atom', 'packing_fraction', 'avg_bond_length',
680
+ 'std_bond_length', 'max_atomic_mass', 'min_atomic_mass',
681
+ ],
682
+ 'comp_magpie_range': (0, n_magpie),
683
+ 'comp_extras_range': (n_magpie, n_magpie + n_extra),
684
+ 'comp_structural_range': (n_magpie + n_extra, n_magpie + n_extra + 11),
685
+ 'comp_mat2vec_range': (n_magpie + n_extra + 11, n_magpie + n_extra + 11 + 200),
686
+ 'comp_total_dim': comp_features.shape[1],
687
+ }
688
+
689
+ # ── SAVE ─────────────────────────────────────────────────
690
+ save_path = "phonons_v6_dataset.pt"
691
+ save_data = {
692
+ # Per-crystal data
693
+ 'graphs': graphs,
694
+ 'comp_features': torch.tensor(comp_features, dtype=torch.float32),
695
+ 'global_physics': global_physics,
696
+ 'targets': torch.tensor(targets, dtype=torch.float32),
697
+
698
+ # Fold indices
699
+ 'fold_indices': fold_indices,
700
+ 'fold_seed': FOLD_SEED,
701
+
702
+ # Metadata
703
+ 'n_samples': N,
704
+ 'feature_info': feature_info,
705
+ 'element_table': elem_table,
706
+ 'config': {
707
+ 'cutoff': CUTOFF,
708
+ 'max_neighbors': MAX_NEIGHBORS,
709
+ 'n_rbf_dist': N_RBF_DIST,
710
+ 'n_rbf_angle': N_RBF_ANGLE,
711
+ 'n_rbf_dihedral': N_RBF_DIHEDRAL,
712
+ 'max_quads': MAX_QUADS,
713
+ 'fold_seed': FOLD_SEED,
714
+ 'n_folds': N_FOLDS,
715
+ },
716
+ }
717
+ torch.save(save_data, save_path)
718
+
719
+ size_mb = os.path.getsize(save_path) / 1e6
720
+ dt = time.time() - t0
721
+ print(f"\n ✅ Saved: {save_path} ({size_mb:.1f} MB)")
722
+ print(f" Total time: {dt:.1f}s")
723
+
724
+ # ── SUMMARY ──────────────────────────────────────────────
725
+ print(f"""
726
+ ╔══════════════════════════════════════════════════════════╗
727
+ ║ Dataset Summary ║
728
+ ╠══════════════════════════════════════════════════════════╣
729
+ ║ Samples: {N:>6} ║
730
+ ║ Atom features: {N_ATOM_FEAT:>6}d (12 elem + 3 coord + 3 local) ║
731
+ ║ Bond RBF: {N_RBF_DIST:>6}d ║
732
+ ║ Bond physics: {N_BOND_PHYSICS:>6}d (k, μ, ω, Δχ, ...) ║
733
+ ║ Angle RBF: {N_RBF_ANGLE:>6}d ║
734
+ ║ Dihedral RBF: {N_RBF_DIHEDRAL:>6}d ║
735
+ ║ Composition: {comp_features.shape[1]:>6}d (MAGPIE+extras+struct+m2v)║
736
+ ║ Global physics: {N_GLOBAL_PHYS:>6}d ║
737
+ ║ Folds: {N_FOLDS:>6} (seed={FOLD_SEED}) ║
738
+ ║ File size: {size_mb:>5.1f} MB ║
739
+ ╚══════════════════════════════════════════════════════════╝
740
+
741
+ ⚠ Remember: NO scaling applied. Apply StandardScaler at
742
+ training time using ONLY train-fold indices!
743
+
744
+ Architecture-agnostic: plug ANY model on top of this dataset.
745
+ """)
746
+
747
+
748
+ if __name__ == '__main__':
749
+ main()
model_code/phonons_model.py ADDED
@@ -0,0 +1,839 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ +=============================================================+
3
+ | TRIADS V6 — Graph Attention TRM + Gate-Based Halting |
4
+ | |
5
+ | Single model: Gate-halt (4-16 adaptive cycles) |
6
+ | d=56, 4 heads, gated residuals, deep supervision |
7
+ | SWA last 50 ep | 200 epochs |
8
+ | |
9
+ | Loads: phonons_v6_dataset.pt |
10
+ +=============================================================+
11
+
12
+ DEPENDENCIES (dataset already pre-computed, no matminer needed):
13
+ pip install torch numpy scikit-learn tqdm
14
+ (all pre-installed on Kaggle)
15
+
16
+ USAGE:
17
+ python phonons_v6.py
18
+ """
19
+
20
+ import os, copy, json, time, math, warnings, threading
21
+ from collections import defaultdict
22
+ warnings.filterwarnings('ignore')
23
+ import numpy as np
24
+ import torch
25
+ import torch.nn as nn
26
+ import torch.nn.functional as F
27
+ from torch.optim.swa_utils import AveragedModel, SWALR
28
+ from sklearn.preprocessing import StandardScaler
29
+
30
+ # Notebook dashboard (IPython is always available on Kaggle)
31
+ try:
32
+ from IPython.display import display, HTML, clear_output
33
+ IN_NOTEBOOK = True
34
+ except ImportError:
35
+ IN_NOTEBOOK = False
36
+
37
+
38
+ # ═══════════════════════════════════════════════════════════════
39
+ # CONFIG
40
+ # ═══════════════════════════════════════════════════════════════
41
+
42
+ D = 56
43
+ N_HEADS = 4
44
+ N_WARMUP = 1 # 1 unshared warm-up (param budget)
45
+ N_ANGLE_RBF = 8
46
+ DROPOUT = 0.1
47
+ BATCH_SIZE = 64
48
+ EPOCHS = 200
49
+ SWA_START = 150
50
+ LR = 5e-4
51
+ WD = 1e-4
52
+ SEEDS = [42]
53
+
54
+ # Gate-halt model
55
+ MIN_CYCLES = 4
56
+ MAX_CYCLES = 16
57
+ GATE_HALT_THR = 0.05 # halt when max gate < this
58
+ GATE_SPARSITY = 0.001 # encourage gates to close
59
+
60
+ BASELINES = {
61
+ 'MEGNet': 28.76, 'ALIGNN': 29.34, 'MODNet': 45.39,
62
+ 'CrabNet': 47.09, 'TRIADS V4': 56.33, 'TRIADS V3.1': 63.00,
63
+ 'TRIADS V1': 71.82, 'Dummy': 323.76,
64
+ }
65
+
66
+
67
+ # ═══════════════════════════════════════════════════════════════
68
+ # SCATTER
69
+ # ═══════════════════════════════════════════════════════════════
70
+
71
+ def scatter_sum(src, idx, dim_size):
72
+ out = torch.zeros(dim_size, src.shape[-1], dtype=src.dtype, device=src.device)
73
+ out.scatter_add_(0, idx.unsqueeze(-1).expand_as(src), src)
74
+ return out
75
+
76
+
77
+ # ═══════════════════════════════════════════════════════════════
78
+ # COLLATION + DATALOADER
79
+ # ═══════════════════════════════════════════════════════════════
80
+
81
+ def collate(graphs, comp, glob_phys, targets, indices, device):
82
+ az, af = [], []
83
+ ei, rb, vc, ph = [], [], [], []
84
+ tr, an = [], []
85
+ ba, na_list = [], []
86
+ a_off, e_off = 0, 0
87
+
88
+ for k, i in enumerate(indices):
89
+ g = graphs[i]
90
+ na, ne = g['n_atoms'], g['n_edges']
91
+ az.append(g['atom_z'])
92
+ af.append(g['atom_features'])
93
+ ei.append(g['edge_index'] + a_off)
94
+ rb.append(g['edge_rbf']); vc.append(g['edge_vec']); ph.append(g['edge_physics'])
95
+ tr.append(g['triplet_index'] + e_off)
96
+ an.append(g['angle_rbf'])
97
+ ba.append(torch.full((na,), k, dtype=torch.long))
98
+ na_list.append(na)
99
+ a_off += na; e_off += ne
100
+
101
+ return (
102
+ comp[indices].to(device),
103
+ glob_phys[indices].to(device),
104
+ {
105
+ 'atom_z': torch.cat(az).to(device),
106
+ 'atom_feat': torch.cat(af).to(device),
107
+ 'ei': torch.cat(ei, 1).to(device),
108
+ 'rbf': torch.cat(rb).to(device),
109
+ 'vec': torch.cat(vc).to(device),
110
+ 'phys': torch.cat(ph).to(device),
111
+ 'triplets': torch.cat(tr, 1).to(device),
112
+ 'angle_feat': torch.cat(an).to(device),
113
+ 'batch': torch.cat(ba).to(device),
114
+ 'n_crystals': len(indices),
115
+ 'n_atoms': na_list,
116
+ },
117
+ targets[indices].to(device),
118
+ )
119
+
120
+
121
+ class Loader:
122
+ def __init__(self, graphs, comp, gp, tgt, idx, bs, dev, shuf=False):
123
+ self.g, self.c, self.gp, self.t = graphs, comp, gp, tgt
124
+ self.idx, self.bs, self.dev, self.shuf = np.array(idx), bs, dev, shuf
125
+
126
+ def __iter__(self):
127
+ i = self.idx.copy()
128
+ if self.shuf: np.random.shuffle(i)
129
+ self._b = [i[j:j+self.bs] for j in range(0, len(i), self.bs)]
130
+ self._p = 0; return self
131
+
132
+ def __next__(self):
133
+ if self._p >= len(self._b): raise StopIteration
134
+ b = self._b[self._p]; self._p += 1
135
+ return collate(self.g, self.c, self.gp, self.t, b, self.dev)
136
+
137
+ def __len__(self): return (len(self.idx) + self.bs - 1) // self.bs
138
+
139
+
140
+ # ═══════════════════════════════════════════════════════════════
141
+ # GRAPH MESSAGE PASSING LAYER (Line Graph style)
142
+ # ═══════════════════════════════════════════════════════════════
143
+
144
+ class GraphMPLayer(nn.Module):
145
+ """Bond update (line graph) + Atom update (edge-gated)."""
146
+
147
+ def __init__(self, d, n_angle=N_ANGLE_RBF, dropout=DROPOUT):
148
+ super().__init__()
149
+ # Phase 1: Bond update from angular neighbors
150
+ self.bond_msg = nn.Sequential(nn.Linear(d*2 + n_angle, d), nn.SiLU())
151
+ self.bond_gate = nn.Sequential(nn.Linear(d*2 + n_angle, d), nn.Sigmoid())
152
+ self.bond_up = nn.Sequential(nn.Linear(d*2, d), nn.LayerNorm(d), nn.SiLU(), nn.Dropout(dropout))
153
+ # Phase 2: Atom update from bonds
154
+ self.atom_msg = nn.Sequential(nn.Linear(d*3, d), nn.SiLU())
155
+ self.atom_gate = nn.Sequential(nn.Linear(d*3, d), nn.Sigmoid())
156
+ self.atom_up = nn.Sequential(nn.Linear(d*2, d), nn.LayerNorm(d), nn.SiLU(), nn.Dropout(dropout))
157
+
158
+ def forward(self, atoms, bonds, ei, triplets, angle_feat):
159
+ # Phase 1: bonds learn from angular neighbors
160
+ if triplets.shape[1] > 0:
161
+ b_ij, b_kj = bonds[triplets[0]], bonds[triplets[1]]
162
+ inp = torch.cat([b_ij, b_kj, angle_feat], -1)
163
+ msg = self.bond_msg(inp) * self.bond_gate(inp)
164
+ agg = torch.zeros(bonds.size(0), bonds.size(1), dtype=torch.float32, device=msg.device)
165
+ agg.scatter_add_(0, triplets[0].unsqueeze(-1).expand_as(msg), msg)
166
+ bonds = bonds + self.bond_up(torch.cat([bonds, agg], -1))
167
+ # Phase 2: atoms aggregate from bonds
168
+ inp = torch.cat([atoms[ei[0]], atoms[ei[1]], bonds], -1)
169
+ msg = self.atom_msg(inp) * self.atom_gate(inp)
170
+ agg = scatter_sum(msg, ei[1], atoms.size(0))
171
+ atoms = atoms + self.atom_up(torch.cat([atoms, agg], -1))
172
+ return atoms, bonds
173
+
174
+
175
+ # ═══════════════════════════════════════════════════════════════
176
+ # PHONON V6 MODEL
177
+ # ═══════════════════════════════════════════════════════════════
178
+
179
+ class PhononV6(nn.Module):
180
+ """
181
+ Graph Attention TRM for phonon prediction.
182
+
183
+ mode='fixed': Fixed n_cycles TRM cycles (Model 1)
184
+ mode='gate_halt': Gate-based implicit halting (Model 2)
185
+ """
186
+
187
+ def __init__(self, comp_dim, global_phys_dim=15, d=D,
188
+ mode='gate_halt', n_cycles=MAX_CYCLES,
189
+ min_cycles=MIN_CYCLES, max_cycles=MAX_CYCLES,
190
+ n_warmup=N_WARMUP, n_heads=N_HEADS, dropout=DROPOUT):
191
+ super().__init__()
192
+ self.d = d
193
+ self.mode = mode
194
+ self.total_cycles = n_cycles if mode == 'fixed' else max_cycles
195
+ self.min_cycles = min_cycles if mode == 'gate_halt' else self.total_cycles
196
+
197
+ # Feature layout (from V6 dataset: 132 magpie + extras + 11 struct + 200 m2v)
198
+ self.n_magpie = 132
199
+ self.n_extra = comp_dim - 132 - 11 - 200
200
+ self.n_comp_tokens = 22 + 1 + 1 # 22 magpie + 1 extra + 1 m2v = 24
201
+
202
+ # ── Input Encoding ────────────────────────────────────
203
+ self.atom_embed = nn.Embedding(103, d)
204
+ self.atom_feat_proj = nn.Linear(18, d)
205
+ self.rbf_enc = nn.Linear(40, d)
206
+ self.vec_enc = nn.Linear(3, d)
207
+ self.phys_enc = nn.Linear(8, d)
208
+
209
+ # ── Composition Token Projections ─────────────────────
210
+ self.magpie_proj = nn.Linear(6, d)
211
+ self.extra_proj = nn.Linear(max(self.n_extra, 1), d)
212
+ self.m2v_proj = nn.Linear(200, d)
213
+
214
+ # ── Context (structural + global physics) ─────────────
215
+ self.ctx_proj = nn.Linear(11 + global_phys_dim, d)
216
+
217
+ # ── Token Type Embeddings ─────────────────────────────
218
+ self.type_embed = nn.Embedding(2, d)
219
+
220
+ # ── Warm-up Layers (unshared) ─────────────────────────
221
+ self.warmup = nn.ModuleList([GraphMPLayer(d, N_ANGLE_RBF, dropout) for _ in range(n_warmup)])
222
+ self.warmup_out = nn.Sequential(nn.Linear(d, d), nn.LayerNorm(d), nn.SiLU())
223
+
224
+ # ── Shared TRM Block ──────────────────────────────────
225
+ # Graph MP (shared)
226
+ self.trm_gnn = GraphMPLayer(d, N_ANGLE_RBF, dropout)
227
+
228
+ # Self-Attention
229
+ self.sa = nn.MultiheadAttention(d, n_heads, dropout=dropout, batch_first=True)
230
+ self.sa_n = nn.LayerNorm(d)
231
+ self.sa_ff = nn.Sequential(nn.Linear(d, d), nn.GELU(), nn.Dropout(dropout), nn.Linear(d, d))
232
+ self.sa_fn = nn.LayerNorm(d)
233
+
234
+ # Cross-Attention
235
+ self.ca = nn.MultiheadAttention(d, n_heads, dropout=dropout, batch_first=True)
236
+ self.ca_n = nn.LayerNorm(d)
237
+
238
+ # ── State Update (Gated Residuals) ───────────────────
239
+ self.z_proj = nn.Linear(d*3, d)
240
+ self.z_up = nn.Sequential(nn.Linear(d*2, d), nn.SiLU(), nn.Linear(d, d))
241
+ self.z_gate = nn.Sequential(nn.Linear(d*2, d), nn.Sigmoid())
242
+ self.y_up = nn.Sequential(nn.Linear(d*2, d), nn.SiLU(), nn.Linear(d, d))
243
+ self.y_gate = nn.Sequential(nn.Linear(d*2, d), nn.Sigmoid())
244
+
245
+ # ── Output Head ───────────────────────────────────────
246
+ self.head = nn.Sequential(nn.Linear(d, d//2), nn.SiLU(), nn.Linear(d//2, 1))
247
+
248
+ self._init_weights()
249
+
250
+ def _init_weights(self):
251
+ for m in self.modules():
252
+ if isinstance(m, nn.Linear):
253
+ nn.init.xavier_uniform_(m.weight)
254
+ if m.bias is not None: nn.init.zeros_(m.bias)
255
+
256
+ def forward(self, comp, glob_phys, g, deep_supervision=False):
257
+ B = g['n_crystals']
258
+ ei = g['ei']
259
+ dev = comp.device
260
+
261
+ # ══════════════════════════════════════════════════════
262
+ # INPUT ENCODING
263
+ # ══════════════════════════════════════════════════════
264
+
265
+ # Atom features
266
+ atoms = self.atom_embed(g['atom_z'].clamp(0, 102)) + self.atom_feat_proj(g['atom_feat'])
267
+
268
+ # Bond features: distance (direction-gated) + physics
269
+ bonds = self.rbf_enc(g['rbf']) * torch.tanh(self.vec_enc(g['vec'])) + self.phys_enc(g['phys'])
270
+
271
+ triplets = g['triplets']
272
+ angle_feat = g['angle_feat']
273
+
274
+ # ══════════════════════════════════════════════════════
275
+ # WARM-UP (2 unshared graph layers)
276
+ # ══════════════════════════════════════════════════════
277
+
278
+ for layer in self.warmup:
279
+ atoms, bonds = layer(atoms, bonds, ei, triplets, angle_feat)
280
+ atoms = self.warmup_out(atoms)
281
+
282
+ # ══════════════════════════════════════════════════════
283
+ # COMPOSITION TOKENS (24 total)
284
+ # ══════════════════════════════════════════════════════
285
+
286
+ magpie = comp[:, :132].view(B, 22, 6)
287
+ extras = comp[:, 132:132+self.n_extra]
288
+ s_meta = comp[:, 132+self.n_extra:132+self.n_extra+11]
289
+ m2v = comp[:, -200:]
290
+
291
+ mag_tok = self.magpie_proj(magpie) # [B, 22, d]
292
+ ext_tok = self.extra_proj(extras).unsqueeze(1) # [B, 1, d]
293
+ m2v_tok = self.m2v_proj(m2v).unsqueeze(1) # [B, 1, d]
294
+ comp_tok = torch.cat([mag_tok, ext_tok, m2v_tok], 1) # [B, 24, d]
295
+
296
+ comp_tok = comp_tok + self.type_embed.weight[0]
297
+
298
+ # Context vector (structural + global physics)
299
+ ctx = self.ctx_proj(torch.cat([s_meta, glob_phys], -1)) # [B, d]
300
+
301
+ # ══════════════════════════════════════════════════════
302
+ # TRM REASONING LOOP
303
+ # ══════════════════════════════════════════════════════
304
+
305
+ z = torch.zeros(B, self.d, device=dev)
306
+ y = torch.zeros(B, self.d, device=dev)
307
+ preds = []
308
+ n_atoms = g['n_atoms']
309
+ self._gate_sparsity = 0. # track gate magnitudes for regularizer
310
+
311
+ for cyc in range(self.total_cycles):
312
+ # ── Phase 1+2: Graph MP (shared weights) ──────────
313
+ atoms, bonds = self.trm_gnn(atoms, bonds, ei, triplets, angle_feat)
314
+
315
+ # ── Pad atoms for attention ─────────────────��─────
316
+ ma = max(n_atoms)
317
+ atom_tok = atoms.new_zeros(B, ma, self.d)
318
+ atom_mask = torch.ones(B, ma, dtype=torch.bool, device=dev)
319
+ off = 0
320
+ for i, na in enumerate(n_atoms):
321
+ atom_tok[i, :na] = atoms[off:off+na]
322
+ atom_mask[i, :na] = False
323
+ off += na
324
+ atom_tok = atom_tok + self.type_embed.weight[1]
325
+
326
+ # ── Phase 3: Joint Self-Attention ─────────────────
327
+ all_tok = torch.cat([comp_tok, atom_tok], 1)
328
+ full_mask = torch.cat([
329
+ torch.zeros(B, self.n_comp_tokens, dtype=torch.bool, device=dev),
330
+ atom_mask
331
+ ], 1)
332
+
333
+ sa_out = self.sa(all_tok, all_tok, all_tok, key_padding_mask=full_mask)[0]
334
+ all_tok = self.sa_n(all_tok + sa_out)
335
+ all_tok = self.sa_fn(all_tok + self.sa_ff(all_tok))
336
+
337
+ comp_tok = all_tok[:, :self.n_comp_tokens]
338
+ atom_tok = all_tok[:, self.n_comp_tokens:]
339
+
340
+ # ── Phase 4: Cross-Attention (comp queries atoms) ─
341
+ ca_out = self.ca(comp_tok, atom_tok, atom_tok, key_padding_mask=atom_mask)[0]
342
+ comp_tok = self.ca_n(comp_tok + ca_out)
343
+
344
+ # ── Unpad atoms back to flat ──────────────────────
345
+ parts = [atom_tok[i, :n_atoms[i]] for i in range(B)]
346
+ atoms = torch.cat(parts, 0)
347
+
348
+ # ── Phase 5: State Update (Gated Residuals) ───────
349
+ xp = comp_tok.mean(dim=1) # [B, d]
350
+
351
+ z_inp = self.z_proj(torch.cat([xp, ctx, y], -1))
352
+ z_cand = self.z_up(torch.cat([z_inp, z], -1))
353
+ z_g = self.z_gate(torch.cat([z_inp, z], -1))
354
+ z = z + z_g * z_cand
355
+
356
+ y_cand = self.y_up(torch.cat([y, z], -1))
357
+ y_g = self.y_gate(torch.cat([y, z], -1))
358
+ y = y + y_g * y_cand
359
+ # Track gate sparsity (mean of all gate activations)
360
+ self._gate_sparsity = self._gate_sparsity + (z_g.mean() + y_g.mean()) / 2
361
+
362
+ preds.append(self.head(y).squeeze(-1))
363
+
364
+ # ── Phase 6: Gate-Based Halting ────────────────────
365
+ if self.mode == 'gate_halt' and cyc >= self.min_cycles - 1:
366
+ if y_g.max().item() < GATE_HALT_THR:
367
+ break
368
+
369
+ # Normalize gate sparsity by number of cycles actually run
370
+ n_run = len(preds)
371
+ self._gate_sparsity = self._gate_sparsity / max(n_run, 1)
372
+
373
+ return preds if deep_supervision else preds[-1]
374
+
375
+ def count_parameters(self):
376
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
377
+
378
+
379
+ # ═══════════════════════════════════════════════════════════════
380
+ # LOSS FUNCTIONS
381
+ # ═══════════════════════════════════════════════════════════════
382
+
383
+ def deep_sup_loss(preds_list, targets):
384
+ """Linearly-weighted deep supervision: later cycles get more weight."""
385
+ p = torch.stack(preds_list)
386
+ w = torch.arange(1, p.shape[0]+1, device=p.device, dtype=p.dtype)
387
+ w = w / w.sum()
388
+ return (w * (p - targets.unsqueeze(0)).abs().mean(1)).sum()
389
+
390
+
391
+ def gate_halt_loss(preds_list, targets, gate_sparsity):
392
+ """Deep supervision + gate sparsity to encourage early halting."""
393
+ return deep_sup_loss(preds_list, targets) + GATE_SPARSITY * gate_sparsity
394
+
395
+
396
+ # ═══════════════════════════════════════════════════════════════
397
+ # STRATIFIED SPLIT (within train fold → train/val)
398
+ # ═══════════════════════════════════════════════════════════════
399
+
400
+ def strat_split(t, vf=0.15, seed=42):
401
+ bins = np.digitize(t, np.percentile(t, [25, 50, 75]))
402
+ tr, vl = [], []
403
+ rng = np.random.RandomState(seed)
404
+ for b in range(4):
405
+ m = np.where(bins == b)[0]
406
+ if len(m) == 0: continue
407
+ n = max(1, int(len(m) * vf))
408
+ c = rng.choice(m, n, replace=False)
409
+ vl.extend(c.tolist())
410
+ tr.extend(np.setdiff1d(m, c).tolist())
411
+ return np.array(tr), np.array(vl)
412
+
413
+
414
+ # ═══════════════════════════════════════════════════════════════
415
+ # LIVE DASHBOARD (IPython HTML — works in Kaggle/Jupyter)
416
+ # ═══════════════════════════════════════════════════════════════
417
+
418
+ _print_lock = threading.Lock()
419
+
420
+ # Shared state updated by training threads, read by dashboard
421
+ _dash_state = {
422
+ 'GH': {'fold': 0, 'ep': 0, 'tr': float('inf'), 'val': float('inf'),
423
+ 'best': float('inf'), 'best_ep': 0, 'lr': 0., 'eta_m': 0,
424
+ 'ep_s': 0., 'swa': False, 'done': False, 'test_mae': None},
425
+ }
426
+ _dash_log = [] # Accumulates milestone messages
427
+
428
+
429
+ def _log(msg):
430
+ with _print_lock:
431
+ _dash_log.append(msg)
432
+ if not IN_NOTEBOOK:
433
+ print(msg, flush=True)
434
+
435
+
436
+ def _render_html():
437
+ """Build an HTML table from _dash_state + recent log lines."""
438
+ css = (
439
+ '<style>'
440
+ '.tv6{font-family:monospace;font-size:13px;border-collapse:collapse;width:100%}'
441
+ '.tv6 th{background:#1a1a2e;color:#e94560;padding:6px 10px;text-align:right;border-bottom:2px solid #e94560}'
442
+ '.tv6 td{padding:5px 10px;text-align:right;border-bottom:1px solid #333}'
443
+ '.tv6 tr:nth-child(odd){background:#16213e}'
444
+ '.tv6 tr:nth-child(even){background:#0f3460}'
445
+ '.tv6 td:first-child,.tv6 th:first-child{text-align:left;font-weight:bold;color:#e9c46a}'
446
+ '.tv6 .best{color:#2ecc71;font-weight:bold}'
447
+ '.tv6 .done{color:#2ecc71}'
448
+ '.tv6 .swa{color:#9b59b6}'
449
+ '.tv6 .training{color:#f39c12}'
450
+ '.tv6 .waiting{color:#636e72}'
451
+ '.logbox{font-family:monospace;font-size:12px;color:#dfe6e9;background:#0a0a0a;'
452
+ 'padding:8px 12px;margin-top:8px;border-radius:6px;max-height:200px;overflow-y:auto}'
453
+ '</style>'
454
+ )
455
+ rows = ''
456
+ for name, s in _dash_state.items():
457
+ if s['done'] and s['test_mae']:
458
+ status = f'<span class="done">✅ {s["test_mae"]:.1f}</span>'
459
+ elif s['swa']:
460
+ status = '<span class="swa">SWA</span>'
461
+ elif s['ep'] == 0:
462
+ status = '<span class="waiting">Waiting</span>'
463
+ else:
464
+ status = '<span class="training">▶ Training</span>'
465
+ ep_str = f"{s['ep']}/{EPOCHS}" if s['ep'] else '-'
466
+ tr_str = f"{s['tr']:.1f}" if s['tr'] < 1e6 else '-'
467
+ val_str = f"{s['val']:.1f}" if s['val'] < 1e6 else '-'
468
+ best_str = f'<span class="best">{s["best"]:.1f}@{s["best_ep"]}</span>' if s['best'] < 1e6 else '-'
469
+ lr_str = f"{s['lr']:.0e}" if s['lr'] > 0 else '-'
470
+ eps_str = f"{s['ep_s']:.1f}" if s['ep_s'] > 0 else '-'
471
+ eta_str = f"{s['eta_m']:.0f}m" if s['eta_m'] > 0 else '-'
472
+ fold_str = str(s['fold']) if s['fold'] else '-'
473
+ rows += (f'<tr><td>{name}</td><td>{fold_str}</td><td>{ep_str}</td>'
474
+ f'<td>{tr_str}</td><td>{val_str}</td><td>{best_str}</td>'
475
+ f'<td>{lr_str}</td><td>{eps_str}</td><td>{eta_str}</td>'
476
+ f'<td>{status}</td></tr>')
477
+ table = (
478
+ f'{css}<h3 style="color:#e94560;font-family:monospace;margin:4px 0">⚡ TRIADS V6 — Live Dashboard</h3>'
479
+ f'<table class="tv6"><tr><th>Model</th><th>Fold</th><th>Epoch</th>'
480
+ f'<th>Train MAE</th><th>Val MAE</th><th>Best MAE</th>'
481
+ f'<th>LR</th><th>s/ep</th><th>ETA</th><th>Status</th></tr>{rows}</table>'
482
+ )
483
+ # Show last 8 log messages
484
+ if _dash_log:
485
+ log_html = '<br>'.join(_dash_log[-8:])
486
+ table += f'<div class="logbox">{log_html}</div>'
487
+ return table
488
+
489
+
490
+ class Dashboard:
491
+ """Background thread that re-renders an HTML table every 5s in-place."""
492
+ def __init__(self):
493
+ self._stop = threading.Event()
494
+ self._thread = None
495
+
496
+ def start(self):
497
+ if not IN_NOTEBOOK:
498
+ return
499
+ self._stop.clear()
500
+ self._thread = threading.Thread(target=self._run, daemon=True)
501
+ self._thread.start()
502
+
503
+ def stop(self):
504
+ if not IN_NOTEBOOK or self._thread is None:
505
+ return
506
+ self._stop.set()
507
+ self._thread.join(timeout=10)
508
+ # Final render
509
+ clear_output(wait=True)
510
+ display(HTML(_render_html()))
511
+
512
+ def _run(self):
513
+ while not self._stop.is_set():
514
+ try:
515
+ clear_output(wait=True)
516
+ display(HTML(_render_html()))
517
+ except Exception:
518
+ pass
519
+ self._stop.wait(5)
520
+
521
+
522
+ _dashboard = Dashboard()
523
+
524
+
525
+ def train_fold_core(model, tr_loader, vl_loader, device, fold, seed,
526
+ model_name, tgt_mean=0., tgt_std=1., log_every=10):
527
+ """
528
+ Train one model on one device. Uses AMP + structured line logging.
529
+ Returns (best_val_mae, model_with_best_weights).
530
+ """
531
+ opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WD)
532
+ # Cosine scheduler with 10-epoch linear warmup
533
+ WARMUP_EP = 10
534
+ def lr_lambda(ep):
535
+ if ep < WARMUP_EP: return (ep + 1) / WARMUP_EP
536
+ progress = (ep - WARMUP_EP) / max(1, EPOCHS - WARMUP_EP)
537
+ return 0.5 * (1 + math.cos(math.pi * progress)) * (1 - 1e-5/LR) + 1e-5/LR
538
+ sch = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda)
539
+
540
+ swa_model = AveragedModel(model)
541
+ swa_sch = SWALR(opt, swa_lr=1e-4)
542
+
543
+ bv, bw, best_ep = float('inf'), None, 0
544
+ fold_start = time.time()
545
+
546
+ for ep in range(EPOCHS):
547
+ ep_start = time.time()
548
+ use_swa = ep >= SWA_START
549
+
550
+ # ── TRAIN ─────────────────────────────────────────────
551
+ model.train()
552
+ te, tn = 0., 0
553
+ for cb, gb, g_batch, tb in tr_loader:
554
+ sp = model(cb, gb, g_batch, True)
555
+ if model.mode == 'gate_halt':
556
+ loss = gate_halt_loss(sp, tb, model._gate_sparsity)
557
+ else:
558
+ loss = deep_sup_loss(sp, tb)
559
+ opt.zero_grad(set_to_none=True)
560
+ loss.backward()
561
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
562
+ opt.step()
563
+ with torch.no_grad():
564
+ te += ((sp[-1] * tgt_std + tgt_mean) - (tb * tgt_std + tgt_mean)).abs().sum().item()
565
+ tn += len(tb)
566
+
567
+ if use_swa:
568
+ swa_model.update_parameters(model)
569
+ swa_sch.step()
570
+ else:
571
+ sch.step()
572
+
573
+ # ── VALIDATE ──────────────────────────────────────────
574
+ eval_m = swa_model if use_swa and ep == EPOCHS - 1 else model
575
+ eval_m.eval()
576
+ ve, vn = 0., 0
577
+ with torch.inference_mode():
578
+ for cb, gb, g_batch, tb in vl_loader:
579
+ pred = eval_m(cb, gb, g_batch)
580
+ ve += ((pred * tgt_std + tgt_mean) - (tb * tgt_std + tgt_mean)).abs().sum().item()
581
+ vn += len(tb)
582
+
583
+ train_mae = te / max(tn, 1)
584
+ val_mae = ve / max(vn, 1)
585
+ ep_time = time.time() - ep_start
586
+
587
+ if val_mae < bv:
588
+ bv = val_mae
589
+ bw = copy.deepcopy(model.state_dict())
590
+ best_ep = ep + 1
591
+
592
+ # ── UPDATE DASHBOARD STATE (every epoch) ────────────
593
+ lr_now = opt.param_groups[0]['lr']
594
+ eta_m = (EPOCHS - ep - 1) * ep_time / 60
595
+ _dash_state[model_name].update({
596
+ 'fold': fold, 'ep': ep + 1,
597
+ 'tr': train_mae, 'val': val_mae,
598
+ 'best': bv, 'best_ep': best_ep,
599
+ 'lr': lr_now, 'ep_s': ep_time,
600
+ 'eta_m': eta_m, 'swa': use_swa,
601
+ })
602
+
603
+ # ── PLAIN LOG (fallback / milestone prints) ───────────
604
+ if not IN_NOTEBOOK and ((ep + 1) % log_every == 0 or ep == 0 or ep == EPOCHS - 1):
605
+ swa_tag = ' SWA' if use_swa else ''
606
+ _log(f" [{model_name}|F{fold}] ep {ep+1:>3}/{EPOCHS}"
607
+ f" │ Tr={train_mae:>6.1f} Val={val_mae:>6.1f}"
608
+ f" Best={bv:>6.1f}@{best_ep:<3}"
609
+ f" │ lr={lr_now:.0e}{swa_tag}"
610
+ f" │ {ep_time:.1f}s/ep ETA {eta_m:.0f}m")
611
+
612
+ model.load_state_dict(bw)
613
+ total_time = time.time() - fold_start
614
+ _log(f" [{model_name}|F{fold}] ✅ Done in {total_time/60:.1f}m │ Best Val MAE = {bv:.2f} @ epoch {best_ep}")
615
+
616
+ return bv, model
617
+
618
+
619
+ def evaluate_model(model, test_loader, device, tgt_mean=0., tgt_std=1.):
620
+ """Evaluate model MAE on test set (returns MAE in original scale)."""
621
+ model.eval()
622
+ ee, en_ = 0., 0
623
+ with torch.inference_mode():
624
+ for cb, gb, g_batch, tb in test_loader:
625
+ pred = model(cb, gb, g_batch) * tgt_std + tgt_mean
626
+ real = tb * tgt_std + tgt_mean
627
+ ee += (pred - real).abs().sum().item()
628
+ en_ += len(tb)
629
+ return ee / max(en_, 1)
630
+
631
+
632
+ # ═══════════════════════════════════════════════════════════════
633
+ # DUAL-GPU PARALLEL TRAINING
634
+ # ═══════════════════════════════════════════════════════════════
635
+
636
+ def _train_worker(model, tr_loader, vl_loader, te_loader, device,
637
+ fold, seed, model_name, result_dict, key,
638
+ tgt_mean=0., tgt_std=1.):
639
+ """Thread worker: train + evaluate one model on one GPU."""
640
+ try:
641
+ _, best_model = train_fold_core(
642
+ model, tr_loader, vl_loader, device, fold, seed, model_name,
643
+ tgt_mean=tgt_mean, tgt_std=tgt_std
644
+ )
645
+ mae = evaluate_model(best_model, te_loader, device, tgt_mean, tgt_std)
646
+ result_dict[key] = mae
647
+ _dash_state[model_name]['test_mae'] = mae
648
+ _dash_state[model_name]['done'] = True
649
+ _log(f" [{model_name}|F{fold}] 🏆 Test MAE = {mae:.2f} cm⁻¹")
650
+ del best_model
651
+ except Exception as e:
652
+ import traceback
653
+ _log(f" [{model_name}|F{fold}] ❌ ERROR: {e}\n{traceback.format_exc()}")
654
+ result_dict[key] = float('inf')
655
+ _dash_state[model_name]['done'] = True
656
+ finally:
657
+ if device.type == 'cuda':
658
+ torch.cuda.empty_cache()
659
+
660
+
661
+ # ═══════════════��═══════════════════════════════════════════════
662
+ # MAIN
663
+ # ═══════════════════════════════════════════════════════════════
664
+
665
+ def main():
666
+ t0 = time.time()
667
+
668
+ n_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
669
+
670
+ print(f"""
671
+ ╔══════════════════════════════════════════════════════════╗
672
+ ║ TRIADS V6 — Graph-TRM + Gate-Based Halting ║
673
+ ║ ║
674
+ ║ Gate-halt: {MIN_CYCLES}-{MAX_CYCLES} adaptive cycles, d={D} ║
675
+ ║ Deep supervision │ SWA (last {EPOCHS-SWA_START} ep) │ {EPOCHS} ep ║
676
+ ╚══════════════════════════════════════════════════════════╝
677
+ """)
678
+
679
+ device = torch.device('cuda:0' if n_gpus > 0 else 'cpu')
680
+ if n_gpus > 0:
681
+ name = torch.cuda.get_device_name(0)
682
+ mem = torch.cuda.get_device_properties(0).total_memory / 1e9
683
+ print(f" GPU: {name} ({mem:.1f} GB)")
684
+ torch.backends.cuda.matmul.allow_tf32 = True
685
+ torch.backends.cudnn.benchmark = True
686
+ else:
687
+ print(" ⚠ No GPU — training will be slow")
688
+
689
+ # ── LOAD DATASET ──────────────────────────────────────────
690
+ kaggle_path = "/kaggle/input/datasets/rudratiwari0099x/phonons-training-dataset/phonons_v6_dataset.pt"
691
+ local_path = "phonons_v6_dataset.pt"
692
+ ds_path = kaggle_path if os.path.exists(kaggle_path) else local_path
693
+ print(f" Loading {ds_path}...")
694
+ data = torch.load(ds_path, weights_only=False)
695
+ graphs = data['graphs']
696
+ comp_all = data['comp_features']
697
+ glob_phys = data['global_physics']
698
+ tgt_all = data['targets']
699
+ fold_indices = data['fold_indices']
700
+ N = data['n_samples']
701
+ comp_dim = comp_all.shape[1]
702
+ gp_dim = glob_phys.shape[1]
703
+ print(f" Dataset: {N} samples | comp_dim: {comp_dim} | global_phys: {gp_dim}")
704
+
705
+ # ── VERIFY FOLDS ──────────────────────────────────────────
706
+ for fi, (tr, te) in enumerate(fold_indices):
707
+ assert len(set(tr) & set(te)) == 0, f"LEAK in fold {fi}!"
708
+ print(" 5 folds: zero leakage ✓")
709
+
710
+ # ── MODEL SIZE CHECK ─────────────────────────────────────
711
+ m_test = PhononV6(comp_dim, gp_dim, mode='gate_halt',
712
+ min_cycles=MIN_CYCLES, max_cycles=MAX_CYCLES)
713
+ n_params = m_test.count_parameters()
714
+ print(f" Model (Gate-Halt TRM): {n_params:,} params")
715
+ del m_test
716
+ print()
717
+
718
+ # ── TRAINING ──────────────────────────────────────────────
719
+ tnp = tgt_all.numpy()
720
+ results = {}
721
+
722
+ _dashboard.start()
723
+ try:
724
+ for seed in SEEDS:
725
+ print(f" {'═'*3} Seed {seed} {'═'*55}")
726
+ ts = time.time()
727
+ fold_maes = {}
728
+
729
+ for fi, (tv_idx, te_idx) in enumerate(fold_indices):
730
+ tv_idx, te_idx = np.array(tv_idx), np.array(te_idx)
731
+ print(f"\n ┌─ Fold {fi+1}/5 {'─'*50}")
732
+
733
+ # Train/val split within train fold
734
+ tri, vli = strat_split(tnp[tv_idx], 0.15, seed + fi)
735
+
736
+ # Normalize targets (from train split ONLY — zero leakage)
737
+ tgt_mean = float(tgt_all[tv_idx[tri]].mean())
738
+ tgt_std = float(tgt_all[tv_idx[tri]].std()) + 1e-8
739
+ tgt_norm = (tgt_all - tgt_mean) / tgt_std
740
+ print(f" │ Target norm: mean={tgt_mean:.1f} std={tgt_std:.1f}")
741
+
742
+ # Scale features (ONLY from train split — zero leakage)
743
+ sc = StandardScaler().fit(comp_all[tv_idx[tri]].numpy())
744
+ cs = torch.tensor(
745
+ np.nan_to_num(sc.transform(comp_all.numpy()), nan=0.).astype(np.float32)
746
+ )
747
+ sc_gp = StandardScaler().fit(glob_phys[tv_idx[tri]].numpy())
748
+ gps = torch.tensor(
749
+ np.nan_to_num(sc_gp.transform(glob_phys.numpy()), nan=0.).astype(np.float32)
750
+ )
751
+
752
+ # Seed for reproducibility
753
+ torch.manual_seed(seed + fi)
754
+ np.random.seed(seed + fi)
755
+ if n_gpus > 0:
756
+ torch.cuda.manual_seed_all(seed + fi)
757
+
758
+ # Create model
759
+ model = PhononV6(comp_dim, gp_dim, mode='gate_halt',
760
+ min_cycles=MIN_CYCLES,
761
+ max_cycles=MAX_CYCLES).to(device)
762
+
763
+ # Build loaders with NORMALIZED targets
764
+ trl = Loader(graphs, cs, gps, tgt_norm, tv_idx[tri], BATCH_SIZE, device, True)
765
+ vll = Loader(graphs, cs, gps, tgt_norm, tv_idx[vli], BATCH_SIZE, device, False)
766
+ tel = Loader(graphs, cs, gps, tgt_norm, te_idx, BATCH_SIZE, device, False)
767
+
768
+ # Reset dashboard
769
+ _dash_state['GH']['done'] = False
770
+
771
+ # Train
772
+ _, best_model = train_fold_core(
773
+ model, trl, vll, device, fi+1, seed, "GH",
774
+ tgt_mean=tgt_mean, tgt_std=tgt_std
775
+ )
776
+ mae = evaluate_model(best_model, tel, device, tgt_mean, tgt_std)
777
+ fold_maes[fi] = mae
778
+ _dash_state['GH']['test_mae'] = mae
779
+ _dash_state['GH']['done'] = True
780
+ _log(f" [GH|F{fi+1}] 🏆 Test MAE = {mae:.2f} cm⁻¹")
781
+
782
+ # ── SAVE WEIGHTS ─────────────────────────────────────
783
+ os.makedirs('phonons_models_v6', exist_ok=True)
784
+ torch.save({
785
+ 'model_state': best_model.state_dict(),
786
+ 'test_mae': mae,
787
+ 'fold': fi + 1,
788
+ 'seed': seed,
789
+ 'comp_dim': comp_dim,
790
+ 'gp_dim': gp_dim,
791
+ }, f'phonons_models_v6/phonons_v6_s{seed}_f{fi+1}.pt')
792
+ _log(f" [GH|F{fi+1}] 💾 Saved phonons_models_v6/phonons_v6_s{seed}_f{fi+1}.pt")
793
+ # ─────────────────────────────────────────────────────
794
+
795
+ print(f" └─ Fold {fi+1} done │ MAE = {fold_maes[fi]:.2f} cm⁻¹")
796
+
797
+ del model, best_model
798
+ if n_gpus > 0: torch.cuda.empty_cache()
799
+
800
+ avg = np.mean(list(fold_maes.values()))
801
+ results[seed] = fold_maes
802
+ elapsed = time.time() - ts
803
+ print(f"\n Seed {seed} │ Avg MAE: {avg:.2f} │ {elapsed/60:.1f} min")
804
+
805
+ finally:
806
+ _dashboard.stop()
807
+
808
+ # ── FINAL RESULTS ─────────────────────────────────────────
809
+ fa = np.mean([np.mean(list(v.values())) for v in results.values()])
810
+
811
+ print(f"""
812
+ {'='*62}
813
+ FINAL RESULTS — V6 Gate-Halt TRM
814
+ {'='*62}
815
+
816
+ {'Model':<45} {'MAE':>10}
817
+ {'─'*57}""")
818
+ for n, v in sorted(BASELINES.items(), key=lambda x: x[1]):
819
+ beaten = ' ← BEATEN!' if fa < v else ''
820
+ print(f" {n:<45} {v:>10.2f}{beaten}")
821
+ print(f" {'V6 Gate-Halt TRM ('+str(n_params//1000)+'K, '+str(MIN_CYCLES)+'-'+str(MAX_CYCLES)+' cycles)':<45} {fa:>10.2f} ← OURS")
822
+ print(f" {'─'*57}")
823
+ print(f" Total time: {(time.time()-t0)/60:.1f} min")
824
+
825
+ # ── SAVE ──────────────────────────────────────────────────
826
+ res = {
827
+ 'model': 'V6-Gate-Halt-TRM', 'params': n_params,
828
+ 'cycles': f'{MIN_CYCLES}-{MAX_CYCLES}',
829
+ 'avg_mae': round(fa, 2),
830
+ 'per_fold': {str(s): {str(k): round(v, 2) for k,v in m.items()}
831
+ for s,m in results.items()},
832
+ }
833
+ with open('phonons_v6_results.json', 'w') as f:
834
+ json.dump(res, f, indent=2)
835
+ print(" Saved: phonons_v6_results.json\n")
836
+
837
+
838
+ if __name__ == '__main__':
839
+ main()
model_code/steels_model.py ADDED
@@ -0,0 +1,1056 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ╔══════════════════════════════════════════════════════════════════════╗
3
+ ║ TRM-MatSci V13 — 2-Layer SA + Multi-Seed Ensemble ║
4
+ ║ Dataset: matbench_steels │ 312 samples │ 5-Fold Nested CV ║
5
+ ║ ║
6
+ ║ V13A 2-Layer Self-Attention + Standard Deep Supervision ║
7
+ ║ d_attn=64, nhead=4, d_hidden=96, ff_dim=150, 20 steps ║
8
+ ║ Expanded features (Magpie + Mat2Vec + Extra descriptors) ║
9
+ ║ 2nd SA layer for higher-order property interactions ║
10
+ ║ 5-seed ensemble (avg predictions across seeds) ║
11
+ ║ ║
12
+ ║ V13B Same 2-Layer SA + Confidence-Weighted Deep Supervision ║
13
+ ║ 22 steps, confidence_head learns which step to trust ║
14
+ ║ 5-seed ensemble (avg predictions across seeds) ║
15
+ ║ ║
16
+ ║ All models: Deep Supervision + SWA + AdamW + 300 epochs ║
17
+ ║ Baseline: V12A = 95.99 MPa (current best) ║
18
+ ╚══════════════════════════════════════════════════════════════════════╝
19
+ """
20
+
21
+ import os, copy, json, time, logging, warnings, shutil, urllib.request
22
+ warnings.filterwarnings('ignore')
23
+
24
+ import numpy as np
25
+ import pandas as pd
26
+
27
+ import matplotlib
28
+ matplotlib.use('Agg')
29
+ import matplotlib.pyplot as plt
30
+ import matplotlib.gridspec as gridspec
31
+
32
+ from tqdm import tqdm
33
+
34
+ import torch
35
+ import torch.nn as nn
36
+ import torch.nn.functional as F
37
+ from torch.utils.data import Dataset, DataLoader
38
+ import torch.optim as optim
39
+ from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
40
+
41
+ from sklearn.model_selection import KFold
42
+ from sklearn.preprocessing import StandardScaler
43
+ from pymatgen.core import Composition
44
+ from matminer.featurizers.composition import ElementProperty
45
+ from gensim.models import Word2Vec
46
+
47
+ logging.basicConfig(level=logging.INFO, format='%(name)s │ %(message)s')
48
+ log = logging.getLogger("TRM13")
49
+
50
+ # Seeds for multi-seed ensemble
51
+ SEEDS = [42, 123, 7, 0, 99]
52
+ N_SEEDS = len(SEEDS)
53
+
54
+ BASELINES = {
55
+ 'TPOT-Mat': 79.9468,
56
+ 'AutoML-Mat': 82.3043,
57
+ 'MODNet': 87.7627,
58
+ 'RF-SCM/Magpie': 103.5125,
59
+ 'V12A (best)': 95.9900,
60
+ 'V11B': 102.3003,
61
+ 'V10A': 103.2867,
62
+ 'CrabNet': 107.3160,
63
+ 'Darwin': 123.2932,
64
+ }
65
+
66
+
67
+ # ══════════════════════════════════════════════════════════════════════
68
+ # 1. FEATURIZER + DATASET
69
+ # ══════════════════════════════════════════════════════════════════════
70
+
71
+ class ExpandedFeaturizer:
72
+ """Magpie (22 props × 6 stats) + Extra matminer descriptors + Mat2Vec (200d).
73
+
74
+ Extra descriptors: ElementFraction, Stoichiometry, ValenceOrbital,
75
+ IonProperty, BandCenter — all concatenated as a flat vector between
76
+ the Magpie block and Mat2Vec.
77
+ """
78
+ GCS = "https://storage.googleapis.com/mat2vec/"
79
+ FILES = ["pretrained_embeddings",
80
+ "pretrained_embeddings.wv.vectors.npy",
81
+ "pretrained_embeddings.trainables.syn1neg.npy"]
82
+
83
+ def __init__(self, cache="mat2vec_cache"):
84
+ from matminer.featurizers.composition import (
85
+ ElementFraction, Stoichiometry, ValenceOrbital,
86
+ IonProperty, BandCenter
87
+ )
88
+ from matminer.featurizers.base import MultipleFeaturizer
89
+
90
+ self.ep_magpie = ElementProperty.from_preset("magpie")
91
+ self.n_mg = len(self.ep_magpie.feature_labels())
92
+
93
+ self.extra_feats = MultipleFeaturizer([
94
+ ElementFraction(),
95
+ Stoichiometry(),
96
+ ValenceOrbital(),
97
+ IonProperty(),
98
+ BandCenter(),
99
+ ])
100
+ self.n_extra = None # detected at featurize time
101
+
102
+ self.scaler = None
103
+ os.makedirs(cache, exist_ok=True)
104
+ for f in self.FILES:
105
+ p = os.path.join(cache, f)
106
+ if not os.path.exists(p):
107
+ log.info(f" Downloading {f}...")
108
+ urllib.request.urlretrieve(self.GCS + f, p)
109
+ self.m2v = Word2Vec.load(os.path.join(cache, "pretrained_embeddings"))
110
+ self.emb = {w: self.m2v.wv[w] for w in self.m2v.wv.index_to_key}
111
+
112
+ def _pool(self, c):
113
+ v, t = np.zeros(200, np.float32), 0.0
114
+ for s, f in c.get_el_amt_dict().items():
115
+ if s in self.emb: v += f * self.emb[s]; t += f
116
+ return v / max(t, 1e-8)
117
+
118
+ def featurize_all(self, comps):
119
+ out = []
120
+ for c in tqdm(comps, desc=" Featurizing (expanded)", leave=False):
121
+ try: mg = np.array(self.ep_magpie.featurize(c), np.float32)
122
+ except: mg = np.zeros(self.n_mg, np.float32)
123
+
124
+ try:
125
+ ex = np.array(self.extra_feats.featurize(c), np.float32)
126
+ except:
127
+ ex = np.zeros(self.n_extra or 200, np.float32)
128
+
129
+ if self.n_extra is None:
130
+ self.n_extra = len(ex)
131
+ log.info(f"Expanded features: {self.n_mg} Magpie + "
132
+ f"{self.n_extra} Extra + 200 Mat2Vec = "
133
+ f"{self.n_mg + self.n_extra + 200}d")
134
+
135
+ out.append(np.concatenate([
136
+ np.nan_to_num(mg, nan=0.0),
137
+ np.nan_to_num(ex, nan=0.0),
138
+ self._pool(c)
139
+ ]))
140
+ return np.array(out)
141
+
142
+ def fit_scaler(self, X): self.scaler = StandardScaler().fit(X)
143
+ def transform(self, X):
144
+ if not self.scaler: return X
145
+ return np.nan_to_num(self.scaler.transform(X), nan=0.0).astype(np.float32)
146
+
147
+
148
+ class DSData(Dataset):
149
+ def __init__(self, X, y):
150
+ self.X = torch.tensor(X, dtype=torch.float32)
151
+ self.y = torch.tensor(np.array(y, np.float32))
152
+ def __len__(self): return len(self.y)
153
+ def __getitem__(self, i): return self.X[i], self.y[i]
154
+
155
+
156
+ # ══════════════════════════════════════════════════════════════════════
157
+ # 2. MODELS — with 2-Layer Self-Attention
158
+ # ══════════════════════════════════════════════════════════════════════
159
+
160
+ class DeepHybridTRM(nn.Module):
161
+ """V13A: 2-Layer SA Hybrid-TRM with Standard Deep Supervision.
162
+
163
+ Key difference from V12A's HybridTRM:
164
+ - TWO self-attention layers (SA1 → FF1 → SA2 → FF2 → CA)
165
+ - Each SA layer has its own residual + LayerNorm + FF block
166
+ - This enables higher-order property interaction modeling
167
+ (e.g., "correlation between electronegativity-range AND
168
+ atomic-radius-mean" requires composing two rounds of attention)
169
+ - Cross-attention (CA) to Mat2Vec context remains after SA stack
170
+
171
+ Everything else (MLP reasoning loop, deep supervision, SWA)
172
+ is identical to V12A.
173
+ """
174
+ def __init__(self, n_props=22, stat_dim=6, n_extra=0, mat2vec_dim=200,
175
+ d_attn=64, nhead=4, d_hidden=96, ff_dim=150,
176
+ dropout=0.2, max_steps=20, **kw):
177
+ super().__init__()
178
+ self.max_steps, self.D = max_steps, d_hidden
179
+ self.n_props, self.stat_dim = n_props, stat_dim
180
+ self.n_extra = n_extra
181
+
182
+ # ── Attention feature extractor (2-Layer SA) ──────────────────
183
+ self.tok_proj = nn.Sequential(
184
+ nn.Linear(stat_dim, d_attn), nn.LayerNorm(d_attn), nn.GELU())
185
+ self.m2v_proj = nn.Sequential(
186
+ nn.Linear(mat2vec_dim, d_attn), nn.LayerNorm(d_attn), nn.GELU())
187
+
188
+ # Self-Attention Layer 1
189
+ self.sa1 = nn.MultiheadAttention(
190
+ d_attn, nhead, dropout=dropout, batch_first=True)
191
+ self.sa1_n = nn.LayerNorm(d_attn)
192
+ self.sa1_ff = nn.Sequential(
193
+ nn.Linear(d_attn, d_attn*2), nn.GELU(), nn.Dropout(dropout),
194
+ nn.Linear(d_attn*2, d_attn))
195
+ self.sa1_fn = nn.LayerNorm(d_attn)
196
+
197
+ # Self-Attention Layer 2 (NEW — captures higher-order interactions)
198
+ self.sa2 = nn.MultiheadAttention(
199
+ d_attn, nhead, dropout=dropout, batch_first=True)
200
+ self.sa2_n = nn.LayerNorm(d_attn)
201
+ self.sa2_ff = nn.Sequential(
202
+ nn.Linear(d_attn, d_attn*2), nn.GELU(), nn.Dropout(dropout),
203
+ nn.Linear(d_attn*2, d_attn))
204
+ self.sa2_fn = nn.LayerNorm(d_attn)
205
+
206
+ # Cross-Attention to Mat2Vec context (after SA stack)
207
+ self.ca = nn.MultiheadAttention(
208
+ d_attn, nhead, dropout=dropout, batch_first=True)
209
+ self.ca_n = nn.LayerNorm(d_attn)
210
+
211
+ # Pool with optional extra feature injection
212
+ pool_in = d_attn + (n_extra if n_extra > 0 else 0)
213
+ self.pool = nn.Sequential(
214
+ nn.Linear(pool_in, d_hidden), nn.LayerNorm(d_hidden), nn.GELU())
215
+
216
+ # MLP-TRM recursive reasoning (shared weights)
217
+ self.z_up = nn.Sequential(
218
+ nn.Linear(d_hidden*3, ff_dim), nn.GELU(), nn.Dropout(dropout),
219
+ nn.Linear(ff_dim, d_hidden), nn.LayerNorm(d_hidden))
220
+ self.y_up = nn.Sequential(
221
+ nn.Linear(d_hidden*2, ff_dim), nn.GELU(), nn.Dropout(dropout),
222
+ nn.Linear(ff_dim, d_hidden), nn.LayerNorm(d_hidden))
223
+ self.head = nn.Linear(d_hidden, 1)
224
+ self._init()
225
+
226
+ def _init(self):
227
+ for m in self.modules():
228
+ if isinstance(m, nn.Linear):
229
+ nn.init.xavier_uniform_(m.weight)
230
+ if m.bias is not None: nn.init.zeros_(m.bias)
231
+
232
+ def _attention(self, x):
233
+ B = x.size(0)
234
+ mg_dim = self.n_props * self.stat_dim
235
+ mg = x[:, :mg_dim]
236
+
237
+ if self.n_extra > 0:
238
+ extra = x[:, mg_dim:mg_dim + self.n_extra]
239
+ m2v = x[:, mg_dim + self.n_extra:]
240
+ else:
241
+ extra = None
242
+ m2v = x[:, mg_dim:]
243
+
244
+ tok = self.tok_proj(mg.view(B, self.n_props, self.stat_dim))
245
+ ctx = self.m2v_proj(m2v).unsqueeze(1)
246
+
247
+ # SA Layer 1: learn pairwise property interactions
248
+ tok = self.sa1_n(tok + self.sa1(tok, tok, tok)[0])
249
+ tok = self.sa1_fn(tok + self.sa1_ff(tok))
250
+
251
+ # SA Layer 2: learn higher-order property interactions
252
+ tok = self.sa2_n(tok + self.sa2(tok, tok, tok)[0])
253
+ tok = self.sa2_fn(tok + self.sa2_ff(tok))
254
+
255
+ # Cross-Attention to Mat2Vec chemistry context
256
+ tok = self.ca_n(tok + self.ca(tok, ctx, ctx)[0])
257
+
258
+ pooled = tok.mean(dim=1) # [B, d_attn]
259
+
260
+ if extra is not None:
261
+ pooled = torch.cat([pooled, extra], dim=-1)
262
+
263
+ return self.pool(pooled) # [B, d_hidden]
264
+
265
+ def forward(self, x, deep_supervision=False, return_trajectory=False):
266
+ B = x.size(0)
267
+ xp = self._attention(x)
268
+ z = torch.zeros(B, self.D, device=x.device)
269
+ y = torch.zeros(B, self.D, device=x.device)
270
+ step_preds = []
271
+ for _ in range(self.max_steps):
272
+ z = z + self.z_up(torch.cat([xp, y, z], -1))
273
+ y = y + self.y_up(torch.cat([y, z], -1))
274
+ step_preds.append(self.head(y).squeeze(1))
275
+ if deep_supervision:
276
+ return step_preds
277
+ elif return_trajectory:
278
+ return step_preds[-1], step_preds
279
+ else:
280
+ return step_preds[-1]
281
+
282
+ def count_parameters(self):
283
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
284
+
285
+
286
+ class DeepConfidenceHybridTRM(nn.Module):
287
+ """V13B: 2-Layer SA Hybrid-TRM with Confidence-Weighted Deep Supervision.
288
+
289
+ Same 2-layer SA feature extractor as DeepHybridTRM, but with:
290
+ - confidence_head that learns which recursion step to trust
291
+ - Final prediction = softmax(confidence) · step_preds
292
+ - No ponder cost (avoids V11C's failure)
293
+ - 22 recursion steps (vs 20 for V13A)
294
+ """
295
+ def __init__(self, n_props=22, stat_dim=6, n_extra=0, mat2vec_dim=200,
296
+ d_attn=64, nhead=4, d_hidden=96, ff_dim=150,
297
+ dropout=0.2, max_steps=22, **kw):
298
+ super().__init__()
299
+ self.max_steps, self.D = max_steps, d_hidden
300
+ self.n_props, self.stat_dim = n_props, stat_dim
301
+ self.n_extra = n_extra
302
+
303
+ # ── Attention feature extractor (2-Layer SA) ──────────────────
304
+ self.tok_proj = nn.Sequential(
305
+ nn.Linear(stat_dim, d_attn), nn.LayerNorm(d_attn), nn.GELU())
306
+ self.m2v_proj = nn.Sequential(
307
+ nn.Linear(mat2vec_dim, d_attn), nn.LayerNorm(d_attn), nn.GELU())
308
+
309
+ # Self-Attention Layer 1
310
+ self.sa1 = nn.MultiheadAttention(
311
+ d_attn, nhead, dropout=dropout, batch_first=True)
312
+ self.sa1_n = nn.LayerNorm(d_attn)
313
+ self.sa1_ff = nn.Sequential(
314
+ nn.Linear(d_attn, d_attn*2), nn.GELU(), nn.Dropout(dropout),
315
+ nn.Linear(d_attn*2, d_attn))
316
+ self.sa1_fn = nn.LayerNorm(d_attn)
317
+
318
+ # Self-Attention Layer 2 (higher-order interactions)
319
+ self.sa2 = nn.MultiheadAttention(
320
+ d_attn, nhead, dropout=dropout, batch_first=True)
321
+ self.sa2_n = nn.LayerNorm(d_attn)
322
+ self.sa2_ff = nn.Sequential(
323
+ nn.Linear(d_attn, d_attn*2), nn.GELU(), nn.Dropout(dropout),
324
+ nn.Linear(d_attn*2, d_attn))
325
+ self.sa2_fn = nn.LayerNorm(d_attn)
326
+
327
+ # Cross-Attention to Mat2Vec context
328
+ self.ca = nn.MultiheadAttention(
329
+ d_attn, nhead, dropout=dropout, batch_first=True)
330
+ self.ca_n = nn.LayerNorm(d_attn)
331
+
332
+ # Pool with optional extra feature injection
333
+ pool_in = d_attn + (n_extra if n_extra > 0 else 0)
334
+ self.pool = nn.Sequential(
335
+ nn.Linear(pool_in, d_hidden), nn.LayerNorm(d_hidden), nn.GELU())
336
+
337
+ # MLP-TRM recursive reasoning (shared weights)
338
+ self.z_up = nn.Sequential(
339
+ nn.Linear(d_hidden*3, ff_dim), nn.GELU(), nn.Dropout(dropout),
340
+ nn.Linear(ff_dim, d_hidden), nn.LayerNorm(d_hidden))
341
+ self.y_up = nn.Sequential(
342
+ nn.Linear(d_hidden*2, ff_dim), nn.GELU(), nn.Dropout(dropout),
343
+ nn.Linear(ff_dim, d_hidden), nn.LayerNorm(d_hidden))
344
+ self.head = nn.Linear(d_hidden, 1)
345
+
346
+ # ── Confidence head: learns which step to trust ──────────────
347
+ self.confidence_head = nn.Sequential(
348
+ nn.Linear(d_hidden, d_hidden // 2), nn.GELU(),
349
+ nn.Linear(d_hidden // 2, 1)) # raw logit, softmaxed later
350
+
351
+ self._init()
352
+
353
+ def _init(self):
354
+ for m in self.modules():
355
+ if isinstance(m, nn.Linear):
356
+ nn.init.xavier_uniform_(m.weight)
357
+ if m.bias is not None: nn.init.zeros_(m.bias)
358
+ with torch.no_grad():
359
+ nn.init.zeros_(self.confidence_head[-1].bias)
360
+
361
+ def _attention(self, x):
362
+ B = x.size(0)
363
+ mg_dim = self.n_props * self.stat_dim
364
+ mg = x[:, :mg_dim]
365
+
366
+ if self.n_extra > 0:
367
+ extra = x[:, mg_dim:mg_dim + self.n_extra]
368
+ m2v = x[:, mg_dim + self.n_extra:]
369
+ else:
370
+ extra = None
371
+ m2v = x[:, mg_dim:]
372
+
373
+ tok = self.tok_proj(mg.view(B, self.n_props, self.stat_dim))
374
+ ctx = self.m2v_proj(m2v).unsqueeze(1)
375
+
376
+ # SA Layer 1
377
+ tok = self.sa1_n(tok + self.sa1(tok, tok, tok)[0])
378
+ tok = self.sa1_fn(tok + self.sa1_ff(tok))
379
+
380
+ # SA Layer 2
381
+ tok = self.sa2_n(tok + self.sa2(tok, tok, tok)[0])
382
+ tok = self.sa2_fn(tok + self.sa2_ff(tok))
383
+
384
+ # Cross-Attention
385
+ tok = self.ca_n(tok + self.ca(tok, ctx, ctx)[0])
386
+
387
+ pooled = tok.mean(dim=1)
388
+
389
+ if extra is not None:
390
+ pooled = torch.cat([pooled, extra], dim=-1)
391
+
392
+ return self.pool(pooled)
393
+
394
+ def forward(self, x, deep_supervision=False, return_confidence=False):
395
+ """Forward pass.
396
+
397
+ Returns:
398
+ deep_supervision=True: (step_preds, confidence_logits)
399
+ deep_supervision=False, return_confidence=False:
400
+ weighted_pred: [B] confidence-weighted prediction
401
+ deep_supervision=False, return_confidence=True:
402
+ (weighted_pred, confidence_weights): [B], [B, max_steps]
403
+ """
404
+ B = x.size(0)
405
+ xp = self._attention(x)
406
+ z = torch.zeros(B, self.D, device=x.device)
407
+ y = torch.zeros(B, self.D, device=x.device)
408
+
409
+ step_preds = []
410
+ conf_logits = []
411
+
412
+ for _ in range(self.max_steps):
413
+ z = z + self.z_up(torch.cat([xp, y, z], -1))
414
+ y = y + self.y_up(torch.cat([y, z], -1))
415
+ step_preds.append(self.head(y).squeeze(1))
416
+ conf_logits.append(self.confidence_head(y).squeeze(1))
417
+
418
+ conf_logits = torch.stack(conf_logits, dim=1) # [B, max_steps]
419
+
420
+ if deep_supervision:
421
+ return step_preds, conf_logits
422
+
423
+ # Confidence-weighted prediction
424
+ conf_weights = F.softmax(conf_logits, dim=1) # [B, max_steps]
425
+ preds_stack = torch.stack(step_preds, dim=1) # [B, max_steps]
426
+ weighted_pred = (preds_stack * conf_weights).sum(dim=1) # [B]
427
+
428
+ if return_confidence:
429
+ return weighted_pred, conf_weights
430
+ return weighted_pred
431
+
432
+ def count_parameters(self):
433
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
434
+
435
+
436
+ # ══════════════════════════════════════════════════════════════════════
437
+ # 3. LOSS FUNCTIONS
438
+ # ══════════════════════════════════════════════════════════════════════
439
+
440
+ def deep_supervision_loss(step_preds, targets):
441
+ """Linear-weighted L1 loss across all recursion steps."""
442
+ n = len(step_preds)
443
+ weights = [(i + 1) for i in range(n)]
444
+ total_w = sum(weights)
445
+ loss = 0.0
446
+ for pred, w in zip(step_preds, weights):
447
+ loss += (w / total_w) * F.l1_loss(pred, targets)
448
+ return loss
449
+
450
+
451
+ def confidence_ds_loss(step_preds, targets, conf_logits):
452
+ """Advanced Deep Supervision: standard DS + confidence-weighted L1.
453
+
454
+ Components:
455
+ 1. Standard linear-weighted deep supervision on all steps
456
+ 2. L1 loss on the confidence-weighted final prediction
457
+ """
458
+ ds = deep_supervision_loss(step_preds, targets)
459
+
460
+ conf_weights = F.softmax(conf_logits, dim=1) # [B, max_steps]
461
+ preds_stack = torch.stack(step_preds, dim=1) # [B, max_steps]
462
+ weighted_pred = (preds_stack * conf_weights).sum(dim=1)
463
+ conf_loss = F.l1_loss(weighted_pred, targets)
464
+
465
+ return ds + conf_loss
466
+
467
+
468
+ # ══════════════════════════════════════════════════════════════════════
469
+ # 4. UTILS + TRAINING
470
+ # ══════════════════════════════════════════════════════════════════════
471
+
472
+ def strat_split(targets, val_size=0.15, seed=42):
473
+ bins = np.percentile(targets, [25, 50, 75])
474
+ lbl = np.digitize(targets, bins)
475
+ tr, vl = [], []
476
+ rng = np.random.RandomState(seed)
477
+ for b in range(4):
478
+ m = np.where(lbl == b)[0]
479
+ if len(m) == 0: continue
480
+ n = max(1, int(len(m) * val_size))
481
+ c = rng.choice(m, n, replace=False)
482
+ vl.extend(c.tolist()); tr.extend(np.setdiff1d(m, c).tolist())
483
+ return np.array(tr), np.array(vl)
484
+
485
+
486
+ def train_fold_standard(model, tr_dl, vl_dl, device,
487
+ epochs=300, swa_start=200, fold=1, name=""):
488
+ """Training with standard deep supervision."""
489
+ opt = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
490
+ sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=swa_start, eta_min=1e-4)
491
+ swa_m = AveragedModel(model)
492
+ swa_s = SWALR(opt, swa_lr=5e-4)
493
+ swa_on = False
494
+ best_v, best_w = float('inf'), copy.deepcopy(model.state_dict())
495
+ hist = {'train': [], 'val': []}
496
+
497
+ pbar = tqdm(range(epochs), desc=f" [{name}] F{fold}/5",
498
+ leave=False, ncols=120)
499
+ for ep in pbar:
500
+ model.train(); tl = 0.0
501
+ for bx, by in tr_dl:
502
+ bx, by = bx.to(device), by.to(device)
503
+ step_preds = model(bx, deep_supervision=True)
504
+ loss = deep_supervision_loss(step_preds, by)
505
+ opt.zero_grad(set_to_none=True); loss.backward()
506
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
507
+ opt.step()
508
+ tl += F.l1_loss(step_preds[-1], by).item() * len(by)
509
+ tl /= len(tr_dl.dataset)
510
+
511
+ model.eval(); vl = 0.0
512
+ with torch.no_grad():
513
+ for bx, by in vl_dl:
514
+ bx, by = bx.to(device), by.to(device)
515
+ pred = model(bx)
516
+ vl += F.l1_loss(pred, by).item() * len(by)
517
+ vl /= len(vl_dl.dataset)
518
+ hist['train'].append(tl); hist['val'].append(vl)
519
+
520
+ if ep < swa_start:
521
+ sch.step()
522
+ if vl < best_v: best_v, best_w = vl, copy.deepcopy(model.state_dict())
523
+ else:
524
+ if not swa_on: swa_on = True
525
+ swa_m.update_parameters(model); swa_s.step()
526
+
527
+ pbar.set_postfix(Tr=f'{tl:.1f}', Val=f'{vl:.1f}',
528
+ Best=f'{best_v:.1f}', Ph='SWA' if swa_on else 'COS')
529
+
530
+ if swa_on:
531
+ update_bn(tr_dl, swa_m, device=device)
532
+ model.load_state_dict(swa_m.module.state_dict())
533
+ else:
534
+ model.load_state_dict(best_w)
535
+ return best_v, model, hist
536
+
537
+
538
+ def train_fold_confidence(model, tr_dl, vl_dl, device,
539
+ epochs=300, swa_start=200, fold=1, name=""):
540
+ """Training with confidence-weighted deep supervision."""
541
+ opt = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
542
+ sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=swa_start, eta_min=1e-4)
543
+ swa_m = AveragedModel(model)
544
+ swa_s = SWALR(opt, swa_lr=5e-4)
545
+ swa_on = False
546
+ best_v, best_w = float('inf'), copy.deepcopy(model.state_dict())
547
+ hist = {'train': [], 'val': []}
548
+
549
+ pbar = tqdm(range(epochs), desc=f" [{name}] F{fold}/5",
550
+ leave=False, ncols=120)
551
+ for ep in pbar:
552
+ model.train(); tl = 0.0
553
+ for bx, by in tr_dl:
554
+ bx, by = bx.to(device), by.to(device)
555
+ step_preds, conf_logits = model(bx, deep_supervision=True)
556
+ loss = confidence_ds_loss(step_preds, by, conf_logits)
557
+ opt.zero_grad(set_to_none=True); loss.backward()
558
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
559
+ opt.step()
560
+ # Track confidence-weighted MAE for display
561
+ with torch.no_grad():
562
+ cw = F.softmax(conf_logits, dim=1)
563
+ ps = torch.stack(step_preds, dim=1)
564
+ wp = (ps * cw).sum(dim=1)
565
+ tl += F.l1_loss(wp, by).item() * len(by)
566
+ tl /= len(tr_dl.dataset)
567
+
568
+ model.eval(); vl = 0.0
569
+ with torch.no_grad():
570
+ for bx, by in vl_dl:
571
+ bx, by = bx.to(device), by.to(device)
572
+ pred = model(bx) # uses confidence-weighted by default
573
+ vl += F.l1_loss(pred, by).item() * len(by)
574
+ vl /= len(vl_dl.dataset)
575
+ hist['train'].append(tl); hist['val'].append(vl)
576
+
577
+ if ep < swa_start:
578
+ sch.step()
579
+ if vl < best_v: best_v, best_w = vl, copy.deepcopy(model.state_dict())
580
+ else:
581
+ if not swa_on: swa_on = True
582
+ swa_m.update_parameters(model); swa_s.step()
583
+
584
+ pbar.set_postfix(Tr=f'{tl:.1f}', Val=f'{vl:.1f}',
585
+ Best=f'{best_v:.1f}', Ph='SWA' if swa_on else 'COS')
586
+
587
+ if swa_on:
588
+ update_bn(tr_dl, swa_m, device=device)
589
+ model.load_state_dict(swa_m.module.state_dict())
590
+ else:
591
+ model.load_state_dict(best_w)
592
+ return best_v, model, hist
593
+
594
+
595
+ def predict(model, dl, device):
596
+ model.eval(); preds = []
597
+ with torch.no_grad():
598
+ for bx, _ in dl:
599
+ preds.append(model(bx.to(device)).cpu())
600
+ return torch.cat(preds)
601
+
602
+
603
+ def predict_confidence(model, dl, device):
604
+ """Predict using confidence model, also return per-step weights."""
605
+ model.eval()
606
+ all_preds, all_weights = [], []
607
+ with torch.no_grad():
608
+ for bx, _ in dl:
609
+ pred, weights = model(bx.to(device), return_confidence=True)
610
+ all_preds.append(pred.cpu())
611
+ all_weights.append(weights.cpu())
612
+ return torch.cat(all_preds), torch.cat(all_weights)
613
+
614
+
615
+ def get_targets(dl):
616
+ tgts = []
617
+ for _, by in dl: tgts.append(by)
618
+ return torch.cat(tgts)
619
+
620
+
621
+ # ══════════════════════════════════════════════════════════════════════
622
+ # 5. MAIN BENCHMARK — Multi-Seed Ensemble
623
+ # ══════════════════════════════════════════════════════════════════════
624
+
625
+ def run_benchmark():
626
+ t0 = time.time()
627
+ print("\n" + "═"*72)
628
+ print(" TRM-MatSci V13 │ 2-Layer SA + Multi-Seed Ensemble │ matbench_steels")
629
+ print(" V13A: 2-Layer SA + expanded features + standard DS (5-seed ensemble)")
630
+ print(" V13B: 2-Layer SA + expanded features + confidence DS (5-seed ensemble)")
631
+ print(f" Seeds: {SEEDS}")
632
+ print("═"*72 + "\n")
633
+
634
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
635
+ if device.type == 'cuda':
636
+ log.info(f"GPU: {torch.cuda.get_device_name(0)} "
637
+ f"({torch.cuda.get_device_properties(0).total_mem/1e9:.1f} GB)")
638
+ torch.backends.cuda.matmul.allow_tf32 = True
639
+ torch.backends.cudnn.benchmark = True
640
+
641
+ log.info("Loading matbench_steels...")
642
+ from matminer.datasets import load_dataset
643
+ df = load_dataset("matbench_steels")
644
+ comps_raw = df['composition'].tolist()
645
+ targets_all = np.array(df['yield strength'].tolist(), np.float32)
646
+ comps_all = [Composition(c) for c in comps_raw]
647
+
648
+ # ── FEATURIZE ─────────────────────────────────────────────────────
649
+ log.info("Computing EXPANDED features...")
650
+ feat = ExpandedFeaturizer()
651
+ X_all = feat.featurize_all(comps_all)
652
+ n_extra = feat.n_extra
653
+ log.info(f"Features: {X_all.shape} (n_extra={n_extra})")
654
+
655
+ kfold = KFold(n_splits=5, shuffle=True, random_state=18012019)
656
+ folds = list(kfold.split(comps_all))
657
+ os.makedirs('trm_models_v13', exist_ok=True)
658
+ dl_kw = dict(batch_size=32, num_workers=0)
659
+
660
+ # ── CONFIGS ───────────────────────────────────────────────────────
661
+ shared_kw = dict(n_props=22, stat_dim=6, n_extra=n_extra,
662
+ mat2vec_dim=200, d_attn=64, nhead=4,
663
+ d_hidden=96, ff_dim=150, dropout=0.2)
664
+
665
+ configs = {
666
+ 'V13A-2xSA-StdDS': {
667
+ 'model_cls': DeepHybridTRM,
668
+ 'model_kw': {**shared_kw, 'max_steps': 20},
669
+ 'train_fn': train_fold_standard,
670
+ 'predict_fn': predict,
671
+ 'is_confidence': False,
672
+ },
673
+ 'V13B-2xSA-ConfDS': {
674
+ 'model_cls': DeepConfidenceHybridTRM,
675
+ 'model_kw': {**shared_kw, 'max_steps': 22},
676
+ 'train_fn': train_fold_confidence,
677
+ 'predict_fn': None, # uses predict_confidence
678
+ 'is_confidence': True,
679
+ },
680
+ }
681
+
682
+ # Print param counts
683
+ print(f"\n {'Config':<24} {'Params':>10} {'Steps':>8} {'Seeds':>6}")
684
+ print(f" {'─'*54}")
685
+ for cname, cfg in configs.items():
686
+ _m = cfg['model_cls'](**cfg['model_kw'])
687
+ np_ = _m.count_parameters(); del _m
688
+ cfg['n_params'] = np_
689
+ steps = cfg['model_kw']['max_steps']
690
+ print(f" {cname:<24} {np_:>10,} {steps:>8} {N_SEEDS:>6}")
691
+ print()
692
+
693
+ # ── TRAIN + EVALUATE (Multi-Seed) ─────────────────────────────────
694
+ all_results = {}
695
+ all_hists = {}
696
+ all_conf_weights = {}
697
+
698
+ for cname, cfg in configs.items():
699
+ print(f"\n{'▓'*72}")
700
+ print(f" {cname} — {N_SEEDS}-Seed Ensemble")
701
+ print(f"{'▓'*72}")
702
+
703
+ # Store per-seed, per-fold predictions and MAEs
704
+ seed_fold_preds = {s: {} for s in SEEDS} # seed -> {fold_idx: preds_tensor}
705
+ seed_fold_maes = {s: [] for s in SEEDS} # seed -> [mae_f1, ..., mae_f5]
706
+ fold_hists = [] # collect from first seed only
707
+ fold_conf_w = [] # collect from first seed only
708
+
709
+ for si, seed in enumerate(SEEDS):
710
+ print(f"\n ╔═══ Seed {seed} ({si+1}/{N_SEEDS}) ═══╗")
711
+
712
+ for fi, (tv_i, te_i) in enumerate(folds):
713
+ print(f"\n ── [{cname} seed={seed}] Fold {fi+1}/5 {'─'*30}")
714
+
715
+ tri, vli = strat_split(targets_all[tv_i], 0.15, seed+fi)
716
+ feat.fit_scaler(X_all[tv_i][tri])
717
+ tr_s = feat.transform(X_all[tv_i][tri])
718
+ vl_s = feat.transform(X_all[tv_i][vli])
719
+ te_s = feat.transform(X_all[te_i])
720
+
721
+ pin = device.type == 'cuda'
722
+ tr_dl = DataLoader(DSData(tr_s, targets_all[tv_i][tri]), shuffle=True,
723
+ pin_memory=pin, **dl_kw)
724
+ vl_dl = DataLoader(DSData(vl_s, targets_all[tv_i][vli]), shuffle=False,
725
+ pin_memory=pin, **dl_kw)
726
+ te_dl = DataLoader(DSData(te_s, targets_all[te_i]), shuffle=False,
727
+ pin_memory=pin, **dl_kw)
728
+ te_tgt = get_targets(te_dl)
729
+
730
+ torch.manual_seed(seed + fi); np.random.seed(seed + fi)
731
+ if device.type == 'cuda': torch.cuda.manual_seed(seed + fi)
732
+
733
+ model = cfg['model_cls'](**cfg['model_kw']).to(device)
734
+ bv, model, hist = cfg['train_fn'](model, tr_dl, vl_dl, device,
735
+ fold=fi+1,
736
+ name=f"{cname}[s{seed}]")
737
+
738
+ # Save hist only for first seed
739
+ if si == 0:
740
+ fold_hists.append(hist)
741
+
742
+ # Predict
743
+ if cfg['is_confidence']:
744
+ pred, conf_w = predict_confidence(model, te_dl, device)
745
+ if si == 0:
746
+ fold_conf_w.append(conf_w)
747
+ avg_peak = conf_w.argmax(dim=1).float().mean().item() + 1
748
+ mae = F.l1_loss(pred, te_tgt).item()
749
+ log.info(f" [s{seed}] F{fi+1}: MAE={mae:.2f} "
750
+ f"(val {bv:.2f}, avg peak step={avg_peak:.1f})")
751
+ else:
752
+ pred = cfg['predict_fn'](model, te_dl, device)
753
+ mae = F.l1_loss(pred, te_tgt).item()
754
+ log.info(f" [s{seed}] F{fi+1}: MAE={mae:.2f} (val {bv:.2f})")
755
+
756
+ seed_fold_preds[seed][fi] = pred
757
+ seed_fold_maes[seed].append(mae)
758
+
759
+ torch.save({'model_state': model.state_dict(), 'test_mae': mae,
760
+ 'config': cname, 'seed': seed},
761
+ f'trm_models_v13/{cname}_seed{seed}_fold{fi+1}.pt')
762
+
763
+ # Free GPU memory
764
+ del model; torch.cuda.empty_cache() if device.type == 'cuda' else None
765
+
766
+ seed_avg = float(np.mean(seed_fold_maes[seed]))
767
+ print(f" ╚═══ Seed {seed} avg: {seed_avg:.2f} MPa ═══╝")
768
+
769
+ # ── Compute ensemble predictions ──────────────────────────────
770
+ ensemble_fold_maes = []
771
+ for fi, (tv_i, te_i) in enumerate(folds):
772
+ te_tgt_np = targets_all[te_i]
773
+ te_tgt_t = torch.tensor(te_tgt_np, dtype=torch.float32)
774
+
775
+ # Average predictions across all seeds for this fold
776
+ all_seed_preds = torch.stack([seed_fold_preds[s][fi] for s in SEEDS])
777
+ ensemble_pred = all_seed_preds.mean(dim=0)
778
+
779
+ ens_mae = F.l1_loss(ensemble_pred, te_tgt_t).item()
780
+ ensemble_fold_maes.append(ens_mae)
781
+
782
+ ens_avg = float(np.mean(ensemble_fold_maes))
783
+ ens_std = float(np.std(ensemble_fold_maes))
784
+
785
+ # Also compute per-seed averages for reporting
786
+ per_seed_avgs = {s: float(np.mean(seed_fold_maes[s])) for s in SEEDS}
787
+ best_single_seed = min(per_seed_avgs.items(), key=lambda x: x[1])
788
+
789
+ all_results[cname] = {
790
+ 'avg': ens_avg, 'std': ens_std, 'folds': ensemble_fold_maes,
791
+ 'params': cfg['n_params'],
792
+ 'per_seed_avgs': per_seed_avgs,
793
+ 'per_seed_folds': {str(s): seed_fold_maes[s] for s in SEEDS},
794
+ 'best_single_seed': best_single_seed[0],
795
+ 'best_single_mae': best_single_seed[1],
796
+ }
797
+ all_hists[cname] = fold_hists
798
+ if fold_conf_w:
799
+ all_conf_weights[cname] = fold_conf_w
800
+
801
+ print(f"\n ═══ {cname} ═══")
802
+ print(f" Ensemble ({N_SEEDS}-seed avg): {ens_avg:.4f} ±{ens_std:.4f} MPa")
803
+ print(f" Best single seed ({best_single_seed[0]}): "
804
+ f"{best_single_seed[1]:.4f} MPa")
805
+ for s in SEEDS:
806
+ print(f" Seed {s:>3}: {per_seed_avgs[s]:.2f} MPa "
807
+ f"folds={[f'{m:.1f}' for m in seed_fold_maes[s]]}")
808
+
809
+ # ══════════════════════════════════════════════════════════════════
810
+ # FINAL RESULTS
811
+ # ══════════════════════════════════════════════════════════════════
812
+
813
+ tt = time.time() - t0
814
+ print(f"\n{'═'*72}")
815
+ print(f" FINAL LEADERBOARD — matbench_steels V13 (5-Fold Avg MAE)")
816
+ print(f"{'═'*72}")
817
+ print(f" {'Model':<26} {'Params':>10} {'MAE(MPa)':>10} {'±Std':>8} Notes")
818
+ print(f" {'─'*72}")
819
+ for n, r in sorted(all_results.items(), key=lambda x: x[1]['avg']):
820
+ tag = (" ← BEATS MODNet 🏆" if r['avg'] < 87.76 else
821
+ " ← BEATS V12A ✓" if r['avg'] < 95.99 else
822
+ " ← BEATS RF-SCM ✓" if r['avg'] < 103.51 else
823
+ " ← BEATS DARWIN ✓" if r['avg'] < 123.29 else "")
824
+ print(f" {n+' (ens)':<26} {r['params']:>9,} "
825
+ f"{r['avg']:>10.4f} {r['std']:>8.4f}{tag}")
826
+ print(f" {n+' (best 1)':<26} {'':>10} "
827
+ f"{r['best_single_mae']:>10.4f} {'':>8} seed={r['best_single_seed']}")
828
+ print(f" {'─'*72}")
829
+ for bn, bv in sorted(BASELINES.items(), key=lambda x: x[1]):
830
+ print(f" {bn:<26} {'baseline':>10} {bv:>10.4f}")
831
+ print(f"\n Total time: {tt/60:.1f} minutes ({N_SEEDS} seeds × 2 configs × 5 folds)")
832
+
833
+ # Per-fold ensemble breakdown
834
+ print(f"\n{'═'*72}")
835
+ print(f" PER-FOLD ENSEMBLE BREAKDOWN")
836
+ print(f"{'═'*72}")
837
+ cnames = list(all_results.keys())
838
+ header = f" {'Fold':<6}"
839
+ for cn in cnames:
840
+ header += f" {cn:>20}"
841
+ print(header)
842
+ print(f" {'─'*52}")
843
+ for fi in range(5):
844
+ row = f" {fi+1:<6}"
845
+ for cn in cnames:
846
+ row += f" {all_results[cn]['folds'][fi]:>20.2f}"
847
+ print(row)
848
+
849
+ # Per-seed breakdown
850
+ print(f"\n{'═'*72}")
851
+ print(f" PER-SEED BREAKDOWN")
852
+ print(f"{'═'*72}")
853
+ for cn in cnames:
854
+ r = all_results[cn]
855
+ print(f"\n {cn}:")
856
+ header = f" {'Seed':<6}"
857
+ for fi in range(5):
858
+ header += f" {'F'+str(fi+1):>8}"
859
+ header += f" {'Avg':>8}"
860
+ print(header)
861
+ print(f" {'─'*52}")
862
+ for s in SEEDS:
863
+ row = f" {s:<6}"
864
+ for mae in r['per_seed_folds'][str(s)]:
865
+ row += f" {mae:>8.2f}"
866
+ row += f" {r['per_seed_avgs'][s]:>8.2f}"
867
+ print(row)
868
+ print(f" {'─'*52}")
869
+ row = f" {'ENS':<6}"
870
+ for mae in r['folds']:
871
+ row += f" {mae:>8.2f}"
872
+ row += f" {r['avg']:>8.2f}"
873
+ print(row)
874
+
875
+ # Confidence stats
876
+ if all_conf_weights:
877
+ print(f"\n Confidence Step Selection Summary:")
878
+ for cn, fw_list in all_conf_weights.items():
879
+ all_w = torch.cat(fw_list, dim=0)
880
+ avg_w = all_w.mean(dim=0)
881
+ peak_step = avg_w.argmax().item() + 1
882
+ avg_peak = all_w.argmax(dim=1).float().mean().item() + 1
883
+ print(f" {cn}: avg peak step={avg_peak:.1f}, "
884
+ f"population peak=step {peak_step}")
885
+ print()
886
+
887
+ generate_plots(all_results, all_hists, all_conf_weights)
888
+ save_summary(all_results, all_hists, all_conf_weights, tt)
889
+ return all_results
890
+
891
+
892
+ # ══════════════════════════════════════════════════════════════════════
893
+ # 6. PLOTS
894
+ # ══════════════════════════════════════════════════════════════════════
895
+
896
+ PAL = {'V13A-2xSA-StdDS': '#1565C0', 'V13B-2xSA-ConfDS': '#E65100'}
897
+
898
+ def generate_plots(all_results, all_hists, all_conf_weights):
899
+ names = list(all_results.keys())
900
+ avgs = [all_results[n]['avg'] for n in names]
901
+ stds = [all_results[n]['std'] for n in names]
902
+ cols = [PAL.get(n, '#888') for n in names]
903
+
904
+ fig = plt.figure(figsize=(22, 18))
905
+ gs = gridspec.GridSpec(2, 2, figure=fig, hspace=0.35, wspace=0.30)
906
+
907
+ # ── Plot 1: Bar chart vs baselines ────────────────────────────────
908
+ ax1 = fig.add_subplot(gs[0, 0])
909
+
910
+ # Show both ensemble and best-single-seed bars
911
+ x_pos = np.arange(len(names))
912
+ w = 0.35
913
+ ens_bars = ax1.bar(x_pos - w/2, avgs, w, yerr=stds, capsize=6,
914
+ color=cols, alpha=0.88, edgecolor='white',
915
+ linewidth=1.5, label='Ensemble')
916
+ best_singles = [all_results[n]['best_single_mae'] for n in names]
917
+ single_bars = ax1.bar(x_pos + w/2, best_singles, w, capsize=6,
918
+ color=cols, alpha=0.45, edgecolor='white',
919
+ linewidth=1.5, label='Best Single Seed',
920
+ hatch='//')
921
+
922
+ for bv, c, ls, lb in [
923
+ (87.76, '#F57F17', '--', 'MODNet (87.76)'),
924
+ (95.99, '#4CAF50', '-.', 'V12A (95.99)'),
925
+ (102.30, '#9E9E9E', '-.', 'V11B (102.30)'),
926
+ (103.51, '#B0BEC5', ':', 'RF-SCM (103.51)'),
927
+ (107.32, '#FF9800', ':', 'CrabNet (107.32)'),
928
+ ]:
929
+ ax1.axhline(bv, color=c, linestyle=ls, linewidth=1.8, label=lb, alpha=0.85)
930
+ for bar, m, s in zip(ens_bars, avgs, stds):
931
+ ax1.text(bar.get_x()+bar.get_width()/2, bar.get_height()+s+1,
932
+ f'{m:.1f}', ha='center', fontsize=11, fontweight='bold')
933
+ for bar, m in zip(single_bars, best_singles):
934
+ ax1.text(bar.get_x()+bar.get_width()/2, bar.get_height()+1,
935
+ f'{m:.1f}', ha='center', fontsize=9, fontstyle='italic',
936
+ alpha=0.7)
937
+
938
+ ax1.set_xticks(x_pos)
939
+ ax1.set_xticklabels(names, fontsize=8)
940
+ ax1.legend(fontsize=6, loc='upper right')
941
+ ax1.set_ylabel('MAE (MPa)'); ax1.set_ylim(0, max(avgs)*1.6)
942
+ ax1.set_title('V13 Results vs Baselines (Ensemble + Best Single)',
943
+ fontsize=11, fontweight='bold')
944
+ ax1.grid(axis='y', alpha=0.3)
945
+
946
+ # ── Plot 2: Per-fold grouped bars ─────────────────────────────────
947
+ ax2 = fig.add_subplot(gs[0, 1])
948
+ x = np.arange(1, 6)
949
+ w = 0.35
950
+ for i, (n, col) in enumerate(zip(names, cols)):
951
+ fold_vals = all_results[n]['folds']
952
+ ax2.bar(x + (i - 0.5) * w, fold_vals, w, color=col, alpha=0.8,
953
+ label=n + ' (ens)', edgecolor='white')
954
+ ax2.axhline(95.99, color='#4CAF50', ls='-.', lw=1.5, label='V12A (95.99)')
955
+ ax2.axhline(87.76, color='#F57F17', ls='--', lw=1.5, label='MODNet (87.76)')
956
+ ax2.set_xlabel('Fold'); ax2.set_ylabel('MAE (MPa)')
957
+ ax2.set_xticks(x); ax2.set_xticklabels([f'F{i}' for i in range(1,6)])
958
+ ax2.set_title('Per-Fold Ensemble Breakdown', fontweight='bold')
959
+ ax2.legend(fontsize=7); ax2.grid(axis='y', alpha=0.2)
960
+
961
+ # ── Plot 3: Training/Val loss curves ──────────────────────────────
962
+ ax3 = fig.add_subplot(gs[1, 0])
963
+ for cname, col in PAL.items():
964
+ if cname not in all_hists: continue
965
+ for fi, h in enumerate(all_hists[cname]):
966
+ lb_tr = f'{cname} train' if fi == 0 else None
967
+ lb_vl = f'{cname} val' if fi == 0 else None
968
+ ax3.plot(h['train'], alpha=0.3, lw=0.8, color=col, label=lb_tr)
969
+ ax3.plot(h['val'], alpha=0.7, lw=1.2, color=col, label=lb_vl,
970
+ linestyle='--')
971
+ ax3.axhline(95.99, color='#4CAF50', ls='-.', lw=1.2, label='V12A (95.99)')
972
+ ax3.axvline(200, color='#4CAF50', ls='--', lw=1.2, alpha=0.6, label='SWA start')
973
+ ax3.set_xlabel('Epoch'); ax3.set_ylabel('MAE (MPa)')
974
+ ax3.set_title('Training Curves (seed 0, all folds)', fontweight='bold')
975
+ ax3.legend(fontsize=6, ncol=2); ax3.grid(alpha=0.2)
976
+ ax3.set_ylim(0, 300)
977
+
978
+ # ── Plot 4: Per-seed scatter / Confidence ─────────────────────────
979
+ ax4 = fig.add_subplot(gs[1, 1])
980
+ if all_conf_weights:
981
+ for cn, fw_list in all_conf_weights.items():
982
+ all_w = torch.cat(fw_list, dim=0)
983
+ avg_w = all_w.mean(dim=0).numpy()
984
+ steps = np.arange(1, len(avg_w)+1)
985
+ ax4.bar(steps, avg_w, color=PAL.get(cn, '#E65100'), alpha=0.8,
986
+ label=f'{cn} avg confidence', edgecolor='white')
987
+ std_w = all_w.std(dim=0).numpy()
988
+ ax4.errorbar(steps, avg_w, yerr=std_w, fmt='none',
989
+ ecolor='#333', capsize=2, alpha=0.5)
990
+ ax4.set_xlabel('Recursion Step')
991
+ ax4.set_ylabel('Confidence Weight (softmax)')
992
+ ax4.set_title('V13B: Where the Model Trusts Its Predictions',
993
+ fontweight='bold')
994
+ ax4.legend(fontsize=8)
995
+ ax4.grid(axis='y', alpha=0.2)
996
+ else:
997
+ # Show per-seed MAE scatter if no confidence model
998
+ for i, (cn, col) in enumerate(zip(names, cols)):
999
+ r = all_results[cn]
1000
+ seed_avgs = [r['per_seed_avgs'][s] for s in SEEDS]
1001
+ ax4.scatter(SEEDS, seed_avgs, s=80, c=col, alpha=0.8,
1002
+ label=f'{cn} per-seed', zorder=5,
1003
+ edgecolors='white', linewidth=1)
1004
+ ax4.axhline(r['avg'], color=col, ls='--', lw=1.5, alpha=0.6,
1005
+ label=f'{cn} ensemble={r["avg"]:.2f}')
1006
+ ax4.axhline(95.99, color='#4CAF50', ls=':', lw=1, alpha=0.5, label='V12A')
1007
+ ax4.set_xlabel('Random Seed')
1008
+ ax4.set_ylabel('5-Fold Avg MAE (MPa)')
1009
+ ax4.set_title('Per-Seed vs Ensemble Performance', fontweight='bold')
1010
+ ax4.legend(fontsize=7); ax4.grid(alpha=0.2)
1011
+
1012
+ fig.suptitle('TRM-MatSci V13 │ 2-Layer SA + Multi-Seed Ensemble │ matbench_steels',
1013
+ fontsize=14, fontweight='bold', y=1.01)
1014
+ fig.savefig('trm_results_v13.png', dpi=150, bbox_inches='tight')
1015
+ plt.close(fig); log.info("✓ Saved: trm_results_v13.png")
1016
+
1017
+
1018
+ def save_summary(all_results, all_hists, all_conf_weights, total_s):
1019
+ # Prepare confidence info
1020
+ conf_info = {}
1021
+ for cn, fw_list in all_conf_weights.items():
1022
+ all_w = torch.cat(fw_list, dim=0)
1023
+ conf_info[cn] = {
1024
+ 'avg_weights': all_w.mean(dim=0).numpy().round(4).tolist(),
1025
+ 'avg_peak_step': float(all_w.argmax(dim=1).float().mean().item() + 1),
1026
+ }
1027
+
1028
+ s = {
1029
+ 'version': 'V13', 'task': 'matbench_steels',
1030
+ 'strategy': '2-Layer SA + Multi-Seed Ensemble',
1031
+ 'seeds': SEEDS,
1032
+ 'n_seeds': N_SEEDS,
1033
+ 'total_min': round(total_s/60, 1),
1034
+ 'models': {},
1035
+ 'confidence': conf_info,
1036
+ }
1037
+ for n, r in all_results.items():
1038
+ s['models'][n] = {
1039
+ 'ensemble_avg': round(r['avg'], 4),
1040
+ 'ensemble_std': round(r['std'], 4),
1041
+ 'ensemble_folds': [round(x, 4) for x in r['folds']],
1042
+ 'params': r['params'],
1043
+ 'best_single_seed': r['best_single_seed'],
1044
+ 'best_single_mae': round(r['best_single_mae'], 4),
1045
+ 'per_seed_avgs': {str(k): round(v, 4) for k, v in r['per_seed_avgs'].items()},
1046
+ }
1047
+
1048
+ with open('trm_models_v13/summary_v13.json', 'w') as f:
1049
+ json.dump(s, f, indent=2, default=str)
1050
+ log.info("✓ Saved: summary_v13.json")
1051
+
1052
+
1053
+ if __name__ == '__main__':
1054
+ results = run_benchmark()
1055
+ shutil.make_archive("trm_v13_all", "zip", "trm_models_v13")
1056
+ log.info("✓ Created trm_v13_all.zip")
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch>=2.0
2
+ pymatgen>=2024.1.1
3
+ matminer>=0.9.0
4
+ gensim>=4.0.0
5
+ scikit-learn>=1.3.0
6
+ numpy>=1.24.0
7
+ pandas>=2.0.0
8
+ tqdm>=4.65.0
9
+ huggingface_hub>=0.20.0
10
+ gradio>=4.0.0
weights/README.md ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d2a5bec16529a25e4d500eea32ec1c9aaff2d12b3a014220f4c0303a75fffa04
3
+ size 1165
weights/expt_gap/weights.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f4658f262e0f3501e5716c35184fbcc86a4bc28765fbcbcc34756ce1ebf0976
3
+ size 2111183
weights/glass/weights.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f6f173c5a305bcee0ec837e7b6a58802b9f88b3745349913721757dc7d1e2c77
3
+ size 966543
weights/is_metal/weights.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:06ed5d20f532f9193aed736f92cb94a7b181ea6b347e959dc7612100f3ff073c
3
+ size 970383
weights/jdft2d/weights.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:21d3e4c4728e18e473b4860d81bec77a2a1633540f89c35160811ec9625c4569
3
+ size 1598799
weights/phonons/weights.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:47fe2ab26addf64bfc1e78c6f6e9b02e408ee290b4f51fb91d85d5b270c51193
3
+ size 6170267
weights/steels/weights.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:64ee164db899d44365bea3a67ef258d7e122144d4088357dd56af10a1c0af838
3
+ size 4574159