lukeafullard commited on
Commit
dbc4275
·
verified ·
1 Parent(s): 7670112

Update src/streamlit_app.py

Browse files
Files changed (1) hide show
  1. src/streamlit_app.py +59 -88
src/streamlit_app.py CHANGED
@@ -1,50 +1,62 @@
1
  import streamlit as st
2
  from PIL import Image, ImageEnhance
3
- from rembg import remove
4
- import io
5
  import torch
 
 
 
 
6
  import numpy as np
7
- from transformers import AutoImageProcessor, Swin2SRForImageSuperResolution, pipeline
8
 
9
  # Page Configuration
10
  st.set_page_config(layout="wide", page_title="AI Image Lab")
11
 
12
  # --- Caching AI Models ---
13
- # We use separate functions for 2x and 4x to avoid loading both into memory if not needed.
14
 
15
  @st.cache_resource
16
- def load_upscaler_x2():
17
- """Loads the Swin2SR model for 2x upscale."""
18
- model_id = "caidas/swin2SR-classical-sr-x2-64"
19
- processor = AutoImageProcessor.from_pretrained(model_id)
20
- model = Swin2SRForImageSuperResolution.from_pretrained(model_id)
21
- return processor, model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
- @st.cache_resource
24
- def load_upscaler_x4():
25
- """Loads the Swin2SR model for 4x upscale."""
26
- # This model is heavier and takes longer to run
27
- model_id = "caidas/swin2SR-classical-sr-x4-63"
28
- processor = AutoImageProcessor.from_pretrained(model_id)
29
- model = Swin2SRForImageSuperResolution.from_pretrained(model_id)
30
- return processor, model
31
 
32
- @st.cache_resource
33
- def load_depth_pipeline():
34
- """Loads a lightweight Depth Estimation pipeline."""
35
- pipe = pipeline(task="depth-estimation", model="vinvino02/glpn-nyu")
36
- return pipe
37
-
38
- def ai_upscale(image, processor, model):
39
- """Runs the super-resolution model."""
40
- inputs = processor(image, return_tensors="pt")
41
  with torch.no_grad():
42
- outputs = model(**inputs)
 
 
 
 
 
 
 
43
 
44
- output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
45
- output = np.moveaxis(output, 0, -1)
46
- output = (output * 255.0).round().astype(np.uint8)
47
- return Image.fromarray(output)
48
 
49
  def convert_image_to_bytes(img):
50
  buf = io.BytesIO()
@@ -52,98 +64,57 @@ def convert_image_to_bytes(img):
52
  return buf.getvalue()
53
 
54
  def main():
55
- st.title("✨ AI Image Lab: Transformers Edition")
56
- st.markdown("Processing pipeline: **Background Removal** **AI Modifiers** → **Geometry**")
57
 
58
  # --- Sidebar Controls ---
59
  st.sidebar.header("Processing Pipeline")
60
 
61
- # 1. Background
62
- st.sidebar.subheader("1. Cleanup")
63
- remove_bg = st.sidebar.checkbox("Remove Background (rembg)", value=False)
64
-
65
- # 2. AI Enhancements
66
- st.sidebar.subheader("2. AI Enhancements")
67
- ai_mode = st.sidebar.radio(
68
- "Choose AI Modification:",
69
- ["None", "AI Super-Resolution (2x)", "AI Super-Resolution (4x)", "Depth Estimation"]
70
- )
71
-
72
- # 3. Geometry & Color
73
- st.sidebar.subheader("3. Final Adjustments")
74
  rotate_angle = st.sidebar.slider("Rotate", -180, 180, 0, 1)
75
- contrast_val = st.sidebar.slider("Contrast", 0.5, 1.5, 1.0, 0.1)
76
 
77
  # --- Main Content ---
78
  uploaded_file = st.file_uploader("Upload an image...", type=["jpg", "jpeg", "png", "webp"])
79
 
80
  if uploaded_file is not None:
 
81
  image = Image.open(uploaded_file).convert("RGB")
82
  processed_image = image.copy()
83
 
84
  # --- STEP 1: Background Removal ---
85
  if remove_bg:
86
- with st.spinner("Removing background..."):
87
- processed_image = remove(processed_image)
88
-
89
- # --- STEP 2: AI Enhancements ---
90
- if ai_mode == "AI Super-Resolution (2x)":
91
- st.info("Loading Swin2SR (2x) model... (Fast)")
92
- try:
93
- processor, model = load_upscaler_x2()
94
- with st.spinner("Upscaling (2x)..."):
95
- processed_image = ai_upscale(processed_image, processor, model)
96
- except Exception as e:
97
- st.error(f"Error loading Upscaler: {e}")
98
-
99
- elif ai_mode == "AI Super-Resolution (4x)":
100
- st.warning("Loading Swin2SR (4x) model... (This is computationally heavy!)")
101
- # Added a warning because 4x on CPU can be quite slow for large images
102
- try:
103
- processor, model = load_upscaler_x4()
104
- with st.spinner("Upscaling (4x)... please wait"):
105
- processed_image = ai_upscale(processed_image, processor, model)
106
- except Exception as e:
107
- st.error(f"Error loading Upscaler: {e}")
108
-
109
- elif ai_mode == "Depth Estimation":
110
- st.info("Generating Depth Map...")
111
  try:
112
- depth_pipe = load_depth_pipeline()
113
- with st.spinner("Estimating depth..."):
114
- result = depth_pipe(processed_image)
115
- processed_image = result["depth"]
 
116
  except Exception as e:
117
- st.error(f"Error loading Depth Model: {e}")
118
 
