HV-Khurdula commited on
Commit
23f9bc4
·
verified ·
1 Parent(s): c366261

Update moondream.py

Browse files

fix: EOS padding for batched gerenation causing generation duplication.

Files changed (1) hide show
  1. moondream.py +75 -81
moondream.py CHANGED
@@ -943,66 +943,63 @@ class MoondreamModel(nn.Module):
943
  b.kv_cache.v_cache[:, :, :T, :] = v.expand(batch_size, -1, -1, -1)
944
 
945
 
946
- def _prefill_prompt_batched(
947
- self,
948
- labels,
949
- pos: int,
950
- lora=None,
951
- temperature: float = 0.0,
952
- top_p: float = 0.0,
953
- ):
954
  tpl = self.config.tokenizer.templates["detect"]
955
  if tpl is None:
956
  raise NotImplementedError("Model does not support object detection.")
957
 
958
- rows, lens = [], []
 
959
  for lab in labels:
960
  ids = tpl["prefix"] + self.tokenizer.encode(" " + lab).ids + tpl["suffix"]
961
  t = torch.tensor(ids, device=self.device, dtype=torch.long)
962
- rows.append(t); lens.append(t.numel())
 
963
 
964
- B, T = len(rows), max(lens)
965
- eos = self.config.tokenizer.eos_id
966
 
967
- # Pad with EOS in the tensor, but we will still start generation per-row at its own length
968
- prompt_ids = torch.full((B, T), eos, device=self.device, dtype=torch.long)
969
- for i, ids in enumerate(rows):
970
- prompt_ids[i, : ids.numel()] = ids
971
-
972
- prompt_emb = text_encoder(prompt_ids, self.text) # (B,T,C)
 
 
 
 
973
  torch._dynamo.mark_dynamic(prompt_emb, 1)
974
 
975
- base = self.attn_mask[:, :, pos : pos + T, :] # (1,1,T,K)
976
- mask = base.expand(B, -1, -1, -1).contiguous() # (B,1,T,K)
977
-
978
  pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long) # (T,)
979
- hidden_BTC = self._prefill(prompt_emb, mask, pos_ids, lora) # (B,T,C)
980
- logits_BTV = lm_head(hidden_BTC, self.text) # (B,T,V)
981
 
982
- # Gather last real token per row
983
- idx = (torch.tensor(lens, device=self.device) - 1).clamp_min(0) # (B,)
 
 
 
984
  last_hidden = hidden_BTC[torch.arange(B, device=self.device), idx][:, None, :] # (B,1,C)
985
  last_logits = logits_BTV[torch.arange(B, device=self.device), idx] # (B,V)
986
 
987
  if temperature == 0.0:
988
- next_token = last_logits.argmax(dim=-1, keepdim=True) # (B,1)
989
  else:
990
  probs = torch.softmax(last_logits / temperature, dim=-1)
991
  probs = self._apply_top_p(probs, top_p)
992
- next_token = torch.multinomial(probs, num_samples=1) # (B,1)
993
 
994
- # Per-row next positions (don’t force them all to pos+T)
995
- pos_vec = (pos + torch.tensor(lens, device=self.device, dtype=torch.long)) # (B,)
996
-
997
- return last_hidden, next_token, pos_vec
998
-
999
 
1000
 
