""" train.py — Training entry point. Usage: # Phase 1: Train VAE python scripts/train.py --phase vae # Phase 2: Train diffusion U-Net python scripts/train.py --phase diffusion --vae-checkpoint checkpoints/vae_final.pt # Resume training python scripts/train.py --phase diffusion --resume checkpoints/diffusion_step_50000.pt """ import argparse import os import sys # Add project root to path sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) import torch from omegaconf import OmegaConf from data.dataset import create_dataloader from training.trainer import Trainer def main(): parser = argparse.ArgumentParser(description="Train the image editing model") parser.add_argument("--config", type=str, default="config/default.yaml", help="Config file path") parser.add_argument("--phase", type=str, required=True, choices=["vae", "diffusion"], help="Training phase") parser.add_argument("--vae-checkpoint", type=str, default=None, help="VAE checkpoint for diffusion phase") parser.add_argument("--resume", type=str, default=None, help="Resume from checkpoint") parser.add_argument("--max-samples", type=int, default=None, help="Limit dataset size (for testing)") parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") args = parser.parse_args() # Load config (default + optional override) base_config = OmegaConf.load("config/default.yaml") override_config = OmegaConf.load(args.config) config = OmegaConf.to_container(OmegaConf.merge(base_config, override_config), resolve=True) if args.max_samples: config["data"]["max_samples"] = args.max_samples print(f"Device: {args.device}") print(f"Phase: {args.phase}") print(f"Config: {args.config}") if args.device == "cuda" and torch.cuda.is_available(): print(f"GPU: {torch.cuda.get_device_name(0)}") print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB") # Create dataloader dataloader = create_dataloader(config, is_train=True, phase=args.phase, split="train") print(f"Dataset size: {len(dataloader.dataset)}") # Create trainer trainer = Trainer(config, device=args.device) if args.phase == "vae": trainer.train_vae(dataloader, resume_from=args.resume) elif args.phase == "diffusion": if not args.vae_checkpoint: # Look for default VAE checkpoint default_vae = os.path.join(config["paths"]["checkpoint_dir"], "vae_final.pt") if os.path.exists(default_vae): args.vae_checkpoint = default_vae else: print("ERROR: --vae-checkpoint required for diffusion phase") print("Train the VAE first: python scripts/train.py --phase vae") sys.exit(1) trainer.train_diffusion(dataloader, args.vae_checkpoint, resume_from=args.resume) if __name__ == "__main__": main()