Spaces:
Sleeping
Sleeping
Update inference.py
Browse files- inference.py +125 -77
inference.py
CHANGED
|
@@ -1,69 +1,72 @@
|
|
| 1 |
# inference.py
|
| 2 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
import numpy as np
|
| 4 |
from PIL import Image
|
| 5 |
from transformers import SegformerImageProcessorFast, SegformerForSemanticSegmentation
|
| 6 |
|
| 7 |
-
id2label
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
'4': 'other coral bleached',
|
| 12 |
-
'5': 'sand',
|
| 13 |
-
'6': 'other coral alive',
|
| 14 |
-
'7': 'human',
|
| 15 |
-
'8': 'transect tools',
|
| 16 |
-
'9': 'fish',
|
| 17 |
-
'10': 'algae covered substrate',
|
| 18 |
-
'11': 'other animal',
|
| 19 |
-
'12': 'unknown hard substrate',
|
| 20 |
-
'13': 'background',
|
| 21 |
-
'14': 'dark',
|
| 22 |
-
'15': 'transect line',
|
| 23 |
-
'16': 'massive/meandering bleached',
|
| 24 |
-
'17': 'massive/meandering alive',
|
| 25 |
-
'18': 'rubble',
|
| 26 |
-
'19': 'branching bleached',
|
| 27 |
-
'20': 'branching dead',
|
| 28 |
-
'21': 'millepora',
|
| 29 |
-
'22': 'branching alive',
|
| 30 |
-
'23': 'massive/meandering dead',
|
| 31 |
-
'24': 'clam',
|
| 32 |
-
'25': 'acropora alive',
|
| 33 |
-
'26': 'sea cucumber',
|
| 34 |
-
'27': 'turbinaria',
|
| 35 |
-
'28': 'table acropora alive',
|
| 36 |
-
'29': 'sponge',
|
| 37 |
-
'30': 'anemone',
|
| 38 |
-
'31': 'pocillopora alive',
|
| 39 |
-
'32': 'table acropora dead',
|
| 40 |
-
'33': 'meandering bleached',
|
| 41 |
-
'34': 'stylophora alive',
|
| 42 |
-
'35': 'sea urchin',
|
| 43 |
-
'36': 'meandering alive',
|
| 44 |
-
'37': 'meandering dead',
|
| 45 |
-
'38': 'crown of thorn',
|
| 46 |
-
'39': 'dead clam'
|
| 47 |
-
}
|
| 48 |
-
|
| 49 |
-
label2color= {'human': [255, 0, 0], 'background': [29, 162, 216], 'fish': [255, 255, 0], 'sand': [194, 178, 128], 'rubble': [161, 153, 128], 'unknown hard substrate': [125, 125, 125], 'algae covered substrate': [125, 163, 125], 'dark': [31, 31, 31], 'branching bleached': [252, 231, 240], 'branching dead': [123, 50, 86], 'branching alive': [226, 91, 157], 'stylophora alive': [255, 111, 194], 'pocillopora alive': [255, 146, 150], 'acropora alive': [236, 128, 255], 'table acropora alive': [189, 119, 255], 'table acropora dead': [85, 53, 116], 'millepora': [244, 150, 115], 'turbinaria': [228, 255, 119], 'other coral bleached': [250, 224, 225], 'other coral dead': [114, 60, 61], 'other coral alive': [224, 118, 119], 'massive/meandering alive': [236, 150, 21], 'massive/meandering dead': [134, 86, 18], 'massive/meandering bleached': [255, 248, 228], 'meandering alive': [230, 193, 0], 'meandering dead': [119, 100, 14], 'meandering bleached': [251, 243, 216], 'transect line': [0, 255, 0], 'transect tools': [8, 205, 12], 'sea urchin': [0, 142, 255], 'sea cucumber': [0, 231, 255], 'anemone': [0, 255, 189], 'sponge': [240, 80, 80], 'clam': [189, 255, 234], 'other animal': [0, 255, 255], 'trash': [255, 0, 134], 'seagrass': [125, 222, 125], 'crown of thorn': [179, 245, 234], 'dead clam': [89, 155, 134]} # {'seagrass':[R,G,B],...}
|
| 50 |
-
|
| 51 |
-
# Helper: build a palette aligned to class indices
|
| 52 |
-
# We assume your model outputs class ids in [0..38].
|
| 53 |
-
# We map index 0-> id "1", index 1-> id "2", ..., index 38-> id "39".
|
| 54 |
-
# If your model uses a *different* order, define a custom `index_to_id` list accordingly.
|
| 55 |
-
index_to_id = [str(i) for i in range(1, 40)] # ["1","2",...,"39"]
|
| 56 |
-
index_to_name = [id2label[i] for i in index_to_id]
|
| 57 |
-
|
| 58 |
-
def make_palette(index_to_name, label2color):
|
| 59 |
-
palette = np.zeros((len(index_to_name), 3), dtype=np.uint8)
|
| 60 |
-
for k, name in enumerate(index_to_name):
|
| 61 |
-
rgb = label2color.get(name, [0, 0, 0])
|
| 62 |
-
palette[k] = np.array(rgb, dtype=np.uint8)
|
| 63 |
-
return palette
|
| 64 |
|
| 65 |
# Load model from HF (swap this with your own if you want)
|
| 66 |
-
HF_MODEL_ID = "EPFL-ECEO/segformer-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
class CoralSegModel:
|
| 69 |
def __init__(self, device=None):
|
|
@@ -78,26 +81,71 @@ class CoralSegModel:
|
|
| 78 |
|
| 79 |
self.model.eval()
|
| 80 |
|
| 81 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
-
|
| 84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 85 |
"""
|
| 86 |
-
|
| 87 |
-
|
|
|
|
|
|
|
| 88 |
"""
|
| 89 |
-
|
| 90 |
-
rgb = frame_bgr[:, :, ::-1]
|
| 91 |
-
pil = Image.fromarray(rgb)
|
| 92 |
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
)
|
| 99 |
-
pred = upsampled.argmax(dim=1)[0].detach().cpu().numpy().astype(np.uint8) # HxW
|
| 100 |
-
|
| 101 |
-
color_mask = self.palette[pred] # HxWx3 (RGB)
|
| 102 |
-
overlay_rgb = (rgb * (1 - alpha) + color_mask * alpha).astype(np.uint8)
|
| 103 |
-
return overlay_rgb
|
|
|
|
| 1 |
# inference.py
|
| 2 |
import torch
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
import json
|
| 6 |
+
import urllib.request
|
| 7 |
+
import cv2
|
| 8 |
import numpy as np
|
| 9 |
from PIL import Image
|
| 10 |
from transformers import SegformerImageProcessorFast, SegformerForSemanticSegmentation
|
| 11 |
|
| 12 |
+
id2label = json.load(urllib.request.urlopen(
|
| 13 |
+
"https://huggingface.co/datasets/EPFL-ECEO/coralscapes/resolve/main/id2label.json"))
|
| 14 |
+
label2color = json.load(urllib.request.urlopen(
|
| 15 |
+
"https://huggingface.co/datasets/EPFL-ECEO/coralscapes/resolve/main/label2color.json"))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
|
| 17 |
# Load model from HF (swap this with your own if you want)
|
| 18 |
+
HF_MODEL_ID = "EPFL-ECEO/segformer-b5-finetuned-coralscapes-1024-1024"
|
| 19 |
+
|
| 20 |
+
def create_segmentation_overlay(pred, id2label, label2color, image, alpha=0.25):
|
| 21 |
+
"""
|
| 22 |
+
Colorizes the segmentation prediction and creates an overlay image.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
pred: The segmentation prediction (numpy array).
|
| 26 |
+
id2label: Dictionary mapping class IDs to labels.
|
| 27 |
+
label2color: Dictionary mapping labels to colors.
|
| 28 |
+
image: The original PIL Image.
|
| 29 |
+
|
| 30 |
+
Returns:
|
| 31 |
+
A PIL Image representing the overlay of the original image and the colorized segmentation mask.
|
| 32 |
+
"""
|
| 33 |
+
H, W = pred.shape
|
| 34 |
+
rgb = np.zeros((H, W, 3), dtype=np.uint8)
|
| 35 |
+
|
| 36 |
+
# Get unique class IDs present in the prediction
|
| 37 |
+
unique_classes = np.unique(pred)
|
| 38 |
+
|
| 39 |
+
# Create a mapping from class ID to color
|
| 40 |
+
id2color = {int(id): label2color[label] for id, label in id2label.items()}
|
| 41 |
+
|
| 42 |
+
# Define a default color for unknown classes (e.g., black)
|
| 43 |
+
default_color = [0, 0, 0]
|
| 44 |
+
|
| 45 |
+
# Iterate through unique class IDs and colorize the image
|
| 46 |
+
for class_id in unique_classes:
|
| 47 |
+
# Get the color for the current class ID, use default_color if not found
|
| 48 |
+
rgb_c = id2color.get(int(class_id), default_color)
|
| 49 |
+
# Assign the color to the pixels with the current class ID
|
| 50 |
+
rgb[pred == class_id] = rgb_c
|
| 51 |
+
|
| 52 |
+
mask_rgb = Image.fromarray(rgb)
|
| 53 |
+
|
| 54 |
+
# 4) Alpha overlay
|
| 55 |
+
overlay = Image.blend(image.convert("RGBA"), mask_rgb.convert("RGBA"), alpha=alpha)
|
| 56 |
+
|
| 57 |
+
return overlay
|
| 58 |
+
|
| 59 |
+
def resize_image(image, target_size=1024):
|
| 60 |
+
"""
|
| 61 |
+
Used to resize the image such that the smaller side equals 1024
|
| 62 |
+
"""
|
| 63 |
+
h_img, w_img = image.size
|
| 64 |
+
if h_img < w_img:
|
| 65 |
+
new_h, new_w = target_size, int(w_img * (target_size / h_img))
|
| 66 |
+
else:
|
| 67 |
+
new_h, new_w = int(h_img * (target_size / w_img)), target_size
|
| 68 |
+
resized_img = image.resize((new_h, new_w))
|
| 69 |
+
return resized_img
|
| 70 |
|
| 71 |
class CoralSegModel:
|
| 72 |
def __init__(self, device=None):
|
|
|
|
| 81 |
|
| 82 |
self.model.eval()
|
| 83 |
|
| 84 |
+
@spaces.GPU
|
| 85 |
+
def segment_image(self, image, preprocessor, model, crop_size = (1024, 1024), num_classes = 40) -> np.ndarray:
|
| 86 |
+
"""
|
| 87 |
+
Finds an optimal stride based on the image size and aspect ratio to create
|
| 88 |
+
overlapping sliding windows of size 1024x1024 which are then fed into the model.
|
| 89 |
+
"""
|
| 90 |
+
h_crop, w_crop = crop_size
|
| 91 |
+
|
| 92 |
+
img = torch.Tensor(np.array(resize_image(image, target_size=1024)).transpose(2, 0, 1)).unsqueeze(0)
|
| 93 |
+
img = img.to(self.device, torch.bfloat16)
|
| 94 |
+
batch_size, _, h_img, w_img = img.size()
|
| 95 |
+
|
| 96 |
+
h_grids = int(np.round(3/2*h_img/h_crop)) if h_img > h_crop else 1
|
| 97 |
+
w_grids = int(np.round(3/2*w_img/w_crop)) if w_img > w_crop else 1
|
| 98 |
+
|
| 99 |
+
h_stride = int((h_img - h_crop + h_grids -1)/(h_grids -1)) if h_grids > 1 else h_crop
|
| 100 |
+
w_stride = int((w_img - w_crop + w_grids -1)/(w_grids -1)) if w_grids > 1 else w_crop
|
| 101 |
+
|
| 102 |
+
preds = img.new_zeros((batch_size, num_classes, h_img, w_img))
|
| 103 |
+
count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
|
| 104 |
|
| 105 |
+
for h_idx in range(h_grids):
|
| 106 |
+
for w_idx in range(w_grids):
|
| 107 |
+
y1 = h_idx * h_stride
|
| 108 |
+
x1 = w_idx * w_stride
|
| 109 |
+
y2 = min(y1 + h_crop, h_img)
|
| 110 |
+
x2 = min(x1 + w_crop, w_img)
|
| 111 |
+
y1 = max(y2 - h_crop, 0)
|
| 112 |
+
x1 = max(x2 - w_crop, 0)
|
| 113 |
+
crop_img = img[:, :, y1:y2, x1:x2]
|
| 114 |
+
with torch.no_grad():
|
| 115 |
+
if(preprocessor):
|
| 116 |
+
inputs = preprocessor(crop_img, return_tensors = "pt", device=self.device)
|
| 117 |
+
inputs["pixel_values"] = inputs["pixel_values"].to(self.device, torch.bfloat16)
|
| 118 |
+
else:
|
| 119 |
+
inputs = crop_img.to(self.device, torch.bfloat16)
|
| 120 |
+
outputs = model.to(self.device)(**inputs)
|
| 121 |
+
|
| 122 |
+
resized_logits = F.interpolate(
|
| 123 |
+
outputs.logits[0].unsqueeze(dim=0), size=crop_img.shape[-2:], mode="bilinear", align_corners=False
|
| 124 |
+
)
|
| 125 |
+
preds += F.pad(resized_logits,
|
| 126 |
+
(int(x1), int(preds.shape[3] - x2), int(y1),
|
| 127 |
+
int(preds.shape[2] - y2)))
|
| 128 |
+
count_mat[:, :, y1:y2, x1:x2] += 1
|
| 129 |
+
|
| 130 |
+
assert (count_mat == 0).sum() == 0
|
| 131 |
+
preds = preds / count_mat
|
| 132 |
+
preds = preds.argmax(dim=1)
|
| 133 |
+
preds = F.interpolate(preds.unsqueeze(0).type(torch.uint8), size=image.size[::-1], mode='nearest')
|
| 134 |
+
label_pred = preds.squeeze().cpu().numpy()
|
| 135 |
+
return label_pred
|
| 136 |
+
|
| 137 |
+
@spaces.GPU
|
| 138 |
+
def predict_map_and_overlay(self, frame_bgr: np.ndarray):
|
| 139 |
"""
|
| 140 |
+
Returns:
|
| 141 |
+
pred_map: HxW (uint8/int) with class indices in [0..C-1]
|
| 142 |
+
overlay: HxWx3 RGB uint8 (blended color mask over original)
|
| 143 |
+
rgb: HxWx3 RGB uint8 original frame (for AnnotatedImage base)
|
| 144 |
"""
|
| 145 |
+
rgb = frame_bgr
|
|
|
|
|
|
|
| 146 |
|
| 147 |
+
pil = Image.fromarray(rgb)
|
| 148 |
+
pred = self.segment_image(pil, self.processor, self.model)
|
| 149 |
+
overlay_rgb = create_segmentation_overlay(pred, id2label, label2color, pil, 0.5)
|
| 150 |
+
|
| 151 |
+
return pred, overlay_rgb, rgb
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|