| | """
|
| | ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| | β TRIADS β Interactive Alloy Yield Strength Predictor β
|
| | β Gradio App for the TRIADS V13A SOTA Ensemble β
|
| | β β
|
| | β Run locally: python app.py β
|
| | β HF Spaces: Auto-detected and hosted β
|
| | ββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| | """
|
| |
|
| | import os
|
| | import warnings
|
| | warnings.filterwarnings("ignore")
|
| |
|
| | import numpy as np
|
| | import torch
|
| | import gradio as gr
|
| | from pymatgen.core import Composition
|
| |
|
| | from model_arch import DeepHybridTRM, ExpandedFeaturizer
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | print("βοΈ Initializing TRIADS V13A Ensemble...")
|
| |
|
| | CKPT_PATH = "triads_v13a_ensemble.pt"
|
| |
|
| |
|
| | if not os.path.exists(CKPT_PATH):
|
| | try:
|
| | from huggingface_hub import hf_hub_download
|
| | print(" Downloading checkpoint from HuggingFace...")
|
| | CKPT_PATH = hf_hub_download(
|
| | repo_id="Rtx09/TRIADS",
|
| | filename="triads_v13a_ensemble.pt"
|
| | )
|
| | except Exception as e:
|
| | raise FileNotFoundError(
|
| | f"Could not find or download checkpoint: {e}")
|
| |
|
| | ckpt = torch.load(CKPT_PATH, map_location="cpu")
|
| | CONFIG = ckpt["config"]
|
| | SEEDS = ckpt["seeds"]
|
| | N_MODELS = ckpt["n_models"]
|
| |
|
| |
|
| | MODELS = []
|
| | for key, state_dict in ckpt["ensemble_weights"].items():
|
| |
|
| | m = DeepHybridTRM(n_extra=200)
|
| | m.load_state_dict(state_dict)
|
| | m.eval()
|
| | MODELS.append(m)
|
| |
|
| | print(f" β Loaded {len(MODELS)} models ({N_MODELS} expected)")
|
| | print(f" β Architecture: {ckpt['model_name']} ({sum(p.numel() for p in MODELS[0].parameters()):,} params)")
|
| |
|
| |
|
| | print(" Loading featurizer (Magpie + Mat2Vec + Matminer)...")
|
| | FEATURIZER = ExpandedFeaturizer()
|
| | print(" β Featurizer ready\n")
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | def predict_yield_strength(formula: str):
|
| | """
|
| | Full ensemble prediction pipeline.
|
| | Returns: prediction text, per-model stats, composition breakdown.
|
| | """
|
| | if not formula or not formula.strip():
|
| | return (
|
| | "β οΈ Please enter a chemical composition.",
|
| | "",
|
| | ""
|
| | )
|
| |
|
| | formula = formula.strip()
|
| |
|
| |
|
| | try:
|
| | comp = Composition(formula)
|
| | except Exception as e:
|
| | return (
|
| | f"β Invalid composition: `{formula}`\n\n"
|
| | f"Error: {str(e)}\n\n"
|
| | f"**Tips:**\n"
|
| | f"- Use element symbols: `Fe`, `Cr`, `Ni`, `C`, etc.\n"
|
| | f"- Fractions must sum to ~1: `Fe0.7Cr0.2Ni0.1`\n"
|
| | f"- Or use integer counts: `Fe70Cr20Ni10`",
|
| | "",
|
| | ""
|
| | )
|
| |
|
| |
|
| | elements = comp.get_el_amt_dict()
|
| | total = sum(elements.values())
|
| | comp_lines = []
|
| | for el, amt in sorted(elements.items(), key=lambda x: -x[1]):
|
| | pct = (amt / total) * 100
|
| | bar = "β" * int(pct / 2) + "β" * (50 - int(pct / 2))
|
| | comp_lines.append(f"**{el:>3s}** `{bar}` {pct:5.1f}%")
|
| | comp_breakdown = "### π§ͺ Composition Breakdown\n\n" + "\n\n".join(comp_lines)
|
| |
|
| |
|
| | try:
|
| | X = FEATURIZER.featurize_all([comp])
|
| | X_tensor = torch.tensor(X, dtype=torch.float32)
|
| | except Exception as e:
|
| | return (
|
| | f"β Featurization failed for `{formula}`:\n{str(e)}",
|
| | "",
|
| | comp_breakdown
|
| | )
|
| |
|
| |
|
| | all_preds = []
|
| | with torch.no_grad():
|
| | for model in MODELS:
|
| | pred = model(X_tensor).item()
|
| | all_preds.append(pred)
|
| |
|
| | all_preds = np.array(all_preds)
|
| | ensemble_mean = np.mean(all_preds)
|
| | ensemble_std = np.std(all_preds)
|
| | pred_min = np.min(all_preds)
|
| | pred_max = np.max(all_preds)
|
| |
|
| |
|
| | result = (
|
| | f"# π― {ensemble_mean:.1f} MPa\n\n"
|
| | f"**Predicted Yield Strength** for `{comp.reduced_formula}`\n\n"
|
| | f"---\n\n"
|
| | f"### π Ensemble Statistics\n\n"
|
| | f"| Metric | Value |\n"
|
| | f"|:-------|------:|\n"
|
| | f"| **Ensemble Mean** | **{ensemble_mean:.2f} MPa** |\n"
|
| | f"| Ensemble Std Dev | Β±{ensemble_std:.2f} MPa |\n"
|
| | f"| Range | {pred_min:.2f} β {pred_max:.2f} MPa |\n"
|
| | f"| Models Used | {len(all_preds)} |\n\n"
|
| | f"---\n\n"
|
| | f"### π Confidence\n\n"
|
| | )
|
| |
|
| |
|
| | cv = (ensemble_std / abs(ensemble_mean)) * 100 if ensemble_mean != 0 else 100
|
| | if cv < 3:
|
| | result += f"π’ **High confidence** β models strongly agree (CV = {cv:.1f}%)"
|
| | elif cv < 8:
|
| | result += f"π‘ **Moderate confidence** β some model disagreement (CV = {cv:.1f}%)"
|
| | else:
|
| | result += f"π΄ **Low confidence** β significant model disagreement (CV = {cv:.1f}%)\n\n> This composition may be outside the training distribution."
|
| |
|
| |
|
| | seed_lines = ["### π± Per-Seed Predictions\n"]
|
| | seed_lines.append("| Seed | Fold 1 | Fold 2 | Fold 3 | Fold 4 | Fold 5 | **Avg** |")
|
| | seed_lines.append("|:-----|-------:|-------:|-------:|-------:|-------:|--------:|")
|
| | for si, seed in enumerate(SEEDS):
|
| | fold_preds = all_preds[si * 5 : (si + 1) * 5]
|
| | avg = np.mean(fold_preds)
|
| | vals = " | ".join(f"{p:.1f}" for p in fold_preds)
|
| | seed_lines.append(f"| {seed} | {vals} | **{avg:.1f}** |")
|
| | seed_breakdown = "\n".join(seed_lines)
|
| |
|
| | return result, seed_breakdown, comp_breakdown
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | EXAMPLES = [
|
| | ["Fe0.7Cr0.15Ni0.15"],
|
| | ["Fe0.8C0.005Mn0.01Cr0.12Ni0.065"],
|
| | ["Fe0.9Cr0.05Mo0.03V0.02"],
|
| | ["Fe0.85Cr0.1Ni0.05"],
|
| | ["Fe0.6Cr0.2Ni0.1Mo0.05Mn0.05"],
|
| | ["Fe0.95C0.01Si0.02Mn0.02"],
|
| | ]
|
| |
|
| | DESCRIPTION = """
|
| | <div style="text-align: center; max-width: 800px; margin: auto;">
|
| | <p style="font-size: 1.1em;">
|
| | A <strong>224K-parameter</strong> deep learning model achieving <strong>91.20 MPa MAE</strong> on the
|
| | <a href="https://matbench.materialsproject.org/" target="_blank">Matbench Steels</a> benchmark β
|
| | surpassing CrabNet, Darwin, and Random Forest baselines.
|
| | </p>
|
| | <p style="font-size: 0.95em; color: #888;">
|
| | Architecture: 2-Layer Self-Attention β Recursive MLP (20 steps) β Deep Supervision | 5-Seed Ensemble (25 models)
|
| | <br>
|
| | <a href="https://github.com/Rtx09x/TRIADS" target="_blank">π Paper & Code on GitHub</a> Β·
|
| | <a href="https://huggingface.co/Rtx09/TRIADS" target="_blank">π€ Model on HuggingFace</a>
|
| | </p>
|
| | </div>
|
| | """
|
| |
|
| | ARTICLE = """
|
| | <div style="text-align: center; margin-top: 20px; padding: 20px; background: rgba(128,128,128,0.05); border-radius: 12px;">
|
| | <h3>How it works</h3>
|
| | <p>
|
| | <strong>1. Featurization:</strong> Your composition is converted into ~462 chemical features
|
| | (Magpie descriptors + Mat2Vec embeddings + Matminer descriptors).<br>
|
| | <strong>2. Attention:</strong> Two self-attention layers learn property interactions across 22 chemical property tokens.<br>
|
| | <strong>3. Recursive Reasoning:</strong> A shared-weight MLP refines the prediction over 20 iterative steps.<br>
|
| | <strong>4. Ensemble:</strong> 25 independently trained models (5 seeds Γ 5 folds) are averaged for the final prediction.
|
| | </p>
|
| | <p style="font-size: 0.85em; color: #888;">
|
| | Trained on the matbench_steels dataset (312 steel compositions).
|
| | Predictions are most reliable for compositions within the training distribution.
|
| | <br><br>
|
| | Built by <a href="https://github.com/Rtx09x" target="_blank">Rudra Tiwari</a> Β·
|
| | Full research journey and ablation studies on <a href="https://github.com/Rtx09x/TRIADS" target="_blank">GitHub</a>
|
| | </p>
|
| | </div>
|
| | """
|
| |
|
| | CSS = """
|
| | .gradio-container {
|
| | max-width: 1100px !important;
|
| | margin: auto !important;
|
| | }
|
| | h1 {
|
| | text-align: center;
|
| | font-size: 2.2em !important;
|
| | margin-bottom: 0 !important;
|
| | }
|
| | """
|
| |
|
| | with gr.Blocks(
|
| | title="TRIADS β Alloy Yield Strength Predictor",
|
| | theme=gr.themes.Soft(
|
| | primary_hue="emerald",
|
| | secondary_hue="blue",
|
| | neutral_hue="slate",
|
| | font=gr.themes.GoogleFont("Inter"),
|
| | ),
|
| | css=CSS,
|
| | ) as demo:
|
| |
|
| | gr.Markdown("# βοΈ TRIADS Yield Strength Predictor")
|
| | gr.HTML(DESCRIPTION)
|
| |
|
| | with gr.Row():
|
| | with gr.Column(scale=1):
|
| | formula_input = gr.Textbox(
|
| | label="Chemical Composition",
|
| | placeholder="e.g., Fe0.7Cr0.15Ni0.15",
|
| | info="Enter a steel alloy formula using element symbols and fractions.",
|
| | lines=1,
|
| | max_lines=1,
|
| | )
|
| | predict_btn = gr.Button(
|
| | "π¬ Predict Yield Strength",
|
| | variant="primary",
|
| | size="lg",
|
| | )
|
| | gr.Examples(
|
| | examples=EXAMPLES,
|
| | inputs=formula_input,
|
| | label="Example Compositions",
|
| | )
|
| |
|
| | with gr.Column(scale=2):
|
| | result_output = gr.Markdown(
|
| | label="Prediction",
|
| | value="*Enter a composition and click predict...*",
|
| | )
|
| |
|
| | with gr.Row():
|
| | with gr.Column():
|
| | comp_output = gr.Markdown(label="Composition")
|
| | with gr.Column():
|
| | seed_output = gr.Markdown(label="Per-Seed Details")
|
| |
|
| | gr.HTML(ARTICLE)
|
| |
|
| |
|
| | predict_btn.click(
|
| | fn=predict_yield_strength,
|
| | inputs=[formula_input],
|
| | outputs=[result_output, seed_output, comp_output],
|
| | )
|
| | formula_input.submit(
|
| | fn=predict_yield_strength,
|
| | inputs=[formula_input],
|
| | outputs=[result_output, seed_output, comp_output],
|
| | )
|
| |
|
| |
|
| | if __name__ == "__main__":
|
| | demo.launch(share=False)
|
| |
|