Spaces:
Paused
Paused
File size: 3,405 Bytes
3f7dd83 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 | import logging
from control_models.base import ControlModel, get_bool
from typing import List, Dict
from label_studio_sdk.label_interface.control_tags import ControlTag
logger = logging.getLogger(__name__)
def is_obb(control: ControlTag) -> bool:
"""Check if the model should use oriented bounding boxes (OBB)
based on the control tag attribute `model_obb` from the labeling config.
"""
return get_bool(control.attr, "model_obb", "false")
class RectangleLabelsModel(ControlModel):
"""
Class representing a RectangleLabels (bounding boxes) control tag for YOLO model.
"""
type = "RectangleLabels"
model_path = "yolov8m.pt"
@classmethod
def is_control_matched(cls, control) -> bool:
# check object tag type
if control.objects[0].tag != "Image":
return False
if is_obb(control):
return False
return control.tag == cls.type
def predict_regions(self, path) -> List[Dict]:
results = self.model.predict(path)
self.debug_plot(results[0].plot())
# oriented bounding boxes are detected, but it should be processed by RectangleLabelsObbModel
if results[0].obb is not None and results[0].boxes is None:
raise ValueError(
"Oriented bounding boxes are detected in the YOLO model results. "
'However, `model_obb="true"` is not set at the RectangleLabels tag '
"in the labeling config."
)
# simple bounding boxes without rotation
return self.create_rectangles(results, path)
def create_rectangles(self, results, path):
"""Simple bounding boxes without rotation"""
logger.debug(f"create_rectangles: {self.from_name}")
data = results[0].boxes # take bboxes from the first frame
model_names = self.model.names
regions = []
for i in range(data.shape[0]): # iterate over items
score = float(data.conf[i]) # tensor => float
x, y, w, h = data.xywhn[i].tolist()
model_label = model_names[int(data.cls[i])]
logger.debug(
"----------------------\n"
f"task id > {path}\n"
f"type: {self.control}\n"
f"x, y, w, h > {x, y, w, h}\n"
f"model label > {model_label}\n"
f"score > {score}\n"
)
# bbox score is too low
if score < self.model_score_threshold:
continue
# there is no mapping between model label and LS label
if model_label not in self.label_map:
continue
output_label = self.label_map[model_label]
# add new region with rectangle
region = {
"from_name": self.from_name,
"to_name": self.to_name,
"type": "rectanglelabels",
"value": {
"rectanglelabels": [output_label],
"x": (x - w / 2) * 100,
"y": (y - h / 2) * 100,
"width": w * 100,
"height": h * 100,
},
"score": score,
}
regions.append(region)
return regions
# pre-load and cache default model at startup
RectangleLabelsModel.get_cached_model(RectangleLabelsModel.model_path)
|