import base64 import io import time import torch QM9_ATOM_TYPES = ["C", "N", "O", "F"] STATE_BLOB_MAX_BYTES = 10 * 1024 * 1024 # 10 MB REQUIRED_STATE_KEYS = { "X", "E", "y", "n_nodes", "dataset_id", "model_type", "T", "n", "m", "t", "t_prime", "gibbs_chain_freq", "inner_step", "step", } # --------------------------------------------------------------------------- # Model-type helpers # --------------------------------------------------------------------------- def _is_discrete(model): from diffusion_model_discrete import DiscreteDenoisingDiffusion return isinstance(model, DiscreteDenoisingDiffusion) def _build_node_mask(n_nodes, n_max, model): """bool for discrete, float32 for continuous.""" arange = torch.arange(n_max, device=n_nodes.device).unsqueeze(0) mask = arange < n_nodes.unsqueeze(1) # (1, n_max) bool return mask if _is_discrete(model) else mask.float() def _sample_initial_noise(model, n_max, node_mask): from src.diffusion import diffusion_utils if _is_discrete(model): return diffusion_utils.sample_discrete_feature_noise( limit_dist=model.limit_dist, node_mask=node_mask) else: bs = node_mask.shape[0] return diffusion_utils.sample_feature_noise( X_size=(bs, n_max, model.Xdim_output), E_size=(bs, n_max, n_max, model.Edim_output), y_size=(bs, model.ydim_output), node_mask=node_mask) def _denoising_step(model, s_t, t_t, X, E, y, node_mask): """One denoising step. Returns (X_soft, E_soft, y_soft, X_int, E_int).""" if _is_discrete(model): sampled_s, discrete_s = model.sample_p_zs_given_zt(s_t, t_t, X, E, y, node_mask) # .type_as(y_t) in the model can cast collapsed ints to float — force back to long return sampled_s.X, sampled_s.E, sampled_s.y, discrete_s.X.long(), discrete_s.E.long() else: from src import utils z_s = model.sample_p_zs_given_zt(s=s_t, t=t_t, X_t=X, E_t=E, y_t=y, node_mask=node_mask) unnorm = utils.unnormalize( z_s.X, z_s.E, z_s.y, model.norm_values, model.norm_biases, node_mask, collapse=True) return z_s.X, z_s.E, z_s.y, unnorm.X, unnorm.E def _gibbs_aggregate(model, X): if _is_discrete(model): return torch.median(X, dim=1).values else: return torch.mean(X, dim=1) def _collapse_final(model, X, E, y, node_mask): """Returns (X_int, E_int) integer tensors. Symmetrize E first: MultiProx aggregation (mean / median over multiple chains) can introduce ULP-level asymmetry that survives into pred_E and breaks the model's strict ``assert (pred_E == pred_E.T).all()`` on some BLAS / vectorization stacks (notably the Linux ``+cu118`` torch wheel inside the deployment container, while the same code runs fine on the Windows wheel in dev). Symmetrizing here is a no-op when the input is already symmetric and a 1-line invariant fix when it isn't. """ E = (E + E.transpose(1, 2)) / 2 if _is_discrete(model): from src.utils import PlaceHolder final = PlaceHolder(X=X, E=E, y=y).mask(node_mask, collapse=True) return final.X.long(), final.E.long() else: final = model.sample_discrete_graph_given_z0(X, E, y, node_mask) return final.X, final.E # --------------------------------------------------------------------------- # Main inference generators — yield progress dicts, then a result dict # --------------------------------------------------------------------------- def run_standard_generation(model, num_nodes, diffusion_steps, chain_frames, dataset_id): device = next(model.parameters()).device if num_nodes is None: n_nodes = model.node_dist.sample_n(1, device) else: n_nodes = torch.tensor([num_nodes], dtype=torch.long, device=device) n_max = n_nodes.item() node_mask = _build_node_mask(n_nodes, n_max, model) z_T = _sample_initial_noise(model, n_max, node_mask) X, E, y = z_T.X, z_T.E, z_T.y frame_interval = max(1, diffusion_steps // chain_frames) gif_frames = [] t0 = time.time() with torch.no_grad(): for s_idx in reversed(range(diffusion_steps)): s_t = (s_idx / diffusion_steps) * torch.ones((1, 1), device=device) t_t = ((s_idx + 1) / diffusion_steps) * torch.ones((1, 1), device=device) X, E, y, X_int, E_int = _denoising_step(model, s_t, t_t, X, E, y, node_mask) step = diffusion_steps - 1 - s_idx is_frame = step % frame_interval == 0 or s_idx == 0 if is_frame: frame_img = render_graph(X_int[0, :n_max], E_int[0, :n_max, :n_max], dataset_id) gif_frames.append(frame_img) event = { "type": "progress", "phase": "denoise", "step": step + 1, "total_steps": diffusion_steps, "elapsed_ms": int((time.time() - t0) * 1000), } if is_frame: event["preview"] = _pil_to_b64(frame_img) yield event X_final, E_final = _collapse_final(model, X, E, y, node_mask) image_b64 = _pil_to_b64(render_graph(X_final[0, :n_max], E_final[0, :n_max, :n_max], dataset_id)) elapsed_ms = int((time.time() - t0) * 1000) yield { "type": "result", "image": image_b64, "chain_gif": _frames_to_gif_b64(gif_frames), "inference_time_ms": elapsed_ms, } def run_multiprox_init(model, num_nodes, n, m, t, t_prime, gibbs_chain_freq, dataset_id): device = next(model.parameters()).device if num_nodes is None: n_nodes = model.node_dist.sample_n(1, device) else: n_nodes = torch.tensor([num_nodes], dtype=torch.long, device=device) n_max = n_nodes.item() node_mask = _build_node_mask(n_nodes, n_max, model) t0 = time.time() z_samples = [] for i in range(m): z_samples.append(_sample_initial_noise(model, n_max, node_mask)) if (i + 1) % max(1, m // 10) == 0 or i == m - 1: yield { "type": "progress", "phase": "noise_init", "step": i + 1, "total_steps": m, "elapsed_ms": int((time.time() - t0) * 1000), } X = torch.stack([z.X for z in z_samples], dim=1) # (1, M, n_max, Xdim) E = torch.stack([z.E for z in z_samples], dim=1) y = torch.stack([z.y for z in z_samples], dim=1) agg_X = _gibbs_aggregate(model, X) agg_E = _gibbs_aggregate(model, E) agg_y = _gibbs_aggregate(model, y.float()) X_int, E_int = _collapse_final(model, agg_X, agg_E, agg_y, node_mask) image_b64 = _pil_to_b64(render_graph(X_int[0, :n_max], E_int[0, :n_max, :n_max], dataset_id)) elapsed_ms = int((time.time() - t0) * 1000) state = { "X": X.cpu(), "E": E.cpu(), "y": y.cpu(), "n_nodes": n_nodes.cpu(), "dataset_id": dataset_id, "model_type": None, # filled by registry "T": model.T, "n": n, "m": m, "t": t, "t_prime": t_prime, "gibbs_chain_freq": gibbs_chain_freq, "inner_step": 0, "step": 0, } yield { "type": "result", "state": state, "image": image_b64, "inference_time_ms": elapsed_ms, } def run_multiprox_step(model, state_dict, dataset_id): device = next(model.parameters()).device X = state_dict["X"].to(device) E = state_dict["E"].to(device) y = state_dict["y"].to(device) n_nodes = state_dict["n_nodes"].to(device) T = state_dict["T"] n = state_dict["n"] m = state_dict["m"] t = state_dict["t"] t_prime = state_dict["t_prime"] gibbs_chain_freq = state_dict["gibbs_chain_freq"] inner_step = state_dict["inner_step"] step = state_dict["step"] n_max = X.shape[2] node_mask = _build_node_mask(n_nodes, n_max, model) fixed_t = t * torch.ones((1, 1), dtype=torch.float, device=device) fixed_s = fixed_t - (1.0 / T) steps_this_call = min(gibbs_chain_freq, m - inner_step) t0 = time.time() with torch.no_grad(): for i in range(steps_this_call): k = inner_step + i avg_X = _gibbs_aggregate(model, X) avg_E = _gibbs_aggregate(model, E) avg_y = _gibbs_aggregate(model, y.float()) denoised_X, denoised_E, denoised_y, _, _ = _denoising_step( model, fixed_s, fixed_t, avg_X, avg_E, avg_y, node_mask) old_t2 = model.gibbs_fixed_t_2 model.gibbs_fixed_t_2 = t # safe: _inference_lock held by registry noisy = model.apply_noise(denoised_X, denoised_E, denoised_y, node_mask, gibbs=True) model.gibbs_fixed_t_2 = old_t2 X[:, k] = noisy["X_t"] E[:, k] = noisy["E_t"] y[:, k] = noisy["y_t"] # Preview: aggregate + collapse current Gibbs state prev_X = _gibbs_aggregate(model, X) prev_E = _gibbs_aggregate(model, E) prev_y = _gibbs_aggregate(model, y.float()) prev_Xi, prev_Ei = _collapse_final(model, prev_X, prev_E, prev_y, node_mask) yield { "type": "progress", "phase": "gibbs", "step": i + 1, "total_steps": steps_this_call, "elapsed_ms": int((time.time() - t0) * 1000), "preview": _pil_to_b64(render_graph(prev_Xi[0, :n_max], prev_Ei[0, :n_max, :n_max], dataset_id)), } new_inner_step = inner_step + steps_this_call round_complete = new_inner_step >= m if round_complete: new_inner_step = 0 new_step = step + 1 else: new_step = step done = round_complete and new_step >= n # Refinement pass — always produce a clean render P = int((t - t_prime) * T) + 1 refine_preview_interval = max(1, P // 10) cur_X = _gibbs_aggregate(model, X) cur_E = _gibbs_aggregate(model, E) cur_y = _gibbs_aggregate(model, y.float()) for j in range(P): s_ref = (t - (j + 1) / T) * torch.ones((1, 1), dtype=torch.float, device=device) t_ref = (t - j / T) * torch.ones((1, 1), dtype=torch.float, device=device) cur_X, cur_E, cur_y, cur_Xi, cur_Ei = _denoising_step( model, s_ref, t_ref, cur_X, cur_E, cur_y, node_mask) is_frame = (j + 1) % refine_preview_interval == 0 or j == P - 1 event = { "type": "progress", "phase": "refine", "step": j + 1, "total_steps": P, "elapsed_ms": int((time.time() - t0) * 1000), } if is_frame: event["preview"] = _pil_to_b64( render_graph(cur_Xi[0, :n_max], cur_Ei[0, :n_max, :n_max], dataset_id)) yield event X_int, E_int = _collapse_final(model, cur_X, cur_E, cur_y, node_mask) image_b64 = _pil_to_b64(render_graph(X_int[0, :n_max], E_int[0, :n_max, :n_max], dataset_id)) elapsed_ms = int((time.time() - t0) * 1000) updated_state = { **state_dict, "X": X.cpu(), "E": E.cpu(), "y": y.cpu(), "step": new_step, "inner_step": new_inner_step, } yield { "type": "result", "state": updated_state, "image": image_b64, "round_complete": round_complete, "done": done, "inference_time_ms": elapsed_ms, } # --------------------------------------------------------------------------- # State blob serialisation # --------------------------------------------------------------------------- def encode_state_blob(state_dict): buf = io.BytesIO() torch.save(state_dict, buf) return base64.b64encode(buf.getvalue()).decode("ascii") def decode_state_blob(b64_str): try: raw = base64.b64decode(b64_str) except Exception: raise ValueError("state is not valid base64") if len(raw) > STATE_BLOB_MAX_BYTES: raise ValueError(f"state blob exceeds {STATE_BLOB_MAX_BYTES // (1024 * 1024)} MB limit") try: state = torch.load(io.BytesIO(raw), weights_only=False) except Exception as exc: raise ValueError(f"state could not be deserialized: {exc}") from exc missing = REQUIRED_STATE_KEYS - set(state.keys()) if missing: raise ValueError(f"state missing keys: {missing}") if not isinstance(state["X"], torch.Tensor) or state["X"].dim() != 4: raise ValueError("state['X'] must be a 4-D tensor") if not isinstance(state["E"], torch.Tensor) or state["E"].dim() != 5: raise ValueError("state['E'] must be a 5-D tensor") return state # --------------------------------------------------------------------------- # Visualisation # --------------------------------------------------------------------------- def render_graph(X_int, E_int, dataset_id): """Render a single graph to PIL Image. X_int/E_int are 1-D / 2-D integer tensors.""" if dataset_id == "qm9": return _render_qm9(X_int, E_int) else: return _render_comm20(X_int, E_int) def _render_qm9(X_int, E_int): from rdkit import Chem from rdkit.Chem import Draw from rdkit.Chem.rdchem import BondType bond_map = {1: BondType.SINGLE, 2: BondType.DOUBLE, 3: BondType.TRIPLE, 4: BondType.AROMATIC} mol = Chem.RWMol() x = X_int.cpu().tolist() valid_atoms = [i for i, a in enumerate(x) if a >= 0] idx_map = {} for i in valid_atoms: atom_sym = QM9_ATOM_TYPES[x[i]] if x[i] < len(QM9_ATOM_TYPES) else "C" idx_map[i] = mol.AddAtom(Chem.Atom(atom_sym)) e = E_int.cpu().tolist() for i in valid_atoms: for j in valid_atoms: if j <= i: continue bond_type_idx = e[i][j] if bond_type_idx > 0 and bond_type_idx in bond_map: mol.AddBond(idx_map[i], idx_map[j], bond_map[bond_type_idx]) try: img = Draw.MolToImage(mol.GetMol(), size=(300, 300)) except Exception: # If RDKit can't sanitize, draw the raw RWMol img = Draw.MolToImage(mol, size=(300, 300)) return img _COMM20_GREEN_LOW = (212, 237, 218) _COMM20_GREEN_MID = (82, 180, 120) _COMM20_GREEN_HIGH = (22, 80, 50) def _comm20_green_rgb(t): """Map [-1, 1] → RGB via a green palette (pale mint → vivid green → deep forest).""" t = max(-1.0, min(1.0, float(t))) if t < 0: w = t + 1.0 a, b = _COMM20_GREEN_LOW, _COMM20_GREEN_MID else: w = t a, b = _COMM20_GREEN_MID, _COMM20_GREEN_HIGH return ( int(a[0] + (b[0] - a[0]) * w), int(a[1] + (b[1] - a[1]) * w), int(a[2] + (b[2] - a[2]) * w), ) def _render_comm20(X_int, E_int): """Render a community graph as an undirected spring-layout plot. Mirrors MultiProxAn's ``visualize_non_molecule``: largest connected component, spring_layout, normalized-Laplacian eigenvector for node colouring (swapped to a site-green palette), no labels, grey edges. Uses pure PIL + networkx and ``torch.linalg.eigh`` to avoid the matplotlib / numpy MKL DLL conflicts on Windows. """ import networkx as nx from PIL import Image, ImageDraw e = E_int.cpu().tolist() n = len(e) G = nx.Graph() G.add_nodes_from(range(n)) for i in range(n): for j in range(i + 1, n): if e[i][j] > 0: G.add_edge(i, j) # Largest connected component only (matches visualize_non_molecule(largest_component=True)). components = sorted(nx.connected_components(G), key=len, reverse=True) graph = G.subgraph(components[0]).copy() if components else G size = 720 img = Image.new("RGB", (size, size), "white") draw = ImageDraw.Draw(img) if graph.number_of_nodes() == 0: return img pos = nx.spring_layout(graph, iterations=100, seed=42) # Normalized Laplacian eigenvector for node colouring (torch avoids numpy MKL DLL clash). L = nx.normalized_laplacian_matrix(graph).toarray() L_t = torch.from_numpy(L).to(torch.float64) _, U_t = torch.linalg.eigh(L_t) U = U_t.numpy() eigen_dim = 1 if U.shape[1] > 1 else 0 vec = U[:, eigen_dim] m_abs = max(abs(vec.min()), abs(vec.max()), 1e-9) node_color = {n: _comm20_green_rgb(vec[i] / m_abs) for i, n in enumerate(graph.nodes())} margin = 60 scale = (size - 2 * margin) / 2 cx, cy = size / 2, size / 2 pixel_pos = {k: (cx + v[0] * scale, cy - v[1] * scale) for k, v in pos.items()} for i, j in graph.edges(): draw.line([pixel_pos[i], pixel_pos[j]], fill="#9a9a9a", width=2) node_r = 14 for k, (x, y) in pixel_pos.items(): r, g, b = node_color[k] draw.ellipse([x - node_r, y - node_r, x + node_r, y + node_r], fill=(r, g, b), outline="#333333", width=2) return img def _pil_to_b64(img): buf = io.BytesIO() img.save(buf, format="PNG") return "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode("ascii") def _frames_to_gif_b64(frames): if not frames: return None buf = io.BytesIO() frames[0].save( buf, format="GIF", save_all=True, append_images=frames[1:], duration=150, loop=0, ) return "data:image/gif;base64," + base64.b64encode(buf.getvalue()).decode("ascii")