119
- # --- STEP 3: Geometry/Color ---
120
- # Rotation
121
  if rotate_angle != 0:
122
  processed_image = processed_image.rotate(rotate_angle, expand=True)
123
-
124
- # Contrast
125
- if contrast_val != 1.0:
126
- enhancer = ImageEnhance.Contrast(processed_image)
127
- processed_image = enhancer.enhance(contrast_val)
128
 
129
  # --- Display ---
130
  col1, col2 = st.columns(2)
131
  with col1:
132
  st.subheader("Original")
133
  st.image(image, use_container_width=True)
134
- st.caption(f"Size: {image.size}")
135
 
136
  with col2:
137
  st.subheader("Result")
138
  st.image(processed_image, use_container_width=True)
139
- st.caption(f"Size: {processed_image.size}")
140
 
141
  # --- Download ---
142
  st.markdown("---")
143
  btn = st.download_button(
144
  label="💾 Download Result",
145
  data=convert_image_to_bytes(processed_image),
146
- file_name="ai_enhanced_image.png",
147
  mime="image/png",
148
  )
149
 
 
1
  import streamlit as st
2
  from PIL import Image, ImageEnhance
 
 
3
  import torch
4
+ import torch.nn.functional as F
5
+ from torchvision import transforms
6
+ from transformers import AutoModelForImageSegmentation
7
+ import io
8
  import numpy as np
 
9
 
10
  # Page Configuration
11
  st.set_page_config(layout="wide", page_title="AI Image Lab")
12
 
13
  # --- Caching AI Models ---
 
14
 
15
  @st.cache_resource
16
+ def load_birefnet_model():
17
+ """
18
+ Loads the RMBG-1.4 model for Background Removal (Pure PyTorch).
19
+ """
20
+ # We use 'briaai/RMBG-1.4' which is SOTA for background removal
21
+ model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4", trust_remote_code=True)
22
+
23
+ # Move to GPU if available, otherwise CPU
24
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
25
+ model.to(device)
26
+ return model, device
27
+
28
+ # ... (Previous Upscalers kept for reference, you can re-add them if you wish) ...
29
+
30
+ def remove_background_torch(image, model, device):
31
+ """
32
+ Runs background removal using RMBG-1.4 on PyTorch.
33
+ """
34
+ # 1. Prepare input
35
+ w, h = image.size
36
+
37
+ # The model expects specific normalization and size
38
+ transform_image = transforms.Compose([
39
+ transforms.Resize((1024, 1024)),
40
+ transforms.ToTensor(),
41
+ transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
42
+ ])
43
 
44
+ input_images = transform_image(image).unsqueeze(0).to(device)
 
 
 
 
 
 
 
45
 
46
+ # 2. Inference
 
 
 
 
 
 
 
 
47
  with torch.no_grad():
48
+ preds = model(input_images)[-1].sigmoid().cpu()
49
+
50
+ # 3. Post-process mask
51
+ pred = preds[0].squeeze()
52
+
53
+ # Convert mask to PIL and resize back to original dimensions
54
+ pred_pil = transforms.ToPILImage()(pred)
55
+ mask = pred_pil.resize((w, h))
56
 
57
+ # 4. Apply mask to original image
58
+ image.putalpha(mask)
59
+ return image
 
60
 
61
  def convert_image_to_bytes(img):
62
  buf = io.BytesIO()
 
64
  return buf.getvalue()
65
 
66
  def main():
67
+ st.title("✨ AI Image Lab: Pure PyTorch Edition")
68
+ st.markdown("Processing pipeline: **RMBG-1.4 (No ONNX)**")
69
 
70
  # --- Sidebar Controls ---
71
  st.sidebar.header("Processing Pipeline")
72
 
73
+ remove_bg = st.sidebar.checkbox("Remove Background (RMBG-1.4)", value=False)
74
+
75
+ st.sidebar.subheader("Final Adjustments")
 
 
 
 
 
 
 
 
 
 
76
  rotate_angle = st.sidebar.slider("Rotate", -180, 180, 0, 1)
 
77
 
78
  # --- Main Content ---
79
  uploaded_file = st.file_uploader("Upload an image...", type=["jpg", "jpeg", "png", "webp"])
80
 
81
  if uploaded_file is not None:
82
+ # Important: RMBG model works best if we ensure RGB mode
83
  image = Image.open(uploaded_file).convert("RGB")
84
  processed_image = image.copy()
85
 
86
  # --- STEP 1: Background Removal ---
87
  if remove_bg:
88
+ st.info("Loading RMBG-1.4 Model (First run will download ~170MB)...")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  try:
90
+ # Load Model
91
+ model, device = load_birefnet_model()
92
+
93
+ with st.spinner("Removing background using PyTorch..."):
94
+ processed_image = remove_background_torch(processed_image, model, device)
95
  except Exception as e:
96
+ st.error(f"Error during background removal: {e}")
97
 
98
+ # --- STEP 2: Geometry/Color ---
 
99
  if rotate_angle != 0:
100
  processed_image = processed_image.rotate(rotate_angle, expand=True)
 
 
 
 
 
101
 
102
  # --- Display ---
103
  col1, col2 = st.columns(2)
104
  with col1:
105
  st.subheader("Original")
106
  st.image(image, use_container_width=True)
 
107
 
108
  with col2:
109
  st.subheader("Result")
110
  st.image(processed_image, use_container_width=True)
 
111
 
112
  # --- Download ---
113
  st.markdown("---")
114
  btn = st.download_button(
115
  label="💾 Download Result",
116
  data=convert_image_to_bytes(processed_image),
117
+ file_name="nobg_image.png",
118
  mime="image/png",
119
  )
120