haiphamcse commited on
Commit
f729117
·
verified ·
1 Parent(s): 549b0f9

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. train_unet.py +23 -1
  2. unet_cifar.yaml +1 -1
train_unet.py CHANGED
@@ -83,6 +83,13 @@ def parse_args() -> argparse.Namespace:
83
  help="YAML with UNet + CFM hyperparameters (default: unet_config.yaml next to this script)",
84
  )
85
 
 
 
 
 
 
 
 
86
  return p.parse_args()
87
 
88
 
@@ -257,13 +264,28 @@ def main() -> None:
257
  [
258
  v2.ToTensor(),
259
  v2.ToDtype(torch.float32, scale=True),
260
- v2.Resize((64,64)),
261
  v2.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
262
  ]
263
  )
264
  train_dataset = load_training_dataset(args, transforms)
265
  print(f"Dataset: {args.dataset}, size={len(train_dataset)}")
266
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
267
  dummy_dataloader = DataLoader(
268
  train_dataset,
269
  batch_size=args.batch_size,
 
83
  help="YAML with UNet + CFM hyperparameters (default: unet_config.yaml next to this script)",
84
  )
85
 
86
+ p.add_argument(
87
+ "--data-percent",
88
+ type=int,
89
+ default=100,
90
+ choices=[10, 20, 30, 60, 80, 100],
91
+ help="Use only this percentage of the (possibly filtered) training dataset.",
92
+ )
93
  return p.parse_args()
94
 
95
 
 
264
  [
265
  v2.ToTensor(),
266
  v2.ToDtype(torch.float32, scale=True),
267
+ v2.Resize((32,32)),
268
  v2.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
269
  ]
270
  )
271
  train_dataset = load_training_dataset(args, transforms)
272
  print(f"Dataset: {args.dataset}, size={len(train_dataset)}")
273
 
274
+ orig_len = len(train_dataset)
275
+ if args.data_percent < 100:
276
+ new_len = max(1, int(orig_len * args.data_percent / 100.0))
277
+
278
+ g = torch.Generator()
279
+ g.manual_seed(args.seed)
280
+
281
+ perm = torch.randperm(orig_len, generator=g)
282
+ indices = perm[:new_len].tolist()
283
+ torch.save(perm[:new_len], os.path.join(args.save_dir, "indices.pt"))
284
+ train_dataset = Subset(train_dataset, indices)
285
+ print(f"Subsampled dataset: {args.data_percent}% -> {len(train_dataset)} samples")
286
+ else:
287
+ print(f"Using full dataset: {orig_len} samples")
288
+
289
  dummy_dataloader = DataLoader(
290
  train_dataset,
291
  batch_size=args.batch_size,
unet_cifar.yaml CHANGED
@@ -11,7 +11,7 @@ weight_decay: 0.0
11
  # NeuralODE visualization / sampling
12
  save_ep: 30
13
  inference_steps: 100
14
- vis_batch_size: 4
15
 
16
  # UNet (torchcfm UNetModelWrapper)
17
  num_res_blocks: 2
 
11
  # NeuralODE visualization / sampling
12
  save_ep: 30
13
  inference_steps: 100
14
+ vis_batch_size: 8
15
 
16
  # UNet (torchcfm UNetModelWrapper)
17
  num_res_blocks: 2