TRIADS / model_arch.py
Rtx09's picture
Upload model_arch.py with huggingface_hub
3387bee verified
"""
╔══════════════════════════════════════════════════════════════════════╗
β•‘ TRM-MatSci V13 β€” 2-Layer SA + Multi-Seed Ensemble β•‘
β•‘ Dataset: matbench_steels β”‚ 312 samples β”‚ 5-Fold Nested CV β•‘
β•‘ β•‘
β•‘ V13A 2-Layer Self-Attention + Standard Deep Supervision β•‘
β•‘ d_attn=64, nhead=4, d_hidden=96, ff_dim=150, 20 steps β•‘
β•‘ Expanded features (Magpie + Mat2Vec + Extra descriptors) β•‘
β•‘ 2nd SA layer for higher-order property interactions β•‘
β•‘ 5-seed ensemble (avg predictions across seeds) β•‘
β•‘ β•‘
β•‘ V13B Same 2-Layer SA + Confidence-Weighted Deep Supervision β•‘
β•‘ 22 steps, confidence_head learns which step to trust β•‘
β•‘ 5-seed ensemble (avg predictions across seeds) β•‘
β•‘ β•‘
β•‘ All models: Deep Supervision + SWA + AdamW + 300 epochs β•‘
β•‘ Baseline: V12A = 95.99 MPa (current best) β•‘
β•šβ•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•β•
"""
import os, copy, json, time, logging, warnings, shutil, urllib.request
warnings.filterwarnings('ignore')
import numpy as np
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from tqdm import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torch.optim as optim
from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
from sklearn.model_selection import KFold
from sklearn.preprocessing import StandardScaler
from pymatgen.core import Composition
from matminer.featurizers.composition import ElementProperty
from gensim.models import Word2Vec
logging.basicConfig(level=logging.INFO, format='%(name)s β”‚ %(message)s')
log = logging.getLogger("TRM13")
# Seeds for multi-seed ensemble
SEEDS = [42, 123, 7, 0, 99]
N_SEEDS = len(SEEDS)
BASELINES = {
'TPOT-Mat': 79.9468,
'AutoML-Mat': 82.3043,
'MODNet': 87.7627,
'RF-SCM/Magpie': 103.5125,
'V12A (best)': 95.9900,
'V11B': 102.3003,
'V10A': 103.2867,
'CrabNet': 107.3160,
'Darwin': 123.2932,
}
# ══════════════════════════════════════════════════════════════════════
# 1. FEATURIZER + DATASET
# ══════════════════════════════════════════════════════════════════════
class ExpandedFeaturizer:
"""Magpie (22 props Γ— 6 stats) + Extra matminer descriptors + Mat2Vec (200d).
Extra descriptors: ElementFraction, Stoichiometry, ValenceOrbital,
IonProperty, BandCenter β€” all concatenated as a flat vector between
the Magpie block and Mat2Vec.
"""
GCS = "https://storage.googleapis.com/mat2vec/"
FILES = ["pretrained_embeddings",
"pretrained_embeddings.wv.vectors.npy",
"pretrained_embeddings.trainables.syn1neg.npy"]
def __init__(self, cache="mat2vec_cache"):
from matminer.featurizers.composition import (
ElementFraction, Stoichiometry, ValenceOrbital,
IonProperty, BandCenter
)
from matminer.featurizers.base import MultipleFeaturizer
self.ep_magpie = ElementProperty.from_preset("magpie")
self.n_mg = len(self.ep_magpie.feature_labels())
self.extra_feats = MultipleFeaturizer([
ElementFraction(),
Stoichiometry(),
ValenceOrbital(),
IonProperty(),
BandCenter(),
])
self.n_extra = None # detected at featurize time
self.scaler = None
os.makedirs(cache, exist_ok=True)
for f in self.FILES:
p = os.path.join(cache, f)
if not os.path.exists(p):
log.info(f" Downloading {f}...")
urllib.request.urlretrieve(self.GCS + f, p)
self.m2v = Word2Vec.load(os.path.join(cache, "pretrained_embeddings"))
self.emb = {w: self.m2v.wv[w] for w in self.m2v.wv.index_to_key}
def _pool(self, c):
v, t = np.zeros(200, np.float32), 0.0
for s, f in c.get_el_amt_dict().items():
if s in self.emb: v += f * self.emb[s]; t += f
return v / max(t, 1e-8)
def featurize_all(self, comps):
out = []
for c in tqdm(comps, desc=" Featurizing (expanded)", leave=False):
try: mg = np.array(self.ep_magpie.featurize(c), np.float32)
except: mg = np.zeros(self.n_mg, np.float32)
try:
ex = np.array(self.extra_feats.featurize(c), np.float32)
except:
ex = np.zeros(self.n_extra or 200, np.float32)
# Force n_extra to exactly 200 for V13A weight compatibility
# (64 attn + 200 extra = 264 pool input)
if len(ex) > 200:
ex = ex[:200]
elif len(ex) < 200:
ex = np.pad(ex, (0, 200 - len(ex)))
if self.n_extra is None:
self.n_extra = 200
log.info(f"Expanded features: {self.n_mg} Magpie + "
f"200 Extra (forced) + 200 Mat2Vec = "
f"{self.n_mg + 200 + 200}d")
out.append(np.concatenate([
np.nan_to_num(mg, nan=0.0),
np.nan_to_num(ex, nan=0.0),
self._pool(c)
]))
return np.array(out)
def fit_scaler(self, X): self.scaler = StandardScaler().fit(X)
def transform(self, X):
if not self.scaler: return X
return np.nan_to_num(self.scaler.transform(X), nan=0.0).astype(np.float32)
class DSData(Dataset):
def __init__(self, X, y):
self.X = torch.tensor(X, dtype=torch.float32)
self.y = torch.tensor(np.array(y, np.float32))
def __len__(self): return len(self.y)
def __getitem__(self, i): return self.X[i], self.y[i]
# ══════════════════════════════════════════════════════════════════════
# 2. MODELS β€” with 2-Layer Self-Attention
# ══════════════════════════════════════════════════════════════════════
class DeepHybridTRM(nn.Module):
"""V13A: 2-Layer SA Hybrid-TRM with Standard Deep Supervision.
Key difference from V12A's HybridTRM:
- TWO self-attention layers (SA1 β†’ FF1 β†’ SA2 β†’ FF2 β†’ CA)
- Each SA layer has its own residual + LayerNorm + FF block
- This enables higher-order property interaction modeling
(e.g., "correlation between electronegativity-range AND
atomic-radius-mean" requires composing two rounds of attention)
- Cross-attention (CA) to Mat2Vec context remains after SA stack
Everything else (MLP reasoning loop, deep supervision, SWA)
is identical to V12A.
"""
def __init__(self, n_props=22, stat_dim=6, n_extra=200, mat2vec_dim=200,
d_attn=64, nhead=4, d_hidden=96, ff_dim=150,
dropout=0.2, max_steps=20, **kw):
super().__init__()
self.max_steps, self.D = max_steps, d_hidden
self.n_props, self.stat_dim = n_props, stat_dim
self.n_extra = n_extra
# ── Attention feature extractor (2-Layer SA) ──────────────────
self.tok_proj = nn.Sequential(
nn.Linear(stat_dim, d_attn), nn.LayerNorm(d_attn), nn.GELU())
self.m2v_proj = nn.Sequential(
nn.Linear(mat2vec_dim, d_attn), nn.LayerNorm(d_attn), nn.GELU())
# Self-Attention Layer 1
self.sa1 = nn.MultiheadAttention(
d_attn, nhead, dropout=dropout, batch_first=True)
self.sa1_n = nn.LayerNorm(d_attn)
self.sa1_ff = nn.Sequential(
nn.Linear(d_attn, d_attn*2), nn.GELU(), nn.Dropout(dropout),
nn.Linear(d_attn*2, d_attn))
self.sa1_fn = nn.LayerNorm(d_attn)
# Self-Attention Layer 2 (NEW β€” captures higher-order interactions)
self.sa2 = nn.MultiheadAttention(
d_attn, nhead, dropout=dropout, batch_first=True)
self.sa2_n = nn.LayerNorm(d_attn)
self.sa2_ff = nn.Sequential(
nn.Linear(d_attn, d_attn*2), nn.GELU(), nn.Dropout(dropout),
nn.Linear(d_attn*2, d_attn))
self.sa2_fn = nn.LayerNorm(d_attn)
# Cross-Attention to Mat2Vec context (after SA stack)
self.ca = nn.MultiheadAttention(
d_attn, nhead, dropout=dropout, batch_first=True)
self.ca_n = nn.LayerNorm(d_attn)
# Pool with optional extra feature injection
pool_in = d_attn + (n_extra if n_extra > 0 else 0)
self.pool = nn.Sequential(
nn.Linear(pool_in, d_hidden), nn.LayerNorm(d_hidden), nn.GELU())
# MLP-TRM recursive reasoning (shared weights)
self.z_up = nn.Sequential(
nn.Linear(d_hidden*3, ff_dim), nn.GELU(), nn.Dropout(dropout),
nn.Linear(ff_dim, d_hidden), nn.LayerNorm(d_hidden))
self.y_up = nn.Sequential(
nn.Linear(d_hidden*2, ff_dim), nn.GELU(), nn.Dropout(dropout),
nn.Linear(ff_dim, d_hidden), nn.LayerNorm(d_hidden))
self.head = nn.Linear(d_hidden, 1)
self._init()
def _init(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None: nn.init.zeros_(m.bias)
def _attention(self, x):
B = x.size(0)
mg_dim = self.n_props * self.stat_dim
mg = x[:, :mg_dim]
if self.n_extra > 0:
extra = x[:, mg_dim:mg_dim + self.n_extra]
m2v = x[:, mg_dim + self.n_extra:]
else:
extra = None
m2v = x[:, mg_dim:]
tok = self.tok_proj(mg.view(B, self.n_props, self.stat_dim))
ctx = self.m2v_proj(m2v).unsqueeze(1)
# SA Layer 1: learn pairwise property interactions
tok = self.sa1_n(tok + self.sa1(tok, tok, tok)[0])
tok = self.sa1_fn(tok + self.sa1_ff(tok))
# SA Layer 2: learn higher-order property interactions
tok = self.sa2_n(tok + self.sa2(tok, tok, tok)[0])
tok = self.sa2_fn(tok + self.sa2_ff(tok))
# Cross-Attention to Mat2Vec chemistry context
tok = self.ca_n(tok + self.ca(tok, ctx, ctx)[0])
pooled = tok.mean(dim=1) # [B, d_attn]
if extra is not None:
pooled = torch.cat([pooled, extra], dim=-1)
return self.pool(pooled) # [B, d_hidden]
def forward(self, x, deep_supervision=False, return_trajectory=False):
B = x.size(0)
xp = self._attention(x)
z = torch.zeros(B, self.D, device=x.device)
y = torch.zeros(B, self.D, device=x.device)
step_preds = []
for _ in range(self.max_steps):
z = z + self.z_up(torch.cat([xp, y, z], -1))
y = y + self.y_up(torch.cat([y, z], -1))
step_preds.append(self.head(y).squeeze(1))
if deep_supervision:
return step_preds
elif return_trajectory:
return step_preds[-1], step_preds
else:
return step_preds[-1]
def count_parameters(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
class DeepConfidenceHybridTRM(nn.Module):
"""V13B: 2-Layer SA Hybrid-TRM with Confidence-Weighted Deep Supervision.
Same 2-layer SA feature extractor as DeepHybridTRM, but with:
- confidence_head that learns which recursion step to trust
- Final prediction = softmax(confidence) Β· step_preds
- No ponder cost (avoids V11C's failure)
- 22 recursion steps (vs 20 for V13A)
"""
def __init__(self, n_props=22, stat_dim=6, n_extra=0, mat2vec_dim=200,
d_attn=64, nhead=4, d_hidden=96, ff_dim=150,
dropout=0.2, max_steps=22, **kw):
super().__init__()
self.max_steps, self.D = max_steps, d_hidden
self.n_props, self.stat_dim = n_props, stat_dim
self.n_extra = n_extra
# ── Attention feature extractor (2-Layer SA) ──────────────────
self.tok_proj = nn.Sequential(
nn.Linear(stat_dim, d_attn), nn.LayerNorm(d_attn), nn.GELU())
self.m2v_proj = nn.Sequential(
nn.Linear(mat2vec_dim, d_attn), nn.LayerNorm(d_attn), nn.GELU())
# Self-Attention Layer 1
self.sa1 = nn.MultiheadAttention(
d_attn, nhead, dropout=dropout, batch_first=True)
self.sa1_n = nn.LayerNorm(d_attn)
self.sa1_ff = nn.Sequential(
nn.Linear(d_attn, d_attn*2), nn.GELU(), nn.Dropout(dropout),
nn.Linear(d_attn*2, d_attn))
self.sa1_fn = nn.LayerNorm(d_attn)
# Self-Attention Layer 2 (higher-order interactions)
self.sa2 = nn.MultiheadAttention(
d_attn, nhead, dropout=dropout, batch_first=True)
self.sa2_n = nn.LayerNorm(d_attn)
self.sa2_ff = nn.Sequential(
nn.Linear(d_attn, d_attn*2), nn.GELU(), nn.Dropout(dropout),
nn.Linear(d_attn*2, d_attn))
self.sa2_fn = nn.LayerNorm(d_attn)
# Cross-Attention to Mat2Vec context
self.ca = nn.MultiheadAttention(
d_attn, nhead, dropout=dropout, batch_first=True)
self.ca_n = nn.LayerNorm(d_attn)
# Pool with optional extra feature injection
pool_in = d_attn + (n_extra if n_extra > 0 else 0)
self.pool = nn.Sequential(
nn.Linear(pool_in, d_hidden), nn.LayerNorm(d_hidden), nn.GELU())
# MLP-TRM recursive reasoning (shared weights)
self.z_up = nn.Sequential(
nn.Linear(d_hidden*3, ff_dim), nn.GELU(), nn.Dropout(dropout),
nn.Linear(ff_dim, d_hidden), nn.LayerNorm(d_hidden))
self.y_up = nn.Sequential(
nn.Linear(d_hidden*2, ff_dim), nn.GELU(), nn.Dropout(dropout),
nn.Linear(ff_dim, d_hidden), nn.LayerNorm(d_hidden))
self.head = nn.Linear(d_hidden, 1)
# ── Confidence head: learns which step to trust ──────────────
self.confidence_head = nn.Sequential(
nn.Linear(d_hidden, d_hidden // 2), nn.GELU(),
nn.Linear(d_hidden // 2, 1)) # raw logit, softmaxed later
self._init()
def _init(self):
for m in self.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None: nn.init.zeros_(m.bias)
with torch.no_grad():
nn.init.zeros_(self.confidence_head[-1].bias)
def _attention(self, x):
B = x.size(0)
mg_dim = self.n_props * self.stat_dim
mg = x[:, :mg_dim]
if self.n_extra > 0:
extra = x[:, mg_dim:mg_dim + self.n_extra]
m2v = x[:, mg_dim + self.n_extra:]
else:
extra = None
m2v = x[:, mg_dim:]
tok = self.tok_proj(mg.view(B, self.n_props, self.stat_dim))
ctx = self.m2v_proj(m2v).unsqueeze(1)
# SA Layer 1
tok = self.sa1_n(tok + self.sa1(tok, tok, tok)[0])
tok = self.sa1_fn(tok + self.sa1_ff(tok))
# SA Layer 2
tok = self.sa2_n(tok + self.sa2(tok, tok, tok)[0])
tok = self.sa2_fn(tok + self.sa2_ff(tok))
# Cross-Attention
tok = self.ca_n(tok + self.ca(tok, ctx, ctx)[0])
pooled = tok.mean(dim=1)
if extra is not None:
pooled = torch.cat([pooled, extra], dim=-1)
return self.pool(pooled)
def forward(self, x, deep_supervision=False, return_confidence=False):
"""Forward pass.
Returns:
deep_supervision=True: (step_preds, confidence_logits)
deep_supervision=False, return_confidence=False:
weighted_pred: [B] confidence-weighted prediction
deep_supervision=False, return_confidence=True:
(weighted_pred, confidence_weights): [B], [B, max_steps]
"""
B = x.size(0)
xp = self._attention(x)
z = torch.zeros(B, self.D, device=x.device)
y = torch.zeros(B, self.D, device=x.device)
step_preds = []
conf_logits = []
for _ in range(self.max_steps):
z = z + self.z_up(torch.cat([xp, y, z], -1))
y = y + self.y_up(torch.cat([y, z], -1))
step_preds.append(self.head(y).squeeze(1))
conf_logits.append(self.confidence_head(y).squeeze(1))
conf_logits = torch.stack(conf_logits, dim=1) # [B, max_steps]
if deep_supervision:
return step_preds, conf_logits
# Confidence-weighted prediction
conf_weights = F.softmax(conf_logits, dim=1) # [B, max_steps]
preds_stack = torch.stack(step_preds, dim=1) # [B, max_steps]
weighted_pred = (preds_stack * conf_weights).sum(dim=1) # [B]
if return_confidence:
return weighted_pred, conf_weights
return weighted_pred
def count_parameters(self):
return sum(p.numel() for p in self.parameters() if p.requires_grad)
# ══════════════════════════════════════════════════════════════════════
# 3. LOSS FUNCTIONS
# ══════════════════════════════════════════════════════════════════════
def deep_supervision_loss(step_preds, targets):
"""Linear-weighted L1 loss across all recursion steps."""
n = len(step_preds)
weights = [(i + 1) for i in range(n)]
total_w = sum(weights)
loss = 0.0
for pred, w in zip(step_preds, weights):
loss += (w / total_w) * F.l1_loss(pred, targets)
return loss
def confidence_ds_loss(step_preds, targets, conf_logits):
"""Advanced Deep Supervision: standard DS + confidence-weighted L1.
Components:
1. Standard linear-weighted deep supervision on all steps
2. L1 loss on the confidence-weighted final prediction
"""
ds = deep_supervision_loss(step_preds, targets)
conf_weights = F.softmax(conf_logits, dim=1) # [B, max_steps]
preds_stack = torch.stack(step_preds, dim=1) # [B, max_steps]
weighted_pred = (preds_stack * conf_weights).sum(dim=1)
conf_loss = F.l1_loss(weighted_pred, targets)
return ds + conf_loss
# ══════════════════════════════════════════════════════════════════════
# 4. UTILS + TRAINING
# ══════════════════════════════════════════════════════════════════════
def strat_split(targets, val_size=0.15, seed=42):
bins = np.percentile(targets, [25, 50, 75])
lbl = np.digitize(targets, bins)
tr, vl = [], []
rng = np.random.RandomState(seed)
for b in range(4):
m = np.where(lbl == b)[0]
if len(m) == 0: continue
n = max(1, int(len(m) * val_size))
c = rng.choice(m, n, replace=False)
vl.extend(c.tolist()); tr.extend(np.setdiff1d(m, c).tolist())
return np.array(tr), np.array(vl)
def train_fold_standard(model, tr_dl, vl_dl, device,
epochs=300, swa_start=200, fold=1, name=""):
"""Training with standard deep supervision."""
opt = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=swa_start, eta_min=1e-4)
swa_m = AveragedModel(model)
swa_s = SWALR(opt, swa_lr=5e-4)
swa_on = False
best_v, best_w = float('inf'), copy.deepcopy(model.state_dict())
hist = {'train': [], 'val': []}
pbar = tqdm(range(epochs), desc=f" [{name}] F{fold}/5",
leave=False, ncols=120)
for ep in pbar:
model.train(); tl = 0.0
for bx, by in tr_dl:
bx, by = bx.to(device), by.to(device)
step_preds = model(bx, deep_supervision=True)
loss = deep_supervision_loss(step_preds, by)
opt.zero_grad(set_to_none=True); loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
tl += F.l1_loss(step_preds[-1], by).item() * len(by)
tl /= len(tr_dl.dataset)
model.eval(); vl = 0.0
with torch.no_grad():
for bx, by in vl_dl:
bx, by = bx.to(device), by.to(device)
pred = model(bx)
vl += F.l1_loss(pred, by).item() * len(by)
vl /= len(vl_dl.dataset)
hist['train'].append(tl); hist['val'].append(vl)
if ep < swa_start:
sch.step()
if vl < best_v: best_v, best_w = vl, copy.deepcopy(model.state_dict())
else:
if not swa_on: swa_on = True
swa_m.update_parameters(model); swa_s.step()
pbar.set_postfix(Tr=f'{tl:.1f}', Val=f'{vl:.1f}',
Best=f'{best_v:.1f}', Ph='SWA' if swa_on else 'COS')
if swa_on:
update_bn(tr_dl, swa_m, device=device)
model.load_state_dict(swa_m.module.state_dict())
else:
model.load_state_dict(best_w)
return best_v, model, hist
def train_fold_confidence(model, tr_dl, vl_dl, device,
epochs=300, swa_start=200, fold=1, name=""):
"""Training with confidence-weighted deep supervision."""
opt = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=swa_start, eta_min=1e-4)
swa_m = AveragedModel(model)
swa_s = SWALR(opt, swa_lr=5e-4)
swa_on = False
best_v, best_w = float('inf'), copy.deepcopy(model.state_dict())
hist = {'train': [], 'val': []}
pbar = tqdm(range(epochs), desc=f" [{name}] F{fold}/5",
leave=False, ncols=120)
for ep in pbar:
model.train(); tl = 0.0
for bx, by in tr_dl:
bx, by = bx.to(device), by.to(device)
step_preds, conf_logits = model(bx, deep_supervision=True)
loss = confidence_ds_loss(step_preds, by, conf_logits)
opt.zero_grad(set_to_none=True); loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
# Track confidence-weighted MAE for display
with torch.no_grad():
cw = F.softmax(conf_logits, dim=1)
ps = torch.stack(step_preds, dim=1)
wp = (ps * cw).sum(dim=1)
tl += F.l1_loss(wp, by).item() * len(by)
tl /= len(tr_dl.dataset)
model.eval(); vl = 0.0
with torch.no_grad():
for bx, by in vl_dl:
bx, by = bx.to(device), by.to(device)
pred = model(bx) # uses confidence-weighted by default
vl += F.l1_loss(pred, by).item() * len(by)
vl /= len(vl_dl.dataset)
hist['train'].append(tl); hist['val'].append(vl)
if ep < swa_start:
sch.step()
if vl < best_v: best_v, best_w = vl, copy.deepcopy(model.state_dict())
else:
if not swa_on: swa_on = True
swa_m.update_parameters(model); swa_s.step()
pbar.set_postfix(Tr=f'{tl:.1f}', Val=f'{vl:.1f}',
Best=f'{best_v:.1f}', Ph='SWA' if swa_on else 'COS')
if swa_on:
update_bn(tr_dl, swa_m, device=device)
model.load_state_dict(swa_m.module.state_dict())
else:
model.load_state_dict(best_w)
return best_v, model, hist
def predict(model, dl, device):
model.eval(); preds = []
with torch.no_grad():
for bx, _ in dl:
preds.append(model(bx.to(device)).cpu())
return torch.cat(preds)
def predict_confidence(model, dl, device):
"""Predict using confidence model, also return per-step weights."""
model.eval()
all_preds, all_weights = [], []
with torch.no_grad():
for bx, _ in dl:
pred, weights = model(bx.to(device), return_confidence=True)
all_preds.append(pred.cpu())
all_weights.append(weights.cpu())
return torch.cat(all_preds), torch.cat(all_weights)
def get_targets(dl):
tgts = []
for _, by in dl: tgts.append(by)
return torch.cat(tgts)
# ══════════════════════════════════════════════════════════════════════
# 5. MAIN BENCHMARK β€” Multi-Seed Ensemble
# ══════════════════════════════════════════════════════════════════════
def run_benchmark():
t0 = time.time()
print("\n" + "═"*72)
print(" TRM-MatSci V13 β”‚ 2-Layer SA + Multi-Seed Ensemble β”‚ matbench_steels")
print(" V13A: 2-Layer SA + expanded features + standard DS (5-seed ensemble)")
print(" V13B: 2-Layer SA + expanded features + confidence DS (5-seed ensemble)")
print(f" Seeds: {SEEDS}")
print("═"*72 + "\n")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type == 'cuda':
log.info(f"GPU: {torch.cuda.get_device_name(0)} "
f"({torch.cuda.get_device_properties(0).total_mem/1e9:.1f} GB)")
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.benchmark = True
log.info("Loading matbench_steels...")
from matminer.datasets import load_dataset
df = load_dataset("matbench_steels")
comps_raw = df['composition'].tolist()
targets_all = np.array(df['yield strength'].tolist(), np.float32)
comps_all = [Composition(c) for c in comps_raw]
# ── FEATURIZE ─────────────────────────────────────────────────────
log.info("Computing EXPANDED features...")
feat = ExpandedFeaturizer()
X_all = feat.featurize_all(comps_all)
n_extra = feat.n_extra
log.info(f"Features: {X_all.shape} (n_extra={n_extra})")
kfold = KFold(n_splits=5, shuffle=True, random_state=18012019)
folds = list(kfold.split(comps_all))
os.makedirs('trm_models_v13', exist_ok=True)
dl_kw = dict(batch_size=32, num_workers=0)
# ── CONFIGS ───────────────────────────────────────────────────────
shared_kw = dict(n_props=22, stat_dim=6, n_extra=n_extra,
mat2vec_dim=200, d_attn=64, nhead=4,
d_hidden=96, ff_dim=150, dropout=0.2)
configs = {
'V13A-2xSA-StdDS': {
'model_cls': DeepHybridTRM,
'model_kw': {**shared_kw, 'max_steps': 20},
'train_fn': train_fold_standard,
'predict_fn': predict,
'is_confidence': False,
},
'V13B-2xSA-ConfDS': {
'model_cls': DeepConfidenceHybridTRM,
'model_kw': {**shared_kw, 'max_steps': 22},
'train_fn': train_fold_confidence,
'predict_fn': None, # uses predict_confidence
'is_confidence': True,
},
}
# Print param counts
print(f"\n {'Config':<24} {'Params':>10} {'Steps':>8} {'Seeds':>6}")
print(f" {'─'*54}")
for cname, cfg in configs.items():
_m = cfg['model_cls'](**cfg['model_kw'])
np_ = _m.count_parameters(); del _m
cfg['n_params'] = np_
steps = cfg['model_kw']['max_steps']
print(f" {cname:<24} {np_:>10,} {steps:>8} {N_SEEDS:>6}")
print()
# ── TRAIN + EVALUATE (Multi-Seed) ─────────────────────────────────
all_results = {}
all_hists = {}
all_conf_weights = {}
for cname, cfg in configs.items():
print(f"\n{'β–“'*72}")
print(f" {cname} β€” {N_SEEDS}-Seed Ensemble")
print(f"{'β–“'*72}")
# Store per-seed, per-fold predictions and MAEs
seed_fold_preds = {s: {} for s in SEEDS} # seed -> {fold_idx: preds_tensor}
seed_fold_maes = {s: [] for s in SEEDS} # seed -> [mae_f1, ..., mae_f5]
fold_hists = [] # collect from first seed only
fold_conf_w = [] # collect from first seed only
for si, seed in enumerate(SEEDS):
print(f"\n ╔═══ Seed {seed} ({si+1}/{N_SEEDS}) ═══╗")
for fi, (tv_i, te_i) in enumerate(folds):
print(f"\n ── [{cname} seed={seed}] Fold {fi+1}/5 {'─'*30}")
tri, vli = strat_split(targets_all[tv_i], 0.15, seed+fi)
feat.fit_scaler(X_all[tv_i][tri])
tr_s = feat.transform(X_all[tv_i][tri])
vl_s = feat.transform(X_all[tv_i][vli])
te_s = feat.transform(X_all[te_i])
pin = device.type == 'cuda'
tr_dl = DataLoader(DSData(tr_s, targets_all[tv_i][tri]), shuffle=True,
pin_memory=pin, **dl_kw)
vl_dl = DataLoader(DSData(vl_s, targets_all[tv_i][vli]), shuffle=False,
pin_memory=pin, **dl_kw)
te_dl = DataLoader(DSData(te_s, targets_all[te_i]), shuffle=False,
pin_memory=pin, **dl_kw)
te_tgt = get_targets(te_dl)
torch.manual_seed(seed + fi); np.random.seed(seed + fi)
if device.type == 'cuda': torch.cuda.manual_seed(seed + fi)
model = cfg['model_cls'](**cfg['model_kw']).to(device)
bv, model, hist = cfg['train_fn'](model, tr_dl, vl_dl, device,
fold=fi+1,
name=f"{cname}[s{seed}]")
# Save hist only for first seed
if si == 0:
fold_hists.append(hist)
# Predict
if cfg['is_confidence']:
pred, conf_w = predict_confidence(model, te_dl, device)
if si == 0:
fold_conf_w.append(conf_w)
avg_peak = conf_w.argmax(dim=1).float().mean().item() + 1
mae = F.l1_loss(pred, te_tgt).item()
log.info(f" [s{seed}] F{fi+1}: MAE={mae:.2f} "
f"(val {bv:.2f}, avg peak step={avg_peak:.1f})")
else:
pred = cfg['predict_fn'](model, te_dl, device)
mae = F.l1_loss(pred, te_tgt).item()
log.info(f" [s{seed}] F{fi+1}: MAE={mae:.2f} (val {bv:.2f})")
seed_fold_preds[seed][fi] = pred
seed_fold_maes[seed].append(mae)
torch.save({'model_state': model.state_dict(), 'test_mae': mae,
'config': cname, 'seed': seed},
f'trm_models_v13/{cname}_seed{seed}_fold{fi+1}.pt')
# Free GPU memory
del model; torch.cuda.empty_cache() if device.type == 'cuda' else None
seed_avg = float(np.mean(seed_fold_maes[seed]))
print(f" β•šβ•β•β• Seed {seed} avg: {seed_avg:.2f} MPa ═══╝")
# ── Compute ensemble predictions ──────────────────────────────
ensemble_fold_maes = []
for fi, (tv_i, te_i) in enumerate(folds):
te_tgt_np = targets_all[te_i]
te_tgt_t = torch.tensor(te_tgt_np, dtype=torch.float32)
# Average predictions across all seeds for this fold
all_seed_preds = torch.stack([seed_fold_preds[s][fi] for s in SEEDS])
ensemble_pred = all_seed_preds.mean(dim=0)
ens_mae = F.l1_loss(ensemble_pred, te_tgt_t).item()
ensemble_fold_maes.append(ens_mae)
ens_avg = float(np.mean(ensemble_fold_maes))
ens_std = float(np.std(ensemble_fold_maes))
# Also compute per-seed averages for reporting
per_seed_avgs = {s: float(np.mean(seed_fold_maes[s])) for s in SEEDS}
best_single_seed = min(per_seed_avgs.items(), key=lambda x: x[1])
all_results[cname] = {
'avg': ens_avg, 'std': ens_std, 'folds': ensemble_fold_maes,
'params': cfg['n_params'],
'per_seed_avgs': per_seed_avgs,
'per_seed_folds': {str(s): seed_fold_maes[s] for s in SEEDS},
'best_single_seed': best_single_seed[0],
'best_single_mae': best_single_seed[1],
}
all_hists[cname] = fold_hists
if fold_conf_w:
all_conf_weights[cname] = fold_conf_w
print(f"\n ═══ {cname} ═══")
print(f" Ensemble ({N_SEEDS}-seed avg): {ens_avg:.4f} Β±{ens_std:.4f} MPa")
print(f" Best single seed ({best_single_seed[0]}): "
f"{best_single_seed[1]:.4f} MPa")
for s in SEEDS:
print(f" Seed {s:>3}: {per_seed_avgs[s]:.2f} MPa "
f"folds={[f'{m:.1f}' for m in seed_fold_maes[s]]}")
# ══════════════════════════════════════════════════════════════════
# FINAL RESULTS
# ══════════════════════════════════════════════════════════════════
tt = time.time() - t0
print(f"\n{'═'*72}")
print(f" FINAL LEADERBOARD β€” matbench_steels V13 (5-Fold Avg MAE)")
print(f"{'═'*72}")
print(f" {'Model':<26} {'Params':>10} {'MAE(MPa)':>10} {'Β±Std':>8} Notes")
print(f" {'─'*72}")
for n, r in sorted(all_results.items(), key=lambda x: x[1]['avg']):
tag = (" ← BEATS MODNet πŸ†" if r['avg'] < 87.76 else
" ← BEATS V12A βœ“" if r['avg'] < 95.99 else
" ← BEATS RF-SCM βœ“" if r['avg'] < 103.51 else
" ← BEATS DARWIN βœ“" if r['avg'] < 123.29 else "")
print(f" {n+' (ens)':<26} {r['params']:>9,} "
f"{r['avg']:>10.4f} {r['std']:>8.4f}{tag}")
print(f" {n+' (best 1)':<26} {'':>10} "
f"{r['best_single_mae']:>10.4f} {'':>8} seed={r['best_single_seed']}")
print(f" {'─'*72}")
for bn, bv in sorted(BASELINES.items(), key=lambda x: x[1]):
print(f" {bn:<26} {'baseline':>10} {bv:>10.4f}")
print(f"\n Total time: {tt/60:.1f} minutes ({N_SEEDS} seeds Γ— 2 configs Γ— 5 folds)")
# Per-fold ensemble breakdown
print(f"\n{'═'*72}")
print(f" PER-FOLD ENSEMBLE BREAKDOWN")
print(f"{'═'*72}")
cnames = list(all_results.keys())
header = f" {'Fold':<6}"
for cn in cnames:
header += f" {cn:>20}"
print(header)
print(f" {'─'*52}")
for fi in range(5):
row = f" {fi+1:<6}"
for cn in cnames:
row += f" {all_results[cn]['folds'][fi]:>20.2f}"
print(row)
# Per-seed breakdown
print(f"\n{'═'*72}")
print(f" PER-SEED BREAKDOWN")
print(f"{'═'*72}")
for cn in cnames:
r = all_results[cn]
print(f"\n {cn}:")
header = f" {'Seed':<6}"
for fi in range(5):
header += f" {'F'+str(fi+1):>8}"
header += f" {'Avg':>8}"
print(header)
print(f" {'─'*52}")
for s in SEEDS:
row = f" {s:<6}"
for mae in r['per_seed_folds'][str(s)]:
row += f" {mae:>8.2f}"
row += f" {r['per_seed_avgs'][s]:>8.2f}"
print(row)
print(f" {'─'*52}")
row = f" {'ENS':<6}"
for mae in r['folds']:
row += f" {mae:>8.2f}"
row += f" {r['avg']:>8.2f}"
print(row)
# Confidence stats
if all_conf_weights:
print(f"\n Confidence Step Selection Summary:")
for cn, fw_list in all_conf_weights.items():
all_w = torch.cat(fw_list, dim=0)
avg_w = all_w.mean(dim=0)
peak_step = avg_w.argmax().item() + 1
avg_peak = all_w.argmax(dim=1).float().mean().item() + 1
print(f" {cn}: avg peak step={avg_peak:.1f}, "
f"population peak=step {peak_step}")
print()
generate_plots(all_results, all_hists, all_conf_weights)
save_summary(all_results, all_hists, all_conf_weights, tt)
return all_results
# ══════════════════════════════════════════════════════════════════════
# 6. PLOTS
# ══════════════════════════════════════════════════════════════════════
PAL = {'V13A-2xSA-StdDS': '#1565C0', 'V13B-2xSA-ConfDS': '#E65100'}
def generate_plots(all_results, all_hists, all_conf_weights):
names = list(all_results.keys())
avgs = [all_results[n]['avg'] for n in names]
stds = [all_results[n]['std'] for n in names]
cols = [PAL.get(n, '#888') for n in names]
fig = plt.figure(figsize=(22, 18))
gs = gridspec.GridSpec(2, 2, figure=fig, hspace=0.35, wspace=0.30)
# ── Plot 1: Bar chart vs baselines ────────────────────────────────
ax1 = fig.add_subplot(gs[0, 0])
# Show both ensemble and best-single-seed bars
x_pos = np.arange(len(names))
w = 0.35
ens_bars = ax1.bar(x_pos - w/2, avgs, w, yerr=stds, capsize=6,
color=cols, alpha=0.88, edgecolor='white',
linewidth=1.5, label='Ensemble')
best_singles = [all_results[n]['best_single_mae'] for n in names]
single_bars = ax1.bar(x_pos + w/2, best_singles, w, capsize=6,
color=cols, alpha=0.45, edgecolor='white',
linewidth=1.5, label='Best Single Seed',
hatch='//')
for bv, c, ls, lb in [
(87.76, '#F57F17', '--', 'MODNet (87.76)'),
(95.99, '#4CAF50', '-.', 'V12A (95.99)'),
(102.30, '#9E9E9E', '-.', 'V11B (102.30)'),
(103.51, '#B0BEC5', ':', 'RF-SCM (103.51)'),
(107.32, '#FF9800', ':', 'CrabNet (107.32)'),
]:
ax1.axhline(bv, color=c, linestyle=ls, linewidth=1.8, label=lb, alpha=0.85)
for bar, m, s in zip(ens_bars, avgs, stds):
ax1.text(bar.get_x()+bar.get_width()/2, bar.get_height()+s+1,
f'{m:.1f}', ha='center', fontsize=11, fontweight='bold')
for bar, m in zip(single_bars, best_singles):
ax1.text(bar.get_x()+bar.get_width()/2, bar.get_height()+1,
f'{m:.1f}', ha='center', fontsize=9, fontstyle='italic',
alpha=0.7)
ax1.set_xticks(x_pos)
ax1.set_xticklabels(names, fontsize=8)
ax1.legend(fontsize=6, loc='upper right')
ax1.set_ylabel('MAE (MPa)'); ax1.set_ylim(0, max(avgs)*1.6)
ax1.set_title('V13 Results vs Baselines (Ensemble + Best Single)',
fontsize=11, fontweight='bold')
ax1.grid(axis='y', alpha=0.3)
# ── Plot 2: Per-fold grouped bars ─────────────────────────────────
ax2 = fig.add_subplot(gs[0, 1])
x = np.arange(1, 6)
w = 0.35
for i, (n, col) in enumerate(zip(names, cols)):
fold_vals = all_results[n]['folds']
ax2.bar(x + (i - 0.5) * w, fold_vals, w, color=col, alpha=0.8,
label=n + ' (ens)', edgecolor='white')
ax2.axhline(95.99, color='#4CAF50', ls='-.', lw=1.5, label='V12A (95.99)')
ax2.axhline(87.76, color='#F57F17', ls='--', lw=1.5, label='MODNet (87.76)')
ax2.set_xlabel('Fold'); ax2.set_ylabel('MAE (MPa)')
ax2.set_xticks(x); ax2.set_xticklabels([f'F{i}' for i in range(1,6)])
ax2.set_title('Per-Fold Ensemble Breakdown', fontweight='bold')
ax2.legend(fontsize=7); ax2.grid(axis='y', alpha=0.2)
# ── Plot 3: Training/Val loss curves ──────────────────────────────
ax3 = fig.add_subplot(gs[1, 0])
for cname, col in PAL.items():
if cname not in all_hists: continue
for fi, h in enumerate(all_hists[cname]):
lb_tr = f'{cname} train' if fi == 0 else None
lb_vl = f'{cname} val' if fi == 0 else None
ax3.plot(h['train'], alpha=0.3, lw=0.8, color=col, label=lb_tr)
ax3.plot(h['val'], alpha=0.7, lw=1.2, color=col, label=lb_vl,
linestyle='--')
ax3.axhline(95.99, color='#4CAF50', ls='-.', lw=1.2, label='V12A (95.99)')
ax3.axvline(200, color='#4CAF50', ls='--', lw=1.2, alpha=0.6, label='SWA start')
ax3.set_xlabel('Epoch'); ax3.set_ylabel('MAE (MPa)')
ax3.set_title('Training Curves (seed 0, all folds)', fontweight='bold')
ax3.legend(fontsize=6, ncol=2); ax3.grid(alpha=0.2)
ax3.set_ylim(0, 300)
# ── Plot 4: Per-seed scatter / Confidence ─────────────────────────
ax4 = fig.add_subplot(gs[1, 1])
if all_conf_weights:
for cn, fw_list in all_conf_weights.items():
all_w = torch.cat(fw_list, dim=0)
avg_w = all_w.mean(dim=0).numpy()
steps = np.arange(1, len(avg_w)+1)
ax4.bar(steps, avg_w, color=PAL.get(cn, '#E65100'), alpha=0.8,
label=f'{cn} avg confidence', edgecolor='white')
std_w = all_w.std(dim=0).numpy()
ax4.errorbar(steps, avg_w, yerr=std_w, fmt='none',
ecolor='#333', capsize=2, alpha=0.5)
ax4.set_xlabel('Recursion Step')
ax4.set_ylabel('Confidence Weight (softmax)')
ax4.set_title('V13B: Where the Model Trusts Its Predictions',
fontweight='bold')
ax4.legend(fontsize=8)
ax4.grid(axis='y', alpha=0.2)
else:
# Show per-seed MAE scatter if no confidence model
for i, (cn, col) in enumerate(zip(names, cols)):
r = all_results[cn]
seed_avgs = [r['per_seed_avgs'][s] for s in SEEDS]
ax4.scatter(SEEDS, seed_avgs, s=80, c=col, alpha=0.8,
label=f'{cn} per-seed', zorder=5,
edgecolors='white', linewidth=1)
ax4.axhline(r['avg'], color=col, ls='--', lw=1.5, alpha=0.6,
label=f'{cn} ensemble={r["avg"]:.2f}')
ax4.axhline(95.99, color='#4CAF50', ls=':', lw=1, alpha=0.5, label='V12A')
ax4.set_xlabel('Random Seed')
ax4.set_ylabel('5-Fold Avg MAE (MPa)')
ax4.set_title('Per-Seed vs Ensemble Performance', fontweight='bold')
ax4.legend(fontsize=7); ax4.grid(alpha=0.2)
fig.suptitle('TRM-MatSci V13 β”‚ 2-Layer SA + Multi-Seed Ensemble β”‚ matbench_steels',
fontsize=14, fontweight='bold', y=1.01)
fig.savefig('trm_results_v13.png', dpi=150, bbox_inches='tight')
plt.close(fig); log.info("βœ“ Saved: trm_results_v13.png")
def save_summary(all_results, all_hists, all_conf_weights, total_s):
# Prepare confidence info
conf_info = {}
for cn, fw_list in all_conf_weights.items():
all_w = torch.cat(fw_list, dim=0)
conf_info[cn] = {
'avg_weights': all_w.mean(dim=0).numpy().round(4).tolist(),
'avg_peak_step': float(all_w.argmax(dim=1).float().mean().item() + 1),
}
s = {
'version': 'V13', 'task': 'matbench_steels',
'strategy': '2-Layer SA + Multi-Seed Ensemble',
'seeds': SEEDS,
'n_seeds': N_SEEDS,
'total_min': round(total_s/60, 1),
'models': {},
'confidence': conf_info,
}
for n, r in all_results.items():
s['models'][n] = {
'ensemble_avg': round(r['avg'], 4),
'ensemble_std': round(r['std'], 4),
'ensemble_folds': [round(x, 4) for x in r['folds']],
'params': r['params'],
'best_single_seed': r['best_single_seed'],
'best_single_mae': round(r['best_single_mae'], 4),
'per_seed_avgs': {str(k): round(v, 4) for k, v in r['per_seed_avgs'].items()},
}
with open('trm_models_v13/summary_v13.json', 'w') as f:
json.dump(s, f, indent=2, default=str)
log.info("βœ“ Saved: summary_v13.json")
if __name__ == '__main__':
results = run_benchmark()
shutil.make_archive("trm_v13_all", "zip", "trm_models_v13")
log.info("βœ“ Created trm_v13_all.zip")