lukeafullard commited on
Commit
17ddc19
·
verified ·
1 Parent(s): 5b003bf

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +57 -43
src/streamlit_app.py CHANGED
@@ -1,7 +1,6 @@
1
  import streamlit as st
2
  from PIL import Image, ImageEnhance
3
  import torch
4
- import torch.nn.functional as F
5
  from torchvision import transforms
6
  from transformers import AutoModelForImageSegmentation, AutoImageProcessor, Swin2SRForImageSuperResolution
7
  import io
@@ -15,6 +14,7 @@ st.set_page_config(layout="wide", page_title="AI Image Lab")
15
  @st.cache_resource
16
  def load_rembg_model():
17
  """Loads RMBG-1.4 for Background Removal."""
 
18
  model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4", trust_remote_code=True)
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
  model.to(device)
@@ -34,9 +34,42 @@ def load_upscaler(scale=2):
34
 
35
  # --- 2. PROCESSING FUNCTIONS ---
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  def safe_rembg_inference(model, image, device):
38
  """
39
- Robust inference for RMBG-1.4 that finds the correct mask tensor.
40
  """
41
  w, h = image.size
42
 
@@ -52,34 +85,31 @@ def safe_rembg_inference(model, image, device):
52
  with torch.no_grad():
53
  outputs = model(input_images)
54
 
55
- # --- FIX START ---
56
- result_tensor = None
57
-
58
- # Priority 1: Check for explicit 'logits' attribute (Standard Hugging Face)
59
- if hasattr(outputs, "logits"):
60
- result_tensor = outputs.logits
61
 
62
- # Priority 2: Iterate through list/tuple to find the 1-channel mask
63
- elif isinstance(outputs, (list, tuple)):
64
- for tensor in outputs:
65
- # We are looking for shape [Batch, 1, Height, Width]
66
- if isinstance(tensor, torch.Tensor) and tensor.dim() == 4 and tensor.shape[1] == 1:
67
- result_tensor = tensor
68
- break
69
-
70
- # Fallback: If no 1-channel tensor found, take the first element
71
- if result_tensor is None:
72
  result_tensor = outputs[0]
73
-
74
- # Priority 3: It's already a tensor
75
- else:
76
- result_tensor = outputs
77
- # --- FIX END ---
78
 
79
  # Post-processing
80
- # Squeeze removes batch dim (1, 1, 1024, 1024) -> (1024, 1024)
81
- pred = result_tensor.squeeze().sigmoid().cpu()
 
 
 
 
 
82
 
 
 
 
 
 
83
  # Convert mask to PIL
84
  pred_pil = transforms.ToPILImage()(pred)
85
  mask = pred_pil.resize((w, h))
@@ -89,30 +119,16 @@ def safe_rembg_inference(model, image, device):
89
  return image
90
 
91
  def ai_upscale(image, processor, model):
92
- """
93
- Upscales RGB image using Swin2SR.
94
- Note: Swin2SR only works on RGB. If image is RGBA, we must handle Alpha separately.
95
- """
96
- # 1. Handle Alpha Channel (if exists)
97
  if image.mode == 'RGBA':
98
- # Split RGB and Alpha
99
  r, g, b, a = image.split()
100
  rgb_image = Image.merge('RGB', (r, g, b))
101
-
102
- # Upscale RGB using AI
103
  upscaled_rgb = run_swin_inference(rgb_image, processor, model)
104
-
105
- # Upscale Alpha using standard interpolation (AI models don't predict alpha)
106
- # We resize alpha to match the new RGB size
107
  upscaled_a = a.resize(upscaled_rgb.size, Image.Resampling.LANCZOS)
108
-
109
- # Recombine
110
  return Image.merge('RGBA', (*upscaled_rgb.split(), upscaled_a))
111
  else:
112
  return run_swin_inference(image, processor, model)
113
 
114
  def run_swin_inference(image, processor, model):
115
- """Helper to run the actual Swin2SR inference on an RGB image."""
116
  inputs = processor(image, return_tensors="pt")
117
  with torch.no_grad():
118
  outputs = model(**inputs)
@@ -131,7 +147,7 @@ def convert_image_to_bytes(img):
131
 
132
  def main():
133
  st.title("✨ AI Image Lab: Robust Edition")
134
- st.markdown("Features: **RMBG-1.4 (No ONNX)** | **Swin2SR (Upscaling)** | **Geometry**")
135
 
136
  # --- Sidebar ---
137
  st.sidebar.header("1. Background")
@@ -148,11 +164,9 @@ def main():
148
 
149
  if uploaded_file is not None:
150
  image = Image.open(uploaded_file).convert("RGB")
151
-
152
- # Create a working copy
153
  processed_image = image.copy()
154
 
155
- # 1. Remove Background (Do this first so we have the mask)
156
  if remove_bg:
157
  st.info("Loading RMBG Model...")
158
  try:
 
1
  import streamlit as st
2
  from PIL import Image, ImageEnhance
3
  import torch
 
