Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import uuid | |
| import base64 | |
| import numpy as np | |
| import onnxruntime as ort | |
| import cv2 | |
| from PIL import Image | |
| from torchvision.transforms.functional import normalize | |
| import torch.nn.functional as F | |
| from typing import Union, List | |
| from io import BytesIO | |
| from huggingface_hub import hf_hub_download | |
| # ---- Config ---- | |
| INPUT_SIZE = [1200, 1800] # (H, W) | |
| # ---- Load ONNX model ---- | |
| model_path = hf_hub_download(repo_id="Trendyol/background-removal", filename="model.onnx") | |
| providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] | |
| try: | |
| ort_sess = ort.InferenceSession(model_path, providers=providers) | |
| except Exception: | |
| ort_sess = ort.InferenceSession(model_path, providers=["CPUExecutionProvider"]) | |
| # ---- Utils from Trendyol ---- | |
| def keep_large_components(a: np.ndarray) -> np.ndarray: | |
| dilate_kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (9, 9)) | |
| a_mask = (a > 25).astype(np.uint8) * 255 | |
| analysis = cv2.connectedComponentsWithStats(a_mask, 4, cv2.CV_32S) | |
| (totalLabels, label_ids, values, _) = analysis | |
| h, w = a.shape[:2] | |
| area_limit = 50000 * (h * w) / (INPUT_SIZE[1] * INPUT_SIZE[0]) | |
| i_to_keep = [] | |
| for i in range(1, totalLabels): | |
| area = values[i, cv2.CC_STAT_AREA] | |
| if area > area_limit: | |
| i_to_keep.append(i) | |
| if len(i_to_keep) > 0: | |
| final_mask = np.zeros_like(a, dtype=np.uint8) | |
| for i in i_to_keep: | |
| componentMask = (label_ids == i).astype("uint8") * 255 | |
| final_mask = cv2.bitwise_or(final_mask, componentMask) | |
| final_mask = cv2.dilate(final_mask, dilate_kernel, iterations=2) | |
| a = cv2.bitwise_and(a, final_mask) | |
| a = a.reshape((a.shape[0], a.shape[1], 1)) | |
| return a | |
| def preprocess_input(im: np.ndarray) -> torch.Tensor: | |
| if len(im.shape) < 3: | |
| im = im[:, :, np.newaxis] | |
| if im.shape[2] == 4: | |
| im = im[:, :, :3] | |
| im_tensor = torch.tensor(im, dtype=torch.float32).permute(2, 0, 1) | |
| im_tensor = F.upsample(torch.unsqueeze(im_tensor, 0), INPUT_SIZE, mode="bilinear").type(torch.uint8) | |
| image = torch.divide(im_tensor, 255.0) | |
| image = normalize(image, [0.5, 0.5, 0.5], [1.0, 1.0, 1.0]) | |
| return image | |
| def postprocess_output(result: np.ndarray, orig_im_shape) -> np.ndarray: | |
| result = torch.squeeze( | |
| F.upsample(torch.from_numpy(result).unsqueeze(0), (orig_im_shape), mode="bilinear"), 0 | |
| ) | |
| ma = torch.max(result) | |
| mi = torch.min(result) | |
| result = (result - mi) / (ma - mi + 1e-8) | |
| a = (result * 255).permute(1, 2, 0).cpu().data.numpy().astype(np.uint8) | |
| a = keep_large_components(a) | |
| return a | |
| # ---- Core processing ---- | |
| def process(image: Image.Image) -> Image.Image: | |
| image_size = image.size | |
| np_img = np.array(image.convert("RGB")) | |
| # Preprocess | |
| img_tensor = preprocess_input(np_img) | |
| # Inference | |
| inputs = {ort_sess.get_inputs()[0].name: img_tensor.numpy()} | |
| result = ort_sess.run(None, inputs)[0][0] # (1,1,H,W) | |
| # Postprocess to mask | |
| alpha = postprocess_output(result, (np_img.shape[0], np_img.shape[1])) # (H,W,1) | |
| # White background composite | |
| mask = Image.fromarray(alpha.squeeze(-1)).convert("L") | |
| binary_mask = mask.point(lambda p: 255 if p > 25 else 0) | |
| white_bg = Image.new("RGB", image_size, (255, 255, 255)) | |
| result = Image.composite(image.convert("RGB"), white_bg, binary_mask) | |
| return result | |
| # ---- Gradio handler ---- | |
| def handler(image=None) -> Union[str, None]: | |
| if image is not None: | |
| processed = process(image) | |
| filename = f"output_{uuid.uuid4().hex[:8]}.png" | |
| processed.save(filename) | |
| return filename | |
| return None | |
| # ---- Gradio UI ---- | |
| demo = gr.Interface( | |
| fn=handler, | |
| inputs=gr.Image(label="Upload Image", type="pil"), | |
| outputs=gr.File(label="Output File"), | |
| title="Background Remover (Trendyol)", | |
| description="Upload an image to remove the background with the Trendyol ONNX model. Background is replaced with white.", | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(show_error=True) | |