Spaces:
Sleeping
Sleeping
| # app.py | |
| import gradio as gr | |
| import torch | |
| import numpy as np | |
| from PIL import Image | |
| import io | |
| import base64 | |
| import requests | |
| from torchvision.transforms import Compose, Resize, ToTensor, Normalize | |
| # Download pre-trained DIS (IS-Net) weights | |
| def download_weights(): | |
| url = "https://github.com/xuebinqin/DIS/releases/download/v1.0/isnet-general-use.pth" | |
| response = requests.get(url) | |
| with open("isnet-general-use.pth", "wb") as f: | |
| f.write(response.content) | |
| # DIS (IS-Net) model definition (simplified) | |
| class ISNet(torch.nn.Module): | |
| def __init__(self): | |
| super(ISNet, self).__init__() | |
| # Placeholder for model architecture (simplified) | |
| # In practice, use the full IS-Net architecture from https://github.com/xuebinqin/DIS | |
| self.conv = torch.nn.Conv2d(3, 1, kernel_size=3, padding=1) | |
| # Load actual weights | |
| self.load_state_dict(torch.load("isnet-general-use.pth", map_location="cpu")) | |
| def forward(self, x): | |
| # Simplified forward pass (replace with actual IS-Net forward) | |
| return torch.sigmoid(self.conv(x)) | |
| # Initialize model | |
| download_weights() | |
| model = ISNet().eval() | |
| def remove_background(image): | |
| """ | |
| Remove background using DIS (IS-Net). | |
| Input: PIL Image | |
| Output: Base64-encoded PNG with transparent background | |
| """ | |
| try: | |
| # Preprocess image | |
| transform = Compose([ | |
| Resize((1024, 1024)), | |
| ToTensor(), | |
| Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]) | |
| ]) | |
| img_tensor = transform(image).unsqueeze(0) | |
| # Run inference | |
| with torch.no_grad(): | |
| mask = model(img_tensor).squeeze().cpu().numpy() | |
| # Post-process mask | |
| mask = (mask > 0.5).astype(np.uint8) * 255 | |
| mask = Image.fromarray(mask).resize(image.size, Image.LANCZOS) | |
| # Apply mask | |
| img_rgba = image.convert("RGBA") | |
| img_array = np.array(img_rgba) | |
| img_array[:, :, 3] = mask | |
| result = Image.fromarray(img_array) | |
| # Save to bytes buffer | |
| buffered = io.BytesIO() | |
| result.save(buffered, format="PNG") | |
| # Encode as base64 | |
| img_str = base64.b64encode(buffered.getvalue()).decode("utf-8") | |
| return f"data:image/png;base64,{img_str}" | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| # Create Gradio interface | |
| iface = gr.Interface( | |
| fn=remove_background, | |
| inputs=gr.Image(type="pil", label="Upload Image"), | |
| outputs=gr.Image(type="pil", label="Image with Background Removed"), | |
| title="Background Removal with DIS (IS-Net)", | |
| description="Upload an image to remove its background using the open-source DIS (IS-Net) model.", | |
| allow_flagging="never" | |
| ) | |
| # Launch the interface | |
| if __name__ == "__main__": | |
| iface.launch(server_name="0.0.0.0", server_port=7860) |