Xsmos commited on
Commit
733f17e
·
verified ·
1 Parent(s): 32187db
context_unet.py CHANGED
@@ -127,12 +127,19 @@ class TimestepBlock(ABC, nn.Module):
127
  """
128
 
129
  class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
 
 
 
 
130
  def forward(self, x, emb, encoder_out=None):
131
  for layer in self:
132
  if isinstance(layer, TimestepBlock):
133
  x = layer(x, emb)
134
  elif isinstance(layer, AttentionBlock):
135
  x = layer(x, encoder_out)
 
 
 
136
  else:
137
  x = layer(x)
138
  return x
 
127
  """
128
 
129
  class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
130
+ def __init__(self, *args, use_checkpoint=False):
131
+ super().__init__(*args)
132
+ self.use_checkpoint = use_checkpoint
133
+
134
  def forward(self, x, emb, encoder_out=None):
135
  for layer in self:
136
  if isinstance(layer, TimestepBlock):
137
  x = layer(x, emb)
138
  elif isinstance(layer, AttentionBlock):
139
  x = layer(x, encoder_out)
140
+ elif self.use_checkpoint and isinstance(layer, tuple(Conv.values())):
141
+ print(f"TimestepEmbedSequential checkpoint working for layer {type(layer)}")
142
+ x = checkpoint.checkpoint(layer, x)
143
  else:
144
  x = layer(x)
145
  return x
diffusion.py CHANGED
@@ -241,8 +241,8 @@ class TrainConfig:
241
  world_size = 1#torch.cuda.device_count()
242
  # repeat = 2
243
 
244
- dim = 2
245
- #dim = 3#2
246
  stride = (2,4) if dim == 2 else (2,2,4)
247
  num_image = 32#0#0#640#320#6400#3000#480#1200#120#3000#300#3000#6000#30#60#6000#1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
248
  batch_size = 1#1#10#50#10#50#20#50#1#2#50#20#2#100 # 10
@@ -296,7 +296,7 @@ class TrainConfig:
296
  #mixed_precision = "no" #"fp16"
297
  gradient_accumulation_steps = 1
298
 
299
- pbar_update_step = 20
300
 
301
  channel_mult = (1,2,2,2,4)
302
  # date = datetime.datetime.now().strftime("%m%d-%H%M")
 
241
  world_size = 1#torch.cuda.device_count()
242
  # repeat = 2
243
 
244
+ #dim = 2
245
+ dim = 3#2
246
  stride = (2,4) if dim == 2 else (2,2,4)
247
  num_image = 32#0#0#640#320#6400#3000#480#1200#120#3000#300#3000#6000#30#60#6000#1000#2000#20000#15000#7000#25600#3000#10000#1000#10000#5000#2560#800#2560
248
  batch_size = 1#1#10#50#10#50#20#50#1#2#50#20#2#100 # 10
 
296
  #mixed_precision = "no" #"fp16"
297
  gradient_accumulation_steps = 1
298
 
299
+ #pbar_update_step = 20
300
 
301
  channel_mult = (1,2,2,2,4)
302
  # date = datetime.datetime.now().strftime("%m%d-%H%M")
perlmutter_diffusion.sbatch CHANGED
@@ -2,10 +2,10 @@
2
  #SBATCH -A m4717
3
  #SBATCH -J diffusion
4
  #SBATCH -C gpu&hbm80g
5
- #SBATCH -q regular #shared
6
  #SBATCH -N1
7
- #SBATCH --gpus-per-node=4
8
- #SBATCH -t 08:30:00
9
  #SBATCH --ntasks-per-node=1
10
  #SBATCH -oReport-%j
11
  #SBATCH --mail-type=BEGIN,END,FAIL
@@ -36,16 +36,16 @@ cat $0
36
  #nvidia-smi
37
 
38
  srun python diffusion.py \
39
- --num_image 6400 \
40
- --batch_size 128 \
41
- --n_epoch 200 \
42
- --num_new_img_per_gpu 20 \
43
- --max_num_img_per_gpu 4 \
44
- --channel_mult 0.5 1 2 2 4 8 \
45
  --gradient_accumulation_steps 1 \
46
  --autocast 1 \
47
  --use_checkpoint 1 \
48
  --train "$SCRATCH/LEN128-DIM64-CUB16-Tvir[4, 6]-zeta[10, 250]-0809-123640.h5" \
49
- #--resume ./outputs/model-N6400-device_count4-node1-epoch49-30143433 \
50
 
51
  date
 
2
  #SBATCH -A m4717
3
  #SBATCH -J diffusion
4
  #SBATCH -C gpu&hbm80g
5
+ #SBATCH -q shared #regular
6
  #SBATCH -N1
7
+ #SBATCH --gpus-per-node=1
8
+ #SBATCH -t 00:30:00
9
  #SBATCH --ntasks-per-node=1
10
  #SBATCH -oReport-%j
11
  #SBATCH --mail-type=BEGIN,END,FAIL
 
36
  #nvidia-smi
37
 
38
  srun python diffusion.py \
39
+ --num_image 64 \
40
+ --batch_size 2 \
41
+ --n_epoch 50 \
42
+ --channel_mult 0.5 1 2 4 4 8 \
43
+ --num_new_img_per_gpu 800 \
44
+ --max_num_img_per_gpu 100 \
45
  --gradient_accumulation_steps 1 \
46
  --autocast 1 \
47
  --use_checkpoint 1 \
48
  --train "$SCRATCH/LEN128-DIM64-CUB16-Tvir[4, 6]-zeta[10, 250]-0809-123640.h5" \
49
+ #--resume ./outputs/model-N6400-device_count4-node1-epoch199-05185634 \
50
 
51
  date
quantify_results.ipynb CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:3618a302e39b53b9514cfc8bdee9a3b8e40e51565fb6ad99b2783ce2b89764cd
3
- size 14880549
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2f5c609710980f1c8798c5f4732afe3f28bce2a24799b0ef5028f1c9fef85a5d
3
+ size 15711677