Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from typing import List, Optional | |
| from transformers import BertTokenizer, BartForConditionalGeneration | |
| title = "HIT-TMG/dialogue-bart-large-chinese-DuSinc" | |
| description = """ | |
| This is a fine-tuned version of HIT-TMG/dialogue-bart-large-chinese on the DuSinc dataset. | |
| But it only has chit-chat ability without knowledge since we haven't introduced knowledge retrieval interface yet.\n | |
| See some details of model card at https://huggingface.co/HIT-TMG/dialogue-bart-large-chinese-DuSinc . \n\n | |
| Besides starting the conversation from scratch, you can also input the whole dialogue history utterance by utterance seperated by '[SEP]'. \n | |
| """ | |
| tokenizer = BertTokenizer.from_pretrained("HIT-TMG/dialogue-bart-large-chinese-DuSinc") | |
| model = BartForConditionalGeneration.from_pretrained("HIT-TMG/dialogue-bart-large-chinese-DuSinc") | |
| tokenizer.truncation_side = 'left' | |
| max_length = 512 | |
| examples = [ | |
| ["你有什么爱好吗"], | |
| ["你好。[SEP]嘿嘿你好,请问你最近在忙什么呢?[SEP]我最近养了一只狗狗,我在训练它呢。"] | |
| ] | |
| def chat_func(input_utterance: str, history: Optional[List[str]] = None): | |
| if history is not None: | |
| history.extend(input_utterance.split(tokenizer.sep_token)) | |
| else: | |
| history = input_utterance.split(tokenizer.sep_token) | |
| history_str = "[history] " + tokenizer.sep_token.join(history) | |
| input_ids = tokenizer(history_str, | |
| return_tensors='pt', | |
| truncation=True, | |
| max_length=max_length, | |
| ).input_ids | |
| output_ids = model.generate(input_ids, | |
| max_new_tokens=30, | |
| top_k=32, | |
| num_beams=4, | |
| repetition_penalty=1.2, | |
| no_repeat_ngram_size=4)[0] | |
| response = tokenizer.decode(output_ids, skip_special_tokens=True) | |
| history.append(response) | |
| if len(history) % 2 == 0: | |
| display_utterances = [(history[i], history[i + 1]) for i in range(0, len(history) - 1, 2)] | |
| else: | |
| display_utterances = [("", history[0])] + [(history[i], history[i + 1]) for i in range(1, len(history) - 1, 2)] | |
| return display_utterances, history | |
| demo = gr.Interface(fn=chat_func, | |
| title=title, | |
| description=description, | |
| inputs=[gr.Textbox(lines=1, placeholder="Input current utterance"), "state"], | |
| examples=examples, | |
| outputs=["chatbot", "state"]) | |
| if __name__ == "__main__": | |
| demo.launch() | |