HV-Khurdula commited on
Commit
217be81
·
verified ·
1 Parent(s): 1168c80

Update moondream.py

Browse files

fix: batched generation

Files changed (1) hide show
  1. moondream.py +55 -32
moondream.py CHANGED
@@ -976,7 +976,7 @@ class MoondreamModel(nn.Module):
976
  def _generate_points_batched(
977
  self,
978
  hidden, # (B,1,C)
979
- next_token, # (B,1)
980
  pos, # int or Tensor; normalized below
981
  include_size: bool = True,
982
  max_objects: int = 50,
@@ -989,10 +989,11 @@ class MoondreamModel(nn.Module):
989
  eos_id = self.config.tokenizer.eos_id
990
  max_ctx = self.config.text.max_context
991
 
 
992
  if torch.is_tensor(pos):
993
  pos = int(pos.max().item())
994
 
995
- # SDPA mask and position ids
996
  mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
997
  if pos > 0:
998
  mask[:, :, :, :pos] = True
@@ -1001,36 +1002,44 @@ class MoondreamModel(nn.Module):
1001
  alive = torch.ones(B, dtype=torch.bool, device=device)
1002
  counts = torch.zeros(B, dtype=torch.int32, device=device)
1003
 
1004
- def _argmax01(logits):
1005
- # logits: (B, bins) -> normalized index in [0,1]
 
 
 
 
 
 
 
 
 
1006
  if use_soft_argmax:
1007
  probs = torch.softmax(logits, dim=-1)
1008
- bins = torch.arange(probs.size(-1), device=logits.device, dtype=torch.float32)
1009
- idx = (probs * bins).sum(dim=-1) / (probs.size(-1) - 1)
1010
- return idx
 
1011
  else:
1012
  idx = logits.argmax(dim=-1).to(torch.float32)
1013
  return idx / float(logits.size(-1) - 1)
1014
 
1015
  with torch.inference_mode():
1016
  while alive.any() and (counts < max_objects).any():
1017
- # x
1018
- x_logits = decode_coordinate(hidden, self.region)
1019
- if x_logits.dim() == 3: x_logits = x_logits.squeeze(1)
1020
- x_center = _argmax01(x_logits)
1021
- x_in = x_center.to(dtype=x_logits.dtype).unsqueeze(-1)
1022
- x_emb = encode_coordinate(x_in, self.region).unsqueeze(1)
1023
 
1024
  mask[alive, :, :, pos] = True
1025
  logits, hidden = self._decode_one_tok(x_emb, mask, pos_ids, lora)
1026
  pos_ids[alive, 0] += 1
1027
  pos += 1
1028
 
1029
- # y
1030
- y_logits = decode_coordinate(hidden, self.region)
1031
- if y_logits.dim() == 3: y_logits = y_logits.squeeze(1)
1032
- y_center = _argmax01(y_logits)
1033
- y_in = y_center.to(dtype=y_logits.dtype).unsqueeze(-1)
1034
  y_emb = encode_coordinate(y_in, self.region).unsqueeze(1)
1035
 
1036
  mask[alive, :, :, pos] = True
@@ -1039,27 +1048,40 @@ class MoondreamModel(nn.Module):
1039
  pos += 1
1040
 
1041
  if include_size:
1042
- # size
1043
- size_logits = decode_size(hidden, self.region)
1044
- w_logits = size_logits[0].squeeze(1)
1045
- h_logits = size_logits[1].squeeze(1)
 
 
 
 
 
 
1046
  if use_soft_argmax:
1047
- w_bin = (torch.softmax(w_logits, dim=-1) *
1048
- torch.arange(w_logits.size(-1), device=device)).sum(dim=-1)
1049
- h_bin = (torch.softmax(h_logits, dim=-1) *
1050
- torch.arange(h_logits.size(-1), device=device)).sum(dim=-1)
 
 
1051
  else:
1052
  w_bin = w_logits.argmax(dim=-1).to(torch.float32)
1053
  h_bin = h_logits.argmax(dim=-1).to(torch.float32)
1054
 
1055
- w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
1056
- h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
 
 
 
1057
 
1058
- size_in = torch.stack([w, h], dim=1).to(dtype=w_logits.dtype)
1059
- size_emb = encode_size(size_in, self.region).unsqueeze(1)
1060
 
 
1061
  for i in range(B):
1062
- if not alive[i]: continue
 
1063
  xl = (x_center[i] - w[i] / 2).item()
1064
  xr = (x_center[i] + w[i] / 2).item()
1065
  yt = (y_center[i] - h[i] / 2).item()
@@ -1075,7 +1097,7 @@ class MoondreamModel(nn.Module):
1075
  logits, hidden = self._decode_one_tok(size_emb, mask, pos_ids, lora)
1076
  pos_ids[alive, 0] += 1
1077
  pos += 1
1078
- next_tok = logits.argmax(dim=-1).squeeze(-1)
1079
  else:
1080
  for i in range(B):
1081
  if alive[i]:
@@ -1086,6 +1108,7 @@ class MoondreamModel(nn.Module):
1086
  pos += 1
1087
  next_tok = logits.argmax(dim=-1).squeeze(-1)
1088
 
 
1089
  finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
1090
  counts = counts + ((~finished_now) & alive).to(counts.dtype)
1091
  alive &= ~finished_now
 
976
  def _generate_points_batched(
977
  self,
978
  hidden, # (B,1,C)
979
+ next_token, # (B,1) (kept for API compatibility)
980
  pos, # int or Tensor; normalized below
981
  include_size: bool = True,
982
  max_objects: int = 50,
 
989
  eos_id = self.config.tokenizer.eos_id
990
  max_ctx = self.config.text.max_context
991
 
992
+ # Normalize pos to a scalar int (supports int, (1,), (B,), (B,1))
993
  if torch.is_tensor(pos):
994
  pos = int(pos.max().item())
995
 
996
+ # 4-D mask: (B, 1, q_len=1, kv_len) + per-row position ids (B,1)
997
  mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
998
  if pos > 0:
999
  mask[:, :, :, :pos] = True
 
1002
  alive = torch.ones(B, dtype=torch.bool, device=device)
1003
  counts = torch.zeros(B, dtype=torch.int32, device=device)
1004
 
1005
+ def _argmax01(logits: torch.Tensor) -> torch.Tensor:
1006
+ """
1007
+ logits: (..., bins) -> normalized index in [0,1] per row
1008
+ Accepts (B,1,bins), (B,bins), or (bins,).
1009
+ """
1010
+ # Canonicalize to (B, bins)
1011
+ if logits.dim() == 3: # (B,1,bins)
1012
+ logits = logits.squeeze(1)
1013
+ elif logits.dim() == 1: # (bins,) -> (1,bins)
1014
+ logits = logits.unsqueeze(0)
1015
+
1016
  if use_soft_argmax:
1017
  probs = torch.softmax(logits, dim=-1)
1018
+ bins_idx = torch.arange(probs.size(-1), device=probs.device, dtype=torch.float32)
1019
+ # expected-bin (0..bins-1) -> normalize by (bins-1) to [0,1]
1020
+ expbin = (probs * bins_idx).sum(dim=-1)
1021
+ return expbin / float(probs.size(-1) - 1)
1022
  else:
1023
  idx = logits.argmax(dim=-1).to(torch.float32)
1024
  return idx / float(logits.size(-1) - 1)
1025
 
1026
  with torch.inference_mode():
1027
  while alive.any() and (counts < max_objects).any():
1028
+ # ---- x ------------------------------------------------------
1029
+ x_logits = decode_coordinate(hidden, self.region) # (B,1,b) or (B,b)
1030
+ x_center = _argmax01(x_logits) # (B,)
1031
+ x_in = x_center.to(dtype=x_logits.dtype if torch.is_tensor(x_logits) else hidden.dtype).unsqueeze(-1)
1032
+ x_emb = encode_coordinate(x_in, self.region).unsqueeze(1) # (B,1,C)
 
1033
 
1034
  mask[alive, :, :, pos] = True
1035
  logits, hidden = self._decode_one_tok(x_emb, mask, pos_ids, lora)
1036
  pos_ids[alive, 0] += 1
1037
  pos += 1
1038
 
1039
+ # ---- y ------------------------------------------------------
1040
+ y_logits = decode_coordinate(hidden, self.region) # (B,1,b) or (B,b)
1041
+ y_center = _argmax01(y_logits) # (B,)
1042
+ y_in = y_center.to(dtype=y_logits.dtype if torch.is_tensor(y_logits) else hidden.dtype).unsqueeze(-1)
 
1043
  y_emb = encode_coordinate(y_in, self.region).unsqueeze(1)
1044
 
1045
  mask[alive, :, :, pos] = True
 
1048
  pos += 1
1049
 
1050
  if include_size:
1051
+ # ---- size ----------------------------------------------
1052
+ size_logits = decode_size(hidden, self.region) # tuple of (w_logits, h_logits)
1053
+ w_logits, h_logits = size_logits
1054
+
1055
+ # Canonicalize to (B, bins) for both
1056
+ if w_logits.dim() == 3: w_logits = w_logits.squeeze(1)
1057
+ if h_logits.dim() == 3: h_logits = h_logits.squeeze(1)
1058
+ if w_logits.dim() == 1: w_logits = w_logits.unsqueeze(0)
1059
+ if h_logits.dim() == 1: h_logits = h_logits.unsqueeze(0)
1060
+
1061
  if use_soft_argmax:
1062
+ w_probs = torch.softmax(w_logits, dim=-1)
1063
+ h_probs = torch.softmax(h_logits, dim=-1)
1064
+ w_bins_idx = torch.arange(w_probs.size(-1), device=device, dtype=torch.float32)
1065
+ h_bins_idx = torch.arange(h_probs.size(-1), device=device, dtype=torch.float32)
1066
+ w_bin = (w_probs * w_bins_idx).sum(dim=-1) # (B,)
1067
+ h_bin = (h_probs * h_bins_idx).sum(dim=-1) # (B,)
1068
  else:
1069
  w_bin = w_logits.argmax(dim=-1).to(torch.float32)
1070
  h_bin = h_logits.argmax(dim=-1).to(torch.float32)
1071
 
1072
+ # bins -> size (inverse log scale), robust to bins != 1024
1073
+ w_den = float(w_logits.size(-1) - 1)
1074
+ h_den = float(h_logits.size(-1) - 1)
1075
+ w = torch.pow(2.0, (w_bin / w_den) * 10.0 - 10.0)
1076
+ h = torch.pow(2.0, (h_bin / h_den) * 10.0 - 10.0)
1077
 
1078
+ size_in = torch.stack([w, h], dim=1).to(dtype=w_logits.dtype) # (B,2)
1079
+ size_emb = encode_size(size_in, self.region).unsqueeze(1) # (B,1,C)
1080
 
1081
+ # record boxes only for alive rows
1082
  for i in range(B):
1083
+ if not alive[i]:
1084
+ continue
1085
  xl = (x_center[i] - w[i] / 2).item()
1086
  xr = (x_center[i] + w[i] / 2).item()
1087
  yt = (y_center[i] - h[i] / 2).item()
 
1097
  logits, hidden = self._decode_one_tok(size_emb, mask, pos_ids, lora)
1098
  pos_ids[alive, 0] += 1
1099
  pos += 1
1100
+ next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
1101
  else:
1102
  for i in range(B):
1103
  if alive[i]:
 
1108
  pos += 1
1109
  next_tok = logits.argmax(dim=-1).squeeze(-1)
1110
 
1111
+ # stop only rows that hit eos (or reached max objects)
1112
  finished_now = (next_tok == eos_id) | (counts >= max_objects - 1)
1113
  counts = counts + ((~finished_now) & alive).to(counts.dtype)
1114
  alive &= ~finished_now