Update moondream.py
Browse files- moondream.py +2 -2
moondream.py
CHANGED
|
@@ -609,8 +609,8 @@ class MoondreamModel(nn.Module):
|
|
| 609 |
)
|
| 610 |
inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1)
|
| 611 |
attn = self.attn_mask # (1,1,Tmax,Tmax)
|
| 612 |
-
mask =
|
| 613 |
-
pos_ids = torch.arange(
|
| 614 |
self._prefill(inputs_embeds, mask, pos_ids, lora)
|
| 615 |
|
| 616 |
|
|
|
|
| 609 |
)
|
| 610 |
inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1)
|
| 611 |
attn = self.attn_mask # (1,1,Tmax,Tmax)
|
| 612 |
+
mask = self.attn_mask[:, :, 0 : inputs_embeds.size(1), :]
|
| 613 |
+
pos_ids = torch.arange(inputs_embeds.size(1), dtype=torch.long)
|
| 614 |
self._prefill(inputs_embeds, mask, pos_ids, lora)
|
| 615 |
|
| 616 |
|