ImageProcessing / src /streamlit_app.py
lukeafullard's picture
Update src/streamlit_app.py
17ddc19 verified
raw
history blame
7.54 kB
import streamlit as st
from PIL import Image, ImageEnhance
import torch
from torchvision import transforms
from transformers import AutoModelForImageSegmentation, AutoImageProcessor, Swin2SRForImageSuperResolution
import io
import numpy as np
# Page Configuration
st.set_page_config(layout="wide", page_title="AI Image Lab")
# --- 1. MODEL LOADING (Cached) ---
@st.cache_resource
def load_rembg_model():
"""Loads RMBG-1.4 for Background Removal."""
# We use 'briaai/RMBG-1.4'
model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4", trust_remote_code=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
return model, device
@st.cache_resource
def load_upscaler(scale=2):
"""Loads Swin2SR for Super-Resolution (2x or 4x)."""
if scale == 4:
model_id = "caidas/swin2SR-classical-sr-x4-63"
else:
model_id = "caidas/swin2SR-classical-sr-x2-64"
processor = AutoImageProcessor.from_pretrained(model_id)
model = Swin2SRForImageSuperResolution.from_pretrained(model_id)
return processor, model
# --- 2. PROCESSING FUNCTIONS ---
def find_mask_tensor(output):
"""
Recursively searches any nested structure (list, tuple, dict, object)
to find the first Tensor that looks like a mask (1 channel).
"""
# 1. If it's a Tensor, check if it's the mask we want
if isinstance(output, torch.Tensor):
# We look for shape [Batch, 1, H, W] or [1, H, W]
# It must have 1 channel (index 1 for 4D, index 0 for 3D)
if output.dim() == 4 and output.shape[1] == 1:
return output
elif output.dim() == 3 and output.shape[0] == 1:
return output
# If it has > 1 channels (e.g. 64), it's a feature map, ignore it.
return None
# 2. If it's a Dict/ModelOutput (like .logits), check values
if hasattr(output, "items"):
for val in output.values():
found = find_mask_tensor(val)
if found is not None: return found
# Special case for Hugging Face model outputs with attributes
elif hasattr(output, "logits"):
return find_mask_tensor(output.logits)
# 3. If it's a List or Tuple, iterate through elements
elif isinstance(output, (list, tuple)):
for item in output:
found = find_mask_tensor(item)
if found is not None: return found
return None
def safe_rembg_inference(model, image, device):
"""
Robust inference for RMBG-1.4 using Deep Search.
"""
w, h = image.size
# Preprocessing
transform_image = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
input_images = transform_image(image).unsqueeze(0).to(device)
# Inference
with torch.no_grad():
outputs = model(input_images)
# --- DEEP SEARCH FOR MASK ---
result_tensor = find_mask_tensor(outputs)
if result_tensor is None:
# Fallback: If deep search failed, try just grabbing the first tensor found
# (Even if dimensions look weird, it's better than crashing)
if isinstance(outputs, (list, tuple)):
result_tensor = outputs[0]
else:
result_tensor = outputs
# Post-processing
# Ensure it's a tensor before operations
if not isinstance(result_tensor, torch.Tensor):
# If we still have a list here, we take the first element blindly
if isinstance(result_tensor, (list, tuple)):
result_tensor = result_tensor[0]
pred = result_tensor.squeeze().cpu()
# Sometimes output is already sigmoid, sometimes logits.
# If values are > 1 or < 0, apply sigmoid.
if pred.max() > 1 or pred.min() < 0:
pred = pred.sigmoid()
# Convert mask to PIL
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize((w, h))
# Apply mask
image.putalpha(mask)
return image
def ai_upscale(image, processor, model):
if image.mode == 'RGBA':
r, g, b, a = image.split()
rgb_image = Image.merge('RGB', (r, g, b))
upscaled_rgb = run_swin_inference(rgb_image, processor, model)
upscaled_a = a.resize(upscaled_rgb.size, Image.Resampling.LANCZOS)
return Image.merge('RGBA', (*upscaled_rgb.split(), upscaled_a))
else:
return run_swin_inference(image, processor, model)
def run_swin_inference(image, processor, model):
inputs = processor(image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output = np.moveaxis(output, 0, -1)
output = (output * 255.0).round().astype(np.uint8)
return Image.fromarray(output)
def convert_image_to_bytes(img):
buf = io.BytesIO()
img.save(buf, format="PNG")
return buf.getvalue()
# --- 3. MAIN APP ---
def main():
st.title("✨ AI Image Lab: Robust Edition")
st.markdown("Features: **RMBG-1.4 (Pure PyTorch)** | **Swin2SR (Upscaling)** | **Geometry**")
# --- Sidebar ---
st.sidebar.header("1. Background")
remove_bg = st.sidebar.checkbox("Remove Background", value=False)
st.sidebar.header("2. AI Upscaling")
upscale_mode = st.sidebar.radio("Magnification", ["None", "2x (Fast)", "4x (Slow)"])
st.sidebar.header("3. Geometry")
rotate_angle = st.sidebar.slider("Rotate", -180, 180, 0, 1)
# --- Main ---
uploaded_file = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg", "webp"])
if uploaded_file is not None:
image = Image.open(uploaded_file).convert("RGB")
processed_image = image.copy()
# 1. Background
if remove_bg:
st.info("Loading RMBG Model...")
try:
bg_model, device = load_rembg_model()
with st.spinner("Removing background..."):
processed_image = safe_rembg_inference(bg_model, processed_image, device)
except Exception as e:
st.error(f"Background Removal Failed: {e}")
# 2. Upscaling
if upscale_mode != "None":
scale = 4 if "4x" in upscale_mode else 2
st.info(f"Loading Swin2SR x{scale} Model...")
try:
processor, upscaler = load_upscaler(scale)
with st.spinner(f"Upscaling x{scale}..."):
processed_image = ai_upscale(processed_image, processor, upscaler)
except Exception as e:
st.error(f"Upscaling Failed: {e}")
# 3. Rotation
if rotate_angle != 0:
processed_image = processed_image.rotate(rotate_angle, expand=True)
# --- Display ---
col1, col2 = st.columns(2)
with col1:
st.subheader("Original")
st.image(image, use_container_width=True)
st.caption(f"Size: {image.size}")
with col2:
st.subheader("Result")
st.image(processed_image, use_container_width=True)
st.caption(f"Size: {processed_image.size}")
# --- Download ---
st.markdown("---")
st.download_button(
label="💾 Download Result (PNG)",
data=convert_image_to_bytes(processed_image),
file_name="processed_image.png",
mime="image/png"
)
if __name__ == "__main__":
main()