| """
|
| Byte Dream Training Pipeline
|
| Complete training system for diffusion models with CPU optimization
|
| """
|
|
|
| import os
|
| import torch
|
| import torch.nn as nn
|
| import torch.nn.functional as F
|
| from torch.utils.data import Dataset, DataLoader
|
| from torchvision import transforms
|
| from PIL import Image
|
| import numpy as np
|
| from tqdm import tqdm
|
| import yaml
|
| import argparse
|
| from pathlib import Path
|
| from typing import Tuple, List, Optional
|
| import gc
|
|
|
|
|
| class ImageTextDataset(Dataset):
|
| """
|
| Dataset for image-text pairs
|
| Supports various data augmentations for better generalization
|
| """
|
|
|
| def __init__(
|
| self,
|
| data_dir: str,
|
| image_size: int = 512,
|
| random_flip: bool = True,
|
| random_crop: bool = False,
|
| center_crop: bool = True,
|
| ):
|
| self.data_dir = Path(data_dir)
|
|
|
|
|
| if not self.data_dir.exists():
|
| raise FileNotFoundError(f"Dataset directory not found: {self.data_dir}\nPlease create the directory and add images, or use --train_data with a valid path.")
|
|
|
| self.image_paths = list(self.data_dir.glob("*.jpg")) + \
|
| list(self.data_dir.glob("*.png")) + \
|
| list(self.data_dir.glob("*.jpeg"))
|
|
|
|
|
| if len(self.image_paths) == 0:
|
| raise ValueError(f"No images found in {self.data_dir}\nSupported formats: .jpg, .png, .jpeg")
|
|
|
| self.image_size = image_size
|
| self.random_flip = random_flip
|
| self.random_crop = random_crop
|
| self.center_crop = center_crop
|
|
|
|
|
| self.transform = self._get_transform()
|
|
|
|
|
| self.captions = self._load_captions()
|
|
|
| def _get_transform(self) -> transforms.Compose:
|
| """Get image transformation pipeline"""
|
| transforms_list = []
|
|
|
| if self.random_crop:
|
| transforms_list.append(transforms.RandomCrop(self.image_size))
|
| elif self.center_crop:
|
| transforms_list.append(transforms.CenterCrop(self.image_size))
|
| else:
|
| transforms_list.append(transforms.Resize((self.image_size, self.image_size)))
|
|
|
| if self.random_flip:
|
| transforms_list.append(transforms.RandomHorizontalFlip(p=0.5))
|
|
|
| transforms_list.extend([
|
| transforms.ToTensor(),
|
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
|
| ])
|
|
|
| return transforms.Compose(transforms_list)
|
|
|
| def _load_captions(self) -> dict:
|
| """Load captions from text files"""
|
| captions = {}
|
|
|
| for img_path in self.image_paths:
|
| caption_path = img_path.with_suffix('.txt')
|
| if caption_path.exists():
|
| with open(caption_path, 'r', encoding='utf-8') as f:
|
| captions[str(img_path)] = f.read().strip()
|
| else:
|
|
|
| captions[str(img_path)] = img_path.stem.replace('_', ' ')
|
|
|
| return captions
|
|
|
| def __len__(self) -> int:
|
| return len(self.image_paths)
|
|
|
| def __getitem__(self, idx: int) -> dict:
|
| img_path = self.image_paths[idx]
|
|
|
|
|
| try:
|
| image = Image.open(img_path).convert('RGB')
|
| except Exception as e:
|
| print(f"Error loading image {img_path}: {e}")
|
| return self.__getitem__((idx + 1) % len(self))
|
|
|
|
|
| pixel_values = self.transform(image)
|
|
|
|
|
| caption = self.captions.get(str(img_path), "")
|
|
|
| return {
|
| "pixel_values": pixel_values,
|
| "input_ids": caption,
|
| "image_path": str(img_path),
|
| }
|
|
|
|
|
| class LatentDiffusionTrainer:
|
| """
|
| Trainer for latent diffusion models
|
| Implements training loop with mixed precision and gradient accumulation
|
| """
|
|
|
| def __init__(
|
| self,
|
| unet: nn.Module,
|
| vae: nn.Module,
|
| text_encoder: nn.Module,
|
| scheduler,
|
| config: dict,
|
| device: str = "cpu",
|
| ):
|
| self.unet = unet
|
| self.vae = vae
|
| self.text_encoder = text_encoder
|
| self.scheduler = scheduler
|
| self.config = config
|
| self.device = torch.device(device)
|
|
|
|
|
| self.epochs = config['training']['epochs']
|
| self.batch_size = config['training']['batch_size']
|
| self.learning_rate = config['training']['learning_rate']
|
| self.gradient_accumulation_steps = config['training']['gradient_accumulation_steps']
|
| self.max_grad_norm = config['training']['max_grad_norm']
|
|
|
|
|
| self.mixed_precision = config['training']['mixed_precision']
|
| self.use_amp = self.mixed_precision != "no"
|
|
|
|
|
| self.output_dir = Path(config['training']['output_dir'])
|
| self.logging_dir = Path(config['training']['logging_dir'])
|
| self.output_dir.mkdir(parents=True, exist_ok=True)
|
| self.logging_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
| self.optimizer = torch.optim.AdamW(
|
| unet.parameters(),
|
| lr=self.learning_rate,
|
| betas=(0.9, 0.999),
|
| weight_decay=1e-2,
|
| eps=1e-08,
|
| )
|
|
|
|
|
| self.lr_scheduler = self._create_lr_scheduler()
|
|
|
|
|
| self.scaler = torch.cuda.amp.GradScaler() if self.use_amp and torch.cuda.is_available() else None
|
|
|
|
|
| self._prepare_models()
|
|
|
| def _prepare_models(self):
|
| """Prepare models for training"""
|
| print(f"Preparing models on {self.device}...")
|
|
|
| self.vae.to(self.device)
|
| self.text_encoder.to(self.device)
|
| self.unet.to(self.device)
|
|
|
|
|
| self.vae.eval()
|
| if hasattr(self.text_encoder, 'model'):
|
| self.text_encoder.model.eval()
|
|
|
|
|
| for param in self.vae.parameters():
|
| param.requires_grad = False
|
|
|
| if hasattr(self.text_encoder, 'model'):
|
| for param in self.text_encoder.model.parameters():
|
| param.requires_grad = False
|
|
|
|
|
| self.unet.train()
|
|
|
| def _create_lr_scheduler(self):
|
| """Create learning rate scheduler"""
|
| sched_config = self.config['training']
|
|
|
| if sched_config['lr_scheduler'] == "constant_with_warmup":
|
| return torch.optim.lr_scheduler.ConstantLR(
|
| self.optimizer,
|
| factor=1.0,
|
| total_iters=sched_config['lr_warmup_steps'],
|
| )
|
| elif sched_config['lr_scheduler'] == "linear":
|
| return torch.optim.lr_scheduler.LinearLR(
|
| self.optimizer,
|
| start_factor=0.1,
|
| end_factor=1.0,
|
| total_iters=sched_config['lr_warmup_steps'],
|
| )
|
| else:
|
| return torch.optim.lr_scheduler.ConstantLR(self.optimizer, factor=1.0)
|
|
|
| def encode_images(self, images: torch.Tensor) -> torch.Tensor:
|
| """Encode images to latent space"""
|
| with torch.no_grad():
|
| latents = self.vae.encode(images)
|
|
|
| latents = latents[:, :4]
|
| latents = latents * 0.18215
|
| return latents
|
|
|
| def encode_text(self, texts: List[str]) -> torch.Tensor:
|
| """Encode text to embeddings"""
|
| with torch.no_grad():
|
| text_embeddings = self.text_encoder(texts, device=self.device)
|
| return text_embeddings
|
|
|
| def compute_loss(
|
| self,
|
| latents: torch.Tensor,
|
| text_embeddings: torch.Tensor,
|
| ) -> torch.Tensor:
|
| """
|
| Compute diffusion loss
|
|
|
| Args:
|
| latents: Latent representations of images
|
| text_embeddings: Text embeddings
|
|
|
| Returns:
|
| Loss value
|
| """
|
| batch_size = latents.shape[0]
|
|
|
|
|
| timesteps = torch.randint(
|
| 0,
|
| self.scheduler.num_train_timesteps,
|
| (batch_size,),
|
| device=self.device,
|
| ).long()
|
|
|
|
|
| noise = torch.randn_like(latents)
|
| noisy_latents = self.scheduler.add_noise(latents, noise, timesteps)
|
|
|
|
|
| timestep_tensor = timesteps
|
|
|
| model_output = self.unet(
|
| sample=noisy_latents,
|
| timestep=timestep_tensor,
|
| encoder_hidden_states=text_embeddings,
|
| )
|
|
|
|
|
| loss = F.mse_loss(model_output, noise, reduction="mean")
|
|
|
| return loss
|
|
|
| def train_step(
|
| self,
|
| batch: dict,
|
| ) -> float:
|
| """
|
| Perform single training step
|
|
|
| Args:
|
| batch: Batch of data
|
|
|
| Returns:
|
| Loss value
|
| """
|
| pixel_values = batch["pixel_values"].to(self.device)
|
| input_ids = batch["input_ids"]
|
|
|
|
|
| latents = self.encode_images(pixel_values)
|
| text_embeddings = self.encode_text(input_ids)
|
|
|
|
|
| if self.use_amp and self.scaler is not None:
|
| with torch.cuda.amp.autocast():
|
| loss = self.compute_loss(latents, text_embeddings)
|
| loss = loss / self.gradient_accumulation_steps
|
|
|
| self.scaler.scale(loss).backward()
|
| else:
|
| loss = self.compute_loss(latents, text_embeddings)
|
| loss = loss / self.gradient_accumulation_steps
|
| loss.backward()
|
|
|
| return loss.item() * self.gradient_accumulation_steps
|
|
|
| def save_checkpoint(self, epoch: int, step: int):
|
| """Save model checkpoint"""
|
| checkpoint_dir = self.output_dir / f"checkpoint-{epoch}-{step}"
|
| checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
| torch.save({
|
| 'epoch': epoch,
|
| 'step': step,
|
| 'unet_state_dict': self.unet.state_dict(),
|
| 'optimizer_state_dict': self.optimizer.state_dict(),
|
| 'scheduler_state_dict': self.lr_scheduler.state_dict() if self.lr_scheduler else None,
|
| }, checkpoint_dir / "pytorch_model.bin")
|
|
|
|
|
| with open(checkpoint_dir / "config.yaml", 'w') as f:
|
| yaml.dump(self.config, f)
|
|
|
| print(f"Checkpoint saved to {checkpoint_dir}")
|
|
|
| def train(self, resume_from_checkpoint: Optional[str] = None):
|
| """
|
| Main training loop
|
|
|
| Args:
|
| resume_from_checkpoint: Path to checkpoint to resume from
|
| """
|
|
|
| train_config = self.config['training']
|
|
|
| dataset = ImageTextDataset(
|
| data_dir=train_config['dataset_path'],
|
| image_size=512,
|
| random_flip=train_config['random_flip'],
|
| random_crop=train_config['random_crop'],
|
| center_crop=train_config['center_crop'],
|
| )
|
|
|
| dataloader = DataLoader(
|
| dataset,
|
| batch_size=self.batch_size,
|
| shuffle=True,
|
| num_workers=0,
|
| pin_memory=False,
|
| )
|
|
|
|
|
| start_epoch = 0
|
| global_step = 0
|
|
|
| if resume_from_checkpoint:
|
| print(f"Resuming from checkpoint: {resume_from_checkpoint}")
|
| checkpoint = torch.load(resume_from_checkpoint, map_location=self.device)
|
| self.unet.load_state_dict(checkpoint['unet_state_dict'])
|
| self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
| if checkpoint['scheduler_state_dict']:
|
| self.lr_scheduler.load_state_dict(checkpoint['scheduler_state_dict'])
|
| start_epoch = checkpoint['epoch']
|
| global_step = checkpoint['step']
|
|
|
|
|
| total_steps = len(dataloader) * self.epochs
|
|
|
| print(f"Starting training for {self.epochs} epochs...")
|
| print(f"Total steps: {total_steps}")
|
| print(f"Batch size: {self.batch_size}")
|
| print(f"Mixed precision: {self.mixed_precision}")
|
|
|
| for epoch in range(start_epoch, self.epochs):
|
| self.unet.train()
|
| progress_bar = tqdm(dataloader, desc=f"Epoch {epoch+1}/{self.epochs}")
|
|
|
| epoch_loss = 0
|
| num_steps = 0
|
|
|
| for step, batch in enumerate(progress_bar):
|
|
|
| loss = self.train_step(batch)
|
| epoch_loss += loss
|
| num_steps += 1
|
|
|
|
|
| if (step + 1) % self.gradient_accumulation_steps == 0:
|
| if self.use_amp and self.scaler is not None:
|
| self.scaler.unscale_(self.optimizer)
|
| torch.nn.utils.clip_grad_norm_(
|
| self.unet.parameters(),
|
| self.max_grad_norm,
|
| )
|
| self.scaler.step(self.optimizer)
|
| self.scaler.update()
|
| else:
|
| torch.nn.utils.clip_grad_norm_(self.unet.parameters(), self.max_grad_norm)
|
| self.optimizer.step()
|
|
|
|
|
| if self.lr_scheduler:
|
| self.lr_scheduler.step()
|
|
|
|
|
| self.optimizer.zero_grad()
|
|
|
|
|
| avg_loss = epoch_loss / num_steps
|
| progress_bar.set_postfix({"loss": f"{avg_loss:.4f}"})
|
|
|
|
|
| if (global_step + 1) % self.config['training']['log_every_n_steps'] == 0:
|
| print(f"\nStep {global_step + 1}: Loss = {avg_loss:.4f}")
|
|
|
|
|
| if (global_step + 1) % 1000 == 0:
|
| self.save_checkpoint(epoch, global_step)
|
|
|
| global_step += 1
|
|
|
|
|
| avg_epoch_loss = epoch_loss / max(num_steps, 1)
|
| print(f"\nEpoch {epoch+1} completed. Average loss: {avg_epoch_loss:.4f}")
|
|
|
|
|
| self.save_checkpoint(epoch, global_step)
|
|
|
|
|
| gc.collect()
|
| if torch.cuda.is_available():
|
| torch.cuda.empty_cache()
|
|
|
|
|
| print("\nTraining completed!")
|
| self.save_final_model()
|
|
|
| def save_final_model(self):
|
| """Save final trained model"""
|
| final_dir = self.output_dir / "final"
|
| final_dir.mkdir(parents=True, exist_ok=True)
|
|
|
|
|
| torch.save({
|
| 'unet_state_dict': self.unet.state_dict(),
|
| 'config': self.config,
|
| }, final_dir / "unet_pytorch_model.bin")
|
|
|
| print(f"Final model saved to {final_dir}")
|
|
|
|
|
| def main():
|
| """Main training function"""
|
| parser = argparse.ArgumentParser(description="Train Byte Dream diffusion model")
|
| parser.add_argument("--config", type=str, default="config.yaml", help="Path to config file")
|
| parser.add_argument("--train_data", type=str, default="./dataset", help="Path to training data (default: ./dataset)")
|
| parser.add_argument("--output_dir", type=str, default="./models/bytedream", help="Output directory")
|
| parser.add_argument("--resume", type=str, default=None, help="Resume from checkpoint")
|
| parser.add_argument("--device", type=str, default="cpu", help="Device to train on")
|
|
|
| args = parser.parse_args()
|
|
|
|
|
| with open(args.config, 'r') as f:
|
| config = yaml.safe_load(f)
|
|
|
|
|
| config['training']['dataset_path'] = args.train_data
|
| config['training']['output_dir'] = args.output_dir
|
|
|
|
|
| from bytedream.model import create_unet, create_vae, create_text_encoder
|
| from bytedream.scheduler import create_scheduler
|
|
|
|
|
| print("Creating model components...")
|
| unet = create_unet(config)
|
| vae = create_vae(config)
|
| text_encoder = create_text_encoder(config)
|
| scheduler = create_scheduler(config)
|
|
|
|
|
| total_params = sum(p.numel() for p in unet.parameters())
|
| print(f"UNet parameters: {total_params:,}")
|
|
|
|
|
| trainer = LatentDiffusionTrainer(
|
| unet=unet,
|
| vae=vae,
|
| text_encoder=text_encoder,
|
| scheduler=scheduler,
|
| config=config,
|
| device=args.device,
|
| )
|
|
|
|
|
| trainer.train(resume_from_checkpoint=args.resume)
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|