ho22joshua commited on
Commit
3676572
·
1 Parent(s): 5ca4f82

added setup scripts

Browse files
root_gnn_dgl/README.md ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # root_gnn_dgl
2
+
3
+ ## Environment Setup
4
+
5
+ The environment dependencies for this project are listed in `environment.yml`. Follow the steps below to set up the environment:
6
+
7
+ ### Step 1: Install Conda
8
+ If you don’t already have Conda installed, install either Miniconda (lightweight) or Anaconda (full version):
9
+
10
+ - **Miniconda**: Download and install from [https://docs.conda.io/en/latest/miniconda.html](https://docs.conda.io/en/latest/miniconda.html).
11
+ - **Anaconda**: Download and install from [https://www.anaconda.com/products/distribution](https://www.anaconda.com/products/distribution).
12
+
13
+ ### Step 2: Clone the Repository
14
+ Clone this repository to your local machine:
15
+ ```bash
16
+ git lfs install
17
+ git clone https://huggingface.co/HWresearch/GNN4Colliders
18
+ ```
19
+ If you want to clone without large files - just their pointers
20
+ ```bash
21
+ GIT_LFS_SKIP_SMUDGE=1 git clone https://huggingface.co/HWresearch/GNN4Colliders
22
+ ```
23
+
24
+ ### Step 3: Create the Conda Environment
25
+ Use the `environment.yml` file to create the Conda environment:
26
+ ```bash
27
+ conda env create -f setup/environment.yml -n <environment_name>
28
+ ```
29
+
30
+ ### Step 4: Activate the Environment
31
+ Activate the newly created environment:
32
+ ```bash
33
+ conda activate <environment_name>
34
+ ```
35
+ Replace <environment_name> with the name of the environment specified in Step 4.
36
+
37
+ ### Step 5: Test the Environment
38
+ Run the `setup/test_setup.py` script to confirm that all packages needed for training are properly set up.
39
+ ```bash
40
+ python setup/test_setup.py
41
+ ```
root_gnn_dgl/models/__pycache__/GCN.cpython-38.pyc DELETED
Binary file (57 kB)
 
root_gnn_dgl/models/__pycache__/loss.cpython-38.pyc DELETED
Binary file (11.4 kB)
 
root_gnn_dgl/scripts/training_script.py CHANGED
@@ -29,8 +29,6 @@ import torch.multiprocessing as mp
29
  from torch.utils.data.distributed import DistributedSampler
30
  from torch.nn.parallel import DistributedDataParallel as DDP
31
 
32
- print("import time: {:.4f} s".format(time.time() - start_time))
33
-
34
  def mem():
35
  print(f'Current memory usage: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024 / 1024} GB')
36
 
 
29
  from torch.utils.data.distributed import DistributedSampler
30
  from torch.nn.parallel import DistributedDataParallel as DDP
31
 
 
 
32
  def mem():
33
  print(f'Current memory usage: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1024 / 1024} GB')
34
 
root_gnn_dgl/setup/environment.yml ADDED
@@ -0,0 +1,391 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: dgl
2
+ channels:
3
+ - pytorch
4
+ - dglteam/label/cu118
5
+ - nvidia
6
+ - conda-forge
7
+ - defaults
8
+ dependencies:
9
+ - _libgcc_mutex=0.1
10
+ - _openmp_mutex=4.5
11
+ - _sysroot_linux-64_curr_repodata_hack=3
12
+ - afterimage=1.21
13
+ - anyio=3.7.1
14
+ - appdirs=1.4.4
15
+ - argon2-cffi=21.3.0
16
+ - argon2-cffi-bindings=21.2.0
17
+ - arrow=1.2.3
18
+ - asttokens=2.2.1
19
+ - async-lru=2.0.4
20
+ - atk-1.0=2.38.0
21
+ - attrs=23.1.0
22
+ - awkward-pandas=2023.8.0
23
+ - aws-c-auth=0.7.0
24
+ - aws-c-cal=0.6.0
25
+ - aws-c-common=0.8.23
26
+ - aws-c-compression=0.2.17
27
+ - aws-c-event-stream=0.3.1
28
+ - aws-c-http=0.7.11
29
+ - aws-c-io=0.13.28
30
+ - aws-c-mqtt=0.8.14
31
+ - aws-c-s3=0.3.13
32
+ - aws-c-sdkutils=0.1.11
33
+ - aws-checksums=0.1.16
34
+ - aws-crt-cpp=0.20.3
35
+ - aws-sdk-cpp=1.10.57
36
+ - babel=2.12.1
37
+ - backcall=0.2.0
38
+ - backports=1.0
39
+ - backports.functools_lru_cache=1.6.5
40
+ - beautifulsoup4=4.12.2
41
+ - binutils=2.38
42
+ - binutils_impl_linux-64=2.38
43
+ - binutils_linux-64=2.38.0
44
+ - blas=1.0
45
+ - bleach=6.0.0
46
+ - brotlipy=0.7.0
47
+ - bzip2=1.0.8
48
+ - c-ares=1.19.1
49
+ - c-compiler=1.5.2
50
+ - ca-certificates=2025.4.26
51
+ - cached-property=1.5.2
52
+ - cached_property=1.5.2
53
+ - cairo=1.16.0
54
+ - certifi=2024.8.30
55
+ - cffi=1.15.1
56
+ - cfitsio=4.2.0
57
+ - charset-normalizer=2.0.4
58
+ - comm=0.1.4
59
+ - compilers=1.5.2
60
+ - cryptography=41.0.2
61
+ - cuda-cudart=11.8.89
62
+ - cuda-cupti=11.8.87
63
+ - cuda-libraries=11.8.0
64
+ - cuda-nvrtc=11.8.89
65
+ - cuda-nvtx=11.8.86
66
+ - cuda-runtime=11.8.0
67
+ - cxx-compiler=1.5.2
68
+ - davix=0.8.4
69
+ - debugpy=1.6.8
70
+ - decorator=5.1.1
71
+ - defusedxml=0.7.1
72
+ - dgl=1.1.1.cu118
73
+ - entrypoints=0.4
74
+ - exceptiongroup=1.1.3
75
+ - executing=1.2.0
76
+ - expat=2.5.0
77
+ - ffmpeg=4.3
78
+ - fftw=3.3.10
79
+ - filelock=3.9.0
80
+ - flit-core=3.9.0
81
+ - font-ttf-dejavu-sans-mono=2.37
82
+ - font-ttf-inconsolata=3.000
83
+ - font-ttf-source-code-pro=2.038
84
+ - font-ttf-ubuntu=0.83
85
+ - fontconfig=2.14.2
86
+ - fonts-conda-ecosystem=1
87
+ - fonts-conda-forge=1
88
+ - fortran-compiler=1.5.2
89
+ - fqdn=1.5.1
90
+ - freetype=2.12.1
91
+ - fribidi=1.0.10
92
+ - ftgl=2.4.0
93
+ - gcc=11.2.0
94
+ - gcc_impl_linux-64=11.2.0
95
+ - gcc_linux-64=11.2.0
96
+ - gdk-pixbuf=2.42.8
97
+ - gettext=0.21.1
98
+ - gflags=2.2.2
99
+ - gfortran=11.2.0
100
+ - gfortran_impl_linux-64=11.2.0
101
+ - gfortran_linux-64=11.2.0
102
+ - giflib=5.2.1
103
+ - gl2ps=1.4.2
104
+ - glew=2.1.0
105
+ - glog=0.6.0
106
+ - gmp=6.2.1
107
+ - gmpy2=2.1.2
108
+ - gnutls=3.6.15
109
+ - graphite2=1.3.13
110
+ - graphviz=6.0.2
111
+ - gsl=2.7
112
+ - gsoap=2.8.123
113
+ - gtk2=2.24.33
114
+ - gts=0.7.6
115
+ - gxx=11.2.0
116
+ - gxx_impl_linux-64=11.2.0
117
+ - gxx_linux-64=11.2.0
118
+ - harfbuzz=7.3.0
119
+ - icu=72.1
120
+ - idna=3.4
121
+ - importlib-metadata=6.8.0
122
+ - importlib-resources=6.0.1
123
+ - importlib_metadata=6.8.0
124
+ - importlib_resources=6.0.1
125
+ - intel-openmp=2023.1.0
126
+ - ipykernel=6.25.1
127
+ - ipyparallel=8.6.1
128
+ - ipython=8.12.2
129
+ - isoduration=20.11.0
130
+ - jedi=0.19.0
131
+ - jinja2=3.1.2
132
+ - jpeg=9e
133
+ - json5=0.9.14
134
+ - jsonpointer=2.0
135
+ - jsonschema=4.19.0
136
+ - jsonschema-specifications=2023.7.1
137
+ - jsonschema-with-format-nongpl=4.19.0
138
+ - jupyter-lsp=2.2.0
139
+ - jupyter_client=8.3.0
140
+ - jupyter_core=5.3.0
141
+ - jupyter_events=0.7.0
142
+ - jupyter_server=2.7.0
143
+ - jupyter_server_terminals=0.4.4
144
+ - jupyterlab=4.0.5
145
+ - jupyterlab_pygments=0.2.2
146
+ - jupyterlab_server=2.24.0
147
+ - kernel-headers_linux-64=3.10.0
148
+ - keyutils=1.6.1
149
+ - krb5=1.20.1
150
+ - lame=3.100
151
+ - lcms2=2.12
152
+ - ld_impl_linux-64=2.38
153
+ - lerc=3.0
154
+ - libabseil=20230125.3
155
+ - libarrow=12.0.1
156
+ - libblas=3.9.0
157
+ - libbrotlicommon=1.0.9
158
+ - libbrotlidec=1.0.9
159
+ - libbrotlienc=1.0.9
160
+ - libcblas=3.9.0
161
+ - libcrc32c=1.1.2
162
+ - libcublas=11.11.3.6
163
+ - libcufft=10.9.0.58
164
+ - libcufile=1.7.1.12
165
+ - libcurand=10.3.3.129
166
+ - libcurl=8.1.2
167
+ - libcusolver=11.4.1.48
168
+ - libcusparse=11.7.5.86
169
+ - libcxx=15.0.7
170
+ - libcxxabi=15.0.7
171
+ - libdeflate=1.12
172
+ - libedit=3.1.20191231
173
+ - libev=4.33
174
+ - libevent=2.1.12
175
+ - libexpat=2.5.0
176
+ - libffi=3.4.4
177
+ - libgcc-devel_linux-64=11.2.0
178
+ - libgcc-ng=13.1.0
179
+ - libgd=2.3.3
180
+ - libgfortran-ng=11.2.0
181
+ - libgfortran5=11.2.0
182
+ - libglib=2.76.4
183
+ - libglu=9.0.0
184
+ - libgomp=13.1.0
185
+ - libgoogle-cloud=2.12.0
186
+ - libgrpc=1.56.2
187
+ - libiconv=1.17
188
+ - libidn2=2.3.4
189
+ - libllvm13=13.0.1
190
+ - libllvm14=14.0.6
191
+ - libnghttp2=1.52.0
192
+ - libnpp=11.8.0.86
193
+ - libnsl=2.0.0
194
+ - libnuma=2.0.18
195
+ - libnvjpeg=11.9.0.86
196
+ - libpng=1.6.39
197
+ - libprotobuf=4.23.3
198
+ - librsvg=2.54.4
199
+ - libsodium=1.0.18
200
+ - libsqlite=3.42.0
201
+ - libssh2=1.11.0
202
+ - libstdcxx-devel_linux-64=11.2.0
203
+ - libstdcxx-ng=13.1.0
204
+ - libtasn1=4.19.0
205
+ - libthrift=0.18.1
206
+ - libtiff=4.4.0
207
+ - libtool=2.4.7
208
+ - libunistring=0.9.10
209
+ - libutf8proc=2.8.0
210
+ - libuuid=2.38.1
211
+ - libwebp=1.2.4
212
+ - libwebp-base=1.2.4
213
+ - libxcb=1.15
214
+ - libxml2=2.10.4
215
+ - libzlib=1.2.13
216
+ - llvmlite=0.40.1
217
+ - lz4-c=1.9.4
218
+ - markupsafe=2.1.1
219
+ - matplotlib-inline=0.1.6
220
+ - metakernel=0.29.5
221
+ - mistune=3.0.0
222
+ - mkl=2023.1.0
223
+ - mkl-service=2.4.0
224
+ - mkl_fft=1.3.6
225
+ - mkl_random=1.2.2
226
+ - mpc=1.1.0
227
+ - mpfr=4.0.2
228
+ - mpmath=1.3.0
229
+ - nbclient=0.8.0
230
+ - nbconvert-core=7.7.3
231
+ - nbformat=5.9.2
232
+ - ncurses=6.4
233
+ - nest-asyncio=1.5.6
234
+ - nettle=3.7.3
235
+ - networkx=3.1
236
+ - nlohmann_json=3.11.2
237
+ - notebook=7.0.2
238
+ - notebook-shim=0.2.3
239
+ - numba=0.57.1
240
+ - numpy=1.24.3
241
+ - numpy-base=1.24.3
242
+ - openh264=2.1.1
243
+ - openssl=3.3.1
244
+ - orc=1.9.0
245
+ - overrides=7.4.0
246
+ - packaging=23.0
247
+ - pandas=2.0.3
248
+ - pandocfilters=1.5.0
249
+ - pango=1.50.14
250
+ - parso=0.8.3
251
+ - pcre=8.45
252
+ - pcre2=10.40
253
+ - pexpect=4.8.0
254
+ - pickleshare=0.7.5
255
+ - pillow=9.4.0
256
+ - pip=23.2.1
257
+ - pixman=0.40.0
258
+ - pkgutil-resolve-name=1.3.10
259
+ - platformdirs=2.6.0
260
+ - pooch=1.4.0
261
+ - portalocker=2.7.0
262
+ - prometheus_client=0.17.1
263
+ - prompt-toolkit=3.0.39
264
+ - prompt_toolkit=3.0.39
265
+ - psutil=5.9.0
266
+ - pthread-stubs=0.4
267
+ - ptyprocess=0.7.0
268
+ - pure_eval=0.2.2
269
+ - pyarrow=12.0.1
270
+ - pycparser=2.21
271
+ - pygments=2.16.1
272
+ - pyopenssl=23.2.0
273
+ - pysocks=1.7.1
274
+ - pythia8=8.309
275
+ - python=3.8.17
276
+ - python-dateutil=2.8.2
277
+ - python-fastjsonschema=2.18.0
278
+ - python-json-logger=2.0.7
279
+ - python-tzdata=2024.2
280
+ - python_abi=3.8
281
+ - pytorch=2.0.1
282
+ - pytorch-cuda=11.8
283
+ - pytorch-mutex=1.0
284
+ - pytz=2023.3
285
+ - pyyaml=6.0
286
+ - pyzmq=25.1.1
287
+ - rdma-core=28.9
288
+ - re2=2023.03.02
289
+ - readline=8.2
290
+ - referencing=0.30.2
291
+ - requests=2.31.0
292
+ - rfc3339-validator=0.1.4
293
+ - rfc3986-validator=0.1.1
294
+ - root=6.28.0
295
+ - root_base=6.28.0
296
+ - rpds-py=0.9.2
297
+ - s2n=1.3.46
298
+ - scipy=1.10.1
299
+ - scitokens-cpp=0.7.3
300
+ - send2trash=1.8.2
301
+ - setuptools=68.0.0
302
+ - six=1.16.0
303
+ - snappy=1.1.10
304
+ - sniffio=1.3.0
305
+ - soupsieve=2.3.2.post1
306
+ - sqlite=3.41.2
307
+ - stack_data=0.6.2
308
+ - sympy=1.11.1
309
+ - sysroot_linux-64=2.17
310
+ - tbb=2021.8.0
311
+ - terminado=0.17.1
312
+ - tinycss2=1.2.1
313
+ - tk=8.6.12
314
+ - tomli=2.0.1
315
+ - torchaudio=2.0.2
316
+ - torchtriton=2.0.0
317
+ - torchvision=0.15.2
318
+ - tornado=6.3.2
319
+ - tqdm=4.65.0
320
+ - traitlets=5.9.0
321
+ - typing_extensions=4.12.2
322
+ - typing_utils=0.1.0
323
+ - ucx=1.14.1
324
+ - uri-template=1.3.0
325
+ - urllib3=1.26.16
326
+ - vdt=0.4.3
327
+ - vector-classes=1.4.3
328
+ - wcwidth=0.2.6
329
+ - webcolors=1.13
330
+ - webencodings=0.5.1
331
+ - websocket-client=1.6.1
332
+ - wheel=0.38.4
333
+ - xorg-fixesproto=5.0
334
+ - xorg-kbproto=1.0.7
335
+ - xorg-libice=1.1.1
336
+ - xorg-libsm=1.2.4
337
+ - xorg-libx11=1.8.6
338
+ - xorg-libxau=1.0.11
339
+ - xorg-libxcursor=1.2.0
340
+ - xorg-libxdmcp=1.1.3
341
+ - xorg-libxext=1.3.4
342
+ - xorg-libxfixes=5.0.3
343
+ - xorg-libxft=2.3.8
344
+ - xorg-libxpm=3.5.16
345
+ - xorg-libxrender=0.9.11
346
+ - xorg-libxt=1.3.0
347
+ - xorg-renderproto=0.11.1
348
+ - xorg-xextproto=7.3.0
349
+ - xorg-xproto=7.0.31
350
+ - xrootd=5.5.4
351
+ - xxhash=0.8.1
352
+ - xz=5.2.6
353
+ - yaml=0.2.5
354
+ - zeromq=4.3.4
355
+ - zipp=3.16.2
356
+ - zlib=1.2.13
357
+ - zstd=1.5.2
358
+ - pip:
359
+ - awkward==2.6.4
360
+ - awkward-cpp==33
361
+ - contourpy==1.1.0
362
+ - cramjam==2.8.3
363
+ - cycler==0.11.0
364
+ - fonttools==4.42.0
365
+ - fsspec==2024.3.1
366
+ - h5py==3.9.0
367
+ - pip-install==1.3.5
368
+ - joblib==1.3.2
369
+ - kiwisolver==1.4.4
370
+ - matplotlib==3.7.2
371
+ - nvidia-cublas-cu12==12.1.3.1
372
+ - nvidia-cuda-cupti-cu12==12.1.105
373
+ - nvidia-cuda-nvrtc-cu12==12.1.105
374
+ - nvidia-cuda-runtime-cu12==12.1.105
375
+ - nvidia-cudnn-cu12==8.9.2.26
376
+ - nvidia-cufft-cu12==11.0.2.54
377
+ - nvidia-curand-cu12==10.3.2.106
378
+ - nvidia-cusolver-cu12==11.4.5.107
379
+ - nvidia-cusparse-cu12==12.1.0.106
380
+ - nvidia-nccl-cu12==2.20.5
381
+ - nvidia-nvjitlink-cu12==12.4.127
382
+ - nvidia-nvtx-cu12==12.1.105
383
+ - pyparsing==3.0.9
384
+ - scikit-learn==1.3.0
385
+ - threadpoolctl==3.2.0
386
+ - torch==2.3.0
387
+ - triton==2.3.0
388
+ - typing-extensions==4.11.0
389
+ - tzdata==2024.1
390
+ - uproot==5.3.7
391
+ prefix: /global/homes/j/joshuaho/.conda/envs/dgl
root_gnn_dgl/setup/test_setup.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import importlib.util
3
+ import sys
4
+
5
+ def test_imports(directories):
6
+ """
7
+ Test importing all Python files in the specified directories.
8
+
9
+ Parameters:
10
+ - directories: List of directory paths to test.
11
+ """
12
+ print("Testing Conda environment...")
13
+
14
+ for directory in directories:
15
+ print(f"\nChecking directory: {directory}")
16
+
17
+ # Check if the directory exists
18
+ if not os.path.isdir(directory):
19
+ print(f"Directory not found: {directory}")
20
+ continue
21
+
22
+ # Iterate through all files in the directory
23
+ for filename in os.listdir(directory):
24
+ # Only consider Python files
25
+ if filename.endswith(".py"):
26
+ filepath = os.path.join(directory, filename)
27
+ module_name = os.path.splitext(filename)[0] # Remove .py extension
28
+
29
+ try:
30
+ # Dynamically import the module
31
+ spec = importlib.util.spec_from_file_location(module_name, filepath)
32
+ module = importlib.util.module_from_spec(spec)
33
+ spec.loader.exec_module(module)
34
+ print(f"Successfully imported: {filepath}")
35
+ except Exception as e:
36
+ # Print the file and the error message if import fails
37
+ print(f"Failed to import: {filepath}")
38
+ print(f"Error: {e}")
39
+
40
+ if __name__ == "__main__":
41
+ # Automatically append the current directory to sys.path
42
+ current_directory = os.getcwd()
43
+ sys.path.append(current_directory)
44
+ print(f"Current directory added to sys.path: {current_directory}")
45
+
46
+ # List of directories to check
47
+ directories = ["scripts", "root_gnn_base", "models"]
48
+ test_imports(directories)