Xsmos commited on
Commit
237a5ab
·
verified ·
1 Parent(s): 5caea8b
diffusion.py CHANGED
@@ -438,14 +438,13 @@ class DDPM21CM:
438
 
439
  del dataset
440
 
441
- def transform(self, img):
442
- if self.config.dim == 3:
443
- #flip along x or y or both
444
- flip_xy = [i+2 for i in range(2) if getrandbits(1)]
445
- img = torch.flip(img, dims=flip_xy)
446
- # flip diagonally
447
- if getrandbits(1):
448
- img = img.transpose(2,3)
449
  return img
450
 
451
  def train(self):
@@ -496,7 +495,10 @@ class DDPM21CM:
496
  pbar_train.set_description(f"{socket.gethostbyname(socket.gethostname())} cuda:{torch.cuda.current_device()}/{self.config.global_rank} Epoch {ep}")
497
  epoch_start = time()
498
  for i, (x, c) in enumerate(self.dataloader):
499
- x = self.transform(x)
 
 
 
500
  x = x.to(self.config.device)#.to(self.config.dtype)
501
  # autocast forward propogation
502
  with autocast(enabled=self.config.autocast):
 
438
 
439
  del dataset
440
 
441
+ def transform(self, img, idx):
442
+ #flip along x or y or both
443
+ flip_xy = [i+1 for i in range(2) if getrandbits(1)]
444
+ img[idx] = torch.flip(img[idx], dims=flip_xy)
445
+ # flip diagonally
446
+ if getrandbits(1):
447
+ img[idx] = img[idx].clone().transpose(1,2)
 
448
  return img
449
 
450
  def train(self):
 
495
  pbar_train.set_description(f"{socket.gethostbyname(socket.gethostname())} cuda:{torch.cuda.current_device()}/{self.config.global_rank} Epoch {ep}")
496
  epoch_start = time()
497
  for i, (x, c) in enumerate(self.dataloader):
498
+ if self.config.dim == 3:
499
+ for idx in range(len(x)):
500
+ x = self.transform(x, idx)
501
+
502
  x = x.to(self.config.device)#.to(self.config.dtype)
503
  # autocast forward propogation
504
  with autocast(enabled=self.config.autocast):
perlmutter_diffusion.sbatch CHANGED
@@ -5,7 +5,7 @@
5
  #SBATCH -q regular #shared
6
  #SBATCH -N4
7
  #SBATCH --gpus-per-node=4
8
- #SBATCH -t 16:00:00
9
  #SBATCH --ntasks-per-node=1
10
  #SBATCH -oReport-%j
11
  #SBATCH --mail-type=BEGIN,END,FAIL
@@ -30,16 +30,16 @@ cat $0
30
  srun python diffusion.py \
31
  --num_image 1600 \
32
  --batch_size 2 \
33
- --n_epoch 20 \
34
  --channel_mult 0.5 1 2 4 4 8 \
35
  --num_new_img_per_gpu 4 \
36
  --max_num_img_per_gpu 2 \
37
  --gradient_accumulation_steps 10 \
38
  --autocast 1 \
39
  --use_checkpoint 1 \
40
- --dropout 0 \
41
  --lrate 2e-5 \
42
  --train "$SCRATCH/LEN128-DIM64-CUB16-Tvir[4, 6]-zeta[10, 250]-0809-123640.h5" \
43
- #--resume ./outputs/model-N1600-device_count4-node4-epoch19-32096018 \
44
 
45
  date
 
5
  #SBATCH -q regular #shared
6
  #SBATCH -N4
7
  #SBATCH --gpus-per-node=4
8
+ #SBATCH -t 48:00:00
9
  #SBATCH --ntasks-per-node=1
10
  #SBATCH -oReport-%j
11
  #SBATCH --mail-type=BEGIN,END,FAIL
 
30
  srun python diffusion.py \
31
  --num_image 1600 \
32
  --batch_size 2 \
33
+ --n_epoch 60 \
34
  --channel_mult 0.5 1 2 4 4 8 \
35
  --num_new_img_per_gpu 4 \
36
  --max_num_img_per_gpu 2 \
37
  --gradient_accumulation_steps 10 \
38
  --autocast 1 \
39
  --use_checkpoint 1 \
40
+ --dropout 0.2 \
41
  --lrate 2e-5 \
42
  --train "$SCRATCH/LEN128-DIM64-CUB16-Tvir[4, 6]-zeta[10, 250]-0809-123640.h5" \
43
+ #--resume ./outputs/model-N1600-device_count4-node4-epoch19-32185426 \
44
 
45
  date
quantify_results.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:076520e2ea10edaa431fab43004103361a160e2b900aa59b3baa114b0aaa5773
3
- size 24213875
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6691f0135f3bb373506e6090511c5527f23e0b4dc780f031bf80ca6d141e32ca
3
+ size 25396988
tensorboard.ipynb CHANGED
@@ -23,13 +23,13 @@
23
  "data": {
24
  "text/html": [
25
  "\n",
26
- " <iframe id=\"tensorboard-frame-54a74258cbb72d6\" width=\"100%\" height=\"800\" frameborder=\"0\">\n",
27
  " </iframe>\n",
28
  " <script>\n",
29
  " (function() {\n",
30
- " const frame = document.getElementById(\"tensorboard-frame-54a74258cbb72d6\");\n",
31
  " const url = new URL(\"/\", window.location);\n",
32
- " const port = 42029;\n",
33
  " if (port) {\n",
34
  " url.port = port;\n",
35
  " }\n",
@@ -59,7 +59,7 @@
59
  {
60
  "data": {
61
  "text/html": [
62
- "<a href=\"https://jupyter.nersc.gov/user/binxia/perlmutter-login-node-base/proxy/42029/\">https://jupyter.nersc.gov/user/binxia/perlmutter-login-node-base/proxy/42029/</a>"
63
  ],
64
  "text/plain": [
65
  "<IPython.core.display.HTML object>"
 
23
  "data": {
24
  "text/html": [
25
  "\n",
26
+ " <iframe id=\"tensorboard-frame-b3fe77206bcde3f5\" width=\"100%\" height=\"800\" frameborder=\"0\">\n",
27
  " </iframe>\n",
28
  " <script>\n",
29
  " (function() {\n",
30
+ " const frame = document.getElementById(\"tensorboard-frame-b3fe77206bcde3f5\");\n",
31
  " const url = new URL(\"/\", window.location);\n",
32
+ " const port = 33553;\n",
33
  " if (port) {\n",
34
  " url.port = port;\n",
35
  " }\n",
 
59
  {
60
  "data": {
61
  "text/html": [
62
+ "<a href=\"https://jupyter.nersc.gov/user/binxia/perlmutter-login-node-base/proxy/33553/\">https://jupyter.nersc.gov/user/binxia/perlmutter-login-node-base/proxy/33553/</a>"
63
  ],
64
  "text/plain": [
65
  "<IPython.core.display.HTML object>"