import logging
from control_models.base import ControlModel, get_bool
from typing import List, Dict
logger = logging.getLogger(__name__)
class KeypointLabelsModel(ControlModel):
"""
Class representing a KeypointLabels control tag for YOLO model.
"""
type = "KeyPointLabels"
model_path = (
"yolov8n-pose.pt" # Adjust the model path to your keypoint detection model
)
add_bboxes: bool = True
point_size: float = 1
point_threshold: float = 0
point_map: Dict = {}
def __init__(self, **data):
super().__init__(**data)
self.add_bboxes = get_bool(self.control.attr, "model_add_bboxes", "true")
self.point_size = float(self.control.attr.get("model_point_size", 1))
self.point_threshold = float(self.control.attr.get("model_point_threshold", 0))
self.point_map = self.build_point_mapping()
@classmethod
def is_control_matched(cls, control) -> bool:
# Check object tag type
if control.objects[0].tag != "Image":
return False
return control.tag == cls.type
def build_point_mapping(self):
"""Build a mapping between points and Label Studio labels, e.g.
=> {"person::0": "nose"}
"""
mapping = {}
for value, label_tag in self.control.labels_attrs.items():
model_name = label_tag.attr.get("predicted_values")
model_index = label_tag.attr.get("model_index")
if model_name and not model_index:
logger.warning(
f"`model_index` is not provided for Label tag: {label_tag}"
)
if not model_name and model_index:
logger.warning(
f"`predicted_values` is not provided for Label tag: {label_tag}"
)
if model_name and model_index:
mapping[f"{model_name}::{model_index}"] = value
if not mapping:
logger.error(
f"No point to label mapping found for control tag: {self.control}"
)
return mapping
def predict_regions(self, path) -> List[Dict]:
results = self.model.predict(path)
return self.create_keypoints(results, path)
def create_keypoints(self, results, path):
logger.debug(f"create_keypoints: {self.from_name}")
keypoints_data = results[0].keypoints # Get keypoints from the first frame
bbox_data = results[0].boxes
image_width = results[0].orig_shape[1]
model_names = self.model.names
regions = []
for bbox_index in range(
keypoints_data.shape[0]
): # Iterate over detected bboxes
bbox_conf = bbox_data.conf[bbox_index]
point_xyn = (
keypoints_data.xyn[bbox_index] * 100
) # Convert normalized keypoints to percentages
model_label = model_names[int(results[0].boxes.cls[bbox_index])]
point_logs = "\n".join(
[f' model_index="{i}", xy={xyn}' for i, xyn in enumerate(point_xyn)]
)
logger.debug(
"----------------------\n"
f"task id > {path}\n"
f"type: {self.control}\n"
f"model label > {model_label}\n"
f"keypoints >\n{point_logs}\n"
f"confidences > {bbox_conf}\n"
)
# bbox score is too low
if bbox_conf < self.model_score_threshold:
continue
# There is no mapping between model label and LS label
if model_label not in self.label_map:
continue
# Add parent bbox that contains all keypoints
if self.add_bboxes:
region = self.create_bounding_box(
bbox_conf, bbox_data, bbox_index, model_label
)
regions.append(region)
for point_index, xyn in enumerate(point_xyn):
point_conf = keypoints_data.conf[bbox_index][point_index]
if point_conf < self.point_threshold:
continue
x, y = xyn.tolist()
index_name = f"{model_label}::{point_index}"
if index_name not in self.point_map:
logger.warning(
f"Point {index_name} not found in point map, "
f"you have to define it in the labeling config, e.g.:\n"
f''
)
continue
point_label = self.point_map[index_name]
# Add new region with keypoint
region = {
"from_name": self.from_name,
"to_name": self.to_name,
"type": "keypointlabels",
"value": {
# point label
"keypointlabels": [point_label],
# point width, just visual styling
"width": self.point_size / image_width * 100,
"x": x,
"y": y,
},
"meta": {
"text": [f"bbox-{bbox_index}"] # Group keypoints by bbox index
},
"score": float(point_conf),
}
# If bboxes are used, group keypoints by bbox
if self.add_bboxes:
region["parentID"] = f"bbox-{bbox_index}"
regions.append(region)
return regions
def create_bounding_box(self, bbox_conf, bbox_data, bbox_index, model_label):
# Add parent bbox that contains all keypoints
x, y, w, h = bbox_data.xywhn[bbox_index].tolist()
region = {
"id": f"bbox-{bbox_index}",
"from_name": self.from_name + "_bbox",
"to_name": self.to_name,
"type": "rectanglelabels",
"value": {
"rectanglelabels": [model_label],
"x": (x - w / 2) * 100,
"y": (y - h / 2) * 100,
"width": w * 100,
"height": h * 100,
},
"meta": {"text": [f"bbox-{bbox_index}"]}, # Group keypoints by bbox index
"score": float(bbox_conf),
"hidden": True,
}
return region
# Pre-load and cache default model at startup
KeypointLabelsModel.get_cached_model(KeypointLabelsModel.model_path)