4
  from torchvision import transforms
5
  from transformers import AutoModelForImageSegmentation, AutoImageProcessor, Swin2SRForImageSuperResolution
6
  import io
 
14
  @st.cache_resource
15
  def load_rembg_model():
16
  """Loads RMBG-1.4 for Background Removal."""
17
+ # We use 'briaai/RMBG-1.4'
18
  model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4", trust_remote_code=True)
19
  device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
  model.to(device)
 
34
 
35
  # --- 2. PROCESSING FUNCTIONS ---
36
 
37
+ def find_mask_tensor(output):
38
+ """
39
+ Recursively searches any nested structure (list, tuple, dict, object)
40
+ to find the first Tensor that looks like a mask (1 channel).
41
+ """
42
+ # 1. If it's a Tensor, check if it's the mask we want
43
+ if isinstance(output, torch.Tensor):
44
+ # We look for shape [Batch, 1, H, W] or [1, H, W]
45
+ # It must have 1 channel (index 1 for 4D, index 0 for 3D)
46
+ if output.dim() == 4 and output.shape[1] == 1:
47
+ return output
48
+ elif output.dim() == 3 and output.shape[0] == 1:
49
+ return output
50
+ # If it has > 1 channels (e.g. 64), it's a feature map, ignore it.
51
+ return None
52
+
53
+ # 2. If it's a Dict/ModelOutput (like .logits), check values
54
+ if hasattr(output, "items"):
55
+ for val in output.values():
56
+ found = find_mask_tensor(val)
57
+ if found is not None: return found
58
+ # Special case for Hugging Face model outputs with attributes
59
+ elif hasattr(output, "logits"):
60
+ return find_mask_tensor(output.logits)
61
+
62
+ # 3. If it's a List or Tuple, iterate through elements
63
+ elif isinstance(output, (list, tuple)):
64
+ for item in output:
65
+ found = find_mask_tensor(item)
66
+ if found is not None: return found
67
+
68
+ return None
69
+
70
  def safe_rembg_inference(model, image, device):
71
  """
72
+ Robust inference for RMBG-1.4 using Deep Search.
73
  """
74
  w, h = image.size
75
 
 
85
  with torch.no_grad():
86
  outputs = model(input_images)
87
 
88
+ # --- DEEP SEARCH FOR MASK ---
89
+ result_tensor = find_mask_tensor(outputs)
 
 
 
 
90
 
91
+ if result_tensor is None:
92
+ # Fallback: If deep search failed, try just grabbing the first tensor found
93
+ # (Even if dimensions look weird, it's better than crashing)
94
+ if isinstance(outputs, (list, tuple)):
 
 
 
 
 
 
95
  result_tensor = outputs[0]
96
+ else:
97
+ result_tensor = outputs
 
 
 
98
 
99
  # Post-processing
100
+ # Ensure it's a tensor before operations
101
+ if not isinstance(result_tensor, torch.Tensor):
102
+ # If we still have a list here, we take the first element blindly
103
+ if isinstance(result_tensor, (list, tuple)):
104
+ result_tensor = result_tensor[0]
105
+
106
+ pred = result_tensor.squeeze().cpu()
107
 
108
+ # Sometimes output is already sigmoid, sometimes logits.
109
+ # If values are > 1 or < 0, apply sigmoid.
110
+ if pred.max() > 1 or pred.min() < 0:
111
+ pred = pred.sigmoid()
112
+
113
  # Convert mask to PIL
114
  pred_pil = transforms.ToPILImage()(pred)
115
  mask = pred_pil.resize((w, h))
 
119
  return image
120
 
121
  def ai_upscale(image, processor, model):
 
 
 
 
 
122
  if image.mode == 'RGBA':
 
123
  r, g, b, a = image.split()
124
  rgb_image = Image.merge('RGB', (r, g, b))
 
 
125
  upscaled_rgb = run_swin_inference(rgb_image, processor, model)
 
 
 
126
  upscaled_a = a.resize(upscaled_rgb.size, Image.Resampling.LANCZOS)
 
 
127
  return Image.merge('RGBA', (*upscaled_rgb.split(), upscaled_a))
128
  else:
129
  return run_swin_inference(image, processor, model)
130
 
131
  def run_swin_inference(image, processor, model):
 
132
  inputs = processor(image, return_tensors="pt")
133
  with torch.no_grad():
134
  outputs = model(**inputs)
 
147
 
148
  def main():
149
  st.title("✨ AI Image Lab: Robust Edition")
150
+ st.markdown("Features: **RMBG-1.4 (Pure PyTorch)** | **Swin2SR (Upscaling)** | **Geometry**")
151
 
152
  # --- Sidebar ---
153
  st.sidebar.header("1. Background")
 
164
 
165
  if uploaded_file is not None:
166
  image = Image.open(uploaded_file).convert("RGB")
 
 
167
  processed_image = image.copy()
168
 
169
+ # 1. Background
170
  if remove_bg:
171
  st.info("Loading RMBG Model...")
172
  try: