ZZZCCCYYY commited on
Commit
e7faed8
·
verified ·
1 Parent(s): 82aab83

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +285 -0
app.py ADDED
@@ -0,0 +1,285 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import zipfile
3
+ import logging
4
+ import torch
5
+ import torch.nn.functional as F
6
+ import numpy as np
7
+ from io import BytesIO
8
+ from PIL import Image
9
+
10
+ import gradio as gr
11
+ from torch_geometric.data import Data as PyGData
12
+ import matplotlib
13
+ matplotlib.use('Agg')
14
+
15
+ from rdkit import Chem
16
+ from rdkit.Chem import Draw, AllChem, MolFromSmiles
17
+
18
+ # ----------------------------
19
+ # Logging & GPU Configuration
20
+ # ----------------------------
21
+ logging.basicConfig(
22
+ level=logging.INFO,
23
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
24
+ )
25
+ logger = logging.getLogger(__name__)
26
+
27
+ os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
28
+ logger.info("Set GPU memory optimization: PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128")
29
+
30
+ # ----------------------------
31
+ # Unzip Model Files if Needed
32
+ # ----------------------------
33
+ if not os.path.exists("best_model-B-6000-185.pth"):
34
+ logger.info("Unzipping model archive...")
35
+ try:
36
+ with zipfile.ZipFile("models.zip", 'r') as z:
37
+ z.extractall(".")
38
+ logger.info("Model archive unzipped successfully.")
39
+ except Exception as e:
40
+ logger.error(f"Failed to unzip models.zip: {e}")
41
+ raise
42
+
43
+ # ----------------------------
44
+ # Import Model Utilities
45
+ # ----------------------------
46
+ try:
47
+ from model_utils import EnhancedGAT, smiles_to_graph, visualize_single_molecule
48
+ logger.info("Imported model_utils successfully.")
49
+ except ImportError as e:
50
+ logger.error(f"Failed to import model_utils: {e}")
51
+ raise
52
+
53
+ # ----------------------------
54
+ # Device Setup
55
+ # ----------------------------
56
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
57
+ logger.info(f"Using device: {device}")
58
+ if torch.cuda.is_available():
59
+ logger.info(f"GPU: {torch.cuda.get_device_name(0)}")
60
+
61
+ # ----------------------------
62
+ # Model Loading
63
+ # ----------------------------
64
+ def load_models():
65
+ from torch.serialization import add_safe_globals
66
+ import numpy.core.multiarray
67
+
68
+ # allow safe numpy objects if needed
69
+ add_safe_globals([numpy.core.multiarray.scalar])
70
+
71
+ specs = {
72
+ "Elastic": ("models/best_model-E-500-68.pth", 2),
73
+ "Plastic": ("models/best_model-P-5000-180.pth", 2),
74
+ "Brittle": ("models/best_model-B-6000-185.pth", 2),
75
+ }
76
+
77
+ models = {}
78
+ for name, (path, out_dim) in specs.items():
79
+ if not os.path.exists(path):
80
+ if os.path.exists("models.zip"):
81
+ logger.info("Extracting models.zip...")
82
+ with zipfile.ZipFile("models.zip", 'r') as z:
83
+ z.extractall(".")
84
+ else:
85
+ raise FileNotFoundError(f"Missing model file: {path}")
86
+
87
+ model = EnhancedGAT(input_dim=12, hidden_dim=512, output_dim=out_dim, num_heads=8)
88
+
89
+ try:
90
+ state = torch.load(path, map_location=device, weights_only=False)
91
+ except TypeError:
92
+ state = torch.load(path, map_location=device)
93
+
94
+ state_dict = state.get("model_state_dict", state)
95
+ model.load_state_dict(state_dict)
96
+ model.eval().to(device)
97
+ models[name] = model
98
+ logger.info(f"{name} model loaded successfully.")
99
+
100
+ return models
101
+
102
+ models = load_models()
103
+
104
+ # ----------------------------
105
+ # Prediction Function
106
+ # ----------------------------
107
+ def predict_all(smiles: str):
108
+ """
109
+ Run predictions for Elastic, Plastic, Brittle.
110
+ Use threshold 0.5 for Elastic/Brittle, 0.3 for Plastic.
111
+ Return (text, PIL image) for each.
112
+ """
113
+ atom_feats, (rows, cols, edge_attr), _ = smiles_to_graph(smiles)
114
+ x = torch.tensor(atom_feats, dtype=torch.float)
115
+ edge_index = torch.tensor(np.vstack((rows, cols)), dtype=torch.long)
116
+ edge_attr = torch.tensor(edge_attr, dtype=torch.float).unsqueeze(1)
117
+ data = PyGData(x=x, edge_index=edge_index, edge_attr=edge_attr,
118
+ smiles=[smiles], batch=torch.zeros(x.size(0), dtype=torch.long))
119
+
120
+ outputs = []
121
+ thresholds = {"Elastic": 0.5, "Plastic": 0.3, "Brittle": 0.5}
122
+
123
+ for name in ["Elastic", "Plastic", "Brittle"]:
124
+ model = models[name]
125
+ with torch.no_grad():
126
+ logits = model(data)
127
+ # assume binary classification: two outputs
128
+ if logits.dim() == 1 or logits.size(1) == 1:
129
+ prob = torch.sigmoid(logits).item()
130
+ else:
131
+ prob = F.softmax(logits, dim=1)[0, 1].item()
132
+ label = int(prob >= thresholds[name])
133
+ # get visualization buffer
134
+ buf, _ = visualize_single_molecule(model, data, device, name)
135
+ img = Image.open(buf) if buf else None
136
+ outputs.append((f"{name}: {label}", img))
137
+
138
+ # flatten to 6 outputs
139
+ return (*outputs[0], *outputs[1], *outputs[2])
140
+
141
+ # ----------------------------
142
+ # Molecule Builder Utilities
143
+ # ----------------------------
144
+ ATOM_TYPES = ["C", "N", "O", "S", "P", "F", "Cl", "Br", "I", "H"]
145
+ BOND_TYPES = ["Single", "Double", "Triple"]
146
+
147
+ def init_molecule():
148
+ return {"atoms": [], "bonds": []}
149
+
150
+ def add_atom(mol, atom_type):
151
+ mol["atoms"].append({"id": len(mol["atoms"]), "type": atom_type})
152
+ return mol
153
+
154
+ def add_bond(mol, a1_sel, a2_sel, b_type):
155
+ if not a1_sel or not a2_sel:
156
+ return mol
157
+ i1, i2 = int(a1_sel.split(":")[0]), int(a2_sel.split(":")[0])
158
+ if {i1, i2} in [{b["atom1"], b["atom2"]} for b in mol["bonds"]]:
159
+ return mol
160
+ mol["bonds"].append({"atom1": i1, "atom2": i2, "type": b_type})
161
+ return mol
162
+
163
+ def generate_smiles(mol):
164
+ try:
165
+ rw = Chem.RWMol()
166
+ id_map = {}
167
+ for atom in mol["atoms"]:
168
+ idx = rw.AddAtom(Chem.Atom(atom["type"]))
169
+ id_map[atom["id"]] = idx
170
+ for b in mol["bonds"]:
171
+ bond_map = {"Single": Chem.BondType.SINGLE,
172
+ "Double": Chem.BondType.DOUBLE,
173
+ "Triple": Chem.BondType.TRIPLE}
174
+ rw.AddBond(id_map[b["atom1"]], id_map[b["atom2"]], bond_map[b["type"]])
175
+ rw.UpdatePropertyCache()
176
+ Chem.SanitizeMol(rw)
177
+ return Chem.MolToSmiles(rw)
178
+ except Exception as e:
179
+ logger.error(f"SMILES generation failed: {e}")
180
+ return ""
181
+
182
+ def visualize_molecule(mol):
183
+ """Return a PIL image or None."""
184
+ smiles = generate_smiles(mol)
185
+ if not smiles:
186
+ return None
187
+ m = MolFromSmiles(smiles)
188
+ if m is None:
189
+ return None
190
+ AllChem.Compute2DCoords(m)
191
+ return Draw.MolToImage(m, size=(300, 300))
192
+
193
+ def update_atom_dropdowns(mol):
194
+ choices = [f"{a['id']}: {a['type']}" for a in mol["atoms"]]
195
+ return gr.update(choices=choices, value=None), gr.update(choices=choices, value=None)
196
+
197
+ def update_atoms_list(mol):
198
+ return [[a["id"], a["type"]] for a in mol["atoms"]]
199
+
200
+ def update_bonds_list(mol):
201
+ out = []
202
+ for b in mol["bonds"]:
203
+ t1 = next(a["type"] for a in mol["atoms"] if a["id"] == b["atom1"])
204
+ t2 = next(a["type"] for a in mol["atoms"] if a["id"] == b["atom2"])
205
+ out.append([f"{b['atom1']}: {t1}", f"{b['atom2']}: {t2}", b["type"]])
206
+ return out
207
+
208
+ # ----------------------------
209
+ # Gradio Interface
210
+ # ----------------------------
211
+ with gr.Blocks(title="CrystalGAT", css="""
212
+ .gradio-container {max-width:800px; margin:auto}
213
+ .gr-button {margin:0.2em}
214
+ """) as demo:
215
+
216
+ gr.Markdown("## CrystalGAT \nEnter a SMILES string or build a molecule to predict Elastic, Plastic, and Brittle classes with attention visualization.")
217
+
218
+ with gr.Tab("SMILES Input"):
219
+ smi_in = gr.Textbox(label="SMILES", placeholder="e.g. CCO")
220
+ predict1 = gr.Button("Predict")
221
+
222
+ with gr.Tab("Manual Molecule Construction"):
223
+ state = gr.State(init_molecule())
224
+ status = gr.Textbox(label="Status", interactive=False, value="Start by adding atoms")
225
+
226
+ with gr.Row():
227
+ with gr.Column():
228
+ atom_type = gr.Dropdown(label="Atom Type", choices=ATOM_TYPES, value="C")
229
+ add_a = gr.Button("Add Atom")
230
+ atom_tbl = gr.Dataframe(headers=["ID","Type"], datatype=["number","str"], interactive=False)
231
+ with gr.Column():
232
+ a1 = gr.Dropdown(label="Atom 1", choices=[], value=None)
233
+ a2 = gr.Dropdown(label="Atom 2", choices=[], value=None)
234
+ bond_type = gr.Dropdown(label="Bond Type", choices=BOND_TYPES, value="Single")
235
+ add_b = gr.Button("Add Bond")
236
+ bond_tbl = gr.Dataframe(headers=["Atom1","Atom2","Type"], datatype=["str","str","str"], interactive=False)
237
+
238
+ with gr.Row():
239
+ clear = gr.Button("Clear All")
240
+ make = gr.Button("Generate SMILES")
241
+ smi_out = gr.Textbox(label="SMILES Output", interactive=False)
242
+ mol_img = gr.Image(type="pil", label="Molecule Preview")
243
+
244
+ predict2 = gr.Button("Predict on Built Molecule")
245
+
246
+ # Outputs
247
+ with gr.Row():
248
+ e_txt = gr.Text(label="Elastic")
249
+ e_img = gr.Image(type="pil", label="Elastic Attention")
250
+ with gr.Row():
251
+ p_txt = gr.Text(label="Plastic")
252
+ p_img = gr.Image(type="pil", label="Plastic Attention")
253
+ with gr.Row():
254
+ b_txt = gr.Text(label="Brittle")
255
+ b_img = gr.Image(type="pil", label="Brittle Attention")
256
+
257
+ # Event bindings
258
+ predict1.click(fn=predict_all, inputs=smi_in,
259
+ outputs=[e_txt, e_img, p_txt, p_img, b_txt, b_img])
260
+
261
+ add_a.click(fn=add_atom, inputs=[state, atom_type], outputs=state)\
262
+ .then(fn=update_atoms_list, inputs=state, outputs=atom_tbl)\
263
+ .then(fn=update_atom_dropdowns, inputs=state, outputs=[a1, a2])\
264
+ .then(fn=lambda: "Atom added.", outputs=status)
265
+
266
+ add_b.click(fn=add_bond, inputs=[state, a1, a2, bond_type], outputs=state)\
267
+ .then(fn=update_bonds_list, inputs=state, outputs=bond_tbl)\
268
+ .then(fn=lambda: "Bond added/updated.", outputs=status)
269
+
270
+ clear.click(fn=init_molecule, outputs=state)\
271
+ .then(fn=lambda: ([], []), outputs=[atom_tbl, bond_tbl])\
272
+ .then(fn=lambda: (gr.update(choices=[], value=None), gr.update(choices=[], value=None)),
273
+ outputs=[a1, a2])\
274
+ .then(fn=lambda: "Cleared all.", outputs=status)
275
+
276
+ make.click(fn=generate_smiles, inputs=state, outputs=smi_out)\
277
+ .then(fn=visualize_molecule, inputs=state, outputs=mol_img)\
278
+ .then(fn=lambda: "Molecule generated.", outputs=status)
279
+
280
+ predict2.click(fn=lambda s: predict_all(s) if s else ("Enter SMILES", None, "", None, "", None),
281
+ inputs=smi_out,
282
+ outputs=[e_txt, e_img, p_txt, p_img, b_txt, b_img])
283
+
284
+ if __name__ == "__main__":
285
+ demo.launch(server_name="0.0.0.0", server_port=7860)