LhatMjnk commited on
Commit
6fee271
·
verified ·
1 Parent(s): 2e8fd20

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +69 -10
inference.py CHANGED
@@ -4,23 +4,83 @@ import numpy as np
4
  from PIL import Image
5
  from transformers import SegformerImageProcessorFast, SegformerForSemanticSegmentation
6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  # Load model from HF (swap this with your own if you want)
8
  HF_MODEL_ID = "EPFL-ECEO/segformer-b2-finetuned-coralscapes-1024-1024"
9
 
10
  class CoralSegModel:
11
  def __init__(self, device=None):
12
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
 
13
  self.processor = SegformerImageProcessorFast.from_pretrained(HF_MODEL_ID)
14
- self.model = SegformerForSemanticSegmentation.from_pretrained(HF_MODEL_ID).to(self.device)
 
 
 
 
 
15
  self.model.eval()
16
 
17
- # Build a simple color palette for masks (fallback if none provided)
18
- # 0..N-1 colors - here random-ish but stable
19
- num_classes = self.model.config.id2label and len(self.model.config.id2label) or 40
20
- rng = np.random.RandomState(0)
21
- self.palette = (rng.randint(0, 255, size=(num_classes, 3))).astype(np.uint8)
22
 
23
- @spaces.GPU
24
  def predict_overlay(self, frame_bgr: np.ndarray, alpha: float = 0.45) -> np.ndarray:
