Spaces:
Paused
Paused
| ''' | |
| Self prompting strategy | |
| INPUT: | |
| predictor: the initialized sam predictor : | |
| rendered_mask_score: - : H*W*1 | |
| num_prompt: - | |
| index_matrix: the matrix contains the 3D index of the rendered view : H*W*3 | |
| OUTPUT: a list of prompts | |
| ''' | |
| import os | |
| import torch | |
| import math | |
| import numpy as np | |
| to8b = lambda x : (255*np.clip(x,0,1)).astype(np.uint8) | |
| def mask_to_prompt(predictor, rendered_mask_score, index_matrix, num_prompts = 3): | |
| '''main function for self prompting''' | |
| h, w, _ = rendered_mask_score.shape | |
| tmp = rendered_mask_score.view(-1) | |
| print("tmp min:", tmp.min(), "tmp max:", tmp.max()) | |
| rand = torch.ones_like(tmp) | |
| topk_v, topk_p = torch.topk(tmp*rand, k = 1)[0].cpu(), torch.topk(tmp*rand, k = 1)[1].cpu() | |
| if topk_v <= 0: | |
| print("No prompt is available") | |
| return np.zeros((0,2)), np.ones((0)) | |
| prompt_points = [] | |
| prompt_points.append([topk_p[0] % w, topk_p[0] // w]) | |
| print((topk_p[0] % w).item(), (topk_p[0] // w).item(), h, w) | |
| tmp_mask = rendered_mask_score.clone().detach() | |
| area = to8b(tmp_mask.cpu().numpy()).sum() / 255 | |
| r = np.sqrt(area / math.pi) | |
| masked_r = max(int(r) // 2, 2) | |
| # masked_r = max(int(r) // 3, 2) | |
| pre_tmp_mask_score = None | |
| for _ in range(num_prompts - 1): | |
| # mask out a region around the last prompt point | |
| input_label = np.ones(len(prompt_points)) | |
| previous_masks, previous_scores, previous_logits = predictor.predict( | |
| point_coords=np.array(prompt_points), | |
| point_labels=input_label, | |
| multimask_output=False, | |
| ) | |
| l = 0 if prompt_points[-1][0]-masked_r <= 0 else prompt_points[-1][0]-masked_r | |
| r = w-1 if prompt_points[-1][0]+masked_r >= w-1 else prompt_points[-1][0]+masked_r | |
| t = 0 if prompt_points[-1][1]-masked_r <= 0 else prompt_points[-1][1]-masked_r | |
| b = h-1 if prompt_points[-1][1]+masked_r >= h-1 else prompt_points[-1][1]+masked_r | |
| tmp_mask[t:b+1, l:r+1, :] = -1e5 | |
| # bool: H W | |
| previous_mask_tensor = torch.tensor(previous_masks[0]) | |
| previous_mask_tensor = previous_mask_tensor.unsqueeze(0).unsqueeze(0).float() | |
| previous_mask_tensor = torch.nn.functional.max_pool2d(previous_mask_tensor, 25, stride = 1, padding = 12) | |
| previous_mask_tensor = previous_mask_tensor.squeeze(0).permute([1,2,0]) | |
| # tmp_mask[previous_mask_tensor > 0] = -1e5 | |
| previous_max_score = torch.max(rendered_mask_score[previous_mask_tensor > 0]) | |
| previous_point_index = torch.zeros_like(index_matrix) | |
| previous_point_index[:,:,0] = prompt_points[-1][1] / h | |
| previous_point_index[:,:,1] = prompt_points[-1][0] / w | |
| previous_point_index[:,:,2] = index_matrix[int(prompt_points[-1][1]), int(prompt_points[-1][0]), 2] | |
| distance_matrix = torch.sqrt(((index_matrix - previous_point_index)**2).sum(-1)) | |
| distance_matrix = (distance_matrix.unsqueeze(-1) - distance_matrix.min()) / (distance_matrix.max() - distance_matrix.min()) | |
| cur_tmp_mask = tmp_mask - distance_matrix * max(previous_max_score, 0) | |
| if pre_tmp_mask_score is None: | |
| pre_tmp_mask_score = cur_tmp_mask | |
| else: | |
| pre_tmp_mask_score[pre_tmp_mask_score < cur_tmp_mask] = cur_tmp_mask[pre_tmp_mask_score < cur_tmp_mask] | |
| pre_tmp_mask_score[tmp_mask == -1e5] = -1e5 | |
| tmp_val_point = pre_tmp_mask_score.view(-1).max(dim = 0) | |
| if tmp_val_point[0] <= 0: | |
| print("There are", len(prompt_points), "prompts") | |
| break | |
| prompt_points.append([int(tmp_val_point[1].cpu() % w), int(tmp_val_point[1].cpu() // w)]) | |
| prompt_points = np.array(prompt_points) | |
| input_label = np.ones(len(prompt_points)) | |
| return prompt_points, input_label | |
| from groundingdino.util.inference import load_model, load_image, predict, annotate | |
| import cv2 | |
| import groundingdino.datasets.transforms as T | |
| from torchvision.ops import box_convert | |
| from PIL import Image | |
| def image_transform(image) -> torch.Tensor: | |
| transform = T.Compose( | |
| [ | |
| T.RandomResize([800], max_size=1333), | |
| T.ToTensor(), | |
| T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]), | |
| ] | |
| ) | |
| image_transformed, _ = transform(image, None) | |
| return image_transformed | |
| def grounding_dino_prompt(image, text): | |
| image_tensor = image_transform(Image.fromarray(image)) | |
| model_root = './dependencies/GroundingDINO' | |
| model = load_model(os.path.join(model_root, "groundingdino/config/GroundingDINO_SwinT_OGC.py"), os.path.join(model_root, "weights/groundingdino_swint_ogc.pth")) | |
| BOX_TRESHOLD = 0.35 | |
| TEXT_TRESHOLD = 0.25 | |
| boxes, logits, phrases = predict( | |
| model=model, | |
| image=image_tensor, | |
| caption=text, | |
| box_threshold=BOX_TRESHOLD, | |
| text_threshold=TEXT_TRESHOLD | |
| ) | |
| h, w, _ = image.shape | |
| print("boxes device", boxes.device) | |
| boxes = boxes * torch.Tensor([w, h, w, h]).to(boxes.device) | |
| xyxy = box_convert(boxes=boxes, in_fmt="cxcywh", out_fmt="xyxy").numpy() | |
| print(xyxy) | |
| return xyxy | |
| ''' | |
| # new prompt strategy: bbox based | |
| # cannot be applied to 360 | |
| try: | |
| prompt = seg_m_for_prompt[:,:,no] | |
| prompt = prompt > 0 | |
| box_prompt = masks_to_boxes(prompt.unsqueeze(0)) | |
| width = box_prompt[0,2] - box_prompt[0,0] | |
| height = box_prompt[0,3] - box_prompt[0,1] | |
| box_prompt[0,0] -= 0.05*width | |
| box_prompt[0,2] += 0.05*width | |
| box_prompt[0,1] -= 0.05*height | |
| box_prompt[0,3] += 0.05*height | |
| # print(box_prompt) | |
| transformed_boxes = predictor.transform.apply_boxes_torch(box_prompt, image.shape[:2]) | |
| masks, _, _ = predictor.predict_torch( | |
| point_coords=None, | |
| point_labels=None, | |
| boxes=transformed_boxes, | |
| multimask_output=False, | |
| ) | |
| masks = masks.float() | |
| except: | |
| continue | |
| ''' | |
| ''' | |
| # mask based | |
| H,W,_ = prompt.shape | |
| target_size = RLS.get_preprocess_shape(H,W, 256) | |
| prompt = torch.nn.functional.interpolate(torch.tensor(prompt).float().unsqueeze(0).permute([0,3,1,2]), target_size, mode = 'bilinear') | |
| h,w = prompt.shape[-2:] | |
| padh = 256 - h | |
| padw = 256 - w | |
| prompt = F.pad(prompt, (0, padw, 0, padh)) | |
| prompt = (prompt / 255) * 40 - 20 | |
| print(prompt.shape) | |
| prompt = prompt.squeeze(1).cpu().numpy() | |
| masks, scores, logits = predictor.predict( | |
| point_coords=None, | |
| point_labels=None, | |
| mask_input=prompt, | |
| multimask_output=False, | |
| ) | |
| ''' |