| | """
|
| | ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| | β TRM-MatSci V13 β 2-Layer SA + Multi-Seed Ensemble β
|
| | β Dataset: matbench_steels β 312 samples β 5-Fold Nested CV β
|
| | β β
|
| | β V13A 2-Layer Self-Attention + Standard Deep Supervision β
|
| | β d_attn=64, nhead=4, d_hidden=96, ff_dim=150, 20 steps β
|
| | β Expanded features (Magpie + Mat2Vec + Extra descriptors) β
|
| | β 2nd SA layer for higher-order property interactions β
|
| | β 5-seed ensemble (avg predictions across seeds) β
|
| | β β
|
| | β V13B Same 2-Layer SA + Confidence-Weighted Deep Supervision β
|
| | β 22 steps, confidence_head learns which step to trust β
|
| | β 5-seed ensemble (avg predictions across seeds) β
|
| | β β
|
| | β All models: Deep Supervision + SWA + AdamW + 300 epochs β
|
| | β Baseline: V12A = 95.99 MPa (current best) β
|
| | ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| | """
|
| |
|
| | import os, copy, json, time, logging, warnings, shutil, urllib.request
|
| | warnings.filterwarnings('ignore')
|
| |
|
| | import numpy as np
|
| | import pandas as pd
|
| |
|
| | import matplotlib
|
| | matplotlib.use('Agg')
|
| | import matplotlib.pyplot as plt
|
| | import matplotlib.gridspec as gridspec
|
| |
|
| | from tqdm import tqdm
|
| |
|
| | import torch
|
| | import torch.nn as nn
|
| | import torch.nn.functional as F
|
| | from torch.utils.data import Dataset, DataLoader
|
| | import torch.optim as optim
|
| | from torch.optim.swa_utils import AveragedModel, SWALR, update_bn
|
| |
|
| | from sklearn.model_selection import KFold
|
| | from sklearn.preprocessing import StandardScaler
|
| | from pymatgen.core import Composition
|
| | from matminer.featurizers.composition import ElementProperty
|
| | from gensim.models import Word2Vec
|
| |
|
| | logging.basicConfig(level=logging.INFO, format='%(name)s β %(message)s')
|
| | log = logging.getLogger("TRM13")
|
| |
|
| |
|
| | SEEDS = [42, 123, 7, 0, 99]
|
| | N_SEEDS = len(SEEDS)
|
| |
|
| | BASELINES = {
|
| | 'TPOT-Mat': 79.9468,
|
| | 'AutoML-Mat': 82.3043,
|
| | 'MODNet': 87.7627,
|
| | 'RF-SCM/Magpie': 103.5125,
|
| | 'V12A (best)': 95.9900,
|
| | 'V11B': 102.3003,
|
| | 'V10A': 103.2867,
|
| | 'CrabNet': 107.3160,
|
| | 'Darwin': 123.2932,
|
| | }
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | class ExpandedFeaturizer:
|
| | """Magpie (22 props Γ 6 stats) + Extra matminer descriptors + Mat2Vec (200d).
|
| |
|
| | Extra descriptors: ElementFraction, Stoichiometry, ValenceOrbital,
|
| | IonProperty, BandCenter β all concatenated as a flat vector between
|
| | the Magpie block and Mat2Vec.
|
| | """
|
| | GCS = "https://storage.googleapis.com/mat2vec/"
|
| | FILES = ["pretrained_embeddings",
|
| | "pretrained_embeddings.wv.vectors.npy",
|
| | "pretrained_embeddings.trainables.syn1neg.npy"]
|
| |
|
| | def __init__(self, cache="mat2vec_cache"):
|
| | from matminer.featurizers.composition import (
|
| | ElementFraction, Stoichiometry, ValenceOrbital,
|
| | IonProperty, BandCenter
|
| | )
|
| | from matminer.featurizers.base import MultipleFeaturizer
|
| |
|
| | self.ep_magpie = ElementProperty.from_preset("magpie")
|
| | self.n_mg = len(self.ep_magpie.feature_labels())
|
| |
|
| | self.extra_feats = MultipleFeaturizer([
|
| | ElementFraction(),
|
| | Stoichiometry(),
|
| | ValenceOrbital(),
|
| | IonProperty(),
|
| | BandCenter(),
|
| | ])
|
| | self.n_extra = None
|
| |
|
| | self.scaler = None
|
| | os.makedirs(cache, exist_ok=True)
|
| | for f in self.FILES:
|
| | p = os.path.join(cache, f)
|
| | if not os.path.exists(p):
|
| | log.info(f" Downloading {f}...")
|
| | urllib.request.urlretrieve(self.GCS + f, p)
|
| | self.m2v = Word2Vec.load(os.path.join(cache, "pretrained_embeddings"))
|
| | self.emb = {w: self.m2v.wv[w] for w in self.m2v.wv.index_to_key}
|
| |
|
| | def _pool(self, c):
|
| | v, t = np.zeros(200, np.float32), 0.0
|
| | for s, f in c.get_el_amt_dict().items():
|
| | if s in self.emb: v += f * self.emb[s]; t += f
|
| | return v / max(t, 1e-8)
|
| |
|
| | def featurize_all(self, comps):
|
| | out = []
|
| | for c in tqdm(comps, desc=" Featurizing (expanded)", leave=False):
|
| | try: mg = np.array(self.ep_magpie.featurize(c), np.float32)
|
| | except: mg = np.zeros(self.n_mg, np.float32)
|
| |
|
| | try:
|
| | ex = np.array(self.extra_feats.featurize(c), np.float32)
|
| | except:
|
| | ex = np.zeros(self.n_extra or 200, np.float32)
|
| |
|
| |
|
| |
|
| | if len(ex) > 200:
|
| | ex = ex[:200]
|
| | elif len(ex) < 200:
|
| | ex = np.pad(ex, (0, 200 - len(ex)))
|
| |
|
| | if self.n_extra is None:
|
| | self.n_extra = 200
|
| | log.info(f"Expanded features: {self.n_mg} Magpie + "
|
| | f"200 Extra (forced) + 200 Mat2Vec = "
|
| | f"{self.n_mg + 200 + 200}d")
|
| |
|
| | out.append(np.concatenate([
|
| | np.nan_to_num(mg, nan=0.0),
|
| | np.nan_to_num(ex, nan=0.0),
|
| | self._pool(c)
|
| | ]))
|
| | return np.array(out)
|
| |
|
| | def fit_scaler(self, X): self.scaler = StandardScaler().fit(X)
|
| | def transform(self, X):
|
| | if not self.scaler: return X
|
| | return np.nan_to_num(self.scaler.transform(X), nan=0.0).astype(np.float32)
|
| |
|
| |
|
| | class DSData(Dataset):
|
| | def __init__(self, X, y):
|
| | self.X = torch.tensor(X, dtype=torch.float32)
|
| | self.y = torch.tensor(np.array(y, np.float32))
|
| | def __len__(self): return len(self.y)
|
| | def __getitem__(self, i): return self.X[i], self.y[i]
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | class DeepHybridTRM(nn.Module):
|
| | """V13A: 2-Layer SA Hybrid-TRM with Standard Deep Supervision.
|
| |
|
| | Key difference from V12A's HybridTRM:
|
| | - TWO self-attention layers (SA1 β FF1 β SA2 β FF2 β CA)
|
| | - Each SA layer has its own residual + LayerNorm + FF block
|
| | - This enables higher-order property interaction modeling
|
| | (e.g., "correlation between electronegativity-range AND
|
| | atomic-radius-mean" requires composing two rounds of attention)
|
| | - Cross-attention (CA) to Mat2Vec context remains after SA stack
|
| |
|
| | Everything else (MLP reasoning loop, deep supervision, SWA)
|
| | is identical to V12A.
|
| | """
|
| | def __init__(self, n_props=22, stat_dim=6, n_extra=200, mat2vec_dim=200,
|
| | d_attn=64, nhead=4, d_hidden=96, ff_dim=150,
|
| | dropout=0.2, max_steps=20, **kw):
|
| | super().__init__()
|
| | self.max_steps, self.D = max_steps, d_hidden
|
| | self.n_props, self.stat_dim = n_props, stat_dim
|
| | self.n_extra = n_extra
|
| |
|
| |
|
| | self.tok_proj = nn.Sequential(
|
| | nn.Linear(stat_dim, d_attn), nn.LayerNorm(d_attn), nn.GELU())
|
| | self.m2v_proj = nn.Sequential(
|
| | nn.Linear(mat2vec_dim, d_attn), nn.LayerNorm(d_attn), nn.GELU())
|
| |
|
| |
|
| | self.sa1 = nn.MultiheadAttention(
|
| | d_attn, nhead, dropout=dropout, batch_first=True)
|
| | self.sa1_n = nn.LayerNorm(d_attn)
|
| | self.sa1_ff = nn.Sequential(
|
| | nn.Linear(d_attn, d_attn*2), nn.GELU(), nn.Dropout(dropout),
|
| | nn.Linear(d_attn*2, d_attn))
|
| | self.sa1_fn = nn.LayerNorm(d_attn)
|
| |
|
| |
|
| | self.sa2 = nn.MultiheadAttention(
|
| | d_attn, nhead, dropout=dropout, batch_first=True)
|
| | self.sa2_n = nn.LayerNorm(d_attn)
|
| | self.sa2_ff = nn.Sequential(
|
| | nn.Linear(d_attn, d_attn*2), nn.GELU(), nn.Dropout(dropout),
|
| | nn.Linear(d_attn*2, d_attn))
|
| | self.sa2_fn = nn.LayerNorm(d_attn)
|
| |
|
| |
|
| | self.ca = nn.MultiheadAttention(
|
| | d_attn, nhead, dropout=dropout, batch_first=True)
|
| | self.ca_n = nn.LayerNorm(d_attn)
|
| |
|
| |
|
| | pool_in = d_attn + (n_extra if n_extra > 0 else 0)
|
| | self.pool = nn.Sequential(
|
| | nn.Linear(pool_in, d_hidden), nn.LayerNorm(d_hidden), nn.GELU())
|
| |
|
| |
|
| | self.z_up = nn.Sequential(
|
| | nn.Linear(d_hidden*3, ff_dim), nn.GELU(), nn.Dropout(dropout),
|
| | nn.Linear(ff_dim, d_hidden), nn.LayerNorm(d_hidden))
|
| | self.y_up = nn.Sequential(
|
| | nn.Linear(d_hidden*2, ff_dim), nn.GELU(), nn.Dropout(dropout),
|
| | nn.Linear(ff_dim, d_hidden), nn.LayerNorm(d_hidden))
|
| | self.head = nn.Linear(d_hidden, 1)
|
| | self._init()
|
| |
|
| | def _init(self):
|
| | for m in self.modules():
|
| | if isinstance(m, nn.Linear):
|
| | nn.init.xavier_uniform_(m.weight)
|
| | if m.bias is not None: nn.init.zeros_(m.bias)
|
| |
|
| | def _attention(self, x):
|
| | B = x.size(0)
|
| | mg_dim = self.n_props * self.stat_dim
|
| | mg = x[:, :mg_dim]
|
| |
|
| | if self.n_extra > 0:
|
| | extra = x[:, mg_dim:mg_dim + self.n_extra]
|
| | m2v = x[:, mg_dim + self.n_extra:]
|
| | else:
|
| | extra = None
|
| | m2v = x[:, mg_dim:]
|
| |
|
| | tok = self.tok_proj(mg.view(B, self.n_props, self.stat_dim))
|
| | ctx = self.m2v_proj(m2v).unsqueeze(1)
|
| |
|
| |
|
| | tok = self.sa1_n(tok + self.sa1(tok, tok, tok)[0])
|
| | tok = self.sa1_fn(tok + self.sa1_ff(tok))
|
| |
|
| |
|
| | tok = self.sa2_n(tok + self.sa2(tok, tok, tok)[0])
|
| | tok = self.sa2_fn(tok + self.sa2_ff(tok))
|
| |
|
| |
|
| | tok = self.ca_n(tok + self.ca(tok, ctx, ctx)[0])
|
| |
|
| | pooled = tok.mean(dim=1)
|
| |
|
| | if extra is not None:
|
| | pooled = torch.cat([pooled, extra], dim=-1)
|
| |
|
| | return self.pool(pooled)
|
| |
|
| | def forward(self, x, deep_supervision=False, return_trajectory=False):
|
| | B = x.size(0)
|
| | xp = self._attention(x)
|
| | z = torch.zeros(B, self.D, device=x.device)
|
| | y = torch.zeros(B, self.D, device=x.device)
|
| | step_preds = []
|
| | for _ in range(self.max_steps):
|
| | z = z + self.z_up(torch.cat([xp, y, z], -1))
|
| | y = y + self.y_up(torch.cat([y, z], -1))
|
| | step_preds.append(self.head(y).squeeze(1))
|
| | if deep_supervision:
|
| | return step_preds
|
| | elif return_trajectory:
|
| | return step_preds[-1], step_preds
|
| | else:
|
| | return step_preds[-1]
|
| |
|
| | def count_parameters(self):
|
| | return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| |
|
| |
|
| | class DeepConfidenceHybridTRM(nn.Module):
|
| | """V13B: 2-Layer SA Hybrid-TRM with Confidence-Weighted Deep Supervision.
|
| |
|
| | Same 2-layer SA feature extractor as DeepHybridTRM, but with:
|
| | - confidence_head that learns which recursion step to trust
|
| | - Final prediction = softmax(confidence) Β· step_preds
|
| | - No ponder cost (avoids V11C's failure)
|
| | - 22 recursion steps (vs 20 for V13A)
|
| | """
|
| | def __init__(self, n_props=22, stat_dim=6, n_extra=0, mat2vec_dim=200,
|
| | d_attn=64, nhead=4, d_hidden=96, ff_dim=150,
|
| | dropout=0.2, max_steps=22, **kw):
|
| | super().__init__()
|
| | self.max_steps, self.D = max_steps, d_hidden
|
| | self.n_props, self.stat_dim = n_props, stat_dim
|
| | self.n_extra = n_extra
|
| |
|
| |
|
| | self.tok_proj = nn.Sequential(
|
| | nn.Linear(stat_dim, d_attn), nn.LayerNorm(d_attn), nn.GELU())
|
| | self.m2v_proj = nn.Sequential(
|
| | nn.Linear(mat2vec_dim, d_attn), nn.LayerNorm(d_attn), nn.GELU())
|
| |
|
| |
|
| | self.sa1 = nn.MultiheadAttention(
|
| | d_attn, nhead, dropout=dropout, batch_first=True)
|
| | self.sa1_n = nn.LayerNorm(d_attn)
|
| | self.sa1_ff = nn.Sequential(
|
| | nn.Linear(d_attn, d_attn*2), nn.GELU(), nn.Dropout(dropout),
|
| | nn.Linear(d_attn*2, d_attn))
|
| | self.sa1_fn = nn.LayerNorm(d_attn)
|
| |
|
| |
|
| | self.sa2 = nn.MultiheadAttention(
|
| | d_attn, nhead, dropout=dropout, batch_first=True)
|
| | self.sa2_n = nn.LayerNorm(d_attn)
|
| | self.sa2_ff = nn.Sequential(
|
| | nn.Linear(d_attn, d_attn*2), nn.GELU(), nn.Dropout(dropout),
|
| | nn.Linear(d_attn*2, d_attn))
|
| | self.sa2_fn = nn.LayerNorm(d_attn)
|
| |
|
| |
|
| | self.ca = nn.MultiheadAttention(
|
| | d_attn, nhead, dropout=dropout, batch_first=True)
|
| | self.ca_n = nn.LayerNorm(d_attn)
|
| |
|
| |
|
| | pool_in = d_attn + (n_extra if n_extra > 0 else 0)
|
| | self.pool = nn.Sequential(
|
| | nn.Linear(pool_in, d_hidden), nn.LayerNorm(d_hidden), nn.GELU())
|
| |
|
| |
|
| | self.z_up = nn.Sequential(
|
| | nn.Linear(d_hidden*3, ff_dim), nn.GELU(), nn.Dropout(dropout),
|
| | nn.Linear(ff_dim, d_hidden), nn.LayerNorm(d_hidden))
|
| | self.y_up = nn.Sequential(
|
| | nn.Linear(d_hidden*2, ff_dim), nn.GELU(), nn.Dropout(dropout),
|
| | nn.Linear(ff_dim, d_hidden), nn.LayerNorm(d_hidden))
|
| | self.head = nn.Linear(d_hidden, 1)
|
| |
|
| |
|
| | self.confidence_head = nn.Sequential(
|
| | nn.Linear(d_hidden, d_hidden // 2), nn.GELU(),
|
| | nn.Linear(d_hidden // 2, 1))
|
| |
|
| | self._init()
|
| |
|
| | def _init(self):
|
| | for m in self.modules():
|
| | if isinstance(m, nn.Linear):
|
| | nn.init.xavier_uniform_(m.weight)
|
| | if m.bias is not None: nn.init.zeros_(m.bias)
|
| | with torch.no_grad():
|
| | nn.init.zeros_(self.confidence_head[-1].bias)
|
| |
|
| | def _attention(self, x):
|
| | B = x.size(0)
|
| | mg_dim = self.n_props * self.stat_dim
|
| | mg = x[:, :mg_dim]
|
| |
|
| | if self.n_extra > 0:
|
| | extra = x[:, mg_dim:mg_dim + self.n_extra]
|
| | m2v = x[:, mg_dim + self.n_extra:]
|
| | else:
|
| | extra = None
|
| | m2v = x[:, mg_dim:]
|
| |
|
| | tok = self.tok_proj(mg.view(B, self.n_props, self.stat_dim))
|
| | ctx = self.m2v_proj(m2v).unsqueeze(1)
|
| |
|
| |
|
| | tok = self.sa1_n(tok + self.sa1(tok, tok, tok)[0])
|
| | tok = self.sa1_fn(tok + self.sa1_ff(tok))
|
| |
|
| |
|
| | tok = self.sa2_n(tok + self.sa2(tok, tok, tok)[0])
|
| | tok = self.sa2_fn(tok + self.sa2_ff(tok))
|
| |
|
| |
|
| | tok = self.ca_n(tok + self.ca(tok, ctx, ctx)[0])
|
| |
|
| | pooled = tok.mean(dim=1)
|
| |
|
| | if extra is not None:
|
| | pooled = torch.cat([pooled, extra], dim=-1)
|
| |
|
| | return self.pool(pooled)
|
| |
|
| | def forward(self, x, deep_supervision=False, return_confidence=False):
|
| | """Forward pass.
|
| |
|
| | Returns:
|
| | deep_supervision=True: (step_preds, confidence_logits)
|
| | deep_supervision=False, return_confidence=False:
|
| | weighted_pred: [B] confidence-weighted prediction
|
| | deep_supervision=False, return_confidence=True:
|
| | (weighted_pred, confidence_weights): [B], [B, max_steps]
|
| | """
|
| | B = x.size(0)
|
| | xp = self._attention(x)
|
| | z = torch.zeros(B, self.D, device=x.device)
|
| | y = torch.zeros(B, self.D, device=x.device)
|
| |
|
| | step_preds = []
|
| | conf_logits = []
|
| |
|
| | for _ in range(self.max_steps):
|
| | z = z + self.z_up(torch.cat([xp, y, z], -1))
|
| | y = y + self.y_up(torch.cat([y, z], -1))
|
| | step_preds.append(self.head(y).squeeze(1))
|
| | conf_logits.append(self.confidence_head(y).squeeze(1))
|
| |
|
| | conf_logits = torch.stack(conf_logits, dim=1)
|
| |
|
| | if deep_supervision:
|
| | return step_preds, conf_logits
|
| |
|
| |
|
| | conf_weights = F.softmax(conf_logits, dim=1)
|
| | preds_stack = torch.stack(step_preds, dim=1)
|
| | weighted_pred = (preds_stack * conf_weights).sum(dim=1)
|
| |
|
| | if return_confidence:
|
| | return weighted_pred, conf_weights
|
| | return weighted_pred
|
| |
|
| | def count_parameters(self):
|
| | return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def deep_supervision_loss(step_preds, targets):
|
| | """Linear-weighted L1 loss across all recursion steps."""
|
| | n = len(step_preds)
|
| | weights = [(i + 1) for i in range(n)]
|
| | total_w = sum(weights)
|
| | loss = 0.0
|
| | for pred, w in zip(step_preds, weights):
|
| | loss += (w / total_w) * F.l1_loss(pred, targets)
|
| | return loss
|
| |
|
| |
|
| | def confidence_ds_loss(step_preds, targets, conf_logits):
|
| | """Advanced Deep Supervision: standard DS + confidence-weighted L1.
|
| |
|
| | Components:
|
| | 1. Standard linear-weighted deep supervision on all steps
|
| | 2. L1 loss on the confidence-weighted final prediction
|
| | """
|
| | ds = deep_supervision_loss(step_preds, targets)
|
| |
|
| | conf_weights = F.softmax(conf_logits, dim=1)
|
| | preds_stack = torch.stack(step_preds, dim=1)
|
| | weighted_pred = (preds_stack * conf_weights).sum(dim=1)
|
| | conf_loss = F.l1_loss(weighted_pred, targets)
|
| |
|
| | return ds + conf_loss
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def strat_split(targets, val_size=0.15, seed=42):
|
| | bins = np.percentile(targets, [25, 50, 75])
|
| | lbl = np.digitize(targets, bins)
|
| | tr, vl = [], []
|
| | rng = np.random.RandomState(seed)
|
| | for b in range(4):
|
| | m = np.where(lbl == b)[0]
|
| | if len(m) == 0: continue
|
| | n = max(1, int(len(m) * val_size))
|
| | c = rng.choice(m, n, replace=False)
|
| | vl.extend(c.tolist()); tr.extend(np.setdiff1d(m, c).tolist())
|
| | return np.array(tr), np.array(vl)
|
| |
|
| |
|
| | def train_fold_standard(model, tr_dl, vl_dl, device,
|
| | epochs=300, swa_start=200, fold=1, name=""):
|
| | """Training with standard deep supervision."""
|
| | opt = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
|
| | sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=swa_start, eta_min=1e-4)
|
| | swa_m = AveragedModel(model)
|
| | swa_s = SWALR(opt, swa_lr=5e-4)
|
| | swa_on = False
|
| | best_v, best_w = float('inf'), copy.deepcopy(model.state_dict())
|
| | hist = {'train': [], 'val': []}
|
| |
|
| | pbar = tqdm(range(epochs), desc=f" [{name}] F{fold}/5",
|
| | leave=False, ncols=120)
|
| | for ep in pbar:
|
| | model.train(); tl = 0.0
|
| | for bx, by in tr_dl:
|
| | bx, by = bx.to(device), by.to(device)
|
| | step_preds = model(bx, deep_supervision=True)
|
| | loss = deep_supervision_loss(step_preds, by)
|
| | opt.zero_grad(set_to_none=True); loss.backward()
|
| | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| | opt.step()
|
| | tl += F.l1_loss(step_preds[-1], by).item() * len(by)
|
| | tl /= len(tr_dl.dataset)
|
| |
|
| | model.eval(); vl = 0.0
|
| | with torch.no_grad():
|
| | for bx, by in vl_dl:
|
| | bx, by = bx.to(device), by.to(device)
|
| | pred = model(bx)
|
| | vl += F.l1_loss(pred, by).item() * len(by)
|
| | vl /= len(vl_dl.dataset)
|
| | hist['train'].append(tl); hist['val'].append(vl)
|
| |
|
| | if ep < swa_start:
|
| | sch.step()
|
| | if vl < best_v: best_v, best_w = vl, copy.deepcopy(model.state_dict())
|
| | else:
|
| | if not swa_on: swa_on = True
|
| | swa_m.update_parameters(model); swa_s.step()
|
| |
|
| | pbar.set_postfix(Tr=f'{tl:.1f}', Val=f'{vl:.1f}',
|
| | Best=f'{best_v:.1f}', Ph='SWA' if swa_on else 'COS')
|
| |
|
| | if swa_on:
|
| | update_bn(tr_dl, swa_m, device=device)
|
| | model.load_state_dict(swa_m.module.state_dict())
|
| | else:
|
| | model.load_state_dict(best_w)
|
| | return best_v, model, hist
|
| |
|
| |
|
| | def train_fold_confidence(model, tr_dl, vl_dl, device,
|
| | epochs=300, swa_start=200, fold=1, name=""):
|
| | """Training with confidence-weighted deep supervision."""
|
| | opt = optim.AdamW(model.parameters(), lr=1e-3, weight_decay=1e-4)
|
| | sch = optim.lr_scheduler.CosineAnnealingLR(opt, T_max=swa_start, eta_min=1e-4)
|
| | swa_m = AveragedModel(model)
|
| | swa_s = SWALR(opt, swa_lr=5e-4)
|
| | swa_on = False
|
| | best_v, best_w = float('inf'), copy.deepcopy(model.state_dict())
|
| | hist = {'train': [], 'val': []}
|
| |
|
| | pbar = tqdm(range(epochs), desc=f" [{name}] F{fold}/5",
|
| | leave=False, ncols=120)
|
| | for ep in pbar:
|
| | model.train(); tl = 0.0
|
| | for bx, by in tr_dl:
|
| | bx, by = bx.to(device), by.to(device)
|
| | step_preds, conf_logits = model(bx, deep_supervision=True)
|
| | loss = confidence_ds_loss(step_preds, by, conf_logits)
|
| | opt.zero_grad(set_to_none=True); loss.backward()
|
| | torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
| | opt.step()
|
| |
|
| | with torch.no_grad():
|
| | cw = F.softmax(conf_logits, dim=1)
|
| | ps = torch.stack(step_preds, dim=1)
|
| | wp = (ps * cw).sum(dim=1)
|
| | tl += F.l1_loss(wp, by).item() * len(by)
|
| | tl /= len(tr_dl.dataset)
|
| |
|
| | model.eval(); vl = 0.0
|
| | with torch.no_grad():
|
| | for bx, by in vl_dl:
|
| | bx, by = bx.to(device), by.to(device)
|
| | pred = model(bx)
|
| | vl += F.l1_loss(pred, by).item() * len(by)
|
| | vl /= len(vl_dl.dataset)
|
| | hist['train'].append(tl); hist['val'].append(vl)
|
| |
|
| | if ep < swa_start:
|
| | sch.step()
|
| | if vl < best_v: best_v, best_w = vl, copy.deepcopy(model.state_dict())
|
| | else:
|
| | if not swa_on: swa_on = True
|
| | swa_m.update_parameters(model); swa_s.step()
|
| |
|
| | pbar.set_postfix(Tr=f'{tl:.1f}', Val=f'{vl:.1f}',
|
| | Best=f'{best_v:.1f}', Ph='SWA' if swa_on else 'COS')
|
| |
|
| | if swa_on:
|
| | update_bn(tr_dl, swa_m, device=device)
|
| | model.load_state_dict(swa_m.module.state_dict())
|
| | else:
|
| | model.load_state_dict(best_w)
|
| | return best_v, model, hist
|
| |
|
| |
|
| | def predict(model, dl, device):
|
| | model.eval(); preds = []
|
| | with torch.no_grad():
|
| | for bx, _ in dl:
|
| | preds.append(model(bx.to(device)).cpu())
|
| | return torch.cat(preds)
|
| |
|
| |
|
| | def predict_confidence(model, dl, device):
|
| | """Predict using confidence model, also return per-step weights."""
|
| | model.eval()
|
| | all_preds, all_weights = [], []
|
| | with torch.no_grad():
|
| | for bx, _ in dl:
|
| | pred, weights = model(bx.to(device), return_confidence=True)
|
| | all_preds.append(pred.cpu())
|
| | all_weights.append(weights.cpu())
|
| | return torch.cat(all_preds), torch.cat(all_weights)
|
| |
|
| |
|
| | def get_targets(dl):
|
| | tgts = []
|
| | for _, by in dl: tgts.append(by)
|
| | return torch.cat(tgts)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def run_benchmark():
|
| | t0 = time.time()
|
| | print("\n" + "β"*72)
|
| | print(" TRM-MatSci V13 β 2-Layer SA + Multi-Seed Ensemble β matbench_steels")
|
| | print(" V13A: 2-Layer SA + expanded features + standard DS (5-seed ensemble)")
|
| | print(" V13B: 2-Layer SA + expanded features + confidence DS (5-seed ensemble)")
|
| | print(f" Seeds: {SEEDS}")
|
| | print("β"*72 + "\n")
|
| |
|
| | device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
| | if device.type == 'cuda':
|
| | log.info(f"GPU: {torch.cuda.get_device_name(0)} "
|
| | f"({torch.cuda.get_device_properties(0).total_mem/1e9:.1f} GB)")
|
| | torch.backends.cuda.matmul.allow_tf32 = True
|
| | torch.backends.cudnn.benchmark = True
|
| |
|
| | log.info("Loading matbench_steels...")
|
| | from matminer.datasets import load_dataset
|
| | df = load_dataset("matbench_steels")
|
| | comps_raw = df['composition'].tolist()
|
| | targets_all = np.array(df['yield strength'].tolist(), np.float32)
|
| | comps_all = [Composition(c) for c in comps_raw]
|
| |
|
| |
|
| | log.info("Computing EXPANDED features...")
|
| | feat = ExpandedFeaturizer()
|
| | X_all = feat.featurize_all(comps_all)
|
| | n_extra = feat.n_extra
|
| | log.info(f"Features: {X_all.shape} (n_extra={n_extra})")
|
| |
|
| | kfold = KFold(n_splits=5, shuffle=True, random_state=18012019)
|
| | folds = list(kfold.split(comps_all))
|
| | os.makedirs('trm_models_v13', exist_ok=True)
|
| | dl_kw = dict(batch_size=32, num_workers=0)
|
| |
|
| |
|
| | shared_kw = dict(n_props=22, stat_dim=6, n_extra=n_extra,
|
| | mat2vec_dim=200, d_attn=64, nhead=4,
|
| | d_hidden=96, ff_dim=150, dropout=0.2)
|
| |
|
| | configs = {
|
| | 'V13A-2xSA-StdDS': {
|
| | 'model_cls': DeepHybridTRM,
|
| | 'model_kw': {**shared_kw, 'max_steps': 20},
|
| | 'train_fn': train_fold_standard,
|
| | 'predict_fn': predict,
|
| | 'is_confidence': False,
|
| | },
|
| | 'V13B-2xSA-ConfDS': {
|
| | 'model_cls': DeepConfidenceHybridTRM,
|
| | 'model_kw': {**shared_kw, 'max_steps': 22},
|
| | 'train_fn': train_fold_confidence,
|
| | 'predict_fn': None,
|
| | 'is_confidence': True,
|
| | },
|
| | }
|
| |
|
| |
|
| | print(f"\n {'Config':<24} {'Params':>10} {'Steps':>8} {'Seeds':>6}")
|
| | print(f" {'β'*54}")
|
| | for cname, cfg in configs.items():
|
| | _m = cfg['model_cls'](**cfg['model_kw'])
|
| | np_ = _m.count_parameters(); del _m
|
| | cfg['n_params'] = np_
|
| | steps = cfg['model_kw']['max_steps']
|
| | print(f" {cname:<24} {np_:>10,} {steps:>8} {N_SEEDS:>6}")
|
| | print()
|
| |
|
| |
|
| | all_results = {}
|
| | all_hists = {}
|
| | all_conf_weights = {}
|
| |
|
| | for cname, cfg in configs.items():
|
| | print(f"\n{'β'*72}")
|
| | print(f" {cname} β {N_SEEDS}-Seed Ensemble")
|
| | print(f"{'β'*72}")
|
| |
|
| |
|
| | seed_fold_preds = {s: {} for s in SEEDS}
|
| | seed_fold_maes = {s: [] for s in SEEDS}
|
| | fold_hists = []
|
| | fold_conf_w = []
|
| |
|
| | for si, seed in enumerate(SEEDS):
|
| | print(f"\n ββββ Seed {seed} ({si+1}/{N_SEEDS}) ββββ")
|
| |
|
| | for fi, (tv_i, te_i) in enumerate(folds):
|
| | print(f"\n ββ [{cname} seed={seed}] Fold {fi+1}/5 {'β'*30}")
|
| |
|
| | tri, vli = strat_split(targets_all[tv_i], 0.15, seed+fi)
|
| | feat.fit_scaler(X_all[tv_i][tri])
|
| | tr_s = feat.transform(X_all[tv_i][tri])
|
| | vl_s = feat.transform(X_all[tv_i][vli])
|
| | te_s = feat.transform(X_all[te_i])
|
| |
|
| | pin = device.type == 'cuda'
|
| | tr_dl = DataLoader(DSData(tr_s, targets_all[tv_i][tri]), shuffle=True,
|
| | pin_memory=pin, **dl_kw)
|
| | vl_dl = DataLoader(DSData(vl_s, targets_all[tv_i][vli]), shuffle=False,
|
| | pin_memory=pin, **dl_kw)
|
| | te_dl = DataLoader(DSData(te_s, targets_all[te_i]), shuffle=False,
|
| | pin_memory=pin, **dl_kw)
|
| | te_tgt = get_targets(te_dl)
|
| |
|
| | torch.manual_seed(seed + fi); np.random.seed(seed + fi)
|
| | if device.type == 'cuda': torch.cuda.manual_seed(seed + fi)
|
| |
|
| | model = cfg['model_cls'](**cfg['model_kw']).to(device)
|
| | bv, model, hist = cfg['train_fn'](model, tr_dl, vl_dl, device,
|
| | fold=fi+1,
|
| | name=f"{cname}[s{seed}]")
|
| |
|
| |
|
| | if si == 0:
|
| | fold_hists.append(hist)
|
| |
|
| |
|
| | if cfg['is_confidence']:
|
| | pred, conf_w = predict_confidence(model, te_dl, device)
|
| | if si == 0:
|
| | fold_conf_w.append(conf_w)
|
| | avg_peak = conf_w.argmax(dim=1).float().mean().item() + 1
|
| | mae = F.l1_loss(pred, te_tgt).item()
|
| | log.info(f" [s{seed}] F{fi+1}: MAE={mae:.2f} "
|
| | f"(val {bv:.2f}, avg peak step={avg_peak:.1f})")
|
| | else:
|
| | pred = cfg['predict_fn'](model, te_dl, device)
|
| | mae = F.l1_loss(pred, te_tgt).item()
|
| | log.info(f" [s{seed}] F{fi+1}: MAE={mae:.2f} (val {bv:.2f})")
|
| |
|
| | seed_fold_preds[seed][fi] = pred
|
| | seed_fold_maes[seed].append(mae)
|
| |
|
| | torch.save({'model_state': model.state_dict(), 'test_mae': mae,
|
| | 'config': cname, 'seed': seed},
|
| | f'trm_models_v13/{cname}_seed{seed}_fold{fi+1}.pt')
|
| |
|
| |
|
| | del model; torch.cuda.empty_cache() if device.type == 'cuda' else None
|
| |
|
| | seed_avg = float(np.mean(seed_fold_maes[seed]))
|
| | print(f" ββββ Seed {seed} avg: {seed_avg:.2f} MPa ββββ")
|
| |
|
| |
|
| | ensemble_fold_maes = []
|
| | for fi, (tv_i, te_i) in enumerate(folds):
|
| | te_tgt_np = targets_all[te_i]
|
| | te_tgt_t = torch.tensor(te_tgt_np, dtype=torch.float32)
|
| |
|
| |
|
| | all_seed_preds = torch.stack([seed_fold_preds[s][fi] for s in SEEDS])
|
| | ensemble_pred = all_seed_preds.mean(dim=0)
|
| |
|
| | ens_mae = F.l1_loss(ensemble_pred, te_tgt_t).item()
|
| | ensemble_fold_maes.append(ens_mae)
|
| |
|
| | ens_avg = float(np.mean(ensemble_fold_maes))
|
| | ens_std = float(np.std(ensemble_fold_maes))
|
| |
|
| |
|
| | per_seed_avgs = {s: float(np.mean(seed_fold_maes[s])) for s in SEEDS}
|
| | best_single_seed = min(per_seed_avgs.items(), key=lambda x: x[1])
|
| |
|
| | all_results[cname] = {
|
| | 'avg': ens_avg, 'std': ens_std, 'folds': ensemble_fold_maes,
|
| | 'params': cfg['n_params'],
|
| | 'per_seed_avgs': per_seed_avgs,
|
| | 'per_seed_folds': {str(s): seed_fold_maes[s] for s in SEEDS},
|
| | 'best_single_seed': best_single_seed[0],
|
| | 'best_single_mae': best_single_seed[1],
|
| | }
|
| | all_hists[cname] = fold_hists
|
| | if fold_conf_w:
|
| | all_conf_weights[cname] = fold_conf_w
|
| |
|
| | print(f"\n βββ {cname} βββ")
|
| | print(f" Ensemble ({N_SEEDS}-seed avg): {ens_avg:.4f} Β±{ens_std:.4f} MPa")
|
| | print(f" Best single seed ({best_single_seed[0]}): "
|
| | f"{best_single_seed[1]:.4f} MPa")
|
| | for s in SEEDS:
|
| | print(f" Seed {s:>3}: {per_seed_avgs[s]:.2f} MPa "
|
| | f"folds={[f'{m:.1f}' for m in seed_fold_maes[s]]}")
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | tt = time.time() - t0
|
| | print(f"\n{'β'*72}")
|
| | print(f" FINAL LEADERBOARD β matbench_steels V13 (5-Fold Avg MAE)")
|
| | print(f"{'β'*72}")
|
| | print(f" {'Model':<26} {'Params':>10} {'MAE(MPa)':>10} {'Β±Std':>8} Notes")
|
| | print(f" {'β'*72}")
|
| | for n, r in sorted(all_results.items(), key=lambda x: x[1]['avg']):
|
| | tag = (" β BEATS MODNet π" if r['avg'] < 87.76 else
|
| | " β BEATS V12A β" if r['avg'] < 95.99 else
|
| | " β BEATS RF-SCM β" if r['avg'] < 103.51 else
|
| | " β BEATS DARWIN β" if r['avg'] < 123.29 else "")
|
| | print(f" {n+' (ens)':<26} {r['params']:>9,} "
|
| | f"{r['avg']:>10.4f} {r['std']:>8.4f}{tag}")
|
| | print(f" {n+' (best 1)':<26} {'':>10} "
|
| | f"{r['best_single_mae']:>10.4f} {'':>8} seed={r['best_single_seed']}")
|
| | print(f" {'β'*72}")
|
| | for bn, bv in sorted(BASELINES.items(), key=lambda x: x[1]):
|
| | print(f" {bn:<26} {'baseline':>10} {bv:>10.4f}")
|
| | print(f"\n Total time: {tt/60:.1f} minutes ({N_SEEDS} seeds Γ 2 configs Γ 5 folds)")
|
| |
|
| |
|
| | print(f"\n{'β'*72}")
|
| | print(f" PER-FOLD ENSEMBLE BREAKDOWN")
|
| | print(f"{'β'*72}")
|
| | cnames = list(all_results.keys())
|
| | header = f" {'Fold':<6}"
|
| | for cn in cnames:
|
| | header += f" {cn:>20}"
|
| | print(header)
|
| | print(f" {'β'*52}")
|
| | for fi in range(5):
|
| | row = f" {fi+1:<6}"
|
| | for cn in cnames:
|
| | row += f" {all_results[cn]['folds'][fi]:>20.2f}"
|
| | print(row)
|
| |
|
| |
|
| | print(f"\n{'β'*72}")
|
| | print(f" PER-SEED BREAKDOWN")
|
| | print(f"{'β'*72}")
|
| | for cn in cnames:
|
| | r = all_results[cn]
|
| | print(f"\n {cn}:")
|
| | header = f" {'Seed':<6}"
|
| | for fi in range(5):
|
| | header += f" {'F'+str(fi+1):>8}"
|
| | header += f" {'Avg':>8}"
|
| | print(header)
|
| | print(f" {'β'*52}")
|
| | for s in SEEDS:
|
| | row = f" {s:<6}"
|
| | for mae in r['per_seed_folds'][str(s)]:
|
| | row += f" {mae:>8.2f}"
|
| | row += f" {r['per_seed_avgs'][s]:>8.2f}"
|
| | print(row)
|
| | print(f" {'β'*52}")
|
| | row = f" {'ENS':<6}"
|
| | for mae in r['folds']:
|
| | row += f" {mae:>8.2f}"
|
| | row += f" {r['avg']:>8.2f}"
|
| | print(row)
|
| |
|
| |
|
| | if all_conf_weights:
|
| | print(f"\n Confidence Step Selection Summary:")
|
| | for cn, fw_list in all_conf_weights.items():
|
| | all_w = torch.cat(fw_list, dim=0)
|
| | avg_w = all_w.mean(dim=0)
|
| | peak_step = avg_w.argmax().item() + 1
|
| | avg_peak = all_w.argmax(dim=1).float().mean().item() + 1
|
| | print(f" {cn}: avg peak step={avg_peak:.1f}, "
|
| | f"population peak=step {peak_step}")
|
| | print()
|
| |
|
| | generate_plots(all_results, all_hists, all_conf_weights)
|
| | save_summary(all_results, all_hists, all_conf_weights, tt)
|
| | return all_results
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | PAL = {'V13A-2xSA-StdDS': '#1565C0', 'V13B-2xSA-ConfDS': '#E65100'}
|
| |
|
| | def generate_plots(all_results, all_hists, all_conf_weights):
|
| | names = list(all_results.keys())
|
| | avgs = [all_results[n]['avg'] for n in names]
|
| | stds = [all_results[n]['std'] for n in names]
|
| | cols = [PAL.get(n, '#888') for n in names]
|
| |
|
| | fig = plt.figure(figsize=(22, 18))
|
| | gs = gridspec.GridSpec(2, 2, figure=fig, hspace=0.35, wspace=0.30)
|
| |
|
| |
|
| | ax1 = fig.add_subplot(gs[0, 0])
|
| |
|
| |
|
| | x_pos = np.arange(len(names))
|
| | w = 0.35
|
| | ens_bars = ax1.bar(x_pos - w/2, avgs, w, yerr=stds, capsize=6,
|
| | color=cols, alpha=0.88, edgecolor='white',
|
| | linewidth=1.5, label='Ensemble')
|
| | best_singles = [all_results[n]['best_single_mae'] for n in names]
|
| | single_bars = ax1.bar(x_pos + w/2, best_singles, w, capsize=6,
|
| | color=cols, alpha=0.45, edgecolor='white',
|
| | linewidth=1.5, label='Best Single Seed',
|
| | hatch='//')
|
| |
|
| | for bv, c, ls, lb in [
|
| | (87.76, '#F57F17', '--', 'MODNet (87.76)'),
|
| | (95.99, '#4CAF50', '-.', 'V12A (95.99)'),
|
| | (102.30, '#9E9E9E', '-.', 'V11B (102.30)'),
|
| | (103.51, '#B0BEC5', ':', 'RF-SCM (103.51)'),
|
| | (107.32, '#FF9800', ':', 'CrabNet (107.32)'),
|
| | ]:
|
| | ax1.axhline(bv, color=c, linestyle=ls, linewidth=1.8, label=lb, alpha=0.85)
|
| | for bar, m, s in zip(ens_bars, avgs, stds):
|
| | ax1.text(bar.get_x()+bar.get_width()/2, bar.get_height()+s+1,
|
| | f'{m:.1f}', ha='center', fontsize=11, fontweight='bold')
|
| | for bar, m in zip(single_bars, best_singles):
|
| | ax1.text(bar.get_x()+bar.get_width()/2, bar.get_height()+1,
|
| | f'{m:.1f}', ha='center', fontsize=9, fontstyle='italic',
|
| | alpha=0.7)
|
| |
|
| | ax1.set_xticks(x_pos)
|
| | ax1.set_xticklabels(names, fontsize=8)
|
| | ax1.legend(fontsize=6, loc='upper right')
|
| | ax1.set_ylabel('MAE (MPa)'); ax1.set_ylim(0, max(avgs)*1.6)
|
| | ax1.set_title('V13 Results vs Baselines (Ensemble + Best Single)',
|
| | fontsize=11, fontweight='bold')
|
| | ax1.grid(axis='y', alpha=0.3)
|
| |
|
| |
|
| | ax2 = fig.add_subplot(gs[0, 1])
|
| | x = np.arange(1, 6)
|
| | w = 0.35
|
| | for i, (n, col) in enumerate(zip(names, cols)):
|
| | fold_vals = all_results[n]['folds']
|
| | ax2.bar(x + (i - 0.5) * w, fold_vals, w, color=col, alpha=0.8,
|
| | label=n + ' (ens)', edgecolor='white')
|
| | ax2.axhline(95.99, color='#4CAF50', ls='-.', lw=1.5, label='V12A (95.99)')
|
| | ax2.axhline(87.76, color='#F57F17', ls='--', lw=1.5, label='MODNet (87.76)')
|
| | ax2.set_xlabel('Fold'); ax2.set_ylabel('MAE (MPa)')
|
| | ax2.set_xticks(x); ax2.set_xticklabels([f'F{i}' for i in range(1,6)])
|
| | ax2.set_title('Per-Fold Ensemble Breakdown', fontweight='bold')
|
| | ax2.legend(fontsize=7); ax2.grid(axis='y', alpha=0.2)
|
| |
|
| |
|
| | ax3 = fig.add_subplot(gs[1, 0])
|
| | for cname, col in PAL.items():
|
| | if cname not in all_hists: continue
|
| | for fi, h in enumerate(all_hists[cname]):
|
| | lb_tr = f'{cname} train' if fi == 0 else None
|
| | lb_vl = f'{cname} val' if fi == 0 else None
|
| | ax3.plot(h['train'], alpha=0.3, lw=0.8, color=col, label=lb_tr)
|
| | ax3.plot(h['val'], alpha=0.7, lw=1.2, color=col, label=lb_vl,
|
| | linestyle='--')
|
| | ax3.axhline(95.99, color='#4CAF50', ls='-.', lw=1.2, label='V12A (95.99)')
|
| | ax3.axvline(200, color='#4CAF50', ls='--', lw=1.2, alpha=0.6, label='SWA start')
|
| | ax3.set_xlabel('Epoch'); ax3.set_ylabel('MAE (MPa)')
|
| | ax3.set_title('Training Curves (seed 0, all folds)', fontweight='bold')
|
| | ax3.legend(fontsize=6, ncol=2); ax3.grid(alpha=0.2)
|
| | ax3.set_ylim(0, 300)
|
| |
|
| |
|
| | ax4 = fig.add_subplot(gs[1, 1])
|
| | if all_conf_weights:
|
| | for cn, fw_list in all_conf_weights.items():
|
| | all_w = torch.cat(fw_list, dim=0)
|
| | avg_w = all_w.mean(dim=0).numpy()
|
| | steps = np.arange(1, len(avg_w)+1)
|
| | ax4.bar(steps, avg_w, color=PAL.get(cn, '#E65100'), alpha=0.8,
|
| | label=f'{cn} avg confidence', edgecolor='white')
|
| | std_w = all_w.std(dim=0).numpy()
|
| | ax4.errorbar(steps, avg_w, yerr=std_w, fmt='none',
|
| | ecolor='#333', capsize=2, alpha=0.5)
|
| | ax4.set_xlabel('Recursion Step')
|
| | ax4.set_ylabel('Confidence Weight (softmax)')
|
| | ax4.set_title('V13B: Where the Model Trusts Its Predictions',
|
| | fontweight='bold')
|
| | ax4.legend(fontsize=8)
|
| | ax4.grid(axis='y', alpha=0.2)
|
| | else:
|
| |
|
| | for i, (cn, col) in enumerate(zip(names, cols)):
|
| | r = all_results[cn]
|
| | seed_avgs = [r['per_seed_avgs'][s] for s in SEEDS]
|
| | ax4.scatter(SEEDS, seed_avgs, s=80, c=col, alpha=0.8,
|
| | label=f'{cn} per-seed', zorder=5,
|
| | edgecolors='white', linewidth=1)
|
| | ax4.axhline(r['avg'], color=col, ls='--', lw=1.5, alpha=0.6,
|
| | label=f'{cn} ensemble={r["avg"]:.2f}')
|
| | ax4.axhline(95.99, color='#4CAF50', ls=':', lw=1, alpha=0.5, label='V12A')
|
| | ax4.set_xlabel('Random Seed')
|
| | ax4.set_ylabel('5-Fold Avg MAE (MPa)')
|
| | ax4.set_title('Per-Seed vs Ensemble Performance', fontweight='bold')
|
| | ax4.legend(fontsize=7); ax4.grid(alpha=0.2)
|
| |
|
| | fig.suptitle('TRM-MatSci V13 β 2-Layer SA + Multi-Seed Ensemble β matbench_steels',
|
| | fontsize=14, fontweight='bold', y=1.01)
|
| | fig.savefig('trm_results_v13.png', dpi=150, bbox_inches='tight')
|
| | plt.close(fig); log.info("β Saved: trm_results_v13.png")
|
| |
|
| |
|
| | def save_summary(all_results, all_hists, all_conf_weights, total_s):
|
| |
|
| | conf_info = {}
|
| | for cn, fw_list in all_conf_weights.items():
|
| | all_w = torch.cat(fw_list, dim=0)
|
| | conf_info[cn] = {
|
| | 'avg_weights': all_w.mean(dim=0).numpy().round(4).tolist(),
|
| | 'avg_peak_step': float(all_w.argmax(dim=1).float().mean().item() + 1),
|
| | }
|
| |
|
| | s = {
|
| | 'version': 'V13', 'task': 'matbench_steels',
|
| | 'strategy': '2-Layer SA + Multi-Seed Ensemble',
|
| | 'seeds': SEEDS,
|
| | 'n_seeds': N_SEEDS,
|
| | 'total_min': round(total_s/60, 1),
|
| | 'models': {},
|
| | 'confidence': conf_info,
|
| | }
|
| | for n, r in all_results.items():
|
| | s['models'][n] = {
|
| | 'ensemble_avg': round(r['avg'], 4),
|
| | 'ensemble_std': round(r['std'], 4),
|
| | 'ensemble_folds': [round(x, 4) for x in r['folds']],
|
| | 'params': r['params'],
|
| | 'best_single_seed': r['best_single_seed'],
|
| | 'best_single_mae': round(r['best_single_mae'], 4),
|
| | 'per_seed_avgs': {str(k): round(v, 4) for k, v in r['per_seed_avgs'].items()},
|
| | }
|
| |
|
| | with open('trm_models_v13/summary_v13.json', 'w') as f:
|
| | json.dump(s, f, indent=2, default=str)
|
| | log.info("β Saved: summary_v13.json")
|
| |
|
| |
|
| | if __name__ == '__main__':
|
| | results = run_benchmark()
|
| | shutil.make_archive("trm_v13_all", "zip", "trm_models_v13")
|
| | log.info("β Created trm_v13_all.zip")
|
| |
|