Spaces:
Sleeping
Sleeping
Update inferencer.py
Browse files- inferencer.py +7 -0
inferencer.py
CHANGED
|
@@ -186,6 +186,13 @@ class UniPicV2Inferencer:
|
|
| 186 |
attention_mask = torch.cat([attention_mask, pad_mask], dim=1)
|
| 187 |
|
| 188 |
# Get input embeddings
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 189 |
inputs_embeds = self.lmm.get_input_embeddings()(input_ids)
|
| 190 |
|
| 191 |
# Ensure meta queries are on correct device
|
|
|
|
| 186 |
attention_mask = torch.cat([attention_mask, pad_mask], dim=1)
|
| 187 |
|
| 188 |
# Get input embeddings
|
| 189 |
+
|
| 190 |
+
# 获取 embedding 权重所在设备
|
| 191 |
+
embed_device = self.lmm.get_input_embeddings().weight.device
|
| 192 |
+
|
| 193 |
+
# 确保 input_ids 在同一设备
|
| 194 |
+
input_ids = input_ids.to(embed_device)
|
| 195 |
+
|
| 196 |
inputs_embeds = self.lmm.get_input_embeddings()(input_ids)
|
| 197 |
|
| 198 |
# Ensure meta queries are on correct device
|