Update moondream.py
Browse filesfix: retry fix for batched point generation
- moondream.py +4 -4
moondream.py
CHANGED
|
@@ -906,8 +906,7 @@ class MoondreamModel(nn.Module):
|
|
| 906 |
pos_vec: torch.Tensor, # (B,) next write pos per row after prefill
|
| 907 |
include_size: bool = True,
|
| 908 |
max_objects: int = 50,
|
| 909 |
-
lora=None
|
| 910 |
-
):
|
| 911 |
"""
|
| 912 |
Batched decode loop for multi-label detection.
|
| 913 |
|
|
@@ -948,7 +947,7 @@ class MoondreamModel(nn.Module):
|
|
| 948 |
mask[alive, :, pos] = 1
|
| 949 |
logits, hidden = self._decode_one_tok(
|
| 950 |
x_emb,
|
| 951 |
-
mask.unsqueeze(2), # (B,1,1,max_ctx)
|
| 952 |
torch.tensor([pos], device=device, dtype=torch.long),
|
| 953 |
lora,
|
| 954 |
)
|
|
@@ -966,7 +965,7 @@ class MoondreamModel(nn.Module):
|
|
| 966 |
mask[alive, :, pos] = 1
|
| 967 |
logits, hidden = self._decode_one_tok(
|
| 968 |
y_emb,
|
| 969 |
-
mask.unsqueeze(2), # (B,1,1,max_ctx)
|
| 970 |
torch.tensor([pos], device=device, dtype=torch.long),
|
| 971 |
lora,
|
| 972 |
)
|
|
@@ -1028,6 +1027,7 @@ class MoondreamModel(nn.Module):
|
|
| 1028 |
|
| 1029 |
|
| 1030 |
|
|
|
|
| 1031 |
def detect_multi(self, image, objects, settings=None):
|
| 1032 |
"""
|
| 1033 |
Parallel multi-label detection.
|
|
|
|
| 906 |
pos_vec: torch.Tensor, # (B,) next write pos per row after prefill
|
| 907 |
include_size: bool = True,
|
| 908 |
max_objects: int = 50,
|
| 909 |
+
lora=None):
|
|
|
|
| 910 |
"""
|
| 911 |
Batched decode loop for multi-label detection.
|
| 912 |
|
|
|
|
| 947 |
mask[alive, :, pos] = 1
|
| 948 |
logits, hidden = self._decode_one_tok(
|
| 949 |
x_emb,
|
| 950 |
+
mask.unsqueeze(2), # (B,1,1,max_ctx) ✅
|
| 951 |
torch.tensor([pos], device=device, dtype=torch.long),
|
| 952 |
lora,
|
| 953 |
)
|
|
|
|
| 965 |
mask[alive, :, pos] = 1
|
| 966 |
logits, hidden = self._decode_one_tok(
|
| 967 |
y_emb,
|
| 968 |
+
mask.unsqueeze(2), # (B,1,1,max_ctx) ✅
|
| 969 |
torch.tensor([pos], device=device, dtype=torch.long),
|
| 970 |
lora,
|
| 971 |
)
|
|
|
|
| 1027 |
|
| 1028 |
|
| 1029 |
|
| 1030 |
+
|
| 1031 |
def detect_multi(self, image, objects, settings=None):
|
| 1032 |
"""
|
| 1033 |
Parallel multi-label detection.
|