File size: 6,412 Bytes
752cdaf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import gradio as gr
import torch
from transformers import AutoTokenizer, T5EncoderModel
import matplotlib.pyplot as plt
import numpy as np
from llmprop_model import T5Predictor
from llmprop_utils import replace_bond_lengths_with_num, replace_bond_angles_with_ang

device = torch.device("cpu")  # HuggingFace free tier is CPU only

tokenizer = AutoTokenizer.from_pretrained("tokenizers/t5_tokenizer_trained_on_modified_part_of_C4_and_textedge")
tokenizer.add_tokens(["[CLS]", "[NUM]", "[ANG]"])

base_model = T5EncoderModel.from_pretrained("google/t5-v1_1-small")
base_model.resize_token_embeddings(len(tokenizer))

model = T5Predictor(base_model, 512, drop_rate=0.2, pooling="cls")
model.load_state_dict(
    torch.load(
        "best_checkpoint_for_band_gap.pt",
        map_location=device
    ),
    strict=False
)
model.to(device)
model.eval()

TRAIN_MEAN = 1.0258
TRAIN_STD  = 1.5106

def predict(description):
    if not description.strip():
        return "Please enter a crystal description.", ""
    text = replace_bond_lengths_with_num(description)
    text = replace_bond_angles_with_ang(text)
    encoded = tokenizer(
        "[CLS] " + text,
        add_special_tokens=True,
        padding="max_length",
        truncation=True,
        max_length=256,
        return_tensors="pt"
    )
    input_ids      = encoded["input_ids"].to(device)
    attention_mask = encoded["attention_mask"].to(device)
    with torch.no_grad():
        _, prediction = model(input_ids, attention_mask)
    band_gap = (prediction.squeeze().cpu().item() * TRAIN_STD) + TRAIN_MEAN
    band_gap = max(0.0, band_gap)
    if band_gap < 0.1:
        material_type = "Metal (zero or near-zero band gap)"
        confidence    = "High"
    elif band_gap < 1.0:
        material_type = "Narrow gap semiconductor"
        confidence    = "Medium"
    elif band_gap < 3.0:
        material_type = "Semiconductor"
        confidence    = "Medium"
    else:
        material_type = "Wide gap semiconductor / Insulator"
        confidence    = "Low"
    result  = f"Predicted Band Gap: {band_gap:.4f} eV"
    details = f"Material Type: {material_type}\nConfidence: {confidence}"
    return result, details

def results_chart():
    true_vals = [0.00, 0.00, 2.68, 2.62, 0.43]
    pred_vals = [0.00, 0.02, 1.31, 1.55, 0.01]
    labels    = ["Sc2Co3Al", "PdCu11N4", "AgClO4", "Mg3Si4(BiO7)2", "LiCuSO4F"]
    fig, axes = plt.subplots(1, 2, figsize=(12, 5))
    ax1 = axes[0]
    ax1.scatter(true_vals, pred_vals, color="#38bdf8", s=100, zorder=5)
    max_val = max(max(true_vals), max(pred_vals)) + 0.5
    ax1.plot([0, max_val], [0, max_val], "r--", label="Perfect prediction")
    for i, label in enumerate(labels):
        ax1.annotate(label, (true_vals[i], pred_vals[i]),
                     textcoords="offset points", xytext=(5, 5), fontsize=8)
    ax1.set_xlabel("True Band Gap (eV)")
    ax1.set_ylabel("Predicted Band Gap (eV)")
    ax1.set_title("Predicted vs True Band Gap")
    ax1.legend()
    ax2 = axes[1]
    errors = [abs(t - p) for t, p in zip(true_vals, pred_vals)]
    ax2.bar(labels, errors, color="#38bdf8", edgecolor="#0284c7")
    ax2.axhline(y=0.6678, color="red", linestyle="--", label="Test MAE (0.6678 eV)")
    ax2.set_xlabel("Material")
    ax2.set_ylabel("Absolute Error (eV)")
    ax2.set_title("Prediction Error by Material")
    ax2.legend()
    plt.xticks(rotation=15, ha="right")
    plt.tight_layout()
    return fig

examples = [
    ["Silicon dioxide crystallizes in the orthorhombic structure. Si-O bond lengths range from 1.6 to 1.7 Angstroms. Bond angles are approximately 109 degrees."],
    ["Copper crystallizes in the face-centered cubic structure. Cu-Cu bond lengths are 2.55 Angstroms. Metallic bonding."],
    ["Zinc sulfide crystallizes in the cubic zinc blende structure. Zn-S bond lengths are 2.34 Angstroms. Coordination number is 4."],
]

with gr.Blocks(theme=gr.themes.Soft(), title="LLM-Prop") as demo:
    gr.Markdown("""
    # LLM-Prop: Crystal Band Gap Predictor
    **Predicting material band gaps from natural language crystal structure descriptions using T5 Transformer**
    """)
    with gr.Tabs():
        with gr.Tab("Demo"):
            gr.Markdown("### Paste a crystal structure description to predict its band gap")
            with gr.Row():
                with gr.Column():
                    inp = gr.Textbox(
                        label="Crystal Structure Description",
                        placeholder="e.g. Silicon dioxide crystallizes in the orthorhombic structure...",
                        lines=6
                    )
                    btn = gr.Button("Predict Band Gap", variant="primary")
                    gr.Examples(examples=examples, inputs=inp)
                with gr.Column():
                    out_result  = gr.Textbox(label="Prediction", lines=2)
                    out_details = gr.Textbox(label="Material Type & Confidence", lines=3)
            btn.click(fn=predict, inputs=inp, outputs=[out_result, out_details])
        with gr.Tab("Results"):
            gr.Markdown("### Model Performance on Test Samples")
            gr.Markdown("""
            | Metric | Value |
            |---|---|
            | Test MAE | 0.6678 eV |
            | Best Val MAE | 0.6393 eV |
            | Training epochs | 15 |
            | Training samples | 125,098 |
            | Model parameters | 35.3M |
            """)
            gr.Plot(value=results_chart())
        with gr.Tab("About"):
            gr.Markdown("""
            ## About LLM-Prop
            **LLM-Prop** is a deep learning system that predicts the band gap of crystalline
            materials from natural language descriptions of their crystal structures.
            ### Model Architecture
            - **Base model:** T5-small encoder (Google)
            - **Parameters:** 35.3 million
            - **Input:** Crystal structure text (max 256 tokens)
            - **Output:** Band gap in eV (regression)
            ### Dataset
            - **Training:** 125,098 samples
            - **Validation:** 9,945 samples
            - **Test:** 11,531 samples
            ### Results
            - **Test MAE:** 0.6678 eV
            - **Baseline MAE:** ~1.5 eV
            - **Improvement over baseline:** ~55%
            ### Reference
            Based on the paper: [LLM-Prop](https://www.nature.com/articles/s41524-025-01536-2)
            """)

demo.launch()