| ## Generating Chinese poetry by topic. | |
| ```python | |
| from transformers import * | |
| tokenizer = BertTokenizer.from_pretrained("gaochangkuan/model_dir") | |
| model = AutoModelWithLMHead.from_pretrained("gaochangkuan/model_dir") | |
| prompt= '''<s>田园躬耕''' | |
| length= 84 | |
| stop_token='</s>' | |
| temperature = 1.2 | |
| repetition_penalty=1.3 | |
| k= 30 | |
| p= 0.95 | |
| device ='cuda' | |
| seed=2020 | |
| no_cuda=False | |
| prompt_text = prompt if prompt else input("Model prompt >>> ") | |
| encoded_prompt = tokenizer.encode( | |
| '<s>'+prompt_text+'<sep>', | |
| add_special_tokens=False, | |
| return_tensors="pt" | |
| ) | |
| encoded_prompt = encoded_prompt.to(device) | |
| output_sequences = model.generate( | |
| input_ids=encoded_prompt, | |
| max_length=length, | |
| min_length=10, | |
| do_sample=True, | |
| early_stopping=True, | |
| num_beams=10, | |
| temperature=temperature, | |
| top_k=k, | |
| top_p=p, | |
| repetition_penalty=repetition_penalty, | |
| bad_words_ids=None, | |
| bos_token_id=tokenizer.bos_token_id, | |
| pad_token_id=tokenizer.pad_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| length_penalty=1.2, | |
| no_repeat_ngram_size=2, | |
| num_return_sequences=1, | |
| attention_mask=None, | |
| decoder_start_token_id=tokenizer.bos_token_id,) | |
| generated_sequence = output_sequences[0].tolist() | |
| text = tokenizer.decode(generated_sequence) | |
| text = text[: text.find(stop_token) if stop_token else None] | |
| print(''.join(text).replace(' ','').replace('<pad>','').replace('<s>','')) | |
| ``` | |