Spaces:
Runtime error
Runtime error
| # MentalQA – Arabic Mental Health Assistant (chat + classifier) | |
| import torch | |
| from transformers import ( | |
| AutoTokenizer, AutoModelForCausalLM, | |
| AutoModelForSequenceClassification, pipeline | |
| ) | |
| CHAT_REPO = "yasser-alharbi/MentalQA" | |
| CLASSIFIER_REPO = "yasser-alharbi/MentalQA-Classification" | |
| # Load chat model | |
| chat_tok = AutoTokenizer.from_pretrained(CHAT_REPO, use_fast=False) | |
| chat_model = AutoModelForCausalLM.from_pretrained( | |
| CHAT_REPO, | |
| torch_dtype="auto", | |
| device_map="auto", | |
| low_cpu_mem_usage=True, | |
| ) | |
| # Load classifier | |
| clf_tok = AutoTokenizer.from_pretrained(CLASSIFIER_REPO) | |
| clf_model = AutoModelForSequenceClassification.from_pretrained(CLASSIFIER_REPO) | |
| device_idx = 0 if torch.cuda.is_available() else -1 | |
| clf_pipe = pipeline("text-classification", model=clf_model, tokenizer=clf_tok, device=device_idx) | |
| label_map = { | |
| "LABEL_0": "A", # تشخيص | |
| "LABEL_1": "B", # علاج | |
| "LABEL_2": "C", # تشريح | |
| "LABEL_3": "D", # وبائيات | |
| "LABEL_4": "E", # نمط حياة | |
| "LABEL_5": "F", # مقدم خدمة | |
| "LABEL_6": "G", # أخرى | |
| } | |
| SYSTEM_MSG = ( | |
| "أنت مساعد ذكي للصحة النفسية اسمه MentalQA. " | |
| "لا تذكر اسمك أو منصة عملك إلا إذا سُئلت صراحةً عن هويتك." | |
| ) | |
| def classify_question(text: str, threshold: float = 0.5) -> str: | |
| pred = max(clf_pipe(text), key=lambda x: x["score"]) | |
| return label_map.get(pred["label"], "G") if pred["score"] >= threshold else "G" | |
| def build_prompt(question: str, final_qt: str) -> str: | |
| return ( | |
| f"{SYSTEM_MSG}\n\n" | |
| f"final_QT: {final_qt}\n\n" | |
| f"سؤال المستخدم:\n{question}\n\n" | |
| "اكتب فقرة واحدة مفصّلة لا تقل عن ثلاث جمل مترابطة، بعد أن تفكّر خطوة بخطوة.\n" | |
| "الإجابة النهائية:\n" | |
| ) | |
| def generate_mentalqa_answer(question: str, threshold: float = 0.5) -> str: | |
| final_qt = classify_question(question, threshold) | |
| prompt = build_prompt(question, final_qt) | |
| chat_input = chat_tok.apply_chat_template( | |
| [{"role": "system", "content": SYSTEM_MSG}, | |
| {"role": "user", "content": prompt}], | |
| add_generation_prompt=True, | |
| return_tensors="pt" | |
| ).to(chat_model.device) | |
| gen_output = chat_model.generate( | |
| chat_input, | |
| max_new_tokens=128, | |
| do_sample=True, | |
| temperature=0.6, | |
| top_p=0.95, | |
| repetition_penalty=1.15, | |
| no_repeat_ngram_size=2, | |
| pad_token_id=chat_tok.eos_token_id, | |
| eos_token_id=chat_tok.eos_token_id, | |
| )[0] | |
| answer = chat_tok.decode(gen_output[chat_input.shape[1]:], skip_special_tokens=True) | |
| return answer.strip() | |