Spaces:
Running
Running
| from tqdm import tqdm | |
| from argparse import ArgumentParser | |
| from jax import numpy as jnp | |
| from torchvision import datasets, transforms | |
| from torchvision.transforms import CenterCrop, Normalize, Resize, ToTensor | |
| from torchvision.transforms.functional import InterpolationMode | |
| from transformers import AutoTokenizer | |
| from modeling_hybrid_clip import FlaxHybridCLIP | |
| import utils | |
| import torch | |
| if __name__ == "__main__": | |
| parser = ArgumentParser() | |
| parser.add_argument("in_dir") | |
| parser.add_argument("out_file") | |
| args = parser.parse_args() | |
| model = FlaxHybridCLIP.from_pretrained("clip-italian/clip-italian") | |
| tokenizer = AutoTokenizer.from_pretrained( | |
| "dbmdz/bert-base-italian-xxl-uncased", cache_dir=None, use_fast=True | |
| ) | |
| image_size = model.config.vision_config.image_size | |
| val_preprocess = transforms.Compose( | |
| [ | |
| Resize([image_size], interpolation=InterpolationMode.BICUBIC), | |
| CenterCrop(image_size), | |
| ToTensor(), | |
| Normalize( | |
| (0.48145466, 0.4578275, 0.40821073), | |
| (0.26862954, 0.26130258, 0.27577711), | |
| ), | |
| ] | |
| ) | |
| dataset = utils.CustomDataSet(args.in_dir, transform=val_preprocess) | |
| loader = torch.utils.data.DataLoader( | |
| dataset, | |
| batch_size=256, | |
| shuffle=False, | |
| num_workers=16, | |
| drop_last=False, | |
| ) | |
| image_features = utils.precompute_image_features(model, loader) | |
| jnp.save(f"static/features/{args.out_file}", image_features) | |