Rtx09 commited on
Commit
4bd82f4
·
verified ·
1 Parent(s): e42c9a6

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +295 -0
app.py ADDED
@@ -0,0 +1,295 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ╔══════════════════════════════════════════════════════════════════════╗
3
+ ║ TRIADS — Interactive Alloy Yield Strength Predictor ║
4
+ ║ Gradio App for the TRIADS V13A SOTA Ensemble ║
5
+ ║ ║
6
+ ║ Run locally: python app.py ║
7
+ ║ HF Spaces: Auto-detected and hosted ║
8
+ ╚══════════════════════════════════════════════════════════════════════╝
9
+ """
10
+
11
+ import os
12
+ import warnings
13
+ warnings.filterwarnings("ignore")
14
+
15
+ import numpy as np
16
+ import torch
17
+ import gradio as gr
18
+ from pymatgen.core import Composition
19
+
20
+ from model_arch import DeepHybridTRM, ExpandedFeaturizer
21
+
22
+
23
+ # ══════════════════════════════════════════════════════════════════════
24
+ # 1. GLOBAL MODEL LOADING
25
+ # ══════════════════════════════════════════════════════════════════════
26
+
27
+ print("⚙️ Initializing TRIADS V13A Ensemble...")
28
+
29
+ CKPT_PATH = "triads_v13a_ensemble.pt"
30
+
31
+ # Try loading locally first, then from HuggingFace
32
+ if not os.path.exists(CKPT_PATH):
33
+ try:
34
+ from huggingface_hub import hf_hub_download
35
+ print(" Downloading checkpoint from HuggingFace...")
36
+ CKPT_PATH = hf_hub_download(
37
+ repo_id="Rtx09/TRIADS",
38
+ filename="triads_v13a_ensemble.pt"
39
+ )
40
+ except Exception as e:
41
+ raise FileNotFoundError(
42
+ f"Could not find or download checkpoint: {e}")
43
+
44
+ ckpt = torch.load(CKPT_PATH, map_location="cpu")
45
+ CONFIG = ckpt["config"]
46
+ SEEDS = ckpt["seeds"]
47
+ N_MODELS = ckpt["n_models"]
48
+
49
+ # Load all 25 models
50
+ MODELS = []
51
+ for key, state_dict in ckpt["ensemble_weights"].items():
52
+ m = DeepHybridTRM(**CONFIG)
53
+ m.load_state_dict(state_dict)
54
+ m.eval()
55
+ MODELS.append(m)
56
+
57
+ print(f" ✓ Loaded {len(MODELS)} models ({N_MODELS} expected)")
58
+ print(f" ✓ Architecture: {ckpt['model_name']} ({sum(p.numel() for p in MODELS[0].parameters()):,} params)")
59
+
60
+ # Initialize featurizer (downloads Mat2Vec on first run)
61
+ print(" Loading featurizer (Magpie + Mat2Vec + Matminer)...")
62
+ FEATURIZER = ExpandedFeaturizer()
63
+ print(" ✓ Featurizer ready\n")
64
+
65
+
66
+ # ══════════════════════════════════════════════════════════════════════
67
+ # 2. PREDICTION LOGIC
68
+ # ══════════════════════════════════════════════════════════════════════
69
+
70
+ def predict_yield_strength(formula: str):
71
+ """
72
+ Full ensemble prediction pipeline.
73
+ Returns: prediction text, per-model stats, composition breakdown.
74
+ """
75
+ if not formula or not formula.strip():
76
+ return (
77
+ "⚠️ Please enter a chemical composition.",
78
+ "",
79
+ ""
80
+ )
81
+
82
+ formula = formula.strip()
83
+
84
+ # ── Parse composition ─────────────────────────────────────────
85
+ try:
86
+ comp = Composition(formula)
87
+ except Exception as e:
88
+ return (
89
+ f"❌ Invalid composition: `{formula}`\n\n"
90
+ f"Error: {str(e)}\n\n"
91
+ f"**Tips:**\n"
92
+ f"- Use element symbols: `Fe`, `Cr`, `Ni`, `C`, etc.\n"
93
+ f"- Fractions must sum to ~1: `Fe0.7Cr0.2Ni0.1`\n"
94
+ f"- Or use integer counts: `Fe70Cr20Ni10`",
95
+ "",
96
+ ""
97
+ )
98
+
99
+ # ── Composition breakdown ─────────────────────────────────────
100
+ elements = comp.get_el_amt_dict()
101
+ total = sum(elements.values())
102
+ comp_lines = []
103
+ for el, amt in sorted(elements.items(), key=lambda x: -x[1]):
104
+ pct = (amt / total) * 100
105
+ bar = "█" * int(pct / 2) + "░" * (50 - int(pct / 2))
106
+ comp_lines.append(f"**{el:>3s}** `{bar}` {pct:5.1f}%")
107
+ comp_breakdown = "### 🧪 Composition Breakdown\n\n" + "\n\n".join(comp_lines)
108
+
109
+ # ── Featurize ─────────────────────────────────────────────────
110
+ try:
111
+ X = FEATURIZER.featurize_all([comp])
112
+ X_tensor = torch.tensor(X, dtype=torch.float32)
113
+ except Exception as e:
114
+ return (
115
+ f"❌ Featurization failed for `{formula}`:\n{str(e)}",
116
+ "",
117
+ comp_breakdown
118
+ )
119
+
120
+ # ── Ensemble prediction ───────────────────────────────────────
121
+ all_preds = []
122
+ with torch.no_grad():
123
+ for model in MODELS:
124
+ pred = model(X_tensor).item()
125
+ all_preds.append(pred)
126
+
127
+ all_preds = np.array(all_preds)
128
+ ensemble_mean = np.mean(all_preds)
129
+ ensemble_std = np.std(all_preds)
130
+ pred_min = np.min(all_preds)
131
+ pred_max = np.max(all_preds)
132
+
133
+ # ── Format results ────────────────────────────────────────────
134
+ result = (
135
+ f"# 🎯 {ensemble_mean:.1f} MPa\n\n"
136
+ f"**Predicted Yield Strength** for `{comp.reduced_formula}`\n\n"
137
+ f"---\n\n"
138
+ f"### 📊 Ensemble Statistics\n\n"
139
+ f"| Metric | Value |\n"
140
+ f"|:-------|------:|\n"
141
+ f"| **Ensemble Mean** | **{ensemble_mean:.2f} MPa** |\n"
142
+ f"| Ensemble Std Dev | ±{ensemble_std:.2f} MPa |\n"
143
+ f"| Range | {pred_min:.2f} – {pred_max:.2f} MPa |\n"
144
+ f"| Models Used | {len(all_preds)} |\n\n"
145
+ f"---\n\n"
146
+ f"### 🔍 Confidence\n\n"
147
+ )
148
+
149
+ # Confidence assessment based on ensemble agreement
150
+ cv = (ensemble_std / abs(ensemble_mean)) * 100 if ensemble_mean != 0 else 100
151
+ if cv < 3:
152
+ result += f"🟢 **High confidence** — models strongly agree (CV = {cv:.1f}%)"
153
+ elif cv < 8:
154
+ result += f"🟡 **Moderate confidence** — some model disagreement (CV = {cv:.1f}%)"
155
+ else:
156
+ result += f"🔴 **Low confidence** — significant model disagreement (CV = {cv:.1f}%)\n\n> This composition may be outside the training distribution."
157
+
158
+ # ── Per-seed breakdown ────────────────────────────────────────
159
+ seed_lines = ["### 🌱 Per-Seed Predictions\n"]
160
+ seed_lines.append("| Seed | Fold 1 | Fold 2 | Fold 3 | Fold 4 | Fold 5 | **Avg** |")
161
+ seed_lines.append("|:-----|-------:|-------:|-------:|-------:|-------:|--------:|")
162
+ for si, seed in enumerate(SEEDS):
163
+ fold_preds = all_preds[si * 5 : (si + 1) * 5]
164
+ avg = np.mean(fold_preds)
165
+ vals = " | ".join(f"{p:.1f}" for p in fold_preds)
166
+ seed_lines.append(f"| {seed} | {vals} | **{avg:.1f}** |")
167
+ seed_breakdown = "\n".join(seed_lines)
168
+
169
+ return result, seed_breakdown, comp_breakdown
170
+
171
+
172
+ # ══════════════════════════════════════════════════════════════════════
173
+ # 3. GRADIO INTERFACE
174
+ # ══════════════════════════════════════════════════════════════════════
175
+
176
+ EXAMPLES = [
177
+ ["Fe0.7Cr0.15Ni0.15"],
178
+ ["Fe0.8C0.005Mn0.01Cr0.12Ni0.065"],
179
+ ["Fe0.9Cr0.05Mo0.03V0.02"],
180
+ ["Fe0.85Cr0.1Ni0.05"],
181
+ ["Fe0.6Cr0.2Ni0.1Mo0.05Mn0.05"],
182
+ ["Fe0.95C0.01Si0.02Mn0.02"],
183
+ ]
184
+
185
+ DESCRIPTION = """
186
+ <div style="text-align: center; max-width: 800px; margin: auto;">
187
+ <p style="font-size: 1.1em;">
188
+ A <strong>224K-parameter</strong> deep learning model achieving <strong>91.20 MPa MAE</strong> on the
189
+ <a href="https://matbench.materialsproject.org/" target="_blank">Matbench Steels</a> benchmark —
190
+ surpassing CrabNet, Darwin, and Random Forest baselines.
191
+ </p>
192
+ <p style="font-size: 0.95em; color: #888;">
193
+ Architecture: 2-Layer Self-Attention → Recursive MLP (20 steps) → Deep Supervision | 5-Seed Ensemble (25 models)
194
+ <br>
195
+ <a href="https://github.com/Rtx09x/TRIADS" target="_blank">📄 Paper & Code on GitHub</a> ·
196
+ <a href="https://huggingface.co/Rtx09/TRIADS" target="_blank">🤗 Model on HuggingFace</a>
197
+ </p>
198
+ </div>
199
+ """
200
+
201
+ ARTICLE = """
202
+ <div style="text-align: center; margin-top: 20px; padding: 20px; background: rgba(128,128,128,0.05); border-radius: 12px;">
203
+ <h3>How it works</h3>
204
+ <p>
205
+ <strong>1. Featurization:</strong> Your composition is converted into ~462 chemical features
206
+ (Magpie descriptors + Mat2Vec embeddings + Matminer descriptors).<br>
207
+ <strong>2. Attention:</strong> Two self-attention layers learn property interactions across 22 chemical property tokens.<br>
208
+ <strong>3. Recursive Reasoning:</strong> A shared-weight MLP refines the prediction over 20 iterative steps.<br>
209
+ <strong>4. Ensemble:</strong> 25 independently trained models (5 seeds × 5 folds) are averaged for the final prediction.
210
+ </p>
211
+ <p style="font-size: 0.85em; color: #888;">
212
+ Trained on the matbench_steels dataset (312 steel compositions).
213
+ Predictions are most reliable for compositions within the training distribution.
214
+ <br><br>
215
+ Built by <a href="https://github.com/Rtx09x" target="_blank">Rudra Tiwari</a> ·
216
+ Full research journey and ablation studies on <a href="https://github.com/Rtx09x/TRIADS" target="_blank">GitHub</a>
217
+ </p>
218
+ </div>
219
+ """
220
+
221
+ CSS = """
222
+ .gradio-container {
223
+ max-width: 1100px !important;
224
+ margin: auto !important;
225
+ }
226
+ h1 {
227
+ text-align: center;
228
+ font-size: 2.2em !important;
229
+ margin-bottom: 0 !important;
230
+ }
231
+ """
232
+
233
+ with gr.Blocks(
234
+ title="TRIADS — Alloy Yield Strength Predictor",
235
+ theme=gr.themes.Soft(
236
+ primary_hue="emerald",
237
+ secondary_hue="blue",
238
+ neutral_hue="slate",
239
+ font=gr.themes.GoogleFont("Inter"),
240
+ ),
241
+ css=CSS,
242
+ ) as demo:
243
+
244
+ gr.Markdown("# ⚛️ TRIADS Yield Strength Predictor")
245
+ gr.HTML(DESCRIPTION)
246
+
247
+ with gr.Row():
248
+ with gr.Column(scale=1):
249
+ formula_input = gr.Textbox(
250
+ label="Chemical Composition",
251
+ placeholder="e.g., Fe0.7Cr0.15Ni0.15",
252
+ info="Enter a steel alloy formula using element symbols and fractions.",
253
+ lines=1,
254
+ max_lines=1,
255
+ )
256
+ predict_btn = gr.Button(
257
+ "🔬 Predict Yield Strength",
258
+ variant="primary",
259
+ size="lg",
260
+ )
261
+ gr.Examples(
262
+ examples=EXAMPLES,
263
+ inputs=formula_input,
264
+ label="Example Compositions",
265
+ )
266
+
267
+ with gr.Column(scale=2):
268
+ result_output = gr.Markdown(
269
+ label="Prediction",
270
+ value="*Enter a composition and click predict...*",
271
+ )
272
+
273
+ with gr.Row():
274
+ with gr.Column():
275
+ comp_output = gr.Markdown(label="Composition")
276
+ with gr.Column():
277
+ seed_output = gr.Markdown(label="Per-Seed Details")
278
+
279
+ gr.HTML(ARTICLE)
280
+
281
+ # Wire up
282
+ predict_btn.click(
283
+ fn=predict_yield_strength,
284
+ inputs=[formula_input],
285
+ outputs=[result_output, seed_output, comp_output],
286
+ )
287
+ formula_input.submit(
288
+ fn=predict_yield_strength,
289
+ inputs=[formula_input],
290
+ outputs=[result_output, seed_output, comp_output],
291
+ )
292
+
293
+
294
+ if __name__ == "__main__":
295
+ demo.launch(share=False)