Spaces:
Sleeping
Sleeping
| import os | |
| from typing import List, Dict, Optional | |
| from label_studio_converter import brush | |
| from label_studio_ml.model import LabelStudioMLBase | |
| from uuid import uuid4 | |
| from sam_predictor import SAMPredictor | |
| SAM_CHOICE = os.environ.get("SAM_CHOICE", "MobileSAM") | |
| PREDICTOR = SAMPredictor(SAM_CHOICE) | |
| class SamMLBackend(LabelStudioMLBase): | |
| def __init__(self, project_id=None, label_config=None, **kwargs): | |
| # Make sure model_dir always exists, even if the backend package | |
| # does not initialize it correctly. | |
| self.model_dir = os.environ.get("MODEL_DIR", "/tmp/mlbackend") | |
| os.makedirs(self.model_dir, exist_ok=True) | |
| super().__init__(project_id=project_id, label_config=label_config) | |
| def setup(self): | |
| # Mark the model as initialized | |
| self.set("model_version", f"{SAM_CHOICE}-v1") | |
| def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -> List[Dict]: | |
| # Hard-code these to match your current Label Studio XML: | |
| # <BrushLabels name="tag" toName="image"> | |
| # <Image name="image" value="$image" ... /> | |
| from_name = "tag" | |
| to_name = "image" | |
| value = "image" | |
| if not context or not context.get("result"): | |
| return [] | |
| image_width = context["result"][0]["original_width"] | |
| image_height = context["result"][0]["original_height"] | |
| point_coords = [] | |
| point_labels = [] | |
| input_box = None | |
| selected_label = None | |
| for ctx in context["result"]: | |
| ctx_type = ctx["type"] | |
| selected_label = ctx["value"][ctx_type][0] | |
| x = ctx["value"]["x"] * image_width / 100.0 | |
| y = ctx["value"]["y"] * image_height / 100.0 | |
| if ctx_type == "keypointlabels": | |
| point_labels.append(int(ctx["is_positive"])) | |
| point_coords.append([int(x), int(y)]) | |
| elif ctx_type == "rectanglelabels": | |
| box_width = ctx["value"]["width"] * image_width / 100.0 | |
| box_height = ctx["value"]["height"] * image_height / 100.0 | |
| input_box = [int(x), int(y), int(x + box_width), int(y + box_height)] | |
| img_path = tasks[0]["data"][value] | |
| predictor_results = PREDICTOR.predict( | |
| img_path=img_path, | |
| point_coords=point_coords or None, | |
| point_labels=point_labels or None, | |
| input_box=input_box, | |
| ) | |
| results = [] | |
| for mask, prob in zip(predictor_results["masks"], predictor_results["probs"]): | |
| label_id = str(uuid4())[:8] | |
| mask = (mask * 255).astype("uint8") | |
| rle = brush.mask2rle(mask) | |
| results.append({ | |
| "id": label_id, | |
| "from_name": from_name, | |
| "to_name": to_name, | |
| "original_width": image_width, | |
| "original_height": image_height, | |
| "image_rotation": 0, | |
| "value": { | |
| "format": "rle", | |
| "rle": rle, | |
| "brushlabels": [selected_label], | |
| }, | |
| "score": float(prob), | |
| "type": "brushlabels", | |
| "readonly": False, | |
| }) | |
| return [{ | |
| "result": results, | |
| "model_version": self.get("model_version"), | |
| }] |