Bailan-Alex commited on
Commit
2f3e169
·
verified ·
1 Parent(s): c906e52

Upload folder using huggingface_hub

Browse files
README.md CHANGED
@@ -1,12 +1,6 @@
1
  ---
2
- title: D2F Eval
3
- emoji: 🏃
4
- colorFrom: gray
5
- colorTo: gray
6
  sdk: gradio
7
- sdk_version: 5.49.1
8
- app_file: app.py
9
- pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: D2F-eval
3
+ app_file: generate_llada_demo_block.py
 
 
4
  sdk: gradio
5
+ sdk_version: 5.49.0
 
 
6
  ---
 
 
eval_dream.py ADDED
@@ -0,0 +1,1155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import gc
3
+ import time
4
+ import json
5
+ from datetime import timedelta
6
+ from typing import List, Optional, Tuple, Type, TypeVar, Union
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import torch.distributions as dists
10
+ import transformers
11
+ from accelerate import (
12
+ Accelerator,
13
+ InitProcessGroupKwargs,
14
+ )
15
+ from datasets import Dataset
16
+ from packaging import version
17
+ from tqdm import tqdm
18
+ from peft import PeftConfig, PeftModel
19
+ import numpy as np
20
+
21
+ from lm_eval import utils
22
+ from lm_eval.api.instance import Instance
23
+ from lm_eval.api.model import LM
24
+ from lm_eval.api.registry import register_model
25
+ from lm_eval.models.utils import get_dtype
26
+ from lm_eval.__main__ import cli_evaluate
27
+
28
+ eval_logger = logging.getLogger(__name__)
29
+ T = TypeVar("T", bound="LM")
30
+ import random
31
+ def set_seed(seed):
32
+ torch.manual_seed(seed)
33
+ random.seed(seed)
34
+ np.random.seed(seed)
35
+
36
+ torch.backends.cudnn.deterministic = True
37
+ torch.backends.cudnn.benchmark = False
38
+
39
+ def shift_logits(logits):
40
+ shifted_logits = torch.zeros_like(logits)
41
+ shifted_logits[:, 1:, :] = logits[:, :-1, :]
42
+ shifted_logits[:, 0, :] = 1.0
43
+ return shifted_logits
44
+
45
+ def create_full_block_attention_mask(prompt_length, max_length, block_size, device=None, dtype=None):
46
+ """
47
+ Creates a complete attention mask for the entire sequence with block-based causal attention.
48
+
49
+ Args:
50
+ prompt_length: Length of the prompt (first irregular block)
51
+ max_length: Maximum total sequence length
52
+ block_size: Size of each regular block
53
+ device: Device to create tensor on
54
+ dtype: Data type for the attention mask
55
+
56
+ Returns:
57
+ attention_mask: Tensor of shape [1, 1, max_length, max_length]
58
+ """
59
+ # Use the provided dtype or default to bfloat16
60
+ if dtype is None:
61
+ dtype = torch.bfloat16
62
+
63
+ # Initialize mask with -inf (no attention)
64
+ attention_mask = torch.full((1, 1, max_length, max_length), -torch.inf, device=device, dtype=dtype)
65
+
66
+ # Block 0: Prompt (can see itself)
67
+ attention_mask[:, :, :prompt_length, :prompt_length] = 0
68
+
69
+ # Calculate the number of regular blocks after prompt
70
+ remaining_length = max_length - prompt_length
71
+ num_blocks = (remaining_length + block_size - 1) // block_size
72
+
73
+ # Process each regular block
74
+ for b in range(num_blocks):
75
+ block_start = prompt_length + b * block_size
76
+ block_end = min(prompt_length + (b + 1) * block_size, max_length)
77
+
78
+ # Current block can see the prompt
79
+ attention_mask[:, :, block_start:block_end, :prompt_length] = 0
80
+
81
+ # Current block can see all previous regular blocks
82
+ for prev_b in range(b):
83
+ prev_start = prompt_length + prev_b * block_size
84
+ prev_end = min(prompt_length + (prev_b + 1) * block_size, max_length)
85
+ attention_mask[:, :, block_start:block_end, prev_start:prev_end] = 0
86
+
87
+ # Current block can see itself (full attention within block)
88
+ attention_mask[:, :, block_start:block_end, block_start:block_end] = 0
89
+
90
+ return attention_mask
91
+
92
+ def extract_attention_mask(full_mask, start_pos, input_length, cache_length):
93
+ """
94
+ Extract the relevant portion of attention mask for current forward pass.
95
+
96
+ Args:
97
+ full_mask: Complete attention mask [1, 1, max_length, max_length]
98
+ start_pos: Starting position in the full sequence
99
+ input_length: Length of current input sequence
100
+ cache_length: Length of cached sequence
101
+
102
+ Returns:
103
+ attention_mask: Extracted mask [1, 1, input_length, cache_length + input_length]
104
+ """
105
+ end_pos = start_pos + input_length
106
+ total_length = cache_length + input_length
107
+
108
+ # Extract the relevant rows (current input positions)
109
+ # and columns (cache + current input positions)
110
+ extracted_mask = torch.full((1, 1, input_length, total_length), -torch.inf,
111
+ device=full_mask.device, dtype=full_mask.dtype)
112
+
113
+ # Copy cache columns (0 to cache_length in the extracted mask corresponds to 0 to cache_length in full mask)
114
+ extracted_mask[:, :, :, :cache_length] = full_mask[:, :, start_pos:end_pos, :cache_length]
115
+
116
+ # Copy current input columns
117
+ extracted_mask[:, :, :, cache_length:] = full_mask[:, :, start_pos:end_pos, start_pos:end_pos]
118
+
119
+ return extracted_mask
120
+
121
+ def build_custom_float_attention_mask(input_ids, prompt_length, block_size, device=None, dtype=None):
122
+ B, seq_len = input_ids.shape
123
+ # Use the provided dtype or default to float32
124
+ if dtype is None:
125
+ dtype = torch.float32
126
+ # Initialize to all -inf
127
+ attn_mask = torch.full((B, 1, seq_len, seq_len), float('-inf'), dtype=dtype, device=device)
128
+ # 1. Prompt part: each token can attend to the entire prompt
129
+ for i in range(B):
130
+ attn_mask[i, :, :, :prompt_length[i]] = 0.0 # Allow all tokens to see the prompt
131
+
132
+ # 2. Block division: divide into blocks starting from prompt_length
133
+ num_blocks = (seq_len - prompt_length[i] + block_size - 1) // block_size
134
+
135
+ for b in range(num_blocks):
136
+ block_start = prompt_length[i] + b * block_size
137
+ block_end = min(block_start + block_size, seq_len)
138
+
139
+ # Full attention within the block
140
+ attn_mask[i, :, block_start:block_end, block_start:block_end] = 0.0
141
+
142
+ # Causal attention between blocks (can only see previous blocks)
143
+ for prev_b in range(b):
144
+ prev_start = prompt_length[i] + prev_b * block_size
145
+ prev_end = min(prev_start + block_size, seq_len)
146
+
147
+ # Current block can see previous blocks
148
+ attn_mask[i, :, block_start:block_end, prev_start:prev_end] = 0.0
149
+
150
+ return attn_mask
151
+
152
+ def top_p_logits(logits, top_p=None):
153
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
154
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
155
+ sorted_indices_to_remove = cumulative_probs > top_p
156
+ # Shift the indices to the right to keep the first token above the threshold
157
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
158
+ sorted_indices_to_remove[..., 0] = 0
159
+
160
+ mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
161
+ mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
162
+ logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
163
+ return logits
164
+
165
+ def top_k_logits(logits, top_k=None):
166
+ top_k = min(top_k, logits.size(-1)) # Safety check
167
+ # Remove all tokens with a probability less than the last token of the top-k
168
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
169
+ logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
170
+ return logits
171
+
172
+ def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
173
+ if temperature > 0:
174
+ logits = logits / temperature
175
+ if top_p is not None and top_p < 1:
176
+ logits = top_p_logits(logits, top_p)
177
+ if top_k is not None:
178
+ logits = top_k_logits(logits, top_k)
179
+ probs = torch.softmax(logits, dim=-1)
180
+
181
+ if temperature > 0:
182
+ try:
183
+ x0 = dists.Categorical(probs=probs).sample()
184
+ initial_confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
185
+ except:
186
+ initial_confidence, x0 = probs.max(dim=-1)
187
+ else:
188
+ initial_confidence, x0 = probs.max(dim=-1)
189
+
190
+ # Save initial confidence
191
+ confidence = initial_confidence.clone()
192
+
193
+ if margin_confidence:
194
+ sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
195
+ # Extract top1 and top2 probabilities
196
+ top1_probs = sorted_probs[:, 0]
197
+ top2_probs = sorted_probs[:, 1]
198
+ # Calculate confidence as top1 - top2
199
+ confidence = top1_probs - top2_probs
200
+
201
+ if neg_entropy:
202
+ epsilon = 1e-10
203
+ log_probs = torch.log(probs + epsilon)
204
+ confidence = torch.sum(probs * log_probs, dim=-1)
205
+
206
+ return confidence, x0, initial_confidence
207
+
208
+ @register_model("dream_lora")
209
+ class DreamLoRA(LM):
210
+ def __init__(
211
+ self,
212
+ pretrained: Union[str, transformers.PreTrainedModel],
213
+ lora_path: str,
214
+ batch_size: Optional[Union[int, str]] = 1,
215
+ device: Optional[str] = "cuda",
216
+ dtype: Optional[Union[str, torch.dtype]] = "auto",
217
+ max_new_tokens: Optional[int] = 128,
218
+ max_length: Optional[int] = 2048, # Updated to match example code
219
+ add_bos_token: Optional[bool] = False,
220
+ nll_type: Optional[str] = "mc",
221
+ log_type: Optional[str] = "ftb",
222
+ mc_num: Optional[int] = 128,
223
+ classifier_free_guidance: Optional[float] = 1.0,
224
+ sampling_eps: Optional[float] = 1e-3,
225
+ diffusion_steps: Optional[int] = 128,
226
+ trust_remote_code: Optional[bool] = True,
227
+ parallelize: Optional[bool] = False,
228
+ autogptq: Optional[Union[bool, str]] = False,
229
+ temperature: Optional[float] = 0.2, # Updated default
230
+ top_p: Optional[float] = None, # Updated default
231
+ top_k: Optional[float] = None,
232
+ alg: Optional[str] = "entropy",
233
+ alg_temp: Optional[float] = 0.0,
234
+ escape_until: Optional[bool] = False,
235
+ block_size: Optional[int] = 4, # Updated to match example code
236
+ mask_token_id: Optional[int] = 151666, # Added mask_token_id parameter
237
+ block_add_threshold: Optional[float] = 0.5, # Added block_add_threshold parameter
238
+ decoded_token_threshold: Optional[int] = 0.9, # Added decoded_token_threshold parameter
239
+ skip_threshold: Optional[float] = 1.0, # Added skip_threshold parameter
240
+ sampling_strategy: Optional[str] = "default", # Added sampling_strategy parameter
241
+ save_dir: Optional[str] = None,
242
+ **kwargs,
243
+ ) -> None:
244
+ super().__init__()
245
+
246
+ # prepare for parallelism
247
+ assert isinstance(device, str)
248
+ assert isinstance(pretrained, str)
249
+ assert isinstance(batch_size, (int, str))
250
+
251
+ gpus = torch.cuda.device_count()
252
+ accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
253
+ accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
254
+ if accelerator.num_processes > 1:
255
+ self.accelerator = accelerator
256
+
257
+ if "npu" in accelerator.device.type:
258
+ gpus = torch.npu.device_count()
259
+
260
+ # using one process with no model parallelism
261
+ if not (parallelize or accelerator.num_processes > 1):
262
+ # use user-passed device
263
+ device_list = set(
264
+ ["cuda", "cpu"]
265
+ + [f"cuda:{i}" for i in range(gpus)]
266
+ + ["mps", "mps:0"]
267
+ + [f"npu:{i}" for i in range(gpus)]
268
+ )
269
+ if device and device in device_list:
270
+ self._device = torch.device(device)
271
+ eval_logger.info(f"Using device '{device}'")
272
+ if device in ("mps", "mps:0") and version.parse(
273
+ torch.__version__
274
+ ) < version.parse("2.1"):
275
+ raise RuntimeError(
276
+ f"mps requires torch >= 2.1. You have {torch.__version__}"
277
+ )
278
+ else:
279
+ eval_logger.info("Device not specified")
280
+ eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
281
+ self._device = (
282
+ torch.device("cuda")
283
+ if torch.cuda.is_available()
284
+ else torch.device("cpu")
285
+ )
286
+ else: # Parallelism managed by accelerate
287
+ if device != "cuda":
288
+ eval_logger.info(
289
+ f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model."
290
+ )
291
+ # TODO: include in warning that `load_in_8bit` etc. affect this too
292
+ self._device = (
293
+ self.accelerator.device
294
+ if hasattr(self, "accelerator")
295
+ else torch.device(device)
296
+ )
297
+
298
+ self.batch_size_per_gpu = batch_size
299
+ if isinstance(batch_size, str):
300
+ self.batch_size_per_gpu = int(batch_size)
301
+
302
+ # Save LoRA path and block_size
303
+ self.lora_path = lora_path
304
+ self.block_size = block_size
305
+ self.block_add_threshold = block_add_threshold # New block_add_threshold attribute
306
+ self.skip_threshold = skip_threshold # New skip_threshold attribute
307
+ self.sampling_strategy = sampling_strategy # Save sampling strategy parameter
308
+ self.decoded_token_threshold = decoded_token_threshold # New decoded_token_threshold attribute
309
+ self.save_dir = save_dir
310
+
311
+ # Add metric tracking
312
+ self.total_forward_passes = 0
313
+ self.total_generated_tokens = 0
314
+ self.total_prompts = 0
315
+ # Add time and token statistics
316
+ self.total_generation_time = 0.0
317
+ self.total_block_tokens = 0 # Number of blocks * block_size
318
+ self.total_actual_tokens = 0 # Actual generated tokens (excluding EOS)
319
+ self.total_non_eos_tokens = 0 # Total non-EOS tokens in the entire sequence
320
+ self.all_generation_times = []
321
+ self.all_block_tokens = []
322
+ self.all_actual_tokens = []
323
+ self.all_non_eos_tokens = []
324
+
325
+ # Save target_dtype for later use
326
+ self.target_dtype = get_dtype(dtype)
327
+
328
+ self._create_model_and_tokenizer(pretrained, dtype, trust_remote_code)
329
+
330
+ if isinstance(pretrained, str):
331
+ if gpus >= 1 or str(self.device) == "mps":
332
+ # TODO: can remove this whole snippet except in the mps case, perhaps?
333
+ if not (parallelize or autogptq or hasattr(self, "accelerator")):
334
+ # place model onto device requested manually,
335
+ # if not using HF Accelerate or device_map
336
+ # or any other option that preloads model onto device
337
+ try:
338
+ self.model.to(self.device)
339
+ except ValueError:
340
+ eval_logger.debug(
341
+ "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided. If the desired GPU is being used, this message is safe to ignore."
342
+ )
343
+ # multigpu data-parallel support when launched with accelerate
344
+ if gpus > 1:
345
+ if accelerator.num_processes > 1:
346
+ if parallelize:
347
+ eval_logger.warning(
348
+ "You are both using a HF Accelerate `device_map` (`--model_args parallelize=True`) and launching via `accelerate launch`. This will attempt to do model and data parallelism depending on the resources available."
349
+ )
350
+ elif gpus > accelerator.num_processes:
351
+ eval_logger.warning(
352
+ "WARNING: The number of total system GPUs does not match the number of spawned processes. "
353
+ "If you would like to use data parallelism, please launch the script "
354
+ "with 'accelerate launch *script*'. "
355
+ f"Current run will proceed with {accelerator.num_processes} devices."
356
+ )
357
+ if self.accelerator.is_local_main_process:
358
+ eval_logger.info(
359
+ f"Using {gpus} devices with data parallelism"
360
+ )
361
+
362
+ self._device = torch.device(f"{accelerator.device}")
363
+ self.accelerator = accelerator
364
+
365
+ self._rank = self.accelerator.local_process_index
366
+ self._world_size = self.accelerator.num_processes
367
+ else:
368
+ # if we aren't launching via accelerate, ditch
369
+ self._rank = 0
370
+ self._world_size = 1
371
+ else:
372
+ # if a PreTrainedModel was passed into HFLM, we forgo distributed setup.
373
+ eval_logger.warning(
374
+ "Passed an already-initialized model through `pretrained`, assuming single-process call to evaluate() or custom distributed integration"
375
+ )
376
+ self._rank = 0
377
+ self._world_size = 1
378
+
379
+ self.max_length = max_length
380
+ self.add_bos_token = add_bos_token
381
+ # generation params
382
+ self.max_new_tokens = max_new_tokens
383
+ self.diffusion_steps = diffusion_steps
384
+ self.temperature = temperature
385
+ self.top_p = top_p
386
+ self.top_k = top_k
387
+ self.alg = alg
388
+ self.alg_temp = alg_temp
389
+ self.escape_until = escape_until
390
+ self.block_size = block_size
391
+ self.mask_token_id = mask_token_id
392
+
393
+ # loglikelihood params
394
+ self.nll_type = nll_type
395
+ self.log_type = log_type
396
+ self.mc_num = mc_num
397
+ self.classifier_free_guidance = classifier_free_guidance
398
+ self.sampling_eps = sampling_eps
399
+
400
+ @property
401
+ def batch_size(self):
402
+ return self.batch_size_per_gpu
403
+
404
+ @property
405
+ def device(self):
406
+ return self._device
407
+
408
+ @property
409
+ def rank(self):
410
+ return self._rank
411
+
412
+ @property
413
+ def world_size(self):
414
+ return self._world_size
415
+
416
+ def _create_model_and_tokenizer(self, pretrained, dtype, trust_remote_code):
417
+ # Get correct data type
418
+ from model_cache.dream.model_dream import DreamModel
419
+ from model_cache.dream.configuration_dream import DreamConfig
420
+ target_dtype = get_dtype(dtype)
421
+
422
+ # Load base model, using DreamModel and DreamConfig
423
+ model_config = DreamConfig.from_pretrained(pretrained)
424
+ self.model = DreamModel.from_pretrained(
425
+ pretrained,
426
+ config=model_config,
427
+ torch_dtype=target_dtype,
428
+ trust_remote_code=False,
429
+ ).eval()
430
+
431
+ # Load LoRA config and model
432
+ config = PeftConfig.from_pretrained(self.lora_path)
433
+ self.model = PeftModel.from_pretrained(self.model, self.lora_path)
434
+
435
+ # Only convert data type if target_dtype is not None and not "auto"
436
+ if target_dtype is not None and target_dtype != "auto":
437
+ self.model = self.model.to(target_dtype)
438
+
439
+ # Move to specified device
440
+ self.model = self.model.to(self.device)
441
+
442
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(
443
+ pretrained, trust_remote_code=trust_remote_code
444
+ )
445
+
446
+ def tok_decode(self, tokens, skip_special_tokens=True):
447
+ return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
448
+
449
+ def tok_encode(self, text, add_special_tokens=True):
450
+ return self.tokenizer(
451
+ text, return_tensors="pt", add_special_tokens=add_special_tokens
452
+ ).input_ids
453
+
454
+ @classmethod
455
+ def create_from_arg_string(
456
+ cls: Type[T], arg_string: str, additional_config: Optional[dict] = None
457
+ ) -> T:
458
+ """
459
+ Creates an instance of the LM class using the given argument string and additional config.
460
+
461
+ Parameters:
462
+ - arg_string: A string containing arguments in the format key1=value1,key2=value2.
463
+ - additional_config: Optional dictionary containing additional configuration parameters.
464
+
465
+ Returns:
466
+ - Instance of the LM class.
467
+ """
468
+ additional_config = {} if additional_config is None else additional_config
469
+ args = utils.simple_parse_args_string(arg_string)
470
+ args2 = {k: v for k, v in additional_config.items() if v is not None}
471
+ return cls(**args, **args2)
472
+
473
+ def apply_chat_template(
474
+ self, chat_history, add_generation_prompt: bool = True
475
+ ) -> str:
476
+ """
477
+ Method to apply a chat template to a list of chat history between user and model.
478
+ """
479
+ chat_templated = self.tokenizer.apply_chat_template(
480
+ chat_history,
481
+ tokenize=False,
482
+ add_generation_prompt=add_generation_prompt,
483
+ continue_final_message=not add_generation_prompt,
484
+ )
485
+
486
+ return chat_templated
487
+
488
+ @property
489
+ def tokenizer_name(self) -> str:
490
+ return self.tokenizer.name_or_path.replace("/", "__")
491
+
492
+ def _count_non_eos_tokens_before_truncation(self, generated_sequence, prompt_length):
493
+ """
494
+ Unified token counting function: counts non-EOS tokens in the generated sequence (before truncation).
495
+ """
496
+ # Get the generated part (excluding the prompt)
497
+ generated_tokens = generated_sequence[prompt_length:]
498
+ # Count non-EOS tokens
499
+ eos_token_id = self.tokenizer.eos_token_id
500
+ if eos_token_id is not None:
501
+ # If it's a tensor, convert to list for counting
502
+ if hasattr(generated_tokens, 'tolist'):
503
+ generated_tokens_list = generated_tokens.tolist()
504
+ else:
505
+ generated_tokens_list = generated_tokens
506
+ non_eos_count = sum(1 for token in generated_tokens_list if token != eos_token_id)
507
+ else:
508
+ non_eos_count = len(generated_tokens)
509
+ return non_eos_count
510
+
511
+ def _generate_batch(self, prompts: List[str]) -> List[str]:
512
+ if self.add_bos_token:
513
+ prompts = [self.tokenizer.bos_token + p for p in prompts]
514
+
515
+ responses = []
516
+
517
+ # Generate for each prompt individually (block generation usually processes one by one)
518
+ for i, prompt in enumerate(prompts):
519
+ # tokenize
520
+ prompt_ids = self.tokenizer.encode(prompt)
521
+ prompt_tensor = torch.tensor([prompt_ids], device=self.device, dtype=torch.long)
522
+
523
+ if len(prompt_ids) > self.max_length - self.max_new_tokens:
524
+ eval_logger.warning(f"Prompt length {len(prompt_ids)} is larger than {self.max_length-self.max_new_tokens}, cutoff on the left side")
525
+ prompt_tensor = prompt_tensor[:, -(self.max_length-self.max_new_tokens):]
526
+
527
+ # Use generate_block_single method to generate, returns EOS-truncated response text
528
+ response = self._generate_block_single(prompt_tensor)
529
+ responses.append(response)
530
+
531
+ return responses
532
+
533
+ def _generate_block_single(self, prompt):
534
+ """
535
+ Generates a response for a single prompt using parallel block generation, based on KV cache,
536
+ and using pre-generated attention masks.
537
+ Returns: EOS-truncated response text.
538
+ """
539
+ self.model.eval()
540
+
541
+ mask_id = self.mask_token_id
542
+ block_size = self.block_size
543
+ block_add_threshold = self.block_add_threshold
544
+ skip_threshold = self.skip_threshold
545
+ decoded_token_threshold = self.decoded_token_threshold
546
+
547
+ # Pre-generate full attention mask, using model's data type
548
+ prompt_length = prompt.shape[1]
549
+ full_attention_mask = create_full_block_attention_mask(
550
+ prompt_length=prompt_length,
551
+ max_length=self.max_length,
552
+ block_size=block_size,
553
+ device=self.device,
554
+ dtype=self.target_dtype if self.target_dtype is not None and self.target_dtype != "auto" else torch.bfloat16
555
+ )
556
+
557
+ with torch.inference_mode():
558
+ # Initialization
559
+ x_t = prompt.to(self.device)
560
+
561
+ # Track block states - state can be: 'active', 'to_cache', 'in_cache'
562
+ # Added 'is_complete' field to indicate whether it's a complete state (True) or incomplete (False)
563
+ block_states = {
564
+ 0: {
565
+ 'start_pos': 0,
566
+ 'end_pos': prompt.shape[1],
567
+ 'mask_count': 0,
568
+ 'total_masks': prompt.shape[1],
569
+ 'state': 'to_cache', # prompt ready for caching immediately
570
+ 'is_complete': True, # prompt is always in a complete state
571
+ },
572
+ }
573
+
574
+ # Initialize cache
575
+ past_key_values = None
576
+ last_logits = None
577
+
578
+ current_blocks = 0 # Number of active blocks
579
+ step = 0
580
+ eos_detected = False # EOS detection flag
581
+
582
+ while current_blocks >= 0:
583
+ step += 1
584
+
585
+ # Check if a new block needs to be added
586
+ if len(block_states)-1 < (self.max_new_tokens // block_size) and not eos_detected:
587
+ last_block_id = len(block_states) - 1
588
+ current_progress = (block_states[last_block_id]['total_masks'] -
589
+ block_states[last_block_id]['mask_count']) / block_states[last_block_id]['total_masks']
590
+ if current_progress >= block_add_threshold:
591
+ # Add new block - defaults to incomplete state
592
+ new_block_id = len(block_states)
593
+ new_start_pos = x_t.shape[1]
594
+ x_t = torch.cat([x_t, torch.tensor([[mask_id] * block_size]).to(self.device)], dim=1)
595
+
596
+ block_states[new_block_id] = {
597
+ 'start_pos': new_start_pos,
598
+ 'end_pos': new_start_pos + block_size,
599
+ 'mask_count': block_size,
600
+ 'total_masks': block_size,
601
+ 'state': 'active',
602
+ 'is_complete': False, # New block defaults to incomplete state
603
+ }
604
+ current_blocks += 1
605
+
606
+ # At the beginning of each loop, update block completion states
607
+ self._update_block_completion_states(block_states, decoded_token_threshold)
608
+ # Check if there are still mask tokens
609
+ mask_index = (x_t == mask_id)
610
+ if mask_index.sum() == 0 and current_blocks == 0:
611
+ break
612
+
613
+ # Determine which blocks need to be added to cache
614
+ blocks_to_cache = [bid for bid, state in block_states.items()
615
+ if state['state'] == 'to_cache']
616
+
617
+ # Determine the part to process
618
+ cache_length = 0 if past_key_values is None else past_key_values.get_seq_length()
619
+
620
+ # Determine content to add to cache
621
+ update_kvcache = 0
622
+ if blocks_to_cache:
623
+ # Find the earliest block that needs to be cached
624
+ earliest_block_id = min(blocks_to_cache)
625
+ earliest_pos = block_states[earliest_block_id]['start_pos']
626
+
627
+ # Find the latest block that needs to be cached
628
+ latest_block_id = max(blocks_to_cache)
629
+ latest_pos = block_states[latest_block_id]['end_pos']
630
+
631
+ # Update cache for all blocks within this range
632
+ update_kvcache = latest_pos - earliest_pos
633
+
634
+ # Create input sequence for forward pass
635
+ process_start_pos = cache_length
636
+
637
+ if update_kvcache > 0:
638
+ # Need to update cache - use completed blocks
639
+ earliest_block_to_cache = min(blocks_to_cache)
640
+ input_seq = x_t[:, block_states[earliest_block_to_cache]['start_pos']:]
641
+ process_start_pos = block_states[earliest_block_to_cache]['start_pos']
642
+ else:
643
+ # Only process active blocks
644
+ active_blocks = [bid for bid in block_states.keys() if block_states[bid]['state'] == 'active']
645
+ if active_blocks:
646
+ # Get all active blocks after the cache
647
+ earliest_active_after_cache = float('inf')
648
+ for bid in active_blocks:
649
+ if block_states[bid]['start_pos'] >= cache_length:
650
+ earliest_active_after_cache = min(earliest_active_after_cache, block_states[bid]['start_pos'])
651
+
652
+ if earliest_active_after_cache < float('inf'):
653
+ input_seq = x_t[:, earliest_active_after_cache:]
654
+ process_start_pos = earliest_active_after_cache
655
+ else:
656
+ # No active blocks after cache, this shouldn't happen
657
+ input_seq = x_t[:, cache_length:]
658
+ # If cache length is already equal to or exceeds sequence length, exit
659
+ if cache_length >= x_t.shape[1]:
660
+ print(f"Cache length ({cache_length}) >= sequence length ({x_t.shape[1]}) at step {step}. Exiting generation loop.")
661
+ raise Exception("Cache length >= sequence length")
662
+ else:
663
+ # No active blocks, but might have blocks to cache in next iteration
664
+ break
665
+
666
+ # Check if input_seq is empty
667
+ if input_seq.shape[1] == 0:
668
+ print(f"Warning: input_seq is empty at step {step}. Breaking generation loop.")
669
+ raise Exception("input_seq is empty")
670
+
671
+ # Extract attention mask for current input from the pre-generated full mask
672
+ input_length = input_seq.shape[1]
673
+ attention_mask = extract_attention_mask(
674
+ full_mask=full_attention_mask,
675
+ start_pos=process_start_pos,
676
+ input_length=input_length,
677
+ cache_length=cache_length
678
+ )
679
+
680
+ # Forward pass
681
+ outputs = self.model(
682
+ input_seq,
683
+ attention_mask=attention_mask,
684
+ past_key_values=past_key_values,
685
+ use_cache=True,
686
+ update_kvcache=update_kvcache,
687
+ )
688
+
689
+ # If needed, update cache
690
+ if update_kvcache > 0:
691
+ # Store logits of the last position for next token prediction
692
+ cache_end_idx = update_kvcache - 1
693
+ last_logits = outputs.logits[:, cache_end_idx, :].unsqueeze(1)
694
+
695
+ # Update cache
696
+ past_key_values = outputs.past_key_values
697
+
698
+ # Mark blocks as cached
699
+ for block_id in blocks_to_cache:
700
+ block_states[block_id]['state'] = 'in_cache'
701
+
702
+ # Get correctly shifted logits for prediction
703
+ logits = self._shift_logits(outputs.logits, last_logit=last_logits)
704
+
705
+ # Process mask tokens for each active block
706
+ blocks_to_deactivate = []
707
+
708
+ for block_id in sorted(block_states.keys()):
709
+ if block_states[block_id]['state'] != 'active':
710
+ continue
711
+
712
+ # Get mask positions for this block
713
+ block_start = block_states[block_id]['start_pos']
714
+ block_end = block_states[block_id]['end_pos']
715
+ block_mask_index = mask_index.clone()
716
+ block_mask_index[:, :block_start] = False
717
+ block_mask_index[:, block_end:] = False
718
+
719
+ # If the current block has no masks, skip it
720
+ if block_mask_index.sum() == 0:
721
+ blocks_to_deactivate.append(block_id)
722
+ continue
723
+
724
+ # Calculate relative position for logits
725
+ logit_offset = block_start - process_start_pos
726
+ block_rel_positions = torch.where(block_mask_index[0, block_start:block_end])[0]
727
+
728
+ if block_rel_positions.size(0) > 0:
729
+ # Get logits for masked positions
730
+ block_mask_logits = logits[:, logit_offset + block_rel_positions, :]
731
+
732
+ # Sample tokens
733
+ confidence, x0, initial_confidence = sample_tokens(
734
+ block_mask_logits.squeeze(0),
735
+ self.temperature,
736
+ top_p=self.top_p,
737
+ top_k=self.top_k,
738
+ neg_entropy=(self.sampling_strategy == "neg_entropy"),
739
+ margin_confidence=(self.sampling_strategy == "margin_confidence")
740
+ )
741
+
742
+ # Apply different sampling strategies based on the block's complete/incomplete state
743
+ is_complete = block_states[block_id]['is_complete']
744
+
745
+ if is_complete:
746
+ # Complete state: apply confidence threshold, if no high confidence, select highest
747
+ high_conf_indices = torch.where(initial_confidence > skip_threshold)[0]
748
+
749
+ if len(high_conf_indices) == 0:
750
+ number_transfer_tokens = 1
751
+ _, transfer_index = torch.topk(confidence, number_transfer_tokens)
752
+ else:
753
+ transfer_index = torch.tensor([], device=self.device, dtype=torch.long)
754
+
755
+ # Merge indices
756
+ all_indices = torch.unique(torch.cat([transfer_index, high_conf_indices]))
757
+ else:
758
+ # Incomplete state: only apply confidence threshold, if none exceed, select no tokens
759
+ high_conf_indices = torch.where(initial_confidence > skip_threshold)[0]
760
+ all_indices = high_conf_indices
761
+
762
+ # Update tokens
763
+ if len(all_indices) > 0:
764
+ x0_ = torch.zeros_like(x0, device=self.device, dtype=torch.long) + mask_id
765
+ x0_[all_indices] = x0[all_indices].clone()
766
+
767
+ # Map indices back to original positions
768
+ for i, idx in enumerate(all_indices):
769
+ abs_pos = block_start + block_rel_positions[idx]
770
+ x_t[0, abs_pos] = x0_[idx]
771
+
772
+ # Update block state
773
+ block_states[block_id]['mask_count'] -= len(all_indices)
774
+
775
+ # Check EOS token
776
+ eos_token_id = self.tokenizer.eos_token_id
777
+ if eos_token_id is not None:
778
+ for idx in all_indices:
779
+ if x0[idx].item() == eos_token_id:
780
+ eos_detected = True
781
+ break
782
+
783
+ # If no masks remain in this block, deactivate it
784
+ mask_index = (x_t == mask_id)
785
+ block_mask_index = mask_index.clone()
786
+ block_mask_index[:, :block_start] = False
787
+ block_mask_index[:, block_end:] = False
788
+ if block_mask_index.sum() == 0:
789
+ blocks_to_deactivate.append(block_id)
790
+ continue
791
+
792
+ # Deactivate completed blocks and mark them for caching in the next iteration
793
+ for block_id in blocks_to_deactivate:
794
+ if block_states[block_id]['state'] == 'active':
795
+ # Check if all preceding blocks are already non-active
796
+ can_deactivate = True
797
+ for prev_block_id in range(block_id):
798
+ if prev_block_id in block_states and block_states[prev_block_id]['state'] == 'active':
799
+ can_deactivate = False
800
+ break
801
+
802
+ # Only mark the current block as 'to_cache' if all preceding blocks are non-active
803
+ if can_deactivate:
804
+ block_states[block_id]['state'] = 'to_cache'
805
+ current_blocks -= 1
806
+ # If there are active blocks before, keep current block as active (do nothing)
807
+
808
+ # Safety check
809
+ if step > 10000:
810
+ print(f"WARNING: Hit safety check at step {step}. Exiting generation loop.")
811
+ break
812
+
813
+ # First, calculate non-EOS tokens for the full generated sequence
814
+ generated_sequence = x_t[0, prompt.shape[1]:].tolist()
815
+ non_eos_tokens = self._count_non_eos_tokens_before_truncation(
816
+ x_t[0].tolist(), prompt.shape[1]
817
+ )
818
+
819
+ # Accumulate to total tokens
820
+ if not hasattr(self, 'total_generated_tokens'):
821
+ self.total_generated_tokens = 0
822
+ self.total_generated_tokens += non_eos_tokens
823
+
824
+ # Generate EOS-truncated response text (consistent with other file logic)
825
+ response = self.tokenizer.decode(generated_sequence).split(self.tokenizer.eos_token)[0]
826
+
827
+ return response
828
+
829
+ def _update_block_completion_states(self, block_states, decoded_token_threshold):
830
+ """
831
+ Updates the complete/incomplete state of blocks.
832
+ Iterates through blocks from front to back. If a block's decoded token count
833
+ is greater than the threshold, the next block to its right (if it exists)
834
+ is set to a complete state.
835
+ """
836
+ for block_id in sorted(block_states.keys()):
837
+ # if block_id == 0: # Skip prompt block
838
+ # continue
839
+
840
+ # Calculate decoded tokens for the current block
841
+ decoded_tokens = block_states[block_id]['total_masks'] - block_states[block_id]['mask_count']
842
+ decode_ratio = decoded_tokens / block_states[block_id]['total_masks']
843
+ # If the current block's decoded token count is greater than the threshold,
844
+ # then the next block (if it exists) is set to a complete state.
845
+ # print("decode_ratio",decode_ratio)
846
+ # print("decoded_token_threshold",decoded_token_threshold)
847
+ if decode_ratio >= decoded_token_threshold:
848
+ next_block_id = block_id + 1
849
+ if next_block_id in block_states:
850
+ block_states[next_block_id]['is_complete'] = True
851
+
852
+ def _shift_logits(self, logits, last_logit=None, block_size=None):
853
+ """Shifts logits to the right by one position, for autoregressive generation"""
854
+ # Check if logits are empty
855
+ if logits.shape[1] == 0:
856
+ print("Warning: logits sequence length is 0, returning empty logits")
857
+ raise Exception("logits sequence length is 0")
858
+
859
+ shifted_logits = torch.zeros_like(logits)
860
+ shifted_logits[:, 1:, :] = logits[:, :-1, :]
861
+ if last_logit is not None:
862
+ shifted_logits[:, 0, :] = last_logit
863
+ return shifted_logits
864
+ shifted_logits[:, 0, :] = 1.0
865
+ return shifted_logits
866
+
867
+ def generate_until(self, requests: List[Instance], disable_tqdm: bool = False):
868
+ res = []
869
+
870
+ # Initialize statistics counters
871
+ if not hasattr(self, 'total_generated_tokens'):
872
+ self.total_generated_tokens = 0
873
+ num_tokens = 0
874
+ num_nfe = 0 # Number of Forward Evaluations
875
+
876
+ pbar = tqdm(
877
+ total=len(requests),
878
+ disable=(disable_tqdm or (self.rank != 0)),
879
+ desc="Running generate_until requests",
880
+ )
881
+
882
+ start_time = time.time()
883
+
884
+ for batch_idx in range(0, len(requests), self.batch_size):
885
+ batch_requests = requests[batch_idx : batch_idx + self.batch_size]
886
+ contexts, gen_args = zip(*[req.arguments for req in batch_requests])
887
+ responses = self._generate_batch(contexts)
888
+ if not self.escape_until:
889
+ for i, r in enumerate(responses):
890
+ for s in gen_args[0]['until']:
891
+ r = r.split(s)[0]
892
+ responses[i] = r
893
+
894
+ res.extend(responses)
895
+ pbar.update(len(contexts))
896
+
897
+ end_time = time.time()
898
+ total_time = end_time - start_time
899
+
900
+ # Accumulate statistics
901
+ num_tokens = self.total_generated_tokens
902
+ num_nfe = self.diffusion_steps * len(requests) # Estimate NFE
903
+
904
+ # Save final statistics
905
+ final_stats = {
906
+ 'processed_samples': len(requests),
907
+ 'total_samples': len(requests),
908
+ 'total_tokens': num_tokens,
909
+ 'total_nfe': num_nfe,
910
+ 'total_time': total_time,
911
+ 'tokens_per_second': num_tokens / total_time if total_time > 0 else 0,
912
+ 'nfe_per_token': num_nfe / num_tokens if num_tokens > 0 else 0,
913
+ 'timestamp': time.strftime('%Y-%m-%d %H:%M:%S')
914
+ }
915
+
916
+ # Save statistics to file
917
+ if self.save_dir is not None:
918
+ import os
919
+ os.makedirs(self.save_dir, exist_ok=True)
920
+
921
+ # Save response results
922
+ save_path = os.path.join(self.save_dir, f'rank_{self.rank}_responses.jsonl')
923
+ with open(save_path, 'w', encoding='utf-8') as f:
924
+ for r in res:
925
+ f.write(json.dumps(r, ensure_ascii=False) + '\n')
926
+
927
+ # Save statistics results
928
+ stats_path = os.path.join(self.save_dir, f'rank_{self.rank}_final_stats.json')
929
+ with open(stats_path, 'w', encoding='utf-8') as f:
930
+ json.dump(final_stats, f, ensure_ascii=False, indent=2)
931
+
932
+ # Print final statistics
933
+ print("\n" + "="*60)
934
+ print("=== Final Statistics ===")
935
+ print("="*60)
936
+ print(f"Processed Samples: {final_stats['processed_samples']}")
937
+ print(f"Total Samples: {final_stats['total_samples']}")
938
+ print(f"Total Tokens: {final_stats['total_tokens']}")
939
+ print(f"Total NFE: {final_stats['total_nfe']}")
940
+ print(f"Total Time: {final_stats['total_time']:.4f}s")
941
+ print(f"Tokens/Second: {final_stats['tokens_per_second']:.2f}")
942
+ print(f"NFE/Token: {final_stats['nfe_per_token']:.4f}")
943
+ print(f"Completion Time: {final_stats['timestamp']}")
944
+ print("="*60)
945
+
946
+ return res
947
+
948
+ def _forward_process(self, batch):
949
+ b, l = batch.shape
950
+ # sample from U[0, 1] following https://arxiv.org/pdf/2107.00630 I.1
951
+ u0 = torch.rand(1, device=batch.device, dtype=torch.float32)
952
+ indices = torch.arange(b, device=batch.device).float()
953
+ t = (u0 + indices / b) % 1
954
+
955
+ p_mask = (1 - self.sampling_eps) * t + self.sampling_eps
956
+
957
+ p_mask = p_mask[:, None].repeat(1, l)
958
+
959
+ mask_indices = torch.rand((b, l), device=batch.device) < p_mask
960
+ # always unmask bos and eos
961
+ mask_indices[:, 0] = False
962
+ mask_indices[:, -1] = False
963
+
964
+ noisy_batch = torch.where(mask_indices, self.mask_token_id, batch)
965
+ return noisy_batch, p_mask
966
+
967
+ @torch.no_grad()
968
+ def get_logits(self, batch, prompt_index):
969
+ '''
970
+ prompt_index : 1D bool tensor, length=batch.shape[1]
971
+ '''
972
+ if self.classifier_free_guidance > 1.:
973
+ assert len(prompt_index) == batch.shape[1]
974
+ prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1)
975
+ un_batch = batch.clone()
976
+ un_batch[prompt_index] = self.mask_token_id
977
+ batch = torch.cat([batch, un_batch])
978
+
979
+ input = batch
980
+
981
+ with torch.amp.autocast('cuda', dtype=torch.bfloat16):
982
+ logits = self.model(input).logits
983
+ # since bos always unmask, the first logits will not be used
984
+ logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
985
+
986
+ if self.classifier_free_guidance > 1.:
987
+ logits, un_logits = torch.chunk(logits, 2, dim=0)
988
+ logits = un_logits + self.cfg * (logits - un_logits)
989
+ return logits[:, :batch.shape[1]]
990
+
991
+ @torch.no_grad()
992
+ def _eval_target_nll_mc(self, prefix, target):
993
+ if prefix is None:
994
+ seq = target[None, :]
995
+ else:
996
+ seq = torch.concatenate([prefix, target])[None, :]
997
+ seq = seq.repeat((self.batch_size, 1)).to(self.device)
998
+
999
+ if self.log_type == 'ftb':
1000
+ prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix)
1001
+ else:
1002
+ prompt_index = torch.arange(seq.shape[1], device=self.device) >= len(prefix)
1003
+
1004
+ loss_acc = []
1005
+ for _ in range(max(self.mc_num // self.batch_size, 1)):
1006
+ perturbed_seq = seq.clone()
1007
+ # eval_logger.info("before noising")
1008
+ perturbed_seq_, p_mask = self._forward_process(seq)
1009
+ # eval_logger.info("end noising")
1010
+ if self.log_type == 'ftb':
1011
+ perturbed_seq[:, -len(target):] = perturbed_seq_[:, -len(target):]
1012
+ elif self.log_type == 'btf':
1013
+ perturbed_seq[:, :len(prefix)] = perturbed_seq_[:, :len(prefix)]
1014
+ elif self.log_type == 'union':
1015
+ perturbed_seq = perturbed_seq_
1016
+ else:
1017
+ raise NotImplementedError(self.log_type)
1018
+
1019
+ mask_indices = perturbed_seq == self.mask_token_id
1020
+ logits = self.get_logits(perturbed_seq, prompt_index)
1021
+ loss = F.cross_entropy(logits[mask_indices], seq[mask_indices], reduction='none') / p_mask[mask_indices]
1022
+ loss = loss.sum() / self.batch_size
1023
+ loss_acc.append(loss.item())
1024
+
1025
+ return sum(loss_acc) / len(loss_acc)
1026
+
1027
+ @torch.no_grad()
1028
+ def _eval_target_nll_ar(self, prefix, target):
1029
+ prefix, target = prefix.unsqueeze(0), target.unsqueeze(0) # 1*l1, 1*l2
1030
+ assert self.log_type in ['ftb', 'btf']
1031
+ assert self.nll_type in ['ar_ftb', 'ar_btf']
1032
+
1033
+ if self.log_type == 'ftb':
1034
+ prompt_index = torch.arange(prefix.shape[1] + target.shape[1], device=self.device) < prefix.shape[1]
1035
+ else:
1036
+ prompt_index = torch.arange(prefix.shape[1] + target.shape[1], device=self.device) >= prefix.shape[1]
1037
+
1038
+ if self.log_type == 'ftb':
1039
+ perturbed_ = target.repeat(target.shape[1], 1).clone().contiguous() # l2*l2
1040
+ else:
1041
+ perturbed_ = prefix.repeat(prefix.shape[1], 1).clone().contiguous() # l1*l1
1042
+
1043
+ mask_index = torch.ones((perturbed_.shape[1], perturbed_.shape[1]), dtype=torch.bool)
1044
+ if self.nll_type == 'ar_ftb':
1045
+ mask_index = torch.triu(mask_index)
1046
+ else:
1047
+ mask_index = torch.tril(mask_index)
1048
+ perturbed_[mask_index] = self.mask_token_id
1049
+ if self.log_type == 'ftb':
1050
+ perturbed_seq = torch.cat([prefix.repeat(perturbed_.shape[0], 1), perturbed_], dim=-1)
1051
+ else:
1052
+ perturbed_seq = torch.cat([perturbed_, target.repeat(perturbed_.shape[0], 1)], dim=-1)
1053
+
1054
+ logits_ = []
1055
+ num = len(perturbed_seq) // self.batch_size if len(perturbed_seq) % self.batch_size == 0 else len(perturbed_seq) // self.batch_size + 1
1056
+ for i in range(num):
1057
+ end = (i + 1) * self.batch_size if (i + 1) * self.batch_size < len(perturbed_seq) else len(perturbed_seq)
1058
+ perturbed_seq_ = perturbed_seq[i * self.batch_size: end]
1059
+ perturbed_seq_ = perturbed_seq_.to(self.device)
1060
+ if len(perturbed_seq_.shape) == 1:
1061
+ perturbed_seq_ = perturbed_seq_.unsqueeze(0)
1062
+ logits = self.get_logits(perturbed_seq_, prompt_index)
1063
+ logits_.append(logits.cpu())
1064
+ logits = torch.cat(logits_, dim=0)
1065
+
1066
+ temp_index = torch.ones((perturbed_.shape[1], perturbed_.shape[1]), dtype=torch.bool)
1067
+ if self.nll_type == 'ar_ftb':
1068
+ temp_index = torch.triu(temp_index, diagonal=1)
1069
+ else:
1070
+ temp_index = torch.tril(temp_index, diagonal=-1)
1071
+ mask_index[temp_index] = False
1072
+ if self.log_type == 'ftb':
1073
+ logits_index = torch.cat([torch.zeros((perturbed_.shape[1], prefix.shape[1]), dtype=torch.bool), mask_index], dim=-1)
1074
+ else:
1075
+ logits_index = torch.cat([mask_index, torch.zeros((perturbed_.shape[1], target.shape[1]), dtype=torch.bool)], dim=-1)
1076
+
1077
+ if self.log_type == 'ftb':
1078
+ loss = F.cross_entropy(logits[logits_index], target[0], reduction='sum').cpu().item()
1079
+ else:
1080
+ loss = F.cross_entropy(logits[logits_index], prefix[0], reduction='sum').cpu().item()
1081
+ return loss
1082
+
1083
+ def _encode_pair(self, context, continuation):
1084
+ if self.add_bos_token:
1085
+ context = self.tokenizer.bos_token + context
1086
+
1087
+ n_spaces = len(context) - len(context.rstrip())
1088
+ if n_spaces > 0:
1089
+ continuation = context[-n_spaces:] + continuation
1090
+ context = context[:-n_spaces]
1091
+
1092
+ whole_enc = self.tokenizer.encode(context + continuation) + [self.tokenizer.eos_token_id]
1093
+ context_enc = self.tokenizer.encode(context)
1094
+
1095
+ context_enc_len = len(context_enc)
1096
+ continuation_enc = whole_enc[context_enc_len:]
1097
+
1098
+ # by default truncate on the left
1099
+ cutoff_length = max(len(whole_enc) - self.max_length, 0)
1100
+ if cutoff_length > 0:
1101
+ eval_logger.warning(f"Text length {len(whole_enc)} is larger than {self.max_length}, cutoff on the left side")
1102
+ context_remain = context_enc_len-cutoff_length
1103
+ if context_remain > 0:
1104
+ context_enc = context_enc[-context_remain:]
1105
+ else:
1106
+ eval_logger.warning(f"All context (prompt) is truncated.")
1107
+ context_enc = ""
1108
+ continuation_enc = whole_enc[-self.max_length:]
1109
+ return context_enc, continuation_enc
1110
+
1111
+ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
1112
+ def _tokenize(e):
1113
+ prefix, target = self._encode_pair(e["prefix"], e["target"])
1114
+ return {
1115
+ "prefix_text": e["prefix"],
1116
+ "target_text": e["target"],
1117
+ "prefix": prefix,
1118
+ "target": target,
1119
+ }
1120
+
1121
+ ds = []
1122
+ ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests]
1123
+ ds = Dataset.from_list(ds)
1124
+ print(ds[0])
1125
+ ds = ds.map(_tokenize)
1126
+ ds = ds.with_format("torch")
1127
+
1128
+ out = []
1129
+ with torch.no_grad():
1130
+ for elem in tqdm(ds, desc="Computing likelihood..."):
1131
+ prefix = elem["prefix"]
1132
+ target = elem["target"]
1133
+ # likelihood calculations are modified from https://github.com/ML-GSAI/SMDM/blob/main/evaluate_diff.py
1134
+ if self.nll_type == 'mc':
1135
+ ll = -self._eval_target_nll_mc(prefix, target)
1136
+ if self.log_type == 'union':
1137
+ ll = ll / (len(target) + len(prefix))
1138
+ elif self.nll_type == 'ar_ftb' or self.nll_type == 'ar_btf':
1139
+ ll = -self._eval_target_nll_ar(prefix, target)
1140
+ else:
1141
+ raise NotImplementedError(self.nll_type)
1142
+
1143
+ # TODO: greedy decoding
1144
+ is_target_greedy_dec = False
1145
+
1146
+ out.append((ll, 1.0 if is_target_greedy_dec else 0.0))
1147
+ return out
1148
+
1149
+ def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
1150
+ raise NotImplementedError
1151
+
1152
+
1153
+ if __name__ == "__main__":
1154
+ set_seed(1234)
1155
+ cli_evaluate()
eval_dream.sh ADDED
@@ -0,0 +1,158 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tasks="gsm8k_cot mbpp minerva_math"
2
+ nshots="8 3 4"
3
+ lengths="256 256 256"
4
+ temperatures="0 0 0"
5
+ limits="10000 10000 10000"
6
+ block_sizes="32 48 64"
7
+ block_add_thresholds="0.1 0.1 0.1"
8
+ decoded_token_thresholds="0.95 0.95 0.95"
9
+ skip_thresholds="0.9 0.9 0.9"
10
+ top_ps="none none none"
11
+ dtypes="bfloat16 bfloat16 bfloat16"
12
+ sampling_strategies="default default default"
13
+
14
+
15
+
16
+ humaneval_nshots="0"
17
+ humaneval_lengths="256"
18
+ humaneval_temperatures="0"
19
+ humaneval_limits="10000"
20
+ humaneval_diffusion_steps="256"
21
+ humaneval_block_sizes="32"
22
+ humaneval_block_add_thresholds="0.9"
23
+ humaneval_decoded_token_thresholds="0.95"
24
+ humaneval_skip_thresholds="0.95"
25
+ humaneval_top_ps="none"
26
+ humaneval_dtypes="bfloat16"
27
+ humaneval_sampling_strategies="default"
28
+
29
+
30
+
31
+ base_model=Dream-org/Dream-v0-Base-7B
32
+
33
+
34
+ lora_models=(
35
+ "SJTU-Deng-Lab/D2F_Dream_Base_7B_Lora"
36
+ )
37
+
38
+
39
+ read -ra TASKS_ARRAY <<< "$tasks"
40
+ read -ra NSHOTS_ARRAY <<< "$nshots"
41
+ read -ra LENGTH_ARRAY <<< "$lengths"
42
+ read -ra TEMP_ARRAY <<< "$temperatures"
43
+ read -ra LIMITS_ARRAY <<< "$limits"
44
+ read -ra BLOCK_SIZES_ARRAY <<< "$block_sizes"
45
+ read -ra BLOCK_ADD_THRESHOLDS_ARRAY <<< "$block_add_thresholds"
46
+ read -ra DECODED_TOKEN_THRESHOLDS_ARRAY <<< "$decoded_token_thresholds"
47
+ read -ra SKIP_THRESHOLDS_ARRAY <<< "$skip_thresholds"
48
+ read -ra TOP_PS_ARRAY <<< "$top_ps"
49
+ read -ra DTYPES_ARRAY <<< "$dtypes"
50
+ read -ra SAMPLING_STRATEGIES_ARRAY <<< "$sampling_strategies"
51
+
52
+
53
+ read -ra HUMANEVAL_NSHOTS_ARRAY <<< "$humaneval_nshots"
54
+ read -ra HUMANEVAL_LENGTHS_ARRAY <<< "$humaneval_lengths"
55
+ read -ra HUMANEVAL_TEMP_ARRAY <<< "$humaneval_temperatures"
56
+ read -ra HUMANEVAL_LIMITS_ARRAY <<< "$humaneval_limits"
57
+ read -ra HUMANEVAL_DIFFUSION_STEPS_ARRAY <<< "$humaneval_diffusion_steps"
58
+ read -ra HUMANEVAL_BLOCK_SIZES_ARRAY <<< "$humaneval_block_sizes"
59
+ read -ra HUMANEVAL_BLOCK_ADD_THRESHOLDS_ARRAY <<< "$humaneval_block_add_thresholds"
60
+ read -ra HUMANEVAL_DECODED_TOKEN_THRESHOLDS_ARRAY <<< "$humaneval_decoded_token_thresholds"
61
+ read -ra HUMANEVAL_SKIP_THRESHOLDS_ARRAY <<< "$humaneval_skip_thresholds"
62
+ read -ra HUMANEVAL_TOP_PS_ARRAY <<< "$humaneval_top_ps"
63
+ read -ra HUMANEVAL_DTYPES_ARRAY <<< "$humaneval_dtypes"
64
+ read -ra HUMANEVAL_SAMPLING_STRATEGIES_ARRAY <<< "$humaneval_sampling_strategies"
65
+
66
+
67
+ array_length=${#TASKS_ARRAY[@]}
68
+ if [[ ${#NSHOTS_ARRAY[@]} -ne $array_length ]] || \
69
+ [[ ${#LENGTH_ARRAY[@]} -ne $array_length ]] || \
70
+ [[ ${#TEMP_ARRAY[@]} -ne $array_length ]] || \
71
+ [[ ${#LIMITS_ARRAY[@]} -ne $array_length ]] || \
72
+ [[ ${#BLOCK_SIZES_ARRAY[@]} -ne $array_length ]] || \
73
+ [[ ${#BLOCK_ADD_THRESHOLDS_ARRAY[@]} -ne $array_length ]] || \
74
+ [[ ${#DECODED_TOKEN_THRESHOLDS_ARRAY[@]} -ne $array_length ]] || \
75
+ [[ ${#SKIP_THRESHOLDS_ARRAY[@]} -ne $array_length ]] || \
76
+ [[ ${#TOP_PS_ARRAY[@]} -ne $array_length ]] || \
77
+ [[ ${#SAMPLING_STRATEGIES_ARRAY[@]} -ne $array_length ]] || \
78
+ [[ ${#DTYPES_ARRAY[@]} -ne $array_length ]]; then
79
+ echo "Error: All configuration arrays must have the same length!"
80
+ echo "Tasks: ${#TASKS_ARRAY[@]}, Nshots: ${#NSHOTS_ARRAY[@]}, Lengths: ${#LENGTH_ARRAY[@]}, Temperatures: ${#TEMP_ARRAY[@]}, Limits: ${#LIMITS_ARRAY[@]}, Block sizes: ${#BLOCK_SIZES_ARRAY[@]}, Block thresholds: ${#BLOCK_ADD_THRESHOLDS_ARRAY[@]}, Decoded token thresholds: ${#DECODED_TOKEN_THRESHOLDS_ARRAY[@]}, Skip thresholds: ${#SKIP_THRESHOLDS_ARRAY[@]}, Top_ps: ${#TOP_PS_ARRAY[@]}, Sampling strategies: ${#SAMPLING_STRATEGIES_ARRAY[@]}, Dtypes: ${#DTYPES_ARRAY[@]}"
81
+ exit 1
82
+ fi
83
+
84
+
85
+ humaneval_array_length=${#HUMANEVAL_NSHOTS_ARRAY[@]}
86
+ if [[ ${#HUMANEVAL_LENGTHS_ARRAY[@]} -ne $humaneval_array_length ]] || \
87
+ [[ ${#HUMANEVAL_TEMP_ARRAY[@]} -ne $humaneval_array_length ]] || \
88
+ [[ ${#HUMANEVAL_LIMITS_ARRAY[@]} -ne $humaneval_array_length ]] || \
89
+ [[ ${#HUMANEVAL_DIFFUSION_STEPS_ARRAY[@]} -ne $humaneval_array_length ]] || \
90
+ [[ ${#HUMANEVAL_BLOCK_SIZES_ARRAY[@]} -ne $humaneval_array_length ]] || \
91
+ [[ ${#HUMANEVAL_BLOCK_ADD_THRESHOLDS_ARRAY[@]} -ne $humaneval_array_length ]] || \
92
+ [[ ${#HUMANEVAL_DECODED_TOKEN_THRESHOLDS_ARRAY[@]} -ne $humaneval_array_length ]] || \
93
+ [[ ${#HUMANEVAL_SKIP_THRESHOLDS_ARRAY[@]} -ne $humaneval_array_length ]] || \
94
+ [[ ${#HUMANEVAL_TOP_PS_ARRAY[@]} -ne $humaneval_array_length ]] || \
95
+ [[ ${#HUMANEVAL_DTYPES_ARRAY[@]} -ne $humaneval_array_length ]] || \
96
+ [[ ${#HUMANEVAL_SAMPLING_STRATEGIES_ARRAY[@]} -ne $humaneval_array_length ]]; then
97
+ echo "Error: All HumanEval configuration arrays must have the same length!"
98
+ echo "HumanEval Nshots: ${#HUMANEVAL_NSHOTS_ARRAY[@]}, Lengths: ${#HUMANEVAL_LENGTHS_ARRAY[@]}, Temperatures: ${#HUMANEVAL_TEMP_ARRAY[@]}, Limits: ${#HUMANEVAL_LIMITS_ARRAY[@]}, Diffusion steps: ${#HUMANEVAL_DIFFUSION_STEPS_ARRAY[@]}, Block sizes: ${#HUMANEVAL_BLOCK_SIZES_ARRAY[@]}, Block thresholds: ${#HUMANEVAL_BLOCK_ADD_THRESHOLDS_ARRAY[@]}, Decoded token thresholds: ${#HUMANEVAL_DECODED_TOKEN_THRESHOLDS_ARRAY[@]}, Skip thresholds: ${#HUMANEVAL_SKIP_THRESHOLDS_ARRAY[@]}, Top_ps: ${#HUMANEVAL_TOP_PS_ARRAY[@]}, Dtypes: ${#HUMANEVAL_DTYPES_ARRAY[@]}, Sampling strategies: ${#HUMANEVAL_SAMPLING_STRATEGIES_ARRAY[@]}"
99
+ exit 1
100
+ fi
101
+
102
+ export HF_ALLOW_CODE_EVAL=1
103
+ for lora_model in "${lora_models[@]}"; do
104
+ lora_model_name="$lora_model"
105
+ echo "===================================================================="
106
+ echo "Evaluating LoRA model: $lora_model_name"
107
+ echo "===================================================================="
108
+
109
+
110
+
111
+ for i in "${!HUMANEVAL_NSHOTS_ARRAY[@]}"; do
112
+ output_path="evals_dream${lora_model_name}/humaneval-ns${HUMANEVAL_NSHOTS_ARRAY[$i]}-len${HUMANEVAL_LENGTHS_ARRAY[$i]}-temp${HUMANEVAL_TEMP_ARRAY[$i]}-limit${HUMANEVAL_LIMITS_ARRAY[$i]}-diffsteps${HUMANEVAL_DIFFUSION_STEPS_ARRAY[$i]}-block${HUMANEVAL_BLOCK_SIZES_ARRAY[$i]}-thresh${HUMANEVAL_BLOCK_ADD_THRESHOLDS_ARRAY[$i]}-decodethresh${HUMANEVAL_DECODED_TOKEN_THRESHOLDS_ARRAY[$i]}-skip${HUMANEVAL_SKIP_THRESHOLDS_ARRAY[$i]}-topp${HUMANEVAL_TOP_PS_ARRAY[$i]}-dtype${HUMANEVAL_DTYPES_ARRAY[$i]}-sampling${HUMANEVAL_SAMPLING_STRATEGIES_ARRAY[$i]}"
113
+ echo "Running HumanEval evaluation $((i+1))/${humaneval_array_length} for $lora_model_name..."
114
+ echo "HumanEval Config: Shots: ${HUMANEVAL_NSHOTS_ARRAY[$i]}, Length: ${HUMANEVAL_LENGTHS_ARRAY[$i]}, Temperature: ${HUMANEVAL_TEMP_ARRAY[$i]}, Limit: ${HUMANEVAL_LIMITS_ARRAY[$i]}, Diffusion Steps: ${HUMANEVAL_DIFFUSION_STEPS_ARRAY[$i]}, Block Size: ${HUMANEVAL_BLOCK_SIZES_ARRAY[$i]}, Block Add Threshold: ${HUMANEVAL_BLOCK_ADD_THRESHOLDS_ARRAY[$i]}, Decoded Token Threshold: ${HUMANEVAL_DECODED_TOKEN_THRESHOLDS_ARRAY[$i]}, Skip Threshold: ${HUMANEVAL_SKIP_THRESHOLDS_ARRAY[$i]}, Top_p: ${HUMANEVAL_TOP_PS_ARRAY[$i]}, Sampling Strategy: ${HUMANEVAL_SAMPLING_STRATEGIES_ARRAY[$i]}, Dtype: ${HUMANEVAL_DTYPES_ARRAY[$i]}; Output: $output_path"
115
+
116
+ if [[ "${HUMANEVAL_TOP_PS_ARRAY[$i]}" == "none" ]]; then
117
+ humaneval_model_args="pretrained=${base_model},lora_path=${lora_model},max_new_tokens=${HUMANEVAL_LENGTHS_ARRAY[$i]},diffusion_steps=${HUMANEVAL_DIFFUSION_STEPS_ARRAY[$i]},temperature=${HUMANEVAL_TEMP_ARRAY[$i]},add_bos_token=true,escape_until=true,block_size=${HUMANEVAL_BLOCK_SIZES_ARRAY[$i]},block_add_threshold=${HUMANEVAL_BLOCK_ADD_THRESHOLDS_ARRAY[$i]},skip_threshold=${HUMANEVAL_SKIP_THRESHOLDS_ARRAY[$i]},decoded_token_threshold=${HUMANEVAL_DECODED_TOKEN_THRESHOLDS_ARRAY[$i]},dtype=${HUMANEVAL_DTYPES_ARRAY[$i]},sampling_strategy=${HUMANEVAL_SAMPLING_STRATEGIES_ARRAY[$i]},save_dir=${output_path}"
118
+ else
119
+ humaneval_model_args="pretrained=${base_model},lora_path=${lora_model},max_new_tokens=${HUMANEVAL_LENGTHS_ARRAY[$i]},diffusion_steps=${HUMANEVAL_DIFFUSION_STEPS_ARRAY[$i]},temperature=${HUMANEVAL_TEMP_ARRAY[$i]},top_p=${HUMANEVAL_TOP_PS_ARRAY[$i]},add_bos_token=true,escape_until=true,block_size=${HUMANEVAL_BLOCK_SIZES_ARRAY[$i]},block_add_threshold=${HUMANEVAL_BLOCK_ADD_THRESHOLDS_ARRAY[$i]},skip_threshold=${HUMANEVAL_SKIP_THRESHOLDS_ARRAY[$i]},decoded_token_threshold=${HUMANEVAL_DECODED_TOKEN_THRESHOLDS_ARRAY[$i]},dtype=${HUMANEVAL_DTYPES_ARRAY[$i]},sampling_strategy=${HUMANEVAL_SAMPLING_STRATEGIES_ARRAY[$i]},save_dir=${output_path}"
120
+ fi
121
+
122
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch --main_process_port 29520 --num_processes 8 eval_dream.py --model dream_lora \
123
+ --model_args $humaneval_model_args \
124
+ --tasks humaneval \
125
+ --num_fewshot ${HUMANEVAL_NSHOTS_ARRAY[$i]} \
126
+ --batch_size 1 \
127
+ --output_path $output_path \
128
+ --log_samples \
129
+ --confirm_run_unsafe_code
130
+ done
131
+
132
+ ### NOTICE: use postprocess for humaneval
133
+ # python postprocess_code.py {the samples_xxx.jsonl file under output_path}
134
+
135
+
136
+ for i in "${!TASKS_ARRAY[@]}"; do
137
+ output_path="evals_dream${lora_model_name}/${TASKS_ARRAY[$i]}-ns${NSHOTS_ARRAY[$i]}-len${LENGTH_ARRAY[$i]}-temp${TEMP_ARRAY[$i]}-limit${LIMITS_ARRAY[$i]}-diffsteps${LENGTH_ARRAY[$i]}-block${BLOCK_SIZES_ARRAY[$i]}-thresh${BLOCK_ADD_THRESHOLDS_ARRAY[$i]}-decodethresh${DECODED_TOKEN_THRESHOLDS_ARRAY[$i]}-skip${SKIP_THRESHOLDS_ARRAY[$i]}-topp${TOP_PS_ARRAY[$i]}-dtype${DTYPES_ARRAY[$i]}-sampling${SAMPLING_STRATEGIES_ARRAY[$i]}"
138
+ echo "Task: ${TASKS_ARRAY[$i]}, Shots: ${NSHOTS_ARRAY[$i]}, Length: ${LENGTH_ARRAY[$i]}, Temperature: ${TEMP_ARRAY[$i]}, Limit: ${LIMITS_ARRAY[$i]}, Block Size: ${BLOCK_SIZES_ARRAY[$i]}, Block Add Threshold: ${BLOCK_ADD_THRESHOLDS_ARRAY[$i]}, Decoded Token Threshold: ${DECODED_TOKEN_THRESHOLDS_ARRAY[$i]}, Skip Threshold: ${SKIP_THRESHOLDS_ARRAY[$i]}, Top_p: ${TOP_PS_ARRAY[$i]}, Sampling Strategy: ${SAMPLING_STRATEGIES_ARRAY[$i]}, Dtype: ${DTYPES_ARRAY[$i]}; Output: $output_path"
139
+
140
+ if [[ "${TOP_PS_ARRAY[$i]}" == "none" ]]; then
141
+ model_args="pretrained=${base_model},lora_path=${lora_model},max_new_tokens=${LENGTH_ARRAY[$i]},diffusion_steps=${LENGTH_ARRAY[$i]},add_bos_token=true,temperature=${TEMP_ARRAY[$i]},block_size=${BLOCK_SIZES_ARRAY[$i]},block_add_threshold=${BLOCK_ADD_THRESHOLDS_ARRAY[$i]},skip_threshold=${SKIP_THRESHOLDS_ARRAY[$i]},decoded_token_threshold=${DECODED_TOKEN_THRESHOLDS_ARRAY[$i]},dtype=${DTYPES_ARRAY[$i]},sampling_strategy=${SAMPLING_STRATEGIES_ARRAY[$i]},save_dir=${output_path}"
142
+ else
143
+ model_args="pretrained=${base_model},lora_path=${lora_model},max_new_tokens=${LENGTH_ARRAY[$i]},diffusion_steps=${LENGTH_ARRAY[$i]},add_bos_token=true,temperature=${TEMP_ARRAY[$i]},top_p=${TOP_PS_ARRAY[$i]},block_size=${BLOCK_SIZES_ARRAY[$i]},block_add_threshold=${BLOCK_ADD_THRESHOLDS_ARRAY[$i]},skip_threshold=${SKIP_THRESHOLDS_ARRAY[$i]},decoded_token_threshold=${DECODED_TOKEN_THRESHOLDS_ARRAY[$i]},dtype=${DTYPES_ARRAY[$i]},sampling_strategy=${SAMPLING_STRATEGIES_ARRAY[$i]},save_dir=${output_path}"
144
+ fi
145
+
146
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch --main_process_port 29520 --num_processes 8 eval_dream.py --model dream_lora \
147
+ --model_args $model_args \
148
+ --tasks ${TASKS_ARRAY[$i]} \
149
+ --limit ${LIMITS_ARRAY[$i]} \
150
+ --num_fewshot ${NSHOTS_ARRAY[$i]} \
151
+ --batch_size 1 \
152
+ --output_path $output_path \
153
+ --log_samples \
154
+ --confirm_run_unsafe_code
155
+ done
156
+ done
157
+
158
+ echo "All evaluations completed!"
eval_dream_d2f_vllm.py ADDED
@@ -0,0 +1,764 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import gc
3
+ import time
4
+ import json
5
+ from datetime import timedelta
6
+ from typing import List, Optional, Tuple, Type, TypeVar, Union
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import torch.distributions as dists
10
+ import transformers
11
+ from accelerate import (
12
+ Accelerator,
13
+ InitProcessGroupKwargs,
14
+ )
15
+ from datasets import Dataset
16
+ from packaging import version
17
+ from tqdm import tqdm
18
+ from peft import PeftConfig, PeftModel
19
+ import numpy as np
20
+
21
+ from lm_eval import utils
22
+ from lm_eval.api.instance import Instance
23
+ from lm_eval.api.model import LM
24
+ from lm_eval.api.registry import register_model
25
+ from lm_eval.models.utils import get_dtype
26
+ from lm_eval.__main__ import cli_evaluate
27
+
28
+ eval_logger = logging.getLogger(__name__)
29
+ T = TypeVar("T", bound="LM")
30
+ import random
31
+ def set_seed(seed):
32
+ torch.manual_seed(seed)
33
+ random.seed(seed)
34
+ np.random.seed(seed)
35
+
36
+ torch.backends.cudnn.deterministic = True
37
+ torch.backends.cudnn.benchmark = False
38
+
39
+ def shift_logits(logits):
40
+ shifted_logits = torch.zeros_like(logits)
41
+ shifted_logits[:, 1:, :] = logits[:, :-1, :]
42
+ shifted_logits[:, 0, :] = 1.0
43
+ return shifted_logits
44
+
45
+ def create_full_block_attention_mask(prompt_length, max_length, block_size, device=None, dtype=None):
46
+ """
47
+ Creates a complete attention mask for the entire sequence with block-based causal attention.
48
+
49
+ Args:
50
+ prompt_length: Length of the prompt (first irregular block)
51
+ max_length: Maximum total sequence length
52
+ block_size: Size of each regular block
53
+ device: Device to create tensor on
54
+ dtype: Data type for the attention mask
55
+
56
+ Returns:
57
+ attention_mask: Tensor of shape [1, 1, max_length, max_length]
58
+ """
59
+ # Use the provided dtype or default to bfloat16
60
+ if dtype is None:
61
+ dtype = torch.bfloat16
62
+
63
+ # Initialize mask with -inf (no attention)
64
+ attention_mask = torch.full((1, 1, max_length, max_length), -torch.inf, device=device, dtype=dtype)
65
+
66
+ # Block 0: Prompt (can see itself)
67
+ attention_mask[:, :, :prompt_length, :prompt_length] = 0
68
+
69
+ # Calculate the number of regular blocks after prompt
70
+ remaining_length = max_length - prompt_length
71
+ num_blocks = (remaining_length + block_size - 1) // block_size
72
+
73
+ # Process each regular block
74
+ for b in range(num_blocks):
75
+ block_start = prompt_length + b * block_size
76
+ block_end = min(prompt_length + (b + 1) * block_size, max_length)
77
+
78
+ # Current block can see the prompt
79
+ attention_mask[:, :, block_start:block_end, :prompt_length] = 0
80
+
81
+ # Current block can see all previous regular blocks
82
+ for prev_b in range(b):
83
+ prev_start = prompt_length + prev_b * block_size
84
+ prev_end = min(prompt_length + (prev_b + 1) * block_size, max_length)
85
+ attention_mask[:, :, block_start:block_end, prev_start:prev_end] = 0
86
+
87
+ # Current block can see itself (full attention within block)
88
+ attention_mask[:, :, block_start:block_end, block_start:block_end] = 0
89
+
90
+ return attention_mask
91
+
92
+ def extract_attention_mask(full_mask, start_pos, input_length, cache_length):
93
+ """
94
+ Extract the relevant portion of attention mask for current forward pass.
95
+
96
+ Args:
97
+ full_mask: Complete attention mask [1, 1, max_length, max_length]
98
+ start_pos: Starting position in the full sequence
99
+ input_length: Length of current input sequence
100
+ cache_length: Length of cached sequence
101
+
102
+ Returns:
103
+ attention_mask: Extracted mask [1, 1, input_length, cache_length + input_length]
104
+ """
105
+ end_pos = start_pos + input_length
106
+ total_length = cache_length + input_length
107
+
108
+ # Extract the relevant rows (current input positions)
109
+ # and columns (cache + current input positions)
110
+ extracted_mask = torch.full((1, 1, input_length, total_length), -torch.inf,
111
+ device=full_mask.device, dtype=full_mask.dtype)
112
+
113
+ # Copy cache columns (0 to cache_length in the extracted mask corresponds to 0 to cache_length in full mask)
114
+ extracted_mask[:, :, :, :cache_length] = full_mask[:, :, start_pos:end_pos, :cache_length]
115
+
116
+ # Copy current input columns
117
+ extracted_mask[:, :, :, cache_length:] = full_mask[:, :, start_pos:end_pos, start_pos:end_pos]
118
+
119
+ return extracted_mask
120
+
121
+ def build_custom_float_attention_mask(input_ids, prompt_length, block_size, device=None, dtype=None):
122
+ B, seq_len = input_ids.shape
123
+ # Use the provided dtype or default to float32
124
+ if dtype is None:
125
+ dtype = torch.float32
126
+ # Initialize to all -inf
127
+ attn_mask = torch.full((B, 1, seq_len, seq_len), float('-inf'), dtype=dtype, device=device)
128
+ # 1. Prompt part: each token can attend to the entire prompt
129
+ for i in range(B):
130
+ attn_mask[i, :, :, :prompt_length[i]] = 0.0 # Allow all tokens to see the prompt
131
+
132
+ # 2. Block division: divide into blocks starting from prompt_length
133
+ num_blocks = (seq_len - prompt_length[i] + block_size - 1) // block_size
134
+
135
+ for b in range(num_blocks):
136
+ block_start = prompt_length[i] + b * block_size
137
+ block_end = min(block_start + block_size, seq_len)
138
+
139
+ # Full attention within the block
140
+ attn_mask[i, :, block_start:block_end, block_start:block_end] = 0.0
141
+
142
+ # Causal attention between blocks (can only see previous blocks)
143
+ for prev_b in range(b):
144
+ prev_start = prompt_length[i] + prev_b * block_size
145
+ prev_end = min(prev_start + block_size, seq_len)
146
+
147
+ # Current block can see previous blocks
148
+ attn_mask[i, :, block_start:block_end, prev_start:prev_end] = 0.0
149
+
150
+ return attn_mask
151
+
152
+ def top_p_logits(logits, top_p=None):
153
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
154
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
155
+ sorted_indices_to_remove = cumulative_probs > top_p
156
+ # Shift the indices to the right to keep the first token above the threshold
157
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
158
+ sorted_indices_to_remove[..., 0] = 0
159
+
160
+ mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
161
+ mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
162
+ logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
163
+ return logits
164
+
165
+ def top_k_logits(logits, top_k=None):
166
+ top_k = min(top_k, logits.size(-1)) # Safety check
167
+ # Remove all tokens with a probability less than the last token of the top-k
168
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
169
+ logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
170
+ return logits
171
+
172
+ def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
173
+ if temperature > 0:
174
+ logits = logits / temperature
175
+ if top_p is not None and top_p < 1:
176
+ logits = top_p_logits(logits, top_p)
177
+ if top_k is not None:
178
+ logits = top_k_logits(logits, top_k)
179
+ probs = torch.softmax(logits, dim=-1)
180
+
181
+ if temperature > 0:
182
+ try:
183
+ x0 = dists.Categorical(probs=probs).sample()
184
+ initial_confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
185
+ except:
186
+ initial_confidence, x0 = probs.max(dim=-1)
187
+ else:
188
+ initial_confidence, x0 = probs.max(dim=-1)
189
+
190
+ # Save initial confidence
191
+ confidence = initial_confidence.clone()
192
+
193
+ if margin_confidence:
194
+ sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
195
+ # Extract top1 and top2 probabilities
196
+ top1_probs = sorted_probs[:, 0]
197
+ top2_probs = sorted_probs[:, 1]
198
+ # Calculate confidence as top1 - top2
199
+ confidence = top1_probs - top2_probs
200
+
201
+ if neg_entropy:
202
+ epsilon = 1e-10
203
+ log_probs = torch.log(probs + epsilon)
204
+ confidence = torch.sum(probs * log_probs, dim=-1)
205
+
206
+ return confidence, x0, initial_confidence
207
+
208
+ @register_model("dream_lora")
209
+ class DreamLoRA(LM):
210
+ def __init__(
211
+ self,
212
+ pretrained: Union[str, transformers.PreTrainedModel],
213
+ lora_path: str,
214
+ batch_size: Optional[Union[int, str]] = 1,
215
+ device: Optional[str] = "cuda",
216
+ dtype: Optional[Union[str, torch.dtype]] = "auto",
217
+ max_new_tokens: Optional[int] = 128,
218
+ max_length: Optional[int] = 2048, # Updated to match example code
219
+ add_bos_token: Optional[bool] = False,
220
+ nll_type: Optional[str] = "mc",
221
+ log_type: Optional[str] = "ftb",
222
+ mc_num: Optional[int] = 128,
223
+ classifier_free_guidance: Optional[float] = 1.0,
224
+ sampling_eps: Optional[float] = 1e-3,
225
+ diffusion_steps: Optional[int] = 128,
226
+ trust_remote_code: Optional[bool] = True,
227
+ parallelize: Optional[bool] = False,
228
+ autogptq: Optional[Union[bool, str]] = False,
229
+ temperature: Optional[float] = 0.2, # Updated default
230
+ top_p: Optional[float] = None, # Updated default
231
+ top_k: Optional[float] = None,
232
+ alg: Optional[str] = "entropy",
233
+ alg_temp: Optional[float] = 0.0,
234
+ escape_until: Optional[bool] = False,
235
+ block_size: Optional[int] = 4, # Updated to match example code
236
+ mask_token_id: Optional[int] = 151666, # Added mask_token_id parameter
237
+ block_add_threshold: Optional[float] = 0.5, # Added block_add_threshold parameter
238
+ decoded_token_threshold: Optional[int] = 0.9, # Added decoded_token_threshold parameter
239
+ skip_threshold: Optional[float] = 1.0, # Added skip_threshold parameter
240
+ sampling_strategy: Optional[str] = "default", # Added sampling_strategy parameter
241
+ save_dir: Optional[str] = None,
242
+ **kwargs,
243
+ ) -> None:
244
+ super().__init__()
245
+
246
+ # prepare for parallelism
247
+ assert isinstance(device, str)
248
+ assert isinstance(pretrained, str)
249
+ assert isinstance(batch_size, (int, str))
250
+
251
+ gpus = torch.cuda.device_count()
252
+ accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
253
+ accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
254
+ if accelerator.num_processes > 1:
255
+ self.accelerator = accelerator
256
+
257
+ if "npu" in accelerator.device.type:
258
+ gpus = torch.npu.device_count()
259
+
260
+ # using one process with no model parallelism
261
+ if not (parallelize or accelerator.num_processes > 1):
262
+ # use user-passed device
263
+ device_list = set(
264
+ ["cuda", "cpu"]
265
+ + [f"cuda:{i}" for i in range(gpus)]
266
+ + ["mps", "mps:0"]
267
+ + [f"npu:{i}" for i in range(gpus)]
268
+ )
269
+ if device and device in device_list:
270
+ self._device = torch.device(device)
271
+ eval_logger.info(f"Using device '{device}'")
272
+ if device in ("mps", "mps:0") and version.parse(
273
+ torch.__version__
274
+ ) < version.parse("2.1"):
275
+ raise RuntimeError(
276
+ f"mps requires torch >= 2.1. You have {torch.__version__}"
277
+ )
278
+ else:
279
+ eval_logger.info("Device not specified")
280
+ eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
281
+ self._device = (
282
+ torch.device("cuda")
283
+ if torch.cuda.is_available()
284
+ else torch.device("cpu")
285
+ )
286
+ else: # Parallelism managed by accelerate
287
+ if device != "cuda":
288
+ eval_logger.info(
289
+ f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model."
290
+ )
291
+ # TODO: include in warning that `load_in_8bit` etc. affect this too
292
+ self._device = (
293
+ self.accelerator.device
294
+ if hasattr(self, "accelerator")
295
+ else torch.device(device)
296
+ )
297
+
298
+ self.batch_size_per_gpu = batch_size
299
+ if isinstance(batch_size, str):
300
+ self.batch_size_per_gpu = int(batch_size)
301
+
302
+ # Save LoRA path and block_size
303
+ self.lora_path = lora_path
304
+ self.block_size = block_size
305
+ self.block_add_threshold = block_add_threshold # New block_add_threshold attribute
306
+ self.skip_threshold = skip_threshold # New skip_threshold attribute
307
+ self.sampling_strategy = sampling_strategy # Save sampling strategy parameter
308
+ self.decoded_token_threshold = decoded_token_threshold # New decoded_token_threshold attribute
309
+ self.save_dir = save_dir
310
+
311
+ # Add metric tracking
312
+ self.total_forward_passes = 0
313
+ self.total_generated_tokens = 0
314
+ self.total_prompts = 0
315
+ # Add time and token statistics
316
+ self.total_generation_time = 0.0
317
+ self.total_block_tokens = 0 # Number of blocks * block_size
318
+ self.total_actual_tokens = 0 # Actual generated tokens (excluding EOS)
319
+ self.total_non_eos_tokens = 0 # Total non-EOS tokens in the entire sequence
320
+ self.all_generation_times = []
321
+ self.all_block_tokens = []
322
+ self.all_actual_tokens = []
323
+ self.all_non_eos_tokens = []
324
+
325
+ # Save target_dtype for later use
326
+ self.target_dtype = get_dtype(dtype)
327
+
328
+ # if isinstance(pretrained, str):
329
+ # if gpus >= 1 or str(self.device) == "mps":
330
+ # # TODO: can remove this whole snippet except in the mps case, perhaps?
331
+ # if not (parallelize or autogptq or hasattr(self, "accelerator")):
332
+ # # place model onto device requested manually,
333
+ # # if not using HF Accelerate or device_map
334
+ # # or any other option that preloads model onto device
335
+ # try:
336
+ # self.model.to(self.device)
337
+ # except ValueError:
338
+ # eval_logger.debug(
339
+ # "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided. If the desired GPU is being used, this message is safe to ignore."
340
+ # )
341
+ # # multigpu data-parallel support when launched with accelerate
342
+ # if gpus > 1:
343
+ # if accelerator.num_processes > 1:
344
+ # if parallelize:
345
+ # eval_logger.warning(
346
+ # "You are both using a HF Accelerate `device_map` (`--model_args parallelize=True`) and launching via `accelerate launch`. This will attempt to do model and data parallelism depending on the resources available."
347
+ # )
348
+ # elif gpus > accelerator.num_processes:
349
+ # eval_logger.warning(
350
+ # "WARNING: The number of total system GPUs does not match the number of spawned processes. "
351
+ # "If you would like to use data parallelism, please launch the script "
352
+ # "with 'accelerate launch *script*'. "
353
+ # f"Current run will proceed with {accelerator.num_processes} devices."
354
+ # )
355
+ # if self.accelerator.is_local_main_process:
356
+ # eval_logger.info(
357
+ # f"Using {gpus} devices with data parallelism"
358
+ # )
359
+
360
+ # self._device = torch.device(f"{accelerator.device}")
361
+ # self.accelerator = accelerator
362
+
363
+ # self._rank = self.accelerator.local_process_index
364
+ # self._world_size = self.accelerator.num_processes
365
+ # else:
366
+ # # if we aren't launching via accelerate, ditch
367
+ # self._rank = 0
368
+ # self._world_size = 1
369
+ # else:
370
+ # # if a PreTrainedModel was passed into HFLM, we forgo distributed setup.
371
+ # eval_logger.warning(
372
+ # "Passed an already-initialized model through `pretrained`, assuming single-process call to evaluate() or custom distributed integration"
373
+ # )
374
+ # self._rank = 0
375
+ # self._world_size = 1
376
+
377
+ self.max_length = max_length
378
+ self.add_bos_token = add_bos_token
379
+ # generation params
380
+ self.max_new_tokens = max_new_tokens
381
+ self.diffusion_steps = diffusion_steps
382
+ self.temperature = temperature
383
+ self.top_p = top_p
384
+ self.top_k = top_k
385
+ self.alg = alg
386
+ self.alg_temp = alg_temp
387
+ self.escape_until = escape_until
388
+ self.block_size = block_size
389
+ self.mask_token_id = mask_token_id
390
+
391
+ # loglikelihood params
392
+ self.nll_type = nll_type
393
+ self.log_type = log_type
394
+ self.mc_num = mc_num
395
+ self.classifier_free_guidance = classifier_free_guidance
396
+ self.sampling_eps = sampling_eps
397
+
398
+ self._create_model_and_tokenizer(pretrained, dtype, trust_remote_code)
399
+
400
+ @property
401
+ def batch_size(self):
402
+ return self.batch_size_per_gpu
403
+
404
+ @property
405
+ def device(self):
406
+ return self._device
407
+
408
+ @property
409
+ def rank(self):
410
+ return self._rank
411
+
412
+ @property
413
+ def world_size(self):
414
+ return self._world_size
415
+
416
+ def _create_model_and_tokenizer(self, pretrained, dtype, trust_remote_code):
417
+ from d2f_vllm import LLM, SamplingParams
418
+
419
+ self.LLM = LLM(
420
+ pretrained,
421
+ lora_path=self.lora_path,
422
+ use_lora=True,
423
+ model_name="dream",
424
+ model_type="diffusion_lm",
425
+ enforce_eager=True,
426
+ tensor_parallel_size=1,
427
+ gpu_memory_utilization=0.60,
428
+ max_num_batched_tokens=2048,
429
+ max_num_seqs=20,
430
+ max_model_len=1024,
431
+ accept_threshold=self.skip_threshold,
432
+ complete_threshold=self.decoded_token_threshold,
433
+ add_new_block_threshold=1-self.block_add_threshold,
434
+ kv_cache_layout="unified"
435
+ )
436
+ self.tokenizer = self.LLM.tokenizer
437
+ self.sampling_params = SamplingParams(temperature=self.temperature, max_tokens=self.max_new_tokens)
438
+
439
+
440
+ def tok_decode(self, tokens, skip_special_tokens=True):
441
+ return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
442
+
443
+ def tok_encode(self, text, add_special_tokens=True):
444
+ return self.tokenizer(
445
+ text, return_tensors="pt", add_special_tokens=add_special_tokens
446
+ ).input_ids
447
+
448
+ @classmethod
449
+ def create_from_arg_string(
450
+ cls: Type[T], arg_string: str, additional_config: Optional[dict] = None
451
+ ) -> T:
452
+ """
453
+ Creates an instance of the LM class using the given argument string and additional config.
454
+
455
+ Parameters:
456
+ - arg_string: A string containing arguments in the format key1=value1,key2=value2.
457
+ - additional_config: Optional dictionary containing additional configuration parameters.
458
+
459
+ Returns:
460
+ - Instance of the LM class.
461
+ """
462
+ additional_config = {} if additional_config is None else additional_config
463
+ args = utils.simple_parse_args_string(arg_string)
464
+ args2 = {k: v for k, v in additional_config.items() if v is not None}
465
+ return cls(**args, **args2)
466
+
467
+ def apply_chat_template(
468
+ self, chat_history, add_generation_prompt: bool = True
469
+ ) -> str:
470
+ """
471
+ Method to apply a chat template to a list of chat history between user and model.
472
+ """
473
+ chat_templated = self.tokenizer.apply_chat_template(
474
+ chat_history,
475
+ tokenize=False,
476
+ add_generation_prompt=add_generation_prompt,
477
+ continue_final_message=not add_generation_prompt,
478
+ )
479
+
480
+ return chat_templated
481
+
482
+ @property
483
+ def tokenizer_name(self) -> str:
484
+ return self.tokenizer.name_or_path.replace("/", "__")
485
+
486
+ def generate_until(self, requests: List[Instance], disable_tqdm: bool = False):
487
+ res = []
488
+
489
+ # Initialize statistics counters
490
+ if not hasattr(self, 'total_generated_tokens'):
491
+ self.total_generated_tokens = 0
492
+ num_tokens = 0
493
+ num_nfe = 0 # Number of Forward Evaluations
494
+
495
+ prompts, gen_args = [], []
496
+ print("Preparing prompts...")
497
+ for req in tqdm(requests):
498
+ prompts.append(self.tokenizer.bos_token + req.arguments[0])
499
+ gen_args.append(req.arguments[1])
500
+
501
+ start_time = time.time()
502
+
503
+ outputs = self.LLM.generate(prompts, self.sampling_params)
504
+
505
+ end_time = time.time()
506
+ total_time = end_time - start_time
507
+
508
+ # Accumulate statistics
509
+ res = [output['text'] for output in outputs]
510
+ num_tokens = sum(len(output['token_ids']) for output in outputs)
511
+ num_nfe = sum(output['n_diff_steps'] for output in outputs)
512
+
513
+ # Save final statistics
514
+ final_stats = {
515
+ 'processed_samples': len(requests),
516
+ 'total_samples': len(requests),
517
+ 'total_tokens': num_tokens,
518
+ 'total_nfe': num_nfe,
519
+ 'total_time': total_time,
520
+ 'tokens_per_second': num_tokens / total_time if total_time > 0 else 0,
521
+ 'nfe_per_token': num_nfe / num_tokens if num_tokens > 0 else 0,
522
+ 'timestamp': time.strftime('%Y-%m-%d %H:%M:%S')
523
+ }
524
+
525
+ # Save statistics to file
526
+ if self.save_dir is not None:
527
+ import os
528
+ os.makedirs(self.save_dir, exist_ok=True)
529
+
530
+ # Save response results
531
+ save_path = os.path.join(self.save_dir, f'rank_{self.rank}_responses.jsonl')
532
+ with open(save_path, 'w', encoding='utf-8') as f:
533
+ for r in res:
534
+ f.write(json.dumps(r, ensure_ascii=False) + '\n')
535
+
536
+ # Save statistics results
537
+ stats_path = os.path.join(self.save_dir, f'rank_{self.rank}_final_stats.json')
538
+ with open(stats_path, 'w', encoding='utf-8') as f:
539
+ json.dump(final_stats, f, ensure_ascii=False, indent=2)
540
+
541
+ # Print final statistics
542
+ print("\n" + "="*60)
543
+ print("=== Final Statistics ===")
544
+ print("="*60)
545
+ print(f"Processed Samples: {final_stats['processed_samples']}")
546
+ print(f"Total Samples: {final_stats['total_samples']}")
547
+ print(f"Total Tokens: {final_stats['total_tokens']}")
548
+ print(f"Total NFE: {final_stats['total_nfe']}")
549
+ print(f"Total Time: {final_stats['total_time']:.4f}s")
550
+ print(f"Tokens/Second: {final_stats['tokens_per_second']:.2f}")
551
+ print(f"NFE/Token: {final_stats['nfe_per_token']:.4f}")
552
+ print(f"Completion Time: {final_stats['timestamp']}")
553
+ print("="*60)
554
+
555
+ return res
556
+
557
+ def _forward_process(self, batch):
558
+ b, l = batch.shape
559
+ # sample from U[0, 1] following https://arxiv.org/pdf/2107.00630 I.1
560
+ u0 = torch.rand(1, device=batch.device, dtype=torch.float32)
561
+ indices = torch.arange(b, device=batch.device).float()
562
+ t = (u0 + indices / b) % 1
563
+
564
+ p_mask = (1 - self.sampling_eps) * t + self.sampling_eps
565
+
566
+ p_mask = p_mask[:, None].repeat(1, l)
567
+
568
+ mask_indices = torch.rand((b, l), device=batch.device) < p_mask
569
+ # always unmask bos and eos
570
+ mask_indices[:, 0] = False
571
+ mask_indices[:, -1] = False
572
+
573
+ noisy_batch = torch.where(mask_indices, self.mask_token_id, batch)
574
+ return noisy_batch, p_mask
575
+
576
+ @torch.no_grad()
577
+ def get_logits(self, batch, prompt_index):
578
+ '''
579
+ prompt_index : 1D bool tensor, length=batch.shape[1]
580
+ '''
581
+ if self.classifier_free_guidance > 1.:
582
+ assert len(prompt_index) == batch.shape[1]
583
+ prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1)
584
+ un_batch = batch.clone()
585
+ un_batch[prompt_index] = self.mask_token_id
586
+ batch = torch.cat([batch, un_batch])
587
+
588
+ input = batch
589
+
590
+ with torch.amp.autocast('cuda', dtype=torch.bfloat16):
591
+ logits = self.model(input).logits
592
+ # since bos always unmask, the first logits will not be used
593
+ logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
594
+
595
+ if self.classifier_free_guidance > 1.:
596
+ logits, un_logits = torch.chunk(logits, 2, dim=0)
597
+ logits = un_logits + self.cfg * (logits - un_logits)
598
+ return logits[:, :batch.shape[1]]
599
+
600
+ @torch.no_grad()
601
+ def _eval_target_nll_mc(self, prefix, target):
602
+ if prefix is None:
603
+ seq = target[None, :]
604
+ else:
605
+ seq = torch.concatenate([prefix, target])[None, :]
606
+ seq = seq.repeat((self.batch_size, 1)).to(self.device)
607
+
608
+ if self.log_type == 'ftb':
609
+ prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix)
610
+ else:
611
+ prompt_index = torch.arange(seq.shape[1], device=self.device) >= len(prefix)
612
+
613
+ loss_acc = []
614
+ for _ in range(max(self.mc_num // self.batch_size, 1)):
615
+ perturbed_seq = seq.clone()
616
+ # eval_logger.info("before noising")
617
+ perturbed_seq_, p_mask = self._forward_process(seq)
618
+ # eval_logger.info("end noising")
619
+ if self.log_type == 'ftb':
620
+ perturbed_seq[:, -len(target):] = perturbed_seq_[:, -len(target):]
621
+ elif self.log_type == 'btf':
622
+ perturbed_seq[:, :len(prefix)] = perturbed_seq_[:, :len(prefix)]
623
+ elif self.log_type == 'union':
624
+ perturbed_seq = perturbed_seq_
625
+ else:
626
+ raise NotImplementedError(self.log_type)
627
+
628
+ mask_indices = perturbed_seq == self.mask_token_id
629
+ logits = self.get_logits(perturbed_seq, prompt_index)
630
+ loss = F.cross_entropy(logits[mask_indices], seq[mask_indices], reduction='none') / p_mask[mask_indices]
631
+ loss = loss.sum() / self.batch_size
632
+ loss_acc.append(loss.item())
633
+
634
+ return sum(loss_acc) / len(loss_acc)
635
+
636
+ @torch.no_grad()
637
+ def _eval_target_nll_ar(self, prefix, target):
638
+ prefix, target = prefix.unsqueeze(0), target.unsqueeze(0) # 1*l1, 1*l2
639
+ assert self.log_type in ['ftb', 'btf']
640
+ assert self.nll_type in ['ar_ftb', 'ar_btf']
641
+
642
+ if self.log_type == 'ftb':
643
+ prompt_index = torch.arange(prefix.shape[1] + target.shape[1], device=self.device) < prefix.shape[1]
644
+ else:
645
+ prompt_index = torch.arange(prefix.shape[1] + target.shape[1], device=self.device) >= prefix.shape[1]
646
+
647
+ if self.log_type == 'ftb':
648
+ perturbed_ = target.repeat(target.shape[1], 1).clone().contiguous() # l2*l2
649
+ else:
650
+ perturbed_ = prefix.repeat(prefix.shape[1], 1).clone().contiguous() # l1*l1
651
+
652
+ mask_index = torch.ones((perturbed_.shape[1], perturbed_.shape[1]), dtype=torch.bool)
653
+ if self.nll_type == 'ar_ftb':
654
+ mask_index = torch.triu(mask_index)
655
+ else:
656
+ mask_index = torch.tril(mask_index)
657
+ perturbed_[mask_index] = self.mask_token_id
658
+ if self.log_type == 'ftb':
659
+ perturbed_seq = torch.cat([prefix.repeat(perturbed_.shape[0], 1), perturbed_], dim=-1)
660
+ else:
661
+ perturbed_seq = torch.cat([perturbed_, target.repeat(perturbed_.shape[0], 1)], dim=-1)
662
+
663
+ logits_ = []
664
+ num = len(perturbed_seq) // self.batch_size if len(perturbed_seq) % self.batch_size == 0 else len(perturbed_seq) // self.batch_size + 1
665
+ for i in range(num):
666
+ end = (i + 1) * self.batch_size if (i + 1) * self.batch_size < len(perturbed_seq) else len(perturbed_seq)
667
+ perturbed_seq_ = perturbed_seq[i * self.batch_size: end]
668
+ perturbed_seq_ = perturbed_seq_.to(self.device)
669
+ if len(perturbed_seq_.shape) == 1:
670
+ perturbed_seq_ = perturbed_seq_.unsqueeze(0)
671
+ logits = self.get_logits(perturbed_seq_, prompt_index)
672
+ logits_.append(logits.cpu())
673
+ logits = torch.cat(logits_, dim=0)
674
+
675
+ temp_index = torch.ones((perturbed_.shape[1], perturbed_.shape[1]), dtype=torch.bool)
676
+ if self.nll_type == 'ar_ftb':
677
+ temp_index = torch.triu(temp_index, diagonal=1)
678
+ else:
679
+ temp_index = torch.tril(temp_index, diagonal=-1)
680
+ mask_index[temp_index] = False
681
+ if self.log_type == 'ftb':
682
+ logits_index = torch.cat([torch.zeros((perturbed_.shape[1], prefix.shape[1]), dtype=torch.bool), mask_index], dim=-1)
683
+ else:
684
+ logits_index = torch.cat([mask_index, torch.zeros((perturbed_.shape[1], target.shape[1]), dtype=torch.bool)], dim=-1)
685
+
686
+ if self.log_type == 'ftb':
687
+ loss = F.cross_entropy(logits[logits_index], target[0], reduction='sum').cpu().item()
688
+ else:
689
+ loss = F.cross_entropy(logits[logits_index], prefix[0], reduction='sum').cpu().item()
690
+ return loss
691
+
692
+ def _encode_pair(self, context, continuation):
693
+ if self.add_bos_token:
694
+ context = self.tokenizer.bos_token + context
695
+
696
+ n_spaces = len(context) - len(context.rstrip())
697
+ if n_spaces > 0:
698
+ continuation = context[-n_spaces:] + continuation
699
+ context = context[:-n_spaces]
700
+
701
+ whole_enc = self.tokenizer.encode(context + continuation) + [self.tokenizer.eos_token_id]
702
+ context_enc = self.tokenizer.encode(context)
703
+
704
+ context_enc_len = len(context_enc)
705
+ continuation_enc = whole_enc[context_enc_len:]
706
+
707
+ # by default truncate on the left
708
+ cutoff_length = max(len(whole_enc) - self.max_length, 0)
709
+ if cutoff_length > 0:
710
+ eval_logger.warning(f"Text length {len(whole_enc)} is larger than {self.max_length}, cutoff on the left side")
711
+ context_remain = context_enc_len-cutoff_length
712
+ if context_remain > 0:
713
+ context_enc = context_enc[-context_remain:]
714
+ else:
715
+ eval_logger.warning(f"All context (prompt) is truncated.")
716
+ context_enc = ""
717
+ continuation_enc = whole_enc[-self.max_length:]
718
+ return context_enc, continuation_enc
719
+
720
+ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
721
+ def _tokenize(e):
722
+ prefix, target = self._encode_pair(e["prefix"], e["target"])
723
+ return {
724
+ "prefix_text": e["prefix"],
725
+ "target_text": e["target"],
726
+ "prefix": prefix,
727
+ "target": target,
728
+ }
729
+
730
+ ds = []
731
+ ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests]
732
+ ds = Dataset.from_list(ds)
733
+ print(ds[0])
734
+ ds = ds.map(_tokenize)
735
+ ds = ds.with_format("torch")
736
+
737
+ out = []
738
+ with torch.no_grad():
739
+ for elem in tqdm(ds, desc="Computing likelihood..."):
740
+ prefix = elem["prefix"]
741
+ target = elem["target"]
742
+ # likelihood calculations are modified from https://github.com/ML-GSAI/SMDM/blob/main/evaluate_diff.py
743
+ if self.nll_type == 'mc':
744
+ ll = -self._eval_target_nll_mc(prefix, target)
745
+ if self.log_type == 'union':
746
+ ll = ll / (len(target) + len(prefix))
747
+ elif self.nll_type == 'ar_ftb' or self.nll_type == 'ar_btf':
748
+ ll = -self._eval_target_nll_ar(prefix, target)
749
+ else:
750
+ raise NotImplementedError(self.nll_type)
751
+
752
+ # TODO: greedy decoding
753
+ is_target_greedy_dec = False
754
+
755
+ out.append((ll, 1.0 if is_target_greedy_dec else 0.0))
756
+ return out
757
+
758
+ def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
759
+ raise NotImplementedError
760
+
761
+
762
+ if __name__ == "__main__":
763
+ set_seed(1234)
764
+ cli_evaluate()
eval_dream_d2f_vllm.sh ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tasks="gsm8k_cot mbpp minerva_math"
2
+ nshots="8 3 4"
3
+ lengths="256 256 256"
4
+ temperatures="0 0 0"
5
+ limits="10000 10000 10000"
6
+ block_sizes="32 48 64"
7
+ block_add_thresholds="0.1 0.1 0.1"
8
+ decoded_token_thresholds="0.95 0.95 0.95"
9
+ skip_thresholds="0.9 0.9 0.9"
10
+ top_ps="none none none"
11
+ dtypes="bfloat16 bfloat16 bfloat16"
12
+ sampling_strategies="default default default"
13
+
14
+ humaneval_nshots="0"
15
+ humaneval_lengths="256"
16
+ humaneval_temperatures="0"
17
+ humaneval_limits="10000"
18
+ humaneval_diffusion_steps="256"
19
+ humaneval_block_sizes="32"
20
+ humaneval_block_add_thresholds="0.9"
21
+ humaneval_decoded_token_thresholds="0.95"
22
+ humaneval_skip_thresholds="0.95"
23
+ humaneval_top_ps="none"
24
+ humaneval_dtypes="bfloat16"
25
+ humaneval_sampling_strategies="default"
26
+
27
+ base_model=Dream-org/Dream-v0-Base-7B
28
+
29
+ lora_models=(
30
+ "SJTU-Deng-Lab/D2F_Dream_Base_7B_Lora"
31
+ )
32
+
33
+ read -ra TASKS_ARRAY <<< "$tasks"
34
+ read -ra NSHOTS_ARRAY <<< "$nshots"
35
+ read -ra LENGTH_ARRAY <<< "$lengths"
36
+ read -ra TEMP_ARRAY <<< "$temperatures"
37
+ read -ra LIMITS_ARRAY <<< "$limits"
38
+ read -ra BLOCK_SIZES_ARRAY <<< "$block_sizes"
39
+ read -ra BLOCK_ADD_THRESHOLDS_ARRAY <<< "$block_add_thresholds"
40
+ read -ra DECODED_TOKEN_THRESHOLDS_ARRAY <<< "$decoded_token_thresholds"
41
+ read -ra SKIP_THRESHOLDS_ARRAY <<< "$skip_thresholds"
42
+ read -ra TOP_PS_ARRAY <<< "$top_ps"
43
+ read -ra DTYPES_ARRAY <<< "$dtypes"
44
+ read -ra SAMPLING_STRATEGIES_ARRAY <<< "$sampling_strategies"
45
+
46
+ read -ra HUMANEVAL_NSHOTS_ARRAY <<< "$humaneval_nshots"
47
+ read -ra HUMANEVAL_LENGTHS_ARRAY <<< "$humaneval_lengths"
48
+ read -ra HUMANEVAL_TEMP_ARRAY <<< "$humaneval_temperatures"
49
+ read -ra HUMANEVAL_LIMITS_ARRAY <<< "$humaneval_limits"
50
+ read -ra HUMANEVAL_DIFFUSION_STEPS_ARRAY <<< "$humaneval_diffusion_steps"
51
+ read -ra HUMANEVAL_BLOCK_SIZES_ARRAY <<< "$humaneval_block_sizes"
52
+ read -ra HUMANEVAL_BLOCK_ADD_THRESHOLDS_ARRAY <<< "$humaneval_block_add_thresholds"
53
+ read -ra HUMANEVAL_DECODED_TOKEN_THRESHOLDS_ARRAY <<< "$humaneval_decoded_token_thresholds"
54
+ read -ra HUMANEVAL_SKIP_THRESHOLDS_ARRAY <<< "$humaneval_skip_thresholds"
55
+ read -ra HUMANEVAL_TOP_PS_ARRAY <<< "$humaneval_top_ps"
56
+ read -ra HUMANEVAL_DTYPES_ARRAY <<< "$humaneval_dtypes"
57
+ read -ra HUMANEVAL_SAMPLING_STRATEGIES_ARRAY <<< "$humaneval_sampling_strategies"
58
+
59
+ array_length=${#TASKS_ARRAY[@]}
60
+ if [[ ${#NSHOTS_ARRAY[@]} -ne $array_length ]] || \
61
+ [[ ${#LENGTH_ARRAY[@]} -ne $array_length ]] || \
62
+ [[ ${#TEMP_ARRAY[@]} -ne $array_length ]] || \
63
+ [[ ${#LIMITS_ARRAY[@]} -ne $array_length ]] || \
64
+ [[ ${#BLOCK_SIZES_ARRAY[@]} -ne $array_length ]] || \
65
+ [[ ${#BLOCK_ADD_THRESHOLDS_ARRAY[@]} -ne $array_length ]] || \
66
+ [[ ${#DECODED_TOKEN_THRESHOLDS_ARRAY[@]} -ne $array_length ]] || \
67
+ [[ ${#SKIP_THRESHOLDS_ARRAY[@]} -ne $array_length ]] || \
68
+ [[ ${#TOP_PS_ARRAY[@]} -ne $array_length ]] || \
69
+ [[ ${#SAMPLING_STRATEGIES_ARRAY[@]} -ne $array_length ]] || \
70
+ [[ ${#DTYPES_ARRAY[@]} -ne $array_length ]]; then
71
+ echo "Error: All configuration arrays must have the same length!"
72
+ exit 1
73
+ fi
74
+
75
+ humaneval_array_length=${#HUMANEVAL_NSHOTS_ARRAY[@]}
76
+ if [[ ${#HUMANEVAL_LENGTHS_ARRAY[@]} -ne $humaneval_array_length ]] || \
77
+ [[ ${#HUMANEVAL_TEMP_ARRAY[@]} -ne $humaneval_array_length ]] || \
78
+ [[ ${#HUMANEVAL_LIMITS_ARRAY[@]} -ne $humaneval_array_length ]] || \
79
+ [[ ${#HUMANEVAL_DIFFUSION_STEPS_ARRAY[@]} -ne $humaneval_array_length ]] || \
80
+ [[ ${#HUMANEVAL_BLOCK_SIZES_ARRAY[@]} -ne $humaneval_array_length ]] || \
81
+ [[ ${#HUMANEVAL_BLOCK_ADD_THRESHOLDS_ARRAY[@]} -ne $humaneval_array_length ]] || \
82
+ [[ ${#HUMANEVAL_DECODED_TOKEN_THRESHOLDS_ARRAY[@]} -ne $humaneval_array_length ]] || \
83
+ [[ ${#HUMANEVAL_SKIP_THRESHOLDS_ARRAY[@]} -ne $humaneval_array_length ]] || \
84
+ [[ ${#HUMANEVAL_TOP_PS_ARRAY[@]} -ne $humaneval_array_length ]] || \
85
+ [[ ${#HUMANEVAL_DTYPES_ARRAY[@]} -ne $humaneval_array_length ]] || \
86
+ [[ ${#HUMANEVAL_SAMPLING_STRATEGIES_ARRAY[@]} -ne $humaneval_array_length ]]; then
87
+ echo "Error: All HumanEval configuration arrays must have the same length!"
88
+ exit 1
89
+ fi
90
+
91
+ export HF_ALLOW_CODE_EVAL=1
92
+ for lora_model in "${lora_models[@]}"; do
93
+ lora_model_name="$lora_model"
94
+ echo "===================================================================="
95
+ echo "Evaluating LoRA model: $lora_model_name"
96
+ echo "===================================================================="
97
+
98
+ for i in "${!HUMANEVAL_NSHOTS_ARRAY[@]}"; do
99
+ output_path="evals_dream${lora_model_name}/humaneval-ns${HUMANEVAL_NSHOTS_ARRAY[$i]}-len${HUMANEVAL_LENGTHS_ARRAY[$i]}-temp${HUMANEVAL_TEMP_ARRAY[$i]}-limit${HUMANEVAL_LIMITS_ARRAY[$i]}-diffsteps${HUMANEVAL_DIFFUSION_STEPS_ARRAY[$i]}-block${HUMANEVAL_BLOCK_SIZES_ARRAY[$i]}-thresh${HUMANEVAL_BLOCK_ADD_THRESHOLDS_ARRAY[$i]}-decodethresh${HUMANEVAL_DECODED_TOKEN_THRESHOLDS_ARRAY[$i]}-skip${HUMANEVAL_SKIP_THRESHOLDS_ARRAY[$i]}-topp${HUMANEVAL_TOP_PS_ARRAY[$i]}-dtype${HUMANEVAL_DTYPES_ARRAY[$i]}-sampling${HUMANEVAL_SAMPLING_STRATEGIES_ARRAY[$i]}"
100
+ echo "Running HumanEval evaluation $((i+1))/${humaneval_array_length} for $lora_model_name..."
101
+ if [[ "${HUMANEVAL_TOP_PS_ARRAY[$i]}" == "none" ]]; then
102
+ humaneval_model_args="pretrained=${base_model},lora_path=${lora_model},max_new_tokens=${HUMANEVAL_LENGTHS_ARRAY[$i]},diffusion_steps=${HUMANEVAL_DIFFUSION_STEPS_ARRAY[$i]},temperature=${HUMANEVAL_TEMP_ARRAY[$i]},add_bos_token=true,escape_until=true,block_size=${HUMANEVAL_BLOCK_SIZES_ARRAY[$i]},block_add_threshold=${HUMANEVAL_BLOCK_ADD_THRESHOLDS_ARRAY[$i]},skip_threshold=${HUMANEVAL_SKIP_THRESHOLDS_ARRAY[$i]},decoded_token_threshold=${HUMANEVAL_DECODED_TOKEN_THRESHOLDS_ARRAY[$i]},dtype=${HUMANEVAL_DTYPES_ARRAY[$i]},sampling_strategy=${HUMANEVAL_SAMPLING_STRATEGIES_ARRAY[$i]},save_dir=${output_path}"
103
+ else
104
+ humaneval_model_args="pretrained=${base_model},lora_path=${lora_model},max_new_tokens=${HUMANEVAL_LENGTHS_ARRAY[$i]},diffusion_steps=${HUMANEVAL_DIFFUSION_STEPS_ARRAY[$i]},temperature=${HUMANEVAL_TEMP_ARRAY[$i]},top_p=${HUMANEVAL_TOP_PS_ARRAY[$i]},add_bos_token=true,escape_until=true,block_size=${HUMANEVAL_BLOCK_SIZES_ARRAY[$i]},block_add_threshold=${HUMANEVAL_BLOCK_ADD_THRESHOLDS_ARRAY[$i]},skip_threshold=${HUMANEVAL_SKIP_THRESHOLDS_ARRAY[$i]},decoded_token_threshold=${HUMANEVAL_DECODED_TOKEN_THRESHOLDS_ARRAY[$i]},dtype=${HUMANEVAL_DTYPES_ARRAY[$i]},sampling_strategy=${HUMANEVAL_SAMPLING_STRATEGIES_ARRAY[$i]},save_dir=${output_path}"
105
+ fi
106
+ CUDA_VISIBLE_DEVICES=5 accelerate launch --main_process_port 29520 --num_processes 1 eval_dream_d2f_vllm.py --model dream_lora \
107
+ --model_args $humaneval_model_args \
108
+ --tasks humaneval \
109
+ --num_fewshot ${HUMANEVAL_NSHOTS_ARRAY[$i]} \
110
+ --batch_size 1 \
111
+ --output_path $output_path \
112
+ --log_samples \
113
+ --confirm_run_unsafe_code
114
+ done
115
+
116
+ for i in "${!TASKS_ARRAY[@]}"; do
117
+ output_path="evals_dream${lora_model_name}/${TASKS_ARRAY[$i]}-ns${NSHOTS_ARRAY[$i]}-len${LENGTH_ARRAY[$i]}-temp${TEMP_ARRAY[$i]}-limit${LIMITS_ARRAY[$i]}-diffsteps${LENGTH_ARRAY[$i]}-block${BLOCK_SIZES_ARRAY[$i]}-thresh${BLOCK_ADD_THRESHOLDS_ARRAY[$i]}-decodethresh${DECODED_TOKEN_THRESHOLDS_ARRAY[$i]}-skip${SKIP_THRESHOLDS_ARRAY[$i]}-topp${TOP_PS_ARRAY[$i]}-dtype${DTYPES_ARRAY[$i]}-sampling${SAMPLING_STRATEGIES_ARRAY[$i]}"
118
+ if [[ "${TOP_PS_ARRAY[$i]}" == "none" ]]; then
119
+ model_args="pretrained=${base_model},lora_path=${lora_model},max_new_tokens=${LENGTH_ARRAY[$i]},diffusion_steps=${LENGTH_ARRAY[$i]},add_bos_token=true,temperature=${TEMP_ARRAY[$i]},block_size=${BLOCK_SIZES_ARRAY[$i]},block_add_threshold=${BLOCK_ADD_THRESHOLDS_ARRAY[$i]},skip_threshold=${SKIP_THRESHOLDS_ARRAY[$i]},decoded_token_threshold=${DECODED_TOKEN_THRESHOLDS_ARRAY[$i]},dtype=${DTYPES_ARRAY[$i]},sampling_strategy=${SAMPLING_STRATEGIES_ARRAY[$i]},save_dir=${output_path}"
120
+ else
121
+ model_args="pretrained=${base_model},lora_path=${lora_model},max_new_tokens=${LENGTH_ARRAY[$i]},diffusion_steps=${LENGTH_ARRAY[$i]},add_bos_token=true,temperature=${TEMP_ARRAY[$i]},top_p=${TOP_PS_ARRAY[$i]},block_size=${BLOCK_SIZES_ARRAY[$i]},block_add_threshold=${BLOCK_ADD_THRESHOLDS_ARRAY[$i]},skip_threshold=${SKIP_THRESHOLDS_ARRAY[$i]},decoded_token_threshold=${DECODED_TOKEN_THRESHOLDS_ARRAY[$i]},dtype=${DTYPES_ARRAY[$i]},sampling_strategy=${SAMPLING_STRATEGIES_ARRAY[$i]},save_dir=${output_path}"
122
+ fi
123
+ CUDA_VISIBLE_DEVICES=5 accelerate launch --main_process_port 29520 --num_processes 1 eval_dream_d2f_vllm.py --model dream_lora \
124
+ --model_args $model_args \
125
+ --tasks ${TASKS_ARRAY[$i]} \
126
+ --limit ${LIMITS_ARRAY[$i]} \
127
+ --num_fewshot ${NSHOTS_ARRAY[$i]} \
128
+ --batch_size 1 \
129
+ --output_path $output_path \
130
+ --log_samples \
131
+ --confirm_run_unsafe_code
132
+ done
133
+ done
134
+
135
+ echo "All evaluations completed!"
eval_llada.py ADDED
@@ -0,0 +1,1198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import gc
3
+ import json
4
+ import time # Add time module
5
+ from datetime import timedelta
6
+ from typing import List, Optional, Tuple, Type, TypeVar, Union, Dict
7
+ import torch
8
+ import torch.nn.functional as F
9
+ import torch.distributions as dists
10
+ import transformers
11
+ from transformers import AutoTokenizer
12
+ from peft import LoraConfig, get_peft_model
13
+ from accelerate import (
14
+ Accelerator,
15
+ InitProcessGroupKwargs,
16
+ )
17
+ from datasets import Dataset
18
+ from packaging import version
19
+ from tqdm import tqdm
20
+ from peft import PeftConfig, PeftModel
21
+ import numpy as np # Add numpy import
22
+ import os
23
+ import jinja2
24
+
25
+ # Import LLaDA model related modules
26
+ from model_cache.llada.modeling_llada import LLaDAModelLM
27
+ from model_cache.llada.configuration_llada import LLaDAConfig
28
+
29
+ from lm_eval import utils
30
+ from lm_eval.api.instance import Instance
31
+ from lm_eval.api.model import TemplateLM
32
+ from lm_eval.api.registry import register_model
33
+ from lm_eval.models.utils import get_dtype
34
+ from lm_eval.__main__ import cli_evaluate
35
+
36
+ eval_logger = logging.getLogger(__name__)
37
+ T = TypeVar("T", bound="TemplateLM")
38
+
39
+ import random
40
+ def set_seed(seed):
41
+ torch.manual_seed(seed)
42
+ random.seed(seed)
43
+ np.random.seed(seed)
44
+
45
+ torch.backends.cudnn.deterministic = True
46
+ torch.backends.cudnn.benchmark = False
47
+
48
+ def create_full_block_attention_mask(prompt_length, max_length, block_size, device=None, dtype=None):
49
+ """
50
+ Creates a complete attention mask for the entire sequence with block-based causal attention.
51
+
52
+ Args:
53
+ prompt_length: Length of the prompt (first irregular block)
54
+ max_length: Maximum total sequence length
55
+ block_size: Size of each regular block
56
+ device: Device to create tensor on
57
+ dtype: Data type for the attention mask
58
+
59
+ Returns:
60
+ attention_mask: Tensor of shape [1, 1, max_length, max_length]
61
+ """
62
+ # Use the provided dtype or default to bfloat16
63
+ if dtype is None:
64
+ dtype = torch.bfloat16
65
+
66
+ # Initialize mask with -inf (no attention)
67
+ attention_mask = torch.full((1, 1, max_length, max_length), -torch.inf, device=device, dtype=dtype)
68
+
69
+ # Block 0: Prompt (can see itself)
70
+ attention_mask[:, :, :prompt_length, :prompt_length] = 0
71
+
72
+ # Calculate the number of regular blocks after prompt
73
+ remaining_length = max_length - prompt_length
74
+ num_blocks = (remaining_length + block_size - 1) // block_size
75
+
76
+ # Process each regular block
77
+ for b in range(num_blocks):
78
+ block_start = prompt_length + b * block_size
79
+ block_end = min(prompt_length + (b + 1) * block_size, max_length)
80
+
81
+ # Current block can see the prompt
82
+ attention_mask[:, :, block_start:block_end, :prompt_length] = 0
83
+
84
+ # Current block can see all previous regular blocks
85
+ for prev_b in range(b):
86
+ prev_start = prompt_length + prev_b * block_size
87
+ prev_end = min(prompt_length + (prev_b + 1) * block_size, max_length)
88
+ attention_mask[:, :, block_start:block_end, prev_start:prev_end] = 0
89
+
90
+ # Current block can see itself (full attention within block)
91
+ attention_mask[:, :, block_start:block_end, block_start:block_end] = 0
92
+
93
+ return attention_mask
94
+
95
+ def extract_attention_mask(full_mask, start_pos, input_length, cache_length):
96
+ """
97
+ Extract the relevant portion of attention mask for current forward pass.
98
+
99
+ Args:
100
+ full_mask: Complete attention mask [1, 1, max_length, max_length]
101
+ start_pos: Starting position in the full sequence
102
+ input_length: Length of current input sequence
103
+ cache_length: Length of cached sequence
104
+
105
+ Returns:
106
+ attention_mask: Extracted mask [1, 1, input_length, cache_length + input_length]
107
+ """
108
+ end_pos = start_pos + input_length
109
+ total_length = cache_length + input_length
110
+
111
+ # Extract the relevant rows (current input positions)
112
+ # and columns (cache + current input positions)
113
+ extracted_mask = torch.full((1, 1, input_length, total_length), -torch.inf,
114
+ device=full_mask.device, dtype=full_mask.dtype)
115
+
116
+ # Copy cache columns (0 to cache_length in the extracted mask corresponds to 0 to cache_length in full mask)
117
+ extracted_mask[:, :, :, :cache_length] = full_mask[:, :, start_pos:end_pos, :cache_length]
118
+
119
+ # Copy current input columns
120
+ extracted_mask[:, :, :, cache_length:] = full_mask[:, :, start_pos:end_pos, start_pos:end_pos]
121
+
122
+ return extracted_mask
123
+
124
+ def build_custom_float_attention_mask(input_ids, prompt_length, block_size, device=None, dtype=None):
125
+ """
126
+ Builds a custom float attention mask with block-based causal attention.
127
+
128
+ Args:
129
+ input_ids: Input token IDs.
130
+ prompt_length: Length of the prompt for each sequence in the batch.
131
+ block_size: Size of each regular block.
132
+ device: Device to create tensor on.
133
+ dtype: Data type for the attention mask.
134
+
135
+ Returns:
136
+ attn_mask: Tensor of shape [B, 1, seq_len, seq_len].
137
+ """
138
+ B, seq_len = input_ids.shape
139
+ # Use the provided dtype or default to float32
140
+ if dtype is None:
141
+ dtype = torch.float32
142
+ # Initialize to all -inf
143
+ attn_mask = torch.full((B, 1, seq_len, seq_len), float('-inf'), dtype=dtype, device=device)
144
+ # 1. Prompt section: each token can attend to the entire prompt
145
+ for i in range(B):
146
+ attn_mask[i, :, :, :prompt_length[i]] = 0.0 # Allow all tokens to see the prompt
147
+
148
+ # 2. Block division: divide blocks starting from prompt_length
149
+ num_blocks = (seq_len - prompt_length[i] + block_size - 1) // block_size
150
+
151
+ for b in range(num_blocks):
152
+ block_start = prompt_length[i] + b * block_size
153
+ block_end = min(block_start + block_size, seq_len)
154
+
155
+ # Full attention within the block
156
+ attn_mask[i, :, block_start:block_end, block_start:block_end] = 0.0
157
+
158
+ # Causal attention between blocks (can only see previous blocks)
159
+ for prev_b in range(b):
160
+ prev_start = prompt_length[i] + prev_b * block_size
161
+ prev_end = min(prev_start + block_size, seq_len)
162
+
163
+ # Current block can see previous blocks
164
+ attn_mask[i, :, block_start:block_end, prev_start:prev_end] = 0.0
165
+
166
+ return attn_mask
167
+
168
+ def top_p_logits(logits, top_p=None):
169
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
170
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
171
+ sorted_indices_to_remove = cumulative_probs > top_p
172
+ # Shift the indices to the right to keep the first token above the threshold
173
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
174
+ sorted_indices_to_remove[..., 0] = 0
175
+
176
+ mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
177
+ mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
178
+ logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
179
+ return logits
180
+
181
+ def top_k_logits(logits, top_k=None):
182
+ top_k = min(top_k, logits.size(-1)) # Safety check
183
+ # Remove all tokens with a probability less than the last token of the top-k
184
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
185
+ logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
186
+ return logits
187
+
188
+ def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
189
+ if temperature > 0:
190
+ logits = logits / temperature
191
+ if top_p is not None and top_p < 1:
192
+ logits = top_p_logits(logits, top_p)
193
+ if top_k is not None:
194
+ logits = top_k_logits(logits, top_k)
195
+ probs = torch.softmax(logits, dim=-1)
196
+
197
+ if temperature > 0:
198
+ try:
199
+ x0 = dists.Categorical(probs=probs).sample()
200
+ initial_confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
201
+ except:
202
+ initial_confidence, x0 = probs.max(dim=-1)
203
+ else:
204
+ initial_confidence, x0 = probs.max(dim=-1)
205
+
206
+ # Save initial confidence
207
+ confidence = initial_confidence.clone()
208
+
209
+ if margin_confidence:
210
+ sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
211
+ # Extract top1 and top2 probabilities
212
+ top1_probs = sorted_probs[:, 0]
213
+ top2_probs = sorted_probs[:, 1]
214
+ # Calculate confidence as top1 - top2
215
+ confidence = top1_probs - top2_probs
216
+
217
+ if neg_entropy:
218
+ epsilon = 1e-10
219
+ log_probs = torch.log(probs + epsilon)
220
+ confidence = torch.sum(probs * log_probs, dim=-1)
221
+
222
+ return confidence, x0, initial_confidence
223
+
224
+ @register_model("dream_lora")
225
+ class DreamLoRA(TemplateLM):
226
+ def __init__(
227
+ self,
228
+ pretrained: Union[str, transformers.PreTrainedModel],
229
+ lora_path: str,
230
+ batch_size: Optional[Union[int, str]] = 1,
231
+ device: Optional[str] = "cuda",
232
+ dtype: Optional[Union[str, torch.dtype]] = "auto",
233
+ max_new_tokens: Optional[int] = 128,
234
+ max_length: Optional[int] = 4096, # Updated to match example code
235
+ add_bos_token: Optional[bool] = False,
236
+ nll_type: Optional[str] = "mc",
237
+ log_type: Optional[str] = "ftb",
238
+ mc_num: Optional[int] = 128,
239
+ classifier_free_guidance: Optional[float] = 1.0,
240
+ sampling_eps: Optional[float] = 1e-3,
241
+ diffusion_steps: Optional[int] = 128,
242
+ trust_remote_code: Optional[bool] = True,
243
+ parallelize: Optional[bool] = False,
244
+ autogptq: Optional[Union[bool, str]] = False,
245
+ temperature: Optional[float] = 0.2, # Updated default value
246
+ top_p: Optional[float] = None, # Updated default value
247
+ top_k: Optional[float] = None,
248
+ alg: Optional[str] = "entropy",
249
+ alg_temp: Optional[float] = 0.0,
250
+ escape_until: Optional[bool] = False,
251
+ block_size: Optional[int] = 4, # Updated to match example code
252
+ mask_token_id: Optional[int] = 126336, # Added mask_token_id parameter
253
+ block_add_threshold: Optional[float] = 0.5, # Added block_add_threshold parameter
254
+ decoded_token_threshold: Optional[float] = 0.9, # Added decoded token threshold parameter
255
+ skip_threshold: Optional[float] = 1.0, # Added skip_threshold parameter
256
+ sampling_strategy: Optional[str] = "default", # Added sampling strategy parameter
257
+ save_dir: Optional[str] = None, # Added save directory parameter
258
+ show_speed: Optional[bool] = True, # Added speed statistics parameter
259
+ **kwargs,
260
+ ) -> None:
261
+ super().__init__()
262
+
263
+ # prepare for parallelism
264
+ assert isinstance(device, str)
265
+ assert isinstance(pretrained, str)
266
+ assert isinstance(batch_size, (int, str))
267
+
268
+ gpus = torch.cuda.device_count()
269
+ accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
270
+ accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
271
+ if accelerator.num_processes > 1:
272
+ self.accelerator = accelerator
273
+
274
+ if "npu" in accelerator.device.type:
275
+ gpus = torch.npu.device_count()
276
+
277
+ # using one process with no model parallelism
278
+ if not (parallelize or accelerator.num_processes > 1):
279
+ # use user-passed device
280
+ device_list = set(
281
+ ["cuda", "cpu"]
282
+ + [f"cuda:{i}" for i in range(gpus)]
283
+ + ["mps", "mps:0"]
284
+ + [f"npu:{i}" for i in range(gpus)]
285
+ )
286
+ if device and device in device_list:
287
+ self._device = torch.device(device)
288
+ eval_logger.info(f"Using device '{device}'")
289
+ if device in ("mps", "mps:0") and version.parse(
290
+ torch.__version__
291
+ ) < version.parse("2.1"):
292
+ raise RuntimeError(
293
+ f"mps requires torch >= 2.1. You have {torch.__version__}"
294
+ )
295
+ else:
296
+ eval_logger.info("Device not specified")
297
+ eval_logger.info(f"Cuda Available? {torch.cuda.is_available()}")
298
+ self._device = (
299
+ torch.device("cuda")
300
+ if torch.cuda.is_available()
301
+ else torch.device("cpu")
302
+ )
303
+ else: # Parallelism managed by accelerate
304
+ if device != "cuda":
305
+ eval_logger.info(
306
+ f"Using `accelerate launch` or `parallelize=True`, device '{device}' will be overridden when placing model."
307
+ )
308
+ # TODO: include in warning that `load_in_8bit` etc. affect this too
309
+ self._device = (
310
+ self.accelerator.device
311
+ if hasattr(self, "accelerator")
312
+ else torch.device(device)
313
+ )
314
+
315
+ self.batch_size_per_gpu = batch_size
316
+ if isinstance(batch_size, str):
317
+ self.batch_size_per_gpu = int(batch_size)
318
+
319
+ # Save LoRA path and block_size
320
+ self.lora_path = lora_path
321
+ self.block_size = block_size
322
+ self.block_add_threshold = block_add_threshold # Added block_add_threshold attribute
323
+ self.skip_threshold = skip_threshold # Added skip_threshold attribute
324
+ self.sampling_strategy = sampling_strategy # Save sampling strategy parameter
325
+ self.decoded_token_threshold = decoded_token_threshold # Added decoded token threshold attribute
326
+
327
+ # Save target_dtype for later use
328
+ self.target_dtype = get_dtype(dtype)
329
+
330
+ self._create_model_and_tokenizer(pretrained, dtype, trust_remote_code)
331
+
332
+ if isinstance(pretrained, str):
333
+ if gpus >= 1 or str(self.device) == "mps":
334
+ # TODO: can remove this whole snippet except in the mps case, perhaps?
335
+ if not (parallelize or autogptq or hasattr(self, "accelerator")):
336
+ # place model onto device requested manually,
337
+ # if not using HF Accelerate or device_map
338
+ # or any other option that preloads model onto device
339
+ try:
340
+ self.model.to(self.device)
341
+ except ValueError:
342
+ eval_logger.debug(
343
+ "Failed to place model onto specified device. This may be because the model is quantized via `bitsandbytes` or `device_map` is provided. If the desired GPU is being used, this message is safe to ignore."
344
+ )
345
+ # multigpu data-parallel support when launched with accelerate
346
+ if gpus > 1:
347
+ if accelerator.num_processes > 1:
348
+ if parallelize:
349
+ eval_logger.warning(
350
+ "You are both using a HF Accelerate `device_map` (`--model_args parallelize=True`) and launching via `accelerate launch`. This will attempt to do model and data parallelism depending on the resources available."
351
+ )
352
+ elif gpus > accelerator.num_processes:
353
+ eval_logger.warning(
354
+ "WARNING: The number of total system GPUs does not match the number of spawned processes. "
355
+ "If you would like to use data parallelism, please launch the script "
356
+ "with 'accelerate launch *script*'. "
357
+ f"Current run will proceed with {accelerator.num_processes} devices."
358
+ )
359
+ if self.accelerator.is_local_main_process:
360
+ eval_logger.info(
361
+ f"Using {gpus} devices with data parallelism"
362
+ )
363
+
364
+ self._device = torch.device(f"{accelerator.device}")
365
+ self.accelerator = accelerator
366
+
367
+ self._rank = self.accelerator.local_process_index
368
+ self._world_size = self.accelerator.num_processes
369
+ else:
370
+ # if we aren't launching via accelerate, ditch
371
+ self._rank = 0
372
+ self._world_size = 1
373
+ else:
374
+ # if a PreTrainedModel was passed into HFLM, we forgo distributed setup.
375
+ eval_logger.warning(
376
+ "Passed an already-initialized model through `pretrained`, assuming single-process call to evaluate() or custom distributed integration"
377
+ )
378
+ self._rank = 0
379
+ self._world_size = 1
380
+
381
+ self.max_length = max_length
382
+ self.add_bos_token = add_bos_token
383
+ # generation params
384
+ self.max_new_tokens = max_new_tokens
385
+ self.diffusion_steps = diffusion_steps
386
+ self.temperature = temperature
387
+ self.top_p = top_p
388
+ self.top_k = top_k
389
+ self.alg = alg
390
+ self.alg_temp = alg_temp
391
+ self.escape_until = escape_until
392
+ self.block_size = block_size
393
+ self.mask_token_id = mask_token_id
394
+
395
+ # loglikelihood params
396
+ self.nll_type = nll_type
397
+ self.log_type = log_type
398
+ self.mc_num = mc_num
399
+ self.classifier_free_guidance = classifier_free_guidance
400
+ self.sampling_eps = sampling_eps
401
+
402
+ # Add backend attribute, consistent with LLaDA.py
403
+ self.backend = "causal"
404
+
405
+ # Add truncation attribute, consistent with LLaDA.py
406
+ self.truncation = False
407
+
408
+ self.save_dir = save_dir
409
+ self.show_speed = show_speed
410
+
411
+ @property
412
+ def batch_size(self):
413
+ return self.batch_size_per_gpu
414
+
415
+ @property
416
+ def eot_token_id(self):
417
+ # we use EOT because end of *text* is more accurate for what we're doing than end of *sentence*
418
+ return self.tokenizer.eos_token_id
419
+
420
+ @property
421
+ def device(self):
422
+ return self._device
423
+
424
+ @property
425
+ def rank(self):
426
+ return self._rank
427
+
428
+ @property
429
+ def world_size(self):
430
+ return self._world_size
431
+
432
+ def _create_model_and_tokenizer(self, pretrained, dtype, trust_remote_code):
433
+ # Get correct data type
434
+ target_dtype = get_dtype(dtype)
435
+
436
+ # Load LLaDA model and configuration
437
+ config = LLaDAConfig.from_pretrained(pretrained)
438
+ self.model = LLaDAModelLM.from_pretrained(
439
+ pretrained,
440
+ config=config,
441
+ torch_dtype=target_dtype,
442
+ trust_remote_code=False,
443
+ ).eval()
444
+
445
+ # Load LoRA configuration and model
446
+ peft_config = PeftConfig.from_pretrained(self.lora_path)
447
+ self.model = PeftModel.from_pretrained(self.model, self.lora_path)
448
+
449
+ # Convert data type only when target_dtype is not None and not "auto"
450
+ if target_dtype is not None and target_dtype != "auto":
451
+ self.model = self.model.to(target_dtype)
452
+
453
+ # Move to specified device
454
+ self.model = self.model.to(self.device)
455
+
456
+ # Load tokenizer
457
+ self.tokenizer = AutoTokenizer.from_pretrained(
458
+ pretrained, trust_remote_code=trust_remote_code
459
+ )
460
+
461
+ def tok_encode(
462
+ self, string: str, left_truncate_len=None, add_special_tokens=None
463
+ ) -> List[int]:
464
+ """ """
465
+ # default for None - empty dict, use predefined tokenizer param
466
+ # used for all models except for CausalLM or predefined value
467
+ special_tokens_kwargs = {}
468
+
469
+ # by default for CausalLM - false or self.add_bos_token is set
470
+ if add_special_tokens is None:
471
+ if self.backend == "causal":
472
+ special_tokens_kwargs = {
473
+ "add_special_tokens": False or self.add_bos_token
474
+ }
475
+ # otherwise the method explicitly defines the value
476
+ else:
477
+ special_tokens_kwargs = {"add_special_tokens": add_special_tokens}
478
+
479
+ encoding = self.tokenizer.encode(string, **special_tokens_kwargs)
480
+
481
+ # left-truncate the encoded context to be at most `left_truncate_len` tokens long
482
+ if left_truncate_len:
483
+ encoding = encoding[-left_truncate_len:]
484
+ return encoding
485
+
486
+ def tok_batch_encode(
487
+ self,
488
+ strings: List[str],
489
+ padding_side: str = "left",
490
+ left_truncate_len: int = None,
491
+ truncation: bool = False,
492
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
493
+ # encode a batch of strings. converts to tensors and pads automatically, unlike tok_encode.
494
+ old_padding_side = self.tokenizer.padding_side
495
+ self.tokenizer.padding_side = padding_side
496
+
497
+ add_special_tokens = {}
498
+ if self.backend == "causal":
499
+ add_special_tokens = {"add_special_tokens": False or self.add_bos_token}
500
+
501
+ encoding = self.tokenizer(
502
+ strings,
503
+ truncation=truncation,
504
+ padding="longest",
505
+ return_tensors="pt",
506
+ **add_special_tokens,
507
+ )
508
+ if left_truncate_len:
509
+ original_lengths = encoding["input_ids"].size(1)
510
+ if original_lengths > left_truncate_len:
511
+ eval_logger.warn(
512
+ f"Left truncation applied. Original sequence length was {original_lengths}, "
513
+ f"truncating to last {left_truncate_len} tokens. Some content will be lost.",
514
+ )
515
+ encoding["input_ids"] = encoding["input_ids"][:, -left_truncate_len:]
516
+ encoding["attention_mask"] = encoding["attention_mask"][
517
+ :, -left_truncate_len:
518
+ ]
519
+ self.tokenizer.padding_side = old_padding_side
520
+
521
+ return encoding["input_ids"].to(self.device), encoding["attention_mask"].to(self.device)
522
+
523
+ def tok_decode(self, tokens, skip_special_tokens=True):
524
+ return self.tokenizer.decode(tokens, skip_special_tokens=skip_special_tokens)
525
+
526
+
527
+
528
+ def _count_tokens_after_truncation(self, response_text: str, until_terms: List[str] = None) -> int:
529
+ """
530
+ Unified token counting function: calculates the number of non-126081 tokens after truncating the response.
531
+ """
532
+ # Apply truncation based on until parameters
533
+ truncated_text = response_text
534
+ if until_terms and not self.escape_until:
535
+ for term in until_terms:
536
+ if len(term) > 0:
537
+ truncated_text = truncated_text.split(term)[0]
538
+
539
+ # Re-tokenize processed answer and count non-126081 tokens
540
+ generated_answer_ids = torch.tensor(self.tokenizer(truncated_text)["input_ids"])
541
+ return int((generated_answer_ids != 126081).sum())
542
+
543
+ @classmethod
544
+ def create_from_arg_string(
545
+ cls: Type[T], arg_string: str, additional_config: Optional[dict] = None
546
+ ) -> T:
547
+ """
548
+ Creates an instance of the LM class using the given argument string and additional config.
549
+
550
+ Parameters:
551
+ - arg_string: A string containing arguments in the format key1=value1,key2=value2.
552
+ - additional_config: Optional dictionary containing additional configuration parameters.
553
+
554
+ Returns:
555
+ - Instance of the LM class.
556
+ """
557
+ additional_config = {} if additional_config is None else additional_config
558
+ args = utils.simple_parse_args_string(arg_string)
559
+ args2 = {k: v for k, v in additional_config.items() if v is not None}
560
+ return cls(**args, **args2)
561
+
562
+ def apply_chat_template(
563
+ self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
564
+ ) -> str:
565
+ """
566
+ Method to apply a chat template to a list of chat history between user and model.
567
+ """
568
+ try:
569
+ chat_templated = self.tokenizer.apply_chat_template(
570
+ chat_history,
571
+ tokenize=False,
572
+ add_generation_prompt=add_generation_prompt,
573
+ continue_final_message=not add_generation_prompt,
574
+ )
575
+ except jinja2.exceptions.TemplateError:
576
+ eval_logger.warning(
577
+ "Failed to apply chat template. removing the system role in chat history."
578
+ )
579
+ chat_history = [msg for msg in chat_history if msg["role"] != "system"]
580
+ chat_templated = self.tokenizer.apply_chat_template(
581
+ chat_history,
582
+ tokenize=False,
583
+ add_generation_prompt=add_generation_prompt,
584
+ continue_final_message=not add_generation_prompt,
585
+ )
586
+
587
+ return chat_templated
588
+
589
+ @property
590
+ def tokenizer_name(self) -> str:
591
+ return self.tokenizer.name_or_path.replace("/", "__")
592
+
593
+ def _generate_block_single(self, prompt):
594
+ """
595
+ Generates a response for a single prompt using parallel block generation, based on KV cache, and uses pre-generated attention masks.
596
+ Returns: generated_sequence (List[int]) - List of generated token IDs
597
+ """
598
+ self.model.eval()
599
+
600
+ mask_id = self.mask_token_id
601
+ block_size = self.block_size
602
+ block_add_threshold = self.block_add_threshold
603
+ skip_threshold = self.skip_threshold
604
+
605
+ # Pre-generate the full attention mask, using the model's data type
606
+ prompt_length = prompt.shape[1]
607
+ full_attention_mask = create_full_block_attention_mask(
608
+ prompt_length=prompt_length,
609
+ max_length=self.max_length,
610
+ block_size=block_size,
611
+ device=self.device,
612
+ dtype=self.target_dtype if self.target_dtype is not None and self.target_dtype != "auto" else torch.bfloat16
613
+ )
614
+
615
+ with torch.inference_mode():
616
+ # Initialization
617
+ x_t = prompt.to(self.device)
618
+
619
+ # Track block states - states can be: 'active', 'to_cache', 'in_cache'
620
+ # Added 'is_complete' field to indicate whether it's a complete state (True) or incomplete state (False)
621
+ block_states = {
622
+ 0: {
623
+ 'start_pos': 0,
624
+ 'end_pos': prompt.shape[1],
625
+ 'mask_count': 0,
626
+ 'total_masks': prompt.shape[1],
627
+ 'state': 'to_cache', # Prompt is immediately ready for caching
628
+ 'is_complete': True, # Prompt is always in a complete state
629
+ },
630
+ }
631
+
632
+ # Initialize cache
633
+ past_key_values = None
634
+
635
+ current_blocks = 0 # Number of active blocks
636
+ step = 0
637
+ eos_detected = False # EOS detection flag
638
+ cache_length = 0
639
+ while current_blocks >= 0:
640
+ step += 1
641
+
642
+ # Check if a new block needs to be added
643
+ if len(block_states)-1 < (self.max_new_tokens // block_size) and not eos_detected:
644
+ last_block_id = len(block_states) - 1
645
+ current_progress = (block_states[last_block_id]['total_masks'] -
646
+ block_states[last_block_id]['mask_count']) / block_states[last_block_id]['total_masks']
647
+ if current_progress >= block_add_threshold:
648
+ # Add new block
649
+ new_block_id = len(block_states)
650
+ new_start_pos = x_t.shape[1]
651
+ x_t = torch.cat([x_t, torch.tensor([[mask_id] * block_size]).to(self.device)], dim=1)
652
+
653
+ block_states[new_block_id] = {
654
+ 'start_pos': new_start_pos,
655
+ 'end_pos': new_start_pos + block_size,
656
+ 'mask_count': block_size,
657
+ 'total_masks': block_size,
658
+ 'state': 'active',
659
+ 'is_complete': False, # New block defaults to an incomplete state
660
+ }
661
+ current_blocks += 1
662
+
663
+ # At the beginning of each loop, update the block's complete/incomplete states
664
+ self._update_block_completion_states(block_states, self.decoded_token_threshold)
665
+ # Check if there are still mask tokens
666
+ mask_index = (x_t == mask_id)
667
+ if mask_index.sum() == 0 and current_blocks == 0:
668
+ break
669
+
670
+ # Determine which blocks need to be added to the cache
671
+ blocks_to_cache = [bid for bid, state in block_states.items()
672
+ if state['state'] == 'to_cache']
673
+
674
+ # Determine the part to be processed
675
+ update_kvcache = 0
676
+ if blocks_to_cache:
677
+ # Find the earliest block to be cached
678
+ earliest_block_id = min(blocks_to_cache)
679
+ earliest_pos = block_states[earliest_block_id]['start_pos']
680
+
681
+ # Find the latest block to be cached
682
+ latest_block_id = max(blocks_to_cache)
683
+ latest_pos = block_states[latest_block_id]['end_pos']
684
+
685
+ # Update the cache for all blocks within this range
686
+ update_kvcache = latest_pos - earliest_pos
687
+
688
+ # Create input sequence for forward pass
689
+ process_start_pos = cache_length
690
+
691
+ if update_kvcache > 0:
692
+ # Need to update cache - use completed blocks
693
+ earliest_block_to_cache = min(blocks_to_cache)
694
+ input_seq = x_t[:, block_states[earliest_block_to_cache]['start_pos']:]
695
+ process_start_pos = block_states[earliest_block_to_cache]['start_pos']
696
+ else:
697
+ # Only process active blocks
698
+ active_blocks = [bid for bid, state in block_states.items() if state['state'] == 'active']
699
+ if active_blocks:
700
+ # Get all active blocks after caching
701
+ earliest_active_after_cache = float('inf')
702
+ for bid in active_blocks:
703
+ if block_states[bid]['start_pos'] >= cache_length:
704
+ earliest_active_after_cache = min(earliest_active_after_cache, block_states[bid]['start_pos'])
705
+
706
+ if earliest_active_after_cache < float('inf'):
707
+ input_seq = x_t[:, earliest_active_after_cache:]
708
+ process_start_pos = earliest_active_after_cache
709
+ else:
710
+ # No active blocks after caching, this should not happen
711
+ input_seq = x_t[:, cache_length:]
712
+ # If cache length is already equal to or exceeds sequence length, exit
713
+ if cache_length >= x_t.shape[1]:
714
+ print(f"Cache length ({cache_length}) >= sequence length ({x_t.shape[1]}) at step {step}. Exiting generation loop.")
715
+ raise Exception("Cache length >= sequence length")
716
+ else:
717
+ # No active blocks, but blocks might need to be cached in the next iteration
718
+ break
719
+
720
+ # Check if input_seq is empty
721
+ if input_seq.shape[1] == 0:
722
+ print(f"Warning: input_seq is empty at step {step}. Breaking generation loop.")
723
+ raise Exception("input_seq is empty")
724
+
725
+ # Extract the attention mask for the current input from the pre-generated full mask
726
+ input_length = input_seq.shape[1]
727
+ attention_mask = extract_attention_mask(
728
+ full_mask=full_attention_mask,
729
+ start_pos=process_start_pos,
730
+ input_length=input_length,
731
+ cache_length=cache_length
732
+ )
733
+
734
+ outputs = self.model(
735
+ input_seq,
736
+ attention_bias=attention_mask,
737
+ past_key_values=past_key_values,
738
+ use_cache=True,
739
+ update_kvcache=update_kvcache+cache_length,
740
+ )
741
+
742
+ # Get current logits - LLaDA model directly uses logits, no shifting needed
743
+ logits = outputs.logits
744
+
745
+ # Update cache if needed
746
+ if update_kvcache > 0:
747
+ # Update cache
748
+ past_key_values = outputs.past_key_values
749
+
750
+ # Mark blocks as cached
751
+ for block_id in blocks_to_cache:
752
+ block_states[block_id]['state'] = 'in_cache'
753
+
754
+ # Process mask tokens for each active block
755
+ blocks_to_deactivate = []
756
+
757
+ for block_id in sorted(block_states.keys()):
758
+ if block_states[block_id]['state'] != 'active':
759
+ continue
760
+
761
+ # Get mask positions for this block
762
+ block_start = block_states[block_id]['start_pos']
763
+ block_end = block_states[block_id]['end_pos']
764
+ block_mask_index = mask_index.clone()
765
+ block_mask_index[:, :block_start] = False
766
+ block_mask_index[:, block_end:] = False
767
+
768
+ # Skip if the current block has no masks
769
+ if block_mask_index.sum() == 0:
770
+ blocks_to_deactivate.append(block_id)
771
+ continue
772
+
773
+
774
+ # Calculate relative position of logits
775
+ logit_offset = block_start - process_start_pos
776
+ block_rel_positions = torch.where(block_mask_index[0, block_start:block_end])[0]
777
+
778
+
779
+ if block_rel_positions.size(0) > 0:
780
+ # Get logits for masked positions
781
+ block_mask_logits = logits[:, logit_offset + block_rel_positions, :]
782
+
783
+ # Sample tokens
784
+ confidence, x0, initial_confidence = sample_tokens(
785
+ block_mask_logits.squeeze(0),
786
+ self.temperature,
787
+ top_p=self.top_p,
788
+ top_k=self.top_k,
789
+ neg_entropy=(self.sampling_strategy == "neg_entropy"),
790
+ margin_confidence=(self.sampling_strategy == "margin_confidence")
791
+ )
792
+
793
+ # Use different sampling strategies based on the block's complete/incomplete state
794
+ is_complete = block_states[block_id]['is_complete']
795
+
796
+ if is_complete:
797
+ # Complete state: apply confidence threshold, if no high confidence, select the highest
798
+ high_conf_indices = torch.where(initial_confidence > skip_threshold)[0]
799
+
800
+ if len(high_conf_indices) == 0:
801
+ number_transfer_tokens = 1
802
+ _, transfer_index = torch.topk(confidence, number_transfer_tokens)
803
+ else:
804
+ transfer_index = torch.tensor([], device=self.device, dtype=torch.long)
805
+
806
+ # Merge indices
807
+ all_indices = torch.unique(torch.cat([transfer_index, high_conf_indices]))
808
+ else:
809
+ # Incomplete state: only apply confidence threshold, if no tokens exceed the threshold, select none
810
+ high_conf_indices = torch.where(initial_confidence > skip_threshold)[0]
811
+ all_indices = high_conf_indices
812
+
813
+ # Update tokens
814
+ if len(all_indices) > 0:
815
+ x0_ = torch.zeros_like(x0, device=self.device, dtype=torch.long) + mask_id
816
+ x0_[all_indices] = x0[all_indices].clone()
817
+
818
+ # Map indices back to original positions
819
+ for i, idx in enumerate(all_indices):
820
+ abs_pos = block_start + block_rel_positions[idx]
821
+ x_t[0, abs_pos] = x0_[idx]
822
+
823
+ # Update block state
824
+ block_states[block_id]['mask_count'] -= len(all_indices)
825
+
826
+ # Check for EOS token
827
+ eos_token_id = 126081
828
+ if eos_token_id is not None:
829
+ for idx in all_indices:
830
+ if x0[idx].item() == eos_token_id:
831
+ eos_detected = True
832
+ break
833
+
834
+ # Deactivate this block if no masks remain
835
+ mask_index = (x_t == mask_id)
836
+ block_mask_index = mask_index.clone()
837
+ block_mask_index[:, :block_start] = False
838
+ block_mask_index[:, block_end:] = False
839
+ if block_mask_index.sum() == 0:
840
+ blocks_to_deactivate.append(block_id)
841
+ continue
842
+
843
+ # Deactivate completed blocks and mark them for caching in the next iteration
844
+ for block_id in blocks_to_deactivate:
845
+ if block_states[block_id]['state'] == 'active':
846
+ # Check if all preceding blocks are already in a non-active state
847
+ can_deactivate = True
848
+ for prev_block_id in range(block_id):
849
+ if prev_block_id in block_states and block_states[prev_block_id]['state'] == 'active':
850
+ can_deactivate = False
851
+ break
852
+
853
+ # Only mark the current block as 'to_cache' if all preceding blocks are not active
854
+ if can_deactivate:
855
+ block_states[block_id]['state'] = 'to_cache'
856
+ current_blocks -= 1
857
+ # If there are active preceding blocks, keep the current block in active state (do nothing)
858
+
859
+ if update_kvcache > 0:
860
+ cache_length += update_kvcache
861
+ # Safety check
862
+ if step > 10000:
863
+ print(f"WARNING: Hit safety check at step {step}. Exiting generation loop.")
864
+ break
865
+
866
+ current_text = self.tokenizer.decode(x_t[0, prompt.shape[1]:].tolist(),skip_special_tokens=False)
867
+
868
+ # Generate final answer
869
+ generated_sequence = x_t[0, prompt.shape[1]:].tolist()
870
+
871
+ return generated_sequence
872
+
873
+
874
+
875
+ def generate_until(self, requests: List[Instance], disable_tqdm: bool = False):
876
+ res = []
877
+ start_time = time.time()
878
+
879
+ # Statistics variables
880
+ num_tokens = 0
881
+ num_nfe = 0
882
+
883
+ bar = tqdm(total=len(requests), disable=(disable_tqdm or (self.rank != 0)), desc="Running generate_until requests")
884
+
885
+ for i, req in enumerate(requests):
886
+ question = req.args[0]
887
+ # print("question:",question)
888
+ # exit()
889
+ gen_kwargs = req.args[1]
890
+
891
+ # Process input in LLaDA.py style
892
+ # print("Self.add_bos_token:", self.add_bos_token)
893
+ contexts = [question]
894
+ if self.add_bos_token:
895
+ contexts = [self.tokenizer.bos_token + p for p in contexts]
896
+
897
+ # Use the same tokenization method as LLaDA.py
898
+ context_enc, attn_masks = self.tok_batch_encode(
899
+ contexts,
900
+ truncation=self.truncation,
901
+ )
902
+
903
+
904
+
905
+ input_ids = context_enc[0].unsqueeze(0) # Take the first one and add batch dimension
906
+
907
+ # Add length check
908
+ if input_ids.shape[1] > self.max_length - self.max_new_tokens:
909
+ eval_logger.warning(f"Prompt length {input_ids.shape[1]} is larger than {self.max_length-self.max_new_tokens}, cutoff on the left side")
910
+ input_ids = input_ids[:, -(self.max_length-self.max_new_tokens):]
911
+
912
+ # Generate token IDs
913
+ generated_answer = self._generate_block_single(input_ids)
914
+
915
+ # Use tokenizer.batch_decode for decoding, consistent with LLaDA.py
916
+ cont_toks_list = self.tokenizer.batch_decode([generated_answer], skip_special_tokens=True)
917
+ s = cont_toks_list[0] # Take the first (and only) result
918
+
919
+ # Use unified token counting function
920
+ if self.show_speed:
921
+ num_tokens += self._count_tokens_after_truncation(s, gen_kwargs.get("until", []))
922
+ num_nfe += 1 # NFE uses simplified statistics (fixed to 1)
923
+
924
+ # Handle until truncation in LLaDA.py style
925
+ if not self.escape_until:
926
+ for term in gen_kwargs.get("until", []):
927
+ if len(term) > 0:
928
+ s = s.split(term)[0]
929
+
930
+ res.append(s)
931
+ bar.update(1)
932
+
933
+ bar.close()
934
+
935
+ # Save statistics only at the end
936
+ if self.save_dir is not None:
937
+ os.makedirs(self.save_dir, exist_ok=True)
938
+ final_time = time.time()
939
+ total_time = final_time - start_time
940
+
941
+ final_stats = {
942
+ "processed_samples": len(res),
943
+ "total_samples": len(requests),
944
+ "total_tokens": int(num_tokens),
945
+ "total_nfe": int(num_nfe),
946
+ "total_time": total_time,
947
+ "tokens_per_second": float(num_tokens) / total_time if total_time > 0 else 0.0,
948
+ "nfe_per_token": float(num_nfe) / float(num_tokens) if num_tokens > 0 else 0.0,
949
+ "timestamp": final_time
950
+ }
951
+ final_stats_path = os.path.join(self.save_dir, f'rank_{self.rank}_final_stats.json')
952
+ with open(final_stats_path, 'w', encoding='utf-8') as f:
953
+ json.dump(final_stats, f, ensure_ascii=False, indent=2)
954
+
955
+ if self.show_speed:
956
+ final_time = time.time()
957
+ total_time = final_time - start_time
958
+ print(f"\n=== Final Statistics ===")
959
+ print(f"Processed samples: {len(res)}")
960
+ print(f"Total tokens: {num_tokens}")
961
+ print(f"Total time: {total_time:.2f} seconds")
962
+ print(f"Throughput: {num_tokens / total_time:.2f} tokens/s")
963
+ print(f"Total NFE: {num_nfe}")
964
+
965
+ return res
966
+
967
+ def _forward_process(self, batch):
968
+ b, l = batch.shape
969
+ # sample from U[0, 1] following https://arxiv.org/pdf/2107.00630 I.1
970
+ u0 = torch.rand(1, device=batch.device, dtype=torch.float32)
971
+ indices = torch.arange(b, device=batch.device).float()
972
+ t = (u0 + indices / b) % 1
973
+
974
+ p_mask = (1 - self.sampling_eps) * t + self.sampling_eps
975
+
976
+ p_mask = p_mask[:, None].repeat(1, l)
977
+
978
+ mask_indices = torch.rand((b, l), device=batch.device) < p_mask
979
+ # always unmask bos and eos
980
+ mask_indices[:, 0] = False
981
+ mask_indices[:, -1] = False
982
+
983
+ noisy_batch = torch.where(mask_indices, self.mask_token_id, batch)
984
+ return noisy_batch, p_mask
985
+
986
+ @torch.no_grad()
987
+ def get_logits(self, batch, prompt_index):
988
+ '''
989
+ prompt_index : 1D bool tensor, length=batch.shape[1]
990
+ '''
991
+ if self.classifier_free_guidance > 1.:
992
+ assert len(prompt_index) == batch.shape[1]
993
+ prompt_index = prompt_index.unsqueeze(0).repeat(batch.shape[0], 1)
994
+ un_batch = batch.clone()
995
+ un_batch[prompt_index] = self.mask_token_id
996
+ batch = torch.cat([batch, un_batch])
997
+
998
+ input = batch
999
+
1000
+ with torch.amp.autocast('cuda', dtype=torch.bfloat16):
1001
+ logits = self.model(input).logits
1002
+ # since bos always unmask, the first logits will not be used
1003
+ logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
1004
+
1005
+ if self.classifier_free_guidance > 1.:
1006
+ logits, un_logits = torch.chunk(logits, 2, dim=0)
1007
+ logits = un_logits + self.cfg * (logits - un_logits)
1008
+ return logits[:, :batch.shape[1]]
1009
+
1010
+ @torch.no_grad()
1011
+ def _eval_target_nll_mc(self, prefix, target):
1012
+ if prefix is None:
1013
+ seq = target[None, :]
1014
+ else:
1015
+ seq = torch.concatenate([prefix, target])[None, :]
1016
+ seq = seq.repeat((self.batch_size, 1)).to(self.device)
1017
+
1018
+ if self.log_type == 'ftb':
1019
+ prompt_index = torch.arange(seq.shape[1], device=self.device) < len(prefix)
1020
+ else:
1021
+ prompt_index = torch.arange(seq.shape[1], device=self.device) >= len(prefix)
1022
+
1023
+ loss_acc = []
1024
+ for _ in range(max(self.mc_num // self.batch_size, 1)):
1025
+ perturbed_seq = seq.clone()
1026
+ # eval_logger.info("before noising")
1027
+ perturbed_seq_, p_mask = self._forward_process(seq)
1028
+ # eval_logger.info("end noising")
1029
+ if self.log_type == 'ftb':
1030
+ perturbed_seq[:, -len(target):] = perturbed_seq_[:, -len(target):]
1031
+ elif self.log_type == 'btf':
1032
+ perturbed_seq[:, :len(prefix)] = perturbed_seq_[:, :len(prefix)]
1033
+ elif self.log_type == 'union':
1034
+ perturbed_seq = perturbed_seq_
1035
+ else:
1036
+ raise NotImplementedError(self.log_type)
1037
+
1038
+ mask_indices = perturbed_seq == self.mask_token_id
1039
+ logits = self.get_logits(perturbed_seq, prompt_index)
1040
+ loss = F.cross_entropy(logits[mask_indices], seq[mask_indices], reduction='none') / p_mask[mask_indices]
1041
+ loss = loss.sum() / self.batch_size
1042
+ loss_acc.append(loss.item())
1043
+
1044
+ return sum(loss_acc) / len(loss_acc)
1045
+
1046
+ @torch.no_grad()
1047
+ def _eval_target_nll_ar(self, prefix, target):
1048
+ prefix, target = prefix.unsqueeze(0), target.unsqueeze(0) # 1*l1, 1*l2
1049
+ assert self.log_type in ['ftb', 'btf']
1050
+ assert self.nll_type in ['ar_ftb', 'ar_btf']
1051
+
1052
+ if self.log_type == 'ftb':
1053
+ prompt_index = torch.arange(prefix.shape[1] + target.shape[1], device=self.device) < prefix.shape[1]
1054
+ else:
1055
+ prompt_index = torch.arange(prefix.shape[1] + target.shape[1], device=self.device) >= prefix.shape[1]
1056
+
1057
+ if self.log_type == 'ftb':
1058
+ perturbed_ = target.repeat(target.shape[1], 1).clone().contiguous() # l2*l2
1059
+ else:
1060
+ perturbed_ = prefix.repeat(prefix.shape[1], 1).clone().contiguous() # l1*l1
1061
+
1062
+ mask_index = torch.ones((perturbed_.shape[1], perturbed_.shape[1]), dtype=torch.bool)
1063
+ if self.nll_type == 'ar_ftb':
1064
+ mask_index = torch.triu(mask_index)
1065
+ else:
1066
+ mask_index = torch.tril(mask_index)
1067
+ perturbed_[mask_index] = self.mask_token_id
1068
+ if self.log_type == 'ftb':
1069
+ perturbed_seq = torch.cat([prefix.repeat(perturbed_.shape[0], 1), perturbed_], dim=-1)
1070
+ else:
1071
+ perturbed_seq = torch.cat([perturbed_, target.repeat(perturbed_.shape[0], 1)], dim=-1)
1072
+
1073
+ logits_ = []
1074
+ num = len(perturbed_seq) // self.batch_size if len(perturbed_seq) % self.batch_size == 0 else len(perturbed_seq) // self.batch_size + 1
1075
+ for i in range(num):
1076
+ end = (i + 1) * self.batch_size if (i + 1) * self.batch_size < len(perturbed_seq) else len(perturbed_seq)
1077
+ perturbed_seq_ = perturbed_seq[i * self.batch_size: end]
1078
+ perturbed_seq_ = perturbed_seq_.to(self.device)
1079
+ if len(perturbed_seq_.shape) == 1:
1080
+ perturbed_seq_ = perturbed_seq_.unsqueeze(0)
1081
+ logits = self.get_logits(perturbed_seq_, prompt_index)
1082
+ logits_.append(logits.cpu())
1083
+ logits = torch.cat(logits_, dim=0)
1084
+
1085
+ temp_index = torch.ones((perturbed_.shape[1], perturbed_.shape[1]), dtype=torch.bool)
1086
+ if self.nll_type == 'ar_ftb':
1087
+ temp_index = torch.triu(temp_index, diagonal=1)
1088
+ else:
1089
+ temp_index = torch.tril(temp_index, diagonal=-1)
1090
+ mask_index[temp_index] = False
1091
+ if self.log_type == 'ftb':
1092
+ logits_index = torch.cat([torch.zeros((perturbed_.shape[1], prefix.shape[1]), dtype=torch.bool), mask_index], dim=-1)
1093
+ else:
1094
+ logits_index = torch.cat([mask_index, torch.zeros((perturbed_.shape[1], target.shape[1]), dtype=torch.bool)], dim=-1)
1095
+
1096
+ if self.log_type == 'ftb':
1097
+ loss = F.cross_entropy(logits[logits_index], target[0], reduction='sum').cpu().item()
1098
+ else:
1099
+ loss = F.cross_entropy(logits[logits_index], prefix[0], reduction='sum').cpu().item()
1100
+ return loss
1101
+
1102
+ def _encode_pair(self, context, continuation):
1103
+ if self.add_bos_token:
1104
+ context = self.tokenizer.bos_token + context
1105
+
1106
+ n_spaces = len(context) - len(context.rstrip())
1107
+ if n_spaces > 0:
1108
+ continuation = context[-n_spaces:] + continuation
1109
+ context = context[:-n_spaces]
1110
+
1111
+ whole_enc = self.tokenizer.encode(context + continuation) + [self.tokenizer.eos_token_id]
1112
+ context_enc = self.tokenizer.encode(context)
1113
+
1114
+ context_enc_len = len(context_enc)
1115
+ continuation_enc = whole_enc[context_enc_len:]
1116
+
1117
+ # by default truncate on the left
1118
+ cutoff_length = max(len(whole_enc) - self.max_length, 0)
1119
+ if cutoff_length > 0:
1120
+ eval_logger.warning(f"Text length {len(whole_enc)} is larger than {self.max_length}, cutoff on the left side")
1121
+ context_remain = context_enc_len-cutoff_length
1122
+ if context_remain > 0:
1123
+ context_enc = context_enc[-context_remain:]
1124
+ else:
1125
+ eval_logger.warning(f"All context (prompt) is truncated.")
1126
+ context_enc = ""
1127
+ continuation_enc = whole_enc[-self.max_length:]
1128
+ return context_enc, continuation_enc
1129
+
1130
+ def loglikelihood(self, requests: List[Instance]) -> List[Tuple[float, bool]]:
1131
+ def _tokenize(e):
1132
+ prefix, target = self._encode_pair(e["prefix"], e["target"])
1133
+ return {
1134
+ "prefix_text": e["prefix"],
1135
+ "target_text": e["target"],
1136
+ "prefix": prefix,
1137
+ "target": target,
1138
+ }
1139
+
1140
+ ds = []
1141
+ ds = [{"prefix": req.args[0], "target": req.args[1]} for req in requests]
1142
+ ds = Dataset.from_list(ds)
1143
+ print(ds[0])
1144
+ ds = ds.map(_tokenize)
1145
+ ds = ds.with_format("torch")
1146
+
1147
+ out = []
1148
+ with torch.no_grad():
1149
+ for elem in tqdm(ds, desc="Computing likelihood..."):
1150
+ prefix = elem["prefix"]
1151
+ target = elem["target"]
1152
+ # likelihood calculations are modified from https://github.com/ML-GSAI/SMDM/blob/main/evaluate_diff.py
1153
+ if self.nll_type == 'mc':
1154
+ ll = -self._eval_target_nll_mc(prefix, target)
1155
+ if self.log_type == 'union':
1156
+ ll = ll / (len(target) + len(prefix))
1157
+ elif self.nll_type == 'ar_ftb' or self.nll_type == 'ar_btf':
1158
+ ll = -self._eval_target_nll_ar(prefix, target)
1159
+ else:
1160
+ raise NotImplementedError(self.nll_type)
1161
+
1162
+ # TODO: greedy decoding
1163
+ is_target_greedy_dec = False
1164
+
1165
+ out.append((ll, 1.0 if is_target_greedy_dec else 0.0))
1166
+ return out
1167
+
1168
+ def loglikelihood_rolling(self, requests: List[Instance]) -> List[float]:
1169
+ raise NotImplementedError
1170
+
1171
+ def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]:
1172
+ raise NotImplementedError
1173
+
1174
+
1175
+ def _update_block_completion_states(self, block_states, decoded_token_threshold):
1176
+ """
1177
+ Updates the complete/incomplete state of blocks.
1178
+ Iterates through blocks from front to back. If a block's decoded token count exceeds the threshold, the next block to its right (if it exists) is set to a complete state.
1179
+ """
1180
+ for block_id in sorted(block_states.keys()):
1181
+ # if block_id == 0: # Skip prompt block
1182
+ # continue
1183
+
1184
+ # Calculate decoded tokens for the current block
1185
+ decoded_tokens = block_states[block_id]['total_masks'] - block_states[block_id]['mask_count']
1186
+ decode_ratio = decoded_tokens / block_states[block_id]['total_masks']
1187
+ # If current block's decoded token count exceeds the threshold, the next block (if exists) is set to a complete state
1188
+ # print("decode_ratio",decode_ratio)
1189
+ # print("decoded_token_threshold",decoded_token_threshold)
1190
+ if decode_ratio >= decoded_token_threshold:
1191
+ next_block_id = block_id + 1
1192
+ if next_block_id in block_states:
1193
+ block_states[next_block_id]['is_complete'] = True
1194
+
1195
+
1196
+ if __name__ == "__main__":
1197
+ set_seed(1234)
1198
+ cli_evaluate()
eval_llada.sh ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+
4
+ tasks="gsm8k mbpp minerva_math"
5
+ nshots="4 3 0"
6
+ lengths="512 512 512"
7
+ temperatures="0 0 0"
8
+ limits="10000 10000 10000"
9
+ block_sizes="64 32 32"
10
+ block_add_thresholds="0.7 0.9 0.1"
11
+ decoded_token_thresholds="0.95 0.95 0.95"
12
+ skip_thresholds="0.9 0.9 0.9"
13
+ top_ps="none none none"
14
+ dtypes="bfloat16 bfloat16 bfloat16"
15
+ sampling_strategies="default default default"
16
+
17
+
18
+ humaneval_nshots="0"
19
+ humaneval_lengths="512"
20
+ humaneval_temperatures="0"
21
+ humaneval_limits="10000"
22
+ humaneval_diffusion_steps="512"
23
+ humaneval_block_sizes="32"
24
+ humaneval_block_add_thresholds="0.1"
25
+ humaneval_decoded_token_thresholds="0.95"
26
+ humaneval_skip_thresholds="0.9"
27
+ humaneval_top_ps="none"
28
+ humaneval_dtypes="bfloat16"
29
+ humaneval_sampling_strategies="default"
30
+
31
+
32
+ base_model=GSAI-ML/LLaDA-8B-Instruct
33
+
34
+ lora_models=(
35
+ "SJTU-Deng-Lab/D2F_LLaDA_Instruct_8B_Lora"
36
+ )
37
+
38
+ read -ra TASKS_ARRAY <<< "$tasks"
39
+ read -ra NSHOTS_ARRAY <<< "$nshots"
40
+ read -ra LENGTH_ARRAY <<< "$lengths"
41
+ read -ra TEMP_ARRAY <<< "$temperatures"
42
+ read -ra LIMITS_ARRAY <<< "$limits"
43
+ read -ra BLOCK_SIZES_ARRAY <<< "$block_sizes"
44
+ read -ra BLOCK_ADD_THRESHOLDS_ARRAY <<< "$block_add_thresholds"
45
+ read -ra DECODED_TOKEN_THRESHOLDS_ARRAY <<< "$decoded_token_thresholds"
46
+ read -ra SKIP_THRESHOLDS_ARRAY <<< "$skip_thresholds"
47
+ read -ra TOP_PS_ARRAY <<< "$top_ps"
48
+ read -ra DTYPES_ARRAY <<< "$dtypes"
49
+ read -ra SAMPLING_STRATEGIES_ARRAY <<< "$sampling_strategies"
50
+
51
+ read -ra HUMANEVAL_NSHOTS_ARRAY <<< "$humaneval_nshots"
52
+ read -ra HUMANEVAL_LENGTHS_ARRAY <<< "$humaneval_lengths"
53
+ read -ra HUMANEVAL_TEMP_ARRAY <<< "$humaneval_temperatures"
54
+ read -ra HUMANEVAL_LIMITS_ARRAY <<< "$humaneval_limits"
55
+ read -ra HUMANEVAL_DIFFUSION_STEPS_ARRAY <<< "$humaneval_diffusion_steps"
56
+ read -ra HUMANEVAL_BLOCK_SIZES_ARRAY <<< "$humaneval_block_sizes"
57
+ read -ra HUMANEVAL_BLOCK_ADD_THRESHOLDS_ARRAY <<< "$humaneval_block_add_thresholds"
58
+ read -ra HUMANEVAL_DECODED_TOKEN_THRESHOLDS_ARRAY <<< "$humaneval_decoded_token_thresholds"
59
+ read -ra HUMANEVAL_SKIP_THRESHOLDS_ARRAY <<< "$humaneval_skip_thresholds"
60
+ read -ra HUMANEVAL_TOP_PS_ARRAY <<< "$humaneval_top_ps"
61
+ read -ra HUMANEVAL_DTYPES_ARRAY <<< "$humaneval_dtypes"
62
+ read -ra HUMANEVAL_SAMPLING_STRATEGIES_ARRAY <<< "$humaneval_sampling_strategies"
63
+
64
+ array_length=${#TASKS_ARRAY[@]}
65
+ if [[ ${#NSHOTS_ARRAY[@]} -ne $array_length ]] || \
66
+ [[ ${#LENGTH_ARRAY[@]} -ne $array_length ]] || \
67
+ [[ ${#TEMP_ARRAY[@]} -ne $array_length ]] || \
68
+ [[ ${#LIMITS_ARRAY[@]} -ne $array_length ]] || \
69
+ [[ ${#BLOCK_SIZES_ARRAY[@]} -ne $array_length ]] || \
70
+ [[ ${#BLOCK_ADD_THRESHOLDS_ARRAY[@]} -ne $array_length ]] || \
71
+ [[ ${#DECODED_TOKEN_THRESHOLDS_ARRAY[@]} -ne $array_length ]] || \
72
+ [[ ${#SKIP_THRESHOLDS_ARRAY[@]} -ne $array_length ]] || \
73
+ [[ ${#TOP_PS_ARRAY[@]} -ne $array_length ]] || \
74
+ [[ ${#SAMPLING_STRATEGIES_ARRAY[@]} -ne $array_length ]] || \
75
+ [[ ${#DTYPES_ARRAY[@]} -ne $array_length ]]; then
76
+ echo "Error: All configuration arrays must have the same length!"
77
+ echo "Tasks: ${#TASKS_ARRAY[@]}, Nshots: ${#NSHOTS_ARRAY[@]}, Lengths: ${#LENGTH_ARRAY[@]}, Temperatures: ${#TEMP_ARRAY[@]}, Limits: ${#LIMITS_ARRAY[@]}, Block sizes: ${#BLOCK_SIZES_ARRAY[@]}, Block thresholds: ${#BLOCK_ADD_THRESHOLDS_ARRAY[@]}, Decoded token thresholds: ${#DECODED_TOKEN_THRESHOLDS_ARRAY[@]}, Skip thresholds: ${#SKIP_THRESHOLDS_ARRAY[@]}, Top_ps: ${#TOP_PS_ARRAY[@]}, Sampling strategies: ${#SAMPLING_STRATEGIES_ARRAY[@]}, Dtypes: ${#DTYPES_ARRAY[@]}"
78
+ exit 1
79
+ fi
80
+
81
+ humaneval_array_length=${#HUMANEVAL_NSHOTS_ARRAY[@]}
82
+ if [[ ${#HUMANEVAL_LENGTHS_ARRAY[@]} -ne $humaneval_array_length ]] || \
83
+ [[ ${#HUMANEVAL_TEMP_ARRAY[@]} -ne $humaneval_array_length ]] || \
84
+ [[ ${#HUMANEVAL_LIMITS_ARRAY[@]} -ne $humaneval_array_length ]] || \
85
+ [[ ${#HUMANEVAL_DIFFUSION_STEPS_ARRAY[@]} -ne $humaneval_array_length ]] || \
86
+ [[ ${#HUMANEVAL_BLOCK_SIZES_ARRAY[@]} -ne $humaneval_array_length ]] || \
87
+ [[ ${#HUMANEVAL_BLOCK_ADD_THRESHOLDS_ARRAY[@]} -ne $humaneval_array_length ]] || \
88
+ [[ ${#HUMANEVAL_DECODED_TOKEN_THRESHOLDS_ARRAY[@]} -ne $humaneval_array_length ]] || \
89
+ [[ ${#HUMANEVAL_SKIP_THRESHOLDS_ARRAY[@]} -ne $humaneval_array_length ]] || \
90
+ [[ ${#HUMANEVAL_TOP_PS_ARRAY[@]} -ne $humaneval_array_length ]] || \
91
+ [[ ${#HUMANEVAL_DTYPES_ARRAY[@]} -ne $humaneval_array_length ]] || \
92
+ [[ ${#HUMANEVAL_SAMPLING_STRATEGIES_ARRAY[@]} -ne $humaneval_array_length ]]; then
93
+ echo "Error: All HumanEval configuration arrays must have the same length!"
94
+ echo "HumanEval Nshots: ${#HUMANEVAL_NSHOTS_ARRAY[@]}, Lengths: ${#HUMANEVAL_LENGTHS_ARRAY[@]}, Temperatures: ${#HUMANEVAL_TEMP_ARRAY[@]}, Limits: ${#HUMANEVAL_LIMITS_ARRAY[@]}, Diffusion steps: ${#HUMANEVAL_DIFFUSION_STEPS_ARRAY[@]}, Block sizes: ${#HUMANEVAL_BLOCK_SIZES_ARRAY[@]}, Block thresholds: ${#HUMANEVAL_BLOCK_ADD_THRESHOLDS_ARRAY[@]}, Decoded token thresholds: ${#HUMANEVAL_DECODED_TOKEN_THRESHOLDS_ARRAY[@]}, Skip thresholds: ${#HUMANEVAL_SKIP_THRESHOLDS_ARRAY[@]}, Top_ps: ${#HUMANEVAL_TOP_PS_ARRAY[@]}, Dtypes: ${#HUMANEVAL_DTYPES_ARRAY[@]}, Sampling strategies: ${#HUMANEVAL_SAMPLING_STRATEGIES_ARRAY[@]}"
95
+ exit 1
96
+ fi
97
+
98
+ export HF_ALLOW_CODE_EVAL=1
99
+ for lora_model in "${lora_models[@]}"; do
100
+ lora_model_name="$lora_model"
101
+ echo "===================================================================="
102
+ echo "Evaluating LoRA model: $lora_model_name"
103
+ echo "===================================================================="
104
+
105
+
106
+
107
+ for i in "${!HUMANEVAL_NSHOTS_ARRAY[@]}"; do
108
+ output_path="eval_llada${lora_model_name}/humaneval-ns${HUMANEVAL_NSHOTS_ARRAY[$i]}-len${HUMANEVAL_LENGTHS_ARRAY[$i]}-temp${HUMANEVAL_TEMP_ARRAY[$i]}-limit${HUMANEVAL_LIMITS_ARRAY[$i]}-diffsteps${HUMANEVAL_DIFFUSION_STEPS_ARRAY[$i]}-block${HUMANEVAL_BLOCK_SIZES_ARRAY[$i]}-thresh${HUMANEVAL_BLOCK_ADD_THRESHOLDS_ARRAY[$i]}-decodethresh${HUMANEVAL_DECODED_TOKEN_THRESHOLDS_ARRAY[$i]}-skip${HUMANEVAL_SKIP_THRESHOLDS_ARRAY[$i]}-topp${HUMANEVAL_TOP_PS_ARRAY[$i]}-dtype${HUMANEVAL_DTYPES_ARRAY[$i]}-sampling${HUMANEVAL_SAMPLING_STRATEGIES_ARRAY[$i]}"
109
+ echo "Running HumanEval evaluation $((i+1))/${humaneval_array_length} for $lora_model_name..."
110
+ echo "HumanEval Config: Shots: ${HUMANEVAL_NSHOTS_ARRAY[$i]}, Length: ${HUMANEVAL_LENGTHS_ARRAY[$i]}, Temperature: ${HUMANEVAL_TEMP_ARRAY[$i]}, Limit: ${HUMANEVAL_LIMITS_ARRAY[$i]}, Diffusion Steps: ${HUMANEVAL_DIFFUSION_STEPS_ARRAY[$i]}, Block Size: ${HUMANEVAL_BLOCK_SIZES_ARRAY[$i]}, Block Add Threshold: ${HUMANEVAL_BLOCK_ADD_THRESHOLDS_ARRAY[$i]}, Decoded Token Threshold: ${HUMANEVAL_DECODED_TOKEN_THRESHOLDS_ARRAY[$i]}, Skip Threshold: ${HUMANEVAL_SKIP_THRESHOLDS_ARRAY[$i]}, Top_p: ${HUMANEVAL_TOP_PS_ARRAY[$i]}, Sampling Strategy: ${HUMANEVAL_SAMPLING_STRATEGIES_ARRAY[$i]}, Dtype: ${HUMANEVAL_DTYPES_ARRAY[$i]}; Output: $output_path"
111
+
112
+ if [[ "${HUMANEVAL_TOP_PS_ARRAY[$i]}" == "none" ]]; then
113
+ humaneval_model_args="pretrained=${base_model},lora_path=${lora_model},max_new_tokens=${HUMANEVAL_LENGTHS_ARRAY[$i]},diffusion_steps=${HUMANEVAL_DIFFUSION_STEPS_ARRAY[$i]},temperature=${HUMANEVAL_TEMP_ARRAY[$i]},add_bos_token=true,escape_until=true,block_size=${HUMANEVAL_BLOCK_SIZES_ARRAY[$i]},block_add_threshold=${HUMANEVAL_BLOCK_ADD_THRESHOLDS_ARRAY[$i]},skip_threshold=${HUMANEVAL_SKIP_THRESHOLDS_ARRAY[$i]},decoded_token_threshold=${HUMANEVAL_DECODED_TOKEN_THRESHOLDS_ARRAY[$i]},dtype=${HUMANEVAL_DTYPES_ARRAY[$i]},sampling_strategy=${HUMANEVAL_SAMPLING_STRATEGIES_ARRAY[$i]},save_dir=${output_path}"
114
+ else
115
+ humaneval_model_args="pretrained=${base_model},lora_path=${lora_model},max_new_tokens=${HUMANEVAL_LENGTHS_ARRAY[$i]},diffusion_steps=${HUMANEVAL_DIFFUSION_STEPS_ARRAY[$i]},temperature=${HUMANEVAL_TEMP_ARRAY[$i]},top_p=${HUMANEVAL_TOP_PS_ARRAY[$i]},add_bos_token=true,escape_until=true,block_size=${HUMANEVAL_BLOCK_SIZES_ARRAY[$i]},block_add_threshold=${HUMANEVAL_BLOCK_ADD_THRESHOLDS_ARRAY[$i]},skip_threshold=${HUMANEVAL_SKIP_THRESHOLDS_ARRAY[$i]},decoded_token_threshold=${HUMANEVAL_DECODED_TOKEN_THRESHOLDS_ARRAY[$i]},dtype=${HUMANEVAL_DTYPES_ARRAY[$i]},sampling_strategy=${HUMANEVAL_SAMPLING_STRATEGIES_ARRAY[$i]},save_dir=${output_path}"
116
+ fi
117
+
118
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch --main_process_port 29520 --num_processes 8 eval_llada.py --model dream_lora \
119
+ --model_args $humaneval_model_args \
120
+ --tasks humaneval \
121
+ --num_fewshot ${HUMANEVAL_NSHOTS_ARRAY[$i]} \
122
+ --batch_size 1 \
123
+ --output_path $output_path \
124
+ --log_samples \
125
+ --confirm_run_unsafe_code
126
+ done
127
+
128
+ ### NOTICE: use postprocess for humaneval
129
+ # python postprocess_code.py {the samples_xxx.jsonl file under output_path}
130
+
131
+ for i in "${!TASKS_ARRAY[@]}"; do
132
+ output_path="eval_llada${lora_model_name}/${TASKS_ARRAY[$i]}-ns${NSHOTS_ARRAY[$i]}-len${LENGTH_ARRAY[$i]}-temp${TEMP_ARRAY[$i]}-limit${LIMITS_ARRAY[$i]}-diffsteps${LENGTH_ARRAY[$i]}-block${BLOCK_SIZES_ARRAY[$i]}-thresh${BLOCK_ADD_THRESHOLDS_ARRAY[$i]}-decodethresh${DECODED_TOKEN_THRESHOLDS_ARRAY[$i]}-skip${SKIP_THRESHOLDS_ARRAY[$i]}-topp${TOP_PS_ARRAY[$i]}-dtype${DTYPES_ARRAY[$i]}-sampling${SAMPLING_STRATEGIES_ARRAY[$i]}"
133
+ echo "Task: ${TASKS_ARRAY[$i]}, Shots: ${NSHOTS_ARRAY[$i]}, Length: ${LENGTH_ARRAY[$i]}, Temperature: ${TEMP_ARRAY[$i]}, Limit: ${LIMITS_ARRAY[$i]}, Block Size: ${BLOCK_SIZES_ARRAY[$i]}, Block Add Threshold: ${BLOCK_ADD_THRESHOLDS_ARRAY[$i]}, Decoded Token Threshold: ${DECODED_TOKEN_THRESHOLDS_ARRAY[$i]}, Skip Threshold: ${SKIP_THRESHOLDS_ARRAY[$i]}, Top_p: ${TOP_PS_ARRAY[$i]}, Sampling Strategy: ${SAMPLING_STRATEGIES_ARRAY[$i]}, Dtype: ${DTYPES_ARRAY[$i]}; Output: $output_path"
134
+
135
+ if [[ "${TOP_PS_ARRAY[$i]}" == "none" ]]; then
136
+ model_args="pretrained=${base_model},lora_path=${lora_model},max_new_tokens=${LENGTH_ARRAY[$i]},diffusion_steps=${LENGTH_ARRAY[$i]},add_bos_token=true,temperature=${TEMP_ARRAY[$i]},block_size=${BLOCK_SIZES_ARRAY[$i]},block_add_threshold=${BLOCK_ADD_THRESHOLDS_ARRAY[$i]},skip_threshold=${SKIP_THRESHOLDS_ARRAY[$i]},decoded_token_threshold=${DECODED_TOKEN_THRESHOLDS_ARRAY[$i]},dtype=${DTYPES_ARRAY[$i]},sampling_strategy=${SAMPLING_STRATEGIES_ARRAY[$i]},save_dir=${output_path}"
137
+ else
138
+ model_args="pretrained=${base_model},lora_path=${lora_model},max_new_tokens=${LENGTH_ARRAY[$i]},diffusion_steps=${LENGTH_ARRAY[$i]},add_bos_token=true,temperature=${TEMP_ARRAY[$i]},top_p=${TOP_PS_ARRAY[$i]},block_size=${BLOCK_SIZES_ARRAY[$i]},block_add_threshold=${BLOCK_ADD_THRESHOLDS_ARRAY[$i]},skip_threshold=${SKIP_THRESHOLDS_ARRAY[$i]},decoded_token_threshold=${DECODED_TOKEN_THRESHOLDS_ARRAY[$i]},dtype=${DTYPES_ARRAY[$i]},sampling_strategy=${SAMPLING_STRATEGIES_ARRAY[$i]},save_dir=${output_path}"
139
+ fi
140
+
141
+ CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 accelerate launch --main_process_port 29520 --num_processes 8 eval_llada.py --model dream_lora \
142
+ --model_args $model_args \
143
+ --tasks ${TASKS_ARRAY[$i]} \
144
+ --limit ${LIMITS_ARRAY[$i]} \
145
+ --num_fewshot ${NSHOTS_ARRAY[$i]} \
146
+ --batch_size 1 \
147
+ --output_path $output_path \
148
+ --log_samples \
149
+ --confirm_run_unsafe_code \
150
+ --apply_chat_template \
151
+ --fewshot_as_multiturn
152
+ done
153
+ done
154
+
155
+ echo "All evaluations completed!"
generate_llada_demo_ar.py ADDED
@@ -0,0 +1,660 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.distributions as dists
4
+ import transformers
5
+ from transformers import AutoTokenizer, AutoModelForCausalLM
6
+ from peft import PeftModel, PeftConfig
7
+ import numpy as np
8
+ import random
9
+ import time
10
+ import os
11
+ from typing import List, Dict, Optional, Tuple, Iterator, Set
12
+ import gradio as gr
13
+ import gc
14
+
15
+ # Suppress some Hugging Face warnings
16
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
17
+
18
+ # Import necessary model classes
19
+ # Assuming these custom classes are in the correct path
20
+ from model_cache.llada.modeling_llada import LLaDAModelLM
21
+ from model_cache.llada.configuration_llada import LLaDAConfig
22
+
23
+ # --- Helper Functions (Unchanged) ---
24
+ def set_seed(seed):
25
+ torch.manual_seed(seed); random.seed(seed); np.random.seed(seed);
26
+ if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed); torch.backends.cudnn.deterministic = True; torch.backends.cudnn.benchmark = False
27
+
28
+ def create_full_block_attention_mask(prompt_length, max_length, block_size, device=None, dtype=None):
29
+ if dtype is None: dtype = torch.bfloat16
30
+ attention_mask = torch.full((1, 1, max_length, max_length), -torch.inf, device=device, dtype=dtype)
31
+ attention_mask[:, :, :prompt_length, :prompt_length] = 0
32
+ remaining_length = max_length - prompt_length
33
+ num_blocks = (remaining_length + block_size - 1) // block_size
34
+ for b in range(num_blocks):
35
+ block_start = prompt_length + b * block_size; block_end = min(prompt_length + (b + 1) * block_size, max_length)
36
+ attention_mask[:, :, block_start:block_end, :prompt_length] = 0
37
+ for prev_b in range(b):
38
+ prev_start = prompt_length + prev_b * block_size; prev_end = min(prompt_length + (prev_b + 1) * block_size, max_length)
39
+ attention_mask[:, :, block_start:block_end, prev_start:prev_end] = 0
40
+ attention_mask[:, :, block_start:block_end, block_start:block_end] = 0
41
+ return attention_mask
42
+
43
+ def extract_attention_mask(full_mask, start_pos, input_length, cache_length):
44
+ end_pos = start_pos + input_length; total_length = cache_length + input_length
45
+ extracted_mask = torch.full((1, 1, input_length, total_length), -torch.inf, device=full_mask.device, dtype=full_mask.dtype)
46
+ extracted_mask[:, :, :, :cache_length] = full_mask[:, :, start_pos:end_pos, :cache_length]
47
+ extracted_mask[:, :, :, cache_length:] = full_mask[:, :, start_pos:end_pos, start_pos:end_pos]
48
+ return extracted_mask
49
+
50
+ def top_p_logits(logits, top_p=None):
51
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
52
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
53
+ sorted_indices_to_remove = cumulative_probs > top_p
54
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
55
+ sorted_indices_to_remove[..., 0] = 0
56
+ mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
57
+ mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
58
+ logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
59
+ return logits
60
+
61
+ def top_k_logits(logits, top_k=None):
62
+ top_k = min(top_k, logits.size(-1))
63
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
64
+ logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
65
+ return logits
66
+
67
+ def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
68
+ if temperature > 0: logits = logits / temperature
69
+ if top_p is not None and top_p < 1: logits = top_p_logits(logits, top_p)
70
+ if top_k is not None: logits = top_k_logits(logits, top_k)
71
+ probs = torch.softmax(logits, dim=-1)
72
+ if temperature > 0:
73
+ try:
74
+ x0 = dists.Categorical(probs=probs).sample()
75
+ initial_confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
76
+ except: initial_confidence, x0 = probs.max(dim=-1)
77
+ else: initial_confidence, x0 = probs.max(dim=-1)
78
+ confidence = initial_confidence.clone()
79
+ if margin_confidence:
80
+ sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
81
+ confidence = sorted_probs[:, 0] - sorted_probs[:, 1]
82
+ if neg_entropy:
83
+ epsilon = 1e-10
84
+ confidence = torch.sum(probs * torch.log(probs + epsilon), dim=-1)
85
+ return confidence, x0, initial_confidence
86
+
87
+
88
+ class D2FInference:
89
+ CSS = """
90
+ .gradio-container {
91
+ font-family: -apple-system, BlinkMacSystemFont, sans-serif;
92
+ }
93
+ .model-header {
94
+ font-size: 1.2em;
95
+ font-weight: bold;
96
+ margin-bottom: 10px;
97
+ padding: 8px;
98
+ border-radius: 5px;
99
+ text-align: center;
100
+ }
101
+ .d2f-header {
102
+ background-color: #DBEAFE;
103
+ color: #1E40AF;
104
+ }
105
+ .llama-header {
106
+ background-color: #FEF3C7;
107
+ color: #92400E;
108
+ }
109
+ .stats-container {
110
+ padding: 15px;
111
+ border: 1px solid #10B981;
112
+ border-radius: 8px;
113
+ background-color: #F0FDF4;
114
+ margin-top: 10px;
115
+ margin-bottom: 20px;
116
+ }
117
+ .output-textbox textarea {
118
+ font-size: 1.5em !important;
119
+ line-height: 1.6 !important;
120
+ height: 70vh !important;
121
+ overflow-y: auto !important;
122
+ }
123
+ """
124
+
125
+ def __init__(self, **kwargs):
126
+ print("Initializing D2F-LLaDA model...")
127
+ self.device = torch.device(kwargs.get("device", "cuda:3") if torch.cuda.is_available() else "cpu")
128
+ self.__dict__.update(kwargs)
129
+ if self.dtype == "bfloat16" and torch.cuda.is_bf16_supported(): self.target_dtype = torch.bfloat16
130
+ elif self.dtype == "float16": self.target_dtype = torch.float16
131
+ else: self.target_dtype = torch.float32
132
+ self._setup_model(self.pretrained_path, self.lora_path)
133
+ print("D2F-LLaDA model and tokenizer setup complete.")
134
+
135
+ def _setup_model(self, pretrained_path, lora_path):
136
+ config = LLaDAConfig.from_pretrained(pretrained_path)
137
+ self.model = LLaDAModelLM.from_pretrained(pretrained_path, config=config, torch_dtype=self.target_dtype).eval()
138
+ self.model = PeftModel.from_pretrained(self.model, lora_path)
139
+ self.model = self.model.to(self.device)
140
+ self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
141
+ if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token
142
+
143
+ def _apply_chat_template(self, prompt):
144
+ chat_history = [{"role": "user", "content": prompt}]
145
+ return self.tokenizer.apply_chat_template(chat_history, tokenize=False, add_generation_prompt=True)
146
+
147
+ def _update_block_completion_states(self, block_states, decoded_token_threshold):
148
+ for block_id in sorted(block_states.keys()):
149
+ decoded_tokens = block_states[block_id]['total_masks'] - block_states[block_id]['mask_count']
150
+ if block_states[block_id]['total_masks'] > 0:
151
+ decode_ratio = decoded_tokens / block_states[block_id]['total_masks']
152
+ if decode_ratio >= decoded_token_threshold:
153
+ if (next_block_id := block_id + 1) in block_states:
154
+ block_states[next_block_id]['is_complete'] = True
155
+
156
+ @torch.inference_mode()
157
+ def stream(
158
+ self,
159
+ prompt_text: str,
160
+ max_new_tokens: int,
161
+ block_size: int,
162
+ block_add_threshold: float,
163
+ decoded_token_threshold: float,
164
+ skip_threshold: float
165
+ ) -> Iterator[Tuple[str, str]]:
166
+
167
+ start_time = time.time()
168
+
169
+ input_ids = self.tokenizer(self._apply_chat_template(prompt_text), return_tensors="pt").input_ids.to(self.device)
170
+ prompt_length = input_ids.shape[1]
171
+
172
+ full_attention_mask = create_full_block_attention_mask(prompt_length, self.max_length, block_size, self.device, self.target_dtype)
173
+ x_t = input_ids
174
+ block_states = {0: {'start_pos': 0, 'end_pos': prompt_length, 'mask_count': 0, 'total_masks': prompt_length, 'state': 'to_cache', 'is_complete': True}}
175
+ past_key_values, current_blocks, step, eos_detected, cache_length = None, 0, 0, False, 0
176
+
177
+ yield "", None
178
+
179
+ tokens_generated = 0
180
+
181
+ while True:
182
+ step += 1
183
+ updated_block_ids = set()
184
+
185
+ if len(block_states) - 1 < (max_new_tokens // block_size) and not eos_detected:
186
+ last_block_id = max(block_states.keys())
187
+ progress_ratio = (block_states[last_block_id]['total_masks'] - block_states[last_block_id]['mask_count']) / block_states[last_block_id]['total_masks'] if block_states[last_block_id]['total_masks'] > 0 else 1.0
188
+ if progress_ratio >= block_add_threshold:
189
+ new_block_id = last_block_id + 1; new_start_pos = x_t.shape[1]
190
+ if new_start_pos + block_size <= self.max_length:
191
+ x_t = torch.cat([x_t, torch.full((1, block_size), self.mask_token_id, device=self.device, dtype=torch.long)], dim=1)
192
+ block_states[new_block_id] = {'start_pos': new_start_pos, 'end_pos': new_start_pos + block_size, 'mask_count': block_size, 'total_masks': block_size, 'state': 'active', 'is_complete': False}
193
+ current_blocks += 1
194
+
195
+ self._update_block_completion_states(block_states, decoded_token_threshold)
196
+ if (x_t == self.mask_token_id).sum() == 0 and current_blocks == 0: break
197
+
198
+ blocks_to_cache = [bid for bid, state in block_states.items() if state['state'] == 'to_cache']
199
+ update_kvcache = 0
200
+ if blocks_to_cache:
201
+ start_pos, end_pos = block_states[min(blocks_to_cache)]['start_pos'], block_states[max(blocks_to_cache)]['end_pos']
202
+ update_kvcache = end_pos - start_pos; input_seq, process_start_pos = x_t[:, start_pos:], start_pos
203
+ else:
204
+ active_blocks = [bid for bid, state in block_states.items() if state['state'] == 'active' and state['start_pos'] >= cache_length]
205
+ if not active_blocks: break
206
+ start_pos = min(block_states[bid]['start_pos'] for bid in active_blocks); input_seq, process_start_pos = x_t[:, start_pos:], start_pos
207
+
208
+ if input_seq.shape[1] == 0: break
209
+
210
+ attention_mask = extract_attention_mask(full_mask=full_attention_mask,
211
+ start_pos=process_start_pos,
212
+ input_length=input_seq.shape[1],
213
+ cache_length=cache_length)
214
+
215
+ outputs = self.model(input_seq,
216
+ attention_bias=attention_mask,
217
+ past_key_values=past_key_values,
218
+ use_cache=True,
219
+ update_kvcache=update_kvcache + cache_length)
220
+
221
+ if update_kvcache > 0:
222
+ past_key_values = outputs.past_key_values
223
+ for bid in blocks_to_cache:
224
+ block_states[bid]['state'] = 'in_cache'
225
+
226
+ blocks_to_deactivate = []
227
+ for block_id, state in block_states.items():
228
+ if state['state'] != 'active':
229
+ continue
230
+
231
+ block_mask_locs = (x_t[0, state['start_pos']:state['end_pos']] == self.mask_token_id).nonzero().squeeze(-1)
232
+
233
+ if block_mask_locs.numel() == 0:
234
+ blocks_to_deactivate.append(block_id)
235
+ continue
236
+
237
+ logit_offset = state['start_pos'] - process_start_pos
238
+ block_mask_logits = outputs.logits[:, logit_offset + block_mask_locs, :]
239
+ _, x0, initial_confidence = sample_tokens(block_mask_logits.squeeze(0), self.temperature, self.top_p, self.top_k)
240
+ all_indices = (initial_confidence > skip_threshold).nonzero().squeeze(-1)
241
+
242
+ if state['is_complete'] and all_indices.numel() == 0 and block_mask_logits.numel() > 0:
243
+ all_indices = torch.tensor([torch.argmax(initial_confidence)], device=self.device)
244
+
245
+ if all_indices.numel() > 0:
246
+ updated_block_ids.add(block_id)
247
+ positions_to_update = state['start_pos'] + block_mask_locs[all_indices]
248
+ x_t[0, positions_to_update] = x0[all_indices]
249
+ state['mask_count'] -= all_indices.numel()
250
+ tokens_generated += all_indices.numel()
251
+
252
+ if self.tokenizer.eos_token_id in x0[all_indices]:
253
+ eos_detected = True
254
+
255
+ if state['mask_count'] == 0:
256
+ blocks_to_deactivate.append(block_id)
257
+
258
+ for bid in blocks_to_deactivate:
259
+ if block_states[bid]['state'] == 'active' and all(block_states.get(i, {}).get('state') != 'active' for i in range(bid)):
260
+ block_states[bid]['state'] = 'to_cache'
261
+ current_blocks -= 1
262
+
263
+ if update_kvcache > 0:
264
+ cache_length += update_kvcache
265
+
266
+ generated_ids = x_t[0, prompt_length:]
267
+ valid_ids = generated_ids[generated_ids != self.mask_token_id]
268
+ live_text = self.tokenizer.decode(valid_ids, skip_special_tokens=True)
269
+
270
+ yield live_text, None
271
+
272
+ total_time = time.time() - start_time
273
+ final_generated_ids = x_t[0, prompt_length:]
274
+ eos_positions = (final_generated_ids == self.tokenizer.eos_token_id).nonzero()
275
+
276
+ if eos_positions.numel() > 0:
277
+ final_generated_ids = final_generated_ids[:eos_positions[0, 0] + 1]
278
+
279
+ final_text = self.tokenizer.decode(final_generated_ids, skip_special_tokens=True)
280
+
281
+ tokens_incl_eos = len(final_generated_ids)
282
+ tokens_per_second = tokens_incl_eos / total_time if total_time > 0 else 0
283
+
284
+ stats = {
285
+ "total_time": total_time,
286
+ "tokens_generated": tokens_incl_eos,
287
+ "tokens_per_second": tokens_per_second
288
+ }
289
+
290
+ if past_key_values is not None:
291
+ del past_key_values
292
+ del full_attention_mask
293
+ torch.cuda.empty_cache()
294
+
295
+ yield final_text, stats
296
+
297
+
298
+ class LlamaInference:
299
+ def __init__(self, **kwargs):
300
+ print("Initializing LLaMA model...")
301
+ self.device = torch.device(kwargs.get("device", "cuda:4") if torch.cuda.is_available() else "cpu")
302
+ self.__dict__.update(kwargs)
303
+ self._setup_model(self.model_id)
304
+ print("LLaMA model and tokenizer setup complete.")
305
+
306
+ def _setup_model(self, model_id):
307
+ print(f"Loading LLaMA model {model_id} on {self.device}...")
308
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id)
309
+
310
+ self.model = AutoModelForCausalLM.from_pretrained(
311
+ model_id,
312
+ torch_dtype=torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16,
313
+ device_map=self.device
314
+ ).eval()
315
+
316
+ if self.tokenizer.eos_token is None:
317
+ self.tokenizer.eos_token = "</s>"
318
+
319
+ if self.tokenizer.pad_token is None:
320
+ self.tokenizer.pad_token = self.tokenizer.eos_token
321
+
322
+ def _apply_chat_template(self, prompt):
323
+ chat_history = [{"role": "user", "content": prompt}]
324
+ return self.tokenizer.apply_chat_template(chat_history, tokenize=False, add_generation_prompt=True)
325
+
326
+ @torch.inference_mode()
327
+ def stream(
328
+ self,
329
+ prompt_text: str,
330
+ max_new_tokens: int,
331
+ temperature: float = 0.0,
332
+ top_p: float = 0.9,
333
+ top_k: int = None
334
+ ) -> Iterator[Tuple[str, str]]:
335
+
336
+ start_time = time.time()
337
+
338
+ formatted_prompt = self._apply_chat_template(prompt_text)
339
+ input_ids = self.tokenizer(formatted_prompt, return_tensors="pt").input_ids.to(self.device)
340
+ prompt_length = input_ids.shape[1]
341
+
342
+ yield "", None
343
+
344
+ tokens_generated = 0
345
+ current_input_ids = input_ids.clone()
346
+
347
+ for i in range(max_new_tokens):
348
+ with torch.no_grad():
349
+ outputs = self.model(current_input_ids, use_cache=True)
350
+
351
+ next_token_logits = outputs.logits[:, -1, :]
352
+
353
+ if temperature > 0:
354
+ next_token_logits = next_token_logits / temperature
355
+ if top_p is not None and top_p < 1:
356
+ next_token_logits = top_p_logits(next_token_logits, top_p)
357
+ if top_k is not None:
358
+ next_token_logits = top_k_logits(next_token_logits, top_k)
359
+ probs = torch.softmax(next_token_logits, dim=-1)
360
+ next_token = torch.multinomial(probs, num_samples=1)
361
+ else:
362
+ next_token = torch.argmax(next_token_logits, dim=-1, keepdim=True)
363
+
364
+ current_input_ids = torch.cat([current_input_ids, next_token], dim=-1)
365
+ tokens_generated += 1
366
+
367
+ if next_token[0, 0].item() == self.tokenizer.eos_token_id:
368
+ break
369
+
370
+ generated_text = self.tokenizer.decode(
371
+ current_input_ids[0, prompt_length:],
372
+ skip_special_tokens=True
373
+ )
374
+
375
+ yield generated_text, None
376
+
377
+ del outputs
378
+
379
+ total_time = time.time() - start_time
380
+ tokens_per_second = tokens_generated / total_time if total_time > 0 else 0
381
+
382
+ final_text = self.tokenizer.decode(current_input_ids[0, prompt_length:], skip_special_tokens=True)
383
+
384
+ stats = {
385
+ "total_time": total_time,
386
+ "tokens_generated": tokens_generated,
387
+ "tokens_per_second": tokens_per_second
388
+ }
389
+
390
+ del current_input_ids
391
+ torch.cuda.empty_cache()
392
+
393
+ yield final_text, stats
394
+
395
+
396
+ # --- Comparison Helper Functions ---
397
+ def create_comparison_html(d2f_results, llama_results):
398
+ d_tokens = d2f_results["tokens_generated"]
399
+ d_time = d2f_results["total_time"]
400
+ d_tokens_per_sec = d2f_results["tokens_per_second"]
401
+
402
+ a_tokens = llama_results["tokens_generated"]
403
+ a_time = llama_results["total_time"]
404
+ a_tokens_per_sec = llama_results["tokens_per_second"]
405
+
406
+ if a_tokens_per_sec > 0:
407
+ speedup = d_tokens_per_sec / a_tokens_per_sec
408
+ else:
409
+ speedup = 0
410
+
411
+ comparison_html = f"""
412
+ <div class="stats-container" style="background-color: #F9FAFB; border-color: #6366F1;">
413
+ <h3>⚡ Performance Comparison</h3>
414
+ <table style="width:100%; text-align: left; border-collapse: collapse;">
415
+ <tr style="background-color: #EEF2FF;">
416
+ <th style="padding: 8px; border: 1px solid #ddd;">Metric</th>
417
+ <th style="padding: 8px; border: 1px solid #ddd;">D2F-LLaDA-Instruct-8B</th>
418
+ <th style="padding: 8px; border: 1px solid #ddd;">LLaMA3-Instruct-8B</th>
419
+ <th style="padding: 8px; border: 1px solid #ddd;">Difference</th>
420
+ </tr>
421
+ <tr>
422
+ <td style="padding: 8px; border: 1px solid #ddd;">Total tokens</td>
423
+ <td style="padding: 8px; border: 1px solid #ddd;">{d_tokens}</td>
424
+ <td style="padding: 8px; border: 1px solid #ddd;">{a_tokens}</td>
425
+ <td style="padding: 8px; border: 1px solid #ddd;">-</td>
426
+ </tr>
427
+ <tr>
428
+ <td style="padding: 8px; border: 1px solid #ddd;">Generation time</td>
429
+ <td style="padding: 8px; border: 1px solid #ddd;">{d_time:.2f}s</td>
430
+ <td style="padding: 8px; border: 1px solid #ddd;">{a_time:.2f}s</td>
431
+ <td style="padding: 8px; border: 1px solid #ddd;">
432
+ {"D2F-LLaDA is " + f"{(a_time/d_time):.1f}x faster" if d_time > 0 and d_time < a_time else "LLaMA3 is " + f"{(d_time/a_time):.1f}x faster"}
433
+ </td>
434
+ </tr>
435
+ <tr>
436
+ <td style="padding: 8px; border: 1px solid #ddd;">Tokens per second</td>
437
+ <td style="padding: 8px; border: 1px solid #ddd;">{d_tokens_per_sec:.2f}</td>
438
+ <td style="padding: 8px; border: 1px solid #ddd;">{a_tokens_per_sec:.2f}</td>
439
+ <td style="padding: 8px; border: 1px solid #ddd;">
440
+ {"D2F-LLaDA is " + f"{speedup:.1f}x faster" if speedup > 1 else "LLaMA3 is " + f"{(1/speedup if speedup > 0 else 0):.1f}x faster"}
441
+ </td>
442
+ </tr>
443
+ </table>
444
+ </div>
445
+ """
446
+
447
+ return comparison_html
448
+
449
+
450
+ def create_stats_html(model_name, results):
451
+ stats_html = f"""
452
+ <div class="stats-container">
453
+ <h3>✓ {model_name} Generation Complete</h3>
454
+ <ul>
455
+ <li><b>Total time:</b> {results["total_time"]:.2f} seconds</li>
456
+ <li><b>Tokens generated:</b> {results["tokens_generated"]}</li>
457
+ <li><b>Tokens per second:</b> {results["tokens_per_second"]:.2f}</li>
458
+ </ul>
459
+ </div>
460
+ """
461
+
462
+ return stats_html
463
+
464
+
465
+ # --- Main Interface ---
466
+ if __name__ == "__main__":
467
+ os.environ["CUDA_VISIBLE_DEVICES"] = "3,4"
468
+
469
+ torch.cuda.empty_cache()
470
+
471
+ d2f_config = {
472
+ "pretrained_path": "GSAI-ML/LLaDA-8B-Instruct",
473
+ "lora_path": "SJTU-Deng-Lab/D2F_LLaDA_Instruct_8B_Lora",
474
+ "device": "cuda:0",
475
+ "dtype": "bfloat16",
476
+ "max_length": 4096,
477
+ "temperature": 0.0,
478
+ "top_p": None,
479
+ "top_k": None,
480
+ "mask_token_id": 126336,
481
+ "sampling_strategy": "default",
482
+ }
483
+
484
+ llama_config = {
485
+ "model_id": "meta-llama/Llama-3.1-8B-Instruct",
486
+ "device": "cuda:1",
487
+ }
488
+
489
+ set_seed(42)
490
+
491
+ d2f_engine = D2FInference(**d2f_config)
492
+ llama_engine = LlamaInference(**llama_config)
493
+
494
+ with gr.Blocks(css=D2FInference.CSS, theme=gr.themes.Soft()) as demo:
495
+ gr.Markdown("# 🚀 D2F-LLaDA vs LLaMA3: Speed Comparison")
496
+
497
+ with gr.Row():
498
+ with gr.Column(scale=1):
499
+ prompt_input = gr.Textbox(
500
+ label="Enter your question",
501
+ placeholder="Example: Natalia sold clips to...",
502
+ lines=5
503
+ )
504
+ generate_button = gr.Button("🚀 Run Speed Comparison", variant="primary")
505
+
506
+ with gr.Accordion("⚙️ D2F-LLaDA Parameter Settings", open=True):
507
+ with gr.Row():
508
+ max_new_tokens_slider = gr.Slider(
509
+ minimum=64, maximum=2048, value=1024, step=64,
510
+ label="Max Tokens to Generate"
511
+ )
512
+ block_size_slider = gr.Slider(
513
+ minimum=16, maximum=128, value=32, step=16,
514
+ label="Block Size"
515
+ )
516
+ with gr.Row():
517
+ block_add_thresh_slider = gr.Slider(
518
+ minimum=0.0, maximum=1.0, value=0.1, step=0.05,
519
+ label="Block Add Threshold"
520
+ )
521
+ decoded_token_thresh_slider = gr.Slider(
522
+ minimum=0.0, maximum=1.0, value=0.5, step=0.05,
523
+ label="Decoding Completion Threshold"
524
+ )
525
+ skip_thresh_slider = gr.Slider(
526
+ minimum=0.0, maximum=1.0, value=0.9, step=0.01,
527
+ label="Skip Threshold"
528
+ )
529
+
530
+ comparison_output = gr.HTML(label="Performance Comparison", elem_id="comparison-container")
531
+
532
+ with gr.Row():
533
+ with gr.Column(scale=1):
534
+ gr.HTML("<div class='model-header d2f-header'>✨ D2F-LLaDA-Instruct-8B (Parallel Decoding)</div>")
535
+ d2f_output = gr.Textbox(
536
+ label="D2F-LLaDA Output",
537
+ interactive=False,
538
+ elem_classes=["output-textbox"]
539
+ )
540
+ d2f_status = gr.HTML(label="D2F-LLaDA Stats")
541
+
542
+ with gr.Column(scale=1):
543
+ gr.HTML("<div class='model-header llama-header'>🔄 LLaMA3-Instruct-8B (Standard)</div>")
544
+ llama_output = gr.Textbox(
545
+ label="LLaMA3 Output",
546
+ interactive=False,
547
+ elem_classes=["output-textbox"]
548
+ )
549
+ llama_status = gr.HTML(label="LLaMA3 Stats")
550
+
551
+ gr.Examples(
552
+ examples=[
553
+ ["Solve the equation x² - 6x + 8 = 0. First, explain what a quadratic equation is and why it can have up to two solutions. Then solve this equation using three different methods: factoring, completing the square, and the quadratic formula. For each method, explain the mathematical reasoning behind it, show all steps in detail, and discuss when this particular method is most useful. Finally, verify your solutions by substituting them back into the original equation.", 1024, 32, 0.1, 0.55, 0.9],
554
+ ["A circular swimming pool has a diameter of 8 meters. Calculate the pool's circumference and area. First, explain the relationship between diameter, radius, circumference, and area of a circle, including the role of π in these formulas. Then perform the calculations using π ≈ 3.14159. Next, estimate how much water (in cubic meters) would be needed to fill this pool if it has a uniform depth of 1.5 meters. Finally, calculate how much it would cost to fill this pool if water costs $2.50 per cubic meter. Show all steps and include appropriate units in your answer.", 1024, 32, 0.1, 0.5, 0.9],
555
+ ["A movie theater offers a loyalty card that costs $15 and gives a 15% discount on all tickets. If a regular movie ticket costs $10, how many tickets would you need to buy to make the loyalty card worthwhile? First, explain the concept of a break-even point. Then set up an equation to find when the total cost with the card equals the total cost without the card. Solve this equation step by step, showing all your work. Finally, interpret your answer in the context of the problem.", 1024, 32, 0.1, 0.5, 0.9],
556
+ ],
557
+ inputs=[
558
+ prompt_input, max_new_tokens_slider, block_size_slider,
559
+ block_add_thresh_slider, decoded_token_thresh_slider, skip_thresh_slider
560
+ ],
561
+ label="Examples (Math Problems)"
562
+ )
563
+
564
+ def run_models_streaming(
565
+ prompt_text,
566
+ max_new_tokens,
567
+ block_size,
568
+ block_add_threshold,
569
+ decoded_token_threshold,
570
+ skip_threshold
571
+ ):
572
+ torch.cuda.empty_cache()
573
+
574
+ d2f_generator = d2f_engine.stream(
575
+ prompt_text=prompt_text,
576
+ max_new_tokens=max_new_tokens,
577
+ block_size=block_size,
578
+ block_add_threshold=block_add_threshold,
579
+ decoded_token_threshold=decoded_token_threshold,
580
+ skip_threshold=skip_threshold
581
+ )
582
+
583
+ llama_generator = llama_engine.stream(
584
+ prompt_text=prompt_text,
585
+ max_new_tokens=max_new_tokens
586
+ )
587
+
588
+ d2f_text = ""
589
+ llama_text = ""
590
+ d2f_stats = None
591
+ llama_stats = None
592
+
593
+ yield d2f_text, llama_text, "", "", ""
594
+
595
+ d2f_done = False
596
+ llama_done = False
597
+
598
+ while not (d2f_done and llama_done):
599
+ if not d2f_done:
600
+ try:
601
+ new_d2f_text, new_d2f_stats = next(d2f_generator)
602
+ d2f_text = new_d2f_text
603
+ if new_d2f_stats is not None:
604
+ d2f_stats = new_d2f_stats
605
+ d2f_done = True
606
+ except StopIteration:
607
+ d2f_done = True
608
+
609
+ if not llama_done:
610
+ try:
611
+ new_llama_text, new_llama_stats = next(llama_generator)
612
+ llama_text = new_llama_text
613
+ if new_llama_stats is not None:
614
+ llama_stats = new_llama_stats
615
+ llama_done = True
616
+ except StopIteration:
617
+ llama_done = True
618
+
619
+ d2f_status_html = create_stats_html("D2F-LLaDA", d2f_stats) if d2f_stats else ""
620
+ llama_status_html = create_stats_html("LLaMA3", llama_stats) if llama_stats else ""
621
+
622
+ comparison = ""
623
+ if d2f_done and llama_done and d2f_stats and llama_stats:
624
+ comparison = create_comparison_html(d2f_stats, llama_stats)
625
+
626
+ yield d2f_text, llama_text, d2f_status_html, llama_status_html, comparison
627
+
628
+ # MODIFICATION: Removed the _js parameter from here
629
+ generate_button.click(
630
+ fn=run_models_streaming,
631
+ inputs=[
632
+ prompt_input, max_new_tokens_slider, block_size_slider,
633
+ block_add_thresh_slider, decoded_token_thresh_slider, skip_thresh_slider
634
+ ],
635
+ outputs=[
636
+ d2f_output, llama_output,
637
+ d2f_status, llama_status,
638
+ comparison_output
639
+ ]
640
+ )
641
+
642
+ # MODIFICATION: Added a hidden HTML component with a script for auto-scrolling
643
+ # This method is compatible with older Gradio versions.
644
+ gr.HTML(
645
+ """
646
+ <script>
647
+ function_to_run = () => {
648
+ const textboxes = document.querySelectorAll('.output-textbox textarea');
649
+ textboxes.forEach(textbox => {
650
+ textbox.scrollTop = textbox.scrollHeight;
651
+ });
652
+ }
653
+ // Run the function every 250ms to ensure autoscrolling
654
+ setInterval(function_to_run, 250);
655
+ </script>
656
+ """,
657
+ visible=False
658
+ )
659
+
660
+ demo.queue().launch(share=True)
generate_llada_demo_block.py ADDED
@@ -0,0 +1,630 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.distributions as dists
4
+ import transformers
5
+ from transformers import AutoTokenizer
6
+ from peft import PeftModel, PeftConfig
7
+ import numpy as np
8
+ import random
9
+ import time
10
+ import os
11
+ from typing import List, Dict, Optional, Tuple, Iterator, Set
12
+ import gradio as gr
13
+ import ipdb
14
+ # Suppress some Hugging Face warnings
15
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
16
+
17
+ # Import necessary model classes
18
+ from model_cache.llada.modeling_llada import LLaDAModelLM
19
+ from model_cache.llada.configuration_llada import LLaDAConfig
20
+
21
+ # --- Helper Functions (Unchanged) ---
22
+ def set_seed(seed):
23
+ torch.manual_seed(seed); random.seed(seed); np.random.seed(seed);
24
+ if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed); torch.backends.cudnn.deterministic = True; torch.backends.cudnn.benchmark = False
25
+ def create_full_block_attention_mask(prompt_length, max_length, block_size, device=None, dtype=None):
26
+ if dtype is None: dtype = torch.bfloat16
27
+ attention_mask = torch.full((1, 1, max_length, max_length), -torch.inf, device=device, dtype=dtype)
28
+ attention_mask[:, :, :prompt_length, :prompt_length] = 0
29
+ remaining_length = max_length - prompt_length
30
+ num_blocks = (remaining_length + block_size - 1) // block_size
31
+ for b in range(num_blocks):
32
+ block_start = prompt_length + b * block_size; block_end = min(prompt_length + (b + 1) * block_size, max_length)
33
+ attention_mask[:, :, block_start:block_end, :prompt_length] = 0
34
+ for prev_b in range(b):
35
+ prev_start = prompt_length + prev_b * block_size; prev_end = min(prompt_length + (prev_b + 1) * block_size, max_length)
36
+ attention_mask[:, :, block_start:block_end, prev_start:prev_end] = 0
37
+ attention_mask[:, :, block_start:block_end, block_start:block_end] = 0
38
+ return attention_mask
39
+ def extract_attention_mask(full_mask, start_pos, input_length, cache_length):
40
+ end_pos = start_pos + input_length; total_length = cache_length + input_length
41
+ extracted_mask = torch.full((1, 1, input_length, total_length), -torch.inf, device=full_mask.device, dtype=full_mask.dtype)
42
+ extracted_mask[:, :, :, :cache_length] = full_mask[:, :, start_pos:end_pos, :cache_length]
43
+ extracted_mask[:, :, :, cache_length:] = full_mask[:, :, start_pos:end_pos, start_pos:end_pos]
44
+ return extracted_mask
45
+ def top_p_logits(logits, top_p=None):
46
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
47
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
48
+ sorted_indices_to_remove = cumulative_probs > top_p
49
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
50
+ sorted_indices_to_remove[..., 0] = 0
51
+ mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
52
+ mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
53
+ logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
54
+ return logits
55
+ def top_k_logits(logits, top_k=None):
56
+ top_k = min(top_k, logits.size(-1))
57
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
58
+ logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
59
+ return logits
60
+ def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
61
+ if temperature > 0: logits = logits / temperature
62
+ if top_p is not None and top_p < 1: logits = top_p_logits(logits, top_p)
63
+ if top_k is not None: logits = top_k_logits(logits, top_k)
64
+ probs = torch.softmax(logits, dim=-1)
65
+ if temperature > 0:
66
+ try:
67
+ x0 = dists.Categorical(probs=probs).sample()
68
+ initial_confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
69
+ except: initial_confidence, x0 = probs.max(dim=-1)
70
+ else: initial_confidence, x0 = probs.max(dim=-1)
71
+ confidence = initial_confidence.clone()
72
+ if margin_confidence:
73
+ sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
74
+ confidence = sorted_probs[:, 0] - sorted_probs[:, 1]
75
+ if neg_entropy:
76
+ epsilon = 1e-10
77
+ confidence = torch.sum(probs * torch.log(probs + epsilon), dim=-1)
78
+ return confidence, x0, initial_confidence
79
+
80
+
81
+ class DreamLoRAInference:
82
+ CSS = """
83
+ /* Fixed height, scrollable visualization container */
84
+ #viz-container {
85
+ height: 500px;
86
+ overflow-y: auto !important;
87
+ border: 1px solid #E5E7EB;
88
+ border-radius: 8px;
89
+ padding: 10px;
90
+ position: relative;
91
+ }
92
+ .block-container {
93
+ display: inline-block; border: 2px solid transparent; border-radius: 8px;
94
+ padding: 5px; margin: 4px 0; transition: border-color 0.3s, box-shadow 0.3s;
95
+ }
96
+ .block-updating {
97
+ border-color: #FF4500 !important;
98
+ box-shadow: 0 0 8px rgba(255, 69, 0, 0.7);
99
+ }
100
+ .token { padding: 2px 4px; margin: 2px; border-radius: 4px; display: inline-block; line-height: 1.4; font-family: monospace; }
101
+ .token.prompt { background-color: #E5E7EB; color: #4B5563; }
102
+ .token.gen-0 { background-color: #DBEAFE; color: #1E40AF; } /* Blue */
103
+ .token.gen-1 { background-color: #D1FAE5; color: #065F46; } /* Green */
104
+ .token.gen-2 { background-color: #FEF3C7; color: #92400E; } /* Yellow */
105
+ .token.gen-3 { background-color: #FEE2E2; color: #991B1B; } /* Red */
106
+ .token.gen-4 { background-color: #E0E7FF; color: #3730A3; } /* Indigo */
107
+ .token.gen-5 { background-color: #F3E8FF; color: #6B21A8; } /* Purple */
108
+ .token.mask { background-color: #F3F4F6; color: #9CA3AF; border: 1px dashed #D1D5DB; }
109
+
110
+ /* Independent status box styles */
111
+ #status-container {
112
+ height: 300px;
113
+ overflow-y: auto !important;
114
+ margin-top: 10px; padding: 15px; border: 1px solid #E5E7EB; border-radius: 8px; background-color: #F9FAFB;
115
+ position: relative;
116
+ }
117
+ #status-container h4 { margin-top: 0; }
118
+ .status-line { font-family: monospace; font-size: 13px; margin-bottom: 5px; margin-top: 5px; padding: 2px 4px; border-radius: 3px;}
119
+ #stats-output { padding: 15px; border: 1px solid #10B981; border-radius: 8px; background-color: #F0FDF4; margin-top: 10px; }
120
+
121
+ /* Scroll anchor */
122
+ .scroll-anchor {
123
+ height: 1px;
124
+ width: 100%;
125
+ }
126
+
127
+ /* Force scrollbar styles */
128
+ #viz-container::-webkit-scrollbar, #status-container::-webkit-scrollbar {
129
+ width: 10px !important;
130
+ background-color: #f5f5f5 !important;
131
+ }
132
+ #viz-container::-webkit-scrollbar-thumb, #status-container::-webkit-scrollbar-thumb {
133
+ background-color: #888 !important;
134
+ border-radius: 5px !important;
135
+ }
136
+ #viz-container::-webkit-scrollbar-track, #status-container::-webkit-scrollbar-track {
137
+ background-color: #f5f5f5 !important;
138
+ border-radius: 5px !important;
139
+ }
140
+
141
+ /* Column height alignment */
142
+ .left-column, .right-column {
143
+ display: flex;
144
+ flex-direction: column;
145
+ height: auto !important;
146
+ min-height: 800px;
147
+ }
148
+
149
+ .live-text-container, .viz-status-container {
150
+ display: flex;
151
+ flex-direction: column;
152
+ flex: 1;
153
+ overflow: visible;
154
+ }
155
+
156
+ #live-text-output, #stats-output {
157
+ margin-bottom: 20px;
158
+ }
159
+
160
+ /* Fix for bottom content being cut off */
161
+ .container {
162
+ padding-bottom: 40px;
163
+ }
164
+
165
+ /* Make sure content is fully visible */
166
+ .gradio-container {
167
+ overflow-y: visible !important;
168
+ }
169
+
170
+ /* Add padding to bottom of page */
171
+ .footer {
172
+ margin-top: 30px;
173
+ padding-bottom: 30px;
174
+ }
175
+ """
176
+
177
+ def __init__(self, **kwargs):
178
+ print("Initializing DreamLoRAInference...")
179
+ self.device = torch.device(kwargs.get("device", "cuda") if torch.cuda.is_available() else "cpu")
180
+ self.__dict__.update(kwargs)
181
+ if self.dtype == "bfloat16" and torch.cuda.is_bf16_supported(): self.target_dtype = torch.bfloat16
182
+ elif self.dtype == "float16": self.target_dtype = torch.float16
183
+ else: self.target_dtype = torch.float32
184
+ self._setup_model(self.pretrained_path, self.lora_path)
185
+ print("Model and tokenizer setup complete.")
186
+
187
+ def _setup_model(self, pretrained_path, lora_path):
188
+ config = LLaDAConfig.from_pretrained(pretrained_path)
189
+ self.model = LLaDAModelLM.from_pretrained(pretrained_path, config=config, torch_dtype=self.target_dtype).eval()
190
+ self.model = PeftModel.from_pretrained(self.model, lora_path)
191
+ self.model = self.model.to(self.device)
192
+ self.tokenizer = AutoTokenizer.from_pretrained(pretrained_path)
193
+ if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token
194
+
195
+ def _apply_chat_template(self, prompt):
196
+ chat_history = [{"role": "user", "content": prompt}]
197
+ return self.tokenizer.apply_chat_template(chat_history, tokenize=False, add_generation_prompt=True)
198
+
199
+ def _update_block_completion_states(self, block_states, decoded_token_threshold):
200
+ for block_id in sorted(block_states.keys()):
201
+ decoded_tokens = block_states[block_id]['total_masks'] - block_states[block_id]['mask_count']
202
+ if block_states[block_id]['total_masks'] > 0:
203
+ decode_ratio = decoded_tokens / block_states[block_id]['total_masks']
204
+ if decode_ratio >= decoded_token_threshold:
205
+ if (next_block_id := block_id + 1) in block_states:
206
+ block_states[next_block_id]['is_complete'] = True
207
+
208
+ # Render visualization part (excluding prompt status info)
209
+ def _render_visualization_html(self, step: int, x_t: torch.Tensor, block_states: Dict, cache_length: int, updated_block_ids: Set[int]) -> str:
210
+ timestamp = int(time.time() * 1000)
211
+
212
+ html_parts = []
213
+ for block_id in sorted(k for k in block_states.keys() if k > 0): # Only render generated part (block_id > 0)
214
+ state = block_states[block_id]
215
+ container_classes = ["block-container"]
216
+ if block_id in updated_block_ids: container_classes.append("block-updating")
217
+ html_parts.append(f'<div class="{" ".join(container_classes)}" id="block-{block_id}-{timestamp}">')
218
+ block_tokens = x_t[0, state['start_pos']:state['end_pos']]
219
+ for token_id in block_tokens:
220
+ token_id_int = token_id.item()
221
+ token_classes = ["token"]
222
+ if token_id_int == self.mask_token_id:
223
+ token_str = '░'; token_classes.append("mask")
224
+ else:
225
+ token_str = self.tokenizer.decode([token_id_int], skip_special_tokens=False)
226
+ token_str = token_str.replace('&', '&amp;').replace('<', '&lt;').replace('>', '&gt;')
227
+ token_classes.append(f"gen-{(block_id - 1) % 6}")
228
+ html_parts.append(f'<span class="{" ".join(token_classes)}">{token_str}</span>')
229
+ html_parts.append('</div>')
230
+
231
+ html_parts.append(f'<div class="scroll-anchor" id="viz-anchor-{timestamp}"></div>')
232
+
233
+ complete_html = f"""
234
+ <div class="viz-content" id="viz-content-{timestamp}">
235
+ {''.join(html_parts)}
236
+ </div>
237
+
238
+ <script>
239
+ function executeVizScroll() {{
240
+ const container = document.getElementById('viz-container');
241
+ const anchor = document.getElementById('viz-anchor-{timestamp}');
242
+ if (container && anchor) {{
243
+ try {{
244
+ container.scrollTo(0, container.scrollHeight);
245
+ container.scrollTop = container.scrollHeight;
246
+ anchor.scrollIntoView({{behavior: 'auto', block: 'end'}});
247
+ }} catch (e) {{
248
+ console.error('Scroll error:', e);
249
+ }}
250
+ }}
251
+ }}
252
+
253
+ setTimeout(executeVizScroll, 10);
254
+ setTimeout(executeVizScroll, 50);
255
+ setTimeout(executeVizScroll, 150);
256
+ setTimeout(executeVizScroll, 300);
257
+
258
+ try {{
259
+ const vizContent = document.getElementById('viz-content-{timestamp}');
260
+ const vizContainer = document.getElementById('viz-container');
261
+
262
+ if (vizContent && vizContainer) {{
263
+ const resizeObserver = new ResizeObserver(() => {{
264
+ executeVizScroll();
265
+ }});
266
+ resizeObserver.observe(vizContent);
267
+
268
+ const mutationObserver = new MutationObserver(() => {{
269
+ executeVizScroll();
270
+ }});
271
+ mutationObserver.observe(vizContainer, {{
272
+ childList: true,
273
+ subtree: true,
274
+ characterData: true
275
+ }});
276
+ }}
277
+ }} catch (e) {{
278
+ console.error('Observer error:', e);
279
+ }}
280
+ </script>
281
+ """
282
+
283
+ return complete_html
284
+
285
+ # Render status box part (only shows generation block information)
286
+ def _render_status_html(self, step: int, block_states: Dict, cache_length: int) -> str:
287
+ timestamp = int(time.time() * 1000)
288
+
289
+ html_parts = []
290
+ html_parts.append(f'<h4>Generation Block Status (Step: {step}, Cache Length: {cache_length})</h4>')
291
+ for block_id in [k for k in sorted(block_states.keys()) if k > 0]:
292
+ state = block_states[block_id]
293
+ block_type = f"Block {block_id}"
294
+ masks_filled = state['total_masks'] - state['mask_count']
295
+ color_class = f"gen-{(block_id - 1) % 6}"
296
+ status_line = f'<b>{block_type.ljust(8)}</b>: Pos=[{str(state["start_pos"]).rjust(4)}:{str(state["end_pos"]).ljust(4)}] | State=\'{state["state"].ljust(8)}\' | Filled={str(masks_filled).rjust(2)}/{state["total_masks"]}'
297
+ html_parts.append(f'<p class="status-line token {color_class}" id="status-line-{block_id}-{timestamp}">{status_line}</p>')
298
+
299
+ html_parts.append(f'<div class="scroll-anchor" id="status-anchor-{timestamp}"></div>')
300
+
301
+ complete_html = f"""
302
+ <div class="status-content" id="status-content-{timestamp}">
303
+ {''.join(html_parts)}
304
+ </div>
305
+
306
+ <script>
307
+ function executeStatusScroll() {{
308
+ const container = document.getElementById('status-container');
309
+ const anchor = document.getElementById('status-anchor-{timestamp}');
310
+ if (container && anchor) {{
311
+ try {{
312
+ container.scrollTo(0, container.scrollHeight);
313
+ container.scrollTop = container.scrollHeight;
314
+ anchor.scrollIntoView({{behavior: 'auto', block: 'end'}});
315
+ }} catch (e) {{
316
+ console.error('Status scroll error:', e);
317
+ }}
318
+ }}
319
+ }}
320
+
321
+ setTimeout(executeStatusScroll, 10);
322
+ setTimeout(executeStatusScroll, 50);
323
+ setTimeout(executeStatusScroll, 150);
324
+ setTimeout(executeStatusScroll, 300);
325
+
326
+ try {{
327
+ const statusContent = document.getElementById('status-content-{timestamp}');
328
+ const statusContainer = document.getElementById('status-container');
329
+
330
+ if (statusContent && statusContainer) {{
331
+ const resizeObserver = new ResizeObserver(() => {{
332
+ executeStatusScroll();
333
+ }});
334
+ resizeObserver.observe(statusContent);
335
+
336
+ const mutationObserver = new MutationObserver(() => {{
337
+ executeStatusScroll();
338
+ }});
339
+ mutationObserver.observe(statusContainer, {{
340
+ childList: true,
341
+ subtree: true,
342
+ characterData: true
343
+ }});
344
+ }}
345
+ }} catch (e) {{
346
+ console.error('Status observer error:', e);
347
+ }}
348
+ </script>
349
+ """
350
+
351
+ return complete_html
352
+
353
+ @torch.inference_mode()
354
+ def stream_and_capture_for_gradio(
355
+ self,
356
+ prompt_text: str,
357
+ max_new_tokens: int,
358
+ block_size: int,
359
+ block_add_threshold: float,
360
+ decoded_token_threshold: float,
361
+ skip_threshold: float
362
+ ) -> Iterator[Tuple[str, List[Tuple[str, str]], str, str, str]]:
363
+
364
+ start_time = time.time()
365
+ captured_frames: List[Tuple[str, str]] = []
366
+
367
+ # Initialization
368
+ ipdb.set_trace()
369
+ input_ids = self.tokenizer(self._apply_chat_template(prompt_text), return_tensors="pt").input_ids.to(self.device)
370
+ prompt_length = input_ids.shape[1]
371
+
372
+ full_attention_mask = create_full_block_attention_mask(prompt_length, self.max_length, block_size, self.device, self.target_dtype)
373
+ x_t = input_ids
374
+ block_states = {0: {'start_pos': 0, 'end_pos': prompt_length, 'mask_count': 0, 'total_masks': prompt_length, 'state': 'to_cache', 'is_complete': True}}
375
+ past_key_values, current_blocks, step, eos_detected, cache_length = None, 0, 0, False, 0
376
+
377
+ # Capture initial state
378
+ initial_viz_html = self._render_visualization_html(0, x_t, block_states, 0, set())
379
+ initial_status_html = self._render_status_html(0, block_states, 0)
380
+ captured_frames.append((initial_viz_html, initial_status_html))
381
+
382
+ yield "", captured_frames, "Initializing generation process...", "Initializing visualization...", "Initializing block status..."
383
+
384
+ # Main generation loop
385
+ while True:
386
+ step += 1
387
+ updated_block_ids: Set[int] = set()
388
+
389
+ if len(block_states) - 1 < (max_new_tokens // block_size) and not eos_detected:
390
+ last_block_id = max(block_states.keys())
391
+ progress = (block_states[last_block_id]['total_masks'] - block_states[last_block_id]['mask_count']) / block_states[last_block_id]['total_masks'] if block_states[last_block_id]['total_masks'] > 0 else 1.0
392
+ if progress >= block_add_threshold:
393
+ new_block_id = last_block_id + 1; new_start_pos = x_t.shape[1]
394
+ if new_start_pos + block_size <= self.max_length:
395
+ x_t = torch.cat([x_t, torch.full((1, block_size), self.mask_token_id, device=self.device, dtype=torch.long)], dim=1)
396
+ block_states[new_block_id] = {'start_pos': new_start_pos, 'end_pos': new_start_pos + block_size, 'mask_count': block_size, 'total_masks': block_size, 'state': 'active', 'is_complete': False}
397
+ current_blocks += 1
398
+
399
+ self._update_block_completion_states(block_states, decoded_token_threshold)
400
+ if (x_t == self.mask_token_id).sum() == 0 and current_blocks == 0: break
401
+
402
+
403
+
404
+ #### D2F-BLOCK ####
405
+ blocks_to_cache = [bid for bid, state in block_states.items() if state['state'] == 'to_cache']
406
+ update_kvcache = 0
407
+ if blocks_to_cache:
408
+ start_pos, end_pos = block_states[min(blocks_to_cache)]['start_pos'], block_states[max(blocks_to_cache)]['end_pos']
409
+ update_kvcache = end_pos - start_pos; input_seq, process_start_pos = x_t[:, start_pos:], start_pos
410
+ else:
411
+ active_blocks = [bid for bid, state in block_states.items() if state['state'] == 'active' and state['start_pos'] >= cache_length]
412
+ if not active_blocks: break
413
+ start_pos = min(block_states[bid]['start_pos'] for bid in active_blocks); input_seq, process_start_pos = x_t[:, start_pos:], start_pos
414
+
415
+ if input_seq.shape[1] == 0: break
416
+
417
+ attention_mask = extract_attention_mask(full_attention_mask, process_start_pos, input_seq.shape[1], cache_length)
418
+ outputs = self.model(input_seq, attention_bias=attention_mask, past_key_values=past_key_values, use_cache=True, update_kvcache=update_kvcache + cache_length)
419
+ if update_kvcache > 0:
420
+ past_key_values = outputs.past_key_values
421
+ for bid in blocks_to_cache: block_states[bid]['state'] = 'in_cache'
422
+
423
+ blocks_to_deactivate = []
424
+ for block_id, state in block_states.items():
425
+ if state['state'] != 'active': continue
426
+ block_mask_locs = (x_t[0, state['start_pos']:state['end_pos']] == self.mask_token_id).nonzero().squeeze(-1)
427
+ if block_mask_locs.numel() == 0:
428
+ blocks_to_deactivate.append(block_id); continue
429
+ logit_offset = state['start_pos'] - process_start_pos
430
+ block_mask_logits = outputs.logits[:, logit_offset + block_mask_locs, :]
431
+ _, x0, initial_confidence = sample_tokens(block_mask_logits.squeeze(0), self.temperature, self.top_p, self.top_k)
432
+ all_indices = (initial_confidence > skip_threshold).nonzero().squeeze(-1)
433
+ if state['is_complete'] and all_indices.numel() == 0 and block_mask_logits.numel() > 0:
434
+ all_indices = torch.tensor([torch.argmax(initial_confidence)], device=self.device)
435
+
436
+ if all_indices.numel() > 0:
437
+ updated_block_ids.add(block_id)
438
+ positions_to_update = state['start_pos'] + block_mask_locs[all_indices]
439
+ x_t[0, positions_to_update] = x0[all_indices]; state['mask_count'] -= all_indices.numel()
440
+ if self.tokenizer.eos_token_id in x0[all_indices]: eos_detected = True
441
+ if state['mask_count'] == 0: blocks_to_deactivate.append(block_id)
442
+
443
+ for bid in blocks_to_deactivate:
444
+ if block_states[bid]['state'] == 'active' and all(block_states.get(i, {}).get('state') != 'active' for i in range(bid)):
445
+ block_states[bid]['state'] = 'to_cache'; current_blocks -= 1
446
+ if update_kvcache > 0: cache_length += update_kvcache
447
+
448
+ #### FlexMDM Cache Update ####
449
+
450
+
451
+
452
+
453
+
454
+ # Capture current step's visualization and status frames
455
+ generated_ids = x_t[0, prompt_length:]
456
+ valid_ids = generated_ids[generated_ids != self.mask_token_id]
457
+ live_text = self.tokenizer.decode(valid_ids, skip_special_tokens=True)
458
+
459
+ current_viz_html = self._render_visualization_html(step, x_t, block_states, cache_length, updated_block_ids)
460
+ current_status_html = self._render_status_html(step, block_states, cache_length)
461
+ captured_frames.append((current_viz_html, current_status_html))
462
+
463
+ yield live_text, captured_frames, "Generating...", "Generating...", "Generating..."
464
+
465
+
466
+
467
+ # Final output
468
+ total_time = time.time() - start_time
469
+ final_generated_ids = x_t[0, prompt_length:]
470
+ eos_positions = (final_generated_ids == self.tokenizer.eos_token_id).nonzero()
471
+ if eos_positions.numel() > 0:
472
+ final_generated_ids = final_generated_ids[:eos_positions[0, 0] + 1]
473
+
474
+ final_text = self.tokenizer.decode(final_generated_ids, skip_special_tokens=True)
475
+ final_viz_html = self._render_visualization_html(step, x_t, block_states, cache_length, set())
476
+ final_status_html = self._render_status_html(step, block_states, cache_length)
477
+ captured_frames.append((final_viz_html, final_status_html))
478
+
479
+ tokens_incl_eos = len(final_generated_ids)
480
+ tokens_excl_eos = len(final_generated_ids[final_generated_ids != self.tokenizer.eos_token_id])
481
+ stats_text = f"""
482
+ ### ✅ Generation Complete!
483
+ ---
484
+ - **Total time:** `{total_time:.2f} seconds`
485
+ - **Tokens generated (incl. EOS):** `{tokens_incl_eos}`
486
+ - **Tokens generated (excl. EOS):** `{tokens_excl_eos}`
487
+ - **Tokens per second:** `{(tokens_incl_eos / total_time):.2f}`
488
+ """
489
+
490
+ yield final_text, captured_frames, stats_text, "Generation complete, playback starting soon", "Generation complete, playback starting soon"
491
+
492
+
493
+ # --- Gradio UI and Event Handlers ---
494
+ if __name__ == "__main__":
495
+ os.environ["CUDA_VISIBLE_DEVICES"] = "3"
496
+ config = {
497
+ "pretrained_path": "GSAI-ML/LLaDA-8B-Instruct",
498
+ "lora_path": "SJTU-Deng-Lab/D2F_LLaDA_Instruct_8B_Lora",
499
+ "device": "cuda", "dtype": "bfloat16", "max_length": 4096,
500
+ "temperature": 0.0, "top_p": None, "top_k": None, "mask_token_id": 126336,
501
+ "sampling_strategy": "default",
502
+ }
503
+ set_seed(42)
504
+ inference_engine = DreamLoRAInference(**config)
505
+
506
+ # Gradio helper for animation
507
+ def animate_visualization(html_frames_list: List[Tuple[str, str]], delay: float) -> Iterator[Tuple[str, str]]:
508
+ if not html_frames_list:
509
+ yield "No visualization data captured", "No status data captured"
510
+ return
511
+ for viz_frame, status_frame in html_frames_list:
512
+ yield viz_frame, status_frame
513
+ time.sleep(delay)
514
+
515
+ # Global auto-scroll JS
516
+ auto_scroll_js = """
517
+ <script>
518
+ function globalForceScroll() {
519
+ // Scroll visualization container
520
+ var vizContainer = document.getElementById('viz-container');
521
+ if (vizContainer) {
522
+ vizContainer.scrollTop = vizContainer.scrollHeight;
523
+ }
524
+
525
+ // Scroll status container
526
+ var statusContainer = document.getElementById('status-container');
527
+ if (statusContainer) {
528
+ statusContainer.scrollTop = statusContainer.scrollHeight;
529
+ }
530
+
531
+ // Scroll all anchors
532
+ var anchors = document.querySelectorAll('.scroll-anchor');
533
+ anchors.forEach(function(anchor) {
534
+ try {
535
+ anchor.scrollIntoView({behavior: 'auto', block: 'end'});
536
+ } catch(e) {}
537
+ });
538
+ }
539
+
540
+ // Periodic scrolling
541
+ setInterval(globalForceScroll, 200);
542
+
543
+ document.addEventListener('DOMContentLoaded', function() {
544
+ // Monitor content changes
545
+ var observer = new MutationObserver(function(mutations) {
546
+ globalForceScroll();
547
+ });
548
+
549
+ observer.observe(document.body, {
550
+ childList: true,
551
+ subtree: true,
552
+ characterData: true
553
+ });
554
+
555
+ // Initial scrolling
556
+ setTimeout(globalForceScroll, 100);
557
+ setTimeout(globalForceScroll, 500);
558
+ setTimeout(globalForceScroll, 1000);
559
+ });
560
+ </script>
561
+ """
562
+
563
+ with gr.Blocks(css=DreamLoRAInference.CSS, theme=gr.themes.Soft()) as demo:
564
+ html_frames_state = gr.State([])
565
+
566
+ gr.Markdown("# ✨ D2F-LLaDA: Real-time Text vs. Slow-motion Visualization")
567
+ gr.Markdown("Left side shows real-time streaming output. Right side plays back the decoding process visualization after generation completes.")
568
+
569
+ # Inject global auto-scroll JS
570
+ gr.HTML(auto_scroll_js)
571
+
572
+ with gr.Row():
573
+ # --- Left Column ---
574
+ with gr.Column(scale=2, elem_classes=["left-column"]):
575
+ prompt_input = gr.Textbox(label="Enter your question", placeholder="Example: Natalia sold clips to...", lines=5)
576
+ generate_button = gr.Button("🚀 Generate & Visualize", variant="primary")
577
+ with gr.Group(elem_classes=["live-text-container"]):
578
+ live_text_output = gr.Textbox(label="Real-time Generation Output", interactive=False, lines=25, elem_id="live-text-output")
579
+ stats_output = gr.Markdown(label="Generation Statistics", elem_id="stats-output")
580
+
581
+ # --- Right Column ---
582
+ with gr.Column(scale=3, elem_classes=["right-column"]):
583
+ with gr.Accordion("⚙️ Parameter Settings", open=True):
584
+ with gr.Row():
585
+ max_new_tokens_slider = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, label="Max Tokens to Generate")
586
+ block_size_slider = gr.Slider(minimum=16, maximum=128, value=32, step=16, label="Block Size")
587
+ with gr.Row():
588
+ block_add_thresh_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.1, step=0.05, label="Block Add Threshold")
589
+ decoded_token_thresh_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.5, step=0.05, label="Decoding Completion Threshold")
590
+ skip_thresh_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.9, step=0.01, label="Skip Threshold")
591
+ delay_slider = gr.Slider(minimum=0.0, maximum=1.0, value=0.1, step=0.05, label="Playback Delay (seconds)", info="Adjust visualization playback speed.")
592
+
593
+ with gr.Group(elem_classes=["viz-status-container"]):
594
+ visualization_output = gr.HTML(label="Generation Process Visualization", elem_id="viz-container")
595
+ status_output_html = gr.HTML(label="Generation Block Status", elem_id="status-container")
596
+
597
+ gr.Examples(
598
+ examples=[
599
+ ["Solve the equation x² - 6x + 8 = 0. First, explain what a quadratic equation is and why it can have up to two solutions. Then solve this equation using three different methods: factoring, completing the square, and the quadratic formula. For each method, explain the mathematical reasoning behind it, show all steps in detail, and discuss when this particular method is most useful. Finally, verify your solutions by substituting them back into the original equation.", 1024, 32, 0.1, 0.55, 0.9, 0.1],
600
+
601
+ ["A circular swimming pool has a diameter of 8 meters. Calculate the pool's circumference and area. First, explain the relationship between diameter, radius, circumference, and area of a circle, including the role of π in these formulas. Then perform the calculations using π ≈ 3.14159. Next, estimate how much water (in cubic meters) would be needed to fill this pool if it has a uniform depth of 1.5 meters. Finally, calculate how much it would cost to fill this pool if water costs $2.50 per cubic meter. Show all steps and include appropriate units in your answer.", 1024, 32, 0.1, 0.5, 0.9, 0.1],
602
+
603
+ ["A movie theater offers a loyalty card that costs $15 and gives a 15% discount on all tickets. If a regular movie ticket costs $10, how many tickets would you need to buy to make the loyalty card worthwhile? First, explain the concept of a break-even point. Then set up an equation to find when the total cost with the card equals the total cost without the card. Solve this equation step by step, showing all your work. Finally, interpret your answer in the context of the problem.", 1024, 32, 0.1, 0.5, 0.9, 0.1],
604
+ ],
605
+ inputs=[
606
+ prompt_input, max_new_tokens_slider, block_size_slider, block_add_thresh_slider,
607
+ decoded_token_thresh_slider, skip_thresh_slider, delay_slider
608
+ ],
609
+ label="Examples (Math Problems)"
610
+ )
611
+
612
+ # --- Event Handling Chain ---
613
+ inputs_list = [
614
+ prompt_input, max_new_tokens_slider, block_size_slider,
615
+ block_add_thresh_slider, decoded_token_thresh_slider, skip_thresh_slider
616
+ ]
617
+ ipdb.set_trace()
618
+ generation_event = generate_button.click(
619
+ fn=inference_engine.stream_and_capture_for_gradio,
620
+ inputs=inputs_list,
621
+ outputs=[live_text_output, html_frames_state, stats_output, visualization_output, status_output_html]
622
+ )
623
+
624
+ generation_event.then(
625
+ fn=animate_visualization,
626
+ inputs=[html_frames_state, delay_slider],
627
+ outputs=[visualization_output, status_output_html]
628
+ )
629
+
630
+ demo.queue().launch(share=True)
model_cache/dream/configuration_dream.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ # coding=utf-8
3
+ # Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved.
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+ """Dream model configuration"""
17
+
18
+ from transformers.configuration_utils import PretrainedConfig
19
+ from transformers.modeling_rope_utils import rope_config_validation
20
+ from transformers.utils import logging
21
+
22
+
23
+ logger = logging.get_logger(__name__)
24
+
25
+
26
+ class DreamConfig(PretrainedConfig):
27
+ model_type = "Dream"
28
+ keys_to_ignore_at_inference = ["past_key_values"]
29
+
30
+ def __init__(
31
+ self,
32
+ vocab_size=151936,
33
+ hidden_size=4096,
34
+ intermediate_size=22016,
35
+ num_hidden_layers=32,
36
+ num_attention_heads=32,
37
+ num_key_value_heads=32,
38
+ hidden_act="silu",
39
+ max_position_embeddings=32768,
40
+ initializer_range=0.02,
41
+ rms_norm_eps=1e-6,
42
+ use_cache=False, # cache not used in diffusion
43
+ tie_word_embeddings=False,
44
+ rope_theta=10000.0,
45
+ rope_scaling=None,
46
+ use_sliding_window=False,
47
+ sliding_window=4096,
48
+ max_window_layers=28,
49
+ attention_dropout=0.0,
50
+ mask_token_id=151666,
51
+ pad_token_id=151643,
52
+ **kwargs,
53
+ ):
54
+ self.vocab_size = vocab_size
55
+ self.max_position_embeddings = max_position_embeddings
56
+ self.hidden_size = hidden_size
57
+ self.intermediate_size = intermediate_size
58
+ self.num_hidden_layers = num_hidden_layers
59
+ self.num_attention_heads = num_attention_heads
60
+ self.use_sliding_window = use_sliding_window
61
+ self.sliding_window = sliding_window if use_sliding_window else None
62
+ self.max_window_layers = max_window_layers
63
+
64
+ # for backward compatibility
65
+ if num_key_value_heads is None:
66
+ num_key_value_heads = num_attention_heads
67
+
68
+ self.num_key_value_heads = num_key_value_heads
69
+ self.hidden_act = hidden_act
70
+ self.initializer_range = initializer_range
71
+ self.rms_norm_eps = rms_norm_eps
72
+ self.use_cache = use_cache
73
+ self.rope_theta = rope_theta
74
+ self.rope_scaling = rope_scaling
75
+ self.attention_dropout = attention_dropout
76
+ # Validate the correctness of rotary position embeddings parameters
77
+ # BC: if there is a 'type' field, move it to 'rope_type'.
78
+ if self.rope_scaling is not None and "type" in self.rope_scaling:
79
+ self.rope_scaling["rope_type"] = self.rope_scaling["type"]
80
+ rope_config_validation(self)
81
+
82
+ super().__init__(
83
+ tie_word_embeddings=tie_word_embeddings,
84
+ **kwargs,
85
+ )
86
+ self.mask_token_id = mask_token_id
87
+ self.pad_token_id = pad_token_id
88
+
model_cache/dream/generation_utils.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import warnings
17
+ import copy
18
+ from dataclasses import dataclass
19
+ from typing import Any, Dict, Optional, Tuple, Union
20
+
21
+ import torch
22
+ import torch.distributions as dists
23
+ from torch.nn import functional as F
24
+ from transformers import __version__
25
+ from transformers.generation.configuration_utils import (
26
+ GenerationConfig
27
+ )
28
+ from transformers.utils import (
29
+ ModelOutput,
30
+ is_torchdynamo_compiling,
31
+ logging,
32
+ )
33
+
34
+ logger = logging.get_logger(__name__)
35
+
36
+
37
+ def top_p_logits(logits, top_p=None):
38
+ sorted_logits, sorted_indices = torch.sort(logits, descending=True)
39
+ cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)
40
+ sorted_indices_to_remove = cumulative_probs > top_p
41
+ # Shift the indices to the right to keep the first token above the threshold
42
+ sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
43
+ sorted_indices_to_remove[..., 0] = 0
44
+
45
+ mask = torch.zeros_like(logits, dtype=torch.bool, device=logits.device)
46
+ mask = mask.scatter_(-1, sorted_indices, sorted_indices_to_remove)
47
+ logits = logits.masked_fill(mask, torch.finfo(logits.dtype).min)
48
+ return logits
49
+
50
+ def top_k_logits(logits, top_k=None):
51
+ top_k = min(top_k, logits.size(-1)) # Safety check
52
+ # Remove all tokens with a probability less than the last token of the top-k
53
+ indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
54
+ logits = logits.masked_fill(indices_to_remove, torch.finfo(logits.dtype).min)
55
+ return logits
56
+
57
+
58
+ def sample_tokens(logits, temperature=0.0, top_p=None, top_k=None, margin_confidence=False, neg_entropy=False):
59
+
60
+ if temperature > 0:
61
+ logits = logits / temperature
62
+ if top_p is not None and top_p < 1:
63
+ logits = top_p_logits(logits, top_p)
64
+ if top_k is not None:
65
+ logits = top_k_logits(logits, top_k)
66
+ probs = torch.softmax(logits, dim=-1)
67
+
68
+ if temperature > 0:
69
+ try:
70
+ x0 = dists.Categorical(probs=probs).sample()
71
+ confidence = torch.gather(probs, -1, x0.unsqueeze(-1)).squeeze(-1)
72
+ except:
73
+ confidence, x0 = probs.max(dim=-1)
74
+ else:
75
+ confidence, x0 = probs.max(dim=-1)
76
+
77
+ if margin_confidence:
78
+ sorted_probs, _ = torch.sort(probs, dim=-1, descending=True)
79
+ # Extract top1 and top2 probabilities
80
+ top1_probs = sorted_probs[:, 0]
81
+ top2_probs = sorted_probs[:, 1]
82
+ # Calculate confidence as top1 - top2
83
+ confidence = top1_probs - top2_probs
84
+
85
+ if neg_entropy:
86
+ epsilon = 1e-10
87
+ log_probs = torch.log(probs + epsilon)
88
+ confidence = torch.sum(probs * log_probs, dim=-1)
89
+
90
+ return confidence, x0
91
+
92
+
93
+ @dataclass
94
+ class DreamModelOutput(ModelOutput):
95
+ sequences: torch.LongTensor = None
96
+ history: Optional[Tuple[torch.FloatTensor]] = None
97
+
98
+
99
+ class DreamGenerationConfig(GenerationConfig):
100
+ def __init__(self, **kwargs):
101
+ self.temperature: float = kwargs.pop("temperature", 0.0)
102
+ self.top_p: Optional[float] = kwargs.pop("top_p", None)
103
+ self.top_k: Optional[int] = kwargs.pop("top_k", None)
104
+ self.max_length = kwargs.pop("max_length", 20)
105
+ self.max_new_tokens = kwargs.pop("max_new_tokens", None)
106
+ # diffusion specific params
107
+ self.eps: float = kwargs.pop("eps", 1e-3)
108
+ self.steps: int = kwargs.pop("steps", 512)
109
+ self.alg: str = kwargs.pop("alg", 'origin')
110
+ self.alg_temp: Optional[float] = kwargs.pop("alg_temp", None)
111
+
112
+ # Parameters that define the output variables of `generate`
113
+ self.num_return_sequences: int = kwargs.pop("num_return_sequences", 1)
114
+ self.return_dict_in_generate: bool = kwargs.pop("return_dict_in_generate", False)
115
+ self.output_history: bool = kwargs.pop("output_history", False)
116
+
117
+ # Special tokens that can be used at generation time
118
+ self.mask_token_id = kwargs.pop("mask_token_id", None)
119
+ self.pad_token_id = kwargs.pop("pad_token_id", None)
120
+ self.bos_token_id = kwargs.pop("bos_token_id", None)
121
+ self.eos_token_id = kwargs.pop("eos_token_id", None)
122
+
123
+ # Wild card
124
+ self.generation_kwargs = kwargs.pop("generation_kwargs", {})
125
+
126
+ # The remaining attributes do not parametrize `.generate()`, but are informative and/or used by the hub
127
+ # interface.
128
+ self._from_model_config = kwargs.pop("_from_model_config", False)
129
+ self._commit_hash = kwargs.pop("_commit_hash", None)
130
+ self.transformers_version = kwargs.pop("transformers_version", __version__)
131
+
132
+ # Additional attributes without default values
133
+ if not self._from_model_config:
134
+ # we don't want to copy values from the model config if we're initializing a `GenerationConfig` from a
135
+ # model's default configuration file
136
+ for key, value in kwargs.items():
137
+ try:
138
+ setattr(self, key, value)
139
+ except AttributeError as err:
140
+ logger.error(f"Can't set {key} with value {value} for {self}")
141
+ raise err
142
+
143
+ # Validate the values of the attributes
144
+ self.validate(is_init=True)
145
+
146
+ def validate(self, is_init=False):
147
+ pass
148
+
149
+ class DreamGenerationMixin:
150
+ @staticmethod
151
+ def _expand_inputs_for_generation(
152
+ expand_size: int = 1,
153
+ input_ids: Optional[torch.LongTensor] = None,
154
+ attention_mask: Optional[torch.LongTensor] = None
155
+ ) -> Tuple[torch.LongTensor, Dict[str, Any]]:
156
+ """Expands tensors from [batch_size, ...] to [batch_size * expand_size, ...]"""
157
+ # Do not call torch.repeat_interleave if expand_size is 1 because it clones
158
+ # the input tensor and thus requires more memory although no change is applied
159
+ if expand_size == 1:
160
+ return input_ids, attention_mask
161
+ if input_ids is not None:
162
+ input_ids = input_ids.repeat_interleave(expand_size, dim=0)
163
+ if attention_mask is not None:
164
+ attention_mask = attention_mask.repeat_interleave(expand_size, dim=0)
165
+ return input_ids, attention_mask
166
+
167
+ def _validate_generated_length(self, generation_config, input_ids_length, has_default_max_length):
168
+ """Performs validation related to the resulting generated length"""
169
+
170
+ # Can't throw warnings/exceptions during compilation
171
+ if is_torchdynamo_compiling():
172
+ return
173
+
174
+ # 1. Max length warnings related to poor parameterization
175
+ if has_default_max_length and generation_config.max_new_tokens is None and generation_config.max_length == 20:
176
+ # 20 is the default max_length of the generation config
177
+ warnings.warn(
178
+ f"Using the model-agnostic default `max_length` (={generation_config.max_length}) to control the "
179
+ "generation length. We recommend setting `max_new_tokens` to control the maximum length of the "
180
+ "generation.",
181
+ UserWarning,
182
+ )
183
+ if input_ids_length >= generation_config.max_length:
184
+ input_ids_string = "input_ids"
185
+ raise ValueError(
186
+ f"Input length of {input_ids_string} is {input_ids_length}, but `max_length` is set to"
187
+ f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
188
+ " increasing `max_length` or, better yet, setting `max_new_tokens`."
189
+ )
190
+
191
+ def _prepare_generated_length(
192
+ self,
193
+ generation_config,
194
+ has_default_max_length,
195
+ input_ids_length,
196
+ ):
197
+ """Prepared max and min length in generation configs to avoid clashes between similar attributes"""
198
+
199
+ if generation_config.max_new_tokens is not None:
200
+ if not has_default_max_length and generation_config.max_length is not None:
201
+ logger.warning(
202
+ f"Both `max_new_tokens` (={generation_config.max_new_tokens}) and `max_length`(="
203
+ f"{generation_config.max_length}) seem to have been set. `max_new_tokens` will take precedence. "
204
+ "Please refer to the documentation for more information. "
205
+ "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
206
+ )
207
+ generation_config.max_length = generation_config.max_new_tokens + input_ids_length
208
+
209
+ elif has_default_max_length:
210
+ if generation_config.max_length == DreamGenerationConfig().max_length:
211
+ generation_config.max_length = generation_config.max_length + input_ids_length
212
+ max_position_embeddings = getattr(self.config, "max_position_embeddings", None)
213
+ if max_position_embeddings is not None:
214
+ generation_config.max_length = min(generation_config.max_length, max_position_embeddings)
215
+
216
+ return generation_config
217
+
218
+ def _prepare_generation_config(
219
+ self, generation_config: Optional[DreamGenerationConfig], **kwargs: Dict
220
+ ) -> DreamGenerationConfig:
221
+ """
222
+ Prepares the base generation config, then applies any generation configuration options from kwargs. This
223
+ function handles retrocompatibility with respect to configuration files.
224
+ """
225
+ # priority: `generation_config` argument > `model.generation_config` (the default generation config)
226
+ using_model_generation_config = False
227
+ if generation_config is None:
228
+ generation_config = DreamGenerationConfig.from_model_config(self.config)
229
+ using_model_generation_config = True
230
+
231
+ # `torch.compile` can't compile `copy.deepcopy`, arguments in `kwargs` that are part of `generation_config`
232
+ # will mutate the object with `.update`. As such, passing these arguments through `kwargs` is disabled -- an
233
+ # exception will be raised in `_validate_model_kwargs`
234
+ if not is_torchdynamo_compiling():
235
+ generation_config = copy.deepcopy(generation_config)
236
+ _kwargs = generation_config.update(**kwargs)
237
+ # If `generation_config` is provided, let's fallback ALL special tokens to the default values for the model
238
+ if not using_model_generation_config:
239
+ if generation_config.bos_token_id is None:
240
+ generation_config.bos_token_id = self.generation_config.bos_token_id
241
+ if generation_config.eos_token_id is None:
242
+ generation_config.eos_token_id = self.generation_config.eos_token_id
243
+ if generation_config.pad_token_id is None:
244
+ generation_config.pad_token_id = self.generation_config.pad_token_id
245
+ if generation_config.mask_token_id is None:
246
+ generation_config.mask_token_id = self.generation_config.mask_token_id
247
+
248
+ return generation_config
249
+
250
+ def _prepare_special_tokens(
251
+ self,
252
+ generation_config: DreamGenerationConfig,
253
+ device: Optional[Union[torch.device, str]] = None,
254
+ ):
255
+ """
256
+ Prepares the special tokens for generation, overwriting the generation config with their processed versions
257
+ converted to tensor.
258
+ Note that `generation_config` is changed in place and stops being serializable after this method is called.
259
+ That is no problem if called within `generate` (`generation_config` is a local copy that doesn't leave the
260
+ function). However, if called outside `generate`, consider creating a copy of `generation_config` first.
261
+ """
262
+
263
+ # Convert special tokens to tensors
264
+ def _tensor_or_none(token, device=None):
265
+ if token is None:
266
+ return token
267
+
268
+ device = device if device is not None else self.device
269
+ if isinstance(token, torch.Tensor):
270
+ return token.to(device)
271
+ return torch.tensor(token, device=device, dtype=torch.long)
272
+
273
+ bos_token_tensor = _tensor_or_none(generation_config.bos_token_id, device=device)
274
+ eos_token_tensor = _tensor_or_none(generation_config.eos_token_id, device=device)
275
+ pad_token_tensor = _tensor_or_none(generation_config.pad_token_id, device=device)
276
+ mask_token_tensor = _tensor_or_none(generation_config.mask_token_id, device=device)
277
+
278
+ # We can have more than one eos token. Always treat it as a 1D tensor (when it exists).
279
+ if eos_token_tensor is not None and eos_token_tensor.ndim == 0:
280
+ eos_token_tensor = eos_token_tensor.unsqueeze(0)
281
+
282
+ # Set pad token if unset (and there are conditions to do so)
283
+ if pad_token_tensor is None and eos_token_tensor is not None:
284
+ pad_token_tensor = eos_token_tensor[0]
285
+ logger.warning(f"Setting `pad_token_id` to `eos_token_id`:{pad_token_tensor} for open-end generation.")
286
+
287
+ # Update generation config with the updated special tokens tensors
288
+ # NOTE: this must be written into a different attribute name than the one holding the original special tokens
289
+ # (in their non-tensor form), in order to enable end-to-end compilation. See
290
+ # https://pytorch.org/docs/stable/torch.compiler_cudagraph_trees.html#limitations
291
+ generation_config._bos_token_tensor = bos_token_tensor
292
+ generation_config._eos_token_tensor = eos_token_tensor
293
+ generation_config._pad_token_tensor = pad_token_tensor
294
+ generation_config._mask_token_tensor = mask_token_tensor
295
+
296
+ @torch.no_grad()
297
+ def diffusion_generate(
298
+ self,
299
+ inputs: Optional[torch.Tensor] = None,
300
+ generation_config: Optional[DreamGenerationConfig] = None,
301
+ **kwargs,
302
+ ) -> Union[DreamModelOutput, torch.LongTensor]:
303
+ # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
304
+ generation_config = self._prepare_generation_config(generation_config, **kwargs)
305
+ generation_tokens_hook_func = kwargs.pop("generation_tokens_hook_func", lambda step, x, logits: x)
306
+ generation_logits_hook_func = kwargs.pop("generation_logits_hook_func", lambda step, x, logits: logits)
307
+
308
+ # 2. Define model inputs
309
+ assert inputs is not None
310
+ input_ids = inputs
311
+ device = input_ids.device
312
+ attention_mask = kwargs.pop("attention_mask", None)
313
+ self._prepare_special_tokens(generation_config, device=device)
314
+
315
+ # 3. Prepare `max_length`.
316
+ input_ids_length = input_ids.shape[-1]
317
+ has_default_max_length = kwargs.get("max_length") is None and generation_config.max_length is not None
318
+ generation_config = self._prepare_generated_length(
319
+ generation_config=generation_config,
320
+ has_default_max_length=has_default_max_length,
321
+ input_ids_length=input_ids_length,
322
+ )
323
+
324
+ self._validate_generated_length(generation_config, input_ids_length, has_default_max_length)
325
+
326
+ # 4. Check input_ids
327
+ if not is_torchdynamo_compiling() and self.device.type != input_ids.device.type:
328
+ warnings.warn(
329
+ "You are calling .generate() with the `input_ids` being on a device type different"
330
+ f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
331
+ f" is on {self.device.type}. You may experience unexpected behaviors or slower generation."
332
+ " Please make sure that you have put `input_ids` to the"
333
+ f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
334
+ " running `.generate()`.",
335
+ UserWarning,
336
+ )
337
+ if (
338
+ hasattr(generation_config, "pad_token_id") and
339
+ torch.any(input_ids == generation_config.pad_token_id) and
340
+ attention_mask is None
341
+ ):
342
+ warnings.warn(
343
+ "Padding was detected but no attention mask is passed here. For correct "
344
+ "generation results, please set `attention_mask` when batch-padding inputs.",
345
+ UserWarning,
346
+ )
347
+
348
+ input_ids, attention_mask = self._expand_inputs_for_generation(
349
+ expand_size=generation_config.num_return_sequences,
350
+ input_ids=input_ids,
351
+ attention_mask=attention_mask
352
+ )
353
+
354
+ result = self._sample(
355
+ input_ids,
356
+ attention_mask=attention_mask,
357
+ generation_config=generation_config,
358
+ generation_tokens_hook_func=generation_tokens_hook_func,
359
+ generation_logits_hook_func=generation_logits_hook_func
360
+ )
361
+ return result
362
+
363
+ def _sample(
364
+ self,
365
+ input_ids: torch.LongTensor,
366
+ attention_mask: Optional[torch.LongTensor],
367
+ generation_config: DreamGenerationConfig,
368
+ generation_tokens_hook_func,
369
+ generation_logits_hook_func
370
+ ) -> Union[DreamModelOutput, torch.LongTensor]:
371
+ # init values
372
+ output_history = generation_config.output_history
373
+ return_dict_in_generate = generation_config.return_dict_in_generate
374
+ max_length = generation_config.max_length
375
+ mask_token_id = generation_config.mask_token_id
376
+ steps = generation_config.steps
377
+ eps = generation_config.eps
378
+ alg = generation_config.alg
379
+ alg_temp = generation_config.alg_temp
380
+ temperature = generation_config.temperature
381
+ top_p = generation_config.top_p
382
+ top_k = generation_config.top_k
383
+
384
+ histories = [] if (return_dict_in_generate and output_history) else None
385
+
386
+ # pad input_ids to max_length
387
+ x = F.pad(input_ids, (0, max_length - input_ids.shape[1]), value=mask_token_id)
388
+
389
+ if attention_mask is not None and torch.any(attention_mask == 0.0):
390
+ # we do not mask the [MASK] tokens so value = 1.0
391
+ attention_mask = F.pad(attention_mask, (0, max_length - attention_mask.shape[1]), value=1.0)
392
+ tok_idx = attention_mask.long().cumsum(-1) - 1
393
+ tok_idx.masked_fill_(attention_mask == 0, 1)
394
+ # attention_mask is of shape [B, N]
395
+ # broadcast to [B, 1, N, N]
396
+ attention_mask = torch.logical_and(
397
+ attention_mask.unsqueeze(1).unsqueeze(-2),
398
+ attention_mask.unsqueeze(1).unsqueeze(-1),
399
+ )
400
+ else:
401
+ tok_idx = None
402
+ attention_mask = "full"
403
+
404
+ timesteps = torch.linspace(1, eps, steps + 1, device=x.device)
405
+
406
+ # this allows user-defined token control of the intermediate steps
407
+ x = generation_tokens_hook_func(None, x, None)
408
+ for i in range(steps):
409
+ mask_index = (x == mask_token_id)
410
+ logits = self(x, attention_mask, tok_idx).logits
411
+ logits = torch.cat([logits[:,:1], logits[:, :-1]], dim=1)
412
+
413
+ # this allows user-defined logits control of the intermediate steps
414
+ logits = generation_logits_hook_func(i, x, logits)
415
+
416
+ mask_logits = logits[mask_index]
417
+ t = timesteps[i]
418
+ s = timesteps[i + 1]
419
+
420
+ if alg == 'origin':
421
+ p_transfer = 1 - s / t if i < steps - 1 else 1
422
+ x0 = torch.zeros_like(x[mask_index], device=self.device, dtype=torch.long) + mask_token_id
423
+ transfer_index_t_s = torch.rand(*x0.shape, device=self.device) < p_transfer
424
+ _, x0[transfer_index_t_s]= sample_tokens(mask_logits[transfer_index_t_s], temperature=temperature, top_p=top_p, top_k=top_k)
425
+ x[mask_index] = x0.clone()
426
+ else:
427
+ if alg == 'maskgit_plus':
428
+ confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k)
429
+ elif alg == 'topk_margin':
430
+ confidence, x0 = sample_tokens(mask_logits, temperature=temperature, top_p=top_p, top_k=top_k, margin_confidence=True)
431
+ elif alg == 'entropy':
432
+ confidence, x0 = sample_tokens(mask_logits, temperature, top_p=top_p, top_k=top_k, neg_entropy=True)
433
+ else:
434
+ raise RuntimeError(f"Unknown alg: {alg}")
435
+ num_mask_token = mask_index.sum() / mask_index.shape[0]
436
+ number_transfer_tokens = int(num_mask_token * (1 - s / t)) if i < steps - 1 else int(num_mask_token)
437
+ full_confidence = torch.full_like(x, -torch.inf, device=self.device, dtype=logits.dtype)
438
+ full_confidence[mask_index] = confidence
439
+ if number_transfer_tokens > 0:
440
+ if alg_temp is None or alg_temp == 0:
441
+ _, transfer_index = torch.topk(full_confidence, number_transfer_tokens)
442
+ else:
443
+ full_confidence = full_confidence / alg_temp
444
+ full_confidence = F.softmax(full_confidence, dim=-1)
445
+ transfer_index = torch.multinomial(full_confidence, num_samples=number_transfer_tokens)
446
+ x_ = torch.zeros_like(x, device=self.device, dtype=torch.long) + mask_token_id
447
+ x_[mask_index] = x0.clone()
448
+ row_indices = torch.arange(x.size(0), device=self.device).unsqueeze(1).expand_as(transfer_index)
449
+ x[row_indices,transfer_index] = x_[row_indices,transfer_index]
450
+
451
+ # this allows user-defined token control of the intermediate steps
452
+ x = generation_tokens_hook_func(i, x, logits)
453
+
454
+ if histories is not None:
455
+ histories.append(x.clone())
456
+
457
+ if return_dict_in_generate:
458
+ return DreamModelOutput(
459
+ sequences=x,
460
+ history=histories,
461
+ )
462
+ else:
463
+ return x
model_cache/dream/model_dream.py ADDED
@@ -0,0 +1,1029 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Hugging Face's logo
2
+ # Hugging Face
3
+ # Models
4
+ # Datasets
5
+ # Spaces
6
+ # Community
7
+ # Docs
8
+ # Enterprise
9
+ # Pricing
10
+
11
+
12
+
13
+
14
+ # Dream-org
15
+ # /
16
+ # Dream-v0-Instruct-7B
17
+
18
+ # like
19
+ # 94
20
+
21
+ # Follow
22
+
23
+ # Dream Org
24
+ # 81
25
+ # Feature Extraction
26
+ # Transformers
27
+ # Safetensors
28
+ # Dream
29
+ # custom_code
30
+
31
+ # License:
32
+ # apache-2.0
33
+ # Model card
34
+ # Files and versions
35
+ # Community
36
+ # 2
37
+ # Dream-v0-Instruct-7B
38
+ # /
39
+ # modeling_dream.py
40
+
41
+ # jiacheng-ye's picture
42
+ # jiacheng-ye
43
+ # Upload model
44
+ # 373705a
45
+ # verified
46
+ # about 2 months ago
47
+ # raw
48
+
49
+ # Copy download link
50
+ # history
51
+ # blame
52
+ # contribute
53
+ # delete
54
+
55
+ # 36.8 kB
56
+ # # coding=utf-8
57
+ # Copyright 2024 The Dream team, HKUNLP Group and the HuggingFace Inc. team. All rights reserved.
58
+ #
59
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
60
+ # and OPT and Qwen implementations in this library. It has been modified from its
61
+ # original forms to accommodate minor architectural differences compared
62
+ # to GPT-NeoX and OPT and Qwen used by the Meta AI and Qwen team that trained the model.
63
+ #
64
+ # Licensed under the Apache License, Version 2.0 (the "License");
65
+ # you may not use this file except in compliance with the License.
66
+ # You may obtain a copy of the License at
67
+ #
68
+ # http://www.apache.org/licenses/LICENSE-2.0
69
+ #
70
+ # Unless required by applicable law or agreed to in writing, software
71
+ # distributed under the License is distributed on an "AS IS" BASIS,
72
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
73
+ # See the License for the specific language governing permissions and
74
+ # limitations under the License.
75
+ """PyTorch Dream model."""
76
+ from transformers import Qwen2Model
77
+ from torch.nn.attention.flex_attention import flex_attention
78
+ import math
79
+ from typing import List, Optional, Tuple, Union
80
+ import os
81
+ import torch
82
+ import torch.utils.checkpoint
83
+ from torch import nn
84
+
85
+ from transformers.activations import ACT2FN
86
+ from transformers.cache_utils import Cache, DynamicCache
87
+ from transformers.modeling_outputs import (
88
+ BaseModelOutput,
89
+ MaskedLMOutput,
90
+ BaseModelOutputWithPast,
91
+ CausalLMOutputWithPast
92
+ )
93
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS
94
+ from transformers.modeling_utils import PreTrainedModel
95
+ from transformers.utils import (
96
+ add_start_docstrings,
97
+ add_start_docstrings_to_model_forward,
98
+ is_flash_attn_2_available,
99
+ is_flash_attn_greater_or_equal_2_10,
100
+ logging,
101
+ )
102
+ from transformers import PretrainedConfig
103
+ from model_cache.dream.configuration_dream import DreamConfig
104
+ from model_cache.dream.generation_utils import DreamGenerationMixin, DreamGenerationConfig
105
+ if is_flash_attn_2_available():
106
+ from transformers.modeling_flash_attention_utils import _flash_attention_forward
107
+
108
+
109
+ logger = logging.get_logger(__name__)
110
+
111
+ from transformers import Qwen2ForCausalLM
112
+ _CHECKPOINT_FOR_DOC = "Dream-7B"
113
+ _CONFIG_FOR_DOC = "DreamConfig"
114
+
115
+
116
+ # Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Dream
117
+ class DreamRMSNorm(nn.Module):
118
+ def __init__(self, hidden_size, eps=1e-6):
119
+ """
120
+ DreamRMSNorm is equivalent to T5LayerNorm
121
+ """
122
+ super().__init__()
123
+ self.weight = nn.Parameter(torch.ones(hidden_size))
124
+ self.variance_epsilon = eps
125
+
126
+ def forward(self, hidden_states):
127
+ input_dtype = hidden_states.dtype
128
+ hidden_states = hidden_states.to(torch.float32)
129
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
130
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
131
+ return self.weight * hidden_states.to(input_dtype)
132
+
133
+ def extra_repr(self):
134
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
135
+
136
+
137
+ # Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Dream
138
+ class DreamRotaryEmbedding(nn.Module):
139
+ def __init__(
140
+ self,
141
+ dim=None,
142
+ max_position_embeddings=2048,
143
+ base=10000,
144
+ device=None,
145
+ scaling_factor=1.0,
146
+ rope_type="default",
147
+ config: Optional[DreamConfig] = None,
148
+ ):
149
+ super().__init__()
150
+ # TODO (joao): remove the `if` below, only used for BC
151
+ self.rope_kwargs = {}
152
+ if config is None:
153
+ logger.warning_once(
154
+ "`DreamRotaryEmbedding` can now be fully parameterized by passing the model config through the "
155
+ "`config` argument. All other arguments will be removed in v4.46"
156
+ )
157
+ self.rope_kwargs = {
158
+ "rope_type": rope_type,
159
+ "factor": scaling_factor,
160
+ "dim": dim,
161
+ "base": base,
162
+ "max_position_embeddings": max_position_embeddings,
163
+ }
164
+ self.rope_type = rope_type
165
+ self.max_seq_len_cached = max_position_embeddings
166
+ self.original_max_seq_len = max_position_embeddings
167
+ else:
168
+ # BC: "rope_type" was originally "type"
169
+ if config.rope_scaling is not None:
170
+ self.rope_type = config.rope_scaling.get("rope_type", config.rope_scaling.get("type"))
171
+ else:
172
+ self.rope_type = "default"
173
+ self.max_seq_len_cached = config.max_position_embeddings
174
+ self.original_max_seq_len = config.max_position_embeddings
175
+
176
+ self.config = config
177
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
178
+
179
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device, **self.rope_kwargs)
180
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
181
+ self.original_inv_freq = self.inv_freq
182
+
183
+ def reset_parameters(self):
184
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, self.inv_freq.device, **self.rope_kwargs)
185
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
186
+ self.original_inv_freq = self.inv_freq
187
+
188
+
189
+ def _dynamic_frequency_update(self, position_ids, device):
190
+ """
191
+ dynamic RoPE layers should recompute `inv_freq` in the following situations:
192
+ 1 - growing beyond the cached sequence length (allow scaling)
193
+ 2 - the current sequence length is in the original scale (avoid losing precision with small sequences)
194
+ """
195
+ seq_len = torch.max(position_ids) + 1
196
+ if seq_len > self.max_seq_len_cached: # growth
197
+ inv_freq, self.attention_scaling = self.rope_init_fn(
198
+ self.config, device, seq_len=seq_len, **self.rope_kwargs
199
+ )
200
+ self.register_buffer("inv_freq", inv_freq, persistent=False) # TODO joao: may break with compilation
201
+ self.max_seq_len_cached = seq_len
202
+
203
+ if seq_len < self.original_max_seq_len and self.max_seq_len_cached > self.original_max_seq_len: # reset
204
+ self.register_buffer("inv_freq", self.original_inv_freq, persistent=False)
205
+ self.max_seq_len_cached = self.original_max_seq_len
206
+
207
+ @torch.no_grad()
208
+ def forward(self, x, position_ids):
209
+ if "dynamic" in self.rope_type:
210
+ self._dynamic_frequency_update(position_ids, device=x.device)
211
+
212
+ # Core RoPE block
213
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
214
+ position_ids_expanded = position_ids[:, None, :].float()
215
+ # Force float32 (see https://github.com/huggingface/transformers/pull/29285)
216
+ device_type = x.device.type
217
+ device_type = device_type if isinstance(device_type, str) and device_type != "mps" else "cpu"
218
+ with torch.autocast(device_type=device_type, enabled=False):
219
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
220
+ emb = torch.cat((freqs, freqs), dim=-1)
221
+ cos = emb.cos()
222
+ sin = emb.sin()
223
+
224
+ # Advanced RoPE types (e.g. yarn) apply a post-processing scaling factor, equivalent to scaling attention
225
+ cos = cos * self.attention_scaling
226
+ sin = sin * self.attention_scaling
227
+
228
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
229
+
230
+
231
+ # Copied from transformers.models.llama.modeling_llama.rotate_half
232
+ def rotate_half(x):
233
+ """Rotates half the hidden dims of the input."""
234
+ x1 = x[..., : x.shape[-1] // 2]
235
+ x2 = x[..., x.shape[-1] // 2 :]
236
+ return torch.cat((-x2, x1), dim=-1)
237
+
238
+
239
+ # Copied from transformers.models.llama.modeling_llama.apply_rotary_pos_emb
240
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
241
+ """Applies Rotary Position Embedding to the query and key tensors.
242
+ Args:
243
+ q (`torch.Tensor`): The query tensor.
244
+ k (`torch.Tensor`): The key tensor.
245
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
246
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
247
+ position_ids (`torch.Tensor`, *optional*):
248
+ Deprecated and unused.
249
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
250
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
251
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
252
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
253
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
254
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
255
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
256
+ Returns:
257
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
258
+ """
259
+ cos = cos.unsqueeze(unsqueeze_dim)
260
+ sin = sin.unsqueeze(unsqueeze_dim)
261
+ q_embed = (q * cos) + (rotate_half(q) * sin)
262
+ k_embed = (k * cos) + (rotate_half(k) * sin)
263
+ return q_embed, k_embed
264
+
265
+
266
+ # Copied from transformers.models.mistral.modeling_mistral.MistralMLP with Mistral->Dream
267
+ class DreamMLP(nn.Module):
268
+ def __init__(self, config):
269
+ super().__init__()
270
+ self.hidden_size = config.hidden_size
271
+ self.intermediate_size = config.intermediate_size
272
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
273
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
274
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
275
+ self.act_fn = ACT2FN[config.hidden_act]
276
+
277
+ def forward(self, hidden_state):
278
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
279
+
280
+
281
+ # Copied from transformers.models.llama.modeling_llama.repeat_kv
282
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
283
+ """
284
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
285
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
286
+ """
287
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
288
+ if n_rep == 1:
289
+ return hidden_states
290
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
291
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
292
+
293
+
294
+ class DreamAttention(nn.Module):
295
+ """
296
+ Multi-headed attention from 'Attention Is All You Need' paper. Modified to use sliding window attention: Longformer
297
+ and "Generating Long Sequences with Sparse Transformers".
298
+ """
299
+
300
+ def __init__(self, config: DreamConfig, layer_idx: Optional[int] = None):
301
+ super().__init__()
302
+ self.config = config
303
+ self.layer_idx = layer_idx
304
+ if layer_idx is None:
305
+ logger.warning_once(
306
+ f"Instantiating {self.__class__.__name__} without passing `layer_idx` is not recommended and will "
307
+ "to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
308
+ "when creating this class."
309
+ )
310
+
311
+ self.hidden_size = config.hidden_size
312
+ self.num_heads = config.num_attention_heads
313
+ self.head_dim = self.hidden_size // self.num_heads
314
+ self.num_key_value_heads = config.num_key_value_heads
315
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
316
+ self.max_position_embeddings = config.max_position_embeddings
317
+ self.rope_theta = config.rope_theta
318
+ self.is_causal = False
319
+ self.attention_dropout = config.attention_dropout
320
+
321
+ if (self.head_dim * self.num_heads) != self.hidden_size:
322
+ raise ValueError(
323
+ f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
324
+ f" and `num_heads`: {self.num_heads})."
325
+ )
326
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=True)
327
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
328
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=True)
329
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
330
+
331
+ self.rotary_emb = DreamRotaryEmbedding(config=self.config)
332
+
333
+ def forward(
334
+ self,
335
+ hidden_states: torch.Tensor,
336
+ attention_mask: Optional[torch.Tensor] = None,
337
+ position_ids: Optional[torch.LongTensor] = None,
338
+ past_key_value: Optional[Cache] = None,
339
+ output_attentions: bool = False,
340
+ use_cache: bool = False,
341
+ cache_position: Optional[torch.LongTensor] = None,
342
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
343
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
344
+ bsz, q_len, _ = hidden_states.size()
345
+
346
+ query_states = self.q_proj(hidden_states)
347
+ key_states = self.k_proj(hidden_states)
348
+ value_states = self.v_proj(hidden_states)
349
+
350
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
351
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
352
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
353
+
354
+ if position_embeddings is None:
355
+ logger.warning_once(
356
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
357
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
358
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
359
+ "removed and `position_embeddings` will be mandatory."
360
+ )
361
+ cos, sin = self.rotary_emb(value_states, position_ids)
362
+ else:
363
+ cos, sin = position_embeddings
364
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
365
+
366
+ if past_key_value is not None:
367
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
368
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
369
+
370
+ # repeat k/v heads if n_kv_heads < n_heads
371
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
372
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
373
+
374
+ attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
375
+ if attention_mask is not None: # no matter the length, we just slice it
376
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
377
+ attn_weights = attn_weights + causal_mask
378
+
379
+ # upcast attention to fp32
380
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
381
+ attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
382
+ attn_output = torch.matmul(attn_weights, value_states)
383
+
384
+ if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
385
+ raise ValueError(
386
+ f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
387
+ f" {attn_output.size()}"
388
+ )
389
+
390
+ attn_output = attn_output.transpose(1, 2).contiguous()
391
+ attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
392
+
393
+ attn_output = self.o_proj(attn_output)
394
+
395
+ if not output_attentions:
396
+ attn_weights = None
397
+
398
+ return attn_output, attn_weights, past_key_value
399
+
400
+
401
+ class DreamSdpaAttention(DreamAttention):
402
+ """
403
+ Dream attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
404
+ `DreamAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
405
+ SDPA API.
406
+ """
407
+
408
+ # Adapted from DreamAttention.forward
409
+ def forward(
410
+ self,
411
+ hidden_states: torch.Tensor,
412
+ attention_mask: Optional[torch.Tensor] = None,
413
+ update_kvcache: torch.int32 = None,
414
+ position_ids: Optional[torch.LongTensor] = None,
415
+ past_key_value: Optional[Cache] = None,
416
+ output_attentions: bool = False,
417
+ use_cache: bool = False,
418
+ cache_position: Optional[torch.LongTensor] = None,
419
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
420
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
421
+ if output_attentions:
422
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
423
+ logger.warning_once(
424
+ "DreamModel is using DreamSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
425
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
426
+ )
427
+ return super().forward(
428
+ hidden_states=hidden_states,
429
+ attention_mask=attention_mask,
430
+ position_ids=position_ids,
431
+ past_key_value=past_key_value,
432
+ output_attentions=output_attentions,
433
+ use_cache=use_cache,
434
+ )
435
+
436
+ bsz, q_len, _ = hidden_states.size()
437
+
438
+ query_states = self.q_proj(hidden_states)
439
+ key_states = self.k_proj(hidden_states)
440
+ value_states = self.v_proj(hidden_states)
441
+
442
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
443
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
444
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
445
+
446
+ if position_embeddings is None:
447
+ logger.warning_once(
448
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
449
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
450
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
451
+ "removed and `position_embeddings` will be mandatory."
452
+ )
453
+ cos, sin = self.rotary_emb(value_states, position_ids)
454
+ else:
455
+ cos, sin = position_embeddings
456
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
457
+
458
+ if past_key_value is not None:
459
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
460
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
461
+
462
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
463
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
464
+
465
+ # causal_mask = attention_mask
466
+ # if attention_mask is not None: # no matter the length, we just slice it
467
+ # causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
468
+
469
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
470
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
471
+ if query_states.device.type == "cuda" and attention_mask is not None:
472
+ query_states = query_states.contiguous()
473
+ key_states = key_states.contiguous()
474
+ value_states = value_states.contiguous()
475
+
476
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
477
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
478
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
479
+ # is_causal = True if causal_mask is None and q_len > 1 else False
480
+
481
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
482
+ query_states,
483
+ key_states,
484
+ value_states,
485
+ attn_mask=attention_mask if attention_mask is not None else None,
486
+ dropout_p=self.attention_dropout if self.training else 0.0,
487
+ is_causal=False, # hard coded
488
+ )
489
+
490
+ attn_output = attn_output.transpose(1, 2).contiguous()
491
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
492
+
493
+ attn_output = self.o_proj(attn_output)
494
+
495
+ return attn_output, None, past_key_value
496
+ class DreamFlexAttention(DreamAttention):
497
+ """
498
+ Dream attention module using torch.nn.functional.scaled_dot_product_attention. This module inherits from
499
+ `DreamAttention` as the weights of the module stays untouched. The only changes are on the forward pass to adapt to
500
+ SDPA API.
501
+ """
502
+
503
+ # Adapted from DreamAttention.forward
504
+ def forward(
505
+ self,
506
+ hidden_states: torch.Tensor,
507
+ attention_mask: Optional[torch.Tensor] = None,
508
+ update_kvcache: torch.int32 = None,
509
+ position_ids: Optional[torch.LongTensor] = None,
510
+ past_key_value: Optional[Cache] = None,
511
+ output_attentions: bool = False,
512
+ use_cache: bool = False,
513
+ cache_position: Optional[torch.LongTensor] = None,
514
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
515
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
516
+ if output_attentions:
517
+ # TODO: Improve this warning with e.g. `model.config.attn_implementation = "manual"` once this is implemented.
518
+ logger.warning_once(
519
+ "DreamModel is using DreamSdpaAttention, but `torch.nn.functional.scaled_dot_product_attention` does not support `output_attentions=True`. Falling back to the manual attention implementation, "
520
+ 'but specifying the manual implementation will be required from Transformers version v5.0.0 onwards. This warning can be removed using the argument `attn_implementation="eager"` when loading the model.'
521
+ )
522
+ return super().forward(
523
+ hidden_states=hidden_states,
524
+ attention_mask=attention_mask,
525
+ position_ids=position_ids,
526
+ past_key_value=past_key_value,
527
+ output_attentions=output_attentions,
528
+ use_cache=use_cache,
529
+ )
530
+ # print("hidden_states",hidden_states)
531
+ bsz, q_len, _ = hidden_states.size()
532
+
533
+ query_states = self.q_proj(hidden_states)
534
+ key_states = self.k_proj(hidden_states)
535
+ value_states = self.v_proj(hidden_states)
536
+
537
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
538
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
539
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
540
+
541
+ if position_embeddings is None:
542
+ logger.warning_once(
543
+ "The attention layers in this model are transitioning from computing the RoPE embeddings internally "
544
+ "through `position_ids` (2D tensor with the indexes of the tokens), to using externally computed "
545
+ "`position_embeddings` (Tuple of tensors, containing cos and sin). In v4.46 `position_ids` will be "
546
+ "removed and `position_embeddings` will be mandatory."
547
+ )
548
+ cos, sin = self.rotary_emb(value_states, position_ids)
549
+ else:
550
+ cos, sin = position_embeddings
551
+ # print(query_states.shape,key_states.shape,cos.shape,sin.shape)
552
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
553
+ # print("k,v",key_states.shape,value_states.shape,past_key_value)
554
+ # print(cos.shape,sin.shape,cache_position.shape)
555
+ if past_key_value is not None:
556
+ if update_kvcache == 0:
557
+ past_key_states, past_value_states = past_key_value[self.layer_idx]
558
+ key_states=torch.cat([past_key_states, key_states], dim=2)
559
+ value_states=torch.cat([past_value_states, value_states], dim=2)
560
+ # Specific to RoPE models
561
+ else:
562
+ cache_kwargs = {"sin": sin[:,:update_kvcache,:], "cos": cos[:,:update_kvcache,:], "cache_position": cache_position[:update_kvcache]}
563
+ # print("update_kvcache",update_kvcache)
564
+ new_key_states, new_value_states = past_key_value.update(key_states[:,:,:update_kvcache, :], value_states[:,:,:update_kvcache, : ], self.layer_idx, cache_kwargs)
565
+ # print("new_kv",new_key_states.shape,new_value_states.shape)
566
+ # print("k,v",new_key_states.shape,new_value_states.shape)
567
+ key_states = torch.cat([new_key_states,key_states[:,:,update_kvcache:,:]], dim=2)
568
+ value_states = torch.cat([new_value_states,value_states[:,:,update_kvcache:,:]], dim=2)
569
+ # print("k,v",key_states.shape,value_states.shape)
570
+ # print(key_states.shape,value_states.shape)
571
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
572
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
573
+
574
+ # causal_mask = attention_mask
575
+ if attention_mask is not None: # no matter the length, we just slice it
576
+ atte_mask = attention_mask[:,:, :, : key_states.shape[-2]].clone()
577
+ # print(update_kvcache,attention_mask.shape)
578
+ # if attention_mask.shape[3]>86+32:
579
+ # if attention_mask.shape[-1]!=attention_mask.shape[-2]:
580
+ # atte_mask[:,:,:update_kvcache,-update_kvcache:]=-torch.inf
581
+
582
+ # if update_kvcache > 0:
583
+ # print("attention_mask中出现过的值",atte_mask.unique())
584
+ # print('tTTTTTTTTT')
585
+ # print("-"*20)
586
+ # print("attention_mask",attention_mask,update_kvcache)
587
+ # print(attention_mask)
588
+ # exit()
589
+ # print(attention_mask[0,0,:,:],attention_mask[0,0,:,:].shape)
590
+ # exit(0)
591
+
592
+ # SDPA with memory-efficient backend is currently (torch==2.1.2) bugged with non-contiguous inputs with custom attn_mask,
593
+ # Reference: https://github.com/pytorch/pytorch/issues/112577.
594
+ if query_states.device.type == "cuda" and attention_mask is not None:
595
+ query_states = query_states.contiguous()
596
+ key_states = key_states.contiguous()
597
+ value_states = value_states.contiguous()
598
+
599
+ # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
600
+ # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
601
+ # The q_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create a causal mask in case q_len == 1.
602
+ # is_causal = True if causal_mask is None and q_len > 1 else False
603
+ # print(query_states.shape[2], key_states.shape[2])
604
+ # attention_mask=attention_mask[:,:, :key_states.shape[2], :key_states.shape[2]] if attention_mask is not None else None
605
+ # attn_output = flex_attention(query_states, key_states, value_states, block_mask= attention_mask ),
606
+ # print(query_states.shape, key_states.shape, value_states.shape, attention_mask.shape if attention_mask is not None else None)
607
+ # print(query_states.dtype,attention_mask.dtype if attention_mask is not None else None)
608
+ # print(self.training)
609
+ # print("key_states",key_states[:,:,:84,:])
610
+ # torch.save(key_states,"key_states1.pt")
611
+ # torch.save(value_states,"value_states1.pt")
612
+ # torch.save(value_states,"query_state1.pt")
613
+ # torch.save(attention_mask,"attention_mask1.pt")
614
+ # print(atte_mask.shape)
615
+ attn_output = torch.nn.functional.scaled_dot_product_attention(
616
+ query_states,
617
+ key_states,
618
+ value_states,
619
+ attn_mask=atte_mask if attention_mask is not None else None,
620
+ dropout_p=self.attention_dropout if self.training else 0.0,
621
+ is_causal=False, # hard coded
622
+ )
623
+ # print("attn_output",attn_output[:,:,:84,:],attn_output.shape)
624
+ # print(atte_mask[:,:,:84,:84],attenti_mask.shape)
625
+ # exit()
626
+ # if self.layer_idx==2:
627
+ # torch.save(attn_output,"attn_output2.pt")
628
+ # exit()
629
+ attn_output = attn_output.transpose(1, 2).contiguous()
630
+ attn_output = attn_output.view(bsz, q_len, self.hidden_size)
631
+
632
+ attn_output = self.o_proj(attn_output)
633
+
634
+ return attn_output, None, past_key_value
635
+
636
+ class DreamDecoderLayer(nn.Module):
637
+ def __init__(self, config: DreamConfig, layer_idx: int):
638
+ super().__init__()
639
+ self.hidden_size = config.hidden_size
640
+
641
+ if config.sliding_window and config._attn_implementation != "flash_attention_2":
642
+ logger.warning_once(
643
+ f"Sliding Window Attention is enabled but not implemented for `{config._attn_implementation}`; "
644
+ "unexpected results may be encountered."
645
+ )
646
+
647
+ # self.self_attn = Dream_ATTENTION_CLASSES[config._attn_implementation](config, layer_idx)
648
+ self.self_attn = DreamFlexAttention(config, layer_idx)
649
+
650
+ self.mlp = DreamMLP(config)
651
+ self.input_layernorm = DreamRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
652
+ self.post_attention_layernorm = DreamRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
653
+
654
+ def forward(
655
+ self,
656
+ hidden_states: torch.Tensor,
657
+ update_kvcache: torch.int32 = None,
658
+ attention_mask: Optional[torch.Tensor] = None,
659
+ position_ids: Optional[torch.LongTensor] = None,
660
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
661
+ output_attentions: Optional[bool] = False,
662
+ use_cache: Optional[bool] = False,
663
+ cache_position: Optional[torch.LongTensor] = None,
664
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # will become mandatory in v4.46
665
+ **kwargs,
666
+ ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
667
+ """
668
+ Args:
669
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
670
+ attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
671
+ `(batch, sequence_length)` where padding elements are indicated by 0.
672
+ output_attentions (`bool`, *optional*):
673
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
674
+ returned tensors for more detail.
675
+ use_cache (`bool`, *optional*):
676
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
677
+ (see `past_key_values`).
678
+ past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
679
+ cache_position (`torch.LongTensor` of shape `(sequence_length)`, *optional*):
680
+ Indices depicting the position of the input sequence tokens in the sequence.
681
+ position_embeddings (`Tuple[torch.FloatTensor, torch.FloatTensor]`, *optional*):
682
+ Tuple containing the cosine and sine positional embeddings of shape `(batch_size, seq_len, head_dim)`,
683
+ with `head_dim` being the embedding dimension of each attention head.
684
+ kwargs (`dict`, *optional*):
685
+ Arbitrary kwargs to be ignored, used for FSDP and other methods that injects code
686
+ into the model
687
+ """
688
+
689
+ residual = hidden_states
690
+
691
+ hidden_states = self.input_layernorm(hidden_states)
692
+
693
+ # Self Attention
694
+ hidden_states, self_attn_weights, present_key_value = self.self_attn(
695
+ hidden_states=hidden_states,
696
+ attention_mask=attention_mask,
697
+ update_kvcache=update_kvcache,
698
+ position_ids=position_ids,
699
+ past_key_value=past_key_value,
700
+ output_attentions=output_attentions,
701
+ use_cache=use_cache,
702
+ cache_position=cache_position,
703
+ position_embeddings=position_embeddings,
704
+ )
705
+ hidden_states = residual + hidden_states
706
+
707
+ # Fully Connected
708
+ residual = hidden_states
709
+ hidden_states = self.post_attention_layernorm(hidden_states)
710
+ hidden_states = self.mlp(hidden_states)
711
+ hidden_states = residual + hidden_states
712
+
713
+ outputs = (hidden_states,)
714
+
715
+ if output_attentions:
716
+ outputs += (self_attn_weights,)
717
+
718
+ if use_cache:
719
+ outputs += (present_key_value,)
720
+
721
+ return outputs
722
+
723
+ class DreamPreTrainedModel(PreTrainedModel):
724
+ config_class = DreamConfig
725
+ base_model_prefix = "model"
726
+ supports_gradient_checkpointing = True
727
+ _no_split_modules = ["DreamDecoderLayer"]
728
+ _skip_keys_device_placement = "past_key_values"
729
+ _supports_flash_attn_2 = True
730
+ _supports_sdpa = True
731
+ _supports_cache_class = True
732
+ _supports_quantized_cache = True
733
+ _supports_static_cache = True
734
+
735
+ def _init_weights(self, module):
736
+ std = self.config.initializer_range
737
+ if isinstance(module, nn.Linear):
738
+ module.weight.data.normal_(mean=0.0, std=std)
739
+ if module.bias is not None:
740
+ module.bias.data.zero_()
741
+ elif isinstance(module, nn.Embedding):
742
+ module.weight.data.normal_(mean=0.0, std=std)
743
+ if module.padding_idx is not None:
744
+ module.weight.data[module.padding_idx].zero_()
745
+
746
+ @classmethod
747
+ def from_pretrained(
748
+ cls,
749
+ pretrained_model_name_or_path: Optional[Union[str, os.PathLike]],
750
+ *model_args,
751
+ config: Optional[Union[PretrainedConfig, str, os.PathLike]] = None,
752
+ cache_dir: Optional[Union[str, os.PathLike]] = None,
753
+ ignore_mismatched_sizes: bool = False,
754
+ force_download: bool = False,
755
+ local_files_only: bool = False,
756
+ token: Optional[Union[str, bool]] = None,
757
+ revision: str = "main",
758
+ use_safetensors: Optional[bool] = None,
759
+ weights_only: bool = True,
760
+ **kwargs,
761
+ ):
762
+ _model = super().from_pretrained(
763
+ pretrained_model_name_or_path,
764
+ *model_args,
765
+ config=config,
766
+ cache_dir=cache_dir,
767
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
768
+ force_download=force_download,
769
+ local_files_only=local_files_only,
770
+ token=token,
771
+ revision=revision,
772
+ use_safetensors=use_safetensors,
773
+ weights_only=weights_only,
774
+ **kwargs,
775
+ )
776
+ # NOTE(Lin): we need to override the generation config
777
+ # because the generation config loaded in `from_pretrained`
778
+ # does not include all the attributes of DreamGenerationConfig
779
+ resume_download = kwargs.get("resume_download", None)
780
+ proxies = kwargs.get("proxies", None)
781
+ subfolder = kwargs.get("subfolder", "")
782
+ from_auto_class = kwargs.get("_from_auto", False)
783
+ from_pipeline = kwargs.get("_from_pipeline", None)
784
+ _model.generation_config = DreamGenerationConfig.from_pretrained(
785
+ pretrained_model_name_or_path,
786
+ cache_dir=cache_dir,
787
+ force_download=force_download,
788
+ resume_download=resume_download,
789
+ proxies=proxies,
790
+ local_files_only=local_files_only,
791
+ token=token,
792
+ revision=revision,
793
+ subfolder=subfolder,
794
+ _from_auto=from_auto_class,
795
+ _from_pipeline=from_pipeline,
796
+ )
797
+ return _model
798
+
799
+ class DreamBaseModel(DreamPreTrainedModel):
800
+ """
801
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DreamDecoderLayer`]
802
+ Args:
803
+ config: DreamConfig
804
+ """
805
+
806
+ def __init__(self, config: DreamConfig):
807
+ super().__init__(config)
808
+ self.padding_idx = config.pad_token_id
809
+ self.vocab_size = config.vocab_size
810
+
811
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
812
+ self.layers = nn.ModuleList(
813
+ [DreamDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
814
+ )
815
+ self._attn_implementation = config._attn_implementation
816
+ self.norm = DreamRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
817
+ self.rotary_emb = DreamRotaryEmbedding(config=config)
818
+
819
+ self.gradient_checkpointing = False
820
+ # Initialize weights and apply final processing
821
+ self.post_init()
822
+
823
+ def get_input_embeddings(self):
824
+ return self.embed_tokens
825
+
826
+ def set_input_embeddings(self, value):
827
+ self.embed_tokens = value
828
+
829
+ def forward(
830
+ self,
831
+ input_ids: torch.LongTensor = None,
832
+ update_kvcache: torch.int32 = None,
833
+ attention_mask: Optional[torch.Tensor] = None,
834
+ position_ids: Optional[torch.LongTensor] = None,
835
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
836
+ inputs_embeds: Optional[torch.FloatTensor] = None,
837
+ use_cache: Optional[bool] = None,
838
+ output_attentions: Optional[bool] = None,
839
+ output_hidden_states: Optional[bool] = None,
840
+ return_dict: Optional[bool] = None,
841
+ cache_position: Optional[torch.LongTensor] = None,
842
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
843
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
844
+ output_hidden_states = (
845
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
846
+ )
847
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
848
+
849
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
850
+
851
+ if (input_ids is None) ^ (inputs_embeds is not None):
852
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
853
+
854
+ if self.gradient_checkpointing and self.training:
855
+ if use_cache:
856
+ logger.warning_once(
857
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
858
+ )
859
+ use_cache = False
860
+
861
+ if inputs_embeds is None:
862
+ # past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
863
+ # input_ids = input_ids[:, past_seen_tokens:]
864
+ inputs_embeds = self.embed_tokens(input_ids)
865
+ # print("inputs_embeds",inputs_embeds.shape)
866
+
867
+ if use_cache and past_key_values is None:
868
+ past_key_values = DynamicCache()
869
+
870
+ if cache_position is None:
871
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
872
+ cache_position = torch.arange(
873
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
874
+ )
875
+
876
+ if position_ids is None:
877
+ position_ids = cache_position.unsqueeze(0)
878
+
879
+ hidden_states = inputs_embeds
880
+
881
+ # create position embeddings to be shared across the decoder layers
882
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
883
+
884
+ # decoder layers
885
+ all_hidden_states = () if output_hidden_states else None
886
+ all_self_attns = () if output_attentions else None
887
+
888
+ for decoder_layer in self.layers:
889
+ if output_hidden_states:
890
+ all_hidden_states += (hidden_states,)
891
+
892
+ if self.gradient_checkpointing and self.training:
893
+ layer_outputs = self._gradient_checkpointing_func(
894
+ decoder_layer.__call__,
895
+ hidden_states,
896
+ attention_mask,
897
+ position_ids,
898
+ past_key_values,
899
+ output_attentions,
900
+ use_cache,
901
+ cache_position,
902
+ position_embeddings,
903
+ )
904
+ else:
905
+ layer_outputs = decoder_layer(
906
+ hidden_states,
907
+ attention_mask=attention_mask,
908
+ update_kvcache=update_kvcache,
909
+ position_ids=position_ids,
910
+ past_key_value=past_key_values,
911
+ output_attentions=output_attentions,
912
+ use_cache=use_cache,
913
+ cache_position=cache_position,
914
+ position_embeddings=position_embeddings,
915
+ )
916
+
917
+ hidden_states = layer_outputs[0]
918
+
919
+ if output_attentions:
920
+ all_self_attns += (layer_outputs[1],)
921
+
922
+ hidden_states = self.norm(hidden_states)
923
+
924
+ # add hidden states from the last decoder layer
925
+ if output_hidden_states:
926
+ all_hidden_states += (hidden_states,)
927
+
928
+ if not return_dict:
929
+ return tuple(v for v in [hidden_states, all_hidden_states, all_self_attns] if v is not None)
930
+ return BaseModelOutputWithPast(
931
+ last_hidden_state=hidden_states,
932
+ past_key_values=past_key_values if use_cache else None,
933
+ hidden_states=all_hidden_states,
934
+ attentions=all_self_attns,
935
+ )
936
+
937
+
938
+ class DreamModel(DreamGenerationMixin, DreamPreTrainedModel):
939
+ _tied_weights_keys = ["lm_head.weight"]
940
+
941
+ def __init__(self, config):
942
+ super().__init__(config)
943
+ self.model = DreamBaseModel(config)
944
+ self.vocab_size = config.vocab_size
945
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
946
+
947
+ # Initialize weights and apply final processing
948
+ self.post_init()
949
+
950
+ def reset_rope_parameters(self):
951
+ self.model.rotary_emb.reset_parameters()
952
+ for layer in self.model.layers:
953
+ layer.self_attn.rotary_emb.reset_parameters()
954
+
955
+ def get_input_embeddings(self):
956
+ return self.model.embed_tokens
957
+
958
+ def set_input_embeddings(self, value):
959
+ self.model.embed_tokens = value
960
+
961
+ def get_output_embeddings(self):
962
+ return self.lm_head
963
+
964
+ def set_output_embeddings(self, new_embeddings):
965
+ self.lm_head = new_embeddings
966
+
967
+ def set_decoder(self, decoder):
968
+ self.model = decoder
969
+
970
+ def get_decoder(self):
971
+ return self.model
972
+
973
+ def forward(
974
+ self,
975
+ input_ids: torch.LongTensor = None,
976
+ attention_mask: Optional[torch.Tensor] = None,
977
+ update_kvcache: torch.int32 = None,
978
+ position_ids: Optional[torch.LongTensor] = None,
979
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
980
+ inputs_embeds: Optional[torch.FloatTensor] = None,
981
+ labels: Optional[torch.LongTensor] = None,
982
+ use_cache: Optional[bool] = None,
983
+ output_attentions: Optional[bool] = None,
984
+ output_hidden_states: Optional[bool] = None,
985
+ return_dict: Optional[bool] = None,
986
+ cache_position: Optional[torch.LongTensor] = None,
987
+ num_logits_to_keep: int = 0,
988
+ **loss_kwargs,
989
+ ) -> Union[Tuple, MaskedLMOutput]:
990
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
991
+ output_hidden_states = (
992
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
993
+ )
994
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
995
+
996
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
997
+ outputs = self.model(
998
+ input_ids=input_ids,
999
+ attention_mask=attention_mask,
1000
+ update_kvcache=update_kvcache,
1001
+ position_ids=position_ids,
1002
+ past_key_values=past_key_values,
1003
+ inputs_embeds=inputs_embeds,
1004
+ use_cache=use_cache,
1005
+ output_attentions=output_attentions,
1006
+ output_hidden_states=output_hidden_states,
1007
+ return_dict=return_dict,
1008
+ cache_position=cache_position,
1009
+ )
1010
+
1011
+ hidden_states = outputs[0]
1012
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1013
+ logits = self.lm_head(hidden_states[:, -num_logits_to_keep:, :])
1014
+
1015
+ loss = None
1016
+ if labels is not None:
1017
+ loss = self.loss_function(logits, labels, self.vocab_size, **loss_kwargs)
1018
+
1019
+ if not return_dict:
1020
+ output = (logits,) + outputs[1:]
1021
+ return (loss,) + output if loss is not None else output
1022
+
1023
+ return CausalLMOutputWithPast(
1024
+ loss=loss,
1025
+ logits=logits,
1026
+ past_key_values=outputs.past_key_values,
1027
+ hidden_states=outputs.hidden_states,
1028
+ attentions=outputs.attentions,
1029
+ )
model_cache/llada/__pycache__/configuration_llada.cpython-310.pyc ADDED
Binary file (6.24 kB). View file
 
model_cache/llada/__pycache__/configuration_llada.cpython-312.pyc ADDED
Binary file (8.26 kB). View file
 
model_cache/llada/__pycache__/modeling_llada.cpython-310.pyc ADDED
Binary file (40.3 kB). View file
 
model_cache/llada/__pycache__/modeling_llada.cpython-312.pyc ADDED
Binary file (72.6 kB). View file
 
model_cache/llada/configuration_llada.py ADDED
@@ -0,0 +1,463 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ LLaDA configuration
3
+ """
4
+ from transformers import AutoConfig, PretrainedConfig
5
+
6
+ from enum import Enum
7
+ from os import PathLike
8
+ from typing import Union
9
+ from dataclasses import asdict, dataclass, field
10
+ from glob import glob
11
+ from pathlib import Path
12
+ from typing import (
13
+ Any,
14
+ Dict,
15
+ Iterable,
16
+ List,
17
+ Optional,
18
+ Tuple,
19
+ Type,
20
+ TypeVar,
21
+ Union,
22
+ cast,
23
+ )
24
+
25
+
26
+ __all__ = [
27
+ "ActivationType",
28
+ "ActivationCheckpointingStrategy",
29
+ "BlockType",
30
+ "LayerNormType",
31
+ "InitFnType",
32
+ "ModelConfig",
33
+ ]
34
+
35
+ PathOrStr = Union[str, PathLike]
36
+
37
+
38
+ class StrEnum(str, Enum):
39
+ """
40
+ This is equivalent to Python's :class:`enum.StrEnum` since version 3.11.
41
+ We include this here for compatibility with older version of Python.
42
+ """
43
+
44
+ def __str__(self) -> str:
45
+ return self.value
46
+
47
+ def __repr__(self) -> str:
48
+ return f"'{str(self)}'"
49
+
50
+
51
+ class LayerNormType(StrEnum):
52
+ default = "default"
53
+ """
54
+ The default LayerNorm implementation, equivalent to PyTorch's built-in version.
55
+ """
56
+
57
+ low_precision = "low_precision"
58
+ """
59
+ A low-precision version of the default LayerNorm.
60
+ """
61
+
62
+ rms = "rms"
63
+ """
64
+ An RMSNorm implementation. When using ``torch.compile`` this is
65
+ probably the fastest implementation.
66
+ """
67
+
68
+ gemma_rms = "gemma_rms"
69
+ """
70
+ An RMSNorm implementation by gemmma. When using ``torch.compile`` this is
71
+ probably the fastest implementation.
72
+ """
73
+
74
+ amd_compatible = "amd_compatible"
75
+ """
76
+ LayerNorm implemented manually to work around an issue with ROCm.
77
+ """
78
+
79
+
80
+ class ActivationType(StrEnum):
81
+ gelu = "gelu"
82
+ relu = "relu"
83
+ silu = "silu"
84
+ swiglu = "swiglu"
85
+
86
+
87
+ class BlockType(StrEnum):
88
+ sequential = "sequential"
89
+ parallel = "parallel"
90
+
91
+ llama = "llama"
92
+ """
93
+ A block similar to the sequential block with slightly different
94
+ implementations of operations like attention to imitate the behavior of Llama.
95
+ """
96
+
97
+
98
+ class InitFnType(StrEnum):
99
+ mitchell = "mitchell"
100
+ """
101
+ The strategy suggested to us by Mitchell Wortsman from UW.
102
+ This uses a truncated normal distribution with an adaptive standard deviation that depends
103
+ on the size of the weights as well as the depth of the layer.
104
+ """
105
+
106
+ normal = "normal"
107
+ """
108
+ All weights are initialized from the same normal distribution.
109
+ """
110
+
111
+ kaiming_normal = "kaiming_normal"
112
+ """
113
+ All weights are initialized with the Kaiming method from a normal distribution.
114
+ Note this currently won't work with FSDP.
115
+ """
116
+
117
+ fan_in = "fan_in"
118
+ """
119
+ "Fan-in variance scaling", i.e. normal with a standard deviation of ``1/sqrt(d_in)`` where ``d_in``
120
+ is the input dimensionality of the kernel.
121
+ """
122
+
123
+ full_megatron = "full_megatron"
124
+ """
125
+ This is what metaseq calls "full megatron init". It is the init used for Llama 2.
126
+ """
127
+
128
+
129
+ @dataclass
130
+ class ModelConfig():
131
+ """
132
+ LLaDA (model) configuration.
133
+ """
134
+
135
+ # Note that the defaults for these attributes are equivalent to the base GPT2 model.
136
+
137
+ d_model: int = 768
138
+ """
139
+ The hidden size of the model.
140
+ """
141
+
142
+ n_heads: int = 12
143
+ """
144
+ The number of self-attention heads.
145
+ """
146
+
147
+ n_kv_heads: Optional[int] = None
148
+ """
149
+ The number of heads to use for keys and values. Defaults to `n_heads`.
150
+ Set this to ``None`` or ``n_heads`` for normal multi-head attention.
151
+ Set this to 1 for multi-query attention.
152
+ Set it to some in-between value for Llama2-style grouped query attention.
153
+ """
154
+
155
+ n_layers: int = 12
156
+ """
157
+ The number of layers/blocks.
158
+ """
159
+
160
+ mlp_ratio: int = 4
161
+ """
162
+ The ratio of the inner MLP dimensionality to ``d_model``.
163
+ This is only used when ``mlp_hidden_size`` is not set.
164
+ """
165
+
166
+ mlp_hidden_size: Optional[int] = None
167
+ """
168
+ Set the exact hidden size for the MLP. Otherwise the inner MLP hidden size will be set to `mlp_ratio * d_model`.
169
+ """
170
+
171
+ activation_type: ActivationType = ActivationType.swiglu
172
+ """
173
+ The activation function to use within the MLP layers.
174
+ """
175
+
176
+ block_type: BlockType = BlockType.sequential
177
+ """
178
+ The transformer block implementation.
179
+ """
180
+
181
+ block_group_size: int = 1
182
+ """
183
+ The number of blocks to group together into a single parent block.
184
+ This has no affect on the number of parameters in the model and is only used to wrap groups
185
+ of blocks together with a single FSDP wrapper during training.
186
+ """
187
+
188
+ alibi: bool = False
189
+ """
190
+ If ``True``, use ALiBi embeddings. Mutually exclusive with ``rope``.
191
+ """
192
+
193
+ alibi_bias_max: float = 8.0
194
+ """
195
+ Maximum absolute value of ALiBi bias.
196
+ """
197
+
198
+ rope: bool = False
199
+ """
200
+ Use rotary positional embeddings (RoPE). Mutually exclusive with ``alibi``.
201
+ """
202
+
203
+ rope_full_precision: bool = True
204
+ """
205
+ If ``True``, apply RoPE embeddings at full precision regardless of the input type. Otherwise,
206
+ apply RoPE at the precision of the input.
207
+ """
208
+
209
+ flash_attention: bool = False
210
+ """
211
+ If ``True``, use ``FlashAttention``.
212
+ """
213
+
214
+ attention_dropout: float = 0.1
215
+ """
216
+ The dropout probability within the attention modules.
217
+ """
218
+
219
+ multi_query_attention: Optional[bool] = None
220
+ """
221
+ Use the Multi-Query formulation of attention used in PaLM. This reduces the number of parameters
222
+ and is more efficient during inference.
223
+ """
224
+
225
+ attention_layer_norm: bool = False
226
+ """
227
+ Apply layer norm to the keys and queries within the attention mechanism.
228
+ This can help stabilize training.
229
+ """
230
+
231
+ residual_dropout: float = 0.1
232
+ """
233
+ The dropout probability for the MLP and attention output within each block.
234
+ """
235
+
236
+ embedding_dropout: float = 0.1
237
+ """
238
+ The dropout probability for embeddings.
239
+ """
240
+
241
+ input_emb_norm: bool = False
242
+ """
243
+ An input hidden_states norm implementation by gemmma.
244
+ """
245
+
246
+ layer_norm_type: LayerNormType = LayerNormType.default
247
+ """
248
+ The layernorm implementation to use.
249
+ """
250
+
251
+ layer_norm_with_affine: bool = True
252
+ """
253
+ Whether to include bias and weight parameters for the layer norms.
254
+ This only affects layer norms that are immediately followed by a linear layer in the forward pass,
255
+ so everything except QK-norms. To turn off affines for QK norms as well, set :attr:`attention_layer_norm_with_affine`
256
+ to ``False``.
257
+ """
258
+
259
+ rms_norm_eps: float = 1e-05
260
+ """
261
+ The rms layernorm eps param.
262
+ """
263
+
264
+ attention_layer_norm_with_affine: bool = True
265
+ """
266
+ Toggle affine transform for the QK norms.
267
+ """
268
+
269
+ max_sequence_length: int = 1024
270
+ """
271
+ The maximum input sequence length supported by the model.
272
+ """
273
+
274
+ rope_theta: float = 10000.0
275
+ """
276
+ The rope base param.
277
+ """
278
+
279
+ include_qkv_bias: Optional[bool] = False
280
+ """
281
+ Whether or not to include bias parameters in qkv linear layers.
282
+ """
283
+
284
+ include_bias: bool = False
285
+ """
286
+ Whether or not to include bias parameters in linear layers.
287
+ In PaLM, they got rid of all bias terms because they found that large
288
+ models tend to have near 0 bias terms anyway.
289
+ """
290
+
291
+ bias_for_layer_norm: Optional[bool] = None
292
+ """
293
+ Whether or not to include bias parameters in layer norm.
294
+ This is separate from the include_bias parameter, because of a ROCm crash when biases are disabled in
295
+ layer norm.
296
+ When this is None (the default), it inherits the setting from include_bias.
297
+ """
298
+
299
+ scale_logits: bool = False
300
+ """
301
+ If ``True``, scale the output logits by ``1 / sqrt(d_model)``.
302
+ """
303
+
304
+ vocab_size: int = 50257
305
+ """
306
+ Vocabulary size of the model.
307
+ """
308
+
309
+ embedding_size: Optional[int] = 50304
310
+ """
311
+ The number of embeddings, i.e. the number of tokens. If set to ``None`` it will default
312
+ to ``vocab_size``. If ``vocab_size`` is not a multiple of 128, setting this to the
313
+ next multiple of 128 that's greater than ``vocab_size`` can improve throughput
314
+ substantially.
315
+ """
316
+
317
+ weight_tying: bool = True
318
+ """
319
+ Whether to tie output linear weights to the input embedding.
320
+ """
321
+
322
+ eos_token_id: int = 50256
323
+ """
324
+ The ID of the end-of-sentence special token.
325
+ """
326
+
327
+ pad_token_id: int = 50256
328
+ """
329
+ The ID of the token to use for padding. Defaults to the ID of the EOS token.
330
+ """
331
+
332
+ mask_token_id: Optional[int] = 50256
333
+ """
334
+ The ID of the token to use for mask token. Defaults to the ID of the EOS token.
335
+ """
336
+
337
+ init_device: Optional[str] = None
338
+ """
339
+ The torch device to use when initializing the model parameters, e.g. "cpu", "cuda:0", "meta".
340
+ """
341
+
342
+ init_fn: InitFnType = InitFnType.normal
343
+ """
344
+ The weight initialization strategy.
345
+ """
346
+
347
+ init_std: float = 0.02
348
+ """
349
+ The standard deviation to use when initializing weights with a "fixed distribution" ``init_fn``, such
350
+ as "normal".
351
+ """
352
+
353
+ init_cutoff_factor: Optional[float] = None
354
+ """
355
+ A positive factor used to scale the cutoff values when initializing weights with a "fixed distribution" ``init_fn``, such
356
+ as "normal". Setting this to None means values are not cutoff.
357
+ """
358
+
359
+ precision: Optional[str] = None
360
+ """
361
+ Precision used to train/evaluate with. You shouldn't set this directly.
362
+ See :data:`TrainConfig.precision` instead.
363
+ """
364
+
365
+ @property
366
+ def effective_n_kv_heads(self) -> int:
367
+ if self.n_kv_heads is None:
368
+ if self.multi_query_attention is True:
369
+ return 1
370
+ else:
371
+ return self.n_heads
372
+ else:
373
+ if self.multi_query_attention is None:
374
+ return self.n_kv_heads
375
+ if self.multi_query_attention:
376
+ n_kv_heads_should_be = 1
377
+ else:
378
+ n_kv_heads_should_be = self.n_heads
379
+ if self.n_kv_heads == n_kv_heads_should_be:
380
+ return n_kv_heads_should_be
381
+ else:
382
+ raise Exception(
383
+ "You can't set `multi_query_attention` and `n_kv_heads` at the same time."
384
+ )
385
+
386
+ class ActivationCheckpointingStrategy(StrEnum):
387
+ whole_layer = "whole_layer"
388
+ """
389
+ Checkpoint every transformer layer.
390
+ """
391
+
392
+ one_in_two = "one_in_two"
393
+ """
394
+ Checkpoint one in two transformer layers.
395
+ """
396
+
397
+ one_in_three = "one_in_three"
398
+ """
399
+ Checkpoint one in three transformer layers.
400
+ """
401
+
402
+ one_in_four = "one_in_four"
403
+ """
404
+ Checkpoint one in four transformer layers.
405
+ """
406
+
407
+ two_in_three = "two_in_three"
408
+ """
409
+ Checkpoint two out of every three transformer layers.
410
+ """
411
+
412
+ three_in_four = "three_in_four"
413
+ """
414
+ Checkpoint three out of four of every transformer layers.
415
+ """
416
+
417
+ four_in_five = "four_in_five"
418
+ """
419
+ Checkpoint four out of five of every transformer layers.
420
+ """
421
+
422
+ nine_in_ten = "nine_in_ten"
423
+ """
424
+ Checkpoint nine out of ten of every transformer layers.
425
+ """
426
+
427
+ fine_grained = "fine_grained"
428
+ """
429
+ Focus checkpointing on where it is cheap to recompute and saves most memory.
430
+ """
431
+
432
+
433
+ class LLaDAConfig(PretrainedConfig):
434
+ model_type = "llada"
435
+ keys_to_ignore_at_inference = ["past_key_values"] # TODO: confirm
436
+
437
+ def __init__(self, use_cache: bool = False, **kwargs):
438
+ model_config = ModelConfig()
439
+ all_kwargs = model_config.__dict__
440
+ all_kwargs.update(kwargs)
441
+ all_kwargs.update({"use_cache": use_cache})
442
+ all_kwargs.update(
443
+ {
444
+ "architectures": all_kwargs.get("architectures", ["LLaDAModelLM"])
445
+ }
446
+ )
447
+ super().__init__(**all_kwargs)
448
+
449
+ @property
450
+ def num_attention_heads(self):
451
+ return self.n_heads
452
+
453
+ @property
454
+ def num_hidden_layers(self):
455
+ return self.n_layers
456
+
457
+ @property
458
+ def hidden_size(self):
459
+ return self.d_model
460
+
461
+
462
+ # Register the config class so that it is available for transformer pipelines, auto-loading etc.
463
+ AutoConfig.register("llada", LLaDAConfig)
model_cache/llada/modeling_llada.py ADDED
@@ -0,0 +1,1504 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import math
5
+ import sys
6
+ from abc import abstractmethod
7
+ from collections import defaultdict
8
+ from functools import partial
9
+ from typing import (
10
+ Callable,
11
+ Dict,
12
+ Iterable,
13
+ List,
14
+ NamedTuple,
15
+ Optional,
16
+ Sequence,
17
+ Set,
18
+ Tuple,
19
+ cast,
20
+ )
21
+ from dataclasses import fields
22
+ from typing import List, Optional, Tuple, Union
23
+
24
+ import torch
25
+ import torch.backends.cuda
26
+ import torch.nn as nn
27
+ import torch.nn.functional as F
28
+ from torch import einsum
29
+ from transformers import PreTrainedModel
30
+ from transformers.modeling_outputs import CausalLMOutputWithPast
31
+ from transformers.models.auto import AutoModel
32
+ from transformers.cache_utils import Cache
33
+
34
+ from .configuration_llada import (
35
+ LLaDAConfig,
36
+ StrEnum,
37
+ InitFnType,
38
+ ActivationType,
39
+ BlockType,
40
+ LayerNormType,
41
+ ModelConfig,
42
+ ActivationCheckpointingStrategy,
43
+ )
44
+
45
+ if sys.version_info.minor > 8:
46
+ from collections.abc import MutableMapping
47
+ elif sys.version_info.minor == 8:
48
+ from typing import MutableMapping
49
+ else:
50
+ raise SystemExit("This script supports Python 3.8 or higher")
51
+
52
+ __all__ = [
53
+ "LayerNormBase",
54
+ "LayerNorm",
55
+ "RMSLayerNorm",
56
+ "GemmaRMSLayerNorm",
57
+ "RotaryEmbedding",
58
+ "Activation",
59
+ "GELU",
60
+ "ReLU",
61
+ "SwiGLU",
62
+ "LLaDABlock",
63
+ "LLaDASequentialBlock",
64
+ "LLaDAModel",
65
+ "LLaDAOutput",
66
+ "LLaDAGenerateOutput",
67
+ ]
68
+
69
+
70
+ log = logging.getLogger(__name__)
71
+
72
+
73
+ class ModuleType(StrEnum):
74
+ in_module = "in"
75
+ out_module = "out"
76
+ emb = "emb"
77
+ final_out = "final_out"
78
+
79
+
80
+ def init_weights(
81
+ config: ModelConfig,
82
+ module: Union[nn.Linear, nn.Embedding],
83
+ d: Optional[int] = None,
84
+ layer_id: Optional[int] = None,
85
+ std_factor: float = 1.0,
86
+ type_of_module: Optional[ModuleType] = None,
87
+ ) -> None:
88
+ """
89
+ Initialize weights of a linear or embedding module.
90
+
91
+ :param config: The model config.
92
+ :param module: The linear or embedding submodule to initialize.
93
+ :param d: The effective input dimensionality of the weights. This could be smaller than the actual dimensions
94
+ for fused layers.
95
+ :param layer_id: When set, the standard deviation for the "mitchell" method will be adjusted by
96
+ ``1 / sqrt(2 * (layer_id + 1))``.
97
+ """
98
+ d = d if d is not None else config.d_model
99
+ if config.init_fn == InitFnType.normal:
100
+ std = config.init_std * std_factor
101
+ if config.init_cutoff_factor is not None:
102
+ cutoff_value = config.init_cutoff_factor * std
103
+ nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-cutoff_value, b=cutoff_value)
104
+ else:
105
+ nn.init.normal_(module.weight, mean=0.0, std=std)
106
+ elif config.init_fn == InitFnType.mitchell:
107
+ std = std_factor / math.sqrt(d)
108
+ if layer_id is not None:
109
+ std = std / math.sqrt(2 * (layer_id + 1))
110
+ nn.init.trunc_normal_(module.weight, mean=0.0, std=std, a=-3 * std, b=3 * std)
111
+ elif config.init_fn == InitFnType.kaiming_normal:
112
+ nn.init.kaiming_normal_(module.weight, nonlinearity="relu")
113
+ elif config.init_fn == InitFnType.fan_in:
114
+ std = std_factor / math.sqrt(d)
115
+ nn.init.normal_(module.weight, mean=0.0, std=std)
116
+ elif config.init_fn == InitFnType.full_megatron:
117
+ if type_of_module is None:
118
+ raise RuntimeError(f"When using the {InitFnType.full_megatron} init, every module must have a type.")
119
+
120
+ cutoff_factor = config.init_cutoff_factor
121
+ if cutoff_factor is None:
122
+ cutoff_factor = 3
123
+
124
+ if type_of_module == ModuleType.in_module:
125
+ # for att_proj (same as QKV), ff_proj
126
+ std = config.init_std
127
+ elif type_of_module == ModuleType.out_module:
128
+ # for attn_out, ff_out
129
+ std = config.init_std / math.sqrt(2.0 * config.n_layers)
130
+ elif type_of_module == ModuleType.emb:
131
+ # positional embeddings (wpe)
132
+ # token embeddings (wte)
133
+ std = config.init_std
134
+ elif type_of_module == ModuleType.final_out:
135
+ # final output (ff_out)
136
+ std = config.d_model**-0.5
137
+ else:
138
+ raise RuntimeError(f"Unknown module type '{type_of_module}'")
139
+ nn.init.trunc_normal_(
140
+ module.weight,
141
+ mean=0.0,
142
+ std=std,
143
+ a=-cutoff_factor * std,
144
+ b=cutoff_factor * std,
145
+ )
146
+ else:
147
+ raise NotImplementedError(config.init_fn)
148
+
149
+ if isinstance(module, nn.Linear):
150
+ if module.bias is not None:
151
+ nn.init.zeros_(module.bias)
152
+
153
+ if config.init_fn == InitFnType.normal and getattr(module, "_is_residual", False):
154
+ with torch.no_grad():
155
+ module.weight.div_(math.sqrt(2 * config.n_layers))
156
+
157
+
158
+ def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False):
159
+ """
160
+ Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the dtype when ``check_neg_inf``
161
+ is ``True`` and to replace ``float("inf")`` with the maximum value of the dtype when ``check_pos_inf`` is ``True``.
162
+ """
163
+ if check_neg_inf:
164
+ x.masked_fill_(x == float("-inf"), torch.finfo(x.dtype).min)
165
+ if check_pos_inf:
166
+ x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max)
167
+
168
+
169
+ def activation_checkpoint_function(cfg: ModelConfig):
170
+ preserve_rng_state = (
171
+ (cfg.attention_dropout == 0.0) and (cfg.embedding_dropout == 0.0) and (cfg.residual_dropout == 0.0)
172
+ )
173
+ from torch.utils.checkpoint import checkpoint
174
+
175
+ return partial(
176
+ checkpoint,
177
+ preserve_rng_state=preserve_rng_state,
178
+ use_reentrant=False,
179
+ )
180
+
181
+
182
+ class BufferCache(dict, MutableMapping[str, torch.Tensor]):
183
+ """
184
+ Cache for attention biases and other things that would normally be stored as buffers.
185
+ We avoid using buffers because we've run into various issues doing so with FSDP.
186
+ In general it appears the way FSDP handles buffers is not well-defined.
187
+ It doesn't shard them but apparently it does synchronize them across processes, which we want to avoid
188
+ since (A) it isn't necessary, and (B) we sometimes have `-inf` in these biases which might get turned into
189
+ NaNs when they're synchronized due to casting or some other issue.
190
+ """
191
+
192
+
193
+ def _non_meta_init_device(config: ModelConfig) -> torch.device:
194
+ if config.init_device is not None and config.init_device != "meta":
195
+ return torch.device(config.init_device)
196
+ else:
197
+ return torch.device("cuda" if torch.cuda.is_available() else "cpu")
198
+
199
+
200
+ class Dropout(nn.Dropout):
201
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
202
+ if self.p == 0.0:
203
+ return input
204
+ else:
205
+ return F.dropout(input, self.p, self.training, self.inplace)
206
+
207
+
208
+ class LayerNormBase(nn.Module):
209
+ def __init__(
210
+ self,
211
+ config: ModelConfig,
212
+ *,
213
+ size: Optional[int] = None,
214
+ elementwise_affine: Optional[bool] = True,
215
+ eps: float = 1e-05,
216
+ ):
217
+ super().__init__()
218
+ self.config = config
219
+ self.eps = eps
220
+ self.normalized_shape = (size or config.d_model,)
221
+ if elementwise_affine or (elementwise_affine is None and self.config.layer_norm_with_affine):
222
+ self.weight = nn.Parameter(torch.ones(self.normalized_shape, device=config.init_device))
223
+ use_bias = self.config.bias_for_layer_norm
224
+ if use_bias is None:
225
+ use_bias = self.config.include_bias
226
+ if use_bias:
227
+ self.bias = nn.Parameter(torch.zeros(self.normalized_shape, device=config.init_device))
228
+ else:
229
+ self.register_parameter("bias", None)
230
+ else:
231
+ self.register_parameter("bias", None)
232
+ self.register_parameter("weight", None)
233
+
234
+ @abstractmethod
235
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
236
+ raise NotImplementedError
237
+
238
+ @classmethod
239
+ def build(cls, config: ModelConfig, size: Optional[int] = None, **kwargs) -> LayerNormBase:
240
+ if config.layer_norm_type == LayerNormType.default:
241
+ return LayerNorm(config, size=size, low_precision=False, **kwargs)
242
+ elif config.layer_norm_type == LayerNormType.low_precision:
243
+ return LayerNorm(config, size=size, low_precision=True, **kwargs)
244
+ elif config.layer_norm_type == LayerNormType.rms:
245
+ return RMSLayerNorm(config, size=size, **kwargs)
246
+ elif config.layer_norm_type == LayerNormType.gemma_rms:
247
+ return GemmaRMSLayerNorm(config, size=size, **kwargs)
248
+ else:
249
+ raise NotImplementedError(f"Unknown LayerNorm type: '{config.layer_norm_type}'")
250
+
251
+ def _cast_if_autocast_enabled(self, tensor: torch.Tensor, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
252
+ # NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
253
+ # `is_autocast_cpu_enabled()` for CPU autocast.
254
+ # See https://github.com/pytorch/pytorch/issues/110966.
255
+ if tensor.device.type == "cuda" and torch.is_autocast_enabled():
256
+ return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_gpu_dtype())
257
+ elif tensor.device.type == "cpu" and torch.is_autocast_cpu_enabled():
258
+ return tensor.to(dtype=dtype if dtype is not None else torch.get_autocast_cpu_dtype())
259
+ else:
260
+ return tensor
261
+
262
+ def reset_parameters(self):
263
+ if self.weight is not None:
264
+ torch.nn.init.ones_(self.weight) # type: ignore
265
+ if self.bias is not None:
266
+ torch.nn.init.zeros_(self.bias) # type: ignore
267
+
268
+
269
+ class LayerNorm(LayerNormBase):
270
+ """
271
+ The default :class:`LayerNorm` implementation which can optionally run in low precision.
272
+ """
273
+
274
+ def __init__(
275
+ self,
276
+ config: ModelConfig,
277
+ size: Optional[int] = None,
278
+ low_precision: bool = False,
279
+ elementwise_affine: Optional[bool] = None,
280
+ eps: float = 1e-05,
281
+ ):
282
+ super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps)
283
+ self.low_precision = low_precision
284
+
285
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
286
+ if self.low_precision:
287
+ module_device = x.device
288
+ downcast_x = self._cast_if_autocast_enabled(x)
289
+ downcast_weight = (
290
+ self._cast_if_autocast_enabled(self.weight) if self.weight is not None else self.weight
291
+ )
292
+ downcast_bias = self._cast_if_autocast_enabled(self.bias) if self.bias is not None else self.bias
293
+ with torch.autocast(enabled=False, device_type=module_device.type):
294
+ return F.layer_norm(
295
+ downcast_x, self.normalized_shape, weight=downcast_weight, bias=downcast_bias, eps=self.eps
296
+ )
297
+ else:
298
+ return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps)
299
+
300
+
301
+ class RMSLayerNorm(LayerNormBase):
302
+ """
303
+ RMS layer norm, a simplified :class:`LayerNorm` implementation
304
+ """
305
+
306
+ def __init__(
307
+ self,
308
+ config: ModelConfig,
309
+ size: Optional[int] = None,
310
+ elementwise_affine: Optional[bool] = None,
311
+ eps: float = 1e-5,
312
+ ):
313
+ super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=config.rms_norm_eps)
314
+
315
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
316
+ # with torch.autocast(enabled=False, device_type=x.device.type):
317
+ og_dtype = x.dtype
318
+ x = x.to(torch.float32)
319
+ # print(x.dtype,x.shape)
320
+ variance = x*x
321
+ # print(variance)
322
+ variance = variance.mean(dim=-1,keepdim=True)
323
+ x = x * torch.rsqrt(variance + self.eps)
324
+ x = x.to(og_dtype)
325
+
326
+ if self.weight is not None:
327
+ if self.bias is not None:
328
+ return self.weight * x + self.bias
329
+ else:
330
+ return self.weight * x
331
+ else:
332
+ return x
333
+
334
+
335
+ class GemmaRMSLayerNorm(LayerNormBase):
336
+ """
337
+ Gemma RMS layer norm, a simplified :class:`LayerNorm` implementation
338
+ """
339
+
340
+ def __init__(
341
+ self,
342
+ config: ModelConfig,
343
+ size: Optional[int] = None,
344
+ elementwise_affine: Optional[bool] = None,
345
+ eps: float = 1e-5,
346
+ ):
347
+ super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=config.rms_norm_eps)
348
+
349
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
350
+ with torch.autocast(enabled=False, device_type=x.device.type):
351
+ og_dtype = x.dtype
352
+ x = x.to(torch.float32)
353
+ variance = x.pow(2).mean(-1, keepdim=True)
354
+ x = x * torch.rsqrt(variance + self.eps)
355
+ x = x.to(og_dtype)
356
+
357
+ if self.weight is not None:
358
+ if self.bias is not None:
359
+ return x * (1 + self.weight) + self.bias
360
+ else:
361
+ return x * (1 + self.weight)
362
+ else:
363
+ return x
364
+
365
+
366
+ class RotaryEmbedding(nn.Module):
367
+ """
368
+ [Rotary positional embeddings (RoPE)](https://arxiv.org/abs/2104.09864).
369
+ """
370
+
371
+ def __init__(self, config: ModelConfig, cache: BufferCache):
372
+ super().__init__()
373
+ self.config = config
374
+ self.__cache = cache
375
+ # Warm up cache.
376
+ self.rope_theta = config.rope_theta
377
+ self.get_rotary_embedding(config.max_sequence_length, _non_meta_init_device(config))
378
+
379
+ def get_rotary_embedding(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]:
380
+ if (
381
+ (pos_sin := self.__cache.get("rope_pos_sin")) is not None
382
+ and (pos_cos := self.__cache.get("rope_pos_cos")) is not None
383
+ and pos_sin.shape[-2] >= seq_len
384
+ and pos_cos.shape[-2] >= seq_len
385
+ ):
386
+ if pos_sin.device != device:
387
+ pos_sin = pos_sin.to(device)
388
+ self.__cache["rope_pos_sin"] = pos_sin
389
+ if pos_cos.device != device:
390
+ pos_cos = pos_cos.to(device)
391
+ self.__cache["rope_pos_cos"] = pos_cos
392
+ return pos_sin[:, :, :seq_len, :], pos_cos[:, :, :seq_len, :]
393
+
394
+ with torch.autocast(device.type, enabled=False):
395
+ dim = self.config.d_model // self.config.n_heads
396
+ inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim))
397
+ seq = torch.arange(seq_len, device=device, dtype=torch.float)
398
+ freqs = einsum("i , j -> i j", seq, inv_freq)
399
+ positions = torch.cat((freqs, freqs), dim=-1)
400
+ pos_sin, pos_cos = positions.sin()[None, None, :, :], positions.cos()[None, None, :, :]
401
+ self.__cache["rope_pos_sin"] = pos_sin
402
+ self.__cache["rope_pos_cos"] = pos_cos
403
+ return pos_sin, pos_cos
404
+
405
+ def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
406
+ B, nh, T, hs = x.size()
407
+ x = x.view(B, nh, T, 2, hs // 2)
408
+ x1, x2 = x.unbind(dim=-2)
409
+ return torch.cat((-x2, x1), dim=-1)
410
+
411
+ def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
412
+ return ((t * pos_cos) + (self.rotate_half(t) * pos_sin)).to(t.dtype)
413
+
414
+ def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
415
+ if self.config.rope_full_precision:
416
+ q_, k_ = q.float(), k.float()
417
+ else:
418
+ q_, k_ = q, k
419
+
420
+ with torch.autocast(q.device.type, enabled=False):
421
+ query_len, key_len = q_.shape[-2], k_.shape[-2] # could be different if layer_past not None
422
+ pos_sin, pos_cos = self.get_rotary_embedding(key_len, q_.device)
423
+ pos_sin = pos_sin.type_as(q_)
424
+ pos_cos = pos_cos.type_as(q_)
425
+ q_ = self.apply_rotary_pos_emb(
426
+ pos_sin[:, :, key_len - query_len : key_len, :],
427
+ pos_cos[:, :, key_len - query_len : key_len, :],
428
+ q_,
429
+ )
430
+ k_ = self.apply_rotary_pos_emb(pos_sin, pos_cos, k_)
431
+ return q_.type_as(q), k_.type_as(k)
432
+
433
+
434
+ class Activation(nn.Module):
435
+ def __init__(self, config: ModelConfig):
436
+ super().__init__()
437
+ self.config = config
438
+
439
+ @abstractmethod
440
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
441
+ raise NotImplementedError
442
+
443
+ @property
444
+ @abstractmethod
445
+ def output_multiplier(self) -> float:
446
+ raise NotImplementedError
447
+
448
+ @classmethod
449
+ def build(cls, config: ModelConfig) -> Activation:
450
+ if config.activation_type == ActivationType.gelu:
451
+ return cast(Activation, GELU(approximate="none"))
452
+ elif config.activation_type == ActivationType.relu:
453
+ return cast(Activation, ReLU(inplace=False))
454
+ elif config.activation_type == ActivationType.silu:
455
+ return cast(Activation, SiLU(inplace=False))
456
+ elif config.activation_type == ActivationType.swiglu:
457
+ return SwiGLU(config)
458
+ else:
459
+ raise NotImplementedError(f"Unknown activation: '{config.activation_type}'")
460
+
461
+
462
+ class GELU(nn.GELU):
463
+ @property
464
+ def output_multiplier(self) -> float:
465
+ return 1.0
466
+
467
+
468
+ class ReLU(nn.ReLU):
469
+ @property
470
+ def output_multiplier(self) -> float:
471
+ return 1.0
472
+
473
+ class SiLU(nn.SiLU):
474
+ @property
475
+ def output_multiplier(self) -> float:
476
+ return 1.0
477
+
478
+ class SwiGLU(Activation):
479
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
480
+ x, gate = x.chunk(2, dim=-1)
481
+ return F.silu(gate) * x
482
+
483
+ @property
484
+ def output_multiplier(self) -> float:
485
+ return 0.5
486
+
487
+
488
+ def causal_attention_bias(seq_len: int, device: torch.device) -> torch.FloatTensor:
489
+ att_bias = torch.triu(
490
+ torch.ones(seq_len, seq_len, device=device, dtype=torch.float),
491
+ diagonal=1,
492
+ )
493
+ att_bias.masked_fill_(att_bias == 1, torch.finfo(att_bias.dtype).min)
494
+ return att_bias.view(1, 1, seq_len, seq_len) # type: ignore
495
+
496
+
497
+ def get_causal_attention_bias(cache: BufferCache, seq_len: int, device: torch.device) -> torch.Tensor:
498
+ if (causal_bias := cache.get("causal_attention_bias")) is not None and causal_bias.shape[-1] >= seq_len:
499
+ if causal_bias.device != device:
500
+ causal_bias = causal_bias.to(device)
501
+ cache["causal_attention_bias"] = causal_bias
502
+ return causal_bias
503
+ with torch.autocast(device.type, enabled=False):
504
+ causal_bias = causal_attention_bias(seq_len, device)
505
+ cache["causal_attention_bias"] = causal_bias
506
+ return causal_bias
507
+
508
+
509
+ def alibi_attention_bias(seq_len: int, config: ModelConfig, device: torch.device) -> torch.FloatTensor:
510
+ alibi_bias = torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, 1, seq_len)
511
+
512
+ # shape: (1, 1, seq_len, seq_len)
513
+ alibi_bias = alibi_bias - torch.arange(1 - seq_len, 1, dtype=torch.float, device=device).view(1, 1, seq_len, 1)
514
+ alibi_bias.abs_().mul_(-1)
515
+
516
+ # shape: (n_heads,)
517
+ m = torch.arange(1, config.n_heads + 1, dtype=torch.float, device=device)
518
+ m.mul_(config.alibi_bias_max / config.n_heads)
519
+
520
+ # shape: (1, n_heads, seq_len, seq_len)
521
+ return alibi_bias * (1.0 / (2 ** m.view(1, config.n_heads, 1, 1))) # type: ignore
522
+
523
+
524
+ class LLaDABlock(nn.Module):
525
+ """
526
+ A base class for transformer block implementations.
527
+ """
528
+
529
+ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
530
+ super().__init__()
531
+ self.layer_id = layer_id
532
+ self.config = config
533
+ self.hidden_size = (
534
+ config.mlp_hidden_size if config.mlp_hidden_size is not None else config.mlp_ratio * config.d_model
535
+ )
536
+ self.__cache = cache
537
+ assert config.d_model % config.n_heads == 0
538
+
539
+ self._activation_checkpoint_fn = None
540
+
541
+ # Dropout.
542
+ self.dropout = Dropout(config.residual_dropout)
543
+
544
+ # Layer norms.
545
+ self.k_norm: Optional[LayerNormBase] = None
546
+ self.q_norm: Optional[LayerNormBase] = None
547
+ if config.attention_layer_norm:
548
+ self.k_norm = LayerNormBase.build(
549
+ config,
550
+ size=(config.d_model // config.n_heads) * config.effective_n_kv_heads,
551
+ elementwise_affine=config.attention_layer_norm_with_affine,
552
+ )
553
+ self.q_norm = LayerNormBase.build(config, elementwise_affine=config.attention_layer_norm_with_affine)
554
+
555
+ # Activation function.
556
+ self.act = Activation.build(config)
557
+ assert (self.act.output_multiplier * self.hidden_size) % 1 == 0
558
+
559
+ # Attention output projection.
560
+ self.attn_out = nn.Linear(
561
+ config.d_model, config.d_model, bias=config.include_bias, device=config.init_device
562
+ )
563
+
564
+ # Feed-forward output projection.
565
+ self.ff_out = nn.Linear(
566
+ int(self.act.output_multiplier * self.hidden_size),
567
+ config.d_model,
568
+ bias=config.include_bias,
569
+ device=config.init_device,
570
+ )
571
+ self.ff_out._is_residual = True # type: ignore
572
+
573
+ # Rotary embeddings.
574
+ if self.config.rope:
575
+ self.rotary_emb = RotaryEmbedding(config, self.__cache)
576
+
577
+ self.flash_attn_func = None
578
+ if config.flash_attention:
579
+ try:
580
+ from flash_attn import flash_attn_func # type: ignore
581
+
582
+ self.flash_attn_func = flash_attn_func
583
+ except ModuleNotFoundError:
584
+ pass
585
+
586
+ def reset_parameters(self):
587
+ if self.k_norm is not None:
588
+ self.k_norm.reset_parameters()
589
+ if self.q_norm is not None:
590
+ self.q_norm.reset_parameters()
591
+ init_weights(
592
+ self.config,
593
+ self.attn_out,
594
+ d=self.config.d_model,
595
+ layer_id=self.layer_id,
596
+ type_of_module=ModuleType.out_module,
597
+ )
598
+ init_weights(
599
+ self.config,
600
+ self.ff_out,
601
+ d=self.ff_out.in_features,
602
+ layer_id=self.layer_id,
603
+ type_of_module=ModuleType.out_module,
604
+ )
605
+
606
+ def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
607
+ if strategy == ActivationCheckpointingStrategy.fine_grained:
608
+ self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
609
+ else:
610
+ self._activation_checkpoint_fn = None
611
+
612
+ @classmethod
613
+ def _cast_attn_bias(cls, bias: torch.Tensor, input_dtype: torch.dtype) -> torch.Tensor:
614
+ target_dtype = input_dtype
615
+ # NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
616
+ # `is_autocast_cpu_enabled()` for CPU autocast.
617
+ # See https://github.com/pytorch/pytorch/issues/110966.
618
+ if bias.device.type == "cuda" and torch.is_autocast_enabled():
619
+ target_dtype = torch.get_autocast_gpu_dtype()
620
+ elif bias.device.type == "cpu" and torch.is_autocast_cpu_enabled():
621
+ target_dtype = torch.get_autocast_cpu_dtype()
622
+ if bias.dtype != target_dtype:
623
+ bias = bias.to(target_dtype)
624
+ ensure_finite_(bias, check_neg_inf=True, check_pos_inf=False)
625
+ return bias
626
+
627
+ def _scaled_dot_product_attention(
628
+ self,
629
+ q: torch.Tensor,
630
+ k: torch.Tensor,
631
+ v: torch.Tensor,
632
+ attn_mask: Optional[torch.Tensor] = None,
633
+ dropout_p: float = 0.0,
634
+ is_causal: bool = False,
635
+ ) -> torch.Tensor:
636
+ """
637
+ Computes scaled dot product attention on query, key and value tensors, using an optional
638
+ attention mask if passed, and applying dropout if a probability greater than 0.0 is specified.
639
+ """
640
+ if self.flash_attn_func is not None and attn_mask is None:
641
+ r = self.flash_attn_func(
642
+ q.transpose(1, 2), k.transpose(1, 2), v.transpose(1, 2), dropout_p=dropout_p, causal=False
643
+ )
644
+ return r.transpose(1, 2)
645
+ else:
646
+ # torch's sdpa doesn't support GQA, so we're doing this
647
+ assert k.size(1) == v.size(1)
648
+ num_kv_heads = k.size(1)
649
+ num_q_heads = q.size(1)
650
+ if num_q_heads != num_kv_heads:
651
+ assert num_q_heads % num_kv_heads == 0
652
+ k = k.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
653
+ v = v.repeat_interleave(num_q_heads // num_kv_heads, dim=1, output_size=num_q_heads)
654
+ # Modify: MDM set causal to False, and with no attn_mask.
655
+ return F.scaled_dot_product_attention(
656
+ q,
657
+ k,
658
+ v,
659
+ attn_mask=attn_mask,
660
+ dropout_p=dropout_p,
661
+ is_causal=False,
662
+ )
663
+
664
+ def attention(
665
+ self,
666
+ q: torch.Tensor,
667
+ k: torch.Tensor,
668
+ v: torch.Tensor,
669
+ attention_bias: Optional[torch.Tensor] = None,
670
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
671
+ use_cache: bool = False,
672
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
673
+ B, T, C = q.size() # batch size, sequence length, d_model
674
+ dtype = k.dtype
675
+
676
+ # Optionally apply layer norm to keys and queries.
677
+ if self.q_norm is not None and self.k_norm is not None:
678
+ q = self.q_norm(q).to(dtype=dtype)
679
+ k = self.k_norm(k).to(dtype=dtype)
680
+
681
+ # Move head forward to be next to the batch dim.
682
+ # shape: (B, nh, T, hs)
683
+ q = q.view(B, T, self.config.n_heads, C // self.config.n_heads).transpose(1, 2)
684
+ # shape: (B, n_kv_h, T, hs)
685
+ k = k.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
686
+ # shape: (B, n_kv_h, T, hs)
687
+ v = v.view(B, T, self.config.effective_n_kv_heads, C // self.config.n_heads).transpose(1, 2)
688
+
689
+ if layer_past is not None:
690
+ past_key, past_value = layer_past
691
+ k = torch.cat((past_key, k), dim=-2)
692
+ v = torch.cat((past_value, v), dim=-2)
693
+
694
+ present = (k, v) if use_cache else None
695
+ query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None
696
+
697
+ if self.config.rope:
698
+ # Apply rotary embeddings.
699
+ q, k = self.rotary_emb(q, k)
700
+
701
+ # if attention_bias is not None:
702
+ # # Resize and cast attention bias.
703
+ # # The current dtype of the attention bias might not match the dtype that the SDP attn function will
704
+ # # run in if AMP is enabled, and this can be a problem if some tokens are masked out due to padding
705
+ # # as down-casting the attention bias to the autocast precision will result in -infs, which will
706
+ # # cause the SDP attn function to produce NaNs.
707
+ # attention_bias = self._cast_attn_bias(
708
+ # attention_bias[:, :, key_len - query_len : key_len, :key_len], dtype
709
+ # )
710
+
711
+ # Get the attention scores.
712
+ # shape: (B, nh, T, hs)
713
+ att = self._scaled_dot_product_attention(
714
+ q,
715
+ k,
716
+ v,
717
+ attn_mask=attention_bias,
718
+ dropout_p=0.0 if not self.training else self.config.attention_dropout,
719
+ is_causal=False,
720
+ )
721
+
722
+ # Re-assemble all head outputs side-by-side.
723
+ att = att.transpose(1, 2).contiguous().view(B, T, C)
724
+
725
+ # Apply output projection.
726
+ return self.attn_out(att), present
727
+
728
+ @abstractmethod
729
+ def forward(
730
+ self,
731
+ x: torch.Tensor,
732
+ attention_bias: Optional[torch.FloatTensor] = None,
733
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
734
+ use_cache: bool = False,
735
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
736
+ raise NotImplementedError
737
+
738
+ @classmethod
739
+ def build(cls, layer_id: int, config: ModelConfig, cache: BufferCache) -> LLaDABlock:
740
+ if config.block_type == BlockType.sequential:
741
+ return LLaDASequentialBlock(layer_id, config, cache)
742
+ elif config.block_type == BlockType.llama:
743
+ return LLaDALlamaBlock(layer_id, config, cache)
744
+ else:
745
+ raise NotImplementedError(f"Unknown block type: '{config.block_type}'")
746
+
747
+
748
+ class LLaDASequentialBlock(LLaDABlock):
749
+ """
750
+ This is a typical transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
751
+ (plus another skip connection).
752
+ """
753
+
754
+ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
755
+ super().__init__(layer_id, config, cache)
756
+ # Layer norms.
757
+ self.attn_norm = LayerNorm.build(config)
758
+ self.ff_norm = LayerNorm.build(config)
759
+ # Attention input projection. Projects x -> (q, k, v)
760
+ head_dim = config.d_model // config.n_heads
761
+ self.fused_dims = (
762
+ config.d_model,
763
+ config.effective_n_kv_heads * head_dim,
764
+ config.effective_n_kv_heads * head_dim,
765
+ )
766
+ self.att_proj = nn.Linear(
767
+ config.d_model, sum(self.fused_dims), bias=config.include_bias | config.include_qkv_bias, device=config.init_device
768
+ )
769
+ # Feed-forward input projection.
770
+ self.ff_proj = nn.Linear(
771
+ config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
772
+ )
773
+
774
+ def reset_parameters(self):
775
+ super().reset_parameters()
776
+ self.attn_norm.reset_parameters()
777
+ self.ff_norm.reset_parameters()
778
+ # NOTE: the standard deviation for these weights does not depend on the layer.
779
+ init_weights(
780
+ self.config, self.att_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
781
+ )
782
+ init_weights(
783
+ self.config, self.ff_proj, d=self.config.d_model, layer_id=None, type_of_module=ModuleType.in_module
784
+ )
785
+
786
+ def forward(
787
+ self,
788
+ x: torch.Tensor,
789
+ attention_bias: Optional[torch.Tensor] = None,
790
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
791
+ use_cache: bool = False,
792
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
793
+ # Get query, key, value projections.
794
+ # shape:
795
+ # - for regular attn q, k, v: (batch_size, seq_len, d_model)
796
+ # - for multi-query attn q: (batch_size, seq_len, d_model)
797
+ # k, v: (batch_size, seq_len, d_model // n_heads)
798
+ # - for group query attn q: (batch_size, seq_len, d_model)
799
+ # k, v: (batch_size, seq_len, d_model // n_kv_heads)
800
+ if self._activation_checkpoint_fn is not None:
801
+ q, k, v = self.att_proj(self._activation_checkpoint_fn(self.attn_norm, x)).split(
802
+ self.fused_dims, dim=-1
803
+ )
804
+ else:
805
+ q, k, v = self.att_proj(self.attn_norm(x)).split(self.fused_dims, dim=-1)
806
+
807
+ # Get attention scores.
808
+ if self._activation_checkpoint_fn is not None:
809
+ att, cache = self._activation_checkpoint_fn( # type: ignore
810
+ self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
811
+ )
812
+ else:
813
+ att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)
814
+
815
+ # Add attention scores.
816
+ # shape: (B, T, C)
817
+ x = x + self.dropout(att)
818
+
819
+ # Add feed-forward projection.
820
+ # shape: (batch_size, seq_len, d_model)
821
+ og_x = x
822
+ if self._activation_checkpoint_fn is not None:
823
+ x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
824
+ else:
825
+ x = self.ff_norm(x)
826
+ x = self.ff_proj(x)
827
+ if self._activation_checkpoint_fn is not None:
828
+ x = self._activation_checkpoint_fn(self.act, x) # type: ignore
829
+ else:
830
+ x = self.act(x)
831
+ x = self.ff_out(x)
832
+ x = self.dropout(x)
833
+ x = og_x + x
834
+
835
+ return x, cache
836
+
837
+
838
+ class LLaDALlamaBlock(LLaDABlock):
839
+ """
840
+ This is a transformer block where the output is computed as ``MLP(LN(x + Attention(LN(x))))``
841
+ (plus another skip connection). This block is similar to `LLaDASequentialBlock`
842
+ but some operations have slightly different implementations to imitate the
843
+ behavior of Llama.
844
+ """
845
+
846
+ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
847
+ super().__init__(layer_id, config, cache)
848
+ # Layer norms.
849
+ self.attn_norm = LayerNorm.build(config)
850
+ self.ff_norm = LayerNorm.build(config)
851
+ self.__cache = cache
852
+
853
+ # Attention input projection. Projects x -> (q, k, v)
854
+ head_dim = config.d_model // config.n_heads
855
+ q_proj_out_dim = config.d_model
856
+ k_proj_out_dim = config.effective_n_kv_heads * head_dim
857
+ v_proj_out_dim = config.effective_n_kv_heads * head_dim
858
+ self.q_proj = nn.Linear(
859
+ config.d_model, q_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device
860
+ )
861
+ self.k_proj = nn.Linear(
862
+ config.d_model, k_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device
863
+ )
864
+ self.v_proj = nn.Linear(
865
+ config.d_model, v_proj_out_dim, bias=config.include_bias | config.include_qkv_bias, device=config.init_device
866
+ )
867
+
868
+ # Feed-forward input projection.
869
+ self.ff_proj = nn.Linear(
870
+ config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
871
+ )
872
+ # new add
873
+ self.up_proj = nn.Linear(
874
+ config.d_model, self.hidden_size, bias=config.include_bias, device=config.init_device
875
+ )
876
+
877
+ def reset_parameters(self):
878
+ super().reset_parameters()
879
+ self.attn_norm.reset_parameters()
880
+ self.ff_norm.reset_parameters()
881
+ # NOTE: the standard deviation for these weights does not depend on the layer.
882
+ init_weights(self.config, self.q_proj, d=self.config.d_model, layer_id=None)
883
+ init_weights(self.config, self.k_proj, d=self.config.d_model, layer_id=None)
884
+ init_weights(self.config, self.v_proj, d=self.config.d_model, layer_id=None)
885
+ init_weights(self.config, self.ff_proj, d=self.config.d_model, layer_id=None)
886
+ init_weights(self.config, self.up_proj, d=self.config.d_model, layer_id=None) # new add
887
+
888
+ def forward(
889
+ self,
890
+ x: torch.Tensor,
891
+ attention_bias: Optional[torch.Tensor] = None,
892
+ layer_past: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
893
+ use_cache: bool = False,
894
+ ) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
895
+ # Get query, key, value projections.
896
+ # shape:
897
+ # - for regular attn q, k, v: (batch_size, seq_len, d_model)
898
+ # - for multi-query attn q: (batch_size, seq_len, d_model)
899
+ # k, v: (batch_size, seq_len, d_model // n_heads)
900
+ # - for group query attn q: (batch_size, seq_len, d_model)
901
+ # k, v: (batch_size, seq_len, d_model // n_kv_heads)
902
+ # print(x)
903
+ x_normed = self.attn_norm(x)
904
+ q = self.q_proj(x_normed)
905
+ k = self.k_proj(x_normed)
906
+ v = self.v_proj(x_normed)
907
+
908
+ # Get attention scores.
909
+ if self._activation_checkpoint_fn is not None:
910
+ att, cache = self._activation_checkpoint_fn( # type: ignore
911
+ self.attention, q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache
912
+ )
913
+ else:
914
+ att, cache = self.attention(q, k, v, attention_bias, layer_past=layer_past, use_cache=use_cache)
915
+
916
+ # Add attention scores.
917
+ # shape: (B, T, C)
918
+ x = x + self.dropout(att)
919
+
920
+ # Add feed-forward projection.
921
+ # shape: (batch_size, seq_len, d_model)
922
+ og_x = x
923
+ if self._activation_checkpoint_fn is not None:
924
+ x = self._activation_checkpoint_fn(self.ff_norm, x) # type: ignore
925
+ else:
926
+ x = self.ff_norm(x)
927
+ x, x_up = self.ff_proj(x), self.up_proj(x) # new add
928
+ if self._activation_checkpoint_fn is not None:
929
+ x = self._activation_checkpoint_fn(self.act, x) # type: ignore
930
+ else:
931
+ x = self.act(x)
932
+ x = x * x_up # new add
933
+ x = self.ff_out(x)
934
+ x = self.dropout(x)
935
+ x = og_x + x
936
+
937
+ return x, cache
938
+
939
+
940
+ class LLaDAOutput(NamedTuple):
941
+ logits: torch.FloatTensor
942
+ """
943
+ A tensor of shape `(batch_size, seq_len, vocab_size)` representing the log probabilities
944
+ for the next token *before* normalization via (log) softmax.
945
+ """
946
+
947
+ attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]]
948
+ """
949
+ Attention keys and values from each block.
950
+ """
951
+
952
+ hidden_states: Optional[Tuple[torch.Tensor]]
953
+ """
954
+ Hidden states from each block.
955
+ """
956
+
957
+
958
+ class LLaDAGenerateOutput(NamedTuple):
959
+ token_ids: torch.LongTensor
960
+ """
961
+ The generated token IDs, a tensor of shape `(batch_size, beam_size, max_steps)`.
962
+ These do *not* include the original input IDs.
963
+ """
964
+
965
+ scores: torch.FloatTensor
966
+ """
967
+ The scores of the generated sequences, a tensor of shape `(batch_size, beam_size)`.
968
+ """
969
+
970
+
971
+ class LLaDABlockGroup(nn.ModuleList):
972
+ def __init__(self, config: ModelConfig, layer_offset: int, modules: Optional[Iterable[nn.Module]] = None):
973
+ super().__init__(modules)
974
+ self.config = config
975
+ self.layer_offset = layer_offset
976
+ self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None
977
+ self._activation_checkpoint_fn = activation_checkpoint_function(self.config)
978
+
979
+ def forward(
980
+ self,
981
+ x: torch.Tensor,
982
+ attention_bias: Optional[torch.FloatTensor] = None,
983
+ layers_past: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
984
+ use_cache: bool = False,
985
+ ) -> Tuple[torch.Tensor, Optional[List[Tuple[torch.Tensor, torch.Tensor]]]]:
986
+ attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
987
+ for block_idx, block in enumerate(self):
988
+ layer_past = None if layers_past is None else layers_past[block_idx]
989
+ block_idx += self.layer_offset
990
+ if (
991
+ (self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.whole_layer)
992
+ or (
993
+ self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_two
994
+ and block_idx % 2 == 0
995
+ )
996
+ or (
997
+ self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_three
998
+ and block_idx % 3 == 0
999
+ )
1000
+ or (
1001
+ self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_four
1002
+ and block_idx % 4 == 0
1003
+ )
1004
+ ):
1005
+ # shape: (batch_size, seq_len, d_model)
1006
+ x, cache = self._activation_checkpoint_fn( # type: ignore
1007
+ block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache
1008
+ )
1009
+ else:
1010
+ # shape: (batch_size, seq_len, d_model)
1011
+ x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
1012
+ if attn_key_values is not None:
1013
+ assert cache is not None
1014
+ attn_key_values.append(cache)
1015
+ return x, attn_key_values
1016
+
1017
+ def reset_parameters(self):
1018
+ for block in self:
1019
+ block.reset_parameters()
1020
+
1021
+ def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
1022
+ self.activation_checkpointing_strategy = strategy
1023
+ for block in self:
1024
+ block.set_activation_checkpointing(strategy)
1025
+
1026
+
1027
+ class LLaDAModel(nn.Module):
1028
+ def __init__(self, config: ModelConfig, init_params: bool = True):
1029
+ super().__init__()
1030
+ self.config = config
1031
+ self.__cache = BufferCache()
1032
+
1033
+ # Validate config.
1034
+ if self.config.alibi and self.config.flash_attention:
1035
+ raise Exception("ALiBi is currently not supported with FlashAttention")
1036
+
1037
+ if self.config.alibi and self.config.rope:
1038
+ raise Exception("ALiBi and RoPE are mutually exclusive")
1039
+
1040
+ if self.config.embedding_size is not None and self.config.embedding_size != self.config.vocab_size:
1041
+ if self.config.embedding_size < self.config.vocab_size:
1042
+ raise Exception("embedding size should be at least as big as vocab size")
1043
+ elif self.config.embedding_size % 128 != 0:
1044
+ import warnings
1045
+
1046
+ warnings.warn(
1047
+ "Embedding size is not a multiple of 128! This could hurt throughput performance.", UserWarning
1048
+ )
1049
+
1050
+ self.activation_checkpointing_strategy: Optional[ActivationCheckpointingStrategy] = None
1051
+ self._activation_checkpoint_fn: Callable = activation_checkpoint_function(self.config)
1052
+
1053
+ if not (
1054
+ 0 < self.config.block_group_size <= self.config.n_layers
1055
+ and self.config.n_layers % self.config.block_group_size == 0
1056
+ ):
1057
+ raise Exception("n layers must be divisible by block group size")
1058
+
1059
+ torch.backends.cuda.enable_flash_sdp(True)
1060
+ torch.backends.cuda.enable_mem_efficient_sdp(False) # this is super slow so make sure torch won't use it
1061
+
1062
+ self.transformer = nn.ModuleDict(
1063
+ dict(
1064
+ wte=nn.Embedding(
1065
+ config.embedding_size or config.vocab_size, config.d_model, device=config.init_device
1066
+ ),
1067
+ emb_drop=Dropout(config.embedding_dropout),
1068
+ ln_f=LayerNorm.build(config),
1069
+ )
1070
+ )
1071
+
1072
+ blocks = [LLaDABlock.build(i, config, self.__cache) for i in range(config.n_layers)]
1073
+ if self.config.block_group_size > 1:
1074
+ block_groups = [
1075
+ LLaDABlockGroup(config, i, blocks[i : i + config.block_group_size])
1076
+ for i in range(0, config.n_layers, config.block_group_size)
1077
+ ]
1078
+ self.transformer.update({"block_groups": nn.ModuleList(block_groups)})
1079
+ else:
1080
+ self.transformer.update({"blocks": nn.ModuleList(blocks)})
1081
+
1082
+ if not (self.config.alibi or self.config.rope):
1083
+ self.transformer.update(
1084
+ {"wpe": nn.Embedding(config.max_sequence_length, config.d_model, device=config.init_device)}
1085
+ )
1086
+ if not config.weight_tying:
1087
+ self.transformer.update(
1088
+ {
1089
+ "ff_out": nn.Linear(
1090
+ config.d_model,
1091
+ config.embedding_size or config.vocab_size,
1092
+ bias=config.include_bias,
1093
+ device=config.init_device,
1094
+ )
1095
+ }
1096
+ )
1097
+ # When `init_device="meta"` FSDP will call `reset_parameters()` to initialize weights.
1098
+ if init_params and self.config.init_device != "meta":
1099
+ self.reset_parameters()
1100
+ self.__num_fwd_flops: Optional[int] = None
1101
+
1102
+ # Warm up cache.
1103
+ if self.config.alibi:
1104
+ get_causal_attention_bias(self.__cache, config.max_sequence_length, _non_meta_init_device(config))
1105
+ self.get_alibi_attention_bias(config.max_sequence_length, _non_meta_init_device(config))
1106
+
1107
+ def set_activation_checkpointing(self, strategy: Optional[ActivationCheckpointingStrategy]):
1108
+ self.activation_checkpointing_strategy = strategy
1109
+ if self.config.block_group_size != 1:
1110
+ for block_group in self.transformer.block_groups:
1111
+ block_group.set_activation_checkpointing(strategy)
1112
+ else:
1113
+ for block in self.transformer.blocks:
1114
+ block.set_activation_checkpointing(strategy)
1115
+
1116
+ @property
1117
+ def device(self) -> torch.device:
1118
+ device: torch.device = self.transformer.wte.weight.device # type: ignore
1119
+ if device.type == "meta":
1120
+ return _non_meta_init_device(self.config)
1121
+ else:
1122
+ return device
1123
+
1124
+ def reset_parameters(self):
1125
+ log.info("Initializing model parameters...")
1126
+ # Top-level embeddings / linear layers.
1127
+ init_weights(
1128
+ self.config,
1129
+ self.transformer.wte, # type: ignore
1130
+ std_factor=(0.5 * math.sqrt(self.config.d_model)) if self.config.scale_logits else 1.0,
1131
+ type_of_module=ModuleType.emb,
1132
+ )
1133
+ if hasattr(self.transformer, "wpe"):
1134
+ init_weights(self.config, self.transformer.wpe, type_of_module=ModuleType.emb) # type: ignore
1135
+
1136
+ # Top-level layer norm.
1137
+ self.transformer.ln_f.reset_parameters() # type: ignore
1138
+
1139
+ # Output weights.
1140
+ if hasattr(self.transformer, "ff_out"):
1141
+ init_weights(self.config, self.transformer.ff_out, type_of_module=ModuleType.final_out) # type: ignore
1142
+
1143
+ # Let the blocks handle themselves.
1144
+ if self.config.block_group_size == 1:
1145
+ for block in self.transformer.blocks:
1146
+ block.reset_parameters()
1147
+ else:
1148
+ for block_group in self.transformer.block_groups:
1149
+ block_group.reset_parameters()
1150
+
1151
+ def get_alibi_attention_bias(self, seq_len: int, device: torch.device) -> torch.Tensor:
1152
+ if (alibi_bias := self.__cache.get("alibi_attention_bias")) is not None and alibi_bias.shape[
1153
+ -1
1154
+ ] >= seq_len:
1155
+ if alibi_bias.device != device:
1156
+ alibi_bias = alibi_bias.to(device)
1157
+ self.__cache["alibi_attention_bias"] = alibi_bias
1158
+ return alibi_bias
1159
+ with torch.autocast(device.type, enabled=False):
1160
+ alibi_bias = alibi_attention_bias(seq_len, self.config, device)
1161
+ self.__cache["alibi_attention_bias"] = alibi_bias
1162
+ return alibi_bias
1163
+
1164
+ def forward(
1165
+ self,
1166
+ input_ids: torch.LongTensor,
1167
+ input_embeddings: Optional[torch.FloatTensor] = None,
1168
+ attention_mask: Optional[torch.Tensor] = None,
1169
+ attention_bias: Optional[torch.Tensor] = None,
1170
+ past_key_values: Optional[Sequence[Tuple[torch.Tensor, torch.Tensor]]] = None,
1171
+ use_cache: bool = False,
1172
+ update_kvcache: bool = False,
1173
+ last_logits_only: bool = False,
1174
+ output_hidden_states: Optional[bool] = None,
1175
+ ) -> LLaDAOutput:
1176
+ """
1177
+ :param input_ids: A tensor of shape `(batch_size, seq_len)`.
1178
+ :param input_embeddings: A tensor of shape `(batch_size, seq_len, d_model)` with input
1179
+ embeddings. When provided, it is treated as the output of the input embedding layer.
1180
+ :param attention_mask: A tensor of shape `(batch_size, seq_len)` that indicates
1181
+ which input IDs are masked. A `1` value in the mask means that
1182
+ the corresponding input ID should *not* be ignored. A `0` means
1183
+ that the corresponding input ID is masked.
1184
+
1185
+ This has the same meaning as the `attention_mask` in HuggingFace's `transformers`
1186
+ library.
1187
+ :param attention_bias: A tensor of shape `(batch_size, 1, seq_len, seq_len)`,
1188
+ `(1, 1, seq_len, seq_len)`, or `(seq_len, seq_len)`. This is used
1189
+ to introduce causal or other biases.
1190
+
1191
+ If the tensor is a bool or byte tensor, a `True` or `1` at `attention_bias[:, :, i, j]`
1192
+ indicates that the i-th element in the sequence is allowed to attend to the j-th
1193
+ element in the sequence.
1194
+
1195
+ If the tensor is a float tensor, it will just be added to the attention
1196
+ scores before the softmax.
1197
+
1198
+ The default is causal, which corresponds to a lower-diagonal byte matrix of ones.
1199
+ :param past_key_values: Pre-computed keys and values for each attention block.
1200
+ Can be used to speed up sequential decoding. The `input_ids` which have
1201
+ their past given to this model should not be passed as `input_ids` as they have already been computed.
1202
+ :param use_cache: If `True`, return key and value tensors for each block.
1203
+ :param last_logits_only: If `True`, only compute the logits for the last token of each sequence.
1204
+ This can speed up decoding when you only care about the next token.
1205
+ """
1206
+ # Add Basic MDM Model config check
1207
+ # print(input_ids.dtype)
1208
+ assert not self.config.alibi, "Alibi length extrapolation is not supported for MDM."
1209
+ assert self.config.rope, "Rope must be used in Llama-Encoder for MDM."
1210
+ # assert (past_key_values is None and not use_cache), "The kvcache is not suppotred for MDM."
1211
+
1212
+ output_hidden_states = output_hidden_states if output_hidden_states is not None else False
1213
+
1214
+ if past_key_values:
1215
+ assert len(past_key_values) == self.config.n_layers
1216
+
1217
+ batch_size, seq_len = input_ids.size() if input_embeddings is None else input_embeddings.size()[:2]
1218
+ if past_key_values is None:
1219
+ past_length = 0
1220
+ else:
1221
+ past_length = past_key_values[0][0].size(-2)
1222
+
1223
+ # Get embeddings of input.
1224
+ # shape: (batch_size, seq_len, d_model)
1225
+ # print(input_ids.dtype,"wte")
1226
+ x = self.transformer.wte(input_ids) if input_embeddings is None else input_embeddings # type: ignore
1227
+
1228
+ if self.config.input_emb_norm:
1229
+ x = x * (self.config.d_model**0.5)
1230
+
1231
+ if not (self.config.alibi or self.config.rope):
1232
+ # Get positional embeddings.
1233
+ # shape: (1, seq_len)
1234
+ pos = torch.arange(past_length, past_length + seq_len, dtype=torch.long, device=x.device).unsqueeze(0)
1235
+ # shape: (1, seq_len, d_model)
1236
+ pos_emb = self.transformer.wpe(pos) # type: ignore
1237
+ x = pos_emb + x
1238
+
1239
+ # Add input + positional embeddings and apply dropout.
1240
+ # shape: (batch_size, seq_len, d_model)
1241
+ x = self.transformer.emb_drop(x) # type: ignore
1242
+
1243
+ # Transform the attention mask into what the blocks expect.
1244
+ if attention_mask is not None and 0.0 in attention_mask:
1245
+ # shape: (batch_size, 1, 1, seq_len)
1246
+ attention_mask = attention_mask.to(dtype=torch.float).view(batch_size, -1)[:, None, None, :]
1247
+ attention_mask = (1.0 - attention_mask) * torch.finfo(attention_mask.dtype).min
1248
+ else:
1249
+ attention_mask = None
1250
+
1251
+ # Merge attention mask with attention bias.
1252
+ if (
1253
+ attention_bias is not None
1254
+ or attention_mask is not None
1255
+ or self.config.alibi
1256
+ # NOTE (epwalsh): we need to initialize the attn bias in order for attn to work properly
1257
+ # with key+value cache. Otherwise `F.scaled_dot_product_attention()` doesn't seem to compute
1258
+ # scores correctly.
1259
+ or past_key_values is not None
1260
+ ):
1261
+ if attention_bias is None and self.config.alibi:
1262
+ attention_bias = get_causal_attention_bias(
1263
+ self.__cache, past_length + seq_len, x.device
1264
+ ) + self.get_alibi_attention_bias(past_length + seq_len, x.device)
1265
+ elif attention_bias is None:
1266
+ attention_bias = get_causal_attention_bias(self.__cache, past_length + seq_len, x.device)
1267
+ elif attention_bias.dtype in (torch.int8, torch.bool):
1268
+ attention_bias = attention_bias.to(dtype=torch.float)
1269
+ attention_bias.masked_fill_(attention_bias == 0.0, torch.finfo(attention_bias.dtype).min)
1270
+
1271
+ # Transform to the right shape and data type.
1272
+ mask_len = seq_len
1273
+ if attention_mask is not None:
1274
+ mask_len = attention_mask.shape[-1]
1275
+ elif past_key_values is not None:
1276
+ mask_len = past_key_values[0][0].shape[-2] + seq_len
1277
+ attention_bias = attention_bias[:, :, :mask_len, :mask_len].to(dtype=torch.float)
1278
+
1279
+ # Add in the masking bias.
1280
+ if attention_mask is not None:
1281
+ attention_bias = attention_bias + attention_mask
1282
+ # Might get -infs after adding attention mask, since dtype.min + dtype.min = -inf.
1283
+ # `F.scaled_dot_product_attention()` doesn't handle -inf like you'd expect, instead
1284
+ # it can produce NaNs.
1285
+ ensure_finite_(attention_bias, check_neg_inf=True, check_pos_inf=False)
1286
+
1287
+ attn_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = [] if use_cache else None
1288
+
1289
+ # decoder layers
1290
+ all_hidden_states = []
1291
+
1292
+ # Apply blocks one-by-one.
1293
+ if self.config.block_group_size == 1:
1294
+ for block_idx, block in enumerate(self.transformer.blocks):
1295
+ if output_hidden_states:
1296
+ # add hidden states
1297
+ all_hidden_states.append(x)
1298
+
1299
+ layer_past = None if past_key_values is None else past_key_values[block_idx]
1300
+ if (
1301
+ (self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.whole_layer)
1302
+ or (
1303
+ self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_two
1304
+ and block_idx % 2 == 0
1305
+ )
1306
+ or (
1307
+ self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_three
1308
+ and block_idx % 3 == 0
1309
+ )
1310
+ or (
1311
+ self.activation_checkpointing_strategy == ActivationCheckpointingStrategy.one_in_four
1312
+ and block_idx % 4 == 0
1313
+ )
1314
+ ):
1315
+ # shape: (batch_size, seq_len, d_model)
1316
+ x, cache = self._activation_checkpoint_fn(
1317
+ block, x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache
1318
+ )
1319
+ else:
1320
+ # shape: (batch_size, seq_len, d_model)
1321
+ x, cache = block(x, attention_bias=attention_bias, layer_past=layer_past, use_cache=use_cache)
1322
+ if attn_key_values is not None:
1323
+ if update_kvcache:
1324
+ cache = (cache[0][:,:,:update_kvcache],cache[1][:,:,:update_kvcache,:])
1325
+ # print("True")
1326
+ attn_key_values.append(cache)
1327
+ else:
1328
+ for group_idx, block_group in enumerate(self.transformer.block_groups):
1329
+ if output_hidden_states:
1330
+ # add hidden states
1331
+ all_hidden_states.append(x)
1332
+
1333
+ layers_past = (
1334
+ None
1335
+ if past_key_values is None
1336
+ else past_key_values[
1337
+ group_idx * self.config.block_group_size : (group_idx + 1) * self.config.block_group_size
1338
+ ]
1339
+ )
1340
+ x, cache = block_group(
1341
+ x, attention_bias=attention_bias, layers_past=layers_past, use_cache=use_cache
1342
+ )
1343
+ if attn_key_values is not None:
1344
+ assert cache is not None
1345
+ attn_key_values.extend(cache)
1346
+
1347
+ if last_logits_only:
1348
+ # shape: (batch_size, 1, d_model)
1349
+ x = x[:, -1, :].unsqueeze(1)
1350
+
1351
+ # Apply final layer norm.
1352
+ # shape: (batch_size, seq_len or 1, d_model)
1353
+ x = self.transformer.ln_f(x) # type: ignore
1354
+ if output_hidden_states:
1355
+ # add final hidden state post-final-layernorm, following HuggingFace's convention
1356
+ all_hidden_states.append(x)
1357
+
1358
+ # Get logits.
1359
+ # shape: (batch_size, seq_len or 1, vocab_size)
1360
+ if self.config.weight_tying:
1361
+ logits = F.linear(x, self.transformer.wte.weight, None) # type: ignore
1362
+ else:
1363
+ logits = self.transformer.ff_out(x) # type: ignore
1364
+ if self.config.scale_logits:
1365
+ logits.mul_(1 / math.sqrt(self.config.d_model))
1366
+ if use_cache == True and update_kvcache == False:
1367
+ attn_key_values=past_key_values
1368
+ return LLaDAOutput(logits=logits, attn_key_values=attn_key_values, hidden_states=tuple(all_hidden_states) if output_hidden_states else None) # type: ignore[arg-type]
1369
+
1370
+
1371
+ def create_model_config_from_pretrained_config(config: LLaDAConfig):
1372
+ """
1373
+ Utility function
1374
+ """
1375
+
1376
+ kwargs = {}
1377
+ for field in fields(ModelConfig):
1378
+ kwargs[field.name] = getattr(config, field.name)
1379
+
1380
+ model_config = ModelConfig(**kwargs)
1381
+ return model_config
1382
+
1383
+
1384
+ class LLaDAModelLM(PreTrainedModel):
1385
+ """
1386
+ Extremely barebones HF model wrapper.
1387
+ """
1388
+
1389
+ config_class = LLaDAConfig
1390
+ base_model_prefix = "model"
1391
+ _no_split_modules = ["LLaDABlock", "LLaDASequentialBlock", "LLaDALlamaBlock"]
1392
+
1393
+ def __init__(self, config: LLaDAConfig, model: Optional[LLaDAModel] = None, init_params: bool = False):
1394
+ super().__init__(config)
1395
+
1396
+ if not model:
1397
+ model_config = create_model_config_from_pretrained_config(config)
1398
+ # Initialize model (always on CPU to start with so we don't run out of GPU memory).
1399
+ model_config.init_device = "cpu"
1400
+ self.model = LLaDAModel(model_config, init_params=init_params)
1401
+ else:
1402
+ self.model = model
1403
+
1404
+ def forward(
1405
+ self,
1406
+ input_ids: torch.LongTensor = None,
1407
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1408
+ attention_mask: Optional[torch.Tensor] = None,
1409
+ attention_bias: Optional[torch.Tensor] = None,
1410
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
1411
+ labels: Optional[torch.LongTensor] = None,
1412
+ use_cache: Optional[bool] = None,
1413
+ update_kvcache: Optional[bool] = False,
1414
+ output_attentions: Optional[bool] = None,
1415
+ output_hidden_states: Optional[bool] = None,
1416
+ return_dict: Optional[bool] = None,
1417
+ cache_position: Optional[Cache] = None, # This is a hack mitigation of an issue in transformers `4.39.x`
1418
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
1419
+ if use_cache is None:
1420
+ use_cache = self.config.use_cache
1421
+
1422
+ if output_attentions:
1423
+ raise ValueError("output_attentions is not yet supported in LLaDA")
1424
+
1425
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1426
+
1427
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
1428
+ outputs = self.model.forward(
1429
+ input_ids=input_ids,
1430
+ input_embeddings=inputs_embeds,
1431
+ attention_mask=attention_mask,
1432
+ attention_bias=attention_bias,
1433
+ past_key_values=past_key_values,
1434
+ use_cache=use_cache,
1435
+ update_kvcache=update_kvcache,
1436
+ output_hidden_states=output_hidden_states,
1437
+ )
1438
+
1439
+ logits = outputs.logits
1440
+ hidden_states = outputs.hidden_states
1441
+
1442
+ loss = None
1443
+ if labels is not None:
1444
+ import warnings
1445
+ warnings.warn("Note that for LLaDA, you cannot calculate the loss here.", UserWarning)
1446
+ if not return_dict:
1447
+ output = (logits,) + outputs[1:]
1448
+ return (loss,) + output if loss is not None else output
1449
+
1450
+ return CausalLMOutputWithPast(
1451
+ logits=logits,
1452
+ past_key_values=outputs.attn_key_values,
1453
+ hidden_states=hidden_states,
1454
+ )
1455
+
1456
+ def can_generate(self) -> bool:
1457
+ return True
1458
+
1459
+ def prepare_inputs_for_generation(
1460
+ self, input_ids: torch.LongTensor, past_key_values: Optional[List[Tuple]] = None, **kwargs
1461
+ ):
1462
+ if past_key_values:
1463
+ # This is because we want the model to only process the last generated token.
1464
+ input_ids = input_ids[:, -1:]
1465
+ model_inputs = {"input_ids": input_ids, "past_key_values": past_key_values}
1466
+
1467
+ model_inputs.update(kwargs)
1468
+ model_inputs["use_cache"] = kwargs.pop("use_cache", self.config.use_cache)
1469
+ return model_inputs
1470
+
1471
+ # TODO: these are required to make the implementation complete.
1472
+ # def resize_position_embeddings(self, new_num_position_embeddings: int):
1473
+ # pass
1474
+ #
1475
+ # def get_position_embeddings(self) -> Union[nn.Embedding, Tuple[nn.Embedding]]:
1476
+ # pass
1477
+ #
1478
+ # def _reorder_cache(self, past_key_values, beam_idx):
1479
+ # pass
1480
+
1481
+ def get_input_embeddings(self) -> torch.nn.Module:
1482
+ return self.model.transformer.wte
1483
+
1484
+ def set_input_embeddings(self, value: torch.nn.Module):
1485
+ self.model.transformer.wte = value
1486
+
1487
+ def get_output_embeddings(self):
1488
+ if self.config.weight_tying:
1489
+ return self.model.transformer.wte
1490
+ else:
1491
+ return self.model.transformer.ff_out
1492
+
1493
+ def set_output_embeddings(self, value: torch.nn.Module):
1494
+ if self.config.weight_tying:
1495
+ self.model.transformer.wte = value
1496
+ else:
1497
+ self.model.transformer.ff_out = value
1498
+
1499
+ def tie_weights(self):
1500
+ if self.config.weight_tying:
1501
+ self.model.transformer.ff_out = self.model.transformer.wte
1502
+
1503
+ # Register the model so that it is available for transformer pipelines, auto-loading, etc.
1504
+ AutoModel.register(LLaDAConfig, LLaDAModelLM)
postprocess_code.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+ # Modified from Dream repos: https://github.com/HKUNLP/Dream
17
+
18
+ import evaluate as hf_evaluate
19
+ import os
20
+ import sys
21
+ from sanitize import sanitize
22
+
23
+ os.environ["HF_ALLOW_CODE_EVAL"] = "1"
24
+ pass_at_k = hf_evaluate.load("code_eval")
25
+
26
+ def pass_at_1(references, predictions):
27
+ return pass_at_k.compute(
28
+ references=references,
29
+ predictions=predictions,
30
+ k=[1],
31
+ )[0]["pass@1"]
32
+
33
+ import json
34
+
35
+
36
+ def read_jsonl(file_path):
37
+ data = []
38
+ with open(file_path, 'r') as file:
39
+ for line in file:
40
+ data.append(json.loads(line))
41
+ return data
42
+
43
+ file_path = sys.argv[1]
44
+ data = read_jsonl(file_path)
45
+
46
+ references = [sample['target'] for sample in data]
47
+
48
+ predictions = [[sanitize(sample['doc']['prompt'] + "\n" + sample['resps'][0][0].split('```python\n', 1)[-1].split('```')[0],
49
+ sample['doc']["entry_point"])]
50
+ for sample in data]
51
+
52
+ pass_at_1s = [pass_at_1([reference], [prediction]) for reference, prediction in zip(references, predictions)]
53
+ print(sum(pass_at_1s)/len(pass_at_1s))
54
+
55
+ def write_jsonl(data, file_path):
56
+ with open(file_path, 'w') as file:
57
+ for item in data:
58
+ file.write(json.dumps(item) + '\n')
59
+
60
+ res = [{"task_id": sample['doc']['task_id'], "completion": pred, "pass_at_1": res}
61
+ for sample, pred, res in zip(data, predictions, pass_at_1s)]
62
+ write_jsonl(res, file_path+'.cleaned')
sanitize.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 NVIDIA CORPORATION & AFFILIATES
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ #
15
+ # SPDX-License-Identifier: Apache-2.0
16
+ # Modified from Dream repos: https://github.com/HKUNLP/Dream
17
+
18
+ """Post-processing LLM-generated Python code implemented using tree-sitter."""
19
+
20
+ import os
21
+ import sys
22
+ import pathlib
23
+
24
+ ROOT = os.path.dirname(os.path.abspath(__file__))
25
+ sys.path.extend([os.path.dirname(ROOT), os.path.dirname(os.path.dirname(ROOT))])
26
+
27
+ import ast
28
+ import traceback
29
+
30
+ from typing import Dict, List, Optional, Set, Tuple
31
+
32
+ def refine_text(text: str) -> str:
33
+ text = text.replace("\t", " ")
34
+ text = text.replace("\r\n", "\n").replace("\r", "\n")
35
+ return text.strip() + "\n"
36
+
37
+ def syntax_check(code, verbose = False):
38
+ try:
39
+ ast.parse(code)
40
+ return True
41
+ except (SyntaxError, MemoryError):
42
+ if verbose:
43
+ traceback.print_exc()
44
+ return False
45
+
46
+ def extract_longest_valid_code(text: str) -> str:
47
+ lines = text.splitlines()
48
+
49
+ if len(lines) > 100:
50
+ lines = lines[:100]
51
+ max_valid_lines = 0
52
+ max_valid_snippet = ""
53
+
54
+ for i in range(len(lines)):
55
+ for j in range(i, len(lines)):
56
+ current_snippet = "\n".join(lines[i:j+1])
57
+ if syntax_check(current_snippet):
58
+ valid_line_count = sum(1 for line in lines[i:j+1] if line.strip())
59
+ if valid_line_count > max_valid_lines:
60
+ max_valid_lines = valid_line_count
61
+ max_valid_snippet = current_snippet
62
+
63
+ return max_valid_snippet
64
+
65
+ def get_deps(nodes: List[Tuple[str, ast.AST]]) -> Dict[str, Set[str]]:
66
+ name2deps = {}
67
+ for name, node in nodes:
68
+ deps = set()
69
+ stack = [node]
70
+ while stack:
71
+ current = stack.pop()
72
+ for child in ast.iter_child_nodes(current):
73
+ if isinstance(child, ast.Name):
74
+ deps.add(child.id)
75
+ elif isinstance(child, ast.Attribute):
76
+ deps.add(child.attr)
77
+ else:
78
+ stack.append(child)
79
+ name2deps[name] = deps
80
+ return name2deps
81
+
82
+ def get_function_dependency(entrypoint: str, call_graph: Dict[str, Set[str]]) -> Set[str]:
83
+ visited = set()
84
+ to_visit = [entrypoint]
85
+
86
+ while to_visit:
87
+ current = to_visit.pop(0)
88
+ if current not in visited:
89
+ visited.add(current)
90
+ to_visit.extend(call_graph.get(current, set()) - visited)
91
+
92
+ return visited
93
+
94
+ def get_definition_name(node: ast.AST) -> Optional[str]:
95
+ if isinstance(node, (ast.FunctionDef, ast.ClassDef)):
96
+ return node.name
97
+ elif isinstance(node, ast.Assign):
98
+ targets = node.targets
99
+ if targets and isinstance(targets[0], ast.Name):
100
+ return targets[0].id
101
+ return None
102
+
103
+ def has_return_statement(node: ast.AST) -> bool:
104
+ return any(isinstance(n, ast.Return) for n in ast.walk(node))
105
+
106
+ def sanitize(text: str, entrypoint: Optional[str] = None) -> str:
107
+
108
+ text = refine_text(text)
109
+
110
+ # text = python_extract(text)
111
+
112
+ code = extract_longest_valid_code(text)
113
+ tree = ast.parse(code)
114
+
115
+ definitions = {}
116
+
117
+ imports = []
118
+
119
+ for node in tree.body:
120
+ if isinstance(node, (ast.Import, ast.ImportFrom)):
121
+ imports.append(node)
122
+ elif isinstance(node, ast.ClassDef):
123
+ name = node.name
124
+ definitions[name] = ('class', node)
125
+ elif isinstance(node, ast.FunctionDef):
126
+ name = node.name
127
+ if has_return_statement(node):
128
+ definitions[name] = ('function', node)
129
+ elif isinstance(node, ast.Assign):
130
+ name = get_definition_name(node)
131
+ if name:
132
+ definitions[name] = ('variable', node)
133
+
134
+ if entrypoint:
135
+ name2deps = get_deps([(name, node) for name, (_, node) in definitions.items()])
136
+ reachable = get_function_dependency(entrypoint, name2deps)
137
+
138
+ sanitized_output = []
139
+
140
+ for node in imports:
141
+ sanitized_output.append(ast.unparse(node))
142
+
143
+ for name, (_, node) in definitions.items():
144
+ if not entrypoint or name in reachable:
145
+ sanitized_output.append(ast.unparse(node))
146
+
147
+ return "\n".join(sanitized_output)