| from __future__ import absolute_import |
| from __future__ import division |
| from __future__ import print_function |
|
|
| import os |
| from io import BytesIO |
| import json |
| import logging |
| import base64 |
| import random |
| from typing import Callable, List, Tuple, Union, NamedTuple |
| from PIL import Image |
| from PIL import ImageFile |
| import torch.utils.data as data |
| from .languages.prompt_engineering import prompt_engineering |
| from .tsv_file import TSVFile, CompositeTSVFile |
|
|
| ImageFile.LOAD_TRUNCATED_IMAGES = True |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class TSVDataset(data.Dataset): |
|
|
| def __init__(self, |
| tsv_file: Union[str, List[str]], |
| transform: Callable = None, |
| map_file: str = None, |
| token_file: str = None, |
| is_train: bool = True, |
| azcopy_path: str = None): |
| self.transform = transform |
| self._chunk_sizes = None |
| self.label2idx = self._load_map(map_file) |
| self.class_selector = list(self.label2idx.keys()) if self.label2idx else None |
|
|
| if isinstance(tsv_file, str): |
| if os.path.splitext(tsv_file)[1] == '.tsv': |
| self.tsv_file = TSVFile( |
| tsv_file, class_selector=self.class_selector |
| ) |
| else: |
| self.tsv_file = CompositeTSVFile( |
| tsv_file, |
| class_selector=self.class_selector, |
| is_train=is_train, |
| sas_token_path=token_file, |
| azcopy_path=azcopy_path |
| ) |
| self._chunk_sizes = self.tsv_file.get_chunk_size() |
| elif isinstance(tsv_file, list): |
| self.tsv_file = CompositeTSVFile( |
| tsv_file, |
| class_selector=self.class_selector, |
| is_train=is_train, |
| sas_token_path=token_file, |
| azcopy_path=azcopy_path |
| ) |
| self._chunk_sizes = self.tsv_file.get_chunk_size() |
| else: |
| raise ValueError("Invalid input! Please check the tsv filenames") |
|
|
| logger.debug('=> {}\titems: {}'.format(tsv_file, len(self.tsv_file))) |
|
|
| def fetch_blob(self, idx): |
| image_tsv = self.tsv_file.file_list[idx] |
| self.tsv_file.blob_storage.fetch_blob(image_tsv) |
|
|
| def num_classes(self): |
| return len(self.class_selector) |
|
|
| def get_chunk_sizes(self): |
| return self._chunk_sizes |
|
|
| def get_class_boundaries(self): |
| |
| |
| return self.tsv_file.get_class_boundaries() |
|
|
| def get_filenames(self): |
| filenames = [ |
| self.tsv_file.get_key(i) |
| for i in range(self.tsv_file.num_rows()) |
| ] |
|
|
| return filenames |
|
|
| def _load_map(self, map_file: str): |
| if not map_file: |
| return None |
|
|
| label2idx = {} |
| with open(map_file) as f: |
| for line in f: |
| items = line.strip().split('\t') |
| label2idx[items[0]] = int(items[1]) |
|
|
| return label2idx |
|
|
| def __getitem__(self, index: Union[int, Tuple[int, int]]): |
| items = self.tsv_file[index] |
| _, target, img = self._decode_data(items) |
|
|
| if self.transform: |
| img = self.transform(img) |
|
|
| return img, target |
|
|
| def _decode_data(self, items: Tuple[str, str, str]): |
| key = items[0] |
| label = self._get_label(items[1]) |
| image = Image.open(BytesIO(base64.b64decode(items[2]))).convert('RGB') |
|
|
| return key, label, image |
|
|
| def _get_label(self, item: str): |
| if not self.label2idx: |
| return int(item) |
|
|
| js = json.loads(item) |
| return self.label2idx[js[0]['class']] |
|
|
| def __len__(self): |
| return len(self.tsv_file) |
|
|
|
|
| class TSVMeta(NamedTuple): |
| source: str |
| num_classes: int |
| task: str |
|
|
|
|
| class TSVImageTextDatasetV2(data.Dataset): |
| """ |
| This class is intended for encapsulating Image/Text pair data for contrastive learning described in |
| the following paper, |
| "Learning Transferable Visual Models From Natural Language Supervision" (a.k.a CLIP) |
| V2: support image text pairs and supervised classification data |
| """ |
|
|
| def __init__(self, |
| image_tsv_file: Union[str, List[str]], |
| text_tsv_file: Union[str, List[str]], |
| transform: Callable = None, |
| tokenize: Callable = None, |
| context_length: int = 77, |
| num_captions: int = 1, |
| text_format: str = 'txt', |
| is_train: bool = True, |
| sas_token_path: str = None, |
| azcopy_path: str = None, |
| metas: List[NamedTuple] = None, |
| prompt_engineering=True, |
| concat_queries=False): |
| self.transform = transform |
| self.tokenize = tokenize |
| self._chunk_sizes = None |
| self.context_length = context_length |
| self.num_captions = num_captions |
| self.text_format = text_format |
| self.tsv_file_list = [] |
| self.metas = metas |
| self.label_offsets = self.build_label_offsets() |
| self.prompt_engineering = prompt_engineering |
| self.concat_queries = concat_queries |
|
|
| if isinstance(image_tsv_file, str) and isinstance(text_tsv_file, str): |
| |
| if ( |
| os.path.splitext(image_tsv_file)[1].lower() == '.tsv' |
| and os.path.splitext(text_tsv_file)[1].lower() == '.tsv' |
| ): |
| self.tsv_file_list.append((image_tsv_file, text_tsv_file)) |
| self.image_tsv_file = TSVFile( |
| image_tsv_file, if_generate_lineidx=True |
| ) |
| self.text_tsv_file = TSVFile( |
| text_tsv_file, if_generate_lineidx=True |
| ) |
| else: |
| raise ValueError("Invalid input! Please check the tsv filenames.") |
| |
| elif ( |
| isinstance(image_tsv_file, list) |
| and isinstance(text_tsv_file, list) |
| ): |
| assert len(image_tsv_file) == len(text_tsv_file), \ |
| "Inconsistent number of Image/Text tsv files!" |
| self.tsv_file_list = [ |
| (txt, img) |
| for img, txt in zip(image_tsv_file, text_tsv_file) |
| ] |
| self.image_tsv_file = CompositeTSVFile( |
| image_tsv_file, |
| is_train=is_train, |
| sas_token_path=sas_token_path, |
| azcopy_path=azcopy_path |
| ) |
| self.text_tsv_file = CompositeTSVFile( |
| text_tsv_file, |
| is_train=is_train, |
| sas_token_path=sas_token_path, |
| azcopy_path=azcopy_path |
| ) |
| self._chunk_sizes = self.image_tsv_file.get_chunk_size() |
| else: |
| raise ValueError("Invalid input! Please check the tsv filenames.") |
|
|
| assert len(self.image_tsv_file) == len(self.text_tsv_file), \ |
| "Inconsistent size of Image/Text ({}/{}) data!".format( |
| len(self.image_tsv_file), len(self.text_tsv_file) |
| ) |
|
|
| def build_label_offsets(self): |
| if self.metas is None: |
| return None |
|
|
| label_offsets = {} |
| offset = 1 |
| for meta in self.metas: |
| print(meta) |
| print(label_offsets) |
| label_offsets[meta.source] = offset |
| offset += meta.num_classes |
|
|
| return label_offsets |
|
|
| def fetch_blob(self, idx): |
| |
| image_tsv = self.image_tsv_file.file_list[idx] |
| text_tsv = self.text_tsv_file.file_list[idx] |
| self.image_tsv_file.blob_storage.fetch_blob(image_tsv) |
| self.text_tsv_file.blob_storage.fetch_blob(text_tsv) |
|
|
| def get_chunk_sizes(self): |
| return self._chunk_sizes |
|
|
| def __getitem__(self, index: Union[int, Tuple[int, int]]): |
| if index is None: |
| import torch |
| return torch.tensor([], dtype=torch.float32), \ |
| torch.tensor([], dtype=torch.int64), \ |
| torch.tensor([], dtype=torch.int64) |
|
|
| items_image = self.image_tsv_file[index] |
| items_text = self.text_tsv_file[index] |
|
|
| assert items_text[0] == items_image[0], \ |
| 'keys do not match for image and text {} vs {}'.format( |
| items_text[0], items_image[0] |
| ) |
|
|
| _, img = self._decode_image(items_image) |
| _, txt, label = self._decode_text(items_text) |
|
|
| if self.transform: |
| img = self.transform(img) |
|
|
| tokens = self.tokenize( |
| txt, padding='max_length', truncation=True, max_length=self.context_length, |
| return_tensors='pt' |
| ) if self.tokenize else txt |
|
|
| tokens['input_ids'].squeeze_() |
| tokens['attention_mask'].squeeze_() |
|
|
| return img, tokens, label |
|
|
| def _decode_image(self, items: Tuple[str, str]): |
| key = items[0] |
| image = Image.open(BytesIO(base64.b64decode(items[1]))).convert('RGB') |
|
|
| return key, image |
|
|
| def _decode_text(self, items: Tuple[str, Union[str, dict]]): |
| key = items[0] |
| text = '' |
|
|
| if self.text_format != 'json': |
| raise ValueError('Only support json format') |
|
|
| |
| try: |
| js = json.loads(items[1]) |
| except Exception as e: |
|
|
| |
| js = {} |
|
|
| |
| logger.info("JSON parsing error on: " + items[1]) |
| logger.info(str(e)) |
|
|
| |
| |
|
|
| |
| sstr = items[1].find("\"") |
| if (sstr < 0): |
| sstr = 0 |
|
|
| estr = items[1][sstr:].find("\"") |
| if (estr < 0): |
| estr = len(items[1]) |
|
|
| text = items[1][sstr:estr] |
| if (len(text) < 2): |
| text = "A picture showing some content." |
|
|
| label = 0 |
|
|
| if 'captions' in js: |
| captions = js['captions'] |
| if isinstance(captions, list): |
| if self.num_captions == 1: |
| text = random.choice(captions) |
| else: |
| text = captions |
| if len(captions) > self.num_captions: |
| text = captions[:self.num_captions] |
| elif isinstance(captions, str): |
| text = captions |
| else: |
| raise ValueError('captions should be str or list') |
| label = 0 |
| elif 'tags' in js: |
| text = prompt_engineering(js['tags']) |
| label = 0 |
| elif 'task' in js and js['task'] == 'classification': |
| if (self.prompt_engineering): |
| text = prompt_engineering(js['class_name']) |
| else: |
| text = js['class_name'] |
| label = js['class_id'] |
|
|
| if (self.label_offsets is not None): |
| if (js['source'] in self.label_offsets): |
| label += self.label_offsets[js['source']] |
|
|
| if (self.concat_queries): |
| if ('queries' in js) and (len(js['queries']) > 0): |
| q = '' |
| for item in js['queries']: |
| q = q + item + ' ' |
|
|
| text = q + ', ' + text |
|
|
| return key, text, label |
|
|
| def __len__(self): |
| return len(self.image_tsv_file) |
|
|