| import argparse |
| import pickle |
| import os |
|
|
| import torch |
| from torch import nn |
| from torch.nn import functional as F |
| from torch.utils.data import DataLoader |
| from torchvision import transforms |
| from torchvision.models import inception_v3, Inception3 |
| import numpy as np |
| from tqdm import tqdm |
|
|
| from inception import InceptionV3 |
| from torchvision.datasets import ImageFolder |
|
|
| class Inception3Feature(Inception3): |
| def forward(self, x): |
| if x.shape[2] != 299 or x.shape[3] != 299: |
| x = F.interpolate(x, size=(299, 299), mode='bilinear', align_corners=True) |
|
|
| x = self.Conv2d_1a_3x3(x) |
| x = self.Conv2d_2a_3x3(x) |
| x = self.Conv2d_2b_3x3(x) |
| x = F.max_pool2d(x, kernel_size=3, stride=2) |
|
|
| x = self.Conv2d_3b_1x1(x) |
| x = self.Conv2d_4a_3x3(x) |
| x = F.max_pool2d(x, kernel_size=3, stride=2) |
|
|
| x = self.Mixed_5b(x) |
| x = self.Mixed_5c(x) |
| x = self.Mixed_5d(x) |
|
|
| x = self.Mixed_6a(x) |
| x = self.Mixed_6b(x) |
| x = self.Mixed_6c(x) |
| x = self.Mixed_6d(x) |
| x = self.Mixed_6e(x) |
|
|
| x = self.Mixed_7a(x) |
| x = self.Mixed_7b(x) |
| x = self.Mixed_7c(x) |
|
|
| x = F.avg_pool2d(x, kernel_size=8) |
|
|
| return x.view(x.shape[0], x.shape[1]) |
|
|
|
|
| def load_patched_inception_v3(): |
| |
| |
| |
| inception_feat = InceptionV3([3], normalize_input=False) |
|
|
| return inception_feat |
|
|
|
|
| @torch.no_grad() |
| def extract_features(loader, inception, device): |
| pbar = tqdm(loader) |
|
|
| feature_list = [] |
|
|
| for img,_ in pbar: |
| img = img.to(device) |
| feature = inception(img)[0].view(img.shape[0], -1) |
| feature_list.append(feature.to('cpu')) |
|
|
| features = torch.cat(feature_list, 0) |
|
|
| return features |
|
|
|
|
| if __name__ == '__main__': |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| parser = argparse.ArgumentParser( |
| description='Calculate Inception v3 features for datasets' |
| ) |
| parser.add_argument('--size', type=int, default=256) |
| parser.add_argument('--batch', default=64, type=int, help='batch size') |
| parser.add_argument('--n_sample', type=int, default=50000) |
| parser.add_argument('--flip', action='store_true') |
| parser.add_argument('path', metavar='PATH', help='path to datset lmdb file') |
|
|
| args = parser.parse_args() |
|
|
| inception = load_patched_inception_v3().eval().to(device) |
|
|
| transform = transforms.Compose( |
| [ |
| transforms.Resize( (args.size, args.size) ), |
| transforms.RandomHorizontalFlip(p=0.5 if args.flip else 0), |
| transforms.ToTensor(), |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), |
| ] |
| ) |
|
|
| dset = ImageFolder(args.path, transform) |
| loader = DataLoader(dset, batch_size=args.batch, num_workers=4) |
|
|
| features = extract_features(loader, inception, device).numpy() |
|
|
| features = features[: args.n_sample] |
|
|
| print(f'extracted {features.shape[0]} features') |
|
|
| mean = np.mean(features, 0) |
| cov = np.cov(features, rowvar=False) |
|
|
| name = os.path.splitext(os.path.basename(args.path))[0] |
|
|
| print({'mean': mean.mean(), 'cov': cov.mean()}) |
| with open(f'inception_{name}.pkl', 'wb') as f: |
| pickle.dump({'mean': mean, 'cov': cov, 'size': args.size, 'path': args.path}, f) |
|
|