File size: 3,825 Bytes
d554758
950af1e
4e6d0f7
950af1e
d74c67e
 
e50da0f
8fbeea7
 
 
a0d6a06
a0d7fb2
8fbeea7
 
9e721e5
8fbeea7
a0d7fb2
e50da0f
 
6f4b876
8fbeea7
e50da0f
8fbeea7
 
 
a0d6a06
8fbeea7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0d7fb2
9e721e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a0d6a06
e50da0f
8fbeea7
 
 
e50da0f
 
 
9e721e5
 
 
e50da0f
d74c67e
e50da0f
d74c67e
 
 
e50da0f
5f3e8c9
 
 
 
9e721e5
e50da0f
 
 
 
9e721e5
 
e50da0f
 
 
 
 
a0d6a06
e50da0f
950af1e
5f3e8c9
8fbeea7
d74c67e
07d78f3
d74c67e
 
e50da0f
 
 
 
 
 
 
 
d74c67e
8fbeea7
 
d74c67e
8fbeea7
e50da0f
d74c67e
d554758
6d99820
d74c67e
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import gradio as gr
from rembg import remove
from PIL import Image
import numpy as np
import torch
import os
import sys
import requests
from tqdm import tqdm
import subprocess

# Clone repository if not present
if not os.path.exists('InSPyReNet'):
    print("Cloning InSPyReNet repository...")
    subprocess.run(['git', 'clone', '--depth', '1', 'https://github.com/plemeri/InSPyReNet.git'])

# Set up correct Python paths
sys.path.insert(0, os.path.abspath('InSPyReNet'))
sys.path.insert(0, os.path.abspath('InSPyReNet/lib'))
sys.path.insert(0, os.path.abspath('InSPyReNet/utils'))

# Download model weights
def download_file(url, filename):
    response = requests.get(url, stream=True)
    total_size = int(response.headers.get('content-length', 0))
    
    with open(filename, 'wb') as f, tqdm(
        desc=filename,
        total=total_size,
        unit='iB',
        unit_scale=True,
        unit_divisor=1024,
    ) as bar:
        for data in response.iter_content(chunk_size=1024):
            size = f.write(data)
            bar.update(size)

if not os.path.exists('InSPyReNet.pth'):
    print("Downloading model weights...")
    download_file(
        "https://github.com/plemeri/InSPyReNet/releases/download/v1.0/InSPyReNet.pth",
        "InSPyReNet.pth"
    )

# Import after setting up environment
try:
    from InSPyReNet import InSPyReNet
    from modules.layers import load_model
    from utils.misc import load_config
    
    # Initialize model
    print("Loading model...")
    cfg = load_config('InSPyReNet/configs/InSPyReNet_SwinB.yaml')
    device = torch.device('cpu')
    model = InSPyReNet(cfg)
    model = load_model(model, 'InSPyReNet.pth', device)
    model.eval()
    HAS_INSPYRE = True
except Exception as e:
    print(f"Failed to load InSPyReNet: {str(e)}")
    HAS_INSPYRE = False

def preprocess(image):
    image = np.array(image).astype(np.float32)
    image -= np.array([104.00699, 116.66877, 122.67892])
    image = image.transpose((2, 0, 1))
    return torch.from_numpy(image).unsqueeze(0)

def process_with_inspyrenet(image):
    if not HAS_INSPYRE:
        raise RuntimeError("InSPyReNet failed to load")
    
    image_tensor = preprocess(image).to(device)
    with torch.no_grad():
        pred = model(image_tensor)
    mask = (pred.squeeze().cpu().numpy() > 0.5).astype(np.uint8) * 255
    return mask

def remove_background(input_image, model_choice="Rembg (U²-Net)"):
    try:
        if isinstance(input_image, np.ndarray):
            input_image = Image.fromarray(input_image)
        
        if model_choice == "InSPyReNet" and HAS_INSPYRE:
            mask = process_with_inspyrenet(input_image)
            output = input_image.copy()
            output.putalpha(Image.fromarray(mask))
        else:
            if model_choice == "InSPyReNet" and not HAS_INSPYRE:
                print("Falling back to Rembg due to InSPyReNet loading error")
            output = remove(input_image)
            if output.mode == 'RGBA':
                mask = output.split()[-1]
            else:
                mask = Image.new('L', output.size, 255)
        
        return output, mask
    
    except Exception as e:
        print(f"Error: {str(e)}")
        return None, None

iface = gr.Interface(
    fn=remove_background,
    inputs=[
        gr.Image(type="pil", label="Input Image"),
        gr.Radio(
            choices=["Rembg (U²-Net)", "InSPyReNet"],
            value="Rembg (U²-Net)",
            label="Model Selection"
        )
    ],
    outputs=[
        gr.Image(type="pil", label="Result"), 
        gr.Image(type="pil", label="Mask")
    ],
    title="Professional Background Remover",
    description="Choose between Rembg (faster) or InSPyReNet (higher quality)"
)

if __name__ == "__main__":
    iface.launch(server_name="0.0.0.0", server_port=7860)