YongganFu commited on
Commit
b4c3108
·
verified ·
1 Parent(s): ad6359b

Upload model

Browse files
Files changed (2) hide show
  1. chat_utils.py +60 -34
  2. modeling_nvrdiff.py +2 -2
chat_utils.py CHANGED
@@ -1,54 +1,75 @@
1
- #!/usr/bin/env python3
2
- # -*- coding: utf-8 -*-
3
-
4
- import os
5
- import sys
6
- import argparse
7
- import random
8
-
9
  import numpy as np
10
  import torch
11
  import torch.nn.functional as F
12
- from transformers import AutoTokenizer
13
 
14
 
15
- def get_transfer_index(
16
- logits, temperature, remasking, mask_index, x, num_transfer_tokens, threshold=None, neg_entropy=False
17
- ):
18
- x0 = torch.argmax(logits, dim=-1) # (B, L)
19
- if remasking == "low_confidence":
20
  p = F.softmax(logits, dim=-1)
21
- x0_p = torch.squeeze(torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1)
22
- elif remasking == "random":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
24
  else:
25
  raise NotImplementedError(remasking)
26
-
 
27
  if neg_entropy:
 
28
  p = F.softmax(logits, dim=-1)
29
  epsilon = 1e-10
30
  log_probs = torch.log(p + epsilon)
31
- confidence_scores = torch.sum(p * log_probs, dim=-1)
32
  else:
33
  confidence_scores = x0_p
34
-
35
  x0 = torch.where(mask_index, x0, x)
36
- confidence = torch.where(mask_index, confidence_scores, torch.tensor(float("-inf"), device=x0.device))
37
 
38
  transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
39
  if threshold is not None:
40
  num_transfer_tokens = mask_index.sum(dim=1, keepdim=True)
41
-
42
  for j in range(confidence.shape[0]):
43
- k = int(num_transfer_tokens[j])
44
- k = max(k, 1)
45
- _, select_index = torch.topk(confidence[j], k=k)
46
  transfer_index[j, select_index] = True
47
  if threshold is not None:
48
- for kk in range(k):
49
- if confidence[j, select_index[kk]] < threshold:
50
- transfer_index[j, select_index[kk]] = False
51
-
52
  return x0, transfer_index
53
 
54
 
@@ -62,20 +83,20 @@ def get_num_transfer_tokens(mask_index, steps: int):
62
  return num_transfer_tokens
63
 
64
 
65
-
66
  @torch.no_grad()
67
  def generate_with_prefix_cache_block_diff(
68
  model,
69
  prompt,
70
  steps=128,
71
  gen_length=128,
72
- block_length=32,
73
  temperature=0.,
74
  remasking='low_confidence',
75
- mask_id=151662,
76
  threshold=None,
77
- shift_logits=True,
78
- neg_entropy=True
 
79
  ):
80
  dream_style=shift_logits
81
  # Initialize the accumulator
@@ -114,7 +135,12 @@ def generate_with_prefix_cache_block_diff(
114
  # Build the initial mask for this block
115
  mask_block_idx0 = (x_accum[:, block_slice] == mask_id) # (B, Lb)
116
 
117
- schedule_mask = mask_block_idx0
 
 
 
 
 
118
 
119
  num_transfer_tokens = get_num_transfer_tokens(schedule_mask, steps_per_block) # (B, steps)
120
 
 
 
 
 
 
 
 
 
 
1
  import numpy as np
2
  import torch
3
  import torch.nn.functional as F
 
4
 
5
 
6
+ def get_transfer_index(logits, temperature, remasking, mask_index, x, num_transfer_tokens, threshold=None,neg_entropy=False):
7
+ x0 = torch.argmax(logits, dim=-1) # b, l
8
+
9
+ if remasking == 'low_confidence':
10
+ # p = F.softmax(logits.to(torch.float64), dim=-1)
11
  p = F.softmax(logits, dim=-1)
12
+ x0_p = torch.squeeze(
13
+ torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
14
+ elif remasking == 'top_p_margin':
15
+ # Compute probabilities
16
+ p = F.softmax(logits, dim=-1) # (B, L, V)
17
+ # Top-2 per position
18
+ top2 = torch.topk(p, k=2, dim=-1).values # (B, L, 2)
19
+ margin = top2[..., 0] - top2[..., 1] # (B, L)
20
+
21
+ # Normalize margin to [0,1] over MASKED positions per row
22
+ plus_inf = torch.full_like(margin, float('inf'))
23
+ minus_inf = torch.full_like(margin, float('-inf'))
24
+ masked_for_min = torch.where(mask_index, margin, plus_inf)
25
+ masked_for_max = torch.where(mask_index, margin, minus_inf)
26
+ row_min = masked_for_min.amin(dim=1, keepdim=True) # (B, 1)
27
+ row_max = masked_for_max.amax(dim=1, keepdim=True) # (B, 1)
28
+ denom = (row_max - row_min)
29
+
30
+ # If denom==0 (all equal), set normalized=1 on masked; 0 elsewhere by default
31
+ normalized = torch.zeros_like(margin)
32
+ nonzero = denom > 0
33
+ normalized = torch.where(
34
+ mask_index & nonzero,
35
+ (margin - row_min) / (denom + 1e-12),
36
+ normalized
37
+ )
38
+ normalized = torch.where(
39
+ mask_index & (~nonzero),
40
+ torch.ones_like(normalized),
41
+ normalized
42
+ )
43
+ x0_p = normalized # ∈ [0,1] on masked positions
44
+ elif remasking == 'random':
45
  x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
46
  else:
47
  raise NotImplementedError(remasking)
48
+
49
+ # Calculate negative entropy if requested
50
  if neg_entropy:
51
+ # p = F.softmax(logits.to(torch.float64), dim=-1)
52
  p = F.softmax(logits, dim=-1)
53
  epsilon = 1e-10
54
  log_probs = torch.log(p + epsilon)
55
+ confidence_scores = torch.sum(p * log_probs, dim=-1) # negative entropy per position
56
  else:
57
  confidence_scores = x0_p
58
+
59
  x0 = torch.where(mask_index, x0, x)
60
+ confidence = torch.where(mask_index, confidence_scores, -np.inf)
61
 
62
  transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
63
  if threshold is not None:
64
  num_transfer_tokens = mask_index.sum(dim=1, keepdim=True)
65
+ # print(f'confidence: {confidence}')
66
  for j in range(confidence.shape[0]):
67
+ _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j])
 
 
68
  transfer_index[j, select_index] = True
69
  if threshold is not None:
70
+ for k in range(1, num_transfer_tokens[j]):
71
+ if confidence[j, select_index[k]] < threshold:
72
+ transfer_index[j, select_index[k]] = False
 
73
  return x0, transfer_index
74
 
75
 
 
83
  return num_transfer_tokens
84
 
85
 
 
86
  @torch.no_grad()
87
  def generate_with_prefix_cache_block_diff(
88
  model,
89
  prompt,
90
  steps=128,
91
  gen_length=128,
92
+ block_length=128,
93
  temperature=0.,
94
  remasking='low_confidence',
95
+ mask_id=126336,
96
  threshold=None,
97
+ factor=None,
98
+ shift_logits=False,
99
+ neg_entropy=False,
100
  ):
101
  dream_style=shift_logits
102
  # Initialize the accumulator
 
135
  # Build the initial mask for this block
136
  mask_block_idx0 = (x_accum[:, block_slice] == mask_id) # (B, Lb)
137
 
138
+ # Precompute the transfer schedule for this block
139
+ if dream_style:
140
+ # still denoise *all* positions (0..Lb-1), since none are seeded
141
+ schedule_mask = mask_block_idx0
142
+ else:
143
+ schedule_mask = mask_block_idx0
144
 
145
  num_transfer_tokens = get_num_transfer_tokens(schedule_mask, steps_per_block) # (B, steps)
146
 
modeling_nvrdiff.py CHANGED
@@ -546,7 +546,7 @@ class DiffEncoderModel(Qwen3PreTrainedModel, GenerationMixin):
546
  mask_id=self.mask_token_id,
547
  threshold=threshold,
548
  shift_logits=True,
549
- neg_entropy=True,
550
  )
551
-
552
  return out_ids, nfe
 
546
  mask_id=self.mask_token_id,
547
  threshold=threshold,
548
  shift_logits=True,
549
+ neg_entropy=False,
550
  )
551
+
552
  return out_ids, nfe