saliacoel commited on
Commit
be302a4
·
verified ·
1 Parent(s): 7a69086

Upload Inspyrenet_Rembg2.py

Browse files
Files changed (1) hide show
  1. Inspyrenet_Rembg2.py +197 -0
Inspyrenet_Rembg2.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+ import os
3
+ import urllib.request
4
+
5
+ import torch
6
+ import numpy as np
7
+ from transparent_background import Remover
8
+ from tqdm import tqdm
9
+
10
+
11
+ CKPT_PATH = "/root/.transparent-background/ckpt_base.pth"
12
+ CKPT_URL = "https://huggingface.co/saliacoel/x/resolve/main/ckpt_base.pth"
13
+
14
+
15
+ def _ensure_ckpt_base():
16
+ """
17
+ 1) Check /root/.transparent-background/ckpt_base.pth
18
+ - if exists: do nothing
19
+ - else: download from CKPT_URL
20
+ """
21
+ try:
22
+ if os.path.isfile(CKPT_PATH) and os.path.getsize(CKPT_PATH) > 0:
23
+ return
24
+ except Exception:
25
+ # If getsize fails for any reason, fall through to download attempt.
26
+ pass
27
+
28
+ os.makedirs(os.path.dirname(CKPT_PATH), exist_ok=True)
29
+ tmp_path = CKPT_PATH + ".tmp"
30
+
31
+ try:
32
+ with urllib.request.urlopen(CKPT_URL) as resp:
33
+ total = resp.headers.get("Content-Length")
34
+ total = int(total) if total is not None else None
35
+
36
+ with open(tmp_path, "wb") as f:
37
+ if total:
38
+ with tqdm(
39
+ total=total,
40
+ unit="B",
41
+ unit_scale=True,
42
+ desc="Downloading ckpt_base.pth",
43
+ ) as pbar:
44
+ while True:
45
+ chunk = resp.read(1024 * 1024)
46
+ if not chunk:
47
+ break
48
+ f.write(chunk)
49
+ pbar.update(len(chunk))
50
+ else:
51
+ while True:
52
+ chunk = resp.read(1024 * 1024)
53
+ if not chunk:
54
+ break
55
+ f.write(chunk)
56
+
57
+ os.replace(tmp_path, CKPT_PATH)
58
+
59
+ finally:
60
+ # Clean up partial download if something went wrong
61
+ if os.path.isfile(tmp_path):
62
+ try:
63
+ os.remove(tmp_path)
64
+ except Exception:
65
+ pass
66
+
67
+
68
+ # Tensor to PIL
69
+ def tensor2pil(image: torch.Tensor) -> Image.Image:
70
+ arr = image.detach().cpu().numpy()
71
+ # Handle accidental singleton batch dim
72
+ if arr.ndim == 4 and arr.shape[0] == 1:
73
+ arr = arr[0]
74
+ arr = np.clip(255.0 * arr, 0, 255).astype(np.uint8)
75
+ return Image.fromarray(arr)
76
+
77
+
78
+ # Convert PIL to Tensor
79
+ def pil2tensor(image: Image.Image) -> torch.Tensor:
80
+ return torch.from_numpy(np.array(image).astype(np.float32) / 255.0).unsqueeze(0)
81
+
82
+
83
+ def _rgba_to_rgb_on_white(pil_img: Image.Image) -> Image.Image:
84
+ """
85
+ 5) If input is RGBA:
86
+ - alpha composite over WHITE background
87
+ - convert to RGB (drop alpha)
88
+ If input is RGB:
89
+ - carry on
90
+ """
91
+ if pil_img.mode == "RGBA":
92
+ bg = Image.new("RGBA", pil_img.size, (255, 255, 255, 255))
93
+ composited = Image.alpha_composite(bg, pil_img)
94
+ return composited.convert("RGB")
95
+
96
+ if pil_img.mode != "RGB":
97
+ return pil_img.convert("RGB")
98
+
99
+ return pil_img
100
+
101
+
102
+ class InspyrenetRembg2:
103
+ """
104
+ Original node kept (unchanged behavior/output), except it now ensures ckpt exists.
105
+ """
106
+ def __init__(self):
107
+ pass
108
+
109
+ @classmethod
110
+ def INPUT_TYPES(s):
111
+ return {
112
+ "required": {
113
+ "image": ("IMAGE",),
114
+ "torchscript_jit": (["default", "on"],)
115
+ },
116
+ }
117
+
118
+ RETURN_TYPES = ("IMAGE", "MASK")
119
+ FUNCTION = "remove_background"
120
+ CATEGORY = "image"
121
+
122
+ def remove_background(self, image, torchscript_jit):
123
+ _ensure_ckpt_base()
124
+
125
+ if (torchscript_jit == "default"):
126
+ remover = Remover()
127
+ else:
128
+ remover = Remover(jit=True)
129
+
130
+ img_list = []
131
+ for img in tqdm(image, "Inspyrenet Rembg"):
132
+ mid = remover.process(tensor2pil(img), type='rgba')
133
+ out = pil2tensor(mid)
134
+ img_list.append(out)
135
+
136
+ img_stack = torch.cat(img_list, dim=0)
137
+ mask = img_stack[:, :, :, 3]
138
+ return (img_stack, mask)
139
+
140
+
141
+ class InspyrenetRembg3:
142
+ """
143
+ New node per requested changes:
144
+ - ensures ckpt_base.pth exists (downloads if missing)
145
+ - torchscript_jit hardcoded to "default" (no input, no JIT)
146
+ - NO MASK output (IMAGE only)
147
+ - if input is RGBA: composite over white, convert to RGB, then run remover
148
+ - output remains RGBA (type='rgba')
149
+ """
150
+ def __init__(self):
151
+ pass
152
+
153
+ @classmethod
154
+ def INPUT_TYPES(s):
155
+ return {
156
+ "required": {
157
+ "image": ("IMAGE",),
158
+ },
159
+ }
160
+
161
+ RETURN_TYPES = ("IMAGE",)
162
+ FUNCTION = "remove_background"
163
+ CATEGORY = "image"
164
+
165
+ def remove_background(self, image):
166
+ _ensure_ckpt_base()
167
+
168
+ # 3) hardcode torchscript_jit == "default"
169
+ remover = Remover()
170
+
171
+ img_list = []
172
+ for img in tqdm(image, "Inspyrenet Rembg3"):
173
+ pil_in = tensor2pil(img)
174
+
175
+ # 5) normalize input to RGB for the model:
176
+ # - if RGBA -> alpha composite on white -> RGB
177
+ # - if RGB -> keep
178
+ pil_rgb = _rgba_to_rgb_on_white(pil_in)
179
+
180
+ # do functionality as usual, output RGBA
181
+ mid = remover.process(pil_rgb, type="rgba")
182
+ out = pil2tensor(mid)
183
+ img_list.append(out)
184
+
185
+ img_stack = torch.cat(img_list, dim=0)
186
+ return (img_stack,)
187
+
188
+
189
+ NODE_CLASS_MAPPINGS = {
190
+ "InspyrenetRembg2": InspyrenetRembg2,
191
+ "InspyrenetRembg3": InspyrenetRembg3,
192
+ }
193
+
194
+ NODE_DISPLAY_NAME_MAPPINGS = {
195
+ "InspyrenetRembg2": "Inspyrenet Rembg2",
196
+ "InspyrenetRembg3": "Inspyrenet Rembg3",
197
+ }