xiazhi commited on
Commit
5b4caff
·
verified ·
1 Parent(s): c1698d6

Upload folder using huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_diffusionvl_qwen2_5.py +6 -3
modeling_diffusionvl_qwen2_5.py CHANGED
@@ -1057,7 +1057,9 @@ class DiffusionVL_Qwen2_5_ForConditionalGeneration(DiffusionVL_Qwen2_5_PreTraine
1057
  cur_pos_ids = position_ids[:, block_start:block_end]
1058
 
1059
  for step in range(steps + 1):
1060
- is_mask = torch.all(torch.abs(cur_block_embeds - mask_embed) < 1e-5, dim=-1)
 
 
1061
  if not is_mask.any():
1062
  if use_kv_cache:
1063
  _ = self.model(
@@ -1113,8 +1115,9 @@ class DiffusionVL_Qwen2_5_ForConditionalGeneration(DiffusionVL_Qwen2_5_PreTraine
1113
  x0_embeds = self.get_input_embeddings()(x0).to(output_device)
1114
  cur_block_embeds = torch.where(transfer_mask.unsqueeze(-1), x0_embeds, cur_block_embeds)
1115
 
1116
- x_embeds[:, block_start:block_end] = cur_block_embeds
1117
- x_ids[:, block_start:block_end] = cur_block_ids
 
1118
 
1119
  # EOS check: stop generation if EOS token is generated
1120
  if block_end > prompt_len:
 
1057
  cur_pos_ids = position_ids[:, block_start:block_end]
1058
 
1059
  for step in range(steps + 1):
1060
+ # Ensure mask_embed is on same device as cur_block_embeds (for device_map="auto")
1061
+ mask_embed_local = mask_embed.to(cur_block_embeds.device)
1062
+ is_mask = torch.all(torch.abs(cur_block_embeds - mask_embed_local) < 1e-5, dim=-1)
1063
  if not is_mask.any():
1064
  if use_kv_cache:
1065
  _ = self.model(
 
1115
  x0_embeds = self.get_input_embeddings()(x0).to(output_device)
1116
  cur_block_embeds = torch.where(transfer_mask.unsqueeze(-1), x0_embeds, cur_block_embeds)
1117
 
1118
+ # Move back to original device before assignment (for device_map="auto")
1119
+ x_embeds[:, block_start:block_end] = cur_block_embeds.to(x_embeds.device)
1120
+ x_ids[:, block_start:block_end] = cur_block_ids.to(x_ids.device)
1121
 
1122
  # EOS check: stop generation if EOS token is generated
1123
  if block_end > prompt_len: