Raghava Pulugu
Clean deployment
cad10d9
Raw
History Blame Contribute Delete
2.99 kB
"""
train.py — Training entry point.
Usage:
# Phase 1: Train VAE
python scripts/train.py --phase vae
# Phase 2: Train diffusion U-Net
python scripts/train.py --phase diffusion --vae-checkpoint checkpoints/vae_final.pt
# Resume training
python scripts/train.py --phase diffusion --resume checkpoints/diffusion_step_50000.pt
"""
import argparse
import os
import sys
# Add project root to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
from omegaconf import OmegaConf
from data.dataset import create_dataloader
from training.trainer import Trainer
def main():
parser = argparse.ArgumentParser(description="Train the image editing model")
parser.add_argument("--config", type=str, default="config/default.yaml", help="Config file path")
parser.add_argument("--phase", type=str, required=True, choices=["vae", "diffusion"], help="Training phase")
parser.add_argument("--vae-checkpoint", type=str, default=None, help="VAE checkpoint for diffusion phase")
parser.add_argument("--resume", type=str, default=None, help="Resume from checkpoint")
parser.add_argument("--max-samples", type=int, default=None, help="Limit dataset size (for testing)")
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu")
args = parser.parse_args()
# Load config (default + optional override)
base_config = OmegaConf.load("config/default.yaml")
override_config = OmegaConf.load(args.config)
config = OmegaConf.to_container(OmegaConf.merge(base_config, override_config), resolve=True)
if args.max_samples:
config["data"]["max_samples"] = args.max_samples
print(f"Device: {args.device}")
print(f"Phase: {args.phase}")
print(f"Config: {args.config}")
if args.device == "cuda" and torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"VRAM: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")
# Create dataloader
dataloader = create_dataloader(config, is_train=True, phase=args.phase, split="train")
print(f"Dataset size: {len(dataloader.dataset)}")
# Create trainer
trainer = Trainer(config, device=args.device)
if args.phase == "vae":
trainer.train_vae(dataloader, resume_from=args.resume)
elif args.phase == "diffusion":
if not args.vae_checkpoint:
# Look for default VAE checkpoint
default_vae = os.path.join(config["paths"]["checkpoint_dir"], "vae_final.pt")
if os.path.exists(default_vae):
args.vae_checkpoint = default_vae
else:
print("ERROR: --vae-checkpoint required for diffusion phase")
print("Train the VAE first: python scripts/train.py --phase vae")
sys.exit(1)
trainer.train_diffusion(dataloader, args.vae_checkpoint, resume_from=args.resume)
if __name__ == "__main__":
main()