import gradio as gr from rembg import remove from PIL import Image import numpy as np import torch import os import sys import requests from tqdm import tqdm import subprocess # Clone repository if not present if not os.path.exists('InSPyReNet'): print("Cloning InSPyReNet repository...") subprocess.run(['git', 'clone', '--depth', '1', 'https://github.com/plemeri/InSPyReNet.git']) # Set up correct Python paths sys.path.insert(0, os.path.abspath('InSPyReNet')) sys.path.insert(0, os.path.abspath('InSPyReNet/lib')) sys.path.insert(0, os.path.abspath('InSPyReNet/utils')) # Download model weights def download_file(url, filename): response = requests.get(url, stream=True) total_size = int(response.headers.get('content-length', 0)) with open(filename, 'wb') as f, tqdm( desc=filename, total=total_size, unit='iB', unit_scale=True, unit_divisor=1024, ) as bar: for data in response.iter_content(chunk_size=1024): size = f.write(data) bar.update(size) if not os.path.exists('InSPyReNet.pth'): print("Downloading model weights...") download_file( "https://github.com/plemeri/InSPyReNet/releases/download/v1.0/InSPyReNet.pth", "InSPyReNet.pth" ) # Import after setting up environment try: from InSPyReNet import InSPyReNet from modules.layers import load_model from utils.misc import load_config # Initialize model print("Loading model...") cfg = load_config('InSPyReNet/configs/InSPyReNet_SwinB.yaml') device = torch.device('cpu') model = InSPyReNet(cfg) model = load_model(model, 'InSPyReNet.pth', device) model.eval() HAS_INSPYRE = True except Exception as e: print(f"Failed to load InSPyReNet: {str(e)}") HAS_INSPYRE = False def preprocess(image): image = np.array(image).astype(np.float32) image -= np.array([104.00699, 116.66877, 122.67892]) image = image.transpose((2, 0, 1)) return torch.from_numpy(image).unsqueeze(0) def process_with_inspyrenet(image): if not HAS_INSPYRE: raise RuntimeError("InSPyReNet failed to load") image_tensor = preprocess(image).to(device) with torch.no_grad(): pred = model(image_tensor) mask = (pred.squeeze().cpu().numpy() > 0.5).astype(np.uint8) * 255 return mask def remove_background(input_image, model_choice="Rembg (U²-Net)"): try: if isinstance(input_image, np.ndarray): input_image = Image.fromarray(input_image) if model_choice == "InSPyReNet" and HAS_INSPYRE: mask = process_with_inspyrenet(input_image) output = input_image.copy() output.putalpha(Image.fromarray(mask)) else: if model_choice == "InSPyReNet" and not HAS_INSPYRE: print("Falling back to Rembg due to InSPyReNet loading error") output = remove(input_image) if output.mode == 'RGBA': mask = output.split()[-1] else: mask = Image.new('L', output.size, 255) return output, mask except Exception as e: print(f"Error: {str(e)}") return None, None iface = gr.Interface( fn=remove_background, inputs=[ gr.Image(type="pil", label="Input Image"), gr.Radio( choices=["Rembg (U²-Net)", "InSPyReNet"], value="Rembg (U²-Net)", label="Model Selection" ) ], outputs=[ gr.Image(type="pil", label="Result"), gr.Image(type="pil", label="Mask") ], title="Professional Background Remover", description="Choose between Rembg (faster) or InSPyReNet (higher quality)" ) if __name__ == "__main__": iface.launch(server_name="0.0.0.0", server_port=7860)