lukeafullard commited on
Commit
02623e7
·
verified ·
1 Parent(s): f753191

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +176 -72
src/streamlit_app.py CHANGED
@@ -1,8 +1,9 @@
1
  import streamlit as st
2
  from PIL import Image
3
  import torch
 
4
  from torchvision import transforms
5
- from transformers import AutoModelForImageSegmentation, AutoImageProcessor, Swin2SRForImageSuperResolution
6
  import io
7
  import numpy as np
8
  import gc
@@ -10,33 +11,49 @@ import gc
10
  # Page Configuration
11
  st.set_page_config(layout="wide", page_title="AI Image Lab")
12
 
13
- # --- 1. MODEL LOADING ---
14
 
15
  @st.cache_resource
16
- def load_rembg_model():
17
- # RMBG-1.4 (Fast & High Quality)
18
  model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4", trust_remote_code=True)
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
  model.to(device)
21
  return model, device
22
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
23
  @st.cache_resource
24
  def load_upscaler(scale=2):
25
  if scale == 4:
26
- # FIXED: Use the 'RealWorld' model for 4x. It exists and handles artifacts better.
27
  model_id = "caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr"
28
  else:
29
- # 2x Classical Model
30
  model_id = "caidas/swin2SR-classical-sr-x2-64"
31
-
32
  processor = AutoImageProcessor.from_pretrained(model_id)
33
  model = Swin2SRForImageSuperResolution.from_pretrained(model_id)
34
  return processor, model
35
 
36
- # --- 2. PROCESSING LOGIC ---
37
 
38
  def find_mask_tensor(output):
39
- """Recursively finds the mask tensor."""
40
  if isinstance(output, torch.Tensor):
41
  if output.dim() == 4 and output.shape[1] == 1: return output
42
  elif output.dim() == 3 and output.shape[0] == 1: return output
@@ -48,112 +65,188 @@ def find_mask_tensor(output):
48
  if found is not None: return found
49
  return None
50
 