25
  """
26
  frame_bgr: np.ndarray HxWx3 in BGR (as read by OpenCV)
@@ -30,7 +90,7 @@ class CoralSegModel:
30
  rgb = frame_bgr[:, :, ::-1]
31
  pil = Image.fromarray(rgb)
32
 
33
- inputs = self.processor(images=pil, return_tensors="pt", device=self.device).to(device)
34
  outputs = self.model(**inputs)
35
  logits = outputs.logits # [B, C, h, w]
36
  upsampled = torch.nn.functional.interpolate(
@@ -40,5 +100,4 @@ class CoralSegModel:
40
 
41
  color_mask = self.palette[pred] # HxWx3 (RGB)
42
  overlay_rgb = (rgb * (1 - alpha) + color_mask * alpha).astype(np.uint8)
43
- overlay_bgr = overlay_rgb[:, :, ::-1]
44
- return overlay_bgr
 
4
  from PIL import Image
5
  from transformers import SegformerImageProcessorFast, SegformerForSemanticSegmentation
6
 
7
+ id2label = {
8
+ '1': 'seagrass',
9
+ '2': 'trash',
10
+ '3': 'other coral dead',
11
+ '4': 'other coral bleached',
12
+ '5': 'sand',
13
+ '6': 'other coral alive',
14
+ '7': 'human',
15
+ '8': 'transect tools',
16
+ '9': 'fish',
17
+ '10': 'algae covered substrate',
18
+ '11': 'other animal',
19
+ '12': 'unknown hard substrate',
20
+ '13': 'background',
21
+ '14': 'dark',
22
+ '15': 'transect line',
23
+ '16': 'massive/meandering bleached',
24
+ '17': 'massive/meandering alive',
25
+ '18': 'rubble',
26
+ '19': 'branching bleached',
27
+ '20': 'branching dead',
28
+ '21': 'millepora',
29
+ '22': 'branching alive',
30
+ '23': 'massive/meandering dead',
31
+ '24': 'clam',
32
+ '25': 'acropora alive',
33
+ '26': 'sea cucumber',
34
+ '27': 'turbinaria',
35
+ '28': 'table acropora alive',
36
+ '29': 'sponge',
37
+ '30': 'anemone',
38
+ '31': 'pocillopora alive',
39
+ '32': 'table acropora dead',
40
+ '33': 'meandering bleached',
41
+ '34': 'stylophora alive',
42
+ '35': 'sea urchin',
43
+ '36': 'meandering alive',
44
+ '37': 'meandering dead',
45
+ '38': 'crown of thorn',
46
+ '39': 'dead clam'
47
+ }
48
+
49
+ label2color= {'human': [255, 0, 0], 'background': [29, 162, 216], 'fish': [255, 255, 0], 'sand': [194, 178, 128], 'rubble': [161, 153, 128], 'unknown hard substrate': [125, 125, 125], 'algae covered substrate': [125, 163, 125], 'dark': [31, 31, 31], 'branching bleached': [252, 231, 240], 'branching dead': [123, 50, 86], 'branching alive': [226, 91, 157], 'stylophora alive': [255, 111, 194], 'pocillopora alive': [255, 146, 150], 'acropora alive': [236, 128, 255], 'table acropora alive': [189, 119, 255], 'table acropora dead': [85, 53, 116], 'millepora': [244, 150, 115], 'turbinaria': [228, 255, 119], 'other coral bleached': [250, 224, 225], 'other coral dead': [114, 60, 61], 'other coral alive': [224, 118, 119], 'massive/meandering alive': [236, 150, 21], 'massive/meandering dead': [134, 86, 18], 'massive/meandering bleached': [255, 248, 228], 'meandering alive': [230, 193, 0], 'meandering dead': [119, 100, 14], 'meandering bleached': [251, 243, 216], 'transect line': [0, 255, 0], 'transect tools': [8, 205, 12], 'sea urchin': [0, 142, 255], 'sea cucumber': [0, 231, 255], 'anemone': [0, 255, 189], 'sponge': [240, 80, 80], 'clam': [189, 255, 234], 'other animal': [0, 255, 255], 'trash': [255, 0, 134], 'seagrass': [125, 222, 125], 'crown of thorn': [179, 245, 234], 'dead clam': [89, 155, 134]} # {'seagrass':[R,G,B],...}
50
+
51
+ # Helper: build a palette aligned to class indices
52
+ # We assume your model outputs class ids in [0..38].
53
+ # We map index 0-> id "1", index 1-> id "2", ..., index 38-> id "39".
54
+ # If your model uses a *different* order, define a custom `index_to_id` list accordingly.
55
+ index_to_id = [str(i) for i in range(1, 40)] # ["1","2",...,"39"]
56
+ index_to_name = [id2label[i] for i in index_to_id]
57
+
58
+ def make_palette(index_to_name, label2color):
59
+ palette = np.zeros((len(index_to_name), 3), dtype=np.uint8)
60
+ for k, name in enumerate(index_to_name):
61
+ rgb = label2color.get(name, [0, 0, 0])
62
+ palette[k] = np.array(rgb, dtype=np.uint8)
63
+ return palette
64
+
65
  # Load model from HF (swap this with your own if you want)
66
  HF_MODEL_ID = "EPFL-ECEO/segformer-b2-finetuned-coralscapes-1024-1024"
67
 
68
  class CoralSegModel:
69
  def __init__(self, device=None):
70
  self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
71
+
72
  self.processor = SegformerImageProcessorFast.from_pretrained(HF_MODEL_ID)
73
+
74
+ self.model = SegformerForSemanticSegmentation.from_pretrained(
75
+ HF_MODEL_ID,
76
+ dtype=torch.bfloat16
77
+ ).to(self.device)
78
+
79
  self.model.eval()
80
 
81
+ self.palette = make_palette(index_to_name, label2color)
 
 
 
 
82
 
83
+ @torch.inference_mode()
84
  def predict_overlay(self, frame_bgr: np.ndarray, alpha: float = 0.45) -> np.ndarray:
85
  """
86
  frame_bgr: np.ndarray HxWx3 in BGR (as read by OpenCV)
 
90
  rgb = frame_bgr[:, :, ::-1]
91
  pil = Image.fromarray(rgb)
92
 
93
+ inputs = self.processor(images=pil, return_tensors="pt", device=self.device).to(self.device, torch.bfloat16)
94
  outputs = self.model(**inputs)
95
  logits = outputs.logits # [B, C, h, w]
96
  upsampled = torch.nn.functional.interpolate(
 
100
 
101
  color_mask = self.palette[pred] # HxWx3 (RGB)
102
  overlay_rgb = (rgb * (1 - alpha) + color_mask * alpha).astype(np.uint8)
103
+ return overlay_rgb