NLPAlhuzali commited on
Commit
de736d7
ยท
verified ยท
1 Parent(s): 51438c7

Update models/space_b.py

Browse files
Files changed (1) hide show
  1. models/space_b.py +23 -117
models/space_b.py CHANGED
@@ -1,123 +1,58 @@
1
- # app.py โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
2
- """
3
- MentalQA demo Space
4
- Loads:
5
- โ€ข yasser-alharbi/MentalQA (ALLaM-7B-based chat model)
6
- โ€ข yasser-alharbi/MentalQA-Classification (final_QT intent classifier)
7
- and exposes an Arabic RTL Gradio interface.
8
- """
9
 
10
-
11
- import torch, gradio as gr
12
- from transformers import (AutoTokenizer,
13
- AutoModelForCausalLM,
14
- AutoModelForSequenceClassification,
15
- pipeline)
16
-
17
- # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ HF repos โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
18
- CHAT_REPO = "yasser-alharbi/MentalQA"
19
  CLASSIFIER_REPO = "yasser-alharbi/MentalQA-Classification"
20
 
21
- # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ Load chat model โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
22
  chat_tok = AutoTokenizer.from_pretrained(CHAT_REPO, use_fast=False)
23
  chat_model = AutoModelForCausalLM.from_pretrained(
24
  CHAT_REPO,
25
  torch_dtype="auto",
26
- device_map="auto", # works for CPU or GPU Space
27
  low_cpu_mem_usage=True,
28
  )
29
 
30
- # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ Load classifier โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
31
  clf_tok = AutoTokenizer.from_pretrained(CLASSIFIER_REPO)
32
  clf_model = AutoModelForSequenceClassification.from_pretrained(CLASSIFIER_REPO)
33
 
34
  device_idx = 0 if torch.cuda.is_available() else -1
35
- clf_pipe = pipeline("text-classification",
36
- model=clf_model,
37
- tokenizer=clf_tok,
38
- device=device_idx)
39
-
40
 
41
  label_map = {
42
- "LABEL_0": "A",
43
- "LABEL_1": "B",
44
- "LABEL_2": "C",
45
- "LABEL_3": "D",
46
- "LABEL_4": "E",
47
- "LABEL_5": "F",
48
- "LABEL_6": "G",
49
  }
50
 
51
- # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ Prompt helpers โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
52
  SYSTEM_MSG = (
53
- "ุฃู†ุช ู…ุณุงุนุฏ ุฐูƒูŠ ู„ู„ุตุญุฉ ุงู„ู†ูุณูŠุฉ ุงุณู…ู‡ MentalQA"
54
  "ู„ุง ุชุฐูƒุฑ ุงุณู…ูƒ ุฃูˆ ู…ู†ุตุฉ ุนู…ู„ูƒ ุฅู„ุง ุฅุฐุง ุณูุฆู„ุช ุตุฑุงุญุฉู‹ ุนู† ู‡ูˆูŠุชูƒ."
55
- "ุจุงู„ุฅุถุงูุฉ ุฅู„ู‰ ุฐู„ูƒ:\n"
56
- "ุนู†ุฏู…ุง ูŠุญูŠูŠูƒ ุฃุญุฏ ุจุชุญูŠุฉ ุนุฑุจูŠุฉ:\n"
57
- " - ุงู„ุณู„ุงู… ุนู„ูŠูƒู… => ูˆุนู„ูŠูƒู… ุงู„ุณู„ุงู…\n"
58
- " - ุตุจุงุญ ุงู„ุฎูŠุฑ => ุตุจุงุญ ุงู„ู†ูˆุฑ\n"
59
- " - ู…ุณุงุก ุงู„ุฎูŠุฑ => ู…ุณุงุก ุงู„ู†ูˆุฑ\n\n"
60
  )
61
 
