Commit ·
8a82d34
0
Parent(s):
TRIADS — 6-benchmark weights + model code + Gradio app
Browse filesBenchmarks:
- matbench_steels: 91.20 MPa (HybridTRIADS V13A, 225K, 5-fold 5-seed avg)
- matbench_expt_gap: 0.3068 eV (HybridTRIADS V3, 100K)
- matbench_expt_ismetal: 0.9655 AUC (HybridTRIADS, 44K, best comp-only)
- matbench_glass: 0.9285 AUC (HybridTRIADS, 44K, 5-seed)
- matbench_jdft2d: 35.89 meV (HybridTRIADS V4, 75K, 5-fold 5-seed avg)
- matbench_phonons: 41.91 cm-1 (GraphTRIADS V6, 247K, gate-halt)
- .gitattributes +2 -0
- README.md +160 -0
- app.py +658 -0
- model_code/__init__.py +9 -0
- model_code/classification_model.py +734 -0
- model_code/expt_gap_model.py +579 -0
- model_code/jdft2d_model.py +589 -0
- model_code/phonons_dataset_builder.py +749 -0
- model_code/phonons_model.py +839 -0
- model_code/steels_model.py +1056 -0
- requirements.txt +10 -0
- weights/README.md +3 -0
- weights/expt_gap/weights.pt +3 -0
- weights/glass/weights.pt +3 -0
- weights/is_metal/weights.pt +3 -0
- weights/jdft2d/weights.pt +3 -0
- weights/phonons/weights.pt +3 -0
- weights/steels/weights.pt +3 -0
.gitattributes
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
weights/** filter=lfs diff=lfs merge=lfs -text
|
README.md
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
language: en
|
| 4 |
+
tags:
|
| 5 |
+
- materials-science
|
| 6 |
+
- machine-learning
|
| 7 |
+
- pytorch
|
| 8 |
+
- matbench
|
| 9 |
+
- small-data
|
| 10 |
+
- attention
|
| 11 |
+
- recursive
|
| 12 |
+
- crystal
|
| 13 |
+
- gradio
|
| 14 |
+
datasets:
|
| 15 |
+
- matbench
|
| 16 |
+
metrics:
|
| 17 |
+
- mae
|
| 18 |
+
- roc_auc
|
| 19 |
+
model-index:
|
| 20 |
+
- name: TRIADS
|
| 21 |
+
results:
|
| 22 |
+
- task:
|
| 23 |
+
type: regression
|
| 24 |
+
name: Yield Strength Prediction (MPa)
|
| 25 |
+
dataset:
|
| 26 |
+
name: matbench_steels
|
| 27 |
+
type: matbench
|
| 28 |
+
metrics:
|
| 29 |
+
- type: mae
|
| 30 |
+
value: 91.20
|
| 31 |
+
name: MAE (MPa)
|
| 32 |
+
- task:
|
| 33 |
+
type: regression
|
| 34 |
+
name: Band Gap Prediction (eV)
|
| 35 |
+
dataset:
|
| 36 |
+
name: matbench_expt_gap
|
| 37 |
+
type: matbench
|
| 38 |
+
metrics:
|
| 39 |
+
- type: mae
|
| 40 |
+
value: 0.3068
|
| 41 |
+
name: MAE (eV)
|
| 42 |
+
- task:
|
| 43 |
+
type: classification
|
| 44 |
+
name: Metallicity Classification
|
| 45 |
+
dataset:
|
| 46 |
+
name: matbench_expt_ismetal
|
| 47 |
+
type: matbench
|
| 48 |
+
metrics:
|
| 49 |
+
- type: roc_auc
|
| 50 |
+
value: 0.9655
|
| 51 |
+
name: ROC-AUC
|
| 52 |
+
- task:
|
| 53 |
+
type: classification
|
| 54 |
+
name: Glass Forming Ability
|
| 55 |
+
dataset:
|
| 56 |
+
name: matbench_glass
|
| 57 |
+
type: matbench
|
| 58 |
+
metrics:
|
| 59 |
+
- type: roc_auc
|
| 60 |
+
value: 0.9285
|
| 61 |
+
name: ROC-AUC
|
| 62 |
+
- task:
|
| 63 |
+
type: regression
|
| 64 |
+
name: Exfoliation Energy (meV/atom)
|
| 65 |
+
dataset:
|
| 66 |
+
name: matbench_jdft2d
|
| 67 |
+
type: matbench
|
| 68 |
+
metrics:
|
| 69 |
+
- type: mae
|
| 70 |
+
value: 35.89
|
| 71 |
+
name: MAE (meV/atom)
|
| 72 |
+
- task:
|
| 73 |
+
type: regression
|
| 74 |
+
name: Peak Phonon Frequency (cm⁻¹)
|
| 75 |
+
dataset:
|
| 76 |
+
name: matbench_phonons
|
| 77 |
+
type: matbench
|
| 78 |
+
metrics:
|
| 79 |
+
- type: mae
|
| 80 |
+
value: 41.91
|
| 81 |
+
name: MAE (cm⁻¹)
|
| 82 |
+
---
|
| 83 |
+
|
| 84 |
+
# TRIADS — Materials Property Prediction Across 6 Matbench Benchmarks
|
| 85 |
+
|
| 86 |
+
**TRIADS (Tiny Recursive Information-Attention with Deep Supervision)** is a parameter-efficient recursive architecture for materials property prediction, purpose-built for the **small-data regime** (312–5,680 samples).
|
| 87 |
+
|
| 88 |
+
[](https://github.com/Rtx09x/TRIADS)
|
| 89 |
+
[](https://github.com/Rtx09x/TRIADS/raw/main/TRIADS_Final.pdf)
|
| 90 |
+
|
| 91 |
+
## Live Demo
|
| 92 |
+
|
| 93 |
+
Try the interactive demo with all 6 benchmarks → **[Launch App](https://huggingface.co/spaces/Rtx09/TRIADS)**
|
| 94 |
+
|
| 95 |
+
## Results Summary
|
| 96 |
+
|
| 97 |
+
| Task | N | TRIADS | Params | Rank |
|
| 98 |
+
|---|---|---|---|---|
|
| 99 |
+
| `matbench_steels` (yield strength) | 312 | **91.20 MPa** | 225K | #3 |
|
| 100 |
+
| `matbench_expt_gap` (band gap) | 4,604 | **0.3068 eV** | 100K | #2 composition-only |
|
| 101 |
+
| `matbench_expt_ismetal` (metal?) | 4,921 | **0.9655 ROC-AUC** | 100K | **#1** composition-only |
|
| 102 |
+
| `matbench_glass` (glass forming) | 5,680 | **0.9285 ROC-AUC** | 44K | #2 |
|
| 103 |
+
| `matbench_jdft2d` (exfol. energy) | 636 | **35.89 meV/atom** | 75K | **#1** no-pretraining |
|
| 104 |
+
| `matbench_phonons` (phonon freq.) | 1,265 | **41.91 cm⁻¹** | 247K | **#1** no-pretraining |
|
| 105 |
+
|
| 106 |
+
## Two Model Variants
|
| 107 |
+
|
| 108 |
+
### HybridTRIADS (composition tasks: steels, gap, ismetal, glass, jdft2d)
|
| 109 |
+
Input: Chemical formula → Magpie + Mat2Vec (composition tokens)
|
| 110 |
+
Core: 2-layer self-attention cell, iterated T=16-20 times with shared weights
|
| 111 |
+
Training: Per-cycle deep supervision (w_t ∝ t)
|
| 112 |
+
|
| 113 |
+
### GraphTRIADS (structural task: phonons)
|
| 114 |
+
Input: CIF/structure → 3-order hierarchical crystal graph (atoms, bonds, triplet angles, dihedral angles)
|
| 115 |
+
Core: Hierarchical GNN message-passing stack inside the shared recursive cell
|
| 116 |
+
Halting: Gate-based adaptive halting (4–16 cycles per sample)
|
| 117 |
+
|
| 118 |
+
## Pretrained Checkpoints
|
| 119 |
+
|
| 120 |
+
Weights are organized by benchmark. Download via `huggingface_hub`:
|
| 121 |
+
|
| 122 |
+
```python
|
| 123 |
+
from huggingface_hub import hf_hub_download
|
| 124 |
+
import torch
|
| 125 |
+
|
| 126 |
+
# Download one benchmark's weights (contains all folds compiled)
|
| 127 |
+
ckpt = torch.load(
|
| 128 |
+
hf_hub_download("Rtx09/TRIADS", "steels/weights.pt"),
|
| 129 |
+
map_location="cpu"
|
| 130 |
+
)
|
| 131 |
+
# ckpt['folds'] -> list of fold dicts, each with 'model_state' and 'test_mae'
|
| 132 |
+
# ckpt['n_extra'] -> int (needed for model init)
|
| 133 |
+
# ckpt['config'] -> dict (d_attn, d_hidden, ff_dim, dropout, max_steps)
|
| 134 |
+
```
|
| 135 |
+
|
| 136 |
+
### Checkpoint Index
|
| 137 |
+
|
| 138 |
+
| Benchmark | File | Folds | Notes |
|
| 139 |
+
|---|---|---|---|
|
| 140 |
+
| matbench_steels | `steels/weights.pt` | 5 | HybridTRIADS V13A · 225K · 5-seed ensemble compiled |
|
| 141 |
+
| matbench_expt_gap | `expt_gap/weights.pt` | 5 | HybridTRIADS V3 · 100K |
|
| 142 |
+
| matbench_expt_ismetal | `is_metal/weights.pt` | 5 | HybridTRIADS · 100K |
|
| 143 |
+
| matbench_glass | `glass/weights.pt` | 5 | HybridTRIADS · 44K |
|
| 144 |
+
| matbench_jdft2d | `jdft2d/weights.pt` | 5 | HybridTRIADS V4 · 75K · 5-seed ensemble compiled |
|
| 145 |
+
| matbench_phonons | `phonons/weights.pt` | 5 | GraphTRIADS V6 · 247K · also needs `phonons/dataset.pt` |
|
| 146 |
+
|
| 147 |
+
## Citation
|
| 148 |
+
|
| 149 |
+
```bibtex
|
| 150 |
+
@article{tiwari2026triads,
|
| 151 |
+
author = {Rudra Tiwari},
|
| 152 |
+
title = {TRIADS: Tiny Recursive Information-Attention with Deep Supervision},
|
| 153 |
+
year = {2026},
|
| 154 |
+
url = {https://github.com/Rtx09x/TRIADS}
|
| 155 |
+
}
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
## License
|
| 159 |
+
|
| 160 |
+
MIT License — see [GitHub repository](https://github.com/Rtx09x/TRIADS/blob/main/LICENSE).
|
app.py
ADDED
|
@@ -0,0 +1,658 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TRIADS — Multi-Benchmark Materials Property Prediction
|
| 3 |
+
HuggingFace Gradio App
|
| 4 |
+
|
| 5 |
+
Covers all 6 Matbench benchmarks:
|
| 6 |
+
1. matbench_steels — Yield Strength (MPa)
|
| 7 |
+
2. matbench_expt_gap — Band Gap (eV)
|
| 8 |
+
3. matbench_ismetal — Metallicity (ROC-AUC)
|
| 9 |
+
4. matbench_glass — Glass Forming Ability
|
| 10 |
+
5. matbench_jdft2d — Exfoliation Energy (meV/atom)
|
| 11 |
+
6. matbench_phonons — Peak Phonon Frequency (cm⁻¹)
|
| 12 |
+
"""
|
| 13 |
+
|
| 14 |
+
import os
|
| 15 |
+
import warnings
|
| 16 |
+
import urllib.request
|
| 17 |
+
import json
|
| 18 |
+
|
| 19 |
+
warnings.filterwarnings("ignore")
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
import torch.nn as nn
|
| 24 |
+
import gradio as gr
|
| 25 |
+
from huggingface_hub import hf_hub_download
|
| 26 |
+
|
| 27 |
+
# ─────────────────────────────────────────────────────────────────
|
| 28 |
+
# CONFIG
|
| 29 |
+
# ─────────────────────────────────────────────────────────────────
|
| 30 |
+
|
| 31 |
+
REPO_ID = "Rtx09/TRIADS"
|
| 32 |
+
|
| 33 |
+
BENCHMARK_INFO = {
|
| 34 |
+
"steels": {
|
| 35 |
+
"title": "🔩 Steel Yield Strength",
|
| 36 |
+
"description": "Predict yield strength (MPa) of steel alloys from composition.",
|
| 37 |
+
"unit": "MPa",
|
| 38 |
+
"example": "Fe0.7Cr0.15Ni0.15",
|
| 39 |
+
"examples": ["Fe0.7Cr0.15Ni0.15", "Fe0.8C0.02Mn0.1Si0.05Cr0.03", "Fe0.6Ni0.25Mo0.1Cr0.05"],
|
| 40 |
+
"task": "regression",
|
| 41 |
+
"result": "91.20 ± 12.23 MPa MAE (5-fold, 5-seed ensemble)",
|
| 42 |
+
},
|
| 43 |
+
"expt_gap": {
|
| 44 |
+
"title": "⚡ Experimental Band Gap",
|
| 45 |
+
"description": "Predict experimental electronic band gap (eV) from composition.",
|
| 46 |
+
"unit": "eV",
|
| 47 |
+
"example": "TiO2",
|
| 48 |
+
"examples": ["TiO2", "GaN", "ZnO", "Si", "CdS"],
|
| 49 |
+
"task": "regression",
|
| 50 |
+
"result": "0.3068 ± 0.0082 eV MAE (5-fold, composition-only)",
|
| 51 |
+
},
|
| 52 |
+
"ismetal": {
|
| 53 |
+
"title": "🔮 Metallicity Classifier",
|
| 54 |
+
"description": "Predict whether a material is metallic or non-metallic from composition.",
|
| 55 |
+
"unit": "probability (1 = metal)",
|
| 56 |
+
"example": "Cu",
|
| 57 |
+
"examples": ["Cu", "SiO2", "Fe3O4", "BaTiO3", "Al"],
|
| 58 |
+
"task": "classification",
|
| 59 |
+
"result": "0.9655 ± 0.0029 ROC-AUC (5-fold, composition-only)",
|
| 60 |
+
},
|
| 61 |
+
"glass": {
|
| 62 |
+
"title": "🪟 Glass Forming Ability",
|
| 63 |
+
"description": "Predict metallic glass forming ability from alloy composition.",
|
| 64 |
+
"unit": "probability (1 = glass former)",
|
| 65 |
+
"example": "Cu46Zr54",
|
| 66 |
+
"examples": ["Cu46Zr54", "Fe80B20", "Al86Ni7La6Y1", "Pd40Cu30Ni10P20"],
|
| 67 |
+
"task": "classification",
|
| 68 |
+
"result": "0.9285 ± 0.0063 ROC-AUC (5-fold, 5-seed ensemble)",
|
| 69 |
+
},
|
| 70 |
+
"jdft2d": {
|
| 71 |
+
"title": "📐 Exfoliation Energy",
|
| 72 |
+
"description": "Predict exfoliation energy (meV/atom) of 2D materials from structure+composition.",
|
| 73 |
+
"unit": "meV/atom",
|
| 74 |
+
"example": "MoS2",
|
| 75 |
+
"examples": ["MoS2", "WSe2", "BN", "graphene (C)", "MoTe2"],
|
| 76 |
+
"task": "regression",
|
| 77 |
+
"result": "35.89 ± 12.40 meV/atom MAE (5-fold, 5-seed ensemble)",
|
| 78 |
+
},
|
| 79 |
+
"phonons": {
|
| 80 |
+
"title": "🎵 Phonon Peak Frequency",
|
| 81 |
+
"description": "Predict peak phonon frequency (cm⁻¹) from crystal structure.",
|
| 82 |
+
"unit": "cm⁻¹",
|
| 83 |
+
"example": "Si (diamond cubic)",
|
| 84 |
+
"examples": ["Si", "GaAs", "MgO", "BN (wurtzite)", "TiO2 (rutile)"],
|
| 85 |
+
"task": "regression",
|
| 86 |
+
"result": "41.91 ± 4.04 cm⁻¹ MAE (5-fold, gate-halt GraphTRIADS)",
|
| 87 |
+
},
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
# ─────────────────────────────────────────────────────────────────
|
| 92 |
+
# MODEL DEFINITIONS (inlined for self-contained app)
|
| 93 |
+
# ─────────────────────────────────────────────────────────────────
|
| 94 |
+
|
| 95 |
+
class DeepHybridTRM(nn.Module):
|
| 96 |
+
"""
|
| 97 |
+
HybridTRIADS — composition-only tasks.
|
| 98 |
+
Shared across: steels, expt_gap, ismetal, glass, jdft2d.
|
| 99 |
+
"""
|
| 100 |
+
def __init__(self, n_props=22, stat_dim=6, n_extra=0, mat2vec_dim=200,
|
| 101 |
+
d_attn=64, nhead=4, d_hidden=96, ff_dim=150,
|
| 102 |
+
dropout=0.2, max_steps=20, **kw):
|
| 103 |
+
super().__init__()
|
| 104 |
+
self.max_steps, self.D = max_steps, d_hidden
|
| 105 |
+
self.n_props, self.stat_dim, self.n_extra = n_props, stat_dim, n_extra
|
| 106 |
+
|
| 107 |
+
self.tok_proj = nn.Sequential(
|
| 108 |
+
nn.Linear(stat_dim, d_attn), nn.LayerNorm(d_attn), nn.GELU())
|
| 109 |
+
self.m2v_proj = nn.Sequential(
|
| 110 |
+
nn.Linear(mat2vec_dim, d_attn), nn.LayerNorm(d_attn), nn.GELU())
|
| 111 |
+
|
| 112 |
+
self.sa1 = nn.MultiheadAttention(d_attn, nhead, dropout=dropout, batch_first=True)
|
| 113 |
+
self.sa1_n = nn.LayerNorm(d_attn)
|
| 114 |
+
self.sa1_ff = nn.Sequential(
|
| 115 |
+
nn.Linear(d_attn, d_attn*2), nn.GELU(), nn.Dropout(dropout),
|
| 116 |
+
nn.Linear(d_attn*2, d_attn))
|
| 117 |
+
self.sa1_fn = nn.LayerNorm(d_attn)
|
| 118 |
+
|
| 119 |
+
self.sa2 = nn.MultiheadAttention(d_attn, nhead, dropout=dropout, batch_first=True)
|
| 120 |
+
self.sa2_n = nn.LayerNorm(d_attn)
|
| 121 |
+
self.sa2_ff = nn.Sequential(
|
| 122 |
+
nn.Linear(d_attn, d_attn*2), nn.GELU(), nn.Dropout(dropout),
|
| 123 |
+
nn.Linear(d_attn*2, d_attn))
|
| 124 |
+
self.sa2_fn = nn.LayerNorm(d_attn)
|
| 125 |
+
|
| 126 |
+
self.ca = nn.MultiheadAttention(d_attn, nhead, dropout=dropout, batch_first=True)
|
| 127 |
+
self.ca_n = nn.LayerNorm(d_attn)
|
| 128 |
+
|
| 129 |
+
pool_in = d_attn + (n_extra if n_extra > 0 else 0)
|
| 130 |
+
self.pool = nn.Sequential(
|
| 131 |
+
nn.Linear(pool_in, d_hidden), nn.LayerNorm(d_hidden), nn.GELU())
|
| 132 |
+
|
| 133 |
+
self.z_up = nn.Sequential(
|
| 134 |
+
nn.Linear(d_hidden*3, ff_dim), nn.GELU(), nn.Dropout(dropout),
|
| 135 |
+
nn.Linear(ff_dim, d_hidden), nn.LayerNorm(d_hidden))
|
| 136 |
+
self.y_up = nn.Sequential(
|
| 137 |
+
nn.Linear(d_hidden*2, ff_dim), nn.GELU(), nn.Dropout(dropout),
|
| 138 |
+
nn.Linear(ff_dim, d_hidden), nn.LayerNorm(d_hidden))
|
| 139 |
+
self.head = nn.Linear(d_hidden, 1)
|
| 140 |
+
self._init()
|
| 141 |
+
|
| 142 |
+
def _init(self):
|
| 143 |
+
for m in self.modules():
|
| 144 |
+
if isinstance(m, nn.Linear):
|
| 145 |
+
nn.init.xavier_uniform_(m.weight)
|
| 146 |
+
if m.bias is not None: nn.init.zeros_(m.bias)
|
| 147 |
+
|
| 148 |
+
def _attention(self, x):
|
| 149 |
+
B = x.size(0)
|
| 150 |
+
mg_dim = self.n_props * self.stat_dim
|
| 151 |
+
if self.n_extra > 0:
|
| 152 |
+
extra = x[:, mg_dim:mg_dim + self.n_extra]
|
| 153 |
+
m2v = x[:, mg_dim + self.n_extra:]
|
| 154 |
+
else:
|
| 155 |
+
extra, m2v = None, x[:, mg_dim:]
|
| 156 |
+
|
| 157 |
+
tok = self.tok_proj(x[:, :mg_dim].view(B, self.n_props, self.stat_dim))
|
| 158 |
+
ctx = self.m2v_proj(m2v).unsqueeze(1)
|
| 159 |
+
|
| 160 |
+
tok = self.sa1_n(tok + self.sa1(tok, tok, tok)[0])
|
| 161 |
+
tok = self.sa1_fn(tok + self.sa1_ff(tok))
|
| 162 |
+
tok = self.sa2_n(tok + self.sa2(tok, tok, tok)[0])
|
| 163 |
+
tok = self.sa2_fn(tok + self.sa2_ff(tok))
|
| 164 |
+
tok = self.ca_n(tok + self.ca(tok, ctx, ctx)[0])
|
| 165 |
+
|
| 166 |
+
pooled = tok.mean(dim=1)
|
| 167 |
+
if extra is not None:
|
| 168 |
+
pooled = torch.cat([pooled, extra], dim=-1)
|
| 169 |
+
return self.pool(pooled)
|
| 170 |
+
|
| 171 |
+
def forward(self, x, deep_supervision=False):
|
| 172 |
+
B = x.size(0)
|
| 173 |
+
xp = self._attention(x)
|
| 174 |
+
z = torch.zeros(B, self.D, device=x.device)
|
| 175 |
+
y = torch.zeros(B, self.D, device=x.device)
|
| 176 |
+
step_preds = []
|
| 177 |
+
for _ in range(self.max_steps):
|
| 178 |
+
z = z + self.z_up(torch.cat([xp, y, z], -1))
|
| 179 |
+
y = y + self.y_up(torch.cat([y, z], -1))
|
| 180 |
+
step_preds.append(self.head(y).squeeze(1))
|
| 181 |
+
return step_preds if deep_supervision else step_preds[-1]
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
# ─────────────────────────────────────────────────────────────────
|
| 185 |
+
# FEATURIZER (composition-only, shared across HybridTRIADS tasks)
|
| 186 |
+
# ─────────────────────────────────────────────────────────────────
|
| 187 |
+
|
| 188 |
+
_featurizer_cache = {}
|
| 189 |
+
_mat2vec_cache = {}
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def _get_featurizer():
|
| 193 |
+
"""Lazy-load the ExpandedFeaturizer (downloads Mat2Vec once)."""
|
| 194 |
+
if "main" in _featurizer_cache:
|
| 195 |
+
return _featurizer_cache["main"]
|
| 196 |
+
|
| 197 |
+
try:
|
| 198 |
+
from matminer.featurizers.composition import (
|
| 199 |
+
ElementProperty, ElementFraction, Stoichiometry,
|
| 200 |
+
ValenceOrbital, IonProperty, BandCenter
|
| 201 |
+
)
|
| 202 |
+
from matminer.featurizers.base import MultipleFeaturizer
|
| 203 |
+
from gensim.models import Word2Vec
|
| 204 |
+
from sklearn.preprocessing import StandardScaler
|
| 205 |
+
|
| 206 |
+
GCS = "https://storage.googleapis.com/mat2vec/"
|
| 207 |
+
M2V_FILES = [
|
| 208 |
+
"pretrained_embeddings",
|
| 209 |
+
"pretrained_embeddings.wv.vectors.npy",
|
| 210 |
+
"pretrained_embeddings.trainables.syn1neg.npy",
|
| 211 |
+
]
|
| 212 |
+
os.makedirs("mat2vec_cache", exist_ok=True)
|
| 213 |
+
for f in M2V_FILES:
|
| 214 |
+
p = os.path.join("mat2vec_cache", f)
|
| 215 |
+
if not os.path.exists(p):
|
| 216 |
+
urllib.request.urlretrieve(GCS + f, p)
|
| 217 |
+
|
| 218 |
+
ep = ElementProperty.from_preset("magpie")
|
| 219 |
+
m2v = Word2Vec.load("mat2vec_cache/pretrained_embeddings")
|
| 220 |
+
emb = {w: m2v.wv[w] for w in m2v.wv.index_to_key}
|
| 221 |
+
extra = MultipleFeaturizer([ElementFraction(), Stoichiometry(),
|
| 222 |
+
ValenceOrbital(), IonProperty(), BandCenter()])
|
| 223 |
+
|
| 224 |
+
_featurizer_cache["main"] = (ep, m2v, emb, extra)
|
| 225 |
+
return _featurizer_cache["main"]
|
| 226 |
+
|
| 227 |
+
except Exception as e:
|
| 228 |
+
return None
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def featurize_composition(formula: str):
|
| 232 |
+
"""Featurize a chemical formula into the TRIADS feature vector."""
|
| 233 |
+
from pymatgen.core import Composition
|
| 234 |
+
|
| 235 |
+
result = _get_featurizer()
|
| 236 |
+
if result is None:
|
| 237 |
+
return None, f"Featurizer not available: {str(e)}"
|
| 238 |
+
|
| 239 |
+
ep, m2v, emb, extra = result
|
| 240 |
+
|
| 241 |
+
try:
|
| 242 |
+
comp = Composition(formula)
|
| 243 |
+
except Exception:
|
| 244 |
+
return None, f"Invalid formula: '{formula}'"
|
| 245 |
+
|
| 246 |
+
try:
|
| 247 |
+
mg = np.array(ep.featurize(comp), np.float32)
|
| 248 |
+
except Exception:
|
| 249 |
+
mg = np.zeros(len(ep.feature_labels()), np.float32)
|
| 250 |
+
|
| 251 |
+
try:
|
| 252 |
+
ex = np.array(extra.featurize(comp), np.float32)
|
| 253 |
+
ex = np.nan_to_num(ex, nan=0.0)
|
| 254 |
+
except Exception:
|
| 255 |
+
ex = np.zeros(50, np.float32)
|
| 256 |
+
|
| 257 |
+
# Mat2Vec pooled
|
| 258 |
+
v, t = np.zeros(200, np.float32), 0.0
|
| 259 |
+
for s, f in comp.get_el_amt_dict().items():
|
| 260 |
+
if s in emb:
|
| 261 |
+
v += f * emb[s]
|
| 262 |
+
t += f
|
| 263 |
+
m2v_vec = v / max(t, 1e-8)
|
| 264 |
+
|
| 265 |
+
mg = np.nan_to_num(mg, nan=0.0)
|
| 266 |
+
feat = np.concatenate([mg, ex, m2v_vec])
|
| 267 |
+
return feat.astype(np.float32), None
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
# ─────────────────────────────────────────────────────────────────
|
| 271 |
+
# WEIGHT LOADING (lazy, cached)
|
| 272 |
+
# ─────────────────────────────────────────────────────────────────
|
| 273 |
+
|
| 274 |
+
# weights.pt format (one file per benchmark on HuggingFace):
|
| 275 |
+
# {
|
| 276 |
+
# 'folds': [ {'model_state': OrderedDict, 'test_mae': float}, ... ], # len == n_folds
|
| 277 |
+
# 'n_extra': int,
|
| 278 |
+
# 'config': {'d_attn': int, 'd_hidden': int, 'ff_dim': int,
|
| 279 |
+
# 'dropout': float, 'max_steps': int},
|
| 280 |
+
# 'benchmark': str,
|
| 281 |
+
# }
|
| 282 |
+
|
| 283 |
+
_fold_models = {} # benchmark -> list[nn.Module] (one entry per fold)
|
| 284 |
+
|
| 285 |
+
_MODEL_CONFIGS = {
|
| 286 |
+
# These MUST match the architecture configs baked into the saved weights.pt files.
|
| 287 |
+
# Values verified by inspecting ckpt['config'] from each weights.pt directly.
|
| 288 |
+
"steels": dict(d_attn=64, d_hidden=96, ff_dim=150, dropout=0.20, max_steps=20),
|
| 289 |
+
"expt_gap": dict(d_attn=64, d_hidden=96, ff_dim=150, dropout=0.20, max_steps=20), # V3 s42 (actual weights)
|
| 290 |
+
"ismetal": dict(d_attn=24, d_hidden=48, ff_dim=72, dropout=0.20, max_steps=16), # 100K actual
|
| 291 |
+
"glass": dict(d_attn=24, d_hidden=48, ff_dim=72, dropout=0.20, max_steps=16), # actual weights
|
| 292 |
+
"jdft2d": dict(d_attn=32, d_hidden=64, ff_dim=96, dropout=0.20, max_steps=16), # V4-75K actual
|
| 293 |
+
}
|
| 294 |
+
|
| 295 |
+
_HF_PATHS = {
|
| 296 |
+
"steels": "steels/weights.pt",
|
| 297 |
+
"expt_gap": "expt_gap/weights.pt",
|
| 298 |
+
"ismetal": "is_metal/weights.pt",
|
| 299 |
+
"glass": "glass/weights.pt",
|
| 300 |
+
"jdft2d": "jdft2d/weights.pt",
|
| 301 |
+
"phonons": "phonons/weights.pt",
|
| 302 |
+
}
|
| 303 |
+
|
| 304 |
+
|
| 305 |
+
def _load_benchmark_models(benchmark: str):
|
| 306 |
+
"""
|
| 307 |
+
Download benchmark/weights.pt once, build one nn.Module per fold,
|
| 308 |
+
cache the list in _fold_models[benchmark].
|
| 309 |
+
Returns list[nn.Module] or None on failure.
|
| 310 |
+
"""
|
| 311 |
+
if benchmark in _fold_models:
|
| 312 |
+
return _fold_models[benchmark]
|
| 313 |
+
|
| 314 |
+
if benchmark == "phonons":
|
| 315 |
+
# Phonons needs structure input — no composition-only inference
|
| 316 |
+
return None
|
| 317 |
+
|
| 318 |
+
try:
|
| 319 |
+
path = hf_hub_download(repo_id=REPO_ID, filename=_HF_PATHS[benchmark])
|
| 320 |
+
ckpt = torch.load(path, map_location="cpu", weights_only=False)
|
| 321 |
+
|
| 322 |
+
# Accept both old per-fold dicts and the new compiled format
|
| 323 |
+
fold_entries = ckpt.get("folds", [ckpt]) # fallback: single-fold legacy
|
| 324 |
+
n_extra = ckpt.get("n_extra", 0)
|
| 325 |
+
cfg = {**_MODEL_CONFIGS[benchmark], "n_extra": n_extra}
|
| 326 |
+
|
| 327 |
+
models = []
|
| 328 |
+
for entry in fold_entries:
|
| 329 |
+
m = DeepHybridTRM(**cfg)
|
| 330 |
+
state = entry if isinstance(entry, dict) and "weight" not in str(list(entry.keys())[:1]) \
|
| 331 |
+
else entry # entry is either a state_dict or {'model_state': ..., ...}
|
| 332 |
+
# Handle both {'model_state': sd} and raw state_dict formats
|
| 333 |
+
sd = entry.get("model_state", entry) if isinstance(entry, dict) else entry
|
| 334 |
+
m.load_state_dict(sd)
|
| 335 |
+
m.eval()
|
| 336 |
+
models.append(m)
|
| 337 |
+
|
| 338 |
+
_fold_models[benchmark] = models
|
| 339 |
+
return models
|
| 340 |
+
|
| 341 |
+
except Exception:
|
| 342 |
+
return None
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
def _ensemble_predict(benchmark: str, x: np.ndarray,
|
| 346 |
+
is_classification: bool = False):
|
| 347 |
+
"""Run inference through all fold models, return averaged prediction."""
|
| 348 |
+
models = _load_benchmark_models(benchmark)
|
| 349 |
+
if not models:
|
| 350 |
+
return None, "Could not load model weights. Are they uploaded to HuggingFace?"
|
| 351 |
+
|
| 352 |
+
xt = torch.tensor(x[None], dtype=torch.float32)
|
| 353 |
+
preds = []
|
| 354 |
+
for m in models:
|
| 355 |
+
with torch.no_grad():
|
| 356 |
+
out = m(xt).item()
|
| 357 |
+
if is_classification:
|
| 358 |
+
out = torch.sigmoid(torch.tensor(out)).item()
|
| 359 |
+
preds.append(out)
|
| 360 |
+
|
| 361 |
+
return float(np.mean(preds)), None
|
| 362 |
+
|
| 363 |
+
|
| 364 |
+
# ─────────────────────────────────────────────────────────────────
|
| 365 |
+
# PREDICTION FUNCTIONS (one per benchmark tab)
|
| 366 |
+
# ─────────────────────────────────────────────────────────────────
|
| 367 |
+
|
| 368 |
+
def _status_bar(benchmark_key: str):
|
| 369 |
+
info = BENCHMARK_INFO[benchmark_key]
|
| 370 |
+
return (f"📊 **Benchmark result:** {info['result']}\n\n"
|
| 371 |
+
f"*Weights will be downloaded from HuggingFace on first prediction.*")
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def predict_steels(formula: str):
|
| 375 |
+
feat, err = featurize_composition(formula)
|
| 376 |
+
if err:
|
| 377 |
+
return f"❌ Error: {err}", ""
|
| 378 |
+
|
| 379 |
+
pred, err = _ensemble_predict("steels", feat, is_classification=False)
|
| 380 |
+
if err:
|
| 381 |
+
return f"❌ {err}", ""
|
| 382 |
+
|
| 383 |
+
context = (f"**{pred:.1f} MPa** yield strength\n\n"
|
| 384 |
+
f"> TRIADS benchmark MAE: 91.20 MPa | "
|
| 385 |
+
f"CrabNet: 107.32 MPa | Darwin: 123.29 MPa")
|
| 386 |
+
return f"### {pred:.1f} MPa", context
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def predict_expt_gap(formula: str):
|
| 390 |
+
feat, err = featurize_composition(formula)
|
| 391 |
+
if err:
|
| 392 |
+
return f"❌ Error: {err}", ""
|
| 393 |
+
|
| 394 |
+
pred, err = _ensemble_predict("expt_gap", feat, is_classification=False)
|
| 395 |
+
if err:
|
| 396 |
+
return f"❌ {err}", ""
|
| 397 |
+
|
| 398 |
+
metal_class = "Likely metallic (Eg ≈ 0)" if pred < 0.3 else (
|
| 399 |
+
"Small gap semiconductor" if pred < 1.5 else
|
| 400 |
+
"Wide-gap semiconductor/insulator")
|
| 401 |
+
context = (f"**{pred:.3f} eV** band gap\n\n"
|
| 402 |
+
f"Classification: {metal_class}\n\n"
|
| 403 |
+
f"> TRIADS benchmark MAE: 0.3068 eV | Darwin: 0.2865 eV")
|
| 404 |
+
return f"### {pred:.3f} eV", context
|
| 405 |
+
|
| 406 |
+
|
| 407 |
+
def predict_ismetal(formula: str):
|
| 408 |
+
feat, err = featurize_composition(formula)
|
| 409 |
+
if err:
|
| 410 |
+
return f"❌ Error: {err}", ""
|
| 411 |
+
|
| 412 |
+
pred, err = _ensemble_predict("ismetal", feat, is_classification=True)
|
| 413 |
+
if err:
|
| 414 |
+
return f"❌ {err}", ""
|
| 415 |
+
|
| 416 |
+
label = "🔩 **METALLIC**" if pred > 0.5 else "💎 **NON-METALLIC**"
|
| 417 |
+
pct = pred * 100 if pred > 0.5 else (1 - pred) * 100
|
| 418 |
+
confidence = "high" if pct > 80 else "moderate" if pct > 60 else "uncertain"
|
| 419 |
+
context = (f"{label} (confidence: {confidence}, p={pred:.3f})\n\n"
|
| 420 |
+
f"> TRIADS benchmark ROC-AUC: 0.9655 (best composition-only model)")
|
| 421 |
+
return f"### {pred:.3f} probability of being metallic", context
|
| 422 |
+
|
| 423 |
+
|
| 424 |
+
def predict_glass(formula: str):
|
| 425 |
+
feat, err = featurize_composition(formula)
|
| 426 |
+
if err:
|
| 427 |
+
return f"❌ Error: {err}", ""
|
| 428 |
+
|
| 429 |
+
pred, err = _ensemble_predict("glass", feat, is_classification=True)
|
| 430 |
+
if err:
|
| 431 |
+
return f"❌ {err}", ""
|
| 432 |
+
|
| 433 |
+
label = "🪟 **Likely glass-former**" if pred > 0.5 else "❌ **Unlikely glass-former**"
|
| 434 |
+
context = (f"{label} (p={pred:.3f})\n\n"
|
| 435 |
+
f"> TRIADS benchmark ROC-AUC: 0.9285 | MODNet: 0.9603")
|
| 436 |
+
return f"### {pred:.3f} glass-forming probability", context
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
def predict_jdft2d(formula: str):
|
| 440 |
+
feat, err = featurize_composition(formula)
|
| 441 |
+
if err:
|
| 442 |
+
return f"❌ Error: {err}", ""
|
| 443 |
+
|
| 444 |
+
pred, err = _ensemble_predict("jdft2d", feat, is_classification=False)
|
| 445 |
+
if err:
|
| 446 |
+
return f"❌ {err}", ""
|
| 447 |
+
|
| 448 |
+
ease = "Easy to exfoliate" if pred < 50 else "Moderate" if pred < 150 else "Hard to exfoliate"
|
| 449 |
+
context = (f"**{pred:.1f} meV/atom** exfoliation energy\n\n"
|
| 450 |
+
f"Exfoliatability: {ease}\n\n"
|
| 451 |
+
f"> TRIADS benchmark MAE: 35.89 meV/atom (best no-pretraining)")
|
| 452 |
+
return f"### {pred:.1f} meV/atom", context
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
def predict_phonons_placeholder(formula: str):
|
| 456 |
+
return ("### ⚠️ Phonons — Structure Required",
|
| 457 |
+
"GraphTRIADS for phonons requires a crystal structure (CIF file), "
|
| 458 |
+
"not just a formula. The pretrained weights are available at "
|
| 459 |
+
"`huggingface.co/Rtx09/TRIADS` under `phonons/`.\n\n"
|
| 460 |
+
f"> Benchmark MAE: 41.91 cm⁻¹ (gate-halt GraphTRIADS V6, 247K params)")
|
| 461 |
+
|
| 462 |
+
|
| 463 |
+
# ─────────────────────────────────────────────────────────────────
|
| 464 |
+
# GRADIO INTERFACE
|
| 465 |
+
# ─────────────────────────────────────────────────────────────────
|
| 466 |
+
|
| 467 |
+
CSS = """
|
| 468 |
+
.gr-box { border-radius: 12px !important; }
|
| 469 |
+
.tab-nav button { font-weight: 600; font-size: 14px; }
|
| 470 |
+
#result_text { font-size: 1.5rem; font-weight: 700; color: #6366f1; }
|
| 471 |
+
.benchmark-badge {
|
| 472 |
+
background: #1e293b; color: #94a3b8; border-radius: 8px;
|
| 473 |
+
padding: 8px 14px; font-family: monospace; font-size: 12px;
|
| 474 |
+
}
|
| 475 |
+
footer { display: none !important; }
|
| 476 |
+
"""
|
| 477 |
+
|
| 478 |
+
def build_interface():
|
| 479 |
+
with gr.Blocks(css=CSS, title="TRIADS — Materials Property Prediction") as demo:
|
| 480 |
+
|
| 481 |
+
gr.Markdown("""
|
| 482 |
+
# ⚡ TRIADS — Materials Property Prediction
|
| 483 |
+
**Tiny Recursive Information-Attention with Deep Supervision**
|
| 484 |
+
Six Matbench benchmarks · Parameter-efficient · Small-data specialist
|
| 485 |
+
|
| 486 |
+
Select a benchmark tab below to predict a material property.
|
| 487 |
+
""")
|
| 488 |
+
|
| 489 |
+
with gr.Tabs():
|
| 490 |
+
|
| 491 |
+
# ── TAB 1: STEELS ───────────────────────────────────���─────────
|
| 492 |
+
with gr.Tab("🔩 Steel Yield"):
|
| 493 |
+
with gr.Row():
|
| 494 |
+
with gr.Column(scale=1):
|
| 495 |
+
gr.Markdown("### Alloy Yield Strength (MPa)")
|
| 496 |
+
gr.Markdown("Input an alloy composition (elemental fractions must sum to 1).")
|
| 497 |
+
formula_s = gr.Textbox(
|
| 498 |
+
label="Alloy formula",
|
| 499 |
+
placeholder="e.g. Fe0.7Cr0.15Ni0.15",
|
| 500 |
+
value="Fe0.7Cr0.15Ni0.15"
|
| 501 |
+
)
|
| 502 |
+
gr.Examples(
|
| 503 |
+
examples=["Fe0.7Cr0.15Ni0.15", "Fe0.8C0.02Mn0.1Si0.05Cr0.03",
|
| 504 |
+
"Fe0.6Ni0.25Mo0.1Cr0.05"],
|
| 505 |
+
inputs=formula_s
|
| 506 |
+
)
|
| 507 |
+
btn_s = gr.Button("Predict Yield Strength", variant="primary")
|
| 508 |
+
with gr.Column(scale=1):
|
| 509 |
+
out_s = gr.Markdown(elem_id="result_text")
|
| 510 |
+
ctx_s = gr.Markdown()
|
| 511 |
+
gr.Markdown(
|
| 512 |
+
"📊 TRIADS V13A · 225K params · 5-seed ensemble · **91.20 MPa MAE**",
|
| 513 |
+
elem_classes="benchmark-badge"
|
| 514 |
+
)
|
| 515 |
+
btn_s.click(predict_steels, inputs=formula_s, outputs=[out_s, ctx_s])
|
| 516 |
+
|
| 517 |
+
# ── TAB 2: BAND GAP ───────────────────────────────────────────
|
| 518 |
+
with gr.Tab("⚡ Band Gap"):
|
| 519 |
+
with gr.Row():
|
| 520 |
+
with gr.Column(scale=1):
|
| 521 |
+
gr.Markdown("### Experimental Band Gap (eV)")
|
| 522 |
+
gr.Markdown("Input a chemical composition formula.")
|
| 523 |
+
formula_g = gr.Textbox(
|
| 524 |
+
label="Composition",
|
| 525 |
+
placeholder="e.g. TiO2",
|
| 526 |
+
value="TiO2"
|
| 527 |
+
)
|
| 528 |
+
gr.Examples(
|
| 529 |
+
examples=["TiO2", "GaN", "ZnO", "Si", "CdS", "SrTiO3"],
|
| 530 |
+
inputs=formula_g
|
| 531 |
+
)
|
| 532 |
+
btn_g = gr.Button("Predict Band Gap", variant="primary")
|
| 533 |
+
with gr.Column(scale=1):
|
| 534 |
+
out_g = gr.Markdown(elem_id="result_text")
|
| 535 |
+
ctx_g = gr.Markdown()
|
| 536 |
+
gr.Markdown(
|
| 537 |
+
"📊 TRIADS V3 · 100K params · **0.3068 eV MAE** (best comp-only)",
|
| 538 |
+
elem_classes="benchmark-badge"
|
| 539 |
+
)
|
| 540 |
+
btn_g.click(predict_expt_gap, inputs=formula_g, outputs=[out_g, ctx_g])
|
| 541 |
+
|
| 542 |
+
# ── TAB 3: METALLICITY ────────────────────────────────────────
|
| 543 |
+
with gr.Tab("🔮 Metallicity"):
|
| 544 |
+
with gr.Row():
|
| 545 |
+
with gr.Column(scale=1):
|
| 546 |
+
gr.Markdown("### Metal vs. Non-metal Classifier")
|
| 547 |
+
gr.Markdown("Predicts electronic metallicity from composition.")
|
| 548 |
+
formula_m = gr.Textbox(
|
| 549 |
+
label="Composition",
|
| 550 |
+
placeholder="e.g. Cu",
|
| 551 |
+
value="Cu"
|
| 552 |
+
)
|
| 553 |
+
gr.Examples(
|
| 554 |
+
examples=["Cu", "SiO2", "Fe3O4", "BaTiO3", "Al", "MgO", "NiO"],
|
| 555 |
+
inputs=formula_m
|
| 556 |
+
)
|
| 557 |
+
btn_m = gr.Button("Classify Metallicity", variant="primary")
|
| 558 |
+
with gr.Column(scale=1):
|
| 559 |
+
out_m = gr.Markdown(elem_id="result_text")
|
| 560 |
+
ctx_m = gr.Markdown()
|
| 561 |
+
gr.Markdown(
|
| 562 |
+
"📊 TRIADS 100K · **0.9655 ROC-AUC** · Best composition-only (beats GPTChem 1B+)",
|
| 563 |
+
elem_classes="benchmark-badge"
|
| 564 |
+
)
|
| 565 |
+
btn_m.click(predict_ismetal, inputs=formula_m, outputs=[out_m, ctx_m])
|
| 566 |
+
|
| 567 |
+
# ── TAB 4: GLASS FORMING ──────────────────────────────────────
|
| 568 |
+
with gr.Tab("🪟 Glass Forming"):
|
| 569 |
+
with gr.Row():
|
| 570 |
+
with gr.Column(scale=1):
|
| 571 |
+
gr.Markdown("### Metallic Glass Forming Ability")
|
| 572 |
+
gr.Markdown("Predicts glass forming probability from alloy composition.")
|
| 573 |
+
formula_gf = gr.Textbox(
|
| 574 |
+
label="Alloy composition",
|
| 575 |
+
placeholder="e.g. Cu46Zr54",
|
| 576 |
+
value="Cu46Zr54"
|
| 577 |
+
)
|
| 578 |
+
gr.Examples(
|
| 579 |
+
examples=["Cu46Zr54", "Fe80B20", "Al86Ni7La6Y1", "Pd40Cu30Ni10P20"],
|
| 580 |
+
inputs=formula_gf
|
| 581 |
+
)
|
| 582 |
+
btn_gf = gr.Button("Predict Glass Forming", variant="primary")
|
| 583 |
+
with gr.Column(scale=1):
|
| 584 |
+
out_gf = gr.Markdown(elem_id="result_text")
|
| 585 |
+
ctx_gf = gr.Markdown()
|
| 586 |
+
gr.Markdown(
|
| 587 |
+
"📊 TRIADS 44K · 5-seed ensemble · **0.9285 ROC-AUC**",
|
| 588 |
+
elem_classes="benchmark-badge"
|
| 589 |
+
)
|
| 590 |
+
btn_gf.click(predict_glass, inputs=formula_gf, outputs=[out_gf, ctx_gf])
|
| 591 |
+
|
| 592 |
+
# ── TAB 5: JDFT2D ─────────────────────────────────────────────
|
| 593 |
+
with gr.Tab("📐 JDFT2D"):
|
| 594 |
+
with gr.Row():
|
| 595 |
+
with gr.Column(scale=1):
|
| 596 |
+
gr.Markdown("### 2D Material Exfoliation Energy (meV/atom)")
|
| 597 |
+
gr.Markdown("Predicts how easily a layered 2D material can be exfoliated.")
|
| 598 |
+
formula_j = gr.Textbox(
|
| 599 |
+
label="Composition",
|
| 600 |
+
placeholder="e.g. MoS2",
|
| 601 |
+
value="MoS2"
|
| 602 |
+
)
|
| 603 |
+
gr.Examples(
|
| 604 |
+
examples=["MoS2", "WSe2", "BN", "MoTe2", "WS2"],
|
| 605 |
+
inputs=formula_j
|
| 606 |
+
)
|
| 607 |
+
btn_j = gr.Button("Predict Exfoliation Energy", variant="primary")
|
| 608 |
+
with gr.Column(scale=1):
|
| 609 |
+
out_j = gr.Markdown(elem_id="result_text")
|
| 610 |
+
ctx_j = gr.Markdown()
|
| 611 |
+
gr.Markdown(
|
| 612 |
+
"📊 TRIADS V4 · 75K params · 5-seed ensemble · **35.89 meV/atom MAE**",
|
| 613 |
+
elem_classes="benchmark-badge"
|
| 614 |
+
)
|
| 615 |
+
btn_j.click(predict_jdft2d, inputs=formula_j, outputs=[out_j, ctx_j])
|
| 616 |
+
|
| 617 |
+
# ── TAB 6: PHONONS ────────────────────────────────────────────
|
| 618 |
+
with gr.Tab("🎵 Phonons"):
|
| 619 |
+
with gr.Row():
|
| 620 |
+
with gr.Column(scale=1):
|
| 621 |
+
gr.Markdown("### Peak Phonon Frequency (cm⁻¹)")
|
| 622 |
+
gr.Markdown(
|
| 623 |
+
"GraphTRIADS V6 predicts phonon peak frequency from crystal structure.\n\n"
|
| 624 |
+
"⚠️ **Structure required.** This model requires a full crystal "
|
| 625 |
+
"structure (CIF) rather than composition alone. Enter a composition "
|
| 626 |
+
"below to get a benchmark context, or see the GitHub repo for full "
|
| 627 |
+
"structure-based inference."
|
| 628 |
+
)
|
| 629 |
+
formula_ph = gr.Textbox(
|
| 630 |
+
label="Formula (for context only)",
|
| 631 |
+
placeholder="e.g. Si",
|
| 632 |
+
value="Si"
|
| 633 |
+
)
|
| 634 |
+
btn_ph = gr.Button("Show Benchmark Info", variant="primary")
|
| 635 |
+
with gr.Column(scale=1):
|
| 636 |
+
out_ph = gr.Markdown(elem_id="result_text")
|
| 637 |
+
ctx_ph = gr.Markdown()
|
| 638 |
+
gr.Markdown(
|
| 639 |
+
"📊 GraphTRIADS V6 · 247K params · Gate-halt · **41.91 cm⁻¹ MAE**",
|
| 640 |
+
elem_classes="benchmark-badge"
|
| 641 |
+
)
|
| 642 |
+
btn_ph.click(predict_phonons_placeholder, inputs=formula_ph, outputs=[out_ph, ctx_ph])
|
| 643 |
+
|
| 644 |
+
# ── FOOTER ──────────────────────────────────────────────────────
|
| 645 |
+
gr.Markdown("""
|
| 646 |
+
---
|
| 647 |
+
**TRIADS** · [GitHub](https://github.com/Rtx09x/TRIADS) · MIT License · Rudra Tiwari, 2026
|
| 648 |
+
|
| 649 |
+
*All benchmarks use exact Matbench 5-fold CV protocol (random\_state=18012019).
|
| 650 |
+
Predictions are ensemble averages across 5 folds (fold-specific scalers approximated at inference).*
|
| 651 |
+
""")
|
| 652 |
+
|
| 653 |
+
return demo
|
| 654 |
+
|
| 655 |
+
|
| 656 |
+
if __name__ == "__main__":
|
| 657 |
+
demo = build_interface()
|
| 658 |
+
demo.launch(share=False)
|
model_code/__init__.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TRIADS model_code package
|
| 2 |
+
# Import the model classes for convenience
|
| 3 |
+
|
| 4 |
+
from .steels_model import DeepHybridTRM as SteelsModel
|
| 5 |
+
from .expt_gap_model import DeepHybridTRM as ExptGapModel
|
| 6 |
+
from .classification_model import DeepHybridTRM as ClassificationModel
|
| 7 |
+
from .jdft2d_model import DeepHybridTRM as Jdft2dModel
|
| 8 |
+
|
| 9 |
+
__all__ = ["SteelsModel", "ExptGapModel", "ClassificationModel", "Jdft2dModel"]
|
model_code/classification_model.py
ADDED
|
@@ -0,0 +1,734 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
+=============================================================+
|
| 3 |
+
| TRIADS — Classification Benchmarks (Combined) |
|
| 4 |
+
| |
|
| 5 |
+
| 1. matbench_expt_is_metal (4,921) — Metal vs Non-metal |
|
| 6 |
+
| 2. matbench_glass (5,680) — Metallic Glass Forming |
|
| 7 |
+
| |
|
| 8 |
+
| 44K model | BCEWithLogitsLoss | ROCAUC | Single Seed |
|
| 9 |
+
| Seeds: [42, 123, 456, 789, 1024] |
|
| 10 |
+
| Folds: KFold(5, shuffle=True, random_state=18012019) |
|
| 11 |
+
| ^^^ exact matbench v0.1 fold generation ^^^ |
|
| 12 |
+
+=============================================================+
|
| 13 |
+
|
| 14 |
+
DEPENDENCIES (run before executing):
|
| 15 |
+
pip install matminer pymatgen gensim tqdm scikit-learn torch
|
| 16 |
+
|
| 17 |
+
USAGE:
|
| 18 |
+
python classification_benchmarks.py # runs both sequentially
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import os, copy, json, time, logging, warnings, urllib.request, shutil
|
| 22 |
+
warnings.filterwarnings('ignore')
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
import pandas as pd
|
| 26 |
+
from tqdm import tqdm
|
| 27 |
+
from sklearn.metrics import roc_auc_score
|
| 28 |
+
|
| 29 |
+
import torch
|
| 30 |
+
import torch.nn as nn
|
| 31 |
+
import torch.nn.functional as F
|
| 32 |
+
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
|
| 33 |
+
|
| 34 |
+
from sklearn.model_selection import KFold
|
| 35 |
+
from sklearn.preprocessing import StandardScaler
|
| 36 |
+
from pymatgen.core import Composition
|
| 37 |
+
from matminer.featurizers.composition import ElementProperty
|
| 38 |
+
from gensim.models import Word2Vec
|
| 39 |
+
|
| 40 |
+
logging.basicConfig(level=logging.INFO, format='%(name)s | %(message)s')
|
| 41 |
+
log = logging.getLogger("TRIADS-CLS")
|
| 42 |
+
|
| 43 |
+
BATCH_SIZE = 64
|
| 44 |
+
# Single seed first — test before committing to full ensemble
|
| 45 |
+
SEEDS = [42]
|
| 46 |
+
# Uncomment below for 5-seed ensemble after single seed looks good:
|
| 47 |
+
# SEEDS = [42, 123, 456, 789, 1024]
|
| 48 |
+
|
| 49 |
+
# ~44K config — smaller to prevent overfitting
|
| 50 |
+
MODEL_CFG = dict(
|
| 51 |
+
d_attn=24, nhead=4, d_hidden=48, ff_dim=72,
|
| 52 |
+
dropout=0.20, max_steps=16,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
# Matbench v0.1 exact fold seed — DO NOT CHANGE
|
| 56 |
+
MATBENCH_FOLD_SEED = 18012019
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
# ======================================================================
|
| 60 |
+
# FAST TENSOR DATALOADER
|
| 61 |
+
# ======================================================================
|
| 62 |
+
|
| 63 |
+
class FastTensorDataLoader:
|
| 64 |
+
def __init__(self, *tensors, batch_size=64, shuffle=False):
|
| 65 |
+
assert all(t.shape[0] == tensors[0].shape[0] for t in tensors)
|
| 66 |
+
self.tensors = tensors
|
| 67 |
+
self.dataset_len = tensors[0].shape[0]
|
| 68 |
+
self.batch_size = batch_size
|
| 69 |
+
self.shuffle = shuffle
|
| 70 |
+
self.n_batches = (self.dataset_len + batch_size - 1) // batch_size
|
| 71 |
+
|
| 72 |
+
def __iter__(self):
|
| 73 |
+
if self.shuffle:
|
| 74 |
+
idx = torch.randperm(self.dataset_len, device=self.tensors[0].device)
|
| 75 |
+
self.tensors = tuple(t[idx] for t in self.tensors)
|
| 76 |
+
self.i = 0
|
| 77 |
+
return self
|
| 78 |
+
|
| 79 |
+
def __next__(self):
|
| 80 |
+
if self.i >= self.dataset_len:
|
| 81 |
+
raise StopIteration
|
| 82 |
+
batch = tuple(t[self.i:self.i + self.batch_size] for t in self.tensors)
|
| 83 |
+
self.i += self.batch_size
|
| 84 |
+
return batch
|
| 85 |
+
|
| 86 |
+
def __len__(self):
|
| 87 |
+
return self.n_batches
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
# ======================================================================
|
| 91 |
+
# FEATURIZERS
|
| 92 |
+
# ======================================================================
|
| 93 |
+
|
| 94 |
+
_ORBITAL_ENERGIES = {
|
| 95 |
+
'H': {'1s': -13.6}, 'He': {'1s': -24.6},
|
| 96 |
+
'Li': {'2s': -5.4}, 'Be': {'2s': -9.3},
|
| 97 |
+
'B': {'2s': -14.0, '2p': -8.3}, 'C': {'2s': -19.4, '2p': -11.3},
|
| 98 |
+
'N': {'2s': -25.6, '2p': -14.5}, 'O': {'2s': -32.4, '2p': -13.6},
|
| 99 |
+
'F': {'2s': -40.2, '2p': -17.4}, 'Ne': {'2s': -48.5, '2p': -21.6},
|
| 100 |
+
'Na': {'3s': -5.1}, 'Mg': {'3s': -7.6},
|
| 101 |
+
'Al': {'3s': -11.3, '3p': -6.0}, 'Si': {'3s': -15.0, '3p': -8.2},
|
| 102 |
+
'P': {'3s': -18.7, '3p': -10.5}, 'S': {'3s': -22.7, '3p': -10.4},
|
| 103 |
+
'Cl': {'3s': -25.3, '3p': -13.0}, 'Ar': {'3s': -29.2, '3p': -15.8},
|
| 104 |
+
'K': {'4s': -4.3}, 'Ca': {'4s': -6.1},
|
| 105 |
+
'Sc': {'4s': -6.6, '3d': -8.0}, 'Ti': {'4s': -6.8, '3d': -8.5},
|
| 106 |
+
'V': {'4s': -6.7, '3d': -8.3}, 'Cr': {'4s': -6.8, '3d': -8.7},
|
| 107 |
+
'Mn': {'4s': -7.4, '3d': -9.5}, 'Fe': {'4s': -7.9, '3d': -10.0},
|
| 108 |
+
'Co': {'4s': -7.9, '3d': -10.0}, 'Ni': {'4s': -7.6, '3d': -10.0},
|
| 109 |
+
'Cu': {'4s': -7.7, '3d': -11.7}, 'Zn': {'4s': -9.4, '3d': -17.3},
|
| 110 |
+
'Ga': {'4s': -12.6, '4p': -6.0}, 'Ge': {'4s': -15.6, '4p': -7.9},
|
| 111 |
+
'As': {'4s': -18.6, '4p': -9.8}, 'Se': {'4s': -21.1, '4p': -9.8},
|
| 112 |
+
'Br': {'4s': -24.0, '4p': -11.8}, 'Kr': {'4s': -27.5, '4p': -14.0},
|
| 113 |
+
'Rb': {'5s': -4.2}, 'Sr': {'5s': -5.7},
|
| 114 |
+
'Y': {'5s': -6.5, '4d': -7.4}, 'Zr': {'5s': -6.8, '4d': -8.3},
|
| 115 |
+
'Nb': {'5s': -6.9, '4d': -8.5}, 'Mo': {'5s': -7.1, '4d': -8.9},
|
| 116 |
+
'Ru': {'5s': -7.4, '4d': -8.7}, 'Rh': {'5s': -7.5, '4d': -8.8},
|
| 117 |
+
'Pd': {'4d': -8.3}, 'Ag': {'5s': -7.6, '4d': -12.3},
|
| 118 |
+
'Cd': {'5s': -9.0, '4d': -16.7}, 'In': {'5s': -12.0, '5p': -5.8},
|
| 119 |
+
'Sn': {'5s': -14.6, '5p': -7.3}, 'Sb': {'5s': -16.5, '5p': -8.6},
|
| 120 |
+
'Te': {'5s': -19.0, '5p': -9.0}, 'I': {'5s': -21.1, '5p': -10.5},
|
| 121 |
+
'Xe': {'5s': -23.4, '5p': -12.1}, 'Cs': {'6s': -3.9}, 'Ba': {'6s': -5.2},
|
| 122 |
+
'La': {'6s': -5.6, '5d': -7.5},
|
| 123 |
+
'Ce': {'6s': -5.5, '5d': -7.3, '4f': -7.0},
|
| 124 |
+
'Hf': {'6s': -7.0, '5d': -8.1}, 'Ta': {'6s': -7.9, '5d': -9.6},
|
| 125 |
+
'W': {'6s': -8.0, '5d': -9.8}, 'Re': {'6s': -7.9, '5d': -9.2},
|
| 126 |
+
'Os': {'6s': -8.4, '5d': -10.0}, 'Ir': {'6s': -9.1, '5d': -10.7},
|
| 127 |
+
'Pt': {'6s': -9.0, '5d': -10.5}, 'Au': {'6s': -9.2, '5d': -12.8},
|
| 128 |
+
'Pb': {'6s': -15.0, '6p': -7.4}, 'Bi': {'6s': -16.7, '6p': -7.3},
|
| 129 |
+
}
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def _compute_homo_lumo_gap(comp):
|
| 133 |
+
elements = comp.get_el_amt_dict()
|
| 134 |
+
highest_occ, all_energies = [], []
|
| 135 |
+
for el, frac in elements.items():
|
| 136 |
+
if el not in _ORBITAL_ENERGIES:
|
| 137 |
+
return np.array([0.0, 0.0, 0.0], dtype=np.float32)
|
| 138 |
+
orbs = _ORBITAL_ENERGIES[el]
|
| 139 |
+
highest_occ.append((max(orbs.values()), frac))
|
| 140 |
+
all_energies.extend(orbs.values())
|
| 141 |
+
if not highest_occ:
|
| 142 |
+
return np.array([0.0, 0.0, 0.0], dtype=np.float32)
|
| 143 |
+
homo = sum(e * f for e, f in highest_occ) / sum(f for _, f in highest_occ)
|
| 144 |
+
above = [e for e in all_energies if e > homo]
|
| 145 |
+
lumo = min(above) if above else homo + 1.0
|
| 146 |
+
return np.array([homo, lumo, lumo - homo], dtype=np.float32)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class _BaseFeaturizer:
|
| 150 |
+
"""Shared Mat2Vec loading and Magpie featurization."""
|
| 151 |
+
GCS = "https://storage.googleapis.com/mat2vec/"
|
| 152 |
+
FILES = ["pretrained_embeddings",
|
| 153 |
+
"pretrained_embeddings.wv.vectors.npy",
|
| 154 |
+
"pretrained_embeddings.trainables.syn1neg.npy"]
|
| 155 |
+
|
| 156 |
+
def __init__(self, cache="mat2vec_cache"):
|
| 157 |
+
self.ep_magpie = ElementProperty.from_preset("magpie")
|
| 158 |
+
self.n_mg = len(self.ep_magpie.feature_labels())
|
| 159 |
+
self.n_extra = None
|
| 160 |
+
self.scaler = None
|
| 161 |
+
|
| 162 |
+
os.makedirs(cache, exist_ok=True)
|
| 163 |
+
for f in self.FILES:
|
| 164 |
+
p = os.path.join(cache, f)
|
| 165 |
+
if not os.path.exists(p):
|
| 166 |
+
log.info(f" Downloading {f}...")
|
| 167 |
+
urllib.request.urlretrieve(self.GCS + f, p)
|
| 168 |
+
self.m2v = Word2Vec.load(os.path.join(cache, "pretrained_embeddings"))
|
| 169 |
+
self.emb = {w: self.m2v.wv[w] for w in self.m2v.wv.index_to_key}
|
| 170 |
+
|
| 171 |
+
def _pool(self, c):
|
| 172 |
+
v, t = np.zeros(200, np.float32), 0.0
|
| 173 |
+
for s, f in c.get_el_amt_dict().items():
|
| 174 |
+
if s in self.emb: v += f * self.emb[s]; t += f
|
| 175 |
+
return v / max(t, 1e-8)
|
| 176 |
+
|
| 177 |
+
def featurize_all(self, comps):
|
| 178 |
+
out = []
|
| 179 |
+
test_ex = self._featurize_extra(comps[0])
|
| 180 |
+
self.n_extra = len(test_ex)
|
| 181 |
+
total = self.n_mg + self.n_extra + 200
|
| 182 |
+
log.info(f"Features: {self.n_mg} Magpie + "
|
| 183 |
+
f"{self.n_extra} Extra + 200 Mat2Vec = {total}d")
|
| 184 |
+
for c in tqdm(comps, desc=" Featurizing", leave=False):
|
| 185 |
+
try: mg = np.array(self.ep_magpie.featurize(c), np.float32)
|
| 186 |
+
except: mg = np.zeros(self.n_mg, np.float32)
|
| 187 |
+
ex = self._featurize_extra(c)
|
| 188 |
+
out.append(np.concatenate([
|
| 189 |
+
np.nan_to_num(mg, nan=0.0),
|
| 190 |
+
np.nan_to_num(ex, nan=0.0),
|
| 191 |
+
self._pool(c)
|
| 192 |
+
]))
|
| 193 |
+
return np.array(out)
|
| 194 |
+
|
| 195 |
+
def fit_scaler(self, X): self.scaler = StandardScaler().fit(X)
|
| 196 |
+
def transform(self, X):
|
| 197 |
+
if not self.scaler: return X
|
| 198 |
+
return np.nan_to_num(self.scaler.transform(X), nan=0.0).astype(np.float32)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
class MetallicityFeaturizer(_BaseFeaturizer):
|
| 202 |
+
"""354d — keeps HOMO/LUMO gap + BandCenter (relevant to metallicity)."""
|
| 203 |
+
def __init__(self, cache="mat2vec_cache"):
|
| 204 |
+
super().__init__(cache)
|
| 205 |
+
from matminer.featurizers.composition import (
|
| 206 |
+
Stoichiometry, ValenceOrbital, IonProperty, BandCenter
|
| 207 |
+
)
|
| 208 |
+
from matminer.featurizers.composition.element import TMetalFraction
|
| 209 |
+
self.extra_featurizers = [
|
| 210 |
+
("Stoichiometry", Stoichiometry()),
|
| 211 |
+
("ValenceOrbital", ValenceOrbital()),
|
| 212 |
+
("IonProperty", IonProperty()),
|
| 213 |
+
("BandCenter", BandCenter()),
|
| 214 |
+
("TMetalFraction", TMetalFraction()),
|
| 215 |
+
]
|
| 216 |
+
self._extra_sizes = {}
|
| 217 |
+
for name, ftzr in self.extra_featurizers:
|
| 218 |
+
try: self._extra_sizes[name] = len(ftzr.feature_labels())
|
| 219 |
+
except: self._extra_sizes[name] = None
|
| 220 |
+
|
| 221 |
+
def _featurize_extra(self, comp):
|
| 222 |
+
parts = []
|
| 223 |
+
for name, ftzr in self.extra_featurizers:
|
| 224 |
+
try:
|
| 225 |
+
vals = np.array(ftzr.featurize(comp), np.float32)
|
| 226 |
+
parts.append(np.nan_to_num(vals, nan=0.0))
|
| 227 |
+
if self._extra_sizes.get(name) is None:
|
| 228 |
+
self._extra_sizes[name] = len(vals)
|
| 229 |
+
except:
|
| 230 |
+
sz = self._extra_sizes.get(name, 0) or 1
|
| 231 |
+
parts.append(np.zeros(sz, np.float32))
|
| 232 |
+
parts.append(_compute_homo_lumo_gap(comp))
|
| 233 |
+
return np.concatenate(parts)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
class GlassFeaturizer(_BaseFeaturizer):
|
| 237 |
+
"""~351d — removes BandCenter & HOMO/LUMO (irrelevant to glass forming)."""
|
| 238 |
+
def __init__(self, cache="mat2vec_cache"):
|
| 239 |
+
super().__init__(cache)
|
| 240 |
+
from matminer.featurizers.composition import (
|
| 241 |
+
Stoichiometry, ValenceOrbital, IonProperty
|
| 242 |
+
)
|
| 243 |
+
from matminer.featurizers.composition.element import TMetalFraction
|
| 244 |
+
self.extra_featurizers = [
|
| 245 |
+
("Stoichiometry", Stoichiometry()),
|
| 246 |
+
("ValenceOrbital", ValenceOrbital()),
|
| 247 |
+
("IonProperty", IonProperty()),
|
| 248 |
+
("TMetalFraction", TMetalFraction()),
|
| 249 |
+
]
|
| 250 |
+
self._extra_sizes = {}
|
| 251 |
+
for name, ftzr in self.extra_featurizers:
|
| 252 |
+
try: self._extra_sizes[name] = len(ftzr.feature_labels())
|
| 253 |
+
except: self._extra_sizes[name] = None
|
| 254 |
+
|
| 255 |
+
def _featurize_extra(self, comp):
|
| 256 |
+
parts = []
|
| 257 |
+
for name, ftzr in self.extra_featurizers:
|
| 258 |
+
try:
|
| 259 |
+
vals = np.array(ftzr.featurize(comp), np.float32)
|
| 260 |
+
parts.append(np.nan_to_num(vals, nan=0.0))
|
| 261 |
+
if self._extra_sizes.get(name) is None:
|
| 262 |
+
self._extra_sizes[name] = len(vals)
|
| 263 |
+
except:
|
| 264 |
+
sz = self._extra_sizes.get(name, 0) or 1
|
| 265 |
+
parts.append(np.zeros(sz, np.float32))
|
| 266 |
+
return np.concatenate(parts)
|
| 267 |
+
|
| 268 |
+
|
| 269 |
+
# ======================================================================
|
| 270 |
+
# MODEL — DeepHybridTRM (100K params)
|
| 271 |
+
# ======================================================================
|
| 272 |
+
|
| 273 |
+
class DeepHybridTRM(nn.Module):
|
| 274 |
+
def __init__(self, n_props=22, stat_dim=6, n_extra=0, mat2vec_dim=200,
|
| 275 |
+
d_attn=32, nhead=4, d_hidden=64, ff_dim=96,
|
| 276 |
+
dropout=0.15, max_steps=16, **kw):
|
| 277 |
+
super().__init__()
|
| 278 |
+
self.max_steps, self.D = max_steps, d_hidden
|
| 279 |
+
self.n_props, self.stat_dim, self.n_extra = n_props, stat_dim, n_extra
|
| 280 |
+
|
| 281 |
+
self.tok_proj = nn.Sequential(
|
| 282 |
+
nn.Linear(stat_dim, d_attn), nn.LayerNorm(d_attn), nn.GELU())
|
| 283 |
+
self.m2v_proj = nn.Sequential(
|
| 284 |
+
nn.Linear(mat2vec_dim, d_attn), nn.LayerNorm(d_attn), nn.GELU())
|
| 285 |
+
|
| 286 |
+
self.sa1 = nn.MultiheadAttention(d_attn, nhead, dropout=dropout, batch_first=True)
|
| 287 |
+
self.sa1_n = nn.LayerNorm(d_attn)
|
| 288 |
+
self.sa1_ff = nn.Sequential(
|
| 289 |
+
nn.Linear(d_attn, d_attn*2), nn.GELU(), nn.Dropout(dropout),
|
| 290 |
+
nn.Linear(d_attn*2, d_attn))
|
| 291 |
+
self.sa1_fn = nn.LayerNorm(d_attn)
|
| 292 |
+
|
| 293 |
+
self.sa2 = nn.MultiheadAttention(d_attn, nhead, dropout=dropout, batch_first=True)
|
| 294 |
+
self.sa2_n = nn.LayerNorm(d_attn)
|
| 295 |
+
self.sa2_ff = nn.Sequential(
|
| 296 |
+
nn.Linear(d_attn, d_attn*2), nn.GELU(), nn.Dropout(dropout),
|
| 297 |
+
nn.Linear(d_attn*2, d_attn))
|
| 298 |
+
self.sa2_fn = nn.LayerNorm(d_attn)
|
| 299 |
+
|
| 300 |
+
self.ca = nn.MultiheadAttention(d_attn, nhead, dropout=dropout, batch_first=True)
|
| 301 |
+
self.ca_n = nn.LayerNorm(d_attn)
|
| 302 |
+
|
| 303 |
+
pool_in = d_attn + (n_extra if n_extra > 0 else 0)
|
| 304 |
+
self.pool = nn.Sequential(
|
| 305 |
+
nn.Linear(pool_in, d_hidden), nn.LayerNorm(d_hidden), nn.GELU())
|
| 306 |
+
|
| 307 |
+
self.z_up = nn.Sequential(
|
| 308 |
+
nn.Linear(d_hidden*3, ff_dim), nn.GELU(), nn.Dropout(dropout),
|
| 309 |
+
nn.Linear(ff_dim, d_hidden), nn.LayerNorm(d_hidden))
|
| 310 |
+
self.y_up = nn.Sequential(
|
| 311 |
+
nn.Linear(d_hidden*2, ff_dim), nn.GELU(), nn.Dropout(dropout),
|
| 312 |
+
nn.Linear(ff_dim, d_hidden), nn.LayerNorm(d_hidden))
|
| 313 |
+
self.head = nn.Linear(d_hidden, 1)
|
| 314 |
+
self._init()
|
| 315 |
+
|
| 316 |
+
def _init(self):
|
| 317 |
+
for m in self.modules():
|
| 318 |
+
if isinstance(m, nn.Linear):
|
| 319 |
+
nn.init.xavier_uniform_(m.weight)
|
| 320 |
+
if m.bias is not None: nn.init.zeros_(m.bias)
|
| 321 |
+
|
| 322 |
+
def _attention(self, x):
|
| 323 |
+
B = x.size(0)
|
| 324 |
+
mg_dim = self.n_props * self.stat_dim
|
| 325 |
+
if self.n_extra > 0:
|
| 326 |
+
extra = x[:, mg_dim:mg_dim + self.n_extra]
|
| 327 |
+
m2v = x[:, mg_dim + self.n_extra:]
|
| 328 |
+
else:
|
| 329 |
+
extra, m2v = None, x[:, mg_dim:]
|
| 330 |
+
tok = self.tok_proj(x[:, :mg_dim].view(B, self.n_props, self.stat_dim))
|
| 331 |
+
ctx = self.m2v_proj(m2v).unsqueeze(1)
|
| 332 |
+
tok = self.sa1_n(tok + self.sa1(tok, tok, tok)[0])
|
| 333 |
+
tok = self.sa1_fn(tok + self.sa1_ff(tok))
|
| 334 |
+
tok = self.sa2_n(tok + self.sa2(tok, tok, tok)[0])
|
| 335 |
+
tok = self.sa2_fn(tok + self.sa2_ff(tok))
|
| 336 |
+
tok = self.ca_n(tok + self.ca(tok, ctx, ctx)[0])
|
| 337 |
+
pooled = tok.mean(dim=1)
|
| 338 |
+
if extra is not None:
|
| 339 |
+
pooled = torch.cat([pooled, extra], dim=-1)
|
| 340 |
+
return self.pool(pooled)
|
| 341 |
+
|
| 342 |
+
def forward(self, x, deep_supervision=False):
|
| 343 |
+
B = x.size(0)
|
| 344 |
+
xp = self._attention(x)
|
| 345 |
+
z = torch.zeros(B, self.D, device=x.device)
|
| 346 |
+
y = torch.zeros(B, self.D, device=x.device)
|
| 347 |
+
step_preds = []
|
| 348 |
+
for s in range(self.max_steps):
|
| 349 |
+
z = z + self.z_up(torch.cat([xp, y, z], -1))
|
| 350 |
+
y = y + self.y_up(torch.cat([y, z], -1))
|
| 351 |
+
step_preds.append(self.head(y).squeeze(1))
|
| 352 |
+
return step_preds if deep_supervision else step_preds[-1]
|
| 353 |
+
|
| 354 |
+
def count_parameters(self):
|
| 355 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
# ======================================================================
|
| 359 |
+
# LOSS + UTILS
|
| 360 |
+
# ======================================================================
|
| 361 |
+
|
| 362 |
+
def deep_supervision_loss_bce(step_preds, targets):
|
| 363 |
+
preds = torch.stack(step_preds)
|
| 364 |
+
n = preds.shape[0]
|
| 365 |
+
w = torch.arange(1, n + 1, device=preds.device, dtype=preds.dtype)
|
| 366 |
+
w = w / w.sum()
|
| 367 |
+
per_step = torch.stack([
|
| 368 |
+
F.binary_cross_entropy_with_logits(preds[i], targets, reduction='mean')
|
| 369 |
+
for i in range(n)
|
| 370 |
+
])
|
| 371 |
+
return (w * per_step).sum()
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
def strat_split_cls(targets, val_size=0.15, seed=42):
|
| 375 |
+
tr, vl = [], []
|
| 376 |
+
rng = np.random.RandomState(seed)
|
| 377 |
+
for cls in [0, 1]:
|
| 378 |
+
m = np.where(targets == cls)[0]
|
| 379 |
+
if len(m) == 0: continue
|
| 380 |
+
n = max(1, int(len(m) * val_size))
|
| 381 |
+
c = rng.choice(m, n, replace=False)
|
| 382 |
+
vl.extend(c.tolist()); tr.extend(np.setdiff1d(m, c).tolist())
|
| 383 |
+
return np.array(tr), np.array(vl)
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
@torch.inference_mode()
|
| 387 |
+
def predict_proba(model, dl):
|
| 388 |
+
model.eval()
|
| 389 |
+
preds = []
|
| 390 |
+
for bx, _ in dl:
|
| 391 |
+
preds.append(torch.sigmoid(model(bx)).cpu())
|
| 392 |
+
return torch.cat(preds)
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
# ======================================================================
|
| 396 |
+
# TRAINING
|
| 397 |
+
# ======================================================================
|
| 398 |
+
|
| 399 |
+
def train_fold(model, tr_dl, vl_dl, device,
|
| 400 |
+
epochs=300, swa_start=200, fold=1, seed=42, label="100K"):
|
| 401 |
+
opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
|
| 402 |
+
sch = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 403 |
+
opt, T_max=swa_start, eta_min=1e-4)
|
| 404 |
+
swa_m = AveragedModel(model)
|
| 405 |
+
swa_s = SWALR(opt, swa_lr=5e-4)
|
| 406 |
+
swa_on = False
|
| 407 |
+
best_v, best_w = float('-inf'), None
|
| 408 |
+
|
| 409 |
+
pbar = tqdm(range(epochs), desc=f" [{label}|s{seed}] F{fold}/5",
|
| 410 |
+
leave=False, ncols=120)
|
| 411 |
+
for ep in pbar:
|
| 412 |
+
model.train()
|
| 413 |
+
epoch_loss, n_batches = 0.0, 0
|
| 414 |
+
for bx, by in tr_dl:
|
| 415 |
+
sp = model(bx, deep_supervision=True)
|
| 416 |
+
loss = deep_supervision_loss_bce(sp, by)
|
| 417 |
+
opt.zero_grad(set_to_none=True)
|
| 418 |
+
loss.backward()
|
| 419 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 420 |
+
opt.step()
|
| 421 |
+
epoch_loss += loss.item()
|
| 422 |
+
n_batches += 1
|
| 423 |
+
|
| 424 |
+
model.eval()
|
| 425 |
+
vp_list, vt_list = [], []
|
| 426 |
+
with torch.inference_mode():
|
| 427 |
+
for bx, by in vl_dl:
|
| 428 |
+
vp_list.append(torch.sigmoid(model(bx)).cpu())
|
| 429 |
+
vt_list.append(by.cpu())
|
| 430 |
+
vp = torch.cat(vp_list).numpy()
|
| 431 |
+
vt = torch.cat(vt_list).numpy()
|
| 432 |
+
try: val_auc = roc_auc_score(vt, vp)
|
| 433 |
+
except: val_auc = 0.5
|
| 434 |
+
|
| 435 |
+
if ep < swa_start:
|
| 436 |
+
sch.step()
|
| 437 |
+
if val_auc > best_v:
|
| 438 |
+
best_v = val_auc
|
| 439 |
+
best_w = copy.deepcopy(model.state_dict())
|
| 440 |
+
else:
|
| 441 |
+
if not swa_on: swa_on = True
|
| 442 |
+
swa_m.update_parameters(model); swa_s.step()
|
| 443 |
+
|
| 444 |
+
if ep % 10 == 0 or ep == epochs - 1:
|
| 445 |
+
pbar.set_postfix(Best=f'{best_v:.4f}', Ph='SWA' if swa_on else 'COS',
|
| 446 |
+
Loss=f'{epoch_loss/max(n_batches,1):.4f}',
|
| 447 |
+
AUC=f'{val_auc:.4f}')
|
| 448 |
+
|
| 449 |
+
if swa_on:
|
| 450 |
+
update_bn(tr_dl, swa_m, device=device)
|
| 451 |
+
model.load_state_dict(swa_m.module.state_dict())
|
| 452 |
+
else:
|
| 453 |
+
model.load_state_dict(best_w)
|
| 454 |
+
return best_v, model
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
# ======================================================================
|
| 458 |
+
# GENERIC BENCHMARK RUNNER
|
| 459 |
+
# ======================================================================
|
| 460 |
+
|
| 461 |
+
def run_classification_benchmark(
|
| 462 |
+
dataset_name, target_col, featurizer_cls,
|
| 463 |
+
model_dir, summary_file, baseline_name, baseline_auc,
|
| 464 |
+
device
|
| 465 |
+
):
|
| 466 |
+
"""Run a full 5-seed ensemble classification benchmark."""
|
| 467 |
+
t0 = time.time()
|
| 468 |
+
|
| 469 |
+
# ── LOAD ─────────────────────────────────────────────────────────
|
| 470 |
+
print(f"\n Loading {dataset_name}...")
|
| 471 |
+
from matminer.datasets import load_dataset
|
| 472 |
+
df = load_dataset(dataset_name)
|
| 473 |
+
|
| 474 |
+
targets_all = np.array(df[target_col].astype(float).tolist(), np.float32)
|
| 475 |
+
|
| 476 |
+
# Handle different column names
|
| 477 |
+
if 'composition' in df.columns:
|
| 478 |
+
comps_all = [Composition(c) for c in df['composition'].tolist()]
|
| 479 |
+
elif 'structure' in df.columns:
|
| 480 |
+
comps_all = [s.composition for s in df['structure'].tolist()]
|
| 481 |
+
elif 'formula' in df.columns:
|
| 482 |
+
comps_all = [Composition(str(f)) for f in df['formula'].tolist()]
|
| 483 |
+
else:
|
| 484 |
+
raise ValueError(f"Cannot find composition column in {df.columns.tolist()}")
|
| 485 |
+
|
| 486 |
+
n_pos = int(targets_all.sum())
|
| 487 |
+
n_neg = len(targets_all) - n_pos
|
| 488 |
+
print(f" Dataset: {len(comps_all)} samples ({n_pos} positive, {n_neg} negative)")
|
| 489 |
+
print(f" Class balance: {n_pos/len(targets_all)*100:.1f}% positive")
|
| 490 |
+
|
| 491 |
+
# ── FEATURIZE (once) ─────────────────────────────────────────────
|
| 492 |
+
t_feat = time.time()
|
| 493 |
+
feat = featurizer_cls()
|
| 494 |
+
X_all = feat.featurize_all(comps_all)
|
| 495 |
+
n_extra = feat.n_extra
|
| 496 |
+
print(f" Features: {X_all.shape} (n_extra={n_extra})")
|
| 497 |
+
print(f" Featurization: {time.time()-t_feat:.1f}s")
|
| 498 |
+
|
| 499 |
+
# ── FOLDS — exact matbench v0.1 splits ───────────────────────────
|
| 500 |
+
kfold = KFold(n_splits=5, shuffle=True, random_state=MATBENCH_FOLD_SEED)
|
| 501 |
+
folds = list(kfold.split(comps_all))
|
| 502 |
+
|
| 503 |
+
# Verify zero leakage
|
| 504 |
+
all_test_indices = []
|
| 505 |
+
for fi, (tv, te) in enumerate(folds):
|
| 506 |
+
assert len(set(tv) & set(te)) == 0, f"Fold {fi}: train/test overlap!"
|
| 507 |
+
all_test_indices.extend(te.tolist())
|
| 508 |
+
assert len(set(all_test_indices)) == len(comps_all), "Not all samples covered!"
|
| 509 |
+
assert len(all_test_indices) == len(comps_all), "Duplicate test samples!"
|
| 510 |
+
print(f" 5 folds verified: zero leakage, full coverage, no duplicates ✓\n")
|
| 511 |
+
|
| 512 |
+
# ── MODEL INFO ───────────────────────────────────────────────────
|
| 513 |
+
model_kw = dict(n_props=22, stat_dim=6, n_extra=n_extra,
|
| 514 |
+
mat2vec_dim=200, **MODEL_CFG)
|
| 515 |
+
test_model = DeepHybridTRM(**model_kw)
|
| 516 |
+
n_params = test_model.count_parameters()
|
| 517 |
+
del test_model
|
| 518 |
+
print(f" Model: {n_params:,} params (100K config)")
|
| 519 |
+
|
| 520 |
+
# ── TRAIN ALL SEEDS ──────────────────────────────────────────────
|
| 521 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 522 |
+
all_seed_aucs = {}
|
| 523 |
+
all_fold_probs = {}
|
| 524 |
+
all_fold_targets = {}
|
| 525 |
+
|
| 526 |
+
for seed in SEEDS:
|
| 527 |
+
print(f"\n {'─'*3} Seed {seed} {'─'*40}")
|
| 528 |
+
t_seed = time.time()
|
| 529 |
+
seed_aucs = {}
|
| 530 |
+
|
| 531 |
+
for fi, (tv_i, te_i) in enumerate(folds):
|
| 532 |
+
tri, vli = strat_split_cls(targets_all[tv_i], 0.15, seed + fi)
|
| 533 |
+
feat.fit_scaler(X_all[tv_i][tri])
|
| 534 |
+
|
| 535 |
+
tr_x = torch.tensor(feat.transform(X_all[tv_i][tri]), dtype=torch.float32).to(device)
|
| 536 |
+
tr_y = torch.tensor(targets_all[tv_i][tri], dtype=torch.float32).to(device)
|
| 537 |
+
vl_x = torch.tensor(feat.transform(X_all[tv_i][vli]), dtype=torch.float32).to(device)
|
| 538 |
+
vl_y = torch.tensor(targets_all[tv_i][vli], dtype=torch.float32).to(device)
|
| 539 |
+
te_x = torch.tensor(feat.transform(X_all[te_i]), dtype=torch.float32).to(device)
|
| 540 |
+
te_y = torch.tensor(targets_all[te_i], dtype=torch.float32).to(device)
|
| 541 |
+
|
| 542 |
+
tr_dl = FastTensorDataLoader(tr_x, tr_y, batch_size=BATCH_SIZE, shuffle=True)
|
| 543 |
+
vl_dl = FastTensorDataLoader(vl_x, vl_y, batch_size=BATCH_SIZE, shuffle=False)
|
| 544 |
+
te_dl = FastTensorDataLoader(te_x, te_y, batch_size=BATCH_SIZE, shuffle=False)
|
| 545 |
+
|
| 546 |
+
torch.manual_seed(seed + fi)
|
| 547 |
+
np.random.seed(seed + fi)
|
| 548 |
+
if device.type == 'cuda': torch.cuda.manual_seed(seed + fi)
|
| 549 |
+
|
| 550 |
+
model = DeepHybridTRM(**model_kw).to(device)
|
| 551 |
+
bv, model = train_fold(model, tr_dl, vl_dl, device,
|
| 552 |
+
epochs=300, swa_start=200,
|
| 553 |
+
fold=fi+1, seed=seed, label="44K")
|
| 554 |
+
|
| 555 |
+
probs = predict_proba(model, te_dl)
|
| 556 |
+
auc = roc_auc_score(te_y.cpu().numpy(), probs.numpy())
|
| 557 |
+
seed_aucs[fi] = auc
|
| 558 |
+
|
| 559 |
+
if fi not in all_fold_probs:
|
| 560 |
+
all_fold_probs[fi] = {}
|
| 561 |
+
all_fold_targets[fi] = te_y.cpu()
|
| 562 |
+
all_fold_probs[fi][seed] = probs
|
| 563 |
+
|
| 564 |
+
torch.save({
|
| 565 |
+
'model_state': model.state_dict(),
|
| 566 |
+
'test_auc': auc, 'fold': fi+1, 'seed': seed,
|
| 567 |
+
'n_extra': n_extra,
|
| 568 |
+
}, f'{model_dir}/{dataset_name}_100K_s{seed}_f{fi+1}.pt')
|
| 569 |
+
|
| 570 |
+
del model, tr_x, tr_y, vl_x, vl_y, te_x, te_y
|
| 571 |
+
if device.type == 'cuda': torch.cuda.empty_cache()
|
| 572 |
+
|
| 573 |
+
avg_s = np.mean(list(seed_aucs.values()))
|
| 574 |
+
all_seed_aucs[seed] = seed_aucs
|
| 575 |
+
dt = time.time() - t_seed
|
| 576 |
+
print(f"\n Seed {seed}: avg={avg_s:.4f} | "
|
| 577 |
+
f"{[f'{seed_aucs[i]:.4f}' for i in range(5)]} ({dt:.0f}s)")
|
| 578 |
+
|
| 579 |
+
# ── ENSEMBLE ─────────────────────────────────────────────────────
|
| 580 |
+
ens_aucs = {}
|
| 581 |
+
for fi in range(5):
|
| 582 |
+
probs_stack = torch.stack([all_fold_probs[fi][s] for s in SEEDS])
|
| 583 |
+
ens_prob = probs_stack.mean(dim=0)
|
| 584 |
+
ens_aucs[fi] = roc_auc_score(
|
| 585 |
+
all_fold_targets[fi].numpy(), ens_prob.numpy())
|
| 586 |
+
|
| 587 |
+
single_avgs = [np.mean(list(all_seed_aucs[s].values())) for s in SEEDS]
|
| 588 |
+
single_mean = np.mean(single_avgs)
|
| 589 |
+
single_std = np.std(single_avgs)
|
| 590 |
+
ens_mean = np.mean(list(ens_aucs.values()))
|
| 591 |
+
ens_std = np.std(list(ens_aucs.values()))
|
| 592 |
+
|
| 593 |
+
tt = time.time() - t0
|
| 594 |
+
|
| 595 |
+
print(f"""
|
| 596 |
+
{'='*72}
|
| 597 |
+
FINAL RESULTS — TRIADS on {dataset_name} (ROCAUC)
|
| 598 |
+
{'='*72}
|
| 599 |
+
|
| 600 |
+
Per-seed results:""")
|
| 601 |
+
for seed in SEEDS:
|
| 602 |
+
sm = all_seed_aucs[seed]
|
| 603 |
+
avg_s = np.mean(list(sm.values()))
|
| 604 |
+
print(f" Seed {seed:>4}: {avg_s:.4f} | "
|
| 605 |
+
f"{[f'{sm[i]:.4f}' for i in range(5)]}")
|
| 606 |
+
|
| 607 |
+
print(f"""
|
| 608 |
+
Single-seed avg: {single_mean:.4f} ± {single_std:.4f}
|
| 609 |
+
5-Seed Ensemble: {ens_mean:.4f} ± {ens_std:.4f}
|
| 610 |
+
Per-fold ens: {[f'{ens_aucs[i]:.4f}' for i in range(5)]}
|
| 611 |
+
|
| 612 |
+
{'Model':<40} {'ROCAUC':>10}
|
| 613 |
+
{'─'*53}
|
| 614 |
+
{baseline_name:<40} {baseline_auc:>10}
|
| 615 |
+
{'TRIADS (44K, 5-seed ens)':<40} {f'{ens_mean:.4f}':>10} ← US
|
| 616 |
+
{'─'*53}
|
| 617 |
+
|
| 618 |
+
Total time: {tt/60:.1f} min
|
| 619 |
+
Saved: {model_dir}/
|
| 620 |
+
""")
|
| 621 |
+
|
| 622 |
+
summary = {
|
| 623 |
+
'dataset': dataset_name,
|
| 624 |
+
'task': 'classification',
|
| 625 |
+
'metric': 'ROCAUC',
|
| 626 |
+
'samples': len(comps_all),
|
| 627 |
+
'class_balance': f'{n_pos} positive / {n_neg} negative',
|
| 628 |
+
'model_config': MODEL_CFG,
|
| 629 |
+
'params': n_params,
|
| 630 |
+
'seeds': SEEDS,
|
| 631 |
+
'fold_seed': MATBENCH_FOLD_SEED,
|
| 632 |
+
'per_seed': {str(s): {str(k): round(v, 4) for k, v in m.items()}
|
| 633 |
+
for s, m in all_seed_aucs.items()},
|
| 634 |
+
'single_seed_avg': round(single_mean, 4),
|
| 635 |
+
'single_seed_std': round(single_std, 4),
|
| 636 |
+
'ensemble_aucs': {str(k): round(v, 4) for k, v in ens_aucs.items()},
|
| 637 |
+
'ensemble_avg': round(ens_mean, 4),
|
| 638 |
+
'ensemble_std': round(ens_std, 4),
|
| 639 |
+
'total_time_min': round(tt/60, 1),
|
| 640 |
+
}
|
| 641 |
+
with open(summary_file, 'w') as f:
|
| 642 |
+
json.dump(summary, f, indent=2)
|
| 643 |
+
print(f" Saved: {summary_file}")
|
| 644 |
+
|
| 645 |
+
shutil.make_archive(model_dir, 'zip', '.', model_dir)
|
| 646 |
+
print(f" Saved: {model_dir}.zip")
|
| 647 |
+
|
| 648 |
+
return ens_mean
|
| 649 |
+
|
| 650 |
+
|
| 651 |
+
# ======================================================================
|
| 652 |
+
# MAIN — RUN BOTH SEQUENTIALLY
|
| 653 |
+
# ======================================================================
|
| 654 |
+
|
| 655 |
+
if __name__ == '__main__':
|
| 656 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 657 |
+
if device.type == 'cuda':
|
| 658 |
+
gm = torch.cuda.get_device_properties(0).total_memory / 1e9
|
| 659 |
+
print(f" GPU: {torch.cuda.get_device_name(0)} ({gm:.1f} GB)")
|
| 660 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 661 |
+
torch.backends.cudnn.benchmark = True
|
| 662 |
+
|
| 663 |
+
print(f"""
|
| 664 |
+
╔══════════════════════════════════════════════════════════╗
|
| 665 |
+
║ TRIADS Classification Benchmarks ║
|
| 666 |
+
║ 44K model | 5-Seed Ensemble | BCEWithLogitsLoss ║
|
| 667 |
+
║ Fold seed: {MATBENCH_FOLD_SEED} (matbench v0.1 standard) ║
|
| 668 |
+
╠══════════════════════════════════════════════════════════╣
|
| 669 |
+
║ 1. matbench_expt_is_metal (4,921 samples) ║
|
| 670 |
+
║ 2. matbench_glass (5,680 samples) ║
|
| 671 |
+
╚══════════════════════════════════════════════════════════╝
|
| 672 |
+
""")
|
| 673 |
+
|
| 674 |
+
t_total = time.time()
|
| 675 |
+
results = {}
|
| 676 |
+
|
| 677 |
+
# ── BENCHMARK 1: expt_is_metal ───────────────────────────────────
|
| 678 |
+
print("\n" + "█"*72)
|
| 679 |
+
print(" BENCHMARK 1/2: matbench_expt_is_metal")
|
| 680 |
+
print("█"*72)
|
| 681 |
+
|
| 682 |
+
auc1 = run_classification_benchmark(
|
| 683 |
+
dataset_name="matbench_expt_is_metal",
|
| 684 |
+
target_col="is_metal",
|
| 685 |
+
featurizer_cls=MetallicityFeaturizer,
|
| 686 |
+
model_dir="is_metal_models",
|
| 687 |
+
summary_file="is_metal_summary.json",
|
| 688 |
+
baseline_name="AMMExpress v2020",
|
| 689 |
+
baseline_auc="0.9209",
|
| 690 |
+
device=device,
|
| 691 |
+
)
|
| 692 |
+
results['is_metal'] = auc1
|
| 693 |
+
|
| 694 |
+
# ── BENCHMARK 2: glass ───────────────────────────────────────────
|
| 695 |
+
print("\n" + "█"*72)
|
| 696 |
+
print(" BENCHMARK 2/2: matbench_glass")
|
| 697 |
+
print("█"*72)
|
| 698 |
+
|
| 699 |
+
auc2 = run_classification_benchmark(
|
| 700 |
+
dataset_name="matbench_glass",
|
| 701 |
+
target_col="gfa",
|
| 702 |
+
featurizer_cls=GlassFeaturizer,
|
| 703 |
+
model_dir="glass_models",
|
| 704 |
+
summary_file="glass_summary.json",
|
| 705 |
+
baseline_name="MODNet v0.1.12",
|
| 706 |
+
baseline_auc="0.9603",
|
| 707 |
+
device=device,
|
| 708 |
+
)
|
| 709 |
+
results['glass'] = auc2
|
| 710 |
+
|
| 711 |
+
# ── COMBINED SUMMARY ─────────────────────────────────────────────
|
| 712 |
+
tt = time.time() - t_total
|
| 713 |
+
print(f"""
|
| 714 |
+
|
| 715 |
+
{'='*72}
|
| 716 |
+
COMBINED RESULTS — ALL CLASSIFICATION BENCHMARKS
|
| 717 |
+
{'='*72}
|
| 718 |
+
|
| 719 |
+
{'Dataset':<30} {'Baseline':>10} {'TRIADS':>10}
|
| 720 |
+
{'─'*53}
|
| 721 |
+
{'matbench_expt_is_metal':<30} {'0.9209':>10} {f'{auc1:.4f}':>10}
|
| 722 |
+
{'matbench_glass':<30} {'0.9603':>10} {f'{auc2:.4f}':>10}
|
| 723 |
+
{'─'*53}
|
| 724 |
+
|
| 725 |
+
Grand total time: {tt/60:.1f} min ({tt/3600:.1f} hrs)
|
| 726 |
+
|
| 727 |
+
ALL TRIADS BENCHMARKS:
|
| 728 |
+
────��────────────────
|
| 729 |
+
steels: 91.20 MPa (#1-2)
|
| 730 |
+
expt_gap: 0.3068 eV (#2)
|
| 731 |
+
jdft2d: 35.89 meV/atom (#3)
|
| 732 |
+
is_metal: {auc1:.4f} ROCAUC
|
| 733 |
+
glass: {auc2:.4f} ROCAUC
|
| 734 |
+
""")
|
model_code/expt_gap_model.py
ADDED
|
@@ -0,0 +1,579 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
+=============================================================+
|
| 3 |
+
| TRIADS V3 on matbench_expt_gap |
|
| 4 |
+
| 2x T4 GPU Parallel Training (auto-fallback to 1 GPU) |
|
| 5 |
+
| 4 Models: Steps(16,20) x Dropout(0.15,0.20) |
|
| 6 |
+
| Proven arch: d_attn=64, d_hidden=96 | batch_size=64 |
|
| 7 |
+
| FastTensorDataLoader | Clean output |
|
| 8 |
+
+=============================================================+
|
| 9 |
+
"""
|
| 10 |
+
|
| 11 |
+
import os, copy, json, time, logging, warnings, urllib.request
|
| 12 |
+
warnings.filterwarnings('ignore')
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
import pandas as pd
|
| 16 |
+
|
| 17 |
+
import matplotlib
|
| 18 |
+
matplotlib.use('Agg')
|
| 19 |
+
import matplotlib.pyplot as plt
|
| 20 |
+
|
| 21 |
+
from tqdm import tqdm
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import torch.nn as nn
|
| 25 |
+
import torch.nn.functional as F
|
| 26 |
+
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
|
| 27 |
+
|
| 28 |
+
from sklearn.model_selection import KFold
|
| 29 |
+
from sklearn.preprocessing import StandardScaler
|
| 30 |
+
from pymatgen.core import Composition
|
| 31 |
+
from matminer.featurizers.composition import ElementProperty
|
| 32 |
+
from gensim.models import Word2Vec
|
| 33 |
+
|
| 34 |
+
logging.basicConfig(level=logging.INFO, format='%(name)s | %(message)s')
|
| 35 |
+
log = logging.getLogger("TRIADS-V3")
|
| 36 |
+
|
| 37 |
+
SEEDS = [42]
|
| 38 |
+
BATCH_SIZE = 64
|
| 39 |
+
|
| 40 |
+
BASELINES = {
|
| 41 |
+
'Darwin': 0.2865,
|
| 42 |
+
'Ax/SAASBO CrabNet': 0.3310,
|
| 43 |
+
'MODNet v0.1.12': 0.3327,
|
| 44 |
+
'AMMExpress v2020': 0.4161,
|
| 45 |
+
'CrabNet': 0.4427,
|
| 46 |
+
'RF-SCM/Magpie': 0.5205,
|
| 47 |
+
'Dummy': 1.0280,
|
| 48 |
+
}
|
| 49 |
+
V1_BEST = {'EG-A (V1)': 0.3510, 'EG-B (V1)': 0.3616}
|
| 50 |
+
|
| 51 |
+
# Use ALL available CPU cores for PyTorch operations
|
| 52 |
+
torch.set_num_threads(4) # 4 vCPUs on Kaggle
|
| 53 |
+
torch.set_num_interop_threads(2) # 2 physical cores
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# ======================================================================
|
| 57 |
+
# FAST TENSOR DATALOADER
|
| 58 |
+
# ======================================================================
|
| 59 |
+
|
| 60 |
+
class FastTensorDataLoader:
|
| 61 |
+
"""Zero-CPU DataLoader. Entire dataset in GPU VRAM."""
|
| 62 |
+
def __init__(self, *tensors, batch_size=64, shuffle=False):
|
| 63 |
+
assert all(t.shape[0] == tensors[0].shape[0] for t in tensors)
|
| 64 |
+
self.tensors = tensors
|
| 65 |
+
self.dataset_len = tensors[0].shape[0]
|
| 66 |
+
self.batch_size = batch_size
|
| 67 |
+
self.shuffle = shuffle
|
| 68 |
+
self.n_batches = (self.dataset_len + batch_size - 1) // batch_size
|
| 69 |
+
|
| 70 |
+
def __iter__(self):
|
| 71 |
+
if self.shuffle:
|
| 72 |
+
idx = torch.randperm(self.dataset_len, device=self.tensors[0].device)
|
| 73 |
+
self.tensors = tuple(t[idx] for t in self.tensors)
|
| 74 |
+
self.i = 0
|
| 75 |
+
return self
|
| 76 |
+
|
| 77 |
+
def __next__(self):
|
| 78 |
+
if self.i >= self.dataset_len:
|
| 79 |
+
raise StopIteration
|
| 80 |
+
batch = tuple(t[self.i:self.i + self.batch_size] for t in self.tensors)
|
| 81 |
+
self.i += self.batch_size
|
| 82 |
+
return batch
|
| 83 |
+
|
| 84 |
+
def __len__(self):
|
| 85 |
+
return self.n_batches
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
# ======================================================================
|
| 89 |
+
# FEATURIZER
|
| 90 |
+
# ======================================================================
|
| 91 |
+
|
| 92 |
+
class ExpandedFeaturizer:
|
| 93 |
+
GCS = "https://storage.googleapis.com/mat2vec/"
|
| 94 |
+
FILES = ["pretrained_embeddings",
|
| 95 |
+
"pretrained_embeddings.wv.vectors.npy",
|
| 96 |
+
"pretrained_embeddings.trainables.syn1neg.npy"]
|
| 97 |
+
|
| 98 |
+
def __init__(self, cache="mat2vec_cache"):
|
| 99 |
+
from matminer.featurizers.composition import (
|
| 100 |
+
ElementFraction, Stoichiometry, ValenceOrbital,
|
| 101 |
+
IonProperty, BandCenter
|
| 102 |
+
)
|
| 103 |
+
from matminer.featurizers.base import MultipleFeaturizer
|
| 104 |
+
self.ep_magpie = ElementProperty.from_preset("magpie")
|
| 105 |
+
self.n_mg = len(self.ep_magpie.feature_labels())
|
| 106 |
+
self.extra_feats = MultipleFeaturizer([
|
| 107 |
+
ElementFraction(), Stoichiometry(), ValenceOrbital(),
|
| 108 |
+
IonProperty(), BandCenter(),
|
| 109 |
+
])
|
| 110 |
+
self.n_extra = None
|
| 111 |
+
self.scaler = None
|
| 112 |
+
os.makedirs(cache, exist_ok=True)
|
| 113 |
+
for f in self.FILES:
|
| 114 |
+
p = os.path.join(cache, f)
|
| 115 |
+
if not os.path.exists(p):
|
| 116 |
+
log.info(f" Downloading {f}...")
|
| 117 |
+
urllib.request.urlretrieve(self.GCS + f, p)
|
| 118 |
+
self.m2v = Word2Vec.load(os.path.join(cache, "pretrained_embeddings"))
|
| 119 |
+
self.emb = {w: self.m2v.wv[w] for w in self.m2v.wv.index_to_key}
|
| 120 |
+
|
| 121 |
+
def _pool(self, c):
|
| 122 |
+
v, t = np.zeros(200, np.float32), 0.0
|
| 123 |
+
for s, f in c.get_el_amt_dict().items():
|
| 124 |
+
if s in self.emb: v += f * self.emb[s]; t += f
|
| 125 |
+
return v / max(t, 1e-8)
|
| 126 |
+
|
| 127 |
+
def featurize_all(self, comps):
|
| 128 |
+
out = []
|
| 129 |
+
for c in tqdm(comps, desc=" Featurizing", leave=False):
|
| 130 |
+
try: mg = np.array(self.ep_magpie.featurize(c), np.float32)
|
| 131 |
+
except: mg = np.zeros(self.n_mg, np.float32)
|
| 132 |
+
try: ex = np.array(self.extra_feats.featurize(c), np.float32)
|
| 133 |
+
except: ex = np.zeros(self.n_extra or 200, np.float32)
|
| 134 |
+
if self.n_extra is None:
|
| 135 |
+
self.n_extra = len(ex)
|
| 136 |
+
log.info(f"Features: {self.n_mg} Magpie + {self.n_extra} Extra + 200 Mat2Vec")
|
| 137 |
+
out.append(np.concatenate([
|
| 138 |
+
np.nan_to_num(mg, nan=0.0),
|
| 139 |
+
np.nan_to_num(ex, nan=0.0),
|
| 140 |
+
self._pool(c)
|
| 141 |
+
]))
|
| 142 |
+
return np.array(out)
|
| 143 |
+
|
| 144 |
+
def fit_scaler(self, X): self.scaler = StandardScaler().fit(X)
|
| 145 |
+
def transform(self, X):
|
| 146 |
+
if not self.scaler: return X
|
| 147 |
+
return np.nan_to_num(self.scaler.transform(X), nan=0.0).astype(np.float32)
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
# ======================================================================
|
| 151 |
+
# MODEL — DeepHybridTRM (V13A proven architecture)
|
| 152 |
+
# ======================================================================
|
| 153 |
+
|
| 154 |
+
class DeepHybridTRM(nn.Module):
|
| 155 |
+
def __init__(self, n_props=22, stat_dim=6, n_extra=0, mat2vec_dim=200,
|
| 156 |
+
d_attn=64, nhead=4, d_hidden=96, ff_dim=150,
|
| 157 |
+
dropout=0.2, max_steps=20, **kw):
|
| 158 |
+
super().__init__()
|
| 159 |
+
self.max_steps, self.D = max_steps, d_hidden
|
| 160 |
+
self.n_props, self.stat_dim, self.n_extra = n_props, stat_dim, n_extra
|
| 161 |
+
|
| 162 |
+
self.tok_proj = nn.Sequential(
|
| 163 |
+
nn.Linear(stat_dim, d_attn), nn.LayerNorm(d_attn), nn.GELU())
|
| 164 |
+
self.m2v_proj = nn.Sequential(
|
| 165 |
+
nn.Linear(mat2vec_dim, d_attn), nn.LayerNorm(d_attn), nn.GELU())
|
| 166 |
+
|
| 167 |
+
self.sa1 = nn.MultiheadAttention(d_attn, nhead, dropout=dropout, batch_first=True)
|
| 168 |
+
self.sa1_n = nn.LayerNorm(d_attn)
|
| 169 |
+
self.sa1_ff = nn.Sequential(
|
| 170 |
+
nn.Linear(d_attn, d_attn*2), nn.GELU(), nn.Dropout(dropout),
|
| 171 |
+
nn.Linear(d_attn*2, d_attn))
|
| 172 |
+
self.sa1_fn = nn.LayerNorm(d_attn)
|
| 173 |
+
|
| 174 |
+
self.sa2 = nn.MultiheadAttention(d_attn, nhead, dropout=dropout, batch_first=True)
|
| 175 |
+
self.sa2_n = nn.LayerNorm(d_attn)
|
| 176 |
+
self.sa2_ff = nn.Sequential(
|
| 177 |
+
nn.Linear(d_attn, d_attn*2), nn.GELU(), nn.Dropout(dropout),
|
| 178 |
+
nn.Linear(d_attn*2, d_attn))
|
| 179 |
+
self.sa2_fn = nn.LayerNorm(d_attn)
|
| 180 |
+
|
| 181 |
+
self.ca = nn.MultiheadAttention(d_attn, nhead, dropout=dropout, batch_first=True)
|
| 182 |
+
self.ca_n = nn.LayerNorm(d_attn)
|
| 183 |
+
|
| 184 |
+
pool_in = d_attn + (n_extra if n_extra > 0 else 0)
|
| 185 |
+
self.pool = nn.Sequential(
|
| 186 |
+
nn.Linear(pool_in, d_hidden), nn.LayerNorm(d_hidden), nn.GELU())
|
| 187 |
+
|
| 188 |
+
self.z_up = nn.Sequential(
|
| 189 |
+
nn.Linear(d_hidden*3, ff_dim), nn.GELU(), nn.Dropout(dropout),
|
| 190 |
+
nn.Linear(ff_dim, d_hidden), nn.LayerNorm(d_hidden))
|
| 191 |
+
self.y_up = nn.Sequential(
|
| 192 |
+
nn.Linear(d_hidden*2, ff_dim), nn.GELU(), nn.Dropout(dropout),
|
| 193 |
+
nn.Linear(ff_dim, d_hidden), nn.LayerNorm(d_hidden))
|
| 194 |
+
self.head = nn.Linear(d_hidden, 1)
|
| 195 |
+
self._init()
|
| 196 |
+
|
| 197 |
+
def _init(self):
|
| 198 |
+
for m in self.modules():
|
| 199 |
+
if isinstance(m, nn.Linear):
|
| 200 |
+
nn.init.xavier_uniform_(m.weight)
|
| 201 |
+
if m.bias is not None: nn.init.zeros_(m.bias)
|
| 202 |
+
|
| 203 |
+
def _attention(self, x):
|
| 204 |
+
B = x.size(0)
|
| 205 |
+
mg_dim = self.n_props * self.stat_dim
|
| 206 |
+
if self.n_extra > 0:
|
| 207 |
+
extra = x[:, mg_dim:mg_dim + self.n_extra]
|
| 208 |
+
m2v = x[:, mg_dim + self.n_extra:]
|
| 209 |
+
else:
|
| 210 |
+
extra, m2v = None, x[:, mg_dim:]
|
| 211 |
+
|
| 212 |
+
tok = self.tok_proj(x[:, :mg_dim].view(B, self.n_props, self.stat_dim))
|
| 213 |
+
ctx = self.m2v_proj(m2v).unsqueeze(1)
|
| 214 |
+
|
| 215 |
+
tok = self.sa1_n(tok + self.sa1(tok, tok, tok)[0])
|
| 216 |
+
tok = self.sa1_fn(tok + self.sa1_ff(tok))
|
| 217 |
+
tok = self.sa2_n(tok + self.sa2(tok, tok, tok)[0])
|
| 218 |
+
tok = self.sa2_fn(tok + self.sa2_ff(tok))
|
| 219 |
+
tok = self.ca_n(tok + self.ca(tok, ctx, ctx)[0])
|
| 220 |
+
|
| 221 |
+
pooled = tok.mean(dim=1)
|
| 222 |
+
if extra is not None:
|
| 223 |
+
pooled = torch.cat([pooled, extra], dim=-1)
|
| 224 |
+
return self.pool(pooled)
|
| 225 |
+
|
| 226 |
+
def forward(self, x, deep_supervision=False):
|
| 227 |
+
B = x.size(0)
|
| 228 |
+
xp = self._attention(x)
|
| 229 |
+
z = torch.zeros(B, self.D, device=x.device)
|
| 230 |
+
y = torch.zeros(B, self.D, device=x.device)
|
| 231 |
+
step_preds = []
|
| 232 |
+
for s in range(self.max_steps):
|
| 233 |
+
z = z + self.z_up(torch.cat([xp, y, z], -1))
|
| 234 |
+
y = y + self.y_up(torch.cat([y, z], -1))
|
| 235 |
+
step_preds.append(self.head(y).squeeze(1))
|
| 236 |
+
return step_preds if deep_supervision else step_preds[-1]
|
| 237 |
+
|
| 238 |
+
def count_parameters(self):
|
| 239 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 240 |
+
|
| 241 |
+
|
| 242 |
+
# ======================================================================
|
| 243 |
+
# LOSS + UTILS
|
| 244 |
+
# ======================================================================
|
| 245 |
+
|
| 246 |
+
def deep_supervision_loss(step_preds, targets):
|
| 247 |
+
n = len(step_preds)
|
| 248 |
+
weights = [(i+1) for i in range(n)]
|
| 249 |
+
tw = sum(weights)
|
| 250 |
+
return sum((w/tw) * F.l1_loss(p, targets) for p, w in zip(step_preds, weights))
|
| 251 |
+
|
| 252 |
+
|
| 253 |
+
def strat_split(targets, val_size=0.15, seed=42):
|
| 254 |
+
bins = np.percentile(targets, [25, 50, 75])
|
| 255 |
+
lbl = np.digitize(targets, bins)
|
| 256 |
+
tr, vl = [], []
|
| 257 |
+
rng = np.random.RandomState(seed)
|
| 258 |
+
for b in range(4):
|
| 259 |
+
m = np.where(lbl == b)[0]
|
| 260 |
+
if len(m) == 0: continue
|
| 261 |
+
n = max(1, int(len(m) * val_size))
|
| 262 |
+
c = rng.choice(m, n, replace=False)
|
| 263 |
+
vl.extend(c.tolist()); tr.extend(np.setdiff1d(m, c).tolist())
|
| 264 |
+
return np.array(tr), np.array(vl)
|
| 265 |
+
|
| 266 |
+
|
| 267 |
+
def predict(model, dl):
|
| 268 |
+
model.eval(); preds = []
|
| 269 |
+
with torch.no_grad():
|
| 270 |
+
for bx, _ in dl:
|
| 271 |
+
preds.append(model(bx).cpu())
|
| 272 |
+
return torch.cat(preds)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
# ======================================================================
|
| 276 |
+
# TRAINING — clean, simple, V1-style
|
| 277 |
+
# ======================================================================
|
| 278 |
+
|
| 279 |
+
def train_fold(model, tr_dl, vl_dl, device,
|
| 280 |
+
epochs=300, swa_start=200, fold=1, name="", gpu_tag=""):
|
| 281 |
+
opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
|
| 282 |
+
sch = torch.optim.lr_scheduler.CosineAnnealingLR(opt, T_max=swa_start, eta_min=1e-4)
|
| 283 |
+
swa_m = AveragedModel(model)
|
| 284 |
+
swa_s = SWALR(opt, swa_lr=5e-4)
|
| 285 |
+
swa_on = False
|
| 286 |
+
best_v, best_w = float('inf'), copy.deepcopy(model.state_dict())
|
| 287 |
+
hist = {'train': [], 'val': []}
|
| 288 |
+
use_amp = (device.type == 'cuda')
|
| 289 |
+
scaler = torch.amp.GradScaler('cuda', enabled=use_amp)
|
| 290 |
+
|
| 291 |
+
pbar = tqdm(range(epochs), desc=f" {gpu_tag}[{name}] F{fold}/5",
|
| 292 |
+
leave=False, ncols=120)
|
| 293 |
+
for ep in pbar:
|
| 294 |
+
model.train(); tl = 0.0
|
| 295 |
+
for bx, by in tr_dl:
|
| 296 |
+
with torch.amp.autocast('cuda', enabled=use_amp):
|
| 297 |
+
sp = model(bx, deep_supervision=True)
|
| 298 |
+
loss = deep_supervision_loss(sp, by)
|
| 299 |
+
opt.zero_grad(set_to_none=True)
|
| 300 |
+
scaler.scale(loss).backward()
|
| 301 |
+
scaler.unscale_(opt)
|
| 302 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 303 |
+
scaler.step(opt)
|
| 304 |
+
scaler.update()
|
| 305 |
+
tl += F.l1_loss(sp[-1], by).item() * len(by)
|
| 306 |
+
tl /= tr_dl.dataset_len
|
| 307 |
+
|
| 308 |
+
model.eval(); vl = 0.0
|
| 309 |
+
with torch.no_grad():
|
| 310 |
+
with torch.amp.autocast('cuda', enabled=use_amp):
|
| 311 |
+
for bx, by in vl_dl:
|
| 312 |
+
vl += F.l1_loss(model(bx), by).item() * len(by)
|
| 313 |
+
vl /= vl_dl.dataset_len
|
| 314 |
+
hist['train'].append(tl); hist['val'].append(vl)
|
| 315 |
+
|
| 316 |
+
if ep < swa_start:
|
| 317 |
+
sch.step()
|
| 318 |
+
if vl < best_v:
|
| 319 |
+
best_v = vl
|
| 320 |
+
best_w = copy.deepcopy(model.state_dict())
|
| 321 |
+
else:
|
| 322 |
+
if not swa_on: swa_on = True
|
| 323 |
+
swa_m.update_parameters(model); swa_s.step()
|
| 324 |
+
|
| 325 |
+
pbar.set_postfix(Best=f'{best_v:.4f}', Ph='SWA' if swa_on else 'COS',
|
| 326 |
+
Tr=f'{tl:.4f}', Val=f'{vl:.4f}')
|
| 327 |
+
|
| 328 |
+
if swa_on:
|
| 329 |
+
update_bn(tr_dl, swa_m, device=device)
|
| 330 |
+
model.load_state_dict(swa_m.module.state_dict())
|
| 331 |
+
else:
|
| 332 |
+
model.load_state_dict(best_w)
|
| 333 |
+
return best_v, model, hist
|
| 334 |
+
|
| 335 |
+
|
| 336 |
+
# ======================================================================
|
| 337 |
+
# GPU WORKER — trains assigned models on one GPU
|
| 338 |
+
# ======================================================================
|
| 339 |
+
|
| 340 |
+
def gpu_worker(gpu_id, config_list, X_all, targets_all, folds, n_extra,
|
| 341 |
+
result_file):
|
| 342 |
+
device = torch.device(f'cuda:{gpu_id}')
|
| 343 |
+
torch.cuda.set_device(gpu_id)
|
| 344 |
+
tag = f"[GPU{gpu_id}] "
|
| 345 |
+
|
| 346 |
+
print(f"\n {tag}Started on {torch.cuda.get_device_name(gpu_id)}")
|
| 347 |
+
print(f" {tag}Models: {[c[0] for c in config_list]}")
|
| 348 |
+
|
| 349 |
+
feat = ExpandedFeaturizer()
|
| 350 |
+
results = {}
|
| 351 |
+
|
| 352 |
+
for ci, (cname, model_kw) in enumerate(config_list):
|
| 353 |
+
print(f"\n {tag}{'='*50}")
|
| 354 |
+
print(f" {tag}[{ci+1}/{len(config_list)}] {cname}")
|
| 355 |
+
print(f" {tag}{'='*50}")
|
| 356 |
+
|
| 357 |
+
seed = SEEDS[0]
|
| 358 |
+
fold_maes = []
|
| 359 |
+
|
| 360 |
+
for fi, (tv_i, te_i) in enumerate(folds):
|
| 361 |
+
print(f"\n {tag}-- [{cname}] Fold {fi+1}/5 " + "-"*20)
|
| 362 |
+
|
| 363 |
+
tri, vli = strat_split(targets_all[tv_i], 0.15, seed + fi)
|
| 364 |
+
feat.fit_scaler(X_all[tv_i][tri])
|
| 365 |
+
|
| 366 |
+
tr_x = torch.tensor(feat.transform(X_all[tv_i][tri]), dtype=torch.float32).to(device)
|
| 367 |
+
tr_y = torch.tensor(targets_all[tv_i][tri], dtype=torch.float32).to(device)
|
| 368 |
+
vl_x = torch.tensor(feat.transform(X_all[tv_i][vli]), dtype=torch.float32).to(device)
|
| 369 |
+
vl_y = torch.tensor(targets_all[tv_i][vli], dtype=torch.float32).to(device)
|
| 370 |
+
te_x = torch.tensor(feat.transform(X_all[te_i]), dtype=torch.float32).to(device)
|
| 371 |
+
te_y = torch.tensor(targets_all[te_i], dtype=torch.float32).to(device)
|
| 372 |
+
|
| 373 |
+
tr_dl = FastTensorDataLoader(tr_x, tr_y, batch_size=BATCH_SIZE, shuffle=True)
|
| 374 |
+
vl_dl = FastTensorDataLoader(vl_x, vl_y, batch_size=BATCH_SIZE, shuffle=False)
|
| 375 |
+
te_dl = FastTensorDataLoader(te_x, te_y, batch_size=BATCH_SIZE, shuffle=False)
|
| 376 |
+
|
| 377 |
+
torch.manual_seed(seed + fi)
|
| 378 |
+
np.random.seed(seed + fi)
|
| 379 |
+
torch.cuda.manual_seed(seed + fi)
|
| 380 |
+
|
| 381 |
+
model = DeepHybridTRM(**model_kw).to(device)
|
| 382 |
+
if fi == 0:
|
| 383 |
+
print(f" {tag}Params: {model.count_parameters():,}")
|
| 384 |
+
|
| 385 |
+
bv, model, hist = train_fold(
|
| 386 |
+
model, tr_dl, vl_dl, device,
|
| 387 |
+
epochs=300, swa_start=200, fold=fi+1, name=cname, gpu_tag=tag)
|
| 388 |
+
|
| 389 |
+
pred = predict(model, te_dl)
|
| 390 |
+
mae = F.l1_loss(pred, te_y.cpu()).item()
|
| 391 |
+
print(f" {tag}Fold {fi+1} TEST: {mae:.4f} eV (val best: {bv:.4f})")
|
| 392 |
+
|
| 393 |
+
fold_maes.append(mae)
|
| 394 |
+
os.makedirs('expt_gap_models_v3', exist_ok=True)
|
| 395 |
+
torch.save({
|
| 396 |
+
'model_state': model.state_dict(),
|
| 397 |
+
'test_mae': mae, 'config': cname, 'seed': seed,
|
| 398 |
+
'fold': fi+1, 'n_extra': n_extra,
|
| 399 |
+
}, f'expt_gap_models_v3/{cname}_s{seed}_f{fi+1}.pt')
|
| 400 |
+
|
| 401 |
+
del model, tr_x, tr_y, vl_x, vl_y, te_x, te_y
|
| 402 |
+
torch.cuda.empty_cache()
|
| 403 |
+
|
| 404 |
+
avg = float(np.mean(fold_maes))
|
| 405 |
+
std = float(np.std(fold_maes))
|
| 406 |
+
results[cname] = {'avg': avg, 'std': std, 'folds': fold_maes}
|
| 407 |
+
|
| 408 |
+
print(f"\n {tag}=== {cname} ===")
|
| 409 |
+
print(f" {tag} 5-Fold Avg MAE: {avg:.4f} +/- {std:.4f} eV")
|
| 410 |
+
print(f" {tag} Per-fold: {[f'{m:.4f}' for m in fold_maes]}")
|
| 411 |
+
|
| 412 |
+
with open(result_file, 'w') as f:
|
| 413 |
+
json.dump(results, f)
|
| 414 |
+
print(f"\n {tag}DONE. Saved to {result_file}")
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
# ======================================================================
|
| 418 |
+
# MAIN
|
| 419 |
+
# ======================================================================
|
| 420 |
+
|
| 421 |
+
def run_benchmark():
|
| 422 |
+
t0 = time.time()
|
| 423 |
+
|
| 424 |
+
print(f"""
|
| 425 |
+
+==========================================================+
|
| 426 |
+
| TRIADS V3 -- P100 | FastTensorDataLoader |
|
| 427 |
+
| 4 Models: Steps(16,20) x Dropout(0.15,0.20) |
|
| 428 |
+
| d_attn=64, d_hidden=96 (proven V1 arch) |
|
| 429 |
+
| batch_size={BATCH_SIZE} | All CPU cores active |
|
| 430 |
+
+==========================================================+
|
| 431 |
+
""")
|
| 432 |
+
|
| 433 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 434 |
+
if device.type == 'cuda':
|
| 435 |
+
try: gm = torch.cuda.get_device_properties(0).total_memory / 1e9
|
| 436 |
+
except: gm = 0
|
| 437 |
+
print(f" GPU: {torch.cuda.get_device_name(0)} ({gm:.1f} GB)")
|
| 438 |
+
print(f" CPU threads: {torch.get_num_threads()} | Interop: {torch.get_num_interop_threads()}")
|
| 439 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 440 |
+
torch.backends.cudnn.benchmark = True
|
| 441 |
+
|
| 442 |
+
# ---- LOAD + FEATURIZE ----
|
| 443 |
+
print("\n Loading matbench_expt_gap...")
|
| 444 |
+
from matminer.datasets import load_dataset
|
| 445 |
+
df = load_dataset("matbench_expt_gap")
|
| 446 |
+
targets_all = np.array(df['gap expt'].tolist(), np.float32)
|
| 447 |
+
comps_all = [Composition(c) for c in df['composition'].tolist()]
|
| 448 |
+
print(f" Dataset: {len(comps_all)} samples")
|
| 449 |
+
|
| 450 |
+
feat = ExpandedFeaturizer()
|
| 451 |
+
X_all = feat.featurize_all(comps_all)
|
| 452 |
+
n_extra = feat.n_extra
|
| 453 |
+
print(f" Features: {X_all.shape}")
|
| 454 |
+
|
| 455 |
+
kfold = KFold(n_splits=5, shuffle=True, random_state=18012019)
|
| 456 |
+
folds = list(kfold.split(comps_all))
|
| 457 |
+
for fi, (tv, te) in enumerate(folds):
|
| 458 |
+
assert len(set(tv) & set(te)) == 0
|
| 459 |
+
print(" 5 folds verified: zero leakage")
|
| 460 |
+
|
| 461 |
+
# ---- CONFIGS ----
|
| 462 |
+
base = dict(n_props=22, stat_dim=6, n_extra=n_extra, mat2vec_dim=200,
|
| 463 |
+
d_attn=64, nhead=4, d_hidden=96, ff_dim=150)
|
| 464 |
+
|
| 465 |
+
all_configs = [
|
| 466 |
+
('V3-S16-D15', {**base, 'max_steps': 16, 'dropout': 0.15}),
|
| 467 |
+
('V3-S16-D20', {**base, 'max_steps': 16, 'dropout': 0.20}),
|
| 468 |
+
('V3-S20-D15', {**base, 'max_steps': 20, 'dropout': 0.15}),
|
| 469 |
+
('V3-S20-D20', {**base, 'max_steps': 20, 'dropout': 0.20}),
|
| 470 |
+
]
|
| 471 |
+
|
| 472 |
+
print(f"\n {'Config':<16} {'Params':>10} {'Steps':>6} {'Drop':>6}")
|
| 473 |
+
for cn, kw in all_configs:
|
| 474 |
+
m = DeepHybridTRM(**kw); print(f" {cn:<16} {m.count_parameters():>10,} {kw['max_steps']:>6} {kw['dropout']:>6.2f}"); del m
|
| 475 |
+
|
| 476 |
+
# ---- TRAIN ----
|
| 477 |
+
all_results = {}
|
| 478 |
+
|
| 479 |
+
for ci, (cname, model_kw) in enumerate(all_configs):
|
| 480 |
+
print(f"\n {'='*60}")
|
| 481 |
+
print(f" [{ci+1}/4] {cname}")
|
| 482 |
+
print(f" {'='*60}")
|
| 483 |
+
|
| 484 |
+
seed = SEEDS[0]
|
| 485 |
+
fold_maes = []
|
| 486 |
+
|
| 487 |
+
for fi, (tv_i, te_i) in enumerate(folds):
|
| 488 |
+
print(f"\n -- [{cname}] Fold {fi+1}/5 " + "-"*30)
|
| 489 |
+
tri, vli = strat_split(targets_all[tv_i], 0.15, seed + fi)
|
| 490 |
+
feat.fit_scaler(X_all[tv_i][tri])
|
| 491 |
+
|
| 492 |
+
tr_x = torch.tensor(feat.transform(X_all[tv_i][tri]), dtype=torch.float32).to(device)
|
| 493 |
+
tr_y = torch.tensor(targets_all[tv_i][tri], dtype=torch.float32).to(device)
|
| 494 |
+
vl_x = torch.tensor(feat.transform(X_all[tv_i][vli]), dtype=torch.float32).to(device)
|
| 495 |
+
vl_y = torch.tensor(targets_all[tv_i][vli], dtype=torch.float32).to(device)
|
| 496 |
+
te_x = torch.tensor(feat.transform(X_all[te_i]), dtype=torch.float32).to(device)
|
| 497 |
+
te_y = torch.tensor(targets_all[te_i], dtype=torch.float32).to(device)
|
| 498 |
+
|
| 499 |
+
tr_dl = FastTensorDataLoader(tr_x, tr_y, batch_size=BATCH_SIZE, shuffle=True)
|
| 500 |
+
vl_dl = FastTensorDataLoader(vl_x, vl_y, batch_size=BATCH_SIZE, shuffle=False)
|
| 501 |
+
te_dl = FastTensorDataLoader(te_x, te_y, batch_size=BATCH_SIZE, shuffle=False)
|
| 502 |
+
|
| 503 |
+
torch.manual_seed(seed + fi); np.random.seed(seed + fi)
|
| 504 |
+
if device.type == 'cuda': torch.cuda.manual_seed(seed + fi)
|
| 505 |
+
|
| 506 |
+
model = DeepHybridTRM(**model_kw).to(device)
|
| 507 |
+
if fi == 0: print(f" Params: {model.count_parameters():,}")
|
| 508 |
+
|
| 509 |
+
bv, model, hist = train_fold(model, tr_dl, vl_dl, device,
|
| 510 |
+
epochs=300, swa_start=200, fold=fi+1, name=cname)
|
| 511 |
+
|
| 512 |
+
pred = predict(model, te_dl)
|
| 513 |
+
mae = F.l1_loss(pred, te_y.cpu()).item()
|
| 514 |
+
print(f" Fold {fi+1} TEST: {mae:.4f} eV (val: {bv:.4f})")
|
| 515 |
+
fold_maes.append(mae)
|
| 516 |
+
|
| 517 |
+
os.makedirs('expt_gap_models_v3', exist_ok=True)
|
| 518 |
+
torch.save({
|
| 519 |
+
'model_state': model.state_dict(),
|
| 520 |
+
'test_mae': mae, 'config': cname, 'seed': seed,
|
| 521 |
+
'fold': fi+1, 'n_extra': n_extra,
|
| 522 |
+
}, f'expt_gap_models_v3/{cname}_s{seed}_f{fi+1}.pt')
|
| 523 |
+
|
| 524 |
+
del model, tr_x, tr_y, vl_x, vl_y, te_x, te_y
|
| 525 |
+
if device.type == 'cuda': torch.cuda.empty_cache()
|
| 526 |
+
|
| 527 |
+
avg = float(np.mean(fold_maes))
|
| 528 |
+
std = float(np.std(fold_maes))
|
| 529 |
+
all_results[cname] = {'avg': avg, 'std': std, 'folds': fold_maes}
|
| 530 |
+
print(f"\n === {cname}: {avg:.4f} +/- {std:.4f} eV ===")
|
| 531 |
+
|
| 532 |
+
# ======== FINAL RESULTS ========
|
| 533 |
+
tt = time.time() - t0
|
| 534 |
+
print(f"\n{'='*72}")
|
| 535 |
+
print(f" FINAL LEADERBOARD -- TRIADS V3 (5-Fold Avg MAE, eV)")
|
| 536 |
+
print(f"{'='*72}")
|
| 537 |
+
print(f" {'Model':<20} {'MAE':>10} {'Std':>8} Notes")
|
| 538 |
+
print(f" {'-'*60}")
|
| 539 |
+
|
| 540 |
+
for n, r in sorted(all_results.items(), key=lambda x: x[1]['avg']):
|
| 541 |
+
tag = (" <-- DARWIN BEATEN!" if r['avg'] < 0.2865 else
|
| 542 |
+
" <-- Top 3!" if r['avg'] < 0.3327 else
|
| 543 |
+
" <-- Beats V1!" if r['avg'] < 0.3510 else
|
| 544 |
+
" <-- Beats AMMExp" if r['avg'] < 0.4161 else "")
|
| 545 |
+
print(f" {n:<20} {r['avg']:>10.4f} {r['std']:>8.4f}{tag}")
|
| 546 |
+
|
| 547 |
+
print(f" {'-'*60}")
|
| 548 |
+
for vn, vm in sorted(V1_BEST.items(), key=lambda x: x[1]):
|
| 549 |
+
print(f" {vn:<20} {vm:>10.4f} (V1)")
|
| 550 |
+
for bn, bv in sorted(BASELINES.items(), key=lambda x: x[1]):
|
| 551 |
+
print(f" {bn:<20} {bv:>10.4f}")
|
| 552 |
+
|
| 553 |
+
# Per-fold
|
| 554 |
+
names = sorted(all_results.keys())
|
| 555 |
+
print(f"\n PER-FOLD:")
|
| 556 |
+
hdr = f" {'Fold':<6}"; [hdr := hdr + f" {cn:>14}" for cn in names]
|
| 557 |
+
print(hdr)
|
| 558 |
+
for fi in range(5):
|
| 559 |
+
row = f" F{fi+1:<5}"; [row := row + f" {all_results[cn]['folds'][fi]:>14.4f}" for cn in names]
|
| 560 |
+
print(row)
|
| 561 |
+
|
| 562 |
+
print(f"\n HP GRID: {'D=0.15':>10} {'D=0.20':>10}")
|
| 563 |
+
for s in [16, 20]:
|
| 564 |
+
d15 = all_results.get(f'V3-S{s}-D15', {}).get('avg', 0)
|
| 565 |
+
d20 = all_results.get(f'V3-S{s}-D20', {}).get('avg', 0)
|
| 566 |
+
print(f" S={s:>2} {d15:>10.4f} {d20:>10.4f}")
|
| 567 |
+
|
| 568 |
+
print(f"\n Total: {tt/60:.1f} min")
|
| 569 |
+
|
| 570 |
+
s = {'version': 'EG-V3', 'batch_size': BATCH_SIZE,
|
| 571 |
+
'total_min': round(tt/60, 1), 'models': all_results,
|
| 572 |
+
'baselines': BASELINES, 'v1': V1_BEST}
|
| 573 |
+
with open('expt_gap_summary_v3.json', 'w') as f:
|
| 574 |
+
json.dump(s, f, indent=2)
|
| 575 |
+
print(" Saved: expt_gap_summary_v3.json")
|
| 576 |
+
|
| 577 |
+
|
| 578 |
+
if __name__ == '__main__':
|
| 579 |
+
run_benchmark()
|
model_code/jdft2d_model.py
ADDED
|
@@ -0,0 +1,589 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
+=============================================================+
|
| 3 |
+
| TRIADS V4 on matbench_jdft2d — 5-Seed Ensemble |
|
| 4 |
+
| Exfoliation Energy (meV/atom) — 636 samples |
|
| 5 |
+
| |
|
| 6 |
+
| Structural + Composition features (~361d) |
|
| 7 |
+
| 75K model (d_attn=32, d_hidden=64) | dropout=0.20 |
|
| 8 |
+
| Seeds: [42, 123, 456, 789, 1024] |
|
| 9 |
+
| Target: Kaggle P100 | ~30 min |
|
| 10 |
+
+=============================================================+
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
import os, copy, json, time, logging, warnings, urllib.request, shutil
|
| 14 |
+
warnings.filterwarnings('ignore')
|
| 15 |
+
|
| 16 |
+
import numpy as np
|
| 17 |
+
import pandas as pd
|
| 18 |
+
from tqdm import tqdm
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import torch.nn as nn
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
|
| 24 |
+
|
| 25 |
+
from sklearn.model_selection import KFold
|
| 26 |
+
from sklearn.preprocessing import StandardScaler
|
| 27 |
+
from pymatgen.core import Composition
|
| 28 |
+
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
|
| 29 |
+
from matminer.featurizers.composition import ElementProperty
|
| 30 |
+
from gensim.models import Word2Vec
|
| 31 |
+
|
| 32 |
+
logging.basicConfig(level=logging.INFO, format='%(name)s | %(message)s')
|
| 33 |
+
log = logging.getLogger("TRIADS-jdft2d")
|
| 34 |
+
|
| 35 |
+
BATCH_SIZE = 64
|
| 36 |
+
SEEDS = [42, 123, 456, 789, 1024]
|
| 37 |
+
|
| 38 |
+
# 75K config — best for 636 samples
|
| 39 |
+
MODEL_CFG = dict(
|
| 40 |
+
d_attn=32, nhead=4, d_hidden=64, ff_dim=96,
|
| 41 |
+
dropout=0.20, max_steps=16,
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
V1_BEST = {'V1 (100K, comp-only)': 45.8045}
|
| 45 |
+
V2_BEST = {'V2 (44K, comp-only)': 46.5889}
|
| 46 |
+
V3_BEST = {'V3 (75K, +struct, single)': 37.0033}
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# ======================================================================
|
| 50 |
+
# FAST TENSOR DATALOADER
|
| 51 |
+
# ======================================================================
|
| 52 |
+
|
| 53 |
+
class FastTensorDataLoader:
|
| 54 |
+
def __init__(self, *tensors, batch_size=64, shuffle=False):
|
| 55 |
+
assert all(t.shape[0] == tensors[0].shape[0] for t in tensors)
|
| 56 |
+
self.tensors = tensors
|
| 57 |
+
self.dataset_len = tensors[0].shape[0]
|
| 58 |
+
self.batch_size = batch_size
|
| 59 |
+
self.shuffle = shuffle
|
| 60 |
+
self.n_batches = (self.dataset_len + batch_size - 1) // batch_size
|
| 61 |
+
|
| 62 |
+
def __iter__(self):
|
| 63 |
+
if self.shuffle:
|
| 64 |
+
idx = torch.randperm(self.dataset_len, device=self.tensors[0].device)
|
| 65 |
+
self.tensors = tuple(t[idx] for t in self.tensors)
|
| 66 |
+
self.i = 0
|
| 67 |
+
return self
|
| 68 |
+
|
| 69 |
+
def __next__(self):
|
| 70 |
+
if self.i >= self.dataset_len:
|
| 71 |
+
raise StopIteration
|
| 72 |
+
batch = tuple(t[self.i:self.i + self.batch_size] for t in self.tensors)
|
| 73 |
+
self.i += self.batch_size
|
| 74 |
+
return batch
|
| 75 |
+
|
| 76 |
+
def __len__(self):
|
| 77 |
+
return self.n_batches
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
# ======================================================================
|
| 81 |
+
# FEATURIZER — Composition + Structural (~361d)
|
| 82 |
+
# ======================================================================
|
| 83 |
+
|
| 84 |
+
def _extract_structural_features(structure):
|
| 85 |
+
feats = []
|
| 86 |
+
try:
|
| 87 |
+
lat = structure.lattice
|
| 88 |
+
feats.extend([lat.a, lat.b, lat.c, lat.alpha, lat.beta, lat.gamma])
|
| 89 |
+
feats.append(structure.volume / max(len(structure), 1))
|
| 90 |
+
feats.append(structure.density)
|
| 91 |
+
feats.append(float(len(structure)))
|
| 92 |
+
try:
|
| 93 |
+
sga = SpacegroupAnalyzer(structure, symprec=0.1)
|
| 94 |
+
feats.append(float(sga.get_space_group_number()))
|
| 95 |
+
except:
|
| 96 |
+
feats.append(0.0)
|
| 97 |
+
try:
|
| 98 |
+
total_vol = sum(
|
| 99 |
+
(4/3) * np.pi * site.specie.atomic_radius**3
|
| 100 |
+
for site in structure if hasattr(site.specie, 'atomic_radius')
|
| 101 |
+
and site.specie.atomic_radius is not None
|
| 102 |
+
)
|
| 103 |
+
feats.append(total_vol / structure.volume if structure.volume > 0 else 0.0)
|
| 104 |
+
except:
|
| 105 |
+
feats.append(0.0)
|
| 106 |
+
except:
|
| 107 |
+
feats = [0.0] * 11
|
| 108 |
+
return np.array(feats, dtype=np.float32)
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
class ExfoliationFeaturizer:
|
| 112 |
+
GCS = "https://storage.googleapis.com/mat2vec/"
|
| 113 |
+
FILES = ["pretrained_embeddings",
|
| 114 |
+
"pretrained_embeddings.wv.vectors.npy",
|
| 115 |
+
"pretrained_embeddings.trainables.syn1neg.npy"]
|
| 116 |
+
|
| 117 |
+
def __init__(self, cache="mat2vec_cache"):
|
| 118 |
+
from matminer.featurizers.composition import (
|
| 119 |
+
Stoichiometry, ValenceOrbital, IonProperty
|
| 120 |
+
)
|
| 121 |
+
from matminer.featurizers.composition.element import TMetalFraction
|
| 122 |
+
|
| 123 |
+
self.ep_magpie = ElementProperty.from_preset("magpie")
|
| 124 |
+
self.n_mg = len(self.ep_magpie.feature_labels())
|
| 125 |
+
|
| 126 |
+
self.extra_featurizers = [
|
| 127 |
+
("Stoichiometry", Stoichiometry()),
|
| 128 |
+
("ValenceOrbital", ValenceOrbital()),
|
| 129 |
+
("IonProperty", IonProperty()),
|
| 130 |
+
("TMetalFraction", TMetalFraction()),
|
| 131 |
+
]
|
| 132 |
+
|
| 133 |
+
self._extra_sizes = {}
|
| 134 |
+
for name, ftzr in self.extra_featurizers:
|
| 135 |
+
try: self._extra_sizes[name] = len(ftzr.feature_labels())
|
| 136 |
+
except: self._extra_sizes[name] = None
|
| 137 |
+
|
| 138 |
+
self.n_extra = None
|
| 139 |
+
self.scaler = None
|
| 140 |
+
|
| 141 |
+
os.makedirs(cache, exist_ok=True)
|
| 142 |
+
for f in self.FILES:
|
| 143 |
+
p = os.path.join(cache, f)
|
| 144 |
+
if not os.path.exists(p):
|
| 145 |
+
log.info(f" Downloading {f}...")
|
| 146 |
+
urllib.request.urlretrieve(self.GCS + f, p)
|
| 147 |
+
self.m2v = Word2Vec.load(os.path.join(cache, "pretrained_embeddings"))
|
| 148 |
+
self.emb = {w: self.m2v.wv[w] for w in self.m2v.wv.index_to_key}
|
| 149 |
+
|
| 150 |
+
def _pool(self, c):
|
| 151 |
+
v, t = np.zeros(200, np.float32), 0.0
|
| 152 |
+
for s, f in c.get_el_amt_dict().items():
|
| 153 |
+
if s in self.emb: v += f * self.emb[s]; t += f
|
| 154 |
+
return v / max(t, 1e-8)
|
| 155 |
+
|
| 156 |
+
def _featurize_extra(self, comp, structure=None):
|
| 157 |
+
parts = []
|
| 158 |
+
for name, ftzr in self.extra_featurizers:
|
| 159 |
+
try:
|
| 160 |
+
vals = np.array(ftzr.featurize(comp), np.float32)
|
| 161 |
+
parts.append(np.nan_to_num(vals, nan=0.0))
|
| 162 |
+
if self._extra_sizes.get(name) is None:
|
| 163 |
+
self._extra_sizes[name] = len(vals)
|
| 164 |
+
except:
|
| 165 |
+
sz = self._extra_sizes.get(name, 0) or 1
|
| 166 |
+
parts.append(np.zeros(sz, np.float32))
|
| 167 |
+
if structure is not None:
|
| 168 |
+
parts.append(_extract_structural_features(structure))
|
| 169 |
+
else:
|
| 170 |
+
parts.append(np.zeros(11, np.float32))
|
| 171 |
+
return np.concatenate(parts)
|
| 172 |
+
|
| 173 |
+
def featurize_all(self, comps, structures=None):
|
| 174 |
+
out = []
|
| 175 |
+
test_struct = structures[0] if structures else None
|
| 176 |
+
test_ex = self._featurize_extra(comps[0], test_struct)
|
| 177 |
+
self.n_extra = len(test_ex)
|
| 178 |
+
total = self.n_mg + self.n_extra + 200
|
| 179 |
+
comp_extras = sum(self._extra_sizes.get(n, 0) or 0
|
| 180 |
+
for n, _ in self.extra_featurizers)
|
| 181 |
+
log.info(f"Features: {self.n_mg} Magpie + {comp_extras} CompExtra + "
|
| 182 |
+
f"11 Structural + 200 Mat2Vec = {total}d")
|
| 183 |
+
for i, c in enumerate(tqdm(comps, desc=" Featurizing", leave=False)):
|
| 184 |
+
struct = structures[i] if structures else None
|
| 185 |
+
try: mg = np.array(self.ep_magpie.featurize(c), np.float32)
|
| 186 |
+
except: mg = np.zeros(self.n_mg, np.float32)
|
| 187 |
+
ex = self._featurize_extra(c, struct)
|
| 188 |
+
out.append(np.concatenate([
|
| 189 |
+
np.nan_to_num(mg, nan=0.0),
|
| 190 |
+
np.nan_to_num(ex, nan=0.0),
|
| 191 |
+
self._pool(c)
|
| 192 |
+
]))
|
| 193 |
+
return np.array(out)
|
| 194 |
+
|
| 195 |
+
def fit_scaler(self, X): self.scaler = StandardScaler().fit(X)
|
| 196 |
+
def transform(self, X):
|
| 197 |
+
if not self.scaler: return X
|
| 198 |
+
return np.nan_to_num(self.scaler.transform(X), nan=0.0).astype(np.float32)
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
# ======================================================================
|
| 202 |
+
# MODEL
|
| 203 |
+
# ======================================================================
|
| 204 |
+
|
| 205 |
+
class DeepHybridTRM(nn.Module):
|
| 206 |
+
def __init__(self, n_props=22, stat_dim=6, n_extra=0, mat2vec_dim=200,
|
| 207 |
+
d_attn=32, nhead=4, d_hidden=64, ff_dim=96,
|
| 208 |
+
dropout=0.15, max_steps=16, **kw):
|
| 209 |
+
super().__init__()
|
| 210 |
+
self.max_steps, self.D = max_steps, d_hidden
|
| 211 |
+
self.n_props, self.stat_dim, self.n_extra = n_props, stat_dim, n_extra
|
| 212 |
+
|
| 213 |
+
self.tok_proj = nn.Sequential(
|
| 214 |
+
nn.Linear(stat_dim, d_attn), nn.LayerNorm(d_attn), nn.GELU())
|
| 215 |
+
self.m2v_proj = nn.Sequential(
|
| 216 |
+
nn.Linear(mat2vec_dim, d_attn), nn.LayerNorm(d_attn), nn.GELU())
|
| 217 |
+
|
| 218 |
+
self.sa1 = nn.MultiheadAttention(d_attn, nhead, dropout=dropout, batch_first=True)
|
| 219 |
+
self.sa1_n = nn.LayerNorm(d_attn)
|
| 220 |
+
self.sa1_ff = nn.Sequential(
|
| 221 |
+
nn.Linear(d_attn, d_attn*2), nn.GELU(), nn.Dropout(dropout),
|
| 222 |
+
nn.Linear(d_attn*2, d_attn))
|
| 223 |
+
self.sa1_fn = nn.LayerNorm(d_attn)
|
| 224 |
+
|
| 225 |
+
self.sa2 = nn.MultiheadAttention(d_attn, nhead, dropout=dropout, batch_first=True)
|
| 226 |
+
self.sa2_n = nn.LayerNorm(d_attn)
|
| 227 |
+
self.sa2_ff = nn.Sequential(
|
| 228 |
+
nn.Linear(d_attn, d_attn*2), nn.GELU(), nn.Dropout(dropout),
|
| 229 |
+
nn.Linear(d_attn*2, d_attn))
|
| 230 |
+
self.sa2_fn = nn.LayerNorm(d_attn)
|
| 231 |
+
|
| 232 |
+
self.ca = nn.MultiheadAttention(d_attn, nhead, dropout=dropout, batch_first=True)
|
| 233 |
+
self.ca_n = nn.LayerNorm(d_attn)
|
| 234 |
+
|
| 235 |
+
pool_in = d_attn + (n_extra if n_extra > 0 else 0)
|
| 236 |
+
self.pool = nn.Sequential(
|
| 237 |
+
nn.Linear(pool_in, d_hidden), nn.LayerNorm(d_hidden), nn.GELU())
|
| 238 |
+
|
| 239 |
+
self.z_up = nn.Sequential(
|
| 240 |
+
nn.Linear(d_hidden*3, ff_dim), nn.GELU(), nn.Dropout(dropout),
|
| 241 |
+
nn.Linear(ff_dim, d_hidden), nn.LayerNorm(d_hidden))
|
| 242 |
+
self.y_up = nn.Sequential(
|
| 243 |
+
nn.Linear(d_hidden*2, ff_dim), nn.GELU(), nn.Dropout(dropout),
|
| 244 |
+
nn.Linear(ff_dim, d_hidden), nn.LayerNorm(d_hidden))
|
| 245 |
+
self.head = nn.Linear(d_hidden, 1)
|
| 246 |
+
self._init()
|
| 247 |
+
|
| 248 |
+
def _init(self):
|
| 249 |
+
for m in self.modules():
|
| 250 |
+
if isinstance(m, nn.Linear):
|
| 251 |
+
nn.init.xavier_uniform_(m.weight)
|
| 252 |
+
if m.bias is not None: nn.init.zeros_(m.bias)
|
| 253 |
+
|
| 254 |
+
def _attention(self, x):
|
| 255 |
+
B = x.size(0)
|
| 256 |
+
mg_dim = self.n_props * self.stat_dim
|
| 257 |
+
if self.n_extra > 0:
|
| 258 |
+
extra = x[:, mg_dim:mg_dim + self.n_extra]
|
| 259 |
+
m2v = x[:, mg_dim + self.n_extra:]
|
| 260 |
+
else:
|
| 261 |
+
extra, m2v = None, x[:, mg_dim:]
|
| 262 |
+
|
| 263 |
+
tok = self.tok_proj(x[:, :mg_dim].view(B, self.n_props, self.stat_dim))
|
| 264 |
+
ctx = self.m2v_proj(m2v).unsqueeze(1)
|
| 265 |
+
|
| 266 |
+
tok = self.sa1_n(tok + self.sa1(tok, tok, tok)[0])
|
| 267 |
+
tok = self.sa1_fn(tok + self.sa1_ff(tok))
|
| 268 |
+
tok = self.sa2_n(tok + self.sa2(tok, tok, tok)[0])
|
| 269 |
+
tok = self.sa2_fn(tok + self.sa2_ff(tok))
|
| 270 |
+
tok = self.ca_n(tok + self.ca(tok, ctx, ctx)[0])
|
| 271 |
+
|
| 272 |
+
pooled = tok.mean(dim=1)
|
| 273 |
+
if extra is not None:
|
| 274 |
+
pooled = torch.cat([pooled, extra], dim=-1)
|
| 275 |
+
return self.pool(pooled)
|
| 276 |
+
|
| 277 |
+
def forward(self, x, deep_supervision=False):
|
| 278 |
+
B = x.size(0)
|
| 279 |
+
xp = self._attention(x)
|
| 280 |
+
z = torch.zeros(B, self.D, device=x.device)
|
| 281 |
+
y = torch.zeros(B, self.D, device=x.device)
|
| 282 |
+
step_preds = []
|
| 283 |
+
for s in range(self.max_steps):
|
| 284 |
+
z = z + self.z_up(torch.cat([xp, y, z], -1))
|
| 285 |
+
y = y + self.y_up(torch.cat([y, z], -1))
|
| 286 |
+
step_preds.append(self.head(y).squeeze(1))
|
| 287 |
+
return step_preds if deep_supervision else step_preds[-1]
|
| 288 |
+
|
| 289 |
+
def count_parameters(self):
|
| 290 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
# ======================================================================
|
| 294 |
+
# LOSS + UTILS
|
| 295 |
+
# ======================================================================
|
| 296 |
+
|
| 297 |
+
def deep_supervision_loss(step_preds, targets):
|
| 298 |
+
preds = torch.stack(step_preds)
|
| 299 |
+
n = preds.shape[0]
|
| 300 |
+
w = torch.arange(1, n + 1, device=preds.device, dtype=preds.dtype)
|
| 301 |
+
w = w / w.sum()
|
| 302 |
+
per_step = (preds - targets.unsqueeze(0)).abs().mean(dim=1)
|
| 303 |
+
return (w * per_step).sum()
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def strat_split(targets, val_size=0.15, seed=42):
|
| 307 |
+
bins = np.percentile(targets, [25, 50, 75])
|
| 308 |
+
lbl = np.digitize(targets, bins)
|
| 309 |
+
tr, vl = [], []
|
| 310 |
+
rng = np.random.RandomState(seed)
|
| 311 |
+
for b in range(4):
|
| 312 |
+
m = np.where(lbl == b)[0]
|
| 313 |
+
if len(m) == 0: continue
|
| 314 |
+
n = max(1, int(len(m) * val_size))
|
| 315 |
+
c = rng.choice(m, n, replace=False)
|
| 316 |
+
vl.extend(c.tolist()); tr.extend(np.setdiff1d(m, c).tolist())
|
| 317 |
+
return np.array(tr), np.array(vl)
|
| 318 |
+
|
| 319 |
+
|
| 320 |
+
@torch.inference_mode()
|
| 321 |
+
def predict(model, dl):
|
| 322 |
+
model.eval()
|
| 323 |
+
preds = []
|
| 324 |
+
for bx, _ in dl:
|
| 325 |
+
preds.append(model(bx).cpu())
|
| 326 |
+
return torch.cat(preds)
|
| 327 |
+
|
| 328 |
+
|
| 329 |
+
# ======================================================================
|
| 330 |
+
# TRAINING
|
| 331 |
+
# ======================================================================
|
| 332 |
+
|
| 333 |
+
def train_fold(model, tr_dl, vl_dl, device,
|
| 334 |
+
epochs=300, swa_start=200, fold=1, seed=42):
|
| 335 |
+
opt = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
|
| 336 |
+
sch = torch.optim.lr_scheduler.CosineAnnealingLR(
|
| 337 |
+
opt, T_max=swa_start, eta_min=1e-4)
|
| 338 |
+
swa_m = AveragedModel(model)
|
| 339 |
+
swa_s = SWALR(opt, swa_lr=5e-4)
|
| 340 |
+
swa_on = False
|
| 341 |
+
best_v, best_w = float('inf'), None
|
| 342 |
+
|
| 343 |
+
pbar = tqdm(range(epochs), desc=f" [75K|s{seed}] F{fold}/5",
|
| 344 |
+
leave=False, ncols=120)
|
| 345 |
+
for ep in pbar:
|
| 346 |
+
model.train()
|
| 347 |
+
epoch_loss = torch.tensor(0.0, device=device)
|
| 348 |
+
n_samples = 0
|
| 349 |
+
|
| 350 |
+
for bx, by in tr_dl:
|
| 351 |
+
sp = model(bx, deep_supervision=True)
|
| 352 |
+
loss = deep_supervision_loss(sp, by)
|
| 353 |
+
opt.zero_grad(set_to_none=True)
|
| 354 |
+
loss.backward()
|
| 355 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 356 |
+
opt.step()
|
| 357 |
+
with torch.no_grad():
|
| 358 |
+
epoch_loss += (sp[-1] - by).abs().sum()
|
| 359 |
+
n_samples += len(by)
|
| 360 |
+
|
| 361 |
+
model.eval()
|
| 362 |
+
val_loss = torch.tensor(0.0, device=device)
|
| 363 |
+
val_n = 0
|
| 364 |
+
with torch.inference_mode():
|
| 365 |
+
for bx, by in vl_dl:
|
| 366 |
+
val_loss += (model(bx) - by).abs().sum()
|
| 367 |
+
val_n += len(by)
|
| 368 |
+
|
| 369 |
+
tl = epoch_loss.item() / n_samples
|
| 370 |
+
vl = val_loss.item() / val_n
|
| 371 |
+
|
| 372 |
+
if ep < swa_start:
|
| 373 |
+
sch.step()
|
| 374 |
+
if vl < best_v:
|
| 375 |
+
best_v = vl
|
| 376 |
+
best_w = copy.deepcopy(model.state_dict())
|
| 377 |
+
else:
|
| 378 |
+
if not swa_on: swa_on = True
|
| 379 |
+
swa_m.update_parameters(model); swa_s.step()
|
| 380 |
+
|
| 381 |
+
if ep % 10 == 0 or ep == epochs - 1:
|
| 382 |
+
pbar.set_postfix(Best=f'{best_v:.2f}', Ph='SWA' if swa_on else 'COS',
|
| 383 |
+
Tr=f'{tl:.2f}', Val=f'{vl:.2f}')
|
| 384 |
+
|
| 385 |
+
if swa_on:
|
| 386 |
+
update_bn(tr_dl, swa_m, device=device)
|
| 387 |
+
model.load_state_dict(swa_m.module.state_dict())
|
| 388 |
+
else:
|
| 389 |
+
model.load_state_dict(best_w)
|
| 390 |
+
return best_v, model
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
# ======================================================================
|
| 394 |
+
# MAIN — 5-SEED ENSEMBLE
|
| 395 |
+
# ======================================================================
|
| 396 |
+
|
| 397 |
+
def run_benchmark():
|
| 398 |
+
t0 = time.time()
|
| 399 |
+
|
| 400 |
+
print(f"""
|
| 401 |
+
+==========================================================+
|
| 402 |
+
| TRIADS V4 — matbench_jdft2d (5-Seed Ensemble) |
|
| 403 |
+
| Structural + Composition features (~361d) |
|
| 404 |
+
| 75K model | dropout=0.20 |
|
| 405 |
+
| Seeds: {SEEDS} |
|
| 406 |
+
+==========================================================+
|
| 407 |
+
""")
|
| 408 |
+
|
| 409 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 410 |
+
if device.type == 'cuda':
|
| 411 |
+
gm = torch.cuda.get_device_properties(0).total_memory / 1e9
|
| 412 |
+
print(f" GPU: {torch.cuda.get_device_name(0)} ({gm:.1f} GB)")
|
| 413 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 414 |
+
torch.backends.cudnn.benchmark = True
|
| 415 |
+
|
| 416 |
+
# ── LOAD DATASET ──────────────────────────────────────────────────
|
| 417 |
+
print("\n Loading matbench_jdft2d...")
|
| 418 |
+
from matminer.datasets import load_dataset
|
| 419 |
+
df = load_dataset("matbench_jdft2d")
|
| 420 |
+
targets_all = np.array(df['exfoliation_en'].tolist(), np.float32)
|
| 421 |
+
structures_all = df['structure'].tolist()
|
| 422 |
+
comps_all = [s.composition for s in structures_all]
|
| 423 |
+
print(f" Dataset: {len(comps_all)} samples")
|
| 424 |
+
|
| 425 |
+
# ── FEATURIZE (once) ─────────────────────────────────────────────
|
| 426 |
+
t_feat = time.time()
|
| 427 |
+
feat = ExfoliationFeaturizer()
|
| 428 |
+
X_all = feat.featurize_all(comps_all, structures_all)
|
| 429 |
+
n_extra = feat.n_extra
|
| 430 |
+
print(f" Features: {X_all.shape} (n_extra={n_extra})")
|
| 431 |
+
print(f" Featurization: {time.time()-t_feat:.1f}s")
|
| 432 |
+
|
| 433 |
+
# ── FOLDS ────────────────────────────────────────────────────────
|
| 434 |
+
kfold = KFold(n_splits=5, shuffle=True, random_state=18012019)
|
| 435 |
+
folds = list(kfold.split(comps_all))
|
| 436 |
+
for fi, (tv, te) in enumerate(folds):
|
| 437 |
+
assert len(set(tv) & set(te)) == 0
|
| 438 |
+
print(" 5 folds verified: zero leakage\n")
|
| 439 |
+
|
| 440 |
+
# ── MODEL INFO ───────────────────────────────────────────────────
|
| 441 |
+
model_kw = dict(n_props=22, stat_dim=6, n_extra=n_extra,
|
| 442 |
+
mat2vec_dim=200, **MODEL_CFG)
|
| 443 |
+
test_model = DeepHybridTRM(**model_kw)
|
| 444 |
+
n_params = test_model.count_parameters()
|
| 445 |
+
del test_model
|
| 446 |
+
print(f" Model: {n_params:,} params")
|
| 447 |
+
print(f" Config: d_attn={MODEL_CFG['d_attn']}, d_hidden={MODEL_CFG['d_hidden']}, "
|
| 448 |
+
f"ff_dim={MODEL_CFG['ff_dim']}, dropout={MODEL_CFG['dropout']}\n")
|
| 449 |
+
|
| 450 |
+
# ── TRAIN ALL SEEDS ──────────────────────────────────────────────
|
| 451 |
+
model_dir = 'jdft2d_models_v4'
|
| 452 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 453 |
+
|
| 454 |
+
# Store predictions and MAEs per seed
|
| 455 |
+
all_seed_maes = {} # {seed: {fold: mae}}
|
| 456 |
+
all_fold_preds = {} # {fold: {seed: predictions}}
|
| 457 |
+
all_fold_targets = {} # {fold: targets}
|
| 458 |
+
|
| 459 |
+
for seed in SEEDS:
|
| 460 |
+
print(f"\n {'─'*3} Seed {seed} {'─'*40}")
|
| 461 |
+
t_seed = time.time()
|
| 462 |
+
seed_maes = {}
|
| 463 |
+
|
| 464 |
+
for fi, (tv_i, te_i) in enumerate(folds):
|
| 465 |
+
tri, vli = strat_split(targets_all[tv_i], 0.15, seed + fi)
|
| 466 |
+
feat.fit_scaler(X_all[tv_i][tri])
|
| 467 |
+
|
| 468 |
+
tr_x = torch.tensor(feat.transform(X_all[tv_i][tri]), dtype=torch.float32).to(device)
|
| 469 |
+
tr_y = torch.tensor(targets_all[tv_i][tri], dtype=torch.float32).to(device)
|
| 470 |
+
vl_x = torch.tensor(feat.transform(X_all[tv_i][vli]), dtype=torch.float32).to(device)
|
| 471 |
+
vl_y = torch.tensor(targets_all[tv_i][vli], dtype=torch.float32).to(device)
|
| 472 |
+
te_x = torch.tensor(feat.transform(X_all[te_i]), dtype=torch.float32).to(device)
|
| 473 |
+
te_y = torch.tensor(targets_all[te_i], dtype=torch.float32).to(device)
|
| 474 |
+
|
| 475 |
+
tr_dl = FastTensorDataLoader(tr_x, tr_y, batch_size=BATCH_SIZE, shuffle=True)
|
| 476 |
+
vl_dl = FastTensorDataLoader(vl_x, vl_y, batch_size=BATCH_SIZE, shuffle=False)
|
| 477 |
+
te_dl = FastTensorDataLoader(te_x, te_y, batch_size=BATCH_SIZE, shuffle=False)
|
| 478 |
+
|
| 479 |
+
torch.manual_seed(seed + fi)
|
| 480 |
+
np.random.seed(seed + fi)
|
| 481 |
+
if device.type == 'cuda': torch.cuda.manual_seed(seed + fi)
|
| 482 |
+
|
| 483 |
+
model = DeepHybridTRM(**model_kw).to(device)
|
| 484 |
+
bv, model = train_fold(model, tr_dl, vl_dl, device,
|
| 485 |
+
epochs=300, swa_start=200,
|
| 486 |
+
fold=fi+1, seed=seed)
|
| 487 |
+
|
| 488 |
+
pred = predict(model, te_dl)
|
| 489 |
+
mae = F.l1_loss(pred, te_y.cpu()).item()
|
| 490 |
+
seed_maes[fi] = mae
|
| 491 |
+
|
| 492 |
+
# Store for ensemble
|
| 493 |
+
if fi not in all_fold_preds:
|
| 494 |
+
all_fold_preds[fi] = {}
|
| 495 |
+
all_fold_targets[fi] = te_y.cpu()
|
| 496 |
+
all_fold_preds[fi][seed] = pred
|
| 497 |
+
|
| 498 |
+
torch.save({
|
| 499 |
+
'model_state': model.state_dict(),
|
| 500 |
+
'test_mae': mae, 'fold': fi+1, 'seed': seed,
|
| 501 |
+
'n_extra': n_extra,
|
| 502 |
+
}, f'{model_dir}/jdft2d_75K_s{seed}_f{fi+1}.pt')
|
| 503 |
+
|
| 504 |
+
del model, tr_x, tr_y, vl_x, vl_y, te_x, te_y
|
| 505 |
+
if device.type == 'cuda': torch.cuda.empty_cache()
|
| 506 |
+
|
| 507 |
+
avg_s = np.mean(list(seed_maes.values()))
|
| 508 |
+
all_seed_maes[seed] = seed_maes
|
| 509 |
+
dt = time.time() - t_seed
|
| 510 |
+
print(f"\n Seed {seed}: avg={avg_s:.4f} | "
|
| 511 |
+
f"{[f'{seed_maes[i]:.4f}' for i in range(5)]} ({dt:.0f}s)")
|
| 512 |
+
|
| 513 |
+
# ── ENSEMBLE ─────────────────────────────────────────────────────
|
| 514 |
+
ens_maes = {}
|
| 515 |
+
for fi in range(5):
|
| 516 |
+
preds_stack = torch.stack([all_fold_preds[fi][s] for s in SEEDS])
|
| 517 |
+
ens_pred = preds_stack.mean(dim=0)
|
| 518 |
+
ens_maes[fi] = F.l1_loss(ens_pred, all_fold_targets[fi]).item()
|
| 519 |
+
|
| 520 |
+
single_avgs = [np.mean(list(all_seed_maes[s].values())) for s in SEEDS]
|
| 521 |
+
single_mean = np.mean(single_avgs)
|
| 522 |
+
single_std = np.std(single_avgs)
|
| 523 |
+
ens_mean = np.mean(list(ens_maes.values()))
|
| 524 |
+
ens_std = np.std(list(ens_maes.values()))
|
| 525 |
+
ens_drop = (1 - ens_mean / single_mean) * 100
|
| 526 |
+
|
| 527 |
+
# ── RESULTS ──────────────────────────────────────────────────────
|
| 528 |
+
tt = time.time() - t0
|
| 529 |
+
|
| 530 |
+
print(f"""
|
| 531 |
+
{'='*72}
|
| 532 |
+
FINAL RESULTS — TRIADS V4 on matbench_jdft2d
|
| 533 |
+
{'='*72}
|
| 534 |
+
|
| 535 |
+
Per-seed results:""")
|
| 536 |
+
|
| 537 |
+
for seed in SEEDS:
|
| 538 |
+
sm = all_seed_maes[seed]
|
| 539 |
+
avg_s = np.mean(list(sm.values()))
|
| 540 |
+
print(f" Seed {seed:>4}: {avg_s:.4f} | "
|
| 541 |
+
f"{[f'{sm[i]:.4f}' for i in range(5)]}")
|
| 542 |
+
|
| 543 |
+
print(f"""
|
| 544 |
+
Single-seed avg: {single_mean:.4f} ± {single_std:.4f}
|
| 545 |
+
5-Seed Ensemble: {ens_mean:.4f} ± {ens_std:.4f} (↓{ens_drop:.1f}% from single)
|
| 546 |
+
Per-fold ens: {[f'{ens_maes[i]:.4f}' for i in range(5)]}
|
| 547 |
+
|
| 548 |
+
{'Model':<40} {'MAE(meV/atom)':>15}
|
| 549 |
+
{'─'*58}
|
| 550 |
+
{'MODNet v0.1.12':<40} {'33.1918':>15}
|
| 551 |
+
{'TRIADS V3 (75K, +struct, single)':<40} {'37.0033':>15}
|
| 552 |
+
{'TRIADS V4 (75K, +struct, 5-seed ens)':<40} {f'{ens_mean:.4f}':>15} ← NEW
|
| 553 |
+
{'TRIADS V1 (100K, comp-only)':<40} {'45.8045':>15}
|
| 554 |
+
{'─'*58}
|
| 555 |
+
|
| 556 |
+
Total time: {tt/60:.1f} min
|
| 557 |
+
Saved: {model_dir}/
|
| 558 |
+
""")
|
| 559 |
+
|
| 560 |
+
# ── SAVE ─────────────────────────────────────────────────────────
|
| 561 |
+
summary = {
|
| 562 |
+
'version': 'jdft2d-V4-ensemble',
|
| 563 |
+
'dataset': 'matbench_jdft2d',
|
| 564 |
+
'samples': len(comps_all),
|
| 565 |
+
'target_unit': 'meV/atom',
|
| 566 |
+
'model_config': MODEL_CFG,
|
| 567 |
+
'params': n_params,
|
| 568 |
+
'seeds': SEEDS,
|
| 569 |
+
'per_seed': {str(s): {str(k): round(v, 4) for k, v in m.items()}
|
| 570 |
+
for s, m in all_seed_maes.items()},
|
| 571 |
+
'single_seed_avg': round(single_mean, 4),
|
| 572 |
+
'single_seed_std': round(single_std, 4),
|
| 573 |
+
'ensemble_maes': {str(k): round(v, 4) for k, v in ens_maes.items()},
|
| 574 |
+
'ensemble_avg': round(ens_mean, 4),
|
| 575 |
+
'ensemble_std': round(ens_std, 4),
|
| 576 |
+
'ensemble_improvement': f'{ens_drop:.1f}%',
|
| 577 |
+
'total_time_min': round(tt/60, 1),
|
| 578 |
+
}
|
| 579 |
+
with open('jdft2d_summary_v4.json', 'w') as f:
|
| 580 |
+
json.dump(summary, f, indent=2)
|
| 581 |
+
print(" Saved: jdft2d_summary_v4.json")
|
| 582 |
+
|
| 583 |
+
# Zip models
|
| 584 |
+
shutil.make_archive(model_dir, 'zip', '.', model_dir)
|
| 585 |
+
print(f" Saved: {model_dir}.zip (download this!)")
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
if __name__ == '__main__':
|
| 589 |
+
run_benchmark()
|
model_code/phonons_dataset_builder.py
ADDED
|
@@ -0,0 +1,749 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
+=============================================================+
|
| 3 |
+
| V6 Physics-Featurized Phonon Dataset Builder |
|
| 4 |
+
| Architecture-Agnostic | Rich Physics | 3-Order Graphs |
|
| 5 |
+
| |
|
| 6 |
+
| Features per atom: 18d (element physics + coords + local) |
|
| 7 |
+
| Features per bond: 8d physics + 40d RBF + 3d direction |
|
| 8 |
+
| Order 2 (angles): 8d angle RBF |
|
| 9 |
+
| Order 3 (dihedrals): 8d dihedral RBF |
|
| 10 |
+
| Composition: MAGPIE + mat2vec + matminer extras |
|
| 11 |
+
| Global physics: Debye temp, force constants, etc. |
|
| 12 |
+
| |
|
| 13 |
+
| ⚠ NO SCALING — raw features. Scale at training time only. |
|
| 14 |
+
+=============================================================+
|
| 15 |
+
|
| 16 |
+
DEPENDENCIES:
|
| 17 |
+
pip install matminer pymatgen gensim tqdm scikit-learn torch numpy
|
| 18 |
+
|
| 19 |
+
USAGE:
|
| 20 |
+
python build_phonons_v6_dataset.py
|
| 21 |
+
-> Outputs: phonons_v6_dataset.pt
|
| 22 |
+
"""
|
| 23 |
+
|
| 24 |
+
import os, time, math, warnings, urllib.request, logging
|
| 25 |
+
from collections import defaultdict
|
| 26 |
+
warnings.filterwarnings('ignore')
|
| 27 |
+
|
| 28 |
+
import numpy as np
|
| 29 |
+
import torch
|
| 30 |
+
from tqdm import tqdm
|
| 31 |
+
from sklearn.model_selection import KFold
|
| 32 |
+
|
| 33 |
+
logging.basicConfig(level=logging.INFO, format='%(name)s | %(message)s')
|
| 34 |
+
log = logging.getLogger("V6-BUILD")
|
| 35 |
+
|
| 36 |
+
# ═══════════════════════════════════════════════════════════════
|
| 37 |
+
# CONFIGURATION
|
| 38 |
+
# ═══════════════════════════════════════════════════════════════
|
| 39 |
+
|
| 40 |
+
CUTOFF = 8.0
|
| 41 |
+
MAX_NEIGHBORS = 12
|
| 42 |
+
N_RBF_DIST = 40
|
| 43 |
+
N_RBF_ANGLE = 8
|
| 44 |
+
N_RBF_DIHEDRAL = 8
|
| 45 |
+
MAX_QUADS = 50000 # cap dihedrals per crystal for memory
|
| 46 |
+
FOLD_SEED = 18012019 # matbench v0.1 protocol
|
| 47 |
+
N_FOLDS = 5
|
| 48 |
+
|
| 49 |
+
N_ELEM_FEAT = 12 # from lookup table
|
| 50 |
+
N_ATOM_COMPUTED = 6 # frac_coords(3) + coord_num(1) + avg_nn(1) + std_nn(1)
|
| 51 |
+
N_ATOM_FEAT = N_ELEM_FEAT + N_ATOM_COMPUTED # 18
|
| 52 |
+
N_BOND_PHYSICS = 8
|
| 53 |
+
N_GLOBAL_PHYS = 15
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
# ═══════════════════════════════════════════════════════════════
|
| 57 |
+
# GAUSSIAN RADIAL BASIS FUNCTIONS
|
| 58 |
+
# ═══════════════════════════════════════════════════════════════
|
| 59 |
+
|
| 60 |
+
def gaussian_rbf(values, n_bins, vmin, vmax):
|
| 61 |
+
"""Fixed Gaussian expansion. No learnable parameters."""
|
| 62 |
+
centers = torch.linspace(vmin, vmax, n_bins)
|
| 63 |
+
gamma = 1.0 / ((vmax - vmin) / n_bins) ** 2
|
| 64 |
+
return torch.exp(-gamma * (values.unsqueeze(-1) - centers.unsqueeze(0)) ** 2)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# ═══════════════════════════════════════════════════════════════
|
| 68 |
+
# ELEMENT PHYSICS LOOKUP TABLE
|
| 69 |
+
# ═══════════════════════════════════════════════════════════════
|
| 70 |
+
|
| 71 |
+
def build_element_table():
|
| 72 |
+
"""
|
| 73 |
+
Build [103, 12] lookup table of per-element physical properties.
|
| 74 |
+
Z=0 is padding. Uses pymatgen Element data.
|
| 75 |
+
|
| 76 |
+
Columns: mass, 1/sqrt(mass), electronegativity, atomic_radius,
|
| 77 |
+
covalent_radius, ionization_energy, electron_affinity,
|
| 78 |
+
valence_electrons, group, period, block, is_metal
|
| 79 |
+
"""
|
| 80 |
+
from pymatgen.core.periodic_table import Element
|
| 81 |
+
|
| 82 |
+
block_map = {'s': 0., 'p': 1., 'd': 2., 'f': 3.}
|
| 83 |
+
table = torch.zeros(103, N_ELEM_FEAT)
|
| 84 |
+
|
| 85 |
+
for z in range(1, 103):
|
| 86 |
+
try:
|
| 87 |
+
el = Element.from_Z(z)
|
| 88 |
+
mass = float(el.atomic_mass) if el.atomic_mass else 1.0
|
| 89 |
+
chi = float(el.X) if el.X is not None else 0.0
|
| 90 |
+
ar = float(el.atomic_radius) if el.atomic_radius is not None else 1.5
|
| 91 |
+
# Covalent radius proxy
|
| 92 |
+
try:
|
| 93 |
+
cr = float(el.average_ionic_radius) if el.average_ionic_radius and float(el.average_ionic_radius) > 0 else ar
|
| 94 |
+
except:
|
| 95 |
+
cr = ar
|
| 96 |
+
# First ionization energy
|
| 97 |
+
ie = 0.0
|
| 98 |
+
try:
|
| 99 |
+
ies = el.ionization_energies
|
| 100 |
+
if isinstance(ies, dict) and 1 in ies and ies[1] is not None:
|
| 101 |
+
ie = float(ies[1])
|
| 102 |
+
elif isinstance(ies, (list, tuple)) and len(ies) > 1 and ies[1] is not None:
|
| 103 |
+
ie = float(ies[1])
|
| 104 |
+
except:
|
| 105 |
+
pass
|
| 106 |
+
# Electron affinity
|
| 107 |
+
ea = 0.0
|
| 108 |
+
try:
|
| 109 |
+
if el.electron_affinity is not None:
|
| 110 |
+
ea = float(el.electron_affinity)
|
| 111 |
+
except:
|
| 112 |
+
pass
|
| 113 |
+
# Group, period, valence electrons
|
| 114 |
+
g = int(el.group) if el.group is not None else 0
|
| 115 |
+
p = int(el.row) if el.row is not None else 0
|
| 116 |
+
ve = g if g <= 2 else (g - 10 if g >= 13 else 2)
|
| 117 |
+
bl = block_map.get(el.block, 0.) if hasattr(el, 'block') and el.block else 0.
|
| 118 |
+
im = 1.0 if el.is_metal else 0.0
|
| 119 |
+
|
| 120 |
+
table[z] = torch.tensor([
|
| 121 |
+
mass, 1.0 / math.sqrt(max(mass, 0.01)), chi, ar, cr,
|
| 122 |
+
ie, ea, float(ve), float(g), float(p), bl, im
|
| 123 |
+
])
|
| 124 |
+
except:
|
| 125 |
+
table[z] = torch.tensor([1., 1., 0., 1.5, 1.5, 0., 0., 0., 0., 0., 0., 0.])
|
| 126 |
+
|
| 127 |
+
return table
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
# ═══════════════════════════════════════════════════════════════
|
| 131 |
+
# CRYSTAL GRAPH BUILDER (Orders 1, 2, 3)
|
| 132 |
+
# ═══════════════════════════════════════════════════════════════
|
| 133 |
+
|
| 134 |
+
def _empty_graph(atom_z, atom_features, n_atoms):
|
| 135 |
+
"""Fallback for crystals with no neighbors found."""
|
| 136 |
+
return {
|
| 137 |
+
'atom_z': atom_z,
|
| 138 |
+
'atom_features': atom_features,
|
| 139 |
+
'n_atoms': n_atoms,
|
| 140 |
+
'edge_index': torch.zeros(2, 1, dtype=torch.long),
|
| 141 |
+
'edge_dist': torch.zeros(1),
|
| 142 |
+
'edge_rbf': torch.zeros(1, N_RBF_DIST),
|
| 143 |
+
'edge_vec': torch.zeros(1, 3),
|
| 144 |
+
'edge_physics': torch.zeros(1, N_BOND_PHYSICS),
|
| 145 |
+
'n_edges': 1,
|
| 146 |
+
'triplet_index': torch.zeros(2, 0, dtype=torch.long),
|
| 147 |
+
'angle_rbf': torch.zeros(0, N_RBF_ANGLE),
|
| 148 |
+
'n_triplets': 0,
|
| 149 |
+
'quad_index': torch.zeros(2, 0, dtype=torch.long),
|
| 150 |
+
'dihedral_rbf': torch.zeros(0, N_RBF_DIHEDRAL),
|
| 151 |
+
'n_quads': 0,
|
| 152 |
+
}
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
def build_crystal_graph(structure, elem_table):
|
| 156 |
+
"""
|
| 157 |
+
Build a complete 3-order crystal graph for a single structure.
|
| 158 |
+
|
| 159 |
+
Returns dict with atom features, edge features + physics,
|
| 160 |
+
triplets (angles), and quads (dihedrals).
|
| 161 |
+
|
| 162 |
+
✅ ZERO DATA LEAKAGE: uses ONLY this structure's geometry.
|
| 163 |
+
"""
|
| 164 |
+
n_atoms = len(structure)
|
| 165 |
+
atom_z = torch.tensor([site.specie.Z for site in structure], dtype=torch.long)
|
| 166 |
+
|
| 167 |
+
# Element lookup features [N, 12]
|
| 168 |
+
atom_elem_feat = elem_table[atom_z.clamp(0, 102)]
|
| 169 |
+
|
| 170 |
+
# Fractional coordinates [N, 3]
|
| 171 |
+
frac_coords = torch.tensor(
|
| 172 |
+
[site.frac_coords for site in structure], dtype=torch.float32
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
# ── NEIGHBOR FINDING ──────────────────────────────────────
|
| 176 |
+
src_list, dst_list, dist_list, vec_list = [], [], [], []
|
| 177 |
+
nn_dists_per_atom = defaultdict(list)
|
| 178 |
+
|
| 179 |
+
try:
|
| 180 |
+
all_nbrs = structure.get_all_neighbors(CUTOFF)
|
| 181 |
+
for i, nbrs in enumerate(all_nbrs):
|
| 182 |
+
nbrs_sorted = sorted(nbrs, key=lambda x: x.nn_distance)[:MAX_NEIGHBORS]
|
| 183 |
+
for nbr in nbrs_sorted:
|
| 184 |
+
src_list.append(i)
|
| 185 |
+
dst_list.append(nbr.index)
|
| 186 |
+
dist_list.append(nbr.nn_distance)
|
| 187 |
+
vec_list.append(nbr.coords - structure[i].coords)
|
| 188 |
+
nn_dists_per_atom[i].append(nbr.nn_distance)
|
| 189 |
+
except Exception as e:
|
| 190 |
+
log.warning(f" Neighbor finding failed: {e}")
|
| 191 |
+
|
| 192 |
+
# Per-atom coordination stats
|
| 193 |
+
coord_nums = torch.zeros(n_atoms)
|
| 194 |
+
avg_nn_dists = torch.zeros(n_atoms)
|
| 195 |
+
std_nn_dists = torch.zeros(n_atoms)
|
| 196 |
+
for i in range(n_atoms):
|
| 197 |
+
ds = nn_dists_per_atom.get(i, [])
|
| 198 |
+
coord_nums[i] = len(ds)
|
| 199 |
+
if ds:
|
| 200 |
+
avg_nn_dists[i] = np.mean(ds)
|
| 201 |
+
std_nn_dists[i] = np.std(ds) if len(ds) > 1 else 0.0
|
| 202 |
+
|
| 203 |
+
# Combined atom features [N, 18]
|
| 204 |
+
atom_features = torch.cat([
|
| 205 |
+
atom_elem_feat, # [N, 12]
|
| 206 |
+
frac_coords, # [N, 3]
|
| 207 |
+
coord_nums.unsqueeze(-1), # [N, 1]
|
| 208 |
+
avg_nn_dists.unsqueeze(-1), # [N, 1]
|
| 209 |
+
std_nn_dists.unsqueeze(-1), # [N, 1]
|
| 210 |
+
], dim=-1) # [N, 18]
|
| 211 |
+
|
| 212 |
+
if len(src_list) == 0:
|
| 213 |
+
return _empty_graph(atom_z, atom_features, n_atoms)
|
| 214 |
+
|
| 215 |
+
# ── EDGE FEATURES (Order 1) ───────────────────────────────
|
| 216 |
+
edge_index = torch.tensor([src_list, dst_list], dtype=torch.long)
|
| 217 |
+
edge_dist = torch.tensor(dist_list, dtype=torch.float32)
|
| 218 |
+
raw_vecs = torch.tensor(np.array(vec_list), dtype=torch.float32)
|
| 219 |
+
n_edges = edge_index.shape[1]
|
| 220 |
+
|
| 221 |
+
edge_rbf = gaussian_rbf(edge_dist, N_RBF_DIST, 0.0, CUTOFF)
|
| 222 |
+
norms = raw_vecs.norm(dim=-1, keepdim=True).clamp(min=1e-8)
|
| 223 |
+
edge_vec = raw_vecs / norms
|
| 224 |
+
|
| 225 |
+
# ── BOND PHYSICS FEATURES [E, 8] ─────────────────────────
|
| 226 |
+
z_src = atom_z[edge_index[0]] # [E]
|
| 227 |
+
z_dst = atom_z[edge_index[1]] # [E]
|
| 228 |
+
|
| 229 |
+
m_src = elem_table[z_src.clamp(0, 102), 0] # mass
|
| 230 |
+
m_dst = elem_table[z_dst.clamp(0, 102), 0]
|
| 231 |
+
chi_src = elem_table[z_src.clamp(0, 102), 2] # electronegativity
|
| 232 |
+
chi_dst = elem_table[z_dst.clamp(0, 102), 2]
|
| 233 |
+
r_src = elem_table[z_src.clamp(0, 102), 3] # atomic radius
|
| 234 |
+
r_dst = elem_table[z_dst.clamp(0, 102), 3]
|
| 235 |
+
|
| 236 |
+
d = edge_dist.clamp(min=0.01)
|
| 237 |
+
|
| 238 |
+
# Vectorized bond physics computation
|
| 239 |
+
chi_prod = (chi_src * chi_dst).clamp(min=0.01)
|
| 240 |
+
k_est = torch.sqrt(chi_prod) / (d * d) # force constant
|
| 241 |
+
mu = (m_src * m_dst) / (m_src + m_dst).clamp(min=0.01) # reduced mass
|
| 242 |
+
omega = torch.sqrt(k_est / mu.clamp(min=0.01)) # Einstein freq
|
| 243 |
+
delta_chi = (chi_src - chi_dst).abs() # EN difference
|
| 244 |
+
ionicity = delta_chi * delta_chi # bond ionicity
|
| 245 |
+
r_ratio = (r_src + r_dst) / d # radius sum ratio
|
| 246 |
+
m_ratio = torch.min(m_src, m_dst) / torch.max(m_src, m_dst).clamp(min=0.01)
|
| 247 |
+
inv_d = 1.0 / d # inverse distance
|
| 248 |
+
|
| 249 |
+
edge_physics = torch.stack([
|
| 250 |
+
k_est, mu, omega, delta_chi, ionicity, r_ratio, m_ratio, inv_d
|
| 251 |
+
], dim=-1) # [E, 8]
|
| 252 |
+
|
| 253 |
+
# ── TRIPLETS / ANGLES (Order 2) ───────────────────────────
|
| 254 |
+
dst_np = edge_index[1].numpy()
|
| 255 |
+
dest_to_edges = defaultdict(list)
|
| 256 |
+
for e_idx in range(n_edges):
|
| 257 |
+
dest_to_edges[int(dst_np[e_idx])].append(e_idx)
|
| 258 |
+
|
| 259 |
+
trip_ij, trip_kj = [], []
|
| 260 |
+
for j, edge_list in dest_to_edges.items():
|
| 261 |
+
for idx_ij in edge_list:
|
| 262 |
+
for idx_kj in edge_list:
|
| 263 |
+
if idx_ij != idx_kj:
|
| 264 |
+
trip_ij.append(idx_ij)
|
| 265 |
+
trip_kj.append(idx_kj)
|
| 266 |
+
|
| 267 |
+
if trip_ij:
|
| 268 |
+
triplet_index = torch.tensor([trip_ij, trip_kj], dtype=torch.long)
|
| 269 |
+
v_ij = edge_vec[triplet_index[0]]
|
| 270 |
+
v_kj = edge_vec[triplet_index[1]]
|
| 271 |
+
cos_theta = (v_ij * v_kj).sum(-1).clamp(-1 + 1e-7, 1 - 1e-7)
|
| 272 |
+
angles = torch.acos(cos_theta)
|
| 273 |
+
angle_rbf_t = gaussian_rbf(angles, N_RBF_ANGLE, 0.0, math.pi)
|
| 274 |
+
n_triplets = triplet_index.shape[1]
|
| 275 |
+
else:
|
| 276 |
+
triplet_index = torch.zeros(2, 0, dtype=torch.long)
|
| 277 |
+
angle_rbf_t = torch.zeros(0, N_RBF_ANGLE)
|
| 278 |
+
n_triplets = 0
|
| 279 |
+
|
| 280 |
+
# ── QUADS / DIHEDRALS (Order 3) ───────────────────────────
|
| 281 |
+
quad_index, dihedral_rbf_t, n_quads = _compute_quads(
|
| 282 |
+
triplet_index, n_triplets, edge_vec, trip_ij, trip_kj
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
return {
|
| 286 |
+
'atom_z': atom_z,
|
| 287 |
+
'atom_features': atom_features,
|
| 288 |
+
'n_atoms': n_atoms,
|
| 289 |
+
'edge_index': edge_index,
|
| 290 |
+
'edge_dist': edge_dist,
|
| 291 |
+
'edge_rbf': edge_rbf,
|
| 292 |
+
'edge_vec': edge_vec,
|
| 293 |
+
'edge_physics': edge_physics,
|
| 294 |
+
'n_edges': n_edges,
|
| 295 |
+
'triplet_index': triplet_index,
|
| 296 |
+
'angle_rbf': angle_rbf_t,
|
| 297 |
+
'n_triplets': n_triplets,
|
| 298 |
+
'quad_index': quad_index,
|
| 299 |
+
'dihedral_rbf': dihedral_rbf_t,
|
| 300 |
+
'n_quads': n_quads,
|
| 301 |
+
}
|
| 302 |
+
|
| 303 |
+
|
| 304 |
+
def _compute_quads(triplet_index, n_triplets, edge_vec, trip_ij, trip_kj):
|
| 305 |
+
"""Compute Order 3: pairs of triplets sharing a bond (dihedrals)."""
|
| 306 |
+
if n_triplets == 0:
|
| 307 |
+
return (torch.zeros(2, 0, dtype=torch.long),
|
| 308 |
+
torch.zeros(0, N_RBF_DIHEDRAL), 0)
|
| 309 |
+
|
| 310 |
+
# For each edge, which triplets reference it?
|
| 311 |
+
edge_to_trips = defaultdict(list)
|
| 312 |
+
for t_idx in range(n_triplets):
|
| 313 |
+
edge_to_trips[trip_ij[t_idx]].append(t_idx)
|
| 314 |
+
edge_to_trips[trip_kj[t_idx]].append(t_idx)
|
| 315 |
+
|
| 316 |
+
quad_src, quad_dst = [], []
|
| 317 |
+
for edge_idx, tlist in edge_to_trips.items():
|
| 318 |
+
for i in range(len(tlist)):
|
| 319 |
+
for j in range(len(tlist)):
|
| 320 |
+
if tlist[i] != tlist[j]:
|
| 321 |
+
quad_src.append(tlist[i])
|
| 322 |
+
quad_dst.append(tlist[j])
|
| 323 |
+
if len(quad_src) >= MAX_QUADS:
|
| 324 |
+
break
|
| 325 |
+
if len(quad_src) >= MAX_QUADS:
|
| 326 |
+
break
|
| 327 |
+
if len(quad_src) >= MAX_QUADS:
|
| 328 |
+
break
|
| 329 |
+
|
| 330 |
+
if not quad_src:
|
| 331 |
+
return (torch.zeros(2, 0, dtype=torch.long),
|
| 332 |
+
torch.zeros(0, N_RBF_DIHEDRAL), 0)
|
| 333 |
+
|
| 334 |
+
quad_index = torch.tensor([quad_src, quad_dst], dtype=torch.long)
|
| 335 |
+
|
| 336 |
+
# Dihedral angle = angle between planes of the two triplets
|
| 337 |
+
v_a1 = edge_vec[triplet_index[0, quad_index[0]]]
|
| 338 |
+
v_a2 = edge_vec[triplet_index[1, quad_index[0]]]
|
| 339 |
+
v_b1 = edge_vec[triplet_index[0, quad_index[1]]]
|
| 340 |
+
v_b2 = edge_vec[triplet_index[1, quad_index[1]]]
|
| 341 |
+
|
| 342 |
+
n_a = torch.cross(v_a1, v_a2, dim=-1)
|
| 343 |
+
n_b = torch.cross(v_b1, v_b2, dim=-1)
|
| 344 |
+
n_a = n_a / n_a.norm(dim=-1, keepdim=True).clamp(min=1e-8)
|
| 345 |
+
n_b = n_b / n_b.norm(dim=-1, keepdim=True).clamp(min=1e-8)
|
| 346 |
+
|
| 347 |
+
cos_dih = (n_a * n_b).sum(-1).clamp(-1 + 1e-7, 1 - 1e-7)
|
| 348 |
+
dihedrals = torch.acos(cos_dih)
|
| 349 |
+
dihedral_rbf_t = gaussian_rbf(dihedrals, N_RBF_DIHEDRAL, 0.0, math.pi)
|
| 350 |
+
|
| 351 |
+
return quad_index, dihedral_rbf_t, quad_index.shape[1]
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
# ═══════════════════════════════════════════════════════════════
|
| 355 |
+
# GLOBAL PHYSICS FEATURES (per crystal)
|
| 356 |
+
# ═══════════════════════════════════════════════════════════════
|
| 357 |
+
|
| 358 |
+
def compute_global_physics(graph, structure, elem_table):
|
| 359 |
+
"""
|
| 360 |
+
Compute 15 global physics features from a crystal graph.
|
| 361 |
+
|
| 362 |
+
Features:
|
| 363 |
+
0: avg_force_constant 7: avg_coordination
|
| 364 |
+
1: std_force_constant 8: density
|
| 365 |
+
2: avg_reduced_mass 9: volume_per_atom
|
| 366 |
+
3: mass_variance 10: packing_fraction
|
| 367 |
+
4: avg_einstein_freq 11: avg_bond_length
|
| 368 |
+
5: electronegativity_var 12: std_bond_length
|
| 369 |
+
6: debye_temp_estimate 13: max_atomic_mass
|
| 370 |
+
14: min_atomic_mass
|
| 371 |
+
"""
|
| 372 |
+
ep = graph['edge_physics'] # [E, 8]
|
| 373 |
+
n_atoms = graph['n_atoms']
|
| 374 |
+
atom_z = graph['atom_z']
|
| 375 |
+
|
| 376 |
+
# From bond physics
|
| 377 |
+
k_vals = ep[:, 0] # force constants
|
| 378 |
+
mu_vals = ep[:, 1] # reduced masses
|
| 379 |
+
omega_vals = ep[:, 2] # Einstein frequencies
|
| 380 |
+
dists = graph['edge_dist']
|
| 381 |
+
|
| 382 |
+
feats = torch.zeros(N_GLOBAL_PHYS)
|
| 383 |
+
|
| 384 |
+
if graph['n_edges'] > 0 and dists.shape[0] > 0:
|
| 385 |
+
feats[0] = k_vals.mean()
|
| 386 |
+
feats[1] = k_vals.std() if k_vals.shape[0] > 1 else 0.0
|
| 387 |
+
feats[2] = mu_vals.mean()
|
| 388 |
+
feats[4] = omega_vals.mean()
|
| 389 |
+
feats[11] = dists.mean()
|
| 390 |
+
feats[12] = dists.std() if dists.shape[0] > 1 else 0.0
|
| 391 |
+
|
| 392 |
+
# Mass statistics
|
| 393 |
+
masses = elem_table[atom_z.clamp(0, 102), 0]
|
| 394 |
+
feats[3] = masses.var() if n_atoms > 1 else 0.0
|
| 395 |
+
feats[13] = masses.max()
|
| 396 |
+
feats[14] = masses.min()
|
| 397 |
+
|
| 398 |
+
# Electronegativity variance
|
| 399 |
+
chis = elem_table[atom_z.clamp(0, 102), 2]
|
| 400 |
+
feats[5] = chis.var() if n_atoms > 1 else 0.0
|
| 401 |
+
|
| 402 |
+
# Debye temperature estimate: Θ_D ∝ sqrt(k_avg / m_avg)
|
| 403 |
+
m_avg = masses.mean()
|
| 404 |
+
k_avg = feats[0]
|
| 405 |
+
feats[6] = math.sqrt(float(k_avg / max(m_avg, 0.01)))
|
| 406 |
+
|
| 407 |
+
# Coordination
|
| 408 |
+
feats[7] = graph['atom_features'][:, N_ELEM_FEAT + 3].mean() # coord_num column
|
| 409 |
+
|
| 410 |
+
# Structural
|
| 411 |
+
try:
|
| 412 |
+
feats[8] = structure.density
|
| 413 |
+
feats[9] = structure.volume / max(n_atoms, 1)
|
| 414 |
+
# Packing fraction
|
| 415 |
+
total_vol = sum(
|
| 416 |
+
(4 / 3) * math.pi * (float(site.specie.atomic_radius) ** 3)
|
| 417 |
+
for site in structure
|
| 418 |
+
if hasattr(site.specie, 'atomic_radius') and site.specie.atomic_radius is not None
|
| 419 |
+
)
|
| 420 |
+
feats[10] = total_vol / structure.volume if structure.volume > 0 else 0.0
|
| 421 |
+
except:
|
| 422 |
+
pass
|
| 423 |
+
|
| 424 |
+
return feats
|
| 425 |
+
|
| 426 |
+
|
| 427 |
+
# ═══════════════════════════════════════════════════════════════
|
| 428 |
+
# STRUCTURAL FEATURES (per crystal)
|
| 429 |
+
# ═══════════════════════════════════════════════════════════════
|
| 430 |
+
|
| 431 |
+
def compute_structural_features(structure):
|
| 432 |
+
"""
|
| 433 |
+
Compute 11 structural features: lattice params + symmetry.
|
| 434 |
+
Same as previous versions for backward compatibility.
|
| 435 |
+
"""
|
| 436 |
+
from pymatgen.symmetry.analyzer import SpacegroupAnalyzer
|
| 437 |
+
|
| 438 |
+
feats = np.zeros(11, dtype=np.float32)
|
| 439 |
+
try:
|
| 440 |
+
lat = structure.lattice
|
| 441 |
+
feats[0:6] = [lat.a, lat.b, lat.c, lat.alpha, lat.beta, lat.gamma]
|
| 442 |
+
feats[6] = structure.volume / max(len(structure), 1)
|
| 443 |
+
feats[7] = structure.density
|
| 444 |
+
feats[8] = float(len(structure))
|
| 445 |
+
try:
|
| 446 |
+
sga = SpacegroupAnalyzer(structure, symprec=0.1)
|
| 447 |
+
feats[9] = float(sga.get_space_group_number())
|
| 448 |
+
except:
|
| 449 |
+
feats[9] = 0.0
|
| 450 |
+
try:
|
| 451 |
+
total_vol = sum(
|
| 452 |
+
(4 / 3) * np.pi * site.specie.atomic_radius ** 3
|
| 453 |
+
for site in structure
|
| 454 |
+
if hasattr(site.specie, 'atomic_radius') and site.specie.atomic_radius is not None
|
| 455 |
+
)
|
| 456 |
+
feats[10] = total_vol / structure.volume if structure.volume > 0 else 0.0
|
| 457 |
+
except:
|
| 458 |
+
feats[10] = 0.0
|
| 459 |
+
except:
|
| 460 |
+
pass
|
| 461 |
+
return feats
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
# ═══════════════════════════════════════════════════════════════
|
| 465 |
+
# COMPOSITION FEATURIZER (MAGPIE + mat2vec + matminer extras)
|
| 466 |
+
# ═══════════════════════════════════════════════════════════════
|
| 467 |
+
|
| 468 |
+
class CompositionFeaturizer:
|
| 469 |
+
"""
|
| 470 |
+
Builds rich composition features per crystal:
|
| 471 |
+
- MAGPIE elemental properties (132d: 22 props × 6 stats)
|
| 472 |
+
- Extra matminer (Stoichiometry, ValenceOrbital, IonProperty, TMetalFraction)
|
| 473 |
+
- Structural features (11d)
|
| 474 |
+
- mat2vec embeddings (200d)
|
| 475 |
+
|
| 476 |
+
✅ ALL features are deterministic per-sample. No cross-sample info.
|
| 477 |
+
"""
|
| 478 |
+
M2V_URL = "https://storage.googleapis.com/mat2vec/"
|
| 479 |
+
M2V_FILES = [
|
| 480 |
+
"pretrained_embeddings",
|
| 481 |
+
"pretrained_embeddings.wv.vectors.npy",
|
| 482 |
+
"pretrained_embeddings.trainables.syn1neg.npy",
|
| 483 |
+
]
|
| 484 |
+
|
| 485 |
+
def __init__(self, cache="mat2vec_cache"):
|
| 486 |
+
from matminer.featurizers.composition import (
|
| 487 |
+
ElementProperty, Stoichiometry, ValenceOrbital, IonProperty
|
| 488 |
+
)
|
| 489 |
+
from matminer.featurizers.composition.element import TMetalFraction
|
| 490 |
+
from gensim.models import Word2Vec
|
| 491 |
+
|
| 492 |
+
self.ep_magpie = ElementProperty.from_preset("magpie")
|
| 493 |
+
self.n_magpie = len(self.ep_magpie.feature_labels())
|
| 494 |
+
|
| 495 |
+
self.extra_ftzrs = [
|
| 496 |
+
("Stoichiometry", Stoichiometry()),
|
| 497 |
+
("ValenceOrbital", ValenceOrbital()),
|
| 498 |
+
("IonProperty", IonProperty()),
|
| 499 |
+
("TMetalFraction", TMetalFraction()),
|
| 500 |
+
]
|
| 501 |
+
self._extra_sizes = {}
|
| 502 |
+
for name, ft in self.extra_ftzrs:
|
| 503 |
+
try:
|
| 504 |
+
self._extra_sizes[name] = len(ft.feature_labels())
|
| 505 |
+
except:
|
| 506 |
+
self._extra_sizes[name] = None
|
| 507 |
+
|
| 508 |
+
# Download mat2vec
|
| 509 |
+
os.makedirs(cache, exist_ok=True)
|
| 510 |
+
for f in self.M2V_FILES:
|
| 511 |
+
p = os.path.join(cache, f)
|
| 512 |
+
if not os.path.exists(p):
|
| 513 |
+
log.info(f" Downloading mat2vec: {f}...")
|
| 514 |
+
urllib.request.urlretrieve(self.M2V_URL + f, p)
|
| 515 |
+
m2v = Word2Vec.load(os.path.join(cache, "pretrained_embeddings"))
|
| 516 |
+
self.emb = {w: m2v.wv[w] for w in m2v.wv.index_to_key}
|
| 517 |
+
|
| 518 |
+
self.n_extra = None # determined on first call
|
| 519 |
+
|
| 520 |
+
def _pool_m2v(self, comp):
|
| 521 |
+
v, t = np.zeros(200, np.float32), 0.0
|
| 522 |
+
for s, f in comp.get_el_amt_dict().items():
|
| 523 |
+
if s in self.emb:
|
| 524 |
+
v += f * self.emb[s]
|
| 525 |
+
t += f
|
| 526 |
+
return v / max(t, 1e-8)
|
| 527 |
+
|
| 528 |
+
def _featurize_extras(self, comp):
|
| 529 |
+
parts = []
|
| 530 |
+
for name, ft in self.extra_ftzrs:
|
| 531 |
+
try:
|
| 532 |
+
vals = np.array(ft.featurize(comp), np.float32)
|
| 533 |
+
parts.append(np.nan_to_num(vals, nan=0.0))
|
| 534 |
+
if self._extra_sizes.get(name) is None:
|
| 535 |
+
self._extra_sizes[name] = len(vals)
|
| 536 |
+
except:
|
| 537 |
+
sz = self._extra_sizes.get(name, 0) or 1
|
| 538 |
+
parts.append(np.zeros(sz, np.float32))
|
| 539 |
+
return np.concatenate(parts)
|
| 540 |
+
|
| 541 |
+
def featurize_all(self, compositions, structures):
|
| 542 |
+
"""Return [N, D_comp] array of all composition features."""
|
| 543 |
+
# Determine dimensions from first sample
|
| 544 |
+
test_extras = self._featurize_extras(compositions[0])
|
| 545 |
+
self.n_extra = len(test_extras)
|
| 546 |
+
struct_feats_dim = 11
|
| 547 |
+
total_dim = self.n_magpie + self.n_extra + struct_feats_dim + 200
|
| 548 |
+
|
| 549 |
+
log.info(f" Composition features: {self.n_magpie} MAGPIE + "
|
| 550 |
+
f"{self.n_extra} Extras + 11 Structural + 200 mat2vec = {total_dim}d")
|
| 551 |
+
|
| 552 |
+
out = []
|
| 553 |
+
for i, comp in enumerate(tqdm(compositions, desc=" Featurizing compositions", leave=False)):
|
| 554 |
+
# MAGPIE
|
| 555 |
+
try:
|
| 556 |
+
mg = np.array(self.ep_magpie.featurize(comp), np.float32)
|
| 557 |
+
except:
|
| 558 |
+
mg = np.zeros(self.n_magpie, np.float32)
|
| 559 |
+
mg = np.nan_to_num(mg, nan=0.0)
|
| 560 |
+
|
| 561 |
+
# Extra matminer
|
| 562 |
+
ex = self._featurize_extras(comp)
|
| 563 |
+
|
| 564 |
+
# Structural
|
| 565 |
+
sf = compute_structural_features(structures[i])
|
| 566 |
+
|
| 567 |
+
# mat2vec
|
| 568 |
+
m2v = self._pool_m2v(comp)
|
| 569 |
+
|
| 570 |
+
out.append(np.concatenate([mg, ex, sf, m2v]))
|
| 571 |
+
|
| 572 |
+
return np.array(out, dtype=np.float32)
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
# ═══════════════════════════════════════════════════════════════
|
| 576 |
+
# MAIN — BUILD AND SAVE
|
| 577 |
+
# ═══════════════════════════════════════════════════════════════
|
| 578 |
+
|
| 579 |
+
def main():
|
| 580 |
+
t0 = time.time()
|
| 581 |
+
print("""
|
| 582 |
+
+==========================================================+
|
| 583 |
+
| V6 Physics-Featurized Phonon Dataset Builder |
|
| 584 |
+
| 3-Order Graphs | Bond Physics | Architecture-Agnostic |
|
| 585 |
+
| ⚠ NO SCALING — raw features only |
|
| 586 |
+
+==========================================================+
|
| 587 |
+
""")
|
| 588 |
+
|
| 589 |
+
# ── LOAD MATBENCH DATA ────────────────────────────────────
|
| 590 |
+
print(" Loading matbench_phonons...")
|
| 591 |
+
from matminer.datasets import load_dataset
|
| 592 |
+
df = load_dataset("matbench_phonons")
|
| 593 |
+
targets = np.array(df['last phdos peak'].tolist(), np.float32)
|
| 594 |
+
structures = df['structure'].tolist()
|
| 595 |
+
compositions = [s.composition for s in structures]
|
| 596 |
+
N = len(structures)
|
| 597 |
+
print(f" Loaded: {N} samples")
|
| 598 |
+
print(f" Target range: {targets.min():.1f} – {targets.max():.1f} cm⁻¹")
|
| 599 |
+
|
| 600 |
+
# ── BUILD ELEMENT TABLE ───────────────────────────────────
|
| 601 |
+
print("\n Building element physics table...")
|
| 602 |
+
elem_table = build_element_table()
|
| 603 |
+
print(f" Element table: {elem_table.shape} (Z=0..102, {N_ELEM_FEAT} features)")
|
| 604 |
+
|
| 605 |
+
# ── BUILD CRYSTAL GRAPHS ─────────────────────────────────
|
| 606 |
+
print(f"\n Building 3-order crystal graphs ({MAX_NEIGHBORS}-NN, cutoff={CUTOFF}Å)...")
|
| 607 |
+
graphs = []
|
| 608 |
+
global_physics_list = []
|
| 609 |
+
|
| 610 |
+
for i, struct in enumerate(tqdm(structures, desc=" Building graphs")):
|
| 611 |
+
g = build_crystal_graph(struct, elem_table)
|
| 612 |
+
gp = compute_global_physics(g, struct, elem_table)
|
| 613 |
+
graphs.append(g)
|
| 614 |
+
global_physics_list.append(gp)
|
| 615 |
+
|
| 616 |
+
# Stats
|
| 617 |
+
n_atoms_list = [g['n_atoms'] for g in graphs]
|
| 618 |
+
n_edges_list = [g['n_edges'] for g in graphs]
|
| 619 |
+
n_trips_list = [g['n_triplets'] for g in graphs]
|
| 620 |
+
n_quads_list = [g['n_quads'] for g in graphs]
|
| 621 |
+
print(f" Graphs built:")
|
| 622 |
+
print(f" Atoms/crystal: min={min(n_atoms_list)}, max={max(n_atoms_list)}, "
|
| 623 |
+
f"mean={np.mean(n_atoms_list):.1f}")
|
| 624 |
+
print(f" Edges/crystal: min={min(n_edges_list)}, max={max(n_edges_list)}, "
|
| 625 |
+
f"mean={np.mean(n_edges_list):.1f}")
|
| 626 |
+
print(f" Triplets/crystal: min={min(n_trips_list)}, max={max(n_trips_list)}, "
|
| 627 |
+
f"mean={np.mean(n_trips_list):.1f}")
|
| 628 |
+
print(f" Quads/crystal: min={min(n_quads_list)}, max={max(n_quads_list)}, "
|
| 629 |
+
f"mean={np.mean(n_quads_list):.1f}")
|
| 630 |
+
|
| 631 |
+
global_physics = torch.stack(global_physics_list)
|
| 632 |
+
print(f" Global physics: {global_physics.shape}")
|
| 633 |
+
|
| 634 |
+
# ── COMPOSITION FEATURES ─────────────────────────────────
|
| 635 |
+
print("\n Computing composition features...")
|
| 636 |
+
feat = CompositionFeaturizer()
|
| 637 |
+
comp_features = feat.featurize_all(compositions, structures)
|
| 638 |
+
print(f" Composition features shape: {comp_features.shape}")
|
| 639 |
+
|
| 640 |
+
# ── FOLD INDICES (strict matbench protocol) ──────────────
|
| 641 |
+
print(f"\n Computing 5-fold split indices (seed={FOLD_SEED})...")
|
| 642 |
+
kf = KFold(N_FOLDS, shuffle=True, random_state=FOLD_SEED)
|
| 643 |
+
fold_indices = [(train_idx.tolist(), test_idx.tolist())
|
| 644 |
+
for train_idx, test_idx in kf.split(range(N))]
|
| 645 |
+
|
| 646 |
+
# Verify zero leakage
|
| 647 |
+
for fi, (tr, te) in enumerate(fold_indices):
|
| 648 |
+
overlap = set(tr) & set(te)
|
| 649 |
+
assert len(overlap) == 0, f"DATA LEAK in fold {fi}: {len(overlap)} shared indices!"
|
| 650 |
+
assert len(tr) + len(te) == N, f"Fold {fi}: missing samples!"
|
| 651 |
+
print(" ✅ All folds verified: ZERO data leakage")
|
| 652 |
+
|
| 653 |
+
# ── FEATURE DIMENSION INFO ───────────────────────────────
|
| 654 |
+
n_magpie = feat.n_magpie
|
| 655 |
+
n_extra = feat.n_extra
|
| 656 |
+
feature_info = {
|
| 657 |
+
'atom_features_dim': N_ATOM_FEAT,
|
| 658 |
+
'atom_features_layout': [
|
| 659 |
+
'mass', '1/sqrt_mass', 'electronegativity', 'atomic_radius',
|
| 660 |
+
'covalent_radius', 'ionization_energy', 'electron_affinity',
|
| 661 |
+
'valence_electrons', 'group', 'period', 'block', 'is_metal',
|
| 662 |
+
'frac_x', 'frac_y', 'frac_z',
|
| 663 |
+
'coordination_num', 'avg_nn_dist', 'std_nn_dist',
|
| 664 |
+
],
|
| 665 |
+
'edge_physics_dim': N_BOND_PHYSICS,
|
| 666 |
+
'edge_physics_layout': [
|
| 667 |
+
'force_constant', 'reduced_mass', 'einstein_freq',
|
| 668 |
+
'en_difference', 'ionicity', 'radius_sum_ratio',
|
| 669 |
+
'mass_ratio', 'inverse_distance',
|
| 670 |
+
],
|
| 671 |
+
'edge_rbf_dim': N_RBF_DIST,
|
| 672 |
+
'angle_rbf_dim': N_RBF_ANGLE,
|
| 673 |
+
'dihedral_rbf_dim': N_RBF_DIHEDRAL,
|
| 674 |
+
'global_physics_dim': N_GLOBAL_PHYS,
|
| 675 |
+
'global_physics_layout': [
|
| 676 |
+
'avg_force_constant', 'std_force_constant', 'avg_reduced_mass',
|
| 677 |
+
'mass_variance', 'avg_einstein_freq', 'en_variance',
|
| 678 |
+
'debye_temp_estimate', 'avg_coordination', 'density',
|
| 679 |
+
'volume_per_atom', 'packing_fraction', 'avg_bond_length',
|
| 680 |
+
'std_bond_length', 'max_atomic_mass', 'min_atomic_mass',
|
| 681 |
+
],
|
| 682 |
+
'comp_magpie_range': (0, n_magpie),
|
| 683 |
+
'comp_extras_range': (n_magpie, n_magpie + n_extra),
|
| 684 |
+
'comp_structural_range': (n_magpie + n_extra, n_magpie + n_extra + 11),
|
| 685 |
+
'comp_mat2vec_range': (n_magpie + n_extra + 11, n_magpie + n_extra + 11 + 200),
|
| 686 |
+
'comp_total_dim': comp_features.shape[1],
|
| 687 |
+
}
|
| 688 |
+
|
| 689 |
+
# ── SAVE ─────────────────────────────────────────────────
|
| 690 |
+
save_path = "phonons_v6_dataset.pt"
|
| 691 |
+
save_data = {
|
| 692 |
+
# Per-crystal data
|
| 693 |
+
'graphs': graphs,
|
| 694 |
+
'comp_features': torch.tensor(comp_features, dtype=torch.float32),
|
| 695 |
+
'global_physics': global_physics,
|
| 696 |
+
'targets': torch.tensor(targets, dtype=torch.float32),
|
| 697 |
+
|
| 698 |
+
# Fold indices
|
| 699 |
+
'fold_indices': fold_indices,
|
| 700 |
+
'fold_seed': FOLD_SEED,
|
| 701 |
+
|
| 702 |
+
# Metadata
|
| 703 |
+
'n_samples': N,
|
| 704 |
+
'feature_info': feature_info,
|
| 705 |
+
'element_table': elem_table,
|
| 706 |
+
'config': {
|
| 707 |
+
'cutoff': CUTOFF,
|
| 708 |
+
'max_neighbors': MAX_NEIGHBORS,
|
| 709 |
+
'n_rbf_dist': N_RBF_DIST,
|
| 710 |
+
'n_rbf_angle': N_RBF_ANGLE,
|
| 711 |
+
'n_rbf_dihedral': N_RBF_DIHEDRAL,
|
| 712 |
+
'max_quads': MAX_QUADS,
|
| 713 |
+
'fold_seed': FOLD_SEED,
|
| 714 |
+
'n_folds': N_FOLDS,
|
| 715 |
+
},
|
| 716 |
+
}
|
| 717 |
+
torch.save(save_data, save_path)
|
| 718 |
+
|
| 719 |
+
size_mb = os.path.getsize(save_path) / 1e6
|
| 720 |
+
dt = time.time() - t0
|
| 721 |
+
print(f"\n ✅ Saved: {save_path} ({size_mb:.1f} MB)")
|
| 722 |
+
print(f" Total time: {dt:.1f}s")
|
| 723 |
+
|
| 724 |
+
# ── SUMMARY ──────────────────────────────────────────────
|
| 725 |
+
print(f"""
|
| 726 |
+
╔══════════════════════════════════════════════════════════╗
|
| 727 |
+
║ Dataset Summary ║
|
| 728 |
+
╠══════════════════════════════════════════════════════════╣
|
| 729 |
+
║ Samples: {N:>6} ║
|
| 730 |
+
║ Atom features: {N_ATOM_FEAT:>6}d (12 elem + 3 coord + 3 local) ║
|
| 731 |
+
║ Bond RBF: {N_RBF_DIST:>6}d ║
|
| 732 |
+
║ Bond physics: {N_BOND_PHYSICS:>6}d (k, μ, ω, Δχ, ...) ║
|
| 733 |
+
║ Angle RBF: {N_RBF_ANGLE:>6}d ║
|
| 734 |
+
║ Dihedral RBF: {N_RBF_DIHEDRAL:>6}d ║
|
| 735 |
+
║ Composition: {comp_features.shape[1]:>6}d (MAGPIE+extras+struct+m2v)║
|
| 736 |
+
║ Global physics: {N_GLOBAL_PHYS:>6}d ║
|
| 737 |
+
║ Folds: {N_FOLDS:>6} (seed={FOLD_SEED}) ║
|
| 738 |
+
║ File size: {size_mb:>5.1f} MB ║
|
| 739 |
+
╚══════════════════════════════════════════════════════════╝
|
| 740 |
+
|
| 741 |
+
⚠ Remember: NO scaling applied. Apply StandardScaler at
|
| 742 |
+
training time using ONLY train-fold indices!
|
| 743 |
+
|
| 744 |
+
Architecture-agnostic: plug ANY model on top of this dataset.
|
| 745 |
+
""")
|
| 746 |
+
|
| 747 |
+
|
| 748 |
+
if __name__ == '__main__':
|
| 749 |
+
main()
|
model_code/phonons_model.py
ADDED
|
@@ -0,0 +1,839 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
+=============================================================+
|
| 3 |
+
| TRIADS V6 — Graph Attention TRM + Gate-Based Halting |
|
| 4 |
+
| |
|
| 5 |
+
| Single model: Gate-halt (4-16 adaptive cycles) |
|
| 6 |
+
| d=56, 4 heads, gated residuals, deep supervision |
|
| 7 |
+
| SWA last 50 ep | 200 epochs |
|
| 8 |
+
| |
|
| 9 |
+
| Loads: phonons_v6_dataset.pt |
|
| 10 |
+
+=============================================================+
|
| 11 |
+
|
| 12 |
+
DEPENDENCIES (dataset already pre-computed, no matminer needed):
|
| 13 |
+
pip install torch numpy scikit-learn tqdm
|
| 14 |
+
(all pre-installed on Kaggle)
|
| 15 |
+
|
| 16 |
+
USAGE:
|
| 17 |
+
python phonons_v6.py
|
| 18 |
+
"""
|
| 19 |
+
|
| 20 |
+
import os, copy, json, time, math, warnings, threading
|
| 21 |
+
from collections import defaultdict
|
| 22 |
+
warnings.filterwarnings('ignore')
|
| 23 |
+
import numpy as np
|
| 24 |
+
import torch
|
| 25 |
+
import torch.nn as nn
|
| 26 |
+
import torch.nn.functional as F
|
| 27 |
+
from torch.optim.swa_utils import AveragedModel, SWALR
|
| 28 |
+
from sklearn.preprocessing import StandardScaler
|
| 29 |
+
|
| 30 |
+
# Notebook dashboard (IPython is always available on Kaggle)
|
| 31 |
+
try:
|
| 32 |
+
from IPython.display import display, HTML, clear_output
|
| 33 |
+
IN_NOTEBOOK = True
|
| 34 |
+
except ImportError:
|
| 35 |
+
IN_NOTEBOOK = False
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# ═══════════════════════════════════════════════════════════════
|
| 39 |
+
# CONFIG
|
| 40 |
+
# ═══════════════════════════════════════════════════════════════
|
| 41 |
+
|
| 42 |
+
D = 56
|
| 43 |
+
N_HEADS = 4
|
| 44 |
+
N_WARMUP = 1 # 1 unshared warm-up (param budget)
|
| 45 |
+
N_ANGLE_RBF = 8
|
| 46 |
+
DROPOUT = 0.1
|
| 47 |
+
BATCH_SIZE = 64
|
| 48 |
+
EPOCHS = 200
|
| 49 |
+
SWA_START = 150
|
| 50 |
+
LR = 5e-4
|
| 51 |
+
WD = 1e-4
|
| 52 |
+
SEEDS = [42]
|
| 53 |
+
|
| 54 |
+
# Gate-halt model
|
| 55 |
+
MIN_CYCLES = 4
|
| 56 |
+
MAX_CYCLES = 16
|
| 57 |
+
GATE_HALT_THR = 0.05 # halt when max gate < this
|
| 58 |
+
GATE_SPARSITY = 0.001 # encourage gates to close
|
| 59 |
+
|
| 60 |
+
BASELINES = {
|
| 61 |
+
'MEGNet': 28.76, 'ALIGNN': 29.34, 'MODNet': 45.39,
|
| 62 |
+
'CrabNet': 47.09, 'TRIADS V4': 56.33, 'TRIADS V3.1': 63.00,
|
| 63 |
+
'TRIADS V1': 71.82, 'Dummy': 323.76,
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# ═══════════════════════════════════════════════════════════════
|
| 68 |
+
# SCATTER
|
| 69 |
+
# ═══════════════════════════════════════════════════════════════
|
| 70 |
+
|
| 71 |
+
def scatter_sum(src, idx, dim_size):
|
| 72 |
+
out = torch.zeros(dim_size, src.shape[-1], dtype=src.dtype, device=src.device)
|
| 73 |
+
out.scatter_add_(0, idx.unsqueeze(-1).expand_as(src), src)
|
| 74 |
+
return out
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
# ═══════════════════════════════════════════════════════════════
|
| 78 |
+
# COLLATION + DATALOADER
|
| 79 |
+
# ═══════════════════════════════════════════════════════════════
|
| 80 |
+
|
| 81 |
+
def collate(graphs, comp, glob_phys, targets, indices, device):
|
| 82 |
+
az, af = [], []
|
| 83 |
+
ei, rb, vc, ph = [], [], [], []
|
| 84 |
+
tr, an = [], []
|
| 85 |
+
ba, na_list = [], []
|
| 86 |
+
a_off, e_off = 0, 0
|
| 87 |
+
|
| 88 |
+
for k, i in enumerate(indices):
|
| 89 |
+
g = graphs[i]
|
| 90 |
+
na, ne = g['n_atoms'], g['n_edges']
|
| 91 |
+
az.append(g['atom_z'])
|
| 92 |
+
af.append(g['atom_features'])
|
| 93 |
+
ei.append(g['edge_index'] + a_off)
|
| 94 |
+
rb.append(g['edge_rbf']); vc.append(g['edge_vec']); ph.append(g['edge_physics'])
|
| 95 |
+
tr.append(g['triplet_index'] + e_off)
|
| 96 |
+
an.append(g['angle_rbf'])
|
| 97 |
+
ba.append(torch.full((na,), k, dtype=torch.long))
|
| 98 |
+
na_list.append(na)
|
| 99 |
+
a_off += na; e_off += ne
|
| 100 |
+
|
| 101 |
+
return (
|
| 102 |
+
comp[indices].to(device),
|
| 103 |
+
glob_phys[indices].to(device),
|
| 104 |
+
{
|
| 105 |
+
'atom_z': torch.cat(az).to(device),
|
| 106 |
+
'atom_feat': torch.cat(af).to(device),
|
| 107 |
+
'ei': torch.cat(ei, 1).to(device),
|
| 108 |
+
'rbf': torch.cat(rb).to(device),
|
| 109 |
+
'vec': torch.cat(vc).to(device),
|
| 110 |
+
'phys': torch.cat(ph).to(device),
|
| 111 |
+
'triplets': torch.cat(tr, 1).to(device),
|
| 112 |
+
'angle_feat': torch.cat(an).to(device),
|
| 113 |
+
'batch': torch.cat(ba).to(device),
|
| 114 |
+
'n_crystals': len(indices),
|
| 115 |
+
'n_atoms': na_list,
|
| 116 |
+
},
|
| 117 |
+
targets[indices].to(device),
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
class Loader:
|
| 122 |
+
def __init__(self, graphs, comp, gp, tgt, idx, bs, dev, shuf=False):
|
| 123 |
+
self.g, self.c, self.gp, self.t = graphs, comp, gp, tgt
|
| 124 |
+
self.idx, self.bs, self.dev, self.shuf = np.array(idx), bs, dev, shuf
|
| 125 |
+
|
| 126 |
+
def __iter__(self):
|
| 127 |
+
i = self.idx.copy()
|
| 128 |
+
if self.shuf: np.random.shuffle(i)
|
| 129 |
+
self._b = [i[j:j+self.bs] for j in range(0, len(i), self.bs)]
|
| 130 |
+
self._p = 0; return self
|
| 131 |
+
|
| 132 |
+
def __next__(self):
|
| 133 |
+
if self._p >= len(self._b): raise StopIteration
|
| 134 |
+
b = self._b[self._p]; self._p += 1
|
| 135 |
+
return collate(self.g, self.c, self.gp, self.t, b, self.dev)
|
| 136 |
+
|
| 137 |
+
def __len__(self): return (len(self.idx) + self.bs - 1) // self.bs
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
# ═══════════════════════════════════════════════════════════════
|
| 141 |
+
# GRAPH MESSAGE PASSING LAYER (Line Graph style)
|
| 142 |
+
# ═══════════════════════════════════════════════════════════════
|
| 143 |
+
|
| 144 |
+
class GraphMPLayer(nn.Module):
|
| 145 |
+
"""Bond update (line graph) + Atom update (edge-gated)."""
|
| 146 |
+
|
| 147 |
+
def __init__(self, d, n_angle=N_ANGLE_RBF, dropout=DROPOUT):
|
| 148 |
+
super().__init__()
|
| 149 |
+
# Phase 1: Bond update from angular neighbors
|
| 150 |
+
self.bond_msg = nn.Sequential(nn.Linear(d*2 + n_angle, d), nn.SiLU())
|
| 151 |
+
self.bond_gate = nn.Sequential(nn.Linear(d*2 + n_angle, d), nn.Sigmoid())
|
| 152 |
+
self.bond_up = nn.Sequential(nn.Linear(d*2, d), nn.LayerNorm(d), nn.SiLU(), nn.Dropout(dropout))
|
| 153 |
+
# Phase 2: Atom update from bonds
|
| 154 |
+
self.atom_msg = nn.Sequential(nn.Linear(d*3, d), nn.SiLU())
|
| 155 |
+
self.atom_gate = nn.Sequential(nn.Linear(d*3, d), nn.Sigmoid())
|
| 156 |
+
self.atom_up = nn.Sequential(nn.Linear(d*2, d), nn.LayerNorm(d), nn.SiLU(), nn.Dropout(dropout))
|
| 157 |
+
|
| 158 |
+
def forward(self, atoms, bonds, ei, triplets, angle_feat):
|
| 159 |
+
# Phase 1: bonds learn from angular neighbors
|
| 160 |
+
if triplets.shape[1] > 0:
|
| 161 |
+
b_ij, b_kj = bonds[triplets[0]], bonds[triplets[1]]
|
| 162 |
+
inp = torch.cat([b_ij, b_kj, angle_feat], -1)
|
| 163 |
+
msg = self.bond_msg(inp) * self.bond_gate(inp)
|
| 164 |
+
agg = torch.zeros(bonds.size(0), bonds.size(1), dtype=torch.float32, device=msg.device)
|
| 165 |
+
agg.scatter_add_(0, triplets[0].unsqueeze(-1).expand_as(msg), msg)
|
| 166 |
+
bonds = bonds + self.bond_up(torch.cat([bonds, agg], -1))
|
| 167 |
+
# Phase 2: atoms aggregate from bonds
|
| 168 |
+
inp = torch.cat([atoms[ei[0]], atoms[ei[1]], bonds], -1)
|
| 169 |
+
msg = self.atom_msg(inp) * self.atom_gate(inp)
|
| 170 |
+
agg = scatter_sum(msg, ei[1], atoms.size(0))
|
| 171 |
+
atoms = atoms + self.atom_up(torch.cat([atoms, agg], -1))
|
| 172 |
+
return atoms, bonds
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
# ═══════════════════════════════════════════════════════════════
|
| 176 |
+
# PHONON V6 MODEL
|
| 177 |
+
# ═══════════════════════════════════════════════════════════════
|
| 178 |
+
|
| 179 |
+
class PhononV6(nn.Module):
|
| 180 |
+
"""
|
| 181 |
+
Graph Attention TRM for phonon prediction.
|
| 182 |
+
|
| 183 |
+
mode='fixed': Fixed n_cycles TRM cycles (Model 1)
|
| 184 |
+
mode='gate_halt': Gate-based implicit halting (Model 2)
|
| 185 |
+
"""
|
| 186 |
+
|
| 187 |
+
def __init__(self, comp_dim, global_phys_dim=15, d=D,
|
| 188 |
+
mode='gate_halt', n_cycles=MAX_CYCLES,
|
| 189 |
+
min_cycles=MIN_CYCLES, max_cycles=MAX_CYCLES,
|
| 190 |
+
n_warmup=N_WARMUP, n_heads=N_HEADS, dropout=DROPOUT):
|
| 191 |
+
super().__init__()
|
| 192 |
+
self.d = d
|
| 193 |
+
self.mode = mode
|
| 194 |
+
self.total_cycles = n_cycles if mode == 'fixed' else max_cycles
|
| 195 |
+
self.min_cycles = min_cycles if mode == 'gate_halt' else self.total_cycles
|
| 196 |
+
|
| 197 |
+
# Feature layout (from V6 dataset: 132 magpie + extras + 11 struct + 200 m2v)
|
| 198 |
+
self.n_magpie = 132
|
| 199 |
+
self.n_extra = comp_dim - 132 - 11 - 200
|
| 200 |
+
self.n_comp_tokens = 22 + 1 + 1 # 22 magpie + 1 extra + 1 m2v = 24
|
| 201 |
+
|
| 202 |
+
# ── Input Encoding ────────────────────────────────────
|
| 203 |
+
self.atom_embed = nn.Embedding(103, d)
|
| 204 |
+
self.atom_feat_proj = nn.Linear(18, d)
|
| 205 |
+
self.rbf_enc = nn.Linear(40, d)
|
| 206 |
+
self.vec_enc = nn.Linear(3, d)
|
| 207 |
+
self.phys_enc = nn.Linear(8, d)
|
| 208 |
+
|
| 209 |
+
# ── Composition Token Projections ─────────────────────
|
| 210 |
+
self.magpie_proj = nn.Linear(6, d)
|
| 211 |
+
self.extra_proj = nn.Linear(max(self.n_extra, 1), d)
|
| 212 |
+
self.m2v_proj = nn.Linear(200, d)
|
| 213 |
+
|
| 214 |
+
# ── Context (structural + global physics) ─────────────
|
| 215 |
+
self.ctx_proj = nn.Linear(11 + global_phys_dim, d)
|
| 216 |
+
|
| 217 |
+
# ── Token Type Embeddings ─────────────────────────────
|
| 218 |
+
self.type_embed = nn.Embedding(2, d)
|
| 219 |
+
|
| 220 |
+
# ── Warm-up Layers (unshared) ─────────────────────────
|
| 221 |
+
self.warmup = nn.ModuleList([GraphMPLayer(d, N_ANGLE_RBF, dropout) for _ in range(n_warmup)])
|
| 222 |
+
self.warmup_out = nn.Sequential(nn.Linear(d, d), nn.LayerNorm(d), nn.SiLU())
|
| 223 |
+
|
| 224 |
+
# ── Shared TRM Block ──────────────────────────────────
|
| 225 |
+
# Graph MP (shared)
|
| 226 |
+
self.trm_gnn = GraphMPLayer(d, N_ANGLE_RBF, dropout)
|
| 227 |
+
|
| 228 |
+
# Self-Attention
|
| 229 |
+
self.sa = nn.MultiheadAttention(d, n_heads, dropout=dropout, batch_first=True)
|
| 230 |
+
self.sa_n = nn.LayerNorm(d)
|
| 231 |
+
self.sa_ff = nn.Sequential(nn.Linear(d, d), nn.GELU(), nn.Dropout(dropout), nn.Linear(d, d))
|
| 232 |
+
self.sa_fn = nn.LayerNorm(d)
|
| 233 |
+
|
| 234 |
+
# Cross-Attention
|
| 235 |
+
self.ca = nn.MultiheadAttention(d, n_heads, dropout=dropout, batch_first=True)
|
| 236 |
+
self.ca_n = nn.LayerNorm(d)
|
| 237 |
+
|
| 238 |
+
# ── State Update (Gated Residuals) ───────────────────
|
| 239 |
+
self.z_proj = nn.Linear(d*3, d)
|
| 240 |
+
self.z_up = nn.Sequential(nn.Linear(d*2, d), nn.SiLU(), nn.Linear(d, d))
|
| 241 |
+
self.z_gate = nn.Sequential(nn.Linear(d*2, d), nn.Sigmoid())
|
| 242 |
+
self.y_up = nn.Sequential(nn.Linear(d*2, d), nn.SiLU(), nn.Linear(d, d))
|
| 243 |
+
self.y_gate = nn.Sequential(nn.Linear(d*2, d), nn.Sigmoid())
|
| 244 |
+
|
| 245 |
+
# ── Output Head ───────────────────────────────────────
|
| 246 |
+
self.head = nn.Sequential(nn.Linear(d, d//2), nn.SiLU(), nn.Linear(d//2, 1))
|
| 247 |
+
|
| 248 |
+
self._init_weights()
|
| 249 |
+
|
| 250 |
+
def _init_weights(self):
|
| 251 |
+
for m in self.modules():
|
| 252 |
+
if isinstance(m, nn.Linear):
|
| 253 |
+
nn.init.xavier_uniform_(m.weight)
|
| 254 |
+
if m.bias is not None: nn.init.zeros_(m.bias)
|
| 255 |
+
|
| 256 |
+
def forward(self, comp, glob_phys, g, deep_supervision=False):
|
| 257 |
+
B = g['n_crystals']
|
| 258 |
+
ei = g['ei']
|
| 259 |
+
dev = comp.device
|
| 260 |
+
|
| 261 |
+
# ══════════════════════════════════════════════════════
|
| 262 |
+
# INPUT ENCODING
|
| 263 |
+
# ══════════════════════════════════════════════════════
|
| 264 |
+
|
| 265 |
+
# Atom features
|
| 266 |
+
atoms = self.atom_embed(g['atom_z'].clamp(0, 102)) + self.atom_feat_proj(g['atom_feat'])
|
| 267 |
+
|
| 268 |
+
# Bond features: distance (direction-gated) + physics
|
| 269 |
+
bonds = self.rbf_enc(g['rbf']) * torch.tanh(self.vec_enc(g['vec'])) + self.phys_enc(g['phys'])
|
| 270 |
+
|
| 271 |
+
triplets = g['triplets']
|
| 272 |
+
angle_feat = g['angle_feat']
|
| 273 |
+
|
| 274 |
+
# ══════════════════════════════════════════════════════
|
| 275 |
+
# WARM-UP (2 unshared graph layers)
|
| 276 |
+
# ══════════════════════════════════════════════════════
|
| 277 |
+
|
| 278 |
+
for layer in self.warmup:
|
| 279 |
+
atoms, bonds = layer(atoms, bonds, ei, triplets, angle_feat)
|
| 280 |
+
atoms = self.warmup_out(atoms)
|
| 281 |
+
|
| 282 |
+
# ══════════════════════════════════════════════════════
|
| 283 |
+
# COMPOSITION TOKENS (24 total)
|
| 284 |
+
# ══════════════════════════════════════════════════════
|
| 285 |
+
|
| 286 |
+
magpie = comp[:, :132].view(B, 22, 6)
|
| 287 |
+
extras = comp[:, 132:132+self.n_extra]
|
| 288 |
+
s_meta = comp[:, 132+self.n_extra:132+self.n_extra+11]
|
| 289 |
+
m2v = comp[:, -200:]
|
| 290 |
+
|
| 291 |
+
mag_tok = self.magpie_proj(magpie) # [B, 22, d]
|
| 292 |
+
ext_tok = self.extra_proj(extras).unsqueeze(1) # [B, 1, d]
|
| 293 |
+
m2v_tok = self.m2v_proj(m2v).unsqueeze(1) # [B, 1, d]
|
| 294 |
+
comp_tok = torch.cat([mag_tok, ext_tok, m2v_tok], 1) # [B, 24, d]
|
| 295 |
+
|
| 296 |
+
comp_tok = comp_tok + self.type_embed.weight[0]
|
| 297 |
+
|
| 298 |
+
# Context vector (structural + global physics)
|
| 299 |
+
ctx = self.ctx_proj(torch.cat([s_meta, glob_phys], -1)) # [B, d]
|
| 300 |
+
|
| 301 |
+
# ══════════════════════════════════════════════════════
|
| 302 |
+
# TRM REASONING LOOP
|
| 303 |
+
# ══════════════════════════════════════════════════════
|
| 304 |
+
|
| 305 |
+
z = torch.zeros(B, self.d, device=dev)
|
| 306 |
+
y = torch.zeros(B, self.d, device=dev)
|
| 307 |
+
preds = []
|
| 308 |
+
n_atoms = g['n_atoms']
|
| 309 |
+
self._gate_sparsity = 0. # track gate magnitudes for regularizer
|
| 310 |
+
|
| 311 |
+
for cyc in range(self.total_cycles):
|
| 312 |
+
# ── Phase 1+2: Graph MP (shared weights) ──────────
|
| 313 |
+
atoms, bonds = self.trm_gnn(atoms, bonds, ei, triplets, angle_feat)
|
| 314 |
+
|
| 315 |
+
# ── Pad atoms for attention ─────────────────��─────
|
| 316 |
+
ma = max(n_atoms)
|
| 317 |
+
atom_tok = atoms.new_zeros(B, ma, self.d)
|
| 318 |
+
atom_mask = torch.ones(B, ma, dtype=torch.bool, device=dev)
|
| 319 |
+
off = 0
|
| 320 |
+
for i, na in enumerate(n_atoms):
|
| 321 |
+
atom_tok[i, :na] = atoms[off:off+na]
|
| 322 |
+
atom_mask[i, :na] = False
|
| 323 |
+
off += na
|
| 324 |
+
atom_tok = atom_tok + self.type_embed.weight[1]
|
| 325 |
+
|
| 326 |
+
# ── Phase 3: Joint Self-Attention ─────────────────
|
| 327 |
+
all_tok = torch.cat([comp_tok, atom_tok], 1)
|
| 328 |
+
full_mask = torch.cat([
|
| 329 |
+
torch.zeros(B, self.n_comp_tokens, dtype=torch.bool, device=dev),
|
| 330 |
+
atom_mask
|
| 331 |
+
], 1)
|
| 332 |
+
|
| 333 |
+
sa_out = self.sa(all_tok, all_tok, all_tok, key_padding_mask=full_mask)[0]
|
| 334 |
+
all_tok = self.sa_n(all_tok + sa_out)
|
| 335 |
+
all_tok = self.sa_fn(all_tok + self.sa_ff(all_tok))
|
| 336 |
+
|
| 337 |
+
comp_tok = all_tok[:, :self.n_comp_tokens]
|
| 338 |
+
atom_tok = all_tok[:, self.n_comp_tokens:]
|
| 339 |
+
|
| 340 |
+
# ── Phase 4: Cross-Attention (comp queries atoms) ─
|
| 341 |
+
ca_out = self.ca(comp_tok, atom_tok, atom_tok, key_padding_mask=atom_mask)[0]
|
| 342 |
+
comp_tok = self.ca_n(comp_tok + ca_out)
|
| 343 |
+
|
| 344 |
+
# ── Unpad atoms back to flat ──────────────────────
|
| 345 |
+
parts = [atom_tok[i, :n_atoms[i]] for i in range(B)]
|
| 346 |
+
atoms = torch.cat(parts, 0)
|
| 347 |
+
|
| 348 |
+
# ── Phase 5: State Update (Gated Residuals) ───────
|
| 349 |
+
xp = comp_tok.mean(dim=1) # [B, d]
|
| 350 |
+
|
| 351 |
+
z_inp = self.z_proj(torch.cat([xp, ctx, y], -1))
|
| 352 |
+
z_cand = self.z_up(torch.cat([z_inp, z], -1))
|
| 353 |
+
z_g = self.z_gate(torch.cat([z_inp, z], -1))
|
| 354 |
+
z = z + z_g * z_cand
|
| 355 |
+
|
| 356 |
+
y_cand = self.y_up(torch.cat([y, z], -1))
|
| 357 |
+
y_g = self.y_gate(torch.cat([y, z], -1))
|
| 358 |
+
y = y + y_g * y_cand
|
| 359 |
+
# Track gate sparsity (mean of all gate activations)
|
| 360 |
+
self._gate_sparsity = self._gate_sparsity + (z_g.mean() + y_g.mean()) / 2
|
| 361 |
+
|
| 362 |
+
preds.append(self.head(y).squeeze(-1))
|
| 363 |
+
|
| 364 |
+
# ── Phase 6: Gate-Based Halting ────────────────────
|
| 365 |
+
if self.mode == 'gate_halt' and cyc >= self.min_cycles - 1:
|
| 366 |
+
if y_g.max().item() < GATE_HALT_THR:
|
| 367 |
+
break
|
| 368 |
+
|
| 369 |
+
# Normalize gate sparsity by number of cycles actually run
|
| 370 |
+
n_run = len(preds)
|
| 371 |
+
self._gate_sparsity = self._gate_sparsity / max(n_run, 1)
|
| 372 |
+
|
| 373 |
+
return preds if deep_supervision else preds[-1]
|
| 374 |
+
|
| 375 |
+
def count_parameters(self):
|
| 376 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 377 |
+
|
| 378 |
+
|
| 379 |
+
# ═══════════════════════════════════════════════════════════════
|
| 380 |
+
# LOSS FUNCTIONS
|
| 381 |
+
# ═══════════════════════════════════════════════════════════════
|
| 382 |
+
|
| 383 |
+
def deep_sup_loss(preds_list, targets):
|
| 384 |
+
"""Linearly-weighted deep supervision: later cycles get more weight."""
|
| 385 |
+
p = torch.stack(preds_list)
|
| 386 |
+
w = torch.arange(1, p.shape[0]+1, device=p.device, dtype=p.dtype)
|
| 387 |
+
w = w / w.sum()
|
| 388 |
+
return (w * (p - targets.unsqueeze(0)).abs().mean(1)).sum()
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
def gate_halt_loss(preds_list, targets, gate_sparsity):
|
| 392 |
+
"""Deep supervision + gate sparsity to encourage early halting."""
|
| 393 |
+
return deep_sup_loss(preds_list, targets) + GATE_SPARSITY * gate_sparsity
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
# ═══════════════════════════════════════════════════════════════
|
| 397 |
+
# STRATIFIED SPLIT (within train fold → train/val)
|
| 398 |
+
# ═══════════════════════════════════════════════════════════════
|
| 399 |
+
|
| 400 |
+
def strat_split(t, vf=0.15, seed=42):
|
| 401 |
+
bins = np.digitize(t, np.percentile(t, [25, 50, 75]))
|
| 402 |
+
tr, vl = [], []
|
| 403 |
+
rng = np.random.RandomState(seed)
|
| 404 |
+
for b in range(4):
|
| 405 |
+
m = np.where(bins == b)[0]
|
| 406 |
+
if len(m) == 0: continue
|
| 407 |
+
n = max(1, int(len(m) * vf))
|
| 408 |
+
c = rng.choice(m, n, replace=False)
|
| 409 |
+
vl.extend(c.tolist())
|
| 410 |
+
tr.extend(np.setdiff1d(m, c).tolist())
|
| 411 |
+
return np.array(tr), np.array(vl)
|
| 412 |
+
|
| 413 |
+
|
| 414 |
+
# ═══════════════════════════════════════════════════════════════
|
| 415 |
+
# LIVE DASHBOARD (IPython HTML — works in Kaggle/Jupyter)
|
| 416 |
+
# ═══════════════════════════════════════════════════════════════
|
| 417 |
+
|
| 418 |
+
_print_lock = threading.Lock()
|
| 419 |
+
|
| 420 |
+
# Shared state updated by training threads, read by dashboard
|
| 421 |
+
_dash_state = {
|
| 422 |
+
'GH': {'fold': 0, 'ep': 0, 'tr': float('inf'), 'val': float('inf'),
|
| 423 |
+
'best': float('inf'), 'best_ep': 0, 'lr': 0., 'eta_m': 0,
|
| 424 |
+
'ep_s': 0., 'swa': False, 'done': False, 'test_mae': None},
|
| 425 |
+
}
|
| 426 |
+
_dash_log = [] # Accumulates milestone messages
|
| 427 |
+
|
| 428 |
+
|
| 429 |
+
def _log(msg):
|
| 430 |
+
with _print_lock:
|
| 431 |
+
_dash_log.append(msg)
|
| 432 |
+
if not IN_NOTEBOOK:
|
| 433 |
+
print(msg, flush=True)
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
def _render_html():
|
| 437 |
+
"""Build an HTML table from _dash_state + recent log lines."""
|
| 438 |
+
css = (
|
| 439 |
+
'<style>'
|
| 440 |
+
'.tv6{font-family:monospace;font-size:13px;border-collapse:collapse;width:100%}'
|
| 441 |
+
'.tv6 th{background:#1a1a2e;color:#e94560;padding:6px 10px;text-align:right;border-bottom:2px solid #e94560}'
|
| 442 |
+
'.tv6 td{padding:5px 10px;text-align:right;border-bottom:1px solid #333}'
|
| 443 |
+
'.tv6 tr:nth-child(odd){background:#16213e}'
|
| 444 |
+
'.tv6 tr:nth-child(even){background:#0f3460}'
|
| 445 |
+
'.tv6 td:first-child,.tv6 th:first-child{text-align:left;font-weight:bold;color:#e9c46a}'
|
| 446 |
+
'.tv6 .best{color:#2ecc71;font-weight:bold}'
|
| 447 |
+
'.tv6 .done{color:#2ecc71}'
|
| 448 |
+
'.tv6 .swa{color:#9b59b6}'
|
| 449 |
+
'.tv6 .training{color:#f39c12}'
|
| 450 |
+
'.tv6 .waiting{color:#636e72}'
|
| 451 |
+
'.logbox{font-family:monospace;font-size:12px;color:#dfe6e9;background:#0a0a0a;'
|
| 452 |
+
'padding:8px 12px;margin-top:8px;border-radius:6px;max-height:200px;overflow-y:auto}'
|
| 453 |
+
'</style>'
|
| 454 |
+
)
|
| 455 |
+
rows = ''
|
| 456 |
+
for name, s in _dash_state.items():
|
| 457 |
+
if s['done'] and s['test_mae']:
|
| 458 |
+
status = f'<span class="done">✅ {s["test_mae"]:.1f}</span>'
|
| 459 |
+
elif s['swa']:
|
| 460 |
+
status = '<span class="swa">SWA</span>'
|
| 461 |
+
elif s['ep'] == 0:
|
| 462 |
+
status = '<span class="waiting">Waiting</span>'
|
| 463 |
+
else:
|
| 464 |
+
status = '<span class="training">▶ Training</span>'
|
| 465 |
+
ep_str = f"{s['ep']}/{EPOCHS}" if s['ep'] else '-'
|
| 466 |
+
tr_str = f"{s['tr']:.1f}" if s['tr'] < 1e6 else '-'
|
| 467 |
+
val_str = f"{s['val']:.1f}" if s['val'] < 1e6 else '-'
|
| 468 |
+
best_str = f'<span class="best">{s["best"]:.1f}@{s["best_ep"]}</span>' if s['best'] < 1e6 else '-'
|
| 469 |
+
lr_str = f"{s['lr']:.0e}" if s['lr'] > 0 else '-'
|
| 470 |
+
eps_str = f"{s['ep_s']:.1f}" if s['ep_s'] > 0 else '-'
|
| 471 |
+
eta_str = f"{s['eta_m']:.0f}m" if s['eta_m'] > 0 else '-'
|
| 472 |
+
fold_str = str(s['fold']) if s['fold'] else '-'
|
| 473 |
+
rows += (f'<tr><td>{name}</td><td>{fold_str}</td><td>{ep_str}</td>'
|
| 474 |
+
f'<td>{tr_str}</td><td>{val_str}</td><td>{best_str}</td>'
|
| 475 |
+
f'<td>{lr_str}</td><td>{eps_str}</td><td>{eta_str}</td>'
|
| 476 |
+
f'<td>{status}</td></tr>')
|
| 477 |
+
table = (
|
| 478 |
+
f'{css}<h3 style="color:#e94560;font-family:monospace;margin:4px 0">⚡ TRIADS V6 — Live Dashboard</h3>'
|
| 479 |
+
f'<table class="tv6"><tr><th>Model</th><th>Fold</th><th>Epoch</th>'
|
| 480 |
+
f'<th>Train MAE</th><th>Val MAE</th><th>Best MAE</th>'
|
| 481 |
+
f'<th>LR</th><th>s/ep</th><th>ETA</th><th>Status</th></tr>{rows}</table>'
|
| 482 |
+
)
|
| 483 |
+
# Show last 8 log messages
|
| 484 |
+
if _dash_log:
|
| 485 |
+
log_html = '<br>'.join(_dash_log[-8:])
|
| 486 |
+
table += f'<div class="logbox">{log_html}</div>'
|
| 487 |
+
return table
|
| 488 |
+
|
| 489 |
+
|
| 490 |
+
class Dashboard:
|
| 491 |
+
"""Background thread that re-renders an HTML table every 5s in-place."""
|
| 492 |
+
def __init__(self):
|
| 493 |
+
self._stop = threading.Event()
|
| 494 |
+
self._thread = None
|
| 495 |
+
|
| 496 |
+
def start(self):
|
| 497 |
+
if not IN_NOTEBOOK:
|
| 498 |
+
return
|
| 499 |
+
self._stop.clear()
|
| 500 |
+
self._thread = threading.Thread(target=self._run, daemon=True)
|
| 501 |
+
self._thread.start()
|
| 502 |
+
|
| 503 |
+
def stop(self):
|
| 504 |
+
if not IN_NOTEBOOK or self._thread is None:
|
| 505 |
+
return
|
| 506 |
+
self._stop.set()
|
| 507 |
+
self._thread.join(timeout=10)
|
| 508 |
+
# Final render
|
| 509 |
+
clear_output(wait=True)
|
| 510 |
+
display(HTML(_render_html()))
|
| 511 |
+
|
| 512 |
+
def _run(self):
|
| 513 |
+
while not self._stop.is_set():
|
| 514 |
+
try:
|
| 515 |
+
clear_output(wait=True)
|
| 516 |
+
display(HTML(_render_html()))
|
| 517 |
+
except Exception:
|
| 518 |
+
pass
|
| 519 |
+
self._stop.wait(5)
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
_dashboard = Dashboard()
|
| 523 |
+
|
| 524 |
+
|
| 525 |
+
def train_fold_core(model, tr_loader, vl_loader, device, fold, seed,
|
| 526 |
+
model_name, tgt_mean=0., tgt_std=1., log_every=10):
|
| 527 |
+
"""
|
| 528 |
+
Train one model on one device. Uses AMP + structured line logging.
|
| 529 |
+
Returns (best_val_mae, model_with_best_weights).
|
| 530 |
+
"""
|
| 531 |
+
opt = torch.optim.AdamW(model.parameters(), lr=LR, weight_decay=WD)
|
| 532 |
+
# Cosine scheduler with 10-epoch linear warmup
|
| 533 |
+
WARMUP_EP = 10
|
| 534 |
+
def lr_lambda(ep):
|
| 535 |
+
if ep < WARMUP_EP: return (ep + 1) / WARMUP_EP
|
| 536 |
+
progress = (ep - WARMUP_EP) / max(1, EPOCHS - WARMUP_EP)
|
| 537 |
+
return 0.5 * (1 + math.cos(math.pi * progress)) * (1 - 1e-5/LR) + 1e-5/LR
|
| 538 |
+
sch = torch.optim.lr_scheduler.LambdaLR(opt, lr_lambda)
|
| 539 |
+
|
| 540 |
+
swa_model = AveragedModel(model)
|
| 541 |
+
swa_sch = SWALR(opt, swa_lr=1e-4)
|
| 542 |
+
|
| 543 |
+
bv, bw, best_ep = float('inf'), None, 0
|
| 544 |
+
fold_start = time.time()
|
| 545 |
+
|
| 546 |
+
for ep in range(EPOCHS):
|
| 547 |
+
ep_start = time.time()
|
| 548 |
+
use_swa = ep >= SWA_START
|
| 549 |
+
|
| 550 |
+
# ── TRAIN ─────────────────────────────────────────────
|
| 551 |
+
model.train()
|
| 552 |
+
te, tn = 0., 0
|
| 553 |
+
for cb, gb, g_batch, tb in tr_loader:
|
| 554 |
+
sp = model(cb, gb, g_batch, True)
|
| 555 |
+
if model.mode == 'gate_halt':
|
| 556 |
+
loss = gate_halt_loss(sp, tb, model._gate_sparsity)
|
| 557 |
+
else:
|
| 558 |
+
loss = deep_sup_loss(sp, tb)
|
| 559 |
+
opt.zero_grad(set_to_none=True)
|
| 560 |
+
loss.backward()
|
| 561 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)
|
| 562 |
+
opt.step()
|
| 563 |
+
with torch.no_grad():
|
| 564 |
+
te += ((sp[-1] * tgt_std + tgt_mean) - (tb * tgt_std + tgt_mean)).abs().sum().item()
|
| 565 |
+
tn += len(tb)
|
| 566 |
+
|
| 567 |
+
if use_swa:
|
| 568 |
+
swa_model.update_parameters(model)
|
| 569 |
+
swa_sch.step()
|
| 570 |
+
else:
|
| 571 |
+
sch.step()
|
| 572 |
+
|
| 573 |
+
# ── VALIDATE ──────────────────────────────────────────
|
| 574 |
+
eval_m = swa_model if use_swa and ep == EPOCHS - 1 else model
|
| 575 |
+
eval_m.eval()
|
| 576 |
+
ve, vn = 0., 0
|
| 577 |
+
with torch.inference_mode():
|
| 578 |
+
for cb, gb, g_batch, tb in vl_loader:
|
| 579 |
+
pred = eval_m(cb, gb, g_batch)
|
| 580 |
+
ve += ((pred * tgt_std + tgt_mean) - (tb * tgt_std + tgt_mean)).abs().sum().item()
|
| 581 |
+
vn += len(tb)
|
| 582 |
+
|
| 583 |
+
train_mae = te / max(tn, 1)
|
| 584 |
+
val_mae = ve / max(vn, 1)
|
| 585 |
+
ep_time = time.time() - ep_start
|
| 586 |
+
|
| 587 |
+
if val_mae < bv:
|
| 588 |
+
bv = val_mae
|
| 589 |
+
bw = copy.deepcopy(model.state_dict())
|
| 590 |
+
best_ep = ep + 1
|
| 591 |
+
|
| 592 |
+
# ── UPDATE DASHBOARD STATE (every epoch) ────────────
|
| 593 |
+
lr_now = opt.param_groups[0]['lr']
|
| 594 |
+
eta_m = (EPOCHS - ep - 1) * ep_time / 60
|
| 595 |
+
_dash_state[model_name].update({
|
| 596 |
+
'fold': fold, 'ep': ep + 1,
|
| 597 |
+
'tr': train_mae, 'val': val_mae,
|
| 598 |
+
'best': bv, 'best_ep': best_ep,
|
| 599 |
+
'lr': lr_now, 'ep_s': ep_time,
|
| 600 |
+
'eta_m': eta_m, 'swa': use_swa,
|
| 601 |
+
})
|
| 602 |
+
|
| 603 |
+
# ── PLAIN LOG (fallback / milestone prints) ───────────
|
| 604 |
+
if not IN_NOTEBOOK and ((ep + 1) % log_every == 0 or ep == 0 or ep == EPOCHS - 1):
|
| 605 |
+
swa_tag = ' SWA' if use_swa else ''
|
| 606 |
+
_log(f" [{model_name}|F{fold}] ep {ep+1:>3}/{EPOCHS}"
|
| 607 |
+
f" │ Tr={train_mae:>6.1f} Val={val_mae:>6.1f}"
|
| 608 |
+
f" Best={bv:>6.1f}@{best_ep:<3}"
|
| 609 |
+
f" │ lr={lr_now:.0e}{swa_tag}"
|
| 610 |
+
f" │ {ep_time:.1f}s/ep ETA {eta_m:.0f}m")
|
| 611 |
+
|
| 612 |
+
model.load_state_dict(bw)
|
| 613 |
+
total_time = time.time() - fold_start
|
| 614 |
+
_log(f" [{model_name}|F{fold}] ✅ Done in {total_time/60:.1f}m │ Best Val MAE = {bv:.2f} @ epoch {best_ep}")
|
| 615 |
+
|
| 616 |
+
return bv, model
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
def evaluate_model(model, test_loader, device, tgt_mean=0., tgt_std=1.):
|
| 620 |
+
"""Evaluate model MAE on test set (returns MAE in original scale)."""
|
| 621 |
+
model.eval()
|
| 622 |
+
ee, en_ = 0., 0
|
| 623 |
+
with torch.inference_mode():
|
| 624 |
+
for cb, gb, g_batch, tb in test_loader:
|
| 625 |
+
pred = model(cb, gb, g_batch) * tgt_std + tgt_mean
|
| 626 |
+
real = tb * tgt_std + tgt_mean
|
| 627 |
+
ee += (pred - real).abs().sum().item()
|
| 628 |
+
en_ += len(tb)
|
| 629 |
+
return ee / max(en_, 1)
|
| 630 |
+
|
| 631 |
+
|
| 632 |
+
# ═══════════════════════════════════════════════════════════════
|
| 633 |
+
# DUAL-GPU PARALLEL TRAINING
|
| 634 |
+
# ═══════════════════════════════════════════════════════════════
|
| 635 |
+
|
| 636 |
+
def _train_worker(model, tr_loader, vl_loader, te_loader, device,
|
| 637 |
+
fold, seed, model_name, result_dict, key,
|
| 638 |
+
tgt_mean=0., tgt_std=1.):
|
| 639 |
+
"""Thread worker: train + evaluate one model on one GPU."""
|
| 640 |
+
try:
|
| 641 |
+
_, best_model = train_fold_core(
|
| 642 |
+
model, tr_loader, vl_loader, device, fold, seed, model_name,
|
| 643 |
+
tgt_mean=tgt_mean, tgt_std=tgt_std
|
| 644 |
+
)
|
| 645 |
+
mae = evaluate_model(best_model, te_loader, device, tgt_mean, tgt_std)
|
| 646 |
+
result_dict[key] = mae
|
| 647 |
+
_dash_state[model_name]['test_mae'] = mae
|
| 648 |
+
_dash_state[model_name]['done'] = True
|
| 649 |
+
_log(f" [{model_name}|F{fold}] 🏆 Test MAE = {mae:.2f} cm⁻¹")
|
| 650 |
+
del best_model
|
| 651 |
+
except Exception as e:
|
| 652 |
+
import traceback
|
| 653 |
+
_log(f" [{model_name}|F{fold}] ❌ ERROR: {e}\n{traceback.format_exc()}")
|
| 654 |
+
result_dict[key] = float('inf')
|
| 655 |
+
_dash_state[model_name]['done'] = True
|
| 656 |
+
finally:
|
| 657 |
+
if device.type == 'cuda':
|
| 658 |
+
torch.cuda.empty_cache()
|
| 659 |
+
|
| 660 |
+
|
| 661 |
+
# ═══════════════��═══════════════════════════════════════════════
|
| 662 |
+
# MAIN
|
| 663 |
+
# ═══════════════════════════════════════════════════════════════
|
| 664 |
+
|
| 665 |
+
def main():
|
| 666 |
+
t0 = time.time()
|
| 667 |
+
|
| 668 |
+
n_gpus = torch.cuda.device_count() if torch.cuda.is_available() else 0
|
| 669 |
+
|
| 670 |
+
print(f"""
|
| 671 |
+
╔══════════════════════════════════════════════════════════╗
|
| 672 |
+
║ TRIADS V6 — Graph-TRM + Gate-Based Halting ║
|
| 673 |
+
║ ║
|
| 674 |
+
║ Gate-halt: {MIN_CYCLES}-{MAX_CYCLES} adaptive cycles, d={D} ║
|
| 675 |
+
║ Deep supervision │ SWA (last {EPOCHS-SWA_START} ep) │ {EPOCHS} ep ║
|
| 676 |
+
╚══════════════════════════════════════════════════════════╝
|
| 677 |
+
""")
|
| 678 |
+
|
| 679 |
+
device = torch.device('cuda:0' if n_gpus > 0 else 'cpu')
|
| 680 |
+
if n_gpus > 0:
|
| 681 |
+
name = torch.cuda.get_device_name(0)
|
| 682 |
+
mem = torch.cuda.get_device_properties(0).total_memory / 1e9
|
| 683 |
+
print(f" GPU: {name} ({mem:.1f} GB)")
|
| 684 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 685 |
+
torch.backends.cudnn.benchmark = True
|
| 686 |
+
else:
|
| 687 |
+
print(" ⚠ No GPU — training will be slow")
|
| 688 |
+
|
| 689 |
+
# ── LOAD DATASET ──────────────────────────────────────────
|
| 690 |
+
kaggle_path = "/kaggle/input/datasets/rudratiwari0099x/phonons-training-dataset/phonons_v6_dataset.pt"
|
| 691 |
+
local_path = "phonons_v6_dataset.pt"
|
| 692 |
+
ds_path = kaggle_path if os.path.exists(kaggle_path) else local_path
|
| 693 |
+
print(f" Loading {ds_path}...")
|
| 694 |
+
data = torch.load(ds_path, weights_only=False)
|
| 695 |
+
graphs = data['graphs']
|
| 696 |
+
comp_all = data['comp_features']
|
| 697 |
+
glob_phys = data['global_physics']
|
| 698 |
+
tgt_all = data['targets']
|
| 699 |
+
fold_indices = data['fold_indices']
|
| 700 |
+
N = data['n_samples']
|
| 701 |
+
comp_dim = comp_all.shape[1]
|
| 702 |
+
gp_dim = glob_phys.shape[1]
|
| 703 |
+
print(f" Dataset: {N} samples | comp_dim: {comp_dim} | global_phys: {gp_dim}")
|
| 704 |
+
|
| 705 |
+
# ── VERIFY FOLDS ──────────────────────────────────────────
|
| 706 |
+
for fi, (tr, te) in enumerate(fold_indices):
|
| 707 |
+
assert len(set(tr) & set(te)) == 0, f"LEAK in fold {fi}!"
|
| 708 |
+
print(" 5 folds: zero leakage ✓")
|
| 709 |
+
|
| 710 |
+
# ── MODEL SIZE CHECK ─────────────────────────────────────
|
| 711 |
+
m_test = PhononV6(comp_dim, gp_dim, mode='gate_halt',
|
| 712 |
+
min_cycles=MIN_CYCLES, max_cycles=MAX_CYCLES)
|
| 713 |
+
n_params = m_test.count_parameters()
|
| 714 |
+
print(f" Model (Gate-Halt TRM): {n_params:,} params")
|
| 715 |
+
del m_test
|
| 716 |
+
print()
|
| 717 |
+
|
| 718 |
+
# ── TRAINING ──────────────────────────────────────────────
|
| 719 |
+
tnp = tgt_all.numpy()
|
| 720 |
+
results = {}
|
| 721 |
+
|
| 722 |
+
_dashboard.start()
|
| 723 |
+
try:
|
| 724 |
+
for seed in SEEDS:
|
| 725 |
+
print(f" {'═'*3} Seed {seed} {'═'*55}")
|
| 726 |
+
ts = time.time()
|
| 727 |
+
fold_maes = {}
|
| 728 |
+
|
| 729 |
+
for fi, (tv_idx, te_idx) in enumerate(fold_indices):
|
| 730 |
+
tv_idx, te_idx = np.array(tv_idx), np.array(te_idx)
|
| 731 |
+
print(f"\n ┌─ Fold {fi+1}/5 {'─'*50}")
|
| 732 |
+
|
| 733 |
+
# Train/val split within train fold
|
| 734 |
+
tri, vli = strat_split(tnp[tv_idx], 0.15, seed + fi)
|
| 735 |
+
|
| 736 |
+
# Normalize targets (from train split ONLY — zero leakage)
|
| 737 |
+
tgt_mean = float(tgt_all[tv_idx[tri]].mean())
|
| 738 |
+
tgt_std = float(tgt_all[tv_idx[tri]].std()) + 1e-8
|
| 739 |
+
tgt_norm = (tgt_all - tgt_mean) / tgt_std
|
| 740 |
+
print(f" │ Target norm: mean={tgt_mean:.1f} std={tgt_std:.1f}")
|
| 741 |
+
|
| 742 |
+
# Scale features (ONLY from train split — zero leakage)
|
| 743 |
+
sc = StandardScaler().fit(comp_all[tv_idx[tri]].numpy())
|
| 744 |
+
cs = torch.tensor(
|
| 745 |
+
np.nan_to_num(sc.transform(comp_all.numpy()), nan=0.).astype(np.float32)
|
| 746 |
+
)
|
| 747 |
+
sc_gp = StandardScaler().fit(glob_phys[tv_idx[tri]].numpy())
|
| 748 |
+
gps = torch.tensor(
|
| 749 |
+
np.nan_to_num(sc_gp.transform(glob_phys.numpy()), nan=0.).astype(np.float32)
|
| 750 |
+
)
|
| 751 |
+
|
| 752 |
+
# Seed for reproducibility
|
| 753 |
+
torch.manual_seed(seed + fi)
|
| 754 |
+
np.random.seed(seed + fi)
|
| 755 |
+
if n_gpus > 0:
|
| 756 |
+
torch.cuda.manual_seed_all(seed + fi)
|
| 757 |
+
|
| 758 |
+
# Create model
|
| 759 |
+
model = PhononV6(comp_dim, gp_dim, mode='gate_halt',
|
| 760 |
+
min_cycles=MIN_CYCLES,
|
| 761 |
+
max_cycles=MAX_CYCLES).to(device)
|
| 762 |
+
|
| 763 |
+
# Build loaders with NORMALIZED targets
|
| 764 |
+
trl = Loader(graphs, cs, gps, tgt_norm, tv_idx[tri], BATCH_SIZE, device, True)
|
| 765 |
+
vll = Loader(graphs, cs, gps, tgt_norm, tv_idx[vli], BATCH_SIZE, device, False)
|
| 766 |
+
tel = Loader(graphs, cs, gps, tgt_norm, te_idx, BATCH_SIZE, device, False)
|
| 767 |
+
|
| 768 |
+
# Reset dashboard
|
| 769 |
+
_dash_state['GH']['done'] = False
|
| 770 |
+
|
| 771 |
+
# Train
|
| 772 |
+
_, best_model = train_fold_core(
|
| 773 |
+
model, trl, vll, device, fi+1, seed, "GH",
|
| 774 |
+
tgt_mean=tgt_mean, tgt_std=tgt_std
|
| 775 |
+
)
|
| 776 |
+
mae = evaluate_model(best_model, tel, device, tgt_mean, tgt_std)
|
| 777 |
+
fold_maes[fi] = mae
|
| 778 |
+
_dash_state['GH']['test_mae'] = mae
|
| 779 |
+
_dash_state['GH']['done'] = True
|
| 780 |
+
_log(f" [GH|F{fi+1}] 🏆 Test MAE = {mae:.2f} cm⁻¹")
|
| 781 |
+
|
| 782 |
+
# ── SAVE WEIGHTS ─────────────────────────────────────
|
| 783 |
+
os.makedirs('phonons_models_v6', exist_ok=True)
|
| 784 |
+
torch.save({
|
| 785 |
+
'model_state': best_model.state_dict(),
|
| 786 |
+
'test_mae': mae,
|
| 787 |
+
'fold': fi + 1,
|
| 788 |
+
'seed': seed,
|
| 789 |
+
'comp_dim': comp_dim,
|
| 790 |
+
'gp_dim': gp_dim,
|
| 791 |
+
}, f'phonons_models_v6/phonons_v6_s{seed}_f{fi+1}.pt')
|
| 792 |
+
_log(f" [GH|F{fi+1}] 💾 Saved phonons_models_v6/phonons_v6_s{seed}_f{fi+1}.pt")
|
| 793 |
+
# ─────────────────────────────────────────────────────
|
| 794 |
+
|
| 795 |
+
print(f" └─ Fold {fi+1} done │ MAE = {fold_maes[fi]:.2f} cm⁻¹")
|
| 796 |
+
|
| 797 |
+
del model, best_model
|
| 798 |
+
if n_gpus > 0: torch.cuda.empty_cache()
|
| 799 |
+
|
| 800 |
+
avg = np.mean(list(fold_maes.values()))
|
| 801 |
+
results[seed] = fold_maes
|
| 802 |
+
elapsed = time.time() - ts
|
| 803 |
+
print(f"\n Seed {seed} │ Avg MAE: {avg:.2f} │ {elapsed/60:.1f} min")
|
| 804 |
+
|
| 805 |
+
finally:
|
| 806 |
+
_dashboard.stop()
|
| 807 |
+
|
| 808 |
+
# ── FINAL RESULTS ─────────────────────────────────────────
|
| 809 |
+
fa = np.mean([np.mean(list(v.values())) for v in results.values()])
|
| 810 |
+
|
| 811 |
+
print(f"""
|
| 812 |
+
{'='*62}
|
| 813 |
+
FINAL RESULTS — V6 Gate-Halt TRM
|
| 814 |
+
{'='*62}
|
| 815 |
+
|
| 816 |
+
{'Model':<45} {'MAE':>10}
|
| 817 |
+
{'─'*57}""")
|
| 818 |
+
for n, v in sorted(BASELINES.items(), key=lambda x: x[1]):
|
| 819 |
+
beaten = ' ← BEATEN!' if fa < v else ''
|
| 820 |
+
print(f" {n:<45} {v:>10.2f}{beaten}")
|
| 821 |
+
print(f" {'V6 Gate-Halt TRM ('+str(n_params//1000)+'K, '+str(MIN_CYCLES)+'-'+str(MAX_CYCLES)+' cycles)':<45} {fa:>10.2f} ← OURS")
|
| 822 |
+
print(f" {'─'*57}")
|
| 823 |
+
print(f" Total time: {(time.time()-t0)/60:.1f} min")
|
| 824 |
+
|
| 825 |
+
# ── SAVE ──────────────────────────────────────────────────
|
| 826 |
+
res = {
|
| 827 |
+
'model': 'V6-Gate-Halt-TRM', 'params': n_params,
|
| 828 |
+
'cycles': f'{MIN_CYCLES}-{MAX_CYCLES}',
|
| 829 |
+
'avg_mae': round(fa, 2),
|
| 830 |
+
'per_fold': {str(s): {str(k): round(v, 2) for k,v in m.items()}
|
| 831 |
+
for s,m in results.items()},
|
| 832 |
+
}
|
| 833 |
+
with open('phonons_v6_results.json', 'w') as f:
|
| 834 |
+
json.dump(res, f, indent=2)
|
| 835 |
+
print(" Saved: phonons_v6_results.json\n")
|
| 836 |
+
|
| 837 |
+
|
| 838 |
+
if __name__ == '__main__':
|
| 839 |
+
main()
|
model_code/steels_model.py
ADDED
|
@@ -0,0 +1,1056 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
╔══════════════════════════════════════════════════════════════════════╗
|
| 3 |
+
║ TRM-MatSci V13 — 2-Layer SA + Multi-Seed Ensemble ║
|
| 4 |
+
║ Dataset: matbench_steels │ 312 samples │ 5-Fold Nested CV ║
|
| 5 |
+
║ ║
|
| 6 |
+
║ V13A 2-Layer Self-Attention + Standard Deep Supervision ║
|
| 7 |
+
║ d_attn=64, nhead=4, d_hidden=96, ff_dim=150, 20 steps ║
|
| 8 |
+
║ Expanded features (Magpie + Mat2Vec + Extra descriptors) ║
|
| 9 |
+
║ 2nd SA layer for higher-order property interactions ║
|
| 10 |
+
║ 5-seed ensemble (avg predictions across seeds) ║
|
| 11 |
+
║ ║
|
| 12 |
+
║ V13B Same 2-Layer SA + Confidence-Weighted Deep Supervision ║
|
| 13 |
+
║ 22 steps, confidence_head learns which step to trust ║
|
| 14 |
+
║ 5-seed ensemble (avg predictions across seeds) ║
|
| 15 |
+
║ ║
|
| 16 |
+
║ All models: Deep Supervision + SWA + AdamW + 300 epochs ║
|
| 17 |
+
║ Baseline: V12A = 95.99 MPa (current best) ║
|
| 18 |
+
╚══════════════════════════════════════════════════════════════════════╝
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import os, copy, json, time, logging, warnings, shutil, urllib.request
|
| 22 |
+
warnings.filterwarnings('ignore')
|
| 23 |
+
|
| 24 |
+
import numpy as np
|
| 25 |
+
import pandas as pd
|
| 26 |
+
|
| 27 |
+
import matplotlib
|
| 28 |
+
matplotlib.use('Agg')
|
| 29 |
+
import matplotlib.pyplot as plt
|
| 30 |
+
import matplotlib.gridspec as gridspec
|
| 31 |
+
|
| 32 |
+
from tqdm import tqdm
|
| 33 |
+
|
| 34 |
+
import torch
|
| 35 |
+
import torch.nn as nn
|
| 36 |
+
import torch.nn.functional as F
|
| 37 |
+
from torch.utils.data import Dataset, DataLoader
|
| 38 |
+
import torch.optim as optim
|
| 39 |
+
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
|
| 40 |
+
|
| 41 |
+
from sklearn.model_selection import KFold
|
| 42 |
+
from sklearn.preprocessing import StandardScaler
|
| 43 |
+
from pymatgen.core import Composition
|
| 44 |
+
from matminer.featurizers.composition import ElementProperty
|
| 45 |
+
from gensim.models import Word2Vec
|
| 46 |
+
|
| 47 |
+
logging.basicConfig(level=logging.INFO, format='%(name)s │ %(message)s')
|
| 48 |
+
log = logging.getLogger("TRM13")
|
| 49 |
+
|
| 50 |
+
# Seeds for multi-seed ensemble
|
| 51 |
+
SEEDS = [42, 123, 7, 0, 99]
|
| 52 |
+
N_SEEDS = len(SEEDS)
|
| 53 |
+
|
| 54 |
+
BASELINES = {
|
| 55 |
+
'TPOT-Mat': 79.9468,
|
| 56 |
+
'AutoML-Mat': 82.3043,
|
| 57 |
+
'MODNet': 87.7627,
|
| 58 |
+
'RF-SCM/Magpie': 103.5125,
|
| 59 |
+
'V12A (best)': 95.9900,
|
| 60 |
+
'V11B': 102.3003,
|
| 61 |
+
'V10A': 103.2867,
|
| 62 |
+
'CrabNet': 107.3160,
|
| 63 |
+
'Darwin': 123.2932,
|
| 64 |
+
}
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
# ══════════════════════════════════════════════════════════════════════
|
| 68 |
+
# 1. FEATURIZER + DATASET
|
| 69 |
+
# ══════════════════════════════════════════════════════════════════════
|
| 70 |
+
|
| 71 |
+
class ExpandedFeaturizer:
|
| 72 |
+
"""Magpie (22 props × 6 stats) + Extra matminer descriptors + Mat2Vec (200d).
|
| 73 |
+
|
| 74 |
+
Extra descriptors: ElementFraction, Stoichiometry, ValenceOrbital,
|
| 75 |
+
IonProperty, BandCenter — all concatenated as a flat vector between
|
| 76 |
+
the Magpie block and Mat2Vec.
|
| 77 |
+
"""
|
| 78 |
+
GCS = "https://storage.googleapis.com/mat2vec/"
|
| 79 |
+
FILES = ["pretrained_embeddings",
|
| 80 |
+
"pretrained_embeddings.wv.vectors.npy",
|
| 81 |
+
"pretrained_embeddings.trainables.syn1neg.npy"]
|
| 82 |
+
|
| 83 |
+
def __init__(self, cache="mat2vec_cache"):
|
| 84 |
+
from matminer.featurizers.composition import (
|
| 85 |
+
ElementFraction, Stoichiometry, ValenceOrbital,
|
| 86 |
+
IonProperty, BandCenter
|
| 87 |
+
)
|
| 88 |
+
from matminer.featurizers.base import MultipleFeaturizer
|
| 89 |
+
|
| 90 |
+
self.ep_magpie = ElementProperty.from_preset("magpie")
|
| 91 |
+
self.n_mg = len(self.ep_magpie.feature_labels())
|
| 92 |
+
|
| 93 |
+
self.extra_feats = MultipleFeaturizer([
|
| 94 |
+
ElementFraction(),
|
| 95 |
+
Stoichiometry(),
|
| 96 |
+
ValenceOrbital(),
|
| 97 |
+
IonProperty(),
|
| 98 |
+
BandCenter(),
|
| 99 |
+
])
|
| 100 |
+
self.n_extra = None # detected at featurize time
|
| 101 |
+
|
| 102 |
+
self.scaler = None
|
| 103 |
+
os.makedirs(cache, exist_ok=True)
|
| 104 |
+
for f in self.FILES:
|
| 105 |
+
p = os.path.join(cache, f)
|
| 106 |
+
if not os.path.exists(p):
|
| 107 |
+
log.info(f" Downloading {f}...")
|
| 108 |
+
urllib.request.urlretrieve(self.GCS + f, p)
|
| 109 |
+
self.m2v = Word2Vec.load(os.path.join(cache, "pretrained_embeddings"))
|
| 110 |
+
self.emb = {w: self.m2v.wv[w] for w in self.m2v.wv.index_to_key}
|
| 111 |
+
|
| 112 |
+
def _pool(self, c):
|
| 113 |
+
v, t = np.zeros(200, np.float32), 0.0
|
| 114 |
+
for s, f in c.get_el_amt_dict().items():
|
| 115 |
+
if s in self.emb: v += f * self.emb[s]; t += f
|
| 116 |
+
return v / max(t, 1e-8)
|
| 117 |
+
|
| 118 |
+
def featurize_all(self, comps):
|
| 119 |
+
out = []
|
| 120 |
+
for c in tqdm(comps, desc=" Featurizing (expanded)", leave=False):
|
| 121 |
+
try: mg = np.array(self.ep_magpie.featurize(c), np.float32)
|
| 122 |
+
except: mg = np.zeros(self.n_mg, np.float32)
|
| 123 |
+
|
| 124 |
+
try:
|
| 125 |
+
ex = np.array(self.extra_feats.featurize(c), np.float32)
|
| 126 |
+
except:
|
| 127 |
+
ex = np.zeros(self.n_extra or 200, np.float32)
|
| 128 |
+
|
| 129 |
+
if self.n_extra is None:
|
| 130 |
+
self.n_extra = len(ex)
|
| 131 |
+
log.info(f"Expanded features: {self.n_mg} Magpie + "
|
| 132 |
+
f"{self.n_extra} Extra + 200 Mat2Vec = "
|
| 133 |
+
f"{self.n_mg + self.n_extra + 200}d")
|
| 134 |
+
|
| 135 |
+
out.append(np.concatenate([
|
| 136 |
+
np.nan_to_num(mg, nan=0.0),
|
| 137 |
+
np.nan_to_num(ex, nan=0.0),
|
| 138 |
+
self._pool(c)
|
| 139 |
+
]))
|
| 140 |
+
return np.array(out)
|
| 141 |
+
|
| 142 |
+
def fit_scaler(self, X): self.scaler = StandardScaler().fit(X)
|
| 143 |
+
def transform(self, X):
|
| 144 |
+
if not self.scaler: return X
|
| 145 |
+
return np.nan_to_num(self.scaler.transform(X), nan=0.0).astype(np.float32)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
class DSData(Dataset):
|
| 149 |
+
def __init__(self, X, y):
|
| 150 |
+
self.X = torch.tensor(X, dtype=torch.float32)
|
| 151 |
+
self.y = torch.tensor(np.array(y, np.float32))
|
| 152 |
+
def __len__(self): return len(self.y)
|
| 153 |
+
def __getitem__(self, i): return self.X[i], self.y[i]
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
# ══════════════════════════════════════════════════════════════════════
|
| 157 |
+
# 2. MODELS — with 2-Layer Self-Attention
|
| 158 |
+
# ══════════════════════════════════════════════════════════════════════
|
| 159 |
+
|
| 160 |
+
class DeepHybridTRM(nn.Module):
|
| 161 |
+
"""V13A: 2-Layer SA Hybrid-TRM with Standard Deep Supervision.
|
| 162 |
+
|
| 163 |
+
Key difference from V12A's HybridTRM:
|
| 164 |
+
- TWO self-attention layers (SA1 → FF1 → SA2 → FF2 → CA)
|
| 165 |
+
- Each SA layer has its own residual + LayerNorm + FF block
|
| 166 |
+
- This enables higher-order property interaction modeling
|
| 167 |
+
(e.g., "correlation between electronegativity-range AND
|
| 168 |
+
atomic-radius-mean" requires composing two rounds of attention)
|
| 169 |
+
- Cross-attention (CA) to Mat2Vec context remains after SA stack
|
| 170 |
+
|
| 171 |
+
Everything else (MLP reasoning loop, deep supervision, SWA)
|
| 172 |
+
is identical to V12A.
|
| 173 |
+
"""
|
| 174 |
+
def __init__(self, n_props=22, stat_dim=6, n_extra=0, mat2vec_dim=200,
|
| 175 |
+
d_attn=64, nhead=4, d_hidden=96, ff_dim=150,
|
| 176 |
+
dropout=0.2, max_steps=20, **kw):
|
| 177 |
+
super().__init__()
|
| 178 |
+
self.max_steps, self.D = max_steps, d_hidden
|
| 179 |
+
self.n_props, self.stat_dim = n_props, stat_dim
|
| 180 |
+
self.n_extra = n_extra
|
| 181 |
+
|
| 182 |
+
# ── Attention feature extractor (2-Layer SA) ──────────────────
|
| 183 |
+
self.tok_proj = nn.Sequential(
|
| 184 |
+
nn.Linear(stat_dim, d_attn), nn.LayerNorm(d_attn), nn.GELU())
|
| 185 |
+
self.m2v_proj = nn.Sequential(
|
| 186 |
+
nn.Linear(mat2vec_dim, d_attn), nn.LayerNorm(d_attn), nn.GELU())
|
| 187 |
+
|
| 188 |
+
# Self-Attention Layer 1
|
| 189 |
+
self.sa1 = nn.MultiheadAttention(
|
| 190 |
+
d_attn, nhead, dropout=dropout, batch_first=True)
|
| 191 |
+
self.sa1_n = nn.LayerNorm(d_attn)
|
| 192 |
+
self.sa1_ff = nn.Sequential(
|
| 193 |
+
nn.Linear(d_attn, d_attn*2), nn.GELU(), nn.Dropout(dropout),
|
| 194 |
+
nn.Linear(d_attn*2, d_attn))
|
| 195 |
+
self.sa1_fn = nn.LayerNorm(d_attn)
|
| 196 |
+
|
| 197 |
+
# Self-Attention Layer 2 (NEW — captures higher-order interactions)
|
| 198 |
+
self.sa2 = nn.MultiheadAttention(
|
| 199 |
+
d_attn, nhead, dropout=dropout, batch_first=True)
|
| 200 |
+
self.sa2_n = nn.LayerNorm(d_attn)
|
| 201 |
+
self.sa2_ff = nn.Sequential(
|
| 202 |
+
nn.Linear(d_attn, d_attn*2), nn.GELU(), nn.Dropout(dropout),
|
| 203 |
+
nn.Linear(d_attn*2, d_attn))
|
| 204 |
+
self.sa2_fn = nn.LayerNorm(d_attn)
|
| 205 |
+
|
| 206 |
+
# Cross-Attention to Mat2Vec context (after SA stack)
|
| 207 |
+
self.ca = nn.MultiheadAttention(
|
| 208 |
+
d_attn, nhead, dropout=dropout, batch_first=True)
|
| 209 |
+
self.ca_n = nn.LayerNorm(d_attn)
|
| 210 |
+
|
| 211 |
+
# Pool with optional extra feature injection
|
| 212 |
+
pool_in = d_attn + (n_extra if n_extra > 0 else 0)
|
| 213 |
+
self.pool = nn.Sequential(
|
| 214 |
+
nn.Linear(pool_in, d_hidden), nn.LayerNorm(d_hidden), nn.GELU())
|
| 215 |
+
|
| 216 |
+
# MLP-TRM recursive reasoning (shared weights)
|
| 217 |
+
self.z_up = nn.Sequential(
|
| 218 |
+
nn.Linear(d_hidden*3, ff_dim), nn.GELU(), nn.Dropout(dropout),
|
| 219 |
+
nn.Linear(ff_dim, d_hidden), nn.LayerNorm(d_hidden))
|
| 220 |
+
self.y_up = nn.Sequential(
|
| 221 |
+
nn.Linear(d_hidden*2, ff_dim), nn.GELU(), nn.Dropout(dropout),
|
| 222 |
+
nn.Linear(ff_dim, d_hidden), nn.LayerNorm(d_hidden))
|
| 223 |
+
self.head = nn.Linear(d_hidden, 1)
|
| 224 |
+
self._init()
|
| 225 |
+
|
| 226 |
+
def _init(self):
|
| 227 |
+
for m in self.modules():
|
| 228 |
+
if isinstance(m, nn.Linear):
|
| 229 |
+
nn.init.xavier_uniform_(m.weight)
|
| 230 |
+
if m.bias is not None: nn.init.zeros_(m.bias)
|
| 231 |
+
|
| 232 |
+
def _attention(self, x):
|
| 233 |
+
B = x.size(0)
|
| 234 |
+
mg_dim = self.n_props * self.stat_dim
|
| 235 |
+
mg = x[:, :mg_dim]
|
| 236 |
+
|
| 237 |
+
if self.n_extra > 0:
|
| 238 |
+
extra = x[:, mg_dim:mg_dim + self.n_extra]
|
| 239 |
+
m2v = x[:, mg_dim + self.n_extra:]
|
| 240 |
+
else:
|
| 241 |
+
extra = None
|
| 242 |
+
m2v = x[:, mg_dim:]
|
| 243 |
+
|
| 244 |
+
tok = self.tok_proj(mg.view(B, self.n_props, self.stat_dim))
|
| 245 |
+
ctx = self.m2v_proj(m2v).unsqueeze(1)
|
| 246 |
+
|
| 247 |
+
# SA Layer 1: learn pairwise property interactions
|
| 248 |
+
tok = self.sa1_n(tok + self.sa1(tok, tok, tok)[0])
|
| 249 |
+
tok = self.sa1_fn(tok + self.sa1_ff(tok))
|
| 250 |
+
|
| 251 |
+
# SA Layer 2: learn higher-order property interactions
|
| 252 |
+
tok = self.sa2_n(tok + self.sa2(tok, tok, tok)[0])
|
| 253 |
+
tok = self.sa2_fn(tok + self.sa2_ff(tok))
|
| 254 |
+
|
| 255 |
+
# Cross-Attention to Mat2Vec chemistry context
|
| 256 |
+
tok = self.ca_n(tok + self.ca(tok, ctx, ctx)[0])
|
| 257 |
+
|
| 258 |
+
pooled = tok.mean(dim=1) # [B, d_attn]
|
| 259 |
+
|
| 260 |
+
if extra is not None:
|
| 261 |
+
pooled = torch.cat([pooled, extra], dim=-1)
|
| 262 |
+
|
| 263 |
+
return self.pool(pooled) # [B, d_hidden]
|
| 264 |
+
|
| 265 |
+
def forward(self, x, deep_supervision=False, return_trajectory=False):
|
| 266 |
+
B = x.size(0)
|
| 267 |
+
xp = self._attention(x)
|
| 268 |
+
z = torch.zeros(B, self.D, device=x.device)
|
| 269 |
+
y = torch.zeros(B, self.D, device=x.device)
|
| 270 |
+
step_preds = []
|
| 271 |
+
for _ in range(self.max_steps):
|
| 272 |
+
z = z + self.z_up(torch.cat([xp, y, z], -1))
|
| 273 |
+
y = y + self.y_up(torch.cat([y, z], -1))
|
| 274 |
+
step_preds.append(self.head(y).squeeze(1))
|
| 275 |
+
if deep_supervision:
|
| 276 |
+
return step_preds
|
| 277 |
+
elif return_trajectory:
|
| 278 |
+
return step_preds[-1], step_preds
|
| 279 |
+
else:
|
| 280 |
+
return step_preds[-1]
|
| 281 |
+
|
| 282 |
+
def count_parameters(self):
|
| 283 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 284 |
+
|
| 285 |
+
|
| 286 |
+
class DeepConfidenceHybridTRM(nn.Module):
|
| 287 |
+
"""V13B: 2-Layer SA Hybrid-TRM with Confidence-Weighted Deep Supervision.
|
| 288 |
+
|
| 289 |
+
Same 2-layer SA feature extractor as DeepHybridTRM, but with:
|
| 290 |
+
- confidence_head that learns which recursion step to trust
|
| 291 |
+
- Final prediction = softmax(confidence) · step_preds
|
| 292 |
+
- No ponder cost (avoids V11C's failure)
|
| 293 |
+
- 22 recursion steps (vs 20 for V13A)
|
| 294 |
+
"""
|
| 295 |
+
def __init__(self, n_props=22, stat_dim=6, n_extra=0, mat2vec_dim=200,
|
| 296 |
+
d_attn=64, nhead=4, d_hidden=96, ff_dim=150,
|
| 297 |
+
dropout=0.2, max_steps=22, **kw):
|
| 298 |
+
super().__init__()
|
| 299 |
+
self.max_steps, self.D = max_steps, d_hidden
|
| 300 |
+
self.n_props, self.stat_dim = n_props, stat_dim
|
| 301 |
+
self.n_extra = n_extra
|
| 302 |
+
|
| 303 |
+
# ── Attention feature extractor (2-Layer SA) ──────────────────
|
| 304 |
+
self.tok_proj = nn.Sequential(
|
| 305 |
+
nn.Linear(stat_dim, d_attn), nn.LayerNorm(d_attn), nn.GELU())
|
| 306 |
+
self.m2v_proj = nn.Sequential(
|
| 307 |
+
nn.Linear(mat2vec_dim, d_attn), nn.LayerNorm(d_attn), nn.GELU())
|
| 308 |
+
|
| 309 |
+
# Self-Attention Layer 1
|
| 310 |
+
self.sa1 = nn.MultiheadAttention(
|
| 311 |
+
d_attn, nhead, dropout=dropout, batch_first=True)
|
| 312 |
+
self.sa1_n = nn.LayerNorm(d_attn)
|
| 313 |
+
self.sa1_ff = nn.Sequential(
|
| 314 |
+
nn.Linear(d_attn, d_attn*2), nn.GELU(), nn.Dropout(dropout),
|
| 315 |
+
nn.Linear(d_attn*2, d_attn))
|
| 316 |
+
self.sa1_fn = nn.LayerNorm(d_attn)
|
| 317 |
+
|
| 318 |
+
# Self-Attention Layer 2 (higher-order interactions)
|
| 319 |
+
self.sa2 = nn.MultiheadAttention(
|
| 320 |
+
d_attn, nhead, dropout=dropout, batch_first=True)
|
| 321 |
+
self.sa2_n = nn.LayerNorm(d_attn)
|
| 322 |
+
self.sa2_ff = nn.Sequential(
|
| 323 |
+
nn.Linear(d_attn, d_attn*2), nn.GELU(), nn.Dropout(dropout),
|
| 324 |
+
nn.Linear(d_attn*2, d_attn))
|
| 325 |
+
self.sa2_fn = nn.LayerNorm(d_attn)
|
| 326 |
+
|
| 327 |
+
# Cross-Attention to Mat2Vec context
|
| 328 |
+
self.ca = nn.MultiheadAttention(
|
| 329 |
+
d_attn, nhead, dropout=dropout, batch_first=True)
|
| 330 |
+
self.ca_n = nn.LayerNorm(d_attn)
|
| 331 |
+
|
| 332 |
+
# Pool with optional extra feature injection
|
| 333 |
+
pool_in = d_attn + (n_extra if n_extra > 0 else 0)
|
| 334 |
+
self.pool = nn.Sequential(
|
| 335 |
+
nn.Linear(pool_in, d_hidden), nn.LayerNorm(d_hidden), nn.GELU())
|
| 336 |
+
|
| 337 |
+
# MLP-TRM recursive reasoning (shared weights)
|
| 338 |
+
self.z_up = nn.Sequential(
|
| 339 |
+
nn.Linear(d_hidden*3, ff_dim), nn.GELU(), nn.Dropout(dropout),
|
| 340 |
+
nn.Linear(ff_dim, d_hidden), nn.LayerNorm(d_hidden))
|
| 341 |
+
self.y_up = nn.Sequential(
|
| 342 |
+
nn.Linear(d_hidden*2, ff_dim), nn.GELU(), nn.Dropout(dropout),
|
| 343 |
+
nn.Linear(ff_dim, d_hidden), nn.LayerNorm(d_hidden))
|
| 344 |
+
self.head = nn.Linear(d_hidden, 1)
|
| 345 |
+
|
| 346 |
+
# ── Confidence head: learns which step to trust ──────────────
|
| 347 |
+
self.confidence_head = nn.Sequential(
|
| 348 |
+
nn.Linear(d_hidden, d_hidden // 2), nn.GELU(),
|
| 349 |
+
nn.Linear(d_hidden // 2, 1)) # raw logit, softmaxed later
|
| 350 |
+
|
| 351 |
+
self._init()
|
| 352 |
+
|
| 353 |
+
def _init(self):
|
| 354 |
+
for m in self.modules():
|
| 355 |
+
if isinstance(m, nn.Linear):
|
| 356 |
+
nn.init.xavier_uniform_(m.weight)
|
| 357 |
+
if m.bias is not None: nn.init.zeros_(m.bias)
|
| 358 |
+
with torch.no_grad():
|
| 359 |
+
nn.init.zeros_(self.confidence_head[-1].bias)
|
| 360 |
+
|
| 361 |
+
def _attention(self, x):
|
| 362 |
+
B = x.size(0)
|
| 363 |
+
mg_dim = self.n_props * self.stat_dim
|
| 364 |
+
mg = x[:, :mg_dim]
|
| 365 |
+
|
| 366 |
+
if self.n_extra > 0:
|
| 367 |
+
extra = x[:, mg_dim:mg_dim + self.n_extra]
|
| 368 |
+
m2v = x[:, mg_dim + self.n_extra:]
|
| 369 |
+
else:
|
| 370 |
+
extra = None
|
| 371 |
+
m2v = x[:, mg_dim:]
|
| 372 |
+
|
| 373 |
+
tok = self.tok_proj(mg.view(B, self.n_props, self.stat_dim))
|
| 374 |
+
ctx = self.m2v_proj(m2v).unsqueeze(1)
|
| 375 |
+
|
| 376 |
+
# SA Layer 1
|
| 377 |
+
tok = self.sa1_n(tok + self.sa1(tok, tok, tok)[0])
|
| 378 |
+
tok = self.sa1_fn(tok + self.sa1_ff(tok))
|
| 379 |
+
|
| 380 |
+
# SA Layer 2
|
| 381 |
+
tok = self.sa2_n(tok + self.sa2(tok, tok, tok)[0])
|
| 382 |
+
tok = self.sa2_fn(tok + self.sa2_ff(tok))
|
| 383 |
+
|
| 384 |
+
# Cross-Attention
|
| 385 |
+
tok = self.ca_n(tok + self.ca(tok, ctx, ctx)[0])
|
| 386 |
+
|
| 387 |
+
pooled = tok.mean(dim=1)
|
| 388 |
+
|
| 389 |
+
if extra is not None:
|
| 390 |
+
pooled = torch.cat([pooled, extra], dim=-1)
|
| 391 |
+
|
| 392 |
+
return self.pool(pooled)
|
| 393 |
+
|
| 394 |
+
def forward(self, x, deep_supervision=False, return_confidence=False):
|
| 395 |
+
"""Forward pass.
|
| 396 |
+
|
| 397 |
+
Returns:
|
| 398 |
+
deep_supervision=True: (step_preds, confidence_logits)
|
| 399 |
+
deep_supervision=False, return_confidence=False:
|
| 400 |
+
weighted_pred: [B] confidence-weighted prediction
|
| 401 |
+
deep_supervision=False, return_confidence=True:
|
| 402 |
+
(weighted_pred, confidence_weights): [B], [B, max_steps]
|
| 403 |
+
"""
|
| 404 |
+
B = x.size(0)
|
| 405 |
+
xp = self._attention(x)
|
| 406 |
+
z = torch.zeros(B, self.D, device=x.device)
|
| 407 |
+
y = torch.zeros(B, self.D, device=x.device)
|
| 408 |
+
|
| 409 |
+
step_preds = []
|
| 410 |
+
conf_logits = []
|
| 411 |
+
|
| 412 |
+
for _ in range(self.max_steps):
|
| 413 |
+
z = z + self.z_up(torch.cat([xp, y, z], -1))
|
| 414 |
+
y = y + self.y_up(torch.cat([y, z], -1))
|
| 415 |
+
step_preds.append(self.head(y).squeeze(1))
|
| 416 |
+
conf_logits.append(self.confidence_head(y).squeeze(1))
|
| 417 |
+
|
| 418 |
+
conf_logits = torch.stack(conf_logits, dim=1) # [B, max_steps]
|
| 419 |
+
|
| 420 |
+
if deep_supervision:
|
| 421 |
+
return step_preds, conf_logits
|
| 422 |
+
|
| 423 |
+
# Confidence-weighted prediction
|
| 424 |
+
conf_weights = F.softmax(conf_logits, dim=1) # [B, max_steps]
|
| 425 |
+
preds_stack = torch.stack(step_preds, dim=1) # [B, max_steps]
|
| 426 |
+
weighted_pred = (preds_stack * conf_weights).sum(dim=1) # [B]
|
| 427 |
+
|
| 428 |
+
if return_confidence:
|
| 429 |
+
return weighted_pred, conf_weights
|
| 430 |
+
return weighted_pred
|
| 431 |
+
|
| 432 |
+
def count_parameters(self):
|
| 433 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| 434 |
+
|
| 435 |
+
|
| 436 |
+
# ══════════════════════════════════════════════════════════════════════
|
| 437 |
+
# 3. LOSS FUNCTIONS
|
| 438 |
+
# ══════════════════════════════════════════════════════════════════════
|
| 439 |
+
|
| 440 |
+
def deep_supervision_loss(step_preds, targets):
|
| 441 |
+
"""Linear-weighted L1 loss across all recursion steps."""
|
| 442 |
+
n = len(step_preds)
|
| 443 |
+
weights = [(i + 1) for i in range(n)]
|
| 444 |
+
total_w = sum(weights)
|
| 445 |
+
loss = 0.0
|
| 446 |
+
for pred, w in zip(step_preds, weights):
|
| 447 |
+
loss += (w / total_w) * F.l1_loss(pred, targets)
|
| 448 |
+
return loss
|
| 449 |
+
|
| 450 |
+
|
| 451 |
+
def confidence_ds_loss(step_preds, targets, conf_logits):
|
| 452 |
+
"""Advanced Deep Supervision: standard DS + confidence-weighted L1.
|
| 453 |
+
|
| 454 |
+
Components:
|
| 455 |
+
1. Standard linear-weighted deep supervision on all steps
|
| 456 |
+
2. L1 loss on the confidence-weighted final prediction
|
| 457 |
+
"""
|
| 458 |
+
ds = deep_supervision_loss(step_preds, targets)
|
| 459 |
+
|
| 460 |
+
conf_weights = F.softmax(conf_logits, dim=1) # [B, max_steps]
|
| 461 |
+
preds_stack = torch.stack(step_preds, dim=1) # [B, max_steps]
|
| 462 |
+
weighted_pred = (preds_stack * conf_weights).sum(dim=1)
|
| 463 |
+
conf_loss = F.l1_loss(weighted_pred, targets)
|
| 464 |
+
|
| 465 |
+
return ds + conf_loss
|
| 466 |
+
|
| 467 |
+
|
| 468 |
+
# ══════════════════════════════════════════════════════════════════════
|
| 469 |
+
# 4. UTILS + TRAINING
|
| 470 |
+
# ══════════════════════════════════════════════════════════════════════
|
| 471 |
+
|
| 472 |
+
def strat_split(targets, val_size=0.15, seed=42):
|
| 473 |
+
bins = np.percentile(targets, [25, 50, 75])
|
| 474 |
+
lbl = np.digitize(targets, bins)
|
| 475 |
+
tr, vl = [], []
|
| 476 |
+
rng = np.random.RandomState(seed)
|
| 477 |
+
for b in range(4):
|
| 478 |
+
m = np.where(lbl == b)[0]
|
| 479 |
+
if len(m) == 0: continue
|
| 480 |
+
n = max(1, int(len(m) * val_size))
|
| 481 |
+
c = rng.choice(m, n, replace=False)
|
| 482 |
+
vl.extend(c.tolist()); tr.extend(np.setdiff1d(m, c).tolist())
|
| 483 |
+
return np.array(tr), np.array(vl)
|
| 484 |
+
|
| 485 |
+
|
| 486 |
+
def train_fold_standard(model, tr_dl, vl_dl, device,
|
| 487 |
+
epochs=300, swa_start=200, fold=1, name=""):
|
| 488 |
+
"""Training with standard deep supervision."""
|
| 489 |
+
opt = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
|
| 490 |
+
sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=swa_start, eta_min=1e-4)
|
| 491 |
+
swa_m = AveragedModel(model)
|
| 492 |
+
swa_s = SWALR(opt, swa_lr=5e-4)
|
| 493 |
+
swa_on = False
|
| 494 |
+
best_v, best_w = float('inf'), copy.deepcopy(model.state_dict())
|
| 495 |
+
hist = {'train': [], 'val': []}
|
| 496 |
+
|
| 497 |
+
pbar = tqdm(range(epochs), desc=f" [{name}] F{fold}/5",
|
| 498 |
+
leave=False, ncols=120)
|
| 499 |
+
for ep in pbar:
|
| 500 |
+
model.train(); tl = 0.0
|
| 501 |
+
for bx, by in tr_dl:
|
| 502 |
+
bx, by = bx.to(device), by.to(device)
|
| 503 |
+
step_preds = model(bx, deep_supervision=True)
|
| 504 |
+
loss = deep_supervision_loss(step_preds, by)
|
| 505 |
+
opt.zero_grad(set_to_none=True); loss.backward()
|
| 506 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 507 |
+
opt.step()
|
| 508 |
+
tl += F.l1_loss(step_preds[-1], by).item() * len(by)
|
| 509 |
+
tl /= len(tr_dl.dataset)
|
| 510 |
+
|
| 511 |
+
model.eval(); vl = 0.0
|
| 512 |
+
with torch.no_grad():
|
| 513 |
+
for bx, by in vl_dl:
|
| 514 |
+
bx, by = bx.to(device), by.to(device)
|
| 515 |
+
pred = model(bx)
|
| 516 |
+
vl += F.l1_loss(pred, by).item() * len(by)
|
| 517 |
+
vl /= len(vl_dl.dataset)
|
| 518 |
+
hist['train'].append(tl); hist['val'].append(vl)
|
| 519 |
+
|
| 520 |
+
if ep < swa_start:
|
| 521 |
+
sch.step()
|
| 522 |
+
if vl < best_v: best_v, best_w = vl, copy.deepcopy(model.state_dict())
|
| 523 |
+
else:
|
| 524 |
+
if not swa_on: swa_on = True
|
| 525 |
+
swa_m.update_parameters(model); swa_s.step()
|
| 526 |
+
|
| 527 |
+
pbar.set_postfix(Tr=f'{tl:.1f}', Val=f'{vl:.1f}',
|
| 528 |
+
Best=f'{best_v:.1f}', Ph='SWA' if swa_on else 'COS')
|
| 529 |
+
|
| 530 |
+
if swa_on:
|
| 531 |
+
update_bn(tr_dl, swa_m, device=device)
|
| 532 |
+
model.load_state_dict(swa_m.module.state_dict())
|
| 533 |
+
else:
|
| 534 |
+
model.load_state_dict(best_w)
|
| 535 |
+
return best_v, model, hist
|
| 536 |
+
|
| 537 |
+
|
| 538 |
+
def train_fold_confidence(model, tr_dl, vl_dl, device,
|
| 539 |
+
epochs=300, swa_start=200, fold=1, name=""):
|
| 540 |
+
"""Training with confidence-weighted deep supervision."""
|
| 541 |
+
opt = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
|
| 542 |
+
sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=swa_start, eta_min=1e-4)
|
| 543 |
+
swa_m = AveragedModel(model)
|
| 544 |
+
swa_s = SWALR(opt, swa_lr=5e-4)
|
| 545 |
+
swa_on = False
|
| 546 |
+
best_v, best_w = float('inf'), copy.deepcopy(model.state_dict())
|
| 547 |
+
hist = {'train': [], 'val': []}
|
| 548 |
+
|
| 549 |
+
pbar = tqdm(range(epochs), desc=f" [{name}] F{fold}/5",
|
| 550 |
+
leave=False, ncols=120)
|
| 551 |
+
for ep in pbar:
|
| 552 |
+
model.train(); tl = 0.0
|
| 553 |
+
for bx, by in tr_dl:
|
| 554 |
+
bx, by = bx.to(device), by.to(device)
|
| 555 |
+
step_preds, conf_logits = model(bx, deep_supervision=True)
|
| 556 |
+
loss = confidence_ds_loss(step_preds, by, conf_logits)
|
| 557 |
+
opt.zero_grad(set_to_none=True); loss.backward()
|
| 558 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| 559 |
+
opt.step()
|
| 560 |
+
# Track confidence-weighted MAE for display
|
| 561 |
+
with torch.no_grad():
|
| 562 |
+
cw = F.softmax(conf_logits, dim=1)
|
| 563 |
+
ps = torch.stack(step_preds, dim=1)
|
| 564 |
+
wp = (ps * cw).sum(dim=1)
|
| 565 |
+
tl += F.l1_loss(wp, by).item() * len(by)
|
| 566 |
+
tl /= len(tr_dl.dataset)
|
| 567 |
+
|
| 568 |
+
model.eval(); vl = 0.0
|
| 569 |
+
with torch.no_grad():
|
| 570 |
+
for bx, by in vl_dl:
|
| 571 |
+
bx, by = bx.to(device), by.to(device)
|
| 572 |
+
pred = model(bx) # uses confidence-weighted by default
|
| 573 |
+
vl += F.l1_loss(pred, by).item() * len(by)
|
| 574 |
+
vl /= len(vl_dl.dataset)
|
| 575 |
+
hist['train'].append(tl); hist['val'].append(vl)
|
| 576 |
+
|
| 577 |
+
if ep < swa_start:
|
| 578 |
+
sch.step()
|
| 579 |
+
if vl < best_v: best_v, best_w = vl, copy.deepcopy(model.state_dict())
|
| 580 |
+
else:
|
| 581 |
+
if not swa_on: swa_on = True
|
| 582 |
+
swa_m.update_parameters(model); swa_s.step()
|
| 583 |
+
|
| 584 |
+
pbar.set_postfix(Tr=f'{tl:.1f}', Val=f'{vl:.1f}',
|
| 585 |
+
Best=f'{best_v:.1f}', Ph='SWA' if swa_on else 'COS')
|
| 586 |
+
|
| 587 |
+
if swa_on:
|
| 588 |
+
update_bn(tr_dl, swa_m, device=device)
|
| 589 |
+
model.load_state_dict(swa_m.module.state_dict())
|
| 590 |
+
else:
|
| 591 |
+
model.load_state_dict(best_w)
|
| 592 |
+
return best_v, model, hist
|
| 593 |
+
|
| 594 |
+
|
| 595 |
+
def predict(model, dl, device):
|
| 596 |
+
model.eval(); preds = []
|
| 597 |
+
with torch.no_grad():
|
| 598 |
+
for bx, _ in dl:
|
| 599 |
+
preds.append(model(bx.to(device)).cpu())
|
| 600 |
+
return torch.cat(preds)
|
| 601 |
+
|
| 602 |
+
|
| 603 |
+
def predict_confidence(model, dl, device):
|
| 604 |
+
"""Predict using confidence model, also return per-step weights."""
|
| 605 |
+
model.eval()
|
| 606 |
+
all_preds, all_weights = [], []
|
| 607 |
+
with torch.no_grad():
|
| 608 |
+
for bx, _ in dl:
|
| 609 |
+
pred, weights = model(bx.to(device), return_confidence=True)
|
| 610 |
+
all_preds.append(pred.cpu())
|
| 611 |
+
all_weights.append(weights.cpu())
|
| 612 |
+
return torch.cat(all_preds), torch.cat(all_weights)
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
def get_targets(dl):
|
| 616 |
+
tgts = []
|
| 617 |
+
for _, by in dl: tgts.append(by)
|
| 618 |
+
return torch.cat(tgts)
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
# ══════════════════════════════════════════════════════════════════════
|
| 622 |
+
# 5. MAIN BENCHMARK — Multi-Seed Ensemble
|
| 623 |
+
# ══════════════════════════════════════════════════════════════════════
|
| 624 |
+
|
| 625 |
+
def run_benchmark():
|
| 626 |
+
t0 = time.time()
|
| 627 |
+
print("\n" + "═"*72)
|
| 628 |
+
print(" TRM-MatSci V13 │ 2-Layer SA + Multi-Seed Ensemble │ matbench_steels")
|
| 629 |
+
print(" V13A: 2-Layer SA + expanded features + standard DS (5-seed ensemble)")
|
| 630 |
+
print(" V13B: 2-Layer SA + expanded features + confidence DS (5-seed ensemble)")
|
| 631 |
+
print(f" Seeds: {SEEDS}")
|
| 632 |
+
print("═"*72 + "\n")
|
| 633 |
+
|
| 634 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| 635 |
+
if device.type == 'cuda':
|
| 636 |
+
log.info(f"GPU: {torch.cuda.get_device_name(0)} "
|
| 637 |
+
f"({torch.cuda.get_device_properties(0).total_mem/1e9:.1f} GB)")
|
| 638 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 639 |
+
torch.backends.cudnn.benchmark = True
|
| 640 |
+
|
| 641 |
+
log.info("Loading matbench_steels...")
|
| 642 |
+
from matminer.datasets import load_dataset
|
| 643 |
+
df = load_dataset("matbench_steels")
|
| 644 |
+
comps_raw = df['composition'].tolist()
|
| 645 |
+
targets_all = np.array(df['yield strength'].tolist(), np.float32)
|
| 646 |
+
comps_all = [Composition(c) for c in comps_raw]
|
| 647 |
+
|
| 648 |
+
# ── FEATURIZE ─────────────────────────────────────────────────────
|
| 649 |
+
log.info("Computing EXPANDED features...")
|
| 650 |
+
feat = ExpandedFeaturizer()
|
| 651 |
+
X_all = feat.featurize_all(comps_all)
|
| 652 |
+
n_extra = feat.n_extra
|
| 653 |
+
log.info(f"Features: {X_all.shape} (n_extra={n_extra})")
|
| 654 |
+
|
| 655 |
+
kfold = KFold(n_splits=5, shuffle=True, random_state=18012019)
|
| 656 |
+
folds = list(kfold.split(comps_all))
|
| 657 |
+
os.makedirs('trm_models_v13', exist_ok=True)
|
| 658 |
+
dl_kw = dict(batch_size=32, num_workers=0)
|
| 659 |
+
|
| 660 |
+
# ── CONFIGS ───────────────────────────────────────────────────────
|
| 661 |
+
shared_kw = dict(n_props=22, stat_dim=6, n_extra=n_extra,
|
| 662 |
+
mat2vec_dim=200, d_attn=64, nhead=4,
|
| 663 |
+
d_hidden=96, ff_dim=150, dropout=0.2)
|
| 664 |
+
|
| 665 |
+
configs = {
|
| 666 |
+
'V13A-2xSA-StdDS': {
|
| 667 |
+
'model_cls': DeepHybridTRM,
|
| 668 |
+
'model_kw': {**shared_kw, 'max_steps': 20},
|
| 669 |
+
'train_fn': train_fold_standard,
|
| 670 |
+
'predict_fn': predict,
|
| 671 |
+
'is_confidence': False,
|
| 672 |
+
},
|
| 673 |
+
'V13B-2xSA-ConfDS': {
|
| 674 |
+
'model_cls': DeepConfidenceHybridTRM,
|
| 675 |
+
'model_kw': {**shared_kw, 'max_steps': 22},
|
| 676 |
+
'train_fn': train_fold_confidence,
|
| 677 |
+
'predict_fn': None, # uses predict_confidence
|
| 678 |
+
'is_confidence': True,
|
| 679 |
+
},
|
| 680 |
+
}
|
| 681 |
+
|
| 682 |
+
# Print param counts
|
| 683 |
+
print(f"\n {'Config':<24} {'Params':>10} {'Steps':>8} {'Seeds':>6}")
|
| 684 |
+
print(f" {'─'*54}")
|
| 685 |
+
for cname, cfg in configs.items():
|
| 686 |
+
_m = cfg['model_cls'](**cfg['model_kw'])
|
| 687 |
+
np_ = _m.count_parameters(); del _m
|
| 688 |
+
cfg['n_params'] = np_
|
| 689 |
+
steps = cfg['model_kw']['max_steps']
|
| 690 |
+
print(f" {cname:<24} {np_:>10,} {steps:>8} {N_SEEDS:>6}")
|
| 691 |
+
print()
|
| 692 |
+
|
| 693 |
+
# ── TRAIN + EVALUATE (Multi-Seed) ─────────────────────────────────
|
| 694 |
+
all_results = {}
|
| 695 |
+
all_hists = {}
|
| 696 |
+
all_conf_weights = {}
|
| 697 |
+
|
| 698 |
+
for cname, cfg in configs.items():
|
| 699 |
+
print(f"\n{'▓'*72}")
|
| 700 |
+
print(f" {cname} — {N_SEEDS}-Seed Ensemble")
|
| 701 |
+
print(f"{'▓'*72}")
|
| 702 |
+
|
| 703 |
+
# Store per-seed, per-fold predictions and MAEs
|
| 704 |
+
seed_fold_preds = {s: {} for s in SEEDS} # seed -> {fold_idx: preds_tensor}
|
| 705 |
+
seed_fold_maes = {s: [] for s in SEEDS} # seed -> [mae_f1, ..., mae_f5]
|
| 706 |
+
fold_hists = [] # collect from first seed only
|
| 707 |
+
fold_conf_w = [] # collect from first seed only
|
| 708 |
+
|
| 709 |
+
for si, seed in enumerate(SEEDS):
|
| 710 |
+
print(f"\n ╔═══ Seed {seed} ({si+1}/{N_SEEDS}) ═══╗")
|
| 711 |
+
|
| 712 |
+
for fi, (tv_i, te_i) in enumerate(folds):
|
| 713 |
+
print(f"\n ── [{cname} seed={seed}] Fold {fi+1}/5 {'─'*30}")
|
| 714 |
+
|
| 715 |
+
tri, vli = strat_split(targets_all[tv_i], 0.15, seed+fi)
|
| 716 |
+
feat.fit_scaler(X_all[tv_i][tri])
|
| 717 |
+
tr_s = feat.transform(X_all[tv_i][tri])
|
| 718 |
+
vl_s = feat.transform(X_all[tv_i][vli])
|
| 719 |
+
te_s = feat.transform(X_all[te_i])
|
| 720 |
+
|
| 721 |
+
pin = device.type == 'cuda'
|
| 722 |
+
tr_dl = DataLoader(DSData(tr_s, targets_all[tv_i][tri]), shuffle=True,
|
| 723 |
+
pin_memory=pin, **dl_kw)
|
| 724 |
+
vl_dl = DataLoader(DSData(vl_s, targets_all[tv_i][vli]), shuffle=False,
|
| 725 |
+
pin_memory=pin, **dl_kw)
|
| 726 |
+
te_dl = DataLoader(DSData(te_s, targets_all[te_i]), shuffle=False,
|
| 727 |
+
pin_memory=pin, **dl_kw)
|
| 728 |
+
te_tgt = get_targets(te_dl)
|
| 729 |
+
|
| 730 |
+
torch.manual_seed(seed + fi); np.random.seed(seed + fi)
|
| 731 |
+
if device.type == 'cuda': torch.cuda.manual_seed(seed + fi)
|
| 732 |
+
|
| 733 |
+
model = cfg['model_cls'](**cfg['model_kw']).to(device)
|
| 734 |
+
bv, model, hist = cfg['train_fn'](model, tr_dl, vl_dl, device,
|
| 735 |
+
fold=fi+1,
|
| 736 |
+
name=f"{cname}[s{seed}]")
|
| 737 |
+
|
| 738 |
+
# Save hist only for first seed
|
| 739 |
+
if si == 0:
|
| 740 |
+
fold_hists.append(hist)
|
| 741 |
+
|
| 742 |
+
# Predict
|
| 743 |
+
if cfg['is_confidence']:
|
| 744 |
+
pred, conf_w = predict_confidence(model, te_dl, device)
|
| 745 |
+
if si == 0:
|
| 746 |
+
fold_conf_w.append(conf_w)
|
| 747 |
+
avg_peak = conf_w.argmax(dim=1).float().mean().item() + 1
|
| 748 |
+
mae = F.l1_loss(pred, te_tgt).item()
|
| 749 |
+
log.info(f" [s{seed}] F{fi+1}: MAE={mae:.2f} "
|
| 750 |
+
f"(val {bv:.2f}, avg peak step={avg_peak:.1f})")
|
| 751 |
+
else:
|
| 752 |
+
pred = cfg['predict_fn'](model, te_dl, device)
|
| 753 |
+
mae = F.l1_loss(pred, te_tgt).item()
|
| 754 |
+
log.info(f" [s{seed}] F{fi+1}: MAE={mae:.2f} (val {bv:.2f})")
|
| 755 |
+
|
| 756 |
+
seed_fold_preds[seed][fi] = pred
|
| 757 |
+
seed_fold_maes[seed].append(mae)
|
| 758 |
+
|
| 759 |
+
torch.save({'model_state': model.state_dict(), 'test_mae': mae,
|
| 760 |
+
'config': cname, 'seed': seed},
|
| 761 |
+
f'trm_models_v13/{cname}_seed{seed}_fold{fi+1}.pt')
|
| 762 |
+
|
| 763 |
+
# Free GPU memory
|
| 764 |
+
del model; torch.cuda.empty_cache() if device.type == 'cuda' else None
|
| 765 |
+
|
| 766 |
+
seed_avg = float(np.mean(seed_fold_maes[seed]))
|
| 767 |
+
print(f" ╚═══ Seed {seed} avg: {seed_avg:.2f} MPa ═══╝")
|
| 768 |
+
|
| 769 |
+
# ── Compute ensemble predictions ──────────────────────────────
|
| 770 |
+
ensemble_fold_maes = []
|
| 771 |
+
for fi, (tv_i, te_i) in enumerate(folds):
|
| 772 |
+
te_tgt_np = targets_all[te_i]
|
| 773 |
+
te_tgt_t = torch.tensor(te_tgt_np, dtype=torch.float32)
|
| 774 |
+
|
| 775 |
+
# Average predictions across all seeds for this fold
|
| 776 |
+
all_seed_preds = torch.stack([seed_fold_preds[s][fi] for s in SEEDS])
|
| 777 |
+
ensemble_pred = all_seed_preds.mean(dim=0)
|
| 778 |
+
|
| 779 |
+
ens_mae = F.l1_loss(ensemble_pred, te_tgt_t).item()
|
| 780 |
+
ensemble_fold_maes.append(ens_mae)
|
| 781 |
+
|
| 782 |
+
ens_avg = float(np.mean(ensemble_fold_maes))
|
| 783 |
+
ens_std = float(np.std(ensemble_fold_maes))
|
| 784 |
+
|
| 785 |
+
# Also compute per-seed averages for reporting
|
| 786 |
+
per_seed_avgs = {s: float(np.mean(seed_fold_maes[s])) for s in SEEDS}
|
| 787 |
+
best_single_seed = min(per_seed_avgs.items(), key=lambda x: x[1])
|
| 788 |
+
|
| 789 |
+
all_results[cname] = {
|
| 790 |
+
'avg': ens_avg, 'std': ens_std, 'folds': ensemble_fold_maes,
|
| 791 |
+
'params': cfg['n_params'],
|
| 792 |
+
'per_seed_avgs': per_seed_avgs,
|
| 793 |
+
'per_seed_folds': {str(s): seed_fold_maes[s] for s in SEEDS},
|
| 794 |
+
'best_single_seed': best_single_seed[0],
|
| 795 |
+
'best_single_mae': best_single_seed[1],
|
| 796 |
+
}
|
| 797 |
+
all_hists[cname] = fold_hists
|
| 798 |
+
if fold_conf_w:
|
| 799 |
+
all_conf_weights[cname] = fold_conf_w
|
| 800 |
+
|
| 801 |
+
print(f"\n ═══ {cname} ═══")
|
| 802 |
+
print(f" Ensemble ({N_SEEDS}-seed avg): {ens_avg:.4f} ±{ens_std:.4f} MPa")
|
| 803 |
+
print(f" Best single seed ({best_single_seed[0]}): "
|
| 804 |
+
f"{best_single_seed[1]:.4f} MPa")
|
| 805 |
+
for s in SEEDS:
|
| 806 |
+
print(f" Seed {s:>3}: {per_seed_avgs[s]:.2f} MPa "
|
| 807 |
+
f"folds={[f'{m:.1f}' for m in seed_fold_maes[s]]}")
|
| 808 |
+
|
| 809 |
+
# ══════════════════════════════════════════════════════════════════
|
| 810 |
+
# FINAL RESULTS
|
| 811 |
+
# ══════════════════════════════════════════════════════════════════
|
| 812 |
+
|
| 813 |
+
tt = time.time() - t0
|
| 814 |
+
print(f"\n{'═'*72}")
|
| 815 |
+
print(f" FINAL LEADERBOARD — matbench_steels V13 (5-Fold Avg MAE)")
|
| 816 |
+
print(f"{'═'*72}")
|
| 817 |
+
print(f" {'Model':<26} {'Params':>10} {'MAE(MPa)':>10} {'±Std':>8} Notes")
|
| 818 |
+
print(f" {'─'*72}")
|
| 819 |
+
for n, r in sorted(all_results.items(), key=lambda x: x[1]['avg']):
|
| 820 |
+
tag = (" ← BEATS MODNet 🏆" if r['avg'] < 87.76 else
|
| 821 |
+
" ← BEATS V12A ✓" if r['avg'] < 95.99 else
|
| 822 |
+
" ← BEATS RF-SCM ✓" if r['avg'] < 103.51 else
|
| 823 |
+
" ← BEATS DARWIN ✓" if r['avg'] < 123.29 else "")
|
| 824 |
+
print(f" {n+' (ens)':<26} {r['params']:>9,} "
|
| 825 |
+
f"{r['avg']:>10.4f} {r['std']:>8.4f}{tag}")
|
| 826 |
+
print(f" {n+' (best 1)':<26} {'':>10} "
|
| 827 |
+
f"{r['best_single_mae']:>10.4f} {'':>8} seed={r['best_single_seed']}")
|
| 828 |
+
print(f" {'─'*72}")
|
| 829 |
+
for bn, bv in sorted(BASELINES.items(), key=lambda x: x[1]):
|
| 830 |
+
print(f" {bn:<26} {'baseline':>10} {bv:>10.4f}")
|
| 831 |
+
print(f"\n Total time: {tt/60:.1f} minutes ({N_SEEDS} seeds × 2 configs × 5 folds)")
|
| 832 |
+
|
| 833 |
+
# Per-fold ensemble breakdown
|
| 834 |
+
print(f"\n{'═'*72}")
|
| 835 |
+
print(f" PER-FOLD ENSEMBLE BREAKDOWN")
|
| 836 |
+
print(f"{'═'*72}")
|
| 837 |
+
cnames = list(all_results.keys())
|
| 838 |
+
header = f" {'Fold':<6}"
|
| 839 |
+
for cn in cnames:
|
| 840 |
+
header += f" {cn:>20}"
|
| 841 |
+
print(header)
|
| 842 |
+
print(f" {'─'*52}")
|
| 843 |
+
for fi in range(5):
|
| 844 |
+
row = f" {fi+1:<6}"
|
| 845 |
+
for cn in cnames:
|
| 846 |
+
row += f" {all_results[cn]['folds'][fi]:>20.2f}"
|
| 847 |
+
print(row)
|
| 848 |
+
|
| 849 |
+
# Per-seed breakdown
|
| 850 |
+
print(f"\n{'═'*72}")
|
| 851 |
+
print(f" PER-SEED BREAKDOWN")
|
| 852 |
+
print(f"{'═'*72}")
|
| 853 |
+
for cn in cnames:
|
| 854 |
+
r = all_results[cn]
|
| 855 |
+
print(f"\n {cn}:")
|
| 856 |
+
header = f" {'Seed':<6}"
|
| 857 |
+
for fi in range(5):
|
| 858 |
+
header += f" {'F'+str(fi+1):>8}"
|
| 859 |
+
header += f" {'Avg':>8}"
|
| 860 |
+
print(header)
|
| 861 |
+
print(f" {'─'*52}")
|
| 862 |
+
for s in SEEDS:
|
| 863 |
+
row = f" {s:<6}"
|
| 864 |
+
for mae in r['per_seed_folds'][str(s)]:
|
| 865 |
+
row += f" {mae:>8.2f}"
|
| 866 |
+
row += f" {r['per_seed_avgs'][s]:>8.2f}"
|
| 867 |
+
print(row)
|
| 868 |
+
print(f" {'─'*52}")
|
| 869 |
+
row = f" {'ENS':<6}"
|
| 870 |
+
for mae in r['folds']:
|
| 871 |
+
row += f" {mae:>8.2f}"
|
| 872 |
+
row += f" {r['avg']:>8.2f}"
|
| 873 |
+
print(row)
|
| 874 |
+
|
| 875 |
+
# Confidence stats
|
| 876 |
+
if all_conf_weights:
|
| 877 |
+
print(f"\n Confidence Step Selection Summary:")
|
| 878 |
+
for cn, fw_list in all_conf_weights.items():
|
| 879 |
+
all_w = torch.cat(fw_list, dim=0)
|
| 880 |
+
avg_w = all_w.mean(dim=0)
|
| 881 |
+
peak_step = avg_w.argmax().item() + 1
|
| 882 |
+
avg_peak = all_w.argmax(dim=1).float().mean().item() + 1
|
| 883 |
+
print(f" {cn}: avg peak step={avg_peak:.1f}, "
|
| 884 |
+
f"population peak=step {peak_step}")
|
| 885 |
+
print()
|
| 886 |
+
|
| 887 |
+
generate_plots(all_results, all_hists, all_conf_weights)
|
| 888 |
+
save_summary(all_results, all_hists, all_conf_weights, tt)
|
| 889 |
+
return all_results
|
| 890 |
+
|
| 891 |
+
|
| 892 |
+
# ══════════════════════════════════════════════════════════════════════
|
| 893 |
+
# 6. PLOTS
|
| 894 |
+
# ══════════════════════════════════════════════════════════════════════
|
| 895 |
+
|
| 896 |
+
PAL = {'V13A-2xSA-StdDS': '#1565C0', 'V13B-2xSA-ConfDS': '#E65100'}
|
| 897 |
+
|
| 898 |
+
def generate_plots(all_results, all_hists, all_conf_weights):
|
| 899 |
+
names = list(all_results.keys())
|
| 900 |
+
avgs = [all_results[n]['avg'] for n in names]
|
| 901 |
+
stds = [all_results[n]['std'] for n in names]
|
| 902 |
+
cols = [PAL.get(n, '#888') for n in names]
|
| 903 |
+
|
| 904 |
+
fig = plt.figure(figsize=(22, 18))
|
| 905 |
+
gs = gridspec.GridSpec(2, 2, figure=fig, hspace=0.35, wspace=0.30)
|
| 906 |
+
|
| 907 |
+
# ── Plot 1: Bar chart vs baselines ────────────────────────────────
|
| 908 |
+
ax1 = fig.add_subplot(gs[0, 0])
|
| 909 |
+
|
| 910 |
+
# Show both ensemble and best-single-seed bars
|
| 911 |
+
x_pos = np.arange(len(names))
|
| 912 |
+
w = 0.35
|
| 913 |
+
ens_bars = ax1.bar(x_pos - w/2, avgs, w, yerr=stds, capsize=6,
|
| 914 |
+
color=cols, alpha=0.88, edgecolor='white',
|
| 915 |
+
linewidth=1.5, label='Ensemble')
|
| 916 |
+
best_singles = [all_results[n]['best_single_mae'] for n in names]
|
| 917 |
+
single_bars = ax1.bar(x_pos + w/2, best_singles, w, capsize=6,
|
| 918 |
+
color=cols, alpha=0.45, edgecolor='white',
|
| 919 |
+
linewidth=1.5, label='Best Single Seed',
|
| 920 |
+
hatch='//')
|
| 921 |
+
|
| 922 |
+
for bv, c, ls, lb in [
|
| 923 |
+
(87.76, '#F57F17', '--', 'MODNet (87.76)'),
|
| 924 |
+
(95.99, '#4CAF50', '-.', 'V12A (95.99)'),
|
| 925 |
+
(102.30, '#9E9E9E', '-.', 'V11B (102.30)'),
|
| 926 |
+
(103.51, '#B0BEC5', ':', 'RF-SCM (103.51)'),
|
| 927 |
+
(107.32, '#FF9800', ':', 'CrabNet (107.32)'),
|
| 928 |
+
]:
|
| 929 |
+
ax1.axhline(bv, color=c, linestyle=ls, linewidth=1.8, label=lb, alpha=0.85)
|
| 930 |
+
for bar, m, s in zip(ens_bars, avgs, stds):
|
| 931 |
+
ax1.text(bar.get_x()+bar.get_width()/2, bar.get_height()+s+1,
|
| 932 |
+
f'{m:.1f}', ha='center', fontsize=11, fontweight='bold')
|
| 933 |
+
for bar, m in zip(single_bars, best_singles):
|
| 934 |
+
ax1.text(bar.get_x()+bar.get_width()/2, bar.get_height()+1,
|
| 935 |
+
f'{m:.1f}', ha='center', fontsize=9, fontstyle='italic',
|
| 936 |
+
alpha=0.7)
|
| 937 |
+
|
| 938 |
+
ax1.set_xticks(x_pos)
|
| 939 |
+
ax1.set_xticklabels(names, fontsize=8)
|
| 940 |
+
ax1.legend(fontsize=6, loc='upper right')
|
| 941 |
+
ax1.set_ylabel('MAE (MPa)'); ax1.set_ylim(0, max(avgs)*1.6)
|
| 942 |
+
ax1.set_title('V13 Results vs Baselines (Ensemble + Best Single)',
|
| 943 |
+
fontsize=11, fontweight='bold')
|
| 944 |
+
ax1.grid(axis='y', alpha=0.3)
|
| 945 |
+
|
| 946 |
+
# ── Plot 2: Per-fold grouped bars ─────────────────────────────────
|
| 947 |
+
ax2 = fig.add_subplot(gs[0, 1])
|
| 948 |
+
x = np.arange(1, 6)
|
| 949 |
+
w = 0.35
|
| 950 |
+
for i, (n, col) in enumerate(zip(names, cols)):
|
| 951 |
+
fold_vals = all_results[n]['folds']
|
| 952 |
+
ax2.bar(x + (i - 0.5) * w, fold_vals, w, color=col, alpha=0.8,
|
| 953 |
+
label=n + ' (ens)', edgecolor='white')
|
| 954 |
+
ax2.axhline(95.99, color='#4CAF50', ls='-.', lw=1.5, label='V12A (95.99)')
|
| 955 |
+
ax2.axhline(87.76, color='#F57F17', ls='--', lw=1.5, label='MODNet (87.76)')
|
| 956 |
+
ax2.set_xlabel('Fold'); ax2.set_ylabel('MAE (MPa)')
|
| 957 |
+
ax2.set_xticks(x); ax2.set_xticklabels([f'F{i}' for i in range(1,6)])
|
| 958 |
+
ax2.set_title('Per-Fold Ensemble Breakdown', fontweight='bold')
|
| 959 |
+
ax2.legend(fontsize=7); ax2.grid(axis='y', alpha=0.2)
|
| 960 |
+
|
| 961 |
+
# ── Plot 3: Training/Val loss curves ──────────────────────────────
|
| 962 |
+
ax3 = fig.add_subplot(gs[1, 0])
|
| 963 |
+
for cname, col in PAL.items():
|
| 964 |
+
if cname not in all_hists: continue
|
| 965 |
+
for fi, h in enumerate(all_hists[cname]):
|
| 966 |
+
lb_tr = f'{cname} train' if fi == 0 else None
|
| 967 |
+
lb_vl = f'{cname} val' if fi == 0 else None
|
| 968 |
+
ax3.plot(h['train'], alpha=0.3, lw=0.8, color=col, label=lb_tr)
|
| 969 |
+
ax3.plot(h['val'], alpha=0.7, lw=1.2, color=col, label=lb_vl,
|
| 970 |
+
linestyle='--')
|
| 971 |
+
ax3.axhline(95.99, color='#4CAF50', ls='-.', lw=1.2, label='V12A (95.99)')
|
| 972 |
+
ax3.axvline(200, color='#4CAF50', ls='--', lw=1.2, alpha=0.6, label='SWA start')
|
| 973 |
+
ax3.set_xlabel('Epoch'); ax3.set_ylabel('MAE (MPa)')
|
| 974 |
+
ax3.set_title('Training Curves (seed 0, all folds)', fontweight='bold')
|
| 975 |
+
ax3.legend(fontsize=6, ncol=2); ax3.grid(alpha=0.2)
|
| 976 |
+
ax3.set_ylim(0, 300)
|
| 977 |
+
|
| 978 |
+
# ── Plot 4: Per-seed scatter / Confidence ─────────────────────────
|
| 979 |
+
ax4 = fig.add_subplot(gs[1, 1])
|
| 980 |
+
if all_conf_weights:
|
| 981 |
+
for cn, fw_list in all_conf_weights.items():
|
| 982 |
+
all_w = torch.cat(fw_list, dim=0)
|
| 983 |
+
avg_w = all_w.mean(dim=0).numpy()
|
| 984 |
+
steps = np.arange(1, len(avg_w)+1)
|
| 985 |
+
ax4.bar(steps, avg_w, color=PAL.get(cn, '#E65100'), alpha=0.8,
|
| 986 |
+
label=f'{cn} avg confidence', edgecolor='white')
|
| 987 |
+
std_w = all_w.std(dim=0).numpy()
|
| 988 |
+
ax4.errorbar(steps, avg_w, yerr=std_w, fmt='none',
|
| 989 |
+
ecolor='#333', capsize=2, alpha=0.5)
|
| 990 |
+
ax4.set_xlabel('Recursion Step')
|
| 991 |
+
ax4.set_ylabel('Confidence Weight (softmax)')
|
| 992 |
+
ax4.set_title('V13B: Where the Model Trusts Its Predictions',
|
| 993 |
+
fontweight='bold')
|
| 994 |
+
ax4.legend(fontsize=8)
|
| 995 |
+
ax4.grid(axis='y', alpha=0.2)
|
| 996 |
+
else:
|
| 997 |
+
# Show per-seed MAE scatter if no confidence model
|
| 998 |
+
for i, (cn, col) in enumerate(zip(names, cols)):
|
| 999 |
+
r = all_results[cn]
|
| 1000 |
+
seed_avgs = [r['per_seed_avgs'][s] for s in SEEDS]
|
| 1001 |
+
ax4.scatter(SEEDS, seed_avgs, s=80, c=col, alpha=0.8,
|
| 1002 |
+
label=f'{cn} per-seed', zorder=5,
|
| 1003 |
+
edgecolors='white', linewidth=1)
|
| 1004 |
+
ax4.axhline(r['avg'], color=col, ls='--', lw=1.5, alpha=0.6,
|
| 1005 |
+
label=f'{cn} ensemble={r["avg"]:.2f}')
|
| 1006 |
+
ax4.axhline(95.99, color='#4CAF50', ls=':', lw=1, alpha=0.5, label='V12A')
|
| 1007 |
+
ax4.set_xlabel('Random Seed')
|
| 1008 |
+
ax4.set_ylabel('5-Fold Avg MAE (MPa)')
|
| 1009 |
+
ax4.set_title('Per-Seed vs Ensemble Performance', fontweight='bold')
|
| 1010 |
+
ax4.legend(fontsize=7); ax4.grid(alpha=0.2)
|
| 1011 |
+
|
| 1012 |
+
fig.suptitle('TRM-MatSci V13 │ 2-Layer SA + Multi-Seed Ensemble │ matbench_steels',
|
| 1013 |
+
fontsize=14, fontweight='bold', y=1.01)
|
| 1014 |
+
fig.savefig('trm_results_v13.png', dpi=150, bbox_inches='tight')
|
| 1015 |
+
plt.close(fig); log.info("✓ Saved: trm_results_v13.png")
|
| 1016 |
+
|
| 1017 |
+
|
| 1018 |
+
def save_summary(all_results, all_hists, all_conf_weights, total_s):
|
| 1019 |
+
# Prepare confidence info
|
| 1020 |
+
conf_info = {}
|
| 1021 |
+
for cn, fw_list in all_conf_weights.items():
|
| 1022 |
+
all_w = torch.cat(fw_list, dim=0)
|
| 1023 |
+
conf_info[cn] = {
|
| 1024 |
+
'avg_weights': all_w.mean(dim=0).numpy().round(4).tolist(),
|
| 1025 |
+
'avg_peak_step': float(all_w.argmax(dim=1).float().mean().item() + 1),
|
| 1026 |
+
}
|
| 1027 |
+
|
| 1028 |
+
s = {
|
| 1029 |
+
'version': 'V13', 'task': 'matbench_steels',
|
| 1030 |
+
'strategy': '2-Layer SA + Multi-Seed Ensemble',
|
| 1031 |
+
'seeds': SEEDS,
|
| 1032 |
+
'n_seeds': N_SEEDS,
|
| 1033 |
+
'total_min': round(total_s/60, 1),
|
| 1034 |
+
'models': {},
|
| 1035 |
+
'confidence': conf_info,
|
| 1036 |
+
}
|
| 1037 |
+
for n, r in all_results.items():
|
| 1038 |
+
s['models'][n] = {
|
| 1039 |
+
'ensemble_avg': round(r['avg'], 4),
|
| 1040 |
+
'ensemble_std': round(r['std'], 4),
|
| 1041 |
+
'ensemble_folds': [round(x, 4) for x in r['folds']],
|
| 1042 |
+
'params': r['params'],
|
| 1043 |
+
'best_single_seed': r['best_single_seed'],
|
| 1044 |
+
'best_single_mae': round(r['best_single_mae'], 4),
|
| 1045 |
+
'per_seed_avgs': {str(k): round(v, 4) for k, v in r['per_seed_avgs'].items()},
|
| 1046 |
+
}
|
| 1047 |
+
|
| 1048 |
+
with open('trm_models_v13/summary_v13.json', 'w') as f:
|
| 1049 |
+
json.dump(s, f, indent=2, default=str)
|
| 1050 |
+
log.info("✓ Saved: summary_v13.json")
|
| 1051 |
+
|
| 1052 |
+
|
| 1053 |
+
if __name__ == '__main__':
|
| 1054 |
+
results = run_benchmark()
|
| 1055 |
+
shutil.make_archive("trm_v13_all", "zip", "trm_models_v13")
|
| 1056 |
+
log.info("✓ Created trm_v13_all.zip")
|
requirements.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch>=2.0
|
| 2 |
+
pymatgen>=2024.1.1
|
| 3 |
+
matminer>=0.9.0
|
| 4 |
+
gensim>=4.0.0
|
| 5 |
+
scikit-learn>=1.3.0
|
| 6 |
+
numpy>=1.24.0
|
| 7 |
+
pandas>=2.0.0
|
| 8 |
+
tqdm>=4.65.0
|
| 9 |
+
huggingface_hub>=0.20.0
|
| 10 |
+
gradio>=4.0.0
|
weights/README.md
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d2a5bec16529a25e4d500eea32ec1c9aaff2d12b3a014220f4c0303a75fffa04
|
| 3 |
+
size 1165
|
weights/expt_gap/weights.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:2f4658f262e0f3501e5716c35184fbcc86a4bc28765fbcbcc34756ce1ebf0976
|
| 3 |
+
size 2111183
|
weights/glass/weights.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:f6f173c5a305bcee0ec837e7b6a58802b9f88b3745349913721757dc7d1e2c77
|
| 3 |
+
size 966543
|
weights/is_metal/weights.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:06ed5d20f532f9193aed736f92cb94a7b181ea6b347e959dc7612100f3ff073c
|
| 3 |
+
size 970383
|
weights/jdft2d/weights.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:21d3e4c4728e18e473b4860d81bec77a2a1633540f89c35160811ec9625c4569
|
| 3 |
+
size 1598799
|
weights/phonons/weights.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:47fe2ab26addf64bfc1e78c6f6e9b02e408ee290b4f51fb91d85d5b270c51193
|
| 3 |
+
size 6170267
|
weights/steels/weights.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:64ee164db899d44365bea3a67ef258d7e122144d4088357dd56af10a1c0af838
|
| 3 |
+
size 4574159
|