Spaces:
Paused
Paused
| import impact.core as core | |
| from impact.config import MAX_RESOLUTION | |
| import impact.segs_nodes as segs_nodes | |
| import impact.utils as utils | |
| import torch | |
| from impact.core import SEG | |
| class SAMDetectorCombined: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "sam_model": ("SAM_MODEL", ), | |
| "segs": ("SEGS", ), | |
| "image": ("IMAGE", ), | |
| "detection_hint": (["center-1", "horizontal-2", "vertical-2", "rect-4", "diamond-4", "mask-area", | |
| "mask-points", "mask-point-bbox", "none"],), | |
| "dilation": ("INT", {"default": 0, "min": -512, "max": 512, "step": 1}), | |
| "threshold": ("FLOAT", {"default": 0.93, "min": 0.0, "max": 1.0, "step": 0.01}), | |
| "bbox_expansion": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}), | |
| "mask_hint_threshold": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.01}), | |
| "mask_hint_use_negative": (["False", "Small", "Outter"], ) | |
| } | |
| } | |
| RETURN_TYPES = ("MASK",) | |
| FUNCTION = "doit" | |
| CATEGORY = "ImpactPack/Detector" | |
| def doit(self, sam_model, segs, image, detection_hint, dilation, | |
| threshold, bbox_expansion, mask_hint_threshold, mask_hint_use_negative): | |
| return (core.make_sam_mask(sam_model, segs, image, detection_hint, dilation, | |
| threshold, bbox_expansion, mask_hint_threshold, mask_hint_use_negative), ) | |
| class SAMDetectorSegmented: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "sam_model": ("SAM_MODEL", ), | |
| "segs": ("SEGS", ), | |
| "image": ("IMAGE", ), | |
| "detection_hint": (["center-1", "horizontal-2", "vertical-2", "rect-4", "diamond-4", "mask-area", | |
| "mask-points", "mask-point-bbox", "none"],), | |
| "dilation": ("INT", {"default": 0, "min": -512, "max": 512, "step": 1}), | |
| "threshold": ("FLOAT", {"default": 0.93, "min": 0.0, "max": 1.0, "step": 0.01}), | |
| "bbox_expansion": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}), | |
| "mask_hint_threshold": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.01}), | |
| "mask_hint_use_negative": (["False", "Small", "Outter"], ) | |
| } | |
| } | |
| RETURN_TYPES = ("MASK", "MASK") | |
| RETURN_NAMES = ("combined_mask", "batch_masks") | |
| FUNCTION = "doit" | |
| CATEGORY = "ImpactPack/Detector" | |
| def doit(self, sam_model, segs, image, detection_hint, dilation, | |
| threshold, bbox_expansion, mask_hint_threshold, mask_hint_use_negative): | |
| combined_mask, batch_masks = core.make_sam_mask_segmented(sam_model, segs, image, detection_hint, dilation, | |
| threshold, bbox_expansion, mask_hint_threshold, | |
| mask_hint_use_negative) | |
| return (combined_mask, batch_masks, ) | |
| class BboxDetectorForEach: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "bbox_detector": ("BBOX_DETECTOR", ), | |
| "image": ("IMAGE", ), | |
| "threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), | |
| "dilation": ("INT", {"default": 10, "min": -512, "max": 512, "step": 1}), | |
| "crop_factor": ("FLOAT", {"default": 3.0, "min": 1.0, "max": 100, "step": 0.1}), | |
| "drop_size": ("INT", {"min": 1, "max": MAX_RESOLUTION, "step": 1, "default": 10}), | |
| "labels": ("STRING", {"multiline": True, "default": "all", "placeholder": "List the types of segments to be allowed, separated by commas"}), | |
| }, | |
| "optional": {"detailer_hook": ("DETAILER_HOOK",), } | |
| } | |
| RETURN_TYPES = ("SEGS", ) | |
| FUNCTION = "doit" | |
| CATEGORY = "ImpactPack/Detector" | |
| def doit(self, bbox_detector, image, threshold, dilation, crop_factor, drop_size, labels=None, detailer_hook=None): | |
| if len(image) > 1: | |
| raise Exception('[Impact Pack] ERROR: BboxDetectorForEach does not allow image batches.\nPlease refer to https://github.com/ltdrdata/ComfyUI-extension-tutorials/blob/Main/ComfyUI-Impact-Pack/tutorial/batching-detailer.md for more information.') | |
| segs = bbox_detector.detect(image, threshold, dilation, crop_factor, drop_size, detailer_hook) | |
| if labels is not None and labels != '': | |
| labels = labels.split(',') | |
| if len(labels) > 0: | |
| segs, _ = segs_nodes.SEGSLabelFilter.filter(segs, labels) | |
| return (segs, ) | |
| class SegmDetectorForEach: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "segm_detector": ("SEGM_DETECTOR", ), | |
| "image": ("IMAGE", ), | |
| "threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), | |
| "dilation": ("INT", {"default": 10, "min": -512, "max": 512, "step": 1}), | |
| "crop_factor": ("FLOAT", {"default": 3.0, "min": 1.0, "max": 100, "step": 0.1}), | |
| "drop_size": ("INT", {"min": 1, "max": MAX_RESOLUTION, "step": 1, "default": 10}), | |
| "labels": ("STRING", {"multiline": True, "default": "all", "placeholder": "List the types of segments to be allowed, separated by commas"}), | |
| }, | |
| "optional": {"detailer_hook": ("DETAILER_HOOK",), } | |
| } | |
| RETURN_TYPES = ("SEGS", ) | |
| FUNCTION = "doit" | |
| CATEGORY = "ImpactPack/Detector" | |
| def doit(self, segm_detector, image, threshold, dilation, crop_factor, drop_size, labels=None, detailer_hook=None): | |
| if len(image) > 1: | |
| raise Exception('[Impact Pack] ERROR: SegmDetectorForEach does not allow image batches.\nPlease refer to https://github.com/ltdrdata/ComfyUI-extension-tutorials/blob/Main/ComfyUI-Impact-Pack/tutorial/batching-detailer.md for more information.') | |
| segs = segm_detector.detect(image, threshold, dilation, crop_factor, drop_size, detailer_hook) | |
| if labels is not None and labels != '': | |
| labels = labels.split(',') | |
| if len(labels) > 0: | |
| segs, _ = segs_nodes.SEGSLabelFilter.filter(segs, labels) | |
| return (segs, ) | |
| class SegmDetectorCombined: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "segm_detector": ("SEGM_DETECTOR", ), | |
| "image": ("IMAGE", ), | |
| "threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), | |
| "dilation": ("INT", {"default": 0, "min": -512, "max": 512, "step": 1}), | |
| } | |
| } | |
| RETURN_TYPES = ("MASK",) | |
| FUNCTION = "doit" | |
| CATEGORY = "ImpactPack/Detector" | |
| def doit(self, segm_detector, image, threshold, dilation): | |
| mask = segm_detector.detect_combined(image, threshold, dilation) | |
| return (mask,) | |
| class BboxDetectorCombined(SegmDetectorCombined): | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "bbox_detector": ("BBOX_DETECTOR", ), | |
| "image": ("IMAGE", ), | |
| "threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), | |
| "dilation": ("INT", {"default": 4, "min": -512, "max": 512, "step": 1}), | |
| } | |
| } | |
| def doit(self, bbox_detector, image, threshold, dilation): | |
| mask = bbox_detector.detect_combined(image, threshold, dilation) | |
| return (mask,) | |
| class SimpleDetectorForEach: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "bbox_detector": ("BBOX_DETECTOR", ), | |
| "image": ("IMAGE", ), | |
| "bbox_threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), | |
| "bbox_dilation": ("INT", {"default": 0, "min": -512, "max": 512, "step": 1}), | |
| "crop_factor": ("FLOAT", {"default": 3.0, "min": 1.0, "max": 100, "step": 0.1}), | |
| "drop_size": ("INT", {"min": 1, "max": MAX_RESOLUTION, "step": 1, "default": 10}), | |
| "sub_threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), | |
| "sub_dilation": ("INT", {"default": 0, "min": -512, "max": 512, "step": 1}), | |
| "sub_bbox_expansion": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}), | |
| "sam_mask_hint_threshold": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.01}), | |
| }, | |
| "optional": { | |
| "post_dilation": ("INT", {"default": 0, "min": -512, "max": 512, "step": 1}), | |
| "sam_model_opt": ("SAM_MODEL", ), | |
| "segm_detector_opt": ("SEGM_DETECTOR", ), | |
| } | |
| } | |
| RETURN_TYPES = ("SEGS",) | |
| FUNCTION = "doit" | |
| CATEGORY = "ImpactPack/Detector" | |
| def detect(bbox_detector, image, bbox_threshold, bbox_dilation, crop_factor, drop_size, | |
| sub_threshold, sub_dilation, sub_bbox_expansion, | |
| sam_mask_hint_threshold, post_dilation=0, sam_model_opt=None, segm_detector_opt=None, | |
| detailer_hook=None): | |
| if len(image) > 1: | |
| raise Exception('[Impact Pack] ERROR: SimpleDetectorForEach does not allow image batches.\nPlease refer to https://github.com/ltdrdata/ComfyUI-extension-tutorials/blob/Main/ComfyUI-Impact-Pack/tutorial/batching-detailer.md for more information.') | |
| segs = bbox_detector.detect(image, bbox_threshold, bbox_dilation, crop_factor, drop_size, detailer_hook=detailer_hook) | |
| if sam_model_opt is not None: | |
| mask = core.make_sam_mask(sam_model_opt, segs, image, "center-1", sub_dilation, | |
| sub_threshold, sub_bbox_expansion, sam_mask_hint_threshold, False) | |
| segs = core.segs_bitwise_and_mask(segs, mask) | |
| elif segm_detector_opt is not None: | |
| segm_segs = segm_detector_opt.detect(image, sub_threshold, sub_dilation, crop_factor, drop_size, detailer_hook=detailer_hook) | |
| mask = core.segs_to_combined_mask(segm_segs) | |
| segs = core.segs_bitwise_and_mask(segs, mask) | |
| segs = core.dilate_segs(segs, post_dilation) | |
| return (segs,) | |
| def doit(self, bbox_detector, image, bbox_threshold, bbox_dilation, crop_factor, drop_size, | |
| sub_threshold, sub_dilation, sub_bbox_expansion, | |
| sam_mask_hint_threshold, post_dilation=0, sam_model_opt=None, segm_detector_opt=None): | |
| return SimpleDetectorForEach.detect(bbox_detector, image, bbox_threshold, bbox_dilation, crop_factor, drop_size, | |
| sub_threshold, sub_dilation, sub_bbox_expansion, | |
| sam_mask_hint_threshold, post_dilation=post_dilation, | |
| sam_model_opt=sam_model_opt, segm_detector_opt=segm_detector_opt) | |
| class SimpleDetectorForEachPipe: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "detailer_pipe": ("DETAILER_PIPE", ), | |
| "image": ("IMAGE", ), | |
| "bbox_threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), | |
| "bbox_dilation": ("INT", {"default": 0, "min": -512, "max": 512, "step": 1}), | |
| "crop_factor": ("FLOAT", {"default": 3.0, "min": 1.0, "max": 100, "step": 0.1}), | |
| "drop_size": ("INT", {"min": 1, "max": MAX_RESOLUTION, "step": 1, "default": 10}), | |
| "sub_threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), | |
| "sub_dilation": ("INT", {"default": 0, "min": -512, "max": 512, "step": 1}), | |
| "sub_bbox_expansion": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}), | |
| "sam_mask_hint_threshold": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.01}), | |
| }, | |
| "optional": { | |
| "post_dilation": ("INT", {"default": 0, "min": -512, "max": 512, "step": 1}), | |
| } | |
| } | |
| RETURN_TYPES = ("SEGS",) | |
| FUNCTION = "doit" | |
| CATEGORY = "ImpactPack/Detector" | |
| def doit(self, detailer_pipe, image, bbox_threshold, bbox_dilation, crop_factor, drop_size, | |
| sub_threshold, sub_dilation, sub_bbox_expansion, sam_mask_hint_threshold, post_dilation=0): | |
| if len(image) > 1: | |
| raise Exception('[Impact Pack] ERROR: SimpleDetectorForEach does not allow image batches.\nPlease refer to https://github.com/ltdrdata/ComfyUI-extension-tutorials/blob/Main/ComfyUI-Impact-Pack/tutorial/batching-detailer.md for more information.') | |
| model, clip, vae, positive, negative, wildcard, bbox_detector, segm_detector_opt, sam_model_opt, detailer_hook, refiner_model, refiner_clip, refiner_positive, refiner_negative = detailer_pipe | |
| return SimpleDetectorForEach.detect(bbox_detector, image, bbox_threshold, bbox_dilation, crop_factor, drop_size, | |
| sub_threshold, sub_dilation, sub_bbox_expansion, | |
| sam_mask_hint_threshold, post_dilation=post_dilation, sam_model_opt=sam_model_opt, segm_detector_opt=segm_detector_opt, | |
| detailer_hook=detailer_hook) | |
| class SimpleDetectorForAnimateDiff: | |
| def INPUT_TYPES(s): | |
| return {"required": { | |
| "bbox_detector": ("BBOX_DETECTOR", ), | |
| "image_frames": ("IMAGE", ), | |
| "bbox_threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), | |
| "bbox_dilation": ("INT", {"default": 0, "min": -255, "max": 255, "step": 1}), | |
| "crop_factor": ("FLOAT", {"default": 3.0, "min": 1.0, "max": 100, "step": 0.1}), | |
| "drop_size": ("INT", {"min": 1, "max": MAX_RESOLUTION, "step": 1, "default": 10}), | |
| "sub_threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step": 0.01}), | |
| "sub_dilation": ("INT", {"default": 0, "min": -255, "max": 255, "step": 1}), | |
| "sub_bbox_expansion": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}), | |
| "sam_mask_hint_threshold": ("FLOAT", {"default": 0.7, "min": 0.0, "max": 1.0, "step": 0.01}), | |
| }, | |
| "optional": { | |
| "masking_mode": (["Pivot SEGS", "Combine neighboring frames", "Don't combine"],), | |
| "segs_pivot": (["Combined mask", "1st frame mask"],), | |
| "sam_model_opt": ("SAM_MODEL", ), | |
| "segm_detector_opt": ("SEGM_DETECTOR", ), | |
| } | |
| } | |
| RETURN_TYPES = ("SEGS",) | |
| FUNCTION = "doit" | |
| CATEGORY = "ImpactPack/Detector" | |
| def detect(bbox_detector, image_frames, bbox_threshold, bbox_dilation, crop_factor, drop_size, | |
| sub_threshold, sub_dilation, sub_bbox_expansion, sam_mask_hint_threshold, | |
| masking_mode="Pivot SEGS", segs_pivot="Combined mask", sam_model_opt=None, segm_detector_opt=None): | |
| h = image_frames.shape[1] | |
| w = image_frames.shape[2] | |
| # gather segs for all frames | |
| segs_by_frames = [] | |
| for image in image_frames: | |
| image = image.unsqueeze(0) | |
| segs = bbox_detector.detect(image, bbox_threshold, bbox_dilation, crop_factor, drop_size) | |
| if sam_model_opt is not None: | |
| mask = core.make_sam_mask(sam_model_opt, segs, image, "center-1", sub_dilation, | |
| sub_threshold, sub_bbox_expansion, sam_mask_hint_threshold, False) | |
| segs = core.segs_bitwise_and_mask(segs, mask) | |
| elif segm_detector_opt is not None: | |
| segm_segs = segm_detector_opt.detect(image, sub_threshold, sub_dilation, crop_factor, drop_size) | |
| mask = core.segs_to_combined_mask(segm_segs) | |
| segs = core.segs_bitwise_and_mask(segs, mask) | |
| segs_by_frames.append(segs) | |
| def get_masked_frames(): | |
| masks_by_frame = [] | |
| for i, segs in enumerate(segs_by_frames): | |
| masks_in_frame = segs_nodes.SEGSToMaskList().doit(segs)[0] | |
| current_frame_mask = (masks_in_frame[0] * 255).to(torch.uint8) | |
| for mask in masks_in_frame[1:]: | |
| current_frame_mask |= (mask * 255).to(torch.uint8) | |
| current_frame_mask = (current_frame_mask/255.0).to(torch.float32) | |
| current_frame_mask = utils.to_binary_mask(current_frame_mask, 0.1)[0] | |
| masks_by_frame.append(current_frame_mask) | |
| return masks_by_frame | |
| def get_empty_mask(): | |
| return torch.zeros((h, w), dtype=torch.float32, device="cpu") | |
| def get_neighboring_mask_at(i, masks_by_frame): | |
| prv = masks_by_frame[i-1] if i > 1 else get_empty_mask() | |
| cur = masks_by_frame[i] | |
| nxt = masks_by_frame[i-1] if i > 1 else get_empty_mask() | |
| prv = prv if prv is not None else get_empty_mask() | |
| cur = cur.clone() if cur is not None else get_empty_mask() | |
| nxt = nxt if nxt is not None else get_empty_mask() | |
| return prv, cur, nxt | |
| def get_merged_neighboring_mask(masks_by_frame): | |
| if len(masks_by_frame) <= 1: | |
| return masks_by_frame | |
| result = [] | |
| for i in range(0, len(masks_by_frame)): | |
| prv, cur, nxt = get_neighboring_mask_at(i, masks_by_frame) | |
| cur = (cur * 255).to(torch.uint8) | |
| cur |= (prv * 255).to(torch.uint8) | |
| cur |= (nxt * 255).to(torch.uint8) | |
| cur = (cur / 255.0).to(torch.float32) | |
| cur = utils.to_binary_mask(cur, 0.1)[0] | |
| result.append(cur) | |
| return result | |
| def get_whole_merged_mask(): | |
| all_masks = [] | |
| for segs in segs_by_frames: | |
| all_masks += segs_nodes.SEGSToMaskList().doit(segs)[0] | |
| merged_mask = (all_masks[0] * 255).to(torch.uint8) | |
| for mask in all_masks[1:]: | |
| merged_mask |= (mask * 255).to(torch.uint8) | |
| merged_mask = (merged_mask / 255.0).to(torch.float32) | |
| merged_mask = utils.to_binary_mask(merged_mask, 0.1)[0] | |
| return merged_mask | |
| def get_pivot_segs(): | |
| if segs_pivot == "1st frame mask": | |
| return segs_by_frames[0][1] | |
| else: | |
| merged_mask = get_whole_merged_mask() | |
| return segs_nodes.MaskToSEGS().doit(merged_mask, False, crop_factor, False, drop_size, contour_fill=True)[0] | |
| def get_merged_neighboring_segs(): | |
| pivot_segs = get_pivot_segs() | |
| masks_by_frame = get_masked_frames() | |
| masks_by_frame = get_merged_neighboring_mask(masks_by_frame) | |
| new_segs = [] | |
| for seg in pivot_segs[1]: | |
| cropped_mask = torch.zeros(seg.cropped_mask.shape, dtype=torch.float32, device="cpu").unsqueeze(0) | |
| pivot_mask = torch.from_numpy(seg.cropped_mask) | |
| x1, y1, x2, y2 = seg.crop_region | |
| for mask in masks_by_frame: | |
| cropped_mask_at_frame = (mask[y1:y2, x1:x2] * pivot_mask).unsqueeze(0) | |
| cropped_mask = torch.cat((cropped_mask, cropped_mask_at_frame), dim=0) | |
| if len(cropped_mask) > 1: | |
| cropped_mask = cropped_mask[1:] | |
| new_seg = SEG(seg.cropped_image, cropped_mask, seg.confidence, seg.crop_region, seg.bbox, seg.label, seg.control_net_wrapper) | |
| new_segs.append(new_seg) | |
| return pivot_segs[0], new_segs | |
| def get_separated_segs(): | |
| pivot_segs = get_pivot_segs() | |
| masks_by_frame = get_masked_frames() | |
| new_segs = [] | |
| for seg in pivot_segs[1]: | |
| cropped_mask = torch.zeros(seg.cropped_mask.shape, dtype=torch.float32, device="cpu").unsqueeze(0) | |
| x1, y1, x2, y2 = seg.crop_region | |
| for mask in masks_by_frame: | |
| cropped_mask_at_frame = mask[y1:y2, x1:x2] | |
| cropped_mask = torch.cat((cropped_mask, cropped_mask_at_frame), dim=0) | |
| new_seg = SEG(seg.cropped_image, cropped_mask, seg.confidence, seg.crop_region, seg.bbox, seg.label, seg.control_net_wrapper) | |
| new_segs.append(new_seg) | |
| return pivot_segs[0], new_segs | |
| # create result mask | |
| if masking_mode == "Pivot SEGS": | |
| return (get_pivot_segs(), ) | |
| elif masking_mode == "Combine neighboring frames": | |
| return (get_merged_neighboring_segs(), ) | |
| else: # elif masking_mode == "Don't combine": | |
| return (get_separated_segs(), ) | |
| def doit(self, bbox_detector, image_frames, bbox_threshold, bbox_dilation, crop_factor, drop_size, | |
| sub_threshold, sub_dilation, sub_bbox_expansion, sam_mask_hint_threshold, | |
| masking_mode="Pivot SEGS", segs_pivot="Combined mask", sam_model_opt=None, segm_detector_opt=None): | |
| return SimpleDetectorForAnimateDiff.detect(bbox_detector, image_frames, bbox_threshold, bbox_dilation, crop_factor, drop_size, | |
| sub_threshold, sub_dilation, sub_bbox_expansion, sam_mask_hint_threshold, | |
| masking_mode, segs_pivot, sam_model_opt, segm_detector_opt) | |