20145032
Browse files- diffusion.py +9 -6
- environment.yml +346 -0
- frontera_diffusion.sbatch +33 -0
- load_h5.py +4 -3
- phoenix_diffusion.sbatch +4 -4
diffusion.py
CHANGED
|
@@ -30,7 +30,7 @@
|
|
| 30 |
import logging
|
| 31 |
#logging.getLogger("torch").setLevel(logging.ERROR)
|
| 32 |
import warnings
|
| 33 |
-
|
| 34 |
|
| 35 |
from dataclasses import dataclass
|
| 36 |
#import h5py
|
|
@@ -269,11 +269,12 @@ class TrainConfig:
|
|
| 269 |
# dim = 2
|
| 270 |
dim = 3#2
|
| 271 |
stride = (2,4) if dim == 2 else (2,2,2)
|
| 272 |
-
num_image =
|
| 273 |
batch_size = 1#1#10#50#10#50#20#50#1#2#50#20#2#100 # 10
|
| 274 |
-
n_epoch =
|
| 275 |
HII_DIM = 64
|
| 276 |
num_redshift = 64#256#512#256#512#256#512#256#512#64#512#64#512#64#256CUDAoom#128#64#512#128#64#512#256#256#64#512#128
|
|
|
|
| 277 |
channel = 1
|
| 278 |
img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)
|
| 279 |
|
|
@@ -444,6 +445,7 @@ class DDPM21CM:
|
|
| 444 |
idx = "random",#'range',
|
| 445 |
HII_DIM=self.config.HII_DIM,
|
| 446 |
num_redshift=self.config.num_redshift,
|
|
|
|
| 447 |
drop_prob=self.config.drop_prob,
|
| 448 |
dim=self.config.dim,
|
| 449 |
ranges_dict=self.ranges_dict,
|
|
@@ -740,7 +742,7 @@ def generate_samples(rank, world_size, local_world_size, master_addr, master_por
|
|
| 740 |
|
| 741 |
if __name__ == "__main__":
|
| 742 |
parser = argparse.ArgumentParser()
|
| 743 |
-
parser.add_argument("--train", type=
|
| 744 |
#parser.add_argument("--sample", type=int, required=False, help="whether to sample", default=0)
|
| 745 |
parser.add_argument("--resume", type=str, required=False, help="filename of the model to resume", default=False)
|
| 746 |
parser.add_argument("--num_new_img_per_gpu", type=int, required=False, default=4)
|
|
@@ -758,7 +760,8 @@ if __name__ == "__main__":
|
|
| 758 |
config = TrainConfig()
|
| 759 |
config.gradient_accumulation_steps = args.gradient_accumulation_steps
|
| 760 |
############################ training ################################
|
| 761 |
-
if args.train
|
|
|
|
| 762 |
print(f" training, ip_addr = {socket.gethostbyname(socket.gethostname())}, master_addr = {master_addr}, local_world_size = {local_world_size}, world_size = {world_size} ".center(120,'-'))
|
| 763 |
mp.spawn(
|
| 764 |
train,
|
|
@@ -767,7 +770,7 @@ if __name__ == "__main__":
|
|
| 767 |
join=True,
|
| 768 |
)
|
| 769 |
############################ sampling ################################
|
| 770 |
-
if args.
|
| 771 |
num_new_img_per_gpu = args.num_new_img_per_gpu#200#4#200
|
| 772 |
max_num_img_per_gpu = args.max_num_img_per_gpu#40#2#20
|
| 773 |
#config = TrainConfig()
|
|
|
|
| 30 |
import logging
|
| 31 |
#logging.getLogger("torch").setLevel(logging.ERROR)
|
| 32 |
import warnings
|
| 33 |
+
warnings.filterwarnings("ignore", category=FutureWarning)
|
| 34 |
|
| 35 |
from dataclasses import dataclass
|
| 36 |
#import h5py
|
|
|
|
| 269 |
# dim = 2
|
| 270 |
dim = 3#2
|
| 271 |
stride = (2,4) if dim == 2 else (2,2,2)
|
| 272 |
+
num_image = 3000#480#1200#120#3000#300#3000#6000#30#60#6000#1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
|
| 273 |
batch_size = 1#1#10#50#10#50#20#50#1#2#50#20#2#100 # 10
|
| 274 |
+
n_epoch = 2#0#1#50#10#1#50#1#50#5#50#5#50#100#50#100#30#120#5#4# 10#50#20#20#2#5#25 # 120
|
| 275 |
HII_DIM = 64
|
| 276 |
num_redshift = 64#256#512#256#512#256#512#256#512#64#512#64#512#64#256CUDAoom#128#64#512#128#64#512#256#256#64#512#128
|
| 277 |
+
startat = 512-num_redshift
|
| 278 |
channel = 1
|
| 279 |
img_shape = (channel, HII_DIM, num_redshift) if dim == 2 else (channel, HII_DIM, HII_DIM, num_redshift)
|
| 280 |
|
|
|
|
| 445 |
idx = "random",#'range',
|
| 446 |
HII_DIM=self.config.HII_DIM,
|
| 447 |
num_redshift=self.config.num_redshift,
|
| 448 |
+
startat=self.config.startat,
|
| 449 |
drop_prob=self.config.drop_prob,
|
| 450 |
dim=self.config.dim,
|
| 451 |
ranges_dict=self.ranges_dict,
|
|
|
|
| 742 |
|
| 743 |
if __name__ == "__main__":
|
| 744 |
parser = argparse.ArgumentParser()
|
| 745 |
+
parser.add_argument("--train", type=str, required=False, help="whether to train the model", default=False)
|
| 746 |
#parser.add_argument("--sample", type=int, required=False, help="whether to sample", default=0)
|
| 747 |
parser.add_argument("--resume", type=str, required=False, help="filename of the model to resume", default=False)
|
| 748 |
parser.add_argument("--num_new_img_per_gpu", type=int, required=False, default=4)
|
|
|
|
| 760 |
config = TrainConfig()
|
| 761 |
config.gradient_accumulation_steps = args.gradient_accumulation_steps
|
| 762 |
############################ training ################################
|
| 763 |
+
if args.train:
|
| 764 |
+
config.dataset_name = args.train
|
| 765 |
print(f" training, ip_addr = {socket.gethostbyname(socket.gethostname())}, master_addr = {master_addr}, local_world_size = {local_world_size}, world_size = {world_size} ".center(120,'-'))
|
| 766 |
mp.spawn(
|
| 767 |
train,
|
|
|
|
| 770 |
join=True,
|
| 771 |
)
|
| 772 |
############################ sampling ################################
|
| 773 |
+
if args.resume:
|
| 774 |
num_new_img_per_gpu = args.num_new_img_per_gpu#200#4#200
|
| 775 |
max_num_img_per_gpu = args.max_num_img_per_gpu#40#2#20
|
| 776 |
#config = TrainConfig()
|
environment.yml
ADDED
|
@@ -0,0 +1,346 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
name: diffusers
|
| 2 |
+
channels:
|
| 3 |
+
- anaconda
|
| 4 |
+
- fastai
|
| 5 |
+
- conda-forge
|
| 6 |
+
- defaults
|
| 7 |
+
dependencies:
|
| 8 |
+
- _libgcc_mutex=0.1=conda_forge
|
| 9 |
+
- _openmp_mutex=4.5=2_gnu
|
| 10 |
+
- abseil-cpp=20211102.0=h27087fc_1
|
| 11 |
+
- absl-py=2.1.0=pyhd8ed1ab_0
|
| 12 |
+
- accelerate=0.28.0=pyhd8ed1ab_0
|
| 13 |
+
- aiohttp=3.9.3=py39hd1e30aa_1
|
| 14 |
+
- aiosignal=1.3.1=pyhd8ed1ab_0
|
| 15 |
+
- annotated-types=0.7.0=pyhd8ed1ab_0
|
| 16 |
+
- anyio=4.4.0=pyhd8ed1ab_0
|
| 17 |
+
- argon2-cffi=23.1.0=pyhd8ed1ab_0
|
| 18 |
+
- argon2-cffi-bindings=21.2.0=py39hd1e30aa_4
|
| 19 |
+
- arrow=1.3.0=pyhd8ed1ab_0
|
| 20 |
+
- arrow-cpp=14.0.2=h374c478_1
|
| 21 |
+
- asttokens=2.4.1=pyhd8ed1ab_0
|
| 22 |
+
- async-lru=2.0.4=pyhd8ed1ab_0
|
| 23 |
+
- async-timeout=4.0.3=pyhd8ed1ab_0
|
| 24 |
+
- attrs=23.2.0=pyh71513ae_0
|
| 25 |
+
- aws-c-auth=0.6.19=h5eee18b_0
|
| 26 |
+
- aws-c-cal=0.5.20=hdbd6064_0
|
| 27 |
+
- aws-c-common=0.8.5=h5eee18b_0
|
| 28 |
+
- aws-c-compression=0.2.16=h5eee18b_0
|
| 29 |
+
- aws-c-event-stream=0.2.15=h6a678d5_0
|
| 30 |
+
- aws-c-http=0.6.25=h5eee18b_0
|
| 31 |
+
- aws-c-io=0.13.10=h5eee18b_0
|
| 32 |
+
- aws-c-mqtt=0.7.13=h5eee18b_0
|
| 33 |
+
- aws-c-s3=0.1.51=hdbd6064_0
|
| 34 |
+
- aws-c-sdkutils=0.1.6=h5eee18b_0
|
| 35 |
+
- aws-checksums=0.1.13=h5eee18b_0
|
| 36 |
+
- aws-crt-cpp=0.18.16=h6a678d5_0
|
| 37 |
+
- aws-sdk-cpp=1.10.55=h721c034_0
|
| 38 |
+
- babel=2.14.0=pyhd8ed1ab_0
|
| 39 |
+
- backcall=0.2.0=pyh9f0ad1d_0
|
| 40 |
+
- beautifulsoup4=4.12.3=pyha770c72_0
|
| 41 |
+
- blas=1.0=mkl
|
| 42 |
+
- bleach=6.1.0=pyhd8ed1ab_0
|
| 43 |
+
- boost-cpp=1.84.0=h44aadfe_2
|
| 44 |
+
- brotli=1.0.9=h9c3ff4c_4
|
| 45 |
+
- brotli-python=1.0.9=py39h5a03fae_9
|
| 46 |
+
- bzip2=1.0.8=hd590300_5
|
| 47 |
+
- c-ares=1.27.0=hd590300_0
|
| 48 |
+
- ca-certificates=2024.7.4=hbcca054_0
|
| 49 |
+
- cached-property=1.5.2=hd8ed1ab_1
|
| 50 |
+
- cached_property=1.5.2=pyha770c72_1
|
| 51 |
+
- catalogue=2.0.10=py39hf3d152e_0
|
| 52 |
+
- certifi=2024.7.4=pyhd8ed1ab_0
|
| 53 |
+
- cffi=1.16.0=py39h7a31438_0
|
| 54 |
+
- charset-normalizer=3.3.2=pyhd8ed1ab_0
|
| 55 |
+
- click=8.1.7=unix_pyh707e725_0
|
| 56 |
+
- cloudpathlib=0.18.1=pyhd8ed1ab_0
|
| 57 |
+
- colorama=0.4.6=pyhd8ed1ab_0
|
| 58 |
+
- comm=0.2.2=pyhd8ed1ab_0
|
| 59 |
+
- confection=0.1.4=py39h8003fee_0
|
| 60 |
+
- contourpy=1.2.0=py39h7633fee_0
|
| 61 |
+
- cycler=0.12.1=pyhd8ed1ab_0
|
| 62 |
+
- cymem=2.0.8=py39h3d6467e_1
|
| 63 |
+
- cyrus-sasl=2.1.28=h52b45da_1
|
| 64 |
+
- cython-blis=0.7.10=py39h44dd56e_2
|
| 65 |
+
- dataclasses=0.8=pyhc8e2a94_3
|
| 66 |
+
- datasets=2.17.1=pyhd8ed1ab_0
|
| 67 |
+
- dbus=1.13.18=hb2f20db_0
|
| 68 |
+
- debugpy=1.8.1=py39h3d6467e_0
|
| 69 |
+
- decorator=5.1.1=pyhd8ed1ab_0
|
| 70 |
+
- defusedxml=0.7.1=pyhd8ed1ab_0
|
| 71 |
+
- diffusers=0.27.1=pyhd8ed1ab_0
|
| 72 |
+
- dill=0.3.8=pyhd8ed1ab_0
|
| 73 |
+
- entrypoints=0.4=pyhd8ed1ab_0
|
| 74 |
+
- exceptiongroup=1.2.0=pyhd8ed1ab_2
|
| 75 |
+
- executing=2.0.1=pyhd8ed1ab_0
|
| 76 |
+
- expat=2.6.2=h59595ed_0
|
| 77 |
+
- fastai=2.7.15=py_0
|
| 78 |
+
- fastcore=1.5.48=pyhd8ed1ab_0
|
| 79 |
+
- fastdownload=0.0.7=pyhd8ed1ab_0
|
| 80 |
+
- fastprogress=1.0.3=pyhd8ed1ab_0
|
| 81 |
+
- filelock=3.13.3=pyhd8ed1ab_0
|
| 82 |
+
- fontconfig=2.14.2=h14ed4e7_0
|
| 83 |
+
- fonttools=4.50.0=py39hd1e30aa_0
|
| 84 |
+
- fqdn=1.5.1=pyhd8ed1ab_0
|
| 85 |
+
- freetype=2.12.1=h267a509_2
|
| 86 |
+
- frozenlist=1.4.1=py39hd1e30aa_0
|
| 87 |
+
- fsspec=2023.10.0=pyhca7485f_0
|
| 88 |
+
- gflags=2.2.2=he1b5a44_1004
|
| 89 |
+
- glib=2.80.0=hf2295e7_1
|
| 90 |
+
- glib-tools=2.80.0=hde27a5a_1
|
| 91 |
+
- glog=0.5.0=h48cff8f_0
|
| 92 |
+
- gmp=6.3.0=h59595ed_1
|
| 93 |
+
- gmpy2=2.1.2=py39h376b7d2_1
|
| 94 |
+
- grpc-cpp=1.48.2=he1ff14a_1
|
| 95 |
+
- grpcio=1.48.2=py39he1ff14a_1
|
| 96 |
+
- gst-plugins-base=1.14.1=h6a678d5_1
|
| 97 |
+
- gstreamer=1.14.1=h5eee18b_1
|
| 98 |
+
- h11=0.14.0=pyhd8ed1ab_0
|
| 99 |
+
- h2=4.1.0=py39hf3d152e_0
|
| 100 |
+
- h5py=3.10.0=nompi_py39h2c511df_101
|
| 101 |
+
- hdf5=1.14.3=nompi_h4f84152_100
|
| 102 |
+
- hpack=4.0.0=pyh9f0ad1d_0
|
| 103 |
+
- httpcore=1.0.5=pyhd8ed1ab_0
|
| 104 |
+
- httpx=0.27.0=pyhd8ed1ab_0
|
| 105 |
+
- huggingface_hub=0.22.1=pyhd8ed1ab_0
|
| 106 |
+
- hyperframe=6.0.1=pyhd8ed1ab_0
|
| 107 |
+
- icu=73.2=h59595ed_0
|
| 108 |
+
- idna=3.6=pyhd8ed1ab_0
|
| 109 |
+
- importlib-metadata=7.1.0=pyha770c72_0
|
| 110 |
+
- importlib-resources=6.4.0=pyhd8ed1ab_0
|
| 111 |
+
- importlib_metadata=7.1.0=hd8ed1ab_0
|
| 112 |
+
- importlib_resources=6.4.0=pyhd8ed1ab_0
|
| 113 |
+
- intel-openmp=2023.1.0=hdb19cb5_46306
|
| 114 |
+
- ipykernel=6.29.3=pyhd33586a_0
|
| 115 |
+
- ipython=8.15.0=py39h06a4308_0
|
| 116 |
+
- isoduration=20.11.0=pyhd8ed1ab_0
|
| 117 |
+
- jedi=0.19.1=pyhd8ed1ab_0
|
| 118 |
+
- jinja2=3.1.3=pyhd8ed1ab_0
|
| 119 |
+
- joblib=1.4.2=pyhd8ed1ab_0
|
| 120 |
+
- jpeg=9e=h0b41bf4_3
|
| 121 |
+
- json5=0.9.25=pyhd8ed1ab_0
|
| 122 |
+
- jsonpointer=3.0.0=py39hf3d152e_0
|
| 123 |
+
- jsonschema=4.22.0=pyhd8ed1ab_0
|
| 124 |
+
- jsonschema-specifications=2023.12.1=pyhd8ed1ab_0
|
| 125 |
+
- jsonschema-with-format-nongpl=4.22.0=pyhd8ed1ab_0
|
| 126 |
+
- jupyter-lsp=2.2.5=pyhd8ed1ab_0
|
| 127 |
+
- jupyter_client=8.6.1=pyhd8ed1ab_0
|
| 128 |
+
- jupyter_core=5.7.2=py39hf3d152e_0
|
| 129 |
+
- jupyter_events=0.10.0=pyhd8ed1ab_0
|
| 130 |
+
- jupyter_server=2.14.1=pyhd8ed1ab_0
|
| 131 |
+
- jupyter_server_terminals=0.5.3=pyhd8ed1ab_0
|
| 132 |
+
- jupyterlab=4.2.3=pyhd8ed1ab_0
|
| 133 |
+
- jupyterlab_pygments=0.3.0=pyhd8ed1ab_1
|
| 134 |
+
- jupyterlab_server=2.27.2=pyhd8ed1ab_0
|
| 135 |
+
- keyutils=1.6.1=h166bdaf_0
|
| 136 |
+
- kiwisolver=1.4.5=py39h7633fee_1
|
| 137 |
+
- krb5=1.20.1=h81ceb04_0
|
| 138 |
+
- langcodes=3.4.0=pyhd8ed1ab_0
|
| 139 |
+
- language-data=1.2.0=pyhd8ed1ab_0
|
| 140 |
+
- lcms2=2.12=h3be6417_0
|
| 141 |
+
- ld_impl_linux-64=2.40=h41732ed_0
|
| 142 |
+
- lerc=3.0=h295c915_0
|
| 143 |
+
- libaec=1.1.3=h59595ed_0
|
| 144 |
+
- libboost=1.84.0=h8013b2b_2
|
| 145 |
+
- libboost-devel=1.84.0=h00ab1b0_2
|
| 146 |
+
- libboost-headers=1.84.0=ha770c72_2
|
| 147 |
+
- libbrotlicommon=1.0.9=h166bdaf_9
|
| 148 |
+
- libbrotlidec=1.0.9=h166bdaf_9
|
| 149 |
+
- libbrotlienc=1.0.9=h166bdaf_9
|
| 150 |
+
- libclang=14.0.6=default_hc6dbbc7_1
|
| 151 |
+
- libclang13=14.0.6=default_he11475f_1
|
| 152 |
+
- libcups=2.3.3=h36d4200_3
|
| 153 |
+
- libcurl=8.5.0=h251f7ec_0
|
| 154 |
+
- libdeflate=1.17=h5eee18b_1
|
| 155 |
+
- libedit=3.1.20230828=h5eee18b_0
|
| 156 |
+
- libev=4.33=hd590300_2
|
| 157 |
+
- libevent=2.1.10=h28343ad_4
|
| 158 |
+
- libexpat=2.6.2=h59595ed_0
|
| 159 |
+
- libffi=3.4.2=h7f98852_5
|
| 160 |
+
- libgcc-ng=13.2.0=h807b86a_5
|
| 161 |
+
- libgfortran-ng=13.2.0=h69a702a_5
|
| 162 |
+
- libgfortran5=13.2.0=ha4646dd_5
|
| 163 |
+
- libglib=2.80.0=hf2295e7_1
|
| 164 |
+
- libgomp=13.2.0=h807b86a_5
|
| 165 |
+
- libhwloc=2.9.3=default_h554bfaf_1009
|
| 166 |
+
- libiconv=1.17=hd590300_2
|
| 167 |
+
- libllvm14=14.0.6=hcd5def8_4
|
| 168 |
+
- libnghttp2=1.58.0=h47da74e_1
|
| 169 |
+
- libnsl=2.0.1=hd590300_0
|
| 170 |
+
- libpng=1.6.43=h2797004_0
|
| 171 |
+
- libpq=12.17=hdbd6064_0
|
| 172 |
+
- libprotobuf=3.20.3=h3eb15da_0
|
| 173 |
+
- libsodium=1.0.18=h36c2ea0_1
|
| 174 |
+
- libsqlite=3.45.2=h2797004_0
|
| 175 |
+
- libssh2=1.11.0=h0841786_0
|
| 176 |
+
- libstdcxx-ng=13.2.0=h7e041cc_5
|
| 177 |
+
- libthrift=0.15.0=h362ad58_1
|
| 178 |
+
- libtiff=4.5.1=h6a678d5_0
|
| 179 |
+
- libuuid=2.38.1=h0b41bf4_0
|
| 180 |
+
- libwebp-base=1.3.2=hd590300_0
|
| 181 |
+
- libxcb=1.15=h0b41bf4_0
|
| 182 |
+
- libxcrypt=4.4.36=hd590300_1
|
| 183 |
+
- libxkbcommon=1.7.0=h662e7e4_0
|
| 184 |
+
- libxml2=2.12.6=h232c23b_1
|
| 185 |
+
- libzlib=1.2.13=hd590300_5
|
| 186 |
+
- llvm-openmp=18.1.2=h4dfa4b3_0
|
| 187 |
+
- lz4-c=1.9.4=hcb278e6_0
|
| 188 |
+
- marisa-trie=1.1.0=py39h3d6467e_1
|
| 189 |
+
- markdown=3.6=pyhd8ed1ab_0
|
| 190 |
+
- markdown-it-py=3.0.0=pyhd8ed1ab_0
|
| 191 |
+
- markupsafe=2.1.5=py39hd1e30aa_0
|
| 192 |
+
- matplotlib=3.8.3=py39hf3d152e_0
|
| 193 |
+
- matplotlib-base=3.8.3=py39he9076e7_0
|
| 194 |
+
- matplotlib-inline=0.1.6=pyhd8ed1ab_0
|
| 195 |
+
- mdurl=0.1.2=pyhd8ed1ab_0
|
| 196 |
+
- mistune=3.0.2=pyhd8ed1ab_0
|
| 197 |
+
- mkl=2023.1.0=h213fc3f_46344
|
| 198 |
+
- mkl-service=2.4.0=py39h5eee18b_1
|
| 199 |
+
- mkl_fft=1.3.8=py39h5eee18b_0
|
| 200 |
+
- mkl_random=1.2.4=py39hdb19cb5_0
|
| 201 |
+
- mpc=1.3.1=hfe3b2da_0
|
| 202 |
+
- mpfr=4.2.1=h9458935_0
|
| 203 |
+
- mpmath=1.3.0=pyhd8ed1ab_0
|
| 204 |
+
- multidict=6.0.5=py39hd1e30aa_0
|
| 205 |
+
- multiprocess=0.70.16=py39hd1e30aa_0
|
| 206 |
+
- munkres=1.1.4=pyh9f0ad1d_0
|
| 207 |
+
- murmurhash=1.0.10=py39h3d6467e_1
|
| 208 |
+
- mysql=5.7.24=h721c034_2
|
| 209 |
+
- nbclient=0.10.0=pyhd8ed1ab_0
|
| 210 |
+
- nbconvert-core=7.16.4=pyhd8ed1ab_1
|
| 211 |
+
- nbformat=5.10.4=pyhd8ed1ab_0
|
| 212 |
+
- ncurses=6.4.20240210=h59595ed_0
|
| 213 |
+
- nest-asyncio=1.6.0=pyhd8ed1ab_0
|
| 214 |
+
- networkx=3.2.1=pyhd8ed1ab_0
|
| 215 |
+
- ninja=1.11.1=h924138e_0
|
| 216 |
+
- notebook=7.2.1=pyhd8ed1ab_0
|
| 217 |
+
- notebook-shim=0.2.4=pyhd8ed1ab_0
|
| 218 |
+
- numpy=1.26.4=py39h5f9d8c6_0
|
| 219 |
+
- numpy-base=1.26.4=py39hb5e798b_0
|
| 220 |
+
- openjpeg=2.4.0=h3ad879b_0
|
| 221 |
+
- openssl=3.3.1=h4bc722e_2
|
| 222 |
+
- orc=1.7.4=hb3bc3d3_1
|
| 223 |
+
- overrides=7.7.0=pyhd8ed1ab_0
|
| 224 |
+
- packaging=24.0=pyhd8ed1ab_0
|
| 225 |
+
- pandas=1.4.2=py39h1832856_2
|
| 226 |
+
- pandocfilters=1.5.0=pyhd8ed1ab_0
|
| 227 |
+
- parso=0.8.3=pyhd8ed1ab_0
|
| 228 |
+
- pcre2=10.43=hcad00b1_0
|
| 229 |
+
- pexpect=4.9.0=pyhd8ed1ab_0
|
| 230 |
+
- pickleshare=0.7.5=py_1003
|
| 231 |
+
- pillow=10.2.0=py39h5eee18b_0
|
| 232 |
+
- pip=24.0=pyhd8ed1ab_0
|
| 233 |
+
- pkgutil-resolve-name=1.3.10=pyhd8ed1ab_1
|
| 234 |
+
- platformdirs=4.2.0=pyhd8ed1ab_0
|
| 235 |
+
- ply=3.11=py_1
|
| 236 |
+
- preshed=3.0.9=py39h3d6467e_1
|
| 237 |
+
- prometheus_client=0.20.0=pyhd8ed1ab_0
|
| 238 |
+
- prompt-toolkit=3.0.42=pyha770c72_0
|
| 239 |
+
- prompt_toolkit=3.0.42=hd8ed1ab_0
|
| 240 |
+
- protobuf=3.20.3=py39h227be39_1
|
| 241 |
+
- psutil=5.9.8=py39hd1e30aa_0
|
| 242 |
+
- pthread-stubs=0.4=h36c2ea0_1001
|
| 243 |
+
- ptyprocess=0.7.0=pyhd3deb0d_0
|
| 244 |
+
- pure_eval=0.2.2=pyhd8ed1ab_0
|
| 245 |
+
- pyarrow=14.0.2=py39h1eedbd7_0
|
| 246 |
+
- pyarrow-hotfix=0.6=pyhd8ed1ab_0
|
| 247 |
+
- pycparser=2.21=pyhd8ed1ab_0
|
| 248 |
+
- pydantic=2.8.2=pyhd8ed1ab_0
|
| 249 |
+
- pydantic-core=2.20.1=py39h5cde264_0
|
| 250 |
+
- pygments=2.17.2=pyhd8ed1ab_0
|
| 251 |
+
- pyparsing=3.1.2=pyhd8ed1ab_0
|
| 252 |
+
- pyqt=5.15.10=py39h6a678d5_0
|
| 253 |
+
- pyqt5-sip=12.13.0=py39h5eee18b_0
|
| 254 |
+
- pysocks=1.7.1=pyha2e5f31_6
|
| 255 |
+
- python=3.9.19=h0755675_0_cpython
|
| 256 |
+
- python-dateutil=2.9.0=pyhd8ed1ab_0
|
| 257 |
+
- python-fastjsonschema=2.20.0=pyhd8ed1ab_0
|
| 258 |
+
- python-json-logger=2.0.7=pyhd8ed1ab_0
|
| 259 |
+
- python-tzdata=2024.1=pyhd8ed1ab_0
|
| 260 |
+
- python-xxhash=3.4.1=py39hd1e30aa_0
|
| 261 |
+
- python_abi=3.9=4_cp39
|
| 262 |
+
- pytorch=2.2.0=cpu_py39hdc00b08_0
|
| 263 |
+
- pytz=2024.1=pyhd8ed1ab_0
|
| 264 |
+
- pyyaml=6.0.1=py39hd1e30aa_1
|
| 265 |
+
- pyzmq=25.1.2=py39h8c080ef_0
|
| 266 |
+
- qt-main=5.15.2=h53bd1ea_10
|
| 267 |
+
- re2=2022.04.01=h27087fc_0
|
| 268 |
+
- readline=8.2=h8228510_1
|
| 269 |
+
- referencing=0.35.1=pyhd8ed1ab_0
|
| 270 |
+
- regex=2023.12.25=py39hd1e30aa_0
|
| 271 |
+
- requests=2.31.0=pyhd8ed1ab_0
|
| 272 |
+
- rfc3339-validator=0.1.4=pyhd8ed1ab_0
|
| 273 |
+
- rfc3986-validator=0.1.1=pyh9f0ad1d_0
|
| 274 |
+
- rich=13.7.1=pyhd8ed1ab_0
|
| 275 |
+
- rpds-py=0.18.1=py39ha68c5e3_0
|
| 276 |
+
- s2n=1.3.27=hdbd6064_0
|
| 277 |
+
- safetensors=0.4.2=py39h9fdd4d6_0
|
| 278 |
+
- scikit-learn=1.5.1=py39hf7b0125_0
|
| 279 |
+
- scipy=1.11.3=py39h5f9d8c6_0
|
| 280 |
+
- send2trash=1.8.3=pyh0d859eb_0
|
| 281 |
+
- setuptools=69.2.0=pyhd8ed1ab_0
|
| 282 |
+
- shellingham=1.5.4=pyhd8ed1ab_0
|
| 283 |
+
- sip=6.7.12=py39h3d6467e_0
|
| 284 |
+
- six=1.16.0=pyh6c4a22f_0
|
| 285 |
+
- smart-open=7.0.4=hd8ed1ab_0
|
| 286 |
+
- smart_open=7.0.4=pyhd8ed1ab_0
|
| 287 |
+
- snappy=1.1.10=h9fff704_0
|
| 288 |
+
- sniffio=1.3.1=pyhd8ed1ab_0
|
| 289 |
+
- soupsieve=2.5=pyhd8ed1ab_1
|
| 290 |
+
- spacy=3.7.5=py39h95fdab5_0
|
| 291 |
+
- spacy-legacy=3.0.12=pyhd8ed1ab_0
|
| 292 |
+
- spacy-loggers=1.0.5=pyhd8ed1ab_0
|
| 293 |
+
- sqlite=3.45.2=h2c6b66d_0
|
| 294 |
+
- srsly=2.4.8=py39h3d6467e_1
|
| 295 |
+
- stack_data=0.6.2=pyhd8ed1ab_0
|
| 296 |
+
- sympy=1.12=pypyh9d50eac_103
|
| 297 |
+
- tbb=2021.11.0=h00ab1b0_1
|
| 298 |
+
- tensorboard=2.17.0=pyhd8ed1ab_0
|
| 299 |
+
- tensorboard-data-server=0.7.0=py39hd4f0224_1
|
| 300 |
+
- terminado=0.18.1=pyh0d859eb_0
|
| 301 |
+
- thinc=8.2.3=py39he5d7314_0
|
| 302 |
+
- threadpoolctl=3.5.0=pyhc1e730c_0
|
| 303 |
+
- tinycss2=1.3.0=pyhd8ed1ab_0
|
| 304 |
+
- tk=8.6.13=noxft_h4845f30_101
|
| 305 |
+
- tomli=2.0.1=pyhd8ed1ab_0
|
| 306 |
+
- torchvision=0.14.1=cpu_py39hcda3413_0
|
| 307 |
+
- tornado=6.4=py39hd1e30aa_0
|
| 308 |
+
- tqdm=4.66.2=pyhd8ed1ab_0
|
| 309 |
+
- traitlets=5.14.2=pyhd8ed1ab_0
|
| 310 |
+
- typer=0.12.3=pyhd8ed1ab_0
|
| 311 |
+
- typer-slim=0.12.3=pyhd8ed1ab_0
|
| 312 |
+
- typer-slim-standard=0.12.3=hd8ed1ab_0
|
| 313 |
+
- types-python-dateutil=2.9.0.20240316=pyhd8ed1ab_0
|
| 314 |
+
- typing-extensions=4.10.0=hd8ed1ab_0
|
| 315 |
+
- typing_extensions=4.10.0=pyha770c72_0
|
| 316 |
+
- typing_utils=0.1.0=pyhd8ed1ab_0
|
| 317 |
+
- tzdata=2024a=h0c530f3_0
|
| 318 |
+
- unicodedata2=15.1.0=py39hd1e30aa_0
|
| 319 |
+
- uri-template=1.3.0=pyhd8ed1ab_0
|
| 320 |
+
- urllib3=2.2.1=pyhd8ed1ab_0
|
| 321 |
+
- utf8proc=2.6.1=h5eee18b_1
|
| 322 |
+
- wasabi=1.1.2=py39hf3d152e_1
|
| 323 |
+
- wcwidth=0.2.13=pyhd8ed1ab_0
|
| 324 |
+
- weasel=0.4.1=pyhd8ed1ab_1
|
| 325 |
+
- webcolors=24.6.0=pyhd8ed1ab_0
|
| 326 |
+
- webencodings=0.5.1=pyhd8ed1ab_2
|
| 327 |
+
- websocket-client=1.8.0=pyhd8ed1ab_0
|
| 328 |
+
- werkzeug=3.0.1=pyhd8ed1ab_0
|
| 329 |
+
- wheel=0.43.0=pyhd8ed1ab_0
|
| 330 |
+
- wrapt=1.16.0=py39hd1e30aa_0
|
| 331 |
+
- xkeyboard-config=2.41=hd590300_0
|
| 332 |
+
- xorg-kbproto=1.0.7=h7f98852_1002
|
| 333 |
+
- xorg-libx11=1.8.7=h8ee46fc_0
|
| 334 |
+
- xorg-libxau=1.0.11=hd590300_0
|
| 335 |
+
- xorg-libxdmcp=1.1.3=h7f98852_0
|
| 336 |
+
- xorg-xextproto=7.3.0=h0b41bf4_1003
|
| 337 |
+
- xorg-xproto=7.0.31=h7f98852_1007
|
| 338 |
+
- xxhash=0.8.2=hd590300_0
|
| 339 |
+
- xz=5.4.6=h5eee18b_0
|
| 340 |
+
- yaml=0.2.5=h7f98852_2
|
| 341 |
+
- yarl=1.9.4=py39hd1e30aa_0
|
| 342 |
+
- zeromq=4.3.5=h59595ed_1
|
| 343 |
+
- zipp=3.17.0=pyhd8ed1ab_0
|
| 344 |
+
- zlib=1.2.13=hd590300_5
|
| 345 |
+
- zstd=1.5.5=hfc55251_0
|
| 346 |
+
prefix: /storage/home/hcoda1/3/bxia34/.conda/envs/diffusers
|
frontera_diffusion.sbatch
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/bin/bash
|
| 2 |
+
#SBATCH -J diffusion # Job name
|
| 3 |
+
#SBATCH -p rtx-dev
|
| 4 |
+
#SBATCH -N2 # Number of nodes and cores per node required
|
| 5 |
+
#SBATCH --ntasks-per-node=1
|
| 6 |
+
#SBATCH -t 02:00:00 # Duration of the job (Ex: 15 mins)
|
| 7 |
+
#SBATCH -oReport-%j # Combined output and error messages file
|
| 8 |
+
#SBATCH --mail-type=BEGIN,END,FAIL # Mail preferences
|
| 9 |
+
|
| 10 |
+
python -c "import torch; print('torch.cuda.is_available() =', torch.cuda.is_available()); print('torch.__version__ =', torch.__version__); print('torch.version.cuda =', torch.version.cuda)"
|
| 11 |
+
pwd
|
| 12 |
+
date
|
| 13 |
+
#module load anaconda3/2022.05 # Load module dependencies
|
| 14 |
+
#module load pytorch
|
| 15 |
+
#conda activate diffusers
|
| 16 |
+
conda env list
|
| 17 |
+
module list
|
| 18 |
+
cat $0
|
| 19 |
+
|
| 20 |
+
MASTER_ADDR=$(scontrol show hostnames $SLURM_JOB_NODELIST | head -n 1)
|
| 21 |
+
MASTER_PORT=$((10000 + RANDOM % 10000)) #12355
|
| 22 |
+
|
| 23 |
+
export MASTER_ADDR=$MASTER_ADDR
|
| 24 |
+
export MASTER_PORT=$MASTER_PORT
|
| 25 |
+
|
| 26 |
+
srun python diffusion.py \
|
| 27 |
+
--train "$SCRATCH/LEN128-DIM64-CUB16-Tvir[4, 6]-zeta[10, 250]-0809-123640.h5" \
|
| 28 |
+
--resume outputs/model-N2000-device_count1-node8-epoch19-19004529 \
|
| 29 |
+
--num_new_img_per_gpu 50 \
|
| 30 |
+
--max_num_img_per_gpu 2 \
|
| 31 |
+
--gradient_accumulation_steps 40 \
|
| 32 |
+
######################################################################################
|
| 33 |
+
|
load_h5.py
CHANGED
|
@@ -44,6 +44,7 @@ class Dataset4h5(Dataset):
|
|
| 44 |
transform=True,
|
| 45 |
ranges_dict=None,
|
| 46 |
num_workers=len(os.sched_getaffinity(0))//torch.cuda.device_count(),
|
|
|
|
| 47 |
# shuffle=False,
|
| 48 |
):
|
| 49 |
super().__init__()
|
|
@@ -59,7 +60,7 @@ class Dataset4h5(Dataset):
|
|
| 59 |
self.dim = dim
|
| 60 |
self.transform = transform
|
| 61 |
self.num_workers = num_workers
|
| 62 |
-
|
| 63 |
# if ranges_dict == None:
|
| 64 |
# ranges_dict = dict(
|
| 65 |
# images = {
|
|
@@ -156,10 +157,10 @@ class Dataset4h5(Dataset):
|
|
| 156 |
with h5py.File(self.dir_name, 'r') as f:
|
| 157 |
images_start = time()
|
| 158 |
if self.dim == 2:
|
| 159 |
-
images = f[self.field][idx,0
|
| 160 |
# images = f[self.field][idx,:self.HII_DIM,:self.HII_DIM,-3][:,None]
|
| 161 |
elif self.dim == 3:
|
| 162 |
-
images = f[self.field][idx
|
| 163 |
images_end = time()
|
| 164 |
# print(f"pid {pid}: images of shape {images.shape} loaded after {load_end-load_start:.3f} s")
|
| 165 |
pid = os.getpid()
|
|
|
|
| 44 |
transform=True,
|
| 45 |
ranges_dict=None,
|
| 46 |
num_workers=len(os.sched_getaffinity(0))//torch.cuda.device_count(),
|
| 47 |
+
startat=0,
|
| 48 |
# shuffle=False,
|
| 49 |
):
|
| 50 |
super().__init__()
|
|
|
|
| 60 |
self.dim = dim
|
| 61 |
self.transform = transform
|
| 62 |
self.num_workers = num_workers
|
| 63 |
+
self.startat = startat
|
| 64 |
# if ranges_dict == None:
|
| 65 |
# ranges_dict = dict(
|
| 66 |
# images = {
|
|
|
|
| 157 |
with h5py.File(self.dir_name, 'r') as f:
|
| 158 |
images_start = time()
|
| 159 |
if self.dim == 2:
|
| 160 |
+
images = f[self.field][idx, 0, :self.HII_DIM, self.startat:self.startat+self.num_redshift][:,None]
|
| 161 |
# images = f[self.field][idx,:self.HII_DIM,:self.HII_DIM,-3][:,None]
|
| 162 |
elif self.dim == 3:
|
| 163 |
+
images = f[self.field][idx, :self.HII_DIM, :self.HII_DIM, self.startat:self.startat+self.num_redshift][:,None]
|
| 164 |
images_end = time()
|
| 165 |
# print(f"pid {pid}: images of shape {images.shape} loaded after {load_end-load_start:.3f} s")
|
| 166 |
pid = os.getpid()
|
phoenix_diffusion.sbatch
CHANGED
|
@@ -2,10 +2,10 @@
|
|
| 2 |
#SBATCH -J diffusion # Job name
|
| 3 |
#SBATCH -A gts-jw254-coda20
|
| 4 |
#SBATCH -qembers
|
| 5 |
-
#SBATCH -
|
| 6 |
#SBATCH --ntasks-per-node=1
|
| 7 |
#SBATCH --mem-per-gpu=16G # Memory per core
|
| 8 |
-
#SBATCH -t
|
| 9 |
#SBATCH -oReport-%j # Combined output and error messages file
|
| 10 |
#SBATCH --mail-type=BEGIN,END,FAIL # Mail preferences
|
| 11 |
|
|
@@ -29,8 +29,8 @@ export MASTER_ADDR=$MASTER_ADDR
|
|
| 29 |
export MASTER_PORT=$MASTER_PORT
|
| 30 |
|
| 31 |
srun python diffusion.py \
|
| 32 |
-
--train
|
| 33 |
-
--resume outputs/model-N2000-device_count1-node8-epoch19-
|
| 34 |
--num_new_img_per_gpu 50 \
|
| 35 |
--max_num_img_per_gpu 2 \
|
| 36 |
--gradient_accumulation_steps 40 \
|
|
|
|
| 2 |
#SBATCH -J diffusion # Job name
|
| 3 |
#SBATCH -A gts-jw254-coda20
|
| 4 |
#SBATCH -qembers
|
| 5 |
+
#SBATCH -N1 --gpus-per-node=V100:1 -C V100-16GB # Number of nodes and cores per node required
|
| 6 |
#SBATCH --ntasks-per-node=1
|
| 7 |
#SBATCH --mem-per-gpu=16G # Memory per core
|
| 8 |
+
#SBATCH -t 02:00:00 # Duration of the job (Ex: 15 mins)
|
| 9 |
#SBATCH -oReport-%j # Combined output and error messages file
|
| 10 |
#SBATCH --mail-type=BEGIN,END,FAIL # Mail preferences
|
| 11 |
|
|
|
|
| 29 |
export MASTER_PORT=$MASTER_PORT
|
| 30 |
|
| 31 |
srun python diffusion.py \
|
| 32 |
+
--train 0 \
|
| 33 |
+
--resume outputs/model-N2000-device_count1-node8-epoch19-19004529 \
|
| 34 |
--num_new_img_per_gpu 50 \
|
| 35 |
--max_num_img_per_gpu 2 \
|
| 36 |
--gradient_accumulation_steps 40 \
|