Spaces:
Sleeping
Sleeping
fix: handle non-RGB images in BiRefNet (CMYK, grayscale, LA modes)
Browse filesAlways convert to RGB before the transform pipeline to prevent
shape mismatch errors with non-standard image modes.
Fixed in both utils_birefnet.py and trellis2/pipelines/rembg/BiRefNet.py.
trellis2/pipelines/rembg/BiRefNet.py
CHANGED
|
@@ -30,13 +30,18 @@ class BiRefNet:
|
|
| 30 |
|
| 31 |
def __call__(self, image: Image.Image) -> Image.Image:
|
| 32 |
image_size = image.size
|
| 33 |
-
|
|
|
|
|
|
|
|
|
|
| 34 |
# Prediction
|
| 35 |
with torch.no_grad():
|
| 36 |
preds = self.model(input_images)[-1].sigmoid().cpu()
|
| 37 |
pred = preds[0].squeeze()
|
| 38 |
pred_pil = transforms.ToPILImage()(pred)
|
| 39 |
mask = pred_pil.resize(image_size)
|
| 40 |
-
|
| 41 |
-
|
|
|
|
|
|
|
| 42 |
|
|
|
|
| 30 |
|
| 31 |
def __call__(self, image: Image.Image) -> Image.Image:
|
| 32 |
image_size = image.size
|
| 33 |
+
# Always convert to RGB for the transform (handles RGBA, L, LA, CMYK, P, etc.)
|
| 34 |
+
rgb_image = image.convert('RGB')
|
| 35 |
+
|
| 36 |
+
input_images = self.transform_image(rgb_image).unsqueeze(0).to("cuda")
|
| 37 |
# Prediction
|
| 38 |
with torch.no_grad():
|
| 39 |
preds = self.model(input_images)[-1].sigmoid().cpu()
|
| 40 |
pred = preds[0].squeeze()
|
| 41 |
pred_pil = transforms.ToPILImage()(pred)
|
| 42 |
mask = pred_pil.resize(image_size)
|
| 43 |
+
# Convert to RGBA so putalpha works regardless of the original mode
|
| 44 |
+
rgba_image = rgb_image.convert('RGBA')
|
| 45 |
+
rgba_image.putalpha(mask)
|
| 46 |
+
return rgba_image
|
| 47 |
|
utils_birefnet.py
CHANGED
|
@@ -29,16 +29,17 @@ class BiRefNet:
|
|
| 29 |
|
| 30 |
def __call__(self, image: Image.Image) -> Image.Image:
|
| 31 |
image_size = image.size
|
| 32 |
-
#
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
input_images = self.transform_image(image).unsqueeze(0).to("cuda")
|
| 37 |
# Prediction
|
| 38 |
with torch.no_grad():
|
| 39 |
preds = self.model(input_images)[-1].sigmoid().cpu()
|
| 40 |
pred = preds[0].squeeze()
|
| 41 |
pred_pil = transforms.ToPILImage()(pred)
|
| 42 |
mask = pred_pil.resize(image_size)
|
| 43 |
-
|
| 44 |
-
|
|
|
|
|
|
|
|
|
| 29 |
|
| 30 |
def __call__(self, image: Image.Image) -> Image.Image:
|
| 31 |
image_size = image.size
|
| 32 |
+
# Always convert to RGB for the transform (handles RGBA, L, LA, CMYK, P, etc.)
|
| 33 |
+
rgb_image = image.convert('RGB')
|
| 34 |
+
|
| 35 |
+
input_images = self.transform_image(rgb_image).unsqueeze(0).to("cuda")
|
|
|
|
| 36 |
# Prediction
|
| 37 |
with torch.no_grad():
|
| 38 |
preds = self.model(input_images)[-1].sigmoid().cpu()
|
| 39 |
pred = preds[0].squeeze()
|
| 40 |
pred_pil = transforms.ToPILImage()(pred)
|
| 41 |
mask = pred_pil.resize(image_size)
|
| 42 |
+
# Convert to RGBA so putalpha works regardless of the original mode
|
| 43 |
+
rgba_image = rgb_image.convert('RGBA')
|
| 44 |
+
rgba_image.putalpha(mask)
|
| 45 |
+
return rgba_image
|