Spaces:
Paused
Paused
File size: 7,642 Bytes
3f7dd83 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 | import os
import logging
from pydantic import BaseModel
from typing import Optional, List, Dict, ClassVar
from ultralytics import YOLO
from label_studio_ml.model import LabelStudioMLBase
from label_studio_ml.utils import DATA_UNDEFINED_NAME
from label_studio_sdk.label_interface.control_tags import ControlTag
from label_studio_sdk.label_interface import LabelInterface
# use matplotlib plots for debug
DEBUG_PLOT = os.getenv("DEBUG_PLOT", "false").lower() in ["1", "true"]
MODEL_SCORE_THRESHOLD = float(os.getenv("MODEL_SCORE_THRESHOLD", 0.5))
DEFAULT_MODEL_ROOT = os.path.join(os.path.dirname(os.path.dirname(__file__)), "models")
MODEL_ROOT = os.getenv("MODEL_ROOT", DEFAULT_MODEL_ROOT)
os.makedirs(MODEL_ROOT, exist_ok=True)
# if true, allow to use custom model path from the control tag in the labeling config
ALLOW_CUSTOM_MODEL_PATH = os.getenv("ALLOW_CUSTOM_MODEL_PATH", "true").lower() in [
"1",
"true",
]
# Global cache for YOLO models
_model_cache = {}
logger = logging.getLogger(__name__)
def get_bool(attr, attr_name, default="false"):
return attr.get(attr_name, default).lower() in ["1", "true", "yes"]
class ControlModel(BaseModel):
"""
Represents a control tag in Label Studio, which is associated with a specific type of labeling task
and is used to generate predictions using a YOLO model.
Attributes:
type (str): Type of the control, e.g., RectangleLabels, Choices, etc.
control (ControlTag): The actual control element from the Label Studio configuration.
from_name (str): The name of the control tag, used to link the control to the data.
to_name (str): The name of the data field that this control is associated with.
value (str): The value name from the object that this control operates on, e.g., an image or text field.
model (object): The model instance (e.g., YOLO) used to generate predictions for this control.
model_path (str): Path to the YOLO model file.
model_score_threshold (float): Threshold for prediction scores; predictions below this value will be ignored.
label_map (Optional[Dict[str, str]]): A mapping of model labels to Label Studio labels.
"""
type: ClassVar[str]
control: ControlTag
from_name: str
to_name: str
value: str
model: YOLO
model_path: ClassVar[str]
model_score_threshold: float = 0.5
label_map: Optional[Dict[str, str]] = {}
label_studio_ml_backend: LabelStudioMLBase
project_id: Optional[str] = None
def __init__(self, **data):
super().__init__(**data)
@classmethod
def is_control_matched(cls, control) -> bool:
"""Check if the control tag matches the model type.
Args:
control (ControlTag): The control tag from the Label Studio Interface.
"""
raise NotImplementedError("This method should be overridden in derived classes")
@staticmethod
def get_from_name_for_label_map(
label_interface: LabelInterface, target_name: str
) -> str:
"""Get the 'from_name' attribute for the label map building."""
return target_name
@classmethod
def create(cls, mlbackend: LabelStudioMLBase, control: ControlTag):
"""Factory method to create an instance of a specific control model class.
Args:
mlbackend (LabelStudioMLBase): The ML backend instance.
control (ControlTag): The control tag from the Label Studio Interface.
"""
from_name = control.name
to_name = control.to_name[0]
value = control.objects[0].value_name
# if skip is true, don't process this control
if get_bool(control.attr, "model_skip", "false"):
logger.info(
f"Skipping control tag '{control.tag}' with name '{from_name}', model_skip=true found"
)
return None
# read threshold attribute from the control tag, e.g.: <RectangleLabels model_score_threshold="0.5">
model_score_threshold = float(
control.attr.get("model_score_threshold")
or control.attr.get(
"score_threshold"
) # not recommended option, use `model_score_threshold`
or MODEL_SCORE_THRESHOLD
)
# read `model_path` attribute from the control tag
model_path = (
ALLOW_CUSTOM_MODEL_PATH and control.attr.get("model_path")
) or cls.model_path
model = cls.get_cached_model(model_path)
model_names = model.names.values() # class names from the model
# from_name for label mapping can be differed from control.name (e.g. VideoRectangle)
label_map_from_name = cls.get_from_name_for_label_map(
mlbackend.label_interface, from_name
)
label_map = mlbackend.build_label_map(label_map_from_name, model_names)
return cls(
control=control,
from_name=from_name,
to_name=to_name,
value=value,
model=model,
model_score_threshold=model_score_threshold,
label_map=label_map,
label_studio_ml_backend=mlbackend,
project_id=mlbackend.project_id,
)
@classmethod
def load_yolo_model(cls, filename) -> YOLO:
"""Load YOLO model from the file."""
path = os.path.join(MODEL_ROOT, filename)
logger.info(f"Loading yolo model: {path}")
model = YOLO(path)
logger.info(f"Model {path} names:\n{model.names}")
return model
@classmethod
def get_cached_model(cls, path: str) -> YOLO:
if path not in _model_cache:
_model_cache[path] = cls.load_yolo_model(path)
return _model_cache[path]
def debug_plot(self, image):
if not DEBUG_PLOT:
return
import matplotlib.pyplot as plt
plt.figure(figsize=(10, 10))
plt.imshow(image[..., ::-1])
plt.axis("off")
plt.title(self.type)
plt.show()
def predict_regions(self, path) -> List[Dict]:
"""Predict regions in the image using the YOLO model.
Args:
path (str): Path to the file with media
"""
raise NotImplementedError("This method should be overridden in derived classes")
def fit(self, event, data, **kwargs):
"""Fit the model."""
logger.warning("The fit method is not implemented for this control model")
return False
def get_path(self, task):
task_path = task["data"].get(self.value) or task["data"].get(
DATA_UNDEFINED_NAME
)
if task_path is None:
raise ValueError(
f"Can't load path using key '{self.value}' from task {task}"
)
if not isinstance(task_path, str):
raise ValueError(f"Path should be a string, but got {task_path}")
# try path as local file or try to load it from Label Studio instance/download via http
path = (
task_path
if os.path.exists(task_path)
else self.label_studio_ml_backend.get_local_path(task_path, task_id=task.get("id"))
)
logger.debug(f"load_image: {task_path} => {path}")
return path
def __str__(self):
"""Return a string with full representation of the control tag."""
return (
f"{self.type} from_name={self.from_name}, "
f"label_map={self.label_map}, model_score_threshold={self.model_score_threshold}"
)
class Config:
arbitrary_types_allowed = True
protected_namespaces = ("__.*__", "_.*") # Excludes 'model_'
|