QQ2S3R commited on
Commit
793eee2
·
verified ·
1 Parent(s): b9eda4c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +38 -29
app.py CHANGED
@@ -1,8 +1,8 @@
1
-
2
  import os
3
  import zipfile
4
  import logging
5
  import torch
 
6
  import numpy as np
7
  from io import BytesIO
8
  from PIL import Image
@@ -10,7 +10,7 @@ from PIL import Image
10
  import gradio as gr
11
  from torch_geometric.data import Data as PyGData
12
  import matplotlib
13
- matplotlib.use('Agg') # fix backend threading issue
14
 
15
  from rdkit import Chem
16
  from rdkit.Chem import Draw, AllChem, MolFromSmiles
@@ -62,10 +62,10 @@ if torch.cuda.is_available():
62
  # Model Loading
63
  # ----------------------------
64
  def load_models():
65
- import numpy.core.multiarray
66
  from torch.serialization import add_safe_globals
 
67
 
68
- # 显式允许加载某些 numpy 序列化类型(仅当你信任模型来源)
69
  add_safe_globals([numpy.core.multiarray.scalar])
70
 
71
  specs = {
@@ -75,11 +75,10 @@ def load_models():
75
  }
76
 
77
  models = {}
78
-
79
  for name, (path, out_dim) in specs.items():
80
  if not os.path.exists(path):
81
  if os.path.exists("models.zip"):
82
- logging.info("Unzipping models.zip...")
83
  with zipfile.ZipFile("models.zip", 'r') as z:
84
  z.extractall(".")
85
  else:
@@ -87,29 +86,30 @@ def load_models():
87
 
88
  model = EnhancedGAT(input_dim=12, hidden_dim=512, output_dim=out_dim, num_heads=8)
89
 
90
- # 🔧 明确禁用 weights_only 以兼容新 PyTorch 版本
91
  try:
92
  state = torch.load(path, map_location=device, weights_only=False)
93
  except TypeError:
94
- # 兼容旧版本 PyTorch(无 weights_only 参数)
95
  state = torch.load(path, map_location=device)
96
 
97
  state_dict = state.get("model_state_dict", state)
98
  model.load_state_dict(state_dict)
99
  model.eval().to(device)
100
  models[name] = model
101
- logging.info(f"{name} model loaded successfully.")
102
 
103
  return models
104
 
105
-
106
  models = load_models()
107
 
108
  # ----------------------------
109
  # Prediction Function
110
  # ----------------------------
111
  def predict_all(smiles: str):
112
- """Run Elastic, Plastic, Brittle predictions; return text & PIL images."""
 
 
 
 
113
  atom_feats, (rows, cols, edge_attr), _ = smiles_to_graph(smiles)
114
  x = torch.tensor(atom_feats, dtype=torch.float)
115
  edge_index = torch.tensor(np.vstack((rows, cols)), dtype=torch.long)
@@ -117,17 +117,26 @@ def predict_all(smiles: str):
117
  data = PyGData(x=x, edge_index=edge_index, edge_attr=edge_attr,
118
  smiles=[smiles], batch=torch.zeros(x.size(0), dtype=torch.long))
119
 
120
- results = []
 
 
121
  for name in ["Elastic", "Plastic", "Brittle"]:
122
- buf, pred = visualize_single_molecule(models[name], data, device, name)
123
- if buf:
124
- buf.seek(0)
125
- img = Image.open(buf)
126
- text = f"{name}: {int(pred)}"
127
- results.append((text, img))
128
- else:
129
- results.append((f"{name} prediction failed", None))
130
- return (*results[0], *results[1], *results[2])
 
 
 
 
 
 
 
131
 
132
  # ----------------------------
133
  # Molecule Builder Utilities
@@ -201,31 +210,31 @@ def update_bonds_list(mol):
201
  # ----------------------------
