lukeafullard commited on
Commit
8de6538
·
verified ·
1 Parent(s): 1c883f5

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +85 -118
src/streamlit_app.py CHANGED
@@ -5,13 +5,12 @@ from torchvision import transforms
5
  from transformers import AutoModelForImageSegmentation, AutoImageProcessor, Swin2SRForImageSuperResolution
6
  import io
7
  import numpy as np
8
- import gc # Garbage collection for memory safety
9
 
10
  # Page Configuration
11
  st.set_page_config(layout="wide", page_title="AI Image Lab")
12
 
13
- # --- 1. MODEL LOADING (Cached Resource) ---
14
- # Models are loaded once and stay in memory.
15
 
16
  @st.cache_resource
17
  def load_rembg_model():
@@ -30,18 +29,14 @@ def load_upscaler(scale=2):
30
  model = Swin2SRForImageSuperResolution.from_pretrained(model_id)
31
  return processor, model
32
 
33
- # --- 2. HELPER FUNCTIONS ---
34
 
35
  def find_mask_tensor(output):
36
- """Recursively finds the mask tensor in complex model outputs."""
37
  if isinstance(output, torch.Tensor):
38
- if output.dim() == 4 and output.shape[1] == 1:
39
- return output
40
- elif output.dim() == 3 and output.shape[0] == 1:
41
- return output
42
  return None
43
- if hasattr(output, "logits"):
44
- return find_mask_tensor(output.logits)
45
  elif isinstance(output, (list, tuple)):
46
  for item in output:
47
  found = find_mask_tensor(item)
@@ -49,49 +44,30 @@ def find_mask_tensor(output):
49
  return None
50
 
51
  def run_swin_inference(image, processor, model):
52
- """Atomic inference for a single chunk."""
53
  inputs = processor(image, return_tensors="pt")
54
  with torch.no_grad():
55
  outputs = model(**inputs)
56
-
57
  output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
58
  output = np.moveaxis(output, 0, -1)
59
  output = (output * 255.0).round().astype(np.uint8)
60
  return Image.fromarray(output)
61
 
62
  def upscale_chunk_logic(image, processor, model):
63
- """Handles RGBA vs RGB logic for a single chunk."""
64
  if image.mode == 'RGBA':
65
  r, g, b, a = image.split()
66
  rgb_image = Image.merge('RGB', (r, g, b))
67
  upscaled_rgb = run_swin_inference(rgb_image, processor, model)
68
- # Resize alpha to match new RGB size
69
  upscaled_a = a.resize(upscaled_rgb.size, Image.Resampling.LANCZOS)
70
  return Image.merge('RGBA', (*upscaled_rgb.split(), upscaled_a))
71
  else:
72
  return run_swin_inference(image, processor, model)
73
 
74
- def convert_image_to_bytes(img):
75
- buf = io.BytesIO()
76
- img.save(buf, format="PNG")
77
- return buf.getvalue()
78
-
79
- # --- 3. HEAVY OPERATIONS (Cached Data) ---
80
- # These functions cache their results. If inputs (image/settings) don't change,
81
- # they return the previous result instantly without using RAM/CPU.
82
-
83
  @st.cache_data(show_spinner=False)
84
  def process_background_removal(image_bytes):
85
- """
86
- Removes background. Input is bytes to make it hashable for caching.
87
- """
88
- # Re-open image from bytes
89
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
90
-
91
- # Load model
92
  model, device = load_rembg_model()
93
 
94
- # Preprocessing
95
  w, h = image.size
96
  transform_image = transforms.Compose([
97
  transforms.Resize((1024, 1024)),
@@ -100,15 +76,11 @@ def process_background_removal(image_bytes):
100
  ])
101
  input_images = transform_image(image).unsqueeze(0).to(device)
102
 
103
- # Inference
104
  with torch.no_grad():
105
  outputs = model(input_images)
106
 
107
- # Find Mask
108
  result_tensor = find_mask_tensor(outputs)
109
- if result_tensor is None:
110
- result_tensor = outputs[0] if isinstance(outputs, (list, tuple)) else outputs
111
-
112
  if not isinstance(result_tensor, torch.Tensor):
