| import base64 |
| import io |
| import time |
|
|
| import torch |
|
|
| QM9_ATOM_TYPES = ["C", "N", "O", "F"] |
| STATE_BLOB_MAX_BYTES = 10 * 1024 * 1024 |
| REQUIRED_STATE_KEYS = { |
| "X", "E", "y", "n_nodes", "dataset_id", "model_type", "T", |
| "n", "m", "t", "t_prime", "gibbs_chain_freq", "inner_step", "step", |
| } |
|
|
|
|
| |
| |
| |
|
|
| def _is_discrete(model): |
| from diffusion_model_discrete import DiscreteDenoisingDiffusion |
| return isinstance(model, DiscreteDenoisingDiffusion) |
|
|
|
|
| def _build_node_mask(n_nodes, n_max, model): |
| """bool for discrete, float32 for continuous.""" |
| arange = torch.arange(n_max, device=n_nodes.device).unsqueeze(0) |
| mask = arange < n_nodes.unsqueeze(1) |
| return mask if _is_discrete(model) else mask.float() |
|
|
|
|
| def _sample_initial_noise(model, n_max, node_mask): |
| from src.diffusion import diffusion_utils |
| if _is_discrete(model): |
| return diffusion_utils.sample_discrete_feature_noise( |
| limit_dist=model.limit_dist, node_mask=node_mask) |
| else: |
| bs = node_mask.shape[0] |
| return diffusion_utils.sample_feature_noise( |
| X_size=(bs, n_max, model.Xdim_output), |
| E_size=(bs, n_max, n_max, model.Edim_output), |
| y_size=(bs, model.ydim_output), |
| node_mask=node_mask) |
|
|
|
|
| def _denoising_step(model, s_t, t_t, X, E, y, node_mask): |
| """One denoising step. Returns (X_soft, E_soft, y_soft, X_int, E_int).""" |
| if _is_discrete(model): |
| sampled_s, discrete_s = model.sample_p_zs_given_zt(s_t, t_t, X, E, y, node_mask) |
| |
| return sampled_s.X, sampled_s.E, sampled_s.y, discrete_s.X.long(), discrete_s.E.long() |
| else: |
| from src import utils |
| z_s = model.sample_p_zs_given_zt(s=s_t, t=t_t, X_t=X, E_t=E, y_t=y, node_mask=node_mask) |
| unnorm = utils.unnormalize( |
| z_s.X, z_s.E, z_s.y, |
| model.norm_values, model.norm_biases, node_mask, collapse=True) |
| return z_s.X, z_s.E, z_s.y, unnorm.X, unnorm.E |
|
|
|
|
| def _gibbs_aggregate(model, X): |
| if _is_discrete(model): |
| return torch.median(X, dim=1).values |
| else: |
| return torch.mean(X, dim=1) |
|
|
|
|
| def _collapse_final(model, X, E, y, node_mask): |
| """Returns (X_int, E_int) integer tensors. |
| |
| Symmetrize E first: MultiProx aggregation (mean / median over multiple |
| chains) can introduce ULP-level asymmetry that survives into pred_E and |
| breaks the model's strict ``assert (pred_E == pred_E.T).all()`` on some |
| BLAS / vectorization stacks (notably the Linux ``+cu118`` torch wheel |
| inside the deployment container, while the same code runs fine on the |
| Windows wheel in dev). Symmetrizing here is a no-op when the input is |
| already symmetric and a 1-line invariant fix when it isn't. |
| """ |
| E = (E + E.transpose(1, 2)) / 2 |
| if _is_discrete(model): |
| from src.utils import PlaceHolder |
| final = PlaceHolder(X=X, E=E, y=y).mask(node_mask, collapse=True) |
| return final.X.long(), final.E.long() |
| else: |
| final = model.sample_discrete_graph_given_z0(X, E, y, node_mask) |
| return final.X, final.E |
|
|
|
|
| |
| |
| |
|
|
| def run_standard_generation(model, num_nodes, diffusion_steps, chain_frames, dataset_id): |
| device = next(model.parameters()).device |
| if num_nodes is None: |
| n_nodes = model.node_dist.sample_n(1, device) |
| else: |
| n_nodes = torch.tensor([num_nodes], dtype=torch.long, device=device) |
| n_max = n_nodes.item() |
| node_mask = _build_node_mask(n_nodes, n_max, model) |
|
|
| z_T = _sample_initial_noise(model, n_max, node_mask) |
| X, E, y = z_T.X, z_T.E, z_T.y |
|
|
| frame_interval = max(1, diffusion_steps // chain_frames) |
| gif_frames = [] |
|
|
| t0 = time.time() |
| with torch.no_grad(): |
| for s_idx in reversed(range(diffusion_steps)): |
| s_t = (s_idx / diffusion_steps) * torch.ones((1, 1), device=device) |
| t_t = ((s_idx + 1) / diffusion_steps) * torch.ones((1, 1), device=device) |
| X, E, y, X_int, E_int = _denoising_step(model, s_t, t_t, X, E, y, node_mask) |
| step = diffusion_steps - 1 - s_idx |
| is_frame = step % frame_interval == 0 or s_idx == 0 |
| if is_frame: |
| frame_img = render_graph(X_int[0, :n_max], E_int[0, :n_max, :n_max], dataset_id) |
| gif_frames.append(frame_img) |
| event = { |
| "type": "progress", |
| "phase": "denoise", |
| "step": step + 1, |
| "total_steps": diffusion_steps, |
| "elapsed_ms": int((time.time() - t0) * 1000), |
| } |
| if is_frame: |
| event["preview"] = _pil_to_b64(frame_img) |
| yield event |
|
|
| X_final, E_final = _collapse_final(model, X, E, y, node_mask) |
| image_b64 = _pil_to_b64(render_graph(X_final[0, :n_max], E_final[0, :n_max, :n_max], dataset_id)) |
| elapsed_ms = int((time.time() - t0) * 1000) |
| yield { |
| "type": "result", |
| "image": image_b64, |
| "chain_gif": _frames_to_gif_b64(gif_frames), |
| "inference_time_ms": elapsed_ms, |
| } |
|
|
|
|
| def run_multiprox_init(model, num_nodes, n, m, t, t_prime, gibbs_chain_freq, dataset_id): |
| device = next(model.parameters()).device |
| if num_nodes is None: |
| n_nodes = model.node_dist.sample_n(1, device) |
| else: |
| n_nodes = torch.tensor([num_nodes], dtype=torch.long, device=device) |
| n_max = n_nodes.item() |
| node_mask = _build_node_mask(n_nodes, n_max, model) |
|
|
| t0 = time.time() |
| z_samples = [] |
| for i in range(m): |
| z_samples.append(_sample_initial_noise(model, n_max, node_mask)) |
| if (i + 1) % max(1, m // 10) == 0 or i == m - 1: |
| yield { |
| "type": "progress", |
| "phase": "noise_init", |
| "step": i + 1, |
| "total_steps": m, |
| "elapsed_ms": int((time.time() - t0) * 1000), |
| } |
|
|
| X = torch.stack([z.X for z in z_samples], dim=1) |
| E = torch.stack([z.E for z in z_samples], dim=1) |
| y = torch.stack([z.y for z in z_samples], dim=1) |
|
|
| agg_X = _gibbs_aggregate(model, X) |
| agg_E = _gibbs_aggregate(model, E) |
| agg_y = _gibbs_aggregate(model, y.float()) |
| X_int, E_int = _collapse_final(model, agg_X, agg_E, agg_y, node_mask) |
| image_b64 = _pil_to_b64(render_graph(X_int[0, :n_max], E_int[0, :n_max, :n_max], dataset_id)) |
| elapsed_ms = int((time.time() - t0) * 1000) |
|
|
| state = { |
| "X": X.cpu(), "E": E.cpu(), "y": y.cpu(), "n_nodes": n_nodes.cpu(), |
| "dataset_id": dataset_id, "model_type": None, |
| "T": model.T, "n": n, "m": m, "t": t, "t_prime": t_prime, |
| "gibbs_chain_freq": gibbs_chain_freq, "inner_step": 0, "step": 0, |
| } |
| yield { |
| "type": "result", |
| "state": state, |
| "image": image_b64, |
| "inference_time_ms": elapsed_ms, |
| } |
|
|
|
|
| def run_multiprox_step(model, state_dict, dataset_id): |
| device = next(model.parameters()).device |
| X = state_dict["X"].to(device) |
| E = state_dict["E"].to(device) |
| y = state_dict["y"].to(device) |
| n_nodes = state_dict["n_nodes"].to(device) |
| T = state_dict["T"] |
| n = state_dict["n"] |
| m = state_dict["m"] |
| t = state_dict["t"] |
| t_prime = state_dict["t_prime"] |
| gibbs_chain_freq = state_dict["gibbs_chain_freq"] |
| inner_step = state_dict["inner_step"] |
| step = state_dict["step"] |
|
|
| n_max = X.shape[2] |
| node_mask = _build_node_mask(n_nodes, n_max, model) |
|
|
| fixed_t = t * torch.ones((1, 1), dtype=torch.float, device=device) |
| fixed_s = fixed_t - (1.0 / T) |
|
|
| steps_this_call = min(gibbs_chain_freq, m - inner_step) |
|
|
| t0 = time.time() |
| with torch.no_grad(): |
| for i in range(steps_this_call): |
| k = inner_step + i |
| avg_X = _gibbs_aggregate(model, X) |
| avg_E = _gibbs_aggregate(model, E) |
| avg_y = _gibbs_aggregate(model, y.float()) |
| denoised_X, denoised_E, denoised_y, _, _ = _denoising_step( |
| model, fixed_s, fixed_t, avg_X, avg_E, avg_y, node_mask) |
| old_t2 = model.gibbs_fixed_t_2 |
| model.gibbs_fixed_t_2 = t |
| noisy = model.apply_noise(denoised_X, denoised_E, denoised_y, node_mask, gibbs=True) |
| model.gibbs_fixed_t_2 = old_t2 |
| X[:, k] = noisy["X_t"] |
| E[:, k] = noisy["E_t"] |
| y[:, k] = noisy["y_t"] |
| |
| prev_X = _gibbs_aggregate(model, X) |
| prev_E = _gibbs_aggregate(model, E) |
| prev_y = _gibbs_aggregate(model, y.float()) |
| prev_Xi, prev_Ei = _collapse_final(model, prev_X, prev_E, prev_y, node_mask) |
| yield { |
| "type": "progress", |
| "phase": "gibbs", |
| "step": i + 1, |
| "total_steps": steps_this_call, |
| "elapsed_ms": int((time.time() - t0) * 1000), |
| "preview": _pil_to_b64(render_graph(prev_Xi[0, :n_max], prev_Ei[0, :n_max, :n_max], dataset_id)), |
| } |
|
|
| new_inner_step = inner_step + steps_this_call |
| round_complete = new_inner_step >= m |
| if round_complete: |
| new_inner_step = 0 |
| new_step = step + 1 |
| else: |
| new_step = step |
|
|
| done = round_complete and new_step >= n |
|
|
| |
| P = int((t - t_prime) * T) + 1 |
| refine_preview_interval = max(1, P // 10) |
| cur_X = _gibbs_aggregate(model, X) |
| cur_E = _gibbs_aggregate(model, E) |
| cur_y = _gibbs_aggregate(model, y.float()) |
| for j in range(P): |
| s_ref = (t - (j + 1) / T) * torch.ones((1, 1), dtype=torch.float, device=device) |
| t_ref = (t - j / T) * torch.ones((1, 1), dtype=torch.float, device=device) |
| cur_X, cur_E, cur_y, cur_Xi, cur_Ei = _denoising_step( |
| model, s_ref, t_ref, cur_X, cur_E, cur_y, node_mask) |
| is_frame = (j + 1) % refine_preview_interval == 0 or j == P - 1 |
| event = { |
| "type": "progress", |
| "phase": "refine", |
| "step": j + 1, |
| "total_steps": P, |
| "elapsed_ms": int((time.time() - t0) * 1000), |
| } |
| if is_frame: |
| event["preview"] = _pil_to_b64( |
| render_graph(cur_Xi[0, :n_max], cur_Ei[0, :n_max, :n_max], dataset_id)) |
| yield event |
|
|
| X_int, E_int = _collapse_final(model, cur_X, cur_E, cur_y, node_mask) |
|
|
| image_b64 = _pil_to_b64(render_graph(X_int[0, :n_max], E_int[0, :n_max, :n_max], dataset_id)) |
| elapsed_ms = int((time.time() - t0) * 1000) |
| updated_state = { |
| **state_dict, |
| "X": X.cpu(), "E": E.cpu(), "y": y.cpu(), |
| "step": new_step, "inner_step": new_inner_step, |
| } |
| yield { |
| "type": "result", |
| "state": updated_state, |
| "image": image_b64, |
| "round_complete": round_complete, |
| "done": done, |
| "inference_time_ms": elapsed_ms, |
| } |
|
|
|
|
| |
| |
| |
|
|
| def encode_state_blob(state_dict): |
| buf = io.BytesIO() |
| torch.save(state_dict, buf) |
| return base64.b64encode(buf.getvalue()).decode("ascii") |
|
|
|
|
| def decode_state_blob(b64_str): |
| try: |
| raw = base64.b64decode(b64_str) |
| except Exception: |
| raise ValueError("state is not valid base64") |
| if len(raw) > STATE_BLOB_MAX_BYTES: |
| raise ValueError(f"state blob exceeds {STATE_BLOB_MAX_BYTES // (1024 * 1024)} MB limit") |
| try: |
| state = torch.load(io.BytesIO(raw), weights_only=False) |
| except Exception as exc: |
| raise ValueError(f"state could not be deserialized: {exc}") from exc |
| missing = REQUIRED_STATE_KEYS - set(state.keys()) |
| if missing: |
| raise ValueError(f"state missing keys: {missing}") |
| if not isinstance(state["X"], torch.Tensor) or state["X"].dim() != 4: |
| raise ValueError("state['X'] must be a 4-D tensor") |
| if not isinstance(state["E"], torch.Tensor) or state["E"].dim() != 5: |
| raise ValueError("state['E'] must be a 5-D tensor") |
| return state |
|
|
|
|
| |
| |
| |
|
|
| def render_graph(X_int, E_int, dataset_id): |
| """Render a single graph to PIL Image. X_int/E_int are 1-D / 2-D integer tensors.""" |
| if dataset_id == "qm9": |
| return _render_qm9(X_int, E_int) |
| else: |
| return _render_comm20(X_int, E_int) |
|
|
|
|
| def _render_qm9(X_int, E_int): |
| from rdkit import Chem |
| from rdkit.Chem import Draw |
| from rdkit.Chem.rdchem import BondType |
|
|
| bond_map = {1: BondType.SINGLE, 2: BondType.DOUBLE, 3: BondType.TRIPLE, 4: BondType.AROMATIC} |
|
|
| mol = Chem.RWMol() |
| x = X_int.cpu().tolist() |
| valid_atoms = [i for i, a in enumerate(x) if a >= 0] |
| idx_map = {} |
| for i in valid_atoms: |
| atom_sym = QM9_ATOM_TYPES[x[i]] if x[i] < len(QM9_ATOM_TYPES) else "C" |
| idx_map[i] = mol.AddAtom(Chem.Atom(atom_sym)) |
|
|
| e = E_int.cpu().tolist() |
| for i in valid_atoms: |
| for j in valid_atoms: |
| if j <= i: |
| continue |
| bond_type_idx = e[i][j] |
| if bond_type_idx > 0 and bond_type_idx in bond_map: |
| mol.AddBond(idx_map[i], idx_map[j], bond_map[bond_type_idx]) |
|
|
| try: |
| img = Draw.MolToImage(mol.GetMol(), size=(300, 300)) |
| except Exception: |
| |
| img = Draw.MolToImage(mol, size=(300, 300)) |
| return img |
|
|
|
|
| _COMM20_GREEN_LOW = (212, 237, 218) |
| _COMM20_GREEN_MID = (82, 180, 120) |
| _COMM20_GREEN_HIGH = (22, 80, 50) |
|
|
|
|
| def _comm20_green_rgb(t): |
| """Map [-1, 1] β RGB via a green palette (pale mint β vivid green β deep forest).""" |
| t = max(-1.0, min(1.0, float(t))) |
| if t < 0: |
| w = t + 1.0 |
| a, b = _COMM20_GREEN_LOW, _COMM20_GREEN_MID |
| else: |
| w = t |
| a, b = _COMM20_GREEN_MID, _COMM20_GREEN_HIGH |
| return ( |
| int(a[0] + (b[0] - a[0]) * w), |
| int(a[1] + (b[1] - a[1]) * w), |
| int(a[2] + (b[2] - a[2]) * w), |
| ) |
|
|
|
|
| def _render_comm20(X_int, E_int): |
| """Render a community graph as an undirected spring-layout plot. |
| |
| Mirrors MultiProxAn's ``visualize_non_molecule``: largest connected |
| component, spring_layout, normalized-Laplacian eigenvector for node |
| colouring (swapped to a site-green palette), no labels, grey edges. |
| Uses pure PIL + networkx and ``torch.linalg.eigh`` to avoid the matplotlib / |
| numpy MKL DLL conflicts on Windows. |
| """ |
| import networkx as nx |
| from PIL import Image, ImageDraw |
|
|
| e = E_int.cpu().tolist() |
| n = len(e) |
| G = nx.Graph() |
| G.add_nodes_from(range(n)) |
| for i in range(n): |
| for j in range(i + 1, n): |
| if e[i][j] > 0: |
| G.add_edge(i, j) |
|
|
| |
| components = sorted(nx.connected_components(G), key=len, reverse=True) |
| graph = G.subgraph(components[0]).copy() if components else G |
|
|
| size = 720 |
| img = Image.new("RGB", (size, size), "white") |
| draw = ImageDraw.Draw(img) |
|
|
| if graph.number_of_nodes() == 0: |
| return img |
|
|
| pos = nx.spring_layout(graph, iterations=100, seed=42) |
|
|
| |
| L = nx.normalized_laplacian_matrix(graph).toarray() |
| L_t = torch.from_numpy(L).to(torch.float64) |
| _, U_t = torch.linalg.eigh(L_t) |
| U = U_t.numpy() |
| eigen_dim = 1 if U.shape[1] > 1 else 0 |
| vec = U[:, eigen_dim] |
| m_abs = max(abs(vec.min()), abs(vec.max()), 1e-9) |
| node_color = {n: _comm20_green_rgb(vec[i] / m_abs) |
| for i, n in enumerate(graph.nodes())} |
|
|
| margin = 60 |
| scale = (size - 2 * margin) / 2 |
| cx, cy = size / 2, size / 2 |
| pixel_pos = {k: (cx + v[0] * scale, cy - v[1] * scale) for k, v in pos.items()} |
|
|
| for i, j in graph.edges(): |
| draw.line([pixel_pos[i], pixel_pos[j]], fill="#9a9a9a", width=2) |
|
|
| node_r = 14 |
| for k, (x, y) in pixel_pos.items(): |
| r, g, b = node_color[k] |
| draw.ellipse([x - node_r, y - node_r, x + node_r, y + node_r], |
| fill=(r, g, b), outline="#333333", width=2) |
|
|
| return img |
|
|
|
|
| def _pil_to_b64(img): |
| buf = io.BytesIO() |
| img.save(buf, format="PNG") |
| return "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode("ascii") |
|
|
|
|
| def _frames_to_gif_b64(frames): |
| if not frames: |
| return None |
| buf = io.BytesIO() |
| frames[0].save( |
| buf, format="GIF", save_all=True, |
| append_images=frames[1:], duration=150, loop=0, |
| ) |
| return "data:image/gif;base64," + base64.b64encode(buf.getvalue()).decode("ascii") |
|
|