Spaces:
Runtime error
Runtime error
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import torch | |
| from collections import defaultdict | |
| import gradio as gr | |
| from optimum.onnxruntime import ORTModelForCausalLM | |
| import itertools | |
| import regex as re | |
| import logging | |
| user_token = "<User>" | |
| eos_token = "<EOS>" | |
| bos_token = "<BOS>" | |
| bot_token = "<Assistant>" | |
| logger = logging.getLogger() | |
| handler = logging.StreamHandler() | |
| formatter = logging.Formatter( | |
| '%(asctime)s %(name)-12s %(levelname)-8s %(message)s') | |
| handler.setFormatter(formatter) | |
| logger.addHandler(handler) | |
| logger.setLevel(logging.INFO) | |
| max_context_length = 750 | |
| def format(history): | |
| prompt = bos_token | |
| for idx, txt in enumerate(history): | |
| if idx % 2 == 0: | |
| prompt += f"{user_token}{txt}{eos_token}" | |
| else: | |
| prompt += f"{bot_token}{txt}" | |
| prompt += bot_token | |
| return prompt | |
| def remove_spaces_between_chinese(text): | |
| rex = r"(?<![a-zA-Z]{2})(?<=[a-zA-Z]{1})[ ]+(?=[a-zA-Z] |.$)|(?<=\p{Han}) +" | |
| return re.sub(rex, "", text, 0, re.MULTILINE | re.UNICODE) | |
| def gradio(model, tokenizer): | |
| def response( | |
| user_input, | |
| chat_history, | |
| top_k, | |
| top_p, | |
| temperature, | |
| repetition_penalty, | |
| no_repeat_ngram_size, | |
| ): | |
| history = list(itertools.chain(*chat_history)) | |
| history.append(user_input) | |
| prompt = format(history) | |
| input_ids = tokenizer.encode( | |
| prompt, | |
| return_tensors="pt", | |
| add_special_tokens=False, | |
| )[:, -max_context_length:] | |
| prompt_length = input_ids.shape[1] | |
| beam_output = model.generate( | |
| input_ids, | |
| pad_token_id=tokenizer.pad_token_id, | |
| max_new_tokens=250, | |
| num_beams=1, # with cpu | |
| top_k=top_k, | |
| top_p=top_p, | |
| no_repeat_ngram_size=no_repeat_ngram_size, | |
| temperature=temperature, | |
| repetition_penalty=repetition_penalty, | |
| early_stopping=True, | |
| do_sample=True | |
| ) | |
| output = beam_output[0][prompt_length:] | |
| generated = remove_spaces_between_chinese(tokenizer.decode(output, skip_special_tokens=True, clean_up_tokenization_spaces=True)) | |
| logger.info(prompt+generated) | |
| return generated | |
| bot = gr.Chatbot(show_copy_button=True, show_share_button=True, height="2000") | |
| with gr.Blocks() as demo: | |
| gr.Markdown("GPT2 chatbot | Powered by nlp-greyfoss") | |
| with gr.Accordion("Parameters in generation", open=False): | |
| with gr.Row(): | |
| top_k = gr.Slider( | |
| 2.0, | |
| 100.0, | |
| label="top_k", | |
| step=1, | |
| value=50, | |
| info="Limit the number of candidate tokens considered during decoding.", | |
| ) | |
| top_p = gr.Slider( | |
| 0.1, | |
| 1.0, | |
| label="top_p", | |
| value=0.9, | |
| info="Control the diversity of the output by selecting tokens with cumulative probabilities up to the Top-P threshold.", | |
| ) | |
| temperature = gr.Slider( | |
| 0.1, | |
| 2.0, | |
| label="temperature", | |
| value=0.9, | |
| info="Control the randomness of the generated text. A higher temperature results in more diverse and unpredictable outputs, while a lower temperature produces more conservative and coherent text.", | |
| ) | |
| repetition_penalty = gr.Slider( | |
| 0.1, | |
| 2.0, | |
| label="repetition_penalty", | |
| value=1.2, | |
| info="Discourage the model from generating repetitive tokens in a sequence.", | |
| ) | |
| no_repeat_ngram_size = gr.Slider( | |
| 0, | |
| 100, | |
| label="no_repeat_ngram_size", | |
| step=1, | |
| value=5, | |
| info="Prevent the model from generating sequences of n consecutive tokens that have already been generated in the context. ", | |
| ) | |
| gr.ChatInterface( | |
| response, | |
| chatbot=bot, | |
| additional_inputs=[ | |
| top_k, | |
| top_p, | |
| temperature, | |
| repetition_penalty, | |
| no_repeat_ngram_size, | |
| ], | |
| retry_btn = "🔄 Regenerate", | |
| undo_btn = "↩️ Remove last turn", | |
| clear_btn = "➕ New conversation", | |
| examples=[ | |
| ["写一篇介绍人工智能的文章。", 30, 0.9, 0.95, 1.2, 5], | |
| ["给我讲一个笑话。", 50, 0.8, 0.9, 1.2, 6], | |
| ["Can you describe spring in English?", 50, 0.9, 1.0, 1, 5] | |
| ] | |
| ) | |
| demo.queue().launch() | |
| tokenizer = AutoTokenizer.from_pretrained("greyfoss/gpt2-chatbot-chinese") | |
| model = ORTModelForCausalLM.from_pretrained("greyfoss/gpt2-chatbot-chinese", export=True) | |
| gradio(model, tokenizer) | |