Andrej Janchevski commited on
Commit
4f1e196
·
1 Parent(s): 0b4eacc

Implement backend discovery endpoints

Browse files
.claude/plans/backend_init.md CHANGED
@@ -10,6 +10,7 @@ Build the Django backend foundation: project scaffolding, dataset/entity/relatio
10
  src/backend/
11
  manage.py
12
  requirements.txt
 
13
  research_api/
14
  __init__.py
15
  settings.py
@@ -51,17 +52,38 @@ The research `Loader` needs: `igraph`, `pandas`, `numpy`, `torch` (for device ha
51
  - Checks `_all_checkpoint_dirs_populated()` first — skips download if all dirs have files
52
  - Uses `gdown.download_folder()` to a `.checkpoint_staging` dir, then `_distribute_checkpoints()` moves files to correct locations
53
  - Gracefully handles failures (logs warning, continues with local files)
54
- 2. **Loads all 3 COINs datasets** using the research `Loader`
55
- 3. **Scans checkpoint directories** for `.tar` / `.ckpt` files to determine availability flags (does NOT load model weights)
56
  - COINs: parses `{dataset}_{algorithm}.tar` filenames
57
  - MultiProxAn: `{dataset}.ckpt` = discrete, `{dataset}_c.ckpt` = continuous
58
  - DiGress KG: `{dataset}.ckpt` = generate, `{dataset}_correct.ckpt` = correct
59
- 4. **Generates sample subgraphs** for KG anomaly using the existing research code:
60
- - `Loader.get_context_subgraph_dataset(max_graph_size, device)` DFS-based partitioning of the full graph into sized subgraphs
61
- - Internally calls `Sampler.get_context_subgraph_samples_dfs()` in `graph_completion/graphs/preprocess.py` (line 742)
62
- - Returns `List[ContextSubgraphData]` (torch_geometric Data objects with `x`, `x_index`, `edge_index`, `edge_attr`, `subgraph_row`, `subgraph_col`)
63
- - `context_subgraphs_to_tensors()` (preprocess.py line 809) converts raw DFS samples to tensors
64
- - For the API sample-subgraphs endpoint: call this once per dataset at startup, store a subset, convert to the API response format (entity names + relation names resolved from the Loader's name maps)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  ### Checkpoint Download (on each boot)
67
 
@@ -87,6 +109,15 @@ Download logic in `registry.py`:
87
  - MultiProxAn (graph generation): glob `{MULTIPROXAN_DIR}/checkpoints/{dataset}.ckpt` and `{dataset}_c.ckpt`
88
  - DiGress KG (anomaly correction): glob `{DIGRESS_KG_DIR}/checkpoints/{dataset}.ckpt` and `{dataset}_correct.ckpt`
89
 
 
 
 
 
 
 
 
 
 
90
  ## Django Settings
91
 
92
  - `DATABASES = {}` — no database
@@ -205,7 +236,7 @@ cd src/analysis/orca && g++ -O2 -std=c++11 -o orca orca.cpp
205
 
206
  ### Backend requirements.txt
207
 
208
- The backend unifies both into one Python 3.9 environment. Since the COINs `Loader` is the only part reused for init, we pin to the DiGress/MultiProxAn versions and ensure COINs code runs under them. For PythonAnywhere (CPU only), use CPU torch wheels.
209
 
210
  ```
211
  # Django
@@ -216,11 +247,11 @@ django-cors-headers==4.*
216
  # Checkpoint download from Google Drive
217
  gdown>=4.7
218
 
219
- # PyTorch (CPU only for PythonAnywhere)
220
- -f https://download.pytorch.org/whl/cpu/torch_stable.html
221
- torch==2.0.1+cpu
222
- torchvision==0.15.2+cpu
223
- torchaudio==2.0.2+cpu
224
 
225
  # PyTorch Geometric (must match torch version)
226
  torch-geometric==2.3.1
@@ -239,6 +270,7 @@ matplotlib==3.7.1
239
  imageio==2.31.1
240
  torchmetrics==0.11.4
241
  tqdm==4.65.0
 
242
 
243
  # Conda-only deps (must be pre-installed, not in pip requirements):
244
  # rdkit==2023.03.2 (conda create -c conda-forge)
@@ -263,25 +295,29 @@ tqdm==4.65.0
263
  12. Hook `ModelRegistry.initialize()` in `AppConfig.ready()`
264
  13. Add research `Loader` integration to registry for dataset loading + subgraph generation
265
 
 
 
 
 
266
  ## Verification
267
 
268
  1. `pip install -r requirements.txt` — all dependencies install cleanly into `website_c` env
269
  2. `python manage.py check` — Django system check passes, 0 issues
270
- 3. `python manage.py runserver` — registry downloads checkpoints from Google Drive (or skips if present), scans all 3 checkpoint dirs, logs entity/relation counts
271
- 4. `GET /api/v1/` — returns API name, version, description, full endpoint directory with absolute URLs
272
  5. `GET /api/v1/health` — returns `{"status": "ok", "models_loaded": {"coins": bool, "multiproxan": bool, "kg_anomaly": bool}}`
273
  6. `GET /api/v1/methods` — returns 3 methods
274
  7. `GET /api/v1/coins/datasets` — returns 3 datasets with correct entity/relation counts
275
- 8. `GET /api/v1/coins/datasets/freebase/entities?q=location&page=1&page_size=10` — filtered, paginated results
276
- 9. `GET /api/v1/coins/datasets/freebase/relations?q=country` — substring search on relation names
277
- 10. `GET /api/v1/coins/datasets/freebase/sample-triples?count=5` — 5 random triples with resolved names
278
  11. `GET /api/v1/coins/datasets/unknown/entities` — `{"error": {"code": "NOT_FOUND", ...}}`
279
  12. `GET /api/v1/coins/models` — 6 algorithms, filtered by actually available checkpoints
280
  13. `GET /api/v1/coins/query-structures` — 7 templates (1p, 2p, 3p, 2i, 3i, ip, pi) matching API spec
281
- 14. `GET /api/v1/graph-generation/datasets` — QM9 + Community20 with model type availability from MultiProxAn checkpoint scan
282
  15. `GET /api/v1/graph-generation/sampling-modes` — standard + multiprox with parameter specs
283
- 16. `GET /api/v1/kg-anomaly/datasets` — 3 datasets with availability from DiGress KG checkpoint scan
284
- 17. `GET /api/v1/kg-anomaly/datasets/freebase/sample-subgraphs?count=3` — 3 pre-computed subgraphs with entity/relation names
285
  18. Import `docs/postman/collection.json` into Postman, set environment, run all GET requests against running server
286
 
287
  ## Critical Source Files
 
10
  src/backend/
11
  manage.py
12
  requirements.txt
13
+ README.md
14
  research_api/
15
  __init__.py
16
  settings.py
 
52
  - Checks `_all_checkpoint_dirs_populated()` first — skips download if all dirs have files
53
  - Uses `gdown.download_folder()` to a `.checkpoint_staging` dir, then `_distribute_checkpoints()` moves files to correct locations
54
  - Gracefully handles failures (logs warning, continues with local files)
55
+ 2. **Scans checkpoint directories** for `.tar` / `.ckpt` files to determine availability flags (does NOT load model weights)
 
56
  - COINs: parses `{dataset}_{algorithm}.tar` filenames
57
  - MultiProxAn: `{dataset}.ckpt` = discrete, `{dataset}_c.ckpt` = continuous
58
  - DiGress KG: `{dataset}.ckpt` = generate, `{dataset}_correct.ckpt` = correct
59
+ 3. **Loads one lightweight COINs Loader per dataset** using vanilla seeds, then frees heavy arrays
60
+ 4. **Generates sample subgraphs** for KG anomaly using the Loader's DFS-based context subgraph partitioning, capped via `max_samples` to avoid O(subgraphs²) memory
61
+
62
+ ### Memory Management
63
+
64
+ Each Loader's `load_graph()` produces heavy arrays not needed for discovery endpoints:
65
+ - `node_neighbours`: shape `(num_nodes, num_relations, num_neighbours)` — ~275MB for freebase alone
66
+ - `com_neighbours`, `node_adjacency`, `com_adjacency`: large dicts/arrays
67
+ - `graph` (igraph object), degree arrays, importance arrays, machine assignments
68
+
69
+ After init, `_free_heavy_arrays()` sets all these to `None`. Only dataset name maps, train edge data, graph indexes, and community metadata are kept.
70
+
71
+ The DFS subgraph partitioning (`get_context_subgraph_samples_dfs`) allocates an O(subgraphs²) edge matrix. A `max_samples` parameter caps partitioning to only process enough nodes to produce the needed samples. The sampler's `context_subgraphs_nodes`/`context_subgraphs_edges` are also freed after sample generation.
72
+
73
+ Django's dev server auto-reloader spawns two processes both calling `ready()`. The outer (file watcher) process is skipped via `RUN_MAIN` env var check in `apps.py` to avoid double memory usage.
74
+
75
+ ### Per-Algorithm Checkpoint Configs
76
+
77
+ Different algorithms were trained with different random seeds, producing different Leiden community structures:
78
+ - **Vanilla group** (transe/distmult/complex/rotate): shared seed per dataset
79
+ - **Q2B, KBGAT, nell_complex**: individual seeds from experiment runs
80
+
81
+ Community structure groups:
82
+ - freebase: transe/distmult/complex/rotate (1092 com) | q2b (1030) | kbgat (1025)
83
+ - wordnet: transe/distmult/complex/rotate (66 com) | q2b (74) | kbgat (88)
84
+ - nell: transe/distmult/rotate (282 com) | complex (188) | q2b (282, diff sizes) | kbgat (275)
85
+
86
+ `_CHECKPOINT_SEEDS` maps `(dataset, algorithm)` → training seed. `get_checkpoint_config(dataset_id, algorithm)` returns the full config. At startup only one Loader per dataset (vanilla seed) loads for discovery. Per-algorithm Loaders for inference load on demand.
87
 
88
  ### Checkpoint Download (on each boot)
89
 
 
109
  - MultiProxAn (graph generation): glob `{MULTIPROXAN_DIR}/checkpoints/{dataset}.ckpt` and `{dataset}_c.ckpt`
110
  - DiGress KG (anomaly correction): glob `{DIGRESS_KG_DIR}/checkpoints/{dataset}.ckpt` and `{dataset}_correct.ckpt`
111
 
112
+ ### Sample Subgraph Generation
113
+
114
+ For KG anomaly `sample-subgraphs` endpoint:
115
+ - `Sampler.get_context_subgraph_samples_dfs()` in `graph_completion/graphs/preprocess.py` — DFS-based partitioning of the graph into sized subgraphs
116
+ - Returns `List[ContextSubgraph]` tuples of `(subgraph_row, subgraph_col, nodes_row, nodes_col, edges)`
117
+ - `max_samples` parameter caps how many subgraphs to partition (avoids full-graph traversal)
118
+ - Called once per dataset at startup, stored as `SubgraphInfo`, converted to API response format with entity/relation names resolved from the Loader's name maps
119
+ - Sampler's internal partitioning data freed after generation
120
+
121
  ## Django Settings
122
 
123
  - `DATABASES = {}` — no database
 
236
 
237
  ### Backend requirements.txt
238
 
239
+ The backend unifies both into one Python 3.9 environment. Since the COINs `Loader` is the only part reused for init, we pin to the DiGress/MultiProxAn versions and ensure COINs code runs under them. PyTorch CUDA 11.8 for local dev, falls back to CPU at runtime on PythonAnywhere.
240
 
241
  ```
242
  # Django
 
247
  # Checkpoint download from Google Drive
248
  gdown>=4.7
249
 
250
+ # PyTorch with CUDA 11.8 (falls back to CPU at runtime if no GPU present)
251
+ --extra-index-url https://download.pytorch.org/whl/cu118
252
+ torch==2.0.1+cu118
253
+ torchvision==0.15.2+cu118
254
+ torchaudio==2.0.2+cu118
255
 
256
  # PyTorch Geometric (must match torch version)
257
  torch-geometric==2.3.1
 
270
  imageio==2.31.1
271
  torchmetrics==0.11.4
272
  tqdm==4.65.0
273
+ scikit-learn>=1.0
274
 
275
  # Conda-only deps (must be pre-installed, not in pip requirements):
276
  # rdkit==2023.03.2 (conda create -c conda-forge)
 
295
  12. Hook `ModelRegistry.initialize()` in `AppConfig.ready()`
296
  13. Add research `Loader` integration to registry for dataset loading + subgraph generation
297
 
298
+ ## Known Issues
299
+
300
+ - **Leiden non-reproducibility**: igraph 0.9.11's `community_leiden()` has no seed parameter. Re-running produces different community structures than the original training runs. For inference, the original cached `subgraphing.gz` files from the training server may be needed, or a matching igraph version that reproduces the results.
301
+
302
  ## Verification
303
 
304
  1. `pip install -r requirements.txt` — all dependencies install cleanly into `website_c` env
305
  2. `python manage.py check` — Django system check passes, 0 issues
306
+ 3. `python manage.py runserver` — registry downloads checkpoints (or skips), scans dirs, logs entity/relation counts
307
+ 4. `GET /api/v1/` — returns API name, version, description, full endpoint directory
308
  5. `GET /api/v1/health` — returns `{"status": "ok", "models_loaded": {"coins": bool, "multiproxan": bool, "kg_anomaly": bool}}`
309
  6. `GET /api/v1/methods` — returns 3 methods
310
  7. `GET /api/v1/coins/datasets` — returns 3 datasets with correct entity/relation counts
311
+ 8. `GET /api/v1/coins/datasets/wordnet/entities?q=dog&page=1&page_size=10` — filtered, paginated results
312
+ 9. `GET /api/v1/coins/datasets/wordnet/relations?q=hyper` — substring search on relation names
313
+ 10. `GET /api/v1/coins/datasets/wordnet/sample-triples?count=5` — 5 random triples with resolved names
314
  11. `GET /api/v1/coins/datasets/unknown/entities` — `{"error": {"code": "NOT_FOUND", ...}}`
315
  12. `GET /api/v1/coins/models` — 6 algorithms, filtered by actually available checkpoints
316
  13. `GET /api/v1/coins/query-structures` — 7 templates (1p, 2p, 3p, 2i, 3i, ip, pi) matching API spec
317
+ 14. `GET /api/v1/graph-generation/datasets` — QM9 + Community20 with model type availability
318
  15. `GET /api/v1/graph-generation/sampling-modes` — standard + multiprox with parameter specs
319
+ 16. `GET /api/v1/kg-anomaly/datasets` — 3 datasets with availability from checkpoint scan
320
+ 17. `GET /api/v1/kg-anomaly/datasets/wordnet/sample-subgraphs?count=3` — 3 pre-computed subgraphs with entity/relation names
321
  18. Import `docs/postman/collection.json` into Postman, set environment, run all GET requests against running server
322
 
323
  ## Critical Source Files
CLAUDE.md CHANGED
@@ -35,3 +35,9 @@ PythonAnywhere and hosted by them. API requests only from the front-end are allo
35
  system boot
36
  - `src/backend`: Django backend and endpoint files location
37
  - `src/frontend`: Vue.js and Semantic UI frontend files location
 
 
 
 
 
 
 
35
  system boot
36
  - `src/backend`: Django backend and endpoint files location
37
  - `src/frontend`: Vue.js and Semantic UI frontend files location
38
+
39
+ ## Backend Documentation
40
+
41
+ When making any backend change (new endpoint, config change, dependency, startup behavior), update `src/backend/README.md`
42
+ to reflect the change. Keep the endpoint table, project structure, and startup sequence sections current.
43
+
docs/postman/collection.json CHANGED
@@ -77,13 +77,13 @@
77
  "method": "GET",
78
  "header": [],
79
  "url": {
80
- "raw": "{{base_url}}/coins/datasets/freebase/entities?page=1&page_size=10&q=location",
81
  "host": ["{{base_url}}"],
82
- "path": ["coins", "datasets", "freebase", "entities"],
83
  "query": [
84
  { "key": "page", "value": "1" },
85
  { "key": "page_size", "value": "10" },
86
- { "key": "q", "value": "location", "description": "Substring search filter" }
87
  ]
88
  },
89
  "description": "Paginated, searchable entity list."
@@ -95,13 +95,13 @@
95
  "method": "GET",
96
  "header": [],
97
  "url": {
98
- "raw": "{{base_url}}/coins/datasets/freebase/relations?page=1&page_size=10&q=country",
99
  "host": ["{{base_url}}"],
100
- "path": ["coins", "datasets", "freebase", "relations"],
101
  "query": [
102
  { "key": "page", "value": "1" },
103
  { "key": "page_size", "value": "10" },
104
- { "key": "q", "value": "country", "description": "Substring search filter" }
105
  ]
106
  },
107
  "description": "Paginated, searchable relation list."
@@ -113,9 +113,9 @@
113
  "method": "GET",
114
  "header": [],
115
  "url": {
116
- "raw": "{{base_url}}/coins/datasets/freebase/sample-triples?count=5",
117
  "host": ["{{base_url}}"],
118
- "path": ["coins", "datasets", "freebase", "sample-triples"],
119
  "query": [
120
  { "key": "count", "value": "5" }
121
  ]
@@ -163,7 +163,7 @@
163
  ],
164
  "body": {
165
  "mode": "raw",
166
- "raw": "{\n \"dataset_id\": \"freebase\",\n \"algorithm\": \"rotate\",\n \"query_structure\": \"1p\",\n \"anchors\": {\"a\": 123},\n \"relations\": {\"r1\": 45},\n \"top_k\": 10\n}"
167
  },
168
  "url": {
169
  "raw": "{{base_url}}/coins/predict",
@@ -182,7 +182,7 @@
182
  ],
183
  "body": {
184
  "mode": "raw",
185
- "raw": "{\n \"dataset_id\": \"freebase\",\n \"algorithm\": \"q2b\",\n \"query_structure\": \"2i\",\n \"anchors\": {\"a1\": 123, \"a2\": 456},\n \"relations\": {\"r1\": 45, \"r2\": 78},\n \"top_k\": 10\n}"
186
  },
187
  "url": {
188
  "raw": "{{base_url}}/coins/predict",
@@ -201,7 +201,7 @@
201
  ],
202
  "body": {
203
  "mode": "raw",
204
- "raw": "{\n \"dataset_id\": \"freebase\",\n \"algorithm\": \"q2b\",\n \"query_structure\": \"ip\",\n \"anchors\": {\"a1\": 123, \"a2\": 456},\n \"relations\": {\"r1\": 45, \"r2\": 78, \"r3\": 12},\n \"top_k\": 10\n}"
205
  },
206
  "url": {
207
  "raw": "{{base_url}}/coins/predict",
@@ -328,9 +328,9 @@
328
  "method": "GET",
329
  "header": [],
330
  "url": {
331
- "raw": "{{base_url}}/kg-anomaly/datasets/freebase/sample-subgraphs?count=3",
332
  "host": ["{{base_url}}"],
333
- "path": ["kg-anomaly", "datasets", "freebase", "sample-subgraphs"],
334
  "query": [
335
  { "key": "count", "value": "3" }
336
  ]
@@ -352,7 +352,7 @@
352
  ],
353
  "body": {
354
  "mode": "raw",
355
- "raw": "{\n \"dataset_id\": \"freebase\",\n \"sampling_mode\": \"standard\",\n \"task\": \"correct\",\n \"subgraph\": {\n \"nodes\": [\n {\"entity_id\": 123, \"type_id\": 2},\n {\"entity_id\": 456, \"type_id\": 1},\n {\"entity_id\": 789, \"type_id\": 0}\n ],\n \"edges\": [\n {\"source_idx\": 0, \"target_idx\": 1, \"relation_id\": 45},\n {\"source_idx\": 1, \"target_idx\": 2, \"relation_id\": 12}\n ]\n },\n \"diffusion_steps\": 500,\n \"chain_frames\": 20\n}"
356
  },
357
  "url": {
358
  "raw": "{{base_url}}/kg-anomaly/correct",
@@ -371,7 +371,7 @@
371
  ],
372
  "body": {
373
  "mode": "raw",
374
- "raw": "{\n \"dataset_id\": \"freebase\",\n \"sampling_mode\": \"standard\",\n \"task\": \"generate\",\n \"subgraph\": {\n \"nodes\": [\n {\"entity_id\": 123, \"type_id\": 2},\n {\"entity_id\": 456, \"type_id\": 1}\n ],\n \"edges\": [\n {\"source_idx\": 0, \"target_idx\": 1, \"relation_id\": 45}\n ]\n },\n \"diffusion_steps\": 500,\n \"chain_frames\": 20\n}"
375
  },
376
  "url": {
377
  "raw": "{{base_url}}/kg-anomaly/correct",
@@ -390,7 +390,7 @@
390
  ],
391
  "body": {
392
  "mode": "raw",
393
- "raw": "{\n \"dataset_id\": \"freebase\",\n \"sampling_mode\": \"multiprox\",\n \"task\": \"correct\",\n \"subgraph\": {\n \"nodes\": [\n {\"entity_id\": 123, \"type_id\": 2},\n {\"entity_id\": 456, \"type_id\": 1}\n ],\n \"edges\": [\n {\"source_idx\": 0, \"target_idx\": 1, \"relation_id\": 45}\n ]\n },\n \"diffusion_steps\": 500,\n \"multiprox_params\": {\n \"m\": 10,\n \"t\": 0.5,\n \"t_prime\": 0.1\n }\n}"
394
  },
395
  "url": {
396
  "raw": "{{base_url}}/kg-anomaly/correct",
 
77
  "method": "GET",
78
  "header": [],
79
  "url": {
80
+ "raw": "{{base_url}}/coins/datasets/wordnet/entities?page=1&page_size=10&q=dog",
81
  "host": ["{{base_url}}"],
82
+ "path": ["coins", "datasets", "wordnet", "entities"],
83
  "query": [
84
  { "key": "page", "value": "1" },
85
  { "key": "page_size", "value": "10" },
86
+ { "key": "q", "value": "dog", "description": "Substring search filter" }
87
  ]
88
  },
89
  "description": "Paginated, searchable entity list."
 
95
  "method": "GET",
96
  "header": [],
97
  "url": {
98
+ "raw": "{{base_url}}/coins/datasets/wordnet/relations?page=1&page_size=10&q=hyper",
99
  "host": ["{{base_url}}"],
100
+ "path": ["coins", "datasets", "wordnet", "relations"],
101
  "query": [
102
  { "key": "page", "value": "1" },
103
  { "key": "page_size", "value": "10" },
104
+ { "key": "q", "value": "hyper", "description": "Substring search filter" }
105
  ]
106
  },
107
  "description": "Paginated, searchable relation list."
 
113
  "method": "GET",
114
  "header": [],
115
  "url": {
116
+ "raw": "{{base_url}}/coins/datasets/wordnet/sample-triples?count=5",
117
  "host": ["{{base_url}}"],
118
+ "path": ["coins", "datasets", "wordnet", "sample-triples"],
119
  "query": [
120
  { "key": "count", "value": "5" }
121
  ]
 
163
  ],
164
  "body": {
165
  "mode": "raw",
166
+ "raw": "{\n \"dataset_id\": \"wordnet\",\n \"algorithm\": \"rotate\",\n \"query_structure\": \"1p\",\n \"anchors\": {\"a\": 11754},\n \"relations\": {\"r1\": 3},\n \"top_k\": 10\n}"
167
  },
168
  "url": {
169
  "raw": "{{base_url}}/coins/predict",
 
182
  ],
183
  "body": {
184
  "mode": "raw",
185
+ "raw": "{\n \"dataset_id\": \"wordnet\",\n \"algorithm\": \"q2b\",\n \"query_structure\": \"2i\",\n \"anchors\": {\"a1\": 11754, \"a2\": 5142},\n \"relations\": {\"r1\": 3, \"r2\": 1},\n \"top_k\": 10\n}"
186
  },
187
  "url": {
188
  "raw": "{{base_url}}/coins/predict",
 
201
  ],
202
  "body": {
203
  "mode": "raw",
204
+ "raw": "{\n \"dataset_id\": \"wordnet\",\n \"algorithm\": \"q2b\",\n \"query_structure\": \"ip\",\n \"anchors\": {\"a1\": 11754, \"a2\": 5142},\n \"relations\": {\"r1\": 3, \"r2\": 1, \"r3\": 2},\n \"top_k\": 10\n}"
205
  },
206
  "url": {
207
  "raw": "{{base_url}}/coins/predict",
 
328
  "method": "GET",
329
  "header": [],
330
  "url": {
331
+ "raw": "{{base_url}}/kg-anomaly/datasets/wordnet/sample-subgraphs?count=3",
332
  "host": ["{{base_url}}"],
333
+ "path": ["kg-anomaly", "datasets", "wordnet", "sample-subgraphs"],
334
  "query": [
335
  { "key": "count", "value": "3" }
336
  ]
 
352
  ],
353
  "body": {
354
  "mode": "raw",
355
+ "raw": "{\n \"dataset_id\": \"wordnet\",\n \"sampling_mode\": \"standard\",\n \"task\": \"correct\",\n \"subgraph\": {\n \"nodes\": [\n {\"entity_id\": 11754, \"type_id\": 3},\n {\"entity_id\": 5142, \"type_id\": 3},\n {\"entity_id\": 8142, \"type_id\": 3}\n ],\n \"edges\": [\n {\"source_idx\": 0, \"target_idx\": 1, \"relation_id\": 3},\n {\"source_idx\": 1, \"target_idx\": 2, \"relation_id\": 1}\n ]\n },\n \"diffusion_steps\": 500,\n \"chain_frames\": 20\n}"
356
  },
357
  "url": {
358
  "raw": "{{base_url}}/kg-anomaly/correct",
 
371
  ],
372
  "body": {
373
  "mode": "raw",
374
+ "raw": "{\n \"dataset_id\": \"wordnet\",\n \"sampling_mode\": \"standard\",\n \"task\": \"generate\",\n \"subgraph\": {\n \"nodes\": [\n {\"entity_id\": 11754, \"type_id\": 3},\n {\"entity_id\": 5142, \"type_id\": 3}\n ],\n \"edges\": [\n {\"source_idx\": 0, \"target_idx\": 1, \"relation_id\": 3}\n ]\n },\n \"diffusion_steps\": 500,\n \"chain_frames\": 20\n}"
375
  },
376
  "url": {
377
  "raw": "{{base_url}}/kg-anomaly/correct",
 
390
  ],
391
  "body": {
392
  "mode": "raw",
393
+ "raw": "{\n \"dataset_id\": \"wordnet\",\n \"sampling_mode\": \"multiprox\",\n \"task\": \"correct\",\n \"subgraph\": {\n \"nodes\": [\n {\"entity_id\": 11754, \"type_id\": 3},\n {\"entity_id\": 5142, \"type_id\": 3}\n ],\n \"edges\": [\n {\"source_idx\": 0, \"target_idx\": 1, \"relation_id\": 3}\n ]\n },\n \"diffusion_steps\": 500,\n \"multiprox_params\": {\n \"m\": 10,\n \"t\": 0.5,\n \"t_prime\": 0.1\n }\n}"
394
  },
395
  "url": {
396
  "raw": "{{base_url}}/kg-anomaly/correct",
src/backend/README.md ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Backend — Django REST API
2
+
3
+ Stateless REST API serving the PhD research models. No database — PyTorch checkpoints are loaded into memory at startup and used to answer each request independently.
4
+
5
+ ## Prerequisites
6
+
7
+ 1. **Conda environment** with pre-installed system deps:
8
+ ```bash
9
+ conda create -n website python=3.9
10
+ conda activate website
11
+ conda install -c conda-forge rdkit=2023.03.2 graph-tool=2.45
12
+ ```
13
+
14
+ 2. **Pip dependencies**:
15
+ ```bash
16
+ pip install -r requirements.txt
17
+ ```
18
+
19
+ 3. **Model checkpoints** — downloaded automatically from Google Drive on first boot (via `gdown`). Alternatively, manually place `.tar` / `.ckpt` files in:
20
+ - `src/research/COINs-KGGeneration/graph_completion/checkpoints/` (COINs: `{dataset}_{algorithm}.tar`)
21
+ - `src/research/COINs-KGGeneration/graph_generation/checkpoints/` (KG anomaly: `{dataset}.ckpt`, `{dataset}_correct.ckpt`)
22
+ - `src/research/MultiProxAn/checkpoints/` (graph generation: `{dataset}.ckpt`, `{dataset}_c.ckpt`)
23
+
24
+ 4. **Dataset files** — the raw KG data files must be present under `src/research/COINs-KGGeneration/data/` (FB15k-237, WN18RR, NELL-995).
25
+
26
+ ## Running
27
+
28
+ From `src/backend/`:
29
+
30
+ ```bash
31
+ # Development server
32
+ python manage.py runserver 8000
33
+
34
+ # With custom settings
35
+ DJANGO_DEBUG=True DJANGO_SECRET_KEY=my-secret python manage.py runserver
36
+ ```
37
+
38
+ The API is served at `http://localhost:8000/api/v1/`.
39
+
40
+ ## Environment Variables
41
+
42
+ | Variable | Default | Description |
43
+ |---|---|---|
44
+ | `DJANGO_SECRET_KEY` | `dev-insecure-key-change-in-production` | Django secret key. **Set in production.** |
45
+ | `DJANGO_DEBUG` | `True` | Enable debug mode. Set to `False` in production. |
46
+ | `DJANGO_ALLOWED_HOSTS` | `localhost,127.0.0.1` | Comma-separated allowed hosts. |
47
+
48
+ ## Startup Sequence
49
+
50
+ On boot (`ApiConfig.ready()`), the `ModelRegistry` initializes:
51
+
52
+ 1. **Download checkpoints** from Google Drive if not already present locally
53
+ 2. **Scan checkpoint directories** to detect available models per method
54
+ 3. **Load lightweight COINs Loaders** — one per dataset (freebase, wordnet, nell), loading graph data, name maps, and train/val/test splits. Heavy arrays (node neighbours ~275MB each, community neighbours, adjacency dicts) are freed after initialization to keep memory low.
55
+ 4. **Generate sample subgraphs** for KG anomaly using the COINs Loaders
56
+
57
+ All model weights (COINs inference, graph generation, KG anomaly) are loaded lazily at first inference request.
58
+
59
+ ## API Endpoints
60
+
61
+ All endpoints are prefixed with `/api/v1/`.
62
+
63
+ ### Health & Discovery
64
+
65
+ | Method | Path | Description |
66
+ |---|---|---|
67
+ | `GET` | `/health` | Service health + model availability |
68
+ | `GET` | `/methods` | List the 3 research methods |
69
+
70
+ ### COINs — KG Reasoning
71
+
72
+ | Method | Path | Description |
73
+ |---|---|---|
74
+ | `GET` | `/coins/datasets` | List datasets with entity/relation counts |
75
+ | `GET` | `/coins/datasets/{id}/entities` | Paginated entity search (`?q=&page=&page_size=`) |
76
+ | `GET` | `/coins/datasets/{id}/relations` | Paginated relation search (`?q=&page=&page_size=`) |
77
+ | `GET` | `/coins/datasets/{id}/sample-triples` | Random training triples (`?count=10`) |
78
+ | `GET` | `/coins/models` | Available algorithms + supported query structures |
79
+ | `GET` | `/coins/query-structures` | Query graph templates for frontend rendering |
80
+ | `POST` | `/coins/predict` | Run inference (not yet implemented) |
81
+
82
+ ### Graph Generation — MultiProxAn
83
+
84
+ | Method | Path | Description |
85
+ |---|---|---|
86
+ | `GET` | `/graph-generation/datasets` | List graph types with node/edge types |
87
+ | `GET` | `/graph-generation/sampling-modes` | Sampling strategies with parameter specs |
88
+ | `POST` | `/graph-generation/generate` | Generate a graph (not yet implemented) |
89
+ | `POST` | `/graph-generation/continue` | Continue MultiProx generation (not yet implemented) |
90
+
91
+ ### KG Anomaly Correction
92
+
93
+ | Method | Path | Description |
94
+ |---|---|---|
95
+ | `GET` | `/kg-anomaly/datasets` | List datasets with correction models |
96
+ | `GET` | `/kg-anomaly/datasets/{id}/sample-subgraphs` | Pre-computed example subgraphs (`?count=5`) |
97
+ | `POST` | `/kg-anomaly/correct` | Run correction (not yet implemented) |
98
+ | `POST` | `/kg-anomaly/continue` | Continue MultiProx correction (not yet implemented) |
99
+
100
+ ## Project Structure
101
+
102
+ ```
103
+ src/backend/
104
+ manage.py
105
+ requirements.txt
106
+ research_api/ # Django project settings
107
+ settings.py
108
+ urls.py
109
+ wsgi.py
110
+ api/ # Django app
111
+ apps.py # Triggers ModelRegistry.initialize() on startup
112
+ urls.py # Route definitions
113
+ pagination.py # Shared pagination helper
114
+ exceptions.py # Custom error envelope
115
+ services/
116
+ constants.py # Dataset metadata, model configs, query structures
117
+ registry.py # ModelRegistry — checkpoint download, scanning, Loader init
118
+ views/
119
+ health.py # /health, /methods
120
+ coins.py # /coins/* endpoints
121
+ graph_generation.py # /graph-generation/* endpoints
122
+ kg_anomaly.py # /kg-anomaly/* endpoints
123
+ ```
124
+
125
+ ## Testing with Postman
126
+
127
+ Import the collection and environment from `docs/postman/` to test all discovery endpoints.
src/backend/api/apps.py CHANGED
@@ -6,6 +6,16 @@ class ApiConfig(AppConfig):
6
  default_auto_field = "django.db.models.BigAutoField"
7
 
8
  def ready(self):
 
 
 
 
 
 
 
 
 
 
9
  from api.services.registry import ModelRegistry
10
 
11
  ModelRegistry.initialize()
 
6
  default_auto_field = "django.db.models.BigAutoField"
7
 
8
  def ready(self):
9
+ import os
10
+ import sys
11
+
12
+ # Django's runserver auto-reloader spawns two processes, both calling ready().
13
+ # The inner (serving) process has RUN_MAIN="true"; the outer (watcher) does not.
14
+ # Skip initialization in the outer process to avoid double memory usage.
15
+ uses_reloader = "runserver" in sys.argv and "--noreload" not in sys.argv
16
+ if uses_reloader and os.environ.get("RUN_MAIN") != "true":
17
+ return
18
+
19
  from api.services.registry import ModelRegistry
20
 
21
  ModelRegistry.initialize()
src/backend/api/pagination.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ def paginate_list(items, page=1, page_size=50):
2
+ """Paginate a list with 1-indexed pages. Returns (page_items, total)."""
3
+ page = max(1, page)
4
+ page_size = max(1, min(200, page_size))
5
+ total = len(items)
6
+ start = (page - 1) * page_size
7
+ end = start + page_size
8
+ return items[start:end], total
src/backend/api/services/constants.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ METHODS = [
2
+ {
3
+ "id": "coins",
4
+ "name": "COINs - Knowledge Graph Reasoning",
5
+ "thesis_section": "3.1",
6
+ "description": (
7
+ "Community-Informed Graph Embeddings (COINs) for scalable knowledge graph link prediction "
8
+ "and complex query answering. Uses community detection to localize embedding computation, "
9
+ "achieving significant speedups over full-graph methods."
10
+ ),
11
+ },
12
+ {
13
+ "id": "multiproxan",
14
+ "name": "MultiProxAn - Graph Generation",
15
+ "thesis_section": "4.3",
16
+ "description": (
17
+ "Discrete denoising diffusion model for graph generation with MultiProx sampling. "
18
+ "Generates molecular graphs (QM9) and synthetic community graphs using iterative "
19
+ "multi-measurement Gibbs sampling for improved sample quality."
20
+ ),
21
+ },
22
+ {
23
+ "id": "kg_anomaly",
24
+ "name": "KG Anomaly Correction",
25
+ "thesis_section": "4.4",
26
+ "description": (
27
+ "Diffusion-based knowledge graph subgraph correction. Applies the DiGress denoising "
28
+ "diffusion model to knowledge graph subgraphs to detect and correct anomalous edges."
29
+ ),
30
+ },
31
+ ]
32
+
33
+ COINS_DATASET_META = {
34
+ "freebase": {
35
+ "name": "FB15k-237",
36
+ "description": "Subset of Freebase knowledge base with 237 relation types",
37
+ "data_dir": "FB15k-237",
38
+ },
39
+ "wordnet": {
40
+ "name": "WN18RR",
41
+ "description": "Subset of WordNet lexical database with 11 relation types",
42
+ "data_dir": "WN18RR",
43
+ },
44
+ "nell": {
45
+ "name": "NELL-995",
46
+ "description": "Never-Ending Language Learner knowledge base with 200 relation types",
47
+ "data_dir": "NELL-995",
48
+ },
49
+ }
50
+
51
+ COINS_MODELS = [
52
+ {
53
+ "algorithm": "transe",
54
+ "name": "TransE",
55
+ "description": "Translation-based embedding model",
56
+ "supported_query_structures": ["1p"],
57
+ },
58
+ {
59
+ "algorithm": "distmult",
60
+ "name": "DistMult",
61
+ "description": "Bilinear diagonal embedding model",
62
+ "supported_query_structures": ["1p"],
63
+ },
64
+ {
65
+ "algorithm": "complex",
66
+ "name": "ComplEx",
67
+ "description": "Complex-valued embedding model",
68
+ "supported_query_structures": ["1p"],
69
+ },
70
+ {
71
+ "algorithm": "rotate",
72
+ "name": "RotatE",
73
+ "description": "Rotation-based embedding model in complex space",
74
+ "supported_query_structures": ["1p"],
75
+ },
76
+ {
77
+ "algorithm": "q2b",
78
+ "name": "Query2Box",
79
+ "description": "Box embedding model for complex logical queries",
80
+ "supported_query_structures": ["1p", "2p", "3p", "2i", "3i", "ip", "pi"],
81
+ },
82
+ {
83
+ "algorithm": "kbgat",
84
+ "name": "KBGAT",
85
+ "description": "Knowledge base graph attention network",
86
+ "supported_query_structures": ["1p"],
87
+ },
88
+ ]
89
+
90
+ QUERY_STRUCTURES = [
91
+ {
92
+ "id": "1p",
93
+ "name": "Single Hop",
94
+ "description": "Direct link prediction: who/what is connected to the anchor via this relation?",
95
+ "nodes": [
96
+ {"id": "a", "type": "anchor", "label": "Anchor"},
97
+ {"id": "t", "type": "target", "label": "?"},
98
+ ],
99
+ "edges": [
100
+ {"id": "r1", "source": "a", "target": "t", "label": "Relation"},
101
+ ],
102
+ },
103
+ {
104
+ "id": "2p",
105
+ "name": "Two Hop",
106
+ "description": "Two-step chain: anchor -> variable -> target",
107
+ "nodes": [
108
+ {"id": "a", "type": "anchor", "label": "Anchor"},
109
+ {"id": "v1", "type": "variable", "label": "Variable"},
110
+ {"id": "t", "type": "target", "label": "?"},
111
+ ],
112
+ "edges": [
113
+ {"id": "r1", "source": "a", "target": "v1", "label": "Relation 1"},
114
+ {"id": "r2", "source": "v1", "target": "t", "label": "Relation 2"},
115
+ ],
116
+ },
117
+ {
118
+ "id": "3p",
119
+ "name": "Three Hop",
120
+ "description": "Three-step chain: anchor -> v1 -> v2 -> target",
121
+ "nodes": [
122
+ {"id": "a", "type": "anchor", "label": "Anchor"},
123
+ {"id": "v1", "type": "variable", "label": "Variable 1"},
124
+ {"id": "v2", "type": "variable", "label": "Variable 2"},
125
+ {"id": "t", "type": "target", "label": "?"},
126
+ ],
127
+ "edges": [
128
+ {"id": "r1", "source": "a", "target": "v1", "label": "Relation 1"},
129
+ {"id": "r2", "source": "v1", "target": "v2", "label": "Relation 2"},
130
+ {"id": "r3", "source": "v2", "target": "t", "label": "Relation 3"},
131
+ ],
132
+ },
133
+ {
134
+ "id": "2i",
135
+ "name": "Two Intersection",
136
+ "description": "Intersection of two single-hop queries sharing the same target",
137
+ "nodes": [
138
+ {"id": "a1", "type": "anchor", "label": "Anchor 1"},
139
+ {"id": "a2", "type": "anchor", "label": "Anchor 2"},
140
+ {"id": "t", "type": "target", "label": "?"},
141
+ ],
142
+ "edges": [
143
+ {"id": "r1", "source": "a1", "target": "t", "label": "Relation 1"},
144
+ {"id": "r2", "source": "a2", "target": "t", "label": "Relation 2"},
145
+ ],
146
+ },
147
+ {
148
+ "id": "3i",
149
+ "name": "Three Intersection",
150
+ "description": "Intersection of three single-hop queries sharing the same target",
151
+ "nodes": [
152
+ {"id": "a1", "type": "anchor", "label": "Anchor 1"},
153
+ {"id": "a2", "type": "anchor", "label": "Anchor 2"},
154
+ {"id": "a3", "type": "anchor", "label": "Anchor 3"},
155
+ {"id": "t", "type": "target", "label": "?"},
156
+ ],
157
+ "edges": [
158
+ {"id": "r1", "source": "a1", "target": "t", "label": "Relation 1"},
159
+ {"id": "r2", "source": "a2", "target": "t", "label": "Relation 2"},
160
+ {"id": "r3", "source": "a3", "target": "t", "label": "Relation 3"},
161
+ ],
162
+ },
163
+ {
164
+ "id": "ip",
165
+ "name": "Intersection then Projection",
166
+ "description": "Two anchors intersect, then the result projects via a third relation to the target",
167
+ "nodes": [
168
+ {"id": "a1", "type": "anchor", "label": "Anchor 1"},
169
+ {"id": "a2", "type": "anchor", "label": "Anchor 2"},
170
+ {"id": "v1", "type": "variable", "label": "Variable"},
171
+ {"id": "t", "type": "target", "label": "?"},
172
+ ],
173
+ "edges": [
174
+ {"id": "r1", "source": "a1", "target": "v1", "label": "Relation 1"},
175
+ {"id": "r2", "source": "a2", "target": "v1", "label": "Relation 2"},
176
+ {"id": "r3", "source": "v1", "target": "t", "label": "Relation 3"},
177
+ ],
178
+ },
179
+ {
180
+ "id": "pi",
181
+ "name": "Projection then Intersection",
182
+ "description": "One anchor projects then intersects with a direct connection from a second anchor",
183
+ "nodes": [
184
+ {"id": "a1", "type": "anchor", "label": "Anchor 1"},
185
+ {"id": "v1", "type": "variable", "label": "Variable"},
186
+ {"id": "a2", "type": "anchor", "label": "Anchor 2"},
187
+ {"id": "t", "type": "target", "label": "?"},
188
+ ],
189
+ "edges": [
190
+ {"id": "r1", "source": "a1", "target": "v1", "label": "Relation 1"},
191
+ {"id": "r2", "source": "v1", "target": "t", "label": "Relation 2"},
192
+ {"id": "r3", "source": "a2", "target": "t", "label": "Relation 3"},
193
+ ],
194
+ },
195
+ ]
196
+
197
+ GRAPHGEN_DATASETS = {
198
+ "qm9": {
199
+ "name": "QM9",
200
+ "type": "molecular",
201
+ "description": "Small organic molecules with up to 9 heavy atoms (C, N, O, F)",
202
+ "node_types": ["C", "N", "O", "F"],
203
+ "edge_types": ["none", "single", "double", "triple", "aromatic"],
204
+ "max_nodes": 9,
205
+ },
206
+ "comm20": {
207
+ "name": "Community20",
208
+ "type": "synthetic",
209
+ "description": "Synthetic community-structured graphs with 12-20 nodes",
210
+ "node_types": ["node"],
211
+ "edge_types": ["none", "edge"],
212
+ "max_nodes": 20,
213
+ },
214
+ }
215
+
216
+ GRAPHGEN_SAMPLING_MODES = [
217
+ {
218
+ "id": "standard",
219
+ "name": "Standard Denoising",
220
+ "description": "Iterative denoising from T to 0. Full quality, slower.",
221
+ "parameters": [
222
+ {
223
+ "name": "diffusion_steps",
224
+ "type": "integer",
225
+ "description": "Number of diffusion steps T",
226
+ "default": 500,
227
+ "min": 50,
228
+ "max": 1000,
229
+ },
230
+ {
231
+ "name": "chain_frames",
232
+ "type": "integer",
233
+ "description": "Number of denoising snapshots in the GIF",
234
+ "default": 20,
235
+ "min": 10,
236
+ "max": 30,
237
+ },
238
+ ],
239
+ },
240
+ {
241
+ "id": "multiprox",
242
+ "name": "MultiProx Sampling",
243
+ "description": (
244
+ "Multi-measurement Gibbs sampling with proximal steps. "
245
+ "Step-by-step generation with controllable noise levels."
246
+ ),
247
+ "parameters": [
248
+ {
249
+ "name": "diffusion_steps",
250
+ "type": "integer",
251
+ "description": "Number of diffusion steps T",
252
+ "default": 500,
253
+ "min": 50,
254
+ "max": 1000,
255
+ },
256
+ {
257
+ "name": "m",
258
+ "type": "integer",
259
+ "description": "Number of parallel samples per multi-measurement step",
260
+ "default": 10,
261
+ "min": 2,
262
+ "max": 100,
263
+ },
264
+ {
265
+ "name": "t",
266
+ "type": "float",
267
+ "description": "First noise level (normalized, 0-1)",
268
+ "default": 0.5,
269
+ "min": 0.0,
270
+ "max": 1.0,
271
+ },
272
+ {
273
+ "name": "t_prime",
274
+ "type": "float",
275
+ "description": "Second noise level (normalized, 0-1). Must satisfy t_prime <= t.",
276
+ "default": 0.1,
277
+ "min": 0.0,
278
+ "max": 1.0,
279
+ },
280
+ ],
281
+ },
282
+ ]
283
+
284
+ KG_ANOMALY_DATASET_META = {
285
+ "freebase": {
286
+ "name": "FB15k-237",
287
+ "description": "Diffusion model trained on Freebase subgraphs",
288
+ },
289
+ "wordnet": {
290
+ "name": "WN18RR",
291
+ "description": "Diffusion model trained on WordNet subgraphs",
292
+ },
293
+ "nell": {
294
+ "name": "NELL-995",
295
+ "description": "Diffusion model trained on NELL subgraphs",
296
+ },
297
+ }
src/backend/api/services/registry.py CHANGED
@@ -1,14 +1,16 @@
1
  import logging
2
  import os
 
3
  from pathlib import Path
4
 
5
  from django.conf import settings
6
 
 
 
7
  logger = logging.getLogger(__name__)
8
 
9
  GDRIVE_FOLDER_ID = "14Bf8fi4KJn0rDdh9y8EFyA5b8OpQyXWi"
10
 
11
- # Subfolder IDs within the Google Drive folder (from gdown listing)
12
  GDRIVE_SUBFOLDERS = {
13
  "coins": {
14
  "folder_name": "COINs-KGGeneration/checkpoints_coins",
@@ -27,6 +29,104 @@ GDRIVE_SUBFOLDERS = {
27
  },
28
  }
29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  class ModelRegistry:
32
  _instance = None
@@ -35,6 +135,8 @@ class ModelRegistry:
35
  self.coins_checkpoints_available = {}
36
  self.graphgen_checkpoints_available = {}
37
  self.kg_anomaly_checkpoints_available = {}
 
 
38
 
39
  @classmethod
40
  def get(cls):
@@ -49,14 +151,19 @@ class ModelRegistry:
49
  instance = cls()
50
  instance._download_checkpoints()
51
  instance._scan_checkpoints()
 
 
52
  cls._instance = instance
53
  logger.info(
54
- "ModelRegistry initialized: coins=%s, multiproxan=%s, kg_anomaly=%s",
55
  instance.is_coins_loaded(),
56
  instance.is_graphgen_loaded(),
57
  instance.is_kg_anomaly_loaded(),
 
58
  )
59
 
 
 
60
  def _download_checkpoints(self):
61
  """Download checkpoints from Google Drive if not already present locally."""
62
  if self._all_checkpoint_dirs_populated():
@@ -116,6 +223,8 @@ class ModelRegistry:
116
  logger.info("Installing checkpoint: %s -> %s", src_file.name, dest_dir)
117
  src_file.replace(dest_file)
118
 
 
 
119
  def _scan_checkpoints(self):
120
  self._scan_coins_checkpoints()
121
  self._scan_graphgen_checkpoints()
@@ -161,6 +270,199 @@ class ModelRegistry:
161
  self.kg_anomaly_checkpoints_available.setdefault(name, []).append("generate")
162
  logger.info("DiGress KG checkpoints: %s", self.kg_anomaly_checkpoints_available)
163
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  def is_coins_loaded(self):
165
  return bool(self.coins_checkpoints_available)
166
 
 
1
  import logging
2
  import os
3
+ import random
4
  from pathlib import Path
5
 
6
  from django.conf import settings
7
 
8
+ from api.services.constants import COINS_DATASET_META
9
+
10
  logger = logging.getLogger(__name__)
11
 
12
  GDRIVE_FOLDER_ID = "14Bf8fi4KJn0rDdh9y8EFyA5b8OpQyXWi"
13
 
 
14
  GDRIVE_SUBFOLDERS = {
15
  "coins": {
16
  "folder_name": "COINs-KGGeneration/checkpoints_coins",
 
29
  },
30
  }
31
 
32
+ # Shared sampler hyperparameters used across all COINs experiments
33
+ _SAMPLER_HPARS = {
34
+ "query_structure": ["1p"],
35
+ "num_negative_samples": 128,
36
+ "num_neighbours": 10,
37
+ "random_walk_length": 10,
38
+ "context_radius": 2,
39
+ "pagerank_importances": True,
40
+ "walks_relation_specific": True,
41
+ }
42
+
43
+ # Per-dataset base config (leiden resolution, loader hyperparams)
44
+ _DATASET_BASE = {
45
+ "freebase": {
46
+ "leiden_resolution": 5.0e-3,
47
+ "loader_hpars": {
48
+ "dataset_name": "freebase", "simulated": False,
49
+ "sample_source": "smore", "sampler_hpars": _SAMPLER_HPARS,
50
+ },
51
+ },
52
+ "wordnet": {
53
+ "leiden_resolution": None, # computed as 1/num_nodes
54
+ "loader_hpars": {
55
+ "dataset_name": "wordnet", "simulated": False,
56
+ "sample_source": "smore", "sampler_hpars": _SAMPLER_HPARS,
57
+ },
58
+ },
59
+ "nell": {
60
+ "leiden_resolution": 2.0e-5,
61
+ "loader_hpars": {
62
+ "dataset_name": "nell", "simulated": False,
63
+ "sample_source": "smore", "sampler_hpars": _SAMPLER_HPARS,
64
+ },
65
+ },
66
+ }
67
+
68
+ # Training seed per (dataset, algorithm), from PhD Thesis/Code experiment runs.
69
+ # Verified by matching num_communities in test_log.txt against checkpoint embedding shapes.
70
+ # Groups sharing the same community structure:
71
+ # freebase: transe/distmult/complex/rotate (1092 com) | q2b (1030) | kbgat (1025)
72
+ # wordnet: transe/distmult/complex/rotate (66 com) | q2b (74) | kbgat (88)
73
+ # nell: transe/distmult/rotate (282 com) | complex (188) | q2b (282, diff sizes) | kbgat (275)
74
+ _CHECKPOINT_SEEDS = {
75
+ ("freebase", "transe"): 4089853924, ("freebase", "distmult"): 4089853924,
76
+ ("freebase", "complex"): 4089853924, ("freebase", "rotate"): 4089853924,
77
+ ("freebase", "q2b"): 1503136574, ("freebase", "kbgat"): 123456789,
78
+ ("wordnet", "transe"): 1919180054, ("wordnet", "distmult"): 1919180054,
79
+ ("wordnet", "complex"): 1919180054, ("wordnet", "rotate"): 1919180054,
80
+ ("wordnet", "q2b"): 3312854056, ("wordnet", "kbgat"): 123456789,
81
+ ("nell", "transe"): 3192206669, ("nell", "distmult"): 3192206669,
82
+ ("nell", "rotate"): 3192206669,
83
+ ("nell", "complex"): 2409194445, ("nell", "q2b"): 3793326028, ("nell", "kbgat"): 123456789,
84
+ }
85
+
86
+ # Vanilla seeds per dataset — used for metadata Loaders (entity/relation search, sample triples).
87
+ VANILLA_SEEDS = {
88
+ "freebase": 4089853924,
89
+ "wordnet": 1919180054,
90
+ "nell": 3192206669,
91
+ }
92
+
93
+
94
+ def get_checkpoint_config(dataset_id, algorithm):
95
+ """Return the full config for a specific (dataset, algorithm) checkpoint."""
96
+ base = _DATASET_BASE[dataset_id]
97
+ seed = _CHECKPOINT_SEEDS.get((dataset_id, algorithm))
98
+ if seed is None:
99
+ seed = VANILLA_SEEDS[dataset_id]
100
+ return {"seed": seed, **base}
101
+
102
+
103
+ def _free_heavy_arrays(loader):
104
+ """Free memory-intensive arrays from a Loader that aren't needed for discovery endpoints."""
105
+ loader.node_neighbours = None
106
+ loader.com_neighbours = None
107
+ loader.node_adjacency = None
108
+ loader.com_adjacency = None
109
+ loader.label_community_edge_freqs = None
110
+ loader.label_community_edge_freqs_index = None
111
+ loader.machines = None
112
+ loader.graph = None
113
+ loader.node_importances = None
114
+ loader.neighbour_importances = None
115
+ loader.out_degrees = None
116
+ loader.in_degrees = None
117
+ loader.degrees = None
118
+ loader.node_degree_type_freqs = None
119
+ loader.relation_freqs = None
120
+
121
+
122
+ class SubgraphInfo:
123
+ """Holds pre-computed sample subgraphs for a KG anomaly dataset."""
124
+
125
+ __slots__ = ("subgraphs",)
126
+
127
+ def __init__(self, subgraphs):
128
+ self.subgraphs = subgraphs
129
+
130
 
131
  class ModelRegistry:
132
  _instance = None
 
135
  self.coins_checkpoints_available = {}
136
  self.graphgen_checkpoints_available = {}
137
  self.kg_anomaly_checkpoints_available = {}
138
+ self.loaders = {} # dataset_id -> lightweight Loader for discovery
139
+ self.kg_anomaly_subgraphs = {} # dataset_id -> SubgraphInfo
140
 
141
  @classmethod
142
  def get(cls):
 
151
  instance = cls()
152
  instance._download_checkpoints()
153
  instance._scan_checkpoints()
154
+ instance._load_all_loaders()
155
+ instance._generate_sample_subgraphs()
156
  cls._instance = instance
157
  logger.info(
158
+ "ModelRegistry initialized: coins=%s, multiproxan=%s, kg_anomaly=%s, loaders=%s",
159
  instance.is_coins_loaded(),
160
  instance.is_graphgen_loaded(),
161
  instance.is_kg_anomaly_loaded(),
162
+ list(instance.loaders.keys()),
163
  )
164
 
165
+ # ---- Checkpoint download -------------------------------------------
166
+
167
  def _download_checkpoints(self):
168
  """Download checkpoints from Google Drive if not already present locally."""
169
  if self._all_checkpoint_dirs_populated():
 
223
  logger.info("Installing checkpoint: %s -> %s", src_file.name, dest_dir)
224
  src_file.replace(dest_file)
225
 
226
+ # ---- Checkpoint scanning -------------------------------------------
227
+
228
  def _scan_checkpoints(self):
229
  self._scan_coins_checkpoints()
230
  self._scan_graphgen_checkpoints()
 
270
  self.kg_anomaly_checkpoints_available.setdefault(name, []).append("generate")
271
  logger.info("DiGress KG checkpoints: %s", self.kg_anomaly_checkpoints_available)
272
 
273
+ # ---- Loader initialization -----------------------------------------
274
+
275
+ def _load_all_loaders(self):
276
+ """Initialize one lightweight Loader per dataset for discovery endpoints.
277
+
278
+ Loads dataset, name maps, train/val/test split, and graph indexes.
279
+ Heavy arrays (node_neighbours, com_neighbours, adjacency dicts) are freed
280
+ after startup to save memory. Full Loaders for inference are loaded on demand.
281
+ """
282
+ coins_root = str(Path(settings.COINS_DATA_DIR).parent)
283
+ original_cwd = os.getcwd()
284
+ try:
285
+ os.chdir(coins_root)
286
+ from graph_completion.graphs.load_graph import Loader, LoaderHpars
287
+
288
+ for dataset_id in _DATASET_BASE:
289
+ seed = VANILLA_SEEDS[dataset_id]
290
+ config = get_checkpoint_config(dataset_id, "transe")
291
+ try:
292
+ logger.info("Initializing Loader for %s (seed=%d)...", dataset_id, seed)
293
+ loader = LoaderHpars.from_dict(config["loader_hpars"]).make()
294
+
295
+ leiden_resolution = config["leiden_resolution"]
296
+ if leiden_resolution is None:
297
+ dataset_obj = Loader.datasets[dataset_id]
298
+ dataset_obj.load_from_disk()
299
+ leiden_resolution = 1.0 / len(dataset_obj.node_data)
300
+ dataset_obj.unload_from_memory()
301
+
302
+ loader.load_graph(
303
+ seed=seed, device="cpu", val_size=0.01, test_size=0.02,
304
+ community_method="leiden", leiden_resolution=leiden_resolution,
305
+ )
306
+ # Free heavy arrays not needed for discovery endpoints
307
+ _free_heavy_arrays(loader)
308
+ self.loaders[dataset_id] = loader
309
+ logger.info(
310
+ "Loader ready for %s: %d entities, %d relations, %d train triples",
311
+ dataset_id, loader.num_nodes, loader.num_relations, len(loader.train_edge_data),
312
+ )
313
+ except Exception:
314
+ logger.exception("Failed to initialize Loader for %s", dataset_id)
315
+ finally:
316
+ os.chdir(original_cwd)
317
+
318
+ # ---- Loader accessor helpers ---------------------------------------
319
+
320
+ def get_loader(self, dataset_id):
321
+ """Return the metadata Loader for a dataset, or None."""
322
+ return self.loaders.get(dataset_id)
323
+
324
+ def get_entity_count(self, dataset_id):
325
+ loader = self.loaders.get(dataset_id)
326
+ return loader.num_nodes if loader else 0
327
+
328
+ def get_relation_count(self, dataset_id):
329
+ loader = self.loaders.get(dataset_id)
330
+ return loader.num_relations if loader else 0
331
+
332
+ def get_inverted_name_maps(self, dataset_id):
333
+ """Return (inv_node_names, inv_node_types, inv_relation_names) Series for a dataset."""
334
+ loader = self.loaders.get(dataset_id)
335
+ if loader is None:
336
+ return None, None, None
337
+ return loader.dataset.get_inverted_name_maps()
338
+
339
+ def search_entities(self, dataset_id, query=None, page=1, page_size=50):
340
+ """Search entities by substring, return paginated (id, name) list and total."""
341
+ loader = self.loaders.get(dataset_id)
342
+ if loader is None:
343
+ return [], 0
344
+ inv_nodes, _, _ = loader.dataset.get_inverted_name_maps()
345
+ items = [(int(idx), str(name)) for idx, name in inv_nodes.items()]
346
+ if query:
347
+ q = query.lower()
348
+ items = [(eid, name) for eid, name in items if q in name.lower()]
349
+ total = len(items)
350
+ start = (max(1, page) - 1) * page_size
351
+ return items[start:start + page_size], total
352
+
353
+ def search_relations(self, dataset_id, query=None, page=1, page_size=50):
354
+ """Search relations by substring, return paginated (id, name) list and total."""
355
+ loader = self.loaders.get(dataset_id)
356
+ if loader is None:
357
+ return [], 0
358
+ _, _, inv_relations = loader.dataset.get_inverted_name_maps()
359
+ items = [(int(idx), str(name)) for idx, name in inv_relations.items()]
360
+ if query:
361
+ q = query.lower()
362
+ items = [(rid, name) for rid, name in items if q in name.lower()]
363
+ total = len(items)
364
+ start = (max(1, page) - 1) * page_size
365
+ return items[start:start + page_size], total
366
+
367
+ def sample_triples(self, dataset_id, count=10):
368
+ """Return random triples with resolved entity/relation names."""
369
+ loader = self.loaders.get(dataset_id)
370
+ if loader is None:
371
+ return []
372
+ inv_nodes, _, inv_relations = loader.dataset.get_inverted_name_maps()
373
+ edge_data = loader.train_edge_data
374
+ count = min(count, len(edge_data))
375
+
376
+ indices = random.sample(range(len(edge_data)), count)
377
+
378
+ result = []
379
+ for i in indices:
380
+ row = edge_data.iloc[i]
381
+ h, r, t = int(row.s), int(row.r), int(row.t)
382
+ result.append({
383
+ "head": {"id": h, "name": str(inv_nodes.get(h, h))},
384
+ "relation": {"id": r, "name": str(inv_relations.get(r, r))},
385
+ "tail": {"id": t, "name": str(inv_nodes.get(t, t))},
386
+ })
387
+ return result
388
+
389
+ # ---- Sample subgraph generation ------------------------------------
390
+
391
+ def _generate_sample_subgraphs(self):
392
+ """Generate sample subgraphs for KG anomaly using the Loader's context subgraph DFS."""
393
+ for dataset_id in COINS_DATASET_META:
394
+ loader = self.loaders.get(dataset_id)
395
+ if loader is None:
396
+ continue
397
+ try:
398
+ subgraphs = self._build_sample_subgraphs(dataset_id, loader)
399
+ self.kg_anomaly_subgraphs[dataset_id] = SubgraphInfo(subgraphs)
400
+ logger.info("Generated %d sample subgraphs for %s", len(subgraphs), dataset_id)
401
+ except Exception:
402
+ logger.exception("Failed to generate sample subgraphs for %s", dataset_id)
403
+
404
+ def _build_sample_subgraphs(self, dataset_id, loader, num_subgraphs=20, max_graph_size=10):
405
+ """Build sample subgraphs using the Sampler's DFS-based context subgraph partitioning."""
406
+ inv_nodes, _, inv_relations = loader.dataset.get_inverted_name_maps()
407
+ node_types = loader.dataset.node_data.type.values
408
+
409
+ # Use the Sampler's DFS partitioning to get context subgraphs
410
+ samples = loader.sampler.get_context_subgraph_samples_dfs(
411
+ max_graph_size, loader.graph_indexes, loader.num_nodes,
412
+ max_samples=num_subgraphs * 5, disable_tqdm=True,
413
+ )
414
+
415
+ subgraphs = []
416
+ for subgraph_row, subgraph_col, nodes_row, nodes_col, edges in samples:
417
+ if len(subgraphs) >= num_subgraphs:
418
+ break
419
+ if len(edges) < 3:
420
+ continue
421
+
422
+ if subgraph_row == subgraph_col:
423
+ sg_nodes = nodes_row
424
+ else:
425
+ sg_nodes = nodes_row + nodes_col
426
+
427
+ node_idx = {n: i for i, n in enumerate(sg_nodes)}
428
+
429
+ nodes = []
430
+ for n in sg_nodes:
431
+ type_id = int(node_types[n]) if n < len(node_types) else 0
432
+ nodes.append({
433
+ "entity_id": n,
434
+ "entity_name": str(inv_nodes.get(n, n)),
435
+ "type_id": type_id,
436
+ })
437
+
438
+ edge_list = []
439
+ for h, r, t in edges:
440
+ if h in node_idx and t in node_idx:
441
+ edge_list.append({
442
+ "source_idx": node_idx[h],
443
+ "target_idx": node_idx[t],
444
+ "relation_id": r,
445
+ "relation_name": str(inv_relations.get(r, r)),
446
+ "entity_name_source": str(inv_nodes.get(h, h)),
447
+ "entity_name_target": str(inv_nodes.get(t, t)),
448
+ })
449
+
450
+ subgraphs.append({
451
+ "id": f"sample_{len(subgraphs) + 1}",
452
+ "num_nodes": len(nodes),
453
+ "num_edges": len(edge_list),
454
+ "nodes": nodes,
455
+ "edges": edge_list,
456
+ })
457
+
458
+ # Free the partitioning data stored on the sampler
459
+ loader.sampler.context_subgraphs_nodes = None
460
+ loader.sampler.context_subgraphs_edges = None
461
+
462
+ return subgraphs
463
+
464
+ # ---- Status --------------------------------------------------------
465
+
466
  def is_coins_loaded(self):
467
  return bool(self.coins_checkpoints_available)
468
 
src/backend/api/urls.py CHANGED
@@ -1,8 +1,33 @@
1
  from django.urls import path
2
 
3
- from api.views.health import ApiRootView, HealthView
 
 
 
 
 
 
 
 
 
 
4
 
5
  urlpatterns = [
 
6
  path("", ApiRootView.as_view()),
7
  path("health", HealthView.as_view()),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  ]
 
1
  from django.urls import path
2
 
3
+ from api.views.coins import (
4
+ CoinsDatasetsView,
5
+ CoinsEntitiesView,
6
+ CoinsModelsView,
7
+ CoinsQueryStructuresView,
8
+ CoinsRelationsView,
9
+ CoinsSampleTriplesView,
10
+ )
11
+ from api.views.graph_generation import GraphGenDatasetsView, GraphGenSamplingModesView
12
+ from api.views.health import ApiRootView, HealthView, MethodsView
13
+ from api.views.kg_anomaly import KgAnomalyDatasetsView, KgAnomalySampleSubgraphsView
14
 
15
  urlpatterns = [
16
+ # Health & discovery
17
  path("", ApiRootView.as_view()),
18
  path("health", HealthView.as_view()),
19
+ path("methods", MethodsView.as_view()),
20
+ # COINs
21
+ path("coins/datasets", CoinsDatasetsView.as_view()),
22
+ path("coins/datasets/<str:dataset_id>/entities", CoinsEntitiesView.as_view()),
23
+ path("coins/datasets/<str:dataset_id>/relations", CoinsRelationsView.as_view()),
24
+ path("coins/datasets/<str:dataset_id>/sample-triples", CoinsSampleTriplesView.as_view()),
25
+ path("coins/models", CoinsModelsView.as_view()),
26
+ path("coins/query-structures", CoinsQueryStructuresView.as_view()),
27
+ # Graph generation
28
+ path("graph-generation/datasets", GraphGenDatasetsView.as_view()),
29
+ path("graph-generation/sampling-modes", GraphGenSamplingModesView.as_view()),
30
+ # KG anomaly
31
+ path("kg-anomaly/datasets", KgAnomalyDatasetsView.as_view()),
32
+ path("kg-anomaly/datasets/<str:dataset_id>/sample-subgraphs", KgAnomalySampleSubgraphsView.as_view()),
33
  ]
src/backend/api/views/coins.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rest_framework.response import Response
2
+ from rest_framework.views import APIView
3
+
4
+ from api.exceptions import NotFoundError
5
+ from api.services.constants import COINS_DATASET_META, COINS_MODELS, QUERY_STRUCTURES
6
+ from api.services.registry import ModelRegistry
7
+
8
+
9
+ def _require_loader(dataset_id):
10
+ """Validate dataset_id and ensure its Loader is available."""
11
+ if dataset_id not in COINS_DATASET_META:
12
+ raise NotFoundError(f"Dataset '{dataset_id}' not found")
13
+ registry = ModelRegistry.get()
14
+ if registry.get_loader(dataset_id) is None:
15
+ raise NotFoundError(f"Dataset '{dataset_id}' data not loaded")
16
+ return registry
17
+
18
+
19
+ class CoinsDatasetsView(APIView):
20
+ def get(self, request):
21
+ registry = ModelRegistry.get()
22
+ datasets = []
23
+ for dataset_id, meta in COINS_DATASET_META.items():
24
+ datasets.append({
25
+ "id": dataset_id,
26
+ "name": meta["name"],
27
+ "num_entities": registry.get_entity_count(dataset_id),
28
+ "num_relations": registry.get_relation_count(dataset_id),
29
+ "description": meta["description"],
30
+ })
31
+ return Response({"datasets": datasets})
32
+
33
+
34
+ class CoinsEntitiesView(APIView):
35
+ def get(self, request, dataset_id):
36
+ registry = _require_loader(dataset_id)
37
+ q = request.query_params.get("q", None)
38
+ page = int(request.query_params.get("page", 1))
39
+ page_size = int(request.query_params.get("page_size", 50))
40
+ page_size = max(1, min(200, page_size))
41
+
42
+ page_items, total = registry.search_entities(dataset_id, q, page, page_size)
43
+
44
+ return Response({
45
+ "dataset_id": dataset_id,
46
+ "total": total,
47
+ "page": page,
48
+ "page_size": page_size,
49
+ "entities": [{"id": eid, "name": name} for eid, name in page_items],
50
+ })
51
+
52
+
53
+ class CoinsRelationsView(APIView):
54
+ def get(self, request, dataset_id):
55
+ registry = _require_loader(dataset_id)
56
+ q = request.query_params.get("q", None)
57
+ page = int(request.query_params.get("page", 1))
58
+ page_size = int(request.query_params.get("page_size", 50))
59
+ page_size = max(1, min(200, page_size))
60
+
61
+ page_items, total = registry.search_relations(dataset_id, q, page, page_size)
62
+
63
+ return Response({
64
+ "dataset_id": dataset_id,
65
+ "total": total,
66
+ "page": page,
67
+ "page_size": page_size,
68
+ "relations": [{"id": rid, "name": name} for rid, name in page_items],
69
+ })
70
+
71
+
72
+ class CoinsSampleTriplesView(APIView):
73
+ def get(self, request, dataset_id):
74
+ registry = _require_loader(dataset_id)
75
+ count = int(request.query_params.get("count", 10))
76
+ count = max(1, min(50, count))
77
+
78
+ return Response({
79
+ "dataset_id": dataset_id,
80
+ "triples": registry.sample_triples(dataset_id, count),
81
+ })
82
+
83
+
84
+ class CoinsModelsView(APIView):
85
+ def get(self, request):
86
+ registry = ModelRegistry.get()
87
+ models = []
88
+ for model in COINS_MODELS:
89
+ available_datasets = []
90
+ for dataset_id in COINS_DATASET_META:
91
+ algos = registry.coins_checkpoints_available.get(dataset_id, [])
92
+ if model["algorithm"] in algos:
93
+ available_datasets.append(dataset_id)
94
+ models.append({
95
+ "algorithm": model["algorithm"],
96
+ "name": model["name"],
97
+ "description": model["description"],
98
+ "supported_query_structures": model["supported_query_structures"],
99
+ "available_datasets": available_datasets,
100
+ })
101
+ return Response({"models": models})
102
+
103
+
104
+ class CoinsQueryStructuresView(APIView):
105
+ def get(self, request):
106
+ return Response({"query_structures": QUERY_STRUCTURES})
src/backend/api/views/graph_generation.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rest_framework.response import Response
2
+ from rest_framework.views import APIView
3
+
4
+ from api.services.constants import GRAPHGEN_DATASETS, GRAPHGEN_SAMPLING_MODES
5
+ from api.services.registry import ModelRegistry
6
+
7
+
8
+ class GraphGenDatasetsView(APIView):
9
+ def get(self, request):
10
+ registry = ModelRegistry.get()
11
+ datasets = []
12
+ for dataset_id, meta in GRAPHGEN_DATASETS.items():
13
+ available_types = registry.graphgen_checkpoints_available.get(dataset_id, [])
14
+ datasets.append({
15
+ "id": dataset_id,
16
+ "name": meta["name"],
17
+ "type": meta["type"],
18
+ "description": meta["description"],
19
+ "node_types": meta["node_types"],
20
+ "edge_types": meta["edge_types"],
21
+ "max_nodes": meta["max_nodes"],
22
+ "available_model_types": available_types,
23
+ })
24
+ return Response({"datasets": datasets})
25
+
26
+
27
+ class GraphGenSamplingModesView(APIView):
28
+ def get(self, request):
29
+ return Response({"sampling_modes": GRAPHGEN_SAMPLING_MODES})
src/backend/api/views/health.py CHANGED
@@ -1,7 +1,7 @@
1
  from rest_framework.response import Response
2
- from rest_framework.reverse import reverse
3
  from rest_framework.views import APIView
4
 
 
5
  from api.services.registry import ModelRegistry
6
 
7
 
@@ -50,3 +50,8 @@ class HealthView(APIView):
50
  "kg_anomaly": registry.is_kg_anomaly_loaded(),
51
  },
52
  })
 
 
 
 
 
 
1
  from rest_framework.response import Response
 
2
  from rest_framework.views import APIView
3
 
4
+ from api.services.constants import METHODS
5
  from api.services.registry import ModelRegistry
6
 
7
 
 
50
  "kg_anomaly": registry.is_kg_anomaly_loaded(),
51
  },
52
  })
53
+
54
+
55
+ class MethodsView(APIView):
56
+ def get(self, request):
57
+ return Response({"methods": METHODS})
src/backend/api/views/kg_anomaly.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rest_framework.response import Response
2
+ from rest_framework.views import APIView
3
+
4
+ from api.exceptions import NotFoundError
5
+ from api.services.constants import KG_ANOMALY_DATASET_META
6
+ from api.services.registry import ModelRegistry
7
+
8
+
9
+ class KgAnomalyDatasetsView(APIView):
10
+ def get(self, request):
11
+ registry = ModelRegistry.get()
12
+ datasets = []
13
+ for dataset_id, meta in KG_ANOMALY_DATASET_META.items():
14
+ available = registry.kg_anomaly_checkpoints_available.get(dataset_id, [])
15
+ datasets.append({
16
+ "id": dataset_id,
17
+ "name": meta["name"],
18
+ "description": meta["description"],
19
+ "available_tasks": available,
20
+ })
21
+ return Response({"datasets": datasets})
22
+
23
+
24
+ class KgAnomalySampleSubgraphsView(APIView):
25
+ def get(self, request, dataset_id):
26
+ if dataset_id not in KG_ANOMALY_DATASET_META:
27
+ raise NotFoundError(f"Dataset '{dataset_id}' not found")
28
+
29
+ registry = ModelRegistry.get()
30
+ sg_info = registry.kg_anomaly_subgraphs.get(dataset_id)
31
+ if sg_info is None:
32
+ raise NotFoundError(f"No sample subgraphs available for dataset '{dataset_id}'")
33
+
34
+ count = int(request.query_params.get("count", 5))
35
+ count = max(1, min(10, count))
36
+
37
+ subgraphs = sg_info.subgraphs[:count]
38
+
39
+ return Response({
40
+ "dataset_id": dataset_id,
41
+ "subgraphs": subgraphs,
42
+ })
src/backend/requirements.txt ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Django
2
+ django==4.2.*
3
+ djangorestframework==3.14.*
4
+ django-cors-headers==4.*
5
+
6
+ # Checkpoint download from Google Drive
7
+ gdown>=4.7
8
+
9
+ # PyTorch with CUDA 11.8 (falls back to CPU at runtime if no GPU present)
10
+ --extra-index-url https://download.pytorch.org/whl/cu118
11
+ torch==2.0.1+cu118
12
+ torchvision==0.15.2+cu118
13
+ torchaudio==2.0.2+cu118
14
+
15
+ # PyTorch Geometric (must match torch version)
16
+ torch-geometric==2.3.1
17
+
18
+ # Research shared deps
19
+ pytorch-lightning==2.0.4
20
+ hydra-core==1.3.2
21
+ omegaconf==2.3.0
22
+ numpy==1.23
23
+ pandas==1.4
24
+ scipy==1.11.0
25
+ igraph==0.9.11
26
+ PyYAML==6.0
27
+ networkx==2.8.7
28
+ matplotlib==3.7.1
29
+ imageio==2.31.1
30
+ torchmetrics==0.11.4
31
+ tqdm==4.65.0
32
+ scikit-learn>=1.0
33
+
34
+ # Conda-only deps (must be pre-installed, not in pip requirements):
35
+ # rdkit==2023.03.2 (conda create -c conda-forge)
36
+ # graph-tool==2.45 (conda install -c conda-forge)
src/backend/research_api/settings.py CHANGED
@@ -1,9 +1,17 @@
1
  import os
 
2
  from pathlib import Path
3
 
4
  BASE_DIR = Path(__file__).resolve().parent.parent
5
  PROJECT_ROOT = BASE_DIR.parent.parent # Website root
6
 
 
 
 
 
 
 
 
7
  SECRET_KEY = os.environ.get("DJANGO_SECRET_KEY", "dev-insecure-key-change-in-production")
8
  DEBUG = os.environ.get("DJANGO_DEBUG", "True").lower() in ("true", "1")
9
  ALLOWED_HOSTS = os.environ.get("DJANGO_ALLOWED_HOSTS", "localhost,127.0.0.1").split(",")
 
1
  import os
2
+ import sys
3
  from pathlib import Path
4
 
5
  BASE_DIR = Path(__file__).resolve().parent.parent
6
  PROJECT_ROOT = BASE_DIR.parent.parent # Website root
7
 
8
+ # Add research repos to sys.path so their modules can be imported
9
+ _COINS_KG_ROOT = str(PROJECT_ROOT / "src" / "research" / "COINs-KGGeneration")
10
+ _MULTIPROXAN_ROOT = str(PROJECT_ROOT / "src" / "research" / "MultiProxAn")
11
+ for _path in (_COINS_KG_ROOT, _MULTIPROXAN_ROOT):
12
+ if _path not in sys.path:
13
+ sys.path.insert(0, _path)
14
+
15
  SECRET_KEY = os.environ.get("DJANGO_SECRET_KEY", "dev-insecure-key-change-in-production")
16
  DEBUG = os.environ.get("DJANGO_DEBUG", "True").lower() in ("true", "1")
17
  ALLOWED_HOSTS = os.environ.get("DJANGO_ALLOWED_HOSTS", "localhost,127.0.0.1").split(",")
src/research/COINs-KGGeneration/graph_completion/graphs/load_graph.py CHANGED
@@ -23,14 +23,14 @@ from graph_completion.graphs.queries import balanced_partition, balanced_partiti
23
  from graph_completion.utils import AbstractConf
24
  from graph_data.load_data import CoDExL, Dataset, FreeBase, HighEnergyPhysicsCitations, NELL, \
25
  OGBBioKG, OGBCitation2, OGBLSCWikiKG90Mv2, \
26
- PatentCitations, RedditHyperlinks, Swisscom, SwisscomBig, Transport, WordNet, YAGO310
27
  from graph_data.serialization import load_object, save_object
28
 
29
 
30
  class Loader:
31
  datasets = {"reddit": RedditHyperlinks(), "hep": HighEnergyPhysicsCitations(), "patent": PatentCitations(),
32
  "ogbl-citation2": OGBCitation2(), "ogbl-biokg": OGBBioKG(), "ogb-lsc-wikikg90m2": OGBLSCWikiKG90Mv2(),
33
- "swisscom": SwisscomBig(), "freebase": FreeBase(), "wordnet": WordNet(), "nell": NELL(),
34
  "codex-l": CoDExL(), "yago3-10": YAGO310()}
35
  transport_cities = [path.split(os_path_sep)[-1] for path in glob("data/transport/*") if "." not in path[2:]]
36
  for city in transport_cities:
@@ -74,7 +74,7 @@ class Loader:
74
  self.inter_community_map: np.ndarray = None
75
  self.com_neighbours: Optional[np.ndarray] = None
76
  self.node_neighbours: Optional[np.ndarray] = None
77
- self.num_machines = min(num_machines, device_count())
78
  self.machines: np.ndarray = None
79
 
80
  self.graph_indexes: Iterable[AdjacencyIndex] = None
@@ -96,37 +96,8 @@ class Loader:
96
  self.leiden_resolution = leiden_resolution
97
 
98
  self.dataset = Loader.datasets[self.dataset_name]
99
-
100
- if self.dataset_name == "swisscom" and glob("data/swisscom/relations_orgs*"):
101
- self.dataset.node_names_map = pd.read_csv(
102
- f"data/swisscom/node_names_orgs.csv",
103
- header=0, index_col=0, squeeze=True, encoding="utf-8"
104
- )
105
- self.dataset.node_data = pd.read_csv(
106
- f"data/swisscom/nodes_orgs.csv",
107
- header=0, index_col=None, parse_dates=["time", ], encoding="utf-8"
108
- )
109
- self.dataset.edge_data = pd.read_csv(
110
- f"data/swisscom/relations_orgs.csv",
111
- header=0, index_col=None, parse_dates=["time", ], encoding="utf-8"
112
- )
113
- elif self.dataset_name == "swisscom":
114
- entrypoints = pd.concat([
115
- pd.read_csv(f"{customer_orgs_file}", header=0, squeeze=True, encoding="utf-8")
116
- for customer_orgs_file in glob("data/swisscom/customer_orgs_*")
117
- ], ignore_index=True).drop_duplicates().sort_values(ignore_index=True)
118
- swisscom_subgraphs = [Swisscom(entrypoint) for entrypoint in entrypoints]
119
- self.dataset.load_from_disk(swisscom_subgraphs)
120
- self.dataset.node_names_map.to_csv(
121
- f"data/swisscom/node_names_orgs.csv",
122
- header=True, index=True, encoding="utf-8")
123
- self.dataset.node_data.to_csv(f"data/swisscom/nodes_orgs.csv",
124
- header=True, index=False, encoding="utf-8")
125
- self.dataset.edge_data.to_csv(f"data/swisscom/relations_orgs.csv",
126
- header=True, index=False, encoding="utf-8")
127
- else:
128
- self.dataset.load_from_disk()
129
- self.dataset.time_sort_and_numerize()
130
  self.graph = Graph(self.dataset_name)
131
  self.graph.update_graph(self.dataset.node_data, self.dataset.edge_data)
132
  print("Computing required metrics...")
 
23
  from graph_completion.utils import AbstractConf
24
  from graph_data.load_data import CoDExL, Dataset, FreeBase, HighEnergyPhysicsCitations, NELL, \
25
  OGBBioKG, OGBCitation2, OGBLSCWikiKG90Mv2, \
26
+ PatentCitations, RedditHyperlinks, Transport, WordNet, YAGO310
27
  from graph_data.serialization import load_object, save_object
28
 
29
 
30
  class Loader:
31
  datasets = {"reddit": RedditHyperlinks(), "hep": HighEnergyPhysicsCitations(), "patent": PatentCitations(),
32
  "ogbl-citation2": OGBCitation2(), "ogbl-biokg": OGBBioKG(), "ogb-lsc-wikikg90m2": OGBLSCWikiKG90Mv2(),
33
+ "freebase": FreeBase(), "wordnet": WordNet(), "nell": NELL(),
34
  "codex-l": CoDExL(), "yago3-10": YAGO310()}
35
  transport_cities = [path.split(os_path_sep)[-1] for path in glob("data/transport/*") if "." not in path[2:]]
36
  for city in transport_cities:
 
74
  self.inter_community_map: np.ndarray = None
75
  self.com_neighbours: Optional[np.ndarray] = None
76
  self.node_neighbours: Optional[np.ndarray] = None
77
+ self.num_machines = max(1, min(num_machines, device_count() or 1))
78
  self.machines: np.ndarray = None
79
 
80
  self.graph_indexes: Iterable[AdjacencyIndex] = None
 
96
  self.leiden_resolution = leiden_resolution
97
 
98
  self.dataset = Loader.datasets[self.dataset_name]
99
+ self.dataset.load_from_disk()
100
+ self.dataset.time_sort_and_numerize()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
  self.graph = Graph(self.dataset_name)
102
  self.graph.update_graph(self.dataset.node_data, self.dataset.edge_data)
103
  print("Computing required metrics...")
src/research/COINs-KGGeneration/graph_completion/graphs/preprocess.py CHANGED
@@ -741,15 +741,21 @@ class Sampler:
741
 
742
  def get_context_subgraph_samples_dfs(self, max_graph_size: int, graph_indexes: Iterable[AdjacencyIndex],
743
  num_nodes: int, allow_disc: bool = False,
 
744
  disable_tqdm: bool = False) -> Iterable[ContextSubgraph]:
745
  _, adj_s_to_t, adj_t_to_s, _, _ = graph_indexes
746
  assignment = -np.ones(num_nodes, dtype=int)
747
  subgraph = 0
748
  subgraph_size = 0
 
 
 
749
  progress_bar = tqdm(desc="Assigning nodes to context subgraphs", total=num_nodes, leave=False,
750
  disable=disable_tqdm)
751
 
752
  for i in range(num_nodes):
 
 
753
  if assignment[i] >= 0:
754
  continue
755
  stack = [i, ]
@@ -759,6 +765,7 @@ class Sampler:
759
  continue
760
  assignment[v] = subgraph
761
  subgraph_size += 1
 
762
  if subgraph_size == max_graph_size:
763
  subgraph_size = 0
764
  subgraph += 1
@@ -781,11 +788,16 @@ class Sampler:
781
  progress_bar.close()
782
  subgraphs_nodes = [[] for _ in range(subgraph)]
783
  subgraphs_edges = [[[] for _ in range(subgraph)] for _ in range(subgraph)]
784
- for v in tqdm(range(num_nodes), desc="Assigning nodes to subgraphs", leave=False, disable=disable_tqdm):
 
785
  subgraphs_nodes[assignment[v]].append(v)
786
  for s, n_s in tqdm(adj_s_to_t.items(), desc="Assigning edges to subgraphs", leave=False, disable=disable_tqdm):
 
 
787
  for r, n_s_r in n_s.items():
788
  for t in n_s_r:
 
 
789
  subgraphs_edges[assignment[s]][assignment[t]].append((s, r, t))
790
  samples = [(k, l, subgraphs_nodes[k], subgraphs_nodes[l], subgraphs_edges[k][l])
791
  for k in range(subgraph) for l in range(subgraph) if len(subgraphs_edges[k][l]) >= 5]
 
741
 
742
  def get_context_subgraph_samples_dfs(self, max_graph_size: int, graph_indexes: Iterable[AdjacencyIndex],
743
  num_nodes: int, allow_disc: bool = False,
744
+ max_samples: int = 0,
745
  disable_tqdm: bool = False) -> Iterable[ContextSubgraph]:
746
  _, adj_s_to_t, adj_t_to_s, _, _ = graph_indexes
747
  assignment = -np.ones(num_nodes, dtype=int)
748
  subgraph = 0
749
  subgraph_size = 0
750
+ nodes_assigned = 0
751
+ # If max_samples is set, only partition enough nodes to produce that many subgraphs
752
+ max_subgraphs = max_samples * 2 if max_samples > 0 else 0
753
  progress_bar = tqdm(desc="Assigning nodes to context subgraphs", total=num_nodes, leave=False,
754
  disable=disable_tqdm)
755
 
756
  for i in range(num_nodes):
757
+ if max_subgraphs > 0 and subgraph >= max_subgraphs:
758
+ break
759
  if assignment[i] >= 0:
760
  continue
761
  stack = [i, ]
 
765
  continue
766
  assignment[v] = subgraph
767
  subgraph_size += 1
768
+ nodes_assigned += 1
769
  if subgraph_size == max_graph_size:
770
  subgraph_size = 0
771
  subgraph += 1
 
788
  progress_bar.close()
789
  subgraphs_nodes = [[] for _ in range(subgraph)]
790
  subgraphs_edges = [[[] for _ in range(subgraph)] for _ in range(subgraph)]
791
+ assigned_nodes = set(np.where(assignment >= 0)[0])
792
+ for v in tqdm(assigned_nodes, desc="Assigning nodes to subgraphs", leave=False, disable=disable_tqdm):
793
  subgraphs_nodes[assignment[v]].append(v)
794
  for s, n_s in tqdm(adj_s_to_t.items(), desc="Assigning edges to subgraphs", leave=False, disable=disable_tqdm):
795
+ if assignment[s] < 0:
796
+ continue
797
  for r, n_s_r in n_s.items():
798
  for t in n_s_r:
799
+ if assignment[t] < 0:
800
+ continue
801
  subgraphs_edges[assignment[s]][assignment[t]].append((s, r, t))
802
  samples = [(k, l, subgraphs_nodes[k], subgraphs_nodes[l], subgraphs_edges[k][l])
803
  for k in range(subgraph) for l in range(subgraph) if len(subgraphs_edges[k][l]) >= 5]
src/research/COINs-KGGeneration/graph_data/load_data.py CHANGED
@@ -327,7 +327,7 @@ class Transport(Dataset):
327
 
328
  class FreeBase(Dataset):
329
  def __init__(self):
330
- super().__init__("fb15k-237", "data/academic/FB15k-237")
331
 
332
  def load_from_disk(self, *args, **kwargs):
333
  train_edges = pd.read_csv(f"{self.file_location}/train.txt",
@@ -366,7 +366,7 @@ class FreeBase(Dataset):
366
 
367
  class WordNet(Dataset):
368
  def __init__(self):
369
- super().__init__("wn18rr", "data/academic/WN18RR")
370
 
371
  def load_from_disk(self, *args, **kwargs):
372
  train_edges = pd.read_csv(f"{self.file_location}/train.txt",
@@ -405,7 +405,7 @@ class WordNet(Dataset):
405
 
406
  class NELL(Dataset):
407
  def __init__(self):
408
- super().__init__("nell-995", "data/academic/NELL-995")
409
 
410
  def load_from_disk(self, *args, **kwargs):
411
  edge_data = pd.read_csv(f"{self.file_location}/raw.kb",
@@ -525,115 +525,6 @@ class YAGO310(Dataset):
525
  return self.clean_split(train_edges), self.clean_split(valid_edges), self.clean_split(test_edges)
526
 
527
 
528
- swisscom_node_types_map = pd.read_csv(f"data/swisscom_node_types.csv",
529
- header=None, index_col=None, encoding="utf-8",
530
- squeeze=True).reset_index(name="type").set_index("type")["index"]
531
- swisscom_relation_types_map = pd.read_csv(f"data/swisscom_relation_types.csv",
532
- header=None, index_col=None, encoding="utf-8",
533
- squeeze=True).reset_index(name="r").set_index("r")["index"]
534
-
535
-
536
- class Swisscom(Dataset):
537
- def __init__(self, customer_org: str):
538
- super().__init__(f"swisscom-{customer_org}", f"data/swisscom/{customer_org}")
539
- self.customer_org = customer_org
540
- self.node_types_map = swisscom_node_types_map
541
- self.relation_names_map = swisscom_relation_types_map
542
-
543
- def load_from_disk(self):
544
- def compute_final_node_label(row):
545
- if "Alert" in str(row["labels"]):
546
- return "Alert"
547
- elif str(row["labels"]) in ["Base", "-", ""] or pd.isna(row["n.inventory_type"]):
548
- return "-"
549
- else:
550
- return str(row["n.inventory_type"]).capitalize()
551
-
552
- node_data = pd.read_csv(glob(f"{self.file_location}/nodes_*")[-1],
553
- sep=",", header=0, index_col=None, encoding="utf-8",
554
- usecols=["n.id", "labels", "n.inventory_type"])
555
- node_data = node_data.assign(type=node_data.apply(compute_final_node_label, axis=1), time=pd.NaT)
556
- self.node_data = node_data.drop(columns=["labels", "n.inventory_type"]).rename(columns={"n.id": "n"})
557
- del node_data
558
-
559
- edge_data = pd.read_csv(glob(f"{self.file_location}/relations_*")[-1],
560
- sep=",", header=0, index_col=None, encoding="utf-8",
561
- usecols=["r.source_item_id", "relation_type_enriched", "r.target_item_id"])
562
- self.edge_data = edge_data.rename(columns={"r.source_item_id": "s",
563
- "relation_type_enriched": "r",
564
- "r.target_item_id": "t"}).assign(time=pd.NaT).dropna(how="all")
565
- del edge_data
566
-
567
-
568
- class SwisscomBig(Dataset):
569
- def __init__(self):
570
- super().__init__(f"swisscom-big", f"data/swisscom/swisscom-big")
571
- self.node_types_map = swisscom_node_types_map
572
- self.relation_names_map = swisscom_relation_types_map
573
-
574
- def load_from_disk(self, subgraphs: List[Swisscom]):
575
- self.node_data = pd.DataFrame(columns=["n", "type", "time"]).astype(
576
- {"n": "int", "type": "int", "time": "datetime64[ns]"}
577
- )
578
- self.edge_data = pd.DataFrame(columns=["s", "r", "t", "time"]).astype(
579
- {"s": "int", "r": "int", "t": "int", "time": "datetime64[ns]"}
580
- )
581
- self.node_names_map = pd.Series(name="index")
582
- self.node_names_map.index.name = "n"
583
-
584
- for dataset in tqdm(subgraphs, desc="Merging subgraphs into the big graph", leave=False):
585
- dataset.load_from_disk()
586
- is_new_node = ~dataset.node_data.n.isin(self.node_names_map.index)
587
- new_nodes = dataset.node_data.loc[is_new_node]
588
- self.node_names_map = self.node_names_map.append(
589
- len(self.node_names_map)
590
- + new_nodes.n.reset_index(name="n", drop=True).reset_index().set_index("n")["index"]
591
- )
592
- dataset.node_data.loc[is_new_node, "n"] = new_nodes.n.map(self.node_names_map)
593
- dataset.node_data.loc[is_new_node, "type"] = new_nodes.type.map(self.node_types_map)
594
- self.node_data = pd.concat((self.node_data, dataset.node_data.loc[is_new_node]), ignore_index=True)
595
- new_edges = dataset.edge_data
596
- new_edges.s = new_edges.s.map(self.node_names_map.rename("s"))
597
- new_edges.r = new_edges.r.map(self.relation_names_map)
598
- new_edges.t = new_edges.t.map(self.node_names_map.rename("t"))
599
- if len(self.edge_data) > 0:
600
- new_edges = new_edges.loc[~new_edges.s.isin(self.edge_data.s)
601
- | ~new_edges.t.isin(self.edge_data.t)
602
- | ~new_edges.r.isin(self.edge_data.r)]
603
- self.edge_data = pd.concat((self.edge_data, new_edges), ignore_index=True)
604
- dataset.unload_from_memory()
605
-
606
-
607
- class SwisscomCommunity(Dataset):
608
- def __init__(self, community_id: int):
609
- self.community_id = community_id
610
- super().__init__(f"swisscom-community-{community_id}", f"data/swisscom/community-{community_id}")
611
- self.node_types_map = swisscom_node_types_map
612
- self.relation_names_map = swisscom_relation_types_map
613
-
614
- def load_from_disk(self):
615
- return
616
-
617
-
618
- class SwisscomSubgraphOverlap(Dataset):
619
- def __init__(self, customer_org: str, customer_org_2: str):
620
- self.customer_org = customer_org
621
- self.customer_org_2 = customer_org_2
622
- super().__init__(f"swisscom-subgraph-overlap-{customer_org}-{customer_org_2}",
623
- f"data/swisscom/subgraph-overlap-{customer_org}-{customer_org_2}")
624
- self.node_types_map = swisscom_node_types_map
625
- self.relation_names_map = swisscom_relation_types_map
626
-
627
- def load_from_disk(self, subgraph_1: Swisscom, subgraph_2: Swisscom):
628
- subgraph_1.load_from_disk()
629
- subgraph_2.load_from_disk()
630
- self.node_data = subgraph_1.node_data.merge(subgraph_2.node_data, how="inner", on=["n", "type"])
631
- self.node_data.drop(columns="time_y", inplace=True)
632
- self.node_data.rename(columns={"time_x": "time"}, inplace=True)
633
-
634
- self.edge_data = subgraph_1.edge_data.merge(subgraph_2.edge_data, how="inner", on=["s", "r", "t"])
635
- self.edge_data.drop(columns="time_y", inplace=True)
636
- self.edge_data.rename(columns={"time_x": "time"}, inplace=True)
637
  subgraph_1.unload_from_memory()
638
  subgraph_2.unload_from_memory()
639
 
 
327
 
328
  class FreeBase(Dataset):
329
  def __init__(self):
330
+ super().__init__("fb15k-237", "data/FB15k-237")
331
 
332
  def load_from_disk(self, *args, **kwargs):
333
  train_edges = pd.read_csv(f"{self.file_location}/train.txt",
 
366
 
367
  class WordNet(Dataset):
368
  def __init__(self):
369
+ super().__init__("wn18rr", "data/WN18RR")
370
 
371
  def load_from_disk(self, *args, **kwargs):
372
  train_edges = pd.read_csv(f"{self.file_location}/train.txt",
 
405
 
406
  class NELL(Dataset):
407
  def __init__(self):
408
+ super().__init__("nell-995", "data/NELL-995")
409
 
410
  def load_from_disk(self, *args, **kwargs):
411
  edge_data = pd.read_csv(f"{self.file_location}/raw.kb",
 
525
  return self.clean_split(train_edges), self.clean_split(valid_edges), self.clean_split(test_edges)
526
 
527
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
528
  subgraph_1.unload_from_memory()
529
  subgraph_2.unload_from_memory()
530