| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| """Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/CTRL/Transformer-XL/XLNet)""" |
|
|
| import argparse |
| import inspect |
| import logging |
|
|
| import torch |
| from accelerate import PartialState |
| from accelerate.utils import set_seed |
|
|
| from transformers import ( |
| AutoTokenizer, |
| BloomForCausalLM, |
| CTRLLMHeadModel, |
| CTRLTokenizer, |
| GenerationMixin, |
| GPT2LMHeadModel, |
| GPT2Tokenizer, |
| GPTJForCausalLM, |
| LlamaForCausalLM, |
| OpenAIGPTLMHeadModel, |
| OpenAIGPTTokenizer, |
| OPTForCausalLM, |
| XLMTokenizer, |
| XLMWithLMHeadModel, |
| XLNetLMHeadModel, |
| XLNetTokenizer, |
| ) |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
|
|
| logging.basicConfig( |
| format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
| datefmt="%m/%d/%Y %H:%M:%S", |
| level=logging.INFO, |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| MAX_LENGTH = 10000 |
|
|
| MODEL_CLASSES = { |
| "gpt2": (GPT2LMHeadModel, GPT2Tokenizer), |
| "ctrl": (CTRLLMHeadModel, CTRLTokenizer), |
| "openai-gpt": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer), |
| "xlnet": (XLNetLMHeadModel, XLNetTokenizer), |
| "xlm": (XLMWithLMHeadModel, XLMTokenizer), |
| "gptj": (GPTJForCausalLM, AutoTokenizer), |
| "bloom": (BloomForCausalLM, AutoTokenizer), |
| "llama": (LlamaForCausalLM, AutoTokenizer), |
| "opt": (OPTForCausalLM, GPT2Tokenizer), |
| } |
|
|
| |
| |
| |
| PREFIX = """In 1991, the remains of Russian Tsar Nicholas II and his family |
| (except for Alexei and Maria) are discovered. |
| The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the |
| remainder of the story. 1883 Western Siberia, |
| a young Grigori Rasputin is asked by his father and a group of men to perform magic. |
| Rasputin has a vision and denounces one of the men as a horse thief. Although his |
| father initially slaps him for making such an accusation, Rasputin watches as the |
| man is chased outside and beaten. Twenty years later, Rasputin sees a vision of |
| the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous, |
| with people, even a bishop, begging for his blessing. <eod> </s> <eos>""" |
|
|
|
|
| |
| |
| |
|
|
|
|
| def prepare_ctrl_input(args, _, tokenizer, prompt_text): |
| if args.temperature > 0.7: |
| logger.info("CTRL typically works better with lower temperatures (and lower top_k).") |
|
|
| encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False) |
| if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()): |
| logger.info("WARNING! You are not starting your generation from a control code so you won't get good results") |
| return prompt_text |
|
|
|
|
| def prepare_xlm_input(args, model, tokenizer, prompt_text): |
| |
|
|
| |
| use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb |
| if hasattr(model.config, "lang2id") and use_lang_emb: |
| available_languages = model.config.lang2id.keys() |
| if args.xlm_language in available_languages: |
| language = args.xlm_language |
| else: |
| language = None |
| while language not in available_languages: |
| language = input("Using XLM. Select language in " + str(list(available_languages)) + " >>> ") |
|
|
| model.config.lang_id = model.config.lang2id[language] |
| |
|
|
| |
| |
| |
| |
| |
|
|
| return prompt_text |
|
|
|
|
| def prepare_xlnet_input(args, _, tokenizer, prompt_text): |
| prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX |
| prompt_text = prefix + prompt_text |
| return prompt_text |
|
|
|
|
| def prepare_transfoxl_input(args, _, tokenizer, prompt_text): |
| prefix = args.prefix if args.prefix else args.padding_text if args.padding_text else PREFIX |
| prompt_text = prefix + prompt_text |
| return prompt_text |
|
|
|
|
| PREPROCESSING_FUNCTIONS = { |
| "ctrl": prepare_ctrl_input, |
| "xlm": prepare_xlm_input, |
| "xlnet": prepare_xlnet_input, |
| "transfo-xl": prepare_transfoxl_input, |
| } |
|
|
|
|
| def adjust_length_to_model(length, max_sequence_length): |
| if length < 0 and max_sequence_length > 0: |
| length = max_sequence_length |
| elif 0 < max_sequence_length < length: |
| length = max_sequence_length |
| elif length < 0: |
| length = MAX_LENGTH |
| return length |
|
|
|
|
| def sparse_model_config(model_config): |
| embedding_size = None |
| if hasattr(model_config, "hidden_size"): |
| embedding_size = model_config.hidden_size |
| elif hasattr(model_config, "n_embed"): |
| embedding_size = model_config.n_embed |
| elif hasattr(model_config, "n_embd"): |
| embedding_size = model_config.n_embd |
|
|
| num_head = None |
| if hasattr(model_config, "num_attention_heads"): |
| num_head = model_config.num_attention_heads |
| elif hasattr(model_config, "n_head"): |
| num_head = model_config.n_head |
|
|
| if embedding_size is None or num_head is None or num_head == 0: |
| raise ValueError("Check the model config") |
|
|
| num_embedding_size_per_head = int(embedding_size / num_head) |
| if hasattr(model_config, "n_layer"): |
| num_layer = model_config.n_layer |
| elif hasattr(model_config, "num_hidden_layers"): |
| num_layer = model_config.num_hidden_layers |
| else: |
| raise ValueError("Number of hidden layers couldn't be determined from the model config") |
|
|
| return num_layer, num_head, num_embedding_size_per_head |
|
|
|
|
| def generate_past_key_values(model, batch_size, seq_len): |
| num_block_layers, num_attention_heads, num_embedding_size_per_head = sparse_model_config(model.config) |
| if model.config.model_type == "bloom": |
| past_key_values = tuple( |
| ( |
| torch.empty(int(num_attention_heads * batch_size), num_embedding_size_per_head, seq_len) |
| .to(model.dtype) |
| .to(model.device), |
| torch.empty(int(num_attention_heads * batch_size), seq_len, num_embedding_size_per_head) |
| .to(model.dtype) |
| .to(model.device), |
| ) |
| for _ in range(num_block_layers) |
| ) |
| else: |
| past_key_values = tuple( |
| ( |
| torch.empty(batch_size, num_attention_heads, seq_len, num_embedding_size_per_head) |
| .to(model.dtype) |
| .to(model.device), |
| torch.empty(batch_size, num_attention_heads, seq_len, num_embedding_size_per_head) |
| .to(model.dtype) |
| .to(model.device), |
| ) |
| for _ in range(num_block_layers) |
| ) |
| return past_key_values |
|
|
|
|
| def prepare_jit_inputs(inputs, model, tokenizer): |
| batch_size = len(inputs) |
| dummy_input = tokenizer.batch_encode_plus(inputs, return_tensors="pt") |
| dummy_input = dummy_input.to(model.device) |
| if model.config.use_cache: |
| dummy_input["past_key_values"] = generate_past_key_values(model, batch_size, 1) |
| dummy_input["attention_mask"] = torch.cat( |
| [ |
| torch.zeros(dummy_input["attention_mask"].shape[0], 1) |
| .to(dummy_input["attention_mask"].dtype) |
| .to(model.device), |
| dummy_input["attention_mask"], |
| ], |
| -1, |
| ) |
| return dummy_input |
|
|
|
|
| class _ModelFallbackWrapper(GenerationMixin): |
| __slots__ = ("_optimized", "_default") |
|
|
| def __init__(self, optimized, default): |
| self._optimized = optimized |
| self._default = default |
|
|
| def __call__(self, *args, **kwargs): |
| if kwargs["past_key_values"] is None and self._default.config.use_cache: |
| kwargs["past_key_values"] = generate_past_key_values(self._default, kwargs["input_ids"].shape[0], 0) |
| kwargs.pop("position_ids", None) |
| for k in list(kwargs.keys()): |
| if kwargs[k] is None or isinstance(kwargs[k], bool): |
| kwargs.pop(k) |
| outputs = self._optimized(**kwargs) |
| lm_logits = outputs[0] |
| past_key_values = outputs[1] |
| fixed_output = CausalLMOutputWithPast( |
| loss=None, |
| logits=lm_logits, |
| past_key_values=past_key_values, |
| hidden_states=None, |
| attentions=None, |
| ) |
| return fixed_output |
|
|
| def __getattr__(self, item): |
| return getattr(self._default, item) |
|
|
| def prepare_inputs_for_generation( |
| self, input_ids, past_key_values=None, inputs_embeds=None, use_cache=None, **kwargs |
| ): |
| return self._default.prepare_inputs_for_generation( |
| input_ids, past_key_values=past_key_values, inputs_embeds=inputs_embeds, use_cache=use_cache, **kwargs |
| ) |
|
|
| def _reorder_cache( |
| self, past_key_values: tuple[tuple[torch.Tensor]], beam_idx: torch.Tensor |
| ) -> tuple[tuple[torch.Tensor]]: |
| """ |
| This function is used to re-order the `past_key_values` cache if [`~PretrainedModel.beam_search`] or |
| [`~PretrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct |
| beam_idx at every generation step. |
| """ |
| return self._default._reorder_cache(past_key_values, beam_idx) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--model_type", |
| default=None, |
| type=str, |
| required=True, |
| help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()), |
| ) |
| parser.add_argument( |
| "--model_name_or_path", |
| default=None, |
| type=str, |
| required=True, |
| help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()), |
| ) |
|
|
| parser.add_argument("--prompt", type=str, default="") |
| parser.add_argument("--length", type=int, default=20) |
| parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped") |
|
|
| parser.add_argument( |
| "--temperature", |
| type=float, |
| default=1.0, |
| help="temperature of 1.0 has no effect, lower tend toward greedy sampling", |
| ) |
| parser.add_argument( |
| "--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2" |
| ) |
| parser.add_argument("--k", type=int, default=0) |
| parser.add_argument("--p", type=float, default=0.9) |
|
|
| parser.add_argument("--prefix", type=str, default="", help="Text added prior to input.") |
| parser.add_argument("--padding_text", type=str, default="", help="Deprecated, the use of `--prefix` is preferred.") |
| parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.") |
|
|
| parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") |
| parser.add_argument( |
| "--use_cpu", |
| action="store_true", |
| help="Whether or not to use cpu. If set to False, we will use gpu/npu or mps device if available", |
| ) |
| parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.") |
| parser.add_argument( |
| "--fp16", |
| action="store_true", |
| help="Whether to use 16-bit (mixed) precision instead of 32-bit", |
| ) |
| parser.add_argument("--jit", action="store_true", help="Whether or not to use jit trace to accelerate inference") |
| args = parser.parse_args() |
|
|
| |
| distributed_state = PartialState(cpu=args.use_cpu) |
|
|
| logger.warning(f"device: {distributed_state.device}, 16-bits inference: {args.fp16}") |
|
|
| if args.seed is not None: |
| set_seed(args.seed) |
|
|
| |
| try: |
| args.model_type = args.model_type.lower() |
| model_class, tokenizer_class = MODEL_CLASSES[args.model_type] |
| except KeyError: |
| raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)") |
|
|
| tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| model = model_class.from_pretrained(args.model_name_or_path) |
|
|
| |
| model.to(distributed_state.device) |
|
|
| if args.fp16: |
| model.half() |
| max_seq_length = getattr(model.config, "max_position_embeddings", 0) |
| args.length = adjust_length_to_model(args.length, max_sequence_length=max_seq_length) |
| logger.info(args) |
|
|
| prompt_text = args.prompt if args.prompt else input("Model prompt >>> ") |
|
|
| |
| requires_preprocessing = args.model_type in PREPROCESSING_FUNCTIONS |
| if requires_preprocessing: |
| prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type) |
| preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text) |
|
|
| tokenizer_kwargs = {} |
|
|
| encoded_prompt = tokenizer.encode( |
| preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", **tokenizer_kwargs |
| ) |
| else: |
| prefix = args.prefix if args.prefix else args.padding_text |
| encoded_prompt = tokenizer.encode(prefix + prompt_text, add_special_tokens=False, return_tensors="pt") |
| encoded_prompt = encoded_prompt.to(distributed_state.device) |
|
|
| if encoded_prompt.size()[-1] == 0: |
| input_ids = None |
| else: |
| input_ids = encoded_prompt |
|
|
| if args.jit: |
| jit_input_texts = ["enable jit"] |
| jit_inputs = prepare_jit_inputs(jit_input_texts, model, tokenizer) |
| torch._C._jit_set_texpr_fuser_enabled(False) |
| model.config.return_dict = False |
| if hasattr(model, "forward"): |
| sig = inspect.signature(model.forward) |
| else: |
| sig = inspect.signature(model.__call__) |
| jit_inputs = tuple(jit_inputs[key] for key in sig.parameters if jit_inputs.get(key, None) is not None) |
| traced_model = torch.jit.trace(model, jit_inputs, strict=False) |
| traced_model = torch.jit.freeze(traced_model.eval()) |
| traced_model(*jit_inputs) |
| traced_model(*jit_inputs) |
|
|
| model = _ModelFallbackWrapper(traced_model, model) |
|
|
| output_sequences = model.generate( |
| input_ids=input_ids, |
| max_length=args.length + len(encoded_prompt[0]), |
| temperature=args.temperature, |
| top_k=args.k, |
| top_p=args.p, |
| repetition_penalty=args.repetition_penalty, |
| do_sample=True, |
| num_return_sequences=args.num_return_sequences, |
| ) |
|
|
| |
| if len(output_sequences.shape) > 2: |
| output_sequences.squeeze_() |
|
|
| generated_sequences = [] |
|
|
| for generated_sequence_idx, generated_sequence in enumerate(output_sequences): |
| print(f"=== GENERATED SEQUENCE {generated_sequence_idx + 1} ===") |
| generated_sequence = generated_sequence.tolist() |
|
|
| |
| text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) |
|
|
| |
| text = text[: text.find(args.stop_token) if args.stop_token else None] |
|
|
| |
| total_sequence = ( |
| prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :] |
| ) |
|
|
| generated_sequences.append(total_sequence) |
| print(total_sequence) |
|
|
| return generated_sequences |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|