1001
  def _generate_points_batched(
1002
  self,
1003
- hidden, # (B,1,C)
1004
- next_token, # (B,1) (unused for greedy)
1005
- pos_vec, # (B,) next-free position per row
1006
  include_size: bool = True,
1007
  max_objects: int = 50,
1008
  lora=None,
@@ -1012,18 +1009,19 @@ class MoondreamModel(nn.Module):
1012
  device = self.device
1013
  out = [[] for _ in range(B)]
1014
  eos_id = self.config.tokenizer.eos_id
1015
- coord_id = self.config.tokenizer.coord_id
1016
  max_ctx = self.config.text.max_context
1017
 
1018
- # Build per-row masks/positions
1019
  mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
1020
- pos_ids = pos_vec.clone().view(B, 1) # (B,1)
1021
- for i in range(B):
1022
- p0 = int(pos_ids[i, 0].item())
1023
- if p0 > 0:
1024
- mask[i, 0, 0, :p0] = True
1025
 
 
1026
  def _argmax01(logits: torch.Tensor) -> torch.Tensor:
 
 
1027
  if use_soft_argmax:
1028
  probs = torch.softmax(logits, dim=-1)
1029
  bins = torch.arange(probs.size(-1), device=logits.device, dtype=torch.float32)
@@ -1031,41 +1029,36 @@ class MoondreamModel(nn.Module):
1031
  idx = logits.argmax(dim=-1).to(torch.float32)
1032
  return idx / float(logits.size(-1) - 1)
1033
 
1034
- def _advance_rows(row_mask: torch.Tensor):
1035
- idx = row_mask.nonzero(as_tuple=False).flatten()
1036
- for i in idx.tolist():
1037
- col = int(pos_ids[i, 0].item())
1038
- mask[i, 0, 0, col] = True
1039
- return idx
1040
-
1041
  alive = torch.ones(B, dtype=torch.bool, device=device)
1042
- counts = torch.zeros(B, dtype=torch.int32, device=device)
1043
 
1044
  with torch.inference_mode():
1045
  while alive.any() and (counts < max_objects).any():
1046
- # -------- x --------
1047
- x_logits = decode_coordinate(hidden, self.region) # (B,1,1024) or (B,1024)
1048
- if x_logits.dim() == 3: x_logits = x_logits.squeeze(1)
1049
- x_center = _argmax01(x_logits) # (B,)
1050
- x_emb = encode_coordinate(x_center.to(dtype=x_logits.dtype).unsqueeze(-1), self.region).unsqueeze(1)
 
1051
 
1052
- idx = _advance_rows(alive)
 
1053
  logits, hidden = self._decode_one_tok(x_emb, mask, pos_ids, lora)
1054
- pos_ids[idx, 0] += 1
1055
 
1056
- # -------- y --------
1057
  y_logits = decode_coordinate(hidden, self.region)
1058
- if y_logits.dim() == 3: y_logits = y_logits.squeeze(1)
1059
- y_center = _argmax01(y_logits)
1060
  y_emb = encode_coordinate(y_center.to(dtype=y_logits.dtype).unsqueeze(-1), self.region).unsqueeze(1)
1061
 
1062
- idx = _advance_rows(alive)
1063
  logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
1064
- pos_ids[idx, 0] += 1
1065
 
1066
  if include_size:
1067
- size_ret = decode_size(hidden, self.region)
1068
- w_logits, h_logits = self._norm_size_logits(size_ret, B) # (B,C)
 
1069
 
1070
  if use_soft_argmax:
1071
  bins = torch.arange(w_logits.size(-1), device=device, dtype=torch.float32)
@@ -1075,13 +1068,14 @@ class MoondreamModel(nn.Module):
1075
  w_bin = w_logits.argmax(dim=-1).to(torch.float32)
1076
  h_bin = h_logits.argmax(dim=-1).to(torch.float32)
1077
 
 
1078
  w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
1079
  h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
1080
 
1081
  size_emb = encode_size(torch.stack([w, h], dim=1).to(dtype=w_logits.dtype), self.region).unsqueeze(1)
1082
 
1083
- # record boxes only for rows still alive
1084
- for i in alive.nonzero(as_tuple=False).flatten().tolist():
1085
  xl = (x_center[i] - w[i] / 2).item()
1086
  xr = (x_center[i] + w[i] / 2).item()
1087
  yt = (y_center[i] - h[i] / 2).item()
@@ -1093,34 +1087,34 @@ class MoondreamModel(nn.Module):
1093
  "y_max": max(0.0, min(1.0, yb)),
1094
  })
1095
 
1096
- idx = _advance_rows(alive)
1097
  logits, hidden = self._decode_one_tok(size_emb, mask, pos_ids, lora)
1098
- pos_ids[idx, 0] += 1
 
1099
  next_tok = logits.argmax(dim=-1)
 
 
1100
  else:
1101
- for i in alive.nonzero(as_tuple=False).flatten().tolist():
 
1102
  out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
1103
- idx = _advance_rows(alive)
1104
  logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
1105
- pos_ids[idx, 0] += 1
1106
  next_tok = logits.argmax(dim=-1)
 
 
1107
 
1108
- # normalize next_tok to (B,)
1109
- while next_tok.dim() > 1:
1110
- next_tok = next_tok.squeeze(-1)
1111
-
1112
- # we added exactly one object/point to all alive rows
1113
- counts[alive] += 1
1114
 
1115
- # GRAMMAR STOP: only continue if the model asks to start another coord;
1116
- # otherwise stop row (covers EOS or any non-coord token).
1117
- continue_mask = (next_tok == coord_id)
1118
- finished_now = (~continue_mask) | (counts >= max_objects)
1119
  alive &= ~finished_now
1120
 
1121
  return out
1122
 
1123
 
 
1124
  def detect_multi(self, image, objects, settings=None):
1125
  if self.config.tokenizer.templates["detect"] is None:
1126
  raise NotImplementedError("Model does not support object detection.")
 