51
- def run_swin_inference(image, processor, model):
52
- inputs = processor(image, return_tensors="pt")
53
- with torch.no_grad():
54
- outputs = model(**inputs)
55
- output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
56
- output = np.moveaxis(output, 0, -1)
57
- output = (output * 255.0).round().astype(np.uint8)
58
- return Image.fromarray(output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- def upscale_chunk_logic(image, processor, model):
61
- if image.mode == 'RGBA':
62
- r, g, b, a = image.split()
63
- rgb_image = Image.merge('RGB', (r, g, b))
64
- upscaled_rgb = run_swin_inference(rgb_image, processor, model)
65
- upscaled_a = a.resize(upscaled_rgb.size, Image.Resampling.LANCZOS)
66
- return Image.merge('RGBA', (*upscaled_rgb.split(), upscaled_a))
67
- else:
68
- return run_swin_inference(image, processor, model)
69
 
70
- @st.cache_data(show_spinner=False)
71
- def process_background_removal(image_bytes):
72
- """Cached background removal (RMBG-1.4)."""
73
- image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
74
- model, device = load_rembg_model()
75
-
76
  w, h = image.size
77
- transform_image = transforms.Compose([
78
- transforms.Resize((1024, 1024)),
 
79
  transforms.ToTensor(),
80
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
81
  ])
82
- input_images = transform_image(image).unsqueeze(0).to(device)
83
 
84
  with torch.no_grad():
85
- outputs = model(input_images)
86
 
87
  result_tensor = find_mask_tensor(outputs)
88
  if result_tensor is None: result_tensor = outputs[0] if isinstance(outputs, (list, tuple)) else outputs
89
  if not isinstance(result_tensor, torch.Tensor):
90
  if isinstance(result_tensor, (list, tuple)): result_tensor = result_tensor[0]
91
 
 
92
  pred = result_tensor.squeeze().cpu()
93
  if pred.max() > 1 or pred.min() < 0: pred = pred.sigmoid()
94
-
 
95
  pred_pil = transforms.ToPILImage()(pred)
96
- mask = pred_pil.resize((w, h))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
97
  image.putalpha(mask)
98
  return image
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  def process_tiled_upscale(image, scale_factor, grid_n, progress_bar):
101
- """
102
- Tiled upscaling with OVERLAP to prevent seams.
103
- """
104
  processor, model = load_upscaler(scale_factor)
105
  w, h = image.size
106
  rows = cols = grid_n
107
-
108
  tile_w = w // cols
109
  tile_h = h // rows
110
-
111
- # Overlap buffer (pixels)
112
  overlap = 32
113
-
114
  full_image = Image.new(image.mode, (w * scale_factor, h * scale_factor))
115
  total_tiles = rows * cols
116
  count = 0
117
-
118
  for y in range(rows):
119
  for x in range(cols):
120
- # Target Area
121
  target_left = x * tile_w
122
  target_upper = y * tile_h
123
  target_right = w if x == cols - 1 else (x + 1) * tile_w
124
  target_lower = h if y == rows - 1 else (y + 1) * tile_h
125
- target_w = target_right - target_left
126
- target_h = target_lower - target_upper
127
-
128
- # Source Area (with overlap)
129
  source_left = max(0, target_left - overlap)
130
  source_upper = max(0, target_upper - overlap)
131
  source_right = min(w, target_right + overlap)
132
  source_lower = min(h, target_lower + overlap)
133
-
134
  tile = image.crop((source_left, source_upper, source_right, source_lower))
135
  upscaled_tile = upscale_chunk_logic(tile, processor, model)
136
-
137
- # Calculate offsets for cropping the valid center
138
  extra_left = target_left - source_left
139
  extra_upper = target_upper - source_upper
140
-
141
  crop_x = extra_left * scale_factor
142
  crop_y = extra_upper * scale_factor
143
  crop_w = target_w * scale_factor
144
  crop_h = target_h * scale_factor
145
-
146
  clean_tile = upscaled_tile.crop((crop_x, crop_y, crop_x + crop_w, crop_y + crop_h))
147
-
148
  paste_x = target_left * scale_factor
149
  paste_y = target_upper * scale_factor
150
  full_image.paste(clean_tile, (paste_x, paste_y))
151
-
152
  del tile, upscaled_tile, clean_tile
153
  gc.collect()
154
  count += 1
155
- progress_bar.progress(count / total_tiles, text=f"Upscaling Tile {count}/{total_tiles} (with overlap)...")
156
-
157
  return full_image
158
 
159
  def convert_image_to_bytes(img):
@@ -161,16 +254,27 @@ def convert_image_to_bytes(img):
161
  img.save(buf, format="PNG")
162
  return buf.getvalue()
163
 
164
- # --- 3. MAIN APP ---
165
 
166
  def main():
167
- st.title("✨ AI Image Lab: Final Edition")
168
- st.markdown("Features: **RMBG-1.4** | **Swin2SR (Seamless Tiling)** | **Progress Bar**")
169
 
170
  # --- Sidebar ---
171
- st.sidebar.header("1. Background")
172
  remove_bg = st.sidebar.checkbox("Remove Background", value=False)
173
 
 
 
 
 
 
 
 
 
 
 
 
174
  st.sidebar.header("2. AI Upscaling")
175
  upscale_mode = st.sidebar.radio("Magnification", ["None", "2x", "4x"])
176
 
@@ -188,25 +292,27 @@ def main():
188
  if uploaded_file is not None:
189
  file_bytes = uploaded_file.getvalue()
190
 
191
- # 1. Background Removal
192
  if remove_bg:
193
- processed_image = process_background_removal(file_bytes)
 
 
194
  else:
195
  processed_image = Image.open(io.BytesIO(file_bytes)).convert("RGB")
196
 
197
- # 2. Upscaling (Manual Caching with Session State)
198
  if upscale_mode != "None":
199
  scale = 4 if "4x" in upscale_mode else 2
200
 
201
- # Cache Key
202
- cache_key = f"{uploaded_file.name}_{remove_bg}_{scale}_{grid_n}_overlap_v4"
203
 
204
  if "upscale_cache" not in st.session_state:
205
  st.session_state.upscale_cache = {}
206
 
207
  if cache_key in st.session_state.upscale_cache:
208
  processed_image = st.session_state.upscale_cache[cache_key]
209
- st.info("✅ Loaded upscaled image from cache (Instant!)")
210
  else:
211
  progress_bar = st.progress(0, text="Initializing AI models...")
212
  processed_image = process_tiled_upscale(processed_image, scale, grid_n, progress_bar)
@@ -223,12 +329,10 @@ def main():
223
  with col1:
224
  st.subheader("Original")
225
  st.image(Image.open(io.BytesIO(file_bytes)), use_container_width=True)
226
- st.caption(f"Size: {Image.open(io.BytesIO(file_bytes)).size}")
227
 
228
  with col2:
229
  st.subheader("Result")
230
  st.image(final_image, use_container_width=True)
231
- st.caption(f"Size: {final_image.size}")
232
 
233
  st.markdown("---")
234
  st.download_button(
 
1
  import streamlit as st
2
  from PIL import Image
3
  import torch
4
+ import torch.nn.functional as F
5
  from torchvision import transforms
6
+ from transformers import AutoModelForImageSegmentation, AutoImageProcessor, Swin2SRForImageSuperResolution, VitMatteForImageMatting
7
  import io
8
  import numpy as np
9
  import gc
 
11
  # Page Configuration
12
  st.set_page_config(layout="wide", page_title="AI Image Lab")
13
 
14
+ # --- 1. MODEL LOADING (Cached) ---
15
 
16
  @st.cache_resource
17
+ def load_rmbg_model():
18
+ """Option 1: The Lightweight Specialist"""
19
  model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4", trust_remote_code=True)
20
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
21
  model.to(device)
22
  return model, device
23
 
24
+ @st.cache_resource
25
+ def load_birefnet_model():
26
+ """Option 2: The Heavyweight Generalist"""
27
+ # This requires 'timm' installed
28
+ model = AutoModelForImageSegmentation.from_pretrained("ZhengPeng7/BiRefNet", trust_remote_code=True)
29
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
30
+ model.to(device)
31
+ return model, device
32
+
33
+ @st.cache_resource
34
+ def load_vitmatte_model():
35
+ """Option 3: The Refiner (Matting)"""
36
+ # VitMatte requires a rough mask first (we use RMBG for that)
37
+ processor = AutoImageProcessor.from_pretrained("hustvl/vitmatte-small-composition-1k")
38
+ model = VitMatteForImageMatting.from_pretrained("hustvl/vitmatte-small-composition-1k")
39
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
40
+ model.to(device)
41
+ return processor, model, device
42
+
43
  @st.cache_resource
44
  def load_upscaler(scale=2):
45
  if scale == 4:
 
46
  model_id = "caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr"
47
  else:
 
48
  model_id = "caidas/swin2SR-classical-sr-x2-64"
 
49
  processor = AutoImageProcessor.from_pretrained(model_id)
50
  model = Swin2SRForImageSuperResolution.from_pretrained(model_id)
51
  return processor, model
52
 
53
+ # --- 2. HELPER FUNCTIONS ---
54
 
55
  def find_mask_tensor(output):
56
+ """Recursively finds the mask tensor in complex model outputs."""
57
  if isinstance(output, torch.Tensor):
58
  if output.dim() == 4 and output.shape[1] == 1: return output
59
  elif output.dim() == 3 and output.shape[0] == 1: return output
 
65
  if found is not None: return found
66
  return None
67
 
68
+ def generate_trimap(mask_tensor, erode_kernel_size=10, dilate_kernel_size=10):
69
+ """
70
+ Generates a trimap (Foreground, Background, Unknown) from a binary mask
71
+ using Pure PyTorch (No OpenCV required).
72
+ Values: 1=FG, 0=BG, 0.5=Unknown (Edge)
73
+ """
74
+ # Ensure mask is Bx1xHxW
75
+ if mask_tensor.dim() == 3: mask_tensor = mask_tensor.unsqueeze(0)
76
+
77
+ # Create kernels
78
+ erode_k = erode_kernel_size
79
+ dilate_k = dilate_kernel_size
80
+
81
+ # Dilation (Max Pooling) - Expands the white area
82
+ # We pad to keep size same
83
+ dilated = F.max_pool2d(mask_tensor, kernel_size=dilate_k, stride=1, padding=dilate_k//2)
84
+
85
+ # Erosion (Negative Max Pooling) - Shrinks the white area
86
+ eroded = -F.max_pool2d(-mask_tensor, kernel_size=erode_k, stride=1, padding=erode_k//2)
87
+
88
+ # Trimap construction
89
+ # Pixels that are 1 in eroded are definitely FG (1.0)
90
+ # Pixels that are 0 in dilated are definitely BG (0.0)
91
+ # Everything else is the "Unknown" zone (0.5)
92
+
93
+ # Start with Unknown (0.5)
94
+ trimap = torch.full_like(mask_tensor, 0.5)
95
+
96
+ # Set definites
97
+ trimap[eroded > 0.5] = 1.0
98
+ trimap[dilated < 0.5] = 0.0
99
+
100
+ return trimap
101
 
102
+ # --- 3. INFERENCE LOGIC ---
 
 
 
 
 
 
 
 
103
 
104
+ def inference_segmentation(model, image, device, resolution=1024):
105
+ """Generic inference for RMBG and BiRefNet."""
 
 
 
 
106
  w, h = image.size
107
+
108
+ transform = transforms.Compose([
109
+ transforms.Resize((resolution, resolution)),
110
  transforms.ToTensor(),
111
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
112
  ])
113
+ input_tensor = transform(image).unsqueeze(0).to(device)
114
 
115
  with torch.no_grad():
116
+ outputs = model(input_tensor)
117
 
118
  result_tensor = find_mask_tensor(outputs)
119
  if result_tensor is None: result_tensor = outputs[0] if isinstance(outputs, (list, tuple)) else outputs
120
  if not isinstance(result_tensor, torch.Tensor):
121
  if isinstance(result_tensor, (list, tuple)): result_tensor = result_tensor[0]
122
 
123
+ # Get binary-ish mask (logits or sigmoid)
124
  pred = result_tensor.squeeze().cpu()
125
  if pred.max() > 1 or pred.min() < 0: pred = pred.sigmoid()
126
+
127
+ # Resize back to original
128
  pred_pil = transforms.ToPILImage()(pred)
129
+ mask = pred_pil.resize((w, h), resample=Image.LANCZOS)
130
+ return mask
131
+
132
+ def inference_vitmatte(image, device):
133
+ """
134
+ Runs pipeline: RMBG (Rough Mask) -> Trimap -> VitMatte (Refined Mask)
135
+ """
136
+ # 1. Get Rough Mask using RMBG (Fast)
137
+ rmbg_model, _ = load_rmbg_model() # Re-use loaded model
138
+ rough_mask_pil = inference_segmentation(rmbg_model, image, device, resolution=1024)
139
+
140
+ # 2. Create Trimap
141
+ # Convert PIL mask to Tensor
142
+ mask_tensor = transforms.ToTensor()(rough_mask_pil).to(device)
143
+ # Generate trimap (1=FG, 0=BG, 0.5=Unknown)
144
+ trimap_tensor = generate_trimap(mask_tensor, erode_kernel_size=25, dilate_kernel_size=25)
145
+
146
+ # 3. VitMatte Inference
147
+ processor, model, _ = load_vitmatte_model()
148
+
149
+ # VitMatte expects inputs: pixel_values (image) and mask_labels (trimap)
150
+ inputs = processor(images=image, trimaps=trimap_tensor, return_tensors="pt").to(device)
151
+
152
+ with torch.no_grad():
153
+ outputs = model(**inputs)
154
+
155
+ # Output is the refined alphas
156
+ alphas = outputs.alphas
157
+
158
+ # 4. Post-process
159
+ # Extract alpha, resize to original
160
+ alpha_np = alphas.squeeze().cpu().numpy()
161
+ alpha_pil = Image.fromarray((alpha_np * 255).astype("uint8"), mode="L")
162
+ alpha_pil = alpha_pil.resize(image.size, resample=Image.LANCZOS)
163
+
164
+ return alpha_pil
165
+
166
+
167
+ @st.cache_data(show_spinner=False)
168
+ def process_background_removal(image_bytes, method="RMBG-1.4"):
169
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
170
+
171
+ if method == "RMBG-1.4":
172
+ model, device = load_rmbg_model()
173
+ mask = inference_segmentation(model, image, device)
174
+
175
+ elif method == "BiRefNet (Heavy)":
176
+ model, device = load_birefnet_model()
177
+ mask = inference_segmentation(model, image, device, resolution=1024)
178
+
179
+ elif method == "VitMatte (Refiner)":
180
+ # VitMatte needs GPU ideally, works on CPU but slow
181
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
182
+ mask = inference_vitmatte(image, device)
183
+
184
+ else:
185
+ # Fallback
186
+ return image
187
+
188
+ # Apply mask
189
  image.putalpha(mask)
190
  return image
191
 
192
+ # --- Upscaling Logic (Same as before) ---
193
+ def run_swin_inference(image, processor, model):
194
+ inputs = processor(image, return_tensors="pt")
195
+ with torch.no_grad():
196
+ outputs = model(**inputs)
197
+ output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
198
+ output = np.moveaxis(output, 0, -1)
199
+ output = (output * 255.0).round().astype(np.uint8)
200
+ return Image.fromarray(output)
201
+
202
+ def upscale_chunk_logic(image, processor, model):
203
+ if image.mode == 'RGBA':
204
+ r, g, b, a = image.split()
205
+ rgb_image = Image.merge('RGB', (r, g, b))
206
+ upscaled_rgb = run_swin_inference(rgb_image, processor, model)
207
+ upscaled_a = a.resize(upscaled_rgb.size, Image.Resampling.LANCZOS)
208
+ return Image.merge('RGBA', (*upscaled_rgb.split(), upscaled_a))
209
+ else:
210
+ return run_swin_inference(image, processor, model)
211
+
212
  def process_tiled_upscale(image, scale_factor, grid_n, progress_bar):
 
 
 
213
  processor, model = load_upscaler(scale_factor)
214
  w, h = image.size
215
  rows = cols = grid_n
 
216
  tile_w = w // cols
217
  tile_h = h // rows
 
 
218
  overlap = 32
 
219
  full_image = Image.new(image.mode, (w * scale_factor, h * scale_factor))
220
  total_tiles = rows * cols
221
  count = 0
 
222
  for y in range(rows):
223
  for x in range(cols):
 
224
  target_left = x * tile_w
225
  target_upper = y * tile_h
226
  target_right = w if x == cols - 1 else (x + 1) * tile_w
227
  target_lower = h if y == rows - 1 else (y + 1) * tile_h
 
 
 
 
228
  source_left = max(0, target_left - overlap)
229
  source_upper = max(0, target_upper - overlap)
230
  source_right = min(w, target_right + overlap)
231
  source_lower = min(h, target_lower + overlap)
 
232
  tile = image.crop((source_left, source_upper, source_right, source_lower))
233
  upscaled_tile = upscale_chunk_logic(tile, processor, model)
234
+ target_w = target_right - target_left
235
+ target_h = target_lower - target_upper
236
  extra_left = target_left - source_left
237
  extra_upper = target_upper - source_upper
 
238
  crop_x = extra_left * scale_factor
239
  crop_y = extra_upper * scale_factor
240
  crop_w = target_w * scale_factor
241
  crop_h = target_h * scale_factor
 
242
  clean_tile = upscaled_tile.crop((crop_x, crop_y, crop_x + crop_w, crop_y + crop_h))
 
243
  paste_x = target_left * scale_factor
244
  paste_y = target_upper * scale_factor
245
  full_image.paste(clean_tile, (paste_x, paste_y))
 
246
  del tile, upscaled_tile, clean_tile
247
  gc.collect()
248
  count += 1
249
+ progress_bar.progress(count / total_tiles, text=f"Upscaling Tile {count}/{total_tiles}...")
 
250
  return full_image
251
 
252
  def convert_image_to_bytes(img):
 
254
  img.save(buf, format="PNG")
255
  return buf.getvalue()
256
 
257
+ # --- 4. MAIN APP ---
258
 
259
  def main():
260
+ st.title("✨ AI Image Lab: Ultimate Edition")
261
+ st.markdown("Features: **Multi-Model Background** | **Swin2SR** | **Progress Bar**")
262
 
263
  # --- Sidebar ---
264
+ st.sidebar.header("1. Background Removal")
265
  remove_bg = st.sidebar.checkbox("Remove Background", value=False)
266
 
267
+ # NEW: Model Selector
268
+ if remove_bg:
269
+ bg_model = st.sidebar.selectbox(
270
+ "Select AI Model",
271
+ ["RMBG-1.4", "BiRefNet (Heavy)", "VitMatte (Refiner)"],
272
+ index=0,
273
+ help="RMBG: Fast, Standard Quality.\nBiRefNet: Slower, Better Edges.\nVitMatte: Slowest, Best for Hair/Transparency."
274
+ )
275
+ else:
276
+ bg_model = "None"
277
+
278
  st.sidebar.header("2. AI Upscaling")
279
  upscale_mode = st.sidebar.radio("Magnification", ["None", "2x", "4x"])
280
 
 
292
  if uploaded_file is not None:
293
  file_bytes = uploaded_file.getvalue()
294
 
295
+ # 1. Background
296
  if remove_bg:
297
+ # We add the model name to the spinner text so user knows what's happening
298
+ with st.spinner(f"Removing background using {bg_model}..."):
299
+ processed_image = process_background_removal(file_bytes, bg_model)
300
  else:
301
  processed_image = Image.open(io.BytesIO(file_bytes)).convert("RGB")
302
 
303
+ # 2. Upscaling
304
  if upscale_mode != "None":
305
  scale = 4 if "4x" in upscale_mode else 2
306
 
307
+ # Cache Key includes model name now
308
+ cache_key = f"{uploaded_file.name}_{bg_model}_{scale}_{grid_n}_v5"
309
 
310
  if "upscale_cache" not in st.session_state:
311
  st.session_state.upscale_cache = {}
312
 
313
  if cache_key in st.session_state.upscale_cache:
314
  processed_image = st.session_state.upscale_cache[cache_key]
315
+ st.info("✅ Loaded upscaled image from cache")
316
  else:
317
  progress_bar = st.progress(0, text="Initializing AI models...")
318
  processed_image = process_tiled_upscale(processed_image, scale, grid_n, progress_bar)
 
329
  with col1:
330
  st.subheader("Original")
331
  st.image(Image.open(io.BytesIO(file_bytes)), use_container_width=True)
 
332
 
333
  with col2:
334
  st.subheader("Result")
335
  st.image(final_image, use_container_width=True)
 
336
 
337
  st.markdown("---")
338
  st.download_button(