Spaces:
Paused
Paused
| import os | |
| import logging | |
| from pydantic import BaseModel | |
| from typing import Optional, List, Dict, ClassVar | |
| from ultralytics import YOLO | |
| from label_studio_ml.model import LabelStudioMLBase | |
| from label_studio_ml.utils import DATA_UNDEFINED_NAME | |
| from label_studio_sdk.label_interface.control_tags import ControlTag | |
| from label_studio_sdk.label_interface import LabelInterface | |
| # use matplotlib plots for debug | |
| DEBUG_PLOT = os.getenv("DEBUG_PLOT", "false").lower() in ["1", "true"] | |
| MODEL_SCORE_THRESHOLD = float(os.getenv("MODEL_SCORE_THRESHOLD", 0.5)) | |
| DEFAULT_MODEL_ROOT = os.path.join(os.path.dirname(os.path.dirname(__file__)), "models") | |
| MODEL_ROOT = os.getenv("MODEL_ROOT", DEFAULT_MODEL_ROOT) | |
| os.makedirs(MODEL_ROOT, exist_ok=True) | |
| # if true, allow to use custom model path from the control tag in the labeling config | |
| ALLOW_CUSTOM_MODEL_PATH = os.getenv("ALLOW_CUSTOM_MODEL_PATH", "true").lower() in [ | |
| "1", | |
| "true", | |
| ] | |
| # Global cache for YOLO models | |
| _model_cache = {} | |
| logger = logging.getLogger(__name__) | |
| def get_bool(attr, attr_name, default="false"): | |
| return attr.get(attr_name, default).lower() in ["1", "true", "yes"] | |
| class ControlModel(BaseModel): | |
| """ | |
| Represents a control tag in Label Studio, which is associated with a specific type of labeling task | |
| and is used to generate predictions using a YOLO model. | |
| Attributes: | |
| type (str): Type of the control, e.g., RectangleLabels, Choices, etc. | |
| control (ControlTag): The actual control element from the Label Studio configuration. | |
| from_name (str): The name of the control tag, used to link the control to the data. | |
| to_name (str): The name of the data field that this control is associated with. | |
| value (str): The value name from the object that this control operates on, e.g., an image or text field. | |
| model (object): The model instance (e.g., YOLO) used to generate predictions for this control. | |
| model_path (str): Path to the YOLO model file. | |
| model_score_threshold (float): Threshold for prediction scores; predictions below this value will be ignored. | |
| label_map (Optional[Dict[str, str]]): A mapping of model labels to Label Studio labels. | |
| """ | |
| type: ClassVar[str] | |
| control: ControlTag | |
| from_name: str | |
| to_name: str | |
| value: str | |
| model: YOLO | |
| model_path: ClassVar[str] | |
| model_score_threshold: float = 0.5 | |
| label_map: Optional[Dict[str, str]] = {} | |
| label_studio_ml_backend: LabelStudioMLBase | |
| project_id: Optional[str] = None | |
| def __init__(self, **data): | |
| super().__init__(**data) | |
| def is_control_matched(cls, control) -> bool: | |
| """Check if the control tag matches the model type. | |
| Args: | |
| control (ControlTag): The control tag from the Label Studio Interface. | |
| """ | |
| raise NotImplementedError("This method should be overridden in derived classes") | |
| def get_from_name_for_label_map( | |
| label_interface: LabelInterface, target_name: str | |
| ) -> str: | |
| """Get the 'from_name' attribute for the label map building.""" | |
| return target_name | |
| def create(cls, mlbackend: LabelStudioMLBase, control: ControlTag): | |
| """Factory method to create an instance of a specific control model class. | |
| Args: | |
| mlbackend (LabelStudioMLBase): The ML backend instance. | |
| control (ControlTag): The control tag from the Label Studio Interface. | |
| """ | |
| from_name = control.name | |
| to_name = control.to_name[0] | |
| value = control.objects[0].value_name | |
| # if skip is true, don't process this control | |
| if get_bool(control.attr, "model_skip", "false"): | |
| logger.info( | |
| f"Skipping control tag '{control.tag}' with name '{from_name}', model_skip=true found" | |
| ) | |
| return None | |
| # read threshold attribute from the control tag, e.g.: <RectangleLabels model_score_threshold="0.5"> | |
| model_score_threshold = float( | |
| control.attr.get("model_score_threshold") | |
| or control.attr.get( | |
| "score_threshold" | |
| ) # not recommended option, use `model_score_threshold` | |
| or MODEL_SCORE_THRESHOLD | |
| ) | |
| # read `model_path` attribute from the control tag | |
| model_path = ( | |
| ALLOW_CUSTOM_MODEL_PATH and control.attr.get("model_path") | |
| ) or cls.model_path | |
| model = cls.get_cached_model(model_path) | |
| model_names = model.names.values() # class names from the model | |
| # from_name for label mapping can be differed from control.name (e.g. VideoRectangle) | |
| label_map_from_name = cls.get_from_name_for_label_map( | |
| mlbackend.label_interface, from_name | |
| ) | |
| label_map = mlbackend.build_label_map(label_map_from_name, model_names) | |
| return cls( | |
| control=control, | |
| from_name=from_name, | |
| to_name=to_name, | |
| value=value, | |
| model=model, | |
| model_score_threshold=model_score_threshold, | |
| label_map=label_map, | |
| label_studio_ml_backend=mlbackend, | |
| project_id=mlbackend.project_id, | |
| ) | |
| def load_yolo_model(cls, filename) -> YOLO: | |
| """Load YOLO model from the file.""" | |
| path = os.path.join(MODEL_ROOT, filename) | |
| logger.info(f"Loading yolo model: {path}") | |
| model = YOLO(path) | |
| logger.info(f"Model {path} names:\n{model.names}") | |
| return model | |
| def get_cached_model(cls, path: str) -> YOLO: | |
| if path not in _model_cache: | |
| _model_cache[path] = cls.load_yolo_model(path) | |
| return _model_cache[path] | |
| def debug_plot(self, image): | |
| if not DEBUG_PLOT: | |
| return | |
| import matplotlib.pyplot as plt | |
| plt.figure(figsize=(10, 10)) | |
| plt.imshow(image[..., ::-1]) | |
| plt.axis("off") | |
| plt.title(self.type) | |
| plt.show() | |
| def predict_regions(self, path) -> List[Dict]: | |
| """Predict regions in the image using the YOLO model. | |
| Args: | |
| path (str): Path to the file with media | |
| """ | |
| raise NotImplementedError("This method should be overridden in derived classes") | |
| def fit(self, event, data, **kwargs): | |
| """Fit the model.""" | |
| logger.warning("The fit method is not implemented for this control model") | |
| return False | |
| def get_path(self, task): | |
| task_path = task["data"].get(self.value) or task["data"].get( | |
| DATA_UNDEFINED_NAME | |
| ) | |
| if task_path is None: | |
| raise ValueError( | |
| f"Can't load path using key '{self.value}' from task {task}" | |
| ) | |
| if not isinstance(task_path, str): | |
| raise ValueError(f"Path should be a string, but got {task_path}") | |
| # try path as local file or try to load it from Label Studio instance/download via http | |
| path = ( | |
| task_path | |
| if os.path.exists(task_path) | |
| else self.label_studio_ml_backend.get_local_path(task_path, task_id=task.get("id")) | |
| ) | |
| logger.debug(f"load_image: {task_path} => {path}") | |
| return path | |
| def __str__(self): | |
| """Return a string with full representation of the control tag.""" | |
| return ( | |
| f"{self.type} from_name={self.from_name}, " | |
| f"label_map={self.label_map}, model_score_threshold={self.model_score_threshold}" | |
| ) | |
| class Config: | |
| arbitrary_types_allowed = True | |
| protected_namespaces = ("__.*__", "_.*") # Excludes 'model_' | |