62
- def build_prompt_arabic(question, final_qt_list):
63
- qt_str = ", ".join(final_qt_list)
64
-
65
- prompt = (
66
- # โ”€โ”€ Core rules โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
67
- "ุฃุฌุจ ุจุงู„ู„ุบุฉ ุงู„ุนุฑุจูŠุฉ ุงุณุชู†ุงุฏู‹ุง ุฅู„ู‰ ุงู„ู‚ูˆุงุนุฏ ุงู„ุชุงู„ูŠุฉ:\n"
68
- "1) ู‡ุฐู‡ ู„ูŠุณุช ุงุณุชุดุงุฑุฉ ุทุจูŠุฉ ุจุฏูŠู„ุฉุ› ู‚ุฏู‘ู… ุฅุฑุดุงุฏุงุช ุนุงู…ุฉ ูˆุชู…ู‡ูŠุฏูŠุฉ.\n"
69
- "2) ู„ุง ุชุณุชุฎุฏู… ุฃุณู…ุงุก ุดุฎุตูŠุฉ ุฃูˆ ุชุฏู‘ุนูŠ ู…ู„ูƒูŠุฉ.\n"
70
- "3) ุฅุฐุง ูƒุงู† ุงู„ุณุคุงู„ ุฎุงุฑุฌ ุงู„ุตุญุฉ ุงู„ู†ูุณูŠุฉุŒ ู‚ู„: 'ุนุฐุฑุงู‹ุŒ ูˆู„ูƒู† ู‡ุฐุง ุงู„ุณุคุงู„ ุฎุงุฑุฌ ู†ุทุงู‚ ู‚ุฏุฑุชูŠ.'\n"
71
- "4) ุงุณุชุฑุดุฏ ุจู‚ูŠู… final_QT (A ุชุดุฎูŠุตุŒ B ุนู„ุงุฌุŒ C ุชุดุฑูŠุญุŒ D ูˆุจุงุฆูŠุงุชุŒ "
72
- "E ู†ู…ุท ุญูŠุงุฉุŒ F ุฎูŠุงุฑุงุช ู…ู‚ุฏู… ุงู„ุฎุฏู…ุฉุŒ G ุฃุฎุฑู‰).\n"
73
- "5) ุฅุฐุง ูƒุงู†ุช ุญุงู„ุฉ ุงู„ู…ุฑูŠุถ ุญุฑุฌุฉุŒ ุฃุจุฏู ุชุนุงุทููƒ ุฃูˆู„ุงู‹ ุซู… ูˆุฌู‘ู‡ ุงู„ู†ุตูŠุญุฉ.\n"
74
- "6) ุฅุฐุง ุงุญุชุงุฌ ุงู„ู…ุฑูŠุถ ู„ุชูˆุฌูŠู‡ ู…ุจุงุดุฑุŒ ู‚ู„: 'ู‚ุฏ ูŠููŠุฏ ุงู„ุชูˆุงุตู„ ู…ุน ู…ุฎุชุต ู†ูุณูŠ ุฃูˆ ู…ุณุชุดุงุฑ ู…ูˆุซูˆู‚.'\n\n"
75
-
76
-
77
-
78
- # โ”€โ”€ Fewโ€‘shot exemplar WITH reasoning โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
79
- "ู…ุซุงู„ ุชูˆุถูŠุญูŠ ู„ู„ุฅุฌุงุจุฉ ุงู„ู…ูุตู‘ู„ุฉ ู…ุน ุฎุทูˆุงุช ุงู„ุชููƒูŠุฑ:\n"
80
- "ุณุคุงู„: ุฃุดุนุฑ ุจุฅุฑู‡ุงู‚ู ู…ุณุชู…ุฑ ูˆู„ุง ุฃุณุชุทูŠุน ุงู„ุชุฑูƒูŠุฒุŒ ู…ุงุฐุง ุฃูุนู„ุŸ\n"
81
- "ุงู„ุชููƒูŠุฑ ุฎุทูˆุฉ ุจุฎุทูˆุฉ:\n"
82
- "1) ุชุญุฏูŠุฏ ู…ุง ุฅุฐุง ูƒุงู† ุงู„ุฅุฑู‡ุงู‚ ุฌุณุฏูŠุงู‹ ุฃู… ู†ูุณูŠุงู‹.\n"
83
- "2) ูุญุต ู†ู…ุท ุงู„ู†ูˆู… ูˆุงู„ุนุงุฏุงุช ุงู„ูŠูˆู…ูŠุฉ.\n"
84
- "3) ุงู„ุชููƒูŠุฑ ููŠ ุนูˆุงู…ู„ ุงู„ุถุบุท ูˆุงู„ุฑุนุงูŠุฉ ุงู„ุฐุงุชูŠุฉ.\n"
85
- "4) ูˆุถุน ุฎุทุฉ ู…ู† ู†ุตุงุฆุญ ุชุฏุฑูŠุฌูŠุฉ ุณู‡ู„ุฉ ุงู„ุชุทุจูŠู‚.\n"
86
- "ุงู„ุฅุฌุงุจุฉ ุงู„ู†ู‡ุงุฆูŠุฉ:\n"
87
- "ู‚ุฏ ูŠุฑุชุจุท ุงู„ุฅุฑู‡ุงู‚ ุจุนุฏู… ุงู†ุชุธุงู… ุงู„ู†ูˆู… ุฃูˆ ุจุถุบูˆุทู ู†ูุณูŠุฉ ู…ุชุฑุงูƒู…ุฉ. "
88
- "ู…ู† ุงู„ู…ู‡ู… ุฃูˆู„ุงู‹ ู…ุฑุงุฌุนุฉ ู†ู…ุท ุญูŠุงุชูƒ: ุงุถุจุท ู…ูˆุงุนูŠุฏ ู†ูˆู… ุซุงุจุชุฉุŒ ูˆุงุจุชุนุฏ ุนู† ุงู„ู…ู†ุจู‘ู‡ุงุช ู‚ุจู„ ุงู„ู†ูˆู… ุจุณุงุนุชูŠู†. "
89
- "ู…ุงุฑุณ ุงู„ู…ุดูŠ ุงู„ุฎููŠู ุฃูˆ ุชู…ุงุฑูŠู† ุงู„ุงุณุชุฑุฎุงุก ูŠูˆู…ูŠู‘ุงู‹ ู„ุชุฎููŠู ุงู„ุชูˆุชุฑ. "
90
- "ุฅุฐุง ุงุณุชู…ุฑ ุงู„ุฅุฑู‡ุงู‚ ุฃูƒุซุฑ ู…ู† ุฃุณุจูˆุนูŠู† ุฑุบู… ู‡ุฐู‡ ุงู„ุชุบูŠูŠุฑุงุชุŒ ููƒุฑ ููŠ ุฒูŠุงุฑุฉ ุทุจูŠุจ ู„ูุญุต ููŠุชุงู…ูŠู† ุฏ ูˆูˆุธุงุฆู ุงู„ุบุฏุฉ ุงู„ุฏุฑู‚ูŠุฉ. "
91
- "ุฏูˆู‘ู† ู…ุดุงุนุฑูƒ ููŠ ู…ููƒุฑุฉ ูŠูˆู…ูŠุฉ ู„ุชูุฑูŠุบ ุงู„ู‚ู„ู‚ ูˆุชุดุฎูŠุต ุงู„ุฃุณุจุงุจ ุจุฏู‚ุฉ.\n"
92
- "โ€”\n\n"
93
-
94
-
95
- # โ”€โ”€ User section โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
96
- f"final_QT: {qt_str}\n\n"
97
- "ุณุคุงู„ ุงู„ู…ุณุชุฎุฏู…:\n"
98
- f"{question}\n\n"
99
 
100
- # โ”€โ”€ Final directive โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
101
- "ุงูƒุชุจ ูู‚ุฑุฉ ูˆุงุญุฏุฉ ู…ูุตู‘ู„ุฉ ู„ุง ุชู‚ู„ ุนู† ุซู„ุงุซ ุฌู…ู„ ู…ุชุฑุงุจุทุฉุŒ ุจุนุฏ ุฃู† ุชููƒู‘ุฑ ุฎุทูˆุฉ ุจุฎุทูˆุฉุŒ \n"
 
 
 
102
  "ุงู„ุฅุฌุงุจุฉ ุงู„ู†ู‡ุงุฆูŠุฉ:\n"
103
  )
