Update moondream.py
Browse files- moondream.py +5 -2
moondream.py
CHANGED
|
@@ -608,9 +608,12 @@ class MoondreamModel(nn.Module):
|
|
| 608 |
torch.tensor([[self.config.tokenizer.bos_id]], device=self.device), self.text
|
| 609 |
)
|
| 610 |
inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1)
|
| 611 |
-
|
| 612 |
-
|
|
|
|
| 613 |
self._prefill(inputs_embeds, mask, pos_ids, lora)
|
|
|
|
|
|
|
| 614 |
|
| 615 |
return EncodedImage(
|
| 616 |
pos=inputs_embeds.size(1),
|
|
|
|
| 608 |
torch.tensor([[self.config.tokenizer.bos_id]], device=self.device), self.text
|
| 609 |
)
|
| 610 |
inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1)
|
| 611 |
+
attn = self.attn_mask # (1,1,Tmax,Tmax)
|
| 612 |
+
mask = attn[:, :, pos:pos+T, :].expand(B, -1, -1, -1).contiguous() # (B,1,T,K)
|
| 613 |
+
pos_ids = torch.arange(pos, pos+T, device=self.device, dtype=torch.long)
|
| 614 |
self._prefill(inputs_embeds, mask, pos_ids, lora)
|
| 615 |
+
|
| 616 |
+
|
| 617 |
|
| 618 |
return EncodedImage(
|
| 619 |
pos=inputs_embeds.size(1),
|