113
  if isinstance(result_tensor, (list, tuple)): result_tensor = result_tensor[0]
114
 
@@ -118,82 +90,90 @@ def process_background_removal(image_bytes):
118
  pred_pil = transforms.ToPILImage()(pred)
119
  mask = pred_pil.resize((w, h))
120
  image.putalpha(mask)
121
-
122
  return image
123
 
124
- def process_tiled_upscale(image, scale_factor, grid_n, progress_bar=None):
125
  """
126
- Splits image into n*n tiles, upscales each, and merges.
127
- This function is NOT cached directly because it uses a progress bar (UI element).
128
- We wrap the logic inside the main loop or a separate cached function if needed.
129
  """
130
- # Load Model
131
  processor, model = load_upscaler(scale_factor)
132
-
133
  w, h = image.size
134
- rows = grid_n
135
- cols = grid_n
136
 
137
- # Calculate tile sizes
138
  tile_w = w // cols
139
  tile_h = h // rows
140
 
141
- # Create large canvas
 
 
142
  full_image = Image.new(image.mode, (w * scale_factor, h * scale_factor))
143
  total_tiles = rows * cols
144
  count = 0
145
 
146
  for y in range(rows):
147
  for x in range(cols):
148
- # 1. Crop
149
- left = x * tile_w
150
- upper = y * tile_h
151
- # Handle edge pixels (ensure last tile takes remainder)
152
- right = w if x == cols - 1 else (x + 1) * tile_w
153
- lower = h if y == rows - 1 else (y + 1) * tile_h
 
 
 
 
 
 
 
 
 
 
154
 
155
- tile = image.crop((left, upper, right, lower))
 
156
 
157
- # 2. Upscale
158
  upscaled_tile = upscale_chunk_logic(tile, processor, model)
159
 
160
- # 3. Paste
161
- paste_x = left * scale_factor
162
- paste_y = upper * scale_factor
163
- full_image.paste(upscaled_tile, (paste_x, paste_y))
164
 
165
- # 4. Memory Cleanup (Crucial for 16Gi limit)
166
- del tile
167
- del upscaled_tile
168
- gc.collect()
169
- if torch.cuda.is_available():
170
- torch.cuda.empty_cache()
 
 
 
 
 
 
 
 
 
 
 
171
 
172
- # 5. Update UI
173
  count += 1
174
- if progress_bar:
175
- progress_bar.progress(count / total_tiles, text=f"Processing Tile {count}/{total_tiles}...")
176
 
177
  return full_image
178
 
179
- # Wrapper for caching the upscale result (without progress bar args)
180
- @st.cache_data(show_spinner=False)
181
- def cached_upscale_wrapper(image_bytes, scale_factor, grid_n):
182
- """
183
- This wrapper allows us to cache the upscale result.
184
- We convert PIL->Bytes->PIL inside to ensure Streamlit can hash the input.
185
- """
186
- image = Image.open(io.BytesIO(image_bytes))
187
- # We cannot pass the progress bar to a cached function,
188
- # so we run it without the bar or handle the bar outside.
189
- # For caching purposes, we run it 'quietly'.
190
- return process_tiled_upscale(image, scale_factor, grid_n, progress_bar=None)
191
 
192
- # --- 4. MAIN APP ---
193
 
194
  def main():
195
- st.title("✨ AI Image Lab: Memory Safe")
196
- st.markdown("Features: **RMBG-1.4** | **Swin2SR (Tiled)** | **Smart Caching**")
197
 
198
  # --- Sidebar ---
199
  st.sidebar.header("1. Background")
@@ -202,16 +182,8 @@ def main():
202
  st.sidebar.header("2. AI Upscaling")
203
  upscale_mode = st.sidebar.radio("Magnification", ["None", "2x", "4x"])
204
 
205
- # Grid Slider for Memory Safety
206
  if upscale_mode != "None":
