hparten commited on
Commit
d4a01e1
·
1 Parent(s): aa02b62
Files changed (2) hide show
  1. app.py +219 -64
  2. requirements.txt +8 -0
app.py CHANGED
@@ -1,70 +1,225 @@
 
 
 
 
 
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient
3
-
4
-
5
- def respond(
6
- message,
7
- history: list[dict[str, str]],
8
- system_message,
9
- max_tokens,
10
- temperature,
11
- top_p,
12
- hf_token: gr.OAuthToken,
13
- ):
14
- """
15
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
16
- """
17
- client = InferenceClient(token=hf_token.token, model="openai/gpt-oss-20b")
18
-
19
- messages = [{"role": "system", "content": system_message}]
20
-
21
- messages.extend(history)
22
-
23
- messages.append({"role": "user", "content": message})
24
-
25
- response = ""
26
-
27
- for message in client.chat_completion(
28
- messages,
29
- max_tokens=max_tokens,
30
- stream=True,
31
- temperature=temperature,
32
- top_p=top_p,
33
- ):
34
- choices = message.choices
35
- token = ""
36
- if len(choices) and choices[0].delta.content:
37
- token = choices[0].delta.content
38
-
39
- response += token
40
- yield response
41
-
42
-
43
- """
44
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
45
- """
46
- chatbot = gr.ChatInterface(
47
- respond,
48
- type="messages",
49
- additional_inputs=[
50
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
51
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
52
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
53
- gr.Slider(
54
- minimum=0.1,
55
- maximum=1.0,
56
- value=0.95,
57
- step=0.05,
58
- label="Top-p (nucleus sampling)",
59
- ),
60
- ],
61
  )
62
 
63
- with gr.Blocks() as demo:
64
- with gr.Sidebar():
65
- gr.LoginButton()
66
- chatbot.render()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
 
 
 
 
 
68
 
69
  if __name__ == "__main__":
70
- demo.launch()
 
 
1
+ import os
2
+ import csv
3
+ import uuid
4
+ from datetime import datetime
5
+ import torch
6
  import gradio as gr
7
+ from filelock import FileLock
8
+ from huggingface_hub import HfApi
9
+ from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
10
+ from peft import PeftModel
11
+
12
+ # =========================
13
+ # ⚙️ Config
14
+ # =========================
15
+ MAX_HISTORY_TURNS = 10
16
+ MAX_PROMPT_TOKENS = 1024
17
+ MAX_NEW_TOKENS = 60
18
+
19
+ LOG_DIR = "/tmp/chat_logs"
20
+ os.makedirs(LOG_DIR, exist_ok=True)
21
+ LOCK_PATH = os.path.join(LOG_DIR, ".lock")
22
+
23
+ HF_TOKEN = os.environ.get("HF_TOKEN")
24
+ PRIVATE_LOG_REPO = "hparten/math_chat_logs" # Private dataset repo
25
+ HF_API = HfApi()
26
+
27
+ MODEL_ID = "hparten/prob1_qlora_math_student"
28
+
29
+ # =========================
30
+ # 🔠 Model + Tokenizer
31
+ # =========================
32
+ model = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="auto")
33
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
34
+ tokenizer.pad_token = tokenizer.eos_token
35
+
36
+ pipe = pipeline(
37
+ "text-generation",
38
+ model=model,
39
+ tokenizer=tokenizer,
40
+ dtype=torch.float16,
41
+ device_map="auto",
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
  )
43
 
