Andrej Janchevski commited on
Commit Β·
c50d15c
1
Parent(s): c1b3cc7
docs(postman): add thesis-optimal multiprox params and force-unlock endpoint
Browse files- QM9 multiprox: t=0.5, t'=0.004 (Table 4.3.1 best valid config)
- Comm20 multiprox: t=0.4, t'=0.1 (Table C.2.1 best orbit/degree)
- Add POST /debug/force-unlock request to Health folder
- Auto-chaining test scripts for multiprox state between init/continue
- Update README with SSE protocol and new endpoints
- Update backend_multiproxan plan to reflect current design decisions
- .claude/plans/backend_multiproxan.md +233 -350
- docs/postman/collection.json +180 -17
- src/backend/README.md +31 -3
.claude/plans/backend_multiproxan.md
CHANGED
|
@@ -13,14 +13,19 @@ The key challenge is that `sample_batch` / `sample_batch_gibbs` write to disk vi
|
|
| 13 |
| `src/research/MultiProxAn/src/diffusion_model_discrete.py` | Strip all `wandb` imports and calls |
|
| 14 |
| `src/research/MultiProxAn/src/diffusion_model.py` | Strip all `wandb` imports and calls (incl. one unguarded call) |
|
| 15 |
| `src/research/MultiProxAn/src/utils.py` | Remove `import wandb` and `setup_wandb()` function |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
| `src/backend/research_api/settings.py` | Add `MultiProxAn/src/` to `sys.path` |
|
| 17 |
-
| `src/backend/api/services/graphgen_inference.py` | **New file**
|
| 18 |
-
| `src/backend/api/services/registry.py` | Add `_graphgen_models` cache, `
|
| 19 |
-
| `src/backend/api/views/graph_generation.py` | Add `GraphGenGenerateView`, `GraphGenContinueView` |
|
| 20 |
-
| `src/backend/api/
|
| 21 |
-
| `src/backend/
|
| 22 |
-
| `src/backend/
|
| 23 |
-
| `
|
|
|
|
| 24 |
|
| 25 |
## Model Differences: Discrete vs Continuous
|
| 26 |
|
|
@@ -30,60 +35,164 @@ The key challenge is that `sample_batch` / `sample_batch_gibbs` write to disk vi
|
|
| 30 |
|---|---|---|
|
| 31 |
| **`node_mask` dtype** | `bool` | `float32` (`.float()` required) |
|
| 32 |
| **Initial noise** | `sample_discrete_feature_noise(limit_dist=model.limit_dist, node_mask=node_mask)` | `sample_feature_noise(X_size=(1,n,Xdim), E_size=(1,n,n,Edim), y_size=(1,ydim), node_mask=node_mask)` |
|
| 33 |
-
| **`sample_p_zs_given_zt` returns** | `(sampled_s, discrete_sampled_s)`
|
| 34 |
-
| **Chain frame render** | use `discrete_sampled_s.X/E` directly | `utils.unnormalize(z_s.X, z_s.E, z_s.y, model.norm_values, model.norm_biases, node_mask, collapse=True)`
|
| 35 |
-
| **Final graph collapse** | `PlaceHolder(X=X, E=E, y=y).mask(node_mask, collapse=True)` | `model.sample_discrete_graph_given_z0(X, E, y, node_mask)` (runs another forward pass + unnormalize) |
|
| 36 |
| **Gibbs ensemble aggregation** | `torch.median(X, dim=1).values` | `torch.mean(X, dim=1)` |
|
| 37 |
|
|
|
|
|
|
|
| 38 |
The refinement loop in `run_multiprox_step` uses `sample_p_zs_given_zt` to update `cur_X/E/y` at each step; only the first element of the tuple (or the single return value) is needed there. `_collapse_final` is called once at the end for rendering.
|
| 39 |
|
| 40 |
## Research Code Reused (read-only)
|
| 41 |
|
| 42 |
| Symbol | Location | How used |
|
| 43 |
|---|---|---|
|
| 44 |
-
| `DiscreteDenoisingDiffusion` | `src/diffusion_model_discrete.py` | Loaded via `
|
| 45 |
-
| `LiftedDenoisingDiffusion` | `src/diffusion_model.py` | Loaded via `
|
| 46 |
| `model.node_dist.sample_n(1, device)` | both models | Sample number of nodes |
|
| 47 |
-
| `diffusion_utils.sample_discrete_feature_noise(limit_dist, node_mask)` | `src/diffusion/diffusion_utils.py` | Initial noise
|
| 48 |
-
| `diffusion_utils.sample_feature_noise(X_size, E_size, y_size, node_mask)` | `src/diffusion/diffusion_utils.py` | Initial noise
|
| 49 |
-
| `model.sample_p_zs_given_zt(s, t, X, E, y, node_mask)` | both models | One denoising step (return varies by model type
|
| 50 |
| `model.apply_noise(X, E, y, node_mask, gibbs=True)` | both models | Re-apply Gibbs noise; uses `model.gibbs_fixed_t_2` internally |
|
| 51 |
-
| `PlaceHolder.mask(node_mask, collapse=True)` | `src/utils.py` | Final collapse
|
| 52 |
-
| `model.sample_discrete_graph_given_z0(X, E, y, node_mask)` | `diffusion_model.py` | Final collapse
|
| 53 |
-
| `utils.unnormalize(X, E, y, norm_values, norm_biases, node_mask, collapse=True)` | `src/utils.py` | Chain frame rendering
|
| 54 |
-
| `model.norm_values`, `model.norm_biases` | `diffusion_model.py` | Unnormalization factors
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
## Implementation
|
| 57 |
|
| 58 |
-
### Step 0: Patch research code
|
| 59 |
|
| 60 |
-
`wandb` is imported at module level in three files; the import alone causes `ImportError` at
|
| 61 |
-
`load_from_checkpoint` time. One call in `diffusion_model.py` is **unguarded** (no `if wandb.run:`
|
| 62 |
-
check) and would crash at inference even if the import somehow succeeded.
|
| 63 |
|
| 64 |
**`src/research/MultiProxAn/src/diffusion_model_discrete.py`**:
|
| 65 |
- Remove `import wandb` (line 9)
|
| 66 |
-
- Remove the two `utils.setup_wandb(self.cfg)` call sites
|
| 67 |
-
- Remove all `if wandb.run: wandb.log(...)` / `wandb.run.summary[...]` blocks
|
| 68 |
|
| 69 |
**`src/research/MultiProxAn/src/diffusion_model.py`**:
|
| 70 |
- Remove `import wandb` (line 9)
|
| 71 |
- Remove the two `utils.setup_wandb(self.cfg)` call sites
|
| 72 |
-
- Remove all guarded `if wandb.run: wandb.log(...)`
|
| 73 |
-
- Remove the **unguarded** `wandb.log({...})` call at line 590 (inside `compute_val_loss` / `compute_test_loss`); the surrounding NLL computation and `return` must be preserved β only the `wandb.log(...)` statement is deleted
|
| 74 |
|
| 75 |
**`src/research/MultiProxAn/src/utils.py`**:
|
| 76 |
-
- Remove `import wandb`
|
| 77 |
-
- Remove the entire `setup_wandb(cfg)` function (lines 134β139)
|
| 78 |
-
|
| 79 |
-
After these edits `wandb` is no longer needed and must be removed from `requirements.txt` (Step 6).
|
| 80 |
|
| 81 |
---
|
| 82 |
|
| 83 |
-
### Step
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
-
|
|
|
|
|
|
|
| 86 |
|
|
|
|
| 87 |
```python
|
| 88 |
_MULTIPROXAN_SRC = str(PROJECT_ROOT / "src" / "research" / "MultiProxAn" / "src")
|
| 89 |
for _path in (_COINS_KG_ROOT, _MULTIPROXAN_ROOT, _MULTIPROXAN_SRC):
|
|
@@ -95,108 +204,48 @@ for _path in (_COINS_KG_ROOT, _MULTIPROXAN_ROOT, _MULTIPROXAN_SRC):
|
|
| 95 |
|
| 96 |
### Step 2: New `graphgen_inference.py`
|
| 97 |
|
| 98 |
-
Completely independent of Django/registry. Receives a loaded model and
|
| 99 |
|
| 100 |
**Module structure:**
|
| 101 |
|
| 102 |
```
|
| 103 |
graphgen_inference.py
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
```
|
| 130 |
|
| 131 |
#### Model-type helpers
|
| 132 |
|
| 133 |
-
```
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
return isinstance(model, DiscreteDenoisingDiffusion)
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
def _build_node_mask(n_nodes, n_max, model):
|
| 140 |
-
"""bool for discrete, float32 for continuous."""
|
| 141 |
-
arange = torch.arange(n_max, device=n_nodes.device).unsqueeze(0)
|
| 142 |
-
mask = arange < n_nodes.unsqueeze(1) # (1, n_max) bool
|
| 143 |
-
return mask if _is_discrete(model) else mask.float()
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
def _sample_initial_noise(model, n_max, node_mask):
|
| 147 |
-
"""Discrete: sample_discrete_feature_noise. Continuous: sample_feature_noise."""
|
| 148 |
-
if _is_discrete(model):
|
| 149 |
-
return diffusion_utils.sample_discrete_feature_noise(
|
| 150 |
-
limit_dist=model.limit_dist, node_mask=node_mask)
|
| 151 |
-
else:
|
| 152 |
-
bs = node_mask.shape[0]
|
| 153 |
-
return diffusion_utils.sample_feature_noise(
|
| 154 |
-
X_size=(bs, n_max, model.Xdim_output),
|
| 155 |
-
E_size=(bs, n_max, n_max, model.Edim_output),
|
| 156 |
-
y_size=(bs, model.ydim_output),
|
| 157 |
-
node_mask=node_mask)
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
def _denoising_step(model, s_t, t_t, X, E, y, node_mask):
|
| 161 |
-
"""Run one denoising step.
|
| 162 |
-
Returns (X_soft, E_soft, y_soft, X_int, E_int) where X/E_int are collapsed integer tensors
|
| 163 |
-
suitable for rendering. X/E_soft are the continuous activations to feed into the next step.
|
| 164 |
-
"""
|
| 165 |
-
if _is_discrete(model):
|
| 166 |
-
sampled_s, discrete_s = model.sample_p_zs_given_zt(s_t, t_t, X, E, y, node_mask)
|
| 167 |
-
return sampled_s.X, sampled_s.E, sampled_s.y, discrete_s.X, discrete_s.E
|
| 168 |
-
else:
|
| 169 |
-
z_s = model.sample_p_zs_given_zt(s=s_t, t=t_t, X_t=X, E_t=E, y_t=y, node_mask=node_mask)
|
| 170 |
-
unnorm = utils.unnormalize(
|
| 171 |
-
z_s.X, z_s.E, z_s.y,
|
| 172 |
-
model.norm_values, model.norm_biases, node_mask, collapse=True)
|
| 173 |
-
return z_s.X, z_s.E, z_s.y, unnorm.X, unnorm.E
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
def _gibbs_aggregate(model, X):
|
| 177 |
-
"""Aggregate ensemble: median for discrete, mean for continuous."""
|
| 178 |
-
if _is_discrete(model):
|
| 179 |
-
return torch.median(X, dim=1).values
|
| 180 |
-
else:
|
| 181 |
-
return torch.mean(X, dim=1)
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
def _collapse_final(model, X, E, y, node_mask):
|
| 185 |
-
"""Collapse continuous activations to integer indices for rendering.
|
| 186 |
-
Discrete: PlaceHolder.mask(collapse=True). Continuous: sample_discrete_graph_given_z0.
|
| 187 |
-
Returns (X_int, E_int) tensors.
|
| 188 |
-
"""
|
| 189 |
-
if _is_discrete(model):
|
| 190 |
-
final = PlaceHolder(X=X, E=E, y=y).mask(node_mask, collapse=True)
|
| 191 |
-
return final.X, final.E
|
| 192 |
-
else:
|
| 193 |
-
final = model.sample_discrete_graph_given_z0(X, E, y, node_mask)
|
| 194 |
-
return final.X, final.E
|
| 195 |
-
```
|
| 196 |
|
| 197 |
#### `run_standard_generation`
|
| 198 |
|
| 199 |
-
Re-implements the denoising loop
|
| 200 |
|
| 201 |
```python
|
| 202 |
def run_standard_generation(model, num_nodes, diffusion_steps, chain_frames, dataset_id):
|
|
@@ -219,121 +268,31 @@ def run_standard_generation(model, num_nodes, diffusion_steps, chain_frames, dat
|
|
| 219 |
step = diffusion_steps - 1 - s_idx
|
| 220 |
if step % frame_interval == 0 or s_idx == 0:
|
| 221 |
gif_frames.append(render_graph(X_int[0, :n_max], E_int[0, :n_max, :n_max], dataset_id))
|
|
|
|
| 222 |
|
| 223 |
X_final, E_final = _collapse_final(model, X, E, y, node_mask)
|
| 224 |
-
|
| 225 |
-
return image_b64, _frames_to_gif_b64(gif_frames), elapsed_ms
|
| 226 |
```
|
| 227 |
|
| 228 |
-
#### API parameter
|
| 229 |
|
| 230 |
| API param | Model attribute | Role |
|
| 231 |
|---|---|---|
|
| 232 |
| `t` | `model.gibbs_fixed_t_2` | Fixed noise level for the Gibbs chain; used by `apply_noise(gibbs=True)` and as `fixed_t_norm` in the inner denoising step |
|
| 233 |
-
| `t_prime` | `model.gibbs_fixed_t_1` | Refinement target; `P = int((gibbs_fixed_t_2
|
| 234 |
| `gibbs_chain_freq` | `model.gibbs_chain_freq` | Inner Gibbs steps per `/continue` call (same attribute name in checkpoint) |
|
| 235 |
| `m` | `model.gibbs_M` | Ensemble size |
|
| 236 |
-
| `n` | `model.gibbs_N` | Number of outer Gibbs rounds
|
| 237 |
|
| 238 |
-
All parameters are always explicit in the API request
|
| 239 |
|
| 240 |
#### `run_multiprox_init`
|
| 241 |
|
| 242 |
-
Initialises the M-member ensemble
|
| 243 |
-
Gibbs iterations run per `/continue` call β `βM / gibbs_chain_freqβ` calls complete one outer round.
|
| 244 |
-
Default `gibbs_chain_freq = max(1, m // 10)` (10% of ensemble size):
|
| 245 |
-
|
| 246 |
-
```python
|
| 247 |
-
def run_multiprox_init(model, num_nodes, n, m, t, t_prime, gibbs_chain_freq, dataset_id):
|
| 248 |
-
device = next(model.parameters()).device
|
| 249 |
-
n_nodes = model.node_dist.sample_n(1, device) if num_nodes is None else torch.tensor([num_nodes], ...)
|
| 250 |
-
n_max = n_nodes.item()
|
| 251 |
-
node_mask = _build_node_mask(n_nodes, n_max, model)
|
| 252 |
-
|
| 253 |
-
# Sample M independent initial noise graphs
|
| 254 |
-
z_samples = [_sample_initial_noise(model, n_max, node_mask) for _ in range(m)]
|
| 255 |
-
X = torch.stack([z.X for z in z_samples], dim=1) # (1, M, n_max, Xdim)
|
| 256 |
-
E = torch.stack([z.E for z in z_samples], dim=1)
|
| 257 |
-
y = torch.stack([z.y for z in z_samples], dim=1)
|
| 258 |
-
|
| 259 |
-
# Step 0 image: aggregate ensemble β collapse β render
|
| 260 |
-
agg_X = _gibbs_aggregate(model, X) # (1, n_max, Xdim)
|
| 261 |
-
agg_E = _gibbs_aggregate(model, E)
|
| 262 |
-
agg_y = _gibbs_aggregate(model, y.float())
|
| 263 |
-
X_int, E_int = _collapse_final(model, agg_X, agg_E, agg_y, node_mask)
|
| 264 |
-
image_b64 = _pil_to_b64(render_graph(X_int[0, :n_max], E_int[0, :n_max, :n_max], dataset_id))
|
| 265 |
-
|
| 266 |
-
state = {"X": X.cpu(), "E": E.cpu(), "y": y.cpu(), "n_nodes": n_nodes.cpu(),
|
| 267 |
-
"dataset_id": dataset_id, "model_type": None, # filled by registry
|
| 268 |
-
"T": model.T, "n": n, "m": m, "t": t, "t_prime": t_prime,
|
| 269 |
-
"gibbs_chain_freq": gibbs_chain_freq, "inner_step": 0, "step": 0}
|
| 270 |
-
return state, image_b64, elapsed_ms
|
| 271 |
-
```
|
| 272 |
|
| 273 |
#### `run_multiprox_step`
|
| 274 |
|
| 275 |
-
Runs `gibbs_chain_freq` inner Gibbs iterations
|
| 276 |
-
(one full outer round), runs the tβt_prime refinement and resets. On intermediate calls (round
|
| 277 |
-
still in progress), renders the raw ensemble aggregate β no refinement, faster response.
|
| 278 |
-
|
| 279 |
-
```python
|
| 280 |
-
def run_multiprox_step(model, state_dict, dataset_id):
|
| 281 |
-
# Unpack X (1,M,n_max,Xdim), E, y, n_nodes, t, t_prime, T, n, m, gibbs_chain_freq, inner_step, step
|
| 282 |
-
device = next(model.parameters()).device
|
| 283 |
-
X, E, y = state["X"].to(device), state["E"].to(device), state["y"].to(device)
|
| 284 |
-
n_nodes = state["n_nodes"].to(device)
|
| 285 |
-
n_max = X.shape[2]
|
| 286 |
-
node_mask = _build_node_mask(n_nodes, n_max, model)
|
| 287 |
-
|
| 288 |
-
fixed_t = t * torch.ones((1,1), dtype=torch.float, device=device)
|
| 289 |
-
fixed_s = fixed_t - (1.0 / T)
|
| 290 |
-
|
| 291 |
-
# How many inner Gibbs steps to run this call (may be less than gibbs_chain_freq at end of round)
|
| 292 |
-
steps_this_call = min(gibbs_chain_freq, m - inner_step)
|
| 293 |
-
|
| 294 |
-
with torch.no_grad():
|
| 295 |
-
for i in range(steps_this_call):
|
| 296 |
-
k = inner_step + i
|
| 297 |
-
avg_X = _gibbs_aggregate(model, X) # (1, n_max, Xdim)
|
| 298 |
-
avg_E = _gibbs_aggregate(model, E)
|
| 299 |
-
avg_y = _gibbs_aggregate(model, y.float())
|
| 300 |
-
denoised_X, denoised_E, denoised_y, _, _ = _denoising_step(
|
| 301 |
-
model, fixed_s, fixed_t, avg_X, avg_E, avg_y, node_mask)
|
| 302 |
-
old_t2 = model.gibbs_fixed_t_2
|
| 303 |
-
model.gibbs_fixed_t_2 = t # override per-request (lock held by registry)
|
| 304 |
-
noisy = model.apply_noise(denoised_X, denoised_E, denoised_y, node_mask, gibbs=True)
|
| 305 |
-
model.gibbs_fixed_t_2 = old_t2
|
| 306 |
-
X[:, k], E[:, k], y[:, k] = noisy["X_t"], noisy["E_t"], noisy["y_t"]
|
| 307 |
-
|
| 308 |
-
new_inner_step = inner_step + steps_this_call
|
| 309 |
-
round_complete = new_inner_step >= m
|
| 310 |
-
if round_complete:
|
| 311 |
-
new_inner_step = 0
|
| 312 |
-
new_step = step + 1
|
| 313 |
-
else:
|
| 314 |
-
new_step = step
|
| 315 |
-
done = round_complete and new_step >= n
|
| 316 |
-
|
| 317 |
-
# Refinement pass runs on every call β always produce a clean render.
|
| 318 |
-
# Uses the current ensemble aggregate (regardless of round_complete).
|
| 319 |
-
P = int((t - t_prime) * T) + 1
|
| 320 |
-
cur_X = _gibbs_aggregate(model, X)
|
| 321 |
-
cur_E = _gibbs_aggregate(model, E)
|
| 322 |
-
cur_y = _gibbs_aggregate(model, y.float())
|
| 323 |
-
for j in range(P):
|
| 324 |
-
s_ref = (t - (j + 1) / T) * torch.ones((1,1), dtype=torch.float, device=device)
|
| 325 |
-
t_ref = (t - j / T) * torch.ones((1,1), dtype=torch.float, device=device)
|
| 326 |
-
cur_X, cur_E, cur_y, _, _ = _denoising_step(
|
| 327 |
-
model, s_ref, t_ref, cur_X, cur_E, cur_y, node_mask)
|
| 328 |
-
X_int, E_int = _collapse_final(model, cur_X, cur_E, cur_y, node_mask)
|
| 329 |
-
|
| 330 |
-
image_b64 = _pil_to_b64(render_graph(X_int[0, :n_max], E_int[0, :n_max, :n_max], dataset_id))
|
| 331 |
-
updated_state = {**state_dict, "X": X.cpu(), "E": E.cpu(), "y": y.cpu(),
|
| 332 |
-
"step": new_step, "inner_step": new_inner_step}
|
| 333 |
-
return updated_state, image_b64, round_complete, done, elapsed_ms
|
| 334 |
-
```
|
| 335 |
-
|
| 336 |
-
Response includes `round_complete: bool` (full M sweep done) and `done: bool` (`n` outer rounds complete β frontend should stop calling continue).
|
| 337 |
|
| 338 |
**State blob (encode/decode):**
|
| 339 |
|
|
@@ -342,167 +301,88 @@ STATE_BLOB_MAX_BYTES = 10 * 1024 * 1024 # 10 MB
|
|
| 342 |
REQUIRED_STATE_KEYS = {"X", "E", "y", "n_nodes", "dataset_id", "model_type", "T",
|
| 343 |
"n", "m", "t", "t_prime", "gibbs_chain_freq", "inner_step", "step"}
|
| 344 |
|
| 345 |
-
# encode: torch.save(state_dict, BytesIO)
|
| 346 |
-
# decode: base64
|
| 347 |
-
# Raise ValueError on bad base64, oversized blob, missing keys, wrong tensor dims.
|
| 348 |
-
# Registry converts ValueError β InvalidRequestError(400).
|
| 349 |
```
|
| 350 |
|
| 351 |
-
**Visualization:**
|
| 352 |
-
|
| 353 |
-
- `_render_qm9`: `RWMol`, `Atom(symbol)` from `QM9_ATOM_TYPES=["C","N","O","F"]`, `AddBond` with `{1:SINGLE, 2:DOUBLE, 3:TRIPLE, 4:AROMATIC}` (index 0 = no bond, skip), `MolToImage(size=(300,300))` β PIL
|
| 354 |
-
- `_render_comm20`: `nx.Graph` from adjacency (edge index > 0 = present), `nx.draw_networkx` with matplotlib Agg backend, `fig.savefig(BytesIO, "png")` β PIL
|
| 355 |
-
- `_pil_to_b64`: `img.save(BytesIO, "PNG")` β `"data:image/png;base64," + b64encode`
|
| 356 |
-
- `_frames_to_gif_b64`: `frames[0].save(BytesIO, "GIF", save_all=True, append_images=..., duration=150, loop=0)` β `"data:image/gif;base64," + b64encode`
|
| 357 |
-
|
| 358 |
---
|
| 359 |
|
| 360 |
### Step 3: `registry.py` additions
|
| 361 |
|
| 362 |
-
**
|
| 363 |
|
|
|
|
| 364 |
```python
|
| 365 |
-
self.
|
|
|
|
|
|
|
| 366 |
```
|
| 367 |
|
| 368 |
-
**`
|
| 369 |
-
|
| 370 |
-
```python
|
| 371 |
-
# Defer imports inside method β sys.path is set at app startup, not at Django module import time
|
| 372 |
-
if model_type == "discrete":
|
| 373 |
-
from diffusion_model_discrete import DiscreteDenoisingDiffusion as cls
|
| 374 |
-
else:
|
| 375 |
-
from diffusion_model import LiftedDenoisingDiffusion as cls
|
| 376 |
-
|
| 377 |
-
suffix = "_c" if model_type == "continuous" else ""
|
| 378 |
-
ckpt_path = Path(settings.MULTIPROXAN_DIR) / "checkpoints" / f"{dataset_id}{suffix}.ckpt"
|
| 379 |
-
|
| 380 |
-
model = cls.load_from_checkpoint(
|
| 381 |
-
str(ckpt_path), map_location="cpu",
|
| 382 |
-
train_metrics=None, sampling_metrics=None, visualization_tools=None,
|
| 383 |
-
)
|
| 384 |
-
model.eval()
|
| 385 |
-
self._graphgen_models[(dataset_id, model_type)] = model
|
| 386 |
-
```
|
| 387 |
|
| 388 |
-
**
|
| 389 |
-
|
| 390 |
-
`save_hyperparameters(ignore=['train_metrics', 'sampling_metrics'])` saves `dataset_infos`,
|
| 391 |
-
`extra_features`, `domain_features`, and `visualization_tools` into the checkpoint. The three `None`
|
| 392 |
-
overrides are sufficient β the remaining hparams are restored from the checkpoint via pickle.
|
| 393 |
-
|
| 394 |
-
**Why unpickling is safe without data files:** Python pickle reconstructs objects by restoring
|
| 395 |
-
`__dict__` without running `__init__` and without any file I/O. The pre-computed tensors inside
|
| 396 |
-
`dataset_infos` (`n_nodes`, `node_types`, `edge_types`, `nodes_dist`) all restore correctly. The
|
| 397 |
-
stored datamodule reference is inert at inference time β we never call `train_dataloader()` or any
|
| 398 |
-
method that touches data files.
|
| 399 |
-
|
| 400 |
-
**Why file paths resolve correctly on any deployment:** all datamodules compute their data root as:
|
| 401 |
-
```python
|
| 402 |
-
base_path = pathlib.Path(os.path.realpath(__file__)).parents[2] # β MultiProxAn/
|
| 403 |
-
root_path = os.path.join(base_path, self.datadir) # β MultiProxAn/data/qm9
|
| 404 |
-
```
|
| 405 |
-
`__file__` is the source file on the current machine, so paths are always relative to the research
|
| 406 |
-
code location. `self.datadir` (a relative string like `"data/qm9"`) pickles and restores as-is.
|
| 407 |
-
No path patching is needed.
|
| 408 |
|
| 409 |
-
`
|
| 410 |
-
they unpickle cleanly.
|
| 411 |
|
| 412 |
-
**`
|
| 413 |
-
|
| 414 |
-
```python
|
| 415 |
-
# Acquire _inference_lock (non-blocking) β raise InferenceBusy() if locked
|
| 416 |
-
# Load model via _load_graphgen_model (may raise ModelUnavailable)
|
| 417 |
-
# standard: run_standard_generation β return {dataset_id, model_type, sampling_mode, image, chain_gif, inference_time_ms}
|
| 418 |
-
# multiprox: run_multiprox_init β state["model_type"] = model_type β encode_state_blob
|
| 419 |
-
# β return {step:0, image, state, inference_time_ms}
|
| 420 |
-
# finally: release lock
|
| 421 |
-
```
|
| 422 |
-
|
| 423 |
-
**`graphgen_continue(self, state_b64)`**:
|
| 424 |
-
|
| 425 |
-
```python
|
| 426 |
-
# decode_state_blob OUTSIDE lock (fail-fast) β raise InvalidRequestError on ValueError
|
| 427 |
-
# Acquire _inference_lock β raise InferenceBusy() if locked
|
| 428 |
-
# Load model via _load_graphgen_model(state["dataset_id"], state["model_type"])
|
| 429 |
-
# run_multiprox_step β encode_state_blob β return {step, image, state, inference_time_ms}
|
| 430 |
-
# finally: release lock
|
| 431 |
-
```
|
| 432 |
|
| 433 |
---
|
| 434 |
|
| 435 |
### Step 4: `views/graph_generation.py` additions
|
| 436 |
|
| 437 |
-
**`GraphGenGenerateView.post`**
|
| 438 |
- Require `dataset_id`, `model_type`, `sampling_mode`
|
| 439 |
-
- `dataset_id
|
| 440 |
-
- Check `registry.graphgen_checkpoints_available
|
| 441 |
- `diffusion_steps` clamped to [50, 1000], `chain_frames` to [10, 30]
|
| 442 |
- `num_nodes`: optional int, validated against `GRAPHGEN_DATASETS[dataset_id]["max_nodes"]`
|
| 443 |
-
- multiprox only: `multiprox_params` required; validate `0 <
|
| 444 |
-
-
|
| 445 |
|
| 446 |
-
**`GraphGenContinueView.post`**
|
| 447 |
- Require `state` (non-empty string)
|
| 448 |
-
-
|
|
|
|
|
|
|
| 449 |
|
| 450 |
---
|
| 451 |
|
| 452 |
-
### Step 5: `
|
| 453 |
|
| 454 |
-
```
|
| 455 |
-
# Update import line 12:
|
| 456 |
-
from api.views.graph_generation import (
|
| 457 |
-
GraphGenDatasetsView, GraphGenSamplingModesView,
|
| 458 |
-
GraphGenGenerateView, GraphGenContinueView,
|
| 459 |
-
)
|
| 460 |
|
| 461 |
-
|
| 462 |
-
path("graph-generation/generate", GraphGenGenerateView.as_view()),
|
| 463 |
-
path("graph-generation/continue", GraphGenContinueView.as_view()),
|
| 464 |
-
```
|
| 465 |
|
| 466 |
---
|
| 467 |
|
| 468 |
-
### Step 6: `
|
| 469 |
-
|
| 470 |
-
| Package | Why needed | Already present? |
|
| 471 |
-
|---|---|---|
|
| 472 |
-
| `Pillow` | PIL `Image`, `BytesIO` ops in `graphgen_inference.py` rendering | Transitive dep of `matplotlib` and `torchvision`, but not explicit β add for clarity |
|
| 473 |
-
| `overrides` | Pulled in by research code transitively | Present in research `requirements.txt`; not in backend β add to be safe |
|
| 474 |
|
| 475 |
-
`
|
| 476 |
-
|
| 477 |
-
|
|
|
|
|
|
|
| 478 |
|
| 479 |
---
|
| 480 |
|
| 481 |
-
### Step 7:
|
| 482 |
|
| 483 |
-
|
|
|
|
|
|
|
|
|
|
| 484 |
|
| 485 |
-
|
| 486 |
|
| 487 |
-
|
| 488 |
-
1. `Standard QM9 Discrete` β POST generate `{dataset_id:"qm9", model_type:"discrete", sampling_mode:"standard", diffusion_steps:50, chain_frames:10}`
|
| 489 |
-
2. `Standard QM9 Continuous` β POST generate `{dataset_id:"qm9", model_type:"continuous", sampling_mode:"standard", diffusion_steps:50, chain_frames:10}`
|
| 490 |
-
3. `Standard comm20 Discrete` β POST generate `{dataset_id:"comm20", model_type:"discrete", sampling_mode:"standard", diffusion_steps:500, chain_frames:10}`
|
| 491 |
-
4. `Standard comm20 Continuous` β POST generate `{dataset_id:"comm20", model_type:"continuous", sampling_mode:"standard", diffusion_steps:500, chain_frames:10}`
|
| 492 |
|
| 493 |
-
|
| 494 |
-
5. `MultiProx QM9 Discrete Init` β POST generate `{dataset_id:"qm9", model_type:"discrete", sampling_mode:"multiprox", multiprox_params:{n:10, m:100, t:0.5, t_prime:0.1}}`
|
| 495 |
-
6. `MultiProx QM9 Continuous Init` β POST generate `{dataset_id:"qm9", model_type:"continuous", sampling_mode:"multiprox", multiprox_params:{n:10, m:100, t:0.5, t_prime:0.1}}`
|
| 496 |
-
7. `MultiProx comm20 Discrete Init` β POST generate `{dataset_id:"comm20", model_type:"discrete", sampling_mode:"multiprox", multiprox_params:{n:10, m:100, t:0.5, t_prime:0.1}}`
|
| 497 |
-
8. `MultiProx comm20 Continuous Init` β POST generate `{dataset_id:"comm20", model_type:"continuous", sampling_mode:"multiprox", multiprox_params:{n:10, m:100, t:0.5, t_prime:0.1}}`
|
| 498 |
|
| 499 |
-
**
|
| 500 |
-
9. `MultiProx QM9 Discrete Continue` β POST continue `{state:"{{graphgen_state_qm9_discrete}}"}`
|
| 501 |
-
10. `MultiProx QM9 Continuous Continue` β POST continue `{state:"{{graphgen_state_qm9_continuous}}"}`
|
| 502 |
-
11. `MultiProx comm20 Discrete Continue` β POST continue `{state:"{{graphgen_state_comm20_discrete}}"}`
|
| 503 |
-
12. `MultiProx comm20 Continuous Continue` β POST continue `{state:"{{graphgen_state_comm20_continuous}}"}`
|
| 504 |
|
| 505 |
-
|
|
|
|
|
|
|
| 506 |
|
| 507 |
---
|
| 508 |
|
|
@@ -510,16 +390,19 @@ Each `continue` request uses a distinct collection variable (set manually from t
|
|
| 510 |
|
| 511 |
| Risk | Mitigation |
|
| 512 |
|---|---|
|
| 513 |
-
| `weights_only=
|
| 514 |
-
| `
|
| 515 |
| `model.gibbs_fixed_t_2` attribute override | Only safe because `_inference_lock` ensures single-threaded inference. Save/restore pattern is used. |
|
| 516 |
-
|
|
|
|
|
| 517 |
|
| 518 |
## Verification
|
| 519 |
|
| 520 |
-
1. Django shell: `from diffusion_model_discrete import DiscreteDenoisingDiffusion`
|
| 521 |
-
2. `GET /graph-generation/datasets`
|
| 522 |
-
3.
|
| 523 |
-
4.
|
| 524 |
-
5.
|
| 525 |
-
6.
|
|
|
|
|
|
|
|
|
| 13 |
| `src/research/MultiProxAn/src/diffusion_model_discrete.py` | Strip all `wandb` imports and calls |
|
| 14 |
| `src/research/MultiProxAn/src/diffusion_model.py` | Strip all `wandb` imports and calls (incl. one unguarded call) |
|
| 15 |
| `src/research/MultiProxAn/src/utils.py` | Remove `import wandb` and `setup_wandb()` function |
|
| 16 |
+
| `src/research/MultiProxAn/src/analysis/spectre_utils.py` | Wrap `graph_tool`, `pyemd`, `pygsp`, `dist_helper` imports in try/except |
|
| 17 |
+
| `src/research/MultiProxAn/src/metrics/molecular_metrics.py` | Guard optional metric imports with try/except |
|
| 18 |
+
| `src/research/MultiProxAn/src/metrics/molecular_metrics_discrete.py` | Guard optional metric imports with try/except |
|
| 19 |
+
| `src/research/MultiProxAn/src/metrics/train_metrics.py` | Guard optional metric imports with try/except |
|
| 20 |
| `src/backend/research_api/settings.py` | Add `MultiProxAn/src/` to `sys.path` |
|
| 21 |
+
| `src/backend/api/services/graphgen_inference.py` | **New file** -- all inference logic + rendering |
|
| 22 |
+
| `src/backend/api/services/registry.py` | Add `_graphgen_models` cache, `_safe_load_lightning_checkpoint`, `graphgen_generate_stream`, `graphgen_continue_stream`, `force_release_inference_lock` |
|
| 23 |
+
| `src/backend/api/views/graph_generation.py` | Add `GraphGenGenerateView`, `GraphGenContinueView`, SSE streaming helpers |
|
| 24 |
+
| `src/backend/api/views/health.py` | Add inference lock status to health endpoint, add `ForceUnlockView` |
|
| 25 |
+
| `src/backend/api/urls.py` | Wire 3 new routes (generate, continue, debug/force-unlock) |
|
| 26 |
+
| `src/backend/requirements.txt` | Add `Pillow`, `overrides` |
|
| 27 |
+
| `src/backend/README.md` | Mark generate/continue as implemented, document SSE protocol |
|
| 28 |
+
| `docs/postman/collection.json` | Add 12 example requests, auto-chaining test scripts |
|
| 29 |
|
| 30 |
## Model Differences: Discrete vs Continuous
|
| 31 |
|
|
|
|
| 35 |
|---|---|---|
|
| 36 |
| **`node_mask` dtype** | `bool` | `float32` (`.float()` required) |
|
| 37 |
| **Initial noise** | `sample_discrete_feature_noise(limit_dist=model.limit_dist, node_mask=node_mask)` | `sample_feature_noise(X_size=(1,n,Xdim), E_size=(1,n,n,Edim), y_size=(1,ydim), node_mask=node_mask)` |
|
| 38 |
+
| **`sample_p_zs_given_zt` returns** | `(sampled_s, discrete_sampled_s)` -- 2-tuple; `discrete_sampled_s.X/E` are already collapsed integers | single `z_s` PlaceHolder with continuous floats |
|
| 39 |
+
| **Chain frame render** | use `discrete_sampled_s.X/E` directly (`.long()` required -- see below) | `utils.unnormalize(z_s.X, z_s.E, z_s.y, model.norm_values, model.norm_biases, node_mask, collapse=True)` -> `.X/.E` are integers |
|
| 40 |
+
| **Final graph collapse** | `PlaceHolder(X=X, E=E, y=y).mask(node_mask, collapse=True)` (`.long()` required) | `model.sample_discrete_graph_given_z0(X, E, y, node_mask)` (runs another forward pass + unnormalize) |
|
| 41 |
| **Gibbs ensemble aggregation** | `torch.median(X, dim=1).values` | `torch.mean(X, dim=1)` |
|
| 42 |
|
| 43 |
+
**Critical `.long()` fix**: The discrete model's `sample_p_zs_given_zt` calls `.type_as(y_t)` internally, which can cast collapsed integer indices to float. Both `_denoising_step` and `_collapse_final` must apply `.long()` on the discrete path to prevent `TypeError: list indices must be integers or slices, not float` in the rendering functions.
|
| 44 |
+
|
| 45 |
The refinement loop in `run_multiprox_step` uses `sample_p_zs_given_zt` to update `cur_X/E/y` at each step; only the first element of the tuple (or the single return value) is needed there. `_collapse_final` is called once at the end for rendering.
|
| 46 |
|
| 47 |
## Research Code Reused (read-only)
|
| 48 |
|
| 49 |
| Symbol | Location | How used |
|
| 50 |
|---|---|---|
|
| 51 |
+
| `DiscreteDenoisingDiffusion` | `src/diffusion_model_discrete.py` | Loaded via `_safe_load_lightning_checkpoint` for `model_type=discrete` |
|
| 52 |
+
| `LiftedDenoisingDiffusion` | `src/diffusion_model.py` | Loaded via `_safe_load_lightning_checkpoint` for `model_type=continuous` |
|
| 53 |
| `model.node_dist.sample_n(1, device)` | both models | Sample number of nodes |
|
| 54 |
+
| `diffusion_utils.sample_discrete_feature_noise(limit_dist, node_mask)` | `src/diffusion/diffusion_utils.py` | Initial noise -- discrete only |
|
| 55 |
+
| `diffusion_utils.sample_feature_noise(X_size, E_size, y_size, node_mask)` | `src/diffusion/diffusion_utils.py` | Initial noise -- continuous only |
|
| 56 |
+
| `model.sample_p_zs_given_zt(s, t, X, E, y, node_mask)` | both models | One denoising step (return varies by model type -- see table above) |
|
| 57 |
| `model.apply_noise(X, E, y, node_mask, gibbs=True)` | both models | Re-apply Gibbs noise; uses `model.gibbs_fixed_t_2` internally |
|
| 58 |
+
| `PlaceHolder.mask(node_mask, collapse=True)` | `src/utils.py` | Final collapse -- discrete only |
|
| 59 |
+
| `model.sample_discrete_graph_given_z0(X, E, y, node_mask)` | `diffusion_model.py` | Final collapse -- continuous only |
|
| 60 |
+
| `utils.unnormalize(X, E, y, norm_values, norm_biases, node_mask, collapse=True)` | `src/utils.py` | Chain frame rendering -- continuous only |
|
| 61 |
+
| `model.norm_values`, `model.norm_biases` | `diffusion_model.py` | Unnormalization factors -- continuous only |
|
| 62 |
+
|
| 63 |
+
## Design Decisions
|
| 64 |
+
|
| 65 |
+
### Streaming Protocol: Server-Sent Events (SSE)
|
| 66 |
+
|
| 67 |
+
Both generate and continue endpoints use **SSE** (`text/event-stream`) instead of plain JSON or NDJSON. This enables real-time progress streaming with preview images in Postman's native SSE viewer.
|
| 68 |
+
|
| 69 |
+
**Event types:**
|
| 70 |
+
- `event: progress` -- metadata (phase, step, total_steps, elapsed_ms)
|
| 71 |
+
- `event: preview` -- base64 PNG image of the current graph state (separate event so Postman shows clean image updates)
|
| 72 |
+
- `event: result` -- final JSON payload (image, chain_gif, state, timing)
|
| 73 |
+
|
| 74 |
+
**SSE helpers** in `views/graph_generation.py`:
|
| 75 |
+
```python
|
| 76 |
+
def _streaming_sse_response(gen):
|
| 77 |
+
resp = StreamingHttpResponse(_sse_iter(gen), content_type="text/event-stream")
|
| 78 |
+
resp["Cache-Control"] = "no-cache"
|
| 79 |
+
resp["X-Accel-Buffering"] = "no" # nginx
|
| 80 |
+
return resp
|
| 81 |
+
|
| 82 |
+
def _sse_iter(gen):
|
| 83 |
+
for event in gen:
|
| 84 |
+
etype = event.get("type", "message")
|
| 85 |
+
preview = event.pop("preview", None)
|
| 86 |
+
yield f"event: {etype}\ndata: {json.dumps(event, separators=(',', ':'))}\n\n"
|
| 87 |
+
if preview:
|
| 88 |
+
yield f"event: preview\ndata: {preview}\n\n"
|
| 89 |
+
```
|
| 90 |
+
|
| 91 |
+
All three inference generators (`run_standard_generation`, `run_multiprox_init`, `run_multiprox_step`) yield progress dicts with optional `preview` keys, then a final `result` dict.
|
| 92 |
+
|
| 93 |
+
### Safe Checkpoint Loading: `_safe_load_lightning_checkpoint`
|
| 94 |
+
|
| 95 |
+
`load_from_checkpoint` must **not** be used for MultiProxAn models. The comm20 checkpoints were trained with DDP (DistributedDataParallel), causing three cascading failures:
|
| 96 |
+
|
| 97 |
+
1. **DDP `__setstate__`/`__getstate__`** -- unpickling checkpoint hparams tries to reconstruct DDP wrappers, which requires an active process group. Fix: monkey-patch `__setstate__` and `__getstate__` during `torch.load`.
|
| 98 |
+
2. **`save_hyperparameters` deepcopy crash** -- Lightning's `save_hyperparameters` deepcopies hparams containing DDP-wrapped datamodule objects, causing a hard process crash (exit code -1). Fix: patch `save_hyperparameters` to a no-op during model construction.
|
| 99 |
+
3. **CUDA OOM** -- the full checkpoint (with heavy pickled hparams objects) crashes if loaded directly to GPU. Fix: `map_location="cpu"` first, then `model.to(device)` after state_dict loading.
|
| 100 |
+
|
| 101 |
+
The solution in `registry.py`:
|
| 102 |
+
```python
|
| 103 |
+
def _safe_load_lightning_checkpoint(cls, ckpt_path):
|
| 104 |
+
# 1. Patch DDP setstate/getstate to simple dict update/return
|
| 105 |
+
# 2. torch.load(ckpt_path, map_location="cpu", weights_only=False)
|
| 106 |
+
# 3. Extract hyper_parameters, null out train_metrics/sampling_metrics/visualization_tools
|
| 107 |
+
# 4. Patch save_hyperparameters to no-op
|
| 108 |
+
# 5. Construct model via cls(**hparams)
|
| 109 |
+
# 6. load_state_dict(strict=False), del ckpt, model.to(device), model.eval()
|
| 110 |
+
```
|
| 111 |
+
|
| 112 |
+
### Graph Rendering: PIL + networkx (no matplotlib)
|
| 113 |
+
|
| 114 |
+
`_render_comm20` uses **PIL + networkx `spring_layout`** instead of matplotlib. Matplotlib's GUI backend initializes in the calling thread and crashes with exit code -1 when called from Django's request threads on Windows. Even with `matplotlib.use("Agg")`, the backend initialization races with Django's threading model.
|
| 115 |
+
|
| 116 |
+
```python
|
| 117 |
+
def _render_comm20(X_int, E_int):
|
| 118 |
+
import networkx as nx
|
| 119 |
+
from PIL import Image, ImageDraw
|
| 120 |
+
# Build networkx Graph from adjacency matrix
|
| 121 |
+
# spring_layout(G, seed=42) for deterministic positions
|
| 122 |
+
# PIL drawing: lines for edges, ellipses for nodes (#2ecc71 green, #1a7a42 outline)
|
| 123 |
+
# Returns PIL.Image (300x300)
|
| 124 |
+
```
|
| 125 |
+
|
| 126 |
+
QM9 rendering uses RDKit's `MolToImage` (no matplotlib dependency).
|
| 127 |
+
|
| 128 |
+
### Inference Lock Management
|
| 129 |
+
|
| 130 |
+
A single `threading.Lock` (`_inference_lock`) protects all inference endpoints (COINs predict, graphgen generate, graphgen continue). Design choices:
|
| 131 |
+
|
| 132 |
+
- **Non-blocking acquire** -- returns 429 `INFERENCE_BUSY` immediately if locked
|
| 133 |
+
- **Owner tracking** -- `_inference_lock_owner` stores a description string (e.g. `"graphgen_generate comm20/discrete/standard"`) for debugging
|
| 134 |
+
- **Lock held by generator** -- the lock is acquired eagerly (before the generator starts) and released in the generator's `finally` block, so it is held for the entire streaming duration
|
| 135 |
+
- **Force-release endpoint** -- `POST /debug/force-unlock` (DEBUG mode only) for stuck locks from killed clients. Also exposed in the health endpoint's `inference_lock` field
|
| 136 |
+
|
| 137 |
+
All three lock acquisition sites follow the same pattern:
|
| 138 |
+
```python
|
| 139 |
+
if not self._inference_lock.acquire(blocking=False):
|
| 140 |
+
raise InferenceBusy()
|
| 141 |
+
self._inference_lock_owner = f"<endpoint> <context>"
|
| 142 |
+
try:
|
| 143 |
+
model = self._load_<model>(...)
|
| 144 |
+
except Exception:
|
| 145 |
+
self._inference_lock_owner = None
|
| 146 |
+
self._inference_lock.release()
|
| 147 |
+
raise
|
| 148 |
+
|
| 149 |
+
def _gen():
|
| 150 |
+
try:
|
| 151 |
+
...yield events...
|
| 152 |
+
finally:
|
| 153 |
+
self._inference_lock_owner = None
|
| 154 |
+
self._inference_lock.release()
|
| 155 |
+
return _gen()
|
| 156 |
+
```
|
| 157 |
+
|
| 158 |
+
### Research Code Import Guards
|
| 159 |
+
|
| 160 |
+
Several research code modules import heavy optional dependencies (`graph_tool`, `pyemd`, `pygsp`, `dist_helper`) at module level. These are needed only for metric computation during training, not inference. Since they are transitively imported during checkpoint unpickling (via pickled datamodule/metric objects), they must be guarded with `try/except ImportError` to avoid crashing on systems where these packages are not installed.
|
| 161 |
+
|
| 162 |
+
Files to patch: `spectre_utils.py`, `molecular_metrics.py`, `molecular_metrics_discrete.py`, `train_metrics.py`.
|
| 163 |
|
| 164 |
## Implementation
|
| 165 |
|
| 166 |
+
### Step 0: Patch research code -- remove `wandb`
|
| 167 |
|
| 168 |
+
`wandb` is imported at module level in three files; the import alone causes `ImportError` at checkpoint load time. One call in `diffusion_model.py` is **unguarded** (no `if wandb.run:` check) and would crash at inference even if the import somehow succeeded.
|
|
|
|
|
|
|
| 169 |
|
| 170 |
**`src/research/MultiProxAn/src/diffusion_model_discrete.py`**:
|
| 171 |
- Remove `import wandb` (line 9)
|
| 172 |
+
- Remove the two `utils.setup_wandb(self.cfg)` call sites
|
| 173 |
+
- Remove all `if wandb.run: wandb.log(...)` / `wandb.run.summary[...]` blocks
|
| 174 |
|
| 175 |
**`src/research/MultiProxAn/src/diffusion_model.py`**:
|
| 176 |
- Remove `import wandb` (line 9)
|
| 177 |
- Remove the two `utils.setup_wandb(self.cfg)` call sites
|
| 178 |
+
- Remove all guarded `if wandb.run:` blocks + one unguarded `wandb.log(...)` call
|
|
|
|
| 179 |
|
| 180 |
**`src/research/MultiProxAn/src/utils.py`**:
|
| 181 |
+
- Remove `import wandb` and the entire `setup_wandb(cfg)` function
|
|
|
|
|
|
|
|
|
|
| 182 |
|
| 183 |
---
|
| 184 |
|
| 185 |
+
### Step 0b: Guard heavy optional imports in research code
|
| 186 |
+
|
| 187 |
+
Wrap `graph_tool`, `pyemd`, `pygsp`, and `dist_helper` imports in `try/except` in:
|
| 188 |
+
- `spectre_utils.py` -- all four imports
|
| 189 |
+
- `molecular_metrics.py`, `molecular_metrics_discrete.py`, `train_metrics.py` -- cascading metric deps
|
| 190 |
|
| 191 |
+
---
|
| 192 |
+
|
| 193 |
+
### Step 1: `settings.py` -- sys.path fix
|
| 194 |
|
| 195 |
+
Both `MultiProxAn/` and `MultiProxAn/src/` must be on `sys.path`:
|
| 196 |
```python
|
| 197 |
_MULTIPROXAN_SRC = str(PROJECT_ROOT / "src" / "research" / "MultiProxAn" / "src")
|
| 198 |
for _path in (_COINS_KG_ROOT, _MULTIPROXAN_ROOT, _MULTIPROXAN_SRC):
|
|
|
|
| 204 |
|
| 205 |
### Step 2: New `graphgen_inference.py`
|
| 206 |
|
| 207 |
+
Completely independent of Django/registry. Receives a loaded model and yields progress/result dicts.
|
| 208 |
|
| 209 |
**Module structure:**
|
| 210 |
|
| 211 |
```
|
| 212 |
graphgen_inference.py
|
| 213 |
+
+-- Constants: QM9_ATOM_TYPES, STATE_BLOB_MAX_BYTES, REQUIRED_STATE_KEYS
|
| 214 |
+
|
|
| 215 |
+
+-- # Model-type helpers
|
| 216 |
+
+-- _is_discrete(model) -> bool
|
| 217 |
+
+-- _build_node_mask(n_nodes, n_max, model) -> bool or float32 tensor
|
| 218 |
+
+-- _sample_initial_noise(model, n_max, node_mask) -> PlaceHolder
|
| 219 |
+
+-- _denoising_step(model, s_t, t_t, X, E, y, node_mask) -> (X_soft, E_soft, y_soft, X_int, E_int)
|
| 220 |
+
+-- _gibbs_aggregate(model, X) -> tensor [median (discrete) or mean (continuous)]
|
| 221 |
+
+-- _collapse_final(model, X, E, y, node_mask) -> (X_int, E_int)
|
| 222 |
+
|
|
| 223 |
+
+-- # Main inference generators (yield progress dicts, then result dict)
|
| 224 |
+
+-- run_standard_generation(model, num_nodes, diffusion_steps, chain_frames, dataset_id)
|
| 225 |
+
+-- run_multiprox_init(model, num_nodes, n, m, t, t_prime, gibbs_chain_freq, dataset_id)
|
| 226 |
+
+-- run_multiprox_step(model, state_dict, dataset_id)
|
| 227 |
+
|
|
| 228 |
+
+-- # State serialisation
|
| 229 |
+
+-- encode_state_blob(state_dict) -> str
|
| 230 |
+
+-- decode_state_blob(b64_str) -> dict
|
| 231 |
+
|
|
| 232 |
+
+-- # Visualisation
|
| 233 |
+
+-- render_graph(X_int, E_int, dataset_id) -> PIL.Image
|
| 234 |
+
+-- _render_qm9(X_int, E_int) -> PIL.Image [RDKit]
|
| 235 |
+
+-- _render_comm20(X_int, E_int) -> PIL.Image [PIL + networkx spring_layout]
|
| 236 |
+
+-- _pil_to_b64(img) -> str
|
| 237 |
+
+-- _frames_to_gif_b64(frames) -> str
|
| 238 |
```
|
| 239 |
|
| 240 |
#### Model-type helpers
|
| 241 |
|
| 242 |
+
`_denoising_step` returns `(X_soft, E_soft, y_soft, X_int, E_int)`. For discrete models, `.long()` is applied to `discrete_s.X` and `discrete_s.E` to counteract the `.type_as(y_t)` float cast. For continuous models, `utils.unnormalize` with `collapse=True` produces the integer tensors.
|
| 243 |
+
|
| 244 |
+
`_collapse_final` also applies `.long()` on the discrete path for the same reason.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
|
| 246 |
#### `run_standard_generation`
|
| 247 |
|
| 248 |
+
Re-implements the denoising loop -- does NOT call `sample_batch`. Uses `s_idx / diffusion_steps` for normalized time. Yields `progress` events at each step (with `preview` at frame intervals) and a final `result` event with `image` and `chain_gif`.
|
| 249 |
|
| 250 |
```python
|
| 251 |
def run_standard_generation(model, num_nodes, diffusion_steps, chain_frames, dataset_id):
|
|
|
|
| 268 |
step = diffusion_steps - 1 - s_idx
|
| 269 |
if step % frame_interval == 0 or s_idx == 0:
|
| 270 |
gif_frames.append(render_graph(X_int[0, :n_max], E_int[0, :n_max, :n_max], dataset_id))
|
| 271 |
+
yield {"type": "progress", "phase": "denoise", "step": step + 1, ...}
|
| 272 |
|
| 273 |
X_final, E_final = _collapse_final(model, X, E, y, node_mask)
|
| 274 |
+
yield {"type": "result", "image": ..., "chain_gif": ..., "inference_time_ms": ...}
|
|
|
|
| 275 |
```
|
| 276 |
|
| 277 |
+
#### API parameter -> model attribute mapping
|
| 278 |
|
| 279 |
| API param | Model attribute | Role |
|
| 280 |
|---|---|---|
|
| 281 |
| `t` | `model.gibbs_fixed_t_2` | Fixed noise level for the Gibbs chain; used by `apply_noise(gibbs=True)` and as `fixed_t_norm` in the inner denoising step |
|
| 282 |
+
| `t_prime` | `model.gibbs_fixed_t_1` | Refinement target; `P = int((gibbs_fixed_t_2 - gibbs_fixed_t_1) * T) + 1` steps denoise from `t` down to `t_prime` |
|
| 283 |
| `gibbs_chain_freq` | `model.gibbs_chain_freq` | Inner Gibbs steps per `/continue` call (same attribute name in checkpoint) |
|
| 284 |
| `m` | `model.gibbs_M` | Ensemble size |
|
| 285 |
+
| `n` | `model.gibbs_N` | Number of outer Gibbs rounds; session is complete when `step == n` |
|
| 286 |
|
| 287 |
+
All parameters are always explicit in the API request -- no fallback to checkpoint values. The `apply_noise(gibbs=True)` call reads `self.gibbs_fixed_t_2` directly, so we must set `model.gibbs_fixed_t_2 = t` before calling it (and restore afterwards). This is safe because `_inference_lock` ensures single-threaded access.
|
| 288 |
|
| 289 |
#### `run_multiprox_init`
|
| 290 |
|
| 291 |
+
Initialises the M-member ensemble. Yields `progress` events during noise sampling, then a `result` event with the step-0 image and state dict. Default `gibbs_chain_freq = max(1, m // 10)`.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 292 |
|
| 293 |
#### `run_multiprox_step`
|
| 294 |
|
| 295 |
+
Runs `gibbs_chain_freq` inner Gibbs iterations. Yields `progress` events during Gibbs phase (with preview of ensemble aggregate) and refinement phase (t -> t_prime denoising). The refinement pass always runs to produce a clean render. Returns `round_complete` and `done` flags.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 296 |
|
| 297 |
**State blob (encode/decode):**
|
| 298 |
|
|
|
|
| 301 |
REQUIRED_STATE_KEYS = {"X", "E", "y", "n_nodes", "dataset_id", "model_type", "T",
|
| 302 |
"n", "m", "t", "t_prime", "gibbs_chain_freq", "inner_step", "step"}
|
| 303 |
|
| 304 |
+
# encode: torch.save(state_dict, BytesIO) -> base64 string
|
| 305 |
+
# decode: base64 -> bytes -> torch.load(weights_only=False) -> validate keys + X.dim()==4, E.dim()==5
|
|
|
|
|
|
|
| 306 |
```
|
| 307 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 308 |
---
|
| 309 |
|
| 310 |
### Step 3: `registry.py` additions
|
| 311 |
|
| 312 |
+
**`_safe_load_lightning_checkpoint(cls, ckpt_path)`** -- top-level function, used by `_load_graphgen_model`.
|
| 313 |
|
| 314 |
+
**New instance variables** in `ModelRegistry.__init__`:
|
| 315 |
```python
|
| 316 |
+
self._inference_lock = threading.Lock()
|
| 317 |
+
self._inference_lock_owner = None # description string for debugging
|
| 318 |
+
self._graphgen_models = {} # (dataset_id, model_type) -> loaded eval-mode model
|
| 319 |
```
|
| 320 |
|
| 321 |
+
**`force_release_inference_lock(self)`** -- emergency release for stuck locks.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 322 |
|
| 323 |
+
**`_load_graphgen_model(self, dataset_id, model_type)`** -- lazy load with caching. Uses `_safe_load_lightning_checkpoint` instead of `load_from_checkpoint`.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
|
| 325 |
+
**`graphgen_generate_stream(self, ...)`** -- acquires lock, loads model, returns a generator. The generator yields SSE-compatible dicts and releases the lock in `finally`. Encodes state blob for multiprox mode.
|
|
|
|
| 326 |
|
| 327 |
+
**`graphgen_continue_stream(self, state_b64)`** -- decodes state blob eagerly (fail-fast before lock), acquires lock, loads model, returns generator.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 328 |
|
| 329 |
---
|
| 330 |
|
| 331 |
### Step 4: `views/graph_generation.py` additions
|
| 332 |
|
| 333 |
+
**`GraphGenGenerateView.post`** -- validation + dispatch:
|
| 334 |
- Require `dataset_id`, `model_type`, `sampling_mode`
|
| 335 |
+
- `dataset_id` in `GRAPHGEN_DATASETS`, `model_type` in `{"discrete","continuous"}`, `sampling_mode` in `{"standard","multiprox"}`
|
| 336 |
+
- Check `registry.graphgen_checkpoints_available` contains `model_type` -> `ModelUnavailable(503)` if not
|
| 337 |
- `diffusion_steps` clamped to [50, 1000], `chain_frames` to [10, 30]
|
| 338 |
- `num_nodes`: optional int, validated against `GRAPHGEN_DATASETS[dataset_id]["max_nodes"]`
|
| 339 |
+
- multiprox only: `multiprox_params` required; validate `0 < t_prime <= t <= 1`, `2 <= m <= 100`, `n >= 1`, `1 <= gibbs_chain_freq <= m`
|
| 340 |
+
- Return `_streaming_sse_response(gen)`
|
| 341 |
|
| 342 |
+
**`GraphGenContinueView.post`** -- minimal validation:
|
| 343 |
- Require `state` (non-empty string)
|
| 344 |
+
- Return `_streaming_sse_response(gen)`
|
| 345 |
+
|
| 346 |
+
**SSE helpers** (`_streaming_sse_response`, `_sse_iter`) -- convert generator dicts to SSE event stream with anti-buffering headers. Preview images split into separate `event: preview` events.
|
| 347 |
|
| 348 |
---
|
| 349 |
|
| 350 |
+
### Step 5: `views/health.py` additions
|
| 351 |
|
| 352 |
+
**`HealthView`** -- include `inference_lock` status (`locked`, `owner`) in response.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 353 |
|
| 354 |
+
**`ForceUnlockView`** -- DEBUG-only POST endpoint at `debug/force-unlock`.
|
|
|
|
|
|
|
|
|
|
| 355 |
|
| 356 |
---
|
| 357 |
|
| 358 |
+
### Step 6: `urls.py`
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 359 |
|
| 360 |
+
```python
|
| 361 |
+
path("graph-generation/generate", GraphGenGenerateView.as_view()),
|
| 362 |
+
path("graph-generation/continue", GraphGenContinueView.as_view()),
|
| 363 |
+
path("debug/force-unlock", ForceUnlockView.as_view()),
|
| 364 |
+
```
|
| 365 |
|
| 366 |
---
|
| 367 |
|
| 368 |
+
### Step 7: `requirements.txt` -- new dependencies
|
| 369 |
|
| 370 |
+
| Package | Why needed |
|
| 371 |
+
|---|---|
|
| 372 |
+
| `Pillow` | PIL rendering in `graphgen_inference.py` |
|
| 373 |
+
| `overrides` | Transitive research code dependency |
|
| 374 |
|
| 375 |
+
`wandb` is **not** added -- stripped from research code.
|
| 376 |
|
| 377 |
+
---
|
|
|
|
|
|
|
|
|
|
|
|
|
| 378 |
|
| 379 |
+
### Step 8: README + Postman
|
|
|
|
|
|
|
|
|
|
|
|
|
| 380 |
|
| 381 |
+
**`README.md`**: Document SSE streaming protocol with three event types. Add force-unlock endpoint to endpoint table. Update health endpoint description to include lock status.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 382 |
|
| 383 |
+
**`docs/postman/collection.json`**:
|
| 384 |
+
- 12 graph-generation requests (4 standard generate, 4 multiprox init, 4 continue variants consolidated into 1 using `{{multiprox_state}}`)
|
| 385 |
+
- **Auto-chaining test scripts**: all 4 multiprox init requests have post-response scripts that parse the SSE `result` event and save the `state` field to the `multiprox_state` collection variable. The consolidated continue endpoint reads `{{multiprox_state}}` and also updates it from its result, enabling repeated continue calls.
|
| 386 |
|
| 387 |
---
|
| 388 |
|
|
|
|
| 390 |
|
| 391 |
| Risk | Mitigation |
|
| 392 |
|---|---|
|
| 393 |
+
| `weights_only=False` in `torch.load` for state blobs | State dict originates from our own server response. The blob is opaque base64 from the client but was generated server-side via `encode_state_blob`. Size limit (10 MB) and key validation provide basic bounds checking. |
|
| 394 |
+
| `_safe_load_lightning_checkpoint` bypasses Lightning's loading | Manual load + `load_state_dict(strict=False)` is correct for inference. Missing keys (if any) default to initialized values. |
|
| 395 |
| `model.gibbs_fixed_t_2` attribute override | Only safe because `_inference_lock` ensures single-threaded inference. Save/restore pattern is used. |
|
| 396 |
+
| SSE connection dropped mid-stream | Generator's `finally` block releases the lock. If the generator is garbage-collected without completing, Python still runs the finally block. |
|
| 397 |
+
| Research code imports heavy optional deps at module level | Guarded with `try/except ImportError` in spectre_utils.py, molecular_metrics*.py, train_metrics.py. |
|
| 398 |
|
| 399 |
## Verification
|
| 400 |
|
| 401 |
+
1. Django shell: `from diffusion_model_discrete import DiscreteDenoisingDiffusion` -- no `ModuleNotFoundError`
|
| 402 |
+
2. `GET /graph-generation/datasets` returns `available_model_types` (regression check)
|
| 403 |
+
3. All 4 standard generate combinations (qm9/comm20 x discrete/continuous) return SSE streams with progress, preview, and result events
|
| 404 |
+
4. All 4 multiprox init combinations return step-0 image and state blob
|
| 405 |
+
5. Continue endpoint chains correctly -- state updates across init -> continue x 3
|
| 406 |
+
6. Force-unlock endpoint releases stuck locks in DEBUG mode
|
| 407 |
+
7. Health endpoint shows lock status and owner
|
| 408 |
+
8. Error paths: unknown `dataset_id` -> 400, corrupted `state` -> 400, concurrent requests -> 429, `t_prime > t` -> 400
|
docs/postman/collection.json
CHANGED
|
@@ -8,6 +8,10 @@
|
|
| 8 |
{
|
| 9 |
"key": "base_url",
|
| 10 |
"value": "http://localhost:8000/api/v1"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
}
|
| 12 |
],
|
| 13 |
"item": [
|
|
@@ -52,6 +56,24 @@
|
|
| 52 |
},
|
| 53 |
"description": "List the 3 research methods."
|
| 54 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
}
|
| 56 |
]
|
| 57 |
},
|
|
@@ -362,60 +384,201 @@
|
|
| 362 |
"name": "Graph Generation - Inference",
|
| 363 |
"item": [
|
| 364 |
{
|
| 365 |
-
"name": "POST /graph-generation/generate (standard)",
|
| 366 |
"request": {
|
| 367 |
"method": "POST",
|
| 368 |
-
"header": [
|
| 369 |
-
{ "key": "Content-Type", "value": "application/json" }
|
| 370 |
-
],
|
| 371 |
"body": {
|
| 372 |
"mode": "raw",
|
| 373 |
-
"raw": "{\n \"dataset_id\": \"qm9\",\n \"model_type\": \"discrete\",\n \"sampling_mode\": \"standard\",\n \"num_nodes\": null,\n \"diffusion_steps\": 500,\n \"chain_frames\":
|
| 374 |
},
|
| 375 |
"url": {
|
| 376 |
"raw": "{{base_url}}/graph-generation/generate",
|
| 377 |
"host": ["{{base_url}}"],
|
| 378 |
"path": ["graph-generation", "generate"]
|
| 379 |
},
|
| 380 |
-
"description": "Standard denoising
|
| 381 |
}
|
| 382 |
},
|
| 383 |
{
|
| 384 |
-
"name": "POST /graph-generation/generate (
|
| 385 |
"request": {
|
| 386 |
"method": "POST",
|
| 387 |
-
"header": [
|
| 388 |
-
|
| 389 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 390 |
"body": {
|
| 391 |
"mode": "raw",
|
| 392 |
-
"raw": "{\n \"dataset_id\": \"comm20\",\n \"model_type\": \"discrete\",\n \"sampling_mode\": \"multiprox\",\n \"num_nodes\": null,\n \"diffusion_steps\": 500,\n \"multiprox_params\": {\n \"
|
| 393 |
},
|
| 394 |
"url": {
|
| 395 |
"raw": "{{base_url}}/graph-generation/generate",
|
| 396 |
"host": ["{{base_url}}"],
|
| 397 |
"path": ["graph-generation", "generate"]
|
| 398 |
},
|
| 399 |
-
"description": "MultiProx
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
}
|
| 401 |
},
|
| 402 |
{
|
| 403 |
"name": "POST /graph-generation/continue",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
"request": {
|
| 405 |
"method": "POST",
|
| 406 |
-
"header": [
|
| 407 |
-
{ "key": "Content-Type", "value": "application/json" }
|
| 408 |
-
],
|
| 409 |
"body": {
|
| 410 |
"mode": "raw",
|
| 411 |
-
"raw": "{\n \"state\": \"
|
| 412 |
},
|
| 413 |
"url": {
|
| 414 |
"raw": "{{base_url}}/graph-generation/continue",
|
| 415 |
"host": ["{{base_url}}"],
|
| 416 |
"path": ["graph-generation", "continue"]
|
| 417 |
},
|
| 418 |
-
"description": "Advance MultiProx chain by
|
| 419 |
}
|
| 420 |
}
|
| 421 |
]
|
|
|
|
| 8 |
{
|
| 9 |
"key": "base_url",
|
| 10 |
"value": "http://localhost:8000/api/v1"
|
| 11 |
+
},
|
| 12 |
+
{
|
| 13 |
+
"key": "multiprox_state",
|
| 14 |
+
"value": ""
|
| 15 |
}
|
| 16 |
],
|
| 17 |
"item": [
|
|
|
|
| 56 |
},
|
| 57 |
"description": "List the 3 research methods."
|
| 58 |
}
|
| 59 |
+
},
|
| 60 |
+
{
|
| 61 |
+
"name": "POST /debug/force-unlock",
|
| 62 |
+
"request": {
|
| 63 |
+
"method": "POST",
|
| 64 |
+
"header": [
|
| 65 |
+
{
|
| 66 |
+
"key": "Content-Type",
|
| 67 |
+
"value": "application/json"
|
| 68 |
+
}
|
| 69 |
+
],
|
| 70 |
+
"url": {
|
| 71 |
+
"raw": "{{base_url}}/debug/force-unlock",
|
| 72 |
+
"host": ["{{base_url}}"],
|
| 73 |
+
"path": ["debug", "force-unlock"]
|
| 74 |
+
},
|
| 75 |
+
"description": "Force-release a stuck inference lock. Only available in DEBUG mode."
|
| 76 |
+
}
|
| 77 |
}
|
| 78 |
]
|
| 79 |
},
|
|
|
|
| 384 |
"name": "Graph Generation - Inference",
|
| 385 |
"item": [
|
| 386 |
{
|
| 387 |
+
"name": "POST /graph-generation/generate (standard, QM9, discrete)",
|
| 388 |
"request": {
|
| 389 |
"method": "POST",
|
| 390 |
+
"header": [{ "key": "Content-Type", "value": "application/json" }],
|
|
|
|
|
|
|
| 391 |
"body": {
|
| 392 |
"mode": "raw",
|
| 393 |
+
"raw": "{\n \"dataset_id\": \"qm9\",\n \"model_type\": \"discrete\",\n \"sampling_mode\": \"standard\",\n \"num_nodes\": null,\n \"diffusion_steps\": 500,\n \"chain_frames\": 30\n}"
|
| 394 |
},
|
| 395 |
"url": {
|
| 396 |
"raw": "{{base_url}}/graph-generation/generate",
|
| 397 |
"host": ["{{base_url}}"],
|
| 398 |
"path": ["graph-generation", "generate"]
|
| 399 |
},
|
| 400 |
+
"description": "Standard denoising on QM9 molecules (discrete model). Returns final PNG image + chain GIF."
|
| 401 |
}
|
| 402 |
},
|
| 403 |
{
|
| 404 |
+
"name": "POST /graph-generation/generate (standard, QM9, continuous)",
|
| 405 |
"request": {
|
| 406 |
"method": "POST",
|
| 407 |
+
"header": [{ "key": "Content-Type", "value": "application/json" }],
|
| 408 |
+
"body": {
|
| 409 |
+
"mode": "raw",
|
| 410 |
+
"raw": "{\n \"dataset_id\": \"qm9\",\n \"model_type\": \"continuous\",\n \"sampling_mode\": \"standard\",\n \"num_nodes\": null,\n \"diffusion_steps\": 500,\n \"chain_frames\": 30\n}"
|
| 411 |
+
},
|
| 412 |
+
"url": {
|
| 413 |
+
"raw": "{{base_url}}/graph-generation/generate",
|
| 414 |
+
"host": ["{{base_url}}"],
|
| 415 |
+
"path": ["graph-generation", "generate"]
|
| 416 |
+
},
|
| 417 |
+
"description": "Standard denoising on QM9 molecules (continuous/lifted model). Returns final PNG image + chain GIF."
|
| 418 |
+
}
|
| 419 |
+
},
|
| 420 |
+
{
|
| 421 |
+
"name": "POST /graph-generation/generate (standard, comm20, discrete)",
|
| 422 |
+
"request": {
|
| 423 |
+
"method": "POST",
|
| 424 |
+
"header": [{ "key": "Content-Type", "value": "application/json" }],
|
| 425 |
+
"body": {
|
| 426 |
+
"mode": "raw",
|
| 427 |
+
"raw": "{\n \"dataset_id\": \"comm20\",\n \"model_type\": \"discrete\",\n \"sampling_mode\": \"standard\",\n \"num_nodes\": null,\n \"diffusion_steps\": 500,\n \"chain_frames\": 30\n}"
|
| 428 |
+
},
|
| 429 |
+
"url": {
|
| 430 |
+
"raw": "{{base_url}}/graph-generation/generate",
|
| 431 |
+
"host": ["{{base_url}}"],
|
| 432 |
+
"path": ["graph-generation", "generate"]
|
| 433 |
+
},
|
| 434 |
+
"description": "Standard denoising on Community20 graphs (discrete model). Returns final PNG image + chain GIF."
|
| 435 |
+
}
|
| 436 |
+
},
|
| 437 |
+
{
|
| 438 |
+
"name": "POST /graph-generation/generate (standard, comm20, continuous)",
|
| 439 |
+
"request": {
|
| 440 |
+
"method": "POST",
|
| 441 |
+
"header": [{ "key": "Content-Type", "value": "application/json" }],
|
| 442 |
+
"body": {
|
| 443 |
+
"mode": "raw",
|
| 444 |
+
"raw": "{\n \"dataset_id\": \"comm20\",\n \"model_type\": \"continuous\",\n \"sampling_mode\": \"standard\",\n \"num_nodes\": null,\n \"diffusion_steps\": 500,\n \"chain_frames\": 30\n}"
|
| 445 |
+
},
|
| 446 |
+
"url": {
|
| 447 |
+
"raw": "{{base_url}}/graph-generation/generate",
|
| 448 |
+
"host": ["{{base_url}}"],
|
| 449 |
+
"path": ["graph-generation", "generate"]
|
| 450 |
+
},
|
| 451 |
+
"description": "Standard denoising on Community20 graphs (continuous/lifted model). Returns final PNG image + chain GIF."
|
| 452 |
+
}
|
| 453 |
+
},
|
| 454 |
+
{
|
| 455 |
+
"name": "POST /graph-generation/generate (multiprox init, QM9, discrete)",
|
| 456 |
+
"event": [
|
| 457 |
+
{
|
| 458 |
+
"listen": "test",
|
| 459 |
+
"script": {
|
| 460 |
+
"type": "text/javascript",
|
| 461 |
+
"exec": ["// Extract state from the SSE result event and store as collection variable", "var body = pm.response.text();", "var lines = body.split('\\n');", "for (var i = 0; i < lines.length; i++) {", " if (lines[i].trim() === 'event: result' && i + 1 < lines.length) {", " var dataLine = lines[i + 1].replace(/^data: /, '');", " try {", " var result = JSON.parse(dataLine);", " if (result.state) {", " pm.collectionVariables.set('multiprox_state', result.state);", " console.log('State saved to {{multiprox_state}} (' + result.state.length + ' chars)');", " }", " } catch (e) { console.log('Failed to parse result event: ' + e); }", " break;", " }", "}"]
|
| 462 |
+
}
|
| 463 |
+
}
|
| 464 |
+
],
|
| 465 |
+
"request": {
|
| 466 |
+
"method": "POST",
|
| 467 |
+
"header": [{ "key": "Content-Type", "value": "application/json" }],
|
| 468 |
+
"body": {
|
| 469 |
+
"mode": "raw",
|
| 470 |
+
"raw": "{\n \"dataset_id\": \"qm9\",\n \"model_type\": \"discrete\",\n \"sampling_mode\": \"multiprox\",\n \"num_nodes\": null,\n \"diffusion_steps\": 500,\n \"multiprox_params\": {\n \"n\": 10,\n \"m\": 100,\n \"t\": 0.5,\n \"t_prime\": 0.004,\n \"gibbs_chain_freq\": 10\n }\n}"
|
| 471 |
+
},
|
| 472 |
+
"url": {
|
| 473 |
+
"raw": "{{base_url}}/graph-generation/generate",
|
| 474 |
+
"host": ["{{base_url}}"],
|
| 475 |
+
"path": ["graph-generation", "generate"]
|
| 476 |
+
},
|
| 477 |
+
"description": "MultiProx Gibbs init on QM9 (discrete). Best params from thesis Table 4.3.1: t=50%, t'=0.4% of T. Returns step 0 image + state blob. State is auto-saved to {{multiprox_state}}."
|
| 478 |
+
}
|
| 479 |
+
},
|
| 480 |
+
{
|
| 481 |
+
"name": "POST /graph-generation/generate (multiprox init, QM9, continuous)",
|
| 482 |
+
"event": [
|
| 483 |
+
{
|
| 484 |
+
"listen": "test",
|
| 485 |
+
"script": {
|
| 486 |
+
"type": "text/javascript",
|
| 487 |
+
"exec": ["var body = pm.response.text();", "var lines = body.split('\\n');", "for (var i = 0; i < lines.length; i++) {", " if (lines[i].trim() === 'event: result' && i + 1 < lines.length) {", " var dataLine = lines[i + 1].replace(/^data: /, '');", " try {", " var result = JSON.parse(dataLine);", " if (result.state) { pm.collectionVariables.set('multiprox_state', result.state); }", " } catch (e) {}", " break;", " }", "}"]
|
| 488 |
+
}
|
| 489 |
+
}
|
| 490 |
+
],
|
| 491 |
+
"request": {
|
| 492 |
+
"method": "POST",
|
| 493 |
+
"header": [{ "key": "Content-Type", "value": "application/json" }],
|
| 494 |
+
"body": {
|
| 495 |
+
"mode": "raw",
|
| 496 |
+
"raw": "{\n \"dataset_id\": \"qm9\",\n \"model_type\": \"continuous\",\n \"sampling_mode\": \"multiprox\",\n \"num_nodes\": null,\n \"diffusion_steps\": 500,\n \"multiprox_params\": {\n \"n\": 10,\n \"m\": 100,\n \"t\": 0.5,\n \"t_prime\": 0.004,\n \"gibbs_chain_freq\": 10\n }\n}"
|
| 497 |
+
},
|
| 498 |
+
"url": {
|
| 499 |
+
"raw": "{{base_url}}/graph-generation/generate",
|
| 500 |
+
"host": ["{{base_url}}"],
|
| 501 |
+
"path": ["graph-generation", "generate"]
|
| 502 |
+
},
|
| 503 |
+
"description": "MultiProx Gibbs init on QM9 (continuous/lifted). Best params from thesis Table 4.3.1: t=50%, t'=0.4% of T. Returns step 0 image + state blob. State is auto-saved to {{multiprox_state}}."
|
| 504 |
+
}
|
| 505 |
+
},
|
| 506 |
+
{
|
| 507 |
+
"name": "POST /graph-generation/generate (multiprox init, comm20, discrete)",
|
| 508 |
+
"event": [
|
| 509 |
+
{
|
| 510 |
+
"listen": "test",
|
| 511 |
+
"script": {
|
| 512 |
+
"type": "text/javascript",
|
| 513 |
+
"exec": ["var body = pm.response.text();", "var lines = body.split('\\n');", "for (var i = 0; i < lines.length; i++) {", " if (lines[i].trim() === 'event: result' && i + 1 < lines.length) {", " var dataLine = lines[i + 1].replace(/^data: /, '');", " try {", " var result = JSON.parse(dataLine);", " if (result.state) { pm.collectionVariables.set('multiprox_state', result.state); }", " } catch (e) {}", " break;", " }", "}"]
|
| 514 |
+
}
|
| 515 |
+
}
|
| 516 |
+
],
|
| 517 |
+
"request": {
|
| 518 |
+
"method": "POST",
|
| 519 |
+
"header": [{ "key": "Content-Type", "value": "application/json" }],
|
| 520 |
"body": {
|
| 521 |
"mode": "raw",
|
| 522 |
+
"raw": "{\n \"dataset_id\": \"comm20\",\n \"model_type\": \"discrete\",\n \"sampling_mode\": \"multiprox\",\n \"num_nodes\": null,\n \"diffusion_steps\": 500,\n \"multiprox_params\": {\n \"n\": 10,\n \"m\": 100,\n \"t\": 0.4,\n \"t_prime\": 0.1,\n \"gibbs_chain_freq\": 10\n }\n}"
|
| 523 |
},
|
| 524 |
"url": {
|
| 525 |
"raw": "{{base_url}}/graph-generation/generate",
|
| 526 |
"host": ["{{base_url}}"],
|
| 527 |
"path": ["graph-generation", "generate"]
|
| 528 |
},
|
| 529 |
+
"description": "MultiProx Gibbs init on Community20 (discrete). Best params from thesis Table C.2.1: t=40%, t'=10% of T. Returns step 0 image + state blob. State is auto-saved to {{multiprox_state}}."
|
| 530 |
+
}
|
| 531 |
+
},
|
| 532 |
+
{
|
| 533 |
+
"name": "POST /graph-generation/generate (multiprox init, comm20, continuous)",
|
| 534 |
+
"event": [
|
| 535 |
+
{
|
| 536 |
+
"listen": "test",
|
| 537 |
+
"script": {
|
| 538 |
+
"type": "text/javascript",
|
| 539 |
+
"exec": ["var body = pm.response.text();", "var lines = body.split('\\n');", "for (var i = 0; i < lines.length; i++) {", " if (lines[i].trim() === 'event: result' && i + 1 < lines.length) {", " var dataLine = lines[i + 1].replace(/^data: /, '');", " try {", " var result = JSON.parse(dataLine);", " if (result.state) { pm.collectionVariables.set('multiprox_state', result.state); }", " } catch (e) {}", " break;", " }", "}"]
|
| 540 |
+
}
|
| 541 |
+
}
|
| 542 |
+
],
|
| 543 |
+
"request": {
|
| 544 |
+
"method": "POST",
|
| 545 |
+
"header": [{ "key": "Content-Type", "value": "application/json" }],
|
| 546 |
+
"body": {
|
| 547 |
+
"mode": "raw",
|
| 548 |
+
"raw": "{\n \"dataset_id\": \"comm20\",\n \"model_type\": \"continuous\",\n \"sampling_mode\": \"multiprox\",\n \"num_nodes\": null,\n \"diffusion_steps\": 500,\n \"multiprox_params\": {\n \"n\": 10,\n \"m\": 100,\n \"t\": 0.4,\n \"t_prime\": 0.1,\n \"gibbs_chain_freq\": 10\n }\n}"
|
| 549 |
+
},
|
| 550 |
+
"url": {
|
| 551 |
+
"raw": "{{base_url}}/graph-generation/generate",
|
| 552 |
+
"host": ["{{base_url}}"],
|
| 553 |
+
"path": ["graph-generation", "generate"]
|
| 554 |
+
},
|
| 555 |
+
"description": "MultiProx Gibbs init on Community20 (continuous/lifted). Best params from thesis Table C.2.1: t=40%, t'=10% of T. Returns step 0 image + state blob. State is auto-saved to {{multiprox_state}}."
|
| 556 |
}
|
| 557 |
},
|
| 558 |
{
|
| 559 |
"name": "POST /graph-generation/continue",
|
| 560 |
+
"event": [
|
| 561 |
+
{
|
| 562 |
+
"listen": "test",
|
| 563 |
+
"script": {
|
| 564 |
+
"type": "text/javascript",
|
| 565 |
+
"exec": ["// Update state for chaining multiple continue calls", "var body = pm.response.text();", "var lines = body.split('\\n');", "for (var i = 0; i < lines.length; i++) {", " if (lines[i].trim() === 'event: result' && i + 1 < lines.length) {", " var dataLine = lines[i + 1].replace(/^data: /, '');", " try {", " var result = JSON.parse(dataLine);", " if (result.state) {", " pm.collectionVariables.set('multiprox_state', result.state);", " console.log('State updated (done=' + result.done + ', step=' + result.step + ')');", " }", " } catch (e) {}", " break;", " }", "}"]
|
| 566 |
+
}
|
| 567 |
+
}
|
| 568 |
+
],
|
| 569 |
"request": {
|
| 570 |
"method": "POST",
|
| 571 |
+
"header": [{ "key": "Content-Type", "value": "application/json" }],
|
|
|
|
|
|
|
| 572 |
"body": {
|
| 573 |
"mode": "raw",
|
| 574 |
+
"raw": "{\n \"state\": \"{{multiprox_state}}\"\n}"
|
| 575 |
},
|
| 576 |
"url": {
|
| 577 |
"raw": "{{base_url}}/graph-generation/continue",
|
| 578 |
"host": ["{{base_url}}"],
|
| 579 |
"path": ["graph-generation", "continue"]
|
| 580 |
},
|
| 581 |
+
"description": "Advance MultiProx chain by gibbs_chain_freq inner steps. Uses {{multiprox_state}} from the last init/continue call. Can be fired repeatedly to chain steps."
|
| 582 |
}
|
| 583 |
}
|
| 584 |
]
|
src/backend/README.md
CHANGED
|
@@ -45,6 +45,7 @@ The API is served at `http://localhost:8000/api/v1/`.
|
|
| 45 |
| `DJANGO_SECRET_KEY` | `dev-insecure-key-change-in-production` | Django secret key. **Set in production.** |
|
| 46 |
| `DJANGO_DEBUG` | `True` | Enable debug mode. Set to `False` in production. |
|
| 47 |
| `DJANGO_ALLOWED_HOSTS` | `localhost,127.0.0.1` | Comma-separated allowed hosts. |
|
|
|
|
| 48 |
|
| 49 |
## Startup Sequence
|
| 50 |
|
|
@@ -65,8 +66,9 @@ All endpoints are prefixed with `/api/v1/`.
|
|
| 65 |
|
| 66 |
| Method | Path | Description |
|
| 67 |
|---|---|---|
|
| 68 |
-
| `GET` | `/health` | Service health + model availability |
|
| 69 |
| `GET` | `/methods` | List the 3 research methods |
|
|
|
|
| 70 |
|
| 71 |
### COINs β KG Reasoning
|
| 72 |
|
|
@@ -86,8 +88,8 @@ All endpoints are prefixed with `/api/v1/`.
|
|
| 86 |
|---|---|---|
|
| 87 |
| `GET` | `/graph-generation/datasets` | List graph types with node/edge types |
|
| 88 |
| `GET` | `/graph-generation/sampling-modes` | Sampling strategies with parameter specs |
|
| 89 |
-
| `POST` | `/graph-generation/generate` | Generate a graph (
|
| 90 |
-
| `POST` | `/graph-generation/continue` |
|
| 91 |
|
| 92 |
### KG Anomaly Correction
|
| 93 |
|
|
@@ -98,6 +100,32 @@ All endpoints are prefixed with `/api/v1/`.
|
|
| 98 |
| `POST` | `/kg-anomaly/correct` | Run correction (not yet implemented) |
|
| 99 |
| `POST` | `/kg-anomaly/continue` | Continue MultiProx correction (not yet implemented) |
|
| 100 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
## Project Structure
|
| 102 |
|
| 103 |
```
|
|
|
|
| 45 |
| `DJANGO_SECRET_KEY` | `dev-insecure-key-change-in-production` | Django secret key. **Set in production.** |
|
| 46 |
| `DJANGO_DEBUG` | `True` | Enable debug mode. Set to `False` in production. |
|
| 47 |
| `DJANGO_ALLOWED_HOSTS` | `localhost,127.0.0.1` | Comma-separated allowed hosts. |
|
| 48 |
+
| `TORCH_DEVICE` | Auto (`cuda:0` if available, else `cpu`) | PyTorch device for model inference. |
|
| 49 |
|
| 50 |
## Startup Sequence
|
| 51 |
|
|
|
|
| 66 |
|
| 67 |
| Method | Path | Description |
|
| 68 |
|---|---|---|
|
| 69 |
+
| `GET` | `/health` | Service health + model availability + inference lock status |
|
| 70 |
| `GET` | `/methods` | List the 3 research methods |
|
| 71 |
+
| `POST` | `/debug/force-unlock` | Release stuck inference lock (debug mode only) |
|
| 72 |
|
| 73 |
### COINs β KG Reasoning
|
| 74 |
|
|
|
|
| 88 |
|---|---|---|
|
| 89 |
| `GET` | `/graph-generation/datasets` | List graph types with node/edge types |
|
| 90 |
| `GET` | `/graph-generation/sampling-modes` | Sampling strategies with parameter specs |
|
| 91 |
+
| `POST` | `/graph-generation/generate` | **Streaming NDJSON.** Generate a graph (standard denoising or MultiProx Gibbs init) |
|
| 92 |
+
| `POST` | `/graph-generation/continue` | **Streaming NDJSON.** Advance a MultiProx Gibbs session by one step |
|
| 93 |
|
| 94 |
### KG Anomaly Correction
|
| 95 |
|
|
|
|
| 100 |
| `POST` | `/kg-anomaly/correct` | Run correction (not yet implemented) |
|
| 101 |
| `POST` | `/kg-anomaly/continue` | Continue MultiProx correction (not yet implemented) |
|
| 102 |
|
| 103 |
+
## Streaming Inference Protocol (SSE)
|
| 104 |
+
|
| 105 |
+
The graph generation endpoints (`/generate`, `/continue`) return **Server-Sent Events** (`text/event-stream`). Three event types are emitted:
|
| 106 |
+
|
| 107 |
+
**`event: progress`** β phase/step metadata (no images):
|
| 108 |
+
```
|
| 109 |
+
event: progress
|
| 110 |
+
data: {"type":"progress","phase":"denoise","step":42,"total_steps":500,"elapsed_ms":2100}
|
| 111 |
+
```
|
| 112 |
+
|
| 113 |
+
**`event: preview`** β base64 PNG of the graph's current state, emitted at key frames:
|
| 114 |
+
```
|
| 115 |
+
event: preview
|
| 116 |
+
data: data:image/png;base64,...
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
Preview frequency: `denoise` emits at `chain_frames` intervals (~30 over 500 steps), `gibbs` emits every inner step, `refine` emits every ~10% of steps.
|
| 120 |
+
|
| 121 |
+
**`event: result`** β final payload with image, chain GIF, and timing:
|
| 122 |
+
```
|
| 123 |
+
event: result
|
| 124 |
+
data: {"type":"result","dataset_id":"qm9","model_type":"discrete","sampling_mode":"standard","image":"data:image/png;base64,...","chain_gif":"data:image/gif;base64,...","inference_time_ms":25000}
|
| 125 |
+
```
|
| 126 |
+
|
| 127 |
+
Phases: `denoise` (standard generation loop), `noise_init` (multiprox init noise sampling), `gibbs` (multiprox inner Gibbs steps), `refine` (multiprox refinement denoising).
|
| 128 |
+
|
| 129 |
## Project Structure
|
| 130 |
|
| 131 |
```
|