Spaces:
Sleeping
Sleeping
Update inferencer.py
Browse files- inferencer.py +7 -1
inferencer.py
CHANGED
|
@@ -212,7 +212,13 @@ class UniPicV2Inferencer:
|
|
| 212 |
# Forward through LMM
|
| 213 |
if hasattr(self.lmm.model, "rope_deltas"):
|
| 214 |
self.lmm.model.rope_deltas = None
|
| 215 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 216 |
outputs = self.lmm.model(
|
| 217 |
inputs_embeds=inputs_embeds.to(self.device),
|
| 218 |
attention_mask=attention_mask.to(self.device),
|
|
|
|
| 212 |
# Forward through LMM
|
| 213 |
if hasattr(self.lmm.model, "rope_deltas"):
|
| 214 |
self.lmm.model.rope_deltas = None
|
| 215 |
+
|
| 216 |
+
model_device = self.lmm.model.embed_tokens.weight.device
|
| 217 |
+
# 强制将所有 tensor 输入搬到这个设备
|
| 218 |
+
for k, v in inputs.items():
|
| 219 |
+
if isinstance(v, torch.Tensor):
|
| 220 |
+
inputs[k] = v.to(model_device)
|
| 221 |
+
|
| 222 |
outputs = self.lmm.model(
|
| 223 |
inputs_embeds=inputs_embeds.to(self.device),
|
| 224 |
attention_mask=attention_mask.to(self.device),
|