Spaces:
Sleeping
Sleeping
File size: 6,412 Bytes
752cdaf | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | import gradio as gr
import torch
from transformers import AutoTokenizer, T5EncoderModel
import matplotlib.pyplot as plt
import numpy as np
from llmprop_model import T5Predictor
from llmprop_utils import replace_bond_lengths_with_num, replace_bond_angles_with_ang
device = torch.device("cpu") # HuggingFace free tier is CPU only
tokenizer = AutoTokenizer.from_pretrained("tokenizers/t5_tokenizer_trained_on_modified_part_of_C4_and_textedge")
tokenizer.add_tokens(["[CLS]", "[NUM]", "[ANG]"])
base_model = T5EncoderModel.from_pretrained("google/t5-v1_1-small")
base_model.resize_token_embeddings(len(tokenizer))
model = T5Predictor(base_model, 512, drop_rate=0.2, pooling="cls")
model.load_state_dict(
torch.load(
"best_checkpoint_for_band_gap.pt",
map_location=device
),
strict=False
)
model.to(device)
model.eval()
TRAIN_MEAN = 1.0258
TRAIN_STD = 1.5106
def predict(description):
if not description.strip():
return "Please enter a crystal description.", ""
text = replace_bond_lengths_with_num(description)
text = replace_bond_angles_with_ang(text)
encoded = tokenizer(
"[CLS] " + text,
add_special_tokens=True,
padding="max_length",
truncation=True,
max_length=256,
return_tensors="pt"
)
input_ids = encoded["input_ids"].to(device)
attention_mask = encoded["attention_mask"].to(device)
with torch.no_grad():
_, prediction = model(input_ids, attention_mask)
band_gap = (prediction.squeeze().cpu().item() * TRAIN_STD) + TRAIN_MEAN
band_gap = max(0.0, band_gap)
if band_gap < 0.1:
material_type = "Metal (zero or near-zero band gap)"
confidence = "High"
elif band_gap < 1.0:
material_type = "Narrow gap semiconductor"
confidence = "Medium"
elif band_gap < 3.0:
material_type = "Semiconductor"
confidence = "Medium"
else:
material_type = "Wide gap semiconductor / Insulator"
confidence = "Low"
result = f"Predicted Band Gap: {band_gap:.4f} eV"
details = f"Material Type: {material_type}\nConfidence: {confidence}"
return result, details
def results_chart():
true_vals = [0.00, 0.00, 2.68, 2.62, 0.43]
pred_vals = [0.00, 0.02, 1.31, 1.55, 0.01]
labels = ["Sc2Co3Al", "PdCu11N4", "AgClO4", "Mg3Si4(BiO7)2", "LiCuSO4F"]
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
ax1 = axes[0]
ax1.scatter(true_vals, pred_vals, color="#38bdf8", s=100, zorder=5)
max_val = max(max(true_vals), max(pred_vals)) + 0.5
ax1.plot([0, max_val], [0, max_val], "r--", label="Perfect prediction")
for i, label in enumerate(labels):
ax1.annotate(label, (true_vals[i], pred_vals[i]),
textcoords="offset points", xytext=(5, 5), fontsize=8)
ax1.set_xlabel("True Band Gap (eV)")
ax1.set_ylabel("Predicted Band Gap (eV)")
ax1.set_title("Predicted vs True Band Gap")
ax1.legend()
ax2 = axes[1]
errors = [abs(t - p) for t, p in zip(true_vals, pred_vals)]
ax2.bar(labels, errors, color="#38bdf8", edgecolor="#0284c7")
ax2.axhline(y=0.6678, color="red", linestyle="--", label="Test MAE (0.6678 eV)")
ax2.set_xlabel("Material")
ax2.set_ylabel("Absolute Error (eV)")
ax2.set_title("Prediction Error by Material")
ax2.legend()
plt.xticks(rotation=15, ha="right")
plt.tight_layout()
return fig
examples = [
["Silicon dioxide crystallizes in the orthorhombic structure. Si-O bond lengths range from 1.6 to 1.7 Angstroms. Bond angles are approximately 109 degrees."],
["Copper crystallizes in the face-centered cubic structure. Cu-Cu bond lengths are 2.55 Angstroms. Metallic bonding."],
["Zinc sulfide crystallizes in the cubic zinc blende structure. Zn-S bond lengths are 2.34 Angstroms. Coordination number is 4."],
]
with gr.Blocks(theme=gr.themes.Soft(), title="LLM-Prop") as demo:
gr.Markdown("""
# LLM-Prop: Crystal Band Gap Predictor
**Predicting material band gaps from natural language crystal structure descriptions using T5 Transformer**
""")
with gr.Tabs():
with gr.Tab("Demo"):
gr.Markdown("### Paste a crystal structure description to predict its band gap")
with gr.Row():
with gr.Column():
inp = gr.Textbox(
label="Crystal Structure Description",
placeholder="e.g. Silicon dioxide crystallizes in the orthorhombic structure...",
lines=6
)
btn = gr.Button("Predict Band Gap", variant="primary")
gr.Examples(examples=examples, inputs=inp)
with gr.Column():
out_result = gr.Textbox(label="Prediction", lines=2)
out_details = gr.Textbox(label="Material Type & Confidence", lines=3)
btn.click(fn=predict, inputs=inp, outputs=[out_result, out_details])
with gr.Tab("Results"):
gr.Markdown("### Model Performance on Test Samples")
gr.Markdown("""
| Metric | Value |
|---|---|
| Test MAE | 0.6678 eV |
| Best Val MAE | 0.6393 eV |
| Training epochs | 15 |
| Training samples | 125,098 |
| Model parameters | 35.3M |
""")
gr.Plot(value=results_chart())
with gr.Tab("About"):
gr.Markdown("""
## About LLM-Prop
**LLM-Prop** is a deep learning system that predicts the band gap of crystalline
materials from natural language descriptions of their crystal structures.
### Model Architecture
- **Base model:** T5-small encoder (Google)
- **Parameters:** 35.3 million
- **Input:** Crystal structure text (max 256 tokens)
- **Output:** Band gap in eV (regression)
### Dataset
- **Training:** 125,098 samples
- **Validation:** 9,945 samples
- **Test:** 11,531 samples
### Results
- **Test MAE:** 0.6678 eV
- **Baseline MAE:** ~1.5 eV
- **Improvement over baseline:** ~55%
### Reference
Based on the paper: [LLM-Prop](https://www.nature.com/articles/s41524-025-01536-2)
""")
demo.launch() |