TRIADS: Tiny Recursive Information-Attention with Deep Supervision

A High-Precision Deep Learning Architecture for Alloy Yield Strength Prediction on Sparse Datasets

Matbench GitHub License: MIT PyTorch

πŸ“¦ This is the model weights repository. For full training code, research documentation, and experimental history, see the GitHub repository.


Overview

TRIADS is a novel deep learning architecture that combines self-attention-based feature extraction with recursive MLP reasoning and deep supervision to predict the yield strength of steel alloys with state-of-the-art accuracy. Developed through 15 iterative versions and approximately 200 trained models, TRIADS achieves a Mean Absolute Error of 91.20 MPa on the Matbench Steels benchmarkβ€”surpassing established baselines including Random Forest models, CrabNet, and Darwin, and competing directly with neural network and AutoML approaches that rely on orders-of-magnitude more parameters or extensive pretraining.

The architecture was designed from the ground up to operate on micro-scale datasets (N=312), where conventional deep learning approaches chronically overfit. TRIADS addresses this challenge through a combination of engineered compositional features, structured attention, shared-weight recursive reasoning, and a deep supervision training protocol that together enable a 224K-parameter model to generalize effectively where million-parameter models fail.

This repository contains a single consolidated checkpoint packaging all 25 models (5 seeds Γ— 5 folds) along with an evaluation script to reproduce the SOTA result end-to-end.


Benchmark Results (Matbench Steels)

Model MAE (MPa) Type Parameters
AutoGluon 77.03 Stacked Ensemble (AutoML) β€”
TPOT-Mat 79.95 AutoML Pipeline β€”
MODNet v0.1.12 87.76 Neural Network β€”
RF-Regex Steels 90.58 Random Forest β€”
TRIADS V13A (Ours) 91.20 Hybrid-TRM + Deep Supervision 224,685
RF-SCM/Magpie 103.51 Random Forest + Magpie β€”
CrabNet 107.31 Transformer (Pretrained on 300K+) ~1M+
Darwin 123.29 Evolutionary Algorithm β€”

Peak per-fold performance of 80.55 MPa observed during official 5-fold nested cross-validationβ€”surpassing TPOT-Mat (79.95) on that data split, achieved with a 224K-parameter model trained entirely from scratch.


πŸš€ Interactive Demo

Try the model without any setup using the bundled Gradio app:

pip install gradio
python app.py

The app provides:

  • Instant predictions β€” enter any steel alloy formula
  • Composition breakdown β€” visual element distribution
  • Ensemble statistics β€” mean, std, range, and confidence assessment
  • Per-seed details β€” transparency into how each model contributes

Reproduce the SOTA Result

You can reproduce the 91.20 MPa result with a single command:

pip install torch pymatgen matminer gensim huggingface_hub scikit-learn tqdm
python evaluate.py

This will:

  1. Download the checkpoint from this repository automatically
  2. Load the matbench_steels dataset (312 samples)
  3. Compute expanded features (Magpie + Mat2Vec + Matminer)
  4. Run official 5-fold nested cross-validation with 5-seed ensemble averaging
  5. Report per-fold and overall MAE

You can also evaluate with a local checkpoint:

python evaluate.py --checkpoint path/to/triads_v13a_ensemble.pt

Architecture

TRIADS operates through four sequential processing stages. Full architecture details and design rationale are documented in the GitHub repository.

Stage 1: Compositional Featurization (~462 features)

  • Magpie Descriptors (132d): 22 elemental properties Γ— 6 statistics
  • Mat2Vec Embeddings (200d): Pretrained from 3M+ materials science abstracts
  • Extended Matminer Descriptors (~130d): ElementFraction, Stoichiometry, ValenceOrbital, IonProperty, BandCenter

Stage 2: 2-Layer Self-Attention Feature Extraction

22 property tokens projected into 64d attention space β†’ 2Γ— self-attention for higher-order property interactions β†’ cross-attention to Mat2Vec chemical semantics.

Stage 3: Recursive MLP Reasoning (The TRM Core)

Shared-weight 2-layer MLP iteratively refines reasoning and prediction states over 20 recursive steps β€” providing the depth of a 40-layer network with zero additional parameters.

For t = 1 to 20:
    zβ‚œ = zβ‚œβ‚‹β‚ + MLP_z(zβ‚œβ‚‹β‚, yβ‚œβ‚‹β‚, x_pooled)     # Refine reasoning
    yβ‚œ = yβ‚œβ‚‹β‚ + MLP_y(yβ‚œβ‚‹β‚, zβ‚œ)                   # Refine prediction

