File size: 6,086 Bytes
c6dfc69
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
import re
import PIL.Image
import matplotlib.pyplot as plt
import numpy
import torch
import pandas
import torchvision


class Visual(torch.utils.data.Dataset):
    def __init__(self, augmentation, directory_path, split, image_size, image_embedding_size):
        self.augment = augmentation
        self.directory_path = directory_path
        self.split = split
        self.image_size = image_size
        self.embedding_size = image_embedding_size

    def load_data(self, file_prefix):
        frame_path = os.path.join(file_prefix, 'frames')
        frame_path = [os.path.join(frame_path, i) for i in os.listdir(frame_path)]
        label_path = os.path.join(file_prefix, 'labels_rgb')
        label_path = [os.path.join(label_path, i) for i in os.listdir(label_path)]

        # if self.split == 'train':
        #     label_path += [os.path.join(file_prefix.replace('v1s', 'v1s_sam2_pseudo_labels'), i) for i in
        #                    os.listdir(file_prefix.replace('v1s', 'v1s_sam2_pseudo_labels'))]

        frame_path.sort(key=lambda x: tuple(map(int, x.split('/')[-1].split("_")[-1].split('.jpg')[0])))
        label_path.sort(key=lambda x: tuple(map(int, x.split('/')[-1].split("_")[-1].split('.png')[0])))

        frame = [PIL.Image.open(i) for i in frame_path]
        label = [PIL.Image.open(i).convert('L') for i in label_path]

        # if self.split == 'train':
        #     label += [PIL.Image.new('L', frame[0].size)] * (len(frame)-len(label))

        label_idx = torch.tensor(list([1] + [0] * 4), dtype=torch.bool)
        # fulfill the empty page.
        # we utilise pseudo-labels now.
        # label_idx = torch.tensor(list([1] + [0] * (len(frame) - len(label))), dtype=torch.bool)
        # label += [PIL.Image.new('L', frame[0].size)] * (len(frame)-len(label))

        # receive the prompts from the ground truth.
        # prompts = {"point_coords": torch.nan, "point_labels": torch.nan,
        #            "masks": [None]*len(frame), "box_coords": [None]*len(frame)}

        prompts = {}
        image_batch = [None]*len(frame)
        label_batch = [None]*len(frame)
        
        if self.split == 'train':
            # frame, label = self.augment.augment_entire_clip(frame, label)
            frame, label = self.augment(frame, label)


        for i in range(len(frame)):
            if self.split == 'test':
                curr_frame, curr_label = self.augment(frame[i], label[i], split=self.split)
            else:
                curr_frame, curr_label = frame[i], label[i]
            # if self.split == 'train' and i > 0:
            #     curr_label = curr_label / 255.
            #     curr_label[curr_label > 0.5] = 1
            #     curr_label[curr_label < 0.5] = 0
            #     # curr_label[(0.05 < curr_label) & (curr_label < 0.95)] = 255
            #     # we temporarily make it to be hard mask;
            #     # curr_label = ((curr_label / 255.) - 0.5) * 2
            #     # curr_label[curr_label >= 0.] = 1.
            #     # curr_label[curr_label < 0.] = 0.
            # else:
            curr_label[curr_label > 0.] = 1.
            image_batch[i], label_batch[i] = curr_frame, curr_label

            # image_batch[i], label_batch[i] = self.augment(frame[i], label[i], split=self.split)
            # note: we simply convert the code to binary mask in v1s, v1m;
            # to some reason, we failed to load the label in `L' format and had to hardcoding here.
            # label_batch[i][label_batch[i] > 0.] = 1.

            # prompts['box_coords'][i], prompts['masks'][i] = self.receive_other_prompts(label_batch[i])

        # organise the prompts
        # prompts.update({'masks': torch.stack(prompts['masks'], dim=0)})
        # prompts.update({'box_coords': torch.stack(prompts['box_coords'], dim=0)})
        # prompts.update({'point_labels': torch.stack(prompts['point_labels'], dim=0)})
        prompts.update({'label_index': label_idx})
        return torch.stack(image_batch, dim=0), torch.stack(label_batch, dim=0), prompts

    def receive_other_prompts(self, y_):
        # y_ = torch.zeros_like(y_)
        if len(torch.unique(y_)) > 1:
            # foreground point
            points_foreground = torch.stack(torch.where(y_ > 0)[::-1], dim=0).transpose(1, 0)

            # bbox prompt (left-top corner & right-bottom corner)
            bbox_one = torch.min(points_foreground[:, 0]), torch.min(points_foreground[:, 1])
            bbox_fou = torch.max(points_foreground[:, 0]), torch.max(points_foreground[:, 1])
            bbox_coord = torch.tensor(bbox_one + bbox_fou, dtype=torch.float)
            bbox_coord = self.transform_coords(bbox_coord, orig_hw=y_.squeeze().shape)
            # mask prompt
            low_mask = torchvision.transforms.functional.resize(y_.clone(), [self.embedding_size*4, self.embedding_size*4],
                                                                torchvision.transforms.InterpolationMode.NEAREST)
        else:
            # for the pure background situation.
            bbox_coord = torch.zeros([4], dtype=torch.float).fill_(float('nan'))
            low_mask = torch.zeros([1, self.embedding_size*4, self.embedding_size*4], dtype=torch.float).fill_(float('nan'))

        return bbox_coord, low_mask

    # we transfer the coords to SAM's input resolution (1024, 1024).
    def transform_coords(self, coords: torch.Tensor, orig_hw=None) -> torch.Tensor:
        """
        Expects a torch tensor with length 2 in the last dimension. The coordinates can be in absolute image or normalized coordinates,
        If the coords are in absolute image coordinates, normalize should be set to True and original image size is required.

        Returns
            Un-normalized coordinates in the range of [0, 1] which is expected by the sam2 model.
        """
        h, w = orig_hw
        coords = coords.clone().reshape(-1, 2, 2)
        coords[..., 0] = coords[..., 0] / w
        coords[..., 1] = coords[..., 1] / h
        coords = coords * self.image_size  # unnormalize coords
        return coords.reshape(4)