saliacoel commited on
Commit
d68f6df
·
verified ·
1 Parent(s): 2a5d451

Upload AILab_SAM3Segment.py

Browse files
Files changed (1) hide show
  1. AILab_SAM3Segment.py +459 -0
AILab_SAM3Segment.py ADDED
@@ -0,0 +1,459 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from contextlib import nullcontext
4
+ from pathlib import Path
5
+
6
+ import numpy as np
7
+ import torch
8
+ from PIL import Image, ImageFilter
9
+ from torch.hub import download_url_to_file
10
+
11
+ import folder_paths
12
+ import comfy.model_management
13
+
14
+ from AILab_ImageMaskTools import pil2tensor, tensor2pil
15
+
16
+ CURRENT_DIR = os.path.dirname(__file__)
17
+ SAM3_LOCAL_DIR = os.path.join(CURRENT_DIR, "sam3")
18
+ if SAM3_LOCAL_DIR not in sys.path:
19
+ sys.path.insert(0, SAM3_LOCAL_DIR)
20
+
21
+ SAM3_BPE_PATH = os.path.join(SAM3_LOCAL_DIR, "assets", "bpe_simple_vocab_16e6.txt.gz")
22
+ if not os.path.isfile(SAM3_BPE_PATH):
23
+ raise RuntimeError("SAM3 assets missing; ensure sam3/assets/bpe_simple_vocab_16e6.txt.gz exists.")
24
+
25
+ from sam3.model_builder import build_sam3_image_model # noqa: E402
26
+ from sam3.model.sam3_image_processor import Sam3Processor # noqa: E402
27
+
28
+ _DEFAULT_PT_ENTRY = {
29
+ "model_url": "https://huggingface.co/saliacoel/x/resolve/main/sam3.pt",
30
+ "filename": "sam3.pt",
31
+ }
32
+
33
+ SAM3_MODELS = {
34
+ "sam3": _DEFAULT_PT_ENTRY.copy(),
35
+ }
36
+
37
+
38
+ def get_sam3_pt_models():
39
+ """Return a dictionary containing the PT model definition."""
40
+ entry = SAM3_MODELS.get("sam3")
41
+ if entry and entry.get("filename", "").endswith(".pt"):
42
+ return {"sam3": entry}
43
+ # Fallback: upgrade any legacy entry to PT naming
44
+ for key, value in SAM3_MODELS.items():
45
+ if value.get("filename", "").endswith(".pt"):
46
+ return {"sam3": value}
47
+ if "sam3" in key and value:
48
+ candidate = value.copy()
49
+ candidate["model_url"] = _DEFAULT_PT_ENTRY["model_url"]
50
+ candidate["filename"] = _DEFAULT_PT_ENTRY["filename"]
51
+ return {"sam3": candidate}
52
+ return {"sam3": _DEFAULT_PT_ENTRY.copy()}
53
+
54
+
55
+ def process_mask(mask_image, invert_output=False, mask_blur=0, mask_offset=0):
56
+ if invert_output:
57
+ mask_np = np.array(mask_image)
58
+ mask_image = Image.fromarray(255 - mask_np)
59
+ if mask_blur > 0:
60
+ mask_image = mask_image.filter(ImageFilter.GaussianBlur(radius=mask_blur))
61
+ if mask_offset != 0:
62
+ filt = ImageFilter.MaxFilter if mask_offset > 0 else ImageFilter.MinFilter
63
+ size = abs(mask_offset) * 2 + 1
64
+ for _ in range(abs(mask_offset)):
65
+ mask_image = mask_image.filter(filt(size))
66
+ return mask_image
67
+
68
+
69
+ def apply_background_color(image, mask_image, background="Alpha", background_color="#222222"):
70
+ rgba_image = image.copy().convert("RGBA")
71
+ rgba_image.putalpha(mask_image.convert("L"))
72
+ if background == "Color":
73
+ hex_color = background_color.lstrip("#")
74
+ r, g, b = int(hex_color[0:2], 16), int(hex_color[2:4], 16), int(hex_color[4:6], 16)
75
+ bg_image = Image.new("RGBA", image.size, (r, g, b, 255))
76
+ composite = Image.alpha_composite(bg_image, rgba_image)
77
+ return composite.convert("RGB")
78
+ return rgba_image
79
+
80
+
81
+ def get_or_download_model_file(filename, url):
82
+ local_path = None
83
+ if hasattr(folder_paths, "get_full_path"):
84
+ local_path = folder_paths.get_full_path("sam3", filename)
85
+ if local_path and os.path.isfile(local_path):
86
+ return local_path
87
+ base_models_dir = getattr(folder_paths, "models_dir", os.path.join(CURRENT_DIR, "models"))
88
+ models_dir = os.path.join(base_models_dir, "sam3")
89
+ os.makedirs(models_dir, exist_ok=True)
90
+ local_path = os.path.join(models_dir, filename)
91
+ if not os.path.exists(local_path):
92
+ print(f"Downloading {filename} from {url} ...")
93
+ download_url_to_file(url, local_path)
94
+ return local_path
95
+
96
+
97
+ def _resolve_device(user_choice):
98
+ auto_device = comfy.model_management.get_torch_device()
99
+ if user_choice == "CPU":
100
+ return torch.device("cpu")
101
+ if user_choice == "GPU":
102
+ if auto_device.type != "cuda":
103
+ raise RuntimeError("GPU unavailable")
104
+ return torch.device("cuda")
105
+ return auto_device
106
+
107
+
108
+ class SAM3Segment:
109
+ @classmethod
110
+ def INPUT_TYPES(cls):
111
+ return {
112
+ "required": {
113
+ "image": ("IMAGE",),
114
+ "prompt": ("STRING", {"default": "", "multiline": True, "placeholder": "Describe the concept"}),
115
+ "sam3_model": (list(SAM3_MODELS.keys()), {"default": "sam3"}),
116
+ "device": (["Auto", "CPU", "GPU"], {"default": "Auto"}),
117
+ "confidence_threshold": ("FLOAT", {"default": 0.5, "min": 0.05, "max": 0.95, "step": 0.01}),
118
+ },
119
+ "optional": {
120
+ "mask_blur": ("INT", {"default": 0, "min": 0, "max": 64, "step": 1}),
121
+ "mask_offset": ("INT", {"default": 0, "min": -64, "max": 64, "step": 1}),
122
+ "invert_output": ("BOOLEAN", {"default": False}),
123
+ "unload_model": ("BOOLEAN", {"default": False}),
124
+ "background": (["Alpha", "Color"], {"default": "Alpha"}),
125
+ "background_color": ("COLORCODE", {"default": "#222222"}),
126
+ },
127
+ }
128
+
129
+ RETURN_TYPES = ("IMAGE", "MASK", "IMAGE")
130
+ RETURN_NAMES = ("IMAGE", "MASK", "MASK_IMAGE")
131
+ FUNCTION = "segment"
132
+ CATEGORY = "🧪AILab/🧽RMBG"
133
+
134
+ def __init__(self):
135
+ self.processor_cache = {}
136
+
137
+ def _load_processor(self, model_choice, device_choice):
138
+ torch_device = _resolve_device(device_choice)
139
+ device_str = "cuda" if torch_device.type == "cuda" else "cpu"
140
+ cache_key = (model_choice, device_str)
141
+ if cache_key not in self.processor_cache:
142
+ model_info = SAM3_MODELS[model_choice]
143
+ ckpt_path = get_or_download_model_file(model_info["filename"], model_info["model_url"])
144
+ model = build_sam3_image_model(
145
+ bpe_path=SAM3_BPE_PATH,
146
+ device=device_str,
147
+ eval_mode=True,
148
+ checkpoint_path=ckpt_path,
149
+ load_from_HF=False,
150
+ enable_segmentation=True,
151
+ enable_inst_interactivity=False,
152
+ )
153
+ processor = Sam3Processor(model, device=device_str)
154
+ self.processor_cache[cache_key] = processor
155
+ return self.processor_cache[cache_key], torch_device
156
+
157
+ def _empty_result(self, img_pil, background, background_color):
158
+ w, h = img_pil.size
159
+ mask_image = Image.new("L", (w, h), 0)
160
+ result_image = apply_background_color(img_pil, mask_image, background, background_color)
161
+ if background == "Alpha":
162
+ result_image = result_image.convert("RGBA")
163
+ else:
164
+ result_image = result_image.convert("RGB")
165
+ empty_mask = torch.zeros((1, h, w), dtype=torch.float32)
166
+ mask_rgb = empty_mask.reshape((-1, 1, h, w)).movedim(1, -1).expand(-1, -1, -1, 3)
167
+ return result_image, empty_mask, mask_rgb
168
+
169
+ def _run_single(self, processor, img_tensor, prompt, confidence, mask_blur, mask_offset, invert, background, background_color):
170
+ img_pil = tensor2pil(img_tensor)
171
+ text = prompt.strip() or "object"
172
+ state = processor.set_image(img_pil)
173
+ processor.reset_all_prompts(state)
174
+ processor.set_confidence_threshold(confidence, state)
175
+ state = processor.set_text_prompt(text, state)
176
+ masks = state.get("masks")
177
+ if masks is None or masks.numel() == 0:
178
+ return self._empty_result(img_pil, background, background_color)
179
+ masks = masks.float().to("cpu")
180
+ if masks.ndim == 4:
181
+ masks = masks.squeeze(1)
182
+ combined = masks.amax(dim=0)
183
+ mask_np = (combined.clamp(0, 1).numpy() * 255).astype(np.uint8)
184
+ mask_image = Image.fromarray(mask_np, mode="L")
185
+ mask_image = process_mask(mask_image, invert, mask_blur, mask_offset)
186
+ result_image = apply_background_color(img_pil, mask_image, background, background_color)
187
+ if background == "Alpha":
188
+ result_image = result_image.convert("RGBA")
189
+ else:
190
+ result_image = result_image.convert("RGB")
191
+ mask_tensor = torch.from_numpy(np.array(mask_image).astype(np.float32) / 255.0).unsqueeze(0)
192
+ mask_rgb = mask_tensor.reshape((-1, 1, mask_image.height, mask_image.width)).movedim(1, -1).expand(-1, -1, -1, 3)
193
+ return result_image, mask_tensor, mask_rgb
194
+
195
+ def segment(self, image, prompt, sam3_model, device, confidence_threshold=0.5, mask_blur=0, mask_offset=0, invert_output=False, unload_model=False, background="Alpha", background_color="#222222"):
196
+
197
+ if image.ndim == 3:
198
+ image = image.unsqueeze(0)
199
+
200
+ processor, torch_device = self._load_processor(sam3_model, device)
201
+ autocast_device = comfy.model_management.get_autocast_device(torch_device)
202
+ autocast_enabled = torch_device.type == "cuda" and not comfy.model_management.is_device_mps(torch_device)
203
+ ctx = torch.autocast(autocast_device, dtype=torch.bfloat16) if autocast_enabled else nullcontext()
204
+
205
+ result_images, result_masks, result_mask_images = [], [], []
206
+
207
+ with ctx:
208
+ for tensor_img in image:
209
+ img_pil, mask_tensor, mask_rgb = self._run_single(
210
+ processor,
211
+ tensor_img,
212
+ prompt,
213
+ confidence_threshold,
214
+ mask_blur,
215
+ mask_offset,
216
+ invert_output,
217
+ background,
218
+ background_color,
219
+ )
220
+ result_images.append(pil2tensor(img_pil))
221
+ result_masks.append(mask_tensor)
222
+ result_mask_images.append(mask_rgb)
223
+
224
+ if unload_model:
225
+ device_str = "cuda" if torch_device.type == "cuda" else "cpu"
226
+ cache_key = (sam3_model, device_str)
227
+ if cache_key in self.processor_cache:
228
+ del self.processor_cache[cache_key]
229
+ if torch_device.type == "cuda":
230
+ torch.cuda.empty_cache()
231
+
232
+ return torch.cat(result_images, dim=0), torch.cat(result_masks, dim=0), torch.cat(result_mask_images, dim=0)
233
+
234
+
235
+ # ======================================================================================
236
+ # NEW FUSED NODE: Salia_ezpz_gated_Duo2 -> SAM3Segment (hardcoded) -> apply_segment_4
237
+ # ======================================================================================
238
+
239
+ def _fallback_list_asset_pngs():
240
+ """
241
+ Best-effort dropdown helper for both Salia_ezpz_gated_Duo2 and apply_segment_4.
242
+ Tries to find a nearby 'assets/images' directory by walking upwards from this file.
243
+ Returns relative posix paths (supports subfolders). If none found, returns placeholder.
244
+ """
245
+ here = Path(__file__).resolve()
246
+ images_dir = None
247
+ for parent in [here.parent] + list(here.parents)[:12]:
248
+ cand = parent / "assets" / "images"
249
+ if cand.is_dir():
250
+ images_dir = cand
251
+ break
252
+
253
+ if images_dir is None:
254
+ return ["<no pngs found>"]
255
+
256
+ files = []
257
+ for p in images_dir.rglob("*.png"):
258
+ if p.is_file():
259
+ files.append(p.relative_to(images_dir).as_posix())
260
+ files.sort()
261
+ return files or ["<no pngs found>"]
262
+
263
+
264
+ def _safe_get_choices_from_node(node_name: str, input_key: str):
265
+ """
266
+ Try to mirror the exact dropdown options of another loaded node.
267
+ Returns None on failure.
268
+ """
269
+ try:
270
+ import nodes # comfy core module where custom nodes are registered
271
+
272
+ node_cls = nodes.NODE_CLASS_MAPPINGS.get(node_name)
273
+ if node_cls is None:
274
+ return None
275
+
276
+ in_types = node_cls.INPUT_TYPES()
277
+ req = in_types.get("required", {})
278
+ field = req.get(input_key)
279
+
280
+ # field is typically like: (choices, config_dict)
281
+ if isinstance(field, tuple) and len(field) > 0:
282
+ choices = field[0]
283
+ if isinstance(choices, (list, tuple)) and len(choices) > 0:
284
+ return list(choices)
285
+ except Exception:
286
+ return None
287
+ return None
288
+
289
+
290
+ class SAM3Segment_Salia:
291
+ """
292
+ Fused node pipeline:
293
+
294
+ if trigger_string == "":
295
+ return input image unchanged
296
+
297
+ else:
298
+ 1) Salia_ezpz_gated_Duo2(image)-> (image, image_cropped)
299
+ 2) SAM3Segment(image_cropped, prompt=...) -> (seg_image, seg_mask, _)
300
+ hardcoded:
301
+ sam3_model="sam3"
302
+ device="GPU"
303
+ confidence_threshold=0.50
304
+ mask_blur=0
305
+ mask_offset=0
306
+ invert_output=False
307
+ unload_model=False
308
+ background="Alpha"
309
+ 3) apply_segment_4(mask=seg_mask, img=seg_image, canvas=input image, x=X_coord, y=Y_coord)
310
+
311
+ Output: Final_Image
312
+ """
313
+
314
+ CATEGORY = "image/salia"
315
+ RETURN_TYPES = ("IMAGE",)
316
+ RETURN_NAMES = ("Final_Image",)
317
+ FUNCTION = "run"
318
+
319
+ @classmethod
320
+ def INPUT_TYPES(cls):
321
+ # Pull dropdown choices from the other nodes (if available), else fallback.
322
+ assets_salia = _safe_get_choices_from_node("Salia_ezpz_gated_Duo2", "asset_image") or _fallback_list_asset_pngs()
323
+ assets_apply = _safe_get_choices_from_node("apply_segment_4", "image") or _fallback_list_asset_pngs()
324
+
325
+ upscale_choices = ["1", "2", "4", "6", "8", "10", "12", "14", "16"]
326
+
327
+ return {
328
+ "required": {
329
+ "image": ("IMAGE",),
330
+ "trigger_string": ("STRING", {"default": ""}),
331
+
332
+ "X_coord": ("INT", {"default": 0, "min": 0, "max": 16384, "step": 1}),
333
+ "Y_coord": ("INT", {"default": 0, "min": 0, "max": 16384, "step": 1}),
334
+
335
+ # 3 prompts total
336
+ "positive_prompt": ("STRING", {"default": "", "multiline": True}),
337
+ "negative_prompt": ("STRING", {"default": "", "multiline": True}),
338
+ "prompt": ("STRING", {"default": "", "multiline": True, "placeholder": "SAM3 prompt"}),
339
+
340
+ # two different asset selections:
341
+ # - asset_image => Salia_ezpz_gated_Duo2
342
+ # - apply_asset_image => apply_segment_4
343
+ "asset_image": (assets_salia, {}),
344
+ "apply_asset_image": (assets_apply, {}),
345
+
346
+ # Salia_ezpz_gated_Duo2 pass-1 inputs
347
+ "square_size_1": ("INT", {"default": 384, "min": 8, "max": 8192, "step": 1}),
348
+ "upscale_factor_1": (upscale_choices, {"default": "4"}),
349
+ "denoise_1": ("FLOAT", {"default": 0.35, "min": 0.00, "max": 1.00, "step": 0.01}),
350
+
351
+ # Salia_ezpz_gated_Duo2 pass-2 inputs
352
+ "square_size_2": ("INT", {"default": 384, "min": 8, "max": 8192, "step": 1}),
353
+ "upscale_factor_2": (upscale_choices, {"default": "4"}),
354
+ "denoise_2": ("FLOAT", {"default": 0.35, "min": 0.00, "max": 1.00, "step": 0.01}),
355
+ }
356
+ }
357
+
358
+ def __init__(self):
359
+ # Reuse SAM3Segment instance to benefit from its processor_cache.
360
+ self._sam3 = SAM3Segment()
361
+ self._salia_node = None
362
+ self._apply_node = None
363
+
364
+ @staticmethod
365
+ def _require_node_instance(node_name: str):
366
+ import nodes # comfy core module where custom nodes are registered
367
+
368
+ node_cls = nodes.NODE_CLASS_MAPPINGS.get(node_name)
369
+ if node_cls is None:
370
+ raise RuntimeError(
371
+ f"Required node '{node_name}' was not found in nodes.NODE_CLASS_MAPPINGS. "
372
+ f"Make sure its custom-node file is installed and loaded."
373
+ )
374
+ return node_cls()
375
+
376
+ def run(
377
+ self,
378
+ image,
379
+ trigger_string="",
380
+ X_coord=0,
381
+ Y_coord=0,
382
+ positive_prompt="",
383
+ negative_prompt="",
384
+ prompt="",
385
+ asset_image="",
386
+ apply_asset_image="",
387
+ square_size_1=384,
388
+ upscale_factor_1="4",
389
+ denoise_1=0.35,
390
+ square_size_2=384,
391
+ upscale_factor_2="4",
392
+ denoise_2=0.35,
393
+ ):
394
+ # Hard bypass: if trigger_string is exactly empty, skip ALL processing.
395
+ if trigger_string == "":
396
+ return (image,)
397
+
398
+ # Lazily instantiate dependent nodes.
399
+ if self._salia_node is None:
400
+ self._salia_node = self._require_node_instance("Salia_ezpz_gated_Duo2")
401
+ if self._apply_node is None:
402
+ self._apply_node = self._require_node_instance("apply_segment_4")
403
+
404
+ # 1) Run Salia_ezpz_gated_Duo2 (pre-node)
405
+ salia_fn = getattr(self._salia_node, getattr(self._salia_node, "FUNCTION", "run"))
406
+ out_image, image_cropped = salia_fn(
407
+ image=image,
408
+ trigger_string=trigger_string,
409
+ X_coord=int(X_coord),
410
+ Y_coord=int(Y_coord),
411
+ positive_prompt=str(positive_prompt),
412
+ negative_prompt=str(negative_prompt),
413
+ asset_image=str(asset_image),
414
+ square_size_1=int(square_size_1),
415
+ upscale_factor_1=str(upscale_factor_1),
416
+ denoise_1=float(denoise_1),
417
+ square_size_2=int(square_size_2),
418
+ upscale_factor_2=str(upscale_factor_2),
419
+ denoise_2=float(denoise_2),
420
+ )
421
+
422
+ # 2) Run SAM3Segment (center node) on the CROPPED image, with hardcoded settings.
423
+ seg_image, seg_mask, _mask_image = self._sam3.segment(
424
+ image=image_cropped,
425
+ prompt=str(prompt),
426
+ sam3_model="sam3",
427
+ device="GPU",
428
+ confidence_threshold=0.50,
429
+ mask_blur=0,
430
+ mask_offset=0,
431
+ invert_output=False,
432
+ unload_model=False,
433
+ background="Alpha",
434
+ background_color="#222222",
435
+ )
436
+
437
+ # 3) Run apply_segment_4 (post-node) on the ORIGINAL canvas image.
438
+ apply_fn = getattr(self._apply_node, getattr(self._apply_node, "FUNCTION", "run"))
439
+ (final_image,) = apply_fn(
440
+ mask=seg_mask,
441
+ image=str(apply_asset_image),
442
+ img=seg_image,
443
+ canvas=image,
444
+ x=int(X_coord),
445
+ y=int(Y_coord),
446
+ )
447
+
448
+ return (final_image,)
449
+
450
+
451
+ NODE_CLASS_MAPPINGS = {
452
+ "SAM3Segment": SAM3Segment,
453
+ "SAM3Segment_Salia": SAM3Segment_Salia,
454
+ }
455
+
456
+ NODE_DISPLAY_NAME_MAPPINGS = {
457
+ "SAM3Segment": "SAM3 Segmentation (RMBG)",
458
+ "SAM3Segment_Salia": "SAM3Segment_Salia (EZPZ + SAM3 + apply_segment_4)",
459
+ }