myduy commited on
Commit
5e91d83
·
verified ·
1 Parent(s): a27c346

Update modeling with argmax_decoding support

Browse files
Files changed (1) hide show
  1. modeling_dlm.py +556 -0
modeling_dlm.py ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn.utils.rnn import pad_sequence
4
+ from transformers import PreTrainedModel, AutoModelForMaskedLM, AutoConfig
5
+ try:
6
+ from .configuration_dlm import DiscreteDiffusionConfig
7
+ except ImportError:
8
+ from configuration_dlm import DiscreteDiffusionConfig
9
+
10
+ from collections import namedtuple
11
+ import math
12
+ import numpy as np
13
+ from typing import List, Optional, Tuple, Union
14
+
15
+ decoder_out_t = namedtuple(
16
+ "decoder_out_t",
17
+ ["output_tokens", "output_scores", "output_masks", "non_fixed_sym_masks", "attn", "step", "max_step", "history"],
18
+ )
19
+
20
+ def topk_masking(scores, cutoff_len, stochastic=False, temp=1.0):
21
+ """
22
+ scores: [b, n]
23
+ cutoff_len: [b, 1]
24
+ stochastic: bool, whether to add noise to select top_k or not
25
+ returns:
26
+ mask: [b, n], with 1 if the token is in top-k lowest scores, 0 otherwise
27
+ """
28
+ if stochastic:
29
+ gumbel_noise = -torch.log(-torch.log(torch.rand_like(scores) + 1e-8) + 1e-8)
30
+ _scores = scores + temp * gumbel_noise
31
+ else:
32
+ _scores = scores
33
+ sorted_index = _scores.sort(-1)[0]
34
+ cutoff = sorted_index.gather(dim=-1, index=cutoff_len) # + 1e-10
35
+ # cutoff_len = k -> select k + 1 tokens
36
+ masking = _scores < cutoff
37
+ return masking
38
+
39
+ class DiscreteDiffusionModel(PreTrainedModel):
40
+ config_class = DiscreteDiffusionConfig
41
+ _keys_to_ignore_on_load_missing = ["fake_layer", "length_trm", "length_predictor", "model.lm_head.decoder.weight"]
42
+
43
+ def __init__(self, config: DiscreteDiffusionConfig):
44
+ super().__init__(config)
45
+ self.config = config
46
+ self.args = config # Alias for compatibility with existing code
47
+
48
+ # Initialize backbone
49
+ if config.backbone_config:
50
+ # We assume backbone_config is a dict
51
+ backbone_config_obj = AutoConfig.for_model(**config.backbone_config)
52
+ self.model = AutoModelForMaskedLM.from_config(backbone_config_obj)
53
+ else:
54
+ # Fallback or error
55
+ raise ValueError("backbone_config must be provided in config")
56
+
57
+ if config.tie_word_embeddings:
58
+ self.model.lm_head.decoder.weight = self.model.roberta.embeddings.word_embeddings.weight
59
+
60
+ self.mask_id = config.mask_token_id
61
+ self.bos_id = config.bos_token_id
62
+ self.eos_id = config.eos_token_id
63
+ self.pad_id = config.pad_token_id
64
+
65
+ # Lora
66
+ if config.lora:
67
+ self.add_fake_layer()
68
+
69
+ # Length predictor (optional, as in original code)
70
+ self.length_trm = nn.TransformerEncoder(
71
+ nn.TransformerEncoderLayer(
72
+ d_model=self.config.hidden_size,
73
+ nhead=self.config.num_attention_heads,
74
+ dim_feedforward=self.config.intermediate_size,
75
+ batch_first=True
76
+ ),
77
+ num_layers=1,
78
+ )
79
+ self.length_predictor = nn.Sequential(
80
+ nn.Linear(self.config.hidden_size , self.config.intermediate_size),
81
+ nn.Tanh(),
82
+ nn.Linear(self.config.intermediate_size, self.config.max_position_embeddings)
83
+ )
84
+
85
+ def add_fake_layer(self):
86
+ self.fake_layer = nn.Parameter(torch.zeros((self.config.hidden_size, )))
87
+
88
+ def gradient_checkpointing_enable(self):
89
+ self.model.gradient_checkpointing_enable()
90
+
91
+ def _tie_weights(self):
92
+ """Tie the weights between the input embeddings and the output embeddings."""
93
+ if self.config.tie_word_embeddings:
94
+ self._tie_or_clone_weights(
95
+ self.model.lm_head.decoder,
96
+ self.model.roberta.embeddings.word_embeddings
97
+ )
98
+
99
+ def _init_weights(self, module):
100
+ """Initialize the weights - called after loading checkpoint."""
101
+ # Call parent init_weights
102
+ super()._init_weights(module)
103
+ # Ensure weights are tied after initialization
104
+ self._tie_weights()
105
+
106
+ @property
107
+ def _tied_weights_keys(self):
108
+ """Return the keys of tied weights."""
109
+ if self.config.tie_word_embeddings:
110
+ return ["model.lm_head.decoder.weight"]
111
+ return []
112
+
113
+ def q_sample_coupled(self, x_0, t1, t2, maskable_mask):
114
+ # ... copy from DiscreteDiffusionBase ...
115
+ assert self.config.diffusion_type == "absorbing", "we only support absorbing diffusion temporarily"
116
+ t1_eq_t2_mask = (t1 == t2)
117
+ t1, t2 = torch.maximum(t1, t2).float(), torch.minimum(t1, t2).float()
118
+
119
+ u = torch.rand_like(x_0, dtype=torch.float)
120
+ t1_mask = (u < (t1 / self.config.num_diffusion_timesteps)[:, None]) & maskable_mask
121
+ x_t1 = x_0.masked_fill(t1_mask, self.mask_id)
122
+
123
+ u = torch.rand_like(x_0, dtype=torch.float)
124
+ t2_mask = t1_mask & (u > ((t1 - t2) / t1)[:, None])
125
+ u = torch.rand_like(x_0[t1_eq_t2_mask], dtype=torch.float)
126
+ t2_mask[t1_eq_t2_mask] = (u < (t1[t1_eq_t2_mask] / self.config.num_diffusion_timesteps)[:, None]) & (maskable_mask[t1_eq_t2_mask])
127
+ x_t2 = x_0.masked_fill(t2_mask, self.mask_id)
128
+
129
+ return {
130
+ "x_t": torch.cat([x_t1, x_t2], dim=0),
131
+ "t": torch.cat([t1, t2]),
132
+ "mask_mask": torch.cat([t1_mask, t2_mask], dim=0)
133
+ }
134
+
135
+ def initialize_decode_samples(self, tokens, partial_masks, prefix_masks, oracle_length=False, length_beam=1, mbr=1):
136
+ # ... copy from DiscreteDiffusionBase ...
137
+ if tokens is None:
138
+ raise NotImplementedError
139
+ else:
140
+ if not oracle_length:
141
+ inputs_tokens = tokens.masked_fill(~prefix_masks, self.pad_id)
142
+ src_length = inputs_tokens.ne(self.pad_id).sum(dim=-1)
143
+ inputs_tokens = inputs_tokens[:, :src_length.max()]
144
+ length_logits = self.forward_length(inputs_tokens)
145
+ # Giới hạn độ dài output tối đa: không quá 3x độ dài source và không quá 100 tokens
146
+ max_allowed_length = torch.min(
147
+ torch.tensor([100]).to(src_length.device),
148
+ (src_length * 3)[:, None]
149
+ )
150
+ length = (
151
+ torch.min(
152
+ torch.min(
153
+ length_logits.topk(length_beam, dim=-1).indices + 1,
154
+ max_allowed_length
155
+ ),
156
+ self.config.max_position_embeddings - 2 - src_length[:, None] - 1
157
+ )
158
+ )
159
+ output_tokens = []
160
+ new_partial_masks = []
161
+ for i, token in enumerate(inputs_tokens):
162
+ for b in range(length_beam):
163
+ for m in range(mbr):
164
+ # Create output token sequence
165
+ seq = torch.cat([
166
+ token[:src_length[i]],
167
+ torch.tensor([self.mask_id] * length[i][b] + [self.eos_id]).to(token)
168
+ ])
169
+ output_tokens.append(seq)
170
+
171
+ # Create corresponding partial mask
172
+ # True for fixed (source), False for generated (mask/eos)
173
+ # partial_masks[i] corresponds to token[i]
174
+ # We assume partial_masks[i] has same length as token[i] (or at least src_length[i])
175
+ p_mask = torch.cat([
176
+ partial_masks[i][:src_length[i]],
177
+ torch.tensor([False] * (length[i][b] + 1)).to(partial_masks)
178
+ ])
179
+ new_partial_masks.append(p_mask)
180
+
181
+ output_tokens = pad_sequence(output_tokens, batch_first=True, padding_value=self.pad_id)
182
+ # Pad partial masks to match output_tokens length
183
+ # We need to pad with True (fixed) or False (maskable)?
184
+ # Usually padding tokens should be ignored.
185
+ # In finalized_hypos: cutoff = tokens.ne(pad) & ... & (~partial_mask)
186
+ # If we pad partial_mask with True, ~partial_mask is False, so it's filtered out.
187
+ # If we pad with False, ~partial_mask is True, so it's kept (if not pad_id).
188
+ # Since we check tokens.ne(pad_id), padding tokens are filtered anyway.
189
+ # But for safety, let's pad with True (fixed) so they are treated as non-generated?
190
+ # Actually, pad_sequence pads with 0. For bool tensor, 0 is False.
191
+ # So if we use pad_sequence on bool tensor, it pads with False.
192
+ partial_masks = pad_sequence(new_partial_masks, batch_first=True, padding_value=True) # Pad with True to be safe?
193
+ # Wait, if we pad with True, then ~partial_mask is False.
194
+
195
+ output_mask = output_tokens.eq(self.mask_id)
196
+ # non_fixed_sym_masks should be all positions that can be modified (not source, not pad, not special tokens)
197
+ # This is critical for _reparam_decoding to work correctly!
198
+ non_fixed_sym_masks = (
199
+ output_tokens.ne(self.pad_id) &
200
+ output_tokens.ne(self.bos_id) &
201
+ ~partial_masks # Not source tokens
202
+ )
203
+ else:
204
+ output_tokens = torch.stack([token for token in tokens for m in range(mbr)])
205
+ partial_masks = torch.stack([mask for mask in partial_masks for m in range(mbr)])
206
+ prefix_masks = torch.stack([mask for mask in prefix_masks for m in range(mbr)])
207
+ output_mask = (
208
+ output_tokens.ne(self.pad_id) &
209
+ output_tokens.ne(self.bos_id) &
210
+ output_tokens.ne(self.eos_id) &
211
+ ~prefix_masks
212
+ )
213
+ output_tokens = output_tokens.masked_fill(output_mask, self.mask_id)
214
+ non_fixed_sym_masks = output_mask.clone()
215
+ output_scores = torch.zeros_like(output_tokens, dtype=torch.float)
216
+
217
+ return partial_masks, decoder_out_t(
218
+ output_tokens=output_tokens,
219
+ output_scores=output_scores,
220
+ output_masks=output_mask,
221
+ non_fixed_sym_masks=non_fixed_sym_masks,
222
+ attn=None,
223
+ step=0,
224
+ max_step=math.inf,
225
+ history=None
226
+ )
227
+
228
+ def forward_length(self, input_ids):
229
+ attention_mask = input_ids.ne(self.pad_id).int()
230
+ with torch.no_grad():
231
+ _feature = self.model.roberta(input_ids, attention_mask=attention_mask)[0]
232
+ feature = self.length_trm(_feature, src_key_padding_mask=(1-attention_mask).bool())
233
+ length = attention_mask.sum(dim=-1)
234
+ pooled_feature = feature.masked_fill((attention_mask==0)[:, :, None], 0).float().sum(1) / length[:, None]
235
+ length_logits = self.length_predictor(pooled_feature.to(feature))
236
+ return length_logits
237
+
238
+ def forward(self, prev_output_tokens, partial_mask, attention_mask=None, loss_mask=None, cache=None):
239
+ input_ids = prev_output_tokens
240
+ if attention_mask is None:
241
+ attention_mask = prev_output_tokens.ne(self.pad_id).int()
242
+
243
+ embeddings = self.model.roberta.embeddings.word_embeddings(input_ids)
244
+
245
+ if hasattr(self, "fake_layer") and self.training:
246
+ self.fake_layer.requires_grad = True
247
+ embeddings = embeddings + self.fake_layer * 0
248
+
249
+ if self.config.attention_strategy == "prefix_lm":
250
+ # ... simplified for now, assuming full attention or handling it ...
251
+ # Copying logic from original
252
+ ext_partial_mask = partial_mask.float()
253
+ ext_partial_mask = torch.bmm(ext_partial_mask[:, :, None], ext_partial_mask[:, None, :]).int()
254
+ ext_mask = attention_mask[:, None, :].repeat(1, attention_mask.size(-1), 1)
255
+ ext_mask[partial_mask] = ext_partial_mask[partial_mask]
256
+ outputs = self.model.roberta(inputs_embeds=embeddings, attention_mask=ext_mask)[0]
257
+ else:
258
+ outputs = self.model.roberta(inputs_embeds=embeddings, attention_mask=attention_mask)[0]
259
+
260
+ if not (~torch.isnan(outputs)).all():
261
+ outputs.masked_fill_(outputs.isnan(), 0)
262
+
263
+ outputs = outputs[loss_mask] if loss_mask is not None else outputs
264
+ return self.model.lm_head(outputs)
265
+
266
+ def _reparam_decoding(
267
+ self,
268
+ output_tokens,
269
+ output_scores,
270
+ cur_tokens,
271
+ cur_scores,
272
+ decoding_strategy,
273
+ xt_neq_x0,
274
+ non_special_sym_mask,
275
+ t,
276
+ max_step,
277
+ noise
278
+ ):
279
+ _, condition, topk_mode, schedule = decoding_strategy.split("-")
280
+
281
+ if schedule == "linear":
282
+ rate = 1 - t / max_step
283
+ elif schedule == "cosine":
284
+ rate = np.cos(t / max_step * np.pi * 0.5)
285
+ else:
286
+ raise NotImplementedError
287
+
288
+ cutoff_len = (
289
+ non_special_sym_mask.sum(1, keepdim=True).type_as(output_scores) * rate
290
+ ).long()
291
+ _scores_for_topk = cur_scores.masked_fill(~non_special_sym_mask, 1000.0)
292
+
293
+ if topk_mode.startswith("stochastic"):
294
+ noise_scale = float(topk_mode.replace("stochastic", ""))
295
+ lowest_k_mask = topk_masking(_scores_for_topk, cutoff_len, stochastic=True, temp=noise_scale * rate)
296
+ elif topk_mode == "deterministic":
297
+ lowest_k_mask = topk_masking(_scores_for_topk, cutoff_len, stochastic=False)
298
+ else:
299
+ raise NotImplementedError
300
+
301
+ if condition == "cond":
302
+ not_v1_t = (cur_tokens == output_tokens) & (cur_scores < output_scores) & lowest_k_mask
303
+ elif condition == "uncond":
304
+ not_v1_t = lowest_k_mask
305
+ else:
306
+ raise NotImplementedError
307
+
308
+ not_v2_t = lowest_k_mask
309
+
310
+ masked_to_noise = (~xt_neq_x0 & not_v1_t) | (xt_neq_x0 & not_v2_t)
311
+ if isinstance(noise, torch.Tensor):
312
+ output_tokens.masked_scatter_(masked_to_noise, noise[masked_to_noise])
313
+ elif isinstance(noise, (int, float)):
314
+ output_tokens.masked_fill_(masked_to_noise, noise)
315
+ else:
316
+ raise NotImplementedError("noise should be either a tensor or a scalar")
317
+ output_scores.masked_fill_(masked_to_noise, -math.inf)
318
+
319
+ masked_to_x0 = xt_neq_x0 & ~not_v2_t
320
+ output_tokens.masked_scatter_(masked_to_x0, cur_tokens[masked_to_x0])
321
+ output_scores.masked_scatter_(masked_to_x0, cur_scores[masked_to_x0])
322
+
323
+ new_xt_neq_x0 = (xt_neq_x0 | not_v1_t) & not_v2_t
324
+ return new_xt_neq_x0
325
+
326
+ def denoise_step(self, decoder_out, partial_masks, temperature=1.0, strategy="reparam-uncond-deterministic-cosine"):
327
+ output_tokens = decoder_out.output_tokens
328
+ output_scores = decoder_out.output_scores
329
+ prev_step, cur_step = decoder_out.step, decoder_out.step + 1
330
+ max_step = decoder_out.max_step
331
+
332
+ logits = self.forward(output_tokens, partial_masks)
333
+
334
+ logits[..., self.mask_id] = -math.inf
335
+ scores = torch.log_softmax(logits, dim=-1)
336
+
337
+ if strategy == "cmlm":
338
+ # get the mask
339
+ # <bos>, <eos> are ignored in this case since
340
+ # they are not equal to unk.
341
+ output_masks = output_tokens.eq(self.mask_id)
342
+ unmask_prob = 1 / (max_step - prev_step)
343
+ # where to unmask
344
+ changes = torch.rand(output_tokens.shape, device=output_tokens.device) < unmask_prob
345
+ # don't unmask somewhere already unmasked
346
+ changes = torch.bitwise_and(changes, output_masks)
347
+
348
+ if getattr(self.config, "argmax_decoding", False):
349
+ output_scores, new_tokens = scores.max(-1)
350
+ else:
351
+ # Assuming dists is imported or available, otherwise use torch.multinomial or similar
352
+ # But let's stick to what was in generator if possible, or implement simple sampling
353
+ # The generator used: dists.Categorical(logits=scores / temperature).sample()
354
+ # We need to import dists or use torch.distributions
355
+ import torch.distributions as dists
356
+ new_tokens = dists.Categorical(logits=scores / temperature).sample()
357
+ output_scores = torch.gather(scores, -1, new_tokens.unsqueeze(-1)).squeeze(-1)
358
+ output_tokens[changes] = new_tokens[changes]
359
+ elif strategy == "ar":
360
+ output_masks = output_tokens.eq(self.mask_id)
361
+ unmask_indices = (output_tokens.ne(self.mask_id) & output_tokens.ne(self.eos_id) & output_tokens.ne(self.pad_id)).sum(dim=-1)
362
+ indices = torch.arange(output_tokens.size(-1)).expand(output_tokens.shape).to(output_masks.device)
363
+ if getattr(self.config, "argmax_decoding", False):
364
+ output_scores, new_tokens = scores.max(-1)
365
+ else:
366
+ import torch.distributions as dists
367
+ new_tokens = dists.Categorical(logits=scores / temperature).sample()
368
+ output_scores = torch.gather(scores, -1, new_tokens.unsqueeze(-1)).squeeze(-1)
369
+ output_tokens[unmask_indices[:, None]==indices] = new_tokens[unmask_indices[:, None]==indices]
370
+ else:
371
+ if getattr(self.config, "argmax_decoding", False):
372
+ cur_scores, cur_tokens = scores.max(-1)
373
+ else:
374
+ import torch.distributions as dists
375
+ cur_tokens = dists.Categorical(logits=scores / temperature).sample()
376
+ cur_scores = torch.gather(scores, -1, cur_tokens.unsqueeze(-1)).squeeze(-1)
377
+ cur_scores = cur_scores.to(output_scores)
378
+
379
+ output_masks = self._reparam_decoding(
380
+ output_tokens=output_tokens,
381
+ output_scores=output_scores,
382
+ cur_tokens=cur_tokens,
383
+ cur_scores=cur_scores,
384
+ decoding_strategy=strategy,
385
+ xt_neq_x0=decoder_out.output_masks,
386
+ non_special_sym_mask=decoder_out.non_fixed_sym_masks,
387
+ t=cur_step,
388
+ max_step=max_step,
389
+ noise=self.mask_id
390
+ )
391
+
392
+ history = (
393
+ ([] if decoder_out.history is None else decoder_out.history) + [output_tokens.clone()]
394
+ if decoder_out.history is not None else None
395
+ )
396
+
397
+ return decoder_out._replace(
398
+ step=cur_step,
399
+ output_tokens=output_tokens,
400
+ output_scores=output_scores,
401
+ output_masks=output_masks,
402
+ history=history,
403
+ )
404
+
405
+ @torch.no_grad()
406
+ def generate(
407
+ self,
408
+ input_ids,
409
+ attention_mask=None,
410
+ max_iterations=10,
411
+ strategy="reparam-uncond-deterministic-cosine",
412
+ temperature=1.0,
413
+ return_history=False,
414
+ max_length=128, # Fixed generation length hyperparameter (like LLaDA)
415
+ **kwargs
416
+ ):
417
+ # Prepare inputs
418
+ src_tokens = input_ids
419
+
420
+ if attention_mask is None:
421
+ partial_masks = torch.ones_like(src_tokens).bool()
422
+ else:
423
+ partial_masks = attention_mask.bool()
424
+
425
+ prefix_masks = partial_masks
426
+
427
+ # Initialize canvas with fixed length (LLaDA approach)
428
+ # Instead of predicting length, use max_length as hyperparameter
429
+ batch_size = src_tokens.size(0)
430
+ src_length = src_tokens.ne(self.pad_id).sum(dim=-1)
431
+
432
+ # Create fully masked response of fixed length
433
+ output_tokens = []
434
+ new_partial_masks = []
435
+
436
+ for i in range(batch_size):
437
+ # Format: <source_without_eos> <mask>...<mask> <eos>
438
+ # Remove EOS from source if it exists
439
+ src_len = src_length[i].item()
440
+ src_seq = src_tokens[i, :src_len]
441
+
442
+ # Remove trailing EOS from source
443
+ if src_seq[-1] == self.eos_id:
444
+ src_seq = src_seq[:-1]
445
+ src_len -= 1
446
+
447
+ seq = torch.cat([
448
+ src_seq,
449
+ torch.full((max_length,), self.mask_id, dtype=src_tokens.dtype, device=src_tokens.device),
450
+ torch.tensor([self.eos_id], dtype=src_tokens.dtype, device=src_tokens.device)
451
+ ])
452
+ output_tokens.append(seq)
453
+
454
+ # Mask: True for source (fixed), False for generated part
455
+ mask = torch.cat([
456
+ torch.ones(src_len, dtype=torch.bool, device=src_tokens.device),
457
+ torch.zeros(max_length + 1, dtype=torch.bool, device=src_tokens.device) # +1 for eos
458
+ ])
459
+ new_partial_masks.append(mask)
460
+
461
+ output_tokens = pad_sequence(output_tokens, batch_first=True, padding_value=self.pad_id)
462
+ partial_masks = pad_sequence(new_partial_masks, batch_first=True, padding_value=True)
463
+
464
+ # Create masks for decoding
465
+ output_mask = output_tokens.eq(self.mask_id)
466
+ non_fixed_sym_masks = (
467
+ output_tokens.ne(self.pad_id) &
468
+ output_tokens.ne(self.bos_id) &
469
+ ~partial_masks # Not source tokens
470
+ )
471
+
472
+ output_scores = torch.zeros_like(output_tokens, dtype=torch.float)
473
+
474
+ prev_decoder_out = decoder_out_t(
475
+ output_tokens=output_tokens,
476
+ output_scores=output_scores,
477
+ output_masks=output_mask,
478
+ non_fixed_sym_masks=non_fixed_sym_masks,
479
+ attn=None,
480
+ step=0,
481
+ max_step=max_iterations,
482
+ history=None
483
+ )
484
+
485
+ if return_history:
486
+ prev_decoder_out = prev_decoder_out._replace(history=[])
487
+
488
+ for step in range(max_iterations):
489
+ prev_decoder_out = self.denoise_step(prev_decoder_out, partial_masks, temperature=temperature, strategy=strategy)
490
+
491
+ # Finalize: discard tokens after EOS (LLaDA approach)
492
+ def finalized_hypos(tokens, scores, partial_mask, history=None):
493
+ # First, find EOS position and cut there
494
+ eos_positions = (tokens == self.eos_id).nonzero(as_tuple=True)[0]
495
+ if len(eos_positions) > 0:
496
+ first_eos = eos_positions[0].item()
497
+ # Cut everything after EOS
498
+ tokens = tokens[:first_eos] # Exclude EOS
499
+ if scores is not None:
500
+ scores = scores[:first_eos]
501
+ partial_mask = partial_mask[:first_eos]
502
+
503
+ # Then apply cutoff logic: keep only generated tokens (not source, not special)
504
+ cutoff = (
505
+ tokens.ne(self.pad_id) &
506
+ tokens.ne(self.bos_id) &
507
+ tokens.ne(self.eos_id) &
508
+ (~partial_mask) # Not source tokens (partial_mask=False for generated)
509
+ )
510
+ tokens = tokens[cutoff]
511
+ if scores is None:
512
+ score = None
513
+ else:
514
+ scores = scores[cutoff]
515
+ score = scores.mean().item() if len(scores) > 0 else 0.0
516
+ ret_dict = {
517
+ "tokens": tokens,
518
+ "positional_scores": scores,
519
+ "score": score,
520
+ "alignment": None
521
+ }
522
+ if history is not None:
523
+ ret_dict["history"] = [
524
+ finalized_hypos(history_tokens, None, partial_mask, history=None)
525
+ for history_tokens in history
526
+ ]
527
+ return ret_dict
528
+
529
+ def score_select(hyps):
530
+ index = np.argmax([hyp["score"] for hyp in hyps])
531
+ return hyps[index]
532
+
533
+ output_tokens, output_scores = prev_decoder_out.output_tokens, prev_decoder_out.output_scores
534
+
535
+ # Handle history if needed
536
+ if return_history and prev_decoder_out.history is not None:
537
+ full_history = prev_decoder_out.history
538
+ histories = [[full_history[j][i] for j in range(max_iterations)] for i in range(output_tokens.size(0))]
539
+ hyps = []
540
+ for tokens, scores, partial_mask, history in zip(output_tokens, output_scores, partial_masks, histories):
541
+ hyps.append(finalized_hypos(tokens, scores, partial_mask, history))
542
+ else:
543
+ hyps = [
544
+ finalized_hypos(tokens, scores, partial_mask, None)
545
+ for tokens, scores, partial_mask in zip(output_tokens, output_scores, partial_masks)
546
+ ]
547
+
548
+ repeatition = kwargs.get("mbr", 1) * kwargs.get("length_beam", 1)
549
+ if repeatition > 1:
550
+ hyps = [score_select(hyps[i:i+repeatition]) for i in range(0, len(hyps), repeatition)]
551
+
552
+ finalized = pad_sequence([h["tokens"] for h in hyps ], batch_first=True, padding_value=self.pad_id)
553
+
554
+ # If the user expects just tokens, we return finalized tokens.
555
+ # The original model.generate returned just tokens.
556
+ return finalized