|
|
import numpy as np |
|
|
import cv2 |
|
|
import torch |
|
|
import torchvision.transforms.functional as TF |
|
|
|
|
|
import sys as _sys |
|
|
from keyword import iskeyword as _iskeyword |
|
|
from operator import itemgetter as _itemgetter |
|
|
|
|
|
from segment_anything import SamPredictor |
|
|
|
|
|
from comfy import model_management |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
try: |
|
|
from _collections import _tuplegetter |
|
|
except ImportError: |
|
|
_tuplegetter = lambda index, doc: property(_itemgetter(index), doc=doc) |
|
|
|
|
|
def namedtuple(typename, field_names, *, rename=False, defaults=None, module=None): |
|
|
"""Returns a new subclass of tuple with named fields. |
|
|
|
|
|
>>> Point = namedtuple('Point', ['x', 'y']) |
|
|
>>> Point.__doc__ # docstring for the new class |
|
|
'Point(x, y)' |
|
|
>>> p = Point(11, y=22) # instantiate with positional args or keywords |
|
|
>>> p[0] + p[1] # indexable like a plain tuple |
|
|
33 |
|
|
>>> x, y = p # unpack like a regular tuple |
|
|
>>> x, y |
|
|
(11, 22) |
|
|
>>> p.x + p.y # fields also accessible by name |
|
|
33 |
|
|
>>> d = p._asdict() # convert to a dictionary |
|
|
>>> d['x'] |
|
|
11 |
|
|
>>> Point(**d) # convert from a dictionary |
|
|
Point(x=11, y=22) |
|
|
>>> p._replace(x=100) # _replace() is like str.replace() but targets named fields |
|
|
Point(x=100, y=22) |
|
|
|
|
|
""" |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(field_names, str): |
|
|
field_names = field_names.replace(',', ' ').split() |
|
|
field_names = list(map(str, field_names)) |
|
|
typename = _sys.intern(str(typename)) |
|
|
|
|
|
if rename: |
|
|
seen = set() |
|
|
for index, name in enumerate(field_names): |
|
|
if (not name.isidentifier() |
|
|
or _iskeyword(name) |
|
|
or name.startswith('_') |
|
|
or name in seen): |
|
|
field_names[index] = f'_{index}' |
|
|
seen.add(name) |
|
|
|
|
|
for name in [typename] + field_names: |
|
|
if type(name) is not str: |
|
|
raise TypeError('Type names and field names must be strings') |
|
|
if not name.isidentifier(): |
|
|
raise ValueError('Type names and field names must be valid ' |
|
|
f'identifiers: {name!r}') |
|
|
if _iskeyword(name): |
|
|
raise ValueError('Type names and field names cannot be a ' |
|
|
f'keyword: {name!r}') |
|
|
|
|
|
seen = set() |
|
|
for name in field_names: |
|
|
if name.startswith('_') and not rename: |
|
|
raise ValueError('Field names cannot start with an underscore: ' |
|
|
f'{name!r}') |
|
|
if name in seen: |
|
|
raise ValueError(f'Encountered duplicate field name: {name!r}') |
|
|
seen.add(name) |
|
|
|
|
|
field_defaults = {} |
|
|
if defaults is not None: |
|
|
defaults = tuple(defaults) |
|
|
if len(defaults) > len(field_names): |
|
|
raise TypeError('Got more default values than field names') |
|
|
field_defaults = dict(reversed(list(zip(reversed(field_names), |
|
|
reversed(defaults))))) |
|
|
|
|
|
|
|
|
field_names = tuple(map(_sys.intern, field_names)) |
|
|
num_fields = len(field_names) |
|
|
arg_list = ', '.join(field_names) |
|
|
if num_fields == 1: |
|
|
arg_list += ',' |
|
|
repr_fmt = '(' + ', '.join(f'{name}=%r' for name in field_names) + ')' |
|
|
tuple_new = tuple.__new__ |
|
|
_dict, _tuple, _len, _map, _zip = dict, tuple, len, map, zip |
|
|
|
|
|
|
|
|
|
|
|
namespace = { |
|
|
'_tuple_new': tuple_new, |
|
|
'__builtins__': {}, |
|
|
'__name__': f'namedtuple_{typename}', |
|
|
} |
|
|
code = f'lambda _cls, {arg_list}: _tuple_new(_cls, ({arg_list}))' |
|
|
__new__ = eval(code, namespace) |
|
|
__new__.__name__ = '__new__' |
|
|
__new__.__doc__ = f'Create new instance of {typename}({arg_list})' |
|
|
if defaults is not None: |
|
|
__new__.__defaults__ = defaults |
|
|
|
|
|
@classmethod |
|
|
def _make(cls, iterable): |
|
|
result = tuple_new(cls, iterable) |
|
|
if _len(result) != num_fields: |
|
|
raise TypeError(f'Expected {num_fields} arguments, got {len(result)}') |
|
|
return result |
|
|
|
|
|
_make.__func__.__doc__ = (f'Make a new {typename} object from a sequence ' |
|
|
'or iterable') |
|
|
|
|
|
def _replace(self, /, **kwds): |
|
|
result = self._make(_map(kwds.pop, field_names, self)) |
|
|
if kwds: |
|
|
raise ValueError(f'Got unexpected field names: {list(kwds)!r}') |
|
|
return result |
|
|
|
|
|
_replace.__doc__ = (f'Return a new {typename} object replacing specified ' |
|
|
'fields with new values') |
|
|
|
|
|
def __repr__(self): |
|
|
'Return a nicely formatted representation string' |
|
|
return self.__class__.__name__ + repr_fmt % self |
|
|
|
|
|
def _asdict(self): |
|
|
'Return a new dict which maps field names to their values.' |
|
|
return _dict(_zip(self._fields, self)) |
|
|
|
|
|
def __getnewargs__(self): |
|
|
'Return self as a plain tuple. Used by copy and pickle.' |
|
|
return _tuple(self) |
|
|
|
|
|
|
|
|
for method in ( |
|
|
__new__, |
|
|
_make.__func__, |
|
|
_replace, |
|
|
__repr__, |
|
|
_asdict, |
|
|
__getnewargs__, |
|
|
): |
|
|
method.__qualname__ = f'{typename}.{method.__name__}' |
|
|
|
|
|
|
|
|
|
|
|
class_namespace = { |
|
|
'__doc__': f'{typename}({arg_list})', |
|
|
'__slots__': (), |
|
|
'_fields': field_names, |
|
|
'_field_defaults': field_defaults, |
|
|
'__new__': __new__, |
|
|
'_make': _make, |
|
|
'_replace': _replace, |
|
|
'__repr__': __repr__, |
|
|
'_asdict': _asdict, |
|
|
'__getnewargs__': __getnewargs__, |
|
|
'__match_args__': field_names, |
|
|
} |
|
|
for index, name in enumerate(field_names): |
|
|
doc = _sys.intern(f'Alias for field number {index}') |
|
|
class_namespace[name] = _tuplegetter(index, doc) |
|
|
|
|
|
result = type(typename, (tuple,), class_namespace) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if module is None: |
|
|
try: |
|
|
module = _sys._getframe(1).f_globals.get('__name__', '__main__') |
|
|
except (AttributeError, ValueError): |
|
|
pass |
|
|
if module is not None: |
|
|
result.__module__ = module |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
SEG = namedtuple("SEG", |
|
|
['cropped_image', 'cropped_mask', 'confidence', 'crop_region', 'bbox', 'label', 'control_net_wrapper'], |
|
|
defaults=[None]) |
|
|
|
|
|
def crop_ndarray4(npimg, crop_region): |
|
|
x1 = crop_region[0] |
|
|
y1 = crop_region[1] |
|
|
x2 = crop_region[2] |
|
|
y2 = crop_region[3] |
|
|
|
|
|
cropped = npimg[:, y1:y2, x1:x2, :] |
|
|
|
|
|
return cropped |
|
|
|
|
|
crop_tensor4 = crop_ndarray4 |
|
|
|
|
|
def crop_ndarray2(npimg, crop_region): |
|
|
x1 = crop_region[0] |
|
|
y1 = crop_region[1] |
|
|
x2 = crop_region[2] |
|
|
y2 = crop_region[3] |
|
|
|
|
|
cropped = npimg[y1:y2, x1:x2] |
|
|
|
|
|
return cropped |
|
|
|
|
|
def crop_image(image, crop_region): |
|
|
return crop_tensor4(image, crop_region) |
|
|
|
|
|
def normalize_region(limit, startp, size): |
|
|
if startp < 0: |
|
|
new_endp = min(limit, size) |
|
|
new_startp = 0 |
|
|
elif startp + size > limit: |
|
|
new_startp = max(0, limit - size) |
|
|
new_endp = limit |
|
|
else: |
|
|
new_startp = startp |
|
|
new_endp = min(limit, startp+size) |
|
|
|
|
|
return int(new_startp), int(new_endp) |
|
|
|
|
|
def make_crop_region(w, h, bbox, crop_factor, crop_min_size=None): |
|
|
x1 = bbox[0] |
|
|
y1 = bbox[1] |
|
|
x2 = bbox[2] |
|
|
y2 = bbox[3] |
|
|
|
|
|
bbox_w = x2 - x1 |
|
|
bbox_h = y2 - y1 |
|
|
|
|
|
crop_w = bbox_w * crop_factor |
|
|
crop_h = bbox_h * crop_factor |
|
|
|
|
|
if crop_min_size is not None: |
|
|
crop_w = max(crop_min_size, crop_w) |
|
|
crop_h = max(crop_min_size, crop_h) |
|
|
|
|
|
kernel_x = x1 + bbox_w / 2 |
|
|
kernel_y = y1 + bbox_h / 2 |
|
|
|
|
|
new_x1 = int(kernel_x - crop_w / 2) |
|
|
new_y1 = int(kernel_y - crop_h / 2) |
|
|
|
|
|
|
|
|
new_x1, new_x2 = normalize_region(w, new_x1, crop_w) |
|
|
new_y1, new_y2 = normalize_region(h, new_y1, crop_h) |
|
|
|
|
|
return [new_x1, new_y1, new_x2, new_y2] |
|
|
|
|
|
def create_segmasks(results): |
|
|
bboxs = results[1] |
|
|
segms = results[2] |
|
|
confidence = results[3] |
|
|
|
|
|
results = [] |
|
|
for i in range(len(segms)): |
|
|
item = (bboxs[i], segms[i].astype(np.float32), confidence[i]) |
|
|
results.append(item) |
|
|
return results |
|
|
|
|
|
def dilate_masks(segmasks, dilation_factor, iter=1): |
|
|
if dilation_factor == 0: |
|
|
return segmasks |
|
|
|
|
|
dilated_masks = [] |
|
|
kernel = np.ones((abs(dilation_factor), abs(dilation_factor)), np.uint8) |
|
|
|
|
|
kernel = cv2.UMat(kernel) |
|
|
|
|
|
for i in range(len(segmasks)): |
|
|
cv2_mask = segmasks[i][1] |
|
|
|
|
|
cv2_mask = cv2.UMat(cv2_mask) |
|
|
|
|
|
if dilation_factor > 0: |
|
|
dilated_mask = cv2.dilate(cv2_mask, kernel, iter) |
|
|
else: |
|
|
dilated_mask = cv2.erode(cv2_mask, kernel, iter) |
|
|
|
|
|
dilated_mask = dilated_mask.get() |
|
|
|
|
|
item = (segmasks[i][0], dilated_mask, segmasks[i][2]) |
|
|
dilated_masks.append(item) |
|
|
|
|
|
return dilated_masks |
|
|
|
|
|
def is_same_device(a, b): |
|
|
a_device = torch.device(a) if isinstance(a, str) else a |
|
|
b_device = torch.device(b) if isinstance(b, str) else b |
|
|
return a_device.type == b_device.type and a_device.index == b_device.index |
|
|
|
|
|
class SafeToGPU: |
|
|
def __init__(self, size): |
|
|
self.size = size |
|
|
|
|
|
def to_device(self, obj, device): |
|
|
if is_same_device(device, 'cpu'): |
|
|
obj.to(device) |
|
|
else: |
|
|
|
|
|
current_device = None |
|
|
try: |
|
|
current_device = next(obj.parameters()).device |
|
|
except: |
|
|
pass |
|
|
|
|
|
if current_device is None or is_same_device(current_device, 'cpu'): |
|
|
model_management.free_memory(self.size * 1.3, device) |
|
|
if model_management.get_free_memory(device) > self.size * 1.3: |
|
|
try: |
|
|
print("Moving to GPU...", end=" ") |
|
|
obj.to(device) |
|
|
print("OK") |
|
|
except: |
|
|
print(f"Failed\nWARN: Model not moved to '{device}' [1]") |
|
|
else: |
|
|
print(f"WARN: Model not moved to '{device}' [2]") |
|
|
|
|
|
def center_of_bbox(bbox): |
|
|
w, h = bbox[2] - bbox[0], bbox[3] - bbox[1] |
|
|
return bbox[0] + w/2, bbox[1] + h/2 |
|
|
|
|
|
def sam_predict(predictor, points, plabs, bbox, threshold): |
|
|
point_coords = None if not points else np.array(points) |
|
|
point_labels = None if not plabs else np.array(plabs) |
|
|
|
|
|
box = np.array([bbox]) if bbox is not None else None |
|
|
|
|
|
cur_masks, scores, _ = predictor.predict(point_coords=point_coords, point_labels=point_labels, box=box) |
|
|
|
|
|
total_masks = [] |
|
|
|
|
|
selected = False |
|
|
max_score = 0 |
|
|
max_mask = None |
|
|
for idx in range(len(scores)): |
|
|
if scores[idx] > max_score: |
|
|
max_score = scores[idx] |
|
|
max_mask = cur_masks[idx] |
|
|
|
|
|
if scores[idx] >= threshold: |
|
|
selected = True |
|
|
total_masks.append(cur_masks[idx]) |
|
|
else: |
|
|
pass |
|
|
|
|
|
if not selected and max_mask is not None: |
|
|
total_masks.append(max_mask) |
|
|
|
|
|
return total_masks |
|
|
|
|
|
def make_2d_mask(mask): |
|
|
if len(mask.shape) == 4: |
|
|
return mask.squeeze(0).squeeze(0) |
|
|
|
|
|
elif len(mask.shape) == 3: |
|
|
return mask.squeeze(0) |
|
|
|
|
|
return mask |
|
|
|
|
|
def gen_detection_hints_from_mask_area(x, y, mask, threshold, use_negative): |
|
|
mask = make_2d_mask(mask) |
|
|
|
|
|
points = [] |
|
|
plabs = [] |
|
|
|
|
|
|
|
|
y_step = max(3, int(mask.shape[0] / 20)) |
|
|
x_step = max(3, int(mask.shape[1] / 20)) |
|
|
|
|
|
for i in range(0, len(mask), y_step): |
|
|
for j in range(0, len(mask[i]), x_step): |
|
|
if mask[i][j] > threshold: |
|
|
points.append((x + j, y + i)) |
|
|
plabs.append(1) |
|
|
elif use_negative and mask[i][j] == 0: |
|
|
points.append((x + j, y + i)) |
|
|
plabs.append(0) |
|
|
|
|
|
return points, plabs |
|
|
|
|
|
def gen_negative_hints(w, h, x1, y1, x2, y2): |
|
|
npoints = [] |
|
|
nplabs = [] |
|
|
|
|
|
|
|
|
y_step = max(3, int(w / 20)) |
|
|
x_step = max(3, int(h / 20)) |
|
|
|
|
|
for i in range(10, h - 10, y_step): |
|
|
for j in range(10, w - 10, x_step): |
|
|
if not (x1 - 10 <= j and j <= x2 + 10 and y1 - 10 <= i and i <= y2 + 10): |
|
|
npoints.append((j, i)) |
|
|
nplabs.append(0) |
|
|
|
|
|
return npoints, nplabs |
|
|
|
|
|
def generate_detection_hints(image, seg, center, detection_hint, dilated_bbox, mask_hint_threshold, use_small_negative, |
|
|
mask_hint_use_negative): |
|
|
[x1, y1, x2, y2] = dilated_bbox |
|
|
|
|
|
points = [] |
|
|
plabs = [] |
|
|
if detection_hint == "center-1": |
|
|
points.append(center) |
|
|
plabs = [1] |
|
|
|
|
|
elif detection_hint == "horizontal-2": |
|
|
gap = (x2 - x1) / 3 |
|
|
points.append((x1 + gap, center[1])) |
|
|
points.append((x1 + gap * 2, center[1])) |
|
|
plabs = [1, 1] |
|
|
|
|
|
elif detection_hint == "vertical-2": |
|
|
gap = (y2 - y1) / 3 |
|
|
points.append((center[0], y1 + gap)) |
|
|
points.append((center[0], y1 + gap * 2)) |
|
|
plabs = [1, 1] |
|
|
|
|
|
elif detection_hint == "rect-4": |
|
|
x_gap = (x2 - x1) / 3 |
|
|
y_gap = (y2 - y1) / 3 |
|
|
points.append((x1 + x_gap, center[1])) |
|
|
points.append((x1 + x_gap * 2, center[1])) |
|
|
points.append((center[0], y1 + y_gap)) |
|
|
points.append((center[0], y1 + y_gap * 2)) |
|
|
plabs = [1, 1, 1, 1] |
|
|
|
|
|
elif detection_hint == "diamond-4": |
|
|
x_gap = (x2 - x1) / 3 |
|
|
y_gap = (y2 - y1) / 3 |
|
|
points.append((x1 + x_gap, y1 + y_gap)) |
|
|
points.append((x1 + x_gap * 2, y1 + y_gap)) |
|
|
points.append((x1 + x_gap, y1 + y_gap * 2)) |
|
|
points.append((x1 + x_gap * 2, y1 + y_gap * 2)) |
|
|
plabs = [1, 1, 1, 1] |
|
|
|
|
|
elif detection_hint == "mask-point-bbox": |
|
|
center = center_of_bbox(seg.bbox) |
|
|
points.append(center) |
|
|
plabs = [1] |
|
|
|
|
|
elif detection_hint == "mask-area": |
|
|
points, plabs = gen_detection_hints_from_mask_area(seg.crop_region[0], seg.crop_region[1], |
|
|
seg.cropped_mask, |
|
|
mask_hint_threshold, use_small_negative) |
|
|
|
|
|
if mask_hint_use_negative == "Outter": |
|
|
npoints, nplabs = gen_negative_hints(image.shape[0], image.shape[1], |
|
|
seg.crop_region[0], seg.crop_region[1], |
|
|
seg.crop_region[2], seg.crop_region[3]) |
|
|
|
|
|
points += npoints |
|
|
plabs += nplabs |
|
|
|
|
|
return points, plabs |
|
|
|
|
|
def combine_masks2(masks): |
|
|
if len(masks) == 0: |
|
|
return None |
|
|
else: |
|
|
initial_cv2_mask = np.array(masks[0]).astype(np.uint8) |
|
|
combined_cv2_mask = initial_cv2_mask |
|
|
|
|
|
for i in range(1, len(masks)): |
|
|
cv2_mask = np.array(masks[i]).astype(np.uint8) |
|
|
|
|
|
if combined_cv2_mask.shape == cv2_mask.shape: |
|
|
combined_cv2_mask = cv2.bitwise_or(combined_cv2_mask, cv2_mask) |
|
|
else: |
|
|
|
|
|
pass |
|
|
|
|
|
mask = torch.from_numpy(combined_cv2_mask) |
|
|
return mask |
|
|
|
|
|
def dilate_mask(mask, dilation_factor, iter=1): |
|
|
if dilation_factor == 0: |
|
|
return make_2d_mask(mask) |
|
|
|
|
|
mask = make_2d_mask(mask) |
|
|
|
|
|
kernel = np.ones((abs(dilation_factor), abs(dilation_factor)), np.uint8) |
|
|
|
|
|
mask = cv2.UMat(mask) |
|
|
kernel = cv2.UMat(kernel) |
|
|
|
|
|
if dilation_factor > 0: |
|
|
result = cv2.dilate(mask, kernel, iter) |
|
|
else: |
|
|
result = cv2.erode(mask, kernel, iter) |
|
|
|
|
|
return result.get() |
|
|
|
|
|
def convert_and_stack_masks(masks): |
|
|
if len(masks) == 0: |
|
|
return None |
|
|
|
|
|
mask_tensors = [] |
|
|
for mask in masks: |
|
|
mask_array = np.array(mask, dtype=np.uint8) |
|
|
mask_tensor = torch.from_numpy(mask_array) |
|
|
mask_tensors.append(mask_tensor) |
|
|
|
|
|
stacked_masks = torch.stack(mask_tensors, dim=0) |
|
|
stacked_masks = stacked_masks.unsqueeze(1) |
|
|
|
|
|
return stacked_masks |
|
|
|
|
|
def merge_and_stack_masks(stacked_masks, group_size): |
|
|
if stacked_masks is None: |
|
|
return None |
|
|
|
|
|
num_masks = stacked_masks.size(0) |
|
|
merged_masks = [] |
|
|
|
|
|
for i in range(0, num_masks, group_size): |
|
|
subset_masks = stacked_masks[i:i + group_size] |
|
|
merged_mask = torch.any(subset_masks, dim=0) |
|
|
merged_masks.append(merged_mask) |
|
|
|
|
|
if len(merged_masks) > 0: |
|
|
merged_masks = torch.stack(merged_masks, dim=0) |
|
|
|
|
|
return merged_masks |
|
|
|
|
|
def make_sam_mask_segmented(sam_model, segs, image, detection_hint, dilation, |
|
|
threshold, bbox_expansion, mask_hint_threshold, mask_hint_use_negative): |
|
|
if getattr(sam_model, 'is_auto_mode', False): |
|
|
device = model_management.get_torch_device() |
|
|
|
|
|
sam_model.to(device) |
|
|
|
|
|
try: |
|
|
predictor = SamPredictor(sam_model) |
|
|
image = np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8) |
|
|
predictor.set_image(image, "RGB") |
|
|
|
|
|
total_masks = [] |
|
|
|
|
|
use_small_negative = mask_hint_use_negative == "Small" |
|
|
|
|
|
|
|
|
segs = segs[1] |
|
|
if detection_hint == "mask-points": |
|
|
points = [] |
|
|
plabs = [] |
|
|
|
|
|
for i in range(len(segs)): |
|
|
bbox = segs[i].bbox |
|
|
center = center_of_bbox(bbox) |
|
|
points.append(center) |
|
|
|
|
|
|
|
|
if use_small_negative and bbox[2] - bbox[0] < 10: |
|
|
plabs.append(0) |
|
|
else: |
|
|
plabs.append(1) |
|
|
|
|
|
detected_masks = sam_predict(predictor, points, plabs, None, threshold) |
|
|
total_masks += detected_masks |
|
|
|
|
|
else: |
|
|
for i in range(len(segs)): |
|
|
bbox = segs[i].bbox |
|
|
center = center_of_bbox(bbox) |
|
|
x1 = max(bbox[0] - bbox_expansion, 0) |
|
|
y1 = max(bbox[1] - bbox_expansion, 0) |
|
|
x2 = min(bbox[2] + bbox_expansion, image.shape[1]) |
|
|
y2 = min(bbox[3] + bbox_expansion, image.shape[0]) |
|
|
|
|
|
dilated_bbox = [x1, y1, x2, y2] |
|
|
|
|
|
points, plabs = generate_detection_hints(image, segs[i], center, detection_hint, dilated_bbox, |
|
|
mask_hint_threshold, use_small_negative, |
|
|
mask_hint_use_negative) |
|
|
|
|
|
detected_masks = sam_predict(predictor, points, plabs, dilated_bbox, threshold) |
|
|
|
|
|
total_masks += detected_masks |
|
|
|
|
|
|
|
|
mask = combine_masks2(total_masks) |
|
|
|
|
|
finally: |
|
|
|
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
sam_model.to(device) |
|
|
|
|
|
pass |
|
|
|
|
|
|
|
|
mask_working_device = model_management.get_torch_device() |
|
|
|
|
|
if mask is not None: |
|
|
mask = mask.float() |
|
|
|
|
|
|
|
|
|
|
|
mask = mask.unsqueeze(0).unsqueeze(0) |
|
|
if dilation > 0: |
|
|
kernel_size = 1 + dilation * 2 |
|
|
mask = torch.nn.functional.max_pool2d(mask, kernel_size=kernel_size, stride=1, padding=dilation) |
|
|
elif dilation < 0: |
|
|
kernel_size = 1 + abs(dilation) * 2 |
|
|
mask = -torch.nn.functional.max_pool2d(-mask, kernel_size=kernel_size, stride=1, padding=abs(dilation)) |
|
|
mask = mask.squeeze(0).squeeze(0) |
|
|
else: |
|
|
|
|
|
height, width, _ = image.shape |
|
|
mask = torch.zeros( |
|
|
(height, width), dtype=torch.float32, device=mask_working_device |
|
|
) |
|
|
|
|
|
stacked_masks = convert_and_stack_masks(total_masks) |
|
|
|
|
|
return (mask, merge_and_stack_masks(stacked_masks, group_size=3)) |
|
|
|
|
|
def tensor2mask(t: torch.Tensor) -> torch.Tensor: |
|
|
size = t.size() |
|
|
if (len(size) < 4): |
|
|
return t |
|
|
if size[3] == 1: |
|
|
return t[:,:,:,0] |
|
|
elif size[3] == 4: |
|
|
|
|
|
if torch.min(t[:, :, :, 3]).item() != 1.: |
|
|
return t[:,:,:,3] |
|
|
return TF.rgb_to_grayscale(tensor2rgb(t).permute(0,3,1,2), num_output_channels=1)[:,0,:,:] |
|
|
|
|
|
def tensor2rgb(t: torch.Tensor) -> torch.Tensor: |
|
|
size = t.size() |
|
|
if (len(size) < 4): |
|
|
return t.unsqueeze(3).repeat(1, 1, 1, 3) |
|
|
if size[3] == 1: |
|
|
return t.repeat(1, 1, 1, 3) |
|
|
elif size[3] == 4: |
|
|
return t[:, :, :, :3] |
|
|
else: |
|
|
return t |
|
|
|
|
|
def tensor2rgba(t: torch.Tensor) -> torch.Tensor: |
|
|
size = t.size() |
|
|
if (len(size) < 4): |
|
|
return t.unsqueeze(3).repeat(1, 1, 1, 4) |
|
|
elif size[3] == 1: |
|
|
return t.repeat(1, 1, 1, 4) |
|
|
elif size[3] == 3: |
|
|
alpha_tensor = t.new_ones((size[0], size[1], size[2], 1)) |
|
|
return torch.cat((t, alpha_tensor), dim=3) |
|
|
else: |
|
|
return t |
|
|
|