| import os,sys |
| import folder_paths |
|
|
| from PIL import Image |
| import importlib.util |
| import numpy as np |
|
|
| import torch |
|
|
| global _available |
| _available=False |
|
|
| def is_installed(package): |
| try: |
| spec = importlib.util.find_spec(package) |
| except ModuleNotFoundError: |
| return False |
| return spec is not None |
|
|
| if is_installed('simple_lama_inpainting')==False: |
| import subprocess |
| from packaging import version |
| |
| if version.parse(torch.__version__)>=version.parse('2.1'): |
| |
| print('#pip install simple_lama_inpainting') |
|
|
| result = subprocess.run([sys.executable, '-s', '-m', 'pip', 'install', 'simple_lama_inpainting'], capture_output=True, text=True) |
|
|
| |
| if result.returncode == 0: |
| print("#install success") |
| from simple_lama_inpainting import SimpleLama |
| _available=True |
| else: |
| print("#install error") |
| else: |
| print('#pls check your torch version >= 2.1') |
|
|
| else: |
| from simple_lama_inpainting import SimpleLama |
| _available=True |
|
|
|
|
| def get_lama_path(): |
| try: |
| return folder_paths.get_folder_paths('lama')[0] |
| except: |
| return os.path.join(folder_paths.models_dir, "lama") |
|
|
| llma_model_path=os.path.join(get_lama_path(), "big-lama.pt") |
| if not os.path.exists(llma_model_path): |
| os.environ['LAMA_MODEL']='' |
| print(f"## lama torchscript model not found: {llma_model_path},pls download from https://github.com/enesmsahin/simple-lama-inpainting/releases/download/v0.1.0/big-lama.pt") |
| else: |
| os.environ['LAMA_MODEL'] = llma_model_path |
|
|
| |
| def tensor2pil(image): |
| return Image.fromarray(np.clip(255. * image.cpu().numpy().squeeze(), 0, 255).astype(np.uint8)) |
|
|
| |
| def pil2tensor(image): |
| return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0) |
|
|
|
|
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
|
|
| class LaMaInpainting: |
| global _available |
| available=_available |
| @classmethod |
| def INPUT_TYPES(s): |
| return {"required": { |
| "image": ("IMAGE",), |
| "mask": ("MASK",), |
| }, |
| } |
| |
| RETURN_TYPES = ("IMAGE",) |
|
|
| FUNCTION = "run" |
|
|
| CATEGORY = "♾️Mixlab/Image" |
|
|
| INPUT_IS_LIST = True |
| OUTPUT_IS_LIST = (True,) |
| global simple_lama |
| simple_lama = None |
| def run(self,image,mask): |
| global simple_lama |
|
|
| result=[] |
| if simple_lama==None: |
| simple_lama = SimpleLama() |
| else: |
| simple_lama.model.to("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| for i in range(len(image)): |
| im=image[i] |
| ma=mask[i] |
| im=tensor2pil(im) |
| ma=tensor2pil(ma) |
| ma =ma.convert('L') |
| |
| res = simple_lama(im, ma) |
| res=pil2tensor(res) |
| result.append(res) |
| |
| if simple_lama.device=='cuda': |
| simple_lama.model.to('cpu') |
|
|
| return (result,) |