HV-Khurdula commited on
Commit
b00c890
·
verified ·
1 Parent(s): cd2accb

Update moondream.py

Browse files

fix: detection corruption

Files changed (1) hide show
  1. moondream.py +53 -173
moondream.py CHANGED
@@ -76,23 +76,21 @@ class KVCache(nn.Module):
76
  Supports:
77
  • Prefill: k,v = (B, n_kv_heads, q_len, d), pos_ids = (q_len,)
78
  • 1-step: k,v = (B, n_kv_heads, 1, d), pos_ids = (B,1) or (B,)
79
- • Legacy: k,v = (1, n_kv_heads, q_len, d), pos_ids = scalar
80
  Writes into self.k_cache/self.v_cache shaped (B, n_kv_heads, T_max, d).
81
  """
82
- kout, vout = self.k_cache, self.v_cache
83
-
84
- # Normalize pos_ids
85
  if not torch.is_tensor(pos_ids):
86
  pos_ids = torch.tensor(pos_ids, device=k.device, dtype=torch.long)
87
  else:
88
  pos_ids = pos_ids.to(device=k.device, dtype=torch.long)
89
-
90
  if k.dim() != 4 or v.dim() != 4:
91
- raise RuntimeError(f"KV update expects 4D k,v. Got k={tuple(k.shape)} v={tuple(v.shape)}")
92
-
93
  B, Hkv, q_len, D = k.shape
94
-
95
- # Ensure cache batch matches B (expand-from-1 allowed)
 
96
  if kout.size(0) != B:
97
  if kout.size(0) == 1:
98
  self.k_cache = kout.expand(B, -1, -1, -1).clone()
@@ -100,36 +98,37 @@ class KVCache(nn.Module):
100
  kout, vout = self.k_cache, self.v_cache
101
  else:
102
  raise RuntimeError(f"KV cache batch mismatch: cache.B={kout.size(0)} vs k.B={B}")
103
-
104
- # Case A: PREFILL — vector of length q_len (same for all B rows)
105
  if pos_ids.dim() == 1 and pos_ids.numel() == q_len:
106
  for i in range(B):
107
  kout[i, :, pos_ids, :] = k[i]
108
  vout[i, :, pos_ids, :] = v[i]
109
  return kout, vout
110
-
111
- # Case B: 1-STEP q_len == 1 with (B,) or (B,1) per-row positions
112
- if q_len == 1 and (pos_ids.numel() == B):
113
  pos_ids = pos_ids.view(B)
114
  for i in range(B):
115
  pi = int(pos_ids[i].item())
116
  kout[i, :, pi, :] = k[i, :, 0, :]
117
  vout[i, :, pi, :] = v[i, :, 0, :]
118
  return kout, vout
119
-
120
- # Case C: scalar + 1-step
121
- if pos_ids.dim() == 0 and q_len == 1:
122
  pi = int(pos_ids.item())
123
  kout[:, :, pi, :] = k[:, :, 0, :]
124
  vout[:, :, pi, :] = v[:, :, 0, :]
125
  return kout, vout
126
-
127
  raise RuntimeError(f"Unsupported KV update combo: k={tuple(k.shape)}, pos_ids={tuple(pos_ids.shape)}")
128
 
129
 
130
 
131
 
132
 
 
133
  class MoondreamModel(nn.Module):
134
 
135
  def __init__(
@@ -201,7 +200,6 @@ class MoondreamModel(nn.Module):
201
  if setup_caches:
202
  self._setup_caches()
203
 
204
-
205
  def _reset_kv_caches(self, batch_size: int = 1):
206
  c = self.config.text
207
  head_dim = c.dim // c.n_heads
@@ -569,13 +567,13 @@ class MoondreamModel(nn.Module):
569
  image: Union[Image.Image, EncodedImage],
570
  settings: Optional[ImageEncodingSettings] = None,
571
  ) -> EncodedImage:
572
- # Always start from single-row caches; avoids leftovers from batched runs
573
- self._setup_caches()
574
- for blk in self.text.blocks:
575
  if blk.kv_cache.k_cache.size(0) != 1:
576
  blk.kv_cache.k_cache = blk.kv_cache.k_cache[:1].contiguous()
577
  blk.kv_cache.v_cache = blk.kv_cache.v_cache[:1].contiguous()
578
-
579
  if isinstance(image, EncodedImage):
580
  return image
581
  if not isinstance(image, Image.Image):
@@ -908,6 +906,7 @@ class MoondreamModel(nn.Module):
908
 
909
 
910
 
 
911
  def _prefill_prompt_batched(self, labels, pos: int, lora=None,
912
  temperature: float = 0.0, top_p: float = 0.0):
913
  tpl = self.config.tokenizer.templates["detect"]
@@ -926,42 +925,39 @@ class MoondreamModel(nn.Module):
926
  for i, ids in enumerate(rows):
927
  prompt_ids[i, : ids.numel()] = ids
928
 
929
- prompt_emb = text_encoder(prompt_ids, self.text) # (B,T,C)
930
  torch._dynamo.mark_dynamic(prompt_emb, 1)
931
 
932
  base = self.attn_mask[:, :, pos:pos+T, :] # (1,1,T,K)
933
- mask = base.expand(B, -1, -1, -1).contiguous() # (B,1,T,K)
934
  pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long) # (T,)
935
 
936
- hidden_BTC = self._prefill(prompt_emb, mask, pos_ids, lora) # (B,T,C)
937
- logits_BTV = lm_head(hidden_BTC, self.text) # (B,T,V)
938
 
939
- idx = (torch.tensor(lens, device=self.device) - 1).clamp_min(0) # (B,)
940
  last_hidden = hidden_BTC[torch.arange(B, device=self.device), idx][:, None, :] # (B,1,C)
941
  last_logits = logits_BTV[torch.arange(B, device=self.device), idx] # (B,V)
942
 
943
  if temperature == 0.0:
944
- next_token = last_logits.argmax(dim=-1, keepdim=True) # (B,1)
945
  else:
946
  probs = torch.softmax(last_logits / temperature, dim=-1)
947
  probs = self._apply_top_p(probs, top_p)
948
- next_token = torch.multinomial(probs, num_samples=1) # (B,1)
949
 
950
- pos_end = int(pos + T) # shared next-slot
951
- return last_hidden, next_token, pos_end
952
-
953
-
954
 
955
 
956
  def _generate_points_batched(
957
  self,
958
  hidden, # (B,1,C)
959
- next_token, # (B,1) (kept for API compatibility)
960
- pos, # int or Tensor; normalized below
961
  include_size: bool = True,
962
  max_objects: int = 50,
963
  lora=None,
964
- use_soft_argmax: bool = True, # reduces bbox jitter
965
  ):
966
  B = hidden.size(0)
967
  device = self.device
@@ -969,156 +965,40 @@ class MoondreamModel(nn.Module):
969
  eos_id = self.config.tokenizer.eos_id
970
  max_ctx = self.config.text.max_context
971
 
972
- # Normalize pos to a scalar int
973
- if torch.is_tensor(pos):
974
- pos = int(pos.max().item())
975
-
976
- # 4-D mask: (B,1,1,K) and per-row pos ids (B,1)
977
  mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
978
  if pos > 0:
979
  mask[:, :, :, :pos] = True
980
  pos_ids = torch.full((B, 1), pos, device=device, dtype=torch.long)
981
 
982
- alive = torch.ones(B, dtype=torch.bool, device=device)
983
  counts = torch.zeros(B, dtype=torch.int32, device=device)
984
 
985
- def _argmax01(logits: torch.Tensor) -> torch.Tensor:
986
- """
987
- logits: (..., bins) -> (B,) in [0,1]
988
- Accepts (B,1,bins), (B,bins), or (bins,)
989
- """
990
- # Canonicalize to (B,bins)
991
- if logits.dim() == 3: # (B,1,bins)
992
- logits = logits.squeeze(1)
993
- elif logits.dim() == 1: # (bins,)
994
- logits = logits.unsqueeze(0)
995
- # If batch accidentally collapsed to 1, expand to B so downstream indexing is safe.
996
- if logits.size(0) == 1 and B > 1:
997
- logits = logits.expand(B, -1)
998
-
999
  if use_soft_argmax:
1000
- probs = torch.softmax(logits, dim=-1)
1001
- bins = torch.arange(probs.size(-1), device=probs.device, dtype=torch.float32)
1002
- expbin = (probs * bins).sum(dim=-1)
1003
- return expbin / float(probs.size(-1) - 1)
1004
- else:
1005
- idx = logits.argmax(dim=-1).to(torch.float32)
1006
- return idx / float(logits.size(-1) - 1)
1007
-
1008
- def _ensure_b(vec: torch.Tensor) -> torch.Tensor:
1009
- """
1010
- Make sure 1D tensors are length-B for safe indexing.
1011
- Accepts scalar/(), (1,), (B,), returns (B,)
1012
- """
1013
- if vec.dim() == 0:
1014
- return vec.repeat(B)
1015
- if vec.dim() == 1 and vec.numel() == 1 and B > 1:
1016
- return vec.repeat(B)
1017
- if vec.dim() == 1 and vec.numel() == B:
1018
- return vec
1019
- raise RuntimeError(f"Expected (B,) vec, got shape {tuple(vec.shape)} for B={B}")
1020
 
1021
  with torch.inference_mode():
1022
  while alive.any() and (counts < max_objects).any():
1023
- # ---- x ------------------------------------------------------
1024
- x_logits = decode_coordinate(hidden, self.region) # (B,1,b) or (B,b) or (b,)
1025
- x_center = _argmax01(x_logits) # (B,)
1026
- x_center = _ensure_b(x_center) # force len B
1027
- x_in = x_center.to(dtype=hidden.dtype).unsqueeze(-1) # (B,1)
1028
- x_emb = encode_coordinate(x_in, self.region).unsqueeze(1) # (B,1,C)
1029
-
1030
  mask[alive, :, :, pos] = True
1031
- logits, hidden = self._decode_one_tok(x_emb, mask, pos_ids, lora)
1032
- pos_ids[alive, 0] += 1
1033
- pos += 1
1034
 
1035
- # ---- y ------------------------------------------------------
1036
  y_logits = decode_coordinate(hidden, self.region)
1037
- y_center = _argmax01(y_logits) # (B,)
1038
- y_center = _ensure_b(y_center)
1039
- y_in = y_center.to(dtype=hidden.dtype).unsqueeze(-1)
1040
- y_emb = encode_coordinate(y_in, self.region).unsqueeze(1)
1041
-
1042
- mask[alive, :, :, pos] = True
1043
- logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
1044
- pos_ids[alive, 0] += 1
1045
- pos += 1
1046
-
1047
- if include_size:
1048
- # ---- size ----------------------------------------------
1049
- size_logits = decode_size(hidden, self.region) # tuple: (w_logits, h_logits)
1050
- w_logits, h_logits = size_logits
1051
-
1052
- # Canonicalize to (B,bins); expand if batch collapsed
1053
- if w_logits.dim() == 3: w_logits = w_logits.squeeze(1)
1054
- if h_logits.dim() == 3: h_logits = h_logits.squeeze(1)
1055
- if w_logits.dim() == 1: w_logits = w_logits.unsqueeze(0)
1056
- if h_logits.dim() == 1: h_logits = h_logits.unsqueeze(0)
1057
- if w_logits.size(0) == 1 and B > 1: w_logits = w_logits.expand(B, -1)
1058
- if h_logits.size(0) == 1 and B > 1: h_logits = h_logits.expand(B, -1)
1059
-
1060
- if use_soft_argmax:
1061
- w_probs = torch.softmax(w_logits, dim=-1)
1062
- h_probs = torch.softmax(h_logits, dim=-1)
1063
- bins_w = torch.arange(w_probs.size(-1), device=device, dtype=torch.float32)
1064
- bins_h = torch.arange(h_probs.size(-1), device=device, dtype=torch.float32)
1065
- w_bin = (w_probs * bins_w).sum(dim=-1) # (B,)
1066
- h_bin = (h_probs * bins_h).sum(dim=-1) # (B,)
1067
- else:
1068
- w_bin = w_logits.argmax(dim=-1).to(torch.float32)
1069
- h_bin = h_logits.argmax(dim=-1).to(torch.float32)
1070
-
1071
- # bins -> size (inverse log scale), robust to bins != 1024
1072
- w_den = float(w_logits.size(-1) - 1)
1073
- h_den = float(h_logits.size(-1) - 1)
1074
- w = torch.pow(2.0, (w_bin / w_den) * 10.0 - 10.0)
1075
- h = torch.pow(2.0, (h_bin / h_den) * 10.0 - 10.0)
1076
-
1077
- # enforce (B,)
1078
- w = _ensure_b(w); h = _ensure_b(h)
1079
-
1080
- size_in = torch.stack([w, h], dim=1).to(dtype=hidden.dtype) # (B,2)
1081
- size_emb = encode_size(size_in, self.region).unsqueeze(1) # (B,1,C)
1082
-
1083
- # record boxes only for alive rows
1084
- for i in range(B):
1085
- if not alive[i]:
1086
- continue
1087
- xl = (x_center[i] - w[i] / 2).item()
1088
- xr = (x_center[i] + w[i] / 2).item()
1089
- yt = (y_center[i] - h[i] / 2).item()
1090
- yb = (y_center[i] + h[i] / 2).item()
1091
- out[i].append({
1092
- "x_min": max(0.0, min(1.0, xl)),
1093
- "y_min": max(0.0, min(1.0, yt)),
1094
- "x_max": max(0.0, min(1.0, xr)),
1095
- "y_max": max(0.0, min(1.0, yb)),
1096
- })
1097
-
1098
- mask[alive, :, :, pos] = True
1099
- logits, hidden = self._decode_one_tok(size_emb, mask, pos_ids, lora)
1100
- pos_ids[alive, 0] += 1
1101
- pos += 1
1102
- next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
1103
- else:
1104
- for i in range(B):
1105
- if alive[i]:
1106
- out[i].append({"x": float(x_center[i]), "y": float(y_center[i])})
1107
- mask[alive, :, :, pos] = True
1108
- logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
1109
- pos_ids[alive, 0] += 1
1110
- pos += 1
1111
- next_tok = logits.argmax(dim=-1).squeeze(-1)
1112
-
1113
- # stop only rows that hit eos (or reached max objects)
1114
- finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
1115
- counts = counts + ((~finished_now) & alive).to(counts.dtype)
1116
- alive &= ~finished_now
1117
-
1118
- return out
1119
-
1120
-
1121
-
1122
 
1123
 
1124
 
 
76
  Supports:
77
  • Prefill: k,v = (B, n_kv_heads, q_len, d), pos_ids = (q_len,)
78
  • 1-step: k,v = (B, n_kv_heads, 1, d), pos_ids = (B,1) or (B,)
79
+ • Legacy: k,v = (B, n_kv_heads, 1, d), pos_ids = scalar
80
  Writes into self.k_cache/self.v_cache shaped (B, n_kv_heads, T_max, d).
81
  """
 
 
 
