File size: 1,806 Bytes
9ed01de
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import glob
import os
import PIL
import PIL.Image
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import open_clip


class SingleFolderDataset(torch.utils.data.Dataset):
    def __init__(self, folder_path, transform=None):
        self.folder_path = folder_path
        self.transform = transform
        self.image_paths = glob.glob(os.path.join(folder_path, "*"))
        print('Found {} images in {}'.format(len(self.image_paths), folder_path))

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, index):
        image_path = self.image_paths[index]
        image = PIL.Image.open(image_path)
        if self.transform:
            image = self.transform(image)
        return image, os.path.basename(image_path)


def extract_feats(index_config):

    ai_config = index_config['a1_config']
    weight_path = index_config['weight_path']
    img_dir = index_config['img_dir']
    batch_size = 1024 # 64

    model, _, transform = open_clip.create_model_and_transforms(ai_config, pretrained=weight_path)

    devive = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
    print('Using device:', devive)
    model = model.to(devive)

    dataset = SingleFolderDataset(img_dir, transform=transform)

    dl = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)

    im_ids = []
    im_feats = []

    for i, (patched_tensor, img_id) in tqdm(enumerate(dl)):
        patched_tensor = patched_tensor.to(devive)

        with torch.no_grad():
            out = model.encode_image(patched_tensor)

        im_ids.append(img_id)
        im_feats.append(out.cpu().numpy())

    im_hashes = np.concatenate(im_ids)
    im_feats = np.concatenate(im_feats)
    return im_hashes, im_feats