207
- grid_n = st.sidebar.slider(
208
- "Grid Split (Memory Saver)",
209
- min_value=2,
210
- max_value=8,
211
- value=4,
212
- help="Higher = Less RAM used, but slightly slower. If crashing, increase this!"
213
- )
214
- st.sidebar.info(f"Splitting image into {grid_n*grid_n} pieces.")
215
  else:
216
  grid_n = 2
217
 
@@ -222,38 +194,34 @@ def main():
222
  uploaded_file = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg", "webp"])
223
 
224
  if uploaded_file is not None:
225
- # Load Original
226
- file_bytes = uploaded_file.getvalue() # Keep raw bytes for caching references
227
- image = Image.open(io.BytesIO(file_bytes)).convert("RGB")
228
 
229
- # --- PIPELINE START ---
230
-
231
- # Step 1: Background Removal (Cached)
232
  if remove_bg:
233
- with st.spinner("Removing background..."):
234
- # We pass bytes to the cached function
235
- processed_image = process_background_removal(file_bytes)
236
  else:
237
- processed_image = image
238
 
239
- # Step 2: Upscaling (Cached manually or via wrapper)
240
  if upscale_mode != "None":
241
  scale = 4 if "4x" in upscale_mode else 2
242
 
243
- # Convert current stage to bytes for cache key
244
- current_stage_bytes = convert_image_to_bytes(processed_image)
245
 
246
- # Check if we should use the cached wrapper or run with progress bar
247
- # To preserve the "Progress Bar" experience while still caching, we can:
248
- # Check if it's already in cache? Streamlit doesn't expose `is_cached`.
249
- # We will use the cached wrapper. The downside: the first run won't show the detailed tile progress
250
- # inside the cached function, just the spinner.
251
 
252
- with st.spinner(f"Upscaling x{scale} ({grid_n*grid_n} tiles)..."):
253
- processed_image = cached_upscale_wrapper(current_stage_bytes, scale, grid_n)
254
-
255
- # Step 3: Geometry (Fast - No Caching needed, applied on top)
256
- # This runs every time you move the slider, but Step 1 & 2 use cache, so it's instant.
 
 
 
 
 
257
  final_image = processed_image.copy()
258
  if rotate_angle != 0:
259
  final_image = final_image.rotate(rotate_angle, expand=True)
@@ -262,15 +230,14 @@ def main():
262
  col1, col2 = st.columns(2)
263
  with col1:
264
  st.subheader("Original")
265
- st.image(image, use_container_width=True)
266
- st.caption(f"Size: {image.size}")
267
 
268
  with col2:
269
  st.subheader("Result")
270
  st.image(final_image, use_container_width=True)
271
  st.caption(f"Size: {final_image.size}")
272
 
273
- # --- Download ---
274
  st.markdown("---")
