Spaces:
Configuration error
Configuration error
| """ | |
| # Copyright (c) 2022, salesforce.com, inc. | |
| # All rights reserved. | |
| # SPDX-License-Identifier: BSD-3-Clause | |
| # For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
| """ | |
| from PIL import Image | |
| import requests | |
| import torch | |
| import os | |
| from lavis.common.registry import registry | |
| from lavis.processors import * | |
| from lavis.models import * | |
| from lavis.common.utils import build_default_model | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def load_demo_image(): | |
| img_url = ( | |
| "https://storage.googleapis.com/sfr-vision-language-research/BLIP/demo.jpg" | |
| ) | |
| raw_image = Image.open(requests.get(img_url, stream=True).raw).convert("RGB") | |
| return raw_image | |
| def read_img(filepath): | |
| raw_image = Image.open(filepath).convert("RGB") | |
| return raw_image | |
| # model | |
| model_url = "https://storage.googleapis.com/sfr-vision-language-research/BLIP/models/model_base.pth" | |
| feature_extractor = BlipFeatureExtractor(pretrained=model_url) | |
| feature_extractor.eval() | |
| feature_extractor = feature_extractor.to(device) | |
| # preprocessors | |
| vis_processor = BlipImageEvalProcessor(image_size=224) | |
| text_processor = BlipCaptionProcessor() | |
| # files to process | |
| # file_root = "/export/home/.cache/lavis/coco/images/val2014" | |
| file_root = "/export/home/.cache/lavis/coco/images/train2014" | |
| filepaths = os.listdir(file_root) | |
| print(len(filepaths)) | |
| caption = "dummy" | |
| path2feat = dict() | |
| bsz = 256 | |
| images_in_batch = [] | |
| filepaths_in_batch = [] | |
| for i, filename in enumerate(filepaths): | |
| if i % bsz == 0 and i > 0: | |
| images_in_batch = torch.cat(images_in_batch, dim=0).to(device) | |
| with torch.no_grad(): | |
| image_features = feature_extractor( | |
| images_in_batch, caption, mode="image", normalized=True | |
| )[:, 0] | |
| for filepath, image_feat in zip(filepaths_in_batch, image_features): | |
| path2feat[os.path.basename(filepath)] = image_feat.detach().cpu() | |
| images_in_batch = [] | |
| filepaths_in_batch = [] | |
| print(len(path2feat), image_features.shape) | |
| else: | |
| filepath = os.path.join(file_root, filename) | |
| image = read_img(filepath) | |
| image = vis_processor(image).unsqueeze(0) | |
| images_in_batch.append(image) | |
| filepaths_in_batch.append(filepath) | |
| torch.save(path2feat, "path2feat_coco_train2014.pth") | |