Andrej Janchevski commited on
Commit
aaf56bb
·
1 Parent(s): acde928

feat(kg-anomaly): add correct/continue endpoints with SSE streaming

Browse files

Implements the two remaining inference endpoints for the KG anomaly
feature, following the MultiProxAn graph-generation pattern:

- POST /kg-anomaly/correct: standard denoising (correct/generate tasks)
or MultiProx Gibbs init; returns SSE stream with progress, preview,
and terminal result events (before/after images, chain GIF, diff).
- POST /kg-anomaly/continue: advances a MultiProx session one step.
- Adds kg_anomaly_inference.py with tensor building, change detection,
directed subgraph rendering (PIL + networkx), and an apply_edge_noise
helper that task-aware forward-diffuses edges for demo input.
- Extends GET /kg-anomaly/datasets/{id}/sample-subgraphs with
noise_level/task/seed query params so callers can fetch pre-noised
subgraphs ready for correction.
- Registry: adds DiscreteDenoisingDiffusionKG loader that reconstructs
dataset_infos from checkpoint state_dict shapes + COINs experiment.
- Adds graph_generation/src to sys.path for bare imports inside the
research module.
- Updates OpenAPI spec, Postman collection (pre-noised example bodies,
noised-subgraph GET variants, auto-chaining for multiprox), and the
backend README endpoint table.

docs/api.yaml CHANGED
@@ -62,6 +62,40 @@ paths:
62
  schema:
63
  $ref: "#/components/schemas/MethodsResponse"
64
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
  # -- COINs -----------------------------------------------------------
66
  /coins/datasets:
67
  get:
@@ -279,15 +313,19 @@ paths:
279
  post:
280
  operationId: graphGenGenerate
281
  tags: [graph-generation]
282
- summary: Generate a graph
283
  description: |
284
- **Standard mode**: runs full diffusion (T->0), returns animated GIF of
285
- the denoising chain + final PNG. Frontend plays the GIF once.
 
 
 
 
 
286
 
287
- **MultiProx mode**: starts a session, runs the first Gibbs iteration,
288
- returns step 0 image + an opaque `state` blob. Use the
289
- `/graph-generation/continue` endpoint with that state to advance
290
- one step at a time.
291
  requestBody:
292
  required: true
293
  content:
@@ -318,30 +356,11 @@ paths:
318
  t_prime: 0.1
319
  responses:
320
  "200":
321
- description: Generated graph (standard) or session step 0 (multiprox)
322
  content:
323
- application/json:
324
  schema:
325
- oneOf:
326
- - $ref: "#/components/schemas/GraphGenStandardResponse"
327
- - $ref: "#/components/schemas/GraphGenMultiProxResponse"
328
- examples:
329
- standard:
330
- summary: Standard generation result
331
- value:
332
- dataset_id: qm9
333
- model_type: discrete
334
- sampling_mode: standard
335
- image: "data:image/png;base64,..."
336
- chain_gif: "data:image/gif;base64,..."
337
- inference_time_ms: 3200
338
- multiprox:
339
- summary: MultiProx session started
340
- value:
341
- state: "base64-encoded-diffusion-state..."
342
- step: 0
343
- image: "data:image/png;base64,..."
344
- inference_time_ms: 800
345
  "400":
346
  $ref: "#/components/responses/InvalidRequest"
347
  "429":
@@ -353,11 +372,14 @@ paths:
353
  post:
354
  operationId: graphGenContinue
355
  tags: [graph-generation]
356
- summary: Advance MultiProx generation by one step
357
  description: |
358
- Advances the MultiProx multi-measurement chain by one Gibbs iteration.
359
- The client must send back the opaque `state` from the previous step's
360
- response. This keeps the API fully stateless - no server-side sessions.
 
 
 
361
  requestBody:
362
  required: true
363
  content:
@@ -366,11 +388,11 @@ paths:
366
  $ref: "#/components/schemas/GraphGenContinueRequest"
367
  responses:
368
  "200":
369
- description: Step result
370
  content:
371
- application/json:
372
  schema:
373
- $ref: "#/components/schemas/GraphGenMultiProxResponse"
374
  "400":
375
  $ref: "#/components/responses/InvalidRequest"
376
  "429":
@@ -396,7 +418,13 @@ paths:
396
  operationId: getKgAnomalySampleSubgraphs
397
  tags: [kg-anomaly]
398
  summary: Get example subgraphs for correction
399
- description: Returns pre-computed example subgraphs from the test set.
 
 
 
 
 
 
400
  parameters:
401
  - $ref: "#/components/parameters/KgAnomalyDatasetId"
402
  - name: count
@@ -407,6 +435,32 @@ paths:
407
  maximum: 10
408
  default: 5
409
  description: Number of sample subgraphs to return
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
410
  responses:
411
  "200":
412
  description: Sample subgraphs
@@ -421,12 +475,19 @@ paths:
421
  post:
422
  operationId: kgAnomalyCorrect
423
  tags: [kg-anomaly]
424
- summary: Correct a KG subgraph
425
  description: |
426
- **Standard mode**: runs full diffusion correction, returns animated GIF
427
- of the process + before/after images + structured diff.
 
 
428
 
429
- **MultiProx mode** (future): starts a session for step-by-step correction.
 
 
 
 
 
430
  requestBody:
431
  required: true
432
  content:
@@ -476,13 +537,11 @@ paths:
476
  t_prime: 0.1
477
  responses:
478
  "200":
479
- description: Correction result
480
  content:
481
- application/json:
482
  schema:
483
- oneOf:
484
- - $ref: "#/components/schemas/KgAnomalyStandardResponse"
485
- - $ref: "#/components/schemas/KgAnomalyMultiProxResponse"
486
  "400":
487
  $ref: "#/components/responses/InvalidRequest"
488
  "404":
@@ -498,11 +557,13 @@ paths:
498
  post:
499
  operationId: kgAnomalyContinue
500
  tags: [kg-anomaly]
501
- summary: Advance MultiProx correction by one step (future)
502
  description: |
503
- Advances the MultiProx correction chain by one Gibbs iteration.
504
- The client must send back the opaque `state` from the previous step's
505
- response. Returns the updated subgraph image and current diff.
 
 
506
  requestBody:
507
  required: true
508
  content:
@@ -511,11 +572,11 @@ paths:
511
  $ref: "#/components/schemas/KgAnomalyContinueRequest"
512
  responses:
513
  "200":
514
- description: Step result
515
  content:
516
- application/json:
517
  schema:
518
- $ref: "#/components/schemas/KgAnomalyMultiProxResponse"
519
  "400":
520
  $ref: "#/components/responses/InvalidRequest"
521
  "429":
@@ -1291,6 +1352,36 @@ components:
1291
  format: float
1292
  example: 800
1293
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1294
  # -- KG Anomaly Correction: Discovery --
1295
  KgAnomalyDatasetsResponse:
1296
  type: object
@@ -1411,7 +1502,7 @@ components:
1411
  $ref: "#/components/schemas/SamplingModeEnum"
1412
  task:
1413
  $ref: "#/components/schemas/KgAnomalyTaskEnum"
1414
- default: inpaint
1415
  description: |
1416
  "generate" = ignore the input subgraph edges and generate a new subgraph from scratch.
1417
  "correct" (default) = keep fixed edges unchanged, only correct the masked (anomalous) edges.
@@ -1570,3 +1661,33 @@ components:
1570
  removed:
1571
  type: integer
1572
  example: 0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  schema:
63
  $ref: "#/components/schemas/MethodsResponse"
64
 
65
+ /debug/force-unlock:
66
+ post:
67
+ operationId: forceUnlockInferenceLock
68
+ tags: [health]
69
+ summary: Release a stuck inference lock (debug only)
70
+ description: |
71
+ Forcibly releases the global inference lock. Only available when
72
+ the server is running with `DJANGO_DEBUG=True`; returns `403` in
73
+ production. Use when a crashed request left the lock held and
74
+ subsequent requests are returning `429 INFERENCE_BUSY`.
75
+ responses:
76
+ "200":
77
+ description: Lock release result
78
+ content:
79
+ application/json:
80
+ schema:
81
+ type: object
82
+ required: [released]
83
+ properties:
84
+ released:
85
+ type: boolean
86
+ description: True if a held lock was released; false if the lock was already free.
87
+ example: true
88
+ "403":
89
+ description: Not available outside debug mode
90
+ content:
91
+ application/json:
92
+ schema:
93
+ type: object
94
+ properties:
95
+ error:
96
+ type: string
97
+ example: only available in debug mode
98
+
99
  # -- COINs -----------------------------------------------------------
100
  /coins/datasets:
101
  get:
 
313
  post:
314
  operationId: graphGenGenerate
315
  tags: [graph-generation]
316
+ summary: Generate a graph (SSE streaming)
317
  description: |
318
+ Server-Sent Events stream (`text/event-stream`). Emits `progress`
319
+ events during diffusion, optional `preview` events with intermediate
320
+ PNGs, and a terminal `result` event whose `data` payload is the JSON
321
+ described below.
322
+
323
+ **Standard mode**: runs full diffusion (T->0); terminal `result`
324
+ payload conforms to `GraphGenStandardResponse` (animated GIF + final PNG).
325
 
326
+ **MultiProx mode**: runs the first Gibbs iteration; terminal `result`
327
+ payload conforms to `GraphGenMultiProxResponse` and includes an opaque
328
+ `state` blob to be passed to `/graph-generation/continue`.
 
329
  requestBody:
330
  required: true
331
  content:
 
356
  t_prime: 0.1
357
  responses:
358
  "200":
359
+ description: SSE stream of progress/preview events terminated by a result event
360
  content:
361
+ text/event-stream:
362
  schema:
363
+ $ref: "#/components/schemas/GraphGenSseStream"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
364
  "400":
365
  $ref: "#/components/responses/InvalidRequest"
366
  "429":
 
372
  post:
373
  operationId: graphGenContinue
374
  tags: [graph-generation]
375
+ summary: Advance MultiProx generation by one step (SSE streaming)
376
  description: |
377
+ SSE stream (`text/event-stream`). Advances the MultiProx
378
+ multi-measurement chain by one Gibbs iteration. The client must send
379
+ back the opaque `state` from the previous step's `result` event.
380
+ Emits `progress` and `preview` events, then a terminal `result` event
381
+ with the `GraphGenMultiProxResponse` payload (including the updated
382
+ `state`). The API remains fully stateless -- no server-side sessions.
383
  requestBody:
384
  required: true
385
  content:
 
388
  $ref: "#/components/schemas/GraphGenContinueRequest"
389
  responses:
390
  "200":
391
+ description: SSE stream of progress/preview events terminated by a result event
392
  content:
393
+ text/event-stream:
394
  schema:
395
+ $ref: "#/components/schemas/GraphGenSseStream"
396
  "400":
397
  $ref: "#/components/responses/InvalidRequest"
398
  "429":
 
418
  operationId: getKgAnomalySampleSubgraphs
419
  tags: [kg-anomaly]
420
  summary: Get example subgraphs for correction
421
+ description: |
422
+ Returns pre-computed example subgraphs from the test set. When
423
+ `noise_level` is supplied, the model's forward diffusion is applied
424
+ to each subgraph's edges so the caller receives a corrupted input
425
+ ready for `/kg-anomaly/correct`. For `task=correct` only the edges
426
+ inside the inpaint mask (second half of nodes) are noised; for
427
+ `task=generate` every edge is noised.
428
  parameters:
429
  - $ref: "#/components/parameters/KgAnomalyDatasetId"
430
  - name: count
 
435
  maximum: 10
436
  default: 5
437
  description: Number of sample subgraphs to return
438
+ - name: noise_level
439
+ in: query
440
+ required: false
441
+ schema:
442
+ type: number
443
+ minimum: 0.0
444
+ exclusiveMinimum: true
445
+ maximum: 1.0
446
+ description: |
447
+ Fraction of the full diffusion horizon T at which to sample
448
+ noised edges (e.g. 0.4 for moderate corruption). Omit to receive
449
+ the clean subgraphs.
450
+ - name: task
451
+ in: query
452
+ required: false
453
+ schema:
454
+ type: string
455
+ enum: [correct, generate]
456
+ default: correct
457
+ description: Task the noise should align with. Ignored if noise_level is not set.
458
+ - name: seed
459
+ in: query
460
+ required: false
461
+ schema:
462
+ type: integer
463
+ description: Optional RNG seed for reproducible noise.
464
  responses:
465
  "200":
