lukeafullard commited on
Commit
cc358c5
·
verified ·
1 Parent(s): a3af8ea

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +18 -40
src/streamlit_app.py CHANGED
@@ -24,7 +24,6 @@ def load_rmbg_model():
24
  @st.cache_resource
25
  def load_birefnet_model():
26
  """Option 2: The Heavyweight Generalist"""
27
- # This requires 'timm' installed
28
  model = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet", trust_remote_code=True)
29
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
  model.to(device)
@@ -67,33 +66,21 @@ def find_mask_tensor(output):
67
 
68
  def generate_trimap(mask_tensor, erode_kernel_size=10, dilate_kernel_size=10):
69
  """
70
- Generates a trimap (Foreground, Background, Unknown) from a binary mask
71
- using Pure PyTorch (No OpenCV required).
72
  Values: 1=FG, 0=BG, 0.5=Unknown (Edge)
73
  """
74
- # Ensure mask is Bx1xHxW
75
  if mask_tensor.dim() == 3: mask_tensor = mask_tensor.unsqueeze(0)
76
 
77
- # Create kernels
78
  erode_k = erode_kernel_size
79
  dilate_k = dilate_kernel_size
80
 
81
- # Dilation (Max Pooling) - Expands the white area
82
- # We pad to keep size same
83
  dilated = F.max_pool2d(mask_tensor, kernel_size=dilate_k, stride=1, padding=dilate_k//2)
84
 
85
- # Erosion (Negative Max Pooling) - Shrinks the white area
86
  eroded = -F.max_pool2d(-mask_tensor, kernel_size=erode_k, stride=1, padding=erode_k//2)
87
 
88
- # Trimap construction
89
- # Pixels that are 1 in eroded are definitely FG (1.0)
90
- # Pixels that are 0 in dilated are definitely BG (0.0)
91
- # Everything else is the "Unknown" zone (0.5)
92
-
93
- # Start with Unknown (0.5)
94
  trimap = torch.full_like(mask_tensor, 0.5)
95
-
96
- # Set definites
97
  trimap[eroded > 0.5] = 1.0
98
  trimap[dilated < 0.5] = 0.0
99
 
@@ -120,11 +107,9 @@ def inference_segmentation(model, image, device, resolution=1024):
120
  if not isinstance(result_tensor, torch.Tensor):
121
  if isinstance(result_tensor, (list, tuple)): result_tensor = result_tensor[0]
122
 
123
- # Get binary-ish mask (logits or sigmoid)
124
  pred = result_tensor.squeeze().cpu()
125
  if pred.max() > 1 or pred.min() < 0: pred = pred.sigmoid()
126
 
127
- # Resize back to original
128
  pred_pil = transforms.ToPILImage()(pred)
129
  mask = pred_pil.resize((w, h), resample=Image.LANCZOS)
130
  return mask
@@ -134,36 +119,36 @@ def inference_vitmatte(image, device):
134
  Runs pipeline: RMBG (Rough Mask) -> Trimap -> VitMatte (Refined Mask)
135
  """
136
  # 1. Get Rough Mask using RMBG (Fast)
137
- rmbg_model, _ = load_rmbg_model() # Re-use loaded model
138
  rough_mask_pil = inference_segmentation(rmbg_model, image, device, resolution=1024)
139
 
140
- # 2. Create Trimap
141
- # Convert PIL mask to Tensor
142
  mask_tensor = transforms.ToTensor()(rough_mask_pil).to(device)
143
- # Generate trimap (1=FG, 0=BG, 0.5=Unknown)
144
  trimap_tensor = generate_trimap(mask_tensor, erode_kernel_size=25, dilate_kernel_size=25)
145
 
146
- # 3. VitMatte Inference
 
 
 
 
 
 
147
  processor, model, _ = load_vitmatte_model()
148
 
149
- # VitMatte expects inputs: pixel_values (image) and mask_labels (trimap)
150
- inputs = processor(images=image, trimaps=trimap_tensor, return_tensors="pt").to(device)
 
151
 
152
  with torch.no_grad():
153
  outputs = model(**inputs)
154
 
155
- # Output is the refined alphas
156
  alphas = outputs.alphas
157
-
158
- # 4. Post-process
159
- # Extract alpha, resize to original
160
  alpha_np = alphas.squeeze().cpu().numpy()
161
  alpha_pil = Image.fromarray((alpha_np * 255).astype("uint8"), mode="L")
162
  alpha_pil = alpha_pil.resize(image.size, resample=Image.LANCZOS)
163
 
164
  return alpha_pil
165
 
166
-
167
  @st.cache_data(show_spinner=False)
168
  def process_background_removal(image_bytes, method="RMBG-1.4"):
169
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
@@ -177,19 +162,16 @@ def process_background_removal(image_bytes, method="RMBG-1.4"):
177
  mask = inference_segmentation(model, image, device, resolution=1024)
178
 
179
  elif method == "VitMatte (Refiner)":
180
- # VitMatte needs GPU ideally, works on CPU but slow
181
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
182
  mask = inference_vitmatte(image, device)
183
 
184
  else:
185
- # Fallback
186
  return image
187
 
188
- # Apply mask
189
  image.putalpha(mask)
190
  return image
191
 
192
- # --- Upscaling Logic (Same as before) ---
193
  def run_swin_inference(image, processor, model):
194
  inputs = processor(image, return_tensors="pt")
195
  with torch.no_grad():
@@ -264,13 +246,12 @@ def main():
264
  st.sidebar.header("1. Background Removal")
265
  remove_bg = st.sidebar.checkbox("Remove Background", value=False)
266
 
267
- # NEW: Model Selector
268
  if remove_bg:
269
  bg_model = st.sidebar.selectbox(
270
  "Select AI Model",
271
  ["RMBG-1.4", "BiRefNet (Heavy)", "VitMatte (Refiner)"],
272
  index=0,
273
- help="RMBG: Fast, Standard Quality.\nBiRefNet: Slower, Better Edges.\nVitMatte: Slowest, Best for Hair/Transparency."
274
  )
275
  else:
276
  bg_model = "None"
@@ -294,7 +275,6 @@ def main():
294
 
295
  # 1. Background
296
  if remove_bg:
297
- # We add the model name to the spinner text so user knows what's happening
298
  with st.spinner(f"Removing background using {bg_model}..."):
299
  processed_image = process_background_removal(file_bytes, bg_model)
300
  else:
@@ -303,9 +283,7 @@ def main():
303
  # 2. Upscaling
304
  if upscale_mode != "None":
305
  scale = 4 if "4x" in upscale_mode else 2
306
-
307
- # Cache Key includes model name now
308
- cache_key = f"{uploaded_file.name}_{bg_model}_{scale}_{grid_n}_v5"
309
 
310
  if "upscale_cache" not in st.session_state:
311
  st.session_state.upscale_cache = {}
 
24
  @st.cache_resource
25
  def load_birefnet_model():
26
  """Option 2: The Heavyweight Generalist"""
 
27
  model = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet", trust_remote_code=True)
28
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
29
  model.to(device)
 
66
 
67
  def generate_trimap(mask_tensor, erode_kernel_size=10, dilate_kernel_size=10):
68
  """
69
+ Generates a trimap (Foreground, Background, Unknown) from a binary mask.
 
70
  Values: 1=FG, 0=BG, 0.5=Unknown (Edge)
71
  """
 
72
  if mask_tensor.dim() == 3: mask_tensor = mask_tensor.unsqueeze(0)
73
 
 
74
  erode_k = erode_kernel_size
75
  dilate_k = dilate_kernel_size
76
 
77
+ # Dilation (Max Pooling)
 
78
  dilated = F.max_pool2d(mask_tensor, kernel_size=dilate_k, stride=1, padding=dilate_k//2)
79
 
80
+ # Erosion (Negative Max Pooling)
81
  eroded = -F.max_pool2d(-mask_tensor, kernel_size=erode_k, stride=1, padding=erode_k//2)
82
 
 
 
 
 
 
 
83
  trimap = torch.full_like(mask_tensor, 0.5)
 
 
84
  trimap[eroded > 0.5] = 1.0
85
  trimap[dilated < 0.5] = 0.0
86
 
 
107
  if not isinstance(result_tensor, torch.Tensor):
108
  if isinstance(result_tensor, (list, tuple)): result_tensor = result_tensor[0]
109
 
 
110
  pred = result_tensor.squeeze().cpu()
111
  if pred.max() > 1 or pred.min() < 0: pred = pred.sigmoid()
112
 
 
113
  pred_pil = transforms.ToPILImage()(pred)
114
  mask = pred_pil.resize((w, h), resample=Image.LANCZOS)
115
  return mask
 
119
  Runs pipeline: RMBG (Rough Mask) -> Trimap -> VitMatte (Refined Mask)
120
  """
121
  # 1. Get Rough Mask using RMBG (Fast)
122
+ rmbg_model, _ = load_rmbg_model()
123
  rough_mask_pil = inference_segmentation(rmbg_model, image, device, resolution=1024)
124
 
125
+ # 2. Create Trimap (Tensor)
 
126
  mask_tensor = transforms.ToTensor()(rough_mask_pil).to(device)
 
127
  trimap_tensor = generate_trimap(mask_tensor, erode_kernel_size=25, dilate_kernel_size=25)
128
 
129
+ # --- FIX START ---
130
+ # 3. Convert Trimap Tensor to PIL Image
131
+ # VitMatte Processor crashes on raw tensors. It wants a PIL Image.
132
+ # We take the tensor (0.0 to 1.0), move to CPU, and convert to PIL (0 to 255)
133
+ trimap_pil = transforms.ToPILImage()(trimap_tensor.squeeze().cpu())
134
+
135
+ # 4. VitMatte Inference
136
  processor, model, _ = load_vitmatte_model()
137
 
138
+ # Pass PIL images for both
139
+ inputs = processor(images=image, trimaps=trimap_pil, return_tensors="pt").to(device)
140
+ # --- FIX END ---
141
 
142
  with torch.no_grad():
143
  outputs = model(**inputs)
144
 
 
145
  alphas = outputs.alphas
 
 
 
146
  alpha_np = alphas.squeeze().cpu().numpy()
147
  alpha_pil = Image.fromarray((alpha_np * 255).astype("uint8"), mode="L")
148
  alpha_pil = alpha_pil.resize(image.size, resample=Image.LANCZOS)
149
 
150
  return alpha_pil
151
 
 
152
  @st.cache_data(show_spinner=False)
153
  def process_background_removal(image_bytes, method="RMBG-1.4"):
154
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
 
162
  mask = inference_segmentation(model, image, device, resolution=1024)
163
 
164
  elif method == "VitMatte (Refiner)":
 
165
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
166
  mask = inference_vitmatte(image, device)
167
 
168
  else:
 
169
  return image
170
 
 
171
  image.putalpha(mask)
172
  return image
173
 
174
+ # --- Upscaling Logic ---
175
  def run_swin_inference(image, processor, model):
176
  inputs = processor(image, return_tensors="pt")
177
  with torch.no_grad():
 
246
  st.sidebar.header("1. Background Removal")
247
  remove_bg = st.sidebar.checkbox("Remove Background", value=False)
248
 
 
249
  if remove_bg:
250
  bg_model = st.sidebar.selectbox(
251
  "Select AI Model",
252
  ["RMBG-1.4", "BiRefNet (Heavy)", "VitMatte (Refiner)"],
253
  index=0,
254
+ help="RMBG: Fast.\nBiRefNet: Better.\nVitMatte: Best for hair/transparency."
255
  )
256
  else:
257
  bg_model = "None"
 
275
 
276
  # 1. Background
277
  if remove_bg:
 
278
  with st.spinner(f"Removing background using {bg_model}..."):
279
  processed_image = process_background_removal(file_bytes, bg_model)
280
  else:
 
283
  # 2. Upscaling
284
  if upscale_mode != "None":
285
  scale = 4 if "4x" in upscale_mode else 2
286
+ cache_key = f"{uploaded_file.name}_{bg_model}_{scale}_{grid_n}_v6"
 
 
287
 
288
  if "upscale_cache" not in st.session_state:
289
  st.session_state.upscale_cache = {}