""" TRIADS — Multi-Benchmark Materials Property Prediction HuggingFace Gradio App (Production Redux) Covers all 6 Matbench benchmarks: 1. matbench_steels — Yield Strength (MPa) 2. matbench_expt_gap — Band Gap (eV) 3. matbench_ismetal — Metallicity (ROC-AUC) 4. matbench_glass — Glass Forming Ability 5. matbench_jdft2d — Exfoliation Energy (meV/atom) 6. matbench_phonons — Peak Phonon Frequency (cm⁻¹) """ import os import warnings import urllib.request import json import traceback warnings.filterwarnings("ignore") import numpy as np import torch import torch.nn as nn import gradio as gr from huggingface_hub import hf_hub_download # ───────────────────────────────────────────────────────────────── # CONFIG # ───────────────────────────────────────────────────────────────── REPO_ID = "Rtx09/TRIADS" # Used only if local weights are missing BENCHMARK_INFO = { "steels": { "title": "🔩 Steel Yield Strength", "description": "Predict yield strength (MPa) of steel alloys from composition.", "unit": "MPa", "example": "Fe0.7Cr0.15Ni0.15", "examples": ["Fe0.7Cr0.15Ni0.15", "Fe0.8C0.02Mn0.1Si0.05Cr0.03", "Fe0.6Ni0.25Mo0.1Cr0.05"], "task": "regression", "result": "91.20 ± 12.23 MPa MAE (5-fold, 5-seed ensemble)", }, "expt_gap": { "title": "⚡ Experimental Band Gap", "description": "Predict experimental electronic band gap (eV) from composition.", "unit": "eV", "example": "TiO2", "examples": ["TiO2", "GaN", "ZnO", "Si", "CdS"], "task": "regression", "result": "0.3068 ± 0.0082 eV MAE (5-fold, composition-only)", }, "ismetal": { "title": "🔮 Metallicity Classifier", "description": "Predict whether a material is metallic or non-metallic from composition.", "unit": "probability (1 = metal)", "example": "Cu", "examples": ["Cu", "SiO2", "Fe3O4", "BaTiO3", "Al"], "task": "classification", "result": "0.9655 ± 0.0029 ROC-AUC (5-fold, composition-only)", }, "glass": { "title": "🪟 Glass Forming Ability", "description": "Predict metallic glass forming ability from alloy composition.", "unit": "probability (1 = glass former)", "example": "Cu46Zr54", "examples": ["Cu46Zr54", "Fe80B20", "Al86Ni7La6Y1", "Pd40Cu30Ni10P20"], "task": "classification", "result": "0.9285 ± 0.0063 ROC-AUC (5-fold, 5-seed ensemble)", }, "jdft2d": { "title": "📐 Exfoliation Energy", "description": "Predict exfoliation energy (meV/atom) of 2D materials from structure+composition.", "unit": "meV/atom", "example": "MoS2", "examples": ["MoS2", "WSe2", "BN", "graphene (C)", "MoTe2"], "task": "regression", "result": "35.89 ± 12.40 meV/atom MAE (5-fold, 5-seed ensemble)", }, "phonons": { "title": "🎵 Phonon Peak Frequency", "description": "Predict peak phonon frequency (cm⁻¹) from crystal structure.", "unit": "cm⁻¹", "example": "Si (diamond cubic)", "examples": ["Si", "GaAs", "MgO", "BN (wurtzite)", "TiO2 (rutile)"], "task": "regression", "result": "41.91 ± 4.04 cm⁻¹ MAE (5-fold, gate-halt GraphTRIADS)", }, } # ───────────────────────────────────────────────────────────────── # MODEL DEFINITIONS (inlined for self-contained app) # ───────────────────────────────────────────────────────────────── class DeepHybridTRM(nn.Module): """ HybridTRIADS — composition-only tasks. Shared across: steels, expt_gap, ismetal, glass, jdft2d. """ def __init__(self, n_props=22, stat_dim=6, n_extra=0, mat2vec_dim=200, d_attn=64, nhead=4, d_hidden=96, ff_dim=150, dropout=0.2, max_steps=20, **kw): super().__init__() self.max_steps, self.D = max_steps, d_hidden self.n_props, self.stat_dim, self.n_extra = n_props, stat_dim, n_extra self.tok_proj = nn.Sequential( nn.Linear(stat_dim, d_attn), nn.LayerNorm(d_attn), nn.GELU()) self.m2v_proj = nn.Sequential( nn.Linear(mat2vec_dim, d_attn), nn.LayerNorm(d_attn), nn.GELU()) self.sa1 = nn.MultiheadAttention(d_attn, nhead, dropout=dropout, batch_first=True) self.sa1_n = nn.LayerNorm(d_attn) self.sa1_ff = nn.Sequential( nn.Linear(d_attn, d_attn*2), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_attn*2, d_attn)) self.sa1_fn = nn.LayerNorm(d_attn) self.sa2 = nn.MultiheadAttention(d_attn, nhead, dropout=dropout, batch_first=True) self.sa2_n = nn.LayerNorm(d_attn) self.sa2_ff = nn.Sequential( nn.Linear(d_attn, d_attn*2), nn.GELU(), nn.Dropout(dropout), nn.Linear(d_attn*2, d_attn)) self.sa2_fn = nn.LayerNorm(d_attn) self.ca = nn.MultiheadAttention(d_attn, nhead, dropout=dropout, batch_first=True) self.ca_n = nn.LayerNorm(d_attn) pool_in = d_attn + (n_extra if n_extra > 0 else 0) self.pool = nn.Sequential( nn.Linear(pool_in, d_hidden), nn.LayerNorm(d_hidden), nn.GELU()) self.z_up = nn.Sequential( nn.Linear(d_hidden*3, ff_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(ff_dim, d_hidden), nn.LayerNorm(d_hidden)) self.y_up = nn.Sequential( nn.Linear(d_hidden*2, ff_dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(ff_dim, d_hidden), nn.LayerNorm(d_hidden)) self.head = nn.Linear(d_hidden, 1) self._init() def _init(self): for m in self.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) def _attention(self, x): B = x.size(0) mg_dim = self.n_props * self.stat_dim if self.n_extra > 0: extra = x[:, mg_dim:mg_dim + self.n_extra] m2v = x[:, mg_dim + self.n_extra:] else: extra, m2v = None, x[:, mg_dim:] tok = self.tok_proj(x[:, :mg_dim].view(B, self.n_props, self.stat_dim)) ctx = self.m2v_proj(m2v).unsqueeze(1) tok = self.sa1_n(tok + self.sa1(tok, tok, tok)[0]) tok = self.sa1_fn(tok + self.sa1_ff(tok)) tok = self.sa2_n(tok + self.sa2(tok, tok, tok)[0]) tok = self.sa2_fn(tok + self.sa2_ff(tok)) tok = self.ca_n(tok + self.ca(tok, ctx, ctx)[0]) pooled = tok.mean(dim=1) if extra is not None: pooled = torch.cat([pooled, extra], dim=-1) return self.pool(pooled) def forward(self, x, deep_supervision=False): B = x.size(0) xp = self._attention(x) z = torch.zeros(B, self.D, device=x.device) y = torch.zeros(B, self.D, device=x.device) step_preds = [] for _ in range(self.max_steps): z = z + self.z_up(torch.cat([xp, y, z], -1)) y = y + self.y_up(torch.cat([y, z], -1)) step_preds.append(self.head(y).squeeze(1)) return step_preds if deep_supervision else step_preds[-1] # ───────────────────────────────────────────────────────────────── # FEATURIZER (composition-only, shared across HybridTRIADS tasks) # ───────────────────────────────────────────────────────────────── _featurizer_cache = {} _mat2vec_cache = {} _featurizer_err = None def _get_featurizer(): """Lazy-load the ExpandedFeaturizer (downloads Mat2Vec once).""" global _featurizer_err if "main" in _featurizer_cache: return _featurizer_cache["main"] try: from matminer.featurizers.composition import ( ElementProperty, ElementFraction, Stoichiometry, ValenceOrbital, IonProperty, BandCenter ) from matminer.featurizers.base import MultipleFeaturizer from gensim.models import Word2Vec from sklearn.preprocessing import StandardScaler GCS = "https://storage.googleapis.com/mat2vec/" M2V_FILES = [ "pretrained_embeddings", "pretrained_embeddings.wv.vectors.npy", "pretrained_embeddings.trainables.syn1neg.npy", ] # Use /tmp for writable cache if current dir is read-only cache_dir = os.path.join(os.getcwd(), "mat2vec_cache") try: os.makedirs(cache_dir, exist_ok=True) # Test write access test_file = os.path.join(cache_dir, ".test") with open(test_file, 'w') as f: f.write('1') os.remove(test_file) except Exception: cache_dir = "/tmp/mat2vec_cache" os.makedirs(cache_dir, exist_ok=True) for f in M2V_FILES: p = os.path.join(cache_dir, f) if not os.path.exists(p): print(f"Downloading {f}...") urllib.request.urlretrieve(GCS + f, p) # Magpie preset can fail if Figshare is down try: ep = ElementProperty.from_preset("magpie") except Exception as e: print(f"Magpie download failed, retrying once: {e}") import time time.sleep(2) ep = ElementProperty.from_preset("magpie") m2v = Word2Vec.load(os.path.join(cache_dir, "pretrained_embeddings")) emb = {w: m2v.wv[w] for w in m2v.wv.index_to_key} extra = MultipleFeaturizer([ElementFraction(), Stoichiometry(), ValenceOrbital(), IonProperty(), BandCenter()]) _featurizer_cache["main"] = (ep, m2v, emb, extra) return _featurizer_cache["main"] except Exception as e: _featurizer_err = f"{type(e).__name__}: {str(e)}\n{traceback.format_exc()}" print(f"CRITICAL Featurizer Error: {_featurizer_err}") return None def featurize_composition(formula: str): """Featurize a chemical formula into the TRIADS feature vector.""" from pymatgen.core import Composition result = _get_featurizer() if result is None: return None, f"Featurizer initialization failed.\nError: {_featurizer_err}" ep, m2v, emb, extra = result try: comp = Composition(formula) except Exception as e: return None, f"Invalid formula: '{formula}' | {str(e)}" try: mg = np.array(ep.featurize(comp), np.float32) except Exception as e: mg = np.zeros(len(ep.feature_labels()), np.float32) try: ex = np.array(extra.featurize(comp), np.float32) ex = np.nan_to_num(ex, nan=0.0) except Exception as e: ex = np.zeros(50, np.float32) # Mat2Vec pooled v, t = np.zeros(200, np.float32), 0.0 for s, f in comp.get_el_amt_dict().items(): if s in emb: v += f * emb[s] t += f m2v_vec = v / max(t, 1e-8) mg = np.nan_to_num(mg, nan=0.0) feat = np.concatenate([mg, ex, m2v_vec]) return feat.astype(np.float32), None # ───────────────────────────────────────────────────────────────── # WEIGHT LOADING (lazy, cached) # ───────────────────────────────────────────────────────────────── _fold_models = {} # benchmark -> list[nn.Module] _MODEL_CONFIGS = { "steels": dict(d_attn=64, d_hidden=96, ff_dim=150, dropout=0.20, max_steps=20), "expt_gap": dict(d_attn=64, d_hidden=96, ff_dim=150, dropout=0.20, max_steps=20), "ismetal": dict(d_attn=24, d_hidden=48, ff_dim=72, dropout=0.20, max_steps=16), "glass": dict(d_attn=24, d_hidden=48, ff_dim=72, dropout=0.20, max_steps=16), "jdft2d": dict(d_attn=32, d_hidden=64, ff_dim=96, dropout=0.20, max_steps=16), } _HF_PATHS = { "steels": "weights/steels/weights.pt", "expt_gap": "weights/expt_gap/weights.pt", "ismetal": "weights/is_metal/weights.pt", "glass": "weights/glass/weights.pt", "jdft2d": "weights/jdft2d/weights.pt", "phonons": "weights/phonons/weights.pt", } def _load_benchmark_models(benchmark: str): if benchmark in _fold_models: return _fold_models[benchmark] if benchmark == "phonons": return None try: # 1. Try local path first (relative to app.py) local_path = _HF_PATHS[benchmark] if os.path.exists(local_path): path = local_path else: # 2. Try hfHub if local missing print(f"Local weight {local_path} missing. Attempting hf_hub_download...") path = hf_hub_download(repo_id=REPO_ID, filename=_HF_PATHS[benchmark]) ckpt = torch.load(path, map_location="cpu", weights_only=False) fold_entries = ckpt.get("folds", [ckpt]) n_extra = ckpt.get("n_extra", 0) cfg = {**_MODEL_CONFIGS[benchmark], "n_extra": n_extra} models = [] for entry in fold_entries: m = DeepHybridTRM(**cfg) sd = entry.get("model_state", entry) if isinstance(entry, dict) else entry m.load_state_dict(sd) m.eval() models.append(m) _fold_models[benchmark] = models return models except Exception as e: err_msg = f"Error loading {benchmark} weights: {e}\n{traceback.format_exc()}" print(err_msg) return None def _ensemble_predict(benchmark: str, x: np.ndarray, is_classification: bool = False): models = _load_benchmark_models(benchmark) if not models: return None, "Weights could not be loaded. See logs." xt = torch.tensor(x[None], dtype=torch.float32) preds = [] for m in models: with torch.no_grad(): out = m(xt).item() if is_classification: out = torch.sigmoid(torch.tensor(out)).item() preds.append(out) return float(np.mean(preds)), None # ───────────────────────────────────────────────────────────────── # PREDICTION FUNCTIONS # ───────────────────────────────────────────────────────────────── def predict_steels(formula: str): feat, err = featurize_composition(formula) if err: return f"❌ Error: {err}", "" pred, err = _ensemble_predict("steels", feat) if err: return f"❌ {err}", "" return f"### {pred:.1f} MPa", f"**{pred:.1f} MPa** yield strength" def predict_expt_gap(formula: str): feat, err = featurize_composition(formula) if err: return f"❌ Error: {err}", "" pred, err = _ensemble_predict("expt_gap", feat) if err: return f"❌ {err}", "" return f"### {pred:.3f} eV", f"**{pred:.3f} eV** band gap" def predict_ismetal(formula: str): feat, err = featurize_composition(formula) if err: return f"❌ Error: {err}", "" pred, err = _ensemble_predict("ismetal", feat, True) if err: return f"❌ {err}", "" label = "🔩 METALLIC" if pred > 0.5 else "💎 NON-METALLIC" return f"### {pred:.3f} (Metal)", f"{label} (p={pred:.3f})" def predict_glass(formula: str): feat, err = featurize_composition(formula) if err: return f"❌ Error: {err}", "" pred, err = _ensemble_predict("glass", feat, True) if err: return f"❌ {err}", "" label = "🪟 GLASS-FORMER" if pred > 0.5 else "❌ CRYSTALLINE" return f"### {pred:.3f} (Glass)", f"{label} (p={pred:.3f})" def predict_jdft2d(formula: str): feat, err = featurize_composition(formula) if err: return f"❌ Error: {err}", "" pred, err = _ensemble_predict("jdft2d", feat) if err: return f"❌ {err}", "" return f"### {pred:.1f} meV/atom", f"**{pred:.1f} meV/atom** exfoliation" PHONONS_INFO = """ ## 🎵 Phonon Peak Frequency The **TRIADS V6 Graph-TRM** achieves **41.91 ± 4.04 cm⁻¹ MAE** on Matbench phonons, using a gate-based halting Graph Neural Network that adaptively runs 4–16 message-passing cycles. ### Architecture - **Gate-based halting**: 4–16 adaptive GNN cycles (halts when gate activations drop below threshold) - **Graph Attention TRM**: line-graph bond updates + joint self-attention + cross-attention - **Input**: Full crystal structure — atom positions, bond distances, angles (requires CIF/POSCAR) ### Why no live demo? The phonons model requires a **pre-computed crystal graph** (atom positions, bond lengths, bond angles). Composition-only featurization is insufficient for phonon prediction — structural details like bond stiffness and crystal symmetry are essential. ### Benchmark Results | Model | MAE (cm⁻¹) | |---|---| | **TRIADS V6 (ours)** | **41.91 ± 4.04** | | MEGNet | 28.76 | | ALIGNN | 29.34 | | MODNet | 45.39 | | CrabNet | 47.09 | | TRIADS V4 | 56.33 | > **Note**: MEGNet and ALIGNN use full DFT structural relaxation data. > TRIADS V6 achieves competitive performance with a simpler, more parameter-efficient Graph-TRM architecture (< 50K parameters). """ # ───────────────────────────────────────────────────────────────── # INTERFACE # ───────────────────────────────────────────────────────────────── CSS = """ #result_text { font-size: 1.5rem; font-weight: 700; color: #6366f1; } .benchmark-badge { background: #1e293b; color: #94a3b8; border-radius: 8px; padding: 8px; } """ def build(): with gr.Blocks(css=CSS, title="TRIADS") as demo: gr.Markdown("# ⚡ TRIADS — Materials Property Prediction") gr.Markdown("Recursive Information-Attention with Deep Supervision for all Matbench benchmarks.") with gr.Tabs(): with gr.Tab("🔩 Steel Yield"): f_s = gr.Textbox(label="Formula", value="Fe0.7Cr0.15Ni0.15") btn_s = gr.Button("Predict", variant="primary") out_s = gr.Markdown(elem_id="result_text") ctx_s = gr.Markdown() btn_s.click(predict_steels, f_s, [out_s, ctx_s]) with gr.Tab("⚡ Band Gap"): f_g = gr.Textbox(label="Formula", value="TiO2") btn_g = gr.Button("Predict", variant="primary") out_g = gr.Markdown(elem_id="result_text") ctx_g = gr.Markdown() btn_g.click(predict_expt_gap, f_g, [out_g, ctx_g]) with gr.Tab("🔮 Metallicity"): f_m = gr.Textbox(label="Formula", value="Cu") btn_m = gr.Button("Predict", variant="primary") out_m = gr.Markdown(elem_id="result_text") ctx_m = gr.Markdown() btn_m.click(predict_ismetal, f_m, [out_m, ctx_m]) with gr.Tab("🪟 Glass Forming"): f_gf = gr.Textbox(label="Formula", value="Cu46Zr54") btn_gf = gr.Button("Predict", variant="primary") out_gf = gr.Markdown(elem_id="result_text") ctx_gf = gr.Markdown() btn_gf.click(predict_glass, f_gf, [out_gf, ctx_gf]) with gr.Tab("📐 JDFT2D"): f_j = gr.Textbox(label="Formula", value="MoS2") btn_j = gr.Button("Predict", variant="primary") out_j = gr.Markdown(elem_id="result_text") ctx_j = gr.Markdown() btn_j.click(predict_jdft2d, f_j, [out_j, ctx_j]) with gr.Tab("🎵 Phonons"): gr.Markdown(PHONONS_INFO) return demo if __name__ == "__main__": build().launch()