AkashKumarave commited on
Commit
6779002
·
verified ·
1 Parent(s): c658d28

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +87 -0
app.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # app.py
2
+ import gradio as gr
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ import io
7
+ import base64
8
+ import requests
9
+ from torchvision.transforms import Compose, Resize, ToTensor, Normalize
10
+
11
+ # Download pre-trained DIS (IS-Net) weights
12
+ def download_weights():
13
+ url = "https://github.com/xuebinqin/DIS/releases/download/v1.0/isnet-general-use.pth"
14
+ response = requests.get(url)
15
+ with open("isnet-general-use.pth", "wb") as f:
16
+ f.write(response.content)
17
+
18
+ # DIS (IS-Net) model definition (simplified)
19
+ class ISNet(torch.nn.Module):
20
+ def __init__(self):
21
+ super(ISNet, self).__init__()
22
+ # Placeholder for model architecture (simplified)
23
+ # In practice, use the full IS-Net architecture from https://github.com/xuebinqin/DIS
24
+ self.conv = torch.nn.Conv2d(3, 1, kernel_size=3, padding=1)
25
+ # Load actual weights
26
+ self.load_state_dict(torch.load("isnet-general-use.pth", map_location="cpu"))
27
+
28
+ def forward(self, x):
29
+ # Simplified forward pass (replace with actual IS-Net forward)
30
+ return torch.sigmoid(self.conv(x))
31
+
32
+ # Initialize model
33
+ download_weights()
34
+ model = ISNet().eval()
35
+
36
+ def remove_background(image):
37
+ """
38
+ Remove background using DIS (IS-Net).
39
+ Input: PIL Image
40
+ Output: Base64-encoded PNG with transparent background
41
+ """
42
+ try:
43
+ # Preprocess image
44
+ transform = Compose([
45
+ Resize((1024, 1024)),
46
+ ToTensor(),
47
+ Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
48
+ ])
49
+ img_tensor = transform(image).unsqueeze(0)
50
+
51
+ # Run inference
52
+ with torch.no_grad():
53
+ mask = model(img_tensor).squeeze().cpu().numpy()
54
+
55
+ # Post-process mask
56
+ mask = (mask > 0.5).astype(np.uint8) * 255
57
+ mask = Image.fromarray(mask).resize(image.size, Image.LANCZOS)
58
+
59
+ # Apply mask
60
+ img_rgba = image.convert("RGBA")
61
+ img_array = np.array(img_rgba)
62
+ img_array[:, :, 3] = mask
63
+ result = Image.fromarray(img_array)
64
+
65
+ # Save to bytes buffer
66
+ buffered = io.BytesIO()
67
+ result.save(buffered, format="PNG")
68
+
69
+ # Encode as base64
70
+ img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
71
+ return f"data:image/png;base64,{img_str}"
72
+ except Exception as e:
73
+ return f"Error: {str(e)}"
74
+
75
+ # Create Gradio interface
76
+ iface = gr.Interface(
77
+ fn=remove_background,
78
+ inputs=gr.Image(type="pil", label="Upload Image"),
79
+ outputs=gr.Image(type="pil", label="Image with Background Removed"),
80
+ title="Background Removal with DIS (IS-Net)",
81
+ description="Upload an image to remove its background using the open-source DIS (IS-Net) model.",
82
+ allow_flagging="never"
83
+ )
84
+
85
+ # Launch the interface
86
+ if __name__ == "__main__":
87
+ iface.launch(server_name="0.0.0.0", server_port=7860)