HV-Khurdula commited on
Commit
0949437
·
verified ·
1 Parent(s): 18ff09d

Update moondream.py

Browse files
Files changed (1) hide show
  1. moondream.py +44 -61
moondream.py CHANGED
@@ -77,18 +77,16 @@ class KVCache(nn.Module):
77
  "v_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
78
  )
79
 
80
- # In class KVCache, REPLACE the whole update() with this:
81
  def update(self, pos_ids, k, v):
82
  """
83
  Supports:
84
  • Prefill: k,v = (B, n_kv_heads, q_len, d), pos_ids = (q_len,)
85
  • 1-step: k,v = (B, n_kv_heads, 1, d), pos_ids = (B,1) or (B,)
86
- Legacy: k,v = (1, n_kv_heads, q_len, d), pos_ids = scalar
87
  Writes into self.k_cache/self.v_cache shaped (B, n_kv_heads, T_max, d).
88
  """
89
  kout, vout = self.k_cache, self.v_cache
90
 
91
- # Normalize pos_ids
92
  if not torch.is_tensor(pos_ids):
93
  pos_ids = torch.tensor(pos_ids, device=k.device, dtype=torch.long)
94
  else:
@@ -98,7 +96,7 @@ class KVCache(nn.Module):
98
  raise RuntimeError(f"KV update expects k,v 4D. Got k={tuple(k.shape)} v={tuple(v.shape)}")
99
  B, Hkv, q_len, D = k.shape
100
 
101
- # Ensure cache batch matches B (expand-from-1 allowed)
102
  if kout.size(0) != B:
103
  if kout.size(0) == 1:
104
  self.k_cache = kout.expand(B, -1, -1, -1).clone()
@@ -107,14 +105,14 @@ class KVCache(nn.Module):
107
  else:
108
  raise RuntimeError(f"KV cache batch mismatch: cache.B={kout.size(0)} vs k.B={B}")
109
 
110
- # Case A: PREFILL vector of length q_len (same for all B rows)
111
  if pos_ids.dim() == 1 and pos_ids.numel() == q_len:
112
  for i in range(B):
113
- kout[i, :, pos_ids, :] = k[i] # (Hkv, q_len, D)
114
  vout[i, :, pos_ids, :] = v[i]
115
  return kout, vout
116
 
117
- # Case B: 1-STEP q_len == 1 with (B,) or (B,1) per-row positions
118
  if q_len == 1 and pos_ids.numel() == B:
119
  pos_ids = pos_ids.view(B)
120
  for i in range(B):
@@ -123,16 +121,15 @@ class KVCache(nn.Module):
123
  vout[i, :, pi, :] = v[i, :, 0, :]
124
  return kout, vout
125
 
126
- # Case C: scalar for everyone & q_len == 1
127
  if pos_ids.dim() == 0 and q_len == 1:
128
  pi = int(pos_ids.item())
129
  kout[:, :, pi, :] = k[:, :, 0, :]
130
  vout[:, :, pi, :] = v[:, :, 0, :]
131
  return kout, vout
132
 
133
- raise RuntimeError(
134
- f"Unsupported KV update combo: k={tuple(k.shape)}, pos_ids={tuple(pos_ids.shape)}"
135
- )
136
 
137
 
138
 
@@ -213,10 +210,6 @@ class MoondreamModel(nn.Module):
213
 
214
 
215
  def _reset_kv_caches(self, batch_size: int = 1):
216
- """
217
- Recreate KV caches with the requested batch size so subsequent calls
218
- (e.g., encode_image) start from a consistent shape.
219
- """
220
  c = self.config.text
221
  head_dim = c.dim // c.n_heads
222
  for blk in self.text.blocks:
@@ -225,6 +218,7 @@ class MoondreamModel(nn.Module):
225
  shape = (batch_size, c.n_kv_heads, c.max_context, head_dim)
226
  blk.kv_cache.k_cache = torch.zeros(shape, device=device, dtype=dtype)
