| import torch | |
| from typing import List, Union | |
| loaded_llm_models = {} | |
| def get_llm2vec_embeddings(text: Union[str, List[str]], | |
| model_name: str = 'McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp', | |
| peft_model_name: str = 'McGill-NLP/LLM2Vec-Meta-Llama-3-8B-Instruct-mntp-unsup-simcse', | |
| instruction: str = '', | |
| device: str = 'cuda', | |
| norm=True) -> torch.Tensor: | |
| """ | |
| Get LLM2Vec embeddings for the given text. | |
| Args: | |
| text (Union[str, List[str]]): The input text to be embedded. | |
| model_name (str): The model to use for embedding. | |
| peft_model_name (str): The model to use for PEFT embeddings. | |
| Returns: | |
| torch.Tensor: The embedding(s) of the input text(s). | |
| """ | |
| try: | |
| from llm2vec import LLM2Vec | |
| except ImportError: | |
| raise ImportError("Please install the llm2vec package using `pip install llm2vec`.") | |
| if peft_model_name in loaded_llm_models: | |
| l2v = loaded_llm_models[peft_model_name] | |
| else: | |
| l2v = LLM2Vec.from_pretrained( | |
| model_name, | |
| peft_model_name_or_path=peft_model_name, | |
| device_map=device, | |
| torch_dtype=torch.bfloat16, | |
| ) | |
| loaded_llm_models[peft_model_name] = l2v | |
| if isinstance(text, str): | |
| text = [text] | |
| if len(instruction) > 0: | |
| text = [[instruction, t] for t in text] | |
| embeddings = l2v.encode(text, batch_size=len(text)) | |
| if norm: | |
| embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=1) | |
| return embeddings.view(len(text), -1) | |
| def get_gritlm_embeddings(text: Union[str, List[str]], | |
| model_name: str = 'GritLM/GritLM-7B', | |
| instruction: str = '', | |
| device: str = 'cuda' | |
| ) -> torch.Tensor: | |
| try: | |
| from gritlm import GritLM | |
| except ImportError: | |
| raise ImportError("Please install the gritlm package using `pip install gritlm`.") | |
| def gritlm_instruction(instruction): | |
| return "<|user|>\n" + instruction + "\n<|embed|>\n" if instruction else "<|embed|>\n" | |
| """ | |
| Get GritLM embeddings for the given text. | |
| Args: | |
| text (Union[str, List[str]]): The input text to be embedded. | |
| instruction (str): The instruction to be used for GritLM. | |
| model_name (str): The model to use for embedding. | |
| Returns: | |
| torch.Tensor: The embedding(s) of the input text(s). | |
| """ | |
| if model_name in loaded_llm_models: | |
| gritlm_model = loaded_llm_models[model_name] | |
| else: | |
| gritlm_model = GritLM(model_name, torch_dtype=torch.bfloat16) | |
| loaded_llm_models[model_name] = gritlm_model | |
| if isinstance(text, str): | |
| text = [text] | |
| embeddings = gritlm_model.encode(text, instruction=gritlm_instruction(instruction)) | |
| embeddings = torch.from_numpy(embeddings) | |
| return embeddings.view(len(text), -1) | |