qingy2024 commited on
Commit
f74dd01
·
verified ·
1 Parent(s): a35be16

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +11 -0
  2. GCond/.gitignore +140 -0
  3. GCond/GCond.png +0 -0
  4. GCond/KDD22_DosCond/README.md +38 -0
  5. GCond/README.md +139 -0
  6. GCond/__pycache__/configs.cpython-312.pyc +0 -0
  7. GCond/__pycache__/utils.cpython-312.pyc +0 -0
  8. GCond/__pycache__/utils_graphsaint.cpython-312.pyc +0 -0
  9. GCond/configs.py +24 -0
  10. GCond/coreset/__init__.py +3 -0
  11. GCond/coreset/__pycache__/__init__.cpython-312.pyc +0 -0
  12. GCond/coreset/__pycache__/all_methods.cpython-312.pyc +0 -0
  13. GCond/coreset/all_methods.py +212 -0
  14. GCond/gcond_agent_induct.py +327 -0
  15. GCond/gcond_agent_transduct.py +326 -0
  16. GCond/models/__pycache__/gcn.cpython-312.pyc +0 -0
  17. GCond/models/gat.py +312 -0
  18. GCond/models/gcn.py +404 -0
  19. GCond/models/myappnp.py +344 -0
  20. GCond/models/myappnp1.py +348 -0
  21. GCond/models/mycheby.py +417 -0
  22. GCond/models/mygatconv.py +203 -0
  23. GCond/models/mygraphsage.py +353 -0
  24. GCond/models/parametrized_adj.py +88 -0
  25. GCond/models/sgc.py +290 -0
  26. GCond/models/sgc_multi.py +315 -0
  27. GCond/requirements.txt +11 -0
  28. GCond/res/cross/empty +1 -0
  29. GCond/script.sh +4 -0
  30. GCond/scripts/run_cross.sh +16 -0
  31. GCond/scripts/run_main.sh +26 -0
  32. GCond/scripts/script_cross.sh +7 -0
  33. GCond/test_other_arcs.py +55 -0
  34. GCond/tester_other_arcs.py +258 -0
  35. GCond/train_coreset.py +117 -0
  36. GCond/train_coreset_induct.py +119 -0
  37. GCond/train_gcond_induct.py +61 -0
  38. GCond/train_gcond_transduct.py +57 -0
  39. GCond/utils.py +383 -0
  40. GCond/utils_graphsaint.py +220 -0
  41. requirements.txt +2 -2
  42. src/2.1_lrmc_bilevel.py +2 -2
  43. src/2.2_lrmc_bilevel.py +2 -1
  44. src/2.3_lrmc_bilevel.py +2 -1
  45. src/2.4_lrmc_bilevel.py +2 -1
  46. src/2.5_lrmc_bilevel.py +540 -0
  47. src/2.6_lrmc_summary.py +691 -0
  48. src/2_epsilon_seed_sweep.py +215 -0
  49. src/2_lrmc_bilevel.py +2 -2
  50. src/2_random_seed_sweep.sh +47 -0
