Spaces:
Runtime error
Runtime error
| import os | |
| import tqdm | |
| import torch | |
| from contextlib import nullcontext | |
| from torch.utils.data import DataLoader | |
| from functools import partial | |
| from datasets import load_dataset | |
| from typing import Dict, List | |
| from transformers.file_utils import PaddingStrategy | |
| from transformers import ( | |
| AutoTokenizer, | |
| PreTrainedTokenizerFast, | |
| DataCollatorWithPadding, | |
| HfArgumentParser, | |
| BatchEncoding | |
| ) | |
| from config import Arguments | |
| from logger_config import logger | |
| from utils import move_to_cuda | |
| from models import BiencoderModelForInference, BiencoderOutput | |
| parser = HfArgumentParser((Arguments,)) | |
| args: Arguments = parser.parse_args_into_dataclasses()[0] | |
| def _psg_transform_func(tokenizer: PreTrainedTokenizerFast, | |
| examples: Dict[str, List]) -> BatchEncoding: | |
| batch_dict = tokenizer(examples['title'], | |
| text_pair=examples['contents'], | |
| max_length=args.p_max_len, | |
| padding=PaddingStrategy.DO_NOT_PAD, | |
| truncation=True) | |
| # for co-Condenser reproduction purpose only | |
| if args.model_name_or_path.startswith('Luyu/'): | |
| del batch_dict['token_type_ids'] | |
| return batch_dict | |
| def _worker_encode_passages(gpu_idx: int): | |
| def _get_out_path(shard_idx: int = 0) -> str: | |
| return '{}/shard_{}_{}'.format(args.encode_save_dir, gpu_idx, shard_idx) | |
| if os.path.exists(_get_out_path(0)): | |
| logger.error('{} already exists, will skip encoding'.format(_get_out_path(0))) | |
| return | |
| dataset = load_dataset('json', data_files=args.encode_in_path)['train'] | |
| if args.dry_run: | |
| dataset = dataset.select(range(4096)) | |
| dataset = dataset.shard(num_shards=torch.cuda.device_count(), | |
| index=gpu_idx, | |
| contiguous=True) | |
| logger.info('GPU {} needs to process {} examples'.format(gpu_idx, len(dataset))) | |
| torch.cuda.set_device(gpu_idx) | |
| tokenizer: PreTrainedTokenizerFast = AutoTokenizer.from_pretrained(args.model_name_or_path) | |
| model: BiencoderModelForInference = BiencoderModelForInference.build(args) | |
| model.eval() | |
| model.cuda() | |
| dataset.set_transform(partial(_psg_transform_func, tokenizer)) | |
| data_collator = DataCollatorWithPadding(tokenizer, pad_to_multiple_of=8 if args.fp16 else None) | |
| data_loader = DataLoader( | |
| dataset, | |
| batch_size=args.encode_batch_size, | |
| shuffle=False, | |
| drop_last=False, | |
| num_workers=args.dataloader_num_workers, | |
| collate_fn=data_collator, | |
| pin_memory=True) | |
| num_encoded_docs, encoded_embeds, cur_shard_idx = 0, [], 0 | |
| for batch_dict in tqdm.tqdm(data_loader, desc='passage encoding', mininterval=8): | |
| batch_dict = move_to_cuda(batch_dict) | |
| with torch.cuda.amp.autocast() if args.fp16 else nullcontext(): | |
| outputs: BiencoderOutput = model(query=None, passage=batch_dict) | |
| encoded_embeds.append(outputs.p_reps.cpu()) | |
| num_encoded_docs += outputs.p_reps.shape[0] | |
| if num_encoded_docs >= args.encode_shard_size: | |
| out_path = _get_out_path(cur_shard_idx) | |
| concat_embeds = torch.cat(encoded_embeds, dim=0) | |
| logger.info('GPU {} save {} embeds to {}'.format(gpu_idx, concat_embeds.shape[0], out_path)) | |
| torch.save(concat_embeds, out_path) | |
| cur_shard_idx += 1 | |
| num_encoded_docs = 0 | |
| encoded_embeds.clear() | |
| if num_encoded_docs > 0: | |
| out_path = _get_out_path(cur_shard_idx) | |
| concat_embeds = torch.cat(encoded_embeds, dim=0) | |
| logger.info('GPU {} save {} embeds to {}'.format(gpu_idx, concat_embeds.shape[0], out_path)) | |
| torch.save(concat_embeds, out_path) | |
| logger.info('Done computing score for worker {}'.format(gpu_idx)) | |
| def _batch_encode_passages(): | |
| logger.info('Args={}'.format(str(args))) | |
| gpu_count = torch.cuda.device_count() | |
| if gpu_count == 0: | |
| logger.error('No gpu available') | |
| return | |
| logger.info('Use {} gpus'.format(gpu_count)) | |
| torch.multiprocessing.spawn(_worker_encode_passages, args=(), nprocs=gpu_count) | |
| logger.info('Done batch encode passages') | |
| if __name__ == '__main__': | |
| _batch_encode_passages() | |