| | """Frozen model architecture and user-tunable inference configuration.""" |
| |
|
| | from __future__ import annotations |
| |
|
| | import json |
| | from dataclasses import asdict, dataclass |
| | from pathlib import Path |
| |
|
| |
|
| | @dataclass(frozen=True) |
| | class MDiffAEConfig: |
| | """Frozen model architecture config. Stored alongside weights as config.json.""" |
| |
|
| | in_channels: int = 3 |
| | patch_size: int = 16 |
| | model_dim: int = 896 |
| | encoder_depth: int = 4 |
| | decoder_depth: int = 4 |
| | bottleneck_dim: int = 64 |
| | mlp_ratio: float = 4.0 |
| | depthwise_kernel_size: int = 7 |
| | adaln_low_rank_rank: int = 128 |
| | |
| | logsnr_min: float = -10.0 |
| | logsnr_max: float = 10.0 |
| | |
| | pixel_noise_std: float = 0.558 |
| | |
| | pdg_mask_ratio: float = 0.75 |
| |
|
| | def save(self, path: str | Path) -> None: |
| | """Save config as JSON.""" |
| | p = Path(path) |
| | p.parent.mkdir(parents=True, exist_ok=True) |
| | p.write_text(json.dumps(asdict(self), indent=2) + "\n") |
| |
|
| | @classmethod |
| | def load(cls, path: str | Path) -> MDiffAEConfig: |
| | """Load config from JSON.""" |
| | data = json.loads(Path(path).read_text()) |
| | return cls(**data) |
| |
|
| |
|
| | @dataclass |
| | class MDiffAEInferenceConfig: |
| | """User-tunable inference parameters with sensible defaults. |
| | |
| | PDG is very sensitive in mDiffAE — use small strengths (1.05–1.2). |
| | """ |
| |
|
| | num_steps: int = 1 |
| | sampler: str = "ddim" |
| | schedule: str = "linear" |
| | pdg_enabled: bool = False |
| | pdg_strength: float = 1.1 |
| | seed: int | None = None |
| |
|