import os import zipfile import logging import torch import torch.nn.functional as F import numpy as np from io import BytesIO from PIL import Image import gradio as gr from torch_geometric.data import Data as PyGData import matplotlib matplotlib.use('Agg') from rdkit import Chem from rdkit.Chem import Draw, AllChem, MolFromSmiles # ---------------------------- # Logging & GPU Configuration # ---------------------------- logging.basicConfig( level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' ) logger = logging.getLogger(__name__) os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" logger.info("Set GPU memory optimization: PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128") # ---------------------------- # Unzip Model Files if Needed # ---------------------------- if not os.path.exists("best_model-B-6000-185.pth"): logger.info("Unzipping model archive...") try: with zipfile.ZipFile("models.zip", 'r') as z: z.extractall(".") logger.info("Model archive unzipped successfully.") except Exception as e: logger.error(f"Failed to unzip models.zip: {e}") raise # ---------------------------- # Import Model Utilities # ---------------------------- try: from model_utils import EnhancedGAT, smiles_to_graph, visualize_single_molecule logger.info("Imported model_utils successfully.") except ImportError as e: logger.error(f"Failed to import model_utils: {e}") raise # ---------------------------- # Device Setup # ---------------------------- device = torch.device("cuda" if torch.cuda.is_available() else "cpu") logger.info(f"Using device: {device}") if torch.cuda.is_available(): logger.info(f"GPU: {torch.cuda.get_device_name(0)}") # ---------------------------- # Model Loading # ---------------------------- def load_models(): from torch.serialization import add_safe_globals import numpy.core.multiarray # allow safe numpy objects if needed add_safe_globals([numpy.core.multiarray.scalar]) specs = { "Elastic": ("models/best_model-E-500-68.pth", 2), "Plastic": ("models/best_model-P-5000-180.pth", 2), "Brittle": ("models/best_model-B-6000-185.pth", 2), } models = {} for name, (path, out_dim) in specs.items(): if not os.path.exists(path): if os.path.exists("models.zip"): logger.info("Extracting models.zip...") with zipfile.ZipFile("models.zip", 'r') as z: z.extractall(".") else: raise FileNotFoundError(f"Missing model file: {path}") model = EnhancedGAT(input_dim=12, hidden_dim=512, output_dim=out_dim, num_heads=8) try: state = torch.load(path, map_location=device, weights_only=False) except TypeError: state = torch.load(path, map_location=device) state_dict = state.get("model_state_dict", state) model.load_state_dict(state_dict) model.eval().to(device) models[name] = model logger.info(f"{name} model loaded successfully.") return models models = load_models() # ---------------------------- # Prediction Function # ---------------------------- def predict_all(smiles: str): """ Run predictions for Elastic, Plastic, Brittle. Use threshold 0.5 for Elastic/Brittle, 0.3 for Plastic. Return (text, PIL image) for each. """ atom_feats, (rows, cols, edge_attr), _ = smiles_to_graph(smiles) x = torch.tensor(atom_feats, dtype=torch.float) edge_index = torch.tensor(np.vstack((rows, cols)), dtype=torch.long) edge_attr = torch.tensor(edge_attr, dtype=torch.float).unsqueeze(1) data = PyGData(x=x, edge_index=edge_index, edge_attr=edge_attr, smiles=[smiles], batch=torch.zeros(x.size(0), dtype=torch.long)) outputs = [] thresholds = {"Elastic": 0.5, "Plastic": 0.3, "Brittle": 0.5} for name in ["Elastic", "Plastic", "Brittle"]: model = models[name] with torch.no_grad(): logits = model(data) # assume binary classification: two outputs if logits.dim() == 1 or logits.size(1) == 1: prob = torch.sigmoid(logits).item() else: prob = F.softmax(logits, dim=1)[0, 1].item() label = int(prob >= thresholds[name]) # get visualization buffer buf, _ = visualize_single_molecule(model, data, device, name) img = Image.open(buf) if buf else None outputs.append((f"{name}: {label}", img)) # flatten to 6 outputs return (*outputs[0], *outputs[1], *outputs[2]) # ---------------------------- # Molecule Builder Utilities # ---------------------------- ATOM_TYPES = ["C", "N", "O", "S", "P", "F", "Cl", "Br", "I", "H"] BOND_TYPES = ["Single", "Double", "Triple"] def init_molecule(): return {"atoms": [], "bonds": []} def add_atom(mol, atom_type): mol["atoms"].append({"id": len(mol["atoms"]), "type": atom_type}) return mol def add_bond(mol, a1_sel, a2_sel, b_type): if not a1_sel or not a2_sel: return mol i1, i2 = int(a1_sel.split(":")[0]), int(a2_sel.split(":")[0]) if {i1, i2} in [{b["atom1"], b["atom2"]} for b in mol["bonds"]]: return mol mol["bonds"].append({"atom1": i1, "atom2": i2, "type": b_type}) return mol def generate_smiles(mol): try: rw = Chem.RWMol() id_map = {} for atom in mol["atoms"]: idx = rw.AddAtom(Chem.Atom(atom["type"])) id_map[atom["id"]] = idx for b in mol["bonds"]: bond_map = {"Single": Chem.BondType.SINGLE, "Double": Chem.BondType.DOUBLE, "Triple": Chem.BondType.TRIPLE} rw.AddBond(id_map[b["atom1"]], id_map[b["atom2"]], bond_map[b["type"]]) rw.UpdatePropertyCache() Chem.SanitizeMol(rw) return Chem.MolToSmiles(rw) except Exception as e: logger.error(f"SMILES generation failed: {e}") return "" def visualize_molecule(mol): """Return a PIL image or None.""" smiles = generate_smiles(mol) if not smiles: return None m = MolFromSmiles(smiles) if m is None: return None AllChem.Compute2DCoords(m) return Draw.MolToImage(m, size=(300, 300)) def update_atom_dropdowns(mol): choices = [f"{a['id']}: {a['type']}" for a in mol["atoms"]] return gr.update(choices=choices, value=None), gr.update(choices=choices, value=None) def update_atoms_list(mol): return [[a["id"], a["type"]] for a in mol["atoms"]] def update_bonds_list(mol): out = [] for b in mol["bonds"]: t1 = next(a["type"] for a in mol["atoms"] if a["id"] == b["atom1"]) t2 = next(a["type"] for a in mol["atoms"] if a["id"] == b["atom2"]) out.append([f"{b['atom1']}: {t1}", f"{b['atom2']}: {t2}", b["type"]]) return out # ---------------------------- # Gradio Interface # ---------------------------- with gr.Blocks(title="CrystalGAT", css=""" .gradio-container {max-width:800px; margin:auto} .gr-button {margin:0.2em} """) as demo: gr.Markdown("## CrystalGAT \nEnter a SMILES string or build a molecule to predict Elastic, Plastic, and Brittle classes with attention visualization.") with gr.Tab("SMILES Input"): smi_in = gr.Textbox(label="SMILES", placeholder="e.g. CCO") predict1 = gr.Button("Predict") with gr.Tab("Manual Molecule Construction"): state = gr.State(init_molecule()) status = gr.Textbox(label="Status", interactive=False, value="Start by adding atoms") with gr.Row(): with gr.Column(): atom_type = gr.Dropdown(label="Atom Type", choices=ATOM_TYPES, value="C") add_a = gr.Button("Add Atom") atom_tbl = gr.Dataframe(headers=["ID","Type"], datatype=["number","str"], interactive=False) with gr.Column(): a1 = gr.Dropdown(label="Atom 1", choices=[], value=None) a2 = gr.Dropdown(label="Atom 2", choices=[], value=None) bond_type = gr.Dropdown(label="Bond Type", choices=BOND_TYPES, value="Single") add_b = gr.Button("Add Bond") bond_tbl = gr.Dataframe(headers=["Atom1","Atom2","Type"], datatype=["str","str","str"], interactive=False) with gr.Row(): clear = gr.Button("Clear All") make = gr.Button("Generate SMILES") smi_out = gr.Textbox(label="SMILES Output", interactive=False) mol_img = gr.Image(type="pil", label="Molecule Preview") predict2 = gr.Button("Predict on Built Molecule") # Outputs with gr.Row(): e_txt = gr.Text(label="Elastic") e_img = gr.Image(type="pil", label="Elastic Attention") with gr.Row(): p_txt = gr.Text(label="Plastic") p_img = gr.Image(type="pil", label="Plastic Attention") with gr.Row(): b_txt = gr.Text(label="Brittle") b_img = gr.Image(type="pil", label="Brittle Attention") # Event bindings predict1.click(fn=predict_all, inputs=smi_in, outputs=[e_txt, e_img, p_txt, p_img, b_txt, b_img]) add_a.click(fn=add_atom, inputs=[state, atom_type], outputs=state)\ .then(fn=update_atoms_list, inputs=state, outputs=atom_tbl)\ .then(fn=update_atom_dropdowns, inputs=state, outputs=[a1, a2])\ .then(fn=lambda: "Atom added.", outputs=status) add_b.click(fn=add_bond, inputs=[state, a1, a2, bond_type], outputs=state)\ .then(fn=update_bonds_list, inputs=state, outputs=bond_tbl)\ .then(fn=lambda: "Bond added/updated.", outputs=status) clear.click(fn=init_molecule, outputs=state)\ .then(fn=lambda: ([], []), outputs=[atom_tbl, bond_tbl])\ .then(fn=lambda: (gr.update(choices=[], value=None), gr.update(choices=[], value=None)), outputs=[a1, a2])\ .then(fn=lambda: "Cleared all.", outputs=status) make.click(fn=generate_smiles, inputs=state, outputs=smi_out)\ .then(fn=visualize_molecule, inputs=state, outputs=mol_img)\ .then(fn=lambda: "Molecule generated.", outputs=status) predict2.click(fn=lambda s: predict_all(s) if s else ("Enter SMILES", None, "", None, "", None), inputs=smi_out, outputs=[e_txt, e_img, p_txt, p_img, b_txt, b_img]) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)