CUSL-Jeremy commited on
Commit
a3dfe3f
·
verified ·
1 Parent(s): 7008c03

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +23 -23
model.py CHANGED
@@ -1,24 +1,35 @@
1
  import os
2
  from typing import List, Dict, Optional
3
- from uuid import uuid4
4
 
5
  from label_studio_converter import brush
6
  from label_studio_ml.model import LabelStudioMLBase
 
7
 
8
  from sam_predictor import SAMPredictor
9
 
10
- SAM_CHOICE = os.getenv("SAM_CHOICE", "MobileSAM")
11
  PREDICTOR = SAMPredictor(SAM_CHOICE)
12
 
13
 
14
  class SamMLBackend(LabelStudioMLBase):
 
 
 
 
 
 
 
15
  def setup(self):
16
- # Explicitly mark the model as initialized for the current backend API
17
  self.set("model_version", f"{SAM_CHOICE}-v1")
18
 
19
  def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -> List[Dict]:
20
- # Newer backend API: parsed config is exposed via self.label_interface
21
- from_name, to_name, value = self.label_interface.get_first_tag_occurence("BrushLabels", "Image")
 
 
 
 
22
 
23
  if not context or not context.get("result"):
24
  return []
@@ -32,11 +43,12 @@ class SamMLBackend(LabelStudioMLBase):
32
  selected_label = None
33
 
34
  for ctx in context["result"]:
35
- x = ctx["value"]["x"] * image_width / 100.0
36
- y = ctx["value"]["y"] * image_height / 100.0
37
  ctx_type = ctx["type"]
38
  selected_label = ctx["value"][ctx_type][0]
39
 
 
 
 
40
  if ctx_type == "keypointlabels":
41
  point_labels.append(int(ctx["is_positive"]))
42
  point_coords.append([int(x), int(y)])
@@ -55,20 +67,8 @@ class SamMLBackend(LabelStudioMLBase):
55
  input_box=input_box,
56
  )
57
 
58
- return self.get_results(
59
- masks=predictor_results["masks"],
60
- probs=predictor_results["probs"],
61
- width=image_width,
62
- height=image_height,
63
- from_name=from_name,
64
- to_name=to_name,
65
- label=selected_label,
66
- )
67
-
68
- def get_results(self, masks, probs, width, height, from_name, to_name, label):
69
  results = []
70
-
71
- for mask, prob in zip(masks, probs):
72
  label_id = str(uuid4())[:8]
73
  mask = (mask * 255).astype("uint8")
74
  rle = brush.mask2rle(mask)
@@ -77,13 +77,13 @@ class SamMLBackend(LabelStudioMLBase):
77
  "id": label_id,
78
  "from_name": from_name,
79
  "to_name": to_name,
80
- "original_width": width,
81
- "original_height": height,
82
  "image_rotation": 0,
83
  "value": {
84
  "format": "rle",
85
  "rle": rle,
86
- "brushlabels": [label],
87
  },
88
  "score": float(prob),
89
  "type": "brushlabels",
 
1
  import os
2
  from typing import List, Dict, Optional
 
3
 
4
  from label_studio_converter import brush
5
  from label_studio_ml.model import LabelStudioMLBase
6
+ from uuid import uuid4
7
 
8
  from sam_predictor import SAMPredictor
9
 
10
+ SAM_CHOICE = os.environ.get("SAM_CHOICE", "MobileSAM")
11
  PREDICTOR = SAMPredictor(SAM_CHOICE)
12
 
13
 
14
  class SamMLBackend(LabelStudioMLBase):
15
+ def __init__(self, project_id=None, label_config=None, **kwargs):
16
+ # Make sure model_dir always exists, even if the backend package
17
+ # does not initialize it correctly.
18
+ self.model_dir = os.environ.get("MODEL_DIR", "/tmp/mlbackend")
19
+ os.makedirs(self.model_dir, exist_ok=True)
20
+ super().__init__(project_id=project_id, label_config=label_config)
21
+
22
  def setup(self):
23
+ # Mark the model as initialized
24
  self.set("model_version", f"{SAM_CHOICE}-v1")
25
 
26
  def predict(self, tasks: List[Dict], context: Optional[Dict] = None, **kwargs) -> List[Dict]:
27
+ # Hard-code these to match your current Label Studio XML:
28
+ # <BrushLabels name="tag" toName="image">
29
+ # <Image name="image" value="$image" ... />
30
+ from_name = "tag"
31
+ to_name = "image"
32
+ value = "image"
33
 
34
  if not context or not context.get("result"):
35
  return []
 
43
  selected_label = None
44
 
45
  for ctx in context["result"]:
 
 
46
  ctx_type = ctx["type"]
47
  selected_label = ctx["value"][ctx_type][0]
48
 
49
+ x = ctx["value"]["x"] * image_width / 100.0
50
+ y = ctx["value"]["y"] * image_height / 100.0
51
+
52
  if ctx_type == "keypointlabels":
53
  point_labels.append(int(ctx["is_positive"]))
54
  point_coords.append([int(x), int(y)])
 
67
  input_box=input_box,
68
  )
69
 
 
 
 
 
 
 
 
 
 
 
 
70
  results = []
71
+ for mask, prob in zip(predictor_results["masks"], predictor_results["probs"]):
 
72
  label_id = str(uuid4())[:8]
73
  mask = (mask * 255).astype("uint8")
74
  rle = brush.mask2rle(mask)
 
77
  "id": label_id,
78
  "from_name": from_name,
79
  "to_name": to_name,
80
+ "original_width": image_width,
81
+ "original_height": image_height,
82
  "image_rotation": 0,
83
  "value": {
84
  "format": "rle",
85
  "rle": rle,
86
+ "brushlabels": [selected_label],
87
  },
88
  "score": float(prob),
89
  "type": "brushlabels",