Spaces:
Build error
Build error
| import json | |
| import logging | |
| import os | |
| from typing import Dict, List, Optional, Union | |
| import numpy as np | |
| import torch | |
| import torch.multiprocessing as mp | |
| from peft import PeftModel | |
| from torch import Tensor, device, nn | |
| from tqdm.autonotebook import tqdm, trange | |
| from transformers import ( | |
| AutoModel, | |
| AutoConfig, | |
| PretrainedConfig, | |
| AutoTokenizer, | |
| LlamaConfig, | |
| MistralConfig, | |
| GemmaConfig, | |
| Qwen2Config, | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def batch_to_device(batch, target_device: device): | |
| """ | |
| send a pytorch batch to a device (CPU/GPU) | |
| """ | |
| for key in batch: | |
| if isinstance(batch[key], Tensor): | |
| batch[key] = batch[key].to(target_device) | |
| return batch | |
| class LLMEncoder(nn.Module): | |
| def __init__( | |
| self, | |
| model: AutoModel, | |
| tokenizer: AutoTokenizer, | |
| pooling_mode: str = "weighted_mean", | |
| max_length: int = 512, | |
| doc_max_length: int = 400, | |
| skip_instruction: bool = True, | |
| ): | |
| super().__init__() | |
| self.model = model | |
| self.tokenizer = tokenizer | |
| self.pooling_mode = pooling_mode | |
| self.skip_instruction = skip_instruction | |
| self.max_length = max_length | |
| self.doc_max_length = doc_max_length | |
| self.config = model.config | |
| def from_pretrained( | |
| self, | |
| base_model_name_or_path, | |
| peft_model_name_or_path=None, | |
| cache_dir=None, | |
| **kwargs, | |
| ): | |
| """ | |
| Load a pretrained model from a model identifier or path. | |
| Args: | |
| base_model_name_or_path: Model identifier or path to pretrained model. | |
| peft_model_name_or_path: Path to any PEFT models to apply. | |
| Returns: L3Prune model. | |
| """ | |
| # pop out encoder args | |
| keys = ["pooling_mode", "max_length", "doc_max_length", "skip_instruction"] | |
| encoder_args = { | |
| key: kwargs.pop(key, None) for key in keys if kwargs.get(key) is not None | |
| } | |
| tokenizer = AutoTokenizer.from_pretrained(base_model_name_or_path, cache_dir=cache_dir) | |
| tokenizer.pad_token = tokenizer.eos_token | |
| tokenizer.padding_side = "left" | |
| config = AutoConfig.from_pretrained(base_model_name_or_path) | |
| model = AutoModel.from_pretrained(base_model_name_or_path, cache_dir=cache_dir, **kwargs) | |
| if os.path.isdir(base_model_name_or_path) and os.path.exists( | |
| f"{base_model_name_or_path}/config.json" | |
| ): | |
| with open(f"{base_model_name_or_path}/config.json", "r") as fIn: | |
| config_dict = json.load(fIn) | |
| config = PretrainedConfig.from_dict(config_dict) | |
| model.config._name_or_path = config._name_or_path | |
| if peft_model_name_or_path is not None: | |
| model = PeftModel.from_pretrained( | |
| model, | |
| peft_model_name_or_path, | |
| ) | |
| model = model.merge_and_unload() | |
| config = {} | |
| if os.path.exists(f"{base_model_name_or_path}/l3prune_config.json"): | |
| with open(f"{base_model_name_or_path}/l3prune_config.json", "r") as fIn: | |
| l3prune_config = json.load(fIn) | |
| config.update(l3prune_config) | |
| for key, value in encoder_args.items(): | |
| config[key] = value | |
| return self(model=model, tokenizer=tokenizer, **config) | |
| def prune(self, percent_prune=0): | |
| """ | |
| Prune a model to a percentage of layers of the base model. If percent_prune is equal to or greater than 1, | |
| it is taken as the specific layer number to prune to. For example, if percent_prune=0.3, 30% of the layers will be pruned. If | |
| percent_prune=3, the model will be pruned to 3 layers. | |
| """ | |
| # take it as the specific layer number to prune to | |
| if percent_prune >= 1: | |
| new_num_layers = int(percent_prune) | |
| else: | |
| new_num_layers = int(self.model.config.num_hidden_layers * (1 - percent_prune)) | |
| print(f"Pruning to {new_num_layers} layer.") | |
| self.model.layers = self.model.layers[:new_num_layers] | |
| self.model.config.num_hidden_layers = new_num_layers | |
| def prepare_for_tokenization(self, text): | |
| if self.model.config._name_or_path == "meta-llama/Meta-Llama-3-8B-Instruct": | |
| text = ( | |
| "<|start_header_id|>user<|end_header_id|>\n\n" | |
| + text.strip() | |
| + "<|eot_id|>" | |
| ) | |
| return text | |
| if self.model.config._name_or_path in [ | |
| "mistralai/Mistral-7B-Instruct-v0.2", | |
| "meta-llama/Llama-2-7b-chat-hf", | |
| ]: | |
| text = "[INST] " + text.strip() + " [/INST]" | |
| if self.model.config._name_or_path in [ | |
| "google/gemma-2-9b-it", | |
| ]: | |
| text = "<bos><start_of_turn>user\n" + text.strip() + "<end_of_turn>" | |
| if self.model.config._name_or_path in [ | |
| "Qwen/Qwen2-1.5B-Instruct", | |
| "Qwen/Qwen2-7B-Instruct", | |
| ]: | |
| text = "<|im_start|>user\n" + text.strip() + "<|im_end|>" | |
| if self.pooling_mode == "eos_token": | |
| if self.model.config._name_or_path == "meta-llama/Meta-Llama-3-8B": | |
| text = text.strip() + "<|end_of_text|>" | |
| elif isinstance(self.model.config, LlamaConfig) or isinstance( | |
| self.model.config, MistralConfig | |
| ): | |
| text = text.strip() + " </s>" | |
| elif isinstance(self.model.config, GemmaConfig): | |
| text = text.strip() + "<eos>" | |
| elif isinstance(self.model.config, Qwen2Config): | |
| text = text.strip() + "<|endoftext|>" | |
| return text | |
| def tokenize(self, texts): | |
| texts_2 = [] | |
| original_texts = [] | |
| for text in texts: | |
| t = text.split("!@#$%^&*()") | |
| texts_2.append(t[1] if len(t) > 1 else "") | |
| original_texts.append("".join(t)) | |
| original = self.tokenizer( | |
| original_texts, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=self.max_length, | |
| ) | |
| embed_mask = None | |
| for t_i, t in enumerate(texts_2): | |
| ids = self.tokenizer( | |
| [t], | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=self.max_length, | |
| add_special_tokens=False, | |
| ) | |
| if embed_mask is None: | |
| e_m = torch.zeros_like(original["attention_mask"][t_i]) | |
| if len(ids["input_ids"][0]) > 0: | |
| e_m[-len(ids["input_ids"][0]) :] = torch.ones( | |
| len(ids["input_ids"][0]) | |
| ) | |
| embed_mask = e_m.unsqueeze(0) | |
| else: | |
| e_m = torch.zeros_like(original["attention_mask"][t_i]) | |
| if len(ids["input_ids"][0]) > 0: | |
| e_m[-len(ids["input_ids"][0]) :] = torch.ones( | |
| len(ids["input_ids"][0]) | |
| ) | |
| embed_mask = torch.cat((embed_mask, e_m.unsqueeze(0)), dim=0) | |
| original["embed_mask"] = embed_mask | |
| return original | |
| def _skip_instruction(self, sentence_feature): | |
| assert ( | |
| sentence_feature["attention_mask"].shape | |
| == sentence_feature["embed_mask"].shape | |
| ) | |
| sentence_feature["attention_mask"] = sentence_feature["embed_mask"] | |
| def forward(self, sentence_feature: Dict[str, Tensor]): | |
| embed_mask = None | |
| if "embed_mask" in sentence_feature: | |
| embed_mask = sentence_feature.pop("embed_mask") | |
| reps = self.model(**sentence_feature) | |
| sentence_feature["embed_mask"] = embed_mask | |
| return self.get_pooling(sentence_feature, reps.last_hidden_state) | |
| def get_pooling(self, features, last_hidden_states): # All models padded from left | |
| assert ( | |
| self.tokenizer.padding_side == "left" | |
| ), "Pooling modes are implemented for padding from left." | |
| if self.skip_instruction: | |
| self._skip_instruction(features) | |
| seq_lengths = features["attention_mask"].sum(dim=-1) | |
| if self.pooling_mode == "mean": | |
| return torch.stack( | |
| [ | |
| last_hidden_states[i, -length:, :].mean(dim=0) | |
| for i, length in enumerate(seq_lengths) | |
| ], | |
| dim=0, | |
| ) | |
| elif self.pooling_mode == "weighted_mean": | |
| bs, l, _ = last_hidden_states.shape | |
| complete_weights = torch.zeros(bs, l, device=last_hidden_states.device) | |
| for i, seq_l in enumerate(seq_lengths): | |
| if seq_l > 0: | |
| complete_weights[i, -seq_l:] = torch.arange(seq_l) + 1 | |
| complete_weights[i] /= torch.clamp( | |
| complete_weights[i].sum(), min=1e-9 | |
| ) | |
| return torch.sum(last_hidden_states * complete_weights.unsqueeze(-1), dim=1) | |
| elif self.pooling_mode == "eos_token" or self.pooling_mode == "last_token": | |
| return last_hidden_states[:, -1] | |
| elif self.pooling_mode == "bos_token": | |
| return last_hidden_states[ | |
| features["input_ids"] == self.tokenizer.bos_token_id | |
| ] | |
| else: | |
| raise ValueError(f"{self.pooling_mode} is not implemented yet.") | |
| def _convert_to_str(self, instruction, text): | |
| tokenized_q = self.tokenizer( | |
| text, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=self.max_length, | |
| add_special_tokens=False, | |
| ) | |
| tokenized_q_length = len(tokenized_q["input_ids"][0]) | |
| while tokenized_q_length > self.doc_max_length: | |
| reduction_ratio = self.doc_max_length / tokenized_q_length | |
| reduced_length = int(len(text.split()) * reduction_ratio) | |
| text = " ".join(text.split()[:reduced_length]) | |
| tokenized_q = self.tokenizer( | |
| text, | |
| return_tensors="pt", | |
| padding=True, | |
| truncation=True, | |
| max_length=self.max_length, | |
| add_special_tokens=False, | |
| ) | |
| tokenized_q_length = len(tokenized_q["input_ids"][0]) | |
| return ( | |
| f"{instruction.strip()} !@#$%^&*(){text}" | |
| if instruction | |
| else f"!@#$%^&*(){text}" | |
| ) | |
| def encode( | |
| self, | |
| sentences: Union[str, List[str]], | |
| batch_size: int = 32, | |
| show_progress_bar: bool = True, | |
| convert_to_numpy: bool = False, | |
| convert_to_tensor: bool = False, | |
| device: Optional[str] = None, | |
| ): | |
| """ | |
| Encode a list of sentences to their respective embeddings. The sentences can be a list of strings or a string. | |
| Args: | |
| sentences: sentence or sentences to encode. | |
| batch_size: batch size for turning sentence tokens into embeddings. | |
| show_progress_bar: whether to show progress bars during encoding steps. | |
| convert_to_numpy: If true, return numpy arrays instead of torch tensors. | |
| convert_to_tensor: If true, return torch tensors (default). | |
| device: torch backend device identifier (e.g., 'cuda', 'cpu','mps' etc.). If not specified, | |
| the default is to use cuda when available, otherwise cpu. Note that only the choice of 'cuda' supports | |
| multiprocessing as currently implemented. | |
| Returns: embeddings of the sentences. Embeddings are detached and always on the CPU (see _encode implementation). | |
| """ | |
| if isinstance(sentences[0], str) and isinstance(sentences[-1], int): | |
| sentences = [sentences] | |
| # required for MEDI version of MTEB | |
| if isinstance(sentences[0], str): | |
| sentences = [[""] + [sentence] for sentence in sentences] | |
| if device is None: | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| concatenated_input_texts = [] | |
| for sentence in sentences: | |
| assert isinstance(sentence[0], str) | |
| assert isinstance(sentence[1], str) | |
| concatenated_input_texts.append( | |
| self._convert_to_str(sentence[0], sentence[1]) | |
| ) | |
| sentences = concatenated_input_texts | |
| self.eval() | |
| if convert_to_tensor: | |
| convert_to_numpy = False | |
| length_sorted_idx = np.argsort([-self._text_length(sen) for sen in sentences]) | |
| sentences_sorted = [sentences[idx] for idx in length_sorted_idx] | |
| all_embeddings = [] | |
| if torch.cuda.device_count() <= 1: | |
| # This branch also support mps devices | |
| self.to(device) | |
| for start_index in trange( | |
| 0, | |
| len(sentences), | |
| batch_size, | |
| desc="Batches", | |
| disable=not show_progress_bar, | |
| ): | |
| sentences_batch = sentences_sorted[ | |
| start_index : start_index + batch_size | |
| ] | |
| embeddings = self._encode( | |
| sentences_batch, device=device, convert_to_numpy=convert_to_numpy | |
| ) | |
| all_embeddings.append(embeddings) | |
| else: | |
| num_proc = torch.cuda.device_count() | |
| cuda_compatible_multiprocess = mp.get_context("spawn") | |
| with cuda_compatible_multiprocess.Pool(num_proc) as p: | |
| sentences_batches = [ | |
| sentences_sorted[start_index : start_index + batch_size] | |
| for start_index in range(0, len(sentences), batch_size) | |
| ] | |
| progress_bar = tqdm( | |
| total=len(sentences_batches), | |
| desc="Batches", | |
| disable=not show_progress_bar, | |
| ) | |
| results = [] | |
| def update(*args): | |
| progress_bar.update() | |
| for batch in sentences_batches: | |
| results.append( | |
| p.apply_async( | |
| self._encode, | |
| args=(batch, None, convert_to_numpy, True), | |
| callback=update, | |
| ) | |
| ) | |
| all_embeddings = [result.get() for result in results] | |
| progress_bar.close() | |
| all_embeddings = torch.cat(all_embeddings, dim=0) | |
| all_embeddings = all_embeddings[np.argsort(length_sorted_idx)] | |
| all_embeddings = all_embeddings.to(torch.float32) | |
| if convert_to_numpy: | |
| all_embeddings = np.asarray([emb.numpy() for emb in all_embeddings]) | |
| return all_embeddings | |
| def save(self, output_path, merge_before_save=False, save_config=True): | |
| if merge_before_save and isinstance(self.model, PeftModel): | |
| self.model = self.model.merge_and_unload() | |
| if hasattr(self.model, "_hf_peft_config_loaded"): | |
| self.model._hf_peft_config_loaded = False | |
| self.model.save_pretrained(output_path) | |
| self.tokenizer.save_pretrained(output_path) | |
| l3prune_config = { | |
| "pooling_mode": self.pooling_mode, | |
| "max_length": self.max_length, | |
| "doc_max_length": self.doc_max_length, | |
| "skip_instruction": self.skip_instruction, | |
| } | |
| if save_config: | |
| os.makedirs(output_path, exist_ok=True) | |
| with open(f"{output_path}/l3prune_config.json", "w") as fOut: | |
| json.dump(l3prune_config, fOut, indent=4) | |
| def _encode( | |
| self, | |
| sentences_batch, | |
| device: Optional[str] = None, | |
| convert_to_numpy: bool = False, | |
| multiprocessing=False, | |
| ): | |
| if multiprocessing: | |
| # multiprocessing only supports CUDA devices at this time, so we ignore the value of device | |
| # and use cuda:rank for the device | |
| rank = mp.current_process()._identity[0] | |
| if device is None and torch.cuda.is_available(): | |
| device = f"cuda:{rank % torch.cuda.device_count()}" | |
| self.to(device) | |
| features = self.tokenize( | |
| [self.prepare_for_tokenization(sentence) for sentence in sentences_batch] | |
| ) | |
| features = batch_to_device(features, device) | |
| with torch.no_grad(): | |
| embeddings = self.forward(features) | |
| embeddings = embeddings.detach() | |
| embeddings = embeddings.cpu() | |
| return embeddings | |
| def _text_length(self, text: Union[List[int], List[List[int]]]): | |
| """ | |
| Help function to get the length for the input text. Text can be either a string (which means a single text) | |
| a list of ints (which means a single tokenized text), or a tuple of list of ints | |
| (representing several text inputs to the model). | |
| """ | |
| if ( | |
| isinstance(text, str) | |
| or (isinstance(text, list) and isinstance(text[0], int)) | |
| or len(text) == 0 | |
| ): # Single text, list of ints, or empty | |
| return len(text) | |
| if isinstance(text, dict): # {key: value} case | |
| return len(next(iter(text.values()))) | |
| elif not hasattr(text, "__len__"): # Object has no len() method | |
| return 1 | |
| else: | |
| return sum([len(t) for t in text]) | |
| def resize_token_embeddings( | |
| self, | |
| new_num_tokens: Optional[int] = None, | |
| pad_to_multiple_of: Optional[int] = None, | |
| ) -> nn.Embedding: | |
| return self.model.resize_token_embeddings( | |
| new_num_tokens=new_num_tokens, pad_to_multiple_of=pad_to_multiple_of | |
| ) | |
| def gradient_checkpointing_enable(self, gradient_checkpointing_kwargs=None): | |
| self.model.gradient_checkpointing_enable( | |
| gradient_checkpointing_kwargs=gradient_checkpointing_kwargs | |
| ) | |