Spaces:
Running
on
Zero
Running
on
Zero
Update pipeline_qwenimage_edit.py
Browse files- pipeline_qwenimage_edit.py +14 -16
pipeline_qwenimage_edit.py
CHANGED
|
@@ -276,37 +276,35 @@ class QwenImageEditPipeline(DiffusionPipeline, QwenImageLoraLoaderMixin):
|
|
| 276 |
)
|
| 277 |
texts.append(text)
|
| 278 |
|
| 279 |
-
# Process inputs
|
| 280 |
model_inputs = self.processor(
|
| 281 |
text=texts,
|
| 282 |
images=images,
|
| 283 |
do_resize=False, # already resized
|
| 284 |
padding=True,
|
| 285 |
return_tensors="pt"
|
| 286 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 287 |
|
| 288 |
# template = self.prompt_template_encode
|
| 289 |
drop_idx = self.prompt_template_encode_start_idx
|
| 290 |
-
# txt = [template.format(e) for e in prompt]
|
| 291 |
-
|
| 292 |
-
# model_inputs = self.processor(
|
| 293 |
-
# text=txt,
|
| 294 |
-
# images=image,
|
| 295 |
-
# padding=True,
|
| 296 |
-
# return_tensors="pt",
|
| 297 |
-
# ).to(device)
|
| 298 |
|
| 299 |
outputs = self.text_encoder(
|
| 300 |
-
input_ids=
|
| 301 |
-
attention_mask=
|
| 302 |
-
pixel_values=
|
| 303 |
-
image_grid_thw=
|
| 304 |
output_hidden_states=True,
|
| 305 |
)
|
| 306 |
-
# import pdb; pdb.set_trace()
|
| 307 |
|
| 308 |
hidden_states = outputs.hidden_states[-1]
|
| 309 |
-
split_hidden_states = self._extract_masked_hidden(hidden_states,
|
| 310 |
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
|
| 311 |
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
|
| 312 |
max_seq_len = max([e.size(0) for e in split_hidden_states])
|
|
|
|
| 276 |
)
|
| 277 |
texts.append(text)
|
| 278 |
|
| 279 |
+
# Process inputs
|
| 280 |
model_inputs = self.processor(
|
| 281 |
text=texts,
|
| 282 |
images=images,
|
| 283 |
do_resize=False, # already resized
|
| 284 |
padding=True,
|
| 285 |
return_tensors="pt"
|
| 286 |
+
)
|
| 287 |
+
|
| 288 |
+
# 修复:明确将每个张量移动到正确的设备
|
| 289 |
+
# 不依赖 .to(device) 的自动传播
|
| 290 |
+
input_ids = model_inputs.input_ids.to(device)
|
| 291 |
+
attention_mask = model_inputs.attention_mask.to(device)
|
| 292 |
+
pixel_values = model_inputs.pixel_values.to(device=device, dtype=dtype)
|
| 293 |
+
image_grid_thw = model_inputs.image_grid_thw.to(device)
|
| 294 |
|
| 295 |
# template = self.prompt_template_encode
|
| 296 |
drop_idx = self.prompt_template_encode_start_idx
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 297 |
|
| 298 |
outputs = self.text_encoder(
|
| 299 |
+
input_ids=input_ids,
|
| 300 |
+
attention_mask=attention_mask,
|
| 301 |
+
pixel_values=pixel_values,
|
| 302 |
+
image_grid_thw=image_grid_thw,
|
| 303 |
output_hidden_states=True,
|
| 304 |
)
|
|
|
|
| 305 |
|
| 306 |
hidden_states = outputs.hidden_states[-1]
|
| 307 |
+
split_hidden_states = self._extract_masked_hidden(hidden_states, attention_mask)
|
| 308 |
split_hidden_states = [e[drop_idx:] for e in split_hidden_states]
|
| 309 |
attn_mask_list = [torch.ones(e.size(0), dtype=torch.long, device=e.device) for e in split_hidden_states]
|
| 310 |
max_seq_len = max([e.size(0) for e in split_hidden_states])
|