fix bug for generaion(empty response and PyTorch exception)

#8
by Natt1e - opened
Files changed (1) hide show
  1. modeling_stable_diffcoder.py +21 -9
modeling_stable_diffcoder.py CHANGED
@@ -137,8 +137,10 @@ class StableDiffcoderForCausalLM(LlamaForCausalLM):
137
  prompt_length = input_ids.shape[1]
138
  gen_block_list = [block_length for _ in range(gen_blocks)]
139
 
140
- res_block = block_length - (prompt_length % block_length)
141
- if res_block > 0:
 
 
142
  gen_block_list = [res_block] + gen_block_list
143
  gen_block_list[-1] = block_length - res_block
144
  gen_blocks += 1
@@ -156,16 +158,20 @@ class StableDiffcoderForCausalLM(LlamaForCausalLM):
156
  nfe = 0
157
  final_flag = False
158
  prefill_length = prompt_length // block_length * block_length
 
159
  if prefill_length > 0:
160
  cur_attn_mask = block_diffusion_attention_mask[
161
  ..., :prefill_length, :prefill_length
162
  ]
 
 
163
  self(
164
  x[:, :prefill_length],
165
  past_key_values=past_key_values,
166
  attention_mask=cur_attn_mask,
167
  use_cache=True,
168
- ).past_key_values
 
169
 
170
  for block_id, block_size in enumerate(gen_block_list):
171
  block_start = (
@@ -182,7 +188,7 @@ class StableDiffcoderForCausalLM(LlamaForCausalLM):
182
  replace_position[:, block_start:block_end] = True
183
 
184
  for token_count in num_transfer_tokens:
185
- if token_count:
186
  nfe += 1
187
  mask_map = x[:, block_start:block_end] == mask_id
188
  attention_mask = block_diffusion_attention_mask[
@@ -205,22 +211,28 @@ class StableDiffcoderForCausalLM(LlamaForCausalLM):
205
  remasking,
206
  mask_map,
207
  x[:, block_start:block_end],
208
- token_count if threshold is None else None,
209
  threshold,
210
- shift=False,
211
  )
212
  x[:, block_start:block_end][transfer_map] = x0[transfer_map]
213
 
214
  if (x[:, block_start:block_end] == mask_id).sum() == 0:
 
 
 
 
215
  if (
216
  eos_id is not None
217
- and (x[:, block_start:block_end] == eos_id).sum() > 0
 
218
  ):
219
  final_flag = True
220
  x = x[:, :block_end]
221
- eos_pos = (x == eos_id).nonzero(as_tuple=True)[1][0].item()
222
  x[0, eos_pos:] = eos_id
223
  break
 
224
  nfe += 1
225
  self(
226
  x[:, block_start:block_end],
@@ -231,7 +243,7 @@ class StableDiffcoderForCausalLM(LlamaForCausalLM):
231
  use_cache=True,
232
  cache_position=replace_position.nonzero(as_tuple=True)[1],
233
  )
234
- break
235
 
236
  if final_flag:
237
  break
 
137
  prompt_length = input_ids.shape[1]
138
  gen_block_list = [block_length for _ in range(gen_blocks)]
139
 
140
+ # Fix 3: Only handle residual blocks if the prompt length is NOT cleanly divisible
141
+ remainder = prompt_length % block_length
142
+ if remainder != 0:
143
+ res_block = block_length - remainder
144
  gen_block_list = [res_block] + gen_block_list
145
  gen_block_list[-1] = block_length - res_block
146
  gen_blocks += 1
 
158
  nfe = 0
159
  final_flag = False
160
  prefill_length = prompt_length // block_length * block_length
161
+
162
  if prefill_length > 0:
163
  cur_attn_mask = block_diffusion_attention_mask[
164
  ..., :prefill_length, :prefill_length
165
  ]
166
+ # Fix 1: Explicitly pass cache_position for newer transformers prefill
167
+ cache_pos = torch.arange(prefill_length, device=x.device)
168
  self(
169
  x[:, :prefill_length],
170
  past_key_values=past_key_values,
171
  attention_mask=cur_attn_mask,
172
  use_cache=True,
173
+ cache_position=cache_pos,
174
+ )
175
 
176
  for block_id, block_size in enumerate(gen_block_list):
177
  block_start = (
 
188
  replace_position[:, block_start:block_end] = True
189
 
190
  for token_count in num_transfer_tokens:
191
+ if token_count > 0:
192
  nfe += 1
193
  mask_map = x[:, block_start:block_end] == mask_id
194
  attention_mask = block_diffusion_attention_mask[
 
211
  remasking,
212
  mask_map,
213
  x[:, block_start:block_end],
214
+ token_count.item() if threshold is None else None,
215
  threshold,
216
+ shift=shift,
217
  )
218
  x[:, block_start:block_end][transfer_map] = x0[transfer_map]
219
 
220
  if (x[:, block_start:block_end] == mask_id).sum() == 0:
221
+
222
+ # Fix 2: Calculate where the generated tokens ACTUALLY start in this block
223
+ gen_start = max(block_start, prompt_length)
224
+
225
  if (
226
  eos_id is not None
227
+ and gen_start < block_end
228
+ and (x[:, gen_start:block_end] == eos_id).sum() > 0
229
  ):
230
  final_flag = True
231
  x = x[:, :block_end]
232
+ eos_pos = (x[:, gen_start:block_end] == eos_id).nonzero(as_tuple=True)[1][0].item() + gen_start
233
  x[0, eos_pos:] = eos_id
234
  break
235
+
236
  nfe += 1
237
  self(
238
  x[:, block_start:block_end],
 
243
  use_cache=True,
244
  cache_position=replace_position.nonzero(as_tuple=True)[1],
245
  )
246
+ break
247
 
248
  if final_flag:
249
  break