| |
|
|
| from functools import partial |
| import numpy as np |
| import os |
| import time |
| import torch |
| from torch.utils.data import BatchSampler, DataLoader, SequentialSampler, Subset |
| from torch.utils.data._utils.collate import default_collate |
| from tqdm import tqdm |
|
|
| from megatron.training import get_args, get_tokenizer, print_rank_0 |
| from megatron import core |
| from megatron.training.arguments import core_transformer_config_from_args |
| from megatron.core.datasets.retro.utils import get_blocks_by_rank |
| from megatron.core.enums import ModelType |
| from megatron.core.pipeline_parallel import get_forward_backward_func |
| from megatron.legacy.model import BertModel |
| from megatron.training.training import setup_model_and_optimizer |
| from pretrain_bert import model_provider, get_batch, loss_func, forward_step |
|
|
| from .dataset import BertEmbeddingDataset |
| from .external_libs import h5py |
| from .huggingface import HuggingfaceEmbedder |
|
|
|
|
| def collate_batch(samples): |
| """Collate samples of various lengths. |
| |
| This collate function handles samples with various sequence lengths, by |
| padding 'text' arrays with pad_id, and other arrays with 0. |
| """ |
|
|
| n_samples = len(samples) |
| keys = list(samples[0].keys()) |
| tokenizer = get_tokenizer() |
|
|
| |
| max_length_map = { key:0 for key in keys } |
| for sample in samples: |
| for key in keys: |
| value_length = \ |
| len(sample[key]) if isinstance(sample[key], np.ndarray) else None |
| max_length_map[key] = None \ |
| if value_length is None else \ |
| max(max_length_map[key], value_length) |
|
|
| |
| padded_samples = [] |
| for sample in samples: |
| padded_sample = {} |
| for key in keys: |
| padded_sample[key] = \ |
| np.pad( |
| sample[key], |
| (0, max_length_map[key] - len(sample[key])), |
| mode="constant", |
| constant_values=tokenizer.pad_id if key == "text" else 0, |
| ) \ |
| if isinstance(sample[key], np.ndarray) else \ |
| sample[key] |
| padded_samples.append(padded_sample) |
|
|
| |
| batch = default_collate(padded_samples) |
|
|
| return batch |
|
|
|
|
| def get_data_loader(dataset, batch_size): |
| """Build data loader over data subset. |
| |
| Get a subset of the dataset (from start_idx -> end_idx), and wrap it in |
| a sequential sampler and data loader. |
| """ |
|
|
| args = get_args() |
|
|
| |
| batch_sampler = BatchSampler( |
| sampler=SequentialSampler(dataset), |
| batch_size=batch_size, |
| drop_last=False, |
| ) |
|
|
| |
| data_loader = DataLoader(dataset, |
| batch_sampler=batch_sampler, |
| num_workers=args.num_workers, |
| pin_memory=True, |
| collate_fn=collate_batch) |
|
|
| return data_loader |
|
|
|
|
| def embed_data_loader(models, data_loader, tag): |
| '''Iterate data loader and compute embeddings.''' |
|
|
| |
| args = get_args() |
| assert args.tensor_model_parallel_size == 1 and \ |
| args.pipeline_model_parallel_size == 1, \ |
| "since we call forward_step directly, only tp == pp == 1 allowed." |
|
|
| |
| data_iterator = iter(data_loader) |
|
|
| |
| for m in models: |
| m.eval() |
|
|
| |
| embeddings = [] |
| for _ in tqdm( |
| range(len(data_loader)), |
| " embed%s" % ("" if tag is None else " / '%s'" % tag), |
| miniters=len(data_loader) // 10, |
| disable=torch.distributed.get_rank() != 0, |
| ): |
| with torch.no_grad(): |
| result = forward_step(data_iterator, models[0]) |
| embeddings.append(result[0].detach().cpu().numpy()) |
|
|
| |
| embeddings = np.concatenate(embeddings, axis=0) |
|
|
| return embeddings |
|
|
|
|
| class TextDataset(torch.utils.data.Dataset): |
| '''Dataset that holds a list of strings.''' |
|
|
| def __init__(self, texts): |
| assert isinstance(texts, list) |
| for t in texts: |
| assert isinstance(t, str) |
| self.texts = texts |
|
|
| def __len__(self): |
| return len(self.texts) |
|
|
| def __getitem__(self, i): |
| return {"text": self.texts[i]} |
|
|
|
|
| class BertEmbedder: |
| '''Compute Bert embeddings, from a text dataset.''' |
|
|
| def __init__(self, batch_size, max_bert_seq_length, embedder_type, warmup=True): |
|
|
| args = get_args() |
|
|
| assert args.output_bert_embeddings |
|
|
| self.models, optimizer, opt_param_scheduler = \ |
| setup_model_and_optimizer(model_provider, |
| ModelType.encoder_or_decoder) |
| self.batch_size = batch_size |
| self.max_bert_seq_length = max_bert_seq_length |
|
|
| |
| if embedder_type == "megatron": |
| self.huggingface_embedder = None |
| elif embedder_type == "huggingface": |
| self.huggingface_embedder = HuggingfaceEmbedder(batch_size, |
| max_bert_seq_length) |
| else: |
| raise Exception("specialize for embedder type '%s'." % embedder_type) |
|
|
| |
| |
| |
| |
| if warmup: |
| warmup_dataset = TextDataset([ |
| "great fleas have lesser fleas, upon their backs to bite’em,", |
| "and lesser fleas have lesser fleas, and so, ad infinitum,", |
| "and those great fleas, themselves, in turn have greater fleas to go on,", |
| "while those again have greater still, and greater still, and so on.", |
| ]) |
| print_rank_0("bert / warmup single.") |
| for _ in range(3): |
| self.embed_text("hi, bert.") |
| print_rank_0("bert / warmup batch.") |
| for _ in range(3): |
| self.embed_text_dataset(warmup_dataset) |
|
|
| def embed_text_dataset(self, text_dataset, tag=None): |
| '''Embed a text dataset.''' |
|
|
| |
| if self.huggingface_embedder: |
| return self.huggingface_embedder.embed_text_dataset(text_dataset) |
|
|
| |
| bert_dataset = BertEmbeddingDataset(text_dataset, |
| self.max_bert_seq_length) |
|
|
| |
| data_loader = get_data_loader(bert_dataset, self.batch_size) |
| embeddings = embed_data_loader(self.models, data_loader, tag) |
|
|
| return embeddings |
|
|
| def embed_text(self, text): |
| '''Embed a single text string. |
| |
| Primarily used for on-the-fly embeddings, particularly during |
| analysis or debugging. For large scale, use 'embed_text_dataset()'. |
| ''' |
|
|
| |
| text_ds = TextDataset([ text ]) |
| embed = self.embed_text_dataset(text_ds)[0] |
|
|
| return embed |
|
|
|
|
| class DiskDataParallelBertEmbedder: |
| '''Process embeddings in blocks & save to disk.''' |
|
|
| def __init__(self, embedder, block_size): |
| assert isinstance(embedder, BertEmbedder) |
| self.embedder = embedder |
| self.block_size = block_size |
|
|
| def embed_text_blocks(self, name, dirname, text_dataset, |
| missing_embedding_blocks): |
| '''Process a text dataset in blocks.''' |
|
|
| |
| for block_index, block_info in enumerate(missing_embedding_blocks): |
|
|
| |
| |
| if block_info is not None: |
|
|
| |
| print_rank_0("embed '%s' block %d / %d ... %s." % ( |
| name, |
| block_index, |
| len(missing_embedding_blocks), |
| block_info["path"], |
| )) |
|
|
| |
| sub_dataset = Subset(text_dataset, range(*block_info["range"])) |
| embeddings = self.embedder.embed_text_dataset(sub_dataset) |
|
|
| |
| f = h5py.File(block_info["path"], "w") |
| f.create_dataset("data", data=embeddings) |
| f.close() |
|
|
| |
| print_rank_0(" > waiting for other ranks to finish block.") |
| torch.distributed.barrier() |
|
|
| def embed_text_dataset(self, name, dirname, text_dataset): |
| '''Embed a text dataset.''' |
|
|
| |
| os.makedirs(dirname, exist_ok=True) |
|
|
| |
| def validate(f): |
| assert f["data"].shape[1] == 1024 |
| blocks = get_blocks_by_rank( |
| dirname, |
| len(text_dataset), |
| self.block_size, |
| validate=validate) |
|
|
| |
| torch.distributed.barrier() |
|
|
| |
| self.embed_text_blocks(name, dirname, text_dataset, blocks.missing) |
|
|