HV-Khurdula commited on
Commit
938b38e
·
verified ·
1 Parent(s): 217be81

Update moondream.py

Browse files
Files changed (1) hide show
  1. moondream.py +99 -106
moondream.py CHANGED
@@ -64,39 +64,35 @@ class EncodedImage:
64
  pos: int
65
  caches: List[Tuple[torch.Tensor, torch.Tensor]]
66
 
67
-
68
  class KVCache(nn.Module):
69
-
70
  def __init__(self, n_heads, n_kv_heads, max_context, dim, device, dtype):
71
  super().__init__()
72
  cache_shape = (1, n_kv_heads, max_context, dim // n_heads)
73
- self.register_buffer(
74
- "k_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
75
- )
76
- self.register_buffer(
77
- "v_cache", torch.zeros(*cache_shape, device=device, dtype=dtype)
78
- )
79
 
80
  def update(self, pos_ids, k, v):
81
  """
82
  Supports:
83
  • Prefill: k,v = (B, n_kv_heads, q_len, d), pos_ids = (q_len,)
84
  • 1-step: k,v = (B, n_kv_heads, 1, d), pos_ids = (B,1) or (B,)
85
- Scalar: k,v = (B, n_kv_heads, 1, d), pos_ids = ()
86
  Writes into self.k_cache/self.v_cache shaped (B, n_kv_heads, T_max, d).
87
  """
88
  kout, vout = self.k_cache, self.v_cache
89
-
 
90
  if not torch.is_tensor(pos_ids):
91
  pos_ids = torch.tensor(pos_ids, device=k.device, dtype=torch.long)
92
  else:
93
  pos_ids = pos_ids.to(device=k.device, dtype=torch.long)
94
-
95
  if k.dim() != 4 or v.dim() != 4:
96
- raise RuntimeError(f"KV update expects k,v 4D. Got k={tuple(k.shape)} v={tuple(v.shape)}")
 
97
  B, Hkv, q_len, D = k.shape
98
-
99
- # Expand cache to batch B if needed (expand-from-1 allowed)
100
  if kout.size(0) != B:
101
  if kout.size(0) == 1:
102
  self.k_cache = kout.expand(B, -1, -1, -1).clone()
@@ -104,34 +100,31 @@ class KVCache(nn.Module):
104
  kout, vout = self.k_cache, self.v_cache
105
  else:
106
  raise RuntimeError(f"KV cache batch mismatch: cache.B={kout.size(0)} vs k.B={B}")
107
-
108
- # A) Prefill: pos_ids = (q_len,)
109
  if pos_ids.dim() == 1 and pos_ids.numel() == q_len:
110
  for i in range(B):
111
  kout[i, :, pos_ids, :] = k[i]
112
  vout[i, :, pos_ids, :] = v[i]
113
  return kout, vout
114
-
115
- # B) One-step: q_len == 1 and pos_ids per row: (B,) or (B,1)
116
- if q_len == 1 and pos_ids.numel() == B:
117
  pos_ids = pos_ids.view(B)
118
  for i in range(B):
119
  pi = int(pos_ids[i].item())
120
  kout[i, :, pi, :] = k[i, :, 0, :]
121
  vout[i, :, pi, :] = v[i, :, 0, :]
122
  return kout, vout
123
-
124
- # C) Scalar position for everyone & q_len == 1
125
  if pos_ids.dim() == 0 and q_len == 1:
126
  pi = int(pos_ids.item())
127
  kout[:, :, pi, :] = k[:, :, 0, :]
128
  vout[:, :, pi, :] = v[:, :, 0, :]
129
  return kout, vout
130
-
131
- raise RuntimeError(f"Unsupported KV update combo: k={tuple(k.shape)}, pos_ids={tuple(pos_ids.shape)}")
132
-
133
-
134
 
 
135
 
136
 
137
 
@@ -214,11 +207,12 @@ class MoondreamModel(nn.Module):
214
  head_dim = c.dim // c.n_heads
215
  for blk in self.text.blocks:
216
  device = blk.kv_cache.k_cache.device
217
- dtype = blk.kv_cache.k_cache.dtype
218
- shape = (batch_size, c.n_kv_heads, c.max_context, head_dim)
219
  blk.kv_cache.k_cache = torch.zeros(shape, device=device, dtype=dtype)
220
  blk.kv_cache.v_cache = torch.zeros(shape, device=device, dtype=dtype)
221
 
 
222
 
223
 
224
  def _setup_caches(self):
@@ -575,52 +569,41 @@ class MoondreamModel(nn.Module):
575
  image: Union[Image.Image, EncodedImage],
576
  settings: Optional[ImageEncodingSettings] = None,
577
  ) -> EncodedImage:
