lukeafullard commited on
Commit
24cbb2c
·
verified ·
1 Parent(s): cc358c5

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +45 -22
src/streamlit_app.py CHANGED
@@ -32,7 +32,6 @@ def load_birefnet_model():
32
  @st.cache_resource
33
  def load_vitmatte_model():
34
  """Option 3: The Refiner (Matting)"""
35
- # VitMatte requires a rough mask first (we use RMBG for that)
36
  processor = AutoImageProcessor.from_pretrained("hustvl/vitmatte-small-composition-1k")
37
  model = VitMatteForImageMatting.from_pretrained("hustvl/vitmatte-small-composition-1k")
38
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@@ -51,6 +50,12 @@ def load_upscaler(scale=2):
51
 
52
  # --- 2. HELPER FUNCTIONS ---
53
 
 
 
 
 
 
 
54
  def find_mask_tensor(output):
55
  """Recursively finds the mask tensor in complex model outputs."""
56
  if isinstance(output, torch.Tensor):
@@ -65,19 +70,13 @@ def find_mask_tensor(output):
65
  return None
66
 
67
  def generate_trimap(mask_tensor, erode_kernel_size=10, dilate_kernel_size=10):
68
- """
69
- Generates a trimap (Foreground, Background, Unknown) from a binary mask.
70
- Values: 1=FG, 0=BG, 0.5=Unknown (Edge)
71
- """
72
  if mask_tensor.dim() == 3: mask_tensor = mask_tensor.unsqueeze(0)
73
 
74
  erode_k = erode_kernel_size
75
  dilate_k = dilate_kernel_size
76
 