227
  blk.kv_cache.v_cache = torch.zeros(shape, device=device, dtype=dtype)
 
228
 
229
 
230
  def _setup_caches(self):
@@ -589,12 +583,13 @@ class MoondreamModel(nn.Module):
589
  elif not isinstance(image, Image.Image):
590
  raise ValueError("image must be a PIL Image or EncodedImage")
591
 
592
- # At the VERY TOP of encode_image(), right after the type checks:
593
  for blk in self.text.blocks:
594
  if blk.kv_cache.k_cache.size(0) != 1:
595
  blk.kv_cache.k_cache = blk.kv_cache.k_cache[:1].contiguous()
596
  blk.kv_cache.v_cache = blk.kv_cache.v_cache[:1].contiguous()
597
 
 
598
 
599
  lora = (
600
  variant_state_dict(settings["variant"], device=self.device)
@@ -933,6 +928,7 @@ class MoondreamModel(nn.Module):
933
  b.kv_cache.k_cache[:, :, :T, :] = k.expand(batch_size, -1, -1, -1)
934
  b.kv_cache.v_cache[:, :, :T, :] = v.expand(batch_size, -1, -1, -1)
935
 
 
936
  def _prefill_prompt_batched(self, labels, pos: int, lora=None,
937
  temperature: float = 0.0, top_p: float = 0.0):
938
  tpl = self.config.tokenizer.templates["detect"]
@@ -949,12 +945,11 @@ class MoondreamModel(nn.Module):
949
 
950
  prompt_ids = torch.full((B, T), eos, device=self.device, dtype=torch.long)
951
  for i, ids in enumerate(rows):
952
- prompt_ids[i, : ids.numel()] = ids
953
 
954
  prompt_emb = text_encoder(prompt_ids, self.text) # (B,T,C)
955
  torch._dynamo.mark_dynamic(prompt_emb, 1)
956
 
957
- # 4-D mask: (B,1,T,kv_len) for SDPA
958
  base = self.attn_mask[:, :, pos:pos+T, :] # (1,1,T,kv_len)
959
  mask = base.expand(B, -1, -1, -1).contiguous() # (B,1,T,kv_len)
960
 
@@ -967,29 +962,26 @@ class MoondreamModel(nn.Module):
967
  last_logits = logits_BTV[torch.arange(B, device=self.device), idx] # (B,V)
968
 
969
  if temperature == 0.0:
970
- next_token = last_logits.argmax(dim=-1, keepdim=True) # (B,1)
971
  else:
972
  probs = torch.softmax(last_logits / temperature, dim=-1)
973
  probs = self._apply_top_p(probs, top_p)
974
- next_token = torch.multinomial(probs, num_samples=1) # (B,1)
975
 
976
- # CRITICAL: per-row next position
977
- pos_vec = torch.tensor(lens, device=self.device, dtype=torch.long) + pos # (B,)
978
-
979
- # At the end of _prefill_prompt_batched(), return a Python int:
980
- pos_end = int((pos + T))
981
  return last_hidden, next_token, pos_end
982
 
983
 
 
984
  def _generate_points_batched(
985
  self,
986
  hidden, # (B,1,C)
987
- next_token, # (B,1) (unused in greedy, but OK)
988
  pos, # int or Tensor; normalized below
989
  include_size: bool = True,
990
  max_objects: int = 50,
991
  lora=None,
992
- use_soft_argmax: bool = True, # NEW: reduces jitter/hallucinations
993
  ):
994
  B = hidden.size(0)
995
  device = self.device
@@ -997,55 +989,48 @@ class MoondreamModel(nn.Module):
997
  eos_id = self.config.tokenizer.eos_id
998
  max_ctx = self.config.text.max_context
999
 
1000
- # Normalize pos to a scalar int (supports int, (1,), (B,), (B,1))
1001
  if torch.is_tensor(pos):
1002
- pos = int(pos.max().item()) # safe upper bound; we manage per-row with pos_ids/alive
1003
 
1004
- # 4-D mask: (B, 1, q_len=1, kv_len)
1005
  mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
1006
  if pos > 0:
1007
  mask[:, :, :, :pos] = True
1008
-
1009
- # position_ids must be (B,1)
1010
  pos_ids = torch.full((B, 1), pos, device=device, dtype=torch.long)
1011
 
1012
  alive = torch.ones(B, dtype=torch.bool, device=device)
1013
  counts = torch.zeros(B, dtype=torch.int32, device=device)
1014
 
1015
- # helpers ---------------------------------------------------------
1016
  def _argmax01(logits):
1017
- # logits: (B, bins)
1018
  if use_soft_argmax:
1019
  probs = torch.softmax(logits, dim=-1)
1020
  bins = torch.arange(probs.size(-1), device=logits.device, dtype=torch.float32)
1021
  idx = (probs * bins).sum(dim=-1) / (probs.size(-1) - 1)
1022
- return idx # 0..1
1023
  else:
1024
  idx = logits.argmax(dim=-1).to(torch.float32)
1025
  return idx / float(logits.size(-1) - 1)
1026
 
1027
  with torch.inference_mode():
1028
  while alive.any() and (counts < max_objects).any():
1029
- # --- x ---------------------------------------------------
1030
- x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
1031
- if x_logits.dim() == 3:
1032
- x_logits = x_logits.squeeze(1)
1033
- x_center = _argmax01(x_logits) # (B,) in [0,1]
1034
- x_in = x_center.to(dtype=x_logits.dtype).unsqueeze(-1) # (B,1)
1035
- x_emb = encode_coordinate(x_in, self.region).unsqueeze(1) # (B,1,C)
1036
 
1037
- # advance attention one step FOR ALIVE ROWS ONLY
1038
  mask[alive, :, :, pos] = True
1039
  logits, hidden = self._decode_one_tok(x_emb, mask, pos_ids, lora)
1040
  pos_ids[alive, 0] += 1
1041
- pos += 1 # scalar next free slot
1042
 
1043
- # --- y ---------------------------------------------------
1044
  y_logits = decode_coordinate(hidden, self.region)
1045
- if y_logits.dim() == 3:
1046
- y_logits = y_logits.squeeze(1)
1047
- y_center = _argmax01(y_logits) # (B,)
1048
- y_in = y_center.to(dtype=y_logits.dtype).unsqueeze(-1)
1049
  y_emb = encode_coordinate(y_in, self.region).unsqueeze(1)
1050
 
1051
  mask[alive, :, :, pos] = True
@@ -1054,32 +1039,31 @@ class MoondreamModel(nn.Module):
1054
  pos += 1
1055
 
1056
  if include_size:
1057
- # --- size --------------------------------------------
1058
  size_logits = decode_size(hidden, self.region)
1059
  w_logits = size_logits[0].squeeze(1)
1060
  h_logits = size_logits[1].squeeze(1)
1061
  if use_soft_argmax:
1062
- # convert expected-bin -> size (same mapping as paper/code)
1063
- w_bin = (torch.softmax(w_logits, dim=-1) * torch.arange(w_logits.size(-1), device=device)).sum(dim=-1)
1064
- h_bin = (torch.softmax(h_logits, dim=-1) * torch.arange(h_logits.size(-1), device=device)).sum(dim=-1)
 
1065
  else:
1066
  w_bin = w_logits.argmax(dim=-1).to(torch.float32)
1067
  h_bin = h_logits.argmax(dim=-1).to(torch.float32)
 
1068
  w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
1069
  h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
1070
 
1071
- size_in = torch.stack([w, h], dim=1).to(dtype=w_logits.dtype) # (B,2)
1072
- size_emb = encode_size(size_in, self.region).unsqueeze(1) # (B,1,C)
1073
 
1074
- # record boxes only for alive rows
1075
  for i in range(B):
1076
- if not alive[i]:
1077
- continue
1078
  xl = (x_center[i] - w[i] / 2).item()
1079
  xr = (x_center[i] + w[i] / 2).item()
1080
  yt = (y_center[i] - h[i] / 2).item()
1081
  yb = (y_center[i] + h[i] / 2).item()
1082
- # clamp for safety
1083
  out[i].append({
1084
  "x_min": max(0.0, min(1.0, xl)),
1085
  "y_min": max(0.0, min(1.0, yt)),
@@ -1091,7 +1075,7 @@ class MoondreamModel(nn.Module):
1091
  logits, hidden = self._decode_one_tok(size_emb, mask, pos_ids, lora)
1092
  pos_ids[alive, 0] += 1
1093
  pos += 1
1094
- next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
1095
  else:
1096
  for i in range(B):
1097
  if alive[i]:
@@ -1102,7 +1086,6 @@ class MoondreamModel(nn.Module):
1102
  pos += 1
1103
  next_tok = logits.argmax(dim=-1).squeeze(-1)
1104
 
1105
- # stop only rows that hit eos (or reached max objects)
1106
  finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
1107
  counts = counts + ((~finished_now) & alive).to(counts.dtype)
1108
  alive &= ~finished_now
 
77
  "v_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
78
  )
79
 
 
80
  def update(self, pos_ids, k, v):
81
  """
82
  Supports:
83
  • Prefill: k,v = (B, n_kv_heads, q_len, d), pos_ids = (q_len,)
84
  • 1-step: k,v = (B, n_kv_heads, 1, d), pos_ids = (B,1) or (B,)
85
+ Scalar: k,v = (B, n_kv_heads, 1, d), pos_ids = ()
86
  Writes into self.k_cache/self.v_cache shaped (B, n_kv_heads, T_max, d).
87
  """
88
  kout, vout = self.k_cache, self.v_cache
89
 
 
90
  if not torch.is_tensor(pos_ids):
91
  pos_ids = torch.tensor(pos_ids, device=k.device, dtype=torch.long)
92
  else:
 
96
  raise RuntimeError(f"KV update expects k,v 4D. Got k={tuple(k.shape)} v={tuple(v.shape)}")
97
  B, Hkv, q_len, D = k.shape
98
 
99
+ # Expand cache to batch B if needed (expand-from-1 allowed)
100
  if kout.size(0) != B:
101
  if kout.size(0) == 1:
102
  self.k_cache = kout.expand(B, -1, -1, -1).clone()
 
105
  else:
106
  raise RuntimeError(f"KV cache batch mismatch: cache.B={kout.size(0)} vs k.B={B}")
107
 
108
+ # A) Prefill: pos_ids = (q_len,)
109
  if pos_ids.dim() == 1 and pos_ids.numel() == q_len:
110
  for i in range(B):
111
+ kout[i, :, pos_ids, :] = k[i]
112
  vout[i, :, pos_ids, :] = v[i]
113
  return kout, vout
114
 
115
+ # B) One-step: q_len == 1 and pos_ids per row: (B,) or (B,1)
116
  if q_len == 1 and pos_ids.numel() == B:
117
  pos_ids = pos_ids.view(B)
118
  for i in range(B):
 
121
  vout[i, :, pi, :] = v[i, :, 0, :]
122
  return kout, vout
123
 
124
+ # C) Scalar position for everyone & q_len == 1
125
  if pos_ids.dim() == 0 and q_len == 1:
126
  pi = int(pos_ids.item())
127
  kout[:, :, pi, :] = k[:, :, 0, :]
128
  vout[:, :, pi, :] = v[:, :, 0, :]
129
  return kout, vout
130
 
131
+ raise RuntimeError(f"Unsupported KV update combo: k={tuple(k.shape)}, pos_ids={tuple(pos_ids.shape)}")
132
+
 
133
 
134
 
135
 
 
210
 
211
 
212
  def _reset_kv_caches(self, batch_size: int = 1):
 
 
 
 
213
  c = self.config.text
214
  head_dim = c.dim // c.n_heads
215
  for blk in self.text.blocks:
 
218
  shape = (batch_size, c.n_kv_heads, c.max_context, head_dim)
219
  blk.kv_cache.k_cache = torch.zeros(shape, device=device, dtype=dtype)
220
  blk.kv_cache.v_cache = torch.zeros(shape, device=device, dtype=dtype)
221
+
222
 
223
 
224
  def _setup_caches(self):
 
583
  elif not isinstance(image, Image.Image):
584
  raise ValueError("image must be a PIL Image or EncodedImage")
585
 
586
+ # Always start from single-row caches to avoid leftovers from batched runs
587
  for blk in self.text.blocks:
588
  if blk.kv_cache.k_cache.size(0) != 1:
589
  blk.kv_cache.k_cache = blk.kv_cache.k_cache[:1].contiguous()
590
  blk.kv_cache.v_cache = blk.kv_cache.v_cache[:1].contiguous()
591
 
592
+
593
 
594
  lora = (
595
  variant_state_dict(settings["variant"], device=self.device)
 
928
  b.kv_cache.k_cache[:, :, :T, :] = k.expand(batch_size, -1, -1, -1)
929
  b.kv_cache.v_cache[:, :, :T, :] = v.expand(batch_size, -1, -1, -1)
930
 
931
+
932
  def _prefill_prompt_batched(self, labels, pos: int, lora=None,
933
  temperature: float = 0.0, top_p: float = 0.0):
934
  tpl = self.config.tokenizer.templates["detect"]
 
945
 
946
  prompt_ids = torch.full((B, T), eos, device=self.device, dtype=torch.long)
947
  for i, ids in enumerate(rows):
948
+ prompt_ids[i, :ids.numel()] = ids
949
 
950
  prompt_emb = text_encoder(prompt_ids, self.text) # (B,T,C)
951
  torch._dynamo.mark_dynamic(prompt_emb, 1)
952
 
 
953
  base = self.attn_mask[:, :, pos:pos+T, :] # (1,1,T,kv_len)
954
  mask = base.expand(B, -1, -1, -1).contiguous() # (B,1,T,kv_len)
955
 
 
962
  last_logits = logits_BTV[torch.arange(B, device=self.device), idx] # (B,V)
963
 
964
  if temperature == 0.0:
965
+ next_token = last_logits.argmax(dim=-1, keepdim=True) # (B,1)
966
  else:
967
  probs = torch.softmax(last_logits / temperature, dim=-1)
968
  probs = self._apply_top_p(probs, top_p)
969
+ next_token = torch.multinomial(probs, num_samples=1) # (B,1)
970
 
971
+ pos_end = int(pos + T) # shared scalar end position
 
 
 
 
972
  return last_hidden, next_token, pos_end
973
 
974
 
975
+
976
  def _generate_points_batched(
977
  self,
978
  hidden, # (B,1,C)
979
+ next_token, # (B,1)
980
  pos, # int or Tensor; normalized below
981
  include_size: bool = True,
982
  max_objects: int = 50,
983
  lora=None,
984
+ use_soft_argmax: bool = True, # reduces bbox jitter
985
  ):
986
  B = hidden.size(0)
987
  device = self.device
 
989
  eos_id = self.config.tokenizer.eos_id
990
  max_ctx = self.config.text.max_context
991
 
 
992
  if torch.is_tensor(pos):
993
+ pos = int(pos.max().item())
994
 
995
+ # SDPA mask and position ids
996
  mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
997
  if pos > 0:
998
  mask[:, :, :, :pos] = True
 
 
999
  pos_ids = torch.full((B, 1), pos, device=device, dtype=torch.long)
1000
 
1001
  alive = torch.ones(B, dtype=torch.bool, device=device)
1002
  counts = torch.zeros(B, dtype=torch.int32, device=device)
1003
 
 
1004
  def _argmax01(logits):
1005
+ # logits: (B, bins) -> normalized index in [0,1]
1006
  if use_soft_argmax:
1007
  probs = torch.softmax(logits, dim=-1)
1008
  bins = torch.arange(probs.size(-1), device=logits.device, dtype=torch.float32)
1009
  idx = (probs * bins).sum(dim=-1) / (probs.size(-1) - 1)
1010
+ return idx
1011
  else:
1012
  idx = logits.argmax(dim=-1).to(torch.float32)
1013
  return idx / float(logits.size(-1) - 1)
1014
 
1015
  with torch.inference_mode():
1016
  while alive.any() and (counts < max_objects).any():
1017
+ # x
1018
+ x_logits = decode_coordinate(hidden, self.region)
1019
+ if x_logits.dim() == 3: x_logits = x_logits.squeeze(1)
1020
+ x_center = _argmax01(x_logits)
1021
+ x_in = x_center.to(dtype=x_logits.dtype).unsqueeze(-1)
1022
+ x_emb = encode_coordinate(x_in, self.region).unsqueeze(1)
 
1023
 
 
1024
  mask[alive, :, :, pos] = True
1025
  logits, hidden = self._decode_one_tok(x_emb, mask, pos_ids, lora)
1026
  pos_ids[alive, 0] += 1
1027
+ pos += 1
1028
 
1029
+ # y
1030
  y_logits = decode_coordinate(hidden, self.region)
1031
+ if y_logits.dim() == 3: y_logits = y_logits.squeeze(1)
1032
+ y_center = _argmax01(y_logits)
1033
+ y_in = y_center.to(dtype=y_logits.dtype).unsqueeze(-1)
 
1034
  y_emb = encode_coordinate(y_in, self.region).unsqueeze(1)
1035
 
1036
  mask[alive, :, :, pos] = True
 
1039
  pos += 1
1040
 
1041
  if include_size:
1042
+ # size
1043
  size_logits = decode_size(hidden, self.region)
1044
  w_logits = size_logits[0].squeeze(1)
1045
  h_logits = size_logits[1].squeeze(1)
1046
  if use_soft_argmax:
1047
+ w_bin = (torch.softmax(w_logits, dim=-1) *
1048
+ torch.arange(w_logits.size(-1), device=device)).sum(dim=-1)
1049
+ h_bin = (torch.softmax(h_logits, dim=-1) *
1050
+ torch.arange(h_logits.size(-1), device=device)).sum(dim=-1)
1051
  else:
1052
  w_bin = w_logits.argmax(dim=-1).to(torch.float32)
1053
  h_bin = h_logits.argmax(dim=-1).to(torch.float32)
1054
+
1055
  w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
1056
  h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
1057
 
1058
+ size_in = torch.stack([w, h], dim=1).to(dtype=w_logits.dtype)
1059
+ size_emb = encode_size(size_in, self.region).unsqueeze(1)
1060
 
 
1061
  for i in range(B):
1062
+ if not alive[i]: continue
 
1063
  xl = (x_center[i] - w[i] / 2).item()
1064
  xr = (x_center[i] + w[i] / 2).item()
1065
  yt = (y_center[i] - h[i] / 2).item()
1066
  yb = (y_center[i] + h[i] / 2).item()
 
1067
  out[i].append({
1068
  "x_min": max(0.0, min(1.0, xl)),
1069
  "y_min": max(0.0, min(1.0, yt)),
 
1075
  logits, hidden = self._decode_one_tok(size_emb, mask, pos_ids, lora)
1076
  pos_ids[alive, 0] += 1
1077
  pos += 1
1078
+ next_tok = logits.argmax(dim=-1).squeeze(-1)
1079
  else:
1080
  for i in range(B):
1081
  if alive[i]:
 
1086
  pos += 1
1087
  next_tok = logits.argmax(dim=-1).squeeze(-1)
1088
 
 
1089
  finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
1090
  counts = counts + ((~finished_now) & alive).to(counts.dtype)
1091
  alive &= ~finished_now