File size: 10,528 Bytes
36df9cc
 
76c5703
36df9cc
793eee2
36df9cc
76c5703
36df9cc
76c5703
36df9cc
 
 
793eee2
76c5703
36df9cc
8ab481c
36df9cc
76c5703
 
 
f900cab
 
 
 
36df9cc
 
53c434b
76c5703
36df9cc
76c5703
 
 
53c434b
76c5703
36df9cc
76c5703
 
 
36df9cc
76c5703
36df9cc
 
76c5703
 
 
36df9cc
 
76c5703
36df9cc
76c5703
36df9cc
 
76c5703
 
 
36df9cc
76c5703
36df9cc
76c5703
36df9cc
76c5703
 
 
36df9cc
b9eda4c
793eee2
b9eda4c
793eee2
b9eda4c
 
76c5703
53c434b
 
f900cab
36df9cc
b9eda4c
36df9cc
76c5703
 
b9eda4c
793eee2
b9eda4c
 
f900cab
b9eda4c
76c5703
 
b9eda4c
b5c7e53
76c5703
b9eda4c
76c5703
b9eda4c
f900cab
 
 
 
793eee2
b9eda4c
36df9cc
 
8ab481c
36df9cc
76c5703
 
 
f900cab
793eee2
 
 
 
 
76c5703
 
f900cab
 
76c5703
 
 
793eee2
 
 
f900cab
793eee2
 
 
 
 
 
 
 
 
 
30807b6
793eee2
ae8870e
793eee2
 
 
36df9cc
76c5703
 
 
b5c7e53
76c5703
b5e3512
 
8ab481c
b5e3512
76c5703
 
 
 
 
 
 
 
 
 
 
 
 
 
b5e3512
76c5703
f900cab
76c5703
 
f900cab
76c5703
 
 
 
 
 
 
 
b5e3512
76c5703
f900cab
b5e3512
76c5703
 
 
f900cab
b5c7e53
76c5703
 
f900cab
76c5703
 
8ab481c
76c5703
 
b5c7e53
8ab481c
76c5703
 
8ab481c
76c5703
8ab481c
76c5703
 
 
b5c7e53
8ab481c
b5e3512
76c5703
 
 
 
 
 
 
793eee2
76c5703
 
 
 
 
793eee2
 
f900cab
76c5703
793eee2
9bb2bc4
793eee2
76c5703
 
 
793eee2
76c5703
 
 
 
 
793eee2
9bb2bc4
76c5703
 
 
 
 
 
f900cab
793eee2
53c434b
f900cab
76c5703
53c434b
f900cab
76c5703
53c434b
f900cab
76c5703
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36df9cc
 
8ab481c
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
import os
import zipfile
import logging
import torch
import torch.nn.functional as F
import numpy as np
from io import BytesIO
from PIL import Image

import gradio as gr
from torch_geometric.data import Data as PyGData
import matplotlib
matplotlib.use('Agg')

from rdkit import Chem
from rdkit.Chem import Draw, AllChem, MolFromSmiles

# ----------------------------
# Logging & GPU Configuration
# ----------------------------
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128"
logger.info("Set GPU memory optimization: PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128")

# ----------------------------
# Unzip Model Files if Needed
# ----------------------------
if not os.path.exists("best_model-B-6000-185.pth"):
    logger.info("Unzipping model archive...")
    try:
        with zipfile.ZipFile("models.zip", 'r') as z:
            z.extractall(".")
        logger.info("Model archive unzipped successfully.")
    except Exception as e:
        logger.error(f"Failed to unzip models.zip: {e}")
        raise

# ----------------------------
# Import Model Utilities
# ----------------------------
try:
    from model_utils import EnhancedGAT, smiles_to_graph, visualize_single_molecule
    logger.info("Imported model_utils successfully.")
except ImportError as e:
    logger.error(f"Failed to import model_utils: {e}")
    raise

# ----------------------------
# Device Setup
# ----------------------------
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
logger.info(f"Using device: {device}")
if torch.cuda.is_available():
    logger.info(f"GPU: {torch.cuda.get_device_name(0)}")

