HV-Khurdula commited on
Commit
15e66ee
·
verified ·
1 Parent(s): 17e2272

Update moondream.py

Browse files

feat: make prefill granular.

Files changed (1) hide show
  1. moondream.py +28 -22
moondream.py CHANGED
@@ -943,13 +943,19 @@ class MoondreamModel(nn.Module):
943
  b.kv_cache.v_cache[:, :, :T, :] = v.expand(batch_size, -1, -1, -1)
944
 
945
 
946
- def _prefill_prompt_batched(self, labels, pos: int, lora=None,
947
- temperature: float = 0.0, top_p: float = 0.0):
 
 
 
 
 
 
948
  tpl = self.config.tokenizer.templates["detect"]
949
  if tpl is None:
950
  raise NotImplementedError("Model does not support object detection.")
951
 
952
- # Build each row's token ids (variable length)
953
  rows_ids, lens = [], []
954
  for lab in labels:
955
  ids = tpl["prefix"] + self.tokenizer.encode(" " + lab).ids + tpl["suffix"]
@@ -957,44 +963,44 @@ class MoondreamModel(nn.Module):
957
  rows_ids.append(t)
958
  lens.append(t.numel())
959
 
960
- B, T = len(rows_ids), max(lens)
 
961
 
962
- # Embed each row, then LEFT-pad using its own first token embedding (neutral),
963
- # mirroring upstream moondream2 batching strategy.
964
- embs = [text_encoder(t.unsqueeze(0), self.text)[0] for t in rows_ids] # list of (Li, C)
965
  padded = []
966
  for e, L in zip(embs, lens):
967
  pad = T - L
968
  if pad > 0:
969
- e = torch.cat([e[:1].repeat(pad, 1), e], dim=0) # (T, C)
970
  padded.append(e)
971
- prompt_emb = torch.stack(padded, dim=0) # (B, T, C)
972
  torch._dynamo.mark_dynamic(prompt_emb, 1)
973
 
974
- # Standard prefill over the shared image prefix [pos : pos+T]
975
- base = self.attn_mask[:, :, pos:pos+T, :] # (1,1,T,K)
976
- mask = base.expand(B, -1, -1, -1).contiguous() # (B,1,T,K)
977
  pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long) # (T,)
 
 
978
 
979
- hidden_BTC = self._prefill(prompt_emb, mask, pos_ids, lora) # (B,T,C)
980
- logits_BTV = lm_head(hidden_BTC, self.text) # (B,T,V)
981
-
982
- # Take the last real token of each row
983
- idx = (torch.tensor(lens, device=self.device) - 1).clamp_min(0) # (B,)
984
- last_hidden = hidden_BTC[torch.arange(B, device=self.device), idx][:, None, :] # (B,1,C)
985
- last_logits = logits_BTV[torch.arange(B, device=self.device), idx] # (B,V)
986
 
987
  if temperature == 0.0:
988
- next_token = last_logits.argmax(dim=-1, keepdim=True) # (B,1)
989
  else:
990
  probs = torch.softmax(last_logits / temperature, dim=-1)
991
  probs = self._apply_top_p(probs, top_p)
992
- next_token = torch.multinomial(probs, num_samples=1) # (B,1)
993
 
994
- pos_end = int(pos + T)
995
  return last_hidden, next_token, pos_end
996
 
997
 
 
998
  def _generate_points_batched(
999
  self,
1000
  hidden, # (B,1,C) - last token hidden state per row
 
943
  b.kv_cache.v_cache[:, :, :T, :] = v.expand(batch_size, -1, -1, -1)
944
 
945
 
946
+ def _prefill_prompt_batched(
947
+ self,
948
+ labels: List[str],
949
+ pos: int,
950
+ lora=None,
951
+ temperature: float = 0.0,
952
+ top_p: float = 0.0,
953
+ ):
954
  tpl = self.config.tokenizer.templates["detect"]
955
  if tpl is None:
956
  raise NotImplementedError("Model does not support object detection.")
957
 
958
+ # 1) Build token ids for each label (variable length)
959
  rows_ids, lens = [], []
960
  for lab in labels:
961
  ids = tpl["prefix"] + self.tokenizer.encode(" " + lab).ids + tpl["suffix"]
 
963
  rows_ids.append(t)
964
  lens.append(t.numel())
965
 
966
+ B = len(rows_ids)
967
+ T = max(lens)
968
 
969
+ # 2) Embed then LEFT-pad each row to length T using the row’s first token embedding
970
+ embs = [text_encoder(t.unsqueeze(0), self.text)[0] for t in rows_ids] # list[(Li, C)]
 
971
  padded = []
972
  for e, L in zip(embs, lens):
973
  pad = T - L
974
  if pad > 0:
975
+ e = torch.cat([e[:1].repeat(pad, 1), e], dim=0) # (T, C)
976
  padded.append(e)
977
+ prompt_emb = torch.stack(padded, dim=0) # (B, T, C)
978
  torch._dynamo.mark_dynamic(prompt_emb, 1)
979
 
980
+ # 3) Prefill over the shared image prefix [pos : pos + T)
981
+ base = self.attn_mask[:, :, pos:pos + T, :] # (1, 1, T, K)
982
+ mask = base.expand(B, -1, -1, -1).contiguous() # (B, 1, T, K)
983
  pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long) # (T,)
984
+ hidden_BTC = self._prefill(prompt_emb, mask, pos_ids, lora) # (B, T, C)
985
+ logits_BTV = lm_head(hidden_BTC, self.text) # (B, T, V)
986
 
987
+ # **FIX**: After left-padding, the last real token sits at T-1 for every row.
988
+ last_idx = torch.full((B,), T - 1, device=self.device, dtype=torch.long) # (B,)
989
+ last_hidden = hidden_BTC[torch.arange(B, device=self.device), last_idx][:, None, :] # (B, 1, C)
990
+ last_logits = logits_BTV[torch.arange(B, device=self.device), last_idx] # (B, V)
 
 
 
991
 
992
  if temperature == 0.0:
993
+ next_token = last_logits.argmax(dim=-1, keepdim=True) # (B, 1)
994
  else:
995
  probs = torch.softmax(last_logits / temperature, dim=-1)
996
  probs = self._apply_top_p(probs, top_p)
997
+ next_token = torch.multinomial(probs, num_samples=1) # (B, 1)
998
 
999
+ pos_end = int(pos + T)
1000
  return last_hidden, next_token, pos_end
1001
 
1002
 
1003
+
1004
  def _generate_points_batched(
1005
  self,
1006
  hidden, # (B,1,C) - last token hidden state per row