|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| import os
|
| import torch
|
| import numpy as np
|
| from PIL import Image, ImageOps, ImageFilter
|
| import folder_paths
|
| from comfy.model_management import get_torch_device
|
| from torchvision import transforms
|
| from huggingface_hub import hf_hub_download
|
| import shutil
|
| import gc
|
|
|
| def tensor2pil(image):
|
| return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8))
|
|
|
| def pil2tensor(image):
|
| return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
|
|
|
| def pil2comfy(image):
|
| img_tensor = torch.from_numpy(np.array(image).astype(np.float32) / 255.0)
|
| if len(img_tensor.shape) == 3:
|
| img_tensor = img_tensor.unsqueeze(0)
|
| return img_tensor
|
|
|
| def pad_image(image, is_mask=False):
|
| w, h = image.size
|
| if w % 8 != 0:
|
| w = w + (8 - w % 8)
|
| if h % 8 != 0:
|
| h = h + (8 - h % 8)
|
|
|
| fill_color = 0 if is_mask else None
|
| padded = Image.new(image.mode, (w, h), color=fill_color)
|
| padded.paste(image, (0, 0))
|
| return padded
|
|
|
| def cropimage(image, w, h):
|
| return image.crop((0, 0, w, h))
|
|
|
| DEVICE = get_torch_device()
|
| folder_paths.add_model_folder_path("rmbg", os.path.join(folder_paths.models_dir, "RMBG"))
|
|
|
| class AILab_LamaRemover:
|
| @classmethod
|
| def INPUT_TYPES(s):
|
| tooltips = {
|
| "images": "Input images to be processed",
|
| "masks": "Masks defining areas to be removed (white=remove)",
|
| "removal_strength": "Strength of the removal effect (higher values increase the effect area)",
|
| "edge_smoothness": "Controls edge smoothness (higher values create smoother transitions)"
|
| }
|
|
|
| return {
|
| "required": {
|
| "images": ("IMAGE", {"tooltip": tooltips["images"]}),
|
| "masks": ("MASK", {"tooltip": tooltips["masks"]}),
|
| "removal_strength": ("INT", {"default": 230, "min": 0, "max": 255, "step": 1, "display": "slider", "tooltip": tooltips["removal_strength"]}),
|
| "edge_smoothness": ("INT", {"default": 8, "min": 0, "max": 20, "step": 1, "display": "slider", "tooltip": tooltips["edge_smoothness"]}),
|
| },
|
| }
|
|
|
| CATEGORY = "🧪AILab/🧽RMBG"
|
| RETURN_NAMES = ("images",)
|
| RETURN_TYPES = ("IMAGE",)
|
| FUNCTION = "remove_object"
|
|
|
| def __init__(self):
|
| self.model = None
|
| self.device = DEVICE
|
| self.cache_dir = os.path.join(folder_paths.models_dir, "RMBG", "Lama")
|
| self.model_path = os.path.join(self.cache_dir, "big-lama.pt")
|
| self.to_pil = transforms.ToPILImage()
|
|
|
| def load_model(self):
|
| if self.model is not None:
|
| return
|
|
|
| if not os.path.exists(self.model_path):
|
| self.download_model()
|
|
|
| try:
|
| self.model = torch.jit.load(self.model_path, map_location=self.device)
|
| except Exception as e:
|
| print(f"Can't use comfy device: {str(e)}")
|
| self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
| self.model = torch.jit.load(self.model_path, map_location=self.device)
|
|
|
| self.model.eval()
|
| self.model.to(self.device)
|
|
|
| def download_model(self):
|
| print("Downloading Big-Lama model...")
|
| os.makedirs(self.cache_dir, exist_ok=True)
|
|
|
| try:
|
| downloaded_path = hf_hub_download(
|
| repo_id="1038lab/Lama",
|
| filename="big-lama.pt",
|
| local_dir=self.cache_dir,
|
| local_dir_use_symlinks=False
|
| )
|
|
|
| if os.path.dirname(downloaded_path) != self.cache_dir:
|
| shutil.move(downloaded_path, self.model_path)
|
|
|
| print("Big-Lama model downloaded successfully")
|
| except Exception as e:
|
| raise RuntimeError(f"Error downloading Big-Lama model: {str(e)}")
|
|
|
| def process_with_model(self, img_tensor, mask_tensor):
|
| with torch.inference_mode():
|
| img_tensor = img_tensor.to(self.device)
|
| mask_tensor = mask_tensor.to(self.device)
|
|
|
| result = self.model(img_tensor, mask_tensor)
|
| result_cpu = result[0].cpu()
|
|
|
| del img_tensor
|
| del mask_tensor
|
| if torch.cuda.is_available():
|
| torch.cuda.empty_cache()
|
|
|
| return result_cpu
|
|
|
| def remove_object(self, images, masks, removal_strength, edge_smoothness):
|
| try:
|
| self.load_model()
|
| results = []
|
|
|
| for image, mask in zip(images, masks):
|
| ori_image = tensor2pil(image)
|
| w, h = ori_image.size
|
| p_image = pad_image(ori_image)
|
|
|
| mask_np = mask.cpu().numpy()
|
| mask_pil = Image.fromarray((mask_np * 255).astype(np.uint8))
|
| p_mask = pad_image(mask_pil, is_mask=True)
|
|
|
| if p_mask.size != p_image.size:
|
| try:
|
| p_mask = p_mask.resize(p_image.size, Image.LANCZOS)
|
| except AttributeError:
|
| p_mask = p_mask.resize(p_image.size, Image.ANTIALIAS)
|
|
|
| p_mask = ImageOps.invert(p_mask)
|
| p_mask = p_mask.filter(ImageFilter.GaussianBlur(radius=edge_smoothness))
|
| gray = p_mask.point(lambda x: 0 if x > removal_strength else 255)
|
|
|
| img_tensor = torch.FloatTensor(np.array(p_image)).permute(2, 0, 1).unsqueeze(0) / 255.0
|
| mask_tensor = torch.FloatTensor(np.array(gray)).unsqueeze(0).unsqueeze(0) / 255.0
|
|
|
| result = self.process_with_model(img_tensor, mask_tensor)
|
| result_img = self.to_pil(result.squeeze())
|
|
|
| if result_img.width > w or result_img.height > h:
|
| result_img = cropimage(result_img, w, h)
|
|
|
| result_tensor = pil2comfy(result_img)
|
| results.append(result_tensor)
|
|
|
| del result
|
| gc.collect()
|
|
|
| return (torch.cat(results, dim=0),)
|
|
|
| except Exception as e:
|
| import traceback
|
| print(traceback.format_exc())
|
| raise RuntimeError(f"Error in object removal: {str(e)}")
|
| finally:
|
| if torch.cuda.is_available():
|
| torch.cuda.empty_cache()
|
|
|
| NODE_CLASS_MAPPINGS = {
|
| "AILab_LamaRemover": AILab_LamaRemover,
|
| }
|
|
|
| NODE_DISPLAY_NAME_MAPPINGS = {
|
| "AILab_LamaRemover": "Lama Remover (RMBG)",
|
| } |