578
- # Always start from single-row caches; avoids leftovers from batched runs. DO NOT TOUCH THIS!!!!!!!!!
579
  self._setup_caches()
580
-
581
- if isinstance(image, EncodedImage):
582
- return image
583
- elif not isinstance(image, Image.Image):
584
- raise ValueError("image must be a PIL Image or EncodedImage")
585
-
586
- # Always start from single-row caches to avoid leftovers from batched runs
587
  for blk in self.text.blocks:
588
  if blk.kv_cache.k_cache.size(0) != 1:
589
  blk.kv_cache.k_cache = blk.kv_cache.k_cache[:1].contiguous()
590
  blk.kv_cache.v_cache = blk.kv_cache.v_cache[:1].contiguous()
591
-
592
-
593
 
594
- lora = (
595
- variant_state_dict(settings["variant"], device=self.device)
596
- if settings is not None and "variant" in settings
597
- else None
598
- )
 
 
599
 
600
  with torch.inference_mode():
601
  img_emb = self._run_vision_encoder(image)
602
- bos_emb = text_encoder(
603
- torch.tensor([[self.config.tokenizer.bos_id]], device=self.device), self.text
604
- )
605
  inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1)
606
- attn = self.attn_mask # (1,1,Tmax,Tmax)
607
- mask = self.attn_mask[:, :, 0 : inputs_embeds.size(1), :]
608
- pos_ids = torch.arange(inputs_embeds.size(1), dtype=torch.long)
609
  self._prefill(inputs_embeds, mask, pos_ids, lora)
610
-
611
-
612
 
613
  return EncodedImage(
614
  pos=inputs_embeds.size(1),
615
  caches=[
616
  (
617
- b.kv_cache.k_cache[:, :, : inputs_embeds.size(1), :].clone(),
618
- b.kv_cache.v_cache[:, :, : inputs_embeds.size(1), :].clone(),
619
  )
620
  for b in self.text.blocks
621
  ],
622
  )
623
 
 
624
  def query(
625
  self,
626
  image: Optional[Union[Image.Image, EncodedImage]] = None,
@@ -913,22 +896,18 @@ class MoondreamModel(nn.Module):
913
 
914
 
915
  def _load_encoded_image_batched(self, encoded_image, batch_size: int):
916
- """
917
- Clone single-image KV caches into a batch-B cache so we can decode B labels in parallel.
918
- """
919
  for b, (k, v) in zip(self.text.blocks, encoded_image.caches):
920
  T = k.size(2)
921
- # Allocate new [B, n_kv_heads, T_max, head_dim] caches if needed
922
  if b.kv_cache.k_cache.size(0) != batch_size:
923
  new_k = b.kv_cache.k_cache.new_zeros((batch_size,) + b.kv_cache.k_cache.shape[1:])
924
  new_v = b.kv_cache.v_cache.new_zeros((batch_size,) + b.kv_cache.v_cache.shape[1:])
925
  b.kv_cache.k_cache = new_k
926
  b.kv_cache.v_cache = new_v
927
- # Copy current prefix from the encoded image into all B rows
928
  b.kv_cache.k_cache[:, :, :T, :] = k.expand(batch_size, -1, -1, -1)
929
  b.kv_cache.v_cache[:, :, :T, :] = v.expand(batch_size, -1, -1, -1)
930
 
931
 
 
932
  def _prefill_prompt_batched(self, labels, pos: int, lora=None,
933
  temperature: float = 0.0, top_p: float = 0.0):
934
  tpl = self.config.tokenizer.templates["detect"]
@@ -945,34 +924,35 @@ class MoondreamModel(nn.Module):
945
 
946
  prompt_ids = torch.full((B, T), eos, device=self.device, dtype=torch.long)
947
  for i, ids in enumerate(rows):
948
- prompt_ids[i, :ids.numel()] = ids
949
 
950
- prompt_emb = text_encoder(prompt_ids, self.text) # (B,T,C)
951
  torch._dynamo.mark_dynamic(prompt_emb, 1)
952
 
953
- base = self.attn_mask[:, :, pos:pos+T, :] # (1,1,T,kv_len)
954
- mask = base.expand(B, -1, -1, -1).contiguous() # (B,1,T,kv_len)
955
-
956
  pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long) # (T,)