82
  if not torch.is_tensor(pos_ids):
83
  pos_ids = torch.tensor(pos_ids, device=k.device, dtype=torch.long)
84
  else:
85
  pos_ids = pos_ids.to(device=k.device, dtype=torch.long)
86
+
87
  if k.dim() != 4 or v.dim() != 4:
88
+ raise RuntimeError(f"KV update expects k,v 4D. Got k={tuple(k.shape)} v={tuple(v.shape)}")
89
+
90
  B, Hkv, q_len, D = k.shape
91
+ kout, vout = self.k_cache, self.v_cache
92
+
93
+ # Expand caches from B=1 lazily if needed
94
  if kout.size(0) != B:
95
  if kout.size(0) == 1:
96
  self.k_cache = kout.expand(B, -1, -1, -1).clone()
 
98
  kout, vout = self.k_cache, self.v_cache
99
  else:
100
  raise RuntimeError(f"KV cache batch mismatch: cache.B={kout.size(0)} vs k.B={B}")
101
+
102
+ # Case A: prefill (same positions for every row)
103
  if pos_ids.dim() == 1 and pos_ids.numel() == q_len:
104
  for i in range(B):
105
  kout[i, :, pos_ids, :] = k[i]
106
  vout[i, :, pos_ids, :] = v[i]
107
  return kout, vout
108
+
109
+ # Case B: single step with per-row position (B,) or (B,1)
110
+ if q_len == 1 and pos_ids.numel() == B:
111
  pos_ids = pos_ids.view(B)
112
  for i in range(B):
113
  pi = int(pos_ids[i].item())
114
  kout[i, :, pi, :] = k[i, :, 0, :]
115
  vout[i, :, pi, :] = v[i, :, 0, :]
116
  return kout, vout
117
+
118
+ # Case C: scalar position for everyone
119
+ if q_len == 1 and pos_ids.dim() == 0:
120
  pi = int(pos_ids.item())
121
  kout[:, :, pi, :] = k[:, :, 0, :]
122
  vout[:, :, pi, :] = v[:, :, 0, :]
123
  return kout, vout
124
+
125
  raise RuntimeError(f"Unsupported KV update combo: k={tuple(k.shape)}, pos_ids={tuple(pos_ids.shape)}")
126
 
127
 
128
 
129
 
130
 
131
+
132
  class MoondreamModel(nn.Module):
133
 
134
  def __init__(
 
200
  if setup_caches:
201
  self._setup_caches()
202
 
 
203
  def _reset_kv_caches(self, batch_size: int = 1):
204
  c = self.config.text
205
  head_dim = c.dim // c.n_heads
 
567
  image: Union[Image.Image, EncodedImage],
568
  settings: Optional[ImageEncodingSettings] = None,
569
  ) -> EncodedImage:
