Spaces:
Runtime error
Runtime error
| import os | |
| os.system("pip install ./MultiScaleDeformableAttention-1.0-py3-none-any.whl") | |
| import gradio as gr | |
| from huggingface_hub import hf_hub_download | |
| import numpy as np | |
| import numpy as np | |
| import torch | |
| import matplotlib.pyplot as plt | |
| import cv2 | |
| from PIL import Image | |
| import torch.nn as nn | |
| from torch.autograd import Variable | |
| from torchvision import transforms | |
| import torch.nn.functional as F | |
| import gdown | |
| import os | |
| from io import BytesIO | |
| from IS_Net.data_loader import normalize, im_reader, im_preprocess | |
| from IS_Net.models.isnet import ISNetGTEncoder, ISNetDIS | |
| from SAM.segment_anything import sam_model_registry, SamPredictor | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| def show_gray_images(images, m=8, alpha=3): | |
| n, h, w = images.shape | |
| num_rows = (n + m - 1) // m | |
| fig, axes = plt.subplots(num_rows, m, figsize=(m * 2*alpha, num_rows * 2*alpha)) | |
| plt.subplots_adjust(wspace=0.05, hspace=0.05) | |
| for i in range(num_rows): | |
| for j in range(m): | |
| idx = i*m + j | |
| if m == 1 or num_rows == 1: | |
| axes[idx].imshow(images[idx], cmap='gray') | |
| axes[idx].axis('off') | |
| elif idx < n: | |
| axes[i, j].imshow(images[idx], cmap='gray') | |
| axes[i, j].axis('off') | |
| plt.show() | |
| def show_mask(mask, ax, random_color=False): | |
| if random_color: | |
| color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) | |
| else: | |
| color = np.array([30/255, 144/255, 255/255, 0.6]) | |
| h, w = mask.shape[-2:] | |
| mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) | |
| ax.imshow(mask_image) | |
| def show_points(coords, labels, ax, marker_size=375): | |
| pos_points = coords[labels==1] | |
| neg_points = coords[labels==0] | |
| ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) | |
| ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25) | |
| def show_box(box, ax): | |
| x0, y0 = box[0], box[1] | |
| w, h = box[2] - box[0], box[3] - box[1] | |
| ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) | |
| sam_checkpoint = hf_hub_download(repo_id="andzhang01/segment_anything", filename="sam_vit_l_0b3195.pth") | |
| # sam_checkpoint = r"~/.cache/huggingface/hub/models--andzhang01--segment-anything/sam_vit_l_0b3195.pth" | |
| model_type = "vit_l" | |
| sam = sam_model_registry[model_type](checkpoint=sam_checkpoint, device=device) | |
| sam.to(device=device) | |
| predictor = SamPredictor(sam) | |
| class GOSNormalize(object): | |
| ''' | |
| Normalize the Image using torch.transforms | |
| ''' | |
| def __init__(self, mean=[0.485,0.456,0.406,0], std=[0.229,0.224,0.225,1.0]): | |
| self.mean = mean | |
| self.std = std | |
| def __call__(self,image): | |
| image = normalize(image,self.mean,self.std) | |
| return image | |
| transform = transforms.Compose([GOSNormalize([0.5,0.5,0.5,0,0],[1.0,1.0,1.0,1.0,1.0])]) | |
| def build_model(hypar,device): | |
| net = hypar["model"]#GOSNETINC(3,1) | |
| # convert to half precision | |
| if(hypar["model_digit"]=="half"): | |
| net.half() | |
| for layer in net.modules(): | |
| if isinstance(layer, nn.BatchNorm2d): | |
| layer.float() | |
| net.to(device) | |
| if(hypar["restore_model"]!=""): | |
| net.load_state_dict(torch.load(hypar["model_path"]+"/"+hypar["restore_model"],map_location=device)) | |
| net.to(device) | |
| net.eval() | |
| return net | |
| def get_box(input_box,size): | |
| # 初始化一个全零的图像 | |
| image = torch.zeros(size) | |
| # 填充方框区域为白色(值为255) | |
| image[input_box[1]:input_box[3],input_box[0]:input_box[2]] = 255 | |
| return image | |
| def get_box_from_mask(gt): | |
| gt = torch.from_numpy(np.array(gt)) | |
| box = torch.zeros_like(gt)+gt | |
| box = box.float() | |
| rows, cols = torch.where(box>0) | |
| left = torch.min(cols) | |
| top = torch.min(rows) | |
| right = torch.max(cols) | |
| bottom = torch.max(rows) | |
| box[top:bottom,left:right] = 255 | |
| box[box!=255] = 0 | |
| return box | |
| def predict_one(net, image, mask, box, transforms, hypar, device): | |
| ''' | |
| Given an Image, predict the mask | |
| ''' | |
| with torch.no_grad(): | |
| image = torch.from_numpy(np.array(image)) | |
| mask = torch.from_numpy(np.array(mask)) | |
| box = torch.from_numpy(np.array(box)) | |
| if mask.max()==1: | |
| mask = mask.type(torch.float32)*255.0 | |
| # for i in [image,mask[...,None],box[...,None]]: | |
| # print(i.shape) | |
| inputs_val_v = torch.cat([image,mask[...,None],box[...,None]],dim=2) | |
| inputs_val_v = inputs_val_v.permute(2,0,1)[None,...] | |
| shapes_val = inputs_val_v.shape[-2:] | |
| inputs_val_v = F.upsample(inputs_val_v,(hypar["input_size"]),mode='bilinear') | |
| box = inputs_val_v[0][-1] | |
| box[box>127] = 255 | |
| box[box<=127] = 0 | |
| inputs_val_v[0][-1] = box | |
| # plt.imshow(inputs_val_v[0][-1]) | |
| # plt.show() | |
| inputs_val_v = inputs_val_v.divide(255.0) | |
| # print(shapes_val) | |
| net.eval() | |
| if(hypar["model_digit"]=="full"): | |
| inputs_val_v = inputs_val_v.type(torch.FloatTensor) | |
| else: | |
| inputs_val_v = inputs_val_v.type(torch.HalfTensor) | |
| inputs_val_v = Variable(inputs_val_v, requires_grad=False).to(device) # wrap inputs in Variable | |
| inputs_val_v = transforms(inputs_val_v) | |
| # print(inputs_val_v.shape) | |
| ds_val = net(inputs_val_v)[0][0] | |
| # print(ds_val.shape) | |
| ## recover the prediction spatial size to the orignal image size | |
| pred_val = F.upsample(ds_val,(shapes_val),mode='bilinear')[0][0] | |
| # print(pred_val.shape) | |
| ma = torch.max(pred_val) | |
| mi = torch.min(pred_val) | |
| pred_val = (pred_val-mi)/(ma-mi) # max = 1 | |
| if device == 'cuda': torch.cuda.empty_cache() | |
| refined_mask = (pred_val.detach().cpu().numpy()*255).astype(np.uint8) | |
| # refined_mask[refined_mask>127] = 255 | |
| # refined_mask[refined_mask<=127] = 0 | |
| # refined_mask = 1 - refined_mask.astype(np.byte) | |
| ret, binary = cv2.threshold(refined_mask, 0, 255, cv2.THRESH_OTSU) | |
| return binary# it is the mask we need | |
| hypar = {} # paramters for inferencing | |
| dis_model_path = hf_hub_download(repo_id="jwlarocque/DIS-SAM", filename="DIS-SAM-checkpoint.pth") | |
| # hypar["model_path"] ="~/.cache/huggingface/hub/jwlarocque/DIS-SAM" | |
| hypar["model_path"] = os.path.split(dis_model_path)[0] | |
| # hypar["restore_model"] = "DIS-SAM-checkpoint.pth" | |
| hypar["restore_model"] = os.path.split(dis_model_path)[1] | |
| hypar["model_digit"] = "full" | |
| hypar["input_size"] = [1024, 1024] | |
| hypar["model"] = ISNetDIS(in_ch=5) | |
| net = build_model(hypar, device) | |
| def bbox_from_str(bbox_str: str): | |
| if not bbox_str: | |
| return None | |
| split = bbox_str.strip().split(",") | |
| if len(split) == 4: | |
| try: | |
| bbox = [int(x) for x in split] | |
| return np.array(bbox) | |
| except ValueError: | |
| return None | |
| else: | |
| return None | |
| def predict(input_img: np.ndarray, bbox_str: str): | |
| predictor.set_image(input_img) | |
| input_label = np.array([1]) | |
| bbox = bbox_from_str(bbox_str) | |
| input_box = bbox if bbox is not None else np.array([0, 0, input_img.shape[1], input_img.shape[0]]) | |
| masks, scores, logits = predictor.predict( | |
| box=input_box, | |
| point_labels=input_label, | |
| multimask_output=True, | |
| ) | |
| mask = masks[0] | |
| DIS_mask = mask | |
| DIS_box = get_box_from_mask(DIS_mask) | |
| refined_mask = predict_one(net,input_img,DIS_mask,DIS_box,transform,hypar,device) | |
| mask_gray = (mask * 255).astype(np.uint8) | |
| refined_mask_gray = refined_mask.astype(np.uint8) | |
| return mask_gray, refined_mask_gray | |
| gradio_app = gr.Interface( | |
| predict, | |
| inputs=[ | |
| gr.Image(label="Select Image", sources=['upload', 'webcam'], type="numpy"), | |
| gr.Textbox(label="Bounding Box Prompt (pixels)", placeholder="x1,y1,x2,y2")], | |
| outputs=[gr.Image(label="SAM Mask", type="numpy", image_mode="L"), gr.Image(label="DIS-SAM Mask", type="numpy", image_mode="L")], | |
| title="DIS-SAM", | |
| examples=[ | |
| ["./images/wire_shelf.jpg", "20,100,480,660"], | |
| ["./images/radio_telescope.jpg", "1130,320,4000,2920"], | |
| ["./images/bridge.jpg", ""], | |
| ["./images/tree.jpg", "70,110,2290,1800"], | |
| ["./images/bicycle.jpg", "135,235,2425,1580"], | |
| ["./images/capybara.jpg", "630,440,2060,1650"], | |
| ["./images/capybara.jpg", "1050,173,1550,618"] | |
| ] | |
| ) | |
| if __name__ == "__main__": | |
| gradio_app.launch() | |