website / src /backend /api /services /graphgen_inference.py
Andrej Janchevski
fix(graphgen): symmetrize E before final discrete-graph collapse
221db87
import base64
import io
import time
import torch
QM9_ATOM_TYPES = ["C", "N", "O", "F"]
STATE_BLOB_MAX_BYTES = 10 * 1024 * 1024 # 10 MB
REQUIRED_STATE_KEYS = {
"X", "E", "y", "n_nodes", "dataset_id", "model_type", "T",
"n", "m", "t", "t_prime", "gibbs_chain_freq", "inner_step", "step",
}
# ---------------------------------------------------------------------------
# Model-type helpers
# ---------------------------------------------------------------------------
def _is_discrete(model):
from diffusion_model_discrete import DiscreteDenoisingDiffusion
return isinstance(model, DiscreteDenoisingDiffusion)
def _build_node_mask(n_nodes, n_max, model):
"""bool for discrete, float32 for continuous."""
arange = torch.arange(n_max, device=n_nodes.device).unsqueeze(0)
mask = arange < n_nodes.unsqueeze(1) # (1, n_max) bool
return mask if _is_discrete(model) else mask.float()
def _sample_initial_noise(model, n_max, node_mask):
from src.diffusion import diffusion_utils
if _is_discrete(model):
return diffusion_utils.sample_discrete_feature_noise(
limit_dist=model.limit_dist, node_mask=node_mask)
else:
bs = node_mask.shape[0]
return diffusion_utils.sample_feature_noise(
X_size=(bs, n_max, model.Xdim_output),
E_size=(bs, n_max, n_max, model.Edim_output),
y_size=(bs, model.ydim_output),
node_mask=node_mask)
def _denoising_step(model, s_t, t_t, X, E, y, node_mask):
"""One denoising step. Returns (X_soft, E_soft, y_soft, X_int, E_int)."""
if _is_discrete(model):
sampled_s, discrete_s = model.sample_p_zs_given_zt(s_t, t_t, X, E, y, node_mask)
# .type_as(y_t) in the model can cast collapsed ints to float β€” force back to long
return sampled_s.X, sampled_s.E, sampled_s.y, discrete_s.X.long(), discrete_s.E.long()
else:
from src import utils
z_s = model.sample_p_zs_given_zt(s=s_t, t=t_t, X_t=X, E_t=E, y_t=y, node_mask=node_mask)
unnorm = utils.unnormalize(
z_s.X, z_s.E, z_s.y,
model.norm_values, model.norm_biases, node_mask, collapse=True)
return z_s.X, z_s.E, z_s.y, unnorm.X, unnorm.E
def _gibbs_aggregate(model, X):
if _is_discrete(model):
return torch.median(X, dim=1).values
else:
return torch.mean(X, dim=1)
def _collapse_final(model, X, E, y, node_mask):
"""Returns (X_int, E_int) integer tensors.
Symmetrize E first: MultiProx aggregation (mean / median over multiple
chains) can introduce ULP-level asymmetry that survives into pred_E and
breaks the model's strict ``assert (pred_E == pred_E.T).all()`` on some
BLAS / vectorization stacks (notably the Linux ``+cu118`` torch wheel
inside the deployment container, while the same code runs fine on the
Windows wheel in dev). Symmetrizing here is a no-op when the input is
already symmetric and a 1-line invariant fix when it isn't.
"""
E = (E + E.transpose(1, 2)) / 2
if _is_discrete(model):
from src.utils import PlaceHolder
final = PlaceHolder(X=X, E=E, y=y).mask(node_mask, collapse=True)
return final.X.long(), final.E.long()
else:
final = model.sample_discrete_graph_given_z0(X, E, y, node_mask)
return final.X, final.E
# ---------------------------------------------------------------------------
# Main inference generators β€” yield progress dicts, then a result dict
# ---------------------------------------------------------------------------
def run_standard_generation(model, num_nodes, diffusion_steps, chain_frames, dataset_id):
device = next(model.parameters()).device
if num_nodes is None:
n_nodes = model.node_dist.sample_n(1, device)
else:
n_nodes = torch.tensor([num_nodes], dtype=torch.long, device=device)
n_max = n_nodes.item()
node_mask = _build_node_mask(n_nodes, n_max, model)
z_T = _sample_initial_noise(model, n_max, node_mask)
X, E, y = z_T.X, z_T.E, z_T.y
frame_interval = max(1, diffusion_steps // chain_frames)
gif_frames = []
t0 = time.time()
with torch.no_grad():
for s_idx in reversed(range(diffusion_steps)):
s_t = (s_idx / diffusion_steps) * torch.ones((1, 1), device=device)
t_t = ((s_idx + 1) / diffusion_steps) * torch.ones((1, 1), device=device)
X, E, y, X_int, E_int = _denoising_step(model, s_t, t_t, X, E, y, node_mask)
step = diffusion_steps - 1 - s_idx
is_frame = step % frame_interval == 0 or s_idx == 0
if is_frame:
frame_img = render_graph(X_int[0, :n_max], E_int[0, :n_max, :n_max], dataset_id)
gif_frames.append(frame_img)
event = {
"type": "progress",
"phase": "denoise",
"step": step + 1,
"total_steps": diffusion_steps,
"elapsed_ms": int((time.time() - t0) * 1000),
}
if is_frame:
event["preview"] = _pil_to_b64(frame_img)
yield event
X_final, E_final = _collapse_final(model, X, E, y, node_mask)
image_b64 = _pil_to_b64(render_graph(X_final[0, :n_max], E_final[0, :n_max, :n_max], dataset_id))
elapsed_ms = int((time.time() - t0) * 1000)
yield {
"type": "result",
"image": image_b64,
"chain_gif": _frames_to_gif_b64(gif_frames),
"inference_time_ms": elapsed_ms,
}
def run_multiprox_init(model, num_nodes, n, m, t, t_prime, gibbs_chain_freq, dataset_id):
device = next(model.parameters()).device
if num_nodes is None:
n_nodes = model.node_dist.sample_n(1, device)
else:
n_nodes = torch.tensor([num_nodes], dtype=torch.long, device=device)
n_max = n_nodes.item()
node_mask = _build_node_mask(n_nodes, n_max, model)
t0 = time.time()
z_samples = []
for i in range(m):
z_samples.append(_sample_initial_noise(model, n_max, node_mask))
if (i + 1) % max(1, m // 10) == 0 or i == m - 1:
yield {
"type": "progress",
"phase": "noise_init",
"step": i + 1,
"total_steps": m,
"elapsed_ms": int((time.time() - t0) * 1000),
}
X = torch.stack([z.X for z in z_samples], dim=1) # (1, M, n_max, Xdim)
E = torch.stack([z.E for z in z_samples], dim=1)
y = torch.stack([z.y for z in z_samples], dim=1)
agg_X = _gibbs_aggregate(model, X)
agg_E = _gibbs_aggregate(model, E)
agg_y = _gibbs_aggregate(model, y.float())
X_int, E_int = _collapse_final(model, agg_X, agg_E, agg_y, node_mask)
image_b64 = _pil_to_b64(render_graph(X_int[0, :n_max], E_int[0, :n_max, :n_max], dataset_id))
elapsed_ms = int((time.time() - t0) * 1000)
state = {
"X": X.cpu(), "E": E.cpu(), "y": y.cpu(), "n_nodes": n_nodes.cpu(),
"dataset_id": dataset_id, "model_type": None, # filled by registry
"T": model.T, "n": n, "m": m, "t": t, "t_prime": t_prime,
"gibbs_chain_freq": gibbs_chain_freq, "inner_step": 0, "step": 0,
}
yield {
"type": "result",
"state": state,
"image": image_b64,
"inference_time_ms": elapsed_ms,
}
def run_multiprox_step(model, state_dict, dataset_id):
device = next(model.parameters()).device
X = state_dict["X"].to(device)
E = state_dict["E"].to(device)
y = state_dict["y"].to(device)
n_nodes = state_dict["n_nodes"].to(device)
T = state_dict["T"]
n = state_dict["n"]
m = state_dict["m"]
t = state_dict["t"]
t_prime = state_dict["t_prime"]
gibbs_chain_freq = state_dict["gibbs_chain_freq"]
inner_step = state_dict["inner_step"]
step = state_dict["step"]
n_max = X.shape[2]
node_mask = _build_node_mask(n_nodes, n_max, model)
fixed_t = t * torch.ones((1, 1), dtype=torch.float, device=device)
fixed_s = fixed_t - (1.0 / T)
steps_this_call = min(gibbs_chain_freq, m - inner_step)
t0 = time.time()
with torch.no_grad():
for i in range(steps_this_call):
k = inner_step + i
avg_X = _gibbs_aggregate(model, X)
avg_E = _gibbs_aggregate(model, E)
avg_y = _gibbs_aggregate(model, y.float())
denoised_X, denoised_E, denoised_y, _, _ = _denoising_step(
model, fixed_s, fixed_t, avg_X, avg_E, avg_y, node_mask)
old_t2 = model.gibbs_fixed_t_2
model.gibbs_fixed_t_2 = t # safe: _inference_lock held by registry
noisy = model.apply_noise(denoised_X, denoised_E, denoised_y, node_mask, gibbs=True)
model.gibbs_fixed_t_2 = old_t2
X[:, k] = noisy["X_t"]
E[:, k] = noisy["E_t"]
y[:, k] = noisy["y_t"]
# Preview: aggregate + collapse current Gibbs state
prev_X = _gibbs_aggregate(model, X)
prev_E = _gibbs_aggregate(model, E)
prev_y = _gibbs_aggregate(model, y.float())
prev_Xi, prev_Ei = _collapse_final(model, prev_X, prev_E, prev_y, node_mask)
yield {
"type": "progress",
"phase": "gibbs",
"step": i + 1,
"total_steps": steps_this_call,
"elapsed_ms": int((time.time() - t0) * 1000),
"preview": _pil_to_b64(render_graph(prev_Xi[0, :n_max], prev_Ei[0, :n_max, :n_max], dataset_id)),
}
new_inner_step = inner_step + steps_this_call
round_complete = new_inner_step >= m
if round_complete:
new_inner_step = 0
new_step = step + 1
else:
new_step = step
done = round_complete and new_step >= n
# Refinement pass β€” always produce a clean render
P = int((t - t_prime) * T) + 1
refine_preview_interval = max(1, P // 10)
cur_X = _gibbs_aggregate(model, X)
cur_E = _gibbs_aggregate(model, E)
cur_y = _gibbs_aggregate(model, y.float())
for j in range(P):
s_ref = (t - (j + 1) / T) * torch.ones((1, 1), dtype=torch.float, device=device)
t_ref = (t - j / T) * torch.ones((1, 1), dtype=torch.float, device=device)
cur_X, cur_E, cur_y, cur_Xi, cur_Ei = _denoising_step(
model, s_ref, t_ref, cur_X, cur_E, cur_y, node_mask)
is_frame = (j + 1) % refine_preview_interval == 0 or j == P - 1
event = {
"type": "progress",
"phase": "refine",
"step": j + 1,
"total_steps": P,
"elapsed_ms": int((time.time() - t0) * 1000),
}
if is_frame:
event["preview"] = _pil_to_b64(
render_graph(cur_Xi[0, :n_max], cur_Ei[0, :n_max, :n_max], dataset_id))
yield event
X_int, E_int = _collapse_final(model, cur_X, cur_E, cur_y, node_mask)
image_b64 = _pil_to_b64(render_graph(X_int[0, :n_max], E_int[0, :n_max, :n_max], dataset_id))
elapsed_ms = int((time.time() - t0) * 1000)
updated_state = {
**state_dict,
"X": X.cpu(), "E": E.cpu(), "y": y.cpu(),
"step": new_step, "inner_step": new_inner_step,
}
yield {
"type": "result",
"state": updated_state,
"image": image_b64,
"round_complete": round_complete,
"done": done,
"inference_time_ms": elapsed_ms,
}
# ---------------------------------------------------------------------------
# State blob serialisation
# ---------------------------------------------------------------------------
def encode_state_blob(state_dict):
buf = io.BytesIO()
torch.save(state_dict, buf)
return base64.b64encode(buf.getvalue()).decode("ascii")
def decode_state_blob(b64_str):
try:
raw = base64.b64decode(b64_str)
except Exception:
raise ValueError("state is not valid base64")
if len(raw) > STATE_BLOB_MAX_BYTES:
raise ValueError(f"state blob exceeds {STATE_BLOB_MAX_BYTES // (1024 * 1024)} MB limit")
try:
state = torch.load(io.BytesIO(raw), weights_only=False)
except Exception as exc:
raise ValueError(f"state could not be deserialized: {exc}") from exc
missing = REQUIRED_STATE_KEYS - set(state.keys())
if missing:
raise ValueError(f"state missing keys: {missing}")
if not isinstance(state["X"], torch.Tensor) or state["X"].dim() != 4:
raise ValueError("state['X'] must be a 4-D tensor")
if not isinstance(state["E"], torch.Tensor) or state["E"].dim() != 5:
raise ValueError("state['E'] must be a 5-D tensor")
return state
# ---------------------------------------------------------------------------
# Visualisation
# ---------------------------------------------------------------------------
def render_graph(X_int, E_int, dataset_id):
"""Render a single graph to PIL Image. X_int/E_int are 1-D / 2-D integer tensors."""
if dataset_id == "qm9":
return _render_qm9(X_int, E_int)
else:
return _render_comm20(X_int, E_int)
def _render_qm9(X_int, E_int):
from rdkit import Chem
from rdkit.Chem import Draw
from rdkit.Chem.rdchem import BondType
bond_map = {1: BondType.SINGLE, 2: BondType.DOUBLE, 3: BondType.TRIPLE, 4: BondType.AROMATIC}
mol = Chem.RWMol()
x = X_int.cpu().tolist()
valid_atoms = [i for i, a in enumerate(x) if a >= 0]
idx_map = {}
for i in valid_atoms:
atom_sym = QM9_ATOM_TYPES[x[i]] if x[i] < len(QM9_ATOM_TYPES) else "C"
idx_map[i] = mol.AddAtom(Chem.Atom(atom_sym))
e = E_int.cpu().tolist()
for i in valid_atoms:
for j in valid_atoms:
if j <= i:
continue
bond_type_idx = e[i][j]
if bond_type_idx > 0 and bond_type_idx in bond_map:
mol.AddBond(idx_map[i], idx_map[j], bond_map[bond_type_idx])
try:
img = Draw.MolToImage(mol.GetMol(), size=(300, 300))
except Exception:
# If RDKit can't sanitize, draw the raw RWMol
img = Draw.MolToImage(mol, size=(300, 300))
return img
_COMM20_GREEN_LOW = (212, 237, 218)
_COMM20_GREEN_MID = (82, 180, 120)
_COMM20_GREEN_HIGH = (22, 80, 50)
def _comm20_green_rgb(t):
"""Map [-1, 1] β†’ RGB via a green palette (pale mint β†’ vivid green β†’ deep forest)."""
t = max(-1.0, min(1.0, float(t)))
if t < 0:
w = t + 1.0
a, b = _COMM20_GREEN_LOW, _COMM20_GREEN_MID
else:
w = t
a, b = _COMM20_GREEN_MID, _COMM20_GREEN_HIGH
return (
int(a[0] + (b[0] - a[0]) * w),
int(a[1] + (b[1] - a[1]) * w),
int(a[2] + (b[2] - a[2]) * w),
)
def _render_comm20(X_int, E_int):
"""Render a community graph as an undirected spring-layout plot.
Mirrors MultiProxAn's ``visualize_non_molecule``: largest connected
component, spring_layout, normalized-Laplacian eigenvector for node
colouring (swapped to a site-green palette), no labels, grey edges.
Uses pure PIL + networkx and ``torch.linalg.eigh`` to avoid the matplotlib /
numpy MKL DLL conflicts on Windows.
"""
import networkx as nx
from PIL import Image, ImageDraw
e = E_int.cpu().tolist()
n = len(e)
G = nx.Graph()
G.add_nodes_from(range(n))
for i in range(n):
for j in range(i + 1, n):
if e[i][j] > 0:
G.add_edge(i, j)
# Largest connected component only (matches visualize_non_molecule(largest_component=True)).
components = sorted(nx.connected_components(G), key=len, reverse=True)
graph = G.subgraph(components[0]).copy() if components else G
size = 720
img = Image.new("RGB", (size, size), "white")
draw = ImageDraw.Draw(img)
if graph.number_of_nodes() == 0:
return img
pos = nx.spring_layout(graph, iterations=100, seed=42)
# Normalized Laplacian eigenvector for node colouring (torch avoids numpy MKL DLL clash).
L = nx.normalized_laplacian_matrix(graph).toarray()
L_t = torch.from_numpy(L).to(torch.float64)
_, U_t = torch.linalg.eigh(L_t)
U = U_t.numpy()
eigen_dim = 1 if U.shape[1] > 1 else 0
vec = U[:, eigen_dim]
m_abs = max(abs(vec.min()), abs(vec.max()), 1e-9)
node_color = {n: _comm20_green_rgb(vec[i] / m_abs)
for i, n in enumerate(graph.nodes())}
margin = 60
scale = (size - 2 * margin) / 2
cx, cy = size / 2, size / 2
pixel_pos = {k: (cx + v[0] * scale, cy - v[1] * scale) for k, v in pos.items()}
for i, j in graph.edges():
draw.line([pixel_pos[i], pixel_pos[j]], fill="#9a9a9a", width=2)
node_r = 14
for k, (x, y) in pixel_pos.items():
r, g, b = node_color[k]
draw.ellipse([x - node_r, y - node_r, x + node_r, y + node_r],
fill=(r, g, b), outline="#333333", width=2)
return img
def _pil_to_b64(img):
buf = io.BytesIO()
img.save(buf, format="PNG")
return "data:image/png;base64," + base64.b64encode(buf.getvalue()).decode("ascii")
def _frames_to_gif_b64(frames):
if not frames:
return None
buf = io.BytesIO()
frames[0].save(
buf, format="GIF", save_all=True,
append_images=frames[1:], duration=150, loop=0,
)
return "data:image/gif;base64," + base64.b64encode(buf.getvalue()).decode("ascii")