import torch import torchvision.transforms as transforms from PIL import Image import numpy as np import os import requests from model.u2net import U2NET MODEL_DIR = "saved_models/u2net" MODEL_PATH = os.path.join(MODEL_DIR, "u2net.pth") MODEL_URL = "https://huggingface.co/flashingtt/U-2-Net/resolve/main/u2net.pth" def download_model(): if not os.path.exists(MODEL_PATH): os.makedirs(MODEL_DIR, exist_ok=True) print("Downloading model...") r = requests.get(MODEL_URL, stream=True) with open(MODEL_PATH, "wb") as f: for chunk in r.iter_content(chunk_size=8192): f.write(chunk) print("Model downloaded.") download_model() def load_model(): net = U2NET(3, 1) net.load_state_dict(torch.load(MODEL_PATH, map_location="cpu")) net.eval() return net model = load_model() def preprocess(image): transform = transforms.Compose([ transforms.Resize((320, 320)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) return transform(image).unsqueeze(0) def postprocess(mask, original_size): mask = mask.squeeze().cpu().data.numpy() mask = (mask - mask.min()) / (mask.max() - mask.min()) mask = Image.fromarray((mask * 255).astype(np.uint8)).resize(original_size, Image.BILINEAR) return mask def remove_background(image): input_tensor = preprocess(image) with torch.no_grad(): d1, *_ = model(input_tensor) mask = postprocess(d1, image.size) image = image.convert("RGBA") datas = image.getdata() masks = mask.getdata() new_data = [] for item, m in zip(datas, masks): new_data.append((item[0], item[1], item[2], m)) image.putdata(new_data) return image