|
|
import os |
|
|
import csv |
|
|
import uuid |
|
|
from datetime import datetime |
|
|
from typing import List, Tuple |
|
|
import torch |
|
|
import gradio as gr |
|
|
from filelock import FileLock |
|
|
from huggingface_hub import HfApi |
|
|
from transformers import ( |
|
|
AutoTokenizer, |
|
|
AutoModelForCausalLM, |
|
|
pipeline, |
|
|
StoppingCriteria, |
|
|
StoppingCriteriaList, |
|
|
) |
|
|
from peft import PeftModel |
|
|
import tempfile |
|
|
import pandas as pd |
|
|
from datasets import load_dataset |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MAX_HISTORY_TURNS = 10 |
|
|
MAX_PROMPT_TOKENS = 1024 |
|
|
MAX_NEW_TOKENS = 60 |
|
|
|
|
|
LOG_DIR = "logs" |
|
|
os.makedirs(LOG_DIR, exist_ok=True) |
|
|
LOCK_PATH = os.path.join(LOG_DIR, ".lock") |
|
|
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN") |
|
|
SPACE_ID = os.environ.get("SPACE_ID") |
|
|
|
|
|
MODEL_ID = "hparten/prob1_qlora_math_student" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
|
MODEL_ID, |
|
|
device_map="auto", |
|
|
) |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) |
|
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
pipe = pipeline( |
|
|
"text-generation", |
|
|
model=model, |
|
|
tokenizer=tokenizer, |
|
|
dtype=torch.float16, |
|
|
device_map="auto", |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
strategy_explanations = { |
|
|
"friendly": "You add on from 41 until you get to 84, usually by counting by 10s, 20s, or 40, then ones.", |
|
|
"differencing": "You difference the ones or tens separately during any part of your answer.", |
|
|
"subtraction": "You turn the problem into a subtraction: 84 minus 41 equals blank to find the missing addend.", |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_system_block(problem_prefix, strategy): |
|
|
problem_text = "41 plus blank equals 84" |
|
|
strat_key = strategy.lower() |
|
|
strat_expl = strategy_explanations.get(strat_key, "Use the named strategy to explain your steps clearly.") |
|
|
strategy_tag = f"<strategy_{strat_key}>" |
|
|
problem_tag = f"<{problem_prefix.lower()}>" |
|
|
|
|
|
system_text = ( |
|
|
f"<system>\n" |
|
|
f"You are the student in a math dialogue.\n" |
|
|
f"Solving the PROBLEM: {problem_tag} - {problem_text}\n" |
|
|
f"Using the STRATEGY: {strategy_tag} — {strat_expl}\n" |
|
|
f"Return EXACTLY one sentence inside <student> ... </student>." |
|
|
f"Do NOT ask questions or include teacher text.\n" |
|
|
f"Mention the strategy implicity only if natural.\n" |
|
|
f"</system>\n" |
|
|
) |
|
|
return system_text.strip() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def build_prompt(strategy, history, teacher_question, tokenizer, problem_prefix="Problem_1"): |
|
|
base_system_prompt = build_system_block(problem_prefix, strategy) |
|
|
|
|
|
turns = [] |
|
|
for tq, sa in history[-MAX_HISTORY_TURNS:]: |
|
|
turns.append(f"<teacher> {tq} </teacher> <student> {sa} </student>") |
|
|
|
|
|
full_prompt = base_system_prompt + "\n" + " ".join(turns) |
|
|
full_prompt += f"<teacher> {teacher_question} </teacher>\n<student>" |
|
|
|
|
|
while len(tokenizer.encode(full_prompt, add_special_tokens=False)) > MAX_PROMPT_TOKENS and len(turns) > 0: |
|
|
turns.pop(0) |
|
|
convo_block = " ".join(turns) |
|
|
full_prompt = base_system_prompt + convo_block + f"<teacher> {teacher_question} </teacher>" |
|
|
|
|
|
return full_prompt.strip() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def make_bad_words_ids(tokenizer, words: List[str]) -> List[List[int]]: |
|
|
"""Safely constructs bad_words_ids for special and normal tokens.""" |
|
|
out = [] |
|
|
for w in words: |
|
|
if w in tokenizer.all_special_tokens: |
|
|
tid = tokenizer.convert_tokens_to_ids(w) |
|
|
if tid != tokenizer.unk_token_id: |
|
|
out.append([tid]) |
|
|
else: |
|
|
toks = tokenizer.encode(w, add_special_tokens=False) |
|
|
if toks: |
|
|
out.append(toks) |
|
|
return out |
|
|
|
|
|
bad_words_ids = make_bad_words_ids( |
|
|
tokenizer, |
|
|
["<teacher>", "</teacher>", "<system>", "</system>", "Teacher:", "teacher:"] |
|
|
) |
|
|
eos_id = tokenizer.convert_tokens_to_ids("</student>") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
HF_DATASET_REPO = "hparten/math_chatbot_logs" |
|
|
|
|
|
api = HfApi() |
|
|
session_logs = {} |
|
|
last_activity = {} |
|
|
|
|
|
def add_turn_to_memory(session_id, username, strategy, teacher_msg, student_msg): |
|
|
"""Store one turn in memory.""" |
|
|
from datetime import datetime |
|
|
row = { |
|
|
"timestamp": datetime.now().isoformat(timespec="seconds"), |
|
|
"session_id": session_id, |
|
|
"username": username, |
|
|
"strategy": strategy, |
|
|
"teacher": teacher_msg, |
|
|
"student": student_msg, |
|
|
} |
|
|
session_logs.setdefault(session_id, []).append(row) |
|
|
update_activity(session_id) |
|
|
|
|
|
def update_activity(session_id): |
|
|
import time |
|
|
last_activity[session_id] = time.time() |
|
|
|
|
|
def flush_session_to_hub(session_id): |
|
|
"""Upload session logs to Hugging Face dataset as a single Parquet file.""" |
|
|
if session_id not in session_logs or not session_logs[session_id]: |
|
|
print(f"[flush] No logs found for session {session_id}") |
|
|
return |
|
|
|
|
|
df = pd.DataFrame(session_logs[session_id]) |
|
|
del session_logs[session_id] |
|
|
|
|
|
try: |
|
|
ds = load_dataset(HF_DATASET_REPO, split="train", token=HF_TOKEN) |
|
|
existing = ds.to_pandas() |
|
|
combined = pd.concat([existing, df], ignore_index=True) |
|
|
except Exception: |
|
|
combined = df |
|
|
|
|
|
with tempfile.NamedTemporaryFile("wb", delete=False, suffix=".parquet") as tmp: |
|
|
combined.to_parquet(tmp.name, index=False) |
|
|
tmp_path = tmp.name |
|
|
|
|
|
api.upload_file( |
|
|
path_or_fileobj=tmp_path, |
|
|
path_in_repo="chat_logs.parquet", |
|
|
repo_id=HF_DATASET_REPO, |
|
|
repo_type="dataset", |
|
|
token=HF_TOKEN, |
|
|
) |
|
|
os.remove(tmp_path) |
|
|
print(f"[flush] Uploaded session {session_id} to HF dataset.") |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def generate_response(teacher_question, username, history, session_id, strategy): |
|
|
prompt = build_prompt(strategy, history, teacher_question, tokenizer) |
|
|
|
|
|
out = pipe( |
|
|
prompt, |
|
|
max_new_tokens=MAX_NEW_TOKENS, |
|
|
do_sample=True, |
|
|
temperature=0.4, |
|
|
top_p=0.9, |
|
|
repetition_penalty=1.05, |
|
|
no_repeat_ngram_size=6, |
|
|
pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id, |
|
|
eos_token_id=eos_id, |
|
|
bad_words_ids=bad_words_ids, |
|
|
return_full_text=False, |
|
|
) |
|
|
|
|
|
out_text = out[0]["generated_text"] |
|
|
if "<student>" in out_text and "</student>" in out_text: |
|
|
student_reply = out_text.split("<student>", 1)[1].split("</student>", 1)[0].strip() |
|
|
else: |
|
|
student_reply = out_text.strip() |
|
|
|
|
|
|
|
|
student_reply = student_reply.split(".")[0].strip() + "." |
|
|
|
|
|
history.append((teacher_question, student_reply)) |
|
|
add_turn_to_memory(session_id, username, strategy, teacher_question, student_reply) |
|
|
|
|
|
return student_reply, history |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import threading, time |
|
|
|
|
|
INACTIVITY_LIMIT = 600 |
|
|
|
|
|
def check_inactivity_loop(): |
|
|
while True: |
|
|
now = time.time() |
|
|
inactive = [sid for sid, ts in last_activity.items() if now - ts > INACTIVITY_LIMIT] |
|
|
for sid in inactive: |
|
|
try: |
|
|
flush_session_to_hub(sid) |
|
|
del last_activity[sid] |
|
|
except Exception as e: |
|
|
print(f"[auto-flush-error] {sid}: {e}") |
|
|
time.sleep(60) |
|
|
|
|
|
threading.Thread(target=check_inactivity_loop, daemon=True).start() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def on_send(teacher_question, username, strategy_choice, history, session_id): |
|
|
if not session_id: |
|
|
session_id = uuid.uuid4().hex[:12] |
|
|
if history is None: |
|
|
history = [] |
|
|
if not username.strip(): |
|
|
gr.Warning("Please enter your name before starting the chat.") |
|
|
return history, history, "", session_id |
|
|
if not teacher_question.strip(): |
|
|
gr.Warning("Please type a question for the student before sending.") |
|
|
return history, history, "", session_id |
|
|
|
|
|
student_reply, history = generate_response( |
|
|
teacher_question.strip(), |
|
|
username.strip(), |
|
|
history, |
|
|
session_id, |
|
|
strategy_choice.lower() |
|
|
) |
|
|
|
|
|
msgs = [] |
|
|
for t, s in history[-MAX_HISTORY_TURNS:]: |
|
|
msgs.append({"role": "user", "content": t}) |
|
|
msgs.append({"role": "assistant", "content": s}) |
|
|
|
|
|
return msgs, history, "", session_id |
|
|
|
|
|
def on_reset(chat, history, teacher_q, session_id): |
|
|
"""Flush the current session before resetting.""" |
|
|
if session_id: |
|
|
try: |
|
|
flush_session_to_hub(session_id) |
|
|
print(f"[manual flush] Uploaded session {session_id} to HF dataset.") |
|
|
except Exception as e: |
|
|
print(f"[manual flush error] {session_id}: {e}") |
|
|
|
|
|
return [], [], "", uuid.uuid4().hex[:12] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Blocks(title="Elementary Math Student Chatbot") as demo: |
|
|
gr.Markdown("## 🧮 Practice Eliciting Student Thinking (Prototype)") |
|
|
gr.Markdown( |
|
|
"You are an elementary math teacher exploring a student's reasoning for **41 + ___ = 84**." |
|
|
"\nAsk questions and see how the student explains their thinking." |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
username = gr.Textbox(label="👤 Your Name (first last)", placeholder="Enter your name...") |
|
|
strategy_choice = gr.Dropdown( |
|
|
["friendly", "differencing", "subtraction"], |
|
|
value="Choose one", |
|
|
label="🧩 Student Strategy" |
|
|
) |
|
|
reset_btn = gr.Button("🔄 Start Over", variant="secondary") |
|
|
|
|
|
teacher_q = gr.Textbox(label="👩🏫 Teacher Question", placeholder="Ask the student a question…") |
|
|
chat = gr.Chatbot(label="💬 Chat", type="messages") |
|
|
state_history = gr.State([]) |
|
|
state_session = gr.State("") |
|
|
|
|
|
send = gr.Button("Send", variant="primary") |
|
|
|
|
|
send.click( |
|
|
on_send, |
|
|
inputs=[teacher_q, username, strategy_choice, state_history, state_session], |
|
|
outputs=[chat, state_history, teacher_q, state_session], |
|
|
) |
|
|
|
|
|
reset_btn.click( |
|
|
on_reset, |
|
|
inputs=[chat, state_history, teacher_q, state_session], |
|
|
outputs=[chat, state_history, teacher_q, state_session], |
|
|
) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.queue() |
|
|
demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False) |