HV-Khurdula commited on
Commit
36f4434
·
verified ·
1 Parent(s): fc77069

Update moondream.py

Browse files

fix: tensor mismatch between kv cache and batched generation

Files changed (1) hide show
  1. moondream.py +79 -61
moondream.py CHANGED
@@ -77,50 +77,54 @@ class KVCache(nn.Module):
77
  "v_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
78
  )
79
 
80
- # --- replace the whole method in KVCache ---
81
  def update(self, pos_ids, k, v):
82
  """
83
- Supports both:
84
  • Prefill: k,v = (B, n_kv_heads, q_len, d), pos_ids = (q_len,)
85
- • 1-step: k,v = (B, n_kv_heads, 1, d), pos_ids = (B,)
86
- Writes into caches shaped (B, n_kv_heads, T_max, d).
 
87
  """
88
  kout, vout = self.k_cache, self.v_cache
89
 
90
  if not torch.is_tensor(pos_ids):
91
- # Scalar position for singleton batch (legacy)
92
  kout[:, :, pos_ids, :] = k
93
  vout[:, :, pos_ids, :] = v
94
  return kout, vout
95
 
96
- # Normalize dtype
97
  pos_ids = pos_ids.to(dtype=torch.long, device=k.device)
98
 
99
- # Shapes
100
  if k.dim() != 4 or v.dim() != 4:
101
  raise RuntimeError(f"KV update expects k,v 4D. Got k={tuple(k.shape)} v={tuple(v.shape)}")
102
  B, Hkv, q_len, D = k.shape
103
 
104
- # Ensure cache batch matches B
105
  if kout.size(0) != B:
106
- raise RuntimeError(f"KV cache batch mismatch: cache.B={kout.size(0)} vs k.B={B}")
 
 
 
 
 
107
 
108
- # Case A: PREFILL — per-row write of a whole span of positions
109
  if pos_ids.dim() == 1 and pos_ids.numel() == q_len:
110
  for i in range(B):
111
- kout[i, :, pos_ids, :] = k[i] # (Hkv, q_len, D)
112
  vout[i, :, pos_ids, :] = v[i]
113
  return kout, vout
114
 
115
- # Case B: STEP DECODE — one new position per row (q_len must be 1)
116
- if pos_ids.dim() == 1 and pos_ids.numel() == B and q_len == 1:
 
117
  for i in range(B):
118
- pi = int(pos_ids[i].item())
119
  kout[i, :, pi, :] = k[i, :, 0, :]
120
  vout[i, :, pi, :] = v[i, :, 0, :]
121
  return kout, vout
122
 
123
- # Optional legacy: scalar pos for everyone
124
  if pos_ids.dim() == 0 and q_len == 1:
125
  pi = int(pos_ids.item())
126
  kout[:, :, pi, :] = k[:, :, 0, :]
@@ -135,6 +139,7 @@ class KVCache(nn.Module):
135
 
136
 
137
 
 
138
  class MoondreamModel(nn.Module):
139
 
