| """
|
| Copyright (c) 2022, salesforce.com, inc.
|
| All rights reserved.
|
| SPDX-License-Identifier: BSD-3-Clause
|
| For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
|
| """
|
|
|
| import logging
|
| import os
|
| import shutil
|
| import warnings
|
|
|
| import lavis.common.utils as utils
|
| import torch.distributed as dist
|
| from lavis.common.dist_utils import is_dist_avail_and_initialized, is_main_process
|
| from lavis.common.registry import registry
|
| from lavis.datasets.data_utils import extract_archive
|
| from lavis.processors.base_processor import BaseProcessor
|
| from omegaconf import OmegaConf
|
| from torchvision.datasets.utils import download_url
|
|
|
|
|
| class BaseDatasetBuilder:
|
| train_dataset_cls, eval_dataset_cls = None, None
|
|
|
| def __init__(self, cfg=None):
|
| super().__init__()
|
|
|
| if cfg is None:
|
|
|
| self.config = load_dataset_config(self.default_config_path())
|
| elif isinstance(cfg, str):
|
| self.config = load_dataset_config(cfg)
|
| else:
|
|
|
| self.config = cfg
|
|
|
| self.data_type = self.config.data_type
|
|
|
| self.vis_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
|
| self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
|
|
|
|
|
| self.kw_processors = {}
|
|
|
| def build_datasets(self):
|
|
|
|
|
|
|
| if is_main_process():
|
| self._download_data()
|
|
|
| if is_dist_avail_and_initialized():
|
| dist.barrier()
|
|
|
|
|
| logging.info("Building datasets...")
|
| datasets = self.build()
|
|
|
| return datasets
|
|
|
| def build_processors(self):
|
| vis_proc_cfg = self.config.get("vis_processor")
|
| txt_proc_cfg = self.config.get("text_processor")
|
|
|
| if vis_proc_cfg is not None:
|
| vis_train_cfg = vis_proc_cfg.get("train")
|
| vis_eval_cfg = vis_proc_cfg.get("eval")
|
|
|
| self.vis_processors["train"] = self._build_proc_from_cfg(vis_train_cfg)
|
| self.vis_processors["eval"] = self._build_proc_from_cfg(vis_eval_cfg)
|
|
|
| if txt_proc_cfg is not None:
|
| txt_train_cfg = txt_proc_cfg.get("train")
|
| txt_eval_cfg = txt_proc_cfg.get("eval")
|
|
|
| self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg)
|
| self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg)
|
|
|
| kw_proc_cfg = self.config.get("kw_processor")
|
| if kw_proc_cfg is not None:
|
| for name, cfg in kw_proc_cfg.items():
|
| self.kw_processors[name] = self._build_proc_from_cfg(cfg)
|
|
|
| @staticmethod
|
| def _build_proc_from_cfg(cfg):
|
| return (
|
| registry.get_processor_class(cfg.name).from_config(cfg)
|
| if cfg is not None
|
| else None
|
| )
|
|
|
| @classmethod
|
| def default_config_path(cls, type="default"):
|
| return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])
|
|
|
| def _download_data(self):
|
| self._download_ann()
|
| self._download_vis()
|
|
|
| def _download_ann(self):
|
| """
|
| Download annotation files if necessary.
|
| All the vision-language datasets should have annotations of unified format.
|
|
|
| storage_path can be:
|
| (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative.
|
| (2) basename/dirname: will be suffixed with base name of URL if dirname is provided.
|
|
|
| Local annotation paths should be relative.
|
| """
|
| anns = self.config.build_info.annotations
|
|
|
| splits = anns.keys()
|
|
|
| cache_root = registry.get_path("cache_root")
|
|
|
| for split in splits:
|
| info = anns[split]
|
|
|
| urls, storage_paths = info.get("url", None), info.storage
|
|
|
| if isinstance(urls, str):
|
| urls = [urls]
|
| if isinstance(storage_paths, str):
|
| storage_paths = [storage_paths]
|
|
|
| assert len(urls) == len(storage_paths)
|
|
|
| for url_or_filename, storage_path in zip(urls, storage_paths):
|
|
|
| if not os.path.isabs(storage_path):
|
| storage_path = os.path.join(cache_root, storage_path)
|
|
|
| dirname = os.path.dirname(storage_path)
|
| if not os.path.exists(dirname):
|
| os.makedirs(dirname)
|
|
|
| if os.path.isfile(url_or_filename):
|
| src, dst = url_or_filename, storage_path
|
| if not os.path.exists(dst):
|
| shutil.copyfile(src=src, dst=dst)
|
| else:
|
| logging.info("Using existing file {}.".format(dst))
|
| else:
|
| if os.path.isdir(storage_path):
|
|
|
| raise ValueError(
|
| "Expecting storage_path to be a file path, got directory {}".format(
|
| storage_path
|
| )
|
| )
|
| else:
|
| filename = os.path.basename(storage_path)
|
|
|
| download_url(url=url_or_filename, root=dirname, filename=filename)
|
|
|
| def _download_vis(self):
|
|
|
| storage_path = self.config.build_info.get(self.data_type).storage
|
| storage_path = utils.get_cache_path(storage_path)
|
|
|
| if not os.path.exists(storage_path):
|
| warnings.warn(
|
| f"""
|
| The specified path {storage_path} for visual inputs does not exist.
|
| Please provide a correct path to the visual inputs or
|
| refer to datasets/download_scripts/README.md for downloading instructions.
|
| """
|
| )
|
|
|
| def build(self):
|
| """
|
| Create by split datasets inheriting torch.utils.data.Datasets.
|
|
|
| # build() can be dataset-specific. Overwrite to customize.
|
| """
|
| self.build_processors()
|
|
|
| build_info = self.config.build_info
|
|
|
| ann_info = build_info.annotations
|
| vis_info = build_info.get(self.data_type)
|
|
|
| datasets = dict()
|
| for split in ann_info.keys():
|
| if split not in ["train", "val", "test"]:
|
| continue
|
|
|
| is_train = split == "train"
|
|
|
|
|
| vis_processor = (
|
| self.vis_processors["train"]
|
| if is_train
|
| else self.vis_processors["eval"]
|
| )
|
| text_processor = (
|
| self.text_processors["train"]
|
| if is_train
|
| else self.text_processors["eval"]
|
| )
|
|
|
|
|
| ann_paths = ann_info.get(split).storage
|
| if isinstance(ann_paths, str):
|
| ann_paths = [ann_paths]
|
|
|
| abs_ann_paths = []
|
| for ann_path in ann_paths:
|
| if not os.path.isabs(ann_path):
|
| ann_path = utils.get_cache_path(ann_path)
|
| abs_ann_paths.append(ann_path)
|
| ann_paths = abs_ann_paths
|
|
|
|
|
| vis_path = vis_info.storage
|
|
|
| if not os.path.isabs(vis_path):
|
|
|
| vis_path = utils.get_cache_path(vis_path)
|
|
|
| if not os.path.exists(vis_path):
|
| warnings.warn("storage path {} does not exist.".format(vis_path))
|
|
|
|
|
| dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls
|
| datasets[split] = dataset_cls(
|
| vis_processor=vis_processor,
|
| text_processor=text_processor,
|
| ann_paths=ann_paths,
|
| vis_root=vis_path,
|
| )
|
|
|
| return datasets
|
|
|
|
|
| class ProteinDatasetBuilder:
|
| train_dataset_cls, eval_dataset_cls = None, None
|
|
|
| def __init__(self, cfg=None):
|
| super().__init__()
|
|
|
| if cfg is None:
|
|
|
| self.config = load_dataset_config(self.default_config_path())
|
| elif isinstance(cfg, str):
|
| self.config = load_dataset_config(cfg)
|
| else:
|
|
|
| self.config = cfg
|
|
|
| self.data_type = self.config.data_type
|
|
|
| self.text_processors = {"train": BaseProcessor(), "eval": BaseProcessor()}
|
|
|
|
|
| self.kw_processors = {}
|
|
|
| def build_datasets(self):
|
|
|
|
|
|
|
| if is_main_process():
|
| self._download_data()
|
|
|
| if is_dist_avail_and_initialized():
|
| dist.barrier()
|
|
|
|
|
| logging.info("Building datasets...")
|
| datasets = self.build()
|
|
|
| return datasets
|
|
|
| def build_processors(self):
|
| txt_proc_cfg = self.config.get("text_processor")
|
|
|
| if txt_proc_cfg is not None:
|
| txt_train_cfg = txt_proc_cfg.get("train")
|
| txt_eval_cfg = txt_proc_cfg.get("eval")
|
|
|
| self.text_processors["train"] = self._build_proc_from_cfg(txt_train_cfg)
|
| self.text_processors["eval"] = self._build_proc_from_cfg(txt_eval_cfg)
|
|
|
| kw_proc_cfg = self.config.get("kw_processor")
|
| if kw_proc_cfg is not None:
|
| for name, cfg in kw_proc_cfg.items():
|
| self.kw_processors[name] = self._build_proc_from_cfg(cfg)
|
|
|
| @staticmethod
|
| def _build_proc_from_cfg(cfg):
|
| return (
|
| registry.get_processor_class(cfg.name).from_config(cfg)
|
| if cfg is not None
|
| else None
|
| )
|
|
|
| @classmethod
|
| def default_config_path(cls, type="default"):
|
| return utils.get_abs_path(cls.DATASET_CONFIG_DICT[type])
|
|
|
| def _download_data(self):
|
| self._download_ann()
|
|
|
| def _download_ann(self):
|
| """
|
| Download annotation files if necessary.
|
| All the vision-language datasets should have annotations of unified format.
|
|
|
| storage_path can be:
|
| (1) relative/absolute: will be prefixed with env.cache_root to make full path if relative.
|
| (2) basename/dirname: will be suffixed with base name of URL if dirname is provided.
|
|
|
| Local annotation paths should be relative.
|
| """
|
| anns = self.config.build_info.annotations
|
|
|
| splits = anns.keys()
|
|
|
| cache_root = registry.get_path("cache_root")
|
|
|
| for split in splits:
|
| info = anns[split]
|
|
|
| urls, storage_paths = info.get("url", None), info.storage
|
|
|
| if isinstance(urls, str):
|
| urls = [urls]
|
| if isinstance(storage_paths, str):
|
| storage_paths = [storage_paths]
|
|
|
| assert len(urls) == len(storage_paths)
|
|
|
| for url_or_filename, storage_path in zip(urls, storage_paths):
|
|
|
| if not os.path.isabs(storage_path):
|
| storage_path = os.path.join(cache_root, storage_path)
|
|
|
| dirname = os.path.dirname(storage_path)
|
| if not os.path.exists(dirname):
|
| os.makedirs(dirname)
|
|
|
| if os.path.isfile(url_or_filename):
|
| src, dst = url_or_filename, storage_path
|
| if not os.path.exists(dst):
|
| shutil.copyfile(src=src, dst=dst)
|
| else:
|
| logging.info("Using existing file {}.".format(dst))
|
| else:
|
| if os.path.isdir(storage_path):
|
|
|
| raise ValueError(
|
| "Expecting storage_path to be a file path, got directory {}".format(
|
| storage_path
|
| )
|
| )
|
| else:
|
| filename = os.path.basename(storage_path)
|
|
|
| download_url(url=url_or_filename, root=dirname, filename=filename)
|
|
|
| def build(self):
|
| """
|
| Create by split datasets inheriting torch.utils.data.Datasets.
|
|
|
| # build() can be dataset-specific. Overwrite to customize.
|
| """
|
| self.build_processors()
|
|
|
| build_info = self.config.build_info
|
|
|
| ann_info = build_info.annotations
|
|
|
| datasets = dict()
|
| for split in ann_info.keys():
|
| if split not in ["train", "val", "test"]:
|
| continue
|
|
|
| is_train = split == "train"
|
|
|
| text_processor = (
|
| self.text_processors["train"]
|
| if is_train
|
| else self.text_processors["eval"]
|
| )
|
|
|
|
|
| ann_paths = ann_info.get(split).storage
|
| if isinstance(ann_paths, str):
|
| ann_paths = [ann_paths]
|
|
|
| abs_ann_paths = []
|
| for ann_path in ann_paths:
|
| if not os.path.isabs(ann_path):
|
| ann_path = utils.get_cache_path(ann_path)
|
| abs_ann_paths.append(ann_path)
|
| ann_paths = abs_ann_paths
|
|
|
|
|
|
|
| dataset_cls = self.train_dataset_cls if is_train else self.eval_dataset_cls
|
| datasets[split] = dataset_cls(
|
| text_processor=text_processor,
|
| ann_paths=ann_paths,
|
| )
|
|
|
| return datasets
|
|
|
|
|
| def load_dataset_config(cfg_path):
|
| cfg = OmegaConf.load(cfg_path).datasets
|
| cfg = cfg[list(cfg.keys())[0]]
|
|
|
| return cfg
|
|
|