File size: 2,859 Bytes
6779002
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
# app.py
import gradio as gr
import torch
import numpy as np
from PIL import Image
import io
import base64
import requests
from torchvision.transforms import Compose, Resize, ToTensor, Normalize

# Download pre-trained DIS (IS-Net) weights
def download_weights():
    url = "https://github.com/xuebinqin/DIS/releases/download/v1.0/isnet-general-use.pth"
    response = requests.get(url)
    with open("isnet-general-use.pth", "wb") as f:
        f.write(response.content)

# DIS (IS-Net) model definition (simplified)
class ISNet(torch.nn.Module):
    def __init__(self):
        super(ISNet, self).__init__()
        # Placeholder for model architecture (simplified)
        # In practice, use the full IS-Net architecture from https://github.com/xuebinqin/DIS
        self.conv = torch.nn.Conv2d(3, 1, kernel_size=3, padding=1)
        # Load actual weights
        self.load_state_dict(torch.load("isnet-general-use.pth", map_location="cpu"))

    def forward(self, x):
        # Simplified forward pass (replace with actual IS-Net forward)
        return torch.sigmoid(self.conv(x))

# Initialize model
download_weights()
model = ISNet().eval()

def remove_background(image):
    """
    Remove background using DIS (IS-Net).
    Input: PIL Image
    Output: Base64-encoded PNG with transparent background
    """
    try:
        # Preprocess image
        transform = Compose([
            Resize((1024, 1024)),
            ToTensor(),
            Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
        ])
        img_tensor = transform(image).unsqueeze(0)

        # Run inference
        with torch.no_grad():
            mask = model(img_tensor).squeeze().cpu().numpy()
        
        # Post-process mask
        mask = (mask > 0.5).astype(np.uint8) * 255
        mask = Image.fromarray(mask).resize(image.size, Image.LANCZOS)
        
        # Apply mask
        img_rgba = image.convert("RGBA")
        img_array = np.array(img_rgba)
        img_array[:, :, 3] = mask
        result = Image.fromarray(img_array)
        
        # Save to bytes buffer
        buffered = io.BytesIO()
        result.save(buffered, format="PNG")
        
        # Encode as base64
        img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
        return f"data:image/png;base64,{img_str}"
    except Exception as e:
        return f"Error: {str(e)}"

# Create Gradio interface
iface = gr.Interface(
    fn=remove_background,
    inputs=gr.Image(type="pil", label="Upload Image"),
    outputs=gr.Image(type="pil", label="Image with Background Removed"),
    title="Background Removal with DIS (IS-Net)",
    description="Upload an image to remove its background using the open-source DIS (IS-Net) model.",
    allow_flagging="never"
)

# Launch the interface
if __name__ == "__main__":
    iface.launch(server_name="0.0.0.0", server_port=7860)