ZhengPeng7 commited on
Commit
a7ab477
·
1 Parent(s): aa6bac1

Add the rgba2rgb preprocessing for RGBA inputs.

Browse files
Files changed (2) hide show
  1. app.py +41 -1
  2. app_local.py +41 -1
app.py CHANGED
@@ -29,6 +29,44 @@ torch.jit.script = lambda f: f
29
  device = "cuda" if torch.cuda.is_available() else "cpu"
30
 
31
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32
  ## CPU version refinement
33
  def FB_blur_fusion_foreground_estimator_cpu(image, FG, B, alpha, r=90):
34
  if isinstance(image, Image.Image):
@@ -112,7 +150,7 @@ def refine_foreground(image, mask, r=90, device='cuda'):
112
  mask = mask.unsqueeze(0)
113
 
114
  estimated_foreground = FB_blur_fusion_foreground_estimator_gpu_2(image, mask, r=r)
115
-
116
  estimated_foreground = estimated_foreground.squeeze()
117
  estimated_foreground = (estimated_foreground.mul(255.0)).to(torch.uint8)
118
  estimated_foreground = estimated_foreground.permute(1, 2, 0).contiguous().cpu().numpy().astype(np.uint8)
@@ -215,6 +253,8 @@ def predict(images, resolution, weights_file):
215
  image_ori = Image.open(image_data)
216
  else:
217
  image_ori = Image.fromarray(image_src)
 
 
218
 
219
  image = image_ori.convert('RGB')
220
  # Preprocess the image
 
29
  device = "cuda" if torch.cuda.is_available() else "cpu"
30
 
31
 
32
+ def rgba2rgb(img):
33
+ """
34
+ Convert RGBA image to RGB with white background.
35
+ Supports both PIL.Image and numpy.ndarray.
36
+ """
37
+
38
+ # 1. Handle PIL Image
39
+ if isinstance(img, Image.Image):
40
+ img = img.convert("RGBA")
41
+ bg = Image.new("RGBA", img.size, (255, 255, 255))
42
+ return Image.alpha_composite(bg, img).convert("RGB")
43
+
44
+ # 2. Handle Numpy Array (OpenCV)
45
+ elif isinstance(img, np.ndarray):
46
+ # Grayscale to RGB
47
+ if img.ndim == 2:
48
+ return np.stack([img] * 3, axis=-1)
49
+
50
+ # Already 3 channels
51
+ if img.shape[2] == 3:
52
+ return img
53
+
54
+ # RGBA to RGB (blending with white)
55
+ elif img.shape[2] == 4:
56
+ # Normalize alpha to 0-1 and keep shape (H, W, 1)
57
+ alpha = img[..., 3:4].astype(float) / 255.0
58
+ foreground = img[..., :3].astype(float)
59
+ background = 255.0
60
+
61
+ # Blend formula: source * alpha + bg * (1 - alpha)
62
+ out = foreground * alpha + background * (1.0 - alpha)
63
+
64
+ return out.clip(0, 255).astype(np.uint8)
65
+
66
+ else:
67
+ raise TypeError(f"Unsupported type: {type(img)}")
68
+
69
+
70
  ## CPU version refinement
71
  def FB_blur_fusion_foreground_estimator_cpu(image, FG, B, alpha, r=90):
72
  if isinstance(image, Image.Image):
 
150
  mask = mask.unsqueeze(0)
151
 
152
  estimated_foreground = FB_blur_fusion_foreground_estimator_gpu_2(image, mask, r=r)
153
+
154
  estimated_foreground = estimated_foreground.squeeze()
155
  estimated_foreground = (estimated_foreground.mul(255.0)).to(torch.uint8)
156
  estimated_foreground = estimated_foreground.permute(1, 2, 0).contiguous().cpu().numpy().astype(np.uint8)
 
253
  image_ori = Image.open(image_data)
254
  else:
255
  image_ori = Image.fromarray(image_src)
256
+ if image_ori.mode == 'RGBA':
257
+ image_ori = rgba2rgb(image_ori)
258
 
