segment-anything / model.py
CUSL-Jeremy's picture
Update model.py
a3dfe3f verified
import os
from typing import List, Dict, Optional
from label_studio_converter import brush
from label_studio_ml.model import LabelStudioMLBase
from uuid import uuid4
from sam_predictor import SAMPredictor
SAM_CHOICE = os.environ.get("SAM_CHOICE", "MobileSAM")
PREDICTOR = SAMPredictor(SAM_CHOICE)
class SamMLBackend(LabelStudioMLBase):
def __init__(self, project_id=None, label_config=None, **kwargs):
# Make sure model_dir always exists, even if the backend package
# does not initialize it correctly.
self.model_dir = os.environ.get("MODEL_DIR", "/tmp/mlbackend")
os.makedirs(self.model_dir, exist_ok=True)
super().__init__(project_id=project_id, label_config=label_config)
def setup(self):
# Mark the model as initialized
self.set("model_version", f"{SAM_CHOICE}-v1")
def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -> List[Dict]:
# Hard-code these to match your current Label Studio XML:
# <BrushLabels name="tag" toName="image">
# <Image name="image" value="$image" ... />
from_name = "tag"
to_name = "image"
value = "image"
if not context or not context.get("result"):
return []
image_width = context["result"][0]["original_width"]
image_height = context["result"][0]["original_height"]
point_coords = []
point_labels = []
input_box = None
selected_label = None
for ctx in context["result"]:
ctx_type = ctx["type"]
selected_label = ctx["value"][ctx_type][0]
x = ctx["value"]["x"] * image_width / 100.0
y = ctx["value"]["y"] * image_height / 100.0
if ctx_type == "keypointlabels":
point_labels.append(int(ctx["is_positive"]))
point_coords.append([int(x), int(y)])
elif ctx_type == "rectanglelabels":
box_width = ctx["value"]["width"] * image_width / 100.0
box_height = ctx["value"]["height"] * image_height / 100.0
input_box = [int(x), int(y), int(x + box_width), int(y + box_height)]
img_path = tasks[0]["data"][value]
predictor_results = PREDICTOR.predict(
img_path=img_path,
point_coords=point_coords or None,
point_labels=point_labels or None,
input_box=input_box,
)
results = []
for mask, prob in zip(predictor_results["masks"], predictor_results["probs"]):
label_id = str(uuid4())[:8]
mask = (mask * 255).astype("uint8")
rle = brush.mask2rle(mask)
results.append({
"id": label_id,
"from_name": from_name,
"to_name": to_name,
"original_width": image_width,
"original_height": image_height,
"image_rotation": 0,
"value": {
"format": "rle",
"rle": rle,
"brushlabels": [selected_label],
},
"score": float(prob),
"type": "brushlabels",
"readonly": False,
})
return [{
"result": results,
"model_version": self.get("model_version"),
}]