HV-Khurdula commited on
Commit
50407fb
·
verified ·
1 Parent(s): 2c36d34

Update moondream.py

Browse files

fix: retry fix for batched point generation

Files changed (1) hide show
  1. moondream.py +4 -4
moondream.py CHANGED
@@ -906,8 +906,7 @@ class MoondreamModel(nn.Module):
906
  pos_vec: torch.Tensor, # (B,) next write pos per row after prefill
907
  include_size: bool = True,
908
  max_objects: int = 50,
909
- lora=None
910
- ):
911
  """
912
  Batched decode loop for multi-label detection.
913
 
@@ -948,7 +947,7 @@ class MoondreamModel(nn.Module):
948
  mask[alive, :, pos] = 1
949
  logits, hidden = self._decode_one_tok(
950
  x_emb,
951
- mask.unsqueeze(2), # (B,1,1,max_ctx)
952
  torch.tensor([pos], device=device, dtype=torch.long),
953
  lora,
954
  )
@@ -966,7 +965,7 @@ class MoondreamModel(nn.Module):
966
  mask[alive, :, pos] = 1
967
  logits, hidden = self._decode_one_tok(
968
  y_emb,
969
- mask.unsqueeze(2), # (B,1,1,max_ctx)
970
  torch.tensor([pos], device=device, dtype=torch.long),
971
  lora,
972
  )
@@ -1028,6 +1027,7 @@ class MoondreamModel(nn.Module):
1028
 
1029
 
1030
 
 
1031
  def detect_multi(self, image, objects, settings=None):
1032
  """
1033
  Parallel multi-label detection.
 
906
  pos_vec: torch.Tensor, # (B,) next write pos per row after prefill
907
  include_size: bool = True,
908
  max_objects: int = 50,
909
+ lora=None):
 
910
  """
911
  Batched decode loop for multi-label detection.
912
 
 
947
  mask[alive, :, pos] = 1
948
  logits, hidden = self._decode_one_tok(
949
  x_emb,
950
+ mask.unsqueeze(2), # (B,1,1,max_ctx)
951
  torch.tensor([pos], device=device, dtype=torch.long),
952
  lora,
953
  )
 
965
  mask[alive, :, pos] = 1
966
  logits, hidden = self._decode_one_tok(
967
  y_emb,
968
+ mask.unsqueeze(2), # (B,1,1,max_ctx)
969
  torch.tensor([pos], device=device, dtype=torch.long),
970
  lora,
971
  )
 
1027
 
1028
 
1029
 
1030
+
1031
  def detect_multi(self, image, objects, settings=None):
1032
  """
1033
  Parallel multi-label detection.