fcxfcx's picture
Upload 2446 files
1327f34 verified
# Copyright 2025 The Scenic Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Global file to store local knowledge base."""
import functools
from absl import logging
from flax.core.frozen_dict import FrozenDict
import jax
import numpy as np
import tqdm
def static_encode_knowledge(knowledge_batch, train_state, *, flax_model):
"""Function to encode KB knowledge.
Args:
knowledge_batch: A single batch of data. The buffer of this argument can be
donated to the computation.
train_state: The state of training including the current global_step,
model_state, rng, and optimizer. The buffer of this argument can be
donated to the computation.
flax_model: A Flax model.
Returns:
Key (single embedding), Val (compressed list of embedding) and Mask.
"""
variables = {'params': train_state.params, **train_state.model_state}
retr_tokens, retr_images = knowledge_batch[
'knowledge_tokens'], knowledge_batch['image']
batch_size = retr_images.shape[0]
keys_head, compressed_val, compressed_mask, _, _ = flax_model.apply(
variables,
retr_texts=retr_tokens,
retr_images=retr_images,
bsz=batch_size,
train=False,
random_drop_image=False,
method=flax_model.encode_knowledge)
return keys_head, compressed_val, compressed_mask
class KnowledgeBase:
"""Local Knowledge Base stored in CPU."""
def __init__(self):
self.memory = {}
self.memory_flatten = {}
self.specs = {}
self.n_data_per_shard = -1
self.n_data = -1
self.ret_specs = None
self.n_kb_dataset = 1
self.n_local_device = 1
self.k = 1
def set_encode_fn(self, flax_model):
encode_fn = functools.partial(
static_encode_knowledge, flax_model=flax_model)
self.encode_knowledge_pmap = jax.pmap(
encode_fn,
axis_name='batch',
donate_argnums=(0, 1),
)
def initialize(self, kb_datasets):
"""Load sharded dataset into CPU."""
memory_image = []
memory_text = []
memory_idxs = []
self.n_kb_dataset = len(kb_datasets)
logging.info('Start Loading sharded dataset into CPU.')
for idx, dataset_name in enumerate(kb_datasets):
logging.info(dataset_name)
dataset = kb_datasets[dataset_name]
n_iter = int(dataset.meta_data['example_per_shard'] //
dataset.meta_data['batch_size'])
for _ in tqdm.tqdm(range(n_iter)):
kb_batch = next(dataset.train_iter)
memory_image += [np.asarray(kb_batch['image'])]
memory_text += [np.asarray(kb_batch['knowledge_tokens'])]
memory_idxs += [
idx * np.ones(shape=kb_batch['image'].shape[:2]).astype('int16')
]
del kb_batch
self.memory['image'] = np.concatenate(memory_image, axis=1)
self.memory['text'] = np.concatenate(memory_text, axis=1)
self.n_local_device = self.memory['image'].shape[0]
self.memory_flatten['idxs'] = np.repeat(
np.reshape(np.concatenate(memory_idxs, axis=1), (1, -1)),
self.n_local_device,
axis=0)
self.n_data_per_shard = self.memory['image'].shape[1]
self.n_data = self.n_data_per_shard * self.n_local_device
self.specs = {
'image': dataset.meta_data['image_spec'],
'text': dataset.meta_data['knowledge_spec']
}
def update_memory(self, pmap_train_state, bsz, retr_k, data_k,
axis_index_groups):
"""Function to update stale embedding as memory.
Args:
pmap_train_state: pmaped train state.
bsz: Global batch size.
retr_k: number of returned data for retrieval.
data_k: number of returned data for ranking.
axis_index_groups: axis groups to gather data.
Returns:
updated train_state
"""
per_bsz = bsz // jax.device_count()
if axis_index_groups is None:
per_shard_bsz = bsz
else:
per_shard_bsz = bsz // len(axis_index_groups[0])
logging.info('update memory!!!')
logging.info(per_bsz)
memory_key = []
memory_val = []
# memory_mask = []
eval_per_bsz = per_bsz * 4
for idx in range(int(np.ceil(self.n_data_per_shard / eval_per_bsz))):
kb_batch = {
'knowledge_tokens':
self.memory['text'][:,
idx * eval_per_bsz:(idx + 1) * eval_per_bsz],
'image':
self.memory['image'][:,
idx * eval_per_bsz:(idx + 1) * eval_per_bsz],
}
keys_head, compressed_val, _ = self.encode_knowledge_pmap(
kb_batch, pmap_train_state)
memory_key += [np.asarray(keys_head)]
memory_val += [np.asarray(compressed_val)]
# memory_mask += [np.asarray(mask)]
del kb_batch
for kw in ['keys', 'values']:
if kw in self.memory:
del self.memory[kw]
if kw in self.memory_flatten:
del self.memory_flatten[kw]
self.memory['keys'] = np.concatenate(memory_key, axis=1)
self.memory['values'] = np.concatenate(memory_val, axis=1)
# self.memory['masks'] = np.concatenate(memory_mask, axis=1)
for kw in ['keys', 'values', 'image', 'text']:
mem = self.memory[kw]
self.memory_flatten[kw] = mem.reshape((mem.shape[0] * mem.shape[1],) +
mem.shape[2:])
self.specs['keys'] = (keys_head.shape[2:], keys_head.dtype.name)
self.specs['values'] = (compressed_val.shape[2:], compressed_val.dtype.name)
# self.specs['masks'] = (mask.shape[2:], mask.dtype.name)
self.local_ret_specs = [{
'keys':
jax.ShapeDtypeStruct(
shape=(per_bsz, retr_k - data_k) + self.specs['keys'][0],
dtype=self.specs['values'][1]),
'values':
jax.ShapeDtypeStruct(
shape=(per_bsz, retr_k - data_k) + self.specs['values'][0],
dtype=self.specs['values'][1])
}, {
'image':
jax.ShapeDtypeStruct(
shape=(per_bsz, data_k) + self.specs['image'][0],
dtype=self.specs['image'][1]),
'text_tokens':
jax.ShapeDtypeStruct(
shape=(per_bsz, data_k) + self.specs['text'][0],
dtype=self.specs['text'][1])
}]
self.k = int(np.ceil(retr_k / len(axis_index_groups)) + 1)
self.ret_top_specs = [{
'keys':
jax.ShapeDtypeStruct(
shape=(per_shard_bsz, self.k) + self.specs['keys'][0],
dtype=self.specs['values'][1]),
'values':
jax.ShapeDtypeStruct(
shape=(per_shard_bsz, self.k) + self.specs['values'][0],
dtype=self.specs['values'][1])
}, {
'image':
jax.ShapeDtypeStruct(
shape=(per_shard_bsz, self.k) + self.specs['image'][0],
dtype=self.specs['image'][1]),
'text_tokens':
jax.ShapeDtypeStruct(
shape=(per_shard_bsz, self.k) + self.specs['text'][0],
dtype=self.specs['text'][1])
}]
logging.info(self.local_ret_specs)
logging.info(self.ret_specs)
logging.info(self.memory['keys'].shape)
new_model_state = pmap_train_state.model_state.unfreeze()
if 'keys' in new_model_state['memory']:
del new_model_state['memory']['keys']
del new_model_state['memory']['idxs']
new_model_state['memory']['keys'] = self.memory['keys']
new_model_state['memory']['idxs'] = self.memory_flatten['idxs']
pmap_train_state = pmap_train_state.replace(
model_state=FrozenDict(new_model_state))
logging.info('finish update memory!!!')
return pmap_train_state
def retrieve_memory(args):
device_id, indexs = args
return [
{
'values': kb.memory['values'][device_id][indexs],
# 'masks': kb.memory['masks'][device_id][indexs]
},
{
'image': kb.memory['image'][device_id][indexs],
'text_tokens': kb.memory['text'][device_id][indexs]
}
]
def local_retrieve_memory(args):
global_data_ids, global_memory_ids = args
return [
{
'keys': kb.memory_flatten['keys'][global_memory_ids],
'values': kb.memory_flatten['values'][global_memory_ids],
# 'masks': kb.memory['masks'][device_id][indexs]
},
{
'image': kb.memory_flatten['image'][global_data_ids],
'text_tokens': kb.memory_flatten['text'][global_data_ids]
}
]
def retrieve_top_memory(args):
top1_ids = args
return [
{
'keys': kb.memory_flatten['keys'][top1_ids],
'values': kb.memory_flatten['values'][top1_ids],
# 'masks': kb.memory['masks'][device_id][indexs]
},
{
'image': kb.memory_flatten['image'][top1_ids],
'text_tokens': kb.memory_flatten['text'][top1_ids]
}
]
kb = KnowledgeBase()