xiEcho commited on
Commit
9bc705c
·
verified ·
1 Parent(s): 0393dd6

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +145 -0
app.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import torch
3
+ from mattergpt_wrapper import MatterGPTWrapper, SimpleTokenizer
4
+ import os
5
+ from slices.core import SLICES
6
+ from pymatgen.core.structure import Structure
7
+ from pymatgen.io.cif import CifWriter
8
+ from pymatgen.io.ase import AseAtomsAdaptor
9
+ from ase.io import write as ase_write
10
+ import tempfile
11
+ import time
12
+ # 设置PyTorch使用的线程数
13
+ torch.set_num_threads(2)
14
+ def load_quantized_model(model_path):
15
+ model = MatterGPTWrapper.from_pretrained(model_path)
16
+ model.to('cpu')
17
+ model.eval()
18
+ quantized_model = torch.quantization.quantize_dynamic(
19
+ model, {torch.nn.Linear}, dtype=torch.qint8
20
+ )
21
+ return quantized_model
22
+
23
+
24
+ # Load and quantize the model
25
+ model_path = "./"
26
+ quantized_model = load_quantized_model(model_path)
27
+ quantized_model.to("cpu")
28
+ quantized_model.eval()
29
+ # Load the tokenizer
30
+ tokenizer_path = "Voc_prior"
31
+ tokenizer = SimpleTokenizer(tokenizer_path)
32
+
33
+ # Initialize SLICES backend
34
+ try:
35
+ backend = SLICES(relax_model="chgnet",fmax=0.4,steps=25)
36
+
37
+ except Exception as e:
38
+ backend = SLICES(relax_model=None)
39
+
40
+
41
+
42
+ def generate_slices_quantized(quantized_model, tokenizer, formation_energy, band_gap, max_length, temperature, do_sample, top_k, top_p):
43
+ condition = torch.tensor([[float(formation_energy), float(band_gap)]], dtype=torch.float32)
44
+ context = '>'
45
+ x = torch.tensor([[tokenizer.stoi[context]]], dtype=torch.long)
46
+
47
+ with torch.no_grad():
48
+ generated = quantized_model.generate(x, prop=condition, max_length=max_length,
49
+ temperature=temperature, do_sample=do_sample,
50
+ top_k=top_k, top_p=top_p)
51
+
52
+ return tokenizer.decode(generated[0].tolist())
53
+
54
+ def generate_slices(formation_energy, band_gap):
55
+ return generate_slices_quantized(quantized_model, tokenizer, formation_energy, band_gap,
56
+ quantized_model.config.block_size, 1.2, True, 0, 0.9)
57
+ def wrap_structure(structure):
58
+ """Wrap all atoms back into the unit cell."""
59
+ for i, site in enumerate(structure):
60
+ frac_coords = site.frac_coords % 1.0
61
+ structure.replace(i, species=site.species, coords=frac_coords, coords_are_cartesian=False)
62
+ return structure
63
+
64
+ def convert_and_visualize(slices_string):
65
+ try:
66
+ structure, energy = backend.SLICES2structure(slices_string)
67
+
68
+ # Wrap atoms back into the unit cell
69
+ structure = wrap_structure(structure)
70
+
71
+ # Generate CIF and save to temporary file
72
+ cif_file = tempfile.NamedTemporaryFile(mode='w', suffix='.cif', delete=False)
73
+ cif_writer = CifWriter(structure)
74
+ cif_writer.write_file(cif_file.name)
75
+
76
+ # Generate structure summary
77
+ summary = f"Formula: {structure.composition.reduced_formula}\n"
78
+ summary += f"Number of sites: {len(structure)}\n"
79
+ summary += f"Lattice parameters: a={structure.lattice.a:.3f}, b={structure.lattice.b:.3f}, c={structure.lattice.c:.3f}\n"
80
+ summary += f"Angles: alpha={structure.lattice.alpha:.2f}, beta={structure.lattice.beta:.2f}, gamma={structure.lattice.gamma:.2f}\n"
81
+ summary += f"Volume: {structure.volume:.3f} ų\n"
82
+ summary += f"Density: {structure.density:.3f} g/cm³"
83
+
84
+ # Generate structure image using ASE and save to temporary file
85
+ atoms = AseAtomsAdaptor.get_atoms(structure)
86
+ image_file = tempfile.NamedTemporaryFile(suffix='.png', delete=False)
87
+ ase_write(image_file.name, atoms, format='png', rotation='10x,10y,10z')
88
+
89
+
90
+ return cif_file.name, image_file.name, summary, f"Conversion successful. Energy: {energy:.4f} eV/atom", True
91
+ except Exception as e:
92
+
93
+ return "", "", "", f"Conversion failed. Error: {str(e)}", False
94
+
95
+ def generate_and_convert(formation_energy, band_gap):
96
+ max_attempts = 5
97
+ start_time = time.time()
98
+ max_time = 300 # 5 minutes maximum execution time
99
+
100
+ for attempt in range(max_attempts):
101
+ if time.time() - start_time > max_time:
102
+ return "Exceeded maximum execution time", "", "", "", "Generation and conversion failed due to timeout"
103
+
104
+ slices_string = generate_slices(formation_energy, band_gap)
105
+ cif_file, image_file, structure_summary, status, success = convert_and_visualize(slices_string)
106
+
107
+ if success:
108
+ return slices_string, cif_file, image_file, structure_summary, f"Successful on attempt {attempt + 1}: {status}"
109
+
110
+
111
+ if attempt == max_attempts - 1:
112
+ return slices_string, "", "", "", f"Failed after {max_attempts} attempts: {status}"
113
+
114
+ return "Failed to generate valid SLICES string", "", "", "", "Generation failed"
115
+
116
+ # Create the Gradio interface
117
+ with gr.Blocks() as iface:
118
+ gr.Markdown("# Crystal Inverse Designer: From Properties to Structures")
119
+
120
+ with gr.Row():
121
+ with gr.Column():
122
+ gr.Image("Figure1.png", label="De novo crystal generation by MatterGPT targeting desired Eg, Ef", width=1000, height=300)
123
+ gr.Markdown("**Enter desired properties to inversely design materials (encoded in SLICES), then decode it into crystal structure.**")
124
+ gr.Markdown("**Allow 1-2 minutes for completion using 2 CPUs.**")
125
+
126
+ with gr.Row():
127
+ with gr.Column(scale=2):
128
+ band_gap = gr.Number(label="Band Gap (eV)", value=2.0)
129
+ formation_energy = gr.Number(label="Formation Energy (eV/atom)", value=-1.0)
130
+ generate_button = gr.Button("Generate")
131
+
132
+ with gr.Column(scale=3):
133
+ slices_output = gr.Textbox(label="Generated SLICES String")
134
+ cif_output = gr.File(label="Download CIF", file_types=[".cif"])
135
+ structure_image = gr.Image(label="Structure Visualization")
136
+ structure_summary = gr.Textbox(label="Structure Summary", lines=6)
137
+ conversion_status = gr.Textbox(label="Conversion Status")
138
+
139
+ generate_button.click(
140
+ generate_and_convert,
141
+ inputs=[formation_energy, band_gap],
142
+ outputs=[slices_output, cif_output, structure_image, structure_summary, conversion_status]
143
+ )
144
+
145
+ iface.launch(share=True)