.gitignore CHANGED
@@ -47,3 +47,14 @@ GraphUNets/data/*
47
  temp*
48
 
49
  *.tex
 
 
 
 
 
 
 
 
 
 
 
 
47
  temp*
48
 
49
  *.tex
50
+
51
+ src/data/*
52
+
53
+ src/cora_seeds/*
54
+
55
+ .venv/*
56
+
57
+
58
+ .ipynb_checkpoints/
59
+
60
+ *.dot
GCond/.gitignore ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Auxillary file on MacOS
2
+ .DS_Store
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ models/__pycache__/
7
+ *.py[co
8
+ dataset
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ data/
16
+ build/
17
+ develop-eggs/
18
+ dist/
19
+ downloads/
20
+ eggs/
21
+ .eggs/
22
+ lib/
23
+ lib64/
24
+ parts/
25
+ sdist/
26
+ var/
27
+ wheels/
28
+ pip-wheel-metadata/
29
+ share/python-wheels/
30
+ *.egg-info/
31
+ .installed.cfg
32
+ *.egg
33
+ MANIFEST
34
+
35
+ # PyInstaller
36
+ # Usually these files are written by a python script from a template
37
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
38
+ *.manifest
39
+ *.spec
40
+
41
+ # Installer logs
42
+ pip-log.txt
43
+ pip-delete-this-directory.txt
44
+
45
+ # Unit test / coverage reports
46
+ htmlcov/
47
+ .tox/
48
+ .nox/
49
+ .coverage
50
+ .coverage.*
51
+ .cache
52
+ nosetests.xml
53
+ coverage.xml
54
+ *.cover
55
+ *.py,cover
56
+ .hypothesis/
57
+ .pytest_cache/
58
+
59
+ # Translations
60
+ *.mo
61
+ *.pot
62
+
63
+ # Django stuff:
64
+ *.log
65
+ local_settings.py
66
+ db.sqlite3
67
+ db.sqlite3-journal
68
+
69
+ # Flask stuff:
70
+ instance/
71
+ .webassets-cache
72
+
73
+ # Scrapy stuff:
74
+ .scrapy
75
+
76
+ # Sphinx documentation
77
+ docs/_build/
78
+
79
+ # PyBuilder
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # IPython
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # pyenv
90
+ .python-version
91
+
92
+ # pipenv
93
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
95
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
96
+ # install all needed dependencies.
97
+ #Pipfile.lock
98
+
99
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
100
+ __pypackages__/
101
+
102
+ # Celery stuff
103
+ celerybeat-schedule
104
+ celerybeat.pid
105
+
106
+ # SageMath parsed files
107
+ *.sage.py
108
+
109
+ # Environments
110
+ .env
111
+ .venv
112
+ env/
113
+ venv/
114
+ ENV/
115
+ env.bak/
116
+ venv.bak/
117
+
118
+ # Spyder project settings
119
+ .spyderproject
120
+ .spyproject
121
+
122
+ # Rope project settings
123
+ .ropeproject
124
+
125
+ # mkdocs documentation
126
+ /site
127
+
128
+ # mypy
129
+ .mypy_cache/
130
+ .dmypy.json
131
+ dmypy.json
132
+
133
+ # Pyre type checker
134
+ .pyre/d]
135
+ *$py.class:
136
+
137
+ def __init__(self):
138
+ pass
139
+
140
+
GCond/GCond.png ADDED
GCond/KDD22_DosCond/README.md ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # DosCond
2
+
3
+ [KDD 2022] A PyTorch Implementation for ["Condensing Graphs via One-Step Gradient Matching"](https://arxiv.org/abs/2206.07746) under node classification setting. For graph classification setting, please refer to [https://github.com/amazon-research/DosCond](https://github.com/amazon-research/DosCond).
4
+
5
+
6
+ Abstract
7
+ ----
8
+ As training deep learning models on large dataset takes a lot of time and resources, it is desired to construct a small synthetic dataset with which we can train deep learning models sufficiently. There are recent works that have explored solutions on condensing image datasets through complex bi-level optimization. For instance, dataset condensation (DC) matches network gradients w.r.t. large-real data and small-synthetic data, where the network weights are optimized for multiple steps at each outer iteration. However, existing approaches have their inherent limitations: (1) they are not directly applicable to graphs where the data is discrete; and (2) the condensation process is computationally expensive due to the involved nested optimization. To bridge the gap, we investigate efficient dataset condensation tailored for graph datasets where we model the discrete graph structure as a probabilistic model. We further propose a one-step gradient matching scheme, which performs gradient matching for only one single step without training the network weights.
9
+
10
+
11
+ Here we do not implement the discrete structure learning, but only borrow the idea from ["Condensing Graphs via One-Step Gradient Matching"](https://arxiv.org/abs/2206.07746) to perform one-step gradient matching, which significantly speeds up the condensation process.
12
+
13
+
14
+ Essentially, we can run the following commands:
15
+ ```
16
+ python train_gcond_transduct.py --dataset citeseer --nlayers=2 --lr_feat=1e-3 --lr_adj=1e-3 --r=0.5 --sgc=0 --dis=mse --one_step=1 --epochs=3000
17
+ python train_gcond_transduct.py --dataset cora --nlayers=2 --lr_feat=1e-3 --lr_adj=1e-3 --r=0.5 --sgc=0 --dis=mse --gpu_id=2 --one_step=1 --epochs=5000
18
+ python train_gcond_transduct.py --dataset pubmed --nlayers=2 --lr_feat=1e-3 --lr_adj=1e-3 --r=0.5 --sgc=0 --dis=mse --gpu_id=2 --one_step=1 --epochs=2000
19
+ python train_gcond_transduct.py --dataset ogbn-arxiv --nlayers=2 --lr_feat=1e-2 --lr_adj=2e-2 --r=0.001 --sgc=1 --dis=ours --gpu_id=2 --one_step=1 --epochs=1000
20
+ python train_gcond_induct.py --dataset flickr --nlayers=2 --lr_feat=5e-3 --lr_adj=5e-3 --r=0.001 --sgc=0 --dis=mse --gpu_id=3 --one_step=1 --epochs=1000
21
+ ```
22
+ Note that using smaller learning rate and larger epochs can get even higher performance.
23
+
24
+
25
+ ## Cite
26
+ For more information, you can take a look at the [paper](https://arxiv.org/abs/2206.07746).
27
+
28
+ If you find this repo to be useful, please cite our paper. Thank you.
29
+ ```
30
+ @inproceedings{jin2022condensing,
31
+ title={Condensing Graphs via One-Step Gradient Matching},
32
+ author={Jin, Wei and Tang, Xianfeng and Jiang, Haoming and Li, Zheng and Zhang, Danqing and Tang, Jiliang and Yin, Bing},
33
+ booktitle={Proceedings of the 28th ACM SIGKDD Conference on Knowledge Discovery and Data Mining},
34
+ pages={720--730},
35
+ year={2022}
36
+ }
37
+ ```
38
+
GCond/README.md ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # GCond
2
+ [ICLR 2022] The PyTorch implementation for ["Graph Condensation for Graph Neural Networks"](https://cse.msu.edu/~jinwei2/files/GCond.pdf) is provided under the main directory.
3
+
4
+ [KDD 2022] The implementation for ["Condensing Graphs via One-Step Gradient Matching"](https://arxiv.org/abs/2206.07746) is shown in the `KDD22_DosCond` directory. See [link](https://github.com/ChandlerBang/GCond/tree/main/KDD22_DosCond).
5
+
6
+ [IJCAI 2024] Please read our recent survey ["A Comprehensive Survey on Graph Reduction: Sparsification, Coarsening, and Condensation"](https://arxiv.org/abs/2402.03358) for a detailed review of graph reduction techniques!
7
+
8
+ [ArXiv 2024] We released a benchmarking framework for graph condensation ["GC4NC: A Benchmark Framework for Graph Condensation with New Insights"](https://arxiv.org/abs/2406.16715), including **robustness**, **privacy preservation**, NAS performance, property analysis, etc!
9
+
10
+ Abstract
11
+ ----
12
+ We propose and study the problem of graph condensation for graph neural networks (GNNs). Specifically, we aim to condense the large, original graph into a small, synthetic, and highly-informative graph, such that GNNs trained on the small graph and large graph have comparable performance. Extensive experiments have demonstrated the effectiveness of the proposed framework in condensing different graph datasets into informative smaller graphs. In particular, we are able to approximate the original test accuracy by 95.3% on Reddit, 99.8% on Flickr and 99.0% on Citeseer, while reducing their graph size by more than 99.9%, and the condensed graphs can be used to train various GNN architectures.
13
+
14
+
15
+ ![]()
16
+
17
+ <div align=center><img src="https://github.com/ChandlerBang/GCond/blob/main/GCond.png" width="800"/></div>
18
+
19
+
20
+ ## A Nice Survey Paper
21
+ Please check out our survey paper blew, which summarizes the recent advances in graph condensation.
22
+
23
+ ![image](https://github.com/CurryTang/Towards-Graph-Foundation-Models-New-perspective-/assets/15672123/89a23a37-71d4-47f7-8949-7d859a41e369)![image](https://github.com/CurryTang/Towards-Graph-Foundation-Models-New-perspective-/assets/15672123/89a23a37-71d4-47f7-8949-7d859a41e369)![image](https://github.com/CurryTang/Towards-Graph-Foundation-Models-New-perspective-/assets/15672123/89a23a37-71d4-47f7-8949-7d859a41e369)[[A Comprehensive Survey on Graph Reduction: Sparsification, Coarsening, and Condensation]](https://arxiv.org/abs/2402.03358)
24
+
25
+
26
+
27
+ ## Requirements
28
+ Please see [requirements.txt](https://github.com/ChandlerBang/GCond/blob/main/requirements.txt).
29
+ ```
30
+ torch==1.7.0
31
+ torch_geometric==1.6.3
32
+ scipy==1.6.2
33
+ numpy==1.19.2
34
+ ogb==1.3.0
35
+ tqdm==4.59.0
36
+ torch_sparse==0.6.9
37
+ deeprobust==0.2.4
38
+ scikit_learn==1.0.2
39
+ ```
40
+
41
+ ## Download Datasets
42
+ For cora, citeseer and pubmed, the code will directly download them; so no extra script is needed.
43
+ For reddit, flickr and arxiv, we use the datasets provided by [GraphSAINT](https://github.com/GraphSAINT/GraphSAINT).
44
+ They are available on [Google Drive link](https://drive.google.com/open?id=1zycmmDES39zVlbVCYs88JTJ1Wm5FbfLz) (alternatively, [BaiduYun link (code: f1ao)](https://pan.baidu.com/s/1SOb0SiSAXavwAcNqkttwcg)). Rename the folder to `data` at the root directory. Note that the links are provided by GraphSAINT team.
45
+
46
+
47
+
48
+
49
+ ## Run the code
50
+ For transductive setting, please run the following command:
51
+ ```
52
+ python train_gcond_transduct.py --dataset cora --nlayers=2 --lr_feat=1e-4 --gpu_id=0 --lr_adj=1e-4 --r=0.5
53
+ ```
54
+ where `r` indicates the ratio of condensed samples to the labeled samples. For instance, there are only 140 labeled nodes in Cora dataset, so `r=0.5` indicates the number of condensed samples are 70, **which corresponds to r=2.6%=70/2710 in our paper**. Thus, the parameter `r` is different from the real reduction rate in the paper for the transductive setting, please see the following table for the correspondence.
55
+
56
+ | | `r` in the code | `r` in the paper (real reduction rate) |
57
+ |--------------|-------------------|---------------------|
58
+ | Transductive | Cora, r=0.5 | Cora, r=2.6% |
59
+ | Transductive | Citeseer, r=0.5 | Citeseer, r=1.8% |
60
+ | Transductive | Ogbn-arxiv, r= 0.005 | Ogbn-arxiv, r=0.25% |
61
+ | Transductive | Pubmed, r=0.5 | Pubmed, r=0.3% |
62
+ | Inductive | Flickr, r=0.01 | Flickr, r=1% |
63
+ | Inductive | Reddit, r=0.001 | Reddit, r=0.1% |
64
+
65
+ For inductive setting, please run the following command:
66
+ ```
67
+ python train_gcond_induct.py --dataset flickr --nlayers=2 --lr_feat=0.01 --gpu_id=0 --lr_adj=0.01 --r=0.005 --epochs=1000 --outer=10 --inner=1
68
+ ```
69
+
70
+ ## Reproduce the performance
71
+ The generated graphs are saved in the folder `saved_ours`; you can directly load them to test the performance.
72
+
73
+ For Table 2, run `bash scripts/run_main.sh`.
74
+
75
+ For Table 3, run `bash scripts/run_cross.sh`.
76
+
77
+ ## [Faster Condensation!] One-Step Gradient Matching
78
+ From the KDD'22 paper ["Condensing Graphs via One-Step Gradient Matching"](https://arxiv.org/abs/2206.07746), we know that performing gradient matching for only one step can also achieve a good performance while significantly accelerating the condensation process. Hence, we can run the following command to perform one-step gradient matching, which is essentially much faster than the original version:
79
+ ```
80
+ python train_gcond_transduct.py --dataset citeseer --nlayers=2 --lr_feat=1e-2 --lr_adj=1e-2 --r=0.5 \
81
+ --sgc=0 --dis=mse --gpu_id=2 --one_step=1 --epochs=3000
82
+ ```
83
+ For more commands, please go to [`KDD22_DosCond`](https://github.com/ChandlerBang/GCond/tree/main/KDD22_DosCond).
84
+
85
+ **[Note]: I found that sometimes using MSE loss for gradient matching can be more stable than using `ours` loss**, and it gives more flexibility on the model used in condensation (using GCN as the backbone can also generate good condensed graphs).
86
+
87
+
88
+ ## Whole Dataset Performance
89
+ When we do coreset selection, we need to first the model on the whole dataset. Thus we can obtain the performanceo of whole dataset by running `train_coreset.py` and `train_coreset_induct.py`:
90
+ ```
91
+ python train_coreset.py --dataset cora --r=0.01 --method=random
92
+ python train_coreset_induct.py --dataset flickr --r=0.01 --method=random
93
+ ```
94
+
95
+ ## Coreset Performance
96
+ Run the following code to get the coreset performance for transductive setting.
97
+ ```
98
+ python train_coreset.py --dataset cora --r=0.01 --method=herding
99
+ python train_coreset.py --dataset cora --r=0.01 --method=random
100
+ python train_coreset.py --dataset cora --r=0.01 --method=kcenter
101
+ ```
102
+ Similarly, run the following code for the inductive setting.
103
+ ```
104
+ python train_coreset_induct.py --dataset flickr --r=0.01 --method=kcenter
105
+ ```
106
+
107
+
108
+ ## Cite
109
+ If you find this repo to be useful, please cite our three papers. Thank you!
110
+ ```
111
+ @inproceedings{
112
+ jin2022graph,
113
+ title={Graph Condensation for Graph Neural Networks},
114
+ author={Wei Jin and Lingxiao Zhao and Shichang Zhang and Yozen Liu and Jiliang Tang and Neil Shah},
115
+ booktitle={International Conference on Learning Representations},
116
+ year={2022},
117
+ url={https://openreview.net/forum?id=WLEx3Jo4QaB}
118
+ }
119
+ ```
120
+
121
+ ```
122
+ @inproceedings{jin2022condensing,
123
+ title={Condensing Graphs via One-Step Gradient Matching},
124
+ author={Jin, Wei and Tang, Xianfeng and Jiang, Haoming and Li, Zheng and Zhang, Danqing and Tang, Jiliang and Yin, Bing},
125
+ booktitle={Proceedings of the 28th ACM SIGKDD Conference on Knowledge Discovery and Data Mining},
126
+ pages={720--730},
127
+ year={2022}
128
+ }
129
+ ```
130
+
131
+ ```
132
+ @article{hashemi2024comprehensive,
133
+ title={A Comprehensive Survey on Graph Reduction: Sparsification, Coarsening, and Condensation},
134
+ author={Hashemi, Mohammad and Gong, Shengbo and Ni, Juntong and Fan, Wenqi and Prakash, B Aditya and Jin, Wei},
135
+ journal={International Joint Conference on Artificial Intelligence (IJCAI)},
136
+ year={2024}
137
+ }
138
+ ```
139
+
GCond/__pycache__/configs.cpython-312.pyc ADDED
Binary file (706 Bytes). View file
 
GCond/__pycache__/utils.cpython-312.pyc ADDED
Binary file (20.5 kB). View file
 
GCond/__pycache__/utils_graphsaint.cpython-312.pyc ADDED
Binary file (12.2 kB). View file
 
GCond/configs.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''Configuration'''
2
+
3
+ def load_config(args):
4
+ dataset = args.dataset
5
+ if dataset in ['flickr']:
6
+ args.nlayers = 2
7
+ args.hidden = 256
8
+ args.weight_decay = 5e-3
9
+ args.dropout = 0.0
10
+
11
+ if dataset in ['reddit']:
12
+ args.nlayers = 2
13
+ args.hidden = 256
14
+ args.weight_decay = 0e-4
15
+ args.dropout = 0
16
+
17
+ if dataset in ['ogbn-arxiv']:
18
+ args.hidden = 256
19
+ args.weight_decay = 0
20
+ args.dropout = 0
21
+
22
+ return args
23
+
24
+
GCond/coreset/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .all_methods import KCenter, Herding, Random, LRMC
2
+
3
+ __all__ = ['KCenter', 'Herding', 'Random', 'LRMC']
GCond/coreset/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (260 Bytes). View file
 
GCond/coreset/__pycache__/all_methods.cpython-312.pyc ADDED
Binary file (12.1 kB). View file
 
GCond/coreset/all_methods.py ADDED
@@ -0,0 +1,212 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import json
4
+
5
+
6
+ class Base:
7
+
8
+ def __init__(self, data, args, device='cuda', **kwargs):
9
+ self.data = data
10
+ self.args = args
11
+ self.device = device
12
+ n = int(data.feat_train.shape[0] * args.reduction_rate)
13
+ d = data.feat_train.shape[1]
14
+ self.nnodes_syn = n
15
+ self.labels_syn = torch.LongTensor(self.generate_labels_syn(data)).to(device)
16
+
17
+ def generate_labels_syn(self, data):
18
+ from collections import Counter
19
+ counter = Counter(data.labels_train)
20
+ num_class_dict = {}
21
+ n = len(data.labels_train)
22
+
23
+ sorted_counter = sorted(counter.items(), key=lambda x:x[1])
24
+ sum_ = 0
25
+ labels_syn = []
26
+ self.syn_class_indices = {}
27
+ for ix, (c, num) in enumerate(sorted_counter):
28
+ if ix == len(sorted_counter) - 1:
29
+ num_class_dict[c] = int(n * self.args.reduction_rate) - sum_
30
+ self.syn_class_indices[c] = [len(labels_syn), len(labels_syn) + num_class_dict[c]]
31
+ labels_syn += [c] * num_class_dict[c]
32
+ else:
33
+ num_class_dict[c] = max(int(num * self.args.reduction_rate), 1)
34
+ sum_ += num_class_dict[c]
35
+ self.syn_class_indices[c] = [len(labels_syn), len(labels_syn) + num_class_dict[c]]
36
+ labels_syn += [c] * num_class_dict[c]
37
+
38
+ self.num_class_dict = num_class_dict
39
+ return labels_syn
40
+
41
+ def select(self):
42
+ return
43
+
44
+ class KCenter(Base):
45
+
46
+ def __init__(self, data, args, device='cuda', **kwargs):
47
+ super(KCenter, self).__init__(data, args, device='cuda', **kwargs)
48
+
49
+ def select(self, embeds, inductive=False):
50
+ # feature: embeds
51
+ # kcenter # class by class
52
+ num_class_dict = self.num_class_dict
53
+ if inductive:
54
+ idx_train = np.arange(len(self.data.idx_train))
55
+ else:
56
+ idx_train = self.data.idx_train
57
+ labels_train = self.data.labels_train
58
+ idx_selected = []
59
+
60
+ for class_id, cnt in num_class_dict.items():
61
+ idx = idx_train[labels_train==class_id]
62
+ feature = embeds[idx]
63
+ mean = torch.mean(feature, dim=0, keepdim=True)
64
+ # dis = distance(feature, mean)[:,0]
65
+ dis = torch.cdist(feature, mean)[:,0]
66
+ rank = torch.argsort(dis)
67
+ idx_centers = rank[:1].tolist()
68
+ for i in range(cnt-1):
69
+ feature_centers = feature[idx_centers]
70
+ dis_center = torch.cdist(feature, feature_centers)
71
+ dis_min, _ = torch.min(dis_center, dim=-1)
72
+ id_max = torch.argmax(dis_min).item()
73
+ idx_centers.append(id_max)
74
+
75
+ idx_selected.append(idx[idx_centers])
76
+ # return np.array(idx_selected).reshape(-1)
77
+ return np.hstack(idx_selected)
78
+
79
+
80
+ class Herding(Base):
81
+
82
+ def __init__(self, data, args, device='cuda', **kwargs):
83
+ super(Herding, self).__init__(data, args, device='cuda', **kwargs)
84
+
85
+ def select(self, embeds, inductive=False):
86
+ num_class_dict = self.num_class_dict
87
+ if inductive:
88
+ idx_train = np.arange(len(self.data.idx_train))
89
+ else:
90
+ idx_train = self.data.idx_train
91
+ labels_train = self.data.labels_train
92
+ idx_selected = []
93
+
94
+ # herding # class by class
95
+ for class_id, cnt in num_class_dict.items():
96
+ idx = idx_train[labels_train==class_id]
97
+ features = embeds[idx]
98
+ mean = torch.mean(features, dim=0, keepdim=True)
99
+ selected = []
100
+ idx_left = np.arange(features.shape[0]).tolist()
101
+
102
+ for i in range(cnt):
103
+ det = mean*(i+1) - torch.sum(features[selected], dim=0)
104
+ dis = torch.cdist(det, features[idx_left])
105
+ id_min = torch.argmin(dis)
106
+ selected.append(idx_left[id_min])
107
+ del idx_left[id_min]
108
+ idx_selected.append(idx[selected])
109
+ # return np.array(idx_selected).reshape(-1)
110
+ return np.hstack(idx_selected)
111
+
112
+
113
+ class Random(Base):
114
+
115
+ def __init__(self, data, args, device='cuda', **kwargs):
116
+ super(Random, self).__init__(data, args, device='cuda', **kwargs)
117
+
118
+ def select(self, embeds, inductive=False):
119
+ num_class_dict = self.num_class_dict
120
+ if inductive:
121
+ idx_train = np.arange(len(self.data.idx_train))
122
+ else:
123
+ idx_train = self.data.idx_train
124
+
125
+ labels_train = self.data.labels_train
126
+ idx_selected = []
127
+
128
+ for class_id, cnt in num_class_dict.items():
129
+ idx = idx_train[labels_train==class_id]
130
+ selected = np.random.permutation(idx)
131
+ idx_selected.append(selected[:cnt])
132
+
133
+ # return np.array(idx_selected).reshape(-1)
134
+ return np.hstack(idx_selected)
135
+
136
+
137
+ class LRMC(Base):
138
+ """
139
+ Coreset selection using precomputed seed nodes from the Laplacian‑Integrated
140
+ Relaxed Maximal Clique (L‑RMC) algorithm. Seed nodes are read from a JSON
141
+ file specified by ``args.lrmc_seeds_path`` and used to preferentially select
142
+ training examples. Per‑class reduction counts are respected: if a class has
143
+ fewer seeds than required, random training nodes from that class are added
144
+ until the quota is met.
145
+ """
146
+
147
+ def __init__(self, data, args, device='cuda', **kwargs):
148
+ super(LRMC, self).__init__(data, args, device=device, **kwargs)
149
+ seeds_path = getattr(args, 'lrmc_seeds_path', None)
150
+ if seeds_path is None:
151
+ raise ValueError(
152
+ "LRMC method selected but no path to seed file provided. "
153
+ "Please specify --lrmc_seeds_path when running the training script."
154
+ )
155
+ self.seed_nodes = self._load_seed_nodes(seeds_path)
156
+
157
+ def _load_seed_nodes(self, path: str):
158
+ # Parse seed nodes from JSON file (supports 'seed_nodes' or 'members').
159
+ with open(path, 'r') as f:
160
+ js = json.load(f)
161
+ clusters = js.get('clusters', [])
162
+ if not clusters:
163
+ raise ValueError(f"No clusters found in L‑RMC seeds file {path}")
164
+ def _cluster_length(c):
165
+ nodes = c.get('seed_nodes') or c.get('members') or []
166
+ return len(nodes)
167
+ best_cluster = max(clusters, key=_cluster_length)
168
+ nodes = best_cluster.get('seed_nodes') or best_cluster.get('members') or []
169
+ seed_nodes = []
170
+ for u in nodes:
171
+ try:
172
+ uid = int(u)
173
+ except Exception:
174
+ continue
175
+ zero_idx = uid - 1
176
+ if zero_idx >= 0:
177
+ seed_nodes.append(zero_idx)
178
+ else:
179
+ if uid >= 0:
180
+ seed_nodes.append(uid)
181
+ seed_nodes = sorted(set(seed_nodes))
182
+ return seed_nodes
183
+
184
+ def select(self, embeds, inductive=False):
185
+ # Determine training indices depending on the inductive setting.
186
+ if inductive:
187
+ idx_train = np.arange(len(self.data.idx_train))
188
+ labels_train = self.data.labels_train
189
+ else:
190
+ idx_train = self.data.idx_train
191
+ labels_train = self.data.labels_train
192
+ num_class_dict = self.num_class_dict
193
+ idx_selected = []
194
+ seed_set = set(self.seed_nodes)
195
+ # Pick seed nodes per class; fill remainder with random nodes if needed.
196
+ for class_id, cnt in num_class_dict.items():
197
+ class_mask = (labels_train == class_id)
198
+ class_indices = idx_train[class_mask]
199
+ seed_in_class = [u for u in class_indices if u in seed_set]
200
+ selected = seed_in_class[:min(len(seed_in_class), cnt)]
201
+ remaining_required = cnt - len(selected)
202
+ if remaining_required > 0:
203
+ remaining_candidates = [u for u in class_indices if u not in selected]
204
+ if len(remaining_candidates) <= remaining_required:
205
+ additional = remaining_candidates
206
+ else:
207
+ additional = np.random.choice(remaining_candidates, remaining_required, replace=False).tolist()
208
+ selected += additional
209
+ idx_selected.append(np.array(selected))
210
+ return np.hstack(idx_selected)
211
+
212
+
GCond/gcond_agent_induct.py ADDED
@@ -0,0 +1,327 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torch.nn import Parameter
6
+ import torch.nn.functional as F
7
+ from utils import match_loss, regularization, row_normalize_tensor
8
+ import deeprobust.graph.utils as utils
9
+ from copy import deepcopy
10
+ import numpy as np
11
+ from tqdm import tqdm
12
+ from models.gcn import GCN
13
+ from models.sgc import SGC
14
+ from models.sgc_multi import SGC as SGC1
15
+ from models.parametrized_adj import PGE
16
+ import scipy.sparse as sp
17
+ from torch_sparse import SparseTensor
18
+
19
+
20
+ class GCond:
21
+
22
+ def __init__(self, data, args, device='cuda', **kwargs):
23
+ self.data = data
24
+ self.args = args
25
+ self.device = device
26
+
27
+ n = int(len(data.idx_train) * args.reduction_rate)
28
+ d = data.feat_train.shape[1]
29
+ self.nnodes_syn = n
30
+ self.feat_syn = nn.Parameter(torch.FloatTensor(n, d).to(device))
31
+ self.pge = PGE(nfeat=d, nnodes=n, device=device, args=args).to(device)
32
+
33
+ self.labels_syn = torch.LongTensor(self.generate_labels_syn(data)).to(device)
34
+ self.reset_parameters()
35
+ self.optimizer_feat = torch.optim.Adam([self.feat_syn], lr=args.lr_feat)
36
+ self.optimizer_pge = torch.optim.Adam(self.pge.parameters(), lr=args.lr_adj)
37
+ print('adj_syn:', (n,n), 'feat_syn:', self.feat_syn.shape)
38
+
39
+ def reset_parameters(self):
40
+ self.feat_syn.data.copy_(torch.randn(self.feat_syn.size()))
41
+
42
+ def generate_labels_syn(self, data):
43
+ from collections import Counter
44
+ counter = Counter(data.labels_train)
45
+ num_class_dict = {}
46
+ n = len(data.labels_train)
47
+
48
+ sorted_counter = sorted(counter.items(), key=lambda x:x[1])
49
+ sum_ = 0
50
+ labels_syn = []
51
+ self.syn_class_indices = {}
52
+
53
+ for ix, (c, num) in enumerate(sorted_counter):
54
+ if ix == len(sorted_counter) - 1:
55
+ num_class_dict[c] = int(n * self.args.reduction_rate) - sum_
56
+ self.syn_class_indices[c] = [len(labels_syn), len(labels_syn) + num_class_dict[c]]
57
+ labels_syn += [c] * num_class_dict[c]
58
+ else:
59
+ num_class_dict[c] = max(int(num * self.args.reduction_rate), 1)
60
+ sum_ += num_class_dict[c]
61
+ self.syn_class_indices[c] = [len(labels_syn), len(labels_syn) + num_class_dict[c]]
62
+ labels_syn += [c] * num_class_dict[c]
63
+
64
+ self.num_class_dict = num_class_dict
65
+ return labels_syn
66
+
67
+ def test_with_val(self, verbose=True):
68
+ res = []
69
+
70
+ data, device = self.data, self.device
71
+ feat_syn, pge, labels_syn = self.feat_syn.detach(), \
72
+ self.pge, self.labels_syn
73
+ # with_bn = True if args.dataset in ['ogbn-arxiv'] else False
74
+ dropout = 0.5 if self.args.dataset in ['reddit'] else 0
75
+ model = GCN(nfeat=feat_syn.shape[1], nhid=self.args.hidden, dropout=dropout,
76
+ weight_decay=5e-4, nlayers=2,
77
+ nclass=data.nclass, device=device).to(device)
78
+
79
+ adj_syn = pge.inference(feat_syn)
80
+ args = self.args
81
+
82
+ if args.save:
83
+ torch.save(adj_syn, f'saved_ours/adj_{args.dataset}_{args.reduction_rate}_{args.seed}.pt')
84
+ torch.save(feat_syn, f'saved_ours/feat_{args.dataset}_{args.reduction_rate}_{args.seed}.pt')
85
+
86
+ noval = True
87
+ model.fit_with_val(feat_syn, adj_syn, labels_syn, data,
88
+ train_iters=600, normalize=True, verbose=False, noval=noval)
89
+
90
+ model.eval()
91
+ labels_test = torch.LongTensor(data.labels_test).cuda()
92
+
93
+ output = model.predict(data.feat_test, data.adj_test)
94
+
95
+ loss_test = F.nll_loss(output, labels_test)
96
+ acc_test = utils.accuracy(output, labels_test)
97
+ res.append(acc_test.item())
98
+ if verbose:
99
+ print("Test set results:",
100
+ "loss= {:.4f}".format(loss_test.item()),
101
+ "accuracy= {:.4f}".format(acc_test.item()))
102
+ print(adj_syn.sum(), adj_syn.sum()/(adj_syn.shape[0]**2))
103
+
104
+ if False:
105
+ if self.args.dataset == 'ogbn-arxiv':
106
+ thresh = 0.6
107
+ elif self.args.dataset == 'reddit':
108
+ thresh = 0.91
109
+ else:
110
+ thresh = 0.7
111
+
112
+ labels_train = torch.LongTensor(data.labels_train).cuda()
113
+ output = model.predict(data.feat_train, data.adj_train)
114
+ # loss_train = F.nll_loss(output, labels_train)
115
+ # acc_train = utils.accuracy(output, labels_train)
116
+ loss_train = torch.tensor(0)
117
+ acc_train = torch.tensor(0)
118
+ if verbose:
119
+ print("Train set results:",
120
+ "loss= {:.4f}".format(loss_train.item()),
121
+ "accuracy= {:.4f}".format(acc_train.item()))
122
+ res.append(acc_train.item())
123
+ return res
124
+
125
+ def train(self, verbose=True):
126
+ args = self.args
127
+ data = self.data
128
+ feat_syn, pge, labels_syn = self.feat_syn, self.pge, self.labels_syn
129
+ features, adj, labels = data.feat_train, data.adj_train, data.labels_train
130
+ syn_class_indices = self.syn_class_indices
131
+ features, adj, labels = utils.to_tensor(features, adj, labels, device=self.device)
132
+ feat_sub, adj_sub = self.get_sub_adj_feat(features)
133
+ self.feat_syn.data.copy_(feat_sub)
134
+
135
+ if utils.is_sparse_tensor(adj):
136
+ adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
137
+ else:
138
+ adj_norm = utils.normalize_adj_tensor(adj)
139
+
140
+ adj = adj_norm
141
+ adj = SparseTensor(row=adj._indices()[0], col=adj._indices()[1],
142
+ value=adj._values(), sparse_sizes=adj.size()).t()
143
+
144
+
145
+ outer_loop, inner_loop = get_loops(args)
146
+
147
+ for it in range(args.epochs+1):
148
+ loss_avg = 0
149
+ if args.sgc==1:
150
+ model = SGC(nfeat=data.feat_train.shape[1], nhid=args.hidden,
151
+ nclass=data.nclass, dropout=args.dropout,
152
+ nlayers=args.nlayers, with_bn=False,
153
+ device=self.device).to(self.device)
154
+ elif args.sgc==2:
155
+ model = SGC1(nfeat=data.feat_train.shape[1], nhid=args.hidden,
156
+ nclass=data.nclass, dropout=args.dropout,
157
+ nlayers=args.nlayers, with_bn=False,
158
+ device=self.device).to(self.device)
159
+
160
+ else:
161
+ model = GCN(nfeat=data.feat_train.shape[1], nhid=args.hidden,
162
+ nclass=data.nclass, dropout=args.dropout, nlayers=args.nlayers,
163
+ device=self.device).to(self.device)
164
+
165
+ model.initialize()
166
+
167
+ model_parameters = list(model.parameters())
168
+
169
+ optimizer_model = torch.optim.Adam(model_parameters, lr=args.lr_model)
170
+ model.train()
171
+
172
+ for ol in range(outer_loop):
173
+ adj_syn = pge(self.feat_syn)
174
+ adj_syn_norm = utils.normalize_adj_tensor(adj_syn, sparse=False)
175
+ feat_syn_norm = feat_syn
176
+
177
+ BN_flag = False
178
+ for module in model.modules():
179
+ if 'BatchNorm' in module._get_name(): #BatchNorm
180
+ BN_flag = True
181
+ if BN_flag:
182
+ model.train() # for updating the mu, sigma of BatchNorm
183
+ output_real = model.forward(features, adj_norm)
184
+ for module in model.modules():
185
+ if 'BatchNorm' in module._get_name(): #BatchNorm
186
+ module.eval() # fix mu and sigma of every BatchNorm layer
187
+
188
+ loss = torch.tensor(0.0).to(self.device)
189
+ for c in range(data.nclass):
190
+ if c not in self.num_class_dict:
191
+ continue
192
+
193
+ batch_size, n_id, adjs = data.retrieve_class_sampler(
194
+ c, adj, transductive=False, args=args)
195
+
196
+ if args.nlayers == 1:
197
+ adjs = [adjs]
198
+ adjs = [adj.to(self.device) for adj in adjs]
199
+ output = model.forward_sampler(features[n_id], adjs)
200
+ loss_real = F.nll_loss(output, labels[n_id[:batch_size]])
201
+ gw_real = torch.autograd.grad(loss_real, model_parameters)
202
+ gw_real = list((_.detach().clone() for _ in gw_real))
203
+
204
+ ind = syn_class_indices[c]
205
+ if args.nlayers == 1:
206
+ adj_syn_norm_list = [adj_syn_norm[ind[0]: ind[1]]]
207
+ else:
208
+ adj_syn_norm_list = [adj_syn_norm]*(args.nlayers-1) + \
209
+ [adj_syn_norm[ind[0]: ind[1]]]
210
+
211
+ output_syn = model.forward_sampler_syn(feat_syn, adj_syn_norm_list)
212
+ loss_syn = F.nll_loss(output_syn, labels_syn[ind[0]: ind[1]])
213
+
214
+ gw_syn = torch.autograd.grad(loss_syn, model_parameters, create_graph=True)
215
+ coeff = self.num_class_dict[c] / max(self.num_class_dict.values())
216
+ loss += coeff * match_loss(gw_syn, gw_real, args, device=self.device)
217
+
218
+ loss_avg += loss.item()
219
+ # TODO: regularize
220
+ if args.alpha > 0:
221
+ loss_reg = args.alpha * regularization(adj_syn, utils.tensor2onehot(labels_syn))
222
+ # else:
223
+ else:
224
+ loss_reg = torch.tensor(0)
225
+
226
+ loss = loss + loss_reg
227
+
228
+ # update sythetic graph
229
+ self.optimizer_feat.zero_grad()
230
+ self.optimizer_pge.zero_grad()
231
+ loss.backward()
232
+ if it % 50 < 10:
233
+ self.optimizer_pge.step()
234
+ else:
235
+ self.optimizer_feat.step()
236
+
237
+ if args.debug and ol % 5 ==0:
238
+ print('Gradient matching loss:', loss.item())
239
+
240
+ if ol == outer_loop - 1:
241
+ # print('loss_reg:', loss_reg.item())
242
+ # print('Gradient matching loss:', loss.item())
243
+ break
244
+
245
+
246
+ feat_syn_inner = feat_syn.detach()
247
+ adj_syn_inner = pge.inference(feat_syn)
248
+ adj_syn_inner_norm = utils.normalize_adj_tensor(adj_syn_inner, sparse=False)
249
+ feat_syn_inner_norm = feat_syn_inner
250
+ for j in range(inner_loop):
251
+ optimizer_model.zero_grad()
252
+ output_syn_inner = model.forward(feat_syn_inner_norm, adj_syn_inner_norm)
253
+ loss_syn_inner = F.nll_loss(output_syn_inner, labels_syn)
254
+ loss_syn_inner.backward()
255
+ optimizer_model.step() # update gnn param
256
+
257
+ loss_avg /= (data.nclass*outer_loop)
258
+ if it % 50 == 0:
259
+ print('Epoch {}, loss_avg: {}'.format(it, loss_avg))
260
+
261
+ eval_epochs = [100, 200, 400, 600, 800, 1000, 1200, 1400, 1600, 1800, 2000, 3000, 4000, 5000]
262
+
263
+ if verbose and it in eval_epochs:
264
+ # if verbose and (it+1) % 500 == 0:
265
+ res = []
266
+ runs = 1 if args.dataset in ['ogbn-arxiv', 'reddit', 'flickr'] else 3
267
+ for i in range(runs):
268
+ # self.test()
269
+ res.append(self.test_with_val())
270
+ res = np.array(res)
271
+ print('Test:',
272
+ repr([res.mean(0), res.std(0)]))
273
+
274
+
275
+
276
+ def get_sub_adj_feat(self, features):
277
+ data = self.data
278
+ args = self.args
279
+ idx_selected = []
280
+
281
+ from collections import Counter;
282
+ counter = Counter(self.labels_syn.cpu().numpy())
283
+
284
+ for c in range(data.nclass):
285
+ tmp = data.retrieve_class(c, num=counter[c])
286
+ tmp = list(tmp)
287
+ idx_selected = idx_selected + tmp
288
+ idx_selected = np.array(idx_selected).reshape(-1)
289
+ features = features[idx_selected]
290
+
291
+ # adj_knn = torch.zeros((data.nclass*args.nsamples, data.nclass*args.nsamples)).to(self.device)
292
+ # for i in range(data.nclass):
293
+ # idx = np.arange(i*args.nsamples, i*args.nsamples+args.nsamples)
294
+ # adj_knn[np.ix_(idx, idx)] = 1
295
+
296
+ from sklearn.metrics.pairwise import cosine_similarity
297
+ # features[features!=0] = 1
298
+ k = 2
299
+ sims = cosine_similarity(features.cpu().numpy())
300
+ sims[(np.arange(len(sims)), np.arange(len(sims)))] = 0
301
+ for i in range(len(sims)):
302
+ indices_argsort = np.argsort(sims[i])
303
+ sims[i, indices_argsort[: -k]] = 0
304
+ adj_knn = torch.FloatTensor(sims).to(self.device)
305
+ return features, adj_knn
306
+
307
+
308
+ def get_loops(args):
309
+ # Get the two hyper-parameters of outer-loop and inner-loop.
310
+ # The following values are empirically good.
311
+ if args.one_step:
312
+ return 10, 0
313
+
314
+ if args.dataset in ['ogbn-arxiv']:
315
+ return 20, 0
316
+ if args.dataset in ['reddit']:
317
+ return args.outer, args.inner
318
+ if args.dataset in ['flickr']:
319
+ return args.outer, args.inner
320
+ # return 10, 1
321
+ if args.dataset in ['cora']:
322
+ return 20, 10
323
+ if args.dataset in ['citeseer']:
324
+ return 20, 5 # at least 200 epochs
325
+ else:
326
+ return 20, 5
327
+
GCond/gcond_agent_transduct.py ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torch.nn import Parameter
6
+ import torch.nn.functional as F
7
+ from utils import match_loss, regularization, row_normalize_tensor
8
+ import deeprobust.graph.utils as utils
9
+ from copy import deepcopy
10
+ import numpy as np
11
+ from tqdm import tqdm
12
+ from models.gcn import GCN
13
+ from models.sgc import SGC
14
+ from models.sgc_multi import SGC as SGC1
15
+ from models.parametrized_adj import PGE
16
+ import scipy.sparse as sp
17
+ from torch_sparse import SparseTensor
18
+
19
+
20
+ class GCond:
21
+
22
+ def __init__(self, data, args, device='cuda', **kwargs):
23
+ self.data = data
24
+ self.args = args
25
+ self.device = device
26
+
27
+ # n = data.nclass * args.nsamples
28
+ n = int(data.feat_train.shape[0] * args.reduction_rate)
29
+ # from collections import Counter; print(Counter(data.labels_train))
30
+
31
+ d = data.feat_train.shape[1]
32
+ self.nnodes_syn = n
33
+ self.feat_syn = nn.Parameter(torch.FloatTensor(n, d).to(device))
34
+ self.pge = PGE(nfeat=d, nnodes=n, device=device,args=args).to(device)
35
+
36
+ self.labels_syn = torch.LongTensor(self.generate_labels_syn(data)).to(device)
37
+
38
+ self.reset_parameters()
39
+ self.optimizer_feat = torch.optim.Adam([self.feat_syn], lr=args.lr_feat)
40
+ self.optimizer_pge = torch.optim.Adam(self.pge.parameters(), lr=args.lr_adj)
41
+ print('adj_syn:', (n,n), 'feat_syn:', self.feat_syn.shape)
42
+
43
+ def reset_parameters(self):
44
+ self.feat_syn.data.copy_(torch.randn(self.feat_syn.size()))
45
+
46
+ def generate_labels_syn(self, data):
47
+ from collections import Counter
48
+ counter = Counter(data.labels_train)
49
+ num_class_dict = {}
50
+ n = len(data.labels_train)
51
+
52
+ sorted_counter = sorted(counter.items(), key=lambda x:x[1])
53
+ sum_ = 0
54
+ labels_syn = []
55
+ self.syn_class_indices = {}
56
+ for ix, (c, num) in enumerate(sorted_counter):
57
+ if ix == len(sorted_counter) - 1:
58
+ num_class_dict[c] = int(n * self.args.reduction_rate) - sum_
59
+ self.syn_class_indices[c] = [len(labels_syn), len(labels_syn) + num_class_dict[c]]
60
+ labels_syn += [c] * num_class_dict[c]
61
+ else:
62
+ num_class_dict[c] = max(int(num * self.args.reduction_rate), 1)
63
+ sum_ += num_class_dict[c]
64
+ self.syn_class_indices[c] = [len(labels_syn), len(labels_syn) + num_class_dict[c]]
65
+ labels_syn += [c] * num_class_dict[c]
66
+
67
+ self.num_class_dict = num_class_dict
68
+ return labels_syn
69
+
70
+
71
+ def test_with_val(self, verbose=True):
72
+ res = []
73
+
74
+ data, device = self.data, self.device
75
+ feat_syn, pge, labels_syn = self.feat_syn.detach(), \
76
+ self.pge, self.labels_syn
77
+
78
+ # with_bn = True if args.dataset in ['ogbn-arxiv'] else False
79
+ model = GCN(nfeat=feat_syn.shape[1], nhid=self.args.hidden, dropout=0.5,
80
+ weight_decay=5e-4, nlayers=2,
81
+ nclass=data.nclass, device=device).to(device)
82
+
83
+ if self.args.dataset in ['ogbn-arxiv']:
84
+ model = GCN(nfeat=feat_syn.shape[1], nhid=self.args.hidden, dropout=0.5,
85
+ weight_decay=0e-4, nlayers=2, with_bn=False,
86
+ nclass=data.nclass, device=device).to(device)
87
+
88
+ adj_syn = pge.inference(feat_syn)
89
+ args = self.args
90
+
91
+ if self.args.save:
92
+ torch.save(adj_syn, f'saved_ours/adj_{args.dataset}_{args.reduction_rate}_{args.seed}.pt')
93
+ torch.save(feat_syn, f'saved_ours/feat_{args.dataset}_{args.reduction_rate}_{args.seed}.pt')
94
+
95
+ if self.args.lr_adj == 0:
96
+ n = len(labels_syn)
97
+ adj_syn = torch.zeros((n, n))
98
+
99
+ model.fit_with_val(feat_syn, adj_syn, labels_syn, data,
100
+ train_iters=600, normalize=True, verbose=False)
101
+
102
+ model.eval()
103
+ labels_test = torch.LongTensor(data.labels_test).cuda()
104
+
105
+ labels_train = torch.LongTensor(data.labels_train).cuda()
106
+ output = model.predict(data.feat_train, data.adj_train)
107
+ loss_train = F.nll_loss(output, labels_train)
108
+ acc_train = utils.accuracy(output, labels_train)
109
+ if verbose:
110
+ print("Train set results:",
111
+ "loss= {:.4f}".format(loss_train.item()),
112
+ "accuracy= {:.4f}".format(acc_train.item()))
113
+ res.append(acc_train.item())
114
+
115
+ # Full graph
116
+ output = model.predict(data.feat_full, data.adj_full)
117
+ loss_test = F.nll_loss(output[data.idx_test], labels_test)
118
+ acc_test = utils.accuracy(output[data.idx_test], labels_test)
119
+ res.append(acc_test.item())
120
+ if verbose:
121
+ print("Test set results:",
122
+ "loss= {:.4f}".format(loss_test.item()),
123
+ "accuracy= {:.4f}".format(acc_test.item()))
124
+ return res
125
+
126
+ def train(self, verbose=True):
127
+ args = self.args
128
+ data = self.data
129
+ feat_syn, pge, labels_syn = self.feat_syn, self.pge, self.labels_syn
130
+ features, adj, labels = data.feat_full, data.adj_full, data.labels_full
131
+ idx_train = data.idx_train
132
+
133
+ syn_class_indices = self.syn_class_indices
134
+
135
+ features, adj, labels = utils.to_tensor(features, adj, labels, device=self.device)
136
+
137
+ feat_sub, adj_sub = self.get_sub_adj_feat(features)
138
+ self.feat_syn.data.copy_(feat_sub)
139
+
140
+ if utils.is_sparse_tensor(adj):
141
+ adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
142
+ else:
143
+ adj_norm = utils.normalize_adj_tensor(adj)
144
+
145
+ adj = adj_norm
146
+ adj = SparseTensor(row=adj._indices()[0], col=adj._indices()[1],
147
+ value=adj._values(), sparse_sizes=adj.size()).t()
148
+
149
+
150
+ outer_loop, inner_loop = get_loops(args)
151
+ loss_avg = 0
152
+
153
+ for it in range(args.epochs+1):
154
+ if args.dataset in ['ogbn-arxiv']:
155
+ model = SGC1(nfeat=feat_syn.shape[1], nhid=self.args.hidden,
156
+ dropout=0.0, with_bn=False,
157
+ weight_decay=0e-4, nlayers=2,
158
+ nclass=data.nclass,
159
+ device=self.device).to(self.device)
160
+ else:
161
+ if args.sgc == 1:
162
+ model = SGC(nfeat=data.feat_train.shape[1], nhid=args.hidden,
163
+ nclass=data.nclass, dropout=args.dropout,
164
+ nlayers=args.nlayers, with_bn=False,
165
+ device=self.device).to(self.device)
166
+ else:
167
+ model = GCN(nfeat=data.feat_train.shape[1], nhid=args.hidden,
168
+ nclass=data.nclass, dropout=args.dropout, nlayers=args.nlayers,
169
+ device=self.device).to(self.device)
170
+
171
+
172
+ model.initialize()
173
+
174
+ model_parameters = list(model.parameters())
175
+
176
+ optimizer_model = torch.optim.Adam(model_parameters, lr=args.lr_model)
177
+ model.train()
178
+
179
+ for ol in range(outer_loop):
180
+ adj_syn = pge(self.feat_syn)
181
+ adj_syn_norm = utils.normalize_adj_tensor(adj_syn, sparse=False)
182
+ feat_syn_norm = feat_syn
183
+
184
+ BN_flag = False
185
+ for module in model.modules():
186
+ if 'BatchNorm' in module._get_name(): #BatchNorm
187
+ BN_flag = True
188
+ if BN_flag:
189
+ model.train() # for updating the mu, sigma of BatchNorm
190
+ output_real = model.forward(features, adj_norm)
191
+ for module in model.modules():
192
+ if 'BatchNorm' in module._get_name(): #BatchNorm
193
+ module.eval() # fix mu and sigma of every BatchNorm layer
194
+
195
+ loss = torch.tensor(0.0).to(self.device)
196
+ for c in range(data.nclass):
197
+ batch_size, n_id, adjs = data.retrieve_class_sampler(
198
+ c, adj, transductive=True, args=args)
199
+ if args.nlayers == 1:
200
+ adjs = [adjs]
201
+
202
+ adjs = [adj.to(self.device) for adj in adjs]
203
+ output = model.forward_sampler(features[n_id], adjs)
204
+ loss_real = F.nll_loss(output, labels[n_id[:batch_size]])
205
+
206
+ gw_real = torch.autograd.grad(loss_real, model_parameters)
207
+ gw_real = list((_.detach().clone() for _ in gw_real))
208
+ output_syn = model.forward(feat_syn, adj_syn_norm)
209
+
210
+ ind = syn_class_indices[c]
211
+ loss_syn = F.nll_loss(
212
+ output_syn[ind[0]: ind[1]],
213
+ labels_syn[ind[0]: ind[1]])
214
+ gw_syn = torch.autograd.grad(loss_syn, model_parameters, create_graph=True)
215
+ coeff = self.num_class_dict[c] / max(self.num_class_dict.values())
216
+ loss += coeff * match_loss(gw_syn, gw_real, args, device=self.device)
217
+
218
+ loss_avg += loss.item()
219
+ # TODO: regularize
220
+ if args.alpha > 0:
221
+ loss_reg = args.alpha * regularization(adj_syn, utils.tensor2onehot(labels_syn))
222
+ else:
223
+ loss_reg = torch.tensor(0)
224
+
225
+ loss = loss + loss_reg
226
+
227
+ # update sythetic graph
228
+ self.optimizer_feat.zero_grad()
229
+ self.optimizer_pge.zero_grad()
230
+ loss.backward()
231
+ if it % 50 < 10:
232
+ self.optimizer_pge.step()
233
+ else:
234
+ self.optimizer_feat.step()
235
+
236
+ if args.debug and ol % 5 ==0:
237
+ print('Gradient matching loss:', loss.item())
238
+
239
+ if ol == outer_loop - 1:
240
+ # print('loss_reg:', loss_reg.item())
241
+ # print('Gradient matching loss:', loss.item())
242
+ break
243
+
244
+ feat_syn_inner = feat_syn.detach()
245
+ adj_syn_inner = pge.inference(feat_syn_inner)
246
+ adj_syn_inner_norm = utils.normalize_adj_tensor(adj_syn_inner, sparse=False)
247
+ feat_syn_inner_norm = feat_syn_inner
248
+ for j in range(inner_loop):
249
+ optimizer_model.zero_grad()
250
+ output_syn_inner = model.forward(feat_syn_inner_norm, adj_syn_inner_norm)
251
+ loss_syn_inner = F.nll_loss(output_syn_inner, labels_syn)
252
+ loss_syn_inner.backward()
253
+ # print(loss_syn_inner.item())
254
+ optimizer_model.step() # update gnn param
255
+
256
+
257
+ loss_avg /= (data.nclass*outer_loop)
258
+ if it % 50 == 0:
259
+ print('Epoch {}, loss_avg: {}'.format(it, loss_avg))
260
+
261
+ eval_epochs = [400, 600, 800, 1000, 1200, 1600, 2000, 3000, 4000, 5000]
262
+
263
+ if verbose and it in eval_epochs:
264
+ # if verbose and (it+1) % 50 == 0:
265
+ res = []
266
+ runs = 1 if args.dataset in ['ogbn-arxiv'] else 3
267
+ for i in range(runs):
268
+ if args.dataset in ['ogbn-arxiv']:
269
+ res.append(self.test_with_val())
270
+ else:
271
+ res.append(self.test_with_val())
272
+
273
+ res = np.array(res)
274
+ print('Train/Test Mean Accuracy:',
275
+ repr([res.mean(0), res.std(0)]))
276
+
277
+ def get_sub_adj_feat(self, features):
278
+ data = self.data
279
+ args = self.args
280
+ idx_selected = []
281
+
282
+ from collections import Counter;
283
+ counter = Counter(self.labels_syn.cpu().numpy())
284
+
285
+ for c in range(data.nclass):
286
+ tmp = data.retrieve_class(c, num=counter[c])
287
+ tmp = list(tmp)
288
+ idx_selected = idx_selected + tmp
289
+ idx_selected = np.array(idx_selected).reshape(-1)
290
+ features = features[self.data.idx_train][idx_selected]
291
+
292
+ # adj_knn = torch.zeros((data.nclass*args.nsamples, data.nclass*args.nsamples)).to(self.device)
293
+ # for i in range(data.nclass):
294
+ # idx = np.arange(i*args.nsamples, i*args.nsamples+args.nsamples)
295
+ # adj_knn[np.ix_(idx, idx)] = 1
296
+
297
+ from sklearn.metrics.pairwise import cosine_similarity
298
+ # features[features!=0] = 1
299
+ k = 2
300
+ sims = cosine_similarity(features.cpu().numpy())
301
+ sims[(np.arange(len(sims)), np.arange(len(sims)))] = 0
302
+ for i in range(len(sims)):
303
+ indices_argsort = np.argsort(sims[i])
304
+ sims[i, indices_argsort[: -k]] = 0
305
+ adj_knn = torch.FloatTensor(sims).to(self.device)
306
+ return features, adj_knn
307
+
308
+
309
+ def get_loops(args):
310
+ # Get the two hyper-parameters of outer-loop and inner-loop.
311
+ # The following values are empirically good.
312
+ if args.one_step:
313
+ if args.dataset =='ogbn-arxiv':
314
+ return 5, 0
315
+ return 1, 0
316
+ if args.dataset in ['ogbn-arxiv']:
317
+ return args.outer, args.inner
318
+ if args.dataset in ['cora']:
319
+ return 20, 15 # sgc
320
+ if args.dataset in ['citeseer']:
321
+ return 20, 15
322
+ if args.dataset in ['physics']:
323
+ return 20, 10
324
+ else:
325
+ return 20, 10
326
+
GCond/models/__pycache__/gcn.cpython-312.pyc ADDED
Binary file (21 kB). View file
 
GCond/models/gat.py ADDED
@@ -0,0 +1,312 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Extended from https://github.com/rusty1s/pytorch_geometric/tree/master/benchmark/citation
3
+ """
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import math
7
+ import torch
8
+ import torch.optim as optim
9
+ from torch.nn.parameter import Parameter
10
+ from torch.nn.modules.module import Module
11
+ from deeprobust.graph import utils
12
+ from copy import deepcopy
13
+ from torch_geometric.nn import SGConv
14
+ from torch_geometric.nn import APPNP as ModuleAPPNP
15
+ # from torch_geometric.nn import GATConv
16
+ from .mygatconv import GATConv
17
+ import numpy as np
18
+ import scipy.sparse as sp
19
+
20
+ from torch.nn import Linear
21
+ from itertools import repeat
22
+
23
+
24
+ class GAT(torch.nn.Module):
25
+
26
+ def __init__(self, nfeat, nhid, nclass, heads=8, output_heads=1, dropout=0.5, lr=0.01,
27
+ weight_decay=5e-4, with_bias=True, device=None, **kwargs):
28
+
29
+ super(GAT, self).__init__()
30
+
31
+ assert device is not None, "Please specify 'device'!"
32
+ self.device = device
33
+ self.dropout = dropout
34
+ self.lr = lr
35
+ self.weight_decay = weight_decay
36
+
37
+ if 'dataset' in kwargs:
38
+ if kwargs['dataset'] in ['ogbn-arxiv']:
39
+ dropout = 0.7 # arxiv
40
+ elif kwargs['dataset'] in ['reddit']:
41
+ dropout = 0.05; self.dropout = 0.1; self.weight_decay = 5e-4
42
+ # self.weight_decay = 5e-2; dropout=0.05; self.dropout=0.1
43
+ elif kwargs['dataset'] in ['citeseer']:
44
+ dropout = 0.7
45
+ self.weight_decay = 5e-4
46
+ elif kwargs['dataset'] in ['flickr']:
47
+ dropout = 0.8
48
+ # nhid=8; heads=8
49
+ # self.dropout=0.1
50
+ else:
51
+ dropout = 0.7 # cora, citeseer, reddit
52
+ else:
53
+ dropout = 0.7
54
+ self.conv1 = GATConv(
55
+ nfeat,
56
+ nhid,
57
+ heads=heads,
58
+ dropout=dropout,
59
+ bias=with_bias)
60
+
61
+ self.conv2 = GATConv(
62
+ nhid * heads,
63
+ nclass,
64
+ heads=output_heads,
65
+ concat=False,
66
+ dropout=dropout,
67
+ bias=with_bias)
68
+
69
+ self.output = None
70
+ self.best_model = None
71
+ self.best_output = None
72
+
73
+ # def forward(self, data):
74
+ # x, edge_index = data.x, data.edge_index
75
+ # x = F.dropout(x, p=self.dropout, training=self.training)
76
+ # x = F.elu(self.conv1(x, edge_index))
77
+ # x = F.dropout(x, p=self.dropout, training=self.training)
78
+ # x = self.conv2(x, edge_index)
79
+ # return F.log_softmax(x, dim=1)
80
+
81
+ def forward(self, data):
82
+ # x, edge_index = data.x, data.edge_index
83
+ x, edge_index, edge_weight = data.x, data.edge_index, data.edge_weight
84
+ x = F.dropout(x, p=self.dropout, training=self.training)
85
+ x = F.elu(self.conv1(x, edge_index, edge_weight=edge_weight))
86
+ # print(self.conv1.att_l.sum())
87
+ x = F.dropout(x, p=self.dropout, training=self.training)
88
+ x = self.conv2(x, edge_index, edge_weight=edge_weight)
89
+ return F.log_softmax(x, dim=1)
90
+
91
+
92
+ def initialize(self):
93
+ """Initialize parameters of GAT.
94
+ """
95
+ self.conv1.reset_parameters()
96
+ self.conv2.reset_parameters()
97
+
98
+
99
+ def fit(self, feat, adj, labels, idx, data=None, train_iters=600, initialize=True, verbose=False, patience=None, noval=False, **kwargs):
100
+
101
+ data_train = GraphData(feat, adj, labels)
102
+ data_train = Dpr2Pyg(data_train)[0]
103
+
104
+ data_test = Dpr2Pyg(GraphData(data.feat_test, data.adj_test, None))[0]
105
+
106
+ if noval:
107
+ data_val = GraphData(data.feat_val, data.adj_val, None)
108
+ data_val = Dpr2Pyg(data_val)[0]
109
+ else:
110
+ data_val = GraphData(data.feat_full, data.adj_full, None)
111
+ data_val = Dpr2Pyg(data_val)[0]
112
+
113
+ labels_val = torch.LongTensor(data.labels_val).to(self.device)
114
+
115
+ if initialize:
116
+ self.initialize()
117
+
118
+ if len(data_train.y.shape) > 1:
119
+ self.multi_label = True
120
+ self.loss = torch.nn.BCELoss()
121
+ else:
122
+ self.multi_label = False
123
+ self.loss = F.nll_loss
124
+
125
+
126
+ data_train.y = data_train.y.float() if self.multi_label else data_train.y
127
+ # data_val.y = data_val.y.float() if self.multi_label else data_val.y
128
+
129
+ if verbose:
130
+ print('=== training gat model ===')
131
+
132
+ optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
133
+ best_acc_val = 0
134
+ best_loss_val = 100
135
+ for i in range(train_iters):
136
+ # if i == train_iters // 2:
137
+ if i in [1500]:
138
+ lr = self.lr*0.1
139
+ optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=self.weight_decay)
140
+
141
+ self.train()
142
+ optimizer.zero_grad()
143
+ output = self.forward(data_train)
144
+ loss_train = self.loss(output, data_train.y)
145
+ loss_train.backward()
146
+ optimizer.step()
147
+
148
+ with torch.no_grad():
149
+ self.eval()
150
+
151
+ output = self.forward(data_val)
152
+ if noval:
153
+ loss_val = F.nll_loss(output, labels_val)
154
+ acc_val = utils.accuracy(output, labels_val)
155
+ else:
156
+ loss_val = F.nll_loss(output[data.idx_val], labels_val)
157
+ acc_val = utils.accuracy(output[data.idx_val], labels_val)
158
+
159
+
160
+ if loss_val < best_loss_val:
161
+ best_loss_val = loss_val
162
+ self.output = output
163
+ weights = deepcopy(self.state_dict())
164
+
165
+ if acc_val > best_acc_val:
166
+ best_acc_val = acc_val
167
+ self.output = output
168
+ weights = deepcopy(self.state_dict())
169
+ # print(best_acc_val)
170
+ # output = self.forward(data_test)
171
+ # labels_test = torch.LongTensor(data.labels_test).to(self.device)
172
+ # loss_test = F.nll_loss(output, labels_test)
173
+ # acc_test = utils.accuracy(output, labels_test)
174
+ # print('acc_test:', acc_test.item())
175
+
176
+
177
+
178
+ if verbose and i % 100 == 0:
179
+ print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
180
+
181
+ if verbose:
182
+ print('=== picking the best model according to the performance on validation ===')
183
+ self.load_state_dict(weights)
184
+
185
+ def test(self, data_test):
186
+ """Evaluate GCN performance
187
+ """
188
+ self.eval()
189
+ with torch.no_grad():
190
+ output = self.forward(data_test)
191
+ evaluate(output, data_test.y, self.args)
192
+
193
+ # @torch.no_grad()
194
+ # def predict(self, data):
195
+ # self.eval()
196
+ # return self.forward(data)
197
+ @torch.no_grad()
198
+ def predict(self, feat, adj):
199
+ self.eval()
200
+ data = GraphData(feat, adj, None)
201
+ data = Dpr2Pyg(data)[0]
202
+ return self.forward(data)
203
+
204
+ @torch.no_grad()
205
+ def predict_unnorm(self, feat, adj):
206
+ self.eval()
207
+ data = GraphData(feat, adj, None)
208
+ data = Dpr2Pyg(data)[0]
209
+
210
+ return self.forward(data)
211
+
212
+
213
+ class GraphData:
214
+
215
+ def __init__(self, features, adj, labels, idx_train=None, idx_val=None, idx_test=None):
216
+ self.adj = adj
217
+ self.features = features
218
+ self.labels = labels
219
+ self.idx_train = idx_train
220
+ self.idx_val = idx_val
221
+ self.idx_test = idx_test
222
+
223
+
224
+ from torch_geometric.data import InMemoryDataset, Data
225
+ import scipy.sparse as sp
226
+
227
+ class Dpr2Pyg(InMemoryDataset):
228
+
229
+ def __init__(self, dpr_data, transform=None, **kwargs):
230
+ root = 'data/' # dummy root; does not mean anything
231
+ self.dpr_data = dpr_data
232
+ super(Dpr2Pyg, self).__init__(root, transform)
233
+ pyg_data = self.process()
234
+ self.data, self.slices = self.collate([pyg_data])
235
+ self.transform = transform
236
+
237
+ def process____(self):
238
+ dpr_data = self.dpr_data
239
+ try:
240
+ edge_index = torch.LongTensor(dpr_data.adj.nonzero().cpu()).cuda().T
241
+ except:
242
+ edge_index = torch.LongTensor(dpr_data.adj.nonzero()).cuda()
243
+ # by default, the features in pyg data is dense
244
+ try:
245
+ x = torch.FloatTensor(dpr_data.features.cpu()).float().cuda()
246
+ except:
247
+ x = torch.FloatTensor(dpr_data.features).float().cuda()
248
+ try:
249
+ y = torch.LongTensor(dpr_data.labels.cpu()).cuda()
250
+ except:
251
+ y = dpr_data.labels
252
+
253
+ data = Data(x=x, edge_index=edge_index, y=y)
254
+ data.train_mask = None
255
+ data.val_mask = None
256
+ data.test_mask = None
257
+ return data
258
+
259
+ def process(self):
260
+ dpr_data = self.dpr_data
261
+ if type(dpr_data.adj) == torch.Tensor:
262
+ adj_selfloop = dpr_data.adj + torch.eye(dpr_data.adj.shape[0]).cuda()
263
+ edge_index_selfloop = adj_selfloop.nonzero().T
264
+ edge_index = edge_index_selfloop
265
+ edge_weight = adj_selfloop[edge_index_selfloop[0], edge_index_selfloop[1]]
266
+ else:
267
+ adj_selfloop = dpr_data.adj + sp.eye(dpr_data.adj.shape[0])
268
+ edge_index = torch.LongTensor(adj_selfloop.nonzero()).cuda()
269
+ edge_weight = torch.FloatTensor(adj_selfloop[adj_selfloop.nonzero()]).cuda()
270
+
271
+ # by default, the features in pyg data is dense
272
+ try:
273
+ x = torch.FloatTensor(dpr_data.features.cpu()).float().cuda()
274
+ except:
275
+ x = torch.FloatTensor(dpr_data.features).float().cuda()
276
+ try:
277
+ y = torch.LongTensor(dpr_data.labels.cpu()).cuda()
278
+ except:
279
+ y = dpr_data.labels
280
+
281
+
282
+ data = Data(x=x, edge_index=edge_index, y=y, edge_weight=edge_weight)
283
+ data.train_mask = None
284
+ data.val_mask = None
285
+ data.test_mask = None
286
+ return data
287
+
288
+ def get(self, idx):
289
+ data = self.data.__class__()
290
+
291
+ if hasattr(self.data, '__num_nodes__'):
292
+ data.num_nodes = self.data.__num_nodes__[idx]
293
+
294
+ for key in self.data.keys:
295
+ item, slices = self.data[key], self.slices[key]
296
+ s = list(repeat(slice(None), item.dim()))
297
+ s[self.data.__cat_dim__(key, item)] = slice(slices[idx],
298
+ slices[idx + 1])
299
+ data[key] = item[s]
300
+ return data
301
+
302
+ @property
303
+ def raw_file_names(self):
304
+ return ['some_file_1', 'some_file_2', ...]
305
+
306
+ @property
307
+ def processed_file_names(self):
308
+ return ['data.pt']
309
+
310
+ def _download(self):
311
+ pass
312
+
GCond/models/gcn.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+ import math
4
+ import torch
5
+ import torch.optim as optim
6
+ from torch.nn.parameter import Parameter
7
+ from torch.nn.modules.module import Module
8
+ from deeprobust.graph import utils
9
+ from copy import deepcopy
10
+ from sklearn.metrics import f1_score
11
+ from torch.nn import init
12
+ import torch_sparse
13
+
14
+
15
+ class GraphConvolution(Module):
16
+ """Simple GCN layer, similar to https://github.com/tkipf/pygcn
17
+ """
18
+
19
+ def __init__(self, in_features, out_features, with_bias=True):
20
+ super(GraphConvolution, self).__init__()
21
+ self.in_features = in_features
22
+ self.out_features = out_features
23
+ self.weight = Parameter(torch.FloatTensor(in_features, out_features))
24
+ self.bias = Parameter(torch.FloatTensor(out_features))
25
+ self.reset_parameters()
26
+
27
+ def reset_parameters(self):
28
+ stdv = 1. / math.sqrt(self.weight.T.size(1))
29
+ self.weight.data.uniform_(-stdv, stdv)
30
+ if self.bias is not None:
31
+ self.bias.data.uniform_(-stdv, stdv)
32
+
33
+ def forward(self, input, adj):
34
+ """ Graph Convolutional Layer forward function
35
+ """
36
+ if input.data.is_sparse:
37
+ support = torch.spmm(input, self.weight)
38
+ else:
39
+ support = torch.mm(input, self.weight)
40
+ if isinstance(adj, torch_sparse.SparseTensor):
41
+ output = torch_sparse.matmul(adj, support)
42
+ else:
43
+ output = torch.spmm(adj, support)
44
+ if self.bias is not None:
45
+ return output + self.bias
46
+ else:
47
+ return output
48
+
49
+ def __repr__(self):
50
+ return self.__class__.__name__ + ' (' \
51
+ + str(self.in_features) + ' -> ' \
52
+ + str(self.out_features) + ')'
53
+
54
+
55
+ class GCN(nn.Module):
56
+
57
+ def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5, lr=0.01, weight_decay=5e-4,
58
+ with_relu=True, with_bias=True, with_bn=False, device=None):
59
+
60
+ super(GCN, self).__init__()
61
+
62
+ assert device is not None, "Please specify 'device'!"
63
+ self.device = device
64
+ self.nfeat = nfeat
65
+ self.nclass = nclass
66
+
67
+ self.layers = nn.ModuleList([])
68
+
69
+ if nlayers == 1:
70
+ self.layers.append(GraphConvolution(nfeat, nclass, with_bias=with_bias))
71
+ else:
72
+ if with_bn:
73
+ self.bns = torch.nn.ModuleList()
74
+ self.bns.append(nn.BatchNorm1d(nhid))
75
+ self.layers.append(GraphConvolution(nfeat, nhid, with_bias=with_bias))
76
+ for i in range(nlayers-2):
77
+ self.layers.append(GraphConvolution(nhid, nhid, with_bias=with_bias))
78
+ if with_bn:
79
+ self.bns.append(nn.BatchNorm1d(nhid))
80
+ self.layers.append(GraphConvolution(nhid, nclass, with_bias=with_bias))
81
+
82
+ self.dropout = dropout
83
+ self.lr = lr
84
+ if not with_relu:
85
+ self.weight_decay = 0
86
+ else:
87
+ self.weight_decay = weight_decay
88
+ self.with_relu = with_relu
89
+ self.with_bn = with_bn
90
+ self.with_bias = with_bias
91
+ self.output = None
92
+ self.best_model = None
93
+ self.best_output = None
94
+ self.adj_norm = None
95
+ self.features = None
96
+ self.multi_label = None
97
+
98
+ def forward(self, x, adj):
99
+ for ix, layer in enumerate(self.layers):
100
+ x = layer(x, adj)
101
+ if ix != len(self.layers) - 1:
102
+ x = self.bns[ix](x) if self.with_bn else x
103
+ if self.with_relu:
104
+ x = F.relu(x)
105
+ x = F.dropout(x, self.dropout, training=self.training)
106
+
107
+ if self.multi_label:
108
+ return torch.sigmoid(x)
109
+ else:
110
+ return F.log_softmax(x, dim=1)
111
+
112
+ def forward_sampler(self, x, adjs):
113
+ # for ix, layer in enumerate(self.layers):
114
+ for ix, (adj, _, size) in enumerate(adjs):
115
+ x = self.layers[ix](x, adj)
116
+ if ix != len(self.layers) - 1:
117
+ x = self.bns[ix](x) if self.with_bn else x
118
+ if self.with_relu:
119
+ x = F.relu(x)
120
+ x = F.dropout(x, self.dropout, training=self.training)
121
+
122
+ if self.multi_label:
123
+ return torch.sigmoid(x)
124
+ else:
125
+ return F.log_softmax(x, dim=1)
126
+
127
+ def forward_sampler_syn(self, x, adjs):
128
+ for ix, (adj) in enumerate(adjs):
129
+ x = self.layers[ix](x, adj)
130
+ if ix != len(self.layers) - 1:
131
+ x = self.bns[ix](x) if self.with_bn else x
132
+ if self.with_relu:
133
+ x = F.relu(x)
134
+ x = F.dropout(x, self.dropout, training=self.training)
135
+
136
+ if self.multi_label:
137
+ return torch.sigmoid(x)
138
+ else:
139
+ return F.log_softmax(x, dim=1)
140
+
141
+
142
+ def initialize(self):
143
+ """Initialize parameters of GCN.
144
+ """
145
+ for layer in self.layers:
146
+ layer.reset_parameters()
147
+ if self.with_bn:
148
+ for bn in self.bns:
149
+ bn.reset_parameters()
150
+
151
+ def fit(self, features, adj, labels, idx_train, idx_val=None, train_iters=200, initialize=True, verbose=False, normalize=True, patience=None, **kwargs):
152
+
153
+ if initialize:
154
+ self.initialize()
155
+
156
+ # features, adj, labels = data.feat_train, data.adj_train, data.labels_train
157
+ if type(adj) is not torch.Tensor:
158
+ features, adj, labels = utils.to_tensor(features, adj, labels, device=self.device)
159
+ else:
160
+ features = features.to(self.device)
161
+ adj = adj.to(self.device)
162
+ labels = labels.to(self.device)
163
+
164
+ if normalize:
165
+ if utils.is_sparse_tensor(adj):
166
+ adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
167
+ else:
168
+ adj_norm = utils.normalize_adj_tensor(adj)
169
+ else:
170
+ adj_norm = adj
171
+
172
+ if 'feat_norm' in kwargs and kwargs['feat_norm']:
173
+ from utils import row_normalize_tensor
174
+ features = row_normalize_tensor(features-features.min())
175
+
176
+ self.adj_norm = adj_norm
177
+ self.features = features
178
+
179
+ if len(labels.shape) > 1:
180
+ self.multi_label = True
181
+ self.loss = torch.nn.BCELoss()
182
+ else:
183
+ self.multi_label = False
184
+ self.loss = F.nll_loss
185
+
186
+ labels = labels.float() if self.multi_label else labels
187
+ self.labels = labels
188
+
189
+
190
+ if idx_val is not None:
191
+ self._train_with_val2(labels, idx_train, idx_val, train_iters, verbose)
192
+ else:
193
+ self._train_without_val2(labels, idx_train, train_iters, verbose)
194
+
195
+ def _train_without_val2(self, labels, idx_train, train_iters, verbose):
196
+ self.train()
197
+ optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
198
+ for i in range(train_iters):
199
+ if i == train_iters // 2:
200
+ lr = self.lr*0.1
201
+ optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=self.weight_decay)
202
+
203
+ optimizer.zero_grad()
204
+ output = self.forward(self.features, self.adj_norm)
205
+ loss_train = self.loss(output[idx_train], labels[idx_train])
206
+ loss_train.backward()
207
+ optimizer.step()
208
+ if verbose and i % 10 == 0:
209
+ print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
210
+
211
+ self.eval()
212
+ output = self.forward(self.features, self.adj_norm)
213
+ self.output = output
214
+
215
+ def fit_with_val(self, features, adj, labels, data, train_iters=200, initialize=True, verbose=False, normalize=True, patience=None, noval=False, **kwargs):
216
+ '''data: full data class'''
217
+ if initialize:
218
+ self.initialize()
219
+
220
+ if type(adj) is not torch.Tensor:
221
+ features, adj, labels = utils.to_tensor(features, adj, labels, device=self.device)
222
+ else:
223
+ features = features.to(self.device)
224
+ adj = adj.to(self.device)
225
+ labels = labels.to(self.device)
226
+
227
+ if normalize:
228
+ if utils.is_sparse_tensor(adj):
229
+ adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
230
+ else:
231
+ adj_norm = utils.normalize_adj_tensor(adj)
232
+ else:
233
+ adj_norm = adj
234
+
235
+ if 'feat_norm' in kwargs and kwargs['feat_norm']:
236
+ from utils import row_normalize_tensor
237
+ features = row_normalize_tensor(features-features.min())
238
+
239
+ self.adj_norm = adj_norm
240
+ self.features = features
241
+
242
+ if len(labels.shape) > 1:
243
+ self.multi_label = True
244
+ self.loss = torch.nn.BCELoss()
245
+ else:
246
+ self.multi_label = False
247
+ self.loss = F.nll_loss
248
+
249
+ labels = labels.float() if self.multi_label else labels
250
+ self.labels = labels
251
+
252
+ if noval:
253
+ self._train_with_val(labels, data, train_iters, verbose, adj_val=True)
254
+ else:
255
+ self._train_with_val(labels, data, train_iters, verbose)
256
+
257
+ def _train_with_val(self, labels, data, train_iters, verbose, adj_val=False):
258
+ if adj_val:
259
+ feat_full, adj_full = data.feat_val, data.adj_val
260
+ else:
261
+ feat_full, adj_full = data.feat_full, data.adj_full
262
+ feat_full, adj_full = utils.to_tensor(feat_full, adj_full, device=self.device)
263
+ adj_full_norm = utils.normalize_adj_tensor(adj_full, sparse=True)
264
+ labels_val = torch.LongTensor(data.labels_val).to(self.device)
265
+
266
+ if verbose:
267
+ print('=== training gcn model ===')
268
+ optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
269
+
270
+ best_acc_val = 0
271
+
272
+ for i in range(train_iters):
273
+ if i == train_iters // 2:
274
+ lr = self.lr*0.1
275
+ optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=self.weight_decay)
276
+
277
+ self.train()
278
+ optimizer.zero_grad()
279
+ output = self.forward(self.features, self.adj_norm)
280
+ loss_train = self.loss(output, labels)
281
+ loss_train.backward()
282
+ optimizer.step()
283
+
284
+ if verbose and i % 100 == 0:
285
+ print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
286
+
287
+ with torch.no_grad():
288
+ self.eval()
289
+ output = self.forward(feat_full, adj_full_norm)
290
+
291
+ if adj_val:
292
+ loss_val = F.nll_loss(output, labels_val)
293
+ acc_val = utils.accuracy(output, labels_val)
294
+ else:
295
+ loss_val = F.nll_loss(output[data.idx_val], labels_val)
296
+ acc_val = utils.accuracy(output[data.idx_val], labels_val)
297
+
298
+ if acc_val > best_acc_val:
299
+ best_acc_val = acc_val
300
+ self.output = output
301
+ weights = deepcopy(self.state_dict())
302
+
303
+ if verbose:
304
+ print('=== picking the best model according to the performance on validation ===')
305
+ self.load_state_dict(weights)
306
+
307
+
308
+ def test(self, idx_test):
309
+ """Evaluate GCN performance on test set.
310
+ Parameters
311
+ ----------
312
+ idx_test :
313
+ node testing indices
314
+ """
315
+ self.eval()
316
+ output = self.predict()
317
+ # output = self.output
318
+ loss_test = F.nll_loss(output[idx_test], self.labels[idx_test])
319
+ acc_test = utils.accuracy(output[idx_test], self.labels[idx_test])
320
+ print("Test set results:",
321
+ "loss= {:.4f}".format(loss_test.item()),
322
+ "accuracy= {:.4f}".format(acc_test.item()))
323
+ return acc_test.item()
324
+
325
+
326
+ @torch.no_grad()
327
+ def predict(self, features=None, adj=None):
328
+ """By default, the inputs should be unnormalized adjacency
329
+ Parameters
330
+ ----------
331
+ features :
332
+ node features. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.
333
+ adj :
334
+ adjcency matrix. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.
335
+ Returns
336
+ -------
337
+ torch.FloatTensor
338
+ output (log probabilities) of GCN
339
+ """
340
+
341
+ self.eval()
342
+ if features is None and adj is None:
343
+ return self.forward(self.features, self.adj_norm)
344
+ else:
345
+ if type(adj) is not torch.Tensor:
346
+ features, adj = utils.to_tensor(features, adj, device=self.device)
347
+
348
+ self.features = features
349
+ if utils.is_sparse_tensor(adj):
350
+ self.adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
351
+ else:
352
+ self.adj_norm = utils.normalize_adj_tensor(adj)
353
+ return self.forward(self.features, self.adj_norm)
354
+
355
+ @torch.no_grad()
356
+ def predict_unnorm(self, features=None, adj=None):
357
+ self.eval()
358
+ if features is None and adj is None:
359
+ return self.forward(self.features, self.adj_norm)
360
+ else:
361
+ if type(adj) is not torch.Tensor:
362
+ features, adj = utils.to_tensor(features, adj, device=self.device)
363
+
364
+ self.features = features
365
+ self.adj_norm = adj
366
+ return self.forward(self.features, self.adj_norm)
367
+
368
+
369
+ def _train_with_val2(self, labels, idx_train, idx_val, train_iters, verbose):
370
+ if verbose:
371
+ print('=== training gcn model ===')
372
+ optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
373
+
374
+ best_loss_val = 100
375
+ best_acc_val = 0
376
+
377
+ for i in range(train_iters):
378
+ if i == train_iters // 2:
379
+ lr = self.lr*0.1
380
+ optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=self.weight_decay)
381
+
382
+ self.train()
383
+ optimizer.zero_grad()
384
+ output = self.forward(self.features, self.adj_norm)
385
+ loss_train = F.nll_loss(output[idx_train], labels[idx_train])
386
+ loss_train.backward()
387
+ optimizer.step()
388
+
389
+ if verbose and i % 10 == 0:
390
+ print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
391
+
392
+ self.eval()
393
+ output = self.forward(self.features, self.adj_norm)
394
+ loss_val = F.nll_loss(output[idx_val], labels[idx_val])
395
+ acc_val = utils.accuracy(output[idx_val], labels[idx_val])
396
+
397
+ if acc_val > best_acc_val:
398
+ best_acc_val = acc_val
399
+ self.output = output
400
+ weights = deepcopy(self.state_dict())
401
+
402
+ if verbose:
403
+ print('=== picking the best model according to the performance on validation ===')
404
+ self.load_state_dict(weights)
GCond/models/myappnp.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """multiple transformaiton and multiple propagation"""
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ import torch
6
+ import torch.optim as optim
7
+ from torch.nn.parameter import Parameter
8
+ from torch.nn.modules.module import Module
9
+ from deeprobust.graph import utils
10
+ from copy import deepcopy
11
+ from sklearn.metrics import f1_score
12
+ from torch.nn import init
13
+ import torch_sparse
14
+
15
+
16
+ class APPNP(nn.Module):
17
+
18
+ def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5, lr=0.01, weight_decay=5e-4,
19
+ ntrans=1, with_bias=True, with_bn=False, device=None):
20
+
21
+ super(APPNP, self).__init__()
22
+
23
+ assert device is not None, "Please specify 'device'!"
24
+ self.device = device
25
+ self.nfeat = nfeat
26
+ self.nclass = nclass
27
+ self.alpha = 0.1
28
+
29
+ with_bn = False
30
+
31
+ self.layers = nn.ModuleList([])
32
+ if ntrans == 1:
33
+ self.layers.append(MyLinear(nfeat, nclass))
34
+ else:
35
+ self.layers.append(MyLinear(nfeat, nhid))
36
+ if with_bn:
37
+ self.bns = torch.nn.ModuleList()
38
+ self.bns.append(nn.BatchNorm1d(nhid))
39
+ for i in range(ntrans-2):
40
+ if with_bn:
41
+ self.bns.append(nn.BatchNorm1d(nhid))
42
+ self.layers.append(MyLinear(nhid, nhid))
43
+ self.layers.append(MyLinear(nhid, nclass))
44
+
45
+ self.nlayers = nlayers
46
+ self.weight_decay = weight_decay
47
+ self.dropout = dropout
48
+ self.lr = lr
49
+ self.with_bn = with_bn
50
+ self.with_bias = with_bias
51
+ self.output = None
52
+ self.best_model = None
53
+ self.best_output = None
54
+ self.adj_norm = None
55
+ self.features = None
56
+ self.multi_label = None
57
+ self.sparse_dropout = SparseDropout(dprob=0)
58
+
59
+ def forward(self, x, adj):
60
+ for ix, layer in enumerate(self.layers):
61
+ x = layer(x)
62
+ if ix != len(self.layers) - 1:
63
+ x = self.bns[ix](x) if self.with_bn else x
64
+ x = F.relu(x)
65
+ x = F.dropout(x, self.dropout, training=self.training)
66
+
67
+ h = x
68
+ # here nlayers means K
69
+ for i in range(self.nlayers):
70
+ # adj_drop = self.sparse_dropout(adj, training=self.training)
71
+ adj_drop = adj
72
+ x = torch.spmm(adj_drop, x)
73
+ x = x * (1 - self.alpha)
74
+ x = x + self.alpha * h
75
+
76
+
77
+ if self.multi_label:
78
+ return torch.sigmoid(x)
79
+ else:
80
+ return F.log_softmax(x, dim=1)
81
+
82
+ def forward_sampler(self, x, adjs):
83
+ for ix, layer in enumerate(self.layers):
84
+ x = layer(x)
85
+ if ix != len(self.layers) - 1:
86
+ x = self.bns[ix](x) if self.with_bn else x
87
+ x = F.relu(x)
88
+ x = F.dropout(x, self.dropout, training=self.training)
89
+
90
+ h = x
91
+ for ix, (adj, _, size) in enumerate(adjs):
92
+ # x_target = x[: size[1]]
93
+ # x = self.layers[ix]((x, x_target), edge_index)
94
+ # adj = adj.to(self.device)
95
+ # adj_drop = F.dropout(adj, p=self.dropout)
96
+ adj_drop = adj
97
+ h = h[: size[1]]
98
+ x = torch_sparse.matmul(adj_drop, x)
99
+ x = x * (1 - self.alpha)
100
+ x = x + self.alpha * h
101
+
102
+ if self.multi_label:
103
+ return torch.sigmoid(x)
104
+ else:
105
+ return F.log_softmax(x, dim=1)
106
+
107
+ def forward_sampler_syn(self, x, adjs):
108
+ for ix, layer in enumerate(self.layers):
109
+ x = layer(x)
110
+ if ix != len(self.layers) - 1:
111
+ x = self.bns[ix](x) if self.with_bn else x
112
+ x = F.relu(x)
113
+ x = F.dropout(x, self.dropout, training=self.training)
114
+
115
+ for ix, (adj) in enumerate(adjs):
116
+ # x_target = x[: size[1]]
117
+ # x = self.layers[ix]((x, x_target), edge_index)
118
+ # adj = adj.to(self.device)
119
+ x = torch_sparse.matmul(adj, x)
120
+
121
+ if self.multi_label:
122
+ return torch.sigmoid(x)
123
+ else:
124
+ return F.log_softmax(x, dim=1)
125
+
126
+
127
+ def initialize(self):
128
+ """Initialize parameters of GCN.
129
+ """
130
+ for layer in self.layers:
131
+ layer.reset_parameters()
132
+ if self.with_bn:
133
+ for bn in self.bns:
134
+ bn.reset_parameters()
135
+
136
+ def fit_with_val(self, features, adj, labels, data, train_iters=200, initialize=True, verbose=False, normalize=True, patience=None, noval=False, **kwargs):
137
+ '''data: full data class'''
138
+ if initialize:
139
+ self.initialize()
140
+
141
+ # features, adj, labels = data.feat_train, data.adj_train, data.labels_train
142
+ if type(adj) is not torch.Tensor:
143
+ features, adj, labels = utils.to_tensor(features, adj, labels, device=self.device)
144
+ else:
145
+ features = features.to(self.device)
146
+ adj = adj.to(self.device)
147
+ labels = labels.to(self.device)
148
+
149
+ if normalize:
150
+ if utils.is_sparse_tensor(adj):
151
+ adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
152
+ else:
153
+ adj_norm = utils.normalize_adj_tensor(adj)
154
+ else:
155
+ adj_norm = adj
156
+
157
+ if 'feat_norm' in kwargs and kwargs['feat_norm']:
158
+ from utils import row_normalize_tensor
159
+ features = row_normalize_tensor(features-features.min())
160
+
161
+ self.adj_norm = adj_norm
162
+ self.features = features
163
+
164
+ if len(labels.shape) > 1:
165
+ self.multi_label = True
166
+ self.loss = torch.nn.BCELoss()
167
+ else:
168
+ self.multi_label = False
169
+ self.loss = F.nll_loss
170
+
171
+ labels = labels.float() if self.multi_label else labels
172
+ self.labels = labels
173
+
174
+ if noval:
175
+ # self._train_without_val(labels, data, train_iters, verbose)
176
+ # self._train_without_val(labels, data, train_iters, verbose)
177
+ self._train_with_val(labels, data, train_iters, verbose, adj_val=True)
178
+ else:
179
+ self._train_with_val(labels, data, train_iters, verbose)
180
+
181
+ def _train_with_val(self, labels, data, train_iters, verbose, adj_val=False):
182
+ if adj_val:
183
+ feat_full, adj_full = data.feat_val, data.adj_val
184
+ else:
185
+ feat_full, adj_full = data.feat_full, data.adj_full
186
+
187
+ feat_full, adj_full = utils.to_tensor(feat_full, adj_full, device=self.device)
188
+ adj_full_norm = utils.normalize_adj_tensor(adj_full, sparse=True)
189
+ labels_val = torch.LongTensor(data.labels_val).to(self.device)
190
+
191
+ if verbose:
192
+ print('=== training gcn model ===')
193
+ optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
194
+
195
+ best_acc_val = 0
196
+
197
+ for i in range(train_iters):
198
+ if i == train_iters // 2:
199
+ lr = self.lr*0.1
200
+ optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=self.weight_decay)
201
+
202
+ self.train()
203
+ optimizer.zero_grad()
204
+ output = self.forward(self.features, self.adj_norm)
205
+ loss_train = self.loss(output, labels)
206
+ loss_train.backward()
207
+ optimizer.step()
208
+
209
+ if verbose and i % 100 == 0:
210
+ print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
211
+
212
+ with torch.no_grad():
213
+ self.eval()
214
+ output = self.forward(feat_full, adj_full_norm)
215
+ if adj_val:
216
+ loss_val = F.nll_loss(output, labels_val)
217
+ acc_val = utils.accuracy(output, labels_val)
218
+ else:
219
+ loss_val = F.nll_loss(output[data.idx_val], labels_val)
220
+ acc_val = utils.accuracy(output[data.idx_val], labels_val)
221
+
222
+ if acc_val > best_acc_val:
223
+ best_acc_val = acc_val
224
+ self.output = output
225
+ weights = deepcopy(self.state_dict())
226
+
227
+ if verbose:
228
+ print('=== picking the best model according to the performance on validation ===')
229
+ self.load_state_dict(weights)
230
+
231
+
232
+ def test(self, idx_test):
233
+ """Evaluate GCN performance on test set.
234
+ Parameters
235
+ ----------
236
+ idx_test :
237
+ node testing indices
238
+ """
239
+ self.eval()
240
+ output = self.predict()
241
+ # output = self.output
242
+ loss_test = F.nll_loss(output[idx_test], self.labels[idx_test])
243
+ acc_test = utils.accuracy(output[idx_test], self.labels[idx_test])
244
+ print("Test set results:",
245
+ "loss= {:.4f}".format(loss_test.item()),
246
+ "accuracy= {:.4f}".format(acc_test.item()))
247
+ return acc_test.item()
248
+
249
+
250
+ @torch.no_grad()
251
+ def predict(self, features=None, adj=None):
252
+ """By default, the inputs should be unnormalized adjacency
253
+ Parameters
254
+ ----------
255
+ features :
256
+ node features. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.
257
+ adj :
258
+ adjcency matrix. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.
259
+ Returns
260
+ -------
261
+ torch.FloatTensor
262
+ output (log probabilities) of GCN
263
+ """
264
+
265
+ self.eval()
266
+ if features is None and adj is None:
267
+ return self.forward(self.features, self.adj_norm)
268
+ else:
269
+ if type(adj) is not torch.Tensor:
270
+ features, adj = utils.to_tensor(features, adj, device=self.device)
271
+
272
+ self.features = features
273
+ if utils.is_sparse_tensor(adj):
274
+ self.adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
275
+ else:
276
+ self.adj_norm = utils.normalize_adj_tensor(adj)
277
+ return self.forward(self.features, self.adj_norm)
278
+
279
+ @torch.no_grad()
280
+ def predict_unnorm(self, features=None, adj=None):
281
+ self.eval()
282
+ if features is None and adj is None:
283
+ return self.forward(self.features, self.adj_norm)
284
+ else:
285
+ if type(adj) is not torch.Tensor:
286
+ features, adj = utils.to_tensor(features, adj, device=self.device)
287
+
288
+ self.features = features
289
+ self.adj_norm = adj
290
+ return self.forward(self.features, self.adj_norm)
291
+
292
+
293
+
294
+ class MyLinear(Module):
295
+ """Simple Linear layer, modified from https://github.com/tkipf/pygcn
296
+ """
297
+
298
+ def __init__(self, in_features, out_features, with_bias=True):
299
+ super(MyLinear, self).__init__()
300
+ self.in_features = in_features
301
+ self.out_features = out_features
302
+ self.weight = Parameter(torch.FloatTensor(in_features, out_features))
303
+ if with_bias:
304
+ self.bias = Parameter(torch.FloatTensor(out_features))
305
+ else:
306
+ self.register_parameter('bias', None)
307
+ self.reset_parameters()
308
+
309
+ def reset_parameters(self):
310
+ # stdv = 1. / math.sqrt(self.weight.size(1))
311
+ stdv = 1. / math.sqrt(self.weight.T.size(1))
312
+ self.weight.data.uniform_(-stdv, stdv)
313
+ if self.bias is not None:
314
+ self.bias.data.uniform_(-stdv, stdv)
315
+
316
+ def forward(self, input):
317
+ if input.data.is_sparse:
318
+ support = torch.spmm(input, self.weight)
319
+ else:
320
+ support = torch.mm(input, self.weight)
321
+ output = support
322
+ if self.bias is not None:
323
+ return output + self.bias
324
+ else:
325
+ return output
326
+
327
+ def __repr__(self):
328
+ return self.__class__.__name__ + ' (' \
329
+ + str(self.in_features) + ' -> ' \
330
+ + str(self.out_features) + ')'
331
+
332
+ class SparseDropout(torch.nn.Module):
333
+ def __init__(self, dprob=0.5):
334
+ super(SparseDropout, self).__init__()
335
+ self.kprob=1-dprob
336
+
337
+ def forward(self, x, training):
338
+ if training:
339
+ mask=((torch.rand(x._values().size())+(self.kprob)).floor()).type(torch.bool)
340
+ rc=x._indices()[:,mask]
341
+ val=x._values()[mask]*(1.0/self.kprob)
342
+ return torch.sparse.FloatTensor(rc, val, x.size())
343
+ else:
344
+ return x
GCond/models/myappnp1.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """multiple transformaiton and multiple propagation"""
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ import torch
6
+ import torch.optim as optim
7
+ from torch.nn.parameter import Parameter
8
+ from torch.nn.modules.module import Module
9
+ from deeprobust.graph import utils
10
+ from copy import deepcopy
11
+ from sklearn.metrics import f1_score
12
+ from torch.nn import init
13
+ import torch_sparse
14
+
15
+
16
+ class APPNP1(nn.Module):
17
+
18
+ def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5, lr=0.01, weight_decay=5e-4,
19
+ with_relu=True, with_bias=True, with_bn=False, device=None):
20
+
21
+ super(APPNP1, self).__init__()
22
+
23
+ assert device is not None, "Please specify 'device'!"
24
+ self.device = device
25
+ self.nfeat = nfeat
26
+ self.nclass = nclass
27
+ self.alpha = 0.1
28
+
29
+ if with_bn:
30
+ self.bns = torch.nn.ModuleList()
31
+ self.bns.append(nn.BatchNorm1d(nhid))
32
+
33
+ self.layers = nn.ModuleList([])
34
+ # self.layers.append(MyLinear(nfeat, nclass))
35
+ self.layers.append(MyLinear(nfeat, nhid))
36
+ # self.layers.append(MyLinear(nhid, nhid))
37
+ self.layers.append(MyLinear(nhid, nclass))
38
+
39
+ # if nlayers == 1:
40
+ # self.layers.append(nn.Linear(nfeat, nclass))
41
+ # else:
42
+ # self.layers.append(nn.Linear(nfeat, nhid))
43
+ # for i in range(nlayers-2):
44
+ # self.layers.append(nn.Linear(nhid, nhid))
45
+ # self.layers.append(nn.Linear(nhid, nclass))
46
+
47
+ self.nlayers = nlayers
48
+ self.dropout = dropout
49
+ self.lr = lr
50
+ if not with_relu:
51
+ self.weight_decay = 0
52
+ else:
53
+ self.weight_decay = weight_decay
54
+ self.with_relu = with_relu
55
+ self.with_bn = with_bn
56
+ self.with_bias = with_bias
57
+ self.output = None
58
+ self.best_model = None
59
+ self.best_output = None
60
+ self.adj_norm = None
61
+ self.features = None
62
+ self.multi_label = None
63
+ self.sparse_dropout = SparseDropout(dprob=0)
64
+
65
+ def forward(self, x, adj):
66
+ for ix, layer in enumerate(self.layers):
67
+ x = layer(x)
68
+ if ix != len(self.layers) - 1:
69
+ x = self.bns[ix](x) if self.with_bn else x
70
+ x = F.relu(x)
71
+ x = F.dropout(x, self.dropout, training=self.training)
72
+
73
+ h = x
74
+ # here nlayers means K
75
+ for i in range(self.nlayers):
76
+ # adj_drop = self.sparse_dropout(adj, training=self.training)
77
+ adj_drop = adj
78
+ x = torch.spmm(adj_drop, x)
79
+ x = x * (1 - self.alpha)
80
+ x = x + self.alpha * h
81
+
82
+
83
+ if self.multi_label:
84
+ return torch.sigmoid(x)
85
+ else:
86
+ return F.log_softmax(x, dim=1)
87
+
88
+ def forward_sampler(self, x, adjs):
89
+ for ix, layer in enumerate(self.layers):
90
+ x = layer(x)
91
+ if ix != len(self.layers) - 1:
92
+ x = self.bns[ix](x) if self.with_bn else x
93
+ x = F.relu(x)
94
+ x = F.dropout(x, self.dropout, training=self.training)
95
+
96
+ h = x
97
+ for ix, (adj, _, size) in enumerate(adjs):
98
+ # x_target = x[: size[1]]
99
+ # x = self.layers[ix]((x, x_target), edge_index)
100
+ # adj = adj.to(self.device)
101
+ # adj_drop = F.dropout(adj, p=self.dropout)
102
+ adj_drop = adj
103
+ h = h[: size[1]]
104
+ x = torch_sparse.matmul(adj_drop, x)
105
+ x = x * (1 - self.alpha)
106
+ x = x + self.alpha * h
107
+
108
+ if self.multi_label:
109
+ return torch.sigmoid(x)
110
+ else:
111
+ return F.log_softmax(x, dim=1)
112
+
113
+ def forward_sampler_syn(self, x, adjs):
114
+ for ix, layer in enumerate(self.layers):
115
+ x = layer(x)
116
+ if ix != len(self.layers) - 1:
117
+ x = self.bns[ix](x) if self.with_bn else x
118
+ x = F.relu(x)
119
+ x = F.dropout(x, self.dropout, training=self.training)
120
+
121
+ for ix, (adj) in enumerate(adjs):
122
+ # x_target = x[: size[1]]
123
+ # x = self.layers[ix]((x, x_target), edge_index)
124
+ # adj = adj.to(self.device)
125
+ x = torch_sparse.matmul(adj, x)
126
+
127
+ if self.multi_label:
128
+ return torch.sigmoid(x)
129
+ else:
130
+ return F.log_softmax(x, dim=1)
131
+
132
+
133
+ def initialize(self):
134
+ """Initialize parameters of GCN.
135
+ """
136
+ for layer in self.layers:
137
+ layer.reset_parameters()
138
+ if self.with_bn:
139
+ for bn in self.bns:
140
+ bn.reset_parameters()
141
+
142
+ def fit_with_val(self, features, adj, labels, data, train_iters=200, initialize=True, verbose=False, normalize=True, patience=None, noval=False, **kwargs):
143
+ '''data: full data class'''
144
+ if initialize:
145
+ self.initialize()
146
+
147
+ # features, adj, labels = data.feat_train, data.adj_train, data.labels_train
148
+ if type(adj) is not torch.Tensor:
149
+ features, adj, labels = utils.to_tensor(features, adj, labels, device=self.device)
150
+ else:
151
+ features = features.to(self.device)
152
+ adj = adj.to(self.device)
153
+ labels = labels.to(self.device)
154
+
155
+ if normalize:
156
+ if utils.is_sparse_tensor(adj):
157
+ adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
158
+ else:
159
+ adj_norm = utils.normalize_adj_tensor(adj)
160
+ else:
161
+ adj_norm = adj
162
+
163
+ if 'feat_norm' in kwargs and kwargs['feat_norm']:
164
+ from utils import row_normalize_tensor
165
+ features = row_normalize_tensor(features-features.min())
166
+
167
+ self.adj_norm = adj_norm
168
+ self.features = features
169
+
170
+ if len(labels.shape) > 1:
171
+ self.multi_label = True
172
+ self.loss = torch.nn.BCELoss()
173
+ else:
174
+ self.multi_label = False
175
+ self.loss = F.nll_loss
176
+
177
+ labels = labels.float() if self.multi_label else labels
178
+ self.labels = labels
179
+
180
+
181
+ if noval:
182
+ self._train_with_val(labels, data, train_iters, verbose, adj_val=True)
183
+ else:
184
+ self._train_with_val(labels, data, train_iters, verbose)
185
+
186
+ def _train_with_val(self, labels, data, train_iters, verbose, adj_val=False):
187
+ if adj_val:
188
+ feat_full, adj_full = data.feat_val, data.adj_val
189
+ else:
190
+ feat_full, adj_full = data.feat_full, data.adj_full
191
+ feat_full, adj_full = utils.to_tensor(feat_full, adj_full, device=self.device)
192
+ adj_full_norm = utils.normalize_adj_tensor(adj_full, sparse=True)
193
+ labels_val = torch.LongTensor(data.labels_val).to(self.device)
194
+
195
+ if verbose:
196
+ print('=== training gcn model ===')
197
+ optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
198
+
199
+ best_acc_val = 0
200
+
201
+ for i in range(train_iters):
202
+ if i == train_iters // 2:
203
+ lr = self.lr*0.1
204
+ optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=self.weight_decay)
205
+
206
+ self.train()
207
+ optimizer.zero_grad()
208
+ output = self.forward(self.features, self.adj_norm)
209
+ loss_train = self.loss(output, labels)
210
+ loss_train.backward()
211
+ optimizer.step()
212
+
213
+ if verbose and i % 100 == 0:
214
+ print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
215
+
216
+ with torch.no_grad():
217
+ self.eval()
218
+ output = self.forward(feat_full, adj_full_norm)
219
+ if adj_val:
220
+ loss_val = F.nll_loss(output, labels_val)
221
+ acc_val = utils.accuracy(output, labels_val)
222
+ else:
223
+ loss_val = F.nll_loss(output[data.idx_val], labels_val)
224
+ acc_val = utils.accuracy(output[data.idx_val], labels_val)
225
+
226
+ if acc_val > best_acc_val:
227
+ best_acc_val = acc_val
228
+ self.output = output
229
+ weights = deepcopy(self.state_dict())
230
+
231
+ if verbose:
232
+ print('=== picking the best model according to the performance on validation ===')
233
+ self.load_state_dict(weights)
234
+
235
+
236
+ def test(self, idx_test):
237
+ """Evaluate GCN performance on test set.
238
+ Parameters
239
+ ----------
240
+ idx_test :
241
+ node testing indices
242
+ """
243
+ self.eval()
244
+ output = self.predict()
245
+ # output = self.output
246
+ loss_test = F.nll_loss(output[idx_test], self.labels[idx_test])
247
+ acc_test = utils.accuracy(output[idx_test], self.labels[idx_test])
248
+ print("Test set results:",
249
+ "loss= {:.4f}".format(loss_test.item()),
250
+ "accuracy= {:.4f}".format(acc_test.item()))
251
+ return acc_test.item()
252
+
253
+
254
+ @torch.no_grad()
255
+ def predict(self, features=None, adj=None):
256
+ """By default, the inputs should be unnormalized adjacency
257
+ Parameters
258
+ ----------
259
+ features :
260
+ node features. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.
261
+ adj :
262
+ adjcency matrix. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.
263
+ Returns
264
+ -------
265
+ torch.FloatTensor
266
+ output (log probabilities) of GCN
267
+ """
268
+
269
+ self.eval()
270
+ if features is None and adj is None:
271
+ return self.forward(self.features, self.adj_norm)
272
+ else:
273
+ if type(adj) is not torch.Tensor:
274
+ features, adj = utils.to_tensor(features, adj, device=self.device)
275
+
276
+ self.features = features
277
+ if utils.is_sparse_tensor(adj):
278
+ self.adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
279
+ else:
280
+ self.adj_norm = utils.normalize_adj_tensor(adj)
281
+ return self.forward(self.features, self.adj_norm)
282
+
283
+ @torch.no_grad()
284
+ def predict_unnorm(self, features=None, adj=None):
285
+ self.eval()
286
+ if features is None and adj is None:
287
+ return self.forward(self.features, self.adj_norm)
288
+ else:
289
+ if type(adj) is not torch.Tensor:
290
+ features, adj = utils.to_tensor(features, adj, device=self.device)
291
+
292
+ self.features = features
293
+ self.adj_norm = adj
294
+ return self.forward(self.features, self.adj_norm)
295
+
296
+
297
+
298
+ class MyLinear(Module):
299
+ """Simple Linear layer, modified from https://github.com/tkipf/pygcn
300
+ """
301
+
302
+ def __init__(self, in_features, out_features, with_bias=True):
303
+ super(MyLinear, self).__init__()
304
+ self.in_features = in_features
305
+ self.out_features = out_features
306
+ self.weight = Parameter(torch.FloatTensor(in_features, out_features))
307
+ if with_bias:
308
+ self.bias = Parameter(torch.FloatTensor(out_features))
309
+ else:
310
+ self.register_parameter('bias', None)
311
+ self.reset_parameters()
312
+
313
+ def reset_parameters(self):
314
+ # stdv = 1. / math.sqrt(self.weight.size(1))
315
+ stdv = 1. / math.sqrt(self.weight.T.size(1))
316
+ self.weight.data.uniform_(-stdv, stdv)
317
+ if self.bias is not None:
318
+ self.bias.data.uniform_(-stdv, stdv)
319
+
320
+ def forward(self, input):
321
+ if input.data.is_sparse:
322
+ support = torch.spmm(input, self.weight)
323
+ else:
324
+ support = torch.mm(input, self.weight)
325
+ output = support
326
+ if self.bias is not None:
327
+ return output + self.bias
328
+ else:
329
+ return output
330
+
331
+ def __repr__(self):
332
+ return self.__class__.__name__ + ' (' \
333
+ + str(self.in_features) + ' -> ' \
334
+ + str(self.out_features) + ')'
335
+
336
+ class SparseDropout(torch.nn.Module):
337
+ def __init__(self, dprob=0.5):
338
+ super(SparseDropout, self).__init__()
339
+ self.kprob=1-dprob
340
+
341
+ def forward(self, x, training):
342
+ if training:
343
+ mask=((torch.rand(x._values().size())+(self.kprob)).floor()).type(torch.bool)
344
+ rc=x._indices()[:,mask]
345
+ val=x._values()[mask]*(1.0/self.kprob)
346
+ return torch.sparse.FloatTensor(rc, val, x.size())
347
+ else:
348
+ return x
GCond/models/mycheby.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+ import math
4
+ import torch
5
+ import torch.optim as optim
6
+ from torch.nn.parameter import Parameter
7
+ from torch.nn.modules.module import Module
8
+ from deeprobust.graph import utils
9
+ from copy import deepcopy
10
+ from sklearn.metrics import f1_score
11
+ from torch.nn import init
12
+ import torch_sparse
13
+ from torch_geometric.nn.inits import zeros
14
+ import scipy.sparse as sp
15
+ import numpy as np
16
+
17
+
18
+ class ChebConvolution(Module):
19
+ """Simple GCN layer, similar to https://github.com/tkipf/pygcn
20
+ """
21
+
22
+ def __init__(self, in_features, out_features, with_bias=True, single_param=True, K=2):
23
+ """set single_param to True to alleivate the overfitting issue"""
24
+ super(ChebConvolution, self).__init__()
25
+ self.in_features = in_features
26
+ self.out_features = out_features
27
+ self.lins = torch.nn.ModuleList([
28
+ MyLinear(in_features, out_features, with_bias=False) for _ in range(K)])
29
+ # self.lins = torch.nn.ModuleList([
30
+ # MyLinear(in_features, out_features, with_bias=True) for _ in range(K)])
31
+ if with_bias:
32
+ self.bias = Parameter(torch.Tensor(out_features))
33
+ else:
34
+ self.register_parameter('bias', None)
35
+ self.single_param = single_param
36
+ self.reset_parameters()
37
+
38
+ def reset_parameters(self):
39
+ for lin in self.lins:
40
+ lin.reset_parameters()
41
+ zeros(self.bias)
42
+
43
+ def forward(self, input, adj, size=None):
44
+ """ Graph Convolutional Layer forward function
45
+ """
46
+ # support = torch.mm(input, self.weight_l)
47
+ x = input
48
+ Tx_0 = x[:size[1]] if size is not None else x
49
+ Tx_1 = x # dummy
50
+ output = self.lins[0](Tx_0)
51
+
52
+ if len(self.lins) > 1:
53
+ if isinstance(adj, torch_sparse.SparseTensor):
54
+ Tx_1 = torch_sparse.matmul(adj, x)
55
+ else:
56
+ Tx_1 = torch.spmm(adj, x)
57
+
58
+ if self.single_param:
59
+ output = output + self.lins[0](Tx_1)
60
+ else:
61
+ output = output + self.lins[1](Tx_1)
62
+
63
+ for lin in self.lins[2:]:
64
+ if self.single_param:
65
+ lin = self.lins[0]
66
+ if isinstance(adj, torch_sparse.SparseTensor):
67
+ Tx_2 = torch_sparse.matmul(adj, Tx_1)
68
+ else:
69
+ Tx_2 = torch.spmm(adj, Tx_1)
70
+ Tx_2 = 2. * Tx_2 - Tx_0
71
+ output = output + lin.forward(Tx_2)
72
+ Tx_0, Tx_1 = Tx_1, Tx_2
73
+
74
+ if self.bias is not None:
75
+ return output + self.bias
76
+ else:
77
+ return output
78
+
79
+ def __repr__(self):
80
+ return self.__class__.__name__ + ' (' \
81
+ + str(self.in_features) + ' -> ' \
82
+ + str(self.out_features) + ')'
83
+
84
+
85
+ class Cheby(nn.Module):
86
+
87
+ def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5, lr=0.01, weight_decay=5e-4,
88
+ with_relu=True, with_bias=True, with_bn=False, device=None):
89
+
90
+ super(Cheby, self).__init__()
91
+
92
+ assert device is not None, "Please specify 'device'!"
93
+ self.device = device
94
+ self.nfeat = nfeat
95
+ self.nclass = nclass
96
+
97
+ self.layers = nn.ModuleList([])
98
+
99
+ if nlayers == 1:
100
+ self.layers.append(ChebConvolution(nfeat, nclass, with_bias=with_bias))
101
+ else:
102
+ if with_bn:
103
+ self.bns = torch.nn.ModuleList()
104
+ self.bns.append(nn.BatchNorm1d(nhid))
105
+ self.layers.append(ChebConvolution(nfeat, nhid, with_bias=with_bias))
106
+ for i in range(nlayers-2):
107
+ self.layers.append(ChebConvolution(nhid, nhid, with_bias=with_bias))
108
+ if with_bn:
109
+ self.bns.append(nn.BatchNorm1d(nhid))
110
+ self.layers.append(ChebConvolution(nhid, nclass, with_bias=with_bias))
111
+
112
+ # self.lin = MyLinear(nhid, nclass, with_bias=True)
113
+
114
+ # dropout = 0.5
115
+ self.dropout = dropout
116
+ self.lr = lr
117
+ self.weight_decay = weight_decay
118
+ self.with_relu = with_relu
119
+ self.with_bn = with_bn
120
+ self.with_bias = with_bias
121
+ self.output = None
122
+ self.best_model = None
123
+ self.best_output = None
124
+ self.adj_norm = None
125
+ self.features = None
126
+ self.multi_label = None
127
+
128
+ def forward(self, x, adj):
129
+ for ix, layer in enumerate(self.layers):
130
+ # x = F.dropout(x, 0.2, training=self.training)
131
+ x = layer(x, adj)
132
+ if ix != len(self.layers) - 1:
133
+ x = self.bns[ix](x) if self.with_bn else x
134
+ if self.with_relu:
135
+ x = F.relu(x)
136
+ x = F.dropout(x, self.dropout, training=self.training)
137
+ # x = F.dropout(x, 0.5, training=self.training)
138
+
139
+ if self.multi_label:
140
+ return torch.sigmoid(x)
141
+ else:
142
+ return F.log_softmax(x, dim=1)
143
+
144
+ def forward_sampler(self, x, adjs):
145
+ # TODO: do we need normalization?
146
+ # for ix, layer in enumerate(self.layers):
147
+ for ix, (adj, _, size) in enumerate(adjs):
148
+ # x_target = x[: size[1]]
149
+ # x = self.layers[ix]((x, x_target), edge_index)
150
+ # adj = adj.to(self.device)
151
+ x = self.layers[ix](x, adj, size=size)
152
+ if ix != len(self.layers) - 1:
153
+ x = self.bns[ix](x) if self.with_bn else x
154
+ if self.with_relu:
155
+ x = F.relu(x)
156
+ x = F.dropout(x, self.dropout, training=self.training)
157
+
158
+ if self.multi_label:
159
+ return torch.sigmoid(x)
160
+ else:
161
+ return F.log_softmax(x, dim=1)
162
+
163
+ def forward_sampler_syn(self, x, adjs):
164
+ for ix, (adj) in enumerate(adjs):
165
+ x = self.layers[ix](x, adj)
166
+ if ix != len(self.layers) - 1:
167
+ x = self.bns[ix](x) if self.with_bn else x
168
+ if self.with_relu:
169
+ x = F.relu(x)
170
+ x = F.dropout(x, self.dropout, training=self.training)
171
+
172
+ if self.multi_label:
173
+ return torch.sigmoid(x)
174
+ else:
175
+ return F.log_softmax(x, dim=1)
176
+
177
+
178
+ def initialize(self):
179
+ """Initialize parameters of GCN.
180
+ """
181
+ for layer in self.layers:
182
+ layer.reset_parameters()
183
+ if self.with_bn:
184
+ for bn in self.bns:
185
+ bn.reset_parameters()
186
+
187
+ def fit_with_val(self, features, adj, labels, data, train_iters=200, initialize=True, verbose=False, normalize=True, patience=None, noval=False, **kwargs):
188
+ '''data: full data class'''
189
+ if initialize:
190
+ self.initialize()
191
+
192
+ # features, adj, labels = data.feat_train, data.adj_train, data.labels_train
193
+
194
+ if type(adj) is not torch.Tensor:
195
+ features, adj, labels = utils.to_tensor(features, adj, labels, device=self.device)
196
+ else:
197
+ features = features.to(self.device)
198
+ adj = adj.to(self.device)
199
+ labels = labels.to(self.device)
200
+
201
+ adj = adj - torch.eye(adj.shape[0]).to(self.device) # cheby
202
+ if normalize:
203
+ adj_norm = utils.normalize_adj_tensor(adj)
204
+ else:
205
+ adj_norm = adj
206
+
207
+
208
+ if 'feat_norm' in kwargs and kwargs['feat_norm']:
209
+ from utils import row_normalize_tensor
210
+ features = row_normalize_tensor(features-features.min())
211
+
212
+ self.adj_norm = adj_norm
213
+ self.features = features
214
+
215
+ if len(labels.shape) > 1:
216
+ self.multi_label = True
217
+ self.loss = torch.nn.BCELoss()
218
+ else:
219
+ self.multi_label = False
220
+ self.loss = F.nll_loss
221
+
222
+ labels = labels.float() if self.multi_label else labels
223
+ self.labels = labels
224
+
225
+ if noval:
226
+ self._train_with_val(labels, data, train_iters, verbose, adj_val=True)
227
+ else:
228
+ self._train_with_val(labels, data, train_iters, verbose)
229
+
230
+ def _train_with_val(self, labels, data, train_iters, verbose, adj_val=False):
231
+ if adj_val:
232
+ feat_full, adj_full = data.feat_val, data.adj_val
233
+ else:
234
+ feat_full, adj_full = data.feat_full, data.adj_full
235
+ # adj_full = adj_full - sp.eye(adj_full.shape[0])
236
+ feat_full, adj_full = utils.to_tensor(feat_full, adj_full, device=self.device)
237
+ adj_full_norm = utils.normalize_adj_tensor(adj_full, sparse=True)
238
+ labels_val = torch.LongTensor(data.labels_val).to(self.device)
239
+
240
+ if verbose:
241
+ print('=== training gcn model ===')
242
+ optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
243
+
244
+ best_acc_val = 0
245
+ best_loss_val = 100
246
+
247
+ for i in range(train_iters):
248
+ if i == train_iters // 2:
249
+ lr = self.lr*0.1
250
+ optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=self.weight_decay)
251
+
252
+ self.train()
253
+ optimizer.zero_grad()
254
+ output = self.forward(self.features, self.adj_norm)
255
+ loss_train = self.loss(output, labels)
256
+ loss_train.backward()
257
+ optimizer.step()
258
+
259
+ if verbose and i % 100 == 0:
260
+ print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
261
+
262
+ with torch.no_grad():
263
+ self.eval()
264
+ output = self.forward(feat_full, adj_full_norm)
265
+ if adj_val:
266
+ loss_val = F.nll_loss(output, labels_val)
267
+ acc_val = utils.accuracy(output, labels_val)
268
+ else:
269
+ loss_val = F.nll_loss(output[data.idx_val], labels_val)
270
+ acc_val = utils.accuracy(output[data.idx_val], labels_val)
271
+
272
+ # if loss_val < best_loss_val:
273
+ # best_loss_val = loss_val
274
+ # self.output = output
275
+ # weights = deepcopy(self.state_dict())
276
+ # print(best_loss_val)
277
+
278
+ if acc_val > best_acc_val:
279
+ best_acc_val = acc_val
280
+ self.output = output
281
+ weights = deepcopy(self.state_dict())
282
+ # print(best_acc_val)
283
+
284
+ if verbose:
285
+ print('=== picking the best model according to the performance on validation ===')
286
+ self.load_state_dict(weights)
287
+
288
+ def test(self, idx_test):
289
+ """Evaluate GCN performance on test set.
290
+ Parameters
291
+ ----------
292
+ idx_test :
293
+ node testing indices
294
+ """
295
+ self.eval()
296
+ output = self.predict()
297
+ # output = self.output
298
+ loss_test = F.nll_loss(output[idx_test], self.labels[idx_test])
299
+ acc_test = utils.accuracy(output[idx_test], self.labels[idx_test])
300
+ print("Test set results:",
301
+ "loss= {:.4f}".format(loss_test.item()),
302
+ "accuracy= {:.4f}".format(acc_test.item()))
303
+ return acc_test.item()
304
+
305
+
306
+ @torch.no_grad()
307
+ def predict(self, features=None, adj=None):
308
+ """By default, the inputs should be unnormalized adjacency
309
+ Parameters
310
+ ----------
311
+ features :
312
+ node features. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.
313
+ adj :
314
+ adjcency matrix. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.
315
+ Returns
316
+ -------
317
+ torch.FloatTensor
318
+ output (log probabilities) of GCN
319
+ """
320
+
321
+ self.eval()
322
+ if features is None and adj is None:
323
+ return self.forward(self.features, self.adj_norm)
324
+ else:
325
+ # adj = adj-sp.eye(adj.shape[0])
326
+ # adj[0,0]=0
327
+
328
+ if type(adj) is not torch.Tensor:
329
+ features, adj = utils.to_tensor(features, adj, device=self.device)
330
+
331
+ self.features = features
332
+ self.adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
333
+ adj = utils.to_scipy(adj)
334
+
335
+ adj = adj-sp.eye(adj.shape[0])
336
+ mx = normalize_adj(adj)
337
+ adj = utils.sparse_mx_to_torch_sparse_tensor(mx).to(self.device)
338
+ return self.forward(self.features, self.adj_norm)
339
+
340
+ @torch.no_grad()
341
+ def predict_unnorm(self, features=None, adj=None):
342
+ self.eval()
343
+ if features is None and adj is None:
344
+ return self.forward(self.features, self.adj_norm)
345
+ else:
346
+ if type(adj) is not torch.Tensor:
347
+ features, adj = utils.to_tensor(features, adj, device=self.device)
348
+
349
+ self.features = features
350
+ self.adj_norm = adj
351
+ return self.forward(self.features, self.adj_norm)
352
+
353
+ class MyLinear(Module):
354
+ """Simple Linear layer, modified from https://github.com/tkipf/pygcn
355
+ """
356
+
357
+ def __init__(self, in_features, out_features, with_bias=True):
358
+ super(MyLinear, self).__init__()
359
+ self.in_features = in_features
360
+ self.out_features = out_features
361
+ self.weight = Parameter(torch.FloatTensor(in_features, out_features))
362
+ if with_bias:
363
+ self.bias = Parameter(torch.FloatTensor(out_features))
364
+ else:
365
+ self.register_parameter('bias', None)
366
+ self.reset_parameters()
367
+
368
+ def reset_parameters(self):
369
+ # stdv = 1. / math.sqrt(self.weight.size(1))
370
+ stdv = 1. / math.sqrt(self.weight.T.size(1))
371
+ self.weight.data.uniform_(-stdv, stdv)
372
+ if self.bias is not None:
373
+ self.bias.data.uniform_(-stdv, stdv)
374
+
375
+ def forward(self, input):
376
+ if input.data.is_sparse:
377
+ support = torch.spmm(input, self.weight)
378
+ else:
379
+ support = torch.mm(input, self.weight)
380
+ output = support
381
+ if self.bias is not None:
382
+ return output + self.bias
383
+ else:
384
+ return output
385
+
386
+ def __repr__(self):
387
+ return self.__class__.__name__ + ' (' \
388
+ + str(self.in_features) + ' -> ' \
389
+ + str(self.out_features) + ')'
390
+
391
+
392
+
393
+ def normalize_adj(mx):
394
+ """Normalize sparse adjacency matrix,
395
+ A' = (D + I)^-1/2 * ( A + I ) * (D + I)^-1/2
396
+ Row-normalize sparse matrix
397
+ Parameters
398
+ ----------
399
+ mx : scipy.sparse.csr_matrix
400
+ matrix to be normalized
401
+ Returns
402
+ -------
403
+ scipy.sprase.lil_matrix
404
+ normalized matrix
405
+ """
406
+
407
+ # TODO: maybe using coo format would be better?
408
+ if type(mx) is not sp.lil.lil_matrix:
409
+ mx = mx.tolil()
410
+ mx = mx + sp.eye(mx.shape[0])
411
+ rowsum = np.array(mx.sum(1))
412
+ r_inv = np.power(rowsum, -1/2).flatten()
413
+ r_inv[np.isinf(r_inv)] = 0.
414
+ r_mat_inv = sp.diags(r_inv)
415
+ mx = r_mat_inv.dot(mx)
416
+ mx = mx.dot(r_mat_inv)
417
+ return mx
GCond/models/mygatconv.py ADDED
@@ -0,0 +1,203 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Union, Tuple, Optional
2
+ from torch_geometric.typing import (OptPairTensor, Adj, Size, NoneType,
3
+ OptTensor)
4
+
5
+ import torch
6
+ from torch import Tensor
7
+ import torch.nn.functional as F
8
+ from torch.nn import Parameter, Linear
9
+ from torch_sparse import SparseTensor, set_diag
10
+ from torch_geometric.nn.conv import MessagePassing
11
+ from torch_geometric.utils import remove_self_loops, add_self_loops, softmax
12
+
13
+ from torch_geometric.nn.inits import glorot, zeros
14
+
15
+
16
+ class GATConv(MessagePassing):
17
+ r"""The graph attentional operator from the `"Graph Attention Networks"
18
+ <https://arxiv.org/abs/1710.10903>`_ paper
19
+
20
+ .. math::
21
+ \mathbf{x}^{\prime}_i = \alpha_{i,i}\mathbf{\Theta}\mathbf{x}_{i} +
22
+ \sum_{j \in \mathcal{N}(i)} \alpha_{i,j}\mathbf{\Theta}\mathbf{x}_{j},
23
+
24
+ where the attention coefficients :math:`\alpha_{i,j}` are computed as
25
+
26
+ .. math::
27
+ \alpha_{i,j} =
28
+ \frac{
29
+ \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top}
30
+ [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_j]
31
+ \right)\right)}
32
+ {\sum_{k \in \mathcal{N}(i) \cup \{ i \}}
33
+ \exp\left(\mathrm{LeakyReLU}\left(\mathbf{a}^{\top}
34
+ [\mathbf{\Theta}\mathbf{x}_i \, \Vert \, \mathbf{\Theta}\mathbf{x}_k]
35
+ \right)\right)}.
36
+
37
+ Args:
38
+ in_channels (int or tuple): Size of each input sample. A tuple
39
+ corresponds to the sizes of source and target dimensionalities.
40
+ out_channels (int): Size of each output sample.
41
+ heads (int, optional): Number of multi-head-attentions.
42
+ (default: :obj:`1`)
43
+ concat (bool, optional): If set to :obj:`False`, the multi-head
44
+ attentions are averaged instead of concatenated.
45
+ (default: :obj:`True`)
46
+ negative_slope (float, optional): LeakyReLU angle of the negative
47
+ slope. (default: :obj:`0.2`)
48
+ dropout (float, optional): Dropout probability of the normalized
49
+ attention coefficients which exposes each node to a stochastically
50
+ sampled neighborhood during training. (default: :obj:`0`)
51
+ add_self_loops (bool, optional): If set to :obj:`False`, will not add
52
+ self-loops to the input graph. (default: :obj:`True`)
53
+ bias (bool, optional): If set to :obj:`False`, the layer will not learn
54
+ an additive bias. (default: :obj:`True`)
55
+ **kwargs (optional): Additional arguments of
56
+ :class:`torch_geometric.nn.conv.MessagePassing`.
57
+ """
58
+ _alpha: OptTensor
59
+
60
+ def __init__(self, in_channels: Union[int, Tuple[int, int]],
61
+ out_channels: int, heads: int = 1, concat: bool = True,
62
+ negative_slope: float = 0.2, dropout: float = 0.0,
63
+ add_self_loops: bool = True, bias: bool = True, **kwargs):
64
+ kwargs.setdefault('aggr', 'add')
65
+ super(GATConv, self).__init__(node_dim=0, **kwargs)
66
+
67
+ self.in_channels = in_channels
68
+ self.out_channels = out_channels
69
+ self.heads = heads
70
+ self.concat = concat
71
+ self.negative_slope = negative_slope
72
+ self.dropout = dropout
73
+ self.add_self_loops = add_self_loops
74
+
75
+ if isinstance(in_channels, int):
76
+ self.lin_l = Linear(in_channels, heads * out_channels, bias=False)
77
+ self.lin_r = self.lin_l
78
+ else:
79
+ self.lin_l = Linear(in_channels[0], heads * out_channels, False)
80
+ self.lin_r = Linear(in_channels[1], heads * out_channels, False)
81
+
82
+ self.att_l = Parameter(torch.Tensor(1, heads, out_channels))
83
+ self.att_r = Parameter(torch.Tensor(1, heads, out_channels))
84
+
85
+ if bias and concat:
86
+ self.bias = Parameter(torch.Tensor(heads * out_channels))
87
+ elif bias and not concat:
88
+ self.bias = Parameter(torch.Tensor(out_channels))
89
+ else:
90
+ self.register_parameter('bias', None)
91
+
92
+ self._alpha = None
93
+
94
+ self.reset_parameters()
95
+ self.edge_weight = None
96
+
97
+ def reset_parameters(self):
98
+ glorot(self.lin_l.weight)
99
+ glorot(self.lin_r.weight)
100
+ glorot(self.att_l)
101
+ glorot(self.att_r)
102
+ zeros(self.bias)
103
+
104
+ def forward(self, x: Union[Tensor, OptPairTensor], edge_index: Adj,
105
+ size: Size = None, return_attention_weights=None, edge_weight=None):
106
+ # type: (Union[Tensor, OptPairTensor], Tensor, Size, NoneType) -> Tensor # noqa
107
+ # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, NoneType) -> Tensor # noqa
108
+ # type: (Union[Tensor, OptPairTensor], Tensor, Size, bool) -> Tuple[Tensor, Tuple[Tensor, Tensor]] # noqa
109
+ # type: (Union[Tensor, OptPairTensor], SparseTensor, Size, bool) -> Tuple[Tensor, SparseTensor] # noqa
110
+ r"""
111
+ Args:
112
+ return_attention_weights (bool, optional): If set to :obj:`True`,
113
+ will additionally return the tuple
114
+ :obj:`(edge_index, attention_weights)`, holding the computed
115
+ attention weights for each edge. (default: :obj:`None`)
116
+ """
117
+ H, C = self.heads, self.out_channels
118
+
119
+ x_l: OptTensor = None
120
+ x_r: OptTensor = None
121
+ alpha_l: OptTensor = None
122
+ alpha_r: OptTensor = None
123
+ if isinstance(x, Tensor):
124
+ assert x.dim() == 2, 'Static graphs not supported in `GATConv`.'
125
+ x_l = x_r = self.lin_l(x).view(-1, H, C)
126
+ alpha_l = (x_l * self.att_l).sum(dim=-1)
127
+ alpha_r = (x_r * self.att_r).sum(dim=-1)
128
+ else:
129
+ x_l, x_r = x[0], x[1]
130
+ assert x[0].dim() == 2, 'Static graphs not supported in `GATConv`.'
131
+ x_l = self.lin_l(x_l).view(-1, H, C)
132
+ alpha_l = (x_l * self.att_l).sum(dim=-1)
133
+ if x_r is not None:
134
+ x_r = self.lin_r(x_r).view(-1, H, C)
135
+ alpha_r = (x_r * self.att_r).sum(dim=-1)
136
+
137
+
138
+ assert x_l is not None
139
+ assert alpha_l is not None
140
+
141
+ if self.add_self_loops:
142
+ if isinstance(edge_index, Tensor):
143
+ num_nodes = x_l.size(0)
144
+ if x_r is not None:
145
+ num_nodes = min(num_nodes, x_r.size(0))
146
+ if size is not None:
147
+ num_nodes = min(size[0], size[1])
148
+ edge_index, _ = remove_self_loops(edge_index)
149
+ edge_index, _ = add_self_loops(edge_index, num_nodes=num_nodes)
150
+
151
+ if edge_weight is not None:
152
+ if self.edge_weight is None:
153
+ self.edge_weight = edge_weight
154
+
155
+ if edge_index.size(1) != self.edge_weight.shape[0]:
156
+ self.edge_weight = edge_weight
157
+
158
+ elif isinstance(edge_index, SparseTensor):
159
+ edge_index = set_diag(edge_index)
160
+
161
+ # propagate_type: (x: OptPairTensor, alpha: OptPairTensor)
162
+
163
+ out = self.propagate(edge_index, x=(x_l, x_r),
164
+ alpha=(alpha_l, alpha_r), size=size)
165
+
166
+ alpha = self._alpha
167
+ self._alpha = None
168
+
169
+ if self.concat:
170
+ out = out.view(-1, self.heads * self.out_channels)
171
+ else:
172
+ out = out.mean(dim=1)
173
+
174
+ if self.bias is not None:
175
+ out += self.bias
176
+
177
+ if isinstance(return_attention_weights, bool):
178
+ assert alpha is not None
179
+ if isinstance(edge_index, Tensor):
180
+ return out, (edge_index, alpha)
181
+ elif isinstance(edge_index, SparseTensor):
182
+ return out, edge_index.set_value(alpha, layout='coo')
183
+ else:
184
+ return out
185
+
186
+ def message(self, x_j: Tensor, alpha_j: Tensor, alpha_i: OptTensor,
187
+ index: Tensor, ptr: OptTensor,
188
+ size_i: Optional[int]) -> Tensor:
189
+ alpha = alpha_j if alpha_i is None else alpha_j + alpha_i
190
+ alpha = F.leaky_relu(alpha, self.negative_slope)
191
+ alpha = softmax(alpha, index, ptr, size_i)
192
+ self._alpha = alpha
193
+ alpha = F.dropout(alpha, p=self.dropout, training=self.training)
194
+
195
+ if self.edge_weight is not None:
196
+ x_j = self.edge_weight.view(-1, 1, 1) * x_j
197
+ return x_j * alpha.unsqueeze(-1)
198
+
199
+ def __repr__(self):
200
+ return '{}({}, {}, heads={})'.format(self.__class__.__name__,
201
+ self.in_channels,
202
+ self.out_channels, self.heads)
203
+
GCond/models/mygraphsage.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+ import math
4
+ import torch
5
+ import torch.optim as optim
6
+ from torch.nn.parameter import Parameter
7
+ from torch.nn.modules.module import Module
8
+ from deeprobust.graph import utils
9
+ from copy import deepcopy
10
+ from sklearn.metrics import f1_score
11
+ from torch.nn import init
12
+ import torch_sparse
13
+ from torch_geometric.data import NeighborSampler
14
+ from torch_sparse import SparseTensor
15
+
16
+
17
+ class SageConvolution(Module):
18
+ """Simple GCN layer, similar to https://github.com/tkipf/pygcn
19
+ """
20
+
21
+ def __init__(self, in_features, out_features, with_bias=True, root_weight=False):
22
+ super(SageConvolution, self).__init__()
23
+ self.in_features = in_features
24
+ self.out_features = out_features
25
+ self.weight_l = Parameter(torch.FloatTensor(in_features, out_features))
26
+ self.bias_l = Parameter(torch.FloatTensor(out_features))
27
+ self.weight_r = Parameter(torch.FloatTensor(in_features, out_features))
28
+ self.bias_r = Parameter(torch.FloatTensor(out_features))
29
+ self.reset_parameters()
30
+ self.root_weight = root_weight
31
+ # self.weight = Parameter(torch.FloatTensor(out_features, in_features))
32
+ # self.linear = torch.nn.Linear(self.in_features, self.out_features)
33
+
34
+ def reset_parameters(self):
35
+ # stdv = 1. / math.sqrt(self.weight.size(1))
36
+ stdv = 1. / math.sqrt(self.weight_l.T.size(1))
37
+ self.weight_l.data.uniform_(-stdv, stdv)
38
+ self.bias_l.data.uniform_(-stdv, stdv)
39
+
40
+ stdv = 1. / math.sqrt(self.weight_r.T.size(1))
41
+ self.weight_r.data.uniform_(-stdv, stdv)
42
+ self.bias_r.data.uniform_(-stdv, stdv)
43
+
44
+ def forward(self, input, adj, size=None):
45
+ """ Graph Convolutional Layer forward function
46
+ """
47
+ if input.data.is_sparse:
48
+ support = torch.spmm(input, self.weight_l)
49
+ else:
50
+ support = torch.mm(input, self.weight_l)
51
+ if isinstance(adj, torch_sparse.SparseTensor):
52
+ output = torch_sparse.matmul(adj, support)
53
+ else:
54
+ output = torch.spmm(adj, support)
55
+ output = output + self.bias_l
56
+
57
+ if self.root_weight:
58
+ if size is not None:
59
+ output = output + input[:size[1]] @ self.weight_r + self.bias_r
60
+ else:
61
+ output = output + input @ self.weight_r + self.bias_r
62
+ else:
63
+ output = output
64
+
65
+ return output
66
+
67
+ def __repr__(self):
68
+ return self.__class__.__name__ + ' (' \
69
+ + str(self.in_features) + ' -> ' \
70
+ + str(self.out_features) + ')'
71
+
72
+
73
+ class GraphSage(nn.Module):
74
+
75
+ def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5, lr=0.01, weight_decay=5e-4,
76
+ with_relu=True, with_bias=True, with_bn=False, device=None):
77
+
78
+ super(GraphSage, self).__init__()
79
+
80
+ assert device is not None, "Please specify 'device'!"
81
+ self.device = device
82
+ self.nfeat = nfeat
83
+ self.nclass = nclass
84
+
85
+ self.layers = nn.ModuleList([])
86
+
87
+ if nlayers == 1:
88
+ self.layers.append(SageConvolution(nfeat, nclass, with_bias=with_bias))
89
+ else:
90
+ if with_bn:
91
+ self.bns = torch.nn.ModuleList()
92
+ self.bns.append(nn.BatchNorm1d(nhid))
93
+ self.layers.append(SageConvolution(nfeat, nhid, with_bias=with_bias))
94
+ for i in range(nlayers-2):
95
+ self.layers.append(SageConvolution(nhid, nhid, with_bias=with_bias))
96
+ if with_bn:
97
+ self.bns.append(nn.BatchNorm1d(nhid))
98
+ self.layers.append(SageConvolution(nhid, nclass, with_bias=with_bias))
99
+
100
+ self.dropout = dropout
101
+ self.lr = lr
102
+ if not with_relu:
103
+ self.weight_decay = 0
104
+ else:
105
+ self.weight_decay = weight_decay
106
+ self.with_relu = with_relu
107
+ self.with_bn = with_bn
108
+ self.with_bias = with_bias
109
+ self.output = None
110
+ self.best_model = None
111
+ self.best_output = None
112
+ self.adj_norm = None
113
+ self.features = None
114
+ self.multi_label = None
115
+
116
+ def forward(self, x, adj):
117
+ for ix, layer in enumerate(self.layers):
118
+ x = layer(x, adj)
119
+ if ix != len(self.layers) - 1:
120
+ x = self.bns[ix](x) if self.with_bn else x
121
+ if self.with_relu:
122
+ x = F.relu(x)
123
+ x = F.dropout(x, self.dropout, training=self.training)
124
+
125
+ if self.multi_label:
126
+ return torch.sigmoid(x)
127
+ else:
128
+ return F.log_softmax(x, dim=1)
129
+
130
+ def forward_sampler(self, x, adjs):
131
+ # TODO: do we need normalization?
132
+ # for ix, layer in enumerate(self.layers):
133
+ for ix, (adj, _, size) in enumerate(adjs):
134
+ # x_target = x[: size[1]]
135
+ # x = self.layers[ix]((x, x_target), edge_index)
136
+ # adj = adj.to(self.device)
137
+ x = self.layers[ix](x, adj, size=size)
138
+ if ix != len(self.layers) - 1:
139
+ x = self.bns[ix](x) if self.with_bn else x
140
+ if self.with_relu:
141
+ x = F.relu(x)
142
+ x = F.dropout(x, self.dropout, training=self.training)
143
+
144
+ if self.multi_label:
145
+ return torch.sigmoid(x)
146
+ else:
147
+ return F.log_softmax(x, dim=1)
148
+
149
+ def forward_sampler_syn(self, x, adjs):
150
+ for ix, (adj) in enumerate(adjs):
151
+ x = self.layers[ix](x, adj)
152
+ if ix != len(self.layers) - 1:
153
+ x = self.bns[ix](x) if self.with_bn else x
154
+ if self.with_relu:
155
+ x = F.relu(x)
156
+ x = F.dropout(x, self.dropout, training=self.training)
157
+
158
+ if self.multi_label:
159
+ return torch.sigmoid(x)
160
+ else:
161
+ return F.log_softmax(x, dim=1)
162
+
163
+
164
+ def initialize(self):
165
+ """Initialize parameters of GCN.
166
+ """
167
+ for layer in self.layers:
168
+ layer.reset_parameters()
169
+ if self.with_bn:
170
+ for bn in self.bns:
171
+ bn.reset_parameters()
172
+
173
+ def fit_with_val(self, features, adj, labels, data, train_iters=200, initialize=True, verbose=False, normalize=True, patience=None, noval=False, **kwargs):
174
+ '''data: full data class'''
175
+ if initialize:
176
+ self.initialize()
177
+
178
+ # features, adj, labels = data.feat_train, data.adj_train, data.labels_train
179
+ if type(adj) is not torch.Tensor:
180
+ features, adj, labels = utils.to_tensor(features, adj, labels, device=self.device)
181
+ else:
182
+ features = features.to(self.device)
183
+ adj = adj.to(self.device)
184
+ labels = labels.to(self.device)
185
+
186
+ if normalize:
187
+ if utils.is_sparse_tensor(adj):
188
+ adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
189
+ else:
190
+ adj_norm = utils.normalize_adj_tensor(adj)
191
+ else:
192
+ adj_norm = adj
193
+
194
+ if 'feat_norm' in kwargs and kwargs['feat_norm']:
195
+ from utils import row_normalize_tensor
196
+ features = row_normalize_tensor(features-features.min())
197
+
198
+ self.adj_norm = adj_norm
199
+ self.features = features
200
+
201
+ if len(labels.shape) > 1:
202
+ self.multi_label = True
203
+ self.loss = torch.nn.BCELoss()
204
+ else:
205
+ self.multi_label = False
206
+ self.loss = F.nll_loss
207
+
208
+ labels = labels.float() if self.multi_label else labels
209
+ self.labels = labels
210
+
211
+ if noval:
212
+ self._train_with_val(labels, data, train_iters, verbose, adj_val=True)
213
+ else:
214
+ self._train_with_val(labels, data, train_iters, verbose)
215
+
216
+ def _train_with_val(self, labels, data, train_iters, verbose, adj_val=False):
217
+ if adj_val:
218
+ feat_full, adj_full = data.feat_val, data.adj_val
219
+ else:
220
+ feat_full, adj_full = data.feat_full, data.adj_full
221
+ feat_full, adj_full = utils.to_tensor(feat_full, adj_full, device=self.device)
222
+ adj_full_norm = utils.normalize_adj_tensor(adj_full, sparse=True)
223
+ labels_val = torch.LongTensor(data.labels_val).to(self.device)
224
+
225
+ if verbose:
226
+ print('=== training gcn model ===')
227
+ optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
228
+
229
+ adj_norm = self.adj_norm
230
+ node_idx = torch.arange(adj_norm.shape[0]).long()
231
+
232
+ edge_index = adj_norm.nonzero().T
233
+ adj_norm = SparseTensor(row=edge_index[0], col=edge_index[1],
234
+ value=adj_norm[edge_index[0], edge_index[1]], sparse_sizes=adj_norm.size()).t()
235
+ # edge_index = adj_norm._indices()
236
+ # adj_norm = SparseTensor(row=edge_index[0], col=edge_index[1],
237
+ # value=adj_norm._values(), sparse_sizes=adj_norm.size()).t()
238
+
239
+ if adj_norm.density() > 0.5: # if the weighted graph is too dense, we need a larger neighborhood size
240
+ sizes = [30, 20]
241
+ else:
242
+ sizes = [5, 5]
243
+ train_loader = NeighborSampler(adj_norm,
244
+ node_idx=node_idx,
245
+ sizes=sizes, batch_size=len(node_idx),
246
+ num_workers=0, return_e_id=False,
247
+ num_nodes=adj_norm.size(0),
248
+ shuffle=True)
249
+
250
+ best_acc_val = 0
251
+ for i in range(train_iters):
252
+ if i == train_iters // 2:
253
+ lr = self.lr*0.1
254
+ optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=self.weight_decay)
255
+
256
+ self.train()
257
+ # optimizer.zero_grad()
258
+ # output = self.forward(self.features, self.adj_norm)
259
+ # loss_train = self.loss(output, labels)
260
+ # loss_train.backward()
261
+ # optimizer.step()
262
+
263
+ for batch_size, n_id, adjs in train_loader:
264
+ adjs = [adj.to(self.device) for adj in adjs]
265
+ optimizer.zero_grad()
266
+ out = self.forward_sampler(self.features[n_id], adjs)
267
+ loss_train = F.nll_loss(out, labels[n_id[:batch_size]])
268
+ loss_train.backward()
269
+ optimizer.step()
270
+
271
+
272
+ if verbose and i % 100 == 0:
273
+ print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
274
+
275
+ with torch.no_grad():
276
+ self.eval()
277
+ output = self.forward(feat_full, adj_full_norm)
278
+ if adj_val:
279
+ loss_val = F.nll_loss(output, labels_val)
280
+ acc_val = utils.accuracy(output, labels_val)
281
+ else:
282
+ loss_val = F.nll_loss(output[data.idx_val], labels_val)
283
+ acc_val = utils.accuracy(output[data.idx_val], labels_val)
284
+
285
+ if acc_val > best_acc_val:
286
+ best_acc_val = acc_val
287
+ self.output = output
288
+ weights = deepcopy(self.state_dict())
289
+
290
+ if verbose:
291
+ print('=== picking the best model according to the performance on validation ===')
292
+ self.load_state_dict(weights)
293
+
294
+ def test(self, idx_test):
295
+ """Evaluate GCN performance on test set.
296
+ Parameters
297
+ ----------
298
+ idx_test :
299
+ node testing indices
300
+ """
301
+ self.eval()
302
+ output = self.predict()
303
+ # output = self.output
304
+ loss_test = F.nll_loss(output[idx_test], self.labels[idx_test])
305
+ acc_test = utils.accuracy(output[idx_test], self.labels[idx_test])
306
+ print("Test set results:",
307
+ "loss= {:.4f}".format(loss_test.item()),
308
+ "accuracy= {:.4f}".format(acc_test.item()))
309
+ return acc_test.item()
310
+
311
+
312
+ @torch.no_grad()
313
+ def predict(self, features=None, adj=None):
314
+ """By default, the inputs should be unnormalized adjacency
315
+ Parameters
316
+ ----------
317
+ features :
318
+ node features. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.
319
+ adj :
320
+ adjcency matrix. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.
321
+ Returns
322
+ -------
323
+ torch.FloatTensor
324
+ output (log probabilities) of GCN
325
+ """
326
+
327
+ self.eval()
328
+ if features is None and adj is None:
329
+ return self.forward(self.features, self.adj_norm)
330
+ else:
331
+ if type(adj) is not torch.Tensor:
332
+ features, adj = utils.to_tensor(features, adj, device=self.device)
333
+
334
+ self.features = features
335
+ if utils.is_sparse_tensor(adj):
336
+ self.adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
337
+ else:
338
+ self.adj_norm = utils.normalize_adj_tensor(adj)
339
+ return self.forward(self.features, self.adj_norm)
340
+
341
+ @torch.no_grad()
342
+ def predict_unnorm(self, features=None, adj=None):
343
+ self.eval()
344
+ if features is None and adj is None:
345
+ return self.forward(self.features, self.adj_norm)
346
+ else:
347
+ if type(adj) is not torch.Tensor:
348
+ features, adj = utils.to_tensor(features, adj, device=self.device)
349
+
350
+ self.features = features
351
+ self.adj_norm = adj
352
+ return self.forward(self.features, self.adj_norm)
353
+
GCond/models/parametrized_adj.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch.nn.functional as F
3
+ import math
4
+ import torch
5
+ import torch.optim as optim
6
+ from torch.nn.parameter import Parameter
7
+ from torch.nn.modules.module import Module
8
+ from itertools import product
9
+ import numpy as np
10
+
11
+ class PGE(nn.Module):
12
+
13
+ def __init__(self, nfeat, nnodes, nhid=128, nlayers=3, device=None, args=None):
14
+ super(PGE, self).__init__()
15
+ if args.dataset in ['ogbn-arxiv', 'arxiv', 'flickr']:
16
+ nhid = 256
17
+ if args.dataset in ['reddit']:
18
+ nhid = 256
19
+ if args.reduction_rate==0.01:
20
+ nhid = 128
21
+ nlayers = 3
22
+ # nhid = 128
23
+
24
+ self.layers = nn.ModuleList([])
25
+ self.layers.append(nn.Linear(nfeat*2, nhid))
26
+ self.bns = torch.nn.ModuleList()
27
+ self.bns.append(nn.BatchNorm1d(nhid))
28
+ for i in range(nlayers-2):
29
+ self.layers.append(nn.Linear(nhid, nhid))
30
+ self.bns.append(nn.BatchNorm1d(nhid))
31
+ self.layers.append(nn.Linear(nhid, 1))
32
+
33
+ edge_index = np.array(list(product(range(nnodes), range(nnodes))))
34
+ self.edge_index = edge_index.T
35
+ self.nnodes = nnodes
36
+ self.device = device
37
+ self.reset_parameters()
38
+ self.cnt = 0
39
+ self.args = args
40
+ self.nnodes = nnodes
41
+
42
+ def forward(self, x, inference=False):
43
+ if self.args.dataset == 'reddit' and self.args.reduction_rate >= 0.01:
44
+ edge_index = self.edge_index
45
+ n_part = 5
46
+ splits = np.array_split(np.arange(edge_index.shape[1]), n_part)
47
+ edge_embed = []
48
+ for idx in splits:
49
+ tmp_edge_embed = torch.cat([x[edge_index[0][idx]],
50
+ x[edge_index[1][idx]]], axis=1)
51
+ for ix, layer in enumerate(self.layers):
52
+ tmp_edge_embed = layer(tmp_edge_embed)
53
+ if ix != len(self.layers) - 1:
54
+ tmp_edge_embed = self.bns[ix](tmp_edge_embed)
55
+ tmp_edge_embed = F.relu(tmp_edge_embed)
56
+ edge_embed.append(tmp_edge_embed)
57
+ edge_embed = torch.cat(edge_embed)
58
+ else:
59
+ edge_index = self.edge_index
60
+ edge_embed = torch.cat([x[edge_index[0]],
61
+ x[edge_index[1]]], axis=1)
62
+ for ix, layer in enumerate(self.layers):
63
+ edge_embed = layer(edge_embed)
64
+ if ix != len(self.layers) - 1:
65
+ edge_embed = self.bns[ix](edge_embed)
66
+ edge_embed = F.relu(edge_embed)
67
+
68
+ adj = edge_embed.reshape(self.nnodes, self.nnodes)
69
+
70
+ adj = (adj + adj.T)/2
71
+ adj = torch.sigmoid(adj)
72
+ adj = adj - torch.diag(torch.diag(adj, 0))
73
+ return adj
74
+
75
+ @torch.no_grad()
76
+ def inference(self, x):
77
+ # self.eval()
78
+ adj_syn = self.forward(x, inference=True)
79
+ return adj_syn
80
+
81
+ def reset_parameters(self):
82
+ def weight_reset(m):
83
+ if isinstance(m, nn.Linear):
84
+ m.reset_parameters()
85
+ if isinstance(m, nn.BatchNorm1d):
86
+ m.reset_parameters()
87
+ self.apply(weight_reset)
88
+
GCond/models/sgc.py ADDED
@@ -0,0 +1,290 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''one transformation with multiple propagation'''
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ import torch
6
+ import torch.optim as optim
7
+ from torch.nn.parameter import Parameter
8
+ from torch.nn.modules.module import Module
9
+ from deeprobust.graph import utils
10
+ from copy import deepcopy
11
+ from sklearn.metrics import f1_score
12
+ from torch.nn import init
13
+ import torch_sparse
14
+
15
+ class GraphConvolution(Module):
16
+ """Simple GCN layer, similar to https://github.com/tkipf/pygcn
17
+ """
18
+
19
+ def __init__(self, in_features, out_features, with_bias=True):
20
+ super(GraphConvolution, self).__init__()
21
+ self.in_features = in_features
22
+ self.out_features = out_features
23
+ self.weight = Parameter(torch.FloatTensor(in_features, out_features))
24
+ self.bias = Parameter(torch.FloatTensor(out_features))
25
+ self.reset_parameters()
26
+
27
+ def reset_parameters(self):
28
+ # stdv = 1. / math.sqrt(self.weight.size(1))
29
+ stdv = 1. / math.sqrt(self.weight.T.size(1))
30
+ self.weight.data.uniform_(-stdv, stdv)
31
+ if self.bias is not None:
32
+ self.bias.data.uniform_(-stdv, stdv)
33
+
34
+ def forward(self, input, adj):
35
+ """ Graph Convolutional Layer forward function
36
+ """
37
+ if input.data.is_sparse:
38
+ support = torch.spmm(input, self.weight)
39
+ else:
40
+ support = torch.mm(input, self.weight)
41
+ if isinstance(adj, torch_sparse.SparseTensor):
42
+ output = torch_sparse.matmul(adj, support)
43
+ else:
44
+ output = torch.spmm(adj, support)
45
+ if self.bias is not None:
46
+ return output + self.bias
47
+ else:
48
+ return output
49
+
50
+ def __repr__(self):
51
+ return self.__class__.__name__ + ' (' \
52
+ + str(self.in_features) + ' -> ' \
53
+ + str(self.out_features) + ')'
54
+
55
+
56
+ class SGC(nn.Module):
57
+
58
+ def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5, lr=0.01, weight_decay=5e-4,
59
+ with_relu=True, with_bias=True, with_bn=False, device=None):
60
+
61
+ super(SGC, self).__init__()
62
+
63
+ assert device is not None, "Please specify 'device'!"
64
+ self.device = device
65
+ self.nfeat = nfeat
66
+ self.nclass = nclass
67
+
68
+ self.conv = GraphConvolution(nfeat, nclass, with_bias=with_bias)
69
+
70
+ self.nlayers = nlayers
71
+ self.dropout = dropout
72
+ self.lr = lr
73
+ if not with_relu:
74
+ self.weight_decay = 0
75
+ else:
76
+ self.weight_decay = weight_decay
77
+ self.with_relu = with_relu
78
+ if with_bn:
79
+ print('Warning: SGC does not have bn!!!')
80
+ self.with_bn = False
81
+ self.with_bias = with_bias
82
+ self.output = None
83
+ self.best_model = None
84
+ self.best_output = None
85
+ self.adj_norm = None
86
+ self.features = None
87
+ self.multi_label = None
88
+
89
+ def forward(self, x, adj):
90
+ weight = self.conv.weight
91
+ bias = self.conv.bias
92
+ x = torch.mm(x, weight)
93
+ for i in range(self.nlayers):
94
+ x = torch.spmm(adj, x)
95
+ x = x + bias
96
+ if self.multi_label:
97
+ return torch.sigmoid(x)
98
+ else:
99
+ return F.log_softmax(x, dim=1)
100
+
101
+ def forward_sampler(self, x, adjs):
102
+ weight = self.conv.weight
103
+ bias = self.conv.bias
104
+ x = torch.mm(x, weight)
105
+ for ix, (adj, _, size) in enumerate(adjs):
106
+ x = torch_sparse.matmul(adj, x)
107
+ x = x + bias
108
+ if self.multi_label:
109
+ return torch.sigmoid(x)
110
+ else:
111
+ return F.log_softmax(x, dim=1)
112
+
113
+ def forward_sampler_syn(self, x, adjs):
114
+ weight = self.conv.weight
115
+ bias = self.conv.bias
116
+ x = torch.mm(x, weight)
117
+ for ix, (adj) in enumerate(adjs):
118
+ if type(adj) == torch.Tensor:
119
+ x = adj @ x
120
+ else:
121
+ x = torch_sparse.matmul(adj, x)
122
+ x = x + bias
123
+ if self.multi_label:
124
+ return torch.sigmoid(x)
125
+ else:
126
+ return F.log_softmax(x, dim=1)
127
+
128
+ def initialize(self):
129
+ """Initialize parameters of GCN.
130
+ """
131
+ self.conv.reset_parameters()
132
+ if self.with_bn:
133
+ for bn in self.bns:
134
+ bn.reset_parameters()
135
+
136
+ def fit_with_val(self, features, adj, labels, data, train_iters=200, initialize=True, verbose=False, normalize=True, patience=None, noval=False, **kwargs):
137
+ '''data: full data class'''
138
+ if initialize:
139
+ self.initialize()
140
+
141
+ # features, adj, labels = data.feat_train, data.adj_train, data.labels_train
142
+ if type(adj) is not torch.Tensor:
143
+ features, adj, labels = utils.to_tensor(features, adj, labels, device=self.device)
144
+ else:
145
+ features = features.to(self.device)
146
+ adj = adj.to(self.device)
147
+ labels = labels.to(self.device)
148
+
149
+ if normalize:
150
+ if utils.is_sparse_tensor(adj):
151
+ adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
152
+ else:
153
+ adj_norm = utils.normalize_adj_tensor(adj)
154
+ else:
155
+ adj_norm = adj
156
+
157
+ if 'feat_norm' in kwargs and kwargs['feat_norm']:
158
+ from utils import row_normalize_tensor
159
+ features = row_normalize_tensor(features-features.min())
160
+
161
+ self.adj_norm = adj_norm
162
+ self.features = features
163
+
164
+ if len(labels.shape) > 1:
165
+ self.multi_label = True
166
+ self.loss = torch.nn.BCELoss()
167
+ else:
168
+ self.multi_label = False
169
+ self.loss = F.nll_loss
170
+
171
+ labels = labels.float() if self.multi_label else labels
172
+ self.labels = labels
173
+
174
+ if noval:
175
+ self._train_with_val(labels, data, train_iters, verbose, adj_val=True)
176
+ else:
177
+ self._train_with_val(labels, data, train_iters, verbose)
178
+
179
+ def _train_with_val(self, labels, data, train_iters, verbose, adj_val=False):
180
+ if adj_val:
181
+ feat_full, adj_full = data.feat_val, data.adj_val
182
+ else:
183
+ feat_full, adj_full = data.feat_full, data.adj_full
184
+
185
+ feat_full, adj_full = utils.to_tensor(feat_full, adj_full, device=self.device)
186
+ adj_full_norm = utils.normalize_adj_tensor(adj_full, sparse=True)
187
+ labels_val = torch.LongTensor(data.labels_val).to(self.device)
188
+
189
+ if verbose:
190
+ print('=== training gcn model ===')
191
+ optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
192
+
193
+ best_acc_val = 0
194
+
195
+ for i in range(train_iters):
196
+ if i == train_iters // 2:
197
+ lr = self.lr*0.1
198
+ optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=self.weight_decay)
199
+
200
+ self.train()
201
+ optimizer.zero_grad()
202
+ output = self.forward(self.features, self.adj_norm)
203
+ loss_train = self.loss(output, labels)
204
+ loss_train.backward()
205
+ optimizer.step()
206
+
207
+ if verbose and i % 100 == 0:
208
+ print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
209
+
210
+ with torch.no_grad():
211
+ self.eval()
212
+ output = self.forward(feat_full, adj_full_norm)
213
+ if adj_val:
214
+ loss_val = F.nll_loss(output, labels_val)
215
+ acc_val = utils.accuracy(output, labels_val)
216
+ else:
217
+ loss_val = F.nll_loss(output[data.idx_val], labels_val)
218
+ acc_val = utils.accuracy(output[data.idx_val], labels_val)
219
+
220
+ if acc_val > best_acc_val:
221
+ best_acc_val = acc_val
222
+ self.output = output
223
+ weights = deepcopy(self.state_dict())
224
+
225
+ if verbose:
226
+ print('=== picking the best model according to the performance on validation ===')
227
+ self.load_state_dict(weights)
228
+
229
+
230
+ def test(self, idx_test):
231
+ """Evaluate GCN performance on test set.
232
+ Parameters
233
+ ----------
234
+ idx_test :
235
+ node testing indices
236
+ """
237
+ self.eval()
238
+ output = self.predict()
239
+ # output = self.output
240
+ loss_test = F.nll_loss(output[idx_test], self.labels[idx_test])
241
+ acc_test = utils.accuracy(output[idx_test], self.labels[idx_test])
242
+ print("Test set results:",
243
+ "loss= {:.4f}".format(loss_test.item()),
244
+ "accuracy= {:.4f}".format(acc_test.item()))
245
+ return acc_test.item()
246
+
247
+
248
+ @torch.no_grad()
249
+ def predict(self, features=None, adj=None):
250
+ """By default, the inputs should be unnormalized adjacency
251
+ Parameters
252
+ ----------
253
+ features :
254
+ node features. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.
255
+ adj :
256
+ adjcency matrix. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.
257
+ Returns
258
+ -------
259
+ torch.FloatTensor
260
+ output (log probabilities) of GCN
261
+ """
262
+
263
+ self.eval()
264
+ if features is None and adj is None:
265
+ return self.forward(self.features, self.adj_norm)
266
+ else:
267
+ if type(adj) is not torch.Tensor:
268
+ features, adj = utils.to_tensor(features, adj, device=self.device)
269
+
270
+ self.features = features
271
+ if utils.is_sparse_tensor(adj):
272
+ self.adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
273
+ else:
274
+ self.adj_norm = utils.normalize_adj_tensor(adj)
275
+ return self.forward(self.features, self.adj_norm)
276
+
277
+ @torch.no_grad()
278
+ def predict_unnorm(self, features=None, adj=None):
279
+ self.eval()
280
+ if features is None and adj is None:
281
+ return self.forward(self.features, self.adj_norm)
282
+ else:
283
+ if type(adj) is not torch.Tensor:
284
+ features, adj = utils.to_tensor(features, adj, device=self.device)
285
+
286
+ self.features = features
287
+ self.adj_norm = adj
288
+ return self.forward(self.features, self.adj_norm)
289
+
290
+
GCond/models/sgc_multi.py ADDED
@@ -0,0 +1,315 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """multiple transformaiton and multiple propagation"""
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ import torch
6
+ import torch.optim as optim
7
+ from torch.nn.parameter import Parameter
8
+ from torch.nn.modules.module import Module
9
+ from deeprobust.graph import utils
10
+ from copy import deepcopy
11
+ from sklearn.metrics import f1_score
12
+ from torch.nn import init
13
+ import torch_sparse
14
+
15
+
16
+ class SGC(nn.Module):
17
+
18
+ def __init__(self, nfeat, nhid, nclass, nlayers=2, dropout=0.5, lr=0.01, weight_decay=5e-4,
19
+ ntrans=2, with_bias=True, with_bn=False, device=None):
20
+
21
+ """nlayers indicates the number of propagations"""
22
+ super(SGC, self).__init__()
23
+
24
+ assert device is not None, "Please specify 'device'!"
25
+ self.device = device
26
+ self.nfeat = nfeat
27
+ self.nclass = nclass
28
+
29
+
30
+ self.layers = nn.ModuleList([])
31
+ if ntrans == 1:
32
+ self.layers.append(MyLinear(nfeat, nclass))
33
+ else:
34
+ self.layers.append(MyLinear(nfeat, nhid))
35
+ if with_bn:
36
+ self.bns = torch.nn.ModuleList()
37
+ self.bns.append(nn.BatchNorm1d(nhid))
38
+ for i in range(ntrans-2):
39
+ if with_bn:
40
+ self.bns.append(nn.BatchNorm1d(nhid))
41
+ self.layers.append(MyLinear(nhid, nhid))
42
+ self.layers.append(MyLinear(nhid, nclass))
43
+
44
+ self.nlayers = nlayers
45
+ self.dropout = dropout
46
+ self.lr = lr
47
+ self.with_bn = with_bn
48
+ self.with_bias = with_bias
49
+ self.weight_decay = weight_decay
50
+ self.output = None
51
+ self.best_model = None
52
+ self.best_output = None
53
+ self.adj_norm = None
54
+ self.features = None
55
+ self.multi_label = None
56
+
57
+ def forward(self, x, adj):
58
+ for ix, layer in enumerate(self.layers):
59
+ x = layer(x)
60
+ if ix != len(self.layers) - 1:
61
+ x = self.bns[ix](x) if self.with_bn else x
62
+ x = F.relu(x)
63
+ x = F.dropout(x, self.dropout, training=self.training)
64
+
65
+ for i in range(self.nlayers):
66
+ x = torch.spmm(adj, x)
67
+
68
+ if self.multi_label:
69
+ return torch.sigmoid(x)
70
+ else:
71
+ return F.log_softmax(x, dim=1)
72
+
73
+ def forward_sampler(self, x, adjs):
74
+ for ix, layer in enumerate(self.layers):
75
+ x = layer(x)
76
+ if ix != len(self.layers) - 1:
77
+ x = self.bns[ix](x) if self.with_bn else x
78
+ x = F.relu(x)
79
+ x = F.dropout(x, self.dropout, training=self.training)
80
+
81
+ for ix, (adj, _, size) in enumerate(adjs):
82
+ # x_target = x[: size[1]]
83
+ # x = self.layers[ix]((x, x_target), edge_index)
84
+ # adj = adj.to(self.device)
85
+ x = torch_sparse.matmul(adj, x)
86
+
87
+ if self.multi_label:
88
+ return torch.sigmoid(x)
89
+ else:
90
+ return F.log_softmax(x, dim=1)
91
+
92
+ def forward_sampler_syn(self, x, adjs):
93
+ for ix, layer in enumerate(self.layers):
94
+ x = layer(x)
95
+ if ix != len(self.layers) - 1:
96
+ x = self.bns[ix](x) if self.with_bn else x
97
+ x = F.relu(x)
98
+ x = F.dropout(x, self.dropout, training=self.training)
99
+
100
+ for ix, (adj) in enumerate(adjs):
101
+ if type(adj) == torch.Tensor:
102
+ x = adj @ x
103
+ else:
104
+ x = torch_sparse.matmul(adj, x)
105
+
106
+ if self.multi_label:
107
+ return torch.sigmoid(x)
108
+ else:
109
+ return F.log_softmax(x, dim=1)
110
+
111
+
112
+ def initialize(self):
113
+ """Initialize parameters of GCN.
114
+ """
115
+ for layer in self.layers:
116
+ layer.reset_parameters()
117
+ if self.with_bn:
118
+ for bn in self.bns:
119
+ bn.reset_parameters()
120
+
121
+ def fit_with_val(self, features, adj, labels, data, train_iters=200, initialize=True, verbose=False, normalize=True, patience=None, noval=False, **kwargs):
122
+ '''data: full data class'''
123
+ if initialize:
124
+ self.initialize()
125
+
126
+ # features, adj, labels = data.feat_train, data.adj_train, data.labels_train
127
+ if type(adj) is not torch.Tensor:
128
+ features, adj, labels = utils.to_tensor(features, adj, labels, device=self.device)
129
+ else:
130
+ features = features.to(self.device)
131
+ adj = adj.to(self.device)
132
+ labels = labels.to(self.device)
133
+
134
+ if normalize:
135
+ if utils.is_sparse_tensor(adj):
136
+ adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
137
+ else:
138
+ adj_norm = utils.normalize_adj_tensor(adj)
139
+ else:
140
+ adj_norm = adj
141
+
142
+ if 'feat_norm' in kwargs and kwargs['feat_norm']:
143
+ from utils import row_normalize_tensor
144
+ features = row_normalize_tensor(features-features.min())
145
+
146
+ self.adj_norm = adj_norm
147
+ self.features = features
148
+
149
+ if len(labels.shape) > 1:
150
+ self.multi_label = True
151
+ self.loss = torch.nn.BCELoss()
152
+ else:
153
+ self.multi_label = False
154
+ self.loss = F.nll_loss
155
+
156
+ labels = labels.float() if self.multi_label else labels
157
+ self.labels = labels
158
+
159
+ if noval:
160
+ self._train_with_val(labels, data, train_iters, verbose, adj_val=True)
161
+ else:
162
+ self._train_with_val(labels, data, train_iters, verbose)
163
+
164
+ def _train_with_val(self, labels, data, train_iters, verbose, adj_val=False):
165
+ if adj_val:
166
+ feat_full, adj_full = data.feat_val, data.adj_val
167
+ else:
168
+ feat_full, adj_full = data.feat_full, data.adj_full
169
+
170
+ feat_full, adj_full = utils.to_tensor(feat_full, adj_full, device=self.device)
171
+ adj_full_norm = utils.normalize_adj_tensor(adj_full, sparse=True)
172
+ labels_val = torch.LongTensor(data.labels_val).to(self.device)
173
+
174
+ if verbose:
175
+ print('=== training gcn model ===')
176
+ optimizer = optim.Adam(self.parameters(), lr=self.lr, weight_decay=self.weight_decay)
177
+
178
+ best_acc_val = 0
179
+
180
+ for i in range(train_iters):
181
+ if i == train_iters // 2:
182
+ lr = self.lr*0.1
183
+ optimizer = optim.Adam(self.parameters(), lr=lr, weight_decay=self.weight_decay)
184
+
185
+ self.train()
186
+ optimizer.zero_grad()
187
+ output = self.forward(self.features, self.adj_norm)
188
+ loss_train = self.loss(output, labels)
189
+ loss_train.backward()
190
+ optimizer.step()
191
+
192
+ if verbose and i % 100 == 0:
193
+ print('Epoch {}, training loss: {}'.format(i, loss_train.item()))
194
+
195
+ with torch.no_grad():
196
+ self.eval()
197
+ output = self.forward(feat_full, adj_full_norm)
198
+ if adj_val:
199
+ loss_val = F.nll_loss(output, labels_val)
200
+ acc_val = utils.accuracy(output, labels_val)
201
+ else:
202
+ loss_val = F.nll_loss(output[data.idx_val], labels_val)
203
+ acc_val = utils.accuracy(output[data.idx_val], labels_val)
204
+
205
+ if acc_val > best_acc_val:
206
+ best_acc_val = acc_val
207
+ self.output = output
208
+ weights = deepcopy(self.state_dict())
209
+
210
+ if verbose:
211
+ print('=== picking the best model according to the performance on validation ===')
212
+ self.load_state_dict(weights)
213
+
214
+
215
+ def test(self, idx_test):
216
+ """Evaluate GCN performance on test set.
217
+ Parameters
218
+ ----------
219
+ idx_test :
220
+ node testing indices
221
+ """
222
+ self.eval()
223
+ output = self.predict()
224
+ # output = self.output
225
+ loss_test = F.nll_loss(output[idx_test], self.labels[idx_test])
226
+ acc_test = utils.accuracy(output[idx_test], self.labels[idx_test])
227
+ print("Test set results:",
228
+ "loss= {:.4f}".format(loss_test.item()),
229
+ "accuracy= {:.4f}".format(acc_test.item()))
230
+ return acc_test.item()
231
+
232
+
233
+ @torch.no_grad()
234
+ def predict(self, features=None, adj=None):
235
+ """By default, the inputs should be unnormalized adjacency
236
+ Parameters
237
+ ----------
238
+ features :
239
+ node features. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.
240
+ adj :
241
+ adjcency matrix. If `features` and `adj` are not given, this function will use previous stored `features` and `adj` from training to make predictions.
242
+ Returns
243
+ -------
244
+ torch.FloatTensor
245
+ output (log probabilities) of GCN
246
+ """
247
+
248
+ self.eval()
249
+ if features is None and adj is None:
250
+ return self.forward(self.features, self.adj_norm)
251
+ else:
252
+ if type(adj) is not torch.Tensor:
253
+ features, adj = utils.to_tensor(features, adj, device=self.device)
254
+
255
+ self.features = features
256
+ if utils.is_sparse_tensor(adj):
257
+ self.adj_norm = utils.normalize_adj_tensor(adj, sparse=True)
258
+ else:
259
+ self.adj_norm = utils.normalize_adj_tensor(adj)
260
+ return self.forward(self.features, self.adj_norm)
261
+
262
+ @torch.no_grad()
263
+ def predict_unnorm(self, features=None, adj=None):
264
+ self.eval()
265
+ if features is None and adj is None:
266
+ return self.forward(self.features, self.adj_norm)
267
+ else:
268
+ if type(adj) is not torch.Tensor:
269
+ features, adj = utils.to_tensor(features, adj, device=self.device)
270
+
271
+ self.features = features
272
+ self.adj_norm = adj
273
+ return self.forward(self.features, self.adj_norm)
274
+
275
+
276
+
277
+ class MyLinear(Module):
278
+ """Simple Linear layer, modified from https://github.com/tkipf/pygcn
279
+ """
280
+
281
+ def __init__(self, in_features, out_features, with_bias=True):
282
+ super(MyLinear, self).__init__()
283
+ self.in_features = in_features
284
+ self.out_features = out_features
285
+ self.weight = Parameter(torch.FloatTensor(in_features, out_features))
286
+ if with_bias:
287
+ self.bias = Parameter(torch.FloatTensor(out_features))
288
+ else:
289
+ self.register_parameter('bias', None)
290
+ self.reset_parameters()
291
+
292
+ def reset_parameters(self):
293
+ # stdv = 1. / math.sqrt(self.weight.size(1))
294
+ stdv = 1. / math.sqrt(self.weight.T.size(1))
295
+ self.weight.data.uniform_(-stdv, stdv)
296
+ if self.bias is not None:
297
+ self.bias.data.uniform_(-stdv, stdv)
298
+
299
+ def forward(self, input):
300
+ if input.data.is_sparse:
301
+ support = torch.spmm(input, self.weight)
302
+ else:
303
+ support = torch.mm(input, self.weight)
304
+ output = support
305
+ if self.bias is not None:
306
+ return output + self.bias
307
+ else:
308
+ return output
309
+
310
+ def __repr__(self):
311
+ return self.__class__.__name__ + ' (' \
312
+ + str(self.in_features) + ' -> ' \
313
+ + str(self.out_features) + ')'
314
+
315
+
GCond/requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch
2
+ torch_geometric
3
+ scipy
4
+ numpy
5
+ ogb
6
+ tqdm
7
+ torch_sparse
8
+ torch_vision
9
+ configs
10
+ deeprobust
11
+ scikit_learn
GCond/res/cross/empty ADDED
@@ -0,0 +1 @@
 
 
1
+
GCond/script.sh ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ python train_cond_tranduct_sampler.py --dataset cora --mlp=0 --nlayers=2 --sgc=1 --lr_feat=1e-4 --gpu_id=0 --lr_adj=1e-4 --r=0.5 --seed=1000
2
+
3
+ python train_cond_tranduct_sampler.py --dataset ogbn-arxiv --mlp=0 --nlayers=2 --sgc=1 --lr_feat=0.01 --gpu_id=3 --lr_adj=0.01 --r=0.02 --seed=0 --inner=3 --epochs=1000 --save=0
4
+
GCond/scripts/run_cross.sh ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ for r in 0.001 0.005 0.01
2
+ do
3
+ bash scripts/script_cross.sh flickr ${r} 0
4
+ bash scripts/script_cross.sh ogbn-arxiv ${r} 0
5
+ done
6
+
7
+ for r in 0.25 0.5 1
8
+ do
9
+ bash scripts/script_cross.sh citeseer ${r} 0
10
+ bash scripts/script_cross.sh cora ${r} 0
11
+ done
12
+
13
+ for r in 0.001 0.0005 0.002
14
+ do
15
+ bash scripts/script_cross.sh reddit ${r} 0
16
+ done
GCond/scripts/run_main.sh ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ for r in 0.25 0.5 1
2
+ do
3
+ python train_gcond_transduct.py --dataset cora --nlayers=2 --sgc=1 --lr_feat=1e-4 --gpu_id=0 --lr_adj=1e-4 --r=${r} --seed=1 --epoch=600 --save=0
4
+ done
5
+
6
+ for r in 0.25 0.5 1
7
+ do
8
+ python train_gcond_transduct.py --dataset citeseer --nlayers=2 --sgc=1 --lr_feat=1e-4 --gpu_id=0 --lr_adj=1e-4 --r=${r} --seed=1 --epoch=600 --save=0
9
+ done
10
+
11
+
12
+ for r in 0.001 0.005 0.01
13
+ do
14
+ python train_gcond_transduct.py --dataset ogbn-arxiv --nlayers=2 --sgc=1 --lr_feat=0.01 --gpu_id=3 --lr_adj=0.01 --r=${r} --seed=1 --inner=3 --epochs=1000 --save=0
15
+ done
16
+
17
+
18
+ for r in 0.001 0.005 0.01
19
+ do
20
+ python train_gcond_induct.py --dataset flickr --sgc=2 --nlayers=2 --lr_feat=0.005 --lr_adj=0.005 --r=${r} --seed=1 --gpu_id=0 --epochs=1000 --inner=1 --outer=10 --save=0
21
+ done
22
+
23
+ for r in 0.001 0.005 0.0005 0.002
24
+ do
25
+ python train_gcond_induct.py --dataset reddit --sgc=1 --nlayers=2 --lr_feat=0.1 --lr_adj=0.1 --r=${r} --seed=1 --gpu_id=0 --epochs=1000 --inner=1 --outer=10 --save=0
26
+ done
GCond/scripts/script_cross.sh ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ dataset=${1}
2
+ r=${2}
3
+ gpu_id=${3}
4
+ for s in 0 1 2 3 4
5
+ do
6
+ python test_other_arcs.py --dataset ${dataset} --gpu_id=${gpu_id} --r=${r} --seed=${s} --nruns=10 >> res/flickr/${1}_${2}.out
7
+ done
GCond/test_other_arcs.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from deeprobust.graph.data import Dataset
3
+ import numpy as np
4
+ import random
5
+ import time
6
+ import argparse
7
+ import torch
8
+ from utils import *
9
+ import torch.nn.functional as F
10
+ from tester_other_arcs import Evaluator
11
+ from utils_graphsaint import DataGraphSAINT
12
+
13
+
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument('--gpu_id', type=int, default=0, help='gpu id')
16
+ parser.add_argument('--dataset', type=str, default='cora')
17
+ parser.add_argument('--nlayers', type=int, default=2)
18
+ parser.add_argument('--hidden', type=int, default=256)
19
+ parser.add_argument('--keep_ratio', type=float, default=1)
20
+ parser.add_argument('--reduction_rate', type=float, default=1)
21
+ parser.add_argument('--weight_decay', type=float, default=0.0)
22
+ parser.add_argument('--dropout', type=float, default=0.0)
23
+ parser.add_argument('--normalize_features', type=bool, default=True)
24
+ parser.add_argument('--seed', type=int, default=0, help='Random seed.')
25
+ parser.add_argument('--mlp', type=int, default=0)
26
+ parser.add_argument('--inner', type=int, default=0)
27
+ parser.add_argument('--epsilon', type=float, default=-1)
28
+ parser.add_argument('--nruns', type=int, default=20)
29
+ args = parser.parse_args()
30
+
31
+ torch.cuda.set_device(args.gpu_id)
32
+
33
+ # random seed setting
34
+ random.seed(args.seed)
35
+ np.random.seed(args.seed)
36
+ torch.manual_seed(args.seed)
37
+ torch.cuda.manual_seed(args.seed)
38
+
39
+ if args.dataset in ['cora', 'citeseer']:
40
+ args.epsilon = 0.05
41
+ else:
42
+ args.epsilon = 0.01
43
+
44
+ print(args)
45
+
46
+ data_graphsaint = ['flickr', 'reddit', 'ogbn-arxiv']
47
+ if args.dataset in data_graphsaint:
48
+ data = DataGraphSAINT(args.dataset)
49
+ data_full = data.data_full
50
+ else:
51
+ data_full = get_dataset(args.dataset, args.normalize_features)
52
+ data = Transd2Ind(data_full, keep_ratio=args.keep_ratio)
53
+
54
+ agent = Evaluator(data, args, device='cuda')
55
+ agent.train()
GCond/tester_other_arcs.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.optim as optim
6
+ from torch.nn import Parameter
7
+ import torch.nn.functional as F
8
+ from utils import match_loss, regularization, row_normalize_tensor
9
+ import deeprobust.graph.utils as utils
10
+ from copy import deepcopy
11
+ import numpy as np
12
+ from tqdm import tqdm
13
+ from models.gcn import GCN
14
+ from models.sgc import SGC
15
+ from models.sgc_multi import SGC as SGC1
16
+ from models.myappnp import APPNP
17
+ from models.myappnp1 import APPNP1
18
+ from models.mycheby import Cheby
19
+ from models.mygraphsage import GraphSage
20
+ from models.gat import GAT
21
+ import scipy.sparse as sp
22
+
23
+
24
+ class Evaluator:
25
+
26
+ def __init__(self, data, args, device='cuda', **kwargs):
27
+ self.data = data
28
+ self.args = args
29
+ self.device = device
30
+ n = int(data.feat_train.shape[0] * args.reduction_rate)
31
+ d = data.feat_train.shape[1]
32
+ self.nnodes_syn = n
33
+ self.adj_param= nn.Parameter(torch.FloatTensor(n, n).to(device))
34
+ self.feat_syn = nn.Parameter(torch.FloatTensor(n, d).to(device))
35
+ self.labels_syn = torch.LongTensor(self.generate_labels_syn(data)).to(device)
36
+ self.reset_parameters()
37
+ print('adj_param:', self.adj_param.shape, 'feat_syn:', self.feat_syn.shape)
38
+
39
+ def reset_parameters(self):
40
+ self.adj_param.data.copy_(torch.randn(self.adj_param.size()))
41
+ self.feat_syn.data.copy_(torch.randn(self.feat_syn.size()))
42
+
43
+ def generate_labels_syn(self, data):
44
+ from collections import Counter
45
+ counter = Counter(data.labels_train)
46
+ num_class_dict = {}
47
+ n = len(data.labels_train)
48
+
49
+ sorted_counter = sorted(counter.items(), key=lambda x:x[1])
50
+ sum_ = 0
51
+ labels_syn = []
52
+ self.syn_class_indices = {}
53
+ for ix, (c, num) in enumerate(sorted_counter):
54
+ if ix == len(sorted_counter) - 1:
55
+ num_class_dict[c] = int(n * self.args.reduction_rate) - sum_
56
+ self.syn_class_indices[c] = [len(labels_syn), len(labels_syn) + num_class_dict[c]]
57
+ labels_syn += [c] * num_class_dict[c]
58
+ else:
59
+ num_class_dict[c] = max(int(num * self.args.reduction_rate), 1)
60
+ sum_ += num_class_dict[c]
61
+ self.syn_class_indices[c] = [len(labels_syn), len(labels_syn) + num_class_dict[c]]
62
+ labels_syn += [c] * num_class_dict[c]
63
+
64
+ self.num_class_dict = num_class_dict
65
+ return labels_syn
66
+
67
+
68
+ def test_gat(self, nlayers, model_type, verbose=False):
69
+ res = []
70
+ args = self.args
71
+
72
+ if args.dataset in ['cora', 'citeseer']:
73
+ args.epsilon = 0.5 # Make the graph sparser as GAT does not work well on dense graph
74
+ else:
75
+ args.epsilon = 0.01
76
+
77
+ print('======= testing %s' % model_type)
78
+ data, device = self.data, self.device
79
+
80
+
81
+ feat_syn, adj_syn, labels_syn = self.get_syn_data(model_type)
82
+ # with_bn = True if self.args.dataset in ['ogbn-arxiv'] else False
83
+ with_bn = False
84
+ if model_type == 'GAT':
85
+ model = GAT(nfeat=feat_syn.shape[1], nhid=16, heads=16, dropout=0.0,
86
+ weight_decay=0e-4, nlayers=self.args.nlayers, lr=0.001,
87
+ nclass=data.nclass, device=device, dataset=self.args.dataset).to(device)
88
+
89
+
90
+ noval = True if args.dataset in ['reddit', 'flickr'] else False
91
+ model.fit(feat_syn, adj_syn, labels_syn, np.arange(len(feat_syn)), noval=noval, data=data,
92
+ train_iters=10000 if noval else 3000, normalize=True, verbose=verbose)
93
+
94
+ model.eval()
95
+ labels_test = torch.LongTensor(data.labels_test).cuda()
96
+
97
+ if args.dataset in ['reddit', 'flickr']:
98
+ output = model.predict(data.feat_test, data.adj_test)
99
+ loss_test = F.nll_loss(output, labels_test)
100
+ acc_test = utils.accuracy(output, labels_test)
101
+ res.append(acc_test.item())
102
+ if verbose:
103
+ print("Test set results:",
104
+ "loss= {:.4f}".format(loss_test.item()),
105
+ "accuracy= {:.4f}".format(acc_test.item()))
106
+
107
+ else:
108
+ # Full graph
109
+ output = model.predict(data.feat_full, data.adj_full)
110
+ loss_test = F.nll_loss(output[data.idx_test], labels_test)
111
+ acc_test = utils.accuracy(output[data.idx_test], labels_test)
112
+ res.append(acc_test.item())
113
+ if verbose:
114
+ print("Test set results:",
115
+ "loss= {:.4f}".format(loss_test.item()),
116
+ "accuracy= {:.4f}".format(acc_test.item()))
117
+
118
+ labels_train = torch.LongTensor(data.labels_train).cuda()
119
+ output = model.predict(data.feat_train, data.adj_train)
120
+ loss_train = F.nll_loss(output, labels_train)
121
+ acc_train = utils.accuracy(output, labels_train)
122
+ if verbose:
123
+ print("Train set results:",
124
+ "loss= {:.4f}".format(loss_train.item()),
125
+ "accuracy= {:.4f}".format(acc_train.item()))
126
+ res.append(acc_train.item())
127
+ return res
128
+
129
+ def get_syn_data(self, model_type=None):
130
+ data, device = self.data, self.device
131
+ feat_syn, adj_param, labels_syn = self.feat_syn.detach(), \
132
+ self.adj_param.detach(), self.labels_syn
133
+
134
+ args = self.args
135
+ adj_syn = torch.load(f'saved_ours/adj_{args.dataset}_{args.reduction_rate}_{args.seed}.pt', map_location='cuda')
136
+ feat_syn = torch.load(f'saved_ours/feat_{args.dataset}_{args.reduction_rate}_{args.seed}.pt', map_location='cuda')
137
+
138
+ if model_type == 'MLP':
139
+ adj_syn = adj_syn.to(self.device)
140
+ adj_syn = adj_syn - adj_syn
141
+ else:
142
+ adj_syn = adj_syn.to(self.device)
143
+
144
+ print('Sum:', adj_syn.sum(), adj_syn.sum()/(adj_syn.shape[0]**2))
145
+ print('Sparsity:', adj_syn.nonzero().shape[0]/(adj_syn.shape[0]**2))
146
+
147
+ if self.args.epsilon > 0:
148
+ adj_syn[adj_syn < self.args.epsilon] = 0
149
+ print('Sparsity after truncating:', adj_syn.nonzero().shape[0]/(adj_syn.shape[0]**2))
150
+ feat_syn = feat_syn.to(self.device)
151
+
152
+ # edge_index = adj_syn.nonzero().T
153
+ # adj_syn = torch.sparse.FloatTensor(edge_index, adj_syn[edge_index[0], edge_index[1]], adj_syn.size())
154
+
155
+ return feat_syn, adj_syn, labels_syn
156
+
157
+
158
+ def test(self, nlayers, model_type, verbose=True):
159
+ res = []
160
+
161
+ args = self.args
162
+ data, device = self.data, self.device
163
+
164
+ feat_syn, adj_syn, labels_syn = self.get_syn_data(model_type)
165
+
166
+ print('======= testing %s' % model_type)
167
+ if model_type == 'MLP':
168
+ model_class = GCN
169
+ else:
170
+ model_class = eval(model_type)
171
+ weight_decay = 5e-4
172
+ dropout = 0.5 if args.dataset in ['reddit'] else 0
173
+
174
+ model = model_class(nfeat=feat_syn.shape[1], nhid=args.hidden, dropout=dropout,
175
+ weight_decay=weight_decay, nlayers=nlayers,
176
+ nclass=data.nclass, device=device).to(device)
177
+
178
+ # with_bn = True if self.args.dataset in ['ogbn-arxiv'] else False
179
+ if args.dataset in ['ogbn-arxiv', 'arxiv']:
180
+ model = model_class(nfeat=feat_syn.shape[1], nhid=args.hidden, dropout=0.,
181
+ weight_decay=weight_decay, nlayers=nlayers, with_bn=False,
182
+ nclass=data.nclass, device=device).to(device)
183
+
184
+ noval = True if args.dataset in ['reddit', 'flickr'] else False
185
+ model.fit_with_val(feat_syn, adj_syn, labels_syn, data,
186
+ train_iters=600, normalize=True, verbose=True, noval=noval)
187
+
188
+ model.eval()
189
+ labels_test = torch.LongTensor(data.labels_test).cuda()
190
+
191
+ if model_type == 'MLP':
192
+ output = model.predict_unnorm(data.feat_test, sp.eye(len(data.feat_test)))
193
+ else:
194
+ output = model.predict(data.feat_test, data.adj_test)
195
+
196
+ if args.dataset in ['reddit', 'flickr']:
197
+ loss_test = F.nll_loss(output, labels_test)
198
+ acc_test = utils.accuracy(output, labels_test)
199
+ res.append(acc_test.item())
200
+ if verbose:
201
+ print("Test set results:",
202
+ "loss= {:.4f}".format(loss_test.item()),
203
+ "accuracy= {:.4f}".format(acc_test.item()))
204
+
205
+ # if not args.dataset in ['reddit', 'flickr']:
206
+ else:
207
+ # Full graph
208
+ output = model.predict(data.feat_full, data.adj_full)
209
+ loss_test = F.nll_loss(output[data.idx_test], labels_test)
210
+ acc_test = utils.accuracy(output[data.idx_test], labels_test)
211
+ res.append(acc_test.item())
212
+ if verbose:
213
+ print("Test full set results:",
214
+ "loss= {:.4f}".format(loss_test.item()),
215
+ "accuracy= {:.4f}".format(acc_test.item()))
216
+
217
+ labels_train = torch.LongTensor(data.labels_train).cuda()
218
+ output = model.predict(data.feat_train, data.adj_train)
219
+ loss_train = F.nll_loss(output, labels_train)
220
+ acc_train = utils.accuracy(output, labels_train)
221
+ if verbose:
222
+ print("Train set results:",
223
+ "loss= {:.4f}".format(loss_train.item()),
224
+ "accuracy= {:.4f}".format(acc_train.item()))
225
+ res.append(acc_train.item())
226
+ return res
227
+
228
+ def train(self, verbose=True):
229
+ args = self.args
230
+ data = self.data
231
+
232
+ final_res = {}
233
+ runs = self.args.nruns
234
+
235
+ for model_type in ['GCN', 'GraphSage', 'SGC1', 'MLP', 'APPNP1', 'Cheby']:
236
+ res = []
237
+ nlayer = 2
238
+ for i in range(runs):
239
+ res.append(self.test(nlayer, verbose=False, model_type=model_type))
240
+ res = np.array(res)
241
+ print('Test/Train Mean Accuracy:',
242
+ repr([res.mean(0), res.std(0)]))
243
+ final_res[model_type] = [res.mean(0), res.std(0)]
244
+
245
+
246
+ print('=== testing GAT')
247
+ res = []
248
+ nlayer = 2
249
+ for i in range(runs):
250
+ res.append(self.test_gat(verbose=True, nlayers=nlayer, model_type='GAT'))
251
+ res = np.array(res)
252
+ print('Layer:', nlayer)
253
+ print('Test/Full Test/Train Mean Accuracy:',
254
+ repr([res.mean(0), res.std(0)]))
255
+ final_res['GAT'] = [res.mean(0), res.std(0)]
256
+
257
+ print('Final result:', final_res)
258
+
GCond/train_coreset.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from deeprobust.graph.data import Dataset
2
+ import numpy as np
3
+ import random
4
+ import time
5
+ import argparse
6
+ import torch
7
+ import sys
8
+ from deeprobust.graph.utils import *
9
+ import torch.nn.functional as F
10
+ from configs import load_config
11
+ from utils import *
12
+ from utils_graphsaint import DataGraphSAINT
13
+ from models.gcn import GCN
14
+ from coreset import KCenter, Herding, Random, LRMC
15
+ from tqdm import tqdm
16
+
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument('--gpu_id', type=int, default=0, help='gpu id')
19
+ parser.add_argument('--dataset', type=str, default='cora')
20
+ parser.add_argument('--hidden', type=int, default=256)
21
+ parser.add_argument('--normalize_features', type=bool, default=True)
22
+ parser.add_argument('--keep_ratio', type=float, default=1.0)
23
+ parser.add_argument('--lr', type=float, default=0.01)
24
+ parser.add_argument('--weight_decay', type=float, default=5e-4)
25
+ parser.add_argument('--seed', type=int, default=15, help='Random seed.')
26
+ parser.add_argument('--nlayers', type=int, default=2, help='Random seed.')
27
+ parser.add_argument('--epochs', type=int, default=400)
28
+ parser.add_argument('--inductive', type=int, default=1)
29
+ parser.add_argument('--save', type=int, default=0)
30
+ parser.add_argument('--method', type=str, choices=['kcenter', 'herding', 'random', 'lrmc'])
31
+ parser.add_argument('--lrmc_seeds_path', type=str, default=None,
32
+ help='Path to a JSON file containing L‑RMC seed nodes. Required when method=lrmc.')
33
+ parser.add_argument('--reduction_rate', type=float, required=True)
34
+ args = parser.parse_args()
35
+
36
+ torch.cuda.set_device(args.gpu_id)
37
+ args = load_config(args)
38
+ print(args)
39
+
40
+ # random seed setting
41
+ random.seed(args.seed)
42
+ np.random.seed(args.seed)
43
+ torch.manual_seed(args.seed)
44
+ torch.cuda.manual_seed(args.seed)
45
+
46
+
47
+ data_graphsaint = ['flickr', 'reddit', 'ogbn-arxiv']
48
+ if args.dataset in data_graphsaint:
49
+ data = DataGraphSAINT(args.dataset)
50
+ data_full = data.data_full
51
+ data = Transd2Ind(data_full, keep_ratio=args.keep_ratio)
52
+ else:
53
+ data_full = get_dataset(args.dataset, args.normalize_features)
54
+ data = Transd2Ind(data_full, keep_ratio=args.keep_ratio)
55
+
56
+ features = data_full.features
57
+ adj = data_full.adj
58
+ labels = data_full.labels
59
+ idx_train = data_full.idx_train
60
+ idx_val = data_full.idx_val
61
+ idx_test = data_full.idx_test
62
+
63
+ # Setup GCN Model
64
+ device = 'cuda'
65
+ model = GCN(nfeat=features.shape[1], nhid=256, nclass=labels.max()+1, device=device, weight_decay=args.weight_decay)
66
+
67
+ model = model.to(device)
68
+ model.fit(features, adj, labels, idx_train, idx_val, train_iters=600, verbose=False)
69
+
70
+ model.eval()
71
+ # You can use the inner function of model to test
72
+ model.test(idx_test)
73
+
74
+ embeds = model.predict().detach()
75
+
76
+ if args.method == 'kcenter':
77
+ agent = KCenter(data, args, device='cuda')
78
+ elif args.method == 'herding':
79
+ agent = Herding(data, args, device='cuda')
80
+ elif args.method == 'random':
81
+ agent = Random(data, args, device='cuda')
82
+ elif args.method == 'lrmc':
83
+ if args.lrmc_seeds_path is None:
84
+ raise ValueError("--lrmc_seeds_path must be specified when method='lrmc'")
85
+ agent = LRMC(data, args, device='cuda')
86
+
87
+ idx_selected = agent.select(embeds)
88
+
89
+
90
+ feat_train = features[idx_selected]
91
+ adj_train = adj[np.ix_(idx_selected, idx_selected)]
92
+
93
+ labels_train = labels[idx_selected]
94
+
95
+ if args.save:
96
+ np.save(f'saved/idx_{args.dataset}_{args.reduction_rate}_{args.method}_{args.seed}.npy', idx_selected)
97
+
98
+
99
+ res = []
100
+ runs = 10
101
+ for _ in tqdm(range(runs)):
102
+ model.initialize()
103
+ model.fit_with_val(feat_train, adj_train, labels_train, data,
104
+ train_iters=600, normalize=True, verbose=False)
105
+
106
+ model.eval()
107
+ labels_test = torch.LongTensor(data.labels_test).cuda()
108
+
109
+ # Full graph
110
+ output = model.predict(data.feat_full, data.adj_full)
111
+ loss_test = F.nll_loss(output[data.idx_test], labels_test)
112
+ acc_test = accuracy(output[data.idx_test], labels_test)
113
+ res.append(acc_test.item())
114
+
115
+ res = np.array(res)
116
+ print('Mean accuracy:', repr([res.mean(), res.std()]))
117
+
GCond/train_coreset_induct.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from deeprobust.graph.data import Dataset
2
+ import numpy as np
3
+ import random
4
+ import time
5
+ import argparse
6
+ import torch
7
+ import sys
8
+ import deeprobust.graph.utils as utils
9
+ import torch.nn.functional as F
10
+ from configs import load_config
11
+ from utils import *
12
+ from utils_graphsaint import DataGraphSAINT
13
+ from models.gcn import GCN
14
+ from coreset import KCenter, Herding, Random, LRMC
15
+ from tqdm import tqdm
16
+
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument('--gpu_id', type=int, default=0, help='gpu id')
19
+ parser.add_argument('--dataset', type=str, default='cora')
20
+ parser.add_argument('--hidden', type=int, default=256)
21
+ parser.add_argument('--normalize_features', type=bool, default=True)
22
+ parser.add_argument('--keep_ratio', type=float, default=1.0)
23
+ parser.add_argument('--lr', type=float, default=0.01)
24
+ parser.add_argument('--weight_decay', type=float, default=5e-4)
25
+ parser.add_argument('--dropout', type=float, default=0.5)
26
+ parser.add_argument('--seed', type=int, default=15, help='Random seed.')
27
+ parser.add_argument('--nlayers', type=int, default=2, help='Random seed.')
28
+ parser.add_argument('--epochs', type=int, default=400)
29
+ parser.add_argument('--inductive', type=int, default=1)
30
+ parser.add_argument('--mlp', type=int, default=0)
31
+ parser.add_argument('--method', type=str, choices=['kcenter', 'herding', 'random', 'lrmc'])
32
+ parser.add_argument('--lrmc_seeds_path', type=str, default=None,
33
+ help='Path to a JSON file containing L‑RMC seed nodes. Required when method=lrmc.')
34
+ parser.add_argument('--reduction_rate', type=float, required=True)
35
+ args = parser.parse_args()
36
+
37
+ torch.cuda.set_device(args.gpu_id)
38
+ args = load_config(args)
39
+ print(args)
40
+
41
+ # random seed setting
42
+ random.seed(args.seed)
43
+ np.random.seed(args.seed)
44
+ torch.manual_seed(args.seed)
45
+ torch.cuda.manual_seed(args.seed)
46
+
47
+ data_graphsaint = ['flickr', 'reddit', 'ogbn-arxiv']
48
+ if args.dataset in data_graphsaint:
49
+ data = DataGraphSAINT(args.dataset)
50
+ data_full = data.data_full
51
+ data = Transd2Ind(data_full, keep_ratio=args.keep_ratio)
52
+ else:
53
+ data_full = get_dataset(args.dataset, args.normalize_features)
54
+ data = Transd2Ind(data_full, keep_ratio=args.keep_ratio)
55
+
56
+ feat_train = data.feat_train
57
+ adj_train = data.adj_train
58
+ labels_train = data.labels_train
59
+
60
+
61
+ # Setup GCN Model
62
+ device = 'cuda'
63
+ model = GCN(nfeat=feat_train.shape[1], nhid=256, nclass=labels_train.max()+1, device=device, weight_decay=args.weight_decay)
64
+
65
+ model = model.to(device)
66
+
67
+ model.fit_with_val(feat_train, adj_train, labels_train, data,
68
+ train_iters=600, normalize=True, verbose=False)
69
+
70
+ model.eval()
71
+ labels_test = torch.LongTensor(data.labels_test).cuda()
72
+ feat_test, adj_test = data.feat_test, data.adj_test
73
+
74
+ embeds = model.predict().detach()
75
+
76
+ output = model.predict(feat_test, adj_test)
77
+ loss_test = F.nll_loss(output, labels_test)
78
+ acc_test = utils.accuracy(output, labels_test)
79
+ print("Test set results:",
80
+ "loss= {:.4f}".format(loss_test.item()),
81
+ "accuracy= {:.4f}".format(acc_test.item()))
82
+
83
+
84
+ if args.method == 'kcenter':
85
+ agent = KCenter(data, args, device='cuda')
86
+ elif args.method == 'herding':
87
+ agent = Herding(data, args, device='cuda')
88
+ elif args.method == 'random':
89
+ agent = Random(data, args, device='cuda')
90
+ elif args.method == 'lrmc':
91
+ if args.lrmc_seeds_path is None:
92
+ raise ValueError("--lrmc_seeds_path must be specified when method='lrmc'")
93
+ agent = LRMC(data, args, device='cuda')
94
+
95
+ idx_selected = agent.select(embeds, inductive=True)
96
+
97
+ feat_train = feat_train[idx_selected]
98
+ adj_train = adj_train[np.ix_(idx_selected, idx_selected)]
99
+
100
+ labels_train = labels_train[idx_selected]
101
+
102
+ res = []
103
+ print('shape of feat_train:', feat_train.shape)
104
+ runs = 10
105
+ for _ in tqdm(range(runs)):
106
+ model.initialize()
107
+ model.fit_with_val(feat_train, adj_train, labels_train, data,
108
+ train_iters=600, normalize=True, verbose=False, noval=True)
109
+
110
+ model.eval()
111
+ labels_test = torch.LongTensor(data.labels_test).cuda()
112
+
113
+ output = model.predict(feat_test, adj_test)
114
+ loss_test = F.nll_loss(output, labels_test)
115
+ acc_test = utils.accuracy(output, labels_test)
116
+ res.append(acc_test.item())
117
+ res = np.array(res)
118
+ print('Mean accuracy:', repr([res.mean(), res.std()]))
119
+
GCond/train_gcond_induct.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from deeprobust.graph.data import Dataset
2
+ import numpy as np
3
+ import random
4
+ import time
5
+ import argparse
6
+ import torch
7
+ from utils import *
8
+ import torch.nn.functional as F
9
+ from gcond_agent_induct import GCond
10
+ from utils_graphsaint import DataGraphSAINT
11
+
12
+
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument('--gpu_id', type=int, default=0, help='gpu id')
15
+ parser.add_argument('--dataset', type=str, default='cora')
16
+ parser.add_argument('--dis_metric', type=str, default='ours')
17
+ parser.add_argument('--epochs', type=int, default=600)
18
+ parser.add_argument('--nlayers', type=int, default=3)
19
+ parser.add_argument('--hidden', type=int, default=256)
20
+ parser.add_argument('--lr_adj', type=float, default=0.01)
21
+ parser.add_argument('--lr_feat', type=float, default=0.01)
22
+ parser.add_argument('--lr_model', type=float, default=0.01)
23
+ parser.add_argument('--weight_decay', type=float, default=0.0)
24
+ parser.add_argument('--dropout', type=float, default=0.0)
25
+ parser.add_argument('--normalize_features', type=bool, default=True)
26
+ parser.add_argument('--keep_ratio', type=float, default=1.0)
27
+ parser.add_argument('--reduction_rate', type=float, default=0.01)
28
+ parser.add_argument('--seed', type=int, default=15, help='Random seed.')
29
+ parser.add_argument('--alpha', type=float, default=0, help='regularization term.')
30
+ parser.add_argument('--debug', type=int, default=0)
31
+ parser.add_argument('--sgc', type=int, default=1)
32
+ parser.add_argument('--inner', type=int, default=0)
33
+ parser.add_argument('--outer', type=int, default=20)
34
+ parser.add_argument('--option', type=int, default=0)
35
+ parser.add_argument('--save', type=int, default=0)
36
+ parser.add_argument('--label_rate', type=float, default=1)
37
+ parser.add_argument('--one_step', type=int, default=0)
38
+ args = parser.parse_args()
39
+
40
+ torch.cuda.set_device(args.gpu_id)
41
+
42
+ # random seed setting
43
+ random.seed(args.seed)
44
+ np.random.seed(args.seed)
45
+ torch.manual_seed(args.seed)
46
+ torch.cuda.manual_seed(args.seed)
47
+
48
+ print(args)
49
+
50
+ data_graphsaint = ['flickr', 'reddit', 'ogbn-arxiv']
51
+ if args.dataset in data_graphsaint:
52
+ # data = DataGraphSAINT(args.dataset)
53
+ data = DataGraphSAINT(args.dataset, label_rate=args.label_rate)
54
+ data_full = data.data_full
55
+ else:
56
+ data_full = get_dataset(args.dataset, args.normalize_features)
57
+ data = Transd2Ind(data_full, keep_ratio=args.keep_ratio)
58
+
59
+ agent = GCond(data, args, device='cuda')
60
+
61
+ agent.train()
GCond/train_gcond_transduct.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from deeprobust.graph.data import Dataset
2
+ import numpy as np
3
+ import random
4
+ import time
5
+ import argparse
6
+ import torch
7
+ from utils import *
8
+ import torch.nn.functional as F
9
+ from gcond_agent_transduct import GCond
10
+ from utils_graphsaint import DataGraphSAINT
11
+
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument('--gpu_id', type=int, default=0, help='gpu id')
14
+ parser.add_argument('--dataset', type=str, default='cora')
15
+ parser.add_argument('--dis_metric', type=str, default='ours')
16
+ parser.add_argument('--epochs', type=int, default=2000)
17
+ parser.add_argument('--nlayers', type=int, default=3)
18
+ parser.add_argument('--hidden', type=int, default=256)
19
+ parser.add_argument('--lr_adj', type=float, default=0.01)
20
+ parser.add_argument('--lr_feat', type=float, default=0.01)
21
+ parser.add_argument('--lr_model', type=float, default=0.01)
22
+ parser.add_argument('--weight_decay', type=float, default=0.0)
23
+ parser.add_argument('--dropout', type=float, default=0.0)
24
+ parser.add_argument('--normalize_features', type=bool, default=True)
25
+ parser.add_argument('--keep_ratio', type=float, default=1.0)
26
+ parser.add_argument('--reduction_rate', type=float, default=1)
27
+ parser.add_argument('--seed', type=int, default=15, help='Random seed.')
28
+ parser.add_argument('--alpha', type=float, default=0, help='regularization term.')
29
+ parser.add_argument('--debug', type=int, default=0)
30
+ parser.add_argument('--sgc', type=int, default=1)
31
+ parser.add_argument('--inner', type=int, default=0)
32
+ parser.add_argument('--outer', type=int, default=20)
33
+ parser.add_argument('--save', type=int, default=0)
34
+ parser.add_argument('--one_step', type=int, default=0)
35
+ args = parser.parse_args()
36
+
37
+ torch.cuda.set_device(args.gpu_id)
38
+
39
+ # random seed setting
40
+ random.seed(args.seed)
41
+ np.random.seed(args.seed)
42
+ torch.manual_seed(args.seed)
43
+ torch.cuda.manual_seed(args.seed)
44
+
45
+ print(args)
46
+
47
+ data_graphsaint = ['flickr', 'reddit', 'ogbn-arxiv']
48
+ if args.dataset in data_graphsaint:
49
+ data = DataGraphSAINT(args.dataset)
50
+ data_full = data.data_full
51
+ else:
52
+ data_full = get_dataset(args.dataset, args.normalize_features)
53
+ data = Transd2Ind(data_full, keep_ratio=args.keep_ratio)
54
+
55
+ agent = GCond(data, args, device='cuda')
56
+
57
+ agent.train()
GCond/utils.py ADDED
@@ -0,0 +1,383 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os.path as osp
2
+ import numpy as np
3
+ import scipy.sparse as sp
4
+ import torch
5
+ import torch_geometric.transforms as T
6
+ from ogb.nodeproppred import PygNodePropPredDataset, Evaluator
7
+ from deeprobust.graph.data import Dataset
8
+ from deeprobust.graph.utils import get_train_val_test
9
+ from torch_geometric.utils import train_test_split_edges
10
+ from sklearn.model_selection import train_test_split
11
+ from sklearn import metrics
12
+ import numpy as np
13
+ import torch.nn.functional as F
14
+ from sklearn.preprocessing import StandardScaler
15
+ from deeprobust.graph.utils import *
16
+ from torch_geometric.data import NeighborSampler
17
+ from torch_geometric.utils import add_remaining_self_loops, to_undirected
18
+ from torch_geometric.datasets import Planetoid
19
+
20
+
21
+ def get_dataset(name, normalize_features=False, transform=None, if_dpr=True):
22
+ path = osp.join(osp.dirname(osp.realpath(__file__)), 'data', name)
23
+ if name in ['cora', 'citeseer', 'pubmed']:
24
+ dataset = Planetoid(path, name)
25
+ elif name in ['ogbn-arxiv']:
26
+ dataset = PygNodePropPredDataset(name='ogbn-arxiv')
27
+ else:
28
+ raise NotImplementedError
29
+
30
+ if transform is not None and normalize_features:
31
+ dataset.transform = T.Compose([T.NormalizeFeatures(), transform])
32
+ elif normalize_features:
33
+ dataset.transform = T.NormalizeFeatures()
34
+ elif transform is not None:
35
+ dataset.transform = transform
36
+
37
+ dpr_data = Pyg2Dpr(dataset)
38
+ if name in ['ogbn-arxiv']:
39
+ # the features are different from the features provided by GraphSAINT
40
+ # normalize features, following graphsaint
41
+ feat, idx_train = dpr_data.features, dpr_data.idx_train
42
+ feat_train = feat[idx_train]
43
+ scaler = StandardScaler()
44
+ scaler.fit(feat_train)
45
+ feat = scaler.transform(feat)
46
+ dpr_data.features = feat
47
+
48
+ return dpr_data
49
+
50
+
51
+ class Pyg2Dpr(Dataset):
52
+ def __init__(self, pyg_data, **kwargs):
53
+ try:
54
+ splits = pyg_data.get_idx_split()
55
+ except:
56
+ pass
57
+
58
+ dataset_name = pyg_data.name
59
+ pyg_data = pyg_data[0]
60
+ n = pyg_data.num_nodes
61
+
62
+ if dataset_name == 'ogbn-arxiv': # symmetrization
63
+ pyg_data.edge_index = to_undirected(pyg_data.edge_index, pyg_data.num_nodes)
64
+
65
+ self.adj = sp.csr_matrix((np.ones(pyg_data.edge_index.shape[1]),
66
+ (pyg_data.edge_index[0], pyg_data.edge_index[1])), shape=(n, n))
67
+
68
+ self.features = pyg_data.x.numpy()
69
+ self.labels = pyg_data.y.numpy()
70
+
71
+ if len(self.labels.shape) == 2 and self.labels.shape[1] == 1:
72
+ self.labels = self.labels.reshape(-1) # ogb-arxiv needs to reshape
73
+
74
+ if hasattr(pyg_data, 'train_mask'):
75
+ # for fixed split
76
+ self.idx_train = mask_to_index(pyg_data.train_mask, n)
77
+ self.idx_val = mask_to_index(pyg_data.val_mask, n)
78
+ self.idx_test = mask_to_index(pyg_data.test_mask, n)
79
+ self.name = 'Pyg2Dpr'
80
+ else:
81
+ try:
82
+ # for ogb
83
+ self.idx_train = splits['train']
84
+ self.idx_val = splits['valid']
85
+ self.idx_test = splits['test']
86
+ self.name = 'Pyg2Dpr'
87
+ except:
88
+ # for other datasets
89
+ self.idx_train, self.idx_val, self.idx_test = get_train_val_test(
90
+ nnodes=n, val_size=0.1, test_size=0.8, stratify=self.labels)
91
+
92
+
93
+ def mask_to_index(index, size):
94
+ all_idx = np.arange(size)
95
+ return all_idx[index]
96
+
97
+ def index_to_mask(index, size):
98
+ mask = torch.zeros((size, ), dtype=torch.bool)
99
+ mask[index] = 1
100
+ return mask
101
+
102
+
103
+
104
+ class Transd2Ind:
105
+ # transductive setting to inductive setting
106
+
107
+ def __init__(self, dpr_data, keep_ratio):
108
+ idx_train, idx_val, idx_test = dpr_data.idx_train, dpr_data.idx_val, dpr_data.idx_test
109
+ adj, features, labels = dpr_data.adj, dpr_data.features, dpr_data.labels
110
+ self.nclass = labels.max()+1
111
+ self.adj_full, self.feat_full, self.labels_full = adj, features, labels
112
+ self.idx_train = np.array(idx_train)
113
+ self.idx_val = np.array(idx_val)
114
+ self.idx_test = np.array(idx_test)
115
+
116
+ if keep_ratio < 1:
117
+ idx_train, _ = train_test_split(idx_train,
118
+ random_state=None,
119
+ train_size=keep_ratio,
120
+ test_size=1-keep_ratio,
121
+ stratify=labels[idx_train])
122
+
123
+ self.adj_train = adj[np.ix_(idx_train, idx_train)]
124
+ self.adj_val = adj[np.ix_(idx_val, idx_val)]
125
+ self.adj_test = adj[np.ix_(idx_test, idx_test)]
126
+ print('size of adj_train:', self.adj_train.shape)
127
+ print('#edges in adj_train:', self.adj_train.sum())
128
+
129
+ self.labels_train = labels[idx_train]
130
+ self.labels_val = labels[idx_val]
131
+ self.labels_test = labels[idx_test]
132
+
133
+ self.feat_train = features[idx_train]
134
+ self.feat_val = features[idx_val]
135
+ self.feat_test = features[idx_test]
136
+
137
+ self.class_dict = None
138
+ self.samplers = None
139
+ self.class_dict2 = None
140
+
141
+ def retrieve_class(self, c, num=256):
142
+ if self.class_dict is None:
143
+ self.class_dict = {}
144
+ for i in range(self.nclass):
145
+ self.class_dict['class_%s'%i] = (self.labels_train == i)
146
+ idx = np.arange(len(self.labels_train))
147
+ idx = idx[self.class_dict['class_%s'%c]]
148
+ return np.random.permutation(idx)[:num]
149
+
150
+ def retrieve_class_sampler(self, c, adj, transductive, num=256, args=None):
151
+ if self.class_dict2 is None:
152
+ self.class_dict2 = {}
153
+ for i in range(self.nclass):
154
+ if transductive:
155
+ idx = self.idx_train[self.labels_train == i]
156
+ else:
157
+ idx = np.arange(len(self.labels_train))[self.labels_train==i]
158
+ self.class_dict2[i] = idx
159
+
160
+ if args.nlayers == 1:
161
+ sizes = [15]
162
+ if args.nlayers == 2:
163
+ sizes = [10, 5]
164
+ # sizes = [-1, -1]
165
+ if args.nlayers == 3:
166
+ sizes = [15, 10, 5]
167
+ if args.nlayers == 4:
168
+ sizes = [15, 10, 5, 5]
169
+ if args.nlayers == 5:
170
+ sizes = [15, 10, 5, 5, 5]
171
+
172
+
173
+ if self.samplers is None:
174
+ self.samplers = []
175
+ for i in range(self.nclass):
176
+ node_idx = torch.LongTensor(self.class_dict2[i])
177
+ self.samplers.append(NeighborSampler(adj,
178
+ node_idx=node_idx,
179
+ sizes=sizes, batch_size=num,
180
+ num_workers=12, return_e_id=False,
181
+ num_nodes=adj.size(0),
182
+ shuffle=True))
183
+ batch = np.random.permutation(self.class_dict2[c])[:num]
184
+ out = self.samplers[c].sample(batch)
185
+ return out
186
+
187
+ def retrieve_class_multi_sampler(self, c, adj, transductive, num=256, args=None):
188
+ if self.class_dict2 is None:
189
+ self.class_dict2 = {}
190
+ for i in range(self.nclass):
191
+ if transductive:
192
+ idx = self.idx_train[self.labels_train == i]
193
+ else:
194
+ idx = np.arange(len(self.labels_train))[self.labels_train==i]
195
+ self.class_dict2[i] = idx
196
+
197
+
198
+ if self.samplers is None:
199
+ self.samplers = []
200
+ for l in range(2):
201
+ layer_samplers = []
202
+ sizes = [15] if l == 0 else [10, 5]
203
+ for i in range(self.nclass):
204
+ node_idx = torch.LongTensor(self.class_dict2[i])
205
+ layer_samplers.append(NeighborSampler(adj,
206
+ node_idx=node_idx,
207
+ sizes=sizes, batch_size=num,
208
+ num_workers=12, return_e_id=False,
209
+ num_nodes=adj.size(0),
210
+ shuffle=True))
211
+ self.samplers.append(layer_samplers)
212
+ batch = np.random.permutation(self.class_dict2[c])[:num]
213
+ out = self.samplers[args.nlayers-1][c].sample(batch)
214
+ return out
215
+
216
+
217
+
218
+ def match_loss(gw_syn, gw_real, args, device):
219
+ dis = torch.tensor(0.0).to(device)
220
+
221
+ if args.dis_metric == 'ours':
222
+
223
+ for ig in range(len(gw_real)):
224
+ gwr = gw_real[ig]
225
+ gws = gw_syn[ig]
226
+ dis += distance_wb(gwr, gws)
227
+
228
+ elif args.dis_metric == 'mse':
229
+ gw_real_vec = []
230
+ gw_syn_vec = []
231
+ for ig in range(len(gw_real)):
232
+ gw_real_vec.append(gw_real[ig].reshape((-1)))
233
+ gw_syn_vec.append(gw_syn[ig].reshape((-1)))
234
+ gw_real_vec = torch.cat(gw_real_vec, dim=0)
235
+ gw_syn_vec = torch.cat(gw_syn_vec, dim=0)
236
+ dis = torch.sum((gw_syn_vec - gw_real_vec)**2)
237
+
238
+ elif args.dis_metric == 'cos':
239
+ gw_real_vec = []
240
+ gw_syn_vec = []
241
+ for ig in range(len(gw_real)):
242
+ gw_real_vec.append(gw_real[ig].reshape((-1)))
243
+ gw_syn_vec.append(gw_syn[ig].reshape((-1)))
244
+ gw_real_vec = torch.cat(gw_real_vec, dim=0)
245
+ gw_syn_vec = torch.cat(gw_syn_vec, dim=0)
246
+ dis = 1 - torch.sum(gw_real_vec * gw_syn_vec, dim=-1) / (torch.norm(gw_real_vec, dim=-1) * torch.norm(gw_syn_vec, dim=-1) + 0.000001)
247
+
248
+ else:
249
+ exit('DC error: unknown distance function')
250
+
251
+ return dis
252
+
253
+ def distance_wb(gwr, gws):
254
+ shape = gwr.shape
255
+
256
+ # TODO: output node!!!!
257
+ if len(gwr.shape) == 2:
258
+ gwr = gwr.T
259
+ gws = gws.T
260
+
261
+ if len(shape) == 4: # conv, out*in*h*w
262
+ gwr = gwr.reshape(shape[0], shape[1] * shape[2] * shape[3])
263
+ gws = gws.reshape(shape[0], shape[1] * shape[2] * shape[3])
264
+ elif len(shape) == 3: # layernorm, C*h*w
265
+ gwr = gwr.reshape(shape[0], shape[1] * shape[2])
266
+ gws = gws.reshape(shape[0], shape[1] * shape[2])
267
+ elif len(shape) == 2: # linear, out*in
268
+ tmp = 'do nothing'
269
+ elif len(shape) == 1: # batchnorm/instancenorm, C; groupnorm x, bias
270
+ gwr = gwr.reshape(1, shape[0])
271
+ gws = gws.reshape(1, shape[0])
272
+ return 0
273
+
274
+ dis_weight = torch.sum(1 - torch.sum(gwr * gws, dim=-1) / (torch.norm(gwr, dim=-1) * torch.norm(gws, dim=-1) + 0.000001))
275
+ dis = dis_weight
276
+ return dis
277
+
278
+
279
+
280
+ def calc_f1(y_true, y_pred,is_sigmoid):
281
+ if not is_sigmoid:
282
+ y_pred = np.argmax(y_pred, axis=1)
283
+ else:
284
+ y_pred[y_pred > 0.5] = 1
285
+ y_pred[y_pred <= 0.5] = 0
286
+ return metrics.f1_score(y_true, y_pred, average="micro"), metrics.f1_score(y_true, y_pred, average="macro")
287
+
288
+ def evaluate(output, labels, args):
289
+ data_graphsaint = ['yelp', 'ppi', 'ppi-large', 'flickr', 'reddit', 'amazon']
290
+ if args.dataset in data_graphsaint:
291
+ labels = labels.cpu().numpy()
292
+ output = output.cpu().numpy()
293
+ if len(labels.shape) > 1:
294
+ micro, macro = calc_f1(labels, output, is_sigmoid=True)
295
+ else:
296
+ micro, macro = calc_f1(labels, output, is_sigmoid=False)
297
+ print("Test set results:", "F1-micro= {:.4f}".format(micro),
298
+ "F1-macro= {:.4f}".format(macro))
299
+ else:
300
+ loss_test = F.nll_loss(output, labels)
301
+ acc_test = accuracy(output, labels)
302
+ print("Test set results:",
303
+ "loss= {:.4f}".format(loss_test.item()),
304
+ "accuracy= {:.4f}".format(acc_test.item()))
305
+ return
306
+
307
+
308
+ from torchvision import datasets, transforms
309
+ def get_mnist(data_path):
310
+ channel = 1
311
+ im_size = (28, 28)
312
+ num_classes = 10
313
+ mean = [0.1307]
314
+ std = [0.3081]
315
+ transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=mean, std=std)])
316
+ dst_train = datasets.MNIST(data_path, train=True, download=True, transform=transform) # no augmentation
317
+ dst_test = datasets.MNIST(data_path, train=False, download=True, transform=transform)
318
+ class_names = [str(c) for c in range(num_classes)]
319
+
320
+ labels = []
321
+ feat = []
322
+ for x, y in dst_train:
323
+ feat.append(x.view(1, -1))
324
+ labels.append(y)
325
+ feat = torch.cat(feat, axis=0).numpy()
326
+ from utils_graphsaint import GraphData
327
+ adj = sp.eye(len(feat))
328
+ idx = np.arange(len(feat))
329
+ dpr_data = GraphData(adj-adj, feat, labels, idx, idx, idx)
330
+ from deeprobust.graph.data import Dpr2Pyg
331
+ return Dpr2Pyg(dpr_data)
332
+
333
+ def regularization(adj, x, eig_real=None):
334
+ # fLf
335
+ loss = 0
336
+ # loss += torch.norm(adj, p=1)
337
+ loss += feature_smoothing(adj, x)
338
+ return loss
339
+
340
+ def maxdegree(adj):
341
+ n = adj.shape[0]
342
+ return F.relu(max(adj.sum(1))/n - 0.5)
343
+
344
+ def sparsity2(adj):
345
+ n = adj.shape[0]
346
+ loss_degree = - torch.log(adj.sum(1)).sum() / n
347
+ loss_fro = torch.norm(adj) / n
348
+ return 0 * loss_degree + loss_fro
349
+
350
+ def sparsity(adj):
351
+ n = adj.shape[0]
352
+ thresh = n * n * 0.01
353
+ return F.relu(adj.sum()-thresh)
354
+ # return F.relu(adj.sum()-thresh) / n**2
355
+
356
+ def feature_smoothing(adj, X):
357
+ adj = (adj.t() + adj)/2
358
+ rowsum = adj.sum(1)
359
+ r_inv = rowsum.flatten()
360
+ D = torch.diag(r_inv)
361
+ L = D - adj
362
+
363
+ r_inv = r_inv + 1e-8
364
+ r_inv = r_inv.pow(-1/2).flatten()
365
+ r_inv[torch.isinf(r_inv)] = 0.
366
+ r_mat_inv = torch.diag(r_inv)
367
+ # L = r_mat_inv @ L
368
+ L = r_mat_inv @ L @ r_mat_inv
369
+
370
+ XLXT = torch.matmul(torch.matmul(X.t(), L), X)
371
+ loss_smooth_feat = torch.trace(XLXT)
372
+ # loss_smooth_feat = loss_smooth_feat / (adj.shape[0]**2)
373
+ return loss_smooth_feat
374
+
375
+ def row_normalize_tensor(mx):
376
+ rowsum = mx.sum(1)
377
+ r_inv = rowsum.pow(-1).flatten()
378
+ # r_inv[torch.isinf(r_inv)] = 0.
379
+ r_mat_inv = torch.diag(r_inv)
380
+ mx = r_mat_inv @ mx
381
+ return mx
382
+
383
+
GCond/utils_graphsaint.py ADDED
@@ -0,0 +1,220 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import scipy.sparse as sp
2
+ import numpy as np
3
+ import sys
4
+ import json
5
+ import os
6
+ from sklearn.preprocessing import StandardScaler
7
+ from torch_geometric.data import InMemoryDataset, Data
8
+ import torch
9
+ from itertools import repeat
10
+ from torch_geometric.data import NeighborSampler
11
+
12
+ class DataGraphSAINT:
13
+ '''datasets used in GraphSAINT paper'''
14
+
15
+ def __init__(self, dataset, **kwargs):
16
+ dataset_str='data/'+dataset+'/'
17
+ adj_full = sp.load_npz(dataset_str+'adj_full.npz')
18
+ self.nnodes = adj_full.shape[0]
19
+ if dataset == 'ogbn-arxiv':
20
+ adj_full = adj_full + adj_full.T
21
+ adj_full[adj_full > 1] = 1
22
+
23
+ role = json.load(open(dataset_str+'role.json','r'))
24
+ idx_train = role['tr']
25
+ idx_test = role['te']
26
+ idx_val = role['va']
27
+
28
+ if 'label_rate' in kwargs:
29
+ label_rate = kwargs['label_rate']
30
+ if label_rate < 1:
31
+ idx_train = idx_train[:int(label_rate*len(idx_train))]
32
+
33
+ self.adj_train = adj_full[np.ix_(idx_train, idx_train)]
34
+ self.adj_val = adj_full[np.ix_(idx_val, idx_val)]
35
+ self.adj_test = adj_full[np.ix_(idx_test, idx_test)]
36
+
37
+ feat = np.load(dataset_str+'feats.npy')
38
+ # ---- normalize feat ----
39
+ feat_train = feat[idx_train]
40
+ scaler = StandardScaler()
41
+ scaler.fit(feat_train)
42
+ feat = scaler.transform(feat)
43
+
44
+ self.feat_train = feat[idx_train]
45
+ self.feat_val = feat[idx_val]
46
+ self.feat_test = feat[idx_test]
47
+
48
+ class_map = json.load(open(dataset_str + 'class_map.json','r'))
49
+ labels = self.process_labels(class_map)
50
+
51
+ self.labels_train = labels[idx_train]
52
+ self.labels_val = labels[idx_val]
53
+ self.labels_test = labels[idx_test]
54
+
55
+ self.data_full = GraphData(adj_full, feat, labels, idx_train, idx_val, idx_test)
56
+ self.class_dict = None
57
+ self.class_dict2 = None
58
+
59
+ self.adj_full = adj_full
60
+ self.feat_full = feat
61
+ self.labels_full = labels
62
+ self.idx_train = np.array(idx_train)
63
+ self.idx_val = np.array(idx_val)
64
+ self.idx_test = np.array(idx_test)
65
+ self.samplers = None
66
+
67
+ def process_labels(self, class_map):
68
+ """
69
+ setup vertex property map for output classests
70
+ """
71
+ num_vertices = self.nnodes
72
+ if isinstance(list(class_map.values())[0], list):
73
+ num_classes = len(list(class_map.values())[0])
74
+ self.nclass = num_classes
75
+ class_arr = np.zeros((num_vertices, num_classes))
76
+ for k,v in class_map.items():
77
+ class_arr[int(k)] = v
78
+ else:
79
+ class_arr = np.zeros(num_vertices, dtype=np.int)
80
+ for k, v in class_map.items():
81
+ class_arr[int(k)] = v
82
+ class_arr = class_arr - class_arr.min()
83
+ self.nclass = max(class_arr) + 1
84
+ return class_arr
85
+
86
+ def retrieve_class(self, c, num=256):
87
+ if self.class_dict is None:
88
+ self.class_dict = {}
89
+ for i in range(self.nclass):
90
+ self.class_dict['class_%s'%i] = (self.labels_train == i)
91
+ idx = np.arange(len(self.labels_train))
92
+ idx = idx[self.class_dict['class_%s'%c]]
93
+ return np.random.permutation(idx)[:num]
94
+
95
+ def retrieve_class_sampler(self, c, adj, transductive, num=256, args=None):
96
+ if args.nlayers == 1:
97
+ sizes = [30]
98
+ if args.nlayers == 2:
99
+ if args.dataset in ['reddit', 'flickr']:
100
+ if args.option == 0:
101
+ sizes = [15, 8]
102
+ if args.option == 1:
103
+ sizes = [20, 10]
104
+ if args.option == 2:
105
+ sizes = [25, 10]
106
+ else:
107
+ sizes = [10, 5]
108
+
109
+ if self.class_dict2 is None:
110
+ print(sizes)
111
+ self.class_dict2 = {}
112
+ for i in range(self.nclass):
113
+ if transductive:
114
+ idx_train = np.array(self.idx_train)
115
+ idx = idx_train[self.labels_train == i]
116
+ else:
117
+ idx = np.arange(len(self.labels_train))[self.labels_train==i]
118
+ self.class_dict2[i] = idx
119
+
120
+ if self.samplers is None:
121
+ self.samplers = []
122
+ for i in range(self.nclass):
123
+ node_idx = torch.LongTensor(self.class_dict2[i])
124
+ if len(node_idx) == 0:
125
+ continue
126
+
127
+ self.samplers.append(NeighborSampler(adj,
128
+ node_idx=node_idx,
129
+ sizes=sizes, batch_size=num,
130
+ num_workers=8, return_e_id=False,
131
+ num_nodes=adj.size(0),
132
+ shuffle=True))
133
+ batch = np.random.permutation(self.class_dict2[c])[:num]
134
+ out = self.samplers[c].sample(batch)
135
+ return out
136
+
137
+
138
+ class GraphData:
139
+
140
+ def __init__(self, adj, features, labels, idx_train, idx_val, idx_test):
141
+ self.adj = adj
142
+ self.features = features
143
+ self.labels = labels
144
+ self.idx_train = idx_train
145
+ self.idx_val = idx_val
146
+ self.idx_test = idx_test
147
+
148
+
149
+ class Data2Pyg:
150
+
151
+ def __init__(self, data, device='cuda', transform=None, **kwargs):
152
+ self.data_train = Dpr2Pyg(data.data_train, transform=transform)[0].to(device)
153
+ self.data_val = Dpr2Pyg(data.data_val, transform=transform)[0].to(device)
154
+ self.data_test = Dpr2Pyg(data.data_test, transform=transform)[0].to(device)
155
+ self.nclass = data.nclass
156
+ self.nfeat = data.nfeat
157
+ self.class_dict = None
158
+
159
+ def retrieve_class(self, c, num=256):
160
+ if self.class_dict is None:
161
+ self.class_dict = {}
162
+ for i in range(self.nclass):
163
+ self.class_dict['class_%s'%i] = (self.data_train.y == i).cpu().numpy()
164
+ idx = np.arange(len(self.data_train.y))
165
+ idx = idx[self.class_dict['class_%s'%c]]
166
+ return np.random.permutation(idx)[:num]
167
+
168
+
169
+ class Dpr2Pyg(InMemoryDataset):
170
+
171
+ def __init__(self, dpr_data, transform=None, **kwargs):
172
+ root = 'data/' # dummy root; does not mean anything
173
+ self.dpr_data = dpr_data
174
+ super(Dpr2Pyg, self).__init__(root, transform)
175
+ pyg_data = self.process()
176
+ self.data, self.slices = self.collate([pyg_data])
177
+ self.transform = transform
178
+
179
+ def process(self):
180
+ dpr_data = self.dpr_data
181
+ edge_index = torch.LongTensor(dpr_data.adj.nonzero())
182
+ # by default, the features in pyg data is dense
183
+ if sp.issparse(dpr_data.features):
184
+ x = torch.FloatTensor(dpr_data.features.todense()).float()
185
+ else:
186
+ x = torch.FloatTensor(dpr_data.features).float()
187
+ y = torch.LongTensor(dpr_data.labels)
188
+ data = Data(x=x, edge_index=edge_index, y=y)
189
+ data.train_mask = None
190
+ data.val_mask = None
191
+ data.test_mask = None
192
+ return data
193
+
194
+
195
+ def get(self, idx):
196
+ data = self.data.__class__()
197
+
198
+ if hasattr(self.data, '__num_nodes__'):
199
+ data.num_nodes = self.data.__num_nodes__[idx]
200
+
201
+ for key in self.data.keys:
202
+ item, slices = self.data[key], self.slices[key]
203
+ s = list(repeat(slice(None), item.dim()))
204
+ s[self.data.__cat_dim__(key, item)] = slice(slices[idx],
205
+ slices[idx + 1])
206
+ data[key] = item[s]
207
+ return data
208
+
209
+ @property
210
+ def raw_file_names(self):
211
+ return ['some_file_1', 'some_file_2', ...]
212
+
213
+ @property
214
+ def processed_file_names(self):
215
+ return ['data.pt']
216
+
217
+ def _download(self):
218
+ pass
219
+
220
+
requirements.txt CHANGED
@@ -1,4 +1,4 @@
1
- torch
2
  torch-scatter
3
  torch-sparse
4
- torch-geometric
 
 
 
1
  torch-scatter
2
  torch-sparse
3
+ torch-geometric
4
+ rich
src/2.1_lrmc_bilevel.py CHANGED
@@ -1,6 +1,4 @@
1
- # lrmc_bilevel.py
2
  # Top-1 LRMC ablation: one-cluster pooling vs. plain GCN on Planetoid (e.g., Cora)
3
- # Requires: torch, torch_geometric, torch_scatter, torch_sparse
4
  #
5
  # Usage examples:
6
  # python lrmc_bilevel.py --dataset Cora --seeds /path/to/lrmc_seeds.json --variant baseline
@@ -28,6 +26,8 @@ from torch_sparse import coalesce, spspmm
28
  from torch_geometric.datasets import Planetoid
29
  from torch_geometric.nn import GCNConv
30
 
 
 
31
 
32
  # ---------------------------
33
  # Utilities: edges and seeds
 
 
1
  # Top-1 LRMC ablation: one-cluster pooling vs. plain GCN on Planetoid (e.g., Cora)
 
2
  #
3
  # Usage examples:
4
  # python lrmc_bilevel.py --dataset Cora --seeds /path/to/lrmc_seeds.json --variant baseline
 
26
  from torch_geometric.datasets import Planetoid
27
  from torch_geometric.nn import GCNConv
28
 
29
+ from rich import print
30
+
31
 
32
  # ---------------------------
33
  # Utilities: edges and seeds
src/2.2_lrmc_bilevel.py CHANGED
@@ -1,4 +1,3 @@
1
- # 2.1_lrmc_bilevel.py
2
  # Top-1 LRMC ablation with debug guards so seeds differences are visible.
3
  # Requires: torch, torch_geometric, torch_scatter, torch_sparse
4
 
@@ -16,6 +15,8 @@ from torch_sparse import coalesce, spspmm
16
  from torch_geometric.datasets import Planetoid
17
  from torch_geometric.nn import GCNConv
18
 
 
 
19
 
20
  # ---------------------------
21
  # Utilities: edges and seeds
 
 
1
  # Top-1 LRMC ablation with debug guards so seeds differences are visible.
2
  # Requires: torch, torch_geometric, torch_scatter, torch_sparse
3
 
 
15
  from torch_geometric.datasets import Planetoid
16
  from torch_geometric.nn import GCNConv
17
 
18
+ from rich import print
19
+
20
 
21
  # ---------------------------
22
  # Utilities: edges and seeds
src/2.3_lrmc_bilevel.py CHANGED
@@ -1,4 +1,3 @@
1
- # 2.3_lrmc_bilevel.py
2
  # Top-1 LRMC ablation with: cluster refinement (k-core), gated residual fusion,
3
  # sparsified cluster graph (drop self-loops + per-row top-k), and A + γA² mix.
4
  # Requires: torch, torch_geometric, torch_scatter, torch_sparse
@@ -17,6 +16,8 @@ from torch_sparse import coalesce, spspmm
17
  from torch_geometric.datasets import Planetoid
18
  from torch_geometric.nn import GCNConv
19
 
 
 
20
 
21
  # ---------------------------
22
  # Utilities: edges and seeds
 
 
1
  # Top-1 LRMC ablation with: cluster refinement (k-core), gated residual fusion,
2
  # sparsified cluster graph (drop self-loops + per-row top-k), and A + γA² mix.
3
  # Requires: torch, torch_geometric, torch_scatter, torch_sparse
 
16
  from torch_geometric.datasets import Planetoid
17
  from torch_geometric.nn import GCNConv
18
 
19
+ from rich import print
20
+
21
 
22
  # ---------------------------
23
  # Utilities: edges and seeds
src/2.4_lrmc_bilevel.py CHANGED
@@ -1,4 +1,3 @@
1
- # lrmc_bilevel.py
2
  # Top-1 LRMC ablation: one-cluster pooling vs. plain GCN on Planetoid (e.g., Cora)
3
  # Requires: torch, torch_geometric, torch_scatter, torch_sparse
4
  #
@@ -30,6 +29,8 @@ from torch_geometric.datasets import Planetoid
30
  from torch_geometric.nn import GCNConv
31
  from torch_geometric.utils import subgraph, degree # Added for stability score
32
 
 
 
33
  # ---------------------------
34
  # Utilities: edges and seeds
35
  # ---------------------------
 
 
1
  # Top-1 LRMC ablation: one-cluster pooling vs. plain GCN on Planetoid (e.g., Cora)
2
  # Requires: torch, torch_geometric, torch_scatter, torch_sparse
3
  #
 
29
  from torch_geometric.nn import GCNConv
30
  from torch_geometric.utils import subgraph, degree # Added for stability score
31
 
32
+ from rich import print
33
+
34
  # ---------------------------
35
  # Utilities: edges and seeds
36
  # ---------------------------
src/2.5_lrmc_bilevel.py ADDED
@@ -0,0 +1,540 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ L-RMC Anchored GCN vs. Plain GCN (dynamic robustness evaluation)
3
+ ==============================================================
4
+
5
+ This script trains a baseline two‑layer GCN and a new **anchor‑gated** GCN on
6
+ Planetoid citation networks (Cora/Citeseer/Pubmed). The anchor‑gated GCN uses
7
+ the top‑1 L‑RMC cluster (loaded from a provided JSON file) as a *decentralized
8
+ core*. During message passing it blends standard neighborhood aggregation
9
+ (`h_base`) with aggregation restricted to the core (`h_core`) via a per‑node
10
+ gating network. Cross‑boundary edges are optionally down‑weighted by a
11
+ damping factor `γ`.
12
+
13
+ After training on the static graph, the script evaluates *robustness over
14
+ time*. Starting from the original adjacency, it repeatedly performs random
15
+ edge rewires (removes a fraction of existing edges and adds the same number
16
+ of random new edges) and measures test accuracy at each step **without
17
+ retraining**. The area under the accuracy–time curve (AUC‑AT) is reported
18
+ for both the baseline and the anchored model. A higher AUC‑AT indicates
19
+ longer resilience to graph churn.
20
+
21
+ Usage examples::
22
+
23
+ # Train only baseline and report dynamic AUC
24
+ python 2.5_lrmc_bilevel.py --dataset Cora --seeds path/to/lrmc_seeds.json --variant baseline
25
+
26
+ # Train baseline and anchor models, evaluate AUC‑over‑time on 30 steps with 5% rewiring
27
+ python 2.5_lrmc_bilevel.py --dataset Cora --seeds path/to/lrmc_seeds.json --variant anchor \
28
+ --dynamic_steps 30 --flip_fraction 0.05 --gamma 0.8
29
+
30
+ Notes
31
+ -----
32
+ * The seeds JSON must contain an entry ``"clusters"`` with a list of clusters; the
33
+ cluster with maximum (score, size) is chosen as the core.
34
+ * For fairness, both models are trained on the identical training mask and
35
+ evaluated on the same dynamic perturbations.
36
+ * Random rewiring is undirected: an edge (u,v) is treated as the same as (v,u).
37
+ * Cross‑boundary damping and the gating network use only structural
38
+ information; features are left unchanged during perturbations.
39
+ """
40
+
41
+ import argparse
42
+ import json
43
+ import random
44
+ from pathlib import Path
45
+ from typing import Tuple, List, Optional, Set
46
+
47
+ import torch
48
+ import torch.nn as nn
49
+ import torch.nn.functional as F
50
+ from torch import Tensor
51
+ from torch_geometric.datasets import Planetoid
52
+ from torch_geometric.nn import GCNConv
53
+
54
+ from rich import print
55
+
56
+ # -----------------------------------------------------------------------------
57
+ # Utilities for loading LRMC core assignment
58
+ # -----------------------------------------------------------------------------
59
+
60
+ def _pick_top1_cluster(obj: dict) -> List[int]:
61
+ """
62
+ From LRMC JSON with structure {"clusters":[{"seed_nodes":[...],"score":float,...},...]}
63
+ choose the cluster with the highest (score, size) and return its members as
64
+ 0‑indexed integers. If no clusters exist, returns an empty list.
65
+ """
66
+ clusters = obj.get("clusters", [])
67
+ if not clusters:
68
+ return []
69
+ # Choose by highest score, tie‑break by size
70
+ best = max(clusters, key=lambda c: (float(c.get("score", 0.0)), len(c.get("seed_nodes", []))))
71
+ return [nid - 1 for nid in best.get("seed_nodes", [])]
72
+
73
+
74
+ def load_top1_assignment(seeds_json: str, n_nodes: int) -> Tuple[Tensor, Tensor]:
75
+ """
76
+ Given a path to the LRMC seeds JSON and total number of nodes, returns:
77
+
78
+ * core_mask: bool Tensor of shape [N] where True indicates membership in the
79
+ top‑1 LRMC cluster.
80
+ * core_nodes: Long Tensor containing the indices of the core nodes.
81
+
82
+ Nodes not in the core form the periphery. If the JSON has no clusters,
83
+ the core is empty.
84
+ """
85
+ obj = json.loads(Path(seeds_json).read_text())
86
+ core_list = _pick_top1_cluster(obj)
87
+ core_nodes = torch.tensor(sorted(set(core_list)), dtype=torch.long)
88
+ core_mask = torch.zeros(n_nodes, dtype=torch.bool)
89
+ if core_nodes.numel() > 0:
90
+ core_mask[core_nodes] = True
91
+ return core_mask, core_nodes
92
+
93
+
94
+ # -----------------------------------------------------------------------------
95
+ # Baseline GCN: standard two‑layer GCN
96
+ # -----------------------------------------------------------------------------
97
+
98
+ class GCN2(nn.Module):
99
+ """Plain 2‑layer GCN (baseline)."""
100
+
101
+ def __init__(self, in_dim: int, hid_dim: int, out_dim: int, dropout: float = 0.5):
102
+ super().__init__()
103
+ self.conv1 = GCNConv(in_dim, hid_dim)
104
+ self.conv2 = GCNConv(hid_dim, out_dim)
105
+ self.dropout = dropout
106
+
107
+ def forward(self, x: Tensor, edge_index: Tensor, edge_weight: Optional[Tensor] = None) -> Tensor:
108
+ # Use self loops implicitly (GCNConv defaults add_self_loops=True)
109
+ x = F.relu(self.conv1(x, edge_index, edge_weight))
110
+ x = F.dropout(x, p=self.dropout, training=self.training)
111
+ x = self.conv2(x, edge_index, edge_weight)
112
+ return x
113
+
114
+
115
+ # -----------------------------------------------------------------------------
116
+ # Anchor‑gated GCN
117
+ # -----------------------------------------------------------------------------
118
+
119
+ class AnchorGCN(nn.Module):
120
+ """
121
+ A two‑layer GCN that injects a core‑restricted aggregation channel and
122
+ down‑weights edges crossing the core boundary. After the first GCN layer
123
+ computes base features, a gating network mixes them with features
124
+ aggregated only among core neighbors.
125
+
126
+ Parameters
127
+ ----------
128
+ in_dim : int
129
+ Dimensionality of input node features.
130
+ hid_dim : int
131
+ Dimensionality of hidden layer.
132
+ out_dim : int
133
+ Number of output classes.
134
+ core_mask : Tensor[bool]
135
+ Boolean mask indicating which nodes belong to the L‑RMC core.
136
+ gamma : float, optional
137
+ Damping factor for edges that connect core and non‑core nodes.
138
+ Values <1.0 reduce the influence of boundary edges. Default is 1.0
139
+ (no damping).
140
+ dropout : float, optional
141
+ Dropout probability applied after the first layer.
142
+ """
143
+
144
+ def __init__(self,
145
+ in_dim: int,
146
+ hid_dim: int,
147
+ out_dim: int,
148
+ core_mask: Tensor,
149
+ gamma: float = 1.0,
150
+ dropout: float = 0.5):
151
+ super().__init__()
152
+ self.core_mask = core_mask.clone().detach()
153
+ self.gamma = float(gamma)
154
+ self.dropout = dropout
155
+
156
+ # Base and core convolutions for the first layer
157
+ # Base conv uses self loops; core conv disables self loops to avoid
158
+ # spurious core contributions on non‑core nodes
159
+ self.base1 = GCNConv(in_dim, hid_dim, add_self_loops=True)
160
+ self.core1 = GCNConv(in_dim, hid_dim, add_self_loops=False)
161
+
162
+ # Second layer: standard GCN on mixed features
163
+ self.conv2 = GCNConv(hid_dim, out_dim)
164
+
165
+ # Gating network: maps structural features to α ∈ [0,1]
166
+ self.gate = nn.Sequential(
167
+ nn.Linear(3, 16),
168
+ nn.ReLU(),
169
+ nn.Linear(16, 1),
170
+ nn.Sigmoid(),
171
+ )
172
+
173
+ def _compute_edge_weights(self, edge_index: Tensor) -> Tensor:
174
+ """
175
+ Given an edge index (two‑row tensor), return a weight tensor of ones
176
+ multiplied by ``gamma`` for edges with exactly one endpoint in the core.
177
+ Self loops (if present) are untouched. Edge weights are 1 for base
178
+ edges and <1 for cross‑boundary edges.
179
+ """
180
+ if self.gamma >= 1.0:
181
+ return torch.ones(edge_index.size(1), dtype=torch.float32, device=edge_index.device)
182
+ src, dst = edge_index[0], edge_index[1]
183
+ in_core_src = self.core_mask[src]
184
+ in_core_dst = self.core_mask[dst]
185
+ cross = in_core_src ^ in_core_dst
186
+ w = torch.ones(edge_index.size(1), dtype=torch.float32, device=edge_index.device)
187
+ w[cross] *= self.gamma
188
+ return w
189
+
190
+ def _compute_structural_features(self, edge_index: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
191
+ """
192
+ Compute structural features used by the gating network:
193
+
194
+ * `in_core` – 1 if node in core, else 0
195
+ * `frac_core_nbrs` – fraction of neighbors that are in the core
196
+ * `is_boundary` – 1 if node has both core and non‑core neighbors
197
+
198
+ The features are returned as a tuple of three tensors of shape [N,1].
199
+ Nodes with zero degree get frac_core_nbrs=0 and is_boundary=0.
200
+ """
201
+ N = self.core_mask.size(0)
202
+ device = edge_index.device
203
+ # Degree and core neighbor counts
204
+ src = edge_index[0]
205
+ dst = edge_index[1]
206
+ deg = torch.zeros(N, dtype=torch.float32, device=device)
207
+ core_deg = torch.zeros(N, dtype=torch.float32, device=device)
208
+ # Count contributions of directed edges; duplicates will double‑count but
209
+ # the ratio remains stable if the graph is symmetric.
210
+ deg.index_add_(0, src, torch.ones_like(src, dtype=torch.float32))
211
+ # Count core neighbors: only increment source if destination is core
212
+ core_flags = self.core_mask[dst].float()
213
+ core_deg.index_add_(0, src, core_flags)
214
+ # Avoid division by zero
215
+ frac_core = torch.zeros(N, dtype=torch.float32, device=device)
216
+ nonzero = deg > 0
217
+ frac_core[nonzero] = core_deg[nonzero] / deg[nonzero]
218
+ # Determine boundary: at least one core neighbor AND at least one non‑core neighbor
219
+ has_core = core_deg > 0
220
+ has_non_core = (deg - core_deg) > 0
221
+ is_boundary = (has_core & has_non_core).float()
222
+ in_core = self.core_mask.float()
223
+ return in_core.view(-1, 1), frac_core.view(-1, 1), is_boundary.view(-1, 1)
224
+
225
+ def forward(self, x: Tensor, edge_index: Tensor) -> Tensor:
226
+ # Compute dynamic edge weights (for base channels) using damping
227
+ w = self._compute_edge_weights(edge_index)
228
+ # First layer: base aggregation (standard neighbors with self loops)
229
+ h_base = self.base1(x, edge_index, w)
230
+ h_base = F.relu(h_base)
231
+
232
+ # First layer: core aggregation (only core neighbors, no self loops)
233
+ # Extract edges where both endpoints are core
234
+ src, dst = edge_index
235
+ mask_core_edges = self.core_mask[src] & self.core_mask[dst]
236
+ ei_core = edge_index[:, mask_core_edges]
237
+ # If no core edges exist, h_core will be zeros
238
+ if ei_core.numel() == 0:
239
+ h_core = torch.zeros_like(h_base)
240
+ else:
241
+ h_core = self.core1(x, ei_core)
242
+ h_core = F.relu(h_core)
243
+
244
+ # Structural features for gating
245
+ in_core, frac_core, is_boundary = self._compute_structural_features(edge_index)
246
+ feats = torch.cat([in_core, frac_core, is_boundary], dim=1)
247
+ alpha = self.gate(feats).view(-1) # shape [N]
248
+ # Force α=0 for nodes with no core neighbors to avoid modifying true periphery.
249
+ # Nodes with frac_core == 0 have zero core neighbors by construction.
250
+ no_core_neighbors = (frac_core.view(-1) == 0)
251
+ alpha = torch.where(no_core_neighbors, torch.zeros_like(alpha), alpha)
252
+
253
+ # Mix base and core features; h_final = h_base + α (h_core - h_base)
254
+ # Equivalent to (1-α)*h_base + α*h_core
255
+ h1 = h_base + alpha.unsqueeze(1) * (h_core - h_base)
256
+
257
+ h1 = F.dropout(h1, p=self.dropout, training=self.training)
258
+ # Second layer: standard GCN with the same damping weights
259
+ out = self.conv2(h1, edge_index, w)
260
+ return out
261
+
262
+
263
+ # The deg_for_division helper is no longer used but left here for completeness.
264
+ def deg_for_division(edge_index: Tensor, num_nodes: int) -> Tensor:
265
+ src = edge_index[0]
266
+ deg = torch.zeros(num_nodes, dtype=torch.float32, device=edge_index.device)
267
+ deg.index_add_(0, src, torch.ones_like(src, dtype=torch.float32))
268
+ return deg
269
+
270
+
271
+ # -----------------------------------------------------------------------------
272
+ # Training and evaluation routines
273
+ # -----------------------------------------------------------------------------
274
+
275
+ @torch.no_grad()
276
+ def accuracy(logits: Tensor, y: Tensor, mask: Tensor) -> float:
277
+ """Compute accuracy of the predictions over the mask."""
278
+ pred = logits[mask].argmax(dim=1)
279
+ return (pred == y[mask]).float().mean().item()
280
+
281
+
282
+ def train_model(model: nn.Module,
283
+ data,
284
+ epochs: int = 200,
285
+ lr: float = 0.01,
286
+ weight_decay: float = 5e-4) -> None:
287
+ """Standard training loop for either baseline or anchor models."""
288
+ opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=weight_decay)
289
+ best_val = 0.0
290
+ best_state = None
291
+ for ep in range(1, epochs + 1):
292
+ model.train()
293
+ opt.zero_grad(set_to_none=True)
294
+ logits = model(data.x, data.edge_index)
295
+ loss = F.cross_entropy(logits[data.train_mask], data.y[data.train_mask])
296
+ loss.backward()
297
+ opt.step()
298
+ # Evaluate on validation
299
+ model.eval()
300
+ logits_val = model(data.x, data.edge_index)
301
+ val_acc = accuracy(logits_val, data.y, data.val_mask)
302
+ if val_acc > best_val:
303
+ best_val = val_acc
304
+ best_state = {k: v.detach().clone() for k, v in model.state_dict().items()}
305
+ if best_state is not None:
306
+ model.load_state_dict(best_state)
307
+ model.eval()
308
+
309
+
310
+ def evaluate_model(model: nn.Module, data) -> dict:
311
+ """Evaluate a trained model on train, val, and test masks."""
312
+ model.eval()
313
+ logits = model(data.x, data.edge_index)
314
+ return {
315
+ "train": accuracy(logits, data.y, data.train_mask),
316
+ "val": accuracy(logits, data.y, data.val_mask),
317
+ "test": accuracy(logits, data.y, data.test_mask),
318
+ }
319
+
320
+
321
+ # -----------------------------------------------------------------------------
322
+ # Dynamic graph perturbation utilities
323
+ # -----------------------------------------------------------------------------
324
+
325
+ def undirected_edge_set(edge_index: Tensor) -> Set[Tuple[int, int]]:
326
+ """
327
+ Convert a directed edge index into a set of undirected edges represented
328
+ as (u,v) tuples with u < v. Self loops are ignored.
329
+ """
330
+ edges = set()
331
+ src = edge_index[0].tolist()
332
+ dst = edge_index[1].tolist()
333
+ for u, v in zip(src, dst):
334
+ if u == v:
335
+ continue
336
+ a, b = (u, v) if u < v else (v, u)
337
+ edges.add((a, b))
338
+ return edges
339
+
340
+
341
+ def edge_set_to_index(edges: Set[Tuple[int, int]], num_nodes: int) -> Tensor:
342
+ """
343
+ Convert an undirected edge set into a directed edge_index tensor of shape
344
+ [2, 2*|edges|] by adding both (u,v) and (v,u) for each undirected edge.
345
+ Self loops are omitted; GCNConv adds them automatically.
346
+ """
347
+ if not edges:
348
+ return torch.empty(2, 0, dtype=torch.long)
349
+ src_list = []
350
+ dst_list = []
351
+ for u, v in edges:
352
+ src_list.extend([u, v])
353
+ dst_list.extend([v, u])
354
+ edge_index = torch.tensor([src_list, dst_list], dtype=torch.long)
355
+ return edge_index
356
+
357
+
358
+ def random_rewire(edges: Set[Tuple[int, int]], num_nodes: int, n_changes: int, rng: random.Random) -> Set[Tuple[int, int]]:
359
+ """
360
+ Perform n_changes edge removals and n_changes edge additions on the given
361
+ undirected edge set. For each change we remove a random existing edge and
362
+ add a random new edge (u,v) not already present. Self loops are never
363
+ added. Duplicate additions are skipped.
364
+ """
365
+ edges = set(edges) # copy
366
+ # If there are fewer edges than n_changes, adjust
367
+ n_changes = min(n_changes, len(edges))
368
+ # Remove random edges
369
+ to_remove = rng.sample(list(edges), n_changes)
370
+ for e in to_remove:
371
+ edges.remove(e)
372
+ # Add random new edges
373
+ added = 0
374
+ attempts = 0
375
+ while added < n_changes and attempts < n_changes * 10:
376
+ u = rng.randrange(num_nodes)
377
+ v = rng.randrange(num_nodes)
378
+ if u == v:
379
+ attempts += 1
380
+ continue
381
+ a, b = (u, v) if u < v else (v, u)
382
+ if (a, b) not in edges:
383
+ edges.add((a, b))
384
+ added += 1
385
+ attempts += 1
386
+ return edges
387
+
388
+
389
+ def auc_over_time(acc_list: List[float]) -> float:
390
+ """
391
+ Compute the area under an accuracy–time curve using the trapezoidal rule.
392
+ ``acc_list`` should contain the accuracies at t=0,1,...,T. The AUC is
393
+ normalized by T so that a perfect score of 1.0 yields AUC=1.0.
394
+ """
395
+ if not acc_list:
396
+ return 0.0
397
+ area = 0.0
398
+ for i in range(1, len(acc_list)):
399
+ area += (acc_list[i] + acc_list[i-1]) / 2.0
400
+ return area / (len(acc_list) - 1)
401
+
402
+
403
+ def evaluate_dynamic_auc(model: nn.Module,
404
+ data,
405
+ core_mask: Tensor,
406
+ steps: int = 30,
407
+ flip_fraction: float = 0.05,
408
+ rng_seed: int = 1234) -> List[float]:
409
+ """
410
+ Evaluate a model's test accuracy over a sequence of random edge rewiring steps.
411
+
412
+ Parameters
413
+ ----------
414
+ model : nn.Module
415
+ A trained model that accepts (x, edge_index) and returns logits.
416
+ data : Data
417
+ PyG data object with attributes x, y, test_mask. ``data.edge_index``
418
+ provides the initial adjacency.
419
+ core_mask : Tensor[bool]
420
+ Boolean mask indicating core nodes (used for gating during evaluation).
421
+ The baseline model ignores it.
422
+ steps : int, optional
423
+ Number of rewiring steps to perform. The accuracy at t=0 is computed
424
+ before any rewiring. Default: 30.
425
+ flip_fraction : float, optional
426
+ Fraction of edges to remove/add at each step. For example, 0.05
427
+ rewires 5% of existing edges per step. Default: 0.05.
428
+ rng_seed : int, optional
429
+ Random seed for reproducibility. Default: 1234.
430
+
431
+ Returns
432
+ -------
433
+ List[float]
434
+ A list of length ``steps+1`` containing the test accuracy at each
435
+ iteration (including t=0).
436
+ """
437
+ # Convert initial edge_index to undirected edge set
438
+ base_edges = undirected_edge_set(data.edge_index)
439
+ num_edges = len(base_edges)
440
+ # Determine number of changes per step
441
+ n_changes = max(1, int(flip_fraction * num_edges))
442
+ # Clone model state so we don't accidentally update it during evaluation
443
+ model.eval()
444
+ # Random generator
445
+ rng = random.Random(rng_seed)
446
+ # Copy of edges for dynamic modification
447
+ cur_edges = set(base_edges)
448
+ accuracies = []
449
+ # Evaluate at t=0
450
+ ei = edge_set_to_index(cur_edges, data.num_nodes)
451
+ # Because PyG expects a tensor on the same device as data.x
452
+ ei = ei.to(data.x.device)
453
+ logits = model(data.x, ei)
454
+ accuracies.append(accuracy(logits, data.y, data.test_mask))
455
+ # Perform rewiring steps
456
+ for t in range(1, steps + 1):
457
+ cur_edges = random_rewire(cur_edges, data.num_nodes, n_changes, rng)
458
+ ei = edge_set_to_index(cur_edges, data.num_nodes).to(data.x.device)
459
+ logits = model(data.x, ei)
460
+ acc = accuracy(logits, data.y, data.test_mask)
461
+ accuracies.append(acc)
462
+ return accuracies
463
+
464
+
465
+ # -----------------------------------------------------------------------------
466
+ # Main entrypoint
467
+ # -----------------------------------------------------------------------------
468
+
469
+ def main():
470
+ parser = argparse.ArgumentParser(description="L‑RMC anchored GCN vs. baseline with dynamic evaluation.")
471
+ parser.add_argument("--dataset", required=True, choices=["Cora", "Citeseer", "Pubmed"],
472
+ help="Planetoid dataset to load.")
473
+ parser.add_argument("--seeds", required=True, help="Path to LRMC seeds JSON (for core extraction).")
474
+ parser.add_argument("--variant", choices=["baseline", "anchor"], default="anchor", help="Which variant to run.")
475
+ parser.add_argument("--hidden", type=int, default=64, help="Hidden dimension.")
476
+ parser.add_argument("--epochs", type=int, default=200, help="Number of training epochs.")
477
+ parser.add_argument("--lr", type=float, default=0.01, help="Learning rate.")
478
+ parser.add_argument("--wd", type=float, default=5e-4, help="Weight decay (L2).")
479
+ parser.add_argument("--dropout", type=float, default=0.5, help="Dropout probability.")
480
+ parser.add_argument("--gamma", type=float, default=1.0, help="Damping factor γ for cross‑boundary edges (anchor only).")
481
+ parser.add_argument("--dynamic_steps", type=int, default=30, help="Number of dynamic rewiring steps for AUC evaluation.")
482
+ parser.add_argument("--flip_fraction", type=float, default=0.05, help="Fraction of edges rewired at each step.")
483
+ parser.add_argument("--seed", type=int, default=42, help="Random seed for PyTorch.")
484
+ args = parser.parse_args()
485
+
486
+ # Set seeds
487
+ torch.manual_seed(args.seed)
488
+ random.seed(args.seed)
489
+
490
+ # Load dataset
491
+ dataset = Planetoid(root=f"./data/{args.dataset}", name=args.dataset)
492
+ data = dataset[0]
493
+ in_dim = dataset.num_node_features
494
+ out_dim = dataset.num_classes
495
+ num_nodes = data.num_nodes
496
+
497
+ # Load core assignment
498
+ core_mask, core_nodes = load_top1_assignment(args.seeds, num_nodes)
499
+ print(f"Loaded core of size {core_nodes.numel()} from {args.seeds}.")
500
+
501
+ if args.variant == "baseline":
502
+ # Train baseline only
503
+ baseline = GCN2(in_dim, args.hidden, out_dim, dropout=args.dropout)
504
+ train_model(baseline, data, epochs=args.epochs, lr=args.lr, weight_decay=args.wd)
505
+ res = evaluate_model(baseline, data)
506
+ print(f"Baseline GCN: train={res['train']:.4f} val={res['val']:.4f} test={res['test']:.4f}")
507
+ # Evaluate dynamic AUC
508
+ accs = evaluate_dynamic_auc(baseline, data, core_mask, steps=args.dynamic_steps,
509
+ flip_fraction=args.flip_fraction, rng_seed=args.seed)
510
+ auc = auc_over_time(accs)
511
+ print(f"Baseline dynamic AUC‑AT (steps={args.dynamic_steps}, flip={args.flip_fraction}): {auc:.4f}")
512
+ return
513
+
514
+ # ----- Train both baseline and anchor variants -----
515
+ # Baseline
516
+ baseline = GCN2(in_dim, args.hidden, out_dim, dropout=args.dropout)
517
+ train_model(baseline, data, epochs=args.epochs, lr=args.lr, weight_decay=args.wd)
518
+ res_base = evaluate_model(baseline, data)
519
+ print(f"Baseline GCN: train={res_base['train']:.4f} val={res_base['val']:.4f} test={res_base['test']:.4f}")
520
+ # Anchor model
521
+ anchor = AnchorGCN(in_dim, args.hidden, out_dim,
522
+ core_mask=core_mask,
523
+ gamma=args.gamma,
524
+ dropout=args.dropout)
525
+ train_model(anchor, data, epochs=args.epochs, lr=args.lr, weight_decay=args.wd)
526
+ res_anchor = evaluate_model(anchor, data)
527
+ print(f"Anchor‑GCN: train={res_anchor['train']:.4f} val={res_anchor['val']:.4f} test={res_anchor['test']:.4f}")
528
+ # Dynamic evaluation
529
+ accs_base = evaluate_dynamic_auc(baseline, data, core_mask, steps=args.dynamic_steps,
530
+ flip_fraction=args.flip_fraction, rng_seed=args.seed)
531
+ accs_anchor = evaluate_dynamic_auc(anchor, data, core_mask, steps=args.dynamic_steps,
532
+ flip_fraction=args.flip_fraction, rng_seed=args.seed)
533
+ auc_base = auc_over_time(accs_base)
534
+ auc_anchor = auc_over_time(accs_anchor)
535
+ print(f"Dynamic AUC‑AT (steps={args.dynamic_steps}, flip={args.flip_fraction}):")
536
+ print(f" Baseline : {auc_base:.4f}\n Anchor : {auc_anchor:.4f}")
537
+
538
+
539
+ if __name__ == "__main__":
540
+ main()
src/2.6_lrmc_summary.py ADDED
@@ -0,0 +1,691 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train a GCN on an L‑RMC subgraph and compare to a full‑graph baseline.
3
+
4
+ Modes:
5
+ - core_mode=forward : Train on core subgraph, then forward on full graph (your current approach).
6
+ - core_mode=appnp : Train on core subgraph, then seed logits on core and APPNP‑propagate on full graph.
7
+
8
+ Extras:
9
+ - --expand_core_with_train : Make sure all training labels lie inside the core
10
+ (C' = C ∪ train_idx) for fair train‑time comparison.
11
+ - --warm_ft_epochs N : Optional short finetune on the full graph starting
12
+ from the core model's weights (measure time‑to‑target).
13
+
14
+ It prints:
15
+ - Dataset stats
16
+ - Core size and coverage of train/val/test inside the core
17
+ - Train/Val/Test accuracy for baseline and core model
18
+ - Wall‑clock times
19
+ """
20
+
21
+ import argparse
22
+ import json
23
+ import time
24
+ import random
25
+ from statistics import mean, stdev
26
+ from pathlib import Path
27
+ from typing import Dict
28
+
29
+ import torch
30
+ import torch.nn.functional as F
31
+ from torch import nn, Tensor
32
+ from torch_geometric.datasets import Planetoid
33
+ from torch_geometric.nn import GCNConv, APPNP
34
+ from torch_geometric.utils import subgraph
35
+
36
+ # ------------------------------------------------------------
37
+ # Rich imports
38
+ # ------------------------------------------------------------
39
+ from rich.console import Console
40
+ from rich.table import Table
41
+ from rich.progress import Progress, SpinnerColumn, TimeElapsedColumn
42
+
43
+ # Rich console instance
44
+ console = Console()
45
+
46
+ # ------------------------------------------------------------
47
+ # Utilities
48
+ # ------------------------------------------------------------
49
+ def load_top1_assignment(seeds_json: str, n_nodes: int) -> torch.Tensor:
50
+ """
51
+ seeds_json format (expected):
52
+ {"clusters": [{"seed_nodes":[...], "score": float, ...}, ...]}
53
+ We pick the cluster with max (score, size) and return a boolean core mask.
54
+
55
+ Always assume that the seeds json nodes are 1-indexed.
56
+ """
57
+ obj = json.loads(Path(seeds_json).read_text())
58
+ clusters = obj.get("clusters", [])
59
+ if not clusters:
60
+ return torch.zeros(n_nodes, dtype=torch.bool)
61
+ best = max(clusters, key=lambda c: (float(c.get("score", 0.0)), len(c.get("seed_nodes", []))))
62
+ ids = best.get("seed_nodes", [])
63
+ ids = [int(x) - 1 for x in ids] # Convert 1-indexed to 0-indexed
64
+ ids = sorted(set([i for i in ids if 0 <= i < n_nodes]))
65
+ mask = torch.zeros(n_nodes, dtype=torch.bool)
66
+ if ids:
67
+ mask[torch.tensor(ids, dtype=torch.long)] = True
68
+ return mask
69
+
70
+
71
+ def coverage_counts(core_mask: torch.Tensor, train_mask: torch.Tensor,
72
+ val_mask: torch.Tensor, test_mask: torch.Tensor) -> Dict[str, int]:
73
+ return {
74
+ "core_size": int(core_mask.sum().item()),
75
+ "train_in_core": int((core_mask & train_mask).sum().item()),
76
+ "val_in_core": int((core_mask & val_mask).sum().item()),
77
+ "test_in_core": int((core_mask & test_mask).sum().item()),
78
+ }
79
+
80
+
81
+ def accuracy(logits: Tensor, y: Tensor, mask: Tensor) -> float:
82
+ pred = logits[mask].argmax(dim=1)
83
+ return (pred == y[mask]).float().mean().item()
84
+
85
+
86
+ def set_seed(seed: int):
87
+ """Set random seeds for reproducibility across runs."""
88
+ random.seed(seed)
89
+ try:
90
+ import numpy as np # optional
91
+ np.random.seed(seed)
92
+ except Exception:
93
+ pass
94
+ torch.manual_seed(seed)
95
+ if torch.cuda.is_available():
96
+ torch.cuda.manual_seed_all(seed)
97
+ # Make CUDA/CuDNN deterministic where applicable
98
+ torch.backends.cudnn.deterministic = True
99
+ torch.backends.cudnn.benchmark = False
100
+
101
+
102
+ # ------------------------------------------------------------
103
+ # Models
104
+ # ------------------------------------------------------------
105
+ class GCN2(nn.Module):
106
+ def __init__(self, in_dim: int, hid: int, out_dim: int, dropout: float = 0.5):
107
+ super().__init__()
108
+ self.c1 = GCNConv(in_dim, hid)
109
+ self.c2 = GCNConv(hid, out_dim)
110
+ self.dropout = dropout
111
+
112
+ def forward(self, x, ei):
113
+ x = self.c1(x, ei)
114
+ x = torch.relu(x)
115
+ x = F.dropout(x, p=self.dropout, training=self.training)
116
+ x = self.c2(x, ei)
117
+ return x
118
+
119
+
120
+ # ------------------------------------------------------------
121
+ # Training / evaluation
122
+ # ------------------------------------------------------------
123
+ @torch.no_grad()
124
+ def eval_all(model: nn.Module, data) -> Dict[str, float]:
125
+ model.eval()
126
+ logits = model(data.x, data.edge_index)
127
+ return {
128
+ "train": accuracy(logits, data.y, data.train_mask),
129
+ "val": accuracy(logits, data.y, data.val_mask),
130
+ "test": accuracy(logits, data.y, data.test_mask),
131
+ }
132
+
133
+
134
+ def train(model: nn.Module, data, epochs=200, lr=0.01, wd=5e-4, patience=100):
135
+ opt = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=wd)
136
+ best, best_state, bad = -1.0, None, 0
137
+
138
+ # Optional progress bar
139
+ with Progress(
140
+ SpinnerColumn(),
141
+ "[progress.description]{task.description}",
142
+ TimeElapsedColumn(),
143
+ transient=True,
144
+ ) as progress:
145
+ task = progress.add_task("Training", total=epochs)
146
+
147
+ for ep in range(1, epochs + 1):
148
+ model.train()
149
+ opt.zero_grad(set_to_none=True)
150
+ out = model(data.x, data.edge_index)
151
+ loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask])
152
+ loss.backward()
153
+ opt.step()
154
+
155
+ # early stop on val
156
+ with torch.no_grad():
157
+ val = accuracy(model(data.x, data.edge_index), data.y, data.val_mask)
158
+
159
+ if val > best:
160
+ best, bad = val, 0
161
+ best_state = {k: v.detach().clone() for k, v in model.state_dict().items()}
162
+ else:
163
+ bad += 1
164
+ if bad >= patience:
165
+ break
166
+
167
+ progress.update(task, advance=1, description=f"Epoch {ep} | val={val:.4f}")
168
+
169
+ if best_state is not None:
170
+ model.load_state_dict(best_state)
171
+ model.eval()
172
+
173
+
174
+ def subset_data(data, nodes_idx: torch.Tensor):
175
+ """
176
+ Build an induced subgraph on 'nodes_idx'. Keeps x,y,masks restricted to that set.
177
+ Returns a shallow copy with edge_index/feature/labels/masks sliced.
178
+ """
179
+ nodes_idx = nodes_idx.to(torch.long)
180
+ sub_ei, _ = subgraph(nodes_idx, data.edge_index, relabel_nodes=True, num_nodes=data.num_nodes)
181
+ sub = type(data)()
182
+ sub.x = data.x[nodes_idx]
183
+ sub.y = data.y[nodes_idx]
184
+ sub.train_mask = data.train_mask[nodes_idx]
185
+ sub.val_mask = data.val_mask[nodes_idx]
186
+ sub.test_mask = data.test_mask[nodes_idx]
187
+ sub.edge_index = sub_ei
188
+ sub.num_nodes = sub.x.size(0)
189
+ return sub
190
+
191
+
192
+ # ------------------------------------------------------------
193
+ # APPNP seeding (Mode B)
194
+ # ------------------------------------------------------------
195
+ def appnp_seed_propagate(logits_seed: Tensor, edge_index: Tensor, K=10, alpha=0.1) -> Tensor:
196
+ """
197
+ logits_seed is [N, C] where rows outside the core are zeros.
198
+ We propagate these logits with APPNP to fill the graph.
199
+ """
200
+ appnp = APPNP(K=K, alpha=alpha) # no trainable params
201
+ return appnp(logits_seed, edge_index)
202
+
203
+
204
+ # ------------------------------------------------------------
205
+ # Main
206
+ # ------------------------------------------------------------
207
+ def main():
208
+ p = argparse.ArgumentParser()
209
+ p.add_argument("--dataset", required=True, choices=["Cora", "Citeseer", "Pubmed"])
210
+ p.add_argument("--seeds", required=True, help="Path to LRMC seeds JSON")
211
+ p.add_argument("--hidden", type=int, default=64)
212
+ p.add_argument("--dropout", type=float, default=0.5)
213
+ p.add_argument("--epochs", type=int, default=200)
214
+ p.add_argument("--lr", type=float, default=0.01)
215
+ p.add_argument("--wd", type=float, default=5e-4)
216
+ p.add_argument("--patience", type=int, default=100)
217
+ p.add_argument("--core_mode", choices=["forward", "appnp"], default="forward",
218
+ help="How to evaluate the core model on the full graph.")
219
+ p.add_argument("--alpha", type=float, default=0.1, help="APPNP teleport prob (Mode B).")
220
+ p.add_argument("--K", type=int, default=10, help="APPNP steps (Mode B).")
221
+ p.add_argument("--expand_core_with_train", action="store_true",
222
+ help="Expand LRMC core with all training nodes (C' = C ∪ train_idx).")
223
+ p.add_argument("--warm_ft_epochs", type=int, default=0,
224
+ help="If >0, run a short finetune on the FULL graph starting from the core model.")
225
+ p.add_argument("--warm_ft_lr", type=float, default=0.005)
226
+ p.add_argument("--runs", type=int, default=1,
227
+ help="Number of runs with different seeds to average results.")
228
+ p.add_argument("-o", "--output_json", type=str, default=None,
229
+ help="If set, save all computed metrics and settings to this JSON file.")
230
+ args = p.parse_args()
231
+
232
+ # ------------------------------------------------------------
233
+ # Load data
234
+ # ------------------------------------------------------------
235
+ ds = Planetoid(root=f"./data/{args.dataset}", name=args.dataset)
236
+ data = ds[0]
237
+ n, e = data.num_nodes, data.edge_index.size(1) // 2
238
+
239
+ console.print(f"[bold cyan]Dataset: {args.dataset} | Nodes: {n} | Edges: {e}[/bold cyan]")
240
+
241
+ # Results accumulator for optional JSON output
242
+ results = {
243
+ "args": {
244
+ k: (float(v) if isinstance(v, float) else v)
245
+ for k, v in vars(args).items()
246
+ if k != "output_json"
247
+ },
248
+ "dataset": {
249
+ "name": args.dataset,
250
+ "num_nodes": int(n),
251
+ "num_edges": int(e),
252
+ },
253
+ }
254
+
255
+ def maybe_save_results():
256
+ """Write results to JSON if the user requested it."""
257
+ if not args.output_json:
258
+ return
259
+ out_path = Path(args.output_json)
260
+ try:
261
+ out_path.parent.mkdir(parents=True, exist_ok=True)
262
+ except Exception:
263
+ pass
264
+ with out_path.open("w") as f:
265
+ json.dump(results, f, indent=2)
266
+
267
+ # ------------------------------------------------------------
268
+ # Load LRMC core
269
+ # ------------------------------------------------------------
270
+ core_mask = load_top1_assignment(args.seeds, n)
271
+ if args.expand_core_with_train:
272
+ core_mask = core_mask | data.train_mask
273
+
274
+ C_idx = torch.nonzero(core_mask, as_tuple=False).view(-1)
275
+ frac = 100.0 * C_idx.numel() / n
276
+ cov = coverage_counts(core_mask, data.train_mask, data.val_mask, data.test_mask)
277
+
278
+ console.print(f"[bold green]Loaded LRMC core of size {cov['core_size']} (≈{frac:.2f}% of the graph) from {args.seeds}[/bold green]")
279
+
280
+ # Record core coverage info
281
+ results["core"] = {
282
+ "source": str(args.seeds),
283
+ "expanded_with_train": bool(args.expand_core_with_train),
284
+ "size": int(cov["core_size"]),
285
+ "fraction": float(frac / 100.0),
286
+ "coverage": {
287
+ "train_in_core": int(cov["train_in_core"]),
288
+ "val_in_core": int(cov["val_in_core"]),
289
+ "test_in_core": int(cov["test_in_core"]),
290
+ },
291
+ }
292
+
293
+ # Coverage table
294
+ cov_table = Table(title="LRMC Core Coverage")
295
+ cov_table.add_column("Metric", style="cyan")
296
+ cov_table.add_column("Count", style="magenta")
297
+ cov_table.add_row("Core Size", str(cov["core_size"]))
298
+ cov_table.add_row("Train in Core", str(cov["train_in_core"]))
299
+ cov_table.add_row("Val in Core", str(cov["val_in_core"]))
300
+ cov_table.add_row("Test in Core", str(cov["test_in_core"]))
301
+ console.print(cov_table)
302
+
303
+ # ------------------------------------------------------------
304
+ # Single-run or multi-run execution
305
+ # ------------------------------------------------------------
306
+ if args.runs == 1:
307
+ # ---------------------
308
+ # Baseline (full graph)
309
+ # ---------------------
310
+ set_seed(0)
311
+ t0 = time.perf_counter()
312
+ base = GCN2(in_dim=ds.num_node_features,
313
+ hid=args.hidden,
314
+ out_dim=ds.num_classes,
315
+ dropout=args.dropout)
316
+ train(base, data, epochs=args.epochs, lr=args.lr, wd=args.wd, patience=args.patience)
317
+ base_metrics = eval_all(base, data)
318
+ t1 = time.perf_counter()
319
+
320
+ console.print("\n[bold]Baseline (trained on full graph):[/bold]")
321
+ base_table = Table(show_header=True, header_style="bold magenta")
322
+ base_table.add_column("Metric", style="cyan")
323
+ base_table.add_column("Value", style="magenta")
324
+ base_table.add_row("Train Accuracy", f"{base_metrics['train']:.4f}")
325
+ base_table.add_row("Validation Accuracy", f"{base_metrics['val']:.4f}")
326
+ base_table.add_row("Test Accuracy", f"{base_metrics['test']:.4f}")
327
+ base_table.add_row("Time (s)", f"{t1 - t0:.2f}")
328
+ console.print(base_table)
329
+
330
+ # Save baseline single-run metrics
331
+ results["single_run"] = {
332
+ "baseline": {
333
+ "train": float(base_metrics["train"]),
334
+ "val": float(base_metrics["val"]),
335
+ "test": float(base_metrics["test"]),
336
+ "time_s": float(t1 - t0),
337
+ }
338
+ }
339
+
340
+ # ---------------------
341
+ # Core model (train on subgraph)
342
+ # ---------------------
343
+ if C_idx.numel() == 0:
344
+ console.print("[bold yellow]LRMC core is empty; skipping core model.[/bold yellow]")
345
+ results["core_empty"] = True
346
+ maybe_save_results()
347
+ return
348
+
349
+ data_C = subset_data(data, C_idx)
350
+ mC = GCN2(in_dim=ds.num_node_features,
351
+ hid=args.hidden,
352
+ out_dim=ds.num_classes,
353
+ dropout=args.dropout)
354
+
355
+ t2 = time.perf_counter()
356
+ train(mC, data_C, epochs=args.epochs, lr=args.lr, wd=args.wd, patience=args.patience)
357
+ t3 = time.perf_counter()
358
+
359
+ # Evaluate core model on FULL graph
360
+ if args.core_mode == "forward":
361
+ # Mode A: run a standard forward pass on the full graph
362
+ mC.eval()
363
+ logits_full = mC(data.x, data.edge_index)
364
+ else:
365
+ # Mode B: seed logits on core and propagate with APPNP
366
+ mC.eval()
367
+ with torch.no_grad():
368
+ logits_C = mC(data_C.x, data_C.edge_index) # [|C|, num_classes]
369
+ logits_seed = torch.zeros(n, ds.num_classes, device=logits_C.device)
370
+ logits_seed[C_idx] = logits_C
371
+ logits_full = appnp_seed_propagate(logits_seed,
372
+ data.edge_index,
373
+ K=args.K,
374
+ alpha=args.alpha)
375
+
376
+ core_metrics = {
377
+ "train": accuracy(logits_full, data.y, data.train_mask),
378
+ "val": accuracy(logits_full, data.y, data.val_mask),
379
+ "test": accuracy(logits_full, data.y, data.test_mask),
380
+ }
381
+
382
+ console.print("\n[bold]LRMC‑core model (trained on core, evaluated on full graph):[/bold]")
383
+ core_table = Table(show_header=True, header_style="bold magenta")
384
+ core_table.add_column("Metric", style="cyan")
385
+ core_table.add_column("Value", style="magenta")
386
+ core_table.add_row("Train Accuracy", f"{core_metrics['train']:.4f}")
387
+ core_table.add_row("Validation Accuracy", f"{core_metrics['val']:.4f}")
388
+ core_table.add_row("Test Accuracy", f"{core_metrics['test']:.4f}")
389
+ core_table.add_row("Core Training Time (s)", f"{t3 - t2:.2f}")
390
+ speedup = (t1 - t0) / (t3 - t2 + 1e-9)
391
+ core_table.add_row("Speedup vs. Baseline", f"{speedup:.2f}×")
392
+ console.print(core_table)
393
+
394
+ # Save core single-run metrics
395
+ results["single_run"]["core_model"] = {
396
+ "mode": str(args.core_mode),
397
+ "train": float(core_metrics["train"]),
398
+ "val": float(core_metrics["val"]),
399
+ "test": float(core_metrics["test"]),
400
+ "core_train_time_s": float(t3 - t2),
401
+ "speedup_vs_baseline": float(speedup),
402
+ }
403
+
404
+ # ------------------------------------------------------------------------------------
405
+
406
+ console.print("\n[bold]Model Comparison: Baseline vs. L-RMC-core[/bold]")
407
+
408
+ # Create comparison table
409
+ comparison_table = Table(title="Performance Comparison", show_header=True, header_style="bold magenta")
410
+ comparison_table.add_column("Metric", style="cyan")
411
+ comparison_table.add_column("Baseline", style="magenta")
412
+ comparison_table.add_column("L-RMC-core", style="green")
413
+ comparison_table.add_column("Speedup", style="yellow")
414
+
415
+ # Add performance metrics
416
+ for metric in ["train", "val", "test"]:
417
+ comparison_table.add_row(
418
+ f"{metric.capitalize()} Accuracy",
419
+ f"{base_metrics[metric]:.4f}",
420
+ f"{core_metrics[metric]:.4f}",
421
+ "" # Speedup is not applicable for accuracy
422
+ )
423
+
424
+ # Add timing and speedup
425
+ baseline_time = t1 - t0
426
+ core_time = t3 - t2
427
+ speedup = baseline_time / core_time if core_time > 0 else float('inf')
428
+
429
+ comparison_table.add_row(
430
+ "Training Time (s)",
431
+ f"{baseline_time:.2f}",
432
+ f"{core_time:.2f}",
433
+ f"{speedup:.2f}x"
434
+ )
435
+
436
+ comparison_table.add_row(
437
+ "Speedup",
438
+ "1x",
439
+ f"{speedup:.2f}x",
440
+ ""
441
+ )
442
+
443
+ console.print(comparison_table)
444
+
445
+ # Optional warm‑start finetune (single run)
446
+ if args.warm_ft_epochs > 0:
447
+ warm = GCN2(in_dim=ds.num_node_features,
448
+ hid=args.hidden,
449
+ out_dim=ds.num_classes,
450
+ dropout=args.dropout)
451
+ warm.load_state_dict(mC.state_dict())
452
+
453
+ t4 = time.perf_counter()
454
+ train(warm, data,
455
+ epochs=args.warm_ft_epochs,
456
+ lr=args.warm_ft_lr,
457
+ wd=args.wd,
458
+ patience=args.warm_ft_epochs + 1)
459
+ t5 = time.perf_counter()
460
+ warm_metrics = eval_all(warm, data)
461
+
462
+ console.print("\n[bold]Warm‑start finetune (start from core model, train on FULL graph):[/bold]")
463
+ warm_table = Table(show_header=True, header_style="bold magenta")
464
+ warm_table.add_column("Metric", style="cyan")
465
+ warm_table.add_column("Value", style="magenta")
466
+ warm_table.add_row("Train Accuracy", f"{warm_metrics['train']:.4f}")
467
+ warm_table.add_row("Validation Accuracy", f"{warm_metrics['val']:.4f}")
468
+ warm_table.add_row("Test Accuracy", f"{warm_metrics['test']:.4f}")
469
+ warm_table.add_row("Finetune Time (s)", f"{t5 - t4:.2f}")
470
+ warm_table.add_row("Total (core train + warm)", f"{(t3 - t2 + t5 - t4):.2f}s")
471
+ console.print(warm_table)
472
+
473
+ # Save warm single-run metrics
474
+ results["single_run"]["warm_finetune"] = {
475
+ "train": float(warm_metrics["train"]),
476
+ "val": float(warm_metrics["val"]),
477
+ "test": float(warm_metrics["test"]),
478
+ "finetune_time_s": float(t5 - t4),
479
+ "total_time_s": float((t3 - t2) + (t5 - t4)),
480
+ }
481
+
482
+ # Emit results for single-run
483
+ maybe_save_results()
484
+ else:
485
+ # --------------------------------------------------------
486
+ # Multi-run: average metrics across different seeds
487
+ # --------------------------------------------------------
488
+ runs = args.runs
489
+ console.print(f"\n[bold]Running {runs} seeds and averaging results[/bold]")
490
+
491
+ # Storage for metrics across runs
492
+ base_train, base_val, base_test, base_time = [], [], [], []
493
+ core_train, core_val, core_test, core_time = [], [], [], []
494
+ speedups = []
495
+
496
+ warm_train, warm_val, warm_test, warm_time, warm_total_time = [], [], [], [], []
497
+
498
+ data_C = subset_data(data, C_idx) if C_idx.numel() > 0 else None
499
+ results["core_empty"] = data_C is None
500
+
501
+ for r in range(runs):
502
+ set_seed(r)
503
+
504
+ # Baseline
505
+ t0 = time.perf_counter()
506
+ base = GCN2(in_dim=ds.num_node_features,
507
+ hid=args.hidden,
508
+ out_dim=ds.num_classes,
509
+ dropout=args.dropout)
510
+ train(base, data, epochs=args.epochs, lr=args.lr, wd=args.wd, patience=args.patience)
511
+ bm = eval_all(base, data)
512
+ t1 = time.perf_counter()
513
+
514
+ base_train.append(bm["train"]) ; base_val.append(bm["val"]) ; base_test.append(bm["test"]) ; base_time.append(t1 - t0)
515
+
516
+ # Core model
517
+ if data_C is None:
518
+ continue # no core available
519
+
520
+ t2 = time.perf_counter()
521
+ mC = GCN2(in_dim=ds.num_node_features,
522
+ hid=args.hidden,
523
+ out_dim=ds.num_classes,
524
+ dropout=args.dropout)
525
+ train(mC, data_C, epochs=args.epochs, lr=args.lr, wd=args.wd, patience=args.patience)
526
+ t3 = time.perf_counter()
527
+
528
+ if args.core_mode == "forward":
529
+ mC.eval()
530
+ logits_full = mC(data.x, data.edge_index)
531
+ else:
532
+ mC.eval()
533
+ with torch.no_grad():
534
+ logits_C = mC(data_C.x, data_C.edge_index)
535
+ logits_seed = torch.zeros(n, ds.num_classes, device=logits_C.device)
536
+ logits_seed[C_idx] = logits_C
537
+ logits_full = appnp_seed_propagate(logits_seed,
538
+ data.edge_index,
539
+ K=args.K,
540
+ alpha=args.alpha)
541
+
542
+ cm = {
543
+ "train": accuracy(logits_full, data.y, data.train_mask),
544
+ "val": accuracy(logits_full, data.y, data.val_mask),
545
+ "test": accuracy(logits_full, data.y, data.test_mask),
546
+ }
547
+
548
+ core_train.append(cm["train"]) ; core_val.append(cm["val"]) ; core_test.append(cm["test"]) ; core_time.append(t3 - t2)
549
+ speedups.append((t1 - t0) / (t3 - t2 + 1e-9))
550
+
551
+ # Optional warm finetune per run
552
+ if args.warm_ft_epochs > 0:
553
+ warm = GCN2(in_dim=ds.num_node_features,
554
+ hid=args.hidden,
555
+ out_dim=ds.num_classes,
556
+ dropout=args.dropout)
557
+ warm.load_state_dict(mC.state_dict())
558
+
559
+ t4 = time.perf_counter()
560
+ train(warm, data,
561
+ epochs=args.warm_ft_epochs,
562
+ lr=args.warm_ft_lr,
563
+ wd=args.wd,
564
+ patience=args.warm_ft_epochs + 1)
565
+ t5 = time.perf_counter()
566
+ wm = eval_all(warm, data)
567
+ warm_train.append(wm["train"]) ; warm_val.append(wm["val"]) ; warm_test.append(wm["test"]) ; warm_time.append(t5 - t4)
568
+ warm_total_time.append((t3 - t2) + (t5 - t4))
569
+
570
+ # Helper to format mean ± std
571
+ def fmt(values, prec=4):
572
+ if not values:
573
+ return "n/a"
574
+ if len(values) == 1:
575
+ return f"{values[0]:.{prec}f}"
576
+ try:
577
+ return f"{mean(values):.{prec}f} ± {stdev(values):.{prec}f}"
578
+ except Exception:
579
+ m = sum(values) / len(values)
580
+ var = sum((v - m) ** 2 for v in values) / max(1, len(values) - 1)
581
+ return f"{m:.{prec}f} ± {var ** 0.5:.{prec}f}"
582
+
583
+ def stats(values):
584
+ """Return dict with list, mean, std, count for JSON."""
585
+ d = {
586
+ "values": [float(v) for v in values],
587
+ "count": int(len(values)),
588
+ }
589
+ if len(values) >= 1:
590
+ d["mean"] = float(mean(values))
591
+ if len(values) >= 2:
592
+ d["std"] = float(stdev(values))
593
+ else:
594
+ d["std"] = None
595
+ return d
596
+
597
+ # Baseline summary
598
+ console.print("\n[bold]Baseline (averaged over runs):[/bold]")
599
+ base_table = Table(show_header=True, header_style="bold magenta")
600
+ base_table.add_column("Metric", style="cyan")
601
+ base_table.add_column("Mean ± Std", style="magenta")
602
+ base_table.add_row("Train Accuracy", fmt(base_train))
603
+ base_table.add_row("Validation Accuracy", fmt(base_val))
604
+ base_table.add_row("Test Accuracy", fmt(base_test))
605
+ base_table.add_row("Time (s)", fmt(base_time, prec=2))
606
+ console.print(base_table)
607
+
608
+ # Save baseline multi-run summary
609
+ results["multi_run"] = {
610
+ "runs": int(runs),
611
+ "baseline": {
612
+ "train": stats(base_train),
613
+ "val": stats(base_val),
614
+ "test": stats(base_test),
615
+ "time_s": stats(base_time),
616
+ }
617
+ }
618
+
619
+ if data_C is None:
620
+ console.print("[bold yellow]LRMC core is empty; no core runs to average.[/bold yellow]")
621
+ maybe_save_results()
622
+ return
623
+
624
+ # Core summary
625
+ console.print("\n[bold]LRMC‑core (averaged over runs):[/bold]")
626
+ core_table = Table(show_header=True, header_style="bold magenta")
627
+ core_table.add_column("Metric", style="cyan")
628
+ core_table.add_column("Mean ± Std", style="magenta")
629
+ core_table.add_row("Train Accuracy", fmt(core_train))
630
+ core_table.add_row("Validation Accuracy", fmt(core_val))
631
+ core_table.add_row("Test Accuracy", fmt(core_test))
632
+ core_table.add_row("Core Training Time (s)", fmt(core_time, prec=2))
633
+ core_table.add_row("Speedup vs. Baseline", fmt(speedups, prec=2))
634
+ console.print(core_table)
635
+
636
+ # Save core multi-run summary
637
+ results["multi_run"]["core_model"] = {
638
+ "mode": str(args.core_mode),
639
+ "train": stats(core_train),
640
+ "val": stats(core_val),
641
+ "test": stats(core_test),
642
+ "core_train_time_s": stats(core_time),
643
+ "speedup_vs_baseline": stats(speedups),
644
+ }
645
+
646
+ # Comparison summary
647
+ console.print("\n[bold]Model Comparison (averaged): Baseline vs. L-RMC-core[/bold]")
648
+ comparison_table = Table(title="Performance Comparison (Mean ± Std)", show_header=True, header_style="bold magenta")
649
+ comparison_table.add_column("Metric", style="cyan")
650
+ comparison_table.add_column("Baseline", style="magenta")
651
+ comparison_table.add_column("L-RMC-core", style="green")
652
+ comparison_table.add_column("Speedup", style="yellow")
653
+
654
+ for metric, b_vals, c_vals in [
655
+ ("Train Accuracy", base_train, core_train),
656
+ ("Validation Accuracy", base_val, core_val),
657
+ ("Test Accuracy", base_test, core_test),
658
+ ]:
659
+ comparison_table.add_row(metric, fmt(b_vals), fmt(c_vals), "")
660
+
661
+ comparison_table.add_row("Training Time (s)", fmt(base_time, prec=2), fmt(core_time, prec=2), fmt(speedups, prec=2))
662
+ comparison_table.add_row("Speedup", "1x", fmt(speedups, prec=2), "")
663
+ console.print(comparison_table)
664
+
665
+ # Optional warm summary
666
+ if args.warm_ft_epochs > 0 and warm_time:
667
+ console.print("\n[bold]Warm‑start finetune (averaged over runs):[/bold]")
668
+ warm_table = Table(show_header=True, header_style="bold magenta")
669
+ warm_table.add_column("Metric", style="cyan")
670
+ warm_table.add_column("Mean ± Std", style="magenta")
671
+ warm_table.add_row("Train Accuracy", fmt(warm_train))
672
+ warm_table.add_row("Validation Accuracy", fmt(warm_val))
673
+ warm_table.add_row("Test Accuracy", fmt(warm_test))
674
+ warm_table.add_row("Finetune Time (s)", fmt(warm_time, prec=2))
675
+ warm_table.add_row("Total (core train + warm)", fmt(warm_total_time, prec=2))
676
+ console.print(warm_table)
677
+
678
+ # Save warm multi-run summary
679
+ results["multi_run"]["warm_finetune"] = {
680
+ "train": stats(warm_train),
681
+ "val": stats(warm_val),
682
+ "test": stats(warm_test),
683
+ "finetune_time_s": stats(warm_time),
684
+ "total_time_s": stats(warm_total_time),
685
+ }
686
+
687
+ # Emit results for multi-run
688
+ maybe_save_results()
689
+
690
+ if __name__ == "__main__":
691
+ main()
src/2_epsilon_seed_sweep.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import shutil
4
+ from typing import List, Set, Dict, Tuple
5
+ from decimal import Decimal, getcontext
6
+ from rich import print
7
+
8
+ from generate_lrmc_seeds import build_lrmc_single_graph
9
+
10
+ def get_seed_nodes(seeds_path: str) -> Set[int]:
11
+ """Extract all seed nodes from a seeds JSON file.
12
+
13
+ Handles either 'seed_nodes' or 'members' fields.
14
+ """
15
+ try:
16
+ with open(seeds_path, 'r') as f:
17
+ data = json.load(f)
18
+
19
+ seed_nodes: Set[int] = set()
20
+ clusters = data.get('clusters', [])
21
+ for cluster in clusters:
22
+ nodes = cluster.get('seed_nodes')
23
+ if nodes is None:
24
+ nodes = cluster.get('members', [])
25
+ seed_nodes.update(nodes)
26
+
27
+ return seed_nodes
28
+ except (FileNotFoundError, json.JSONDecodeError, KeyError) as e:
29
+ print(f"[red]Error reading {seeds_path}: {e}[/red]")
30
+ return set()
31
+
32
+ def _format_eps_label(val: Decimal) -> str:
33
+ """Return a stable, unique string label for epsilon values.
34
+
35
+ - Use integer string for integral values (e.g., '50000').
36
+ - Otherwise, use a compact decimal without trailing zeros.
37
+ This avoids duplicates like many '1e+04' when step is small.
38
+ """
39
+ # Normalize to remove exponent if integral
40
+ if val == val.to_integral_value():
41
+ return str(val.to_integral_value())
42
+ # Use 'f' then strip trailing zeros/decimal point for uniqueness and readability
43
+ s = format(val, 'f')
44
+ if '.' in s:
45
+ s = s.rstrip('0').rstrip('.')
46
+ return s
47
+
48
+ def generate_epsilon_range(start: float, end: float, step: float) -> List[str]:
49
+ """Generate epsilon values as unique, stable strings.
50
+
51
+ Uses Decimal to avoid float accumulation and label collisions.
52
+ """
53
+ if step <= 0:
54
+ raise ValueError("epsilon_step must be > 0")
55
+
56
+ getcontext().prec = 28
57
+ s = Decimal(str(start))
58
+ e = Decimal(str(end))
59
+ t = Decimal(str(step))
60
+
61
+ vals: List[str] = []
62
+ cur = s
63
+ # Safety margin to include end due to decimal rounding
64
+ while cur <= e + Decimal('1e-18'):
65
+ label = _format_eps_label(cur)
66
+ if not vals or vals[-1] != label:
67
+ vals.append(label)
68
+ cur += t
69
+ return vals
70
+
71
+ def run_epsilon_sweep(input_edgelist: str, out_dir: str, levels: int,
72
+ epsilon_start: float = 1e4, epsilon_end: float = 5e5,
73
+ epsilon_step: float = 1e4, cleanup_duplicates: bool = True):
74
+ """
75
+ Run LRMC for multiple epsilon values and remove duplicate results.
76
+
77
+ Args:
78
+ input_edgelist: Path to input edgelist file
79
+ out_dir: Output directory
80
+ levels: Number of levels to build
81
+ epsilon_start: Starting epsilon value (default: 1e4)
82
+ epsilon_end: Ending epsilon value (default: 5e5)
83
+ epsilon_step: Step size for epsilon (default: 1e4)
84
+ cleanup_duplicates: Whether to remove duplicate seed sets (default: True)
85
+ """
86
+ print(f"[blue]Starting epsilon sweep from {epsilon_start} to {epsilon_end} with step {epsilon_step}[/blue]")
87
+
88
+ # Preflight: check input edgelist path and fix a common typo
89
+ if not os.path.isfile(input_edgelist):
90
+ fixed_path = None
91
+ if input_edgelist.endswith('.tx') and os.path.isfile(input_edgelist + 't'):
92
+ fixed_path = input_edgelist + 't'
93
+ elif input_edgelist.endswith('.txt'):
94
+ # Try relative to CWD if a bare filename was intended
95
+ alt = os.path.join(os.getcwd(), input_edgelist)
96
+ if os.path.isfile(alt):
97
+ fixed_path = alt
98
+ if fixed_path:
99
+ print(f"[yellow]Input edgelist not found at '{input_edgelist}'. Using '{fixed_path}' instead.[/yellow]")
100
+ input_edgelist = fixed_path
101
+ else:
102
+ raise FileNotFoundError(f"Input edgelist not found: '{input_edgelist}'. Did you mean '.txt'?")
103
+
104
+ # Generate epsilon values
105
+ epsilons = generate_epsilon_range(epsilon_start, epsilon_end, epsilon_step)
106
+ print(f"[blue]Will test {len(epsilons)} epsilon values: {epsilons}[/blue]")
107
+
108
+ # Track seen seed sets and their corresponding epsilon values
109
+ seen_seed_sets: Dict[Tuple[int, ...], str] = {}
110
+
111
+ # Run for each epsilon
112
+ for epsilon in epsilons:
113
+ print(f"[yellow]Processing epsilon: {epsilon}[/yellow]")
114
+
115
+ try:
116
+ # Create temporary output directory for this epsilon
117
+ temp_out_dir = f"{out_dir}_temp_{epsilon}"
118
+
119
+ # Run LRMC
120
+ seeds_path = build_lrmc_single_graph(
121
+ input_edgelist=input_edgelist,
122
+ out_dir=temp_out_dir,
123
+ levels=levels,
124
+ epsilon=epsilon
125
+ )
126
+
127
+ # Get seed nodes
128
+ seed_nodes = get_seed_nodes(seeds_path)
129
+ seed_nodes_tuple = tuple(sorted(seed_nodes))
130
+
131
+ print(f"[green]Epsilon {epsilon}: Found {len(seed_nodes)} unique seed nodes[/green]")
132
+
133
+ # Check if this seed set has been seen before
134
+ if seed_nodes_tuple in seen_seed_sets:
135
+ existing_epsilon = seen_seed_sets[seed_nodes_tuple]
136
+ print(f"[yellow]Duplicate seed set found! Epsilon {epsilon} has same seeds as {existing_epsilon}[/yellow]")
137
+ print(f"[yellow]Removing duplicate results for epsilon {epsilon}[/yellow]")
138
+
139
+ # Clean up temporary directory
140
+ if os.path.exists(temp_out_dir):
141
+ shutil.rmtree(temp_out_dir)
142
+ continue
143
+
144
+ # If we get here, this is a unique seed set
145
+ seen_seed_sets[seed_nodes_tuple] = epsilon
146
+
147
+ # Move results to final location
148
+ final_out_dir = f"{out_dir}_epsilon_{epsilon}"
149
+ if os.path.exists(final_out_dir):
150
+ shutil.rmtree(final_out_dir)
151
+ shutil.move(temp_out_dir, final_out_dir)
152
+
153
+ # Move seeds_XXXXX.json to the stage0 directory
154
+ seeds_file = os.path.join(final_out_dir, "stage0", f"seeds_{epsilon}.json")
155
+ if os.path.exists(seeds_file):
156
+ stage0_dir = os.path.join(out_dir, "stage0")
157
+ if not os.path.exists(stage0_dir):
158
+ os.makedirs(stage0_dir)
159
+ shutil.move(seeds_file, os.path.join(stage0_dir, f"seeds_{epsilon}.json"))
160
+
161
+ print(f"[green]Unique results saved to {os.path.join(stage0_dir, f"seeds_{epsilon}.json")}[/green]")
162
+
163
+ except Exception as e:
164
+ print(f"[red]Error processing epsilon {epsilon}: {e}[/red]")
165
+ # Clean up temporary directory if it exists
166
+ temp_out_dir = f"{out_dir}_temp_{epsilon}"
167
+ if os.path.exists(temp_out_dir):
168
+ shutil.rmtree(temp_out_dir)
169
+
170
+ # Print summary
171
+ print("\n[blue]--- Summary ---[/blue]")
172
+ print(f"[blue]Total epsilon values tested: {len(epsilons)}[/blue]")
173
+ print(f"[blue]Unique seed sets found: {len(seen_seed_sets)}[/blue]")
174
+ print(f"[blue]Duplicates removed: {len(epsilons) - len(seen_seed_sets)}[/blue]")
175
+
176
+ if seen_seed_sets:
177
+ print("\n[green]Unique epsilon values kept:[/green]")
178
+ for seed_tuple, epsilon in sorted(seen_seed_sets.items()):
179
+ seed_count = len(seed_tuple)
180
+ print(f" {epsilon}: {seed_count} seed nodes")
181
+
182
+ def main():
183
+ """Main function with command line interface."""
184
+ import argparse
185
+
186
+ parser = argparse.ArgumentParser(description="Run LRMC epsilon sweep with duplicate removal")
187
+ parser.add_argument('--input_edgelist', type=str, required=True,
188
+ help='Path to input edgelist file')
189
+ parser.add_argument('--out_dir', type=str, required=True,
190
+ help='Base output directory (results will be saved as out_dir_epsilon_X)')
191
+ parser.add_argument('--levels', type=int, required=True,
192
+ help='Number of levels to build')
193
+ parser.add_argument('--epsilon_start', type=float, default=1e4,
194
+ help='Starting epsilon value (default: 1e4)')
195
+ parser.add_argument('--epsilon_end', type=float, default=5e5,
196
+ help='Ending epsilon value (default: 5e5)')
197
+ parser.add_argument('--epsilon_step', type=float, default=1e4,
198
+ help='Epsilon step size (default: 1e4)')
199
+ parser.add_argument('--no_cleanup', action='store_true',
200
+ help='Do not remove duplicates (keep all results)')
201
+
202
+ args = parser.parse_args()
203
+
204
+ run_epsilon_sweep(
205
+ input_edgelist=args.input_edgelist,
206
+ out_dir=args.out_dir,
207
+ levels=args.levels,
208
+ epsilon_start=args.epsilon_start,
209
+ epsilon_end=args.epsilon_end,
210
+ epsilon_step=args.epsilon_step,
211
+ cleanup_duplicates=not args.no_cleanup
212
+ )
213
+
214
+ if __name__ == '__main__':
215
+ main()
src/2_lrmc_bilevel.py CHANGED
@@ -1,6 +1,4 @@
1
- # lrmc_bilevel.py
2
  # Bi-level Node↔Cluster message passing with fixed LRMC seeds
3
- # Requires: torch, torch_geometric, torch_sparse
4
 
5
  import argparse, json, os
6
  from pathlib import Path
@@ -19,6 +17,8 @@ from torch_geometric.loader import DataLoader
19
  from torch_geometric.datasets import Planetoid, TUDataset
20
  from torch_geometric.nn import GCNConv, global_mean_pool
21
 
 
 
22
 
23
  # ---------------------------
24
  # Utilities: edges and seeds
 
 
1
  # Bi-level Node↔Cluster message passing with fixed LRMC seeds
 
2
 
3
  import argparse, json, os
4
  from pathlib import Path
 
17
  from torch_geometric.datasets import Planetoid, TUDataset
18
  from torch_geometric.nn import GCNConv, global_mean_pool
19
 
20
+ from rich import print
21
+
22
 
23
  # ---------------------------
24
  # Utilities: edges and seeds
src/2_random_seed_sweep.sh ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ # ------------------------------------------------------------
3
+ # generate_lrmc_seeds.sh
4
+ #
5
+ # Run generate_lrmc_seeds.py for num_nodes = 1, 101, 201, …,
6
+ # 2601, 2701 and then once more for 2708.
7
+ #
8
+ # Usage: ./generate_lrmc_seeds.sh
9
+ # ------------------------------------------------------------
10
+
11
+ set -euo pipefail # Exit on error, undefined var, or pipe failure
12
+
13
+ # ------------------------------------------------------------------
14
+ # Configuration – adjust these if your file layout changes
15
+ # ------------------------------------------------------------------
16
+ INPUT_EDGELIST="cora_seeds/edgelist.txt"
17
+ OUT_DIR="cora_seeds"
18
+ LEVELS=1
19
+ BASELINE="random"
20
+
21
+ # ------------------------------------------------------------------
22
+ # Main loop – 1 … 2708 stepping by 100
23
+ # ------------------------------------------------------------------
24
+ echo "Starting loop over num_nodes = 1, 101, 201, …, 2601, 2701 …"
25
+
26
+ for NUM_NODES in $(seq 1 100 2708); do
27
+ echo ">>> num_nodes=$NUM_NODES"
28
+ python3 generate_lrmc_seeds.py \
29
+ --input_edgelist "$INPUT_EDGELIST" \
30
+ --out_dir "$OUT_DIR" \
31
+ --levels "$LEVELS" \
32
+ --baseline "$BASELINE" \
33
+ --num_nodes "$NUM_NODES"
34
+ done
35
+
36
+ # ------------------------------------------------------------------
37
+ # Explicitly run for 2708 (not hit by the 100‑step sequence)
38
+ # ------------------------------------------------------------------
39
+ echo ">>> num_nodes=2708 (explicitly added)"
40
+ python3 generate_lrmc_seeds.py \
41
+ --input_edgelist "$INPUT_EDGELIST" \
42
+ --out_dir "$OUT_DIR" \
43
+ --levels "$LEVELS" \
44
+ --baseline "$BASELINE" \
45
+ --num_nodes 2708
46
+
47
+ echo "All done!"