import gradio as gr import torch from transformers import AutoTokenizer, T5EncoderModel import matplotlib.pyplot as plt import numpy as np from llmprop_model import T5Predictor from llmprop_utils import replace_bond_lengths_with_num, replace_bond_angles_with_ang device = torch.device("cpu") # HuggingFace free tier is CPU only tokenizer = AutoTokenizer.from_pretrained("tokenizers/t5_tokenizer_trained_on_modified_part_of_C4_and_textedge") tokenizer.add_tokens(["[CLS]", "[NUM]", "[ANG]"]) base_model = T5EncoderModel.from_pretrained("google/t5-v1_1-small") base_model.resize_token_embeddings(len(tokenizer)) model = T5Predictor(base_model, 512, drop_rate=0.2, pooling="cls") model.load_state_dict( torch.load( "best_checkpoint_for_band_gap.pt", map_location=device ), strict=False ) model.to(device) model.eval() TRAIN_MEAN = 1.0258 TRAIN_STD = 1.5106 def predict(description): if not description.strip(): return "Please enter a crystal description.", "" text = replace_bond_lengths_with_num(description) text = replace_bond_angles_with_ang(text) encoded = tokenizer( "[CLS] " + text, add_special_tokens=True, padding="max_length", truncation=True, max_length=256, return_tensors="pt" ) input_ids = encoded["input_ids"].to(device) attention_mask = encoded["attention_mask"].to(device) with torch.no_grad(): _, prediction = model(input_ids, attention_mask) band_gap = (prediction.squeeze().cpu().item() * TRAIN_STD) + TRAIN_MEAN band_gap = max(0.0, band_gap) if band_gap < 0.1: material_type = "Metal (zero or near-zero band gap)" confidence = "High" elif band_gap < 1.0: material_type = "Narrow gap semiconductor" confidence = "Medium" elif band_gap < 3.0: material_type = "Semiconductor" confidence = "Medium" else: material_type = "Wide gap semiconductor / Insulator" confidence = "Low" result = f"Predicted Band Gap: {band_gap:.4f} eV" details = f"Material Type: {material_type}\nConfidence: {confidence}" return result, details def results_chart(): true_vals = [0.00, 0.00, 2.68, 2.62, 0.43] pred_vals = [0.00, 0.02, 1.31, 1.55, 0.01] labels = ["Sc2Co3Al", "PdCu11N4", "AgClO4", "Mg3Si4(BiO7)2", "LiCuSO4F"] fig, axes = plt.subplots(1, 2, figsize=(12, 5)) ax1 = axes[0] ax1.scatter(true_vals, pred_vals, color="#38bdf8", s=100, zorder=5) max_val = max(max(true_vals), max(pred_vals)) + 0.5 ax1.plot([0, max_val], [0, max_val], "r--", label="Perfect prediction") for i, label in enumerate(labels): ax1.annotate(label, (true_vals[i], pred_vals[i]), textcoords="offset points", xytext=(5, 5), fontsize=8) ax1.set_xlabel("True Band Gap (eV)") ax1.set_ylabel("Predicted Band Gap (eV)") ax1.set_title("Predicted vs True Band Gap") ax1.legend() ax2 = axes[1] errors = [abs(t - p) for t, p in zip(true_vals, pred_vals)] ax2.bar(labels, errors, color="#38bdf8", edgecolor="#0284c7") ax2.axhline(y=0.6678, color="red", linestyle="--", label="Test MAE (0.6678 eV)") ax2.set_xlabel("Material") ax2.set_ylabel("Absolute Error (eV)") ax2.set_title("Prediction Error by Material") ax2.legend() plt.xticks(rotation=15, ha="right") plt.tight_layout() return fig examples = [ ["Silicon dioxide crystallizes in the orthorhombic structure. Si-O bond lengths range from 1.6 to 1.7 Angstroms. Bond angles are approximately 109 degrees."], ["Copper crystallizes in the face-centered cubic structure. Cu-Cu bond lengths are 2.55 Angstroms. Metallic bonding."], ["Zinc sulfide crystallizes in the cubic zinc blende structure. Zn-S bond lengths are 2.34 Angstroms. Coordination number is 4."], ] with gr.Blocks(theme=gr.themes.Soft(), title="LLM-Prop") as demo: gr.Markdown(""" # LLM-Prop: Crystal Band Gap Predictor **Predicting material band gaps from natural language crystal structure descriptions using T5 Transformer** """) with gr.Tabs(): with gr.Tab("Demo"): gr.Markdown("### Paste a crystal structure description to predict its band gap") with gr.Row(): with gr.Column(): inp = gr.Textbox( label="Crystal Structure Description", placeholder="e.g. Silicon dioxide crystallizes in the orthorhombic structure...", lines=6 ) btn = gr.Button("Predict Band Gap", variant="primary") gr.Examples(examples=examples, inputs=inp) with gr.Column(): out_result = gr.Textbox(label="Prediction", lines=2) out_details = gr.Textbox(label="Material Type & Confidence", lines=3) btn.click(fn=predict, inputs=inp, outputs=[out_result, out_details]) with gr.Tab("Results"): gr.Markdown("### Model Performance on Test Samples") gr.Markdown(""" | Metric | Value | |---|---| | Test MAE | 0.6678 eV | | Best Val MAE | 0.6393 eV | | Training epochs | 15 | | Training samples | 125,098 | | Model parameters | 35.3M | """) gr.Plot(value=results_chart()) with gr.Tab("About"): gr.Markdown(""" ## About LLM-Prop **LLM-Prop** is a deep learning system that predicts the band gap of crystalline materials from natural language descriptions of their crystal structures. ### Model Architecture - **Base model:** T5-small encoder (Google) - **Parameters:** 35.3 million - **Input:** Crystal structure text (max 256 tokens) - **Output:** Band gap in eV (regression) ### Dataset - **Training:** 125,098 samples - **Validation:** 9,945 samples - **Test:** 11,531 samples ### Results - **Test MAE:** 0.6678 eV - **Baseline MAE:** ~1.5 eV - **Improvement over baseline:** ~55% ### Reference Based on the paper: [LLM-Prop](https://www.nature.com/articles/s41524-025-01536-2) """) demo.launch()