Final output: yield_strength = Linear(yβ‚‚β‚€)

Stage 4: Deep Supervision (Training)

L1 loss at every recursion step with linearly increasing weights β€” the single most impactful design decision, reducing MAE by 24 MPa and acting as a regularizer that enables architectural scaling on small datasets.

5-Seed Ensemble (Inference)

Final prediction averaged across 5 independently trained models (seeds 42, 123, 7, 0, 99), yielding a 5.57 MPa improvement (96.77 β†’ 91.20 MPa).


Model Details

Property Value
Architecture DeepHybridTRM (2-Layer SA + Recursive MLP)
Parameters 224,685
Input Features ~462 (Magpie + Mat2Vec + Matminer)
Attention d_attn=64, nhead=4, 2Γ— Self-Attention + Cross-Attention
Reasoning d_hidden=96, ff_dim=150, 20 recursive steps
Ensemble 5 seeds Γ— 5 folds = 25 models
Checkpoint Size ~22 MB (consolidated)
Training Hardware Kaggle P100 GPU
Training Time ~25 min per seed

Loading the Checkpoint Programmatically

import torch
from huggingface_hub import hf_hub_download
from model_arch import DeepHybridTRM

# Download
ckpt_path = hf_hub_download(repo_id="Rtx09/TRIADS", filename="triads_v13a_ensemble.pt")
ckpt = torch.load(ckpt_path, map_location="cpu")
config = ckpt["config"]

# Load a specific model (e.g., seed42, fold1)
model = DeepHybridTRM(**config)
model.load_state_dict(ckpt["ensemble_weights"]["seed42_fold1"])
model.eval()

# Load ALL models for full ensemble
models = {}
for key, state_dict in ckpt["ensemble_weights"].items():
    m = DeepHybridTRM(**config)
    m.load_state_dict(state_dict)
    m.eval()
    models[key] = m

The Research Journey: From 184 to 91 MPa

TRIADS was developed through 15 major versions. Full details in Architecture Evolution, Hyperparameter Studies, and Performance Logs.

Version Focus MAE (MPa) Key Finding
V1 Baseline TRM 184.38 Input representation is the bottleneck
V2 Element-as-token Transformer 388.58 ❌ 312 samples insufficient for raw attention
V3 Magpie feature engineering 130.33 Engineered features shatter the 184 MPa ceiling
V5 SWA + Property-token attention 128.98 Property tokens unlock attention
V7 Scaled Hybrid-TRM 127.08 First time attention surpasses pure MLP
V10 Deep Supervision 103.28 Core breakthrough β€” beat Darwin, CrabNet
V12 Expanded features + scaling 95.99 First sub-100 MPa; features + capacity coupling
V13 2-layer SA + 5-seed ensemble 91.20 Project SOTA β€” 50.5% error reduction from V1

All training scripts (V1–V15), raw JSON metrics, and result visualizations are available in the GitHub repository.


Key Technical Insights

  1. Input Representation > Model Capacity: On sparse datasets, how you represent the input matters more than how you process it. Replacing fraction-weighted sums with Magpie descriptors dropped MAE by 54 MPa without any architectural change.

  2. Attention Requires Structure: Element-as-token attention: 388 MPa. Property-token attention: 165 MPa. A 223 MPa improvement from input restructuring alone.

  3. Deep Supervision as Regularization: Enables architectural scaling β€” the same d_attn=64 overfits by +28 MPa without DS and achieves SOTA with it.

  4. The Recursive Mechanism Works: Smooth MAE descent from ~1400 MPa at step 1 to ~184 MPa at step 16 β€” each pass demonstrably refines the prediction.

  5. Features and Architecture Must Be Co-Scaled: Expanded features failed on small architectures; large architectures failed without expanded features. The breakthrough required both.

Detailed ablation analysis available in Hyperparameter Studies.


Repository Contents

File Description
triads_v13a_ensemble.pt Consolidated checkpoint (25 models, ~22 MB)
model_arch.py Full architecture code (DeepHybridTRM + ExpandedFeaturizer)
evaluate.py Self-contained script to reproduce the 91.20 MPa result
app.py Interactive Gradio demo for non-technical users
README.md This model card

Citation

@software{tiwari2026triads,
  author = {Rudra Tiwari},
  title = {TRIADS: Tiny Recursive Information-Attention with Deep Supervision for Alloy Strength Prediction},
  year = {2026},
  url = {https://github.com/Rtx09x/TRIADS}
}

License

This project is licensed under the MIT License. See the GitHub repository for full license text.

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support