202
  with gr.Blocks(title="CrystalGAT", css="""
203
  .gradio-container {max-width:800px; margin:auto}
204
- .block {padding:1em}
205
  .gr-button {margin:0.2em}
206
  """) as demo:
 
207
  gr.Markdown("## CrystalGAT \nEnter a SMILES string or build a molecule to predict Elastic, Plastic, and Brittle classes with attention visualization.")
208
 
209
  with gr.Tab("SMILES Input"):
210
  smi_in = gr.Textbox(label="SMILES", placeholder="e.g. CCO")
211
  predict1 = gr.Button("Predict")
212
-
213
- with gr.Tab("Manual Molecules Construction"):
214
  state = gr.State(init_molecule())
215
  status = gr.Textbox(label="Status", interactive=False, value="Start by adding atoms")
216
-
217
  with gr.Row():
218
- with gr.Column(scale=1):
219
  atom_type = gr.Dropdown(label="Atom Type", choices=ATOM_TYPES, value="C")
220
  add_a = gr.Button("Add Atom")
221
  atom_tbl = gr.Dataframe(headers=["ID","Type"], datatype=["number","str"], interactive=False)
222
- with gr.Column(scale=1):
223
  a1 = gr.Dropdown(label="Atom 1", choices=[], value=None)
224
  a2 = gr.Dropdown(label="Atom 2", choices=[], value=None)
225
  bond_type = gr.Dropdown(label="Bond Type", choices=BOND_TYPES, value="Single")
226
  add_b = gr.Button("Add Bond")
227
  bond_tbl = gr.Dataframe(headers=["Atom1","Atom2","Type"], datatype=["str","str","str"], interactive=False)
228
-
229
  with gr.Row():
230
  clear = gr.Button("Clear All")
231
  make = gr.Button("Generate SMILES")
