| '''this eval code is borrowed from E5''' |
| import os |
| import json |
| import tqdm |
| import numpy as np |
| import torch |
| import argparse |
|
|
| from datasets import Dataset |
| from typing import List, Dict |
| from functools import partial |
| from transformers import AutoModel, AutoTokenizer, PreTrainedTokenizerFast, BatchEncoding, DataCollatorWithPadding |
| from transformers.modeling_outputs import BaseModelOutput |
| from torch.utils.data import DataLoader |
| from mteb import MTEB, AbsTaskRetrieval, DRESModel |
|
|
| from utils import pool, logger, move_to_cuda |
|
|
| parser = argparse.ArgumentParser(description='evaluation for BEIR benchmark') |
| parser.add_argument('--model-name-or-path', default='bert-base-uncased', |
| type=str, metavar='N', help='which model to use') |
| parser.add_argument('--output-dir', default='tmp-outputs/', |
| type=str, metavar='N', help='output directory') |
| parser.add_argument('--pool-type', default='avg', help='pool type') |
| parser.add_argument('--max-length', default=512, help='max length') |
|
|
| args = parser.parse_args() |
| logger.info('Args: {}'.format(json.dumps(args.__dict__, ensure_ascii=False, indent=4))) |
| assert args.pool_type in ['cls', 'avg'], 'pool_type should be cls or avg' |
| assert args.output_dir, 'output_dir should be set' |
| os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
| def _transform_func(tokenizer: PreTrainedTokenizerFast, |
| examples: Dict[str, List]) -> BatchEncoding: |
| return tokenizer(examples['contents'], |
| max_length=int(args.max_length), |
| padding=True, |
| return_token_type_ids=False, |
| truncation=True) |
|
|
|
|
| class RetrievalModel(DRESModel): |
| |
| def __init__(self, **kwargs): |
| self.encoder = AutoModel.from_pretrained(args.model_name_or_path) |
| self.tokenizer = AutoTokenizer.from_pretrained(args.model_name_or_path) |
| self.gpu_count = torch.cuda.device_count() |
| if self.gpu_count > 1: |
| self.encoder = torch.nn.DataParallel(self.encoder) |
|
|
| self.encoder.cuda() |
| self.encoder.eval() |
|
|
| def encode_queries(self, queries: List[str], **kwargs) -> np.ndarray: |
| input_texts = ['查询: {}'.format(q) for q in queries] |
| return self._do_encode(input_texts) |
|
|
| def encode_corpus(self, corpus: List[Dict[str, str]], **kwargs) -> np.ndarray: |
| input_texts = ['{} {}'.format(doc.get('title', ''), doc['text']).strip() for doc in corpus] |
| input_texts = ['结果: {}'.format(t) for t in input_texts] |
| return self._do_encode(input_texts) |
|
|
| @torch.no_grad() |
| def _do_encode(self, input_texts: List[str]) -> np.ndarray: |
| dataset: Dataset = Dataset.from_dict({'contents': input_texts}) |
| dataset.set_transform(partial(_transform_func, self.tokenizer)) |
|
|
| data_collator = DataCollatorWithPadding(self.tokenizer, pad_to_multiple_of=8) |
| batch_size = 128 * self.gpu_count |
| data_loader = DataLoader( |
| dataset, |
| batch_size=batch_size, |
| shuffle=False, |
| drop_last=False, |
| num_workers=4, |
| collate_fn=data_collator, |
| pin_memory=True) |
|
|
| encoded_embeds = [] |
| for batch_dict in tqdm.tqdm(data_loader, desc='encoding', mininterval=10): |
| batch_dict = move_to_cuda(batch_dict) |
|
|
| with torch.cuda.amp.autocast(): |
| outputs: BaseModelOutput = self.encoder(**batch_dict) |
| embeds = pool(outputs.last_hidden_state, batch_dict['attention_mask'], args.pool_type) |
| encoded_embeds.append(embeds.cpu().numpy()) |
|
|
| return np.concatenate(encoded_embeds, axis=0) |
|
|
| TASKS = ["T2Retrieval", "MMarcoRetrieval", "DuRetrieval", "CovidRetrieval", "CmedqaRetrieval", "EcomRetrieval", "MedicalRetrieval", "VideoRetrieval"] |
| def main(): |
| assert AbsTaskRetrieval.is_dres_compatible(RetrievalModel) |
| model = RetrievalModel() |
|
|
| task_names = [t.description["name"] for t in MTEB(tasks=TASKS).tasks] |
| logger.info('Tasks: {}'.format(task_names)) |
|
|
| for task in task_names: |
| logger.info('Processing task: {}'.format(task)) |
| evaluation = MTEB(tasks=[task]) |
| evaluation.run(model, output_folder=args.output_dir, overwrite_results=False) |
|
|
|
|
| if __name__ == '__main__': |
| main() |
|
|