HV-Khurdula commited on
Commit
0841a6c
·
verified ·
1 Parent(s): a9ee753

Update moondream.py

Browse files

fix: breaking pos rows, and prefill processing & generation

Files changed (1) hide show
  1. moondream.py +84 -86
moondream.py CHANGED
@@ -79,11 +79,23 @@ class KVCache(nn.Module):
79
 
80
  def update(self, pos_ids, k, v):
81
  kout, vout = self.k_cache, self.v_cache
82
- kout[:, :, pos_ids, :] = k
83
- vout[:, :, pos_ids, :] = v
 
 
 
 
 
 
 
 
 
 
 
84
  return kout, vout
85
 
86
 
 
87
  class MoondreamModel(nn.Module):
88
 
89
  def __init__(
@@ -532,6 +544,11 @@ class MoondreamModel(nn.Module):
532
  return image
533
  elif not isinstance(image, Image.Image):
534
  raise ValueError("image must be a PIL Image or EncodedImage")
 
 
 
 
 
535
 
536
  lora = (
537
  variant_state_dict(settings["variant"], device=self.device)
@@ -867,14 +884,8 @@ class MoondreamModel(nn.Module):
867
  b.kv_cache.k_cache[:, :, :T, :] = k.expand(batch_size, -1, -1, -1)
868
  b.kv_cache.v_cache[:, :, :T, :] = v.expand(batch_size, -1, -1, -1)
869
 
870
- def _prefill_prompt_batched(
871
- self,
872
- labels,
873
- pos: int,
874
- lora=None,
875
- temperature: float = 0.0,
876
- top_p: float = 0.0,
877
- ):
878
  tpl = self.config.tokenizer.templates["detect"]
879
  if tpl is None:
880
  raise NotImplementedError("Model does not support object detection (no detect template).")
@@ -882,28 +893,27 @@ class MoondreamModel(nn.Module):
882
  rows, lens = [], []
883
  for lab in labels:
884
  ids = tpl["prefix"] + self.tokenizer.encode(" " + lab).ids + tpl["suffix"]
885
- rows.append(torch.tensor(ids, device=self.device, dtype=torch.long))
886
- lens.append(len(ids))
887
- B = len(rows)
888
- T = max(lens)
889
  eos = self.config.tokenizer.eos_id
890
 
891
  prompt_ids = torch.full((B, T), eos, device=self.device, dtype=torch.long)
892
  for i, ids in enumerate(rows):
893
  prompt_ids[i, : ids.numel()] = ids
894
 
895
- prompt_emb = text_encoder(prompt_ids, self.text) # (B, T, C)
896
  torch._dynamo.mark_dynamic(prompt_emb, 1)
897
 
898
- # 4-D mask is broadcastable to (B, n_heads, T, K)
899
- attn = self.attn_mask
900
- mask = attn[:, :, pos : pos + T, :].expand(B, -1, -1, -1).contiguous()
901
- pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long)
902
 
903
- hidden_BTC = self._prefill(prompt_emb, mask, pos_ids, lora) # (B, T, C)
904
- logits_BTV = lm_head(hidden_BTC, self.text) # (B, T, V)
 
905
 
906
- idx = (torch.tensor(lens, device=self.device, dtype=torch.long) - 1).clamp_min(0)
907
  last_hidden = hidden_BTC[torch.arange(B, device=self.device), idx][:, None, :] # (B,1,C)
908
  last_logits = logits_BTV[torch.arange(B, device=self.device), idx] # (B,V)
909
 
@@ -914,117 +924,105 @@ class MoondreamModel(nn.Module):
914
  probs = self._apply_top_p(probs, top_p)
915
  next_token = torch.multinomial(probs, num_samples=1) # (B,1)
916
 
917
- pos_end = pos + T
918
- return last_hidden, next_token, pos_end # (B,1,C), (B,1), int
919
-
920
-
921
 
922
- def _generate_points_batched(
923
- self,
924
- hidden, # (B,1,C)
925
- next_token, # (B,1)
926
- pos: int, # shared scalar next position
927
- include_size: bool = True,
928
- max_objects: int = 50,
929
- lora=None,
930
- ):
931
  B = hidden.size(0)
932
  device = self.device
933
  out = [[] for _ in range(B)]
934
  eos_id = self.config.tokenizer.eos_id
935
  max_ctx = self.config.text.max_context
936
 
937
- # 4-D mask: (B, 1, q_len=1, kv_len)
938
  mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
939
- if pos > 0:
940
- mask[:, :, :, :pos] = True
941
- pos_id = torch.tensor([pos], device=device, dtype=torch.long) # (1,)
942
 
943
- alive = torch.ones(B, dtype=torch.bool, device=device)
944
- counts = torch.zeros(B, dtype=torch.int32, device=device)
 
945
 
946
  with torch.inference_mode():
947
  while alive.any() and (counts < max_objects).any():
948
- # --- x coordinate ---
949
- x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
950
- if x_logits.dim() == 3:
951
- x_logits = x_logits.squeeze(1)
952
  x_bin = x_logits.argmax(dim=-1).to(torch.float32)
953
- x_center = x_bin / float(x_logits.size(-1)) # (B,)
954
  x_in = x_center.to(dtype=x_logits.dtype).unsqueeze(-1) # (B,1)
955
- x_emb = encode_coordinate(x_in, self.region).unsqueeze(1) # (B,1,C)
956
 
957
- mask[:, :, :, pos_id[0].item()] = True
958
- logits, hidden = self._decode_one_tok(x_emb, mask, pos_id, lora)
959
- pos_id += 1
 
 
960
 
961
- # --- y coordinate ---
962
  y_logits = decode_coordinate(hidden, self.region)
963
- if y_logits.dim() == 3:
964
- y_logits = y_logits.squeeze(1)
965
  y_bin = y_logits.argmax(dim=-1).to(torch.float32)
966
- y_center = y_bin / float(y_logits.size(-1)) # (B,)
967
  y_in = y_center.to(dtype=y_logits.dtype).unsqueeze(-1)
968
  y_emb = encode_coordinate(y_in, self.region).unsqueeze(1)
969
 
970
- mask[:, :, :, pos_id[0].item()] = True
971
- logits, hidden = self._decode_one_tok(y_emb, mask, pos_id, lora)
972
- pos_id += 1
 
973
 
974
  if include_size:
975
- size_logits = decode_size(hidden, self.region)
976
- # Support both tuple-of-tensors and flattened (2, -1) forms
977
- if isinstance(size_logits, (tuple, list)):
978
- w_logits = size_logits[0]
979
- h_logits = size_logits[1]
980
- if w_logits.dim() == 3: # (B,1,1024)
981
- w_logits = w_logits.squeeze(1)
982
- h_logits = h_logits.squeeze(1)
983
- else:
984
- # size_logits shape: (2, B * size_bins) — reshape it back.
985
- size_logits = size_logits.reshape(2, B, -1)
986
- w_logits, h_logits = size_logits[0], size_logits[1] # (B, size_bins)
987
 
988
  w_bin = w_logits.argmax(dim=-1).to(torch.float32)
989
  h_bin = h_logits.argmax(dim=-1).to(torch.float32)
990
- # inverse of log-scale mapping used by Moondream
991
  w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
992
  h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
993
-
994
- size_in = torch.stack([w, h], dim=1).to(dtype=w_logits.dtype) # (B,2)
995
  size_emb = encode_size(size_in, self.region).unsqueeze(1) # (B,1,C)
996
 
997
  for i in range(B):
998
- if alive[i]:
999
- out[i].append({
1000
- "x_min": (x_center[i] - w[i] / 2).item(),
1001
- "y_min": (y_center[i] - h[i] / 2).item(),
1002
- "x_max": (x_center[i] + w[i] / 2).item(),
1003
- "y_max": (y_center[i] + h[i] / 2).item(),
1004
- })
1005
 
1006
- mask[:, :, :, pos_id[0].item()] = True
1007
- logits, hidden = self._decode_one_tok(size_emb, mask, pos_id, lora)
1008
- pos_id += 1
 
1009
  next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
1010
  else:
1011
  for i in range(B):
1012
  if alive[i]:
1013
  out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
1014
 
1015
- mask[:, :, :, pos_id[0].item()] = True
1016
- logits, hidden = self._decode_one_tok(y_emb, mask, pos_id, lora)
1017
- pos_id += 1
 
1018
  next_tok = logits.argmax(dim=-1).squeeze(-1)
1019
 
1020
  finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
1021
- counts = counts + (~finished_now & alive).to(counts.dtype)
1022
  alive &= ~finished_now
1023
 
1024
  return out
1025
 
1026
 
1027
 
 
1028
  def detect_multi(self, image, objects, settings=None):
1029
  """
1030
  Parallel multi-label detection.
 
79
 
80
  def update(self, pos_ids, k, v):
81
  kout, vout = self.k_cache, self.v_cache
82
+ # pos_ids: scalar (int or 0-D) OR LongTensor[B]
83
+ if not torch.is_tensor(pos_ids) or pos_ids.ndim == 0:
84
+ # singleton batch
85
+ kout[:, :, pos_ids, :] = k
86
+ vout[:, :, pos_ids, :] = v
87
+ else:
88
+ # batched: write each row into its own position
89
+ B = k.size(0)
90
+ # Safe, explicit per-row scatter (B is usually small)
91
+ for i in range(B):
92
+ pi = int(pos_ids[i].item())
93
+ kout[i, :, pi, :] = k[i]
94
+ vout[i, :, pi, :] = v[i]
95
  return kout, vout
96
 
97
 
98
+
99
  class MoondreamModel(nn.Module):
100
 
101
  def __init__(
 
544
  return image
545
  elif not isinstance(image, Image.Image):
546
  raise ValueError("image must be a PIL Image or EncodedImage")
547
+
548
+ for blk in self.text.blocks:
549
+ if blk.kv_cache.k_cache.size(0) != 1:
550
+ blk.kv_cache.k_cache = blk.kv_cache.k_cache[:1].contiguous()
551
+ blk.kv_cache.v_cache = blk.kv_cache.v_cache[:1].contiguous()
552
 
553
  lora = (
554
  variant_state_dict(settings["variant"], device=self.device)
 
884
  b.kv_cache.k_cache[:, :, :T, :] = k.expand(batch_size, -1, -1, -1)
885
  b.kv_cache.v_cache[:, :, :T, :] = v.expand(batch_size, -1, -1, -1)
886
 
887
+ def _prefill_prompt_batched(self, labels, pos: int, lora=None,
888
+ temperature: float = 0.0, top_p: float = 0.0):
 
 
 
 
 
 
889
  tpl = self.config.tokenizer.templates["detect"]
890
  if tpl is None:
891
  raise NotImplementedError("Model does not support object detection (no detect template).")
 
893
  rows, lens = [], []
894
  for lab in labels:
895
  ids = tpl["prefix"] + self.tokenizer.encode(" " + lab).ids + tpl["suffix"]
896
+ t = torch.tensor(ids, device=self.device, dtype=torch.long)
897
+ rows.append(t); lens.append(t.numel())
898
+ B = len(rows); T = max(lens)
 
899
  eos = self.config.tokenizer.eos_id
900
 
901
  prompt_ids = torch.full((B, T), eos, device=self.device, dtype=torch.long)
902
  for i, ids in enumerate(rows):
903
  prompt_ids[i, : ids.numel()] = ids
904
 
905
+ prompt_emb = text_encoder(prompt_ids, self.text) # (B,T,C)
906
  torch._dynamo.mark_dynamic(prompt_emb, 1)
907
 
908
+ # 4-D mask: (B,1,T,kv_len) for SDPA
909
+ base = self.attn_mask[:, :, pos:pos+T, :] # (1,1,T,kv_len)
910
+ mask = base.expand(B, -1, -1, -1).contiguous() # (B,1,T,kv_len)
 
911
 
912
+ pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long) # (T,)
913
+ hidden_BTC = self._prefill(prompt_emb, mask, pos_ids, lora) # (B,T,C)
914
+ logits_BTV = lm_head(hidden_BTC, self.text)
915
 
916
+ idx = (torch.tensor(lens, device=self.device) - 1).clamp_min(0) # (B,)
917
  last_hidden = hidden_BTC[torch.arange(B, device=self.device), idx][:, None, :] # (B,1,C)
918
  last_logits = logits_BTV[torch.arange(B, device=self.device), idx] # (B,V)
919
 
 
924
  probs = self._apply_top_p(probs, top_p)
925
  next_token = torch.multinomial(probs, num_samples=1) # (B,1)
926
 
927
+ # CRITICAL: per-row next position
928
+ pos_vec = torch.tensor(lens, device=self.device, dtype=torch.long) + pos # (B,)
929
+
930
+ return last_hidden, next_token, pos_vec
931
 
932
+ def _generate_points_batched(self, hidden, next_token, pos_vec,
933
+ include_size: bool = True, max_objects: int = 50, lora=None):
 
 
 
 
 
 
 
934
  B = hidden.size(0)
935
  device = self.device
936
  out = [[] for _ in range(B)]
937
  eos_id = self.config.tokenizer.eos_id
938
  max_ctx = self.config.text.max_context
939
 
940
+ # 4-D mask: (B,1,1,kv_len) and fill historical prefix per row
941
  mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
942
+ for i in range(B):
943
+ p = int(pos_vec[i].item())
944
+ if p > 0: mask[i, 0, 0, :p] = True
945
 
946
+ pos_ids = pos_vec.clone() # (B,)
947
+ alive = torch.ones(B, dtype=torch.bool, device=device)
948
+ counts = torch.zeros(B, dtype=torch.int32, device=device)
949
 
950
  with torch.inference_mode():
951
  while alive.any() and (counts < max_objects).any():
952
+ # --- x ---
953
+ x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
954
+ if x_logits.dim() == 3: x_logits = x_logits.squeeze(1)
 
955
  x_bin = x_logits.argmax(dim=-1).to(torch.float32)
956
+ x_center = x_bin / float(x_logits.size(-1)) # (B,)
957
  x_in = x_center.to(dtype=x_logits.dtype).unsqueeze(-1) # (B,1)
958
+ x_emb = encode_coordinate(x_in, self.region).unsqueeze(1) # (B,1,C)
959
 
960
+ # advance one position per *alive* row
961
+ for i in range(B):
962
+ if alive[i]: mask[i, 0, 0, int(pos_ids[i].item())] = True
963
+ logits, hidden = self._decode_one_tok(x_emb, mask, pos_ids, lora)
964
+ pos_ids = pos_ids + alive.to(torch.long)
965
 
966
+ # --- y ---
967
  y_logits = decode_coordinate(hidden, self.region)
968
+ if y_logits.dim() == 3: y_logits = y_logits.squeeze(1)
 
969
  y_bin = y_logits.argmax(dim=-1).to(torch.float32)
970
+ y_center = y_bin / float(y_logits.size(-1))
971
  y_in = y_center.to(dtype=y_logits.dtype).unsqueeze(-1)
972
  y_emb = encode_coordinate(y_in, self.region).unsqueeze(1)
973
 
974
+ for i in range(B):
975
+ if alive[i]: mask[i, 0, 0, int(pos_ids[i].item())] = True
976
+ logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
977
+ pos_ids = pos_ids + alive.to(torch.long)
978
 
979
  if include_size:
980
+ size_logits = decode_size(hidden, self.region) # Expect [(B,1,1024),(B,1,1024)] or (tuple)
981
+ # be robust to either rank
982
+ w_logits = size_logits[0].squeeze(1) if size_logits[0].dim() == 3 else size_logits[0]
983
+ h_logits = size_logits[1].squeeze(1) if size_logits[1].dim() == 3 else size_logits[1]
 
 
 
 
 
 
 
 
984
 
985
  w_bin = w_logits.argmax(dim=-1).to(torch.float32)
986
  h_bin = h_logits.argmax(dim=-1).to(torch.float32)
 
987
  w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
988
  h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
989
+ size_in = torch.stack([w, h], dim=1).to(dtype=w_logits.dtype) # (B,2)
 
990
  size_emb = encode_size(size_in, self.region).unsqueeze(1) # (B,1,C)
991
 
992
  for i in range(B):
993
+ if not alive[i]: continue
994
+ out[i].append({
995
+ "x_min": (x_center[i] - w[i] / 2).item(),
996
+ "y_min": (y_center[i] - h[i] / 2).item(),
997
+ "x_max": (x_center[i] + w[i] / 2).item(),
998
+ "y_max": (y_center[i] + h[i] / 2).item(),
999
+ })
1000
 
1001
+ for i in range(B):
1002
+ if alive[i]: mask[i, 0, 0, int(pos_ids[i].item())] = True
1003
+ logits, hidden = self._decode_one_tok(size_emb, mask, pos_ids, lora)
1004
+ pos_ids = pos_ids + alive.to(torch.long)
1005
  next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
1006
  else:
1007
  for i in range(B):
1008
  if alive[i]:
1009
  out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
1010
 
1011
+ for i in range(B):
1012
+ if alive[i]: mask[i, 0, 0, int(pos_ids[i].item())] = True
1013
+ logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
1014
+ pos_ids = pos_ids + alive.to(torch.long)
1015
  next_tok = logits.argmax(dim=-1).squeeze(-1)
1016
 
1017
  finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
1018
+ counts = counts + ((~finished_now) & alive).to(counts.dtype)
1019
  alive &= ~finished_now
1020
 
1021
  return out
1022
 
1023
 
1024
 
1025
+
1026
  def detect_multi(self, image, objects, settings=None):
1027
  """
1028
  Parallel multi-label detection.