LhatMjnk commited on
Commit
14ebdf3
·
verified ·
1 Parent(s): 6fee271

Update inference.py

Browse files
Files changed (1) hide show
  1. inference.py +125 -77
inference.py CHANGED
@@ -1,69 +1,72 @@
1
  # inference.py
2
  import torch
 
 
 
 
 
3
  import numpy as np
4
  from PIL import Image
5
  from transformers import SegformerImageProcessorFast, SegformerForSemanticSegmentation
6
 
7
- id2label = {
8
- '1': 'seagrass',
9
- '2': 'trash',
10
- '3': 'other coral dead',
11
- '4': 'other coral bleached',
12
- '5': 'sand',
13
- '6': 'other coral alive',
14
- '7': 'human',
15
- '8': 'transect tools',
16
- '9': 'fish',
17
- '10': 'algae covered substrate',
18
- '11': 'other animal',
19
- '12': 'unknown hard substrate',
20
- '13': 'background',
21
- '14': 'dark',
22
- '15': 'transect line',
23
- '16': 'massive/meandering bleached',
24
- '17': 'massive/meandering alive',
25
- '18': 'rubble',
26
- '19': 'branching bleached',
27
- '20': 'branching dead',
28
- '21': 'millepora',
29
- '22': 'branching alive',
30
- '23': 'massive/meandering dead',
31
- '24': 'clam',
32
- '25': 'acropora alive',
33
- '26': 'sea cucumber',
34
- '27': 'turbinaria',
35
- '28': 'table acropora alive',
36
- '29': 'sponge',
37
- '30': 'anemone',
38
- '31': 'pocillopora alive',
39
- '32': 'table acropora dead',
40
- '33': 'meandering bleached',
41
- '34': 'stylophora alive',
42
- '35': 'sea urchin',
43
- '36': 'meandering alive',
44
- '37': 'meandering dead',
45
- '38': 'crown of thorn',
46
- '39': 'dead clam'
47
- }
48
-
49
- label2color= {'human': [255, 0, 0], 'background': [29, 162, 216], 'fish': [255, 255, 0], 'sand': [194, 178, 128], 'rubble': [161, 153, 128], 'unknown hard substrate': [125, 125, 125], 'algae covered substrate': [125, 163, 125], 'dark': [31, 31, 31], 'branching bleached': [252, 231, 240], 'branching dead': [123, 50, 86], 'branching alive': [226, 91, 157], 'stylophora alive': [255, 111, 194], 'pocillopora alive': [255, 146, 150], 'acropora alive': [236, 128, 255], 'table acropora alive': [189, 119, 255], 'table acropora dead': [85, 53, 116], 'millepora': [244, 150, 115], 'turbinaria': [228, 255, 119], 'other coral bleached': [250, 224, 225], 'other coral dead': [114, 60, 61], 'other coral alive': [224, 118, 119], 'massive/meandering alive': [236, 150, 21], 'massive/meandering dead': [134, 86, 18], 'massive/meandering bleached': [255, 248, 228], 'meandering alive': [230, 193, 0], 'meandering dead': [119, 100, 14], 'meandering bleached': [251, 243, 216], 'transect line': [0, 255, 0], 'transect tools': [8, 205, 12], 'sea urchin': [0, 142, 255], 'sea cucumber': [0, 231, 255], 'anemone': [0, 255, 189], 'sponge': [240, 80, 80], 'clam': [189, 255, 234], 'other animal': [0, 255, 255], 'trash': [255, 0, 134], 'seagrass': [125, 222, 125], 'crown of thorn': [179, 245, 234], 'dead clam': [89, 155, 134]} # {'seagrass':[R,G,B],...}
50
-
51
- # Helper: build a palette aligned to class indices
52
- # We assume your model outputs class ids in [0..38].
53
- # We map index 0-> id "1", index 1-> id "2", ..., index 38-> id "39".
54
- # If your model uses a *different* order, define a custom `index_to_id` list accordingly.
55
- index_to_id = [str(i) for i in range(1, 40)] # ["1","2",...,"39"]
56
- index_to_name = [id2label[i] for i in index_to_id]
57
-
58
- def make_palette(index_to_name, label2color):
59
- palette = np.zeros((len(index_to_name), 3), dtype=np.uint8)
60
- for k, name in enumerate(index_to_name):
61
- rgb = label2color.get(name, [0, 0, 0])
62
- palette[k] = np.array(rgb, dtype=np.uint8)
63
- return palette
64
 
