ImageProcessing / src /streamlit_app.py
lukeafullard's picture
Update src/streamlit_app.py
169cdb3 verified
raw
history blame
6.66 kB
import streamlit as st
from PIL import Image, ImageEnhance
import torch
import torch.nn.functional as F
from torchvision import transforms
from transformers import AutoModelForImageSegmentation, AutoImageProcessor, Swin2SRForImageSuperResolution
import io
import numpy as np
# Page Configuration
st.set_page_config(layout="wide", page_title="AI Image Lab")
# --- 1. MODEL LOADING (Cached) ---
@st.cache_resource
def load_rembg_model():
"""Loads RMBG-1.4 for Background Removal."""
model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4", trust_remote_code=True)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
return model, device
@st.cache_resource
def load_upscaler(scale=2):
"""Loads Swin2SR for Super-Resolution (2x or 4x)."""
if scale == 4:
model_id = "caidas/swin2SR-classical-sr-x4-63"
else:
model_id = "caidas/swin2SR-classical-sr-x2-64"
processor = AutoImageProcessor.from_pretrained(model_id)
model = Swin2SRForImageSuperResolution.from_pretrained(model_id)
return processor, model
# --- 2. PROCESSING FUNCTIONS ---
def safe_rembg_inference(model, image, device):
"""
Robust inference for RMBG-1.4 that handles different output formats.
"""
w, h = image.size
# Preprocessing
transform_image = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
input_images = transform_image(image).unsqueeze(0).to(device)
# Inference
with torch.no_grad():
outputs = model(input_images)
# FIX: Handle List vs Tuple vs Tensor output
# BiRefNet usually returns a list/tuple of tensors.
# The output we want is usually the LAST element or the FIRST depending on version.
# We check if 'outputs' is a sequence (list/tuple) and grab the tensor.
if isinstance(outputs, (list, tuple)):
# We assume the last element is the high-res prediction for RMBG-1.4
result_tensor = outputs[-1]
# Double check: if the result is still a list (nested), grab the first item
if isinstance(result_tensor, (list, tuple)):
result_tensor = result_tensor[0]
else:
result_tensor = outputs
# Post-processing
pred = result_tensor.sigmoid().cpu()[0].squeeze()
# Convert mask to PIL
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize((w, h))
# Apply mask
image.putalpha(mask)
return image
def ai_upscale(image, processor, model):
"""
Upscales RGB image using Swin2SR.
Note: Swin2SR only works on RGB. If image is RGBA, we must handle Alpha separately.
"""
# 1. Handle Alpha Channel (if exists)
if image.mode == 'RGBA':
# Split RGB and Alpha
r, g, b, a = image.split()
rgb_image = Image.merge('RGB', (r, g, b))
# Upscale RGB using AI
upscaled_rgb = run_swin_inference(rgb_image, processor, model)
# Upscale Alpha using standard interpolation (AI models don't predict alpha)
# We resize alpha to match the new RGB size
upscaled_a = a.resize(upscaled_rgb.size, Image.Resampling.LANCZOS)
# Recombine
return Image.merge('RGBA', (*upscaled_rgb.split(), upscaled_a))
else:
return run_swin_inference(image, processor, model)
def run_swin_inference(image, processor, model):
"""Helper to run the actual Swin2SR inference on an RGB image."""
inputs = processor(image, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
output = outputs.reconstruction.data.squeeze().float().cpu().clamp_(0, 1).numpy()
output = np.moveaxis(output, 0, -1)
output = (output * 255.0).round().astype(np.uint8)
return Image.fromarray(output)
def convert_image_to_bytes(img):
buf = io.BytesIO()
img.save(buf, format="PNG")
return buf.getvalue()
# --- 3. MAIN APP ---
def main():
st.title("✨ AI Image Lab: Robust Edition")
st.markdown("Features: **RMBG-1.4 (No ONNX)** | **Swin2SR (Upscaling)** | **Geometry**")
# --- Sidebar ---
st.sidebar.header("1. Background")
remove_bg = st.sidebar.checkbox("Remove Background", value=False)
st.sidebar.header("2. AI Upscaling")
upscale_mode = st.sidebar.radio("Magnification", ["None", "2x (Fast)", "4x (Slow)"])
st.sidebar.header("3. Geometry")
rotate_angle = st.sidebar.slider("Rotate", -180, 180, 0, 1)
# --- Main ---
uploaded_file = st.file_uploader("Upload Image", type=["png", "jpg", "jpeg", "webp"])
if uploaded_file is not None:
image = Image.open(uploaded_file).convert("RGB")
# Create a working copy
processed_image = image.copy()
# 1. Remove Background (Do this first so we have the mask)
if remove_bg:
st.info("Loading RMBG Model...")
try:
bg_model, device = load_rembg_model()
with st.spinner("Removing background..."):
processed_image = safe_rembg_inference(bg_model, processed_image, device)
except Exception as e:
st.error(f"Background Removal Failed: {e}")
# 2. Upscaling
if upscale_mode != "None":
scale = 4 if "4x" in upscale_mode else 2
st.info(f"Loading Swin2SR x{scale} Model...")
try:
processor, upscaler = load_upscaler(scale)
with st.spinner(f"Upscaling x{scale}..."):
processed_image = ai_upscale(processed_image, processor, upscaler)
except Exception as e:
st.error(f"Upscaling Failed: {e}")
# 3. Rotation
if rotate_angle != 0:
processed_image = processed_image.rotate(rotate_angle, expand=True)
# --- Display ---
col1, col2 = st.columns(2)
with col1:
st.subheader("Original")
st.image(image, use_container_width=True)
st.caption(f"Size: {image.size}")
with col2:
st.subheader("Result")
st.image(processed_image, use_container_width=True)
st.caption(f"Size: {processed_image.size}")
# --- Download ---
st.markdown("---")
st.download_button(
label="💾 Download Result (PNG)",
data=convert_image_to_bytes(processed_image),
file_name="processed_image.png",
mime="image/png"
)
if __name__ == "__main__":
main()