Andrej Janchevski commited on
Commit ·
4f1e196
1
Parent(s): 0b4eacc
Implement backend discovery endpoints
Browse files- .claude/plans/backend_init.md +58 -22
- CLAUDE.md +6 -0
- docs/postman/collection.json +16 -16
- src/backend/README.md +127 -0
- src/backend/api/apps.py +10 -0
- src/backend/api/pagination.py +8 -0
- src/backend/api/services/constants.py +297 -0
- src/backend/api/services/registry.py +304 -2
- src/backend/api/urls.py +26 -1
- src/backend/api/views/coins.py +106 -0
- src/backend/api/views/graph_generation.py +29 -0
- src/backend/api/views/health.py +6 -1
- src/backend/api/views/kg_anomaly.py +42 -0
- src/backend/requirements.txt +36 -0
- src/backend/research_api/settings.py +8 -0
- src/research/COINs-KGGeneration/graph_completion/graphs/load_graph.py +5 -34
- src/research/COINs-KGGeneration/graph_completion/graphs/preprocess.py +13 -1
- src/research/COINs-KGGeneration/graph_data/load_data.py +3 -112
.claude/plans/backend_init.md
CHANGED
|
@@ -10,6 +10,7 @@ Build the Django backend foundation: project scaffolding, dataset/entity/relatio
|
|
| 10 |
src/backend/
|
| 11 |
manage.py
|
| 12 |
requirements.txt
|
|
|
|
| 13 |
research_api/
|
| 14 |
__init__.py
|
| 15 |
settings.py
|
|
@@ -51,17 +52,38 @@ The research `Loader` needs: `igraph`, `pandas`, `numpy`, `torch` (for device ha
|
|
| 51 |
- Checks `_all_checkpoint_dirs_populated()` first — skips download if all dirs have files
|
| 52 |
- Uses `gdown.download_folder()` to a `.checkpoint_staging` dir, then `_distribute_checkpoints()` moves files to correct locations
|
| 53 |
- Gracefully handles failures (logs warning, continues with local files)
|
| 54 |
-
2. **
|
| 55 |
-
3. **Scans checkpoint directories** for `.tar` / `.ckpt` files to determine availability flags (does NOT load model weights)
|
| 56 |
- COINs: parses `{dataset}_{algorithm}.tar` filenames
|
| 57 |
- MultiProxAn: `{dataset}.ckpt` = discrete, `{dataset}_c.ckpt` = continuous
|
| 58 |
- DiGress KG: `{dataset}.ckpt` = generate, `{dataset}_correct.ckpt` = correct
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
### Checkpoint Download (on each boot)
|
| 67 |
|
|
@@ -87,6 +109,15 @@ Download logic in `registry.py`:
|
|
| 87 |
- MultiProxAn (graph generation): glob `{MULTIPROXAN_DIR}/checkpoints/{dataset}.ckpt` and `{dataset}_c.ckpt`
|
| 88 |
- DiGress KG (anomaly correction): glob `{DIGRESS_KG_DIR}/checkpoints/{dataset}.ckpt` and `{dataset}_correct.ckpt`
|
| 89 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 90 |
## Django Settings
|
| 91 |
|
| 92 |
- `DATABASES = {}` — no database
|
|
@@ -205,7 +236,7 @@ cd src/analysis/orca && g++ -O2 -std=c++11 -o orca orca.cpp
|
|
| 205 |
|
| 206 |
### Backend requirements.txt
|
| 207 |
|
| 208 |
-
The backend unifies both into one Python 3.9 environment. Since the COINs `Loader` is the only part reused for init, we pin to the DiGress/MultiProxAn versions and ensure COINs code runs under them.
|
| 209 |
|
| 210 |
```
|
| 211 |
# Django
|
|
@@ -216,11 +247,11 @@ django-cors-headers==4.*
|
|
| 216 |
# Checkpoint download from Google Drive
|
| 217 |
gdown>=4.7
|
| 218 |
|
| 219 |
-
# PyTorch (CPU
|
| 220 |
-
-
|
| 221 |
-
torch==2.0.1+
|
| 222 |
-
torchvision==0.15.2+
|
| 223 |
-
torchaudio==2.0.2+
|
| 224 |
|
| 225 |
# PyTorch Geometric (must match torch version)
|
| 226 |
torch-geometric==2.3.1
|
|
@@ -239,6 +270,7 @@ matplotlib==3.7.1
|
|
| 239 |
imageio==2.31.1
|
| 240 |
torchmetrics==0.11.4
|
| 241 |
tqdm==4.65.0
|
|
|
|
| 242 |
|
| 243 |
# Conda-only deps (must be pre-installed, not in pip requirements):
|
| 244 |
# rdkit==2023.03.2 (conda create -c conda-forge)
|
|
@@ -263,25 +295,29 @@ tqdm==4.65.0
|
|
| 263 |
12. Hook `ModelRegistry.initialize()` in `AppConfig.ready()`
|
| 264 |
13. Add research `Loader` integration to registry for dataset loading + subgraph generation
|
| 265 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 266 |
## Verification
|
| 267 |
|
| 268 |
1. `pip install -r requirements.txt` — all dependencies install cleanly into `website_c` env
|
| 269 |
2. `python manage.py check` — Django system check passes, 0 issues
|
| 270 |
-
3. `python manage.py runserver` — registry downloads checkpoints
|
| 271 |
-
4. `GET /api/v1/` — returns API name, version, description, full endpoint directory
|
| 272 |
5. `GET /api/v1/health` — returns `{"status": "ok", "models_loaded": {"coins": bool, "multiproxan": bool, "kg_anomaly": bool}}`
|
| 273 |
6. `GET /api/v1/methods` — returns 3 methods
|
| 274 |
7. `GET /api/v1/coins/datasets` — returns 3 datasets with correct entity/relation counts
|
| 275 |
-
8. `GET /api/v1/coins/datasets/
|
| 276 |
-
9. `GET /api/v1/coins/datasets/
|
| 277 |
-
10. `GET /api/v1/coins/datasets/
|
| 278 |
11. `GET /api/v1/coins/datasets/unknown/entities` — `{"error": {"code": "NOT_FOUND", ...}}`
|
| 279 |
12. `GET /api/v1/coins/models` — 6 algorithms, filtered by actually available checkpoints
|
| 280 |
13. `GET /api/v1/coins/query-structures` — 7 templates (1p, 2p, 3p, 2i, 3i, ip, pi) matching API spec
|
| 281 |
-
14. `GET /api/v1/graph-generation/datasets` — QM9 + Community20 with model type availability
|
| 282 |
15. `GET /api/v1/graph-generation/sampling-modes` — standard + multiprox with parameter specs
|
| 283 |
-
16. `GET /api/v1/kg-anomaly/datasets` — 3 datasets with availability from
|
| 284 |
-
17. `GET /api/v1/kg-anomaly/datasets/
|
| 285 |
18. Import `docs/postman/collection.json` into Postman, set environment, run all GET requests against running server
|
| 286 |
|
| 287 |
## Critical Source Files
|
|
|
|
| 10 |
src/backend/
|
| 11 |
manage.py
|
| 12 |
requirements.txt
|
| 13 |
+
README.md
|
| 14 |
research_api/
|
| 15 |
__init__.py
|
| 16 |
settings.py
|
|
|
|
| 52 |
- Checks `_all_checkpoint_dirs_populated()` first — skips download if all dirs have files
|
| 53 |
- Uses `gdown.download_folder()` to a `.checkpoint_staging` dir, then `_distribute_checkpoints()` moves files to correct locations
|
| 54 |
- Gracefully handles failures (logs warning, continues with local files)
|
| 55 |
+
2. **Scans checkpoint directories** for `.tar` / `.ckpt` files to determine availability flags (does NOT load model weights)
|
|
|
|
| 56 |
- COINs: parses `{dataset}_{algorithm}.tar` filenames
|
| 57 |
- MultiProxAn: `{dataset}.ckpt` = discrete, `{dataset}_c.ckpt` = continuous
|
| 58 |
- DiGress KG: `{dataset}.ckpt` = generate, `{dataset}_correct.ckpt` = correct
|
| 59 |
+
3. **Loads one lightweight COINs Loader per dataset** using vanilla seeds, then frees heavy arrays
|
| 60 |
+
4. **Generates sample subgraphs** for KG anomaly using the Loader's DFS-based context subgraph partitioning, capped via `max_samples` to avoid O(subgraphs²) memory
|
| 61 |
+
|
| 62 |
+
### Memory Management
|
| 63 |
+
|
| 64 |
+
Each Loader's `load_graph()` produces heavy arrays not needed for discovery endpoints:
|
| 65 |
+
- `node_neighbours`: shape `(num_nodes, num_relations, num_neighbours)` — ~275MB for freebase alone
|
| 66 |
+
- `com_neighbours`, `node_adjacency`, `com_adjacency`: large dicts/arrays
|
| 67 |
+
- `graph` (igraph object), degree arrays, importance arrays, machine assignments
|
| 68 |
+
|
| 69 |
+
After init, `_free_heavy_arrays()` sets all these to `None`. Only dataset name maps, train edge data, graph indexes, and community metadata are kept.
|
| 70 |
+
|
| 71 |
+
The DFS subgraph partitioning (`get_context_subgraph_samples_dfs`) allocates an O(subgraphs²) edge matrix. A `max_samples` parameter caps partitioning to only process enough nodes to produce the needed samples. The sampler's `context_subgraphs_nodes`/`context_subgraphs_edges` are also freed after sample generation.
|
| 72 |
+
|
| 73 |
+
Django's dev server auto-reloader spawns two processes both calling `ready()`. The outer (file watcher) process is skipped via `RUN_MAIN` env var check in `apps.py` to avoid double memory usage.
|
| 74 |
+
|
| 75 |
+
### Per-Algorithm Checkpoint Configs
|
| 76 |
+
|
| 77 |
+
Different algorithms were trained with different random seeds, producing different Leiden community structures:
|
| 78 |
+
- **Vanilla group** (transe/distmult/complex/rotate): shared seed per dataset
|
| 79 |
+
- **Q2B, KBGAT, nell_complex**: individual seeds from experiment runs
|
| 80 |
+
|
| 81 |
+
Community structure groups:
|
| 82 |
+
- freebase: transe/distmult/complex/rotate (1092 com) | q2b (1030) | kbgat (1025)
|
| 83 |
+
- wordnet: transe/distmult/complex/rotate (66 com) | q2b (74) | kbgat (88)
|
| 84 |
+
- nell: transe/distmult/rotate (282 com) | complex (188) | q2b (282, diff sizes) | kbgat (275)
|
| 85 |
+
|
| 86 |
+
`_CHECKPOINT_SEEDS` maps `(dataset, algorithm)` → training seed. `get_checkpoint_config(dataset_id, algorithm)` returns the full config. At startup only one Loader per dataset (vanilla seed) loads for discovery. Per-algorithm Loaders for inference load on demand.
|
| 87 |
|
| 88 |
### Checkpoint Download (on each boot)
|
| 89 |
|
|
|
|
| 109 |
- MultiProxAn (graph generation): glob `{MULTIPROXAN_DIR}/checkpoints/{dataset}.ckpt` and `{dataset}_c.ckpt`
|
| 110 |
- DiGress KG (anomaly correction): glob `{DIGRESS_KG_DIR}/checkpoints/{dataset}.ckpt` and `{dataset}_correct.ckpt`
|
| 111 |
|
| 112 |
+
### Sample Subgraph Generation
|
| 113 |
+
|
| 114 |
+
For KG anomaly `sample-subgraphs` endpoint:
|
| 115 |
+
- `Sampler.get_context_subgraph_samples_dfs()` in `graph_completion/graphs/preprocess.py` — DFS-based partitioning of the graph into sized subgraphs
|
| 116 |
+
- Returns `List[ContextSubgraph]` tuples of `(subgraph_row, subgraph_col, nodes_row, nodes_col, edges)`
|
| 117 |
+
- `max_samples` parameter caps how many subgraphs to partition (avoids full-graph traversal)
|
| 118 |
+
- Called once per dataset at startup, stored as `SubgraphInfo`, converted to API response format with entity/relation names resolved from the Loader's name maps
|
| 119 |
+
- Sampler's internal partitioning data freed after generation
|
| 120 |
+
|
| 121 |
## Django Settings
|
| 122 |
|
| 123 |
- `DATABASES = {}` — no database
|
|
|
|
| 236 |
|
| 237 |
### Backend requirements.txt
|
| 238 |
|
| 239 |
+
The backend unifies both into one Python 3.9 environment. Since the COINs `Loader` is the only part reused for init, we pin to the DiGress/MultiProxAn versions and ensure COINs code runs under them. PyTorch CUDA 11.8 for local dev, falls back to CPU at runtime on PythonAnywhere.
|
| 240 |
|
| 241 |
```
|
| 242 |
# Django
|
|
|
|
| 247 |
# Checkpoint download from Google Drive
|
| 248 |
gdown>=4.7
|
| 249 |
|
| 250 |
+
# PyTorch with CUDA 11.8 (falls back to CPU at runtime if no GPU present)
|
| 251 |
+
--extra-index-url https://download.pytorch.org/whl/cu118
|
| 252 |
+
torch==2.0.1+cu118
|
| 253 |
+
torchvision==0.15.2+cu118
|
| 254 |
+
torchaudio==2.0.2+cu118
|
| 255 |
|
| 256 |
# PyTorch Geometric (must match torch version)
|
| 257 |
torch-geometric==2.3.1
|
|
|
|
| 270 |
imageio==2.31.1
|
| 271 |
torchmetrics==0.11.4
|
| 272 |
tqdm==4.65.0
|
| 273 |
+
scikit-learn>=1.0
|
| 274 |
|
| 275 |
# Conda-only deps (must be pre-installed, not in pip requirements):
|
| 276 |
# rdkit==2023.03.2 (conda create -c conda-forge)
|
|
|
|
| 295 |
12. Hook `ModelRegistry.initialize()` in `AppConfig.ready()`
|
| 296 |
13. Add research `Loader` integration to registry for dataset loading + subgraph generation
|
| 297 |
|
| 298 |
+
## Known Issues
|
| 299 |
+
|
| 300 |
+
- **Leiden non-reproducibility**: igraph 0.9.11's `community_leiden()` has no seed parameter. Re-running produces different community structures than the original training runs. For inference, the original cached `subgraphing.gz` files from the training server may be needed, or a matching igraph version that reproduces the results.
|
| 301 |
+
|
| 302 |
## Verification
|
| 303 |
|
| 304 |
1. `pip install -r requirements.txt` — all dependencies install cleanly into `website_c` env
|
| 305 |
2. `python manage.py check` — Django system check passes, 0 issues
|
| 306 |
+
3. `python manage.py runserver` — registry downloads checkpoints (or skips), scans dirs, logs entity/relation counts
|
| 307 |
+
4. `GET /api/v1/` — returns API name, version, description, full endpoint directory
|
| 308 |
5. `GET /api/v1/health` — returns `{"status": "ok", "models_loaded": {"coins": bool, "multiproxan": bool, "kg_anomaly": bool}}`
|
| 309 |
6. `GET /api/v1/methods` — returns 3 methods
|
| 310 |
7. `GET /api/v1/coins/datasets` — returns 3 datasets with correct entity/relation counts
|
| 311 |
+
8. `GET /api/v1/coins/datasets/wordnet/entities?q=dog&page=1&page_size=10` — filtered, paginated results
|
| 312 |
+
9. `GET /api/v1/coins/datasets/wordnet/relations?q=hyper` — substring search on relation names
|
| 313 |
+
10. `GET /api/v1/coins/datasets/wordnet/sample-triples?count=5` — 5 random triples with resolved names
|
| 314 |
11. `GET /api/v1/coins/datasets/unknown/entities` — `{"error": {"code": "NOT_FOUND", ...}}`
|
| 315 |
12. `GET /api/v1/coins/models` — 6 algorithms, filtered by actually available checkpoints
|
| 316 |
13. `GET /api/v1/coins/query-structures` — 7 templates (1p, 2p, 3p, 2i, 3i, ip, pi) matching API spec
|
| 317 |
+
14. `GET /api/v1/graph-generation/datasets` — QM9 + Community20 with model type availability
|
| 318 |
15. `GET /api/v1/graph-generation/sampling-modes` — standard + multiprox with parameter specs
|
| 319 |
+
16. `GET /api/v1/kg-anomaly/datasets` — 3 datasets with availability from checkpoint scan
|
| 320 |
+
17. `GET /api/v1/kg-anomaly/datasets/wordnet/sample-subgraphs?count=3` — 3 pre-computed subgraphs with entity/relation names
|
| 321 |
18. Import `docs/postman/collection.json` into Postman, set environment, run all GET requests against running server
|
| 322 |
|
| 323 |
## Critical Source Files
|
CLAUDE.md
CHANGED
|
@@ -35,3 +35,9 @@ PythonAnywhere and hosted by them. API requests only from the front-end are allo
|
|
| 35 |
system boot
|
| 36 |
- `src/backend`: Django backend and endpoint files location
|
| 37 |
- `src/frontend`: Vue.js and Semantic UI frontend files location
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
system boot
|
| 36 |
- `src/backend`: Django backend and endpoint files location
|
| 37 |
- `src/frontend`: Vue.js and Semantic UI frontend files location
|
| 38 |
+
|
| 39 |
+
## Backend Documentation
|
| 40 |
+
|
| 41 |
+
When making any backend change (new endpoint, config change, dependency, startup behavior), update `src/backend/README.md`
|
| 42 |
+
to reflect the change. Keep the endpoint table, project structure, and startup sequence sections current.
|
| 43 |
+
|
docs/postman/collection.json
CHANGED
|
@@ -77,13 +77,13 @@
|
|
| 77 |
"method": "GET",
|
| 78 |
"header": [],
|
| 79 |
"url": {
|
| 80 |
-
"raw": "{{base_url}}/coins/datasets/
|
| 81 |
"host": ["{{base_url}}"],
|
| 82 |
-
"path": ["coins", "datasets", "
|
| 83 |
"query": [
|
| 84 |
{ "key": "page", "value": "1" },
|
| 85 |
{ "key": "page_size", "value": "10" },
|
| 86 |
-
{ "key": "q", "value": "
|
| 87 |
]
|
| 88 |
},
|
| 89 |
"description": "Paginated, searchable entity list."
|
|
@@ -95,13 +95,13 @@
|
|
| 95 |
"method": "GET",
|
| 96 |
"header": [],
|
| 97 |
"url": {
|
| 98 |
-
"raw": "{{base_url}}/coins/datasets/
|
| 99 |
"host": ["{{base_url}}"],
|
| 100 |
-
"path": ["coins", "datasets", "
|
| 101 |
"query": [
|
| 102 |
{ "key": "page", "value": "1" },
|
| 103 |
{ "key": "page_size", "value": "10" },
|
| 104 |
-
{ "key": "q", "value": "
|
| 105 |
]
|
| 106 |
},
|
| 107 |
"description": "Paginated, searchable relation list."
|
|
@@ -113,9 +113,9 @@
|
|
| 113 |
"method": "GET",
|
| 114 |
"header": [],
|
| 115 |
"url": {
|
| 116 |
-
"raw": "{{base_url}}/coins/datasets/
|
| 117 |
"host": ["{{base_url}}"],
|
| 118 |
-
"path": ["coins", "datasets", "
|
| 119 |
"query": [
|
| 120 |
{ "key": "count", "value": "5" }
|
| 121 |
]
|
|
@@ -163,7 +163,7 @@
|
|
| 163 |
],
|
| 164 |
"body": {
|
| 165 |
"mode": "raw",
|
| 166 |
-
"raw": "{\n \"dataset_id\": \"
|
| 167 |
},
|
| 168 |
"url": {
|
| 169 |
"raw": "{{base_url}}/coins/predict",
|
|
@@ -182,7 +182,7 @@
|
|
| 182 |
],
|
| 183 |
"body": {
|
| 184 |
"mode": "raw",
|
| 185 |
-
"raw": "{\n \"dataset_id\": \"
|
| 186 |
},
|
| 187 |
"url": {
|
| 188 |
"raw": "{{base_url}}/coins/predict",
|
|
@@ -201,7 +201,7 @@
|
|
| 201 |
],
|
| 202 |
"body": {
|
| 203 |
"mode": "raw",
|
| 204 |
-
"raw": "{\n \"dataset_id\": \"
|
| 205 |
},
|
| 206 |
"url": {
|
| 207 |
"raw": "{{base_url}}/coins/predict",
|
|
@@ -328,9 +328,9 @@
|
|
| 328 |
"method": "GET",
|
| 329 |
"header": [],
|
| 330 |
"url": {
|
| 331 |
-
"raw": "{{base_url}}/kg-anomaly/datasets/
|
| 332 |
"host": ["{{base_url}}"],
|
| 333 |
-
"path": ["kg-anomaly", "datasets", "
|
| 334 |
"query": [
|
| 335 |
{ "key": "count", "value": "3" }
|
| 336 |
]
|
|
@@ -352,7 +352,7 @@
|
|
| 352 |
],
|
| 353 |
"body": {
|
| 354 |
"mode": "raw",
|
| 355 |
-
"raw": "{\n \"dataset_id\": \"
|
| 356 |
},
|
| 357 |
"url": {
|
| 358 |
"raw": "{{base_url}}/kg-anomaly/correct",
|
|
@@ -371,7 +371,7 @@
|
|
| 371 |
],
|
| 372 |
"body": {
|
| 373 |
"mode": "raw",
|
| 374 |
-
"raw": "{\n \"dataset_id\": \"
|
| 375 |
},
|
| 376 |
"url": {
|
| 377 |
"raw": "{{base_url}}/kg-anomaly/correct",
|
|
@@ -390,7 +390,7 @@
|
|
| 390 |
],
|
| 391 |
"body": {
|
| 392 |
"mode": "raw",
|
| 393 |
-
"raw": "{\n \"dataset_id\": \"
|
| 394 |
},
|
| 395 |
"url": {
|
| 396 |
"raw": "{{base_url}}/kg-anomaly/correct",
|
|
|
|
| 77 |
"method": "GET",
|
| 78 |
"header": [],
|
| 79 |
"url": {
|
| 80 |
+
"raw": "{{base_url}}/coins/datasets/wordnet/entities?page=1&page_size=10&q=dog",
|
| 81 |
"host": ["{{base_url}}"],
|
| 82 |
+
"path": ["coins", "datasets", "wordnet", "entities"],
|
| 83 |
"query": [
|
| 84 |
{ "key": "page", "value": "1" },
|
| 85 |
{ "key": "page_size", "value": "10" },
|
| 86 |
+
{ "key": "q", "value": "dog", "description": "Substring search filter" }
|
| 87 |
]
|
| 88 |
},
|
| 89 |
"description": "Paginated, searchable entity list."
|
|
|
|
| 95 |
"method": "GET",
|
| 96 |
"header": [],
|
| 97 |
"url": {
|
| 98 |
+
"raw": "{{base_url}}/coins/datasets/wordnet/relations?page=1&page_size=10&q=hyper",
|
| 99 |
"host": ["{{base_url}}"],
|
| 100 |
+
"path": ["coins", "datasets", "wordnet", "relations"],
|
| 101 |
"query": [
|
| 102 |
{ "key": "page", "value": "1" },
|
| 103 |
{ "key": "page_size", "value": "10" },
|
| 104 |
+
{ "key": "q", "value": "hyper", "description": "Substring search filter" }
|
| 105 |
]
|
| 106 |
},
|
| 107 |
"description": "Paginated, searchable relation list."
|
|
|
|
| 113 |
"method": "GET",
|
| 114 |
"header": [],
|
| 115 |
"url": {
|
| 116 |
+
"raw": "{{base_url}}/coins/datasets/wordnet/sample-triples?count=5",
|
| 117 |
"host": ["{{base_url}}"],
|
| 118 |
+
"path": ["coins", "datasets", "wordnet", "sample-triples"],
|
| 119 |
"query": [
|
| 120 |
{ "key": "count", "value": "5" }
|
| 121 |
]
|
|
|
|
| 163 |
],
|
| 164 |
"body": {
|
| 165 |
"mode": "raw",
|
| 166 |
+
"raw": "{\n \"dataset_id\": \"wordnet\",\n \"algorithm\": \"rotate\",\n \"query_structure\": \"1p\",\n \"anchors\": {\"a\": 11754},\n \"relations\": {\"r1\": 3},\n \"top_k\": 10\n}"
|
| 167 |
},
|
| 168 |
"url": {
|
| 169 |
"raw": "{{base_url}}/coins/predict",
|
|
|
|
| 182 |
],
|
| 183 |
"body": {
|
| 184 |
"mode": "raw",
|
| 185 |
+
"raw": "{\n \"dataset_id\": \"wordnet\",\n \"algorithm\": \"q2b\",\n \"query_structure\": \"2i\",\n \"anchors\": {\"a1\": 11754, \"a2\": 5142},\n \"relations\": {\"r1\": 3, \"r2\": 1},\n \"top_k\": 10\n}"
|
| 186 |
},
|
| 187 |
"url": {
|
| 188 |
"raw": "{{base_url}}/coins/predict",
|
|
|
|
| 201 |
],
|
| 202 |
"body": {
|
| 203 |
"mode": "raw",
|
| 204 |
+
"raw": "{\n \"dataset_id\": \"wordnet\",\n \"algorithm\": \"q2b\",\n \"query_structure\": \"ip\",\n \"anchors\": {\"a1\": 11754, \"a2\": 5142},\n \"relations\": {\"r1\": 3, \"r2\": 1, \"r3\": 2},\n \"top_k\": 10\n}"
|
| 205 |
},
|
| 206 |
"url": {
|
| 207 |
"raw": "{{base_url}}/coins/predict",
|
|
|
|
| 328 |
"method": "GET",
|
| 329 |
"header": [],
|
| 330 |
"url": {
|
| 331 |
+
"raw": "{{base_url}}/kg-anomaly/datasets/wordnet/sample-subgraphs?count=3",
|
| 332 |
"host": ["{{base_url}}"],
|
| 333 |
+
"path": ["kg-anomaly", "datasets", "wordnet", "sample-subgraphs"],
|
| 334 |
"query": [
|
| 335 |
{ "key": "count", "value": "3" }
|
| 336 |
]
|
|
|
|
| 352 |
],
|
| 353 |
"body": {
|
| 354 |
"mode": "raw",
|
| 355 |
+
"raw": "{\n \"dataset_id\": \"wordnet\",\n \"sampling_mode\": \"standard\",\n \"task\": \"correct\",\n \"subgraph\": {\n \"nodes\": [\n {\"entity_id\": 11754, \"type_id\": 3},\n {\"entity_id\": 5142, \"type_id\": 3},\n {\"entity_id\": 8142, \"type_id\": 3}\n ],\n \"edges\": [\n {\"source_idx\": 0, \"target_idx\": 1, \"relation_id\": 3},\n {\"source_idx\": 1, \"target_idx\": 2, \"relation_id\": 1}\n ]\n },\n \"diffusion_steps\": 500,\n \"chain_frames\": 20\n}"
|
| 356 |
},
|
| 357 |
"url": {
|
| 358 |
"raw": "{{base_url}}/kg-anomaly/correct",
|
|
|
|
| 371 |
],
|
| 372 |
"body": {
|
| 373 |
"mode": "raw",
|
| 374 |
+
"raw": "{\n \"dataset_id\": \"wordnet\",\n \"sampling_mode\": \"standard\",\n \"task\": \"generate\",\n \"subgraph\": {\n \"nodes\": [\n {\"entity_id\": 11754, \"type_id\": 3},\n {\"entity_id\": 5142, \"type_id\": 3}\n ],\n \"edges\": [\n {\"source_idx\": 0, \"target_idx\": 1, \"relation_id\": 3}\n ]\n },\n \"diffusion_steps\": 500,\n \"chain_frames\": 20\n}"
|
| 375 |
},
|
| 376 |
"url": {
|
| 377 |
"raw": "{{base_url}}/kg-anomaly/correct",
|
|
|
|
| 390 |
],
|
| 391 |
"body": {
|
| 392 |
"mode": "raw",
|
| 393 |
+
"raw": "{\n \"dataset_id\": \"wordnet\",\n \"sampling_mode\": \"multiprox\",\n \"task\": \"correct\",\n \"subgraph\": {\n \"nodes\": [\n {\"entity_id\": 11754, \"type_id\": 3},\n {\"entity_id\": 5142, \"type_id\": 3}\n ],\n \"edges\": [\n {\"source_idx\": 0, \"target_idx\": 1, \"relation_id\": 3}\n ]\n },\n \"diffusion_steps\": 500,\n \"multiprox_params\": {\n \"m\": 10,\n \"t\": 0.5,\n \"t_prime\": 0.1\n }\n}"
|
| 394 |
},
|
| 395 |
"url": {
|
| 396 |
"raw": "{{base_url}}/kg-anomaly/correct",
|
src/backend/README.md
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Backend — Django REST API
|
| 2 |
+
|
| 3 |
+
Stateless REST API serving the PhD research models. No database — PyTorch checkpoints are loaded into memory at startup and used to answer each request independently.
|
| 4 |
+
|
| 5 |
+
## Prerequisites
|
| 6 |
+
|
| 7 |
+
1. **Conda environment** with pre-installed system deps:
|
| 8 |
+
```bash
|
| 9 |
+
conda create -n website python=3.9
|
| 10 |
+
conda activate website
|
| 11 |
+
conda install -c conda-forge rdkit=2023.03.2 graph-tool=2.45
|
| 12 |
+
```
|
| 13 |
+
|
| 14 |
+
2. **Pip dependencies**:
|
| 15 |
+
```bash
|
| 16 |
+
pip install -r requirements.txt
|
| 17 |
+
```
|
| 18 |
+
|
| 19 |
+
3. **Model checkpoints** — downloaded automatically from Google Drive on first boot (via `gdown`). Alternatively, manually place `.tar` / `.ckpt` files in:
|
| 20 |
+
- `src/research/COINs-KGGeneration/graph_completion/checkpoints/` (COINs: `{dataset}_{algorithm}.tar`)
|
| 21 |
+
- `src/research/COINs-KGGeneration/graph_generation/checkpoints/` (KG anomaly: `{dataset}.ckpt`, `{dataset}_correct.ckpt`)
|
| 22 |
+
- `src/research/MultiProxAn/checkpoints/` (graph generation: `{dataset}.ckpt`, `{dataset}_c.ckpt`)
|
| 23 |
+
|
| 24 |
+
4. **Dataset files** — the raw KG data files must be present under `src/research/COINs-KGGeneration/data/` (FB15k-237, WN18RR, NELL-995).
|
| 25 |
+
|
| 26 |
+
## Running
|
| 27 |
+
|
| 28 |
+
From `src/backend/`:
|
| 29 |
+
|
| 30 |
+
```bash
|
| 31 |
+
# Development server
|
| 32 |
+
python manage.py runserver 8000
|
| 33 |
+
|
| 34 |
+
# With custom settings
|
| 35 |
+
DJANGO_DEBUG=True DJANGO_SECRET_KEY=my-secret python manage.py runserver
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
The API is served at `http://localhost:8000/api/v1/`.
|
| 39 |
+
|
| 40 |
+
## Environment Variables
|
| 41 |
+
|
| 42 |
+
| Variable | Default | Description |
|
| 43 |
+
|---|---|---|
|
| 44 |
+
| `DJANGO_SECRET_KEY` | `dev-insecure-key-change-in-production` | Django secret key. **Set in production.** |
|
| 45 |
+
| `DJANGO_DEBUG` | `True` | Enable debug mode. Set to `False` in production. |
|
| 46 |
+
| `DJANGO_ALLOWED_HOSTS` | `localhost,127.0.0.1` | Comma-separated allowed hosts. |
|
| 47 |
+
|
| 48 |
+
## Startup Sequence
|
| 49 |
+
|
| 50 |
+
On boot (`ApiConfig.ready()`), the `ModelRegistry` initializes:
|
| 51 |
+
|
| 52 |
+
1. **Download checkpoints** from Google Drive if not already present locally
|
| 53 |
+
2. **Scan checkpoint directories** to detect available models per method
|
| 54 |
+
3. **Load lightweight COINs Loaders** — one per dataset (freebase, wordnet, nell), loading graph data, name maps, and train/val/test splits. Heavy arrays (node neighbours ~275MB each, community neighbours, adjacency dicts) are freed after initialization to keep memory low.
|
| 55 |
+
4. **Generate sample subgraphs** for KG anomaly using the COINs Loaders
|
| 56 |
+
|
| 57 |
+
All model weights (COINs inference, graph generation, KG anomaly) are loaded lazily at first inference request.
|
| 58 |
+
|
| 59 |
+
## API Endpoints
|
| 60 |
+
|
| 61 |
+
All endpoints are prefixed with `/api/v1/`.
|
| 62 |
+
|
| 63 |
+
### Health & Discovery
|
| 64 |
+
|
| 65 |
+
| Method | Path | Description |
|
| 66 |
+
|---|---|---|
|
| 67 |
+
| `GET` | `/health` | Service health + model availability |
|
| 68 |
+
| `GET` | `/methods` | List the 3 research methods |
|
| 69 |
+
|
| 70 |
+
### COINs — KG Reasoning
|
| 71 |
+
|
| 72 |
+
| Method | Path | Description |
|
| 73 |
+
|---|---|---|
|
| 74 |
+
| `GET` | `/coins/datasets` | List datasets with entity/relation counts |
|
| 75 |
+
| `GET` | `/coins/datasets/{id}/entities` | Paginated entity search (`?q=&page=&page_size=`) |
|
| 76 |
+
| `GET` | `/coins/datasets/{id}/relations` | Paginated relation search (`?q=&page=&page_size=`) |
|
| 77 |
+
| `GET` | `/coins/datasets/{id}/sample-triples` | Random training triples (`?count=10`) |
|
| 78 |
+
| `GET` | `/coins/models` | Available algorithms + supported query structures |
|
| 79 |
+
| `GET` | `/coins/query-structures` | Query graph templates for frontend rendering |
|
| 80 |
+
| `POST` | `/coins/predict` | Run inference (not yet implemented) |
|
| 81 |
+
|
| 82 |
+
### Graph Generation — MultiProxAn
|
| 83 |
+
|
| 84 |
+
| Method | Path | Description |
|
| 85 |
+
|---|---|---|
|
| 86 |
+
| `GET` | `/graph-generation/datasets` | List graph types with node/edge types |
|
| 87 |
+
| `GET` | `/graph-generation/sampling-modes` | Sampling strategies with parameter specs |
|
| 88 |
+
| `POST` | `/graph-generation/generate` | Generate a graph (not yet implemented) |
|
| 89 |
+
| `POST` | `/graph-generation/continue` | Continue MultiProx generation (not yet implemented) |
|
| 90 |
+
|
| 91 |
+
### KG Anomaly Correction
|
| 92 |
+
|
| 93 |
+
| Method | Path | Description |
|
| 94 |
+
|---|---|---|
|
| 95 |
+
| `GET` | `/kg-anomaly/datasets` | List datasets with correction models |
|
| 96 |
+
| `GET` | `/kg-anomaly/datasets/{id}/sample-subgraphs` | Pre-computed example subgraphs (`?count=5`) |
|
| 97 |
+
| `POST` | `/kg-anomaly/correct` | Run correction (not yet implemented) |
|
| 98 |
+
| `POST` | `/kg-anomaly/continue` | Continue MultiProx correction (not yet implemented) |
|
| 99 |
+
|
| 100 |
+
## Project Structure
|
| 101 |
+
|
| 102 |
+
```
|
| 103 |
+
src/backend/
|
| 104 |
+
manage.py
|
| 105 |
+
requirements.txt
|
| 106 |
+
research_api/ # Django project settings
|
| 107 |
+
settings.py
|
| 108 |
+
urls.py
|
| 109 |
+
wsgi.py
|
| 110 |
+
api/ # Django app
|
| 111 |
+
apps.py # Triggers ModelRegistry.initialize() on startup
|
| 112 |
+
urls.py # Route definitions
|
| 113 |
+
pagination.py # Shared pagination helper
|
| 114 |
+
exceptions.py # Custom error envelope
|
| 115 |
+
services/
|
| 116 |
+
constants.py # Dataset metadata, model configs, query structures
|
| 117 |
+
registry.py # ModelRegistry — checkpoint download, scanning, Loader init
|
| 118 |
+
views/
|
| 119 |
+
health.py # /health, /methods
|
| 120 |
+
coins.py # /coins/* endpoints
|
| 121 |
+
graph_generation.py # /graph-generation/* endpoints
|
| 122 |
+
kg_anomaly.py # /kg-anomaly/* endpoints
|
| 123 |
+
```
|
| 124 |
+
|
| 125 |
+
## Testing with Postman
|
| 126 |
+
|
| 127 |
+
Import the collection and environment from `docs/postman/` to test all discovery endpoints.
|
src/backend/api/apps.py
CHANGED
|
@@ -6,6 +6,16 @@ class ApiConfig(AppConfig):
|
|
| 6 |
default_auto_field = "django.db.models.BigAutoField"
|
| 7 |
|
| 8 |
def ready(self):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
from api.services.registry import ModelRegistry
|
| 10 |
|
| 11 |
ModelRegistry.initialize()
|
|
|
|
| 6 |
default_auto_field = "django.db.models.BigAutoField"
|
| 7 |
|
| 8 |
def ready(self):
|
| 9 |
+
import os
|
| 10 |
+
import sys
|
| 11 |
+
|
| 12 |
+
# Django's runserver auto-reloader spawns two processes, both calling ready().
|
| 13 |
+
# The inner (serving) process has RUN_MAIN="true"; the outer (watcher) does not.
|
| 14 |
+
# Skip initialization in the outer process to avoid double memory usage.
|
| 15 |
+
uses_reloader = "runserver" in sys.argv and "--noreload" not in sys.argv
|
| 16 |
+
if uses_reloader and os.environ.get("RUN_MAIN") != "true":
|
| 17 |
+
return
|
| 18 |
+
|
| 19 |
from api.services.registry import ModelRegistry
|
| 20 |
|
| 21 |
ModelRegistry.initialize()
|
src/backend/api/pagination.py
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
def paginate_list(items, page=1, page_size=50):
|
| 2 |
+
"""Paginate a list with 1-indexed pages. Returns (page_items, total)."""
|
| 3 |
+
page = max(1, page)
|
| 4 |
+
page_size = max(1, min(200, page_size))
|
| 5 |
+
total = len(items)
|
| 6 |
+
start = (page - 1) * page_size
|
| 7 |
+
end = start + page_size
|
| 8 |
+
return items[start:end], total
|
src/backend/api/services/constants.py
ADDED
|
@@ -0,0 +1,297 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
METHODS = [
|
| 2 |
+
{
|
| 3 |
+
"id": "coins",
|
| 4 |
+
"name": "COINs - Knowledge Graph Reasoning",
|
| 5 |
+
"thesis_section": "3.1",
|
| 6 |
+
"description": (
|
| 7 |
+
"Community-Informed Graph Embeddings (COINs) for scalable knowledge graph link prediction "
|
| 8 |
+
"and complex query answering. Uses community detection to localize embedding computation, "
|
| 9 |
+
"achieving significant speedups over full-graph methods."
|
| 10 |
+
),
|
| 11 |
+
},
|
| 12 |
+
{
|
| 13 |
+
"id": "multiproxan",
|
| 14 |
+
"name": "MultiProxAn - Graph Generation",
|
| 15 |
+
"thesis_section": "4.3",
|
| 16 |
+
"description": (
|
| 17 |
+
"Discrete denoising diffusion model for graph generation with MultiProx sampling. "
|
| 18 |
+
"Generates molecular graphs (QM9) and synthetic community graphs using iterative "
|
| 19 |
+
"multi-measurement Gibbs sampling for improved sample quality."
|
| 20 |
+
),
|
| 21 |
+
},
|
| 22 |
+
{
|
| 23 |
+
"id": "kg_anomaly",
|
| 24 |
+
"name": "KG Anomaly Correction",
|
| 25 |
+
"thesis_section": "4.4",
|
| 26 |
+
"description": (
|
| 27 |
+
"Diffusion-based knowledge graph subgraph correction. Applies the DiGress denoising "
|
| 28 |
+
"diffusion model to knowledge graph subgraphs to detect and correct anomalous edges."
|
| 29 |
+
),
|
| 30 |
+
},
|
| 31 |
+
]
|
| 32 |
+
|
| 33 |
+
COINS_DATASET_META = {
|
| 34 |
+
"freebase": {
|
| 35 |
+
"name": "FB15k-237",
|
| 36 |
+
"description": "Subset of Freebase knowledge base with 237 relation types",
|
| 37 |
+
"data_dir": "FB15k-237",
|
| 38 |
+
},
|
| 39 |
+
"wordnet": {
|
| 40 |
+
"name": "WN18RR",
|
| 41 |
+
"description": "Subset of WordNet lexical database with 11 relation types",
|
| 42 |
+
"data_dir": "WN18RR",
|
| 43 |
+
},
|
| 44 |
+
"nell": {
|
| 45 |
+
"name": "NELL-995",
|
| 46 |
+
"description": "Never-Ending Language Learner knowledge base with 200 relation types",
|
| 47 |
+
"data_dir": "NELL-995",
|
| 48 |
+
},
|
| 49 |
+
}
|
| 50 |
+
|
| 51 |
+
COINS_MODELS = [
|
| 52 |
+
{
|
| 53 |
+
"algorithm": "transe",
|
| 54 |
+
"name": "TransE",
|
| 55 |
+
"description": "Translation-based embedding model",
|
| 56 |
+
"supported_query_structures": ["1p"],
|
| 57 |
+
},
|
| 58 |
+
{
|
| 59 |
+
"algorithm": "distmult",
|
| 60 |
+
"name": "DistMult",
|
| 61 |
+
"description": "Bilinear diagonal embedding model",
|
| 62 |
+
"supported_query_structures": ["1p"],
|
| 63 |
+
},
|
| 64 |
+
{
|
| 65 |
+
"algorithm": "complex",
|
| 66 |
+
"name": "ComplEx",
|
| 67 |
+
"description": "Complex-valued embedding model",
|
| 68 |
+
"supported_query_structures": ["1p"],
|
| 69 |
+
},
|
| 70 |
+
{
|
| 71 |
+
"algorithm": "rotate",
|
| 72 |
+
"name": "RotatE",
|
| 73 |
+
"description": "Rotation-based embedding model in complex space",
|
| 74 |
+
"supported_query_structures": ["1p"],
|
| 75 |
+
},
|
| 76 |
+
{
|
| 77 |
+
"algorithm": "q2b",
|
| 78 |
+
"name": "Query2Box",
|
| 79 |
+
"description": "Box embedding model for complex logical queries",
|
| 80 |
+
"supported_query_structures": ["1p", "2p", "3p", "2i", "3i", "ip", "pi"],
|
| 81 |
+
},
|
| 82 |
+
{
|
| 83 |
+
"algorithm": "kbgat",
|
| 84 |
+
"name": "KBGAT",
|
| 85 |
+
"description": "Knowledge base graph attention network",
|
| 86 |
+
"supported_query_structures": ["1p"],
|
| 87 |
+
},
|
| 88 |
+
]
|
| 89 |
+
|
| 90 |
+
QUERY_STRUCTURES = [
|
| 91 |
+
{
|
| 92 |
+
"id": "1p",
|
| 93 |
+
"name": "Single Hop",
|
| 94 |
+
"description": "Direct link prediction: who/what is connected to the anchor via this relation?",
|
| 95 |
+
"nodes": [
|
| 96 |
+
{"id": "a", "type": "anchor", "label": "Anchor"},
|
| 97 |
+
{"id": "t", "type": "target", "label": "?"},
|
| 98 |
+
],
|
| 99 |
+
"edges": [
|
| 100 |
+
{"id": "r1", "source": "a", "target": "t", "label": "Relation"},
|
| 101 |
+
],
|
| 102 |
+
},
|
| 103 |
+
{
|
| 104 |
+
"id": "2p",
|
| 105 |
+
"name": "Two Hop",
|
| 106 |
+
"description": "Two-step chain: anchor -> variable -> target",
|
| 107 |
+
"nodes": [
|
| 108 |
+
{"id": "a", "type": "anchor", "label": "Anchor"},
|
| 109 |
+
{"id": "v1", "type": "variable", "label": "Variable"},
|
| 110 |
+
{"id": "t", "type": "target", "label": "?"},
|
| 111 |
+
],
|
| 112 |
+
"edges": [
|
| 113 |
+
{"id": "r1", "source": "a", "target": "v1", "label": "Relation 1"},
|
| 114 |
+
{"id": "r2", "source": "v1", "target": "t", "label": "Relation 2"},
|
| 115 |
+
],
|
| 116 |
+
},
|
| 117 |
+
{
|
| 118 |
+
"id": "3p",
|
| 119 |
+
"name": "Three Hop",
|
| 120 |
+
"description": "Three-step chain: anchor -> v1 -> v2 -> target",
|
| 121 |
+
"nodes": [
|
| 122 |
+
{"id": "a", "type": "anchor", "label": "Anchor"},
|
| 123 |
+
{"id": "v1", "type": "variable", "label": "Variable 1"},
|
| 124 |
+
{"id": "v2", "type": "variable", "label": "Variable 2"},
|
| 125 |
+
{"id": "t", "type": "target", "label": "?"},
|
| 126 |
+
],
|
| 127 |
+
"edges": [
|
| 128 |
+
{"id": "r1", "source": "a", "target": "v1", "label": "Relation 1"},
|
| 129 |
+
{"id": "r2", "source": "v1", "target": "v2", "label": "Relation 2"},
|
| 130 |
+
{"id": "r3", "source": "v2", "target": "t", "label": "Relation 3"},
|
| 131 |
+
],
|
| 132 |
+
},
|
| 133 |
+
{
|
| 134 |
+
"id": "2i",
|
| 135 |
+
"name": "Two Intersection",
|
| 136 |
+
"description": "Intersection of two single-hop queries sharing the same target",
|
| 137 |
+
"nodes": [
|
| 138 |
+
{"id": "a1", "type": "anchor", "label": "Anchor 1"},
|
| 139 |
+
{"id": "a2", "type": "anchor", "label": "Anchor 2"},
|
| 140 |
+
{"id": "t", "type": "target", "label": "?"},
|
| 141 |
+
],
|
| 142 |
+
"edges": [
|
| 143 |
+
{"id": "r1", "source": "a1", "target": "t", "label": "Relation 1"},
|
| 144 |
+
{"id": "r2", "source": "a2", "target": "t", "label": "Relation 2"},
|
| 145 |
+
],
|
| 146 |
+
},
|
| 147 |
+
{
|
| 148 |
+
"id": "3i",
|
| 149 |
+
"name": "Three Intersection",
|
| 150 |
+
"description": "Intersection of three single-hop queries sharing the same target",
|
| 151 |
+
"nodes": [
|
| 152 |
+
{"id": "a1", "type": "anchor", "label": "Anchor 1"},
|
| 153 |
+
{"id": "a2", "type": "anchor", "label": "Anchor 2"},
|
| 154 |
+
{"id": "a3", "type": "anchor", "label": "Anchor 3"},
|
| 155 |
+
{"id": "t", "type": "target", "label": "?"},
|
| 156 |
+
],
|
| 157 |
+
"edges": [
|
| 158 |
+
{"id": "r1", "source": "a1", "target": "t", "label": "Relation 1"},
|
| 159 |
+
{"id": "r2", "source": "a2", "target": "t", "label": "Relation 2"},
|
| 160 |
+
{"id": "r3", "source": "a3", "target": "t", "label": "Relation 3"},
|
| 161 |
+
],
|
| 162 |
+
},
|
| 163 |
+
{
|
| 164 |
+
"id": "ip",
|
| 165 |
+
"name": "Intersection then Projection",
|
| 166 |
+
"description": "Two anchors intersect, then the result projects via a third relation to the target",
|
| 167 |
+
"nodes": [
|
| 168 |
+
{"id": "a1", "type": "anchor", "label": "Anchor 1"},
|
| 169 |
+
{"id": "a2", "type": "anchor", "label": "Anchor 2"},
|
| 170 |
+
{"id": "v1", "type": "variable", "label": "Variable"},
|
| 171 |
+
{"id": "t", "type": "target", "label": "?"},
|
| 172 |
+
],
|
| 173 |
+
"edges": [
|
| 174 |
+
{"id": "r1", "source": "a1", "target": "v1", "label": "Relation 1"},
|
| 175 |
+
{"id": "r2", "source": "a2", "target": "v1", "label": "Relation 2"},
|
| 176 |
+
{"id": "r3", "source": "v1", "target": "t", "label": "Relation 3"},
|
| 177 |
+
],
|
| 178 |
+
},
|
| 179 |
+
{
|
| 180 |
+
"id": "pi",
|
| 181 |
+
"name": "Projection then Intersection",
|
| 182 |
+
"description": "One anchor projects then intersects with a direct connection from a second anchor",
|
| 183 |
+
"nodes": [
|
| 184 |
+
{"id": "a1", "type": "anchor", "label": "Anchor 1"},
|
| 185 |
+
{"id": "v1", "type": "variable", "label": "Variable"},
|
| 186 |
+
{"id": "a2", "type": "anchor", "label": "Anchor 2"},
|
| 187 |
+
{"id": "t", "type": "target", "label": "?"},
|
| 188 |
+
],
|
| 189 |
+
"edges": [
|
| 190 |
+
{"id": "r1", "source": "a1", "target": "v1", "label": "Relation 1"},
|
| 191 |
+
{"id": "r2", "source": "v1", "target": "t", "label": "Relation 2"},
|
| 192 |
+
{"id": "r3", "source": "a2", "target": "t", "label": "Relation 3"},
|
| 193 |
+
],
|
| 194 |
+
},
|
| 195 |
+
]
|
| 196 |
+
|
| 197 |
+
GRAPHGEN_DATASETS = {
|
| 198 |
+
"qm9": {
|
| 199 |
+
"name": "QM9",
|
| 200 |
+
"type": "molecular",
|
| 201 |
+
"description": "Small organic molecules with up to 9 heavy atoms (C, N, O, F)",
|
| 202 |
+
"node_types": ["C", "N", "O", "F"],
|
| 203 |
+
"edge_types": ["none", "single", "double", "triple", "aromatic"],
|
| 204 |
+
"max_nodes": 9,
|
| 205 |
+
},
|
| 206 |
+
"comm20": {
|
| 207 |
+
"name": "Community20",
|
| 208 |
+
"type": "synthetic",
|
| 209 |
+
"description": "Synthetic community-structured graphs with 12-20 nodes",
|
| 210 |
+
"node_types": ["node"],
|
| 211 |
+
"edge_types": ["none", "edge"],
|
| 212 |
+
"max_nodes": 20,
|
| 213 |
+
},
|
| 214 |
+
}
|
| 215 |
+
|
| 216 |
+
GRAPHGEN_SAMPLING_MODES = [
|
| 217 |
+
{
|
| 218 |
+
"id": "standard",
|
| 219 |
+
"name": "Standard Denoising",
|
| 220 |
+
"description": "Iterative denoising from T to 0. Full quality, slower.",
|
| 221 |
+
"parameters": [
|
| 222 |
+
{
|
| 223 |
+
"name": "diffusion_steps",
|
| 224 |
+
"type": "integer",
|
| 225 |
+
"description": "Number of diffusion steps T",
|
| 226 |
+
"default": 500,
|
| 227 |
+
"min": 50,
|
| 228 |
+
"max": 1000,
|
| 229 |
+
},
|
| 230 |
+
{
|
| 231 |
+
"name": "chain_frames",
|
| 232 |
+
"type": "integer",
|
| 233 |
+
"description": "Number of denoising snapshots in the GIF",
|
| 234 |
+
"default": 20,
|
| 235 |
+
"min": 10,
|
| 236 |
+
"max": 30,
|
| 237 |
+
},
|
| 238 |
+
],
|
| 239 |
+
},
|
| 240 |
+
{
|
| 241 |
+
"id": "multiprox",
|
| 242 |
+
"name": "MultiProx Sampling",
|
| 243 |
+
"description": (
|
| 244 |
+
"Multi-measurement Gibbs sampling with proximal steps. "
|
| 245 |
+
"Step-by-step generation with controllable noise levels."
|
| 246 |
+
),
|
| 247 |
+
"parameters": [
|
| 248 |
+
{
|
| 249 |
+
"name": "diffusion_steps",
|
| 250 |
+
"type": "integer",
|
| 251 |
+
"description": "Number of diffusion steps T",
|
| 252 |
+
"default": 500,
|
| 253 |
+
"min": 50,
|
| 254 |
+
"max": 1000,
|
| 255 |
+
},
|
| 256 |
+
{
|
| 257 |
+
"name": "m",
|
| 258 |
+
"type": "integer",
|
| 259 |
+
"description": "Number of parallel samples per multi-measurement step",
|
| 260 |
+
"default": 10,
|
| 261 |
+
"min": 2,
|
| 262 |
+
"max": 100,
|
| 263 |
+
},
|
| 264 |
+
{
|
| 265 |
+
"name": "t",
|
| 266 |
+
"type": "float",
|
| 267 |
+
"description": "First noise level (normalized, 0-1)",
|
| 268 |
+
"default": 0.5,
|
| 269 |
+
"min": 0.0,
|
| 270 |
+
"max": 1.0,
|
| 271 |
+
},
|
| 272 |
+
{
|
| 273 |
+
"name": "t_prime",
|
| 274 |
+
"type": "float",
|
| 275 |
+
"description": "Second noise level (normalized, 0-1). Must satisfy t_prime <= t.",
|
| 276 |
+
"default": 0.1,
|
| 277 |
+
"min": 0.0,
|
| 278 |
+
"max": 1.0,
|
| 279 |
+
},
|
| 280 |
+
],
|
| 281 |
+
},
|
| 282 |
+
]
|
| 283 |
+
|
| 284 |
+
KG_ANOMALY_DATASET_META = {
|
| 285 |
+
"freebase": {
|
| 286 |
+
"name": "FB15k-237",
|
| 287 |
+
"description": "Diffusion model trained on Freebase subgraphs",
|
| 288 |
+
},
|
| 289 |
+
"wordnet": {
|
| 290 |
+
"name": "WN18RR",
|
| 291 |
+
"description": "Diffusion model trained on WordNet subgraphs",
|
| 292 |
+
},
|
| 293 |
+
"nell": {
|
| 294 |
+
"name": "NELL-995",
|
| 295 |
+
"description": "Diffusion model trained on NELL subgraphs",
|
| 296 |
+
},
|
| 297 |
+
}
|
src/backend/api/services/registry.py
CHANGED
|
@@ -1,14 +1,16 @@
|
|
| 1 |
import logging
|
| 2 |
import os
|
|
|
|
| 3 |
from pathlib import Path
|
| 4 |
|
| 5 |
from django.conf import settings
|
| 6 |
|
|
|
|
|
|
|
| 7 |
logger = logging.getLogger(__name__)
|
| 8 |
|
| 9 |
GDRIVE_FOLDER_ID = "14Bf8fi4KJn0rDdh9y8EFyA5b8OpQyXWi"
|
| 10 |
|
| 11 |
-
# Subfolder IDs within the Google Drive folder (from gdown listing)
|
| 12 |
GDRIVE_SUBFOLDERS = {
|
| 13 |
"coins": {
|
| 14 |
"folder_name": "COINs-KGGeneration/checkpoints_coins",
|
|
@@ -27,6 +29,104 @@ GDRIVE_SUBFOLDERS = {
|
|
| 27 |
},
|
| 28 |
}
|
| 29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
|
| 31 |
class ModelRegistry:
|
| 32 |
_instance = None
|
|
@@ -35,6 +135,8 @@ class ModelRegistry:
|
|
| 35 |
self.coins_checkpoints_available = {}
|
| 36 |
self.graphgen_checkpoints_available = {}
|
| 37 |
self.kg_anomaly_checkpoints_available = {}
|
|
|
|
|
|
|
| 38 |
|
| 39 |
@classmethod
|
| 40 |
def get(cls):
|
|
@@ -49,14 +151,19 @@ class ModelRegistry:
|
|
| 49 |
instance = cls()
|
| 50 |
instance._download_checkpoints()
|
| 51 |
instance._scan_checkpoints()
|
|
|
|
|
|
|
| 52 |
cls._instance = instance
|
| 53 |
logger.info(
|
| 54 |
-
"ModelRegistry initialized: coins=%s, multiproxan=%s, kg_anomaly=%s",
|
| 55 |
instance.is_coins_loaded(),
|
| 56 |
instance.is_graphgen_loaded(),
|
| 57 |
instance.is_kg_anomaly_loaded(),
|
|
|
|
| 58 |
)
|
| 59 |
|
|
|
|
|
|
|
| 60 |
def _download_checkpoints(self):
|
| 61 |
"""Download checkpoints from Google Drive if not already present locally."""
|
| 62 |
if self._all_checkpoint_dirs_populated():
|
|
@@ -116,6 +223,8 @@ class ModelRegistry:
|
|
| 116 |
logger.info("Installing checkpoint: %s -> %s", src_file.name, dest_dir)
|
| 117 |
src_file.replace(dest_file)
|
| 118 |
|
|
|
|
|
|
|
| 119 |
def _scan_checkpoints(self):
|
| 120 |
self._scan_coins_checkpoints()
|
| 121 |
self._scan_graphgen_checkpoints()
|
|
@@ -161,6 +270,199 @@ class ModelRegistry:
|
|
| 161 |
self.kg_anomaly_checkpoints_available.setdefault(name, []).append("generate")
|
| 162 |
logger.info("DiGress KG checkpoints: %s", self.kg_anomaly_checkpoints_available)
|
| 163 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 164 |
def is_coins_loaded(self):
|
| 165 |
return bool(self.coins_checkpoints_available)
|
| 166 |
|
|
|
|
| 1 |
import logging
|
| 2 |
import os
|
| 3 |
+
import random
|
| 4 |
from pathlib import Path
|
| 5 |
|
| 6 |
from django.conf import settings
|
| 7 |
|
| 8 |
+
from api.services.constants import COINS_DATASET_META
|
| 9 |
+
|
| 10 |
logger = logging.getLogger(__name__)
|
| 11 |
|
| 12 |
GDRIVE_FOLDER_ID = "14Bf8fi4KJn0rDdh9y8EFyA5b8OpQyXWi"
|
| 13 |
|
|
|
|
| 14 |
GDRIVE_SUBFOLDERS = {
|
| 15 |
"coins": {
|
| 16 |
"folder_name": "COINs-KGGeneration/checkpoints_coins",
|
|
|
|
| 29 |
},
|
| 30 |
}
|
| 31 |
|
| 32 |
+
# Shared sampler hyperparameters used across all COINs experiments
|
| 33 |
+
_SAMPLER_HPARS = {
|
| 34 |
+
"query_structure": ["1p"],
|
| 35 |
+
"num_negative_samples": 128,
|
| 36 |
+
"num_neighbours": 10,
|
| 37 |
+
"random_walk_length": 10,
|
| 38 |
+
"context_radius": 2,
|
| 39 |
+
"pagerank_importances": True,
|
| 40 |
+
"walks_relation_specific": True,
|
| 41 |
+
}
|
| 42 |
+
|
| 43 |
+
# Per-dataset base config (leiden resolution, loader hyperparams)
|
| 44 |
+
_DATASET_BASE = {
|
| 45 |
+
"freebase": {
|
| 46 |
+
"leiden_resolution": 5.0e-3,
|
| 47 |
+
"loader_hpars": {
|
| 48 |
+
"dataset_name": "freebase", "simulated": False,
|
| 49 |
+
"sample_source": "smore", "sampler_hpars": _SAMPLER_HPARS,
|
| 50 |
+
},
|
| 51 |
+
},
|
| 52 |
+
"wordnet": {
|
| 53 |
+
"leiden_resolution": None, # computed as 1/num_nodes
|
| 54 |
+
"loader_hpars": {
|
| 55 |
+
"dataset_name": "wordnet", "simulated": False,
|
| 56 |
+
"sample_source": "smore", "sampler_hpars": _SAMPLER_HPARS,
|
| 57 |
+
},
|
| 58 |
+
},
|
| 59 |
+
"nell": {
|
| 60 |
+
"leiden_resolution": 2.0e-5,
|
| 61 |
+
"loader_hpars": {
|
| 62 |
+
"dataset_name": "nell", "simulated": False,
|
| 63 |
+
"sample_source": "smore", "sampler_hpars": _SAMPLER_HPARS,
|
| 64 |
+
},
|
| 65 |
+
},
|
| 66 |
+
}
|
| 67 |
+
|
| 68 |
+
# Training seed per (dataset, algorithm), from PhD Thesis/Code experiment runs.
|
| 69 |
+
# Verified by matching num_communities in test_log.txt against checkpoint embedding shapes.
|
| 70 |
+
# Groups sharing the same community structure:
|
| 71 |
+
# freebase: transe/distmult/complex/rotate (1092 com) | q2b (1030) | kbgat (1025)
|
| 72 |
+
# wordnet: transe/distmult/complex/rotate (66 com) | q2b (74) | kbgat (88)
|
| 73 |
+
# nell: transe/distmult/rotate (282 com) | complex (188) | q2b (282, diff sizes) | kbgat (275)
|
| 74 |
+
_CHECKPOINT_SEEDS = {
|
| 75 |
+
("freebase", "transe"): 4089853924, ("freebase", "distmult"): 4089853924,
|
| 76 |
+
("freebase", "complex"): 4089853924, ("freebase", "rotate"): 4089853924,
|
| 77 |
+
("freebase", "q2b"): 1503136574, ("freebase", "kbgat"): 123456789,
|
| 78 |
+
("wordnet", "transe"): 1919180054, ("wordnet", "distmult"): 1919180054,
|
| 79 |
+
("wordnet", "complex"): 1919180054, ("wordnet", "rotate"): 1919180054,
|
| 80 |
+
("wordnet", "q2b"): 3312854056, ("wordnet", "kbgat"): 123456789,
|
| 81 |
+
("nell", "transe"): 3192206669, ("nell", "distmult"): 3192206669,
|
| 82 |
+
("nell", "rotate"): 3192206669,
|
| 83 |
+
("nell", "complex"): 2409194445, ("nell", "q2b"): 3793326028, ("nell", "kbgat"): 123456789,
|
| 84 |
+
}
|
| 85 |
+
|
| 86 |
+
# Vanilla seeds per dataset — used for metadata Loaders (entity/relation search, sample triples).
|
| 87 |
+
VANILLA_SEEDS = {
|
| 88 |
+
"freebase": 4089853924,
|
| 89 |
+
"wordnet": 1919180054,
|
| 90 |
+
"nell": 3192206669,
|
| 91 |
+
}
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def get_checkpoint_config(dataset_id, algorithm):
|
| 95 |
+
"""Return the full config for a specific (dataset, algorithm) checkpoint."""
|
| 96 |
+
base = _DATASET_BASE[dataset_id]
|
| 97 |
+
seed = _CHECKPOINT_SEEDS.get((dataset_id, algorithm))
|
| 98 |
+
if seed is None:
|
| 99 |
+
seed = VANILLA_SEEDS[dataset_id]
|
| 100 |
+
return {"seed": seed, **base}
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def _free_heavy_arrays(loader):
|
| 104 |
+
"""Free memory-intensive arrays from a Loader that aren't needed for discovery endpoints."""
|
| 105 |
+
loader.node_neighbours = None
|
| 106 |
+
loader.com_neighbours = None
|
| 107 |
+
loader.node_adjacency = None
|
| 108 |
+
loader.com_adjacency = None
|
| 109 |
+
loader.label_community_edge_freqs = None
|
| 110 |
+
loader.label_community_edge_freqs_index = None
|
| 111 |
+
loader.machines = None
|
| 112 |
+
loader.graph = None
|
| 113 |
+
loader.node_importances = None
|
| 114 |
+
loader.neighbour_importances = None
|
| 115 |
+
loader.out_degrees = None
|
| 116 |
+
loader.in_degrees = None
|
| 117 |
+
loader.degrees = None
|
| 118 |
+
loader.node_degree_type_freqs = None
|
| 119 |
+
loader.relation_freqs = None
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
class SubgraphInfo:
|
| 123 |
+
"""Holds pre-computed sample subgraphs for a KG anomaly dataset."""
|
| 124 |
+
|
| 125 |
+
__slots__ = ("subgraphs",)
|
| 126 |
+
|
| 127 |
+
def __init__(self, subgraphs):
|
| 128 |
+
self.subgraphs = subgraphs
|
| 129 |
+
|
| 130 |
|
| 131 |
class ModelRegistry:
|
| 132 |
_instance = None
|
|
|
|
| 135 |
self.coins_checkpoints_available = {}
|
| 136 |
self.graphgen_checkpoints_available = {}
|
| 137 |
self.kg_anomaly_checkpoints_available = {}
|
| 138 |
+
self.loaders = {} # dataset_id -> lightweight Loader for discovery
|
| 139 |
+
self.kg_anomaly_subgraphs = {} # dataset_id -> SubgraphInfo
|
| 140 |
|
| 141 |
@classmethod
|
| 142 |
def get(cls):
|
|
|
|
| 151 |
instance = cls()
|
| 152 |
instance._download_checkpoints()
|
| 153 |
instance._scan_checkpoints()
|
| 154 |
+
instance._load_all_loaders()
|
| 155 |
+
instance._generate_sample_subgraphs()
|
| 156 |
cls._instance = instance
|
| 157 |
logger.info(
|
| 158 |
+
"ModelRegistry initialized: coins=%s, multiproxan=%s, kg_anomaly=%s, loaders=%s",
|
| 159 |
instance.is_coins_loaded(),
|
| 160 |
instance.is_graphgen_loaded(),
|
| 161 |
instance.is_kg_anomaly_loaded(),
|
| 162 |
+
list(instance.loaders.keys()),
|
| 163 |
)
|
| 164 |
|
| 165 |
+
# ---- Checkpoint download -------------------------------------------
|
| 166 |
+
|
| 167 |
def _download_checkpoints(self):
|
| 168 |
"""Download checkpoints from Google Drive if not already present locally."""
|
| 169 |
if self._all_checkpoint_dirs_populated():
|
|
|
|
| 223 |
logger.info("Installing checkpoint: %s -> %s", src_file.name, dest_dir)
|
| 224 |
src_file.replace(dest_file)
|
| 225 |
|
| 226 |
+
# ---- Checkpoint scanning -------------------------------------------
|
| 227 |
+
|
| 228 |
def _scan_checkpoints(self):
|
| 229 |
self._scan_coins_checkpoints()
|
| 230 |
self._scan_graphgen_checkpoints()
|
|
|
|
| 270 |
self.kg_anomaly_checkpoints_available.setdefault(name, []).append("generate")
|
| 271 |
logger.info("DiGress KG checkpoints: %s", self.kg_anomaly_checkpoints_available)
|
| 272 |
|
| 273 |
+
# ---- Loader initialization -----------------------------------------
|
| 274 |
+
|
| 275 |
+
def _load_all_loaders(self):
|
| 276 |
+
"""Initialize one lightweight Loader per dataset for discovery endpoints.
|
| 277 |
+
|
| 278 |
+
Loads dataset, name maps, train/val/test split, and graph indexes.
|
| 279 |
+
Heavy arrays (node_neighbours, com_neighbours, adjacency dicts) are freed
|
| 280 |
+
after startup to save memory. Full Loaders for inference are loaded on demand.
|
| 281 |
+
"""
|
| 282 |
+
coins_root = str(Path(settings.COINS_DATA_DIR).parent)
|
| 283 |
+
original_cwd = os.getcwd()
|
| 284 |
+
try:
|
| 285 |
+
os.chdir(coins_root)
|
| 286 |
+
from graph_completion.graphs.load_graph import Loader, LoaderHpars
|
| 287 |
+
|
| 288 |
+
for dataset_id in _DATASET_BASE:
|
| 289 |
+
seed = VANILLA_SEEDS[dataset_id]
|
| 290 |
+
config = get_checkpoint_config(dataset_id, "transe")
|
| 291 |
+
try:
|
| 292 |
+
logger.info("Initializing Loader for %s (seed=%d)...", dataset_id, seed)
|
| 293 |
+
loader = LoaderHpars.from_dict(config["loader_hpars"]).make()
|
| 294 |
+
|
| 295 |
+
leiden_resolution = config["leiden_resolution"]
|
| 296 |
+
if leiden_resolution is None:
|
| 297 |
+
dataset_obj = Loader.datasets[dataset_id]
|
| 298 |
+
dataset_obj.load_from_disk()
|
| 299 |
+
leiden_resolution = 1.0 / len(dataset_obj.node_data)
|
| 300 |
+
dataset_obj.unload_from_memory()
|
| 301 |
+
|
| 302 |
+
loader.load_graph(
|
| 303 |
+
seed=seed, device="cpu", val_size=0.01, test_size=0.02,
|
| 304 |
+
community_method="leiden", leiden_resolution=leiden_resolution,
|
| 305 |
+
)
|
| 306 |
+
# Free heavy arrays not needed for discovery endpoints
|
| 307 |
+
_free_heavy_arrays(loader)
|
| 308 |
+
self.loaders[dataset_id] = loader
|
| 309 |
+
logger.info(
|
| 310 |
+
"Loader ready for %s: %d entities, %d relations, %d train triples",
|
| 311 |
+
dataset_id, loader.num_nodes, loader.num_relations, len(loader.train_edge_data),
|
| 312 |
+
)
|
| 313 |
+
except Exception:
|
| 314 |
+
logger.exception("Failed to initialize Loader for %s", dataset_id)
|
| 315 |
+
finally:
|
| 316 |
+
os.chdir(original_cwd)
|
| 317 |
+
|
| 318 |
+
# ---- Loader accessor helpers ---------------------------------------
|
| 319 |
+
|
| 320 |
+
def get_loader(self, dataset_id):
|
| 321 |
+
"""Return the metadata Loader for a dataset, or None."""
|
| 322 |
+
return self.loaders.get(dataset_id)
|
| 323 |
+
|
| 324 |
+
def get_entity_count(self, dataset_id):
|
| 325 |
+
loader = self.loaders.get(dataset_id)
|
| 326 |
+
return loader.num_nodes if loader else 0
|
| 327 |
+
|
| 328 |
+
def get_relation_count(self, dataset_id):
|
| 329 |
+
loader = self.loaders.get(dataset_id)
|
| 330 |
+
return loader.num_relations if loader else 0
|
| 331 |
+
|
| 332 |
+
def get_inverted_name_maps(self, dataset_id):
|
| 333 |
+
"""Return (inv_node_names, inv_node_types, inv_relation_names) Series for a dataset."""
|
| 334 |
+
loader = self.loaders.get(dataset_id)
|
| 335 |
+
if loader is None:
|
| 336 |
+
return None, None, None
|
| 337 |
+
return loader.dataset.get_inverted_name_maps()
|
| 338 |
+
|
| 339 |
+
def search_entities(self, dataset_id, query=None, page=1, page_size=50):
|
| 340 |
+
"""Search entities by substring, return paginated (id, name) list and total."""
|
| 341 |
+
loader = self.loaders.get(dataset_id)
|
| 342 |
+
if loader is None:
|
| 343 |
+
return [], 0
|
| 344 |
+
inv_nodes, _, _ = loader.dataset.get_inverted_name_maps()
|
| 345 |
+
items = [(int(idx), str(name)) for idx, name in inv_nodes.items()]
|
| 346 |
+
if query:
|
| 347 |
+
q = query.lower()
|
| 348 |
+
items = [(eid, name) for eid, name in items if q in name.lower()]
|
| 349 |
+
total = len(items)
|
| 350 |
+
start = (max(1, page) - 1) * page_size
|
| 351 |
+
return items[start:start + page_size], total
|
| 352 |
+
|
| 353 |
+
def search_relations(self, dataset_id, query=None, page=1, page_size=50):
|
| 354 |
+
"""Search relations by substring, return paginated (id, name) list and total."""
|
| 355 |
+
loader = self.loaders.get(dataset_id)
|
| 356 |
+
if loader is None:
|
| 357 |
+
return [], 0
|
| 358 |
+
_, _, inv_relations = loader.dataset.get_inverted_name_maps()
|
| 359 |
+
items = [(int(idx), str(name)) for idx, name in inv_relations.items()]
|
| 360 |
+
if query:
|
| 361 |
+
q = query.lower()
|
| 362 |
+
items = [(rid, name) for rid, name in items if q in name.lower()]
|
| 363 |
+
total = len(items)
|
| 364 |
+
start = (max(1, page) - 1) * page_size
|
| 365 |
+
return items[start:start + page_size], total
|
| 366 |
+
|
| 367 |
+
def sample_triples(self, dataset_id, count=10):
|
| 368 |
+
"""Return random triples with resolved entity/relation names."""
|
| 369 |
+
loader = self.loaders.get(dataset_id)
|
| 370 |
+
if loader is None:
|
| 371 |
+
return []
|
| 372 |
+
inv_nodes, _, inv_relations = loader.dataset.get_inverted_name_maps()
|
| 373 |
+
edge_data = loader.train_edge_data
|
| 374 |
+
count = min(count, len(edge_data))
|
| 375 |
+
|
| 376 |
+
indices = random.sample(range(len(edge_data)), count)
|
| 377 |
+
|
| 378 |
+
result = []
|
| 379 |
+
for i in indices:
|
| 380 |
+
row = edge_data.iloc[i]
|
| 381 |
+
h, r, t = int(row.s), int(row.r), int(row.t)
|
| 382 |
+
result.append({
|
| 383 |
+
"head": {"id": h, "name": str(inv_nodes.get(h, h))},
|
| 384 |
+
"relation": {"id": r, "name": str(inv_relations.get(r, r))},
|
| 385 |
+
"tail": {"id": t, "name": str(inv_nodes.get(t, t))},
|
| 386 |
+
})
|
| 387 |
+
return result
|
| 388 |
+
|
| 389 |
+
# ---- Sample subgraph generation ------------------------------------
|
| 390 |
+
|
| 391 |
+
def _generate_sample_subgraphs(self):
|
| 392 |
+
"""Generate sample subgraphs for KG anomaly using the Loader's context subgraph DFS."""
|
| 393 |
+
for dataset_id in COINS_DATASET_META:
|
| 394 |
+
loader = self.loaders.get(dataset_id)
|
| 395 |
+
if loader is None:
|
| 396 |
+
continue
|
| 397 |
+
try:
|
| 398 |
+
subgraphs = self._build_sample_subgraphs(dataset_id, loader)
|
| 399 |
+
self.kg_anomaly_subgraphs[dataset_id] = SubgraphInfo(subgraphs)
|
| 400 |
+
logger.info("Generated %d sample subgraphs for %s", len(subgraphs), dataset_id)
|
| 401 |
+
except Exception:
|
| 402 |
+
logger.exception("Failed to generate sample subgraphs for %s", dataset_id)
|
| 403 |
+
|
| 404 |
+
def _build_sample_subgraphs(self, dataset_id, loader, num_subgraphs=20, max_graph_size=10):
|
| 405 |
+
"""Build sample subgraphs using the Sampler's DFS-based context subgraph partitioning."""
|
| 406 |
+
inv_nodes, _, inv_relations = loader.dataset.get_inverted_name_maps()
|
| 407 |
+
node_types = loader.dataset.node_data.type.values
|
| 408 |
+
|
| 409 |
+
# Use the Sampler's DFS partitioning to get context subgraphs
|
| 410 |
+
samples = loader.sampler.get_context_subgraph_samples_dfs(
|
| 411 |
+
max_graph_size, loader.graph_indexes, loader.num_nodes,
|
| 412 |
+
max_samples=num_subgraphs * 5, disable_tqdm=True,
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
subgraphs = []
|
| 416 |
+
for subgraph_row, subgraph_col, nodes_row, nodes_col, edges in samples:
|
| 417 |
+
if len(subgraphs) >= num_subgraphs:
|
| 418 |
+
break
|
| 419 |
+
if len(edges) < 3:
|
| 420 |
+
continue
|
| 421 |
+
|
| 422 |
+
if subgraph_row == subgraph_col:
|
| 423 |
+
sg_nodes = nodes_row
|
| 424 |
+
else:
|
| 425 |
+
sg_nodes = nodes_row + nodes_col
|
| 426 |
+
|
| 427 |
+
node_idx = {n: i for i, n in enumerate(sg_nodes)}
|
| 428 |
+
|
| 429 |
+
nodes = []
|
| 430 |
+
for n in sg_nodes:
|
| 431 |
+
type_id = int(node_types[n]) if n < len(node_types) else 0
|
| 432 |
+
nodes.append({
|
| 433 |
+
"entity_id": n,
|
| 434 |
+
"entity_name": str(inv_nodes.get(n, n)),
|
| 435 |
+
"type_id": type_id,
|
| 436 |
+
})
|
| 437 |
+
|
| 438 |
+
edge_list = []
|
| 439 |
+
for h, r, t in edges:
|
| 440 |
+
if h in node_idx and t in node_idx:
|
| 441 |
+
edge_list.append({
|
| 442 |
+
"source_idx": node_idx[h],
|
| 443 |
+
"target_idx": node_idx[t],
|
| 444 |
+
"relation_id": r,
|
| 445 |
+
"relation_name": str(inv_relations.get(r, r)),
|
| 446 |
+
"entity_name_source": str(inv_nodes.get(h, h)),
|
| 447 |
+
"entity_name_target": str(inv_nodes.get(t, t)),
|
| 448 |
+
})
|
| 449 |
+
|
| 450 |
+
subgraphs.append({
|
| 451 |
+
"id": f"sample_{len(subgraphs) + 1}",
|
| 452 |
+
"num_nodes": len(nodes),
|
| 453 |
+
"num_edges": len(edge_list),
|
| 454 |
+
"nodes": nodes,
|
| 455 |
+
"edges": edge_list,
|
| 456 |
+
})
|
| 457 |
+
|
| 458 |
+
# Free the partitioning data stored on the sampler
|
| 459 |
+
loader.sampler.context_subgraphs_nodes = None
|
| 460 |
+
loader.sampler.context_subgraphs_edges = None
|
| 461 |
+
|
| 462 |
+
return subgraphs
|
| 463 |
+
|
| 464 |
+
# ---- Status --------------------------------------------------------
|
| 465 |
+
|
| 466 |
def is_coins_loaded(self):
|
| 467 |
return bool(self.coins_checkpoints_available)
|
| 468 |
|
src/backend/api/urls.py
CHANGED
|
@@ -1,8 +1,33 @@
|
|
| 1 |
from django.urls import path
|
| 2 |
|
| 3 |
-
from api.views.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
urlpatterns = [
|
|
|
|
| 6 |
path("", ApiRootView.as_view()),
|
| 7 |
path("health", HealthView.as_view()),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
]
|
|
|
|
| 1 |
from django.urls import path
|
| 2 |
|
| 3 |
+
from api.views.coins import (
|
| 4 |
+
CoinsDatasetsView,
|
| 5 |
+
CoinsEntitiesView,
|
| 6 |
+
CoinsModelsView,
|
| 7 |
+
CoinsQueryStructuresView,
|
| 8 |
+
CoinsRelationsView,
|
| 9 |
+
CoinsSampleTriplesView,
|
| 10 |
+
)
|
| 11 |
+
from api.views.graph_generation import GraphGenDatasetsView, GraphGenSamplingModesView
|
| 12 |
+
from api.views.health import ApiRootView, HealthView, MethodsView
|
| 13 |
+
from api.views.kg_anomaly import KgAnomalyDatasetsView, KgAnomalySampleSubgraphsView
|
| 14 |
|
| 15 |
urlpatterns = [
|
| 16 |
+
# Health & discovery
|
| 17 |
path("", ApiRootView.as_view()),
|
| 18 |
path("health", HealthView.as_view()),
|
| 19 |
+
path("methods", MethodsView.as_view()),
|
| 20 |
+
# COINs
|
| 21 |
+
path("coins/datasets", CoinsDatasetsView.as_view()),
|
| 22 |
+
path("coins/datasets/<str:dataset_id>/entities", CoinsEntitiesView.as_view()),
|
| 23 |
+
path("coins/datasets/<str:dataset_id>/relations", CoinsRelationsView.as_view()),
|
| 24 |
+
path("coins/datasets/<str:dataset_id>/sample-triples", CoinsSampleTriplesView.as_view()),
|
| 25 |
+
path("coins/models", CoinsModelsView.as_view()),
|
| 26 |
+
path("coins/query-structures", CoinsQueryStructuresView.as_view()),
|
| 27 |
+
# Graph generation
|
| 28 |
+
path("graph-generation/datasets", GraphGenDatasetsView.as_view()),
|
| 29 |
+
path("graph-generation/sampling-modes", GraphGenSamplingModesView.as_view()),
|
| 30 |
+
# KG anomaly
|
| 31 |
+
path("kg-anomaly/datasets", KgAnomalyDatasetsView.as_view()),
|
| 32 |
+
path("kg-anomaly/datasets/<str:dataset_id>/sample-subgraphs", KgAnomalySampleSubgraphsView.as_view()),
|
| 33 |
]
|
src/backend/api/views/coins.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from rest_framework.response import Response
|
| 2 |
+
from rest_framework.views import APIView
|
| 3 |
+
|
| 4 |
+
from api.exceptions import NotFoundError
|
| 5 |
+
from api.services.constants import COINS_DATASET_META, COINS_MODELS, QUERY_STRUCTURES
|
| 6 |
+
from api.services.registry import ModelRegistry
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _require_loader(dataset_id):
|
| 10 |
+
"""Validate dataset_id and ensure its Loader is available."""
|
| 11 |
+
if dataset_id not in COINS_DATASET_META:
|
| 12 |
+
raise NotFoundError(f"Dataset '{dataset_id}' not found")
|
| 13 |
+
registry = ModelRegistry.get()
|
| 14 |
+
if registry.get_loader(dataset_id) is None:
|
| 15 |
+
raise NotFoundError(f"Dataset '{dataset_id}' data not loaded")
|
| 16 |
+
return registry
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
class CoinsDatasetsView(APIView):
|
| 20 |
+
def get(self, request):
|
| 21 |
+
registry = ModelRegistry.get()
|
| 22 |
+
datasets = []
|
| 23 |
+
for dataset_id, meta in COINS_DATASET_META.items():
|
| 24 |
+
datasets.append({
|
| 25 |
+
"id": dataset_id,
|
| 26 |
+
"name": meta["name"],
|
| 27 |
+
"num_entities": registry.get_entity_count(dataset_id),
|
| 28 |
+
"num_relations": registry.get_relation_count(dataset_id),
|
| 29 |
+
"description": meta["description"],
|
| 30 |
+
})
|
| 31 |
+
return Response({"datasets": datasets})
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
class CoinsEntitiesView(APIView):
|
| 35 |
+
def get(self, request, dataset_id):
|
| 36 |
+
registry = _require_loader(dataset_id)
|
| 37 |
+
q = request.query_params.get("q", None)
|
| 38 |
+
page = int(request.query_params.get("page", 1))
|
| 39 |
+
page_size = int(request.query_params.get("page_size", 50))
|
| 40 |
+
page_size = max(1, min(200, page_size))
|
| 41 |
+
|
| 42 |
+
page_items, total = registry.search_entities(dataset_id, q, page, page_size)
|
| 43 |
+
|
| 44 |
+
return Response({
|
| 45 |
+
"dataset_id": dataset_id,
|
| 46 |
+
"total": total,
|
| 47 |
+
"page": page,
|
| 48 |
+
"page_size": page_size,
|
| 49 |
+
"entities": [{"id": eid, "name": name} for eid, name in page_items],
|
| 50 |
+
})
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class CoinsRelationsView(APIView):
|
| 54 |
+
def get(self, request, dataset_id):
|
| 55 |
+
registry = _require_loader(dataset_id)
|
| 56 |
+
q = request.query_params.get("q", None)
|
| 57 |
+
page = int(request.query_params.get("page", 1))
|
| 58 |
+
page_size = int(request.query_params.get("page_size", 50))
|
| 59 |
+
page_size = max(1, min(200, page_size))
|
| 60 |
+
|
| 61 |
+
page_items, total = registry.search_relations(dataset_id, q, page, page_size)
|
| 62 |
+
|
| 63 |
+
return Response({
|
| 64 |
+
"dataset_id": dataset_id,
|
| 65 |
+
"total": total,
|
| 66 |
+
"page": page,
|
| 67 |
+
"page_size": page_size,
|
| 68 |
+
"relations": [{"id": rid, "name": name} for rid, name in page_items],
|
| 69 |
+
})
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class CoinsSampleTriplesView(APIView):
|
| 73 |
+
def get(self, request, dataset_id):
|
| 74 |
+
registry = _require_loader(dataset_id)
|
| 75 |
+
count = int(request.query_params.get("count", 10))
|
| 76 |
+
count = max(1, min(50, count))
|
| 77 |
+
|
| 78 |
+
return Response({
|
| 79 |
+
"dataset_id": dataset_id,
|
| 80 |
+
"triples": registry.sample_triples(dataset_id, count),
|
| 81 |
+
})
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
class CoinsModelsView(APIView):
|
| 85 |
+
def get(self, request):
|
| 86 |
+
registry = ModelRegistry.get()
|
| 87 |
+
models = []
|
| 88 |
+
for model in COINS_MODELS:
|
| 89 |
+
available_datasets = []
|
| 90 |
+
for dataset_id in COINS_DATASET_META:
|
| 91 |
+
algos = registry.coins_checkpoints_available.get(dataset_id, [])
|
| 92 |
+
if model["algorithm"] in algos:
|
| 93 |
+
available_datasets.append(dataset_id)
|
| 94 |
+
models.append({
|
| 95 |
+
"algorithm": model["algorithm"],
|
| 96 |
+
"name": model["name"],
|
| 97 |
+
"description": model["description"],
|
| 98 |
+
"supported_query_structures": model["supported_query_structures"],
|
| 99 |
+
"available_datasets": available_datasets,
|
| 100 |
+
})
|
| 101 |
+
return Response({"models": models})
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class CoinsQueryStructuresView(APIView):
|
| 105 |
+
def get(self, request):
|
| 106 |
+
return Response({"query_structures": QUERY_STRUCTURES})
|
src/backend/api/views/graph_generation.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from rest_framework.response import Response
|
| 2 |
+
from rest_framework.views import APIView
|
| 3 |
+
|
| 4 |
+
from api.services.constants import GRAPHGEN_DATASETS, GRAPHGEN_SAMPLING_MODES
|
| 5 |
+
from api.services.registry import ModelRegistry
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class GraphGenDatasetsView(APIView):
|
| 9 |
+
def get(self, request):
|
| 10 |
+
registry = ModelRegistry.get()
|
| 11 |
+
datasets = []
|
| 12 |
+
for dataset_id, meta in GRAPHGEN_DATASETS.items():
|
| 13 |
+
available_types = registry.graphgen_checkpoints_available.get(dataset_id, [])
|
| 14 |
+
datasets.append({
|
| 15 |
+
"id": dataset_id,
|
| 16 |
+
"name": meta["name"],
|
| 17 |
+
"type": meta["type"],
|
| 18 |
+
"description": meta["description"],
|
| 19 |
+
"node_types": meta["node_types"],
|
| 20 |
+
"edge_types": meta["edge_types"],
|
| 21 |
+
"max_nodes": meta["max_nodes"],
|
| 22 |
+
"available_model_types": available_types,
|
| 23 |
+
})
|
| 24 |
+
return Response({"datasets": datasets})
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class GraphGenSamplingModesView(APIView):
|
| 28 |
+
def get(self, request):
|
| 29 |
+
return Response({"sampling_modes": GRAPHGEN_SAMPLING_MODES})
|
src/backend/api/views/health.py
CHANGED
|
@@ -1,7 +1,7 @@
|
|
| 1 |
from rest_framework.response import Response
|
| 2 |
-
from rest_framework.reverse import reverse
|
| 3 |
from rest_framework.views import APIView
|
| 4 |
|
|
|
|
| 5 |
from api.services.registry import ModelRegistry
|
| 6 |
|
| 7 |
|
|
@@ -50,3 +50,8 @@ class HealthView(APIView):
|
|
| 50 |
"kg_anomaly": registry.is_kg_anomaly_loaded(),
|
| 51 |
},
|
| 52 |
})
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
from rest_framework.response import Response
|
|
|
|
| 2 |
from rest_framework.views import APIView
|
| 3 |
|
| 4 |
+
from api.services.constants import METHODS
|
| 5 |
from api.services.registry import ModelRegistry
|
| 6 |
|
| 7 |
|
|
|
|
| 50 |
"kg_anomaly": registry.is_kg_anomaly_loaded(),
|
| 51 |
},
|
| 52 |
})
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class MethodsView(APIView):
|
| 56 |
+
def get(self, request):
|
| 57 |
+
return Response({"methods": METHODS})
|
src/backend/api/views/kg_anomaly.py
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from rest_framework.response import Response
|
| 2 |
+
from rest_framework.views import APIView
|
| 3 |
+
|
| 4 |
+
from api.exceptions import NotFoundError
|
| 5 |
+
from api.services.constants import KG_ANOMALY_DATASET_META
|
| 6 |
+
from api.services.registry import ModelRegistry
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class KgAnomalyDatasetsView(APIView):
|
| 10 |
+
def get(self, request):
|
| 11 |
+
registry = ModelRegistry.get()
|
| 12 |
+
datasets = []
|
| 13 |
+
for dataset_id, meta in KG_ANOMALY_DATASET_META.items():
|
| 14 |
+
available = registry.kg_anomaly_checkpoints_available.get(dataset_id, [])
|
| 15 |
+
datasets.append({
|
| 16 |
+
"id": dataset_id,
|
| 17 |
+
"name": meta["name"],
|
| 18 |
+
"description": meta["description"],
|
| 19 |
+
"available_tasks": available,
|
| 20 |
+
})
|
| 21 |
+
return Response({"datasets": datasets})
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class KgAnomalySampleSubgraphsView(APIView):
|
| 25 |
+
def get(self, request, dataset_id):
|
| 26 |
+
if dataset_id not in KG_ANOMALY_DATASET_META:
|
| 27 |
+
raise NotFoundError(f"Dataset '{dataset_id}' not found")
|
| 28 |
+
|
| 29 |
+
registry = ModelRegistry.get()
|
| 30 |
+
sg_info = registry.kg_anomaly_subgraphs.get(dataset_id)
|
| 31 |
+
if sg_info is None:
|
| 32 |
+
raise NotFoundError(f"No sample subgraphs available for dataset '{dataset_id}'")
|
| 33 |
+
|
| 34 |
+
count = int(request.query_params.get("count", 5))
|
| 35 |
+
count = max(1, min(10, count))
|
| 36 |
+
|
| 37 |
+
subgraphs = sg_info.subgraphs[:count]
|
| 38 |
+
|
| 39 |
+
return Response({
|
| 40 |
+
"dataset_id": dataset_id,
|
| 41 |
+
"subgraphs": subgraphs,
|
| 42 |
+
})
|
src/backend/requirements.txt
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Django
|
| 2 |
+
django==4.2.*
|
| 3 |
+
djangorestframework==3.14.*
|
| 4 |
+
django-cors-headers==4.*
|
| 5 |
+
|
| 6 |
+
# Checkpoint download from Google Drive
|
| 7 |
+
gdown>=4.7
|
| 8 |
+
|
| 9 |
+
# PyTorch with CUDA 11.8 (falls back to CPU at runtime if no GPU present)
|
| 10 |
+
--extra-index-url https://download.pytorch.org/whl/cu118
|
| 11 |
+
torch==2.0.1+cu118
|
| 12 |
+
torchvision==0.15.2+cu118
|
| 13 |
+
torchaudio==2.0.2+cu118
|
| 14 |
+
|
| 15 |
+
# PyTorch Geometric (must match torch version)
|
| 16 |
+
torch-geometric==2.3.1
|
| 17 |
+
|
| 18 |
+
# Research shared deps
|
| 19 |
+
pytorch-lightning==2.0.4
|
| 20 |
+
hydra-core==1.3.2
|
| 21 |
+
omegaconf==2.3.0
|
| 22 |
+
numpy==1.23
|
| 23 |
+
pandas==1.4
|
| 24 |
+
scipy==1.11.0
|
| 25 |
+
igraph==0.9.11
|
| 26 |
+
PyYAML==6.0
|
| 27 |
+
networkx==2.8.7
|
| 28 |
+
matplotlib==3.7.1
|
| 29 |
+
imageio==2.31.1
|
| 30 |
+
torchmetrics==0.11.4
|
| 31 |
+
tqdm==4.65.0
|
| 32 |
+
scikit-learn>=1.0
|
| 33 |
+
|
| 34 |
+
# Conda-only deps (must be pre-installed, not in pip requirements):
|
| 35 |
+
# rdkit==2023.03.2 (conda create -c conda-forge)
|
| 36 |
+
# graph-tool==2.45 (conda install -c conda-forge)
|
src/backend/research_api/settings.py
CHANGED
|
@@ -1,9 +1,17 @@
|
|
| 1 |
import os
|
|
|
|
| 2 |
from pathlib import Path
|
| 3 |
|
| 4 |
BASE_DIR = Path(__file__).resolve().parent.parent
|
| 5 |
PROJECT_ROOT = BASE_DIR.parent.parent # Website root
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
SECRET_KEY = os.environ.get("DJANGO_SECRET_KEY", "dev-insecure-key-change-in-production")
|
| 8 |
DEBUG = os.environ.get("DJANGO_DEBUG", "True").lower() in ("true", "1")
|
| 9 |
ALLOWED_HOSTS = os.environ.get("DJANGO_ALLOWED_HOSTS", "localhost,127.0.0.1").split(",")
|
|
|
|
| 1 |
import os
|
| 2 |
+
import sys
|
| 3 |
from pathlib import Path
|
| 4 |
|
| 5 |
BASE_DIR = Path(__file__).resolve().parent.parent
|
| 6 |
PROJECT_ROOT = BASE_DIR.parent.parent # Website root
|
| 7 |
|
| 8 |
+
# Add research repos to sys.path so their modules can be imported
|
| 9 |
+
_COINS_KG_ROOT = str(PROJECT_ROOT / "src" / "research" / "COINs-KGGeneration")
|
| 10 |
+
_MULTIPROXAN_ROOT = str(PROJECT_ROOT / "src" / "research" / "MultiProxAn")
|
| 11 |
+
for _path in (_COINS_KG_ROOT, _MULTIPROXAN_ROOT):
|
| 12 |
+
if _path not in sys.path:
|
| 13 |
+
sys.path.insert(0, _path)
|
| 14 |
+
|
| 15 |
SECRET_KEY = os.environ.get("DJANGO_SECRET_KEY", "dev-insecure-key-change-in-production")
|
| 16 |
DEBUG = os.environ.get("DJANGO_DEBUG", "True").lower() in ("true", "1")
|
| 17 |
ALLOWED_HOSTS = os.environ.get("DJANGO_ALLOWED_HOSTS", "localhost,127.0.0.1").split(",")
|
src/research/COINs-KGGeneration/graph_completion/graphs/load_graph.py
CHANGED
|
@@ -23,14 +23,14 @@ from graph_completion.graphs.queries import balanced_partition, balanced_partiti
|
|
| 23 |
from graph_completion.utils import AbstractConf
|
| 24 |
from graph_data.load_data import CoDExL, Dataset, FreeBase, HighEnergyPhysicsCitations, NELL, \
|
| 25 |
OGBBioKG, OGBCitation2, OGBLSCWikiKG90Mv2, \
|
| 26 |
-
PatentCitations, RedditHyperlinks,
|
| 27 |
from graph_data.serialization import load_object, save_object
|
| 28 |
|
| 29 |
|
| 30 |
class Loader:
|
| 31 |
datasets = {"reddit": RedditHyperlinks(), "hep": HighEnergyPhysicsCitations(), "patent": PatentCitations(),
|
| 32 |
"ogbl-citation2": OGBCitation2(), "ogbl-biokg": OGBBioKG(), "ogb-lsc-wikikg90m2": OGBLSCWikiKG90Mv2(),
|
| 33 |
-
"
|
| 34 |
"codex-l": CoDExL(), "yago3-10": YAGO310()}
|
| 35 |
transport_cities = [path.split(os_path_sep)[-1] for path in glob("data/transport/*") if "." not in path[2:]]
|
| 36 |
for city in transport_cities:
|
|
@@ -74,7 +74,7 @@ class Loader:
|
|
| 74 |
self.inter_community_map: np.ndarray = None
|
| 75 |
self.com_neighbours: Optional[np.ndarray] = None
|
| 76 |
self.node_neighbours: Optional[np.ndarray] = None
|
| 77 |
-
self.num_machines = min(num_machines, device_count())
|
| 78 |
self.machines: np.ndarray = None
|
| 79 |
|
| 80 |
self.graph_indexes: Iterable[AdjacencyIndex] = None
|
|
@@ -96,37 +96,8 @@ class Loader:
|
|
| 96 |
self.leiden_resolution = leiden_resolution
|
| 97 |
|
| 98 |
self.dataset = Loader.datasets[self.dataset_name]
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
self.dataset.node_names_map = pd.read_csv(
|
| 102 |
-
f"data/swisscom/node_names_orgs.csv",
|
| 103 |
-
header=0, index_col=0, squeeze=True, encoding="utf-8"
|
| 104 |
-
)
|
| 105 |
-
self.dataset.node_data = pd.read_csv(
|
| 106 |
-
f"data/swisscom/nodes_orgs.csv",
|
| 107 |
-
header=0, index_col=None, parse_dates=["time", ], encoding="utf-8"
|
| 108 |
-
)
|
| 109 |
-
self.dataset.edge_data = pd.read_csv(
|
| 110 |
-
f"data/swisscom/relations_orgs.csv",
|
| 111 |
-
header=0, index_col=None, parse_dates=["time", ], encoding="utf-8"
|
| 112 |
-
)
|
| 113 |
-
elif self.dataset_name == "swisscom":
|
| 114 |
-
entrypoints = pd.concat([
|
| 115 |
-
pd.read_csv(f"{customer_orgs_file}", header=0, squeeze=True, encoding="utf-8")
|
| 116 |
-
for customer_orgs_file in glob("data/swisscom/customer_orgs_*")
|
| 117 |
-
], ignore_index=True).drop_duplicates().sort_values(ignore_index=True)
|
| 118 |
-
swisscom_subgraphs = [Swisscom(entrypoint) for entrypoint in entrypoints]
|
| 119 |
-
self.dataset.load_from_disk(swisscom_subgraphs)
|
| 120 |
-
self.dataset.node_names_map.to_csv(
|
| 121 |
-
f"data/swisscom/node_names_orgs.csv",
|
| 122 |
-
header=True, index=True, encoding="utf-8")
|
| 123 |
-
self.dataset.node_data.to_csv(f"data/swisscom/nodes_orgs.csv",
|
| 124 |
-
header=True, index=False, encoding="utf-8")
|
| 125 |
-
self.dataset.edge_data.to_csv(f"data/swisscom/relations_orgs.csv",
|
| 126 |
-
header=True, index=False, encoding="utf-8")
|
| 127 |
-
else:
|
| 128 |
-
self.dataset.load_from_disk()
|
| 129 |
-
self.dataset.time_sort_and_numerize()
|
| 130 |
self.graph = Graph(self.dataset_name)
|
| 131 |
self.graph.update_graph(self.dataset.node_data, self.dataset.edge_data)
|
| 132 |
print("Computing required metrics...")
|
|
|
|
| 23 |
from graph_completion.utils import AbstractConf
|
| 24 |
from graph_data.load_data import CoDExL, Dataset, FreeBase, HighEnergyPhysicsCitations, NELL, \
|
| 25 |
OGBBioKG, OGBCitation2, OGBLSCWikiKG90Mv2, \
|
| 26 |
+
PatentCitations, RedditHyperlinks, Transport, WordNet, YAGO310
|
| 27 |
from graph_data.serialization import load_object, save_object
|
| 28 |
|
| 29 |
|
| 30 |
class Loader:
|
| 31 |
datasets = {"reddit": RedditHyperlinks(), "hep": HighEnergyPhysicsCitations(), "patent": PatentCitations(),
|
| 32 |
"ogbl-citation2": OGBCitation2(), "ogbl-biokg": OGBBioKG(), "ogb-lsc-wikikg90m2": OGBLSCWikiKG90Mv2(),
|
| 33 |
+
"freebase": FreeBase(), "wordnet": WordNet(), "nell": NELL(),
|
| 34 |
"codex-l": CoDExL(), "yago3-10": YAGO310()}
|
| 35 |
transport_cities = [path.split(os_path_sep)[-1] for path in glob("data/transport/*") if "." not in path[2:]]
|
| 36 |
for city in transport_cities:
|
|
|
|
| 74 |
self.inter_community_map: np.ndarray = None
|
| 75 |
self.com_neighbours: Optional[np.ndarray] = None
|
| 76 |
self.node_neighbours: Optional[np.ndarray] = None
|
| 77 |
+
self.num_machines = max(1, min(num_machines, device_count() or 1))
|
| 78 |
self.machines: np.ndarray = None
|
| 79 |
|
| 80 |
self.graph_indexes: Iterable[AdjacencyIndex] = None
|
|
|
|
| 96 |
self.leiden_resolution = leiden_resolution
|
| 97 |
|
| 98 |
self.dataset = Loader.datasets[self.dataset_name]
|
| 99 |
+
self.dataset.load_from_disk()
|
| 100 |
+
self.dataset.time_sort_and_numerize()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 101 |
self.graph = Graph(self.dataset_name)
|
| 102 |
self.graph.update_graph(self.dataset.node_data, self.dataset.edge_data)
|
| 103 |
print("Computing required metrics...")
|
src/research/COINs-KGGeneration/graph_completion/graphs/preprocess.py
CHANGED
|
@@ -741,15 +741,21 @@ class Sampler:
|
|
| 741 |
|
| 742 |
def get_context_subgraph_samples_dfs(self, max_graph_size: int, graph_indexes: Iterable[AdjacencyIndex],
|
| 743 |
num_nodes: int, allow_disc: bool = False,
|
|
|
|
| 744 |
disable_tqdm: bool = False) -> Iterable[ContextSubgraph]:
|
| 745 |
_, adj_s_to_t, adj_t_to_s, _, _ = graph_indexes
|
| 746 |
assignment = -np.ones(num_nodes, dtype=int)
|
| 747 |
subgraph = 0
|
| 748 |
subgraph_size = 0
|
|
|
|
|
|
|
|
|
|
| 749 |
progress_bar = tqdm(desc="Assigning nodes to context subgraphs", total=num_nodes, leave=False,
|
| 750 |
disable=disable_tqdm)
|
| 751 |
|
| 752 |
for i in range(num_nodes):
|
|
|
|
|
|
|
| 753 |
if assignment[i] >= 0:
|
| 754 |
continue
|
| 755 |
stack = [i, ]
|
|
@@ -759,6 +765,7 @@ class Sampler:
|
|
| 759 |
continue
|
| 760 |
assignment[v] = subgraph
|
| 761 |
subgraph_size += 1
|
|
|
|
| 762 |
if subgraph_size == max_graph_size:
|
| 763 |
subgraph_size = 0
|
| 764 |
subgraph += 1
|
|
@@ -781,11 +788,16 @@ class Sampler:
|
|
| 781 |
progress_bar.close()
|
| 782 |
subgraphs_nodes = [[] for _ in range(subgraph)]
|
| 783 |
subgraphs_edges = [[[] for _ in range(subgraph)] for _ in range(subgraph)]
|
| 784 |
-
|
|
|
|
| 785 |
subgraphs_nodes[assignment[v]].append(v)
|
| 786 |
for s, n_s in tqdm(adj_s_to_t.items(), desc="Assigning edges to subgraphs", leave=False, disable=disable_tqdm):
|
|
|
|
|
|
|
| 787 |
for r, n_s_r in n_s.items():
|
| 788 |
for t in n_s_r:
|
|
|
|
|
|
|
| 789 |
subgraphs_edges[assignment[s]][assignment[t]].append((s, r, t))
|
| 790 |
samples = [(k, l, subgraphs_nodes[k], subgraphs_nodes[l], subgraphs_edges[k][l])
|
| 791 |
for k in range(subgraph) for l in range(subgraph) if len(subgraphs_edges[k][l]) >= 5]
|
|
|
|
| 741 |
|
| 742 |
def get_context_subgraph_samples_dfs(self, max_graph_size: int, graph_indexes: Iterable[AdjacencyIndex],
|
| 743 |
num_nodes: int, allow_disc: bool = False,
|
| 744 |
+
max_samples: int = 0,
|
| 745 |
disable_tqdm: bool = False) -> Iterable[ContextSubgraph]:
|
| 746 |
_, adj_s_to_t, adj_t_to_s, _, _ = graph_indexes
|
| 747 |
assignment = -np.ones(num_nodes, dtype=int)
|
| 748 |
subgraph = 0
|
| 749 |
subgraph_size = 0
|
| 750 |
+
nodes_assigned = 0
|
| 751 |
+
# If max_samples is set, only partition enough nodes to produce that many subgraphs
|
| 752 |
+
max_subgraphs = max_samples * 2 if max_samples > 0 else 0
|
| 753 |
progress_bar = tqdm(desc="Assigning nodes to context subgraphs", total=num_nodes, leave=False,
|
| 754 |
disable=disable_tqdm)
|
| 755 |
|
| 756 |
for i in range(num_nodes):
|
| 757 |
+
if max_subgraphs > 0 and subgraph >= max_subgraphs:
|
| 758 |
+
break
|
| 759 |
if assignment[i] >= 0:
|
| 760 |
continue
|
| 761 |
stack = [i, ]
|
|
|
|
| 765 |
continue
|
| 766 |
assignment[v] = subgraph
|
| 767 |
subgraph_size += 1
|
| 768 |
+
nodes_assigned += 1
|
| 769 |
if subgraph_size == max_graph_size:
|
| 770 |
subgraph_size = 0
|
| 771 |
subgraph += 1
|
|
|
|
| 788 |
progress_bar.close()
|
| 789 |
subgraphs_nodes = [[] for _ in range(subgraph)]
|
| 790 |
subgraphs_edges = [[[] for _ in range(subgraph)] for _ in range(subgraph)]
|
| 791 |
+
assigned_nodes = set(np.where(assignment >= 0)[0])
|
| 792 |
+
for v in tqdm(assigned_nodes, desc="Assigning nodes to subgraphs", leave=False, disable=disable_tqdm):
|
| 793 |
subgraphs_nodes[assignment[v]].append(v)
|
| 794 |
for s, n_s in tqdm(adj_s_to_t.items(), desc="Assigning edges to subgraphs", leave=False, disable=disable_tqdm):
|
| 795 |
+
if assignment[s] < 0:
|
| 796 |
+
continue
|
| 797 |
for r, n_s_r in n_s.items():
|
| 798 |
for t in n_s_r:
|
| 799 |
+
if assignment[t] < 0:
|
| 800 |
+
continue
|
| 801 |
subgraphs_edges[assignment[s]][assignment[t]].append((s, r, t))
|
| 802 |
samples = [(k, l, subgraphs_nodes[k], subgraphs_nodes[l], subgraphs_edges[k][l])
|
| 803 |
for k in range(subgraph) for l in range(subgraph) if len(subgraphs_edges[k][l]) >= 5]
|
src/research/COINs-KGGeneration/graph_data/load_data.py
CHANGED
|
@@ -327,7 +327,7 @@ class Transport(Dataset):
|
|
| 327 |
|
| 328 |
class FreeBase(Dataset):
|
| 329 |
def __init__(self):
|
| 330 |
-
super().__init__("fb15k-237", "data/
|
| 331 |
|
| 332 |
def load_from_disk(self, *args, **kwargs):
|
| 333 |
train_edges = pd.read_csv(f"{self.file_location}/train.txt",
|
|
@@ -366,7 +366,7 @@ class FreeBase(Dataset):
|
|
| 366 |
|
| 367 |
class WordNet(Dataset):
|
| 368 |
def __init__(self):
|
| 369 |
-
super().__init__("wn18rr", "data/
|
| 370 |
|
| 371 |
def load_from_disk(self, *args, **kwargs):
|
| 372 |
train_edges = pd.read_csv(f"{self.file_location}/train.txt",
|
|
@@ -405,7 +405,7 @@ class WordNet(Dataset):
|
|
| 405 |
|
| 406 |
class NELL(Dataset):
|
| 407 |
def __init__(self):
|
| 408 |
-
super().__init__("nell-995", "data/
|
| 409 |
|
| 410 |
def load_from_disk(self, *args, **kwargs):
|
| 411 |
edge_data = pd.read_csv(f"{self.file_location}/raw.kb",
|
|
@@ -525,115 +525,6 @@ class YAGO310(Dataset):
|
|
| 525 |
return self.clean_split(train_edges), self.clean_split(valid_edges), self.clean_split(test_edges)
|
| 526 |
|
| 527 |
|
| 528 |
-
swisscom_node_types_map = pd.read_csv(f"data/swisscom_node_types.csv",
|
| 529 |
-
header=None, index_col=None, encoding="utf-8",
|
| 530 |
-
squeeze=True).reset_index(name="type").set_index("type")["index"]
|
| 531 |
-
swisscom_relation_types_map = pd.read_csv(f"data/swisscom_relation_types.csv",
|
| 532 |
-
header=None, index_col=None, encoding="utf-8",
|
| 533 |
-
squeeze=True).reset_index(name="r").set_index("r")["index"]
|
| 534 |
-
|
| 535 |
-
|
| 536 |
-
class Swisscom(Dataset):
|
| 537 |
-
def __init__(self, customer_org: str):
|
| 538 |
-
super().__init__(f"swisscom-{customer_org}", f"data/swisscom/{customer_org}")
|
| 539 |
-
self.customer_org = customer_org
|
| 540 |
-
self.node_types_map = swisscom_node_types_map
|
| 541 |
-
self.relation_names_map = swisscom_relation_types_map
|
| 542 |
-
|
| 543 |
-
def load_from_disk(self):
|
| 544 |
-
def compute_final_node_label(row):
|
| 545 |
-
if "Alert" in str(row["labels"]):
|
| 546 |
-
return "Alert"
|
| 547 |
-
elif str(row["labels"]) in ["Base", "-", ""] or pd.isna(row["n.inventory_type"]):
|
| 548 |
-
return "-"
|
| 549 |
-
else:
|
| 550 |
-
return str(row["n.inventory_type"]).capitalize()
|
| 551 |
-
|
| 552 |
-
node_data = pd.read_csv(glob(f"{self.file_location}/nodes_*")[-1],
|
| 553 |
-
sep=",", header=0, index_col=None, encoding="utf-8",
|
| 554 |
-
usecols=["n.id", "labels", "n.inventory_type"])
|
| 555 |
-
node_data = node_data.assign(type=node_data.apply(compute_final_node_label, axis=1), time=pd.NaT)
|
| 556 |
-
self.node_data = node_data.drop(columns=["labels", "n.inventory_type"]).rename(columns={"n.id": "n"})
|
| 557 |
-
del node_data
|
| 558 |
-
|
| 559 |
-
edge_data = pd.read_csv(glob(f"{self.file_location}/relations_*")[-1],
|
| 560 |
-
sep=",", header=0, index_col=None, encoding="utf-8",
|
| 561 |
-
usecols=["r.source_item_id", "relation_type_enriched", "r.target_item_id"])
|
| 562 |
-
self.edge_data = edge_data.rename(columns={"r.source_item_id": "s",
|
| 563 |
-
"relation_type_enriched": "r",
|
| 564 |
-
"r.target_item_id": "t"}).assign(time=pd.NaT).dropna(how="all")
|
| 565 |
-
del edge_data
|
| 566 |
-
|
| 567 |
-
|
| 568 |
-
class SwisscomBig(Dataset):
|
| 569 |
-
def __init__(self):
|
| 570 |
-
super().__init__(f"swisscom-big", f"data/swisscom/swisscom-big")
|
| 571 |
-
self.node_types_map = swisscom_node_types_map
|
| 572 |
-
self.relation_names_map = swisscom_relation_types_map
|
| 573 |
-
|
| 574 |
-
def load_from_disk(self, subgraphs: List[Swisscom]):
|
| 575 |
-
self.node_data = pd.DataFrame(columns=["n", "type", "time"]).astype(
|
| 576 |
-
{"n": "int", "type": "int", "time": "datetime64[ns]"}
|
| 577 |
-
)
|
| 578 |
-
self.edge_data = pd.DataFrame(columns=["s", "r", "t", "time"]).astype(
|
| 579 |
-
{"s": "int", "r": "int", "t": "int", "time": "datetime64[ns]"}
|
| 580 |
-
)
|
| 581 |
-
self.node_names_map = pd.Series(name="index")
|
| 582 |
-
self.node_names_map.index.name = "n"
|
| 583 |
-
|
| 584 |
-
for dataset in tqdm(subgraphs, desc="Merging subgraphs into the big graph", leave=False):
|
| 585 |
-
dataset.load_from_disk()
|
| 586 |
-
is_new_node = ~dataset.node_data.n.isin(self.node_names_map.index)
|
| 587 |
-
new_nodes = dataset.node_data.loc[is_new_node]
|
| 588 |
-
self.node_names_map = self.node_names_map.append(
|
| 589 |
-
len(self.node_names_map)
|
| 590 |
-
+ new_nodes.n.reset_index(name="n", drop=True).reset_index().set_index("n")["index"]
|
| 591 |
-
)
|
| 592 |
-
dataset.node_data.loc[is_new_node, "n"] = new_nodes.n.map(self.node_names_map)
|
| 593 |
-
dataset.node_data.loc[is_new_node, "type"] = new_nodes.type.map(self.node_types_map)
|
| 594 |
-
self.node_data = pd.concat((self.node_data, dataset.node_data.loc[is_new_node]), ignore_index=True)
|
| 595 |
-
new_edges = dataset.edge_data
|
| 596 |
-
new_edges.s = new_edges.s.map(self.node_names_map.rename("s"))
|
| 597 |
-
new_edges.r = new_edges.r.map(self.relation_names_map)
|
| 598 |
-
new_edges.t = new_edges.t.map(self.node_names_map.rename("t"))
|
| 599 |
-
if len(self.edge_data) > 0:
|
| 600 |
-
new_edges = new_edges.loc[~new_edges.s.isin(self.edge_data.s)
|
| 601 |
-
| ~new_edges.t.isin(self.edge_data.t)
|
| 602 |
-
| ~new_edges.r.isin(self.edge_data.r)]
|
| 603 |
-
self.edge_data = pd.concat((self.edge_data, new_edges), ignore_index=True)
|
| 604 |
-
dataset.unload_from_memory()
|
| 605 |
-
|
| 606 |
-
|
| 607 |
-
class SwisscomCommunity(Dataset):
|
| 608 |
-
def __init__(self, community_id: int):
|
| 609 |
-
self.community_id = community_id
|
| 610 |
-
super().__init__(f"swisscom-community-{community_id}", f"data/swisscom/community-{community_id}")
|
| 611 |
-
self.node_types_map = swisscom_node_types_map
|
| 612 |
-
self.relation_names_map = swisscom_relation_types_map
|
| 613 |
-
|
| 614 |
-
def load_from_disk(self):
|
| 615 |
-
return
|
| 616 |
-
|
| 617 |
-
|
| 618 |
-
class SwisscomSubgraphOverlap(Dataset):
|
| 619 |
-
def __init__(self, customer_org: str, customer_org_2: str):
|
| 620 |
-
self.customer_org = customer_org
|
| 621 |
-
self.customer_org_2 = customer_org_2
|
| 622 |
-
super().__init__(f"swisscom-subgraph-overlap-{customer_org}-{customer_org_2}",
|
| 623 |
-
f"data/swisscom/subgraph-overlap-{customer_org}-{customer_org_2}")
|
| 624 |
-
self.node_types_map = swisscom_node_types_map
|
| 625 |
-
self.relation_names_map = swisscom_relation_types_map
|
| 626 |
-
|
| 627 |
-
def load_from_disk(self, subgraph_1: Swisscom, subgraph_2: Swisscom):
|
| 628 |
-
subgraph_1.load_from_disk()
|
| 629 |
-
subgraph_2.load_from_disk()
|
| 630 |
-
self.node_data = subgraph_1.node_data.merge(subgraph_2.node_data, how="inner", on=["n", "type"])
|
| 631 |
-
self.node_data.drop(columns="time_y", inplace=True)
|
| 632 |
-
self.node_data.rename(columns={"time_x": "time"}, inplace=True)
|
| 633 |
-
|
| 634 |
-
self.edge_data = subgraph_1.edge_data.merge(subgraph_2.edge_data, how="inner", on=["s", "r", "t"])
|
| 635 |
-
self.edge_data.drop(columns="time_y", inplace=True)
|
| 636 |
-
self.edge_data.rename(columns={"time_x": "time"}, inplace=True)
|
| 637 |
subgraph_1.unload_from_memory()
|
| 638 |
subgraph_2.unload_from_memory()
|
| 639 |
|
|
|
|
| 327 |
|
| 328 |
class FreeBase(Dataset):
|
| 329 |
def __init__(self):
|
| 330 |
+
super().__init__("fb15k-237", "data/FB15k-237")
|
| 331 |
|
| 332 |
def load_from_disk(self, *args, **kwargs):
|
| 333 |
train_edges = pd.read_csv(f"{self.file_location}/train.txt",
|
|
|
|
| 366 |
|
| 367 |
class WordNet(Dataset):
|
| 368 |
def __init__(self):
|
| 369 |
+
super().__init__("wn18rr", "data/WN18RR")
|
| 370 |
|
| 371 |
def load_from_disk(self, *args, **kwargs):
|
| 372 |
train_edges = pd.read_csv(f"{self.file_location}/train.txt",
|
|
|
|
| 405 |
|
| 406 |
class NELL(Dataset):
|
| 407 |
def __init__(self):
|
| 408 |
+
super().__init__("nell-995", "data/NELL-995")
|
| 409 |
|
| 410 |
def load_from_disk(self, *args, **kwargs):
|
| 411 |
edge_data = pd.read_csv(f"{self.file_location}/raw.kb",
|
|
|
|
| 525 |
return self.clean_split(train_edges), self.clean_split(valid_edges), self.clean_split(test_edges)
|
| 526 |
|
| 527 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 528 |
subgraph_1.unload_from_memory()
|
| 529 |
subgraph_2.unload_from_memory()
|
| 530 |
|