Spaces:
Running
on
Zero
Running
on
Zero
| from transformers.pipelines.text_generation import Chat | |
| from transformers import TextGenerationPipeline | |
| from typing import Dict | |
| class MyTextGenerationPipeline(TextGenerationPipeline): | |
| """ | |
| This subclass overrides the preprocess method to add pad_to_multiple_of=8 to tokenizer_kwargs. | |
| Fix for: "RuntimeError: p.attn_bias_ptr is not correctly aligned" | |
| https://github.com/google-deepmind/gemma/issues/169 | |
| NOTE: we also need padding="longest", which is set during class instantiation | |
| """ | |
| def preprocess( | |
| self, | |
| prompt_text, | |
| prefix="", | |
| handle_long_generation=None, | |
| add_special_tokens=None, | |
| truncation=None, | |
| padding=None, | |
| max_length=None, | |
| continue_final_message=None, | |
| **generate_kwargs, | |
| ): | |
| # Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults | |
| tokenizer_kwargs = { | |
| "add_special_tokens": add_special_tokens, | |
| "truncation": truncation, | |
| "padding": padding, | |
| "max_length": max_length, | |
| "pad_to_multiple_of": 8, | |
| } | |
| tokenizer_kwargs = { | |
| key: value for key, value in tokenizer_kwargs.items() if value is not None | |
| } | |
| if isinstance(prompt_text, Chat): | |
| tokenizer_kwargs.pop( | |
| "add_special_tokens", None | |
| ) # ignore add_special_tokens on chats | |
| # If the user passes a chat that ends in an assistant message, we treat it as a prefill by default | |
| # because very few models support multiple separate, consecutive assistant messages | |
| if continue_final_message is None: | |
| continue_final_message = prompt_text.messages[-1]["role"] == "assistant" | |
| inputs = self.tokenizer.apply_chat_template( | |
| prompt_text.messages, | |
| add_generation_prompt=not continue_final_message, | |
| continue_final_message=continue_final_message, | |
| return_dict=True, | |
| return_tensors=self.framework, | |
| **tokenizer_kwargs, | |
| ) | |
| else: | |
| inputs = self.tokenizer( | |
| prefix + prompt_text, return_tensors=self.framework, **tokenizer_kwargs | |
| ) | |
| inputs["prompt_text"] = prompt_text | |
| if handle_long_generation == "hole": | |
| cur_len = inputs["input_ids"].shape[-1] | |
| if "max_new_tokens" in generate_kwargs: | |
| new_tokens = generate_kwargs["max_new_tokens"] | |
| else: | |
| new_tokens = ( | |
| generate_kwargs.get("max_length", self.generation_config.max_length) | |
| - cur_len | |
| ) | |
| if new_tokens < 0: | |
| raise ValueError("We cannot infer how many new tokens are expected") | |
| if cur_len + new_tokens > self.tokenizer.model_max_length: | |
| keep_length = self.tokenizer.model_max_length - new_tokens | |
| if keep_length <= 0: | |
| raise ValueError( | |
| "We cannot use `hole` to handle this generation the number of desired tokens exceeds the" | |
| " models max length" | |
| ) | |
| inputs["input_ids"] = inputs["input_ids"][:, -keep_length:] | |
| if "attention_mask" in inputs: | |
| inputs["attention_mask"] = inputs["attention_mask"][ | |
| :, -keep_length: | |
| ] | |
| return inputs | |