# --- Monkeypatch for Gradio 4.44.0 bug --- # Fix: TypeError / APIInfoParseError when additionalProperties is True (bool) # in gradio_client.utils when processing JSON schemas import gradio_client.utils as _gc_utils _orig_get_type = _gc_utils.get_type def _patched_get_type(schema): if isinstance(schema, bool): return "Any" return _orig_get_type(schema) _gc_utils.get_type = _patched_get_type _orig_json_schema = _gc_utils._json_schema_to_python_type def _patched_json_schema(schema, defs=None): if isinstance(schema, bool): return "Any" return _orig_json_schema(schema, defs) _gc_utils._json_schema_to_python_type = _patched_json_schema # --- End monkeypatch --- import gradio as gr import torch import pickle import os, zipfile, io import tempfile import numpy as np import scipy.spatial import pandas as pd from rdkit import Chem from rdkit.Chem import AllChem from MolecularDiffusion.core import Engine from MolecularDiffusion.utils import seed_everything # ── GLOBALS ──────────────────────────────────────────────────── MODEL_DIR = os.path.join(os.path.dirname(__file__), "model") # Standard atom vocab for QM9/GEOM used in training ATOM_VOCAB = ["H","B","C","N","O","F","Al","Si","P","S","Cl","As","Se","Br","I","Hg","Bi"] DEVICE = torch.device("cpu") # Force CPU for free spaces # Track currently loaded model TASK = None LOADED_MODEL_NAME = None def discover_models(): """Scan model/ for subfolders containing edm_chem.pkl.""" models = [] if os.path.isdir(MODEL_DIR): for name in sorted(os.listdir(MODEL_DIR)): sub = os.path.join(MODEL_DIR, name) if os.path.isdir(sub) and os.path.exists(os.path.join(sub, "edm_chem.pkl")): models.append(name) return models AVAILABLE_MODELS = discover_models() print(f"Discovered models: {AVAILABLE_MODELS}") def get_condition_names(task): """Return conditioning property names from loaded task.""" if task is None: return [] cond = getattr(task, "condition", []) if cond is None: return [] return list(cond) def parse_condition_row(df_like, expected_names, required=True): """Parse first-row values from a gr.Dataframe payload.""" if not expected_names: return [] if df_like is None: if required: raise ValueError("Missing conditioning inputs.") return [] if isinstance(df_like, pd.DataFrame): if df_like.empty: row_values = [] else: row_values = df_like.iloc[0].tolist() elif isinstance(df_like, list): row_values = df_like[0] if len(df_like) > 0 else [] else: row_values = [] # Layout A (legacy): one row, one column per property. if isinstance(df_like, pd.DataFrame) and all(name in df_like.columns for name in expected_names): out = [] for name in expected_names: raw = df_like.iloc[0][name] if not df_like.empty else None if raw is None or (isinstance(raw, float) and np.isnan(raw)) or str(raw).strip() == "": if required: raise ValueError(f"Target value required for property '{name}'.") continue try: out.append(float(raw)) except Exception as exc: raise ValueError(f"Invalid numeric value for property '{name}': {raw}") from exc return out # Layout B (current): two columns [property, value], one row per property. rows = df_like.values.tolist() if isinstance(df_like, pd.DataFrame) else (df_like or []) value_map = {} for row in rows: if len(row) < 2: continue prop, raw = row[0], row[1] if prop is None: continue value_map[str(prop)] = raw out = [] for name in expected_names: raw = value_map.get(name, None) if raw is None or (isinstance(raw, float) and np.isnan(raw)) or str(raw).strip() == "": if required: raise ValueError(f"Target value required for property '{name}'.") continue try: out.append(float(raw)) except Exception as exc: raise ValueError(f"Invalid numeric value for property '{name}': {raw}") from exc return out # ── MODEL LOADING ────────────────────────────────────────────── def load_model(model_name): """Load a model from model//.""" global TASK, LOADED_MODEL_NAME model_subdir = os.path.join(MODEL_DIR, model_name) print(f"Loading model '{model_name}' on {DEVICE}...") try: # Initialize empty engine engine = Engine(None, None, None, None, None) # Load from pickle checkpoint # interference_mode=True skips optimizer/dataset loading chkpt_path = os.path.join(model_subdir, "edm_chem.pkl") if not os.path.exists(chkpt_path): raise FileNotFoundError(f"Checkpoint not found at {chkpt_path}") engine = engine.load_from_checkpoint( chkpt_path, interference_mode=True ) task = engine.model stat_path = os.path.join(model_subdir, "edm_stat.pkl") if os.path.exists(stat_path): with open(stat_path, "rb") as f: stats = pickle.load(f) task.node_dist_model = stats.get("node") if "prop" in stats: task.prop_dist_model = stats.get("prop") else: print("Warning: edm_stat.pkl not found. Size sampling might fail.") # Set vocab manually if missing if not hasattr(task, 'atom_vocab') or task.atom_vocab is None: task.atom_vocab = ATOM_VOCAB # ── CRITICAL FIXES FOR DEVICE COMPATIBILITY ── # The pickled model might have components with persistent 'cuda' device attributes # or cached CUDA tensors that don't get moved by .to(device). # We must manually clean them up. task.to(DEVICE) task.eval() print("Patching model for CPU compatibility...") # 1. Monkeypatch task.device property if accessing it returns CUDA try: # If it returns CUDA, we force patch it if task.device.type != 'cpu': print(" Monkeypatching task.device") try: task.device = DEVICE except AttributeError: type(task).device = property(lambda self: torch.device("cpu")) except Exception: pass # 2. Iterate over all modules to fix attributes and clear caches for name, module in task.named_modules(): # Clear _edges_dict cache (contains CUDA tensors) if hasattr(module, "_edges_dict"): module._edges_dict = {} # Fix 'device' attribute if it exists and is stored as value (not property) if hasattr(module, "device"): # Check if property is_property = isinstance(getattr(type(module), "device", None), property) if not is_property: try: module.device = DEVICE except Exception: pass print(f"Model '{model_name}' loaded successfully!") TASK = task LOADED_MODEL_NAME = model_name return task except Exception as e: print(f"CRITICAL ERROR loading model '{model_name}': {e}") import traceback traceback.print_exc() return None # Load first available model at startup if AVAILABLE_MODELS: load_model(AVAILABLE_MODELS[0]) else: print("WARNING: No models found in model/ directory!") # ── HELPERS ──────────────────────────────────────────────────── def tensors_to_xyz_string(one_hot_i, x_i, node_mask_i, atom_vocab): """Convert single-molecule tensors to XYZ string.""" atoms = torch.argmax(one_hot_i, dim=1) n_atoms = int(node_mask_i.squeeze(-1).sum().item()) lines = [f"{n_atoms}", "Generated by MolecularDiffusion"] for k in range(n_atoms): idx = atoms[k].item() if idx < len(atom_vocab): symbol = atom_vocab[idx] else: symbol = "X" pos = x_i[k] lines.append(f"{symbol} {pos[0]:.6f} {pos[1]:.6f} {pos[2]:.6f}") return "\n".join(lines) def parse_composition(xyz_str): """Parse XYZ string to get basic composition.""" lines = xyz_str.strip().split('\n') try: n_atoms = int(lines[0]) atoms = [line.split()[0] for line in lines[2:2+n_atoms]] formula = {} for a in atoms: formula[a] = formula.get(a, 0) + 1 # Sort by Hill system (C first, H second, then alphabetical) is complex, # so just alphabetical for now formula_str = "".join([f"{k}{v if v>1 else ''}" for k,v in sorted(formula.items())]) return {"Atoms": n_atoms, "Formula": formula_str} except: return {"Atoms": 0, "Formula": "Error"} def build_adjacency_matrix(positions: np.ndarray, atomic_numbers: list, scale: float = 1.25) -> np.ndarray: """Build binary adjacency matrix based on scaled covalent radii.""" from ase.data import covalent_radii n = len(atomic_numbers) adj = np.zeros((n, n), dtype=int) radii = np.array([covalent_radii[z] for z in atomic_numbers]) dist_matrix = scipy.spatial.distance.cdist(positions, positions) thresholds = (radii[:, None] + radii[None, :]) * scale mask = (dist_matrix <= thresholds) & (dist_matrix > 0.05) adj[mask] = 1 return adj def create_xyz_zip(xyz_strings): """Create a zip file containing all generated XYZs.""" if not xyz_strings: return None # Create a temporary file tmp_zip = tempfile.NamedTemporaryFile(suffix='.zip', delete=False) tmp_zip.close() with zipfile.ZipFile(tmp_zip.name, 'w') as zf: for i, xyz_content in enumerate(xyz_strings): # writestr takes (archive_name, data) zf.writestr(f"molecule_{i:03d}.xyz", xyz_content) return tmp_zip.name # ── 3Dmol.js SETUP (loaded via head tag) ────────────────────── def xyz_to_pdb(xyz_str): """Convert XYZ string to PDB format for the viewer.""" lines = xyz_str.strip().splitlines() if len(lines) < 3: return "" atoms = lines[2:] # Skip header pdb_lines = [] for i, line in enumerate(atoms): parts = line.split() if len(parts) >= 4: elem = parts[0] try: x, y, z = float(parts[1]), float(parts[2]), float(parts[3]) except ValueError: continue # Simple HETATM record pdb_lines.append( f"HETATM{i+1:>5} {elem:<3} UNK 1 {x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00 {elem:>2}" ) return "\n".join(pdb_lines) def save_to_format(xyz_str, idx, fmt="pdb"): """Save molecule to requested format (pdb or xyz) and return path.""" suffix = f".{fmt.lower()}" content = xyz_to_pdb(xyz_str) if fmt.lower() == "pdb" else xyz_str tmp = tempfile.NamedTemporaryFile( suffix=suffix, prefix=f"mol_{idx:03d}_", delete=False, mode="w" ) tmp.write(content) tmp.close() return tmp.name # ── GENERATION FUNCTION ─────────────────────────────────────── @torch.no_grad() def generate( model_name, num_molecules, num_runs, size_mode, fixed_size, diffusion_steps, seed, target_values_df, negative_values_df, cfg_scale, ): global TASK, LOADED_MODEL_NAME # Reload model if selection changed if model_name != LOADED_MODEL_NAME: print(f"Switching model: {LOADED_MODEL_NAME} -> {model_name}") load_model(model_name) if TASK is None: return "", gr.update(choices=[], value=None), [], None, None total_requested = int(num_molecules) * int(num_runs) print( f"Generating {total_requested} molecules as {num_runs} run(s) x batch {num_molecules} " f"with '{model_name}' (Steps: {diffusion_steps}, Seed: {seed})..." ) seed_everything(int(seed)) # 1. Override diffusion steps # Handle both EGCL and other architectures original_T = None if hasattr(TASK, 'model') and hasattr(TASK.model, 'T'): original_T = TASK.model.T TASK.model.T = int(diffusion_steps) elif hasattr(TASK, 'T'): original_T = TASK.T TASK.T = int(diffusion_steps) try: # 2. Resolve conditioning mode/inputs once condition_names = get_condition_names(TASK) is_conditional = len(condition_names) > 0 if is_conditional: target_values = parse_condition_row( target_values_df, condition_names, required=True ) negative_values = parse_condition_row( negative_values_df, condition_names, required=False ) # 3. Sample across runs xyz_strings = [] summary_rows = [] target_fn = lambda z, t: torch.zeros( z.size(0), device=z.device, dtype=z.dtype ) for run_idx in range(int(num_runs)): # Determine molecule sizes for this run if "Auto" in size_mode: if TASK.node_dist_model is not None: nodesxsample = TASK.node_dist_model.sample(int(num_molecules)) else: nodesxsample = torch.randint(10, 30, (int(num_molecules),)) else: nodesxsample = torch.tensor([fixed_size] * int(num_molecules)) nodesxsample = nodesxsample.to(DEVICE).long() if is_conditional: one_hot, charges, x, node_mask = TASK.sample_guidance_conitional( target_function=target_fn, target_value=target_values, negative_target_value=negative_values, nodesxsample=nodesxsample, gg_scale=0.0, cfg_scale=float(cfg_scale), guidance_ver="cfg", n_frames=0, fix_noise=False, ) else: one_hot, charges, x, node_mask = TASK.sample( nodesxsample=nodesxsample, mode="ddpm", n_frames=0, fix_noise=False, ) for i in range(int(num_molecules)): xyz_str = tensors_to_xyz_string( one_hot[i], x[i], node_mask[i], TASK.atom_vocab ) xyz_strings.append(xyz_str) summary_rows.append(parse_composition(xyz_str)) # 4. Output generation – save zip for bulk download zip_path = create_xyz_zip(xyz_strings) # Prepare table with "Name" column for i, row in enumerate(summary_rows): row["Name"] = f"Molecule {i+1}" df_out = pd.DataFrame(summary_rows) # Reorder to put Name first cols = ["Name"] + [c for c in df_out.columns if c != "Name"] df_out = df_out[cols] choices = [f"Molecule {i+1}" for i in range(len(xyz_strings))] return ( xyz_strings[0], # current_xyz (triggers JS) gr.update(choices=choices, value=choices[0]), # selector xyz_strings, # raw_xyz_state zip_path, # download_all df_out, # table ) except Exception as e: import traceback traceback.print_exc() return "", gr.update(choices=[], value=None), [], None, None finally: # Restore original T to be safe if original_T is not None: if hasattr(TASK, 'model') and hasattr(TASK.model, 'T'): TASK.model.T = original_T elif hasattr(TASK, 'T'): TASK.T = original_T # ── 3Dmol.js JavaScript ─────────────────────────────────────── THREEDMOL_HEAD = """ """ # Background color mapping BG_COLORS = {"White": "0xffffff", "Black": "0x000000", "Grey": "0x333333"} # ── GRADIO UI ────────────────────────────────────────────────── with gr.Blocks(title="MolCraftDiffusion", theme=gr.themes.Soft(), head=THREEDMOL_HEAD) as demo: gr.Markdown( """ # 🧪 MolCraftDiffusion Generate novel 3D molecular structures using the **MolecularDiffusion** framework. """ ) with gr.Row(): with gr.Column(scale=1, min_width=300): with gr.Group(): gr.Markdown("### ⚙️ Settings") model_selector = gr.Dropdown( choices=AVAILABLE_MODELS, value=AVAILABLE_MODELS[0] if AVAILABLE_MODELS else None, label="Model", interactive=True ) initial_condition_names = get_condition_names(TASK) initial_is_conditional = len(initial_condition_names) > 0 mode_status = gr.Markdown( "Mode: **Conditional**" if initial_is_conditional else "Mode: **Unconditional**" ) num_mol = gr.Slider(1, 12, value=4, step=1, label="Number of Molecules") num_runs = gr.Slider(1, 20, value=1, step=1, label="Number of Runs") size_mode = gr.Radio( ["Auto (from training data)", "Fixed size"], value="Auto (from training data)", label="Molecule Size Strategy" ) fixed_size = gr.Slider( 5, 50, value=15, step=1, label="Atoms per Molecule", visible=False ) def toggle_fixed_size(choice): return gr.update(visible=(choice == "Fixed size")) size_mode.change( fn=toggle_fixed_size, inputs=[size_mode], outputs=[fixed_size] ) diffusion_steps = gr.Slider( 50, 1000, value=300, step=50, label="Diffusion Steps (Higher = Slower but better)" ) seed = gr.Number(value=42, label="Random Seed", precision=0) with gr.Accordion("🎯 Conditional CFG Controls", open=True, visible=initial_is_conditional) as cond_controls: gr.Markdown("Fill target values for all properties. Negative values are optional.") target_values = gr.Dataframe( headers=["property", "value"], value=[[name, 0.0] for name in initial_condition_names] if initial_is_conditional else [["property", None]], row_count=(len(initial_condition_names), "fixed") if initial_is_conditional else (1, "fixed"), col_count=(2, "fixed"), datatype=["str", "number"], label="Target values (required)", height=280, interactive=True, ) negative_values = gr.Dataframe( headers=["property", "value"], value=[[name, None] for name in initial_condition_names] if initial_is_conditional else [["property", None]], row_count=(len(initial_condition_names), "fixed") if initial_is_conditional else (1, "fixed"), col_count=(2, "fixed"), datatype=["str", "number"], label="Negative values (optional)", height=280, interactive=True, ) cfg_scale = gr.Slider( 0.0, 10.0, value=1.0, step=0.1, label="CFG scale", interactive=True, ) with gr.Accordion("🎨 Viewer Options", open=True): mol_style = gr.Radio( ["Ball and Stick", "Licorice", "Sphere", "Stick"], value="Ball and Stick", label="3D Style" ) bg_color = gr.Radio( ["White", "Black", "Grey"], value="White", label="Background" ) show_labels = gr.Checkbox(label="Show atom labels", value=False) hide_h = gr.Checkbox(label="Hide hydrogens", value=False) dl_format = gr.Radio( ["PDB", "XYZ"], value="PDB", label="Download Format" ) btn = gr.Button("🚀 Generate Molecules", variant="primary", size="lg") with gr.Column(scale=2): gr.Markdown("### 🧬 Generated Molecules") with gr.Row(): mol_selector = gr.Dropdown( label="Select Molecule", choices=[], interactive=True, scale=3 ) single_dl = gr.File(label="Download Selection", scale=2) # 3Dmol.js viewer container with gr.Row(): with gr.Column(scale=4): viewer_html = gr.HTML( value='
' ) with gr.Column(scale=1): gr.Markdown("#### 📏 Measurements") status_md = gr.HTML( value='Selected Atoms: None', elem_id="measurement-status" ) gr.Markdown( "*Click atoms to select (max 4).*\n" "*Right-click empty space to clear all.*" ) reset_btn = gr.Button("Reset Selections", size="sm") debug_log = gr.Textbox(label="Viewer Debug Log", interactive=False, elem_id="viewer-debug-log", lines=3) with gr.Row(): preview = gr.Image(label="2D Preview (PNG)", scale=1) with gr.Row(): table = gr.Dataframe(label="Properties") download_all = gr.File(label="Download All (.zip)") # Hidden states raw_xyz_state = gr.State([]) # List of raw XYZ strings current_xyz = gr.Textbox(visible=False) # Current molecule XYZ for JS def mol_to_png(xyz_str): """Generate a 2D PNG preview using RDKit.""" from rdkit.Chem import Draw try: lines = xyz_str.strip().splitlines() if len(lines) < 3: return None mol = Chem.RWMol() pos = [] for line in lines[2:]: parts = line.split() if len(parts) >= 4: atom = Chem.Atom(parts[0]) mol.AddAtom(atom) pos.append([float(parts[1]), float(parts[2]), float(parts[3])]) adj = build_adjacency_matrix(np.array(pos), [a.GetAtomicNum() for a in mol.GetAtoms()], scale=1.2) for i in range(len(pos)): for j in range(i+1, len(pos)): if adj[i, j]: mol.AddBond(i, j, Chem.BondType.SINGLE) rd_mol = mol.GetMol() AllChem.Compute2DCoords(rd_mol) img = Draw.MolToImage(rd_mol, size=(400, 400)) tmp_img = tempfile.NamedTemporaryFile(suffix=".png", delete=False) img.save(tmp_img.name) return tmp_img.name except Exception as e: print(f"Drawing error: {e}") return None def select_molecule(choice, raw_xyzs, fmt): """Update download file and current XYZ when molecule selection changes.""" if not raw_xyzs or choice is None: return None, None, "" try: idx = int(choice.split()[-1]) - 1 if not (0 <= idx < len(raw_xyzs)): return None, None, "" except (ValueError, IndexError): return None, None, "" dl_path = save_to_format(raw_xyzs[idx], idx, fmt) png_path = mol_to_png(raw_xyzs[idx]) return dl_path, png_path, raw_xyzs[idx] # JS to call loadMolecule when current_xyz changes LOAD_JS = """ (xyz, style, bg, labels, hideH) => { var bgHex = {'White':'0xffffff','Black':'0x000000','Grey':'0x333333'}[bg] || '0xffffff'; setTimeout(() => { loadMolecule(xyz, style, bgHex, labels, hideH); }, 200); return [xyz, style, bg, labels, hideH]; } """ REFRESH_JS = """ (style, bg, labels, hideH) => { var bgHex = {'White':'0xffffff','Black':'0x000000','Grey':'0x333333'}[bg] || '0xffffff'; refreshViewer(style, bgHex, labels, hideH); return [style, bg, labels, hideH]; } """ def on_model_change(model_name): global TASK, LOADED_MODEL_NAME if model_name and model_name != LOADED_MODEL_NAME: load_model(model_name) condition_names = get_condition_names(TASK) is_conditional = len(condition_names) > 0 if is_conditional: n = len(condition_names) target_update = gr.update( headers=["property", "value"], value=[[name, 0.0] for name in condition_names], row_count=(n, "fixed"), col_count=(2, "fixed"), datatype=["str", "number"], ) negative_update = gr.update( headers=["property", "value"], value=[[name, None] for name in condition_names], row_count=(n, "fixed"), col_count=(2, "fixed"), datatype=["str", "number"], ) mode_msg = "Mode: **Conditional**" else: target_update = gr.update( headers=["property", "value"], value=[["property", None]], row_count=(1, "fixed"), col_count=(2, "fixed"), datatype=["str", "number"], ) negative_update = gr.update( headers=["property", "value"], value=[["property", None]], row_count=(1, "fixed"), col_count=(2, "fixed"), datatype=["str", "number"], ) mode_msg = "Mode: **Unconditional**" return ( mode_msg, gr.update(visible=is_conditional), target_update, negative_update, gr.update(value=1.0), ) model_selector.change( fn=on_model_change, inputs=[model_selector], outputs=[mode_status, cond_controls, target_values, negative_values, cfg_scale], ) # When molecule selection changes: update download + preview, then trigger JS mol_selector.change( fn=select_molecule, inputs=[mol_selector, raw_xyz_state, dl_format], outputs=[single_dl, preview, current_xyz], ).then( fn=None, inputs=[current_xyz, mol_style, bg_color, show_labels, hide_h], outputs=None, js=LOAD_JS, ) # When style/bg/labels/hideH changes: refresh viewer via JS only for ctrl in [mol_style, bg_color, show_labels, hide_h]: ctrl.change( fn=None, inputs=[mol_style, bg_color, show_labels, hide_h], outputs=None, js=REFRESH_JS, ) # When format changes: just update the download file def update_dl(choice, raw_xyzs, fmt): if not raw_xyzs or choice is None: return None try: idx = int(choice.split()[-1]) - 1 if 0 <= idx < len(raw_xyzs): return save_to_format(raw_xyzs[idx], idx, fmt) except (ValueError, IndexError): pass return None dl_format.change( fn=update_dl, inputs=[mol_selector, raw_xyz_state, dl_format], outputs=[single_dl], ) # Reset button reset_btn.click(fn=None, inputs=[], outputs=[], js="clearSelections") # Generate button btn.click( fn=generate, inputs=[ model_selector, num_mol, num_runs, size_mode, fixed_size, diffusion_steps, seed, target_values, negative_values, cfg_scale, ], outputs=[current_xyz, mol_selector, raw_xyz_state, download_all, table] ).then( fn=None, inputs=[current_xyz, mol_style, bg_color, show_labels, hide_h], outputs=None, js=LOAD_JS, ) gr.Markdown("---") gr.Markdown( "Powered by [MolecularDiffusion](https://github.com/pregHosh/MolCraftDiffusion). " "Model: EDM Pretrained on QM9/GEOM." ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0")