File size: 1,806 Bytes
9ed01de |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
import glob
import os
import PIL
import PIL.Image
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
import open_clip
class SingleFolderDataset(torch.utils.data.Dataset):
def __init__(self, folder_path, transform=None):
self.folder_path = folder_path
self.transform = transform
self.image_paths = glob.glob(os.path.join(folder_path, "*"))
print('Found {} images in {}'.format(len(self.image_paths), folder_path))
def __len__(self):
return len(self.image_paths)
def __getitem__(self, index):
image_path = self.image_paths[index]
image = PIL.Image.open(image_path)
if self.transform:
image = self.transform(image)
return image, os.path.basename(image_path)
def extract_feats(index_config):
ai_config = index_config['a1_config']
weight_path = index_config['weight_path']
img_dir = index_config['img_dir']
batch_size = 1024 # 64
model, _, transform = open_clip.create_model_and_transforms(ai_config, pretrained=weight_path)
devive = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
print('Using device:', devive)
model = model.to(devive)
dataset = SingleFolderDataset(img_dir, transform=transform)
dl = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4)
im_ids = []
im_feats = []
for i, (patched_tensor, img_id) in tqdm(enumerate(dl)):
patched_tensor = patched_tensor.to(devive)
with torch.no_grad():
out = model.encode_image(patched_tensor)
im_ids.append(img_id)
im_feats.append(out.cpu().numpy())
im_hashes = np.concatenate(im_ids)
im_feats = np.concatenate(im_feats)
return im_hashes, im_feats
|