| import os
|
|
|
| os.environ["HF_ENDPOINT"] = "https://hf-mirror.com"
|
| os.environ["OPENBLAS_NUM_THREADS"] = "32"
|
| import mteb
|
| import torch
|
| import numpy as np
|
| from mteb.encoder_interface import PromptType
|
| from sentence_transformers import SentenceTransformer
|
|
|
| TASK_NAME2TYPE = {
|
| 'ArguAna': 'Retrieval', 'ArXivHierarchicalClusteringP2P': 'Clustering',
|
| 'ArXivHierarchicalClusteringS2S': 'Clustering', 'AskUbuntuDupQuestions': 'Reranking',
|
| 'BIOSSES': 'STS', 'Banking77Classification': 'Classification',
|
| 'BiorxivClusteringP2P.v2': 'Clustering', 'CQADupstackGamingRetrieval': 'Retrieval',
|
| 'CQADupstackUnixRetrieval': 'Retrieval', 'ClimateFEVERHardNegatives': 'Retrieval',
|
| 'FEVERHardNegatives': 'Retrieval', 'FiQA2018': 'Retrieval', 'HotpotQAHardNegatives': 'Retrieval',
|
| 'ImdbClassification': 'Classification', 'MTOPDomainClassification': 'Classification',
|
| 'MassiveIntentClassification': 'Classification', 'MassiveScenarioClassification': 'Classification',
|
| 'MedrxivClusteringP2P.v2': 'Clustering', 'MedrxivClusteringS2S.v2': 'Clustering',
|
| 'MindSmallReranking': 'Reranking', 'SCIDOCS': 'Retrieval', 'SICK-R': 'STS', 'STS12': 'STS',
|
| 'STS13': 'STS', 'STS14': 'STS', 'STS15': 'STS', 'STSBenchmark': 'STS',
|
| 'SprintDuplicateQuestions': 'PairClassification', 'StackExchangeClustering.v2': 'Clustering',
|
| 'StackExchangeClusteringP2P.v2': 'Clustering', 'TRECCOVID': 'Retrieval',
|
| 'Touche2020Retrieval.v3': 'Retrieval', 'ToxicConversationsClassification': 'Classification',
|
| 'TweetSentimentExtractionClassification': 'Classification',
|
| 'TwentyNewsgroupsClustering.v2': 'Clustering', 'TwitterSemEval2015': 'PairClassification',
|
| 'TwitterURLCorpus': 'PairClassification', 'SummEvalSummarization.v2': 'Summarization',
|
| 'AmazonCounterfactualClassification': 'Classification', 'STS17': 'STS', 'STS22.v2': 'STS'
|
| }
|
|
|
| RETRIEVE_Q_PROMPT = "<|START_INSTRUCTION|>Answer the question<|END_INSTRUCTION|>"
|
| RETRIEVE_P_PROMPT = "<|START_INSTRUCTION|>Candidate document<|END_INSTRUCTION|>"
|
| STS_PROMPT = "<|START_INSTRUCTION|>Generate semantically similar text<|END_INSTRUCTION|>"
|
|
|
| TASK_NAME2PROMPT = {
|
|
|
| "Banking77Classification": "<|START_INSTRUCTION|>Classify text into intents<|END_INSTRUCTION|>",
|
| "ImdbClassification": "<|START_INSTRUCTION|>Classify text into sentiment<|END_INSTRUCTION|>",
|
| "MTOPDomainClassification": "<|START_INSTRUCTION|>Classify text into intent domain<|END_INSTRUCTION|>",
|
| "MassiveIntentClassification": "<|START_INSTRUCTION|>Classify text into user intents<|END_INSTRUCTION|>",
|
| "MassiveScenarioClassification": "<|START_INSTRUCTION|>Classify text into user scenarios<|END_INSTRUCTION|>",
|
| "ToxicConversationsClassification": "<|START_INSTRUCTION|>Classify text into toxic or not toxic<|END_INSTRUCTION|>",
|
| "TweetSentimentExtractionClassification": "<|START_INSTRUCTION|>Classify text into positive, negative, or neutral sentiment<|END_INSTRUCTION|>",
|
| "AmazonCounterfactualClassification": "<|START_INSTRUCTION|>Classify text into counterfactual or not-counterfactual<|END_INSTRUCTION|>",
|
|
|
|
|
| "ArXivHierarchicalClusteringP2P": "<|START_INSTRUCTION|>Output main and secondary category of Arxiv papers based on the titles and abstracts<|END_INSTRUCTION|>",
|
| "ArXivHierarchicalClusteringS2S": "<|START_INSTRUCTION|>Output main and secondary category of Arxiv papers based on the titles<|END_INSTRUCTION|>",
|
| "BiorxivClusteringP2P.v2": "<|START_INSTRUCTION|>Output main category of Biorxiv papers based on the titles and abstracts<|END_INSTRUCTION|>",
|
| "MedrxivClusteringP2P.v2": "<|START_INSTRUCTION|>Output main category of Medrxiv papers based on the titles and abstracts<|END_INSTRUCTION|>",
|
| "MedrxivClusteringS2S.v2": "<|START_INSTRUCTION|>Output main category of Medrxiv papers based on the titles<|END_INSTRUCTION|>",
|
| "StackExchangeClustering.v2": "<|START_INSTRUCTION|>Output topic or theme of StackExchange posts based on the titles<|END_INSTRUCTION|>",
|
| "StackExchangeClusteringP2P.v2": "<|START_INSTRUCTION|>Output topic or theme of StackExchange posts based on the given paragraphs<|END_INSTRUCTION|>",
|
| "TwentyNewsgroupsClustering.v2": "<|START_INSTRUCTION|>Output topic or theme of news articles<|END_INSTRUCTION|>",
|
| }
|
|
|
|
|
| class DeweyWrapper:
|
| def __init__(self, model_dir, max_seq_length: int = 1536, batch_size: int = 8):
|
| self.model = SentenceTransformer(
|
| model_dir,
|
| trust_remote_code=True,
|
| model_kwargs={
|
| "torch_dtype": torch.bfloat16,
|
| "attn_implementation": "flash_attention_2"
|
| },
|
| config_kwargs={"single_vector_type": "cls_add_mean"}
|
| ).cuda().bfloat16().eval()
|
| self.model.max_seq_length = max_seq_length
|
| self.pool = self.model.start_multi_process_pool()
|
| self.batch_size = batch_size
|
|
|
| def encode(
|
| self,
|
| sentences: list[str],
|
| task_name: str,
|
| prompt_type: PromptType | None = None,
|
| **kwargs,
|
| ) -> np.ndarray:
|
| task_type = TASK_NAME2TYPE[task_name]
|
| if task_type == "Retrieval":
|
| if prompt_type.value == "query":
|
| prompt = RETRIEVE_Q_PROMPT
|
| else:
|
| prompt = RETRIEVE_P_PROMPT
|
| elif task_type in ["STS", "PairClassification", "Summarization", "Reranking"]:
|
| prompt = STS_PROMPT
|
| else:
|
| prompt = TASK_NAME2PROMPT[task_name]
|
| vectors = self.model.encode_multi_process(
|
| sentences=sentences,
|
| pool=self.pool,
|
| show_progress_bar=True,
|
| batch_size=self.batch_size,
|
| normalize_embeddings=True,
|
| prompt=prompt,
|
| precision="float32"
|
| )
|
| return vectors
|
|
|
|
|
| if __name__ == "__main__":
|
| max_seq_length = 1536
|
| batch_szie = 8
|
| model_dir_or_name = "infgrad/dewey_en_beta"
|
| output_folder = f"./mteb_eng_results/dewey_en_beta"
|
| model = DeweyWrapper(model_dir_or_name, max_seq_length=max_seq_length, batch_size=batch_szie)
|
|
|
| tasks = list(mteb.get_benchmark("MTEB(eng, v2)"))
|
| evaluation = mteb.MTEB(tasks=tasks)
|
| evaluation.run(model, output_folder=output_folder, verbosity=2, overwrite_results=False)
|
|
|