Spaces:
Runtime error
Runtime error
| import cv2 | |
| import gradio as gr | |
| import os | |
| import requests | |
| from PIL import Image | |
| import numpy as np | |
| import torch | |
| from torch.autograd import Variable | |
| from torchvision import transforms | |
| import torch.nn.functional as F | |
| # Automatically download required files | |
| # 1. data_loader_cache.py from GitHub | |
| if not os.path.exists("data_loader_cache.py"): | |
| print("Downloading data_loader_cache.py...") | |
| try: | |
| response = requests.get("https://raw.githubusercontent.com/xuebinqin/DIS/main/DIS/IS-Net/data_loader_cache.py") | |
| response.raise_for_status() | |
| with open("data_loader_cache.py", "wb") as f: | |
| f.write(response.content) | |
| except requests.RequestException as e: | |
| print(f"Failed to download data_loader_cache.py: {e}") | |
| raise | |
| # 2. models.py from GitHub | |
| if not os.path.exists("models.py"): | |
| print("Downloading models.py...") | |
| try: | |
| response = requests.get("https://raw.githubusercontent.com/xuebinqin/DIS/main/DIS/IS-Net/models.py") | |
| response.raise_for_status() | |
| with open("models.py", "wb") as f: | |
| f.write(response.content) | |
| except requests.RequestException as e: | |
| print(f"Failed to download models.py: {e}") | |
| raise | |
| # 3. isnet.pth from Hugging Face Git LFS (direct URL from screenshot) | |
| if not os.path.exists("saved_models"): | |
| os.makedirs("saved_models") | |
| isnet_path = "saved_models/isnet.pth" | |
| if not os.path.exists(isnet_path): | |
| print("Downloading isnet.pth from Hugging Face Git LFS...") | |
| try: | |
| lfs_url = "https://cdn-lfs.huggingface.co/repos/e0/a8/e0a889743a78391b48db7c4c0b4de1963ee320cb10934c75a32481dc5af9c61/e0a889743a78391b48db7c4c0b4de1963ee320cb10934c75a32481dc5af9c61?download=true" | |
| response = requests.get(lfs_url, stream=True) | |
| response.raise_for_status() | |
| with open(isnet_path, "wb") as f: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| if chunk: | |
| f.write(chunk) | |
| except requests.RequestException as e: | |
| print(f"Failed to download isnet.pth: {e}") | |
| raise | |
| # Project imports | |
| from data_loader_cache import normalize, im_reader, im_preprocess | |
| from models import * | |
| # Helpers | |
| device = 'cpu' | |
| class GOSNormalize(object): | |
| def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): | |
| self.mean = mean | |
| self.std = std | |
| def __call__(self, image): | |
| image = normalize(image, self.mean, self.std) | |
| return image | |
| transform = transforms.Compose([GOSNormalize([0.5, 0.5, 0.5], [1.0, 1.0, 1.0])]) | |
| def load_image(im_path, hypar): | |
| im = im_reader(im_path) | |
| im, im_shp = im_preprocess(im, hypar["cache_size"]) | |
| im = torch.divide(im, 255.0) | |
| shape = torch.from_numpy(np.array(im_shp)) | |
| return transform(im).unsqueeze(0), shape.unsqueeze(0) | |
| def build_model(hypar, device): | |
| net = hypar["model"] | |
| net.to(device) | |
| if hypar["restore_model"]: | |
| net.load_state_dict(torch.load(os.path.join(hypar["model_path"], hypar["restore_model"]), map_location=device)) | |
| net.eval() | |
| return net | |
| def predict(net, inputs_val, shapes_val, hypar, device): | |
| net.eval() | |
| inputs_val = inputs_val.type(torch.FloatTensor).to(device) | |
| with torch.no_grad(): | |
| inputs_val_v = Variable(inputs_val) | |
| ds_val = net(inputs_val_v)[0] | |
| pred_val = ds_val[0][0, :, :, :] | |
| pred_val = torch.squeeze(F.upsample(torch.unsqueeze(pred_val, 0), (shapes_val[0][0], shapes_val[0][1]), mode='bilinear')) | |
| ma = torch.max(pred_val) | |
| mi = torch.min(pred_val) | |
| pred_val = (pred_val - mi) / (ma - mi) | |
| return (pred_val.cpu().numpy() * 255).astype(np.uint8) | |
| # Set Parameters | |
| hypar = { | |
| "model_path": "saved_models", | |
| "restore_model": "isnet.pth", | |
| "cache_size": [512, 512], | |
| "input_size": [512, 512], | |
| "crop_size": [512, 512], | |
| "model": ISNetDIS() | |
| } | |
| # Build Model | |
| net = build_model(hypar, device) | |
| def inference(image): | |
| image_path = image | |
| image_tensor, orig_size = load_image(image_path, hypar) | |
| mask = predict(net, image_tensor, orig_size, hypar, device) | |
| pil_mask = Image.fromarray(mask).convert('L') | |
| im_rgb = Image.open(image).convert("RGB") | |
| im_rgba = im_rgb.copy() | |
| im_rgba.putalpha(pil_mask) | |
| return [im_rgba, pil_mask] | |
| title = "Dichotomous Image Segmentation" | |
| description = "Upload an image to remove its background." | |
| interface = gr.Interface( | |
| fn=inference, | |
| inputs=gr.Image(type='filepath'), | |
| outputs=[gr.Image(type='filepath', format="png"), gr.Image(type='filepath', format="png")], | |
| title=title, | |
| description=description, | |
| flagging_mode="never", | |
| cache_mode="lazy" | |
| ).launch() |