File size: 1,834 Bytes
02d3a85
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
import torch
from PIL import Image
import os
from torch.utils.data import Dataset
import torch.distributed as dist


def ddp_setup(rank: int, world_size: int):
    os.environ["MASTER_ADDR"] = "localhost"
    os.environ["MASTER_PORT"] = "12355"
    dist.init_process_group(backend="nccl", rank=rank, world_size=world_size)
    torch.cuda.set_device(rank)


# class BlipCaptionedWrapper(nn.Module):
#     def __init__(self, model, num_of_captions=1):
#         super().__init__()
#         self.model = model
#         self.num_of_captions = num_of_captions
#
#     def forward(self, image_inputs):
#         return self.model.generate({"image": image_inputs}, use_nucleus_sampling=True, num_captions=self.num_of_captions)

class PrepareImageForBlip(object):
    def __init__(self, processor):
        self.processor = processor

    def __call__(self, image):
        return self.processor(image)


class ImageDataset(Dataset):
    def __init__(self, root_dir, transform=None, path_list=None):
        self.root_dir = root_dir
        if path_list:
            self.image_list = path_list
        else:
            self.image_list = os.listdir(root_dir)
        self.transform = transform

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        image_name = os.path.join(self.root_dir, self.image_list[idx])
        img = Image.open(image_name)
        img = img.convert("L")
        img = img.convert("RGB")
        filename = self.image_list[idx]

        if self.transform:
            img = self.transform(img)

        return img, filename


class ImageDatasetFromImageList(Dataset):
    def __init__(self, image_list):
        self.image_list = image_list

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        return self.image_list[idx]