ImageProcessing / src /streamlit_app.py
lukeafullard's picture
Update src/streamlit_app.py
dbc4275 verified
raw
history blame
3.91 kB
import streamlit as st
from PIL import Image, ImageEnhance
import torch
import torch.nn.functional as F
from torchvision import transforms
from transformers import AutoModelForImageSegmentation
import io
import numpy as np
# Page Configuration
st.set_page_config(layout="wide", page_title="AI Image Lab")
# --- Caching AI Models ---
@st.cache_resource
def load_birefnet_model():
"""
Loads the RMBG-1.4 model for Background Removal (Pure PyTorch).
"""
# We use 'briaai/RMBG-1.4' which is SOTA for background removal
model = AutoModelForImageSegmentation.from_pretrained("briaai/RMBG-1.4", trust_remote_code=True)
# Move to GPU if available, otherwise CPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
return model, device
# ... (Previous Upscalers kept for reference, you can re-add them if you wish) ...
def remove_background_torch(image, model, device):
"""
Runs background removal using RMBG-1.4 on PyTorch.
"""
# 1. Prepare input
w, h = image.size
# The model expects specific normalization and size
transform_image = transforms.Compose([
transforms.Resize((1024, 1024)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
input_images = transform_image(image).unsqueeze(0).to(device)
# 2. Inference
with torch.no_grad():
preds = model(input_images)[-1].sigmoid().cpu()
# 3. Post-process mask
pred = preds[0].squeeze()
# Convert mask to PIL and resize back to original dimensions
pred_pil = transforms.ToPILImage()(pred)
mask = pred_pil.resize((w, h))
# 4. Apply mask to original image
image.putalpha(mask)
return image
def convert_image_to_bytes(img):
buf = io.BytesIO()
img.save(buf, format="PNG")
return buf.getvalue()
def main():
st.title("✨ AI Image Lab: Pure PyTorch Edition")
st.markdown("Processing pipeline: **RMBG-1.4 (No ONNX)**")
# --- Sidebar Controls ---
st.sidebar.header("Processing Pipeline")
remove_bg = st.sidebar.checkbox("Remove Background (RMBG-1.4)", value=False)
st.sidebar.subheader("Final Adjustments")
rotate_angle = st.sidebar.slider("Rotate", -180, 180, 0, 1)
# --- Main Content ---
uploaded_file = st.file_uploader("Upload an image...", type=["jpg", "jpeg", "png", "webp"])
if uploaded_file is not None:
# Important: RMBG model works best if we ensure RGB mode
image = Image.open(uploaded_file).convert("RGB")
processed_image = image.copy()
# --- STEP 1: Background Removal ---
if remove_bg:
st.info("Loading RMBG-1.4 Model (First run will download ~170MB)...")
try:
# Load Model
model, device = load_birefnet_model()
with st.spinner("Removing background using PyTorch..."):
processed_image = remove_background_torch(processed_image, model, device)
except Exception as e:
st.error(f"Error during background removal: {e}")
# --- STEP 2: Geometry/Color ---
if rotate_angle != 0:
processed_image = processed_image.rotate(rotate_angle, expand=True)
# --- Display ---
col1, col2 = st.columns(2)
with col1:
st.subheader("Original")
st.image(image, use_container_width=True)
with col2:
st.subheader("Result")
st.image(processed_image, use_container_width=True)
# --- Download ---
st.markdown("---")
btn = st.download_button(
label="💾 Download Result",
data=convert_image_to_bytes(processed_image),
file_name="nobg_image.png",
mime="image/png",
)
if __name__ == "__main__":
main()