lukeafullard commited on
Commit
169cdb3
·
verified ·
1 Parent(s): dbc9d93

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +115 -47
src/streamlit_app.py CHANGED
@@ -3,100 +3,166 @@ from PIL import Image, ImageEnhance
3
  import torch
4
  import torch.nn.functional as F
5
  from torchvision import transforms
6
- from transformers import AutoModelForImageSegmentation
7
  import io
8
  import numpy as np
9
 
10
  # Page Configuration
11
  st.set_page_config(layout="wide", page_title="AI Image Lab")
12
 
13
- # --- Caching AI Models ---
14
 
15
  @st.cache_resource
16
- def load_birefnet_model():
17
- """
18
- Loads the RMBG-1.4 model for Background Removal (Pure PyTorch).
19
- """
20
- # We use 'briaai/RMBG-1.4' which is SOTA for background removal
21
  model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4", trust_remote_code=True)
22
-
23
- # Move to GPU if available, otherwise CPU
24
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
  model.to(device)
26
  return model, device
27
 
28
- # ... (Previous Upscalers kept for reference, you can re-add them if you wish) ...
29
-
30
- def remove_background_torch(image, model, device):
 
 
 
 
 
 
 
 
 
 
 
 
31
  """
32
- Runs background removal using RMBG-1.4 on PyTorch.
33
  """
34
- # 1. Prepare input
35
  w, h = image.size
36
 
37
- # The model expects specific normalization and size
38
  transform_image = transforms.Compose([
39
  transforms.Resize((1024, 1024)),
40
  transforms.ToTensor(),
41
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
42
  ])
43
-
44
  input_images = transform_image(image).unsqueeze(0).to(device)
45
 
46
- # 2. Inference
47
  with torch.no_grad():
48
- preds = model(input_images)[-1].sigmoid().cpu()
49
 
50
- # 3. Post-process mask
51
- pred = preds[0].squeeze()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
52
 
53
- # Convert mask to PIL and resize back to original dimensions
54
  pred_pil = transforms.ToPILImage()(pred)
55
  mask = pred_pil.resize((w, h))
56
 
57
- # 4. Apply mask to original image
58
  image.putalpha(mask)
59
  return image
60
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  def convert_image_to_bytes(img):
62
  buf = io.BytesIO()
63
  img.save(buf, format="PNG")
64
  return buf.getvalue()
65
 
 
 
66
  def main():
67
- st.title("✨ AI Image Lab: Pure PyTorch Edition")
68
- st.markdown("Processing pipeline: **RMBG-1.4 (No ONNX)**")
69
- st.write("XSRF:", st.get_option("server.enableXsrfProtection"))
70
 
71
- # --- Sidebar Controls ---
72
- st.sidebar.header("Processing Pipeline")
 
73
 
74
- remove_bg = st.sidebar.checkbox("Remove Background (RMBG-1.4)", value=False)
 
75
 
76
- st.sidebar.subheader("Final Adjustments")
77
  rotate_angle = st.sidebar.slider("Rotate", -180, 180, 0, 1)
78
 
79
- # --- Main Content ---
80
- uploaded_file = st.file_uploader("Upload an image...", type=["jpg", "jpeg", "png", "webp"])
81
 
82
  if uploaded_file is not None:
83
- # Important: RMBG model works best if we ensure RGB mode
84
  image = Image.open(uploaded_file).convert("RGB")
 
 
85
  processed_image = image.copy()
86
-
87
- # --- STEP 1: Background Removal ---
88
  if remove_bg:
89
- st.info("Loading RMBG-1.4 Model (First run will download ~170MB)...")
90
  try:
91
- # Load Model
92
- model, device = load_birefnet_model()
93
-
94
- with st.spinner("Removing background using PyTorch..."):
95
- processed_image = remove_background_torch(processed_image, model, device)
96
  except Exception as e:
97
- st.error(f"Error during background removal: {e}")
98
 
99
- # --- STEP 2: Geometry/Color ---
 
 
 
 
 
 
 
 
 
 
 
100
  if rotate_angle != 0:
101
  processed_image = processed_image.rotate(rotate_angle, expand=True)
