YongganFu commited on
Commit
41d881f
·
verified ·
1 Parent(s): 8c7caf4

Upload model

Browse files
Files changed (2) hide show
  1. chat_utils.py +196 -0
  2. modeling_nvrdiff.py +44 -4
chat_utils.py ADDED
@@ -0,0 +1,196 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import os
5
+ import sys
6
+ import argparse
7
+ import random
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from transformers import AutoTokenizer
13
+
14
+ sys.path.insert(1, "/lustre/fsw/portfolios/nvr/users/yongganf/adlr-megatron-lm")
15
+ from get_hf_model import get_torchtitan_model_sft # noqa: E402
16
+
17
+
18
+ # --------------------------- Reproducibility ----------------------------------
19
+ def set_seed(seed: int = 42):
20
+ torch.manual_seed(seed)
21
+ random.seed(seed)
22
+ np.random.seed(seed)
23
+ torch.backends.cudnn.deterministic = True
24
+ torch.backends.cudnn.benchmark = False
25
+
26
+
27
+ # -------------------- Diffusion helpers (unchanged logic) --------------------
28
+ def get_transfer_index(
29
+ logits, temperature, remasking, mask_index, x, num_transfer_tokens, threshold=None, neg_entropy=False
30
+ ):
31
+ x0 = torch.argmax(logits, dim=-1) # (B, L)
32
+ if remasking == "low_confidence":
33
+ p = F.softmax(logits, dim=-1)
34
+ x0_p = torch.squeeze(torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1)
35
+ elif remasking == "random":
36
+ x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
37
+ else:
38
+ raise NotImplementedError(remasking)
39
+
40
+ if neg_entropy:
41
+ p = F.softmax(logits, dim=-1)
42
+ epsilon = 1e-10
43
+ log_probs = torch.log(p + epsilon)
44
+ confidence_scores = torch.sum(p * log_probs, dim=-1)
45
+ else:
46
+ confidence_scores = x0_p
47
+
48
+ x0 = torch.where(mask_index, x0, x)
49
+ confidence = torch.where(mask_index, confidence_scores, torch.tensor(float("-inf"), device=x0.device))
50
+
51
+ transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
52
+ if threshold is not None:
53
+ num_transfer_tokens = mask_index.sum(dim=1, keepdim=True)
54
+
55
+ for j in range(confidence.shape[0]):
56
+ k = int(num_transfer_tokens[j])
57
+ k = max(k, 1)
58
+ _, select_index = torch.topk(confidence[j], k=k)
59
+ transfer_index[j, select_index] = True
60
+ if threshold is not None:
61
+ for kk in range(k):
62
+ if confidence[j, select_index[kk]] < threshold:
63
+ transfer_index[j, select_index[kk]] = False
64
+
65
+ return x0, transfer_index
66
+
67
+
68
+ def get_num_transfer_tokens(mask_index, steps: int):
69
+ mask_num = mask_index.sum(dim=1, keepdim=True)
70
+ base = mask_num // steps
71
+ remainder = mask_num % steps
72
+ num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base
73
+ for i in range(mask_num.size(0)):
74
+ num_transfer_tokens[i, : int(remainder[i])] += 1
75
+ return num_transfer_tokens
76
+
77
+
78
+
79
+ @torch.no_grad()
80
+ def generate_with_prefix_cache_block_diff(
81
+ model,
82
+ prompt,
83
+ steps=128,
84
+ gen_length=128,
85
+ block_length=32,
86
+ temperature=0.,
87
+ remasking='low_confidence',
88
+ mask_id=151662,
89
+ threshold=None,
90
+ shift_logits=True,
91
+ neg_entropy=True
92
+ ):
93
+ dream_style=shift_logits
94
+ # Initialize the accumulator
95
+ x_accum = prompt.clone()
96
+
97
+ assert gen_length % block_length == 0
98
+ num_blocks = gen_length // block_length
99
+
100
+ assert steps % num_blocks == 0
101
+ steps_per_block = steps // num_blocks
102
+
103
+ nfe = 0
104
+
105
+ # Compute KV cache for the prompt initially
106
+ output = model(prompt, use_cache=True)
107
+ past_key_values = output.past_key_values
108
+
109
+ # For dream_style: store the "next token logit" of the context
110
+ next_logits_context = None
111
+ if dream_style:
112
+ next_logits_context = output.logits[:, -1:, :] # (B, 1, V)
113
+
114
+ for num_block in range(num_blocks):
115
+ # Create a new block with mask tokens (no seeding)
116
+ mask_block = torch.ones(
117
+ (prompt.shape[0], block_length),
118
+ dtype=prompt.dtype,
119
+ device=prompt.device
120
+ ) * mask_id
121
+
122
+ # Append the block of masks
123
+ x_accum = torch.cat([x_accum, mask_block], dim=1)
124
+ current_block_start = prompt.size(1) + num_block * block_length
125
+ block_slice = slice(current_block_start, current_block_start + block_length)
126
+
127
+ # Build the initial mask for this block
128
+ mask_block_idx0 = (x_accum[:, block_slice] == mask_id) # (B, Lb)
129
+
130
+ schedule_mask = mask_block_idx0
131
+
132
+ num_transfer_tokens = get_num_transfer_tokens(schedule_mask, steps_per_block) # (B, steps)
133
+
134
+ # Denoise the current block
135
+ for i in range(steps_per_block):
136
+ mask_block_idx = (x_accum[:, block_slice] == mask_id) # (B, Lb)
137
+ if mask_block_idx.sum() == 0:
138
+ break
139
+
140
+ nfe += 1
141
+
142
+ # Forward only the current noisy block using cached context
143
+ logits_block = model(
144
+ x_accum[:, block_slice],
145
+ past_key_values=past_key_values,
146
+ use_cache=False
147
+ ).logits
148
+
149
+ if dream_style:
150
+ # Align logits so that each masked position has a predictor:
151
+ # prepend context-next logit, then use logits_block[:-1]
152
+ if block_length == 1:
153
+ logits_use = next_logits_context # (B, 1, V)
154
+ else:
155
+ logits_use = torch.cat(
156
+ [next_logits_context, logits_block[:, :-1, :]],
157
+ dim=1
158
+ ) # (B, Lb, V)
159
+
160
+ mask_use = mask_block_idx # (B, Lb)
161
+ x_use = x_accum[:, block_slice] # (B, Lb)
162
+
163
+ x0, transfer_idx = get_transfer_index(
164
+ logits_use, temperature, remasking, mask_use, x_use,
165
+ num_transfer_tokens=num_transfer_tokens[:, i],
166
+ threshold=threshold, neg_entropy=neg_entropy
167
+ )
168
+ cur = x_accum[:, block_slice].clone()
169
+ cur[transfer_idx] = x0[transfer_idx]
170
+ x_accum[:, block_slice] = cur
171
+
172
+ else:
173
+ # non-AR (same-position) case
174
+ x0, transfer_idx = get_transfer_index(
175
+ logits_block, temperature, remasking, mask_block_idx,
176
+ x_accum[:, block_slice],
177
+ num_transfer_tokens=num_transfer_tokens[:, i],
178
+ threshold=threshold, neg_entropy=neg_entropy
179
+ )
180
+ cur = x_accum[:, block_slice].clone()
181
+ cur[transfer_idx] = x0[transfer_idx]
182
+ x_accum[:, block_slice] = cur
183
+
184
+ # after block is fully denoised, update KV cache
185
+ output = model(
186
+ x_accum[:, block_slice],
187
+ past_key_values=past_key_values,
188
+ use_cache=True
189
+ )
190
+ past_key_values = output.past_key_values
191
+
192
+ if dream_style and num_block < num_blocks - 1:
193
+ # refresh context-next logit for the next block
194
+ next_logits_context = output.logits[:, -1:, :] # (B, 1, V)
195
+
196
+ return x_accum, nfe
modeling_nvrdiff.py CHANGED
@@ -7,9 +7,6 @@ import torch.nn.functional as F
7
  from torch import nn
