Upload folder using huggingface_hub
Browse files- train_unet.py +23 -1
- unet_cifar.yaml +1 -1
train_unet.py
CHANGED
|
@@ -83,6 +83,13 @@ def parse_args() -> argparse.Namespace:
|
|
| 83 |
help="YAML with UNet + CFM hyperparameters (default: unet_config.yaml next to this script)",
|
| 84 |
)
|
| 85 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 86 |
return p.parse_args()
|
| 87 |
|
| 88 |
|
|
@@ -257,13 +264,28 @@ def main() -> None:
|
|
| 257 |
[
|
| 258 |
v2.ToTensor(),
|
| 259 |
v2.ToDtype(torch.float32, scale=True),
|
| 260 |
-
v2.Resize((
|
| 261 |
v2.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
|
| 262 |
]
|
| 263 |
)
|
| 264 |
train_dataset = load_training_dataset(args, transforms)
|
| 265 |
print(f"Dataset: {args.dataset}, size={len(train_dataset)}")
|
| 266 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 267 |
dummy_dataloader = DataLoader(
|
| 268 |
train_dataset,
|
| 269 |
batch_size=args.batch_size,
|
|
|
|
| 83 |
help="YAML with UNet + CFM hyperparameters (default: unet_config.yaml next to this script)",
|
| 84 |
)
|
| 85 |
|
| 86 |
+
p.add_argument(
|
| 87 |
+
"--data-percent",
|
| 88 |
+
type=int,
|
| 89 |
+
default=100,
|
| 90 |
+
choices=[10, 20, 30, 60, 80, 100],
|
| 91 |
+
help="Use only this percentage of the (possibly filtered) training dataset.",
|
| 92 |
+
)
|
| 93 |
return p.parse_args()
|
| 94 |
|
| 95 |
|
|
|
|
| 264 |
[
|
| 265 |
v2.ToTensor(),
|
| 266 |
v2.ToDtype(torch.float32, scale=True),
|
| 267 |
+
v2.Resize((32,32)),
|
| 268 |
v2.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0]),
|
| 269 |
]
|
| 270 |
)
|
| 271 |
train_dataset = load_training_dataset(args, transforms)
|
| 272 |
print(f"Dataset: {args.dataset}, size={len(train_dataset)}")
|
| 273 |
|
| 274 |
+
orig_len = len(train_dataset)
|
| 275 |
+
if args.data_percent < 100:
|
| 276 |
+
new_len = max(1, int(orig_len * args.data_percent / 100.0))
|
| 277 |
+
|
| 278 |
+
g = torch.Generator()
|
| 279 |
+
g.manual_seed(args.seed)
|
| 280 |
+
|
| 281 |
+
perm = torch.randperm(orig_len, generator=g)
|
| 282 |
+
indices = perm[:new_len].tolist()
|
| 283 |
+
torch.save(perm[:new_len], os.path.join(args.save_dir, "indices.pt"))
|
| 284 |
+
train_dataset = Subset(train_dataset, indices)
|
| 285 |
+
print(f"Subsampled dataset: {args.data_percent}% -> {len(train_dataset)} samples")
|
| 286 |
+
else:
|
| 287 |
+
print(f"Using full dataset: {orig_len} samples")
|
| 288 |
+
|
| 289 |
dummy_dataloader = DataLoader(
|
| 290 |
train_dataset,
|
| 291 |
batch_size=args.batch_size,
|
unet_cifar.yaml
CHANGED
|
@@ -11,7 +11,7 @@ weight_decay: 0.0
|
|
| 11 |
# NeuralODE visualization / sampling
|
| 12 |
save_ep: 30
|
| 13 |
inference_steps: 100
|
| 14 |
-
vis_batch_size:
|
| 15 |
|
| 16 |
# UNet (torchcfm UNetModelWrapper)
|
| 17 |
num_res_blocks: 2
|
|
|
|
| 11 |
# NeuralODE visualization / sampling
|
| 12 |
save_ep: 30
|
| 13 |
inference_steps: 100
|
| 14 |
+
vis_batch_size: 8
|
| 15 |
|
| 16 |
# UNet (torchcfm UNetModelWrapper)
|
| 17 |
num_res_blocks: 2
|