ak / app.py
AkashKumarave's picture
Create app.py
6779002 verified
raw
history blame
2.86 kB
# app.py
import gradio as gr
import torch
import numpy as np
from PIL import Image
import io
import base64
import requests
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
# Download pre-trained DIS (IS-Net) weights
def download_weights():
url = "https://github.com/xuebinqin/DIS/releases/download/v1.0/isnet-general-use.pth"
response = requests.get(url)
with open("isnet-general-use.pth", "wb") as f:
f.write(response.content)
# DIS (IS-Net) model definition (simplified)
class ISNet(torch.nn.Module):
def __init__(self):
super(ISNet, self).__init__()
# Placeholder for model architecture (simplified)
# In practice, use the full IS-Net architecture from https://github.com/xuebinqin/DIS
self.conv = torch.nn.Conv2d(3, 1, kernel_size=3, padding=1)
# Load actual weights
self.load_state_dict(torch.load("isnet-general-use.pth", map_location="cpu"))
def forward(self, x):
# Simplified forward pass (replace with actual IS-Net forward)
return torch.sigmoid(self.conv(x))
# Initialize model
download_weights()
model = ISNet().eval()
def remove_background(image):
"""
Remove background using DIS (IS-Net).
Input: PIL Image
Output: Base64-encoded PNG with transparent background
"""
try:
# Preprocess image
transform = Compose([
Resize((1024, 1024)),
ToTensor(),
Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
])
img_tensor = transform(image).unsqueeze(0)
# Run inference
with torch.no_grad():
mask = model(img_tensor).squeeze().cpu().numpy()
# Post-process mask
mask = (mask > 0.5).astype(np.uint8) * 255
mask = Image.fromarray(mask).resize(image.size, Image.LANCZOS)
# Apply mask
img_rgba = image.convert("RGBA")
img_array = np.array(img_rgba)
img_array[:, :, 3] = mask
result = Image.fromarray(img_array)
# Save to bytes buffer
buffered = io.BytesIO()
result.save(buffered, format="PNG")
# Encode as base64
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
return f"data:image/png;base64,{img_str}"
except Exception as e:
return f"Error: {str(e)}"
# Create Gradio interface
iface = gr.Interface(
fn=remove_background,
inputs=gr.Image(type="pil", label="Upload Image"),
outputs=gr.Image(type="pil", label="Image with Background Removed"),
title="Background Removal with DIS (IS-Net)",
description="Upload an image to remove its background using the open-source DIS (IS-Net) model.",
allow_flagging="never"
)
# Launch the interface
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860)