hin123123 commited on
Commit
97cc9f8
·
verified ·
1 Parent(s): 42f231d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -15
app.py CHANGED
@@ -2,8 +2,10 @@ import gradio as gr
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
5
- MODEL_ID = "hin123123/gemma2-2b-it-slp-merged" # merged repo from Colab
 
6
 
 
7
  tokenizer = AutoTokenizer.from_pretrained(
8
  MODEL_ID,
9
  use_fast=True,
@@ -24,17 +26,39 @@ model = AutoModelForCausalLM.from_pretrained(
24
  )
25
  model.eval()
26
 
 
27
  SYSTEM_PROMPT = (
28
- "You are a speech-language pathology assistant. "
29
- "You analyze child speech production errors and respond ONLY with JSON. "
30
- "Use concise, accurate outputs."
31
  )
32
 
33
- def run(user_text, max_new_tokens=256, temperature=0.0, top_p=1.0, repetition_penalty=1.05):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  messages = [
35
  {"role": "system", "content": SYSTEM_PROMPT},
36
  {"role": "user", "content": user_text},
37
  ]
 
38
  prompt = tokenizer.apply_chat_template(
39
  messages,
40
  tokenize=False,
@@ -56,27 +80,31 @@ def run(user_text, max_new_tokens=256, temperature=0.0, top_p=1.0, repetition_pe
56
  )
57
 
58
  gen = out_ids[0, inputs["input_ids"].shape[1]:]
59
- return tokenizer.decode(gen, skip_special_tokens=True).strip()
 
 
60
 
61
  demo = gr.Interface(
62
  fn=run,
63
  inputs=[
64
  gr.Textbox(
65
- label="User text / Case JSON",
66
- lines=4,
67
- value=(
68
- "Instructions: Classify or reflect the user's spoken attempt using Substitution, "
69
- "Omission, or Addition, and respond with JSON only.\n\n"
70
- 'Case JSON: {"target": "mop", "ipa_target": "/mɑp/", "attempt": "mo", "ipa_attempt": "/mɑ/"}'
71
- ),
72
  ),
73
  gr.Slider(8, 1024, 256, step=1, label="max_new_tokens"),
74
  gr.Slider(0, 1, 0.0, step=0.05, label="temperature"),
75
  gr.Slider(0.1, 1.0, 1.0, step=0.05, label="top_p"),
76
  gr.Slider(1.0, 1.5, 1.05, step=0.01, label="repetition_penalty"),
77
  ],
78
- outputs=gr.Textbox(label="Model output"),
79
- title="Gemma-2-2B-IT SLP JSON API (Merged)",
 
 
 
 
 
 
80
  api_name="run",
81
  )
82
 
 
2
  import torch
3
  from transformers import AutoTokenizer, AutoModelForCausalLM
4
 
5
+ # ===== Model =====
6
+ MODEL_ID = "hin123123/gemma2-2b-it-slp-merged" # merged model you created
7
 
8
+ # ===== Tokenizer & Model Load =====
9
  tokenizer = AutoTokenizer.from_pretrained(
10
  MODEL_ID,
11
  use_fast=True,
 
26
  )
27
  model.eval()
28
 
29
+ # ===== Prompts =====
30
  SYSTEM_PROMPT = (
31
+ "You are an articulation/phonology error expert SLP assistant. "
32
+ "You only respond with valid JSON, never explanations."
 
33
  )
34
 
35
+ TRAIN_INSTRUCTION = (
36
+ "Instructions: Classify or reflect the user's spoken attempt using Substitution, Omission, or Addition. "
37
+ "Include subtype and return JSON with keys: disorder, category, subtype, target, attempt, ipa_target, ipa_attempt, "
38
+ "correct_rate(pcc), severity_level, evidence, suggestions. "
39
+ "If present, also include task scaffolding keys: task_type, times_read, marks, daily_max_marks, task_label, "
40
+ "sentence_target, sentence_attempt. Return JSON only.\n"
41
+ )
42
+
43
+
44
+ def run(case_json: str,
45
+ max_new_tokens: int = 256,
46
+ temperature: float = 0.0,
47
+ top_p: float = 1.0,
48
+ repetition_penalty: float = 1.05):
49
+
50
+ case_json = case_json.strip()
51
+ if not case_json:
52
+ return "{}", # empty JSON
53
+
54
+ # Build user text exactly like training
55
+ user_text = TRAIN_INSTRUCTION + "Case JSON:\n" + case_json
56
+
57
  messages = [
58
  {"role": "system", "content": SYSTEM_PROMPT},
59
  {"role": "user", "content": user_text},
60
  ]
61
+
62
  prompt = tokenizer.apply_chat_template(
63
  messages,
64
  tokenize=False,
 
80
  )
81
 
82
  gen = out_ids[0, inputs["input_ids"].shape[1]:]
83
+ text = tokenizer.decode(gen, skip_special_tokens=True).strip()
84
+ return text
85
+
86
 
87
  demo = gr.Interface(
88
  fn=run,
89
  inputs=[
90
  gr.Textbox(
91
+ label="Case JSON",
92
+ lines=6,
93
+ value='{"target": "recording", "ipa_target": "/ɹəˈkɔɹdɪŋ/", "attempt": "wecording", "ipa_attempt": "/wəˈkɔɹdɪŋ/"}',
 
 
 
 
94
  ),
95
  gr.Slider(8, 1024, 256, step=1, label="max_new_tokens"),
96
  gr.Slider(0, 1, 0.0, step=0.05, label="temperature"),
97
  gr.Slider(0.1, 1.0, 1.0, step=0.05, label="top_p"),
98
  gr.Slider(1.0, 1.5, 1.05, step=0.01, label="repetition_penalty"),
99
  ],
100
+ outputs=gr.Textbox(label="Model output (JSON expected)"),
101
+ title="Gemma-2-2B-IT SLP JSON API (Merged, 283k dataset)",
102
+ description=(
103
+ "Paste a single case as JSON (target, attempt, ipa_target, ipa_attempt, etc.).\n"
104
+ "The model was fine-tuned to output JSON with keys: "
105
+ "disorder, category, subtype, target, attempt, ipa_target, ipa_attempt, "
106
+ "correct_rate(pcc), severity_level, evidence, suggestions, and optional task_* keys."
107
+ ),
108
  api_name="run",
109
  )
110