| import os |
| import zipfile |
| import logging |
| import torch |
| import torch.nn.functional as F |
| import numpy as np |
| from io import BytesIO |
| from PIL import Image |
|
|
| import gradio as gr |
| from torch_geometric.data import Data as PyGData |
| import matplotlib |
| matplotlib.use('Agg') |
|
|
| from rdkit import Chem |
| from rdkit.Chem import Draw, AllChem, MolFromSmiles |
|
|
| |
| |
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" |
| logger.info("Set GPU memory optimization: PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128") |
|
|
| |
| |
| |
| if not os.path.exists("best_model-B-6000-185.pth"): |
| logger.info("Unzipping model archive...") |
| try: |
| with zipfile.ZipFile("models.zip", 'r') as z: |
| z.extractall(".") |
| logger.info("Model archive unzipped successfully.") |
| except Exception as e: |
| logger.error(f"Failed to unzip models.zip: {e}") |
| raise |
|
|
| |
| |
| |
| try: |
| from model_utils import EnhancedGAT, smiles_to_graph, visualize_single_molecule |
| logger.info("Imported model_utils successfully.") |
| except ImportError as e: |
| logger.error(f"Failed to import model_utils: {e}") |
| raise |
|
|
| |
| |
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| logger.info(f"Using device: {device}") |
| if torch.cuda.is_available(): |
| logger.info(f"GPU: {torch.cuda.get_device_name(0)}") |
|
|
| |
| |
| |
| def load_models(): |
| from torch.serialization import add_safe_globals |
| import numpy.core.multiarray |
|
|
| |
| add_safe_globals([numpy.core.multiarray.scalar]) |
|
|
| specs = { |
| "Elastic": ("models/best_model-E-500-68.pth", 2), |
| "Plastic": ("models/best_model-P-5000-180.pth", 2), |
| "Brittle": ("models/best_model-B-6000-185.pth", 2), |
| } |
|
|
| models = {} |
| for name, (path, out_dim) in specs.items(): |
| if not os.path.exists(path): |
| if os.path.exists("models.zip"): |
| logger.info("Extracting models.zip...") |
| with zipfile.ZipFile("models.zip", 'r') as z: |
| z.extractall(".") |
| else: |
| raise FileNotFoundError(f"Missing model file: {path}") |
|
|
| model = EnhancedGAT(input_dim=12, hidden_dim=512, output_dim=out_dim, num_heads=8) |
|
|
| try: |
| state = torch.load(path, map_location=device, weights_only=False) |
| except TypeError: |
| state = torch.load(path, map_location=device) |
|
|
| state_dict = state.get("model_state_dict", state) |
| model.load_state_dict(state_dict) |
| model.eval().to(device) |
| models[name] = model |
| logger.info(f"{name} model loaded successfully.") |
|
|
| return models |
|
|
| models = load_models() |
|
|
| |
| |
| |
| def predict_all(smiles: str): |
| """ |
| Run predictions for Elastic, Plastic, Brittle. |
| Use threshold 0.5 for Elastic/Brittle, 0.3 for Plastic. |
| Return (text, PIL image) for each. |
| """ |
| atom_feats, (rows, cols, edge_attr), _ = smiles_to_graph(smiles) |
| x = torch.tensor(atom_feats, dtype=torch.float) |
| edge_index = torch.tensor(np.vstack((rows, cols)), dtype=torch.long) |
| edge_attr = torch.tensor(edge_attr, dtype=torch.float).unsqueeze(1) |
| data = PyGData(x=x, edge_index=edge_index, edge_attr=edge_attr, |
| smiles=[smiles], batch=torch.zeros(x.size(0), dtype=torch.long)) |
|
|
| outputs = [] |
| thresholds = {"Elastic": 0.5, "Plastic": 0.3, "Brittle": 0.5} |
|
|
| for name in ["Elastic", "Plastic", "Brittle"]: |
| model = models[name] |
| with torch.no_grad(): |
| logits = model(data) |
| |
| if logits.dim() == 1 or logits.size(1) == 1: |
| prob = torch.sigmoid(logits).item() |
| else: |
| prob = F.softmax(logits, dim=1)[0, 1].item() |
| label = int(prob >= thresholds[name]) |
| |
| buf, _ = visualize_single_molecule(model, data, device, name) |
| img = Image.open(buf) if buf else None |
| outputs.append((f"{name}: {label}", img)) |
|
|
| |
| return (*outputs[0], *outputs[1], *outputs[2]) |
|
|
| |
| |
| |
| ATOM_TYPES = ["C", "N", "O", "S", "P", "F", "Cl", "Br", "I", "H"] |
| BOND_TYPES = ["Single", "Double", "Triple"] |
|
|
| def init_molecule(): |
| return {"atoms": [], "bonds": []} |
|
|
| def add_atom(mol, atom_type): |
| mol["atoms"].append({"id": len(mol["atoms"]), "type": atom_type}) |
| return mol |
|
|
| def add_bond(mol, a1_sel, a2_sel, b_type): |
| if not a1_sel or not a2_sel: |
| return mol |
| i1, i2 = int(a1_sel.split(":")[0]), int(a2_sel.split(":")[0]) |
| if {i1, i2} in [{b["atom1"], b["atom2"]} for b in mol["bonds"]]: |
| return mol |
| mol["bonds"].append({"atom1": i1, "atom2": i2, "type": b_type}) |
| return mol |
|
|
| def generate_smiles(mol): |
| try: |
| rw = Chem.RWMol() |
| id_map = {} |
| for atom in mol["atoms"]: |
| idx = rw.AddAtom(Chem.Atom(atom["type"])) |
| id_map[atom["id"]] = idx |
| for b in mol["bonds"]: |
| bond_map = {"Single": Chem.BondType.SINGLE, |
| "Double": Chem.BondType.DOUBLE, |
| "Triple": Chem.BondType.TRIPLE} |
| rw.AddBond(id_map[b["atom1"]], id_map[b["atom2"]], bond_map[b["type"]]) |
| rw.UpdatePropertyCache() |
| Chem.SanitizeMol(rw) |
| return Chem.MolToSmiles(rw) |
| except Exception as e: |
| logger.error(f"SMILES generation failed: {e}") |
| return "" |
|
|
| def visualize_molecule(mol): |
| """Return a PIL image or None.""" |
| smiles = generate_smiles(mol) |
| if not smiles: |
| return None |
| m = MolFromSmiles(smiles) |
| if m is None: |
| return None |
| AllChem.Compute2DCoords(m) |
| return Draw.MolToImage(m, size=(300, 300)) |
|
|
| def update_atom_dropdowns(mol): |
| choices = [f"{a['id']}: {a['type']}" for a in mol["atoms"]] |
| return gr.update(choices=choices, value=None), gr.update(choices=choices, value=None) |
|
|
| def update_atoms_list(mol): |
| return [[a["id"], a["type"]] for a in mol["atoms"]] |
|
|
| def update_bonds_list(mol): |
| out = [] |
| for b in mol["bonds"]: |
| t1 = next(a["type"] for a in mol["atoms"] if a["id"] == b["atom1"]) |
| t2 = next(a["type"] for a in mol["atoms"] if a["id"] == b["atom2"]) |
| out.append([f"{b['atom1']}: {t1}", f"{b['atom2']}: {t2}", b["type"]]) |
| return out |
|
|
| |
| |
| |
| with gr.Blocks(title="CrystalGAT", css=""" |
| .gradio-container {max-width:800px; margin:auto} |
| .gr-button {margin:0.2em} |
| """) as demo: |
|
|
| gr.Markdown("## CrystalGAT \nEnter a SMILES string or build a molecule to predict Elastic, Plastic, and Brittle classes with attention visualization.") |
|
|
| with gr.Tab("SMILES Input"): |
| smi_in = gr.Textbox(label="SMILES", placeholder="e.g. CCO") |
| predict1 = gr.Button("Predict") |
|
|
| with gr.Tab("Manual Molecule Construction"): |
| state = gr.State(init_molecule()) |
| status = gr.Textbox(label="Status", interactive=False, value="Start by adding atoms") |
|
|
| with gr.Row(): |
| with gr.Column(): |
| atom_type = gr.Dropdown(label="Atom Type", choices=ATOM_TYPES, value="C") |
| add_a = gr.Button("Add Atom") |
| atom_tbl = gr.Dataframe(headers=["ID","Type"], datatype=["number","str"], interactive=False) |
| with gr.Column(): |
| a1 = gr.Dropdown(label="Atom 1", choices=[], value=None) |
| a2 = gr.Dropdown(label="Atom 2", choices=[], value=None) |
| bond_type = gr.Dropdown(label="Bond Type", choices=BOND_TYPES, value="Single") |
| add_b = gr.Button("Add Bond") |
| bond_tbl = gr.Dataframe(headers=["Atom1","Atom2","Type"], datatype=["str","str","str"], interactive=False) |
|
|
| with gr.Row(): |
| clear = gr.Button("Clear All") |
| make = gr.Button("Generate SMILES") |
| smi_out = gr.Textbox(label="SMILES Output", interactive=False) |
| mol_img = gr.Image(type="pil", label="Molecule Preview") |
|
|
| predict2 = gr.Button("Predict on Built Molecule") |
|
|
| |
| with gr.Row(): |
| e_txt = gr.Text(label="Elastic") |
| e_img = gr.Image(type="pil", label="Elastic Attention") |
| with gr.Row(): |
| p_txt = gr.Text(label="Plastic") |
| p_img = gr.Image(type="pil", label="Plastic Attention") |
| with gr.Row(): |
| b_txt = gr.Text(label="Brittle") |
| b_img = gr.Image(type="pil", label="Brittle Attention") |
|
|
| |
| predict1.click(fn=predict_all, inputs=smi_in, |
| outputs=[e_txt, e_img, p_txt, p_img, b_txt, b_img]) |
|
|
| add_a.click(fn=add_atom, inputs=[state, atom_type], outputs=state)\ |
| .then(fn=update_atoms_list, inputs=state, outputs=atom_tbl)\ |
| .then(fn=update_atom_dropdowns, inputs=state, outputs=[a1, a2])\ |
| .then(fn=lambda: "Atom added.", outputs=status) |
|
|
| add_b.click(fn=add_bond, inputs=[state, a1, a2, bond_type], outputs=state)\ |
| .then(fn=update_bonds_list, inputs=state, outputs=bond_tbl)\ |
| .then(fn=lambda: "Bond added/updated.", outputs=status) |
|
|
| clear.click(fn=init_molecule, outputs=state)\ |
| .then(fn=lambda: ([], []), outputs=[atom_tbl, bond_tbl])\ |
| .then(fn=lambda: (gr.update(choices=[], value=None), gr.update(choices=[], value=None)), |
| outputs=[a1, a2])\ |
| .then(fn=lambda: "Cleared all.", outputs=status) |
|
|
| make.click(fn=generate_smiles, inputs=state, outputs=smi_out)\ |
| .then(fn=visualize_molecule, inputs=state, outputs=mol_img)\ |
| .then(fn=lambda: "Molecule generated.", outputs=status) |
|
|
| predict2.click(fn=lambda s: predict_all(s) if s else ("Enter SMILES", None, "", None, "", None), |
| inputs=smi_out, |
| outputs=[e_txt, e_img, p_txt, p_img, b_txt, b_img]) |
|
|
| if __name__ == "__main__": |
| demo.launch(server_name="0.0.0.0", server_port=7860) |
|
|