Xsmos commited on
Commit
d254527
·
verified ·
1 Parent(s): 537cb3b
diffusion.py CHANGED
@@ -30,7 +30,7 @@
30
  import logging
31
  #logging.getLogger("torch").setLevel(logging.ERROR)
32
  import warnings
33
- #warnings.filterwarnings("ignore", message=r"^Detected kernel version")
34
 
35
  from dataclasses import dataclass
36
  #import h5py
@@ -269,11 +269,12 @@ class TrainConfig:
269
  # dim = 2
270
  dim = 3#2
271
  stride = (2,4) if dim == 2 else (2,2,2)
272
- num_image = 2000#480#1200#120#3000#300#3000#6000#30#60#6000#1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
273
  batch_size = 1#1#10#50#10#50#20#50#1#2#50#20#2#100 # 10
274
- n_epoch = 20#1#50#10#1#50#1#50#5#50#5#50#100#50#100#30#120#5#4# 10#50#20#20#2#5#25 # 120
275
  HII_DIM = 64
276
  num_redshift = 64#256#512#256#512#256#512#256#512#64#512#64#512#64#256CUDAoom#128#64#512#128#64#512#256#256#64#512#128
 
277
  channel = 1
278
  img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)
279
 
@@ -444,6 +445,7 @@ class DDPM21CM:
444
  idx = "random",#'range',
445
  HII_DIM=self.config.HII_DIM,
446
  num_redshift=self.config.num_redshift,
 
447
  drop_prob=self.config.drop_prob,
448
  dim=self.config.dim,
449
  ranges_dict=self.ranges_dict,
@@ -740,7 +742,7 @@ def generate_samples(rank, world_size, local_world_size, master_addr, master_por
740
 
741
  if __name__ == "__main__":
742
  parser = argparse.ArgumentParser()
743
- parser.add_argument("--train", type=int, required=False, help="whether to train the model", default=1)
744
  #parser.add_argument("--sample", type=int, required=False, help="whether to sample", default=0)
745
  parser.add_argument("--resume", type=str, required=False, help="filename of the model to resume", default=False)
746
  parser.add_argument("--num_new_img_per_gpu", type=int, required=False, default=4)
@@ -758,7 +760,8 @@ if __name__ == "__main__":
758
  config = TrainConfig()
759
  config.gradient_accumulation_steps = args.gradient_accumulation_steps
760
  ############################ training ################################
761
- if args.train == 1:
 
762
  print(f" training, ip_addr = {socket.gethostbyname(socket.gethostname())}, master_addr = {master_addr}, local_world_size = {local_world_size}, world_size = {world_size} ".center(120,'-'))
763
  mp.spawn(
764
  train,
@@ -767,7 +770,7 @@ if __name__ == "__main__":
767
  join=True,
768
  )
769
  ############################ sampling ################################
770
- if args.train == 0:
771
  num_new_img_per_gpu = args.num_new_img_per_gpu#200#4#200
772
  max_num_img_per_gpu = args.max_num_img_per_gpu#40#2#20
773
  #config = TrainConfig()
 
30
  import logging
31
  #logging.getLogger("torch").setLevel(logging.ERROR)
32
  import warnings
33
+ warnings.filterwarnings("ignore", category=FutureWarning)
34
 
35
  from dataclasses import dataclass
36
  #import h5py
 
269
  # dim = 2
270
  dim = 3#2
271
  stride = (2,4) if dim == 2 else (2,2,2)
272
+ num_image = 3000#480#1200#120#3000#300#3000#6000#30#60#6000#1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
273
  batch_size = 1#1#10#50#10#50#20#50#1#2#50#20#2#100 # 10
274
+ n_epoch = 2#0#1#50#10#1#50#1#50#5#50#5#50#100#50#100#30#120#5#4# 10#50#20#20#2#5#25 # 120
275
  HII_DIM = 64
276
  num_redshift = 64#256#512#256#512#256#512#256#512#64#512#64#512#64#256CUDAoom#128#64#512#128#64#512#256#256#64#512#128
277
+ startat = 512-num_redshift
278
  channel = 1
279
  img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)
280
 
 
445
  idx = "random",#'range',
446
  HII_DIM=self.config.HII_DIM,
447
  num_redshift=self.config.num_redshift,
448
+ startat=self.config.startat,
449
  drop_prob=self.config.drop_prob,
450
  dim=self.config.dim,
451
  ranges_dict=self.ranges_dict,
 
742
 
743
  if __name__ == "__main__":
