chatbox / test_jais.py
anaspro
updatE
51d3416
raw
history blame
2.56 kB
#!/usr/bin/env python3
"""
اختبار مودل Jais - مثل الكود الأصلي
"""
import os
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
def test_jais():
model_path = "inceptionai/jais-family-13b-chat"
# تحميل المودل مثل الكود الأصلي
tokenizer = AutoTokenizer.from_pretrained(model_path)
model = AutoModelForCausalLM.from_pretrained(model_path, device_map="auto", trust_remote_code=True)
# الـ prompts الأصلية
prompt_eng = "### Instruction:Your name is 'Jais', and you are named after Jebel Jais, the highest mountain in UAE. You were made by 'Inception' in the UAE. You are a helpful, respectful, and honest assistant. Always answer as helpfully as possible, while being safe. Complete the conversation between [|Human|] and [|AI|]:\n### Input: [|Human|] {Question}\n[|AI|]\n### Response :"
prompt_ar = "### Instruction:اسمك \"جيس\" وسميت على اسم جبل جيس اعلى جبل في الامارات. تم بنائك بواسطة Inception في الإمارات. أنت مساعد مفيد ومحترم وصادق. أجب دائمًا بأكبر قدر ممكن من المساعدة، مع الحفاظ على البقاء أمناً. أكمل المحادثة بين [|Human|] و[|AI|] :\n### Input:[|Human|] {Question}\n[|AI|]\n### Response :"
def get_response(text):
input_ids = tokenizer(text, return_tensors="pt").input_ids
inputs = input_ids.to("cuda" if torch.cuda.is_available() else "cpu")
input_len = inputs.shape[-1]
generate_ids = model.generate(
inputs,
top_p=0.9,
temperature=0.3,
max_length=2048,
min_length=input_len + 4,
repetition_penalty=1.2,
do_sample=True,
)
response = tokenizer.batch_decode(
generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
)[0]
response = response.split("### Response :")[-1]
return response
# اختبار عربي
ques = "ما هي عاصمة الامارات؟"
text = prompt_ar.format_map({'Question': ques})
print("السؤال العربي:", ques)
print("الرد:", get_response(text))
print()
# اختبار إنجليزي
ques = "What is the capital of UAE?"
text = prompt_eng.format_map({'Question': ques})
print("السؤال الإنجليزي:", ques)
print("الرد:", get_response(text))
if __name__ == "__main__":
test_jais()