Fix: skip VQ-VAE if checkpoint exists, reduce epochs to 3
Browse files- train_full_pipeline.py +6 -5
train_full_pipeline.py
CHANGED
|
@@ -46,7 +46,7 @@ PERSIST_DIR = os.path.join(DATA_DIR, "zeeb_checkpoints")
|
|
| 46 |
os.makedirs(PERSIST_DIR, exist_ok=True)
|
| 47 |
|
| 48 |
# VQ-VAE training
|
| 49 |
-
VQ_VAE_EPOCHS =
|
| 50 |
VQ_VAE_LR = 3e-4
|
| 51 |
VQ_VAE_BATCH = 8
|
| 52 |
VQ_VAE_IMG_SIZE = 128
|
|
@@ -352,14 +352,15 @@ def train_vq_vae(logger: Logger, state: PipelineState) -> VQVAE:
|
|
| 352 |
from torchvision import transforms
|
| 353 |
from PIL import Image
|
| 354 |
|
| 355 |
-
# Check if already done
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
ckpt_path = os.path.join(PERSIST_DIR, "vq_vae_best.pt")
|
| 359 |
if os.path.exists(ckpt_path):
|
|
|
|
| 360 |
model = VQVAE()
|
| 361 |
model.load_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=False))
|
| 362 |
logger.log("Loaded trained VQ-VAE from checkpoint.\n")
|
|
|
|
| 363 |
return model
|
| 364 |
else:
|
| 365 |
logger.log("Checkpoint not found, retraining...\n")
|
|
|
|
| 46 |
os.makedirs(PERSIST_DIR, exist_ok=True)
|
| 47 |
|
| 48 |
# VQ-VAE training
|
| 49 |
+
VQ_VAE_EPOCHS = 3
|
| 50 |
VQ_VAE_LR = 3e-4
|
| 51 |
VQ_VAE_BATCH = 8
|
| 52 |
VQ_VAE_IMG_SIZE = 128
|
|
|
|
| 352 |
from torchvision import transforms
|
| 353 |
from PIL import Image
|
| 354 |
|
| 355 |
+
# Check if already done OR if checkpoint exists from previous run
|
| 356 |
+
ckpt_path = os.path.join(PERSIST_DIR, "vq_vae_best.pt")
|
| 357 |
+
if state.is_done("vq_vae") or os.path.exists(ckpt_path):
|
|
|
|
| 358 |
if os.path.exists(ckpt_path):
|
| 359 |
+
logger.log("VQ-VAE checkpoint found! Loading and skipping training.\n")
|
| 360 |
model = VQVAE()
|
| 361 |
model.load_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=False))
|
| 362 |
logger.log("Loaded trained VQ-VAE from checkpoint.\n")
|
| 363 |
+
state.update(vq_vae_done=True, phase=2)
|
| 364 |
return model
|
| 365 |
else:
|
| 366 |
logger.log("Checkpoint not found, retraining...\n")
|