LhatMjnk commited on
Commit
c95ad8c
·
verified ·
1 Parent(s): ae30280

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +3 -3
inference.py CHANGED
@@ -2,7 +2,7 @@
2
  import torch
3
  import numpy as np
4
  from PIL import Image
5
- from transformers import SegformerImageProcessor, SegformerForSemanticSegmentation
6
 
7
  # Load model from HF (swap this with your own if you want)
8
  HF_MODEL_ID = "EPFL-ECEO/segformer-b2-finetuned-coralscapes-1024-1024"
@@ -10,7 +10,7 @@ HF_MODEL_ID = "EPFL-ECEO/segformer-b2-finetuned-coralscapes-1024-1024"
10
  class CoralSegModel:
11
  def __init__(self, device=None):
12
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
13
- self.processor = SegformerImageProcessor.from_pretrained(HF_MODEL_ID)
14
  self.model = SegformerForSemanticSegmentation.from_pretrained(HF_MODEL_ID).to(self.device)
15
  self.model.eval()
16
 
@@ -30,7 +30,7 @@ class CoralSegModel:
30
  rgb = frame_bgr[:, :, ::-1]
31
  pil = Image.fromarray(rgb)
32
 
33
- inputs = self.processor(images=pil, return_tensors="pt").to(self.device)
34
  outputs = self.model(**inputs)
35
  logits = outputs.logits # [B, C, h, w]
36
  upsampled = torch.nn.functional.interpolate(
 
2
  import torch
3
  import numpy as np
4
  from PIL import Image
5
+ from transformers import SegformerImageProcessorFast, SegformerForSemanticSegmentation
6
 
7
  # Load model from HF (swap this with your own if you want)
8
  HF_MODEL_ID = "EPFL-ECEO/segformer-b2-finetuned-coralscapes-1024-1024"
 
10
  class CoralSegModel:
11
  def __init__(self, device=None):
12
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
13
+ self.processor = SegformerImageProcessorFast.from_pretrained(HF_MODEL_ID)
14
  self.model = SegformerForSemanticSegmentation.from_pretrained(HF_MODEL_ID).to(self.device)
15
  self.model.eval()
16
 
 
30
  rgb = frame_bgr[:, :, ::-1]
31
  pil = Image.fromarray(rgb)
32
 
33
+ inputs = self.processor(images=pil, return_tensors="pt", device=self.device)
34
  outputs = self.model(**inputs)
35
  logits = outputs.logits # [B, C, h, w]
36
  upsampled = torch.nn.functional.interpolate(