# ----------------------------
# Model Loading
# ----------------------------
def load_models():
    from torch.serialization import add_safe_globals
    import numpy.core.multiarray

    # allow safe numpy objects if needed
    add_safe_globals([numpy.core.multiarray.scalar])

    specs = {
        "Elastic": ("models/best_model-E-500-68.pth", 2),
        "Plastic": ("models/best_model-P-5000-180.pth", 2),
        "Brittle": ("models/best_model-B-6000-185.pth", 2),
    }

    models = {}
    for name, (path, out_dim) in specs.items():
        if not os.path.exists(path):
            if os.path.exists("models.zip"):
                logger.info("Extracting models.zip...")
                with zipfile.ZipFile("models.zip", 'r') as z:
                    z.extractall(".")
            else:
                raise FileNotFoundError(f"Missing model file: {path}")

        model = EnhancedGAT(input_dim=12, hidden_dim=512, output_dim=out_dim, num_heads=8)

        try:
            state = torch.load(path, map_location=device, weights_only=False)
        except TypeError:
            state = torch.load(path, map_location=device)

        state_dict = state.get("model_state_dict", state)
        model.load_state_dict(state_dict)
        model.eval().to(device)
        models[name] = model
        logger.info(f"{name} model loaded successfully.")

    return models

models = load_models()

# ----------------------------
# Prediction Function
# ----------------------------
def predict_all(smiles: str):
    """
    Run predictions for Elastic, Plastic, Brittle.
    Use threshold 0.5 for Elastic/Brittle, 0.3 for Plastic.
    Return (text, PIL image) for each.
    """
    atom_feats, (rows, cols, edge_attr), _ = smiles_to_graph(smiles)
    x = torch.tensor(atom_feats, dtype=torch.float)
    edge_index = torch.tensor(np.vstack((rows, cols)), dtype=torch.long)
    edge_attr = torch.tensor(edge_attr, dtype=torch.float).unsqueeze(1)
    data = PyGData(x=x, edge_index=edge_index, edge_attr=edge_attr,
                   smiles=[smiles], batch=torch.zeros(x.size(0), dtype=torch.long))

    outputs = []
    thresholds = {"Elastic": 0.5, "Plastic": 0.3, "Brittle": 0.5}

    for name in ["Elastic", "Plastic", "Brittle"]:
        model = models[name]
        with torch.no_grad():
            logits = model(data)
            # assume binary classification: two outputs
            if logits.dim() == 1 or logits.size(1) == 1:
                prob = torch.sigmoid(logits).item()
            else:
                prob = F.softmax(logits, dim=1)[0, 1].item()
        label = int(prob >= thresholds[name])
        # get visualization buffer
        buf, _ = visualize_single_molecule(model, data, device, name)
        img = Image.open(buf) if buf else None
        outputs.append((f"{name}: {label}", img))

    # flatten to 6 outputs
    return (*outputs[0], *outputs[1], *outputs[2])

# ----------------------------
# Molecule Builder Utilities
# ----------------------------
ATOM_TYPES = ["C", "N", "O", "S", "P", "F", "Cl", "Br", "I", "H"]
BOND_TYPES = ["Single", "Double", "Triple"]

def init_molecule():
    return {"atoms": [], "bonds": []}

def add_atom(mol, atom_type):
    mol["atoms"].append({"id": len(mol["atoms"]), "type": atom_type})
    return mol

def add_bond(mol, a1_sel, a2_sel, b_type):
    if not a1_sel or not a2_sel:
        return mol
    i1, i2 = int(a1_sel.split(":")[0]), int(a2_sel.split(":")[0])
    if {i1, i2} in [{b["atom1"], b["atom2"]} for b in mol["bonds"]]:
        return mol
    mol["bonds"].append({"atom1": i1, "atom2": i2, "type": b_type})
    return mol

def generate_smiles(mol):
    try:
        rw = Chem.RWMol()
        id_map = {}
        for atom in mol["atoms"]:
            idx = rw.AddAtom(Chem.Atom(atom["type"]))
            id_map[atom["id"]] = idx
        for b in mol["bonds"]:
            bond_map = {"Single": Chem.BondType.SINGLE,
                        "Double": Chem.BondType.DOUBLE,
                        "Triple": Chem.BondType.TRIPLE}
            rw.AddBond(id_map[b["atom1"]], id_map[b["atom2"]], bond_map[b["type"]])
        rw.UpdatePropertyCache()
        Chem.SanitizeMol(rw)
        return Chem.MolToSmiles(rw)
    except Exception as e:
        logger.error(f"SMILES generation failed: {e}")
        return ""

def visualize_molecule(mol):
    """Return a PIL image or None."""
    smiles = generate_smiles(mol)
    if not smiles:
        return None
    m = MolFromSmiles(smiles)
    if m is None:
        return None
    AllChem.Compute2DCoords(m)
    return Draw.MolToImage(m, size=(300, 300))

