NyxKrage commited on
Commit
1101625
·
verified ·
1 Parent(s): 2fa4720

fix bounding box detection not accounting for batch dimension

Browse files
Files changed (1) hide show
  1. modeling_moondream3.py +7 -4
modeling_moondream3.py CHANGED
@@ -1147,12 +1147,15 @@ class Moondream3ForConditionalGeneration(Moondream3PreTrainedModel, GenerationMi
1147
  torch.stack([w, h],dim=-1).to(size_logits.dtype)
1148
  )
1149
  ).unsqueeze(1)
 
 
1150
  bbox = [
1151
- x_center.item() - w.item() / 2,
1152
- y_center.item() - h.item() / 2,
1153
- x_center.item() + w.item() / 2,
1154
- y_center.item() + h.item() / 2,
1155
  ]
 
1156
  bbox = bbox * (batch_mask).unsqueeze(1)
1157
  pos_ids += 1
1158
  cache_pos = cache_pos + 1
 
1147
  torch.stack([w, h],dim=-1).to(size_logits.dtype)
1148
  )
1149
  ).unsqueeze(1)
1150
+ x_center = x_center.squeeze(1)
1151
+ y_center = y_center.squeeze(1)
1152
  bbox = [
1153
+ x_center - w / 2,
1154
+ y_center - h / 2,
1155
+ x_center + w / 2,
1156
+ y_center + h / 2,
1157
  ]
1158
+ bbox = torch.stack(bbox, dim=1)
1159
  bbox = bbox * (batch_mask).unsqueeze(1)
1160
  pos_ids += 1
1161
  cache_pos = cache_pos + 1