| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """ Conditional text generation with the auto-regressive models of the library (GPT/GPT-2/CTRL/Transformer-XL/XLNet) |
| | """ |
| |
|
| |
|
| | import argparse |
| | import logging |
| |
|
| | import numpy as np |
| | import torch |
| |
|
| | from transformers import ( |
| | CTRLLMHeadModel, |
| | CTRLTokenizer, |
| | GPT2LMHeadModel, |
| | GPT2Tokenizer, |
| | OpenAIGPTLMHeadModel, |
| | OpenAIGPTTokenizer, |
| | TransfoXLLMHeadModel, |
| | TransfoXLTokenizer, |
| | XLMTokenizer, |
| | XLMWithLMHeadModel, |
| | XLNetLMHeadModel, |
| | XLNetTokenizer, |
| | ) |
| |
|
| |
|
| | logging.basicConfig( |
| | format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S", level=logging.INFO, |
| | ) |
| | logger = logging.getLogger(__name__) |
| |
|
| | MAX_LENGTH = int(10000) |
| |
|
| | MODEL_CLASSES = { |
| | "gpt2": (GPT2LMHeadModel, GPT2Tokenizer), |
| | "ctrl": (CTRLLMHeadModel, CTRLTokenizer), |
| | "openai-gpt": (OpenAIGPTLMHeadModel, OpenAIGPTTokenizer), |
| | "xlnet": (XLNetLMHeadModel, XLNetTokenizer), |
| | "transfo-xl": (TransfoXLLMHeadModel, TransfoXLTokenizer), |
| | "xlm": (XLMWithLMHeadModel, XLMTokenizer), |
| | } |
| |
|
| | |
| | |
| | |
| | PADDING_TEXT = """In 1991, the remains of Russian Tsar Nicholas II and his family |
| | (except for Alexei and Maria) are discovered. |
| | The voice of Nicholas's young son, Tsarevich Alexei Nikolaevich, narrates the |
| | remainder of the story. 1883 Western Siberia, |
| | a young Grigori Rasputin is asked by his father and a group of men to perform magic. |
| | Rasputin has a vision and denounces one of the men as a horse thief. Although his |
| | father initially slaps him for making such an accusation, Rasputin watches as the |
| | man is chased outside and beaten. Twenty years later, Rasputin sees a vision of |
| | the Virgin Mary, prompting him to become a priest. Rasputin quickly becomes famous, |
| | with people, even a bishop, begging for his blessing. <eod> </s> <eos>""" |
| |
|
| |
|
| | def set_seed(args): |
| | np.random.seed(args.seed) |
| | torch.manual_seed(args.seed) |
| | if args.n_gpu > 0: |
| | torch.cuda.manual_seed_all(args.seed) |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | def prepare_ctrl_input(args, _, tokenizer, prompt_text): |
| | if args.temperature > 0.7: |
| | logger.info("CTRL typically works better with lower temperatures (and lower top_k).") |
| |
|
| | encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=False) |
| | if not any(encoded_prompt[0] == x for x in tokenizer.control_codes.values()): |
| | logger.info("WARNING! You are not starting your generation from a control code so you won't get good results") |
| | return prompt_text |
| |
|
| |
|
| | def prepare_xlm_input(args, model, tokenizer, prompt_text): |
| | |
| |
|
| | |
| | use_lang_emb = hasattr(model.config, "use_lang_emb") and model.config.use_lang_emb |
| | if hasattr(model.config, "lang2id") and use_lang_emb: |
| | available_languages = model.config.lang2id.keys() |
| | if args.xlm_language in available_languages: |
| | language = args.xlm_language |
| | else: |
| | language = None |
| | while language not in available_languages: |
| | language = input("Using XLM. Select language in " + str(list(available_languages)) + " >>> ") |
| |
|
| | model.config.lang_id = model.config.lang2id[language] |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | return prompt_text |
| |
|
| |
|
| | def prepare_xlnet_input(args, _, tokenizer, prompt_text): |
| | prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text |
| | return prompt_text |
| |
|
| |
|
| | def prepare_transfoxl_input(args, _, tokenizer, prompt_text): |
| | prompt_text = (args.padding_text if args.padding_text else PADDING_TEXT) + prompt_text |
| | return prompt_text |
| |
|
| |
|
| | PREPROCESSING_FUNCTIONS = { |
| | "ctrl": prepare_ctrl_input, |
| | "xlm": prepare_xlm_input, |
| | "xlnet": prepare_xlnet_input, |
| | "transfo-xl": prepare_transfoxl_input, |
| | } |
| |
|
| |
|
| | def adjust_length_to_model(length, max_sequence_length): |
| | if length < 0 and max_sequence_length > 0: |
| | length = max_sequence_length |
| | elif 0 < max_sequence_length < length: |
| | length = max_sequence_length |
| | elif length < 0: |
| | length = MAX_LENGTH |
| | return length |
| |
|
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument( |
| | "--model_type", |
| | default=None, |
| | type=str, |
| | required=True, |
| | help="Model type selected in the list: " + ", ".join(MODEL_CLASSES.keys()), |
| | ) |
| | parser.add_argument( |
| | "--model_name_or_path", |
| | default=None, |
| | type=str, |
| | required=True, |
| | help="Path to pre-trained model or shortcut name selected in the list: " + ", ".join(MODEL_CLASSES.keys()), |
| | ) |
| |
|
| | parser.add_argument("--prompt", type=str, default="") |
| | parser.add_argument("--length", type=int, default=20) |
| | parser.add_argument("--stop_token", type=str, default=None, help="Token at which text generation is stopped") |
| |
|
| | parser.add_argument( |
| | "--temperature", |
| | type=float, |
| | default=1.0, |
| | help="temperature of 1.0 has no effect, lower tend toward greedy sampling", |
| | ) |
| | parser.add_argument( |
| | "--repetition_penalty", type=float, default=1.0, help="primarily useful for CTRL model; in that case, use 1.2" |
| | ) |
| | parser.add_argument("--k", type=int, default=0) |
| | parser.add_argument("--p", type=float, default=0.9) |
| |
|
| | parser.add_argument("--padding_text", type=str, default="", help="Padding text for Transfo-XL and XLNet.") |
| | parser.add_argument("--xlm_language", type=str, default="", help="Optional language when used with the XLM model.") |
| |
|
| | parser.add_argument("--seed", type=int, default=42, help="random seed for initialization") |
| | parser.add_argument("--no_cuda", action="store_true", help="Avoid using CUDA when available") |
| | parser.add_argument("--num_return_sequences", type=int, default=1, help="The number of samples to generate.") |
| | args = parser.parse_args() |
| |
|
| | args.device = torch.device("cuda" if torch.cuda.is_available() and not args.no_cuda else "cpu") |
| | args.n_gpu = 0 if args.no_cuda else torch.cuda.device_count() |
| |
|
| | set_seed(args) |
| |
|
| | |
| | try: |
| | args.model_type = args.model_type.lower() |
| | model_class, tokenizer_class = MODEL_CLASSES[args.model_type] |
| | except KeyError: |
| | raise KeyError("the model {} you specified is not supported. You are welcome to add it and open a PR :)") |
| |
|
| | tokenizer = tokenizer_class.from_pretrained(args.model_name_or_path) |
| | model = model_class.from_pretrained(args.model_name_or_path) |
| | model.to(args.device) |
| |
|
| | args.length = adjust_length_to_model(args.length, max_sequence_length=model.config.max_position_embeddings) |
| | logger.info(args) |
| |
|
| | prompt_text = args.prompt if args.prompt else input("Model prompt >>> ") |
| |
|
| | |
| | requires_preprocessing = args.model_type in PREPROCESSING_FUNCTIONS.keys() |
| | if requires_preprocessing: |
| | prepare_input = PREPROCESSING_FUNCTIONS.get(args.model_type) |
| | preprocessed_prompt_text = prepare_input(args, model, tokenizer, prompt_text) |
| | encoded_prompt = tokenizer.encode( |
| | preprocessed_prompt_text, add_special_tokens=False, return_tensors="pt", add_space_before_punct_symbol=True |
| | ) |
| | else: |
| | encoded_prompt = tokenizer.encode(prompt_text, add_special_tokens=True, return_tensors="pt") |
| | encoded_prompt = encoded_prompt.to(args.device) |
| |
|
| | if encoded_prompt.size()[-1] == 0: |
| | input_ids = None |
| | else: |
| | input_ids = encoded_prompt |
| |
|
| | output_sequences = model.generate( |
| | input_ids=input_ids, |
| | max_length=args.length + len(encoded_prompt[0]), |
| | temperature=args.temperature, |
| | top_k=args.k, |
| | top_p=args.p, |
| | repetition_penalty=args.repetition_penalty, |
| | do_sample=True, |
| | num_return_sequences=args.num_return_sequences, |
| | ) |
| |
|
| | |
| | if len(output_sequences.shape) > 2: |
| | output_sequences.squeeze_() |
| |
|
| | generated_sequences = [] |
| |
|
| | for generated_sequence_idx, generated_sequence in enumerate(output_sequences): |
| | print("=== GENERATED SEQUENCE {} ===".format(generated_sequence_idx + 1)) |
| | generated_sequence = generated_sequence.tolist() |
| |
|
| | |
| | text = tokenizer.decode(generated_sequence, clean_up_tokenization_spaces=True) |
| |
|
| | |
| | text = text[: text.find(args.stop_token) if args.stop_token else None] |
| |
|
| | |
| | total_sequence = ( |
| | prompt_text + text[len(tokenizer.decode(encoded_prompt[0], clean_up_tokenization_spaces=True)) :] |
| | ) |
| |
|
| | generated_sequences.append(total_sequence) |
| | print(total_sequence) |
| |
|
| | return generated_sequences |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|