| | from transformers import AutoTokenizer, AutoModelForCausalLM |
| | import torch |
| | from abs_compressor import AbstractCompressor |
| |
|
| |
|
| | class KiSCompressor(AbstractCompressor): |
| | def __init__(self, DEVICE: str = 'cpu', model_dir: str = 'philippelaban/keep_it_simple'): |
| | self.DEVICE = DEVICE |
| | self.tokenizer = AutoTokenizer.from_pretrained(model_dir, padding_side='right', pad_token='<|endoftext|') |
| | self.tokenizer.pad_token = self.tokenizer.eos_token |
| | self.tokenizer.padding_side = 'right' |
| | self.kis_model = AutoModelForCausalLM.from_pretrained(model_dir) |
| | self.kis_model.to(self.DEVICE) |
| | |
| | |
| | |
| |
|
| | def compress(self, original_prompt: str, ratio: float = 0.5, max_length: int = 150, num_beams: int = 4, do_sample: bool = True, num_return_sequences: int = 1, target_index: int = 0) -> dict: |
| |
|
| | original_tokens = len(self.gpt_tokenizer.encode(original_prompt)) |
| |
|
| | start_id = self.tokenizer.bos_token_id |
| | print(self.tokenizer.padding_side) |
| | tokenized_paragraph = [(self.tokenizer.encode(text=original_prompt) + [start_id])] |
| | input_ids = torch.LongTensor(tokenized_paragraph) |
| | if self.DEVICE == 'cuda': |
| | input_ids = input_ids.type(torch.cuda.LongTensor) |
| | output_ids = self.kis_model.generate(input_ids, max_length=max_length, num_beams=num_beams, do_sample=do_sample, |
| | num_return_sequences=num_return_sequences, |
| | pad_token_id=self.tokenizer.eos_token_id) |
| | output_ids = output_ids[:, input_ids.shape[1]:] |
| | output = self.tokenizer.batch_decode(output_ids) |
| | output = [o.replace(self.tokenizer.eos_token, "") for o in output] |
| | compressed_prompt = output[target_index] |
| |
|
| | compressed_tokens = len(self.gpt_tokenizer.encode(compressed_prompt)) |
| |
|
| | result = { |
| | 'compressed_prompt': compressed_prompt, |
| | 'ratio': compressed_tokens / original_tokens, |
| | 'original_tokens': original_tokens, |
| | 'compressed_tokens': compressed_tokens, |
| | } |
| |
|
| | return result |
| |
|
| |
|
| |
|