Spaces:
Runtime error
Runtime error
| # coding:utf-8 | |
| import os | |
| import numpy as np | |
| import cv2 | |
| from typing import Optional | |
| import torch | |
| # from models.transforms import ResizeLongestSide | |
| # from .transforms import ResizeLongestSide | |
| from torchvision import transforms | |
| def get_prompt_inp_scatter(scatter_file_): | |
| scatter_mask = cv2.imread(scatter_file_, cv2.IMREAD_UNCHANGED) | |
| return scatter_mask | |
| def pre_scatter_prompt(scatter, filp, device): | |
| if filp == True: | |
| scatter = cv2.flip(scatter, 1) | |
| img_transform = transforms.Compose([ | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=[0.485, 0.456, 0.406], | |
| std=[0.229, 0.224, 0.225]) | |
| ]) | |
| scatter_torch = img_transform(scatter) | |
| scatter_torch = scatter_torch.to(device) | |
| return scatter_torch | |
| def get_prompt_inp(txt_file_, filp): | |
| f = open(txt_file_) | |
| lines = f.readlines() | |
| points = [] | |
| labels = [] | |
| boxes = [] | |
| masks = [] | |
| for line in lines: | |
| x_1, y_1, x_2, y_2, x_3, y_3, x_4, y_4, classname, _ = line.split(' ') | |
| # print(x_1, y_1, x_2, y_2, x_3, y_3, x_4, y_4, classname, _) | |
| x_1, y_1, x_2, y_2, x_3, y_3, x_4, y_4 = float(x_1), float(y_1), \ | |
| float(x_2), float(y_2), \ | |
| float(x_3), float(y_3), \ | |
| float(x_4), float(y_4) | |
| xmin = min(x_1, x_2, x_3, x_4) | |
| xmax = max(x_1, x_2, x_3, x_4) | |
| ymin = min(y_1, y_2, y_3, y_4) | |
| ymax = max(y_1, y_2, y_3, y_4) | |
| if filp: | |
| xmin = 1024.0 - xmin | |
| xmax = 1024.0 - xmax | |
| x_center = (xmin + xmax)/2 | |
| y_center = (ymin + ymax)/2 | |
| point = [x_center, y_center] | |
| box = [[xmin, ymin], [xmax, ymax]] | |
| # box = [xmin, ymin, xmax, ymax] | |
| mask = [] | |
| points.append(point) | |
| labels.append(classname) | |
| boxes.append(box) | |
| masks.append(mask) | |
| # boxes = boxes[:1] | |
| # return points, labels, boxes, masks | |
| return points, labels, boxes, None | |
| def pre_prompt(points=None, boxes=None, masks=None, device=None): | |
| points_torch = points | |
| if points != None: | |
| # points = points/16.0 | |
| points_torch = torch.as_tensor(points, dtype=torch.float, device=device) | |
| points_torch = points_torch/16.0 | |
| boxes_torch = boxes | |
| if boxes != None: | |
| # boxes = boxes/16.0 | |
| boxes_torch = torch.as_tensor(boxes, dtype=torch.float, device=device) | |
| boxes_torch = boxes_torch/16.0 | |
| # for box in boxes: | |
| # left_top, bottom_right = box | |
| masks_torch = masks | |
| if masks != None: | |
| masks_torch = torch.as_tensor(masks, dtype=torch.float, device=device) | |
| return points_torch, boxes_torch, masks_torch | |
| # def pre_prompt( | |
| # point_coords: Optional[np.ndarray] = None, | |
| # point_labels: Optional[np.ndarray] = None, | |
| # box: Optional[np.ndarray] = None, | |
| # mask_input: Optional[np.ndarray] = None, | |
| # device=None, | |
| # original_size = [1024, 1024] | |
| # ): | |
| # | |
| # transform = ResizeLongestSide(1024) | |
| # # Transform input prompts | |
| # coords_torch, labels_torch, box_torch, mask_input_torch = None, None, None, None | |
| # if point_coords is not None: | |
| # assert ( | |
| # point_labels is not None | |
| # ), "point_labels must be supplied if point_coords is supplied." | |
| # point_coords = transform.apply_coords(point_coords, original_size) | |
| # coords_torch = torch.as_tensor(point_coords, dtype=torch.float, device=device) | |
| # labels_torch = torch.as_tensor(point_labels, dtype=torch.int, device=device) | |
| # coords_torch, labels_torch = coords_torch[None, :, :], labels_torch[None, :] | |
| # if box is not None: | |
| # box = transform.apply_boxes(box, original_size) | |
| # box_torch = torch.as_tensor(box, dtype=torch.float, device=device) | |
| # box_torch = box_torch[None, :] | |
| # if mask_input is not None: | |
| # mask_input_torch = torch.as_tensor(mask_input, dtype=torch.float, device=device) | |
| # mask_input_torch = mask_input_torch[None, :, :, :] | |
| # | |
| # return coords_torch, labels_torch, box_torch, mask_input_torch | |
| if __name__ == '__main__': | |
| txt_dir = './ISAID/train/trainprompt/sub_labelTxt/' | |
| txt_list = os.listdir(txt_dir) | |
| txt_file_0 = txt_dir + txt_list[0] | |
| points, labels, boxes, masks = get_prompt_inp(txt_file_0) | |
| print(points) | |
| print(labels) | |
| print(boxes) | |
| # boxes = boxes / 16.0 | |
| boxes_torch = torch.as_tensor(boxes, dtype=torch.float) | |
| boxes_torch = boxes_torch/16.0 | |
| print(boxes_torch, boxes_torch.shape) | |