File size: 27,394 Bytes
fb42d3e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 |
import enum
import itertools
import types
from typing import Any, overload
from ..generation import GenerationConfig
from ..utils import ModelOutput, add_end_docstrings, is_tf_available, is_torch_available
from .base import Pipeline, build_pipeline_init_args
if is_torch_available():
import torch
from ..models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
from .pt_utils import KeyDataset
if is_tf_available():
import tensorflow as tf
from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
ChatType = list[dict[str, str]]
class ReturnType(enum.Enum):
TENSORS = 0
NEW_TEXT = 1
FULL_TEXT = 2
class Chat:
"""This class is intended to just be used internally in this pipeline and not exposed to users. We convert chats
to this format because the rest of the pipeline code tends to assume that lists of messages are
actually a batch of samples rather than messages in the same conversation."""
def __init__(self, messages: dict):
for message in messages:
if not ("role" in message and "content" in message):
raise ValueError("When passing chat dicts as input, each dict must have a 'role' and 'content' key.")
self.messages = messages
@add_end_docstrings(build_pipeline_init_args(has_tokenizer=True))
class TextGenerationPipeline(Pipeline):
"""
Language generation pipeline using any `ModelWithLMHead` or `ModelForCausalLM`. This pipeline predicts the words
that will follow a specified text prompt. When the underlying model is a conversational model, it can also accept
one or more chats, in which case the pipeline will operate in chat mode and will continue the chat(s) by adding
its response(s). Each chat takes the form of a list of dicts, where each dict contains "role" and "content" keys.
Unless the model you're using explicitly sets these generation parameters in its configuration files
(`generation_config.json`), the following default values will be used:
- max_new_tokens: 256
- do_sample: True
- temperature: 0.7
Examples:
```python
>>> from transformers import pipeline
>>> generator = pipeline(model="openai-community/gpt2")
>>> generator("I can't believe you did such a ", do_sample=False)
[{'generated_text': "I can't believe you did such a icky thing to me. I'm so sorry. I'm so sorry. I'm so sorry. I'm so sorry. I'm so sorry. I'm so sorry. I'm so sorry. I"}]
>>> # These parameters will return suggestions, and only the newly created text making it easier for prompting suggestions.
>>> outputs = generator("My tart needs some", num_return_sequences=4, return_full_text=False)
```
```python
>>> from transformers import pipeline
>>> generator = pipeline(model="HuggingFaceH4/zephyr-7b-beta")
>>> # Zephyr-beta is a conversational model, so let's pass it a chat instead of a single string
>>> generator([{"role": "user", "content": "What is the capital of France? Answer in one word."}], do_sample=False, max_new_tokens=2)
[{'generated_text': [{'role': 'user', 'content': 'What is the capital of France? Answer in one word.'}, {'role': 'assistant', 'content': 'Paris'}]}]
```
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial). You can pass text
generation parameters to this pipeline to control stopping criteria, decoding strategy, and more. Learn more about
text generation parameters in [Text generation strategies](../generation_strategies) and [Text
generation](text_generation).
This language generation pipeline can currently be loaded from [`pipeline`] using the following task identifier:
`"text-generation"`.
The models that this pipeline can use are models that have been trained with an autoregressive language modeling
objective. See the list of available [text completion models](https://huggingface.co/models?filter=text-generation)
and the list of [conversational models](https://huggingface.co/models?other=conversational)
on [huggingface.co/models].
"""
# Prefix text to help Transformer-XL and XLNet with short prompts as proposed by Aman Rusia
# in https://github.com/rusiaaman/XLNet-gen#methodology
# and https://medium.com/@amanrusia/xlnet-speaks-comparison-to-gpt-2-ea1a4e9ba39e
XL_PREFIX = """
In 1991, the remains of Russian Tsar Nicholas II and his family (except for Alexei and Maria) are discovered. The
voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the remainder of the story. 1883 Western
Siberia, a young Grigori Rasputin is asked by his father and a group of men to perform magic. Rasputin has a vision
and denounces one of the men as a horse thief. Although his father initially slaps him for making such an
accusation, Rasputin watches as the man is chased outside and beaten. Twenty years later, Rasputin sees a vision of
the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous, with people, even a bishop,
begging for his blessing. <eod> </s> <eos>
"""
_pipeline_calls_generate = True
_load_processor = False
_load_image_processor = False
_load_feature_extractor = False
_load_tokenizer = True
# Make sure the docstring is updated when the default generation config is changed
_default_generation_config = GenerationConfig(
max_new_tokens=256,
do_sample=True, # free-form text generation often uses sampling
temperature=0.7,
)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.check_model_type(
TF_MODEL_FOR_CAUSAL_LM_MAPPING_NAMES if self.framework == "tf" else MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
)
if "prefix" not in self._preprocess_params:
# This is very specific. The logic is quite complex and needs to be done
# as a "default".
# It also defines both some preprocess_kwargs and generate_kwargs
# which is why we cannot put them in their respective methods.
prefix = None
if self.prefix is not None:
prefix = self.prefix
if prefix is None and self.model.__class__.__name__ in [
"XLNetLMHeadModel",
"TransfoXLLMHeadModel",
"TFXLNetLMHeadModel",
"TFTransfoXLLMHeadModel",
]:
# For XLNet and TransformerXL we add an article to the prompt to give more state to the model.
prefix = self.XL_PREFIX
if prefix is not None:
# Recalculate some generate_kwargs linked to prefix.
preprocess_params, forward_params, _ = self._sanitize_parameters(prefix=prefix, **self._forward_params)
self._preprocess_params = {**self._preprocess_params, **preprocess_params}
self._forward_params = {**self._forward_params, **forward_params}
def _sanitize_parameters(
self,
return_full_text=None,
return_tensors=None,
return_text=None,
return_type=None,
clean_up_tokenization_spaces=None,
prefix=None,
handle_long_generation=None,
stop_sequence=None,
truncation=None,
max_length=None,
continue_final_message=None,
skip_special_tokens=None,
tokenizer_encode_kwargs=None,
**generate_kwargs,
):
# preprocess kwargs
preprocess_params = {}
add_special_tokens = False
if "add_special_tokens" in generate_kwargs:
add_special_tokens = preprocess_params["add_special_tokens"] = generate_kwargs.pop("add_special_tokens")
if "padding" in generate_kwargs:
preprocess_params["padding"] = generate_kwargs.pop("padding")
if truncation is not None:
preprocess_params["truncation"] = truncation
if max_length is not None:
preprocess_params["max_length"] = max_length
generate_kwargs["max_length"] = max_length
if prefix is not None:
preprocess_params["prefix"] = prefix
if prefix:
prefix_inputs = self.tokenizer(
prefix, padding=False, add_special_tokens=add_special_tokens, return_tensors=self.framework
)
generate_kwargs["prefix_length"] = prefix_inputs["input_ids"].shape[-1]
if handle_long_generation is not None:
if handle_long_generation != "hole":
raise ValueError(
f"{handle_long_generation} is not a valid value for `handle_long_generation` parameter expected"
" [None, 'hole']"
)
preprocess_params["handle_long_generation"] = handle_long_generation
if continue_final_message is not None:
preprocess_params["continue_final_message"] = continue_final_message
if tokenizer_encode_kwargs is not None:
preprocess_params["tokenizer_encode_kwargs"] = tokenizer_encode_kwargs
preprocess_params.update(generate_kwargs)
# forward kwargs
if stop_sequence is not None:
stop_sequence_ids = self.tokenizer.encode(stop_sequence, add_special_tokens=False)
generate_kwargs["eos_token_id"] = stop_sequence_ids
forward_params = generate_kwargs
if self.assistant_model is not None:
forward_params["assistant_model"] = self.assistant_model
if self.assistant_tokenizer is not None:
forward_params["tokenizer"] = self.tokenizer
forward_params["assistant_tokenizer"] = self.assistant_tokenizer
# postprocess kwargs
postprocess_params = {}
if return_full_text is not None and return_type is None:
if return_text is not None:
raise ValueError("`return_text` is mutually exclusive with `return_full_text`")
if return_tensors is not None:
raise ValueError("`return_full_text` is mutually exclusive with `return_tensors`")
return_type = ReturnType.FULL_TEXT if return_full_text else ReturnType.NEW_TEXT
if return_tensors is not None and return_type is None:
if return_text is not None:
raise ValueError("`return_text` is mutually exclusive with `return_tensors`")
return_type = ReturnType.TENSORS
if return_type is not None:
postprocess_params["return_type"] = return_type
if clean_up_tokenization_spaces is not None:
postprocess_params["clean_up_tokenization_spaces"] = clean_up_tokenization_spaces
if continue_final_message is not None:
postprocess_params["continue_final_message"] = continue_final_message
if skip_special_tokens is not None:
postprocess_params["skip_special_tokens"] = skip_special_tokens
return preprocess_params, forward_params, postprocess_params
# overriding _parse_and_tokenize to allow for unusual language-modeling tokenizer arguments
def _parse_and_tokenize(self, *args, **kwargs):
"""
Parse arguments and tokenize
"""
# Parse arguments
if self.model.__class__.__name__ == "TransfoXLLMHeadModel":
kwargs.update({"add_space_before_punct_symbol": True})
return super()._parse_and_tokenize(*args, **kwargs)
@overload
def __call__(self, text_inputs: str, **kwargs: Any) -> list[dict[str, str]]: ...
@overload
def __call__(self, text_inputs: list[str], **kwargs: Any) -> list[list[dict[str, str]]]: ...
@overload
def __call__(self, text_inputs: ChatType, **kwargs: Any) -> list[dict[str, ChatType]]: ...
@overload
def __call__(self, text_inputs: list[ChatType], **kwargs: Any) -> list[list[dict[str, ChatType]]]: ...
def __call__(self, text_inputs, **kwargs):
"""
Complete the prompt(s) given as inputs.
Args:
text_inputs (`str`, `list[str]`, list[dict[str, str]], or `list[list[dict[str, str]]]`):
One or several prompts (or one list of prompts) to complete. If strings or a list of string are
passed, this pipeline will continue each prompt. Alternatively, a "chat", in the form of a list
of dicts with "role" and "content" keys, can be passed, or a list of such chats. When chats are passed,
the model's chat template will be used to format them before passing them to the model.
return_tensors (`bool`, *optional*, defaults to `False`):
Returns the tensors of predictions (as token indices) in the outputs. If set to
`True`, the decoded text is not returned.
return_text (`bool`, *optional*):
Returns the decoded texts in the outputs.
return_full_text (`bool`, *optional*, defaults to `True`):
If set to `False` only added text is returned, otherwise the full text is returned. Cannot be
specified at the same time as `return_text`.
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
Whether or not to clean up the potential extra spaces in the text output.
continue_final_message( `bool`, *optional*): This indicates that you want the model to continue the
last message in the input chat rather than starting a new one, allowing you to "prefill" its response.
By default this is `True` when the final message in the input chat has the `assistant` role and
`False` otherwise, but you can manually override that behaviour by setting this flag.
prefix (`str`, *optional*):
Prefix added to prompt.
handle_long_generation (`str`, *optional*):
By default, this pipelines does not handle long generation (ones that exceed in one form or the other
the model maximum length). There is no perfect way to address this (more info
:https://github.com/huggingface/transformers/issues/14033#issuecomment-948385227). This provides common
strategies to work around that problem depending on your use case.
- `None` : default strategy where nothing in particular happens
- `"hole"`: Truncates left of input, and leaves a gap wide enough to let generation happen (might
truncate a lot of the prompt and not suitable when generation exceed the model capacity)
tokenizer_encode_kwargs (`dict`, *optional*):
Additional keyword arguments to pass along to the encoding step of the tokenizer. If the text input is
a chat, it is passed to `apply_chat_template`. Otherwise, it is passed to `__call__`.
generate_kwargs (`dict`, *optional*):
Additional keyword arguments to pass along to the generate method of the model (see the generate method
corresponding to your framework [here](./text_generation)).
Return:
A list or a list of lists of `dict`: Returns one of the following dictionaries (cannot return a combination
of both `generated_text` and `generated_token_ids`):
- **generated_text** (`str`, present when `return_text=True`) -- The generated text.
- **generated_token_ids** (`torch.Tensor` or `tf.Tensor`, present when `return_tensors=True`) -- The token
ids of the generated text.
"""
if isinstance(
text_inputs,
(list, tuple, types.GeneratorType, KeyDataset)
if is_torch_available()
else (list, tuple, types.GeneratorType),
):
if isinstance(text_inputs, types.GeneratorType):
text_inputs, _ = itertools.tee(text_inputs)
text_inputs, first_item = (x for x in text_inputs), next(_)
else:
first_item = text_inputs[0]
if isinstance(first_item, (list, tuple, dict)):
# We have one or more prompts in list-of-dicts format, so this is chat mode
if isinstance(first_item, dict):
return super().__call__(Chat(text_inputs), **kwargs)
else:
chats = (Chat(chat) for chat in text_inputs) # π π π
if isinstance(text_inputs, types.GeneratorType):
return super().__call__(chats, **kwargs)
else:
return super().__call__(list(chats), **kwargs)
return super().__call__(text_inputs, **kwargs)
def preprocess(
self,
prompt_text,
prefix="",
handle_long_generation=None,
add_special_tokens=None,
truncation=None,
padding=None,
max_length=None,
continue_final_message=None,
tokenizer_encode_kwargs=None,
**generate_kwargs,
):
# Only set non-None tokenizer kwargs, so as to rely on the tokenizer's defaults
tokenizer_kwargs = {
"add_special_tokens": add_special_tokens,
"truncation": truncation,
"padding": padding,
"max_length": max_length, # NOTE: `max_length` is also a `generate` arg. Use `tokenizer_encode_kwargs` to avoid a name clash
}
tokenizer_kwargs = {key: value for key, value in tokenizer_kwargs.items() if value is not None}
tokenizer_kwargs.update(tokenizer_encode_kwargs or {})
if isinstance(prompt_text, Chat):
tokenizer_kwargs.pop("add_special_tokens", None) # ignore add_special_tokens on chats
# If the user passes a chat that ends in an assistant message, we treat it as a prefill by default
# because very few models support multiple separate, consecutive assistant messages
if continue_final_message is None:
continue_final_message = prompt_text.messages[-1]["role"] == "assistant"
inputs = self.tokenizer.apply_chat_template(
prompt_text.messages,
add_generation_prompt=not continue_final_message,
continue_final_message=continue_final_message,
return_dict=True,
return_tensors=self.framework,
**tokenizer_kwargs,
)
else:
inputs = self.tokenizer(prefix + prompt_text, return_tensors=self.framework, **tokenizer_kwargs)
inputs["prompt_text"] = prompt_text
if handle_long_generation == "hole":
cur_len = inputs["input_ids"].shape[-1]
if "max_new_tokens" in generate_kwargs:
new_tokens = generate_kwargs["max_new_tokens"]
else:
new_tokens = generate_kwargs.get("max_length", self.generation_config.max_length) - cur_len
if new_tokens < 0:
raise ValueError("We cannot infer how many new tokens are expected")
if cur_len + new_tokens > self.tokenizer.model_max_length:
keep_length = self.tokenizer.model_max_length - new_tokens
if keep_length <= 0:
raise ValueError(
"We cannot use `hole` to handle this generation the number of desired tokens exceeds the"
" models max length"
)
inputs["input_ids"] = inputs["input_ids"][:, -keep_length:]
if "attention_mask" in inputs:
inputs["attention_mask"] = inputs["attention_mask"][:, -keep_length:]
return inputs
def _forward(self, model_inputs, **generate_kwargs):
input_ids = model_inputs["input_ids"]
attention_mask = model_inputs.get("attention_mask", None)
# Allow empty prompts
if input_ids.shape[1] == 0:
input_ids = None
attention_mask = None
in_b = 1
else:
in_b = input_ids.shape[0]
prompt_text = model_inputs.pop("prompt_text")
# If there is a prefix, we may need to adjust the generation length. Do so without permanently modifying
# generate_kwargs, as some of the parameterization may come from the initialization of the pipeline.
prefix_length = generate_kwargs.pop("prefix_length", 0)
if prefix_length > 0:
has_max_new_tokens = "max_new_tokens" in generate_kwargs or (
"generation_config" in generate_kwargs
and generate_kwargs["generation_config"].max_new_tokens is not None
)
if not has_max_new_tokens:
generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.generation_config.max_length
generate_kwargs["max_length"] += prefix_length
has_min_new_tokens = "min_new_tokens" in generate_kwargs or (
"generation_config" in generate_kwargs
and generate_kwargs["generation_config"].min_new_tokens is not None
)
if not has_min_new_tokens and "min_length" in generate_kwargs:
generate_kwargs["min_length"] += prefix_length
# User-defined `generation_config` passed to the pipeline call take precedence
if "generation_config" not in generate_kwargs:
generate_kwargs["generation_config"] = self.generation_config
output = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs)
if isinstance(output, ModelOutput):
generated_sequence = output.sequences
other_outputs = {k: v for k, v in output.items() if k not in {"sequences", "past_key_values"}}
out_b = generated_sequence.shape[0]
if self.framework == "pt":
for key, value in other_outputs.items():
if isinstance(value, torch.Tensor) and value.shape[0] == out_b:
other_outputs[key] = value.reshape(in_b, out_b // in_b, *value.shape[1:])
if isinstance(value, tuple) and len(value[0]) == out_b:
value = torch.stack(value).swapaxes(0, 1)
other_outputs[key] = value
elif self.framework == "tf":
for key, value in other_outputs.items():
if isinstance(value, tf.Tensor) and value.shape[0] == out_b:
other_outputs[key] = tf.reshape(value, (in_b, out_b // in_b, *value.shape[1:]))
if isinstance(value, tuple) and len(value[0]) == out_b:
value = tf.stack(value).swapaxes(0, 1)
other_outputs[key] = value
else:
generated_sequence = output
other_outputs = {}
out_b = generated_sequence.shape[0]
if self.framework == "pt":
generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:])
elif self.framework == "tf":
generated_sequence = tf.reshape(generated_sequence, (in_b, out_b // in_b, *generated_sequence.shape[1:]))
model_outputs = {
"generated_sequence": generated_sequence,
"input_ids": input_ids,
"prompt_text": prompt_text,
}
if other_outputs:
model_outputs.update({"additional_outputs": other_outputs})
return model_outputs
def postprocess(
self,
model_outputs,
return_type=ReturnType.FULL_TEXT,
clean_up_tokenization_spaces=True,
continue_final_message=None,
skip_special_tokens=None,
):
generated_sequence = model_outputs["generated_sequence"][0]
input_ids = model_outputs["input_ids"]
prompt_text = model_outputs["prompt_text"]
generated_sequence = generated_sequence.numpy().tolist()
records = []
other_outputs = model_outputs.get("additional_outputs", {})
split_keys = {}
if other_outputs:
if self.framework == "pt":
for k, v in other_outputs.items():
if isinstance(v, torch.Tensor) and v.shape[0] == len(generated_sequence):
split_keys[k] = v.numpy().tolist()
elif self.framework == "tf":
for k, v in other_outputs.items():
if isinstance(v, tf.Tensor) and v.shape[0] == len(generated_sequence):
split_keys[k] = v.numpy().tolist()
skip_special_tokens = skip_special_tokens if skip_special_tokens is not None else True
for idx, sequence in enumerate(generated_sequence):
if return_type == ReturnType.TENSORS:
record = {"generated_token_ids": sequence}
elif return_type in {ReturnType.NEW_TEXT, ReturnType.FULL_TEXT}:
# Decode text
text = self.tokenizer.decode(
sequence,
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
# Remove PADDING prompt of the sequence if XLNet or Transfo-XL model is used
if input_ids is None:
prompt_length = 0
else:
prompt_length = len(
self.tokenizer.decode(
input_ids[0],
skip_special_tokens=skip_special_tokens,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
)
)
all_text = text[prompt_length:]
if return_type == ReturnType.FULL_TEXT:
if isinstance(prompt_text, str):
all_text = prompt_text + all_text
elif isinstance(prompt_text, Chat):
if continue_final_message is None:
# If the user passes a chat ending in an assistant message, we treat it as a prefill by
# default because very few models support multiple separate, consecutive assistant messages
continue_final_message = prompt_text.messages[-1]["role"] == "assistant"
if continue_final_message:
# With assistant prefill, concat onto the end of the last message
all_text = list(prompt_text.messages)[:-1] + [
{
"role": prompt_text.messages[-1]["role"],
"content": prompt_text.messages[-1]["content"] + all_text,
}
]
else:
# When we're not starting from a prefill, the output is a new assistant message
all_text = list(prompt_text.messages) + [{"role": "assistant", "content": all_text}]
record = {"generated_text": all_text}
for key, values in split_keys.items():
record[key] = values[idx]
records.append(record)
return records
|