import os import re import json import pickle from urllib.parse import quote import numpy as np import gradio as gr from rank_bm25 import BM25Okapi from sentence_transformers import SentenceTransformer from openai import OpenAI BUILD_DIR = "brainchat_build" CHUNKS_PATH = os.path.join(BUILD_DIR, "chunks.pkl") TOKENS_PATH = os.path.join(BUILD_DIR, "tokenized_chunks.pkl") EMBED_PATH = os.path.join(BUILD_DIR, "embeddings.npy") CONFIG_PATH = os.path.join(BUILD_DIR, "config.json") LOGO_FILE = "Brain chat-09.png" EMBED_MODEL = None BM25 = None CHUNKS = None EMBEDDINGS = None OAI = None def tokenize(text: str): return re.findall(r"\w+", text.lower(), flags=re.UNICODE) def ensure_loaded(): global EMBED_MODEL, BM25, CHUNKS, EMBEDDINGS, OAI if CHUNKS is None: for path in [CHUNKS_PATH, TOKENS_PATH, EMBED_PATH, CONFIG_PATH]: if not os.path.exists(path): raise FileNotFoundError(f"Missing file: {path}") with open(CHUNKS_PATH, "rb") as f: CHUNKS = pickle.load(f) with open(TOKENS_PATH, "rb") as f: tokenized_chunks = pickle.load(f) EMBEDDINGS = np.load(EMBED_PATH) with open(CONFIG_PATH, "r", encoding="utf-8") as f: cfg = json.load(f) BM25 = BM25Okapi(tokenized_chunks) EMBED_MODEL = SentenceTransformer(cfg["embedding_model"]) if OAI is None: api_key = os.getenv("OPENAI_API_KEY") if not api_key: raise ValueError("OPENAI_API_KEY is missing in Hugging Face Space Secrets.") OAI = OpenAI(api_key=api_key) def search_hybrid(query: str, shortlist_k: int = 30, final_k: int = 5): ensure_loaded() query_tokens = tokenize(query) bm25_scores = BM25.get_scores(query_tokens) shortlist_idx = np.argsort(bm25_scores)[::-1][:shortlist_k] shortlist_embeddings = EMBEDDINGS[shortlist_idx] qvec = EMBED_MODEL.encode([query], normalize_embeddings=True).astype("float32")[0] dense_scores = shortlist_embeddings @ qvec rerank_order = np.argsort(dense_scores)[::-1][:final_k] final_idx = shortlist_idx[rerank_order] return [CHUNKS[int(i)] for i in final_idx] def build_context(records): blocks = [] for i, r in enumerate(records, start=1): blocks.append( f"""[Source {i}] Book: {r['book']} Section: {r['section_title']} Pages: {r['page_start']}-{r['page_end']} Text: {r['text']}""" ) return "\n\n".join(blocks) def make_sources(records): seen = set() lines = [] for r in records: key = (r["book"], r["section_title"], r["page_start"], r["page_end"]) if key in seen: continue seen.add(key) lines.append( f"• {r['book']} | {r['section_title']} | pp. {r['page_start']}-{r['page_end']}" ) return "\n".join(lines) def choose_quiz_count(user_text: str, selector: str) -> int: if selector in {"3", "5", "7"}: return int(selector) t = user_text.lower() if any(k in t for k in ["mock test", "final exam", "exam practice", "full test"]): return 7 if any(k in t for k in ["detailed", "revision", "comprehensive", "study"]): return 5 return 3 def language_instruction(language_mode: str) -> str: if language_mode == "English": return "Answer only in English." if language_mode == "Spanish": return "Answer only in Spanish." if language_mode == "Bilingual": return "Answer first in English, then provide a Spanish version under the heading 'Español:'." return ( "If the user's message is in Spanish, answer in Spanish. " "If the user's message is in English, answer in English." ) def build_tutor_prompt(mode: str, language_mode: str, question: str, context: str) -> str: mode_map = { "Explain": "Explain clearly like a friendly tutor using simple language.", "Detailed": "Give a fuller and more detailed explanation.", "Short Notes": "Answer in concise revision-note format using bullets.", "Flashcards": "Create 6 flashcards in Q/A format.", "Case-Based": "Create a short clinical scenario and explain it clearly." } return f""" You are BrainChat, an interactive neurology and neuroanatomy tutor. Rules: - Use only the provided context from the books. - If the answer is not supported by the context, say exactly: Not found in the course material. - Be accurate and student-friendly. - {language_instruction(language_mode)} Teaching style: {mode_map[mode]} Context: {context} Question: {question} """.strip() def build_quiz_generation_prompt(language_mode: str, topic: str, context: str, n_questions: int) -> str: return f""" You are BrainChat, an interactive tutor. Rules: - Use only the provided context. - Create exactly {n_questions} quiz questions. - Questions should be short and clear. - Also create a short answer key. - Return valid JSON only. - {language_instruction(language_mode)} Required JSON format: {{ "title": "short quiz title", "questions": [ {{"q": "question 1", "answer_key": "expected short answer"}}, {{"q": "question 2", "answer_key": "expected short answer"}} ] }} Context: {context} Topic: {topic} """.strip() def chat_text(prompt: str) -> str: resp = OAI.chat.completions.create( model="gpt-4o-mini", temperature=0.2, messages=[ {"role": "system", "content": "You are a helpful educational assistant."}, {"role": "user", "content": prompt}, ], ) return resp.choices[0].message.content.strip() def chat_json(prompt: str) -> dict: resp = OAI.chat.completions.create( model="gpt-4o-mini", temperature=0.2, response_format={"type": "json_object"}, messages=[ {"role": "system", "content": "Return only valid JSON."}, {"role": "user", "content": prompt}, ], ) return json.loads(resp.choices[0].message.content) def detect_logo_url(): if os.path.exists(LOGO_FILE): return f"/gradio_api/file={quote(LOGO_FILE)}" return None def render_header(): logo_url = detect_logo_url() if logo_url: logo_html = f'BrainChat Logo' else: logo_html = '
BRAIN
CHAT
' return f"""
BrainChat
Interactive neurology and neuroanatomy tutor based on your uploaded books
""" def answer_question(message, history, mode, language_mode, quiz_count_mode, show_sources): if not message or not message.strip(): return "Please type a topic or question." ensure_loaded() user_text = message.strip() records = search_hybrid(user_text, shortlist_k=30, final_k=5) context = build_context(records) if mode == "Quiz Me": n_questions = choose_quiz_count(user_text, quiz_count_mode) prompt = build_quiz_generation_prompt(language_mode, user_text, context, n_questions) quiz_data = chat_json(prompt) lines = [] lines.append(f"**{quiz_data.get('title', 'Quiz')}**") lines.append("") lines.append(f"**Total questions: {len(quiz_data['questions'])}**") lines.append("") for i, q in enumerate(quiz_data["questions"], start=1): lines.append(f"**Q{i}.** {q['q']}") lines.append("") lines.append("Reply with your answers in one message, for example:") lines.append("1. ...") lines.append("2. ...") lines.append("") lines.append("This version generates quiz questions only. Evaluation can be added next.") if show_sources: lines.append("\n---\n**Topic sources used to create the quiz:**") lines.append(make_sources(records)) return "\n".join(lines) prompt = build_tutor_prompt(mode, language_mode, user_text, context) answer = chat_text(prompt) if show_sources: answer += "\n\n---\n**Sources used:**\n" + make_sources(records) return answer CSS = """ body, .gradio-container { background: #dcdcdc !important; font-family: Arial, Helvetica, sans-serif !important; } footer { display: none !important; } .hero-card { max-width: 860px; margin: 18px auto 14px auto; border-radius: 28px; background: linear-gradient(180deg, #e8c7d4 0%, #a55ca2 48%, #2b0c46 100%); padding: 22px 22px 18px 22px; } .hero-inner { text-align: center; } .hero-title { color: white; font-size: 34px; font-weight: 800; margin-top: 6px; } .hero-subtitle { color: white; opacity: 0.92; font-size: 16px; margin-top: 6px; } """ with gr.Blocks(css=CSS) as demo: gr.HTML(render_header()) with gr.Row(): mode = gr.Dropdown( choices=["Explain", "Detailed", "Short Notes", "Quiz Me", "Flashcards", "Case-Based"], value="Explain", label="Tutor Mode" ) language_mode = gr.Dropdown( choices=["Auto", "English", "Spanish", "Bilingual"], value="Auto", label="Answer Language" ) with gr.Row(): quiz_count_mode = gr.Dropdown( choices=["Auto", "3", "5", "7"], value="Auto", label="Quiz Questions" ) show_sources = gr.Checkbox(value=True, label="Show Sources") gr.Markdown(""" **How to use** - Choose a **Tutor Mode** - Then type a topic or question - For **Quiz Me**, type a topic such as: `cranial nerves` - For **Flashcards**, type a topic such as: `hippocampus` """) gr.ChatInterface( fn=answer_question, additional_inputs=[mode, language_mode, quiz_count_mode, show_sources], title=None, description=None, textbox=gr.Textbox( placeholder="Ask a question or type a topic...", lines=1 ) ) if __name__ == "__main__": demo.launch()