Zhen Ye commited on
Commit
af0f84f
·
1 Parent(s): 8bc4370

fix: replaced batch with sequential

Browse files
Files changed (1) hide show
  1. models/detectors/grounding_dino.py +4 -46
models/detectors/grounding_dino.py CHANGED
@@ -103,49 +103,7 @@ class GroundingDinoDetector(ObjectDetector):
103
  return self._parse_single_result(processed_list[0])
104
 
105
  def predict_batch(self, frames: Sequence[np.ndarray], queries: Sequence[str]) -> Sequence[DetectionResult]:
106
- if not frames:
107
- return []
108
-
109
- import cv2
110
- from types import SimpleNamespace
111
-
112
- frames_rgb = [cv2.cvtColor(f, cv2.COLOR_BGR2RGB) for f in frames]
113
- prompt = self._build_prompt(queries)
114
-
115
- # 1. Preprocess each frame individually (avoids batch processor issues)
116
- individual_inputs = []
117
- for frame in frames_rgb:
118
- inp = self.processor(images=frame, text=prompt, return_tensors="pt")
119
- individual_inputs.append(inp)
120
-
121
- # 2. Stack into batch for GPU forward pass
122
- # All frames are from the same video (same resolution), so tensor shapes match.
123
- # If they don't (edge case), fall back to sequential predict().
124
- batch_keys = list(individual_inputs[0].keys())
125
- try:
126
- batch_inputs = {}
127
- for key in batch_keys:
128
- batch_inputs[key] = torch.cat(
129
- [inp[key] for inp in individual_inputs], dim=0
130
- ).to(self.device)
131
- except RuntimeError:
132
- # Shape mismatch (different resolutions) — fall back to sequential
133
- return [self.predict(f, queries) for f in frames]
134
-
135
- # 3. Batched forward pass (GPU-efficient)
136
- with torch.no_grad():
137
- outputs = self.model(**batch_inputs)
138
-
139
- # 4. Per-frame post-processing using individual (non-batched) input_ids
140
- single_input_ids = individual_inputs[0]["input_ids"].to(self.device)
141
- results = []
142
- for i in range(len(frames)):
143
- frame_outputs = SimpleNamespace(
144
- logits=outputs.logits[i : i + 1],
145
- pred_boxes=outputs.pred_boxes[i : i + 1],
146
- )
147
- target_sizes = torch.tensor([frames[i].shape[:2]], device=self.device)
148
- processed = self._post_process(frame_outputs, single_input_ids, target_sizes)
149
- results.append(self._parse_single_result(processed[0]))
150
-
151
- return results
 
103
  return self._parse_single_result(processed_list[0])
104
 
105
  def predict_batch(self, frames: Sequence[np.ndarray], queries: Sequence[str]) -> Sequence[DetectionResult]:
106
+ # Grounding DINO's forward pass produces degraded/zero logits at batch_size > 1
107
+ # (known HF issue #32206, #34346). Fall back to sequential single-frame inference
108
+ # which is the only path proven to produce correct detections.
109
+ return [self.predict(f, queries) for f in frames]