Spaces:
Sleeping
Sleeping
Update inference.py
Browse files- inference.py +69 -10
inference.py
CHANGED
|
@@ -4,23 +4,83 @@ import numpy as np
|
|
| 4 |
from PIL import Image
|
| 5 |
from transformers import SegformerImageProcessorFast, SegformerForSemanticSegmentation
|
| 6 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 7 |
# Load model from HF (swap this with your own if you want)
|
| 8 |
HF_MODEL_ID = "EPFL-ECEO/segformer-b2-finetuned-coralscapes-1024-1024"
|
| 9 |
|
| 10 |
class CoralSegModel:
|
| 11 |
def __init__(self, device=None):
|
| 12 |
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
| 13 |
self.processor = SegformerImageProcessorFast.from_pretrained(HF_MODEL_ID)
|
| 14 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 15 |
self.model.eval()
|
| 16 |
|
| 17 |
-
|
| 18 |
-
# 0..N-1 colors - here random-ish but stable
|
| 19 |
-
num_classes = self.model.config.id2label and len(self.model.config.id2label) or 40
|
| 20 |
-
rng = np.random.RandomState(0)
|
| 21 |
-
self.palette = (rng.randint(0, 255, size=(num_classes, 3))).astype(np.uint8)
|
| 22 |
|
| 23 |
-
@
|
| 24 |
def predict_overlay(self, frame_bgr: np.ndarray, alpha: float = 0.45) -> np.ndarray:
|
| 25 |
"""
|
| 26 |
frame_bgr: np.ndarray HxWx3 in BGR (as read by OpenCV)
|
|
@@ -30,7 +90,7 @@ class CoralSegModel:
|
|
| 30 |
rgb = frame_bgr[:, :, ::-1]
|
| 31 |
pil = Image.fromarray(rgb)
|
| 32 |
|
| 33 |
-
inputs = self.processor(images=pil, return_tensors="pt", device=self.device).to(device)
|
| 34 |
outputs = self.model(**inputs)
|
| 35 |
logits = outputs.logits # [B, C, h, w]
|
| 36 |
upsampled = torch.nn.functional.interpolate(
|
|
@@ -40,5 +100,4 @@ class CoralSegModel:
|
|
| 40 |
|
| 41 |
color_mask = self.palette[pred] # HxWx3 (RGB)
|
| 42 |
overlay_rgb = (rgb * (1 - alpha) + color_mask * alpha).astype(np.uint8)
|
| 43 |
-
|
| 44 |
-
return overlay_bgr
|
|
|
|
| 4 |
from PIL import Image
|
| 5 |
from transformers import SegformerImageProcessorFast, SegformerForSemanticSegmentation
|
| 6 |
|
| 7 |
+
id2label = {
|
| 8 |
+
'1': 'seagrass',
|
| 9 |
+
'2': 'trash',
|
| 10 |
+
'3': 'other coral dead',
|
| 11 |
+
'4': 'other coral bleached',
|
| 12 |
+
'5': 'sand',
|
| 13 |
+
'6': 'other coral alive',
|
| 14 |
+
'7': 'human',
|
| 15 |
+
'8': 'transect tools',
|
| 16 |
+
'9': 'fish',
|
| 17 |
+
'10': 'algae covered substrate',
|
| 18 |
+
'11': 'other animal',
|
| 19 |
+
'12': 'unknown hard substrate',
|
| 20 |
+
'13': 'background',
|
| 21 |
+
'14': 'dark',
|
| 22 |
+
'15': 'transect line',
|
| 23 |
+
'16': 'massive/meandering bleached',
|
| 24 |
+
'17': 'massive/meandering alive',
|
| 25 |
+
'18': 'rubble',
|
| 26 |
+
'19': 'branching bleached',
|
| 27 |
+
'20': 'branching dead',
|
| 28 |
+
'21': 'millepora',
|
| 29 |
+
'22': 'branching alive',
|
| 30 |
+
'23': 'massive/meandering dead',
|
| 31 |
+
'24': 'clam',
|
| 32 |
+
'25': 'acropora alive',
|
| 33 |
+
'26': 'sea cucumber',
|
| 34 |
+
'27': 'turbinaria',
|
| 35 |
+
'28': 'table acropora alive',
|
| 36 |
+
'29': 'sponge',
|
| 37 |
+
'30': 'anemone',
|
| 38 |
+
'31': 'pocillopora alive',
|
| 39 |
+
'32': 'table acropora dead',
|
| 40 |
+
'33': 'meandering bleached',
|
| 41 |
+
'34': 'stylophora alive',
|
| 42 |
+
'35': 'sea urchin',
|
| 43 |
+
'36': 'meandering alive',
|
| 44 |
+
'37': 'meandering dead',
|
| 45 |
+
'38': 'crown of thorn',
|
| 46 |
+
'39': 'dead clam'
|
| 47 |
+
}
|
| 48 |
+
|
| 49 |
+
label2color= {'human': [255, 0, 0], 'background': [29, 162, 216], 'fish': [255, 255, 0], 'sand': [194, 178, 128], 'rubble': [161, 153, 128], 'unknown hard substrate': [125, 125, 125], 'algae covered substrate': [125, 163, 125], 'dark': [31, 31, 31], 'branching bleached': [252, 231, 240], 'branching dead': [123, 50, 86], 'branching alive': [226, 91, 157], 'stylophora alive': [255, 111, 194], 'pocillopora alive': [255, 146, 150], 'acropora alive': [236, 128, 255], 'table acropora alive': [189, 119, 255], 'table acropora dead': [85, 53, 116], 'millepora': [244, 150, 115], 'turbinaria': [228, 255, 119], 'other coral bleached': [250, 224, 225], 'other coral dead': [114, 60, 61], 'other coral alive': [224, 118, 119], 'massive/meandering alive': [236, 150, 21], 'massive/meandering dead': [134, 86, 18], 'massive/meandering bleached': [255, 248, 228], 'meandering alive': [230, 193, 0], 'meandering dead': [119, 100, 14], 'meandering bleached': [251, 243, 216], 'transect line': [0, 255, 0], 'transect tools': [8, 205, 12], 'sea urchin': [0, 142, 255], 'sea cucumber': [0, 231, 255], 'anemone': [0, 255, 189], 'sponge': [240, 80, 80], 'clam': [189, 255, 234], 'other animal': [0, 255, 255], 'trash': [255, 0, 134], 'seagrass': [125, 222, 125], 'crown of thorn': [179, 245, 234], 'dead clam': [89, 155, 134]} # {'seagrass':[R,G,B],...}
|
| 50 |
+
|
| 51 |
+
# Helper: build a palette aligned to class indices
|
| 52 |
+
# We assume your model outputs class ids in [0..38].
|
| 53 |
+
# We map index 0-> id "1", index 1-> id "2", ..., index 38-> id "39".
|
| 54 |
+
# If your model uses a *different* order, define a custom `index_to_id` list accordingly.
|
| 55 |
+
index_to_id = [str(i) for i in range(1, 40)] # ["1","2",...,"39"]
|
| 56 |
+
index_to_name = [id2label[i] for i in index_to_id]
|
| 57 |
+
|
| 58 |
+
def make_palette(index_to_name, label2color):
|
| 59 |
+
palette = np.zeros((len(index_to_name), 3), dtype=np.uint8)
|
| 60 |
+
for k, name in enumerate(index_to_name):
|
| 61 |
+
rgb = label2color.get(name, [0, 0, 0])
|
| 62 |
+
palette[k] = np.array(rgb, dtype=np.uint8)
|
| 63 |
+
return palette
|
| 64 |
+
|
| 65 |
# Load model from HF (swap this with your own if you want)
|
| 66 |
HF_MODEL_ID = "EPFL-ECEO/segformer-b2-finetuned-coralscapes-1024-1024"
|
| 67 |
|
| 68 |
class CoralSegModel:
|
| 69 |
def __init__(self, device=None):
|
| 70 |
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
| 71 |
+
|
| 72 |
self.processor = SegformerImageProcessorFast.from_pretrained(HF_MODEL_ID)
|
| 73 |
+
|
| 74 |
+
self.model = SegformerForSemanticSegmentation.from_pretrained(
|
| 75 |
+
HF_MODEL_ID,
|
| 76 |
+
dtype=torch.bfloat16
|
| 77 |
+
).to(self.device)
|
| 78 |
+
|
| 79 |
self.model.eval()
|
| 80 |
|
| 81 |
+
self.palette = make_palette(index_to_name, label2color)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 82 |
|
| 83 |
+
@torch.inference_mode()
|
| 84 |
def predict_overlay(self, frame_bgr: np.ndarray, alpha: float = 0.45) -> np.ndarray:
|
| 85 |
"""
|
| 86 |
frame_bgr: np.ndarray HxWx3 in BGR (as read by OpenCV)
|
|
|
|
| 90 |
rgb = frame_bgr[:, :, ::-1]
|
| 91 |
pil = Image.fromarray(rgb)
|
| 92 |
|
| 93 |
+
inputs = self.processor(images=pil, return_tensors="pt", device=self.device).to(self.device, torch.bfloat16)
|
| 94 |
outputs = self.model(**inputs)
|
| 95 |
logits = outputs.logits # [B, C, h, w]
|
| 96 |
upsampled = torch.nn.functional.interpolate(
|
|
|
|
| 100 |
|
| 101 |
color_mask = self.palette[pred] # HxWx3 (RGB)
|
| 102 |
overlay_rgb = (rgb * (1 - alpha) + color_mask * alpha).astype(np.uint8)
|
| 103 |
+
return overlay_rgb
|
|
|