fcxfcx's picture
Upload 2446 files
1327f34 verified
# Copyright 2025 The Scenic Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Dataset and Loader for VQA dataset."""
import functools
from typing import Optional
from absl import logging
from flax import jax_utils
import jax.numpy as jnp
import ml_collections
from scenic.dataset_lib import dataset_utils
from scenic.dataset_lib import datasets
from scenic.dataset_lib.big_transfer import bit
from scenic.dataset_lib.big_transfer import builder
from scenic.dataset_lib.big_transfer import registry
from scenic.dataset_lib import web_image_text_dataset
from scenic.projects.knowledge_visual_language.data import data_utils
import tensorflow as tf
# import jax
OUTPUT_MAX_LENGTH = 64
IMAGE_SIZE = 224
QUESTION_LENGTH = 64
ANSWER_LENGTH = 32
KNOWLEDGE_MAX_LENGTH = 320
n_qa = 5
@registry.Registry.register('preprocess_ops.get_vqa_pair', 'function')
def get_vqa_pair():
"""Concat title passage and document together to form knowledge."""
def get_vqa_pair_fn(data):
"""Prepare Knowledge by concating hierarchy, passage and first-paragraph."""
data['question'] = data['question/answers']['question_text']
data['answers'] = tf.reshape(data['question/answers']['answers'], [5, -1])
data['answer'] = data['answers'][:, 0]
data['top_answers'] = tf.strings.reduce_join(
data['answers'], separator=', ', axis=-1
)
return data
return get_vqa_pair_fn
def map_vqa_split(batch):
"""Split answer into decoder_input and decoder_output."""
full_tokens = batch.pop('answer')
batch['decoder_input_tokens'] = full_tokens[..., :-1]
batch['decoder_target_tokens'] = full_tokens[..., 1:]
return batch
def get_default_dataset_config(runlocal=False):
"""Gets default configs for CC12M dataset."""
dataset_configs = ml_collections.ConfigDict()
dataset_configs.dataset = 'vqa'
# Add path to your data here:
dataset_configs.dataset_dir = ''
dataset_configs.train_split = 'train+validation[5000:]'
dataset_configs.question_max_num_tokens = QUESTION_LENGTH
dataset_configs.answer_max_num_tokens = ANSWER_LENGTH
dataset_configs.image_size = IMAGE_SIZE
dataset_configs.pp_train = (
f'decode|resize(resize_size={IMAGE_SIZE})|value_range(-1,1)|get_vqa_pair|list_t5_tokenize(max_num_tokens={KNOWLEDGE_MAX_LENGTH},inkey="top_answers",'
' outkey="retr_texts",'
f' prompt="{data_utils.KNOWLEDGE_PREFIX}")|list_t5_tokenize(max_num_tokens={ANSWER_LENGTH},'
f' inkey="answer",outkey="answer")|multi_t5_tokenize(max_num_tokens={ANSWER_LENGTH},inkey="answers",outkey="answers")|list_t5_tokenize(max_num_tokens={QUESTION_LENGTH},'
' inkey="question", outkey="question",'
f' prompt="{data_utils.VQA_PREFIX}")|keep("image", "question", "answer",'
' "retr_texts", "answers")'
)
dataset_configs.val_split = [(
'val',
dataset_configs.dataset,
'validation[:5000]',
dataset_configs.pp_train,
)]
dataset_configs.shuffle_buffer_size = 10000 if not runlocal else 50
dataset_configs.val_cache = 'loaded' # Unfortunately, "batched" gets us OOM.
dataset_configs.vocab_size = data_utils.VOCAB_SIZE_T5
dataset_configs.prefetch_to_device = 2
return dataset_configs
@datasets.add_dataset('vqa')
def get_dataset(
*,
batch_size,
eval_batch_size,
num_shards,
dtype_str='float32',
shuffle_seed=None,
rng=None,
dataset_configs=None,
dataset_service_address: Optional[str] = None,
):
"""Returns generators for the VQA train and validation sets.
Args:
batch_size: int; Determines the train batch size.
eval_batch_size: int; Determines the evaluation batch size.
num_shards: int; Number of shards --> batch shape: [num_shards, bs, ...].
dtype_str: Data type of the image (e.g. 'float32').
shuffle_seed: int; Seed for shuffling the training data. Not used.
rng: JAX rng key, which can be used for augmentation, shuffling, etc.
dataset_configs: dict; Dataset specific configurations.
dataset_service_address: If set, will distribute the training dataset using
the given tf.data service at the given address.
Returns:
A dataset_utils.Dataset() which includes a train_iter, a valid_iter,
a test_iter, and a dict of meta_data.
"""
default_dataset_config = get_default_dataset_config(runlocal=False)
if dataset_configs:
default_dataset_config.update(dataset_configs)
dataset_configs = default_dataset_config
del rng
assert dataset_configs is not None
logging.info('Loading train split of the %s', dataset_configs.dataset)
def pp_fn(x, how):
pp = builder.get_preprocess_fn(how, remove_tpu_dtypes=False)
example = pp(x)
return {
'encoder_input_image': example['image'],
'encoder_input_tokens': example['question'],
'answer': example['answer'],
'answers': example['answers'],
'retr_texts': example['retr_texts'],
}
# E.g. for testing with TAP.
shuffle_buffer_size = (
1000 if num_shards == 1 else dataset_configs.shuffle_buffer_size
)
train_ds = data_utils.get_data(
dataset=dataset_configs.dataset,
split=dataset_configs.train_split,
data_dir=dataset_configs.get('dataset_dir'),
batch_size=batch_size,
preprocess_fn=functools.partial(pp_fn, how=dataset_configs.pp_train),
shuffle_buffer_size=shuffle_buffer_size,
prefetch=dataset_configs.get('prefetch_to_host', 2),
cache=dataset_configs.val_cache,
ignore_errors=True,
)
if dataset_service_address:
if shuffle_seed is not None:
raise ValueError(
'Using dataset service with a random seed causes each '
'worker to produce exactly the same data. Add '
'config.shuffle_seed = None to your config if you '
'want to run with dataset service.'
)
logging.info('Using the tf.data service at %s', dataset_service_address)
assert shuffle_buffer_size is not None
train_ds = dataset_utils.distribute(train_ds, dataset_service_address)
n_train_ex = dataset_utils.get_num_examples(
dataset_configs.dataset,
dataset_configs.train_split,
data_dir=dataset_configs.get('dataset_dir'),
)
maybe_pad_batches_train = functools.partial(
dataset_utils.maybe_pad_batch,
inputs_key='encoder_input_image',
train=True,
batch_size=batch_size,
)
shard_batches = functools.partial(dataset_utils.shard, n_devices=num_shards)
train_iter = iter(train_ds)
train_iter = map(map_vqa_split, train_iter)
train_iter = map(dataset_utils.tf_to_numpy, train_iter)
train_iter = map(data_utils.sample_retr_image, train_iter)
train_iter = map(maybe_pad_batches_train, train_iter)
if num_shards > 0:
train_iter = map(shard_batches, train_iter)
if dataset_configs.prefetch_to_device:
train_iter = jax_utils.prefetch_to_device(
train_iter, dataset_configs.prefetch_to_device
)
logging.info('Loading validation split of the %s', dataset_configs.dataset)
maybe_pad_batches_eval = functools.partial(
dataset_utils.maybe_pad_batch,
inputs_key='encoder_input_image',
train=False,
batch_size=eval_batch_size,
)
def _get_eval_iter(dataset, split, pp_eval):
val_ds = data_utils.get_data(
dataset=dataset,
split=split,
data_dir=dataset_configs.get('dataset_dir'),
batch_size=eval_batch_size,
preprocess_fn=functools.partial(pp_fn, how=pp_eval),
cache='batched',
repeat_after_batching=True,
drop_remainder=False,
)
valid_iter = iter(val_ds)
valid_iter = map(map_vqa_split, valid_iter)
valid_iter = map(bit.tf_to_numpy, valid_iter)
valid_iter = map(data_utils.sample_retr_image, valid_iter)
valid_iter = map(maybe_pad_batches_eval, valid_iter)
if num_shards > 0:
valid_iter = map(shard_batches, valid_iter)
if dataset_configs.prefetch_to_device:
valid_iter = jax_utils.prefetch_to_device(
valid_iter, dataset_configs.prefetch_to_device
)
return valid_iter
def _get_num_eval_examples(dataset, split, data_dir):
return dataset_utils.get_num_examples(dataset, split, data_dir)
if isinstance(dataset_configs.val_split, str):
valid_iter = _get_eval_iter(
dataset_configs.dataset,
dataset_configs.val_split,
dataset_configs.pp_eval,
)
n_eval_ex = _get_num_eval_examples(
dataset_configs.dataset,
dataset_configs.val_split,
data_dir=dataset_configs.get('dataset_dir'),
)
else:
valid_iter, n_eval_ex = {}, {}
for eval_spec in dataset_configs.val_split:
name, dataset, split, pp_eval = eval_spec
valid_iter[name] = _get_eval_iter(dataset, split, pp_eval)
n_eval_ex[name] = _get_num_eval_examples(
dataset, split, data_dir=dataset_configs.get('dataset_dir')
)
meta_data = {'num_train_examples': n_train_ex, 'num_eval_examples': n_eval_ex}
if dataset_configs.get('extra_meta_data'):
for k, v in dataset_configs.extra_meta_data.items():
meta_data[k] = v
image_shape = (-1, dataset_configs.image_size, dataset_configs.image_size, 3)
predix_shape = (-1, QUESTION_LENGTH)
input_shape = (-1, ANSWER_LENGTH)
retr_texts_shape = (-1, KNOWLEDGE_MAX_LENGTH + data_utils.PROMPT_LENGTH)
retr_image_shape = (
-1,
dataset_configs.image_size,
dataset_configs.image_size,
3,
)
meta_data['encoder_input_image_spec'] = (image_shape, getattr(jnp, dtype_str))
meta_data['encoder_input_tokens_spec'] = (predix_shape, jnp.int16)
meta_data['decoder_input_tokens_spec'] = (input_shape, jnp.int16)
meta_data['decoder_target_tokens_spec'] = (input_shape, jnp.int16)
meta_data['retr_texts_spec'] = (retr_texts_shape, jnp.int16)
meta_data['retr_images_spec'] = (retr_image_shape, getattr(jnp, dtype_str))
return dataset_utils.Dataset(train_iter, valid_iter, None, meta_data)