77
- # Dilation (Max Pooling)
78
  dilated = F.max_pool2d(mask_tensor, kernel_size=dilate_k, stride=1, padding=dilate_k//2)
79
-
80
- # Erosion (Negative Max Pooling)
81
  eroded = -F.max_pool2d(-mask_tensor, kernel_size=erode_k, stride=1, padding=erode_k//2)
82
 
83
  trimap = torch.full_like(mask_tensor, 0.5)
@@ -116,28 +115,42 @@ def inference_segmentation(model, image, device, resolution=1024):
116
 
117
  def inference_vitmatte(image, device):
118
  """
119
- Runs pipeline: RMBG (Rough Mask) -> Trimap -> VitMatte (Refined Mask)
 
120
  """
121
- # 1. Get Rough Mask using RMBG (Fast)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  rmbg_model, _ = load_rmbg_model()
123
- rough_mask_pil = inference_segmentation(rmbg_model, image, device, resolution=1024)
124
 
125
- # 2. Create Trimap (Tensor)
126
  mask_tensor = transforms.ToTensor()(rough_mask_pil).to(device)
127
  trimap_tensor = generate_trimap(mask_tensor, erode_kernel_size=25, dilate_kernel_size=25)
128
 
129
- # --- FIX START ---
130
- # 3. Convert Trimap Tensor to PIL Image
131
- # VitMatte Processor crashes on raw tensors. It wants a PIL Image.
132
- # We take the tensor (0.0 to 1.0), move to CPU, and convert to PIL (0 to 255)
133
  trimap_pil = transforms.ToPILImage()(trimap_tensor.squeeze().cpu())
134
 
135
  # 4. VitMatte Inference
136
  processor, model, _ = load_vitmatte_model()
137
 
138
- # Pass PIL images for both
139
- inputs = processor(images=image, trimaps=trimap_pil, return_tensors="pt").to(device)
140
- # --- FIX END ---
141
 
142
  with torch.no_grad():
143
  outputs = model(**inputs)
@@ -145,12 +158,18 @@ def inference_vitmatte(image, device):
145
  alphas = outputs.alphas
146
  alpha_np = alphas.squeeze().cpu().numpy()
147
  alpha_pil = Image.fromarray((alpha_np * 255).astype("uint8"), mode="L")
148
- alpha_pil = alpha_pil.resize(image.size, resample=Image.LANCZOS)
149
 
 
 
 
 
 
 
150
  return alpha_pil
151
 
152
  @st.cache_data(show_spinner=False)
153
  def process_background_removal(image_bytes, method="RMBG-1.4"):
 
154
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
155
 
156
  if method == "RMBG-1.4":
@@ -159,6 +178,7 @@ def process_background_removal(image_bytes, method="RMBG-1.4"):
159
 
160
  elif method == "BiRefNet (Heavy)":
161
  model, device = load_birefnet_model()
 
162
  mask = inference_segmentation(model, image, device, resolution=1024)
163
 
164
  elif method == "VitMatte (Refiner)":
@@ -192,6 +212,7 @@ def upscale_chunk_logic(image, processor, model):
192
  return run_swin_inference(image, processor, model)
193
 
194
  def process_tiled_upscale(image, scale_factor, grid_n, progress_bar):
 
195
  processor, model = load_upscaler(scale_factor)
196
  w, h = image.size
197
  rows = cols = grid_n
@@ -226,7 +247,7 @@ def process_tiled_upscale(image, scale_factor, grid_n, progress_bar):
226
  paste_y = target_upper * scale_factor
227
  full_image.paste(clean_tile, (paste_x, paste_y))
228
  del tile, upscaled_tile, clean_tile
229
- gc.collect()
230
  count += 1
231
  progress_bar.progress(count / total_tiles, text=f"Upscaling Tile {count}/{total_tiles}...")
232
  return full_image
@@ -283,7 +304,7 @@ def main():
283
  # 2. Upscaling
284
  if upscale_mode != "None":
285
  scale = 4 if "4x" in upscale_mode else 2
286
- cache_key = f"{uploaded_file.name}_{bg_model}_{scale}_{grid_n}_v6"
287
 
288
  if "upscale_cache" not in st.session_state:
289
  st.session_state.upscale_cache = {}
@@ -306,10 +327,12 @@ def main():
306
  col1, col2 = st.columns(2)
307
  with col1:
308
  st.subheader("Original")
 
309
  st.image(Image.open(io.BytesIO(file_bytes)), use_container_width=True)
310
 
311
  with col2:
312
  st.subheader("Result")
 
313
  st.image(final_image, use_container_width=True)
314
 
315
  st.markdown("---")
 
32
  @st.cache_resource
33
  def load_vitmatte_model():
34
  """Option 3: The Refiner (Matting)"""
 
35
  processor = AutoImageProcessor.from_pretrained("hustvl/vitmatte-small-composition-1k")
36
  model = VitMatteForImageMatting.from_pretrained("hustvl/vitmatte-small-composition-1k")
37
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
50
 
51
  # --- 2. HELPER FUNCTIONS ---
52
 
53
+ def cleanup_memory():
54
+ """Forcibly clears memory."""
55
+ gc.collect()
56
+ if torch.cuda.is_available():
57
+ torch.cuda.empty_cache()
58
+
59
  def find_mask_tensor(output):
60
  """Recursively finds the mask tensor in complex model outputs."""
61
  if isinstance(output, torch.Tensor):
 
70
  return None
71
 
72
  def generate_trimap(mask_tensor, erode_kernel_size=10, dilate_kernel_size=10):
73
+ """Generates a trimap (Foreground, Background, Unknown) from a binary mask."""
 
 
 
74
  if mask_tensor.dim() == 3: mask_tensor = mask_tensor.unsqueeze(0)
75
 
76
  erode_k = erode_kernel_size
77
  dilate_k = dilate_kernel_size
78
 
 
79
  dilated = F.max_pool2d(mask_tensor, kernel_size=dilate_k, stride=1, padding=dilate_k//2)
 
 
80
  eroded = -F.max_pool2d(-mask_tensor, kernel_size=erode_k, stride=1, padding=erode_k//2)
81
 
82
  trimap = torch.full_like(mask_tensor, 0.5)
 
115
 
116
  def inference_vitmatte(image, device):
117
  """
118
+ Runs pipeline: RMBG (Rough Mask) -> Trimap -> VitMatte (Refined Mask).
119
+ Includes memory safety downscaling.
120
  """
121
+ cleanup_memory() # Clear RAM before starting
122
+
123
+ original_size = image.size
124
+
125
+ # --- MEMORY SAFETY CHECK ---
126
+ # If image is too large, downscale it for VitMatte processing
127
+ # 1536px is a sweet spot: good detail, safe RAM usage (~4-6GB peak)
128
+ max_dim = 1536
129
+ if max(image.size) > max_dim:
130
+ scale_ratio = max_dim / max(image.size)
131
+ new_w = int(image.size[0] * scale_ratio)
132
+ new_h = int(image.size[1] * scale_ratio)
133
+ # Create a smaller copy for processing
134
+ processing_image = image.resize((new_w, new_h), Image.LANCZOS)
135
+ else:
136
+ processing_image = image
137
+
138
+ # 1. Get Rough Mask using RMBG
139
  rmbg_model, _ = load_rmbg_model()
140
+ rough_mask_pil = inference_segmentation(rmbg_model, processing_image, device, resolution=1024)
141
 
142
+ # 2. Create Trimap
143
  mask_tensor = transforms.ToTensor()(rough_mask_pil).to(device)
144
  trimap_tensor = generate_trimap(mask_tensor, erode_kernel_size=25, dilate_kernel_size=25)
145
 
146
+ # 3. Convert Trimap to PIL (Required for Processor)
 
 
 
147
  trimap_pil = transforms.ToPILImage()(trimap_tensor.squeeze().cpu())
148
 
149
  # 4. VitMatte Inference
150
  processor, model, _ = load_vitmatte_model()
151
 
152
+ # Pass PIL images
153
+ inputs = processor(images=processing_image, trimaps=trimap_pil, return_tensors="pt").to(device)
 
154
 
155
  with torch.no_grad():
156
  outputs = model(**inputs)
 
158
  alphas = outputs.alphas
159
  alpha_np = alphas.squeeze().cpu().numpy()
160
  alpha_pil = Image.fromarray((alpha_np * 255).astype("uint8"), mode="L")
 
161
 
162
+ # 5. Restore Resolution
163
+ # If we downscaled, we must upscale the result mask back to match original
164
+ if original_size != processing_image.size:
165
+ alpha_pil = alpha_pil.resize(original_size, resample=Image.LANCZOS)
166
+
167
+ cleanup_memory() # Cleanup after finish
168
  return alpha_pil
169
 
170
  @st.cache_data(show_spinner=False)
171
  def process_background_removal(image_bytes, method="RMBG-1.4"):
172
+ cleanup_memory() # Ensure clean state
173
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
174
 
175
  if method == "RMBG-1.4":
 
178
 
179
  elif method == "BiRefNet (Heavy)":
180
  model, device = load_birefnet_model()
181
+ # BiRefNet handles 1024 internally well, generally safe on memory
182
  mask = inference_segmentation(model, image, device, resolution=1024)
183
 
184
  elif method == "VitMatte (Refiner)":
 
212
  return run_swin_inference(image, processor, model)
213
 
214
  def process_tiled_upscale(image, scale_factor, grid_n, progress_bar):
215
+ cleanup_memory()
216
  processor, model = load_upscaler(scale_factor)
217
  w, h = image.size
218
  rows = cols = grid_n
 
247
  paste_y = target_upper * scale_factor
248
  full_image.paste(clean_tile, (paste_x, paste_y))
249
  del tile, upscaled_tile, clean_tile
250
+ cleanup_memory()
251
  count += 1
252
  progress_bar.progress(count / total_tiles, text=f"Upscaling Tile {count}/{total_tiles}...")
253
  return full_image
 
304
  # 2. Upscaling
305
  if upscale_mode != "None":
306
  scale = 4 if "4x" in upscale_mode else 2
307
+ cache_key = f"{uploaded_file.name}_{bg_model}_{scale}_{grid_n}_v7"
308
 
309
  if "upscale_cache" not in st.session_state:
310
  st.session_state.upscale_cache = {}
 
327
  col1, col2 = st.columns(2)
328
  with col1:
329
  st.subheader("Original")
330
+ # Fixed deprecation warning for use_container_width
331
  st.image(Image.open(io.BytesIO(file_bytes)), use_container_width=True)
332
 
333
  with col2:
334
  st.subheader("Result")
335
+ # Fixed deprecation warning for use_container_width
336
  st.image(final_image, use_container_width=True)
337
 
338
  st.markdown("---")