HV-Khurdula commited on
Commit
1ca98b0
·
verified ·
1 Parent(s): cdfd7db

Update moondream.py

Browse files
Files changed (1) hide show
  1. moondream.py +90 -70
moondream.py CHANGED
@@ -895,81 +895,96 @@ class MoondreamModel(nn.Module):
895
  probs = self._apply_top_p(probs, top_p)
896
  next_token = torch.multinomial(probs, num_samples=1) # (B, 1)
897
 
898
- pos_vec = torch.tensor([pos], device=self.device, dtype=torch.long).repeat(B) + torch.tensor(lens, device=self.device)
899
 
900
  return last_hidden, next_token, pos_vec # (B,1,C), (B,1), (B,)
901
 
902
- def _generate_points_batched(self, hidden, next_token, pos_vec, include_size: bool = True, max_objects: int = 50, lora=None):
 
 
 
 
 
 
 
 
903
  """
904
- Vectorized version of _generate_points() that decodes x -> y -> size -> next-token
905
- for all rows in the batch simultaneously.
 
 
906
  Returns: list-of-lists of dicts, length B.
907
  """
908
-
909
  B = hidden.size(0)
910
  device = self.device
911
  out = [[] for _ in range(B)]
912
  eos_id = self.config.tokenizer.eos_id
913
-
914
- # Per-row attention/masking state
 
 
 
 
915
  max_ctx = self.config.text.max_context
916
  mask = torch.zeros(B, 1, max_ctx, device=device, dtype=torch.bool)
917
- for i in range(B):
918
- mask[i, :, : int(pos_vec[i].item())] = 1
919
- pos_ids = pos_vec.clone()
920
-
921
- alive = torch.ones(B, dtype=torch.bool, device=device)
922
  counts = torch.zeros(B, dtype=torch.int32, device=device)
923
-
924
  with torch.inference_mode():
925
  while alive.any() and (counts < max_objects).any():
926
- # --- x coordinate (from current hidden) ---
927
- x_logits = decode_coordinate(hidden, self.region) # (B, 1, 1024) or (B, 1024)
928
  if x_logits.dim() == 3:
929
- x_logits = x_logits.squeeze(1) # (B, 1024)
930
- x_bin = x_logits.argmax(dim=-1).to(torch.float32) # (B,)
931
  x_center = x_bin / float(x_logits.size(-1)) # (B,)
932
- x_input = x_center.to(dtype=x_logits.dtype).unsqueeze(-1) # (B, 1)
933
  x_emb = encode_coordinate(x_input, self.region).unsqueeze(1) # (B,1,C)
934
-
935
-
936
- # step: decode to get hidden for y
937
- for i in range(B):
938
- if alive[i]:
939
- mask[i, :, pos_ids[i]] = 1
940
- logits, hidden = self._decode_one_tok(x_emb, mask, pos_ids, lora)
941
- pos_ids = pos_ids + alive.to(torch.long)
942
-
 
 
943
  # --- y coordinate ---
944
  y_logits = decode_coordinate(hidden, self.region)
945
  if y_logits.dim() == 3:
946
- y_logits = y_logits.squeeze(1) # (B, 1024)
947
- y_bin = y_logits.argmax(dim=-1).to(torch.float32)
948
- y_center = y_bin / float(y_logits.size(-1)) # (B,)
949
- y_input = y_center.to(dtype=y_logits.dtype).unsqueeze(-1) # (B, 1)
950
- y_emb = encode_coordinate(y_input, self.region).unsqueeze(1)
951
-
952
-
953
- # step: decode to get hidden for size (or eos)
954
- for i in range(B):
955
- if alive[i]:
956
- mask[i, :, pos_ids[i]] = 1
957
- logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
958
- pos_ids = pos_ids + alive.to(torch.long)
959
-
 
 
960
  if include_size:
961
- # --- size logits (batched) ---
962
- size_logits = decode_size(hidden, self.region) # tuple/list [w_logits, h_logits] shaped (B,1,1024)
963
- w_logits, h_logits = size_logits[0].squeeze(1), size_logits[1].squeeze(1) # (B,1024), (B,1024)
964
  w_bin = w_logits.argmax(dim=-1).to(torch.float32)
965
  h_bin = h_logits.argmax(dim=-1).to(torch.float32)
