32257242
Browse files- diffusion.py +11 -9
- perlmutter_diffusion.sbatch +4 -4
- quantify_results.ipynb +2 -2
- tensorboard.ipynb +4 -4
diffusion.py
CHANGED
|
@@ -438,14 +438,13 @@ class DDPM21CM:
|
|
| 438 |
|
| 439 |
del dataset
|
| 440 |
|
| 441 |
-
def transform(self, img):
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
| 448 |
-
img = img.transpose(2,3)
|
| 449 |
return img
|
| 450 |
|
| 451 |
def train(self):
|
|
@@ -496,7 +495,10 @@ class DDPM21CM:
|
|
| 496 |
pbar_train.set_description(f"{socket.gethostbyname(socket.gethostname())} cuda:{torch.cuda.current_device()}/{self.config.global_rank} Epoch {ep}")
|
| 497 |
epoch_start = time()
|
| 498 |
for i, (x, c) in enumerate(self.dataloader):
|
| 499 |
-
|
|
|
|
|
|
|
|
|
|
| 500 |
x = x.to(self.config.device)#.to(self.config.dtype)
|
| 501 |
# autocast forward propogation
|
| 502 |
with autocast(enabled=self.config.autocast):
|
|
|
|
| 438 |
|
| 439 |
del dataset
|
| 440 |
|
| 441 |
+
def transform(self, img, idx):
|
| 442 |
+
#flip along x or y or both
|
| 443 |
+
flip_xy = [i+1 for i in range(2) if getrandbits(1)]
|
| 444 |
+
img[idx] = torch.flip(img[idx], dims=flip_xy)
|
| 445 |
+
# flip diagonally
|
| 446 |
+
if getrandbits(1):
|
| 447 |
+
img[idx] = img[idx].clone().transpose(1,2)
|
|
|
|
| 448 |
return img
|
| 449 |
|
| 450 |
def train(self):
|
|
|
|
| 495 |
pbar_train.set_description(f"{socket.gethostbyname(socket.gethostname())} cuda:{torch.cuda.current_device()}/{self.config.global_rank} Epoch {ep}")
|
| 496 |
epoch_start = time()
|
| 497 |
for i, (x, c) in enumerate(self.dataloader):
|
| 498 |
+
if self.config.dim == 3:
|
| 499 |
+
for idx in range(len(x)):
|
| 500 |
+
x = self.transform(x, idx)
|
| 501 |
+
|
| 502 |
x = x.to(self.config.device)#.to(self.config.dtype)
|
| 503 |
# autocast forward propogation
|
| 504 |
with autocast(enabled=self.config.autocast):
|
perlmutter_diffusion.sbatch
CHANGED
|
@@ -5,7 +5,7 @@
|
|
| 5 |
#SBATCH -q regular #shared
|
| 6 |
#SBATCH -N4
|
| 7 |
#SBATCH --gpus-per-node=4
|
| 8 |
-
#SBATCH -t
|
| 9 |
#SBATCH --ntasks-per-node=1
|
| 10 |
#SBATCH -oReport-%j
|
| 11 |
#SBATCH --mail-type=BEGIN,END,FAIL
|
|
@@ -30,16 +30,16 @@ cat $0
|
|
| 30 |
srun python diffusion.py \
|
| 31 |
--num_image 1600 \
|
| 32 |
--batch_size 2 \
|
| 33 |
-
--n_epoch
|
| 34 |
--channel_mult 0.5 1 2 4 4 8 \
|
| 35 |
--num_new_img_per_gpu 4 \
|
| 36 |
--max_num_img_per_gpu 2 \
|
| 37 |
--gradient_accumulation_steps 10 \
|
| 38 |
--autocast 1 \
|
| 39 |
--use_checkpoint 1 \
|
| 40 |
-
--dropout 0 \
|
| 41 |
--lrate 2e-5 \
|
| 42 |
--train "$SCRATCH/LEN128-DIM64-CUB16-Tvir[4, 6]-zeta[10, 250]-0809-123640.h5" \
|
| 43 |
-
#--resume ./outputs/model-N1600-device_count4-node4-epoch19-
|
| 44 |
|
| 45 |
date
|
|
|
|
| 5 |
#SBATCH -q regular #shared
|
| 6 |
#SBATCH -N4
|
| 7 |
#SBATCH --gpus-per-node=4
|
| 8 |
+
#SBATCH -t 48:00:00
|
| 9 |
#SBATCH --ntasks-per-node=1
|
| 10 |
#SBATCH -oReport-%j
|
| 11 |
#SBATCH --mail-type=BEGIN,END,FAIL
|
|
|
|
| 30 |
srun python diffusion.py \
|
| 31 |
--num_image 1600 \
|
| 32 |
--batch_size 2 \
|
| 33 |
+
--n_epoch 60 \
|
| 34 |
--channel_mult 0.5 1 2 4 4 8 \
|
| 35 |
--num_new_img_per_gpu 4 \
|
| 36 |
--max_num_img_per_gpu 2 \
|
| 37 |
--gradient_accumulation_steps 10 \
|
| 38 |
--autocast 1 \
|
| 39 |
--use_checkpoint 1 \
|
| 40 |
+
--dropout 0.2 \
|
| 41 |
--lrate 2e-5 \
|
| 42 |
--train "$SCRATCH/LEN128-DIM64-CUB16-Tvir[4, 6]-zeta[10, 250]-0809-123640.h5" \
|
| 43 |
+
#--resume ./outputs/model-N1600-device_count4-node4-epoch19-32185426 \
|
| 44 |
|
| 45 |
date
|
quantify_results.ipynb
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
-
size
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:6691f0135f3bb373506e6090511c5527f23e0b4dc780f031bf80ca6d141e32ca
|
| 3 |
+
size 25396988
|
tensorboard.ipynb
CHANGED
|
@@ -23,13 +23,13 @@
|
|
| 23 |
"data": {
|
| 24 |
"text/html": [
|
| 25 |
"\n",
|
| 26 |
-
" <iframe id=\"tensorboard-frame-
|
| 27 |
" </iframe>\n",
|
| 28 |
" <script>\n",
|
| 29 |
" (function() {\n",
|
| 30 |
-
" const frame = document.getElementById(\"tensorboard-frame-
|
| 31 |
" const url = new URL(\"/\", window.location);\n",
|
| 32 |
-
" const port =
|
| 33 |
" if (port) {\n",
|
| 34 |
" url.port = port;\n",
|
| 35 |
" }\n",
|
|
@@ -59,7 +59,7 @@
|
|
| 59 |
{
|
| 60 |
"data": {
|
| 61 |
"text/html": [
|
| 62 |
-
"<a href=\"https://jupyter.nersc.gov/user/binxia/perlmutter-login-node-base/proxy/
|
| 63 |
],
|
| 64 |
"text/plain": [
|
| 65 |
"<IPython.core.display.HTML object>"
|
|
|
|
| 23 |
"data": {
|
| 24 |
"text/html": [
|
| 25 |
"\n",
|
| 26 |
+
" <iframe id=\"tensorboard-frame-b3fe77206bcde3f5\" width=\"100%\" height=\"800\" frameborder=\"0\">\n",
|
| 27 |
" </iframe>\n",
|
| 28 |
" <script>\n",
|
| 29 |
" (function() {\n",
|
| 30 |
+
" const frame = document.getElementById(\"tensorboard-frame-b3fe77206bcde3f5\");\n",
|
| 31 |
" const url = new URL(\"/\", window.location);\n",
|
| 32 |
+
" const port = 33553;\n",
|
| 33 |
" if (port) {\n",
|
| 34 |
" url.port = port;\n",
|
| 35 |
" }\n",
|
|
|
|
| 59 |
{
|
| 60 |
"data": {
|
| 61 |
"text/html": [
|
| 62 |
+
"<a href=\"https://jupyter.nersc.gov/user/binxia/perlmutter-login-node-base/proxy/33553/\">https://jupyter.nersc.gov/user/binxia/perlmutter-login-node-base/proxy/33553/</a>"
|
| 63 |
],
|
| 64 |
"text/plain": [
|
| 65 |
"<IPython.core.display.HTML object>"
|