104
- return prompt
105
-
106
- def classify_question(text: str, thr: float = 0.5):
107
- pred = max(clf_pipe(text), key=lambda x: x["score"])
108
- return label_map.get(pred["label"], pred["label"]) if pred["score"] >= thr else "G"
109
 
110
- def chat_generate(prompt: str, max_new_tokens: int = 128):
 
 
111
  chat_ids = chat_tok.apply_chat_template(
112
- [{"role": "system", "content": SYSTEM_MSG},
113
- {"role": "user", "content": prompt}],
114
  add_generation_prompt=True,
115
  return_tensors="pt"
116
  ).to(chat_model.device)
117
 
118
  gen_ids = chat_model.generate(
119
  chat_ids,
120
- max_new_tokens=max_new_tokens,
121
  do_sample=True,
122
  temperature=0.6,
123
  top_p=0.95,
@@ -128,33 +63,4 @@ def chat_generate(prompt: str, max_new_tokens: int = 128):
128
  )[0]
129
 
130
  answer_ids = gen_ids[chat_ids.shape[1]:]
131
- return chat_tok.decode(answer_ids,
132
- skip_special_tokens=True,
133
- clean_up_tokenization_spaces=True).strip()
134
-
135
- def get_mentalqa_answer(question: str, thr: float = 0.5):
136
- tag = classify_question(question, thr)
137
- prompt= build_prompt_arabic(question, tag)
138
- return chat_generate(prompt)
139
-
140
- # โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€ Gradio UI โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€โ”€
141
- CSS = """
142
- #container{max-width:640px;margin:1.5rem auto;}
143
- #question_box label,#answer_box label,
144
- #question_box textarea,#answer_box textarea{
145
- direction:rtl;text-align:right;
146
- }
147
- """
148
-
149
- with gr.Blocks(css=CSS, theme="soft") as demo:
150
- gr.Markdown("<h2 style='text-align:center;'>๐Ÿง  MentalQA โ€“ ู…ุณุงุนุฏ ุงู„ุตุญุฉ ุงู„ู†ูุณูŠุฉ</h2>"
151
- "<p style='text-align:center;'>ุงูƒุชุจ ุณุคุงู„ูƒ ุงู„ู†ูุณูŠ ุจุงู„ู„ุบุฉ ุงู„ุนุฑุจูŠุฉ ูˆุณูŠุฌูŠุจูƒ ุงู„ู†ู…ูˆุฐุฌ.</p>")
152
- with gr.Group(elem_id="container"):
153
- q = gr.Textbox(lines=3, placeholder="ุงูƒุชุจ ุณุคุงู„ูƒ ู‡ู†ุง...", label="ุณุคุงู„:", elem_id="question_box")
154
- a = gr.Textbox(lines=5, label="ุงู„ุฅุฌุงุจุฉ:", elem_id="answer_box")
155
- btn = gr.Button("ุฅุฑุณุงู„")
156
- btn.click(get_mentalqa_answer, inputs=q, outputs=a)
157
- q.submit(get_mentalqa_answer, inputs=q, outputs=a)
158
-
159
- if __name__ == "__main__":
160
- demo.launch()
 