140
  def __init__(
@@ -968,93 +973,105 @@ class MoondreamModel(nn.Module):
968
 
969
  return last_hidden, next_token, pos_vec
970
 
971
- def _generate_points_batched(self, hidden, next_token, pos_vec,
972
- include_size: bool = True, max_objects: int = 50, lora=None):
 
 
 
 
 
 
 
 
 
 
 
 
973
  B = hidden.size(0)
974
  device = self.device
975
  out = [[] for _ in range(B)]
976
  eos_id = self.config.tokenizer.eos_id
977
  max_ctx = self.config.text.max_context
978
 
979
- # 4-D mask: (B,1,1,kv_len) and fill historical prefix per row
980
  mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
981
- for i in range(B):
982
- p = int(pos_vec[i].item())
983
- if p > 0: mask[i, 0, 0, :p] = True
 
984
 
985
- pos_ids = pos_vec.clone() # (B,)
986
- alive = torch.ones(B, dtype=torch.bool, device=device)
987
- counts = torch.zeros(B, dtype=torch.int32, device=device)
988
 
989
  with torch.inference_mode():
990
  while alive.any() and (counts < max_objects).any():
991
- # --- x ---
992
  x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
993
- if x_logits.dim() == 3: x_logits = x_logits.squeeze(1)
994
- x_bin = x_logits.argmax(dim=-1).to(torch.float32)
995
- x_center = x_bin / float(x_logits.size(-1)) # (B,)
996
  x_in = x_center.to(dtype=x_logits.dtype).unsqueeze(-1) # (B,1)
997
  x_emb = encode_coordinate(x_in, self.region).unsqueeze(1) # (B,1,C)
998
 
999
- # advance one position per *alive* row
1000
- for i in range(B):
1001
- if alive[i]: mask[i, 0, 0, int(pos_ids[i].item())] = True
1002
  logits, hidden = self._decode_one_tok(x_emb, mask, pos_ids, lora)
1003
- pos_ids = pos_ids + alive.to(torch.long)
 
1004
 
1005
- # --- y ---
1006
  y_logits = decode_coordinate(hidden, self.region)
1007
- if y_logits.dim() == 3: y_logits = y_logits.squeeze(1)
1008
- y_bin = y_logits.argmax(dim=-1).to(torch.float32)
1009
- y_center = y_bin / float(y_logits.size(-1))
1010
- y_in = y_center.to(dtype=y_logits.dtype).unsqueeze(-1)
1011
  y_emb = encode_coordinate(y_in, self.region).unsqueeze(1)
1012
 
1013
- for i in range(B):
1014
- if alive[i]: mask[i, 0, 0, int(pos_ids[i].item())] = True
1015
  logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
1016
- pos_ids = pos_ids + alive.to(torch.long)
 
1017
 
1018
  if include_size:
1019
- size_logits = decode_size(hidden, self.region) # Expect [(B,1,1024),(B,1,1024)] or (tuple)
1020
- # be robust to either rank
1021
- w_logits = size_logits[0].squeeze(1) if size_logits[0].dim() == 3 else size_logits[0]
1022
- h_logits = size_logits[1].squeeze(1) if size_logits[1].dim() == 3 else size_logits[1]
1023
-
1024
  w_bin = w_logits.argmax(dim=-1).to(torch.float32)
1025
  h_bin = h_logits.argmax(dim=-1).to(torch.float32)
1026
  w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
1027
  h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
1028
- size_in = torch.stack([w, h], dim=1).to(dtype=w_logits.dtype) # (B,2)
1029
  size_emb = encode_size(size_in, self.region).unsqueeze(1) # (B,1,C)
1030
 
 
1031
  for i in range(B):
1032
- if not alive[i]: continue
1033
- out[i].append({
1034
- "x_min": (x_center[i] - w[i] / 2).item(),
1035
- "y_min": (y_center[i] - h[i] / 2).item(),
1036
- "x_max": (x_center[i] + w[i] / 2).item(),
1037
- "y_max": (y_center[i] + h[i] / 2).item(),
1038
- })
1039
 
1040
- for i in range(B):
1041
- if alive[i]: mask[i, 0, 0, int(pos_ids[i].item())] = True
1042
  logits, hidden = self._decode_one_tok(size_emb, mask, pos_ids, lora)
1043
- pos_ids = pos_ids + alive.to(torch.long)
 
1044
  next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
1045
  else:
1046
  for i in range(B):
1047
  if alive[i]:
1048
  out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
1049
-
1050
- for i in range(B):
1051
- if alive[i]: mask[i, 0, 0, int(pos_ids[i].item())] = True
1052
  logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
1053
- pos_ids = pos_ids + alive.to(torch.long)
 
1054
  next_tok = logits.argmax(dim=-1).squeeze(-1)
1055
 
1056
  finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
1057
- counts = counts + ((~finished_now) & alive).to(counts.dtype)
1058
  alive &= ~finished_now
1059
 
1060
  return out
@@ -1062,6 +1079,7 @@ class MoondreamModel(nn.Module):
1062
 
1063
 
1064
 
 
1065
  def detect_multi(self, image, objects, settings=None):
1066
  """
1067
  Parallel multi-label detection.
 
77
  "v_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
78
  )
79
 
 
80
  def update(self, pos_ids, k, v):
81
  """
82
+ Supports:
83
  • Prefill: k,v = (B, n_kv_heads, q_len, d), pos_ids = (q_len,)
84
+ • 1-step: k,v = (B, n_kv_heads, 1, d), pos_ids = (B,) or (B,1)
85
+ Legacy: k,v = (1, n_kv_heads, q_len, d), pos_ids = scalar int
86
+ Writes into self.k_cache/self.v_cache shaped (B, n_kv_heads, T_max, d).
87
  """
88
  kout, vout = self.k_cache, self.v_cache
89
 
90
  if not torch.is_tensor(pos_ids):
91
+ # Scalar legacy path
92
  kout[:, :, pos_ids, :] = k
93
  vout[:, :, pos_ids, :] = v
94
  return kout, vout
95
 
 
96
  pos_ids = pos_ids.to(dtype=torch.long, device=k.device)
97
 
 
98
  if k.dim() != 4 or v.dim() != 4:
99
  raise RuntimeError(f"KV update expects k,v 4D. Got k={tuple(k.shape)} v={tuple(v.shape)}")
100
  B, Hkv, q_len, D = k.shape
101
 
102
+ # Make sure cache batch matches B (expand-from-1 is ok, otherwise error)
103
  if kout.size(0) != B:
104
+ if kout.size(0) == 1:
105
+ self.k_cache = kout.expand(B, -1, -1, -1).clone()
106
+ self.v_cache = vout.expand(B, -1, -1, -1).clone()
107
+ kout, vout = self.k_cache, self.v_cache
108
+ else:
109
+ raise RuntimeError(f"KV cache batch mismatch: cache.B={kout.size(0)} vs k.B={B}")
110
 
111
+ # Case A: PREFILL — pos_ids indexes a contiguous range per row
112
  if pos_ids.dim() == 1 and pos_ids.numel() == q_len:
113
  for i in range(B):
114
+ kout[i, :, pos_ids, :] = k[i] # (Hkv, q_len, D)
115
  vout[i, :, pos_ids, :] = v[i]
116
  return kout, vout
117
 
118
+ # Case B: STEP — q_len == 1 and one position per row
119
+ if q_len == 1 and pos_ids.numel() == B:
120
+ pos_ids_flat = pos_ids.view(-1) # handle (B,1) or (B,)
121
  for i in range(B):
122
+ pi = int(pos_ids_flat[i].item())
123
  kout[i, :, pi, :] = k[i, :, 0, :]
124
  vout[i, :, pi, :] = v[i, :, 0, :]
125
  return kout, vout
126
 
127
+ # Case C: scalar for everyone
128
  if pos_ids.dim() == 0 and q_len == 1:
129
  pi = int(pos_ids.item())
130
  kout[:, :, pi, :] = k[:, :, 0, :]
 
139
 
140
 
141
 
142
+
143
  class MoondreamModel(nn.Module):
144
 
145
  def __init__(
 
973
 
974
  return last_hidden, next_token, pos_vec
975
 
976
+ # In class MoondreamModel, replace the whole method:
977
+ def _generate_points_batched(
978
+ self,
979
+ hidden, # (B,1,C)
980
+ next_token, # (B,1) (not used when temperature=0, but ok)
981
+ pos: int, # shared scalar next position
982
+ include_size: bool = True,
983
+ max_objects: int = 50,
984
+ lora=None,
985
+ ):
986
+ """
987
+ Vectorized version of _generate_points() that decodes x -> y -> size -> next-token
988
+ for all rows in the batch simultaneously. Returns list-of-lists of dicts, len B.
989
+ """
990
  B = hidden.size(0)
991
  device = self.device
992
  out = [[] for _ in range(B)]
993
  eos_id = self.config.tokenizer.eos_id
994
  max_ctx = self.config.text.max_context
995
 
996
+ # 4-D mask: (B, 1, q_len=1, kv_len)
997
  mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
998
+ if pos > 0:
999
+ mask[:, :, :, :pos] = True
1000
+ # IMPORTANT: position_ids must be (B, 1) for rotary; KVCache.update accepts (B,1) too
1001
+ pos_ids = torch.full((B, 1), pos, device=device, dtype=torch.long)
1002
 
1003
+ alive = torch.ones(B, dtype=torch.bool, device=device)
1004
+ counts = torch.zeros(B, dtype=torch.int32, device=device)
 
1005
 
1006
  with torch.inference_mode():
1007
  while alive.any() and (counts < max_objects).any():
1008
+ # --- x coordinate ---
1009
  x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
1010
+ if x_logits.dim() == 3:
1011
+ x_logits = x_logits.squeeze(1)
1012
+ x_center = x_logits.argmax(dim=-1).to(torch.float32) / float(x_logits.size(-1)) # (B,)
1013
  x_in = x_center.to(dtype=x_logits.dtype).unsqueeze(-1) # (B,1)
1014
  x_emb = encode_coordinate(x_in, self.region).unsqueeze(1) # (B,1,C)
1015
 
1016
+ # advance attention one step
1017
+ mask[:, :, :, pos] = True
 
1018
  logits, hidden = self._decode_one_tok(x_emb, mask, pos_ids, lora)
1019
+ pos += 1
1020
+ pos_ids[:, 0] = pos
1021
 
1022
+ # --- y coordinate ---
1023
  y_logits = decode_coordinate(hidden, self.region)
1024
+ if y_logits.dim() == 3:
1025
+ y_logits = y_logits.squeeze(1)
1026
+ y_center = y_logits.argmax(dim=-1).to(torch.float32) / float(y_logits.size(-1))
1027
+ y_in = y_center.to(dtype=y_logits.dtype).unsqueeze(-1) # (B,1)
1028
  y_emb = encode_coordinate(y_in, self.region).unsqueeze(1)
1029
 
1030
+ mask[:, :, :, pos] = True
 
1031
  logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
1032
+ pos += 1
1033
+ pos_ids[:, 0] = pos
1034
 
1035
  if include_size:
1036
+ # --- size ---
1037
+ size_logits = decode_size(hidden, self.region) # tuple/list [w_logits, h_logits]
1038
+ # Support both (B,1,1024) and (B,1024)
1039
+ w_logits = size_logits[0].squeeze(1)
1040
+ h_logits = size_logits[1].squeeze(1)
1041
  w_bin = w_logits.argmax(dim=-1).to(torch.float32)
1042
  h_bin = h_logits.argmax(dim=-1).to(torch.float32)
1043
  w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
1044
  h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
1045
+ size_in = torch.stack([w, h], dim=1).to(dtype=w_logits.dtype) # (B,2)
1046
  size_emb = encode_size(size_in, self.region).unsqueeze(1) # (B,1,C)
1047
 
1048
+ # record boxes
1049
  for i in range(B):
1050
+ if alive[i]:
1051
+ out[i].append({
1052
+ "x_min": (x_center[i] - w[i] / 2).item(),
1053
+ "y_min": (y_center[i] - h[i] / 2).item(),
1054
+ "x_max": (x_center[i] + w[i] / 2).item(),
1055
+ "y_max": (y_center[i] + h[i] / 2).item(),
1056
+ })
1057
 
1058
+ mask[:, :, :, pos] = True
 
1059
  logits, hidden = self._decode_one_tok(size_emb, mask, pos_ids, lora)
1060
+ pos += 1
1061
+ pos_ids[:, 0] = pos
1062
  next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
1063
  else:
1064
  for i in range(B):
1065
  if alive[i]:
1066
  out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
1067
+ mask[:, :, :, pos] = True
 
 
1068
  logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
1069
+ pos += 1
1070
+ pos_ids[:, 0] = pos
1071
  next_tok = logits.argmax(dim=-1).squeeze(-1)
1072
 
1073
  finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
1074
+ counts = counts + (~finished_now & alive).to(counts.dtype)
1075
  alive &= ~finished_now
1076
 
1077
  return out
 
1079
 
1080
 
1081
 
1082
+
1083
  def detect_multi(self, image, objects, settings=None):
1084
  """
1085
  Parallel multi-label detection.