Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
|
@@ -6,32 +6,54 @@ from PIL import Image
|
|
| 6 |
import io
|
| 7 |
import base64
|
| 8 |
import requests
|
|
|
|
| 9 |
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
|
| 10 |
|
| 11 |
# Download pre-trained DIS (IS-Net) weights
|
| 12 |
def download_weights():
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
-
# DIS (IS-Net) model
|
| 19 |
class ISNet(torch.nn.Module):
|
| 20 |
def __init__(self):
|
| 21 |
super(ISNet, self).__init__()
|
| 22 |
-
#
|
| 23 |
-
#
|
| 24 |
-
self.
|
| 25 |
-
|
| 26 |
-
self.
|
|
|
|
|
|
|
| 27 |
|
| 28 |
def forward(self, x):
|
| 29 |
-
# Simplified forward pass (replace with
|
| 30 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 31 |
|
| 32 |
# Initialize model
|
| 33 |
-
|
| 34 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 35 |
|
| 36 |
def remove_background(image):
|
| 37 |
"""
|
|
@@ -40,6 +62,10 @@ def remove_background(image):
|
|
| 40 |
Output: Base64-encoded PNG with transparent background
|
| 41 |
"""
|
| 42 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 43 |
# Preprocess image
|
| 44 |
transform = Compose([
|
| 45 |
Resize((1024, 1024)),
|
|
@@ -77,8 +103,8 @@ iface = gr.Interface(
|
|
| 77 |
fn=remove_background,
|
| 78 |
inputs=gr.Image(type="pil", label="Upload Image"),
|
| 79 |
outputs=gr.Image(type="pil", label="Image with Background Removed"),
|
| 80 |
-
title="Background Removal
|
| 81 |
-
description="
|
| 82 |
allow_flagging="never"
|
| 83 |
)
|
| 84 |
|
|
|
|
| 6 |
import io
|
| 7 |
import base64
|
| 8 |
import requests
|
| 9 |
+
import os
|
| 10 |
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
|
| 11 |
|
| 12 |
# Download pre-trained DIS (IS-Net) weights
|
| 13 |
def download_weights():
|
| 14 |
+
weights_path = "isnet-general-use.pth"
|
| 15 |
+
if not os.path.exists(weights_path):
|
| 16 |
+
url = "https://github.com/xuebinqin/DIS/releases/download/v1.0/isnet-general-use.pth"
|
| 17 |
+
try:
|
| 18 |
+
response = requests.get(url, stream=True)
|
| 19 |
+
response.raise_for_status()
|
| 20 |
+
with open(weights_path, "wb") as f:
|
| 21 |
+
for chunk in response.iter_content(chunk_size=8192):
|
| 22 |
+
f.write(chunk)
|
| 23 |
+
except Exception as e:
|
| 24 |
+
raise Exception(f"Failed to download weights: {str(e)}")
|
| 25 |
+
return weights_path
|
| 26 |
|
| 27 |
+
# DIS (IS-Net) model architecture (simplified from https://github.com/xuebinqin/DIS)
|
| 28 |
class ISNet(torch.nn.Module):
|
| 29 |
def __init__(self):
|
| 30 |
super(ISNet, self).__init__()
|
| 31 |
+
# Simplified architecture (for demonstration; replace with full IS-Net)
|
| 32 |
+
# Full architecture: https://github.com/xuebinqin/DIS/blob/main/ISNet.py
|
| 33 |
+
self.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1)
|
| 34 |
+
self.pool = torch.nn.MaxPool2d(2, 2)
|
| 35 |
+
self.conv2 = torch.nn.Conv2d(64, 128, kernel_size=3, padding=1)
|
| 36 |
+
self.upconv = torch.nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
|
| 37 |
+
self.conv3 = torch.nn.Conv2d(64, 1, kernel_size=3, padding=1)
|
| 38 |
|
| 39 |
def forward(self, x):
|
| 40 |
+
# Simplified forward pass (replace with full IS-Net forward)
|
| 41 |
+
x = torch.relu(self.conv1(x))
|
| 42 |
+
x = self.pool(x)
|
| 43 |
+
x = torch.relu(self.conv2(x))
|
| 44 |
+
x = self.upconv(x)
|
| 45 |
+
x = torch.sigmoid(self.conv3(x))
|
| 46 |
+
return x
|
| 47 |
|
| 48 |
# Initialize model
|
| 49 |
+
try:
|
| 50 |
+
weights_path = download_weights()
|
| 51 |
+
model = ISNet()
|
| 52 |
+
state_dict = torch.load(weights_path, map_location="cpu", weights_only=True)
|
| 53 |
+
model.load_state_dict(state_dict)
|
| 54 |
+
model.eval()
|
| 55 |
+
except Exception as e:
|
| 56 |
+
raise Exception(f"Model initialization failed: {str(e)}")
|
| 57 |
|
| 58 |
def remove_background(image):
|
| 59 |
"""
|
|
|
|
| 62 |
Output: Base64-encoded PNG with transparent background
|
| 63 |
"""
|
| 64 |
try:
|
| 65 |
+
# Ensure image is RGB
|
| 66 |
+
if image.mode != "RGB":
|
| 67 |
+
image = image.convert("RGB")
|
| 68 |
+
|
| 69 |
# Preprocess image
|
| 70 |
transform = Compose([
|
| 71 |
Resize((1024, 1024)),
|
|
|
|
| 103 |
fn=remove_background,
|
| 104 |
inputs=gr.Image(type="pil", label="Upload Image"),
|
| 105 |
outputs=gr.Image(type="pil", label="Image with Background Removed"),
|
| 106 |
+
title="DIS Background Removal",
|
| 107 |
+
description="Remove backgrounds from any image using the open-source DIS (IS-Net) model.",
|
| 108 |
allow_flagging="never"
|
| 109 |
)
|
| 110 |
|