lukeafullard commited on
Commit
1c883f5
·
verified ·
1 Parent(s): a533ab1

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +139 -100
src/streamlit_app.py CHANGED
@@ -5,15 +5,16 @@ from torchvision import transforms
5
  from transformers import AutoModelForImageSegmentation, AutoImageProcessor, Swin2SRForImageSuperResolution
6
  import io
7
  import numpy as np
 
8
 
9
  # Page Configuration
10
  st.set_page_config(layout="wide", page_title="AI Image Lab")
11
 
12
- # --- 1. MODEL LOADING (Cached) ---
 
13
 
14
  @st.cache_resource
15
  def load_rembg_model():
16
- """Loads RMBG-1.4 for Background Removal."""
17
  model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4", trust_remote_code=True)
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
  model.to(device)
@@ -21,17 +22,15 @@ def load_rembg_model():
21
 
22
  @st.cache_resource
23
  def load_upscaler(scale=2):
24
- """Loads Swin2SR for Super-Resolution (2x or 4x)."""
25
  if scale == 4:
26
  model_id = "caidas/swin2SR-classical-sr-x4-63"
27
  else:
28
  model_id = "caidas/swin2SR-classical-sr-x2-64"
29
-
30
  processor = AutoImageProcessor.from_pretrained(model_id)
31
  model = Swin2SRForImageSuperResolution.from_pretrained(model_id)
32
  return processor, model
33
 
34
- # --- 2. PROCESSING FUNCTIONS ---
35
 
36
  def find_mask_tensor(output):
37
  """Recursively finds the mask tensor in complex model outputs."""
@@ -41,22 +40,58 @@ def find_mask_tensor(output):
41
  elif output.dim() == 3 and output.shape[0] == 1:
42
  return output
43
  return None
44
-
45
  if hasattr(output, "logits"):
46
  return find_mask_tensor(output.logits)
47
  elif isinstance(output, (list, tuple)):
48
  for item in output:
49
  found = find_mask_tensor(item)
50
  if found is not None: return found
51
- elif hasattr(output, "items"):
52
- for val in output.values():
53
- found = find_mask_tensor(val)
54
- if found is not None: return found
55
-
56
  return None
57
 
58
- def safe_rembg_inference(model, image, device):
59
- """Robust background removal inference."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  w, h = image.size
61
  transform_image = transforms.Compose([
62
  transforms.Resize((1024, 1024)),
@@ -65,13 +100,15 @@ def safe_rembg_inference(model, image, device):
65
  ])
66
  input_images = transform_image(image).unsqueeze(0).to(device)
67
 
 
68
  with torch.no_grad():
69
  outputs = model(input_images)
70
 
 
71
  result_tensor = find_mask_tensor(outputs)
72
  if result_tensor is None:
73
  result_tensor = outputs[0] if isinstance(outputs, (list, tuple)) else outputs
74
-
75
  if not isinstance(result_tensor, torch.Tensor):
76
  if isinstance(result_tensor, (list, tuple)): result_tensor = result_tensor[0]
77
 
@@ -81,143 +118,145 @@ def safe_rembg_inference(model, image, device):
81
  pred_pil = transforms.ToPILImage()(pred)
82
  mask = pred_pil.resize((w, h))
83
  image.putalpha(mask)
84
- return image
85
-
86
- def run_swin_inference(image, processor, model):
87
- """Atomic inference for a single image/tile."""
88
- inputs = processor(image, return_tensors="pt")
89
- with torch.no_grad():
90
- outputs = model(**inputs)
91
 
92
- output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
93
- output = np.moveaxis(output, 0, -1)
94
- output = (output * 255.0).round().astype(np.uint8)
95
- return Image.fromarray(output)
96
-
97
- def upscale_image_logic(image, processor, model):
98
- """Handles RGBA vs RGB logic for a single chunk."""
99
- if image.mode == 'RGBA':
100
- r, g, b, a = image.split()
101
- rgb_image = Image.merge('RGB', (r, g, b))
102
- upscaled_rgb = run_swin_inference(rgb_image, processor, model)
103
- upscaled_a = a.resize(upscaled_rgb.size, Image.Resampling.LANCZOS)
104
- return Image.merge('RGBA', (*upscaled_rgb.split(), upscaled_a))
105
- else:
106
- return run_swin_inference(image, processor, model)
107
 
108
- def tiled_upscale(image, processor, model, scale_factor, progress_bar):
109
  """