570
+ # Top of encode_image(), just after type checks:
571
+ self._setup_caches() # re-create caches
572
+ for blk in self.text.blocks: # force B=1 for encode
573
  if blk.kv_cache.k_cache.size(0) != 1:
574
  blk.kv_cache.k_cache = blk.kv_cache.k_cache[:1].contiguous()
575
  blk.kv_cache.v_cache = blk.kv_cache.v_cache[:1].contiguous()
576
+
577
  if isinstance(image, EncodedImage):
578
  return image
579
  if not isinstance(image, Image.Image):
 
906
 
907
 
908
 
909
+
910
  def _prefill_prompt_batched(self, labels, pos: int, lora=None,
911
  temperature: float = 0.0, top_p: float = 0.0):
912
  tpl = self.config.tokenizer.templates["detect"]
 
925
  for i, ids in enumerate(rows):
926
  prompt_ids[i, : ids.numel()] = ids
927
 
928
+ prompt_emb = text_encoder(prompt_ids, self.text) # (B,T,C)
929
  torch._dynamo.mark_dynamic(prompt_emb, 1)
930
 
931
  base = self.attn_mask[:, :, pos:pos+T, :] # (1,1,T,K)
932
+ mask = base.expand(B, -1, -1, -1).contiguous() # (B,1,T,K)
933
  pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long) # (T,)
