import gradio as gr from transformers import pipeline from PIL import Image, ImageFilter import numpy as np import io # Load the segmentation pipeline pipe = pipeline("image-segmentation", model="mattmdjaga/segformer_b2_clothes") # Simplified refine_mask function def refine_mask(mask): """Simplify and smooth the segmentation mask.""" mask_array = np.array(mask) mask_array = (mask_array > 128).astype(np.uint8) * 255 # Threshold to binary mask refined_mask = Image.fromarray(mask_array).filter(ImageFilter.GaussianBlur(0.5)) # Smooth edges return refined_mask # Function to blur the background def blur_background(image: bytes, blur_radius: int) -> bytes: # Convert the image from bytes to PIL Image image = Image.open(io.BytesIO(image)) # Perform segmentation result = pipe(image) # Extract the background mask background_mask = None for entry in result: if entry["label"] == "Background": background_mask = refine_mask(entry["mask"]) # Refine the background mask break if background_mask is None: return image # If no background is detected, return the original image # Convert the image and mask to NumPy arrays image_np = np.array(image) background_mask_np = np.array(background_mask) # Create a blurred version of the entire image blurred_image = image.filter(ImageFilter.GaussianBlur(radius=blur_radius)) blurred_np = np.array(blurred_image) # Combine the original image and the blurred background final_image = np.where(background_mask_np[..., None] == 255, blurred_np, image_np).astype(np.uint8) # Convert back to PIL image and then to bytes output_image = Image.fromarray(final_image) # Save the image to a bytes buffer img_byte_arr = io.BytesIO() output_image.save(img_byte_arr, format='PNG') img_byte_arr.seek(0) return img_byte_arr.read() # Gradio Interface def predict(image: bytes, blur_radius: int) -> bytes: return blur_background(image, blur_radius) # API setup def launch_api(): # Expose the API interface = gr.Interface( fn=predict, inputs=[ gr.Image(type="bytes"), # Input image as bytes gr.Slider(1, 50, step=1, label="Blur Intensity") # Slider for blur radius ], outputs=gr.Image(type="bytes"), # Output image as bytes title="Background Blur API", description="This API blurs the background of an image while preserving the subject.", ) interface.launch(share=True) # Launch the Gradio interface as an API # Launch the API launch_api()