65
  # Load model from HF (swap this with your own if you want)
66
- HF_MODEL_ID = "EPFL-ECEO/segformer-b2-finetuned-coralscapes-1024-1024"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  class CoralSegModel:
69
  def __init__(self, device=None):
@@ -78,26 +81,71 @@ class CoralSegModel:
78
 
79
  self.model.eval()
80
 
81
- self.palette = make_palette(index_to_name, label2color)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
82
 
83
- @torch.inference_mode()
84
- def predict_overlay(self, frame_bgr: np.ndarray, alpha: float = 0.45) -> np.ndarray:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
85
  """
86
- frame_bgr: np.ndarray HxWx3 in BGR (as read by OpenCV)
87
- returns: np.ndarray HxWx3 in BGR (overlay)
 
 
88
  """
89
- # Convert BGR -> RGB PIL
90
- rgb = frame_bgr[:, :, ::-1]
91
- pil = Image.fromarray(rgb)
92
 
93
- inputs = self.processor(images=pil, return_tensors="pt", device=self.device).to(self.device, torch.bfloat16)
94
- outputs = self.model(**inputs)
95
- logits = outputs.logits # [B, C, h, w]
96
- upsampled = torch.nn.functional.interpolate(
97
- logits, size=pil.size[::-1], mode="bilinear", align_corners=False
98
- )
99
- pred = upsampled.argmax(dim=1)[0].detach().cpu().numpy().astype(np.uint8) # HxW
100
-
101
- color_mask = self.palette[pred] # HxWx3 (RGB)
102
- overlay_rgb = (rgb * (1 - alpha) + color_mask * alpha).astype(np.uint8)
103
- return overlay_rgb
 
1
  # inference.py
2
  import torch
3
+ import torch.nn.functional as F
4
+
5
+ import json
6
+ import urllib.request
7
+ import cv2
8
  import numpy as np
9
  from PIL import Image
10
  from transformers import SegformerImageProcessorFast, SegformerForSemanticSegmentation
11
 
