Spaces:
Running
Running
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, T5EncoderModel | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from llmprop_model import T5Predictor | |
| from llmprop_utils import replace_bond_lengths_with_num, replace_bond_angles_with_ang | |
| device = torch.device("cpu") # HuggingFace free tier is CPU only | |
| tokenizer = AutoTokenizer.from_pretrained("tokenizers/t5_tokenizer_trained_on_modified_part_of_C4_and_textedge") | |
| tokenizer.add_tokens(["[CLS]", "[NUM]", "[ANG]"]) | |
| base_model = T5EncoderModel.from_pretrained("google/t5-v1_1-small") | |
| base_model.resize_token_embeddings(len(tokenizer)) | |
| model = T5Predictor(base_model, 512, drop_rate=0.2, pooling="cls") | |
| model.load_state_dict( | |
| torch.load( | |
| "best_checkpoint_for_band_gap.pt", | |
| map_location=device | |
| ), | |
| strict=False | |
| ) | |
| model.to(device) | |
| model.eval() | |
| TRAIN_MEAN = 1.0258 | |
| TRAIN_STD = 1.5106 | |
| def predict(description): | |
| if not description.strip(): | |
| return "Please enter a crystal description.", "" | |
| text = replace_bond_lengths_with_num(description) | |
| text = replace_bond_angles_with_ang(text) | |
| encoded = tokenizer( | |
| "[CLS] " + text, | |
| add_special_tokens=True, | |
| padding="max_length", | |
| truncation=True, | |
| max_length=256, | |
| return_tensors="pt" | |
| ) | |
| input_ids = encoded["input_ids"].to(device) | |
| attention_mask = encoded["attention_mask"].to(device) | |
| with torch.no_grad(): | |
| _, prediction = model(input_ids, attention_mask) | |
| band_gap = (prediction.squeeze().cpu().item() * TRAIN_STD) + TRAIN_MEAN | |
| band_gap = max(0.0, band_gap) | |
| if band_gap < 0.1: | |
| material_type = "Metal (zero or near-zero band gap)" | |
| confidence = "High" | |
| elif band_gap < 1.0: | |
| material_type = "Narrow gap semiconductor" | |
| confidence = "Medium" | |
| elif band_gap < 3.0: | |
| material_type = "Semiconductor" | |
| confidence = "Medium" | |
| else: | |
| material_type = "Wide gap semiconductor / Insulator" | |
| confidence = "Low" | |
| result = f"Predicted Band Gap: {band_gap:.4f} eV" | |
| details = f"Material Type: {material_type}\nConfidence: {confidence}" | |
| return result, details | |
| def results_chart(): | |
| true_vals = [0.00, 0.00, 2.68, 2.62, 0.43] | |
| pred_vals = [0.00, 0.02, 1.31, 1.55, 0.01] | |
| labels = ["Sc2Co3Al", "PdCu11N4", "AgClO4", "Mg3Si4(BiO7)2", "LiCuSO4F"] | |
| fig, axes = plt.subplots(1, 2, figsize=(12, 5)) | |
| ax1 = axes[0] | |
| ax1.scatter(true_vals, pred_vals, color="#38bdf8", s=100, zorder=5) | |
| max_val = max(max(true_vals), max(pred_vals)) + 0.5 | |
| ax1.plot([0, max_val], [0, max_val], "r--", label="Perfect prediction") | |
| for i, label in enumerate(labels): | |
| ax1.annotate(label, (true_vals[i], pred_vals[i]), | |
| textcoords="offset points", xytext=(5, 5), fontsize=8) | |
| ax1.set_xlabel("True Band Gap (eV)") | |
| ax1.set_ylabel("Predicted Band Gap (eV)") | |
| ax1.set_title("Predicted vs True Band Gap") | |
| ax1.legend() | |
| ax2 = axes[1] | |
| errors = [abs(t - p) for t, p in zip(true_vals, pred_vals)] | |
| ax2.bar(labels, errors, color="#38bdf8", edgecolor="#0284c7") | |
| ax2.axhline(y=0.6678, color="red", linestyle="--", label="Test MAE (0.6678 eV)") | |
| ax2.set_xlabel("Material") | |
| ax2.set_ylabel("Absolute Error (eV)") | |
| ax2.set_title("Prediction Error by Material") | |
| ax2.legend() | |
| plt.xticks(rotation=15, ha="right") | |
| plt.tight_layout() | |
| return fig | |
| examples = [ | |
| ["Silicon dioxide crystallizes in the orthorhombic structure. Si-O bond lengths range from 1.6 to 1.7 Angstroms. Bond angles are approximately 109 degrees."], | |
| ["Copper crystallizes in the face-centered cubic structure. Cu-Cu bond lengths are 2.55 Angstroms. Metallic bonding."], | |
| ["Zinc sulfide crystallizes in the cubic zinc blende structure. Zn-S bond lengths are 2.34 Angstroms. Coordination number is 4."], | |
| ] | |
| with gr.Blocks(theme=gr.themes.Soft(), title="LLM-Prop") as demo: | |
| gr.Markdown(""" | |
| # LLM-Prop: Crystal Band Gap Predictor | |
| **Predicting material band gaps from natural language crystal structure descriptions using T5 Transformer** | |
| """) | |
| with gr.Tabs(): | |
| with gr.Tab("Demo"): | |
| gr.Markdown("### Paste a crystal structure description to predict its band gap") | |
| with gr.Row(): | |
| with gr.Column(): | |
| inp = gr.Textbox( | |
| label="Crystal Structure Description", | |
| placeholder="e.g. Silicon dioxide crystallizes in the orthorhombic structure...", | |
| lines=6 | |
| ) | |
| btn = gr.Button("Predict Band Gap", variant="primary") | |
| gr.Examples(examples=examples, inputs=inp) | |
| with gr.Column(): | |
| out_result = gr.Textbox(label="Prediction", lines=2) | |
| out_details = gr.Textbox(label="Material Type & Confidence", lines=3) | |
| btn.click(fn=predict, inputs=inp, outputs=[out_result, out_details]) | |
| with gr.Tab("Results"): | |
| gr.Markdown("### Model Performance on Test Samples") | |
| gr.Markdown(""" | |
| | Metric | Value | | |
| |---|---| | |
| | Test MAE | 0.6678 eV | | |
| | Best Val MAE | 0.6393 eV | | |
| | Training epochs | 15 | | |
| | Training samples | 125,098 | | |
| | Model parameters | 35.3M | | |
| """) | |
| gr.Plot(value=results_chart()) | |
| with gr.Tab("About"): | |
| gr.Markdown(""" | |
| ## About LLM-Prop | |
| **LLM-Prop** is a deep learning system that predicts the band gap of crystalline | |
| materials from natural language descriptions of their crystal structures. | |
| ### Model Architecture | |
| - **Base model:** T5-small encoder (Google) | |
| - **Parameters:** 35.3 million | |
| - **Input:** Crystal structure text (max 256 tokens) | |
| - **Output:** Band gap in eV (regression) | |
| ### Dataset | |
| - **Training:** 125,098 samples | |
| - **Validation:** 9,945 samples | |
| - **Test:** 11,531 samples | |
| ### Results | |
| - **Test MAE:** 0.6678 eV | |
| - **Baseline MAE:** ~1.5 eV | |
| - **Improvement over baseline:** ~55% | |
| ### Reference | |
| Based on the paper: [LLM-Prop](https://www.nature.com/articles/s41524-025-01536-2) | |
| """) | |
| demo.launch() |