File size: 3,405 Bytes
3f7dd83
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import logging

from control_models.base import ControlModel, get_bool
from typing import List, Dict
from label_studio_sdk.label_interface.control_tags import ControlTag


logger = logging.getLogger(__name__)


def is_obb(control: ControlTag) -> bool:
    """Check if the model should use oriented bounding boxes (OBB)
    based on the control tag attribute `model_obb` from the labeling config.
    """
    return get_bool(control.attr, "model_obb", "false")


class RectangleLabelsModel(ControlModel):
    """
    Class representing a RectangleLabels (bounding boxes) control tag for YOLO model.
    """

    type = "RectangleLabels"
    model_path = "yolov8m.pt"

    @classmethod
    def is_control_matched(cls, control) -> bool:
        # check object tag type
        if control.objects[0].tag != "Image":
            return False
        if is_obb(control):
            return False
        return control.tag == cls.type

    def predict_regions(self, path) -> List[Dict]:
        results = self.model.predict(path)
        self.debug_plot(results[0].plot())

        # oriented bounding boxes are detected, but it should be processed by RectangleLabelsObbModel
        if results[0].obb is not None and results[0].boxes is None:
            raise ValueError(
                "Oriented bounding boxes are detected in the YOLO model results. "
                'However, `model_obb="true"` is not set at the RectangleLabels tag '
                "in the labeling config."
            )

        # simple bounding boxes without rotation
        return self.create_rectangles(results, path)

    def create_rectangles(self, results, path):
        """Simple bounding boxes without rotation"""
        logger.debug(f"create_rectangles: {self.from_name}")
        data = results[0].boxes  # take bboxes from the first frame
        model_names = self.model.names
        regions = []

        for i in range(data.shape[0]):  # iterate over items
            score = float(data.conf[i])  # tensor => float
            x, y, w, h = data.xywhn[i].tolist()
            model_label = model_names[int(data.cls[i])]

            logger.debug(
                "----------------------\n"
                f"task id > {path}\n"
                f"type: {self.control}\n"
                f"x, y, w, h > {x, y, w, h}\n"
                f"model label > {model_label}\n"
                f"score > {score}\n"
            )

            # bbox score is too low
            if score < self.model_score_threshold:
                continue

            # there is no mapping between model label and LS label
            if model_label not in self.label_map:
                continue
            output_label = self.label_map[model_label]

            # add new region with rectangle
            region = {
                "from_name": self.from_name,
                "to_name": self.to_name,
                "type": "rectanglelabels",
                "value": {
                    "rectanglelabels": [output_label],
                    "x": (x - w / 2) * 100,
                    "y": (y - h / 2) * 100,
                    "width": w * 100,
                    "height": h * 100,
                },
                "score": score,
            }
            regions.append(region)
        return regions


# pre-load and cache default model at startup
RectangleLabelsModel.get_cached_model(RectangleLabelsModel.model_path)