943
  b.kv_cache.v_cache[:, :, :T, :] = v.expand(batch_size, -1, -1, -1)
944
 
945
 
946
+ def _prefill_prompt_batched(self, labels, pos: int, lora=None,
947
+ temperature: float = 0.0, top_p: float = 0.0):
 
 
 
 
 
 
948
  tpl = self.config.tokenizer.templates["detect"]
949
  if tpl is None:
950
  raise NotImplementedError("Model does not support object detection.")
951
 
952
+ # Build each row's token ids (variable length)
953
+ rows_ids, lens = [], []
954
  for lab in labels:
955
  ids = tpl["prefix"] + self.tokenizer.encode(" " + lab).ids + tpl["suffix"]
956
  t = torch.tensor(ids, device=self.device, dtype=torch.long)
957
+ rows_ids.append(t)
958
+ lens.append(t.numel())
959
 
960
+ B, T = len(rows_ids), max(lens)
 
961
 
962
+ # Embed each row, then LEFT-pad using its own first token embedding (neutral),
963
+ # mirroring upstream moondream2 batching strategy.
964
+ embs = [text_encoder(t.unsqueeze(0), self.text)[0] for t in rows_ids] # list of (Li, C)
965
+ padded = []
966
+ for e, L in zip(embs, lens):
967
+ pad = T - L
968
+ if pad > 0:
969
+ e = torch.cat([e[:1].repeat(pad, 1), e], dim=0) # (T, C)
970
+ padded.append(e)
971
+ prompt_emb = torch.stack(padded, dim=0) # (B, T, C)
972
  torch._dynamo.mark_dynamic(prompt_emb, 1)
973
 
974
+ # Standard prefill over the shared image prefix [pos : pos+T]
975
+ base = self.attn_mask[:, :, pos:pos+T, :] # (1,1,T,K)
976
+ mask = base.expand(B, -1, -1, -1).contiguous() # (B,1,T,K)
977
  pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long) # (T,)
 
 
978
 
979
+ hidden_BTC = self._prefill(prompt_emb, mask, pos_ids, lora) # (B,T,C)
980
+ logits_BTV = lm_head(hidden_BTC, self.text) # (B,T,V)
981
+
982
+ # Take the last real token of each row
983
+ idx = (torch.tensor(lens, device=self.device) - 1).clamp_min(0) # (B,)
984
  last_hidden = hidden_BTC[torch.arange(B, device=self.device), idx][:, None, :] # (B,1,C)
985
  last_logits = logits_BTV[torch.arange(B, device=self.device), idx] # (B,V)
986
 
987
  if temperature == 0.0:
988
+ next_token = last_logits.argmax(dim=-1, keepdim=True) # (B,1)
989
  else:
990
  probs = torch.softmax(last_logits / temperature, dim=-1)
991
  probs = self._apply_top_p(probs, top_p)
992
+ next_token = torch.multinomial(probs, num_samples=1) # (B,1)
993
 
994
+ pos_end = int(pos + T)
995
+ return last_hidden, next_token, pos_end
 
 
 
996
 
997
 