44
+ # =========================
45
+ # 🧩 Strategy Explanations
46
+ # =========================
47
+ strategy_explanations = {
48
+ "friendly": "You add on from 41 until you get to 84, usually by counting by 10s, 20s, or 40, then ones.",
49
+ "differencing": "You difference the ones or tens separately during any part of your answer.",
50
+ "subtraction": "You turn the problem into a subtraction: 84 minus 41 equals blank to find the missing addend.",
51
+ }
52
+
53
+ # =========================
54
+ # 🧠 System Prompt
55
+ # =========================
56
+ def build_system_block(problem_prefix, strategy):
57
+ problem_text = "41 plus blank equals 84"
58
+ strat_key = strategy.lower()
59
+ strat_expl = strategy_explanations.get(strat_key, "Use the named strategy to explain your steps clearly.")
60
+ strategy_tag = f"<strategy_{strat_key}>"
61
+ problem_tag = f"<{problem_prefix.lower()}>"
62
+
63
+ system_text = (
64
+ f"<system>\n"
65
+ f"You are the student in a math dialogue.\n"
66
+ f"PROBLEM: {problem_tag} - {problem_text}\n"
67
+ f"STRATEGY: {strategy_tag} — {strat_expl}\n"
68
+ f"When you answer, think step by step, like a student explaining their work out loud.\n"
69
+ f"Keep your answers short and natural—1 sentence. Let the teacher ask follow-up questions.\n"
70
+ f"Reply exactly to the teacher questions using <student> ... </student>. Never include any teacher text in your answer.\n"
71
+ f"</system>\n"
72
+ )
73
+ return system_text.strip()
74
+
75
+ # =========================
76
+ # 🧾 Logging (Private Upload)
77
+ # =========================
78
+ CSV_HEADERS = ["timestamp", "session_id", "username", "strategy", "teacher", "student"]
79
+
80
+ def log_turn(session_id, username, strategy, teacher_msg, student_msg):
81
+ path = os.path.join(LOG_DIR, f"chat_{session_id}.csv")
82
+ file_exists = os.path.exists(path)
83
+
84
+ with FileLock(LOCK_PATH):
85
+ with open(path, "a", newline="", encoding="utf-8") as f:
86
+ writer = csv.writer(f)
87
+ if not file_exists:
88
+ writer.writerow(CSV_HEADERS)
89
+ writer.writerow([
90
+ datetime.now().isoformat(timespec="seconds"),
91
+ session_id,
92
+ username,
93
+ strategy,
94
+ teacher_msg,
95
+ student_msg,
96
+ ])
97
+
98
+ # --- Try uploading to private dataset repo ---
99
+ try:
100
+ HF_API.upload_file(
101
+ path_or_fileobj=path,
102
+ path_in_repo=f"{os.path.basename(path)}",
103
+ repo_id=PRIVATE_LOG_REPO,
104
+ repo_type="dataset",
105
+ token=HF_TOKEN,
106
+ )
107
+ print(f"✅ Uploaded log to private dataset: {PRIVATE_LOG_REPO}")
108
+ except Exception as e:
109
+ print(f"⚠️ Could not push log: {e}")
110
+
111
+ # =========================
112
+ # 🧩 Prompt Builder
113
+ # =========================
114
+ def build_prompt(strategy, history, teacher_question, tokenizer, problem_prefix="Problem_1"):
115
+ base_system_prompt = build_system_block(problem_prefix, strategy)
116
+ turns = [f"<teacher> {tq} </teacher> <student> {sa} </student>" for tq, sa in history[-MAX_HISTORY_TURNS:]]
117
+ full_prompt = base_system_prompt + "\n" + " ".join(turns)
118
+ full_prompt += f"<teacher> {teacher_question} </teacher>\n"
119
+
120
+ while len(tokenizer.encode(full_prompt, add_special_tokens=False)) > MAX_PROMPT_TOKENS and len(turns) > 0:
121
+ turns.pop(0)
122
+ convo_block = " ".join(turns)
123
+ full_prompt = base_system_prompt + convo_block + f"<teacher> {teacher_question} </teacher>"
124
+
125
+ return full_prompt.strip()
126
+
127
+ # =========================
128
+ # 🤖 Generation
129
+ # =========================
130
+ def generate_response(teacher_question, username, history, session_id, strategy):
131
+ prompt = build_prompt(strategy, history, teacher_question, tokenizer)
132
+ out = pipe(
133
+ prompt,
134
+ max_new_tokens=MAX_NEW_TOKENS,
135
+ do_sample=True,
136
+ temperature=0.5,
137
+ top_p=0.9,
138
+ repetition_penalty=1.05,
139
+ pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id,
140
+ return_full_text=False,
141
+ )
142
+ out_text = out[0]["generated_text"]
143
+
144
+ if "<student>" in out_text and "</student>" in out_text:
145
+ student_reply = out_text.split("<student>", 1)[1].split("</student>", 1)[0].strip()
146
+ else:
147
+ student_reply = out_text.strip()
148
+
149
+ history.append((teacher_question, student_reply))
150
+ log_turn(session_id, username, strategy, teacher_question, student_reply)
151
+ return student_reply, history
152
+
153
+ # =========================
154
+ # 🖥 Gradio UI
155
+ # =========================
156
+ def on_send(teacher_question, username, strategy_choice, history, session_id):
157
+ if not session_id:
158
+ session_id = uuid.uuid4().hex[:12]
159
+ if history is None:
160
+ history = []
161
+ if not username.strip():
162
+ gr.Warning("Please enter your name before starting the chat.")
163
+ return history, history, "", session_id
164
+ if not teacher_question.strip():
165
+ gr.Warning("Please type a question for the student before sending.")
166
+ return history, history, "", session_id
167
+
168
+ student_reply, history = generate_response(
169
+ teacher_question.strip(),
170
+ username.strip(),
171
+ history,
172
+ session_id,
173
+ strategy_choice.lower(),
174
+ )
175
+
176
+ msgs = []
177
+ for t, s in history[-MAX_HISTORY_TURNS:]:
178
+ msgs.append({"role": "user", "content": t})
179
+ msgs.append({"role": "assistant", "content": s})
180
+
181
+ return msgs, history, "", session_id
182
+
183
+ def on_reset():
184
+ return [], [], "", uuid.uuid4().hex[:12]
185
+
186
+ # =========================
187
+ # 🚀 Gradio App
188
+ # =========================
189
+ with gr.Blocks(title="Elementary Math Student Chatbot") as demo:
190
+ gr.Markdown("## 🧮 Practice Eliciting Student Thinking (Prototype)")
191
+ gr.Markdown(
192
+ "You are an elementary math teacher exploring a student's reasoning for **41 + ___ = 84**.\n"
193
+ "Ask questions and see how the student explains their thinking."
194
+ )
195
+
196
+ with gr.Row():
197
+ username = gr.Textbox(label="👤 Your Name", placeholder="Enter your name...")
198
+ strategy_choice = gr.Dropdown(
199
+ ["friendly", "differencing", "subtraction"],
200
+ value="friendly",
201
+ label="🧩 Student Strategy",
202
+ )
203
+ reset_btn = gr.Button("🔄 Start Over", variant="secondary")
204
+
205
+ teacher_q = gr.Textbox(label="👩‍🏫 Teacher Question", placeholder="Ask the student a question…")
206
+ chat = gr.Chatbot(label="💬 Chat", type="messages")
207
+ state_history = gr.State([])
208
+ state_session = gr.State("")
209
+ send = gr.Button("Send", variant="primary")
210
+
211
+ send.click(
212
+ on_send,
213
+ inputs=[teacher_q, username, strategy_choice, state_history, state_session],
214
+ outputs=[chat, state_history, teacher_q, state_session],
215
+ )
216
 
217
+ reset_btn.click(
218
+ on_reset,
219
+ inputs=[],
220
+ outputs=[chat, state_history, teacher_q, state_session],
221
+ )
222
 
223
  if __name__ == "__main__":
224
+ demo.queue()
225
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ torch>=2.1.0
2
+ transformers>=4.44.0
3
+ peft>=0.10.0
4
+ accelerate>=0.30.0
5
+ bitsandbytes>=0.43.0
6
+ huggingface_hub>=0.23.0
7
+ gradio>=4.40.0
8
+ filelock>=3.12.0