from __future__ import annotations """configuration_sedd.py ==================================== HuggingFace *Transformers* configuration class for the `SEDD` architecture. This mirrors the structure of other community models in 🤗 Transformers so that `AutoConfig` can correctly instantiate the model. The default values roughly reproduce the "small" setup shipped in `configs/model/small.yaml` of this repository. """ from typing import Any, Dict from transformers.configuration_utils import PretrainedConfig try: # `omegaconf` is an explicit dependency of the original SEDD implementation. from omegaconf import OmegaConf # type: ignore except ImportError: # pragma: no cover – users might wish to load a config without installing omegaconf OmegaConf = None # type: ignore __all__ = [ "SEDDConfig", ] class SEDDConfig(PretrainedConfig): """Configuration class for the SEDD score-based model. Parameters ---------- tokens: Size of the tokenizer vocabulary (default: 50257 – GPT-2 vocab). graph_type: Type of token graph to use ("absorb" matches the reference implementation). model_hidden_size: Dimension of the transformer hidden states. model_cond_dim: Dimension of the conditional embedding for the noise level. model_length: Maximum (fixed) sequence length the model was trained with. model_n_blocks: Number of *DDiT* blocks in the network. model_n_heads: Number of attention heads per *DDiT* block. model_scale_by_sigma: Whether to scale the output logits by the noise level (see `SEDD.forward`). model_dropout: Drop-out probability used throughout the network. tie_word_embeddings: Standard Transformer flag – not used by SEDD but required by the base class. Must be present so that the value is serialised in the resulting JSON file. """ model_type: str = "sedd" def __init__( self, *, tokens: int = 50257, # Graph section graph_type: str = "absorb", # Model section model_hidden_size: int = 768, model_cond_dim: int = 128, model_length: int = 1024, model_n_blocks: int = 12, model_n_heads: int = 12, model_scale_by_sigma: bool = True, model_dropout: float = 0.10, # Miscellaneous / HF specific tie_word_embeddings: bool = False, **kwargs, ) -> None: # NOTE: `tie_word_embeddings` goes to the base class because # `PretrainedConfig` validates keyword-only signature. super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) # Keep attributes *flat* – matching the style used by most HF models. # ------------------------------------------------------------------- self.tokens = tokens self.graph_type = graph_type self.model_hidden_size = model_hidden_size self.model_cond_dim = model_cond_dim self.model_length = model_length self.model_n_blocks = model_n_blocks self.model_n_heads = model_n_heads self.model_scale_by_sigma = model_scale_by_sigma self.model_dropout = model_dropout # ------------------------------------------------------------------ # Compatibility helpers # ------------------------------------------------------------------ def to_hydra(self): """Convert this *flat* configuration to the nested OmegaConf structure expected by the reference `SEDD` implementation. """ if OmegaConf is None: raise RuntimeError("`omegaconf` is required to build a Hydra config") nested: Dict[str, Any] = { "tokens": self.tokens, "graph": { "type": self.graph_type, }, "model": { "hidden_size": self.model_hidden_size, "cond_dim": self.model_cond_dim, "length": self.model_length, "n_blocks": self.model_n_blocks, "n_heads": self.model_n_heads, "scale_by_sigma": self.model_scale_by_sigma, "dropout": self.model_dropout, }, } return OmegaConf.create(nested)