lukeafullard commited on
Commit
a533ab1
·
verified ·
1 Parent(s): d21ed78

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +96 -68
src/streamlit_app.py CHANGED
@@ -1,12 +1,12 @@
1
  import streamlit as st
2
- from PIL import Image, ImageEnhance
3
  import torch
4
  from torchvision import transforms
5
  from transformers import AutoModelForImageSegmentation, AutoImageProcessor, Swin2SRForImageSuperResolution
6
  import io
7
  import numpy as np
8
 
9
- # Page Configuration g
10
  st.set_page_config(layout="wide", page_title="AI Image Lab")
11
 
12
  # --- 1. MODEL LOADING (Cached) ---
@@ -14,7 +14,6 @@ st.set_page_config(layout="wide", page_title="AI Image Lab")
14
  @st.cache_resource
15
  def load_rembg_model():
16
  """Loads RMBG-1.4 for Background Removal."""
17
- # We use 'briaai/RMBG-1.4'
18
  model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4", trust_remote_code=True)
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
  model.to(device)
@@ -35,45 +34,30 @@ def load_upscaler(scale=2):
35
  # --- 2. PROCESSING FUNCTIONS ---
36
 
37
  def find_mask_tensor(output):
38
- """
39
- Recursively searches any nested structure (list, tuple, dict, object)
40
- to find the first Tensor that looks like a mask (1 channel).
41
- """
42
- # 1. If it's a Tensor, check if it's the mask we want
43
  if isinstance(output, torch.Tensor):
44
- # We look for shape [Batch, 1, H, W] or [1, H, W]
45
- # It must have 1 channel (index 1 for 4D, index 0 for 3D)
46
  if output.dim() == 4 and output.shape[1] == 1:
47
  return output
48
  elif output.dim() == 3 and output.shape[0] == 1:
49
  return output
50
- # If it has > 1 channels (e.g. 64), it's a feature map, ignore it.
51
  return None
52
-
53
- # 2. If it's a Dict/ModelOutput (like .logits), check values
54
- if hasattr(output, "items"):
55
- for val in output.values():
56
- found = find_mask_tensor(val)
57
- if found is not None: return found
58
- # Special case for Hugging Face model outputs with attributes
59
- elif hasattr(output, "logits"):
60
  return find_mask_tensor(output.logits)
61
-
62
- # 3. If it's a List or Tuple, iterate through elements
63
  elif isinstance(output, (list, tuple)):
64
  for item in output:
65
  found = find_mask_tensor(item)
66
  if found is not None: return found
 
 
 
 
67
 
68
  return None
69
 
70
  def safe_rembg_inference(model, image, device):
71
- """
72
- Robust inference for RMBG-1.4 using Deep Search.
73
- """
74
  w, h = image.size
75
-
76
- # Preprocessing
77
  transform_image = transforms.Compose([
78
  transforms.Resize((1024, 1024)),
79
  transforms.ToTensor(),
@@ -81,44 +65,37 @@ def safe_rembg_inference(model, image, device):
81
  ])
82
  input_images = transform_image(image).unsqueeze(0).to(device)
83
 
84
- # Inference
85
  with torch.no_grad():
86
  outputs = model(input_images)
87
 
88
- # --- DEEP SEARCH FOR MASK ---
89
  result_tensor = find_mask_tensor(outputs)
90
-
91
  if result_tensor is None:
92
- # Fallback: If deep search failed, try just grabbing the first tensor found
93
- # (Even if dimensions look weird, it's better than crashing)
94
- if isinstance(outputs, (list, tuple)):
95
- result_tensor = outputs[0]
96
- else:
97
- result_tensor = outputs
98
-
99
- # Post-processing
100
- # Ensure it's a tensor before operations
101
  if not isinstance(result_tensor, torch.Tensor):
102
- # If we still have a list here, we take the first element blindly
103
- if isinstance(result_tensor, (list, tuple)):
104
- result_tensor = result_tensor[0]
105
 
106
  pred = result_tensor.squeeze().cpu()
107
-
108
- # Sometimes output is already sigmoid, sometimes logits.
109
- # If values are > 1 or < 0, apply sigmoid.
110
- if pred.max() > 1 or pred.min() < 0:
111
- pred = pred.sigmoid()
112
 
113
- # Convert mask to PIL
114
  pred_pil = transforms.ToPILImage()(pred)
115
  mask = pred_pil.resize((w, h))
116
-
117
- # Apply mask
118
  image.putalpha(mask)
119
  return image
120
 
121
- def ai_upscale(image, processor, model):
 
 
 
 
 
 
 
 
 
 
 
 
122
  if image.mode == 'RGBA':
123
  r, g, b, a = image.split()
124
  rgb_image = Image.merge('RGB', (r, g, b))
