Andrej Janchevski commited on
Commit
c50d15c
Β·
1 Parent(s): c1b3cc7

docs(postman): add thesis-optimal multiprox params and force-unlock endpoint

Browse files

- QM9 multiprox: t=0.5, t'=0.004 (Table 4.3.1 best valid config)
- Comm20 multiprox: t=0.4, t'=0.1 (Table C.2.1 best orbit/degree)
- Add POST /debug/force-unlock request to Health folder
- Auto-chaining test scripts for multiprox state between init/continue
- Update README with SSE protocol and new endpoints
- Update backend_multiproxan plan to reflect current design decisions

.claude/plans/backend_multiproxan.md CHANGED
@@ -13,14 +13,19 @@ The key challenge is that `sample_batch` / `sample_batch_gibbs` write to disk vi
13
  | `src/research/MultiProxAn/src/diffusion_model_discrete.py` | Strip all `wandb` imports and calls |
14
  | `src/research/MultiProxAn/src/diffusion_model.py` | Strip all `wandb` imports and calls (incl. one unguarded call) |
15
  | `src/research/MultiProxAn/src/utils.py` | Remove `import wandb` and `setup_wandb()` function |
 
 
 
 
16
  | `src/backend/research_api/settings.py` | Add `MultiProxAn/src/` to `sys.path` |
17
- | `src/backend/api/services/graphgen_inference.py` | **New file** β€” all inference logic + rendering |
18
- | `src/backend/api/services/registry.py` | Add `_graphgen_models` cache, `_load_graphgen_model`, `graphgen_generate`, `graphgen_continue` |
19
- | `src/backend/api/views/graph_generation.py` | Add `GraphGenGenerateView`, `GraphGenContinueView` |
20
- | `src/backend/api/urls.py` | Wire 2 new routes + update import |
21
- | `src/backend/requirements.txt` | Add `wandb`, `Pillow`, `overrides` |
22
- | `src/backend/README.md` | Mark generate/continue as implemented |
23
- | `docs/postman/collection.json` | Add 12 example requests |
 
24
 
25
  ## Model Differences: Discrete vs Continuous
26
 
@@ -30,60 +35,164 @@ The key challenge is that `sample_batch` / `sample_batch_gibbs` write to disk vi
30
  |---|---|---|
31
  | **`node_mask` dtype** | `bool` | `float32` (`.float()` required) |
32
  | **Initial noise** | `sample_discrete_feature_noise(limit_dist=model.limit_dist, node_mask=node_mask)` | `sample_feature_noise(X_size=(1,n,Xdim), E_size=(1,n,n,Edim), y_size=(1,ydim), node_mask=node_mask)` |
33
- | **`sample_p_zs_given_zt` returns** | `(sampled_s, discrete_sampled_s)` β€” 2-tuple; `discrete_sampled_s.X/E` are already collapsed integers | single `z_s` PlaceHolder with continuous floats |
34
- | **Chain frame render** | use `discrete_sampled_s.X/E` directly | `utils.unnormalize(z_s.X, z_s.E, z_s.y, model.norm_values, model.norm_biases, node_mask, collapse=True)` β†’ `.X/.E` are integers |
35
- | **Final graph collapse** | `PlaceHolder(X=X, E=E, y=y).mask(node_mask, collapse=True)` | `model.sample_discrete_graph_given_z0(X, E, y, node_mask)` (runs another forward pass + unnormalize) |
36
  | **Gibbs ensemble aggregation** | `torch.median(X, dim=1).values` | `torch.mean(X, dim=1)` |
37
 
 
 
38
  The refinement loop in `run_multiprox_step` uses `sample_p_zs_given_zt` to update `cur_X/E/y` at each step; only the first element of the tuple (or the single return value) is needed there. `_collapse_final` is called once at the end for rendering.
39
 
40
  ## Research Code Reused (read-only)
41
 
42
  | Symbol | Location | How used |
43
  |---|---|---|
44
- | `DiscreteDenoisingDiffusion` | `src/diffusion_model_discrete.py` | Loaded via `load_from_checkpoint` for `model_type=discrete` |
45
- | `LiftedDenoisingDiffusion` | `src/diffusion_model.py` | Loaded via `load_from_checkpoint` for `model_type=continuous` |
46
  | `model.node_dist.sample_n(1, device)` | both models | Sample number of nodes |
47
- | `diffusion_utils.sample_discrete_feature_noise(limit_dist, node_mask)` | `src/diffusion/diffusion_utils.py` | Initial noise β€” discrete only |
48
- | `diffusion_utils.sample_feature_noise(X_size, E_size, y_size, node_mask)` | `src/diffusion/diffusion_utils.py` | Initial noise β€” continuous only |
49
- | `model.sample_p_zs_given_zt(s, t, X, E, y, node_mask)` | both models | One denoising step (return varies by model type β€” see table above) |
50
  | `model.apply_noise(X, E, y, node_mask, gibbs=True)` | both models | Re-apply Gibbs noise; uses `model.gibbs_fixed_t_2` internally |
51
- | `PlaceHolder.mask(node_mask, collapse=True)` | `src/utils.py` | Final collapse β€” discrete only |
52
- | `model.sample_discrete_graph_given_z0(X, E, y, node_mask)` | `diffusion_model.py` | Final collapse β€” continuous only |
53
- | `utils.unnormalize(X, E, y, norm_values, norm_biases, node_mask, collapse=True)` | `src/utils.py` | Chain frame rendering β€” continuous only |
54
- | `model.norm_values`, `model.norm_biases` | `diffusion_model.py` | Unnormalization factors β€” continuous only |
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
 
56
  ## Implementation
57
 
58
- ### Step 0: Patch research code β€” remove `wandb`
59
 
60
- `wandb` is imported at module level in three files; the import alone causes `ImportError` at
61
- `load_from_checkpoint` time. One call in `diffusion_model.py` is **unguarded** (no `if wandb.run:`
62
- check) and would crash at inference even if the import somehow succeeded.
63
 
64
  **`src/research/MultiProxAn/src/diffusion_model_discrete.py`**:
65
  - Remove `import wandb` (line 9)
66
- - Remove the two `utils.setup_wandb(self.cfg)` call sites (inside `on_train_start` and `on_test_start`)
67
- - Remove all `if wandb.run: wandb.log(...)` / `wandb.run.summary[...]` blocks (5 blocks, all guarded β€” safe to delete entirely)
68
 
69
  **`src/research/MultiProxAn/src/diffusion_model.py`**:
70
  - Remove `import wandb` (line 9)
71
  - Remove the two `utils.setup_wandb(self.cfg)` call sites
72
- - Remove all guarded `if wandb.run: wandb.log(...)` blocks (5 blocks)
73
- - Remove the **unguarded** `wandb.log({...})` call at line 590 (inside `compute_val_loss` / `compute_test_loss`); the surrounding NLL computation and `return` must be preserved β€” only the `wandb.log(...)` statement is deleted
74
 
75
  **`src/research/MultiProxAn/src/utils.py`**:
76
- - Remove `import wandb` (line 7)
77
- - Remove the entire `setup_wandb(cfg)` function (lines 134–139)
78
-
79
- After these edits `wandb` is no longer needed and must be removed from `requirements.txt` (Step 6).
80
 
81
  ---
82
 
83
- ### Step 1: `settings.py` β€” sys.path fix
 
 
 
 
84
 
85
- The research code uses two import styles simultaneously (`from diffusion.noise_schedule import ...` needs `MultiProxAn/src/` on path; `from src import utils` needs `MultiProxAn/` on path). Extend the existing loop at lines 9–13:
 
 
86
 
 
87
  ```python
88
  _MULTIPROXAN_SRC = str(PROJECT_ROOT / "src" / "research" / "MultiProxAn" / "src")
89
  for _path in (_COINS_KG_ROOT, _MULTIPROXAN_ROOT, _MULTIPROXAN_SRC):
@@ -95,108 +204,48 @@ for _path in (_COINS_KG_ROOT, _MULTIPROXAN_ROOT, _MULTIPROXAN_SRC):
95
 
96
  ### Step 2: New `graphgen_inference.py`
97
 
98
- Completely independent of Django/registry. Receives a loaded model and returns bytes/dicts.
99
 
100
  **Module structure:**
101
 
102
  ```
103
  graphgen_inference.py
104
- β”œβ”€β”€ Constants: QM9_ATOM_TYPES, STATE_BLOB_MAX_BYTES, REQUIRED_STATE_KEYS
105
- β”‚
106
- β”œβ”€β”€ # Model-type helpers (called by all three main functions)
107
- β”œβ”€β”€ _is_discrete(model) β†’ bool
108
- β”œβ”€β”€ _build_node_mask(n_nodes, n_max, model) β†’ bool or float32 tensor
109
- β”œβ”€β”€ _sample_initial_noise(model, n_max, node_mask) β†’ PlaceHolder
110
- β”œβ”€β”€ _denoising_step(model, s_t, t_t, X, E, y, node_mask) β†’ (X_soft, E_soft, y_soft, X_int, E_int)
111
- β”œβ”€β”€ _gibbs_aggregate(model, X) β†’ tensor [median (discrete) or mean (continuous)]
112
- β”œβ”€β”€ _collapse_final(model, X, E, y, node_mask) β†’ (X_int, E_int)
113
- β”‚
114
- β”œβ”€β”€ # Main inference functions
115
- β”œβ”€β”€ run_standard_generation(model, num_nodes, diffusion_steps, chain_frames, dataset_id)
116
- β”œβ”€β”€ run_multiprox_init(model, num_nodes, m, t, t_prime, dataset_id)
117
- β”œβ”€β”€ run_multiprox_step(model, state_dict, dataset_id)
118
- β”‚
119
- β”œβ”€β”€ # State serialisation
120
- β”œβ”€β”€ encode_state_blob(state_dict) β†’ str
121
- β”œβ”€β”€ decode_state_blob(b64_str) β†’ dict
122
- β”‚
123
- └── # Visualisation
124
- β”œβ”€β”€ render_graph(X_int, E_int, dataset_id) β†’ PIL.Image
125
- β”œβ”€β”€ _render_qm9(X_int, E_int) β†’ PIL.Image [RDKit]
126
- β”œβ”€β”€ _render_comm20(X_int, E_int) β†’ PIL.Image [networkx + matplotlib Agg]
127
- β”œβ”€β”€ _pil_to_b64(img) β†’ str
128
- └── _frames_to_gif_b64(frames) β†’ str
129
  ```
130
 
131
  #### Model-type helpers
132
 
133
- ```python
134
- def _is_discrete(model):
135
- from diffusion_model_discrete import DiscreteDenoisingDiffusion
136
- return isinstance(model, DiscreteDenoisingDiffusion)
137
-
138
-
139
- def _build_node_mask(n_nodes, n_max, model):
140
- """bool for discrete, float32 for continuous."""
141
- arange = torch.arange(n_max, device=n_nodes.device).unsqueeze(0)
142
- mask = arange < n_nodes.unsqueeze(1) # (1, n_max) bool
143
- return mask if _is_discrete(model) else mask.float()
144
-
145
-
146
- def _sample_initial_noise(model, n_max, node_mask):
147
- """Discrete: sample_discrete_feature_noise. Continuous: sample_feature_noise."""
148
- if _is_discrete(model):
149
- return diffusion_utils.sample_discrete_feature_noise(
150
- limit_dist=model.limit_dist, node_mask=node_mask)
151
- else:
152
- bs = node_mask.shape[0]
153
- return diffusion_utils.sample_feature_noise(
154
- X_size=(bs, n_max, model.Xdim_output),
155
- E_size=(bs, n_max, n_max, model.Edim_output),
156
- y_size=(bs, model.ydim_output),
157
- node_mask=node_mask)
158
-
159
-
160
- def _denoising_step(model, s_t, t_t, X, E, y, node_mask):
161
- """Run one denoising step.
162
- Returns (X_soft, E_soft, y_soft, X_int, E_int) where X/E_int are collapsed integer tensors
163
- suitable for rendering. X/E_soft are the continuous activations to feed into the next step.
164
- """
165
- if _is_discrete(model):
166
- sampled_s, discrete_s = model.sample_p_zs_given_zt(s_t, t_t, X, E, y, node_mask)
167
- return sampled_s.X, sampled_s.E, sampled_s.y, discrete_s.X, discrete_s.E
168
- else:
169
- z_s = model.sample_p_zs_given_zt(s=s_t, t=t_t, X_t=X, E_t=E, y_t=y, node_mask=node_mask)
170
- unnorm = utils.unnormalize(
171
- z_s.X, z_s.E, z_s.y,
172
- model.norm_values, model.norm_biases, node_mask, collapse=True)
173
- return z_s.X, z_s.E, z_s.y, unnorm.X, unnorm.E
174
-
175
-
176
- def _gibbs_aggregate(model, X):
177
- """Aggregate ensemble: median for discrete, mean for continuous."""
178
- if _is_discrete(model):
179
- return torch.median(X, dim=1).values
180
- else:
181
- return torch.mean(X, dim=1)
182
-
183
-
184
- def _collapse_final(model, X, E, y, node_mask):
185
- """Collapse continuous activations to integer indices for rendering.
186
- Discrete: PlaceHolder.mask(collapse=True). Continuous: sample_discrete_graph_given_z0.
187
- Returns (X_int, E_int) tensors.
188
- """
189
- if _is_discrete(model):
190
- final = PlaceHolder(X=X, E=E, y=y).mask(node_mask, collapse=True)
191
- return final.X, final.E
192
- else:
193
- final = model.sample_discrete_graph_given_z0(X, E, y, node_mask)
194
- return final.X, final.E
195
- ```
196
 
197
  #### `run_standard_generation`
198
 
199
- Re-implements the denoising loop β€” do NOT call `sample_batch`. Uses `s_idx / diffusion_steps` for normalized time (valid because `sample_p_zs_given_zt` takes normalized floats in [0,1] regardless of `model.T`):
200
 
201
  ```python
202
  def run_standard_generation(model, num_nodes, diffusion_steps, chain_frames, dataset_id):
