LhatMjnk commited on
Commit
d85f92c
·
verified ·
1 Parent(s): c95ad8c

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +2 -2
inference.py CHANGED
@@ -20,7 +20,7 @@ class CoralSegModel:
20
  rng = np.random.RandomState(0)
21
  self.palette = (rng.randint(0, 255, size=(num_classes, 3))).astype(np.uint8)
22
 
23
- @torch.inference_mode()
24
  def predict_overlay(self, frame_bgr: np.ndarray, alpha: float = 0.45) -> np.ndarray:
25
  """
26
  frame_bgr: np.ndarray HxWx3 in BGR (as read by OpenCV)
@@ -30,7 +30,7 @@ class CoralSegModel:
30
  rgb = frame_bgr[:, :, ::-1]
31
  pil = Image.fromarray(rgb)
32
 
33
- inputs = self.processor(images=pil, return_tensors="pt", device=self.device)
34
  outputs = self.model(**inputs)
35
  logits = outputs.logits # [B, C, h, w]
36
  upsampled = torch.nn.functional.interpolate(
 
20
  rng = np.random.RandomState(0)
21
  self.palette = (rng.randint(0, 255, size=(num_classes, 3))).astype(np.uint8)
22
 
23
+ @spaces.GPU
24
  def predict_overlay(self, frame_bgr: np.ndarray, alpha: float = 0.45) -> np.ndarray:
25
  """
26
  frame_bgr: np.ndarray HxWx3 in BGR (as read by OpenCV)
 
30
  rgb = frame_bgr[:, :, ::-1]
31
  pil = Image.fromarray(rgb)
32
 
33
+ inputs = self.processor(images=pil, return_tensors="pt", device=self.device).to(device)
34
  outputs = self.model(**inputs)
35
  logits = outputs.logits # [B, C, h, w]
36
  upsampled = torch.nn.functional.interpolate(