AkashKumarave commited on
Commit
42af69b
·
verified ·
1 Parent(s): 4efd361

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -16
app.py CHANGED
@@ -6,32 +6,54 @@ from PIL import Image
6
  import io
7
  import base64
8
  import requests
 
9
  from torchvision.transforms import Compose, Resize, ToTensor, Normalize
10
 
11
  # Download pre-trained DIS (IS-Net) weights
12
  def download_weights():
13
- url = "https://github.com/xuebinqin/DIS/releases/download/v1.0/isnet-general-use.pth"
14
- response = requests.get(url)
15
- with open("isnet-general-use.pth", "wb") as f:
16
- f.write(response.content)
 
 
 
 
 
 
 
 
17
 
18
- # DIS (IS-Net) model definition (simplified)
19
  class ISNet(torch.nn.Module):
20
  def __init__(self):
21
  super(ISNet, self).__init__()
22
- # Placeholder for model architecture (simplified)
23
- # In practice, use the full IS-Net architecture from https://github.com/xuebinqin/DIS
24
- self.conv = torch.nn.Conv2d(3, 1, kernel_size=3, padding=1)
25
- # Load actual weights
26
- self.load_state_dict(torch.load("isnet-general-use.pth", map_location="cpu"))
 
 
27
 
28
  def forward(self, x):
29
- # Simplified forward pass (replace with actual IS-Net forward)
30
- return torch.sigmoid(self.conv(x))
 
 
 
 
 
31
 
32
  # Initialize model
33
- download_weights()
34
- model = ISNet().eval()
 
 
 
 
 
 
35
 
36
  def remove_background(image):
37
  """
@@ -40,6 +62,10 @@ def remove_background(image):
40
  Output: Base64-encoded PNG with transparent background
41
  """
42
  try:
 
 
 
 
43
  # Preprocess image
44
  transform = Compose([
45
  Resize((1024, 1024)),
@@ -77,8 +103,8 @@ iface = gr.Interface(
77
  fn=remove_background,
78
  inputs=gr.Image(type="pil", label="Upload Image"),
79
  outputs=gr.Image(type="pil", label="Image with Background Removed"),
80
- title="Background Removal with DIS (IS-Net)",
81
- description="Upload an image to remove its background using the open-source DIS (IS-Net) model.",
82
  allow_flagging="never"
83
  )
84
 
 
6
  import io
7
  import base64
8
  import requests
9
+ import os
10
  from torchvision.transforms import Compose, Resize, ToTensor, Normalize
11
 
12
  # Download pre-trained DIS (IS-Net) weights
13
  def download_weights():
14
+ weights_path = "isnet-general-use.pth"
15
+ if not os.path.exists(weights_path):
16
+ url = "https://github.com/xuebinqin/DIS/releases/download/v1.0/isnet-general-use.pth"
17
+ try:
18
+ response = requests.get(url, stream=True)
19
+ response.raise_for_status()
20
+ with open(weights_path, "wb") as f:
21
+ for chunk in response.iter_content(chunk_size=8192):
22
+ f.write(chunk)
23
+ except Exception as e:
24
+ raise Exception(f"Failed to download weights: {str(e)}")
25
+ return weights_path
26
 
27
+ # DIS (IS-Net) model architecture (simplified from https://github.com/xuebinqin/DIS)
28
  class ISNet(torch.nn.Module):
29
  def __init__(self):
30
  super(ISNet, self).__init__()
31
+ # Simplified architecture (for demonstration; replace with full IS-Net)
32
+ # Full architecture: https://github.com/xuebinqin/DIS/blob/main/ISNet.py
33
+ self.conv1 = torch.nn.Conv2d(3, 64, kernel_size=3, padding=1)
34
+ self.pool = torch.nn.MaxPool2d(2, 2)
35
+ self.conv2 = torch.nn.Conv2d(64, 128, kernel_size=3, padding=1)
36
+ self.upconv = torch.nn.ConvTranspose2d(128, 64, kernel_size=2, stride=2)
37
+ self.conv3 = torch.nn.Conv2d(64, 1, kernel_size=3, padding=1)
38
 
39
  def forward(self, x):
40
+ # Simplified forward pass (replace with full IS-Net forward)
41
+ x = torch.relu(self.conv1(x))
42
+ x = self.pool(x)
43
+ x = torch.relu(self.conv2(x))
44
+ x = self.upconv(x)
45
+ x = torch.sigmoid(self.conv3(x))
46
+ return x
47
 
48
  # Initialize model
49
+ try:
50
+ weights_path = download_weights()
51
+ model = ISNet()
52
+ state_dict = torch.load(weights_path, map_location="cpu", weights_only=True)
53
+ model.load_state_dict(state_dict)
54
+ model.eval()
55
+ except Exception as e:
56
+ raise Exception(f"Model initialization failed: {str(e)}")
57
 
58
  def remove_background(image):
59
  """
 
62
  Output: Base64-encoded PNG with transparent background
63
  """
64
  try:
65
+ # Ensure image is RGB
66
+ if image.mode != "RGB":
67
+ image = image.convert("RGB")
68
+
69
  # Preprocess image
70
  transform = Compose([
71
  Resize((1024, 1024)),
 
103
  fn=remove_background,
104
  inputs=gr.Image(type="pil", label="Upload Image"),
105
  outputs=gr.Image(type="pil", label="Image with Background Removed"),
106
+ title="DIS Background Removal",
107
+ description="Remove backgrounds from any image using the open-source DIS (IS-Net) model.",
108
  allow_flagging="never"
109
  )
110