@@ -219,121 +268,31 @@ def run_standard_generation(model, num_nodes, diffusion_steps, chain_frames, dat
219
  step = diffusion_steps - 1 - s_idx
220
  if step % frame_interval == 0 or s_idx == 0:
221
  gif_frames.append(render_graph(X_int[0, :n_max], E_int[0, :n_max, :n_max], dataset_id))
 
222
 
223
  X_final, E_final = _collapse_final(model, X, E, y, node_mask)
224
- image_b64 = _pil_to_b64(render_graph(X_final[0, :n_max], E_final[0, :n_max, :n_max], dataset_id))
225
- return image_b64, _frames_to_gif_b64(gif_frames), elapsed_ms
226
  ```
227
 
228
- #### API parameter β†’ model attribute mapping
229
 
230
  | API param | Model attribute | Role |
231
  |---|---|---|
232
  | `t` | `model.gibbs_fixed_t_2` | Fixed noise level for the Gibbs chain; used by `apply_noise(gibbs=True)` and as `fixed_t_norm` in the inner denoising step |
233
- | `t_prime` | `model.gibbs_fixed_t_1` | Refinement target; `P = int((gibbs_fixed_t_2 βˆ’ gibbs_fixed_t_1) Γ— T) + 1` steps denoise from `t` down to `t_prime` |
234
  | `gibbs_chain_freq` | `model.gibbs_chain_freq` | Inner Gibbs steps per `/continue` call (same attribute name in checkpoint) |
235
  | `m` | `model.gibbs_M` | Ensemble size |
236
- | `n` | `model.gibbs_N` | Number of outer Gibbs rounds (full sweeps over M); session is complete when `step == n` |
237
 
238
- All parameters are always explicit in the API request β€” no fallback to checkpoint values. If the user changes any parameter on the frontend, it clears the state and starts a fresh `generate` call. The `apply_noise(gibbs=True)` call reads `self.gibbs_fixed_t_2` directly, so we must set `model.gibbs_fixed_t_2 = t` before calling it (and restore afterwards). This is safe because `_inference_lock` ensures single-threaded access.
239
 
240
  #### `run_multiprox_init`
241
 
242
- Initialises the M-member ensemble; no Gibbs steps run yet. `gibbs_chain_freq` controls how many inner
243
- Gibbs iterations run per `/continue` call β€” `⌈M / gibbs_chain_freqβŒ‰` calls complete one outer round.
244
- Default `gibbs_chain_freq = max(1, m // 10)` (10% of ensemble size):
245
-
246
- ```python
247
- def run_multiprox_init(model, num_nodes, n, m, t, t_prime, gibbs_chain_freq, dataset_id):
248
- device = next(model.parameters()).device
249
- n_nodes = model.node_dist.sample_n(1, device) if num_nodes is None else torch.tensor([num_nodes], ...)
250
- n_max = n_nodes.item()
251
- node_mask = _build_node_mask(n_nodes, n_max, model)
252
-
253
- # Sample M independent initial noise graphs
254
- z_samples = [_sample_initial_noise(model, n_max, node_mask) for _ in range(m)]
255
- X = torch.stack([z.X for z in z_samples], dim=1) # (1, M, n_max, Xdim)
256
- E = torch.stack([z.E for z in z_samples], dim=1)
257
- y = torch.stack([z.y for z in z_samples], dim=1)
258
-
259
- # Step 0 image: aggregate ensemble β†’ collapse β†’ render
260
- agg_X = _gibbs_aggregate(model, X) # (1, n_max, Xdim)
261
- agg_E = _gibbs_aggregate(model, E)
262
- agg_y = _gibbs_aggregate(model, y.float())
263
- X_int, E_int = _collapse_final(model, agg_X, agg_E, agg_y, node_mask)
264
- image_b64 = _pil_to_b64(render_graph(X_int[0, :n_max], E_int[0, :n_max, :n_max], dataset_id))
265
-
266
- state = {"X": X.cpu(), "E": E.cpu(), "y": y.cpu(), "n_nodes": n_nodes.cpu(),
267
- "dataset_id": dataset_id, "model_type": None, # filled by registry
268
- "T": model.T, "n": n, "m": m, "t": t, "t_prime": t_prime,
269
- "gibbs_chain_freq": gibbs_chain_freq, "inner_step": 0, "step": 0}
270
- return state, image_b64, elapsed_ms
271
- ```
272
 
273
  #### `run_multiprox_step`
274
 
275
- Runs `gibbs_chain_freq` inner Gibbs iterations, advancing `inner_step`. When `inner_step` reaches `m`
276
- (one full outer round), runs the t→t_prime refinement and resets. On intermediate calls (round
277
- still in progress), renders the raw ensemble aggregate β€” no refinement, faster response.
278
-
279
- ```python
280
- def run_multiprox_step(model, state_dict, dataset_id):
281
- # Unpack X (1,M,n_max,Xdim), E, y, n_nodes, t, t_prime, T, n, m, gibbs_chain_freq, inner_step, step
282
- device = next(model.parameters()).device
283
- X, E, y = state["X"].to(device), state["E"].to(device), state["y"].to(device)
284
- n_nodes = state["n_nodes"].to(device)
285
- n_max = X.shape[2]
286
- node_mask = _build_node_mask(n_nodes, n_max, model)
287
-
288
- fixed_t = t * torch.ones((1,1), dtype=torch.float, device=device)
289
- fixed_s = fixed_t - (1.0 / T)
290
-
291
- # How many inner Gibbs steps to run this call (may be less than gibbs_chain_freq at end of round)
292
- steps_this_call = min(gibbs_chain_freq, m - inner_step)
293
-
294
- with torch.no_grad():
295
- for i in range(steps_this_call):
296
- k = inner_step + i
297
- avg_X = _gibbs_aggregate(model, X) # (1, n_max, Xdim)
298
- avg_E = _gibbs_aggregate(model, E)
299
- avg_y = _gibbs_aggregate(model, y.float())
300
- denoised_X, denoised_E, denoised_y, _, _ = _denoising_step(
301
- model, fixed_s, fixed_t, avg_X, avg_E, avg_y, node_mask)
302
- old_t2 = model.gibbs_fixed_t_2
303
- model.gibbs_fixed_t_2 = t # override per-request (lock held by registry)
304
- noisy = model.apply_noise(denoised_X, denoised_E, denoised_y, node_mask, gibbs=True)
305
- model.gibbs_fixed_t_2 = old_t2
306
- X[:, k], E[:, k], y[:, k] = noisy["X_t"], noisy["E_t"], noisy["y_t"]
307
-
308
- new_inner_step = inner_step + steps_this_call
309
- round_complete = new_inner_step >= m
310
- if round_complete:
311
- new_inner_step = 0
312
- new_step = step + 1
313
- else:
314
- new_step = step
315
- done = round_complete and new_step >= n
316
-
317
- # Refinement pass runs on every call β€” always produce a clean render.
318
- # Uses the current ensemble aggregate (regardless of round_complete).
319
- P = int((t - t_prime) * T) + 1
320
- cur_X = _gibbs_aggregate(model, X)
321
- cur_E = _gibbs_aggregate(model, E)
322
- cur_y = _gibbs_aggregate(model, y.float())
323
- for j in range(P):
324
- s_ref = (t - (j + 1) / T) * torch.ones((1,1), dtype=torch.float, device=device)
325
- t_ref = (t - j / T) * torch.ones((1,1), dtype=torch.float, device=device)
326
- cur_X, cur_E, cur_y, _, _ = _denoising_step(
327
- model, s_ref, t_ref, cur_X, cur_E, cur_y, node_mask)
328
- X_int, E_int = _collapse_final(model, cur_X, cur_E, cur_y, node_mask)
329
-
330
- image_b64 = _pil_to_b64(render_graph(X_int[0, :n_max], E_int[0, :n_max, :n_max], dataset_id))
331
- updated_state = {**state_dict, "X": X.cpu(), "E": E.cpu(), "y": y.cpu(),
332
- "step": new_step, "inner_step": new_inner_step}
333
- return updated_state, image_b64, round_complete, done, elapsed_ms
334
- ```
335
-
336
- Response includes `round_complete: bool` (full M sweep done) and `done: bool` (`n` outer rounds complete β€” frontend should stop calling continue).
337
 
338
  **State blob (encode/decode):**
339
 
@@ -342,167 +301,88 @@ STATE_BLOB_MAX_BYTES = 10 * 1024 * 1024 # 10 MB
342
  REQUIRED_STATE_KEYS = {"X", "E", "y", "n_nodes", "dataset_id", "model_type", "T",
343
  "n", "m", "t", "t_prime", "gibbs_chain_freq", "inner_step", "step"}
344
 
345
- # encode: torch.save(state_dict, BytesIO) β†’ base64 string
346
- # decode: base64 β†’ bytes β†’ torch.load(weights_only=True) β†’ validate keys + X.dim()==4, E.dim()==5
347
- # Raise ValueError on bad base64, oversized blob, missing keys, wrong tensor dims.
348
- # Registry converts ValueError β†’ InvalidRequestError(400).
349
  ```
350
 
351
- **Visualization:**
352
-
353
- - `_render_qm9`: `RWMol`, `Atom(symbol)` from `QM9_ATOM_TYPES=["C","N","O","F"]`, `AddBond` with `{1:SINGLE, 2:DOUBLE, 3:TRIPLE, 4:AROMATIC}` (index 0 = no bond, skip), `MolToImage(size=(300,300))` β†’ PIL
354
- - `_render_comm20`: `nx.Graph` from adjacency (edge index > 0 = present), `nx.draw_networkx` with matplotlib Agg backend, `fig.savefig(BytesIO, "png")` β†’ PIL
355
- - `_pil_to_b64`: `img.save(BytesIO, "PNG")` β†’ `"data:image/png;base64," + b64encode`
356
- - `_frames_to_gif_b64`: `frames[0].save(BytesIO, "GIF", save_all=True, append_images=..., duration=150, loop=0)` β†’ `"data:image/gif;base64," + b64encode`
357
-
358
  ---
359
 
360
  ### Step 3: `registry.py` additions
361
 
362
- **New instance variable** in `ModelRegistry.__init__`:
363
 
 
364
  ```python
365
- self._graphgen_models = {} # (dataset_id, model_type) -> loaded eval-mode model
 
 
366
  ```
367
 
368
- **`_load_graphgen_model(self, dataset_id, model_type)`** β€” lazy load with caching:
369
-
370
- ```python
371
- # Defer imports inside method β€” sys.path is set at app startup, not at Django module import time
372
- if model_type == "discrete":
373
- from diffusion_model_discrete import DiscreteDenoisingDiffusion as cls
374
- else:
375
- from diffusion_model import LiftedDenoisingDiffusion as cls
376
-
377
- suffix = "_c" if model_type == "continuous" else ""
378
- ckpt_path = Path(settings.MULTIPROXAN_DIR) / "checkpoints" / f"{dataset_id}{suffix}.ckpt"
379
-
380
- model = cls.load_from_checkpoint(
381
- str(ckpt_path), map_location="cpu",
382
- train_metrics=None, sampling_metrics=None, visualization_tools=None,
383
- )
384
- model.eval()
385
- self._graphgen_models[(dataset_id, model_type)] = model
386
- ```
387
 
388
- **Dataset loader / hyperparameter restoration notes:**
389
-
390
- `save_hyperparameters(ignore=['train_metrics', 'sampling_metrics'])` saves `dataset_infos`,
391
- `extra_features`, `domain_features`, and `visualization_tools` into the checkpoint. The three `None`
392
- overrides are sufficient β€” the remaining hparams are restored from the checkpoint via pickle.
393
-
394
- **Why unpickling is safe without data files:** Python pickle reconstructs objects by restoring
395
- `__dict__` without running `__init__` and without any file I/O. The pre-computed tensors inside
396
- `dataset_infos` (`n_nodes`, `node_types`, `edge_types`, `nodes_dist`) all restore correctly. The
397
- stored datamodule reference is inert at inference time β€” we never call `train_dataloader()` or any
398
- method that touches data files.
399
-
400
- **Why file paths resolve correctly on any deployment:** all datamodules compute their data root as:
401
- ```python
402
- base_path = pathlib.Path(os.path.realpath(__file__)).parents[2] # β†’ MultiProxAn/
403
- root_path = os.path.join(base_path, self.datadir) # β†’ MultiProxAn/data/qm9
404
- ```
405
- `__file__` is the source file on the current machine, so paths are always relative to the research
406
- code location. `self.datadir` (a relative string like `"data/qm9"`) pickles and restores as-is.
407
- No path patching is needed.
408
 
409
- `extra_features` and `domain_features` are lightweight stateless callables with no file I/O β€”
410
- they unpickle cleanly.
411
 
412
- **`graphgen_generate(self, dataset_id, model_type, sampling_mode, num_nodes, diffusion_steps, chain_frames, multiprox_params)`**:
413
-
414
- ```python
415
- # Acquire _inference_lock (non-blocking) β†’ raise InferenceBusy() if locked
416
- # Load model via _load_graphgen_model (may raise ModelUnavailable)
417
- # standard: run_standard_generation β†’ return {dataset_id, model_type, sampling_mode, image, chain_gif, inference_time_ms}
418
- # multiprox: run_multiprox_init β†’ state["model_type"] = model_type β†’ encode_state_blob
419
- # β†’ return {step:0, image, state, inference_time_ms}
420
- # finally: release lock
421
- ```
422
-
423
- **`graphgen_continue(self, state_b64)`**:
424
-
425
- ```python
426
- # decode_state_blob OUTSIDE lock (fail-fast) β†’ raise InvalidRequestError on ValueError
427
- # Acquire _inference_lock β†’ raise InferenceBusy() if locked
428
- # Load model via _load_graphgen_model(state["dataset_id"], state["model_type"])
429
- # run_multiprox_step β†’ encode_state_blob β†’ return {step, image, state, inference_time_ms}
430
- # finally: release lock
431
- ```
432
 
433
  ---
434
 
435
  ### Step 4: `views/graph_generation.py` additions
436
 
437
- **`GraphGenGenerateView.post`** β€” validation + dispatch:
438
  - Require `dataset_id`, `model_type`, `sampling_mode`
439
- - `dataset_id ∈ GRAPHGEN_DATASETS`, `model_type ∈ {"discrete","continuous"}`, `sampling_mode ∈ {"standard","multiprox"}`
440
- - Check `registry.graphgen_checkpoints_available.get(dataset_id, [])` contains `model_type` β†’ `ModelUnavailable(503)` if not
441
  - `diffusion_steps` clamped to [50, 1000], `chain_frames` to [10, 30]
442
  - `num_nodes`: optional int, validated against `GRAPHGEN_DATASETS[dataset_id]["max_nodes"]`
443
- - multiprox only: `multiprox_params` required; validate `0 <= t_prime <= t <= 1`, `2 <= m <= 100` (default: `100`), `n >= 1` (default: `10`), `1 <= gibbs_chain_freq <= m` (default: `max(1, m // 10)`)
444
- - Call `registry.graphgen_generate(...)` β†’ `Response(result)`
445
 
446
- **`GraphGenContinueView.post`** β€” minimal validation:
447
  - Require `state` (non-empty string)
448
- - Call `registry.graphgen_continue(state_b64)` β†’ `Response(result)`
 
 
449
 
450
  ---
451
 
452
- ### Step 5: `urls.py`
453
 
454
- ```python
455
- # Update import line 12:
456
- from api.views.graph_generation import (
457
- GraphGenDatasetsView, GraphGenSamplingModesView,
458
- GraphGenGenerateView, GraphGenContinueView,
459
- )
460
 
461
- # Add after "graph-generation/sampling-modes" route:
462
- path("graph-generation/generate", GraphGenGenerateView.as_view()),
463
- path("graph-generation/continue", GraphGenContinueView.as_view()),
464
- ```
465
 
466
  ---
467
 
468
- ### Step 6: `requirements.txt` β€” new dependencies
469
-
470
- | Package | Why needed | Already present? |
471
- |---|---|---|
472
- | `Pillow` | PIL `Image`, `BytesIO` ops in `graphgen_inference.py` rendering | Transitive dep of `matplotlib` and `torchvision`, but not explicit β€” add for clarity |
473
- | `overrides` | Pulled in by research code transitively | Present in research `requirements.txt`; not in backend β€” add to be safe |
474
 
475
- `wandb` is **not** added β€” it is stripped from the research code in Step 0. Pin `overrides==7.3.1` to match the research repo. Add under a `# MultiProxAn graph generation` comment.
476
-
477
- **Python 3.9 compatibility verdict:** No issues. Research code syntax requires Python β‰₯ 3.6 only (f-strings; no walrus operator, match statements, or `f"{var=}"` debug syntax). PyTorch 2.0.1 and pytorch-lightning 2.0.4 both officially support Python 3.8–3.11. `torch.load(weights_only=True)` in our state blob decoder is safe on torch 2.0.1 (default is `False`; `True` is supported for plain tensor dicts since torch 1.13). The research code's own `torch.load` calls (checkpoint files, dataset caches) do not use `weights_only` β€” this is fine for our trusted local files on torch 2.0.1, which does not warn or break on the missing kwarg.
 
 
478
 
479
  ---
480
 
481
- ### Step 7: README + Postman
482
 
483
- **`README.md`**: Remove "not yet implemented" annotations from generate and continue rows in the endpoint table.
 
 
 
484
 
485
- **`docs/postman/collection.json`**: Add 12 new requests to the graph-generation folder, covering all combinations of dataset Γ— model_type Γ— sampling_mode Γ— endpoint. `continue` applies to multiprox only.
486
 
487
- **Standard / generate (4 requests):**
488
- 1. `Standard QM9 Discrete` β€” POST generate `{dataset_id:"qm9", model_type:"discrete", sampling_mode:"standard", diffusion_steps:50, chain_frames:10}`
489
- 2. `Standard QM9 Continuous` β€” POST generate `{dataset_id:"qm9", model_type:"continuous", sampling_mode:"standard", diffusion_steps:50, chain_frames:10}`
490
- 3. `Standard comm20 Discrete` β€” POST generate `{dataset_id:"comm20", model_type:"discrete", sampling_mode:"standard", diffusion_steps:500, chain_frames:10}`
491
- 4. `Standard comm20 Continuous` β€” POST generate `{dataset_id:"comm20", model_type:"continuous", sampling_mode:"standard", diffusion_steps:500, chain_frames:10}`
492
 
493
- **MultiProx / generate (4 requests):**
494
- 5. `MultiProx QM9 Discrete Init` β€” POST generate `{dataset_id:"qm9", model_type:"discrete", sampling_mode:"multiprox", multiprox_params:{n:10, m:100, t:0.5, t_prime:0.1}}`
495
- 6. `MultiProx QM9 Continuous Init` β€” POST generate `{dataset_id:"qm9", model_type:"continuous", sampling_mode:"multiprox", multiprox_params:{n:10, m:100, t:0.5, t_prime:0.1}}`
496
- 7. `MultiProx comm20 Discrete Init` β€” POST generate `{dataset_id:"comm20", model_type:"discrete", sampling_mode:"multiprox", multiprox_params:{n:10, m:100, t:0.5, t_prime:0.1}}`
497
- 8. `MultiProx comm20 Continuous Init` β€” POST generate `{dataset_id:"comm20", model_type:"continuous", sampling_mode:"multiprox", multiprox_params:{n:10, m:100, t:0.5, t_prime:0.1}}`
498
 
499
- **MultiProx / continue (4 requests β€” one per dataset Γ— model_type):**
500
- 9. `MultiProx QM9 Discrete Continue` β€” POST continue `{state:"{{graphgen_state_qm9_discrete}}"}`
501
- 10. `MultiProx QM9 Continuous Continue` β€” POST continue `{state:"{{graphgen_state_qm9_continuous}}"}`
502
- 11. `MultiProx comm20 Discrete Continue` β€” POST continue `{state:"{{graphgen_state_comm20_discrete}}"}`
503
- 12. `MultiProx comm20 Continuous Continue` β€” POST continue `{state:"{{graphgen_state_comm20_continuous}}"}`
504
 
505
- Each `continue` request uses a distinct collection variable (set manually from the corresponding init response's `state` field).
 
 
506
 
507
  ---
508
 
@@ -510,16 +390,19 @@ Each `continue` request uses a distinct collection variable (set manually from t
510
 
511
  | Risk | Mitigation |
512
  |---|---|
513
- | `weights_only=True` in `torch.load` blocks primitive scalars | State dict contains only tensors + str/int/float; PyTorch 2.x whitelist covers these. If a version tightens further, fall back to `weights_only=False` with a comment explaining the controlled trust boundary (blob originates from our own server response). |
514
- | `load_from_checkpoint` raises on pickle restoration | Unlikely β€” Python pickle reconstructs objects without file I/O and datamodule paths resolve via `__file__` relative to the research code. If it does occur, inspect `ckpt['hyper_parameters']` to identify which object fails to deserialize. |
515
  | `model.gibbs_fixed_t_2` attribute override | Only safe because `_inference_lock` ensures single-threaded inference. Save/restore pattern is used. |
516
- | matplotlib not thread-safe | Rendering only called inside `_inference_lock`, so effectively single-threaded. |
 
517
 
518
  ## Verification
519
 
520
- 1. Django shell: `from diffusion_model_discrete import DiscreteDenoisingDiffusion` β†’ no `ModuleNotFoundError`
521
- 2. `GET /graph-generation/datasets` still returns `available_model_types` (regression)
522
- 3. `POST /graph-generation/generate` (standard, comm20, `diffusion_steps=50, chain_frames=10`) β†’ `image` starts with `data:image/png;base64,`, `chain_gif` starts with `data:image/gif;base64,`
523
- 4. `POST /graph-generation/generate` (multiprox, comm20) β†’ `step=0`, `state` is non-empty base64 string
524
- 5. `POST /graph-generation/continue` with `state` from step 4 β†’ `step=1`, different `image`, new `state`
525
- 6. Error paths: unknown `dataset_id` β†’ 400, corrupted `state` β†’ 400, concurrent requests β†’ 429, `t_prime > t` β†’ 400
 
 
 
13
  | `src/research/MultiProxAn/src/diffusion_model_discrete.py` | Strip all `wandb` imports and calls |
14
  | `src/research/MultiProxAn/src/diffusion_model.py` | Strip all `wandb` imports and calls (incl. one unguarded call) |
15
  | `src/research/MultiProxAn/src/utils.py` | Remove `import wandb` and `setup_wandb()` function |
16
+ | `src/research/MultiProxAn/src/analysis/spectre_utils.py` | Wrap `graph_tool`, `pyemd`, `pygsp`, `dist_helper` imports in try/except |
17
+ | `src/research/MultiProxAn/src/metrics/molecular_metrics.py` | Guard optional metric imports with try/except |
18
+ | `src/research/MultiProxAn/src/metrics/molecular_metrics_discrete.py` | Guard optional metric imports with try/except |
19
+ | `src/research/MultiProxAn/src/metrics/train_metrics.py` | Guard optional metric imports with try/except |
20
  | `src/backend/research_api/settings.py` | Add `MultiProxAn/src/` to `sys.path` |
21
+ | `src/backend/api/services/graphgen_inference.py` | **New file** -- all inference logic + rendering |
22
+ | `src/backend/api/services/registry.py` | Add `_graphgen_models` cache, `_safe_load_lightning_checkpoint`, `graphgen_generate_stream`, `graphgen_continue_stream`, `force_release_inference_lock` |
23
+ | `src/backend/api/views/graph_generation.py` | Add `GraphGenGenerateView`, `GraphGenContinueView`, SSE streaming helpers |
24
+ | `src/backend/api/views/health.py` | Add inference lock status to health endpoint, add `ForceUnlockView` |
25
+ | `src/backend/api/urls.py` | Wire 3 new routes (generate, continue, debug/force-unlock) |
26
+ | `src/backend/requirements.txt` | Add `Pillow`, `overrides` |
27
+ | `src/backend/README.md` | Mark generate/continue as implemented, document SSE protocol |
28
+ | `docs/postman/collection.json` | Add 12 example requests, auto-chaining test scripts |
29
 
30
  ## Model Differences: Discrete vs Continuous
31
 
 
35
  |---|---|---|
36
  | **`node_mask` dtype** | `bool` | `float32` (`.float()` required) |
37
  | **Initial noise** | `sample_discrete_feature_noise(limit_dist=model.limit_dist, node_mask=node_mask)` | `sample_feature_noise(X_size=(1,n,Xdim), E_size=(1,n,n,Edim), y_size=(1,ydim), node_mask=node_mask)` |
38
+ | **`sample_p_zs_given_zt` returns** | `(sampled_s, discrete_sampled_s)` -- 2-tuple; `discrete_sampled_s.X/E` are already collapsed integers | single `z_s` PlaceHolder with continuous floats |
39
+ | **Chain frame render** | use `discrete_sampled_s.X/E` directly (`.long()` required -- see below) | `utils.unnormalize(z_s.X, z_s.E, z_s.y, model.norm_values, model.norm_biases, node_mask, collapse=True)` -> `.X/.E` are integers |
40
+ | **Final graph collapse** | `PlaceHolder(X=X, E=E, y=y).mask(node_mask, collapse=True)` (`.long()` required) | `model.sample_discrete_graph_given_z0(X, E, y, node_mask)` (runs another forward pass + unnormalize) |
41
  | **Gibbs ensemble aggregation** | `torch.median(X, dim=1).values` | `torch.mean(X, dim=1)` |
42
 
43
+ **Critical `.long()` fix**: The discrete model's `sample_p_zs_given_zt` calls `.type_as(y_t)` internally, which can cast collapsed integer indices to float. Both `_denoising_step` and `_collapse_final` must apply `.long()` on the discrete path to prevent `TypeError: list indices must be integers or slices, not float` in the rendering functions.
44
+
45
  The refinement loop in `run_multiprox_step` uses `sample_p_zs_given_zt` to update `cur_X/E/y` at each step; only the first element of the tuple (or the single return value) is needed there. `_collapse_final` is called once at the end for rendering.
46
 
47
  ## Research Code Reused (read-only)
48
 
49
  | Symbol | Location | How used |
50
  |---|---|---|
51
+ | `DiscreteDenoisingDiffusion` | `src/diffusion_model_discrete.py` | Loaded via `_safe_load_lightning_checkpoint` for `model_type=discrete` |
52
+ | `LiftedDenoisingDiffusion` | `src/diffusion_model.py` | Loaded via `_safe_load_lightning_checkpoint` for `model_type=continuous` |
53
  | `model.node_dist.sample_n(1, device)` | both models | Sample number of nodes |
54
+ | `diffusion_utils.sample_discrete_feature_noise(limit_dist, node_mask)` | `src/diffusion/diffusion_utils.py` | Initial noise -- discrete only |
55
+ | `diffusion_utils.sample_feature_noise(X_size, E_size, y_size, node_mask)` | `src/diffusion/diffusion_utils.py` | Initial noise -- continuous only |
56
+ | `model.sample_p_zs_given_zt(s, t, X, E, y, node_mask)` | both models | One denoising step (return varies by model type -- see table above) |
57
  | `model.apply_noise(X, E, y, node_mask, gibbs=True)` | both models | Re-apply Gibbs noise; uses `model.gibbs_fixed_t_2` internally |
58
+ | `PlaceHolder.mask(node_mask, collapse=True)` | `src/utils.py` | Final collapse -- discrete only |
59
+ | `model.sample_discrete_graph_given_z0(X, E, y, node_mask)` | `diffusion_model.py` | Final collapse -- continuous only |
60
+ | `utils.unnormalize(X, E, y, norm_values, norm_biases, node_mask, collapse=True)` | `src/utils.py` | Chain frame rendering -- continuous only |
61
+ | `model.norm_values`, `model.norm_biases` | `diffusion_model.py` | Unnormalization factors -- continuous only |
62
+
63
+ ## Design Decisions
64
+
65
+ ### Streaming Protocol: Server-Sent Events (SSE)
66
+
67
+ Both generate and continue endpoints use **SSE** (`text/event-stream`) instead of plain JSON or NDJSON. This enables real-time progress streaming with preview images in Postman's native SSE viewer.
68
+
69
+ **Event types:**
70
+ - `event: progress` -- metadata (phase, step, total_steps, elapsed_ms)
71
+ - `event: preview` -- base64 PNG image of the current graph state (separate event so Postman shows clean image updates)
72
+ - `event: result` -- final JSON payload (image, chain_gif, state, timing)
73
+
74
+ **SSE helpers** in `views/graph_generation.py`:
75
+ ```python
76
+ def _streaming_sse_response(gen):
77
+ resp = StreamingHttpResponse(_sse_iter(gen), content_type="text/event-stream")
78
+ resp["Cache-Control"] = "no-cache"
79
+ resp["X-Accel-Buffering"] = "no" # nginx
80
+ return resp
81
+
82
+ def _sse_iter(gen):
83
+ for event in gen:
84
+ etype = event.get("type", "message")
85
+ preview = event.pop("preview", None)
86
+ yield f"event: {etype}\ndata: {json.dumps(event, separators=(',', ':'))}\n\n"
87
+ if preview:
88
+ yield f"event: preview\ndata: {preview}\n\n"
89
+ ```
90
+
91
+ All three inference generators (`run_standard_generation`, `run_multiprox_init`, `run_multiprox_step`) yield progress dicts with optional `preview` keys, then a final `result` dict.
92
+
93
+ ### Safe Checkpoint Loading: `_safe_load_lightning_checkpoint`
94
+
95
+ `load_from_checkpoint` must **not** be used for MultiProxAn models. The comm20 checkpoints were trained with DDP (DistributedDataParallel), causing three cascading failures:
96
+
97
+ 1. **DDP `__setstate__`/`__getstate__`** -- unpickling checkpoint hparams tries to reconstruct DDP wrappers, which requires an active process group. Fix: monkey-patch `__setstate__` and `__getstate__` during `torch.load`.
98
+ 2. **`save_hyperparameters` deepcopy crash** -- Lightning's `save_hyperparameters` deepcopies hparams containing DDP-wrapped datamodule objects, causing a hard process crash (exit code -1). Fix: patch `save_hyperparameters` to a no-op during model construction.
99
+ 3. **CUDA OOM** -- the full checkpoint (with heavy pickled hparams objects) crashes if loaded directly to GPU. Fix: `map_location="cpu"` first, then `model.to(device)` after state_dict loading.
100
+
101
+ The solution in `registry.py`:
102
+ ```python
103
+ def _safe_load_lightning_checkpoint(cls, ckpt_path):
104
+ # 1. Patch DDP setstate/getstate to simple dict update/return
105
+ # 2. torch.load(ckpt_path, map_location="cpu", weights_only=False)
106
+ # 3. Extract hyper_parameters, null out train_metrics/sampling_metrics/visualization_tools
107
+ # 4. Patch save_hyperparameters to no-op
108
+ # 5. Construct model via cls(**hparams)
109
+ # 6. load_state_dict(strict=False), del ckpt, model.to(device), model.eval()
110
+ ```
111
+
112
+ ### Graph Rendering: PIL + networkx (no matplotlib)
113
+
114
+ `_render_comm20` uses **PIL + networkx `spring_layout`** instead of matplotlib. Matplotlib's GUI backend initializes in the calling thread and crashes with exit code -1 when called from Django's request threads on Windows. Even with `matplotlib.use("Agg")`, the backend initialization races with Django's threading model.
115
+
116
+ ```python
117
+ def _render_comm20(X_int, E_int):
118
+ import networkx as nx
119
+ from PIL import Image, ImageDraw
120
+ # Build networkx Graph from adjacency matrix
121
+ # spring_layout(G, seed=42) for deterministic positions
122
+ # PIL drawing: lines for edges, ellipses for nodes (#2ecc71 green, #1a7a42 outline)
123
+ # Returns PIL.Image (300x300)
124
+ ```
125
+
126
+ QM9 rendering uses RDKit's `MolToImage` (no matplotlib dependency).
127
+
128
+ ### Inference Lock Management
129
+
130
+ A single `threading.Lock` (`_inference_lock`) protects all inference endpoints (COINs predict, graphgen generate, graphgen continue). Design choices:
131
+
132
+ - **Non-blocking acquire** -- returns 429 `INFERENCE_BUSY` immediately if locked
133
+ - **Owner tracking** -- `_inference_lock_owner` stores a description string (e.g. `"graphgen_generate comm20/discrete/standard"`) for debugging
134
+ - **Lock held by generator** -- the lock is acquired eagerly (before the generator starts) and released in the generator's `finally` block, so it is held for the entire streaming duration
135
+ - **Force-release endpoint** -- `POST /debug/force-unlock` (DEBUG mode only) for stuck locks from killed clients. Also exposed in the health endpoint's `inference_lock` field
136
+
137
+ All three lock acquisition sites follow the same pattern:
138
+ ```python
139
+ if not self._inference_lock.acquire(blocking=False):
140
+ raise InferenceBusy()
141
+ self._inference_lock_owner = f"<endpoint> <context>"
142
+ try:
143
+ model = self._load_<model>(...)
144
+ except Exception:
145
+ self._inference_lock_owner = None
146
+ self._inference_lock.release()
147
+ raise
148
+
149
+ def _gen():
150
+ try:
151
+ ...yield events...
152
+ finally:
153
+ self._inference_lock_owner = None
154
+ self._inference_lock.release()
155
+ return _gen()
156
+ ```
157
+
158
+ ### Research Code Import Guards
159
+
160
+ Several research code modules import heavy optional dependencies (`graph_tool`, `pyemd`, `pygsp`, `dist_helper`) at module level. These are needed only for metric computation during training, not inference. Since they are transitively imported during checkpoint unpickling (via pickled datamodule/metric objects), they must be guarded with `try/except ImportError` to avoid crashing on systems where these packages are not installed.
161
+
162
+ Files to patch: `spectre_utils.py`, `molecular_metrics.py`, `molecular_metrics_discrete.py`, `train_metrics.py`.
163
 
164
  ## Implementation
165
 
166
+ ### Step 0: Patch research code -- remove `wandb`
167
 
168
+ `wandb` is imported at module level in three files; the import alone causes `ImportError` at checkpoint load time. One call in `diffusion_model.py` is **unguarded** (no `if wandb.run:` check) and would crash at inference even if the import somehow succeeded.
 
 
169
 
170
  **`src/research/MultiProxAn/src/diffusion_model_discrete.py`**:
171
  - Remove `import wandb` (line 9)
172
+ - Remove the two `utils.setup_wandb(self.cfg)` call sites
173
+ - Remove all `if wandb.run: wandb.log(...)` / `wandb.run.summary[...]` blocks
174
 
175
  **`src/research/MultiProxAn/src/diffusion_model.py`**:
176
  - Remove `import wandb` (line 9)
177
  - Remove the two `utils.setup_wandb(self.cfg)` call sites
178
+ - Remove all guarded `if wandb.run:` blocks + one unguarded `wandb.log(...)` call
 
179
 
180
  **`src/research/MultiProxAn/src/utils.py`**:
181
+ - Remove `import wandb` and the entire `setup_wandb(cfg)` function
 
 
 
182
 
183
  ---
184
 
185
+ ### Step 0b: Guard heavy optional imports in research code
186
+
187
+ Wrap `graph_tool`, `pyemd`, `pygsp`, and `dist_helper` imports in `try/except` in:
188
+ - `spectre_utils.py` -- all four imports
189
+ - `molecular_metrics.py`, `molecular_metrics_discrete.py`, `train_metrics.py` -- cascading metric deps
190
 
191
+ ---
192
+
193
+ ### Step 1: `settings.py` -- sys.path fix
194
 
195
+ Both `MultiProxAn/` and `MultiProxAn/src/` must be on `sys.path`:
196
  ```python
197
  _MULTIPROXAN_SRC = str(PROJECT_ROOT / "src" / "research" / "MultiProxAn" / "src")
198
  for _path in (_COINS_KG_ROOT, _MULTIPROXAN_ROOT, _MULTIPROXAN_SRC):
 
204
 
205
  ### Step 2: New `graphgen_inference.py`
206
 
207
+ Completely independent of Django/registry. Receives a loaded model and yields progress/result dicts.
208
 
209
  **Module structure:**
210
 
211
  ```
212
  graphgen_inference.py
213
+ +-- Constants: QM9_ATOM_TYPES, STATE_BLOB_MAX_BYTES, REQUIRED_STATE_KEYS
214
+ |
215
+ +-- # Model-type helpers
216
+ +-- _is_discrete(model) -> bool
217
+ +-- _build_node_mask(n_nodes, n_max, model) -> bool or float32 tensor
218
+ +-- _sample_initial_noise(model, n_max, node_mask) -> PlaceHolder
219
+ +-- _denoising_step(model, s_t, t_t, X, E, y, node_mask) -> (X_soft, E_soft, y_soft, X_int, E_int)
220
+ +-- _gibbs_aggregate(model, X) -> tensor [median (discrete) or mean (continuous)]
221
+ +-- _collapse_final(model, X, E, y, node_mask) -> (X_int, E_int)
222
+ |
223
+ +-- # Main inference generators (yield progress dicts, then result dict)
224
+ +-- run_standard_generation(model, num_nodes, diffusion_steps, chain_frames, dataset_id)
225
+ +-- run_multiprox_init(model, num_nodes, n, m, t, t_prime, gibbs_chain_freq, dataset_id)
226
+ +-- run_multiprox_step(model, state_dict, dataset_id)
227
+ |
228
+ +-- # State serialisation
229
+ +-- encode_state_blob(state_dict) -> str
230
+ +-- decode_state_blob(b64_str) -> dict
231
+ |
232
+ +-- # Visualisation
233
+ +-- render_graph(X_int, E_int, dataset_id) -> PIL.Image
234
+ +-- _render_qm9(X_int, E_int) -> PIL.Image [RDKit]
235
+ +-- _render_comm20(X_int, E_int) -> PIL.Image [PIL + networkx spring_layout]
236
+ +-- _pil_to_b64(img) -> str
237
+ +-- _frames_to_gif_b64(frames) -> str
238
  ```
239
 
240
  #### Model-type helpers
241
 
242
+ `_denoising_step` returns `(X_soft, E_soft, y_soft, X_int, E_int)`. For discrete models, `.long()` is applied to `discrete_s.X` and `discrete_s.E` to counteract the `.type_as(y_t)` float cast. For continuous models, `utils.unnormalize` with `collapse=True` produces the integer tensors.
243
+
244
+ `_collapse_final` also applies `.long()` on the discrete path for the same reason.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
  #### `run_standard_generation`
247
 
248
+ Re-implements the denoising loop -- does NOT call `sample_batch`. Uses `s_idx / diffusion_steps` for normalized time. Yields `progress` events at each step (with `preview` at frame intervals) and a final `result` event with `image` and `chain_gif`.
249
 
250
  ```python
251
  def run_standard_generation(model, num_nodes, diffusion_steps, chain_frames, dataset_id):
 
268
  step = diffusion_steps - 1 - s_idx
269
  if step % frame_interval == 0 or s_idx == 0:
270
  gif_frames.append(render_graph(X_int[0, :n_max], E_int[0, :n_max, :n_max], dataset_id))
271
+ yield {"type": "progress", "phase": "denoise", "step": step + 1, ...}
272
 
273
  X_final, E_final = _collapse_final(model, X, E, y, node_mask)
274
+ yield {"type": "result", "image": ..., "chain_gif": ..., "inference_time_ms": ...}
 
275
  ```
276
 
277
+ #### API parameter -> model attribute mapping
278
 
279
  | API param | Model attribute | Role |
280
  |---|---|---|
281
  | `t` | `model.gibbs_fixed_t_2` | Fixed noise level for the Gibbs chain; used by `apply_noise(gibbs=True)` and as `fixed_t_norm` in the inner denoising step |
282
+ | `t_prime` | `model.gibbs_fixed_t_1` | Refinement target; `P = int((gibbs_fixed_t_2 - gibbs_fixed_t_1) * T) + 1` steps denoise from `t` down to `t_prime` |
283
  | `gibbs_chain_freq` | `model.gibbs_chain_freq` | Inner Gibbs steps per `/continue` call (same attribute name in checkpoint) |
284
  | `m` | `model.gibbs_M` | Ensemble size |
285
+ | `n` | `model.gibbs_N` | Number of outer Gibbs rounds; session is complete when `step == n` |
286
 
287
+ All parameters are always explicit in the API request -- no fallback to checkpoint values. The `apply_noise(gibbs=True)` call reads `self.gibbs_fixed_t_2` directly, so we must set `model.gibbs_fixed_t_2 = t` before calling it (and restore afterwards). This is safe because `_inference_lock` ensures single-threaded access.
288
 
289
  #### `run_multiprox_init`
290
 
291
+ Initialises the M-member ensemble. Yields `progress` events during noise sampling, then a `result` event with the step-0 image and state dict. Default `gibbs_chain_freq = max(1, m // 10)`.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
292
 
293
  #### `run_multiprox_step`
294
 
295
+ Runs `gibbs_chain_freq` inner Gibbs iterations. Yields `progress` events during Gibbs phase (with preview of ensemble aggregate) and refinement phase (t -> t_prime denoising). The refinement pass always runs to produce a clean render. Returns `round_complete` and `done` flags.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
296
 
297
  **State blob (encode/decode):**
298
 
 
301
  REQUIRED_STATE_KEYS = {"X", "E", "y", "n_nodes", "dataset_id", "model_type", "T",
302
  "n", "m", "t", "t_prime", "gibbs_chain_freq", "inner_step", "step"}
303
 
304
+ # encode: torch.save(state_dict, BytesIO) -> base64 string
305
+ # decode: base64 -> bytes -> torch.load(weights_only=False) -> validate keys + X.dim()==4, E.dim()==5
 
 
306
  ```
307
 
 
 
 
 
 
 
 
308
  ---
309
 
310
  ### Step 3: `registry.py` additions
311
 
312
+ **`_safe_load_lightning_checkpoint(cls, ckpt_path)`** -- top-level function, used by `_load_graphgen_model`.
313
 
314
+ **New instance variables** in `ModelRegistry.__init__`:
315
  ```python
316
+ self._inference_lock = threading.Lock()
317
+ self._inference_lock_owner = None # description string for debugging
318
+ self._graphgen_models = {} # (dataset_id, model_type) -> loaded eval-mode model
319
  ```
320
 
321
+ **`force_release_inference_lock(self)`** -- emergency release for stuck locks.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
322
 
323
+ **`_load_graphgen_model(self, dataset_id, model_type)`** -- lazy load with caching. Uses `_safe_load_lightning_checkpoint` instead of `load_from_checkpoint`.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
 
325
+ **`graphgen_generate_stream(self, ...)`** -- acquires lock, loads model, returns a generator. The generator yields SSE-compatible dicts and releases the lock in `finally`. Encodes state blob for multiprox mode.
 
326
 
327
+ **`graphgen_continue_stream(self, state_b64)`** -- decodes state blob eagerly (fail-fast before lock), acquires lock, loads model, returns generator.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
328
 
329
  ---
330
 
331
  ### Step 4: `views/graph_generation.py` additions
332
 
333
+ **`GraphGenGenerateView.post`** -- validation + dispatch:
334
  - Require `dataset_id`, `model_type`, `sampling_mode`
335
+ - `dataset_id` in `GRAPHGEN_DATASETS`, `model_type` in `{"discrete","continuous"}`, `sampling_mode` in `{"standard","multiprox"}`
336
+ - Check `registry.graphgen_checkpoints_available` contains `model_type` -> `ModelUnavailable(503)` if not
337
  - `diffusion_steps` clamped to [50, 1000], `chain_frames` to [10, 30]
338
  - `num_nodes`: optional int, validated against `GRAPHGEN_DATASETS[dataset_id]["max_nodes"]`
339
+ - multiprox only: `multiprox_params` required; validate `0 < t_prime <= t <= 1`, `2 <= m <= 100`, `n >= 1`, `1 <= gibbs_chain_freq <= m`
340
+ - Return `_streaming_sse_response(gen)`
341
 
342
+ **`GraphGenContinueView.post`** -- minimal validation:
343
  - Require `state` (non-empty string)
344
+ - Return `_streaming_sse_response(gen)`
345
+
346
+ **SSE helpers** (`_streaming_sse_response`, `_sse_iter`) -- convert generator dicts to SSE event stream with anti-buffering headers. Preview images split into separate `event: preview` events.
347
 
348
  ---
349
 
350
+ ### Step 5: `views/health.py` additions
351
 
352
+ **`HealthView`** -- include `inference_lock` status (`locked`, `owner`) in response.
 
 
 
 
 
353
 
354
+ **`ForceUnlockView`** -- DEBUG-only POST endpoint at `debug/force-unlock`.
 
 
 
355
 
356
  ---
357
 
358
+ ### Step 6: `urls.py`
 
 
 
 
 
359
 
360
+ ```python
361
+ path("graph-generation/generate", GraphGenGenerateView.as_view()),
362
+ path("graph-generation/continue", GraphGenContinueView.as_view()),
363
+ path("debug/force-unlock", ForceUnlockView.as_view()),
364
+ ```
365
 
366
  ---
367
 
368
+ ### Step 7: `requirements.txt` -- new dependencies
369
 
370
+ | Package | Why needed |
371
+ |---|---|
372
+ | `Pillow` | PIL rendering in `graphgen_inference.py` |
373
+ | `overrides` | Transitive research code dependency |
374
 
375
+ `wandb` is **not** added -- stripped from research code.
376
 
377
+ ---
 
 
 
 
378
 
379
+ ### Step 8: README + Postman
 
 
 
 
380
 
381
+ **`README.md`**: Document SSE streaming protocol with three event types. Add force-unlock endpoint to endpoint table. Update health endpoint description to include lock status.
 
 
 
 
382
 
383
+ **`docs/postman/collection.json`**:
384
+ - 12 graph-generation requests (4 standard generate, 4 multiprox init, 4 continue variants consolidated into 1 using `{{multiprox_state}}`)
385
+ - **Auto-chaining test scripts**: all 4 multiprox init requests have post-response scripts that parse the SSE `result` event and save the `state` field to the `multiprox_state` collection variable. The consolidated continue endpoint reads `{{multiprox_state}}` and also updates it from its result, enabling repeated continue calls.
386
 
387
  ---
388
 
 
390
 
391
  | Risk | Mitigation |
392
  |---|---|
393
+ | `weights_only=False` in `torch.load` for state blobs | State dict originates from our own server response. The blob is opaque base64 from the client but was generated server-side via `encode_state_blob`. Size limit (10 MB) and key validation provide basic bounds checking. |
394
+ | `_safe_load_lightning_checkpoint` bypasses Lightning's loading | Manual load + `load_state_dict(strict=False)` is correct for inference. Missing keys (if any) default to initialized values. |
395
  | `model.gibbs_fixed_t_2` attribute override | Only safe because `_inference_lock` ensures single-threaded inference. Save/restore pattern is used. |
396
+ | SSE connection dropped mid-stream | Generator's `finally` block releases the lock. If the generator is garbage-collected without completing, Python still runs the finally block. |
397
+ | Research code imports heavy optional deps at module level | Guarded with `try/except ImportError` in spectre_utils.py, molecular_metrics*.py, train_metrics.py. |
398
 
399
  ## Verification
400
 
401
+ 1. Django shell: `from diffusion_model_discrete import DiscreteDenoisingDiffusion` -- no `ModuleNotFoundError`
402
+ 2. `GET /graph-generation/datasets` returns `available_model_types` (regression check)
403
+ 3. All 4 standard generate combinations (qm9/comm20 x discrete/continuous) return SSE streams with progress, preview, and result events
404
+ 4. All 4 multiprox init combinations return step-0 image and state blob
405
+ 5. Continue endpoint chains correctly -- state updates across init -> continue x 3
406
+ 6. Force-unlock endpoint releases stuck locks in DEBUG mode
407
+ 7. Health endpoint shows lock status and owner
408
+ 8. Error paths: unknown `dataset_id` -> 400, corrupted `state` -> 400, concurrent requests -> 429, `t_prime > t` -> 400
docs/postman/collection.json CHANGED
@@ -8,6 +8,10 @@
8
  {
9
  "key": "base_url",
10
  "value": "http://localhost:8000/api/v1"
 
 
 
 
11
  }
12
  ],
13
  "item": [
@@ -52,6 +56,24 @@
52
  },
53
  "description": "List the 3 research methods."
54
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  }
56
  ]
57
  },
@@ -362,60 +384,201 @@
362
  "name": "Graph Generation - Inference",
363
  "item": [
364
  {
365
- "name": "POST /graph-generation/generate (standard)",
366
  "request": {
367
  "method": "POST",
368
- "header": [
369
- { "key": "Content-Type", "value": "application/json" }
370
- ],
371
  "body": {
372
  "mode": "raw",
373
- "raw": "{\n \"dataset_id\": \"qm9\",\n \"model_type\": \"discrete\",\n \"sampling_mode\": \"standard\",\n \"num_nodes\": null,\n \"diffusion_steps\": 500,\n \"chain_frames\": 20\n}"
374
  },
375
  "url": {
376
  "raw": "{{base_url}}/graph-generation/generate",
377
  "host": ["{{base_url}}"],
378
  "path": ["graph-generation", "generate"]
379
  },
380
- "description": "Standard denoising graph generation. Returns final image + chain GIF."
381
  }
382
  },
383
  {
384
- "name": "POST /graph-generation/generate (multiprox)",
385
  "request": {
386
  "method": "POST",
387
- "header": [
388
- { "key": "Content-Type", "value": "application/json" }
389
- ],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
390
  "body": {
391
  "mode": "raw",
392
- "raw": "{\n \"dataset_id\": \"comm20\",\n \"model_type\": \"discrete\",\n \"sampling_mode\": \"multiprox\",\n \"num_nodes\": null,\n \"diffusion_steps\": 500,\n \"multiprox_params\": {\n \"m\": 10,\n \"t\": 0.5,\n \"t_prime\": 0.1\n }\n}"
393
  },
394
  "url": {
395
  "raw": "{{base_url}}/graph-generation/generate",
396
  "host": ["{{base_url}}"],
397
  "path": ["graph-generation", "generate"]
398
  },
399
- "description": "MultiProx generation. Returns step 0 + state for continue calls."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
400
  }
401
  },
402
  {
403
  "name": "POST /graph-generation/continue",
 
 
 
 
 
 
 
 
 
404
  "request": {
405
  "method": "POST",
406
- "header": [
407
- { "key": "Content-Type", "value": "application/json" }
408
- ],
409
  "body": {
410
  "mode": "raw",
411
- "raw": "{\n \"state\": \"<paste state from previous response>\"\n}"
412
  },
413
  "url": {
414
  "raw": "{{base_url}}/graph-generation/continue",
415
  "host": ["{{base_url}}"],
416
  "path": ["graph-generation", "continue"]
417
  },
418
- "description": "Advance MultiProx chain by one step. Paste the state from the previous response."
419
  }
420
  }
421
  ]
 
8
  {
9
  "key": "base_url",
10
  "value": "http://localhost:8000/api/v1"
11
+ },
12
+ {
13
+ "key": "multiprox_state",
14
+ "value": ""
15
  }
16
  ],
17
  "item": [
 
56
  },
57
  "description": "List the 3 research methods."
58
  }
59
+ },
60
+ {
61
+ "name": "POST /debug/force-unlock",
62
+ "request": {
63
+ "method": "POST",
64
+ "header": [
65
+ {
66
+ "key": "Content-Type",
67
+ "value": "application/json"
68
+ }
69
+ ],
70
+ "url": {
71
+ "raw": "{{base_url}}/debug/force-unlock",
72
+ "host": ["{{base_url}}"],
73
+ "path": ["debug", "force-unlock"]
74
+ },
75
+ "description": "Force-release a stuck inference lock. Only available in DEBUG mode."
76
+ }
77
  }
78
  ]
79
  },
 
384
  "name": "Graph Generation - Inference",
385
  "item": [
386
  {
387
+ "name": "POST /graph-generation/generate (standard, QM9, discrete)",
388
  "request": {
389
  "method": "POST",
390
+ "header": [{ "key": "Content-Type", "value": "application/json" }],
 
 
391
  "body": {
392
  "mode": "raw",
393
+ "raw": "{\n \"dataset_id\": \"qm9\",\n \"model_type\": \"discrete\",\n \"sampling_mode\": \"standard\",\n \"num_nodes\": null,\n \"diffusion_steps\": 500,\n \"chain_frames\": 30\n}"
394
  },
395
  "url": {
396
  "raw": "{{base_url}}/graph-generation/generate",
397
  "host": ["{{base_url}}"],
398
  "path": ["graph-generation", "generate"]
399
  },
400
+ "description": "Standard denoising on QM9 molecules (discrete model). Returns final PNG image + chain GIF."
401
  }
402
  },
403
  {
404
+ "name": "POST /graph-generation/generate (standard, QM9, continuous)",
405
  "request": {
406
  "method": "POST",
407
+ "header": [{ "key": "Content-Type", "value": "application/json" }],
408
+ "body": {
409
+ "mode": "raw",
410
+ "raw": "{\n \"dataset_id\": \"qm9\",\n \"model_type\": \"continuous\",\n \"sampling_mode\": \"standard\",\n \"num_nodes\": null,\n \"diffusion_steps\": 500,\n \"chain_frames\": 30\n}"
411
+ },
412
+ "url": {
413
+ "raw": "{{base_url}}/graph-generation/generate",
414
+ "host": ["{{base_url}}"],
415
+ "path": ["graph-generation", "generate"]
416
+ },
417
+ "description": "Standard denoising on QM9 molecules (continuous/lifted model). Returns final PNG image + chain GIF."
418
+ }
419
+ },
420
+ {
421
+ "name": "POST /graph-generation/generate (standard, comm20, discrete)",
422
+ "request": {
423
+ "method": "POST",
424
+ "header": [{ "key": "Content-Type", "value": "application/json" }],
425
+ "body": {
426
+ "mode": "raw",
427
+ "raw": "{\n \"dataset_id\": \"comm20\",\n \"model_type\": \"discrete\",\n \"sampling_mode\": \"standard\",\n \"num_nodes\": null,\n \"diffusion_steps\": 500,\n \"chain_frames\": 30\n}"
428
+ },
429
+ "url": {
430
+ "raw": "{{base_url}}/graph-generation/generate",
431
+ "host": ["{{base_url}}"],
432
+ "path": ["graph-generation", "generate"]
433
+ },
434
+ "description": "Standard denoising on Community20 graphs (discrete model). Returns final PNG image + chain GIF."
435
+ }
436
+ },
437
+ {
438
+ "name": "POST /graph-generation/generate (standard, comm20, continuous)",
439
+ "request": {
440
+ "method": "POST",
441
+ "header": [{ "key": "Content-Type", "value": "application/json" }],
442
+ "body": {
443
+ "mode": "raw",
444
+ "raw": "{\n \"dataset_id\": \"comm20\",\n \"model_type\": \"continuous\",\n \"sampling_mode\": \"standard\",\n \"num_nodes\": null,\n \"diffusion_steps\": 500,\n \"chain_frames\": 30\n}"
445
+ },
446
+ "url": {
447
+ "raw": "{{base_url}}/graph-generation/generate",
448
+ "host": ["{{base_url}}"],
449
+ "path": ["graph-generation", "generate"]
450
+ },
451
+ "description": "Standard denoising on Community20 graphs (continuous/lifted model). Returns final PNG image + chain GIF."
452
+ }
453
+ },
454
+ {
455
+ "name": "POST /graph-generation/generate (multiprox init, QM9, discrete)",
456
+ "event": [
457
+ {
458
+ "listen": "test",
459
+ "script": {
460
+ "type": "text/javascript",
461
+ "exec": ["// Extract state from the SSE result event and store as collection variable", "var body = pm.response.text();", "var lines = body.split('\\n');", "for (var i = 0; i < lines.length; i++) {", " if (lines[i].trim() === 'event: result' && i + 1 < lines.length) {", " var dataLine = lines[i + 1].replace(/^data: /, '');", " try {", " var result = JSON.parse(dataLine);", " if (result.state) {", " pm.collectionVariables.set('multiprox_state', result.state);", " console.log('State saved to {{multiprox_state}} (' + result.state.length + ' chars)');", " }", " } catch (e) { console.log('Failed to parse result event: ' + e); }", " break;", " }", "}"]
462
+ }
463
+ }
464
+ ],
465
+ "request": {
466
+ "method": "POST",
467
+ "header": [{ "key": "Content-Type", "value": "application/json" }],
468
+ "body": {
469
+ "mode": "raw",
470
+ "raw": "{\n \"dataset_id\": \"qm9\",\n \"model_type\": \"discrete\",\n \"sampling_mode\": \"multiprox\",\n \"num_nodes\": null,\n \"diffusion_steps\": 500,\n \"multiprox_params\": {\n \"n\": 10,\n \"m\": 100,\n \"t\": 0.5,\n \"t_prime\": 0.004,\n \"gibbs_chain_freq\": 10\n }\n}"
471
+ },
472
+ "url": {
473
+ "raw": "{{base_url}}/graph-generation/generate",
474
+ "host": ["{{base_url}}"],
475
+ "path": ["graph-generation", "generate"]
476
+ },
477
+ "description": "MultiProx Gibbs init on QM9 (discrete). Best params from thesis Table 4.3.1: t=50%, t'=0.4% of T. Returns step 0 image + state blob. State is auto-saved to {{multiprox_state}}."
478
+ }
479
+ },
480
+ {
481
+ "name": "POST /graph-generation/generate (multiprox init, QM9, continuous)",
482
+ "event": [
483
+ {
484
+ "listen": "test",
485
+ "script": {
486
+ "type": "text/javascript",
487
+ "exec": ["var body = pm.response.text();", "var lines = body.split('\\n');", "for (var i = 0; i < lines.length; i++) {", " if (lines[i].trim() === 'event: result' && i + 1 < lines.length) {", " var dataLine = lines[i + 1].replace(/^data: /, '');", " try {", " var result = JSON.parse(dataLine);", " if (result.state) { pm.collectionVariables.set('multiprox_state', result.state); }", " } catch (e) {}", " break;", " }", "}"]
488
+ }
489
+ }
490
+ ],
491
+ "request": {
492
+ "method": "POST",
493
+ "header": [{ "key": "Content-Type", "value": "application/json" }],
494
+ "body": {
495
+ "mode": "raw",
496
+ "raw": "{\n \"dataset_id\": \"qm9\",\n \"model_type\": \"continuous\",\n \"sampling_mode\": \"multiprox\",\n \"num_nodes\": null,\n \"diffusion_steps\": 500,\n \"multiprox_params\": {\n \"n\": 10,\n \"m\": 100,\n \"t\": 0.5,\n \"t_prime\": 0.004,\n \"gibbs_chain_freq\": 10\n }\n}"
497
+ },
498
+ "url": {
499
+ "raw": "{{base_url}}/graph-generation/generate",
500
+ "host": ["{{base_url}}"],
501
+ "path": ["graph-generation", "generate"]
502
+ },
503
+ "description": "MultiProx Gibbs init on QM9 (continuous/lifted). Best params from thesis Table 4.3.1: t=50%, t'=0.4% of T. Returns step 0 image + state blob. State is auto-saved to {{multiprox_state}}."
504
+ }
505
+ },
506
+ {
507
+ "name": "POST /graph-generation/generate (multiprox init, comm20, discrete)",
508
+ "event": [
509
+ {
510
+ "listen": "test",
511
+ "script": {
512
+ "type": "text/javascript",
513
+ "exec": ["var body = pm.response.text();", "var lines = body.split('\\n');", "for (var i = 0; i < lines.length; i++) {", " if (lines[i].trim() === 'event: result' && i + 1 < lines.length) {", " var dataLine = lines[i + 1].replace(/^data: /, '');", " try {", " var result = JSON.parse(dataLine);", " if (result.state) { pm.collectionVariables.set('multiprox_state', result.state); }", " } catch (e) {}", " break;", " }", "}"]
514
+ }
515
+ }
516
+ ],
517
+ "request": {
518
+ "method": "POST",
519
+ "header": [{ "key": "Content-Type", "value": "application/json" }],
520
  "body": {
521
  "mode": "raw",
522
+ "raw": "{\n \"dataset_id\": \"comm20\",\n \"model_type\": \"discrete\",\n \"sampling_mode\": \"multiprox\",\n \"num_nodes\": null,\n \"diffusion_steps\": 500,\n \"multiprox_params\": {\n \"n\": 10,\n \"m\": 100,\n \"t\": 0.4,\n \"t_prime\": 0.1,\n \"gibbs_chain_freq\": 10\n }\n}"
523
  },
524
  "url": {
525
  "raw": "{{base_url}}/graph-generation/generate",
526
  "host": ["{{base_url}}"],
527
  "path": ["graph-generation", "generate"]
528
  },
529
+ "description": "MultiProx Gibbs init on Community20 (discrete). Best params from thesis Table C.2.1: t=40%, t'=10% of T. Returns step 0 image + state blob. State is auto-saved to {{multiprox_state}}."
530
+ }
531
+ },
532
+ {
533
+ "name": "POST /graph-generation/generate (multiprox init, comm20, continuous)",
534
+ "event": [
535
+ {
536
+ "listen": "test",
537
+ "script": {
538
+ "type": "text/javascript",
539
+ "exec": ["var body = pm.response.text();", "var lines = body.split('\\n');", "for (var i = 0; i < lines.length; i++) {", " if (lines[i].trim() === 'event: result' && i + 1 < lines.length) {", " var dataLine = lines[i + 1].replace(/^data: /, '');", " try {", " var result = JSON.parse(dataLine);", " if (result.state) { pm.collectionVariables.set('multiprox_state', result.state); }", " } catch (e) {}", " break;", " }", "}"]
540
+ }
541
+ }
542
+ ],
543
+ "request": {
544
+ "method": "POST",
545
+ "header": [{ "key": "Content-Type", "value": "application/json" }],
546
+ "body": {
547
+ "mode": "raw",
548
+ "raw": "{\n \"dataset_id\": \"comm20\",\n \"model_type\": \"continuous\",\n \"sampling_mode\": \"multiprox\",\n \"num_nodes\": null,\n \"diffusion_steps\": 500,\n \"multiprox_params\": {\n \"n\": 10,\n \"m\": 100,\n \"t\": 0.4,\n \"t_prime\": 0.1,\n \"gibbs_chain_freq\": 10\n }\n}"
549
+ },
550
+ "url": {
551
+ "raw": "{{base_url}}/graph-generation/generate",
552
+ "host": ["{{base_url}}"],
553
+ "path": ["graph-generation", "generate"]
554
+ },
555
+ "description": "MultiProx Gibbs init on Community20 (continuous/lifted). Best params from thesis Table C.2.1: t=40%, t'=10% of T. Returns step 0 image + state blob. State is auto-saved to {{multiprox_state}}."
556
  }
557
  },
558
  {
559
  "name": "POST /graph-generation/continue",
560
+ "event": [
561
+ {
562
+ "listen": "test",
563
+ "script": {
564
+ "type": "text/javascript",
565
+ "exec": ["// Update state for chaining multiple continue calls", "var body = pm.response.text();", "var lines = body.split('\\n');", "for (var i = 0; i < lines.length; i++) {", " if (lines[i].trim() === 'event: result' && i + 1 < lines.length) {", " var dataLine = lines[i + 1].replace(/^data: /, '');", " try {", " var result = JSON.parse(dataLine);", " if (result.state) {", " pm.collectionVariables.set('multiprox_state', result.state);", " console.log('State updated (done=' + result.done + ', step=' + result.step + ')');", " }", " } catch (e) {}", " break;", " }", "}"]
566
+ }
567
+ }
568
+ ],
569
  "request": {
570
  "method": "POST",
571
+ "header": [{ "key": "Content-Type", "value": "application/json" }],
 
 
572
  "body": {
573
  "mode": "raw",
574
+ "raw": "{\n \"state\": \"{{multiprox_state}}\"\n}"
575
  },
576
  "url": {
577
  "raw": "{{base_url}}/graph-generation/continue",
578
  "host": ["{{base_url}}"],
579
  "path": ["graph-generation", "continue"]
580
  },
581
+ "description": "Advance MultiProx chain by gibbs_chain_freq inner steps. Uses {{multiprox_state}} from the last init/continue call. Can be fired repeatedly to chain steps."
582
  }
583
  }
584
  ]
src/backend/README.md CHANGED
@@ -45,6 +45,7 @@ The API is served at `http://localhost:8000/api/v1/`.
45
  | `DJANGO_SECRET_KEY` | `dev-insecure-key-change-in-production` | Django secret key. **Set in production.** |
46
  | `DJANGO_DEBUG` | `True` | Enable debug mode. Set to `False` in production. |
47
  | `DJANGO_ALLOWED_HOSTS` | `localhost,127.0.0.1` | Comma-separated allowed hosts. |
 
48
 
49
  ## Startup Sequence
50
 
@@ -65,8 +66,9 @@ All endpoints are prefixed with `/api/v1/`.
65
 
66
  | Method | Path | Description |
67
  |---|---|---|
68
- | `GET` | `/health` | Service health + model availability |
69
  | `GET` | `/methods` | List the 3 research methods |
 
70
 
71
  ### COINs β€” KG Reasoning
72
 
@@ -86,8 +88,8 @@ All endpoints are prefixed with `/api/v1/`.
86
  |---|---|---|
87
  | `GET` | `/graph-generation/datasets` | List graph types with node/edge types |
88
  | `GET` | `/graph-generation/sampling-modes` | Sampling strategies with parameter specs |
89
- | `POST` | `/graph-generation/generate` | Generate a graph (not yet implemented) |
90
- | `POST` | `/graph-generation/continue` | Continue MultiProx generation (not yet implemented) |
91
 
92
  ### KG Anomaly Correction
93
 
@@ -98,6 +100,32 @@ All endpoints are prefixed with `/api/v1/`.
98
  | `POST` | `/kg-anomaly/correct` | Run correction (not yet implemented) |
99
  | `POST` | `/kg-anomaly/continue` | Continue MultiProx correction (not yet implemented) |
100
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  ## Project Structure
102
 
103
  ```
 
45
  | `DJANGO_SECRET_KEY` | `dev-insecure-key-change-in-production` | Django secret key. **Set in production.** |
46
  | `DJANGO_DEBUG` | `True` | Enable debug mode. Set to `False` in production. |
47
  | `DJANGO_ALLOWED_HOSTS` | `localhost,127.0.0.1` | Comma-separated allowed hosts. |
48
+ | `TORCH_DEVICE` | Auto (`cuda:0` if available, else `cpu`) | PyTorch device for model inference. |
49
 
50
  ## Startup Sequence
51
 
 
66
 
67
  | Method | Path | Description |
68
  |---|---|---|
69
+ | `GET` | `/health` | Service health + model availability + inference lock status |
70
  | `GET` | `/methods` | List the 3 research methods |
71
+ | `POST` | `/debug/force-unlock` | Release stuck inference lock (debug mode only) |
72
 
73
  ### COINs β€” KG Reasoning
74
 
 
88
  |---|---|---|
89
  | `GET` | `/graph-generation/datasets` | List graph types with node/edge types |
90
  | `GET` | `/graph-generation/sampling-modes` | Sampling strategies with parameter specs |
91
+ | `POST` | `/graph-generation/generate` | **Streaming NDJSON.** Generate a graph (standard denoising or MultiProx Gibbs init) |
92
+ | `POST` | `/graph-generation/continue` | **Streaming NDJSON.** Advance a MultiProx Gibbs session by one step |
93
 
94
  ### KG Anomaly Correction
95
 
 
100
  | `POST` | `/kg-anomaly/correct` | Run correction (not yet implemented) |
101
  | `POST` | `/kg-anomaly/continue` | Continue MultiProx correction (not yet implemented) |
102
 
103
+ ## Streaming Inference Protocol (SSE)
104
+
105
+ The graph generation endpoints (`/generate`, `/continue`) return **Server-Sent Events** (`text/event-stream`). Three event types are emitted:
106
+
107
+ **`event: progress`** β€” phase/step metadata (no images):
108
+ ```
109
+ event: progress
110
+ data: {"type":"progress","phase":"denoise","step":42,"total_steps":500,"elapsed_ms":2100}
111
+ ```
112
+
113
+ **`event: preview`** β€” base64 PNG of the graph's current state, emitted at key frames:
114
+ ```
115
+ event: preview
116
+ data: data:image/png;base64,...
117
+ ```
118
+
119
+ Preview frequency: `denoise` emits at `chain_frames` intervals (~30 over 500 steps), `gibbs` emits every inner step, `refine` emits every ~10% of steps.
120
+
121
+ **`event: result`** β€” final payload with image, chain GIF, and timing:
122
+ ```
123
+ event: result
124
+ data: {"type":"result","dataset_id":"qm9","model_type":"discrete","sampling_mode":"standard","image":"data:image/png;base64,...","chain_gif":"data:image/gif;base64,...","inference_time_ms":25000}
125
+ ```
126
+
127
+ Phases: `denoise` (standard generation loop), `noise_init` (multiprox init noise sampling), `gibbs` (multiprox inner Gibbs steps), `refine` (multiprox refinement denoising).
128
+
129
  ## Project Structure
130
 
131
  ```