| from typing import cast | |
| import gradio as gr | |
| import torch | |
| from transformers import BertTokenizerFast, ErnieForCausalLM | |
| def load_model(): | |
| tokenizer = BertTokenizerFast.from_pretrained("wybxc/new-yiri") | |
| assert isinstance(tokenizer, BertTokenizerFast) | |
| model = ErnieForCausalLM.from_pretrained("wybxc/new-yiri") | |
| assert isinstance(model, ErnieForCausalLM) | |
| return tokenizer, model | |
| def generate( | |
| tokenizer: BertTokenizerFast, | |
| model: ErnieForCausalLM, | |
| input_str: str, | |
| alpha: float, | |
| topk: int, | |
| ): | |
| input_ids = tokenizer.encode(input_str, return_tensors="pt") | |
| input_ids = cast(torch.Tensor, input_ids) | |
| outputs = model.generate( | |
| input_ids, | |
| max_new_tokens=100, | |
| penalty_alpha=alpha, | |
| top_k=topk, | |
| early_stopping=True, | |
| decoder_start_token_id=tokenizer.sep_token_id, | |
| eos_token_id=tokenizer.sep_token_id, | |
| ) | |
| i, *_ = torch.nonzero(outputs[0] == tokenizer.sep_token_id) | |
| output = tokenizer.decode( | |
| outputs[0, i:], | |
| skip_special_tokens=True, | |
| ).replace(" ", "") | |
| return output | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot().style(height=500) | |
| with gr.Row(): | |
| with gr.Column(scale=4): | |
| msg = gr.Textbox( | |
| show_label=False, placeholder="Enter text and press enter" | |
| ).style(container=False) | |
| msg = cast(gr.Textbox, msg) | |
| with gr.Column(scale=1): | |
| button = gr.Button("Generate") | |
| with gr.Column(scale=1): | |
| clear = gr.Button("Clear") | |
| with gr.Column(scale=1): | |
| alpha = gr.Slider(0, 1, 0.5, step=0.01, label="Penalty Alpha") | |
| topk = gr.Slider(1, 50, 5, step=1, label="Top K") | |
| tokenizer, model = load_model() | |
| def user(user_message: str, history: list[list[str]]): | |
| return "", [*history, [user_message, None]] | |
| def bot(history: list[list[str]], alpha: float, topk: int): | |
| user_message = history[-1][0] | |
| bot_message = generate( | |
| tokenizer, | |
| model, | |
| user_message, | |
| alpha=alpha, | |
| topk=topk, | |
| ) | |
| history[-1][1] = bot_message | |
| return history | |
| msg.submit(user, inputs=[msg, chatbot], outputs=[msg, chatbot]).then( | |
| bot, inputs=[chatbot, alpha, topk], outputs=[chatbot] | |
| ) | |
| button.click(user, inputs=[msg, chatbot], outputs=[msg, chatbot]).then( | |
| bot, inputs=[chatbot, alpha, topk], outputs=[chatbot] | |
| ) | |
| clear.click(lambda: None, None, chatbot) | |
| if __name__ == "__main__": | |
| demo.queue(concurrency_count=3) | |
| demo.launch() | |