OrlandoHugBot commited on
Commit
87f5c9f
·
verified ·
1 Parent(s): 1e77965

Update pipeline_qwenimage_edit.py

Browse files
Files changed (1) hide show
  1. pipeline_qwenimage_edit.py +14 -16
pipeline_qwenimage_edit.py CHANGED
@@ -276,37 +276,35 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
276
  )
277
  texts.append(text)
278
 
279
- # Process inputs - 修复:使用 text_encoder 的实际设备
280
  model_inputs = self.processor(
281
  text=texts,
282
  images=images,
283
  do_resize=False, # already resized
284
  padding=True,
285
  return_tensors="pt"
286
- ).to(device)
 
 
 
 
 
 
 
287
 
288
  # template = self.prompt_template_encode
289
  drop_idx = self.prompt_template_encode_start_idx
290
- # txt = [template.format(e) for e in prompt]
291
-
292
- # model_inputs = self.processor(
293
- # text=txt,
294
- # images=image,
295
- # padding=True,
296
- # return_tensors="pt",
297
- # ).to(device)
298
 
299
  outputs = self.text_encoder(
300
- input_ids=model_inputs.input_ids,
301
- attention_mask=model_inputs.attention_mask,
302
- pixel_values=model_inputs.pixel_values,
303
- image_grid_thw=model_inputs.image_grid_thw,
304
  output_hidden_states=True,
305
  )
306
- # import pdb; pdb.set_trace()
307
 
308
  hidden_states = outputs.hidden_states[-1]
309
- split_hidden_states = self._extract_masked_hidden(hidden_states, model_inputs.attention_mask)
310
  split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
311
  attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
312
  max_seq_len = max([e.size(0) for e in split_hidden_states])
 
276
  )
277
  texts.append(text)
278
 
279
+ # Process inputs
280
  model_inputs = self.processor(
281
  text=texts,
282
  images=images,
283
  do_resize=False, # already resized
284
  padding=True,
285
  return_tensors="pt"
286
+ )
287
+
288
+ # 修复:明确将每个张量移动到正确的设备
289
+ # 不依赖 .to(device) 的自动传播
290
+ input_ids = model_inputs.input_ids.to(device)
291
+ attention_mask = model_inputs.attention_mask.to(device)
292
+ pixel_values = model_inputs.pixel_values.to(device=device, dtype=dtype)
293
+ image_grid_thw = model_inputs.image_grid_thw.to(device)
294
 
295
  # template = self.prompt_template_encode
296
  drop_idx = self.prompt_template_encode_start_idx
 
 
 
 
 
 
 
 
297
 
298
  outputs = self.text_encoder(
299
+ input_ids=input_ids,
300
+ attention_mask=attention_mask,
301
+ pixel_values=pixel_values,
302
+ image_grid_thw=image_grid_thw,
303
  output_hidden_states=True,
304
  )
 
305
 
306
  hidden_states = outputs.hidden_states[-1]
307
+ split_hidden_states = self._extract_masked_hidden(hidden_states, attention_mask)
308
  split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
309
  attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
310
  max_seq_len = max([e.size(0) for e in split_hidden_states])