|
|
from __future__ import annotations |
|
|
|
|
|
"""configuration_sedd.py |
|
|
==================================== |
|
|
HuggingFace *Transformers* configuration class for the `SEDD` architecture. |
|
|
|
|
|
This mirrors the structure of other community models in π€ Transformers so that |
|
|
`AutoConfig` can correctly instantiate the model. |
|
|
|
|
|
The default values roughly reproduce the "small" setup shipped in |
|
|
`configs/model/small.yaml` of this repository. |
|
|
""" |
|
|
|
|
|
from typing import Any, Dict |
|
|
|
|
|
from transformers.configuration_utils import PretrainedConfig |
|
|
|
|
|
try: |
|
|
|
|
|
from omegaconf import OmegaConf |
|
|
except ImportError: |
|
|
OmegaConf = None |
|
|
|
|
|
__all__ = [ |
|
|
"SEDDConfig", |
|
|
] |
|
|
|
|
|
|
|
|
class SEDDConfig(PretrainedConfig): |
|
|
"""Configuration class for the SEDD score-based model. |
|
|
|
|
|
Parameters |
|
|
---------- |
|
|
tokens: |
|
|
Size of the tokenizer vocabulary (default: 50257 β GPT-2 vocab). |
|
|
graph_type: |
|
|
Type of token graph to use ("absorb" matches the reference implementation). |
|
|
model_hidden_size: |
|
|
Dimension of the transformer hidden states. |
|
|
model_cond_dim: |
|
|
Dimension of the conditional embedding for the noise level. |
|
|
model_length: |
|
|
Maximum (fixed) sequence length the model was trained with. |
|
|
model_n_blocks: |
|
|
Number of *DDiT* blocks in the network. |
|
|
model_n_heads: |
|
|
Number of attention heads per *DDiT* block. |
|
|
model_scale_by_sigma: |
|
|
Whether to scale the output logits by the noise level (see |
|
|
`SEDD.forward`). |
|
|
model_dropout: |
|
|
Drop-out probability used throughout the network. |
|
|
tie_word_embeddings: |
|
|
Standard Transformer flag β not used by SEDD but required by the base |
|
|
class. Must be present so that the value is serialised in the resulting |
|
|
JSON file. |
|
|
""" |
|
|
|
|
|
model_type: str = "sedd" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
*, |
|
|
tokens: int = 50257, |
|
|
|
|
|
graph_type: str = "absorb", |
|
|
|
|
|
model_hidden_size: int = 768, |
|
|
model_cond_dim: int = 128, |
|
|
model_length: int = 1024, |
|
|
model_n_blocks: int = 12, |
|
|
model_n_heads: int = 12, |
|
|
model_scale_by_sigma: bool = True, |
|
|
model_dropout: float = 0.10, |
|
|
|
|
|
tie_word_embeddings: bool = False, |
|
|
**kwargs, |
|
|
) -> None: |
|
|
|
|
|
|
|
|
super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs) |
|
|
|
|
|
|
|
|
|
|
|
self.tokens = tokens |
|
|
self.graph_type = graph_type |
|
|
|
|
|
self.model_hidden_size = model_hidden_size |
|
|
self.model_cond_dim = model_cond_dim |
|
|
self.model_length = model_length |
|
|
self.model_n_blocks = model_n_blocks |
|
|
self.model_n_heads = model_n_heads |
|
|
self.model_scale_by_sigma = model_scale_by_sigma |
|
|
self.model_dropout = model_dropout |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def to_hydra(self): |
|
|
"""Convert this *flat* configuration to the nested OmegaConf structure |
|
|
expected by the reference `SEDD` implementation. |
|
|
""" |
|
|
if OmegaConf is None: |
|
|
raise RuntimeError("`omegaconf` is required to build a Hydra config") |
|
|
|
|
|
nested: Dict[str, Any] = { |
|
|
"tokens": self.tokens, |
|
|
"graph": { |
|
|
"type": self.graph_type, |
|
|
}, |
|
|
"model": { |
|
|
"hidden_size": self.model_hidden_size, |
|
|
"cond_dim": self.model_cond_dim, |
|
|
"length": self.model_length, |
|
|
"n_blocks": self.model_n_blocks, |
|
|
"n_heads": self.model_n_heads, |
|
|
"scale_by_sigma": self.model_scale_by_sigma, |
|
|
"dropout": self.model_dropout, |
|
|
}, |
|
|
} |
|
|
return OmegaConf.create(nested) |