Spaces:
Sleeping
Sleeping
File size: 3,357 Bytes
6307f85 e76a972 6307f85 a3dfe3f 6307f85 e76a972 a3dfe3f 6307f85 a3dfe3f e76a972 a3dfe3f e76a972 6307f85 a3dfe3f 6307f85 e76a972 6307f85 e76a972 6307f85 e76a972 a3dfe3f e76a972 6307f85 e76a972 6307f85 e76a972 6307f85 a3dfe3f e76a972 6307f85 e76a972 a3dfe3f e76a972 a3dfe3f 6307f85 e76a972 6307f85 e76a972 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 | import os
from typing import List, Dict, Optional
from label_studio_converter import brush
from label_studio_ml.model import LabelStudioMLBase
from uuid import uuid4
from sam_predictor import SAMPredictor
SAM_CHOICE = os.environ.get("SAM_CHOICE", "MobileSAM")
PREDICTOR = SAMPredictor(SAM_CHOICE)
class SamMLBackend(LabelStudioMLBase):
def __init__(self, project_id=None, label_config=None, **kwargs):
# Make sure model_dir always exists, even if the backend package
# does not initialize it correctly.
self.model_dir = os.environ.get("MODEL_DIR", "/tmp/mlbackend")
os.makedirs(self.model_dir, exist_ok=True)
super().__init__(project_id=project_id, label_config=label_config)
def setup(self):
# Mark the model as initialized
self.set("model_version", f"{SAM_CHOICE}-v1")
def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -> List[Dict]:
# Hard-code these to match your current Label Studio XML:
# <BrushLabels name="tag" toName="image">
# <Image name="image" value="$image" ... />
from_name = "tag"
to_name = "image"
value = "image"
if not context or not context.get("result"):
return []
image_width = context["result"][0]["original_width"]
image_height = context["result"][0]["original_height"]
point_coords = []
point_labels = []
input_box = None
selected_label = None
for ctx in context["result"]:
ctx_type = ctx["type"]
selected_label = ctx["value"][ctx_type][0]
x = ctx["value"]["x"] * image_width / 100.0
y = ctx["value"]["y"] * image_height / 100.0
if ctx_type == "keypointlabels":
point_labels.append(int(ctx["is_positive"]))
point_coords.append([int(x), int(y)])
elif ctx_type == "rectanglelabels":
box_width = ctx["value"]["width"] * image_width / 100.0
box_height = ctx["value"]["height"] * image_height / 100.0
input_box = [int(x), int(y), int(x + box_width), int(y + box_height)]
img_path = tasks[0]["data"][value]
predictor_results = PREDICTOR.predict(
img_path=img_path,
point_coords=point_coords or None,
point_labels=point_labels or None,
input_box=input_box,
)
results = []
for mask, prob in zip(predictor_results["masks"], predictor_results["probs"]):
label_id = str(uuid4())[:8]
mask = (mask * 255).astype("uint8")
rle = brush.mask2rle(mask)
results.append({
"id": label_id,
"from_name": from_name,
"to_name": to_name,
"original_width": image_width,
"original_height": image_height,
"image_rotation": 0,
"value": {
"format": "rle",
"rle": rle,
"brushlabels": [selected_label],
},
"score": float(prob),
"type": "brushlabels",
"readonly": False,
})
return [{
"result": results,
"model_version": self.get("model_version"),
}] |