File size: 1,539 Bytes
1ed770c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 | """Frozen model architecture and user-tunable inference configuration."""
from __future__ import annotations
import json
from dataclasses import asdict, dataclass
from pathlib import Path
@dataclass(frozen=True)
class IRDiffAEConfig:
"""Frozen model architecture config. Stored alongside weights as config.json."""
in_channels: int = 3
patch_size: int = 16
model_dim: int = 896
encoder_depth: int = 4
decoder_depth: int = 8
bottleneck_dim: int = 128
mlp_ratio: float = 4.0
depthwise_kernel_size: int = 7
adaln_low_rank_rank: int = 128
# VP diffusion schedule endpoints
logsnr_min: float = -10.0
logsnr_max: float = 10.0
# Pixel-space noise std for VP diffusion initialization
pixel_noise_std: float = 0.558
def save(self, path: str | Path) -> None:
"""Save config as JSON."""
p = Path(path)
p.parent.mkdir(parents=True, exist_ok=True)
p.write_text(json.dumps(asdict(self), indent=2) + "\n")
@classmethod
def load(cls, path: str | Path) -> IRDiffAEConfig:
"""Load config from JSON."""
data = json.loads(Path(path).read_text())
return cls(**data)
@dataclass
class IRDiffAEInferenceConfig:
"""User-tunable inference parameters with sensible defaults."""
num_steps: int = 1 # decoder forward passes (NFE)
sampler: str = "ddim" # "ddim" or "dpmpp_2m"
schedule: str = "linear" # "linear" or "cosine"
pdg_enabled: bool = False
pdg_strength: float = 2.0
seed: int | None = None
|