pipeline2 / app.py
Janeka's picture
Update app.py
9e721e5 verified
import gradio as gr
from rembg import remove
from PIL import Image
import numpy as np
import torch
import os
import sys
import requests
from tqdm import tqdm
import subprocess
# Clone repository if not present
if not os.path.exists('InSPyReNet'):
print("Cloning InSPyReNet repository...")
subprocess.run(['git', 'clone', '--depth', '1', 'https://github.com/plemeri/InSPyReNet.git'])
# Set up correct Python paths
sys.path.insert(0, os.path.abspath('InSPyReNet'))
sys.path.insert(0, os.path.abspath('InSPyReNet/lib'))
sys.path.insert(0, os.path.abspath('InSPyReNet/utils'))
# Download model weights
def download_file(url, filename):
response = requests.get(url, stream=True)
total_size = int(response.headers.get('content-length', 0))
with open(filename, 'wb') as f, tqdm(
desc=filename,
total=total_size,
unit='iB',
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in response.iter_content(chunk_size=1024):
size = f.write(data)
bar.update(size)
if not os.path.exists('InSPyReNet.pth'):
print("Downloading model weights...")
download_file(
"https://github.com/plemeri/InSPyReNet/releases/download/v1.0/InSPyReNet.pth",
"InSPyReNet.pth"
)
# Import after setting up environment
try:
from InSPyReNet import InSPyReNet
from modules.layers import load_model
from utils.misc import load_config
# Initialize model
print("Loading model...")
cfg = load_config('InSPyReNet/configs/InSPyReNet_SwinB.yaml')
device = torch.device('cpu')
model = InSPyReNet(cfg)
model = load_model(model, 'InSPyReNet.pth', device)
model.eval()
HAS_INSPYRE = True
except Exception as e:
print(f"Failed to load InSPyReNet: {str(e)}")
HAS_INSPYRE = False
def preprocess(image):
image = np.array(image).astype(np.float32)
image -= np.array([104.00699, 116.66877, 122.67892])
image = image.transpose((2, 0, 1))
return torch.from_numpy(image).unsqueeze(0)
def process_with_inspyrenet(image):
if not HAS_INSPYRE:
raise RuntimeError("InSPyReNet failed to load")
image_tensor = preprocess(image).to(device)
with torch.no_grad():
pred = model(image_tensor)
mask = (pred.squeeze().cpu().numpy() > 0.5).astype(np.uint8) * 255
return mask
def remove_background(input_image, model_choice="Rembg (U²-Net)"):
try:
if isinstance(input_image, np.ndarray):
input_image = Image.fromarray(input_image)
if model_choice == "InSPyReNet" and HAS_INSPYRE:
mask = process_with_inspyrenet(input_image)
output = input_image.copy()
output.putalpha(Image.fromarray(mask))
else:
if model_choice == "InSPyReNet" and not HAS_INSPYRE:
print("Falling back to Rembg due to InSPyReNet loading error")
output = remove(input_image)
if output.mode == 'RGBA':
mask = output.split()[-1]
else:
mask = Image.new('L', output.size, 255)
return output, mask
except Exception as e:
print(f"Error: {str(e)}")
return None, None
iface = gr.Interface(
fn=remove_background,
inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Radio(
choices=["Rembg (U²-Net)", "InSPyReNet"],
value="Rembg (U²-Net)",
label="Model Selection"
)
],
outputs=[
gr.Image(type="pil", label="Result"),
gr.Image(type="pil", label="Mask")
],
title="Professional Background Remover",
description="Choose between Rembg (faster) or InSPyReNet (higher quality)"
)
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860)