Andrej Janchevski commited on
Commit
acde928
·
1 Parent(s): c50d15c

docs(plan): add kg-anomaly correction inference endpoints plan

Browse files

Plan for POST /kg-anomaly/correct and POST /kg-anomaly/continue wrapping
DiscreteDenoisingDiffusionKG. Documents model-loading strategy with
reconstructed dataset_infos (KG checkpoints only pickle cfg via
save_hyperparameters('cfg'), unlike MultiProxAn), SSE streaming for both
standard and multiprox modes, change detection, and KG subgraph rendering.

Files changed (1) hide show
  1. .claude/plans/backend_kg_anomaly.md +261 -0
.claude/plans/backend_kg_anomaly.md ADDED
@@ -0,0 +1,261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # KG Anomaly Correction Inference Endpoints
2
+
3
+ Implement `POST /kg-anomaly/correct` and `POST /kg-anomaly/continue` -- the final two inference endpoints.
4
+ These wrap the DiGress KG diffusion model (`DiscreteDenoisingDiffusionKG`) for correcting/generating knowledge graph subgraph edges, following the same architecture as the MultiProxAn graph generation endpoints.
5
+
6
+ ## Context
7
+
8
+ Discovery endpoints (`GET /kg-anomaly/datasets`, `GET /kg-anomaly/datasets/{id}/sample-subgraphs`) are already implemented. Checkpoint scanning is already in `registry.py`. The two remaining endpoints are the core inference: standard denoising and MultiProx step-by-step Gibbs sampling, both using SSE streaming (updated from the original JSON spec per user request).
9
+
10
+ The MultiProxAn graph generation endpoints (`graphgen_inference.py`, `graph_generation.py` views, `registry.py` methods) are the direct template. Key differences for KG anomaly:
11
+ - Model class: `DiscreteDenoisingDiffusionKG` (not `DiscreteDenoisingDiffusion`)
12
+ - Inputs: subgraph nodes/edges from API (not generated from noise) -- X is given, only E is diffused
13
+ - Extra model args: `X_index` (entity IDs), `inpaint_mask` (which edges to correct)
14
+ - Task modes: "correct" (inpaint masked edges) vs "generate" (regenerate all edges)
15
+ - Checkpoint only saves `cfg` via `save_hyperparameters('cfg')` -- must reconstruct `dataset_infos` from a COINs experiment + state_dict shapes
16
+ - Response includes `changes` diff (before/after edge comparison)
17
+
18
+ ## Assumptions and Constraints
19
+
20
+ - Checkpoints: `{dataset_id}.ckpt` (generate) and `{dataset_id}_correct.ckpt` (correct) in `src/research/COINs-KGGeneration/graph_generation/checkpoints/`
21
+ - Only discrete KG models (no continuous variant for KG anomaly)
22
+ - `loader.communities` (numpy array, entity_id -> community_id) is available on the lightweight Loader (not freed)
23
+ - COINs experiment for the same dataset must load successfully (needed for kg_experiment in model constructor)
24
+ - E tensor uses class 0 for "no edge", classes 1..N for actual relation types; API `relation_id` is 0-indexed for actual relations: `E_class = relation_id + 1`
25
+ - `graph_generation/src` must be added to `sys.path` for bare imports (`from diffusion.noise_schedule import ...`, `from metrics.abstract_metrics import ...`)
26
+
27
+ ## Scope
28
+
29
+ **Included:**
30
+ - `POST /kg-anomaly/correct` (standard + multiprox modes, both SSE streaming)
31
+ - `POST /kg-anomaly/continue` (multiprox continuation, SSE streaming)
32
+ - KG subgraph rendering (PIL + networkx, directed graph with entity/relation labels, color-coded changes)
33
+ - Change detection (before/after edge diff)
34
+ - Model loading with reconstructed dataset_infos
35
+ - sys.path update for `graph_generation/src`
36
+ - API spec update (`docs/api.yaml`) to reflect SSE streaming
37
+ - Postman collection update with auto-chaining test scripts
38
+ - Backend README update
39
+
40
+ **Excluded:**
41
+ - Frontend implementation
42
+ - Continuous model variant (only discrete exists for KG)
43
+
44
+ ## Design
45
+
46
+ ### Model Loading Strategy
47
+
48
+ The checkpoint's `hyper_parameters` only contains `cfg` (from `save_hyperparameters('cfg')`). We must reconstruct all other constructor args:
49
+
50
+ 1. **Load checkpoint to CPU** with DDP patching (same as `_safe_load_lightning_checkpoint`)
51
+ 2. **Extract `cfg`** from `hyper_parameters`
52
+ 3. **Infer dims from state_dict shapes:**
53
+ - `Xdim_output = ckpt["state_dict"]["transition_model.u_x"].shape[1]`
54
+ - `Edim_output = ckpt["state_dict"]["transition_model.u_e"].shape[1]`
55
+ - `input_dims['X'] = ckpt["state_dict"]["model.mlp_in_X.0.weight"].shape[1]`
56
+ - `input_dims['E'] = ckpt["state_dict"]["model.mlp_in_E.0.weight"].shape[1]`
57
+ - `input_dims['y'] = ckpt["state_dict"]["model.mlp_in_y.0.weight"].shape[1]`
58
+ - `output_dims = {'X': Xdim_output, 'E': Edim_output, 'y': 0}`
59
+ 4. **Extract marginal distributions** from state_dict buffers:
60
+ - `node_types = ckpt["state_dict"]["transition_model.u_x"].squeeze(0)`
61
+ - `edge_types = ckpt["state_dict"]["transition_model.u_e"].squeeze(0)`
62
+ 5. **Load COINs experiment** via `_load_coins_experiment(dataset_id, "transe")` for the `kg_experiment`
63
+ 6. **Build mock `dataset_infos`** with dims, distributions, `nodes_dist`, and `datamodule.kg_experiment`
64
+ 7. **Build `extra_features`** from `cfg.model.extra_features` (ExtraFeatures or DummyExtraFeatures)
65
+ 8. **Construct model**, patch `save_hyperparameters` to no-op, load state_dict, move to device, eval mode
66
+
67
+ ### Input Subgraph -> Tensor Conversion
68
+
69
+ `build_kg_tensors(subgraph, dataset_id, loader, model)` -> `(X_given, E_given, y_given, X_index, X_c, n_nodes, is_bip, node_mask)`
70
+
71
+ - **X_given** `(1, n, Xdim_output)`: one-hot from `type_id`. `X[0, i, type_id] = 1.0`
72
+ - **E_given** `(1, n, n, Edim_output)`: init all to class 0 (no edge). For each edge: `E[0, src, tgt, 0] = 0; E[0, src, tgt, relation_id + 1] = 1.0`
73
+ - **y_given** `(1, 0)`: empty tensor (consistent with `kg_dataset.py` line 73)
74
+ - **X_index** `(1, n)`: entity IDs from `nodes[i]["entity_id"]`
75
+ - **X_c** `(1, n)`: `loader.communities[entity_id]` for each node
76
+ - **n_nodes** `tensor([n])`, **is_bip** `tensor([n > 20])`, **node_mask** `ones(1, n, dtype=bool)`
77
+
78
+ ### Change Detection
79
+
80
+ `compute_changes(original_E_int, corrected_E_int, num_nodes, loader)` -> `{"edges": [...], "summary": {...}}`
81
+
82
+ For each directed pair `(i, j)` where `i != j`:
83
+ - `orig = original_E_int[i, j]`, `corr = corrected_E_int[i, j]` (integer class indices)
84
+ - Both 0: skip (no edge in either)
85
+ - Same nonzero: `"unchanged"`
86
+ - orig=0, corr>0: `"added"`
87
+ - orig>0, corr=0: `"removed"` (with `original_relation_id = orig - 1`)
88
+ - Both nonzero, different: `"modified"` (with `original_relation_id = orig - 1`)
89
+
90
+ Relation names resolved via `loader.dataset.get_inverted_name_maps()`. API `relation_id = E_class - 1`.
91
+
92
+ ### KG Subgraph Rendering
93
+
94
+ `render_kg_subgraph(X_int, E_int, num_nodes, X_index, dataset_id, loader, changes=None)` -> PIL.Image
95
+
96
+ - PIL + networkx `spring_layout` (no matplotlib, same pattern as `_render_comm20`)
97
+ - Directed graph (`nx.DiGraph`), directed arrows for edges
98
+ - Node labels: entity names from `inv_nodes[entity_id]` (truncated)
99
+ - Edge labels: relation names from `inv_relations[relation_id]`
100
+ - Image size: 500x500
101
+ - Color coding when `changes` provided:
102
+ - unchanged: gray `#888888`
103
+ - modified: orange `#e67e22`
104
+ - added: green `#27ae60`
105
+ - removed: red dashed `#e74c3c`
106
+
107
+ ### SSE Streaming (updated from original JSON spec)
108
+
109
+ Both endpoints use SSE streaming like graph generation, with `event: progress`, `event: preview`, and `event: result`. The `result` event for standard mode includes `original_image`, `corrected_image`, `chain_gif`, `changes`. For multiprox: `step`, `image`, `state`, `changes`.
110
+
111
+ ### State Blob (MultiProx)
112
+
113
+ Keys:
114
+ ```
115
+ X_given, E, y, n_nodes, dataset_id, task, X_index, X_c, is_bip,
116
+ original_E_int, T, n, m, t, t_prime, gibbs_chain_freq, inner_step, step
117
+ ```
118
+
119
+ Note: `X_given` (not `X`) because node features are fixed. `E` is the `(1, M, n, n, Edim)` ensemble tensor.
120
+
121
+ ## Implementation Steps
122
+
123
+ ### Step 1: Add `graph_generation/src` to sys.path
124
+
125
+ **File:** `src/backend/research_api/settings.py`
126
+
127
+ Add `_DIGRESS_KG_SRC = str(PROJECT_ROOT / "src" / "research" / "COINs-KGGeneration" / "graph_generation" / "src")` to the sys.path block. Required for bare imports in `diffusion_model_discrete_kg.py`:
128
+ - `from diffusion.noise_schedule import ...`
129
+ - `from metrics.abstract_metrics import ...`
130
+
131
+ ### Step 2: Create `src/backend/api/services/kg_anomaly_inference.py`
132
+
133
+ New file mirroring `graphgen_inference.py`. Contains:
134
+
135
+ - `build_kg_tensors()` -- input subgraph -> tensors
136
+ - `compute_changes()` -- before/after edge diff
137
+ - `render_kg_subgraph()` -- PIL + networkx directed graph rendering
138
+ - `run_standard_correction()` -- generator yielding progress/preview/result events
139
+ - `run_multiprox_correction_init()` -- generator yielding noise init progress + result with state
140
+ - `run_multiprox_correction_step()` -- generator yielding gibbs + refine progress + result with updated state
141
+ - `encode_state_blob()` / `decode_state_blob()` -- same pattern as graphgen
142
+ - `_pil_to_b64()` / `_frames_to_gif_b64()` -- import from graphgen_inference (reuse)
143
+
144
+ Standard correction generator logic (following research code `sample_batch`):
145
+ ```
146
+ 1. Move tensors to device
147
+ 2. Build node_mask, inpaint_mask (via get_inpaint_mask for "correct" task, all-ones for "generate")
148
+ 3. Sample initial noise for E only; apply inpaint mask: z_T.E * mask + E_given * (~mask)
149
+ 4. Set X = X_given (nodes are given, not diffused)
150
+ 5. Denoising loop s_int from T-1 to 0:
151
+ - model.sample_p_zs_given_zt(s, t, X, E, y, X_index, node_mask, inpaint_mask)
152
+ - yield progress + preview at frame intervals
153
+ 6. Collapse final: PlaceHolder(X, E, y).mask(node_mask, collapse=True)
154
+ 7. Compute changes, render original + corrected images, build GIF
155
+ 8. yield result event
156
+ ```
157
+
158
+ MultiProx init/step follows same structure as graphgen but with:
159
+ - X_index and inpaint_mask passed to `sample_p_zs_given_zt`
160
+ - `model.apply_noise(..., inpaint_mask, gibbs=True)` for re-noising
161
+ - Changes computed at each step for the result event
162
+
163
+ ### Step 3: Add model loading to `src/backend/api/services/registry.py`
164
+
165
+ Add to `ModelRegistry.__init__`:
166
+ ```python
167
+ self._kg_anomaly_models = {} # (dataset_id, task) -> loaded eval-mode model
168
+ ```
169
+
170
+ Add `_load_kg_anomaly_model(self, dataset_id, task)`:
171
+ 1. Check cache `_kg_anomaly_models[(dataset_id, task)]`
172
+ 2. Determine ckpt path: `DIGRESS_KG_DIR / checkpoints / {dataset_id}[_correct].ckpt`
173
+ 3. Load checkpoint to CPU with DDP patching
174
+ 4. Extract `cfg` from `hyper_parameters`
175
+ 5. Infer dims from state_dict shapes (transition_model buffers, model MLP weights)
176
+ 6. Load COINs experiment via `self._load_coins_experiment(dataset_id, "transe")`
177
+ 7. Build mock dataset_infos (MockDataModule with kg_experiment, MockDatasetInfos with dims/distributions)
178
+ 8. Build ExtraFeatures from cfg
179
+ 9. Construct `DiscreteDenoisingDiffusionKG(cfg, dataset_infos, None, None, None, extra_features, domain_features)`
180
+ 10. Load state_dict (strict=False), move to device, eval mode
181
+ 11. Cache and return
182
+
183
+ Add `kg_anomaly_correct_stream(self, ...)`:
184
+ - Acquire inference lock eagerly
185
+ - Load model, build tensors from subgraph
186
+ - Return generator (standard or multiprox init), with lock released in finally block
187
+
188
+ Add `kg_anomaly_continue_stream(self, state_b64)`:
189
+ - Decode state blob eagerly, acquire lock, load model
190
+ - Return generator wrapping `run_multiprox_correction_step`, with lock released in finally block
191
+
192
+ ### Step 4: Add views to `src/backend/api/views/kg_anomaly.py`
193
+
194
+ Add `KgAnomalyCorrectView`:
195
+ - Validate: dataset_id, sampling_mode, task (default "correct"), subgraph (nodes 2-20, edge indices valid), checkpoint availability
196
+ - Map API task "correct" -> model task "inpaint", "generate" -> "generate"
197
+ - For standard: clamp diffusion_steps [50, 1000], chain_frames [10, 30]
198
+ - For multiprox: validate m, t, t_prime, gibbs_chain_freq (same rules as GraphGenGenerateView)
199
+ - Return `_streaming_sse_response(gen)` (import SSE helpers from graph_generation views, or duplicate)
200
+
201
+ Add `KgAnomalyContinueView`:
202
+ - Validate state string
203
+ - Return `_streaming_sse_response(gen)`
204
+
205
+ ### Step 5: Update `src/backend/api/urls.py`
206
+
207
+ Add:
208
+ ```python
209
+ path("kg-anomaly/correct", KgAnomalyCorrectView.as_view()),
210
+ path("kg-anomaly/continue", KgAnomalyContinueView.as_view()),
211
+ ```
212
+
213
+ ### Step 6: Update `docs/api.yaml`
214
+
215
+ Change `/kg-anomaly/correct` and `/kg-anomaly/continue` response schemas from regular JSON to SSE streaming (text/event-stream), matching the graph-generation endpoint pattern. Add SSE event schemas for progress, preview, and result events.
216
+
217
+ ### Step 7: Update `docs/postman/collection.json`
218
+
219
+ Add auto-chaining test scripts to the multiprox correct request (parse SSE result event, save state to `{{multiprox_state}}`). Update continue request to use `{{multiprox_state}}`.
220
+
221
+ ### Step 8: Update `src/backend/README.md`
222
+
223
+ Add `/kg-anomaly/correct` and `/kg-anomaly/continue` to the endpoint table, mark as streaming SSE.
224
+
225
+ ## Critical Files
226
+
227
+ | File | Action | Purpose |
228
+ |------|--------|---------|
229
+ | `src/backend/research_api/settings.py` | Modify | Add `graph_generation/src` to sys.path |
230
+ | `src/backend/api/services/kg_anomaly_inference.py` | Create | Core inference, rendering, change detection |
231
+ | `src/backend/api/services/registry.py` | Modify | Model loading, stream orchestration, lock management |
232
+ | `src/backend/api/views/kg_anomaly.py` | Modify | Add CorrectView and ContinueView |
233
+ | `src/backend/api/urls.py` | Modify | Wire new routes |
234
+ | `docs/api.yaml` | Modify | SSE streaming for kg-anomaly endpoints |
235
+ | `docs/postman/collection.json` | Modify | Auto-chaining for multiprox |
236
+ | `src/backend/README.md` | Modify | Endpoint documentation |
237
+
238
+ ## Research Code Referenced (read-only)
239
+
240
+ | File | What we use |
241
+ |------|-------------|
242
+ | `src/research/COINs-KGGeneration/graph_generation/src/diffusion_model_discrete_kg.py` | `DiscreteDenoisingDiffusionKG` class, `sample_p_zs_given_zt`, `apply_noise` signatures |
243
+ | `src/research/COINs-KGGeneration/graph_generation/src/utils.py` | `get_inpaint_mask`, `PlaceHolder` |
244
+ | `src/research/COINs-KGGeneration/graph_generation/src/diffusion/diffusion_utils.py` | `sample_discrete_feature_noise` |
245
+ | `src/research/COINs-KGGeneration/graph_generation/src/diffusion/extra_features.py` | `ExtraFeatures`, `DummyExtraFeatures` |
246
+ | `src/research/COINs-KGGeneration/graph_generation/src/diffusion/distributions.py` | `DistributionNodes` |
247
+ | `src/research/COINs-KGGeneration/graph_generation/src/datasets/abstract_dataset.py` | `AbstractDatasetInfos` for mock pattern |
248
+ | `src/research/COINs-KGGeneration/graph_generation/src/main.py` | Reference loading sequence (lines 82-138) |
249
+ | `src/research/COINs-KGGeneration/graph_completion/graphs/load_graph.py` | `loader.communities` array |
250
+
251
+ ## Verification
252
+
253
+ 1. `python manage.py runserver 8000` boots without errors (model scanning detects KG anomaly checkpoints)
254
+ 2. `GET /kg-anomaly/datasets` returns `available_tasks: ["generate", "correct"]` for each dataset
255
+ 3. Standard correct (wordnet, correct task): returns SSE stream with progress, preview frames, and result with `original_image`, `corrected_image`, `chain_gif`, `changes`
256
+ 4. Standard generate (wordnet, generate task): same but inpaint_mask is all-ones (all edges regenerated)
257
+ 5. MultiProx init (wordnet): returns SSE result with `state` blob, `step: 0`, `changes`
258
+ 6. MultiProx continue: state round-trips correctly, `step` increments, `round_complete`/`done` flags work
259
+ 7. Inference lock: second concurrent request returns 429 INFERENCE_BUSY
260
+ 8. Postman collection: all 4 KG anomaly inference requests pass with auto-chaining
261
+ 9. All 3 datasets (freebase, wordnet, nell) x 2 tasks (correct, generate) work