466
  description: Sample subgraphs
 
475
  post:
476
  operationId: kgAnomalyCorrect
477
  tags: [kg-anomaly]
478
+ summary: Correct a KG subgraph (SSE streaming)
479
  description: |
480
+ Server-Sent Events stream (`text/event-stream`). Emits `progress`
481
+ events during diffusion, optional `preview` events with intermediate
482
+ PNGs, and a terminal `result` event whose `data` payload is the JSON
483
+ described below.
484
 
485
+ **Standard mode**: runs full diffusion correction; terminal `result`
486
+ payload conforms to `KgAnomalyStandardResponse`.
487
+
488
+ **MultiProx mode**: runs the first Gibbs iteration; terminal `result`
489
+ payload conforms to `KgAnomalyMultiProxResponse` and includes an
490
+ opaque `state` blob to be passed to `/kg-anomaly/continue`.
491
  requestBody:
492
  required: true
493
  content:
 
537
  t_prime: 0.1
538
  responses:
539
  "200":
540
+ description: SSE stream of progress/preview events terminated by a result event
541
  content:
542
+ text/event-stream:
543
  schema:
544
+ $ref: "#/components/schemas/KgAnomalySseStream"
 
 
545
  "400":
546
  $ref: "#/components/responses/InvalidRequest"
547
  "404":
 
557
  post:
558
  operationId: kgAnomalyContinue
559
  tags: [kg-anomaly]
560
+ summary: Advance MultiProx correction by one step (SSE streaming)
561
  description: |
562
+ SSE stream (`text/event-stream`). Advances the MultiProx correction
563
+ chain by one Gibbs iteration. The client must send back the opaque
564
+ `state` from the previous step's `result` event. Emits `progress`
565
+ and `preview` events, then a terminal `result` event with the
566
+ `KgAnomalyMultiProxResponse` payload (including the updated `state`).
567
  requestBody:
568
  required: true
569
  content:
 
572
  $ref: "#/components/schemas/KgAnomalyContinueRequest"
573
  responses:
574
  "200":
575
+ description: SSE stream of progress/preview events terminated by a result event
576
  content:
577
+ text/event-stream:
578
  schema:
579
+ $ref: "#/components/schemas/KgAnomalySseStream"
580
  "400":
581
  $ref: "#/components/responses/InvalidRequest"
582
  "429":
 
1352
  format: float
1353
  example: 800
1354
 
1355
+ # -- Graph Generation SSE stream ----------------------------------
1356
+ GraphGenSseStream:
1357
+ type: string
1358
+ description: |
1359
+ SSE text stream. Each event is `event: <name>\ndata: <payload>\n\n`.
1360
+
1361
+ * `event: progress` -- payload is `GraphGenProgressEvent` JSON.
1362
+ * `event: preview` -- payload is a raw `data:image/png;base64,...` data URI
1363
+ (intermediate graph snapshot; not JSON).
1364
+ * `event: result` -- payload is a `GraphGenStandardResponse` (standard mode)
1365
+ or `GraphGenMultiProxResponse` (multiprox mode / continue) JSON.
1366
+ * `event: error` -- payload is an error object with `code` and `message`.
1367
+
1368
+ GraphGenProgressEvent:
1369
+ type: object
1370
+ required: [type, stage]
1371
+ properties:
1372
+ type:
1373
+ type: string
1374
+ enum: [progress]
1375
+ stage:
1376
+ type: string
1377
+ description: Current phase (e.g. "denoise", "noise", "refine")
1378
+ step:
1379
+ type: integer
1380
+ description: Current step within the stage
1381
+ total:
1382
+ type: integer
1383
+ description: Total steps in the stage
1384
+
1385
  # -- KG Anomaly Correction: Discovery --
1386
  KgAnomalyDatasetsResponse:
1387
  type: object
 
1502
  $ref: "#/components/schemas/SamplingModeEnum"
1503
  task:
1504
  $ref: "#/components/schemas/KgAnomalyTaskEnum"
1505
+ default: correct
1506
  description: |
1507
  "generate" = ignore the input subgraph edges and generate a new subgraph from scratch.
1508
  "correct" (default) = keep fixed edges unchanged, only correct the masked (anomalous) edges.
 
1661
  removed:
1662
  type: integer
1663
  example: 0
1664
+
1665
+ # -- KG Anomaly SSE stream ----------------------------------------
1666
+ KgAnomalySseStream:
1667
+ type: string
1668
+ description: |
1669
+ SSE text stream. Each event is `event: <name>\ndata: <payload>\n\n`.
1670
+
1671
+ * `event: progress` -- payload is `KgAnomalyProgressEvent` JSON.
1672
+ * `event: preview` -- payload is a raw `data:image/png;base64,...` data URI
1673
+ (intermediate subgraph snapshot; not JSON).
1674
+ * `event: result` -- payload is a `KgAnomalyStandardResponse` (standard mode)
1675
+ or `KgAnomalyMultiProxResponse` (multiprox mode / continue) JSON.
1676
+ * `event: error` -- payload is an error object with `code` and `message`.
1677
+
1678
+ KgAnomalyProgressEvent:
1679
+ type: object
1680
+ required: [type, stage]
1681
+ properties:
1682
+ type:
1683
+ type: string
1684
+ enum: [progress]
1685
+ stage:
1686
+ type: string
1687
+ description: Current phase (e.g. "denoise", "noise", "refine")
1688
+ step:
1689
+ type: integer
1690
+ description: Current step within the stage
1691
+ total:
1692
+ type: integer
1693
+ description: Total steps in the stage
docs/postman/collection.json CHANGED
@@ -612,7 +612,45 @@
612
  { "key": "count", "value": "3" }
613
  ]
614
  },
615
- "description": "Pre-computed example subgraphs for correction."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
616
  }
617
  }
618
  ]
@@ -629,7 +667,7 @@
629
  ],
630
  "body": {
631
  "mode": "raw",
632
- "raw": "{\n \"dataset_id\": \"wordnet\",\n \"sampling_mode\": \"standard\",\n \"task\": \"correct\",\n \"subgraph\": {\n \"nodes\": [\n {\"entity_id\": 11754, \"type_id\": 3},\n {\"entity_id\": 5142, \"type_id\": 3},\n {\"entity_id\": 8142, \"type_id\": 3}\n ],\n \"edges\": [\n {\"source_idx\": 0, \"target_idx\": 1, \"relation_id\": 3},\n {\"source_idx\": 1, \"target_idx\": 2, \"relation_id\": 1}\n ]\n },\n \"diffusion_steps\": 500,\n \"chain_frames\": 20\n}"
633
  },
634
  "url": {
635
  "raw": "{{base_url}}/kg-anomaly/correct",
@@ -648,7 +686,7 @@
648
  ],
649
  "body": {
650
  "mode": "raw",
651
- "raw": "{\n \"dataset_id\": \"wordnet\",\n \"sampling_mode\": \"standard\",\n \"task\": \"generate\",\n \"subgraph\": {\n \"nodes\": [\n {\"entity_id\": 11754, \"type_id\": 3},\n {\"entity_id\": 5142, \"type_id\": 3}\n ],\n \"edges\": [\n {\"source_idx\": 0, \"target_idx\": 1, \"relation_id\": 3}\n ]\n },\n \"diffusion_steps\": 500,\n \"chain_frames\": 20\n}"
652
  },
653
  "url": {
654
  "raw": "{{base_url}}/kg-anomaly/correct",
@@ -659,7 +697,16 @@
659
  }
660
  },
661
  {
662
- "name": "POST /kg-anomaly/correct (multiprox, future)",
 
 
 
 
 
 
 
 
 
663
  "request": {
664
  "method": "POST",
665
  "header": [
@@ -667,18 +714,27 @@
667
  ],
668
  "body": {
669
  "mode": "raw",
670
- "raw": "{\n \"dataset_id\": \"wordnet\",\n \"sampling_mode\": \"multiprox\",\n \"task\": \"correct\",\n \"subgraph\": {\n \"nodes\": [\n {\"entity_id\": 11754, \"type_id\": 3},\n {\"entity_id\": 5142, \"type_id\": 3}\n ],\n \"edges\": [\n {\"source_idx\": 0, \"target_idx\": 1, \"relation_id\": 3}\n ]\n },\n \"diffusion_steps\": 500,\n \"multiprox_params\": {\n \"m\": 10,\n \"t\": 0.5,\n \"t_prime\": 0.1\n }\n}"
671
  },
672
  "url": {
673
  "raw": "{{base_url}}/kg-anomaly/correct",
674
  "host": ["{{base_url}}"],
675
  "path": ["kg-anomaly", "correct"]
676
  },
677
- "description": "MultiProx correction. Returns step 0 + state."
678
  }
679
  },
680
  {
681
  "name": "POST /kg-anomaly/continue",
 
 
 
 
 
 
 
 
 
682
  "request": {
683
  "method": "POST",
684
  "header": [
@@ -686,14 +742,14 @@
686
  ],
687
  "body": {
688
  "mode": "raw",
689
- "raw": "{\n \"state\": \"<paste state from previous response>\"\n}"
690
  },
691
  "url": {
692
  "raw": "{{base_url}}/kg-anomaly/continue",
693
  "host": ["{{base_url}}"],
694
  "path": ["kg-anomaly", "continue"]
695
  },
696
- "description": "Advance MultiProx correction by one step."
697
  }
698
  }
699
  ]
 
612
  { "key": "count", "value": "3" }
613
  ]
614
  },
615
+ "description": "Pre-computed example subgraphs for correction (clean)."
616
+ }
617
+ },
618
+ {
619
+ "name": "GET /kg-anomaly/datasets/{id}/sample-subgraphs (noised, correct)",
620
+ "request": {
621
+ "method": "GET",
622
+ "header": [],
623
+ "url": {
624
+ "raw": "{{base_url}}/kg-anomaly/datasets/wordnet/sample-subgraphs?count=3&noise_level=0.4&task=correct&seed=42",
625
+ "host": ["{{base_url}}"],
626
+ "path": ["kg-anomaly", "datasets", "wordnet", "sample-subgraphs"],
627
+ "query": [
628
+ { "key": "count", "value": "3" },
629
+ { "key": "noise_level", "value": "0.4" },
630
+ { "key": "task", "value": "correct" },
631
+ { "key": "seed", "value": "42" }
632
+ ]
633
+ },
634
+ "description": "Pre-noised example subgraphs for the 'correct' task (only inpaint-mask region is corrupted)."
635
+ }
636
+ },
637
+ {
638
+ "name": "GET /kg-anomaly/datasets/{id}/sample-subgraphs (noised, generate)",
639
+ "request": {
640
+ "method": "GET",
641
+ "header": [],
642
+ "url": {
643
+ "raw": "{{base_url}}/kg-anomaly/datasets/wordnet/sample-subgraphs?count=3&noise_level=0.4&task=generate&seed=43",
644
+ "host": ["{{base_url}}"],
645
+ "path": ["kg-anomaly", "datasets", "wordnet", "sample-subgraphs"],
646
+ "query": [
647
+ { "key": "count", "value": "3" },
648
+ { "key": "noise_level", "value": "0.4" },
649
+ { "key": "task", "value": "generate" },
650
+ { "key": "seed", "value": "43" }
651
+ ]
652
+ },
653
+ "description": "Pre-noised example subgraphs for the 'generate' task (all edges corrupted)."
654
  }
655
  }
656
  ]
 
667
  ],
