Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +11 -0
- GCond/.gitignore +140 -0
- GCond/GCond.png +0 -0
- GCond/KDD22_DosCond/README.md +38 -0
- GCond/README.md +139 -0
- GCond/__pycache__/configs.cpython-312.pyc +0 -0
- GCond/__pycache__/utils.cpython-312.pyc +0 -0
- GCond/__pycache__/utils_graphsaint.cpython-312.pyc +0 -0
- GCond/configs.py +24 -0
- GCond/coreset/__init__.py +3 -0
- GCond/coreset/__pycache__/__init__.cpython-312.pyc +0 -0
- GCond/coreset/__pycache__/all_methods.cpython-312.pyc +0 -0
- GCond/coreset/all_methods.py +212 -0
- GCond/gcond_agent_induct.py +327 -0
- GCond/gcond_agent_transduct.py +326 -0
- GCond/models/__pycache__/gcn.cpython-312.pyc +0 -0
- GCond/models/gat.py +312 -0
- GCond/models/gcn.py +404 -0
- GCond/models/myappnp.py +344 -0
- GCond/models/myappnp1.py +348 -0
- GCond/models/mycheby.py +417 -0
- GCond/models/mygatconv.py +203 -0
- GCond/models/mygraphsage.py +353 -0
- GCond/models/parametrized_adj.py +88 -0
- GCond/models/sgc.py +290 -0
- GCond/models/sgc_multi.py +315 -0
- GCond/requirements.txt +11 -0
- GCond/res/cross/empty +1 -0
- GCond/script.sh +4 -0
- GCond/scripts/run_cross.sh +16 -0
- GCond/scripts/run_main.sh +26 -0
- GCond/scripts/script_cross.sh +7 -0
- GCond/test_other_arcs.py +55 -0
- GCond/tester_other_arcs.py +258 -0
- GCond/train_coreset.py +117 -0
- GCond/train_coreset_induct.py +119 -0
- GCond/train_gcond_induct.py +61 -0
- GCond/train_gcond_transduct.py +57 -0
- GCond/utils.py +383 -0
- GCond/utils_graphsaint.py +220 -0
- requirements.txt +2 -2
- src/2.1_lrmc_bilevel.py +2 -2
- src/2.2_lrmc_bilevel.py +2 -1
- src/2.3_lrmc_bilevel.py +2 -1
- src/2.4_lrmc_bilevel.py +2 -1
- src/2.5_lrmc_bilevel.py +540 -0
- src/2.6_lrmc_summary.py +691 -0
- src/2_epsilon_seed_sweep.py +215 -0
- src/2_lrmc_bilevel.py +2 -2
- src/2_random_seed_sweep.sh +47 -0
.gitignore
CHANGED
|
@@ -47,3 +47,14 @@ GraphUNets/data/*
|
|
| 47 |
temp*
|
| 48 |
|
| 49 |
*.tex
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
temp*
|
| 48 |
|
| 49 |
*.tex
|
| 50 |
+
|
| 51 |
+
src/data/*
|
| 52 |
+
|
| 53 |
+
src/cora_seeds/*
|
| 54 |
+
|
| 55 |
+
.venv/*
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
.ipynb_checkpoints/
|
| 59 |
+
|
| 60 |
+
*.dot
|
GCond/.gitignore
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Auxillary file on MacOS
|
| 2 |
+
.DS_Store
|
| 3 |
+
|
| 4 |
+
# Byte-compiled / optimized / DLL files
|
| 5 |
+
__pycache__/
|
| 6 |
+
models/__pycache__/
|
| 7 |
+
*.py[co
|
| 8 |
+
dataset
|
| 9 |
+
|
| 10 |
+
# C extensions
|
| 11 |
+
*.so
|
| 12 |
+
|
| 13 |
+
# Distribution / packaging
|
| 14 |
+
.Python
|
| 15 |
+
data/
|
| 16 |
+
build/
|
| 17 |
+
develop-eggs/
|
| 18 |
+
dist/
|
| 19 |
+
downloads/
|
| 20 |
+
eggs/
|
| 21 |
+
.eggs/
|
| 22 |
+
lib/
|
| 23 |
+
lib64/
|
| 24 |
+
parts/
|
| 25 |
+
sdist/
|
| 26 |
+
var/
|
| 27 |
+
wheels/
|
| 28 |
+
pip-wheel-metadata/
|
| 29 |
+
share/python-wheels/
|
| 30 |
+
*.egg-info/
|
| 31 |
+
.installed.cfg
|
| 32 |
+
*.egg
|
| 33 |
+
MANIFEST
|
| 34 |
+
|
| 35 |
+
# PyInstaller
|
| 36 |
+
# Usually these files are written by a python script from a template
|
| 37 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 38 |
+
*.manifest
|
| 39 |
+
*.spec
|
| 40 |
+
|
| 41 |
+
# Installer logs
|
| 42 |
+
pip-log.txt
|
| 43 |
+
pip-delete-this-directory.txt
|
| 44 |
+
|
| 45 |
+
# Unit test / coverage reports
|
| 46 |
+
htmlcov/
|
| 47 |
+
.tox/
|
| 48 |
+
.nox/
|
| 49 |
+
.coverage
|
| 50 |
+
.coverage.*
|
| 51 |
+
.cache
|
| 52 |
+
nosetests.xml
|
| 53 |
+
coverage.xml
|
| 54 |
+
*.cover
|
| 55 |
+
*.py,cover
|
| 56 |
+
.hypothesis/
|
| 57 |
+
.pytest_cache/
|
| 58 |
+
|
| 59 |
+
# Translations
|
| 60 |
+
*.mo
|
| 61 |
+
*.pot
|
| 62 |
+
|
| 63 |
+
# Django stuff:
|
| 64 |
+
*.log
|
| 65 |
+
local_settings.py
|
| 66 |
+
db.sqlite3
|
| 67 |
+
db.sqlite3-journal
|
| 68 |
+
|
| 69 |
+
# Flask stuff:
|
| 70 |
+
instance/
|
| 71 |
+
.webassets-cache
|
| 72 |
+
|
| 73 |
+
# Scrapy stuff:
|
| 74 |
+
.scrapy
|
| 75 |
+
|
| 76 |
+
# Sphinx documentation
|
| 77 |
+
docs/_build/
|
| 78 |
+
|
| 79 |
+
# PyBuilder
|
| 80 |
+
target/
|
| 81 |
+
|
| 82 |
+
# Jupyter Notebook
|
| 83 |
+
.ipynb_checkpoints
|
| 84 |
+
|
| 85 |
+
# IPython
|
| 86 |
+
profile_default/
|
| 87 |
+
ipython_config.py
|
| 88 |
+
|
| 89 |
+
# pyenv
|
| 90 |
+
.python-version
|
| 91 |
+
|
| 92 |
+
# pipenv
|
| 93 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 94 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 95 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 96 |
+
# install all needed dependencies.
|
| 97 |
+
#Pipfile.lock
|
| 98 |
+
|
| 99 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
| 100 |
+
__pypackages__/
|
| 101 |
+
|
| 102 |
+
# Celery stuff
|
| 103 |
+
celerybeat-schedule
|
| 104 |
+
celerybeat.pid
|
| 105 |
+
|
| 106 |
+
# SageMath parsed files
|
| 107 |
+
*.sage.py
|
| 108 |
+
|
| 109 |
+
# Environments
|
| 110 |
+
.env
|
| 111 |
+
.venv
|
| 112 |
+
env/
|
| 113 |
+
venv/
|
| 114 |
+
ENV/
|
| 115 |
+
env.bak/
|
| 116 |
+
venv.bak/
|
| 117 |
+
|
| 118 |
+
# Spyder project settings
|
| 119 |
+
.spyderproject
|
| 120 |
+
.spyproject
|
| 121 |
+
|
| 122 |
+
# Rope project settings
|
| 123 |
+
.ropeproject
|
| 124 |
+
|
| 125 |
+
# mkdocs documentation
|
| 126 |
+
/site
|
| 127 |
+
|
| 128 |
+
# mypy
|
| 129 |
+
.mypy_cache/
|
| 130 |
+
.dmypy.json
|
| 131 |
+
dmypy.json
|
| 132 |
+
|
| 133 |
+
# Pyre type checker
|
| 134 |
+
.pyre/d]
|
| 135 |
+
*$py.class:
|
| 136 |
+
|
| 137 |
+
def __init__(self):
|
| 138 |
+
pass
|
| 139 |
+
|
| 140 |
+
|
GCond/GCond.png
ADDED
|
GCond/KDD22_DosCond/README.md
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# DosCond
|
| 2 |
+
|
| 3 |
+
[KDD 2022] A PyTorch Implementation for ["Condensing Graphs via One-Step Gradient Matching"](https://arxiv.org/abs/2206.07746) under node classification setting. For graph classification setting, please refer to [https://github.com/amazon-research/DosCond](https://github.com/amazon-research/DosCond).
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
Abstract
|
| 7 |
+
----
|
| 8 |
+
As training deep learning models on large dataset takes a lot of time and resources, it is desired to construct a small synthetic dataset with which we can train deep learning models sufficiently. There are recent works that have explored solutions on condensing image datasets through complex bi-level optimization. For instance, dataset condensation (DC) matches network gradients w.r.t. large-real data and small-synthetic data, where the network weights are optimized for multiple steps at each outer iteration. However, existing approaches have their inherent limitations: (1) they are not directly applicable to graphs where the data is discrete; and (2) the condensation process is computationally expensive due to the involved nested optimization. To bridge the gap, we investigate efficient dataset condensation tailored for graph datasets where we model the discrete graph structure as a probabilistic model. We further propose a one-step gradient matching scheme, which performs gradient matching for only one single step without training the network weights.
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
Here we do not implement the discrete structure learning, but only borrow the idea from ["Condensing Graphs via One-Step Gradient Matching"](https://arxiv.org/abs/2206.07746) to perform one-step gradient matching, which significantly speeds up the condensation process.
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
Essentially, we can run the following commands:
|
| 15 |
+
```
|
| 16 |
+
python train_gcond_transduct.py --dataset citeseer --nlayers=2 --lr_feat=1e-3 --lr_adj=1e-3 --r=0.5 --sgc=0 --dis=mse --one_step=1 --epochs=3000
|
| 17 |
+
python train_gcond_transduct.py --dataset cora --nlayers=2 --lr_feat=1e-3 --lr_adj=1e-3 --r=0.5 --sgc=0 --dis=mse --gpu_id=2 --one_step=1 --epochs=5000
|
| 18 |
+
python train_gcond_transduct.py --dataset pubmed --nlayers=2 --lr_feat=1e-3 --lr_adj=1e-3 --r=0.5 --sgc=0 --dis=mse --gpu_id=2 --one_step=1 --epochs=2000
|
| 19 |
+
python train_gcond_transduct.py --dataset ogbn-arxiv --nlayers=2 --lr_feat=1e-2 --lr_adj=2e-2 --r=0.001 --sgc=1 --dis=ours --gpu_id=2 --one_step=1 --epochs=1000
|
| 20 |
+
python train_gcond_induct.py --dataset flickr --nlayers=2 --lr_feat=5e-3 --lr_adj=5e-3 --r=0.001 --sgc=0 --dis=mse --gpu_id=3 --one_step=1 --epochs=1000
|
| 21 |
+
```
|
| 22 |
+
Note that using smaller learning rate and larger epochs can get even higher performance.
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
## Cite
|
| 26 |
+
For more information, you can take a look at the [paper](https://arxiv.org/abs/2206.07746).
|
| 27 |
+
|
| 28 |
+
If you find this repo to be useful, please cite our paper. Thank you.
|
| 29 |
+
```
|
| 30 |
+
@inproceedings{jin2022condensing,
|
| 31 |
+
title={Condensing Graphs via One-Step Gradient Matching},
|
| 32 |
+
author={Jin, Wei and Tang, Xianfeng and Jiang, Haoming and Li, Zheng and Zhang, Danqing and Tang, Jiliang and Yin, Bing},
|
| 33 |
+
booktitle={Proceedings of the 28th ACM SIGKDD Conference on Knowledge Discovery and Data Mining},
|
| 34 |
+
pages={720--730},
|
| 35 |
+
year={2022}
|
| 36 |
+
}
|
| 37 |
+
```
|
| 38 |
+
|
GCond/README.md
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# GCond
|
| 2 |
+
[ICLR 2022] The PyTorch implementation for ["Graph Condensation for Graph Neural Networks"](https://cse.msu.edu/~jinwei2/files/GCond.pdf) is provided under the main directory.
|
| 3 |
+
|
| 4 |
+
[KDD 2022] The implementation for ["Condensing Graphs via One-Step Gradient Matching"](https://arxiv.org/abs/2206.07746) is shown in the `KDD22_DosCond` directory. See [link](https://github.com/ChandlerBang/GCond/tree/main/KDD22_DosCond).
|
| 5 |
+
|
| 6 |
+
[IJCAI 2024] Please read our recent survey ["A Comprehensive Survey on Graph Reduction: Sparsification, Coarsening, and Condensation"](https://arxiv.org/abs/2402.03358) for a detailed review of graph reduction techniques!
|
| 7 |
+
|
| 8 |
+
[ArXiv 2024] We released a benchmarking framework for graph condensation ["GC4NC: A Benchmark Framework for Graph Condensation with New Insights"](https://arxiv.org/abs/2406.16715), including **robustness**, **privacy preservation**, NAS performance, property analysis, etc!
|
| 9 |
+
|
| 10 |
+
Abstract
|
| 11 |
+
----
|
| 12 |
+
We propose and study the problem of graph condensation for graph neural networks (GNNs). Specifically, we aim to condense the large, original graph into a small, synthetic, and highly-informative graph, such that GNNs trained on the small graph and large graph have comparable performance. Extensive experiments have demonstrated the effectiveness of the proposed framework in condensing different graph datasets into informative smaller graphs. In particular, we are able to approximate the original test accuracy by 95.3% on Reddit, 99.8% on Flickr and 99.0% on Citeseer, while reducing their graph size by more than 99.9%, and the condensed graphs can be used to train various GNN architectures.
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
![]()
|
| 16 |
+
|
| 17 |
+
<div align=center><img src="https://github.com/ChandlerBang/GCond/blob/main/GCond.png" width="800"/></div>
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
## A Nice Survey Paper
|
| 21 |
+
Please check out our survey paper blew, which summarizes the recent advances in graph condensation.
|
| 22 |
+
|
| 23 |
+
[[A Comprehensive Survey on Graph Reduction: Sparsification, Coarsening, and Condensation]](https://arxiv.org/abs/2402.03358)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
## Requirements
|
| 28 |
+
Please see [requirements.txt](https://github.com/ChandlerBang/GCond/blob/main/requirements.txt).
|
| 29 |
+
```
|
| 30 |
+
torch==1.7.0
|
| 31 |
+
torch_geometric==1.6.3
|
| 32 |
+
scipy==1.6.2
|
| 33 |
+
numpy==1.19.2
|
| 34 |
+
ogb==1.3.0
|
| 35 |
+
tqdm==4.59.0
|
| 36 |
+
torch_sparse==0.6.9
|
| 37 |
+
deeprobust==0.2.4
|
| 38 |
+
scikit_learn==1.0.2
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
## Download Datasets
|
| 42 |
+
For cora, citeseer and pubmed, the code will directly download them; so no extra script is needed.
|
| 43 |
+
For reddit, flickr and arxiv, we use the datasets provided by [GraphSAINT](https://github.com/GraphSAINT/GraphSAINT).
|
| 44 |
+
They are available on [Google Drive link](https://drive.google.com/open?id=1zycmmDES39zVlbVCYs88JTJ1Wm5FbfLz) (alternatively, [BaiduYun link (code: f1ao)](https://pan.baidu.com/s/1SOb0SiSAXavwAcNqkttwcg)). Rename the folder to `data` at the root directory. Note that the links are provided by GraphSAINT team.
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
## Run the code
|
| 50 |
+
For transductive setting, please run the following command:
|
| 51 |
+
```
|
| 52 |
+
python train_gcond_transduct.py --dataset cora --nlayers=2 --lr_feat=1e-4 --gpu_id=0 --lr_adj=1e-4 --r=0.5
|
| 53 |
+
```
|
| 54 |
+
where `r` indicates the ratio of condensed samples to the labeled samples. For instance, there are only 140 labeled nodes in Cora dataset, so `r=0.5` indicates the number of condensed samples are 70, **which corresponds to r=2.6%=70/2710 in our paper**. Thus, the parameter `r` is different from the real reduction rate in the paper for the transductive setting, please see the following table for the correspondence.
|
| 55 |
+
|
| 56 |
+
| | `r` in the code | `r` in the paper (real reduction rate) |
|
| 57 |
+
|--------------|-------------------|---------------------|
|
| 58 |
+
| Transductive | Cora, r=0.5 | Cora, r=2.6% |
|
| 59 |
+
| Transductive | Citeseer, r=0.5 | Citeseer, r=1.8% |
|
| 60 |
+
| Transductive | Ogbn-arxiv, r= 0.005 | Ogbn-arxiv, r=0.25% |
|
| 61 |
+
| Transductive | Pubmed, r=0.5 | Pubmed, r=0.3% |
|
| 62 |
+
| Inductive | Flickr, r=0.01 | Flickr, r=1% |
|
| 63 |
+
| Inductive | Reddit, r=0.001 | Reddit, r=0.1% |
|
| 64 |
+
|
| 65 |
+
For inductive setting, please run the following command:
|
| 66 |
+
```
|
| 67 |
+
python train_gcond_induct.py --dataset flickr --nlayers=2 --lr_feat=0.01 --gpu_id=0 --lr_adj=0.01 --r=0.005 --epochs=1000 --outer=10 --inner=1
|
| 68 |
+
```
|
| 69 |
+
|
| 70 |
+
## Reproduce the performance
|
| 71 |
+
The generated graphs are saved in the folder `saved_ours`; you can directly load them to test the performance.
|
| 72 |
+
|
| 73 |
+
For Table 2, run `bash scripts/run_main.sh`.
|
| 74 |
+
|
| 75 |
+
For Table 3, run `bash scripts/run_cross.sh`.
|
| 76 |
+
|
| 77 |
+
## [Faster Condensation!] One-Step Gradient Matching
|
| 78 |
+
From the KDD'22 paper ["Condensing Graphs via One-Step Gradient Matching"](https://arxiv.org/abs/2206.07746), we know that performing gradient matching for only one step can also achieve a good performance while significantly accelerating the condensation process. Hence, we can run the following command to perform one-step gradient matching, which is essentially much faster than the original version:
|
| 79 |
+
```
|
| 80 |
+
python train_gcond_transduct.py --dataset citeseer --nlayers=2 --lr_feat=1e-2 --lr_adj=1e-2 --r=0.5 \
|
| 81 |
+
--sgc=0 --dis=mse --gpu_id=2 --one_step=1 --epochs=3000
|
| 82 |
+
```
|
| 83 |
+
For more commands, please go to [`KDD22_DosCond`](https://github.com/ChandlerBang/GCond/tree/main/KDD22_DosCond).
|
| 84 |
+
|
| 85 |
+
**[Note]: I found that sometimes using MSE loss for gradient matching can be more stable than using `ours` loss**, and it gives more flexibility on the model used in condensation (using GCN as the backbone can also generate good condensed graphs).
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
## Whole Dataset Performance
|
| 89 |
+
When we do coreset selection, we need to first the model on the whole dataset. Thus we can obtain the performanceo of whole dataset by running `train_coreset.py` and `train_coreset_induct.py`:
|
| 90 |
+
```
|
| 91 |
+
python train_coreset.py --dataset cora --r=0.01 --method=random
|
| 92 |
+
python train_coreset_induct.py --dataset flickr --r=0.01 --method=random
|
| 93 |
+
```
|
| 94 |
+
|
| 95 |
+
## Coreset Performance
|
| 96 |
+
Run the following code to get the coreset performance for transductive setting.
|
| 97 |
+
```
|
| 98 |
+
python train_coreset.py --dataset cora --r=0.01 --method=herding
|
| 99 |
+
python train_coreset.py --dataset cora --r=0.01 --method=random
|
| 100 |
+
python train_coreset.py --dataset cora --r=0.01 --method=kcenter
|
| 101 |
+
```
|
| 102 |
+
Similarly, run the following code for the inductive setting.
|
| 103 |
+
```
|
| 104 |
+
python train_coreset_induct.py --dataset flickr --r=0.01 --method=kcenter
|
| 105 |
+
```
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
## Cite
|
| 109 |
+
If you find this repo to be useful, please cite our three papers. Thank you!
|
| 110 |
+
```
|
| 111 |
+
@inproceedings{
|
| 112 |
+
jin2022graph,
|
| 113 |
+
title={Graph Condensation for Graph Neural Networks},
|
| 114 |
+
author={Wei Jin and Lingxiao Zhao and Shichang Zhang and Yozen Liu and Jiliang Tang and Neil Shah},
|
| 115 |
+
booktitle={International Conference on Learning Representations},
|
| 116 |
+
year={2022},
|
| 117 |
+
url={https://openreview.net/forum?id=WLEx3Jo4QaB}
|
| 118 |
+
}
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
```
|
| 122 |
+
@inproceedings{jin2022condensing,
|
| 123 |
+
title={Condensing Graphs via One-Step Gradient Matching},
|
| 124 |
+
author={Jin, Wei and Tang, Xianfeng and Jiang, Haoming and Li, Zheng and Zhang, Danqing and Tang, Jiliang and Yin, Bing},
|
| 125 |
+
booktitle={Proceedings of the 28th ACM SIGKDD Conference on Knowledge Discovery and Data Mining},
|
| 126 |
+
pages={720--730},
|
| 127 |
+
year={2022}
|
| 128 |
+
}
|
| 129 |
+
```
|
| 130 |
+
|
| 131 |
+
```
|
| 132 |
+
@article{hashemi2024comprehensive,
|
| 133 |
+
title={A Comprehensive Survey on Graph Reduction: Sparsification, Coarsening, and Condensation},
|
| 134 |
+
author={Hashemi, Mohammad and Gong, Shengbo and Ni, Juntong and Fan, Wenqi and Prakash, B Aditya and Jin, Wei},
|
| 135 |
+
journal={International Joint Conference on Artificial Intelligence (IJCAI)},
|
| 136 |
+
year={2024}
|
| 137 |
+
}
|
| 138 |
+
```
|
| 139 |
+
|
GCond/__pycache__/configs.cpython-312.pyc
ADDED
|
Binary file (706 Bytes). View file
|
|
|
GCond/__pycache__/utils.cpython-312.pyc
ADDED
|
Binary file (20.5 kB). View file
|
|
|
GCond/__pycache__/utils_graphsaint.cpython-312.pyc
ADDED
|
Binary file (12.2 kB). View file
|
|
|
GCond/configs.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''Configuration'''
|
| 2 |
+
|
| 3 |
+
def load_config(args):
|
| 4 |
+
dataset = args.dataset
|
| 5 |
+
if dataset in ['flickr']:
|
| 6 |
+
args.nlayers = 2
|
| 7 |
+
args.hidden = 256
|
| 8 |
+
args.weight_decay = 5e-3
|
| 9 |
+
args.dropout = 0.0
|
| 10 |
+
|
| 11 |
+
if dataset in ['reddit']:
|
| 12 |
+
args.nlayers = 2
|
| 13 |
+
args.hidden = 256
|
| 14 |
+
args.weight_decay = 0e-4
|
| 15 |
+
args.dropout = 0
|
| 16 |
+
|
| 17 |
+
if dataset in ['ogbn-arxiv']:
|
| 18 |
+
args.hidden = 256
|
| 19 |
+
args.weight_decay = 0
|
| 20 |
+
args.dropout = 0
|
| 21 |
+
|
| 22 |
+
return args
|
| 23 |
+
|
| 24 |
+
|
GCond/coreset/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .all_methods import KCenter, Herding, Random, LRMC
|
| 2 |
+
|
| 3 |
+
__all__ = ['KCenter', 'Herding', 'Random', 'LRMC']
|
GCond/coreset/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (260 Bytes). View file
|
|
|
GCond/coreset/__pycache__/all_methods.cpython-312.pyc
ADDED
|
Binary file (12.1 kB). View file
|
|
|
GCond/coreset/all_methods.py
ADDED
|
@@ -0,0 +1,212 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
import json
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class Base:
|
| 7 |
+
|
| 8 |
+
def __init__(self, data, args, device='cuda', **kwargs):
|
| 9 |
+
self.data = data
|
| 10 |
+
self.args = args
|
| 11 |
+
self.device = device
|
| 12 |
+
n = int(data.feat_train.shape[0] * args.reduction_rate)
|
| 13 |
+
d = data.feat_train.shape[1]
|
| 14 |
+
self.nnodes_syn = n
|
| 15 |
+
self.labels_syn = torch.LongTensor(self.generate_labels_syn(data)).to(device)
|
| 16 |
+
|
| 17 |
+
def generate_labels_syn(self, data):
|
| 18 |
+
from collections import Counter
|
| 19 |
+
counter = Counter(data.labels_train)
|
| 20 |
+
num_class_dict = {}
|
| 21 |
+
n = len(data.labels_train)
|
| 22 |
+
|
| 23 |
+
sorted_counter = sorted(counter.items(), key=lambda x:x[1])
|
| 24 |
+
sum_ = 0
|
| 25 |
+
labels_syn = []
|
| 26 |
+
self.syn_class_indices = {}
|
| 27 |
+
for ix, (c, num) in enumerate(sorted_counter):
|
| 28 |
+
if ix == len(sorted_counter) - 1:
|
| 29 |
+
num_class_dict[c] = int(n * self.args.reduction_rate) - sum_
|
| 30 |
+
self.syn_class_indices[c] = [len(labels_syn), len(labels_syn) + num_class_dict[c]]
|
| 31 |
+
labels_syn += [c] * num_class_dict[c]
|
| 32 |
+
else:
|
| 33 |
+
num_class_dict[c] = max(int(num * self.args.reduction_rate), 1)
|
| 34 |
+
sum_ += num_class_dict[c]
|
| 35 |
+
self.syn_class_indices[c] = [len(labels_syn), len(labels_syn) + num_class_dict[c]]
|
| 36 |
+
labels_syn += [c] * num_class_dict[c]
|
| 37 |
+
|
| 38 |
+
self.num_class_dict = num_class_dict
|
| 39 |
+
return labels_syn
|
| 40 |
+
|
| 41 |
+
def select(self):
|
| 42 |
+
return
|
| 43 |
+
|
| 44 |
+
class KCenter(Base):
|
| 45 |
+
|
| 46 |
+
def __init__(self, data, args, device='cuda', **kwargs):
|
| 47 |
+
super(KCenter, self).__init__(data, args, device='cuda', **kwargs)
|
| 48 |
+
|
| 49 |
+
def select(self, embeds, inductive=False):
|
| 50 |
+
# feature: embeds
|
| 51 |
+
# kcenter # class by class
|
| 52 |
+
num_class_dict = self.num_class_dict
|
| 53 |
+
if inductive:
|
| 54 |
+
idx_train = np.arange(len(self.data.idx_train))
|
| 55 |
+
else:
|
| 56 |
+
idx_train = self.data.idx_train
|
| 57 |
+
labels_train = self.data.labels_train
|
| 58 |
+
idx_selected = []
|
| 59 |
+
|
| 60 |
+
for class_id, cnt in num_class_dict.items():
|
| 61 |
+
idx = idx_train[labels_train==class_id]
|
| 62 |
+
feature = embeds[idx]
|
| 63 |
+
mean = torch.mean(feature, dim=0, keepdim=True)
|
| 64 |
+
# dis = distance(feature, mean)[:,0]
|
| 65 |
+
dis = torch.cdist(feature, mean)[:,0]
|
| 66 |
+
rank = torch.argsort(dis)
|
| 67 |
+
idx_centers = rank[:1].tolist()
|
| 68 |
+
for i in range(cnt-1):
|
| 69 |
+
feature_centers = feature[idx_centers]
|
| 70 |
+
dis_center = torch.cdist(feature, feature_centers)
|
| 71 |
+
dis_min, _ = torch.min(dis_center, dim=-1)
|
| 72 |
+
id_max = torch.argmax(dis_min).item()
|
| 73 |
+
idx_centers.append(id_max)
|
| 74 |
+
|
| 75 |
+
idx_selected.append(idx[idx_centers])
|
| 76 |
+
# return np.array(idx_selected).reshape(-1)
|
| 77 |
+
return np.hstack(idx_selected)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
class Herding(Base):
|
| 81 |
+
|
| 82 |
+
def __init__(self, data, args, device='cuda', **kwargs):
|
| 83 |
+
super(Herding, self).__init__(data, args, device='cuda', **kwargs)
|
| 84 |
+
|
| 85 |
+
def select(self, embeds, inductive=False):
|
| 86 |
+
num_class_dict = self.num_class_dict
|
| 87 |
+
if inductive:
|
| 88 |
+
idx_train = np.arange(len(self.data.idx_train))
|
| 89 |
+
else:
|
| 90 |
+
idx_train = self.data.idx_train
|
| 91 |
+
labels_train = self.data.labels_train
|
| 92 |
+
idx_selected = []
|
| 93 |
+
|
| 94 |
+
# herding # class by class
|
| 95 |
+
for class_id, cnt in num_class_dict.items():
|
| 96 |
+
idx = idx_train[labels_train==class_id]
|
| 97 |
+
features = embeds[idx]
|
| 98 |
+
mean = torch.mean(features, dim=0, keepdim=True)
|
| 99 |
+
selected = []
|
| 100 |
+
idx_left = np.arange(features.shape[0]).tolist()
|
| 101 |
+
|
| 102 |
+
for i in range(cnt):
|
| 103 |
+
det = mean*(i+1) - torch.sum(features[selected], dim=0)
|
| 104 |
+
dis = torch.cdist(det, features[idx_left])
|
| 105 |
+
id_min = torch.argmin(dis)
|
| 106 |
+
selected.append(idx_left[id_min])
|
| 107 |
+
del idx_left[id_min]
|
| 108 |
+
idx_selected.append(idx[selected])
|
| 109 |
+
# return np.array(idx_selected).reshape(-1)
|
| 110 |
+
return np.hstack(idx_selected)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
class Random(Base):
|
| 114 |
+
|
| 115 |
+
def __init__(self, data, args, device='cuda', **kwargs):
|
| 116 |
+
super(Random, self).__init__(data, args, device='cuda', **kwargs)
|
| 117 |
+
|
| 118 |
+
def select(self, embeds, inductive=False):
|
| 119 |
+
num_class_dict = self.num_class_dict
|
| 120 |
+
if inductive:
|
| 121 |
+
idx_train = np.arange(len(self.data.idx_train))
|
| 122 |
+
else:
|
| 123 |
+
idx_train = self.data.idx_train
|
| 124 |
+
|
| 125 |
+
labels_train = self.data.labels_train
|
| 126 |
+
idx_selected = []
|
| 127 |
+
|
| 128 |
+
for class_id, cnt in num_class_dict.items():
|
| 129 |
+
idx = idx_train[labels_train==class_id]
|
| 130 |
+
selected = np.random.permutation(idx)
|
| 131 |
+
idx_selected.append(selected[:cnt])
|
| 132 |
+
|
| 133 |
+
# return np.array(idx_selected).reshape(-1)
|
| 134 |
+
return np.hstack(idx_selected)
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class LRMC(Base):
|
| 138 |
+
"""
|
| 139 |
+
Coreset selection using precomputed seed nodes from the Laplacian‑Integrated
|
| 140 |
+
Relaxed Maximal Clique (L‑RMC) algorithm. Seed nodes are read from a JSON
|
| 141 |
+
file specified by ``args.lrmc_seeds_path`` and used to preferentially select
|
| 142 |
+
training examples. Per‑class reduction counts are respected: if a class has
|
| 143 |
+
fewer seeds than required, random training nodes from that class are added
|
| 144 |
+
until the quota is met.
|
| 145 |
+
"""
|
| 146 |
+
|
| 147 |
+
def __init__(self, data, args, device='cuda', **kwargs):
|
| 148 |
+
super(LRMC, self).__init__(data, args, device=device, **kwargs)
|
| 149 |
+
seeds_path = getattr(args, 'lrmc_seeds_path', None)
|
| 150 |
+
if seeds_path is None:
|
| 151 |
+
raise ValueError(
|
| 152 |
+
"LRMC method selected but no path to seed file provided. "
|
| 153 |
+
"Please specify --lrmc_seeds_path when running the training script."
|
| 154 |
+
)
|
| 155 |
+
self.seed_nodes = self._load_seed_nodes(seeds_path)
|
| 156 |
+
|
| 157 |
+
def _load_seed_nodes(self, path: str):
|
| 158 |
+
# Parse seed nodes from JSON file (supports 'seed_nodes' or 'members').
|
| 159 |
+
with open(path, 'r') as f:
|
| 160 |
+
js = json.load(f)
|
| 161 |
+
clusters = js.get('clusters', [])
|
| 162 |
+
if not clusters:
|
| 163 |
+
raise ValueError(f"No clusters found in L‑RMC seeds file {path}")
|
| 164 |
+
def _cluster_length(c):
|
| 165 |
+
nodes = c.get('seed_nodes') or c.get('members') or []
|
| 166 |
+
return len(nodes)
|
| 167 |
+
best_cluster = max(clusters, key=_cluster_length)
|
| 168 |
+
nodes = best_cluster.get('seed_nodes') or best_cluster.get('members') or []
|
| 169 |
+
seed_nodes = []
|
| 170 |
+
for u in nodes:
|
| 171 |
+
try:
|
| 172 |
+
uid = int(u)
|
| 173 |
+
except Exception:
|
| 174 |
+
continue
|
| 175 |
+
zero_idx = uid - 1
|
| 176 |
+
if zero_idx >= 0:
|
| 177 |
+
seed_nodes.append(zero_idx)
|
| 178 |
+
else:
|
| 179 |
+
if uid >= 0:
|
| 180 |
+
seed_nodes.append(uid)
|
| 181 |
+
seed_nodes = sorted(set(seed_nodes))
|
| 182 |
+
return seed_nodes
|
| 183 |
+
|
| 184 |
+
def select(self, embeds, inductive=False):
|
| 185 |
+
# Determine training indices depending on the inductive setting.
|
| 186 |
+
if inductive:
|
| 187 |
+
idx_train = np.arange(len(self.data.idx_train))
|
| 188 |
+
labels_train = self.data.labels_train
|
| 189 |
+
else:
|
| 190 |
+
idx_train = self.data.idx_train
|
| 191 |
+
labels_train = self.data.labels_train
|
| 192 |
+
num_class_dict = self.num_class_dict
|
| 193 |
+
idx_selected = []
|
| 194 |
+
seed_set = set(self.seed_nodes)
|
| 195 |
+
# Pick seed nodes per class; fill remainder with random nodes if needed.
|
| 196 |
+
for class_id, cnt in num_class_dict.items():
|
| 197 |
+
class_mask = (labels_train == class_id)
|
| 198 |
+
class_indices = idx_train[class_mask]
|
| 199 |
+
seed_in_class = [u for u in class_indices if u in seed_set]
|
| 200 |
+
selected = seed_in_class[:min(len(seed_in_class), cnt)]
|
| 201 |
+
remaining_required = cnt - len(selected)
|
| 202 |
+
if remaining_required > 0:
|
| 203 |
+
remaining_candidates = [u for u in class_indices if u not in selected]
|
| 204 |
+
if len(remaining_candidates) <= remaining_required:
|
| 205 |
+
additional = remaining_candidates
|
| 206 |
+
else:
|
| 207 |
+
additional = np.random.choice(remaining_candidates, remaining_required, replace=False).tolist()
|
| 208 |
+
selected += additional
|
| 209 |
+
idx_selected.append(np.array(selected))
|
| 210 |
+
return np.hstack(idx_selected)
|
| 211 |
+
|
| 212 |
+
|
GCond/gcond_agent_induct.py
ADDED
|
@@ -0,0 +1,327 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.optim as optim
|
| 5 |
+
from torch.nn import Parameter
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from utils import match_loss, regularization, row_normalize_tensor
|
| 8 |
+
import deeprobust.graph.utils as utils
|
| 9 |
+
from copy import deepcopy
|
| 10 |
+
import numpy as np
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
from models.gcn import GCN
|
| 13 |
+
from models.sgc import SGC
|
| 14 |
+
from models.sgc_multi import SGC as SGC1
|
| 15 |
+
from models.parametrized_adj import PGE
|
| 16 |
+
import scipy.sparse as sp
|
| 17 |
+
from torch_sparse import SparseTensor
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class GCond:
|
| 21 |
+
|
| 22 |
+
def __init__(self, data, args, device='cuda', **kwargs):
|
| 23 |
+
self.data = data
|
| 24 |
+
self.args = args
|
| 25 |
+
self.device = device
|
| 26 |
+
|
| 27 |
+
n = int(len(data.idx_train) * args.reduction_rate)
|
| 28 |
+
d = data.feat_train.shape[1]
|
| 29 |
+
self.nnodes_syn = n
|
| 30 |
+
self.feat_syn = nn.Parameter(torch.FloatTensor(n, d).to(device))
|
| 31 |
+
self.pge = PGE(nfeat=d, nnodes=n, device=device, args=args).to(device)
|
| 32 |
+
|
| 33 |
+
self.labels_syn = torch.LongTensor(self.generate_labels_syn(data)).to(device)
|
| 34 |
+
self.reset_parameters()
|
| 35 |
+
self.optimizer_feat = torch.optim.Adam([self.feat_syn], lr=args.lr_feat)
|
| 36 |
+
self.optimizer_pge = torch.optim.Adam(self.pge.parameters(), lr=args.lr_adj)
|
| 37 |
+
print('adj_syn:', (n,n), 'feat_syn:', self.feat_syn.shape)
|
| 38 |
+
|
| 39 |
+
def reset_parameters(self):
|
| 40 |
+
self.feat_syn.data.copy_(torch.randn(self.feat_syn.size()))
|
| 41 |
+
|
| 42 |
+
def generate_labels_syn(self, data):
|
| 43 |
+
from collections import Counter
|
| 44 |
+
counter = Counter(data.labels_train)
|
| 45 |
+
num_class_dict = {}
|
| 46 |
+
n = len(data.labels_train)
|
| 47 |
+
|
| 48 |
+
sorted_counter = sorted(counter.items(), key=lambda x:x[1])
|
| 49 |
+
sum_ = 0
|
| 50 |
+
labels_syn = []
|
| 51 |
+
self.syn_class_indices = {}
|
| 52 |
+
|
| 53 |
+
for ix, (c, num) in enumerate(sorted_counter):
|
| 54 |
+
if ix == len(sorted_counter) - 1:
|
| 55 |
+
num_class_dict[c] = int(n * self.args.reduction_rate) - sum_
|
| 56 |
+
self.syn_class_indices[c] = [len(labels_syn), len(labels_syn) + num_class_dict[c]]
|
| 57 |
+
labels_syn += [c] * num_class_dict[c]
|
| 58 |
+
else:
|
| 59 |
+
num_class_dict[c] = max(int(num * self.args.reduction_rate), 1)
|
| 60 |
+
sum_ += num_class_dict[c]
|
| 61 |
+
self.syn_class_indices[c] = [len(labels_syn), len(labels_syn) + num_class_dict[c]]
|
| 62 |
+
labels_syn += [c] * num_class_dict[c]
|
| 63 |
+
|
| 64 |
+
self.num_class_dict = num_class_dict
|
| 65 |
+
return labels_syn
|
| 66 |
+
|
| 67 |
+
def test_with_val(self, verbose=True):
|
| 68 |
+
res = []
|
| 69 |
+
|
| 70 |
+
data, device = self.data, self.device
|
| 71 |
+
feat_syn, pge, labels_syn = self.feat_syn.detach(), \
|
| 72 |
+
self.pge, self.labels_syn
|
| 73 |
+
# with_bn = True if args.dataset in ['ogbn-arxiv'] else False
|
| 74 |
+
dropout = 0.5 if self.args.dataset in ['reddit'] else 0
|
| 75 |
+
model = GCN(nfeat=feat_syn.shape[1], nhid=self.args.hidden, dropout=dropout,
|
| 76 |
+
weight_decay=5e-4, nlayers=2,
|
| 77 |
+
nclass=data.nclass, device=device).to(device)
|
| 78 |
+
|
| 79 |
+
adj_syn = pge.inference(feat_syn)
|
| 80 |
+
args = self.args
|
| 81 |
+
|
| 82 |
+
if args.save:
|
| 83 |
+
torch.save(adj_syn, f'saved_ours/adj_{args.dataset}_{args.reduction_rate}_{args.seed}.pt')
|
| 84 |
+
torch.save(feat_syn, f'saved_ours/feat_{args.dataset}_{args.reduction_rate}_{args.seed}.pt')
|
| 85 |
+
|
| 86 |
+
noval = True
|
| 87 |
+
model.fit_with_val(feat_syn, adj_syn, labels_syn, data,
|
| 88 |
+
train_iters=600, normalize=True, verbose=False, noval=noval)
|
| 89 |
+
|
| 90 |
+
model.eval()
|
| 91 |
+
labels_test = torch.LongTensor(data.labels_test).cuda()
|
| 92 |
+
|
| 93 |
+
output = model.predict(data.feat_test, data.adj_test)
|
| 94 |
+
|
| 95 |
+
loss_test = F.nll_loss(output, labels_test)
|
| 96 |
+
acc_test = utils.accuracy(output, labels_test)
|
| 97 |
+
res.append(acc_test.item())
|
| 98 |
+
if verbose:
|
| 99 |
+
print("Test set results:",
|
| 100 |
+
"loss= {:.4f}".format(loss_test.item()),
|
| 101 |
+
"accuracy= {:.4f}".format(acc_test.item()))
|
| 102 |
+
print(adj_syn.sum(), adj_syn.sum()/(adj_syn.shape[0]**2))
|
| 103 |
+
|
| 104 |
+
if False:
|
| 105 |
+
if self.args.dataset == 'ogbn-arxiv':
|
| 106 |
+
thresh = 0.6
|
| 107 |
+
elif self.args.dataset == 'reddit':
|
| 108 |
+
thresh = 0.91
|
| 109 |
+
else:
|
| 110 |
+
thresh = 0.7
|
| 111 |
+
|
| 112 |
+
labels_train = torch.LongTensor(data.labels_train).cuda()
|
| 113 |
+
output = model.predict(data.feat_train, data.adj_train)
|
| 114 |
+
# loss_train = F.nll_loss(output, labels_train)
|
| 115 |
+
# acc_train = utils.accuracy(output, labels_train)
|
| 116 |
+
loss_train = torch.tensor(0)
|
| 117 |
+
acc_train = torch.tensor(0)
|
| 118 |
+
if verbose:
|
| 119 |
+
print("Train set results:",
|
| 120 |
+
"loss= {:.4f}".format(loss_train.item()),
|
| 121 |
+
"accuracy= {:.4f}".format(acc_train.item()))
|
| 122 |
+
res.append(acc_train.item())
|
| 123 |
+
return res
|
| 124 |
+
|
| 125 |
+
def train(self, verbose=True):
|
| 126 |
+
args = self.args
|
| 127 |
+
data = self.data
|
| 128 |
+
feat_syn, pge, labels_syn = self.feat_syn, self.pge, self.labels_syn
|
| 129 |
+
features, adj, labels = data.feat_train, data.adj_train, data.labels_train
|
| 130 |
+
syn_class_indices = self.syn_class_indices
|
| 131 |
+
features, adj, labels = utils.to_tensor(features, adj, labels, device=self.device)
|
| 132 |
+
feat_sub, adj_sub = self.get_sub_adj_feat(features)
|
| 133 |
+
self.feat_syn.data.copy_(feat_sub)
|
| 134 |
+
|
| 135 |
+
if utils.is_sparse_tensor(adj):
|
| 136 |
+
adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
|
| 137 |
+
else:
|
| 138 |
+
adj_norm = utils.normalize_adj_tensor(adj)
|
| 139 |
+
|
| 140 |
+
adj = adj_norm
|
| 141 |
+
adj = SparseTensor(row=adj._indices()[0], col=adj._indices()[1],
|
| 142 |
+
value=adj._values(), sparse_sizes=adj.size()).t()
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
outer_loop, inner_loop = get_loops(args)
|
| 146 |
+
|
| 147 |
+
for it in range(args.epochs+1):
|
| 148 |
+
loss_avg = 0
|
| 149 |
+
if args.sgc==1:
|
| 150 |
+
model = SGC(nfeat=data.feat_train.shape[1], nhid=args.hidden,
|
| 151 |
+
nclass=data.nclass, dropout=args.dropout,
|
| 152 |
+
nlayers=args.nlayers, with_bn=False,
|
| 153 |
+
device=self.device).to(self.device)
|
| 154 |
+
elif args.sgc==2:
|
| 155 |
+
model = SGC1(nfeat=data.feat_train.shape[1], nhid=args.hidden,
|
| 156 |
+
nclass=data.nclass, dropout=args.dropout,
|
| 157 |
+
nlayers=args.nlayers, with_bn=False,
|
| 158 |
+
device=self.device).to(self.device)
|
| 159 |
+
|
| 160 |
+
else:
|
| 161 |
+
model = GCN(nfeat=data.feat_train.shape[1], nhid=args.hidden,
|
| 162 |
+
nclass=data.nclass, dropout=args.dropout, nlayers=args.nlayers,
|
| 163 |
+
device=self.device).to(self.device)
|
| 164 |
+
|
| 165 |
+
model.initialize()
|
| 166 |
+
|
| 167 |
+
model_parameters = list(model.parameters())
|
| 168 |
+
|
| 169 |
+
optimizer_model = torch.optim.Adam(model_parameters, lr=args.lr_model)
|
| 170 |
+
model.train()
|
| 171 |
+
|
| 172 |
+
for ol in range(outer_loop):
|
| 173 |
+
adj_syn = pge(self.feat_syn)
|
| 174 |
+
adj_syn_norm = utils.normalize_adj_tensor(adj_syn, sparse=False)
|
| 175 |
+
feat_syn_norm = feat_syn
|
| 176 |
+
|
| 177 |
+
BN_flag = False
|
| 178 |
+
for module in model.modules():
|
| 179 |
+
if 'BatchNorm' in module._get_name(): #BatchNorm
|
| 180 |
+
BN_flag = True
|
| 181 |
+
if BN_flag:
|
| 182 |
+
model.train() # for updating the mu, sigma of BatchNorm
|
| 183 |
+
output_real = model.forward(features, adj_norm)
|
| 184 |
+
for module in model.modules():
|
| 185 |
+
if 'BatchNorm' in module._get_name(): #BatchNorm
|
| 186 |
+
module.eval() # fix mu and sigma of every BatchNorm layer
|
| 187 |
+
|
| 188 |
+
loss = torch.tensor(0.0).to(self.device)
|
| 189 |
+
for c in range(data.nclass):
|
| 190 |
+
if c not in self.num_class_dict:
|
| 191 |
+
continue
|
| 192 |
+
|
| 193 |
+
batch_size, n_id, adjs = data.retrieve_class_sampler(
|
| 194 |
+
c, adj, transductive=False, args=args)
|
| 195 |
+
|
| 196 |
+
if args.nlayers == 1:
|
| 197 |
+
adjs = [adjs]
|
| 198 |
+
adjs = [adj.to(self.device) for adj in adjs]
|
| 199 |
+
output = model.forward_sampler(features[n_id], adjs)
|
| 200 |
+
loss_real = F.nll_loss(output, labels[n_id[:batch_size]])
|
| 201 |
+
gw_real = torch.autograd.grad(loss_real, model_parameters)
|
| 202 |
+
gw_real = list((_.detach().clone() for _ in gw_real))
|
| 203 |
+
|
| 204 |
+
ind = syn_class_indices[c]
|
| 205 |
+
if args.nlayers == 1:
|
| 206 |
+
adj_syn_norm_list = [adj_syn_norm[ind[0]: ind[1]]]
|
| 207 |
+
else:
|
| 208 |
+
adj_syn_norm_list = [adj_syn_norm]*(args.nlayers-1) + \
|
| 209 |
+
[adj_syn_norm[ind[0]: ind[1]]]
|
| 210 |
+
|
| 211 |
+
output_syn = model.forward_sampler_syn(feat_syn, adj_syn_norm_list)
|
| 212 |
+
loss_syn = F.nll_loss(output_syn, labels_syn[ind[0]: ind[1]])
|
| 213 |
+
|
| 214 |
+
gw_syn = torch.autograd.grad(loss_syn, model_parameters, create_graph=True)
|
| 215 |
+
coeff = self.num_class_dict[c] / max(self.num_class_dict.values())
|
| 216 |
+
loss += coeff * match_loss(gw_syn, gw_real, args, device=self.device)
|
| 217 |
+
|
| 218 |
+
loss_avg += loss.item()
|
| 219 |
+
# TODO: regularize
|
| 220 |
+
if args.alpha > 0:
|
| 221 |
+
loss_reg = args.alpha * regularization(adj_syn, utils.tensor2onehot(labels_syn))
|
| 222 |
+
# else:
|
| 223 |
+
else:
|
| 224 |
+
loss_reg = torch.tensor(0)
|
| 225 |
+
|
| 226 |
+
loss = loss + loss_reg
|
| 227 |
+
|
| 228 |
+
# update sythetic graph
|
| 229 |
+
self.optimizer_feat.zero_grad()
|
| 230 |
+
self.optimizer_pge.zero_grad()
|
| 231 |
+
loss.backward()
|
| 232 |
+
if it % 50 < 10:
|
| 233 |
+
self.optimizer_pge.step()
|
| 234 |
+
else:
|
| 235 |
+
self.optimizer_feat.step()
|
| 236 |
+
|
| 237 |
+
if args.debug and ol % 5 ==0:
|
| 238 |
+
print('Gradient matching loss:', loss.item())
|
| 239 |
+
|
| 240 |
+
if ol == outer_loop - 1:
|
| 241 |
+
# print('loss_reg:', loss_reg.item())
|
| 242 |
+
# print('Gradient matching loss:', loss.item())
|
| 243 |
+
break
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
feat_syn_inner = feat_syn.detach()
|
| 247 |
+
adj_syn_inner = pge.inference(feat_syn)
|
| 248 |
+
adj_syn_inner_norm = utils.normalize_adj_tensor(adj_syn_inner, sparse=False)
|
| 249 |
+
feat_syn_inner_norm = feat_syn_inner
|
| 250 |
+
for j in range(inner_loop):
|
| 251 |
+
optimizer_model.zero_grad()
|
| 252 |
+
output_syn_inner = model.forward(feat_syn_inner_norm, adj_syn_inner_norm)
|
| 253 |
+
loss_syn_inner = F.nll_loss(output_syn_inner, labels_syn)
|
| 254 |
+
loss_syn_inner.backward()
|
| 255 |
+
optimizer_model.step() # update gnn param
|
| 256 |
+
|
| 257 |
+
loss_avg /= (data.nclass*outer_loop)
|
| 258 |
+
if it % 50 == 0:
|
| 259 |
+
print('Epoch {}, loss_avg: {}'.format(it, loss_avg))
|
| 260 |
+
|
| 261 |
+
eval_epochs = [100, 200, 400, 600, 800, 1000, 1200, 1400, 1600, 1800, 2000, 3000, 4000, 5000]
|
| 262 |
+
|
| 263 |
+
if verbose and it in eval_epochs:
|
| 264 |
+
# if verbose and (it+1) % 500 == 0:
|
| 265 |
+
res = []
|
| 266 |
+
runs = 1 if args.dataset in ['ogbn-arxiv', 'reddit', 'flickr'] else 3
|
| 267 |
+
for i in range(runs):
|
| 268 |
+
# self.test()
|
| 269 |
+
res.append(self.test_with_val())
|
| 270 |
+
res = np.array(res)
|
| 271 |
+
print('Test:',
|
| 272 |
+
repr([res.mean(0), res.std(0)]))
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def get_sub_adj_feat(self, features):
|
| 277 |
+
data = self.data
|
| 278 |
+
args = self.args
|
| 279 |
+
idx_selected = []
|
| 280 |
+
|
| 281 |
+
from collections import Counter;
|
| 282 |
+
counter = Counter(self.labels_syn.cpu().numpy())
|
| 283 |
+
|
| 284 |
+
for c in range(data.nclass):
|
| 285 |
+
tmp = data.retrieve_class(c, num=counter[c])
|
| 286 |
+
tmp = list(tmp)
|
| 287 |
+
idx_selected = idx_selected + tmp
|
| 288 |
+
idx_selected = np.array(idx_selected).reshape(-1)
|
| 289 |
+
features = features[idx_selected]
|
| 290 |
+
|
| 291 |
+
# adj_knn = torch.zeros((data.nclass*args.nsamples, data.nclass*args.nsamples)).to(self.device)
|
| 292 |
+
# for i in range(data.nclass):
|
| 293 |
+
# idx = np.arange(i*args.nsamples, i*args.nsamples+args.nsamples)
|
| 294 |
+
# adj_knn[np.ix_(idx, idx)] = 1
|
| 295 |
+
|
| 296 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 297 |
+
# features[features!=0] = 1
|
| 298 |
+
k = 2
|
| 299 |
+
sims = cosine_similarity(features.cpu().numpy())
|
| 300 |
+
sims[(np.arange(len(sims)), np.arange(len(sims)))] = 0
|
| 301 |
+
for i in range(len(sims)):
|
| 302 |
+
indices_argsort = np.argsort(sims[i])
|
| 303 |
+
sims[i, indices_argsort[: -k]] = 0
|
| 304 |
+
adj_knn = torch.FloatTensor(sims).to(self.device)
|
| 305 |
+
return features, adj_knn
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def get_loops(args):
|
| 309 |
+
# Get the two hyper-parameters of outer-loop and inner-loop.
|
| 310 |
+
# The following values are empirically good.
|
| 311 |
+
if args.one_step:
|
| 312 |
+
return 10, 0
|
| 313 |
+
|
| 314 |
+
if args.dataset in ['ogbn-arxiv']:
|
| 315 |
+
return 20, 0
|
| 316 |
+
if args.dataset in ['reddit']:
|
| 317 |
+
return args.outer, args.inner
|
| 318 |
+
if args.dataset in ['flickr']:
|
| 319 |
+
return args.outer, args.inner
|
| 320 |
+
# return 10, 1
|
| 321 |
+
if args.dataset in ['cora']:
|
| 322 |
+
return 20, 10
|
| 323 |
+
if args.dataset in ['citeseer']:
|
| 324 |
+
return 20, 5 # at least 200 epochs
|
| 325 |
+
else:
|
| 326 |
+
return 20, 5
|
| 327 |
+
|
GCond/gcond_agent_transduct.py
ADDED
|
@@ -0,0 +1,326 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.optim as optim
|
| 5 |
+
from torch.nn import Parameter
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
from utils import match_loss, regularization, row_normalize_tensor
|
| 8 |
+
import deeprobust.graph.utils as utils
|
| 9 |
+
from copy import deepcopy
|
| 10 |
+
import numpy as np
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
from models.gcn import GCN
|
| 13 |
+
from models.sgc import SGC
|
| 14 |
+
from models.sgc_multi import SGC as SGC1
|
| 15 |
+
from models.parametrized_adj import PGE
|
| 16 |
+
import scipy.sparse as sp
|
| 17 |
+
from torch_sparse import SparseTensor
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class GCond:
|
| 21 |
+
|
| 22 |
+
def __init__(self, data, args, device='cuda', **kwargs):
|
| 23 |
+
self.data = data
|
| 24 |
+
self.args = args
|
| 25 |
+
self.device = device
|
| 26 |
+
|
| 27 |
+
# n = data.nclass * args.nsamples
|
| 28 |
+
n = int(data.feat_train.shape[0] * args.reduction_rate)
|
| 29 |
+
# from collections import Counter; print(Counter(data.labels_train))
|
| 30 |
+
|
| 31 |
+
d = data.feat_train.shape[1]
|
| 32 |
+
self.nnodes_syn = n
|
| 33 |
+
self.feat_syn = nn.Parameter(torch.FloatTensor(n, d).to(device))
|
| 34 |
+
self.pge = PGE(nfeat=d, nnodes=n, device=device,args=args).to(device)
|
| 35 |
+
|
| 36 |
+
self.labels_syn = torch.LongTensor(self.generate_labels_syn(data)).to(device)
|
| 37 |
+
|
| 38 |
+
self.reset_parameters()
|
| 39 |
+
self.optimizer_feat = torch.optim.Adam([self.feat_syn], lr=args.lr_feat)
|
| 40 |
+
self.optimizer_pge = torch.optim.Adam(self.pge.parameters(), lr=args.lr_adj)
|
| 41 |
+
print('adj_syn:', (n,n), 'feat_syn:', self.feat_syn.shape)
|
| 42 |
+
|
| 43 |
+
def reset_parameters(self):
|
| 44 |
+
self.feat_syn.data.copy_(torch.randn(self.feat_syn.size()))
|
| 45 |
+
|
| 46 |
+
def generate_labels_syn(self, data):
|
| 47 |
+
from collections import Counter
|
| 48 |
+
counter = Counter(data.labels_train)
|
| 49 |
+
num_class_dict = {}
|
| 50 |
+
n = len(data.labels_train)
|
| 51 |
+
|
| 52 |
+
sorted_counter = sorted(counter.items(), key=lambda x:x[1])
|
| 53 |
+
sum_ = 0
|
| 54 |
+
labels_syn = []
|
| 55 |
+
self.syn_class_indices = {}
|
| 56 |
+
for ix, (c, num) in enumerate(sorted_counter):
|
| 57 |
+
if ix == len(sorted_counter) - 1:
|
| 58 |
+
num_class_dict[c] = int(n * self.args.reduction_rate) - sum_
|
| 59 |
+
self.syn_class_indices[c] = [len(labels_syn), len(labels_syn) + num_class_dict[c]]
|
| 60 |
+
labels_syn += [c] * num_class_dict[c]
|
| 61 |
+
else:
|
| 62 |
+
num_class_dict[c] = max(int(num * self.args.reduction_rate), 1)
|
| 63 |
+
sum_ += num_class_dict[c]
|
| 64 |
+
self.syn_class_indices[c] = [len(labels_syn), len(labels_syn) + num_class_dict[c]]
|
| 65 |
+
labels_syn += [c] * num_class_dict[c]
|
| 66 |
+
|
| 67 |
+
self.num_class_dict = num_class_dict
|
| 68 |
+
return labels_syn
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def test_with_val(self, verbose=True):
|
| 72 |
+
res = []
|
| 73 |
+
|
| 74 |
+
data, device = self.data, self.device
|
| 75 |
+
feat_syn, pge, labels_syn = self.feat_syn.detach(), \
|
| 76 |
+
self.pge, self.labels_syn
|
| 77 |
+
|
| 78 |
+
# with_bn = True if args.dataset in ['ogbn-arxiv'] else False
|
| 79 |
+
model = GCN(nfeat=feat_syn.shape[1], nhid=self.args.hidden, dropout=0.5,
|
| 80 |
+
weight_decay=5e-4, nlayers=2,
|
| 81 |
+
nclass=data.nclass, device=device).to(device)
|
| 82 |
+
|
| 83 |
+
if self.args.dataset in ['ogbn-arxiv']:
|
| 84 |
+
model = GCN(nfeat=feat_syn.shape[1], nhid=self.args.hidden, dropout=0.5,
|
| 85 |
+
weight_decay=0e-4, nlayers=2, with_bn=False,
|
| 86 |
+
nclass=data.nclass, device=device).to(device)
|
| 87 |
+
|
| 88 |
+
adj_syn = pge.inference(feat_syn)
|
| 89 |
+
args = self.args
|
| 90 |
+
|
| 91 |
+
if self.args.save:
|
| 92 |
+
torch.save(adj_syn, f'saved_ours/adj_{args.dataset}_{args.reduction_rate}_{args.seed}.pt')
|
| 93 |
+
torch.save(feat_syn, f'saved_ours/feat_{args.dataset}_{args.reduction_rate}_{args.seed}.pt')
|
| 94 |
+
|
| 95 |
+
if self.args.lr_adj == 0:
|
| 96 |
+
n = len(labels_syn)
|
| 97 |
+
adj_syn = torch.zeros((n, n))
|
| 98 |
+
|
| 99 |
+
model.fit_with_val(feat_syn, adj_syn, labels_syn, data,
|
| 100 |
+
train_iters=600, normalize=True, verbose=False)
|
| 101 |
+
|
| 102 |
+
model.eval()
|
| 103 |
+
labels_test = torch.LongTensor(data.labels_test).cuda()
|
| 104 |
+
|
| 105 |
+
labels_train = torch.LongTensor(data.labels_train).cuda()
|
| 106 |
+
output = model.predict(data.feat_train, data.adj_train)
|
| 107 |
+
loss_train = F.nll_loss(output, labels_train)
|
| 108 |
+
acc_train = utils.accuracy(output, labels_train)
|
| 109 |
+
if verbose:
|
| 110 |
+
print("Train set results:",
|
| 111 |
+
"loss= {:.4f}".format(loss_train.item()),
|
| 112 |
+
"accuracy= {:.4f}".format(acc_train.item()))
|
| 113 |
+
res.append(acc_train.item())
|
| 114 |
+
|
| 115 |
+
# Full graph
|
| 116 |
+
output = model.predict(data.feat_full, data.adj_full)
|
| 117 |
+
loss_test = F.nll_loss(output[data.idx_test], labels_test)
|
| 118 |
+
acc_test = utils.accuracy(output[data.idx_test], labels_test)
|
| 119 |
+
res.append(acc_test.item())
|
| 120 |
+
if verbose:
|
| 121 |
+
print("Test set results:",
|
| 122 |
+
"loss= {:.4f}".format(loss_test.item()),
|
| 123 |
+
"accuracy= {:.4f}".format(acc_test.item()))
|
| 124 |
+
return res
|
| 125 |
+
|
| 126 |
+
def train(self, verbose=True):
|
| 127 |
+
args = self.args
|
| 128 |
+
data = self.data
|
| 129 |
+
feat_syn, pge, labels_syn = self.feat_syn, self.pge, self.labels_syn
|
| 130 |
+
features, adj, labels = data.feat_full, data.adj_full, data.labels_full
|
| 131 |
+
idx_train = data.idx_train
|
| 132 |
+
|
| 133 |
+
syn_class_indices = self.syn_class_indices
|
| 134 |
+
|
| 135 |
+
features, adj, labels = utils.to_tensor(features, adj, labels, device=self.device)
|
| 136 |
+
|
| 137 |
+
feat_sub, adj_sub = self.get_sub_adj_feat(features)
|
| 138 |
+
self.feat_syn.data.copy_(feat_sub)
|
| 139 |
+
|
| 140 |
+
if utils.is_sparse_tensor(adj):
|
| 141 |
+
adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
|
| 142 |
+
else:
|
| 143 |
+
adj_norm = utils.normalize_adj_tensor(adj)
|
| 144 |
+
|
| 145 |
+
adj = adj_norm
|
| 146 |
+
adj = SparseTensor(row=adj._indices()[0], col=adj._indices()[1],
|
| 147 |
+
value=adj._values(), sparse_sizes=adj.size()).t()
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
outer_loop, inner_loop = get_loops(args)
|
| 151 |
+
loss_avg = 0
|
| 152 |
+
|
| 153 |
+
for it in range(args.epochs+1):
|
| 154 |
+
if args.dataset in ['ogbn-arxiv']:
|
| 155 |
+
model = SGC1(nfeat=feat_syn.shape[1], nhid=self.args.hidden,
|
| 156 |
+
dropout=0.0, with_bn=False,
|
| 157 |
+
weight_decay=0e-4, nlayers=2,
|
| 158 |
+
nclass=data.nclass,
|
| 159 |
+
device=self.device).to(self.device)
|
| 160 |
+
else:
|
| 161 |
+
if args.sgc == 1:
|
| 162 |
+
model = SGC(nfeat=data.feat_train.shape[1], nhid=args.hidden,
|
| 163 |
+
nclass=data.nclass, dropout=args.dropout,
|
| 164 |
+
nlayers=args.nlayers, with_bn=False,
|
| 165 |
+
device=self.device).to(self.device)
|
| 166 |
+
else:
|
| 167 |
+
model = GCN(nfeat=data.feat_train.shape[1], nhid=args.hidden,
|
| 168 |
+
nclass=data.nclass, dropout=args.dropout, nlayers=args.nlayers,
|
| 169 |
+
device=self.device).to(self.device)
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
model.initialize()
|
| 173 |
+
|
| 174 |
+
model_parameters = list(model.parameters())
|
| 175 |
+
|
| 176 |
+
optimizer_model = torch.optim.Adam(model_parameters, lr=args.lr_model)
|
| 177 |
+
model.train()
|
| 178 |
+
|
| 179 |
+
for ol in range(outer_loop):
|
| 180 |
+
adj_syn = pge(self.feat_syn)
|
| 181 |
+
adj_syn_norm = utils.normalize_adj_tensor(adj_syn, sparse=False)
|
| 182 |
+
feat_syn_norm = feat_syn
|
| 183 |
+
|
| 184 |
+
BN_flag = False
|
| 185 |
+
for module in model.modules():
|
| 186 |
+
if 'BatchNorm' in module._get_name(): #BatchNorm
|
| 187 |
+
BN_flag = True
|
| 188 |
+
if BN_flag:
|
| 189 |
+
model.train() # for updating the mu, sigma of BatchNorm
|
| 190 |
+
output_real = model.forward(features, adj_norm)
|
| 191 |
+
for module in model.modules():
|
| 192 |
+
if 'BatchNorm' in module._get_name(): #BatchNorm
|
| 193 |
+
module.eval() # fix mu and sigma of every BatchNorm layer
|
| 194 |
+
|
| 195 |
+
loss = torch.tensor(0.0).to(self.device)
|
| 196 |
+
for c in range(data.nclass):
|
| 197 |
+
batch_size, n_id, adjs = data.retrieve_class_sampler(
|
| 198 |
+
c, adj, transductive=True, args=args)
|
| 199 |
+
if args.nlayers == 1:
|
| 200 |
+
adjs = [adjs]
|
| 201 |
+
|
| 202 |
+
adjs = [adj.to(self.device) for adj in adjs]
|
| 203 |
+
output = model.forward_sampler(features[n_id], adjs)
|
| 204 |
+
loss_real = F.nll_loss(output, labels[n_id[:batch_size]])
|
| 205 |
+
|
| 206 |
+
gw_real = torch.autograd.grad(loss_real, model_parameters)
|
| 207 |
+
gw_real = list((_.detach().clone() for _ in gw_real))
|
| 208 |
+
output_syn = model.forward(feat_syn, adj_syn_norm)
|
| 209 |
+
|
| 210 |
+
ind = syn_class_indices[c]
|
| 211 |
+
loss_syn = F.nll_loss(
|
| 212 |
+
output_syn[ind[0]: ind[1]],
|
| 213 |
+
labels_syn[ind[0]: ind[1]])
|
| 214 |
+
gw_syn = torch.autograd.grad(loss_syn, model_parameters, create_graph=True)
|
| 215 |
+
coeff = self.num_class_dict[c] / max(self.num_class_dict.values())
|
| 216 |
+
loss += coeff * match_loss(gw_syn, gw_real, args, device=self.device)
|
| 217 |
+
|
| 218 |
+
loss_avg += loss.item()
|
| 219 |
+
# TODO: regularize
|
| 220 |
+
if args.alpha > 0:
|
| 221 |
+
loss_reg = args.alpha * regularization(adj_syn, utils.tensor2onehot(labels_syn))
|
| 222 |
+
else:
|
| 223 |
+
loss_reg = torch.tensor(0)
|
| 224 |
+
|
| 225 |
+
loss = loss + loss_reg
|
| 226 |
+
|
| 227 |
+
# update sythetic graph
|
| 228 |
+
self.optimizer_feat.zero_grad()
|
| 229 |
+
self.optimizer_pge.zero_grad()
|
| 230 |
+
loss.backward()
|
| 231 |
+
if it % 50 < 10:
|
| 232 |
+
self.optimizer_pge.step()
|
| 233 |
+
else:
|
| 234 |
+
self.optimizer_feat.step()
|
| 235 |
+
|
| 236 |
+
if args.debug and ol % 5 ==0:
|
| 237 |
+
print('Gradient matching loss:', loss.item())
|
| 238 |
+
|
| 239 |
+
if ol == outer_loop - 1:
|
| 240 |
+
# print('loss_reg:', loss_reg.item())
|
| 241 |
+
# print('Gradient matching loss:', loss.item())
|
| 242 |
+
break
|
| 243 |
+
|
| 244 |
+
feat_syn_inner = feat_syn.detach()
|
| 245 |
+
adj_syn_inner = pge.inference(feat_syn_inner)
|
| 246 |
+
adj_syn_inner_norm = utils.normalize_adj_tensor(adj_syn_inner, sparse=False)
|
| 247 |
+
feat_syn_inner_norm = feat_syn_inner
|
| 248 |
+
for j in range(inner_loop):
|
| 249 |
+
optimizer_model.zero_grad()
|
| 250 |
+
output_syn_inner = model.forward(feat_syn_inner_norm, adj_syn_inner_norm)
|
| 251 |
+
loss_syn_inner = F.nll_loss(output_syn_inner, labels_syn)
|
| 252 |
+
loss_syn_inner.backward()
|
| 253 |
+
# print(loss_syn_inner.item())
|
| 254 |
+
optimizer_model.step() # update gnn param
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
loss_avg /= (data.nclass*outer_loop)
|
| 258 |
+
if it % 50 == 0:
|
| 259 |
+
print('Epoch {}, loss_avg: {}'.format(it, loss_avg))
|
| 260 |
+
|
| 261 |
+
eval_epochs = [400, 600, 800, 1000, 1200, 1600, 2000, 3000, 4000, 5000]
|
| 262 |
+
|
| 263 |
+
if verbose and it in eval_epochs:
|
| 264 |
+
# if verbose and (it+1) % 50 == 0:
|
| 265 |
+
res = []
|
| 266 |
+
runs = 1 if args.dataset in ['ogbn-arxiv'] else 3
|
| 267 |
+
for i in range(runs):
|
| 268 |
+
if args.dataset in ['ogbn-arxiv']:
|
| 269 |
+
res.append(self.test_with_val())
|
| 270 |
+
else:
|
| 271 |
+
res.append(self.test_with_val())
|
| 272 |
+
|
| 273 |
+
res = np.array(res)
|
| 274 |
+
print('Train/Test Mean Accuracy:',
|
| 275 |
+
repr([res.mean(0), res.std(0)]))
|
| 276 |
+
|
| 277 |
+
def get_sub_adj_feat(self, features):
|
| 278 |
+
data = self.data
|
| 279 |
+
args = self.args
|
| 280 |
+
idx_selected = []
|
| 281 |
+
|
| 282 |
+
from collections import Counter;
|
| 283 |
+
counter = Counter(self.labels_syn.cpu().numpy())
|
| 284 |
+
|
| 285 |
+
for c in range(data.nclass):
|
| 286 |
+
tmp = data.retrieve_class(c, num=counter[c])
|
| 287 |
+
tmp = list(tmp)
|
| 288 |
+
idx_selected = idx_selected + tmp
|
| 289 |
+
idx_selected = np.array(idx_selected).reshape(-1)
|
| 290 |
+
features = features[self.data.idx_train][idx_selected]
|
| 291 |
+
|
| 292 |
+
# adj_knn = torch.zeros((data.nclass*args.nsamples, data.nclass*args.nsamples)).to(self.device)
|
| 293 |
+
# for i in range(data.nclass):
|
| 294 |
+
# idx = np.arange(i*args.nsamples, i*args.nsamples+args.nsamples)
|
| 295 |
+
# adj_knn[np.ix_(idx, idx)] = 1
|
| 296 |
+
|
| 297 |
+
from sklearn.metrics.pairwise import cosine_similarity
|
| 298 |
+
# features[features!=0] = 1
|
| 299 |
+
k = 2
|
| 300 |
+
sims = cosine_similarity(features.cpu().numpy())
|
| 301 |
+
sims[(np.arange(len(sims)), np.arange(len(sims)))] = 0
|
| 302 |
+
for i in range(len(sims)):
|
| 303 |
+
indices_argsort = np.argsort(sims[i])
|
| 304 |
+
sims[i, indices_argsort[: -k]] = 0
|
| 305 |
+
adj_knn = torch.FloatTensor(sims).to(self.device)
|
| 306 |
+
return features, adj_knn
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
def get_loops(args):
|
| 310 |
+
# Get the two hyper-parameters of outer-loop and inner-loop.
|
| 311 |
+
# The following values are empirically good.
|
| 312 |
+
if args.one_step:
|
| 313 |
+
if args.dataset =='ogbn-arxiv':
|
| 314 |
+
return 5, 0
|
| 315 |
+
return 1, 0
|
| 316 |
+
if args.dataset in ['ogbn-arxiv']:
|
| 317 |
+
return args.outer, args.inner
|
| 318 |
+
if args.dataset in ['cora']:
|
| 319 |
+
return 20, 15 # sgc
|
| 320 |
+
if args.dataset in ['citeseer']:
|
| 321 |
+
return 20, 15
|
| 322 |
+
if args.dataset in ['physics']:
|
| 323 |
+
return 20, 10
|
| 324 |
+
else:
|
| 325 |
+
return 20, 10
|
| 326 |
+
|
GCond/models/__pycache__/gcn.cpython-312.pyc
ADDED
|
Binary file (21 kB). View file
|
|
|
GCond/models/gat.py
ADDED
|
@@ -0,0 +1,312 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Extended from https://github.com/rusty1s/pytorch_geometric/tree/master/benchmark/citation
|
| 3 |
+
"""
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
import math
|
| 7 |
+
import torch
|
| 8 |
+
import torch.optim as optim
|
| 9 |
+
from torch.nn.parameter import Parameter
|
| 10 |
+
from torch.nn.modules.module import Module
|
| 11 |
+
from deeprobust.graph import utils
|
| 12 |
+
from copy import deepcopy
|
| 13 |
+
from torch_geometric.nn import SGConv
|
| 14 |
+
from torch_geometric.nn import APPNP as ModuleAPPNP
|
| 15 |
+
# from torch_geometric.nn import GATConv
|
| 16 |
+
from .mygatconv import GATConv
|
| 17 |
+
import numpy as np
|
| 18 |
+
import scipy.sparse as sp
|
| 19 |
+
|
| 20 |
+
from torch.nn import Linear
|
| 21 |
+
from itertools import repeat
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class GAT(torch.nn.Module):
|
| 25 |
+
|
| 26 |
+
def __init__(self, nfeat, nhid, nclass, heads=8, output_heads=1, dropout=0.5, lr=0.01,
|
| 27 |
+
weight_decay=5e-4, with_bias=True, device=None, **kwargs):
|
| 28 |
+
|
| 29 |
+
super(GAT, self).__init__()
|
| 30 |
+
|
| 31 |
+
assert device is not None, "Please specify 'device'!"
|
| 32 |
+
self.device = device
|
| 33 |
+
self.dropout = dropout
|
| 34 |
+
self.lr = lr
|
| 35 |
+
self.weight_decay = weight_decay
|
| 36 |
+
|
| 37 |
+
if 'dataset' in kwargs:
|
| 38 |
+
if kwargs['dataset'] in ['ogbn-arxiv']:
|
| 39 |
+
dropout = 0.7 # arxiv
|
| 40 |
+
elif kwargs['dataset'] in ['reddit']:
|
| 41 |
+
dropout = 0.05; self.dropout = 0.1; self.weight_decay = 5e-4
|
| 42 |
+
# self.weight_decay = 5e-2; dropout=0.05; self.dropout=0.1
|
| 43 |
+
elif kwargs['dataset'] in ['citeseer']:
|
| 44 |
+
dropout = 0.7
|
| 45 |
+
self.weight_decay = 5e-4
|
| 46 |
+
elif kwargs['dataset'] in ['flickr']:
|
| 47 |
+
dropout = 0.8
|
| 48 |
+
# nhid=8; heads=8
|
| 49 |
+
# self.dropout=0.1
|
| 50 |
+
else:
|
| 51 |
+
dropout = 0.7 # cora, citeseer, reddit
|
| 52 |
+
else:
|
| 53 |
+
dropout = 0.7
|
| 54 |
+
self.conv1 = GATConv(
|
| 55 |
+
nfeat,
|
| 56 |
+
nhid,
|
| 57 |
+
heads=heads,
|
| 58 |
+
dropout=dropout,
|
| 59 |
+
bias=with_bias)
|
| 60 |
+
|
| 61 |
+
self.conv2 = GATConv(
|
| 62 |
+
nhid * heads,
|
| 63 |
+
nclass,
|
| 64 |
+
heads=output_heads,
|
| 65 |
+
concat=False,
|
| 66 |
+
dropout=dropout,
|
| 67 |
+
bias=with_bias)
|
| 68 |
+
|
| 69 |
+
self.output = None
|
| 70 |
+
self.best_model = None
|
| 71 |
+
self.best_output = None
|
| 72 |
+
|
| 73 |
+
# def forward(self, data):
|
| 74 |
+
# x, edge_index = data.x, data.edge_index
|
| 75 |
+
# x = F.dropout(x, p=self.dropout, training=self.training)
|
| 76 |
+
# x = F.elu(self.conv1(x, edge_index))
|
| 77 |
+
# x = F.dropout(x, p=self.dropout, training=self.training)
|
| 78 |
+
# x = self.conv2(x, edge_index)
|
| 79 |
+
# return F.log_softmax(x, dim=1)
|
| 80 |
+
|
| 81 |
+
def forward(self, data):
|
| 82 |
+
# x, edge_index = data.x, data.edge_index
|
| 83 |
+
x, edge_index, edge_weight = data.x, data.edge_index, data.edge_weight
|
| 84 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
| 85 |
+
x = F.elu(self.conv1(x, edge_index, edge_weight=edge_weight))
|
| 86 |
+
# print(self.conv1.att_l.sum())
|
| 87 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
| 88 |
+
x = self.conv2(x, edge_index, edge_weight=edge_weight)
|
| 89 |
+
return F.log_softmax(x, dim=1)
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def initialize(self):
|
| 93 |
+
"""Initialize parameters of GAT.
|
| 94 |
+
"""
|
| 95 |
+
self.conv1.reset_parameters()
|
| 96 |
+
self.conv2.reset_parameters()
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def fit(self, feat, adj, labels, idx, data=None, train_iters=600, initialize=True, verbose=False, patience=None, noval=False, **kwargs):
|
| 100 |
+
|
| 101 |
+
data_train = GraphData(feat, adj, labels)
|
| 102 |
+
data_train = Dpr2Pyg(data_train)[0]
|
| 103 |
+
|
| 104 |
+
data_test = Dpr2Pyg(GraphData(data.feat_test, data.adj_test, None))[0]
|
| 105 |
+
|
| 106 |
+
if noval:
|
| 107 |
+
data_val = GraphData(data.feat_val, data.adj_val, None)
|
| 108 |
+
data_val = Dpr2Pyg(data_val)[0]
|
| 109 |
+
else:
|
| 110 |
+
data_val = GraphData(data.feat_full, data.adj_full, None)
|
| 111 |
+
data_val = Dpr2Pyg(data_val)[0]
|
| 112 |
+
|
| 113 |
+
labels_val = torch.LongTensor(data.labels_val).to(self.device)
|
| 114 |
+
|
| 115 |
+
if initialize:
|
| 116 |
+
self.initialize()
|
| 117 |
+
|
| 118 |
+
if len(data_train.y.shape) > 1:
|
| 119 |
+
self.multi_label = True
|
| 120 |
+
self.loss = torch.nn.BCELoss()
|
| 121 |
+
else:
|
| 122 |
+
self.multi_label = False
|
| 123 |
+
self.loss = F.nll_loss
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
data_train.y = data_train.y.float() if self.multi_label else data_train.y
|
| 127 |
+
# data_val.y = data_val.y.float() if self.multi_label else data_val.y
|
| 128 |
+
|
| 129 |
+
if verbose:
|
| 130 |
+
print('=== training gat model ===')
|
| 131 |
+
|
| 132 |
+
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
| 133 |
+
best_acc_val = 0
|
| 134 |
+
best_loss_val = 100
|
| 135 |
+
for i in range(train_iters):
|
| 136 |
+
# if i == train_iters // 2:
|
| 137 |
+
if i in [1500]:
|
| 138 |
+
lr = self.lr*0.1
|
| 139 |
+
optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=self.weight_decay)
|
| 140 |
+
|
| 141 |
+
self.train()
|
| 142 |
+
optimizer.zero_grad()
|
| 143 |
+
output = self.forward(data_train)
|
| 144 |
+
loss_train = self.loss(output, data_train.y)
|
| 145 |
+
loss_train.backward()
|
| 146 |
+
optimizer.step()
|
| 147 |
+
|
| 148 |
+
with torch.no_grad():
|
| 149 |
+
self.eval()
|
| 150 |
+
|
| 151 |
+
output = self.forward(data_val)
|
| 152 |
+
if noval:
|
| 153 |
+
loss_val = F.nll_loss(output, labels_val)
|
| 154 |
+
acc_val = utils.accuracy(output, labels_val)
|
| 155 |
+
else:
|
| 156 |
+
loss_val = F.nll_loss(output[data.idx_val], labels_val)
|
| 157 |
+
acc_val = utils.accuracy(output[data.idx_val], labels_val)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
if loss_val < best_loss_val:
|
| 161 |
+
best_loss_val = loss_val
|
| 162 |
+
self.output = output
|
| 163 |
+
weights = deepcopy(self.state_dict())
|
| 164 |
+
|
| 165 |
+
if acc_val > best_acc_val:
|
| 166 |
+
best_acc_val = acc_val
|
| 167 |
+
self.output = output
|
| 168 |
+
weights = deepcopy(self.state_dict())
|
| 169 |
+
# print(best_acc_val)
|
| 170 |
+
# output = self.forward(data_test)
|
| 171 |
+
# labels_test = torch.LongTensor(data.labels_test).to(self.device)
|
| 172 |
+
# loss_test = F.nll_loss(output, labels_test)
|
| 173 |
+
# acc_test = utils.accuracy(output, labels_test)
|
| 174 |
+
# print('acc_test:', acc_test.item())
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
if verbose and i % 100 == 0:
|
| 179 |
+
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
|
| 180 |
+
|
| 181 |
+
if verbose:
|
| 182 |
+
print('=== picking the best model according to the performance on validation ===')
|
| 183 |
+
self.load_state_dict(weights)
|
| 184 |
+
|
| 185 |
+
def test(self, data_test):
|
| 186 |
+
"""Evaluate GCN performance
|
| 187 |
+
"""
|
| 188 |
+
self.eval()
|
| 189 |
+
with torch.no_grad():
|
| 190 |
+
output = self.forward(data_test)
|
| 191 |
+
evaluate(output, data_test.y, self.args)
|
| 192 |
+
|
| 193 |
+
# @torch.no_grad()
|
| 194 |
+
# def predict(self, data):
|
| 195 |
+
# self.eval()
|
| 196 |
+
# return self.forward(data)
|
| 197 |
+
@torch.no_grad()
|
| 198 |
+
def predict(self, feat, adj):
|
| 199 |
+
self.eval()
|
| 200 |
+
data = GraphData(feat, adj, None)
|
| 201 |
+
data = Dpr2Pyg(data)[0]
|
| 202 |
+
return self.forward(data)
|
| 203 |
+
|
| 204 |
+
@torch.no_grad()
|
| 205 |
+
def predict_unnorm(self, feat, adj):
|
| 206 |
+
self.eval()
|
| 207 |
+
data = GraphData(feat, adj, None)
|
| 208 |
+
data = Dpr2Pyg(data)[0]
|
| 209 |
+
|
| 210 |
+
return self.forward(data)
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
class GraphData:
|
| 214 |
+
|
| 215 |
+
def __init__(self, features, adj, labels, idx_train=None, idx_val=None, idx_test=None):
|
| 216 |
+
self.adj = adj
|
| 217 |
+
self.features = features
|
| 218 |
+
self.labels = labels
|
| 219 |
+
self.idx_train = idx_train
|
| 220 |
+
self.idx_val = idx_val
|
| 221 |
+
self.idx_test = idx_test
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
from torch_geometric.data import InMemoryDataset, Data
|
| 225 |
+
import scipy.sparse as sp
|
| 226 |
+
|
| 227 |
+
class Dpr2Pyg(InMemoryDataset):
|
| 228 |
+
|
| 229 |
+
def __init__(self, dpr_data, transform=None, **kwargs):
|
| 230 |
+
root = 'data/' # dummy root; does not mean anything
|
| 231 |
+
self.dpr_data = dpr_data
|
| 232 |
+
super(Dpr2Pyg, self).__init__(root, transform)
|
| 233 |
+
pyg_data = self.process()
|
| 234 |
+
self.data, self.slices = self.collate([pyg_data])
|
| 235 |
+
self.transform = transform
|
| 236 |
+
|
| 237 |
+
def process____(self):
|
| 238 |
+
dpr_data = self.dpr_data
|
| 239 |
+
try:
|
| 240 |
+
edge_index = torch.LongTensor(dpr_data.adj.nonzero().cpu()).cuda().T
|
| 241 |
+
except:
|
| 242 |
+
edge_index = torch.LongTensor(dpr_data.adj.nonzero()).cuda()
|
| 243 |
+
# by default, the features in pyg data is dense
|
| 244 |
+
try:
|
| 245 |
+
x = torch.FloatTensor(dpr_data.features.cpu()).float().cuda()
|
| 246 |
+
except:
|
| 247 |
+
x = torch.FloatTensor(dpr_data.features).float().cuda()
|
| 248 |
+
try:
|
| 249 |
+
y = torch.LongTensor(dpr_data.labels.cpu()).cuda()
|
| 250 |
+
except:
|
| 251 |
+
y = dpr_data.labels
|
| 252 |
+
|
| 253 |
+
data = Data(x=x, edge_index=edge_index, y=y)
|
| 254 |
+
data.train_mask = None
|
| 255 |
+
data.val_mask = None
|
| 256 |
+
data.test_mask = None
|
| 257 |
+
return data
|
| 258 |
+
|
| 259 |
+
def process(self):
|
| 260 |
+
dpr_data = self.dpr_data
|
| 261 |
+
if type(dpr_data.adj) == torch.Tensor:
|
| 262 |
+
adj_selfloop = dpr_data.adj + torch.eye(dpr_data.adj.shape[0]).cuda()
|
| 263 |
+
edge_index_selfloop = adj_selfloop.nonzero().T
|
| 264 |
+
edge_index = edge_index_selfloop
|
| 265 |
+
edge_weight = adj_selfloop[edge_index_selfloop[0], edge_index_selfloop[1]]
|
| 266 |
+
else:
|
| 267 |
+
adj_selfloop = dpr_data.adj + sp.eye(dpr_data.adj.shape[0])
|
| 268 |
+
edge_index = torch.LongTensor(adj_selfloop.nonzero()).cuda()
|
| 269 |
+
edge_weight = torch.FloatTensor(adj_selfloop[adj_selfloop.nonzero()]).cuda()
|
| 270 |
+
|
| 271 |
+
# by default, the features in pyg data is dense
|
| 272 |
+
try:
|
| 273 |
+
x = torch.FloatTensor(dpr_data.features.cpu()).float().cuda()
|
| 274 |
+
except:
|
| 275 |
+
x = torch.FloatTensor(dpr_data.features).float().cuda()
|
| 276 |
+
try:
|
| 277 |
+
y = torch.LongTensor(dpr_data.labels.cpu()).cuda()
|
| 278 |
+
except:
|
| 279 |
+
y = dpr_data.labels
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
data = Data(x=x, edge_index=edge_index, y=y, edge_weight=edge_weight)
|
| 283 |
+
data.train_mask = None
|
| 284 |
+
data.val_mask = None
|
| 285 |
+
data.test_mask = None
|
| 286 |
+
return data
|
| 287 |
+
|
| 288 |
+
def get(self, idx):
|
| 289 |
+
data = self.data.__class__()
|
| 290 |
+
|
| 291 |
+
if hasattr(self.data, '__num_nodes__'):
|
| 292 |
+
data.num_nodes = self.data.__num_nodes__[idx]
|
| 293 |
+
|
| 294 |
+
for key in self.data.keys:
|
| 295 |
+
item, slices = self.data[key], self.slices[key]
|
| 296 |
+
s = list(repeat(slice(None), item.dim()))
|
| 297 |
+
s[self.data.__cat_dim__(key, item)] = slice(slices[idx],
|
| 298 |
+
slices[idx + 1])
|
| 299 |
+
data[key] = item[s]
|
| 300 |
+
return data
|
| 301 |
+
|
| 302 |
+
@property
|
| 303 |
+
def raw_file_names(self):
|
| 304 |
+
return ['some_file_1', 'some_file_2', ...]
|
| 305 |
+
|
| 306 |
+
@property
|
| 307 |
+
def processed_file_names(self):
|
| 308 |
+
return ['data.pt']
|
| 309 |
+
|
| 310 |
+
def _download(self):
|
| 311 |
+
pass
|
| 312 |
+
|
GCond/models/gcn.py
ADDED
|
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import math
|
| 4 |
+
import torch
|
| 5 |
+
import torch.optim as optim
|
| 6 |
+
from torch.nn.parameter import Parameter
|
| 7 |
+
from torch.nn.modules.module import Module
|
| 8 |
+
from deeprobust.graph import utils
|
| 9 |
+
from copy import deepcopy
|
| 10 |
+
from sklearn.metrics import f1_score
|
| 11 |
+
from torch.nn import init
|
| 12 |
+
import torch_sparse
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class GraphConvolution(Module):
|
| 16 |
+
"""Simple GCN layer, similar to https://github.com/tkipf/pygcn
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, in_features, out_features, with_bias=True):
|
| 20 |
+
super(GraphConvolution, self).__init__()
|
| 21 |
+
self.in_features = in_features
|
| 22 |
+
self.out_features = out_features
|
| 23 |
+
self.weight = Parameter(torch.FloatTensor(in_features, out_features))
|
| 24 |
+
self.bias = Parameter(torch.FloatTensor(out_features))
|
| 25 |
+
self.reset_parameters()
|
| 26 |
+
|
| 27 |
+
def reset_parameters(self):
|
| 28 |
+
stdv = 1. / math.sqrt(self.weight.T.size(1))
|
| 29 |
+
self.weight.data.uniform_(-stdv, stdv)
|
| 30 |
+
if self.bias is not None:
|
| 31 |
+
self.bias.data.uniform_(-stdv, stdv)
|
| 32 |
+
|
| 33 |
+
def forward(self, input, adj):
|
| 34 |
+
""" Graph Convolutional Layer forward function
|
| 35 |
+
"""
|
| 36 |
+
if input.data.is_sparse:
|
| 37 |
+
support = torch.spmm(input, self.weight)
|
| 38 |
+
else:
|
| 39 |
+
support = torch.mm(input, self.weight)
|
| 40 |
+
if isinstance(adj, torch_sparse.SparseTensor):
|
| 41 |
+
output = torch_sparse.matmul(adj, support)
|
| 42 |
+
else:
|
| 43 |
+
output = torch.spmm(adj, support)
|
| 44 |
+
if self.bias is not None:
|
| 45 |
+
return output + self.bias
|
| 46 |
+
else:
|
| 47 |
+
return output
|
| 48 |
+
|
| 49 |
+
def __repr__(self):
|
| 50 |
+
return self.__class__.__name__ + ' (' \
|
| 51 |
+
+ str(self.in_features) + ' -> ' \
|
| 52 |
+
+ str(self.out_features) + ')'
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
class GCN(nn.Module):
|
| 56 |
+
|
| 57 |
+
def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5, lr=0.01, weight_decay=5e-4,
|
| 58 |
+
with_relu=True, with_bias=True, with_bn=False, device=None):
|
| 59 |
+
|
| 60 |
+
super(GCN, self).__init__()
|
| 61 |
+
|
| 62 |
+
assert device is not None, "Please specify 'device'!"
|
| 63 |
+
self.device = device
|
| 64 |
+
self.nfeat = nfeat
|
| 65 |
+
self.nclass = nclass
|
| 66 |
+
|
| 67 |
+
self.layers = nn.ModuleList([])
|
| 68 |
+
|
| 69 |
+
if nlayers == 1:
|
| 70 |
+
self.layers.append(GraphConvolution(nfeat, nclass, with_bias=with_bias))
|
| 71 |
+
else:
|
| 72 |
+
if with_bn:
|
| 73 |
+
self.bns = torch.nn.ModuleList()
|
| 74 |
+
self.bns.append(nn.BatchNorm1d(nhid))
|
| 75 |
+
self.layers.append(GraphConvolution(nfeat, nhid, with_bias=with_bias))
|
| 76 |
+
for i in range(nlayers-2):
|
| 77 |
+
self.layers.append(GraphConvolution(nhid, nhid, with_bias=with_bias))
|
| 78 |
+
if with_bn:
|
| 79 |
+
self.bns.append(nn.BatchNorm1d(nhid))
|
| 80 |
+
self.layers.append(GraphConvolution(nhid, nclass, with_bias=with_bias))
|
| 81 |
+
|
| 82 |
+
self.dropout = dropout
|
| 83 |
+
self.lr = lr
|
| 84 |
+
if not with_relu:
|
| 85 |
+
self.weight_decay = 0
|
| 86 |
+
else:
|
| 87 |
+
self.weight_decay = weight_decay
|
| 88 |
+
self.with_relu = with_relu
|
| 89 |
+
self.with_bn = with_bn
|
| 90 |
+
self.with_bias = with_bias
|
| 91 |
+
self.output = None
|
| 92 |
+
self.best_model = None
|
| 93 |
+
self.best_output = None
|
| 94 |
+
self.adj_norm = None
|
| 95 |
+
self.features = None
|
| 96 |
+
self.multi_label = None
|
| 97 |
+
|
| 98 |
+
def forward(self, x, adj):
|
| 99 |
+
for ix, layer in enumerate(self.layers):
|
| 100 |
+
x = layer(x, adj)
|
| 101 |
+
if ix != len(self.layers) - 1:
|
| 102 |
+
x = self.bns[ix](x) if self.with_bn else x
|
| 103 |
+
if self.with_relu:
|
| 104 |
+
x = F.relu(x)
|
| 105 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
| 106 |
+
|
| 107 |
+
if self.multi_label:
|
| 108 |
+
return torch.sigmoid(x)
|
| 109 |
+
else:
|
| 110 |
+
return F.log_softmax(x, dim=1)
|
| 111 |
+
|
| 112 |
+
def forward_sampler(self, x, adjs):
|
| 113 |
+
# for ix, layer in enumerate(self.layers):
|
| 114 |
+
for ix, (adj, _, size) in enumerate(adjs):
|
| 115 |
+
x = self.layers[ix](x, adj)
|
| 116 |
+
if ix != len(self.layers) - 1:
|
| 117 |
+
x = self.bns[ix](x) if self.with_bn else x
|
| 118 |
+
if self.with_relu:
|
| 119 |
+
x = F.relu(x)
|
| 120 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
| 121 |
+
|
| 122 |
+
if self.multi_label:
|
| 123 |
+
return torch.sigmoid(x)
|
| 124 |
+
else:
|
| 125 |
+
return F.log_softmax(x, dim=1)
|
| 126 |
+
|
| 127 |
+
def forward_sampler_syn(self, x, adjs):
|
| 128 |
+
for ix, (adj) in enumerate(adjs):
|
| 129 |
+
x = self.layers[ix](x, adj)
|
| 130 |
+
if ix != len(self.layers) - 1:
|
| 131 |
+
x = self.bns[ix](x) if self.with_bn else x
|
| 132 |
+
if self.with_relu:
|
| 133 |
+
x = F.relu(x)
|
| 134 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
| 135 |
+
|
| 136 |
+
if self.multi_label:
|
| 137 |
+
return torch.sigmoid(x)
|
| 138 |
+
else:
|
| 139 |
+
return F.log_softmax(x, dim=1)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def initialize(self):
|
| 143 |
+
"""Initialize parameters of GCN.
|
| 144 |
+
"""
|
| 145 |
+
for layer in self.layers:
|
| 146 |
+
layer.reset_parameters()
|
| 147 |
+
if self.with_bn:
|
| 148 |
+
for bn in self.bns:
|
| 149 |
+
bn.reset_parameters()
|
| 150 |
+
|
| 151 |
+
def fit(self, features, adj, labels, idx_train, idx_val=None, train_iters=200, initialize=True, verbose=False, normalize=True, patience=None, **kwargs):
|
| 152 |
+
|
| 153 |
+
if initialize:
|
| 154 |
+
self.initialize()
|
| 155 |
+
|
| 156 |
+
# features, adj, labels = data.feat_train, data.adj_train, data.labels_train
|
| 157 |
+
if type(adj) is not torch.Tensor:
|
| 158 |
+
features, adj, labels = utils.to_tensor(features, adj, labels, device=self.device)
|
| 159 |
+
else:
|
| 160 |
+
features = features.to(self.device)
|
| 161 |
+
adj = adj.to(self.device)
|
| 162 |
+
labels = labels.to(self.device)
|
| 163 |
+
|
| 164 |
+
if normalize:
|
| 165 |
+
if utils.is_sparse_tensor(adj):
|
| 166 |
+
adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
|
| 167 |
+
else:
|
| 168 |
+
adj_norm = utils.normalize_adj_tensor(adj)
|
| 169 |
+
else:
|
| 170 |
+
adj_norm = adj
|
| 171 |
+
|
| 172 |
+
if 'feat_norm' in kwargs and kwargs['feat_norm']:
|
| 173 |
+
from utils import row_normalize_tensor
|
| 174 |
+
features = row_normalize_tensor(features-features.min())
|
| 175 |
+
|
| 176 |
+
self.adj_norm = adj_norm
|
| 177 |
+
self.features = features
|
| 178 |
+
|
| 179 |
+
if len(labels.shape) > 1:
|
| 180 |
+
self.multi_label = True
|
| 181 |
+
self.loss = torch.nn.BCELoss()
|
| 182 |
+
else:
|
| 183 |
+
self.multi_label = False
|
| 184 |
+
self.loss = F.nll_loss
|
| 185 |
+
|
| 186 |
+
labels = labels.float() if self.multi_label else labels
|
| 187 |
+
self.labels = labels
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
if idx_val is not None:
|
| 191 |
+
self._train_with_val2(labels, idx_train, idx_val, train_iters, verbose)
|
| 192 |
+
else:
|
| 193 |
+
self._train_without_val2(labels, idx_train, train_iters, verbose)
|
| 194 |
+
|
| 195 |
+
def _train_without_val2(self, labels, idx_train, train_iters, verbose):
|
| 196 |
+
self.train()
|
| 197 |
+
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
| 198 |
+
for i in range(train_iters):
|
| 199 |
+
if i == train_iters // 2:
|
| 200 |
+
lr = self.lr*0.1
|
| 201 |
+
optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=self.weight_decay)
|
| 202 |
+
|
| 203 |
+
optimizer.zero_grad()
|
| 204 |
+
output = self.forward(self.features, self.adj_norm)
|
| 205 |
+
loss_train = self.loss(output[idx_train], labels[idx_train])
|
| 206 |
+
loss_train.backward()
|
| 207 |
+
optimizer.step()
|
| 208 |
+
if verbose and i % 10 == 0:
|
| 209 |
+
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
|
| 210 |
+
|
| 211 |
+
self.eval()
|
| 212 |
+
output = self.forward(self.features, self.adj_norm)
|
| 213 |
+
self.output = output
|
| 214 |
+
|
| 215 |
+
def fit_with_val(self, features, adj, labels, data, train_iters=200, initialize=True, verbose=False, normalize=True, patience=None, noval=False, **kwargs):
|
| 216 |
+
'''data: full data class'''
|
| 217 |
+
if initialize:
|
| 218 |
+
self.initialize()
|
| 219 |
+
|
| 220 |
+
if type(adj) is not torch.Tensor:
|
| 221 |
+
features, adj, labels = utils.to_tensor(features, adj, labels, device=self.device)
|
| 222 |
+
else:
|
| 223 |
+
features = features.to(self.device)
|
| 224 |
+
adj = adj.to(self.device)
|
| 225 |
+
labels = labels.to(self.device)
|
| 226 |
+
|
| 227 |
+
if normalize:
|
| 228 |
+
if utils.is_sparse_tensor(adj):
|
| 229 |
+
adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
|
| 230 |
+
else:
|
| 231 |
+
adj_norm = utils.normalize_adj_tensor(adj)
|
| 232 |
+
else:
|
| 233 |
+
adj_norm = adj
|
| 234 |
+
|
| 235 |
+
if 'feat_norm' in kwargs and kwargs['feat_norm']:
|
| 236 |
+
from utils import row_normalize_tensor
|
| 237 |
+
features = row_normalize_tensor(features-features.min())
|
| 238 |
+
|
| 239 |
+
self.adj_norm = adj_norm
|
| 240 |
+
self.features = features
|
| 241 |
+
|
| 242 |
+
if len(labels.shape) > 1:
|
| 243 |
+
self.multi_label = True
|
| 244 |
+
self.loss = torch.nn.BCELoss()
|
| 245 |
+
else:
|
| 246 |
+
self.multi_label = False
|
| 247 |
+
self.loss = F.nll_loss
|
| 248 |
+
|
| 249 |
+
labels = labels.float() if self.multi_label else labels
|
| 250 |
+
self.labels = labels
|
| 251 |
+
|
| 252 |
+
if noval:
|
| 253 |
+
self._train_with_val(labels, data, train_iters, verbose, adj_val=True)
|
| 254 |
+
else:
|
| 255 |
+
self._train_with_val(labels, data, train_iters, verbose)
|
| 256 |
+
|
| 257 |
+
def _train_with_val(self, labels, data, train_iters, verbose, adj_val=False):
|
| 258 |
+
if adj_val:
|
| 259 |
+
feat_full, adj_full = data.feat_val, data.adj_val
|
| 260 |
+
else:
|
| 261 |
+
feat_full, adj_full = data.feat_full, data.adj_full
|
| 262 |
+
feat_full, adj_full = utils.to_tensor(feat_full, adj_full, device=self.device)
|
| 263 |
+
adj_full_norm = utils.normalize_adj_tensor(adj_full, sparse=True)
|
| 264 |
+
labels_val = torch.LongTensor(data.labels_val).to(self.device)
|
| 265 |
+
|
| 266 |
+
if verbose:
|
| 267 |
+
print('=== training gcn model ===')
|
| 268 |
+
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
| 269 |
+
|
| 270 |
+
best_acc_val = 0
|
| 271 |
+
|
| 272 |
+
for i in range(train_iters):
|
| 273 |
+
if i == train_iters // 2:
|
| 274 |
+
lr = self.lr*0.1
|
| 275 |
+
optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=self.weight_decay)
|
| 276 |
+
|
| 277 |
+
self.train()
|
| 278 |
+
optimizer.zero_grad()
|
| 279 |
+
output = self.forward(self.features, self.adj_norm)
|
| 280 |
+
loss_train = self.loss(output, labels)
|
| 281 |
+
loss_train.backward()
|
| 282 |
+
optimizer.step()
|
| 283 |
+
|
| 284 |
+
if verbose and i % 100 == 0:
|
| 285 |
+
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
|
| 286 |
+
|
| 287 |
+
with torch.no_grad():
|
| 288 |
+
self.eval()
|
| 289 |
+
output = self.forward(feat_full, adj_full_norm)
|
| 290 |
+
|
| 291 |
+
if adj_val:
|
| 292 |
+
loss_val = F.nll_loss(output, labels_val)
|
| 293 |
+
acc_val = utils.accuracy(output, labels_val)
|
| 294 |
+
else:
|
| 295 |
+
loss_val = F.nll_loss(output[data.idx_val], labels_val)
|
| 296 |
+
acc_val = utils.accuracy(output[data.idx_val], labels_val)
|
| 297 |
+
|
| 298 |
+
if acc_val > best_acc_val:
|
| 299 |
+
best_acc_val = acc_val
|
| 300 |
+
self.output = output
|
| 301 |
+
weights = deepcopy(self.state_dict())
|
| 302 |
+
|
| 303 |
+
if verbose:
|
| 304 |
+
print('=== picking the best model according to the performance on validation ===')
|
| 305 |
+
self.load_state_dict(weights)
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
def test(self, idx_test):
|
| 309 |
+
"""Evaluate GCN performance on test set.
|
| 310 |
+
Parameters
|
| 311 |
+
----------
|
| 312 |
+
idx_test :
|
| 313 |
+
node testing indices
|
| 314 |
+
"""
|
| 315 |
+
self.eval()
|
| 316 |
+
output = self.predict()
|
| 317 |
+
# output = self.output
|
| 318 |
+
loss_test = F.nll_loss(output[idx_test], self.labels[idx_test])
|
| 319 |
+
acc_test = utils.accuracy(output[idx_test], self.labels[idx_test])
|
| 320 |
+
print("Test set results:",
|
| 321 |
+
"loss= {:.4f}".format(loss_test.item()),
|
| 322 |
+
"accuracy= {:.4f}".format(acc_test.item()))
|
| 323 |
+
return acc_test.item()
|
| 324 |
+
|
| 325 |
+
|
| 326 |
+
@torch.no_grad()
|
| 327 |
+
def predict(self, features=None, adj=None):
|
| 328 |
+
"""By default, the inputs should be unnormalized adjacency
|
| 329 |
+
Parameters
|
| 330 |
+
----------
|
| 331 |
+
features :
|
| 332 |
+
node features. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.
|
| 333 |
+
adj :
|
| 334 |
+
adjcency matrix. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.
|
| 335 |
+
Returns
|
| 336 |
+
-------
|
| 337 |
+
torch.FloatTensor
|
| 338 |
+
output (log probabilities) of GCN
|
| 339 |
+
"""
|
| 340 |
+
|
| 341 |
+
self.eval()
|
| 342 |
+
if features is None and adj is None:
|
| 343 |
+
return self.forward(self.features, self.adj_norm)
|
| 344 |
+
else:
|
| 345 |
+
if type(adj) is not torch.Tensor:
|
| 346 |
+
features, adj = utils.to_tensor(features, adj, device=self.device)
|
| 347 |
+
|
| 348 |
+
self.features = features
|
| 349 |
+
if utils.is_sparse_tensor(adj):
|
| 350 |
+
self.adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
|
| 351 |
+
else:
|
| 352 |
+
self.adj_norm = utils.normalize_adj_tensor(adj)
|
| 353 |
+
return self.forward(self.features, self.adj_norm)
|
| 354 |
+
|
| 355 |
+
@torch.no_grad()
|
| 356 |
+
def predict_unnorm(self, features=None, adj=None):
|
| 357 |
+
self.eval()
|
| 358 |
+
if features is None and adj is None:
|
| 359 |
+
return self.forward(self.features, self.adj_norm)
|
| 360 |
+
else:
|
| 361 |
+
if type(adj) is not torch.Tensor:
|
| 362 |
+
features, adj = utils.to_tensor(features, adj, device=self.device)
|
| 363 |
+
|
| 364 |
+
self.features = features
|
| 365 |
+
self.adj_norm = adj
|
| 366 |
+
return self.forward(self.features, self.adj_norm)
|
| 367 |
+
|
| 368 |
+
|
| 369 |
+
def _train_with_val2(self, labels, idx_train, idx_val, train_iters, verbose):
|
| 370 |
+
if verbose:
|
| 371 |
+
print('=== training gcn model ===')
|
| 372 |
+
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
| 373 |
+
|
| 374 |
+
best_loss_val = 100
|
| 375 |
+
best_acc_val = 0
|
| 376 |
+
|
| 377 |
+
for i in range(train_iters):
|
| 378 |
+
if i == train_iters // 2:
|
| 379 |
+
lr = self.lr*0.1
|
| 380 |
+
optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=self.weight_decay)
|
| 381 |
+
|
| 382 |
+
self.train()
|
| 383 |
+
optimizer.zero_grad()
|
| 384 |
+
output = self.forward(self.features, self.adj_norm)
|
| 385 |
+
loss_train = F.nll_loss(output[idx_train], labels[idx_train])
|
| 386 |
+
loss_train.backward()
|
| 387 |
+
optimizer.step()
|
| 388 |
+
|
| 389 |
+
if verbose and i % 10 == 0:
|
| 390 |
+
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
|
| 391 |
+
|
| 392 |
+
self.eval()
|
| 393 |
+
output = self.forward(self.features, self.adj_norm)
|
| 394 |
+
loss_val = F.nll_loss(output[idx_val], labels[idx_val])
|
| 395 |
+
acc_val = utils.accuracy(output[idx_val], labels[idx_val])
|
| 396 |
+
|
| 397 |
+
if acc_val > best_acc_val:
|
| 398 |
+
best_acc_val = acc_val
|
| 399 |
+
self.output = output
|
| 400 |
+
weights = deepcopy(self.state_dict())
|
| 401 |
+
|
| 402 |
+
if verbose:
|
| 403 |
+
print('=== picking the best model according to the performance on validation ===')
|
| 404 |
+
self.load_state_dict(weights)
|
GCond/models/myappnp.py
ADDED
|
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""multiple transformaiton and multiple propagation"""
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import math
|
| 5 |
+
import torch
|
| 6 |
+
import torch.optim as optim
|
| 7 |
+
from torch.nn.parameter import Parameter
|
| 8 |
+
from torch.nn.modules.module import Module
|
| 9 |
+
from deeprobust.graph import utils
|
| 10 |
+
from copy import deepcopy
|
| 11 |
+
from sklearn.metrics import f1_score
|
| 12 |
+
from torch.nn import init
|
| 13 |
+
import torch_sparse
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class APPNP(nn.Module):
|
| 17 |
+
|
| 18 |
+
def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5, lr=0.01, weight_decay=5e-4,
|
| 19 |
+
ntrans=1, with_bias=True, with_bn=False, device=None):
|
| 20 |
+
|
| 21 |
+
super(APPNP, self).__init__()
|
| 22 |
+
|
| 23 |
+
assert device is not None, "Please specify 'device'!"
|
| 24 |
+
self.device = device
|
| 25 |
+
self.nfeat = nfeat
|
| 26 |
+
self.nclass = nclass
|
| 27 |
+
self.alpha = 0.1
|
| 28 |
+
|
| 29 |
+
with_bn = False
|
| 30 |
+
|
| 31 |
+
self.layers = nn.ModuleList([])
|
| 32 |
+
if ntrans == 1:
|
| 33 |
+
self.layers.append(MyLinear(nfeat, nclass))
|
| 34 |
+
else:
|
| 35 |
+
self.layers.append(MyLinear(nfeat, nhid))
|
| 36 |
+
if with_bn:
|
| 37 |
+
self.bns = torch.nn.ModuleList()
|
| 38 |
+
self.bns.append(nn.BatchNorm1d(nhid))
|
| 39 |
+
for i in range(ntrans-2):
|
| 40 |
+
if with_bn:
|
| 41 |
+
self.bns.append(nn.BatchNorm1d(nhid))
|
| 42 |
+
self.layers.append(MyLinear(nhid, nhid))
|
| 43 |
+
self.layers.append(MyLinear(nhid, nclass))
|
| 44 |
+
|
| 45 |
+
self.nlayers = nlayers
|
| 46 |
+
self.weight_decay = weight_decay
|
| 47 |
+
self.dropout = dropout
|
| 48 |
+
self.lr = lr
|
| 49 |
+
self.with_bn = with_bn
|
| 50 |
+
self.with_bias = with_bias
|
| 51 |
+
self.output = None
|
| 52 |
+
self.best_model = None
|
| 53 |
+
self.best_output = None
|
| 54 |
+
self.adj_norm = None
|
| 55 |
+
self.features = None
|
| 56 |
+
self.multi_label = None
|
| 57 |
+
self.sparse_dropout = SparseDropout(dprob=0)
|
| 58 |
+
|
| 59 |
+
def forward(self, x, adj):
|
| 60 |
+
for ix, layer in enumerate(self.layers):
|
| 61 |
+
x = layer(x)
|
| 62 |
+
if ix != len(self.layers) - 1:
|
| 63 |
+
x = self.bns[ix](x) if self.with_bn else x
|
| 64 |
+
x = F.relu(x)
|
| 65 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
| 66 |
+
|
| 67 |
+
h = x
|
| 68 |
+
# here nlayers means K
|
| 69 |
+
for i in range(self.nlayers):
|
| 70 |
+
# adj_drop = self.sparse_dropout(adj, training=self.training)
|
| 71 |
+
adj_drop = adj
|
| 72 |
+
x = torch.spmm(adj_drop, x)
|
| 73 |
+
x = x * (1 - self.alpha)
|
| 74 |
+
x = x + self.alpha * h
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
if self.multi_label:
|
| 78 |
+
return torch.sigmoid(x)
|
| 79 |
+
else:
|
| 80 |
+
return F.log_softmax(x, dim=1)
|
| 81 |
+
|
| 82 |
+
def forward_sampler(self, x, adjs):
|
| 83 |
+
for ix, layer in enumerate(self.layers):
|
| 84 |
+
x = layer(x)
|
| 85 |
+
if ix != len(self.layers) - 1:
|
| 86 |
+
x = self.bns[ix](x) if self.with_bn else x
|
| 87 |
+
x = F.relu(x)
|
| 88 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
| 89 |
+
|
| 90 |
+
h = x
|
| 91 |
+
for ix, (adj, _, size) in enumerate(adjs):
|
| 92 |
+
# x_target = x[: size[1]]
|
| 93 |
+
# x = self.layers[ix]((x, x_target), edge_index)
|
| 94 |
+
# adj = adj.to(self.device)
|
| 95 |
+
# adj_drop = F.dropout(adj, p=self.dropout)
|
| 96 |
+
adj_drop = adj
|
| 97 |
+
h = h[: size[1]]
|
| 98 |
+
x = torch_sparse.matmul(adj_drop, x)
|
| 99 |
+
x = x * (1 - self.alpha)
|
| 100 |
+
x = x + self.alpha * h
|
| 101 |
+
|
| 102 |
+
if self.multi_label:
|
| 103 |
+
return torch.sigmoid(x)
|
| 104 |
+
else:
|
| 105 |
+
return F.log_softmax(x, dim=1)
|
| 106 |
+
|
| 107 |
+
def forward_sampler_syn(self, x, adjs):
|
| 108 |
+
for ix, layer in enumerate(self.layers):
|
| 109 |
+
x = layer(x)
|
| 110 |
+
if ix != len(self.layers) - 1:
|
| 111 |
+
x = self.bns[ix](x) if self.with_bn else x
|
| 112 |
+
x = F.relu(x)
|
| 113 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
| 114 |
+
|
| 115 |
+
for ix, (adj) in enumerate(adjs):
|
| 116 |
+
# x_target = x[: size[1]]
|
| 117 |
+
# x = self.layers[ix]((x, x_target), edge_index)
|
| 118 |
+
# adj = adj.to(self.device)
|
| 119 |
+
x = torch_sparse.matmul(adj, x)
|
| 120 |
+
|
| 121 |
+
if self.multi_label:
|
| 122 |
+
return torch.sigmoid(x)
|
| 123 |
+
else:
|
| 124 |
+
return F.log_softmax(x, dim=1)
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
def initialize(self):
|
| 128 |
+
"""Initialize parameters of GCN.
|
| 129 |
+
"""
|
| 130 |
+
for layer in self.layers:
|
| 131 |
+
layer.reset_parameters()
|
| 132 |
+
if self.with_bn:
|
| 133 |
+
for bn in self.bns:
|
| 134 |
+
bn.reset_parameters()
|
| 135 |
+
|
| 136 |
+
def fit_with_val(self, features, adj, labels, data, train_iters=200, initialize=True, verbose=False, normalize=True, patience=None, noval=False, **kwargs):
|
| 137 |
+
'''data: full data class'''
|
| 138 |
+
if initialize:
|
| 139 |
+
self.initialize()
|
| 140 |
+
|
| 141 |
+
# features, adj, labels = data.feat_train, data.adj_train, data.labels_train
|
| 142 |
+
if type(adj) is not torch.Tensor:
|
| 143 |
+
features, adj, labels = utils.to_tensor(features, adj, labels, device=self.device)
|
| 144 |
+
else:
|
| 145 |
+
features = features.to(self.device)
|
| 146 |
+
adj = adj.to(self.device)
|
| 147 |
+
labels = labels.to(self.device)
|
| 148 |
+
|
| 149 |
+
if normalize:
|
| 150 |
+
if utils.is_sparse_tensor(adj):
|
| 151 |
+
adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
|
| 152 |
+
else:
|
| 153 |
+
adj_norm = utils.normalize_adj_tensor(adj)
|
| 154 |
+
else:
|
| 155 |
+
adj_norm = adj
|
| 156 |
+
|
| 157 |
+
if 'feat_norm' in kwargs and kwargs['feat_norm']:
|
| 158 |
+
from utils import row_normalize_tensor
|
| 159 |
+
features = row_normalize_tensor(features-features.min())
|
| 160 |
+
|
| 161 |
+
self.adj_norm = adj_norm
|
| 162 |
+
self.features = features
|
| 163 |
+
|
| 164 |
+
if len(labels.shape) > 1:
|
| 165 |
+
self.multi_label = True
|
| 166 |
+
self.loss = torch.nn.BCELoss()
|
| 167 |
+
else:
|
| 168 |
+
self.multi_label = False
|
| 169 |
+
self.loss = F.nll_loss
|
| 170 |
+
|
| 171 |
+
labels = labels.float() if self.multi_label else labels
|
| 172 |
+
self.labels = labels
|
| 173 |
+
|
| 174 |
+
if noval:
|
| 175 |
+
# self._train_without_val(labels, data, train_iters, verbose)
|
| 176 |
+
# self._train_without_val(labels, data, train_iters, verbose)
|
| 177 |
+
self._train_with_val(labels, data, train_iters, verbose, adj_val=True)
|
| 178 |
+
else:
|
| 179 |
+
self._train_with_val(labels, data, train_iters, verbose)
|
| 180 |
+
|
| 181 |
+
def _train_with_val(self, labels, data, train_iters, verbose, adj_val=False):
|
| 182 |
+
if adj_val:
|
| 183 |
+
feat_full, adj_full = data.feat_val, data.adj_val
|
| 184 |
+
else:
|
| 185 |
+
feat_full, adj_full = data.feat_full, data.adj_full
|
| 186 |
+
|
| 187 |
+
feat_full, adj_full = utils.to_tensor(feat_full, adj_full, device=self.device)
|
| 188 |
+
adj_full_norm = utils.normalize_adj_tensor(adj_full, sparse=True)
|
| 189 |
+
labels_val = torch.LongTensor(data.labels_val).to(self.device)
|
| 190 |
+
|
| 191 |
+
if verbose:
|
| 192 |
+
print('=== training gcn model ===')
|
| 193 |
+
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
| 194 |
+
|
| 195 |
+
best_acc_val = 0
|
| 196 |
+
|
| 197 |
+
for i in range(train_iters):
|
| 198 |
+
if i == train_iters // 2:
|
| 199 |
+
lr = self.lr*0.1
|
| 200 |
+
optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=self.weight_decay)
|
| 201 |
+
|
| 202 |
+
self.train()
|
| 203 |
+
optimizer.zero_grad()
|
| 204 |
+
output = self.forward(self.features, self.adj_norm)
|
| 205 |
+
loss_train = self.loss(output, labels)
|
| 206 |
+
loss_train.backward()
|
| 207 |
+
optimizer.step()
|
| 208 |
+
|
| 209 |
+
if verbose and i % 100 == 0:
|
| 210 |
+
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
|
| 211 |
+
|
| 212 |
+
with torch.no_grad():
|
| 213 |
+
self.eval()
|
| 214 |
+
output = self.forward(feat_full, adj_full_norm)
|
| 215 |
+
if adj_val:
|
| 216 |
+
loss_val = F.nll_loss(output, labels_val)
|
| 217 |
+
acc_val = utils.accuracy(output, labels_val)
|
| 218 |
+
else:
|
| 219 |
+
loss_val = F.nll_loss(output[data.idx_val], labels_val)
|
| 220 |
+
acc_val = utils.accuracy(output[data.idx_val], labels_val)
|
| 221 |
+
|
| 222 |
+
if acc_val > best_acc_val:
|
| 223 |
+
best_acc_val = acc_val
|
| 224 |
+
self.output = output
|
| 225 |
+
weights = deepcopy(self.state_dict())
|
| 226 |
+
|
| 227 |
+
if verbose:
|
| 228 |
+
print('=== picking the best model according to the performance on validation ===')
|
| 229 |
+
self.load_state_dict(weights)
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def test(self, idx_test):
|
| 233 |
+
"""Evaluate GCN performance on test set.
|
| 234 |
+
Parameters
|
| 235 |
+
----------
|
| 236 |
+
idx_test :
|
| 237 |
+
node testing indices
|
| 238 |
+
"""
|
| 239 |
+
self.eval()
|
| 240 |
+
output = self.predict()
|
| 241 |
+
# output = self.output
|
| 242 |
+
loss_test = F.nll_loss(output[idx_test], self.labels[idx_test])
|
| 243 |
+
acc_test = utils.accuracy(output[idx_test], self.labels[idx_test])
|
| 244 |
+
print("Test set results:",
|
| 245 |
+
"loss= {:.4f}".format(loss_test.item()),
|
| 246 |
+
"accuracy= {:.4f}".format(acc_test.item()))
|
| 247 |
+
return acc_test.item()
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
@torch.no_grad()
|
| 251 |
+
def predict(self, features=None, adj=None):
|
| 252 |
+
"""By default, the inputs should be unnormalized adjacency
|
| 253 |
+
Parameters
|
| 254 |
+
----------
|
| 255 |
+
features :
|
| 256 |
+
node features. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.
|
| 257 |
+
adj :
|
| 258 |
+
adjcency matrix. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.
|
| 259 |
+
Returns
|
| 260 |
+
-------
|
| 261 |
+
torch.FloatTensor
|
| 262 |
+
output (log probabilities) of GCN
|
| 263 |
+
"""
|
| 264 |
+
|
| 265 |
+
self.eval()
|
| 266 |
+
if features is None and adj is None:
|
| 267 |
+
return self.forward(self.features, self.adj_norm)
|
| 268 |
+
else:
|
| 269 |
+
if type(adj) is not torch.Tensor:
|
| 270 |
+
features, adj = utils.to_tensor(features, adj, device=self.device)
|
| 271 |
+
|
| 272 |
+
self.features = features
|
| 273 |
+
if utils.is_sparse_tensor(adj):
|
| 274 |
+
self.adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
|
| 275 |
+
else:
|
| 276 |
+
self.adj_norm = utils.normalize_adj_tensor(adj)
|
| 277 |
+
return self.forward(self.features, self.adj_norm)
|
| 278 |
+
|
| 279 |
+
@torch.no_grad()
|
| 280 |
+
def predict_unnorm(self, features=None, adj=None):
|
| 281 |
+
self.eval()
|
| 282 |
+
if features is None and adj is None:
|
| 283 |
+
return self.forward(self.features, self.adj_norm)
|
| 284 |
+
else:
|
| 285 |
+
if type(adj) is not torch.Tensor:
|
| 286 |
+
features, adj = utils.to_tensor(features, adj, device=self.device)
|
| 287 |
+
|
| 288 |
+
self.features = features
|
| 289 |
+
self.adj_norm = adj
|
| 290 |
+
return self.forward(self.features, self.adj_norm)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
|
| 294 |
+
class MyLinear(Module):
|
| 295 |
+
"""Simple Linear layer, modified from https://github.com/tkipf/pygcn
|
| 296 |
+
"""
|
| 297 |
+
|
| 298 |
+
def __init__(self, in_features, out_features, with_bias=True):
|
| 299 |
+
super(MyLinear, self).__init__()
|
| 300 |
+
self.in_features = in_features
|
| 301 |
+
self.out_features = out_features
|
| 302 |
+
self.weight = Parameter(torch.FloatTensor(in_features, out_features))
|
| 303 |
+
if with_bias:
|
| 304 |
+
self.bias = Parameter(torch.FloatTensor(out_features))
|
| 305 |
+
else:
|
| 306 |
+
self.register_parameter('bias', None)
|
| 307 |
+
self.reset_parameters()
|
| 308 |
+
|
| 309 |
+
def reset_parameters(self):
|
| 310 |
+
# stdv = 1. / math.sqrt(self.weight.size(1))
|
| 311 |
+
stdv = 1. / math.sqrt(self.weight.T.size(1))
|
| 312 |
+
self.weight.data.uniform_(-stdv, stdv)
|
| 313 |
+
if self.bias is not None:
|
| 314 |
+
self.bias.data.uniform_(-stdv, stdv)
|
| 315 |
+
|
| 316 |
+
def forward(self, input):
|
| 317 |
+
if input.data.is_sparse:
|
| 318 |
+
support = torch.spmm(input, self.weight)
|
| 319 |
+
else:
|
| 320 |
+
support = torch.mm(input, self.weight)
|
| 321 |
+
output = support
|
| 322 |
+
if self.bias is not None:
|
| 323 |
+
return output + self.bias
|
| 324 |
+
else:
|
| 325 |
+
return output
|
| 326 |
+
|
| 327 |
+
def __repr__(self):
|
| 328 |
+
return self.__class__.__name__ + ' (' \
|
| 329 |
+
+ str(self.in_features) + ' -> ' \
|
| 330 |
+
+ str(self.out_features) + ')'
|
| 331 |
+
|
| 332 |
+
class SparseDropout(torch.nn.Module):
|
| 333 |
+
def __init__(self, dprob=0.5):
|
| 334 |
+
super(SparseDropout, self).__init__()
|
| 335 |
+
self.kprob=1-dprob
|
| 336 |
+
|
| 337 |
+
def forward(self, x, training):
|
| 338 |
+
if training:
|
| 339 |
+
mask=((torch.rand(x._values().size())+(self.kprob)).floor()).type(torch.bool)
|
| 340 |
+
rc=x._indices()[:,mask]
|
| 341 |
+
val=x._values()[mask]*(1.0/self.kprob)
|
| 342 |
+
return torch.sparse.FloatTensor(rc, val, x.size())
|
| 343 |
+
else:
|
| 344 |
+
return x
|
GCond/models/myappnp1.py
ADDED
|
@@ -0,0 +1,348 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""multiple transformaiton and multiple propagation"""
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import math
|
| 5 |
+
import torch
|
| 6 |
+
import torch.optim as optim
|
| 7 |
+
from torch.nn.parameter import Parameter
|
| 8 |
+
from torch.nn.modules.module import Module
|
| 9 |
+
from deeprobust.graph import utils
|
| 10 |
+
from copy import deepcopy
|
| 11 |
+
from sklearn.metrics import f1_score
|
| 12 |
+
from torch.nn import init
|
| 13 |
+
import torch_sparse
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class APPNP1(nn.Module):
|
| 17 |
+
|
| 18 |
+
def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5, lr=0.01, weight_decay=5e-4,
|
| 19 |
+
with_relu=True, with_bias=True, with_bn=False, device=None):
|
| 20 |
+
|
| 21 |
+
super(APPNP1, self).__init__()
|
| 22 |
+
|
| 23 |
+
assert device is not None, "Please specify 'device'!"
|
| 24 |
+
self.device = device
|
| 25 |
+
self.nfeat = nfeat
|
| 26 |
+
self.nclass = nclass
|
| 27 |
+
self.alpha = 0.1
|
| 28 |
+
|
| 29 |
+
if with_bn:
|
| 30 |
+
self.bns = torch.nn.ModuleList()
|
| 31 |
+
self.bns.append(nn.BatchNorm1d(nhid))
|
| 32 |
+
|
| 33 |
+
self.layers = nn.ModuleList([])
|
| 34 |
+
# self.layers.append(MyLinear(nfeat, nclass))
|
| 35 |
+
self.layers.append(MyLinear(nfeat, nhid))
|
| 36 |
+
# self.layers.append(MyLinear(nhid, nhid))
|
| 37 |
+
self.layers.append(MyLinear(nhid, nclass))
|
| 38 |
+
|
| 39 |
+
# if nlayers == 1:
|
| 40 |
+
# self.layers.append(nn.Linear(nfeat, nclass))
|
| 41 |
+
# else:
|
| 42 |
+
# self.layers.append(nn.Linear(nfeat, nhid))
|
| 43 |
+
# for i in range(nlayers-2):
|
| 44 |
+
# self.layers.append(nn.Linear(nhid, nhid))
|
| 45 |
+
# self.layers.append(nn.Linear(nhid, nclass))
|
| 46 |
+
|
| 47 |
+
self.nlayers = nlayers
|
| 48 |
+
self.dropout = dropout
|
| 49 |
+
self.lr = lr
|
| 50 |
+
if not with_relu:
|
| 51 |
+
self.weight_decay = 0
|
| 52 |
+
else:
|
| 53 |
+
self.weight_decay = weight_decay
|
| 54 |
+
self.with_relu = with_relu
|
| 55 |
+
self.with_bn = with_bn
|
| 56 |
+
self.with_bias = with_bias
|
| 57 |
+
self.output = None
|
| 58 |
+
self.best_model = None
|
| 59 |
+
self.best_output = None
|
| 60 |
+
self.adj_norm = None
|
| 61 |
+
self.features = None
|
| 62 |
+
self.multi_label = None
|
| 63 |
+
self.sparse_dropout = SparseDropout(dprob=0)
|
| 64 |
+
|
| 65 |
+
def forward(self, x, adj):
|
| 66 |
+
for ix, layer in enumerate(self.layers):
|
| 67 |
+
x = layer(x)
|
| 68 |
+
if ix != len(self.layers) - 1:
|
| 69 |
+
x = self.bns[ix](x) if self.with_bn else x
|
| 70 |
+
x = F.relu(x)
|
| 71 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
| 72 |
+
|
| 73 |
+
h = x
|
| 74 |
+
# here nlayers means K
|
| 75 |
+
for i in range(self.nlayers):
|
| 76 |
+
# adj_drop = self.sparse_dropout(adj, training=self.training)
|
| 77 |
+
adj_drop = adj
|
| 78 |
+
x = torch.spmm(adj_drop, x)
|
| 79 |
+
x = x * (1 - self.alpha)
|
| 80 |
+
x = x + self.alpha * h
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
if self.multi_label:
|
| 84 |
+
return torch.sigmoid(x)
|
| 85 |
+
else:
|
| 86 |
+
return F.log_softmax(x, dim=1)
|
| 87 |
+
|
| 88 |
+
def forward_sampler(self, x, adjs):
|
| 89 |
+
for ix, layer in enumerate(self.layers):
|
| 90 |
+
x = layer(x)
|
| 91 |
+
if ix != len(self.layers) - 1:
|
| 92 |
+
x = self.bns[ix](x) if self.with_bn else x
|
| 93 |
+
x = F.relu(x)
|
| 94 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
| 95 |
+
|
| 96 |
+
h = x
|
| 97 |
+
for ix, (adj, _, size) in enumerate(adjs):
|
| 98 |
+
# x_target = x[: size[1]]
|
| 99 |
+
# x = self.layers[ix]((x, x_target), edge_index)
|
| 100 |
+
# adj = adj.to(self.device)
|
| 101 |
+
# adj_drop = F.dropout(adj, p=self.dropout)
|
| 102 |
+
adj_drop = adj
|
| 103 |
+
h = h[: size[1]]
|
| 104 |
+
x = torch_sparse.matmul(adj_drop, x)
|
| 105 |
+
x = x * (1 - self.alpha)
|
| 106 |
+
x = x + self.alpha * h
|
| 107 |
+
|
| 108 |
+
if self.multi_label:
|
| 109 |
+
return torch.sigmoid(x)
|
| 110 |
+
else:
|
| 111 |
+
return F.log_softmax(x, dim=1)
|
| 112 |
+
|
| 113 |
+
def forward_sampler_syn(self, x, adjs):
|
| 114 |
+
for ix, layer in enumerate(self.layers):
|
| 115 |
+
x = layer(x)
|
| 116 |
+
if ix != len(self.layers) - 1:
|
| 117 |
+
x = self.bns[ix](x) if self.with_bn else x
|
| 118 |
+
x = F.relu(x)
|
| 119 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
| 120 |
+
|
| 121 |
+
for ix, (adj) in enumerate(adjs):
|
| 122 |
+
# x_target = x[: size[1]]
|
| 123 |
+
# x = self.layers[ix]((x, x_target), edge_index)
|
| 124 |
+
# adj = adj.to(self.device)
|
| 125 |
+
x = torch_sparse.matmul(adj, x)
|
| 126 |
+
|
| 127 |
+
if self.multi_label:
|
| 128 |
+
return torch.sigmoid(x)
|
| 129 |
+
else:
|
| 130 |
+
return F.log_softmax(x, dim=1)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
def initialize(self):
|
| 134 |
+
"""Initialize parameters of GCN.
|
| 135 |
+
"""
|
| 136 |
+
for layer in self.layers:
|
| 137 |
+
layer.reset_parameters()
|
| 138 |
+
if self.with_bn:
|
| 139 |
+
for bn in self.bns:
|
| 140 |
+
bn.reset_parameters()
|
| 141 |
+
|
| 142 |
+
def fit_with_val(self, features, adj, labels, data, train_iters=200, initialize=True, verbose=False, normalize=True, patience=None, noval=False, **kwargs):
|
| 143 |
+
'''data: full data class'''
|
| 144 |
+
if initialize:
|
| 145 |
+
self.initialize()
|
| 146 |
+
|
| 147 |
+
# features, adj, labels = data.feat_train, data.adj_train, data.labels_train
|
| 148 |
+
if type(adj) is not torch.Tensor:
|
| 149 |
+
features, adj, labels = utils.to_tensor(features, adj, labels, device=self.device)
|
| 150 |
+
else:
|
| 151 |
+
features = features.to(self.device)
|
| 152 |
+
adj = adj.to(self.device)
|
| 153 |
+
labels = labels.to(self.device)
|
| 154 |
+
|
| 155 |
+
if normalize:
|
| 156 |
+
if utils.is_sparse_tensor(adj):
|
| 157 |
+
adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
|
| 158 |
+
else:
|
| 159 |
+
adj_norm = utils.normalize_adj_tensor(adj)
|
| 160 |
+
else:
|
| 161 |
+
adj_norm = adj
|
| 162 |
+
|
| 163 |
+
if 'feat_norm' in kwargs and kwargs['feat_norm']:
|
| 164 |
+
from utils import row_normalize_tensor
|
| 165 |
+
features = row_normalize_tensor(features-features.min())
|
| 166 |
+
|
| 167 |
+
self.adj_norm = adj_norm
|
| 168 |
+
self.features = features
|
| 169 |
+
|
| 170 |
+
if len(labels.shape) > 1:
|
| 171 |
+
self.multi_label = True
|
| 172 |
+
self.loss = torch.nn.BCELoss()
|
| 173 |
+
else:
|
| 174 |
+
self.multi_label = False
|
| 175 |
+
self.loss = F.nll_loss
|
| 176 |
+
|
| 177 |
+
labels = labels.float() if self.multi_label else labels
|
| 178 |
+
self.labels = labels
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
if noval:
|
| 182 |
+
self._train_with_val(labels, data, train_iters, verbose, adj_val=True)
|
| 183 |
+
else:
|
| 184 |
+
self._train_with_val(labels, data, train_iters, verbose)
|
| 185 |
+
|
| 186 |
+
def _train_with_val(self, labels, data, train_iters, verbose, adj_val=False):
|
| 187 |
+
if adj_val:
|
| 188 |
+
feat_full, adj_full = data.feat_val, data.adj_val
|
| 189 |
+
else:
|
| 190 |
+
feat_full, adj_full = data.feat_full, data.adj_full
|
| 191 |
+
feat_full, adj_full = utils.to_tensor(feat_full, adj_full, device=self.device)
|
| 192 |
+
adj_full_norm = utils.normalize_adj_tensor(adj_full, sparse=True)
|
| 193 |
+
labels_val = torch.LongTensor(data.labels_val).to(self.device)
|
| 194 |
+
|
| 195 |
+
if verbose:
|
| 196 |
+
print('=== training gcn model ===')
|
| 197 |
+
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
| 198 |
+
|
| 199 |
+
best_acc_val = 0
|
| 200 |
+
|
| 201 |
+
for i in range(train_iters):
|
| 202 |
+
if i == train_iters // 2:
|
| 203 |
+
lr = self.lr*0.1
|
| 204 |
+
optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=self.weight_decay)
|
| 205 |
+
|
| 206 |
+
self.train()
|
| 207 |
+
optimizer.zero_grad()
|
| 208 |
+
output = self.forward(self.features, self.adj_norm)
|
| 209 |
+
loss_train = self.loss(output, labels)
|
| 210 |
+
loss_train.backward()
|
| 211 |
+
optimizer.step()
|
| 212 |
+
|
| 213 |
+
if verbose and i % 100 == 0:
|
| 214 |
+
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
|
| 215 |
+
|
| 216 |
+
with torch.no_grad():
|
| 217 |
+
self.eval()
|
| 218 |
+
output = self.forward(feat_full, adj_full_norm)
|
| 219 |
+
if adj_val:
|
| 220 |
+
loss_val = F.nll_loss(output, labels_val)
|
| 221 |
+
acc_val = utils.accuracy(output, labels_val)
|
| 222 |
+
else:
|
| 223 |
+
loss_val = F.nll_loss(output[data.idx_val], labels_val)
|
| 224 |
+
acc_val = utils.accuracy(output[data.idx_val], labels_val)
|
| 225 |
+
|
| 226 |
+
if acc_val > best_acc_val:
|
| 227 |
+
best_acc_val = acc_val
|
| 228 |
+
self.output = output
|
| 229 |
+
weights = deepcopy(self.state_dict())
|
| 230 |
+
|
| 231 |
+
if verbose:
|
| 232 |
+
print('=== picking the best model according to the performance on validation ===')
|
| 233 |
+
self.load_state_dict(weights)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def test(self, idx_test):
|
| 237 |
+
"""Evaluate GCN performance on test set.
|
| 238 |
+
Parameters
|
| 239 |
+
----------
|
| 240 |
+
idx_test :
|
| 241 |
+
node testing indices
|
| 242 |
+
"""
|
| 243 |
+
self.eval()
|
| 244 |
+
output = self.predict()
|
| 245 |
+
# output = self.output
|
| 246 |
+
loss_test = F.nll_loss(output[idx_test], self.labels[idx_test])
|
| 247 |
+
acc_test = utils.accuracy(output[idx_test], self.labels[idx_test])
|
| 248 |
+
print("Test set results:",
|
| 249 |
+
"loss= {:.4f}".format(loss_test.item()),
|
| 250 |
+
"accuracy= {:.4f}".format(acc_test.item()))
|
| 251 |
+
return acc_test.item()
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
@torch.no_grad()
|
| 255 |
+
def predict(self, features=None, adj=None):
|
| 256 |
+
"""By default, the inputs should be unnormalized adjacency
|
| 257 |
+
Parameters
|
| 258 |
+
----------
|
| 259 |
+
features :
|
| 260 |
+
node features. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.
|
| 261 |
+
adj :
|
| 262 |
+
adjcency matrix. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.
|
| 263 |
+
Returns
|
| 264 |
+
-------
|
| 265 |
+
torch.FloatTensor
|
| 266 |
+
output (log probabilities) of GCN
|
| 267 |
+
"""
|
| 268 |
+
|
| 269 |
+
self.eval()
|
| 270 |
+
if features is None and adj is None:
|
| 271 |
+
return self.forward(self.features, self.adj_norm)
|
| 272 |
+
else:
|
| 273 |
+
if type(adj) is not torch.Tensor:
|
| 274 |
+
features, adj = utils.to_tensor(features, adj, device=self.device)
|
| 275 |
+
|
| 276 |
+
self.features = features
|
| 277 |
+
if utils.is_sparse_tensor(adj):
|
| 278 |
+
self.adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
|
| 279 |
+
else:
|
| 280 |
+
self.adj_norm = utils.normalize_adj_tensor(adj)
|
| 281 |
+
return self.forward(self.features, self.adj_norm)
|
| 282 |
+
|
| 283 |
+
@torch.no_grad()
|
| 284 |
+
def predict_unnorm(self, features=None, adj=None):
|
| 285 |
+
self.eval()
|
| 286 |
+
if features is None and adj is None:
|
| 287 |
+
return self.forward(self.features, self.adj_norm)
|
| 288 |
+
else:
|
| 289 |
+
if type(adj) is not torch.Tensor:
|
| 290 |
+
features, adj = utils.to_tensor(features, adj, device=self.device)
|
| 291 |
+
|
| 292 |
+
self.features = features
|
| 293 |
+
self.adj_norm = adj
|
| 294 |
+
return self.forward(self.features, self.adj_norm)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
class MyLinear(Module):
|
| 299 |
+
"""Simple Linear layer, modified from https://github.com/tkipf/pygcn
|
| 300 |
+
"""
|
| 301 |
+
|
| 302 |
+
def __init__(self, in_features, out_features, with_bias=True):
|
| 303 |
+
super(MyLinear, self).__init__()
|
| 304 |
+
self.in_features = in_features
|
| 305 |
+
self.out_features = out_features
|
| 306 |
+
self.weight = Parameter(torch.FloatTensor(in_features, out_features))
|
| 307 |
+
if with_bias:
|
| 308 |
+
self.bias = Parameter(torch.FloatTensor(out_features))
|
| 309 |
+
else:
|
| 310 |
+
self.register_parameter('bias', None)
|
| 311 |
+
self.reset_parameters()
|
| 312 |
+
|
| 313 |
+
def reset_parameters(self):
|
| 314 |
+
# stdv = 1. / math.sqrt(self.weight.size(1))
|
| 315 |
+
stdv = 1. / math.sqrt(self.weight.T.size(1))
|
| 316 |
+
self.weight.data.uniform_(-stdv, stdv)
|
| 317 |
+
if self.bias is not None:
|
| 318 |
+
self.bias.data.uniform_(-stdv, stdv)
|
| 319 |
+
|
| 320 |
+
def forward(self, input):
|
| 321 |
+
if input.data.is_sparse:
|
| 322 |
+
support = torch.spmm(input, self.weight)
|
| 323 |
+
else:
|
| 324 |
+
support = torch.mm(input, self.weight)
|
| 325 |
+
output = support
|
| 326 |
+
if self.bias is not None:
|
| 327 |
+
return output + self.bias
|
| 328 |
+
else:
|
| 329 |
+
return output
|
| 330 |
+
|
| 331 |
+
def __repr__(self):
|
| 332 |
+
return self.__class__.__name__ + ' (' \
|
| 333 |
+
+ str(self.in_features) + ' -> ' \
|
| 334 |
+
+ str(self.out_features) + ')'
|
| 335 |
+
|
| 336 |
+
class SparseDropout(torch.nn.Module):
|
| 337 |
+
def __init__(self, dprob=0.5):
|
| 338 |
+
super(SparseDropout, self).__init__()
|
| 339 |
+
self.kprob=1-dprob
|
| 340 |
+
|
| 341 |
+
def forward(self, x, training):
|
| 342 |
+
if training:
|
| 343 |
+
mask=((torch.rand(x._values().size())+(self.kprob)).floor()).type(torch.bool)
|
| 344 |
+
rc=x._indices()[:,mask]
|
| 345 |
+
val=x._values()[mask]*(1.0/self.kprob)
|
| 346 |
+
return torch.sparse.FloatTensor(rc, val, x.size())
|
| 347 |
+
else:
|
| 348 |
+
return x
|
GCond/models/mycheby.py
ADDED
|
@@ -0,0 +1,417 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import math
|
| 4 |
+
import torch
|
| 5 |
+
import torch.optim as optim
|
| 6 |
+
from torch.nn.parameter import Parameter
|
| 7 |
+
from torch.nn.modules.module import Module
|
| 8 |
+
from deeprobust.graph import utils
|
| 9 |
+
from copy import deepcopy
|
| 10 |
+
from sklearn.metrics import f1_score
|
| 11 |
+
from torch.nn import init
|
| 12 |
+
import torch_sparse
|
| 13 |
+
from torch_geometric.nn.inits import zeros
|
| 14 |
+
import scipy.sparse as sp
|
| 15 |
+
import numpy as np
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ChebConvolution(Module):
|
| 19 |
+
"""Simple GCN layer, similar to https://github.com/tkipf/pygcn
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self, in_features, out_features, with_bias=True, single_param=True, K=2):
|
| 23 |
+
"""set single_param to True to alleivate the overfitting issue"""
|
| 24 |
+
super(ChebConvolution, self).__init__()
|
| 25 |
+
self.in_features = in_features
|
| 26 |
+
self.out_features = out_features
|
| 27 |
+
self.lins = torch.nn.ModuleList([
|
| 28 |
+
MyLinear(in_features, out_features, with_bias=False) for _ in range(K)])
|
| 29 |
+
# self.lins = torch.nn.ModuleList([
|
| 30 |
+
# MyLinear(in_features, out_features, with_bias=True) for _ in range(K)])
|
| 31 |
+
if with_bias:
|
| 32 |
+
self.bias = Parameter(torch.Tensor(out_features))
|
| 33 |
+
else:
|
| 34 |
+
self.register_parameter('bias', None)
|
| 35 |
+
self.single_param = single_param
|
| 36 |
+
self.reset_parameters()
|
| 37 |
+
|
| 38 |
+
def reset_parameters(self):
|
| 39 |
+
for lin in self.lins:
|
| 40 |
+
lin.reset_parameters()
|
| 41 |
+
zeros(self.bias)
|
| 42 |
+
|
| 43 |
+
def forward(self, input, adj, size=None):
|
| 44 |
+
""" Graph Convolutional Layer forward function
|
| 45 |
+
"""
|
| 46 |
+
# support = torch.mm(input, self.weight_l)
|
| 47 |
+
x = input
|
| 48 |
+
Tx_0 = x[:size[1]] if size is not None else x
|
| 49 |
+
Tx_1 = x # dummy
|
| 50 |
+
output = self.lins[0](Tx_0)
|
| 51 |
+
|
| 52 |
+
if len(self.lins) > 1:
|
| 53 |
+
if isinstance(adj, torch_sparse.SparseTensor):
|
| 54 |
+
Tx_1 = torch_sparse.matmul(adj, x)
|
| 55 |
+
else:
|
| 56 |
+
Tx_1 = torch.spmm(adj, x)
|
| 57 |
+
|
| 58 |
+
if self.single_param:
|
| 59 |
+
output = output + self.lins[0](Tx_1)
|
| 60 |
+
else:
|
| 61 |
+
output = output + self.lins[1](Tx_1)
|
| 62 |
+
|
| 63 |
+
for lin in self.lins[2:]:
|
| 64 |
+
if self.single_param:
|
| 65 |
+
lin = self.lins[0]
|
| 66 |
+
if isinstance(adj, torch_sparse.SparseTensor):
|
| 67 |
+
Tx_2 = torch_sparse.matmul(adj, Tx_1)
|
| 68 |
+
else:
|
| 69 |
+
Tx_2 = torch.spmm(adj, Tx_1)
|
| 70 |
+
Tx_2 = 2. * Tx_2 - Tx_0
|
| 71 |
+
output = output + lin.forward(Tx_2)
|
| 72 |
+
Tx_0, Tx_1 = Tx_1, Tx_2
|
| 73 |
+
|
| 74 |
+
if self.bias is not None:
|
| 75 |
+
return output + self.bias
|
| 76 |
+
else:
|
| 77 |
+
return output
|
| 78 |
+
|
| 79 |
+
def __repr__(self):
|
| 80 |
+
return self.__class__.__name__ + ' (' \
|
| 81 |
+
+ str(self.in_features) + ' -> ' \
|
| 82 |
+
+ str(self.out_features) + ')'
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class Cheby(nn.Module):
|
| 86 |
+
|
| 87 |
+
def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5, lr=0.01, weight_decay=5e-4,
|
| 88 |
+
with_relu=True, with_bias=True, with_bn=False, device=None):
|
| 89 |
+
|
| 90 |
+
super(Cheby, self).__init__()
|
| 91 |
+
|
| 92 |
+
assert device is not None, "Please specify 'device'!"
|
| 93 |
+
self.device = device
|
| 94 |
+
self.nfeat = nfeat
|
| 95 |
+
self.nclass = nclass
|
| 96 |
+
|
| 97 |
+
self.layers = nn.ModuleList([])
|
| 98 |
+
|
| 99 |
+
if nlayers == 1:
|
| 100 |
+
self.layers.append(ChebConvolution(nfeat, nclass, with_bias=with_bias))
|
| 101 |
+
else:
|
| 102 |
+
if with_bn:
|
| 103 |
+
self.bns = torch.nn.ModuleList()
|
| 104 |
+
self.bns.append(nn.BatchNorm1d(nhid))
|
| 105 |
+
self.layers.append(ChebConvolution(nfeat, nhid, with_bias=with_bias))
|
| 106 |
+
for i in range(nlayers-2):
|
| 107 |
+
self.layers.append(ChebConvolution(nhid, nhid, with_bias=with_bias))
|
| 108 |
+
if with_bn:
|
| 109 |
+
self.bns.append(nn.BatchNorm1d(nhid))
|
| 110 |
+
self.layers.append(ChebConvolution(nhid, nclass, with_bias=with_bias))
|
| 111 |
+
|
| 112 |
+
# self.lin = MyLinear(nhid, nclass, with_bias=True)
|
| 113 |
+
|
| 114 |
+
# dropout = 0.5
|
| 115 |
+
self.dropout = dropout
|
| 116 |
+
self.lr = lr
|
| 117 |
+
self.weight_decay = weight_decay
|
| 118 |
+
self.with_relu = with_relu
|
| 119 |
+
self.with_bn = with_bn
|
| 120 |
+
self.with_bias = with_bias
|
| 121 |
+
self.output = None
|
| 122 |
+
self.best_model = None
|
| 123 |
+
self.best_output = None
|
| 124 |
+
self.adj_norm = None
|
| 125 |
+
self.features = None
|
| 126 |
+
self.multi_label = None
|
| 127 |
+
|
| 128 |
+
def forward(self, x, adj):
|
| 129 |
+
for ix, layer in enumerate(self.layers):
|
| 130 |
+
# x = F.dropout(x, 0.2, training=self.training)
|
| 131 |
+
x = layer(x, adj)
|
| 132 |
+
if ix != len(self.layers) - 1:
|
| 133 |
+
x = self.bns[ix](x) if self.with_bn else x
|
| 134 |
+
if self.with_relu:
|
| 135 |
+
x = F.relu(x)
|
| 136 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
| 137 |
+
# x = F.dropout(x, 0.5, training=self.training)
|
| 138 |
+
|
| 139 |
+
if self.multi_label:
|
| 140 |
+
return torch.sigmoid(x)
|
| 141 |
+
else:
|
| 142 |
+
return F.log_softmax(x, dim=1)
|
| 143 |
+
|
| 144 |
+
def forward_sampler(self, x, adjs):
|
| 145 |
+
# TODO: do we need normalization?
|
| 146 |
+
# for ix, layer in enumerate(self.layers):
|
| 147 |
+
for ix, (adj, _, size) in enumerate(adjs):
|
| 148 |
+
# x_target = x[: size[1]]
|
| 149 |
+
# x = self.layers[ix]((x, x_target), edge_index)
|
| 150 |
+
# adj = adj.to(self.device)
|
| 151 |
+
x = self.layers[ix](x, adj, size=size)
|
| 152 |
+
if ix != len(self.layers) - 1:
|
| 153 |
+
x = self.bns[ix](x) if self.with_bn else x
|
| 154 |
+
if self.with_relu:
|
| 155 |
+
x = F.relu(x)
|
| 156 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
| 157 |
+
|
| 158 |
+
if self.multi_label:
|
| 159 |
+
return torch.sigmoid(x)
|
| 160 |
+
else:
|
| 161 |
+
return F.log_softmax(x, dim=1)
|
| 162 |
+
|
| 163 |
+
def forward_sampler_syn(self, x, adjs):
|
| 164 |
+
for ix, (adj) in enumerate(adjs):
|
| 165 |
+
x = self.layers[ix](x, adj)
|
| 166 |
+
if ix != len(self.layers) - 1:
|
| 167 |
+
x = self.bns[ix](x) if self.with_bn else x
|
| 168 |
+
if self.with_relu:
|
| 169 |
+
x = F.relu(x)
|
| 170 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
| 171 |
+
|
| 172 |
+
if self.multi_label:
|
| 173 |
+
return torch.sigmoid(x)
|
| 174 |
+
else:
|
| 175 |
+
return F.log_softmax(x, dim=1)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def initialize(self):
|
| 179 |
+
"""Initialize parameters of GCN.
|
| 180 |
+
"""
|
| 181 |
+
for layer in self.layers:
|
| 182 |
+
layer.reset_parameters()
|
| 183 |
+
if self.with_bn:
|
| 184 |
+
for bn in self.bns:
|
| 185 |
+
bn.reset_parameters()
|
| 186 |
+
|
| 187 |
+
def fit_with_val(self, features, adj, labels, data, train_iters=200, initialize=True, verbose=False, normalize=True, patience=None, noval=False, **kwargs):
|
| 188 |
+
'''data: full data class'''
|
| 189 |
+
if initialize:
|
| 190 |
+
self.initialize()
|
| 191 |
+
|
| 192 |
+
# features, adj, labels = data.feat_train, data.adj_train, data.labels_train
|
| 193 |
+
|
| 194 |
+
if type(adj) is not torch.Tensor:
|
| 195 |
+
features, adj, labels = utils.to_tensor(features, adj, labels, device=self.device)
|
| 196 |
+
else:
|
| 197 |
+
features = features.to(self.device)
|
| 198 |
+
adj = adj.to(self.device)
|
| 199 |
+
labels = labels.to(self.device)
|
| 200 |
+
|
| 201 |
+
adj = adj - torch.eye(adj.shape[0]).to(self.device) # cheby
|
| 202 |
+
if normalize:
|
| 203 |
+
adj_norm = utils.normalize_adj_tensor(adj)
|
| 204 |
+
else:
|
| 205 |
+
adj_norm = adj
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
if 'feat_norm' in kwargs and kwargs['feat_norm']:
|
| 209 |
+
from utils import row_normalize_tensor
|
| 210 |
+
features = row_normalize_tensor(features-features.min())
|
| 211 |
+
|
| 212 |
+
self.adj_norm = adj_norm
|
| 213 |
+
self.features = features
|
| 214 |
+
|
| 215 |
+
if len(labels.shape) > 1:
|
| 216 |
+
self.multi_label = True
|
| 217 |
+
self.loss = torch.nn.BCELoss()
|
| 218 |
+
else:
|
| 219 |
+
self.multi_label = False
|
| 220 |
+
self.loss = F.nll_loss
|
| 221 |
+
|
| 222 |
+
labels = labels.float() if self.multi_label else labels
|
| 223 |
+
self.labels = labels
|
| 224 |
+
|
| 225 |
+
if noval:
|
| 226 |
+
self._train_with_val(labels, data, train_iters, verbose, adj_val=True)
|
| 227 |
+
else:
|
| 228 |
+
self._train_with_val(labels, data, train_iters, verbose)
|
| 229 |
+
|
| 230 |
+
def _train_with_val(self, labels, data, train_iters, verbose, adj_val=False):
|
| 231 |
+
if adj_val:
|
| 232 |
+
feat_full, adj_full = data.feat_val, data.adj_val
|
| 233 |
+
else:
|
| 234 |
+
feat_full, adj_full = data.feat_full, data.adj_full
|
| 235 |
+
# adj_full = adj_full - sp.eye(adj_full.shape[0])
|
| 236 |
+
feat_full, adj_full = utils.to_tensor(feat_full, adj_full, device=self.device)
|
| 237 |
+
adj_full_norm = utils.normalize_adj_tensor(adj_full, sparse=True)
|
| 238 |
+
labels_val = torch.LongTensor(data.labels_val).to(self.device)
|
| 239 |
+
|
| 240 |
+
if verbose:
|
| 241 |
+
print('=== training gcn model ===')
|
| 242 |
+
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
| 243 |
+
|
| 244 |
+
best_acc_val = 0
|
| 245 |
+
best_loss_val = 100
|
| 246 |
+
|
| 247 |
+
for i in range(train_iters):
|
| 248 |
+
if i == train_iters // 2:
|
| 249 |
+
lr = self.lr*0.1
|
| 250 |
+
optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=self.weight_decay)
|
| 251 |
+
|
| 252 |
+
self.train()
|
| 253 |
+
optimizer.zero_grad()
|
| 254 |
+
output = self.forward(self.features, self.adj_norm)
|
| 255 |
+
loss_train = self.loss(output, labels)
|
| 256 |
+
loss_train.backward()
|
| 257 |
+
optimizer.step()
|
| 258 |
+
|
| 259 |
+
if verbose and i % 100 == 0:
|
| 260 |
+
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
|
| 261 |
+
|
| 262 |
+
with torch.no_grad():
|
| 263 |
+
self.eval()
|
| 264 |
+
output = self.forward(feat_full, adj_full_norm)
|
| 265 |
+
if adj_val:
|
| 266 |
+
loss_val = F.nll_loss(output, labels_val)
|
| 267 |
+
acc_val = utils.accuracy(output, labels_val)
|
| 268 |
+
else:
|
| 269 |
+
loss_val = F.nll_loss(output[data.idx_val], labels_val)
|
| 270 |
+
acc_val = utils.accuracy(output[data.idx_val], labels_val)
|
| 271 |
+
|
| 272 |
+
# if loss_val < best_loss_val:
|
| 273 |
+
# best_loss_val = loss_val
|
| 274 |
+
# self.output = output
|
| 275 |
+
# weights = deepcopy(self.state_dict())
|
| 276 |
+
# print(best_loss_val)
|
| 277 |
+
|
| 278 |
+
if acc_val > best_acc_val:
|
| 279 |
+
best_acc_val = acc_val
|
| 280 |
+
self.output = output
|
| 281 |
+
weights = deepcopy(self.state_dict())
|
| 282 |
+
# print(best_acc_val)
|
| 283 |
+
|
| 284 |
+
if verbose:
|
| 285 |
+
print('=== picking the best model according to the performance on validation ===')
|
| 286 |
+
self.load_state_dict(weights)
|
| 287 |
+
|
| 288 |
+
def test(self, idx_test):
|
| 289 |
+
"""Evaluate GCN performance on test set.
|
| 290 |
+
Parameters
|
| 291 |
+
----------
|
| 292 |
+
idx_test :
|
| 293 |
+
node testing indices
|
| 294 |
+
"""
|
| 295 |
+
self.eval()
|
| 296 |
+
output = self.predict()
|
| 297 |
+
# output = self.output
|
| 298 |
+
loss_test = F.nll_loss(output[idx_test], self.labels[idx_test])
|
| 299 |
+
acc_test = utils.accuracy(output[idx_test], self.labels[idx_test])
|
| 300 |
+
print("Test set results:",
|
| 301 |
+
"loss= {:.4f}".format(loss_test.item()),
|
| 302 |
+
"accuracy= {:.4f}".format(acc_test.item()))
|
| 303 |
+
return acc_test.item()
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
@torch.no_grad()
|
| 307 |
+
def predict(self, features=None, adj=None):
|
| 308 |
+
"""By default, the inputs should be unnormalized adjacency
|
| 309 |
+
Parameters
|
| 310 |
+
----------
|
| 311 |
+
features :
|
| 312 |
+
node features. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.
|
| 313 |
+
adj :
|
| 314 |
+
adjcency matrix. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.
|
| 315 |
+
Returns
|
| 316 |
+
-------
|
| 317 |
+
torch.FloatTensor
|
| 318 |
+
output (log probabilities) of GCN
|
| 319 |
+
"""
|
| 320 |
+
|
| 321 |
+
self.eval()
|
| 322 |
+
if features is None and adj is None:
|
| 323 |
+
return self.forward(self.features, self.adj_norm)
|
| 324 |
+
else:
|
| 325 |
+
# adj = adj-sp.eye(adj.shape[0])
|
| 326 |
+
# adj[0,0]=0
|
| 327 |
+
|
| 328 |
+
if type(adj) is not torch.Tensor:
|
| 329 |
+
features, adj = utils.to_tensor(features, adj, device=self.device)
|
| 330 |
+
|
| 331 |
+
self.features = features
|
| 332 |
+
self.adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
|
| 333 |
+
adj = utils.to_scipy(adj)
|
| 334 |
+
|
| 335 |
+
adj = adj-sp.eye(adj.shape[0])
|
| 336 |
+
mx = normalize_adj(adj)
|
| 337 |
+
adj = utils.sparse_mx_to_torch_sparse_tensor(mx).to(self.device)
|
| 338 |
+
return self.forward(self.features, self.adj_norm)
|
| 339 |
+
|
| 340 |
+
@torch.no_grad()
|
| 341 |
+
def predict_unnorm(self, features=None, adj=None):
|
| 342 |
+
self.eval()
|
| 343 |
+
if features is None and adj is None:
|
| 344 |
+
return self.forward(self.features, self.adj_norm)
|
| 345 |
+
else:
|
| 346 |
+
if type(adj) is not torch.Tensor:
|
| 347 |
+
features, adj = utils.to_tensor(features, adj, device=self.device)
|
| 348 |
+
|
| 349 |
+
self.features = features
|
| 350 |
+
self.adj_norm = adj
|
| 351 |
+
return self.forward(self.features, self.adj_norm)
|
| 352 |
+
|
| 353 |
+
class MyLinear(Module):
|
| 354 |
+
"""Simple Linear layer, modified from https://github.com/tkipf/pygcn
|
| 355 |
+
"""
|
| 356 |
+
|
| 357 |
+
def __init__(self, in_features, out_features, with_bias=True):
|
| 358 |
+
super(MyLinear, self).__init__()
|
| 359 |
+
self.in_features = in_features
|
| 360 |
+
self.out_features = out_features
|
| 361 |
+
self.weight = Parameter(torch.FloatTensor(in_features, out_features))
|
| 362 |
+
if with_bias:
|
| 363 |
+
self.bias = Parameter(torch.FloatTensor(out_features))
|
| 364 |
+
else:
|
| 365 |
+
self.register_parameter('bias', None)
|
| 366 |
+
self.reset_parameters()
|
| 367 |
+
|
| 368 |
+
def reset_parameters(self):
|
| 369 |
+
# stdv = 1. / math.sqrt(self.weight.size(1))
|
| 370 |
+
stdv = 1. / math.sqrt(self.weight.T.size(1))
|
| 371 |
+
self.weight.data.uniform_(-stdv, stdv)
|
| 372 |
+
if self.bias is not None:
|
| 373 |
+
self.bias.data.uniform_(-stdv, stdv)
|
| 374 |
+
|
| 375 |
+
def forward(self, input):
|
| 376 |
+
if input.data.is_sparse:
|
| 377 |
+
support = torch.spmm(input, self.weight)
|
| 378 |
+
else:
|
| 379 |
+
support = torch.mm(input, self.weight)
|
| 380 |
+
output = support
|
| 381 |
+
if self.bias is not None:
|
| 382 |
+
return output + self.bias
|
| 383 |
+
else:
|
| 384 |
+
return output
|
| 385 |
+
|
| 386 |
+
def __repr__(self):
|
| 387 |
+
return self.__class__.__name__ + ' (' \
|
| 388 |
+
+ str(self.in_features) + ' -> ' \
|
| 389 |
+
+ str(self.out_features) + ')'
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def normalize_adj(mx):
|
| 394 |
+
"""Normalize sparse adjacency matrix,
|
| 395 |
+
A' = (D + I)^-1/2 * ( A + I ) * (D + I)^-1/2
|
| 396 |
+
Row-normalize sparse matrix
|
| 397 |
+
Parameters
|
| 398 |
+
----------
|
| 399 |
+
mx : scipy.sparse.csr_matrix
|
| 400 |
+
matrix to be normalized
|
| 401 |
+
Returns
|
| 402 |
+
-------
|
| 403 |
+
scipy.sprase.lil_matrix
|
| 404 |
+
normalized matrix
|
| 405 |
+
"""
|
| 406 |
+
|
| 407 |
+
# TODO: maybe using coo format would be better?
|
| 408 |
+
if type(mx) is not sp.lil.lil_matrix:
|
| 409 |
+
mx = mx.tolil()
|
| 410 |
+
mx = mx + sp.eye(mx.shape[0])
|
| 411 |
+
rowsum = np.array(mx.sum(1))
|
| 412 |
+
r_inv = np.power(rowsum, -1/2).flatten()
|
| 413 |
+
r_inv[np.isinf(r_inv)] = 0.
|
| 414 |
+
r_mat_inv = sp.diags(r_inv)
|
| 415 |
+
mx = r_mat_inv.dot(mx)
|
| 416 |
+
mx = mx.dot(r_mat_inv)
|
| 417 |
+
return mx
|
GCond/models/mygatconv.py
ADDED
|
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Union, Tuple, Optional
|
| 2 |
+
from torch_geometric.typing import (OptPairTensor, Adj, Size, NoneType,
|
| 3 |
+
OptTensor)
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from torch import Tensor
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from torch.nn import Parameter, Linear
|
| 9 |
+
from torch_sparse import SparseTensor, set_diag
|
| 10 |
+
from torch_geometric.nn.conv import MessagePassing
|
| 11 |
+
from torch_geometric.utils import remove_self_loops, add_self_loops, softmax
|
| 12 |
+
|
| 13 |
+
from torch_geometric.nn.inits import glorot, zeros
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class GATConv(MessagePassing):
|
| 17 |
+
r"""The graph attentional operator from the `"Graph Attention Networks"
|
| 18 |
+
<https://arxiv.org/abs/1710.10903>`_ paper
|
| 19 |
+
|
| 20 |
+
.. math::
|
| 21 |
+
\mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} +
|
| 22 |
+
\sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j},
|
| 23 |
+
|
| 24 |
+
where the attention coefficients :math:`\alpha_{i,j}` are computed as
|
| 25 |
+
|
| 26 |
+
.. math::
|
| 27 |
+
\alpha_{i,j} =
|
| 28 |
+
\frac{
|
| 29 |
+
\exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top}
|
| 30 |
+
[\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j]
|
| 31 |
+
\right)\right)}
|
| 32 |
+
{\sum_{k \in \mathcal{N}(i) \cup \{ i \}}
|
| 33 |
+
\exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top}
|
| 34 |
+
[\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k]
|
| 35 |
+
\right)\right)}.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
in_channels (int or tuple): Size of each input sample. A tuple
|
| 39 |
+
corresponds to the sizes of source and target dimensionalities.
|
| 40 |
+
out_channels (int): Size of each output sample.
|
| 41 |
+
heads (int, optional): Number of multi-head-attentions.
|
| 42 |
+
(default: :obj:`1`)
|
| 43 |
+
concat (bool, optional): If set to :obj:`False`, the multi-head
|
| 44 |
+
attentions are averaged instead of concatenated.
|
| 45 |
+
(default: :obj:`True`)
|
| 46 |
+
negative_slope (float, optional): LeakyReLU angle of the negative
|
| 47 |
+
slope. (default: :obj:`0.2`)
|
| 48 |
+
dropout (float, optional): Dropout probability of the normalized
|
| 49 |
+
attention coefficients which exposes each node to a stochastically
|
| 50 |
+
sampled neighborhood during training. (default: :obj:`0`)
|
| 51 |
+
add_self_loops (bool, optional): If set to :obj:`False`, will not add
|
| 52 |
+
self-loops to the input graph. (default: :obj:`True`)
|
| 53 |
+
bias (bool, optional): If set to :obj:`False`, the layer will not learn
|
| 54 |
+
an additive bias. (default: :obj:`True`)
|
| 55 |
+
**kwargs (optional): Additional arguments of
|
| 56 |
+
:class:`torch_geometric.nn.conv.MessagePassing`.
|
| 57 |
+
"""
|
| 58 |
+
_alpha: OptTensor
|
| 59 |
+
|
| 60 |
+
def __init__(self, in_channels: Union[int, Tuple[int, int]],
|
| 61 |
+
out_channels: int, heads: int = 1, concat: bool = True,
|
| 62 |
+
negative_slope: float = 0.2, dropout: float = 0.0,
|
| 63 |
+
add_self_loops: bool = True, bias: bool = True, **kwargs):
|
| 64 |
+
kwargs.setdefault('aggr', 'add')
|
| 65 |
+
super(GATConv, self).__init__(node_dim=0, **kwargs)
|
| 66 |
+
|
| 67 |
+
self.in_channels = in_channels
|
| 68 |
+
self.out_channels = out_channels
|
| 69 |
+
self.heads = heads
|
| 70 |
+
self.concat = concat
|
| 71 |
+
self.negative_slope = negative_slope
|
| 72 |
+
self.dropout = dropout
|
| 73 |
+
self.add_self_loops = add_self_loops
|
| 74 |
+
|
| 75 |
+
if isinstance(in_channels, int):
|
| 76 |
+
self.lin_l = Linear(in_channels, heads * out_channels, bias=False)
|
| 77 |
+
self.lin_r = self.lin_l
|
| 78 |
+
else:
|
| 79 |
+
self.lin_l = Linear(in_channels[0], heads * out_channels, False)
|
| 80 |
+
self.lin_r = Linear(in_channels[1], heads * out_channels, False)
|
| 81 |
+
|
| 82 |
+
self.att_l = Parameter(torch.Tensor(1, heads, out_channels))
|
| 83 |
+
self.att_r = Parameter(torch.Tensor(1, heads, out_channels))
|
| 84 |
+
|
| 85 |
+
if bias and concat:
|
| 86 |
+
self.bias = Parameter(torch.Tensor(heads * out_channels))
|
| 87 |
+
elif bias and not concat:
|
| 88 |
+
self.bias = Parameter(torch.Tensor(out_channels))
|
| 89 |
+
else:
|
| 90 |
+
self.register_parameter('bias', None)
|
| 91 |
+
|
| 92 |
+
self._alpha = None
|
| 93 |
+
|
| 94 |
+
self.reset_parameters()
|
| 95 |
+
self.edge_weight = None
|
| 96 |
+
|
| 97 |
+
def reset_parameters(self):
|
| 98 |
+
glorot(self.lin_l.weight)
|
| 99 |
+
glorot(self.lin_r.weight)
|
| 100 |
+
glorot(self.att_l)
|
| 101 |
+
glorot(self.att_r)
|
| 102 |
+
zeros(self.bias)
|
| 103 |
+
|
| 104 |
+
def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
|
| 105 |
+
size: Size = None, return_attention_weights=None, edge_weight=None):
|
| 106 |
+
# type: (Union[Tensor, OptPairTensor], Tensor, Size, NoneType) -> Tensor # noqa
|
| 107 |
+
# type: (Union[Tensor, OptPairTensor], SparseTensor, Size, NoneType) -> Tensor # noqa
|
| 108 |
+
# type: (Union[Tensor, OptPairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa
|
| 109 |
+
# type: (Union[Tensor, OptPairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa
|
| 110 |
+
r"""
|
| 111 |
+
Args:
|
| 112 |
+
return_attention_weights (bool, optional): If set to :obj:`True`,
|
| 113 |
+
will additionally return the tuple
|
| 114 |
+
:obj:`(edge_index, attention_weights)`, holding the computed
|
| 115 |
+
attention weights for each edge. (default: :obj:`None`)
|
| 116 |
+
"""
|
| 117 |
+
H, C = self.heads, self.out_channels
|
| 118 |
+
|
| 119 |
+
x_l: OptTensor = None
|
| 120 |
+
x_r: OptTensor = None
|
| 121 |
+
alpha_l: OptTensor = None
|
| 122 |
+
alpha_r: OptTensor = None
|
| 123 |
+
if isinstance(x, Tensor):
|
| 124 |
+
assert x.dim() == 2, 'Static graphs not supported in `GATConv`.'
|
| 125 |
+
x_l = x_r = self.lin_l(x).view(-1, H, C)
|
| 126 |
+
alpha_l = (x_l * self.att_l).sum(dim=-1)
|
| 127 |
+
alpha_r = (x_r * self.att_r).sum(dim=-1)
|
| 128 |
+
else:
|
| 129 |
+
x_l, x_r = x[0], x[1]
|
| 130 |
+
assert x[0].dim() == 2, 'Static graphs not supported in `GATConv`.'
|
| 131 |
+
x_l = self.lin_l(x_l).view(-1, H, C)
|
| 132 |
+
alpha_l = (x_l * self.att_l).sum(dim=-1)
|
| 133 |
+
if x_r is not None:
|
| 134 |
+
x_r = self.lin_r(x_r).view(-1, H, C)
|
| 135 |
+
alpha_r = (x_r * self.att_r).sum(dim=-1)
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
assert x_l is not None
|
| 139 |
+
assert alpha_l is not None
|
| 140 |
+
|
| 141 |
+
if self.add_self_loops:
|
| 142 |
+
if isinstance(edge_index, Tensor):
|
| 143 |
+
num_nodes = x_l.size(0)
|
| 144 |
+
if x_r is not None:
|
| 145 |
+
num_nodes = min(num_nodes, x_r.size(0))
|
| 146 |
+
if size is not None:
|
| 147 |
+
num_nodes = min(size[0], size[1])
|
| 148 |
+
edge_index, _ = remove_self_loops(edge_index)
|
| 149 |
+
edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
|
| 150 |
+
|
| 151 |
+
if edge_weight is not None:
|
| 152 |
+
if self.edge_weight is None:
|
| 153 |
+
self.edge_weight = edge_weight
|
| 154 |
+
|
| 155 |
+
if edge_index.size(1) != self.edge_weight.shape[0]:
|
| 156 |
+
self.edge_weight = edge_weight
|
| 157 |
+
|
| 158 |
+
elif isinstance(edge_index, SparseTensor):
|
| 159 |
+
edge_index = set_diag(edge_index)
|
| 160 |
+
|
| 161 |
+
# propagate_type: (x: OptPairTensor, alpha: OptPairTensor)
|
| 162 |
+
|
| 163 |
+
out = self.propagate(edge_index, x=(x_l, x_r),
|
| 164 |
+
alpha=(alpha_l, alpha_r), size=size)
|
| 165 |
+
|
| 166 |
+
alpha = self._alpha
|
| 167 |
+
self._alpha = None
|
| 168 |
+
|
| 169 |
+
if self.concat:
|
| 170 |
+
out = out.view(-1, self.heads * self.out_channels)
|
| 171 |
+
else:
|
| 172 |
+
out = out.mean(dim=1)
|
| 173 |
+
|
| 174 |
+
if self.bias is not None:
|
| 175 |
+
out += self.bias
|
| 176 |
+
|
| 177 |
+
if isinstance(return_attention_weights, bool):
|
| 178 |
+
assert alpha is not None
|
| 179 |
+
if isinstance(edge_index, Tensor):
|
| 180 |
+
return out, (edge_index, alpha)
|
| 181 |
+
elif isinstance(edge_index, SparseTensor):
|
| 182 |
+
return out, edge_index.set_value(alpha, layout='coo')
|
| 183 |
+
else:
|
| 184 |
+
return out
|
| 185 |
+
|
| 186 |
+
def message(self, x_j: Tensor, alpha_j: Tensor, alpha_i: OptTensor,
|
| 187 |
+
index: Tensor, ptr: OptTensor,
|
| 188 |
+
size_i: Optional[int]) -> Tensor:
|
| 189 |
+
alpha = alpha_j if alpha_i is None else alpha_j + alpha_i
|
| 190 |
+
alpha = F.leaky_relu(alpha, self.negative_slope)
|
| 191 |
+
alpha = softmax(alpha, index, ptr, size_i)
|
| 192 |
+
self._alpha = alpha
|
| 193 |
+
alpha = F.dropout(alpha, p=self.dropout, training=self.training)
|
| 194 |
+
|
| 195 |
+
if self.edge_weight is not None:
|
| 196 |
+
x_j = self.edge_weight.view(-1, 1, 1) * x_j
|
| 197 |
+
return x_j * alpha.unsqueeze(-1)
|
| 198 |
+
|
| 199 |
+
def __repr__(self):
|
| 200 |
+
return '{}({}, {}, heads={})'.format(self.__class__.__name__,
|
| 201 |
+
self.in_channels,
|
| 202 |
+
self.out_channels, self.heads)
|
| 203 |
+
|
GCond/models/mygraphsage.py
ADDED
|
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import math
|
| 4 |
+
import torch
|
| 5 |
+
import torch.optim as optim
|
| 6 |
+
from torch.nn.parameter import Parameter
|
| 7 |
+
from torch.nn.modules.module import Module
|
| 8 |
+
from deeprobust.graph import utils
|
| 9 |
+
from copy import deepcopy
|
| 10 |
+
from sklearn.metrics import f1_score
|
| 11 |
+
from torch.nn import init
|
| 12 |
+
import torch_sparse
|
| 13 |
+
from torch_geometric.data import NeighborSampler
|
| 14 |
+
from torch_sparse import SparseTensor
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class SageConvolution(Module):
|
| 18 |
+
"""Simple GCN layer, similar to https://github.com/tkipf/pygcn
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
def __init__(self, in_features, out_features, with_bias=True, root_weight=False):
|
| 22 |
+
super(SageConvolution, self).__init__()
|
| 23 |
+
self.in_features = in_features
|
| 24 |
+
self.out_features = out_features
|
| 25 |
+
self.weight_l = Parameter(torch.FloatTensor(in_features, out_features))
|
| 26 |
+
self.bias_l = Parameter(torch.FloatTensor(out_features))
|
| 27 |
+
self.weight_r = Parameter(torch.FloatTensor(in_features, out_features))
|
| 28 |
+
self.bias_r = Parameter(torch.FloatTensor(out_features))
|
| 29 |
+
self.reset_parameters()
|
| 30 |
+
self.root_weight = root_weight
|
| 31 |
+
# self.weight = Parameter(torch.FloatTensor(out_features, in_features))
|
| 32 |
+
# self.linear = torch.nn.Linear(self.in_features, self.out_features)
|
| 33 |
+
|
| 34 |
+
def reset_parameters(self):
|
| 35 |
+
# stdv = 1. / math.sqrt(self.weight.size(1))
|
| 36 |
+
stdv = 1. / math.sqrt(self.weight_l.T.size(1))
|
| 37 |
+
self.weight_l.data.uniform_(-stdv, stdv)
|
| 38 |
+
self.bias_l.data.uniform_(-stdv, stdv)
|
| 39 |
+
|
| 40 |
+
stdv = 1. / math.sqrt(self.weight_r.T.size(1))
|
| 41 |
+
self.weight_r.data.uniform_(-stdv, stdv)
|
| 42 |
+
self.bias_r.data.uniform_(-stdv, stdv)
|
| 43 |
+
|
| 44 |
+
def forward(self, input, adj, size=None):
|
| 45 |
+
""" Graph Convolutional Layer forward function
|
| 46 |
+
"""
|
| 47 |
+
if input.data.is_sparse:
|
| 48 |
+
support = torch.spmm(input, self.weight_l)
|
| 49 |
+
else:
|
| 50 |
+
support = torch.mm(input, self.weight_l)
|
| 51 |
+
if isinstance(adj, torch_sparse.SparseTensor):
|
| 52 |
+
output = torch_sparse.matmul(adj, support)
|
| 53 |
+
else:
|
| 54 |
+
output = torch.spmm(adj, support)
|
| 55 |
+
output = output + self.bias_l
|
| 56 |
+
|
| 57 |
+
if self.root_weight:
|
| 58 |
+
if size is not None:
|
| 59 |
+
output = output + input[:size[1]] @ self.weight_r + self.bias_r
|
| 60 |
+
else:
|
| 61 |
+
output = output + input @ self.weight_r + self.bias_r
|
| 62 |
+
else:
|
| 63 |
+
output = output
|
| 64 |
+
|
| 65 |
+
return output
|
| 66 |
+
|
| 67 |
+
def __repr__(self):
|
| 68 |
+
return self.__class__.__name__ + ' (' \
|
| 69 |
+
+ str(self.in_features) + ' -> ' \
|
| 70 |
+
+ str(self.out_features) + ')'
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class GraphSage(nn.Module):
|
| 74 |
+
|
| 75 |
+
def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5, lr=0.01, weight_decay=5e-4,
|
| 76 |
+
with_relu=True, with_bias=True, with_bn=False, device=None):
|
| 77 |
+
|
| 78 |
+
super(GraphSage, self).__init__()
|
| 79 |
+
|
| 80 |
+
assert device is not None, "Please specify 'device'!"
|
| 81 |
+
self.device = device
|
| 82 |
+
self.nfeat = nfeat
|
| 83 |
+
self.nclass = nclass
|
| 84 |
+
|
| 85 |
+
self.layers = nn.ModuleList([])
|
| 86 |
+
|
| 87 |
+
if nlayers == 1:
|
| 88 |
+
self.layers.append(SageConvolution(nfeat, nclass, with_bias=with_bias))
|
| 89 |
+
else:
|
| 90 |
+
if with_bn:
|
| 91 |
+
self.bns = torch.nn.ModuleList()
|
| 92 |
+
self.bns.append(nn.BatchNorm1d(nhid))
|
| 93 |
+
self.layers.append(SageConvolution(nfeat, nhid, with_bias=with_bias))
|
| 94 |
+
for i in range(nlayers-2):
|
| 95 |
+
self.layers.append(SageConvolution(nhid, nhid, with_bias=with_bias))
|
| 96 |
+
if with_bn:
|
| 97 |
+
self.bns.append(nn.BatchNorm1d(nhid))
|
| 98 |
+
self.layers.append(SageConvolution(nhid, nclass, with_bias=with_bias))
|
| 99 |
+
|
| 100 |
+
self.dropout = dropout
|
| 101 |
+
self.lr = lr
|
| 102 |
+
if not with_relu:
|
| 103 |
+
self.weight_decay = 0
|
| 104 |
+
else:
|
| 105 |
+
self.weight_decay = weight_decay
|
| 106 |
+
self.with_relu = with_relu
|
| 107 |
+
self.with_bn = with_bn
|
| 108 |
+
self.with_bias = with_bias
|
| 109 |
+
self.output = None
|
| 110 |
+
self.best_model = None
|
| 111 |
+
self.best_output = None
|
| 112 |
+
self.adj_norm = None
|
| 113 |
+
self.features = None
|
| 114 |
+
self.multi_label = None
|
| 115 |
+
|
| 116 |
+
def forward(self, x, adj):
|
| 117 |
+
for ix, layer in enumerate(self.layers):
|
| 118 |
+
x = layer(x, adj)
|
| 119 |
+
if ix != len(self.layers) - 1:
|
| 120 |
+
x = self.bns[ix](x) if self.with_bn else x
|
| 121 |
+
if self.with_relu:
|
| 122 |
+
x = F.relu(x)
|
| 123 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
| 124 |
+
|
| 125 |
+
if self.multi_label:
|
| 126 |
+
return torch.sigmoid(x)
|
| 127 |
+
else:
|
| 128 |
+
return F.log_softmax(x, dim=1)
|
| 129 |
+
|
| 130 |
+
def forward_sampler(self, x, adjs):
|
| 131 |
+
# TODO: do we need normalization?
|
| 132 |
+
# for ix, layer in enumerate(self.layers):
|
| 133 |
+
for ix, (adj, _, size) in enumerate(adjs):
|
| 134 |
+
# x_target = x[: size[1]]
|
| 135 |
+
# x = self.layers[ix]((x, x_target), edge_index)
|
| 136 |
+
# adj = adj.to(self.device)
|
| 137 |
+
x = self.layers[ix](x, adj, size=size)
|
| 138 |
+
if ix != len(self.layers) - 1:
|
| 139 |
+
x = self.bns[ix](x) if self.with_bn else x
|
| 140 |
+
if self.with_relu:
|
| 141 |
+
x = F.relu(x)
|
| 142 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
| 143 |
+
|
| 144 |
+
if self.multi_label:
|
| 145 |
+
return torch.sigmoid(x)
|
| 146 |
+
else:
|
| 147 |
+
return F.log_softmax(x, dim=1)
|
| 148 |
+
|
| 149 |
+
def forward_sampler_syn(self, x, adjs):
|
| 150 |
+
for ix, (adj) in enumerate(adjs):
|
| 151 |
+
x = self.layers[ix](x, adj)
|
| 152 |
+
if ix != len(self.layers) - 1:
|
| 153 |
+
x = self.bns[ix](x) if self.with_bn else x
|
| 154 |
+
if self.with_relu:
|
| 155 |
+
x = F.relu(x)
|
| 156 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
| 157 |
+
|
| 158 |
+
if self.multi_label:
|
| 159 |
+
return torch.sigmoid(x)
|
| 160 |
+
else:
|
| 161 |
+
return F.log_softmax(x, dim=1)
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
def initialize(self):
|
| 165 |
+
"""Initialize parameters of GCN.
|
| 166 |
+
"""
|
| 167 |
+
for layer in self.layers:
|
| 168 |
+
layer.reset_parameters()
|
| 169 |
+
if self.with_bn:
|
| 170 |
+
for bn in self.bns:
|
| 171 |
+
bn.reset_parameters()
|
| 172 |
+
|
| 173 |
+
def fit_with_val(self, features, adj, labels, data, train_iters=200, initialize=True, verbose=False, normalize=True, patience=None, noval=False, **kwargs):
|
| 174 |
+
'''data: full data class'''
|
| 175 |
+
if initialize:
|
| 176 |
+
self.initialize()
|
| 177 |
+
|
| 178 |
+
# features, adj, labels = data.feat_train, data.adj_train, data.labels_train
|
| 179 |
+
if type(adj) is not torch.Tensor:
|
| 180 |
+
features, adj, labels = utils.to_tensor(features, adj, labels, device=self.device)
|
| 181 |
+
else:
|
| 182 |
+
features = features.to(self.device)
|
| 183 |
+
adj = adj.to(self.device)
|
| 184 |
+
labels = labels.to(self.device)
|
| 185 |
+
|
| 186 |
+
if normalize:
|
| 187 |
+
if utils.is_sparse_tensor(adj):
|
| 188 |
+
adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
|
| 189 |
+
else:
|
| 190 |
+
adj_norm = utils.normalize_adj_tensor(adj)
|
| 191 |
+
else:
|
| 192 |
+
adj_norm = adj
|
| 193 |
+
|
| 194 |
+
if 'feat_norm' in kwargs and kwargs['feat_norm']:
|
| 195 |
+
from utils import row_normalize_tensor
|
| 196 |
+
features = row_normalize_tensor(features-features.min())
|
| 197 |
+
|
| 198 |
+
self.adj_norm = adj_norm
|
| 199 |
+
self.features = features
|
| 200 |
+
|
| 201 |
+
if len(labels.shape) > 1:
|
| 202 |
+
self.multi_label = True
|
| 203 |
+
self.loss = torch.nn.BCELoss()
|
| 204 |
+
else:
|
| 205 |
+
self.multi_label = False
|
| 206 |
+
self.loss = F.nll_loss
|
| 207 |
+
|
| 208 |
+
labels = labels.float() if self.multi_label else labels
|
| 209 |
+
self.labels = labels
|
| 210 |
+
|
| 211 |
+
if noval:
|
| 212 |
+
self._train_with_val(labels, data, train_iters, verbose, adj_val=True)
|
| 213 |
+
else:
|
| 214 |
+
self._train_with_val(labels, data, train_iters, verbose)
|
| 215 |
+
|
| 216 |
+
def _train_with_val(self, labels, data, train_iters, verbose, adj_val=False):
|
| 217 |
+
if adj_val:
|
| 218 |
+
feat_full, adj_full = data.feat_val, data.adj_val
|
| 219 |
+
else:
|
| 220 |
+
feat_full, adj_full = data.feat_full, data.adj_full
|
| 221 |
+
feat_full, adj_full = utils.to_tensor(feat_full, adj_full, device=self.device)
|
| 222 |
+
adj_full_norm = utils.normalize_adj_tensor(adj_full, sparse=True)
|
| 223 |
+
labels_val = torch.LongTensor(data.labels_val).to(self.device)
|
| 224 |
+
|
| 225 |
+
if verbose:
|
| 226 |
+
print('=== training gcn model ===')
|
| 227 |
+
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
| 228 |
+
|
| 229 |
+
adj_norm = self.adj_norm
|
| 230 |
+
node_idx = torch.arange(adj_norm.shape[0]).long()
|
| 231 |
+
|
| 232 |
+
edge_index = adj_norm.nonzero().T
|
| 233 |
+
adj_norm = SparseTensor(row=edge_index[0], col=edge_index[1],
|
| 234 |
+
value=adj_norm[edge_index[0], edge_index[1]], sparse_sizes=adj_norm.size()).t()
|
| 235 |
+
# edge_index = adj_norm._indices()
|
| 236 |
+
# adj_norm = SparseTensor(row=edge_index[0], col=edge_index[1],
|
| 237 |
+
# value=adj_norm._values(), sparse_sizes=adj_norm.size()).t()
|
| 238 |
+
|
| 239 |
+
if adj_norm.density() > 0.5: # if the weighted graph is too dense, we need a larger neighborhood size
|
| 240 |
+
sizes = [30, 20]
|
| 241 |
+
else:
|
| 242 |
+
sizes = [5, 5]
|
| 243 |
+
train_loader = NeighborSampler(adj_norm,
|
| 244 |
+
node_idx=node_idx,
|
| 245 |
+
sizes=sizes, batch_size=len(node_idx),
|
| 246 |
+
num_workers=0, return_e_id=False,
|
| 247 |
+
num_nodes=adj_norm.size(0),
|
| 248 |
+
shuffle=True)
|
| 249 |
+
|
| 250 |
+
best_acc_val = 0
|
| 251 |
+
for i in range(train_iters):
|
| 252 |
+
if i == train_iters // 2:
|
| 253 |
+
lr = self.lr*0.1
|
| 254 |
+
optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=self.weight_decay)
|
| 255 |
+
|
| 256 |
+
self.train()
|
| 257 |
+
# optimizer.zero_grad()
|
| 258 |
+
# output = self.forward(self.features, self.adj_norm)
|
| 259 |
+
# loss_train = self.loss(output, labels)
|
| 260 |
+
# loss_train.backward()
|
| 261 |
+
# optimizer.step()
|
| 262 |
+
|
| 263 |
+
for batch_size, n_id, adjs in train_loader:
|
| 264 |
+
adjs = [adj.to(self.device) for adj in adjs]
|
| 265 |
+
optimizer.zero_grad()
|
| 266 |
+
out = self.forward_sampler(self.features[n_id], adjs)
|
| 267 |
+
loss_train = F.nll_loss(out, labels[n_id[:batch_size]])
|
| 268 |
+
loss_train.backward()
|
| 269 |
+
optimizer.step()
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
if verbose and i % 100 == 0:
|
| 273 |
+
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
|
| 274 |
+
|
| 275 |
+
with torch.no_grad():
|
| 276 |
+
self.eval()
|
| 277 |
+
output = self.forward(feat_full, adj_full_norm)
|
| 278 |
+
if adj_val:
|
| 279 |
+
loss_val = F.nll_loss(output, labels_val)
|
| 280 |
+
acc_val = utils.accuracy(output, labels_val)
|
| 281 |
+
else:
|
| 282 |
+
loss_val = F.nll_loss(output[data.idx_val], labels_val)
|
| 283 |
+
acc_val = utils.accuracy(output[data.idx_val], labels_val)
|
| 284 |
+
|
| 285 |
+
if acc_val > best_acc_val:
|
| 286 |
+
best_acc_val = acc_val
|
| 287 |
+
self.output = output
|
| 288 |
+
weights = deepcopy(self.state_dict())
|
| 289 |
+
|
| 290 |
+
if verbose:
|
| 291 |
+
print('=== picking the best model according to the performance on validation ===')
|
| 292 |
+
self.load_state_dict(weights)
|
| 293 |
+
|
| 294 |
+
def test(self, idx_test):
|
| 295 |
+
"""Evaluate GCN performance on test set.
|
| 296 |
+
Parameters
|
| 297 |
+
----------
|
| 298 |
+
idx_test :
|
| 299 |
+
node testing indices
|
| 300 |
+
"""
|
| 301 |
+
self.eval()
|
| 302 |
+
output = self.predict()
|
| 303 |
+
# output = self.output
|
| 304 |
+
loss_test = F.nll_loss(output[idx_test], self.labels[idx_test])
|
| 305 |
+
acc_test = utils.accuracy(output[idx_test], self.labels[idx_test])
|
| 306 |
+
print("Test set results:",
|
| 307 |
+
"loss= {:.4f}".format(loss_test.item()),
|
| 308 |
+
"accuracy= {:.4f}".format(acc_test.item()))
|
| 309 |
+
return acc_test.item()
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
@torch.no_grad()
|
| 313 |
+
def predict(self, features=None, adj=None):
|
| 314 |
+
"""By default, the inputs should be unnormalized adjacency
|
| 315 |
+
Parameters
|
| 316 |
+
----------
|
| 317 |
+
features :
|
| 318 |
+
node features. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.
|
| 319 |
+
adj :
|
| 320 |
+
adjcency matrix. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.
|
| 321 |
+
Returns
|
| 322 |
+
-------
|
| 323 |
+
torch.FloatTensor
|
| 324 |
+
output (log probabilities) of GCN
|
| 325 |
+
"""
|
| 326 |
+
|
| 327 |
+
self.eval()
|
| 328 |
+
if features is None and adj is None:
|
| 329 |
+
return self.forward(self.features, self.adj_norm)
|
| 330 |
+
else:
|
| 331 |
+
if type(adj) is not torch.Tensor:
|
| 332 |
+
features, adj = utils.to_tensor(features, adj, device=self.device)
|
| 333 |
+
|
| 334 |
+
self.features = features
|
| 335 |
+
if utils.is_sparse_tensor(adj):
|
| 336 |
+
self.adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
|
| 337 |
+
else:
|
| 338 |
+
self.adj_norm = utils.normalize_adj_tensor(adj)
|
| 339 |
+
return self.forward(self.features, self.adj_norm)
|
| 340 |
+
|
| 341 |
+
@torch.no_grad()
|
| 342 |
+
def predict_unnorm(self, features=None, adj=None):
|
| 343 |
+
self.eval()
|
| 344 |
+
if features is None and adj is None:
|
| 345 |
+
return self.forward(self.features, self.adj_norm)
|
| 346 |
+
else:
|
| 347 |
+
if type(adj) is not torch.Tensor:
|
| 348 |
+
features, adj = utils.to_tensor(features, adj, device=self.device)
|
| 349 |
+
|
| 350 |
+
self.features = features
|
| 351 |
+
self.adj_norm = adj
|
| 352 |
+
return self.forward(self.features, self.adj_norm)
|
| 353 |
+
|
GCond/models/parametrized_adj.py
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch.nn as nn
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import math
|
| 4 |
+
import torch
|
| 5 |
+
import torch.optim as optim
|
| 6 |
+
from torch.nn.parameter import Parameter
|
| 7 |
+
from torch.nn.modules.module import Module
|
| 8 |
+
from itertools import product
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
class PGE(nn.Module):
|
| 12 |
+
|
| 13 |
+
def __init__(self, nfeat, nnodes, nhid=128, nlayers=3, device=None, args=None):
|
| 14 |
+
super(PGE, self).__init__()
|
| 15 |
+
if args.dataset in ['ogbn-arxiv', 'arxiv', 'flickr']:
|
| 16 |
+
nhid = 256
|
| 17 |
+
if args.dataset in ['reddit']:
|
| 18 |
+
nhid = 256
|
| 19 |
+
if args.reduction_rate==0.01:
|
| 20 |
+
nhid = 128
|
| 21 |
+
nlayers = 3
|
| 22 |
+
# nhid = 128
|
| 23 |
+
|
| 24 |
+
self.layers = nn.ModuleList([])
|
| 25 |
+
self.layers.append(nn.Linear(nfeat*2, nhid))
|
| 26 |
+
self.bns = torch.nn.ModuleList()
|
| 27 |
+
self.bns.append(nn.BatchNorm1d(nhid))
|
| 28 |
+
for i in range(nlayers-2):
|
| 29 |
+
self.layers.append(nn.Linear(nhid, nhid))
|
| 30 |
+
self.bns.append(nn.BatchNorm1d(nhid))
|
| 31 |
+
self.layers.append(nn.Linear(nhid, 1))
|
| 32 |
+
|
| 33 |
+
edge_index = np.array(list(product(range(nnodes), range(nnodes))))
|
| 34 |
+
self.edge_index = edge_index.T
|
| 35 |
+
self.nnodes = nnodes
|
| 36 |
+
self.device = device
|
| 37 |
+
self.reset_parameters()
|
| 38 |
+
self.cnt = 0
|
| 39 |
+
self.args = args
|
| 40 |
+
self.nnodes = nnodes
|
| 41 |
+
|
| 42 |
+
def forward(self, x, inference=False):
|
| 43 |
+
if self.args.dataset == 'reddit' and self.args.reduction_rate >= 0.01:
|
| 44 |
+
edge_index = self.edge_index
|
| 45 |
+
n_part = 5
|
| 46 |
+
splits = np.array_split(np.arange(edge_index.shape[1]), n_part)
|
| 47 |
+
edge_embed = []
|
| 48 |
+
for idx in splits:
|
| 49 |
+
tmp_edge_embed = torch.cat([x[edge_index[0][idx]],
|
| 50 |
+
x[edge_index[1][idx]]], axis=1)
|
| 51 |
+
for ix, layer in enumerate(self.layers):
|
| 52 |
+
tmp_edge_embed = layer(tmp_edge_embed)
|
| 53 |
+
if ix != len(self.layers) - 1:
|
| 54 |
+
tmp_edge_embed = self.bns[ix](tmp_edge_embed)
|
| 55 |
+
tmp_edge_embed = F.relu(tmp_edge_embed)
|
| 56 |
+
edge_embed.append(tmp_edge_embed)
|
| 57 |
+
edge_embed = torch.cat(edge_embed)
|
| 58 |
+
else:
|
| 59 |
+
edge_index = self.edge_index
|
| 60 |
+
edge_embed = torch.cat([x[edge_index[0]],
|
| 61 |
+
x[edge_index[1]]], axis=1)
|
| 62 |
+
for ix, layer in enumerate(self.layers):
|
| 63 |
+
edge_embed = layer(edge_embed)
|
| 64 |
+
if ix != len(self.layers) - 1:
|
| 65 |
+
edge_embed = self.bns[ix](edge_embed)
|
| 66 |
+
edge_embed = F.relu(edge_embed)
|
| 67 |
+
|
| 68 |
+
adj = edge_embed.reshape(self.nnodes, self.nnodes)
|
| 69 |
+
|
| 70 |
+
adj = (adj + adj.T)/2
|
| 71 |
+
adj = torch.sigmoid(adj)
|
| 72 |
+
adj = adj - torch.diag(torch.diag(adj, 0))
|
| 73 |
+
return adj
|
| 74 |
+
|
| 75 |
+
@torch.no_grad()
|
| 76 |
+
def inference(self, x):
|
| 77 |
+
# self.eval()
|
| 78 |
+
adj_syn = self.forward(x, inference=True)
|
| 79 |
+
return adj_syn
|
| 80 |
+
|
| 81 |
+
def reset_parameters(self):
|
| 82 |
+
def weight_reset(m):
|
| 83 |
+
if isinstance(m, nn.Linear):
|
| 84 |
+
m.reset_parameters()
|
| 85 |
+
if isinstance(m, nn.BatchNorm1d):
|
| 86 |
+
m.reset_parameters()
|
| 87 |
+
self.apply(weight_reset)
|
| 88 |
+
|
GCond/models/sgc.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
'''one transformation with multiple propagation'''
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import math
|
| 5 |
+
import torch
|
| 6 |
+
import torch.optim as optim
|
| 7 |
+
from torch.nn.parameter import Parameter
|
| 8 |
+
from torch.nn.modules.module import Module
|
| 9 |
+
from deeprobust.graph import utils
|
| 10 |
+
from copy import deepcopy
|
| 11 |
+
from sklearn.metrics import f1_score
|
| 12 |
+
from torch.nn import init
|
| 13 |
+
import torch_sparse
|
| 14 |
+
|
| 15 |
+
class GraphConvolution(Module):
|
| 16 |
+
"""Simple GCN layer, similar to https://github.com/tkipf/pygcn
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, in_features, out_features, with_bias=True):
|
| 20 |
+
super(GraphConvolution, self).__init__()
|
| 21 |
+
self.in_features = in_features
|
| 22 |
+
self.out_features = out_features
|
| 23 |
+
self.weight = Parameter(torch.FloatTensor(in_features, out_features))
|
| 24 |
+
self.bias = Parameter(torch.FloatTensor(out_features))
|
| 25 |
+
self.reset_parameters()
|
| 26 |
+
|
| 27 |
+
def reset_parameters(self):
|
| 28 |
+
# stdv = 1. / math.sqrt(self.weight.size(1))
|
| 29 |
+
stdv = 1. / math.sqrt(self.weight.T.size(1))
|
| 30 |
+
self.weight.data.uniform_(-stdv, stdv)
|
| 31 |
+
if self.bias is not None:
|
| 32 |
+
self.bias.data.uniform_(-stdv, stdv)
|
| 33 |
+
|
| 34 |
+
def forward(self, input, adj):
|
| 35 |
+
""" Graph Convolutional Layer forward function
|
| 36 |
+
"""
|
| 37 |
+
if input.data.is_sparse:
|
| 38 |
+
support = torch.spmm(input, self.weight)
|
| 39 |
+
else:
|
| 40 |
+
support = torch.mm(input, self.weight)
|
| 41 |
+
if isinstance(adj, torch_sparse.SparseTensor):
|
| 42 |
+
output = torch_sparse.matmul(adj, support)
|
| 43 |
+
else:
|
| 44 |
+
output = torch.spmm(adj, support)
|
| 45 |
+
if self.bias is not None:
|
| 46 |
+
return output + self.bias
|
| 47 |
+
else:
|
| 48 |
+
return output
|
| 49 |
+
|
| 50 |
+
def __repr__(self):
|
| 51 |
+
return self.__class__.__name__ + ' (' \
|
| 52 |
+
+ str(self.in_features) + ' -> ' \
|
| 53 |
+
+ str(self.out_features) + ')'
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class SGC(nn.Module):
|
| 57 |
+
|
| 58 |
+
def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5, lr=0.01, weight_decay=5e-4,
|
| 59 |
+
with_relu=True, with_bias=True, with_bn=False, device=None):
|
| 60 |
+
|
| 61 |
+
super(SGC, self).__init__()
|
| 62 |
+
|
| 63 |
+
assert device is not None, "Please specify 'device'!"
|
| 64 |
+
self.device = device
|
| 65 |
+
self.nfeat = nfeat
|
| 66 |
+
self.nclass = nclass
|
| 67 |
+
|
| 68 |
+
self.conv = GraphConvolution(nfeat, nclass, with_bias=with_bias)
|
| 69 |
+
|
| 70 |
+
self.nlayers = nlayers
|
| 71 |
+
self.dropout = dropout
|
| 72 |
+
self.lr = lr
|
| 73 |
+
if not with_relu:
|
| 74 |
+
self.weight_decay = 0
|
| 75 |
+
else:
|
| 76 |
+
self.weight_decay = weight_decay
|
| 77 |
+
self.with_relu = with_relu
|
| 78 |
+
if with_bn:
|
| 79 |
+
print('Warning: SGC does not have bn!!!')
|
| 80 |
+
self.with_bn = False
|
| 81 |
+
self.with_bias = with_bias
|
| 82 |
+
self.output = None
|
| 83 |
+
self.best_model = None
|
| 84 |
+
self.best_output = None
|
| 85 |
+
self.adj_norm = None
|
| 86 |
+
self.features = None
|
| 87 |
+
self.multi_label = None
|
| 88 |
+
|
| 89 |
+
def forward(self, x, adj):
|
| 90 |
+
weight = self.conv.weight
|
| 91 |
+
bias = self.conv.bias
|
| 92 |
+
x = torch.mm(x, weight)
|
| 93 |
+
for i in range(self.nlayers):
|
| 94 |
+
x = torch.spmm(adj, x)
|
| 95 |
+
x = x + bias
|
| 96 |
+
if self.multi_label:
|
| 97 |
+
return torch.sigmoid(x)
|
| 98 |
+
else:
|
| 99 |
+
return F.log_softmax(x, dim=1)
|
| 100 |
+
|
| 101 |
+
def forward_sampler(self, x, adjs):
|
| 102 |
+
weight = self.conv.weight
|
| 103 |
+
bias = self.conv.bias
|
| 104 |
+
x = torch.mm(x, weight)
|
| 105 |
+
for ix, (adj, _, size) in enumerate(adjs):
|
| 106 |
+
x = torch_sparse.matmul(adj, x)
|
| 107 |
+
x = x + bias
|
| 108 |
+
if self.multi_label:
|
| 109 |
+
return torch.sigmoid(x)
|
| 110 |
+
else:
|
| 111 |
+
return F.log_softmax(x, dim=1)
|
| 112 |
+
|
| 113 |
+
def forward_sampler_syn(self, x, adjs):
|
| 114 |
+
weight = self.conv.weight
|
| 115 |
+
bias = self.conv.bias
|
| 116 |
+
x = torch.mm(x, weight)
|
| 117 |
+
for ix, (adj) in enumerate(adjs):
|
| 118 |
+
if type(adj) == torch.Tensor:
|
| 119 |
+
x = adj @ x
|
| 120 |
+
else:
|
| 121 |
+
x = torch_sparse.matmul(adj, x)
|
| 122 |
+
x = x + bias
|
| 123 |
+
if self.multi_label:
|
| 124 |
+
return torch.sigmoid(x)
|
| 125 |
+
else:
|
| 126 |
+
return F.log_softmax(x, dim=1)
|
| 127 |
+
|
| 128 |
+
def initialize(self):
|
| 129 |
+
"""Initialize parameters of GCN.
|
| 130 |
+
"""
|
| 131 |
+
self.conv.reset_parameters()
|
| 132 |
+
if self.with_bn:
|
| 133 |
+
for bn in self.bns:
|
| 134 |
+
bn.reset_parameters()
|
| 135 |
+
|
| 136 |
+
def fit_with_val(self, features, adj, labels, data, train_iters=200, initialize=True, verbose=False, normalize=True, patience=None, noval=False, **kwargs):
|
| 137 |
+
'''data: full data class'''
|
| 138 |
+
if initialize:
|
| 139 |
+
self.initialize()
|
| 140 |
+
|
| 141 |
+
# features, adj, labels = data.feat_train, data.adj_train, data.labels_train
|
| 142 |
+
if type(adj) is not torch.Tensor:
|
| 143 |
+
features, adj, labels = utils.to_tensor(features, adj, labels, device=self.device)
|
| 144 |
+
else:
|
| 145 |
+
features = features.to(self.device)
|
| 146 |
+
adj = adj.to(self.device)
|
| 147 |
+
labels = labels.to(self.device)
|
| 148 |
+
|
| 149 |
+
if normalize:
|
| 150 |
+
if utils.is_sparse_tensor(adj):
|
| 151 |
+
adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
|
| 152 |
+
else:
|
| 153 |
+
adj_norm = utils.normalize_adj_tensor(adj)
|
| 154 |
+
else:
|
| 155 |
+
adj_norm = adj
|
| 156 |
+
|
| 157 |
+
if 'feat_norm' in kwargs and kwargs['feat_norm']:
|
| 158 |
+
from utils import row_normalize_tensor
|
| 159 |
+
features = row_normalize_tensor(features-features.min())
|
| 160 |
+
|
| 161 |
+
self.adj_norm = adj_norm
|
| 162 |
+
self.features = features
|
| 163 |
+
|
| 164 |
+
if len(labels.shape) > 1:
|
| 165 |
+
self.multi_label = True
|
| 166 |
+
self.loss = torch.nn.BCELoss()
|
| 167 |
+
else:
|
| 168 |
+
self.multi_label = False
|
| 169 |
+
self.loss = F.nll_loss
|
| 170 |
+
|
| 171 |
+
labels = labels.float() if self.multi_label else labels
|
| 172 |
+
self.labels = labels
|
| 173 |
+
|
| 174 |
+
if noval:
|
| 175 |
+
self._train_with_val(labels, data, train_iters, verbose, adj_val=True)
|
| 176 |
+
else:
|
| 177 |
+
self._train_with_val(labels, data, train_iters, verbose)
|
| 178 |
+
|
| 179 |
+
def _train_with_val(self, labels, data, train_iters, verbose, adj_val=False):
|
| 180 |
+
if adj_val:
|
| 181 |
+
feat_full, adj_full = data.feat_val, data.adj_val
|
| 182 |
+
else:
|
| 183 |
+
feat_full, adj_full = data.feat_full, data.adj_full
|
| 184 |
+
|
| 185 |
+
feat_full, adj_full = utils.to_tensor(feat_full, adj_full, device=self.device)
|
| 186 |
+
adj_full_norm = utils.normalize_adj_tensor(adj_full, sparse=True)
|
| 187 |
+
labels_val = torch.LongTensor(data.labels_val).to(self.device)
|
| 188 |
+
|
| 189 |
+
if verbose:
|
| 190 |
+
print('=== training gcn model ===')
|
| 191 |
+
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
| 192 |
+
|
| 193 |
+
best_acc_val = 0
|
| 194 |
+
|
| 195 |
+
for i in range(train_iters):
|
| 196 |
+
if i == train_iters // 2:
|
| 197 |
+
lr = self.lr*0.1
|
| 198 |
+
optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=self.weight_decay)
|
| 199 |
+
|
| 200 |
+
self.train()
|
| 201 |
+
optimizer.zero_grad()
|
| 202 |
+
output = self.forward(self.features, self.adj_norm)
|
| 203 |
+
loss_train = self.loss(output, labels)
|
| 204 |
+
loss_train.backward()
|
| 205 |
+
optimizer.step()
|
| 206 |
+
|
| 207 |
+
if verbose and i % 100 == 0:
|
| 208 |
+
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
|
| 209 |
+
|
| 210 |
+
with torch.no_grad():
|
| 211 |
+
self.eval()
|
| 212 |
+
output = self.forward(feat_full, adj_full_norm)
|
| 213 |
+
if adj_val:
|
| 214 |
+
loss_val = F.nll_loss(output, labels_val)
|
| 215 |
+
acc_val = utils.accuracy(output, labels_val)
|
| 216 |
+
else:
|
| 217 |
+
loss_val = F.nll_loss(output[data.idx_val], labels_val)
|
| 218 |
+
acc_val = utils.accuracy(output[data.idx_val], labels_val)
|
| 219 |
+
|
| 220 |
+
if acc_val > best_acc_val:
|
| 221 |
+
best_acc_val = acc_val
|
| 222 |
+
self.output = output
|
| 223 |
+
weights = deepcopy(self.state_dict())
|
| 224 |
+
|
| 225 |
+
if verbose:
|
| 226 |
+
print('=== picking the best model according to the performance on validation ===')
|
| 227 |
+
self.load_state_dict(weights)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def test(self, idx_test):
|
| 231 |
+
"""Evaluate GCN performance on test set.
|
| 232 |
+
Parameters
|
| 233 |
+
----------
|
| 234 |
+
idx_test :
|
| 235 |
+
node testing indices
|
| 236 |
+
"""
|
| 237 |
+
self.eval()
|
| 238 |
+
output = self.predict()
|
| 239 |
+
# output = self.output
|
| 240 |
+
loss_test = F.nll_loss(output[idx_test], self.labels[idx_test])
|
| 241 |
+
acc_test = utils.accuracy(output[idx_test], self.labels[idx_test])
|
| 242 |
+
print("Test set results:",
|
| 243 |
+
"loss= {:.4f}".format(loss_test.item()),
|
| 244 |
+
"accuracy= {:.4f}".format(acc_test.item()))
|
| 245 |
+
return acc_test.item()
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
@torch.no_grad()
|
| 249 |
+
def predict(self, features=None, adj=None):
|
| 250 |
+
"""By default, the inputs should be unnormalized adjacency
|
| 251 |
+
Parameters
|
| 252 |
+
----------
|
| 253 |
+
features :
|
| 254 |
+
node features. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.
|
| 255 |
+
adj :
|
| 256 |
+
adjcency matrix. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.
|
| 257 |
+
Returns
|
| 258 |
+
-------
|
| 259 |
+
torch.FloatTensor
|
| 260 |
+
output (log probabilities) of GCN
|
| 261 |
+
"""
|
| 262 |
+
|
| 263 |
+
self.eval()
|
| 264 |
+
if features is None and adj is None:
|
| 265 |
+
return self.forward(self.features, self.adj_norm)
|
| 266 |
+
else:
|
| 267 |
+
if type(adj) is not torch.Tensor:
|
| 268 |
+
features, adj = utils.to_tensor(features, adj, device=self.device)
|
| 269 |
+
|
| 270 |
+
self.features = features
|
| 271 |
+
if utils.is_sparse_tensor(adj):
|
| 272 |
+
self.adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
|
| 273 |
+
else:
|
| 274 |
+
self.adj_norm = utils.normalize_adj_tensor(adj)
|
| 275 |
+
return self.forward(self.features, self.adj_norm)
|
| 276 |
+
|
| 277 |
+
@torch.no_grad()
|
| 278 |
+
def predict_unnorm(self, features=None, adj=None):
|
| 279 |
+
self.eval()
|
| 280 |
+
if features is None and adj is None:
|
| 281 |
+
return self.forward(self.features, self.adj_norm)
|
| 282 |
+
else:
|
| 283 |
+
if type(adj) is not torch.Tensor:
|
| 284 |
+
features, adj = utils.to_tensor(features, adj, device=self.device)
|
| 285 |
+
|
| 286 |
+
self.features = features
|
| 287 |
+
self.adj_norm = adj
|
| 288 |
+
return self.forward(self.features, self.adj_norm)
|
| 289 |
+
|
| 290 |
+
|
GCond/models/sgc_multi.py
ADDED
|
@@ -0,0 +1,315 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""multiple transformaiton and multiple propagation"""
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import math
|
| 5 |
+
import torch
|
| 6 |
+
import torch.optim as optim
|
| 7 |
+
from torch.nn.parameter import Parameter
|
| 8 |
+
from torch.nn.modules.module import Module
|
| 9 |
+
from deeprobust.graph import utils
|
| 10 |
+
from copy import deepcopy
|
| 11 |
+
from sklearn.metrics import f1_score
|
| 12 |
+
from torch.nn import init
|
| 13 |
+
import torch_sparse
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class SGC(nn.Module):
|
| 17 |
+
|
| 18 |
+
def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5, lr=0.01, weight_decay=5e-4,
|
| 19 |
+
ntrans=2, with_bias=True, with_bn=False, device=None):
|
| 20 |
+
|
| 21 |
+
"""nlayers indicates the number of propagations"""
|
| 22 |
+
super(SGC, self).__init__()
|
| 23 |
+
|
| 24 |
+
assert device is not None, "Please specify 'device'!"
|
| 25 |
+
self.device = device
|
| 26 |
+
self.nfeat = nfeat
|
| 27 |
+
self.nclass = nclass
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
self.layers = nn.ModuleList([])
|
| 31 |
+
if ntrans == 1:
|
| 32 |
+
self.layers.append(MyLinear(nfeat, nclass))
|
| 33 |
+
else:
|
| 34 |
+
self.layers.append(MyLinear(nfeat, nhid))
|
| 35 |
+
if with_bn:
|
| 36 |
+
self.bns = torch.nn.ModuleList()
|
| 37 |
+
self.bns.append(nn.BatchNorm1d(nhid))
|
| 38 |
+
for i in range(ntrans-2):
|
| 39 |
+
if with_bn:
|
| 40 |
+
self.bns.append(nn.BatchNorm1d(nhid))
|
| 41 |
+
self.layers.append(MyLinear(nhid, nhid))
|
| 42 |
+
self.layers.append(MyLinear(nhid, nclass))
|
| 43 |
+
|
| 44 |
+
self.nlayers = nlayers
|
| 45 |
+
self.dropout = dropout
|
| 46 |
+
self.lr = lr
|
| 47 |
+
self.with_bn = with_bn
|
| 48 |
+
self.with_bias = with_bias
|
| 49 |
+
self.weight_decay = weight_decay
|
| 50 |
+
self.output = None
|
| 51 |
+
self.best_model = None
|
| 52 |
+
self.best_output = None
|
| 53 |
+
self.adj_norm = None
|
| 54 |
+
self.features = None
|
| 55 |
+
self.multi_label = None
|
| 56 |
+
|
| 57 |
+
def forward(self, x, adj):
|
| 58 |
+
for ix, layer in enumerate(self.layers):
|
| 59 |
+
x = layer(x)
|
| 60 |
+
if ix != len(self.layers) - 1:
|
| 61 |
+
x = self.bns[ix](x) if self.with_bn else x
|
| 62 |
+
x = F.relu(x)
|
| 63 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
| 64 |
+
|
| 65 |
+
for i in range(self.nlayers):
|
| 66 |
+
x = torch.spmm(adj, x)
|
| 67 |
+
|
| 68 |
+
if self.multi_label:
|
| 69 |
+
return torch.sigmoid(x)
|
| 70 |
+
else:
|
| 71 |
+
return F.log_softmax(x, dim=1)
|
| 72 |
+
|
| 73 |
+
def forward_sampler(self, x, adjs):
|
| 74 |
+
for ix, layer in enumerate(self.layers):
|
| 75 |
+
x = layer(x)
|
| 76 |
+
if ix != len(self.layers) - 1:
|
| 77 |
+
x = self.bns[ix](x) if self.with_bn else x
|
| 78 |
+
x = F.relu(x)
|
| 79 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
| 80 |
+
|
| 81 |
+
for ix, (adj, _, size) in enumerate(adjs):
|
| 82 |
+
# x_target = x[: size[1]]
|
| 83 |
+
# x = self.layers[ix]((x, x_target), edge_index)
|
| 84 |
+
# adj = adj.to(self.device)
|
| 85 |
+
x = torch_sparse.matmul(adj, x)
|
| 86 |
+
|
| 87 |
+
if self.multi_label:
|
| 88 |
+
return torch.sigmoid(x)
|
| 89 |
+
else:
|
| 90 |
+
return F.log_softmax(x, dim=1)
|
| 91 |
+
|
| 92 |
+
def forward_sampler_syn(self, x, adjs):
|
| 93 |
+
for ix, layer in enumerate(self.layers):
|
| 94 |
+
x = layer(x)
|
| 95 |
+
if ix != len(self.layers) - 1:
|
| 96 |
+
x = self.bns[ix](x) if self.with_bn else x
|
| 97 |
+
x = F.relu(x)
|
| 98 |
+
x = F.dropout(x, self.dropout, training=self.training)
|
| 99 |
+
|
| 100 |
+
for ix, (adj) in enumerate(adjs):
|
| 101 |
+
if type(adj) == torch.Tensor:
|
| 102 |
+
x = adj @ x
|
| 103 |
+
else:
|
| 104 |
+
x = torch_sparse.matmul(adj, x)
|
| 105 |
+
|
| 106 |
+
if self.multi_label:
|
| 107 |
+
return torch.sigmoid(x)
|
| 108 |
+
else:
|
| 109 |
+
return F.log_softmax(x, dim=1)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def initialize(self):
|
| 113 |
+
"""Initialize parameters of GCN.
|
| 114 |
+
"""
|
| 115 |
+
for layer in self.layers:
|
| 116 |
+
layer.reset_parameters()
|
| 117 |
+
if self.with_bn:
|
| 118 |
+
for bn in self.bns:
|
| 119 |
+
bn.reset_parameters()
|
| 120 |
+
|
| 121 |
+
def fit_with_val(self, features, adj, labels, data, train_iters=200, initialize=True, verbose=False, normalize=True, patience=None, noval=False, **kwargs):
|
| 122 |
+
'''data: full data class'''
|
| 123 |
+
if initialize:
|
| 124 |
+
self.initialize()
|
| 125 |
+
|
| 126 |
+
# features, adj, labels = data.feat_train, data.adj_train, data.labels_train
|
| 127 |
+
if type(adj) is not torch.Tensor:
|
| 128 |
+
features, adj, labels = utils.to_tensor(features, adj, labels, device=self.device)
|
| 129 |
+
else:
|
| 130 |
+
features = features.to(self.device)
|
| 131 |
+
adj = adj.to(self.device)
|
| 132 |
+
labels = labels.to(self.device)
|
| 133 |
+
|
| 134 |
+
if normalize:
|
| 135 |
+
if utils.is_sparse_tensor(adj):
|
| 136 |
+
adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
|
| 137 |
+
else:
|
| 138 |
+
adj_norm = utils.normalize_adj_tensor(adj)
|
| 139 |
+
else:
|
| 140 |
+
adj_norm = adj
|
| 141 |
+
|
| 142 |
+
if 'feat_norm' in kwargs and kwargs['feat_norm']:
|
| 143 |
+
from utils import row_normalize_tensor
|
| 144 |
+
features = row_normalize_tensor(features-features.min())
|
| 145 |
+
|
| 146 |
+
self.adj_norm = adj_norm
|
| 147 |
+
self.features = features
|
| 148 |
+
|
| 149 |
+
if len(labels.shape) > 1:
|
| 150 |
+
self.multi_label = True
|
| 151 |
+
self.loss = torch.nn.BCELoss()
|
| 152 |
+
else:
|
| 153 |
+
self.multi_label = False
|
| 154 |
+
self.loss = F.nll_loss
|
| 155 |
+
|
| 156 |
+
labels = labels.float() if self.multi_label else labels
|
| 157 |
+
self.labels = labels
|
| 158 |
+
|
| 159 |
+
if noval:
|
| 160 |
+
self._train_with_val(labels, data, train_iters, verbose, adj_val=True)
|
| 161 |
+
else:
|
| 162 |
+
self._train_with_val(labels, data, train_iters, verbose)
|
| 163 |
+
|
| 164 |
+
def _train_with_val(self, labels, data, train_iters, verbose, adj_val=False):
|
| 165 |
+
if adj_val:
|
| 166 |
+
feat_full, adj_full = data.feat_val, data.adj_val
|
| 167 |
+
else:
|
| 168 |
+
feat_full, adj_full = data.feat_full, data.adj_full
|
| 169 |
+
|
| 170 |
+
feat_full, adj_full = utils.to_tensor(feat_full, adj_full, device=self.device)
|
| 171 |
+
adj_full_norm = utils.normalize_adj_tensor(adj_full, sparse=True)
|
| 172 |
+
labels_val = torch.LongTensor(data.labels_val).to(self.device)
|
| 173 |
+
|
| 174 |
+
if verbose:
|
| 175 |
+
print('=== training gcn model ===')
|
| 176 |
+
optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
|
| 177 |
+
|
| 178 |
+
best_acc_val = 0
|
| 179 |
+
|
| 180 |
+
for i in range(train_iters):
|
| 181 |
+
if i == train_iters // 2:
|
| 182 |
+
lr = self.lr*0.1
|
| 183 |
+
optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=self.weight_decay)
|
| 184 |
+
|
| 185 |
+
self.train()
|
| 186 |
+
optimizer.zero_grad()
|
| 187 |
+
output = self.forward(self.features, self.adj_norm)
|
| 188 |
+
loss_train = self.loss(output, labels)
|
| 189 |
+
loss_train.backward()
|
| 190 |
+
optimizer.step()
|
| 191 |
+
|
| 192 |
+
if verbose and i % 100 == 0:
|
| 193 |
+
print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
|
| 194 |
+
|
| 195 |
+
with torch.no_grad():
|
| 196 |
+
self.eval()
|
| 197 |
+
output = self.forward(feat_full, adj_full_norm)
|
| 198 |
+
if adj_val:
|
| 199 |
+
loss_val = F.nll_loss(output, labels_val)
|
| 200 |
+
acc_val = utils.accuracy(output, labels_val)
|
| 201 |
+
else:
|
| 202 |
+
loss_val = F.nll_loss(output[data.idx_val], labels_val)
|
| 203 |
+
acc_val = utils.accuracy(output[data.idx_val], labels_val)
|
| 204 |
+
|
| 205 |
+
if acc_val > best_acc_val:
|
| 206 |
+
best_acc_val = acc_val
|
| 207 |
+
self.output = output
|
| 208 |
+
weights = deepcopy(self.state_dict())
|
| 209 |
+
|
| 210 |
+
if verbose:
|
| 211 |
+
print('=== picking the best model according to the performance on validation ===')
|
| 212 |
+
self.load_state_dict(weights)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
def test(self, idx_test):
|
| 216 |
+
"""Evaluate GCN performance on test set.
|
| 217 |
+
Parameters
|
| 218 |
+
----------
|
| 219 |
+
idx_test :
|
| 220 |
+
node testing indices
|
| 221 |
+
"""
|
| 222 |
+
self.eval()
|
| 223 |
+
output = self.predict()
|
| 224 |
+
# output = self.output
|
| 225 |
+
loss_test = F.nll_loss(output[idx_test], self.labels[idx_test])
|
| 226 |
+
acc_test = utils.accuracy(output[idx_test], self.labels[idx_test])
|
| 227 |
+
print("Test set results:",
|
| 228 |
+
"loss= {:.4f}".format(loss_test.item()),
|
| 229 |
+
"accuracy= {:.4f}".format(acc_test.item()))
|
| 230 |
+
return acc_test.item()
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
@torch.no_grad()
|
| 234 |
+
def predict(self, features=None, adj=None):
|
| 235 |
+
"""By default, the inputs should be unnormalized adjacency
|
| 236 |
+
Parameters
|
| 237 |
+
----------
|
| 238 |
+
features :
|
| 239 |
+
node features. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.
|
| 240 |
+
adj :
|
| 241 |
+
adjcency matrix. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.
|
| 242 |
+
Returns
|
| 243 |
+
-------
|
| 244 |
+
torch.FloatTensor
|
| 245 |
+
output (log probabilities) of GCN
|
| 246 |
+
"""
|
| 247 |
+
|
| 248 |
+
self.eval()
|
| 249 |
+
if features is None and adj is None:
|
| 250 |
+
return self.forward(self.features, self.adj_norm)
|
| 251 |
+
else:
|
| 252 |
+
if type(adj) is not torch.Tensor:
|
| 253 |
+
features, adj = utils.to_tensor(features, adj, device=self.device)
|
| 254 |
+
|
| 255 |
+
self.features = features
|
| 256 |
+
if utils.is_sparse_tensor(adj):
|
| 257 |
+
self.adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
|
| 258 |
+
else:
|
| 259 |
+
self.adj_norm = utils.normalize_adj_tensor(adj)
|
| 260 |
+
return self.forward(self.features, self.adj_norm)
|
| 261 |
+
|
| 262 |
+
@torch.no_grad()
|
| 263 |
+
def predict_unnorm(self, features=None, adj=None):
|
| 264 |
+
self.eval()
|
| 265 |
+
if features is None and adj is None:
|
| 266 |
+
return self.forward(self.features, self.adj_norm)
|
| 267 |
+
else:
|
| 268 |
+
if type(adj) is not torch.Tensor:
|
| 269 |
+
features, adj = utils.to_tensor(features, adj, device=self.device)
|
| 270 |
+
|
| 271 |
+
self.features = features
|
| 272 |
+
self.adj_norm = adj
|
| 273 |
+
return self.forward(self.features, self.adj_norm)
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
|
| 277 |
+
class MyLinear(Module):
|
| 278 |
+
"""Simple Linear layer, modified from https://github.com/tkipf/pygcn
|
| 279 |
+
"""
|
| 280 |
+
|
| 281 |
+
def __init__(self, in_features, out_features, with_bias=True):
|
| 282 |
+
super(MyLinear, self).__init__()
|
| 283 |
+
self.in_features = in_features
|
| 284 |
+
self.out_features = out_features
|
| 285 |
+
self.weight = Parameter(torch.FloatTensor(in_features, out_features))
|
| 286 |
+
if with_bias:
|
| 287 |
+
self.bias = Parameter(torch.FloatTensor(out_features))
|
| 288 |
+
else:
|
| 289 |
+
self.register_parameter('bias', None)
|
| 290 |
+
self.reset_parameters()
|
| 291 |
+
|
| 292 |
+
def reset_parameters(self):
|
| 293 |
+
# stdv = 1. / math.sqrt(self.weight.size(1))
|
| 294 |
+
stdv = 1. / math.sqrt(self.weight.T.size(1))
|
| 295 |
+
self.weight.data.uniform_(-stdv, stdv)
|
| 296 |
+
if self.bias is not None:
|
| 297 |
+
self.bias.data.uniform_(-stdv, stdv)
|
| 298 |
+
|
| 299 |
+
def forward(self, input):
|
| 300 |
+
if input.data.is_sparse:
|
| 301 |
+
support = torch.spmm(input, self.weight)
|
| 302 |
+
else:
|
| 303 |
+
support = torch.mm(input, self.weight)
|
| 304 |
+
output = support
|
| 305 |
+
if self.bias is not None:
|
| 306 |
+
return output + self.bias
|
| 307 |
+
else:
|
| 308 |
+
return output
|
| 309 |
+
|
| 310 |
+
def __repr__(self):
|
| 311 |
+
return self.__class__.__name__ + ' (' \
|
| 312 |
+
+ str(self.in_features) + ' -> ' \
|
| 313 |
+
+ str(self.out_features) + ')'
|
| 314 |
+
|
| 315 |
+
|
GCond/requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torch_geometric
|
| 3 |
+
scipy
|
| 4 |
+
numpy
|
| 5 |
+
ogb
|
| 6 |
+
tqdm
|
| 7 |
+
torch_sparse
|
| 8 |
+
torch_vision
|
| 9 |
+
configs
|
| 10 |
+
deeprobust
|
| 11 |
+
scikit_learn
|
GCond/res/cross/empty
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
GCond/script.sh
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
python train_cond_tranduct_sampler.py --dataset cora --mlp=0 --nlayers=2 --sgc=1 --lr_feat=1e-4 --gpu_id=0 --lr_adj=1e-4 --r=0.5 --seed=1000
|
| 2 |
+
|
| 3 |
+
python train_cond_tranduct_sampler.py --dataset ogbn-arxiv --mlp=0 --nlayers=2 --sgc=1 --lr_feat=0.01 --gpu_id=3 --lr_adj=0.01 --r=0.02 --seed=0 --inner=3 --epochs=1000 --save=0
|
| 4 |
+
|
GCond/scripts/run_cross.sh
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
for r in 0.001 0.005 0.01
|
| 2 |
+
do
|
| 3 |
+
bash scripts/script_cross.sh flickr ${r} 0
|
| 4 |
+
bash scripts/script_cross.sh ogbn-arxiv ${r} 0
|
| 5 |
+
done
|
| 6 |
+
|
| 7 |
+
for r in 0.25 0.5 1
|
| 8 |
+
do
|
| 9 |
+
bash scripts/script_cross.sh citeseer ${r} 0
|
| 10 |
+
bash scripts/script_cross.sh cora ${r} 0
|
| 11 |
+
done
|
| 12 |
+
|
| 13 |
+
for r in 0.001 0.0005 0.002
|
| 14 |
+
do
|
| 15 |
+
bash scripts/script_cross.sh reddit ${r} 0
|
| 16 |
+
done
|
GCond/scripts/run_main.sh
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
for r in 0.25 0.5 1
|
| 2 |
+
do
|
| 3 |
+
python train_gcond_transduct.py --dataset cora --nlayers=2 --sgc=1 --lr_feat=1e-4 --gpu_id=0 --lr_adj=1e-4 --r=${r} --seed=1 --epoch=600 --save=0
|
| 4 |
+
done
|
| 5 |
+
|
| 6 |
+
for r in 0.25 0.5 1
|
| 7 |
+
do
|
| 8 |
+
python train_gcond_transduct.py --dataset citeseer --nlayers=2 --sgc=1 --lr_feat=1e-4 --gpu_id=0 --lr_adj=1e-4 --r=${r} --seed=1 --epoch=600 --save=0
|
| 9 |
+
done
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
for r in 0.001 0.005 0.01
|
| 13 |
+
do
|
| 14 |
+
python train_gcond_transduct.py --dataset ogbn-arxiv --nlayers=2 --sgc=1 --lr_feat=0.01 --gpu_id=3 --lr_adj=0.01 --r=${r} --seed=1 --inner=3 --epochs=1000 --save=0
|
| 15 |
+
done
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
for r in 0.001 0.005 0.01
|
| 19 |
+
do
|
| 20 |
+
python train_gcond_induct.py --dataset flickr --sgc=2 --nlayers=2 --lr_feat=0.005 --lr_adj=0.005 --r=${r} --seed=1 --gpu_id=0 --epochs=1000 --inner=1 --outer=10 --save=0
|
| 21 |
+
done
|
| 22 |
+
|
| 23 |
+
for r in 0.001 0.005 0.0005 0.002
|
| 24 |
+
do
|
| 25 |
+
python train_gcond_induct.py --dataset reddit --sgc=1 --nlayers=2 --lr_feat=0.1 --lr_adj=0.1 --r=${r} --seed=1 --gpu_id=0 --epochs=1000 --inner=1 --outer=10 --save=0
|
| 26 |
+
done
|
GCond/scripts/script_cross.sh
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
dataset=${1}
|
| 2 |
+
r=${2}
|
| 3 |
+
gpu_id=${3}
|
| 4 |
+
for s in 0 1 2 3 4
|
| 5 |
+
do
|
| 6 |
+
python test_other_arcs.py --dataset ${dataset} --gpu_id=${gpu_id} --r=${r} --seed=${s} --nruns=10 >> res/flickr/${1}_${2}.out
|
| 7 |
+
done
|
GCond/test_other_arcs.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
from deeprobust.graph.data import Dataset
|
| 3 |
+
import numpy as np
|
| 4 |
+
import random
|
| 5 |
+
import time
|
| 6 |
+
import argparse
|
| 7 |
+
import torch
|
| 8 |
+
from utils import *
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from tester_other_arcs import Evaluator
|
| 11 |
+
from utils_graphsaint import DataGraphSAINT
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
parser = argparse.ArgumentParser()
|
| 15 |
+
parser.add_argument('--gpu_id', type=int, default=0, help='gpu id')
|
| 16 |
+
parser.add_argument('--dataset', type=str, default='cora')
|
| 17 |
+
parser.add_argument('--nlayers', type=int, default=2)
|
| 18 |
+
parser.add_argument('--hidden', type=int, default=256)
|
| 19 |
+
parser.add_argument('--keep_ratio', type=float, default=1)
|
| 20 |
+
parser.add_argument('--reduction_rate', type=float, default=1)
|
| 21 |
+
parser.add_argument('--weight_decay', type=float, default=0.0)
|
| 22 |
+
parser.add_argument('--dropout', type=float, default=0.0)
|
| 23 |
+
parser.add_argument('--normalize_features', type=bool, default=True)
|
| 24 |
+
parser.add_argument('--seed', type=int, default=0, help='Random seed.')
|
| 25 |
+
parser.add_argument('--mlp', type=int, default=0)
|
| 26 |
+
parser.add_argument('--inner', type=int, default=0)
|
| 27 |
+
parser.add_argument('--epsilon', type=float, default=-1)
|
| 28 |
+
parser.add_argument('--nruns', type=int, default=20)
|
| 29 |
+
args = parser.parse_args()
|
| 30 |
+
|
| 31 |
+
torch.cuda.set_device(args.gpu_id)
|
| 32 |
+
|
| 33 |
+
# random seed setting
|
| 34 |
+
random.seed(args.seed)
|
| 35 |
+
np.random.seed(args.seed)
|
| 36 |
+
torch.manual_seed(args.seed)
|
| 37 |
+
torch.cuda.manual_seed(args.seed)
|
| 38 |
+
|
| 39 |
+
if args.dataset in ['cora', 'citeseer']:
|
| 40 |
+
args.epsilon = 0.05
|
| 41 |
+
else:
|
| 42 |
+
args.epsilon = 0.01
|
| 43 |
+
|
| 44 |
+
print(args)
|
| 45 |
+
|
| 46 |
+
data_graphsaint = ['flickr', 'reddit', 'ogbn-arxiv']
|
| 47 |
+
if args.dataset in data_graphsaint:
|
| 48 |
+
data = DataGraphSAINT(args.dataset)
|
| 49 |
+
data_full = data.data_full
|
| 50 |
+
else:
|
| 51 |
+
data_full = get_dataset(args.dataset, args.normalize_features)
|
| 52 |
+
data = Transd2Ind(data_full, keep_ratio=args.keep_ratio)
|
| 53 |
+
|
| 54 |
+
agent = Evaluator(data, args, device='cuda')
|
| 55 |
+
agent.train()
|
GCond/tester_other_arcs.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sys
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.optim as optim
|
| 6 |
+
from torch.nn import Parameter
|
| 7 |
+
import torch.nn.functional as F
|
| 8 |
+
from utils import match_loss, regularization, row_normalize_tensor
|
| 9 |
+
import deeprobust.graph.utils as utils
|
| 10 |
+
from copy import deepcopy
|
| 11 |
+
import numpy as np
|
| 12 |
+
from tqdm import tqdm
|
| 13 |
+
from models.gcn import GCN
|
| 14 |
+
from models.sgc import SGC
|
| 15 |
+
from models.sgc_multi import SGC as SGC1
|
| 16 |
+
from models.myappnp import APPNP
|
| 17 |
+
from models.myappnp1 import APPNP1
|
| 18 |
+
from models.mycheby import Cheby
|
| 19 |
+
from models.mygraphsage import GraphSage
|
| 20 |
+
from models.gat import GAT
|
| 21 |
+
import scipy.sparse as sp
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
class Evaluator:
|
| 25 |
+
|
| 26 |
+
def __init__(self, data, args, device='cuda', **kwargs):
|
| 27 |
+
self.data = data
|
| 28 |
+
self.args = args
|
| 29 |
+
self.device = device
|
| 30 |
+
n = int(data.feat_train.shape[0] * args.reduction_rate)
|
| 31 |
+
d = data.feat_train.shape[1]
|
| 32 |
+
self.nnodes_syn = n
|
| 33 |
+
self.adj_param= nn.Parameter(torch.FloatTensor(n, n).to(device))
|
| 34 |
+
self.feat_syn = nn.Parameter(torch.FloatTensor(n, d).to(device))
|
| 35 |
+
self.labels_syn = torch.LongTensor(self.generate_labels_syn(data)).to(device)
|
| 36 |
+
self.reset_parameters()
|
| 37 |
+
print('adj_param:', self.adj_param.shape, 'feat_syn:', self.feat_syn.shape)
|
| 38 |
+
|
| 39 |
+
def reset_parameters(self):
|
| 40 |
+
self.adj_param.data.copy_(torch.randn(self.adj_param.size()))
|
| 41 |
+
self.feat_syn.data.copy_(torch.randn(self.feat_syn.size()))
|
| 42 |
+
|
| 43 |
+
def generate_labels_syn(self, data):
|
| 44 |
+
from collections import Counter
|
| 45 |
+
counter = Counter(data.labels_train)
|
| 46 |
+
num_class_dict = {}
|
| 47 |
+
n = len(data.labels_train)
|
| 48 |
+
|
| 49 |
+
sorted_counter = sorted(counter.items(), key=lambda x:x[1])
|
| 50 |
+
sum_ = 0
|
| 51 |
+
labels_syn = []
|
| 52 |
+
self.syn_class_indices = {}
|
| 53 |
+
for ix, (c, num) in enumerate(sorted_counter):
|
| 54 |
+
if ix == len(sorted_counter) - 1:
|
| 55 |
+
num_class_dict[c] = int(n * self.args.reduction_rate) - sum_
|
| 56 |
+
self.syn_class_indices[c] = [len(labels_syn), len(labels_syn) + num_class_dict[c]]
|
| 57 |
+
labels_syn += [c] * num_class_dict[c]
|
| 58 |
+
else:
|
| 59 |
+
num_class_dict[c] = max(int(num * self.args.reduction_rate), 1)
|
| 60 |
+
sum_ += num_class_dict[c]
|
| 61 |
+
self.syn_class_indices[c] = [len(labels_syn), len(labels_syn) + num_class_dict[c]]
|
| 62 |
+
labels_syn += [c] * num_class_dict[c]
|
| 63 |
+
|
| 64 |
+
self.num_class_dict = num_class_dict
|
| 65 |
+
return labels_syn
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def test_gat(self, nlayers, model_type, verbose=False):
|
| 69 |
+
res = []
|
| 70 |
+
args = self.args
|
| 71 |
+
|
| 72 |
+
if args.dataset in ['cora', 'citeseer']:
|
| 73 |
+
args.epsilon = 0.5 # Make the graph sparser as GAT does not work well on dense graph
|
| 74 |
+
else:
|
| 75 |
+
args.epsilon = 0.01
|
| 76 |
+
|
| 77 |
+
print('======= testing %s' % model_type)
|
| 78 |
+
data, device = self.data, self.device
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
feat_syn, adj_syn, labels_syn = self.get_syn_data(model_type)
|
| 82 |
+
# with_bn = True if self.args.dataset in ['ogbn-arxiv'] else False
|
| 83 |
+
with_bn = False
|
| 84 |
+
if model_type == 'GAT':
|
| 85 |
+
model = GAT(nfeat=feat_syn.shape[1], nhid=16, heads=16, dropout=0.0,
|
| 86 |
+
weight_decay=0e-4, nlayers=self.args.nlayers, lr=0.001,
|
| 87 |
+
nclass=data.nclass, device=device, dataset=self.args.dataset).to(device)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
noval = True if args.dataset in ['reddit', 'flickr'] else False
|
| 91 |
+
model.fit(feat_syn, adj_syn, labels_syn, np.arange(len(feat_syn)), noval=noval, data=data,
|
| 92 |
+
train_iters=10000 if noval else 3000, normalize=True, verbose=verbose)
|
| 93 |
+
|
| 94 |
+
model.eval()
|
| 95 |
+
labels_test = torch.LongTensor(data.labels_test).cuda()
|
| 96 |
+
|
| 97 |
+
if args.dataset in ['reddit', 'flickr']:
|
| 98 |
+
output = model.predict(data.feat_test, data.adj_test)
|
| 99 |
+
loss_test = F.nll_loss(output, labels_test)
|
| 100 |
+
acc_test = utils.accuracy(output, labels_test)
|
| 101 |
+
res.append(acc_test.item())
|
| 102 |
+
if verbose:
|
| 103 |
+
print("Test set results:",
|
| 104 |
+
"loss= {:.4f}".format(loss_test.item()),
|
| 105 |
+
"accuracy= {:.4f}".format(acc_test.item()))
|
| 106 |
+
|
| 107 |
+
else:
|
| 108 |
+
# Full graph
|
| 109 |
+
output = model.predict(data.feat_full, data.adj_full)
|
| 110 |
+
loss_test = F.nll_loss(output[data.idx_test], labels_test)
|
| 111 |
+
acc_test = utils.accuracy(output[data.idx_test], labels_test)
|
| 112 |
+
res.append(acc_test.item())
|
| 113 |
+
if verbose:
|
| 114 |
+
print("Test set results:",
|
| 115 |
+
"loss= {:.4f}".format(loss_test.item()),
|
| 116 |
+
"accuracy= {:.4f}".format(acc_test.item()))
|
| 117 |
+
|
| 118 |
+
labels_train = torch.LongTensor(data.labels_train).cuda()
|
| 119 |
+
output = model.predict(data.feat_train, data.adj_train)
|
| 120 |
+
loss_train = F.nll_loss(output, labels_train)
|
| 121 |
+
acc_train = utils.accuracy(output, labels_train)
|
| 122 |
+
if verbose:
|
| 123 |
+
print("Train set results:",
|
| 124 |
+
"loss= {:.4f}".format(loss_train.item()),
|
| 125 |
+
"accuracy= {:.4f}".format(acc_train.item()))
|
| 126 |
+
res.append(acc_train.item())
|
| 127 |
+
return res
|
| 128 |
+
|
| 129 |
+
def get_syn_data(self, model_type=None):
|
| 130 |
+
data, device = self.data, self.device
|
| 131 |
+
feat_syn, adj_param, labels_syn = self.feat_syn.detach(), \
|
| 132 |
+
self.adj_param.detach(), self.labels_syn
|
| 133 |
+
|
| 134 |
+
args = self.args
|
| 135 |
+
adj_syn = torch.load(f'saved_ours/adj_{args.dataset}_{args.reduction_rate}_{args.seed}.pt', map_location='cuda')
|
| 136 |
+
feat_syn = torch.load(f'saved_ours/feat_{args.dataset}_{args.reduction_rate}_{args.seed}.pt', map_location='cuda')
|
| 137 |
+
|
| 138 |
+
if model_type == 'MLP':
|
| 139 |
+
adj_syn = adj_syn.to(self.device)
|
| 140 |
+
adj_syn = adj_syn - adj_syn
|
| 141 |
+
else:
|
| 142 |
+
adj_syn = adj_syn.to(self.device)
|
| 143 |
+
|
| 144 |
+
print('Sum:', adj_syn.sum(), adj_syn.sum()/(adj_syn.shape[0]**2))
|
| 145 |
+
print('Sparsity:', adj_syn.nonzero().shape[0]/(adj_syn.shape[0]**2))
|
| 146 |
+
|
| 147 |
+
if self.args.epsilon > 0:
|
| 148 |
+
adj_syn[adj_syn < self.args.epsilon] = 0
|
| 149 |
+
print('Sparsity after truncating:', adj_syn.nonzero().shape[0]/(adj_syn.shape[0]**2))
|
| 150 |
+
feat_syn = feat_syn.to(self.device)
|
| 151 |
+
|
| 152 |
+
# edge_index = adj_syn.nonzero().T
|
| 153 |
+
# adj_syn = torch.sparse.FloatTensor(edge_index, adj_syn[edge_index[0], edge_index[1]], adj_syn.size())
|
| 154 |
+
|
| 155 |
+
return feat_syn, adj_syn, labels_syn
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
def test(self, nlayers, model_type, verbose=True):
|
| 159 |
+
res = []
|
| 160 |
+
|
| 161 |
+
args = self.args
|
| 162 |
+
data, device = self.data, self.device
|
| 163 |
+
|
| 164 |
+
feat_syn, adj_syn, labels_syn = self.get_syn_data(model_type)
|
| 165 |
+
|
| 166 |
+
print('======= testing %s' % model_type)
|
| 167 |
+
if model_type == 'MLP':
|
| 168 |
+
model_class = GCN
|
| 169 |
+
else:
|
| 170 |
+
model_class = eval(model_type)
|
| 171 |
+
weight_decay = 5e-4
|
| 172 |
+
dropout = 0.5 if args.dataset in ['reddit'] else 0
|
| 173 |
+
|
| 174 |
+
model = model_class(nfeat=feat_syn.shape[1], nhid=args.hidden, dropout=dropout,
|
| 175 |
+
weight_decay=weight_decay, nlayers=nlayers,
|
| 176 |
+
nclass=data.nclass, device=device).to(device)
|
| 177 |
+
|
| 178 |
+
# with_bn = True if self.args.dataset in ['ogbn-arxiv'] else False
|
| 179 |
+
if args.dataset in ['ogbn-arxiv', 'arxiv']:
|
| 180 |
+
model = model_class(nfeat=feat_syn.shape[1], nhid=args.hidden, dropout=0.,
|
| 181 |
+
weight_decay=weight_decay, nlayers=nlayers, with_bn=False,
|
| 182 |
+
nclass=data.nclass, device=device).to(device)
|
| 183 |
+
|
| 184 |
+
noval = True if args.dataset in ['reddit', 'flickr'] else False
|
| 185 |
+
model.fit_with_val(feat_syn, adj_syn, labels_syn, data,
|
| 186 |
+
train_iters=600, normalize=True, verbose=True, noval=noval)
|
| 187 |
+
|
| 188 |
+
model.eval()
|
| 189 |
+
labels_test = torch.LongTensor(data.labels_test).cuda()
|
| 190 |
+
|
| 191 |
+
if model_type == 'MLP':
|
| 192 |
+
output = model.predict_unnorm(data.feat_test, sp.eye(len(data.feat_test)))
|
| 193 |
+
else:
|
| 194 |
+
output = model.predict(data.feat_test, data.adj_test)
|
| 195 |
+
|
| 196 |
+
if args.dataset in ['reddit', 'flickr']:
|
| 197 |
+
loss_test = F.nll_loss(output, labels_test)
|
| 198 |
+
acc_test = utils.accuracy(output, labels_test)
|
| 199 |
+
res.append(acc_test.item())
|
| 200 |
+
if verbose:
|
| 201 |
+
print("Test set results:",
|
| 202 |
+
"loss= {:.4f}".format(loss_test.item()),
|
| 203 |
+
"accuracy= {:.4f}".format(acc_test.item()))
|
| 204 |
+
|
| 205 |
+
# if not args.dataset in ['reddit', 'flickr']:
|
| 206 |
+
else:
|
| 207 |
+
# Full graph
|
| 208 |
+
output = model.predict(data.feat_full, data.adj_full)
|
| 209 |
+
loss_test = F.nll_loss(output[data.idx_test], labels_test)
|
| 210 |
+
acc_test = utils.accuracy(output[data.idx_test], labels_test)
|
| 211 |
+
res.append(acc_test.item())
|
| 212 |
+
if verbose:
|
| 213 |
+
print("Test full set results:",
|
| 214 |
+
"loss= {:.4f}".format(loss_test.item()),
|
| 215 |
+
"accuracy= {:.4f}".format(acc_test.item()))
|
| 216 |
+
|
| 217 |
+
labels_train = torch.LongTensor(data.labels_train).cuda()
|
| 218 |
+
output = model.predict(data.feat_train, data.adj_train)
|
| 219 |
+
loss_train = F.nll_loss(output, labels_train)
|
| 220 |
+
acc_train = utils.accuracy(output, labels_train)
|
| 221 |
+
if verbose:
|
| 222 |
+
print("Train set results:",
|
| 223 |
+
"loss= {:.4f}".format(loss_train.item()),
|
| 224 |
+
"accuracy= {:.4f}".format(acc_train.item()))
|
| 225 |
+
res.append(acc_train.item())
|
| 226 |
+
return res
|
| 227 |
+
|
| 228 |
+
def train(self, verbose=True):
|
| 229 |
+
args = self.args
|
| 230 |
+
data = self.data
|
| 231 |
+
|
| 232 |
+
final_res = {}
|
| 233 |
+
runs = self.args.nruns
|
| 234 |
+
|
| 235 |
+
for model_type in ['GCN', 'GraphSage', 'SGC1', 'MLP', 'APPNP1', 'Cheby']:
|
| 236 |
+
res = []
|
| 237 |
+
nlayer = 2
|
| 238 |
+
for i in range(runs):
|
| 239 |
+
res.append(self.test(nlayer, verbose=False, model_type=model_type))
|
| 240 |
+
res = np.array(res)
|
| 241 |
+
print('Test/Train Mean Accuracy:',
|
| 242 |
+
repr([res.mean(0), res.std(0)]))
|
| 243 |
+
final_res[model_type] = [res.mean(0), res.std(0)]
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
print('=== testing GAT')
|
| 247 |
+
res = []
|
| 248 |
+
nlayer = 2
|
| 249 |
+
for i in range(runs):
|
| 250 |
+
res.append(self.test_gat(verbose=True, nlayers=nlayer, model_type='GAT'))
|
| 251 |
+
res = np.array(res)
|
| 252 |
+
print('Layer:', nlayer)
|
| 253 |
+
print('Test/Full Test/Train Mean Accuracy:',
|
| 254 |
+
repr([res.mean(0), res.std(0)]))
|
| 255 |
+
final_res['GAT'] = [res.mean(0), res.std(0)]
|
| 256 |
+
|
| 257 |
+
print('Final result:', final_res)
|
| 258 |
+
|
GCond/train_coreset.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from deeprobust.graph.data import Dataset
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
import time
|
| 5 |
+
import argparse
|
| 6 |
+
import torch
|
| 7 |
+
import sys
|
| 8 |
+
from deeprobust.graph.utils import *
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from configs import load_config
|
| 11 |
+
from utils import *
|
| 12 |
+
from utils_graphsaint import DataGraphSAINT
|
| 13 |
+
from models.gcn import GCN
|
| 14 |
+
from coreset import KCenter, Herding, Random, LRMC
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
|
| 17 |
+
parser = argparse.ArgumentParser()
|
| 18 |
+
parser.add_argument('--gpu_id', type=int, default=0, help='gpu id')
|
| 19 |
+
parser.add_argument('--dataset', type=str, default='cora')
|
| 20 |
+
parser.add_argument('--hidden', type=int, default=256)
|
| 21 |
+
parser.add_argument('--normalize_features', type=bool, default=True)
|
| 22 |
+
parser.add_argument('--keep_ratio', type=float, default=1.0)
|
| 23 |
+
parser.add_argument('--lr', type=float, default=0.01)
|
| 24 |
+
parser.add_argument('--weight_decay', type=float, default=5e-4)
|
| 25 |
+
parser.add_argument('--seed', type=int, default=15, help='Random seed.')
|
| 26 |
+
parser.add_argument('--nlayers', type=int, default=2, help='Random seed.')
|
| 27 |
+
parser.add_argument('--epochs', type=int, default=400)
|
| 28 |
+
parser.add_argument('--inductive', type=int, default=1)
|
| 29 |
+
parser.add_argument('--save', type=int, default=0)
|
| 30 |
+
parser.add_argument('--method', type=str, choices=['kcenter', 'herding', 'random', 'lrmc'])
|
| 31 |
+
parser.add_argument('--lrmc_seeds_path', type=str, default=None,
|
| 32 |
+
help='Path to a JSON file containing L‑RMC seed nodes. Required when method=lrmc.')
|
| 33 |
+
parser.add_argument('--reduction_rate', type=float, required=True)
|
| 34 |
+
args = parser.parse_args()
|
| 35 |
+
|
| 36 |
+
torch.cuda.set_device(args.gpu_id)
|
| 37 |
+
args = load_config(args)
|
| 38 |
+
print(args)
|
| 39 |
+
|
| 40 |
+
# random seed setting
|
| 41 |
+
random.seed(args.seed)
|
| 42 |
+
np.random.seed(args.seed)
|
| 43 |
+
torch.manual_seed(args.seed)
|
| 44 |
+
torch.cuda.manual_seed(args.seed)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
data_graphsaint = ['flickr', 'reddit', 'ogbn-arxiv']
|
| 48 |
+
if args.dataset in data_graphsaint:
|
| 49 |
+
data = DataGraphSAINT(args.dataset)
|
| 50 |
+
data_full = data.data_full
|
| 51 |
+
data = Transd2Ind(data_full, keep_ratio=args.keep_ratio)
|
| 52 |
+
else:
|
| 53 |
+
data_full = get_dataset(args.dataset, args.normalize_features)
|
| 54 |
+
data = Transd2Ind(data_full, keep_ratio=args.keep_ratio)
|
| 55 |
+
|
| 56 |
+
features = data_full.features
|
| 57 |
+
adj = data_full.adj
|
| 58 |
+
labels = data_full.labels
|
| 59 |
+
idx_train = data_full.idx_train
|
| 60 |
+
idx_val = data_full.idx_val
|
| 61 |
+
idx_test = data_full.idx_test
|
| 62 |
+
|
| 63 |
+
# Setup GCN Model
|
| 64 |
+
device = 'cuda'
|
| 65 |
+
model = GCN(nfeat=features.shape[1], nhid=256, nclass=labels.max()+1, device=device, weight_decay=args.weight_decay)
|
| 66 |
+
|
| 67 |
+
model = model.to(device)
|
| 68 |
+
model.fit(features, adj, labels, idx_train, idx_val, train_iters=600, verbose=False)
|
| 69 |
+
|
| 70 |
+
model.eval()
|
| 71 |
+
# You can use the inner function of model to test
|
| 72 |
+
model.test(idx_test)
|
| 73 |
+
|
| 74 |
+
embeds = model.predict().detach()
|
| 75 |
+
|
| 76 |
+
if args.method == 'kcenter':
|
| 77 |
+
agent = KCenter(data, args, device='cuda')
|
| 78 |
+
elif args.method == 'herding':
|
| 79 |
+
agent = Herding(data, args, device='cuda')
|
| 80 |
+
elif args.method == 'random':
|
| 81 |
+
agent = Random(data, args, device='cuda')
|
| 82 |
+
elif args.method == 'lrmc':
|
| 83 |
+
if args.lrmc_seeds_path is None:
|
| 84 |
+
raise ValueError("--lrmc_seeds_path must be specified when method='lrmc'")
|
| 85 |
+
agent = LRMC(data, args, device='cuda')
|
| 86 |
+
|
| 87 |
+
idx_selected = agent.select(embeds)
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
feat_train = features[idx_selected]
|
| 91 |
+
adj_train = adj[np.ix_(idx_selected, idx_selected)]
|
| 92 |
+
|
| 93 |
+
labels_train = labels[idx_selected]
|
| 94 |
+
|
| 95 |
+
if args.save:
|
| 96 |
+
np.save(f'saved/idx_{args.dataset}_{args.reduction_rate}_{args.method}_{args.seed}.npy', idx_selected)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
res = []
|
| 100 |
+
runs = 10
|
| 101 |
+
for _ in tqdm(range(runs)):
|
| 102 |
+
model.initialize()
|
| 103 |
+
model.fit_with_val(feat_train, adj_train, labels_train, data,
|
| 104 |
+
train_iters=600, normalize=True, verbose=False)
|
| 105 |
+
|
| 106 |
+
model.eval()
|
| 107 |
+
labels_test = torch.LongTensor(data.labels_test).cuda()
|
| 108 |
+
|
| 109 |
+
# Full graph
|
| 110 |
+
output = model.predict(data.feat_full, data.adj_full)
|
| 111 |
+
loss_test = F.nll_loss(output[data.idx_test], labels_test)
|
| 112 |
+
acc_test = accuracy(output[data.idx_test], labels_test)
|
| 113 |
+
res.append(acc_test.item())
|
| 114 |
+
|
| 115 |
+
res = np.array(res)
|
| 116 |
+
print('Mean accuracy:', repr([res.mean(), res.std()]))
|
| 117 |
+
|
GCond/train_coreset_induct.py
ADDED
|
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from deeprobust.graph.data import Dataset
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
import time
|
| 5 |
+
import argparse
|
| 6 |
+
import torch
|
| 7 |
+
import sys
|
| 8 |
+
import deeprobust.graph.utils as utils
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
from configs import load_config
|
| 11 |
+
from utils import *
|
| 12 |
+
from utils_graphsaint import DataGraphSAINT
|
| 13 |
+
from models.gcn import GCN
|
| 14 |
+
from coreset import KCenter, Herding, Random, LRMC
|
| 15 |
+
from tqdm import tqdm
|
| 16 |
+
|
| 17 |
+
parser = argparse.ArgumentParser()
|
| 18 |
+
parser.add_argument('--gpu_id', type=int, default=0, help='gpu id')
|
| 19 |
+
parser.add_argument('--dataset', type=str, default='cora')
|
| 20 |
+
parser.add_argument('--hidden', type=int, default=256)
|
| 21 |
+
parser.add_argument('--normalize_features', type=bool, default=True)
|
| 22 |
+
parser.add_argument('--keep_ratio', type=float, default=1.0)
|
| 23 |
+
parser.add_argument('--lr', type=float, default=0.01)
|
| 24 |
+
parser.add_argument('--weight_decay', type=float, default=5e-4)
|
| 25 |
+
parser.add_argument('--dropout', type=float, default=0.5)
|
| 26 |
+
parser.add_argument('--seed', type=int, default=15, help='Random seed.')
|
| 27 |
+
parser.add_argument('--nlayers', type=int, default=2, help='Random seed.')
|
| 28 |
+
parser.add_argument('--epochs', type=int, default=400)
|
| 29 |
+
parser.add_argument('--inductive', type=int, default=1)
|
| 30 |
+
parser.add_argument('--mlp', type=int, default=0)
|
| 31 |
+
parser.add_argument('--method', type=str, choices=['kcenter', 'herding', 'random', 'lrmc'])
|
| 32 |
+
parser.add_argument('--lrmc_seeds_path', type=str, default=None,
|
| 33 |
+
help='Path to a JSON file containing L‑RMC seed nodes. Required when method=lrmc.')
|
| 34 |
+
parser.add_argument('--reduction_rate', type=float, required=True)
|
| 35 |
+
args = parser.parse_args()
|
| 36 |
+
|
| 37 |
+
torch.cuda.set_device(args.gpu_id)
|
| 38 |
+
args = load_config(args)
|
| 39 |
+
print(args)
|
| 40 |
+
|
| 41 |
+
# random seed setting
|
| 42 |
+
random.seed(args.seed)
|
| 43 |
+
np.random.seed(args.seed)
|
| 44 |
+
torch.manual_seed(args.seed)
|
| 45 |
+
torch.cuda.manual_seed(args.seed)
|
| 46 |
+
|
| 47 |
+
data_graphsaint = ['flickr', 'reddit', 'ogbn-arxiv']
|
| 48 |
+
if args.dataset in data_graphsaint:
|
| 49 |
+
data = DataGraphSAINT(args.dataset)
|
| 50 |
+
data_full = data.data_full
|
| 51 |
+
data = Transd2Ind(data_full, keep_ratio=args.keep_ratio)
|
| 52 |
+
else:
|
| 53 |
+
data_full = get_dataset(args.dataset, args.normalize_features)
|
| 54 |
+
data = Transd2Ind(data_full, keep_ratio=args.keep_ratio)
|
| 55 |
+
|
| 56 |
+
feat_train = data.feat_train
|
| 57 |
+
adj_train = data.adj_train
|
| 58 |
+
labels_train = data.labels_train
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# Setup GCN Model
|
| 62 |
+
device = 'cuda'
|
| 63 |
+
model = GCN(nfeat=feat_train.shape[1], nhid=256, nclass=labels_train.max()+1, device=device, weight_decay=args.weight_decay)
|
| 64 |
+
|
| 65 |
+
model = model.to(device)
|
| 66 |
+
|
| 67 |
+
model.fit_with_val(feat_train, adj_train, labels_train, data,
|
| 68 |
+
train_iters=600, normalize=True, verbose=False)
|
| 69 |
+
|
| 70 |
+
model.eval()
|
| 71 |
+
labels_test = torch.LongTensor(data.labels_test).cuda()
|
| 72 |
+
feat_test, adj_test = data.feat_test, data.adj_test
|
| 73 |
+
|
| 74 |
+
embeds = model.predict().detach()
|
| 75 |
+
|
| 76 |
+
output = model.predict(feat_test, adj_test)
|
| 77 |
+
loss_test = F.nll_loss(output, labels_test)
|
| 78 |
+
acc_test = utils.accuracy(output, labels_test)
|
| 79 |
+
print("Test set results:",
|
| 80 |
+
"loss= {:.4f}".format(loss_test.item()),
|
| 81 |
+
"accuracy= {:.4f}".format(acc_test.item()))
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
if args.method == 'kcenter':
|
| 85 |
+
agent = KCenter(data, args, device='cuda')
|
| 86 |
+
elif args.method == 'herding':
|
| 87 |
+
agent = Herding(data, args, device='cuda')
|
| 88 |
+
elif args.method == 'random':
|
| 89 |
+
agent = Random(data, args, device='cuda')
|
| 90 |
+
elif args.method == 'lrmc':
|
| 91 |
+
if args.lrmc_seeds_path is None:
|
| 92 |
+
raise ValueError("--lrmc_seeds_path must be specified when method='lrmc'")
|
| 93 |
+
agent = LRMC(data, args, device='cuda')
|
| 94 |
+
|
| 95 |
+
idx_selected = agent.select(embeds, inductive=True)
|
| 96 |
+
|
| 97 |
+
feat_train = feat_train[idx_selected]
|
| 98 |
+
adj_train = adj_train[np.ix_(idx_selected, idx_selected)]
|
| 99 |
+
|
| 100 |
+
labels_train = labels_train[idx_selected]
|
| 101 |
+
|
| 102 |
+
res = []
|
| 103 |
+
print('shape of feat_train:', feat_train.shape)
|
| 104 |
+
runs = 10
|
| 105 |
+
for _ in tqdm(range(runs)):
|
| 106 |
+
model.initialize()
|
| 107 |
+
model.fit_with_val(feat_train, adj_train, labels_train, data,
|
| 108 |
+
train_iters=600, normalize=True, verbose=False, noval=True)
|
| 109 |
+
|
| 110 |
+
model.eval()
|
| 111 |
+
labels_test = torch.LongTensor(data.labels_test).cuda()
|
| 112 |
+
|
| 113 |
+
output = model.predict(feat_test, adj_test)
|
| 114 |
+
loss_test = F.nll_loss(output, labels_test)
|
| 115 |
+
acc_test = utils.accuracy(output, labels_test)
|
| 116 |
+
res.append(acc_test.item())
|
| 117 |
+
res = np.array(res)
|
| 118 |
+
print('Mean accuracy:', repr([res.mean(), res.std()]))
|
| 119 |
+
|
GCond/train_gcond_induct.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from deeprobust.graph.data import Dataset
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
import time
|
| 5 |
+
import argparse
|
| 6 |
+
import torch
|
| 7 |
+
from utils import *
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from gcond_agent_induct import GCond
|
| 10 |
+
from utils_graphsaint import DataGraphSAINT
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
parser = argparse.ArgumentParser()
|
| 14 |
+
parser.add_argument('--gpu_id', type=int, default=0, help='gpu id')
|
| 15 |
+
parser.add_argument('--dataset', type=str, default='cora')
|
| 16 |
+
parser.add_argument('--dis_metric', type=str, default='ours')
|
| 17 |
+
parser.add_argument('--epochs', type=int, default=600)
|
| 18 |
+
parser.add_argument('--nlayers', type=int, default=3)
|
| 19 |
+
parser.add_argument('--hidden', type=int, default=256)
|
| 20 |
+
parser.add_argument('--lr_adj', type=float, default=0.01)
|
| 21 |
+
parser.add_argument('--lr_feat', type=float, default=0.01)
|
| 22 |
+
parser.add_argument('--lr_model', type=float, default=0.01)
|
| 23 |
+
parser.add_argument('--weight_decay', type=float, default=0.0)
|
| 24 |
+
parser.add_argument('--dropout', type=float, default=0.0)
|
| 25 |
+
parser.add_argument('--normalize_features', type=bool, default=True)
|
| 26 |
+
parser.add_argument('--keep_ratio', type=float, default=1.0)
|
| 27 |
+
parser.add_argument('--reduction_rate', type=float, default=0.01)
|
| 28 |
+
parser.add_argument('--seed', type=int, default=15, help='Random seed.')
|
| 29 |
+
parser.add_argument('--alpha', type=float, default=0, help='regularization term.')
|
| 30 |
+
parser.add_argument('--debug', type=int, default=0)
|
| 31 |
+
parser.add_argument('--sgc', type=int, default=1)
|
| 32 |
+
parser.add_argument('--inner', type=int, default=0)
|
| 33 |
+
parser.add_argument('--outer', type=int, default=20)
|
| 34 |
+
parser.add_argument('--option', type=int, default=0)
|
| 35 |
+
parser.add_argument('--save', type=int, default=0)
|
| 36 |
+
parser.add_argument('--label_rate', type=float, default=1)
|
| 37 |
+
parser.add_argument('--one_step', type=int, default=0)
|
| 38 |
+
args = parser.parse_args()
|
| 39 |
+
|
| 40 |
+
torch.cuda.set_device(args.gpu_id)
|
| 41 |
+
|
| 42 |
+
# random seed setting
|
| 43 |
+
random.seed(args.seed)
|
| 44 |
+
np.random.seed(args.seed)
|
| 45 |
+
torch.manual_seed(args.seed)
|
| 46 |
+
torch.cuda.manual_seed(args.seed)
|
| 47 |
+
|
| 48 |
+
print(args)
|
| 49 |
+
|
| 50 |
+
data_graphsaint = ['flickr', 'reddit', 'ogbn-arxiv']
|
| 51 |
+
if args.dataset in data_graphsaint:
|
| 52 |
+
# data = DataGraphSAINT(args.dataset)
|
| 53 |
+
data = DataGraphSAINT(args.dataset, label_rate=args.label_rate)
|
| 54 |
+
data_full = data.data_full
|
| 55 |
+
else:
|
| 56 |
+
data_full = get_dataset(args.dataset, args.normalize_features)
|
| 57 |
+
data = Transd2Ind(data_full, keep_ratio=args.keep_ratio)
|
| 58 |
+
|
| 59 |
+
agent = GCond(data, args, device='cuda')
|
| 60 |
+
|
| 61 |
+
agent.train()
|
GCond/train_gcond_transduct.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from deeprobust.graph.data import Dataset
|
| 2 |
+
import numpy as np
|
| 3 |
+
import random
|
| 4 |
+
import time
|
| 5 |
+
import argparse
|
| 6 |
+
import torch
|
| 7 |
+
from utils import *
|
| 8 |
+
import torch.nn.functional as F
|
| 9 |
+
from gcond_agent_transduct import GCond
|
| 10 |
+
from utils_graphsaint import DataGraphSAINT
|
| 11 |
+
|
| 12 |
+
parser = argparse.ArgumentParser()
|
| 13 |
+
parser.add_argument('--gpu_id', type=int, default=0, help='gpu id')
|
| 14 |
+
parser.add_argument('--dataset', type=str, default='cora')
|
| 15 |
+
parser.add_argument('--dis_metric', type=str, default='ours')
|
| 16 |
+
parser.add_argument('--epochs', type=int, default=2000)
|
| 17 |
+
parser.add_argument('--nlayers', type=int, default=3)
|
| 18 |
+
parser.add_argument('--hidden', type=int, default=256)
|
| 19 |
+
parser.add_argument('--lr_adj', type=float, default=0.01)
|
| 20 |
+
parser.add_argument('--lr_feat', type=float, default=0.01)
|
| 21 |
+
parser.add_argument('--lr_model', type=float, default=0.01)
|
| 22 |
+
parser.add_argument('--weight_decay', type=float, default=0.0)
|
| 23 |
+
parser.add_argument('--dropout', type=float, default=0.0)
|
| 24 |
+
parser.add_argument('--normalize_features', type=bool, default=True)
|
| 25 |
+
parser.add_argument('--keep_ratio', type=float, default=1.0)
|
| 26 |
+
parser.add_argument('--reduction_rate', type=float, default=1)
|
| 27 |
+
parser.add_argument('--seed', type=int, default=15, help='Random seed.')
|
| 28 |
+
parser.add_argument('--alpha', type=float, default=0, help='regularization term.')
|
| 29 |
+
parser.add_argument('--debug', type=int, default=0)
|
| 30 |
+
parser.add_argument('--sgc', type=int, default=1)
|
| 31 |
+
parser.add_argument('--inner', type=int, default=0)
|
| 32 |
+
parser.add_argument('--outer', type=int, default=20)
|
| 33 |
+
parser.add_argument('--save', type=int, default=0)
|
| 34 |
+
parser.add_argument('--one_step', type=int, default=0)
|
| 35 |
+
args = parser.parse_args()
|
| 36 |
+
|
| 37 |
+
torch.cuda.set_device(args.gpu_id)
|
| 38 |
+
|
| 39 |
+
# random seed setting
|
| 40 |
+
random.seed(args.seed)
|
| 41 |
+
np.random.seed(args.seed)
|
| 42 |
+
torch.manual_seed(args.seed)
|
| 43 |
+
torch.cuda.manual_seed(args.seed)
|
| 44 |
+
|
| 45 |
+
print(args)
|
| 46 |
+
|
| 47 |
+
data_graphsaint = ['flickr', 'reddit', 'ogbn-arxiv']
|
| 48 |
+
if args.dataset in data_graphsaint:
|
| 49 |
+
data = DataGraphSAINT(args.dataset)
|
| 50 |
+
data_full = data.data_full
|
| 51 |
+
else:
|
| 52 |
+
data_full = get_dataset(args.dataset, args.normalize_features)
|
| 53 |
+
data = Transd2Ind(data_full, keep_ratio=args.keep_ratio)
|
| 54 |
+
|
| 55 |
+
agent = GCond(data, args, device='cuda')
|
| 56 |
+
|
| 57 |
+
agent.train()
|
GCond/utils.py
ADDED
|
@@ -0,0 +1,383 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os.path as osp
|
| 2 |
+
import numpy as np
|
| 3 |
+
import scipy.sparse as sp
|
| 4 |
+
import torch
|
| 5 |
+
import torch_geometric.transforms as T
|
| 6 |
+
from ogb.nodeproppred import PygNodePropPredDataset, Evaluator
|
| 7 |
+
from deeprobust.graph.data import Dataset
|
| 8 |
+
from deeprobust.graph.utils import get_train_val_test
|
| 9 |
+
from torch_geometric.utils import train_test_split_edges
|
| 10 |
+
from sklearn.model_selection import train_test_split
|
| 11 |
+
from sklearn import metrics
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch.nn.functional as F
|
| 14 |
+
from sklearn.preprocessing import StandardScaler
|
| 15 |
+
from deeprobust.graph.utils import *
|
| 16 |
+
from torch_geometric.data import NeighborSampler
|
| 17 |
+
from torch_geometric.utils import add_remaining_self_loops, to_undirected
|
| 18 |
+
from torch_geometric.datasets import Planetoid
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def get_dataset(name, normalize_features=False, transform=None, if_dpr=True):
|
| 22 |
+
path = osp.join(osp.dirname(osp.realpath(__file__)), 'data', name)
|
| 23 |
+
if name in ['cora', 'citeseer', 'pubmed']:
|
| 24 |
+
dataset = Planetoid(path, name)
|
| 25 |
+
elif name in ['ogbn-arxiv']:
|
| 26 |
+
dataset = PygNodePropPredDataset(name='ogbn-arxiv')
|
| 27 |
+
else:
|
| 28 |
+
raise NotImplementedError
|
| 29 |
+
|
| 30 |
+
if transform is not None and normalize_features:
|
| 31 |
+
dataset.transform = T.Compose([T.NormalizeFeatures(), transform])
|
| 32 |
+
elif normalize_features:
|
| 33 |
+
dataset.transform = T.NormalizeFeatures()
|
| 34 |
+
elif transform is not None:
|
| 35 |
+
dataset.transform = transform
|
| 36 |
+
|
| 37 |
+
dpr_data = Pyg2Dpr(dataset)
|
| 38 |
+
if name in ['ogbn-arxiv']:
|
| 39 |
+
# the features are different from the features provided by GraphSAINT
|
| 40 |
+
# normalize features, following graphsaint
|
| 41 |
+
feat, idx_train = dpr_data.features, dpr_data.idx_train
|
| 42 |
+
feat_train = feat[idx_train]
|
| 43 |
+
scaler = StandardScaler()
|
| 44 |
+
scaler.fit(feat_train)
|
| 45 |
+
feat = scaler.transform(feat)
|
| 46 |
+
dpr_data.features = feat
|
| 47 |
+
|
| 48 |
+
return dpr_data
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
class Pyg2Dpr(Dataset):
|
| 52 |
+
def __init__(self, pyg_data, **kwargs):
|
| 53 |
+
try:
|
| 54 |
+
splits = pyg_data.get_idx_split()
|
| 55 |
+
except:
|
| 56 |
+
pass
|
| 57 |
+
|
| 58 |
+
dataset_name = pyg_data.name
|
| 59 |
+
pyg_data = pyg_data[0]
|
| 60 |
+
n = pyg_data.num_nodes
|
| 61 |
+
|
| 62 |
+
if dataset_name == 'ogbn-arxiv': # symmetrization
|
| 63 |
+
pyg_data.edge_index = to_undirected(pyg_data.edge_index, pyg_data.num_nodes)
|
| 64 |
+
|
| 65 |
+
self.adj = sp.csr_matrix((np.ones(pyg_data.edge_index.shape[1]),
|
| 66 |
+
(pyg_data.edge_index[0], pyg_data.edge_index[1])), shape=(n, n))
|
| 67 |
+
|
| 68 |
+
self.features = pyg_data.x.numpy()
|
| 69 |
+
self.labels = pyg_data.y.numpy()
|
| 70 |
+
|
| 71 |
+
if len(self.labels.shape) == 2 and self.labels.shape[1] == 1:
|
| 72 |
+
self.labels = self.labels.reshape(-1) # ogb-arxiv needs to reshape
|
| 73 |
+
|
| 74 |
+
if hasattr(pyg_data, 'train_mask'):
|
| 75 |
+
# for fixed split
|
| 76 |
+
self.idx_train = mask_to_index(pyg_data.train_mask, n)
|
| 77 |
+
self.idx_val = mask_to_index(pyg_data.val_mask, n)
|
| 78 |
+
self.idx_test = mask_to_index(pyg_data.test_mask, n)
|
| 79 |
+
self.name = 'Pyg2Dpr'
|
| 80 |
+
else:
|
| 81 |
+
try:
|
| 82 |
+
# for ogb
|
| 83 |
+
self.idx_train = splits['train']
|
| 84 |
+
self.idx_val = splits['valid']
|
| 85 |
+
self.idx_test = splits['test']
|
| 86 |
+
self.name = 'Pyg2Dpr'
|
| 87 |
+
except:
|
| 88 |
+
# for other datasets
|
| 89 |
+
self.idx_train, self.idx_val, self.idx_test = get_train_val_test(
|
| 90 |
+
nnodes=n, val_size=0.1, test_size=0.8, stratify=self.labels)
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def mask_to_index(index, size):
|
| 94 |
+
all_idx = np.arange(size)
|
| 95 |
+
return all_idx[index]
|
| 96 |
+
|
| 97 |
+
def index_to_mask(index, size):
|
| 98 |
+
mask = torch.zeros((size, ), dtype=torch.bool)
|
| 99 |
+
mask[index] = 1
|
| 100 |
+
return mask
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
class Transd2Ind:
|
| 105 |
+
# transductive setting to inductive setting
|
| 106 |
+
|
| 107 |
+
def __init__(self, dpr_data, keep_ratio):
|
| 108 |
+
idx_train, idx_val, idx_test = dpr_data.idx_train, dpr_data.idx_val, dpr_data.idx_test
|
| 109 |
+
adj, features, labels = dpr_data.adj, dpr_data.features, dpr_data.labels
|
| 110 |
+
self.nclass = labels.max()+1
|
| 111 |
+
self.adj_full, self.feat_full, self.labels_full = adj, features, labels
|
| 112 |
+
self.idx_train = np.array(idx_train)
|
| 113 |
+
self.idx_val = np.array(idx_val)
|
| 114 |
+
self.idx_test = np.array(idx_test)
|
| 115 |
+
|
| 116 |
+
if keep_ratio < 1:
|
| 117 |
+
idx_train, _ = train_test_split(idx_train,
|
| 118 |
+
random_state=None,
|
| 119 |
+
train_size=keep_ratio,
|
| 120 |
+
test_size=1-keep_ratio,
|
| 121 |
+
stratify=labels[idx_train])
|
| 122 |
+
|
| 123 |
+
self.adj_train = adj[np.ix_(idx_train, idx_train)]
|
| 124 |
+
self.adj_val = adj[np.ix_(idx_val, idx_val)]
|
| 125 |
+
self.adj_test = adj[np.ix_(idx_test, idx_test)]
|
| 126 |
+
print('size of adj_train:', self.adj_train.shape)
|
| 127 |
+
print('#edges in adj_train:', self.adj_train.sum())
|
| 128 |
+
|
| 129 |
+
self.labels_train = labels[idx_train]
|
| 130 |
+
self.labels_val = labels[idx_val]
|
| 131 |
+
self.labels_test = labels[idx_test]
|
| 132 |
+
|
| 133 |
+
self.feat_train = features[idx_train]
|
| 134 |
+
self.feat_val = features[idx_val]
|
| 135 |
+
self.feat_test = features[idx_test]
|
| 136 |
+
|
| 137 |
+
self.class_dict = None
|
| 138 |
+
self.samplers = None
|
| 139 |
+
self.class_dict2 = None
|
| 140 |
+
|
| 141 |
+
def retrieve_class(self, c, num=256):
|
| 142 |
+
if self.class_dict is None:
|
| 143 |
+
self.class_dict = {}
|
| 144 |
+
for i in range(self.nclass):
|
| 145 |
+
self.class_dict['class_%s'%i] = (self.labels_train == i)
|
| 146 |
+
idx = np.arange(len(self.labels_train))
|
| 147 |
+
idx = idx[self.class_dict['class_%s'%c]]
|
| 148 |
+
return np.random.permutation(idx)[:num]
|
| 149 |
+
|
| 150 |
+
def retrieve_class_sampler(self, c, adj, transductive, num=256, args=None):
|
| 151 |
+
if self.class_dict2 is None:
|
| 152 |
+
self.class_dict2 = {}
|
| 153 |
+
for i in range(self.nclass):
|
| 154 |
+
if transductive:
|
| 155 |
+
idx = self.idx_train[self.labels_train == i]
|
| 156 |
+
else:
|
| 157 |
+
idx = np.arange(len(self.labels_train))[self.labels_train==i]
|
| 158 |
+
self.class_dict2[i] = idx
|
| 159 |
+
|
| 160 |
+
if args.nlayers == 1:
|
| 161 |
+
sizes = [15]
|
| 162 |
+
if args.nlayers == 2:
|
| 163 |
+
sizes = [10, 5]
|
| 164 |
+
# sizes = [-1, -1]
|
| 165 |
+
if args.nlayers == 3:
|
| 166 |
+
sizes = [15, 10, 5]
|
| 167 |
+
if args.nlayers == 4:
|
| 168 |
+
sizes = [15, 10, 5, 5]
|
| 169 |
+
if args.nlayers == 5:
|
| 170 |
+
sizes = [15, 10, 5, 5, 5]
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
if self.samplers is None:
|
| 174 |
+
self.samplers = []
|
| 175 |
+
for i in range(self.nclass):
|
| 176 |
+
node_idx = torch.LongTensor(self.class_dict2[i])
|
| 177 |
+
self.samplers.append(NeighborSampler(adj,
|
| 178 |
+
node_idx=node_idx,
|
| 179 |
+
sizes=sizes, batch_size=num,
|
| 180 |
+
num_workers=12, return_e_id=False,
|
| 181 |
+
num_nodes=adj.size(0),
|
| 182 |
+
shuffle=True))
|
| 183 |
+
batch = np.random.permutation(self.class_dict2[c])[:num]
|
| 184 |
+
out = self.samplers[c].sample(batch)
|
| 185 |
+
return out
|
| 186 |
+
|
| 187 |
+
def retrieve_class_multi_sampler(self, c, adj, transductive, num=256, args=None):
|
| 188 |
+
if self.class_dict2 is None:
|
| 189 |
+
self.class_dict2 = {}
|
| 190 |
+
for i in range(self.nclass):
|
| 191 |
+
if transductive:
|
| 192 |
+
idx = self.idx_train[self.labels_train == i]
|
| 193 |
+
else:
|
| 194 |
+
idx = np.arange(len(self.labels_train))[self.labels_train==i]
|
| 195 |
+
self.class_dict2[i] = idx
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
if self.samplers is None:
|
| 199 |
+
self.samplers = []
|
| 200 |
+
for l in range(2):
|
| 201 |
+
layer_samplers = []
|
| 202 |
+
sizes = [15] if l == 0 else [10, 5]
|
| 203 |
+
for i in range(self.nclass):
|
| 204 |
+
node_idx = torch.LongTensor(self.class_dict2[i])
|
| 205 |
+
layer_samplers.append(NeighborSampler(adj,
|
| 206 |
+
node_idx=node_idx,
|
| 207 |
+
sizes=sizes, batch_size=num,
|
| 208 |
+
num_workers=12, return_e_id=False,
|
| 209 |
+
num_nodes=adj.size(0),
|
| 210 |
+
shuffle=True))
|
| 211 |
+
self.samplers.append(layer_samplers)
|
| 212 |
+
batch = np.random.permutation(self.class_dict2[c])[:num]
|
| 213 |
+
out = self.samplers[args.nlayers-1][c].sample(batch)
|
| 214 |
+
return out
|
| 215 |
+
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def match_loss(gw_syn, gw_real, args, device):
|
| 219 |
+
dis = torch.tensor(0.0).to(device)
|
| 220 |
+
|
| 221 |
+
if args.dis_metric == 'ours':
|
| 222 |
+
|
| 223 |
+
for ig in range(len(gw_real)):
|
| 224 |
+
gwr = gw_real[ig]
|
| 225 |
+
gws = gw_syn[ig]
|
| 226 |
+
dis += distance_wb(gwr, gws)
|
| 227 |
+
|
| 228 |
+
elif args.dis_metric == 'mse':
|
| 229 |
+
gw_real_vec = []
|
| 230 |
+
gw_syn_vec = []
|
| 231 |
+
for ig in range(len(gw_real)):
|
| 232 |
+
gw_real_vec.append(gw_real[ig].reshape((-1)))
|
| 233 |
+
gw_syn_vec.append(gw_syn[ig].reshape((-1)))
|
| 234 |
+
gw_real_vec = torch.cat(gw_real_vec, dim=0)
|
| 235 |
+
gw_syn_vec = torch.cat(gw_syn_vec, dim=0)
|
| 236 |
+
dis = torch.sum((gw_syn_vec - gw_real_vec)**2)
|
| 237 |
+
|
| 238 |
+
elif args.dis_metric == 'cos':
|
| 239 |
+
gw_real_vec = []
|
| 240 |
+
gw_syn_vec = []
|
| 241 |
+
for ig in range(len(gw_real)):
|
| 242 |
+
gw_real_vec.append(gw_real[ig].reshape((-1)))
|
| 243 |
+
gw_syn_vec.append(gw_syn[ig].reshape((-1)))
|
| 244 |
+
gw_real_vec = torch.cat(gw_real_vec, dim=0)
|
| 245 |
+
gw_syn_vec = torch.cat(gw_syn_vec, dim=0)
|
| 246 |
+
dis = 1 - torch.sum(gw_real_vec * gw_syn_vec, dim=-1) / (torch.norm(gw_real_vec, dim=-1) * torch.norm(gw_syn_vec, dim=-1) + 0.000001)
|
| 247 |
+
|
| 248 |
+
else:
|
| 249 |
+
exit('DC error: unknown distance function')
|
| 250 |
+
|
| 251 |
+
return dis
|
| 252 |
+
|
| 253 |
+
def distance_wb(gwr, gws):
|
| 254 |
+
shape = gwr.shape
|
| 255 |
+
|
| 256 |
+
# TODO: output node!!!!
|
| 257 |
+
if len(gwr.shape) == 2:
|
| 258 |
+
gwr = gwr.T
|
| 259 |
+
gws = gws.T
|
| 260 |
+
|
| 261 |
+
if len(shape) == 4: # conv, out*in*h*w
|
| 262 |
+
gwr = gwr.reshape(shape[0], shape[1] * shape[2] * shape[3])
|
| 263 |
+
gws = gws.reshape(shape[0], shape[1] * shape[2] * shape[3])
|
| 264 |
+
elif len(shape) == 3: # layernorm, C*h*w
|
| 265 |
+
gwr = gwr.reshape(shape[0], shape[1] * shape[2])
|
| 266 |
+
gws = gws.reshape(shape[0], shape[1] * shape[2])
|
| 267 |
+
elif len(shape) == 2: # linear, out*in
|
| 268 |
+
tmp = 'do nothing'
|
| 269 |
+
elif len(shape) == 1: # batchnorm/instancenorm, C; groupnorm x, bias
|
| 270 |
+
gwr = gwr.reshape(1, shape[0])
|
| 271 |
+
gws = gws.reshape(1, shape[0])
|
| 272 |
+
return 0
|
| 273 |
+
|
| 274 |
+
dis_weight = torch.sum(1 - torch.sum(gwr * gws, dim=-1) / (torch.norm(gwr, dim=-1) * torch.norm(gws, dim=-1) + 0.000001))
|
| 275 |
+
dis = dis_weight
|
| 276 |
+
return dis
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
def calc_f1(y_true, y_pred,is_sigmoid):
|
| 281 |
+
if not is_sigmoid:
|
| 282 |
+
y_pred = np.argmax(y_pred, axis=1)
|
| 283 |
+
else:
|
| 284 |
+
y_pred[y_pred > 0.5] = 1
|
| 285 |
+
y_pred[y_pred <= 0.5] = 0
|
| 286 |
+
return metrics.f1_score(y_true, y_pred, average="micro"), metrics.f1_score(y_true, y_pred, average="macro")
|
| 287 |
+
|
| 288 |
+
def evaluate(output, labels, args):
|
| 289 |
+
data_graphsaint = ['yelp', 'ppi', 'ppi-large', 'flickr', 'reddit', 'amazon']
|
| 290 |
+
if args.dataset in data_graphsaint:
|
| 291 |
+
labels = labels.cpu().numpy()
|
| 292 |
+
output = output.cpu().numpy()
|
| 293 |
+
if len(labels.shape) > 1:
|
| 294 |
+
micro, macro = calc_f1(labels, output, is_sigmoid=True)
|
| 295 |
+
else:
|
| 296 |
+
micro, macro = calc_f1(labels, output, is_sigmoid=False)
|
| 297 |
+
print("Test set results:", "F1-micro= {:.4f}".format(micro),
|
| 298 |
+
"F1-macro= {:.4f}".format(macro))
|
| 299 |
+
else:
|
| 300 |
+
loss_test = F.nll_loss(output, labels)
|
| 301 |
+
acc_test = accuracy(output, labels)
|
| 302 |
+
print("Test set results:",
|
| 303 |
+
"loss= {:.4f}".format(loss_test.item()),
|
| 304 |
+
"accuracy= {:.4f}".format(acc_test.item()))
|
| 305 |
+
return
|
| 306 |
+
|
| 307 |
+
|
| 308 |
+
from torchvision import datasets, transforms
|
| 309 |
+
def get_mnist(data_path):
|
| 310 |
+
channel = 1
|
| 311 |
+
im_size = (28, 28)
|
| 312 |
+
num_classes = 10
|
| 313 |
+
mean = [0.1307]
|
| 314 |
+
std = [0.3081]
|
| 315 |
+
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
|
| 316 |
+
dst_train = datasets.MNIST(data_path, train=True, download=True, transform=transform) # no augmentation
|
| 317 |
+
dst_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)
|
| 318 |
+
class_names = [str(c) for c in range(num_classes)]
|
| 319 |
+
|
| 320 |
+
labels = []
|
| 321 |
+
feat = []
|
| 322 |
+
for x, y in dst_train:
|
| 323 |
+
feat.append(x.view(1, -1))
|
| 324 |
+
labels.append(y)
|
| 325 |
+
feat = torch.cat(feat, axis=0).numpy()
|
| 326 |
+
from utils_graphsaint import GraphData
|
| 327 |
+
adj = sp.eye(len(feat))
|
| 328 |
+
idx = np.arange(len(feat))
|
| 329 |
+
dpr_data = GraphData(adj-adj, feat, labels, idx, idx, idx)
|
| 330 |
+
from deeprobust.graph.data import Dpr2Pyg
|
| 331 |
+
return Dpr2Pyg(dpr_data)
|
| 332 |
+
|
| 333 |
+
def regularization(adj, x, eig_real=None):
|
| 334 |
+
# fLf
|
| 335 |
+
loss = 0
|
| 336 |
+
# loss += torch.norm(adj, p=1)
|
| 337 |
+
loss += feature_smoothing(adj, x)
|
| 338 |
+
return loss
|
| 339 |
+
|
| 340 |
+
def maxdegree(adj):
|
| 341 |
+
n = adj.shape[0]
|
| 342 |
+
return F.relu(max(adj.sum(1))/n - 0.5)
|
| 343 |
+
|
| 344 |
+
def sparsity2(adj):
|
| 345 |
+
n = adj.shape[0]
|
| 346 |
+
loss_degree = - torch.log(adj.sum(1)).sum() / n
|
| 347 |
+
loss_fro = torch.norm(adj) / n
|
| 348 |
+
return 0 * loss_degree + loss_fro
|
| 349 |
+
|
| 350 |
+
def sparsity(adj):
|
| 351 |
+
n = adj.shape[0]
|
| 352 |
+
thresh = n * n * 0.01
|
| 353 |
+
return F.relu(adj.sum()-thresh)
|
| 354 |
+
# return F.relu(adj.sum()-thresh) / n**2
|
| 355 |
+
|
| 356 |
+
def feature_smoothing(adj, X):
|
| 357 |
+
adj = (adj.t() + adj)/2
|
| 358 |
+
rowsum = adj.sum(1)
|
| 359 |
+
r_inv = rowsum.flatten()
|
| 360 |
+
D = torch.diag(r_inv)
|
| 361 |
+
L = D - adj
|
| 362 |
+
|
| 363 |
+
r_inv = r_inv + 1e-8
|
| 364 |
+
r_inv = r_inv.pow(-1/2).flatten()
|
| 365 |
+
r_inv[torch.isinf(r_inv)] = 0.
|
| 366 |
+
r_mat_inv = torch.diag(r_inv)
|
| 367 |
+
# L = r_mat_inv @ L
|
| 368 |
+
L = r_mat_inv @ L @ r_mat_inv
|
| 369 |
+
|
| 370 |
+
XLXT = torch.matmul(torch.matmul(X.t(), L), X)
|
| 371 |
+
loss_smooth_feat = torch.trace(XLXT)
|
| 372 |
+
# loss_smooth_feat = loss_smooth_feat / (adj.shape[0]**2)
|
| 373 |
+
return loss_smooth_feat
|
| 374 |
+
|
| 375 |
+
def row_normalize_tensor(mx):
|
| 376 |
+
rowsum = mx.sum(1)
|
| 377 |
+
r_inv = rowsum.pow(-1).flatten()
|
| 378 |
+
# r_inv[torch.isinf(r_inv)] = 0.
|
| 379 |
+
r_mat_inv = torch.diag(r_inv)
|
| 380 |
+
mx = r_mat_inv @ mx
|
| 381 |
+
return mx
|
| 382 |
+
|
| 383 |
+
|
GCond/utils_graphsaint.py
ADDED
|
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import scipy.sparse as sp
|
| 2 |
+
import numpy as np
|
| 3 |
+
import sys
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
from sklearn.preprocessing import StandardScaler
|
| 7 |
+
from torch_geometric.data import InMemoryDataset, Data
|
| 8 |
+
import torch
|
| 9 |
+
from itertools import repeat
|
| 10 |
+
from torch_geometric.data import NeighborSampler
|
| 11 |
+
|
| 12 |
+
class DataGraphSAINT:
|
| 13 |
+
'''datasets used in GraphSAINT paper'''
|
| 14 |
+
|
| 15 |
+
def __init__(self, dataset, **kwargs):
|
| 16 |
+
dataset_str='data/'+dataset+'/'
|
| 17 |
+
adj_full = sp.load_npz(dataset_str+'adj_full.npz')
|
| 18 |
+
self.nnodes = adj_full.shape[0]
|
| 19 |
+
if dataset == 'ogbn-arxiv':
|
| 20 |
+
adj_full = adj_full + adj_full.T
|
| 21 |
+
adj_full[adj_full > 1] = 1
|
| 22 |
+
|
| 23 |
+
role = json.load(open(dataset_str+'role.json','r'))
|
| 24 |
+
idx_train = role['tr']
|
| 25 |
+
idx_test = role['te']
|
| 26 |
+
idx_val = role['va']
|
| 27 |
+
|
| 28 |
+
if 'label_rate' in kwargs:
|
| 29 |
+
label_rate = kwargs['label_rate']
|
| 30 |
+
if label_rate < 1:
|
| 31 |
+
idx_train = idx_train[:int(label_rate*len(idx_train))]
|
| 32 |
+
|
| 33 |
+
self.adj_train = adj_full[np.ix_(idx_train, idx_train)]
|
| 34 |
+
self.adj_val = adj_full[np.ix_(idx_val, idx_val)]
|
| 35 |
+
self.adj_test = adj_full[np.ix_(idx_test, idx_test)]
|
| 36 |
+
|
| 37 |
+
feat = np.load(dataset_str+'feats.npy')
|
| 38 |
+
# ---- normalize feat ----
|
| 39 |
+
feat_train = feat[idx_train]
|
| 40 |
+
scaler = StandardScaler()
|
| 41 |
+
scaler.fit(feat_train)
|
| 42 |
+
feat = scaler.transform(feat)
|
| 43 |
+
|
| 44 |
+
self.feat_train = feat[idx_train]
|
| 45 |
+
self.feat_val = feat[idx_val]
|
| 46 |
+
self.feat_test = feat[idx_test]
|
| 47 |
+
|
| 48 |
+
class_map = json.load(open(dataset_str + 'class_map.json','r'))
|
| 49 |
+
labels = self.process_labels(class_map)
|
| 50 |
+
|
| 51 |
+
self.labels_train = labels[idx_train]
|
| 52 |
+
self.labels_val = labels[idx_val]
|
| 53 |
+
self.labels_test = labels[idx_test]
|
| 54 |
+
|
| 55 |
+
self.data_full = GraphData(adj_full, feat, labels, idx_train, idx_val, idx_test)
|
| 56 |
+
self.class_dict = None
|
| 57 |
+
self.class_dict2 = None
|
| 58 |
+
|
| 59 |
+
self.adj_full = adj_full
|
| 60 |
+
self.feat_full = feat
|
| 61 |
+
self.labels_full = labels
|
| 62 |
+
self.idx_train = np.array(idx_train)
|
| 63 |
+
self.idx_val = np.array(idx_val)
|
| 64 |
+
self.idx_test = np.array(idx_test)
|
| 65 |
+
self.samplers = None
|
| 66 |
+
|
| 67 |
+
def process_labels(self, class_map):
|
| 68 |
+
"""
|
| 69 |
+
setup vertex property map for output classests
|
| 70 |
+
"""
|
| 71 |
+
num_vertices = self.nnodes
|
| 72 |
+
if isinstance(list(class_map.values())[0], list):
|
| 73 |
+
num_classes = len(list(class_map.values())[0])
|
| 74 |
+
self.nclass = num_classes
|
| 75 |
+
class_arr = np.zeros((num_vertices, num_classes))
|
| 76 |
+
for k,v in class_map.items():
|
| 77 |
+
class_arr[int(k)] = v
|
| 78 |
+
else:
|
| 79 |
+
class_arr = np.zeros(num_vertices, dtype=np.int)
|
| 80 |
+
for k, v in class_map.items():
|
| 81 |
+
class_arr[int(k)] = v
|
| 82 |
+
class_arr = class_arr - class_arr.min()
|
| 83 |
+
self.nclass = max(class_arr) + 1
|
| 84 |
+
return class_arr
|
| 85 |
+
|
| 86 |
+
def retrieve_class(self, c, num=256):
|
| 87 |
+
if self.class_dict is None:
|
| 88 |
+
self.class_dict = {}
|
| 89 |
+
for i in range(self.nclass):
|
| 90 |
+
self.class_dict['class_%s'%i] = (self.labels_train == i)
|
| 91 |
+
idx = np.arange(len(self.labels_train))
|
| 92 |
+
idx = idx[self.class_dict['class_%s'%c]]
|
| 93 |
+
return np.random.permutation(idx)[:num]
|
| 94 |
+
|
| 95 |
+
def retrieve_class_sampler(self, c, adj, transductive, num=256, args=None):
|
| 96 |
+
if args.nlayers == 1:
|
| 97 |
+
sizes = [30]
|
| 98 |
+
if args.nlayers == 2:
|
| 99 |
+
if args.dataset in ['reddit', 'flickr']:
|
| 100 |
+
if args.option == 0:
|
| 101 |
+
sizes = [15, 8]
|
| 102 |
+
if args.option == 1:
|
| 103 |
+
sizes = [20, 10]
|
| 104 |
+
if args.option == 2:
|
| 105 |
+
sizes = [25, 10]
|
| 106 |
+
else:
|
| 107 |
+
sizes = [10, 5]
|
| 108 |
+
|
| 109 |
+
if self.class_dict2 is None:
|
| 110 |
+
print(sizes)
|
| 111 |
+
self.class_dict2 = {}
|
| 112 |
+
for i in range(self.nclass):
|
| 113 |
+
if transductive:
|
| 114 |
+
idx_train = np.array(self.idx_train)
|
| 115 |
+
idx = idx_train[self.labels_train == i]
|
| 116 |
+
else:
|
| 117 |
+
idx = np.arange(len(self.labels_train))[self.labels_train==i]
|
| 118 |
+
self.class_dict2[i] = idx
|
| 119 |
+
|
| 120 |
+
if self.samplers is None:
|
| 121 |
+
self.samplers = []
|
| 122 |
+
for i in range(self.nclass):
|
| 123 |
+
node_idx = torch.LongTensor(self.class_dict2[i])
|
| 124 |
+
if len(node_idx) == 0:
|
| 125 |
+
continue
|
| 126 |
+
|
| 127 |
+
self.samplers.append(NeighborSampler(adj,
|
| 128 |
+
node_idx=node_idx,
|
| 129 |
+
sizes=sizes, batch_size=num,
|
| 130 |
+
num_workers=8, return_e_id=False,
|
| 131 |
+
num_nodes=adj.size(0),
|
| 132 |
+
shuffle=True))
|
| 133 |
+
batch = np.random.permutation(self.class_dict2[c])[:num]
|
| 134 |
+
out = self.samplers[c].sample(batch)
|
| 135 |
+
return out
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class GraphData:
|
| 139 |
+
|
| 140 |
+
def __init__(self, adj, features, labels, idx_train, idx_val, idx_test):
|
| 141 |
+
self.adj = adj
|
| 142 |
+
self.features = features
|
| 143 |
+
self.labels = labels
|
| 144 |
+
self.idx_train = idx_train
|
| 145 |
+
self.idx_val = idx_val
|
| 146 |
+
self.idx_test = idx_test
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class Data2Pyg:
|
| 150 |
+
|
| 151 |
+
def __init__(self, data, device='cuda', transform=None, **kwargs):
|
| 152 |
+
self.data_train = Dpr2Pyg(data.data_train, transform=transform)[0].to(device)
|
| 153 |
+
self.data_val = Dpr2Pyg(data.data_val, transform=transform)[0].to(device)
|
| 154 |
+
self.data_test = Dpr2Pyg(data.data_test, transform=transform)[0].to(device)
|
| 155 |
+
self.nclass = data.nclass
|
| 156 |
+
self.nfeat = data.nfeat
|
| 157 |
+
self.class_dict = None
|
| 158 |
+
|
| 159 |
+
def retrieve_class(self, c, num=256):
|
| 160 |
+
if self.class_dict is None:
|
| 161 |
+
self.class_dict = {}
|
| 162 |
+
for i in range(self.nclass):
|
| 163 |
+
self.class_dict['class_%s'%i] = (self.data_train.y == i).cpu().numpy()
|
| 164 |
+
idx = np.arange(len(self.data_train.y))
|
| 165 |
+
idx = idx[self.class_dict['class_%s'%c]]
|
| 166 |
+
return np.random.permutation(idx)[:num]
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class Dpr2Pyg(InMemoryDataset):
|
| 170 |
+
|
| 171 |
+
def __init__(self, dpr_data, transform=None, **kwargs):
|
| 172 |
+
root = 'data/' # dummy root; does not mean anything
|
| 173 |
+
self.dpr_data = dpr_data
|
| 174 |
+
super(Dpr2Pyg, self).__init__(root, transform)
|
| 175 |
+
pyg_data = self.process()
|
| 176 |
+
self.data, self.slices = self.collate([pyg_data])
|
| 177 |
+
self.transform = transform
|
| 178 |
+
|
| 179 |
+
def process(self):
|
| 180 |
+
dpr_data = self.dpr_data
|
| 181 |
+
edge_index = torch.LongTensor(dpr_data.adj.nonzero())
|
| 182 |
+
# by default, the features in pyg data is dense
|
| 183 |
+
if sp.issparse(dpr_data.features):
|
| 184 |
+
x = torch.FloatTensor(dpr_data.features.todense()).float()
|
| 185 |
+
else:
|
| 186 |
+
x = torch.FloatTensor(dpr_data.features).float()
|
| 187 |
+
y = torch.LongTensor(dpr_data.labels)
|
| 188 |
+
data = Data(x=x, edge_index=edge_index, y=y)
|
| 189 |
+
data.train_mask = None
|
| 190 |
+
data.val_mask = None
|
| 191 |
+
data.test_mask = None
|
| 192 |
+
return data
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def get(self, idx):
|
| 196 |
+
data = self.data.__class__()
|
| 197 |
+
|
| 198 |
+
if hasattr(self.data, '__num_nodes__'):
|
| 199 |
+
data.num_nodes = self.data.__num_nodes__[idx]
|
| 200 |
+
|
| 201 |
+
for key in self.data.keys:
|
| 202 |
+
item, slices = self.data[key], self.slices[key]
|
| 203 |
+
s = list(repeat(slice(None), item.dim()))
|
| 204 |
+
s[self.data.__cat_dim__(key, item)] = slice(slices[idx],
|
| 205 |
+
slices[idx + 1])
|
| 206 |
+
data[key] = item[s]
|
| 207 |
+
return data
|
| 208 |
+
|
| 209 |
+
@property
|
| 210 |
+
def raw_file_names(self):
|
| 211 |
+
return ['some_file_1', 'some_file_2', ...]
|
| 212 |
+
|
| 213 |
+
@property
|
| 214 |
+
def processed_file_names(self):
|
| 215 |
+
return ['data.pt']
|
| 216 |
+
|
| 217 |
+
def _download(self):
|
| 218 |
+
pass
|
| 219 |
+
|
| 220 |
+
|
requirements.txt
CHANGED
|
@@ -1,4 +1,4 @@
|
|
| 1 |
-
torch
|
| 2 |
torch-scatter
|
| 3 |
torch-sparse
|
| 4 |
-
torch-geometric
|
|
|
|
|
|
|
|
|
| 1 |
torch-scatter
|
| 2 |
torch-sparse
|
| 3 |
+
torch-geometric
|
| 4 |
+
rich
|
src/2.1_lrmc_bilevel.py
CHANGED
|
@@ -1,6 +1,4 @@
|
|
| 1 |
-
# lrmc_bilevel.py
|
| 2 |
# Top-1 LRMC ablation: one-cluster pooling vs. plain GCN on Planetoid (e.g., Cora)
|
| 3 |
-
# Requires: torch, torch_geometric, torch_scatter, torch_sparse
|
| 4 |
#
|
| 5 |
# Usage examples:
|
| 6 |
# python lrmc_bilevel.py --dataset Cora --seeds /path/to/lrmc_seeds.json --variant baseline
|
|
@@ -28,6 +26,8 @@ from torch_sparse import coalesce, spspmm
|
|
| 28 |
from torch_geometric.datasets import Planetoid
|
| 29 |
from torch_geometric.nn import GCNConv
|
| 30 |
|
|
|
|
|
|
|
| 31 |
|
| 32 |
# ---------------------------
|
| 33 |
# Utilities: edges and seeds
|
|
|
|
|
|
|
| 1 |
# Top-1 LRMC ablation: one-cluster pooling vs. plain GCN on Planetoid (e.g., Cora)
|
|
|
|
| 2 |
#
|
| 3 |
# Usage examples:
|
| 4 |
# python lrmc_bilevel.py --dataset Cora --seeds /path/to/lrmc_seeds.json --variant baseline
|
|
|
|
| 26 |
from torch_geometric.datasets import Planetoid
|
| 27 |
from torch_geometric.nn import GCNConv
|
| 28 |
|
| 29 |
+
from rich import print
|
| 30 |
+
|
| 31 |
|
| 32 |
# ---------------------------
|
| 33 |
# Utilities: edges and seeds
|
src/2.2_lrmc_bilevel.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
# 2.1_lrmc_bilevel.py
|
| 2 |
# Top-1 LRMC ablation with debug guards so seeds differences are visible.
|
| 3 |
# Requires: torch, torch_geometric, torch_scatter, torch_sparse
|
| 4 |
|
|
@@ -16,6 +15,8 @@ from torch_sparse import coalesce, spspmm
|
|
| 16 |
from torch_geometric.datasets import Planetoid
|
| 17 |
from torch_geometric.nn import GCNConv
|
| 18 |
|
|
|
|
|
|
|
| 19 |
|
| 20 |
# ---------------------------
|
| 21 |
# Utilities: edges and seeds
|
|
|
|
|
|
|
| 1 |
# Top-1 LRMC ablation with debug guards so seeds differences are visible.
|
| 2 |
# Requires: torch, torch_geometric, torch_scatter, torch_sparse
|
| 3 |
|
|
|
|
| 15 |
from torch_geometric.datasets import Planetoid
|
| 16 |
from torch_geometric.nn import GCNConv
|
| 17 |
|
| 18 |
+
from rich import print
|
| 19 |
+
|
| 20 |
|
| 21 |
# ---------------------------
|
| 22 |
# Utilities: edges and seeds
|
src/2.3_lrmc_bilevel.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
# 2.3_lrmc_bilevel.py
|
| 2 |
# Top-1 LRMC ablation with: cluster refinement (k-core), gated residual fusion,
|
| 3 |
# sparsified cluster graph (drop self-loops + per-row top-k), and A + γA² mix.
|
| 4 |
# Requires: torch, torch_geometric, torch_scatter, torch_sparse
|
|
@@ -17,6 +16,8 @@ from torch_sparse import coalesce, spspmm
|
|
| 17 |
from torch_geometric.datasets import Planetoid
|
| 18 |
from torch_geometric.nn import GCNConv
|
| 19 |
|
|
|
|
|
|
|
| 20 |
|
| 21 |
# ---------------------------
|
| 22 |
# Utilities: edges and seeds
|
|
|
|
|
|
|
| 1 |
# Top-1 LRMC ablation with: cluster refinement (k-core), gated residual fusion,
|
| 2 |
# sparsified cluster graph (drop self-loops + per-row top-k), and A + γA² mix.
|
| 3 |
# Requires: torch, torch_geometric, torch_scatter, torch_sparse
|
|
|
|
| 16 |
from torch_geometric.datasets import Planetoid
|
| 17 |
from torch_geometric.nn import GCNConv
|
| 18 |
|
| 19 |
+
from rich import print
|
| 20 |
+
|
| 21 |
|
| 22 |
# ---------------------------
|
| 23 |
# Utilities: edges and seeds
|
src/2.4_lrmc_bilevel.py
CHANGED
|
@@ -1,4 +1,3 @@
|
|
| 1 |
-
# lrmc_bilevel.py
|
| 2 |
# Top-1 LRMC ablation: one-cluster pooling vs. plain GCN on Planetoid (e.g., Cora)
|
| 3 |
# Requires: torch, torch_geometric, torch_scatter, torch_sparse
|
| 4 |
#
|
|
@@ -30,6 +29,8 @@ from torch_geometric.datasets import Planetoid
|
|
| 30 |
from torch_geometric.nn import GCNConv
|
| 31 |
from torch_geometric.utils import subgraph, degree # Added for stability score
|
| 32 |
|
|
|
|
|
|
|
| 33 |
# ---------------------------
|
| 34 |
# Utilities: edges and seeds
|
| 35 |
# ---------------------------
|
|
|
|
|
|
|
| 1 |
# Top-1 LRMC ablation: one-cluster pooling vs. plain GCN on Planetoid (e.g., Cora)
|
| 2 |
# Requires: torch, torch_geometric, torch_scatter, torch_sparse
|
| 3 |
#
|
|
|
|
| 29 |
from torch_geometric.nn import GCNConv
|
| 30 |
from torch_geometric.utils import subgraph, degree # Added for stability score
|
| 31 |
|
| 32 |
+
from rich import print
|
| 33 |
+
|
| 34 |
# ---------------------------
|
| 35 |
# Utilities: edges and seeds
|
| 36 |
# ---------------------------
|
src/2.5_lrmc_bilevel.py
ADDED
|
@@ -0,0 +1,540 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
L-RMC Anchored GCN vs. Plain GCN (dynamic robustness evaluation)
|
| 3 |
+
==============================================================
|
| 4 |
+
|
| 5 |
+
This script trains a baseline two‑layer GCN and a new **anchor‑gated** GCN on
|
| 6 |
+
Planetoid citation networks (Cora/Citeseer/Pubmed). The anchor‑gated GCN uses
|
| 7 |
+
the top‑1 L‑RMC cluster (loaded from a provided JSON file) as a *decentralized
|
| 8 |
+
core*. During message passing it blends standard neighborhood aggregation
|
| 9 |
+
(`h_base`) with aggregation restricted to the core (`h_core`) via a per‑node
|
| 10 |
+
gating network. Cross‑boundary edges are optionally down‑weighted by a
|
| 11 |
+
damping factor `γ`.
|
| 12 |
+
|
| 13 |
+
After training on the static graph, the script evaluates *robustness over
|
| 14 |
+
time*. Starting from the original adjacency, it repeatedly performs random
|
| 15 |
+
edge rewires (removes a fraction of existing edges and adds the same number
|
| 16 |
+
of random new edges) and measures test accuracy at each step **without
|
| 17 |
+
retraining**. The area under the accuracy–time curve (AUC‑AT) is reported
|
| 18 |
+
for both the baseline and the anchored model. A higher AUC‑AT indicates
|
| 19 |
+
longer resilience to graph churn.
|
| 20 |
+
|
| 21 |
+
Usage examples::
|
| 22 |
+
|
| 23 |
+
# Train only baseline and report dynamic AUC
|
| 24 |
+
python 2.5_lrmc_bilevel.py --dataset Cora --seeds path/to/lrmc_seeds.json --variant baseline
|
| 25 |
+
|
| 26 |
+
# Train baseline and anchor models, evaluate AUC‑over‑time on 30 steps with 5% rewiring
|
| 27 |
+
python 2.5_lrmc_bilevel.py --dataset Cora --seeds path/to/lrmc_seeds.json --variant anchor \
|
| 28 |
+
--dynamic_steps 30 --flip_fraction 0.05 --gamma 0.8
|
| 29 |
+
|
| 30 |
+
Notes
|
| 31 |
+
-----
|
| 32 |
+
* The seeds JSON must contain an entry ``"clusters"`` with a list of clusters; the
|
| 33 |
+
cluster with maximum (score, size) is chosen as the core.
|
| 34 |
+
* For fairness, both models are trained on the identical training mask and
|
| 35 |
+
evaluated on the same dynamic perturbations.
|
| 36 |
+
* Random rewiring is undirected: an edge (u,v) is treated as the same as (v,u).
|
| 37 |
+
* Cross‑boundary damping and the gating network use only structural
|
| 38 |
+
information; features are left unchanged during perturbations.
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
import argparse
|
| 42 |
+
import json
|
| 43 |
+
import random
|
| 44 |
+
from pathlib import Path
|
| 45 |
+
from typing import Tuple, List, Optional, Set
|
| 46 |
+
|
| 47 |
+
import torch
|
| 48 |
+
import torch.nn as nn
|
| 49 |
+
import torch.nn.functional as F
|
| 50 |
+
from torch import Tensor
|
| 51 |
+
from torch_geometric.datasets import Planetoid
|
| 52 |
+
from torch_geometric.nn import GCNConv
|
| 53 |
+
|
| 54 |
+
from rich import print
|
| 55 |
+
|
| 56 |
+
# -----------------------------------------------------------------------------
|
| 57 |
+
# Utilities for loading LRMC core assignment
|
| 58 |
+
# -----------------------------------------------------------------------------
|
| 59 |
+
|
| 60 |
+
def _pick_top1_cluster(obj: dict) -> List[int]:
|
| 61 |
+
"""
|
| 62 |
+
From LRMC JSON with structure {"clusters":[{"seed_nodes":[...],"score":float,...},...]}
|
| 63 |
+
choose the cluster with the highest (score, size) and return its members as
|
| 64 |
+
0‑indexed integers. If no clusters exist, returns an empty list.
|
| 65 |
+
"""
|
| 66 |
+
clusters = obj.get("clusters", [])
|
| 67 |
+
if not clusters:
|
| 68 |
+
return []
|
| 69 |
+
# Choose by highest score, tie‑break by size
|
| 70 |
+
best = max(clusters, key=lambda c: (float(c.get("score", 0.0)), len(c.get("seed_nodes", []))))
|
| 71 |
+
return [nid - 1 for nid in best.get("seed_nodes", [])]
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def load_top1_assignment(seeds_json: str, n_nodes: int) -> Tuple[Tensor, Tensor]:
|
| 75 |
+
"""
|
| 76 |
+
Given a path to the LRMC seeds JSON and total number of nodes, returns:
|
| 77 |
+
|
| 78 |
+
* core_mask: bool Tensor of shape [N] where True indicates membership in the
|
| 79 |
+
top‑1 LRMC cluster.
|
| 80 |
+
* core_nodes: Long Tensor containing the indices of the core nodes.
|
| 81 |
+
|
| 82 |
+
Nodes not in the core form the periphery. If the JSON has no clusters,
|
| 83 |
+
the core is empty.
|
| 84 |
+
"""
|
| 85 |
+
obj = json.loads(Path(seeds_json).read_text())
|
| 86 |
+
core_list = _pick_top1_cluster(obj)
|
| 87 |
+
core_nodes = torch.tensor(sorted(set(core_list)), dtype=torch.long)
|
| 88 |
+
core_mask = torch.zeros(n_nodes, dtype=torch.bool)
|
| 89 |
+
if core_nodes.numel() > 0:
|
| 90 |
+
core_mask[core_nodes] = True
|
| 91 |
+
return core_mask, core_nodes
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
# -----------------------------------------------------------------------------
|
| 95 |
+
# Baseline GCN: standard two‑layer GCN
|
| 96 |
+
# -----------------------------------------------------------------------------
|
| 97 |
+
|
| 98 |
+
class GCN2(nn.Module):
|
| 99 |
+
"""Plain 2‑layer GCN (baseline)."""
|
| 100 |
+
|
| 101 |
+
def __init__(self, in_dim: int, hid_dim: int, out_dim: int, dropout: float = 0.5):
|
| 102 |
+
super().__init__()
|
| 103 |
+
self.conv1 = GCNConv(in_dim, hid_dim)
|
| 104 |
+
self.conv2 = GCNConv(hid_dim, out_dim)
|
| 105 |
+
self.dropout = dropout
|
| 106 |
+
|
| 107 |
+
def forward(self, x: Tensor, edge_index: Tensor, edge_weight: Optional[Tensor] = None) -> Tensor:
|
| 108 |
+
# Use self loops implicitly (GCNConv defaults add_self_loops=True)
|
| 109 |
+
x = F.relu(self.conv1(x, edge_index, edge_weight))
|
| 110 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
| 111 |
+
x = self.conv2(x, edge_index, edge_weight)
|
| 112 |
+
return x
|
| 113 |
+
|
| 114 |
+
|
| 115 |
+
# -----------------------------------------------------------------------------
|
| 116 |
+
# Anchor‑gated GCN
|
| 117 |
+
# -----------------------------------------------------------------------------
|
| 118 |
+
|
| 119 |
+
class AnchorGCN(nn.Module):
|
| 120 |
+
"""
|
| 121 |
+
A two‑layer GCN that injects a core‑restricted aggregation channel and
|
| 122 |
+
down‑weights edges crossing the core boundary. After the first GCN layer
|
| 123 |
+
computes base features, a gating network mixes them with features
|
| 124 |
+
aggregated only among core neighbors.
|
| 125 |
+
|
| 126 |
+
Parameters
|
| 127 |
+
----------
|
| 128 |
+
in_dim : int
|
| 129 |
+
Dimensionality of input node features.
|
| 130 |
+
hid_dim : int
|
| 131 |
+
Dimensionality of hidden layer.
|
| 132 |
+
out_dim : int
|
| 133 |
+
Number of output classes.
|
| 134 |
+
core_mask : Tensor[bool]
|
| 135 |
+
Boolean mask indicating which nodes belong to the L‑RMC core.
|
| 136 |
+
gamma : float, optional
|
| 137 |
+
Damping factor for edges that connect core and non‑core nodes.
|
| 138 |
+
Values <1.0 reduce the influence of boundary edges. Default is 1.0
|
| 139 |
+
(no damping).
|
| 140 |
+
dropout : float, optional
|
| 141 |
+
Dropout probability applied after the first layer.
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
def __init__(self,
|
| 145 |
+
in_dim: int,
|
| 146 |
+
hid_dim: int,
|
| 147 |
+
out_dim: int,
|
| 148 |
+
core_mask: Tensor,
|
| 149 |
+
gamma: float = 1.0,
|
| 150 |
+
dropout: float = 0.5):
|
| 151 |
+
super().__init__()
|
| 152 |
+
self.core_mask = core_mask.clone().detach()
|
| 153 |
+
self.gamma = float(gamma)
|
| 154 |
+
self.dropout = dropout
|
| 155 |
+
|
| 156 |
+
# Base and core convolutions for the first layer
|
| 157 |
+
# Base conv uses self loops; core conv disables self loops to avoid
|
| 158 |
+
# spurious core contributions on non‑core nodes
|
| 159 |
+
self.base1 = GCNConv(in_dim, hid_dim, add_self_loops=True)
|
| 160 |
+
self.core1 = GCNConv(in_dim, hid_dim, add_self_loops=False)
|
| 161 |
+
|
| 162 |
+
# Second layer: standard GCN on mixed features
|
| 163 |
+
self.conv2 = GCNConv(hid_dim, out_dim)
|
| 164 |
+
|
| 165 |
+
# Gating network: maps structural features to α ∈ [0,1]
|
| 166 |
+
self.gate = nn.Sequential(
|
| 167 |
+
nn.Linear(3, 16),
|
| 168 |
+
nn.ReLU(),
|
| 169 |
+
nn.Linear(16, 1),
|
| 170 |
+
nn.Sigmoid(),
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
def _compute_edge_weights(self, edge_index: Tensor) -> Tensor:
|
| 174 |
+
"""
|
| 175 |
+
Given an edge index (two‑row tensor), return a weight tensor of ones
|
| 176 |
+
multiplied by ``gamma`` for edges with exactly one endpoint in the core.
|
| 177 |
+
Self loops (if present) are untouched. Edge weights are 1 for base
|
| 178 |
+
edges and <1 for cross‑boundary edges.
|
| 179 |
+
"""
|
| 180 |
+
if self.gamma >= 1.0:
|
| 181 |
+
return torch.ones(edge_index.size(1), dtype=torch.float32, device=edge_index.device)
|
| 182 |
+
src, dst = edge_index[0], edge_index[1]
|
| 183 |
+
in_core_src = self.core_mask[src]
|
| 184 |
+
in_core_dst = self.core_mask[dst]
|
| 185 |
+
cross = in_core_src ^ in_core_dst
|
| 186 |
+
w = torch.ones(edge_index.size(1), dtype=torch.float32, device=edge_index.device)
|
| 187 |
+
w[cross] *= self.gamma
|
| 188 |
+
return w
|
| 189 |
+
|
| 190 |
+
def _compute_structural_features(self, edge_index: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
|
| 191 |
+
"""
|
| 192 |
+
Compute structural features used by the gating network:
|
| 193 |
+
|
| 194 |
+
* `in_core` – 1 if node in core, else 0
|
| 195 |
+
* `frac_core_nbrs` – fraction of neighbors that are in the core
|
| 196 |
+
* `is_boundary` – 1 if node has both core and non‑core neighbors
|
| 197 |
+
|
| 198 |
+
The features are returned as a tuple of three tensors of shape [N,1].
|
| 199 |
+
Nodes with zero degree get frac_core_nbrs=0 and is_boundary=0.
|
| 200 |
+
"""
|
| 201 |
+
N = self.core_mask.size(0)
|
| 202 |
+
device = edge_index.device
|
| 203 |
+
# Degree and core neighbor counts
|
| 204 |
+
src = edge_index[0]
|
| 205 |
+
dst = edge_index[1]
|
| 206 |
+
deg = torch.zeros(N, dtype=torch.float32, device=device)
|
| 207 |
+
core_deg = torch.zeros(N, dtype=torch.float32, device=device)
|
| 208 |
+
# Count contributions of directed edges; duplicates will double‑count but
|
| 209 |
+
# the ratio remains stable if the graph is symmetric.
|
| 210 |
+
deg.index_add_(0, src, torch.ones_like(src, dtype=torch.float32))
|
| 211 |
+
# Count core neighbors: only increment source if destination is core
|
| 212 |
+
core_flags = self.core_mask[dst].float()
|
| 213 |
+
core_deg.index_add_(0, src, core_flags)
|
| 214 |
+
# Avoid division by zero
|
| 215 |
+
frac_core = torch.zeros(N, dtype=torch.float32, device=device)
|
| 216 |
+
nonzero = deg > 0
|
| 217 |
+
frac_core[nonzero] = core_deg[nonzero] / deg[nonzero]
|
| 218 |
+
# Determine boundary: at least one core neighbor AND at least one non‑core neighbor
|
| 219 |
+
has_core = core_deg > 0
|
| 220 |
+
has_non_core = (deg - core_deg) > 0
|
| 221 |
+
is_boundary = (has_core & has_non_core).float()
|
| 222 |
+
in_core = self.core_mask.float()
|
| 223 |
+
return in_core.view(-1, 1), frac_core.view(-1, 1), is_boundary.view(-1, 1)
|
| 224 |
+
|
| 225 |
+
def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
|
| 226 |
+
# Compute dynamic edge weights (for base channels) using damping
|
| 227 |
+
w = self._compute_edge_weights(edge_index)
|
| 228 |
+
# First layer: base aggregation (standard neighbors with self loops)
|
| 229 |
+
h_base = self.base1(x, edge_index, w)
|
| 230 |
+
h_base = F.relu(h_base)
|
| 231 |
+
|
| 232 |
+
# First layer: core aggregation (only core neighbors, no self loops)
|
| 233 |
+
# Extract edges where both endpoints are core
|
| 234 |
+
src, dst = edge_index
|
| 235 |
+
mask_core_edges = self.core_mask[src] & self.core_mask[dst]
|
| 236 |
+
ei_core = edge_index[:, mask_core_edges]
|
| 237 |
+
# If no core edges exist, h_core will be zeros
|
| 238 |
+
if ei_core.numel() == 0:
|
| 239 |
+
h_core = torch.zeros_like(h_base)
|
| 240 |
+
else:
|
| 241 |
+
h_core = self.core1(x, ei_core)
|
| 242 |
+
h_core = F.relu(h_core)
|
| 243 |
+
|
| 244 |
+
# Structural features for gating
|
| 245 |
+
in_core, frac_core, is_boundary = self._compute_structural_features(edge_index)
|
| 246 |
+
feats = torch.cat([in_core, frac_core, is_boundary], dim=1)
|
| 247 |
+
alpha = self.gate(feats).view(-1) # shape [N]
|
| 248 |
+
# Force α=0 for nodes with no core neighbors to avoid modifying true periphery.
|
| 249 |
+
# Nodes with frac_core == 0 have zero core neighbors by construction.
|
| 250 |
+
no_core_neighbors = (frac_core.view(-1) == 0)
|
| 251 |
+
alpha = torch.where(no_core_neighbors, torch.zeros_like(alpha), alpha)
|
| 252 |
+
|
| 253 |
+
# Mix base and core features; h_final = h_base + α (h_core - h_base)
|
| 254 |
+
# Equivalent to (1-α)*h_base + α*h_core
|
| 255 |
+
h1 = h_base + alpha.unsqueeze(1) * (h_core - h_base)
|
| 256 |
+
|
| 257 |
+
h1 = F.dropout(h1, p=self.dropout, training=self.training)
|
| 258 |
+
# Second layer: standard GCN with the same damping weights
|
| 259 |
+
out = self.conv2(h1, edge_index, w)
|
| 260 |
+
return out
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
# The deg_for_division helper is no longer used but left here for completeness.
|
| 264 |
+
def deg_for_division(edge_index: Tensor, num_nodes: int) -> Tensor:
|
| 265 |
+
src = edge_index[0]
|
| 266 |
+
deg = torch.zeros(num_nodes, dtype=torch.float32, device=edge_index.device)
|
| 267 |
+
deg.index_add_(0, src, torch.ones_like(src, dtype=torch.float32))
|
| 268 |
+
return deg
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
# -----------------------------------------------------------------------------
|
| 272 |
+
# Training and evaluation routines
|
| 273 |
+
# -----------------------------------------------------------------------------
|
| 274 |
+
|
| 275 |
+
@torch.no_grad()
|
| 276 |
+
def accuracy(logits: Tensor, y: Tensor, mask: Tensor) -> float:
|
| 277 |
+
"""Compute accuracy of the predictions over the mask."""
|
| 278 |
+
pred = logits[mask].argmax(dim=1)
|
| 279 |
+
return (pred == y[mask]).float().mean().item()
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def train_model(model: nn.Module,
|
| 283 |
+
data,
|
| 284 |
+
epochs: int = 200,
|
| 285 |
+
lr: float = 0.01,
|
| 286 |
+
weight_decay: float = 5e-4) -> None:
|
| 287 |
+
"""Standard training loop for either baseline or anchor models."""
|
| 288 |
+
opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
|
| 289 |
+
best_val = 0.0
|
| 290 |
+
best_state = None
|
| 291 |
+
for ep in range(1, epochs + 1):
|
| 292 |
+
model.train()
|
| 293 |
+
opt.zero_grad(set_to_none=True)
|
| 294 |
+
logits = model(data.x, data.edge_index)
|
| 295 |
+
loss = F.cross_entropy(logits[data.train_mask], data.y[data.train_mask])
|
| 296 |
+
loss.backward()
|
| 297 |
+
opt.step()
|
| 298 |
+
# Evaluate on validation
|
| 299 |
+
model.eval()
|
| 300 |
+
logits_val = model(data.x, data.edge_index)
|
| 301 |
+
val_acc = accuracy(logits_val, data.y, data.val_mask)
|
| 302 |
+
if val_acc > best_val:
|
| 303 |
+
best_val = val_acc
|
| 304 |
+
best_state = {k: v.detach().clone() for k, v in model.state_dict().items()}
|
| 305 |
+
if best_state is not None:
|
| 306 |
+
model.load_state_dict(best_state)
|
| 307 |
+
model.eval()
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
def evaluate_model(model: nn.Module, data) -> dict:
|
| 311 |
+
"""Evaluate a trained model on train, val, and test masks."""
|
| 312 |
+
model.eval()
|
| 313 |
+
logits = model(data.x, data.edge_index)
|
| 314 |
+
return {
|
| 315 |
+
"train": accuracy(logits, data.y, data.train_mask),
|
| 316 |
+
"val": accuracy(logits, data.y, data.val_mask),
|
| 317 |
+
"test": accuracy(logits, data.y, data.test_mask),
|
| 318 |
+
}
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
# -----------------------------------------------------------------------------
|
| 322 |
+
# Dynamic graph perturbation utilities
|
| 323 |
+
# -----------------------------------------------------------------------------
|
| 324 |
+
|
| 325 |
+
def undirected_edge_set(edge_index: Tensor) -> Set[Tuple[int, int]]:
|
| 326 |
+
"""
|
| 327 |
+
Convert a directed edge index into a set of undirected edges represented
|
| 328 |
+
as (u,v) tuples with u < v. Self loops are ignored.
|
| 329 |
+
"""
|
| 330 |
+
edges = set()
|
| 331 |
+
src = edge_index[0].tolist()
|
| 332 |
+
dst = edge_index[1].tolist()
|
| 333 |
+
for u, v in zip(src, dst):
|
| 334 |
+
if u == v:
|
| 335 |
+
continue
|
| 336 |
+
a, b = (u, v) if u < v else (v, u)
|
| 337 |
+
edges.add((a, b))
|
| 338 |
+
return edges
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def edge_set_to_index(edges: Set[Tuple[int, int]], num_nodes: int) -> Tensor:
|
| 342 |
+
"""
|
| 343 |
+
Convert an undirected edge set into a directed edge_index tensor of shape
|
| 344 |
+
[2, 2*|edges|] by adding both (u,v) and (v,u) for each undirected edge.
|
| 345 |
+
Self loops are omitted; GCNConv adds them automatically.
|
| 346 |
+
"""
|
| 347 |
+
if not edges:
|
| 348 |
+
return torch.empty(2, 0, dtype=torch.long)
|
| 349 |
+
src_list = []
|
| 350 |
+
dst_list = []
|
| 351 |
+
for u, v in edges:
|
| 352 |
+
src_list.extend([u, v])
|
| 353 |
+
dst_list.extend([v, u])
|
| 354 |
+
edge_index = torch.tensor([src_list, dst_list], dtype=torch.long)
|
| 355 |
+
return edge_index
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def random_rewire(edges: Set[Tuple[int, int]], num_nodes: int, n_changes: int, rng: random.Random) -> Set[Tuple[int, int]]:
|
| 359 |
+
"""
|
| 360 |
+
Perform n_changes edge removals and n_changes edge additions on the given
|
| 361 |
+
undirected edge set. For each change we remove a random existing edge and
|
| 362 |
+
add a random new edge (u,v) not already present. Self loops are never
|
| 363 |
+
added. Duplicate additions are skipped.
|
| 364 |
+
"""
|
| 365 |
+
edges = set(edges) # copy
|
| 366 |
+
# If there are fewer edges than n_changes, adjust
|
| 367 |
+
n_changes = min(n_changes, len(edges))
|
| 368 |
+
# Remove random edges
|
| 369 |
+
to_remove = rng.sample(list(edges), n_changes)
|
| 370 |
+
for e in to_remove:
|
| 371 |
+
edges.remove(e)
|
| 372 |
+
# Add random new edges
|
| 373 |
+
added = 0
|
| 374 |
+
attempts = 0
|
| 375 |
+
while added < n_changes and attempts < n_changes * 10:
|
| 376 |
+
u = rng.randrange(num_nodes)
|
| 377 |
+
v = rng.randrange(num_nodes)
|
| 378 |
+
if u == v:
|
| 379 |
+
attempts += 1
|
| 380 |
+
continue
|
| 381 |
+
a, b = (u, v) if u < v else (v, u)
|
| 382 |
+
if (a, b) not in edges:
|
| 383 |
+
edges.add((a, b))
|
| 384 |
+
added += 1
|
| 385 |
+
attempts += 1
|
| 386 |
+
return edges
|
| 387 |
+
|
| 388 |
+
|
| 389 |
+
def auc_over_time(acc_list: List[float]) -> float:
|
| 390 |
+
"""
|
| 391 |
+
Compute the area under an accuracy–time curve using the trapezoidal rule.
|
| 392 |
+
``acc_list`` should contain the accuracies at t=0,1,...,T. The AUC is
|
| 393 |
+
normalized by T so that a perfect score of 1.0 yields AUC=1.0.
|
| 394 |
+
"""
|
| 395 |
+
if not acc_list:
|
| 396 |
+
return 0.0
|
| 397 |
+
area = 0.0
|
| 398 |
+
for i in range(1, len(acc_list)):
|
| 399 |
+
area += (acc_list[i] + acc_list[i-1]) / 2.0
|
| 400 |
+
return area / (len(acc_list) - 1)
|
| 401 |
+
|
| 402 |
+
|
| 403 |
+
def evaluate_dynamic_auc(model: nn.Module,
|
| 404 |
+
data,
|
| 405 |
+
core_mask: Tensor,
|
| 406 |
+
steps: int = 30,
|
| 407 |
+
flip_fraction: float = 0.05,
|
| 408 |
+
rng_seed: int = 1234) -> List[float]:
|
| 409 |
+
"""
|
| 410 |
+
Evaluate a model's test accuracy over a sequence of random edge rewiring steps.
|
| 411 |
+
|
| 412 |
+
Parameters
|
| 413 |
+
----------
|
| 414 |
+
model : nn.Module
|
| 415 |
+
A trained model that accepts (x, edge_index) and returns logits.
|
| 416 |
+
data : Data
|
| 417 |
+
PyG data object with attributes x, y, test_mask. ``data.edge_index``
|
| 418 |
+
provides the initial adjacency.
|
| 419 |
+
core_mask : Tensor[bool]
|
| 420 |
+
Boolean mask indicating core nodes (used for gating during evaluation).
|
| 421 |
+
The baseline model ignores it.
|
| 422 |
+
steps : int, optional
|
| 423 |
+
Number of rewiring steps to perform. The accuracy at t=0 is computed
|
| 424 |
+
before any rewiring. Default: 30.
|
| 425 |
+
flip_fraction : float, optional
|
| 426 |
+
Fraction of edges to remove/add at each step. For example, 0.05
|
| 427 |
+
rewires 5% of existing edges per step. Default: 0.05.
|
| 428 |
+
rng_seed : int, optional
|
| 429 |
+
Random seed for reproducibility. Default: 1234.
|
| 430 |
+
|
| 431 |
+
Returns
|
| 432 |
+
-------
|
| 433 |
+
List[float]
|
| 434 |
+
A list of length ``steps+1`` containing the test accuracy at each
|
| 435 |
+
iteration (including t=0).
|
| 436 |
+
"""
|
| 437 |
+
# Convert initial edge_index to undirected edge set
|
| 438 |
+
base_edges = undirected_edge_set(data.edge_index)
|
| 439 |
+
num_edges = len(base_edges)
|
| 440 |
+
# Determine number of changes per step
|
| 441 |
+
n_changes = max(1, int(flip_fraction * num_edges))
|
| 442 |
+
# Clone model state so we don't accidentally update it during evaluation
|
| 443 |
+
model.eval()
|
| 444 |
+
# Random generator
|
| 445 |
+
rng = random.Random(rng_seed)
|
| 446 |
+
# Copy of edges for dynamic modification
|
| 447 |
+
cur_edges = set(base_edges)
|
| 448 |
+
accuracies = []
|
| 449 |
+
# Evaluate at t=0
|
| 450 |
+
ei = edge_set_to_index(cur_edges, data.num_nodes)
|
| 451 |
+
# Because PyG expects a tensor on the same device as data.x
|
| 452 |
+
ei = ei.to(data.x.device)
|
| 453 |
+
logits = model(data.x, ei)
|
| 454 |
+
accuracies.append(accuracy(logits, data.y, data.test_mask))
|
| 455 |
+
# Perform rewiring steps
|
| 456 |
+
for t in range(1, steps + 1):
|
| 457 |
+
cur_edges = random_rewire(cur_edges, data.num_nodes, n_changes, rng)
|
| 458 |
+
ei = edge_set_to_index(cur_edges, data.num_nodes).to(data.x.device)
|
| 459 |
+
logits = model(data.x, ei)
|
| 460 |
+
acc = accuracy(logits, data.y, data.test_mask)
|
| 461 |
+
accuracies.append(acc)
|
| 462 |
+
return accuracies
|
| 463 |
+
|
| 464 |
+
|
| 465 |
+
# -----------------------------------------------------------------------------
|
| 466 |
+
# Main entrypoint
|
| 467 |
+
# -----------------------------------------------------------------------------
|
| 468 |
+
|
| 469 |
+
def main():
|
| 470 |
+
parser = argparse.ArgumentParser(description="L‑RMC anchored GCN vs. baseline with dynamic evaluation.")
|
| 471 |
+
parser.add_argument("--dataset", required=True, choices=["Cora", "Citeseer", "Pubmed"],
|
| 472 |
+
help="Planetoid dataset to load.")
|
| 473 |
+
parser.add_argument("--seeds", required=True, help="Path to LRMC seeds JSON (for core extraction).")
|
| 474 |
+
parser.add_argument("--variant", choices=["baseline", "anchor"], default="anchor", help="Which variant to run.")
|
| 475 |
+
parser.add_argument("--hidden", type=int, default=64, help="Hidden dimension.")
|
| 476 |
+
parser.add_argument("--epochs", type=int, default=200, help="Number of training epochs.")
|
| 477 |
+
parser.add_argument("--lr", type=float, default=0.01, help="Learning rate.")
|
| 478 |
+
parser.add_argument("--wd", type=float, default=5e-4, help="Weight decay (L2).")
|
| 479 |
+
parser.add_argument("--dropout", type=float, default=0.5, help="Dropout probability.")
|
| 480 |
+
parser.add_argument("--gamma", type=float, default=1.0, help="Damping factor γ for cross‑boundary edges (anchor only).")
|
| 481 |
+
parser.add_argument("--dynamic_steps", type=int, default=30, help="Number of dynamic rewiring steps for AUC evaluation.")
|
| 482 |
+
parser.add_argument("--flip_fraction", type=float, default=0.05, help="Fraction of edges rewired at each step.")
|
| 483 |
+
parser.add_argument("--seed", type=int, default=42, help="Random seed for PyTorch.")
|
| 484 |
+
args = parser.parse_args()
|
| 485 |
+
|
| 486 |
+
# Set seeds
|
| 487 |
+
torch.manual_seed(args.seed)
|
| 488 |
+
random.seed(args.seed)
|
| 489 |
+
|
| 490 |
+
# Load dataset
|
| 491 |
+
dataset = Planetoid(root=f"./data/{args.dataset}", name=args.dataset)
|
| 492 |
+
data = dataset[0]
|
| 493 |
+
in_dim = dataset.num_node_features
|
| 494 |
+
out_dim = dataset.num_classes
|
| 495 |
+
num_nodes = data.num_nodes
|
| 496 |
+
|
| 497 |
+
# Load core assignment
|
| 498 |
+
core_mask, core_nodes = load_top1_assignment(args.seeds, num_nodes)
|
| 499 |
+
print(f"Loaded core of size {core_nodes.numel()} from {args.seeds}.")
|
| 500 |
+
|
| 501 |
+
if args.variant == "baseline":
|
| 502 |
+
# Train baseline only
|
| 503 |
+
baseline = GCN2(in_dim, args.hidden, out_dim, dropout=args.dropout)
|
| 504 |
+
train_model(baseline, data, epochs=args.epochs, lr=args.lr, weight_decay=args.wd)
|
| 505 |
+
res = evaluate_model(baseline, data)
|
| 506 |
+
print(f"Baseline GCN: train={res['train']:.4f} val={res['val']:.4f} test={res['test']:.4f}")
|
| 507 |
+
# Evaluate dynamic AUC
|
| 508 |
+
accs = evaluate_dynamic_auc(baseline, data, core_mask, steps=args.dynamic_steps,
|
| 509 |
+
flip_fraction=args.flip_fraction, rng_seed=args.seed)
|
| 510 |
+
auc = auc_over_time(accs)
|
| 511 |
+
print(f"Baseline dynamic AUC‑AT (steps={args.dynamic_steps}, flip={args.flip_fraction}): {auc:.4f}")
|
| 512 |
+
return
|
| 513 |
+
|
| 514 |
+
# ----- Train both baseline and anchor variants -----
|
| 515 |
+
# Baseline
|
| 516 |
+
baseline = GCN2(in_dim, args.hidden, out_dim, dropout=args.dropout)
|
| 517 |
+
train_model(baseline, data, epochs=args.epochs, lr=args.lr, weight_decay=args.wd)
|
| 518 |
+
res_base = evaluate_model(baseline, data)
|
| 519 |
+
print(f"Baseline GCN: train={res_base['train']:.4f} val={res_base['val']:.4f} test={res_base['test']:.4f}")
|
| 520 |
+
# Anchor model
|
| 521 |
+
anchor = AnchorGCN(in_dim, args.hidden, out_dim,
|
| 522 |
+
core_mask=core_mask,
|
| 523 |
+
gamma=args.gamma,
|
| 524 |
+
dropout=args.dropout)
|
| 525 |
+
train_model(anchor, data, epochs=args.epochs, lr=args.lr, weight_decay=args.wd)
|
| 526 |
+
res_anchor = evaluate_model(anchor, data)
|
| 527 |
+
print(f"Anchor‑GCN: train={res_anchor['train']:.4f} val={res_anchor['val']:.4f} test={res_anchor['test']:.4f}")
|
| 528 |
+
# Dynamic evaluation
|
| 529 |
+
accs_base = evaluate_dynamic_auc(baseline, data, core_mask, steps=args.dynamic_steps,
|
| 530 |
+
flip_fraction=args.flip_fraction, rng_seed=args.seed)
|
| 531 |
+
accs_anchor = evaluate_dynamic_auc(anchor, data, core_mask, steps=args.dynamic_steps,
|
| 532 |
+
flip_fraction=args.flip_fraction, rng_seed=args.seed)
|
| 533 |
+
auc_base = auc_over_time(accs_base)
|
| 534 |
+
auc_anchor = auc_over_time(accs_anchor)
|
| 535 |
+
print(f"Dynamic AUC‑AT (steps={args.dynamic_steps}, flip={args.flip_fraction}):")
|
| 536 |
+
print(f" Baseline : {auc_base:.4f}\n Anchor : {auc_anchor:.4f}")
|
| 537 |
+
|
| 538 |
+
|
| 539 |
+
if __name__ == "__main__":
|
| 540 |
+
main()
|
src/2.6_lrmc_summary.py
ADDED
|
@@ -0,0 +1,691 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Train a GCN on an L‑RMC subgraph and compare to a full‑graph baseline.
|
| 3 |
+
|
| 4 |
+
Modes:
|
| 5 |
+
- core_mode=forward : Train on core subgraph, then forward on full graph (your current approach).
|
| 6 |
+
- core_mode=appnp : Train on core subgraph, then seed logits on core and APPNP‑propagate on full graph.
|
| 7 |
+
|
| 8 |
+
Extras:
|
| 9 |
+
- --expand_core_with_train : Make sure all training labels lie inside the core
|
| 10 |
+
(C' = C ∪ train_idx) for fair train‑time comparison.
|
| 11 |
+
- --warm_ft_epochs N : Optional short finetune on the full graph starting
|
| 12 |
+
from the core model's weights (measure time‑to‑target).
|
| 13 |
+
|
| 14 |
+
It prints:
|
| 15 |
+
- Dataset stats
|
| 16 |
+
- Core size and coverage of train/val/test inside the core
|
| 17 |
+
- Train/Val/Test accuracy for baseline and core model
|
| 18 |
+
- Wall‑clock times
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import argparse
|
| 22 |
+
import json
|
| 23 |
+
import time
|
| 24 |
+
import random
|
| 25 |
+
from statistics import mean, stdev
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
from typing import Dict
|
| 28 |
+
|
| 29 |
+
import torch
|
| 30 |
+
import torch.nn.functional as F
|
| 31 |
+
from torch import nn, Tensor
|
| 32 |
+
from torch_geometric.datasets import Planetoid
|
| 33 |
+
from torch_geometric.nn import GCNConv, APPNP
|
| 34 |
+
from torch_geometric.utils import subgraph
|
| 35 |
+
|
| 36 |
+
# ------------------------------------------------------------
|
| 37 |
+
# Rich imports
|
| 38 |
+
# ------------------------------------------------------------
|
| 39 |
+
from rich.console import Console
|
| 40 |
+
from rich.table import Table
|
| 41 |
+
from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn
|
| 42 |
+
|
| 43 |
+
# Rich console instance
|
| 44 |
+
console = Console()
|
| 45 |
+
|
| 46 |
+
# ------------------------------------------------------------
|
| 47 |
+
# Utilities
|
| 48 |
+
# ------------------------------------------------------------
|
| 49 |
+
def load_top1_assignment(seeds_json: str, n_nodes: int) -> torch.Tensor:
|
| 50 |
+
"""
|
| 51 |
+
seeds_json format (expected):
|
| 52 |
+
{"clusters": [{"seed_nodes":[...], "score": float, ...}, ...]}
|
| 53 |
+
We pick the cluster with max (score, size) and return a boolean core mask.
|
| 54 |
+
|
| 55 |
+
Always assume that the seeds json nodes are 1-indexed.
|
| 56 |
+
"""
|
| 57 |
+
obj = json.loads(Path(seeds_json).read_text())
|
| 58 |
+
clusters = obj.get("clusters", [])
|
| 59 |
+
if not clusters:
|
| 60 |
+
return torch.zeros(n_nodes, dtype=torch.bool)
|
| 61 |
+
best = max(clusters, key=lambda c: (float(c.get("score", 0.0)), len(c.get("seed_nodes", []))))
|
| 62 |
+
ids = best.get("seed_nodes", [])
|
| 63 |
+
ids = [int(x) - 1 for x in ids] # Convert 1-indexed to 0-indexed
|
| 64 |
+
ids = sorted(set([i for i in ids if 0 <= i < n_nodes]))
|
| 65 |
+
mask = torch.zeros(n_nodes, dtype=torch.bool)
|
| 66 |
+
if ids:
|
| 67 |
+
mask[torch.tensor(ids, dtype=torch.long)] = True
|
| 68 |
+
return mask
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def coverage_counts(core_mask: torch.Tensor, train_mask: torch.Tensor,
|
| 72 |
+
val_mask: torch.Tensor, test_mask: torch.Tensor) -> Dict[str, int]:
|
| 73 |
+
return {
|
| 74 |
+
"core_size": int(core_mask.sum().item()),
|
| 75 |
+
"train_in_core": int((core_mask & train_mask).sum().item()),
|
| 76 |
+
"val_in_core": int((core_mask & val_mask).sum().item()),
|
| 77 |
+
"test_in_core": int((core_mask & test_mask).sum().item()),
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def accuracy(logits: Tensor, y: Tensor, mask: Tensor) -> float:
|
| 82 |
+
pred = logits[mask].argmax(dim=1)
|
| 83 |
+
return (pred == y[mask]).float().mean().item()
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
def set_seed(seed: int):
|
| 87 |
+
"""Set random seeds for reproducibility across runs."""
|
| 88 |
+
random.seed(seed)
|
| 89 |
+
try:
|
| 90 |
+
import numpy as np # optional
|
| 91 |
+
np.random.seed(seed)
|
| 92 |
+
except Exception:
|
| 93 |
+
pass
|
| 94 |
+
torch.manual_seed(seed)
|
| 95 |
+
if torch.cuda.is_available():
|
| 96 |
+
torch.cuda.manual_seed_all(seed)
|
| 97 |
+
# Make CUDA/CuDNN deterministic where applicable
|
| 98 |
+
torch.backends.cudnn.deterministic = True
|
| 99 |
+
torch.backends.cudnn.benchmark = False
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
# ------------------------------------------------------------
|
| 103 |
+
# Models
|
| 104 |
+
# ------------------------------------------------------------
|
| 105 |
+
class GCN2(nn.Module):
|
| 106 |
+
def __init__(self, in_dim: int, hid: int, out_dim: int, dropout: float = 0.5):
|
| 107 |
+
super().__init__()
|
| 108 |
+
self.c1 = GCNConv(in_dim, hid)
|
| 109 |
+
self.c2 = GCNConv(hid, out_dim)
|
| 110 |
+
self.dropout = dropout
|
| 111 |
+
|
| 112 |
+
def forward(self, x, ei):
|
| 113 |
+
x = self.c1(x, ei)
|
| 114 |
+
x = torch.relu(x)
|
| 115 |
+
x = F.dropout(x, p=self.dropout, training=self.training)
|
| 116 |
+
x = self.c2(x, ei)
|
| 117 |
+
return x
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# ------------------------------------------------------------
|
| 121 |
+
# Training / evaluation
|
| 122 |
+
# ------------------------------------------------------------
|
| 123 |
+
@torch.no_grad()
|
| 124 |
+
def eval_all(model: nn.Module, data) -> Dict[str, float]:
|
| 125 |
+
model.eval()
|
| 126 |
+
logits = model(data.x, data.edge_index)
|
| 127 |
+
return {
|
| 128 |
+
"train": accuracy(logits, data.y, data.train_mask),
|
| 129 |
+
"val": accuracy(logits, data.y, data.val_mask),
|
| 130 |
+
"test": accuracy(logits, data.y, data.test_mask),
|
| 131 |
+
}
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
def train(model: nn.Module, data, epochs=200, lr=0.01, wd=5e-4, patience=100):
|
| 135 |
+
opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
|
| 136 |
+
best, best_state, bad = -1.0, None, 0
|
| 137 |
+
|
| 138 |
+
# Optional progress bar
|
| 139 |
+
with Progress(
|
| 140 |
+
SpinnerColumn(),
|
| 141 |
+
"[progress.description]{task.description}",
|
| 142 |
+
TimeElapsedColumn(),
|
| 143 |
+
transient=True,
|
| 144 |
+
) as progress:
|
| 145 |
+
task = progress.add_task("Training", total=epochs)
|
| 146 |
+
|
| 147 |
+
for ep in range(1, epochs + 1):
|
| 148 |
+
model.train()
|
| 149 |
+
opt.zero_grad(set_to_none=True)
|
| 150 |
+
out = model(data.x, data.edge_index)
|
| 151 |
+
loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
|
| 152 |
+
loss.backward()
|
| 153 |
+
opt.step()
|
| 154 |
+
|
| 155 |
+
# early stop on val
|
| 156 |
+
with torch.no_grad():
|
| 157 |
+
val = accuracy(model(data.x, data.edge_index), data.y, data.val_mask)
|
| 158 |
+
|
| 159 |
+
if val > best:
|
| 160 |
+
best, bad = val, 0
|
| 161 |
+
best_state = {k: v.detach().clone() for k, v in model.state_dict().items()}
|
| 162 |
+
else:
|
| 163 |
+
bad += 1
|
| 164 |
+
if bad >= patience:
|
| 165 |
+
break
|
| 166 |
+
|
| 167 |
+
progress.update(task, advance=1, description=f"Epoch {ep} | val={val:.4f}")
|
| 168 |
+
|
| 169 |
+
if best_state is not None:
|
| 170 |
+
model.load_state_dict(best_state)
|
| 171 |
+
model.eval()
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def subset_data(data, nodes_idx: torch.Tensor):
|
| 175 |
+
"""
|
| 176 |
+
Build an induced subgraph on 'nodes_idx'. Keeps x,y,masks restricted to that set.
|
| 177 |
+
Returns a shallow copy with edge_index/feature/labels/masks sliced.
|
| 178 |
+
"""
|
| 179 |
+
nodes_idx = nodes_idx.to(torch.long)
|
| 180 |
+
sub_ei, _ = subgraph(nodes_idx, data.edge_index, relabel_nodes=True, num_nodes=data.num_nodes)
|
| 181 |
+
sub = type(data)()
|
| 182 |
+
sub.x = data.x[nodes_idx]
|
| 183 |
+
sub.y = data.y[nodes_idx]
|
| 184 |
+
sub.train_mask = data.train_mask[nodes_idx]
|
| 185 |
+
sub.val_mask = data.val_mask[nodes_idx]
|
| 186 |
+
sub.test_mask = data.test_mask[nodes_idx]
|
| 187 |
+
sub.edge_index = sub_ei
|
| 188 |
+
sub.num_nodes = sub.x.size(0)
|
| 189 |
+
return sub
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
# ------------------------------------------------------------
|
| 193 |
+
# APPNP seeding (Mode B)
|
| 194 |
+
# ------------------------------------------------------------
|
| 195 |
+
def appnp_seed_propagate(logits_seed: Tensor, edge_index: Tensor, K=10, alpha=0.1) -> Tensor:
|
| 196 |
+
"""
|
| 197 |
+
logits_seed is [N, C] where rows outside the core are zeros.
|
| 198 |
+
We propagate these logits with APPNP to fill the graph.
|
| 199 |
+
"""
|
| 200 |
+
appnp = APPNP(K=K, alpha=alpha) # no trainable params
|
| 201 |
+
return appnp(logits_seed, edge_index)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
# ------------------------------------------------------------
|
| 205 |
+
# Main
|
| 206 |
+
# ------------------------------------------------------------
|
| 207 |
+
def main():
|
| 208 |
+
p = argparse.ArgumentParser()
|
| 209 |
+
p.add_argument("--dataset", required=True, choices=["Cora", "Citeseer", "Pubmed"])
|
| 210 |
+
p.add_argument("--seeds", required=True, help="Path to LRMC seeds JSON")
|
| 211 |
+
p.add_argument("--hidden", type=int, default=64)
|
| 212 |
+
p.add_argument("--dropout", type=float, default=0.5)
|
| 213 |
+
p.add_argument("--epochs", type=int, default=200)
|
| 214 |
+
p.add_argument("--lr", type=float, default=0.01)
|
| 215 |
+
p.add_argument("--wd", type=float, default=5e-4)
|
| 216 |
+
p.add_argument("--patience", type=int, default=100)
|
| 217 |
+
p.add_argument("--core_mode", choices=["forward", "appnp"], default="forward",
|
| 218 |
+
help="How to evaluate the core model on the full graph.")
|
| 219 |
+
p.add_argument("--alpha", type=float, default=0.1, help="APPNP teleport prob (Mode B).")
|
| 220 |
+
p.add_argument("--K", type=int, default=10, help="APPNP steps (Mode B).")
|
| 221 |
+
p.add_argument("--expand_core_with_train", action="store_true",
|
| 222 |
+
help="Expand LRMC core with all training nodes (C' = C ∪ train_idx).")
|
| 223 |
+
p.add_argument("--warm_ft_epochs", type=int, default=0,
|
| 224 |
+
help="If >0, run a short finetune on the FULL graph starting from the core model.")
|
| 225 |
+
p.add_argument("--warm_ft_lr", type=float, default=0.005)
|
| 226 |
+
p.add_argument("--runs", type=int, default=1,
|
| 227 |
+
help="Number of runs with different seeds to average results.")
|
| 228 |
+
p.add_argument("-o", "--output_json", type=str, default=None,
|
| 229 |
+
help="If set, save all computed metrics and settings to this JSON file.")
|
| 230 |
+
args = p.parse_args()
|
| 231 |
+
|
| 232 |
+
# ------------------------------------------------------------
|
| 233 |
+
# Load data
|
| 234 |
+
# ------------------------------------------------------------
|
| 235 |
+
ds = Planetoid(root=f"./data/{args.dataset}", name=args.dataset)
|
| 236 |
+
data = ds[0]
|
| 237 |
+
n, e = data.num_nodes, data.edge_index.size(1) // 2
|
| 238 |
+
|
| 239 |
+
console.print(f"[bold cyan]Dataset: {args.dataset} | Nodes: {n} | Edges: {e}[/bold cyan]")
|
| 240 |
+
|
| 241 |
+
# Results accumulator for optional JSON output
|
| 242 |
+
results = {
|
| 243 |
+
"args": {
|
| 244 |
+
k: (float(v) if isinstance(v, float) else v)
|
| 245 |
+
for k, v in vars(args).items()
|
| 246 |
+
if k != "output_json"
|
| 247 |
+
},
|
| 248 |
+
"dataset": {
|
| 249 |
+
"name": args.dataset,
|
| 250 |
+
"num_nodes": int(n),
|
| 251 |
+
"num_edges": int(e),
|
| 252 |
+
},
|
| 253 |
+
}
|
| 254 |
+
|
| 255 |
+
def maybe_save_results():
|
| 256 |
+
"""Write results to JSON if the user requested it."""
|
| 257 |
+
if not args.output_json:
|
| 258 |
+
return
|
| 259 |
+
out_path = Path(args.output_json)
|
| 260 |
+
try:
|
| 261 |
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
| 262 |
+
except Exception:
|
| 263 |
+
pass
|
| 264 |
+
with out_path.open("w") as f:
|
| 265 |
+
json.dump(results, f, indent=2)
|
| 266 |
+
|
| 267 |
+
# ------------------------------------------------------------
|
| 268 |
+
# Load LRMC core
|
| 269 |
+
# ------------------------------------------------------------
|
| 270 |
+
core_mask = load_top1_assignment(args.seeds, n)
|
| 271 |
+
if args.expand_core_with_train:
|
| 272 |
+
core_mask = core_mask | data.train_mask
|
| 273 |
+
|
| 274 |
+
C_idx = torch.nonzero(core_mask, as_tuple=False).view(-1)
|
| 275 |
+
frac = 100.0 * C_idx.numel() / n
|
| 276 |
+
cov = coverage_counts(core_mask, data.train_mask, data.val_mask, data.test_mask)
|
| 277 |
+
|
| 278 |
+
console.print(f"[bold green]Loaded LRMC core of size {cov['core_size']} (≈{frac:.2f}% of the graph) from {args.seeds}[/bold green]")
|
| 279 |
+
|
| 280 |
+
# Record core coverage info
|
| 281 |
+
results["core"] = {
|
| 282 |
+
"source": str(args.seeds),
|
| 283 |
+
"expanded_with_train": bool(args.expand_core_with_train),
|
| 284 |
+
"size": int(cov["core_size"]),
|
| 285 |
+
"fraction": float(frac / 100.0),
|
| 286 |
+
"coverage": {
|
| 287 |
+
"train_in_core": int(cov["train_in_core"]),
|
| 288 |
+
"val_in_core": int(cov["val_in_core"]),
|
| 289 |
+
"test_in_core": int(cov["test_in_core"]),
|
| 290 |
+
},
|
| 291 |
+
}
|
| 292 |
+
|
| 293 |
+
# Coverage table
|
| 294 |
+
cov_table = Table(title="LRMC Core Coverage")
|
| 295 |
+
cov_table.add_column("Metric", style="cyan")
|
| 296 |
+
cov_table.add_column("Count", style="magenta")
|
| 297 |
+
cov_table.add_row("Core Size", str(cov["core_size"]))
|
| 298 |
+
cov_table.add_row("Train in Core", str(cov["train_in_core"]))
|
| 299 |
+
cov_table.add_row("Val in Core", str(cov["val_in_core"]))
|
| 300 |
+
cov_table.add_row("Test in Core", str(cov["test_in_core"]))
|
| 301 |
+
console.print(cov_table)
|
| 302 |
+
|
| 303 |
+
# ------------------------------------------------------------
|
| 304 |
+
# Single-run or multi-run execution
|
| 305 |
+
# ------------------------------------------------------------
|
| 306 |
+
if args.runs == 1:
|
| 307 |
+
# ---------------------
|
| 308 |
+
# Baseline (full graph)
|
| 309 |
+
# ---------------------
|
| 310 |
+
set_seed(0)
|
| 311 |
+
t0 = time.perf_counter()
|
| 312 |
+
base = GCN2(in_dim=ds.num_node_features,
|
| 313 |
+
hid=args.hidden,
|
| 314 |
+
out_dim=ds.num_classes,
|
| 315 |
+
dropout=args.dropout)
|
| 316 |
+
train(base, data, epochs=args.epochs, lr=args.lr, wd=args.wd, patience=args.patience)
|
| 317 |
+
base_metrics = eval_all(base, data)
|
| 318 |
+
t1 = time.perf_counter()
|
| 319 |
+
|
| 320 |
+
console.print("\n[bold]Baseline (trained on full graph):[/bold]")
|
| 321 |
+
base_table = Table(show_header=True, header_style="bold magenta")
|
| 322 |
+
base_table.add_column("Metric", style="cyan")
|
| 323 |
+
base_table.add_column("Value", style="magenta")
|
| 324 |
+
base_table.add_row("Train Accuracy", f"{base_metrics['train']:.4f}")
|
| 325 |
+
base_table.add_row("Validation Accuracy", f"{base_metrics['val']:.4f}")
|
| 326 |
+
base_table.add_row("Test Accuracy", f"{base_metrics['test']:.4f}")
|
| 327 |
+
base_table.add_row("Time (s)", f"{t1 - t0:.2f}")
|
| 328 |
+
console.print(base_table)
|
| 329 |
+
|
| 330 |
+
# Save baseline single-run metrics
|
| 331 |
+
results["single_run"] = {
|
| 332 |
+
"baseline": {
|
| 333 |
+
"train": float(base_metrics["train"]),
|
| 334 |
+
"val": float(base_metrics["val"]),
|
| 335 |
+
"test": float(base_metrics["test"]),
|
| 336 |
+
"time_s": float(t1 - t0),
|
| 337 |
+
}
|
| 338 |
+
}
|
| 339 |
+
|
| 340 |
+
# ---------------------
|
| 341 |
+
# Core model (train on subgraph)
|
| 342 |
+
# ---------------------
|
| 343 |
+
if C_idx.numel() == 0:
|
| 344 |
+
console.print("[bold yellow]LRMC core is empty; skipping core model.[/bold yellow]")
|
| 345 |
+
results["core_empty"] = True
|
| 346 |
+
maybe_save_results()
|
| 347 |
+
return
|
| 348 |
+
|
| 349 |
+
data_C = subset_data(data, C_idx)
|
| 350 |
+
mC = GCN2(in_dim=ds.num_node_features,
|
| 351 |
+
hid=args.hidden,
|
| 352 |
+
out_dim=ds.num_classes,
|
| 353 |
+
dropout=args.dropout)
|
| 354 |
+
|
| 355 |
+
t2 = time.perf_counter()
|
| 356 |
+
train(mC, data_C, epochs=args.epochs, lr=args.lr, wd=args.wd, patience=args.patience)
|
| 357 |
+
t3 = time.perf_counter()
|
| 358 |
+
|
| 359 |
+
# Evaluate core model on FULL graph
|
| 360 |
+
if args.core_mode == "forward":
|
| 361 |
+
# Mode A: run a standard forward pass on the full graph
|
| 362 |
+
mC.eval()
|
| 363 |
+
logits_full = mC(data.x, data.edge_index)
|
| 364 |
+
else:
|
| 365 |
+
# Mode B: seed logits on core and propagate with APPNP
|
| 366 |
+
mC.eval()
|
| 367 |
+
with torch.no_grad():
|
| 368 |
+
logits_C = mC(data_C.x, data_C.edge_index) # [|C|, num_classes]
|
| 369 |
+
logits_seed = torch.zeros(n, ds.num_classes, device=logits_C.device)
|
| 370 |
+
logits_seed[C_idx] = logits_C
|
| 371 |
+
logits_full = appnp_seed_propagate(logits_seed,
|
| 372 |
+
data.edge_index,
|
| 373 |
+
K=args.K,
|
| 374 |
+
alpha=args.alpha)
|
| 375 |
+
|
| 376 |
+
core_metrics = {
|
| 377 |
+
"train": accuracy(logits_full, data.y, data.train_mask),
|
| 378 |
+
"val": accuracy(logits_full, data.y, data.val_mask),
|
| 379 |
+
"test": accuracy(logits_full, data.y, data.test_mask),
|
| 380 |
+
}
|
| 381 |
+
|
| 382 |
+
console.print("\n[bold]LRMC‑core model (trained on core, evaluated on full graph):[/bold]")
|
| 383 |
+
core_table = Table(show_header=True, header_style="bold magenta")
|
| 384 |
+
core_table.add_column("Metric", style="cyan")
|
| 385 |
+
core_table.add_column("Value", style="magenta")
|
| 386 |
+
core_table.add_row("Train Accuracy", f"{core_metrics['train']:.4f}")
|
| 387 |
+
core_table.add_row("Validation Accuracy", f"{core_metrics['val']:.4f}")
|
| 388 |
+
core_table.add_row("Test Accuracy", f"{core_metrics['test']:.4f}")
|
| 389 |
+
core_table.add_row("Core Training Time (s)", f"{t3 - t2:.2f}")
|
| 390 |
+
speedup = (t1 - t0) / (t3 - t2 + 1e-9)
|
| 391 |
+
core_table.add_row("Speedup vs. Baseline", f"{speedup:.2f}×")
|
| 392 |
+
console.print(core_table)
|
| 393 |
+
|
| 394 |
+
# Save core single-run metrics
|
| 395 |
+
results["single_run"]["core_model"] = {
|
| 396 |
+
"mode": str(args.core_mode),
|
| 397 |
+
"train": float(core_metrics["train"]),
|
| 398 |
+
"val": float(core_metrics["val"]),
|
| 399 |
+
"test": float(core_metrics["test"]),
|
| 400 |
+
"core_train_time_s": float(t3 - t2),
|
| 401 |
+
"speedup_vs_baseline": float(speedup),
|
| 402 |
+
}
|
| 403 |
+
|
| 404 |
+
# ------------------------------------------------------------------------------------
|
| 405 |
+
|
| 406 |
+
console.print("\n[bold]Model Comparison: Baseline vs. L-RMC-core[/bold]")
|
| 407 |
+
|
| 408 |
+
# Create comparison table
|
| 409 |
+
comparison_table = Table(title="Performance Comparison", show_header=True, header_style="bold magenta")
|
| 410 |
+
comparison_table.add_column("Metric", style="cyan")
|
| 411 |
+
comparison_table.add_column("Baseline", style="magenta")
|
| 412 |
+
comparison_table.add_column("L-RMC-core", style="green")
|
| 413 |
+
comparison_table.add_column("Speedup", style="yellow")
|
| 414 |
+
|
| 415 |
+
# Add performance metrics
|
| 416 |
+
for metric in ["train", "val", "test"]:
|
| 417 |
+
comparison_table.add_row(
|
| 418 |
+
f"{metric.capitalize()} Accuracy",
|
| 419 |
+
f"{base_metrics[metric]:.4f}",
|
| 420 |
+
f"{core_metrics[metric]:.4f}",
|
| 421 |
+
"" # Speedup is not applicable for accuracy
|
| 422 |
+
)
|
| 423 |
+
|
| 424 |
+
# Add timing and speedup
|
| 425 |
+
baseline_time = t1 - t0
|
| 426 |
+
core_time = t3 - t2
|
| 427 |
+
speedup = baseline_time / core_time if core_time > 0 else float('inf')
|
| 428 |
+
|
| 429 |
+
comparison_table.add_row(
|
| 430 |
+
"Training Time (s)",
|
| 431 |
+
f"{baseline_time:.2f}",
|
| 432 |
+
f"{core_time:.2f}",
|
| 433 |
+
f"{speedup:.2f}x"
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
comparison_table.add_row(
|
| 437 |
+
"Speedup",
|
| 438 |
+
"1x",
|
| 439 |
+
f"{speedup:.2f}x",
|
| 440 |
+
""
|
| 441 |
+
)
|
| 442 |
+
|
| 443 |
+
console.print(comparison_table)
|
| 444 |
+
|
| 445 |
+
# Optional warm‑start finetune (single run)
|
| 446 |
+
if args.warm_ft_epochs > 0:
|
| 447 |
+
warm = GCN2(in_dim=ds.num_node_features,
|
| 448 |
+
hid=args.hidden,
|
| 449 |
+
out_dim=ds.num_classes,
|
| 450 |
+
dropout=args.dropout)
|
| 451 |
+
warm.load_state_dict(mC.state_dict())
|
| 452 |
+
|
| 453 |
+
t4 = time.perf_counter()
|
| 454 |
+
train(warm, data,
|
| 455 |
+
epochs=args.warm_ft_epochs,
|
| 456 |
+
lr=args.warm_ft_lr,
|
| 457 |
+
wd=args.wd,
|
| 458 |
+
patience=args.warm_ft_epochs + 1)
|
| 459 |
+
t5 = time.perf_counter()
|
| 460 |
+
warm_metrics = eval_all(warm, data)
|
| 461 |
+
|
| 462 |
+
console.print("\n[bold]Warm‑start finetune (start from core model, train on FULL graph):[/bold]")
|
| 463 |
+
warm_table = Table(show_header=True, header_style="bold magenta")
|
| 464 |
+
warm_table.add_column("Metric", style="cyan")
|
| 465 |
+
warm_table.add_column("Value", style="magenta")
|
| 466 |
+
warm_table.add_row("Train Accuracy", f"{warm_metrics['train']:.4f}")
|
| 467 |
+
warm_table.add_row("Validation Accuracy", f"{warm_metrics['val']:.4f}")
|
| 468 |
+
warm_table.add_row("Test Accuracy", f"{warm_metrics['test']:.4f}")
|
| 469 |
+
warm_table.add_row("Finetune Time (s)", f"{t5 - t4:.2f}")
|
| 470 |
+
warm_table.add_row("Total (core train + warm)", f"{(t3 - t2 + t5 - t4):.2f}s")
|
| 471 |
+
console.print(warm_table)
|
| 472 |
+
|
| 473 |
+
# Save warm single-run metrics
|
| 474 |
+
results["single_run"]["warm_finetune"] = {
|
| 475 |
+
"train": float(warm_metrics["train"]),
|
| 476 |
+
"val": float(warm_metrics["val"]),
|
| 477 |
+
"test": float(warm_metrics["test"]),
|
| 478 |
+
"finetune_time_s": float(t5 - t4),
|
| 479 |
+
"total_time_s": float((t3 - t2) + (t5 - t4)),
|
| 480 |
+
}
|
| 481 |
+
|
| 482 |
+
# Emit results for single-run
|
| 483 |
+
maybe_save_results()
|
| 484 |
+
else:
|
| 485 |
+
# --------------------------------------------------------
|
| 486 |
+
# Multi-run: average metrics across different seeds
|
| 487 |
+
# --------------------------------------------------------
|
| 488 |
+
runs = args.runs
|
| 489 |
+
console.print(f"\n[bold]Running {runs} seeds and averaging results[/bold]")
|
| 490 |
+
|
| 491 |
+
# Storage for metrics across runs
|
| 492 |
+
base_train, base_val, base_test, base_time = [], [], [], []
|
| 493 |
+
core_train, core_val, core_test, core_time = [], [], [], []
|
| 494 |
+
speedups = []
|
| 495 |
+
|
| 496 |
+
warm_train, warm_val, warm_test, warm_time, warm_total_time = [], [], [], [], []
|
| 497 |
+
|
| 498 |
+
data_C = subset_data(data, C_idx) if C_idx.numel() > 0 else None
|
| 499 |
+
results["core_empty"] = data_C is None
|
| 500 |
+
|
| 501 |
+
for r in range(runs):
|
| 502 |
+
set_seed(r)
|
| 503 |
+
|
| 504 |
+
# Baseline
|
| 505 |
+
t0 = time.perf_counter()
|
| 506 |
+
base = GCN2(in_dim=ds.num_node_features,
|
| 507 |
+
hid=args.hidden,
|
| 508 |
+
out_dim=ds.num_classes,
|
| 509 |
+
dropout=args.dropout)
|
| 510 |
+
train(base, data, epochs=args.epochs, lr=args.lr, wd=args.wd, patience=args.patience)
|
| 511 |
+
bm = eval_all(base, data)
|
| 512 |
+
t1 = time.perf_counter()
|
| 513 |
+
|
| 514 |
+
base_train.append(bm["train"]) ; base_val.append(bm["val"]) ; base_test.append(bm["test"]) ; base_time.append(t1 - t0)
|
| 515 |
+
|
| 516 |
+
# Core model
|
| 517 |
+
if data_C is None:
|
| 518 |
+
continue # no core available
|
| 519 |
+
|
| 520 |
+
t2 = time.perf_counter()
|
| 521 |
+
mC = GCN2(in_dim=ds.num_node_features,
|
| 522 |
+
hid=args.hidden,
|
| 523 |
+
out_dim=ds.num_classes,
|
| 524 |
+
dropout=args.dropout)
|
| 525 |
+
train(mC, data_C, epochs=args.epochs, lr=args.lr, wd=args.wd, patience=args.patience)
|
| 526 |
+
t3 = time.perf_counter()
|
| 527 |
+
|
| 528 |
+
if args.core_mode == "forward":
|
| 529 |
+
mC.eval()
|
| 530 |
+
logits_full = mC(data.x, data.edge_index)
|
| 531 |
+
else:
|
| 532 |
+
mC.eval()
|
| 533 |
+
with torch.no_grad():
|
| 534 |
+
logits_C = mC(data_C.x, data_C.edge_index)
|
| 535 |
+
logits_seed = torch.zeros(n, ds.num_classes, device=logits_C.device)
|
| 536 |
+
logits_seed[C_idx] = logits_C
|
| 537 |
+
logits_full = appnp_seed_propagate(logits_seed,
|
| 538 |
+
data.edge_index,
|
| 539 |
+
K=args.K,
|
| 540 |
+
alpha=args.alpha)
|
| 541 |
+
|
| 542 |
+
cm = {
|
| 543 |
+
"train": accuracy(logits_full, data.y, data.train_mask),
|
| 544 |
+
"val": accuracy(logits_full, data.y, data.val_mask),
|
| 545 |
+
"test": accuracy(logits_full, data.y, data.test_mask),
|
| 546 |
+
}
|
| 547 |
+
|
| 548 |
+
core_train.append(cm["train"]) ; core_val.append(cm["val"]) ; core_test.append(cm["test"]) ; core_time.append(t3 - t2)
|
| 549 |
+
speedups.append((t1 - t0) / (t3 - t2 + 1e-9))
|
| 550 |
+
|
| 551 |
+
# Optional warm finetune per run
|
| 552 |
+
if args.warm_ft_epochs > 0:
|
| 553 |
+
warm = GCN2(in_dim=ds.num_node_features,
|
| 554 |
+
hid=args.hidden,
|
| 555 |
+
out_dim=ds.num_classes,
|
| 556 |
+
dropout=args.dropout)
|
| 557 |
+
warm.load_state_dict(mC.state_dict())
|
| 558 |
+
|
| 559 |
+
t4 = time.perf_counter()
|
| 560 |
+
train(warm, data,
|
| 561 |
+
epochs=args.warm_ft_epochs,
|
| 562 |
+
lr=args.warm_ft_lr,
|
| 563 |
+
wd=args.wd,
|
| 564 |
+
patience=args.warm_ft_epochs + 1)
|
| 565 |
+
t5 = time.perf_counter()
|
| 566 |
+
wm = eval_all(warm, data)
|
| 567 |
+
warm_train.append(wm["train"]) ; warm_val.append(wm["val"]) ; warm_test.append(wm["test"]) ; warm_time.append(t5 - t4)
|
| 568 |
+
warm_total_time.append((t3 - t2) + (t5 - t4))
|
| 569 |
+
|
| 570 |
+
# Helper to format mean ± std
|
| 571 |
+
def fmt(values, prec=4):
|
| 572 |
+
if not values:
|
| 573 |
+
return "n/a"
|
| 574 |
+
if len(values) == 1:
|
| 575 |
+
return f"{values[0]:.{prec}f}"
|
| 576 |
+
try:
|
| 577 |
+
return f"{mean(values):.{prec}f} ± {stdev(values):.{prec}f}"
|
| 578 |
+
except Exception:
|
| 579 |
+
m = sum(values) / len(values)
|
| 580 |
+
var = sum((v - m) ** 2 for v in values) / max(1, len(values) - 1)
|
| 581 |
+
return f"{m:.{prec}f} ± {var ** 0.5:.{prec}f}"
|
| 582 |
+
|
| 583 |
+
def stats(values):
|
| 584 |
+
"""Return dict with list, mean, std, count for JSON."""
|
| 585 |
+
d = {
|
| 586 |
+
"values": [float(v) for v in values],
|
| 587 |
+
"count": int(len(values)),
|
| 588 |
+
}
|
| 589 |
+
if len(values) >= 1:
|
| 590 |
+
d["mean"] = float(mean(values))
|
| 591 |
+
if len(values) >= 2:
|
| 592 |
+
d["std"] = float(stdev(values))
|
| 593 |
+
else:
|
| 594 |
+
d["std"] = None
|
| 595 |
+
return d
|
| 596 |
+
|
| 597 |
+
# Baseline summary
|
| 598 |
+
console.print("\n[bold]Baseline (averaged over runs):[/bold]")
|
| 599 |
+
base_table = Table(show_header=True, header_style="bold magenta")
|
| 600 |
+
base_table.add_column("Metric", style="cyan")
|
| 601 |
+
base_table.add_column("Mean ± Std", style="magenta")
|
| 602 |
+
base_table.add_row("Train Accuracy", fmt(base_train))
|
| 603 |
+
base_table.add_row("Validation Accuracy", fmt(base_val))
|
| 604 |
+
base_table.add_row("Test Accuracy", fmt(base_test))
|
| 605 |
+
base_table.add_row("Time (s)", fmt(base_time, prec=2))
|
| 606 |
+
console.print(base_table)
|
| 607 |
+
|
| 608 |
+
# Save baseline multi-run summary
|
| 609 |
+
results["multi_run"] = {
|
| 610 |
+
"runs": int(runs),
|
| 611 |
+
"baseline": {
|
| 612 |
+
"train": stats(base_train),
|
| 613 |
+
"val": stats(base_val),
|
| 614 |
+
"test": stats(base_test),
|
| 615 |
+
"time_s": stats(base_time),
|
| 616 |
+
}
|
| 617 |
+
}
|
| 618 |
+
|
| 619 |
+
if data_C is None:
|
| 620 |
+
console.print("[bold yellow]LRMC core is empty; no core runs to average.[/bold yellow]")
|
| 621 |
+
maybe_save_results()
|
| 622 |
+
return
|
| 623 |
+
|
| 624 |
+
# Core summary
|
| 625 |
+
console.print("\n[bold]LRMC‑core (averaged over runs):[/bold]")
|
| 626 |
+
core_table = Table(show_header=True, header_style="bold magenta")
|
| 627 |
+
core_table.add_column("Metric", style="cyan")
|
| 628 |
+
core_table.add_column("Mean ± Std", style="magenta")
|
| 629 |
+
core_table.add_row("Train Accuracy", fmt(core_train))
|
| 630 |
+
core_table.add_row("Validation Accuracy", fmt(core_val))
|
| 631 |
+
core_table.add_row("Test Accuracy", fmt(core_test))
|
| 632 |
+
core_table.add_row("Core Training Time (s)", fmt(core_time, prec=2))
|
| 633 |
+
core_table.add_row("Speedup vs. Baseline", fmt(speedups, prec=2))
|
| 634 |
+
console.print(core_table)
|
| 635 |
+
|
| 636 |
+
# Save core multi-run summary
|
| 637 |
+
results["multi_run"]["core_model"] = {
|
| 638 |
+
"mode": str(args.core_mode),
|
| 639 |
+
"train": stats(core_train),
|
| 640 |
+
"val": stats(core_val),
|
| 641 |
+
"test": stats(core_test),
|
| 642 |
+
"core_train_time_s": stats(core_time),
|
| 643 |
+
"speedup_vs_baseline": stats(speedups),
|
| 644 |
+
}
|
| 645 |
+
|
| 646 |
+
# Comparison summary
|
| 647 |
+
console.print("\n[bold]Model Comparison (averaged): Baseline vs. L-RMC-core[/bold]")
|
| 648 |
+
comparison_table = Table(title="Performance Comparison (Mean ± Std)", show_header=True, header_style="bold magenta")
|
| 649 |
+
comparison_table.add_column("Metric", style="cyan")
|
| 650 |
+
comparison_table.add_column("Baseline", style="magenta")
|
| 651 |
+
comparison_table.add_column("L-RMC-core", style="green")
|
| 652 |
+
comparison_table.add_column("Speedup", style="yellow")
|
| 653 |
+
|
| 654 |
+
for metric, b_vals, c_vals in [
|
| 655 |
+
("Train Accuracy", base_train, core_train),
|
| 656 |
+
("Validation Accuracy", base_val, core_val),
|
| 657 |
+
("Test Accuracy", base_test, core_test),
|
| 658 |
+
]:
|
| 659 |
+
comparison_table.add_row(metric, fmt(b_vals), fmt(c_vals), "")
|
| 660 |
+
|
| 661 |
+
comparison_table.add_row("Training Time (s)", fmt(base_time, prec=2), fmt(core_time, prec=2), fmt(speedups, prec=2))
|
| 662 |
+
comparison_table.add_row("Speedup", "1x", fmt(speedups, prec=2), "")
|
| 663 |
+
console.print(comparison_table)
|
| 664 |
+
|
| 665 |
+
# Optional warm summary
|
| 666 |
+
if args.warm_ft_epochs > 0 and warm_time:
|
| 667 |
+
console.print("\n[bold]Warm‑start finetune (averaged over runs):[/bold]")
|
| 668 |
+
warm_table = Table(show_header=True, header_style="bold magenta")
|
| 669 |
+
warm_table.add_column("Metric", style="cyan")
|
| 670 |
+
warm_table.add_column("Mean ± Std", style="magenta")
|
| 671 |
+
warm_table.add_row("Train Accuracy", fmt(warm_train))
|
| 672 |
+
warm_table.add_row("Validation Accuracy", fmt(warm_val))
|
| 673 |
+
warm_table.add_row("Test Accuracy", fmt(warm_test))
|
| 674 |
+
warm_table.add_row("Finetune Time (s)", fmt(warm_time, prec=2))
|
| 675 |
+
warm_table.add_row("Total (core train + warm)", fmt(warm_total_time, prec=2))
|
| 676 |
+
console.print(warm_table)
|
| 677 |
+
|
| 678 |
+
# Save warm multi-run summary
|
| 679 |
+
results["multi_run"]["warm_finetune"] = {
|
| 680 |
+
"train": stats(warm_train),
|
| 681 |
+
"val": stats(warm_val),
|
| 682 |
+
"test": stats(warm_test),
|
| 683 |
+
"finetune_time_s": stats(warm_time),
|
| 684 |
+
"total_time_s": stats(warm_total_time),
|
| 685 |
+
}
|
| 686 |
+
|
| 687 |
+
# Emit results for multi-run
|
| 688 |
+
maybe_save_results()
|
| 689 |
+
|
| 690 |
+
if __name__ == "__main__":
|
| 691 |
+
main()
|
src/2_epsilon_seed_sweep.py
ADDED
|
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import shutil
|
| 4 |
+
from typing import List, Set, Dict, Tuple
|
| 5 |
+
from decimal import Decimal, getcontext
|
| 6 |
+
from rich import print
|
| 7 |
+
|
| 8 |
+
from generate_lrmc_seeds import build_lrmc_single_graph
|
| 9 |
+
|
| 10 |
+
def get_seed_nodes(seeds_path: str) -> Set[int]:
|
| 11 |
+
"""Extract all seed nodes from a seeds JSON file.
|
| 12 |
+
|
| 13 |
+
Handles either 'seed_nodes' or 'members' fields.
|
| 14 |
+
"""
|
| 15 |
+
try:
|
| 16 |
+
with open(seeds_path, 'r') as f:
|
| 17 |
+
data = json.load(f)
|
| 18 |
+
|
| 19 |
+
seed_nodes: Set[int] = set()
|
| 20 |
+
clusters = data.get('clusters', [])
|
| 21 |
+
for cluster in clusters:
|
| 22 |
+
nodes = cluster.get('seed_nodes')
|
| 23 |
+
if nodes is None:
|
| 24 |
+
nodes = cluster.get('members', [])
|
| 25 |
+
seed_nodes.update(nodes)
|
| 26 |
+
|
| 27 |
+
return seed_nodes
|
| 28 |
+
except (FileNotFoundError, json.JSONDecodeError, KeyError) as e:
|
| 29 |
+
print(f"[red]Error reading {seeds_path}: {e}[/red]")
|
| 30 |
+
return set()
|
| 31 |
+
|
| 32 |
+
def _format_eps_label(val: Decimal) -> str:
|
| 33 |
+
"""Return a stable, unique string label for epsilon values.
|
| 34 |
+
|
| 35 |
+
- Use integer string for integral values (e.g., '50000').
|
| 36 |
+
- Otherwise, use a compact decimal without trailing zeros.
|
| 37 |
+
This avoids duplicates like many '1e+04' when step is small.
|
| 38 |
+
"""
|
| 39 |
+
# Normalize to remove exponent if integral
|
| 40 |
+
if val == val.to_integral_value():
|
| 41 |
+
return str(val.to_integral_value())
|
| 42 |
+
# Use 'f' then strip trailing zeros/decimal point for uniqueness and readability
|
| 43 |
+
s = format(val, 'f')
|
| 44 |
+
if '.' in s:
|
| 45 |
+
s = s.rstrip('0').rstrip('.')
|
| 46 |
+
return s
|
| 47 |
+
|
| 48 |
+
def generate_epsilon_range(start: float, end: float, step: float) -> List[str]:
|
| 49 |
+
"""Generate epsilon values as unique, stable strings.
|
| 50 |
+
|
| 51 |
+
Uses Decimal to avoid float accumulation and label collisions.
|
| 52 |
+
"""
|
| 53 |
+
if step <= 0:
|
| 54 |
+
raise ValueError("epsilon_step must be > 0")
|
| 55 |
+
|
| 56 |
+
getcontext().prec = 28
|
| 57 |
+
s = Decimal(str(start))
|
| 58 |
+
e = Decimal(str(end))
|
| 59 |
+
t = Decimal(str(step))
|
| 60 |
+
|
| 61 |
+
vals: List[str] = []
|
| 62 |
+
cur = s
|
| 63 |
+
# Safety margin to include end due to decimal rounding
|
| 64 |
+
while cur <= e + Decimal('1e-18'):
|
| 65 |
+
label = _format_eps_label(cur)
|
| 66 |
+
if not vals or vals[-1] != label:
|
| 67 |
+
vals.append(label)
|
| 68 |
+
cur += t
|
| 69 |
+
return vals
|
| 70 |
+
|
| 71 |
+
def run_epsilon_sweep(input_edgelist: str, out_dir: str, levels: int,
|
| 72 |
+
epsilon_start: float = 1e4, epsilon_end: float = 5e5,
|
| 73 |
+
epsilon_step: float = 1e4, cleanup_duplicates: bool = True):
|
| 74 |
+
"""
|
| 75 |
+
Run LRMC for multiple epsilon values and remove duplicate results.
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
input_edgelist: Path to input edgelist file
|
| 79 |
+
out_dir: Output directory
|
| 80 |
+
levels: Number of levels to build
|
| 81 |
+
epsilon_start: Starting epsilon value (default: 1e4)
|
| 82 |
+
epsilon_end: Ending epsilon value (default: 5e5)
|
| 83 |
+
epsilon_step: Step size for epsilon (default: 1e4)
|
| 84 |
+
cleanup_duplicates: Whether to remove duplicate seed sets (default: True)
|
| 85 |
+
"""
|
| 86 |
+
print(f"[blue]Starting epsilon sweep from {epsilon_start} to {epsilon_end} with step {epsilon_step}[/blue]")
|
| 87 |
+
|
| 88 |
+
# Preflight: check input edgelist path and fix a common typo
|
| 89 |
+
if not os.path.isfile(input_edgelist):
|
| 90 |
+
fixed_path = None
|
| 91 |
+
if input_edgelist.endswith('.tx') and os.path.isfile(input_edgelist + 't'):
|
| 92 |
+
fixed_path = input_edgelist + 't'
|
| 93 |
+
elif input_edgelist.endswith('.txt'):
|
| 94 |
+
# Try relative to CWD if a bare filename was intended
|
| 95 |
+
alt = os.path.join(os.getcwd(), input_edgelist)
|
| 96 |
+
if os.path.isfile(alt):
|
| 97 |
+
fixed_path = alt
|
| 98 |
+
if fixed_path:
|
| 99 |
+
print(f"[yellow]Input edgelist not found at '{input_edgelist}'. Using '{fixed_path}' instead.[/yellow]")
|
| 100 |
+
input_edgelist = fixed_path
|
| 101 |
+
else:
|
| 102 |
+
raise FileNotFoundError(f"Input edgelist not found: '{input_edgelist}'. Did you mean '.txt'?")
|
| 103 |
+
|
| 104 |
+
# Generate epsilon values
|
| 105 |
+
epsilons = generate_epsilon_range(epsilon_start, epsilon_end, epsilon_step)
|
| 106 |
+
print(f"[blue]Will test {len(epsilons)} epsilon values: {epsilons}[/blue]")
|
| 107 |
+
|
| 108 |
+
# Track seen seed sets and their corresponding epsilon values
|
| 109 |
+
seen_seed_sets: Dict[Tuple[int, ...], str] = {}
|
| 110 |
+
|
| 111 |
+
# Run for each epsilon
|
| 112 |
+
for epsilon in epsilons:
|
| 113 |
+
print(f"[yellow]Processing epsilon: {epsilon}[/yellow]")
|
| 114 |
+
|
| 115 |
+
try:
|
| 116 |
+
# Create temporary output directory for this epsilon
|
| 117 |
+
temp_out_dir = f"{out_dir}_temp_{epsilon}"
|
| 118 |
+
|
| 119 |
+
# Run LRMC
|
| 120 |
+
seeds_path = build_lrmc_single_graph(
|
| 121 |
+
input_edgelist=input_edgelist,
|
| 122 |
+
out_dir=temp_out_dir,
|
| 123 |
+
levels=levels,
|
| 124 |
+
epsilon=epsilon
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# Get seed nodes
|
| 128 |
+
seed_nodes = get_seed_nodes(seeds_path)
|
| 129 |
+
seed_nodes_tuple = tuple(sorted(seed_nodes))
|
| 130 |
+
|
| 131 |
+
print(f"[green]Epsilon {epsilon}: Found {len(seed_nodes)} unique seed nodes[/green]")
|
| 132 |
+
|
| 133 |
+
# Check if this seed set has been seen before
|
| 134 |
+
if seed_nodes_tuple in seen_seed_sets:
|
| 135 |
+
existing_epsilon = seen_seed_sets[seed_nodes_tuple]
|
| 136 |
+
print(f"[yellow]Duplicate seed set found! Epsilon {epsilon} has same seeds as {existing_epsilon}[/yellow]")
|
| 137 |
+
print(f"[yellow]Removing duplicate results for epsilon {epsilon}[/yellow]")
|
| 138 |
+
|
| 139 |
+
# Clean up temporary directory
|
| 140 |
+
if os.path.exists(temp_out_dir):
|
| 141 |
+
shutil.rmtree(temp_out_dir)
|
| 142 |
+
continue
|
| 143 |
+
|
| 144 |
+
# If we get here, this is a unique seed set
|
| 145 |
+
seen_seed_sets[seed_nodes_tuple] = epsilon
|
| 146 |
+
|
| 147 |
+
# Move results to final location
|
| 148 |
+
final_out_dir = f"{out_dir}_epsilon_{epsilon}"
|
| 149 |
+
if os.path.exists(final_out_dir):
|
| 150 |
+
shutil.rmtree(final_out_dir)
|
| 151 |
+
shutil.move(temp_out_dir, final_out_dir)
|
| 152 |
+
|
| 153 |
+
# Move seeds_XXXXX.json to the stage0 directory
|
| 154 |
+
seeds_file = os.path.join(final_out_dir, "stage0", f"seeds_{epsilon}.json")
|
| 155 |
+
if os.path.exists(seeds_file):
|
| 156 |
+
stage0_dir = os.path.join(out_dir, "stage0")
|
| 157 |
+
if not os.path.exists(stage0_dir):
|
| 158 |
+
os.makedirs(stage0_dir)
|
| 159 |
+
shutil.move(seeds_file, os.path.join(stage0_dir, f"seeds_{epsilon}.json"))
|
| 160 |
+
|
| 161 |
+
print(f"[green]Unique results saved to {os.path.join(stage0_dir, f"seeds_{epsilon}.json")}[/green]")
|
| 162 |
+
|
| 163 |
+
except Exception as e:
|
| 164 |
+
print(f"[red]Error processing epsilon {epsilon}: {e}[/red]")
|
| 165 |
+
# Clean up temporary directory if it exists
|
| 166 |
+
temp_out_dir = f"{out_dir}_temp_{epsilon}"
|
| 167 |
+
if os.path.exists(temp_out_dir):
|
| 168 |
+
shutil.rmtree(temp_out_dir)
|
| 169 |
+
|
| 170 |
+
# Print summary
|
| 171 |
+
print("\n[blue]--- Summary ---[/blue]")
|
| 172 |
+
print(f"[blue]Total epsilon values tested: {len(epsilons)}[/blue]")
|
| 173 |
+
print(f"[blue]Unique seed sets found: {len(seen_seed_sets)}[/blue]")
|
| 174 |
+
print(f"[blue]Duplicates removed: {len(epsilons) - len(seen_seed_sets)}[/blue]")
|
| 175 |
+
|
| 176 |
+
if seen_seed_sets:
|
| 177 |
+
print("\n[green]Unique epsilon values kept:[/green]")
|
| 178 |
+
for seed_tuple, epsilon in sorted(seen_seed_sets.items()):
|
| 179 |
+
seed_count = len(seed_tuple)
|
| 180 |
+
print(f" {epsilon}: {seed_count} seed nodes")
|
| 181 |
+
|
| 182 |
+
def main():
|
| 183 |
+
"""Main function with command line interface."""
|
| 184 |
+
import argparse
|
| 185 |
+
|
| 186 |
+
parser = argparse.ArgumentParser(description="Run LRMC epsilon sweep with duplicate removal")
|
| 187 |
+
parser.add_argument('--input_edgelist', type=str, required=True,
|
| 188 |
+
help='Path to input edgelist file')
|
| 189 |
+
parser.add_argument('--out_dir', type=str, required=True,
|
| 190 |
+
help='Base output directory (results will be saved as out_dir_epsilon_X)')
|
| 191 |
+
parser.add_argument('--levels', type=int, required=True,
|
| 192 |
+
help='Number of levels to build')
|
| 193 |
+
parser.add_argument('--epsilon_start', type=float, default=1e4,
|
| 194 |
+
help='Starting epsilon value (default: 1e4)')
|
| 195 |
+
parser.add_argument('--epsilon_end', type=float, default=5e5,
|
| 196 |
+
help='Ending epsilon value (default: 5e5)')
|
| 197 |
+
parser.add_argument('--epsilon_step', type=float, default=1e4,
|
| 198 |
+
help='Epsilon step size (default: 1e4)')
|
| 199 |
+
parser.add_argument('--no_cleanup', action='store_true',
|
| 200 |
+
help='Do not remove duplicates (keep all results)')
|
| 201 |
+
|
| 202 |
+
args = parser.parse_args()
|
| 203 |
+
|
| 204 |
+
run_epsilon_sweep(
|
| 205 |
+
input_edgelist=args.input_edgelist,
|
| 206 |
+
out_dir=args.out_dir,
|
| 207 |
+
levels=args.levels,
|
| 208 |
+
epsilon_start=args.epsilon_start,
|
| 209 |
+
epsilon_end=args.epsilon_end,
|
| 210 |
+
epsilon_step=args.epsilon_step,
|
| 211 |
+
cleanup_duplicates=not args.no_cleanup
|
| 212 |
+
)
|
| 213 |
+
|
| 214 |
+
if __name__ == '__main__':
|
| 215 |
+
main()
|
src/2_lrmc_bilevel.py
CHANGED
|
@@ -1,6 +1,4 @@
|
|
| 1 |
-
# lrmc_bilevel.py
|
| 2 |
# Bi-level Node↔Cluster message passing with fixed LRMC seeds
|
| 3 |
-
# Requires: torch, torch_geometric, torch_sparse
|
| 4 |
|
| 5 |
import argparse, json, os
|
| 6 |
from pathlib import Path
|
|
@@ -19,6 +17,8 @@ from torch_geometric.loader import DataLoader
|
|
| 19 |
from torch_geometric.datasets import Planetoid, TUDataset
|
| 20 |
from torch_geometric.nn import GCNConv, global_mean_pool
|
| 21 |
|
|
|
|
|
|
|
| 22 |
|
| 23 |
# ---------------------------
|
| 24 |
# Utilities: edges and seeds
|
|
|
|
|
|
|
| 1 |
# Bi-level Node↔Cluster message passing with fixed LRMC seeds
|
|
|
|
| 2 |
|
| 3 |
import argparse, json, os
|
| 4 |
from pathlib import Path
|
|
|
|
| 17 |
from torch_geometric.datasets import Planetoid, TUDataset
|
| 18 |
from torch_geometric.nn import GCNConv, global_mean_pool
|
| 19 |
|
| 20 |
+
from rich import print
|
| 21 |
+
|
| 22 |
|
| 23 |
# ---------------------------
|
| 24 |
# Utilities: edges and seeds
|
src/2_random_seed_sweep.sh
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env bash
|
| 2 |
+
# ------------------------------------------------------------
|
| 3 |
+
# generate_lrmc_seeds.sh
|
| 4 |
+
#
|
| 5 |
+
# Run generate_lrmc_seeds.py for num_nodes = 1, 101, 201, …,
|
| 6 |
+
# 2601, 2701 and then once more for 2708.
|
| 7 |
+
#
|
| 8 |
+
# Usage: ./generate_lrmc_seeds.sh
|
| 9 |
+
# ------------------------------------------------------------
|
| 10 |
+
|
| 11 |
+
set -euo pipefail # Exit on error, undefined var, or pipe failure
|
| 12 |
+
|
| 13 |
+
# ------------------------------------------------------------------
|
| 14 |
+
# Configuration – adjust these if your file layout changes
|
| 15 |
+
# ------------------------------------------------------------------
|
| 16 |
+
INPUT_EDGELIST="cora_seeds/edgelist.txt"
|
| 17 |
+
OUT_DIR="cora_seeds"
|
| 18 |
+
LEVELS=1
|
| 19 |
+
BASELINE="random"
|
| 20 |
+
|
| 21 |
+
# ------------------------------------------------------------------
|
| 22 |
+
# Main loop – 1 … 2708 stepping by 100
|
| 23 |
+
# ------------------------------------------------------------------
|
| 24 |
+
echo "Starting loop over num_nodes = 1, 101, 201, …, 2601, 2701 …"
|
| 25 |
+
|
| 26 |
+
for NUM_NODES in $(seq 1 100 2708); do
|
| 27 |
+
echo ">>> num_nodes=$NUM_NODES"
|
| 28 |
+
python3 generate_lrmc_seeds.py \
|
| 29 |
+
--input_edgelist "$INPUT_EDGELIST" \
|
| 30 |
+
--out_dir "$OUT_DIR" \
|
| 31 |
+
--levels "$LEVELS" \
|
| 32 |
+
--baseline "$BASELINE" \
|
| 33 |
+
--num_nodes "$NUM_NODES"
|
| 34 |
+
done
|
| 35 |
+
|
| 36 |
+
# ------------------------------------------------------------------
|
| 37 |
+
# Explicitly run for 2708 (not hit by the 100‑step sequence)
|
| 38 |
+
# ------------------------------------------------------------------
|
| 39 |
+
echo ">>> num_nodes=2708 (explicitly added)"
|
| 40 |
+
python3 generate_lrmc_seeds.py \
|
| 41 |
+
--input_edgelist "$INPUT_EDGELIST" \
|
| 42 |
+
--out_dir "$OUT_DIR" \
|
| 43 |
+
--levels "$LEVELS" \
|
| 44 |
+
--baseline "$BASELINE" \
|
| 45 |
+
--num_nodes 2708
|
| 46 |
+
|
| 47 |
+
echo "All done!"
|