744
  parser = argparse.ArgumentParser()
745
+ parser.add_argument("--train", type=str, required=False, help="whether to train the model", default=False)
746
  #parser.add_argument("--sample", type=int, required=False, help="whether to sample", default=0)
747
  parser.add_argument("--resume", type=str, required=False, help="filename of the model to resume", default=False)
748
  parser.add_argument("--num_new_img_per_gpu", type=int, required=False, default=4)
 
760
  config = TrainConfig()
761
  config.gradient_accumulation_steps = args.gradient_accumulation_steps
762
  ############################ training ################################
763
+ if args.train:
764
+ config.dataset_name = args.train
765
  print(f" training, ip_addr = {socket.gethostbyname(socket.gethostname())}, master_addr = {master_addr}, local_world_size = {local_world_size}, world_size = {world_size} ".center(120,'-'))
766
  mp.spawn(
767
  train,
 
770
  join=True,
771
  )
772
  ############################ sampling ################################
773
+ if args.resume:
774
  num_new_img_per_gpu = args.num_new_img_per_gpu#200#4#200
775
  max_num_img_per_gpu = args.max_num_img_per_gpu#40#2#20
776
  #config = TrainConfig()
environment.yml ADDED
@@ -0,0 +1,346 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ name: diffusers
2
+ channels:
3
+ - anaconda
4
+ - fastai
5
+ - conda-forge
6
+ - defaults
7
+ dependencies:
8
+ - _libgcc_mutex=0.1=conda_forge
9
+ - _openmp_mutex=4.5=2_gnu
10
+ - abseil-cpp=20211102.0=h27087fc_1
11
+ - absl-py=2.1.0=pyhd8ed1ab_0
12
+ - accelerate=0.28.0=pyhd8ed1ab_0
13
+ - aiohttp=3.9.3=py39hd1e30aa_1
14
+ - aiosignal=1.3.1=pyhd8ed1ab_0
15
+ - annotated-types=0.7.0=pyhd8ed1ab_0
16
+ - anyio=4.4.0=pyhd8ed1ab_0
17
+ - argon2-cffi=23.1.0=pyhd8ed1ab_0
18
+ - argon2-cffi-bindings=21.2.0=py39hd1e30aa_4
19
+ - arrow=1.3.0=pyhd8ed1ab_0
20
+ - arrow-cpp=14.0.2=h374c478_1
21
+ - asttokens=2.4.1=pyhd8ed1ab_0
22
+ - async-lru=2.0.4=pyhd8ed1ab_0
23
+ - async-timeout=4.0.3=pyhd8ed1ab_0
24
+ - attrs=23.2.0=pyh71513ae_0
25
+ - aws-c-auth=0.6.19=h5eee18b_0
26
+ - aws-c-cal=0.5.20=hdbd6064_0
27
+ - aws-c-common=0.8.5=h5eee18b_0
28
+ - aws-c-compression=0.2.16=h5eee18b_0
29
+ - aws-c-event-stream=0.2.15=h6a678d5_0
30
+ - aws-c-http=0.6.25=h5eee18b_0
31
+ - aws-c-io=0.13.10=h5eee18b_0
32
+ - aws-c-mqtt=0.7.13=h5eee18b_0
33
+ - aws-c-s3=0.1.51=hdbd6064_0
34
+ - aws-c-sdkutils=0.1.6=h5eee18b_0
35
+ - aws-checksums=0.1.13=h5eee18b_0
36
+ - aws-crt-cpp=0.18.16=h6a678d5_0
37
+ - aws-sdk-cpp=1.10.55=h721c034_0
38
+ - babel=2.14.0=pyhd8ed1ab_0
39
+ - backcall=0.2.0=pyh9f0ad1d_0
40
+ - beautifulsoup4=4.12.3=pyha770c72_0
41
+ - blas=1.0=mkl
42
+ - bleach=6.1.0=pyhd8ed1ab_0
43
+ - boost-cpp=1.84.0=h44aadfe_2
44
+ - brotli=1.0.9=h9c3ff4c_4
45
+ - brotli-python=1.0.9=py39h5a03fae_9
46
+ - bzip2=1.0.8=hd590300_5
47
+ - c-ares=1.27.0=hd590300_0
48
+ - ca-certificates=2024.7.4=hbcca054_0
49
+ - cached-property=1.5.2=hd8ed1ab_1
50
+ - cached_property=1.5.2=pyha770c72_1
51
+ - catalogue=2.0.10=py39hf3d152e_0
52
+ - certifi=2024.7.4=pyhd8ed1ab_0
53
+ - cffi=1.16.0=py39h7a31438_0
54
+ - charset-normalizer=3.3.2=pyhd8ed1ab_0
55
+ - click=8.1.7=unix_pyh707e725_0
56
+ - cloudpathlib=0.18.1=pyhd8ed1ab_0
57
+ - colorama=0.4.6=pyhd8ed1ab_0
58
+ - comm=0.2.2=pyhd8ed1ab_0
59
+ - confection=0.1.4=py39h8003fee_0
60
+ - contourpy=1.2.0=py39h7633fee_0
61
+ - cycler=0.12.1=pyhd8ed1ab_0
62
+ - cymem=2.0.8=py39h3d6467e_1
63
+ - cyrus-sasl=2.1.28=h52b45da_1
64
+ - cython-blis=0.7.10=py39h44dd56e_2
65
+ - dataclasses=0.8=pyhc8e2a94_3
66
+ - datasets=2.17.1=pyhd8ed1ab_0
67
+ - dbus=1.13.18=hb2f20db_0
68
+ - debugpy=1.8.1=py39h3d6467e_0
69
+ - decorator=5.1.1=pyhd8ed1ab_0
70
+ - defusedxml=0.7.1=pyhd8ed1ab_0
71
+ - diffusers=0.27.1=pyhd8ed1ab_0
72
+ - dill=0.3.8=pyhd8ed1ab_0
73
+ - entrypoints=0.4=pyhd8ed1ab_0
74
+ - exceptiongroup=1.2.0=pyhd8ed1ab_2
75
+ - executing=2.0.1=pyhd8ed1ab_0
76
+ - expat=2.6.2=h59595ed_0
77
+ - fastai=2.7.15=py_0
78
+ - fastcore=1.5.48=pyhd8ed1ab_0
79
+ - fastdownload=0.0.7=pyhd8ed1ab_0
80
+ - fastprogress=1.0.3=pyhd8ed1ab_0
81
+ - filelock=3.13.3=pyhd8ed1ab_0
82
+ - fontconfig=2.14.2=h14ed4e7_0
83
+ - fonttools=4.50.0=py39hd1e30aa_0
84
+ - fqdn=1.5.1=pyhd8ed1ab_0
85
+ - freetype=2.12.1=h267a509_2
86
+ - frozenlist=1.4.1=py39hd1e30aa_0
87
+ - fsspec=2023.10.0=pyhca7485f_0
88
+ - gflags=2.2.2=he1b5a44_1004
89
+ - glib=2.80.0=hf2295e7_1
90
+ - glib-tools=2.80.0=hde27a5a_1
91
+ - glog=0.5.0=h48cff8f_0
92
+ - gmp=6.3.0=h59595ed_1
93
+ - gmpy2=2.1.2=py39h376b7d2_1
94
+ - grpc-cpp=1.48.2=he1ff14a_1
95
+ - grpcio=1.48.2=py39he1ff14a_1
96
+ - gst-plugins-base=1.14.1=h6a678d5_1
97
+ - gstreamer=1.14.1=h5eee18b_1
98
+ - h11=0.14.0=pyhd8ed1ab_0
99
+ - h2=4.1.0=py39hf3d152e_0
100
+ - h5py=3.10.0=nompi_py39h2c511df_101
101
+ - hdf5=1.14.3=nompi_h4f84152_100
102
+ - hpack=4.0.0=pyh9f0ad1d_0
103
+ - httpcore=1.0.5=pyhd8ed1ab_0
104
+ - httpx=0.27.0=pyhd8ed1ab_0
105
+ - huggingface_hub=0.22.1=pyhd8ed1ab_0
106
+ - hyperframe=6.0.1=pyhd8ed1ab_0
107
+ - icu=73.2=h59595ed_0
108
+ - idna=3.6=pyhd8ed1ab_0
109
+ - importlib-metadata=7.1.0=pyha770c72_0
110
+ - importlib-resources=6.4.0=pyhd8ed1ab_0
111
+ - importlib_metadata=7.1.0=hd8ed1ab_0
112
+ - importlib_resources=6.4.0=pyhd8ed1ab_0
113
+ - intel-openmp=2023.1.0=hdb19cb5_46306
114
+ - ipykernel=6.29.3=pyhd33586a_0
115
+ - ipython=8.15.0=py39h06a4308_0
116
+ - isoduration=20.11.0=pyhd8ed1ab_0
117
+ - jedi=0.19.1=pyhd8ed1ab_0
118
+ - jinja2=3.1.3=pyhd8ed1ab_0
119
+ - joblib=1.4.2=pyhd8ed1ab_0
120
+ - jpeg=9e=h0b41bf4_3
121
+ - json5=0.9.25=pyhd8ed1ab_0
122
+ - jsonpointer=3.0.0=py39hf3d152e_0
123
+ - jsonschema=4.22.0=pyhd8ed1ab_0
124
+ - jsonschema-specifications=2023.12.1=pyhd8ed1ab_0
125
+ - jsonschema-with-format-nongpl=4.22.0=pyhd8ed1ab_0
126
+ - jupyter-lsp=2.2.5=pyhd8ed1ab_0
127
+ - jupyter_client=8.6.1=pyhd8ed1ab_0
128
+ - jupyter_core=5.7.2=py39hf3d152e_0
129
+ - jupyter_events=0.10.0=pyhd8ed1ab_0
130
+ - jupyter_server=2.14.1=pyhd8ed1ab_0
131
+ - jupyter_server_terminals=0.5.3=pyhd8ed1ab_0
132
+ - jupyterlab=4.2.3=pyhd8ed1ab_0
133
+ - jupyterlab_pygments=0.3.0=pyhd8ed1ab_1
134
+ - jupyterlab_server=2.27.2=pyhd8ed1ab_0
135
+ - keyutils=1.6.1=h166bdaf_0
136
+ - kiwisolver=1.4.5=py39h7633fee_1
137
+ - krb5=1.20.1=h81ceb04_0
138
+ - langcodes=3.4.0=pyhd8ed1ab_0
139
+ - language-data=1.2.0=pyhd8ed1ab_0
140
+ - lcms2=2.12=h3be6417_0
141
+ - ld_impl_linux-64=2.40=h41732ed_0
142
+ - lerc=3.0=h295c915_0
143
+ - libaec=1.1.3=h59595ed_0
144
+ - libboost=1.84.0=h8013b2b_2
145
+ - libboost-devel=1.84.0=h00ab1b0_2
146
+ - libboost-headers=1.84.0=ha770c72_2
147
+ - libbrotlicommon=1.0.9=h166bdaf_9
148
+ - libbrotlidec=1.0.9=h166bdaf_9
149
+ - libbrotlienc=1.0.9=h166bdaf_9
150
+ - libclang=14.0.6=default_hc6dbbc7_1
151
+ - libclang13=14.0.6=default_he11475f_1
152
+ - libcups=2.3.3=h36d4200_3
153
+ - libcurl=8.5.0=h251f7ec_0
154
+ - libdeflate=1.17=h5eee18b_1
155
+ - libedit=3.1.20230828=h5eee18b_0
156
+ - libev=4.33=hd590300_2
157
+ - libevent=2.1.10=h28343ad_4
158
+ - libexpat=2.6.2=h59595ed_0
159
+ - libffi=3.4.2=h7f98852_5
160
+ - libgcc-ng=13.2.0=h807b86a_5
161
+ - libgfortran-ng=13.2.0=h69a702a_5
162
+ - libgfortran5=13.2.0=ha4646dd_5
163
+ - libglib=2.80.0=hf2295e7_1
164
+ - libgomp=13.2.0=h807b86a_5
165
+ - libhwloc=2.9.3=default_h554bfaf_1009
166
+ - libiconv=1.17=hd590300_2
167
+ - libllvm14=14.0.6=hcd5def8_4
168
+ - libnghttp2=1.58.0=h47da74e_1
169
+ - libnsl=2.0.1=hd590300_0
170
+ - libpng=1.6.43=h2797004_0
171
+ - libpq=12.17=hdbd6064_0
172
+ - libprotobuf=3.20.3=h3eb15da_0
173
+ - libsodium=1.0.18=h36c2ea0_1
174
+ - libsqlite=3.45.2=h2797004_0
175
+ - libssh2=1.11.0=h0841786_0
176
+ - libstdcxx-ng=13.2.0=h7e041cc_5
177
+ - libthrift=0.15.0=h362ad58_1
178
+ - libtiff=4.5.1=h6a678d5_0
179
+ - libuuid=2.38.1=h0b41bf4_0
180
+ - libwebp-base=1.3.2=hd590300_0
181
+ - libxcb=1.15=h0b41bf4_0
182
+ - libxcrypt=4.4.36=hd590300_1
183
+ - libxkbcommon=1.7.0=h662e7e4_0
184
+ - libxml2=2.12.6=h232c23b_1
185
+ - libzlib=1.2.13=hd590300_5
186
+ - llvm-openmp=18.1.2=h4dfa4b3_0
187
+ - lz4-c=1.9.4=hcb278e6_0
188
+ - marisa-trie=1.1.0=py39h3d6467e_1
189
+ - markdown=3.6=pyhd8ed1ab_0
190
+ - markdown-it-py=3.0.0=pyhd8ed1ab_0
191
+ - markupsafe=2.1.5=py39hd1e30aa_0
192
+ - matplotlib=3.8.3=py39hf3d152e_0
193
+ - matplotlib-base=3.8.3=py39he9076e7_0
194
+ - matplotlib-inline=0.1.6=pyhd8ed1ab_0
195
+ - mdurl=0.1.2=pyhd8ed1ab_0
196
+ - mistune=3.0.2=pyhd8ed1ab_0
197
+ - mkl=2023.1.0=h213fc3f_46344
198
+ - mkl-service=2.4.0=py39h5eee18b_1
199
+ - mkl_fft=1.3.8=py39h5eee18b_0
200
+ - mkl_random=1.2.4=py39hdb19cb5_0
201
+ - mpc=1.3.1=hfe3b2da_0
202
+ - mpfr=4.2.1=h9458935_0
203
+ - mpmath=1.3.0=pyhd8ed1ab_0
204
+ - multidict=6.0.5=py39hd1e30aa_0
205
+ - multiprocess=0.70.16=py39hd1e30aa_0
206
+ - munkres=1.1.4=pyh9f0ad1d_0
207
+ - murmurhash=1.0.10=py39h3d6467e_1
208
+ - mysql=5.7.24=h721c034_2
209
+ - nbclient=0.10.0=pyhd8ed1ab_0
210
+ - nbconvert-core=7.16.4=pyhd8ed1ab_1
211
+ - nbformat=5.10.4=pyhd8ed1ab_0
212
+ - ncurses=6.4.20240210=h59595ed_0
213
+ - nest-asyncio=1.6.0=pyhd8ed1ab_0
214
+ - networkx=3.2.1=pyhd8ed1ab_0
215
+ - ninja=1.11.1=h924138e_0
216
+ - notebook=7.2.1=pyhd8ed1ab_0
217
+ - notebook-shim=0.2.4=pyhd8ed1ab_0
218
+ - numpy=1.26.4=py39h5f9d8c6_0
219
+ - numpy-base=1.26.4=py39hb5e798b_0
220
+ - openjpeg=2.4.0=h3ad879b_0
221
+ - openssl=3.3.1=h4bc722e_2
222
+ - orc=1.7.4=hb3bc3d3_1
223
+ - overrides=7.7.0=pyhd8ed1ab_0
224
+ - packaging=24.0=pyhd8ed1ab_0
225
+ - pandas=1.4.2=py39h1832856_2
226
+ - pandocfilters=1.5.0=pyhd8ed1ab_0
227
+ - parso=0.8.3=pyhd8ed1ab_0
228
+ - pcre2=10.43=hcad00b1_0
229
+ - pexpect=4.9.0=pyhd8ed1ab_0
230
+ - pickleshare=0.7.5=py_1003
231
+ - pillow=10.2.0=py39h5eee18b_0
232
+ - pip=24.0=pyhd8ed1ab_0
233
+ - pkgutil-resolve-name=1.3.10=pyhd8ed1ab_1
234
+ - platformdirs=4.2.0=pyhd8ed1ab_0
235
+ - ply=3.11=py_1
236
+ - preshed=3.0.9=py39h3d6467e_1
237
+ - prometheus_client=0.20.0=pyhd8ed1ab_0
238
+ - prompt-toolkit=3.0.42=pyha770c72_0
239
+ - prompt_toolkit=3.0.42=hd8ed1ab_0
240
+ - protobuf=3.20.3=py39h227be39_1
241
+ - psutil=5.9.8=py39hd1e30aa_0
242
+ - pthread-stubs=0.4=h36c2ea0_1001
243
+ - ptyprocess=0.7.0=pyhd3deb0d_0
244
+ - pure_eval=0.2.2=pyhd8ed1ab_0
245
+ - pyarrow=14.0.2=py39h1eedbd7_0
246
+ - pyarrow-hotfix=0.6=pyhd8ed1ab_0
247
+ - pycparser=2.21=pyhd8ed1ab_0
248
+ - pydantic=2.8.2=pyhd8ed1ab_0
249
+ - pydantic-core=2.20.1=py39h5cde264_0
250
+ - pygments=2.17.2=pyhd8ed1ab_0
251
+ - pyparsing=3.1.2=pyhd8ed1ab_0
252
+ - pyqt=5.15.10=py39h6a678d5_0
253
+ - pyqt5-sip=12.13.0=py39h5eee18b_0
254
+ - pysocks=1.7.1=pyha2e5f31_6
255
+ - python=3.9.19=h0755675_0_cpython
256
+ - python-dateutil=2.9.0=pyhd8ed1ab_0
257
+ - python-fastjsonschema=2.20.0=pyhd8ed1ab_0
258
+ - python-json-logger=2.0.7=pyhd8ed1ab_0
259
+ - python-tzdata=2024.1=pyhd8ed1ab_0
260
+ - python-xxhash=3.4.1=py39hd1e30aa_0
261
+ - python_abi=3.9=4_cp39
262
+ - pytorch=2.2.0=cpu_py39hdc00b08_0
263
+ - pytz=2024.1=pyhd8ed1ab_0
264
+ - pyyaml=6.0.1=py39hd1e30aa_1
265
+ - pyzmq=25.1.2=py39h8c080ef_0
266
+ - qt-main=5.15.2=h53bd1ea_10
267
+ - re2=2022.04.01=h27087fc_0
268
+ - readline=8.2=h8228510_1
269
+ - referencing=0.35.1=pyhd8ed1ab_0
270
+ - regex=2023.12.25=py39hd1e30aa_0
271
+ - requests=2.31.0=pyhd8ed1ab_0
272
+ - rfc3339-validator=0.1.4=pyhd8ed1ab_0
273
+ - rfc3986-validator=0.1.1=pyh9f0ad1d_0
274
+ - rich=13.7.1=pyhd8ed1ab_0
275
+ - rpds-py=0.18.1=py39ha68c5e3_0
276
+ - s2n=1.3.27=hdbd6064_0
277
+ - safetensors=0.4.2=py39h9fdd4d6_0
278
+ - scikit-learn=1.5.1=py39hf7b0125_0
279
+ - scipy=1.11.3=py39h5f9d8c6_0
280
+ - send2trash=1.8.3=pyh0d859eb_0
281
+ - setuptools=69.2.0=pyhd8ed1ab_0
282
+ - shellingham=1.5.4=pyhd8ed1ab_0
283
+ - sip=6.7.12=py39h3d6467e_0
284
+ - six=1.16.0=pyh6c4a22f_0
285
+ - smart-open=7.0.4=hd8ed1ab_0
286
+ - smart_open=7.0.4=pyhd8ed1ab_0
287
+ - snappy=1.1.10=h9fff704_0
288
+ - sniffio=1.3.1=pyhd8ed1ab_0
289
+ - soupsieve=2.5=pyhd8ed1ab_1
290
+ - spacy=3.7.5=py39h95fdab5_0
291
+ - spacy-legacy=3.0.12=pyhd8ed1ab_0
292
+ - spacy-loggers=1.0.5=pyhd8ed1ab_0
293
+ - sqlite=3.45.2=h2c6b66d_0
294
+ - srsly=2.4.8=py39h3d6467e_1
295
+ - stack_data=0.6.2=pyhd8ed1ab_0
296
+ - sympy=1.12=pypyh9d50eac_103
297
+ - tbb=2021.11.0=h00ab1b0_1
298
+ - tensorboard=2.17.0=pyhd8ed1ab_0
299
+ - tensorboard-data-server=0.7.0=py39hd4f0224_1
300
+ - terminado=0.18.1=pyh0d859eb_0
301
+ - thinc=8.2.3=py39he5d7314_0
302
+ - threadpoolctl=3.5.0=pyhc1e730c_0
303
+ - tinycss2=1.3.0=pyhd8ed1ab_0
304
+ - tk=8.6.13=noxft_h4845f30_101
305
+ - tomli=2.0.1=pyhd8ed1ab_0
306
+ - torchvision=0.14.1=cpu_py39hcda3413_0
307
+ - tornado=6.4=py39hd1e30aa_0
308
+ - tqdm=4.66.2=pyhd8ed1ab_0
309
+ - traitlets=5.14.2=pyhd8ed1ab_0
310
+ - typer=0.12.3=pyhd8ed1ab_0
311
+ - typer-slim=0.12.3=pyhd8ed1ab_0
312
+ - typer-slim-standard=0.12.3=hd8ed1ab_0
313
+ - types-python-dateutil=2.9.0.20240316=pyhd8ed1ab_0
314
+ - typing-extensions=4.10.0=hd8ed1ab_0
315
+ - typing_extensions=4.10.0=pyha770c72_0
316
+ - typing_utils=0.1.0=pyhd8ed1ab_0
317
+ - tzdata=2024a=h0c530f3_0
318
+ - unicodedata2=15.1.0=py39hd1e30aa_0
319
+ - uri-template=1.3.0=pyhd8ed1ab_0
320
+ - urllib3=2.2.1=pyhd8ed1ab_0
321
+ - utf8proc=2.6.1=h5eee18b_1
322
+ - wasabi=1.1.2=py39hf3d152e_1
323
+ - wcwidth=0.2.13=pyhd8ed1ab_0
324
+ - weasel=0.4.1=pyhd8ed1ab_1
325
+ - webcolors=24.6.0=pyhd8ed1ab_0
326
+ - webencodings=0.5.1=pyhd8ed1ab_2
327
+ - websocket-client=1.8.0=pyhd8ed1ab_0
328
+ - werkzeug=3.0.1=pyhd8ed1ab_0
329
+ - wheel=0.43.0=pyhd8ed1ab_0
330
+ - wrapt=1.16.0=py39hd1e30aa_0
331
+ - xkeyboard-config=2.41=hd590300_0
332
+ - xorg-kbproto=1.0.7=h7f98852_1002
333
+ - xorg-libx11=1.8.7=h8ee46fc_0
334
+ - xorg-libxau=1.0.11=hd590300_0
335
+ - xorg-libxdmcp=1.1.3=h7f98852_0
336
+ - xorg-xextproto=7.3.0=h0b41bf4_1003
337
+ - xorg-xproto=7.0.31=h7f98852_1007
338
+ - xxhash=0.8.2=hd590300_0
339
+ - xz=5.4.6=h5eee18b_0
340
+ - yaml=0.2.5=h7f98852_2
341
+ - yarl=1.9.4=py39hd1e30aa_0
342
+ - zeromq=4.3.5=h59595ed_1
343
+ - zipp=3.17.0=pyhd8ed1ab_0
344
+ - zlib=1.2.13=hd590300_5
345
+ - zstd=1.5.5=hfc55251_0
346
+ prefix: /storage/home/hcoda1/3/bxia34/.conda/envs/diffusers
frontera_diffusion.sbatch ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH -J diffusion # Job name
3
+ #SBATCH -p rtx-dev
4
+ #SBATCH -N2 # Number of nodes and cores per node required
5
+ #SBATCH --ntasks-per-node=1
6
+ #SBATCH -t 02:00:00 # Duration of the job (Ex: 15 mins)
7
+ #SBATCH -oReport-%j # Combined output and error messages file
8
+ #SBATCH --mail-type=BEGIN,END,FAIL # Mail preferences
9
+
10
+ python -c "import torch; print('torch.cuda.is_available() =', torch.cuda.is_available()); print('torch.__version__ =', torch.__version__); print('torch.version.cuda =', torch.version.cuda)"
11
+ pwd
12
+ date
13
+ #module load anaconda3/2022.05 # Load module dependencies
14
+ #module load pytorch
15
+ #conda activate diffusers
16
+ conda env list
17
+ module list
18
+ cat $0
19
+
20
+ MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
21
+ MASTER_PORT=$((10000 + RANDOM % 10000)) #12355
22
+
23
+ export MASTER_ADDR=$MASTER_ADDR
24
+ export MASTER_PORT=$MASTER_PORT
25
+
26
+ srun python diffusion.py \
27
+ --train "$SCRATCH/LEN128-DIM64-CUB16-Tvir[4, 6]-zeta[10, 250]-0809-123640.h5" \
28
+ --resume outputs/model-N2000-device_count1-node8-epoch19-19004529 \
29
+ --num_new_img_per_gpu 50 \
30
+ --max_num_img_per_gpu 2 \
31
+ --gradient_accumulation_steps 40 \
32
+ ######################################################################################
33
+
load_h5.py CHANGED
@@ -44,6 +44,7 @@ class Dataset4h5(Dataset):
44
  transform=True,