12
+ id2label = json.load(urllib.request.urlopen(
13
+ "https://huggingface.co/datasets/EPFL-ECEO/coralscapes/resolve/main/id2label.json"))
14
+ label2color = json.load(urllib.request.urlopen(
15
+ "https://huggingface.co/datasets/EPFL-ECEO/coralscapes/resolve/main/label2color.json"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
  # Load model from HF (swap this with your own if you want)
18
+ HF_MODEL_ID = "EPFL-ECEO/segformer-b5-finetuned-coralscapes-1024-1024"
19
+
20
+ def create_segmentation_overlay(pred, id2label, label2color, image, alpha=0.25):
21
+ """
22
+ Colorizes the segmentation prediction and creates an overlay image.
23
+
24
+ Args:
25
+ pred: The segmentation prediction (numpy array).
26
+ id2label: Dictionary mapping class IDs to labels.
27
+ label2color: Dictionary mapping labels to colors.
28
+ image: The original PIL Image.
29
+
30
+ Returns:
31
+ A PIL Image representing the overlay of the original image and the colorized segmentation mask.
32
+ """
33
+ H, W = pred.shape
34
+ rgb = np.zeros((H, W, 3), dtype=np.uint8)
35
+
36
+ # Get unique class IDs present in the prediction
37
+ unique_classes = np.unique(pred)
38
+
39
+ # Create a mapping from class ID to color
40
+ id2color = {int(id): label2color[label] for id, label in id2label.items()}
41
+
42
+ # Define a default color for unknown classes (e.g., black)
43
+ default_color = [0, 0, 0]
44
+
45
+ # Iterate through unique class IDs and colorize the image
46
+ for class_id in unique_classes:
47
+ # Get the color for the current class ID, use default_color if not found
48
+ rgb_c = id2color.get(int(class_id), default_color)
49
+ # Assign the color to the pixels with the current class ID
50
+ rgb[pred == class_id] = rgb_c
51
+
52
+ mask_rgb = Image.fromarray(rgb)
53
+
54
+ # 4) Alpha overlay
55
+ overlay = Image.blend(image.convert("RGBA"), mask_rgb.convert("RGBA"), alpha=alpha)
56
+
57
+ return overlay
58
+
59
+ def resize_image(image, target_size=1024):
60
+ """
61
+ Used to resize the image such that the smaller side equals 1024
62
+ """
63
+ h_img, w_img = image.size
64
+ if h_img < w_img:
65
+ new_h, new_w = target_size, int(w_img * (target_size / h_img))
66
+ else:
67
+ new_h, new_w = int(h_img * (target_size / w_img)), target_size
68
+ resized_img = image.resize((new_h, new_w))
69
+ return resized_img
70
 
71
  class CoralSegModel:
72
  def __init__(self, device=None):
 
81
 
82
  self.model.eval()
83
 
84
+ @spaces.GPU
85
+ def segment_image(self, image, preprocessor, model, crop_size = (1024, 1024), num_classes = 40) -> np.ndarray:
86
+ """
87
+ Finds an optimal stride based on the image size and aspect ratio to create
88
+ overlapping sliding windows of size 1024x1024 which are then fed into the model.
89
+ """
90
+ h_crop, w_crop = crop_size
91
+
92
+ img = torch.Tensor(np.array(resize_image(image, target_size=1024)).transpose(2, 0, 1)).unsqueeze(0)
93
+ img = img.to(self.device, torch.bfloat16)
94
+ batch_size, _, h_img, w_img = img.size()
95
+
96
+ h_grids = int(np.round(3/2*h_img/h_crop)) if h_img > h_crop else 1
97
+ w_grids = int(np.round(3/2*w_img/w_crop)) if w_img > w_crop else 1
98
+
99
+ h_stride = int((h_img - h_crop + h_grids -1)/(h_grids -1)) if h_grids > 1 else h_crop
100
+ w_stride = int((w_img - w_crop + w_grids -1)/(w_grids -1)) if w_grids > 1 else w_crop
101
+
102
+ preds = img.new_zeros((batch_size, num_classes, h_img, w_img))
103
+ count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
104
 
105
+ for h_idx in range(h_grids):
106
+ for w_idx in range(w_grids):
107
+ y1 = h_idx * h_stride
108
+ x1 = w_idx * w_stride
109
+ y2 = min(y1 + h_crop, h_img)
110
+ x2 = min(x1 + w_crop, w_img)
111
+ y1 = max(y2 - h_crop, 0)
112
+ x1 = max(x2 - w_crop, 0)
113
+ crop_img = img[:, :, y1:y2, x1:x2]
114
+ with torch.no_grad():
115
+ if(preprocessor):
116
+ inputs = preprocessor(crop_img, return_tensors = "pt", device=self.device)
117
+ inputs["pixel_values"] = inputs["pixel_values"].to(self.device, torch.bfloat16)
118
+ else:
119
+ inputs = crop_img.to(self.device, torch.bfloat16)
120
+ outputs = model.to(self.device)(**inputs)
121
+
122
+ resized_logits = F.interpolate(
123
+ outputs.logits[0].unsqueeze(dim=0), size=crop_img.shape[-2:], mode="bilinear", align_corners=False
124
+ )
125
+ preds += F.pad(resized_logits,
126
+ (int(x1), int(preds.shape[3] - x2), int(y1),
127
+ int(preds.shape[2] - y2)))
128
+ count_mat[:, :, y1:y2, x1:x2] += 1
129
+
130
+ assert (count_mat == 0).sum() == 0
131
+ preds = preds / count_mat
132
+ preds = preds.argmax(dim=1)
133
+ preds = F.interpolate(preds.unsqueeze(0).type(torch.uint8), size=image.size[::-1], mode='nearest')
134
+ label_pred = preds.squeeze().cpu().numpy()
135
+ return label_pred
136
+
137
+ @spaces.GPU
138
+ def predict_map_and_overlay(self, frame_bgr: np.ndarray):
139
  """
140
+ Returns:
141
+ pred_map: HxW (uint8/int) with class indices in [0..C-1]
142
+ overlay: HxWx3 RGB uint8 (blended color mask over original)
143
+ rgb: HxWx3 RGB uint8 original frame (for AnnotatedImage base)
144
  """
145
+ rgb = frame_bgr
 
 
146
 
147
+ pil = Image.fromarray(rgb)
148
+ pred = self.segment_image(pil, self.processor, self.model)
149
+ overlay_rgb = create_segmentation_overlay(pred, id2label, label2color, pil, 0.5)
150
+
151
+ return pred, overlay_rgb, rgb