|
|
from subprocess import call |
|
|
import os, json |
|
|
|
|
|
from torchvision.datasets import VisionDataset |
|
|
from PIL import Image |
|
|
|
|
|
|
|
|
GITHUB_MAIN_ORIGINAL_ANNOTATION_PATH = 'https://github.com/mehdidc/retrieval_annotations/releases/download/1.0.0/coco_{}_karpathy.json' |
|
|
GITHUB_MAIN_PATH = 'https://raw.githubusercontent.com/adobe-research/Cross-lingual-Test-Dataset-XTD10/main/XTD10/' |
|
|
SUPPORTED_LANGUAGES = ['es', 'it', 'ko', 'pl', 'ru', 'tr', 'zh', 'en'] |
|
|
|
|
|
IMAGE_INDEX_FILE = 'mscoco-multilingual_index.json' |
|
|
IMAGE_INDEX_FILE_DOWNLOAD_NAME = 'test_image_names.txt' |
|
|
|
|
|
CAPTIONS_FILE_DOWNLOAD_NAME = 'test_1kcaptions_{}.txt' |
|
|
CAPTIONS_FILE_NAME = 'multilingual_mscoco_captions-{}.json' |
|
|
|
|
|
ORIGINAL_ANNOTATION_FILE_NAME = 'coco_{}_karpathy.json' |
|
|
|
|
|
|
|
|
|
|
|
class Multilingual_MSCOCO(VisionDataset): |
|
|
|
|
|
def __init__(self, root, ann_file, transform=None, target_transform=None): |
|
|
super().__init__(root, transform=transform, target_transform=target_transform) |
|
|
self.ann_file = os.path.expanduser(ann_file) |
|
|
with open(ann_file, 'r') as fp: |
|
|
data = json.load(fp) |
|
|
|
|
|
self.data = [(img_path, txt) for img_path, txt in zip(data['image_paths'], data['annotations'])] |
|
|
|
|
|
def __getitem__(self, index): |
|
|
img, captions = self.data[index] |
|
|
|
|
|
|
|
|
img = Image.open(os.path.join(self.root, img)).convert("RGB") |
|
|
if self.transform is not None: |
|
|
img = self.transform(img) |
|
|
|
|
|
|
|
|
target = [captions, ] |
|
|
if self.target_transform is not None: |
|
|
target = self.target_transform(target) |
|
|
|
|
|
return img, target |
|
|
|
|
|
|
|
|
def __len__(self) -> int: |
|
|
return len(self.data) |
|
|
|
|
|
|
|
|
def _get_downloadable_file(filename, download_url, is_json=True): |
|
|
if (os.path.exists(filename) == False): |
|
|
print("Downloading", download_url) |
|
|
call("wget {} -O {}".format(download_url, filename), shell=True) |
|
|
with open(filename, 'r') as fp: |
|
|
if (is_json): |
|
|
return json.load(fp) |
|
|
return [line.strip() for line in fp.readlines()] |
|
|
|
|
|
|
|
|
def create_annotation_file(root, lang_code): |
|
|
print("Downloading multilingual_ms_coco index file") |
|
|
download_path = os.path.join(GITHUB_MAIN_PATH, IMAGE_INDEX_FILE_DOWNLOAD_NAME) |
|
|
target_images = _get_downloadable_file("multilingual_coco_images.txt", download_path, False) |
|
|
|
|
|
print("Downloading multilingual_ms_coco captions:", lang_code) |
|
|
download_path = os.path.join(GITHUB_MAIN_PATH, CAPTIONS_FILE_DOWNLOAD_NAME.format(lang_code)) |
|
|
target_captions = _get_downloadable_file('raw_multilingual_coco_captions_{}.txt'.format(lang_code), download_path, False) |
|
|
|
|
|
number_of_missing_images = 0 |
|
|
valid_images, valid_annotations, valid_indicies = [], [], [] |
|
|
for i, (img, txt) in enumerate(zip(target_images, target_captions)): |
|
|
|
|
|
root_split = 'val2014' if 'val' in img else 'train2014' |
|
|
filename_with_root_split = "{}/{}".format(root_split, img) |
|
|
|
|
|
if (os.path.exists(filename_with_root_split)): |
|
|
print("Missing image file", img) |
|
|
number_of_missing_images += 1 |
|
|
continue |
|
|
|
|
|
valid_images.append(filename_with_root_split) |
|
|
valid_annotations.append(txt) |
|
|
valid_indicies.append(i) |
|
|
|
|
|
if (number_of_missing_images > 0): |
|
|
print("*** WARNING *** missing {} files.".format(number_of_missing_images)) |
|
|
|
|
|
with open(os.path.join(root, CAPTIONS_FILE_NAME.format(lang_code)), 'w') as fp: |
|
|
json.dump({'image_paths': valid_images, 'annotations': valid_annotations, 'indicies': valid_indicies}, fp) |
|
|
|