import os import csv import uuid from datetime import datetime from typing import List, Tuple import torch import gradio as gr from filelock import FileLock from huggingface_hub import HfApi from transformers import ( AutoTokenizer, AutoModelForCausalLM, pipeline, StoppingCriteria, StoppingCriteriaList, ) from peft import PeftModel import tempfile import pandas as pd from datasets import load_dataset # ========================= # โš™๏ธ Config # ========================= MAX_HISTORY_TURNS = 10 MAX_PROMPT_TOKENS = 1024 MAX_NEW_TOKENS = 60 LOG_DIR = "logs" os.makedirs(LOG_DIR, exist_ok=True) LOCK_PATH = os.path.join(LOG_DIR, ".lock") HF_TOKEN = os.environ.get("HF_TOKEN") SPACE_ID = os.environ.get("SPACE_ID") MODEL_ID = "hparten/prob1_qlora_math_student" # ========================= # ๐Ÿ”  Model + Tokenizer # ========================= model = AutoModelForCausalLM.from_pretrained( MODEL_ID, device_map="auto", ) tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) tokenizer.pad_token = tokenizer.eos_token pipe = pipeline( "text-generation", model=model, tokenizer=tokenizer, dtype=torch.float16, device_map="auto", ) # ========================= # ๐Ÿงฉ Strategy Explanations # ========================= strategy_explanations = { "friendly": "You add on from 41 until you get to 84, usually by counting by 10s, 20s, or 40, then ones.", "differencing": "You difference the ones or tens separately during any part of your answer.", "subtraction": "You turn the problem into a subtraction: 84 minus 41 equals blank to find the missing addend.", } # ========================= # ๐Ÿง  Build System Prompt # ========================= def build_system_block(problem_prefix, strategy): problem_text = "41 plus blank equals 84" strat_key = strategy.lower() strat_expl = strategy_explanations.get(strat_key, "Use the named strategy to explain your steps clearly.") strategy_tag = f"" problem_tag = f"<{problem_prefix.lower()}>" system_text = ( f"\n" f"You are the student in a math dialogue.\n" f"Solving the PROBLEM: {problem_tag} - {problem_text}\n" f"Using the STRATEGY: {strategy_tag} โ€” {strat_expl}\n" f"Return EXACTLY one sentence inside ... ." f"Do NOT ask questions or include teacher text.\n" f"Mention the strategy implicity only if natural.\n" f"\n" ) return system_text.strip() # ========================= # ๐Ÿงพ Logging # ========================= #CSV_HEADERS = ["timestamp", "session_id", "username", "strategy", "teacher", "student"] # #def _append_csv(path, row): # with FileLock(LOCK_PATH): # file_exists = os.path.exists(path) # with open(path, "a", newline="", encoding="utf-8") as f: # w = csv.writer(f) # if not file_exists: # w.writerow(CSV_HEADERS) # w.writerow(row) # #def log_turn(session_id, username, strategy, teacher_msg, student_msg): # row = [datetime.now().isoformat(timespec="seconds"), session_id, username, strategy, #teacher_msg, student_msg] # per_session = os.path.join(LOG_DIR, f"chat_{session_id}.csv") # _append_csv(per_session, row) # ========================= # ๐Ÿงฉ Prompt builder # ========================= def build_prompt(strategy, history, teacher_question, tokenizer, problem_prefix="Problem_1"): base_system_prompt = build_system_block(problem_prefix, strategy) turns = [] for tq, sa in history[-MAX_HISTORY_TURNS:]: turns.append(f" {tq} {sa} ") full_prompt = base_system_prompt + "\n" + " ".join(turns) full_prompt += f" {teacher_question} \n" while len(tokenizer.encode(full_prompt, add_special_tokens=False)) > MAX_PROMPT_TOKENS and len(turns) > 0: turns.pop(0) convo_block = " ".join(turns) full_prompt = base_system_prompt + convo_block + f" {teacher_question} " return full_prompt.strip() # ========================================================= # โŒ Banned Tokens (prevents teacher drift) # ========================================================= def make_bad_words_ids(tokenizer, words: List[str]) -> List[List[int]]: """Safely constructs bad_words_ids for special and normal tokens.""" out = [] for w in words: if w in tokenizer.all_special_tokens: tid = tokenizer.convert_tokens_to_ids(w) if tid != tokenizer.unk_token_id: out.append([tid]) else: toks = tokenizer.encode(w, add_special_tokens=False) if toks: out.append(toks) return out bad_words_ids = make_bad_words_ids( tokenizer, ["", "", "", "", "Teacher:", "teacher:"] ) eos_id = tokenizer.convert_tokens_to_ids("") # ========================= # โ˜๏ธ In-Memory Logging + HF Upload # ========================= HF_DATASET_REPO = "hparten/math_chatbot_logs" api = HfApi() session_logs = {} # session_id -> list of turns last_activity = {} # session_id -> timestamp def add_turn_to_memory(session_id, username, strategy, teacher_msg, student_msg): """Store one turn in memory.""" from datetime import datetime row = { "timestamp": datetime.now().isoformat(timespec="seconds"), "session_id": session_id, "username": username, "strategy": strategy, "teacher": teacher_msg, "student": student_msg, } session_logs.setdefault(session_id, []).append(row) update_activity(session_id) def update_activity(session_id): import time last_activity[session_id] = time.time() def flush_session_to_hub(session_id): """Upload session logs to Hugging Face dataset as a single Parquet file.""" if session_id not in session_logs or not session_logs[session_id]: print(f"[flush] No logs found for session {session_id}") return df = pd.DataFrame(session_logs[session_id]) del session_logs[session_id] try: ds = load_dataset(HF_DATASET_REPO, split="train", token=HF_TOKEN) existing = ds.to_pandas() combined = pd.concat([existing, df], ignore_index=True) except Exception: combined = df with tempfile.NamedTemporaryFile("wb", delete=False, suffix=".parquet") as tmp: combined.to_parquet(tmp.name, index=False) tmp_path = tmp.name api.upload_file( path_or_fileobj=tmp_path, path_in_repo="chat_logs.parquet", repo_id=HF_DATASET_REPO, repo_type="dataset", token=HF_TOKEN, ) os.remove(tmp_path) print(f"[flush] Uploaded session {session_id} to HF dataset.") # ========================= # ๐Ÿค– Generation # ========================= def generate_response(teacher_question, username, history, session_id, strategy): prompt = build_prompt(strategy, history, teacher_question, tokenizer) out = pipe( prompt, max_new_tokens=MAX_NEW_TOKENS, do_sample=True, temperature=0.4, top_p=0.9, repetition_penalty=1.05, no_repeat_ngram_size=6, pad_token_id=tokenizer.pad_token_id or tokenizer.eos_token_id, eos_token_id=eos_id, bad_words_ids=bad_words_ids, return_full_text=False, ) out_text = out[0]["generated_text"] if "" in out_text and "" in out_text: student_reply = out_text.split("", 1)[1].split("", 1)[0].strip() else: student_reply = out_text.strip() # Force single-sentence cleanup student_reply = student_reply.split(".")[0].strip() + "." history.append((teacher_question, student_reply)) add_turn_to_memory(session_id, username, strategy, teacher_question, student_reply) return student_reply, history # ========================= # โ˜๏ธ Flush session logs to Hugging Face Hub # ========================= # #def flush_session_to_hub(session_id): # """Append this session to one Parquet file in the private HF dataset.""" # if session_id not in session_logs or not session_logs[session_id]: # print(f"[flush] No logs found for session {session_id}") # return # # df = pd.DataFrame(session_logs[session_id]) # del session_logs[session_id] # # try: # ds = load_dataset(HF_DATASET_REPO, split="train", token=HF_TOKEN) # existing = ds.to_pandas() # combined = pd.concat([existing, df], ignore_index=True) # except Exception: # combined = df # # with tempfile.NamedTemporaryFile("wb", delete=False, suffix=".parquet") as tmp: # combined.to_parquet(tmp.name, index=False) # tmp_path = tmp.name # # api.upload_file( # path_or_fileobj=tmp_path, # path_in_repo="chat_logs.parquet", # repo_id=HF_DATASET_REPO, # repo_type="dataset", # token=HF_TOKEN, # ) # # os.remove(tmp_path) # print(f"[flush] Uploaded session {session_id} to HF dataset.") # ========================= # Inactivity flush # ========================= import threading, time INACTIVITY_LIMIT = 600 # 10 minutes def check_inactivity_loop(): while True: now = time.time() inactive = [sid for sid, ts in last_activity.items() if now - ts > INACTIVITY_LIMIT] for sid in inactive: try: flush_session_to_hub(sid) del last_activity[sid] except Exception as e: print(f"[auto-flush-error] {sid}: {e}") time.sleep(60) threading.Thread(target=check_inactivity_loop, daemon=True).start() # ========================= # ๐Ÿ–ฅ Gradio UI # ========================= def on_send(teacher_question, username, strategy_choice, history, session_id): if not session_id: session_id = uuid.uuid4().hex[:12] if history is None: history = [] if not username.strip(): gr.Warning("Please enter your name before starting the chat.") return history, history, "", session_id if not teacher_question.strip(): gr.Warning("Please type a question for the student before sending.") return history, history, "", session_id student_reply, history = generate_response( teacher_question.strip(), username.strip(), history, session_id, strategy_choice.lower() ) msgs = [] for t, s in history[-MAX_HISTORY_TURNS:]: msgs.append({"role": "user", "content": t}) msgs.append({"role": "assistant", "content": s}) return msgs, history, "", session_id def on_reset(chat, history, teacher_q, session_id): """Flush the current session before resetting.""" if session_id: try: flush_session_to_hub(session_id) print(f"[manual flush] Uploaded session {session_id} to HF dataset.") except Exception as e: print(f"[manual flush error] {session_id}: {e}") return [], [], "", uuid.uuid4().hex[:12] # ========================= # ๐Ÿš€ Gradio App # ========================= with gr.Blocks(title="Elementary Math Student Chatbot") as demo: gr.Markdown("## ๐Ÿงฎ Practice Eliciting Student Thinking (Prototype)") gr.Markdown( "You are an elementary math teacher exploring a student's reasoning for **41 + ___ = 84**." "\nAsk questions and see how the student explains their thinking." ) with gr.Row(): username = gr.Textbox(label="๐Ÿ‘ค Your Name (first last)", placeholder="Enter your name...") strategy_choice = gr.Dropdown( ["friendly", "differencing", "subtraction"], value="Choose one", label="๐Ÿงฉ Student Strategy" ) reset_btn = gr.Button("๐Ÿ”„ Start Over", variant="secondary") teacher_q = gr.Textbox(label="๐Ÿ‘ฉโ€๐Ÿซ Teacher Question", placeholder="Ask the student a questionโ€ฆ") chat = gr.Chatbot(label="๐Ÿ’ฌ Chat", type="messages") state_history = gr.State([]) state_session = gr.State("") send = gr.Button("Send", variant="primary") send.click( on_send, inputs=[teacher_q, username, strategy_choice, state_history, state_session], outputs=[chat, state_history, teacher_q, state_session], ) reset_btn.click( on_reset, inputs=[chat, state_history, teacher_q, state_session], outputs=[chat, state_history, teacher_q, state_session], ) if __name__ == "__main__": demo.queue() demo.launch(server_name="0.0.0.0", server_port=7860, ssr_mode=False)