| | from llmlingua import PromptCompressor |
| | import bisect |
| | from collections import defaultdict |
| | from typing import List |
| |
|
| | import numpy as np |
| | import torch |
| |
|
| | import nltk |
| | import tiktoken |
| | import re |
| | from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer |
| |
|
| | from abs_compressor import AbstractCompressor |
| |
|
| | encoding = tiktoken.encoding_for_model("gpt-3.5-turbo") |
| |
|
| | class LongLLMLinguaCompressor(AbstractCompressor): |
| | def __init__( |
| | self, |
| | model_name: str = "meta-llama/Llama-2-7b-chat-hf", |
| | device_map: str = "cuda", |
| | use_auth_token: bool = False, |
| | open_api_config: dict = {}, |
| | ): |
| | self.load_model(model_name, device_map, use_auth_token) |
| | self.retrieval_model = None |
| | self.retrieval_model_name = None |
| | self.open_api_config = open_api_config |
| | self.cache_bos_num = 10 |
| |
|
| | def load_model( |
| | self, model_name: str, device_map: str = "cuda", use_auth_token: bool = False |
| | ): |
| | config = AutoConfig.from_pretrained("meta-llama/Llama-2-7b-chat-hf") |
| | tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf") |
| | tokenizer.padding_side = "left" |
| | tokenizer.pad_token_id = ( |
| | config.pad_token_id if config.pad_token_id else tokenizer.eos_token_id |
| | ) |
| | self.device = ( |
| | device_map if any(key in device_map for key in ["cuda", "cpu"]) else "cuda" |
| | ) |
| | if "cuda" in device_map or "cpu" in device_map: |
| | model = AutoModelForCausalLM.from_pretrained( |
| | model_name, |
| | torch_dtype="auto" if device_map == "cuda" else torch.float32, |
| | config=config, |
| | ignore_mismatched_sizes=True, |
| | trust_remote_code=True, |
| | token="Your Token here" |
| | ).to(device_map) |
| | else: |
| | model = AutoModelForCausalLM.from_pretrained( |
| | model_name, |
| | device_map=device_map, |
| | torch_dtype="auto", |
| | pad_token_id=tokenizer.pad_token_id, |
| | offload_folder="/tmp/offload", |
| | offload_state_dict=True, |
| | cache_dir="/tmp/cache", |
| | use_auth_token=use_auth_token, |
| | trust_remote_code=True, |
| | token="Your Token here" |
| | ) |
| | self.tokenizer = tokenizer |
| | self.model = model |
| | self.context_idxs = [] |
| | self.max_position_embeddings = config.max_position_embeddings |
| |
|
| | def get_ppl( |
| | self, |
| | text: str, |
| | granularity: str = "sentence", |
| | input_ids=None, |
| | attention_mask=None, |
| | past_key_values=None, |
| | return_kv=False, |
| | end=None, |
| | condition_mode: str = "none", |
| | condition_pos_id: int = 0, |
| | ): |
| | if input_ids is None: |
| | tokenized_text = self.tokenizer(text, return_tensors="pt") |
| | input_ids = tokenized_text["input_ids"].to(self.device) |
| | attention_mask = tokenized_text["attention_mask"].to(self.device) |
| | if past_key_values is not None: |
| | past_length = past_key_values[0][0].shape[2] |
| | else: |
| | past_length = 0 |
| | if end is None: |
| | end = input_ids.shape[1] |
| | end = min(end, past_length + self.max_position_embeddings) |
| | with torch.no_grad(): |
| | response = self.model( |
| | input_ids[:, past_length:end], |
| | attention_mask=attention_mask[:, :end], |
| | past_key_values=past_key_values, |
| | use_cache=True, |
| | ) |
| | past_key_values = response.past_key_values |
| |
|
| | loss_fct = torch.nn.CrossEntropyLoss(reduction="none") |
| | shift_logits = response.logits[..., :-1, :].contiguous() |
| | shift_labels = input_ids[..., past_length + 1 : end].contiguous() |
| | |
| | active = (attention_mask[:, past_length:end] == 1)[..., :-1].view(-1) |
| | active_logits = shift_logits.view(-1, shift_logits.size(-1))[active] |
| | active_labels = shift_labels.view(-1)[active] |
| | loss_fct = torch.nn.CrossEntropyLoss(reduction="none") |
| | loss = loss_fct(active_logits, active_labels) |
| | if condition_mode == "before": |
| | loss = loss[:condition_pos_id] |
| | elif condition_mode == "after": |
| | loss = loss[condition_pos_id:] |
| | res = loss.mean() if granularity == "sentence" else loss |
| | return (res, past_key_values) if return_kv else res |
| |
|
| | def __call__(self, *args, **kwargs): |
| | return self.compress(*args, **kwargs) |
| |
|
| | def compress( |
| | self, |
| | context: List[str], |
| | instruction: str = "", |
| | question: str = " ", |
| | ratio: float = 0.5, |
| | target_token: float = -1, |
| | iterative_size: int = 200, |
| | force_context_ids: List[int] = None, |
| | force_context_number: int = None, |
| | use_sentence_level_filter: bool = False, |
| | use_context_level_filter: bool = True, |
| | use_token_level_filter: bool = True, |
| | keep_split: bool = False, |
| | keep_first_sentence: int = 0, |
| | keep_last_sentence: int = 0, |
| | keep_sentence_number: int = 0, |
| | high_priority_bonus: int = 100, |
| | context_budget: str = "+100", |
| | token_budget_ratio: float = 1.4, |
| | condition_in_question: str = "none", |
| | reorder_context: str = "original", |
| | dynamic_context_compression_ratio: float = 0.0, |
| | condition_compare: bool = False, |
| | add_instruction: bool = False, |
| | rank_method: str = "longllmlingua", |
| | concate_question: bool = True, |
| | ): |
| | if isinstance(context, str): |
| | context = [context] |
| | assert not ( |
| | rank_method == "longllmlingua" and not question |
| | ), "In the LongLLMLingua, it is necessary to set a question." |
| | if condition_compare and "_condition" not in condition_in_question: |
| | condition_in_question += "_condition" |
| | if rank_method == "longllmlingua": |
| | if condition_in_question == "none": |
| | condition_in_question = "after" |
| | elif rank_method == "llmlingua": |
| | condition_in_question = ( |
| | "none" |
| | if "_condition" not in condition_in_question |
| | else "none_condition" |
| | ) |
| | origin_tokens = len( |
| | encoding.encode("\n\n".join([instruction] + context + [question]).strip()) |
| | ) |
| | context_tokens_length = [self.get_token_length(c) for c in context] |
| | instruction_tokens_length, question_tokens_length = self.get_token_length( |
| | instruction |
| | ), self.get_token_length(question) |
| | if target_token == -1: |
| | target_token = ( |
| | ( |
| | instruction_tokens_length |
| | + question_tokens_length |
| | + sum(context_tokens_length) |
| | ) |
| | * (1 - ratio) |
| | - instruction_tokens_length |
| | - (question_tokens_length if concate_question else 0) |
| | ) |
| | condition_flag = "_condition" in condition_in_question |
| | condition_in_question = condition_in_question.replace("_condition", "") |
| |
|
| | if len(context) > 1 and use_context_level_filter: |
| | context, dynamic_ratio = self.control_context_budget( |
| | context, |
| | context_tokens_length, |
| | target_token, |
| | force_context_ids, |
| | force_context_number, |
| | question, |
| | condition_in_question, |
| | reorder_context=reorder_context, |
| | dynamic_context_compression_ratio=dynamic_context_compression_ratio, |
| | rank_method=rank_method, |
| | context_budget=context_budget, |
| | ) |
| | else: |
| | dynamic_ratio = [0.0] * len(context) |
| |
|
| | if use_sentence_level_filter: |
| | context = self.control_sentence_budget( |
| | context, |
| | target_token, |
| | keep_first_sentence=keep_first_sentence, |
| | keep_last_sentence=keep_last_sentence, |
| | keep_sentence_number=keep_sentence_number, |
| | high_priority_bonus=high_priority_bonus, |
| | token_budget_ratio=token_budget_ratio, |
| | question=question, |
| | condition_in_question=condition_in_question, |
| | rank_method=rank_method, |
| | ) |
| |
|
| | if condition_flag: |
| | if add_instruction: |
| | context = [question + "\n\n" + instruction] + context |
| | start = self.get_token_length(question + "\n\n" + instruction) + 2 |
| | else: |
| | context = [question] + context |
| | start = self.get_token_length(question) + 2 |
| | else: |
| | start = 0 |
| |
|
| | if use_token_level_filter: |
| | context = self.iterative_compress_prompt( |
| | context, |
| | target_token, |
| | iterative_size=iterative_size, |
| | keep_split=keep_split, |
| | start=start, |
| | dynamic_ratio=dynamic_ratio, |
| | condition_compare=condition_compare, |
| | ) |
| | compressed_prompt = ( |
| | self.tokenizer.batch_decode(context[0])[0] |
| | .replace("<s> ", "") |
| | .replace("<s>", "") |
| | ) |
| | else: |
| | compressed_prompt = "\n\n".join(context) |
| |
|
| | if instruction: |
| | compressed_prompt = instruction + "\n\n" + compressed_prompt |
| | if question and concate_question: |
| | compressed_prompt = compressed_prompt + "\n\n" + question |
| |
|
| | compressed_tokens = len(encoding.encode(compressed_prompt)) |
| | saving = (origin_tokens - compressed_tokens) * 0.06 / 1000 |
| | return { |
| | "compressed_prompt": compressed_prompt, |
| | "origin_tokens": origin_tokens, |
| | "compressed_tokens": compressed_tokens, |
| | |
| | "ratio": compressed_tokens / origin_tokens, |
| | |
| | } |
| |
|
| | def get_token_length(self, text: str, add_special_tokens: bool = True): |
| | return len( |
| | self.tokenizer(text, add_special_tokens=add_special_tokens).input_ids |
| | ) |
| |
|
| | def get_condition_ppl( |
| | self, |
| | text: str, |
| | question: str, |
| | condition_in_question: str = "none", |
| | granularity: str = "sentence", |
| | ): |
| | if condition_in_question == "none": |
| | return self.get_ppl(text, granularity=granularity) |
| | elif condition_in_question == "before": |
| | return self.get_ppl( |
| | question + text, |
| | granularity=granularity, |
| | condition_mode="after", |
| | condition_pos_id=self.get_token_length(question) - 1, |
| | ) |
| | elif condition_in_question == "after": |
| | return self.get_ppl( |
| | text + question, |
| | granularity=granularity, |
| | condition_mode="after", |
| | condition_pos_id=self.get_token_length(text) - 1, |
| | ) |
| |
|
| | def get_dynamic_compression_ratio( |
| | self, |
| | context: list, |
| | target_token: float, |
| | iterative_size: int, |
| | dynamic_ratio: list, |
| | start: int, |
| | ): |
| | def get_ratio(base: float, delta: float): |
| | return max(min(1, base + delta), 0) |
| |
|
| | context_length = [self.get_token_length(ii, False) + 2 for ii in context] |
| | if start: |
| | context_length = context_length[1:] |
| | tau = target_token / (sum(context_length) + 1) |
| | res, idx, last, last_target = [], 0, 1, [] |
| | while idx < len(context_length): |
| | if last + context_length[idx] >= iterative_size: |
| | last_target.append( |
| | (iterative_size - last, get_ratio(tau, dynamic_ratio[idx])) |
| | ) |
| | res.append(last_target) |
| | last = last + context_length[idx] - iterative_size |
| | if last > iterative_size: |
| | k = last // iterative_size |
| | res.extend( |
| | [[(iterative_size, get_ratio(tau, dynamic_ratio[idx]))]] * k |
| | ) |
| | last -= k * iterative_size |
| |
|
| | last_target = ( |
| | [(last, get_ratio(tau, dynamic_ratio[idx]))] if last else [] |
| | ) |
| | else: |
| | last += context_length[idx] |
| | last_target.append( |
| | (context_length[idx], get_ratio(tau, dynamic_ratio[idx])) |
| | ) |
| | idx += 1 |
| | if last_target: |
| | res.append(last_target) |
| | return res |
| |
|
| | def control_context_budget( |
| | self, |
| | context: List[str], |
| | context_tokens_length: List[int], |
| | target_token: float, |
| | force_context_ids: List[int] = None, |
| | force_context_number: int = None, |
| | question: str = "", |
| | condition_in_question: str = "none", |
| | reorder_context: str = "original", |
| | dynamic_context_compression_ratio: float = 0.0, |
| | rank_method: str = "longllmlingua", |
| | context_budget: str = "+100", |
| | ): |
| | if force_context_ids is not None: |
| | return [context[ii] for ii in force_context_ids] |
| | demostrations_sort = self.get_rank_results( |
| | context, |
| | question, |
| | rank_method, |
| | condition_in_question, |
| | context_tokens_length, |
| | ) |
| |
|
| | if target_token < 0: |
| | target_token = 100 |
| | target_token = eval("target_token" + context_budget) |
| | res = [] |
| | used = force_context_ids if force_context_ids is not None else [] |
| |
|
| | self.context_idxs.append([x for idx, (x, _) in enumerate(demostrations_sort)]) |
| | for idx, _ in demostrations_sort: |
| | if idx >= len(context_tokens_length): |
| | continue |
| | target_token -= context_tokens_length[idx] |
| | if idx not in used: |
| | used.append(idx) |
| | if target_token < 0 or ( |
| | force_context_number is not None and len(res) >= force_context_number |
| | ): |
| | break |
| | original_used = used |
| | if reorder_context == "original": |
| | used = sorted(used) |
| | elif reorder_context == "two_stage": |
| | l, r = [_ for idx, _ in enumerate(used) if idx % 2 == 0], [ |
| | _ for idx, _ in enumerate(used) if idx % 2 == 1 |
| | ] |
| | used = l + r[::-1] |
| |
|
| | if dynamic_context_compression_ratio > 0: |
| | N = len(used) |
| | if condition_in_question: |
| | rank = [ |
| | i |
| | for i, _ in self.get_rank_results( |
| | context, |
| | question, |
| | "longllmlingua", |
| | "after", |
| | context_tokens_length, |
| | ) |
| | ] |
| | used = sorted(used, key=lambda x: rank.index(x)) |
| | dynamic_ratio = [ |
| | i * (abs(dynamic_context_compression_ratio) / (N - 1)) if N > 1 else 0 |
| | for i in range(-(N - 1), N, 2) |
| | ][::-1] |
| | dynamic_ratio_map = {i: j for i, j in zip(original_used, dynamic_ratio)} |
| | dynamic_ratio = [dynamic_ratio_map[i] for i in used] |
| | else: |
| | dynamic_ratio = [0.0] * len(used) |
| |
|
| | res = [context[idx] for idx in used if idx < len(context)] |
| | return res, dynamic_ratio |
| |
|
| | def control_sentence_budget( |
| | self, |
| | context: List[str], |
| | target_token: float, |
| | keep_first_sentence: int = 0, |
| | keep_last_sentence: int = 0, |
| | keep_sentence_number: int = 0, |
| | high_priority_bonus: int = 100, |
| | token_budget_ratio: float = 1.4, |
| | question: str = "", |
| | condition_in_question: str = "none", |
| | rank_method: str = "longllmlingua", |
| | ): |
| | def keep_sentence(dem_idx: int, sent_keep: int): |
| | idxs = sorted(dem_g[dem_idx], key=lambda x: sentence_ppl[x])[:sent_keep] |
| | for idx in idxs: |
| | sentence_ppl[idx] += high_priority_bonus |
| |
|
| | sentences = [nltk.sent_tokenize(c) for c in context] |
| | dem_g, s2de, idx = defaultdict(set), defaultdict(int), 0 |
| | for idx_d, s in enumerate(sentences): |
| | for _ in s: |
| | dem_g[idx_d].add(idx) |
| | s2de[idx] = idx_d |
| | idx += 1 |
| |
|
| | context_sentences = [s for ii in sentences for s in ii] |
| | sentence_tokens_length = [ |
| | self.get_token_length(sentence) for sentence in context_sentences |
| | ] |
| | N = len(context_sentences) |
| | flags = list(range(len(context_sentences))) |
| | if len(sentence_tokens_length) == 1: |
| | return context |
| | if rank_method == "longllmlingua": |
| | sentence_ppl = [ |
| | self.get_condition_ppl(sentence, question, condition_in_question) |
| | .cpu() |
| | .numpy() |
| | .item() |
| | for sentence in context_sentences |
| | ] |
| | if keep_first_sentence: |
| | sentence_ppl[:keep_first_sentence] = [ |
| | ii + high_priority_bonus |
| | for ii in sentence_ppl[:keep_first_sentence] |
| | ] |
| | if keep_last_sentence: |
| | sentence_ppl[-keep_last_sentence:] = [ |
| | ii + high_priority_bonus |
| | for ii in sentence_ppl[-keep_last_sentence:] |
| | ] |
| | if keep_sentence_number: |
| | for dem_idx in range(len(sentences)): |
| | keep_sentence(dem_idx, keep_sentence_number) |
| | sort_direct = -1 if condition_in_question == "none" else 1 |
| | sent_sort = sorted( |
| | enumerate(sentence_ppl), key=lambda x: sort_direct * x[1] |
| | ) |
| | else: |
| | sent_sort = self.get_rank_results( |
| | context_sentences, |
| | question, |
| | rank_method, |
| | condition_in_question, |
| | [0] * len(context_sentences), |
| | ) |
| |
|
| | sentence_flags = [False] * N |
| | if target_token < 0: |
| | target_token = 100 |
| | target_token *= token_budget_ratio |
| | res = [] |
| | for idx, _ in sent_sort: |
| | idx = flags[idx] |
| | target_token -= sentence_tokens_length[idx] |
| | sentence_flags[idx] = True |
| | if target_token < 0: |
| | break |
| | idx = 0 |
| | res = [] |
| | for s in sentences: |
| | tmp = [jj for ii, jj in enumerate(s) if sentence_flags[idx + ii]] |
| | res.append("\n".join(tmp)) |
| | idx += len(s) |
| | return res |
| |
|
| | def get_compressed_input( |
| | self, |
| | loss, |
| | input_ids, |
| | attention_mask, |
| | end=200, |
| | iterative_size=200, |
| | threshold=0.5, |
| | keep_flag=None, |
| | split_token_id: int = 13, |
| | start: int = 0, |
| | self_loss=None, |
| | self_input_ids=None, |
| | self_attention_mask=None, |
| | ): |
| | if self_loss is not None: |
| | need_idx = torch.concat( |
| | [ |
| | loss[:start] > 0, |
| | self_loss[: loss[start:].shape[0]] - loss[start:] > threshold, |
| | loss[:1] > 0, |
| | ] |
| | ) |
| | else: |
| | need_idx = torch.concat([loss > threshold, loss[:1] > 0]) |
| | need_idx[end:] = 1 |
| | need_idx[: end - iterative_size] = 1 |
| | loss = loss[need_idx[:-1]] |
| | if self_loss is not None: |
| | if need_idx.shape[0] < self_loss.shape[0] + start + 1: |
| | need_idx = torch.cat( |
| | [ |
| | need_idx, |
| | torch.ones( |
| | self_loss.shape[0] - need_idx.shape[0] + start + 1, |
| | dtype=torch.bool, |
| | ).to(need_idx.device), |
| | ] |
| | ) |
| | self_loss = self_loss[need_idx[start:-1]] |
| |
|
| | if need_idx.shape[0] < input_ids.shape[1]: |
| | need_idx = torch.cat( |
| | [ |
| | need_idx, |
| | torch.ones( |
| | input_ids.shape[1] - need_idx.shape[0], dtype=torch.bool |
| | ).to(need_idx.device), |
| | ] |
| | ) |
| | elif need_idx.shape[0] > input_ids.shape[1]: |
| | need_idx = need_idx[: input_ids.shape[1]] |
| |
|
| | if keep_flag is not None: |
| | need_idx[keep_flag == 1] = 1 |
| | last = -1 |
| | if keep_flag is not None: |
| | for ii in range(end - iterative_size, end): |
| | if need_idx[ii] != 1: |
| | continue |
| | now = input_ids[0][ii].detach().cpu().item() |
| | if ( |
| | now == split_token_id |
| | and last == split_token_id |
| | and keep_flag[ii].detach().cpu().item() == 0 |
| | ): |
| | need_idx[ii] = 0 |
| | else: |
| | last = now |
| | compressed_input_ids = input_ids[attention_mask == 1][need_idx].unsqueeze(0) |
| | compressed_attention_mask = attention_mask[attention_mask == 1][ |
| | need_idx |
| | ].unsqueeze(0) |
| |
|
| | if self_loss is not None: |
| | self_compressed_input_ids = self_input_ids[self_attention_mask == 1][ |
| | need_idx[start:] |
| | ].unsqueeze(0) |
| | self_compressed_attention_mask = self_attention_mask[ |
| | self_attention_mask == 1 |
| | ][need_idx[start:]].unsqueeze(0) |
| | else: |
| | self_compressed_input_ids, self_compressed_attention_mask = None, None |
| | if keep_flag is not None: |
| | if len(keep_flag) > len(need_idx): |
| | keep_flag = torch.cat( |
| | [ |
| | keep_flag[:start], |
| | keep_flag[start : len(need_idx) + start][need_idx], |
| | keep_flag[start + len(need_idx) :], |
| | ] |
| | ) |
| | else: |
| | keep_flag = keep_flag[need_idx] |
| | end -= (need_idx[:end] == 0).sum() |
| | return ( |
| | compressed_input_ids, |
| | compressed_attention_mask, |
| | keep_flag, |
| | end, |
| | loss, |
| | self_loss, |
| | self_compressed_input_ids, |
| | self_compressed_attention_mask, |
| | ) |
| |
|
| | def get_estimate_threshold_base_distribution( |
| | self, ppl, ratio: float, condition_flag: bool = False |
| | ): |
| | ppl = ppl[ppl != 10000] |
| | target_token = max(0, min(len(ppl) - 1, int(len(ppl) * ratio) - 1)) |
| | return ( |
| | ppl.sort(descending=not condition_flag) |
| | .values[target_token] |
| | .detach() |
| | .cpu() |
| | .item() |
| | ) |
| |
|
| | def iterative_compress_prompt( |
| | self, |
| | context: List[str], |
| | target_token: float, |
| | iterative_size: int = 200, |
| | keep_split: bool = False, |
| | split_token_id: int = 13, |
| | start: int = 0, |
| | dynamic_ratio: list = None, |
| | condition_compare: bool = False, |
| | ): |
| | iterative_ratios = self.get_dynamic_compression_ratio( |
| | context, target_token, iterative_size, dynamic_ratio, start |
| | ) |
| | context = "\n\n".join(context) |
| | tokenized_text = self.tokenizer(context, return_tensors="pt") |
| | input_ids = tokenized_text["input_ids"].to(self.device) |
| | attention_mask = tokenized_text["attention_mask"].to(self.device) |
| |
|
| | N = (attention_mask == 1).sum() |
| | compressed_input_ids, compressed_attention_mask = input_ids, attention_mask |
| | if condition_compare: |
| | self_input_ids, self_attention_mask = ( |
| | input_ids[:, start:], |
| | attention_mask[:, start:], |
| | ) |
| | self_compressed_input_ids, self_compressed_attention_mask = ( |
| | self_input_ids, |
| | self_attention_mask, |
| | ) |
| |
|
| | end = min(iterative_size + start, compressed_input_ids.shape[1]) |
| | threshold, keep_flag = None, None |
| | if keep_split: |
| | input_ids_numpy = input_ids.cpu().detach().numpy()[0] |
| | N = len(input_ids_numpy) |
| | keep_flag = [ |
| | int( |
| | ( |
| | ii > 0 |
| | and input_ids_numpy[ii] == split_token_id |
| | and input_ids_numpy[ii - 1] == split_token_id |
| | ) |
| | or ( |
| | ii < N - 1 |
| | and input_ids_numpy[ii] == split_token_id |
| | and input_ids_numpy[ii + 1] == split_token_id |
| | ) |
| | ) |
| | for ii in range(N) |
| | ] |
| | keep_flag = torch.tensor(keep_flag).to(self.device) |
| | past_key_values, past_loss, ready_end = None, None, 0 |
| | self_past_key_values, self_past_loss, self_ready_end = None, None, 0 |
| | pop_compressed_input_ids, pop_self_compressed_input_ids = None, None |
| | idx = 0 |
| | while end <= compressed_input_ids.shape[1]: |
| | if end > self.max_position_embeddings and past_key_values is not None: |
| | |
| | e, s = end - self.max_position_embeddings, self.cache_bos_num |
| | if pop_compressed_input_ids is None: |
| | pop_compressed_input_ids = compressed_input_ids[:, :e] |
| | else: |
| | pop_compressed_input_ids = torch.cat( |
| | [pop_compressed_input_ids, compressed_input_ids[:, :e]], dim=-1 |
| | ) |
| | compressed_input_ids = compressed_input_ids[:, e:] |
| | compressed_attention_mask = compressed_attention_mask[:, e:] |
| | past_key_values = [ |
| | [ |
| | torch.cat([k[..., :s, :], k[..., s + e :, :]], dim=-2), |
| | torch.cat([v[..., :s, :], v[..., s + e :, :]], dim=-2), |
| | ] |
| | for k, v in past_key_values |
| | ] |
| | end, ready_end = end - e, ready_end - e |
| | if condition_compare: |
| | self_ready_end -= e |
| | if pop_self_compressed_input_ids is None: |
| | pop_self_compressed_input_ids = self_compressed_input_ids[:, :e] |
| | else: |
| | pop_self_compressed_input_ids = torch.cat( |
| | [ |
| | pop_self_compressed_input_ids, |
| | self_compressed_input_ids[:, :e], |
| | ], |
| | dim=-1, |
| | ) |
| | self_compressed_input_ids = self_compressed_input_ids[:, e:] |
| | self_compressed_attention_mask = self_compressed_attention_mask[ |
| | :, e: |
| | ] |
| | self_past_key_values = [ |
| | [ |
| | torch.cat([k[..., :s, :], k[..., s + e :, :]], dim=-2), |
| | torch.cat([v[..., :s, :], v[..., s + e :, :]], dim=-2), |
| | ] |
| | for k, v in self_past_key_values |
| | ] |
| |
|
| | loss, past_key_values = self.get_ppl( |
| | "", |
| | "token", |
| | compressed_input_ids, |
| | compressed_attention_mask, |
| | past_key_values=past_key_values, |
| | return_kv=True, |
| | end=end if idx else None, |
| | ) |
| | if past_loss is not None: |
| | if end - 1 > len(past_loss): |
| | past_loss = torch.cat( |
| | [past_loss, torch.zeros_like(loss)[: end - 1 - len(past_loss)]] |
| | ) |
| | past_loss[ready_end : end - 1] = loss |
| | loss = past_loss |
| | else: |
| | past_loss = loss |
| | if idx: |
| | past_key_values = [ |
| | [k[:, :, : end - iterative_size], v[:, :, : end - iterative_size]] |
| | for k, v in past_key_values |
| | ] |
| | else: |
| | past_key_values = None |
| |
|
| | if condition_compare: |
| | self_loss, self_past_key_values = self.get_ppl( |
| | "", |
| | "token", |
| | self_compressed_input_ids, |
| | self_compressed_attention_mask, |
| | past_key_values=self_past_key_values, |
| | return_kv=True, |
| | end=end - start if idx else None, |
| | ) |
| | if self_past_loss is not None: |
| | if end - start - 1 > len(self_past_loss): |
| | self_past_loss = torch.cat( |
| | [ |
| | self_past_loss, |
| | torch.zeros_like(self_loss)[ |
| | : end - 1 - start - len(self_past_loss) |
| | ], |
| | ] |
| | ) |
| | self_past_loss[self_ready_end : end - start - 1] = self_loss |
| | self_loss = self_past_loss |
| | else: |
| | self_past_loss = self_loss |
| | if idx: |
| | self_past_key_values = [ |
| | [ |
| | k[:, :, : end - iterative_size - start], |
| | v[:, :, : end - iterative_size - start], |
| | ] |
| | for k, v in self_past_key_values |
| | ] |
| | else: |
| | self_past_key_values = None |
| |
|
| | self_ready_end = ( |
| | end - start - iterative_size if not (start and idx == 0) else 0 |
| | ) |
| | ready_end = end - iterative_size if not (start and idx == 0) else 0 |
| |
|
| | for delta_end, ratio in iterative_ratios[idx]: |
| | loss = past_loss |
| | if condition_compare: |
| | self_loss = self_past_loss |
| | threshold = self.get_estimate_threshold_base_distribution( |
| | self_loss[: loss[start:].shape[0]] - loss[start:], ratio, False |
| | ) |
| | else: |
| | threshold = self.get_estimate_threshold_base_distribution( |
| | loss, ratio, False |
| | ) |
| |
|
| | ( |
| | compressed_input_ids, |
| | compressed_attention_mask, |
| | keep_flag, |
| | end, |
| | past_loss, |
| | self_past_loss, |
| | self_compressed_input_ids, |
| | self_compressed_attention_mask, |
| | ) = self.get_compressed_input( |
| | loss, |
| | compressed_input_ids, |
| | compressed_attention_mask, |
| | end - iterative_size + delta_end, |
| | iterative_size=delta_end, |
| | threshold=threshold, |
| | keep_flag=keep_flag, |
| | split_token_id=split_token_id, |
| | start=start, |
| | self_loss=self_loss if condition_compare else None, |
| | self_input_ids=self_compressed_input_ids |
| | if condition_compare |
| | else None, |
| | self_attention_mask=self_compressed_attention_mask |
| | if condition_compare |
| | else None, |
| | ) |
| | end += iterative_size |
| | idx += 1 |
| | if pop_compressed_input_ids is not None: |
| | compressed_input_ids = torch.cat( |
| | [pop_compressed_input_ids, compressed_input_ids], dim=-1 |
| | ) |
| | return compressed_input_ids[:, start:], compressed_attention_mask[:, start:] |
| |
|
| | def recover( |
| | self, |
| | original_prompt: str, |
| | compressed_prompt: str, |
| | response: str, |
| | ): |
| | def match_from_compressed(response_word): |
| | response_input_ids = self.tokenizer( |
| | response_word, add_special_tokens=False |
| | )["input_ids"] |
| | response_set, response_c = set(response_input_ids), defaultdict(list) |
| | for idx in range(M): |
| | if original_input_ids[idx] in response_set: |
| | response_c[original_input_ids[idx]].append(idx) |
| | res, res_min, res_c = None, float("inf"), 1 |
| | n = len(response_input_ids) |
| | for l in response_c[response_input_ids[0]]: |
| | x, y, c = 0, l, 1 |
| | for x in range(1, n): |
| | idx = bisect.bisect_right(response_c[response_input_ids[x]], y) |
| | if ( |
| | idx >= len(response_c[response_input_ids[x]]) |
| | or response_c[response_input_ids[x]][idx] - y > 10 |
| | ): |
| | continue |
| | c += 1 |
| | y = response_c[response_input_ids[x]][idx] |
| | if c > res_c: |
| | res_c = c |
| | res_min = y - l + 1 |
| | res = (l, y + 1) |
| | elif c == res_c and y - l + 1 < res_min: |
| | res_min = y - l + 1 |
| | res = (l, y + 1) |
| |
|
| | if res is None: |
| | return response_word |
| | |
| | |
| | |
| | |
| | return self.tokenizer.decode(original_input_ids[res[0] : res[1]]) |
| |
|
| | response_words = response.split(" ") |
| |
|
| | original_input_ids = self.tokenizer(original_prompt, add_special_tokens=False)[ |
| | "input_ids" |
| | ] |
| | N, M = len(response_words), len(original_input_ids) |
| | recovered_response_words = [] |
| | l = 0 |
| | while l < N: |
| | if response_words[l] not in compressed_prompt: |
| | recovered_response_words.append(response_words[l]) |
| | l += 1 |
| | continue |
| | r = l |
| | while ( |
| | r + 1 < N and " ".join(response_words[l : r + 2]) in compressed_prompt |
| | ): |
| | r += 1 |
| |
|
| | match_words = match_from_compressed(" ".join(response_words[l : r + 1])) |
| | recovered_response_words.append(match_words) |
| | l = r + 1 |
| | return " ".join(recovered_response_words) |
| |
|
| | def get_rank_results( |
| | self, |
| | context: list, |
| | question: str, |
| | rank_method: str, |
| | condition_in_question: str, |
| | context_tokens_length: list, |
| | ): |
| | def get_distance_bm25(corpus, query): |
| | from rank_bm25 import BM25Okapi |
| |
|
| | tokenized_corpus = [doc.split(" ") for doc in corpus] |
| | bm25 = BM25Okapi(tokenized_corpus) |
| | tokenized_query = query.split(" ") |
| | doc_scores = bm25.get_scores(tokenized_query) |
| | idx = [(ii, 0) for ii in (-doc_scores).argsort()] |
| | return idx |
| |
|
| | def get_distance_gzip(corpus, query): |
| | def get_score(x, y): |
| | cx, cy = len(gzip.compress(x.encode())), len(gzip.compress(y.encode())) |
| | cxy = len(gzip.compress(f"{x} {y}".encode())) |
| | return (cxy - min(cx, cy)) / max(cx, cy) |
| |
|
| | import gzip |
| |
|
| | doc_scores = [get_score(doc, query) for doc in corpus] |
| | idx = [(ii, 0) for ii in np.argsort(doc_scores)] |
| | return idx |
| |
|
| | def get_distance_sentbert(corpus, query): |
| | from sentence_transformers import SentenceTransformer, util |
| |
|
| | if self.retrieval_model is None or self.retrieval_model_name != rank_method: |
| | self.retrieval_model = SentenceTransformer("multi-qa-mpnet-base-dot-v1") |
| | self.retrieval_model_name = rank_method |
| | doc_embeds = self.retrieval_model.encode(corpus) |
| | query = self.retrieval_model.encode(query) |
| | doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1) |
| | idx = [(ii, 0) for ii in np.argsort(doc_scores)] |
| | return idx |
| |
|
| | def get_distance_openai(corpus, query): |
| | import openai |
| | from sentence_transformers import util |
| |
|
| | openai.api_key = self.open_api_config.get("api_key", "") |
| | openai.api_base = self.open_api_config.get( |
| | "api_base", "https://api.openai.com/v1" |
| | ) |
| | openai.api_type = self.open_api_config.get("api_type", "open_ai") |
| | openai.api_version = self.open_api_config.get("api_version", "2023-05-15") |
| | engine = self.open_api_config.get("engine", "text-embedding-ada-002") |
| |
|
| | def get_embed(text): |
| | return openai.Embedding.create( |
| | input=[text.replace("\n", " ")], engine=engine |
| | )["LongBench"][0]["embedding"] |
| |
|
| | doc_embeds = [get_embed(i) for i in corpus] |
| | query = get_embed(query) |
| | doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1) |
| | idx = [(ii, 0) for ii in np.argsort(doc_scores)] |
| | return idx |
| |
|
| | def get_distance_sentbert_bge(corpus, query): |
| | from sentence_transformers import SentenceTransformer, util |
| |
|
| | if self.retrieval_model is None or self.retrieval_model_name != rank_method: |
| | self.retrieval_model = SentenceTransformer("BAAI/bge-large-en-v1.5") |
| | self.retrieval_model_name = rank_method |
| | doc_embeds = self.retrieval_model.encode( |
| | [i for i in corpus], normalize_embeddings=True |
| | ) |
| | query = self.retrieval_model.encode(query, normalize_embeddings=True) |
| | doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1) |
| | idx = [(ii, 0) for ii in np.argsort(doc_scores)] |
| | return idx |
| |
|
| | def get_distance_bge_ranker(corpus, query): |
| | from transformers import AutoModelForSequenceClassification, AutoTokenizer |
| |
|
| | pairs = [[i, query] for i in corpus] |
| | if self.retrieval_model is None or self.retrieval_model_name != rank_method: |
| | tokenizer = AutoTokenizer.from_pretrained("BAAI/bge-reranker-large") |
| | model = ( |
| | AutoModelForSequenceClassification.from_pretrained( |
| | "BAAI/bge-reranker-large" |
| | ) |
| | .eval() |
| | .to(self.device) |
| | ) |
| | self.retrieval_model = [tokenizer, model] |
| | self.retrieval_model_name = rank_method |
| | with torch.no_grad(): |
| | inputs = self.retrieval_model[0]( |
| | pairs, |
| | padding=True, |
| | truncation=True, |
| | return_tensors="pt", |
| | max_length=512, |
| | ).to(self.device) |
| | scores = ( |
| | self.retrieval_model[1](**inputs, return_dict=True) |
| | .logits.view( |
| | -1, |
| | ) |
| | .float() |
| | ) |
| | idx = [(ii, 0) for ii in np.argsort(-scores.cpu())] |
| | return idx |
| |
|
| | def get_distance_bge_llmembedder(corpus, query): |
| | from transformers import AutoModel, AutoTokenizer |
| |
|
| | if self.retrieval_model is None or self.retrieval_model_name != rank_method: |
| | tokenizer = AutoTokenizer.from_pretrained("BAAI/llm-embedder") |
| | model = ( |
| | AutoModel.from_pretrained("BAAI/llm-embedder") |
| | .eval() |
| | .to(self.device) |
| | ) |
| | self.retrieval_model = [tokenizer, model] |
| | self.retrieval_model_name = rank_method |
| |
|
| | instruction_qa_query = ( |
| | "Represent this query for retrieving relevant documents: " |
| | ) |
| | instruction_qa_key = "Represent this document for retrieval: " |
| | queries = [instruction_qa_query + query for _ in corpus] |
| | keys = [instruction_qa_key + key for key in corpus] |
| | with torch.no_grad(): |
| | query_inputs = self.retrieval_model[0]( |
| | queries, |
| | padding=True, |
| | truncation=True, |
| | return_tensors="pt", |
| | max_length=512, |
| | ).to(self.device) |
| | key_inputs = self.retrieval_model[0]( |
| | keys, |
| | padding=True, |
| | truncation=True, |
| | return_tensors="pt", |
| | max_length=512, |
| | ).to(self.device) |
| | query_outputs = self.retrieval_model[1](**query_inputs) |
| | key_outputs = self.retrieval_model[1](**key_inputs) |
| | |
| | query_embeddings = query_outputs.last_hidden_state[:, 0] |
| | key_embeddings = key_outputs.last_hidden_state[:, 0] |
| | |
| | query_embeddings = torch.nn.functional.normalize( |
| | query_embeddings, p=2, dim=1 |
| | ) |
| | key_embeddings = torch.nn.functional.normalize( |
| | key_embeddings, p=2, dim=1 |
| | ) |
| | similarity = query_embeddings @ key_embeddings.T |
| | idx = [(ii, 0) for ii in np.argsort(-similarity[0].cpu())] |
| | return idx |
| |
|
| | def get_distance_jinza(corpus, query): |
| | from numpy.linalg import norm |
| |
|
| | from transformers import AutoModel |
| |
|
| | def cos_sim(a, b): |
| | return (a @ b.T) / (norm(a) * norm(b)) |
| |
|
| | if self.retrieval_model is None or self.retrieval_model_name != rank_method: |
| | model = ( |
| | AutoModel.from_pretrained( |
| | "jinaai/jina-embeddings-v2-base-en", trust_remote_code=True |
| | ) |
| | .eval() |
| | .to(self.device) |
| | ) |
| | self.retrieval_model = model |
| | self.retrieval_model_name = rank_method |
| |
|
| | doc_embeds = self.retrieval_model.encode(corpus) |
| | query = self.retrieval_model.encode(query) |
| | doc_scores = cos_sim(doc_embeds, query) |
| | idx = [(ii, 0) for ii in np.argsort(-doc_scores)] |
| | return idx |
| |
|
| | def get_distance_voyageai(corpus, query): |
| | import voyageai |
| | from sentence_transformers import util |
| |
|
| | voyageai.api_key = self.open_api_config.get("voyageai_api_key", "") |
| |
|
| | def get_embed(text): |
| | return voyageai.get_embedding(text, model="voyage-01") |
| |
|
| | doc_embeds = [get_embed(i) for i in corpus] |
| | query = get_embed(query) |
| | doc_scores = -util.dot_score(doc_embeds, query).cpu().numpy().reshape(-1) |
| | idx = [(ii, 0) for ii in np.argsort(doc_scores)] |
| | return idx |
| |
|
| | def get_distance_cohere(corpus, query): |
| | import cohere |
| |
|
| | api_key = self.open_api_config.get("cohere_api_key", "") |
| | co = cohere.Client(api_key) |
| | results = co.rerank( |
| | model="rerank-english-v2.0", query=query, documents=corpus, top_n=20 |
| | ) |
| | c_map = {jj: ii for ii, jj in enumerate(corpus)} |
| | doc_rank = [c_map[ii.document["text"]] for ii in results] |
| | idx = [(ii, 0) for ii in doc_rank] |
| | return idx |
| |
|
| | def get_distance_longllmlingua(corpus, query): |
| | context_ppl = [ |
| | self.get_condition_ppl( |
| | d, |
| | query |
| | + " We can get the answer to this question in the given documents.", |
| | condition_in_question, |
| | ) |
| | - dl * 2 / 250 * 0 |
| | for d, dl in zip(corpus, context_tokens_length) |
| | ] |
| | sort_direct = -1 if condition_in_question == "none" else 1 |
| | ys = sorted(enumerate(context_ppl), key=lambda x: sort_direct * x[1]) |
| | return ys |
| |
|
| | method = None |
| | if rank_method == "bm25": |
| | method = get_distance_bm25 |
| | elif rank_method == "gzip": |
| | method = get_distance_gzip |
| | elif rank_method == "sentbert": |
| | method = get_distance_sentbert |
| | elif rank_method == "openai": |
| | method = get_distance_openai |
| | elif rank_method in ["longllmlingua", "llmlingua"]: |
| | method = get_distance_longllmlingua |
| | elif rank_method == "bge": |
| | method = get_distance_sentbert_bge |
| | elif rank_method == "bge_reranker": |
| | method = get_distance_bge_ranker |
| | elif rank_method == "bge_llmembedder": |
| | method = get_distance_bge_llmembedder |
| | elif rank_method == "jinza": |
| | method = get_distance_jinza |
| | elif rank_method == "voyageai": |
| | method = get_distance_voyageai |
| | elif rank_method == "cohere": |
| | method = get_distance_cohere |
| | return method(context, question) |
| |
|
| |
|