File size: 3,825 Bytes
d554758 950af1e 4e6d0f7 950af1e d74c67e e50da0f 8fbeea7 a0d6a06 a0d7fb2 8fbeea7 9e721e5 8fbeea7 a0d7fb2 e50da0f 6f4b876 8fbeea7 e50da0f 8fbeea7 a0d6a06 8fbeea7 a0d7fb2 9e721e5 a0d6a06 e50da0f 8fbeea7 e50da0f 9e721e5 e50da0f d74c67e e50da0f d74c67e e50da0f 5f3e8c9 9e721e5 e50da0f 9e721e5 e50da0f a0d6a06 e50da0f 950af1e 5f3e8c9 8fbeea7 d74c67e 07d78f3 d74c67e e50da0f d74c67e 8fbeea7 d74c67e 8fbeea7 e50da0f d74c67e d554758 6d99820 d74c67e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
import gradio as gr
from rembg import remove
from PIL import Image
import numpy as np
import torch
import os
import sys
import requests
from tqdm import tqdm
import subprocess
# Clone repository if not present
if not os.path.exists('InSPyReNet'):
print("Cloning InSPyReNet repository...")
subprocess.run(['git', 'clone', '--depth', '1', 'https://github.com/plemeri/InSPyReNet.git'])
# Set up correct Python paths
sys.path.insert(0, os.path.abspath('InSPyReNet'))
sys.path.insert(0, os.path.abspath('InSPyReNet/lib'))
sys.path.insert(0, os.path.abspath('InSPyReNet/utils'))
# Download model weights
def download_file(url, filename):
response = requests.get(url, stream=True)
total_size = int(response.headers.get('content-length', 0))
with open(filename, 'wb') as f, tqdm(
desc=filename,
total=total_size,
unit='iB',
unit_scale=True,
unit_divisor=1024,
) as bar:
for data in response.iter_content(chunk_size=1024):
size = f.write(data)
bar.update(size)
if not os.path.exists('InSPyReNet.pth'):
print("Downloading model weights...")
download_file(
"https://github.com/plemeri/InSPyReNet/releases/download/v1.0/InSPyReNet.pth",
"InSPyReNet.pth"
)
# Import after setting up environment
try:
from InSPyReNet import InSPyReNet
from modules.layers import load_model
from utils.misc import load_config
# Initialize model
print("Loading model...")
cfg = load_config('InSPyReNet/configs/InSPyReNet_SwinB.yaml')
device = torch.device('cpu')
model = InSPyReNet(cfg)
model = load_model(model, 'InSPyReNet.pth', device)
model.eval()
HAS_INSPYRE = True
except Exception as e:
print(f"Failed to load InSPyReNet: {str(e)}")
HAS_INSPYRE = False
def preprocess(image):
image = np.array(image).astype(np.float32)
image -= np.array([104.00699, 116.66877, 122.67892])
image = image.transpose((2, 0, 1))
return torch.from_numpy(image).unsqueeze(0)
def process_with_inspyrenet(image):
if not HAS_INSPYRE:
raise RuntimeError("InSPyReNet failed to load")
image_tensor = preprocess(image).to(device)
with torch.no_grad():
pred = model(image_tensor)
mask = (pred.squeeze().cpu().numpy() > 0.5).astype(np.uint8) * 255
return mask
def remove_background(input_image, model_choice="Rembg (U²-Net)"):
try:
if isinstance(input_image, np.ndarray):
input_image = Image.fromarray(input_image)
if model_choice == "InSPyReNet" and HAS_INSPYRE:
mask = process_with_inspyrenet(input_image)
output = input_image.copy()
output.putalpha(Image.fromarray(mask))
else:
if model_choice == "InSPyReNet" and not HAS_INSPYRE:
print("Falling back to Rembg due to InSPyReNet loading error")
output = remove(input_image)
if output.mode == 'RGBA':
mask = output.split()[-1]
else:
mask = Image.new('L', output.size, 255)
return output, mask
except Exception as e:
print(f"Error: {str(e)}")
return None, None
iface = gr.Interface(
fn=remove_background,
inputs=[
gr.Image(type="pil", label="Input Image"),
gr.Radio(
choices=["Rembg (U²-Net)", "InSPyReNet"],
value="Rembg (U²-Net)",
label="Model Selection"
)
],
outputs=[
gr.Image(type="pil", label="Result"),
gr.Image(type="pil", label="Mask")
],
title="Professional Background Remover",
description="Choose between Rembg (faster) or InSPyReNet (higher quality)"
)
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860) |