1
+ import torch
2
+ from transformers import (AutoTokenizer, AutoModelForCausalLM,
3
+ AutoModelForSequenceClassification, pipeline)
 
 
 
 
 
4
 
5
+ CHAT_REPO = "yasser-alharbi/MentalQA"
 
 
 
 
 
 
 
 
6
  CLASSIFIER_REPO = "yasser-alharbi/MentalQA-Classification"
7
 
 
8
  chat_tok = AutoTokenizer.from_pretrained(CHAT_REPO, use_fast=False)
9
  chat_model = AutoModelForCausalLM.from_pretrained(
10
  CHAT_REPO,
11
  torch_dtype="auto",
12
+ device_map="auto",
13
  low_cpu_mem_usage=True,
14
  )
15
 
 
16
  clf_tok = AutoTokenizer.from_pretrained(CLASSIFIER_REPO)
17
  clf_model = AutoModelForSequenceClassification.from_pretrained(CLASSIFIER_REPO)
18
 
19
  device_idx = 0 if torch.cuda.is_available() else -1
20
+ clf_pipe = pipeline("text-classification", model=clf_model, tokenizer=clf_tok, device=device_idx)
 
 
 
 
21
 
22
  label_map = {
23
+ "LABEL_0": "A", "LABEL_1": "B", "LABEL_2": "C",
24
+ "LABEL_3": "D", "LABEL_4": "E", "LABEL_5": "F", "LABEL_6": "G"
 
 
 
 
 
25
  }