259
  image = image_ori.convert('RGB')
260
  # Preprocess the image
app_local.py CHANGED
@@ -25,6 +25,44 @@ torch.set_float32_matmul_precision('high')
25
  device = "cuda" if torch.cuda.is_available() else "cpu"
26
 
27
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  ## CPU version refinement
29
  def FB_blur_fusion_foreground_estimator_cpu(image, FG, B, alpha, r=90):
30
  if isinstance(image, Image.Image):
@@ -108,7 +146,7 @@ def refine_foreground(image, mask, r=90, device='cuda'):
108
  mask = mask.unsqueeze(0)
109
 
110
  estimated_foreground = FB_blur_fusion_foreground_estimator_gpu_2(image, mask, r=r)
111
-
112
  estimated_foreground = estimated_foreground.squeeze()
113
  estimated_foreground = (estimated_foreground.mul(255.0)).to(torch.uint8)
114
  estimated_foreground = estimated_foreground.permute(1, 2, 0).contiguous().cpu().numpy().astype(np.uint8)
@@ -210,6 +248,8 @@ def predict(images, resolution, weights_file):
210
  image_ori = Image.open(image_data)
211
  else:
212
  image_ori = Image.fromarray(image_src)
 
 
213
 
214
  image = image_ori.convert('RGB')
215
  # Preprocess the image
 
25
  device = "cuda" if torch.cuda.is_available() else "cpu"
26
 
27
 
28
+ def rgba2rgb(img):
29
+ """
30
+ Convert RGBA image to RGB with white background.
31
+ Supports both PIL.Image and numpy.ndarray.
32
+ """
33
+
34
+ # 1. Handle PIL Image
35
+ if isinstance(img, Image.Image):
36
+ img = img.convert("RGBA")
37
+ bg = Image.new("RGBA", img.size, (255, 255, 255))
38
+ return Image.alpha_composite(bg, img).convert("RGB")
39
+
40
+ # 2. Handle Numpy Array (OpenCV)
41
+ elif isinstance(img, np.ndarray):
42
+ # Grayscale to RGB
43
+ if img.ndim == 2:
44
+ return np.stack([img] * 3, axis=-1)
45
+
46
+ # Already 3 channels
47
+ if img.shape[2] == 3:
48
+ return img
49
+
50
+ # RGBA to RGB (blending with white)
51
+ elif img.shape[2] == 4:
52
+ # Normalize alpha to 0-1 and keep shape (H, W, 1)
53
+ alpha = img[..., 3:4].astype(float) / 255.0
54
+ foreground = img[..., :3].astype(float)
55
+ background = 255.0
56
+
57
+ # Blend formula: source * alpha + bg * (1 - alpha)
58
+ out = foreground * alpha + background * (1.0 - alpha)
59
+
60
+ return out.clip(0, 255).astype(np.uint8)
61
+
62
+ else:
63
+ raise TypeError(f"Unsupported type: {type(img)}")
64
+
65
+
66
  ## CPU version refinement
67
  def FB_blur_fusion_foreground_estimator_cpu(image, FG, B, alpha, r=90):
68
  if isinstance(image, Image.Image):
 
146
  mask = mask.unsqueeze(0)
147
 
148
  estimated_foreground = FB_blur_fusion_foreground_estimator_gpu_2(image, mask, r=r)
149
+
150
  estimated_foreground = estimated_foreground.squeeze()
151
  estimated_foreground = (estimated_foreground.mul(255.0)).to(torch.uint8)
152
  estimated_foreground = estimated_foreground.permute(1, 2, 0).contiguous().cpu().numpy().astype(np.uint8)
 
248
  image_ori = Image.open(image_data)
249
  else:
250
  image_ori = Image.fromarray(image_src)
251
+ if image_ori.mode == 'RGBA':
252
+ image_ori = rgba2rgb(image_ori)
253
 
254
  image = image_ori.convert('RGB')
255
  # Preprocess the image