668
  "body": {
669
  "mode": "raw",
670
+ "raw": "{\n \"dataset_id\": \"wordnet\",\n \"sampling_mode\": \"standard\",\n \"task\": \"correct\",\n \"subgraph\": {\n \"nodes\": [\n {\"entity_id\": 28155, \"type_id\": 1},\n {\"entity_id\": 29348, \"type_id\": 4},\n {\"entity_id\": 29358, \"type_id\": 1},\n {\"entity_id\": 36247, \"type_id\": 1},\n {\"entity_id\": 36248, \"type_id\": 4},\n {\"entity_id\": 36855, \"type_id\": 1},\n {\"entity_id\": 36858, \"type_id\": 4},\n {\"entity_id\": 36860, \"type_id\": 4},\n {\"entity_id\": 36881, \"type_id\": 1},\n {\"entity_id\": 39993, \"type_id\": 1}\n ],\n \"edges\": [\n {\"source_idx\": 1, \"target_idx\": 2, \"relation_id\": 1},\n {\"source_idx\": 1, \"target_idx\": 3, \"relation_id\": 1},\n {\"source_idx\": 2, \"target_idx\": 1, \"relation_id\": 1},\n {\"source_idx\": 2, \"target_idx\": 4, \"relation_id\": 1},\n {\"source_idx\": 3, \"target_idx\": 1, \"relation_id\": 1},\n {\"source_idx\": 3, \"target_idx\": 4, \"relation_id\": 1},\n {\"source_idx\": 4, \"target_idx\": 2, \"relation_id\": 1},\n {\"source_idx\": 4, \"target_idx\": 3, \"relation_id\": 1},\n {\"source_idx\": 5, \"target_idx\": 6, \"relation_id\": 2},\n {\"source_idx\": 5, \"target_idx\": 7, \"relation_id\": 1},\n {\"source_idx\": 5, \"target_idx\": 8, \"relation_id\": 5},\n {\"source_idx\": 5, \"target_idx\": 9, \"relation_id\": 4},\n {\"source_idx\": 6, \"target_idx\": 4, \"relation_id\": 3},\n {\"source_idx\": 6, \"target_idx\": 5, \"relation_id\": 1},\n {\"source_idx\": 6, \"target_idx\": 8, \"relation_id\": 1},\n {\"source_idx\": 7, \"target_idx\": 5, \"relation_id\": 10},\n {\"source_idx\": 7, \"target_idx\": 6, \"relation_id\": 10},\n {\"source_idx\": 8, \"target_idx\": 5, \"relation_id\": 1},\n {\"source_idx\": 8, \"target_idx\": 6, \"relation_id\": 1},\n {\"source_idx\": 8, \"target_idx\": 7, \"relation_id\": 1},\n {\"source_idx\": 8, \"target_idx\": 9, \"relation_id\": 1},\n {\"source_idx\": 9, \"target_idx\": 0, \"relation_id\": 3},\n {\"source_idx\": 9, \"target_idx\": 6, \"relation_id\": 10},\n {\"source_idx\": 9, \"target_idx\": 7, \"relation_id\": 3},\n {\"source_idx\": 9, \"target_idx\": 8, \"relation_id\": 10}\n ]\n },\n \"diffusion_steps\": 500,\n \"chain_frames\": 20\n}"
671
  },
672
  "url": {
673
  "raw": "{{base_url}}/kg-anomaly/correct",
 
686
  ],
687
  "body": {
688
  "mode": "raw",
689
+ "raw": "{\n \"dataset_id\": \"wordnet\",\n \"sampling_mode\": \"standard\",\n \"task\": \"generate\",\n \"subgraph\": {\n \"nodes\": [\n {\"entity_id\": 28155, \"type_id\": 1},\n {\"entity_id\": 29348, \"type_id\": 4},\n {\"entity_id\": 29358, \"type_id\": 1},\n {\"entity_id\": 36247, \"type_id\": 1},\n {\"entity_id\": 36248, \"type_id\": 4},\n {\"entity_id\": 36855, \"type_id\": 1},\n {\"entity_id\": 36858, \"type_id\": 4},\n {\"entity_id\": 36860, \"type_id\": 4},\n {\"entity_id\": 36881, \"type_id\": 1},\n {\"entity_id\": 39993, \"type_id\": 1}\n ],\n \"edges\": [\n {\"source_idx\": 0, \"target_idx\": 2, \"relation_id\": 10},\n {\"source_idx\": 0, \"target_idx\": 4, \"relation_id\": 2},\n {\"source_idx\": 1, \"target_idx\": 0, \"relation_id\": 0},\n {\"source_idx\": 1, \"target_idx\": 2, \"relation_id\": 2},\n {\"source_idx\": 1, \"target_idx\": 3, \"relation_id\": 1},\n {\"source_idx\": 1, \"target_idx\": 5, \"relation_id\": 9},\n {\"source_idx\": 1, \"target_idx\": 7, \"relation_id\": 7},\n {\"source_idx\": 1, \"target_idx\": 8, \"relation_id\": 3},\n {\"source_idx\": 2, \"target_idx\": 0, \"relation_id\": 7},\n {\"source_idx\": 2, \"target_idx\": 1, \"relation_id\": 1},\n {\"source_idx\": 2, \"target_idx\": 4, \"relation_id\": 5},\n {\"source_idx\": 2, \"target_idx\": 8, \"relation_id\": 10},\n {\"source_idx\": 2, \"target_idx\": 9, \"relation_id\": 2},\n {\"source_idx\": 3, \"target_idx\": 1, \"relation_id\": 1},\n {\"source_idx\": 3, \"target_idx\": 4, \"relation_id\": 1},\n {\"source_idx\": 3, \"target_idx\": 5, \"relation_id\": 7},\n {\"source_idx\": 3, \"target_idx\": 6, \"relation_id\": 6},\n {\"source_idx\": 3, \"target_idx\": 7, \"relation_id\": 0},\n {\"source_idx\": 4, \"target_idx\": 2, \"relation_id\": 1},\n {\"source_idx\": 4, \"target_idx\": 3, \"relation_id\": 6},\n {\"source_idx\": 4, \"target_idx\": 6, \"relation_id\": 7},\n {\"source_idx\": 4, \"target_idx\": 7, \"relation_id\": 7},\n {\"source_idx\": 5, \"target_idx\": 4, \"relation_id\": 2},\n {\"source_idx\": 5, \"target_idx\": 6, \"relation_id\": 2},\n {\"source_idx\": 5, \"target_idx\": 7, \"relation_id\": 6},\n {\"source_idx\": 5, \"target_idx\": 8, \"relation_id\": 1},\n {\"source_idx\": 5, \"target_idx\": 9, \"relation_id\": 1},\n {\"source_idx\": 6, \"target_idx\": 0, \"relation_id\": 5},\n {\"source_idx\": 6, \"target_idx\": 3, \"relation_id\": 7},\n {\"source_idx\": 6, \"target_idx\": 4, \"relation_id\": 3},\n {\"source_idx\": 6, \"target_idx\": 5, \"relation_id\": 1},\n {\"source_idx\": 6, \"target_idx\": 7, \"relation_id\": 0},\n {\"source_idx\": 6, \"target_idx\": 8, \"relation_id\": 1},\n {\"source_idx\": 6, \"target_idx\": 9, \"relation_id\": 4},\n {\"source_idx\": 7, \"target_idx\": 2, \"relation_id\": 10},\n {\"source_idx\": 7, \"target_idx\": 5, \"relation_id\": 5},\n {\"source_idx\": 7, \"target_idx\": 6, \"relation_id\": 5},\n {\"source_idx\": 8, \"target_idx\": 0, \"relation_id\": 0},\n {\"source_idx\": 8, \"target_idx\": 2, \"relation_id\": 6},\n {\"source_idx\": 8, \"target_idx\": 4, \"relation_id\": 10},\n {\"source_idx\": 8, \"target_idx\": 5, \"relation_id\": 1},\n {\"source_idx\": 8, \"target_idx\": 6, \"relation_id\": 1},\n {\"source_idx\": 8, \"target_idx\": 7, \"relation_id\": 1},\n {\"source_idx\": 8, \"target_idx\": 9, \"relation_id\": 0},\n {\"source_idx\": 9, \"target_idx\": 0, \"relation_id\": 3},\n {\"source_idx\": 9, \"target_idx\": 1, \"relation_id\": 8},\n {\"source_idx\": 9, \"target_idx\": 3, \"relation_id\": 5},\n {\"source_idx\": 9, \"target_idx\": 7, \"relation_id\": 5},\n {\"source_idx\": 9, \"target_idx\": 8, \"relation_id\": 1}\n ]\n },\n \"diffusion_steps\": 500,\n \"chain_frames\": 20\n}"
690
  },
691
  "url": {
692
  "raw": "{{base_url}}/kg-anomaly/correct",
 
697
  }
698
  },
699
  {
700
+ "name": "POST /kg-anomaly/correct (multiprox init)",
701
+ "event": [
702
+ {
703
+ "listen": "test",
704
+ "script": {
705
+ "type": "text/javascript",
706
+ "exec": ["var body = pm.response.text();", "var lines = body.split('\\n');", "for (var i = 0; i < lines.length; i++) {", " if (lines[i].trim() === 'event: result' && i + 1 < lines.length) {", " var dataLine = lines[i + 1].replace(/^data: /, '');", " try {", " var result = JSON.parse(dataLine);", " if (result.state) { pm.collectionVariables.set('multiprox_state', result.state); }", " } catch (e) {}", " break;", " }", "}"]
707
+ }
708
+ }
709
+ ],
710
  "request": {
711
  "method": "POST",
712
  "header": [
 
714
  ],
715
  "body": {
716
  "mode": "raw",
717
+ "raw": "{\n \"dataset_id\": \"wordnet\",\n \"sampling_mode\": \"multiprox\",\n \"task\": \"correct\",\n \"subgraph\": {\n \"nodes\": [\n {\"entity_id\": 28155, \"type_id\": 1},\n {\"entity_id\": 29348, \"type_id\": 4},\n {\"entity_id\": 29358, \"type_id\": 1},\n {\"entity_id\": 36247, \"type_id\": 1},\n {\"entity_id\": 36248, \"type_id\": 4},\n {\"entity_id\": 36855, \"type_id\": 1},\n {\"entity_id\": 36858, \"type_id\": 4},\n {\"entity_id\": 36860, \"type_id\": 4},\n {\"entity_id\": 36881, \"type_id\": 1},\n {\"entity_id\": 39993, \"type_id\": 1}\n ],\n \"edges\": [\n {\"source_idx\": 1, \"target_idx\": 2, \"relation_id\": 1},\n {\"source_idx\": 1, \"target_idx\": 3, \"relation_id\": 1},\n {\"source_idx\": 2, \"target_idx\": 1, \"relation_id\": 1},\n {\"source_idx\": 2, \"target_idx\": 4, \"relation_id\": 1},\n {\"source_idx\": 3, \"target_idx\": 1, \"relation_id\": 1},\n {\"source_idx\": 3, \"target_idx\": 4, \"relation_id\": 1},\n {\"source_idx\": 4, \"target_idx\": 2, \"relation_id\": 1},\n {\"source_idx\": 4, \"target_idx\": 3, \"relation_id\": 1},\n {\"source_idx\": 5, \"target_idx\": 6, \"relation_id\": 2},\n {\"source_idx\": 5, \"target_idx\": 7, \"relation_id\": 1},\n {\"source_idx\": 5, \"target_idx\": 8, \"relation_id\": 5},\n {\"source_idx\": 5, \"target_idx\": 9, \"relation_id\": 4},\n {\"source_idx\": 6, \"target_idx\": 4, \"relation_id\": 3},\n {\"source_idx\": 6, \"target_idx\": 5, \"relation_id\": 1},\n {\"source_idx\": 6, \"target_idx\": 8, \"relation_id\": 1},\n {\"source_idx\": 7, \"target_idx\": 5, \"relation_id\": 10},\n {\"source_idx\": 7, \"target_idx\": 6, \"relation_id\": 10},\n {\"source_idx\": 8, \"target_idx\": 5, \"relation_id\": 1},\n {\"source_idx\": 8, \"target_idx\": 6, \"relation_id\": 1},\n {\"source_idx\": 8, \"target_idx\": 7, \"relation_id\": 1},\n {\"source_idx\": 8, \"target_idx\": 9, \"relation_id\": 1},\n {\"source_idx\": 9, \"target_idx\": 0, \"relation_id\": 3},\n {\"source_idx\": 9, \"target_idx\": 6, \"relation_id\": 10},\n {\"source_idx\": 9, \"target_idx\": 7, \"relation_id\": 3},\n {\"source_idx\": 9, \"target_idx\": 8, \"relation_id\": 10}\n ]\n },\n \"multiprox_params\": {\n \"n\": 10,\n \"m\": 100,\n \"t\": 0.4,\n \"t_prime\": 0.1,\n \"gibbs_chain_freq\": 10\n }\n}"
718
  },
719
  "url": {
720
  "raw": "{{base_url}}/kg-anomaly/correct",
721
  "host": ["{{base_url}}"],
722
  "path": ["kg-anomaly", "correct"]
723
  },
724
+ "description": "MultiProx Gibbs init on wordnet correction. SSE stream; the result event's state blob is auto-saved to {{multiprox_state}}."
725
  }
726
  },
