lukeafullard commited on
Commit
f753191
·
verified ·
1 Parent(s): 948b624

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +15 -23
src/streamlit_app.py CHANGED
@@ -7,13 +7,14 @@ import io
7
  import numpy as np
8
  import gc
9
 
10
- # Page Configuration d
11
  st.set_page_config(layout="wide", page_title="AI Image Lab")
12
 
13
  # --- 1. MODEL LOADING ---
14
 
15
  @st.cache_resource
16
  def load_rembg_model():
 
17
  model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4", trust_remote_code=True)
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
  model.to(device)
@@ -22,9 +23,12 @@ def load_rembg_model():
22
  @st.cache_resource
23
  def load_upscaler(scale=2):
24
  if scale == 4:
25
- model_id = "caidas/swin2SR-classical-sr-x4-63"
 
26
  else:
 
27
  model_id = "caidas/swin2SR-classical-sr-x2-64"
 
28
  processor = AutoImageProcessor.from_pretrained(model_id)
29
  model = Swin2SRForImageSuperResolution.from_pretrained(model_id)
30
  return processor, model
@@ -32,6 +36,7 @@ def load_upscaler(scale=2):
32
  # --- 2. PROCESSING LOGIC ---
33
 
34
  def find_mask_tensor(output):
 
35
  if isinstance(output, torch.Tensor):
36
  if output.dim() == 4 and output.shape[1] == 1: return output
37
  elif output.dim() == 3 and output.shape[0] == 1: return output
@@ -64,7 +69,7 @@ def upscale_chunk_logic(image, processor, model):
64
 
65
  @st.cache_data(show_spinner=False)
66
  def process_background_removal(image_bytes):
67
- """Cached background removal."""
68
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
69
  model, device = load_rembg_model()
70
 
@@ -94,17 +99,16 @@ def process_background_removal(image_bytes):
94
 
95
  def process_tiled_upscale(image, scale_factor, grid_n, progress_bar):
96
  """
97
- Tiled upscaling with OVERLAP to prevent edge artifacts.
98
  """
99
  processor, model = load_upscaler(scale_factor)
100
  w, h = image.size
101
  rows = cols = grid_n
102
 
103
- # Base tile size (without overlap)
104
  tile_w = w // cols
105
  tile_h = h // rows
106
 
107
- # Overlap buffer (pixels) - lets the AI see context
108
  overlap = 32
109
 
110
  full_image = Image.new(image.mode, (w * scale_factor, h * scale_factor))
@@ -113,52 +117,40 @@ def process_tiled_upscale(image, scale_factor, grid_n, progress_bar):
113
 
114
  for y in range(rows):
115
  for x in range(cols):
116
- # 1. Define the "Target" area (where this tile goes in the original)
117
  target_left = x * tile_w
118
  target_upper = y * tile_h
119
- # Handle edge pixels for the last column/row
120
  target_right = w if x == cols - 1 else (x + 1) * tile_w
121
  target_lower = h if y == rows - 1 else (y + 1) * tile_h
122
-
123
  target_w = target_right - target_left
124
  target_h = target_lower - target_upper
125
 
126
- # 2. Define the "Source" area (Target + Overlap)
127
- # We expand the box outwards by 'overlap' px, but keep it within image bounds
128
  source_left = max(0, target_left - overlap)
129
  source_upper = max(0, target_upper - overlap)
130
  source_right = min(w, target_right + overlap)
131
  source_lower = min(h, target_lower + overlap)
132
 
133
- # Crop the padded tile
134
  tile = image.crop((source_left, source_upper, source_right, source_lower))
135
-
136
- # 3. Upscale the Padded Tile
137
  upscaled_tile = upscale_chunk_logic(tile, processor, model)
138
 
139
- # 4. Crop the "Valid" center from the upscaled tile
140
- # Calculate how much extra we took on the Left and Top (in original scale)
141
  extra_left = target_left - source_left
142
  extra_upper = target_upper - source_upper
143
 
144
- # Convert these offsets to the new Upscaled Scale
145
  crop_x = extra_left * scale_factor
146
  crop_y = extra_upper * scale_factor
147
  crop_w = target_w * scale_factor
148
  crop_h = target_h * scale_factor
149
 
150
- # Perform the final crop to remove the overlap borders
151
  clean_tile = upscaled_tile.crop((crop_x, crop_y, crop_x + crop_w, crop_y + crop_h))
152
 
153
- # 5. Paste the clean tile
154
  paste_x = target_left * scale_factor
155
  paste_y = target_upper * scale_factor
156
  full_image.paste(clean_tile, (paste_x, paste_y))
157
 
158
- # Cleanup
159
  del tile, upscaled_tile, clean_tile
160
  gc.collect()
161
-
162
  count += 1
163
  progress_bar.progress(count / total_tiles, text=f"Upscaling Tile {count}/{total_tiles} (with overlap)...")
164
 
@@ -172,7 +164,7 @@ def convert_image_to_bytes(img):
172
  # --- 3. MAIN APP ---
173
 
174
  def main():
