|
|
import gradio as gr |
|
|
from transformers import pipeline |
|
|
from arabert.aragpt2.grover.modeling_gpt2 import GPT2LMHeadModel |
|
|
from transformers import AutoTokenizer |
|
|
import re |
|
|
import torch |
|
|
import spaces |
|
|
|
|
|
model_name = "Naseej/AskMe-Large" |
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name, bos_token='<|startoftext|>', |
|
|
eos_token='<|endoftext|>', pad_token='<|pad|>') |
|
|
model = GPT2LMHeadModel.from_pretrained(model_name) |
|
|
model.resize_token_embeddings(len(tokenizer)) |
|
|
|
|
|
|
|
|
|
|
|
generator = pipeline("text-generation", model=model, tokenizer=tokenizer) |
|
|
|
|
|
|
|
|
@spaces.GPU(duration=60) |
|
|
def generate_response(message, history, num_beams=4, temperature=0.99, do_sample=True, top_k=60, top_p=0.9): |
|
|
|
|
|
generator.model = generator.model.to('cuda') |
|
|
|
|
|
prompt = f'Prompt: {message}\nAnswer:' |
|
|
pred_text = generator(prompt, |
|
|
pad_token_id=tokenizer.eos_token_id, |
|
|
num_beams=int(num_beams), |
|
|
max_length=1024, |
|
|
min_length=0, |
|
|
temperature=temperature, |
|
|
do_sample=do_sample, |
|
|
top_p=top_p, |
|
|
top_k=int(top_k), |
|
|
repetition_penalty=3.0, |
|
|
no_repeat_ngram_size=3)[0]['generated_text'] |
|
|
try: |
|
|
pred_sentiment = re.findall("Answer:(.*)", pred_text, re.S)[-1] |
|
|
except: |
|
|
pred_sentiment = "لم أستطع توليد إجابة. يرجى إعادة صياغة السؤال." |
|
|
|
|
|
|
|
|
generator.model = generator.model.to('cpu') |
|
|
return pred_sentiment |
|
|
|
|
|
|
|
|
def respond(message, chat_history, num_beams, temperature, do_sample, top_k, top_p): |
|
|
bot_message = generate_response(message, chat_history, num_beams, temperature, do_sample, top_k, top_p) |
|
|
chat_history.append((message, bot_message)) |
|
|
return "", chat_history |
|
|
|
|
|
|
|
|
css = """ |
|
|
.gradio-container {direction: rtl;} |
|
|
.message.user {background-color: #2b5797; color: white; border-radius: 20px; padding: 8px 12px; margin-bottom: 8px; text-align: right;} |
|
|
.message.bot {background-color: #f0f0f0; color: black; border-radius: 20px; padding: 8px 12px; margin-bottom: 8px; text-align: right;} |
|
|
""" |
|
|
|
|
|
with gr.Blocks(css=css) as demo: |
|
|
gr.Markdown("# نظام AskMe - تحدث معي") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=3): |
|
|
chatbot = gr.Chatbot(label="المحادثة", elem_classes=["chatbot"]) |
|
|
msg = gr.Textbox(label="اكتب رسالتك هنا", placeholder="اكتب هنا...") |
|
|
|
|
|
with gr.Row(): |
|
|
submit_btn = gr.Button("إرسال", variant="primary") |
|
|
clear_btn = gr.Button("مسح المحادثة") |
|
|
|
|
|
with gr.Column(scale=1): |
|
|
with gr.Accordion("إعدادات توليد النص", open=False): |
|
|
num_beams = gr.Slider(1, 10, value=4, step=1, label="عدد الشعاعات") |
|
|
temperature = gr.Slider(0.1, 2.0, value=0.99, step=0.01, label="درجة الحرارة") |
|
|
do_sample = gr.Checkbox(value=True, label="النمط الإبداعي") |
|
|
top_k = gr.Slider(1, 100, value=60, step=1, label="Top-K") |
|
|
top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Top-P") |
|
|
|
|
|
|
|
|
examples = gr.Examples( |
|
|
examples=[ |
|
|
["اكتب مقال عن الذكاء الصناعي"], |
|
|
["اكتب قصة قصيرة عن النجاح"], |
|
|
["كيف يمكن المحافظة على حياه صحية"] |
|
|
], |
|
|
inputs=msg |
|
|
) |
|
|
|
|
|
|
|
|
submit_btn.click( |
|
|
respond, |
|
|
inputs=[msg, chatbot, num_beams, temperature, do_sample, top_k, top_p], |
|
|
outputs=[msg, chatbot] |
|
|
) |
|
|
|
|
|
msg.submit( |
|
|
respond, |
|
|
inputs=[msg, chatbot, num_beams, temperature, do_sample, top_k, top_p], |
|
|
outputs=[msg, chatbot] |
|
|
) |
|
|
|
|
|
clear_btn.click(lambda: None, None, chatbot, queue=False) |
|
|
|
|
|
demo.launch() |