@@ -234,7 +243,7 @@ with gr.Blocks(title="CrystalGAT", css="""
234
 
235
  predict2 = gr.Button("Predict on Built Molecule")
236
 
237
- # Prediction Outputs
238
  with gr.Row():
239
  e_txt = gr.Text(label="Elastic")
240
  e_img = gr.Image(type="pil", label="Elastic Attention")
 
 
1
  import os
2
  import zipfile
3
  import logging
4
  import torch
5
+ import torch.nn.functional as F
6
  import numpy as np
7
  from io import BytesIO
8
  from PIL import Image
 
10
  import gradio as gr
11
  from torch_geometric.data import Data as PyGData
12
  import matplotlib
13
+ matplotlib.use('Agg')
14
 
15
  from rdkit import Chem
16
  from rdkit.Chem import Draw, AllChem, MolFromSmiles
 
62
  # Model Loading
63
  # ----------------------------
64
  def load_models():
 
65
  from torch.serialization import add_safe_globals
66
+ import numpy.core.multiarray
67
 
68
+ # allow safe numpy objects if needed
69
  add_safe_globals([numpy.core.multiarray.scalar])
70
 
71
  specs = {
 
75
  }
76
 
77
  models = {}
 
78
  for name, (path, out_dim) in specs.items():
79
  if not os.path.exists(path):
80
  if os.path.exists("models.zip"):
81
+ logger.info("Extracting models.zip...")
82
  with zipfile.ZipFile("models.zip", 'r') as z:
83
  z.extractall(".")
84
  else:
 
86
 
87
  model = EnhancedGAT(input_dim=12, hidden_dim=512, output_dim=out_dim, num_heads=8)
88
 
 
89
  try:
90
  state = torch.load(path, map_location=device, weights_only=False)
91
  except TypeError:
 
92
  state = torch.load(path, map_location=device)
93
 
94
  state_dict = state.get("model_state_dict", state)
95
  model.load_state_dict(state_dict)
96
  model.eval().to(device)
97
  models[name] = model
98
+ logger.info(f"{name} model loaded successfully.")
99
 
100
  return models
101
 
 
102
  models = load_models()
103
 
104
  # ----------------------------
105
  # Prediction Function
106
  # ----------------------------
107
  def predict_all(smiles: str):
108
+ """
109
+ Run predictions for Elastic, Plastic, Brittle.
110
+ Use threshold 0.5 for Elastic/Brittle, 0.3 for Plastic.
111
+ Return (text, PIL image) for each.
112
+ """
113
  atom_feats, (rows, cols, edge_attr), _ = smiles_to_graph(smiles)
114
  x = torch.tensor(atom_feats, dtype=torch.float)
115
  edge_index = torch.tensor(np.vstack((rows, cols)), dtype=torch.long)
 
117
  data = PyGData(x=x, edge_index=edge_index, edge_attr=edge_attr,
118
  smiles=[smiles], batch=torch.zeros(x.size(0), dtype=torch.long))
119
 
120
+ outputs = []
121
+ thresholds = {"Elastic": 0.5, "Plastic": 0.3, "Brittle": 0.5}
122
+
123
  for name in ["Elastic", "Plastic", "Brittle"]:
124
+ model = models[name]
125
+ with torch.no_grad():
126
+ logits = model(data)
127
+ # assume binary classification: two outputs
128
+ if logits.dim() == 1 or logits.size(1) == 1:
129
+ prob = torch.sigmoid(logits).item()
130
+ else:
131
+ prob = F.softmax(logits, dim=1)[0, 1].item()
132
+ label = int(prob >= thresholds[name])
133
+ # get visualization buffer
134
+ buf, _ = visualize_single_molecule(model, data, device, name)
135
+ img = Image.open(buf) if buf else None
136
+ outputs.append((f"{name}: {label} (p={prob:.2f})", img))
137
+
138
+ # flatten to 6 outputs
139
+ return (*outputs[0], *outputs[1], *outputs[2])
140
 
141
  # ----------------------------
142
  # Molecule Builder Utilities
 
210
  # ----------------------------
211
  with gr.Blocks(title="CrystalGAT", css="""
212
  .gradio-container {max-width:800px; margin:auto}
 
213
  .gr-button {margin:0.2em}
214
  """) as demo:
215
+
216
  gr.Markdown("## CrystalGAT \nEnter a SMILES string or build a molecule to predict Elastic, Plastic, and Brittle classes with attention visualization.")
217
 
218
  with gr.Tab("SMILES Input"):
219
  smi_in = gr.Textbox(label="SMILES", placeholder="e.g. CCO")
220
  predict1 = gr.Button("Predict")
221
+
222
+ with gr.Tab("Manual Molecule Construction"):
223
  state = gr.State(init_molecule())
224
  status = gr.Textbox(label="Status", interactive=False, value="Start by adding atoms")
225
+
226
  with gr.Row():
227
+ with gr.Column():
228
  atom_type = gr.Dropdown(label="Atom Type", choices=ATOM_TYPES, value="C")
229
  add_a = gr.Button("Add Atom")
230
  atom_tbl = gr.Dataframe(headers=["ID","Type"], datatype=["number","str"], interactive=False)
231
+ with gr.Column():
232
  a1 = gr.Dropdown(label="Atom 1", choices=[], value=None)
233
  a2 = gr.Dropdown(label="Atom 2", choices=[], value=None)
234
  bond_type = gr.Dropdown(label="Bond Type", choices=BOND_TYPES, value="Single")
235
  add_b = gr.Button("Add Bond")
236
  bond_tbl = gr.Dataframe(headers=["Atom1","Atom2","Type"], datatype=["str","str","str"], interactive=False)
237
+
238
  with gr.Row():
239
  clear = gr.Button("Clear All")
240
  make = gr.Button("Generate SMILES")
 
243
 
244
  predict2 = gr.Button("Predict on Built Molecule")
245
 
246
+ # Outputs
247
  with gr.Row():
248
  e_txt = gr.Text(label="Elastic")
249
  e_img = gr.Image(type="pil", label="Elastic Attention")