Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import spaces | |
| import torch | |
| import uuid | |
| from PIL import Image | |
| from torchvision import transforms | |
| from transformers import AutoModelForImageSegmentation | |
| from typing import Union, List | |
| from loadimg import load_img # Your helper to load from URL or file | |
| torch.set_float32_matmul_precision("high") | |
| # Load BiRefNet model | |
| birefnet = AutoModelForImageSegmentation.from_pretrained( | |
| "ZhengPeng7/BiRefNet", trust_remote_code=True | |
| ) | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| birefnet.to(device) | |
| # Image transformation | |
| transform_image = transforms.Compose([ | |
| transforms.Resize((1024, 1024)), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
| ]) | |
| def process(image: Image.Image) -> Image.Image: | |
| image_size = image.size | |
| input_tensor = transform_image(image).unsqueeze(0).to(device) | |
| with torch.no_grad(): | |
| preds = birefnet(input_tensor)[-1].sigmoid().cpu() | |
| pred = preds[0].squeeze() | |
| mask = transforms.ToPILImage()(pred).resize(image_size).convert("L") | |
| binary_mask = mask.point(lambda p: 255 if p > 127 else 0) | |
| white_bg = Image.new("RGB", image_size, (255, 255, 255)) | |
| result = Image.composite(image, white_bg, binary_mask) | |
| return result | |
| def handler(image=None, image_url=None, batch_urls=None) -> Union[str, List[str], None]: | |
| results = [] | |
| try: | |
| # Single image upload | |
| if image is not None: | |
| image = image.convert("RGB") | |
| processed = process(image) | |
| filename = f"output_{uuid.uuid4().hex[:8]}.png" | |
| processed.save(filename) | |
| return filename | |
| # Single image from URL | |
| if image_url: | |
| im = load_img(image_url, output_type="pil").convert("RGB") | |
| processed = process(im) | |
| filename = f"output_{uuid.uuid4().hex[:8]}.png" | |
| processed.save(filename) | |
| return filename | |
| # Batch of URLs | |
| if batch_urls: | |
| urls = [u.strip() for u in batch_urls.split(",") if u.strip()] | |
| for url in urls: | |
| try: | |
| im = load_img(url, output_type="pil").convert("RGB") | |
| processed = process(im) | |
| filename = f"output_{uuid.uuid4().hex[:8]}.png" | |
| processed.save(filename) | |
| results.append(filename) | |
| except Exception as e: | |
| print(f"Error with {url}: {e}") | |
| return results if results else None | |
| except Exception as e: | |
| print("General error:", e) | |
| return None | |
| # Interface | |
| demo = gr.Interface( | |
| fn=handler, | |
| inputs=[ | |
| gr.Image(label="Upload Image", type="pil"), | |
| gr.Textbox(label="Paste Image URL"), | |
| gr.Textbox(label="Comma-separated Image URLs (Batch)"), | |
| ], | |
| outputs=gr.File(label="Output File(s)", file_count="multiple"), | |
| title="Background Remover (White Fill)", | |
| description="Upload an image, paste a URL, or send a batch of URLs to remove the background and replace it with white.", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(show_error=True, mcp_server=True) |