feat(kg-anomaly): add correct/continue endpoints with SSE streaming
Browse filesImplements the two remaining inference endpoints for the KG anomaly
feature, following the MultiProxAn graph-generation pattern:
- POST /kg-anomaly/correct: standard denoising (correct/generate tasks)
or MultiProx Gibbs init; returns SSE stream with progress, preview,
and terminal result events (before/after images, chain GIF, diff).
- POST /kg-anomaly/continue: advances a MultiProx session one step.
- Adds kg_anomaly_inference.py with tensor building, change detection,
directed subgraph rendering (PIL + networkx), and an apply_edge_noise
helper that task-aware forward-diffuses edges for demo input.
- Extends GET /kg-anomaly/datasets/{id}/sample-subgraphs with
noise_level/task/seed query params so callers can fetch pre-noised
subgraphs ready for correction.
- Registry: adds DiscreteDenoisingDiffusionKG loader that reconstructs
dataset_infos from checkpoint state_dict shapes + COINs experiment.
- Adds graph_generation/src to sys.path for bare imports inside the
research module.
- Updates OpenAPI spec, Postman collection (pre-noised example bodies,
noised-subgraph GET variants, auto-chaining for multiprox), and the
backend README endpoint table.
- docs/api.yaml +175 -54
- docs/postman/collection.json +64 -8
- src/backend/README.md +5 -5
- src/backend/api/services/kg_anomaly_inference.py +708 -0
- src/backend/api/services/registry.py +239 -2
- src/backend/api/urls.py +8 -1
- src/backend/api/views/kg_anomaly.py +134 -2
- src/backend/research_api/settings.py +2 -1
|
@@ -62,6 +62,40 @@ paths:
|
|
| 62 |
schema:
|
| 63 |
$ref: "#/components/schemas/MethodsResponse"
|
| 64 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
# -- COINs -----------------------------------------------------------
|
| 66 |
/coins/datasets:
|
| 67 |
get:
|
|
@@ -279,15 +313,19 @@ paths:
|
|
| 279 |
post:
|
| 280 |
operationId: graphGenGenerate
|
| 281 |
tags: [graph-generation]
|
| 282 |
-
summary: Generate a graph
|
| 283 |
description: |
|
| 284 |
-
|
| 285 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
|
| 287 |
-
**MultiProx mode**:
|
| 288 |
-
|
| 289 |
-
`
|
| 290 |
-
one step at a time.
|
| 291 |
requestBody:
|
| 292 |
required: true
|
| 293 |
content:
|
|
@@ -318,30 +356,11 @@ paths:
|
|
| 318 |
t_prime: 0.1
|
| 319 |
responses:
|
| 320 |
"200":
|
| 321 |
-
description:
|
| 322 |
content:
|
| 323 |
-
|
| 324 |
schema:
|
| 325 |
-
|
| 326 |
-
- $ref: "#/components/schemas/GraphGenStandardResponse"
|
| 327 |
-
- $ref: "#/components/schemas/GraphGenMultiProxResponse"
|
| 328 |
-
examples:
|
| 329 |
-
standard:
|
| 330 |
-
summary: Standard generation result
|
| 331 |
-
value:
|
| 332 |
-
dataset_id: qm9
|
| 333 |
-
model_type: discrete
|
| 334 |
-
sampling_mode: standard
|
| 335 |
-
image: "data:image/png;base64,..."
|
| 336 |
-
chain_gif: "data:image/gif;base64,..."
|
| 337 |
-
inference_time_ms: 3200
|
| 338 |
-
multiprox:
|
| 339 |
-
summary: MultiProx session started
|
| 340 |
-
value:
|
| 341 |
-
state: "base64-encoded-diffusion-state..."
|
| 342 |
-
step: 0
|
| 343 |
-
image: "data:image/png;base64,..."
|
| 344 |
-
inference_time_ms: 800
|
| 345 |
"400":
|
| 346 |
$ref: "#/components/responses/InvalidRequest"
|
| 347 |
"429":
|
|
@@ -353,11 +372,14 @@ paths:
|
|
| 353 |
post:
|
| 354 |
operationId: graphGenContinue
|
| 355 |
tags: [graph-generation]
|
| 356 |
-
summary: Advance MultiProx generation by one step
|
| 357 |
description: |
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
|
|
|
|
|
|
|
|
|
| 361 |
requestBody:
|
| 362 |
required: true
|
| 363 |
content:
|
|
@@ -366,11 +388,11 @@ paths:
|
|
| 366 |
$ref: "#/components/schemas/GraphGenContinueRequest"
|
| 367 |
responses:
|
| 368 |
"200":
|
| 369 |
-
description:
|
| 370 |
content:
|
| 371 |
-
|
| 372 |
schema:
|
| 373 |
-
$ref: "#/components/schemas/
|
| 374 |
"400":
|
| 375 |
$ref: "#/components/responses/InvalidRequest"
|
| 376 |
"429":
|
|
@@ -396,7 +418,13 @@ paths:
|
|
| 396 |
operationId: getKgAnomalySampleSubgraphs
|
| 397 |
tags: [kg-anomaly]
|
| 398 |
summary: Get example subgraphs for correction
|
| 399 |
-
description:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 400 |
parameters:
|
| 401 |
- $ref: "#/components/parameters/KgAnomalyDatasetId"
|
| 402 |
- name: count
|
|
@@ -407,6 +435,32 @@ paths:
|
|
| 407 |
maximum: 10
|
| 408 |
default: 5
|
| 409 |
description: Number of sample subgraphs to return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 410 |
responses:
|
| 411 |
"200":
|
| 412 |
description: Sample subgraphs
|
|
@@ -421,12 +475,19 @@ paths:
|
|
| 421 |
post:
|
| 422 |
operationId: kgAnomalyCorrect
|
| 423 |
tags: [kg-anomaly]
|
| 424 |
-
summary: Correct a KG subgraph
|
| 425 |
description: |
|
| 426 |
-
|
| 427 |
-
|
|
|
|
|
|
|
| 428 |
|
| 429 |
-
**
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 430 |
requestBody:
|
| 431 |
required: true
|
| 432 |
content:
|
|
@@ -476,13 +537,11 @@ paths:
|
|
| 476 |
t_prime: 0.1
|
| 477 |
responses:
|
| 478 |
"200":
|
| 479 |
-
description:
|
| 480 |
content:
|
| 481 |
-
|
| 482 |
schema:
|
| 483 |
-
|
| 484 |
-
- $ref: "#/components/schemas/KgAnomalyStandardResponse"
|
| 485 |
-
- $ref: "#/components/schemas/KgAnomalyMultiProxResponse"
|
| 486 |
"400":
|
| 487 |
$ref: "#/components/responses/InvalidRequest"
|
| 488 |
"404":
|
|
@@ -498,11 +557,13 @@ paths:
|
|
| 498 |
post:
|
| 499 |
operationId: kgAnomalyContinue
|
| 500 |
tags: [kg-anomaly]
|
| 501 |
-
summary: Advance MultiProx correction by one step (
|
| 502 |
description: |
|
| 503 |
-
Advances the MultiProx correction
|
| 504 |
-
The client must send back the opaque
|
| 505 |
-
|
|
|
|
|
|
|
| 506 |
requestBody:
|
| 507 |
required: true
|
| 508 |
content:
|
|
@@ -511,11 +572,11 @@ paths:
|
|
| 511 |
$ref: "#/components/schemas/KgAnomalyContinueRequest"
|
| 512 |
responses:
|
| 513 |
"200":
|
| 514 |
-
description:
|
| 515 |
content:
|
| 516 |
-
|
| 517 |
schema:
|
| 518 |
-
$ref: "#/components/schemas/
|
| 519 |
"400":
|
| 520 |
$ref: "#/components/responses/InvalidRequest"
|
| 521 |
"429":
|
|
@@ -1291,6 +1352,36 @@ components:
|
|
| 1291 |
format: float
|
| 1292 |
example: 800
|
| 1293 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1294 |
# -- KG Anomaly Correction: Discovery --
|
| 1295 |
KgAnomalyDatasetsResponse:
|
| 1296 |
type: object
|
|
@@ -1411,7 +1502,7 @@ components:
|
|
| 1411 |
$ref: "#/components/schemas/SamplingModeEnum"
|
| 1412 |
task:
|
| 1413 |
$ref: "#/components/schemas/KgAnomalyTaskEnum"
|
| 1414 |
-
default:
|
| 1415 |
description: |
|
| 1416 |
"generate" = ignore the input subgraph edges and generate a new subgraph from scratch.
|
| 1417 |
"correct" (default) = keep fixed edges unchanged, only correct the masked (anomalous) edges.
|
|
@@ -1570,3 +1661,33 @@ components:
|
|
| 1570 |
removed:
|
| 1571 |
type: integer
|
| 1572 |
example: 0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 62 |
schema:
|
| 63 |
$ref: "#/components/schemas/MethodsResponse"
|
| 64 |
|
| 65 |
+
/debug/force-unlock:
|
| 66 |
+
post:
|
| 67 |
+
operationId: forceUnlockInferenceLock
|
| 68 |
+
tags: [health]
|
| 69 |
+
summary: Release a stuck inference lock (debug only)
|
| 70 |
+
description: |
|
| 71 |
+
Forcibly releases the global inference lock. Only available when
|
| 72 |
+
the server is running with `DJANGO_DEBUG=True`; returns `403` in
|
| 73 |
+
production. Use when a crashed request left the lock held and
|
| 74 |
+
subsequent requests are returning `429 INFERENCE_BUSY`.
|
| 75 |
+
responses:
|
| 76 |
+
"200":
|
| 77 |
+
description: Lock release result
|
| 78 |
+
content:
|
| 79 |
+
application/json:
|
| 80 |
+
schema:
|
| 81 |
+
type: object
|
| 82 |
+
required: [released]
|
| 83 |
+
properties:
|
| 84 |
+
released:
|
| 85 |
+
type: boolean
|
| 86 |
+
description: True if a held lock was released; false if the lock was already free.
|
| 87 |
+
example: true
|
| 88 |
+
"403":
|
| 89 |
+
description: Not available outside debug mode
|
| 90 |
+
content:
|
| 91 |
+
application/json:
|
| 92 |
+
schema:
|
| 93 |
+
type: object
|
| 94 |
+
properties:
|
| 95 |
+
error:
|
| 96 |
+
type: string
|
| 97 |
+
example: only available in debug mode
|
| 98 |
+
|
| 99 |
# -- COINs -----------------------------------------------------------
|
| 100 |
/coins/datasets:
|
| 101 |
get:
|
|
|
|
| 313 |
post:
|
| 314 |
operationId: graphGenGenerate
|
| 315 |
tags: [graph-generation]
|
| 316 |
+
summary: Generate a graph (SSE streaming)
|
| 317 |
description: |
|
| 318 |
+
Server-Sent Events stream (`text/event-stream`). Emits `progress`
|
| 319 |
+
events during diffusion, optional `preview` events with intermediate
|
| 320 |
+
PNGs, and a terminal `result` event whose `data` payload is the JSON
|
| 321 |
+
described below.
|
| 322 |
+
|
| 323 |
+
**Standard mode**: runs full diffusion (T->0); terminal `result`
|
| 324 |
+
payload conforms to `GraphGenStandardResponse` (animated GIF + final PNG).
|
| 325 |
|
| 326 |
+
**MultiProx mode**: runs the first Gibbs iteration; terminal `result`
|
| 327 |
+
payload conforms to `GraphGenMultiProxResponse` and includes an opaque
|
| 328 |
+
`state` blob to be passed to `/graph-generation/continue`.
|
|
|
|
| 329 |
requestBody:
|
| 330 |
required: true
|
| 331 |
content:
|
|
|
|
| 356 |
t_prime: 0.1
|
| 357 |
responses:
|
| 358 |
"200":
|
| 359 |
+
description: SSE stream of progress/preview events terminated by a result event
|
| 360 |
content:
|
| 361 |
+
text/event-stream:
|
| 362 |
schema:
|
| 363 |
+
$ref: "#/components/schemas/GraphGenSseStream"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 364 |
"400":
|
| 365 |
$ref: "#/components/responses/InvalidRequest"
|
| 366 |
"429":
|
|
|
|
| 372 |
post:
|
| 373 |
operationId: graphGenContinue
|
| 374 |
tags: [graph-generation]
|
| 375 |
+
summary: Advance MultiProx generation by one step (SSE streaming)
|
| 376 |
description: |
|
| 377 |
+
SSE stream (`text/event-stream`). Advances the MultiProx
|
| 378 |
+
multi-measurement chain by one Gibbs iteration. The client must send
|
| 379 |
+
back the opaque `state` from the previous step's `result` event.
|
| 380 |
+
Emits `progress` and `preview` events, then a terminal `result` event
|
| 381 |
+
with the `GraphGenMultiProxResponse` payload (including the updated
|
| 382 |
+
`state`). The API remains fully stateless -- no server-side sessions.
|
| 383 |
requestBody:
|
| 384 |
required: true
|
| 385 |
content:
|
|
|
|
| 388 |
$ref: "#/components/schemas/GraphGenContinueRequest"
|
| 389 |
responses:
|
| 390 |
"200":
|
| 391 |
+
description: SSE stream of progress/preview events terminated by a result event
|
| 392 |
content:
|
| 393 |
+
text/event-stream:
|
| 394 |
schema:
|
| 395 |
+
$ref: "#/components/schemas/GraphGenSseStream"
|
| 396 |
"400":
|
| 397 |
$ref: "#/components/responses/InvalidRequest"
|
| 398 |
"429":
|
|
|
|
| 418 |
operationId: getKgAnomalySampleSubgraphs
|
| 419 |
tags: [kg-anomaly]
|
| 420 |
summary: Get example subgraphs for correction
|
| 421 |
+
description: |
|
| 422 |
+
Returns pre-computed example subgraphs from the test set. When
|
| 423 |
+
`noise_level` is supplied, the model's forward diffusion is applied
|
| 424 |
+
to each subgraph's edges so the caller receives a corrupted input
|
| 425 |
+
ready for `/kg-anomaly/correct`. For `task=correct` only the edges
|
| 426 |
+
inside the inpaint mask (second half of nodes) are noised; for
|
| 427 |
+
`task=generate` every edge is noised.
|
| 428 |
parameters:
|
| 429 |
- $ref: "#/components/parameters/KgAnomalyDatasetId"
|
| 430 |
- name: count
|
|
|
|
| 435 |
maximum: 10
|
| 436 |
default: 5
|
| 437 |
description: Number of sample subgraphs to return
|
| 438 |
+
- name: noise_level
|
| 439 |
+
in: query
|
| 440 |
+
required: false
|
| 441 |
+
schema:
|
| 442 |
+
type: number
|
| 443 |
+
minimum: 0.0
|
| 444 |
+
exclusiveMinimum: true
|
| 445 |
+
maximum: 1.0
|
| 446 |
+
description: |
|
| 447 |
+
Fraction of the full diffusion horizon T at which to sample
|
| 448 |
+
noised edges (e.g. 0.4 for moderate corruption). Omit to receive
|
| 449 |
+
the clean subgraphs.
|
| 450 |
+
- name: task
|
| 451 |
+
in: query
|
| 452 |
+
required: false
|
| 453 |
+
schema:
|
| 454 |
+
type: string
|
| 455 |
+
enum: [correct, generate]
|
| 456 |
+
default: correct
|
| 457 |
+
description: Task the noise should align with. Ignored if noise_level is not set.
|
| 458 |
+
- name: seed
|
| 459 |
+
in: query
|
| 460 |
+
required: false
|
| 461 |
+
schema:
|
| 462 |
+
type: integer
|
| 463 |
+
description: Optional RNG seed for reproducible noise.
|
| 464 |
responses:
|
| 465 |
"200":
|
| 466 |
description: Sample subgraphs
|
|
|
|
| 475 |
post:
|
| 476 |
operationId: kgAnomalyCorrect
|
| 477 |
tags: [kg-anomaly]
|
| 478 |
+
summary: Correct a KG subgraph (SSE streaming)
|
| 479 |
description: |
|
| 480 |
+
Server-Sent Events stream (`text/event-stream`). Emits `progress`
|
| 481 |
+
events during diffusion, optional `preview` events with intermediate
|
| 482 |
+
PNGs, and a terminal `result` event whose `data` payload is the JSON
|
| 483 |
+
described below.
|
| 484 |
|
| 485 |
+
**Standard mode**: runs full diffusion correction; terminal `result`
|
| 486 |
+
payload conforms to `KgAnomalyStandardResponse`.
|
| 487 |
+
|
| 488 |
+
**MultiProx mode**: runs the first Gibbs iteration; terminal `result`
|
| 489 |
+
payload conforms to `KgAnomalyMultiProxResponse` and includes an
|
| 490 |
+
opaque `state` blob to be passed to `/kg-anomaly/continue`.
|
| 491 |
requestBody:
|
| 492 |
required: true
|
| 493 |
content:
|
|
|
|
| 537 |
t_prime: 0.1
|
| 538 |
responses:
|
| 539 |
"200":
|
| 540 |
+
description: SSE stream of progress/preview events terminated by a result event
|
| 541 |
content:
|
| 542 |
+
text/event-stream:
|
| 543 |
schema:
|
| 544 |
+
$ref: "#/components/schemas/KgAnomalySseStream"
|
|
|
|
|
|
|
| 545 |
"400":
|
| 546 |
$ref: "#/components/responses/InvalidRequest"
|
| 547 |
"404":
|
|
|
|
| 557 |
post:
|
| 558 |
operationId: kgAnomalyContinue
|
| 559 |
tags: [kg-anomaly]
|
| 560 |
+
summary: Advance MultiProx correction by one step (SSE streaming)
|
| 561 |
description: |
|
| 562 |
+
SSE stream (`text/event-stream`). Advances the MultiProx correction
|
| 563 |
+
chain by one Gibbs iteration. The client must send back the opaque
|
| 564 |
+
`state` from the previous step's `result` event. Emits `progress`
|
| 565 |
+
and `preview` events, then a terminal `result` event with the
|
| 566 |
+
`KgAnomalyMultiProxResponse` payload (including the updated `state`).
|
| 567 |
requestBody:
|
| 568 |
required: true
|
| 569 |
content:
|
|
|
|
| 572 |
$ref: "#/components/schemas/KgAnomalyContinueRequest"
|
| 573 |
responses:
|
| 574 |
"200":
|
| 575 |
+
description: SSE stream of progress/preview events terminated by a result event
|
| 576 |
content:
|
| 577 |
+
text/event-stream:
|
| 578 |
schema:
|
| 579 |
+
$ref: "#/components/schemas/KgAnomalySseStream"
|
| 580 |
"400":
|
| 581 |
$ref: "#/components/responses/InvalidRequest"
|
| 582 |
"429":
|
|
|
|
| 1352 |
format: float
|
| 1353 |
example: 800
|
| 1354 |
|
| 1355 |
+
# -- Graph Generation SSE stream ----------------------------------
|
| 1356 |
+
GraphGenSseStream:
|
| 1357 |
+
type: string
|
| 1358 |
+
description: |
|
| 1359 |
+
SSE text stream. Each event is `event: <name>\ndata: <payload>\n\n`.
|
| 1360 |
+
|
| 1361 |
+
* `event: progress` -- payload is `GraphGenProgressEvent` JSON.
|
| 1362 |
+
* `event: preview` -- payload is a raw `data:image/png;base64,...` data URI
|
| 1363 |
+
(intermediate graph snapshot; not JSON).
|
| 1364 |
+
* `event: result` -- payload is a `GraphGenStandardResponse` (standard mode)
|
| 1365 |
+
or `GraphGenMultiProxResponse` (multiprox mode / continue) JSON.
|
| 1366 |
+
* `event: error` -- payload is an error object with `code` and `message`.
|
| 1367 |
+
|
| 1368 |
+
GraphGenProgressEvent:
|
| 1369 |
+
type: object
|
| 1370 |
+
required: [type, stage]
|
| 1371 |
+
properties:
|
| 1372 |
+
type:
|
| 1373 |
+
type: string
|
| 1374 |
+
enum: [progress]
|
| 1375 |
+
stage:
|
| 1376 |
+
type: string
|
| 1377 |
+
description: Current phase (e.g. "denoise", "noise", "refine")
|
| 1378 |
+
step:
|
| 1379 |
+
type: integer
|
| 1380 |
+
description: Current step within the stage
|
| 1381 |
+
total:
|
| 1382 |
+
type: integer
|
| 1383 |
+
description: Total steps in the stage
|
| 1384 |
+
|
| 1385 |
# -- KG Anomaly Correction: Discovery --
|
| 1386 |
KgAnomalyDatasetsResponse:
|
| 1387 |
type: object
|
|
|
|
| 1502 |
$ref: "#/components/schemas/SamplingModeEnum"
|
| 1503 |
task:
|
| 1504 |
$ref: "#/components/schemas/KgAnomalyTaskEnum"
|
| 1505 |
+
default: correct
|
| 1506 |
description: |
|
| 1507 |
"generate" = ignore the input subgraph edges and generate a new subgraph from scratch.
|
| 1508 |
"correct" (default) = keep fixed edges unchanged, only correct the masked (anomalous) edges.
|
|
|
|
| 1661 |
removed:
|
| 1662 |
type: integer
|
| 1663 |
example: 0
|
| 1664 |
+
|
| 1665 |
+
# -- KG Anomaly SSE stream ----------------------------------------
|
| 1666 |
+
KgAnomalySseStream:
|
| 1667 |
+
type: string
|
| 1668 |
+
description: |
|
| 1669 |
+
SSE text stream. Each event is `event: <name>\ndata: <payload>\n\n`.
|
| 1670 |
+
|
| 1671 |
+
* `event: progress` -- payload is `KgAnomalyProgressEvent` JSON.
|
| 1672 |
+
* `event: preview` -- payload is a raw `data:image/png;base64,...` data URI
|
| 1673 |
+
(intermediate subgraph snapshot; not JSON).
|
| 1674 |
+
* `event: result` -- payload is a `KgAnomalyStandardResponse` (standard mode)
|
| 1675 |
+
or `KgAnomalyMultiProxResponse` (multiprox mode / continue) JSON.
|
| 1676 |
+
* `event: error` -- payload is an error object with `code` and `message`.
|
| 1677 |
+
|
| 1678 |
+
KgAnomalyProgressEvent:
|
| 1679 |
+
type: object
|
| 1680 |
+
required: [type, stage]
|
| 1681 |
+
properties:
|
| 1682 |
+
type:
|
| 1683 |
+
type: string
|
| 1684 |
+
enum: [progress]
|
| 1685 |
+
stage:
|
| 1686 |
+
type: string
|
| 1687 |
+
description: Current phase (e.g. "denoise", "noise", "refine")
|
| 1688 |
+
step:
|
| 1689 |
+
type: integer
|
| 1690 |
+
description: Current step within the stage
|
| 1691 |
+
total:
|
| 1692 |
+
type: integer
|
| 1693 |
+
description: Total steps in the stage
|
|
@@ -612,7 +612,45 @@
|
|
| 612 |
{ "key": "count", "value": "3" }
|
| 613 |
]
|
| 614 |
},
|
| 615 |
-
"description": "Pre-computed example subgraphs for correction."
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 616 |
}
|
| 617 |
}
|
| 618 |
]
|
|
@@ -629,7 +667,7 @@
|
|
| 629 |
],
|
| 630 |
"body": {
|
| 631 |
"mode": "raw",
|
| 632 |
-
"raw": "{\n \"dataset_id\": \"wordnet\",\n \"sampling_mode\": \"standard\",\n \"task\": \"correct\",\n \"subgraph\": {\n \"nodes\": [\n {\"entity_id\":
|
| 633 |
},
|
| 634 |
"url": {
|
| 635 |
"raw": "{{base_url}}/kg-anomaly/correct",
|
|
@@ -648,7 +686,7 @@
|
|
| 648 |
],
|
| 649 |
"body": {
|
| 650 |
"mode": "raw",
|
| 651 |
-
"raw": "{\n \"dataset_id\": \"wordnet\",\n \"sampling_mode\": \"standard\",\n \"task\": \"generate\",\n \"subgraph\": {\n \"nodes\": [\n {\"entity_id\":
|
| 652 |
},
|
| 653 |
"url": {
|
| 654 |
"raw": "{{base_url}}/kg-anomaly/correct",
|
|
@@ -659,7 +697,16 @@
|
|
| 659 |
}
|
| 660 |
},
|
| 661 |
{
|
| 662 |
-
"name": "POST /kg-anomaly/correct (multiprox
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 663 |
"request": {
|
| 664 |
"method": "POST",
|
| 665 |
"header": [
|
|
@@ -667,18 +714,27 @@
|
|
| 667 |
],
|
| 668 |
"body": {
|
| 669 |
"mode": "raw",
|
| 670 |
-
"raw": "{\n \"dataset_id\": \"wordnet\",\n \"sampling_mode\": \"multiprox\",\n \"task\": \"correct\",\n \"subgraph\": {\n \"nodes\": [\n {\"entity_id\":
|
| 671 |
},
|
| 672 |
"url": {
|
| 673 |
"raw": "{{base_url}}/kg-anomaly/correct",
|
| 674 |
"host": ["{{base_url}}"],
|
| 675 |
"path": ["kg-anomaly", "correct"]
|
| 676 |
},
|
| 677 |
-
"description": "MultiProx correction.
|
| 678 |
}
|
| 679 |
},
|
| 680 |
{
|
| 681 |
"name": "POST /kg-anomaly/continue",
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 682 |
"request": {
|
| 683 |
"method": "POST",
|
| 684 |
"header": [
|
|
@@ -686,14 +742,14 @@
|
|
| 686 |
],
|
| 687 |
"body": {
|
| 688 |
"mode": "raw",
|
| 689 |
-
"raw": "{\n \"state\": \"
|
| 690 |
},
|
| 691 |
"url": {
|
| 692 |
"raw": "{{base_url}}/kg-anomaly/continue",
|
| 693 |
"host": ["{{base_url}}"],
|
| 694 |
"path": ["kg-anomaly", "continue"]
|
| 695 |
},
|
| 696 |
-
"description": "Advance MultiProx correction by one step."
|
| 697 |
}
|
| 698 |
}
|
| 699 |
]
|
|
|
|
| 612 |
{ "key": "count", "value": "3" }
|
| 613 |
]
|
| 614 |
},
|
| 615 |
+
"description": "Pre-computed example subgraphs for correction (clean)."
|
| 616 |
+
}
|
| 617 |
+
},
|
| 618 |
+
{
|
| 619 |
+
"name": "GET /kg-anomaly/datasets/{id}/sample-subgraphs (noised, correct)",
|
| 620 |
+
"request": {
|
| 621 |
+
"method": "GET",
|
| 622 |
+
"header": [],
|
| 623 |
+
"url": {
|
| 624 |
+
"raw": "{{base_url}}/kg-anomaly/datasets/wordnet/sample-subgraphs?count=3&noise_level=0.4&task=correct&seed=42",
|
| 625 |
+
"host": ["{{base_url}}"],
|
| 626 |
+
"path": ["kg-anomaly", "datasets", "wordnet", "sample-subgraphs"],
|
| 627 |
+
"query": [
|
| 628 |
+
{ "key": "count", "value": "3" },
|
| 629 |
+
{ "key": "noise_level", "value": "0.4" },
|
| 630 |
+
{ "key": "task", "value": "correct" },
|
| 631 |
+
{ "key": "seed", "value": "42" }
|
| 632 |
+
]
|
| 633 |
+
},
|
| 634 |
+
"description": "Pre-noised example subgraphs for the 'correct' task (only inpaint-mask region is corrupted)."
|
| 635 |
+
}
|
| 636 |
+
},
|
| 637 |
+
{
|
| 638 |
+
"name": "GET /kg-anomaly/datasets/{id}/sample-subgraphs (noised, generate)",
|
| 639 |
+
"request": {
|
| 640 |
+
"method": "GET",
|
| 641 |
+
"header": [],
|
| 642 |
+
"url": {
|
| 643 |
+
"raw": "{{base_url}}/kg-anomaly/datasets/wordnet/sample-subgraphs?count=3&noise_level=0.4&task=generate&seed=43",
|
| 644 |
+
"host": ["{{base_url}}"],
|
| 645 |
+
"path": ["kg-anomaly", "datasets", "wordnet", "sample-subgraphs"],
|
| 646 |
+
"query": [
|
| 647 |
+
{ "key": "count", "value": "3" },
|
| 648 |
+
{ "key": "noise_level", "value": "0.4" },
|
| 649 |
+
{ "key": "task", "value": "generate" },
|
| 650 |
+
{ "key": "seed", "value": "43" }
|
| 651 |
+
]
|
| 652 |
+
},
|
| 653 |
+
"description": "Pre-noised example subgraphs for the 'generate' task (all edges corrupted)."
|
| 654 |
}
|
| 655 |
}
|
| 656 |
]
|
|
|
|
| 667 |
],
|
| 668 |
"body": {
|
| 669 |
"mode": "raw",
|
| 670 |
+
"raw": "{\n \"dataset_id\": \"wordnet\",\n \"sampling_mode\": \"standard\",\n \"task\": \"correct\",\n \"subgraph\": {\n \"nodes\": [\n {\"entity_id\": 28155, \"type_id\": 1},\n {\"entity_id\": 29348, \"type_id\": 4},\n {\"entity_id\": 29358, \"type_id\": 1},\n {\"entity_id\": 36247, \"type_id\": 1},\n {\"entity_id\": 36248, \"type_id\": 4},\n {\"entity_id\": 36855, \"type_id\": 1},\n {\"entity_id\": 36858, \"type_id\": 4},\n {\"entity_id\": 36860, \"type_id\": 4},\n {\"entity_id\": 36881, \"type_id\": 1},\n {\"entity_id\": 39993, \"type_id\": 1}\n ],\n \"edges\": [\n {\"source_idx\": 1, \"target_idx\": 2, \"relation_id\": 1},\n {\"source_idx\": 1, \"target_idx\": 3, \"relation_id\": 1},\n {\"source_idx\": 2, \"target_idx\": 1, \"relation_id\": 1},\n {\"source_idx\": 2, \"target_idx\": 4, \"relation_id\": 1},\n {\"source_idx\": 3, \"target_idx\": 1, \"relation_id\": 1},\n {\"source_idx\": 3, \"target_idx\": 4, \"relation_id\": 1},\n {\"source_idx\": 4, \"target_idx\": 2, \"relation_id\": 1},\n {\"source_idx\": 4, \"target_idx\": 3, \"relation_id\": 1},\n {\"source_idx\": 5, \"target_idx\": 6, \"relation_id\": 2},\n {\"source_idx\": 5, \"target_idx\": 7, \"relation_id\": 1},\n {\"source_idx\": 5, \"target_idx\": 8, \"relation_id\": 5},\n {\"source_idx\": 5, \"target_idx\": 9, \"relation_id\": 4},\n {\"source_idx\": 6, \"target_idx\": 4, \"relation_id\": 3},\n {\"source_idx\": 6, \"target_idx\": 5, \"relation_id\": 1},\n {\"source_idx\": 6, \"target_idx\": 8, \"relation_id\": 1},\n {\"source_idx\": 7, \"target_idx\": 5, \"relation_id\": 10},\n {\"source_idx\": 7, \"target_idx\": 6, \"relation_id\": 10},\n {\"source_idx\": 8, \"target_idx\": 5, \"relation_id\": 1},\n {\"source_idx\": 8, \"target_idx\": 6, \"relation_id\": 1},\n {\"source_idx\": 8, \"target_idx\": 7, \"relation_id\": 1},\n {\"source_idx\": 8, \"target_idx\": 9, \"relation_id\": 1},\n {\"source_idx\": 9, \"target_idx\": 0, \"relation_id\": 3},\n {\"source_idx\": 9, \"target_idx\": 6, \"relation_id\": 10},\n {\"source_idx\": 9, \"target_idx\": 7, \"relation_id\": 3},\n {\"source_idx\": 9, \"target_idx\": 8, \"relation_id\": 10}\n ]\n },\n \"diffusion_steps\": 500,\n \"chain_frames\": 20\n}"
|
| 671 |
},
|
| 672 |
"url": {
|
| 673 |
"raw": "{{base_url}}/kg-anomaly/correct",
|
|
|
|
| 686 |
],
|
| 687 |
"body": {
|
| 688 |
"mode": "raw",
|
| 689 |
+
"raw": "{\n \"dataset_id\": \"wordnet\",\n \"sampling_mode\": \"standard\",\n \"task\": \"generate\",\n \"subgraph\": {\n \"nodes\": [\n {\"entity_id\": 28155, \"type_id\": 1},\n {\"entity_id\": 29348, \"type_id\": 4},\n {\"entity_id\": 29358, \"type_id\": 1},\n {\"entity_id\": 36247, \"type_id\": 1},\n {\"entity_id\": 36248, \"type_id\": 4},\n {\"entity_id\": 36855, \"type_id\": 1},\n {\"entity_id\": 36858, \"type_id\": 4},\n {\"entity_id\": 36860, \"type_id\": 4},\n {\"entity_id\": 36881, \"type_id\": 1},\n {\"entity_id\": 39993, \"type_id\": 1}\n ],\n \"edges\": [\n {\"source_idx\": 0, \"target_idx\": 2, \"relation_id\": 10},\n {\"source_idx\": 0, \"target_idx\": 4, \"relation_id\": 2},\n {\"source_idx\": 1, \"target_idx\": 0, \"relation_id\": 0},\n {\"source_idx\": 1, \"target_idx\": 2, \"relation_id\": 2},\n {\"source_idx\": 1, \"target_idx\": 3, \"relation_id\": 1},\n {\"source_idx\": 1, \"target_idx\": 5, \"relation_id\": 9},\n {\"source_idx\": 1, \"target_idx\": 7, \"relation_id\": 7},\n {\"source_idx\": 1, \"target_idx\": 8, \"relation_id\": 3},\n {\"source_idx\": 2, \"target_idx\": 0, \"relation_id\": 7},\n {\"source_idx\": 2, \"target_idx\": 1, \"relation_id\": 1},\n {\"source_idx\": 2, \"target_idx\": 4, \"relation_id\": 5},\n {\"source_idx\": 2, \"target_idx\": 8, \"relation_id\": 10},\n {\"source_idx\": 2, \"target_idx\": 9, \"relation_id\": 2},\n {\"source_idx\": 3, \"target_idx\": 1, \"relation_id\": 1},\n {\"source_idx\": 3, \"target_idx\": 4, \"relation_id\": 1},\n {\"source_idx\": 3, \"target_idx\": 5, \"relation_id\": 7},\n {\"source_idx\": 3, \"target_idx\": 6, \"relation_id\": 6},\n {\"source_idx\": 3, \"target_idx\": 7, \"relation_id\": 0},\n {\"source_idx\": 4, \"target_idx\": 2, \"relation_id\": 1},\n {\"source_idx\": 4, \"target_idx\": 3, \"relation_id\": 6},\n {\"source_idx\": 4, \"target_idx\": 6, \"relation_id\": 7},\n {\"source_idx\": 4, \"target_idx\": 7, \"relation_id\": 7},\n {\"source_idx\": 5, \"target_idx\": 4, \"relation_id\": 2},\n {\"source_idx\": 5, \"target_idx\": 6, \"relation_id\": 2},\n {\"source_idx\": 5, \"target_idx\": 7, \"relation_id\": 6},\n {\"source_idx\": 5, \"target_idx\": 8, \"relation_id\": 1},\n {\"source_idx\": 5, \"target_idx\": 9, \"relation_id\": 1},\n {\"source_idx\": 6, \"target_idx\": 0, \"relation_id\": 5},\n {\"source_idx\": 6, \"target_idx\": 3, \"relation_id\": 7},\n {\"source_idx\": 6, \"target_idx\": 4, \"relation_id\": 3},\n {\"source_idx\": 6, \"target_idx\": 5, \"relation_id\": 1},\n {\"source_idx\": 6, \"target_idx\": 7, \"relation_id\": 0},\n {\"source_idx\": 6, \"target_idx\": 8, \"relation_id\": 1},\n {\"source_idx\": 6, \"target_idx\": 9, \"relation_id\": 4},\n {\"source_idx\": 7, \"target_idx\": 2, \"relation_id\": 10},\n {\"source_idx\": 7, \"target_idx\": 5, \"relation_id\": 5},\n {\"source_idx\": 7, \"target_idx\": 6, \"relation_id\": 5},\n {\"source_idx\": 8, \"target_idx\": 0, \"relation_id\": 0},\n {\"source_idx\": 8, \"target_idx\": 2, \"relation_id\": 6},\n {\"source_idx\": 8, \"target_idx\": 4, \"relation_id\": 10},\n {\"source_idx\": 8, \"target_idx\": 5, \"relation_id\": 1},\n {\"source_idx\": 8, \"target_idx\": 6, \"relation_id\": 1},\n {\"source_idx\": 8, \"target_idx\": 7, \"relation_id\": 1},\n {\"source_idx\": 8, \"target_idx\": 9, \"relation_id\": 0},\n {\"source_idx\": 9, \"target_idx\": 0, \"relation_id\": 3},\n {\"source_idx\": 9, \"target_idx\": 1, \"relation_id\": 8},\n {\"source_idx\": 9, \"target_idx\": 3, \"relation_id\": 5},\n {\"source_idx\": 9, \"target_idx\": 7, \"relation_id\": 5},\n {\"source_idx\": 9, \"target_idx\": 8, \"relation_id\": 1}\n ]\n },\n \"diffusion_steps\": 500,\n \"chain_frames\": 20\n}"
|
| 690 |
},
|
| 691 |
"url": {
|
| 692 |
"raw": "{{base_url}}/kg-anomaly/correct",
|
|
|
|
| 697 |
}
|
| 698 |
},
|
| 699 |
{
|
| 700 |
+
"name": "POST /kg-anomaly/correct (multiprox init)",
|
| 701 |
+
"event": [
|
| 702 |
+
{
|
| 703 |
+
"listen": "test",
|
| 704 |
+
"script": {
|
| 705 |
+
"type": "text/javascript",
|
| 706 |
+
"exec": ["var body = pm.response.text();", "var lines = body.split('\\n');", "for (var i = 0; i < lines.length; i++) {", " if (lines[i].trim() === 'event: result' && i + 1 < lines.length) {", " var dataLine = lines[i + 1].replace(/^data: /, '');", " try {", " var result = JSON.parse(dataLine);", " if (result.state) { pm.collectionVariables.set('multiprox_state', result.state); }", " } catch (e) {}", " break;", " }", "}"]
|
| 707 |
+
}
|
| 708 |
+
}
|
| 709 |
+
],
|
| 710 |
"request": {
|
| 711 |
"method": "POST",
|
| 712 |
"header": [
|
|
|
|
| 714 |
],
|
| 715 |
"body": {
|
| 716 |
"mode": "raw",
|
| 717 |
+
"raw": "{\n \"dataset_id\": \"wordnet\",\n \"sampling_mode\": \"multiprox\",\n \"task\": \"correct\",\n \"subgraph\": {\n \"nodes\": [\n {\"entity_id\": 28155, \"type_id\": 1},\n {\"entity_id\": 29348, \"type_id\": 4},\n {\"entity_id\": 29358, \"type_id\": 1},\n {\"entity_id\": 36247, \"type_id\": 1},\n {\"entity_id\": 36248, \"type_id\": 4},\n {\"entity_id\": 36855, \"type_id\": 1},\n {\"entity_id\": 36858, \"type_id\": 4},\n {\"entity_id\": 36860, \"type_id\": 4},\n {\"entity_id\": 36881, \"type_id\": 1},\n {\"entity_id\": 39993, \"type_id\": 1}\n ],\n \"edges\": [\n {\"source_idx\": 1, \"target_idx\": 2, \"relation_id\": 1},\n {\"source_idx\": 1, \"target_idx\": 3, \"relation_id\": 1},\n {\"source_idx\": 2, \"target_idx\": 1, \"relation_id\": 1},\n {\"source_idx\": 2, \"target_idx\": 4, \"relation_id\": 1},\n {\"source_idx\": 3, \"target_idx\": 1, \"relation_id\": 1},\n {\"source_idx\": 3, \"target_idx\": 4, \"relation_id\": 1},\n {\"source_idx\": 4, \"target_idx\": 2, \"relation_id\": 1},\n {\"source_idx\": 4, \"target_idx\": 3, \"relation_id\": 1},\n {\"source_idx\": 5, \"target_idx\": 6, \"relation_id\": 2},\n {\"source_idx\": 5, \"target_idx\": 7, \"relation_id\": 1},\n {\"source_idx\": 5, \"target_idx\": 8, \"relation_id\": 5},\n {\"source_idx\": 5, \"target_idx\": 9, \"relation_id\": 4},\n {\"source_idx\": 6, \"target_idx\": 4, \"relation_id\": 3},\n {\"source_idx\": 6, \"target_idx\": 5, \"relation_id\": 1},\n {\"source_idx\": 6, \"target_idx\": 8, \"relation_id\": 1},\n {\"source_idx\": 7, \"target_idx\": 5, \"relation_id\": 10},\n {\"source_idx\": 7, \"target_idx\": 6, \"relation_id\": 10},\n {\"source_idx\": 8, \"target_idx\": 5, \"relation_id\": 1},\n {\"source_idx\": 8, \"target_idx\": 6, \"relation_id\": 1},\n {\"source_idx\": 8, \"target_idx\": 7, \"relation_id\": 1},\n {\"source_idx\": 8, \"target_idx\": 9, \"relation_id\": 1},\n {\"source_idx\": 9, \"target_idx\": 0, \"relation_id\": 3},\n {\"source_idx\": 9, \"target_idx\": 6, \"relation_id\": 10},\n {\"source_idx\": 9, \"target_idx\": 7, \"relation_id\": 3},\n {\"source_idx\": 9, \"target_idx\": 8, \"relation_id\": 10}\n ]\n },\n \"multiprox_params\": {\n \"n\": 10,\n \"m\": 100,\n \"t\": 0.4,\n \"t_prime\": 0.1,\n \"gibbs_chain_freq\": 10\n }\n}"
|
| 718 |
},
|
| 719 |
"url": {
|
| 720 |
"raw": "{{base_url}}/kg-anomaly/correct",
|
| 721 |
"host": ["{{base_url}}"],
|
| 722 |
"path": ["kg-anomaly", "correct"]
|
| 723 |
},
|
| 724 |
+
"description": "MultiProx Gibbs init on wordnet correction. SSE stream; the result event's state blob is auto-saved to {{multiprox_state}}."
|
| 725 |
}
|
| 726 |
},
|
| 727 |
{
|
| 728 |
"name": "POST /kg-anomaly/continue",
|
| 729 |
+
"event": [
|
| 730 |
+
{
|
| 731 |
+
"listen": "test",
|
| 732 |
+
"script": {
|
| 733 |
+
"type": "text/javascript",
|
| 734 |
+
"exec": ["var body = pm.response.text();", "var lines = body.split('\\n');", "for (var i = 0; i < lines.length; i++) {", " if (lines[i].trim() === 'event: result' && i + 1 < lines.length) {", " var dataLine = lines[i + 1].replace(/^data: /, '');", " try {", " var result = JSON.parse(dataLine);", " if (result.state) {", " pm.collectionVariables.set('multiprox_state', result.state);", " console.log('State updated (done=' + result.done + ', step=' + result.step + ')');", " }", " } catch (e) {}", " break;", " }", "}"]
|
| 735 |
+
}
|
| 736 |
+
}
|
| 737 |
+
],
|
| 738 |
"request": {
|
| 739 |
"method": "POST",
|
| 740 |
"header": [
|
|
|
|
| 742 |
],
|
| 743 |
"body": {
|
| 744 |
"mode": "raw",
|
| 745 |
+
"raw": "{\n \"state\": \"{{multiprox_state}}\"\n}"
|
| 746 |
},
|
| 747 |
"url": {
|
| 748 |
"raw": "{{base_url}}/kg-anomaly/continue",
|
| 749 |
"host": ["{{base_url}}"],
|
| 750 |
"path": ["kg-anomaly", "continue"]
|
| 751 |
},
|
| 752 |
+
"description": "Advance MultiProx correction by one step. Uses {{multiprox_state}}; can be chained repeatedly."
|
| 753 |
}
|
| 754 |
}
|
| 755 |
]
|
|
@@ -88,17 +88,17 @@ All endpoints are prefixed with `/api/v1/`.
|
|
| 88 |
|---|---|---|
|
| 89 |
| `GET` | `/graph-generation/datasets` | List graph types with node/edge types |
|
| 90 |
| `GET` | `/graph-generation/sampling-modes` | Sampling strategies with parameter specs |
|
| 91 |
-
| `POST` | `/graph-generation/generate` | **Streaming
|
| 92 |
-
| `POST` | `/graph-generation/continue` | **Streaming
|
| 93 |
|
| 94 |
### KG Anomaly Correction
|
| 95 |
|
| 96 |
| Method | Path | Description |
|
| 97 |
|---|---|---|
|
| 98 |
| `GET` | `/kg-anomaly/datasets` | List datasets with correction models |
|
| 99 |
-
| `GET` | `/kg-anomaly/datasets/{id}/sample-subgraphs` | Pre-computed example subgraphs (`?count=5`) |
|
| 100 |
-
| `POST` | `/kg-anomaly/correct` |
|
| 101 |
-
| `POST` | `/kg-anomaly/continue` |
|
| 102 |
|
| 103 |
## Streaming Inference Protocol (SSE)
|
| 104 |
|
|
|
|
| 88 |
|---|---|---|
|
| 89 |
| `GET` | `/graph-generation/datasets` | List graph types with node/edge types |
|
| 90 |
| `GET` | `/graph-generation/sampling-modes` | Sampling strategies with parameter specs |
|
| 91 |
+
| `POST` | `/graph-generation/generate` | **Streaming SSE.** Generate a graph (standard denoising or MultiProx Gibbs init) |
|
| 92 |
+
| `POST` | `/graph-generation/continue` | **Streaming SSE.** Advance a MultiProx Gibbs session by one step |
|
| 93 |
|
| 94 |
### KG Anomaly Correction
|
| 95 |
|
| 96 |
| Method | Path | Description |
|
| 97 |
|---|---|---|
|
| 98 |
| `GET` | `/kg-anomaly/datasets` | List datasets with correction models |
|
| 99 |
+
| `GET` | `/kg-anomaly/datasets/{id}/sample-subgraphs` | Pre-computed example subgraphs (`?count=5&noise_level=0.4&task=correct&seed=42`); noise is task-aware |
|
| 100 |
+
| `POST` | `/kg-anomaly/correct` | **Streaming SSE.** Correct/regenerate a KG subgraph (standard denoising or MultiProx Gibbs init) |
|
| 101 |
+
| `POST` | `/kg-anomaly/continue` | **Streaming SSE.** Advance a MultiProx correction session by one step |
|
| 102 |
|
| 103 |
## Streaming Inference Protocol (SSE)
|
| 104 |
|
|
@@ -0,0 +1,708 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import base64
|
| 2 |
+
import io
|
| 3 |
+
import time
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
from api.services.graphgen_inference import (
|
| 9 |
+
_frames_to_gif_b64, _pil_to_b64,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
STATE_BLOB_MAX_BYTES = 10 * 1024 * 1024 # 10 MB
|
| 13 |
+
REQUIRED_STATE_KEYS = {
|
| 14 |
+
"X_given", "E", "y", "n_nodes", "dataset_id", "task", "X_index", "X_c",
|
| 15 |
+
"is_bip", "original_E_int", "T", "n", "m", "t", "t_prime",
|
| 16 |
+
"gibbs_chain_freq", "inner_step", "step",
|
| 17 |
+
}
|
| 18 |
+
|
| 19 |
+
CHANGE_COLORS = {
|
| 20 |
+
"unchanged": "#888888",
|
| 21 |
+
"modified": "#e67e22",
|
| 22 |
+
"added": "#27ae60",
|
| 23 |
+
"removed": "#e74c3c",
|
| 24 |
+
}
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
# ---------------------------------------------------------------------------
|
| 28 |
+
# Input subgraph -> tensors
|
| 29 |
+
# ---------------------------------------------------------------------------
|
| 30 |
+
|
| 31 |
+
def build_kg_tensors(subgraph, loader, model):
|
| 32 |
+
"""Convert API subgraph payload into model-ready tensors.
|
| 33 |
+
|
| 34 |
+
Returns a dict with X_given, E_given, y_given, X_index, X_c, n_nodes,
|
| 35 |
+
is_bip, node_mask — all on CPU, batch size 1.
|
| 36 |
+
"""
|
| 37 |
+
nodes = subgraph["nodes"]
|
| 38 |
+
edges = subgraph["edges"]
|
| 39 |
+
n = len(nodes)
|
| 40 |
+
|
| 41 |
+
X_given = torch.zeros(1, n, model.Xdim_output, dtype=torch.float32)
|
| 42 |
+
for i, node in enumerate(nodes):
|
| 43 |
+
type_id = int(node.get("type_id", 0))
|
| 44 |
+
if 0 <= type_id < model.Xdim_output:
|
| 45 |
+
X_given[0, i, type_id] = 1.0
|
| 46 |
+
else:
|
| 47 |
+
X_given[0, i, 0] = 1.0
|
| 48 |
+
|
| 49 |
+
E_given = torch.zeros(1, n, n, model.Edim_output, dtype=torch.float32)
|
| 50 |
+
# Default to class 0 ("no edge") everywhere
|
| 51 |
+
E_given[0, :, :, 0] = 1.0
|
| 52 |
+
for e in edges:
|
| 53 |
+
src = int(e["source_idx"])
|
| 54 |
+
tgt = int(e["target_idx"])
|
| 55 |
+
rel = int(e["relation_id"])
|
| 56 |
+
e_class = rel + 1
|
| 57 |
+
if not (0 <= src < n and 0 <= tgt < n):
|
| 58 |
+
continue
|
| 59 |
+
if not (1 <= e_class < model.Edim_output):
|
| 60 |
+
continue
|
| 61 |
+
E_given[0, src, tgt, :] = 0.0
|
| 62 |
+
E_given[0, src, tgt, e_class] = 1.0
|
| 63 |
+
|
| 64 |
+
y_given = torch.zeros(1, 0, dtype=torch.float32)
|
| 65 |
+
|
| 66 |
+
X_index = torch.zeros(1, n, dtype=torch.long)
|
| 67 |
+
for i, node in enumerate(nodes):
|
| 68 |
+
X_index[0, i] = int(node["entity_id"])
|
| 69 |
+
|
| 70 |
+
X_c = torch.zeros(1, n, dtype=torch.long)
|
| 71 |
+
communities = getattr(loader, "communities", None)
|
| 72 |
+
if communities is not None:
|
| 73 |
+
for i, node in enumerate(nodes):
|
| 74 |
+
eid = int(node["entity_id"])
|
| 75 |
+
if 0 <= eid < len(communities):
|
| 76 |
+
X_c[0, i] = int(communities[eid])
|
| 77 |
+
|
| 78 |
+
n_nodes = torch.tensor([n], dtype=torch.long)
|
| 79 |
+
is_bip = torch.tensor([n > 20], dtype=torch.bool)
|
| 80 |
+
node_mask = torch.ones(1, n, dtype=torch.bool)
|
| 81 |
+
|
| 82 |
+
return {
|
| 83 |
+
"X_given": X_given, "E_given": E_given, "y_given": y_given,
|
| 84 |
+
"X_index": X_index, "X_c": X_c, "n_nodes": n_nodes,
|
| 85 |
+
"is_bip": is_bip, "node_mask": node_mask,
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def _to_device(t, device):
|
| 90 |
+
return t.to(device) if isinstance(t, torch.Tensor) else t
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def apply_edge_noise(model, tensors, task, noise_level, seed=None):
|
| 94 |
+
"""Forward-diffuse the given subgraph's edges at t = noise_level * T.
|
| 95 |
+
|
| 96 |
+
For task="correct", only edges inside the inpaint mask (the second half of
|
| 97 |
+
nodes) are noised, matching what the correction endpoint will regenerate.
|
| 98 |
+
For task="generate", every edge slot is noised.
|
| 99 |
+
|
| 100 |
+
Returns a new list of {source_idx, target_idx, relation_id} dicts.
|
| 101 |
+
"""
|
| 102 |
+
from graph_generation.src.utils import get_inpaint_mask
|
| 103 |
+
from graph_generation.src.diffusion import diffusion_utils
|
| 104 |
+
|
| 105 |
+
if not (0.0 < noise_level <= 1.0):
|
| 106 |
+
raise ValueError("noise_level must be in (0, 1]")
|
| 107 |
+
|
| 108 |
+
device = next(model.parameters()).device
|
| 109 |
+
X = tensors["X_given"].to(device)
|
| 110 |
+
E = tensors["E_given"].to(device)
|
| 111 |
+
node_mask = tensors["node_mask"].to(device)
|
| 112 |
+
is_bip = tensors["is_bip"].to(device)
|
| 113 |
+
n = int(tensors["n_nodes"].item())
|
| 114 |
+
|
| 115 |
+
if task == "generate":
|
| 116 |
+
bs, n_max = node_mask.shape
|
| 117 |
+
inpaint_mask = torch.ones(
|
| 118 |
+
bs, n_max, n_max, model.Edim_output, dtype=torch.bool, device=device)
|
| 119 |
+
else:
|
| 120 |
+
inpaint_mask = get_inpaint_mask(node_mask, is_bip, model.Edim_output, device)
|
| 121 |
+
|
| 122 |
+
T = model.T
|
| 123 |
+
t_int = torch.tensor([[int(noise_level * T)]], dtype=torch.float, device=device)
|
| 124 |
+
t_float = t_int / T
|
| 125 |
+
alpha_t_bar = model.noise_schedule.get_alpha_bar(t_normalized=t_float)
|
| 126 |
+
Qtb = model.transition_model.get_Qt_bar(alpha_t_bar, device=device)
|
| 127 |
+
|
| 128 |
+
probX = X @ Qtb.X
|
| 129 |
+
probE = E @ Qtb.E.unsqueeze(1)
|
| 130 |
+
|
| 131 |
+
if seed is not None:
|
| 132 |
+
torch.manual_seed(int(seed))
|
| 133 |
+
|
| 134 |
+
sampled = diffusion_utils.sample_discrete_features(
|
| 135 |
+
probX=probX, probE=probE, node_mask=node_mask)
|
| 136 |
+
E_noised = F.one_hot(sampled.E, num_classes=model.Edim_output).float()
|
| 137 |
+
E_mixed = E_noised * inpaint_mask + E * (~inpaint_mask)
|
| 138 |
+
E_int = E_mixed[0].argmax(dim=-1).cpu()
|
| 139 |
+
|
| 140 |
+
edges = []
|
| 141 |
+
for i in range(n):
|
| 142 |
+
for j in range(n):
|
| 143 |
+
if i == j:
|
| 144 |
+
continue
|
| 145 |
+
cls = int(E_int[i, j])
|
| 146 |
+
if cls == 0:
|
| 147 |
+
continue
|
| 148 |
+
edges.append({
|
| 149 |
+
"source_idx": i, "target_idx": j, "relation_id": cls - 1,
|
| 150 |
+
})
|
| 151 |
+
return edges
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
# ---------------------------------------------------------------------------
|
| 155 |
+
# Change detection
|
| 156 |
+
# ---------------------------------------------------------------------------
|
| 157 |
+
|
| 158 |
+
def compute_changes(original_E_int, corrected_E_int, num_nodes, loader):
|
| 159 |
+
"""Compute before/after edge diff for a directed KG subgraph.
|
| 160 |
+
|
| 161 |
+
original_E_int / corrected_E_int: 2-D int tensors (n, n) where 0 = no edge
|
| 162 |
+
and classes 1..N are relation types. Returns {"edges": [...], "summary": {...}}.
|
| 163 |
+
"""
|
| 164 |
+
_, _, inv_relations = loader.dataset.get_inverted_name_maps()
|
| 165 |
+
|
| 166 |
+
edges = []
|
| 167 |
+
summary = {"added": 0, "removed": 0, "modified": 0, "unchanged": 0}
|
| 168 |
+
|
| 169 |
+
orig = original_E_int.cpu().tolist()
|
| 170 |
+
corr = corrected_E_int.cpu().tolist()
|
| 171 |
+
|
| 172 |
+
for i in range(num_nodes):
|
| 173 |
+
for j in range(num_nodes):
|
| 174 |
+
if i == j:
|
| 175 |
+
continue
|
| 176 |
+
o = int(orig[i][j])
|
| 177 |
+
c = int(corr[i][j])
|
| 178 |
+
if o == 0 and c == 0:
|
| 179 |
+
continue
|
| 180 |
+
if o == c:
|
| 181 |
+
summary["unchanged"] += 1
|
| 182 |
+
edges.append({
|
| 183 |
+
"source_idx": i, "target_idx": j, "change": "unchanged",
|
| 184 |
+
"relation_id": c - 1,
|
| 185 |
+
"relation_name": str(inv_relations.get(c - 1, c - 1)),
|
| 186 |
+
})
|
| 187 |
+
continue
|
| 188 |
+
if o == 0 and c > 0:
|
| 189 |
+
summary["added"] += 1
|
| 190 |
+
edges.append({
|
| 191 |
+
"source_idx": i, "target_idx": j, "change": "added",
|
| 192 |
+
"relation_id": c - 1,
|
| 193 |
+
"relation_name": str(inv_relations.get(c - 1, c - 1)),
|
| 194 |
+
})
|
| 195 |
+
elif o > 0 and c == 0:
|
| 196 |
+
summary["removed"] += 1
|
| 197 |
+
edges.append({
|
| 198 |
+
"source_idx": i, "target_idx": j, "change": "removed",
|
| 199 |
+
"original_relation_id": o - 1,
|
| 200 |
+
"original_relation_name": str(inv_relations.get(o - 1, o - 1)),
|
| 201 |
+
})
|
| 202 |
+
else:
|
| 203 |
+
summary["modified"] += 1
|
| 204 |
+
edges.append({
|
| 205 |
+
"source_idx": i, "target_idx": j, "change": "modified",
|
| 206 |
+
"original_relation_id": o - 1,
|
| 207 |
+
"original_relation_name": str(inv_relations.get(o - 1, o - 1)),
|
| 208 |
+
"relation_id": c - 1,
|
| 209 |
+
"relation_name": str(inv_relations.get(c - 1, c - 1)),
|
| 210 |
+
})
|
| 211 |
+
|
| 212 |
+
return {"edges": edges, "summary": summary}
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
# ---------------------------------------------------------------------------
|
| 216 |
+
# Rendering
|
| 217 |
+
# ---------------------------------------------------------------------------
|
| 218 |
+
|
| 219 |
+
def _format_entity_label(dataset_id, name):
|
| 220 |
+
s = str(name)
|
| 221 |
+
if dataset_id == "freebase":
|
| 222 |
+
s = s.replace("/m/", "")
|
| 223 |
+
elif dataset_id == "wordnet":
|
| 224 |
+
s = s.split(".")[0]
|
| 225 |
+
else:
|
| 226 |
+
if "concept" in s:
|
| 227 |
+
parts = s.split(":")
|
| 228 |
+
s = parts[-2] if "new" in s and len(parts) >= 2 else parts[-1]
|
| 229 |
+
if len(s) > 14:
|
| 230 |
+
s = s[:13] + "…"
|
| 231 |
+
return s
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
def _format_relation_label(dataset_id, name):
|
| 235 |
+
s = str(name)
|
| 236 |
+
if dataset_id == "freebase":
|
| 237 |
+
parts = s.split(".")
|
| 238 |
+
s = ".".join(["_".join(p.split("/")[-2:]) for p in parts])
|
| 239 |
+
elif dataset_id == "wordnet":
|
| 240 |
+
s = s[1:] if s.startswith("_") else s
|
| 241 |
+
else:
|
| 242 |
+
if "concept" in s:
|
| 243 |
+
parts = s.split(":")
|
| 244 |
+
s = parts[-2] if "new" in s and len(parts) >= 2 else parts[-1]
|
| 245 |
+
if len(s) > 16:
|
| 246 |
+
s = s[:15] + "…"
|
| 247 |
+
return s
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def render_kg_subgraph(E_int, num_nodes, X_index, dataset_id, loader, changes=None):
|
| 251 |
+
"""Render a directed KG subgraph as a PIL image using networkx + PIL.
|
| 252 |
+
|
| 253 |
+
Does not use matplotlib (same reason as graphgen_inference: Windows thread safety).
|
| 254 |
+
"""
|
| 255 |
+
import networkx as nx
|
| 256 |
+
from PIL import Image, ImageDraw, ImageFont
|
| 257 |
+
|
| 258 |
+
inv_nodes, _, inv_relations = loader.dataset.get_inverted_name_maps()
|
| 259 |
+
|
| 260 |
+
e = E_int.cpu().tolist()
|
| 261 |
+
xi = X_index.cpu().tolist()
|
| 262 |
+
|
| 263 |
+
G = nx.DiGraph()
|
| 264 |
+
for i in range(num_nodes):
|
| 265 |
+
G.add_node(i)
|
| 266 |
+
for i in range(num_nodes):
|
| 267 |
+
for j in range(num_nodes):
|
| 268 |
+
if i == j:
|
| 269 |
+
continue
|
| 270 |
+
if int(e[i][j]) > 0:
|
| 271 |
+
G.add_edge(i, j, rel=int(e[i][j]) - 1)
|
| 272 |
+
|
| 273 |
+
pos = nx.spring_layout(G, seed=42)
|
| 274 |
+
|
| 275 |
+
# Build change lookup: (i, j) -> change_type
|
| 276 |
+
change_lookup = {}
|
| 277 |
+
if changes is not None:
|
| 278 |
+
for entry in changes.get("edges", []):
|
| 279 |
+
change_lookup[(entry["source_idx"], entry["target_idx"])] = entry["change"]
|
| 280 |
+
|
| 281 |
+
size = 500
|
| 282 |
+
margin = 50
|
| 283 |
+
scale = (size - 2 * margin) / 2
|
| 284 |
+
cx, cy = size / 2, size / 2
|
| 285 |
+
pixel_pos = {k: (cx + v[0] * scale, cy + v[1] * scale) for k, v in pos.items()}
|
| 286 |
+
|
| 287 |
+
img = Image.new("RGB", (size, size), "white")
|
| 288 |
+
draw = ImageDraw.Draw(img)
|
| 289 |
+
try:
|
| 290 |
+
font = ImageFont.truetype("arial.ttf", 11)
|
| 291 |
+
small_font = ImageFont.truetype("arial.ttf", 9)
|
| 292 |
+
except (OSError, IOError):
|
| 293 |
+
font = ImageFont.load_default()
|
| 294 |
+
small_font = font
|
| 295 |
+
|
| 296 |
+
node_r = 10
|
| 297 |
+
|
| 298 |
+
# Draw edges first (so nodes overlay them)
|
| 299 |
+
# Include "removed" edges from change_lookup even if not in G
|
| 300 |
+
all_edges = set((i, j) for i, j in G.edges())
|
| 301 |
+
if changes is not None:
|
| 302 |
+
for (i, j), ct in change_lookup.items():
|
| 303 |
+
if ct == "removed":
|
| 304 |
+
all_edges.add((i, j))
|
| 305 |
+
|
| 306 |
+
for (i, j) in all_edges:
|
| 307 |
+
change_type = change_lookup.get((i, j))
|
| 308 |
+
color = CHANGE_COLORS.get(change_type, "#444444") if changes is not None else "#444444"
|
| 309 |
+
dashed = (change_type == "removed")
|
| 310 |
+
x0, y0 = pixel_pos[i]
|
| 311 |
+
x1, y1 = pixel_pos[j]
|
| 312 |
+
# Shorten line to not overlap node circles
|
| 313 |
+
dx, dy = x1 - x0, y1 - y0
|
| 314 |
+
dist = max(1.0, (dx * dx + dy * dy) ** 0.5)
|
| 315 |
+
ux, uy = dx / dist, dy / dist
|
| 316 |
+
sx, sy = x0 + ux * node_r, y0 + uy * node_r
|
| 317 |
+
ex, ey = x1 - ux * node_r, y1 - uy * node_r
|
| 318 |
+
if dashed:
|
| 319 |
+
_draw_dashed(draw, (sx, sy), (ex, ey), color, width=2, dash=6)
|
| 320 |
+
else:
|
| 321 |
+
draw.line([(sx, sy), (ex, ey)], fill=color, width=2)
|
| 322 |
+
# Arrowhead
|
| 323 |
+
_draw_arrowhead(draw, (ex, ey), (ux, uy), color)
|
| 324 |
+
# Relation label
|
| 325 |
+
if (i, j) in G.edges():
|
| 326 |
+
rel_id = G.edges[(i, j)]["rel"]
|
| 327 |
+
rel_name = _format_relation_label(dataset_id, inv_relations.get(rel_id, rel_id))
|
| 328 |
+
mx, my = (sx + ex) / 2, (sy + ey) / 2
|
| 329 |
+
draw.text((mx + 3, my - 5), rel_name, fill=color, font=small_font)
|
| 330 |
+
|
| 331 |
+
# Draw nodes
|
| 332 |
+
for i in range(num_nodes):
|
| 333 |
+
x, y = pixel_pos[i]
|
| 334 |
+
draw.ellipse([x - node_r, y - node_r, x + node_r, y + node_r],
|
| 335 |
+
fill="#2ecc71", outline="#1a7a42")
|
| 336 |
+
eid = int(xi[i]) if i < len(xi) else i
|
| 337 |
+
label = _format_entity_label(dataset_id, inv_nodes.get(eid, eid))
|
| 338 |
+
draw.text((x + node_r + 2, y - 6), label, fill="#111111", font=font)
|
| 339 |
+
|
| 340 |
+
return img
|
| 341 |
+
|
| 342 |
+
|
| 343 |
+
def _draw_arrowhead(draw, tip, direction, color):
|
| 344 |
+
import math
|
| 345 |
+
ux, uy = direction
|
| 346 |
+
angle = math.atan2(uy, ux)
|
| 347 |
+
ah_len = 7
|
| 348 |
+
ah_angle = math.radians(25)
|
| 349 |
+
x, y = tip
|
| 350 |
+
x1 = x - ah_len * math.cos(angle - ah_angle)
|
| 351 |
+
y1 = y - ah_len * math.sin(angle - ah_angle)
|
| 352 |
+
x2 = x - ah_len * math.cos(angle + ah_angle)
|
| 353 |
+
y2 = y - ah_len * math.sin(angle + ah_angle)
|
| 354 |
+
draw.polygon([(x, y), (x1, y1), (x2, y2)], fill=color)
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def _draw_dashed(draw, start, end, color, width=2, dash=6):
|
| 358 |
+
x0, y0 = start
|
| 359 |
+
x1, y1 = end
|
| 360 |
+
dx, dy = x1 - x0, y1 - y0
|
| 361 |
+
dist = max(1.0, (dx * dx + dy * dy) ** 0.5)
|
| 362 |
+
steps = int(dist // dash)
|
| 363 |
+
ux, uy = dx / dist, dy / dist
|
| 364 |
+
for k in range(steps):
|
| 365 |
+
if k % 2 == 1:
|
| 366 |
+
continue
|
| 367 |
+
sx = x0 + ux * dash * k
|
| 368 |
+
sy = y0 + uy * dash * k
|
| 369 |
+
ex = x0 + ux * dash * min(k + 1, steps)
|
| 370 |
+
ey = y0 + uy * dash * min(k + 1, steps)
|
| 371 |
+
draw.line([(sx, sy), (ex, ey)], fill=color, width=width)
|
| 372 |
+
|
| 373 |
+
|
| 374 |
+
# ---------------------------------------------------------------------------
|
| 375 |
+
# Shared inference helpers
|
| 376 |
+
# ---------------------------------------------------------------------------
|
| 377 |
+
|
| 378 |
+
def _build_inpaint_mask(task, node_mask, is_bip, E_out_dim, device):
|
| 379 |
+
from graph_generation.src.utils import get_inpaint_mask
|
| 380 |
+
if task == "generate":
|
| 381 |
+
bs, n_max = node_mask.shape
|
| 382 |
+
return torch.ones(bs, n_max, n_max, E_out_dim, dtype=torch.bool, device=device)
|
| 383 |
+
return get_inpaint_mask(node_mask, is_bip, E_out_dim, device)
|
| 384 |
+
|
| 385 |
+
|
| 386 |
+
def _sample_initial_noise_kg(model, node_mask):
|
| 387 |
+
from graph_generation.src.diffusion import diffusion_utils
|
| 388 |
+
return diffusion_utils.sample_discrete_feature_noise(
|
| 389 |
+
limit_dist=model.limit_dist, node_mask=node_mask)
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
def _collapse_final_kg(model, X, E, y, node_mask):
|
| 393 |
+
from graph_generation.src.utils import PlaceHolder
|
| 394 |
+
final = PlaceHolder(X=X, E=E, y=y).mask(node_mask, collapse=True)
|
| 395 |
+
return final.X.long(), final.E.long()
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
# ---------------------------------------------------------------------------
|
| 399 |
+
# Standard correction / generation
|
| 400 |
+
# ---------------------------------------------------------------------------
|
| 401 |
+
|
| 402 |
+
def run_standard_correction(model, tensors, dataset_id, task, loader,
|
| 403 |
+
diffusion_steps, chain_frames):
|
| 404 |
+
device = next(model.parameters()).device
|
| 405 |
+
X_given = tensors["X_given"].to(device)
|
| 406 |
+
E_given = tensors["E_given"].to(device)
|
| 407 |
+
y_given = tensors["y_given"].to(device)
|
| 408 |
+
X_index = tensors["X_index"].to(device)
|
| 409 |
+
is_bip = tensors["is_bip"].to(device)
|
| 410 |
+
n_nodes = tensors["n_nodes"].to(device)
|
| 411 |
+
node_mask = tensors["node_mask"].to(device)
|
| 412 |
+
n_max = n_nodes.item()
|
| 413 |
+
|
| 414 |
+
inpaint_mask = _build_inpaint_mask(
|
| 415 |
+
task, node_mask, is_bip, model.Edim_output, device)
|
| 416 |
+
|
| 417 |
+
original_E_int = E_given[0].argmax(dim=-1).long() # (n, n)
|
| 418 |
+
original_img = render_kg_subgraph(
|
| 419 |
+
original_E_int, n_max, X_index[0], dataset_id, loader, changes=None)
|
| 420 |
+
|
| 421 |
+
model_T = model.T
|
| 422 |
+
step_stride = max(1, model_T // diffusion_steps)
|
| 423 |
+
total_loop_steps = (model_T + step_stride - 1) // step_stride
|
| 424 |
+
frame_interval = max(1, total_loop_steps // chain_frames)
|
| 425 |
+
|
| 426 |
+
with torch.no_grad():
|
| 427 |
+
z_T = _sample_initial_noise_kg(model, node_mask)
|
| 428 |
+
if task != "generate":
|
| 429 |
+
z_T.E = z_T.E * inpaint_mask + E_given * (~inpaint_mask)
|
| 430 |
+
X, E, y = X_given, z_T.E, y_given
|
| 431 |
+
|
| 432 |
+
gif_frames = []
|
| 433 |
+
t0 = time.time()
|
| 434 |
+
emitted = 0
|
| 435 |
+
for s_idx in reversed(range(0, model_T, step_stride)):
|
| 436 |
+
t_idx = min(s_idx + step_stride, model_T)
|
| 437 |
+
s_t = (s_idx / model_T) * torch.ones((1, 1), device=device)
|
| 438 |
+
t_t = (t_idx / model_T) * torch.ones((1, 1), device=device)
|
| 439 |
+
sampled_s, discrete_s = model.sample_p_zs_given_zt(
|
| 440 |
+
s_t, t_t, X, E, y, X_index, node_mask, inpaint_mask)
|
| 441 |
+
X, E, y = sampled_s.X, sampled_s.E, sampled_s.y
|
| 442 |
+
emitted += 1
|
| 443 |
+
is_frame = (emitted % frame_interval == 0) or (s_idx == 0)
|
| 444 |
+
E_int_prev = discrete_s.E[0].long()
|
| 445 |
+
event = {
|
| 446 |
+
"type": "progress",
|
| 447 |
+
"phase": "denoise",
|
| 448 |
+
"step": emitted,
|
| 449 |
+
"total_steps": total_loop_steps,
|
| 450 |
+
"elapsed_ms": int((time.time() - t0) * 1000),
|
| 451 |
+
}
|
| 452 |
+
if is_frame:
|
| 453 |
+
frame = render_kg_subgraph(
|
| 454 |
+
E_int_prev, n_max, X_index[0], dataset_id, loader)
|
| 455 |
+
gif_frames.append(frame)
|
| 456 |
+
event["preview"] = _pil_to_b64(frame)
|
| 457 |
+
yield event
|
| 458 |
+
|
| 459 |
+
X_final, E_final = _collapse_final_kg(model, X, E, y, node_mask)
|
| 460 |
+
|
| 461 |
+
corrected_E_int = E_final[0]
|
| 462 |
+
changes = compute_changes(original_E_int, corrected_E_int, n_max, loader)
|
| 463 |
+
corrected_img = render_kg_subgraph(
|
| 464 |
+
corrected_E_int, n_max, X_index[0], dataset_id, loader, changes=changes)
|
| 465 |
+
|
| 466 |
+
elapsed_ms = int((time.time() - t0) * 1000)
|
| 467 |
+
yield {
|
| 468 |
+
"type": "result",
|
| 469 |
+
"original_image": _pil_to_b64(original_img),
|
| 470 |
+
"corrected_image": _pil_to_b64(corrected_img),
|
| 471 |
+
"chain_gif": _frames_to_gif_b64(gif_frames),
|
| 472 |
+
"changes": changes,
|
| 473 |
+
"inference_time_ms": elapsed_ms,
|
| 474 |
+
}
|
| 475 |
+
|
| 476 |
+
|
| 477 |
+
# ---------------------------------------------------------------------------
|
| 478 |
+
# MultiProx correction / generation
|
| 479 |
+
# ---------------------------------------------------------------------------
|
| 480 |
+
|
| 481 |
+
def run_multiprox_correction_init(model, tensors, dataset_id, task, loader,
|
| 482 |
+
n, m, t, t_prime, gibbs_chain_freq):
|
| 483 |
+
device = next(model.parameters()).device
|
| 484 |
+
X_given = tensors["X_given"].to(device)
|
| 485 |
+
E_given = tensors["E_given"].to(device)
|
| 486 |
+
y_given = tensors["y_given"].to(device)
|
| 487 |
+
X_index = tensors["X_index"].to(device)
|
| 488 |
+
X_c = tensors["X_c"].to(device)
|
| 489 |
+
is_bip = tensors["is_bip"].to(device)
|
| 490 |
+
n_nodes = tensors["n_nodes"].to(device)
|
| 491 |
+
node_mask = tensors["node_mask"].to(device)
|
| 492 |
+
n_max = n_nodes.item()
|
| 493 |
+
|
| 494 |
+
inpaint_mask = _build_inpaint_mask(
|
| 495 |
+
task, node_mask, is_bip, model.Edim_output, device)
|
| 496 |
+
original_E_int = E_given[0].argmax(dim=-1).long()
|
| 497 |
+
original_img = render_kg_subgraph(
|
| 498 |
+
original_E_int, n_max, X_index[0], dataset_id, loader, changes=None)
|
| 499 |
+
|
| 500 |
+
t0 = time.time()
|
| 501 |
+
# Sample initial noise for each of M Gibbs chains
|
| 502 |
+
z_samples = []
|
| 503 |
+
with torch.no_grad():
|
| 504 |
+
for i in range(m):
|
| 505 |
+
z_i = _sample_initial_noise_kg(model, node_mask)
|
| 506 |
+
if task != "generate":
|
| 507 |
+
z_i.E = z_i.E * inpaint_mask + E_given * (~inpaint_mask)
|
| 508 |
+
z_samples.append(z_i)
|
| 509 |
+
if (i + 1) % max(1, m // 10) == 0 or i == m - 1:
|
| 510 |
+
yield {
|
| 511 |
+
"type": "progress",
|
| 512 |
+
"phase": "noise_init",
|
| 513 |
+
"step": i + 1,
|
| 514 |
+
"total_steps": m,
|
| 515 |
+
"elapsed_ms": int((time.time() - t0) * 1000),
|
| 516 |
+
}
|
| 517 |
+
|
| 518 |
+
# Stack to (1, M, n, ...) tensors
|
| 519 |
+
E_ens = torch.stack([z.E for z in z_samples], dim=1) # (1, M, n, n, Edim)
|
| 520 |
+
y_ens = torch.stack([z.y for z in z_samples], dim=1) # (1, M, ydim)
|
| 521 |
+
|
| 522 |
+
# Aggregate for preview
|
| 523 |
+
agg_E = torch.median(E_ens, dim=1).values
|
| 524 |
+
agg_y = torch.median(y_ens.float(), dim=1).values
|
| 525 |
+
X_int, E_int = _collapse_final_kg(model, X_given, agg_E, agg_y, node_mask)
|
| 526 |
+
corrected_E_int = E_int[0]
|
| 527 |
+
changes = compute_changes(original_E_int, corrected_E_int, n_max, loader)
|
| 528 |
+
preview_img = render_kg_subgraph(
|
| 529 |
+
corrected_E_int, n_max, X_index[0], dataset_id, loader, changes=changes)
|
| 530 |
+
elapsed_ms = int((time.time() - t0) * 1000)
|
| 531 |
+
|
| 532 |
+
state = {
|
| 533 |
+
"X_given": X_given.cpu(),
|
| 534 |
+
"E": E_ens.cpu(),
|
| 535 |
+
"y": y_ens.cpu(),
|
| 536 |
+
"n_nodes": n_nodes.cpu(),
|
| 537 |
+
"dataset_id": dataset_id,
|
| 538 |
+
"task": task,
|
| 539 |
+
"X_index": X_index.cpu(),
|
| 540 |
+
"X_c": X_c.cpu(),
|
| 541 |
+
"is_bip": is_bip.cpu(),
|
| 542 |
+
"original_E_int": original_E_int.cpu(),
|
| 543 |
+
"T": model.T, "n": n, "m": m, "t": t, "t_prime": t_prime,
|
| 544 |
+
"gibbs_chain_freq": gibbs_chain_freq,
|
| 545 |
+
"inner_step": 0, "step": 0,
|
| 546 |
+
}
|
| 547 |
+
yield {
|
| 548 |
+
"type": "result",
|
| 549 |
+
"state": state,
|
| 550 |
+
"original_image": _pil_to_b64(original_img),
|
| 551 |
+
"image": _pil_to_b64(preview_img),
|
| 552 |
+
"changes": changes,
|
| 553 |
+
"inference_time_ms": elapsed_ms,
|
| 554 |
+
}
|
| 555 |
+
|
| 556 |
+
|
| 557 |
+
def run_multiprox_correction_step(model, state, loader):
|
| 558 |
+
device = next(model.parameters()).device
|
| 559 |
+
dataset_id = state["dataset_id"]
|
| 560 |
+
task = state["task"]
|
| 561 |
+
X_given = state["X_given"].to(device)
|
| 562 |
+
E = state["E"].to(device)
|
| 563 |
+
y = state["y"].to(device)
|
| 564 |
+
X_index = state["X_index"].to(device)
|
| 565 |
+
is_bip = state["is_bip"].to(device)
|
| 566 |
+
n_nodes = state["n_nodes"].to(device)
|
| 567 |
+
original_E_int = state["original_E_int"].to(device)
|
| 568 |
+
|
| 569 |
+
T = state["T"]
|
| 570 |
+
n = state["n"]
|
| 571 |
+
m = state["m"]
|
| 572 |
+
t = state["t"]
|
| 573 |
+
t_prime = state["t_prime"]
|
| 574 |
+
gibbs_chain_freq = state["gibbs_chain_freq"]
|
| 575 |
+
inner_step = state["inner_step"]
|
| 576 |
+
step = state["step"]
|
| 577 |
+
|
| 578 |
+
n_max = int(n_nodes.item())
|
| 579 |
+
node_mask = torch.ones(1, n_max, dtype=torch.bool, device=device)
|
| 580 |
+
inpaint_mask = _build_inpaint_mask(task, node_mask, is_bip, model.Edim_output, device)
|
| 581 |
+
|
| 582 |
+
fixed_t_norm = t * torch.ones((1, 1), dtype=torch.float, device=device)
|
| 583 |
+
fixed_s_norm = fixed_t_norm - (1.0 / T)
|
| 584 |
+
|
| 585 |
+
steps_this_call = min(gibbs_chain_freq, m - inner_step)
|
| 586 |
+
|
| 587 |
+
t0 = time.time()
|
| 588 |
+
with torch.no_grad():
|
| 589 |
+
for i in range(steps_this_call):
|
| 590 |
+
k = inner_step + i
|
| 591 |
+
avg_E = torch.median(E, dim=1).values
|
| 592 |
+
avg_y = torch.median(y.float(), dim=1).values
|
| 593 |
+
denoised, _ = model.sample_p_zs_given_zt(
|
| 594 |
+
fixed_s_norm, fixed_t_norm, X_given, avg_E, avg_y,
|
| 595 |
+
X_index, node_mask, inpaint_mask)
|
| 596 |
+
|
| 597 |
+
old_t2 = model.gibbs_fixed_t_2
|
| 598 |
+
model.gibbs_fixed_t_2 = t # safe: inference lock held by registry
|
| 599 |
+
noisy = model.apply_noise(
|
| 600 |
+
denoised.X, denoised.E, denoised.y, node_mask, inpaint_mask, gibbs=True)
|
| 601 |
+
model.gibbs_fixed_t_2 = old_t2
|
| 602 |
+
|
| 603 |
+
E[:, k] = noisy["E_t"]
|
| 604 |
+
y[:, k] = noisy["y_t"]
|
| 605 |
+
|
| 606 |
+
# Preview aggregate state
|
| 607 |
+
prev_E = torch.median(E, dim=1).values
|
| 608 |
+
prev_y = torch.median(y.float(), dim=1).values
|
| 609 |
+
_, prev_Ei = _collapse_final_kg(model, X_given, prev_E, prev_y, node_mask)
|
| 610 |
+
preview_img = render_kg_subgraph(
|
| 611 |
+
prev_Ei[0], n_max, X_index[0], dataset_id, loader)
|
| 612 |
+
yield {
|
| 613 |
+
"type": "progress",
|
| 614 |
+
"phase": "gibbs",
|
| 615 |
+
"step": i + 1,
|
| 616 |
+
"total_steps": steps_this_call,
|
| 617 |
+
"elapsed_ms": int((time.time() - t0) * 1000),
|
| 618 |
+
"preview": _pil_to_b64(preview_img),
|
| 619 |
+
}
|
| 620 |
+
|
| 621 |
+
new_inner_step = inner_step + steps_this_call
|
| 622 |
+
round_complete = new_inner_step >= m
|
| 623 |
+
if round_complete:
|
| 624 |
+
new_inner_step = 0
|
| 625 |
+
new_step = step + 1
|
| 626 |
+
else:
|
| 627 |
+
new_step = step
|
| 628 |
+
done = round_complete and new_step >= n
|
| 629 |
+
|
| 630 |
+
# Refinement pass — always produce a clean render
|
| 631 |
+
P = int((t - t_prime) * T) + 1
|
| 632 |
+
P = max(P, 1)
|
| 633 |
+
refine_preview_interval = max(1, P // 10)
|
| 634 |
+
cur_E = torch.median(E, dim=1).values
|
| 635 |
+
cur_y = torch.median(y.float(), dim=1).values
|
| 636 |
+
cur_X = X_given
|
| 637 |
+
for j in range(P):
|
| 638 |
+
s_ref = (t - (j + 1) / T) * torch.ones((1, 1), dtype=torch.float, device=device)
|
| 639 |
+
t_ref = (t - j / T) * torch.ones((1, 1), dtype=torch.float, device=device)
|
| 640 |
+
sampled, discrete_s = model.sample_p_zs_given_zt(
|
| 641 |
+
s_ref, t_ref, cur_X, cur_E, cur_y, X_index, node_mask, inpaint_mask)
|
| 642 |
+
cur_X, cur_E, cur_y = sampled.X, sampled.E, sampled.y
|
| 643 |
+
is_frame = (j + 1) % refine_preview_interval == 0 or j == P - 1
|
| 644 |
+
event = {
|
| 645 |
+
"type": "progress",
|
| 646 |
+
"phase": "refine",
|
| 647 |
+
"step": j + 1,
|
| 648 |
+
"total_steps": P,
|
| 649 |
+
"elapsed_ms": int((time.time() - t0) * 1000),
|
| 650 |
+
}
|
| 651 |
+
if is_frame:
|
| 652 |
+
event["preview"] = _pil_to_b64(render_kg_subgraph(
|
| 653 |
+
discrete_s.E[0].long(), n_max, X_index[0], dataset_id, loader))
|
| 654 |
+
yield event
|
| 655 |
+
|
| 656 |
+
X_int, E_int = _collapse_final_kg(model, cur_X, cur_E, cur_y, node_mask)
|
| 657 |
+
|
| 658 |
+
corrected_E_int = E_int[0]
|
| 659 |
+
changes = compute_changes(original_E_int, corrected_E_int, n_max, loader)
|
| 660 |
+
corrected_img = render_kg_subgraph(
|
| 661 |
+
corrected_E_int, n_max, X_index[0], dataset_id, loader, changes=changes)
|
| 662 |
+
elapsed_ms = int((time.time() - t0) * 1000)
|
| 663 |
+
|
| 664 |
+
updated_state = {
|
| 665 |
+
**state,
|
| 666 |
+
"E": E.cpu(), "y": y.cpu(),
|
| 667 |
+
"step": new_step, "inner_step": new_inner_step,
|
| 668 |
+
}
|
| 669 |
+
yield {
|
| 670 |
+
"type": "result",
|
| 671 |
+
"state": updated_state,
|
| 672 |
+
"image": _pil_to_b64(corrected_img),
|
| 673 |
+
"changes": changes,
|
| 674 |
+
"round_complete": round_complete,
|
| 675 |
+
"done": done,
|
| 676 |
+
"inference_time_ms": elapsed_ms,
|
| 677 |
+
}
|
| 678 |
+
|
| 679 |
+
|
| 680 |
+
# ---------------------------------------------------------------------------
|
| 681 |
+
# State blob serialisation
|
| 682 |
+
# ---------------------------------------------------------------------------
|
| 683 |
+
|
| 684 |
+
def encode_state_blob(state):
|
| 685 |
+
buf = io.BytesIO()
|
| 686 |
+
torch.save(state, buf)
|
| 687 |
+
return base64.b64encode(buf.getvalue()).decode("ascii")
|
| 688 |
+
|
| 689 |
+
|
| 690 |
+
def decode_state_blob(b64_str):
|
| 691 |
+
try:
|
| 692 |
+
raw = base64.b64decode(b64_str)
|
| 693 |
+
except Exception:
|
| 694 |
+
raise ValueError("state is not valid base64")
|
| 695 |
+
if len(raw) > STATE_BLOB_MAX_BYTES:
|
| 696 |
+
raise ValueError(f"state blob exceeds {STATE_BLOB_MAX_BYTES // (1024 * 1024)} MB limit")
|
| 697 |
+
try:
|
| 698 |
+
state = torch.load(io.BytesIO(raw), weights_only=False)
|
| 699 |
+
except Exception as exc:
|
| 700 |
+
raise ValueError(f"state could not be deserialized: {exc}") from exc
|
| 701 |
+
missing = REQUIRED_STATE_KEYS - set(state.keys())
|
| 702 |
+
if missing:
|
| 703 |
+
raise ValueError(f"state missing keys: {missing}")
|
| 704 |
+
if not isinstance(state["E"], torch.Tensor) or state["E"].dim() != 5:
|
| 705 |
+
raise ValueError("state['E'] must be a 5-D tensor")
|
| 706 |
+
if not isinstance(state["X_given"], torch.Tensor) or state["X_given"].dim() != 3:
|
| 707 |
+
raise ValueError("state['X_given'] must be a 3-D tensor")
|
| 708 |
+
return state
|
|
@@ -276,6 +276,7 @@ class ModelRegistry:
|
|
| 276 |
self._coins_experiments = {} # (dataset_id, algorithm) -> Experiment
|
| 277 |
self._coins_loaders = {} # (dataset_id, seed, leiden_resolution) -> full Loader
|
| 278 |
self._graphgen_models = {} # (dataset_id, model_type) -> loaded eval-mode model
|
|
|
|
| 279 |
|
| 280 |
def force_release_inference_lock(self):
|
| 281 |
"""Emergency release for a stuck inference lock (e.g. client disconnect)."""
|
|
@@ -465,9 +466,12 @@ class ModelRegistry:
|
|
| 465 |
seed=seed, device="cpu", val_size=0.01, test_size=0.02,
|
| 466 |
community_method="leiden", leiden_resolution=leiden_resolution,
|
| 467 |
)
|
| 468 |
-
# Free heavy arrays not needed for discovery endpoints
|
| 469 |
-
_free_heavy_arrays(loader)
|
| 470 |
self.loaders[dataset_id] = loader
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 471 |
logger.info(
|
| 472 |
"Loader ready for %s: %d entities, %d relations, %d train triples",
|
| 473 |
dataset_id, loader.num_nodes, loader.num_relations, len(loader.train_edge_data),
|
|
@@ -885,6 +889,239 @@ class ModelRegistry:
|
|
| 885 |
|
| 886 |
return _gen()
|
| 887 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 888 |
# ---- COINs inference ---------------------------------------------------
|
| 889 |
|
| 890 |
def coins_predict(self, dataset_id, algorithm, query_structure_id,
|
|
|
|
| 276 |
self._coins_experiments = {} # (dataset_id, algorithm) -> Experiment
|
| 277 |
self._coins_loaders = {} # (dataset_id, seed, leiden_resolution) -> full Loader
|
| 278 |
self._graphgen_models = {} # (dataset_id, model_type) -> loaded eval-mode model
|
| 279 |
+
self._kg_anomaly_models = {} # (dataset_id, task) -> loaded eval-mode model
|
| 280 |
|
| 281 |
def force_release_inference_lock(self):
|
| 282 |
"""Emergency release for a stuck inference lock (e.g. client disconnect)."""
|
|
|
|
| 466 |
seed=seed, device="cpu", val_size=0.01, test_size=0.02,
|
| 467 |
community_method="leiden", leiden_resolution=leiden_resolution,
|
| 468 |
)
|
|
|
|
|
|
|
| 469 |
self.loaders[dataset_id] = loader
|
| 470 |
+
# Share this loader with _load_coins_experiment so experiments for the
|
| 471 |
+
# same (dataset, seed, leiden_resolution) reuse it instead of reloading
|
| 472 |
+
# the graph. Heavy arrays stay populated — they're needed by full
|
| 473 |
+
# experiments (embedder/sampler/ranker) and by KG anomaly inference.
|
| 474 |
+
self._coins_loaders[(dataset_id, seed, leiden_resolution)] = loader
|
| 475 |
logger.info(
|
| 476 |
"Loader ready for %s: %d entities, %d relations, %d train triples",
|
| 477 |
dataset_id, loader.num_nodes, loader.num_relations, len(loader.train_edge_data),
|
|
|
|
| 889 |
|
| 890 |
return _gen()
|
| 891 |
|
| 892 |
+
# ---- KG anomaly (DiGress KG) inference --------------------------------
|
| 893 |
+
|
| 894 |
+
def _load_kg_anomaly_model(self, dataset_id, task):
|
| 895 |
+
"""Load the DiGress KG checkpoint for (dataset_id, task), cached.
|
| 896 |
+
|
| 897 |
+
The KG checkpoint pickles only ``cfg`` via ``save_hyperparameters('cfg')``,
|
| 898 |
+
so we must reconstruct ``dataset_infos``, ``extra_features`` and
|
| 899 |
+
``domain_features`` before constructing the model. Dims are inferred from
|
| 900 |
+
state_dict shapes; kg_experiment comes from the matching COINs experiment.
|
| 901 |
+
"""
|
| 902 |
+
key = (dataset_id, task)
|
| 903 |
+
if key in self._kg_anomaly_models:
|
| 904 |
+
return self._kg_anomaly_models[key]
|
| 905 |
+
|
| 906 |
+
import torch
|
| 907 |
+
import torch.nn.parallel.distributed as _ddp_mod
|
| 908 |
+
|
| 909 |
+
suffix = "_correct" if task == "correct" else ""
|
| 910 |
+
ckpt_path = Path(settings.DIGRESS_KG_DIR) / "checkpoints" / f"{dataset_id}{suffix}.ckpt"
|
| 911 |
+
if not ckpt_path.exists():
|
| 912 |
+
from api.exceptions import ModelUnavailable
|
| 913 |
+
raise ModelUnavailable(f"KG anomaly checkpoint not found: {ckpt_path.name}")
|
| 914 |
+
|
| 915 |
+
logger.info("Loading KG anomaly model: dataset=%s task=%s", dataset_id, task)
|
| 916 |
+
|
| 917 |
+
# Load to CPU with DDP patching (same strategy as _safe_load_lightning_checkpoint)
|
| 918 |
+
_orig_set = _ddp_mod.DistributedDataParallel.__setstate__
|
| 919 |
+
_orig_get = _ddp_mod.DistributedDataParallel.__getstate__
|
| 920 |
+
_ddp_mod.DistributedDataParallel.__setstate__ = lambda self, state: self.__dict__.update(state)
|
| 921 |
+
_ddp_mod.DistributedDataParallel.__getstate__ = lambda self: self.__dict__
|
| 922 |
+
try:
|
| 923 |
+
ckpt = torch.load(str(ckpt_path), map_location="cpu", weights_only=False)
|
| 924 |
+
finally:
|
| 925 |
+
_ddp_mod.DistributedDataParallel.__setstate__ = _orig_set
|
| 926 |
+
_ddp_mod.DistributedDataParallel.__getstate__ = _orig_get
|
| 927 |
+
|
| 928 |
+
hparams = ckpt.get("hyper_parameters", {})
|
| 929 |
+
cfg = hparams.get("cfg") if isinstance(hparams, dict) else getattr(hparams, "cfg", None)
|
| 930 |
+
if cfg is None:
|
| 931 |
+
raise RuntimeError(f"KG anomaly checkpoint {ckpt_path.name} is missing 'cfg' in hyper_parameters")
|
| 932 |
+
state_dict = ckpt["state_dict"]
|
| 933 |
+
|
| 934 |
+
# Ensure the model's task matches the endpoint task.
|
| 935 |
+
try:
|
| 936 |
+
cfg.model.task = task
|
| 937 |
+
except Exception:
|
| 938 |
+
pass # OmegaConf struct-mode tolerant: if already set, leave it
|
| 939 |
+
|
| 940 |
+
# Infer dims from state_dict
|
| 941 |
+
edim_output = state_dict["model.mlp_out_E.2.weight"].shape[0]
|
| 942 |
+
input_dim_x = state_dict["model.mlp_in_X.0.weight"].shape[1]
|
| 943 |
+
input_dim_e = state_dict["model.mlp_in_E.0.weight"].shape[1]
|
| 944 |
+
input_dim_y = state_dict["model.mlp_in_y.0.weight"].shape[1]
|
| 945 |
+
|
| 946 |
+
# Load COINs experiment — needed for kg_experiment and for num_node_types
|
| 947 |
+
experiment = self._load_coins_experiment(dataset_id, "transe")
|
| 948 |
+
xdim_output = experiment.loader.num_node_types
|
| 949 |
+
# Sanity: input_dim_e should equal edim_output (no extra E features for KG)
|
| 950 |
+
if input_dim_e != edim_output:
|
| 951 |
+
logger.warning(
|
| 952 |
+
"Unexpected mlp_in_E dim %d != edim_output %d for %s/%s",
|
| 953 |
+
input_dim_e, edim_output, dataset_id, task,
|
| 954 |
+
)
|
| 955 |
+
|
| 956 |
+
# Build mock dataset_infos
|
| 957 |
+
from graph_generation.src.diffusion.distributions import DistributionNodes
|
| 958 |
+
from graph_generation.src.diffusion.extra_features import (
|
| 959 |
+
DummyExtraFeatures, ExtraFeatures,
|
| 960 |
+
)
|
| 961 |
+
|
| 962 |
+
# max_num_nodes from dataset name (e.g. "freebase_20" -> 20, then *2 per kg_dataset.py)
|
| 963 |
+
try:
|
| 964 |
+
base_max = int(cfg.dataset.name.split("_")[-1])
|
| 965 |
+
except (AttributeError, ValueError):
|
| 966 |
+
base_max = 20
|
| 967 |
+
max_num_nodes = base_max * 2
|
| 968 |
+
|
| 969 |
+
# Histogram for DistributionNodes — uniform over possible node counts
|
| 970 |
+
n_hist = torch.ones(max_num_nodes + 1)
|
| 971 |
+
n_hist[:2] = 0 # at least 2 nodes
|
| 972 |
+
nodes_dist = DistributionNodes(n_hist)
|
| 973 |
+
|
| 974 |
+
class _MockDataModule:
|
| 975 |
+
def __init__(self, kg_experiment, max_num_nodes):
|
| 976 |
+
self.kg_experiment = kg_experiment
|
| 977 |
+
self.max_num_nodes = max_num_nodes
|
| 978 |
+
|
| 979 |
+
class _MockDatasetInfos:
|
| 980 |
+
pass
|
| 981 |
+
|
| 982 |
+
dataset_infos = _MockDatasetInfos()
|
| 983 |
+
dataset_infos.datamodule = _MockDataModule(experiment, max_num_nodes)
|
| 984 |
+
dataset_infos.input_dims = {"X": input_dim_x, "E": input_dim_e, "y": input_dim_y}
|
| 985 |
+
dataset_infos.output_dims = {"X": xdim_output, "E": edim_output, "y": 0}
|
| 986 |
+
dataset_infos.nodes_dist = nodes_dist
|
| 987 |
+
dataset_infos.max_n_nodes = max_num_nodes
|
| 988 |
+
dataset_infos.node_types = torch.ones(xdim_output, dtype=torch.float32)
|
| 989 |
+
dataset_infos.edge_types = torch.ones(edim_output, dtype=torch.float32)
|
| 990 |
+
|
| 991 |
+
# extra_features per cfg
|
| 992 |
+
extra_features_type = getattr(cfg.model, "extra_features", None)
|
| 993 |
+
if cfg.model.type == "discrete" and extra_features_type is not None:
|
| 994 |
+
extra_features = ExtraFeatures(extra_features_type, dataset_info=dataset_infos)
|
| 995 |
+
else:
|
| 996 |
+
extra_features = DummyExtraFeatures()
|
| 997 |
+
domain_features = DummyExtraFeatures()
|
| 998 |
+
|
| 999 |
+
from diffusion_model_discrete_kg import DiscreteDenoisingDiffusionKG as cls
|
| 1000 |
+
|
| 1001 |
+
_orig_save = cls.save_hyperparameters
|
| 1002 |
+
cls.save_hyperparameters = lambda self, *a, **kw: None
|
| 1003 |
+
try:
|
| 1004 |
+
model = cls(cfg, dataset_infos, None, None, None, extra_features, domain_features)
|
| 1005 |
+
finally:
|
| 1006 |
+
cls.save_hyperparameters = _orig_save
|
| 1007 |
+
|
| 1008 |
+
missing, unexpected = model.load_state_dict(state_dict, strict=False)
|
| 1009 |
+
if missing:
|
| 1010 |
+
logger.debug("KG anomaly state_dict missing keys: %d (e.g. %s)",
|
| 1011 |
+
len(missing), missing[:3])
|
| 1012 |
+
if unexpected:
|
| 1013 |
+
logger.debug("KG anomaly state_dict unexpected keys: %d (e.g. %s)",
|
| 1014 |
+
len(unexpected), unexpected[:3])
|
| 1015 |
+
|
| 1016 |
+
del ckpt
|
| 1017 |
+
model.to(settings.TORCH_DEVICE)
|
| 1018 |
+
model.eval()
|
| 1019 |
+
self._kg_anomaly_models[key] = model
|
| 1020 |
+
logger.info("KG anomaly model ready: dataset=%s task=%s", dataset_id, task)
|
| 1021 |
+
return model
|
| 1022 |
+
|
| 1023 |
+
def kg_anomaly_correct_stream(self, dataset_id, task, sampling_mode, subgraph,
|
| 1024 |
+
diffusion_steps, chain_frames, multiprox_params):
|
| 1025 |
+
"""Return a generator of SSE event dicts for /kg-anomaly/correct."""
|
| 1026 |
+
from api.exceptions import InferenceBusy
|
| 1027 |
+
from api.services.kg_anomaly_inference import (
|
| 1028 |
+
build_kg_tensors, encode_state_blob,
|
| 1029 |
+
run_multiprox_correction_init, run_standard_correction,
|
| 1030 |
+
)
|
| 1031 |
+
if not self._inference_lock.acquire(blocking=False):
|
| 1032 |
+
raise InferenceBusy()
|
| 1033 |
+
self._inference_lock_owner = f"kg_anomaly_correct {dataset_id}/{task}/{sampling_mode}"
|
| 1034 |
+
try:
|
| 1035 |
+
model = self._load_kg_anomaly_model(dataset_id, task)
|
| 1036 |
+
loader = self.loaders.get(dataset_id)
|
| 1037 |
+
tensors = build_kg_tensors(subgraph, loader, model)
|
| 1038 |
+
except Exception:
|
| 1039 |
+
self._inference_lock_owner = None
|
| 1040 |
+
self._inference_lock.release()
|
| 1041 |
+
raise
|
| 1042 |
+
|
| 1043 |
+
def _gen():
|
| 1044 |
+
try:
|
| 1045 |
+
if sampling_mode == "standard":
|
| 1046 |
+
for event in run_standard_correction(
|
| 1047 |
+
model, tensors, dataset_id, task, loader,
|
| 1048 |
+
diffusion_steps, chain_frames):
|
| 1049 |
+
if event["type"] == "result":
|
| 1050 |
+
event.update({
|
| 1051 |
+
"dataset_id": dataset_id,
|
| 1052 |
+
"task": task,
|
| 1053 |
+
"sampling_mode": sampling_mode,
|
| 1054 |
+
})
|
| 1055 |
+
yield event
|
| 1056 |
+
else:
|
| 1057 |
+
n = multiprox_params["n"]
|
| 1058 |
+
m = multiprox_params["m"]
|
| 1059 |
+
t = multiprox_params["t"]
|
| 1060 |
+
t_prime = multiprox_params["t_prime"]
|
| 1061 |
+
gibbs_chain_freq = multiprox_params["gibbs_chain_freq"]
|
| 1062 |
+
for event in run_multiprox_correction_init(
|
| 1063 |
+
model, tensors, dataset_id, task, loader,
|
| 1064 |
+
n, m, t, t_prime, gibbs_chain_freq):
|
| 1065 |
+
if event["type"] == "result":
|
| 1066 |
+
state = event.pop("state")
|
| 1067 |
+
event.update({
|
| 1068 |
+
"dataset_id": dataset_id,
|
| 1069 |
+
"task": task,
|
| 1070 |
+
"sampling_mode": sampling_mode,
|
| 1071 |
+
"step": 0,
|
| 1072 |
+
"round_complete": False,
|
| 1073 |
+
"done": False,
|
| 1074 |
+
"state": encode_state_blob(state),
|
| 1075 |
+
})
|
| 1076 |
+
yield event
|
| 1077 |
+
finally:
|
| 1078 |
+
self._inference_lock_owner = None
|
| 1079 |
+
self._inference_lock.release()
|
| 1080 |
+
|
| 1081 |
+
return _gen()
|
| 1082 |
+
|
| 1083 |
+
def kg_anomaly_continue_stream(self, state_b64):
|
| 1084 |
+
"""Return a generator of SSE event dicts for /kg-anomaly/continue."""
|
| 1085 |
+
from api.exceptions import InferenceBusy, InvalidRequestError
|
| 1086 |
+
from api.services.kg_anomaly_inference import (
|
| 1087 |
+
decode_state_blob, encode_state_blob, run_multiprox_correction_step,
|
| 1088 |
+
)
|
| 1089 |
+
try:
|
| 1090 |
+
state = decode_state_blob(state_b64)
|
| 1091 |
+
except ValueError as exc:
|
| 1092 |
+
raise InvalidRequestError(str(exc))
|
| 1093 |
+
|
| 1094 |
+
if not self._inference_lock.acquire(blocking=False):
|
| 1095 |
+
raise InferenceBusy()
|
| 1096 |
+
self._inference_lock_owner = (
|
| 1097 |
+
f"kg_anomaly_continue {state['dataset_id']}/{state['task']}"
|
| 1098 |
+
)
|
| 1099 |
+
try:
|
| 1100 |
+
model = self._load_kg_anomaly_model(state["dataset_id"], state["task"])
|
| 1101 |
+
loader = self.loaders.get(state["dataset_id"])
|
| 1102 |
+
except Exception:
|
| 1103 |
+
self._inference_lock_owner = None
|
| 1104 |
+
self._inference_lock.release()
|
| 1105 |
+
raise
|
| 1106 |
+
|
| 1107 |
+
def _gen():
|
| 1108 |
+
try:
|
| 1109 |
+
for event in run_multiprox_correction_step(model, state, loader):
|
| 1110 |
+
if event["type"] == "result":
|
| 1111 |
+
updated_state = event.pop("state")
|
| 1112 |
+
event.update({
|
| 1113 |
+
"dataset_id": updated_state["dataset_id"],
|
| 1114 |
+
"task": updated_state["task"],
|
| 1115 |
+
"step": updated_state["step"],
|
| 1116 |
+
"state": encode_state_blob(updated_state),
|
| 1117 |
+
})
|
| 1118 |
+
yield event
|
| 1119 |
+
finally:
|
| 1120 |
+
self._inference_lock_owner = None
|
| 1121 |
+
self._inference_lock.release()
|
| 1122 |
+
|
| 1123 |
+
return _gen()
|
| 1124 |
+
|
| 1125 |
# ---- COINs inference ---------------------------------------------------
|
| 1126 |
|
| 1127 |
def coins_predict(self, dataset_id, algorithm, query_structure_id,
|
|
@@ -13,7 +13,12 @@ from api.views.graph_generation import (
|
|
| 13 |
GraphGenContinueView, GraphGenDatasetsView, GraphGenGenerateView, GraphGenSamplingModesView,
|
| 14 |
)
|
| 15 |
from api.views.health import ApiRootView, ForceUnlockView, HealthView, MethodsView
|
| 16 |
-
from api.views.kg_anomaly import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
urlpatterns = [
|
| 19 |
# Health & discovery
|
|
@@ -37,4 +42,6 @@ urlpatterns = [
|
|
| 37 |
# KG anomaly
|
| 38 |
path("kg-anomaly/datasets", KgAnomalyDatasetsView.as_view()),
|
| 39 |
path("kg-anomaly/datasets/<str:dataset_id>/sample-subgraphs", KgAnomalySampleSubgraphsView.as_view()),
|
|
|
|
|
|
|
| 40 |
]
|
|
|
|
| 13 |
GraphGenContinueView, GraphGenDatasetsView, GraphGenGenerateView, GraphGenSamplingModesView,
|
| 14 |
)
|
| 15 |
from api.views.health import ApiRootView, ForceUnlockView, HealthView, MethodsView
|
| 16 |
+
from api.views.kg_anomaly import (
|
| 17 |
+
KgAnomalyContinueView,
|
| 18 |
+
KgAnomalyCorrectView,
|
| 19 |
+
KgAnomalyDatasetsView,
|
| 20 |
+
KgAnomalySampleSubgraphsView,
|
| 21 |
+
)
|
| 22 |
|
| 23 |
urlpatterns = [
|
| 24 |
# Health & discovery
|
|
|
|
| 42 |
# KG anomaly
|
| 43 |
path("kg-anomaly/datasets", KgAnomalyDatasetsView.as_view()),
|
| 44 |
path("kg-anomaly/datasets/<str:dataset_id>/sample-subgraphs", KgAnomalySampleSubgraphsView.as_view()),
|
| 45 |
+
path("kg-anomaly/correct", KgAnomalyCorrectView.as_view()),
|
| 46 |
+
path("kg-anomaly/continue", KgAnomalyContinueView.as_view()),
|
| 47 |
]
|
|
@@ -1,9 +1,11 @@
|
|
| 1 |
from rest_framework.response import Response
|
| 2 |
from rest_framework.views import APIView
|
| 3 |
|
| 4 |
-
from api.exceptions import NotFoundError
|
| 5 |
from api.services.constants import KG_ANOMALY_DATASET_META
|
|
|
|
| 6 |
from api.services.registry import ModelRegistry
|
|
|
|
| 7 |
|
| 8 |
|
| 9 |
class KgAnomalyDatasetsView(APIView):
|
|
@@ -34,9 +36,139 @@ class KgAnomalySampleSubgraphsView(APIView):
|
|
| 34 |
count = int(request.query_params.get("count", 5))
|
| 35 |
count = max(1, min(10, count))
|
| 36 |
|
| 37 |
-
subgraphs = sg_info.subgraphs[:count]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
return Response({
|
| 40 |
"dataset_id": dataset_id,
|
| 41 |
"subgraphs": subgraphs,
|
| 42 |
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from rest_framework.response import Response
|
| 2 |
from rest_framework.views import APIView
|
| 3 |
|
| 4 |
+
from api.exceptions import InvalidRequestError, ModelUnavailable, NotFoundError
|
| 5 |
from api.services.constants import KG_ANOMALY_DATASET_META
|
| 6 |
+
from api.services.kg_anomaly_inference import apply_edge_noise, build_kg_tensors
|
| 7 |
from api.services.registry import ModelRegistry
|
| 8 |
+
from api.views.graph_generation import _streaming_sse_response
|
| 9 |
|
| 10 |
|
| 11 |
class KgAnomalyDatasetsView(APIView):
|
|
|
|
| 36 |
count = int(request.query_params.get("count", 5))
|
| 37 |
count = max(1, min(10, count))
|
| 38 |
|
| 39 |
+
subgraphs = [dict(sg) for sg in sg_info.subgraphs[:count]]
|
| 40 |
+
|
| 41 |
+
noise_level_raw = request.query_params.get("noise_level")
|
| 42 |
+
if noise_level_raw is not None:
|
| 43 |
+
try:
|
| 44 |
+
noise_level = float(noise_level_raw)
|
| 45 |
+
except ValueError:
|
| 46 |
+
raise InvalidRequestError("'noise_level' must be a float in (0, 1]")
|
| 47 |
+
if not (0.0 < noise_level <= 1.0):
|
| 48 |
+
raise InvalidRequestError("'noise_level' must be in (0, 1]")
|
| 49 |
+
|
| 50 |
+
task = request.query_params.get("task", "correct")
|
| 51 |
+
if task not in ("correct", "generate"):
|
| 52 |
+
raise InvalidRequestError("'task' must be 'correct' or 'generate'")
|
| 53 |
+
|
| 54 |
+
available = registry.kg_anomaly_checkpoints_available.get(dataset_id, [])
|
| 55 |
+
if task not in available:
|
| 56 |
+
raise ModelUnavailable(
|
| 57 |
+
f"No '{task}' checkpoint available for dataset '{dataset_id}'")
|
| 58 |
+
|
| 59 |
+
seed_raw = request.query_params.get("seed")
|
| 60 |
+
seed = int(seed_raw) if seed_raw is not None else None
|
| 61 |
+
|
| 62 |
+
loader = registry.loaders[dataset_id]
|
| 63 |
+
model = registry._load_kg_anomaly_model(dataset_id, task)
|
| 64 |
+
|
| 65 |
+
for i, sg in enumerate(subgraphs):
|
| 66 |
+
offset_seed = None if seed is None else seed + i
|
| 67 |
+
tensors = build_kg_tensors(sg, loader, model)
|
| 68 |
+
sg["edges"] = apply_edge_noise(model, tensors, task, noise_level, offset_seed)
|
| 69 |
|
| 70 |
return Response({
|
| 71 |
"dataset_id": dataset_id,
|
| 72 |
"subgraphs": subgraphs,
|
| 73 |
})
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def _validate_subgraph(subgraph):
|
| 77 |
+
if not isinstance(subgraph, dict):
|
| 78 |
+
raise InvalidRequestError("'subgraph' must be an object with 'nodes' and 'edges'")
|
| 79 |
+
nodes = subgraph.get("nodes")
|
| 80 |
+
edges = subgraph.get("edges")
|
| 81 |
+
if not isinstance(nodes, list) or not (2 <= len(nodes) <= 20):
|
| 82 |
+
raise InvalidRequestError("'subgraph.nodes' must be a list of 2 to 20 items")
|
| 83 |
+
if not isinstance(edges, list):
|
| 84 |
+
raise InvalidRequestError("'subgraph.edges' must be a list")
|
| 85 |
+
n = len(nodes)
|
| 86 |
+
for i, node in enumerate(nodes):
|
| 87 |
+
if not isinstance(node, dict) or "entity_id" not in node:
|
| 88 |
+
raise InvalidRequestError(f"subgraph.nodes[{i}] must have 'entity_id'")
|
| 89 |
+
for i, e in enumerate(edges):
|
| 90 |
+
if not isinstance(e, dict):
|
| 91 |
+
raise InvalidRequestError(f"subgraph.edges[{i}] must be an object")
|
| 92 |
+
for field in ("source_idx", "target_idx", "relation_id"):
|
| 93 |
+
if field not in e:
|
| 94 |
+
raise InvalidRequestError(f"subgraph.edges[{i}] missing '{field}'")
|
| 95 |
+
if not (0 <= int(e["source_idx"]) < n and 0 <= int(e["target_idx"]) < n):
|
| 96 |
+
raise InvalidRequestError(f"subgraph.edges[{i}] has out-of-range node index")
|
| 97 |
+
if int(e["source_idx"]) == int(e["target_idx"]):
|
| 98 |
+
raise InvalidRequestError(f"subgraph.edges[{i}] is a self-loop (not allowed)")
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
class KgAnomalyCorrectView(APIView):
|
| 102 |
+
def post(self, request):
|
| 103 |
+
data = request.data
|
| 104 |
+
registry = ModelRegistry.get()
|
| 105 |
+
|
| 106 |
+
dataset_id = data.get("dataset_id")
|
| 107 |
+
if dataset_id not in KG_ANOMALY_DATASET_META:
|
| 108 |
+
raise InvalidRequestError(
|
| 109 |
+
f"Unknown dataset_id '{dataset_id}'. Valid: {list(KG_ANOMALY_DATASET_META)}")
|
| 110 |
+
|
| 111 |
+
task = data.get("task", "correct")
|
| 112 |
+
if task not in ("correct", "generate"):
|
| 113 |
+
raise InvalidRequestError("task must be 'correct' or 'generate'")
|
| 114 |
+
|
| 115 |
+
available = registry.kg_anomaly_checkpoints_available.get(dataset_id, [])
|
| 116 |
+
if task not in available:
|
| 117 |
+
raise ModelUnavailable(
|
| 118 |
+
f"No '{task}' checkpoint available for dataset '{dataset_id}'")
|
| 119 |
+
|
| 120 |
+
sampling_mode = data.get("sampling_mode")
|
| 121 |
+
if sampling_mode not in ("standard", "multiprox"):
|
| 122 |
+
raise InvalidRequestError("sampling_mode must be 'standard' or 'multiprox'")
|
| 123 |
+
|
| 124 |
+
subgraph = data.get("subgraph")
|
| 125 |
+
_validate_subgraph(subgraph)
|
| 126 |
+
|
| 127 |
+
if sampling_mode == "standard":
|
| 128 |
+
diffusion_steps = min(max(int(data.get("diffusion_steps", 500)), 50), 1000)
|
| 129 |
+
chain_frames = min(max(int(data.get("chain_frames", 20)), 10), 30)
|
| 130 |
+
gen = registry.kg_anomaly_correct_stream(
|
| 131 |
+
dataset_id, task, sampling_mode, subgraph,
|
| 132 |
+
diffusion_steps, chain_frames, None)
|
| 133 |
+
else:
|
| 134 |
+
mp = data.get("multiprox_params")
|
| 135 |
+
if not mp or not isinstance(mp, dict):
|
| 136 |
+
raise InvalidRequestError("multiprox_params is required for multiprox sampling_mode")
|
| 137 |
+
|
| 138 |
+
m = int(mp.get("m", 100))
|
| 139 |
+
if not (2 <= m <= 100):
|
| 140 |
+
raise InvalidRequestError("multiprox_params.m must be in [2, 100]")
|
| 141 |
+
|
| 142 |
+
n = int(mp.get("n", 10))
|
| 143 |
+
if n < 1:
|
| 144 |
+
raise InvalidRequestError("multiprox_params.n must be >= 1")
|
| 145 |
+
|
| 146 |
+
t = float(mp.get("t", 0.5))
|
| 147 |
+
t_prime = float(mp.get("t_prime", 0.1))
|
| 148 |
+
if not (0 < t_prime <= t <= 1):
|
| 149 |
+
raise InvalidRequestError(
|
| 150 |
+
"multiprox_params must satisfy 0 < t_prime <= t <= 1")
|
| 151 |
+
|
| 152 |
+
gibbs_chain_freq = int(mp.get("gibbs_chain_freq", max(1, m // 10)))
|
| 153 |
+
if not (1 <= gibbs_chain_freq <= m):
|
| 154 |
+
raise InvalidRequestError(
|
| 155 |
+
f"multiprox_params.gibbs_chain_freq must be in [1, {m}]")
|
| 156 |
+
|
| 157 |
+
multiprox_params = {
|
| 158 |
+
"n": n, "m": m, "t": t, "t_prime": t_prime,
|
| 159 |
+
"gibbs_chain_freq": gibbs_chain_freq,
|
| 160 |
+
}
|
| 161 |
+
gen = registry.kg_anomaly_correct_stream(
|
| 162 |
+
dataset_id, task, sampling_mode, subgraph,
|
| 163 |
+
None, None, multiprox_params)
|
| 164 |
+
|
| 165 |
+
return _streaming_sse_response(gen)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class KgAnomalyContinueView(APIView):
|
| 169 |
+
def post(self, request):
|
| 170 |
+
state_b64 = request.data.get("state")
|
| 171 |
+
if not state_b64 or not isinstance(state_b64, str):
|
| 172 |
+
raise InvalidRequestError("'state' is required and must be a non-empty string")
|
| 173 |
+
gen = ModelRegistry.get().kg_anomaly_continue_stream(state_b64)
|
| 174 |
+
return _streaming_sse_response(gen)
|
|
@@ -7,9 +7,10 @@ PROJECT_ROOT = BASE_DIR.parent.parent # Website root
|
|
| 7 |
|
| 8 |
# Add research repos to sys.path so their modules can be imported
|
| 9 |
_COINS_KG_ROOT = str(PROJECT_ROOT / "src" / "research" / "COINs-KGGeneration")
|
|
|
|
| 10 |
_MULTIPROXAN_ROOT = str(PROJECT_ROOT / "src" / "research" / "MultiProxAn")
|
| 11 |
_MULTIPROXAN_SRC = str(PROJECT_ROOT / "src" / "research" / "MultiProxAn" / "src")
|
| 12 |
-
for _path in (_COINS_KG_ROOT, _MULTIPROXAN_ROOT, _MULTIPROXAN_SRC):
|
| 13 |
if _path not in sys.path:
|
| 14 |
sys.path.insert(0, _path)
|
| 15 |
|
|
|
|
| 7 |
|
| 8 |
# Add research repos to sys.path so their modules can be imported
|
| 9 |
_COINS_KG_ROOT = str(PROJECT_ROOT / "src" / "research" / "COINs-KGGeneration")
|
| 10 |
+
_DIGRESS_KG_SRC = str(PROJECT_ROOT / "src" / "research" / "COINs-KGGeneration" / "graph_generation" / "src")
|
| 11 |
_MULTIPROXAN_ROOT = str(PROJECT_ROOT / "src" / "research" / "MultiProxAn")
|
| 12 |
_MULTIPROXAN_SRC = str(PROJECT_ROOT / "src" / "research" / "MultiProxAn" / "src")
|
| 13 |
+
for _path in (_COINS_KG_ROOT, _DIGRESS_KG_SRC, _MULTIPROXAN_ROOT, _MULTIPROXAN_SRC):
|
| 14 |
if _path not in sys.path:
|
| 15 |
sys.path.insert(0, _path)
|
| 16 |
|