Spaces:
Sleeping
Sleeping
| """ | |
| train.py — Training entry point. | |
| Usage: | |
| # Phase 1: Train VAE | |
| python scripts/train.py --phase vae | |
| # Phase 2: Train diffusion U-Net | |
| python scripts/train.py --phase diffusion --vae-checkpoint checkpoints/vae_final.pt | |
| # Resume training | |
| python scripts/train.py --phase diffusion --resume checkpoints/diffusion_step_50000.pt | |
| """ | |
| import argparse | |
| import os | |
| import sys | |
| # Add project root to path | |
| sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| import torch | |
| from omegaconf import OmegaConf | |
| from data.dataset import create_dataloader | |
| from training.trainer import Trainer | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Train the image editing model") | |
| parser.add_argument("--config", type=str, default="config/default.yaml", help="Config file path") | |
| parser.add_argument("--phase", type=str, required=True, choices=["vae", "diffusion"], help="Training phase") | |
| parser.add_argument("--vae-checkpoint", type=str, default=None, help="VAE checkpoint for diffusion phase") | |
| parser.add_argument("--resume", type=str, default=None, help="Resume from checkpoint") | |
| parser.add_argument("--max-samples", type=int, default=None, help="Limit dataset size (for testing)") | |
| parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu") | |
| args = parser.parse_args() | |
| # Load config (default + optional override) | |
| base_config = OmegaConf.load("config/default.yaml") | |
| override_config = OmegaConf.load(args.config) | |
| config = OmegaConf.to_container(OmegaConf.merge(base_config, override_config), resolve=True) | |
| if args.max_samples: | |
| config["data"]["max_samples"] = args.max_samples | |
| print(f"Device: {args.device}") | |
| print(f"Phase: {args.phase}") | |
| print(f"Config: {args.config}") | |
| if args.device == "cuda" and torch.cuda.is_available(): | |
| print(f"GPU: {torch.cuda.get_device_name(0)}") | |
| print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB") | |
| # Create dataloader | |
| dataloader = create_dataloader(config, is_train=True, phase=args.phase, split="train") | |
| print(f"Dataset size: {len(dataloader.dataset)}") | |
| # Create trainer | |
| trainer = Trainer(config, device=args.device) | |
| if args.phase == "vae": | |
| trainer.train_vae(dataloader, resume_from=args.resume) | |
| elif args.phase == "diffusion": | |
| if not args.vae_checkpoint: | |
| # Look for default VAE checkpoint | |
| default_vae = os.path.join(config["paths"]["checkpoint_dir"], "vae_final.pt") | |
| if os.path.exists(default_vae): | |
| args.vae_checkpoint = default_vae | |
| else: | |
| print("ERROR: --vae-checkpoint required for diffusion phase") | |
| print("Train the VAE first: python scripts/train.py --phase vae") | |
| sys.exit(1) | |
| trainer.train_diffusion(dataloader, args.vae_checkpoint, resume_from=args.resume) | |
| if __name__ == "__main__": | |
| main() | |