Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import numpy as np | |
| import cv2 | |
| import torch | |
| from PIL import Image | |
| import requests | |
| from io import BytesIO | |
| # Helper functions for each model | |
| def load_image(image_path_or_url): | |
| if isinstance(image_path_or_url, str) and image_path_or_url.startswith('http'): | |
| response = requests.get(image_path_or_url) | |
| img = Image.open(BytesIO(response.content)) | |
| else: | |
| img = Image.open(image_path_or_url) | |
| return img.convert('RGB') | |
| def prepare_image_for_api(image): | |
| """Convert PIL Image to bytes for API requests""" | |
| img_byte_arr = BytesIO() | |
| image.save(img_byte_arr, format='PNG') | |
| img_byte_arr.seek(0) | |
| return img_byte_arr | |
| def process_bria(image): | |
| # Use BRIA API or hosted inference | |
| # Example using hosted endpoint (replace with actual API) | |
| bria_api_url = "https://api.bria.ai/removebg" | |
| # Prepare image for API | |
| image_bytes = prepare_image_for_api(image) | |
| files = {'image': ('image.png', image_bytes, 'image/png')} | |
| response = requests.post(bria_api_url, files=files) | |
| result = Image.open(BytesIO(response.content)) | |
| return result | |
| def process_inspyrenet(image): | |
| # Use INSPyReNet from Hugging Face | |
| # First convert to numpy array as expected by transformers | |
| image_np = np.array(image) | |
| from transformers import pipeline | |
| pipe = pipeline("image-segmentation", "mattmdjaga/INSPyReNet") | |
| result = pipe(image_np) | |
| # Convert result back to PIL Image | |
| return Image.fromarray(result) | |
| def process_u2net(image): | |
| # U²-Net from Hugging Face | |
| image_np = np.array(image) | |
| from transformers import pipeline | |
| pipe = pipeline("image-segmentation", "silk-road/u2net") | |
| result = pipe(image_np) | |
| return Image.fromarray(result) | |
| def process_u2net_human(image): | |
| # U²-Net Human Seg from Hugging Face | |
| image_np = np.array(image) | |
| from transformers import pipeline | |
| pipe = pipeline("image-segmentation", "mattmdjaga/u2net-human-seg") | |
| result = pipe(image_np) | |
| return Image.fromarray(result) | |
| def process_isnet_general(image): | |
| # Placeholder - implement actual ISNet processing | |
| # For now just return a dummy mask (white rectangle) | |
| mask = Image.new('L', image.size, 255) | |
| return mask | |
| def process_isnet_anime(image): | |
| # Placeholder - implement actual ISNet processing | |
| # For now just return a dummy mask (white rectangle) | |
| mask = Image.new('L', image.size, 255) | |
| return mask | |
| def combine_masks(masks): | |
| # Convert all masks to numpy arrays and combine | |
| mask_arrays = [np.array(mask) for mask in masks] | |
| combined = np.zeros_like(mask_arrays[0], dtype=np.uint8) | |
| for mask in mask_arrays: | |
| # Normalize each mask to 0-1 range before combining | |
| mask_normalized = mask / 255.0 | |
| combined = np.maximum(combined, mask_normalized) | |
| return (combined * 255).astype(np.uint8) | |
| def process_image(image): | |
| # Convert Gradio input to PIL Image if needed | |
| if isinstance(image, np.ndarray): | |
| image = Image.fromarray(image) | |
| # Process through all models in sequence | |
| results = [ | |
| process_bria(image), | |
| process_inspyrenet(image), | |
| process_u2net(image), | |
| process_u2net_human(image), | |
| process_isnet_general(image), | |
| process_isnet_anime(image) | |
| ] | |
| # Combine results | |
| combined_mask = combine_masks(results) | |
| # Apply mask to original image | |
| final_image = Image.composite( | |
| image, | |
| Image.new('RGB', image.size, (0, 0, 0)), | |
| Image.fromarray(combined_mask) | |
| ) | |
| return final_image | |
| # Gradio interface | |
| iface = gr.Interface( | |
| fn=process_image, | |
| inputs=gr.Image(), | |
| outputs=gr.Image(), | |
| title="Multi-Model Background Removal", | |
| description="Combines BRIA, INSPyReNet, U²-Net, U²-Net Human Seg, ISNet-General-Use, and ISNet-Anime for superior background removal" | |
| ) | |
| iface.launch(share=True) # Added share=True for public URL |