Seas0 commited on
Commit
658a484
·
verified ·
1 Parent(s): d0808fa

Update to support transformers v5.3.0

Browse files
Files changed (2) hide show
  1. config.json +30 -30
  2. modeling_stable_diffcoder.py +298 -0
config.json CHANGED
@@ -1,31 +1,31 @@
1
  {
2
- "architectures": [
3
- "LlamaForCausalLM"
4
- ],
5
- "auto_map": {
6
- "AutoModelForCausalLM": "modeling_seed_diffcoder.SeedDiffcoderForCausalLM"
7
- },
8
- "attention_bias": false,
9
- "attention_dropout": 0.1,
10
- "bos_token_id": 0,
11
- "eos_token_id": 2,
12
- "hidden_act": "silu",
13
- "hidden_size": 4096,
14
- "initializer_range": 0.009882118,
15
- "intermediate_size": 14336,
16
- "layer_norm_eps": null,
17
- "max_position_embeddings": 8192,
18
- "mlp_bias": false,
19
- "model_type": "llama",
20
- "num_attention_heads": 32,
21
- "num_hidden_layers": 32,
22
- "num_key_value_heads": 8,
23
- "resid_pdrop": 0.1,
24
- "rms_norm_eps": 1e-06,
25
- "rope_theta": 500000.0,
26
- "tie_word_embeddings": false,
27
- "torch_dtype": "bfloat16",
28
- "transformers_version": "4.46.2",
29
- "use_cache": true,
30
- "vocab_size": 155136
31
- }
 
