# MentalQA – Arabic Mental Health Assistant (chat + classifier) import torch from transformers import ( AutoTokenizer, AutoModelForCausalLM, AutoModelForSequenceClassification, pipeline ) CHAT_REPO = "yasser-alharbi/MentalQA" CLASSIFIER_REPO = "yasser-alharbi/MentalQA-Classification" # Load chat model chat_tok = AutoTokenizer.from_pretrained(CHAT_REPO, use_fast=False) chat_model = AutoModelForCausalLM.from_pretrained( CHAT_REPO, torch_dtype="auto", device_map="auto", low_cpu_mem_usage=True, ) # Load classifier clf_tok = AutoTokenizer.from_pretrained(CLASSIFIER_REPO) clf_model = AutoModelForSequenceClassification.from_pretrained(CLASSIFIER_REPO) device_idx = 0 if torch.cuda.is_available() else -1 clf_pipe = pipeline("text-classification", model=clf_model, tokenizer=clf_tok, device=device_idx) label_map = { "LABEL_0": "A", # تشخيص "LABEL_1": "B", # علاج "LABEL_2": "C", # تشريح "LABEL_3": "D", # وبائيات "LABEL_4": "E", # نمط حياة "LABEL_5": "F", # مقدم خدمة "LABEL_6": "G", # أخرى } SYSTEM_MSG = ( "أنت مساعد ذكي للصحة النفسية اسمه MentalQA. " "لا تذكر اسمك أو منصة عملك إلا إذا سُئلت صراحةً عن هويتك." ) def classify_question(text: str, threshold: float = 0.5) -> str: pred = max(clf_pipe(text), key=lambda x: x["score"]) return label_map.get(pred["label"], "G") if pred["score"] >= threshold else "G" def build_prompt(question: str, final_qt: str) -> str: return ( f"{SYSTEM_MSG}\n\n" f"final_QT: {final_qt}\n\n" f"سؤال المستخدم:\n{question}\n\n" "اكتب فقرة واحدة مفصّلة لا تقل عن ثلاث جمل مترابطة، بعد أن تفكّر خطوة بخطوة.\n" "الإجابة النهائية:\n" ) def generate_mentalqa_answer(question: str, threshold: float = 0.5) -> str: final_qt = classify_question(question, threshold) prompt = build_prompt(question, final_qt) chat_input = chat_tok.apply_chat_template( [{"role": "system", "content": SYSTEM_MSG}, {"role": "user", "content": prompt}], add_generation_prompt=True, return_tensors="pt" ).to(chat_model.device) gen_output = chat_model.generate( chat_input, max_new_tokens=128, do_sample=True, temperature=0.6, top_p=0.95, repetition_penalty=1.15, no_repeat_ngram_size=2, pad_token_id=chat_tok.eos_token_id, eos_token_id=chat_tok.eos_token_id, )[0] answer = chat_tok.decode(gen_output[chat_input.shape[1]:], skip_special_tokens=True) return answer.strip()