File size: 4,455 Bytes
e729b65 d3f1805 e729b65 b768edc e157ec5 e729b65 f3eb4d5 f4753f3 e729b65 f4753f3 e157ec5 e729b65 e157ec5 f3eb4d5 e157ec5 f3eb4d5 e729b65 f3eb4d5 e729b65 f3eb4d5 e729b65 e157ec5 e729b65 73aeb88 f3eb4d5 e729b65 f3eb4d5 73aeb88 f3eb4d5 73aeb88 f3eb4d5 73aeb88 f3eb4d5 73aeb88 f3eb4d5 73aeb88 f3eb4d5 e729b65 e157ec5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 |
import gradio as gr
from transformers import pipeline
from arabert.aragpt2.grover.modeling_gpt2 import GPT2LMHeadModel
from transformers import AutoTokenizer
import re
import torch
import spaces # Import the spaces module for ZeroGPU
model_name = "Naseej/AskMe-Large"
tokenizer = AutoTokenizer.from_pretrained(model_name, bos_token='<|startoftext|>',
eos_token='<|endoftext|>', pad_token='<|pad|>')
model = GPT2LMHeadModel.from_pretrained(model_name)
model.resize_token_embeddings(len(tokenizer))
# For ZeroGPU, we'll move the model to CUDA inside the decorated function
# Create the generator pipeline without specifying device
generator = pipeline("text-generation", model=model, tokenizer=tokenizer)
# ZeroGPU-decorated function for text generation
@spaces.GPU(duration=60) # Set duration based on your needs
def generate_response(message, history, num_beams=4, temperature=0.99, do_sample=True, top_k=60, top_p=0.9):
# Move model to CUDA inside the decorated function
generator.model = generator.model.to('cuda')
prompt = f'Prompt: {message}\nAnswer:'
pred_text = generator(prompt,
pad_token_id=tokenizer.eos_token_id,
num_beams=int(num_beams),
max_length=1024,
min_length=0,
temperature=temperature,
do_sample=do_sample,
top_p=top_p,
top_k=int(top_k),
repetition_penalty=3.0,
no_repeat_ngram_size=3)[0]['generated_text']
try:
pred_sentiment = re.findall("Answer:(.*)", pred_text, re.S)[-1]
except:
pred_sentiment = "لم أستطع توليد إجابة. يرجى إعادة صياغة السؤال."
# Move model back to CPU to free GPU memory
generator.model = generator.model.to('cpu')
return pred_sentiment
# Properly format the chat message handler
def respond(message, chat_history, num_beams, temperature, do_sample, top_k, top_p):
bot_message = generate_response(message, chat_history, num_beams, temperature, do_sample, top_k, top_p)
chat_history.append((message, bot_message))
return "", chat_history
# CSS for RTL support and styling
css = """
.gradio-container {direction: rtl;}
.message.user {background-color: #2b5797; color: white; border-radius: 20px; padding: 8px 12px; margin-bottom: 8px; text-align: right;}
.message.bot {background-color: #f0f0f0; color: black; border-radius: 20px; padding: 8px 12px; margin-bottom: 8px; text-align: right;}
"""
with gr.Blocks(css=css) as demo:
gr.Markdown("# نظام AskMe - تحدث معي")
with gr.Row():
with gr.Column(scale=3):
chatbot = gr.Chatbot(label="المحادثة", elem_classes=["chatbot"])
msg = gr.Textbox(label="اكتب رسالتك هنا", placeholder="اكتب هنا...")
with gr.Row():
submit_btn = gr.Button("إرسال", variant="primary")
clear_btn = gr.Button("مسح المحادثة")
with gr.Column(scale=1):
with gr.Accordion("إعدادات توليد النص", open=False):
num_beams = gr.Slider(1, 10, value=4, step=1, label="عدد الشعاعات")
temperature = gr.Slider(0.1, 2.0, value=0.99, step=0.01, label="درجة الحرارة")
do_sample = gr.Checkbox(value=True, label="النمط الإبداعي")
top_k = gr.Slider(1, 100, value=60, step=1, label="Top-K")
top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Top-P")
# Example prompts
examples = gr.Examples(
examples=[
["اكتب مقال عن الذكاء الصناعي"],
["اكتب قصة قصيرة عن النجاح"],
["كيف يمكن المحافظة على حياه صحية"]
],
inputs=msg
)
# Set up event handlers
submit_btn.click(
respond,
inputs=[msg, chatbot, num_beams, temperature, do_sample, top_k, top_p],
outputs=[msg, chatbot]
)
msg.submit(
respond,
inputs=[msg, chatbot, num_beams, temperature, do_sample, top_k, top_p],
outputs=[msg, chatbot]
)
clear_btn.click(lambda: None, None, chatbot, queue=False)
demo.launch() |