"""Frozen model architecture and user-tunable inference configuration.""" from __future__ import annotations import json from dataclasses import asdict, dataclass from pathlib import Path @dataclass(frozen=True) class MDiffAEConfig: """Frozen model architecture config. Stored alongside weights as config.json.""" in_channels: int = 3 patch_size: int = 16 model_dim: int = 896 encoder_depth: int = 4 decoder_depth: int = 4 bottleneck_dim: int = 64 mlp_ratio: float = 4.0 depthwise_kernel_size: int = 7 adaln_low_rank_rank: int = 128 # VP diffusion schedule endpoints logsnr_min: float = -10.0 logsnr_max: float = 10.0 # Pixel-space noise std for VP diffusion initialization pixel_noise_std: float = 0.558 # Token mask ratio for PDG (fraction of spatial tokens replaced with mask_feature) pdg_mask_ratio: float = 0.75 def save(self, path: str | Path) -> None: """Save config as JSON.""" p = Path(path) p.parent.mkdir(parents=True, exist_ok=True) p.write_text(json.dumps(asdict(self), indent=2) + "\n") @classmethod def load(cls, path: str | Path) -> MDiffAEConfig: """Load config from JSON.""" data = json.loads(Path(path).read_text()) return cls(**data) @dataclass class MDiffAEInferenceConfig: """User-tunable inference parameters with sensible defaults. PDG is very sensitive in mDiffAE — use small strengths (1.05–1.2). """ num_steps: int = 1 # decoder forward passes (NFE) sampler: str = "ddim" # "ddim" or "dpmpp_2m" schedule: str = "linear" # "linear" or "cosine" pdg_enabled: bool = False pdg_strength: float = 1.1 seed: int | None = None