HV-Khurdula commited on
Commit
75468c3
verified
1 Parent(s): f9b6e6a

Update moondream.py

Browse files

fix: caled_dot_product_attention expands to a 3-D mask into (B, n_heads, q_len, kv_len) and bombs out.

Files changed (1) hide show
  1. moondream.py +38 -27
moondream.py CHANGED
@@ -26,6 +26,8 @@ from .region import decode_coordinate, encode_coordinate, decode_size, encode_si
26
  from .text import text_encoder, lm_head
27
  from typing import Optional, List, Union
28
  from .lora import variant_state_dict
 
 
29
 
30
  ImageEncodingSettings = TypedDict(
31
  "ImageEncodingSettings",
@@ -911,24 +913,27 @@ class MoondreamModel(nn.Module):
911
 
912
  def _generate_points_batched(
913
  self,
914
- hidden, # (B,1,C)
915
- next_token, # (B,1)
916
- pos: int, # shared scalar next position
917
  include_size: bool = True,
918
  max_objects: int = 50,
919
  lora=None,
920
  ):
921
  """
922
  Vectorized version of _generate_points() that decodes x -> y -> size -> next-token
923
- for all rows in the batch simultaneously. Returns list-of-lists of dicts, len B.
 
924
  """
 
 
925
  B = hidden.size(0)
926
  device = self.device
927
  out = [[] for _ in range(B)]
928
  eos_id = self.config.tokenizer.eos_id
929
  max_ctx = self.config.text.max_context
930
 
931
- # 4-D mask: (B, 1, q_len=1, kv_len)
932
  mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
933
  if pos > 0:
934
  mask[:, :, :, :pos] = True
@@ -939,29 +944,29 @@ class MoondreamModel(nn.Module):
939
 
940
  with torch.inference_mode():
941
  while alive.any() and (counts < max_objects).any():
942
- # --- x coordinate ---
943
- x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
944
  if x_logits.dim() == 3:
945
- x_logits = x_logits.squeeze(1)
946
- x_bin = x_logits.argmax(dim=-1).to(torch.float32) # (B,)
947
- x_center = x_bin / float(x_logits.size(-1)) # (B,)
948
- x_in = x_center.to(dtype=x_logits.dtype).unsqueeze(-1) # (B,1)
949
- x_emb = encode_coordinate(x_in, self.region).unsqueeze(1) # (B,1,C)
950
 
951
- # advance attention one step
952
  mask[:, :, :, pos] = True
953
  logits, hidden = self._decode_one_tok(x_emb, mask, pos_id, lora)
954
  pos += 1
955
  pos_id[0] = pos
956
 
957
- # --- y coordinate ---
958
- y_logits = decode_coordinate(hidden, self.region)
959
  if y_logits.dim() == 3:
960
  y_logits = y_logits.squeeze(1)
961
- y_bin = y_logits.argmax(dim=-1).to(torch.float32)
962
- y_center = y_bin / float(y_logits.size(-1)) # (B,)
963
- y_in = y_center.to(dtype=y_logits.dtype).unsqueeze(-1) # (B,1)
964
- y_emb = encode_coordinate(y_in, self.region).unsqueeze(1)
965
 
966
  mask[:, :, :, pos] = True
967
  logits, hidden = self._decode_one_tok(y_emb, mask, pos_id, lora)
@@ -969,18 +974,23 @@ class MoondreamModel(nn.Module):
969
  pos_id[0] = pos
970
 
971
  if include_size:
972
- # --- size ---
973
- size_logits = decode_size(hidden, self.region)
974
- w_logits, h_logits = size_logits[0].squeeze(1), size_logits[1].squeeze(1)
 
 
 
975
  w_bin = w_logits.argmax(dim=-1).to(torch.float32)
976
  h_bin = h_logits.argmax(dim=-1).to(torch.float32)
977
- # bins -> size in [0,1] (inverse of log-scale mapping)
 
978
  w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
979
  h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
980
- size_in = torch.stack([w, h], dim=1).to(dtype=w_logits.dtype) # (B,2)
 
981
  size_emb = encode_size(size_in, self.region).unsqueeze(1) # (B,1,C)
982
 
983
- # record boxes
984
  for i in range(B):
985
  if alive[i]:
986
  out[i].append({
@@ -990,21 +1000,22 @@ class MoondreamModel(nn.Module):
990
  "y_max": (y_center[i] + h[i] / 2).item(),
991
  })
992
 
 
993
  mask[:, :, :, pos] = True
994
  logits, hidden = self._decode_one_tok(size_emb, mask, pos_id, lora)
995
  pos += 1
996
  pos_id[0] = pos
997
  next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
998
  else:
 
999
  for i in range(B):
1000
  if alive[i]:
1001
  out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
1002
-
1003
  mask[:, :, :, pos] = True
1004
  logits, hidden = self._decode_one_tok(y_emb, mask, pos_id, lora)
1005
  pos += 1
1006
  pos_id[0] = pos
1007
- next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
1008
 
1009
  finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
1010
  counts = counts + (~finished_now & alive).to(counts.dtype)
 
26
  from .text import text_encoder, lm_head
27
  from typing import Optional, List, Union
28
  from .lora import variant_state_dict
29
+ from .layers import mlp
30
+
31
 
32
  ImageEncodingSettings = TypedDict(
33
  "ImageEncodingSettings",
 
913
 
914
  def _generate_points_batched(
915
  self,
916
+ hidden, # (B,1,C) last hidden after prefill (per label row)
917
+ next_token, # (B,1) (kept for parity; not used when temperature=0)
918
+ pos: int, # shared scalar next position for all rows
919
  include_size: bool = True,
920
  max_objects: int = 50,
921
  lora=None,
922
  ):
923
  """
924
  Vectorized version of _generate_points() that decodes x -> y -> size -> next-token
925
+ for all rows in the batch simultaneously. Returns list-of-lists of dicts (len B).
926
+ Batch-safe: uses 4-D masks and avoids region.decode_size() (which flattens batch).
927
  """
928
+ import torch
929
+
930
  B = hidden.size(0)
931
  device = self.device
932
  out = [[] for _ in range(B)]
933
  eos_id = self.config.tokenizer.eos_id
934
  max_ctx = self.config.text.max_context
935
 
936
+ # 4-D mask: (B, 1, q_len=1, kv_len), True means "visible" to match model's convention
937
  mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
938
  if pos > 0:
939
  mask[:, :, :, :pos] = True
 
944
 
945
  with torch.inference_mode():
946
  while alive.any() and (counts < max_objects).any():
947
+ # --- x coordinate (batched) ---
948
+ x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
949
  if x_logits.dim() == 3:
950
+ x_logits = x_logits.squeeze(1) # (B,1024)
951
+ x_bin = x_logits.argmax(dim=-1).to(torch.float32)
952
+ x_center = x_bin / float(x_logits.size(-1)) # (B,)
953
+ x_in = x_center.to(dtype=x_logits.dtype).unsqueeze(-1) # (B,1)
954
+ x_emb = encode_coordinate(x_in, self.region).unsqueeze(1) # (B,1,C)
955
 
956
+ # advance one token
957
  mask[:, :, :, pos] = True
958
  logits, hidden = self._decode_one_tok(x_emb, mask, pos_id, lora)
959
  pos += 1
960
  pos_id[0] = pos
961
 
962
+ # --- y coordinate (batched) ---
963
+ y_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
964
  if y_logits.dim() == 3:
965
  y_logits = y_logits.squeeze(1)
966
+ y_bin = y_logits.argmax(dim=-1).to(torch.float32)
967
+ y_center = y_bin / float(y_logits.size(-1))
968
+ y_in = y_center.to(dtype=y_logits.dtype).unsqueeze(-1) # (B,1)
969
+ y_emb = encode_coordinate(y_in, self.region).unsqueeze(1)
970
 
971
  mask[:, :, :, pos] = True
972
  logits, hidden = self._decode_one_tok(y_emb, mask, pos_id, lora)
 
974
  pos_id[0] = pos
975
 
976
  if include_size:
977
+ # ---- size (batched, *without* region.decode_size which flattens batch) ----
978
+ # size_out_dim is 2*1024 (W then H). mlp() preserves (B,1,路).
979
+ size_logits = mlp(hidden, self.region["size_decoder"]).squeeze(1) # (B, 2048)
980
+ half = size_logits.size(-1) // 2
981
+ w_logits, h_logits = size_logits[:, :half], size_logits[:, half:] # (B,1024),(B,1024)
982
+
983
  w_bin = w_logits.argmax(dim=-1).to(torch.float32)
984
  h_bin = h_logits.argmax(dim=-1).to(torch.float32)
985
+
986
+ # inverse log-scale mapping used by the repo
987
  w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
988
  h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
989
+
990
+ size_in = torch.stack([w, h], dim=1).to(dtype=w_logits.dtype) # (B,2)
991
  size_emb = encode_size(size_in, self.region).unsqueeze(1) # (B,1,C)
992
 
993
+ # commit boxes
994
  for i in range(B):
995
  if alive[i]:
996
  out[i].append({
 
1000
  "y_max": (y_center[i] + h[i] / 2).item(),
1001
  })
1002
 
1003
+ # decide continuation
1004
  mask[:, :, :, pos] = True
1005
  logits, hidden = self._decode_one_tok(size_emb, mask, pos_id, lora)
1006
  pos += 1
1007
  pos_id[0] = pos
1008
  next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
1009
  else:
1010
+ # points mode
1011
  for i in range(B):
1012
  if alive[i]:
1013
  out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
 
1014
  mask[:, :, :, pos] = True
1015
  logits, hidden = self._decode_one_tok(y_emb, mask, pos_id, lora)
1016
  pos += 1
1017
  pos_id[0] = pos
1018
+ next_tok = logits.argmax(dim=-1).squeeze(-1)
1019
 
1020
  finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
1021
  counts = counts + (~finished_now & alive).to(counts.dtype)