957
- hidden_BTC = self._prefill(prompt_emb, mask, pos_ids, lora) # (B,T,C)
958
- logits_BTV = lm_head(hidden_BTC, self.text)
959
 
960
- idx = (torch.tensor(lens, device=self.device) - 1).clamp_min(0) # (B,)
 
 
 
961
  last_hidden = hidden_BTC[torch.arange(B, device=self.device), idx][:, None, :] # (B,1,C)
962
  last_logits = logits_BTV[torch.arange(B, device=self.device), idx] # (B,V)
963
 
964
  if temperature == 0.0:
965
- next_token = last_logits.argmax(dim=-1, keepdim=True) # (B,1)
966
  else:
967
  probs = torch.softmax(last_logits / temperature, dim=-1)
968
  probs = self._apply_top_p(probs, top_p)
969
- next_token = torch.multinomial(probs, num_samples=1) # (B,1)
970
 
971
- pos_end = int(pos + T) # shared scalar end position
972
  return last_hidden, next_token, pos_end
973
 
974
 
975
 
 
976
  def _generate_points_batched(
977
  self,
978
  hidden, # (B,1,C)
@@ -989,11 +969,11 @@ class MoondreamModel(nn.Module):
989
  eos_id = self.config.tokenizer.eos_id
990
  max_ctx = self.config.text.max_context
991
 
992
- # Normalize pos to a scalar int (supports int, (1,), (B,), (B,1))
993
  if torch.is_tensor(pos):
994
  pos = int(pos.max().item())
995
 
996
- # 4-D mask: (B, 1, q_len=1, kv_len) + per-row position ids (B,1)
997
  mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
998
  if pos > 0:
999
  mask[:, :, :, :pos] = True
@@ -1004,32 +984,48 @@ class MoondreamModel(nn.Module):
1004
 
1005
  def _argmax01(logits: torch.Tensor) -> torch.Tensor:
1006
  """
1007
- logits: (..., bins) -> normalized index in [0,1] per row
1008
- Accepts (B,1,bins), (B,bins), or (bins,).
1009
  """
1010
- # Canonicalize to (B, bins)
1011
- if logits.dim() == 3: # (B,1,bins)
1012
  logits = logits.squeeze(1)
1013
- elif logits.dim() == 1: # (bins,) -> (1,bins)
1014
  logits = logits.unsqueeze(0)
 
 
 
1015
 
1016
  if use_soft_argmax:
1017
  probs = torch.softmax(logits, dim=-1)
1018
- bins_idx = torch.arange(probs.size(-1), device=probs.device, dtype=torch.float32)
1019
- # expected-bin (0..bins-1) -> normalize by (bins-1) to [0,1]
1020
- expbin = (probs * bins_idx).sum(dim=-1)
1021
  return expbin / float(probs.size(-1) - 1)
1022
  else:
1023
  idx = logits.argmax(dim=-1).to(torch.float32)
1024
  return idx / float(logits.size(-1) - 1)
1025
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1026
  with torch.inference_mode():
1027
  while alive.any() and (counts < max_objects).any():
1028
  # ---- x ------------------------------------------------------
1029
- x_logits = decode_coordinate(hidden, self.region) # (B,1,b) or (B,b)
1030
- x_center = _argmax01(x_logits) # (B,)
1031
- x_in = x_center.to(dtype=x_logits.dtype if torch.is_tensor(x_logits) else hidden.dtype).unsqueeze(-1)
1032
- x_emb = encode_coordinate(x_in, self.region).unsqueeze(1) # (B,1,C)
 
1033
 
1034
  mask[alive, :, :, pos] = True
1035
  logits, hidden = self._decode_one_tok(x_emb, mask, pos_ids, lora)
@@ -1037,9 +1033,10 @@ class MoondreamModel(nn.Module):
1037
  pos += 1
1038
 
1039
  # ---- y ------------------------------------------------------
1040
- y_logits = decode_coordinate(hidden, self.region) # (B,1,b) or (B,b)
1041
- y_center = _argmax01(y_logits) # (B,)
1042
- y_in = y_center.to(dtype=y_logits.dtype if torch.is_tensor(y_logits) else hidden.dtype).unsqueeze(-1)
 
1043
  y_emb = encode_coordinate(y_in, self.region).unsqueeze(1)
1044
 
1045
  mask[alive, :, :, pos] = True
@@ -1049,22 +1046,24 @@ class MoondreamModel(nn.Module):
1049
 
1050
  if include_size:
1051
  # ---- size ----------------------------------------------
1052
- size_logits = decode_size(hidden, self.region) # tuple of (w_logits, h_logits)
1053
  w_logits, h_logits = size_logits
1054
 
1055
- # Canonicalize to (B, bins) for both
1056
  if w_logits.dim() == 3: w_logits = w_logits.squeeze(1)
1057
  if h_logits.dim() == 3: h_logits = h_logits.squeeze(1)
1058
  if w_logits.dim() == 1: w_logits = w_logits.unsqueeze(0)
1059
  if h_logits.dim() == 1: h_logits = h_logits.unsqueeze(0)
 
 
1060
 
1061
  if use_soft_argmax:
1062
  w_probs = torch.softmax(w_logits, dim=-1)
1063
  h_probs = torch.softmax(h_logits, dim=-1)
1064
- w_bins_idx = torch.arange(w_probs.size(-1), device=device, dtype=torch.float32)
1065
- h_bins_idx = torch.arange(h_probs.size(-1), device=device, dtype=torch.float32)
1066
- w_bin = (w_probs * w_bins_idx).sum(dim=-1) # (B,)
1067
- h_bin = (h_probs * h_bins_idx).sum(dim=-1) # (B,)
1068
  else:
1069
  w_bin = w_logits.argmax(dim=-1).to(torch.float32)
1070
  h_bin = h_logits.argmax(dim=-1).to(torch.float32)
@@ -1075,8 +1074,11 @@ class MoondreamModel(nn.Module):
1075
  w = torch.pow(2.0, (w_bin / w_den) * 10.0 - 10.0)
1076
  h = torch.pow(2.0, (h_bin / h_den) * 10.0 - 10.0)
1077
 
1078
- size_in = torch.stack([w, h], dim=1).to(dtype=w_logits.dtype) # (B,2)
1079
- size_emb = encode_size(size_in, self.region).unsqueeze(1) # (B,1,C)
 
 
 
1080
 
1081
  # record boxes only for alive rows
1082
  for i in range(B):
@@ -1097,11 +1099,11 @@ class MoondreamModel(nn.Module):
1097
  logits, hidden = self._decode_one_tok(size_emb, mask, pos_ids, lora)
1098
  pos_ids[alive, 0] += 1
1099
  pos += 1
1100
- next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
1101
  else:
1102
  for i in range(B):
1103
  if alive[i]:
1104
- out[i].append({"x": x_center[i].item(), "y": y_center[i].item()})
1105
  mask[alive, :, :, pos] = True
1106
  logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
1107
  pos_ids[alive, 0] += 1
@@ -1120,16 +1122,8 @@ class MoondreamModel(nn.Module):
1120
 
1121
 
1122
 
 
1123
  def detect_multi(self, image, objects, settings=None):
1124
- """
1125
- Parallel multi-label detection.
1126
- Args:
1127
- image: PIL.Image or EncodedImage
1128
- objects: list[str], e.g. ["person", "car"]
1129
- settings: Optional[ObjectSamplingSettings], honors "max_objects" and "variant"
1130
- Returns:
1131
- {"objects": {label: [box_dict, ...]}}
1132
- """
1133
  if self.config.tokenizer.templates["detect"] is None:
1134
  raise NotImplementedError("Model does not support object detection.")
1135
  settings = settings or {}
@@ -1160,9 +1154,8 @@ class MoondreamModel(nn.Module):
1160
  d["label"] = lab
1161
  res[lab] = lst
1162
 
1163
- # IMPORTANT: restore caches to B=1 so future calls (e.g., encode_image) are safe.
1164
  self._reset_kv_caches(1)
1165
-
1166
  return {"objects": res}
1167
 
1168
 
 
64
  pos: int
65
  caches: List[Tuple[torch.Tensor, torch.Tensor]]
66
 
 
67
  class KVCache(nn.Module):
 
68
  def __init__(self, n_heads, n_kv_heads, max_context, dim, device, dtype):
69
  super().__init__()
70
  cache_shape = (1, n_kv_heads, max_context, dim // n_heads)
71
+ self.register_buffer("k_cache", torch.zeros(*cache_shape, device=device, dtype=dtype))
72
+ self.register_buffer("v_cache", torch.zeros(*cache_shape, device=device, dtype=dtype))
 
 
 
 
73
 
74
  def update(self, pos_ids, k, v):
75
  """
76
  Supports:
77
  • Prefill: k,v = (B, n_kv_heads, q_len, d), pos_ids = (q_len,)
78
  • 1-step: k,v = (B, n_kv_heads, 1, d), pos_ids = (B,1) or (B,)
79
+ Legacy: k,v = (1, n_kv_heads, q_len, d), pos_ids = scalar
80
  Writes into self.k_cache/self.v_cache shaped (B, n_kv_heads, T_max, d).
81
  """
82
  kout, vout = self.k_cache, self.v_cache
83
+
84
+ # Normalize pos_ids
85
  if not torch.is_tensor(pos_ids):
86
  pos_ids = torch.tensor(pos_ids, device=k.device, dtype=torch.long)
87
  else:
88
  pos_ids = pos_ids.to(device=k.device, dtype=torch.long)
89
+
90
  if k.dim() != 4 or v.dim() != 4:
91
+ raise RuntimeError(f"KV update expects 4D k,v. Got k={tuple(k.shape)} v={tuple(v.shape)}")
92
+
93
  B, Hkv, q_len, D = k.shape
94
+
95
+ # Ensure cache batch matches B (expand-from-1 allowed)
96
  if kout.size(0) != B:
97
  if kout.size(0) == 1:
98
  self.k_cache = kout.expand(B, -1, -1, -1).clone()
 
100
  kout, vout = self.k_cache, self.v_cache
101
  else:
102
  raise RuntimeError(f"KV cache batch mismatch: cache.B={kout.size(0)} vs k.B={B}")
103
+
104
+ # Case A: PREFILL vector of length q_len (same for all B rows)
105
  if pos_ids.dim() == 1 and pos_ids.numel() == q_len:
106
  for i in range(B):
107
  kout[i, :, pos_ids, :] = k[i]
108
  vout[i, :, pos_ids, :] = v[i]
109
  return kout, vout
110
+
111
+ # Case B: 1-STEP q_len == 1 with (B,) or (B,1) per-row positions
112
+ if q_len == 1 and (pos_ids.numel() == B):
113
  pos_ids = pos_ids.view(B)
114
  for i in range(B):
115
  pi = int(pos_ids[i].item())
116
  kout[i, :, pi, :] = k[i, :, 0, :]
117
  vout[i, :, pi, :] = v[i, :, 0, :]
118
  return kout, vout
119
+
120
+ # Case C: scalar + 1-step
121
  if pos_ids.dim() == 0 and q_len == 1:
122
  pi = int(pos_ids.item())
123
  kout[:, :, pi, :] = k[:, :, 0, :]
124
  vout[:, :, pi, :] = v[:, :, 0, :]
125
  return kout, vout
 
 
 
 
126
 
127
+ raise RuntimeError(f"Unsupported KV update combo: k={tuple(k.shape)}, pos_ids={tuple(pos_ids.shape)}")
128
 
129
 
130
 
 
207
  head_dim = c.dim // c.n_heads
208
  for blk in self.text.blocks:
209
  device = blk.kv_cache.k_cache.device
210
+ dtype = blk.kv_cache.k_cache.dtype
211
+ shape = (batch_size, c.n_kv_heads, c.max_context, head_dim)
212
  blk.kv_cache.k_cache = torch.zeros(shape, device=device, dtype=dtype)
213
  blk.kv_cache.v_cache = torch.zeros(shape, device=device, dtype=dtype)
214
 
215
+
216
 
217
 
218
  def _setup_caches(self):
 
569
  image: Union[Image.Image, EncodedImage],
570
  settings: Optional[ImageEncodingSettings] = None,
571
  ) -> EncodedImage:
572
+ # Always start from single-row caches; avoids leftovers from batched runs
573
  self._setup_caches()
 
 
 
 
 
 
 
574
  for blk in self.text.blocks:
575
  if blk.kv_cache.k_cache.size(0) != 1:
576
  blk.kv_cache.k_cache = blk.kv_cache.k_cache[:1].contiguous()
577
  blk.kv_cache.v_cache = blk.kv_cache.v_cache[:1].contiguous()
 
 
578
 
579
+ if isinstance(image, EncodedImage):
580
+ return image
581
+ if not isinstance(image, Image.Image):
582
+ raise ValueError("image must be a PIL Image or EncodedImage")
583
+
584
+ lora = (variant_state_dict(settings["variant"], device=self.device)
585
+ if settings is not None and "variant" in settings else None)
586
 
587
  with torch.inference_mode():
588
  img_emb = self._run_vision_encoder(image)
589
+ bos_emb = text_encoder(torch.tensor([[self.config.tokenizer.bos_id]], device=self.device), self.text)
 
 
590
  inputs_embeds = torch.cat([bos_emb, img_emb[None]], dim=1)
591
+ mask = self.attn_mask[:, :, :inputs_embeds.size(1), :]
592
+ pos_ids = torch.arange(inputs_embeds.size(1), dtype=torch.long, device=self.device)
 
593
  self._prefill(inputs_embeds, mask, pos_ids, lora)
 
 
594
 
595
  return EncodedImage(
596
  pos=inputs_embeds.size(1),
597
  caches=[
598
  (
599
+ b.kv_cache.k_cache[:, :, :inputs_embeds.size(1), :].clone(),
600
+ b.kv_cache.v_cache[:, :, :inputs_embeds.size(1), :].clone(),
601
  )
602
  for b in self.text.blocks
603
  ],
604
  )
605
 
606
+
607
  def query(
608
  self,
609
  image: Optional[Union[Image.Image, EncodedImage]] = None,
 
896
 
897
 
898
  def _load_encoded_image_batched(self, encoded_image, batch_size: int):
 
 
 
899
  for b, (k, v) in zip(self.text.blocks, encoded_image.caches):
900
  T = k.size(2)
 
901
  if b.kv_cache.k_cache.size(0) != batch_size:
902
  new_k = b.kv_cache.k_cache.new_zeros((batch_size,) + b.kv_cache.k_cache.shape[1:])
903
  new_v = b.kv_cache.v_cache.new_zeros((batch_size,) + b.kv_cache.v_cache.shape[1:])
904
  b.kv_cache.k_cache = new_k
905
  b.kv_cache.v_cache = new_v
 
906
  b.kv_cache.k_cache[:, :, :T, :] = k.expand(batch_size, -1, -1, -1)
907
  b.kv_cache.v_cache[:, :, :T, :] = v.expand(batch_size, -1, -1, -1)
908
 
909
 
910
+
911
  def _prefill_prompt_batched(self, labels, pos: int, lora=None,
912
  temperature: float = 0.0, top_p: float = 0.0):
913
  tpl = self.config.tokenizer.templates["detect"]
 
924
 
925
  prompt_ids = torch.full((B, T), eos, device=self.device, dtype=torch.long)
926
  for i, ids in enumerate(rows):
927
+ prompt_ids[i, : ids.numel()] = ids
928
 
929
+ prompt_emb = text_encoder(prompt_ids, self.text) # (B,T,C)
930
  torch._dynamo.mark_dynamic(prompt_emb, 1)
931
 
932
+ base = self.attn_mask[:, :, pos:pos+T, :] # (1,1,T,K)
933
+ mask = base.expand(B, -1, -1, -1).contiguous() # (B,1,T,K)
 
934
  pos_ids = torch.arange(pos, pos + T, device=self.device, dtype=torch.long) # (T,)
 
 
935
 
936
+ hidden_BTC = self._prefill(prompt_emb, mask, pos_ids, lora) # (B,T,C)
937
+ logits_BTV = lm_head(hidden_BTC, self.text) # (B,T,V)
938
+
939
+ idx = (torch.tensor(lens, device=self.device) - 1).clamp_min(0) # (B,)
940
  last_hidden = hidden_BTC[torch.arange(B, device=self.device), idx][:, None, :] # (B,1,C)
941
  last_logits = logits_BTV[torch.arange(B, device=self.device), idx] # (B,V)
942
 
943
  if temperature == 0.0:
944
+ next_token = last_logits.argmax(dim=-1, keepdim=True) # (B,1)
945
  else:
946
  probs = torch.softmax(last_logits / temperature, dim=-1)
947
  probs = self._apply_top_p(probs, top_p)
948
+ next_token = torch.multinomial(probs, num_samples=1) # (B,1)
949
 
950
+ pos_end = int(pos + T) # shared next-slot
951
  return last_hidden, next_token, pos_end
952
 
953
 
954
 
955
+
956
  def _generate_points_batched(
957
  self,
958
  hidden, # (B,1,C)
 
969
  eos_id = self.config.tokenizer.eos_id
970
  max_ctx = self.config.text.max_context
971
 
972
+ # Normalize pos to a scalar int
973
  if torch.is_tensor(pos):
974
  pos = int(pos.max().item())
975
 
976
+ # 4-D mask: (B,1,1,K) and per-row pos ids (B,1)
977
  mask = torch.zeros(B, 1, 1, max_ctx, device=device, dtype=torch.bool)
978
  if pos > 0:
979
  mask[:, :, :, :pos] = True
 
984
 
985
  def _argmax01(logits: torch.Tensor) -> torch.Tensor:
986
  """
987
+ logits: (..., bins) -> (B,) in [0,1]
988
+ Accepts (B,1,bins), (B,bins), or (bins,)
989
  """
990
+ # Canonicalize to (B,bins)
991
+ if logits.dim() == 3: # (B,1,bins)
992
  logits = logits.squeeze(1)
993
+ elif logits.dim() == 1: # (bins,)
994
  logits = logits.unsqueeze(0)
995
+ # If batch accidentally collapsed to 1, expand to B so downstream indexing is safe.
996
+ if logits.size(0) == 1 and B > 1:
997
+ logits = logits.expand(B, -1)
998
 
999
  if use_soft_argmax:
1000
  probs = torch.softmax(logits, dim=-1)
1001
+ bins = torch.arange(probs.size(-1), device=probs.device, dtype=torch.float32)
1002
+ expbin = (probs * bins).sum(dim=-1)
 
1003
  return expbin / float(probs.size(-1) - 1)
1004
  else:
1005
  idx = logits.argmax(dim=-1).to(torch.float32)
1006
  return idx / float(logits.size(-1) - 1)
1007
 
1008
+ def _ensure_b(vec: torch.Tensor) -> torch.Tensor:
1009
+ """
1010
+ Make sure 1D tensors are length-B for safe indexing.
1011
+ Accepts scalar/(), (1,), (B,), returns (B,)
1012
+ """
1013
+ if vec.dim() == 0:
1014
+ return vec.repeat(B)
1015
+ if vec.dim() == 1 and vec.numel() == 1 and B > 1:
1016
+ return vec.repeat(B)
1017
+ if vec.dim() == 1 and vec.numel() == B:
1018
+ return vec
1019
+ raise RuntimeError(f"Expected (B,) vec, got shape {tuple(vec.shape)} for B={B}")
1020
+
1021
  with torch.inference_mode():
1022
  while alive.any() and (counts < max_objects).any():
1023
  # ---- x ------------------------------------------------------
1024
+ x_logits = decode_coordinate(hidden, self.region) # (B,1,b) or (B,b) or (b,)
1025
+ x_center = _argmax01(x_logits) # (B,)
1026
+ x_center = _ensure_b(x_center) # force len B
1027
+ x_in = x_center.to(dtype=hidden.dtype).unsqueeze(-1) # (B,1)
1028
+ x_emb = encode_coordinate(x_in, self.region).unsqueeze(1) # (B,1,C)
1029
 
1030
  mask[alive, :, :, pos] = True
1031
  logits, hidden = self._decode_one_tok(x_emb, mask, pos_ids, lora)
 
1033
  pos += 1
1034
 
1035
  # ---- y ------------------------------------------------------
1036
+ y_logits = decode_coordinate(hidden, self.region)
1037
+ y_center = _argmax01(y_logits) # (B,)
1038
+ y_center = _ensure_b(y_center)
1039
+ y_in = y_center.to(dtype=hidden.dtype).unsqueeze(-1)
1040
  y_emb = encode_coordinate(y_in, self.region).unsqueeze(1)
1041
 
1042
  mask[alive, :, :, pos] = True
 
1046
 
1047
  if include_size:
1048
  # ---- size ----------------------------------------------
1049
+ size_logits = decode_size(hidden, self.region) # tuple: (w_logits, h_logits)
1050
  w_logits, h_logits = size_logits
1051
 
1052
+ # Canonicalize to (B,bins); expand if batch collapsed
1053
  if w_logits.dim() == 3: w_logits = w_logits.squeeze(1)
1054
  if h_logits.dim() == 3: h_logits = h_logits.squeeze(1)
1055
  if w_logits.dim() == 1: w_logits = w_logits.unsqueeze(0)
1056
  if h_logits.dim() == 1: h_logits = h_logits.unsqueeze(0)
1057
+ if w_logits.size(0) == 1 and B > 1: w_logits = w_logits.expand(B, -1)
1058
+ if h_logits.size(0) == 1 and B > 1: h_logits = h_logits.expand(B, -1)
1059
 
1060
  if use_soft_argmax:
1061
  w_probs = torch.softmax(w_logits, dim=-1)
1062
  h_probs = torch.softmax(h_logits, dim=-1)
1063
+ bins_w = torch.arange(w_probs.size(-1), device=device, dtype=torch.float32)
1064
+ bins_h = torch.arange(h_probs.size(-1), device=device, dtype=torch.float32)
1065
+ w_bin = (w_probs * bins_w).sum(dim=-1) # (B,)
1066
+ h_bin = (h_probs * bins_h).sum(dim=-1) # (B,)
1067
  else:
1068
  w_bin = w_logits.argmax(dim=-1).to(torch.float32)
1069
  h_bin = h_logits.argmax(dim=-1).to(torch.float32)
 
1074
  w = torch.pow(2.0, (w_bin / w_den) * 10.0 - 10.0)
1075
  h = torch.pow(2.0, (h_bin / h_den) * 10.0 - 10.0)
1076
 
1077
+ # enforce (B,)
1078
+ w = _ensure_b(w); h = _ensure_b(h)
1079
+
1080
+ size_in = torch.stack([w, h], dim=1).to(dtype=hidden.dtype) # (B,2)
1081
+ size_emb = encode_size(size_in, self.region).unsqueeze(1) # (B,1,C)
1082
 
1083
  # record boxes only for alive rows
1084
  for i in range(B):
 
1099
  logits, hidden = self._decode_one_tok(size_emb, mask, pos_ids, lora)
1100
  pos_ids[alive, 0] += 1
1101
  pos += 1
1102
+ next_tok = logits.argmax(dim=-1).squeeze(-1) # (B,)
1103
  else:
1104
  for i in range(B):
1105
  if alive[i]:
1106
+ out[i].append({"x": float(x_center[i]), "y": float(y_center[i])})
1107
  mask[alive, :, :, pos] = True
1108
  logits, hidden = self._decode_one_tok(y_emb, mask, pos_ids, lora)
1109
  pos_ids[alive, 0] += 1
 
1122
 
1123
 
1124
 
1125
+
1126
  def detect_multi(self, image, objects, settings=None):
 
 
 
 
 
 
 
 
 
1127
  if self.config.tokenizer.templates["detect"] is None:
1128
  raise NotImplementedError("Model does not support object detection.")
1129
  settings = settings or {}
 
1154
  d["label"] = lab
1155
  res[lab] = lst
1156
 
1157
+ # IMPORTANT: restore caches to B=1 so future calls are safe
1158
  self._reset_kv_caches(1)
 
1159
  return {"objects": res}
1160
 
1161