Blur_Background / app.py
Kaliboy002's picture
Update app.py
89525d2 verified
import gradio as gr
from transformers import pipeline
from PIL import Image, ImageFilter
import numpy as np
import io
# Load the segmentation pipeline
pipe = pipeline("image-segmentation", model="mattmdjaga/segformer_b2_clothes")
# Simplified refine_mask function
def refine_mask(mask):
"""Simplify and smooth the segmentation mask."""
mask_array = np.array(mask)
mask_array = (mask_array > 128).astype(np.uint8) * 255 # Threshold to binary mask
refined_mask = Image.fromarray(mask_array).filter(ImageFilter.GaussianBlur(0.5)) # Smooth edges
return refined_mask
# Function to blur the background
def blur_background(image: bytes, blur_radius: int) -> bytes:
# Convert the image from bytes to PIL Image
image = Image.open(io.BytesIO(image))
# Perform segmentation
result = pipe(image)
# Extract the background mask
background_mask = None
for entry in result:
if entry["label"] == "Background":
background_mask = refine_mask(entry["mask"]) # Refine the background mask
break
if background_mask is None:
return image # If no background is detected, return the original image
# Convert the image and mask to NumPy arrays
image_np = np.array(image)
background_mask_np = np.array(background_mask)
# Create a blurred version of the entire image
blurred_image = image.filter(ImageFilter.GaussianBlur(radius=blur_radius))
blurred_np = np.array(blurred_image)
# Combine the original image and the blurred background
final_image = np.where(background_mask_np[..., None] == 255, blurred_np, image_np).astype(np.uint8)
# Convert back to PIL image and then to bytes
output_image = Image.fromarray(final_image)
# Save the image to a bytes buffer
img_byte_arr = io.BytesIO()
output_image.save(img_byte_arr, format='PNG')
img_byte_arr.seek(0)
return img_byte_arr.read()
# Gradio Interface
def predict(image: bytes, blur_radius: int) -> bytes:
return blur_background(image, blur_radius)
# API setup
def launch_api():
# Expose the API
interface = gr.Interface(
fn=predict,
inputs=[
gr.Image(type="bytes"), # Input image as bytes
gr.Slider(1, 50, step=1, label="Blur Intensity") # Slider for blur radius
],
outputs=gr.Image(type="bytes"), # Output image as bytes
title="Background Blur API",
description="This API blurs the background of an image while preserving the subject.",
)
interface.launch(share=True) # Launch the Gradio interface as an API
# Launch the API
launch_api()