Spaces:
Sleeping
Sleeping
File size: 6,053 Bytes
5311f6e 14ebdf3 36ab196 14ebdf3 5311f6e c95ad8c 5311f6e 14ebdf3 6fee271 5311f6e 14ebdf3 5311f6e 6fee271 c95ad8c 6fee271 5311f6e 14ebdf3 5311f6e 14ebdf3 5311f6e 14ebdf3 5311f6e 14ebdf3 5311f6e 14ebdf3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
# inference.py
import torch
import torch.nn.functional as F
import spaces
import json
import urllib.request
import cv2
import numpy as np
from PIL import Image
from transformers import SegformerImageProcessorFast, SegformerForSemanticSegmentation
id2label = json.load(urllib.request.urlopen(
"https://huggingface.co/datasets/EPFL-ECEO/coralscapes/resolve/main/id2label.json"))
label2color = json.load(urllib.request.urlopen(
"https://huggingface.co/datasets/EPFL-ECEO/coralscapes/resolve/main/label2color.json"))
# Load model from HF (swap this with your own if you want)
HF_MODEL_ID = "EPFL-ECEO/segformer-b5-finetuned-coralscapes-1024-1024"
def create_segmentation_overlay(pred, id2label, label2color, image, alpha=0.25):
"""
Colorizes the segmentation prediction and creates an overlay image.
Args:
pred: The segmentation prediction (numpy array).
id2label: Dictionary mapping class IDs to labels.
label2color: Dictionary mapping labels to colors.
image: The original PIL Image.
Returns:
A PIL Image representing the overlay of the original image and the colorized segmentation mask.
"""
H, W = pred.shape
rgb = np.zeros((H, W, 3), dtype=np.uint8)
# Get unique class IDs present in the prediction
unique_classes = np.unique(pred)
# Create a mapping from class ID to color
id2color = {int(id): label2color[label] for id, label in id2label.items()}
# Define a default color for unknown classes (e.g., black)
default_color = [0, 0, 0]
# Iterate through unique class IDs and colorize the image
for class_id in unique_classes:
# Get the color for the current class ID, use default_color if not found
rgb_c = id2color.get(int(class_id), default_color)
# Assign the color to the pixels with the current class ID
rgb[pred == class_id] = rgb_c
mask_rgb = Image.fromarray(rgb)
# 4) Alpha overlay
overlay = Image.blend(image.convert("RGBA"), mask_rgb.convert("RGBA"), alpha=alpha)
return overlay
def resize_image(image, target_size=1024):
"""
Used to resize the image such that the smaller side equals 1024
"""
h_img, w_img = image.size
if h_img < w_img:
new_h, new_w = target_size, int(w_img * (target_size / h_img))
else:
new_h, new_w = int(h_img * (target_size / w_img)), target_size
resized_img = image.resize((new_h, new_w))
return resized_img
class CoralSegModel:
def __init__(self, device=None):
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
self.processor = SegformerImageProcessorFast.from_pretrained(HF_MODEL_ID)
self.model = SegformerForSemanticSegmentation.from_pretrained(
HF_MODEL_ID,
dtype=torch.bfloat16
).to(self.device)
self.model.eval()
@spaces.GPU
def segment_image(self, image, preprocessor, model, crop_size = (1024, 1024), num_classes = 40) -> np.ndarray:
"""
Finds an optimal stride based on the image size and aspect ratio to create
overlapping sliding windows of size 1024x1024 which are then fed into the model.
"""
h_crop, w_crop = crop_size
img = torch.Tensor(np.array(resize_image(image, target_size=1024)).transpose(2, 0, 1)).unsqueeze(0)
img = img.to(self.device, torch.bfloat16)
batch_size, _, h_img, w_img = img.size()
h_grids = int(np.round(3/2*h_img/h_crop)) if h_img > h_crop else 1
w_grids = int(np.round(3/2*w_img/w_crop)) if w_img > w_crop else 1
h_stride = int((h_img - h_crop + h_grids -1)/(h_grids -1)) if h_grids > 1 else h_crop
w_stride = int((w_img - w_crop + w_grids -1)/(w_grids -1)) if w_grids > 1 else w_crop
preds = img.new_zeros((batch_size, num_classes, h_img, w_img))
count_mat = img.new_zeros((batch_size, 1, h_img, w_img))
for h_idx in range(h_grids):
for w_idx in range(w_grids):
y1 = h_idx * h_stride
x1 = w_idx * w_stride
y2 = min(y1 + h_crop, h_img)
x2 = min(x1 + w_crop, w_img)
y1 = max(y2 - h_crop, 0)
x1 = max(x2 - w_crop, 0)
crop_img = img[:, :, y1:y2, x1:x2]
with torch.no_grad():
if(preprocessor):
inputs = preprocessor(crop_img, return_tensors = "pt", device=self.device)
inputs["pixel_values"] = inputs["pixel_values"].to(self.device, torch.bfloat16)
else:
inputs = crop_img.to(self.device, torch.bfloat16)
outputs = model.to(self.device)(**inputs)
resized_logits = F.interpolate(
outputs.logits[0].unsqueeze(dim=0), size=crop_img.shape[-2:], mode="bilinear", align_corners=False
)
preds += F.pad(resized_logits,
(int(x1), int(preds.shape[3] - x2), int(y1),
int(preds.shape[2] - y2)))
count_mat[:, :, y1:y2, x1:x2] += 1
assert (count_mat == 0).sum() == 0
preds = preds / count_mat
preds = preds.argmax(dim=1)
preds = F.interpolate(preds.unsqueeze(0).type(torch.uint8), size=image.size[::-1], mode='nearest')
label_pred = preds.squeeze().cpu().numpy()
return label_pred
@spaces.GPU
def predict_map_and_overlay(self, frame_bgr: np.ndarray):
"""
Returns:
pred_map: HxW (uint8/int) with class indices in [0..C-1]
overlay: HxWx3 RGB uint8 (blended color mask over original)
rgb: HxWx3 RGB uint8 original frame (for AnnotatedImage base)
"""
rgb = frame_bgr
pil = Image.fromarray(rgb)
pred = self.segment_image(pil, self.processor, self.model)
overlay_rgb = create_segmentation_overlay(pred, id2label, label2color, pil, 0.5)
return pred, overlay_rgb, rgb
|