Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import torch | |
| from PIL import Image | |
| import io | |
| import numpy as np | |
| from briarmbg import BriaRMBG | |
| from torchvision.transforms.functional import normalize | |
| # Reuse the functions from your CLI script | |
| def convert_to_jpg(image, image_name): | |
| """Convert PNG to JPG if necessary.""" | |
| if image_name.lower().endswith('.png'): | |
| img = Image.open(image) | |
| # Convert to RGB if the image has an alpha channel | |
| if img.mode in ('RGBA', 'LA') or (img.mode == 'P' and 'transparency' in img.info): | |
| bg = Image.new("RGB", img.size, (255, 255, 255)) | |
| bg.paste(img, mask=img.split()[3] if img.mode == 'RGBA' else img.split()[1]) | |
| else: | |
| bg = img.convert("RGB") | |
| return bg | |
| return Image.open(image) | |
| def resize_image(image, size=(1024, 1024)): | |
| image = image.convert('RGB') | |
| image = image.resize(size, Image.BILINEAR) | |
| return image | |
| def remove_background(model, image): | |
| # Save original size | |
| original_size = image.size | |
| # Convert to JPG if necessary | |
| # image = convert_to_jpg(image) | |
| # Preprocess the image | |
| image_resized = resize_image(image) | |
| im_np = np.array(image_resized) | |
| im_tensor = torch.tensor(im_np, dtype=torch.float32).permute(2,0,1) | |
| im_tensor = torch.unsqueeze(im_tensor,0) | |
| im_tensor = torch.divide(im_tensor,255.0) | |
| im_tensor = normalize(im_tensor,[0.5,0.5,0.5],[1.0,1.0,1.0]) | |
| if torch.cuda.is_available(): | |
| im_tensor = im_tensor.cuda() | |
| model = model.cuda() | |
| # Process the image | |
| with torch.no_grad(): | |
| result = model(im_tensor) | |
| result = torch.squeeze(torch.nn.functional.interpolate(result[0][0], size=image_resized.size, mode='bilinear'), 0) | |
| ma = torch.max(result) | |
| mi = torch.min(result) | |
| result = (result-mi)/(ma-mi) | |
| im_array = (result*255).cpu().data.numpy().astype(np.uint8) | |
| pil_im = Image.fromarray(np.squeeze(im_array)).resize(original_size, Image.BILINEAR) | |
| # Create transparent image | |
| new_im = Image.new("RGBA", original_size, (0,0,0,0)) | |
| new_im.paste(image, mask=pil_im) | |
| return new_im | |
| # Load the model | |
| def load_model(): | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| net = BriaRMBG() | |
| net.load_state_dict(torch.load("model.pth", map_location=device)) | |
| net.to(device) | |
| net.eval() | |
| return net | |
| # Streamlit app | |
| def main(): | |
| st.title("Background Removal App") | |
| # Load model | |
| model = load_model() | |
| # File uploader | |
| uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) | |
| if uploaded_file is not None: | |
| # Display original image | |
| image = convert_to_jpg(uploaded_file, uploaded_file.name) | |
| st.image(image, caption="Original Image", use_column_width=True) | |
| # Process button | |
| if st.button("Remove Background"): | |
| # Process image | |
| result = remove_background(model, image) | |
| # Display result | |
| st.image(result, caption="Image with Background Removed", use_column_width=True) | |
| # Save button | |
| buf = io.BytesIO() | |
| result.save(buf, format="PNG") | |
| byte_im = buf.getvalue() | |
| st.download_button( | |
| label="Download Image", | |
| data=byte_im, | |
| file_name="background_removed.png", | |
| mime="image/png" | |
| ) | |
| if __name__ == "__main__": | |
| main() | |