Spaces:
Runtime error
Runtime error
| import transformers | |
| import gradio as gr | |
| import git | |
| import os | |
| os.system("pip install --upgrade pip") | |
| #Load arabert preprocessor | |
| import git | |
| git.Git("arabert").clone("https://github.com/aub-mind/arabert") | |
| from arabert.preprocess import ArabertPreprocessor | |
| arabert_prep = ArabertPreprocessor(model_name="bert-base-arabert", keep_emojis=False) | |
| #Load Model | |
| from transformers import EncoderDecoderModel, AutoTokenizer | |
| tokenizer = AutoTokenizer.from_pretrained("tareknaous/bert2bert-empathetic-response-msa") | |
| model = EncoderDecoderModel.from_pretrained("tareknaous/bert2bert-empathetic-response-msa") | |
| model.eval() | |
| def generate_response(text, minimum_length, p, temperature): | |
| text_clean = arabert_prep.preprocess(text) | |
| inputs = tokenizer.encode_plus(text_clean,return_tensors='pt') | |
| outputs = model.generate(input_ids = inputs.input_ids, | |
| attention_mask = inputs.attention_mask, | |
| do_sample = True, | |
| min_length=minimum_length, | |
| top_p = p, | |
| temperature = temperature) | |
| preds = tokenizer.batch_decode(outputs) | |
| response = str(preds) | |
| response = response.replace("\'", '') | |
| response = response.replace("[[CLS]", '') | |
| response = response.replace("[SEP]]", '') | |
| response = str(arabert_prep.desegment(response)) | |
| return response | |
| # title = 'Empathetic Response Generation in Arabic' | |
| # description = 'This demo is for a BERT2BERT model trained for single-turn open-domain empathetic dialogue response generation in Modern Standard Arabic' | |
| css = """ | |
| .rtlClass {direction:rtl !important} | |
| """ | |
| with gr.Blocks(css=css) as demo: | |
| with gr.Column(): | |
| gr.Markdown("Empathetic Response Generation in Arabic") | |
| chatbot = gr.Chatbot(elem_classes="rtlClass").style(height=400) | |
| msg = gr.Textbox(placeholder="Ψ§Ψ±Ψ³Ω Ψ±Ψ³Ψ§ΩΨ©",show_label=False,elem_classes="rtlClass").style(container=False) | |
| with gr.Column(): | |
| output_slider=gr.Slider(5, 20, step=1, label='Minimum Output Length') | |
| top_p_slider=gr.Slider(0.7, 1, step=0.1, label='Top-P') | |
| temperature_slider=gr.Slider(1, 3, step=0.1, label='Temperature') | |
| clear = gr.Button("Clear Chat") | |
| def respond(message,chat_history,output_slider,top_p_slider,temperature_slider): | |
| bot_message = generate_response(message,output_slider,top_p_slider,temperature_slider) | |
| chat_history.append((message, bot_message)) | |
| return "", chat_history | |
| msg.submit(respond, [msg, chatbot,output_slider,top_p_slider,temperature_slider], [msg, chatbot]) | |
| clear.click(lambda: None, None, chatbot, queue=False) | |
| demo.launch() |