| |
| |
| |
| |
| |
| |
|
|
| import os |
| import json |
| import random |
| import torch |
| import glob |
| from collections import defaultdict, Counter |
| from torchvision import transforms |
| from torchvision.datasets.folder import default_loader |
| from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_INCEPTION_MEAN, IMAGENET_INCEPTION_STD |
| from timm.data.transforms import RandomResizedCropAndInterpolation |
| from timm.data import create_transform |
|
|
| import utils |
| from glossary import normalize_word |
| from randaug import RandomAugment |
|
|
|
|
| class BaseDataset(torch.utils.data.Dataset): |
| def __init__( |
| self, data_path, split, transform, |
| tokenizer, num_max_bpe_tokens, task=None, |
| ): |
| index_files = self.get_index_files(split, task=task) |
| self.tokenizer = tokenizer |
| self.num_max_bpe_tokens = num_max_bpe_tokens |
| self.data_path = data_path |
| items = [] |
| self.index_files = index_files |
|
|
| offset = 0 |
| for _index_file in index_files: |
| index_file = os.path.join(data_path, _index_file) |
| with open(index_file, mode="r", encoding="utf-8") as reader: |
| for line in reader: |
| data = json.loads(line) |
| items.append(data) |
| print("Load %d image-text pairs from %s. " % (len(items) - offset, index_file)) |
| offset = len(items) |
| self.items = items |
| self.bos_token_id = tokenizer.bos_token_id |
| self.eos_token_id = tokenizer.eos_token_id |
| self.pad_token_id = tokenizer.pad_token_id |
| self.loader = default_loader |
| self.transform = transform |
| self.split = split |
|
|
| @staticmethod |
| def get_index_files(split): |
| raise NotImplementedError() |
|
|
| def _get_image(self, image_path: str): |
| image_path = os.path.join(self.data_path, image_path) |
| image = self.loader(image_path) |
| return self.transform(image) |
|
|
| def _get_text_segment(self, text_segment, max_len=None): |
| if isinstance(text_segment, str): |
| tokens = self.tokenizer.tokenize(text_segment) |
| else: |
| tokens = text_segment[:] |
| if len(tokens) == 0: |
| raise RuntimeError("The text segment should contains at least one tokens!") |
| if max_len is None: |
| max_len = self.num_max_bpe_tokens |
|
|
| if len(tokens) > max_len - 2: |
| tokens = tokens[:max_len - 2] |
|
|
| tokens = [self.bos_token_id] + tokens[:] + [self.eos_token_id] |
| num_tokens = len(tokens) |
| padding_mask = [0] * num_tokens + [1] * (max_len - num_tokens) |
| return tokens + [self.pad_token_id] * (max_len - num_tokens), padding_mask, num_tokens |
|
|
| def _get_image_text_example(self, index: int, data: dict): |
| item = self.items[index] |
| img_path = item["image_path"] |
| img = self._get_image(img_path) |
| data["image"] = img |
|
|
| text_segment = item["text_segment"] |
| language_tokens, padding_mask, _ = self._get_text_segment(text_segment) |
| data["language_tokens"] = language_tokens |
| data["padding_mask"] = padding_mask |
|
|
| def __getitem__(self, index: int): |
| data = dict() |
| self._get_image_text_example(index, data) |
| return data |
|
|
| def __len__(self) -> int: |
| return len(self.items) |
|
|
| def __repr__(self) -> str: |
| head = "Dataset " + self.__class__.__name__ |
| body = '{' + "\n Number of items: %s," % self.__len__() |
| body += "\n data root = %s," % self.data_path |
| body += "\n split = %s," % self.split |
| body += "\n dataset index files = %s" % str(self.index_files) |
| body += "\n num max bpe tokens = %s" % self.num_max_bpe_tokens |
| body += "\n transforms = [" |
| for t in self.transform.transforms: |
| body += "\n %s" % str(t) |
| body += "\n ]" |
| body += "\n}" |
|
|
| return head + body |
|
|
|
|
| def _write_data_into_jsonl(items, jsonl_file): |
| with open(jsonl_file, mode="w", encoding="utf-8") as writer: |
| for data in items: |
| writer.write(json.dumps(data, indent=None)) |
| writer.write('\n') |
| print("Write %s with %d items !" % (jsonl_file, len(items))) |
|
|
|
|
| def _make_retrieval_coco_karpathy_dataset_index( |
| data_path, |
| tokenizer, |
| split=("train", "restval"), |
| split_name="train", |
| ): |
| coco_karpathy_split_json_file = os.path.join(data_path, "dataset_coco.json") |
| items = [] |
| image_counter = set() |
| print("read %s" % coco_karpathy_split_json_file) |
| with open(coco_karpathy_split_json_file, mode="r", encoding="utf-8") as reader: |
| data = json.loads(reader.read()) |
| for item in data["images"]: |
| if item["split"] in split: |
| image_path = os.path.join(item["filepath"], item["filename"]) |
| for sent in item["sentences"]: |
| tokens = tokenizer.tokenize(sent["raw"]) |
| token_ids = tokenizer.convert_tokens_to_ids(tokens) |
| items.append({ |
| "image_path": image_path, |
| "text_segment": token_ids, |
| "image_id": len(image_counter), |
| }) |
| if image_path not in image_counter: |
| image_counter.add(image_path) |
| print("Find %d images and %d image-text pairs for karpathy dataset %s split !" % \ |
| (len(image_counter), len(items), split_name)) |
| index_file = os.path.join(data_path, "coco_retrieval.%s.jsonl" % split_name) |
| _write_data_into_jsonl(items, index_file) |
| pass |
|
|
|
|
| def _make_captioning_coco_karpathy_dataset_index( |
| data_path, |
| tokenizer, |
| split=("train", "restval"), |
| split_name="train", |
| ): |
| coco_karpathy_split_json_file = os.path.join(data_path, "dataset_coco.json") |
| items = [] |
| image_counter = set() |
| print("read %s" % coco_karpathy_split_json_file) |
| with open(coco_karpathy_split_json_file, mode="r", encoding="utf-8") as reader: |
| data = json.loads(reader.read()) |
| for item in data["images"]: |
| if item["split"] in split: |
| image_path = os.path.join(item["filepath"], item["filename"]) |
| if item["split"] in ["train", "restval"]: |
| for sent in item["sentences"]: |
| tokens = tokenizer.tokenize(sent["raw"]) |
| token_ids = tokenizer.convert_tokens_to_ids(tokens) |
| items.append({ |
| "image_path": image_path, |
| "text_segment": token_ids, |
| "image_id": item["cocoid"], |
| }) |
| else: |
| items.append({ |
| "image_path": image_path, |
| "text_segment": None, |
| "image_id": item["cocoid"], |
| }) |
| if image_path not in image_counter: |
| image_counter.add(image_path) |
| print("Find %d images and %d image-text pairs for karpathy dataset %s split !" % \ |
| (len(image_counter), len(items), split_name)) |
| index_file = os.path.join(data_path, "coco_captioning.%s.jsonl" % split_name) |
| _write_data_into_jsonl(items, index_file) |
| pass |
|
|
|
|
| def _make_nocaps_dataset_index( |
| data_path, |
| split="val", |
| ): |
| if split == "val": |
| json_file = "nocaps_val_4500_captions.json" |
| elif split == "test": |
| json_file = "nocaps_test_image_info.json" |
| nocaps_split_json_file = os.path.join(data_path, json_file) |
| items = [] |
| image_counter = set() |
| print("read %s" % nocaps_split_json_file) |
| with open(nocaps_split_json_file, mode="r", encoding="utf-8") as reader: |
| data = json.loads(reader.read()) |
| for item in data["images"]: |
| image_path = os.path.join(split, item["file_name"]) |
| items.append({ |
| "image_path": image_path, |
| "text_segment": None, |
| "image_id": item["id"], |
| }) |
|
|
| if image_path not in image_counter: |
| image_counter.add(image_path) |
|
|
| print("Find %d images and %d image-text pairs for nocaps dataset %s split !" % \ |
| (len(image_counter), len(items), split)) |
| index_file = os.path.join(data_path, "nocaps.%s.jsonl" % split) |
| _write_data_into_jsonl(items, index_file) |
|
|
|
|
| class NLVR2Dataset(BaseDataset): |
| @staticmethod |
| def get_index_files(split, task=None): |
| if split == "train": |
| return ("nlvr2.train.index.jsonl", ) |
| elif split == "val": |
| return ("nlvr2.dev.index.jsonl", ) |
| elif split == "test": |
| return ("nlvr2.test-P.index.jsonl", ) |
| else: |
| raise RuntimeError("split %s is not found!" % split) |
|
|
| def __getitem__(self, index: int): |
| data = super().__getitem__(index) |
| item = self.items[index] |
| img_path = item["image2_path"] |
| img = self._get_image(img_path) |
| data["image2"] = img |
| data["label"] = self.items[index]["label"] |
| return data |
|
|
| @staticmethod |
| def __preprocess_json(preifx, json_file, tokenizer, index_file): |
| items = [] |
| with open(json_file, mode="r", encoding="utf-8") as reader: |
| for line in reader: |
| data = json.loads(line) |
| path = os.path.join(preifx, str(data["directory"])) if "directory" in data else preifx |
| path = os.path.join(path, "-".join(data["identifier"].split("-")[:-1])) |
| tokens = tokenizer.tokenize(data["sentence"]) |
| token_ids = tokenizer.convert_tokens_to_ids(tokens) |
| items.append({ |
| "image_path": path + "-img0.png", |
| "image2_path": path + "-img1.png", |
| "text_segment": token_ids, |
| "label": 1 if data["label"] == "True" else 0, |
| "identifier": data["identifier"], |
| }) |
| _write_data_into_jsonl(items, index_file) |
|
|
| @classmethod |
| def make_dataset_index(cls, data_path, tokenizer, nlvr_repo_path): |
| cls.__preprocess_json( |
| preifx="images/train", json_file=os.path.join(nlvr_repo_path, "nlvr2/data/train.json"), |
| tokenizer=tokenizer, index_file=os.path.join(data_path, cls.get_index_files("train")[0]), |
| ) |
| cls.__preprocess_json( |
| preifx="dev", json_file=os.path.join(nlvr_repo_path, "nlvr2/data/dev.json"), |
| tokenizer=tokenizer, index_file=os.path.join(data_path, cls.get_index_files("val")[0]), |
| ) |
| cls.__preprocess_json( |
| preifx="test1", json_file=os.path.join(nlvr_repo_path, "nlvr2/data/test1.json"), |
| tokenizer=tokenizer, index_file=os.path.join(data_path, cls.get_index_files("test")[0]), |
| ) |
|
|
|
|
| class ImageNetDataset(BaseDataset): |
| @staticmethod |
| def get_index_files(split, task=None): |
| if split == "train": |
| return ("imagenet.train.index.jsonl", ) |
| elif split == "val": |
| return ("imagenet.val.index.jsonl", ) |
| elif split == "test": |
| return ("imagenet.val.index.jsonl", ) |
| else: |
| raise RuntimeError("split %s is not found!" % split) |
|
|
| def __getitem__(self, index: int): |
| data = dict() |
| item = self.items[index] |
| img_path = item["image_path"] |
| img = self._get_image(img_path) |
| data["image"] = img |
| data["label"] = item["label"] |
| return data |
| |
| @staticmethod |
| def _find_classes(dir): |
| """ |
| Finds the class folders in a dataset. |
| Args: |
| dir (string): Root directory path. |
| Returns: |
| tuple: (classes, class_to_idx) where classes are relative to (dir), and class_to_idx is a dictionary. |
| Ensures: |
| No class is a subdirectory of another. |
| """ |
| classes = [d.name for d in os.scandir(dir) if d.is_dir()] |
| classes.sort() |
| class_to_idx = {cls_name: i for i, cls_name in enumerate(classes)} |
| return classes, class_to_idx |
|
|
| @staticmethod |
| def _make_imagenet_index(data_path, index_path, data_path_prefix, class_to_idx, split): |
| items = [] |
| index_file = os.path.join(index_path, f"imagenet.{split}.index.jsonl") |
| for target_class in sorted(class_to_idx.keys()): |
| class_index = class_to_idx[target_class] |
| target_dir = os.path.join(data_path, target_class) |
| if not os.path.isdir(target_dir): |
| continue |
| for root, _, fnames in sorted(os.walk(target_dir, followlinks=True)): |
| for fname in sorted(fnames): |
| path = os.path.join(root, fname) |
| path = path.replace(data_path_prefix, "") |
| items.append({ |
| "image_path": path, |
| "label": class_index, |
| }) |
|
|
| _write_data_into_jsonl(items, index_file) |
|
|
| @classmethod |
| def make_dataset_index(cls, train_data_path, val_data_path, index_path): |
| data_path_prefix = train_data_path[:[x[0]==x[1] for x in zip(train_data_path, val_data_path)].index(0)] |
| classes, class_to_idx = cls._find_classes(train_data_path) |
| cls._make_imagenet_index( |
| data_path=train_data_path, index_path=index_path, data_path_prefix=data_path_prefix, |
| class_to_idx=class_to_idx, split="train", |
| ) |
| cls._make_imagenet_index( |
| data_path=val_data_path, index_path=index_path, data_path_prefix=data_path_prefix, |
| class_to_idx=class_to_idx, split="val", |
| ) |
|
|
|
|
| class VQAv2Dataset(BaseDataset): |
| def __init__(self, data_path, **kwargs): |
| super().__init__(data_path=data_path, **kwargs) |
| ans2label_file = os.path.join(data_path, "answer2label.txt") |
| ans2label = {} |
| label2ans = [] |
| with open(ans2label_file, mode="r", encoding="utf-8") as reader: |
| for i, line in enumerate(reader): |
| data = json.loads(line) |
| ans = data["answer"] |
| label = data["label"] |
| label = int(label) |
| assert label == i |
| ans2label[ans] = i |
| label2ans.append(ans) |
| |
| self.ans2label = ans2label |
| self.label2ans = label2ans |
|
|
| @staticmethod |
| def get_index_files(split, task=None): |
| if split == "train": |
| return ("vqa.train.jsonl", "vqa.trainable_val.jsonl") |
| elif split == "val": |
| return ("vqa.rest_val.jsonl", ) |
| elif split == "test": |
| return ("vqa.test.jsonl", ) |
| elif split == "test-dev": |
| return ("vqa.test-dev.jsonl", ) |
| else: |
| raise RuntimeError("split %s is not found!" % split) |
|
|
| def __getitem__(self, index: int): |
| data = super().__getitem__(index) |
| if "labels" in self.items[index] and len(self.items[index]["labels"]) > 0: |
| labels = [0.] * len(self.label2ans) |
| for l, s in zip(self.items[index]["labels"], self.items[index]["scores"]): |
| labels[l] = s |
| data["labels"] = torch.FloatTensor(labels) |
| else: |
| data["qid"] = self.items[index]["qid"] |
| return data |
|
|
| @staticmethod |
| def get_score(occurences): |
| if occurences == 0: |
| return 0.0 |
| elif occurences == 1: |
| return 0.3 |
| elif occurences == 2: |
| return 0.6 |
| elif occurences == 3: |
| return 0.9 |
| else: |
| return 1.0 |
|
|
| @classmethod |
| def make_dataset_index(cls, data_path, tokenizer, annotation_data_path): |
| with open(os.path.join(annotation_data_path, "v2_OpenEnded_mscoco_train2014_questions.json"), "r") as fp: |
| questions_train2014 = json.load(fp)["questions"] |
| with open(os.path.join(annotation_data_path, "v2_OpenEnded_mscoco_val2014_questions.json"), "r") as fp: |
| questions_val2014 = json.load(fp)["questions"] |
| with open(os.path.join(annotation_data_path, "v2_OpenEnded_mscoco_test2015_questions.json"), "r") as fp: |
| questions_test2015 = json.load(fp)["questions"] |
| with open(os.path.join(annotation_data_path, "v2_OpenEnded_mscoco_test-dev2015_questions.json"), "r") as fp: |
| questions_test_dev2015 = json.load(fp)["questions"] |
|
|
| with open(os.path.join(annotation_data_path, "v2_mscoco_train2014_annotations.json"), "r") as fp: |
| annotations_train2014 = json.load(fp)["annotations"] |
| with open(os.path.join(annotation_data_path, "v2_mscoco_val2014_annotations.json"), "r") as fp: |
| annotations_val2014 = json.load(fp)["annotations"] |
|
|
| annotations = dict() |
|
|
| for split, questions in zip( |
| ["train", "val", "test", "test-dev"], |
| [questions_train2014, questions_val2014, questions_test2015, questions_test_dev2015], |
| ): |
| _annot = defaultdict(dict) |
| for q in questions: |
| question_text = q["question"] |
| tokens = tokenizer.tokenize(question_text) |
| token_ids = tokenizer.convert_tokens_to_ids(tokens) |
|
|
| assert q["question_id"] not in _annot[q["image_id"]] |
| _annot[q["image_id"]][q["question_id"]] = { |
| "question": question_text, |
| "token_ids": token_ids, |
| } |
|
|
| annotations[split] = _annot |
|
|
| all_major_answers = list() |
|
|
| for split, annots in zip( |
| ["train", "val"], [annotations_train2014, annotations_val2014], |
| ): |
| |
| for q in annots: |
| all_major_answers.append(q["multiple_choice_answer"]) |
|
|
| all_major_answers = [normalize_word(word) for word in all_major_answers] |
| counter = {k: v for k, v in Counter(all_major_answers).items() if v >= 9} |
| ans2label = {k: i for i, k in enumerate(counter.keys())} |
| label2ans = list(counter.keys()) |
|
|
| for split, annots in zip( |
| ["train", "val"], [annotations_train2014, annotations_val2014], |
| ): |
| _annot = annotations[split] |
| for q in annots: |
| answers = q["answers"] |
| answer_count = {} |
| for answer in answers: |
| answer_ = answer["answer"] |
| answer_count[answer_] = answer_count.get(answer_, 0) + 1 |
|
|
| labels = [] |
| scores = [] |
| for answer in answer_count: |
| if answer not in ans2label: |
| continue |
| labels.append(ans2label[answer]) |
| score = cls.get_score(answer_count[answer]) |
| scores.append(score) |
|
|
| assert "labels" not in _annot[q["image_id"]][q["question_id"]] |
| assert "question" in _annot[q["image_id"]][q["question_id"]] |
| _annot[q["image_id"]][q["question_id"]]["labels"] = labels |
| _annot[q["image_id"]][q["question_id"]]["scores"] = scores |
|
|
| for split in ["train", "val"]: |
| filtered_annot = dict() |
| for ik, iv in annotations[split].items(): |
| new_q = dict() |
| for qk, qv in iv.items(): |
| if len(qv["labels"]) != 0: |
| new_q[qk] = qv |
| if len(new_q) != 0: |
| filtered_annot[ik] = new_q |
| annotations[split] = filtered_annot |
|
|
| split2items = {} |
| for split in ["train", "val", "test", "test-dev"]: |
| annot = annotations[split] |
| split_name = { |
| "train": "train2014", |
| "val": "val2014", |
| "test": "test2015", |
| "test-dev": "test2015", |
| }[split] |
| paths = list(glob.glob(f"{data_path}/{split_name}/*.jpg")) |
| random.shuffle(paths) |
| annot_paths = [path for path in paths \ |
| if int(path.split("/")[-1].split("_")[-1][:-4]) in annot] |
|
|
| if len(paths) == len(annot_paths): |
| print("all images have caption annotations") |
| else: |
| print("not all images have caption annotations") |
| print(len(paths), len(annot_paths), len(annot)) |
|
|
| items = [] |
| for path in annot_paths: |
| iid = int(path.split("/")[-1].split("_")[-1][:-4]) |
| _annot = annotations[split][iid] |
| for qid in _annot: |
| q = _annot[qid] |
| if split in ["train", "val"]: |
| labels = q["labels"] |
| scores = q["scores"] |
| else: |
| labels, scores = [], [] |
|
|
| items.append({ |
| "image_path": os.path.join(split_name, path.split('/')[-1]), |
| "text_segment": q["token_ids"], |
| "labels": labels, |
| "scores": scores, |
| "qid": qid, |
| }) |
| split2items[split] = items |
|
|
| _write_data_into_jsonl(items=items, jsonl_file=os.path.join(data_path, "vqa.%s.jsonl" % split)) |
|
|
| |
| val_image2items = defaultdict(list) |
| for item in split2items["val"]: |
| val_image2items[item["image_path"]].append(item) |
| |
| print("Contains %d image and %d pairs for val set!" % (len(val_image2items), len(split2items["val"]))) |
|
|
| val_images = list(val_image2items.keys()) |
| random.shuffle(val_images) |
| trainable_val = [] |
| rest_val = [] |
| for i, image_id in enumerate(val_images): |
| if i < 1000: |
| rest_val += val_image2items[image_id] |
| else: |
| trainable_val += val_image2items[image_id] |
| |
| _write_data_into_jsonl(items=trainable_val, jsonl_file=os.path.join(data_path, "vqa.trainable_val.jsonl")) |
| _write_data_into_jsonl(items=rest_val, jsonl_file=os.path.join(data_path, "vqa.rest_val.jsonl")) |
|
|
| with open(os.path.join(data_path, "answer2label.txt"), mode="w", encoding="utf-8") as writer: |
| for ans in ans2label: |
| to_json = { |
| "answer": ans, |
| "label": ans2label[ans] |
| } |
| writer.write("%s\n" % json.dumps(to_json)) |
|
|
|
|
| class RetrievalDataset(BaseDataset): |
| @staticmethod |
| def get_index_files(split, task=None): |
| if split == "train": |
| return (f"{task}.train.jsonl", ) |
| elif split == "val": |
| return (f"{task}.val.jsonl", ) |
| elif split == "test": |
| return (f"{task}.test.jsonl", ) |
| else: |
| raise RuntimeError("split %s is not found!" % split) |
|
|
| def __getitem__(self, index: int): |
| data = super().__getitem__(index) |
| data["image_id"] = self.items[index]["image_id"] |
| return data |
|
|
| @staticmethod |
| def make_flickr30k_dataset_index(data_path, tokenizer, karpathy_path): |
|
|
| with open(os.path.join(karpathy_path, "dataset_flickr30k.json"), "r") as reader: |
| captions = json.loads(reader.read()) |
|
|
| captions = captions["images"] |
| split2items = defaultdict(list) |
| split2images = defaultdict(set) |
|
|
| for each_item in captions: |
| image_path = os.path.join("flickr30k-images", each_item["filename"]) |
| split = each_item["split"] |
|
|
| for text_segment in each_item["sentences"]: |
| tokens = tokenizer.tokenize(text_segment["raw"]) |
| token_ids = tokenizer.convert_tokens_to_ids(tokens) |
|
|
| split2items[split].append({ |
| "image_path": image_path, |
| "text_segment": token_ids, |
| "image_id": len(split2images[split]), |
| }) |
|
|
| assert each_item["filename"] not in split2images[split] |
| split2images[split].add(each_item["filename"]) |
|
|
| for split in split2items: |
| print("%d images and %d image-text pairs!" % (len(split2images[split]), len(split2items[split]))) |
| _write_data_into_jsonl(split2items[split], os.path.join(data_path, "flickr30k.%s.jsonl" % split)) |
|
|
| @staticmethod |
| def make_coco_dataset_index(data_path, tokenizer): |
| _make_retrieval_coco_karpathy_dataset_index(data_path, tokenizer, split=("train", "restval"), split_name="train") |
| _make_retrieval_coco_karpathy_dataset_index(data_path, tokenizer, split=("val", ), split_name="val") |
| _make_retrieval_coco_karpathy_dataset_index(data_path, tokenizer, split=("test", ), split_name="test") |
|
|
|
|
| class CaptioningDataset(BaseDataset): |
|
|
| def __init__(self, data_path, split, transform, |
| tokenizer, num_max_bpe_tokens, task, mask_prob): |
| super().__init__( |
| data_path=data_path, split=split, |
| transform=transform, tokenizer=tokenizer, |
| num_max_bpe_tokens=num_max_bpe_tokens, task=task, |
| ) |
| self.mask_token_id = tokenizer.mask_token_id |
| self.language_vocab_size = tokenizer.vocab_size |
| self.mask_prob = mask_prob |
|
|
| @staticmethod |
| def get_index_files(split, task=None): |
| if split == "train": |
| return ("coco_captioning.train.jsonl", ) |
| elif split == "val": |
| return (f"{task}.val.jsonl", ) |
| elif split == "test": |
| return (f"{task}.test.jsonl", ) |
| else: |
| raise RuntimeError("split %s is not found!" % split) |
|
|
| def _get_mask_token(self, token): |
| p = random.random() |
| if p < 0.8: |
| return self.mask_token_id |
| elif p < 0.9: |
| return token |
| else: |
| return random.randint(3, self.language_vocab_size - 1) |
|
|
| def _masking_on_text_tokens(self, tokens, num_tokens, mask_prob): |
| bool_masked_pos = [0] * len(tokens) |
| to_mask = min(int(num_tokens * mask_prob + 0.5), num_tokens - 1) |
| to_mask = max(to_mask, 1) |
| num_masked_tokens = 0 |
| while num_masked_tokens < to_mask: |
| i = random.randint(1, num_tokens - 1) |
| if bool_masked_pos[i] == 0: |
| bool_masked_pos[i] = 1 |
| tokens[i] = self._get_mask_token(tokens[i]) |
| num_masked_tokens += 1 |
|
|
| return tokens, bool_masked_pos |
|
|
| def __getitem__(self, index: int): |
| data = dict() |
| item = self.items[index] |
| img_path = item["image_path"] |
| img = self._get_image(img_path) |
| data["image"] = img |
| data["image_id"] = item["image_id"] |
|
|
| text_segment = item["text_segment"] |
| if text_segment is not None: |
| language_tokens, padding_mask, num_tokens = self._get_text_segment(text_segment) |
| masked_tokens = language_tokens[:] |
| masked_tokens, language_masked_pos = \ |
| self._masking_on_text_tokens(masked_tokens, num_tokens, self.mask_prob) |
| data["language_tokens"] = language_tokens |
| data["masked_tokens"] = masked_tokens |
| data["language_masked_pos"] = language_masked_pos |
| data["padding_mask"] = padding_mask |
| return data |
|
|
| @staticmethod |
| def make_coco_captioning_dataset_index(data_path, tokenizer): |
| _make_captioning_coco_karpathy_dataset_index(data_path, tokenizer, split=("train", "restval"), split_name="train") |
| _make_captioning_coco_karpathy_dataset_index(data_path, tokenizer, split=("val", ), split_name="val") |
| _make_captioning_coco_karpathy_dataset_index(data_path, tokenizer, split=("test", ), split_name="test") |
|
|
| @staticmethod |
| def make_nocaps_captioning_dataset_index(data_path): |
| _make_nocaps_dataset_index(data_path, split="val") |
| _make_nocaps_dataset_index(data_path, split="test") |
|
|
|
|
| task2dataset = { |
| "nlvr2": NLVR2Dataset, |
| "vqav2": VQAv2Dataset, |
| "flickr30k": RetrievalDataset, |
| "coco_retrieval": RetrievalDataset, |
| "coco_captioning": CaptioningDataset, |
| "nocaps": CaptioningDataset, |
| "imagenet": ImageNetDataset, |
| } |
|
|
|
|
| def create_dataloader(dataset, is_train, batch_size, num_workers, pin_mem, dist_eval=False): |
| if is_train or dist_eval: |
| num_tasks = utils.get_world_size() |
| global_rank = utils.get_rank() |
|
|
| if not is_train and dist_eval and len(dataset) % num_tasks != 0: |
| print('Warning: Enabling distributed evaluation with an eval dataset not divisible by process number. ' |
| 'This will slightly alter validation results as extra duplicate entries are added to achieve ' |
| 'equal num of samples per-process.') |
|
|
| sampler = torch.utils.data.DistributedSampler( |
| dataset, num_replicas=num_tasks, rank=global_rank, shuffle=is_train |
| ) |
| else: |
| sampler = torch.utils.data.SequentialSampler(dataset) |
| |
| return torch.utils.data.DataLoader( |
| dataset, sampler=sampler, |
| batch_size=batch_size, |
| num_workers=num_workers, |
| pin_memory=pin_mem, |
| drop_last=is_train, |
| collate_fn=utils.merge_batch_tensors_by_dict_key, |
| ) |
|
|
|
|
| def build_transform(is_train, args): |
| if args.task in ["imagenet"]: |
| return build_imagenet_transform(is_train, args) |
|
|
| if is_train: |
| t = [ |
| RandomResizedCropAndInterpolation(args.input_size, scale=(0.5, 1.0), interpolation=args.train_interpolation), |
| transforms.RandomHorizontalFlip(), |
| ] |
| if args.randaug: |
| t.append( |
| RandomAugment( |
| 2, 7, isPIL=True, |
| augs=[ |
| 'Identity','AutoContrast','Equalize','Brightness','Sharpness', |
| 'ShearX', 'ShearY', 'TranslateX', 'TranslateY', 'Rotate', |
| ])) |
| t += [ |
| transforms.ToTensor(), |
| transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD), |
| ] |
| t = transforms.Compose(t) |
| else: |
| t = transforms.Compose([ |
| transforms.Resize((args.input_size, args.input_size), interpolation=3), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=IMAGENET_INCEPTION_MEAN, std=IMAGENET_INCEPTION_STD) |
| ]) |
|
|
| return t |
|
|
|
|
| def build_imagenet_transform(is_train, args): |
| resize_im = args.input_size > 32 |
| if is_train: |
| |
| transform = create_transform( |
| input_size=args.input_size, |
| is_training=True, |
| color_jitter=args.color_jitter, |
| auto_augment=args.aa, |
| interpolation=args.train_interpolation, |
| re_prob=args.reprob, |
| re_mode=args.remode, |
| re_count=args.recount, |
| mean=IMAGENET_DEFAULT_MEAN, |
| std=IMAGENET_DEFAULT_STD, |
| ) |
| if not resize_im: |
| |
| |
| transform.transforms[0] = transforms.RandomCrop( |
| args.input_size, padding=4) |
| return transform |
|
|
| t = [] |
| if resize_im: |
| if args.crop_pct is None: |
| args.crop_pct = 1.0 |
| size = int(args.input_size / args.crop_pct) |
| t.append( |
| transforms.Resize(size, interpolation=3), |
| ) |
| t.append(transforms.CenterCrop(args.input_size)) |
|
|
| t.append(transforms.ToTensor()) |
| t.append(transforms.Normalize(mean=IMAGENET_DEFAULT_MEAN, std=IMAGENET_DEFAULT_STD)) |
| return transforms.Compose(t) |
|
|
|
|
| def get_sentencepiece_model_for_beit3(args): |
| from transformers import XLMRobertaTokenizer |
| return XLMRobertaTokenizer(args.sentencepiece_model) |
|
|
|
|
| def create_dataset_by_split(args, split, is_train=True): |
| transform = build_transform(is_train=is_train, args=args) |
| dataset_class = task2dataset[args.task] |
| tokenizer = get_sentencepiece_model_for_beit3(args) |
|
|
| opt_kwargs = {} |
| if args.task in ["coco_captioning", "nocaps"]: |
| opt_kwargs["mask_prob"] = args.captioning_mask_prob |
|
|
| dataset = dataset_class( |
| data_path=args.data_path, split=split, |
| transform=transform, tokenizer=tokenizer, |
| num_max_bpe_tokens=args.num_max_bpe_tokens, |
| task=args.task, **opt_kwargs, |
| ) |
| if is_train: |
| batch_size = args.batch_size |
| elif hasattr(args, "eval_batch_size") and args.eval_batch_size is not None: |
| batch_size = args.eval_batch_size |
| else: |
| batch_size = int(args.batch_size * 1.5) |
|
|
| return create_dataloader( |
| dataset, is_train=is_train, batch_size=batch_size, |
| num_workers=args.num_workers, pin_mem=args.pin_mem, dist_eval=args.dist_eval, |
| ) |
|
|
|
|
| def create_downstream_dataset(args, is_eval=False): |
| if is_eval: |
| return create_dataset_by_split(args, split="test", is_train=False) |
| else: |
| return \ |
| create_dataset_by_split(args, split="train", is_train=True), \ |
| create_dataset_by_split(args, split="val", is_train=True) |
|
|