Spaces:
Running
Running
| import sys | |
| from pathlib import Path | |
| sys.path.append(str(Path(__file__).resolve().parent.parent.parent)) | |
| import argparse | |
| import json | |
| from collections import UserDict | |
| from pathlib import Path | |
| import numpy as np | |
| import torch | |
| import webdataset as wds | |
| from PIL import Image | |
| from torchvision import transforms | |
| from tqdm import tqdm | |
| from webdataset.autodecode import ImageHandler | |
| from utils.image_processing import CenterCrop | |
| print("Loading dinov2") | |
| augmentation_dinov2 = transforms.Compose( | |
| [ | |
| CenterCrop(ratio="1:1"), | |
| transforms.Resize(336, interpolation=transforms.InterpolationMode.BICUBIC), | |
| transforms.ToTensor(), | |
| transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), | |
| ] | |
| ) | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| dinov2_model = torch.hub.load("facebookresearch/dinov2", "dinov2_vitl14_reg") | |
| dinov2_model.eval() | |
| dinov2_model.to(device) | |
| print(f"Model loaded on {device}") | |
| def dict_collate(batch): | |
| output_dict = {} | |
| if isinstance(batch[0], dict): | |
| for key in batch[0].keys(): | |
| list_key = [d[key] for d in batch] | |
| if key != "json": | |
| output_dict[key] = dict_collate(list_key) | |
| else: | |
| output_dict[key] = list_key | |
| return output_dict | |
| elif isinstance(batch[0], Image.Image): | |
| return [img for img in batch] | |
| else: | |
| return torch.utils.data.dataloader.default_collate(batch) | |
| def log_and_continue(exn): | |
| """Call in an exception handler to ignore any exception, issue a warning, and continue.""" | |
| # logging.warning(f"Handling webdataset error ({repr(exn)}). Ignoring.") | |
| return True | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| def add_clip_scores_and_embeddings(src, dest, batch_size=512): | |
| dataset = wds.DataPipeline( | |
| wds.SimpleShardList(str(src)), | |
| wds.split_by_worker, | |
| wds.tarfile_to_samples(), | |
| wds.rename( | |
| __key__="__key__", | |
| dino_image="jpg", | |
| image="jpg", | |
| street_clip="street_clip.npy", | |
| json="json", | |
| ), | |
| wds.decode( | |
| ImageHandler("pilrgb", ["dino_image"]) | |
| ), # avoid encoding decoding jpeg for true | |
| wds.map_dict( | |
| dino_image=augmentation_dinov2, | |
| image=lambda x: x, | |
| street_clip=lambda x: x, | |
| json=lambda x: x, | |
| ), | |
| wds.to_tuple( | |
| "__key__", | |
| "dino_image", | |
| "street_clip", | |
| "image", | |
| "json", | |
| ), | |
| wds.batched(batch_size), | |
| ) | |
| loader = wds.WebLoader(dataset, num_workers=8, batch_size=None) | |
| with wds.TarWriter(str(dest)) as sink: | |
| for batch in tqdm(loader, total=10000 // batch_size): | |
| ( | |
| keys, | |
| dino_image, | |
| street_clip, | |
| image, | |
| json, | |
| ) = batch | |
| dino_image = dino_image.to(device) | |
| with torch.no_grad(): | |
| dino_embedding = dinov2_model(dino_image).cpu().numpy() | |
| for i in range(len(keys)): | |
| sample = { | |
| "__key__": keys[i], | |
| "jpg": image[i], | |
| "street_clip.npy": street_clip[i], | |
| "json": json[i], | |
| "dinov2_vitl14_registers.npy": dino_embedding[i], | |
| } | |
| sink.write(sample) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--src", help="path to source files") | |
| parser.add_argument("--dest", help="path to destination files") | |
| parser.add_argument("--shard_id", help="shard id") | |
| args = parser.parse_args() | |
| src = Path(args.src) | |
| list_of_shards = list(src.glob("*.tar")) | |
| list_of_shards.sort() | |
| shard = str(list_of_shards[int(args.shard_id)]).split("/")[-1] | |
| dest = Path(args.dest) | |
| dest.mkdir(exist_ok=True, parents=True) | |
| batch_size = 256 | |
| print(f"Loading {shard}") | |
| tar_name = shard.split(".")[0] | |
| src_shard = src / shard # f"{{{tar_name}...{tar_name}}}.tar" | |
| print(f"Processing {src_shard} to {dest / shard}") | |
| add_clip_scores_and_embeddings(src_shard, dest / shard, batch_size) | |