Update app.py
Browse files
app.py
CHANGED
|
@@ -1,8 +1,8 @@
|
|
| 1 |
-
|
| 2 |
import os
|
| 3 |
import zipfile
|
| 4 |
import logging
|
| 5 |
import torch
|
|
|
|
| 6 |
import numpy as np
|
| 7 |
from io import BytesIO
|
| 8 |
from PIL import Image
|
|
@@ -10,7 +10,7 @@ from PIL import Image
|
|
| 10 |
import gradio as gr
|
| 11 |
from torch_geometric.data import Data as PyGData
|
| 12 |
import matplotlib
|
| 13 |
-
matplotlib.use('Agg')
|
| 14 |
|
| 15 |
from rdkit import Chem
|
| 16 |
from rdkit.Chem import Draw, AllChem, MolFromSmiles
|
|
@@ -62,10 +62,10 @@ if torch.cuda.is_available():
|
|
| 62 |
# Model Loading
|
| 63 |
# ----------------------------
|
| 64 |
def load_models():
|
| 65 |
-
import numpy.core.multiarray
|
| 66 |
from torch.serialization import add_safe_globals
|
|
|
|
| 67 |
|
| 68 |
-
#
|
| 69 |
add_safe_globals([numpy.core.multiarray.scalar])
|
| 70 |
|
| 71 |
specs = {
|
|
@@ -75,11 +75,10 @@ def load_models():
|
|
| 75 |
}
|
| 76 |
|
| 77 |
models = {}
|
| 78 |
-
|
| 79 |
for name, (path, out_dim) in specs.items():
|
| 80 |
if not os.path.exists(path):
|
| 81 |
if os.path.exists("models.zip"):
|
| 82 |
-
|
| 83 |
with zipfile.ZipFile("models.zip", 'r') as z:
|
| 84 |
z.extractall(".")
|
| 85 |
else:
|
|
@@ -87,29 +86,30 @@ def load_models():
|
|
| 87 |
|
| 88 |
model = EnhancedGAT(input_dim=12, hidden_dim=512, output_dim=out_dim, num_heads=8)
|
| 89 |
|
| 90 |
-
# 🔧 明确禁用 weights_only 以兼容新 PyTorch 版本
|
| 91 |
try:
|
| 92 |
state = torch.load(path, map_location=device, weights_only=False)
|
| 93 |
except TypeError:
|
| 94 |
-
# 兼容旧版本 PyTorch(无 weights_only 参数)
|
| 95 |
state = torch.load(path, map_location=device)
|
| 96 |
|
| 97 |
state_dict = state.get("model_state_dict", state)
|
| 98 |
model.load_state_dict(state_dict)
|
| 99 |
model.eval().to(device)
|
| 100 |
models[name] = model
|
| 101 |
-
|
| 102 |
|
| 103 |
return models
|
| 104 |
|
| 105 |
-
|
| 106 |
models = load_models()
|
| 107 |
|
| 108 |
# ----------------------------
|
| 109 |
# Prediction Function
|
| 110 |
# ----------------------------
|
| 111 |
def predict_all(smiles: str):
|
| 112 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
| 113 |
atom_feats, (rows, cols, edge_attr), _ = smiles_to_graph(smiles)
|
| 114 |
x = torch.tensor(atom_feats, dtype=torch.float)
|
| 115 |
edge_index = torch.tensor(np.vstack((rows, cols)), dtype=torch.long)
|
|
@@ -117,17 +117,26 @@ def predict_all(smiles: str):
|
|
| 117 |
data = PyGData(x=x, edge_index=edge_index, edge_attr=edge_attr,
|
| 118 |
smiles=[smiles], batch=torch.zeros(x.size(0), dtype=torch.long))
|
| 119 |
|
| 120 |
-
|
|
|
|
|
|
|
| 121 |
for name in ["Elastic", "Plastic", "Brittle"]:
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 131 |
|
| 132 |
# ----------------------------
|
| 133 |
# Molecule Builder Utilities
|
|
@@ -201,31 +210,31 @@ def update_bonds_list(mol):
|
|
| 201 |
# ----------------------------
|
| 202 |
with gr.Blocks(title="CrystalGAT", css="""
|
| 203 |
.gradio-container {max-width:800px; margin:auto}
|
| 204 |
-
.block {padding:1em}
|
| 205 |
.gr-button {margin:0.2em}
|
| 206 |
""") as demo:
|
|
|
|
| 207 |
gr.Markdown("## CrystalGAT \nEnter a SMILES string or build a molecule to predict Elastic, Plastic, and Brittle classes with attention visualization.")
|
| 208 |
|
| 209 |
with gr.Tab("SMILES Input"):
|
| 210 |
smi_in = gr.Textbox(label="SMILES", placeholder="e.g. CCO")
|
| 211 |
predict1 = gr.Button("Predict")
|
| 212 |
-
|
| 213 |
-
with gr.Tab("Manual
|
| 214 |
state = gr.State(init_molecule())
|
| 215 |
status = gr.Textbox(label="Status", interactive=False, value="Start by adding atoms")
|
| 216 |
-
|
| 217 |
with gr.Row():
|
| 218 |
-
with gr.Column(
|
| 219 |
atom_type = gr.Dropdown(label="Atom Type", choices=ATOM_TYPES, value="C")
|
| 220 |
add_a = gr.Button("Add Atom")
|
| 221 |
atom_tbl = gr.Dataframe(headers=["ID","Type"], datatype=["number","str"], interactive=False)
|
| 222 |
-
with gr.Column(
|
| 223 |
a1 = gr.Dropdown(label="Atom 1", choices=[], value=None)
|
| 224 |
a2 = gr.Dropdown(label="Atom 2", choices=[], value=None)
|
| 225 |
bond_type = gr.Dropdown(label="Bond Type", choices=BOND_TYPES, value="Single")
|
| 226 |
add_b = gr.Button("Add Bond")
|
| 227 |
bond_tbl = gr.Dataframe(headers=["Atom1","Atom2","Type"], datatype=["str","str","str"], interactive=False)
|
| 228 |
-
|
| 229 |
with gr.Row():
|
| 230 |
clear = gr.Button("Clear All")
|
| 231 |
make = gr.Button("Generate SMILES")
|
|
@@ -234,7 +243,7 @@ with gr.Blocks(title="CrystalGAT", css="""
|
|
| 234 |
|
| 235 |
predict2 = gr.Button("Predict on Built Molecule")
|
| 236 |
|
| 237 |
-
#
|
| 238 |
with gr.Row():
|
| 239 |
e_txt = gr.Text(label="Elastic")
|
| 240 |
e_img = gr.Image(type="pil", label="Elastic Attention")
|
|
|
|
|
|
|
| 1 |
import os
|
| 2 |
import zipfile
|
| 3 |
import logging
|
| 4 |
import torch
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
import numpy as np
|
| 7 |
from io import BytesIO
|
| 8 |
from PIL import Image
|
|
|
|
| 10 |
import gradio as gr
|
| 11 |
from torch_geometric.data import Data as PyGData
|
| 12 |
import matplotlib
|
| 13 |
+
matplotlib.use('Agg')
|
| 14 |
|
| 15 |
from rdkit import Chem
|
| 16 |
from rdkit.Chem import Draw, AllChem, MolFromSmiles
|
|
|
|
| 62 |
# Model Loading
|
| 63 |
# ----------------------------
|
| 64 |
def load_models():
|
|
|
|
| 65 |
from torch.serialization import add_safe_globals
|
| 66 |
+
import numpy.core.multiarray
|
| 67 |
|
| 68 |
+
# allow safe numpy objects if needed
|
| 69 |
add_safe_globals([numpy.core.multiarray.scalar])
|
| 70 |
|
| 71 |
specs = {
|
|
|
|
| 75 |
}
|
| 76 |
|
| 77 |
models = {}
|
|
|
|
| 78 |
for name, (path, out_dim) in specs.items():
|
| 79 |
if not os.path.exists(path):
|
| 80 |
if os.path.exists("models.zip"):
|
| 81 |
+
logger.info("Extracting models.zip...")
|
| 82 |
with zipfile.ZipFile("models.zip", 'r') as z:
|
| 83 |
z.extractall(".")
|
| 84 |
else:
|
|
|
|
| 86 |
|
| 87 |
model = EnhancedGAT(input_dim=12, hidden_dim=512, output_dim=out_dim, num_heads=8)
|
| 88 |
|
|
|
|
| 89 |
try:
|
| 90 |
state = torch.load(path, map_location=device, weights_only=False)
|
| 91 |
except TypeError:
|
|
|
|
| 92 |
state = torch.load(path, map_location=device)
|
| 93 |
|
| 94 |
state_dict = state.get("model_state_dict", state)
|
| 95 |
model.load_state_dict(state_dict)
|
| 96 |
model.eval().to(device)
|
| 97 |
models[name] = model
|
| 98 |
+
logger.info(f"{name} model loaded successfully.")
|
| 99 |
|
| 100 |
return models
|
| 101 |
|
|
|
|
| 102 |
models = load_models()
|
| 103 |
|
| 104 |
# ----------------------------
|
| 105 |
# Prediction Function
|
| 106 |
# ----------------------------
|
| 107 |
def predict_all(smiles: str):
|
| 108 |
+
"""
|
| 109 |
+
Run predictions for Elastic, Plastic, Brittle.
|
| 110 |
+
Use threshold 0.5 for Elastic/Brittle, 0.3 for Plastic.
|
| 111 |
+
Return (text, PIL image) for each.
|
| 112 |
+
"""
|
| 113 |
atom_feats, (rows, cols, edge_attr), _ = smiles_to_graph(smiles)
|
| 114 |
x = torch.tensor(atom_feats, dtype=torch.float)
|
| 115 |
edge_index = torch.tensor(np.vstack((rows, cols)), dtype=torch.long)
|
|
|
|
| 117 |
data = PyGData(x=x, edge_index=edge_index, edge_attr=edge_attr,
|
| 118 |
smiles=[smiles], batch=torch.zeros(x.size(0), dtype=torch.long))
|
| 119 |
|
| 120 |
+
outputs = []
|
| 121 |
+
thresholds = {"Elastic": 0.5, "Plastic": 0.3, "Brittle": 0.5}
|
| 122 |
+
|
| 123 |
for name in ["Elastic", "Plastic", "Brittle"]:
|
| 124 |
+
model = models[name]
|
| 125 |
+
with torch.no_grad():
|
| 126 |
+
logits = model(data)
|
| 127 |
+
# assume binary classification: two outputs
|
| 128 |
+
if logits.dim() == 1 or logits.size(1) == 1:
|
| 129 |
+
prob = torch.sigmoid(logits).item()
|
| 130 |
+
else:
|
| 131 |
+
prob = F.softmax(logits, dim=1)[0, 1].item()
|
| 132 |
+
label = int(prob >= thresholds[name])
|
| 133 |
+
# get visualization buffer
|
| 134 |
+
buf, _ = visualize_single_molecule(model, data, device, name)
|
| 135 |
+
img = Image.open(buf) if buf else None
|
| 136 |
+
outputs.append((f"{name}: {label} (p={prob:.2f})", img))
|
| 137 |
+
|
| 138 |
+
# flatten to 6 outputs
|
| 139 |
+
return (*outputs[0], *outputs[1], *outputs[2])
|
| 140 |
|
| 141 |
# ----------------------------
|
| 142 |
# Molecule Builder Utilities
|
|
|
|
| 210 |
# ----------------------------
|
| 211 |
with gr.Blocks(title="CrystalGAT", css="""
|
| 212 |
.gradio-container {max-width:800px; margin:auto}
|
|
|
|
| 213 |
.gr-button {margin:0.2em}
|
| 214 |
""") as demo:
|
| 215 |
+
|
| 216 |
gr.Markdown("## CrystalGAT \nEnter a SMILES string or build a molecule to predict Elastic, Plastic, and Brittle classes with attention visualization.")
|
| 217 |
|
| 218 |
with gr.Tab("SMILES Input"):
|
| 219 |
smi_in = gr.Textbox(label="SMILES", placeholder="e.g. CCO")
|
| 220 |
predict1 = gr.Button("Predict")
|
| 221 |
+
|
| 222 |
+
with gr.Tab("Manual Molecule Construction"):
|
| 223 |
state = gr.State(init_molecule())
|
| 224 |
status = gr.Textbox(label="Status", interactive=False, value="Start by adding atoms")
|
| 225 |
+
|
| 226 |
with gr.Row():
|
| 227 |
+
with gr.Column():
|
| 228 |
atom_type = gr.Dropdown(label="Atom Type", choices=ATOM_TYPES, value="C")
|
| 229 |
add_a = gr.Button("Add Atom")
|
| 230 |
atom_tbl = gr.Dataframe(headers=["ID","Type"], datatype=["number","str"], interactive=False)
|
| 231 |
+
with gr.Column():
|
| 232 |
a1 = gr.Dropdown(label="Atom 1", choices=[], value=None)
|
| 233 |
a2 = gr.Dropdown(label="Atom 2", choices=[], value=None)
|
| 234 |
bond_type = gr.Dropdown(label="Bond Type", choices=BOND_TYPES, value="Single")
|
| 235 |
add_b = gr.Button("Add Bond")
|
| 236 |
bond_tbl = gr.Dataframe(headers=["Atom1","Atom2","Type"], datatype=["str","str","str"], interactive=False)
|
| 237 |
+
|
| 238 |
with gr.Row():
|
| 239 |
clear = gr.Button("Clear All")
|
| 240 |
make = gr.Button("Generate SMILES")
|
|
|
|
| 243 |
|
| 244 |
predict2 = gr.Button("Predict on Built Molecule")
|
| 245 |
|
| 246 |
+
# Outputs
|
| 247 |
with gr.Row():
|
| 248 |
e_txt = gr.Text(label="Elastic")
|
| 249 |
e_img = gr.Image(type="pil", label="Elastic Attention")
|