Spaces:
Sleeping
Sleeping
| # --- Monkeypatch for Gradio 4.44.0 bug --- | |
| # Fix: TypeError / APIInfoParseError when additionalProperties is True (bool) | |
| # in gradio_client.utils when processing JSON schemas | |
| import gradio_client.utils as _gc_utils | |
| _orig_get_type = _gc_utils.get_type | |
| def _patched_get_type(schema): | |
| if isinstance(schema, bool): | |
| return "Any" | |
| return _orig_get_type(schema) | |
| _gc_utils.get_type = _patched_get_type | |
| _orig_json_schema = _gc_utils._json_schema_to_python_type | |
| def _patched_json_schema(schema, defs=None): | |
| if isinstance(schema, bool): | |
| return "Any" | |
| return _orig_json_schema(schema, defs) | |
| _gc_utils._json_schema_to_python_type = _patched_json_schema | |
| # --- End monkeypatch --- | |
| import gradio as gr | |
| import torch | |
| import pickle | |
| import os, zipfile, io | |
| import tempfile | |
| import numpy as np | |
| import scipy.spatial | |
| import pandas as pd | |
| from rdkit import Chem | |
| from rdkit.Chem import AllChem | |
| from MolecularDiffusion.core import Engine | |
| from MolecularDiffusion.utils import seed_everything | |
| # ββ GLOBALS ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| MODEL_DIR = os.path.join(os.path.dirname(__file__), "model") | |
| # Standard atom vocab for QM9/GEOM used in training | |
| ATOM_VOCAB = ["H","B","C","N","O","F","Al","Si","P","S","Cl","As","Se","Br","I","Hg","Bi"] | |
| DEVICE = torch.device("cpu") # Force CPU for free spaces | |
| # Track currently loaded model | |
| TASK = None | |
| LOADED_MODEL_NAME = None | |
| def discover_models(): | |
| """Scan model/ for subfolders containing edm_chem.pkl.""" | |
| models = [] | |
| if os.path.isdir(MODEL_DIR): | |
| for name in sorted(os.listdir(MODEL_DIR)): | |
| sub = os.path.join(MODEL_DIR, name) | |
| if os.path.isdir(sub) and os.path.exists(os.path.join(sub, "edm_chem.pkl")): | |
| models.append(name) | |
| return models | |
| AVAILABLE_MODELS = discover_models() | |
| print(f"Discovered models: {AVAILABLE_MODELS}") | |
| def get_condition_names(task): | |
| """Return conditioning property names from loaded task.""" | |
| if task is None: | |
| return [] | |
| cond = getattr(task, "condition", []) | |
| if cond is None: | |
| return [] | |
| return list(cond) | |
| def parse_condition_row(df_like, expected_names, required=True): | |
| """Parse first-row values from a gr.Dataframe payload.""" | |
| if not expected_names: | |
| return [] | |
| if df_like is None: | |
| if required: | |
| raise ValueError("Missing conditioning inputs.") | |
| return [] | |
| if isinstance(df_like, pd.DataFrame): | |
| if df_like.empty: | |
| row_values = [] | |
| else: | |
| row_values = df_like.iloc[0].tolist() | |
| elif isinstance(df_like, list): | |
| row_values = df_like[0] if len(df_like) > 0 else [] | |
| else: | |
| row_values = [] | |
| # Layout A (legacy): one row, one column per property. | |
| if isinstance(df_like, pd.DataFrame) and all(name in df_like.columns for name in expected_names): | |
| out = [] | |
| for name in expected_names: | |
| raw = df_like.iloc[0][name] if not df_like.empty else None | |
| if raw is None or (isinstance(raw, float) and np.isnan(raw)) or str(raw).strip() == "": | |
| if required: | |
| raise ValueError(f"Target value required for property '{name}'.") | |
| continue | |
| try: | |
| out.append(float(raw)) | |
| except Exception as exc: | |
| raise ValueError(f"Invalid numeric value for property '{name}': {raw}") from exc | |
| return out | |
| # Layout B (current): two columns [property, value], one row per property. | |
| rows = df_like.values.tolist() if isinstance(df_like, pd.DataFrame) else (df_like or []) | |
| value_map = {} | |
| for row in rows: | |
| if len(row) < 2: | |
| continue | |
| prop, raw = row[0], row[1] | |
| if prop is None: | |
| continue | |
| value_map[str(prop)] = raw | |
| out = [] | |
| for name in expected_names: | |
| raw = value_map.get(name, None) | |
| if raw is None or (isinstance(raw, float) and np.isnan(raw)) or str(raw).strip() == "": | |
| if required: | |
| raise ValueError(f"Target value required for property '{name}'.") | |
| continue | |
| try: | |
| out.append(float(raw)) | |
| except Exception as exc: | |
| raise ValueError(f"Invalid numeric value for property '{name}': {raw}") from exc | |
| return out | |
| # ββ MODEL LOADING ββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_model(model_name): | |
| """Load a model from model/<model_name>/.""" | |
| global TASK, LOADED_MODEL_NAME | |
| model_subdir = os.path.join(MODEL_DIR, model_name) | |
| print(f"Loading model '{model_name}' on {DEVICE}...") | |
| try: | |
| # Initialize empty engine | |
| engine = Engine(None, None, None, None, None) | |
| # Load from pickle checkpoint | |
| # interference_mode=True skips optimizer/dataset loading | |
| chkpt_path = os.path.join(model_subdir, "edm_chem.pkl") | |
| if not os.path.exists(chkpt_path): | |
| raise FileNotFoundError(f"Checkpoint not found at {chkpt_path}") | |
| engine = engine.load_from_checkpoint( | |
| chkpt_path, interference_mode=True | |
| ) | |
| task = engine.model | |
| stat_path = os.path.join(model_subdir, "edm_stat.pkl") | |
| if os.path.exists(stat_path): | |
| with open(stat_path, "rb") as f: | |
| stats = pickle.load(f) | |
| task.node_dist_model = stats.get("node") | |
| if "prop" in stats: | |
| task.prop_dist_model = stats.get("prop") | |
| else: | |
| print("Warning: edm_stat.pkl not found. Size sampling might fail.") | |
| # Set vocab manually if missing | |
| if not hasattr(task, 'atom_vocab') or task.atom_vocab is None: | |
| task.atom_vocab = ATOM_VOCAB | |
| # ββ CRITICAL FIXES FOR DEVICE COMPATIBILITY ββ | |
| # The pickled model might have components with persistent 'cuda' device attributes | |
| # or cached CUDA tensors that don't get moved by .to(device). | |
| # We must manually clean them up. | |
| task.to(DEVICE) | |
| task.eval() | |
| print("Patching model for CPU compatibility...") | |
| # 1. Monkeypatch task.device property if accessing it returns CUDA | |
| try: | |
| # If it returns CUDA, we force patch it | |
| if task.device.type != 'cpu': | |
| print(" Monkeypatching task.device") | |
| try: | |
| task.device = DEVICE | |
| except AttributeError: | |
| type(task).device = property(lambda self: torch.device("cpu")) | |
| except Exception: | |
| pass | |
| # 2. Iterate over all modules to fix attributes and clear caches | |
| for name, module in task.named_modules(): | |
| # Clear _edges_dict cache (contains CUDA tensors) | |
| if hasattr(module, "_edges_dict"): | |
| module._edges_dict = {} | |
| # Fix 'device' attribute if it exists and is stored as value (not property) | |
| if hasattr(module, "device"): | |
| # Check if property | |
| is_property = isinstance(getattr(type(module), "device", None), property) | |
| if not is_property: | |
| try: | |
| module.device = DEVICE | |
| except Exception: | |
| pass | |
| print(f"Model '{model_name}' loaded successfully!") | |
| TASK = task | |
| LOADED_MODEL_NAME = model_name | |
| return task | |
| except Exception as e: | |
| print(f"CRITICAL ERROR loading model '{model_name}': {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return None | |
| # Load first available model at startup | |
| if AVAILABLE_MODELS: | |
| load_model(AVAILABLE_MODELS[0]) | |
| else: | |
| print("WARNING: No models found in model/ directory!") | |
| # ββ HELPERS ββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def tensors_to_xyz_string(one_hot_i, x_i, node_mask_i, atom_vocab): | |
| """Convert single-molecule tensors to XYZ string.""" | |
| atoms = torch.argmax(one_hot_i, dim=1) | |
| n_atoms = int(node_mask_i.squeeze(-1).sum().item()) | |
| lines = [f"{n_atoms}", "Generated by MolecularDiffusion"] | |
| for k in range(n_atoms): | |
| idx = atoms[k].item() | |
| if idx < len(atom_vocab): | |
| symbol = atom_vocab[idx] | |
| else: | |
| symbol = "X" | |
| pos = x_i[k] | |
| lines.append(f"{symbol} {pos[0]:.6f} {pos[1]:.6f} {pos[2]:.6f}") | |
| return "\n".join(lines) | |
| def parse_composition(xyz_str): | |
| """Parse XYZ string to get basic composition.""" | |
| lines = xyz_str.strip().split('\n') | |
| try: | |
| n_atoms = int(lines[0]) | |
| atoms = [line.split()[0] for line in lines[2:2+n_atoms]] | |
| formula = {} | |
| for a in atoms: | |
| formula[a] = formula.get(a, 0) + 1 | |
| # Sort by Hill system (C first, H second, then alphabetical) is complex, | |
| # so just alphabetical for now | |
| formula_str = "".join([f"{k}{v if v>1 else ''}" for k,v in sorted(formula.items())]) | |
| return {"Atoms": n_atoms, "Formula": formula_str} | |
| except: | |
| return {"Atoms": 0, "Formula": "Error"} | |
| def build_adjacency_matrix(positions: np.ndarray, atomic_numbers: list, scale: float = 1.25) -> np.ndarray: | |
| """Build binary adjacency matrix based on scaled covalent radii.""" | |
| from ase.data import covalent_radii | |
| n = len(atomic_numbers) | |
| adj = np.zeros((n, n), dtype=int) | |
| radii = np.array([covalent_radii[z] for z in atomic_numbers]) | |
| dist_matrix = scipy.spatial.distance.cdist(positions, positions) | |
| thresholds = (radii[:, None] + radii[None, :]) * scale | |
| mask = (dist_matrix <= thresholds) & (dist_matrix > 0.05) | |
| adj[mask] = 1 | |
| return adj | |
| def create_xyz_zip(xyz_strings): | |
| """Create a zip file containing all generated XYZs.""" | |
| if not xyz_strings: | |
| return None | |
| # Create a temporary file | |
| tmp_zip = tempfile.NamedTemporaryFile(suffix='.zip', delete=False) | |
| tmp_zip.close() | |
| with zipfile.ZipFile(tmp_zip.name, 'w') as zf: | |
| for i, xyz_content in enumerate(xyz_strings): | |
| # writestr takes (archive_name, data) | |
| zf.writestr(f"molecule_{i:03d}.xyz", xyz_content) | |
| return tmp_zip.name | |
| # ββ 3Dmol.js SETUP (loaded via head tag) ββββββββββββββββββββββ | |
| def xyz_to_pdb(xyz_str): | |
| """Convert XYZ string to PDB format for the viewer.""" | |
| lines = xyz_str.strip().splitlines() | |
| if len(lines) < 3: | |
| return "" | |
| atoms = lines[2:] # Skip header | |
| pdb_lines = [] | |
| for i, line in enumerate(atoms): | |
| parts = line.split() | |
| if len(parts) >= 4: | |
| elem = parts[0] | |
| try: | |
| x, y, z = float(parts[1]), float(parts[2]), float(parts[3]) | |
| except ValueError: | |
| continue | |
| # Simple HETATM record | |
| pdb_lines.append( | |
| f"HETATM{i+1:>5} {elem:<3} UNK 1 {x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00 {elem:>2}" | |
| ) | |
| return "\n".join(pdb_lines) | |
| def save_to_format(xyz_str, idx, fmt="pdb"): | |
| """Save molecule to requested format (pdb or xyz) and return path.""" | |
| suffix = f".{fmt.lower()}" | |
| content = xyz_to_pdb(xyz_str) if fmt.lower() == "pdb" else xyz_str | |
| tmp = tempfile.NamedTemporaryFile( | |
| suffix=suffix, prefix=f"mol_{idx:03d}_", delete=False, mode="w" | |
| ) | |
| tmp.write(content) | |
| tmp.close() | |
| return tmp.name | |
| # ββ GENERATION FUNCTION βββββββββββββββββββββββββββββββββββββββ | |
| def generate( | |
| model_name, | |
| num_molecules, | |
| num_runs, | |
| size_mode, | |
| fixed_size, | |
| diffusion_steps, | |
| seed, | |
| target_values_df, | |
| negative_values_df, | |
| cfg_scale, | |
| ): | |
| global TASK, LOADED_MODEL_NAME | |
| # Reload model if selection changed | |
| if model_name != LOADED_MODEL_NAME: | |
| print(f"Switching model: {LOADED_MODEL_NAME} -> {model_name}") | |
| load_model(model_name) | |
| if TASK is None: | |
| return "", gr.update(choices=[], value=None), [], None, None | |
| total_requested = int(num_molecules) * int(num_runs) | |
| print( | |
| f"Generating {total_requested} molecules as {num_runs} run(s) x batch {num_molecules} " | |
| f"with '{model_name}' (Steps: {diffusion_steps}, Seed: {seed})..." | |
| ) | |
| seed_everything(int(seed)) | |
| # 1. Override diffusion steps | |
| # Handle both EGCL and other architectures | |
| original_T = None | |
| if hasattr(TASK, 'model') and hasattr(TASK.model, 'T'): | |
| original_T = TASK.model.T | |
| TASK.model.T = int(diffusion_steps) | |
| elif hasattr(TASK, 'T'): | |
| original_T = TASK.T | |
| TASK.T = int(diffusion_steps) | |
| try: | |
| # 2. Resolve conditioning mode/inputs once | |
| condition_names = get_condition_names(TASK) | |
| is_conditional = len(condition_names) > 0 | |
| if is_conditional: | |
| target_values = parse_condition_row( | |
| target_values_df, condition_names, required=True | |
| ) | |
| negative_values = parse_condition_row( | |
| negative_values_df, condition_names, required=False | |
| ) | |
| # 3. Sample across runs | |
| xyz_strings = [] | |
| summary_rows = [] | |
| target_fn = lambda z, t: torch.zeros( | |
| z.size(0), device=z.device, dtype=z.dtype | |
| ) | |
| for run_idx in range(int(num_runs)): | |
| # Determine molecule sizes for this run | |
| if "Auto" in size_mode: | |
| if TASK.node_dist_model is not None: | |
| nodesxsample = TASK.node_dist_model.sample(int(num_molecules)) | |
| else: | |
| nodesxsample = torch.randint(10, 30, (int(num_molecules),)) | |
| else: | |
| nodesxsample = torch.tensor([fixed_size] * int(num_molecules)) | |
| nodesxsample = nodesxsample.to(DEVICE).long() | |
| if is_conditional: | |
| one_hot, charges, x, node_mask = TASK.sample_guidance_conitional( | |
| target_function=target_fn, | |
| target_value=target_values, | |
| negative_target_value=negative_values, | |
| nodesxsample=nodesxsample, | |
| gg_scale=0.0, | |
| cfg_scale=float(cfg_scale), | |
| guidance_ver="cfg", | |
| n_frames=0, | |
| fix_noise=False, | |
| ) | |
| else: | |
| one_hot, charges, x, node_mask = TASK.sample( | |
| nodesxsample=nodesxsample, | |
| mode="ddpm", | |
| n_frames=0, | |
| fix_noise=False, | |
| ) | |
| for i in range(int(num_molecules)): | |
| xyz_str = tensors_to_xyz_string( | |
| one_hot[i], | |
| x[i], | |
| node_mask[i], | |
| TASK.atom_vocab | |
| ) | |
| xyz_strings.append(xyz_str) | |
| summary_rows.append(parse_composition(xyz_str)) | |
| # 4. Output generation β save zip for bulk download | |
| zip_path = create_xyz_zip(xyz_strings) | |
| # Prepare table with "Name" column | |
| for i, row in enumerate(summary_rows): | |
| row["Name"] = f"Molecule {i+1}" | |
| df_out = pd.DataFrame(summary_rows) | |
| # Reorder to put Name first | |
| cols = ["Name"] + [c for c in df_out.columns if c != "Name"] | |
| df_out = df_out[cols] | |
| choices = [f"Molecule {i+1}" for i in range(len(xyz_strings))] | |
| return ( | |
| xyz_strings[0], # current_xyz (triggers JS) | |
| gr.update(choices=choices, value=choices[0]), # selector | |
| xyz_strings, # raw_xyz_state | |
| zip_path, # download_all | |
| df_out, # table | |
| ) | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| return "", gr.update(choices=[], value=None), [], None, None | |
| finally: | |
| # Restore original T to be safe | |
| if original_T is not None: | |
| if hasattr(TASK, 'model') and hasattr(TASK.model, 'T'): | |
| TASK.model.T = original_T | |
| elif hasattr(TASK, 'T'): | |
| TASK.T = original_T | |
| # ββ 3Dmol.js JavaScript βββββββββββββββββββββββββββββββββββββββ | |
| THREEDMOL_HEAD = """ | |
| <script src="https://unpkg.com/3dmol/build/3Dmol-min.js"></script> | |
| <style> | |
| #mol3d-container { width: 100%; height: 480px; position: relative; border: 1px solid #ccc; border-radius: 8px; overflow: hidden; cursor: crosshair; } | |
| #measurement-display { background: #f9f9f9; padding: 10px; border-radius: 8px; border: 1px solid #ddd; font-family: monospace; font-size: 0.9em; min-height: 100px; } | |
| #measurement-status b { color: #2c3e50; } | |
| </style> | |
| <script> | |
| var _viewer = null; | |
| var _currentXYZ = null; | |
| var _selectedAtoms = []; // Array of {atom, index, sphereID, labelID} | |
| // Global error handler | |
| window.onerror = function(msg, url, line, col, error) { | |
| var el = document.querySelector('#viewer-debug-log textarea'); | |
| if (el) { | |
| el.value = "JS ERROR: " + msg + String.fromCharCode(10) + el.value; | |
| el.dispatchEvent(new Event('input', { bubbles: true })); | |
| } | |
| return false; | |
| }; | |
| function logDebug(msg) { | |
| console.log(msg); | |
| var el = document.querySelector('#viewer-debug-log textarea'); | |
| if (el) { | |
| el.value = msg + String.fromCharCode(10) + el.value; | |
| el.dispatchEvent(new Event('input', { bubbles: true })); | |
| } | |
| } | |
| function initViewer() { | |
| // Poll for library load | |
| if (typeof $3Dmol === 'undefined') { | |
| setTimeout(initViewer, 200); | |
| return; | |
| } | |
| logDebug("Initializing 3Dmol viewer..."); | |
| var el = document.getElementById('mol3d-container'); | |
| if (!el) { logDebug("ERROR: Viewer container not found!"); return; } | |
| if (_viewer) { | |
| try { _viewer.clear(); } catch(e) { logDebug("Error clearing viewer: " + e); } | |
| } else { | |
| try { | |
| _viewer = $3Dmol.createViewer(el, {backgroundColor: '0xffffff'}); | |
| logDebug("Viewer created."); | |
| } catch(e) { | |
| logDebug("CRITICAL: Failed to create viewer: " + e); | |
| } | |
| } | |
| } | |
| function updateStatus(msg) { | |
| var el = document.getElementById('measurement-status'); | |
| if (el) { el.innerHTML = msg; } | |
| } | |
| function getAtomName(a) { | |
| return a.elem + (a.serial || (a.index + 1)); | |
| } | |
| function syncMeasurements() { | |
| try { | |
| _viewer.removeAllLabels(); | |
| _viewer.removeAllShapes(); | |
| // Re-add selection highlights after clearing | |
| _selectedAtoms = _selectedAtoms.map(sa => { | |
| let a = sa.atom; | |
| let sphereID = _viewer.addSphere({center:{x:a.x, y:a.y, z:a.z}, radius:0.55, color:'#ffd400', opacity:0.6, clickable:false}); | |
| let labelID = _viewer.addLabel(getAtomName(a), {position:{x:a.x, y:a.y, z:a.z}, fontColor:'black', backgroundColor:'yellow', backgroundOpacity:0.6, fontSize:12, borderColor:'#b58900', borderThickness:1}); | |
| return {atom: a, index: sa.index, sphereID: sphereID, labelID: labelID}; | |
| }); | |
| } catch (e) { | |
| logDebug("syncMeasurements error: " + e); | |
| } | |
| var n = _selectedAtoms.length; | |
| var names = _selectedAtoms.map(sa => getAtomName(sa.atom)).join(', '); | |
| var statusText = "<b>Selected Atoms:</b> " + (names || "None"); | |
| if (n >= 2) { | |
| for (let i = 0; i < n - 1; i++) { | |
| let a1 = _selectedAtoms[i].atom; | |
| let a2 = _selectedAtoms[i+1].atom; | |
| let d = Math.sqrt(Math.pow(a1.x-a2.x,2)+Math.pow(a1.y-a2.y,2)+Math.pow(a1.z-a2.z,2)).toFixed(3); | |
| _viewer.addLine({start:{x:a1.x,y:a1.y,z:a1.z}, end:{x:a2.x,y:a2.y,z:a2.z}, color:'yellow', dashed:true}); | |
| if (n == 2) statusText += "<br><b>Distance:</b> " + d + " Γ "; | |
| } | |
| } | |
| if (n >= 3) { | |
| let a1 = _selectedAtoms[0].atom, a2 = _selectedAtoms[1].atom, a3 = _selectedAtoms[2].atom; | |
| let v1 = {x:a1.x-a2.x, y:a1.y-a2.y, z:a1.z-a2.z}; | |
| let v2 = {x:a3.x-a2.x, y:a3.y-a2.y, z:a3.z-a2.z}; | |
| let dot = v1.x*v2.x + v1.y*v2.y + v1.z*v2.z; | |
| let mag1 = Math.sqrt(v1.x*v1.x+v1.y*v1.y+v1.z*v1.z); | |
| let mag2 = Math.sqrt(v2.x*v2.x+v2.y*v2.y+v2.z*v2.z); | |
| let angle = (Math.acos(dot/(mag1*mag2)) * 180 / Math.PI).toFixed(2); | |
| if (n == 3) statusText += "<br><b>Angle:</b> " + angle + "Β°"; | |
| } | |
| if (n == 4) { | |
| let p1 = _selectedAtoms[0].atom, p2 = _selectedAtoms[1].atom, p3 = _selectedAtoms[2].atom, p4 = _selectedAtoms[3].atom; | |
| // Dihedral logic | |
| let b1 = {x:p2.x-p1.x, y:p2.y-p1.y, z:p2.z-p1.z}; | |
| let b2 = {x:p3.x-p2.x, y:p3.y-p2.y, z:p3.z-p2.z}; | |
| let b3 = {x:p4.x-p3.x, y:p4.y-p3.y, z:p4.z-p3.z}; | |
| // Normal vectors | |
| let n1 = {x:b1.y*b2.z-b1.z*b2.y, y:b1.z*b2.x-b1.x*b2.z, z:b1.x*b2.y-b1.y*b2.x}; | |
| let n2 = {x:b2.y*b3.z-b2.z*b3.y, y:b2.z*b3.x-b2.x*b3.z, z:b2.x*b3.y-b2.y*b3.x}; | |
| let m1 = {x:n1.y*b2.z-n1.z*b2.y, y:n1.z*b2.x-n1.x*b2.z, z:n1.x*b2.y-n1.y*b2.x}; | |
| let dot = n1.x*n2.x + n1.y*n2.y + n1.z*n2.z; | |
| let x = dot; | |
| let y = (m1.x*n2.x + m1.y*n2.y + m1.z*n2.z) / Math.sqrt(b2.x*b2.x+b2.y*b2.y+b2.z*b2.z); | |
| let chi = (Math.atan2(y, x) * 180 / Math.PI).toFixed(2); | |
| statusText += "<br><b>Dihedral:</b> " + chi + "Β°"; | |
| } | |
| updateStatus(statusText); | |
| _viewer.render(); | |
| } | |
| function getMouseButton(event) { | |
| if (!event) return 0; | |
| if (typeof event.button === 'number') return event.button; | |
| if (typeof event.which === 'number') { | |
| if (event.which === 3) return 2; | |
| if (event.which === 2) return 1; | |
| if (event.which === 1) return 0; | |
| } | |
| return 0; | |
| } | |
| function handleAtomClick(atom, event) { | |
| var btn = getMouseButton(event); | |
| console.log("handleAtomClick triggered:", atom ? atom.index : "no atom", btn); | |
| if (!atom || atom.index === undefined) return; | |
| // Right click (button 2) -> Deselect | |
| if (btn === 2) { | |
| if (event && event.preventDefault) event.preventDefault(); | |
| let idx = _selectedAtoms.findIndex(sa => sa.index === atom.index); | |
| if (idx !== -1) { | |
| console.log("Deselecting atom:", atom.index); | |
| _viewer.removeShape(_selectedAtoms[idx].sphereID); | |
| if (_selectedAtoms[idx].labelID) _viewer.removeLabel(_selectedAtoms[idx].labelID); | |
| _selectedAtoms.splice(idx, 1); | |
| } else { | |
| console.log("Right click on unselected atom - clearing all."); | |
| clearSelections(); | |
| } | |
| } else { | |
| // Left click -> Select | |
| if (_selectedAtoms.length >= 4) { console.log("Max 4 atoms selected."); return; } | |
| if (_selectedAtoms.some(sa => sa.index === atom.index)) { console.log("Atom already selected."); return; } | |
| console.log("Selecting atom:", atom.index); | |
| let sphereID = _viewer.addSphere({center:{x:atom.x, y:atom.y, z:atom.z}, radius:0.55, color:'#ffd400', opacity:0.6, clickable:false}); | |
| let labelID = _viewer.addLabel(getAtomName(atom), {position:{x:atom.x, y:atom.y, z:atom.z}, fontColor:'black', backgroundColor:'yellow', backgroundOpacity:0.6, fontSize:12, borderColor:'#b58900', borderThickness:1}); | |
| _selectedAtoms.push({atom: atom, index: atom.index, sphereID: sphereID, labelID: labelID}); | |
| } | |
| syncMeasurements(); | |
| } | |
| function clearSelections() { | |
| _selectedAtoms.forEach(sa => { | |
| _viewer.removeShape(sa.sphereID); | |
| if (sa.labelID) _viewer.removeLabel(sa.labelID); | |
| }); | |
| _selectedAtoms = []; | |
| syncMeasurements(); | |
| } | |
| function loadMolecule(xyzStr, style, bg, showLabels, hideH) { | |
| if (typeof $3Dmol === 'undefined' || !_viewer) { | |
| initViewer(); | |
| if (!_viewer) { | |
| setTimeout(() => { loadMolecule(xyzStr, style, bg, showLabels, hideH); }, 200); | |
| return; | |
| } | |
| } | |
| _currentXYZ = xyzStr; | |
| _viewer.clear(); | |
| _selectedAtoms = []; | |
| updateStatus("<b>Selected Atoms:</b> None"); | |
| if (!xyzStr || xyzStr.length < 5) return; | |
| _viewer.addModel(xyzStr, 'xyz'); | |
| applyStyle(style, hideH); | |
| if (showLabels) { | |
| _viewer.addPropertyLabels('elem', {}, {fontSize: 12, fontColor: 'black', backgroundOpacity: 0.3, backgroundColor: 'white', alignment: 'center'}); | |
| } | |
| _viewer.setBackgroundColor(bg); | |
| // Ensure model is clickable and route clicks to our handler | |
| // Use viewer-level clickable so we consistently get click events | |
| try { | |
| _viewer.setClickable({}, true, function(atom, viewer, event) { | |
| handleAtomClick(atom, event); | |
| }); | |
| } catch (e) { | |
| logDebug("Clickable setup failed: " + e); | |
| } | |
| // Disable context menu on the viewer to allow right-click | |
| var el = document.getElementById('mol3d-container'); | |
| if (el) { | |
| el.oncontextmenu = function(e) { e.preventDefault(); return false; }; | |
| // Manual right-click listener for empty space clearing | |
| if (!el.dataset.listener) { | |
| el.dataset.listener = "true"; | |
| el.addEventListener('mousedown', function(event) { | |
| if (getMouseButton(event) === 2) { | |
| try { | |
| var rect = el.getBoundingClientRect(); | |
| var x = event.clientX - rect.left; | |
| var y = event.clientY - rect.top; | |
| var picked = _viewer ? _viewer.pick({x:x, y:y}) : null; | |
| if (!picked || !picked.atom) { | |
| clearSelections(); | |
| } | |
| } catch (e) { | |
| clearSelections(); | |
| } | |
| } | |
| }); | |
| } | |
| } | |
| _viewer.zoomTo(); | |
| _viewer.render(); | |
| } | |
| function applyStyle(style, hideH) { | |
| if (!_viewer) return; | |
| _viewer.setStyle({}, {}); | |
| var sel = hideH ? {elem: ['C','N','O','S','P','F','Cl','Br','I','B','Si','Se','As','Al','Hg','Bi']} : {}; | |
| var common = { | |
| clickable: true, | |
| hoverable: true | |
| }; | |
| if (hideH) { _viewer.setStyle({elem: 'H'}, {}); } | |
| switch(style) { | |
| case 'Ball and Stick': | |
| _viewer.setStyle(sel, {stick: {radius: 0.15, colorscheme: 'Jmol', ...common}, sphere: {scale: 0.25, colorscheme: 'Jmol', ...common}}); | |
| break; | |
| case 'Licorice': | |
| _viewer.setStyle(sel, {stick: {radius: 0.3, colorscheme: 'Jmol', ...common}}); | |
| break; | |
| case 'Sphere': | |
| _viewer.setStyle(sel, {sphere: {colorscheme: 'Jmol', ...common}}); | |
| break; | |
| case 'Stick': | |
| _viewer.setStyle(sel, {stick: {colorscheme: 'Jmol', ...common}}); | |
| break; | |
| default: | |
| _viewer.setStyle(sel, {stick: {radius: 0.15, colorscheme: 'Jmol', ...common}, sphere: {scale: 0.25, colorscheme: 'Jmol', ...common}}); | |
| } | |
| _viewer.render(); | |
| } | |
| function refreshViewer(style, bg, showLabels, hideH) { | |
| if (!_viewer || !_currentXYZ) return; | |
| loadMolecule(_currentXYZ, style, bg, showLabels, hideH); | |
| } | |
| </script> | |
| """ | |
| # Background color mapping | |
| BG_COLORS = {"White": "0xffffff", "Black": "0x000000", "Grey": "0x333333"} | |
| # ββ GRADIO UI ββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| with gr.Blocks(title="MolCraftDiffusion", theme=gr.themes.Soft(), head=THREEDMOL_HEAD) as demo: | |
| gr.Markdown( | |
| """ | |
| # π§ͺ MolCraftDiffusion | |
| Generate novel 3D molecular structures using the **MolecularDiffusion** framework. | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1, min_width=300): | |
| with gr.Group(): | |
| gr.Markdown("### βοΈ Settings") | |
| model_selector = gr.Dropdown( | |
| choices=AVAILABLE_MODELS, | |
| value=AVAILABLE_MODELS[0] if AVAILABLE_MODELS else None, | |
| label="Model", | |
| interactive=True | |
| ) | |
| initial_condition_names = get_condition_names(TASK) | |
| initial_is_conditional = len(initial_condition_names) > 0 | |
| mode_status = gr.Markdown( | |
| "Mode: **Conditional**" if initial_is_conditional else "Mode: **Unconditional**" | |
| ) | |
| num_mol = gr.Slider(1, 12, value=4, step=1, label="Number of Molecules") | |
| num_runs = gr.Slider(1, 20, value=1, step=1, label="Number of Runs") | |
| size_mode = gr.Radio( | |
| ["Auto (from training data)", "Fixed size"], | |
| value="Auto (from training data)", | |
| label="Molecule Size Strategy" | |
| ) | |
| fixed_size = gr.Slider( | |
| 5, 50, value=15, step=1, | |
| label="Atoms per Molecule", | |
| visible=False | |
| ) | |
| def toggle_fixed_size(choice): | |
| return gr.update(visible=(choice == "Fixed size")) | |
| size_mode.change( | |
| fn=toggle_fixed_size, | |
| inputs=[size_mode], | |
| outputs=[fixed_size] | |
| ) | |
| diffusion_steps = gr.Slider( | |
| 50, 1000, value=300, step=50, | |
| label="Diffusion Steps (Higher = Slower but better)" | |
| ) | |
| seed = gr.Number(value=42, label="Random Seed", precision=0) | |
| with gr.Accordion("π― Conditional CFG Controls", open=True, visible=initial_is_conditional) as cond_controls: | |
| gr.Markdown("Fill target values for all properties. Negative values are optional.") | |
| target_values = gr.Dataframe( | |
| headers=["property", "value"], | |
| value=[[name, 0.0] for name in initial_condition_names] if initial_is_conditional else [["property", None]], | |
| row_count=(len(initial_condition_names), "fixed") if initial_is_conditional else (1, "fixed"), | |
| col_count=(2, "fixed"), | |
| datatype=["str", "number"], | |
| label="Target values (required)", | |
| height=280, | |
| interactive=True, | |
| ) | |
| negative_values = gr.Dataframe( | |
| headers=["property", "value"], | |
| value=[[name, None] for name in initial_condition_names] if initial_is_conditional else [["property", None]], | |
| row_count=(len(initial_condition_names), "fixed") if initial_is_conditional else (1, "fixed"), | |
| col_count=(2, "fixed"), | |
| datatype=["str", "number"], | |
| label="Negative values (optional)", | |
| height=280, | |
| interactive=True, | |
| ) | |
| cfg_scale = gr.Slider( | |
| 0.0, | |
| 10.0, | |
| value=1.0, | |
| step=0.1, | |
| label="CFG scale", | |
| interactive=True, | |
| ) | |
| with gr.Accordion("π¨ Viewer Options", open=True): | |
| mol_style = gr.Radio( | |
| ["Ball and Stick", "Licorice", "Sphere", "Stick"], | |
| value="Ball and Stick", | |
| label="3D Style" | |
| ) | |
| bg_color = gr.Radio( | |
| ["White", "Black", "Grey"], | |
| value="White", | |
| label="Background" | |
| ) | |
| show_labels = gr.Checkbox(label="Show atom labels", value=False) | |
| hide_h = gr.Checkbox(label="Hide hydrogens", value=False) | |
| dl_format = gr.Radio( | |
| ["PDB", "XYZ"], | |
| value="PDB", | |
| label="Download Format" | |
| ) | |
| btn = gr.Button("π Generate Molecules", variant="primary", size="lg") | |
| with gr.Column(scale=2): | |
| gr.Markdown("### 𧬠Generated Molecules") | |
| with gr.Row(): | |
| mol_selector = gr.Dropdown( | |
| label="Select Molecule", choices=[], interactive=True, scale=3 | |
| ) | |
| single_dl = gr.File(label="Download Selection", scale=2) | |
| # 3Dmol.js viewer container | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| viewer_html = gr.HTML( | |
| value='<div id="mol3d-container"></div>' | |
| ) | |
| with gr.Column(scale=1): | |
| gr.Markdown("#### π Measurements") | |
| status_md = gr.HTML( | |
| value='<b>Selected Atoms:</b> None', | |
| elem_id="measurement-status" | |
| ) | |
| gr.Markdown( | |
| "*Click atoms to select (max 4).*\n" | |
| "*Right-click empty space to clear all.*" | |
| ) | |
| reset_btn = gr.Button("Reset Selections", size="sm") | |
| debug_log = gr.Textbox(label="Viewer Debug Log", interactive=False, elem_id="viewer-debug-log", lines=3) | |
| with gr.Row(): | |
| preview = gr.Image(label="2D Preview (PNG)", scale=1) | |
| with gr.Row(): | |
| table = gr.Dataframe(label="Properties") | |
| download_all = gr.File(label="Download All (.zip)") | |
| # Hidden states | |
| raw_xyz_state = gr.State([]) # List of raw XYZ strings | |
| current_xyz = gr.Textbox(visible=False) # Current molecule XYZ for JS | |
| def mol_to_png(xyz_str): | |
| """Generate a 2D PNG preview using RDKit.""" | |
| from rdkit.Chem import Draw | |
| try: | |
| lines = xyz_str.strip().splitlines() | |
| if len(lines) < 3: return None | |
| mol = Chem.RWMol() | |
| pos = [] | |
| for line in lines[2:]: | |
| parts = line.split() | |
| if len(parts) >= 4: | |
| atom = Chem.Atom(parts[0]) | |
| mol.AddAtom(atom) | |
| pos.append([float(parts[1]), float(parts[2]), float(parts[3])]) | |
| adj = build_adjacency_matrix(np.array(pos), [a.GetAtomicNum() for a in mol.GetAtoms()], scale=1.2) | |
| for i in range(len(pos)): | |
| for j in range(i+1, len(pos)): | |
| if adj[i, j]: | |
| mol.AddBond(i, j, Chem.BondType.SINGLE) | |
| rd_mol = mol.GetMol() | |
| AllChem.Compute2DCoords(rd_mol) | |
| img = Draw.MolToImage(rd_mol, size=(400, 400)) | |
| tmp_img = tempfile.NamedTemporaryFile(suffix=".png", delete=False) | |
| img.save(tmp_img.name) | |
| return tmp_img.name | |
| except Exception as e: | |
| print(f"Drawing error: {e}") | |
| return None | |
| def select_molecule(choice, raw_xyzs, fmt): | |
| """Update download file and current XYZ when molecule selection changes.""" | |
| if not raw_xyzs or choice is None: | |
| return None, None, "" | |
| try: | |
| idx = int(choice.split()[-1]) - 1 | |
| if not (0 <= idx < len(raw_xyzs)): | |
| return None, None, "" | |
| except (ValueError, IndexError): | |
| return None, None, "" | |
| dl_path = save_to_format(raw_xyzs[idx], idx, fmt) | |
| png_path = mol_to_png(raw_xyzs[idx]) | |
| return dl_path, png_path, raw_xyzs[idx] | |
| # JS to call loadMolecule when current_xyz changes | |
| LOAD_JS = """ | |
| (xyz, style, bg, labels, hideH) => { | |
| var bgHex = {'White':'0xffffff','Black':'0x000000','Grey':'0x333333'}[bg] || '0xffffff'; | |
| setTimeout(() => { loadMolecule(xyz, style, bgHex, labels, hideH); }, 200); | |
| return [xyz, style, bg, labels, hideH]; | |
| } | |
| """ | |
| REFRESH_JS = """ | |
| (style, bg, labels, hideH) => { | |
| var bgHex = {'White':'0xffffff','Black':'0x000000','Grey':'0x333333'}[bg] || '0xffffff'; | |
| refreshViewer(style, bgHex, labels, hideH); | |
| return [style, bg, labels, hideH]; | |
| } | |
| """ | |
| def on_model_change(model_name): | |
| global TASK, LOADED_MODEL_NAME | |
| if model_name and model_name != LOADED_MODEL_NAME: | |
| load_model(model_name) | |
| condition_names = get_condition_names(TASK) | |
| is_conditional = len(condition_names) > 0 | |
| if is_conditional: | |
| n = len(condition_names) | |
| target_update = gr.update( | |
| headers=["property", "value"], | |
| value=[[name, 0.0] for name in condition_names], | |
| row_count=(n, "fixed"), | |
| col_count=(2, "fixed"), | |
| datatype=["str", "number"], | |
| ) | |
| negative_update = gr.update( | |
| headers=["property", "value"], | |
| value=[[name, None] for name in condition_names], | |
| row_count=(n, "fixed"), | |
| col_count=(2, "fixed"), | |
| datatype=["str", "number"], | |
| ) | |
| mode_msg = "Mode: **Conditional**" | |
| else: | |
| target_update = gr.update( | |
| headers=["property", "value"], | |
| value=[["property", None]], | |
| row_count=(1, "fixed"), | |
| col_count=(2, "fixed"), | |
| datatype=["str", "number"], | |
| ) | |
| negative_update = gr.update( | |
| headers=["property", "value"], | |
| value=[["property", None]], | |
| row_count=(1, "fixed"), | |
| col_count=(2, "fixed"), | |
| datatype=["str", "number"], | |
| ) | |
| mode_msg = "Mode: **Unconditional**" | |
| return ( | |
| mode_msg, | |
| gr.update(visible=is_conditional), | |
| target_update, | |
| negative_update, | |
| gr.update(value=1.0), | |
| ) | |
| model_selector.change( | |
| fn=on_model_change, | |
| inputs=[model_selector], | |
| outputs=[mode_status, cond_controls, target_values, negative_values, cfg_scale], | |
| ) | |
| # When molecule selection changes: update download + preview, then trigger JS | |
| mol_selector.change( | |
| fn=select_molecule, | |
| inputs=[mol_selector, raw_xyz_state, dl_format], | |
| outputs=[single_dl, preview, current_xyz], | |
| ).then( | |
| fn=None, | |
| inputs=[current_xyz, mol_style, bg_color, show_labels, hide_h], | |
| outputs=None, | |
| js=LOAD_JS, | |
| ) | |
| # When style/bg/labels/hideH changes: refresh viewer via JS only | |
| for ctrl in [mol_style, bg_color, show_labels, hide_h]: | |
| ctrl.change( | |
| fn=None, | |
| inputs=[mol_style, bg_color, show_labels, hide_h], | |
| outputs=None, | |
| js=REFRESH_JS, | |
| ) | |
| # When format changes: just update the download file | |
| def update_dl(choice, raw_xyzs, fmt): | |
| if not raw_xyzs or choice is None: | |
| return None | |
| try: | |
| idx = int(choice.split()[-1]) - 1 | |
| if 0 <= idx < len(raw_xyzs): | |
| return save_to_format(raw_xyzs[idx], idx, fmt) | |
| except (ValueError, IndexError): | |
| pass | |
| return None | |
| dl_format.change( | |
| fn=update_dl, | |
| inputs=[mol_selector, raw_xyz_state, dl_format], | |
| outputs=[single_dl], | |
| ) | |
| # Reset button | |
| reset_btn.click(fn=None, inputs=[], outputs=[], js="clearSelections") | |
| # Generate button | |
| btn.click( | |
| fn=generate, | |
| inputs=[ | |
| model_selector, | |
| num_mol, | |
| num_runs, | |
| size_mode, | |
| fixed_size, | |
| diffusion_steps, | |
| seed, | |
| target_values, | |
| negative_values, | |
| cfg_scale, | |
| ], | |
| outputs=[current_xyz, mol_selector, raw_xyz_state, download_all, table] | |
| ).then( | |
| fn=None, | |
| inputs=[current_xyz, mol_style, bg_color, show_labels, hide_h], | |
| outputs=None, | |
| js=LOAD_JS, | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown( | |
| "Powered by [MolecularDiffusion](https://github.com/pregHosh/MolCraftDiffusion). " | |
| "Model: EDM Pretrained on QM9/GEOM." | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(server_name="0.0.0.0") | |