"""Central hyperparameter config for DDIM face-generation project. Resolution stages (64 -> 128 -> 256) share the same model architecture; each stage just resizes input images and adjusts batch size. Use Config.for_stage() to materialize a stage-specific config. """ from __future__ import annotations import os from dataclasses import dataclass, field, asdict from typing import Tuple import torch def _pick_device() -> str: if torch.backends.mps.is_available(): return "mps" if torch.cuda.is_available(): return "cuda" return "cpu" @dataclass class Config: # ---- paths --------------------------------------------------------- project_root: str = "/Volumes/Projects/DDIM_image_Generation" data_dir: str = "/Volumes/Projects/DDIM_image_Generation/celeba_hq_256" ckpt_dir: str = "/Volumes/Projects/DDIM_image_Generation/minidiffusion/checkpoints" sample_dir: str = "/Volumes/Projects/DDIM_image_Generation/minidiffusion/samples" # ---- model --------------------------------------------------------- image_size: int = 64 # current training stage in_channels: int = 3 base_channels: int = 128 channel_mults: Tuple[int, ...] = (1, 2, 4, 8) # -> [128, 256, 512, 1024] num_res_blocks: int = 2 # resolutions (in pixels) at which self-attention is applied. Two attn # blocks at the bottleneck (8x8) and one at 32x32 are encoded by listing # 8 twice and 32 once; the U-Net checks membership. attn_resolutions: Tuple[int, ...] = (8, 8, 32) dropout: float = 0.1 time_embed_dim: int = 512 # 4 * base_channels # ---- diffusion ----------------------------------------------------- timesteps: int = 1000 beta_start: float = 1e-4 beta_end: float = 2e-2 beta_schedule: str = "linear" # linear | cosine ddim_steps: int = 50 ddim_eta: float = 0.0 # 0 = deterministic DDIM # ---- training ------------------------------------------------------ batch_size: int = 32 # overridden per stage num_workers: int = 4 lr: float = 2e-4 weight_decay: float = 0.0 ema_decay: float = 0.9999 grad_clip: float = 1.0 epochs: int = 100 log_every: int = 50 # steps sample_every_epochs: int = 5 # log a sample grid to W&B ckpt_every_epochs: int = 1 seed: int = 42 # ---- runtime ------------------------------------------------------- device: str = field(default_factory=_pick_device) mixed_precision: bool = False # MPS autocast still flaky use_wandb: bool = True wandb_project: str = "minidiffusion-celebahq" run_name: str = "stage-64" # ---- helpers ------------------------------------------------------- @classmethod def for_stage(cls, image_size: int, **overrides) -> "Config": """Return a config tuned for a given resolution stage. Channel counts are tuned for a 24GB Mac Mini (Apple Silicon, MPS). Smaller stages use a smaller backbone so the warm-up trains quickly; the 256-stage uses the full [128,256,512,1024] config from the spec. """ if image_size == 64: # ~30M params — fast to iterate on, fits easily in MPS memory stage = dict(image_size=64, batch_size=32, run_name="stage-64", base_channels=64, channel_mults=(1, 2, 4, 4), num_res_blocks=2, attn_resolutions=(8, 8, 16), time_embed_dim=256) elif image_size == 128: # ~80M params stage = dict(image_size=128, batch_size=16, run_name="stage-128", base_channels=96, channel_mults=(1, 2, 4, 4), num_res_blocks=2, attn_resolutions=(8, 8, 32), time_embed_dim=384) elif image_size == 256: # ~245M params — the full spec, only run overnight stage = dict(image_size=256, batch_size=4, run_name="stage-256", base_channels=128, channel_mults=(1, 2, 4, 8), num_res_blocks=2, attn_resolutions=(8, 8, 32), time_embed_dim=512) else: raise ValueError(f"Unsupported image_size {image_size}") stage.update(overrides) return cls(**stage) def to_dict(self) -> dict: return asdict(self) def get_default_config() -> Config: cfg = Config() os.makedirs(cfg.ckpt_dir, exist_ok=True) os.makedirs(cfg.sample_dir, exist_ok=True) return cfg if __name__ == "__main__": cfg = get_default_config() print("device:", cfg.device) for stage in (64, 128, 256): s = Config.for_stage(stage) print(f"stage {stage}: bs={s.batch_size} attn={s.attn_resolutions} run={s.run_name}")