"""Frozen model architecture and user-tunable inference configuration.""" from __future__ import annotations import json from dataclasses import asdict, dataclass from pathlib import Path @dataclass(frozen=True) class IRDiffAEConfig: """Frozen model architecture config. Stored alongside weights as config.json.""" in_channels: int = 3 patch_size: int = 16 model_dim: int = 896 encoder_depth: int = 4 decoder_depth: int = 8 bottleneck_dim: int = 128 mlp_ratio: float = 4.0 depthwise_kernel_size: int = 7 adaln_low_rank_rank: int = 128 # VP diffusion schedule endpoints logsnr_min: float = -10.0 logsnr_max: float = 10.0 # Pixel-space noise std for VP diffusion initialization pixel_noise_std: float = 0.558 def save(self, path: str | Path) -> None: """Save config as JSON.""" p = Path(path) p.parent.mkdir(parents=True, exist_ok=True) p.write_text(json.dumps(asdict(self), indent=2) + "\n") @classmethod def load(cls, path: str | Path) -> IRDiffAEConfig: """Load config from JSON.""" data = json.loads(Path(path).read_text()) return cls(**data) @dataclass class IRDiffAEInferenceConfig: """User-tunable inference parameters with sensible defaults.""" num_steps: int = 1 # decoder forward passes (NFE) sampler: str = "ddim" # "ddim" or "dpmpp_2m" schedule: str = "linear" # "linear" or "cosine" pdg_enabled: bool = False pdg_strength: float = 2.0 seed: int | None = None