102
 
@@ -105,18 +171,20 @@ def main():
105
  with col1:
106
  st.subheader("Original")
107
  st.image(image, use_container_width=True)
108
-
 
109
  with col2:
110
  st.subheader("Result")
111
  st.image(processed_image, use_container_width=True)
 
112
 
113
  # --- Download ---
114
  st.markdown("---")
115
- btn = st.download_button(
116
- label="💾 Download Result",
117
  data=convert_image_to_bytes(processed_image),
118
- file_name="nobg_image.png",
119
- mime="image/png",
120
  )
121
 
122
  if __name__ == "__main__":
 
3
  import torch
4
  import torch.nn.functional as F
5
  from torchvision import transforms
6
+ from transformers import AutoModelForImageSegmentation, AutoImageProcessor, Swin2SRForImageSuperResolution
7
  import io
8
  import numpy as np
9
 
10
  # Page Configuration
11
  st.set_page_config(layout="wide", page_title="AI Image Lab")
12
 
13
+ # --- 1. MODEL LOADING (Cached) ---
14
 
15
  @st.cache_resource
16
+ def load_rembg_model():
17
+ """Loads RMBG-1.4 for Background Removal."""
 
 
 
18
  model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4", trust_remote_code=True)
 
 
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
  model.to(device)
21
  return model, device
22
 
23
+ @st.cache_resource
24
+ def load_upscaler(scale=2):
25
+ """Loads Swin2SR for Super-Resolution (2x or 4x)."""
26
+ if scale == 4:
27
+ model_id = "caidas/swin2SR-classical-sr-x4-63"
28
+ else:
29
+ model_id = "caidas/swin2SR-classical-sr-x2-64"
30
+
31
+ processor = AutoImageProcessor.from_pretrained(model_id)
32
+ model = Swin2SRForImageSuperResolution.from_pretrained(model_id)
33
+ return processor, model
34
+
35
+ # --- 2. PROCESSING FUNCTIONS ---
36
+
37
+ def safe_rembg_inference(model, image, device):
38
  """
39
+ Robust inference for RMBG-1.4 that handles different output formats.
40
  """
 
41
  w, h = image.size
42
 
43
+ # Preprocessing
44
  transform_image = transforms.Compose([
45
  transforms.Resize((1024, 1024)),
46
  transforms.ToTensor(),
47
  transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
48
  ])
 
49
  input_images = transform_image(image).unsqueeze(0).to(device)
50
 
51
+ # Inference
52
  with torch.no_grad():
53
+ outputs = model(input_images)
54
 
55
+ # FIX: Handle List vs Tuple vs Tensor output
56
+ # BiRefNet usually returns a list/tuple of tensors.
57
+ # The output we want is usually the LAST element or the FIRST depending on version.
58
+ # We check if 'outputs' is a sequence (list/tuple) and grab the tensor.
59
+ if isinstance(outputs, (list, tuple)):
60
+ # We assume the last element is the high-res prediction for RMBG-1.4
61
+ result_tensor = outputs[-1]
62
+
63
+ # Double check: if the result is still a list (nested), grab the first item
64
+ if isinstance(result_tensor, (list, tuple)):
65
+ result_tensor = result_tensor[0]
66
+ else:
67
+ result_tensor = outputs
68
+
69
+ # Post-processing
70
+ pred = result_tensor.sigmoid().cpu()[0].squeeze()
71
 
72
+ # Convert mask to PIL
73
  pred_pil = transforms.ToPILImage()(pred)
74
  mask = pred_pil.resize((w, h))
75
 
76
+ # Apply mask
77
  image.putalpha(mask)
78
  return image
79
 
