import os import re import json import pickle from urllib.parse import quote import numpy as np import gradio as gr from rank_bm25 import BM25Okapi from sentence_transformers import SentenceTransformer from openai import OpenAI # ===================================================== # PATHS # ===================================================== BUILD_DIR = "brainchat_build" CHUNKS_PATH = os.path.join(BUILD_DIR, "chunks.pkl") TOKENS_PATH = os.path.join(BUILD_DIR, "tokenized_chunks.pkl") EMBED_PATH = os.path.join(BUILD_DIR, "embeddings.npy") CONFIG_PATH = os.path.join(BUILD_DIR, "config.json") LOGO_FILE = "Brain chat-09.png" # ===================================================== # GLOBALS # ===================================================== EMBED_MODEL = None BM25 = None CHUNKS = None EMBEDDINGS = None CLIENT = None # ===================================================== # LOADERS # ===================================================== def tokenize(text: str): return re.findall(r"\w+", text.lower(), flags=re.UNICODE) def ensure_loaded(): global EMBED_MODEL, BM25, CHUNKS, EMBEDDINGS, CLIENT if CHUNKS is None: for path in [CHUNKS_PATH, TOKENS_PATH, EMBED_PATH, CONFIG_PATH]: if not os.path.exists(path): raise FileNotFoundError(f"Missing file: {path}") with open(CHUNKS_PATH, "rb") as f: CHUNKS = pickle.load(f) with open(TOKENS_PATH, "rb") as f: tokenized_chunks = pickle.load(f) EMBEDDINGS = np.load(EMBED_PATH) with open(CONFIG_PATH, "r", encoding="utf-8") as f: cfg = json.load(f) BM25 = BM25Okapi(tokenized_chunks) EMBED_MODEL = SentenceTransformer(cfg["embedding_model"]) if CLIENT is None: api_key = os.getenv("OPENAI_API_KEY") if not api_key: raise ValueError("OPENAI_API_KEY is missing in Hugging Face Space Secrets.") CLIENT = OpenAI(api_key=api_key) # ===================================================== # RETRIEVAL # ===================================================== def search_hybrid(query: str, shortlist_k: int = 20, final_k: int = 3): ensure_loaded() query_tokens = tokenize(query) bm25_scores = BM25.get_scores(query_tokens) shortlist_idx = np.argsort(bm25_scores)[::-1][:shortlist_k] shortlist_embeddings = EMBEDDINGS[shortlist_idx] qvec = EMBED_MODEL.encode([query], normalize_embeddings=True).astype("float32")[0] dense_scores = shortlist_embeddings @ qvec rerank_order = np.argsort(dense_scores)[::-1][:final_k] final_idx = shortlist_idx[rerank_order] return [CHUNKS[int(i)] for i in final_idx] def build_context(records): blocks = [] for i, r in enumerate(records, start=1): blocks.append( f"""[Source {i}] Book: {r['book']} Section: {r['section_title']} Pages: {r['page_start']}-{r['page_end']} Text: {r['text']}""" ) return "\n\n".join(blocks) def make_sources(records): seen = set() lines = [] for r in records: key = (r["book"], r["section_title"], r["page_start"], r["page_end"]) if key in seen: continue seen.add(key) lines.append( f"• {r['book']} | {r['section_title']} | pp. {r['page_start']}-{r['page_end']}" ) return "\n".join(lines) # ===================================================== # PROMPTS # ===================================================== def language_instruction(language_mode: str) -> str: if language_mode == "English": return "Answer only in English." if language_mode == "Spanish": return "Answer only in Spanish." if language_mode == "Bilingual": return "Answer first in English, then provide a Spanish version under the heading 'Español:'." return ( "If the user's message is in Spanish, answer in Spanish. " "If the user's message is in English, answer in English." ) def choose_quiz_count(user_text: str, selector: str) -> int: if selector in {"3", "5", "7"}: return int(selector) t = user_text.lower() if any(k in t for k in ["mock test", "final exam", "exam practice", "full test"]): return 7 if any(k in t for k in ["detailed", "revision", "comprehensive", "study"]): return 5 return 3 def build_tutor_prompt(mode: str, language_mode: str, question: str, context: str) -> str: mode_map = { "Explain": ( "Explain clearly like a friendly tutor using simple language. " "Use short headings if useful." ), "Detailed": ( "Give a fuller and more detailed explanation. Include concept, key points, and clinical relevance when supported by context." ), "Short Notes": ( "Answer in concise revision-note format using short bullet points." ), "Flashcards": ( "Create 6 flashcards in Q/A format using only the provided context." ), "Case-Based": ( "Create a short clinical scenario and explain it clearly using the provided context." ) } return f""" You are BrainChat, an interactive neurology and neuroanatomy tutor. Rules: - Use only the provided context from the books. - If the answer is not supported by the context, say exactly: Not found in the course material. - Be accurate and student-friendly. - Do not invent facts outside the context. - {language_instruction(language_mode)} Teaching style: {mode_map[mode]} Context: {context} Question: {question} """.strip() def build_quiz_generation_prompt(language_mode: str, topic: str, context: str, n_questions: int) -> str: return f""" You are BrainChat, an interactive tutor. Rules: - Use only the provided context. - Create exactly {n_questions} quiz questions. - Questions should be short and clear. - Also create a short answer key. - Return valid JSON only. - {language_instruction(language_mode)} Required JSON format: {{ "title": "short quiz title", "questions": [ {{"q": "question 1", "answer_key": "expected short answer"}}, {{"q": "question 2", "answer_key": "expected short answer"}} ] }} Context: {context} Topic: {topic} """.strip() def build_quiz_evaluation_prompt(language_mode: str, quiz_data: dict, user_answers: str) -> str: quiz_json = json.dumps(quiz_data, ensure_ascii=False) return f""" You are BrainChat, an interactive tutor. Evaluate the student's answers fairly using the quiz answer key. Give: - total score - per-question feedback - one short improvement suggestion Rules: - Accept semantically correct answers even if wording differs. - Return valid JSON only. - {language_instruction(language_mode)} Required JSON format: {{ "score_obtained": 0, "score_total": 0, "summary": "short overall feedback", "results": [ {{ "question": "question text", "student_answer": "student answer", "result": "Correct / Partially Correct / Incorrect", "feedback": "short explanation" }} ] }} Quiz data: {quiz_json} Student answers: {user_answers} """.strip() # ===================================================== # OPENAI HELPERS # ===================================================== def chat_text(prompt: str) -> str: resp = CLIENT.chat.completions.create( model="gpt-4o-mini", temperature=0.2, messages=[ {"role": "system", "content": "You are a helpful educational assistant."}, {"role": "user", "content": prompt}, ], ) return resp.choices[0].message.content.strip() def chat_json(prompt: str) -> dict: resp = CLIENT.chat.completions.create( model="gpt-4o-mini", temperature=0.2, response_format={"type": "json_object"}, messages=[ {"role": "system", "content": "Return only valid JSON."}, {"role": "user", "content": prompt}, ], ) return json.loads(resp.choices[0].message.content) # ===================================================== # HTML RENDERING # ===================================================== def md_to_html(text: str) -> str: safe = ( text.replace("&", "&") .replace("<", "<") .replace(">", ">") ) safe = re.sub(r"\*\*(.+?)\*\*", r"\1", safe) safe = safe.replace("\n", "
") return safe def render_chat(history): if not history: return """
Ask a question, choose a tutor mode, or start a quiz.
""" rows = [] for msg in history: role = msg["role"] content = md_to_html(msg["content"]) if role == "user": rows.append( f'
{content}
' ) else: rows.append( f'
{content}
' ) return f'
{"".join(rows)}
' def detect_logo_url(): if os.path.exists(LOGO_FILE): return f"/gradio_api/file={quote(LOGO_FILE)}" return None def render_header(): logo_url = detect_logo_url() if logo_url: logo_html = f""" BrainChat Logo """ else: logo_html = """
BRAIN
CHAT
""" return f"""
BrainChat
Interactive neurology and neuroanatomy tutor based on your uploaded books
""" # ===================================================== # MAIN LOGIC # ===================================================== def answer_question(message, history, mode, language_mode, quiz_count_mode, show_sources, quiz_state): if history is None: history = [] if quiz_state is None: quiz_state = { "active": False, "topic": None, "quiz_data": None, "language_mode": "Auto" } if not message or not message.strip(): return history, render_chat(history), quiz_state, "" try: ensure_loaded() user_text = message.strip() history = history + [{"role": "user", "content": user_text}] # ------------------------------- # QUIZ EVALUATION # ------------------------------- if quiz_state.get("active", False): evaluation_prompt = build_quiz_evaluation_prompt( quiz_state["language_mode"], quiz_state["quiz_data"], user_text ) evaluation = chat_json(evaluation_prompt) lines = [] lines.append(f"**Score:** {evaluation['score_obtained']}/{evaluation['score_total']}") lines.append("") lines.append(f"**Overall feedback:** {evaluation['summary']}") lines.append("") lines.append("**Question-wise evaluation:**") for item in evaluation["results"]: lines.append("") lines.append(f"**Q:** {item['question']}") lines.append(f"**Your answer:** {item['student_answer']}") lines.append(f"**Result:** {item['result']}") lines.append(f"**Feedback:** {item['feedback']}") final_answer = "\n".join(lines) history = history + [{"role": "assistant", "content": final_answer}] quiz_state = { "active": False, "topic": None, "quiz_data": None, "language_mode": language_mode } return history, render_chat(history), quiz_state, "" # ------------------------------- # NORMAL RETRIEVAL # ------------------------------- records = search_hybrid(user_text, shortlist_k=20, final_k=3) context = build_context(records) # ------------------------------- # QUIZ GENERATION # ------------------------------- if mode == "Quiz Me": n_questions = choose_quiz_count(user_text, quiz_count_mode) prompt = build_quiz_generation_prompt(language_mode, user_text, context, n_questions) quiz_data = chat_json(prompt) lines = [] lines.append(f"**{quiz_data.get('title', 'Quiz')}**") lines.append("") lines.append("Please answer the following questions in one message.") lines.append("You can reply in numbered format, for example:") lines.append("1. ...") lines.append("2. ...") lines.append("") lines.append(f"**Total questions: {len(quiz_data['questions'])}**") lines.append("") for i, q in enumerate(quiz_data["questions"], start=1): lines.append(f"**Q{i}.** {q['q']}") if show_sources: lines.append("\n---\n**Topic sources used to create the quiz:**") lines.append(make_sources(records)) assistant_text = "\n".join(lines) history = history + [{"role": "assistant", "content": assistant_text}] quiz_state = { "active": True, "topic": user_text, "quiz_data": quiz_data, "language_mode": language_mode } return history, render_chat(history), quiz_state, "" # ------------------------------- # OTHER MODES # ------------------------------- prompt = build_tutor_prompt(mode, language_mode, user_text, context) answer = chat_text(prompt) if show_sources: answer += "\n\n---\n**Sources used:**\n" + make_sources(records) history = history + [{"role": "assistant", "content": answer}] return history, render_chat(history), quiz_state, "" except Exception as e: history = history + [{"role": "assistant", "content": f"Error: {str(e)}"}] quiz_state["active"] = False return history, render_chat(history), quiz_state, "" def clear_all(): empty_history = [] empty_quiz = { "active": False, "topic": None, "quiz_data": None, "language_mode": "Auto" } return empty_history, render_chat(empty_history), empty_quiz, "" # ===================================================== # CSS # ===================================================== CSS = """ body, .gradio-container { background: #dcdcdc !important; font-family: Arial, Helvetica, sans-serif !important; } footer { display: none !important; } .hero-card { max-width: 900px; margin: 18px auto 14px auto; border-radius: 28px; background: linear-gradient(180deg, #e8c7d4 0%, #a55ca2 48%, #2b0c46 100%); padding: 22px 22px 18px 22px; } .hero-inner { text-align: center; } .hero-title { color: white; font-size: 34px; font-weight: 800; margin-top: 6px; } .hero-subtitle { color: white; opacity: 0.92; font-size: 16px; margin-top: 6px; } .chat-panel { max-width: 900px; margin: 0 auto; background: white; border-radius: 22px; padding: 16px; min-height: 420px; box-shadow: 0 6px 18px rgba(0,0,0,0.08); } .chat-wrap { display: flex; flex-direction: column; gap: 14px; } .msg-row { display: flex; width: 100%; } .user-row { justify-content: flex-end; } .bot-row { justify-content: flex-start; } .msg-bubble { max-width: 80%; padding: 14px 16px; border-radius: 18px; line-height: 1.5; font-size: 15px; word-wrap: break-word; } .user-bubble { background: #e9d8ff; color: #111; border-bottom-right-radius: 6px; } .bot-bubble { background: #f7f3a1; color: #111; border-bottom-left-radius: 6px; } .empty-chat { display: flex; justify-content: center; align-items: center; min-height: 360px; } .empty-chat-text { color: #777; font-size: 16px; text-align: center; } .controls-wrap { max-width: 900px; margin: 0 auto; } """ # ===================================================== # UI # ===================================================== with gr.Blocks() as demo: history_state = gr.State([]) quiz_state = gr.State({ "active": False, "topic": None, "quiz_data": None, "language_mode": "Auto" }) gr.HTML(render_header()) with gr.Row(elem_classes="controls-wrap"): mode = gr.Dropdown( choices=["Explain", "Detailed", "Short Notes", "Flashcards", "Case-Based", "Quiz Me"], value="Explain", label="Tutor Mode" ) language_mode = gr.Dropdown( choices=["Auto", "English", "Spanish", "Bilingual"], value="Auto", label="Answer Language" ) with gr.Row(elem_classes="controls-wrap"): quiz_count_mode = gr.Dropdown( choices=["Auto", "3", "5", "7"], value="Auto", label="Quiz Questions" ) show_sources = gr.Checkbox(value=True, label="Show Sources") gr.Markdown(""" **How to use** - Choose a **Tutor Mode** - Then type a topic or question - For **Quiz Me**, type a topic such as: `cranial nerves` - The system will ask questions, and your **next message will be evaluated automatically** """) chat_html = gr.HTML(render_chat([]), elem_classes="chat-panel") with gr.Row(elem_classes="controls-wrap"): msg = gr.Textbox( placeholder="Ask a question or type a topic...", lines=1, show_label=False, scale=8 ) send_btn = gr.Button("Send", scale=1) with gr.Row(elem_classes="controls-wrap"): clear_btn = gr.Button("Clear Chat") msg.submit( answer_question, inputs=[msg, history_state, mode, language_mode, quiz_count_mode, show_sources, quiz_state], outputs=[history_state, chat_html, quiz_state, msg] ) send_btn.click( answer_question, inputs=[msg, history_state, mode, language_mode, quiz_count_mode, show_sources, quiz_state], outputs=[history_state, chat_html, quiz_state, msg] ) clear_btn.click( clear_all, inputs=[], outputs=[history_state, chat_html, quiz_state, msg] ) if __name__ == "__main__": demo.launch(css=CSS)