iflp1908sl's picture
now supports repeated generation runs with fixed batch size per run:
89947eb
# --- Monkeypatch for Gradio 4.44.0 bug ---
# Fix: TypeError / APIInfoParseError when additionalProperties is True (bool)
# in gradio_client.utils when processing JSON schemas
import gradio_client.utils as _gc_utils
_orig_get_type = _gc_utils.get_type
def _patched_get_type(schema):
if isinstance(schema, bool):
return "Any"
return _orig_get_type(schema)
_gc_utils.get_type = _patched_get_type
_orig_json_schema = _gc_utils._json_schema_to_python_type
def _patched_json_schema(schema, defs=None):
if isinstance(schema, bool):
return "Any"
return _orig_json_schema(schema, defs)
_gc_utils._json_schema_to_python_type = _patched_json_schema
# --- End monkeypatch ---
import gradio as gr
import torch
import pickle
import os, zipfile, io
import tempfile
import numpy as np
import scipy.spatial
import pandas as pd
from rdkit import Chem
from rdkit.Chem import AllChem
from MolecularDiffusion.core import Engine
from MolecularDiffusion.utils import seed_everything
# ── GLOBALS ────────────────────────────────────────────────────
MODEL_DIR = os.path.join(os.path.dirname(__file__), "model")
# Standard atom vocab for QM9/GEOM used in training
ATOM_VOCAB = ["H","B","C","N","O","F","Al","Si","P","S","Cl","As","Se","Br","I","Hg","Bi"]
DEVICE = torch.device("cpu") # Force CPU for free spaces
# Track currently loaded model
TASK = None
LOADED_MODEL_NAME = None
def discover_models():
"""Scan model/ for subfolders containing edm_chem.pkl."""
models = []
if os.path.isdir(MODEL_DIR):
for name in sorted(os.listdir(MODEL_DIR)):
sub = os.path.join(MODEL_DIR, name)
if os.path.isdir(sub) and os.path.exists(os.path.join(sub, "edm_chem.pkl")):
models.append(name)
return models
AVAILABLE_MODELS = discover_models()
print(f"Discovered models: {AVAILABLE_MODELS}")
def get_condition_names(task):
"""Return conditioning property names from loaded task."""
if task is None:
return []
cond = getattr(task, "condition", [])
if cond is None:
return []
return list(cond)
def parse_condition_row(df_like, expected_names, required=True):
"""Parse first-row values from a gr.Dataframe payload."""
if not expected_names:
return []
if df_like is None:
if required:
raise ValueError("Missing conditioning inputs.")
return []
if isinstance(df_like, pd.DataFrame):
if df_like.empty:
row_values = []
else:
row_values = df_like.iloc[0].tolist()
elif isinstance(df_like, list):
row_values = df_like[0] if len(df_like) > 0 else []
else:
row_values = []
# Layout A (legacy): one row, one column per property.
if isinstance(df_like, pd.DataFrame) and all(name in df_like.columns for name in expected_names):
out = []
for name in expected_names:
raw = df_like.iloc[0][name] if not df_like.empty else None
if raw is None or (isinstance(raw, float) and np.isnan(raw)) or str(raw).strip() == "":
if required:
raise ValueError(f"Target value required for property '{name}'.")
continue
try:
out.append(float(raw))
except Exception as exc:
raise ValueError(f"Invalid numeric value for property '{name}': {raw}") from exc
return out
# Layout B (current): two columns [property, value], one row per property.
rows = df_like.values.tolist() if isinstance(df_like, pd.DataFrame) else (df_like or [])
value_map = {}
for row in rows:
if len(row) < 2:
continue
prop, raw = row[0], row[1]
if prop is None:
continue
value_map[str(prop)] = raw
out = []
for name in expected_names:
raw = value_map.get(name, None)
if raw is None or (isinstance(raw, float) and np.isnan(raw)) or str(raw).strip() == "":
if required:
raise ValueError(f"Target value required for property '{name}'.")
continue
try:
out.append(float(raw))
except Exception as exc:
raise ValueError(f"Invalid numeric value for property '{name}': {raw}") from exc
return out
# ── MODEL LOADING ──────────────────────────────────────────────
def load_model(model_name):
"""Load a model from model/<model_name>/."""
global TASK, LOADED_MODEL_NAME
model_subdir = os.path.join(MODEL_DIR, model_name)
print(f"Loading model '{model_name}' on {DEVICE}...")
try:
# Initialize empty engine
engine = Engine(None, None, None, None, None)
# Load from pickle checkpoint
# interference_mode=True skips optimizer/dataset loading
chkpt_path = os.path.join(model_subdir, "edm_chem.pkl")
if not os.path.exists(chkpt_path):
raise FileNotFoundError(f"Checkpoint not found at {chkpt_path}")
engine = engine.load_from_checkpoint(
chkpt_path, interference_mode=True
)
task = engine.model
stat_path = os.path.join(model_subdir, "edm_stat.pkl")
if os.path.exists(stat_path):
with open(stat_path, "rb") as f:
stats = pickle.load(f)
task.node_dist_model = stats.get("node")
if "prop" in stats:
task.prop_dist_model = stats.get("prop")
else:
print("Warning: edm_stat.pkl not found. Size sampling might fail.")
# Set vocab manually if missing
if not hasattr(task, 'atom_vocab') or task.atom_vocab is None:
task.atom_vocab = ATOM_VOCAB
# ── CRITICAL FIXES FOR DEVICE COMPATIBILITY ──
# The pickled model might have components with persistent 'cuda' device attributes
# or cached CUDA tensors that don't get moved by .to(device).
# We must manually clean them up.
task.to(DEVICE)
task.eval()
print("Patching model for CPU compatibility...")
# 1. Monkeypatch task.device property if accessing it returns CUDA
try:
# If it returns CUDA, we force patch it
if task.device.type != 'cpu':
print(" Monkeypatching task.device")
try:
task.device = DEVICE
except AttributeError:
type(task).device = property(lambda self: torch.device("cpu"))
except Exception:
pass
# 2. Iterate over all modules to fix attributes and clear caches
for name, module in task.named_modules():
# Clear _edges_dict cache (contains CUDA tensors)
if hasattr(module, "_edges_dict"):
module._edges_dict = {}
# Fix 'device' attribute if it exists and is stored as value (not property)
if hasattr(module, "device"):
# Check if property
is_property = isinstance(getattr(type(module), "device", None), property)
if not is_property:
try:
module.device = DEVICE
except Exception:
pass
print(f"Model '{model_name}' loaded successfully!")
TASK = task
LOADED_MODEL_NAME = model_name
return task
except Exception as e:
print(f"CRITICAL ERROR loading model '{model_name}': {e}")
import traceback
traceback.print_exc()
return None
# Load first available model at startup
if AVAILABLE_MODELS:
load_model(AVAILABLE_MODELS[0])
else:
print("WARNING: No models found in model/ directory!")
# ── HELPERS ────────────────────────────────────────────────────
def tensors_to_xyz_string(one_hot_i, x_i, node_mask_i, atom_vocab):
"""Convert single-molecule tensors to XYZ string."""
atoms = torch.argmax(one_hot_i, dim=1)
n_atoms = int(node_mask_i.squeeze(-1).sum().item())
lines = [f"{n_atoms}", "Generated by MolecularDiffusion"]
for k in range(n_atoms):
idx = atoms[k].item()
if idx < len(atom_vocab):
symbol = atom_vocab[idx]
else:
symbol = "X"
pos = x_i[k]
lines.append(f"{symbol} {pos[0]:.6f} {pos[1]:.6f} {pos[2]:.6f}")
return "\n".join(lines)
def parse_composition(xyz_str):
"""Parse XYZ string to get basic composition."""
lines = xyz_str.strip().split('\n')
try:
n_atoms = int(lines[0])
atoms = [line.split()[0] for line in lines[2:2+n_atoms]]
formula = {}
for a in atoms:
formula[a] = formula.get(a, 0) + 1
# Sort by Hill system (C first, H second, then alphabetical) is complex,
# so just alphabetical for now
formula_str = "".join([f"{k}{v if v>1 else ''}" for k,v in sorted(formula.items())])
return {"Atoms": n_atoms, "Formula": formula_str}
except:
return {"Atoms": 0, "Formula": "Error"}
def build_adjacency_matrix(positions: np.ndarray, atomic_numbers: list, scale: float = 1.25) -> np.ndarray:
"""Build binary adjacency matrix based on scaled covalent radii."""
from ase.data import covalent_radii
n = len(atomic_numbers)
adj = np.zeros((n, n), dtype=int)
radii = np.array([covalent_radii[z] for z in atomic_numbers])
dist_matrix = scipy.spatial.distance.cdist(positions, positions)
thresholds = (radii[:, None] + radii[None, :]) * scale
mask = (dist_matrix <= thresholds) & (dist_matrix > 0.05)
adj[mask] = 1
return adj
def create_xyz_zip(xyz_strings):
"""Create a zip file containing all generated XYZs."""
if not xyz_strings:
return None
# Create a temporary file
tmp_zip = tempfile.NamedTemporaryFile(suffix='.zip', delete=False)
tmp_zip.close()
with zipfile.ZipFile(tmp_zip.name, 'w') as zf:
for i, xyz_content in enumerate(xyz_strings):
# writestr takes (archive_name, data)
zf.writestr(f"molecule_{i:03d}.xyz", xyz_content)
return tmp_zip.name
# ── 3Dmol.js SETUP (loaded via head tag) ──────────────────────
def xyz_to_pdb(xyz_str):
"""Convert XYZ string to PDB format for the viewer."""
lines = xyz_str.strip().splitlines()
if len(lines) < 3:
return ""
atoms = lines[2:] # Skip header
pdb_lines = []
for i, line in enumerate(atoms):
parts = line.split()
if len(parts) >= 4:
elem = parts[0]
try:
x, y, z = float(parts[1]), float(parts[2]), float(parts[3])
except ValueError:
continue
# Simple HETATM record
pdb_lines.append(
f"HETATM{i+1:>5} {elem:<3} UNK 1 {x:8.3f}{y:8.3f}{z:8.3f} 1.00 0.00 {elem:>2}"
)
return "\n".join(pdb_lines)
def save_to_format(xyz_str, idx, fmt="pdb"):
"""Save molecule to requested format (pdb or xyz) and return path."""
suffix = f".{fmt.lower()}"
content = xyz_to_pdb(xyz_str) if fmt.lower() == "pdb" else xyz_str
tmp = tempfile.NamedTemporaryFile(
suffix=suffix, prefix=f"mol_{idx:03d}_", delete=False, mode="w"
)
tmp.write(content)
tmp.close()
return tmp.name
# ── GENERATION FUNCTION ───────────────────────────────────────
@torch.no_grad()
def generate(
model_name,
num_molecules,
num_runs,
size_mode,
fixed_size,
diffusion_steps,
seed,
target_values_df,
negative_values_df,
cfg_scale,
):
global TASK, LOADED_MODEL_NAME
# Reload model if selection changed
if model_name != LOADED_MODEL_NAME:
print(f"Switching model: {LOADED_MODEL_NAME} -> {model_name}")
load_model(model_name)
if TASK is None:
return "", gr.update(choices=[], value=None), [], None, None
total_requested = int(num_molecules) * int(num_runs)
print(
f"Generating {total_requested} molecules as {num_runs} run(s) x batch {num_molecules} "
f"with '{model_name}' (Steps: {diffusion_steps}, Seed: {seed})..."
)
seed_everything(int(seed))
# 1. Override diffusion steps
# Handle both EGCL and other architectures
original_T = None
if hasattr(TASK, 'model') and hasattr(TASK.model, 'T'):
original_T = TASK.model.T
TASK.model.T = int(diffusion_steps)
elif hasattr(TASK, 'T'):
original_T = TASK.T
TASK.T = int(diffusion_steps)
try:
# 2. Resolve conditioning mode/inputs once
condition_names = get_condition_names(TASK)
is_conditional = len(condition_names) > 0
if is_conditional:
target_values = parse_condition_row(
target_values_df, condition_names, required=True
)
negative_values = parse_condition_row(
negative_values_df, condition_names, required=False
)
# 3. Sample across runs
xyz_strings = []
summary_rows = []
target_fn = lambda z, t: torch.zeros(
z.size(0), device=z.device, dtype=z.dtype
)
for run_idx in range(int(num_runs)):
# Determine molecule sizes for this run
if "Auto" in size_mode:
if TASK.node_dist_model is not None:
nodesxsample = TASK.node_dist_model.sample(int(num_molecules))
else:
nodesxsample = torch.randint(10, 30, (int(num_molecules),))
else:
nodesxsample = torch.tensor([fixed_size] * int(num_molecules))
nodesxsample = nodesxsample.to(DEVICE).long()
if is_conditional:
one_hot, charges, x, node_mask = TASK.sample_guidance_conitional(
target_function=target_fn,
target_value=target_values,
negative_target_value=negative_values,
nodesxsample=nodesxsample,
gg_scale=0.0,
cfg_scale=float(cfg_scale),
guidance_ver="cfg",
n_frames=0,
fix_noise=False,
)
else:
one_hot, charges, x, node_mask = TASK.sample(
nodesxsample=nodesxsample,
mode="ddpm",
n_frames=0,
fix_noise=False,
)
for i in range(int(num_molecules)):
xyz_str = tensors_to_xyz_string(
one_hot[i],
x[i],
node_mask[i],
TASK.atom_vocab
)
xyz_strings.append(xyz_str)
summary_rows.append(parse_composition(xyz_str))
# 4. Output generation – save zip for bulk download
zip_path = create_xyz_zip(xyz_strings)
# Prepare table with "Name" column
for i, row in enumerate(summary_rows):
row["Name"] = f"Molecule {i+1}"
df_out = pd.DataFrame(summary_rows)
# Reorder to put Name first
cols = ["Name"] + [c for c in df_out.columns if c != "Name"]
df_out = df_out[cols]
choices = [f"Molecule {i+1}" for i in range(len(xyz_strings))]
return (
xyz_strings[0], # current_xyz (triggers JS)
gr.update(choices=choices, value=choices[0]), # selector
xyz_strings, # raw_xyz_state
zip_path, # download_all
df_out, # table
)
except Exception as e:
import traceback
traceback.print_exc()
return "", gr.update(choices=[], value=None), [], None, None
finally:
# Restore original T to be safe
if original_T is not None:
if hasattr(TASK, 'model') and hasattr(TASK.model, 'T'):
TASK.model.T = original_T
elif hasattr(TASK, 'T'):
TASK.T = original_T
# ── 3Dmol.js JavaScript ───────────────────────────────────────
THREEDMOL_HEAD = """
<script src="https://unpkg.com/3dmol/build/3Dmol-min.js"></script>
<style>
#mol3d-container { width: 100%; height: 480px; position: relative; border: 1px solid #ccc; border-radius: 8px; overflow: hidden; cursor: crosshair; }
#measurement-display { background: #f9f9f9; padding: 10px; border-radius: 8px; border: 1px solid #ddd; font-family: monospace; font-size: 0.9em; min-height: 100px; }
#measurement-status b { color: #2c3e50; }
</style>
<script>
var _viewer = null;
var _currentXYZ = null;
var _selectedAtoms = []; // Array of {atom, index, sphereID, labelID}
// Global error handler
window.onerror = function(msg, url, line, col, error) {
var el = document.querySelector('#viewer-debug-log textarea');
if (el) {
el.value = "JS ERROR: " + msg + String.fromCharCode(10) + el.value;
el.dispatchEvent(new Event('input', { bubbles: true }));
}
return false;
};
function logDebug(msg) {
console.log(msg);
var el = document.querySelector('#viewer-debug-log textarea');
if (el) {
el.value = msg + String.fromCharCode(10) + el.value;
el.dispatchEvent(new Event('input', { bubbles: true }));
}
}
function initViewer() {
// Poll for library load
if (typeof $3Dmol === 'undefined') {
setTimeout(initViewer, 200);
return;
}
logDebug("Initializing 3Dmol viewer...");
var el = document.getElementById('mol3d-container');
if (!el) { logDebug("ERROR: Viewer container not found!"); return; }
if (_viewer) {
try { _viewer.clear(); } catch(e) { logDebug("Error clearing viewer: " + e); }
} else {
try {
_viewer = $3Dmol.createViewer(el, {backgroundColor: '0xffffff'});
logDebug("Viewer created.");
} catch(e) {
logDebug("CRITICAL: Failed to create viewer: " + e);
}
}
}
function updateStatus(msg) {
var el = document.getElementById('measurement-status');
if (el) { el.innerHTML = msg; }
}
function getAtomName(a) {
return a.elem + (a.serial || (a.index + 1));
}
function syncMeasurements() {
try {
_viewer.removeAllLabels();
_viewer.removeAllShapes();
// Re-add selection highlights after clearing
_selectedAtoms = _selectedAtoms.map(sa => {
let a = sa.atom;
let sphereID = _viewer.addSphere({center:{x:a.x, y:a.y, z:a.z}, radius:0.55, color:'#ffd400', opacity:0.6, clickable:false});
let labelID = _viewer.addLabel(getAtomName(a), {position:{x:a.x, y:a.y, z:a.z}, fontColor:'black', backgroundColor:'yellow', backgroundOpacity:0.6, fontSize:12, borderColor:'#b58900', borderThickness:1});
return {atom: a, index: sa.index, sphereID: sphereID, labelID: labelID};
});
} catch (e) {
logDebug("syncMeasurements error: " + e);
}
var n = _selectedAtoms.length;
var names = _selectedAtoms.map(sa => getAtomName(sa.atom)).join(', ');
var statusText = "<b>Selected Atoms:</b> " + (names || "None");
if (n >= 2) {
for (let i = 0; i < n - 1; i++) {
let a1 = _selectedAtoms[i].atom;
let a2 = _selectedAtoms[i+1].atom;
let d = Math.sqrt(Math.pow(a1.x-a2.x,2)+Math.pow(a1.y-a2.y,2)+Math.pow(a1.z-a2.z,2)).toFixed(3);
_viewer.addLine({start:{x:a1.x,y:a1.y,z:a1.z}, end:{x:a2.x,y:a2.y,z:a2.z}, color:'yellow', dashed:true});
if (n == 2) statusText += "<br><b>Distance:</b> " + d + " Γ…";
}
}
if (n >= 3) {
let a1 = _selectedAtoms[0].atom, a2 = _selectedAtoms[1].atom, a3 = _selectedAtoms[2].atom;
let v1 = {x:a1.x-a2.x, y:a1.y-a2.y, z:a1.z-a2.z};
let v2 = {x:a3.x-a2.x, y:a3.y-a2.y, z:a3.z-a2.z};
let dot = v1.x*v2.x + v1.y*v2.y + v1.z*v2.z;
let mag1 = Math.sqrt(v1.x*v1.x+v1.y*v1.y+v1.z*v1.z);
let mag2 = Math.sqrt(v2.x*v2.x+v2.y*v2.y+v2.z*v2.z);
let angle = (Math.acos(dot/(mag1*mag2)) * 180 / Math.PI).toFixed(2);
if (n == 3) statusText += "<br><b>Angle:</b> " + angle + "Β°";
}
if (n == 4) {
let p1 = _selectedAtoms[0].atom, p2 = _selectedAtoms[1].atom, p3 = _selectedAtoms[2].atom, p4 = _selectedAtoms[3].atom;
// Dihedral logic
let b1 = {x:p2.x-p1.x, y:p2.y-p1.y, z:p2.z-p1.z};
let b2 = {x:p3.x-p2.x, y:p3.y-p2.y, z:p3.z-p2.z};
let b3 = {x:p4.x-p3.x, y:p4.y-p3.y, z:p4.z-p3.z};
// Normal vectors
let n1 = {x:b1.y*b2.z-b1.z*b2.y, y:b1.z*b2.x-b1.x*b2.z, z:b1.x*b2.y-b1.y*b2.x};
let n2 = {x:b2.y*b3.z-b2.z*b3.y, y:b2.z*b3.x-b2.x*b3.z, z:b2.x*b3.y-b2.y*b3.x};
let m1 = {x:n1.y*b2.z-n1.z*b2.y, y:n1.z*b2.x-n1.x*b2.z, z:n1.x*b2.y-n1.y*b2.x};
let dot = n1.x*n2.x + n1.y*n2.y + n1.z*n2.z;
let x = dot;
let y = (m1.x*n2.x + m1.y*n2.y + m1.z*n2.z) / Math.sqrt(b2.x*b2.x+b2.y*b2.y+b2.z*b2.z);
let chi = (Math.atan2(y, x) * 180 / Math.PI).toFixed(2);
statusText += "<br><b>Dihedral:</b> " + chi + "Β°";
}
updateStatus(statusText);
_viewer.render();
}
function getMouseButton(event) {
if (!event) return 0;
if (typeof event.button === 'number') return event.button;
if (typeof event.which === 'number') {
if (event.which === 3) return 2;
if (event.which === 2) return 1;
if (event.which === 1) return 0;
}
return 0;
}
function handleAtomClick(atom, event) {
var btn = getMouseButton(event);
console.log("handleAtomClick triggered:", atom ? atom.index : "no atom", btn);
if (!atom || atom.index === undefined) return;
// Right click (button 2) -> Deselect
if (btn === 2) {
if (event && event.preventDefault) event.preventDefault();
let idx = _selectedAtoms.findIndex(sa => sa.index === atom.index);
if (idx !== -1) {
console.log("Deselecting atom:", atom.index);
_viewer.removeShape(_selectedAtoms[idx].sphereID);
if (_selectedAtoms[idx].labelID) _viewer.removeLabel(_selectedAtoms[idx].labelID);
_selectedAtoms.splice(idx, 1);
} else {
console.log("Right click on unselected atom - clearing all.");
clearSelections();
}
} else {
// Left click -> Select
if (_selectedAtoms.length >= 4) { console.log("Max 4 atoms selected."); return; }
if (_selectedAtoms.some(sa => sa.index === atom.index)) { console.log("Atom already selected."); return; }
console.log("Selecting atom:", atom.index);
let sphereID = _viewer.addSphere({center:{x:atom.x, y:atom.y, z:atom.z}, radius:0.55, color:'#ffd400', opacity:0.6, clickable:false});
let labelID = _viewer.addLabel(getAtomName(atom), {position:{x:atom.x, y:atom.y, z:atom.z}, fontColor:'black', backgroundColor:'yellow', backgroundOpacity:0.6, fontSize:12, borderColor:'#b58900', borderThickness:1});
_selectedAtoms.push({atom: atom, index: atom.index, sphereID: sphereID, labelID: labelID});
}
syncMeasurements();
}
function clearSelections() {
_selectedAtoms.forEach(sa => {
_viewer.removeShape(sa.sphereID);
if (sa.labelID) _viewer.removeLabel(sa.labelID);
});
_selectedAtoms = [];
syncMeasurements();
}
function loadMolecule(xyzStr, style, bg, showLabels, hideH) {
if (typeof $3Dmol === 'undefined' || !_viewer) {
initViewer();
if (!_viewer) {
setTimeout(() => { loadMolecule(xyzStr, style, bg, showLabels, hideH); }, 200);
return;
}
}
_currentXYZ = xyzStr;
_viewer.clear();
_selectedAtoms = [];
updateStatus("<b>Selected Atoms:</b> None");
if (!xyzStr || xyzStr.length < 5) return;
_viewer.addModel(xyzStr, 'xyz');
applyStyle(style, hideH);
if (showLabels) {
_viewer.addPropertyLabels('elem', {}, {fontSize: 12, fontColor: 'black', backgroundOpacity: 0.3, backgroundColor: 'white', alignment: 'center'});
}
_viewer.setBackgroundColor(bg);
// Ensure model is clickable and route clicks to our handler
// Use viewer-level clickable so we consistently get click events
try {
_viewer.setClickable({}, true, function(atom, viewer, event) {
handleAtomClick(atom, event);
});
} catch (e) {
logDebug("Clickable setup failed: " + e);
}
// Disable context menu on the viewer to allow right-click
var el = document.getElementById('mol3d-container');
if (el) {
el.oncontextmenu = function(e) { e.preventDefault(); return false; };
// Manual right-click listener for empty space clearing
if (!el.dataset.listener) {
el.dataset.listener = "true";
el.addEventListener('mousedown', function(event) {
if (getMouseButton(event) === 2) {
try {
var rect = el.getBoundingClientRect();
var x = event.clientX - rect.left;
var y = event.clientY - rect.top;
var picked = _viewer ? _viewer.pick({x:x, y:y}) : null;
if (!picked || !picked.atom) {
clearSelections();
}
} catch (e) {
clearSelections();
}
}
});
}
}
_viewer.zoomTo();
_viewer.render();
}
function applyStyle(style, hideH) {
if (!_viewer) return;
_viewer.setStyle({}, {});
var sel = hideH ? {elem: ['C','N','O','S','P','F','Cl','Br','I','B','Si','Se','As','Al','Hg','Bi']} : {};
var common = {
clickable: true,
hoverable: true
};
if (hideH) { _viewer.setStyle({elem: 'H'}, {}); }
switch(style) {
case 'Ball and Stick':
_viewer.setStyle(sel, {stick: {radius: 0.15, colorscheme: 'Jmol', ...common}, sphere: {scale: 0.25, colorscheme: 'Jmol', ...common}});
break;
case 'Licorice':
_viewer.setStyle(sel, {stick: {radius: 0.3, colorscheme: 'Jmol', ...common}});
break;
case 'Sphere':
_viewer.setStyle(sel, {sphere: {colorscheme: 'Jmol', ...common}});
break;
case 'Stick':
_viewer.setStyle(sel, {stick: {colorscheme: 'Jmol', ...common}});
break;
default:
_viewer.setStyle(sel, {stick: {radius: 0.15, colorscheme: 'Jmol', ...common}, sphere: {scale: 0.25, colorscheme: 'Jmol', ...common}});
}
_viewer.render();
}
function refreshViewer(style, bg, showLabels, hideH) {
if (!_viewer || !_currentXYZ) return;
loadMolecule(_currentXYZ, style, bg, showLabels, hideH);
}
</script>
"""
# Background color mapping
BG_COLORS = {"White": "0xffffff", "Black": "0x000000", "Grey": "0x333333"}
# ── GRADIO UI ──────────────────────────────────────────────────
with gr.Blocks(title="MolCraftDiffusion", theme=gr.themes.Soft(), head=THREEDMOL_HEAD) as demo:
gr.Markdown(
"""
# πŸ§ͺ MolCraftDiffusion
Generate novel 3D molecular structures using the **MolecularDiffusion** framework.
"""
)
with gr.Row():
with gr.Column(scale=1, min_width=300):
with gr.Group():
gr.Markdown("### βš™οΈ Settings")
model_selector = gr.Dropdown(
choices=AVAILABLE_MODELS,
value=AVAILABLE_MODELS[0] if AVAILABLE_MODELS else None,
label="Model",
interactive=True
)
initial_condition_names = get_condition_names(TASK)
initial_is_conditional = len(initial_condition_names) > 0
mode_status = gr.Markdown(
"Mode: **Conditional**" if initial_is_conditional else "Mode: **Unconditional**"
)
num_mol = gr.Slider(1, 12, value=4, step=1, label="Number of Molecules")
num_runs = gr.Slider(1, 20, value=1, step=1, label="Number of Runs")
size_mode = gr.Radio(
["Auto (from training data)", "Fixed size"],
value="Auto (from training data)",
label="Molecule Size Strategy"
)
fixed_size = gr.Slider(
5, 50, value=15, step=1,
label="Atoms per Molecule",
visible=False
)
def toggle_fixed_size(choice):
return gr.update(visible=(choice == "Fixed size"))
size_mode.change(
fn=toggle_fixed_size,
inputs=[size_mode],
outputs=[fixed_size]
)
diffusion_steps = gr.Slider(
50, 1000, value=300, step=50,
label="Diffusion Steps (Higher = Slower but better)"
)
seed = gr.Number(value=42, label="Random Seed", precision=0)
with gr.Accordion("🎯 Conditional CFG Controls", open=True, visible=initial_is_conditional) as cond_controls:
gr.Markdown("Fill target values for all properties. Negative values are optional.")
target_values = gr.Dataframe(
headers=["property", "value"],
value=[[name, 0.0] for name in initial_condition_names] if initial_is_conditional else [["property", None]],
row_count=(len(initial_condition_names), "fixed") if initial_is_conditional else (1, "fixed"),
col_count=(2, "fixed"),
datatype=["str", "number"],
label="Target values (required)",
height=280,
interactive=True,
)
negative_values = gr.Dataframe(
headers=["property", "value"],
value=[[name, None] for name in initial_condition_names] if initial_is_conditional else [["property", None]],
row_count=(len(initial_condition_names), "fixed") if initial_is_conditional else (1, "fixed"),
col_count=(2, "fixed"),
datatype=["str", "number"],
label="Negative values (optional)",
height=280,
interactive=True,
)
cfg_scale = gr.Slider(
0.0,
10.0,
value=1.0,
step=0.1,
label="CFG scale",
interactive=True,
)
with gr.Accordion("🎨 Viewer Options", open=True):
mol_style = gr.Radio(
["Ball and Stick", "Licorice", "Sphere", "Stick"],
value="Ball and Stick",
label="3D Style"
)
bg_color = gr.Radio(
["White", "Black", "Grey"],
value="White",
label="Background"
)
show_labels = gr.Checkbox(label="Show atom labels", value=False)
hide_h = gr.Checkbox(label="Hide hydrogens", value=False)
dl_format = gr.Radio(
["PDB", "XYZ"],
value="PDB",
label="Download Format"
)
btn = gr.Button("πŸš€ Generate Molecules", variant="primary", size="lg")
with gr.Column(scale=2):
gr.Markdown("### 🧬 Generated Molecules")
with gr.Row():
mol_selector = gr.Dropdown(
label="Select Molecule", choices=[], interactive=True, scale=3
)
single_dl = gr.File(label="Download Selection", scale=2)
# 3Dmol.js viewer container
with gr.Row():
with gr.Column(scale=4):
viewer_html = gr.HTML(
value='<div id="mol3d-container"></div>'
)
with gr.Column(scale=1):
gr.Markdown("#### πŸ“ Measurements")
status_md = gr.HTML(
value='<b>Selected Atoms:</b> None',
elem_id="measurement-status"
)
gr.Markdown(
"*Click atoms to select (max 4).*\n"
"*Right-click empty space to clear all.*"
)
reset_btn = gr.Button("Reset Selections", size="sm")
debug_log = gr.Textbox(label="Viewer Debug Log", interactive=False, elem_id="viewer-debug-log", lines=3)
with gr.Row():
preview = gr.Image(label="2D Preview (PNG)", scale=1)
with gr.Row():
table = gr.Dataframe(label="Properties")
download_all = gr.File(label="Download All (.zip)")
# Hidden states
raw_xyz_state = gr.State([]) # List of raw XYZ strings
current_xyz = gr.Textbox(visible=False) # Current molecule XYZ for JS
def mol_to_png(xyz_str):
"""Generate a 2D PNG preview using RDKit."""
from rdkit.Chem import Draw
try:
lines = xyz_str.strip().splitlines()
if len(lines) < 3: return None
mol = Chem.RWMol()
pos = []
for line in lines[2:]:
parts = line.split()
if len(parts) >= 4:
atom = Chem.Atom(parts[0])
mol.AddAtom(atom)
pos.append([float(parts[1]), float(parts[2]), float(parts[3])])
adj = build_adjacency_matrix(np.array(pos), [a.GetAtomicNum() for a in mol.GetAtoms()], scale=1.2)
for i in range(len(pos)):
for j in range(i+1, len(pos)):
if adj[i, j]:
mol.AddBond(i, j, Chem.BondType.SINGLE)
rd_mol = mol.GetMol()
AllChem.Compute2DCoords(rd_mol)
img = Draw.MolToImage(rd_mol, size=(400, 400))
tmp_img = tempfile.NamedTemporaryFile(suffix=".png", delete=False)
img.save(tmp_img.name)
return tmp_img.name
except Exception as e:
print(f"Drawing error: {e}")
return None
def select_molecule(choice, raw_xyzs, fmt):
"""Update download file and current XYZ when molecule selection changes."""
if not raw_xyzs or choice is None:
return None, None, ""
try:
idx = int(choice.split()[-1]) - 1
if not (0 <= idx < len(raw_xyzs)):
return None, None, ""
except (ValueError, IndexError):
return None, None, ""
dl_path = save_to_format(raw_xyzs[idx], idx, fmt)
png_path = mol_to_png(raw_xyzs[idx])
return dl_path, png_path, raw_xyzs[idx]
# JS to call loadMolecule when current_xyz changes
LOAD_JS = """
(xyz, style, bg, labels, hideH) => {
var bgHex = {'White':'0xffffff','Black':'0x000000','Grey':'0x333333'}[bg] || '0xffffff';
setTimeout(() => { loadMolecule(xyz, style, bgHex, labels, hideH); }, 200);
return [xyz, style, bg, labels, hideH];
}
"""
REFRESH_JS = """
(style, bg, labels, hideH) => {
var bgHex = {'White':'0xffffff','Black':'0x000000','Grey':'0x333333'}[bg] || '0xffffff';
refreshViewer(style, bgHex, labels, hideH);
return [style, bg, labels, hideH];
}
"""
def on_model_change(model_name):
global TASK, LOADED_MODEL_NAME
if model_name and model_name != LOADED_MODEL_NAME:
load_model(model_name)
condition_names = get_condition_names(TASK)
is_conditional = len(condition_names) > 0
if is_conditional:
n = len(condition_names)
target_update = gr.update(
headers=["property", "value"],
value=[[name, 0.0] for name in condition_names],
row_count=(n, "fixed"),
col_count=(2, "fixed"),
datatype=["str", "number"],
)
negative_update = gr.update(
headers=["property", "value"],
value=[[name, None] for name in condition_names],
row_count=(n, "fixed"),
col_count=(2, "fixed"),
datatype=["str", "number"],
)
mode_msg = "Mode: **Conditional**"
else:
target_update = gr.update(
headers=["property", "value"],
value=[["property", None]],
row_count=(1, "fixed"),
col_count=(2, "fixed"),
datatype=["str", "number"],
)
negative_update = gr.update(
headers=["property", "value"],
value=[["property", None]],
row_count=(1, "fixed"),
col_count=(2, "fixed"),
datatype=["str", "number"],
)
mode_msg = "Mode: **Unconditional**"
return (
mode_msg,
gr.update(visible=is_conditional),
target_update,
negative_update,
gr.update(value=1.0),
)
model_selector.change(
fn=on_model_change,
inputs=[model_selector],
outputs=[mode_status, cond_controls, target_values, negative_values, cfg_scale],
)
# When molecule selection changes: update download + preview, then trigger JS
mol_selector.change(
fn=select_molecule,
inputs=[mol_selector, raw_xyz_state, dl_format],
outputs=[single_dl, preview, current_xyz],
).then(
fn=None,
inputs=[current_xyz, mol_style, bg_color, show_labels, hide_h],
outputs=None,
js=LOAD_JS,
)
# When style/bg/labels/hideH changes: refresh viewer via JS only
for ctrl in [mol_style, bg_color, show_labels, hide_h]:
ctrl.change(
fn=None,
inputs=[mol_style, bg_color, show_labels, hide_h],
outputs=None,
js=REFRESH_JS,
)
# When format changes: just update the download file
def update_dl(choice, raw_xyzs, fmt):
if not raw_xyzs or choice is None:
return None
try:
idx = int(choice.split()[-1]) - 1
if 0 <= idx < len(raw_xyzs):
return save_to_format(raw_xyzs[idx], idx, fmt)
except (ValueError, IndexError):
pass
return None
dl_format.change(
fn=update_dl,
inputs=[mol_selector, raw_xyz_state, dl_format],
outputs=[single_dl],
)
# Reset button
reset_btn.click(fn=None, inputs=[], outputs=[], js="clearSelections")
# Generate button
btn.click(
fn=generate,
inputs=[
model_selector,
num_mol,
num_runs,
size_mode,
fixed_size,
diffusion_steps,
seed,
target_values,
negative_values,
cfg_scale,
],
outputs=[current_xyz, mol_selector, raw_xyz_state, download_all, table]
).then(
fn=None,
inputs=[current_xyz, mol_style, bg_color, show_labels, hide_h],
outputs=None,
js=LOAD_JS,
)
gr.Markdown("---")
gr.Markdown(
"Powered by [MolecularDiffusion](https://github.com/pregHosh/MolCraftDiffusion). "
"Model: EDM Pretrained on QM9/GEOM."
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0")