pbcong commited on
Commit
19ed37d
·
verified ·
1 Parent(s): 0812c72

Upload folder using huggingface_hub

Browse files
Files changed (5) hide show
  1. README.md +1 -0
  2. config.json +61 -0
  3. configuration_sedd.py +122 -0
  4. pytorch_model.bin +3 -0
  5. sedd_wrapper.py +289 -0
README.md ADDED
@@ -0,0 +1 @@
 
 
1
+ Score Entropy Discrete Diffusion (SEDD) medium model for use with inference code in https://github.com/louaaron/Score-Entropy-Discrete-Diffusion. Paper found at arxiv.org/abs/2310.16834
config.json ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "ngpus": 8,
3
+ "tokens": 50257,
4
+ "training": {
5
+ "batch_size": 512,
6
+ "accum": 2,
7
+ "n_iters": 1300001,
8
+ "snapshot_freq": 50000,
9
+ "log_freq": 50,
10
+ "eval_freq": 100,
11
+ "snapshot_freq_for_preemption": 10000,
12
+ "weight": "standard",
13
+ "snapshot_sampling": true,
14
+ "ema": 0.9999
15
+ },
16
+ "data": {
17
+ "train": "openwebtext",
18
+ "valid": "wikitext103",
19
+ "cache_dir": "data"
20
+ },
21
+ "graph": {
22
+ "type": "absorb"
23
+ },
24
+ "noise": {
25
+ "type": "loglinear",
26
+ "sigma_min": 0.0001,
27
+ "sigma_max": 20
28
+ },
29
+ "sampling": {
30
+ "predictor": "euler",
31
+ "steps": 128,
32
+ "noise_removal": true
33
+ },
34
+ "eval": {
35
+ "batch_size": 512,
36
+ "perplexity": true,
37
+ "perplexity_batch_size": 32
38
+ },
39
+ "optim": {
40
+ "weight_decay": 0,
41
+ "optimizer": "AdamW",
42
+ "lr": 0.0003,
43
+ "beta1": 0.9,
44
+ "beta2": 0.999,
45
+ "eps": 1e-08,
46
+ "warmup": 2500,
47
+ "grad_clip": 1.0
48
+ },
49
+ "model": {
50
+ "name": "medium",
51
+ "type": "ddit",
52
+ "hidden_size": 1024,
53
+ "cond_dim": 128,
54
+ "length": 1024,
55
+ "n_blocks": 24,
56
+ "n_heads": 16,
57
+ "scale_by_sigma": true,
58
+ "dropout": 0.1
59
+ },
60
+ "work_dir": "absorb_medium"
61
+ }
configuration_sedd.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ """configuration_sedd.py
4
+ ====================================
5
+ HuggingFace *Transformers* configuration class for the `SEDD` architecture.
6
+
7
+ This mirrors the structure of other community models in 🤗 Transformers so that
8
+ `AutoConfig` can correctly instantiate the model.
9
+
10
+ The default values roughly reproduce the "small" setup shipped in
11
+ `configs/model/small.yaml` of this repository.
12
+ """
13
+
14
+ from typing import Any, Dict
15
+
16
+ from transformers.configuration_utils import PretrainedConfig
17
+
18
+ try:
19
+ # `omegaconf` is an explicit dependency of the original SEDD implementation.
20
+ from omegaconf import OmegaConf # type: ignore
21
+ except ImportError: # pragma: no cover – users might wish to load a config without installing omegaconf
22
+ OmegaConf = None # type: ignore
23
+
24
+ __all__ = [
25
+ "SEDDConfig",
26
+ ]
27
+
28
+
29
+ class SEDDConfig(PretrainedConfig):
30
+ """Configuration class for the SEDD score-based model.
31
+
32
+ Parameters
33
+ ----------
34
+ tokens:
35
+ Size of the tokenizer vocabulary (default: 50257 – GPT-2 vocab).
36
+ graph_type:
37
+ Type of token graph to use ("absorb" matches the reference implementation).
38
+ model_hidden_size:
39
+ Dimension of the transformer hidden states.
40
+ model_cond_dim:
41
+ Dimension of the conditional embedding for the noise level.
42
+ model_length:
43
+ Maximum (fixed) sequence length the model was trained with.
44
+ model_n_blocks:
45
+ Number of *DDiT* blocks in the network.
46
+ model_n_heads:
47
+ Number of attention heads per *DDiT* block.
48
+ model_scale_by_sigma:
49
+ Whether to scale the output logits by the noise level (see
50
+ `SEDD.forward`).
51
+ model_dropout:
52
+ Drop-out probability used throughout the network.
53
+ tie_word_embeddings:
54
+ Standard Transformer flag – not used by SEDD but required by the base
55
+ class. Must be present so that the value is serialised in the resulting
56
+ JSON file.
57
+ """
58
+
59
+ model_type: str = "sedd"
60
+
61
+ def __init__(
62
+ self,
63
+ *,
64
+ tokens: int = 50257,
65
+ # Graph section
66
+ graph_type: str = "absorb",
67
+ # Model section
68
+ model_hidden_size: int = 768,
69
+ model_cond_dim: int = 128,
70
+ model_length: int = 1024,
71
+ model_n_blocks: int = 12,
72
+ model_n_heads: int = 12,
73
+ model_scale_by_sigma: bool = True,
74
+ model_dropout: float = 0.10,
75
+ # Miscellaneous / HF specific
76
+ tie_word_embeddings: bool = False,
77
+ **kwargs,
78
+ ) -> None:
79
+ # NOTE: `tie_word_embeddings` goes to the base class because
80
+ # `PretrainedConfig` validates keyword-only signature.
81
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
82
+
83
+ # Keep attributes *flat* – matching the style used by most HF models.
84
+ # -------------------------------------------------------------------
85
+ self.tokens = tokens
86
+ self.graph_type = graph_type
87
+
88
+ self.model_hidden_size = model_hidden_size
89
+ self.model_cond_dim = model_cond_dim
90
+ self.model_length = model_length
91
+ self.model_n_blocks = model_n_blocks
92
+ self.model_n_heads = model_n_heads
93
+ self.model_scale_by_sigma = model_scale_by_sigma
94
+ self.model_dropout = model_dropout
95
+
96
+ # ------------------------------------------------------------------
97
+ # Compatibility helpers
98
+ # ------------------------------------------------------------------
99
+
100
+ def to_hydra(self):
101
+ """Convert this *flat* configuration to the nested OmegaConf structure
102
+ expected by the reference `SEDD` implementation.
103
+ """
104
+ if OmegaConf is None:
105
+ raise RuntimeError("`omegaconf` is required to build a Hydra config")
106
+
107
+ nested: Dict[str, Any] = {
108
+ "tokens": self.tokens,
109
+ "graph": {
110
+ "type": self.graph_type,
111
+ },
112
+ "model": {
113
+ "hidden_size": self.model_hidden_size,
114
+ "cond_dim": self.model_cond_dim,
115
+ "length": self.model_length,
116
+ "n_blocks": self.model_n_blocks,
117
+ "n_heads": self.model_n_heads,
118
+ "scale_by_sigma": self.model_scale_by_sigma,
119
+ "dropout": self.model_dropout,
120
+ },
121
+ }
122
+ return OmegaConf.create(nested)
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d93bb0dd1013295a4865848ea546ee3763a5be036cf55ea407e898c0a7a82a33
3
+ size 1698000441
sedd_wrapper.py ADDED
@@ -0,0 +1,289 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ """sedd_wrapper.py
4
+ =========================================
5
+ This module provides a minimal HuggingFace-compatible wrapper around the
6
+ `SEDD` architecture that is implemented in :pyfile:`model/transformer.py`.
7
+
8
+ The wrapper closely follows the design used in the Aero implementation that
9
+ lives in this code-base (see :pyfile:`configuration_aero.py` and
10
+ :pyfile:`modeling_aero.py`). Concretely we expose three public objects:
11
+
12
+ * ``SEDDConfig`` A :class:`transformers.PretrainedConfig` subclass that
13
+ stores the hyper-parameters needed to instantiate a ``SEDD`` model.
14
+ * ``SEDDModel`` A :class:`transformers.PreTrainedModel` subclass that
15
+ internally contains an instance of the original ``SEDD`` network and maps
16
+ from ``input_ids`` + ``sigma`` to the vocabulary logits.
17
+ * ``SEDDOutput`` A thin :class:`transformers.modeling_outputs.ModelOutput`
18
+ dataclass that mirrors the usual "logits / loss" structure.
19
+
20
+ With this wrapper a trained model checkpoint can be pushed to / loaded from
21
+ 🤗 Hub via ``SEDDModel.push_to_hub`` / ``SEDDModel.from_pretrained`` the same
22
+ way as any other ``transformers`` model.
23
+ """
24
+
25
+ from dataclasses import dataclass
26
+ from typing import Optional, Tuple, List, Dict, Any, Union
27
+
28
+ import torch
29
+ from torch import nn
30
+ from transformers.configuration_utils import PretrainedConfig
31
+ from transformers.modeling_outputs import ModelOutput
32
+ from transformers.modeling_utils import PreTrainedModel
33
+ from transformers.utils import logging
34
+
35
+ # Original SEDD implementation
36
+ from model.transformer import SEDD as _OrigSEDD
37
+
38
+ try:
39
+ from omegaconf import OmegaConf
40
+ except ImportError: # pragma: no cover – omegaconf is an explicit dependency of SEDD
41
+ OmegaConf = None # type: ignore
42
+
43
+ logger = logging.get_logger(__name__)
44
+
45
+ ###############################################################################
46
+ # Configuration #
47
+ ###############################################################################
48
+
49
+
50
+ class SEDDConfig(PretrainedConfig):
51
+ """Configuration class for the SEDD architecture.
52
+
53
+ The defaults reproduce *roughly* the "small" configuration shipped in
54
+ ``configs/model/small.yaml``. Additional keys that are present in the
55
+ original Hydra config but not required for instantiation (e.g. *training*
56
+ hyper-parameters) are deliberately omitted here – they can still be stored
57
+ as *extra* fields in the underlying JSON if a user wishes to preserve them.
58
+ """
59
+
60
+ model_type: str = "sedd"
61
+
62
+ def __init__(
63
+ self,
64
+ *,
65
+ tokens: int = 50257,
66
+ # graph section
67
+ graph_type: str = "absorb",
68
+ # model section (mirrors configs/model/*.yaml)
69
+ model_hidden_size: int = 768,
70
+ model_cond_dim: int = 128,
71
+ model_length: int = 1024,
72
+ model_n_blocks: int = 12,
73
+ model_n_heads: int = 12,
74
+ model_scale_by_sigma: bool = True,
75
+ model_dropout: float = 0.10,
76
+ # miscellaneous
77
+ tie_word_embeddings: bool = False,
78
+ **kwargs,
79
+ ) -> None:
80
+ super().__init__(tie_word_embeddings=tie_word_embeddings, **kwargs)
81
+
82
+ # Top-level attributes (kept flat for simplicity)
83
+ self.tokens = tokens
84
+ self.graph_type = graph_type
85
+
86
+ # Model hyper-parameters
87
+ self.model_hidden_size = model_hidden_size
88
+ self.model_cond_dim = model_cond_dim
89
+ self.model_length = model_length
90
+ self.model_n_blocks = model_n_blocks
91
+ self.model_n_heads = model_n_heads
92
+ self.model_scale_by_sigma = model_scale_by_sigma
93
+ self.model_dropout = model_dropout
94
+
95
+ # ---------------------------------------------------------------------
96
+ # Serialization helpers – these optionally bridge to the original Hydra
97
+ # config structure that the reference implementation expects.
98
+ # ---------------------------------------------------------------------
99
+
100
+ def to_hydra(self):
101
+ """Convert this *flat* config to the nested OmegaConf structure that
102
+ the reference ``SEDD`` implementation expects.
103
+ """
104
+
105
+ if OmegaConf is None:
106
+ raise RuntimeError("`omegaconf` is required to build a Hydra config")
107
+
108
+ nested: Dict[str, Any] = {
109
+ "tokens": self.tokens,
110
+ "graph": {
111
+ "type": self.graph_type,
112
+ },
113
+ "model": {
114
+ "hidden_size": self.model_hidden_size,
115
+ "cond_dim": self.model_cond_dim,
116
+ "length": self.model_length,
117
+ "n_blocks": self.model_n_blocks,
118
+ "n_heads": self.model_n_heads,
119
+ "scale_by_sigma": self.model_scale_by_sigma,
120
+ "dropout": self.model_dropout,
121
+ },
122
+ }
123
+ return OmegaConf.create(nested)
124
+
125
+ ###############################################################################
126
+ # Output container #
127
+ ###############################################################################
128
+
129
+
130
+ @dataclass
131
+ class SEDDOutput(ModelOutput):
132
+ """Standard output for :class:`SEDDModel`.
133
+
134
+ Attributes
135
+ ----------
136
+ loss:
137
+ *Optional* scalar returned when ``labels`` are provided.
138
+ logits:
139
+ The raw vocabulary logits computed by the model of shape
140
+ ``(batch_size, sequence_length, vocab_size)``.
141
+ """
142
+
143
+ loss: Optional[torch.FloatTensor] = None
144
+ logits: torch.FloatTensor | None = None
145
+
146
+ ###############################################################################
147
+ # Model #
148
+ ###############################################################################
149
+
150
+
151
+ class SEDDModel(PreTrainedModel):
152
+ """HuggingFace *Transformers* wrapper around the original ``SEDD`` model."""
153
+
154
+ config_class = SEDDConfig
155
+ base_model_prefix = "score_model"
156
+ _no_split_modules: List[str] = [
157
+ "DDiTBlock", # ensure these blocks are not split when using FSDP/TP
158
+ ]
159
+
160
+ def __init__(self, config: SEDDConfig):
161
+ super().__init__(config)
162
+
163
+ # ------------------------------------------------------------------
164
+ # Instantiate the original SEDD architecture using the Hydra cfg that
165
+ # the implementation expects.
166
+ # ------------------------------------------------------------------
167
+ if OmegaConf is None:
168
+ raise RuntimeError("`omegaconf` is required to instantiate SEDD")
169
+
170
+ hydra_cfg = config.to_hydra()
171
+ self.score_model = _OrigSEDD(hydra_cfg)
172
+
173
+ # Make sure parameters are created on the right device / dtype.
174
+ self.post_init()
175
+
176
+ # ------------------------------------------------------------------
177
+ # Forward pass
178
+ # ------------------------------------------------------------------
179
+
180
+ def forward(
181
+ self,
182
+ input_ids: torch.LongTensor,
183
+ sigma: torch.FloatTensor,
184
+ labels: Optional[torch.LongTensor] = None,
185
+ **kwargs: Any,
186
+ ) -> Union[SEDDOutput, Tuple]:
187
+ """Run a forward pass.
188
+
189
+ Parameters
190
+ ----------
191
+ input_ids:
192
+ Token indices of shape ``(batch_size, seq_len)``.
193
+ sigma:
194
+ Noise level ("time-step") of shape ``(batch_size,)``.
195
+ labels:
196
+ *Optional* label tensor used to compute a cross-entropy training
197
+ loss. If provided the returned :class:`SEDDOutput` will contain a
198
+ ``loss`` field.
199
+ """
200
+
201
+ logits = self.score_model(indices=input_ids, sigma=sigma)
202
+
203
+ loss: Optional[torch.Tensor] = None
204
+ if labels is not None:
205
+ # Standard CE loss over the last dimension (vocab)
206
+ loss_fct = nn.CrossEntropyLoss()
207
+ loss = loss_fct(logits.view(-1, logits.size(-1)), labels.view(-1))
208
+
209
+ if not self.config.return_dict:
210
+ output: Tuple[Any, ...] = (logits,)
211
+ return ((loss,) + output) if loss is not None else output
212
+
213
+ return SEDDOutput(loss=loss, logits=logits)
214
+
215
+ # ------------------------------------------------------------------
216
+ # Weight loading helpers – we delegate to the *original* SEDD mixin so that
217
+ # checkpoints trained with the previous implementation can be re-used.
218
+ # ------------------------------------------------------------------
219
+
220
+ @classmethod
221
+ def from_pretrained(
222
+ cls,
223
+ pretrained_model_name_or_path: str,
224
+ *model_args: Any,
225
+ **kwargs: Any,
226
+ ) -> "SEDDModel":
227
+ """Overrides the default method to allow loading legacy SEDD checkpoints
228
+ whose weights are saved via ``torch.save({'model': state_dict, ...})``.
229
+ """
230
+
231
+ try:
232
+ # First try the regular *transformers* loading routine – this will
233
+ # succeed if the repository follows the standard file-naming
234
+ # conventions (i.e. contains a ``pytorch_model.bin`` / safetensors).
235
+ return super().from_pretrained(
236
+ pretrained_model_name_or_path, *model_args, **kwargs
237
+ )
238
+ except (EnvironmentError, RuntimeError) as e:
239
+ logger.info(
240
+ "Falling back to legacy SEDD checkpoint format because standard "
241
+ "loading raised: %s", e,
242
+ )
243
+
244
+ # ----------------------------------------------------------
245
+ # 1. Load config the usual way so we get a `SEDDConfig` instance.
246
+ # ----------------------------------------------------------
247
+ config = kwargs.pop("config", None) or SEDDConfig.from_pretrained(
248
+ pretrained_model_name_or_path
249
+ )
250
+ model = cls(config, *model_args, **kwargs)
251
+
252
+ # ----------------------------------------------------------
253
+ # 2. Attempt to locate the legacy *.pth* checkpoint and load it.
254
+ # ----------------------------------------------------------
255
+ import os
256
+ import torch as _torch
257
+
258
+ checkpoint_path = os.path.join(
259
+ pretrained_model_name_or_path, "checkpoints-meta", "checkpoint.pth"
260
+ )
261
+ if not os.path.isfile(checkpoint_path):
262
+ raise FileNotFoundError(
263
+ "Could not find legacy SEDD checkpoint at " f"{checkpoint_path}"
264
+ )
265
+
266
+ ckpt = _torch.load(checkpoint_path, map_location="cpu")
267
+ state_dict = ckpt.get("model", ckpt)
268
+ # Strip prefix if present (sometimes stored under "module.")
269
+ state_dict = {
270
+ k.replace("module.", ""): v for k, v in state_dict.items()
271
+ }
272
+ missing, unexpected = model.load_state_dict(state_dict, strict=False)
273
+ if missing:
274
+ logger.warning("Missing keys when loading SEDD weights: %s", missing)
275
+ if unexpected:
276
+ logger.warning(
277
+ "Unexpected keys when loading SEDD weights: %s", unexpected
278
+ )
279
+ return model
280
+
281
+ ###############################################################################
282
+ # Public API #
283
+ ###############################################################################
284
+
285
+ __all__ = [
286
+ "SEDDConfig",
287
+ "SEDDModel",
288
+ "SEDDOutput",
289
+ ]