File size: 4,455 Bytes
e729b65
 
 
d3f1805
e729b65
b768edc
e157ec5
e729b65
 
f3eb4d5
 
f4753f3
e729b65
f4753f3
e157ec5
 
 
e729b65
e157ec5
 
f3eb4d5
e157ec5
 
 
f3eb4d5
e729b65
f3eb4d5
 
 
 
 
 
 
 
 
 
e729b65
 
 
f3eb4d5
e729b65
e157ec5
 
e729b65
 
73aeb88
 
 
 
 
 
f3eb4d5
 
 
 
 
 
e729b65
f3eb4d5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73aeb88
 
 
 
 
 
 
 
f3eb4d5
 
 
73aeb88
f3eb4d5
73aeb88
f3eb4d5
 
 
73aeb88
f3eb4d5
73aeb88
f3eb4d5
 
 
e729b65
e157ec5
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import gradio as gr
from transformers import pipeline
from arabert.aragpt2.grover.modeling_gpt2 import GPT2LMHeadModel
from transformers import AutoTokenizer
import re
import torch
import spaces  # Import the spaces module for ZeroGPU

model_name = "Naseej/AskMe-Large"
tokenizer = AutoTokenizer.from_pretrained(model_name, bos_token='<|startoftext|>',
                                         eos_token='<|endoftext|>', pad_token='<|pad|>')
model = GPT2LMHeadModel.from_pretrained(model_name)
model.resize_token_embeddings(len(tokenizer))

# For ZeroGPU, we'll move the model to CUDA inside the decorated function
# Create the generator pipeline without specifying device
generator = pipeline("text-generation", model=model, tokenizer=tokenizer)

# ZeroGPU-decorated function for text generation
@spaces.GPU(duration=60)  # Set duration based on your needs
def generate_response(message, history, num_beams=4, temperature=0.99, do_sample=True, top_k=60, top_p=0.9):
    # Move model to CUDA inside the decorated function
    generator.model = generator.model.to('cuda')
    
    prompt = f'Prompt: {message}\nAnswer:'
    pred_text = generator(prompt,
                         pad_token_id=tokenizer.eos_token_id,
                         num_beams=int(num_beams),
                         max_length=1024,
                         min_length=0,
                         temperature=temperature,
                         do_sample=do_sample,
                         top_p=top_p,
                         top_k=int(top_k),
                         repetition_penalty=3.0,
                         no_repeat_ngram_size=3)[0]['generated_text']
    try:
        pred_sentiment = re.findall("Answer:(.*)", pred_text, re.S)[-1]
    except:
        pred_sentiment = "لم أستطع توليد إجابة. يرجى إعادة صياغة السؤال."

    # Move model back to CPU to free GPU memory
    generator.model = generator.model.to('cpu')
    return pred_sentiment

# Properly format the chat message handler
def respond(message, chat_history, num_beams, temperature, do_sample, top_k, top_p):
    bot_message = generate_response(message, chat_history, num_beams, temperature, do_sample, top_k, top_p)
    chat_history.append((message, bot_message))
    return "", chat_history

# CSS for RTL support and styling
css = """
.gradio-container {direction: rtl;}
.message.user {background-color: #2b5797; color: white; border-radius: 20px; padding: 8px 12px; margin-bottom: 8px; text-align: right;}
.message.bot {background-color: #f0f0f0; color: black; border-radius: 20px; padding: 8px 12px; margin-bottom: 8px; text-align: right;}
"""

with gr.Blocks(css=css) as demo:
    gr.Markdown("# نظام AskMe - تحدث معي")
    
    with gr.Row():
        with gr.Column(scale=3):
            chatbot = gr.Chatbot(label="المحادثة", elem_classes=["chatbot"])
            msg = gr.Textbox(label="اكتب رسالتك هنا", placeholder="اكتب هنا...")
            
            with gr.Row():
                submit_btn = gr.Button("إرسال", variant="primary")
                clear_btn = gr.Button("مسح المحادثة")
        
        with gr.Column(scale=1):
            with gr.Accordion("إعدادات توليد النص", open=False):
                num_beams = gr.Slider(1, 10, value=4, step=1, label="عدد الشعاعات")
                temperature = gr.Slider(0.1, 2.0, value=0.99, step=0.01, label="درجة الحرارة")
                do_sample = gr.Checkbox(value=True, label="النمط الإبداعي")
                top_k = gr.Slider(1, 100, value=60, step=1, label="Top-K")
                top_p = gr.Slider(0.0, 1.0, value=0.9, step=0.01, label="Top-P")
    
    # Example prompts
    examples = gr.Examples(
        examples=[
            ["اكتب مقال عن الذكاء الصناعي"],
            ["اكتب قصة قصيرة عن النجاح"],
            ["كيف يمكن المحافظة على حياه صحية"]
        ],
        inputs=msg
    )
    
    # Set up event handlers
    submit_btn.click(
        respond,
        inputs=[msg, chatbot, num_beams, temperature, do_sample, top_k, top_p],
        outputs=[msg, chatbot]
    )
    
    msg.submit(
        respond,
        inputs=[msg, chatbot, num_beams, temperature, do_sample, top_k, top_p],
        outputs=[msg, chatbot]
    )
    
    clear_btn.click(lambda: None, None, chatbot, queue=False)

demo.launch()