Spaces:
Runtime error
Runtime error
File size: 2,746 Bytes
74ef6d5 de736d7 74ef6d5 dc00d03 de736d7 dc00d03 74ef6d5 dc00d03 de736d7 dc00d03 74ef6d5 dc00d03 de736d7 dc00d03 74ef6d5 dc00d03 de736d7 dc00d03 74ef6d5 de736d7 74ef6d5 dc00d03 74ef6d5 de736d7 74ef6d5 de736d7 dc00d03 74ef6d5 dc00d03 74ef6d5 de736d7 dc00d03 74ef6d5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
# MentalQA โ Arabic Mental Health Assistant (chat + classifier)
import torch
from transformers import (
AutoTokenizer, AutoModelForCausalLM,
AutoModelForSequenceClassification, pipeline
)
CHAT_REPO = "yasser-alharbi/MentalQA"
CLASSIFIER_REPO = "yasser-alharbi/MentalQA-Classification"
# Load chat model
chat_tok = AutoTokenizer.from_pretrained(CHAT_REPO, use_fast=False)
chat_model = AutoModelForCausalLM.from_pretrained(
CHAT_REPO,
torch_dtype="auto",
device_map="auto",
low_cpu_mem_usage=True,
)
# Load classifier
clf_tok = AutoTokenizer.from_pretrained(CLASSIFIER_REPO)
clf_model = AutoModelForSequenceClassification.from_pretrained(CLASSIFIER_REPO)
device_idx = 0 if torch.cuda.is_available() else -1
clf_pipe = pipeline("text-classification", model=clf_model, tokenizer=clf_tok, device=device_idx)
label_map = {
"LABEL_0": "A", # ุชุดุฎูุต
"LABEL_1": "B", # ุนูุงุฌ
"LABEL_2": "C", # ุชุดุฑูุญ
"LABEL_3": "D", # ูุจุงุฆูุงุช
"LABEL_4": "E", # ูู
ุท ุญูุงุฉ
"LABEL_5": "F", # ู
ูุฏู
ุฎุฏู
ุฉ
"LABEL_6": "G", # ุฃุฎุฑู
}
SYSTEM_MSG = (
"ุฃูุช ู
ุณุงุนุฏ ุฐูู ููุตุญุฉ ุงูููุณูุฉ ุงุณู
ู MentalQA. "
"ูุง ุชุฐูุฑ ุงุณู
ู ุฃู ู
ูุตุฉ ุนู
ูู ุฅูุง ุฅุฐุง ุณูุฆูุช ุตุฑุงุญุฉู ุนู ูููุชู."
)
def classify_question(text: str, threshold: float = 0.5) -> str:
pred = max(clf_pipe(text), key=lambda x: x["score"])
return label_map.get(pred["label"], "G") if pred["score"] >= threshold else "G"
def build_prompt(question: str, final_qt: str) -> str:
return (
f"{SYSTEM_MSG}\n\n"
f"final_QT: {final_qt}\n\n"
f"ุณุคุงู ุงูู
ุณุชุฎุฏู
:\n{question}\n\n"
"ุงูุชุจ ููุฑุฉ ูุงุญุฏุฉ ู
ูุตููุฉ ูุง ุชูู ุนู ุซูุงุซ ุฌู
ู ู
ุชุฑุงุจุทุฉุ ุจุนุฏ ุฃู ุชูููุฑ ุฎุทูุฉ ุจุฎุทูุฉ.\n"
"ุงูุฅุฌุงุจุฉ ุงูููุงุฆูุฉ:\n"
)
def generate_mentalqa_answer(question: str, threshold: float = 0.5) -> str:
final_qt = classify_question(question, threshold)
prompt = build_prompt(question, final_qt)
chat_input = chat_tok.apply_chat_template(
[{"role": "system", "content": SYSTEM_MSG},
{"role": "user", "content": prompt}],
add_generation_prompt=True,
return_tensors="pt"
).to(chat_model.device)
gen_output = chat_model.generate(
chat_input,
max_new_tokens=128,
do_sample=True,
temperature=0.6,
top_p=0.95,
repetition_penalty=1.15,
no_repeat_ngram_size=2,
pad_token_id=chat_tok.eos_token_id,
eos_token_id=chat_tok.eos_token_id,
)[0]
answer = chat_tok.decode(gen_output[chat_input.shape[1]:], skip_special_tokens=True)
return answer.strip()
|