sedd-medium / configuration_sedd.py
pbcong's picture
Upload folder using huggingface_hub
19ed37d verified
from __future__ import annotations
"""configuration_sedd.py
====================================
HuggingFace *Transformers* configuration class for the `SEDD` architecture.
This mirrors the structure of other community models in πŸ€— Transformers so that
`AutoConfig` can correctly instantiate the model.
The default values roughly reproduce the "small" setup shipped in
`configs/model/small.yaml` of this repository.
"""
from typing import Any, Dict
from transformers.configuration_utils import PretrainedConfig
try:
# `omegaconf` is an explicit dependency of the original SEDD implementation.
from omegaconf import OmegaConf # type: ignore
except ImportError: # pragma: no cover – users might wish to load a config without installing omegaconf
OmegaConf = None # type: ignore
__all__ = [
"SEDDConfig",
]
class SEDDConfig(PretrainedConfig):
"""Configuration class for the SEDD score-based model.
Parameters
----------
tokens:
Size of the tokenizer vocabulary (default: 50257 – GPT-2 vocab).
graph_type:
Type of token graph to use ("absorb" matches the reference implementation).
model_hidden_size:
Dimension of the transformer hidden states.
model_cond_dim:
Dimension of the conditional embedding for the noise level.
model_length:
Maximum (fixed) sequence length the model was trained with.
model_n_blocks:
Number of *DDiT* blocks in the network.
model_n_heads:
Number of attention heads per *DDiT* block.
model_scale_by_sigma:
Whether to scale the output logits by the noise level (see
`SEDD.forward`).
model_dropout:
Drop-out probability used throughout the network.
tie_word_embeddings:
Standard Transformer flag – not used by SEDD but required by the base
class. Must be present so that the value is serialised in the resulting
JSON file.
"""
model_type: str = "sedd"
def __init__(
self,
*,
tokens: int = 50257,
# Graph section
graph_type: str = "absorb",
# Model section
model_hidden_size: int = 768,
model_cond_dim: int = 128,
model_length: int = 1024,
model_n_blocks: int = 12,
model_n_heads: int = 12,
model_scale_by_sigma: bool = True,
model_dropout: float = 0.10,
# Miscellaneous / HF specific
tie_word_embeddings: bool = False,
**kwargs,
) -> None:
# NOTE: `tie_word_embeddings` goes to the base class because
# `PretrainedConfig` validates keyword-only signature.
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
# Keep attributes *flat* – matching the style used by most HF models.
# -------------------------------------------------------------------
self.tokens = tokens
self.graph_type = graph_type
self.model_hidden_size = model_hidden_size
self.model_cond_dim = model_cond_dim
self.model_length = model_length
self.model_n_blocks = model_n_blocks
self.model_n_heads = model_n_heads
self.model_scale_by_sigma = model_scale_by_sigma
self.model_dropout = model_dropout
# ------------------------------------------------------------------
# Compatibility helpers
# ------------------------------------------------------------------
def to_hydra(self):
"""Convert this *flat* configuration to the nested OmegaConf structure
expected by the reference `SEDD` implementation.
"""
if OmegaConf is None:
raise RuntimeError("`omegaconf` is required to build a Hydra config")
nested: Dict[str, Any] = {
"tokens": self.tokens,
"graph": {
"type": self.graph_type,
},
"model": {
"hidden_size": self.model_hidden_size,
"cond_dim": self.model_cond_dim,
"length": self.model_length,
"n_blocks": self.model_n_blocks,
"n_heads": self.model_n_heads,
"scale_by_sigma": self.model_scale_by_sigma,
"dropout": self.model_dropout,
},
}
return OmegaConf.create(nested)