966
- # Convert from log-scale bin to size in [0,1]
967
- w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
968
- h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
969
- size_input = torch.stack([w, h], dim=1).to(dtype=w_logits.dtype) # (B, 2)
970
- size_emb = encode_size(size_input, self.region).unsqueeze(1)
971
-
972
- # Commit boxes for alive rows
973
  for i in range(B):
974
  if not alive[i]:
975
  continue
@@ -979,35 +994,40 @@ class MoondreamModel(nn.Module):
979
  "x_max": (x_center[i] + w[i] / 2).item(),
980
  "y_max": (y_center[i] + h[i] / 2).item(),
981
  })
982
-
983
  # step: decode "next token" to decide continuation
984
- for i in range(B):
985
- if alive[i]:
986
- mask[i, :, pos_ids[i]] = 1
987
- logits, hidden = self._decode_one_tok(size_emb, mask, pos_ids, lora)
988
- pos_ids = pos_ids + alive.to(torch.long)
 
 
 
989
  next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
990
  else:
991
  # Points mode (no size)
992
- for i in range(B):
993
- if not alive[i]:
994
- continue
995
- out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
996
- # step: decode next token from y_emb
997
  for i in range(B):
998
  if alive[i]:
999
- mask[i, :, pos_ids[i]] = 1
1000
- logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
1001
- pos_ids = pos_ids + alive.to(torch.long)
1002
- next_tok = logits.argmax(dim=-1).squeeze(-1)
1003
-
1004
- # Update which rows are done and count
 
 
 
 
 
 
1005
  finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
1006
  counts = counts + (~finished_now & alive).to(counts.dtype)
1007
  alive &= ~finished_now
1008
-
1009
  return out
1010
 
 
1011
  def detect_multi(self, image, objects, settings=None):
1012
  """
1013
  Parallel multi-label detection.
@@ -1053,7 +1073,7 @@ class MoondreamModel(nn.Module):
1053
  d["label"] = lab
1054
  res[lab] = lst
1055
  return {"objects": res}
1056
- # === END: Batched multi-label detection additions ===
1057
  def _detect_gaze(
1058
  self,
1059
  image: EncodedImage,
 
895
  probs = self._apply_top_p(probs, top_p)
896
  next_token = torch.multinomial(probs, num_samples=1) # (B, 1)
897
 
898
+ pos_vec = torch.full((B,), pos + T, device=self.device, dtype=torch.long)
899
 
900
  return last_hidden, next_token, pos_vec # (B,1,C), (B,1), (B,)
901
 
902
+ def _generate_points_batched(
903
+ self,
904
+ hidden: torch.Tensor, # (B, 1, C) last hidden per row from prefill
905
+ next_token: torch.Tensor, # (B, 1) not used directly (kept for parity)
906
+ pos_vec: torch.Tensor, # (B,) next write pos per row after prefill
907
+ include_size: bool = True,
908
+ max_objects: int = 50,
909
+ lora=None,
910
+ ):
911
  """
912
+ Batched decode loop for multi-label detection.
913
+ - Uses a *shared* scalar position id per step (q_len = 1), as expected by RoPE.
914
+ - Maintains a per-row attention mask and 'alive' flags.
915
+ - Feeds coord encoders with (B,1) tensors; size encoder with (B,2).
916
  Returns: list-of-lists of dicts, length B.
