File size: 4,309 Bytes
19ed37d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
from __future__ import annotations

"""configuration_sedd.py
====================================
HuggingFace *Transformers* configuration class for the `SEDD` architecture.

This mirrors the structure of other community models in 🤗 Transformers so that
`AutoConfig` can correctly instantiate the model.

The default values roughly reproduce the "small" setup shipped in
`configs/model/small.yaml` of this repository.
"""

from typing import Any, Dict

from transformers.configuration_utils import PretrainedConfig

try:
    # `omegaconf` is an explicit dependency of the original SEDD implementation.
    from omegaconf import OmegaConf  # type: ignore
except ImportError:  # pragma: no cover – users might wish to load a config without installing omegaconf
    OmegaConf = None  # type: ignore

__all__ = [
    "SEDDConfig",
]


class SEDDConfig(PretrainedConfig):
    """Configuration class for the SEDD score-based model.

    Parameters
    ----------
    tokens:
        Size of the tokenizer vocabulary (default: 50257 – GPT-2 vocab).
    graph_type:
        Type of token graph to use ("absorb" matches the reference implementation).
    model_hidden_size:
        Dimension of the transformer hidden states.
    model_cond_dim:
        Dimension of the conditional embedding for the noise level.
    model_length:
        Maximum (fixed) sequence length the model was trained with.
    model_n_blocks:
        Number of *DDiT* blocks in the network.
    model_n_heads:
        Number of attention heads per *DDiT* block.
    model_scale_by_sigma:
        Whether to scale the output logits by the noise level (see
        `SEDD.forward`).
    model_dropout:
        Drop-out probability used throughout the network.
    tie_word_embeddings:
        Standard Transformer flag – not used by SEDD but required by the base
        class. Must be present so that the value is serialised in the resulting
        JSON file.
    """

    model_type: str = "sedd"

    def __init__(
        self,
        *,
        tokens: int = 50257,
        # Graph section
        graph_type: str = "absorb",
        # Model section
        model_hidden_size: int = 768,
        model_cond_dim: int = 128,
        model_length: int = 1024,
        model_n_blocks: int = 12,
        model_n_heads: int = 12,
        model_scale_by_sigma: bool = True,
        model_dropout: float = 0.10,
        # Miscellaneous / HF specific
        tie_word_embeddings: bool = False,
        **kwargs,
    ) -> None:
        # NOTE: `tie_word_embeddings` goes to the base class because
        # `PretrainedConfig` validates keyword-only signature.
        super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)

        # Keep attributes *flat* – matching the style used by most HF models.
        # -------------------------------------------------------------------
        self.tokens = tokens
        self.graph_type = graph_type

        self.model_hidden_size = model_hidden_size
        self.model_cond_dim = model_cond_dim
        self.model_length = model_length
        self.model_n_blocks = model_n_blocks
        self.model_n_heads = model_n_heads
        self.model_scale_by_sigma = model_scale_by_sigma
        self.model_dropout = model_dropout

    # ------------------------------------------------------------------
    # Compatibility helpers
    # ------------------------------------------------------------------

    def to_hydra(self):
        """Convert this *flat* configuration to the nested OmegaConf structure
        expected by the reference `SEDD` implementation.
        """
        if OmegaConf is None:
            raise RuntimeError("`omegaconf` is required to build a Hydra config")

        nested: Dict[str, Any] = {
            "tokens": self.tokens,
            "graph": {
                "type": self.graph_type,
            },
            "model": {
                "hidden_size": self.model_hidden_size,
                "cond_dim": self.model_cond_dim,
                "length": self.model_length,
                "n_blocks": self.model_n_blocks,
                "n_heads": self.model_n_heads,
                "scale_by_sigma": self.model_scale_by_sigma,
                "dropout": self.model_dropout,
            },
        }
        return OmegaConf.create(nested)