727
  {
728
  "name": "POST /kg-anomaly/continue",
729
+ "event": [
730
+ {
731
+ "listen": "test",
732
+ "script": {
733
+ "type": "text/javascript",
734
+ "exec": ["var body = pm.response.text();", "var lines = body.split('\\n');", "for (var i = 0; i < lines.length; i++) {", " if (lines[i].trim() === 'event: result' && i + 1 < lines.length) {", " var dataLine = lines[i + 1].replace(/^data: /, '');", " try {", " var result = JSON.parse(dataLine);", " if (result.state) {", " pm.collectionVariables.set('multiprox_state', result.state);", " console.log('State updated (done=' + result.done + ', step=' + result.step + ')');", " }", " } catch (e) {}", " break;", " }", "}"]
735
+ }
736
+ }
737
+ ],
738
  "request": {
739
  "method": "POST",
740
  "header": [
 
742
  ],
743
  "body": {
744
  "mode": "raw",
745
+ "raw": "{\n \"state\": \"{{multiprox_state}}\"\n}"
746
  },
747
  "url": {
748
  "raw": "{{base_url}}/kg-anomaly/continue",
749
  "host": ["{{base_url}}"],
750
  "path": ["kg-anomaly", "continue"]
751
  },
752
+ "description": "Advance MultiProx correction by one step. Uses {{multiprox_state}}; can be chained repeatedly."
753
  }
754
  }
755
  ]
src/backend/README.md CHANGED
@@ -88,17 +88,17 @@ All endpoints are prefixed with `/api/v1/`.
88
  |---|---|---|
89
  | `GET` | `/graph-generation/datasets` | List graph types with node/edge types |
90
  | `GET` | `/graph-generation/sampling-modes` | Sampling strategies with parameter specs |
91
- | `POST` | `/graph-generation/generate` | **Streaming NDJSON.** Generate a graph (standard denoising or MultiProx Gibbs init) |
92
- | `POST` | `/graph-generation/continue` | **Streaming NDJSON.** Advance a MultiProx Gibbs session by one step |
93
 
94
  ### KG Anomaly Correction
95
 
96
  | Method | Path | Description |
97
  |---|---|---|
98
  | `GET` | `/kg-anomaly/datasets` | List datasets with correction models |
99
- | `GET` | `/kg-anomaly/datasets/{id}/sample-subgraphs` | Pre-computed example subgraphs (`?count=5`) |
100
- | `POST` | `/kg-anomaly/correct` | Run correction (not yet implemented) |
101
- | `POST` | `/kg-anomaly/continue` | Continue MultiProx correction (not yet implemented) |
102
 
103
  ## Streaming Inference Protocol (SSE)
104
 
 
88
  |---|---|---|
89
  | `GET` | `/graph-generation/datasets` | List graph types with node/edge types |
90
  | `GET` | `/graph-generation/sampling-modes` | Sampling strategies with parameter specs |
91
+ | `POST` | `/graph-generation/generate` | **Streaming SSE.** Generate a graph (standard denoising or MultiProx Gibbs init) |
92
+ | `POST` | `/graph-generation/continue` | **Streaming SSE.** Advance a MultiProx Gibbs session by one step |
93
 
94
  ### KG Anomaly Correction
95
 
96
  | Method | Path | Description |
97
  |---|---|---|
98
  | `GET` | `/kg-anomaly/datasets` | List datasets with correction models |
99
+ | `GET` | `/kg-anomaly/datasets/{id}/sample-subgraphs` | Pre-computed example subgraphs (`?count=5&noise_level=0.4&task=correct&seed=42`); noise is task-aware |
100
+ | `POST` | `/kg-anomaly/correct` | **Streaming SSE.** Correct/regenerate a KG subgraph (standard denoising or MultiProx Gibbs init) |
101
+ | `POST` | `/kg-anomaly/continue` | **Streaming SSE.** Advance a MultiProx correction session by one step |
102
 
103
  ## Streaming Inference Protocol (SSE)
104
 
