import os import re import json import pickle from urllib.parse import quote import numpy as np import gradio as gr from rank_bm25 import BM25Okapi from sentence_transformers import SentenceTransformer from openai import OpenAI # ============================================================ # Configuration # ============================================================ BUILD_DIR = "brainchat_build" CHUNKS_PATH = os.path.join(BUILD_DIR, "chunks.pkl") TOKENS_PATH = os.path.join(BUILD_DIR, "tokenized_chunks.pkl") EMBED_PATH = os.path.join(BUILD_DIR, "embeddings.npy") CONFIG_PATH = os.path.join(BUILD_DIR, "config.json") # Put ONE of these logo files in your Space repo root (same folder as app.py) LOGO_CANDIDATES = [ "Brain chat-09.png", "brainchat_logo.png.png", "Brain Chat Imagen.svg", "ebcbb9f5-022f-473a-bf51-7e7974f794b4.png", ] MODEL_NAME_TEXT = os.getenv("OPENAI_MODEL", "gpt-4o-mini") # ============================================================ # Globals (lazy loaded) # ============================================================ BM25 = None CHUNKS = None EMBEDDINGS = None EMBED_MODEL = None CLIENT = None # ============================================================ # Utilities # ============================================================ def tokenize(text: str): return re.findall(r"\w+", text.lower(), flags=re.UNICODE) def ensure_loaded(): global BM25, CHUNKS, EMBEDDINGS, EMBED_MODEL, CLIENT if CHUNKS is None: missing = [p for p in [CHUNKS_PATH, TOKENS_PATH, EMBED_PATH, CONFIG_PATH] if not os.path.exists(p)] if missing: raise FileNotFoundError( "Missing build files. Make sure you ran the build step and committed brainchat_build/.\n" + "\n".join(missing) ) with open(CHUNKS_PATH, "rb") as f: CHUNKS = pickle.load(f) with open(TOKENS_PATH, "rb") as f: tokenized_chunks = pickle.load(f) EMBEDDINGS = np.load(EMBED_PATH) with open(CONFIG_PATH, "r", encoding="utf-8") as f: cfg = json.load(f) BM25 = BM25Okapi(tokenized_chunks) EMBED_MODEL = SentenceTransformer(cfg["embedding_model"]) if CLIENT is None: api_key = os.getenv("OPENAI_API_KEY") if not api_key: raise ValueError("OPENAI_API_KEY is missing. Add it in your Space Secrets.") CLIENT = OpenAI(api_key=api_key) def search_hybrid(query: str, shortlist_k: int = 30, final_k: int = 5): ensure_loaded() q_tokens = tokenize(query) bm25_scores = BM25.get_scores(q_tokens) shortlist_idx = np.argsort(bm25_scores)[::-1][:shortlist_k] qvec = EMBED_MODEL.encode([query], normalize_embeddings=True).astype("float32")[0] shortlist_emb = EMBEDDINGS[shortlist_idx] dense_scores = shortlist_emb @ qvec rerank = np.argsort(dense_scores)[::-1][:final_k] final_idx = shortlist_idx[rerank] return [CHUNKS[int(i)] for i in final_idx] def build_context(records): blocks = [] for i, r in enumerate(records, start=1): blocks.append( f"""[Source {i}] Book: {r.get('book','')} Section: {r.get('section_title','')} Pages: {r.get('page_start','')}-{r.get('page_end','')} Text: {r.get('text','')}""" ) return "\n\n".join(blocks) def make_sources(records): seen = set() lines = [] for r in records: key = (r.get("book"), r.get("section_title"), r.get("page_start"), r.get("page_end")) if key in seen: continue seen.add(key) lines.append( f"• {r.get('book','')} | {r.get('section_title','')} | pp. {r.get('page_start','')}-{r.get('page_end','')}" ) return "\n".join(lines) def choose_quiz_count(user_text: str, selector: str) -> int: if selector in {"3", "5", "7"}: return int(selector) t = user_text.lower() if any(k in t for k in ["mock test", "final exam", "exam practice", "full test"]): return 7 if any(k in t for k in ["detailed", "revision", "comprehensive", "study"]): return 5 return 3 def language_instruction(language_mode: str) -> str: if language_mode == "English": return "Answer only in English." if language_mode == "Spanish": return "Answer only in Spanish." if language_mode == "Bilingual": return "Answer first in English, then provide a Spanish version under the heading 'Español:'." return "If the user writes in Spanish, answer in Spanish; otherwise answer in English." def build_tutor_prompt(mode: str, language_mode: str, question: str, context: str) -> str: mode_map = { "Explain": ( "Explain clearly like a friendly tutor using simple language. " "Use short headings if helpful." ), "Detailed": ( "Give a detailed explanation. Include key terms and clinical relevance only if supported by the context." ), "Short Notes": "Write concise revision notes using bullet points.", "Flashcards": "Create 6 flashcards in Q/A format.", "Case-Based": ( "Create a short clinical scenario (2–4 lines) and then explain the underlying concept using the context." ), } return f""" You are BrainChat, an interactive neurology and neuroanatomy tutor. Rules: - Use ONLY the provided context from the books. - If the answer is not supported by the context, say exactly: Not found in the course material. - Do not invent facts outside the context. - {language_instruction(language_mode)} Teaching style: {mode_map.get(mode, mode_map['Explain'])} Context: {context} Student question: {question} """.strip() def build_quiz_generation_prompt(language_mode: str, topic: str, context: str, n_questions: int) -> str: return f""" You are BrainChat, an interactive tutor. Rules: - Use ONLY the provided context. - Create exactly {n_questions} quiz questions. - Questions should be short, clear, and course-aligned. - Provide a short answer key per question. - Return VALID JSON only. - {language_instruction(language_mode)} Required JSON format: {{ "title": "short quiz title", "questions": [ {{"q": "question 1", "answer_key": "expected short answer"}}, {{"q": "question 2", "answer_key": "expected short answer"}} ] }} Context: {context} Topic: {topic} """.strip() def build_quiz_evaluation_prompt(language_mode: str, quiz_data: dict, user_answers: str) -> str: quiz_json = json.dumps(quiz_data, ensure_ascii=False) return f""" You are BrainChat, an interactive tutor. Task: Evaluate the student's answers fairly against the answer keys. Accept semantically correct answers even if wording differs. Return VALID JSON only. Required JSON format: {{ "score_obtained": 0, "score_total": 0, "summary": "short overall feedback", "results": [ {{ "question": "question text", "answer_key": "expected short answer", "student_answer": "student answer", "result": "Correct / Partially Correct / Incorrect", "feedback": "short explanation" }} ], "improvement_tip": "one short study suggestion" }} Quiz: {quiz_json} Student answers: {user_answers} Language: {language_instruction(language_mode)} """.strip() def chat_text(prompt: str) -> str: ensure_loaded() resp = CLIENT.chat.completions.create( model=MODEL_NAME_TEXT, temperature=0.2, messages=[ {"role": "system", "content": "You are a helpful educational assistant."}, {"role": "user", "content": prompt}, ], ) return resp.choices[0].message.content.strip() def chat_json(prompt: str) -> dict: ensure_loaded() resp = CLIENT.chat.completions.create( model=MODEL_NAME_TEXT, temperature=0.2, response_format={"type": "json_object"}, messages=[ {"role": "system", "content": "Return only valid JSON."}, {"role": "user", "content": prompt}, ], ) return json.loads(resp.choices[0].message.content) # ============================================================ # Logo + Header HTML # ============================================================ def find_logo_file(): for name in LOGO_CANDIDATES: if os.path.exists(name): return name return None def logo_img_tag(size_px: int = 88) -> str: logo_file = find_logo_file() if logo_file: url = f"/gradio_api/file={quote(logo_file)}" return f'BrainChat logo' return '
BRAIN
CHAT
' def render_top_banner() -> str: return f"""
BrainChat
Neurology & neuroanatomy tutor (book-based)
""".strip() def render_phone_logo() -> str: return f""" """.strip() # ============================================================ # Chat logic (with quiz state) # ============================================================ def respond(message, history, mode, language_mode, quiz_count_mode, show_sources, quiz_state): if history is None: history = [] if quiz_state is None: quiz_state = {"active": False, "quiz_data": None, "language_mode": "Auto"} user_text = (message or "").strip() if not user_text: return "", history, quiz_state try: history = history + [{"role": "user", "content": user_text}] # Quiz evaluation step if quiz_state.get("active", False): evaluation_prompt = build_quiz_evaluation_prompt( quiz_state.get("language_mode", language_mode), quiz_state.get("quiz_data", {}), user_text, ) evaluation = chat_json(evaluation_prompt) lines = [] lines.append(f"**Score:** {evaluation.get('score_obtained', 0)}/{evaluation.get('score_total', 0)}") if evaluation.get("summary"): lines.append(f"\n**Overall:** {evaluation['summary']}") if evaluation.get("improvement_tip"): lines.append(f"\n**Tip:** {evaluation['improvement_tip']}\n") results = evaluation.get("results", []) if results: lines.append("**Question-wise feedback:**") for item in results: lines.append("") lines.append(f"**Q:** {item.get('question','')}") lines.append(f"**Your answer:** {item.get('student_answer','')}") lines.append(f"**Expected:** {item.get('answer_key','')}") lines.append(f"**Result:** {item.get('result','')}") lines.append(f"**Feedback:** {item.get('feedback','')}") assistant_text = "\n".join(lines).strip() history = history + [{"role": "assistant", "content": assistant_text}] quiz_state = {"active": False, "quiz_data": None, "language_mode": language_mode} return "", history, quiz_state # Normal retrieval records = search_hybrid(user_text, shortlist_k=30, final_k=5) context = build_context(records) # Quiz generation if mode == "Quiz Me": n_questions = choose_quiz_count(user_text, quiz_count_mode) quiz_prompt = build_quiz_generation_prompt(language_mode, user_text, context, n_questions) quiz_data = chat_json(quiz_prompt) lines = [] lines.append(f"**{quiz_data.get('title','Quiz')}**") lines.append(f"\n**Total questions:** {len(quiz_data.get('questions', []))}\n") lines.append("Reply in ONE message with numbered answers, like:") lines.append("1. ...") lines.append("2. ...\n") for i, q in enumerate(quiz_data.get("questions", []), start=1): lines.append(f"**Q{i}.** {q.get('q','')}") if show_sources: lines.append("\n\n**Sources used to create this quiz:**") lines.append(make_sources(records)) assistant_text = "\n".join(lines).strip() history = history + [{"role": "assistant", "content": assistant_text}] quiz_state = {"active": True, "quiz_data": quiz_data, "language_mode": language_mode} return "", history, quiz_state # Other modes tutor_prompt = build_tutor_prompt(mode, language_mode, user_text, context) answer = chat_text(tutor_prompt) if show_sources: answer = (answer or "").strip() + "\n\n**Sources:**\n" + make_sources(records) history = history + [{"role": "assistant", "content": answer.strip()}] return "", history, quiz_state except Exception as e: history = history + [{"role": "assistant", "content": f"Error: {str(e)}"}] quiz_state = {"active": False, "quiz_data": None, "language_mode": language_mode} return "", history, quiz_state def clear_all(): return "", [], {"active": False, "quiz_data": None, "language_mode": "Auto"} # ============================================================ # CSS (Instagram-style phone mock) # ============================================================ CSS = r""" :root{ --bc-page-bg: #dcdcdc; --bc-grad-top: #E8C7D4; --bc-grad-mid: #A55CA2; --bc-grad-bot: #2B0C46; --bc-yellow: #FFF34A; --bc-bot-bubble: #FAF7B4; --bc-user-bubble: #FFFFFF; --bc-ink: #141414; } body, .gradio-container{ background: var(--bc-page-bg) !important; font-family: ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial; } footer{ display:none !important; } /* Banner */ #bc_banner{ max-width: 980px; margin: 18px auto 8px auto; } .bc-banner{ background: linear-gradient(180deg, var(--bc-grad-top) 0%, var(--bc-grad-mid) 52%, var(--bc-grad-bot) 100%); border-radius: 26px; padding: 14px 16px; box-shadow: 0 10px 26px rgba(0,0,0,.12); } .bc-banner-inner{ display:flex; align-items:center; gap: 12px; color: white; } .bc-banner-title{ font-size: 20px; font-weight: 800; line-height:1.1; } .bc-banner-subtitle{ font-size: 13px; opacity:.92; margin-top:2px; } .bc-banner-logo .bc-logo-img{ border-radius: 999px; background: var(--bc-yellow); padding: 6px; display:block; } .bc-logo-fallback{ width: 64px; height: 64px; border-radius: 999px; background: var(--bc-yellow); display:flex; align-items:center; justify-content:center; color: #111; font-weight: 900; font-size: 12px; text-align:center; } /* Settings */ #bc_settings{ max-width: 980px; margin: 0 auto 10px auto; } #bc_settings .label{ font-weight: 700; } /* Phone */ #bc_phone{ max-width: 420px; margin: 0 auto 18px auto; border-radius: 38px; background: linear-gradient(180deg, var(--bc-grad-top) 0%, var(--bc-grad-mid) 45%, var(--bc-grad-bot) 100%); box-shadow: 0 18px 40px rgba(0,0,0,.18); border: 1px solid rgba(255,255,255,.22); padding: 14px 14px 12px 14px; position: relative; } /* Floating logo in phone */ #bc_phone_logo{ position: absolute; top: 12px; left: 50%; transform: translateX(-50%); z-index: 10; } .bc-phone-logo{ width: 92px; height: 92px; border-radius: 999px; background: var(--bc-yellow); display:flex; align-items:center; justify-content:center; box-shadow: 0 10px 22px rgba(0,0,0,.18); } .bc-phone-logo .bc-logo-img{ width: 84px !important; height: 84px !important; object-fit: contain; } /* Push chat down under logo */ #bc_chatbot{ margin-top: 92px; } /* Chatbot transparent */ #bc_chatbot, #bc_chatbot > div{ background: transparent !important; border: none !important; box-shadow: none !important; } #bc_chatbot .toolbar{ display:none !important; } /* Bubble styling via internal testid markers */ #bc_chatbot button[data-testid="user"], #bc_chatbot button[data-testid="bot"]{ max-width: 82%; border-radius: 18px !important; padding: 12px 14px !important; color: var(--bc-ink) !important; box-shadow: 0 8px 18px rgba(0,0,0,.10); border: 0 !important; line-height: 1.35; font-size: 14px; } /* User bubble white */ #bc_chatbot button[data-testid="user"]{ background: var(--bc-user-bubble) !important; } /* Bot bubble pale yellow */ #bc_chatbot button[data-testid="bot"]{ background: var(--bc-bot-bubble) !important; } /* Bubble tails */ #bc_chatbot button[data-testid="user"]::after{ content:""; position:absolute; right:-7px; bottom: 12px; width:0; height:0; border-left: 10px solid var(--bc-user-bubble); border-top: 8px solid transparent; border-bottom: 8px solid transparent; } #bc_chatbot button[data-testid="bot"]::before{ content:""; position:absolute; left:-7px; bottom: 12px; width:0; height:0; border-right: 10px solid var(--bc-bot-bubble); border-top: 8px solid transparent; border-bottom: 8px solid transparent; } /* Input bar */ #bc_input_row{ margin-top: 10px; background: rgba(255,243,74,.96); border-radius: 999px; padding: 10px 10px; box-shadow: 0 10px 22px rgba(0,0,0,.14); align-items: center; } #bc_plus{ width: 34px; height: 34px; border-radius: 999px; display:flex; align-items:center; justify-content:center; font-weight: 900; color: var(--bc-grad-bot); background: rgba(255,255,255,.35); user-select: none; } #bc_msg textarea{ background: rgba(255,255,255,.35) !important; border-radius: 999px !important; border: none !important; padding: 10px 12px !important; color: var(--bc-grad-bot) !important; box-shadow: none !important; } #bc_send{ min-width: 42px !important; height: 38px !important; border-radius: 999px !important; border: none !important; background: rgba(255,255,255,.35) !important; color: var(--bc-grad-bot) !important; font-size: 18px !important; font-weight: 900 !important; } #bc_send:hover{ background: rgba(255,255,255,.55) !important; } /* Clear */ #bc_clear{ max-width: 420px; margin: 10px auto 0 auto; border-radius: 14px !important; } @media (max-width: 480px){ #bc_phone{ max-width: 95vw; } #bc_chatbot button[data-testid="user"], #bc_chatbot button[data-testid="bot"]{ max-width: 88%; font-size: 14px; } } """ # ============================================================ # UI # ============================================================ with gr.Blocks() as demo: quiz_state = gr.State({"active": False, "quiz_data": None, "language_mode": "Auto"}) gr.HTML(render_top_banner(), elem_id="bc_banner") with gr.Accordion("Settings", open=False, elem_id="bc_settings"): mode = gr.Dropdown( choices=["Explain", "Detailed", "Short Notes", "Flashcards", "Case-Based", "Quiz Me"], value="Explain", label="Tutor Mode", ) language_mode = gr.Dropdown( choices=["Auto", "English", "Spanish", "Bilingual"], value="Auto", label="Answer Language", ) quiz_count_mode = gr.Dropdown( choices=["Auto", "3", "5", "7"], value="Auto", label="Quiz Questions", ) show_sources = gr.Checkbox(value=True, label="Show Sources") with gr.Group(elem_id="bc_phone"): gr.HTML(render_phone_logo(), elem_id="bc_phone_logo") chatbot = gr.Chatbot( value=[], elem_id="bc_chatbot", height=560, layout="bubble", container=False, show_label=False, autoscroll=True, buttons=[], placeholder="Ask a question or type a topic…", ) with gr.Row(elem_id="bc_input_row"): gr.HTML("
+
", elem_id="bc_plus") msg = gr.Textbox( placeholder="Type a message…", show_label=False, container=False, scale=8, elem_id="bc_msg", ) send_btn = gr.Button("➤", elem_id="bc_send", scale=1) clear_btn = gr.Button("Clear chat", elem_id="bc_clear") msg.submit( respond, inputs=[msg, chatbot, mode, language_mode, quiz_count_mode, show_sources, quiz_state], outputs=[msg, chatbot, quiz_state], ) send_btn.click( respond, inputs=[msg, chatbot, mode, language_mode, quiz_count_mode, show_sources, quiz_state], outputs=[msg, chatbot, quiz_state], ) clear_btn.click( clear_all, inputs=None, outputs=[msg, chatbot, quiz_state], queue=False, ) if __name__ == "__main__": demo.launch(css=CSS)