Upload app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
import torch
|
| 3 |
+
from mattergpt_wrapper import MatterGPTWrapper, SimpleTokenizer
|
| 4 |
+
import os
|
| 5 |
+
from slices.core import SLICES
|
| 6 |
+
from pymatgen.core.structure import Structure
|
| 7 |
+
from pymatgen.io.cif import CifWriter
|
| 8 |
+
from pymatgen.io.ase import AseAtomsAdaptor
|
| 9 |
+
from ase.io import write as ase_write
|
| 10 |
+
import tempfile
|
| 11 |
+
import time
|
| 12 |
+
# 设置PyTorch使用的线程数
|
| 13 |
+
torch.set_num_threads(2)
|
| 14 |
+
def load_quantized_model(model_path):
|
| 15 |
+
model = MatterGPTWrapper.from_pretrained(model_path)
|
| 16 |
+
model.to('cpu')
|
| 17 |
+
model.eval()
|
| 18 |
+
quantized_model = torch.quantization.quantize_dynamic(
|
| 19 |
+
model, {torch.nn.Linear}, dtype=torch.qint8
|
| 20 |
+
)
|
| 21 |
+
return quantized_model
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# Load and quantize the model
|
| 25 |
+
model_path = "./"
|
| 26 |
+
quantized_model = load_quantized_model(model_path)
|
| 27 |
+
quantized_model.to("cpu")
|
| 28 |
+
quantized_model.eval()
|
| 29 |
+
# Load the tokenizer
|
| 30 |
+
tokenizer_path = "Voc_prior"
|
| 31 |
+
tokenizer = SimpleTokenizer(tokenizer_path)
|
| 32 |
+
|
| 33 |
+
# Initialize SLICES backend
|
| 34 |
+
try:
|
| 35 |
+
backend = SLICES(relax_model="chgnet",fmax=0.4,steps=25)
|
| 36 |
+
|
| 37 |
+
except Exception as e:
|
| 38 |
+
backend = SLICES(relax_model=None)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def generate_slices_quantized(quantized_model, tokenizer, formation_energy, band_gap, max_length, temperature, do_sample, top_k, top_p):
|
| 43 |
+
condition = torch.tensor([[float(formation_energy), float(band_gap)]], dtype=torch.float32)
|
| 44 |
+
context = '>'
|
| 45 |
+
x = torch.tensor([[tokenizer.stoi[context]]], dtype=torch.long)
|
| 46 |
+
|
| 47 |
+
with torch.no_grad():
|
| 48 |
+
generated = quantized_model.generate(x, prop=condition, max_length=max_length,
|
| 49 |
+
temperature=temperature, do_sample=do_sample,
|
| 50 |
+
top_k=top_k, top_p=top_p)
|
| 51 |
+
|
| 52 |
+
return tokenizer.decode(generated[0].tolist())
|
| 53 |
+
|
| 54 |
+
def generate_slices(formation_energy, band_gap):
|
| 55 |
+
return generate_slices_quantized(quantized_model, tokenizer, formation_energy, band_gap,
|
| 56 |
+
quantized_model.config.block_size, 1.2, True, 0, 0.9)
|
| 57 |
+
def wrap_structure(structure):
|
| 58 |
+
"""Wrap all atoms back into the unit cell."""
|
| 59 |
+
for i, site in enumerate(structure):
|
| 60 |
+
frac_coords = site.frac_coords % 1.0
|
| 61 |
+
structure.replace(i, species=site.species, coords=frac_coords, coords_are_cartesian=False)
|
| 62 |
+
return structure
|
| 63 |
+
|
| 64 |
+
def convert_and_visualize(slices_string):
|
| 65 |
+
try:
|
| 66 |
+
structure, energy = backend.SLICES2structure(slices_string)
|
| 67 |
+
|
| 68 |
+
# Wrap atoms back into the unit cell
|
| 69 |
+
structure = wrap_structure(structure)
|
| 70 |
+
|
| 71 |
+
# Generate CIF and save to temporary file
|
| 72 |
+
cif_file = tempfile.NamedTemporaryFile(mode='w', suffix='.cif', delete=False)
|
| 73 |
+
cif_writer = CifWriter(structure)
|
| 74 |
+
cif_writer.write_file(cif_file.name)
|
| 75 |
+
|
| 76 |
+
# Generate structure summary
|
| 77 |
+
summary = f"Formula: {structure.composition.reduced_formula}\n"
|
| 78 |
+
summary += f"Number of sites: {len(structure)}\n"
|
| 79 |
+
summary += f"Lattice parameters: a={structure.lattice.a:.3f}, b={structure.lattice.b:.3f}, c={structure.lattice.c:.3f}\n"
|
| 80 |
+
summary += f"Angles: alpha={structure.lattice.alpha:.2f}, beta={structure.lattice.beta:.2f}, gamma={structure.lattice.gamma:.2f}\n"
|
| 81 |
+
summary += f"Volume: {structure.volume:.3f} ų\n"
|
| 82 |
+
summary += f"Density: {structure.density:.3f} g/cm³"
|
| 83 |
+
|
| 84 |
+
# Generate structure image using ASE and save to temporary file
|
| 85 |
+
atoms = AseAtomsAdaptor.get_atoms(structure)
|
| 86 |
+
image_file = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
|
| 87 |
+
ase_write(image_file.name, atoms, format='png', rotation='10x,10y,10z')
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
return cif_file.name, image_file.name, summary, f"Conversion successful. Energy: {energy:.4f} eV/atom", True
|
| 91 |
+
except Exception as e:
|
| 92 |
+
|
| 93 |
+
return "", "", "", f"Conversion failed. Error: {str(e)}", False
|
| 94 |
+
|
| 95 |
+
def generate_and_convert(formation_energy, band_gap):
|
| 96 |
+
max_attempts = 5
|
| 97 |
+
start_time = time.time()
|
| 98 |
+
max_time = 300 # 5 minutes maximum execution time
|
| 99 |
+
|
| 100 |
+
for attempt in range(max_attempts):
|
| 101 |
+
if time.time() - start_time > max_time:
|
| 102 |
+
return "Exceeded maximum execution time", "", "", "", "Generation and conversion failed due to timeout"
|
| 103 |
+
|
| 104 |
+
slices_string = generate_slices(formation_energy, band_gap)
|
| 105 |
+
cif_file, image_file, structure_summary, status, success = convert_and_visualize(slices_string)
|
| 106 |
+
|
| 107 |
+
if success:
|
| 108 |
+
return slices_string, cif_file, image_file, structure_summary, f"Successful on attempt {attempt + 1}: {status}"
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
if attempt == max_attempts - 1:
|
| 112 |
+
return slices_string, "", "", "", f"Failed after {max_attempts} attempts: {status}"
|
| 113 |
+
|
| 114 |
+
return "Failed to generate valid SLICES string", "", "", "", "Generation failed"
|
| 115 |
+
|
| 116 |
+
# Create the Gradio interface
|
| 117 |
+
with gr.Blocks() as iface:
|
| 118 |
+
gr.Markdown("# Crystal Inverse Designer: From Properties to Structures")
|
| 119 |
+
|
| 120 |
+
with gr.Row():
|
| 121 |
+
with gr.Column():
|
| 122 |
+
gr.Image("Figure1.png", label="De novo crystal generation by MatterGPT targeting desired Eg, Ef", width=1000, height=300)
|
| 123 |
+
gr.Markdown("**Enter desired properties to inversely design materials (encoded in SLICES), then decode it into crystal structure.**")
|
| 124 |
+
gr.Markdown("**Allow 1-2 minutes for completion using 2 CPUs.**")
|
| 125 |
+
|
| 126 |
+
with gr.Row():
|
| 127 |
+
with gr.Column(scale=2):
|
| 128 |
+
band_gap = gr.Number(label="Band Gap (eV)", value=2.0)
|
| 129 |
+
formation_energy = gr.Number(label="Formation Energy (eV/atom)", value=-1.0)
|
| 130 |
+
generate_button = gr.Button("Generate")
|
| 131 |
+
|
| 132 |
+
with gr.Column(scale=3):
|
| 133 |
+
slices_output = gr.Textbox(label="Generated SLICES String")
|
| 134 |
+
cif_output = gr.File(label="Download CIF", file_types=[".cif"])
|
| 135 |
+
structure_image = gr.Image(label="Structure Visualization")
|
| 136 |
+
structure_summary = gr.Textbox(label="Structure Summary", lines=6)
|
| 137 |
+
conversion_status = gr.Textbox(label="Conversion Status")
|
| 138 |
+
|
| 139 |
+
generate_button.click(
|
| 140 |
+
generate_and_convert,
|
| 141 |
+
inputs=[formation_energy, band_gap],
|
| 142 |
+
outputs=[slices_output, cif_output, structure_image, structure_summary, conversion_status]
|
| 143 |
+
)
|
| 144 |
+
|
| 145 |
+
iface.launch(share=True)
|