src/backend/api/services/kg_anomaly_inference.py ADDED
@@ -0,0 +1,708 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import base64
2
+ import io
3
+ import time
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ from api.services.graphgen_inference import (
9
+ _frames_to_gif_b64, _pil_to_b64,
10
+ )
11
+
12
+ STATE_BLOB_MAX_BYTES = 10 * 1024 * 1024 # 10 MB
13
+ REQUIRED_STATE_KEYS = {
14
+ "X_given", "E", "y", "n_nodes", "dataset_id", "task", "X_index", "X_c",
15
+ "is_bip", "original_E_int", "T", "n", "m", "t", "t_prime",
16
+ "gibbs_chain_freq", "inner_step", "step",
17
+ }
18
+
19
+ CHANGE_COLORS = {
20
+ "unchanged": "#888888",
21
+ "modified": "#e67e22",
22
+ "added": "#27ae60",
23
+ "removed": "#e74c3c",
24
+ }
25
+
26
+
27
+ # ---------------------------------------------------------------------------
28
+ # Input subgraph -> tensors
29
+ # ---------------------------------------------------------------------------
30
+
31
+ def build_kg_tensors(subgraph, loader, model):
32
+ """Convert API subgraph payload into model-ready tensors.
33
+
34
+ Returns a dict with X_given, E_given, y_given, X_index, X_c, n_nodes,
35
+ is_bip, node_mask — all on CPU, batch size 1.
36
+ """
37
+ nodes = subgraph["nodes"]
38
+ edges = subgraph["edges"]
39
+ n = len(nodes)
40
+
41
+ X_given = torch.zeros(1, n, model.Xdim_output, dtype=torch.float32)
42
+ for i, node in enumerate(nodes):
43
+ type_id = int(node.get("type_id", 0))
44
+ if 0 <= type_id < model.Xdim_output:
45
+ X_given[0, i, type_id] = 1.0
46
+ else:
47
+ X_given[0, i, 0] = 1.0
48
+
49
+ E_given = torch.zeros(1, n, n, model.Edim_output, dtype=torch.float32)
50
+ # Default to class 0 ("no edge") everywhere
51
+ E_given[0, :, :, 0] = 1.0
52
+ for e in edges:
53
+ src = int(e["source_idx"])
54
+ tgt = int(e["target_idx"])
55
+ rel = int(e["relation_id"])
56
+ e_class = rel + 1
57
+ if not (0 <= src < n and 0 <= tgt < n):
58
+ continue
59
+ if not (1 <= e_class < model.Edim_output):
60
+ continue
61
+ E_given[0, src, tgt, :] = 0.0
62
+ E_given[0, src, tgt, e_class] = 1.0
63
+
64
+ y_given = torch.zeros(1, 0, dtype=torch.float32)
65
+
66
+ X_index = torch.zeros(1, n, dtype=torch.long)
67
+ for i, node in enumerate(nodes):
68
+ X_index[0, i] = int(node["entity_id"])
69
+
70
+ X_c = torch.zeros(1, n, dtype=torch.long)
71
+ communities = getattr(loader, "communities", None)
72
+ if communities is not None:
73
+ for i, node in enumerate(nodes):
74
+ eid = int(node["entity_id"])
75
+ if 0 <= eid < len(communities):
76
+ X_c[0, i] = int(communities[eid])
77
+
78
+ n_nodes = torch.tensor([n], dtype=torch.long)
79
+ is_bip = torch.tensor([n > 20], dtype=torch.bool)
80
+ node_mask = torch.ones(1, n, dtype=torch.bool)
81
+
82
+ return {
83
+ "X_given": X_given, "E_given": E_given, "y_given": y_given,
84
+ "X_index": X_index, "X_c": X_c, "n_nodes": n_nodes,
85
+ "is_bip": is_bip, "node_mask": node_mask,
86
+ }
87
+
88
+
89
+ def _to_device(t, device):
90
+ return t.to(device) if isinstance(t, torch.Tensor) else t
91
+
92
+
93
+ def apply_edge_noise(model, tensors, task, noise_level, seed=None):
94
+ """Forward-diffuse the given subgraph's edges at t = noise_level * T.
95
+
96
+ For task="correct", only edges inside the inpaint mask (the second half of
97
+ nodes) are noised, matching what the correction endpoint will regenerate.
98
+ For task="generate", every edge slot is noised.
99
+
100
+ Returns a new list of {source_idx, target_idx, relation_id} dicts.
101
+ """
102
+ from graph_generation.src.utils import get_inpaint_mask
103
+ from graph_generation.src.diffusion import diffusion_utils
104
+
105
+ if not (0.0 < noise_level <= 1.0):
106
+ raise ValueError("noise_level must be in (0, 1]")
107
+
108
+ device = next(model.parameters()).device
109
+ X = tensors["X_given"].to(device)
110
+ E = tensors["E_given"].to(device)
111
+ node_mask = tensors["node_mask"].to(device)
112
+ is_bip = tensors["is_bip"].to(device)
113
+ n = int(tensors["n_nodes"].item())
114
+
115
+ if task == "generate":
116
+ bs, n_max = node_mask.shape
117
+ inpaint_mask = torch.ones(
118
+ bs, n_max, n_max, model.Edim_output, dtype=torch.bool, device=device)
119
+ else:
120
+ inpaint_mask = get_inpaint_mask(node_mask, is_bip, model.Edim_output, device)
121
+
122
+ T = model.T
123
+ t_int = torch.tensor([[int(noise_level * T)]], dtype=torch.float, device=device)
124
+ t_float = t_int / T
125
+ alpha_t_bar = model.noise_schedule.get_alpha_bar(t_normalized=t_float)
126
+ Qtb = model.transition_model.get_Qt_bar(alpha_t_bar, device=device)
127
+
128
+ probX = X @ Qtb.X
129
+ probE = E @ Qtb.E.unsqueeze(1)
130
+
131
+ if seed is not None:
132
+ torch.manual_seed(int(seed))
133
+
134
+ sampled = diffusion_utils.sample_discrete_features(
135
+ probX=probX, probE=probE, node_mask=node_mask)
136
+ E_noised = F.one_hot(sampled.E, num_classes=model.Edim_output).float()
137
+ E_mixed = E_noised * inpaint_mask + E * (~inpaint_mask)
138
+ E_int = E_mixed[0].argmax(dim=-1).cpu()
139
+
140
+ edges = []
141
+ for i in range(n):
142
+ for j in range(n):
143
+ if i == j:
144
+ continue
145
+ cls = int(E_int[i, j])
146
+ if cls == 0:
147
+ continue
148
+ edges.append({
149
+ "source_idx": i, "target_idx": j, "relation_id": cls - 1,
150
+ })
151
+ return edges
152
+
153
+
154
+ # ---------------------------------------------------------------------------
155
+ # Change detection
156
+ # ---------------------------------------------------------------------------
157
+
158
+ def compute_changes(original_E_int, corrected_E_int, num_nodes, loader):
159
+ """Compute before/after edge diff for a directed KG subgraph.
160
+
161
+ original_E_int / corrected_E_int: 2-D int tensors (n, n) where 0 = no edge
162
+ and classes 1..N are relation types. Returns {"edges": [...], "summary": {...}}.
163
+ """
164
+ _, _, inv_relations = loader.dataset.get_inverted_name_maps()
165
+
166
+ edges = []
167
+ summary = {"added": 0, "removed": 0, "modified": 0, "unchanged": 0}
168
+
169
+ orig = original_E_int.cpu().tolist()
170
+ corr = corrected_E_int.cpu().tolist()
171
+
172
+ for i in range(num_nodes):
173
+ for j in range(num_nodes):
174
+ if i == j:
175
+ continue
176
+ o = int(orig[i][j])
177
+ c = int(corr[i][j])
178
+ if o == 0 and c == 0:
179
+ continue
180
+ if o == c:
181
+ summary["unchanged"] += 1
182
+ edges.append({
183
+ "source_idx": i, "target_idx": j, "change": "unchanged",
184
+ "relation_id": c - 1,
185
+ "relation_name": str(inv_relations.get(c - 1, c - 1)),
186
+ })
187
+ continue
188
+ if o == 0 and c > 0:
189
+ summary["added"] += 1
190
+ edges.append({
191
+ "source_idx": i, "target_idx": j, "change": "added",
192
+ "relation_id": c - 1,
193
+ "relation_name": str(inv_relations.get(c - 1, c - 1)),
194
+ })
195
+ elif o > 0 and c == 0:
196
+ summary["removed"] += 1
197
+ edges.append({
198
+ "source_idx": i, "target_idx": j, "change": "removed",
199
+ "original_relation_id": o - 1,
200
+ "original_relation_name": str(inv_relations.get(o - 1, o - 1)),
201
+ })
202
+ else:
203
+ summary["modified"] += 1
204
+ edges.append({
205
+ "source_idx": i, "target_idx": j, "change": "modified",
206
+ "original_relation_id": o - 1,
207
+ "original_relation_name": str(inv_relations.get(o - 1, o - 1)),
208
+ "relation_id": c - 1,
209
+ "relation_name": str(inv_relations.get(c - 1, c - 1)),
210
+ })
211
+
212
+ return {"edges": edges, "summary": summary}
213
+
214
+
215
+ # ---------------------------------------------------------------------------
216
+ # Rendering
217
+ # ---------------------------------------------------------------------------
218
+
219
+ def _format_entity_label(dataset_id, name):
220
+ s = str(name)
221
+ if dataset_id == "freebase":
222
+ s = s.replace("/m/", "")
223
+ elif dataset_id == "wordnet":
224
+ s = s.split(".")[0]
225
+ else:
226
+ if "concept" in s:
227
+ parts = s.split(":")
228
+ s = parts[-2] if "new" in s and len(parts) >= 2 else parts[-1]
229
+ if len(s) > 14:
230
+ s = s[:13] + "…"
231
+ return s
232
+
233
+
234
+ def _format_relation_label(dataset_id, name):
235
+ s = str(name)
236
+ if dataset_id == "freebase":
237
+ parts = s.split(".")
238
+ s = ".".join(["_".join(p.split("/")[-2:]) for p in parts])
239
+ elif dataset_id == "wordnet":
240
+ s = s[1:] if s.startswith("_") else s
241
+ else:
242
+ if "concept" in s:
243
+ parts = s.split(":")
244
+ s = parts[-2] if "new" in s and len(parts) >= 2 else parts[-1]
245
+ if len(s) > 16:
246
+ s = s[:15] + "…"
247
+ return s
248
+
249
+
250
+ def render_kg_subgraph(E_int, num_nodes, X_index, dataset_id, loader, changes=None):
251
+ """Render a directed KG subgraph as a PIL image using networkx + PIL.
252
+
253
+ Does not use matplotlib (same reason as graphgen_inference: Windows thread safety).
254
+ """
255
+ import networkx as nx
256
+ from PIL import Image, ImageDraw, ImageFont
257
+
258
+ inv_nodes, _, inv_relations = loader.dataset.get_inverted_name_maps()
259
+
260
+ e = E_int.cpu().tolist()
261
+ xi = X_index.cpu().tolist()
262
+
263
+ G = nx.DiGraph()
264
+ for i in range(num_nodes):
265
+ G.add_node(i)
266
+ for i in range(num_nodes):
267
+ for j in range(num_nodes):
268
+ if i == j:
269
+ continue
270
+ if int(e[i][j]) > 0:
271
+ G.add_edge(i, j, rel=int(e[i][j]) - 1)
272
+
273
+ pos = nx.spring_layout(G, seed=42)
274
+
275
+ # Build change lookup: (i, j) -> change_type
276
+ change_lookup = {}
277
+ if changes is not None:
278
+ for entry in changes.get("edges", []):
279
+ change_lookup[(entry["source_idx"], entry["target_idx"])] = entry["change"]
280
+
281
+ size = 500
282
+ margin = 50
283
+ scale = (size - 2 * margin) / 2
284
+ cx, cy = size / 2, size / 2
285
+ pixel_pos = {k: (cx + v[0] * scale, cy + v[1] * scale) for k, v in pos.items()}
286
+
287
+ img = Image.new("RGB", (size, size), "white")
288
+ draw = ImageDraw.Draw(img)
289
+ try:
290
+ font = ImageFont.truetype("arial.ttf", 11)
291
+ small_font = ImageFont.truetype("arial.ttf", 9)
292
+ except (OSError, IOError):
293
+ font = ImageFont.load_default()
294
+ small_font = font
295
+
296
+ node_r = 10
297
+
298
+ # Draw edges first (so nodes overlay them)
299
+ # Include "removed" edges from change_lookup even if not in G
300
+ all_edges = set((i, j) for i, j in G.edges())
301
+ if changes is not None:
302
+ for (i, j), ct in change_lookup.items():
303
+ if ct == "removed":
304
+ all_edges.add((i, j))
305
+
306
+ for (i, j) in all_edges:
307
+ change_type = change_lookup.get((i, j))
308
+ color = CHANGE_COLORS.get(change_type, "#444444") if changes is not None else "#444444"
309
+ dashed = (change_type == "removed")
310
+ x0, y0 = pixel_pos[i]
311
+ x1, y1 = pixel_pos[j]
312
+ # Shorten line to not overlap node circles
313
+ dx, dy = x1 - x0, y1 - y0
314
+ dist = max(1.0, (dx * dx + dy * dy) ** 0.5)
315
+ ux, uy = dx / dist, dy / dist
316
+ sx, sy = x0 + ux * node_r, y0 + uy * node_r
317
+ ex, ey = x1 - ux * node_r, y1 - uy * node_r
318
+ if dashed:
319
+ _draw_dashed(draw, (sx, sy), (ex, ey), color, width=2, dash=6)
320
+ else:
321
+ draw.line([(sx, sy), (ex, ey)], fill=color, width=2)
322
+ # Arrowhead
323
+ _draw_arrowhead(draw, (ex, ey), (ux, uy), color)
324
+ # Relation label
325
+ if (i, j) in G.edges():
326
+ rel_id = G.edges[(i, j)]["rel"]
327
+ rel_name = _format_relation_label(dataset_id, inv_relations.get(rel_id, rel_id))
328
+ mx, my = (sx + ex) / 2, (sy + ey) / 2
329
+ draw.text((mx + 3, my - 5), rel_name, fill=color, font=small_font)
330
+
331
+ # Draw nodes
332
+ for i in range(num_nodes):
333
+ x, y = pixel_pos[i]
334
+ draw.ellipse([x - node_r, y - node_r, x + node_r, y + node_r],
335
+ fill="#2ecc71", outline="#1a7a42")
336
+ eid = int(xi[i]) if i < len(xi) else i
337
+ label = _format_entity_label(dataset_id, inv_nodes.get(eid, eid))
338
+ draw.text((x + node_r + 2, y - 6), label, fill="#111111", font=font)
339
+
340
+ return img
341
+
342
+
343
+ def _draw_arrowhead(draw, tip, direction, color):
344
+ import math
345
+ ux, uy = direction
346
+ angle = math.atan2(uy, ux)
347
+ ah_len = 7
348
+ ah_angle = math.radians(25)
349
+ x, y = tip
350
+ x1 = x - ah_len * math.cos(angle - ah_angle)
351
+ y1 = y - ah_len * math.sin(angle - ah_angle)
352
+ x2 = x - ah_len * math.cos(angle + ah_angle)
353
+ y2 = y - ah_len * math.sin(angle + ah_angle)
354
+ draw.polygon([(x, y), (x1, y1), (x2, y2)], fill=color)
355
+
356
+
357
+ def _draw_dashed(draw, start, end, color, width=2, dash=6):
358
+ x0, y0 = start
359
+ x1, y1 = end
360
+ dx, dy = x1 - x0, y1 - y0
361
+ dist = max(1.0, (dx * dx + dy * dy) ** 0.5)
362
+ steps = int(dist // dash)
363
+ ux, uy = dx / dist, dy / dist
364
+ for k in range(steps):
365
+ if k % 2 == 1:
366
+ continue
367
+ sx = x0 + ux * dash * k
368
+ sy = y0 + uy * dash * k
369
+ ex = x0 + ux * dash * min(k + 1, steps)
370
+ ey = y0 + uy * dash * min(k + 1, steps)
371
+ draw.line([(sx, sy), (ex, ey)], fill=color, width=width)
372
+
373
+
374
+ # ---------------------------------------------------------------------------
375
+ # Shared inference helpers
376
+ # ---------------------------------------------------------------------------
377
+
378
+ def _build_inpaint_mask(task, node_mask, is_bip, E_out_dim, device):
379
+ from graph_generation.src.utils import get_inpaint_mask
380
+ if task == "generate":
381
+ bs, n_max = node_mask.shape
382
+ return torch.ones(bs, n_max, n_max, E_out_dim, dtype=torch.bool, device=device)
383
+ return get_inpaint_mask(node_mask, is_bip, E_out_dim, device)
384
+
385
+
386
+ def _sample_initial_noise_kg(model, node_mask):
387
+ from graph_generation.src.diffusion import diffusion_utils
388
+ return diffusion_utils.sample_discrete_feature_noise(
389
+ limit_dist=model.limit_dist, node_mask=node_mask)
390
+
391
+
392
+ def _collapse_final_kg(model, X, E, y, node_mask):
393
+ from graph_generation.src.utils import PlaceHolder
394
+ final = PlaceHolder(X=X, E=E, y=y).mask(node_mask, collapse=True)
395
+ return final.X.long(), final.E.long()
396
+
397
+
398
+ # ---------------------------------------------------------------------------
399
+ # Standard correction / generation
400
+ # ---------------------------------------------------------------------------
401
+
402
+ def run_standard_correction(model, tensors, dataset_id, task, loader,
403
+ diffusion_steps, chain_frames):
404
+ device = next(model.parameters()).device
405
+ X_given = tensors["X_given"].to(device)
406
+ E_given = tensors["E_given"].to(device)
407
+ y_given = tensors["y_given"].to(device)
408
+ X_index = tensors["X_index"].to(device)
409
+ is_bip = tensors["is_bip"].to(device)
410
+ n_nodes = tensors["n_nodes"].to(device)
411
+ node_mask = tensors["node_mask"].to(device)
412
+ n_max = n_nodes.item()
413
+
414
+ inpaint_mask = _build_inpaint_mask(
415
+ task, node_mask, is_bip, model.Edim_output, device)
416
+
417
+ original_E_int = E_given[0].argmax(dim=-1).long() # (n, n)
418
+ original_img = render_kg_subgraph(
419
+ original_E_int, n_max, X_index[0], dataset_id, loader, changes=None)
420
+
421
+ model_T = model.T
422
+ step_stride = max(1, model_T // diffusion_steps)
423
+ total_loop_steps = (model_T + step_stride - 1) // step_stride
424
+ frame_interval = max(1, total_loop_steps // chain_frames)
425
+
426
+ with torch.no_grad():
427
+ z_T = _sample_initial_noise_kg(model, node_mask)
428
+ if task != "generate":
429
+ z_T.E = z_T.E * inpaint_mask + E_given * (~inpaint_mask)
430
+ X, E, y = X_given, z_T.E, y_given
431
+
432
+ gif_frames = []
433
+ t0 = time.time()
434
+ emitted = 0
435
+ for s_idx in reversed(range(0, model_T, step_stride)):
436
+ t_idx = min(s_idx + step_stride, model_T)
437
+ s_t = (s_idx / model_T) * torch.ones((1, 1), device=device)
438
+ t_t = (t_idx / model_T) * torch.ones((1, 1), device=device)
439
+ sampled_s, discrete_s = model.sample_p_zs_given_zt(
440
+ s_t, t_t, X, E, y, X_index, node_mask, inpaint_mask)
441
+ X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
442
+ emitted += 1
443
+ is_frame = (emitted % frame_interval == 0) or (s_idx == 0)
444
+ E_int_prev = discrete_s.E[0].long()
445
+ event = {
446
+ "type": "progress",
447
+ "phase": "denoise",
448
+ "step": emitted,
449
+ "total_steps": total_loop_steps,
450
+ "elapsed_ms": int((time.time() - t0) * 1000),
451
+ }
452
+ if is_frame:
453
+ frame = render_kg_subgraph(
454
+ E_int_prev, n_max, X_index[0], dataset_id, loader)
455
+ gif_frames.append(frame)
456
+ event["preview"] = _pil_to_b64(frame)
457
+ yield event
458
+
459
+ X_final, E_final = _collapse_final_kg(model, X, E, y, node_mask)
460
+
461
+ corrected_E_int = E_final[0]
462
+ changes = compute_changes(original_E_int, corrected_E_int, n_max, loader)
463
+ corrected_img = render_kg_subgraph(
464
+ corrected_E_int, n_max, X_index[0], dataset_id, loader, changes=changes)
465
+
466
+ elapsed_ms = int((time.time() - t0) * 1000)
467
+ yield {
468
+ "type": "result",
469
+ "original_image": _pil_to_b64(original_img),
470
+ "corrected_image": _pil_to_b64(corrected_img),
471
+ "chain_gif": _frames_to_gif_b64(gif_frames),
472
+ "changes": changes,
473
+ "inference_time_ms": elapsed_ms,
474
+ }
475
+
476
+
477
+ # ---------------------------------------------------------------------------
478
+ # MultiProx correction / generation
479
+ # ---------------------------------------------------------------------------
480
+
481
+ def run_multiprox_correction_init(model, tensors, dataset_id, task, loader,
482
+ n, m, t, t_prime, gibbs_chain_freq):
483
+ device = next(model.parameters()).device
484
+ X_given = tensors["X_given"].to(device)
485
+ E_given = tensors["E_given"].to(device)
486
+ y_given = tensors["y_given"].to(device)
487
+ X_index = tensors["X_index"].to(device)
488
+ X_c = tensors["X_c"].to(device)
489
+ is_bip = tensors["is_bip"].to(device)
490
+ n_nodes = tensors["n_nodes"].to(device)
491
+ node_mask = tensors["node_mask"].to(device)
492
+ n_max = n_nodes.item()
493
+
494
+ inpaint_mask = _build_inpaint_mask(
495
+ task, node_mask, is_bip, model.Edim_output, device)
496
+ original_E_int = E_given[0].argmax(dim=-1).long()
497
+ original_img = render_kg_subgraph(
498
+ original_E_int, n_max, X_index[0], dataset_id, loader, changes=None)
499
+
500
+ t0 = time.time()
501
+ # Sample initial noise for each of M Gibbs chains
502
+ z_samples = []
503
+ with torch.no_grad():
504
+ for i in range(m):
505
+ z_i = _sample_initial_noise_kg(model, node_mask)
506
+ if task != "generate":
507
+ z_i.E = z_i.E * inpaint_mask + E_given * (~inpaint_mask)
508
+ z_samples.append(z_i)
509
+ if (i + 1) % max(1, m // 10) == 0 or i == m - 1:
510
+ yield {
511
+ "type": "progress",
512
+ "phase": "noise_init",
513
+ "step": i + 1,
514
+ "total_steps": m,
515
+ "elapsed_ms": int((time.time() - t0) * 1000),
516
+ }
517
+
518
+ # Stack to (1, M, n, ...) tensors
519
+ E_ens = torch.stack([z.E for z in z_samples], dim=1) # (1, M, n, n, Edim)
520
+ y_ens = torch.stack([z.y for z in z_samples], dim=1) # (1, M, ydim)
521
+
522
+ # Aggregate for preview
523
+ agg_E = torch.median(E_ens, dim=1).values
524
+ agg_y = torch.median(y_ens.float(), dim=1).values
525
+ X_int, E_int = _collapse_final_kg(model, X_given, agg_E, agg_y, node_mask)
526
+ corrected_E_int = E_int[0]
527
+ changes = compute_changes(original_E_int, corrected_E_int, n_max, loader)
528
+ preview_img = render_kg_subgraph(
529
+ corrected_E_int, n_max, X_index[0], dataset_id, loader, changes=changes)
530
+ elapsed_ms = int((time.time() - t0) * 1000)
531
+
532
+ state = {
533
+ "X_given": X_given.cpu(),
534
+ "E": E_ens.cpu(),
535
+ "y": y_ens.cpu(),
536
+ "n_nodes": n_nodes.cpu(),
537
+ "dataset_id": dataset_id,
538
+ "task": task,
539
+ "X_index": X_index.cpu(),
540
+ "X_c": X_c.cpu(),
541
+ "is_bip": is_bip.cpu(),
542
+ "original_E_int": original_E_int.cpu(),
543
+ "T": model.T, "n": n, "m": m, "t": t, "t_prime": t_prime,
544
+ "gibbs_chain_freq": gibbs_chain_freq,
545
+ "inner_step": 0, "step": 0,
546
+ }
547
+ yield {
548
+ "type": "result",
549
+ "state": state,
550
+ "original_image": _pil_to_b64(original_img),
551
+ "image": _pil_to_b64(preview_img),
552
+ "changes": changes,
553
+ "inference_time_ms": elapsed_ms,
554
+ }
555
+
556
+
557
+ def run_multiprox_correction_step(model, state, loader):
558
+ device = next(model.parameters()).device
559
+ dataset_id = state["dataset_id"]
560
+ task = state["task"]
561
+ X_given = state["X_given"].to(device)
562
+ E = state["E"].to(device)
563
+ y = state["y"].to(device)
564
+ X_index = state["X_index"].to(device)
565
+ is_bip = state["is_bip"].to(device)
566
+ n_nodes = state["n_nodes"].to(device)
567
+ original_E_int = state["original_E_int"].to(device)
568
+
569
+ T = state["T"]
570
+ n = state["n"]
571
+ m = state["m"]
572
+ t = state["t"]
573
+ t_prime = state["t_prime"]
574
+ gibbs_chain_freq = state["gibbs_chain_freq"]
575
+ inner_step = state["inner_step"]
576
+ step = state["step"]
577
+
578
+ n_max = int(n_nodes.item())
579
+ node_mask = torch.ones(1, n_max, dtype=torch.bool, device=device)
580
+ inpaint_mask = _build_inpaint_mask(task, node_mask, is_bip, model.Edim_output, device)
581
+
582
+ fixed_t_norm = t * torch.ones((1, 1), dtype=torch.float, device=device)
583
+ fixed_s_norm = fixed_t_norm - (1.0 / T)
584
+
585
+ steps_this_call = min(gibbs_chain_freq, m - inner_step)
586
+
587
+ t0 = time.time()
588
+ with torch.no_grad():
589
+ for i in range(steps_this_call):
590
+ k = inner_step + i
591
+ avg_E = torch.median(E, dim=1).values
592
+ avg_y = torch.median(y.float(), dim=1).values
593
+ denoised, _ = model.sample_p_zs_given_zt(
594
+ fixed_s_norm, fixed_t_norm, X_given, avg_E, avg_y,
595
+ X_index, node_mask, inpaint_mask)
596
+
597
+ old_t2 = model.gibbs_fixed_t_2
598
+ model.gibbs_fixed_t_2 = t # safe: inference lock held by registry
599
+ noisy = model.apply_noise(
600
+ denoised.X, denoised.E, denoised.y, node_mask, inpaint_mask, gibbs=True)
601
+ model.gibbs_fixed_t_2 = old_t2
602
+
603
+ E[:, k] = noisy["E_t"]
604
+ y[:, k] = noisy["y_t"]
605
+
606
+ # Preview aggregate state
607
+ prev_E = torch.median(E, dim=1).values
608
+ prev_y = torch.median(y.float(), dim=1).values
609
+ _, prev_Ei = _collapse_final_kg(model, X_given, prev_E, prev_y, node_mask)
610
+ preview_img = render_kg_subgraph(
611
+ prev_Ei[0], n_max, X_index[0], dataset_id, loader)
612
+ yield {
613
+ "type": "progress",
614
+ "phase": "gibbs",
615
+ "step": i + 1,
616
+ "total_steps": steps_this_call,
617
+ "elapsed_ms": int((time.time() - t0) * 1000),
618
+ "preview": _pil_to_b64(preview_img),
619
+ }
620
+
621
+ new_inner_step = inner_step + steps_this_call
622
+ round_complete = new_inner_step >= m
623
+ if round_complete:
624
+ new_inner_step = 0
625
+ new_step = step + 1
626
+ else:
627
+ new_step = step
628
+ done = round_complete and new_step >= n
629
+
630
+ # Refinement pass — always produce a clean render
631
+ P = int((t - t_prime) * T) + 1
632
+ P = max(P, 1)
633
+ refine_preview_interval = max(1, P // 10)
634
+ cur_E = torch.median(E, dim=1).values
635
+ cur_y = torch.median(y.float(), dim=1).values
636
+ cur_X = X_given
637
+ for j in range(P):
638
+ s_ref = (t - (j + 1) / T) * torch.ones((1, 1), dtype=torch.float, device=device)
639
+ t_ref = (t - j / T) * torch.ones((1, 1), dtype=torch.float, device=device)
640
+ sampled, discrete_s = model.sample_p_zs_given_zt(
641
+ s_ref, t_ref, cur_X, cur_E, cur_y, X_index, node_mask, inpaint_mask)
642
+ cur_X, cur_E, cur_y = sampled.X, sampled.E, sampled.y
643
+ is_frame = (j + 1) % refine_preview_interval == 0 or j == P - 1
644
+ event = {
645
+ "type": "progress",
646
+ "phase": "refine",
647
+ "step": j + 1,
648
+ "total_steps": P,
649
+ "elapsed_ms": int((time.time() - t0) * 1000),
650
+ }
651
+ if is_frame:
652
+ event["preview"] = _pil_to_b64(render_kg_subgraph(
653
+ discrete_s.E[0].long(), n_max, X_index[0], dataset_id, loader))
654
+ yield event
655
+
656
+ X_int, E_int = _collapse_final_kg(model, cur_X, cur_E, cur_y, node_mask)
657
+
658
+ corrected_E_int = E_int[0]
659
+ changes = compute_changes(original_E_int, corrected_E_int, n_max, loader)
660
+ corrected_img = render_kg_subgraph(
661
+ corrected_E_int, n_max, X_index[0], dataset_id, loader, changes=changes)
662
+ elapsed_ms = int((time.time() - t0) * 1000)
663
+
664
+ updated_state = {
665
+ **state,
666
+ "E": E.cpu(), "y": y.cpu(),
667
+ "step": new_step, "inner_step": new_inner_step,
668
+ }
669
+ yield {
670
+ "type": "result",
671
+ "state": updated_state,
672
+ "image": _pil_to_b64(corrected_img),
673
+ "changes": changes,
674
+ "round_complete": round_complete,
675
+ "done": done,
676
+ "inference_time_ms": elapsed_ms,
677
+ }
678
+
679
+
680
+ # ---------------------------------------------------------------------------
681
+ # State blob serialisation
682
+ # ---------------------------------------------------------------------------
683
+
684
+ def encode_state_blob(state):
685
+ buf = io.BytesIO()
686
+ torch.save(state, buf)
687
+ return base64.b64encode(buf.getvalue()).decode("ascii")
688
+
689
+
690
+ def decode_state_blob(b64_str):
691
+ try:
692
+ raw = base64.b64decode(b64_str)
693
+ except Exception:
694
+ raise ValueError("state is not valid base64")
695
+ if len(raw) > STATE_BLOB_MAX_BYTES:
696
+ raise ValueError(f"state blob exceeds {STATE_BLOB_MAX_BYTES // (1024 * 1024)} MB limit")
697
+ try:
698
+ state = torch.load(io.BytesIO(raw), weights_only=False)
699
+ except Exception as exc:
700
+ raise ValueError(f"state could not be deserialized: {exc}") from exc
701
+ missing = REQUIRED_STATE_KEYS - set(state.keys())
702
+ if missing:
703
+ raise ValueError(f"state missing keys: {missing}")
704
+ if not isinstance(state["E"], torch.Tensor) or state["E"].dim() != 5:
705
+ raise ValueError("state['E'] must be a 5-D tensor")
706
+ if not isinstance(state["X_given"], torch.Tensor) or state["X_given"].dim() != 3:
707
+ raise ValueError("state['X_given'] must be a 3-D tensor")
708
+ return state
src/backend/api/services/registry.py CHANGED
@@ -276,6 +276,7 @@ class ModelRegistry:
276
  self._coins_experiments = {} # (dataset_id, algorithm) -> Experiment
277
  self._coins_loaders = {} # (dataset_id, seed, leiden_resolution) -> full Loader
278
  self._graphgen_models = {} # (dataset_id, model_type) -> loaded eval-mode model
 
279
 
280
  def force_release_inference_lock(self):
281
  """Emergency release for a stuck inference lock (e.g. client disconnect)."""
@@ -465,9 +466,12 @@ class ModelRegistry:
465
  seed=seed, device="cpu", val_size=0.01, test_size=0.02,
466
  community_method="leiden", leiden_resolution=leiden_resolution,
467
  )
468
- # Free heavy arrays not needed for discovery endpoints
469
- _free_heavy_arrays(loader)
470
  self.loaders[dataset_id] = loader
 
 
 
 
 
471
  logger.info(
472
  "Loader ready for %s: %d entities, %d relations, %d train triples",
473
  dataset_id, loader.num_nodes, loader.num_relations, len(loader.train_edge_data),
@@ -885,6 +889,239 @@ class ModelRegistry:
885
 
886
  return _gen()
887
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
888
  # ---- COINs inference ---------------------------------------------------
889
 
890
  def coins_predict(self, dataset_id, algorithm, query_structure_id,
 
276
  self._coins_experiments = {} # (dataset_id, algorithm) -> Experiment
277
  self._coins_loaders = {} # (dataset_id, seed, leiden_resolution) -> full Loader
278
  self._graphgen_models = {} # (dataset_id, model_type) -> loaded eval-mode model
279
+ self._kg_anomaly_models = {} # (dataset_id, task) -> loaded eval-mode model
280
 
281
  def force_release_inference_lock(self):
282
  """Emergency release for a stuck inference lock (e.g. client disconnect)."""
 
466
  seed=seed, device="cpu", val_size=0.01, test_size=0.02,
467
  community_method="leiden", leiden_resolution=leiden_resolution,
468
  )
 
 
469
  self.loaders[dataset_id] = loader
470
+ # Share this loader with _load_coins_experiment so experiments for the
471
+ # same (dataset, seed, leiden_resolution) reuse it instead of reloading
472
+ # the graph. Heavy arrays stay populated — they're needed by full
473
+ # experiments (embedder/sampler/ranker) and by KG anomaly inference.
474
+ self._coins_loaders[(dataset_id, seed, leiden_resolution)] = loader
475
  logger.info(
476
  "Loader ready for %s: %d entities, %d relations, %d train triples",
477
  dataset_id, loader.num_nodes, loader.num_relations, len(loader.train_edge_data),
 
889
 
890
  return _gen()
891
 
892
+ # ---- KG anomaly (DiGress KG) inference --------------------------------
893
+
894
+ def _load_kg_anomaly_model(self, dataset_id, task):
895
+ """Load the DiGress KG checkpoint for (dataset_id, task), cached.
896
+
897
+ The KG checkpoint pickles only ``cfg`` via ``save_hyperparameters('cfg')``,
898
+ so we must reconstruct ``dataset_infos``, ``extra_features`` and
899
+ ``domain_features`` before constructing the model. Dims are inferred from
900
+ state_dict shapes; kg_experiment comes from the matching COINs experiment.
901
+ """
902
+ key = (dataset_id, task)
903
+ if key in self._kg_anomaly_models:
904
+ return self._kg_anomaly_models[key]
905
+
906
+ import torch
907
+ import torch.nn.parallel.distributed as _ddp_mod
908
+
909
+ suffix = "_correct" if task == "correct" else ""
910
+ ckpt_path = Path(settings.DIGRESS_KG_DIR) / "checkpoints" / f"{dataset_id}{suffix}.ckpt"
911
+ if not ckpt_path.exists():
912
+ from api.exceptions import ModelUnavailable
913
+ raise ModelUnavailable(f"KG anomaly checkpoint not found: {ckpt_path.name}")
914
+
915
+ logger.info("Loading KG anomaly model: dataset=%s task=%s", dataset_id, task)
916
+
917
+ # Load to CPU with DDP patching (same strategy as _safe_load_lightning_checkpoint)
918
+ _orig_set = _ddp_mod.DistributedDataParallel.__setstate__
919
+ _orig_get = _ddp_mod.DistributedDataParallel.__getstate__
920
+ _ddp_mod.DistributedDataParallel.__setstate__ = lambda self, state: self.__dict__.update(state)
921
+ _ddp_mod.DistributedDataParallel.__getstate__ = lambda self: self.__dict__
922
+ try:
923
+ ckpt = torch.load(str(ckpt_path), map_location="cpu", weights_only=False)
924
+ finally:
925
+ _ddp_mod.DistributedDataParallel.__setstate__ = _orig_set
926
+ _ddp_mod.DistributedDataParallel.__getstate__ = _orig_get
927
+
928
+ hparams = ckpt.get("hyper_parameters", {})
929
+ cfg = hparams.get("cfg") if isinstance(hparams, dict) else getattr(hparams, "cfg", None)
930
+ if cfg is None:
931
+ raise RuntimeError(f"KG anomaly checkpoint {ckpt_path.name} is missing 'cfg' in hyper_parameters")
932
+ state_dict = ckpt["state_dict"]
933
+
934
+ # Ensure the model's task matches the endpoint task.
935
+ try:
936
+ cfg.model.task = task
937
+ except Exception:
938
+ pass # OmegaConf struct-mode tolerant: if already set, leave it
939
+
940
+ # Infer dims from state_dict
941
+ edim_output = state_dict["model.mlp_out_E.2.weight"].shape[0]
942
+ input_dim_x = state_dict["model.mlp_in_X.0.weight"].shape[1]
943
+ input_dim_e = state_dict["model.mlp_in_E.0.weight"].shape[1]
944
+ input_dim_y = state_dict["model.mlp_in_y.0.weight"].shape[1]
945
+
946
+ # Load COINs experiment — needed for kg_experiment and for num_node_types
947
+ experiment = self._load_coins_experiment(dataset_id, "transe")
948
+ xdim_output = experiment.loader.num_node_types
949
+ # Sanity: input_dim_e should equal edim_output (no extra E features for KG)
950
+ if input_dim_e != edim_output:
951
+ logger.warning(
952
+ "Unexpected mlp_in_E dim %d != edim_output %d for %s/%s",
953
+ input_dim_e, edim_output, dataset_id, task,
954
+ )
955
+
956
+ # Build mock dataset_infos
957
+ from graph_generation.src.diffusion.distributions import DistributionNodes
958
+ from graph_generation.src.diffusion.extra_features import (
959
+ DummyExtraFeatures, ExtraFeatures,
960
+ )
961
+
962
+ # max_num_nodes from dataset name (e.g. "freebase_20" -> 20, then *2 per kg_dataset.py)
963
+ try:
964
+ base_max = int(cfg.dataset.name.split("_")[-1])
965
+ except (AttributeError, ValueError):
966
+ base_max = 20
967
+ max_num_nodes = base_max * 2
968
+
969
+ # Histogram for DistributionNodes — uniform over possible node counts
970
+ n_hist = torch.ones(max_num_nodes + 1)
971
+ n_hist[:2] = 0 # at least 2 nodes
972
+ nodes_dist = DistributionNodes(n_hist)
973
+
974
+ class _MockDataModule:
975
+ def __init__(self, kg_experiment, max_num_nodes):
976
+ self.kg_experiment = kg_experiment
977
+ self.max_num_nodes = max_num_nodes
978
+
979
+ class _MockDatasetInfos:
980
+ pass
981
+
982
+ dataset_infos = _MockDatasetInfos()
983
+ dataset_infos.datamodule = _MockDataModule(experiment, max_num_nodes)
984
+ dataset_infos.input_dims = {"X": input_dim_x, "E": input_dim_e, "y": input_dim_y}
985
+ dataset_infos.output_dims = {"X": xdim_output, "E": edim_output, "y": 0}
986
+ dataset_infos.nodes_dist = nodes_dist
987
+ dataset_infos.max_n_nodes = max_num_nodes
988
+ dataset_infos.node_types = torch.ones(xdim_output, dtype=torch.float32)
989
+ dataset_infos.edge_types = torch.ones(edim_output, dtype=torch.float32)
990
+
991
+ # extra_features per cfg
992
+ extra_features_type = getattr(cfg.model, "extra_features", None)
993
+ if cfg.model.type == "discrete" and extra_features_type is not None:
994
+ extra_features = ExtraFeatures(extra_features_type, dataset_info=dataset_infos)
995
+ else:
996
+ extra_features = DummyExtraFeatures()
997
+ domain_features = DummyExtraFeatures()
998
+
999
+ from diffusion_model_discrete_kg import DiscreteDenoisingDiffusionKG as cls
1000
+
1001
+ _orig_save = cls.save_hyperparameters
1002
+ cls.save_hyperparameters = lambda self, *a, **kw: None
1003
+ try:
1004
+ model = cls(cfg, dataset_infos, None, None, None, extra_features, domain_features)
1005
+ finally:
1006
+ cls.save_hyperparameters = _orig_save
1007
+
1008
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
1009
+ if missing:
1010
+ logger.debug("KG anomaly state_dict missing keys: %d (e.g. %s)",
1011
+ len(missing), missing[:3])
1012
+ if unexpected:
1013
+ logger.debug("KG anomaly state_dict unexpected keys: %d (e.g. %s)",
1014
+ len(unexpected), unexpected[:3])
1015
+
1016
+ del ckpt
1017
+ model.to(settings.TORCH_DEVICE)
1018
+ model.eval()
1019
+ self._kg_anomaly_models[key] = model
1020
+ logger.info("KG anomaly model ready: dataset=%s task=%s", dataset_id, task)
1021
+ return model
1022
+
1023
+ def kg_anomaly_correct_stream(self, dataset_id, task, sampling_mode, subgraph,
1024
+ diffusion_steps, chain_frames, multiprox_params):
1025
+ """Return a generator of SSE event dicts for /kg-anomaly/correct."""
1026
+ from api.exceptions import InferenceBusy
1027
+ from api.services.kg_anomaly_inference import (
1028
+ build_kg_tensors, encode_state_blob,
1029
+ run_multiprox_correction_init, run_standard_correction,
1030
+ )
1031
+ if not self._inference_lock.acquire(blocking=False):
1032
+ raise InferenceBusy()
1033
+ self._inference_lock_owner = f"kg_anomaly_correct {dataset_id}/{task}/{sampling_mode}"
1034
+ try:
1035
+ model = self._load_kg_anomaly_model(dataset_id, task)
1036
+ loader = self.loaders.get(dataset_id)
1037
+ tensors = build_kg_tensors(subgraph, loader, model)
1038
+ except Exception:
1039
+ self._inference_lock_owner = None
1040
+ self._inference_lock.release()
1041
+ raise
1042
+
1043
+ def _gen():
1044
+ try:
1045
+ if sampling_mode == "standard":
1046
+ for event in run_standard_correction(
1047
+ model, tensors, dataset_id, task, loader,
1048
+ diffusion_steps, chain_frames):
1049
+ if event["type"] == "result":
1050
+ event.update({
1051
+ "dataset_id": dataset_id,
1052
+ "task": task,
1053
+ "sampling_mode": sampling_mode,
1054
+ })
1055
+ yield event
1056
+ else:
1057
+ n = multiprox_params["n"]
1058
+ m = multiprox_params["m"]
1059
+ t = multiprox_params["t"]
1060
+ t_prime = multiprox_params["t_prime"]
1061
+ gibbs_chain_freq = multiprox_params["gibbs_chain_freq"]
1062
+ for event in run_multiprox_correction_init(
1063
+ model, tensors, dataset_id, task, loader,
1064
+ n, m, t, t_prime, gibbs_chain_freq):
1065
+ if event["type"] == "result":
1066
+ state = event.pop("state")
1067
+ event.update({
1068
+ "dataset_id": dataset_id,
1069
+ "task": task,
1070
+ "sampling_mode": sampling_mode,
1071
+ "step": 0,
1072
+ "round_complete": False,
1073
+ "done": False,
1074
+ "state": encode_state_blob(state),
1075
+ })
1076
+ yield event
1077
+ finally:
1078
+ self._inference_lock_owner = None
1079
+ self._inference_lock.release()
1080
+
1081
+ return _gen()
1082
+
1083
+ def kg_anomaly_continue_stream(self, state_b64):
1084
+ """Return a generator of SSE event dicts for /kg-anomaly/continue."""
1085
+ from api.exceptions import InferenceBusy, InvalidRequestError
1086
+ from api.services.kg_anomaly_inference import (
1087
+ decode_state_blob, encode_state_blob, run_multiprox_correction_step,
1088
+ )
1089
+ try:
1090
+ state = decode_state_blob(state_b64)
1091
+ except ValueError as exc:
1092
+ raise InvalidRequestError(str(exc))
1093
+
1094
+ if not self._inference_lock.acquire(blocking=False):
1095
+ raise InferenceBusy()
1096
+ self._inference_lock_owner = (
1097
+ f"kg_anomaly_continue {state['dataset_id']}/{state['task']}"
1098
+ )
1099
+ try:
1100
+ model = self._load_kg_anomaly_model(state["dataset_id"], state["task"])
1101
+ loader = self.loaders.get(state["dataset_id"])
1102
+ except Exception:
1103
+ self._inference_lock_owner = None
1104
+ self._inference_lock.release()
1105
+ raise
1106
+
1107
+ def _gen():
1108
+ try:
1109
+ for event in run_multiprox_correction_step(model, state, loader):
1110
+ if event["type"] == "result":
1111
+ updated_state = event.pop("state")
1112
+ event.update({
1113
+ "dataset_id": updated_state["dataset_id"],
1114
+ "task": updated_state["task"],
1115
+ "step": updated_state["step"],
1116
+ "state": encode_state_blob(updated_state),
1117
+ })
1118
+ yield event
1119
+ finally:
1120
+ self._inference_lock_owner = None
1121
+ self._inference_lock.release()
1122
+
1123
+ return _gen()
1124
+
1125
  # ---- COINs inference ---------------------------------------------------
1126
 
1127
  def coins_predict(self, dataset_id, algorithm, query_structure_id,
src/backend/api/urls.py CHANGED
@@ -13,7 +13,12 @@ from api.views.graph_generation import (
13
  GraphGenContinueView, GraphGenDatasetsView, GraphGenGenerateView, GraphGenSamplingModesView,
14
  )
15
  from api.views.health import ApiRootView, ForceUnlockView, HealthView, MethodsView
16
- from api.views.kg_anomaly import KgAnomalyDatasetsView, KgAnomalySampleSubgraphsView
 
 
 
 
 
17
 
18
  urlpatterns = [
19
  # Health & discovery
@@ -37,4 +42,6 @@ urlpatterns = [
37
  # KG anomaly
38
  path("kg-anomaly/datasets", KgAnomalyDatasetsView.as_view()),
39
  path("kg-anomaly/datasets/<str:dataset_id>/sample-subgraphs", KgAnomalySampleSubgraphsView.as_view()),
 
 
40
  ]
 
13
  GraphGenContinueView, GraphGenDatasetsView, GraphGenGenerateView, GraphGenSamplingModesView,
14
  )
15
  from api.views.health import ApiRootView, ForceUnlockView, HealthView, MethodsView
16
+ from api.views.kg_anomaly import (
17
+ KgAnomalyContinueView,
18
+ KgAnomalyCorrectView,
19
+ KgAnomalyDatasetsView,
20
+ KgAnomalySampleSubgraphsView,
21
+ )
22
 
23
  urlpatterns = [
24
  # Health & discovery
 
42
  # KG anomaly
43
  path("kg-anomaly/datasets", KgAnomalyDatasetsView.as_view()),
44
  path("kg-anomaly/datasets/<str:dataset_id>/sample-subgraphs", KgAnomalySampleSubgraphsView.as_view()),
45
+ path("kg-anomaly/correct", KgAnomalyCorrectView.as_view()),
46
+ path("kg-anomaly/continue", KgAnomalyContinueView.as_view()),
47
  ]
src/backend/api/views/kg_anomaly.py CHANGED
@@ -1,9 +1,11 @@
1
  from rest_framework.response import Response
2
  from rest_framework.views import APIView
3
 
4
- from api.exceptions import NotFoundError
5
  from api.services.constants import KG_ANOMALY_DATASET_META
 
6
  from api.services.registry import ModelRegistry
 
7
 
8
 
9
  class KgAnomalyDatasetsView(APIView):
@@ -34,9 +36,139 @@ class KgAnomalySampleSubgraphsView(APIView):
34
  count = int(request.query_params.get("count", 5))
35
  count = max(1, min(10, count))
36
 
37
- subgraphs = sg_info.subgraphs[:count]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
  return Response({
40
  "dataset_id": dataset_id,
41
  "subgraphs": subgraphs,
42
  })
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  from rest_framework.response import Response
2
  from rest_framework.views import APIView
3
 
4
+ from api.exceptions import InvalidRequestError, ModelUnavailable, NotFoundError
5
  from api.services.constants import KG_ANOMALY_DATASET_META
6
+ from api.services.kg_anomaly_inference import apply_edge_noise, build_kg_tensors
7
  from api.services.registry import ModelRegistry
8
+ from api.views.graph_generation import _streaming_sse_response
9
 
10
 
11
  class KgAnomalyDatasetsView(APIView):
 
36
  count = int(request.query_params.get("count", 5))
37
  count = max(1, min(10, count))
38
 
39
+ subgraphs = [dict(sg) for sg in sg_info.subgraphs[:count]]
40
+
41
+ noise_level_raw = request.query_params.get("noise_level")
42
+ if noise_level_raw is not None:
43
+ try:
44
+ noise_level = float(noise_level_raw)
45
+ except ValueError:
46
+ raise InvalidRequestError("'noise_level' must be a float in (0, 1]")
47
+ if not (0.0 < noise_level <= 1.0):
48
+ raise InvalidRequestError("'noise_level' must be in (0, 1]")
49
+
50
+ task = request.query_params.get("task", "correct")
51
+ if task not in ("correct", "generate"):
52
+ raise InvalidRequestError("'task' must be 'correct' or 'generate'")
53
+
54
+ available = registry.kg_anomaly_checkpoints_available.get(dataset_id, [])
55
+ if task not in available:
56
+ raise ModelUnavailable(
57
+ f"No '{task}' checkpoint available for dataset '{dataset_id}'")
58
+
59
+ seed_raw = request.query_params.get("seed")
60
+ seed = int(seed_raw) if seed_raw is not None else None
61
+
62
+ loader = registry.loaders[dataset_id]
63
+ model = registry._load_kg_anomaly_model(dataset_id, task)
64
+
65
+ for i, sg in enumerate(subgraphs):
66
+ offset_seed = None if seed is None else seed + i
67
+ tensors = build_kg_tensors(sg, loader, model)
68
+ sg["edges"] = apply_edge_noise(model, tensors, task, noise_level, offset_seed)
69
 
70
  return Response({
71
  "dataset_id": dataset_id,
72
  "subgraphs": subgraphs,
73
  })
74
+
75
+
76
+ def _validate_subgraph(subgraph):
77
+ if not isinstance(subgraph, dict):
78
+ raise InvalidRequestError("'subgraph' must be an object with 'nodes' and 'edges'")
79
+ nodes = subgraph.get("nodes")
80
+ edges = subgraph.get("edges")
81
+ if not isinstance(nodes, list) or not (2 <= len(nodes) <= 20):
82
+ raise InvalidRequestError("'subgraph.nodes' must be a list of 2 to 20 items")
83
+ if not isinstance(edges, list):
84
+ raise InvalidRequestError("'subgraph.edges' must be a list")
85
+ n = len(nodes)
86
+ for i, node in enumerate(nodes):
87
+ if not isinstance(node, dict) or "entity_id" not in node:
88
+ raise InvalidRequestError(f"subgraph.nodes[{i}] must have 'entity_id'")
89
+ for i, e in enumerate(edges):
90
+ if not isinstance(e, dict):
91
+ raise InvalidRequestError(f"subgraph.edges[{i}] must be an object")
92
+ for field in ("source_idx", "target_idx", "relation_id"):
93
+ if field not in e:
94
+ raise InvalidRequestError(f"subgraph.edges[{i}] missing '{field}'")
95
+ if not (0 <= int(e["source_idx"]) < n and 0 <= int(e["target_idx"]) < n):
96
+ raise InvalidRequestError(f"subgraph.edges[{i}] has out-of-range node index")
97
+ if int(e["source_idx"]) == int(e["target_idx"]):
98
+ raise InvalidRequestError(f"subgraph.edges[{i}] is a self-loop (not allowed)")
99
+
100
+
101
+ class KgAnomalyCorrectView(APIView):
102
+ def post(self, request):
103
+ data = request.data
104
+ registry = ModelRegistry.get()
105
+
106
+ dataset_id = data.get("dataset_id")
107
+ if dataset_id not in KG_ANOMALY_DATASET_META:
108
+ raise InvalidRequestError(
109
+ f"Unknown dataset_id '{dataset_id}'. Valid: {list(KG_ANOMALY_DATASET_META)}")
110
+
111
+ task = data.get("task", "correct")
112
+ if task not in ("correct", "generate"):
113
+ raise InvalidRequestError("task must be 'correct' or 'generate'")
114
+
115
+ available = registry.kg_anomaly_checkpoints_available.get(dataset_id, [])
116
+ if task not in available:
117
+ raise ModelUnavailable(
118
+ f"No '{task}' checkpoint available for dataset '{dataset_id}'")
119
+
120
+ sampling_mode = data.get("sampling_mode")
121
+ if sampling_mode not in ("standard", "multiprox"):
122
+ raise InvalidRequestError("sampling_mode must be 'standard' or 'multiprox'")
123
+
124
+ subgraph = data.get("subgraph")
125
+ _validate_subgraph(subgraph)
126
+
127
+ if sampling_mode == "standard":
128
+ diffusion_steps = min(max(int(data.get("diffusion_steps", 500)), 50), 1000)
129
+ chain_frames = min(max(int(data.get("chain_frames", 20)), 10), 30)
130
+ gen = registry.kg_anomaly_correct_stream(
131
+ dataset_id, task, sampling_mode, subgraph,
132
+ diffusion_steps, chain_frames, None)
133
+ else:
134
+ mp = data.get("multiprox_params")
135
+ if not mp or not isinstance(mp, dict):
136
+ raise InvalidRequestError("multiprox_params is required for multiprox sampling_mode")
137
+
138
+ m = int(mp.get("m", 100))
139
+ if not (2 <= m <= 100):
140
+ raise InvalidRequestError("multiprox_params.m must be in [2, 100]")
141
+
142
+ n = int(mp.get("n", 10))
143
+ if n < 1:
144
+ raise InvalidRequestError("multiprox_params.n must be >= 1")
145
+
146
+ t = float(mp.get("t", 0.5))
147
+ t_prime = float(mp.get("t_prime", 0.1))
148
+ if not (0 < t_prime <= t <= 1):
149
+ raise InvalidRequestError(
150
+ "multiprox_params must satisfy 0 < t_prime <= t <= 1")
151
+
152
+ gibbs_chain_freq = int(mp.get("gibbs_chain_freq", max(1, m // 10)))
153
+ if not (1 <= gibbs_chain_freq <= m):
154
+ raise InvalidRequestError(
155
+ f"multiprox_params.gibbs_chain_freq must be in [1, {m}]")
156
+
157
+ multiprox_params = {
158
+ "n": n, "m": m, "t": t, "t_prime": t_prime,
159
+ "gibbs_chain_freq": gibbs_chain_freq,
160
+ }
161
+ gen = registry.kg_anomaly_correct_stream(
162
+ dataset_id, task, sampling_mode, subgraph,
163
+ None, None, multiprox_params)
164
+
165
+ return _streaming_sse_response(gen)
166
+
167
+
168
+ class KgAnomalyContinueView(APIView):
169
+ def post(self, request):
170
+ state_b64 = request.data.get("state")
171
+ if not state_b64 or not isinstance(state_b64, str):
172
+ raise InvalidRequestError("'state' is required and must be a non-empty string")
173
+ gen = ModelRegistry.get().kg_anomaly_continue_stream(state_b64)
174
+ return _streaming_sse_response(gen)
src/backend/research_api/settings.py CHANGED
@@ -7,9 +7,10 @@ PROJECT_ROOT = BASE_DIR.parent.parent # Website root
7
 
8
  # Add research repos to sys.path so their modules can be imported
9
  _COINS_KG_ROOT = str(PROJECT_ROOT / "src" / "research" / "COINs-KGGeneration")
 
10
  _MULTIPROXAN_ROOT = str(PROJECT_ROOT / "src" / "research" / "MultiProxAn")
11
  _MULTIPROXAN_SRC = str(PROJECT_ROOT / "src" / "research" / "MultiProxAn" / "src")
12
- for _path in (_COINS_KG_ROOT, _MULTIPROXAN_ROOT, _MULTIPROXAN_SRC):
13
  if _path not in sys.path:
14
  sys.path.insert(0, _path)
15
 
 
7
 
8
  # Add research repos to sys.path so their modules can be imported
9
  _COINS_KG_ROOT = str(PROJECT_ROOT / "src" / "research" / "COINs-KGGeneration")
10
+ _DIGRESS_KG_SRC = str(PROJECT_ROOT / "src" / "research" / "COINs-KGGeneration" / "graph_generation" / "src")
11
  _MULTIPROXAN_ROOT = str(PROJECT_ROOT / "src" / "research" / "MultiProxAn")
12
  _MULTIPROXAN_SRC = str(PROJECT_ROOT / "src" / "research" / "MultiProxAn" / "src")
13
+ for _path in (_COINS_KG_ROOT, _DIGRESS_KG_SRC, _MULTIPROXAN_ROOT, _MULTIPROXAN_SRC):
14
  if _path not in sys.path:
15
  sys.path.insert(0, _path)
16