def update_atom_dropdowns(mol):
    choices = [f"{a['id']}: {a['type']}" for a in mol["atoms"]]
    return gr.update(choices=choices, value=None), gr.update(choices=choices, value=None)

def update_atoms_list(mol):
    return [[a["id"], a["type"]] for a in mol["atoms"]]

def update_bonds_list(mol):
    out = []
    for b in mol["bonds"]:
        t1 = next(a["type"] for a in mol["atoms"] if a["id"] == b["atom1"])
        t2 = next(a["type"] for a in mol["atoms"] if a["id"] == b["atom2"])
        out.append([f"{b['atom1']}: {t1}", f"{b['atom2']}: {t2}", b["type"]])
    return out

# ----------------------------
# Gradio Interface
# ----------------------------
with gr.Blocks(title="CrystalGAT", css="""
    .gradio-container {max-width:800px; margin:auto}
    .gr-button {margin:0.2em}
""") as demo:

    gr.Markdown("## CrystalGAT  \nEnter a SMILES string or build a molecule to predict Elastic, Plastic, and Brittle classes with attention visualization.")

    with gr.Tab("SMILES Input"):
        smi_in = gr.Textbox(label="SMILES", placeholder="e.g. CCO")
        predict1 = gr.Button("Predict")

    with gr.Tab("Manual Molecule Construction"):
        state = gr.State(init_molecule())
        status = gr.Textbox(label="Status", interactive=False, value="Start by adding atoms")

        with gr.Row():
            with gr.Column():
                atom_type = gr.Dropdown(label="Atom Type", choices=ATOM_TYPES, value="C")
                add_a = gr.Button("Add Atom")
                atom_tbl = gr.Dataframe(headers=["ID","Type"], datatype=["number","str"], interactive=False)
            with gr.Column():
                a1 = gr.Dropdown(label="Atom 1", choices=[], value=None)
                a2 = gr.Dropdown(label="Atom 2", choices=[], value=None)
                bond_type = gr.Dropdown(label="Bond Type", choices=BOND_TYPES, value="Single")
                add_b = gr.Button("Add Bond")
                bond_tbl = gr.Dataframe(headers=["Atom1","Atom2","Type"], datatype=["str","str","str"], interactive=False)

        with gr.Row():
            clear = gr.Button("Clear All")
            make = gr.Button("Generate SMILES")
            smi_out = gr.Textbox(label="SMILES Output", interactive=False)
            mol_img = gr.Image(type="pil", label="Molecule Preview")

        predict2 = gr.Button("Predict on Built Molecule")

    # Outputs
    with gr.Row():
        e_txt = gr.Text(label="Elastic")
        e_img = gr.Image(type="pil", label="Elastic Attention")
    with gr.Row():
        p_txt = gr.Text(label="Plastic")
        p_img = gr.Image(type="pil", label="Plastic Attention")
    with gr.Row():
        b_txt = gr.Text(label="Brittle")
        b_img = gr.Image(type="pil", label="Brittle Attention")

    # Event bindings
    predict1.click(fn=predict_all, inputs=smi_in,
                   outputs=[e_txt, e_img, p_txt, p_img, b_txt, b_img])

    add_a.click(fn=add_atom, inputs=[state, atom_type], outputs=state)\
         .then(fn=update_atoms_list, inputs=state, outputs=atom_tbl)\
         .then(fn=update_atom_dropdowns, inputs=state, outputs=[a1, a2])\
         .then(fn=lambda: "Atom added.", outputs=status)

    add_b.click(fn=add_bond, inputs=[state, a1, a2, bond_type], outputs=state)\
         .then(fn=update_bonds_list, inputs=state, outputs=bond_tbl)\
         .then(fn=lambda: "Bond added/updated.", outputs=status)

    clear.click(fn=init_molecule, outputs=state)\
         .then(fn=lambda: ([], []), outputs=[atom_tbl, bond_tbl])\
         .then(fn=lambda: (gr.update(choices=[], value=None), gr.update(choices=[], value=None)),
               outputs=[a1, a2])\
         .then(fn=lambda: "Cleared all.", outputs=status)

    make.click(fn=generate_smiles, inputs=state, outputs=smi_out)\
        .then(fn=visualize_molecule, inputs=state, outputs=mol_img)\
        .then(fn=lambda: "Molecule generated.", outputs=status)

    predict2.click(fn=lambda s: predict_all(s) if s else ("Enter SMILES", None, "", None, "", None),
                   inputs=smi_out,
                   outputs=[e_txt, e_img, p_txt, p_img, b_txt, b_img])

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)