26
 
 
27
  SYSTEM_MSG = (
28
+ "ุฃู†ุช ู…ุณุงุนุฏ ุฐูƒูŠ ู„ู„ุตุญุฉ ุงู„ู†ูุณูŠุฉ ุงุณู…ู‡ MentalQA. "
29
  "ู„ุง ุชุฐูƒุฑ ุงุณู…ูƒ ุฃูˆ ู…ู†ุตุฉ ุนู…ู„ูƒ ุฅู„ุง ุฅุฐุง ุณูุฆู„ุช ุตุฑุงุญุฉู‹ ุนู† ู‡ูˆูŠุชูƒ."
 
 
 
 
 
30
  )
31
 
32
+ def classify_question(text: str, thr: float = 0.5) -> str:
33
+ pred = max(clf_pipe(text), key=lambda x: x["score"])
34
+ return label_map.get(pred["label"], pred["label"]) if pred["score"] >= thr else "G"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
36
+ def build_prompt(question: str, tag: str) -> str:
37
+ return (
38
+ f"{SYSTEM_MSG}\n\nfinal_QT: {tag}\n\n"
39
+ f"ุณุคุงู„ ุงู„ู…ุณุชุฎุฏู…:\n{question}\n\n"
40
+ "ุงูƒุชุจ ูู‚ุฑุฉ ูˆุงุญุฏุฉ ู…ูุตู‘ู„ุฉ ู„ุง ุชู‚ู„ ุนู† ุซู„ุงุซ ุฌู…ู„ ู…ุชุฑุงุจุทุฉุŒ ุจุนุฏ ุฃู† ุชููƒู‘ุฑ ุฎุทูˆุฉ ุจุฎุทูˆุฉ.\n"
41
  "ุงู„ุฅุฌุงุจุฉ ุงู„ู†ู‡ุงุฆูŠุฉ:\n"
42
  )
 
 
 
 
 
43
 
44
+ def generate_mentalqa_answer(question: str) -> str:
45
+ tag = classify_question(question)
46
+ prompt = build_prompt(question, tag)
47
  chat_ids = chat_tok.apply_chat_template(
48
+ [{"role": "system", "content": SYSTEM_MSG}, {"role": "user", "content": prompt}],
 
49
  add_generation_prompt=True,
50
  return_tensors="pt"
51
  ).to(chat_model.device)
52
 
53
  gen_ids = chat_model.generate(
54
  chat_ids,
55
+ max_new_tokens=128,
56
  do_sample=True,
57
  temperature=0.6,
58
  top_p=0.95,
 
63
  )[0]
64
 
65
  answer_ids = gen_ids[chat_ids.shape[1]:]
66
+ return chat_tok.decode(answer_ids, skip_special_tokens=True).strip()