275
  st.download_button(
276
  label="💾 Download Result (PNG)",
 
5
  from transformers import AutoModelForImageSegmentation, AutoImageProcessor, Swin2SRForImageSuperResolution
6
  import io
7
  import numpy as np
8
+ import gc
9
 
10
  # Page Configuration
11
  st.set_page_config(layout="wide", page_title="AI Image Lab")
12
 
13
+ # --- 1. MODEL LOADING ---
 
14
 
15
  @st.cache_resource
16
  def load_rembg_model():
 
29
  model = Swin2SRForImageSuperResolution.from_pretrained(model_id)
30
  return processor, model
31
 
32
+ # --- 2. PROCESSING LOGIC ---
33
 
34
  def find_mask_tensor(output):
 
35
  if isinstance(output, torch.Tensor):
36
+ if output.dim() == 4 and output.shape[1] == 1: return output
37
+ elif output.dim() == 3 and output.shape[0] == 1: return output
 
 
38
  return None
39
+ if hasattr(output, "logits"): return find_mask_tensor(output.logits)
 
40
  elif isinstance(output, (list, tuple)):
41
  for item in output:
42
  found = find_mask_tensor(item)
 
44
  return None
45
 
46
  def run_swin_inference(image, processor, model):
 
47
  inputs = processor(image, return_tensors="pt")
48
  with torch.no_grad():
49
  outputs = model(**inputs)
 
50
  output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
51
  output = np.moveaxis(output, 0, -1)
52
  output = (output * 255.0).round().astype(np.uint8)
53
  return Image.fromarray(output)
54
 
55
  def upscale_chunk_logic(image, processor, model):
 
56
  if image.mode == 'RGBA':
57
  r, g, b, a = image.split()
58
  rgb_image = Image.merge('RGB', (r, g, b))
59
  upscaled_rgb = run_swin_inference(rgb_image, processor, model)
 
60
  upscaled_a = a.resize(upscaled_rgb.size, Image.Resampling.LANCZOS)
61
  return Image.merge('RGBA', (*upscaled_rgb.split(), upscaled_a))
62
  else:
63
  return run_swin_inference(image, processor, model)
64
 
 
 
 
 
 
 
 
 
 
65
  @st.cache_data(show_spinner=False)
66
  def process_background_removal(image_bytes):
67
+ """Cached background removal."""
 
 
 
68
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
 
 
69
  model, device = load_rembg_model()
70
 
 
71
  w, h = image.size
72
  transform_image = transforms.Compose([
73
  transforms.Resize((1024, 1024)),
 
76
  ])
77
  input_images = transform_image(image).unsqueeze(0).to(device)
78
 
 
79
  with torch.no_grad():
80
  outputs = model(input_images)
81
 
 
82
  result_tensor = find_mask_tensor(outputs)
83
+ if result_tensor is None: result_tensor = outputs[0] if isinstance(outputs, (list, tuple)) else outputs
 
 
84
  if not isinstance(result_tensor, torch.Tensor):
85
  if isinstance(result_tensor, (list, tuple)): result_tensor = result_tensor[0]
86
 
 
90
  pred_pil = transforms.ToPILImage()(pred)
91
  mask = pred_pil.resize((w, h))
92
  image.putalpha(mask)
 
93
  return image
94
 
95
+ def process_tiled_upscale(image, scale_factor, grid_n, progress_bar):
96
  """
97
+ Tiled upscaling with OVERLAP to prevent edge artifacts.
 
 
98
  """
 
99
  processor, model = load_upscaler(scale_factor)
 
100
  w, h = image.size
101
+ rows = cols = grid_n
 
102
 
103
+ # Base tile size (without overlap)
104
  tile_w = w // cols
105
  tile_h = h // rows
106
 
107
+ # Overlap buffer (pixels) - lets the AI see context
108
+ overlap = 32
109
+
110
  full_image = Image.new(image.mode, (w * scale_factor, h * scale_factor))
111
  total_tiles = rows * cols
112
  count = 0
113
 
114
  for y in range(rows):
115
  for x in range(cols):
116
+ # 1. Define the "Target" area (where this tile goes in the original)
117
+ target_left = x * tile_w
118
+ target_upper = y * tile_h
119
+ # Handle edge pixels for the last column/row
120
+ target_right = w if x == cols - 1 else (x + 1) * tile_w
121
+ target_lower = h if y == rows - 1 else (y + 1) * tile_h
122
+
123
+ target_w = target_right - target_left
124
+ target_h = target_lower - target_upper
125
+
126
+ # 2. Define the "Source" area (Target + Overlap)
127
+ # We expand the box outwards by 'overlap' px, but keep it within image bounds
128
+ source_left = max(0, target_left - overlap)
129
+ source_upper = max(0, target_upper - overlap)
130
+ source_right = min(w, target_right + overlap)
131
+ source_lower = min(h, target_lower + overlap)
132
 
133
+ # Crop the padded tile
134
+ tile = image.crop((source_left, source_upper, source_right, source_lower))
135
 
136
+ # 3. Upscale the Padded Tile
137
  upscaled_tile = upscale_chunk_logic(tile, processor, model)
138
 
139
+ # 4. Crop the "Valid" center from the upscaled tile
140
+ # Calculate how much extra we took on the Left and Top (in original scale)
141
+ extra_left = target_left - source_left
142
+ extra_upper = target_upper - source_upper
143
 
144
+ # Convert these offsets to the new Upscaled Scale
145
+ crop_x = extra_left * scale_factor
146
+ crop_y = extra_upper * scale_factor
147
+ crop_w = target_w * scale_factor
148
+ crop_h = target_h * scale_factor
149
+
150
+ # Perform the final crop to remove the overlap borders
151
+ clean_tile = upscaled_tile.crop((crop_x, crop_y, crop_x + crop_w, crop_y + crop_h))
152
+
153
+ # 5. Paste the clean tile
154
+ paste_x = target_left * scale_factor
155
+ paste_y = target_upper * scale_factor
156
+ full_image.paste(clean_tile, (paste_x, paste_y))
157
+
158
+ # Cleanup
159
+ del tile, upscaled_tile, clean_tile
160
+ gc.collect()
161
 
 
162
  count += 1
163
+ progress_bar.progress(count / total_tiles, text=f"Upscaling Tile {count}/{total_tiles} (with overlap)...")
 
164
 
165
  return full_image
166
 
167
+ def convert_image_to_bytes(img):
168
+ buf = io.BytesIO()
169
+ img.save(buf, format="PNG")
170
+ return buf.getvalue()
 
 
 
 
 
 
 
 
171
 
172
+ # --- 3. MAIN APP ---
173
 
174
  def main():
175
+ st.title("✨ AI Image Lab: Seamless Edition")
176
+ st.markdown("Features: **RMBG-1.4** | **Swin2SR (Seamless Tiling)** | **Progress Bar**")
177
 
178
  # --- Sidebar ---
179
  st.sidebar.header("1. Background")
 
182
  st.sidebar.header("2. AI Upscaling")
183
  upscale_mode = st.sidebar.radio("Magnification", ["None", "2x", "4x"])
184
 
 
185
  if upscale_mode != "None":
186
+ grid_n = st.sidebar.slider("Grid Split", 2, 8, 4, help="Higher = Safer RAM usage")
 
 
 
 
 
 
 
187
  else:
188
  grid_n = 2
189
 
 
194
  uploaded_file = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg", "webp"])
195
 
196
  if uploaded_file is not None:
197
+ file_bytes = uploaded_file.getvalue()
 
 
198
 
199
+ # 1. Background Removal
 
 
200
  if remove_bg:
201
+ processed_image = process_background_removal(file_bytes)
 
 
202
  else:
203
+ processed_image = Image.open(io.BytesIO(file_bytes)).convert("RGB")
204
 
205
+ # 2. Upscaling (Manual Caching with Session State)
206
  if upscale_mode != "None":
207
  scale = 4 if "4x" in upscale_mode else 2
208
 
209
+ # Cache Key
210
+ cache_key = f"{uploaded_file.name}_{remove_bg}_{scale}_{grid_n}_overlap"
211
 
212
+ if "upscale_cache" not in st.session_state:
213
+ st.session_state.upscale_cache = {}
 
 
 
214
 
215
+ if cache_key in st.session_state.upscale_cache:
216
+ processed_image = st.session_state.upscale_cache[cache_key]
217
+ st.info("✅ Loaded upscaled image from cache (Instant!)")
218
+ else:
219
+ progress_bar = st.progress(0, text="Initializing AI models...")
220
+ processed_image = process_tiled_upscale(processed_image, scale, grid_n, progress_bar)
221
+ progress_bar.empty()
222
+ st.session_state.upscale_cache[cache_key] = processed_image
223
+
224
+ # 3. Geometry
225
  final_image = processed_image.copy()
226
  if rotate_angle != 0:
227
  final_image = final_image.rotate(rotate_angle, expand=True)
 
230
  col1, col2 = st.columns(2)
231
  with col1:
232
  st.subheader("Original")
233
+ st.image(Image.open(io.BytesIO(file_bytes)), use_container_width=True)
234
+ st.caption(f"Size: {Image.open(io.BytesIO(file_bytes)).size}")
235
 
236
  with col2:
237
  st.subheader("Result")
238
  st.image(final_image, use_container_width=True)
239
  st.caption(f"Size: {final_image.size}")
240
 
 
241
  st.markdown("---")
242
  st.download_button(
243
  label="💾 Download Result (PNG)",