110
- Splits image into a 2x2 grid, upscales each tile, and updates progress bar.
 
 
111
  """
112
- rows, cols = 2, 2 # Split into 4 tiles
 
 
113
  w, h = image.size
 
 
114
 
115
  # Calculate tile sizes
116
  tile_w = w // cols
117
  tile_h = h // rows
118
 
 
119
  full_image = Image.new(image.mode, (w * scale_factor, h * scale_factor))
120
  total_tiles = rows * cols
121
  count = 0
122
 
123
  for y in range(rows):
124
  for x in range(cols):
125
- # Define crop box
126
  left = x * tile_w
127
  upper = y * tile_h
128
- # Ensure the last tile takes the remaining pixels (fixes rounding errors)
129
  right = w if x == cols - 1 else (x + 1) * tile_w
130
  lower = h if y == rows - 1 else (y + 1) * tile_h
131
 
132
- # Crop
133
  tile = image.crop((left, upper, right, lower))
134
 
135
- # Upscale the tile
136
- upscaled_tile = upscale_image_logic(tile, processor, model)
137
 
138
- # Paste into new canvas
139
  paste_x = left * scale_factor
140
  paste_y = upper * scale_factor
141
  full_image.paste(upscaled_tile, (paste_x, paste_y))
142
 
143
- # Update Progress
144
- count += 1
145
- progress_bar.progress(count / total_tiles, text=f"Upscaling Tile {count}/{total_tiles}...")
 
 
 
146
 
 
 
 
 
 
147
  return full_image
148
 
149
- def convert_image_to_bytes(img):
150
- buf = io.BytesIO()
151
- img.save(buf, format="PNG")
152
- return buf.getvalue()
 
 
 
 
 
 
 
 
153
 
154
- # --- 3. MAIN APP ---
155
 
156
  def main():
157
- st.title("✨ AI Image Lab: Tiled Edition")
158
- st.markdown("Features: **RMBG-1.4** | **Swin2SR (Tiled)** | **Geometry**")
159
 
160
  # --- Sidebar ---
161
  st.sidebar.header("1. Background")
162
  remove_bg = st.sidebar.checkbox("Remove Background", value=False)
163
 
164
  st.sidebar.header("2. AI Upscaling")
165
- upscale_mode = st.sidebar.radio("Magnification", ["None", "2x (Fast)", "4x (Slow - Tiled)"])
166
 
 
 
 
 
 
 
 
 
 
 
 
 
 
167
  st.sidebar.header("3. Geometry")
168
  rotate_angle = st.sidebar.slider("Rotate", -180, 180, 0, 1)
169
 
170
- # --- Main ---
171
  uploaded_file = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg", "webp"])
172
 
173
  if uploaded_file is not None:
174
- image = Image.open(uploaded_file).convert("RGB")
175
- processed_image = image.copy()
 
176
 
177
- # 1. Background
 
 
178
  if remove_bg:
179
- st.info("Loading RMBG Model...")
180
- try:
181
- bg_model, device = load_rembg_model()
182
- with st.spinner("Removing background..."):
183
- processed_image = safe_rembg_inference(bg_model, processed_image, device)
184
- except Exception as e:
185
- st.error(f"Background Removal Failed: {e}")
186
-
187
- # 2. Upscaling
188
  if upscale_mode != "None":
189
  scale = 4 if "4x" in upscale_mode else 2
190
 
191
- # If 4x, use the Progress Bar + Tiling method
192
- if scale == 4:
193
- st.info(f"Loading Swin2SR x{scale} Model...")
194
- try:
195
- processor, upscaler = load_upscaler(scale)
196
-
197
- # Create Progress Bar
198
- my_bar = st.progress(0, text="Starting Tiled Upscaling...")
199
-
200
- processed_image = tiled_upscale(processed_image, processor, upscaler, scale, my_bar)
201
-
202
- # Clear bar when done
203
- my_bar.empty()
204
-
205
- except Exception as e:
206
- st.error(f"Upscaling Failed: {e}")
207
 
208
- # If 2x, keep it simple (it's fast enough)
209
- else:
210
- st.info(f"Loading Swin2SR x{scale} Model...")
211
- try:
212
- processor, upscaler = load_upscaler(scale)
213
- with st.spinner("Upscaling (2x)..."):
214
- processed_image = upscale_image_logic(processed_image, processor, upscaler)
215
- except Exception as e:
216
- st.error(f"Upscaling Failed: {e}")
217
-
218
- # 3. Rotation
 
219
  if rotate_angle != 0:
220
- processed_image = processed_image.rotate(rotate_angle, expand=True)
221
 
222
  # --- Display ---
223
  col1, col2 = st.columns(2)
@@ -228,14 +267,14 @@ def main():
228
 
229
  with col2:
230
  st.subheader("Result")
231
- st.image(processed_image, use_container_width=True)
232
- st.caption(f"Size: {processed_image.size}")
233
 
234
  # --- Download ---
235
  st.markdown("---")
236
  st.download_button(
237
  label="💾 Download Result (PNG)",
238
- data=convert_image_to_bytes(processed_image),
239
  file_name="processed_image.png",
240
  mime="image/png"
241
  )
 
5
  from transformers import AutoModelForImageSegmentation, AutoImageProcessor, Swin2SRForImageSuperResolution
6
  import io
7
  import numpy as np
8
+ import gc # Garbage collection for memory safety
9
 
10
  # Page Configuration
11
  st.set_page_config(layout="wide", page_title="AI Image Lab")
12
 
13
+ # --- 1. MODEL LOADING (Cached Resource) ---
14
+ # Models are loaded once and stay in memory.
15
 
16
  @st.cache_resource
17
  def load_rembg_model():
 
18
  model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4", trust_remote_code=True)
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
  model.to(device)
 
22
 
23
  @st.cache_resource
24
  def load_upscaler(scale=2):
 
25
  if scale == 4:
26
  model_id = "caidas/swin2SR-classical-sr-x4-63"
27
  else:
28
  model_id = "caidas/swin2SR-classical-sr-x2-64"
 
29
  processor = AutoImageProcessor.from_pretrained(model_id)
30
  model = Swin2SRForImageSuperResolution.from_pretrained(model_id)
31
  return processor, model
32
 
33
+ # --- 2. HELPER FUNCTIONS ---
34
 
35
  def find_mask_tensor(output):
36
  """Recursively finds the mask tensor in complex model outputs."""
 
40
  elif output.dim() == 3 and output.shape[0] == 1:
41
  return output
42
  return None
 
43
  if hasattr(output, "logits"):
44
  return find_mask_tensor(output.logits)
45
  elif isinstance(output, (list, tuple)):
46
  for item in output:
47
  found = find_mask_tensor(item)
48
  if found is not None: return found
 
 
 
 
 
49
  return None
50
 
51
+ def run_swin_inference(image, processor, model):
52
+ """Atomic inference for a single chunk."""
53
+ inputs = processor(image, return_tensors="pt")
54
+ with torch.no_grad():
55
+ outputs = model(**inputs)
56
+
57
+ output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
58
+ output = np.moveaxis(output, 0, -1)
59
+ output = (output * 255.0).round().astype(np.uint8)
60
+ return Image.fromarray(output)
61
+
62
+ def upscale_chunk_logic(image, processor, model):
63
+ """Handles RGBA vs RGB logic for a single chunk."""
64
+ if image.mode == 'RGBA':
65
+ r, g, b, a = image.split()
66
+ rgb_image = Image.merge('RGB', (r, g, b))
67
+ upscaled_rgb = run_swin_inference(rgb_image, processor, model)
68
+ # Resize alpha to match new RGB size
69
+ upscaled_a = a.resize(upscaled_rgb.size, Image.Resampling.LANCZOS)
70
+ return Image.merge('RGBA', (*upscaled_rgb.split(), upscaled_a))
71
+ else:
72
+ return run_swin_inference(image, processor, model)
73
+
74
+ def convert_image_to_bytes(img):
75
+ buf = io.BytesIO()
76
+ img.save(buf, format="PNG")
77
+ return buf.getvalue()
78
+
79
+ # --- 3. HEAVY OPERATIONS (Cached Data) ---
80
+ # These functions cache their results. If inputs (image/settings) don't change,
81
+ # they return the previous result instantly without using RAM/CPU.
82
+
83
+ @st.cache_data(show_spinner=False)
84
+ def process_background_removal(image_bytes):
85
+ """
86
+ Removes background. Input is bytes to make it hashable for caching.
87
+ """
88
+ # Re-open image from bytes
89
+ image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
90
+
91
+ # Load model
92
+ model, device = load_rembg_model()
93
+
94
+ # Preprocessing
95
  w, h = image.size
96
  transform_image = transforms.Compose([
97
  transforms.Resize((1024, 1024)),
 
100
  ])
101
  input_images = transform_image(image).unsqueeze(0).to(device)
102
 
103
+ # Inference
104
  with torch.no_grad():
105
  outputs = model(input_images)
106
 
107
+ # Find Mask
108
  result_tensor = find_mask_tensor(outputs)
109
  if result_tensor is None:
110
  result_tensor = outputs[0] if isinstance(outputs, (list, tuple)) else outputs
111
+
112
  if not isinstance(result_tensor, torch.Tensor):
113
  if isinstance(result_tensor, (list, tuple)): result_tensor = result_tensor[0]
114
 
 
118
  pred_pil = transforms.ToPILImage()(pred)
119
  mask = pred_pil.resize((w, h))
120
  image.putalpha(mask)
 
 
 
 
 
 
 
121
 
122
+ return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
123
 
124
+ def process_tiled_upscale(image, scale_factor, grid_n, progress_bar=None):
125
  """
