LLM-Prop / app.py
varshith1110's picture
Upload app.py with huggingface_hub
752cdaf verified
import gradio as gr
import torch
from transformers import AutoTokenizer, T5EncoderModel
import matplotlib.pyplot as plt
import numpy as np
from llmprop_model import T5Predictor
from llmprop_utils import replace_bond_lengths_with_num, replace_bond_angles_with_ang
device = torch.device("cpu") # HuggingFace free tier is CPU only
tokenizer = AutoTokenizer.from_pretrained("tokenizers/t5_tokenizer_trained_on_modified_part_of_C4_and_textedge")
tokenizer.add_tokens(["[CLS]", "[NUM]", "[ANG]"])
base_model = T5EncoderModel.from_pretrained("google/t5-v1_1-small")
base_model.resize_token_embeddings(len(tokenizer))
model = T5Predictor(base_model, 512, drop_rate=0.2, pooling="cls")
model.load_state_dict(
torch.load(
"best_checkpoint_for_band_gap.pt",
map_location=device
),
strict=False
)
model.to(device)
model.eval()
TRAIN_MEAN = 1.0258
TRAIN_STD = 1.5106
def predict(description):
if not description.strip():
return "Please enter a crystal description.", ""
text = replace_bond_lengths_with_num(description)
text = replace_bond_angles_with_ang(text)
encoded = tokenizer(
"[CLS] " + text,
add_special_tokens=True,
padding="max_length",
truncation=True,
max_length=256,
return_tensors="pt"
)
input_ids = encoded["input_ids"].to(device)
attention_mask = encoded["attention_mask"].to(device)
with torch.no_grad():
_, prediction = model(input_ids, attention_mask)
band_gap = (prediction.squeeze().cpu().item() * TRAIN_STD) + TRAIN_MEAN
band_gap = max(0.0, band_gap)
if band_gap < 0.1:
material_type = "Metal (zero or near-zero band gap)"
confidence = "High"
elif band_gap < 1.0:
material_type = "Narrow gap semiconductor"
confidence = "Medium"
elif band_gap < 3.0:
material_type = "Semiconductor"
confidence = "Medium"
else:
material_type = "Wide gap semiconductor / Insulator"
confidence = "Low"
result = f"Predicted Band Gap: {band_gap:.4f} eV"
details = f"Material Type: {material_type}\nConfidence: {confidence}"
return result, details
def results_chart():
true_vals = [0.00, 0.00, 2.68, 2.62, 0.43]
pred_vals = [0.00, 0.02, 1.31, 1.55, 0.01]
labels = ["Sc2Co3Al", "PdCu11N4", "AgClO4", "Mg3Si4(BiO7)2", "LiCuSO4F"]
fig, axes = plt.subplots(1, 2, figsize=(12, 5))
ax1 = axes[0]
ax1.scatter(true_vals, pred_vals, color="#38bdf8", s=100, zorder=5)
max_val = max(max(true_vals), max(pred_vals)) + 0.5
ax1.plot([0, max_val], [0, max_val], "r--", label="Perfect prediction")
for i, label in enumerate(labels):
ax1.annotate(label, (true_vals[i], pred_vals[i]),
textcoords="offset points", xytext=(5, 5), fontsize=8)
ax1.set_xlabel("True Band Gap (eV)")
ax1.set_ylabel("Predicted Band Gap (eV)")
ax1.set_title("Predicted vs True Band Gap")
ax1.legend()
ax2 = axes[1]
errors = [abs(t - p) for t, p in zip(true_vals, pred_vals)]
ax2.bar(labels, errors, color="#38bdf8", edgecolor="#0284c7")
ax2.axhline(y=0.6678, color="red", linestyle="--", label="Test MAE (0.6678 eV)")
ax2.set_xlabel("Material")
ax2.set_ylabel("Absolute Error (eV)")
ax2.set_title("Prediction Error by Material")
ax2.legend()
plt.xticks(rotation=15, ha="right")
plt.tight_layout()
return fig
examples = [
["Silicon dioxide crystallizes in the orthorhombic structure. Si-O bond lengths range from 1.6 to 1.7 Angstroms. Bond angles are approximately 109 degrees."],
["Copper crystallizes in the face-centered cubic structure. Cu-Cu bond lengths are 2.55 Angstroms. Metallic bonding."],
["Zinc sulfide crystallizes in the cubic zinc blende structure. Zn-S bond lengths are 2.34 Angstroms. Coordination number is 4."],
]
with gr.Blocks(theme=gr.themes.Soft(), title="LLM-Prop") as demo:
gr.Markdown("""
# LLM-Prop: Crystal Band Gap Predictor
**Predicting material band gaps from natural language crystal structure descriptions using T5 Transformer**
""")
with gr.Tabs():
with gr.Tab("Demo"):
gr.Markdown("### Paste a crystal structure description to predict its band gap")
with gr.Row():
with gr.Column():
inp = gr.Textbox(
label="Crystal Structure Description",
placeholder="e.g. Silicon dioxide crystallizes in the orthorhombic structure...",
lines=6
)
btn = gr.Button("Predict Band Gap", variant="primary")
gr.Examples(examples=examples, inputs=inp)
with gr.Column():
out_result = gr.Textbox(label="Prediction", lines=2)
out_details = gr.Textbox(label="Material Type & Confidence", lines=3)
btn.click(fn=predict, inputs=inp, outputs=[out_result, out_details])
with gr.Tab("Results"):
gr.Markdown("### Model Performance on Test Samples")
gr.Markdown("""
| Metric | Value |
|---|---|
| Test MAE | 0.6678 eV |
| Best Val MAE | 0.6393 eV |
| Training epochs | 15 |
| Training samples | 125,098 |
| Model parameters | 35.3M |
""")
gr.Plot(value=results_chart())
with gr.Tab("About"):
gr.Markdown("""
## About LLM-Prop
**LLM-Prop** is a deep learning system that predicts the band gap of crystalline
materials from natural language descriptions of their crystal structures.
### Model Architecture
- **Base model:** T5-small encoder (Google)
- **Parameters:** 35.3 million
- **Input:** Crystal structure text (max 256 tokens)
- **Output:** Band gap in eV (regression)
### Dataset
- **Training:** 125,098 samples
- **Validation:** 9,945 samples
- **Test:** 11,531 samples
### Results
- **Test MAE:** 0.6678 eV
- **Baseline MAE:** ~1.5 eV
- **Improvement over baseline:** ~55%
### Reference
Based on the paper: [LLM-Prop](https://www.nature.com/articles/s41524-025-01536-2)
""")
demo.launch()