917
  """
 
918
  B = hidden.size(0)
919
  device = self.device
920
  out = [[] for _ in range(B)]
921
  eos_id = self.config.tokenizer.eos_id
922
+
923
+ # --- Shared write position (scalar) consistent with RoPE q_len=1 ---
924
+ # We align rows by padding; using the maximum ensures all KV rows can decode in lockstep.
925
+ pos = int(pos_vec.max().item())
926
+
927
+ # Per-row attention mask (1 = visible). Mark everything up to 'pos' as visible.
928
  max_ctx = self.config.text.max_context
929
  mask = torch.zeros(B, 1, max_ctx, device=device, dtype=torch.bool)
930
+ mask[:, :, :pos] = 1
931
+
932
+ alive = torch.ones(B, dtype=torch.bool, device=device)
 
 
933
  counts = torch.zeros(B, dtype=torch.int32, device=device)
934
+
935
  with torch.inference_mode():
936
  while alive.any() and (counts < max_objects).any():
937
+ # --- x coordinate ---
938
+ x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
939
  if x_logits.dim() == 3:
940
+ x_logits = x_logits.squeeze(1) # (B,1024)
941
+ x_bin = x_logits.argmax(dim=-1).to(torch.float32) # (B,)
942
  x_center = x_bin / float(x_logits.size(-1)) # (B,)
943
+ x_input = x_center.to(dtype=x_logits.dtype).unsqueeze(-1) # (B,1)
944
  x_emb = encode_coordinate(x_input, self.region).unsqueeze(1) # (B,1,C)
945
+
946
+ # step: decode hidden for y (advance shared pos)
947
+ mask[:, :, pos] = 1
948
+ logits, hidden = self._decode_one_tok(
949
+ x_emb,
950
+ mask,
951
+ torch.tensor([pos], device=device, dtype=torch.long), # length-1 (q_len=1)
952
+ lora,
953
+ )
954
+ pos += 1
955
+
956
  # --- y coordinate ---
957
  y_logits = decode_coordinate(hidden, self.region)
958
  if y_logits.dim() == 3:
959
+ y_logits = y_logits.squeeze(1)
960
+ y_bin = y_logits.argmax(dim=-1).to(torch.float32)
961
+ y_center = y_bin / float(y_logits.size(-1)) # (B,)
962
+ y_input = y_center.to(dtype=y_logits.dtype).unsqueeze(-1) # (B,1)
963
+ y_emb = encode_coordinate(y_input, self.region).unsqueeze(1) # (B,1,C)
964
+
965
+ # step: decode hidden for size / eos (advance shared pos)
966
+ mask[:, :, pos] = 1
967
+ logits, hidden = self._decode_one_tok(
968
+ y_emb,
969
+ mask,
970
+ torch.tensor([pos], device=device, dtype=torch.long),
971
+ lora,
972
+ )
973
+ pos += 1
974
+
975
  if include_size:
976
+ # --- size (batched) ---
977
+ size_logits = decode_size(hidden, self.region) # ([B,1,1024],[B,1,1024])
978
+ w_logits, h_logits = size_logits[0].squeeze(1), size_logits[1].squeeze(1) # (B,1024)
979
  w_bin = w_logits.argmax(dim=-1).to(torch.float32)
980
  h_bin = h_logits.argmax(dim=-1).to(torch.float32)
981
+ # Convert log-scale bins -> sizes in [0,1]
982
+ w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0) # (B,)
983
+ h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0) # (B,)
984
+ size_input = torch.stack([w, h], dim=1).to(dtype=w_logits.dtype) # (B,2)
985
+ size_emb = encode_size(size_input, self.region).unsqueeze(1) # (B,1,C)
986
+
987
+ # Record boxes for alive rows
988
  for i in range(B):
989
  if not alive[i]:
990
  continue
 
994
  "x_max": (x_center[i] + w[i] / 2).item(),
995
  "y_max": (y_center[i] + h[i] / 2).item(),
996
  })
997
+
998
  # step: decode "next token" to decide continuation
999
+ mask[:, :, pos] = 1
1000
+ logits, hidden = self._decode_one_tok(
1001
+ size_emb,
1002
+ mask,
1003
+ torch.tensor([pos], device=device, dtype=torch.long),
1004
+ lora,
1005
+ )
1006
+ pos += 1
1007
  next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
1008
  else:
1009
  # Points mode (no size)
 
 
 
 
 
1010
  for i in range(B):
1011
  if alive[i]:
1012
+ out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
1013
+ mask[:, :, pos] = 1
1014
+ logits, hidden = self._decode_one_tok(
1015
+ y_emb,
1016
+ mask,
1017
+ torch.tensor([pos], device=device, dtype=torch.long),
1018
+ lora,
1019
+ )
1020
+ pos += 1
1021
+ next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
1022
+
1023
+ # Update finished/alive bookkeeping
1024
  finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
1025
  counts = counts + (~finished_now & alive).to(counts.dtype)
1026
  alive &= ~finished_now
1027
+
1028
  return out
1029
 
1030
+
1031
  def detect_multi(self, image, objects, settings=None):
1032
  """
1033
  Parallel multi-label detection.
 
1073
  d["label"] = lab
1074
  res[lab] = lst
1075
  return {"objects": res}
1076
+
1077
  def _detect_gaze(
1078
  self,
1079
  image: EncodedImage,