@@ -128,15 +105,46 @@ def ai_upscale(image, processor, model):
128
  else:
129
  return run_swin_inference(image, processor, model)
130
 
131
- def run_swin_inference(image, processor, model):
132
- inputs = processor(image, return_tensors="pt")
133
- with torch.no_grad():
134
- outputs = model(**inputs)
 
 
135
 
136
- output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
137
- output = np.moveaxis(output, 0, -1)
138
- output = (output * 255.0).round().astype(np.uint8)
139
- return Image.fromarray(output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  def convert_image_to_bytes(img):
142
  buf = io.BytesIO()
@@ -146,15 +154,15 @@ def convert_image_to_bytes(img):
146
  # --- 3. MAIN APP ---
147
 
148
  def main():
149
- st.title("✨ AI Image Lab: Robust Edition")
150
- st.markdown("Features: **RMBG-1.4 (Pure PyTorch)** | **Swin2SR (Upscaling)** | **Geometry**")
151
 
152
  # --- Sidebar ---
153
  st.sidebar.header("1. Background")
154
  remove_bg = st.sidebar.checkbox("Remove Background", value=False)
155
 
156
  st.sidebar.header("2. AI Upscaling")
157
- upscale_mode = st.sidebar.radio("Magnification", ["None", "2x (Fast)", "4x (Slow)"])
158
 
159
  st.sidebar.header("3. Geometry")
160
  rotate_angle = st.sidebar.slider("Rotate", -180, 180, 0, 1)
@@ -179,13 +187,33 @@ def main():
179
  # 2. Upscaling
180
  if upscale_mode != "None":
181
  scale = 4 if "4x" in upscale_mode else 2
182
- st.info(f"Loading Swin2SR x{scale} Model...")
183
- try:
184
- processor, upscaler = load_upscaler(scale)
185
- with st.spinner(f"Upscaling x{scale}..."):
186
- processed_image = ai_upscale(processed_image, processor, upscaler)
187
- except Exception as e:
188
- st.error(f"Upscaling Failed: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
 
190
  # 3. Rotation
191
  if rotate_angle != 0:
 
1
  import streamlit as st
2
+ from PIL import Image
3
  import torch
4
  from torchvision import transforms
5
  from transformers import AutoModelForImageSegmentation, AutoImageProcessor, Swin2SRForImageSuperResolution
6
  import io
7
  import numpy as np
8
 
9
+ # Page Configuration
10
  st.set_page_config(layout="wide", page_title="AI Image Lab")
11
 
12
  # --- 1. MODEL LOADING (Cached) ---
 
14
  @st.cache_resource
15
  def load_rembg_model():
16
  """Loads RMBG-1.4 for Background Removal."""
 
17
  model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4", trust_remote_code=True)
18
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
19
  model.to(device)
 
34
  # --- 2. PROCESSING FUNCTIONS ---
35
 
36
  def find_mask_tensor(output):
37
+ """Recursively finds the mask tensor in complex model outputs."""
 
 
 
 
38
  if isinstance(output, torch.Tensor):
 
 
39
  if output.dim() == 4 and output.shape[1] == 1:
40
  return output
41
  elif output.dim() == 3 and output.shape[0] == 1:
42
  return output
 
43
  return None
44
+
45
+ if hasattr(output, "logits"):
 
 
 
 
 
 
46
  return find_mask_tensor(output.logits)
 
 
47
  elif isinstance(output, (list, tuple)):
48
  for item in output:
49
  found = find_mask_tensor(item)
50
  if found is not None: return found
51
+ elif hasattr(output, "items"):
52
+ for val in output.values():
53
+ found = find_mask_tensor(val)
54
+ if found is not None: return found
55
 
56
  return None
57
 
58
  def safe_rembg_inference(model, image, device):
59
+ """Robust background removal inference."""
 
 
60
  w, h = image.size
 
 
61
  transform_image = transforms.Compose([
62
  transforms.Resize((1024, 1024)),
63
  transforms.ToTensor(),
 
65
  ])
66
  input_images = transform_image(image).unsqueeze(0).to(device)
67
 
 
68
  with torch.no_grad():
69
  outputs = model(input_images)
70
 
 
71
  result_tensor = find_mask_tensor(outputs)
 
72
  if result_tensor is None:
73
+ result_tensor = outputs[0] if isinstance(outputs, (list, tuple)) else outputs
74
+
 
 
 
 
 
 
 
75
  if not isinstance(result_tensor, torch.Tensor):
76
+ if isinstance(result_tensor, (list, tuple)): result_tensor = result_tensor[0]
 
 
77
 
78
  pred = result_tensor.squeeze().cpu()
79
+ if pred.max() > 1 or pred.min() < 0: pred = pred.sigmoid()
 
 
 
 
80
 
 
81
  pred_pil = transforms.ToPILImage()(pred)
82
  mask = pred_pil.resize((w, h))
 
 
83
  image.putalpha(mask)
84
  return image
85
 
86
+ def run_swin_inference(image, processor, model):
87
+ """Atomic inference for a single image/tile."""
88
+ inputs = processor(image, return_tensors="pt")
89
+ with torch.no_grad():
90
+ outputs = model(**inputs)
91
+
92
+ output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
93
+ output = np.moveaxis(output, 0, -1)
94
+ output = (output * 255.0).round().astype(np.uint8)
95
+ return Image.fromarray(output)
96
+
97
+ def upscale_image_logic(image, processor, model):
98
+ """Handles RGBA vs RGB logic for a single chunk."""
99
  if image.mode == 'RGBA':
100
  r, g, b, a = image.split()
101
  rgb_image = Image.merge('RGB', (r, g, b))
 
105
  else:
106
  return run_swin_inference(image, processor, model)
107
 
108
+ def tiled_upscale(image, processor, model, scale_factor, progress_bar):
109
+ """
110
+ Splits image into a 2x2 grid, upscales each tile, and updates progress bar.
111
+ """
112
+ rows, cols = 2, 2 # Split into 4 tiles
113
+ w, h = image.size
114
 
115
+ # Calculate tile sizes
116
+ tile_w = w // cols
117
+ tile_h = h // rows
118
+
119
+ full_image = Image.new(image.mode, (w * scale_factor, h * scale_factor))
120
+ total_tiles = rows * cols
121
+ count = 0
122
+
123
+ for y in range(rows):
124
+ for x in range(cols):
125
+ # Define crop box
126
+ left = x * tile_w
127
+ upper = y * tile_h
128
+ # Ensure the last tile takes the remaining pixels (fixes rounding errors)
129
+ right = w if x == cols - 1 else (x + 1) * tile_w
130
+ lower = h if y == rows - 1 else (y + 1) * tile_h
131
+
132
+ # Crop
133
+ tile = image.crop((left, upper, right, lower))
134
+
135
+ # Upscale the tile
136
+ upscaled_tile = upscale_image_logic(tile, processor, model)
137
+
138
+ # Paste into new canvas
139
+ paste_x = left * scale_factor
140
+ paste_y = upper * scale_factor
141
+ full_image.paste(upscaled_tile, (paste_x, paste_y))
142
+
143
+ # Update Progress
144
+ count += 1
145
+ progress_bar.progress(count / total_tiles, text=f"Upscaling Tile {count}/{total_tiles}...")
146
+
147
+ return full_image
148
 
149
  def convert_image_to_bytes(img):
150
  buf = io.BytesIO()
 
154
  # --- 3. MAIN APP ---
155
 
156
  def main():
157
+ st.title("✨ AI Image Lab: Tiled Edition")
158
+ st.markdown("Features: **RMBG-1.4** | **Swin2SR (Tiled)** | **Geometry**")
159
 
160
  # --- Sidebar ---
161
  st.sidebar.header("1. Background")
162
  remove_bg = st.sidebar.checkbox("Remove Background", value=False)
163
 
164
  st.sidebar.header("2. AI Upscaling")
165
+ upscale_mode = st.sidebar.radio("Magnification", ["None", "2x (Fast)", "4x (Slow - Tiled)"])
166
 
167
  st.sidebar.header("3. Geometry")
168
  rotate_angle = st.sidebar.slider("Rotate", -180, 180, 0, 1)
 
187
  # 2. Upscaling
188
  if upscale_mode != "None":
189
  scale = 4 if "4x" in upscale_mode else 2
190
+
191
+ # If 4x, use the Progress Bar + Tiling method
192
+ if scale == 4:
193
+ st.info(f"Loading Swin2SR x{scale} Model...")
194
+ try:
195
+ processor, upscaler = load_upscaler(scale)
196
+
197
+ # Create Progress Bar
198
+ my_bar = st.progress(0, text="Starting Tiled Upscaling...")
199
+
200
+ processed_image = tiled_upscale(processed_image, processor, upscaler, scale, my_bar)
201
+
202
+ # Clear bar when done
203
+ my_bar.empty()
204
+
205
+ except Exception as e:
206
+ st.error(f"Upscaling Failed: {e}")
207
+
208
+ # If 2x, keep it simple (it's fast enough)
209
+ else:
210
+ st.info(f"Loading Swin2SR x{scale} Model...")
211
+ try:
212
+ processor, upscaler = load_upscaler(scale)
213
+ with st.spinner("Upscaling (2x)..."):
214
+ processed_image = upscale_image_logic(processed_image, processor, upscaler)
215
+ except Exception as e:
216
+ st.error(f"Upscaling Failed: {e}")
217
 
218
  # 3. Rotation
219
  if rotate_angle != 0: