Zhen Ye commited on
Commit
f78d96f
·
1 Parent(s): 06e44d3

Fix BGR to RGB conversion for DETR and GroundingDino inference

Browse files
models/detectors/detr.py CHANGED
@@ -44,7 +44,11 @@ class DetrDetector(ObjectDetector):
44
  )
45
 
46
  def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
47
- inputs = self.processor(images=frame, return_tensors="pt")
 
 
 
 
48
  inputs = {key: value.to(self.device) for key, value in inputs.items()}
49
  with torch.no_grad():
50
  outputs = self.model(**inputs)
@@ -57,7 +61,11 @@ class DetrDetector(ObjectDetector):
57
  return self._parse_single_result(processed)
58
 
59
  def predict_batch(self, frames: Sequence[np.ndarray], queries: Sequence[str]) -> Sequence[DetectionResult]:
60
- inputs = self.processor(images=frames, return_tensors="pt", padding=True)
 
 
 
 
61
  inputs = {key: value.to(self.device) for key, value in inputs.items()}
62
 
63
  with torch.no_grad():
 
44
  )
45
 
46
  def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
47
+ # OpenCV frames are BGR, model expects RGB
48
+ import cv2
49
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
50
+
51
+ inputs = self.processor(images=frame_rgb, return_tensors="pt")
52
  inputs = {key: value.to(self.device) for key, value in inputs.items()}
53
  with torch.no_grad():
54
  outputs = self.model(**inputs)
 
61
  return self._parse_single_result(processed)
62
 
63
  def predict_batch(self, frames: Sequence[np.ndarray], queries: Sequence[str]) -> Sequence[DetectionResult]:
64
+ # OpenCV frames are BGR, model expects RGB
65
+ import cv2
66
+ frames_rgb = [cv2.cvtColor(f, cv2.COLOR_BGR2RGB) for f in frames]
67
+
68
+ inputs = self.processor(images=frames_rgb, return_tensors="pt", padding=True)
69
  inputs = {key: value.to(self.device) for key, value in inputs.items()}
70
 
71
  with torch.no_grad():
models/detectors/grounding_dino.py CHANGED
@@ -74,8 +74,12 @@ class GroundingDinoDetector(ObjectDetector):
74
  )
75
 
76
  def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
 
 
 
 
77
  prompt = self._build_prompt(queries)
78
- inputs = self.processor(images=frame, text=prompt, return_tensors="pt")
79
  inputs = {key: value.to(self.device) for key, value in inputs.items()}
80
  with torch.no_grad():
81
  outputs = self.model(**inputs)
@@ -84,9 +88,13 @@ class GroundingDinoDetector(ObjectDetector):
84
  return self._parse_single_result(processed_list[0])
85
 
86
  def predict_batch(self, frames: Sequence[np.ndarray], queries: Sequence[str]) -> Sequence[DetectionResult]:
 
 
 
 
87
  prompt = self._build_prompt(queries)
88
  # Same prompt for all frames in batch
89
- inputs = self.processor(images=frames, text=[prompt]*len(frames), return_tensors="pt", padding=True)
90
  inputs = {key: value.to(self.device) for key, value in inputs.items()}
91
 
92
  with torch.no_grad():
 
74
  )
75
 
76
  def predict(self, frame: np.ndarray, queries: Sequence[str]) -> DetectionResult:
77
+ # OpenCV frames are BGR, model expects RGB
78
+ import cv2
79
+ frame_rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
80
+
81
  prompt = self._build_prompt(queries)
82
+ inputs = self.processor(images=frame_rgb, text=prompt, return_tensors="pt")
83
  inputs = {key: value.to(self.device) for key, value in inputs.items()}
84
  with torch.no_grad():
85
  outputs = self.model(**inputs)
 
88
  return self._parse_single_result(processed_list[0])
89
 
90
  def predict_batch(self, frames: Sequence[np.ndarray], queries: Sequence[str]) -> Sequence[DetectionResult]:
91
+ # OpenCV frames are BGR, model expects RGB
92
+ import cv2
93
+ frames_rgb = [cv2.cvtColor(f, cv2.COLOR_BGR2RGB) for f in frames]
94
+
95
  prompt = self._build_prompt(queries)
96
  # Same prompt for all frames in batch
97
+ inputs = self.processor(images=frames_rgb, text=[prompt]*len(frames), return_tensors="pt", padding=True)
98
  inputs = {key: value.to(self.device) for key, value in inputs.items()}
99
 
100
  with torch.no_grad():