eeshaAI commited on
Commit
44a7b3e
·
verified ·
1 Parent(s): 2c311a6

Fix: skip VQ-VAE if checkpoint exists, reduce epochs to 3

Browse files
Files changed (1) hide show
  1. train_full_pipeline.py +6 -5
train_full_pipeline.py CHANGED
@@ -46,7 +46,7 @@ PERSIST_DIR = os.path.join(DATA_DIR, "zeeb_checkpoints")
46
  os.makedirs(PERSIST_DIR, exist_ok=True)
47
 
48
  # VQ-VAE training
49
- VQ_VAE_EPOCHS = 5
50
  VQ_VAE_LR = 3e-4
51
  VQ_VAE_BATCH = 8
52
  VQ_VAE_IMG_SIZE = 128
@@ -352,14 +352,15 @@ def train_vq_vae(logger: Logger, state: PipelineState) -> VQVAE:
352
  from torchvision import transforms
353
  from PIL import Image
354
 
355
- # Check if already done
356
- if state.is_done("vq_vae"):
357
- logger.log("VQ-VAE already trained! Loading checkpoint...\n")
358
- ckpt_path = os.path.join(PERSIST_DIR, "vq_vae_best.pt")
359
  if os.path.exists(ckpt_path):
 
360
  model = VQVAE()
361
  model.load_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=False))
362
  logger.log("Loaded trained VQ-VAE from checkpoint.\n")
 
363
  return model
364
  else:
365
  logger.log("Checkpoint not found, retraining...\n")
 
46
  os.makedirs(PERSIST_DIR, exist_ok=True)
47
 
48
  # VQ-VAE training
49
+ VQ_VAE_EPOCHS = 3
50
  VQ_VAE_LR = 3e-4
51
  VQ_VAE_BATCH = 8
52
  VQ_VAE_IMG_SIZE = 128
 
352
  from torchvision import transforms
353
  from PIL import Image
354
 
355
+ # Check if already done OR if checkpoint exists from previous run
356
+ ckpt_path = os.path.join(PERSIST_DIR, "vq_vae_best.pt")
357
+ if state.is_done("vq_vae") or os.path.exists(ckpt_path):
 
358
  if os.path.exists(ckpt_path):
359
+ logger.log("VQ-VAE checkpoint found! Loading and skipping training.\n")
360
  model = VQVAE()
361
  model.load_state_dict(torch.load(ckpt_path, map_location="cpu", weights_only=False))
362
  logger.log("Loaded trained VQ-VAE from checkpoint.\n")
363
+ state.update(vq_vae_done=True, phase=2)
364
  return model
365
  else:
366
  logger.log("Checkpoint not found, retraining...\n")