| import os,sys |
| import folder_paths |
|
|
| from PIL import Image |
| import importlib.util |
|
|
| import comfy.utils |
| import numpy as np |
| import torch |
|
|
| from huggingface_hub import hf_hub_download |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torchvision.transforms.functional import normalize |
| |
| class REBNCONV(nn.Module): |
| def __init__(self,in_ch=3,out_ch=3,dirate=1,stride=1): |
| super(REBNCONV,self).__init__() |
|
|
| self.conv_s1 = nn.Conv2d(in_ch,out_ch,3,padding=1*dirate,dilation=1*dirate,stride=stride) |
| self.bn_s1 = nn.BatchNorm2d(out_ch) |
| self.relu_s1 = nn.ReLU(inplace=True) |
|
|
| def forward(self,x): |
|
|
| hx = x |
| xout = self.relu_s1(self.bn_s1(self.conv_s1(hx))) |
|
|
| return xout |
|
|
| |
| def _upsample_like(src,tar): |
|
|
| src = F.interpolate(src,size=tar.shape[2:],mode='bilinear') |
|
|
| return src |
|
|
|
|
| |
| class RSU7(nn.Module): |
|
|
| def __init__(self, in_ch=3, mid_ch=12, out_ch=3, img_size=512): |
| super(RSU7,self).__init__() |
|
|
| self.in_ch = in_ch |
| self.mid_ch = mid_ch |
| self.out_ch = out_ch |
|
|
| self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) |
|
|
| self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) |
| self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
| self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) |
| self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
| self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) |
| self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
| self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1) |
| self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
| self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1) |
| self.pool5 = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
| self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=1) |
|
|
| self.rebnconv7 = REBNCONV(mid_ch,mid_ch,dirate=2) |
|
|
| self.rebnconv6d = REBNCONV(mid_ch*2,mid_ch,dirate=1) |
| self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1) |
| self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1) |
| self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) |
| self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) |
| self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) |
|
|
| def forward(self,x): |
| b, c, h, w = x.shape |
|
|
| hx = x |
| hxin = self.rebnconvin(hx) |
|
|
| hx1 = self.rebnconv1(hxin) |
| hx = self.pool1(hx1) |
|
|
| hx2 = self.rebnconv2(hx) |
| hx = self.pool2(hx2) |
|
|
| hx3 = self.rebnconv3(hx) |
| hx = self.pool3(hx3) |
|
|
| hx4 = self.rebnconv4(hx) |
| hx = self.pool4(hx4) |
|
|
| hx5 = self.rebnconv5(hx) |
| hx = self.pool5(hx5) |
|
|
| hx6 = self.rebnconv6(hx) |
|
|
| hx7 = self.rebnconv7(hx6) |
|
|
| hx6d = self.rebnconv6d(torch.cat((hx7,hx6),1)) |
| hx6dup = _upsample_like(hx6d,hx5) |
|
|
| hx5d = self.rebnconv5d(torch.cat((hx6dup,hx5),1)) |
| hx5dup = _upsample_like(hx5d,hx4) |
|
|
| hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1)) |
| hx4dup = _upsample_like(hx4d,hx3) |
|
|
| hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) |
| hx3dup = _upsample_like(hx3d,hx2) |
|
|
| hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) |
| hx2dup = _upsample_like(hx2d,hx1) |
|
|
| hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) |
|
|
| return hx1d + hxin |
|
|
|
|
| |
| class RSU6(nn.Module): |
|
|
| def __init__(self, in_ch=3, mid_ch=12, out_ch=3): |
| super(RSU6,self).__init__() |
|
|
| self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) |
|
|
| self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) |
| self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
| self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) |
| self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
| self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) |
| self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
| self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1) |
| self.pool4 = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
| self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=1) |
|
|
| self.rebnconv6 = REBNCONV(mid_ch,mid_ch,dirate=2) |
|
|
| self.rebnconv5d = REBNCONV(mid_ch*2,mid_ch,dirate=1) |
| self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1) |
| self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) |
| self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) |
| self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) |
|
|
| def forward(self,x): |
|
|
| hx = x |
|
|
| hxin = self.rebnconvin(hx) |
|
|
| hx1 = self.rebnconv1(hxin) |
| hx = self.pool1(hx1) |
|
|
| hx2 = self.rebnconv2(hx) |
| hx = self.pool2(hx2) |
|
|
| hx3 = self.rebnconv3(hx) |
| hx = self.pool3(hx3) |
|
|
| hx4 = self.rebnconv4(hx) |
| hx = self.pool4(hx4) |
|
|
| hx5 = self.rebnconv5(hx) |
|
|
| hx6 = self.rebnconv6(hx5) |
|
|
|
|
| hx5d = self.rebnconv5d(torch.cat((hx6,hx5),1)) |
| hx5dup = _upsample_like(hx5d,hx4) |
|
|
| hx4d = self.rebnconv4d(torch.cat((hx5dup,hx4),1)) |
| hx4dup = _upsample_like(hx4d,hx3) |
|
|
| hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) |
| hx3dup = _upsample_like(hx3d,hx2) |
|
|
| hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) |
| hx2dup = _upsample_like(hx2d,hx1) |
|
|
| hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) |
|
|
| return hx1d + hxin |
|
|
| |
| class RSU5(nn.Module): |
|
|
| def __init__(self, in_ch=3, mid_ch=12, out_ch=3): |
| super(RSU5,self).__init__() |
|
|
| self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) |
|
|
| self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) |
| self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
| self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) |
| self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
| self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) |
| self.pool3 = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
| self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=1) |
|
|
| self.rebnconv5 = REBNCONV(mid_ch,mid_ch,dirate=2) |
|
|
| self.rebnconv4d = REBNCONV(mid_ch*2,mid_ch,dirate=1) |
| self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) |
| self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) |
| self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) |
|
|
| def forward(self,x): |
|
|
| hx = x |
|
|
| hxin = self.rebnconvin(hx) |
|
|
| hx1 = self.rebnconv1(hxin) |
| hx = self.pool1(hx1) |
|
|
| hx2 = self.rebnconv2(hx) |
| hx = self.pool2(hx2) |
|
|
| hx3 = self.rebnconv3(hx) |
| hx = self.pool3(hx3) |
|
|
| hx4 = self.rebnconv4(hx) |
|
|
| hx5 = self.rebnconv5(hx4) |
|
|
| hx4d = self.rebnconv4d(torch.cat((hx5,hx4),1)) |
| hx4dup = _upsample_like(hx4d,hx3) |
|
|
| hx3d = self.rebnconv3d(torch.cat((hx4dup,hx3),1)) |
| hx3dup = _upsample_like(hx3d,hx2) |
|
|
| hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) |
| hx2dup = _upsample_like(hx2d,hx1) |
|
|
| hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) |
|
|
| return hx1d + hxin |
|
|
| |
| class RSU4(nn.Module): |
|
|
| def __init__(self, in_ch=3, mid_ch=12, out_ch=3): |
| super(RSU4,self).__init__() |
|
|
| self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) |
|
|
| self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) |
| self.pool1 = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
| self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=1) |
| self.pool2 = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
| self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=1) |
|
|
| self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=2) |
|
|
| self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=1) |
| self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=1) |
| self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) |
|
|
| def forward(self,x): |
|
|
| hx = x |
|
|
| hxin = self.rebnconvin(hx) |
|
|
| hx1 = self.rebnconv1(hxin) |
| hx = self.pool1(hx1) |
|
|
| hx2 = self.rebnconv2(hx) |
| hx = self.pool2(hx2) |
|
|
| hx3 = self.rebnconv3(hx) |
|
|
| hx4 = self.rebnconv4(hx3) |
|
|
| hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1)) |
| hx3dup = _upsample_like(hx3d,hx2) |
|
|
| hx2d = self.rebnconv2d(torch.cat((hx3dup,hx2),1)) |
| hx2dup = _upsample_like(hx2d,hx1) |
|
|
| hx1d = self.rebnconv1d(torch.cat((hx2dup,hx1),1)) |
|
|
| return hx1d + hxin |
|
|
| |
| class RSU4F(nn.Module): |
|
|
| def __init__(self, in_ch=3, mid_ch=12, out_ch=3): |
| super(RSU4F,self).__init__() |
|
|
| self.rebnconvin = REBNCONV(in_ch,out_ch,dirate=1) |
|
|
| self.rebnconv1 = REBNCONV(out_ch,mid_ch,dirate=1) |
| self.rebnconv2 = REBNCONV(mid_ch,mid_ch,dirate=2) |
| self.rebnconv3 = REBNCONV(mid_ch,mid_ch,dirate=4) |
|
|
| self.rebnconv4 = REBNCONV(mid_ch,mid_ch,dirate=8) |
|
|
| self.rebnconv3d = REBNCONV(mid_ch*2,mid_ch,dirate=4) |
| self.rebnconv2d = REBNCONV(mid_ch*2,mid_ch,dirate=2) |
| self.rebnconv1d = REBNCONV(mid_ch*2,out_ch,dirate=1) |
|
|
| def forward(self,x): |
|
|
| hx = x |
|
|
| hxin = self.rebnconvin(hx) |
|
|
| hx1 = self.rebnconv1(hxin) |
| hx2 = self.rebnconv2(hx1) |
| hx3 = self.rebnconv3(hx2) |
|
|
| hx4 = self.rebnconv4(hx3) |
|
|
| hx3d = self.rebnconv3d(torch.cat((hx4,hx3),1)) |
| hx2d = self.rebnconv2d(torch.cat((hx3d,hx2),1)) |
| hx1d = self.rebnconv1d(torch.cat((hx2d,hx1),1)) |
|
|
| return hx1d + hxin |
|
|
|
|
| class myrebnconv(nn.Module): |
| def __init__(self, in_ch=3, |
| out_ch=1, |
| kernel_size=3, |
| stride=1, |
| padding=1, |
| dilation=1, |
| groups=1): |
| super(myrebnconv,self).__init__() |
|
|
| self.conv = nn.Conv2d(in_ch, |
| out_ch, |
| kernel_size=kernel_size, |
| stride=stride, |
| padding=padding, |
| dilation=dilation, |
| groups=groups) |
| self.bn = nn.BatchNorm2d(out_ch) |
| self.rl = nn.ReLU(inplace=True) |
|
|
| def forward(self,x): |
| return self.rl(self.bn(self.conv(x))) |
|
|
|
|
| class BriaRMBG(nn.Module): |
|
|
| def __init__(self,in_ch=3,out_ch=1): |
| super(BriaRMBG,self).__init__() |
|
|
| self.conv_in = nn.Conv2d(in_ch,64,3,stride=2,padding=1) |
| self.pool_in = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
| self.stage1 = RSU7(64,32,64) |
| self.pool12 = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
| self.stage2 = RSU6(64,32,128) |
| self.pool23 = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
| self.stage3 = RSU5(128,64,256) |
| self.pool34 = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
| self.stage4 = RSU4(256,128,512) |
| self.pool45 = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
| self.stage5 = RSU4F(512,256,512) |
| self.pool56 = nn.MaxPool2d(2,stride=2,ceil_mode=True) |
|
|
| self.stage6 = RSU4F(512,256,512) |
|
|
| |
| self.stage5d = RSU4F(1024,256,512) |
| self.stage4d = RSU4(1024,128,256) |
| self.stage3d = RSU5(512,64,128) |
| self.stage2d = RSU6(256,32,64) |
| self.stage1d = RSU7(128,16,64) |
|
|
| self.side1 = nn.Conv2d(64,out_ch,3,padding=1) |
| self.side2 = nn.Conv2d(64,out_ch,3,padding=1) |
| self.side3 = nn.Conv2d(128,out_ch,3,padding=1) |
| self.side4 = nn.Conv2d(256,out_ch,3,padding=1) |
| self.side5 = nn.Conv2d(512,out_ch,3,padding=1) |
| self.side6 = nn.Conv2d(512,out_ch,3,padding=1) |
|
|
| |
|
|
| def forward(self,x): |
|
|
| hx = x |
|
|
| hxin = self.conv_in(hx) |
| |
|
|
| |
| hx1 = self.stage1(hxin) |
| hx = self.pool12(hx1) |
|
|
| |
| hx2 = self.stage2(hx) |
| hx = self.pool23(hx2) |
|
|
| |
| hx3 = self.stage3(hx) |
| hx = self.pool34(hx3) |
|
|
| |
| hx4 = self.stage4(hx) |
| hx = self.pool45(hx4) |
|
|
| |
| hx5 = self.stage5(hx) |
| hx = self.pool56(hx5) |
|
|
| |
| hx6 = self.stage6(hx) |
| hx6up = _upsample_like(hx6,hx5) |
|
|
| |
| hx5d = self.stage5d(torch.cat((hx6up,hx5),1)) |
| hx5dup = _upsample_like(hx5d,hx4) |
|
|
| hx4d = self.stage4d(torch.cat((hx5dup,hx4),1)) |
| hx4dup = _upsample_like(hx4d,hx3) |
|
|
| hx3d = self.stage3d(torch.cat((hx4dup,hx3),1)) |
| hx3dup = _upsample_like(hx3d,hx2) |
|
|
| hx2d = self.stage2d(torch.cat((hx3dup,hx2),1)) |
| hx2dup = _upsample_like(hx2d,hx1) |
|
|
| hx1d = self.stage1d(torch.cat((hx2dup,hx1),1)) |
|
|
|
|
| |
| d1 = self.side1(hx1d) |
| d1 = _upsample_like(d1,x) |
|
|
| d2 = self.side2(hx2d) |
| d2 = _upsample_like(d2,x) |
|
|
| d3 = self.side3(hx3d) |
| d3 = _upsample_like(d3,x) |
|
|
| d4 = self.side4(hx4d) |
| d4 = _upsample_like(d4,x) |
|
|
| d5 = self.side5(hx5d) |
| d5 = _upsample_like(d5,x) |
|
|
| d6 = self.side6(hx6) |
| d6 = _upsample_like(d6,x) |
|
|
| return [F.sigmoid(d1), F.sigmoid(d2), F.sigmoid(d3), F.sigmoid(d4), F.sigmoid(d5), F.sigmoid(d6)],[hx1d,hx2d,hx3d,hx4d,hx5d,hx6] |
| |
|
|
|
|
|
|
| def get_U2NET_model_path(): |
| try: |
| return folder_paths.get_folder_paths('rembg')[0] |
| except: |
| return os.path.join(folder_paths.models_dir, "rembg") |
| |
|
|
| U2NET_HOME=get_U2NET_model_path() |
| os.environ["U2NET_HOME"] = U2NET_HOME |
|
|
| global _available |
| _available=False |
|
|
|
|
| def get_rembg_models(path): |
| """从目录中获取文件并提取文件名 |
| Args: |
| path: 目录路径 |
| Returns: |
| 文件名列表 |
| """ |
| filenames = [] |
| for root, _, files in os.walk(path): |
| for filename in files: |
| |
| if not filename.startswith('.'): |
| name, ext = os.path.splitext(os.path.basename(filename)) |
| filenames.append(name) |
| return filenames |
|
|
|
|
| def is_installed(package): |
| try: |
| spec = importlib.util.find_spec(package) |
| except ModuleNotFoundError: |
| return False |
| return spec is not None |
|
|
|
|
| try: |
| if is_installed('rembg')==False: |
| import subprocess |
|
|
| |
| print('#pip install rembg[gpu]') |
|
|
| result = subprocess.run([sys.executable, '-s', '-m', 'pip', 'install', 'rembg[gpu]'], capture_output=True, text=True) |
|
|
| |
| if result.returncode == 0: |
| print("#install success") |
| from rembg import new_session, remove |
| _available=True |
| else: |
| print("#install error") |
| |
| else: |
| from rembg import new_session, remove |
| _available=True |
|
|
| except: |
| _available=False |
|
|
|
|
| def run_briarmbg(images=[]): |
| mroot=U2NET_HOME |
| m=os.path.join(mroot,'briarmbg.pth') |
| if os.path.exists(m)==False: |
| |
| m1=hf_hub_download("briaai/RMBG-1.4", |
| local_dir=mroot, |
| filename='model.pth', |
| local_dir_use_symlinks=False, |
| endpoint='https://hf-mirror.com') |
| os.rename(m1, m) |
| |
| net=BriaRMBG() |
| if torch.cuda.is_available(): |
| net.load_state_dict(torch.load(m)) |
| net=net.cuda() |
| else: |
| net.load_state_dict(torch.load(m,map_location="cpu")) |
| net.eval() |
|
|
| masks=[] |
| rgba_images=[] |
| rgb_images=[] |
| for orig_image in images: |
|
|
| w,h = orig_im_size = orig_image.size |
|
|
| image = orig_image.convert('RGB') |
| model_input_size = (1024, 1024) |
| image = image.resize(model_input_size, Image.BILINEAR) |
|
|
| im_np = np.array(image) |
| im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1) |
| im_tensor = torch.unsqueeze(im_tensor,0) |
| im_tensor = torch.divide(im_tensor,255.0) |
| im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0]) |
| if torch.cuda.is_available(): |
| im_tensor=im_tensor.cuda() |
|
|
| result=net(im_tensor) |
| result = torch.squeeze(F.interpolate(result[0][0], size=(h,w), mode='bilinear') ,0) |
| ma = torch.max(result) |
| mi = torch.min(result) |
| result = (result-mi)/(ma-mi) |
| im_array = (result*255).cpu().data.numpy().astype(np.uint8) |
| mask = Image.fromarray(np.squeeze(im_array)) |
| |
| |
| mask=mask.convert('L') |
|
|
| masks.append(mask) |
|
|
| |
| image_rgba =orig_image.convert("RGBA") |
| image_rgba.putalpha(mask) |
| rgba_images.append(image_rgba) |
|
|
| |
| rgb_image = Image.new("RGB", image_rgba.size, (0, 0, 0)) |
| rgb_image.paste(image_rgba, mask=image_rgba.split()[3]) |
| rgb_images.append(rgb_image) |
| return (masks,rgba_images,rgb_images) |
|
|
|
|
| def run_rembg(model_name= "unet",images=[],callback=None): |
| |
| |
| rembg_session = new_session(model_name) |
| masks=[] |
| rgba_images=[] |
| rgb_images=[] |
| |
| pbar=callback |
| for img in images: |
| |
| mask = remove(img, session=rembg_session,only_mask=True,post_process_mask=True) |
| |
| |
| if model_name=="u2net_cloth_seg": |
| width, original_height = mask.size |
| num_slices = original_height // img.height |
| for i in range(num_slices): |
| top = i * img.height |
| bottom = (i + 1) * img.height |
| slice_image = mask.crop((0, top, width, bottom)) |
| slice_mask=slice_image.convert('L') |
| masks.append(slice_mask) |
|
|
| |
| image_rgba = img.convert("RGBA") |
| image_rgba.putalpha(slice_mask) |
| rgba_images.append(image_rgba) |
|
|
| |
| rgb_image = Image.new("RGB", image_rgba.size, (0, 0, 0)) |
| rgb_image.paste(image_rgba, mask=image_rgba.split()[3]) |
| rgb_images.append(rgb_image) |
|
|
| else: |
| mask=mask.convert('L') |
| |
| masks.append(mask) |
|
|
| |
| image_rgba = img.convert("RGBA") |
| image_rgba.putalpha(mask) |
| rgba_images.append(image_rgba) |
|
|
| |
| rgb_image = Image.new("RGB", image_rgba.size, (0, 0, 0)) |
| rgb_image.paste(image_rgba, mask=image_rgba.split()[3]) |
| rgb_images.append(rgb_image) |
|
|
| if pbar: |
| pbar.update(1) |
| return (masks,rgba_images,rgb_images) |
|
|
|
|
| |
| def tensor2pil(image): |
| return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) |
|
|
| |
| def pil2tensor(image): |
| return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) |
|
|
|
|
| class RembgNode_: |
|
|
| global _available |
| available=_available |
| |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "image": ("IMAGE",), |
| "model_name": (get_rembg_models(U2NET_HOME),), |
| |
| }, |
| } |
| |
| RETURN_TYPES = ("MASK","IMAGE","RGBA",) |
| RETURN_NAMES = ("masks","images","RGBAs") |
|
|
| FUNCTION = "run" |
|
|
| CATEGORY = "♾️Mixlab/Mask" |
| OUTPUT_NODE = True |
| INPUT_IS_LIST = True |
| OUTPUT_IS_LIST = (True,True,True,) |
| |
| def run(self,image,model_name): |
| |
|
|
| model_name=model_name[0] |
|
|
| images=[] |
| |
| for ims in image: |
| for im in ims: |
| im=tensor2pil(im) |
| images.append(im) |
|
|
| if model_name=='briarmbg': |
| masks,rgba_images,rgb_images=run_briarmbg(images) |
| else: |
| masks,rgba_images,rgb_images=run_rembg(model_name,images, comfy.utils.ProgressBar(len(images) )) |
|
|
| masks=[pil2tensor(m) for m in masks] |
|
|
| rgba_images=[pil2tensor(m) for m in rgba_images] |
|
|
| rgb_images=[pil2tensor(m) for m in rgb_images] |
|
|
| return (masks,rgb_images,rgba_images,) |