AskMe / app.py
mobarmg's picture
Update app.py
e157ec5 verified
import gradio as gr
from transformers import pipeline
from arabert.aragpt2.grover.modeling_gpt2 import GPT2LMHeadModel
from transformers import AutoTokenizer
import re
import torch
import spaces # Import the spaces module for ZeroGPU
model_name = "Naseej/AskMe-Large"
tokenizer = AutoTokenizer.from_pretrained(model_name, bos_token='<|startoftext|>',
eos_token='<|endoftext|>', pad_token='<|pad|>')
model = GPT2LMHeadModel.from_pretrained(model_name)
model.resize_token_embeddings(len(tokenizer))
# For ZeroGPU, we'll move the model to CUDA inside the decorated function
# Create the generator pipeline without specifying device
generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
# ZeroGPU-decorated function for text generation
@spaces.GPU(duration=60) # Set duration based on your needs
def generate_response(message, history, num_beams=4, temperature=0.99, do_sample=True, top_k=60, top_p=0.9):
# Move model to CUDA inside the decorated function
generator.model = generator.model.to('cuda')
prompt = f'Prompt: {message}\nAnswer:'
pred_text = generator(prompt,
pad_token_id=tokenizer.eos_token_id,
num_beams=int(num_beams),
max_length=1024,
min_length=0,
temperature=temperature,
do_sample=do_sample,
top_p=top_p,
top_k=int(top_k),
repetition_penalty=3.0,
no_repeat_ngram_size=3)[0]['generated_text']
try:
pred_sentiment = re.findall("Answer:(.*)", pred_text, re.S)[-1]
except:
pred_sentiment = "لم أستطع توليد إجابة. يرجى إعادة صياغة السؤال."
# Move model back to CPU to free GPU memory
generator.model = generator.model.to('cpu')
return pred_sentiment
# Properly format the chat message handler
def respond(message, chat_history, num_beams, temperature, do_sample, top_k, top_p):
bot_message = generate_response(message, chat_history, num_beams, temperature, do_sample, top_k, top_p)
chat_history.append((message, bot_message))
return "", chat_history
# CSS for RTL support and styling
css = """
.gradio-container {direction: rtl;}
.message.user {background-color: #2b5797; color: white; border-radius: 20px; padding: 8px 12px; margin-bottom: 8px; text-align: right;}
.message.bot {background-color: #f0f0f0; color: black; border-radius: 20px; padding: 8px 12px; margin-bottom: 8px; text-align: right;}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown("# نظام AskMe - تحدث معي")
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(label="المحادثة", elem_classes=["chatbot"])
msg = gr.Textbox(label="اكتب رسالتك هنا", placeholder="اكتب هنا...")
with gr.Row():
submit_btn = gr.Button("إرسال", variant="primary")
clear_btn = gr.Button("مسح المحادثة")
with gr.Column(scale=1):
with gr.Accordion("إعدادات توليد النص", open=False):
num_beams = gr.Slider(1, 10, value=4, step=1, label="عدد الشعاعات")
temperature = gr.Slider(0.1, 2.0, value=0.99, step=0.01, label="درجة الحرارة")
do_sample = gr.Checkbox(value=True, label="النمط الإبداعي")
top_k = gr.Slider(1, 100, value=60, step=1, label="Top-K")
top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Top-P")
# Example prompts
examples = gr.Examples(
examples=[
["اكتب مقال عن الذكاء الصناعي"],
["اكتب قصة قصيرة عن النجاح"],
["كيف يمكن المحافظة على حياه صحية"]
],
inputs=msg
)
# Set up event handlers
submit_btn.click(
respond,
inputs=[msg, chatbot, num_beams, temperature, do_sample, top_k, top_p],
outputs=[msg, chatbot]
)
msg.submit(
respond,
inputs=[msg, chatbot, num_beams, temperature, do_sample, top_k, top_p],
outputs=[msg, chatbot]
)
clear_btn.click(lambda: None, None, chatbot, queue=False)
demo.launch()