1
  {
2
+ "architectures": [
3
+ "StableDiffcoderForCausalLM"
4
+ ],
5
+ "auto_map": {
6
+ "AutoModelForCausalLM": "modeling_stable_diffcoder.StableDiffcoderForCausalLM"
7
+ },
8
+ "attention_bias": false,
9
+ "attention_dropout": 0.1,
10
+ "bos_token_id": 0,
11
+ "eos_token_id": 2,
12
+ "hidden_act": "silu",
13
+ "hidden_size": 4096,
14
+ "initializer_range": 0.009882118,
15
+ "intermediate_size": 14336,
16
+ "layer_norm_eps": null,
17
+ "max_position_embeddings": 8192,
18
+ "mlp_bias": false,
19
+ "model_type": "llama",
20
+ "num_attention_heads": 32,
21
+ "num_hidden_layers": 32,
22
+ "num_key_value_heads": 8,
23
+ "resid_pdrop": 0.1,
24
+ "rms_norm_eps": 1e-06,
25
+ "rope_theta": 500000.0,
26
+ "tie_word_embeddings": false,
27
+ "torch_dtype": "bfloat16",
28
+ "transformers_version": "5.3.0",
29
+ "use_cache": true,
30
+ "vocab_size": 155136
31
+ }
modeling_stable_diffcoder.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2026 ByteDance Ltd. and/or its affiliates
2
+ # SPDX-License-Identifier: MIT
3
+
4
+ import numpy as np
5
+ import torch
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+ from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM, DynamicCache
9
+ from transformers.models.llama.modeling_llama import LlamaForCausalLM
10
+ from transformers.generation.utils import GenerationConfig
11
+
12
+
13
+ class StableDiffcoderForCausalLM(LlamaForCausalLM):
14
+ def _get_num_transfer_tokens(self, mask_map, steps):
15
+ # Only bs == 1 is supported for now
16
+ mask_num = mask_map.sum().long().item()
17
+
18
+ base = mask_num // steps
19
+ remainder = mask_num % steps
20
+
21
+ num_transfer_tokens = torch.full(
22
+ (steps,), fill_value=base, device=mask_map.device, dtype=torch.long
23
+ )
24
+
25
+ num_transfer_tokens[:remainder] += 1
26
+
27
+ return num_transfer_tokens
28
+
29
+ def _make_block_causal_mask(
30
+ self, seq_len, block_size=2, device=None, dtype=torch.bfloat16
31
+ ):
32
+ # ceil(seq_len / block_size)
33
+ num_blocks = (seq_len + block_size - 1) // block_size
34
+ # create a block-wise causal mask using Kronecker product
35
+ # global_mask = block_wise_mask ⊗ per_block_local_mask
36
+ block_mask = torch.tril(
37
+ torch.ones((num_blocks, num_blocks), dtype=torch.bool, device=device)
38
+ )
39
+ local_block = torch.ones(
40
+ (block_size, block_size), dtype=torch.bool, device=device
41
+ )
42
+ mask = block_mask.kron(local_block)[:seq_len, :seq_len]
43
+ # [x] [ ] [ ] [ )
44
+ # [x] [x] [ ] [ )
45
+ # [x] [x] [x] [ )
46
+ # [x] [x] [x] [x)
47
+
48
+ # TODO: remove this itchy -inf masking method.
49
+ attention_mask = mask.float()
50
+ attention_mask.masked_fill_(~mask, -torch.inf)
51
+ attention_mask = attention_mask.unsqueeze(0).unsqueeze(0).to(dtype)
52
+ return attention_mask
53
+
54
+ def _get_transfer_index(
55
+ self,
56
+ logits,
57
+ temperature,
58
+ remasking,
59
+ mask_index,
60
+ x,
61
+ num_transfer_token,
62
+ threshold=None,
63
+ shift=False,
64
+ ):
65
+ def add_gumbel_noise(logits, temperature):
66
+ if temperature == 0:
67
+ return logits
68
+ logits = logits.to(torch.float64)
69
+ noise = torch.rand_like(logits, dtype=torch.float64)
70
+ gumbel_noise = (-torch.log(noise)) ** temperature
71
+ return logits.exp() / gumbel_noise
72
+
73
+ logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
74
+ x0 = torch.argmax(logits_with_noise, dim=-1) # b, l
75
+ if shift == True:
76
+ x0 = torch.cat([x[:, :1], x0[:, :-1]], dim=-1)
77
+ pad = torch.zeros_like(logits[:, :1])
78
+ logits = torch.cat([pad, logits[:, :-1]], dim=1)
79
+ if remasking == "low_confidence":
80
+ p = F.softmax(logits.to(torch.float64), dim=-1)
81
+ x0_p = torch.squeeze(
82
+ torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1
83
+ ) # b, l
84
+ elif remasking == "random":
85
+ x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
86
+ else:
87
+ raise NotImplementedError(remasking)
88
+
89
+ x0 = torch.where(mask_index, x0, x)
90
+ confidence = torch.where(mask_index, x0_p, -np.inf)
91
+
92
+ transfer_map = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
93
+ if threshold is not None:
94
+ num_transfer_token = mask_index.sum(dim=1, keepdim=True)
95
+ _, select_index = torch.topk(confidence[0], k=num_transfer_token)
96
+ transfer_map[0, select_index] = True
97
+ if threshold is not None:
98
+ for k in range(1, num_transfer_token):
99
+ if confidence[0, select_index[k]] < threshold:
100
+ transfer_map[0, select_index[k]] = False
101
+ return x0, transfer_map
102
+
103
+ @torch.no_grad()
104
+ def generate_block(
105
+ self,
106
+ input_ids: torch.LongTensor,
107
+ steps=128,
108
+ gen_length=128,
109
+ block_length=4,
110
+ temperature=0.0,
111
+ remasking="low_confidence",
112
+ tokenizer=None,
113
+ mask_id=5,
114
+ threshold=0.95,
115
+ shift=False,
116
+ eos_id=None,
117
+ ):
118
+ # initialize x with mask_id and copy prompt to the beginning
119
+ # x = torch.full((1, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(
120
+ # self.device
121
+ # )
122
+ # x[:, : prompt.shape[1]] = prompt.clone()
123
+ x = torch.cat(
124
+ [
125
+ input_ids,
126
+ torch.full(
127
+ (input_ids.shape[0], gen_length),
128
+ mask_id,
129
+ dtype=torch.long,
130
+ device=input_ids.device,
131
+ ),
132
+ ],
133
+ dim=1,
134
+ )
135
+
136
+ # check the validity of block count
137
+ assert gen_length % block_length == 0, (
138
+ "gen_length must be divisible by block_length"
139
+ )
140
+ gen_blocks = gen_length // block_length
141
+
142
+ # check the validity of sampling steps
143
+ assert steps % gen_blocks == 0, (
144
+ "steps must be divisible by the number of generation blocks"
145
+ )
146
+ steps = steps // gen_blocks
147
+
148
+ # check bs == 1
149
+ assert x.shape[0] == 1, (
150
+ "Only batch size of 1 is supported for block-wise generation currently."
151
+ )
152
+
153
+ # construct block lengths
154
+ prompt_length = input_ids.shape[1]
155
+ gen_block_list = [block_length for _ in range(gen_blocks)]
156
+
157
+ # if the prompt is not aligned with block boundary
158
+ # adjust the first block and the last block accordingly
159
+ res_block = block_length - (prompt_length % block_length)
160
+ if res_block > 0:
161
+ gen_block_list = [res_block] + gen_block_list
162
+ gen_block_list[-1] = block_length - res_block
163
+ gen_blocks += 1
164
+ # cumulative block lengths (pfxSum for attn mask construction)
165
+ cum_block = [sum(gen_block_list[: i + 1]) for i in range(len(gen_block_list))]
166
+
167
+ # make block-wise causal diffusion attention mask
168
+ block_diffusion_attention_mask = self._make_block_causal_mask(
169
+ prompt_length + gen_length,
170
+ block_length,
171
+ self.device,
172
+ dtype=torch.bfloat16,
173
+ )
174
+
175
+ # TODO: better cache initialization method
176
+ past_key_values = DynamicCache()
177
+
178
+ # prefill the kv cache with prompt as input
179
+ nfe = 0
180
+ final_flag = False
181
+ # align prompt_length to block_length boundary
182
+ prefill_length = prompt_length // block_length * block_length
183
+ if prefill_length > 0:
184
+ cur_attn_mask = block_diffusion_attention_mask[
185
+ :, :, :prefill_length, :prefill_length
186
+ ]
187
+ self(
188
+ x[:, :prefill_length],
189
+ past_key_values=past_key_values,
190
+ attention_mask=cur_attn_mask,
191
+ use_cache=True,
192
+ ).past_key_values
193
+
194
+ # iterative block-wise generation
195
+ for block_id, block_size in enumerate(gen_block_list):
196
+ # print(
197
+ # f"Generating block {block_id + 1}/{gen_blocks} with {steps} steps..."
198
+ # )
199
+ block_start = (
200
+ prompt_length + cum_block[block_id - 1]
201
+ if block_id > 0
202
+ else prefill_length
203
+ )
204
+ block_end = prompt_length + cum_block[block_id]
205
+ # print(f"Current block range: [{block_start}, {block_end})")
206
+
207
+ block_mask_map = x[:, block_start:block_end] == mask_id
208
+ # sampling noise schedule
209
+ num_transfer_tokens = self._get_num_transfer_tokens(block_mask_map, steps)
210
+ # print(f"DEBUG: {num_transfer_tokens=}")
211
+
212
+ replace_position = torch.zeros_like(x, dtype=torch.bool)
213
+ replace_position[:, block_start:block_end] = True
214
+
215
+ for token_count in num_transfer_tokens:
216
+ if token_count:
217
+ # print(f"Transferring {token_count} tokens in block {block_id + 1}/{gen_blocks}...")
218
+ nfe += 1
219
+ mask_map = x[:, block_start:block_end] == mask_id
220
+ attention_mask = block_diffusion_attention_mask[
221
+ ..., block_start:block_end, :block_end
222
+ ]
223
+ output = self(
224
+ x[:, block_start:block_end],
225
+ attention_mask=attention_mask,
226
+ past_key_values=past_key_values,
227
+ use_cache=True,
228
+ cache_position=replace_position.nonzero(as_tuple=True)[1],
229
+ )
230
+ logits = output.logits
231
+
232
+ # crop the kv cache as we didn't finish the cur. blk
233
+ # IMPORTANT: check the correctness
234
+ past_key_values.crop(block_start)
235
+
236
+ # unmask based on policy of logits
237
+ x0, transfer_map = self._get_transfer_index(
238
+ logits,
239
+ temperature,
240
+ remasking,
241
+ mask_map,
242
+ x[:, block_start:block_end],
243
+ token_count if threshold is None else None,
244
+ threshold,
245
+ shift=False,
246
+ )
247
+ x[:, block_start:block_end][transfer_map] = x0[transfer_map]
248
+
249
+ if (x[:, block_start:block_end] == mask_id).sum() == 0:
250
+ # check if all sequences in the batch have produced eos
251
+ # if eos_id is not None and (x[:, current_block_start:current_block_end] == eos_id).sum() > 0:
252
+ if (
253
+ eos_id is not None
254
+ and (x[:, block_start:block_end] == eos_id).sum() > 0
255
+ ):
256
+ final_flag = True
257
+ x = x[:, :block_end]
258
+ # fill the rest of the sequence with eos_id if eos_id is specified
259
+ eos_pos = (x == eos_id).nonzero(as_tuple=True)[1][0].item()
260
+ x[0, eos_pos + 1:] = eos_id
261
+ break
262
+ nfe += 1
263
+ # update the kv cache
264
+ self(
265
+ x[:, block_start:block_end],
266
+ attention_mask=block_diffusion_attention_mask[
267
+ ..., block_start:block_end, :block_end
268
+ ],
269
+ past_key_values=past_key_values,
270
+ use_cache=True,
271
+ cache_position=replace_position.nonzero(as_tuple=True)[1],
272
+ )
273
+ break
274
+
275
+ if final_flag:
276
+ break
277
+
278
+ return x, nfe
279
+
280
+ @torch.no_grad()
281
+ def generate(
282
+ self,
283
+ input_ids=None,
284
+ generation_config: GenerationConfig = None,
285
+ **kwargs,
286
+ ):
287
+ if input_ids is None:
288
+ raise ValueError("input_ids must be provided")
289
+
290
+ if generation_config is None:
291
+ generation_config = self.generation_config
292
+
293
+ output_ids, nfe = self.generate_block(
294
+ input_ids=input_ids,
295
+ **kwargs,
296
+ )
297
+
298
+ return output_ids