import os import re import json import pickle import base64 import mimetypes from openai import OpenAI from urllib.parse import quote import numpy as np import gradio as gr from rank_bm25 import BM25Okapi from sentence_transformers import SentenceTransformer # ===================================================== # CONFIG # ===================================================== BUILD_DIR = "brainchat_build" CHUNKS_PATH = os.path.join(BUILD_DIR, "chunks.pkl") TOKENS_PATH = os.path.join(BUILD_DIR, "tokenized_chunks.pkl") EMBED_PATH = os.path.join(BUILD_DIR, "embeddings.npy") CONFIG_PATH = os.path.join(BUILD_DIR, "config.json") LOGO_FILE = "logo.png" OPENAI_MODEL = os.getenv("OPENAI_MODEL", "gpt-4o-mini") BM25 = None CHUNKS = None EMBEDDINGS = None EMBED_MODEL = None CLIENT = None # ===================================================== # LOADERS # ===================================================== def tokenize(text: str): return re.findall(r"\w+", text.lower(), flags=re.UNICODE) def ensure_loaded(): global BM25, CHUNKS, EMBEDDINGS, EMBED_MODEL, CLIENT if CHUNKS is None: missing = [] for p in [CHUNKS_PATH, TOKENS_PATH, EMBED_PATH, CONFIG_PATH]: if not os.path.exists(p): missing.append(p) if missing: raise FileNotFoundError( "Missing build files:\n" + "\n".join(missing) ) with open(CHUNKS_PATH, "rb") as f: CHUNKS = pickle.load(f) with open(TOKENS_PATH, "rb") as f: tokenized_chunks = pickle.load(f) EMBEDDINGS = np.load(EMBED_PATH) with open(CONFIG_PATH, "r", encoding="utf-8") as f: cfg = json.load(f) BM25 = BM25Okapi(tokenized_chunks) EMBED_MODEL = SentenceTransformer(cfg["embedding_model"]) if CLIENT is None: api_key = os.getenv("OPENAI_API_KEY") if not api_key: raise ValueError("OPENAI_API_KEY is missing in Hugging Face Space Secrets.") CLIENT = OpenAI(api_key=api_key) # ===================================================== # RETRIEVAL # ===================================================== def search_hybrid(query: str, shortlist_k: int = 20, final_k: int = 4): ensure_loaded() q_tokens = tokenize(query) bm25_scores = BM25.get_scores(q_tokens) shortlist_idx = np.argsort(bm25_scores)[::-1][:shortlist_k] shortlist_emb = EMBEDDINGS[shortlist_idx] qvec = EMBED_MODEL.encode([query], normalize_embeddings=True).astype("float32")[0] dense_scores = shortlist_emb @ qvec rerank = np.argsort(dense_scores)[::-1][:final_k] final_idx = shortlist_idx[rerank] return [CHUNKS[int(i)] for i in final_idx] def build_context(records): blocks = [] for i, r in enumerate(records, start=1): blocks.append( f"""[Source {i}] Book: {r.get('book','')} Section: {r.get('section_title','')} Pages: {r.get('page_start','')}-{r.get('page_end','')} Text: {r.get('text','')}""" ) return "\n\n".join(blocks) def make_sources(records): seen = set() lines = [] for r in records: key = ( r.get("book"), r.get("section_title"), r.get("page_start"), r.get("page_end"), ) if key in seen: continue seen.add(key) lines.append( f"• {r.get('book','')} | {r.get('section_title','')} | pp. {r.get('page_start','')}-{r.get('page_end','')}" ) return "\n".join(lines) # ===================================================== # PROMPTS # ===================================================== def language_instruction(language_mode: str) -> str: if language_mode == "English": return "Answer only in English." if language_mode == "Spanish": return "Answer only in Spanish." if language_mode == "Bilingual": return "Answer first in English, then provide a Spanish version under the heading 'Español:'." return "If the user's message is in Spanish, answer in Spanish; otherwise answer in English." def choose_quiz_count(user_text: str, selector: str) -> int: if selector in {"3", "5", "7"}: return int(selector) t = user_text.lower() if any(k in t for k in ["mock test", "final exam", "exam practice", "full test"]): return 7 if any(k in t for k in ["detailed", "revision", "comprehensive", "study"]): return 5 return 3 def build_tutor_prompt(mode: str, language_mode: str, question: str, context: str) -> str: styles = { "Explain": "Explain clearly like a friendly tutor using simple language.", "Detailed": "Give a more detailed explanation with key points and clinical relevance only when supported by context.", "Short Notes": "Write concise revision notes using short bullet points.", "Flashcards": "Create 6 flashcards in Q/A format using only the context.", "Case-Based": "Create a short clinical case scenario and then explain the concept using the context.", } return f""" You are BrainChat, an interactive neurology and neuroanatomy tutor. Rules: - Use ONLY the provided context. - If the answer is not supported by the context, say exactly: Not found in the course material. - Do not invent facts outside the context. - {language_instruction(language_mode)} Teaching style: {styles.get(mode, "Explain clearly like a friendly tutor.")} Context: {context} Student question: {question} """.strip() def build_quiz_generation_prompt(language_mode: str, topic: str, context: str, n_questions: int) -> str: return f""" You are BrainChat, an interactive tutor. Rules: - Use ONLY the provided context. - Create exactly {n_questions} quiz questions. - Keep questions short and clear. - Include a short answer key for each. - Return VALID JSON only. - {language_instruction(language_mode)} Return JSON in this format: {{ "title": "short quiz title", "questions": [ {{"q": "question 1", "answer_key": "expected short answer"}}, {{"q": "question 2", "answer_key": "expected short answer"}} ] }} Context: {context} Topic: {topic} """.strip() def build_quiz_eval_prompt(language_mode: str, quiz_data: dict, user_answers: str) -> str: quiz_json = json.dumps(quiz_data, ensure_ascii=False) return f""" You are BrainChat, an interactive tutor. Evaluate the student's answers fairly using the answer keys. Accept semantically correct answers even if wording differs. Return VALID JSON only. Return JSON in this format: {{ "score_obtained": 0, "score_total": 0, "summary": "short overall feedback", "results": [ {{ "question": "question text", "answer_key": "expected answer", "student_answer": "student answer", "result": "Correct / Partially Correct / Incorrect", "feedback": "short explanation" }} ], "improvement_tip": "one short study suggestion" }} Quiz: {quiz_json} Student answers: {user_answers} Language: {language_instruction(language_mode)} """.strip() # ===================================================== # OPENAI # ===================================================== def oai_text(prompt: str) -> str: ensure_loaded() resp = CLIENT.chat.completions.create( model=OPENAI_MODEL, temperature=0.2, messages=[ {"role": "system", "content": "You are a helpful educational assistant."}, {"role": "user", "content": prompt}, ], ) return resp.choices[0].message.content.strip() def oai_json(prompt: str) -> dict: ensure_loaded() resp = CLIENT.chat.completions.create( model=OPENAI_MODEL, temperature=0.2, response_format={"type": "json_object"}, messages=[ {"role": "system", "content": "Return only valid JSON."}, {"role": "user", "content": prompt}, ], ) return json.loads(resp.choices[0].message.content) # ===================================================== # LOGO # ===================================================== def get_logo_data_uri(): if not os.path.exists(LOGO_FILE): return None mime_type, _ = mimetypes.guess_type(LOGO_FILE) if not mime_type: mime_type = "image/png" with open(LOGO_FILE, "rb") as f: encoded = base64.b64encode(f.read()).decode("utf-8") return f"data:{mime_type};base64,{encoded}" def render_logo(): data_uri = get_logo_data_uri() if data_uri: return f'BrainChat logo' return '
BRAIN
CHAT
' # ===================================================== # CHAT HTML # ===================================================== def format_text(text: str) -> str: safe = ( text.replace("&", "&") .replace("<", "<") .replace(">", ">") ) safe = re.sub(r"\*\*(.+?)\*\*", r"\1", safe) safe = safe.replace("\n", "
") return safe def render_chat(history): if not history: return """
Ask a question and choose a tutor mode above.
""" rows = [] for item in history: role = item["role"] content = format_text(item["content"]) if role == "user": rows.append( f'
{content}
' ) else: rows.append( f'
{content}
' ) return f"""
{''.join(rows)}
""" # ===================================================== # MAIN LOGIC # ===================================================== def respond(user_msg, history, mode, language_mode, quiz_count_mode, show_sources, quiz_state): history = history or [] quiz_state = quiz_state or {"active": False, "quiz_data": None, "language_mode": "Auto"} text = (user_msg or "").strip() if not text: return "", history, render_chat(history), quiz_state try: history = history + [{"role": "user", "content": text}] if quiz_state.get("active", False): evaluation = oai_json( build_quiz_eval_prompt( quiz_state.get("language_mode", language_mode), quiz_state.get("quiz_data", {}), text ) ) lines = [] lines.append(f"**Score:** {evaluation.get('score_obtained', 0)}/{evaluation.get('score_total', 0)}") if evaluation.get("summary"): lines.append(f"\n**Overall:** {evaluation['summary']}") if evaluation.get("improvement_tip"): lines.append(f"\n**Tip:** {evaluation['improvement_tip']}\n") results = evaluation.get("results", []) if results: lines.append("**Question-wise feedback:**") for item in results: lines.append("") lines.append(f"**Q:** {item.get('question','')}") lines.append(f"**Your answer:** {item.get('student_answer','')}") lines.append(f"**Expected:** {item.get('answer_key','')}") lines.append(f"**Result:** {item.get('result','')}") lines.append(f"**Feedback:** {item.get('feedback','')}") history = history + [{"role": "assistant", "content": "\n".join(lines).strip()}] quiz_state = {"active": False, "quiz_data": None, "language_mode": language_mode} return "", history, render_chat(history), quiz_state records = search_hybrid(text, shortlist_k=20, final_k=4) context = build_context(records) if mode == "Quiz Me": n_questions = choose_quiz_count(text, quiz_count_mode) quiz_data = oai_json(build_quiz_generation_prompt(language_mode, text, context, n_questions)) lines = [] lines.append(f"**{quiz_data.get('title', 'Quiz')}**") lines.append(f"\n**Total questions:** {len(quiz_data.get('questions', []))}\n") lines.append("Reply in one message using numbered answers.") lines.append("Example: 1. ... 2. ...\n") for i, q in enumerate(quiz_data.get("questions", []), start=1): lines.append(f"**Q{i}.** {q.get('q','')}") if show_sources: lines.append("\n\n**Sources used to create this quiz:**") lines.append(make_sources(records)) history = history + [{"role": "assistant", "content": "\n".join(lines).strip()}] quiz_state = {"active": True, "quiz_data": quiz_data, "language_mode": language_mode} return "", history, render_chat(history), quiz_state answer = oai_text(build_tutor_prompt(mode, language_mode, text, context)) if show_sources: answer = answer.strip() + "\n\n**Sources:**\n" + make_sources(records) history = history + [{"role": "assistant", "content": answer.strip()}] return "", history, render_chat(history), quiz_state except Exception as e: history = history + [{"role": "assistant", "content": f"Error: {str(e)}"}] quiz_state = {"active": False, "quiz_data": None, "language_mode": language_mode} return "", history, render_chat(history), quiz_state def clear_all(): empty_history = [] empty_quiz = {"active": False, "quiz_data": None, "language_mode": "Auto"} return "", empty_history, render_chat(empty_history), empty_quiz # ===================================================== # CSS # ===================================================== CSS = """ :root{ --page-bg: #d9d9dd; --panel-bg: #555765; --chat-bg: #4a4c59; --grad-top: #e8c7d4; --grad-mid: #a55ca2; --grad-bot: #5a2d77; --accent: #f4eb4b; --accent-soft: #f5ef9a; --user-bubble: #ffffff; --bot-bubble: #f5efad; --text-dark: #241336; --text-dark-strong: #170c25; --text-light: #ffffff; --shadow: rgba(30,20,50,0.18); } html, body, .gradio-container{ background: var(--page-bg) !important; font-family: Arial, Helvetica, sans-serif !important; } footer{display:none !important;} #bc_app{ max-width: 1000px; margin: 18px auto; } .bc-settings{ background: linear-gradient(180deg, #4a4c59 0%, #434552 100%); border-radius: 20px; padding: 14px; box-shadow: 0 10px 24px var(--shadow); margin-bottom: 14px; } .bc-howto{ margin-top: 8px; padding: 12px 14px; border-radius: 14px; background: rgba(255,255,255,0.10); color: white; font-size: 14px; line-height: 1.5; } .bc-phone{ position: relative; background: linear-gradient(180deg, var(--grad-top) 0%, var(--grad-mid) 48%, var(--grad-bot) 100%); border-radius: 30px; padding: 92px 14px 14px 14px; box-shadow: 0 16px 34px var(--shadow); min-height: 620px; } .bc-logo-holder{ position: absolute; top: 16px; left: 50%; transform: translateX(-50%); width: 104px; height: 104px; border-radius: 999px; background: var(--accent); display: flex; align-items: center; justify-content: center; box-shadow: 0 10px 22px rgba(0,0,0,0.18); } .bc-logo-img{ width: 88px; height: 88px; object-fit: contain; display:block; } .bc-logo-fallback{ width: 88px; height: 88px; border-radius: 999px; display:flex; align-items:center; justify-content:center; text-align:center; font-size: 13px; font-weight: 900; color: var(--text-dark-strong); background: rgba(255,255,255,0.40); line-height: 1.05; } .bc-chat-shell{ background: rgba(74,76,89,0.92); border-radius: 20px; padding: 16px; min-height: 460px; box-shadow: inset 0 1px 0 rgba(255,255,255,0.06); } .bc-chat-wrap{ display: flex; flex-direction: column; gap: 14px; max-height: 460px; overflow-y: auto; padding-right: 4px; } .bc-chat-wrap::-webkit-scrollbar{ width: 8px; } .bc-chat-wrap::-webkit-scrollbar-thumb{ background: rgba(255,255,255,0.28); border-radius: 999px; } .bc-row{ display:flex; width:100%; } .bc-user-row{ justify-content: flex-start; } .bc-bot-row{ justify-content: flex-end; } .bc-bubble{ max-width: 80%; padding: 15px 18px; border-radius: 22px; line-height: 1.6; font-size: 15px; box-shadow: 0 10px 18px rgba(0,0,0,0.10); word-wrap: break-word; font-weight: 500; } .bc-user-bubble{ background: var(--user-bubble); color: var(--text-dark-strong) !important; border-bottom-left-radius: 8px; } .bc-bot-bubble{ background: var(--bot-bubble); color: var(--text-dark-strong) !important; border-bottom-right-radius: 8px; } .bc-bubble strong{ color: var(--text-dark-strong) !important; } .bc-empty{ display:flex; justify-content:center; align-items:center; min-height: 400px; } .bc-empty-text{ color: white; text-align:center; opacity: 0.96; font-size: 16px; } .bc-input-bar{ margin-top: 12px; background: var(--accent); border-radius: 999px; padding: 8px 10px; display:flex; align-items:center; gap: 10px; box-shadow: 0 10px 22px rgba(0,0,0,0.14); } .bc-plus{ width: 38px; height: 38px; border-radius: 999px; background: rgba(255,255,255,0.34); display:flex; align-items:center; justify-content:center; font-size: 30px; font-weight: 900; color: var(--text-dark-strong); user-select:none; } #bc_msg textarea{ background: rgba(255,255,255,0.42) !important; border: none !important; box-shadow: none !important; border-radius: 999px !important; color: var(--text-dark-strong) !important; padding: 11px 14px !important; min-height: 42px !important; } #bc_msg textarea::placeholder{ color: rgba(34,23,53,0.72) !important; } #bc_send button{ min-width: 48px !important; height: 42px !important; border-radius: 999px !important; border: none !important; background: rgba(255,255,255,0.34) !important; color: var(--text-dark-strong) !important; font-size: 20px !important; font-weight: 900 !important; box-shadow: none !important; } #bc_send button:hover{ background: rgba(255,255,255,0.52) !important; } #bc_clear button{ border-radius: 14px !important; } @media (max-width: 768px){ #bc_app{ max-width: 96vw; } .bc-bubble{ max-width: 88%; } } """ # ===================================================== # UI # ===================================================== with gr.Blocks() as demo: history_state = gr.State([]) quiz_state = gr.State({"active": False, "quiz_data": None, "language_mode": "Auto"}) with gr.Column(elem_id="bc_app"): with gr.Group(elem_classes="bc-settings"): with gr.Row(): mode = gr.Dropdown( choices=["Explain", "Detailed", "Short Notes", "Flashcards", "Case-Based", "Quiz Me"], value="Explain", label="Tutor Mode" ) language_mode = gr.Dropdown( choices=["Auto", "English", "Spanish", "Bilingual"], value="Auto", label="Answer Language" ) with gr.Row(): quiz_count_mode = gr.Dropdown( choices=["Auto", "3", "5", "7"], value="Auto", label="Quiz Questions" ) show_sources = gr.Checkbox(value=True, label="Show Sources") gr.HTML("""
How to use
1. Choose a tutor mode such as Explain, Detailed, Flashcards, or Quiz Me.
2. Type your topic or question in the message box below.
3. For Quiz Me, the next message you send will be evaluated automatically.
4. Turn on Show Sources if you want references from the books.
""") with gr.Group(elem_classes="bc-phone"): gr.HTML(f'
{render_logo()}
') chat_html = gr.HTML( f'
{render_chat([])}
' ) with gr.Row(elem_classes="bc-input-bar"): gr.HTML('
+
') msg = gr.Textbox( placeholder="Type a message...", show_label=False, container=False, scale=8, elem_id="bc_msg" ) send_btn = gr.Button("➤", elem_id="bc_send", scale=1) clear_btn = gr.Button("Clear Chat", elem_id="bc_clear") msg.submit( respond, inputs=[msg, history_state, mode, language_mode, quiz_count_mode, show_sources, quiz_state], outputs=[msg, history_state, chat_html, quiz_state] ) send_btn.click( respond, inputs=[msg, history_state, mode, language_mode, quiz_count_mode, show_sources, quiz_state], outputs=[msg, history_state, chat_html, quiz_state] ) clear_btn.click( clear_all, inputs=[], outputs=[msg, history_state, chat_html, quiz_state], queue=False ) if __name__ == "__main__": demo.queue() demo.launch(css=CSS)