File size: 3,357 Bytes
6307f85
 
e76a972
 
6307f85
a3dfe3f
6307f85
e76a972
 
a3dfe3f
6307f85
 
 
 
a3dfe3f
 
 
 
 
 
 
e76a972
a3dfe3f
e76a972
6307f85
 
a3dfe3f
 
 
 
 
 
6307f85
e76a972
6307f85
 
e76a972
 
6307f85
 
 
 
 
e76a972
 
 
 
 
a3dfe3f
 
 
e76a972
 
6307f85
 
e76a972
 
 
 
 
 
6307f85
 
 
 
 
e76a972
6307f85
 
 
a3dfe3f
e76a972
 
6307f85
 
 
e76a972
 
 
a3dfe3f
 
e76a972
 
 
 
a3dfe3f
6307f85
e76a972
 
 
6307f85
 
 
e76a972
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
import os
from typing import List, Dict, Optional

from label_studio_converter import brush
from label_studio_ml.model import LabelStudioMLBase
from uuid import uuid4

from sam_predictor import SAMPredictor

SAM_CHOICE = os.environ.get("SAM_CHOICE", "MobileSAM")
PREDICTOR = SAMPredictor(SAM_CHOICE)


class SamMLBackend(LabelStudioMLBase):
    def __init__(self, project_id=None, label_config=None, **kwargs):
        # Make sure model_dir always exists, even if the backend package
        # does not initialize it correctly.
        self.model_dir = os.environ.get("MODEL_DIR", "/tmp/mlbackend")
        os.makedirs(self.model_dir, exist_ok=True)
        super().__init__(project_id=project_id, label_config=label_config)

    def setup(self):
        # Mark the model as initialized
        self.set("model_version", f"{SAM_CHOICE}-v1")

    def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -> List[Dict]:
        # Hard-code these to match your current Label Studio XML:
        # <BrushLabels name="tag" toName="image">
        # <Image name="image" value="$image" ... />
        from_name = "tag"
        to_name = "image"
        value = "image"

        if not context or not context.get("result"):
            return []

        image_width = context["result"][0]["original_width"]
        image_height = context["result"][0]["original_height"]

        point_coords = []
        point_labels = []
        input_box = None
        selected_label = None

        for ctx in context["result"]:
            ctx_type = ctx["type"]
            selected_label = ctx["value"][ctx_type][0]

            x = ctx["value"]["x"] * image_width / 100.0
            y = ctx["value"]["y"] * image_height / 100.0

            if ctx_type == "keypointlabels":
                point_labels.append(int(ctx["is_positive"]))
                point_coords.append([int(x), int(y)])

            elif ctx_type == "rectanglelabels":
                box_width = ctx["value"]["width"] * image_width / 100.0
                box_height = ctx["value"]["height"] * image_height / 100.0
                input_box = [int(x), int(y), int(x + box_width), int(y + box_height)]

        img_path = tasks[0]["data"][value]

        predictor_results = PREDICTOR.predict(
            img_path=img_path,
            point_coords=point_coords or None,
            point_labels=point_labels or None,
            input_box=input_box,
        )

        results = []
        for mask, prob in zip(predictor_results["masks"], predictor_results["probs"]):
            label_id = str(uuid4())[:8]
            mask = (mask * 255).astype("uint8")
            rle = brush.mask2rle(mask)

            results.append({
                "id": label_id,
                "from_name": from_name,
                "to_name": to_name,
                "original_width": image_width,
                "original_height": image_height,
                "image_rotation": 0,
                "value": {
                    "format": "rle",
                    "rle": rle,
                    "brushlabels": [selected_label],
                },
                "score": float(prob),
                "type": "brushlabels",
                "readonly": False,
            })

        return [{
            "result": results,
            "model_version": self.get("model_version"),
        }]