175
- st.title("✨ AI Image Lab: Seamless Edition")
176
  st.markdown("Features: **RMBG-1.4** | **Swin2SR (Seamless Tiling)** | **Progress Bar**")
177
 
178
  # --- Sidebar ---
@@ -207,7 +199,7 @@ def main():
207
  scale = 4 if "4x" in upscale_mode else 2
208
 
209
  # Cache Key
210
- cache_key = f"{uploaded_file.name}_{remove_bg}_{scale}_{grid_n}_overlap"
211
 
212
  if "upscale_cache" not in st.session_state:
213
  st.session_state.upscale_cache = {}
 
7
  import numpy as np
8
  import gc
9
 
10
+ # Page Configuration
11
  st.set_page_config(layout="wide", page_title="AI Image Lab")
12
 
13
  # --- 1. MODEL LOADING ---
14
 
15
  @st.cache_resource
16
  def load_rembg_model():
17
+ # RMBG-1.4 (Fast & High Quality)
18
  model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4", trust_remote_code=True)
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
  model.to(device)
 
23
  @st.cache_resource
24
  def load_upscaler(scale=2):
25
  if scale == 4:
26
+ # FIXED: Use the 'RealWorld' model for 4x. It exists and handles artifacts better.
27
+ model_id = "caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr"
28
  else:
29
+ # 2x Classical Model
30
  model_id = "caidas/swin2SR-classical-sr-x2-64"
31
+
32
  processor = AutoImageProcessor.from_pretrained(model_id)
33
  model = Swin2SRForImageSuperResolution.from_pretrained(model_id)
34
  return processor, model
 
36
  # --- 2. PROCESSING LOGIC ---
37
 
38
  def find_mask_tensor(output):
39
+ """Recursively finds the mask tensor."""
40
  if isinstance(output, torch.Tensor):
41
  if output.dim() == 4 and output.shape[1] == 1: return output
42
  elif output.dim() == 3 and output.shape[0] == 1: return output
 
69
 
70
  @st.cache_data(show_spinner=False)
71
  def process_background_removal(image_bytes):
72
+ """Cached background removal (RMBG-1.4)."""
73
  image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
74
  model, device = load_rembg_model()
75
 
 
99
 
100
  def process_tiled_upscale(image, scale_factor, grid_n, progress_bar):
101
  """
102
+ Tiled upscaling with OVERLAP to prevent seams.
103
  """
104
  processor, model = load_upscaler(scale_factor)
105
  w, h = image.size
106
  rows = cols = grid_n
107
 
 
108
  tile_w = w // cols
109
  tile_h = h // rows
110
 
111
+ # Overlap buffer (pixels)
112
  overlap = 32
113
 
114
  full_image = Image.new(image.mode, (w * scale_factor, h * scale_factor))
 
117
 
118
  for y in range(rows):
119
  for x in range(cols):
120
+ # Target Area
121
  target_left = x * tile_w
122
  target_upper = y * tile_h
 
123
  target_right = w if x == cols - 1 else (x + 1) * tile_w
124
  target_lower = h if y == rows - 1 else (y + 1) * tile_h
 
125
  target_w = target_right - target_left
126
  target_h = target_lower - target_upper
127
 
128
+ # Source Area (with overlap)
 
129
  source_left = max(0, target_left - overlap)
130
  source_upper = max(0, target_upper - overlap)
131
  source_right = min(w, target_right + overlap)
132
  source_lower = min(h, target_lower + overlap)
133
 
 
134
  tile = image.crop((source_left, source_upper, source_right, source_lower))
 
 
135
  upscaled_tile = upscale_chunk_logic(tile, processor, model)
136
 
137
+ # Calculate offsets for cropping the valid center
 
138
  extra_left = target_left - source_left
139
  extra_upper = target_upper - source_upper
140
 
 
141
  crop_x = extra_left * scale_factor
142
  crop_y = extra_upper * scale_factor
143
  crop_w = target_w * scale_factor
144
  crop_h = target_h * scale_factor
145
 
 
146
  clean_tile = upscaled_tile.crop((crop_x, crop_y, crop_x + crop_w, crop_y + crop_h))
147
 
 
148
  paste_x = target_left * scale_factor
149
  paste_y = target_upper * scale_factor
150
  full_image.paste(clean_tile, (paste_x, paste_y))
151
 
 
152
  del tile, upscaled_tile, clean_tile
153
  gc.collect()
 
154
  count += 1
155
  progress_bar.progress(count / total_tiles, text=f"Upscaling Tile {count}/{total_tiles} (with overlap)...")
156
 
 
164
  # --- 3. MAIN APP ---
165
 
166
  def main():
167
+ st.title("✨ AI Image Lab: Final Edition")
168
  st.markdown("Features: **RMBG-1.4** | **Swin2SR (Seamless Tiling)** | **Progress Bar**")
169
 
170
  # --- Sidebar ---
 
199
  scale = 4 if "4x" in upscale_mode else 2
200
 
201
  # Cache Key
202
+ cache_key = f"{uploaded_file.name}_{remove_bg}_{scale}_{grid_n}_overlap_v4"
203
 
204
  if "upscale_cache" not in st.session_state:
205
  st.session_state.upscale_cache = {}