8
  from transformers.modeling_outputs import CausalLMOutputWithPast
9
 
10
- from .modeling_qwen3 import Qwen3Model, Qwen3PreTrainedModel, Qwen3Attention, apply_rotary_pos_emb, repeat_kv
11
- from .configuration_nvrdiff import NVRDiffConfig
12
-
13
  from torch.nn.attention.flex_attention import flex_attention, create_block_mask
14
 
15
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
@@ -24,6 +21,10 @@ from transformers.generation import GenerationMixin
24
 
25
  import math
26
 
 
 
 
 
27
  # @torch.compile(dynamic=True, mode="reduce-overhead")
28
  # @torch.compile(mode="default")
29
  # @torch.compile(fullgraph=True, mode="reduce-overhead", dynamic=False)
@@ -532,4 +533,43 @@ class DiffEncoderModel(Qwen3PreTrainedModel, GenerationMixin):
532
  hidden_states=None,
533
  attentions=None,
534
  )
535
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  from torch import nn
8
  from transformers.modeling_outputs import CausalLMOutputWithPast
9
 
 
 
 
10
  from torch.nn.attention.flex_attention import flex_attention, create_block_mask
11
 
12
  from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
 
21
 
22
  import math
23
 
24
+ from .modeling_qwen3 import Qwen3Model, Qwen3PreTrainedModel, Qwen3Attention, apply_rotary_pos_emb, repeat_kv
25
+ from .configuration_nvrdiff import NVRDiffConfig
26
+ from .chat_utils import generate_with_prefix_cache_block_diff
27
+
28
  # @torch.compile(dynamic=True, mode="reduce-overhead")
29
  # @torch.compile(mode="default")
30
  # @torch.compile(fullgraph=True, mode="reduce-overhead", dynamic=False)
 
533
  hidden_states=None,
534
  attentions=None,
535
  )
536
+
537
+
538
+ def chat(self, tokenizer, max_new_tokens, steps, block_length, threshold):
539
+ print("Stateless chat (type 'exit' to quit)")
540
+ print("------------------------------------")
541
+
542
+ try:
543
+ while True:
544
+ user_input = input("User: ").strip()
545
+ if user_input.lower() in {"exit", "quit", "q"}:
546
+ print("Conversation ended.")
547
+ break
548
+
549
+ prompt_ids = tokenizer(
550
+ user_input,return_tensors='pt'
551
+ ).input_ids.to(device='cuda')
552
+
553
+ out_ids, nfe = generate_with_prefix_cache_block_diff(
554
+ model=self,
555
+ prompt=prompt_ids,
556
+ gen_length=max_new_tokens,
557
+ steps=steps,
558
+ block_length=block_length,
559
+ remasking="low_confidence",
560
+ mask_id=self.mask_token_id,
561
+ threshold=threshold,
562
+ shift_logits=True,
563
+ neg_entropy=True,
564
+ )
565
+
566
+ generated_tokens = out_ids[:, prompt_ids.shape[1]:]
567
+ tokenized_out = tokenizer.batch_decode(
568
+ generated_tokens,
569
+ skip_special_tokens=True
570
+ )[0]
571
+ print(f"Model: {tokenized_out}")
572
+ print(f"[nfe={nfe}]")
573
+
574
+ except KeyboardInterrupt:
575
+ print("\n[info] interrupted by user (Ctrl-C).")