934
 
935
+ hidden_BTC = self._prefill(prompt_emb, mask, pos_ids, lora) # (B,T,C)
936
+ logits_BTV = lm_head(hidden_BTC, self.text)
937
 
938
+ idx = (torch.tensor(lens, device=self.device) - 1).clamp_min(0) # (B,)
939
  last_hidden = hidden_BTC[torch.arange(B, device=self.device), idx][:, None, :] # (B,1,C)
940
  last_logits = logits_BTV[torch.arange(B, device=self.device), idx] # (B,V)
941
 
942
  if temperature == 0.0:
943
+ next_token = last_logits.argmax(dim=-1, keepdim=True) # (B,1)
944
  else:
945
  probs = torch.softmax(last_logits / temperature, dim=-1)
946
  probs = self._apply_top_p(probs, top_p)
947
+ next_token = torch.multinomial(probs, num_samples=1) # (B,1)
948
 
949
+ return last_hidden, next_token, int(pos + T)
 
 
 
950
 
951
 
952
  def _generate_points_batched(
953
  self,
954
  hidden, # (B,1,C)
955
+ next_token, # (B,1) (not used with greedy coords; kept for API)
956
+ pos, # int, next free KV slot
957
  include_size: bool = True,
958
  max_objects: int = 50,
959
  lora=None,
960
+ use_soft_argmax: bool = True,
961
  ):
962
  B = hidden.size(0)
963
  device = self.device
 
965
  eos_id = self.config.tokenizer.eos_id
966
  max_ctx = self.config.text.max_context
967
 
968
+ # mask & position ids
 
 
 
 
969
  mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
970
  if pos > 0:
971
  mask[:, :, :, :pos] = True
972
  pos_ids = torch.full((B, 1), pos, device=device, dtype=torch.long)
973
 
974
+ alive = torch.ones(B, dtype=torch.bool, device=device)
975
  counts = torch.zeros(B, dtype=torch.int32, device=device)
976
 
977
+ def _argmax01(logits_2d):
978
+ # logits_2d: (B, bins)
 
 
 
 
 
 
 
 
 
 
 
 
979
  if use_soft_argmax:
980
+ probs = torch.softmax(logits_2d, dim=-1)
981
+ bins = torch.arange(probs.size(-1), device=logits_2d.device, dtype=torch.float32)
982
+ val = (probs * bins).sum(dim=-1) / (probs.size(-1) - 1)
983
+ return val # in [0,1]
984
+ idx = logits_2d.argmax(dim=-1).to(torch.float32)
985
+ return idx / float(logits_2d.size(-1) - 1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
986
 
987
  with torch.inference_mode():
988
  while alive.any() and (counts < max_objects).any():
989
+ # x
990
+ x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
991
+ if x_logits.dim() == 3: x_logits = x_logits.squeeze(1)
992
+ x_center = _argmax01(x_logits) # (B,)
993
+ x_emb = encode_coordinate(x_center.to(dtype=x_logits.dtype).unsqueeze(-1),
994
+ self.region).unsqueeze(1) # (B,1,C)
 
995
  mask[alive, :, :, pos] = True
996
+ _, hidden = self._decode_one_tok(x_emb, mask, pos_ids, lora)
997
+ pos_ids[alive, 0] += 1; pos += 1
 
998
 
999
+ # y
1000
  y_logits = decode_coordinate(hidden, self.region)
1001
+ if y_logits.dim
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1002
 
1003
 
1004