Spaces:
Running
Running
Update src/streamlit_app.py
Browse files- src/streamlit_app.py +18 -40
src/streamlit_app.py
CHANGED
|
@@ -24,7 +24,6 @@ def load_rmbg_model():
|
|
| 24 |
@st.cache_resource
|
| 25 |
def load_birefnet_model():
|
| 26 |
"""Option 2: The Heavyweight Generalist"""
|
| 27 |
-
# This requires 'timm' installed
|
| 28 |
model = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet", trust_remote_code=True)
|
| 29 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 30 |
model.to(device)
|
|
@@ -67,33 +66,21 @@ def find_mask_tensor(output):
|
|
| 67 |
|
| 68 |
def generate_trimap(mask_tensor, erode_kernel_size=10, dilate_kernel_size=10):
|
| 69 |
"""
|
| 70 |
-
Generates a trimap (Foreground, Background, Unknown) from a binary mask
|
| 71 |
-
using Pure PyTorch (No OpenCV required).
|
| 72 |
Values: 1=FG, 0=BG, 0.5=Unknown (Edge)
|
| 73 |
"""
|
| 74 |
-
# Ensure mask is Bx1xHxW
|
| 75 |
if mask_tensor.dim() == 3: mask_tensor = mask_tensor.unsqueeze(0)
|
| 76 |
|
| 77 |
-
# Create kernels
|
| 78 |
erode_k = erode_kernel_size
|
| 79 |
dilate_k = dilate_kernel_size
|
| 80 |
|
| 81 |
-
# Dilation (Max Pooling)
|
| 82 |
-
# We pad to keep size same
|
| 83 |
dilated = F.max_pool2d(mask_tensor, kernel_size=dilate_k, stride=1, padding=dilate_k//2)
|
| 84 |
|
| 85 |
-
# Erosion (Negative Max Pooling)
|
| 86 |
eroded = -F.max_pool2d(-mask_tensor, kernel_size=erode_k, stride=1, padding=erode_k//2)
|
| 87 |
|
| 88 |
-
# Trimap construction
|
| 89 |
-
# Pixels that are 1 in eroded are definitely FG (1.0)
|
| 90 |
-
# Pixels that are 0 in dilated are definitely BG (0.0)
|
| 91 |
-
# Everything else is the "Unknown" zone (0.5)
|
| 92 |
-
|
| 93 |
-
# Start with Unknown (0.5)
|
| 94 |
trimap = torch.full_like(mask_tensor, 0.5)
|
| 95 |
-
|
| 96 |
-
# Set definites
|
| 97 |
trimap[eroded > 0.5] = 1.0
|
| 98 |
trimap[dilated < 0.5] = 0.0
|
| 99 |
|
|
@@ -120,11 +107,9 @@ def inference_segmentation(model, image, device, resolution=1024):
|
|
| 120 |
if not isinstance(result_tensor, torch.Tensor):
|
| 121 |
if isinstance(result_tensor, (list, tuple)): result_tensor = result_tensor[0]
|
| 122 |
|
| 123 |
-
# Get binary-ish mask (logits or sigmoid)
|
| 124 |
pred = result_tensor.squeeze().cpu()
|
| 125 |
if pred.max() > 1 or pred.min() < 0: pred = pred.sigmoid()
|
| 126 |
|
| 127 |
-
# Resize back to original
|
| 128 |
pred_pil = transforms.ToPILImage()(pred)
|
| 129 |
mask = pred_pil.resize((w, h), resample=Image.LANCZOS)
|
| 130 |
return mask
|
|
@@ -134,36 +119,36 @@ def inference_vitmatte(image, device):
|
|
| 134 |
Runs pipeline: RMBG (Rough Mask) -> Trimap -> VitMatte (Refined Mask)
|
| 135 |
"""
|
| 136 |
# 1. Get Rough Mask using RMBG (Fast)
|
| 137 |
-
rmbg_model, _ = load_rmbg_model()
|
| 138 |
rough_mask_pil = inference_segmentation(rmbg_model, image, device, resolution=1024)
|
| 139 |
|
| 140 |
-
# 2. Create Trimap
|
| 141 |
-
# Convert PIL mask to Tensor
|
| 142 |
mask_tensor = transforms.ToTensor()(rough_mask_pil).to(device)
|
| 143 |
-
# Generate trimap (1=FG, 0=BG, 0.5=Unknown)
|
| 144 |
trimap_tensor = generate_trimap(mask_tensor, erode_kernel_size=25, dilate_kernel_size=25)
|
| 145 |
|
| 146 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
processor, model, _ = load_vitmatte_model()
|
| 148 |
|
| 149 |
-
#
|
| 150 |
-
inputs = processor(images=image, trimaps=
|
|
|
|
| 151 |
|
| 152 |
with torch.no_grad():
|
| 153 |
outputs = model(**inputs)
|
| 154 |
|
| 155 |
-
# Output is the refined alphas
|
| 156 |
alphas = outputs.alphas
|
| 157 |
-
|
| 158 |
-
# 4. Post-process
|
| 159 |
-
# Extract alpha, resize to original
|
| 160 |
alpha_np = alphas.squeeze().cpu().numpy()
|
| 161 |
alpha_pil = Image.fromarray((alpha_np * 255).astype("uint8"), mode="L")
|
| 162 |
alpha_pil = alpha_pil.resize(image.size, resample=Image.LANCZOS)
|
| 163 |
|
| 164 |
return alpha_pil
|
| 165 |
|
| 166 |
-
|
| 167 |
@st.cache_data(show_spinner=False)
|
| 168 |
def process_background_removal(image_bytes, method="RMBG-1.4"):
|
| 169 |
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
|
@@ -177,19 +162,16 @@ def process_background_removal(image_bytes, method="RMBG-1.4"):
|
|
| 177 |
mask = inference_segmentation(model, image, device, resolution=1024)
|
| 178 |
|
| 179 |
elif method == "VitMatte (Refiner)":
|
| 180 |
-
# VitMatte needs GPU ideally, works on CPU but slow
|
| 181 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 182 |
mask = inference_vitmatte(image, device)
|
| 183 |
|
| 184 |
else:
|
| 185 |
-
# Fallback
|
| 186 |
return image
|
| 187 |
|
| 188 |
-
# Apply mask
|
| 189 |
image.putalpha(mask)
|
| 190 |
return image
|
| 191 |
|
| 192 |
-
# --- Upscaling Logic
|
| 193 |
def run_swin_inference(image, processor, model):
|
| 194 |
inputs = processor(image, return_tensors="pt")
|
| 195 |
with torch.no_grad():
|
|
@@ -264,13 +246,12 @@ def main():
|
|
| 264 |
st.sidebar.header("1. Background Removal")
|
| 265 |
remove_bg = st.sidebar.checkbox("Remove Background", value=False)
|
| 266 |
|
| 267 |
-
# NEW: Model Selector
|
| 268 |
if remove_bg:
|
| 269 |
bg_model = st.sidebar.selectbox(
|
| 270 |
"Select AI Model",
|
| 271 |
["RMBG-1.4", "BiRefNet (Heavy)", "VitMatte (Refiner)"],
|
| 272 |
index=0,
|
| 273 |
-
help="RMBG: Fast
|
| 274 |
)
|
| 275 |
else:
|
| 276 |
bg_model = "None"
|
|
@@ -294,7 +275,6 @@ def main():
|
|
| 294 |
|
| 295 |
# 1. Background
|
| 296 |
if remove_bg:
|
| 297 |
-
# We add the model name to the spinner text so user knows what's happening
|
| 298 |
with st.spinner(f"Removing background using {bg_model}..."):
|
| 299 |
processed_image = process_background_removal(file_bytes, bg_model)
|
| 300 |
else:
|
|
@@ -303,9 +283,7 @@ def main():
|
|
| 303 |
# 2. Upscaling
|
| 304 |
if upscale_mode != "None":
|
| 305 |
scale = 4 if "4x" in upscale_mode else 2
|
| 306 |
-
|
| 307 |
-
# Cache Key includes model name now
|
| 308 |
-
cache_key = f"{uploaded_file.name}_{bg_model}_{scale}_{grid_n}_v5"
|
| 309 |
|
| 310 |
if "upscale_cache" not in st.session_state:
|
| 311 |
st.session_state.upscale_cache = {}
|
|
|
|
| 24 |
@st.cache_resource
|
| 25 |
def load_birefnet_model():
|
| 26 |
"""Option 2: The Heavyweight Generalist"""
|
|
|
|
| 27 |
model = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet", trust_remote_code=True)
|
| 28 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 29 |
model.to(device)
|
|
|
|
| 66 |
|
| 67 |
def generate_trimap(mask_tensor, erode_kernel_size=10, dilate_kernel_size=10):
|
| 68 |
"""
|
| 69 |
+
Generates a trimap (Foreground, Background, Unknown) from a binary mask.
|
|
|
|
| 70 |
Values: 1=FG, 0=BG, 0.5=Unknown (Edge)
|
| 71 |
"""
|
|
|
|
| 72 |
if mask_tensor.dim() == 3: mask_tensor = mask_tensor.unsqueeze(0)
|
| 73 |
|
|
|
|
| 74 |
erode_k = erode_kernel_size
|
| 75 |
dilate_k = dilate_kernel_size
|
| 76 |
|
| 77 |
+
# Dilation (Max Pooling)
|
|
|
|
| 78 |
dilated = F.max_pool2d(mask_tensor, kernel_size=dilate_k, stride=1, padding=dilate_k//2)
|
| 79 |
|
| 80 |
+
# Erosion (Negative Max Pooling)
|
| 81 |
eroded = -F.max_pool2d(-mask_tensor, kernel_size=erode_k, stride=1, padding=erode_k//2)
|
| 82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 83 |
trimap = torch.full_like(mask_tensor, 0.5)
|
|
|
|
|
|
|
| 84 |
trimap[eroded > 0.5] = 1.0
|
| 85 |
trimap[dilated < 0.5] = 0.0
|
| 86 |
|
|
|
|
| 107 |
if not isinstance(result_tensor, torch.Tensor):
|
| 108 |
if isinstance(result_tensor, (list, tuple)): result_tensor = result_tensor[0]
|
| 109 |
|
|
|
|
| 110 |
pred = result_tensor.squeeze().cpu()
|
| 111 |
if pred.max() > 1 or pred.min() < 0: pred = pred.sigmoid()
|
| 112 |
|
|
|
|
| 113 |
pred_pil = transforms.ToPILImage()(pred)
|
| 114 |
mask = pred_pil.resize((w, h), resample=Image.LANCZOS)
|
| 115 |
return mask
|
|
|
|
| 119 |
Runs pipeline: RMBG (Rough Mask) -> Trimap -> VitMatte (Refined Mask)
|
| 120 |
"""
|
| 121 |
# 1. Get Rough Mask using RMBG (Fast)
|
| 122 |
+
rmbg_model, _ = load_rmbg_model()
|
| 123 |
rough_mask_pil = inference_segmentation(rmbg_model, image, device, resolution=1024)
|
| 124 |
|
| 125 |
+
# 2. Create Trimap (Tensor)
|
|
|
|
| 126 |
mask_tensor = transforms.ToTensor()(rough_mask_pil).to(device)
|
|
|
|
| 127 |
trimap_tensor = generate_trimap(mask_tensor, erode_kernel_size=25, dilate_kernel_size=25)
|
| 128 |
|
| 129 |
+
# --- FIX START ---
|
| 130 |
+
# 3. Convert Trimap Tensor to PIL Image
|
| 131 |
+
# VitMatte Processor crashes on raw tensors. It wants a PIL Image.
|
| 132 |
+
# We take the tensor (0.0 to 1.0), move to CPU, and convert to PIL (0 to 255)
|
| 133 |
+
trimap_pil = transforms.ToPILImage()(trimap_tensor.squeeze().cpu())
|
| 134 |
+
|
| 135 |
+
# 4. VitMatte Inference
|
| 136 |
processor, model, _ = load_vitmatte_model()
|
| 137 |
|
| 138 |
+
# Pass PIL images for both
|
| 139 |
+
inputs = processor(images=image, trimaps=trimap_pil, return_tensors="pt").to(device)
|
| 140 |
+
# --- FIX END ---
|
| 141 |
|
| 142 |
with torch.no_grad():
|
| 143 |
outputs = model(**inputs)
|
| 144 |
|
|
|
|
| 145 |
alphas = outputs.alphas
|
|
|
|
|
|
|
|
|
|
| 146 |
alpha_np = alphas.squeeze().cpu().numpy()
|
| 147 |
alpha_pil = Image.fromarray((alpha_np * 255).astype("uint8"), mode="L")
|
| 148 |
alpha_pil = alpha_pil.resize(image.size, resample=Image.LANCZOS)
|
| 149 |
|
| 150 |
return alpha_pil
|
| 151 |
|
|
|
|
| 152 |
@st.cache_data(show_spinner=False)
|
| 153 |
def process_background_removal(image_bytes, method="RMBG-1.4"):
|
| 154 |
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
|
|
|
| 162 |
mask = inference_segmentation(model, image, device, resolution=1024)
|
| 163 |
|
| 164 |
elif method == "VitMatte (Refiner)":
|
|
|
|
| 165 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 166 |
mask = inference_vitmatte(image, device)
|
| 167 |
|
| 168 |
else:
|
|
|
|
| 169 |
return image
|
| 170 |
|
|
|
|
| 171 |
image.putalpha(mask)
|
| 172 |
return image
|
| 173 |
|
| 174 |
+
# --- Upscaling Logic ---
|
| 175 |
def run_swin_inference(image, processor, model):
|
| 176 |
inputs = processor(image, return_tensors="pt")
|
| 177 |
with torch.no_grad():
|
|
|
|
| 246 |
st.sidebar.header("1. Background Removal")
|
| 247 |
remove_bg = st.sidebar.checkbox("Remove Background", value=False)
|
| 248 |
|
|
|
|
| 249 |
if remove_bg:
|
| 250 |
bg_model = st.sidebar.selectbox(
|
| 251 |
"Select AI Model",
|
| 252 |
["RMBG-1.4", "BiRefNet (Heavy)", "VitMatte (Refiner)"],
|
| 253 |
index=0,
|
| 254 |
+
help="RMBG: Fast.\nBiRefNet: Better.\nVitMatte: Best for hair/transparency."
|
| 255 |
)
|
| 256 |
else:
|
| 257 |
bg_model = "None"
|
|
|
|
| 275 |
|
| 276 |
# 1. Background
|
| 277 |
if remove_bg:
|
|
|
|
| 278 |
with st.spinner(f"Removing background using {bg_model}..."):
|
| 279 |
processed_image = process_background_removal(file_bytes, bg_model)
|
| 280 |
else:
|
|
|
|
| 283 |
# 2. Upscaling
|
| 284 |
if upscale_mode != "None":
|
| 285 |
scale = 4 if "4x" in upscale_mode else 2
|
| 286 |
+
cache_key = f"{uploaded_file.name}_{bg_model}_{scale}_{grid_n}_v6"
|
|
|
|
|
|
|
| 287 |
|
| 288 |
if "upscale_cache" not in st.session_state:
|
| 289 |
st.session_state.upscale_cache = {}
|