Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -9,13 +9,23 @@ from pymatgen.io.ase import AseAtomsAdaptor
|
|
| 9 |
from ase.io import write as ase_write
|
| 10 |
import tempfile
|
| 11 |
import time
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
-
# Load the model
|
| 15 |
model_path = "./"
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
# Load the tokenizer
|
| 20 |
tokenizer_path = "Voc_prior"
|
| 21 |
tokenizer = SimpleTokenizer(tokenizer_path)
|
|
@@ -27,22 +37,23 @@ try:
|
|
| 27 |
except Exception as e:
|
| 28 |
backend = SLICES(relax_model=None)
|
| 29 |
|
| 30 |
-
|
|
|
|
|
|
|
| 31 |
condition = torch.tensor([[float(formation_energy), float(band_gap)]], dtype=torch.float32)
|
| 32 |
context = '>'
|
| 33 |
x = torch.tensor([[tokenizer.stoi[context]]], dtype=torch.long)
|
| 34 |
|
| 35 |
with torch.no_grad():
|
| 36 |
-
generated =
|
| 37 |
-
|
| 38 |
-
|
| 39 |
|
| 40 |
return tokenizer.decode(generated[0].tolist())
|
| 41 |
|
| 42 |
def generate_slices(formation_energy, band_gap):
|
| 43 |
-
return
|
| 44 |
-
|
| 45 |
-
|
| 46 |
def wrap_structure(structure):
|
| 47 |
"""Wrap all atoms back into the unit cell."""
|
| 48 |
for i, site in enumerate(structure):
|
|
@@ -109,7 +120,7 @@ with gr.Blocks() as iface:
|
|
| 109 |
with gr.Row():
|
| 110 |
with gr.Column():
|
| 111 |
gr.Image("Figure1.png", label="De novo crystal generation by MatterGPT targeting desired Eg, Ef", width=1000, height=300)
|
| 112 |
-
gr.Markdown("**Enter desired properties to inversely design materials (encoded in SLICES), then decode it into crystal structure.**")
|
| 113 |
|
| 114 |
with gr.Row():
|
| 115 |
with gr.Column(scale=2):
|
|
|
|
| 9 |
from ase.io import write as ase_write
|
| 10 |
import tempfile
|
| 11 |
import time
|
| 12 |
+
# 设置PyTorch使用的线程数
|
| 13 |
+
torch.set_num_threads(2)
|
| 14 |
+
def load_quantized_model(model_path):
|
| 15 |
+
model = MatterGPTWrapper.from_pretrained(model_path)
|
| 16 |
+
model.to('cpu')
|
| 17 |
+
model.eval()
|
| 18 |
+
quantized_model = torch.quantization.quantize_dynamic(
|
| 19 |
+
model, {torch.nn.Linear}, dtype=torch.qint8
|
| 20 |
+
)
|
| 21 |
+
return quantized_model
|
| 22 |
|
| 23 |
|
| 24 |
+
# Load and quantize the model
|
| 25 |
model_path = "./"
|
| 26 |
+
quantized_model = load_quantized_model(model_path)
|
| 27 |
+
quantized_model.to("cpu")
|
| 28 |
+
quantized_model.eval()
|
| 29 |
# Load the tokenizer
|
| 30 |
tokenizer_path = "Voc_prior"
|
| 31 |
tokenizer = SimpleTokenizer(tokenizer_path)
|
|
|
|
| 37 |
except Exception as e:
|
| 38 |
backend = SLICES(relax_model=None)
|
| 39 |
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def generate_slices_quantized(quantized_model, tokenizer, formation_energy, band_gap, max_length, temperature, do_sample, top_k, top_p):
|
| 43 |
condition = torch.tensor([[float(formation_energy), float(band_gap)]], dtype=torch.float32)
|
| 44 |
context = '>'
|
| 45 |
x = torch.tensor([[tokenizer.stoi[context]]], dtype=torch.long)
|
| 46 |
|
| 47 |
with torch.no_grad():
|
| 48 |
+
generated = quantized_model.generate(x, prop=condition, max_length=max_length,
|
| 49 |
+
temperature=temperature, do_sample=do_sample,
|
| 50 |
+
top_k=top_k, top_p=top_p)
|
| 51 |
|
| 52 |
return tokenizer.decode(generated[0].tolist())
|
| 53 |
|
| 54 |
def generate_slices(formation_energy, band_gap):
|
| 55 |
+
return generate_slices_quantized(quantized_model, tokenizer, formation_energy, band_gap,
|
| 56 |
+
quantized_model.config.block_size, 1.2, True, 0, 0.9)
|
|
|
|
| 57 |
def wrap_structure(structure):
|
| 58 |
"""Wrap all atoms back into the unit cell."""
|
| 59 |
for i, site in enumerate(structure):
|
|
|
|
| 120 |
with gr.Row():
|
| 121 |
with gr.Column():
|
| 122 |
gr.Image("Figure1.png", label="De novo crystal generation by MatterGPT targeting desired Eg, Ef", width=1000, height=300)
|
| 123 |
+
gr.Markdown("**Enter desired properties to inversely design materials (encoded in SLICES), then decode it into crystal structure. Take 1-2 mins to finish with 2 cpus**")
|
| 124 |
|
| 125 |
with gr.Row():
|
| 126 |
with gr.Column(scale=2):
|