HV-Khurdula commited on
Commit
18ff09d
·
verified ·
1 Parent(s): b5aefdb

Update moondream.py

Browse files
Files changed (1) hide show
  1. moondream.py +2 -2
moondream.py CHANGED
@@ -609,8 +609,8 @@ class MoondreamModel(nn.Module):
609
  )
610
  inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1)
611
  attn = self.attn_mask # (1,1,Tmax,Tmax)
612
- mask = attn[:, :, pos:pos+T, :].expand(B, -1, -1, -1).contiguous() # (B,1,T,K)
613
- pos_ids = torch.arange(pos, pos+T, device=self.device, dtype=torch.long)
614
  self._prefill(inputs_embeds, mask, pos_ids, lora)
615
 
616
 
 
609
  )
610
  inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1)
611
  attn = self.attn_mask # (1,1,Tmax,Tmax)
612
+ mask = self.attn_mask[:, :, 0 : inputs_embeds.size(1), :]
613
+ pos_ids = torch.arange(inputs_embeds.size(1), dtype=torch.long)
614
  self._prefill(inputs_embeds, mask, pos_ids, lora)
615
 
616