|
|
import gradio as gr |
|
|
from rembg import remove |
|
|
from PIL import Image |
|
|
import numpy as np |
|
|
import torch |
|
|
import os |
|
|
import sys |
|
|
import requests |
|
|
from tqdm import tqdm |
|
|
import subprocess |
|
|
|
|
|
|
|
|
if not os.path.exists('InSPyReNet'): |
|
|
print("Cloning InSPyReNet repository...") |
|
|
subprocess.run(['git', 'clone', '--depth', '1', 'https://github.com/plemeri/InSPyReNet.git']) |
|
|
|
|
|
|
|
|
sys.path.insert(0, os.path.abspath('InSPyReNet')) |
|
|
sys.path.insert(0, os.path.abspath('InSPyReNet/lib')) |
|
|
sys.path.insert(0, os.path.abspath('InSPyReNet/utils')) |
|
|
|
|
|
|
|
|
def download_file(url, filename): |
|
|
response = requests.get(url, stream=True) |
|
|
total_size = int(response.headers.get('content-length', 0)) |
|
|
|
|
|
with open(filename, 'wb') as f, tqdm( |
|
|
desc=filename, |
|
|
total=total_size, |
|
|
unit='iB', |
|
|
unit_scale=True, |
|
|
unit_divisor=1024, |
|
|
) as bar: |
|
|
for data in response.iter_content(chunk_size=1024): |
|
|
size = f.write(data) |
|
|
bar.update(size) |
|
|
|
|
|
if not os.path.exists('InSPyReNet.pth'): |
|
|
print("Downloading model weights...") |
|
|
download_file( |
|
|
"https://github.com/plemeri/InSPyReNet/releases/download/v1.0/InSPyReNet.pth", |
|
|
"InSPyReNet.pth" |
|
|
) |
|
|
|
|
|
|
|
|
try: |
|
|
from InSPyReNet import InSPyReNet |
|
|
from modules.layers import load_model |
|
|
from utils.misc import load_config |
|
|
|
|
|
|
|
|
print("Loading model...") |
|
|
cfg = load_config('InSPyReNet/configs/InSPyReNet_SwinB.yaml') |
|
|
device = torch.device('cpu') |
|
|
model = InSPyReNet(cfg) |
|
|
model = load_model(model, 'InSPyReNet.pth', device) |
|
|
model.eval() |
|
|
HAS_INSPYRE = True |
|
|
except Exception as e: |
|
|
print(f"Failed to load InSPyReNet: {str(e)}") |
|
|
HAS_INSPYRE = False |
|
|
|
|
|
def preprocess(image): |
|
|
image = np.array(image).astype(np.float32) |
|
|
image -= np.array([104.00699, 116.66877, 122.67892]) |
|
|
image = image.transpose((2, 0, 1)) |
|
|
return torch.from_numpy(image).unsqueeze(0) |
|
|
|
|
|
def process_with_inspyrenet(image): |
|
|
if not HAS_INSPYRE: |
|
|
raise RuntimeError("InSPyReNet failed to load") |
|
|
|
|
|
image_tensor = preprocess(image).to(device) |
|
|
with torch.no_grad(): |
|
|
pred = model(image_tensor) |
|
|
mask = (pred.squeeze().cpu().numpy() > 0.5).astype(np.uint8) * 255 |
|
|
return mask |
|
|
|
|
|
def remove_background(input_image, model_choice="Rembg (U²-Net)"): |
|
|
try: |
|
|
if isinstance(input_image, np.ndarray): |
|
|
input_image = Image.fromarray(input_image) |
|
|
|
|
|
if model_choice == "InSPyReNet" and HAS_INSPYRE: |
|
|
mask = process_with_inspyrenet(input_image) |
|
|
output = input_image.copy() |
|
|
output.putalpha(Image.fromarray(mask)) |
|
|
else: |
|
|
if model_choice == "InSPyReNet" and not HAS_INSPYRE: |
|
|
print("Falling back to Rembg due to InSPyReNet loading error") |
|
|
output = remove(input_image) |
|
|
if output.mode == 'RGBA': |
|
|
mask = output.split()[-1] |
|
|
else: |
|
|
mask = Image.new('L', output.size, 255) |
|
|
|
|
|
return output, mask |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error: {str(e)}") |
|
|
return None, None |
|
|
|
|
|
iface = gr.Interface( |
|
|
fn=remove_background, |
|
|
inputs=[ |
|
|
gr.Image(type="pil", label="Input Image"), |
|
|
gr.Radio( |
|
|
choices=["Rembg (U²-Net)", "InSPyReNet"], |
|
|
value="Rembg (U²-Net)", |
|
|
label="Model Selection" |
|
|
) |
|
|
], |
|
|
outputs=[ |
|
|
gr.Image(type="pil", label="Result"), |
|
|
gr.Image(type="pil", label="Mask") |
|
|
], |
|
|
title="Professional Background Remover", |
|
|
description="Choose between Rembg (faster) or InSPyReNet (higher quality)" |
|
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
iface.launch(server_name="0.0.0.0", server_port=7860) |