126
+ Splits image into n*n tiles, upscales each, and merges.
127
+ This function is NOT cached directly because it uses a progress bar (UI element).
128
+ We wrap the logic inside the main loop or a separate cached function if needed.
129
  """
130
+ # Load Model
131
+ processor, model = load_upscaler(scale_factor)
132
+
133
  w, h = image.size
134
+ rows = grid_n
135
+ cols = grid_n
136
 
137
  # Calculate tile sizes
138
  tile_w = w // cols
139
  tile_h = h // rows
140
 
141
+ # Create large canvas
142
  full_image = Image.new(image.mode, (w * scale_factor, h * scale_factor))
143
  total_tiles = rows * cols
144
  count = 0
145
 
146
  for y in range(rows):
147
  for x in range(cols):
148
+ # 1. Crop
149
  left = x * tile_w
150
  upper = y * tile_h
151
+ # Handle edge pixels (ensure last tile takes remainder)
152
  right = w if x == cols - 1 else (x + 1) * tile_w
153
  lower = h if y == rows - 1 else (y + 1) * tile_h
154
 
 
155
  tile = image.crop((left, upper, right, lower))
156
 
157
+ # 2. Upscale
158
+ upscaled_tile = upscale_chunk_logic(tile, processor, model)
159
 
160
+ # 3. Paste
161
  paste_x = left * scale_factor
162
  paste_y = upper * scale_factor
163
  full_image.paste(upscaled_tile, (paste_x, paste_y))
164
 
165
+ # 4. Memory Cleanup (Crucial for 16Gi limit)
166
+ del tile
167
+ del upscaled_tile
168
+ gc.collect()
169
+ if torch.cuda.is_available():
170
+ torch.cuda.empty_cache()
171
 
172
+ # 5. Update UI
173
+ count += 1
174
+ if progress_bar:
175
+ progress_bar.progress(count / total_tiles, text=f"Processing Tile {count}/{total_tiles}...")
176
+
177
  return full_image
178
 
179
+ # Wrapper for caching the upscale result (without progress bar args)
180
+ @st.cache_data(show_spinner=False)
181
+ def cached_upscale_wrapper(image_bytes, scale_factor, grid_n):
182
+ """
183
+ This wrapper allows us to cache the upscale result.
184
+ We convert PIL->Bytes->PIL inside to ensure Streamlit can hash the input.
185
+ """
186
+ image = Image.open(io.BytesIO(image_bytes))
187
+ # We cannot pass the progress bar to a cached function,
188
+ # so we run it without the bar or handle the bar outside.
189
+ # For caching purposes, we run it 'quietly'.
190
+ return process_tiled_upscale(image, scale_factor, grid_n, progress_bar=None)
191
 
192
+ # --- 4. MAIN APP ---
193
 
194
  def main():
195
+ st.title("✨ AI Image Lab: Memory Safe")
196
+ st.markdown("Features: **RMBG-1.4** | **Swin2SR (Tiled)** | **Smart Caching**")
197
 
198
  # --- Sidebar ---
199
  st.sidebar.header("1. Background")
200
  remove_bg = st.sidebar.checkbox("Remove Background", value=False)
201
 
202
  st.sidebar.header("2. AI Upscaling")
203
+ upscale_mode = st.sidebar.radio("Magnification", ["None", "2x", "4x"])
204
 
205
+ # Grid Slider for Memory Safety
206
+ if upscale_mode != "None":
207
+ grid_n = st.sidebar.slider(
208
+ "Grid Split (Memory Saver)",
209
+ min_value=2,
210
+ max_value=8,
211
+ value=4,
212
+ help="Higher = Less RAM used, but slightly slower. If crashing, increase this!"
213
+ )
214
+ st.sidebar.info(f"Splitting image into {grid_n*grid_n} pieces.")
215
+ else:
216
+ grid_n = 2
217
+
218
  st.sidebar.header("3. Geometry")
219
  rotate_angle = st.sidebar.slider("Rotate", -180, 180, 0, 1)
220
 
221
+ # --- Main Logic ---
222
  uploaded_file = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg", "webp"])
223
 
224
  if uploaded_file is not None:
225
+ # Load Original
226
+ file_bytes = uploaded_file.getvalue() # Keep raw bytes for caching references
227
+ image = Image.open(io.BytesIO(file_bytes)).convert("RGB")
228
 
229
+ # --- PIPELINE START ---
230
+
231
+ # Step 1: Background Removal (Cached)
232
  if remove_bg:
233
+ with st.spinner("Removing background..."):
234
+ # We pass bytes to the cached function
235
+ processed_image = process_background_removal(file_bytes)
236
+ else:
237
+ processed_image = image
238
+
239
+ # Step 2: Upscaling (Cached manually or via wrapper)
 
 
240
  if upscale_mode != "None":
241
  scale = 4 if "4x" in upscale_mode else 2
242
 
243
+ # Convert current stage to bytes for cache key
244
+ current_stage_bytes = convert_image_to_bytes(processed_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
 
246
+ # Check if we should use the cached wrapper or run with progress bar
247
+ # To preserve the "Progress Bar" experience while still caching, we can:
248
+ # Check if it's already in cache? Streamlit doesn't expose `is_cached`.
249
+ # We will use the cached wrapper. The downside: the first run won't show the detailed tile progress
250
+ # inside the cached function, just the spinner.
251
+
252
+ with st.spinner(f"Upscaling x{scale} ({grid_n*grid_n} tiles)..."):
253
+ processed_image = cached_upscale_wrapper(current_stage_bytes, scale, grid_n)
254
+
255
+ # Step 3: Geometry (Fast - No Caching needed, applied on top)
256
+ # This runs every time you move the slider, but Step 1 & 2 use cache, so it's instant.
257
+ final_image = processed_image.copy()
258
  if rotate_angle != 0:
259
+ final_image = final_image.rotate(rotate_angle, expand=True)
260
 
261
  # --- Display ---
262
  col1, col2 = st.columns(2)
 
267
 
268
  with col2:
269
  st.subheader("Result")
270
+ st.image(final_image, use_container_width=True)
271
+ st.caption(f"Size: {final_image.size}")
272
 
273
  # --- Download ---
274
  st.markdown("---")
275
  st.download_button(
276
  label="💾 Download Result (PNG)",
277
+ data=convert_image_to_bytes(final_image),
278
  file_name="processed_image.png",
279
  mime="image/png"
280
  )