arabago96 commited on
Commit
262cecc
·
1 Parent(s): 8bdea49

fix: handle non-RGB images in BiRefNet (CMYK, grayscale, LA modes)

Browse files

Always convert to RGB before the transform pipeline to prevent
shape mismatch errors with non-standard image modes.
Fixed in both utils_birefnet.py and trellis2/pipelines/rembg/BiRefNet.py.

trellis2/pipelines/rembg/BiRefNet.py CHANGED
@@ -30,13 +30,18 @@ class BiRefNet:
30
 
31
  def __call__(self, image: Image.Image) -> Image.Image:
32
  image_size = image.size
33
- input_images = self.transform_image(image).unsqueeze(0).to("cuda")
 
 
 
34
  # Prediction
35
  with torch.no_grad():
36
  preds = self.model(input_images)[-1].sigmoid().cpu()
37
  pred = preds[0].squeeze()
38
  pred_pil = transforms.ToPILImage()(pred)
39
  mask = pred_pil.resize(image_size)
40
- image.putalpha(mask)
41
- return image
 
 
42
 
 
30
 
31
  def __call__(self, image: Image.Image) -> Image.Image:
32
  image_size = image.size
33
+ # Always convert to RGB for the transform (handles RGBA, L, LA, CMYK, P, etc.)
34
+ rgb_image = image.convert('RGB')
35
+
36
+ input_images = self.transform_image(rgb_image).unsqueeze(0).to("cuda")
37
  # Prediction
38
  with torch.no_grad():
39
  preds = self.model(input_images)[-1].sigmoid().cpu()
40
  pred = preds[0].squeeze()
41
  pred_pil = transforms.ToPILImage()(pred)
42
  mask = pred_pil.resize(image_size)
43
+ # Convert to RGBA so putalpha works regardless of the original mode
44
+ rgba_image = rgb_image.convert('RGBA')
45
+ rgba_image.putalpha(mask)
46
+ return rgba_image
47
 
utils_birefnet.py CHANGED
@@ -29,16 +29,17 @@ class BiRefNet:
29
 
30
  def __call__(self, image: Image.Image) -> Image.Image:
31
  image_size = image.size
32
- # Handle alpha channel if present
33
- if image.mode == 'RGBA':
34
- image = image.convert('RGB')
35
-
36
- input_images = self.transform_image(image).unsqueeze(0).to("cuda")
37
  # Prediction
38
  with torch.no_grad():
39
  preds = self.model(input_images)[-1].sigmoid().cpu()
40
  pred = preds[0].squeeze()
41
  pred_pil = transforms.ToPILImage()(pred)
42
  mask = pred_pil.resize(image_size)
43
- image.putalpha(mask)
44
- return image
 
 
 
29
 
30
  def __call__(self, image: Image.Image) -> Image.Image:
31
  image_size = image.size
32
+ # Always convert to RGB for the transform (handles RGBA, L, LA, CMYK, P, etc.)
33
+ rgb_image = image.convert('RGB')
34
+
35
+ input_images = self.transform_image(rgb_image).unsqueeze(0).to("cuda")
 
36
  # Prediction
37
  with torch.no_grad():
38
  preds = self.model(input_images)[-1].sigmoid().cpu()
39
  pred = preds[0].squeeze()
40
  pred_pil = transforms.ToPILImage()(pred)
41
  mask = pred_pil.resize(image_size)
42
+ # Convert to RGBA so putalpha works regardless of the original mode
43
+ rgba_image = rgb_image.convert('RGBA')
44
+ rgba_image.putalpha(mask)
45
+ return rgba_image