fix bounding box detection not accounting for batch dimension
Browse files- modeling_moondream3.py +7 -4
modeling_moondream3.py
CHANGED
|
@@ -1147,12 +1147,15 @@ class Moondream3ForConditionalGeneration(Moondream3PreTrainedModel, GenerationMi
|
|
| 1147 |
torch.stack([w, h],dim=-1).to(size_logits.dtype)
|
| 1148 |
)
|
| 1149 |
).unsqueeze(1)
|
|
|
|
|
|
|
| 1150 |
bbox = [
|
| 1151 |
-
x_center
|
| 1152 |
-
y_center
|
| 1153 |
-
x_center
|
| 1154 |
-
y_center
|
| 1155 |
]
|
|
|
|
| 1156 |
bbox = bbox * (batch_mask).unsqueeze(1)
|
| 1157 |
pos_ids += 1
|
| 1158 |
cache_pos = cache_pos + 1
|
|
|
|
| 1147 |
torch.stack([w, h],dim=-1).to(size_logits.dtype)
|
| 1148 |
)
|
| 1149 |
).unsqueeze(1)
|
| 1150 |
+
x_center = x_center.squeeze(1)
|
| 1151 |
+
y_center = y_center.squeeze(1)
|
| 1152 |
bbox = [
|
| 1153 |
+
x_center - w / 2,
|
| 1154 |
+
y_center - h / 2,
|
| 1155 |
+
x_center + w / 2,
|
| 1156 |
+
y_center + h / 2,
|
| 1157 |
]
|
| 1158 |
+
bbox = torch.stack(bbox, dim=1)
|
| 1159 |
bbox = bbox * (batch_mask).unsqueeze(1)
|
| 1160 |
pos_ids += 1
|
| 1161 |
cache_pos = cache_pos + 1
|