File size: 3,202 Bytes
ca2a3d8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
from PIL import Image
import numpy as np
from transformers import CLIPSegProcessor, CLIPSegForImageSegmentation
import folder_paths
import os, requests

def get_path():
    if "clipseg" in folder_paths.folder_names_and_paths:
        paths = folder_paths.folder_names_and_paths["clipseg"]
        return paths[0][0]
    else:
        # Jank backup path if you're not running properly in Swarm
        path = os.path.dirname(os.path.realpath(__file__)) + "/models"
        return path


# Manual download of the model from a safetensors conversion.
# Done manually to guarantee it's only a safetensors file ever and not a pickle
def download_model(path, urlbase):
    if os.path.exists(path):
        return
    for file in ["config.json", "merges.txt", "model.safetensors", "preprocessor_config.json", "special_tokens_map.json", "tokenizer_config.json", "vocab.json"]:
        os.makedirs(path, exist_ok=True)
        filepath = path + file
        if not os.path.exists(filepath):
            with open(filepath, "wb") as f:
                print(f"[SwarmClipSeg] Downloading '{file}'...")
                f.write(requests.get(f"{urlbase}{file}").content)


class SwarmClipSeg:
    @classmethod
    def INPUT_TYPES(s):
        return {
            "required": {
                "images": ("IMAGE",),
                "match_text": ("STRING", {"multiline": True, "tooltip": "A short description (a few words) to describe something within the image to find and mask."}),
                "threshold": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 1.0, "step":0.01, "round": False, "tooltip": "Threshold to apply to the mask, higher values will make the mask more strict. Without sufficient thresholding, CLIPSeg may include random stray content around the edges."}),
            }
        }

    CATEGORY = "SwarmUI/masks"
    RETURN_TYPES = ("MASK",)
    FUNCTION = "seg"
    DESCRIPTION = "Segment an image using CLIPSeg, creating a mask of what part of an image appears to match the given text."

    def seg(self, images, match_text, threshold):
        # TODO: Batch support?
        i = 255.0 * images[0].cpu().numpy()
        img = Image.fromarray(np.clip(i, 0, 255).astype(np.uint8))
        # TODO: Cache the model in RAM in some way?
        path = get_path() + "/clipseg-rd64-refined-fp16-safetensors/"
        download_model(path, "https://huggingface.co/mcmonkey/clipseg-rd64-refined-fp16/resolve/main/")
        processor = CLIPSegProcessor.from_pretrained(path)
        model = CLIPSegForImageSegmentation.from_pretrained(path)
        with torch.no_grad():
            mask = model(**processor(text=match_text, images=img, return_tensors="pt", padding=True))[0]
        mask = torch.nn.functional.threshold(mask.sigmoid(), threshold, 0)
        mask -= mask.min()
        max = mask.max()
        if max > 0:
            mask /= max
        while mask.ndim < 4:
            mask = mask.unsqueeze(0)
        mask = torch.nn.functional.interpolate(mask, size=(images.shape[1], images.shape[2]), mode="bilinear").squeeze(0)
        return (mask,)

NODE_CLASS_MAPPINGS = {
    "SwarmClipSeg": SwarmClipSeg,
}