ls-yolo-backend / control_models /brush_labels.py
davanstrien's picture
davanstrien HF Staff
Initial: HumanSignal yolo example patched for HF Spaces
3f7dd83 verified
import logging
from typing import List, Dict
from uuid import uuid4
import cv2
from label_studio_sdk.converter.brush import mask2rle
from control_models.base import ControlModel
logger = logging.getLogger(__name__)
class BrushLabelsModel(ControlModel):
"""
Class representing a BrushLabels control tag for YOLO model.
"""
type = "BrushLabels"
model_path = "yolov8n-seg.pt"
@classmethod
def is_control_matched(cls, control) -> bool:
# check object tag type
if control.objects[0].tag != "Image":
return False
return control.tag == cls.type
def predict_regions(self, path) -> List[Dict]:
results = self.model.predict(path)
return self.create_brush(results, path)
def create_brush(self, results, path):
logger.debug(f"create_brush: {self.from_name}")
data = results[0].masks
model_names = self.model.names
height, width = data.orig_shape
regions = []
for i in range(len(data)):
label_id = str(uuid4())[:9]
score = float(results[0].boxes.conf[i])
mask = (
cv2.resize(data.data[i].numpy(), results[0].orig_shape[::-1]) > 0
).astype("uint8") * 255
rle = mask2rle(mask)
model_label = model_names[int(results[0].boxes.cls[i])]
logger.debug(
"----------------------\n"
f"task id > {path}\n"
f"type: {self.control}\n"
f"rle > {rle}\n"
f"model label > {model_label}\n"
f"score > {score}\n"
)
if score < self.model_score_threshold:
continue
if model_label not in self.label_map:
continue
output_label = self.label_map[model_label]
region = {
"id": label_id,
"from_name": self.from_name,
"to_name": self.to_name,
"original_width": width,
"original_height": height,
"image_rotation": 0,
"value": {
"format": "rle",
"rle": rle,
"brushlabels": [output_label],
},
"score": score,
"type": "brushlabels",
}
regions.append(region)
return regions
# pre-load and cache default model at startup
BrushLabelsModel.get_cached_model(BrushLabelsModel.model_path)