Spaces:
Sleeping
Sleeping
| import os | |
| import zipfile | |
| import logging | |
| import torch | |
| import torch.nn.functional as F | |
| import numpy as np | |
| from io import BytesIO | |
| from PIL import Image | |
| import gradio as gr | |
| from torch_geometric.data import Data as PyGData | |
| import matplotlib | |
| matplotlib.use('Agg') | |
| from rdkit import Chem | |
| from rdkit.Chem import Draw, AllChem, MolFromSmiles | |
| # ---------------------------- | |
| # Logging & GPU Configuration | |
| # ---------------------------- | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
| ) | |
| logger = logging.getLogger(__name__) | |
| os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "max_split_size_mb:128" | |
| logger.info("Set GPU memory optimization: PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:128") | |
| # ---------------------------- | |
| # Unzip Model Files if Needed | |
| # ---------------------------- | |
| if not os.path.exists("best_model-B-6000-185.pth"): | |
| logger.info("Unzipping model archive...") | |
| try: | |
| with zipfile.ZipFile("models.zip", 'r') as z: | |
| z.extractall(".") | |
| logger.info("Model archive unzipped successfully.") | |
| except Exception as e: | |
| logger.error(f"Failed to unzip models.zip: {e}") | |
| raise | |
| # ---------------------------- | |
| # Import Model Utilities | |
| # ---------------------------- | |
| try: | |
| from model_utils import EnhancedGAT, smiles_to_graph, visualize_single_molecule | |
| logger.info("Imported model_utils successfully.") | |
| except ImportError as e: | |
| logger.error(f"Failed to import model_utils: {e}") | |
| raise | |
| # ---------------------------- | |
| # Device Setup | |
| # ---------------------------- | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| logger.info(f"Using device: {device}") | |
| if torch.cuda.is_available(): | |
| logger.info(f"GPU: {torch.cuda.get_device_name(0)}") | |
| # ---------------------------- | |
| # Model Loading | |
| # ---------------------------- | |
| def load_models(): | |
| from torch.serialization import add_safe_globals | |
| import numpy.core.multiarray | |
| # allow safe numpy objects if needed | |
| add_safe_globals([numpy.core.multiarray.scalar]) | |
| specs = { | |
| "Elastic": ("models/best_model-E-500-68.pth", 2), | |
| "Plastic": ("models/best_model-P-5000-180.pth", 2), | |
| "Brittle": ("models/best_model-B-6000-185.pth", 2), | |
| } | |
| models = {} | |
| for name, (path, out_dim) in specs.items(): | |
| if not os.path.exists(path): | |
| if os.path.exists("models.zip"): | |
| logger.info("Extracting models.zip...") | |
| with zipfile.ZipFile("models.zip", 'r') as z: | |
| z.extractall(".") | |
| else: | |
| raise FileNotFoundError(f"Missing model file: {path}") | |
| model = EnhancedGAT(input_dim=12, hidden_dim=512, output_dim=out_dim, num_heads=8) | |
| try: | |
| state = torch.load(path, map_location=device, weights_only=False) | |
| except TypeError: | |
| state = torch.load(path, map_location=device) | |
| state_dict = state.get("model_state_dict", state) | |
| model.load_state_dict(state_dict) | |
| model.eval().to(device) | |
| models[name] = model | |
| logger.info(f"{name} model loaded successfully.") | |
| return models | |
| models = load_models() | |
| # ---------------------------- | |
| # Prediction Function | |
| # ---------------------------- | |
| def predict_all(smiles: str): | |
| """ | |
| Run predictions for Elastic, Plastic, Brittle. | |
| Use threshold 0.5 for Elastic/Brittle, 0.3 for Plastic. | |
| Return (text, PIL image) for each. | |
| """ | |
| try: | |
| atom_feats, (rows, cols, edge_attr), _ = smiles_to_graph(smiles) | |
| x = torch.tensor(atom_feats, dtype=torch.float) | |
| edge_index = torch.tensor(np.vstack((rows, cols)), dtype=torch.long) | |
| edge_attr = torch.tensor(edge_attr, dtype=torch.float).unsqueeze(1) | |
| data = PyGData(x=x, edge_index=edge_index, edge_attr=edge_attr, | |
| smiles=[smiles], batch=torch.zeros(x.size(0), dtype=torch.long)) | |
| outputs = [] | |
| thresholds = {"Elastic": 0.5, "Plastic": 0.3, "Brittle": 0.5} | |
| for name in ["Elastic", "Plastic", "Brittle"]: | |
| model = models[name] | |
| with torch.no_grad(): | |
| logits = model(data) | |
| # assume binary classification: two outputs | |
| if logits.dim() == 1 or logits.size(1) == 1: | |
| prob = torch.sigmoid(logits).item() | |
| else: | |
| prob = F.softmax(logits, dim=1)[0, 1].item() | |
| label = int(prob >= thresholds[name]) | |
| # get visualization buffer | |
| buf, _ = visualize_single_molecule(model, data, device, name) | |
| img = Image.open(buf) if buf else None | |
| outputs.append((f"{name}: {label}", img)) | |
| # flatten to 6 outputs | |
| return (*outputs[0], *outputs[1], *outputs[2]) | |
| except Exception as e: | |
| logger.error(f"Prediction failed: {e}") | |
| return "Error: Invalid SMILES", None, "Error: Invalid SMILES", None, "Error: Invalid SMILES", None | |
| # ---------------------------- | |
| # Molecule Builder Utilities | |
| # ---------------------------- | |
| ATOM_TYPES = ["C", "N", "O", "S", "P", "F", "Cl", "Br", "I", "H"] | |
| BOND_TYPES = ["Single", "Double", "Triple"] | |
| def init_molecule(): | |
| return {"atoms": [], "bonds": []} | |
| def add_atom(mol, atom_type): | |
| mol["atoms"].append({"id": len(mol["atoms"]), "type": atom_type}) | |
| return mol | |
| def add_bond(mol, a1_sel, a2_sel, b_type): | |
| if not a1_sel or not a2_sel: | |
| return mol | |
| i1, i2 = int(a1_sel.split(":")[0]), int(a2_sel.split(":")[0]) | |
| if {i1, i2} in [{b["atom1"], b["atom2"]} for b in mol["bonds"]]: | |
| return mol | |
| mol["bonds"].append({"atom1": i1, "atom2": i2, "type": b_type}) | |
| return mol | |
| def generate_smiles(mol): | |
| try: | |
| rw = Chem.RWMol() | |
| id_map = {} | |
| for atom in mol["atoms"]: | |
| idx = rw.AddAtom(Chem.Atom(atom["type"])) | |
| id_map[atom["id"]] = idx | |
| for b in mol["bonds"]: | |
| bond_map = {"Single": Chem.BondType.SINGLE, | |
| "Double": Chem.BondType.DOUBLE, | |
| "Triple": Chem.BondType.TRIPLE} | |
| rw.AddBond(id_map[b["atom1"]], id_map[b["atom2"]], bond_map[b["type"]]) | |
| rw.UpdatePropertyCache() | |
| Chem.SanitizeMol(rw) | |
| return Chem.MolToSmiles(rw) | |
| except Exception as e: | |
| logger.error(f"SMILES generation failed: {e}") | |
| return "" | |
| def visualize_molecule(mol): | |
| """Return a PIL image or None.""" | |
| smiles = generate_smiles(mol) | |
| if not smiles: | |
| return None | |
| m = MolFromSmiles(smiles) | |
| if m is None: | |
| return None | |
| AllChem.Compute2DCoords(m) | |
| return Draw.MolToImage(m, size=(300, 300)) | |
| def update_atom_dropdowns(mol): | |
| choices = [f"{a['id']}: {a['type']}" for a in mol["atoms"]] | |
| return gr.update(choices=choices, value=None), gr.update(choices=choices, value=None) | |
| def update_atoms_list(mol): | |
| return [[a["id"], a["type"]] for a in mol["atoms"]] | |
| def update_bonds_list(mol): | |
| out = [] | |
| for b in mol["bonds"]: | |
| t1 = next(a["type"] for a in mol["atoms"] if a["id"] == b["atom1"]) | |
| t2 = next(a["type"] for a in mol["atoms"] if a["id"] == b["atom2"]) | |
| out.append([f"{b['atom1']}: {t1}", f"{b['atom2']}: {t2}", b["type"]]) | |
| return out | |
| # ---------------------------- | |
| # Gradio Interface | |
| # ---------------------------- | |
| with gr.Blocks(title="CrystalGAT", css=""" | |
| .gradio-container {max-width:800px; margin:auto} | |
| .gr-button {margin:0.2em} | |
| """) as demo: | |
| gr.Markdown("## CrystalGAT \nEnter a SMILES string or build a molecule to predict Elastic, Plastic, and Brittle classes with attention visualization.") | |
| with gr.Tab("SMILES Input"): | |
| smi_in = gr.Textbox(label="SMILES", placeholder="e.g. CCO") | |
| predict1 = gr.Button("Predict") | |
| with gr.Tab("Manual Molecule Construction"): | |
| state = gr.State(init_molecule()) | |
| status = gr.Textbox(label="Status", interactive=False, value="Start by adding atoms") | |
| with gr.Row(): | |
| with gr.Column(): | |
| atom_type = gr.Dropdown(label="Atom Type", choices=ATOM_TYPES, value="C") | |
| add_a = gr.Button("Add Atom") | |
| atom_tbl = gr.Dataframe(headers=["ID","Type"], datatype=["number","str"], interactive=False) | |
| with gr.Column(): | |
| a1 = gr.Dropdown(label="Atom 1", choices=[], value=None) | |
| a2 = gr.Dropdown(label="Atom 2", choices=[], value=None) | |
| bond_type = gr.Dropdown(label="Bond Type", choices=BOND_TYPES, value="Single") | |
| add_b = gr.Button("Add Bond") | |
| bond_tbl = gr.Dataframe(headers=["Atom1","Atom2","Type"], datatype=["str","str","str"], interactive=False) | |
| with gr.Row(): | |
| clear = gr.Button("Clear All") | |
| make = gr.Button("Generate SMILES") | |
| smi_out = gr.Textbox(label="SMILES Output", interactive=False) | |
| mol_img = gr.Image(type="pil", label="Molecule Preview") | |
| predict2 = gr.Button("Predict on Built Molecule") | |
| # Outputs | |
| with gr.Row(): | |
| e_txt = gr.Text(label="Elastic") | |
| e_img = gr.Image(type="pil", label="Elastic Attention") | |
| with gr.Row(): | |
| p_txt = gr.Text(label="Plastic") | |
| p_img = gr.Image(type="pil", label="Plastic Attention") | |
| with gr.Row(): | |
| b_txt = gr.Text(label="Brittle") | |
| b_img = gr.Image(type="pil", label="Brittle Attention") | |
| # Event bindings | |
| predict1.click(fn=predict_all, inputs=smi_in, | |
| outputs=[e_txt, e_img, p_txt, p_img, b_txt, b_img]) | |
| add_a.click(fn=add_atom, inputs=[state, atom_type], outputs=state)\ | |
| .then(fn=update_atoms_list, inputs=state, outputs=atom_tbl)\ | |
| .then(fn=update_atom_dropdowns, inputs=state, outputs=[a1, a2])\ | |
| .then(fn=lambda: "Atom added.", outputs=status) | |
| add_b.click(fn=add_bond, inputs=[state, a1, a2, bond_type], outputs=state)\ | |
| .then(fn=update_bonds_list, inputs=state, outputs=bond_tbl)\ | |
| .then(fn=lambda: "Bond added/updated.", outputs=status) | |
| clear.click(fn=init_molecule, outputs=state)\ | |
| .then(fn=lambda: ([], []), outputs=[atom_tbl, bond_tbl])\ | |
| .then(fn=lambda: (gr.update(choices=[], value=None), gr.update(choices=[], value=None)), | |
| outputs=[a1, a2])\ | |
| .then(fn=lambda: "Cleared all.", outputs=status) | |
| make.click(fn=generate_smiles, inputs=state, outputs=smi_out)\ | |
| .then(fn=visualize_molecule, inputs=state, outputs=mol_img)\ | |
| .then(fn=lambda: "Molecule generated.", outputs=status) | |
| predict2.click(fn=lambda s: predict_all(s) if s else ("Enter SMILES", None, "", None, "", None), | |
| inputs=smi_out, | |
| outputs=[e_txt, e_img, p_txt, p_img, b_txt, b_img]) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0", server_port=7860) | |