998
  def _generate_points_batched(
999
  self,
1000
+ hidden, # (B,1,C) - last token hidden state per row
1001
+ next_token, # (B,1) - unused for greedy loop; kept for API
1002
+ pos, # int - first free position in cache
1003
  include_size: bool = True,
1004
  max_objects: int = 50,
1005
  lora=None,
 
1009
  device = self.device
1010
  out = [[] for _ in range(B)]
1011
  eos_id = self.config.tokenizer.eos_id
 
1012
  max_ctx = self.config.text.max_context
1013
 
1014
+ # 4D mask: (B,1,1,K); we advance per-row
1015
  mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
1016
+ p0 = int(pos)
1017
+ if p0 > 0:
1018
+ mask[:, :, :, :p0] = True
1019
+ pos_ids = torch.full((B, 1), p0, device=device, dtype=torch.long)
 
1020
 
1021
+ # helper: logits -> normalized [0..1] coordinate (soft-argmax for stability)
1022
  def _argmax01(logits: torch.Tensor) -> torch.Tensor:
1023
+ if logits.dim() == 3:
1024
+ logits = logits.squeeze(1) # (B, bins)
1025
  if use_soft_argmax:
1026
  probs = torch.softmax(logits, dim=-1)
1027
  bins = torch.arange(probs.size(-1), device=logits.device, dtype=torch.float32)
 
1029
  idx = logits.argmax(dim=-1).to(torch.float32)
1030
  return idx / float(logits.size(-1) - 1)
1031
 
 
 
 
 
 
 
 
1032
  alive = torch.ones(B, dtype=torch.bool, device=device)
1033
+ counts = torch.zeros(B, dtype=torch.int32, device=device)
1034
 
1035
  with torch.inference_mode():
1036
  while alive.any() and (counts < max_objects).any():
1037
+ alive_idx = alive.nonzero(as_tuple=False).squeeze(1)
1038
+
1039
+ # ---------- x ----------
1040
+ x_logits = decode_coordinate(hidden, self.region) # (B,1,bins) or (B,bins)
1041
+ x_center = _argmax01(x_logits) # (B,)
1042
+ x_emb = encode_coordinate(x_center.to(dtype=x_logits.dtype).unsqueeze(-1), self.region).unsqueeze(1) # (B,1,C)
1043
 
1044
+ # advance one token for each alive row (per-row column)
1045
+ mask[alive_idx, 0, 0, pos_ids[alive_idx, 0]] = True
1046
  logits, hidden = self._decode_one_tok(x_emb, mask, pos_ids, lora)
1047
+ pos_ids[alive_idx, 0] += 1
1048
 
1049
+ # ---------- y ----------
1050
  y_logits = decode_coordinate(hidden, self.region)
1051
+ y_center = _argmax01(y_logits) # (B,)
 
1052
  y_emb = encode_coordinate(y_center.to(dtype=y_logits.dtype).unsqueeze(-1), self.region).unsqueeze(1)
1053
 
1054
+ mask[alive_idx, 0, 0, pos_ids[alive_idx, 0]] = True
1055
  logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
1056
+ pos_ids[alive_idx, 0] += 1
1057
 
1058
  if include_size:
1059
+ # ---------- size (w,h) ----------
1060
+ size_ret = decode_size(hidden, self.region) # (...,2,bins)
1061
+ w_logits, h_logits = self._norm_size_logits(size_ret, B) # each (B,bins)
1062
 
1063
  if use_soft_argmax:
1064
  bins = torch.arange(w_logits.size(-1), device=device, dtype=torch.float32)
 
1068
  w_bin = w_logits.argmax(dim=-1).to(torch.float32)
1069
  h_bin = h_logits.argmax(dim=-1).to(torch.float32)
1070
 
1071
+ # inverse log scale (md2)
1072
  w = torch.pow(2.0, (w_bin / 1023.0) * 10.0 - 10.0)
1073
  h = torch.pow(2.0, (h_bin / 1023.0) * 10.0 - 10.0)
1074
 
1075
  size_emb = encode_size(torch.stack([w, h], dim=1).to(dtype=w_logits.dtype), self.region).unsqueeze(1)
1076
 
1077
+ # write outputs only for alive rows
1078
+ for i in alive_idx.tolist():
1079
  xl = (x_center[i] - w[i] / 2).item()
1080
  xr = (x_center[i] + w[i] / 2).item()
1081
  yt = (y_center[i] - h[i] / 2).item()
 
1087
  "y_max": max(0.0, min(1.0, yb)),
1088
  })
1089
 
1090
+ mask[alive_idx, 0, 0, pos_ids[alive_idx, 0]] = True
1091
  logits, hidden = self._decode_one_tok(size_emb, mask, pos_ids, lora)
1092
+ pos_ids[alive_idx, 0] += 1
1093
+
1094
  next_tok = logits.argmax(dim=-1)
1095
+ if next_tok.dim() == 3: next_tok = next_tok.squeeze(-1).squeeze(-1)
1096
+ if next_tok.dim() == 2: next_tok = next_tok.squeeze(1)
1097
  else:
1098
+ # points only
1099
+ for i in alive_idx.tolist():
1100
  out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
1101
+ mask[alive_idx, 0, 0, pos_ids[alive_idx, 0]] = True
1102
  logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
1103
+ pos_ids[alive_idx, 0] += 1
1104
  next_tok = logits.argmax(dim=-1)
1105
+ if next_tok.dim() == 3: next_tok = next_tok.squeeze(-1).squeeze(-1)
1106
+ if next_tok.dim() == 2: next_tok = next_tok.squeeze(1)
1107
 
1108
+ counts[alive] += 1 # we produced one object/point for each alive row
 
 
 
 
 
1109
 
1110
+ # stop rows that hit eos OR reached quota
1111
+ finished_now = (next_tok == eos_id) | (counts >= max_objects)
 
 
1112
  alive &= ~finished_now
1113
 
1114
  return out
1115
 
1116
 
1117
+
1118
  def detect_multi(self, image, objects, settings=None):
1119
  if self.config.tokenizer.templates["detect"] is None:
1120
  raise NotImplementedError("Model does not support object detection.")