File size: 2,746 Bytes
74ef6d5
 
de736d7
74ef6d5
 
 
 
dc00d03
de736d7
dc00d03
 
74ef6d5
dc00d03
 
 
 
de736d7
dc00d03
 
 
74ef6d5
dc00d03
 
 
 
de736d7
dc00d03
 
74ef6d5
 
 
 
 
 
 
dc00d03
 
 
de736d7
dc00d03
 
 
74ef6d5
de736d7
74ef6d5
dc00d03
74ef6d5
de736d7
74ef6d5
 
de736d7
 
dc00d03
 
 
74ef6d5
 
 
 
 
 
 
dc00d03
 
 
 
74ef6d5
 
de736d7
dc00d03
 
 
 
 
 
 
 
 
74ef6d5
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
# MentalQA โ€“ Arabic Mental Health Assistant (chat + classifier)

import torch
from transformers import (
    AutoTokenizer, AutoModelForCausalLM,
    AutoModelForSequenceClassification, pipeline
)

CHAT_REPO = "yasser-alharbi/MentalQA"
CLASSIFIER_REPO = "yasser-alharbi/MentalQA-Classification"

# Load chat model
chat_tok = AutoTokenizer.from_pretrained(CHAT_REPO, use_fast=False)
chat_model = AutoModelForCausalLM.from_pretrained(
    CHAT_REPO,
    torch_dtype="auto",
    device_map="auto",
    low_cpu_mem_usage=True,
)

# Load classifier
clf_tok = AutoTokenizer.from_pretrained(CLASSIFIER_REPO)
clf_model = AutoModelForSequenceClassification.from_pretrained(CLASSIFIER_REPO)

device_idx = 0 if torch.cuda.is_available() else -1
clf_pipe = pipeline("text-classification", model=clf_model, tokenizer=clf_tok, device=device_idx)

label_map = {
    "LABEL_0": "A",  # ุชุดุฎูŠุต
    "LABEL_1": "B",  # ุนู„ุงุฌ
    "LABEL_2": "C",  # ุชุดุฑูŠุญ
    "LABEL_3": "D",  # ูˆุจุงุฆูŠุงุช
    "LABEL_4": "E",  # ู†ู…ุท ุญูŠุงุฉ
    "LABEL_5": "F",  # ู…ู‚ุฏู… ุฎุฏู…ุฉ
    "LABEL_6": "G",  # ุฃุฎุฑู‰
}

SYSTEM_MSG = (
    "ุฃู†ุช ู…ุณุงุนุฏ ุฐูƒูŠ ู„ู„ุตุญุฉ ุงู„ู†ูุณูŠุฉ ุงุณู…ู‡ MentalQA. "
    "ู„ุง ุชุฐูƒุฑ ุงุณู…ูƒ ุฃูˆ ู…ู†ุตุฉ ุนู…ู„ูƒ ุฅู„ุง ุฅุฐุง ุณูุฆู„ุช ุตุฑุงุญุฉู‹ ุนู† ู‡ูˆูŠุชูƒ."
)

def classify_question(text: str, threshold: float = 0.5) -> str:
    pred = max(clf_pipe(text), key=lambda x: x["score"])
    return label_map.get(pred["label"], "G") if pred["score"] >= threshold else "G"

def build_prompt(question: str, final_qt: str) -> str:
    return (
        f"{SYSTEM_MSG}\n\n"
        f"final_QT: {final_qt}\n\n"
        f"ุณุคุงู„ ุงู„ู…ุณุชุฎุฏู…:\n{question}\n\n"
        "ุงูƒุชุจ ูู‚ุฑุฉ ูˆุงุญุฏุฉ ู…ูุตู‘ู„ุฉ ู„ุง ุชู‚ู„ ุนู† ุซู„ุงุซ ุฌู…ู„ ู…ุชุฑุงุจุทุฉุŒ ุจุนุฏ ุฃู† ุชููƒู‘ุฑ ุฎุทูˆุฉ ุจุฎุทูˆุฉ.\n"
        "ุงู„ุฅุฌุงุจุฉ ุงู„ู†ู‡ุงุฆูŠุฉ:\n"
    )

def generate_mentalqa_answer(question: str, threshold: float = 0.5) -> str:
    final_qt = classify_question(question, threshold)
    prompt = build_prompt(question, final_qt)

    chat_input = chat_tok.apply_chat_template(
        [{"role": "system", "content": SYSTEM_MSG},
         {"role": "user", "content": prompt}],
        add_generation_prompt=True,
        return_tensors="pt"
    ).to(chat_model.device)

    gen_output = chat_model.generate(
        chat_input,
        max_new_tokens=128,
        do_sample=True,
        temperature=0.6,
        top_p=0.95,
        repetition_penalty=1.15,
        no_repeat_ngram_size=2,
        pad_token_id=chat_tok.eos_token_id,
        eos_token_id=chat_tok.eos_token_id,
    )[0]

    answer = chat_tok.decode(gen_output[chat_input.shape[1]:], skip_special_tokens=True)
    return answer.strip()