Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| import numpy as np | |
| from PIL import Image | |
| import torch | |
| import torchvision.transforms as T | |
| import io | |
| # Assuming you have the U2NET model defined somewhere | |
| from model.u2net import U2NET # Replace with your actual import path | |
| # Initialize the U2NET model | |
| u2net = U2NET(in_ch=3, out_ch=1) | |
| def load_model(model, model_path, device): | |
| model.load_state_dict(torch.load(model_path, map_location=device)) | |
| model = model.to(device) | |
| return model | |
| # Load the model onto the specified device | |
| u2net = load_model(model=u2net, model_path="u2net.pth", device="cpu") | |
| # Mean and std for normalization | |
| mean = torch.tensor([0.485, 0.456, 0.406]) | |
| std = torch.tensor([0.229, 0.224, 0.225]) | |
| resize_shape = (320, 320) | |
| transforms = T.Compose([ | |
| T.Resize(resize_shape), | |
| T.ToTensor(), | |
| T.Normalize(mean=mean, std=std) | |
| ]) | |
| def prepare_single_image(image, resize, transforms, device): | |
| """Prepare a single image for prediction.""" | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| image = image.convert("RGB") | |
| image_resize = image.resize(resize, resample=Image.BILINEAR) | |
| image_trans = transforms(image_resize) | |
| image_batch = image_trans.unsqueeze(0).to(device) # Add batch dimension | |
| return image_batch | |
| def prepare_prediction(model, image_batch): | |
| model.eval() | |
| with torch.no_grad(): | |
| results = model(image_batch) | |
| mask = torch.squeeze(results[0].cpu(), dim=0) | |
| return mask.numpy() | |
| def normPRED(predicted_map): | |
| ma = np.max(predicted_map) | |
| mi = np.min(predicted_map) | |
| map_normalize = (predicted_map - mi) / (ma - mi) | |
| return map_normalize | |
| def apply_mask(image, mask): | |
| """Apply the mask to the original image and return the result with transparent background.""" | |
| mask = np.squeeze(mask) | |
| mask = normPRED(mask) | |
| mask = (mask * 255).astype(np.uint8) | |
| mask_image = Image.fromarray(mask, mode='L') # 'L' mode for grayscale | |
| original_image = image.convert("RGB") | |
| original_image = original_image.resize(resize_shape, resample=Image.BILINEAR) | |
| original_image_rgba = original_image.convert("RGBA") | |
| transparent_background = Image.new("RGBA", original_image_rgba.size, (0, 0, 0, 0)) | |
| masked_image = Image.composite(original_image_rgba, transparent_background, mask_image) | |
| return masked_image | |
| # Streamlit app setup | |
| st.title("Image Segmentation with U2NET") | |
| # Sidebar for file upload and controls | |
| st.sidebar.title("Controls :gear:") | |
| uploaded_file = st.sidebar.file_uploader("Choose an image...", type=["png", "jpg", "jpeg"]) | |
| # Function to handle image and segmentation display | |
| def fix_image(upload=None): | |
| if upload: | |
| image = Image.open(upload) | |
| else: | |
| image = Image.open("8.jpg") | |
| # Prepare image for segmentation | |
| image_batch = prepare_single_image(image, resize_shape, transforms, "cpu") | |
| prediction_u2net = prepare_prediction(u2net, image_batch) | |
| masked_image = apply_mask(image, prediction_u2net) | |
| # Display the original and segmented images side by side | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| st.image(image, caption="Uploaded Image", use_column_width=True) | |
| with col2: | |
| st.image(masked_image, caption='Segmented Image', use_column_width=True) | |
| # Provide download option for segmented image | |
| buf = io.BytesIO() | |
| masked_image.save(buf, format='PNG') | |
| byte_im = buf.getvalue() | |
| st.sidebar.markdown('### Download Segmented Image') | |
| st.sidebar.download_button( | |
| label="Download Segmented Image", | |
| data=byte_im, | |
| file_name="segmented_image.png", | |
| mime="image/png" | |
| ) | |
| if uploaded_file is not None: | |
| fix_image(upload=uploaded_file) | |
| else: | |
| fix_image() # Use default image if none uploaded | |