import datasets import torch from llm2vec import LLM2Vec from beir import util from beir.datasets.data_loader import GenericDataLoader as BeirDataLoader import os from typing import Dict, List from beir.retrieval.evaluation import EvaluateRetrieval dataset = "arguana" instruction = "Given a claim, find documents that refute the claim: " print("Loading dataset...") url = ( f"https://public.ukp.informatik.tu-darmstadt.de/thakur/BEIR/datasets/{dataset}.zip" ) download_path = os.path.join(datasets.config.HF_DATASETS_CACHE, "BeIR") data_path = util.download_and_unzip(url, download_path) corpus, queries, relevant_docs = BeirDataLoader(data_folder=data_path).load( split="test" ) batch_size = 8 print("Loading model...") model = LLM2Vec.from_pretrained( "McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp", peft_model_name_or_path="McGill-NLP/LLM2Vec-Mistral-7B-Instruct-v2-mntp-supervised", device_map="cuda" if torch.cuda.is_available() else "cpu", torch_dtype=torch.bfloat16, ) def append_instruction(instruction, sentences): new_sentences = [] for s in sentences: new_sentences.append([instruction, s, 0]) return new_sentences def cos_sim(a: torch.Tensor, b: torch.Tensor): if not isinstance(a, torch.Tensor): a = torch.tensor(a) if not isinstance(b, torch.Tensor): b = torch.tensor(b) if len(a.shape) == 1: a = a.unsqueeze(0) if len(b.shape) == 1: b = b.unsqueeze(0) a_norm = torch.nn.functional.normalize(a, p=2, dim=1) b_norm = torch.nn.functional.normalize(b, p=2, dim=1) return torch.mm(a_norm, b_norm.transpose(0, 1)) def encode_queries(queries: List[str], batch_size: int, **kwargs): new_sentences = append_instruction(instruction, queries) kwargs["show_progress_bar"] = False return model.encode(new_sentences, batch_size=batch_size, **kwargs) def encode_corpus(corpus: List[Dict[str, str]], batch_size: int, **kwargs): if type(corpus) is dict: sentences = [ ( (corpus["title"][i] + " " + corpus["text"][i]).strip() if "title" in corpus else corpus["text"][i].strip() ) for i in range(len(corpus["text"])) ] else: sentences = [ ( (doc["title"] + " " + doc["text"]).strip() if "title" in doc else doc["text"].strip() ) for doc in corpus ] new_sentences = append_instruction("", sentences) return model.encode(new_sentences, batch_size=batch_size, **kwargs) print("Encoding Queries...") query_ids = list(queries.keys()) results = {qid: {} for qid in query_ids} queries = [queries[qid] for qid in queries] query_embeddings = encode_queries( queries, batch_size=batch_size, show_progress_bar=True, convert_to_tensor=True ) print("Sorting Corpus by document length (Longest first)...") corpus_ids = sorted( corpus, key=lambda k: len(corpus[k].get("title", "") + corpus[k].get("text", "")), reverse=True, ) corpus = [corpus[cid] for cid in corpus_ids] print("Encoding Corpus ... Warning: This might take a while!") corpus_embeddings = encode_corpus( corpus, batch_size=batch_size, show_progress_bar=True, convert_to_tensor=True ) print("Scoring Function: {} ({})".format("Cosine Similarity", "cos_sim")) cos_scores = cos_sim(query_embeddings, corpus_embeddings) cos_scores[torch.isnan(cos_scores)] = -1 # Get top-k values top_k = 1000 cos_scores_top_k_values, cos_scores_top_k_idx = torch.topk( cos_scores, min(top_k + 1, len(cos_scores[0])), dim=1, largest=True, sorted=False ) cos_scores_top_k_values = cos_scores_top_k_values.cpu().tolist() cos_scores_top_k_idx = cos_scores_top_k_idx.cpu().tolist() for query_itr in range(len(query_embeddings)): query_id = query_ids[query_itr] for sub_corpus_id, score in zip( cos_scores_top_k_idx[query_itr], cos_scores_top_k_values[query_itr] ): corpus_id = corpus_ids[sub_corpus_id] if corpus_id != query_id: results[query_id][corpus_id] = score retriever = EvaluateRetrieval(model, score_function="cos_sim") ndcg, _map, recall, precision = retriever.evaluate( relevant_docs, results, retriever.k_values ) mrr = retriever.evaluate_custom(relevant_docs, results, retriever.k_values, "mrr") scores = { **{f"ndcg_at_{k.split('@')[1]}": v for (k, v) in ndcg.items()}, **{f"map_at_{k.split('@')[1]}": v for (k, v) in _map.items()}, **{f"recall_at_{k.split('@')[1]}": v for (k, v) in recall.items()}, **{f"precision_at_{k.split('@')[1]}": v for (k, v) in precision.items()}, **{f"mrr_at_{k.split('@')[1]}": v for (k, v) in mrr.items()}, } print(scores) """ { 'ndcg_at_1': 0.32788, 'ndcg_at_3': 0.47534, 'ndcg_at_5': 0.52296, 'ndcg_at_10': 0.57505, 'ndcg_at_100': 0.6076, 'ndcg_at_1000': 0.60801, 'map_at_1': 0.32788, 'map_at_3': 0.43883, 'map_at_5': 0.46518, 'map_at_10': 0.48675, 'map_at_100': 0.49506, 'map_at_1000': 0.49509, 'recall_at_1': 0.32788, 'recall_at_3': 0.58108, 'recall_at_5': 0.69701, 'recall_at_10': 0.85775, 'recall_at_100': 0.9936, 'recall_at_1000': 0.99644, 'precision_at_1': 0.32788, 'precision_at_3': 0.19369, 'precision_at_5': 0.1394, 'precision_at_10': 0.08578, 'precision_at_100': 0.00994, 'precision_at_1000': 0.001, 'mrr_at_1': 0.33357, 'mrr_at_3': 0.44085, 'mrr_at_5': 0.46745, 'mrr_at_10': 0.4888, 'mrr_at_100': 0.49718, 'mrr_at_1000': 0.49721} """