mdiffae-v1 / m_diffae /config.py
data-archetype's picture
Upload folder using huggingface_hub
128cb34 verified
"""Frozen model architecture and user-tunable inference configuration."""
from __future__ import annotations
import json
from dataclasses import asdict, dataclass
from pathlib import Path
@dataclass(frozen=True)
class MDiffAEConfig:
"""Frozen model architecture config. Stored alongside weights as config.json."""
in_channels: int = 3
patch_size: int = 16
model_dim: int = 896
encoder_depth: int = 4
decoder_depth: int = 4
bottleneck_dim: int = 64
mlp_ratio: float = 4.0
depthwise_kernel_size: int = 7
adaln_low_rank_rank: int = 128
# VP diffusion schedule endpoints
logsnr_min: float = -10.0
logsnr_max: float = 10.0
# Pixel-space noise std for VP diffusion initialization
pixel_noise_std: float = 0.558
# Token mask ratio for PDG (fraction of spatial tokens replaced with mask_feature)
pdg_mask_ratio: float = 0.75
def save(self, path: str | Path) -> None:
"""Save config as JSON."""
p = Path(path)
p.parent.mkdir(parents=True, exist_ok=True)
p.write_text(json.dumps(asdict(self), indent=2) + "\n")
@classmethod
def load(cls, path: str | Path) -> MDiffAEConfig:
"""Load config from JSON."""
data = json.loads(Path(path).read_text())
return cls(**data)
@dataclass
class MDiffAEInferenceConfig:
"""User-tunable inference parameters with sensible defaults.
PDG is very sensitive in mDiffAE — use small strengths (1.05–1.2).
"""
num_steps: int = 1 # decoder forward passes (NFE)
sampler: str = "ddim" # "ddim" or "dpmpp_2m"
schedule: str = "linear" # "linear" or "cosine"
pdg_enabled: bool = False
pdg_strength: float = 1.1
seed: int | None = None