80
+ def ai_upscale(image, processor, model):
81
+ """
82
+ Upscales RGB image using Swin2SR.
83
+ Note: Swin2SR only works on RGB. If image is RGBA, we must handle Alpha separately.
84
+ """
85
+ # 1. Handle Alpha Channel (if exists)
86
+ if image.mode == 'RGBA':
87
+ # Split RGB and Alpha
88
+ r, g, b, a = image.split()
89
+ rgb_image = Image.merge('RGB', (r, g, b))
90
+
91
+ # Upscale RGB using AI
92
+ upscaled_rgb = run_swin_inference(rgb_image, processor, model)
93
+
94
+ # Upscale Alpha using standard interpolation (AI models don't predict alpha)
95
+ # We resize alpha to match the new RGB size
96
+ upscaled_a = a.resize(upscaled_rgb.size, Image.Resampling.LANCZOS)
97
+
98
+ # Recombine
99
+ return Image.merge('RGBA', (*upscaled_rgb.split(), upscaled_a))
100
+ else:
101
+ return run_swin_inference(image, processor, model)
102
+
103
+ def run_swin_inference(image, processor, model):
104
+ """Helper to run the actual Swin2SR inference on an RGB image."""
105
+ inputs = processor(image, return_tensors="pt")
106
+ with torch.no_grad():
107
+ outputs = model(**inputs)
108
+
109
+ output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
110
+ output = np.moveaxis(output, 0, -1)
111
+ output = (output * 255.0).round().astype(np.uint8)
112
+ return Image.fromarray(output)
113
+
114
  def convert_image_to_bytes(img):
115
  buf = io.BytesIO()
116
  img.save(buf, format="PNG")
117
  return buf.getvalue()
118
 
119
+ # --- 3. MAIN APP ---
120
+
121
  def main():
122
+ st.title("✨ AI Image Lab: Robust Edition")
123
+ st.markdown("Features: **RMBG-1.4 (No ONNX)** | **Swin2SR (Upscaling)** | **Geometry**")
 
124
 
125
+ # --- Sidebar ---
126
+ st.sidebar.header("1. Background")
127
+ remove_bg = st.sidebar.checkbox("Remove Background", value=False)
128
 
129
+ st.sidebar.header("2. AI Upscaling")
130
+ upscale_mode = st.sidebar.radio("Magnification", ["None", "2x (Fast)", "4x (Slow)"])
131
 
132
+ st.sidebar.header("3. Geometry")
133
  rotate_angle = st.sidebar.slider("Rotate", -180, 180, 0, 1)
134
 
135
+ # --- Main ---
136
+ uploaded_file = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg", "webp"])
137
 
138
  if uploaded_file is not None:
 
139
  image = Image.open(uploaded_file).convert("RGB")
140
+
141
+ # Create a working copy
142
  processed_image = image.copy()
143
+
144
+ # 1. Remove Background (Do this first so we have the mask)
145
  if remove_bg:
146
+ st.info("Loading RMBG Model...")
147
  try:
148
+ bg_model, device = load_rembg_model()
149
+ with st.spinner("Removing background..."):
150
+ processed_image = safe_rembg_inference(bg_model, processed_image, device)
 
 
151
  except Exception as e:
152
+ st.error(f"Background Removal Failed: {e}")
153
 
154
+ # 2. Upscaling
155
+ if upscale_mode != "None":
156
+ scale = 4 if "4x" in upscale_mode else 2
157
+ st.info(f"Loading Swin2SR x{scale} Model...")
158
+ try:
159
+ processor, upscaler = load_upscaler(scale)
160
+ with st.spinner(f"Upscaling x{scale}..."):
161
+ processed_image = ai_upscale(processed_image, processor, upscaler)
162
+ except Exception as e:
163
+ st.error(f"Upscaling Failed: {e}")
164
+
165
+ # 3. Rotation
166
  if rotate_angle != 0:
167
  processed_image = processed_image.rotate(rotate_angle, expand=True)
168
 
 
171
  with col1:
172
  st.subheader("Original")
173
  st.image(image, use_container_width=True)
174
+ st.caption(f"Size: {image.size}")
175
+
176
  with col2:
177
  st.subheader("Result")
178
  st.image(processed_image, use_container_width=True)
179
+ st.caption(f"Size: {processed_image.size}")
180
 
181
  # --- Download ---
182
  st.markdown("---")
183
+ st.download_button(
184
+ label="💾 Download Result (PNG)",
185
  data=convert_image_to_bytes(processed_image),
186
+ file_name="processed_image.png",
187
+ mime="image/png"
188
  )
189
 
190
  if __name__ == "__main__":