45
  ranges_dict=None,
46
  num_workers=len(os.sched_getaffinity(0))//torch.cuda.device_count(),
 
47
  # shuffle=False,
48
  ):
49
  super().__init__()
@@ -59,7 +60,7 @@ class Dataset4h5(Dataset):
59
  self.dim = dim
60
  self.transform = transform
61
  self.num_workers = num_workers
62
-
63
  # if ranges_dict == None:
64
  # ranges_dict = dict(
65
  # images = {
@@ -156,10 +157,10 @@ class Dataset4h5(Dataset):
156
  with h5py.File(self.dir_name, 'r') as f:
157
  images_start = time()
158
  if self.dim == 2:
159
- images = f[self.field][idx,0,:self.HII_DIM,-self.num_redshift:][:,None]
160
  # images = f[self.field][idx,:self.HII_DIM,:self.HII_DIM,-3][:,None]
161
  elif self.dim == 3:
162
- images = f[self.field][idx,:self.HII_DIM,:self.HII_DIM,-self.num_redshift:][:,None]
163
  images_end = time()
164
  # print(f"pid {pid}: images of shape {images.shape} loaded after {load_end-load_start:.3f} s")
165
  pid = os.getpid()
 
44
  transform=True,
45
  ranges_dict=None,
46
  num_workers=len(os.sched_getaffinity(0))//torch.cuda.device_count(),
47
+ startat=0,
48
  # shuffle=False,
49
  ):
50
  super().__init__()
 
60
  self.dim = dim
61
  self.transform = transform
62
  self.num_workers = num_workers
63
+ self.startat = startat
64
  # if ranges_dict == None:
65
  # ranges_dict = dict(
66
  # images = {
 
157
  with h5py.File(self.dir_name, 'r') as f:
158
  images_start = time()
159
  if self.dim == 2:
160
+ images = f[self.field][idx, 0, :self.HII_DIM, self.startat:self.startat+self.num_redshift][:,None]
161
  # images = f[self.field][idx,:self.HII_DIM,:self.HII_DIM,-3][:,None]
162
  elif self.dim == 3:
163
+ images = f[self.field][idx, :self.HII_DIM, :self.HII_DIM, self.startat:self.startat+self.num_redshift][:,None]
164
  images_end = time()
165
  # print(f"pid {pid}: images of shape {images.shape} loaded after {load_end-load_start:.3f} s")
166
  pid = os.getpid()
phoenix_diffusion.sbatch CHANGED
@@ -2,10 +2,10 @@
2
  #SBATCH -J diffusion # Job name
3
  #SBATCH -A gts-jw254-coda20
4
  #SBATCH -qembers
5
- #SBATCH -N8 --gpus-per-node=V100:1 -C V100-16GB # Number of nodes and cores per node required
6
  #SBATCH --ntasks-per-node=1
7
  #SBATCH --mem-per-gpu=16G # Memory per core
8
- #SBATCH -t 08:00:00 # Duration of the job (Ex: 15 mins)
9
  #SBATCH -oReport-%j # Combined output and error messages file
10
  #SBATCH --mail-type=BEGIN,END,FAIL # Mail preferences
11
 
@@ -29,8 +29,8 @@ export MASTER_ADDR=$MASTER_ADDR
29
  export MASTER_PORT=$MASTER_PORT
30
 
31
  srun python diffusion.py \
32
- --train 1 \
33
- --resume outputs/model-N2000-device_count1-node8-epoch19-18001622 \
34
  --num_new_img_per_gpu 50 \
35
  --max_num_img_per_gpu 2 \
36
  --gradient_accumulation_steps 40 \
 
2
  #SBATCH -J diffusion # Job name
3
  #SBATCH -A gts-jw254-coda20
4
  #SBATCH -qembers
5
+ #SBATCH -N1 --gpus-per-node=V100:1 -C V100-16GB # Number of nodes and cores per node required
6
  #SBATCH --ntasks-per-node=1
7
  #SBATCH --mem-per-gpu=16G # Memory per core
8
+ #SBATCH -t 02:00:00 # Duration of the job (Ex: 15 mins)
9
  #SBATCH -oReport-%j # Combined output and error messages file
10
  #SBATCH --mail-type=BEGIN,END,FAIL # Mail preferences
11
 
 
29
  export MASTER_PORT=$MASTER_PORT
30
 
31
  srun python diffusion.py \
32
+ --train 0 \
33
+ --resume outputs/model-N2000-device_count1-node8-epoch19-19004529 \
34
  --num_new_img_per_gpu 50 \
35
  --max_num_img_per_gpu 2 \
36
  --gradient_accumulation_steps 40 \