File size: 7,229 Bytes
c6535db | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 | # ComfyUI-RMBG
# This custom node for ComfyUI provides functionality for Object removal using Big-Lama model.
#
# reference from https://github.com/advimman/lama
#
# This integration script follows GPL-3.0 License.
# When using or modifying this code, please respect both the original model licenses
# and this integration's license terms.
#
# Source: https://github.com/AILab-AI/ComfyUI-RMBG
import os
import torch
import numpy as np
from PIL import Image, ImageOps, ImageFilter
import folder_paths
from comfy.model_management import get_torch_device
from torchvision import transforms
from huggingface_hub import hf_hub_download
import shutil
import gc
def tensor2pil(image):
return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
def pil2tensor(image):
return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
def pil2comfy(image):
img_tensor = torch.from_numpy(np.array(image).astype(np.float32) / 255.0)
if len(img_tensor.shape) == 3:
img_tensor = img_tensor.unsqueeze(0)
return img_tensor
def pad_image(image, is_mask=False):
w, h = image.size
if w % 8 != 0:
w = w + (8 - w % 8)
if h % 8 != 0:
h = h + (8 - h % 8)
fill_color = 0 if is_mask else None
padded = Image.new(image.mode, (w, h), color=fill_color)
padded.paste(image, (0, 0))
return padded
def cropimage(image, w, h):
return image.crop((0, 0, w, h))
DEVICE = get_torch_device()
folder_paths.add_model_folder_path("rmbg", os.path.join(folder_paths.models_dir, "RMBG"))
class AILab_LamaRemover:
@classmethod
def INPUT_TYPES(s):
tooltips = {
"images": "Input images to be processed",
"masks": "Masks defining areas to be removed (white=remove)",
"removal_strength": "Strength of the removal effect (higher values increase the effect area)",
"edge_smoothness": "Controls edge smoothness (higher values create smoother transitions)"
}
return {
"required": {
"images": ("IMAGE", {"tooltip": tooltips["images"]}),
"masks": ("MASK", {"tooltip": tooltips["masks"]}),
"removal_strength": ("INT", {"default": 230, "min": 0, "max": 255, "step": 1, "display": "slider", "tooltip": tooltips["removal_strength"]}),
"edge_smoothness": ("INT", {"default": 8, "min": 0, "max": 20, "step": 1, "display": "slider", "tooltip": tooltips["edge_smoothness"]}),
},
}
CATEGORY = "🧪AILab/🧽RMBG"
RETURN_NAMES = ("images",)
RETURN_TYPES = ("IMAGE",)
FUNCTION = "remove_object"
def __init__(self):
self.model = None
self.device = DEVICE
self.cache_dir = os.path.join(folder_paths.models_dir, "RMBG", "Lama")
self.model_path = os.path.join(self.cache_dir, "big-lama.pt")
self.to_pil = transforms.ToPILImage()
def load_model(self):
if self.model is not None:
return
if not os.path.exists(self.model_path):
self.download_model()
try:
self.model = torch.jit.load(self.model_path, map_location=self.device)
except Exception as e:
print(f"Can't use comfy device: {str(e)}")
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.model = torch.jit.load(self.model_path, map_location=self.device)
self.model.eval()
self.model.to(self.device)
def download_model(self):
print("Downloading Big-Lama model...")
os.makedirs(self.cache_dir, exist_ok=True)
try:
downloaded_path = hf_hub_download(
repo_id="1038lab/Lama",
filename="big-lama.pt",
local_dir=self.cache_dir,
local_dir_use_symlinks=False
)
if os.path.dirname(downloaded_path) != self.cache_dir:
shutil.move(downloaded_path, self.model_path)
print("Big-Lama model downloaded successfully")
except Exception as e:
raise RuntimeError(f"Error downloading Big-Lama model: {str(e)}")
def process_with_model(self, img_tensor, mask_tensor):
with torch.inference_mode():
img_tensor = img_tensor.to(self.device)
mask_tensor = mask_tensor.to(self.device)
result = self.model(img_tensor, mask_tensor)
result_cpu = result[0].cpu()
del img_tensor
del mask_tensor
if torch.cuda.is_available():
torch.cuda.empty_cache()
return result_cpu
def remove_object(self, images, masks, removal_strength, edge_smoothness):
try:
self.load_model()
results = []
for image, mask in zip(images, masks):
ori_image = tensor2pil(image)
w, h = ori_image.size
p_image = pad_image(ori_image)
mask_np = mask.cpu().numpy()
mask_pil = Image.fromarray((mask_np * 255).astype(np.uint8))
p_mask = pad_image(mask_pil, is_mask=True)
if p_mask.size != p_image.size:
try:
p_mask = p_mask.resize(p_image.size, Image.LANCZOS)
except AttributeError:
p_mask = p_mask.resize(p_image.size, Image.ANTIALIAS)
p_mask = ImageOps.invert(p_mask)
p_mask = p_mask.filter(ImageFilter.GaussianBlur(radius=edge_smoothness))
gray = p_mask.point(lambda x: 0 if x > removal_strength else 255)
img_tensor = torch.FloatTensor(np.array(p_image)).permute(2, 0, 1).unsqueeze(0) / 255.0
mask_tensor = torch.FloatTensor(np.array(gray)).unsqueeze(0).unsqueeze(0) / 255.0
result = self.process_with_model(img_tensor, mask_tensor)
result_img = self.to_pil(result.squeeze())
if result_img.width > w or result_img.height > h:
result_img = cropimage(result_img, w, h)
result_tensor = pil2comfy(result_img)
results.append(result_tensor)
del result
gc.collect()
return (torch.cat(results, dim=0),)
except Exception as e:
import traceback
print(traceback.format_exc())
raise RuntimeError(f"Error in object removal: {str(e)}")
finally:
if torch.cuda.is_available():
torch.cuda.empty_cache()
NODE_CLASS_MAPPINGS = {
"AILab_LamaRemover": AILab_LamaRemover,
}
NODE_DISPLAY_NAME_MAPPINGS = {
"AILab_LamaRemover": "Lama Remover (RMBG)",
} |