import os from typing import List, Dict, Optional from label_studio_converter import brush from label_studio_ml.model import LabelStudioMLBase from uuid import uuid4 from sam_predictor import SAMPredictor SAM_CHOICE = os.environ.get("SAM_CHOICE", "MobileSAM") PREDICTOR = SAMPredictor(SAM_CHOICE) class SamMLBackend(LabelStudioMLBase): def __init__(self, project_id=None, label_config=None, **kwargs): # Make sure model_dir always exists, even if the backend package # does not initialize it correctly. self.model_dir = os.environ.get("MODEL_DIR", "/tmp/mlbackend") os.makedirs(self.model_dir, exist_ok=True) super().__init__(project_id=project_id, label_config=label_config) def setup(self): # Mark the model as initialized self.set("model_version", f"{SAM_CHOICE}-v1") def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -> List[Dict]: # Hard-code these to match your current Label Studio XML: # # from_name = "tag" to_name = "image" value = "image" if not context or not context.get("result"): return [] image_width = context["result"][0]["original_width"] image_height = context["result"][0]["original_height"] point_coords = [] point_labels = [] input_box = None selected_label = None for ctx in context["result"]: ctx_type = ctx["type"] selected_label = ctx["value"][ctx_type][0] x = ctx["value"]["x"] * image_width / 100.0 y = ctx["value"]["y"] * image_height / 100.0 if ctx_type == "keypointlabels": point_labels.append(int(ctx["is_positive"])) point_coords.append([int(x), int(y)]) elif ctx_type == "rectanglelabels": box_width = ctx["value"]["width"] * image_width / 100.0 box_height = ctx["value"]["height"] * image_height / 100.0 input_box = [int(x), int(y), int(x + box_width), int(y + box_height)] img_path = tasks[0]["data"][value] predictor_results = PREDICTOR.predict( img_path=img_path, point_coords=point_coords or None, point_labels=point_labels or None, input_box=input_box, ) results = [] for mask, prob in zip(predictor_results["masks"], predictor_results["probs"]): label_id = str(uuid4())[:8] mask = (mask * 255).astype("uint8") rle = brush.mask2rle(mask) results.append({ "id": label_id, "from_name": from_name, "to_name": to_name, "original_width": image_width, "original_height": image_height, "image_rotation": 0, "value": { "format": "rle", "rle": rle, "brushlabels": [selected_label], }, "score": float(prob), "type": "brushlabels", "readonly": False, }) return [{ "result": results, "model_version": self.get("model_version"), }]