|
|
|
|
|
import streamlit as st |
|
|
from PIL import Image |
|
|
import torch |
|
|
import torchvision.transforms as transforms |
|
|
import numpy as np |
|
|
import subprocess |
|
|
|
|
|
|
|
|
try: |
|
|
import pytorch_hub_examples |
|
|
except ModuleNotFoundError: |
|
|
subprocess.run(["git", "clone", "https://github.com/facebookresearch/pytorch_hub_examples.git"]) |
|
|
subprocess.run(["cd", "pytorch_hub_examples", "&&", "python", "setup.py", "install"]) |
|
|
import pytorch_hub_examples |
|
|
|
|
|
|
|
|
model = pytorch_hub_examples.u2net(pretrained=True) |
|
|
model.eval() |
|
|
|
|
|
|
|
|
transform = transforms.Compose([ |
|
|
transforms.Resize((320, 320)), |
|
|
transforms.ToTensor(), |
|
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
]) |
|
|
|
|
|
def remove_background(image): |
|
|
try: |
|
|
img = transform(image).unsqueeze(0) |
|
|
|
|
|
with torch.no_grad(): |
|
|
out = model(img) |
|
|
mask = torch.sigmoid(out[0]) |
|
|
|
|
|
mask = mask.squeeze().cpu().numpy() |
|
|
mask = (mask * 255).astype(np.uint8) |
|
|
mask = Image.fromarray(mask).convert("L") |
|
|
|
|
|
image = image.convert("RGBA") |
|
|
new_image = Image.new("RGBA", image.size, (255, 255, 255, 0)) |
|
|
|
|
|
for x in range(image.width): |
|
|
for y in range(image.height): |
|
|
if mask.getpixel((x, y)) > 0: |
|
|
new_image.putpixel((x, y), image.getpixel((x, y))) |
|
|
|
|
|
return new_image |
|
|
except Exception as e: |
|
|
st.error(f"Error: {e}") |
|
|
return None |
|
|
|
|
|
st.title("Background Remover") |
|
|
|
|
|
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "jpeg", "png"]) |
|
|
|
|
|
if uploaded_file is not None: |
|
|
image = Image.open(uploaded_file).convert("RGB") |
|
|
st.image(image, caption="Uploaded Image", use_column_width=True) |
|
|
|
|
|
if st.button("Remove Background"): |
|
|
with st.spinner("Removing background..."): |
|
|
result_image = remove_background(image) |
|
|
if result_image: |
|
|
st.image(result_image, caption="Background Removed", use_column_width=True) |