HV-Khurdula commited on
Commit
b5aefdb
·
verified ·
1 Parent(s): 01b09b7

Update moondream.py

Browse files
Files changed (1) hide show
  1. moondream.py +5 -2
moondream.py CHANGED
@@ -608,9 +608,12 @@ class MoondreamModel(nn.Module):
608
  torch.tensor([[self.config.tokenizer.bos_id]], device=self.device), self.text
609
  )
610
  inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1)
611
- mask = self.attn_mask[:, :, 0 : inputs_embeds.size(1), :]
612
- pos_ids = torch.arange(inputs_embeds.size(1), dtype=torch.long)
 
613
  self._prefill(inputs_embeds, mask, pos_ids, lora)
 
 
614
 
615
  return EncodedImage(
616
  pos=inputs_embeds.size(1),
 
608
  torch.tensor([[self.config.tokenizer.bos_id]], device=self.device), self.text
609
  )
610
  inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1)
611
+ attn = self.attn_mask # (1,1,Tmax,Tmax)
612
+ mask = attn[:, :, pos:pos+T, :].expand(B, -1, -1, -1).contiguous() # (B,1,T,K)
613
+ pos_ids = torch.arange(pos, pos+T, device=self.device, dtype=torch.long)
614
  self._prefill(inputs_embeds, mask, pos_ids, lora)
615
+
616
+
617
 
618
  return EncodedImage(
619
  pos=inputs_embeds.size(1),