import os import re import json import html import pickle from urllib.parse import quote import numpy as np import gradio as gr from rank_bm25 import BM25Okapi from sentence_transformers import SentenceTransformer from openai import OpenAI # --------------------------------------------------- # Paths # --------------------------------------------------- BUILD_DIR = "brainchat_build" CHUNKS_PATH = os.path.join(BUILD_DIR, "chunks.pkl") TOKENS_PATH = os.path.join(BUILD_DIR, "tokenized_chunks.pkl") EMBED_PATH = os.path.join(BUILD_DIR, "embeddings.npy") CONFIG_PATH = os.path.join(BUILD_DIR, "config.json") EMBED_MODEL = None BM25 = None CHUNKS = None EMBEDDINGS = None OAI = None # --------------------------------------------------- # Load resources once # --------------------------------------------------- def tokenize(text: str): return re.findall(r"\w+", text.lower(), flags=re.UNICODE) def ensure_loaded(): global EMBED_MODEL, BM25, CHUNKS, EMBEDDINGS, OAI if CHUNKS is None: missing = [] for path in [CHUNKS_PATH, TOKENS_PATH, EMBED_PATH, CONFIG_PATH]: if not os.path.exists(path): missing.append(path) if missing: raise FileNotFoundError( "Missing build files:\n" + "\n".join(missing) ) with open(CHUNKS_PATH, "rb") as f: CHUNKS = pickle.load(f) with open(TOKENS_PATH, "rb") as f: tokenized_chunks = pickle.load(f) EMBEDDINGS = np.load(EMBED_PATH) with open(CONFIG_PATH, "r", encoding="utf-8") as f: cfg = json.load(f) BM25 = BM25Okapi(tokenized_chunks) EMBED_MODEL = SentenceTransformer(cfg["embedding_model"]) if OAI is None: api_key = os.getenv("OPENAI_API_KEY") if not api_key: raise ValueError("OPENAI_API_KEY is missing in Hugging Face Space Secrets.") OAI = OpenAI(api_key=api_key) # --------------------------------------------------- # Hybrid retrieval # --------------------------------------------------- def search_hybrid(query: str, shortlist_k: int = 30, final_k: int = 5): ensure_loaded() query_tokens = tokenize(query) bm25_scores = BM25.get_scores(query_tokens) shortlist_idx = np.argsort(bm25_scores)[::-1][:shortlist_k] shortlist_embeddings = EMBEDDINGS[shortlist_idx] qvec = EMBED_MODEL.encode([query], normalize_embeddings=True).astype("float32")[0] dense_scores = shortlist_embeddings @ qvec rerank_order = np.argsort(dense_scores)[::-1][:final_k] final_idx = shortlist_idx[rerank_order] return [CHUNKS[int(i)] for i in final_idx] def build_context(records): blocks = [] for i, r in enumerate(records, start=1): blocks.append( f"""[Source {i}] Book: {r['book']} Section: {r['section_title']} Pages: {r['page_start']}-{r['page_end']} Text: {r['text']}""" ) return "\n\n".join(blocks) def make_sources(records): seen = set() lines = [] for r in records: key = (r["book"], r["section_title"], r["page_start"], r["page_end"]) if key in seen: continue seen.add(key) lines.append( f"- {r['book']} | {r['section_title']} | pp. {r['page_start']}-{r['page_end']}" ) return "\n".join(lines) # --------------------------------------------------- # Prompt helpers # --------------------------------------------------- def build_system_prompt(mode: str, language_mode: str) -> str: mode_map = { "Explain": ( "Explain the answer clearly like a supportive tutor. " "Use short headings if helpful. Keep it easy to understand." ), "Detailed": ( "Give a fuller, more detailed explanation like a tutor teaching a serious student. " "Include concept, key points, and clinical relevance when supported by context." ), "Short Notes": ( "Answer in concise revision-note format. " "Use short bullet points with only the most important facts." ), "Quiz Me": ( "Do not immediately give the full answer. " "First ask 3 short quiz questions based on the topic. " "Then give a brief correct-answer summary." ), "Flashcards": ( "Create 6 short flashcards in Q/A format using only the provided context." ), "Case-Based": ( "Create a short case-based explanation or clinical vignette, then explain the answer clearly." ), } language_map = { "Auto": ( "If the user's question is in Spanish, answer in Spanish. " "If the user's question is in English, answer in English." ), "English": "Answer only in English.", "Spanish": "Answer only in Spanish.", "Bilingual": ( "Answer first in English, then provide a Spanish version under a heading 'EspaƱol:'." ), } return f""" You are BrainChat, an interactive neurology and neuroanatomy tutor. Rules: - Use only the provided context from the books. - If the answer is not supported by the context, say exactly: Not found in the course material. - Be accurate, calm, and student-friendly. - Do not invent facts outside the provided context. - If sources are weak or incomplete, be honest. Teaching mode: {mode_map[mode]} Language behavior: {language_map[language_mode]} """.strip() # --------------------------------------------------- # Main answer function # --------------------------------------------------- def answer_question(message: str, history, mode: str, language_mode: str, show_sources: bool): if not message or not message.strip(): return "Please type a question." try: records = search_hybrid(message, shortlist_k=30, final_k=5) context = build_context(records) system_prompt = build_system_prompt(mode, language_mode) user_prompt = f"""Context: {context} Question: {message} """ resp = OAI.chat.completions.create( model="gpt-4o-mini", temperature=0.2, messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ], ) answer = resp.choices[0].message.content.strip() if show_sources: answer += "\n\n---\nSources used:\n" + make_sources(records) return answer except Exception as e: return f"Error: {str(e)}" # --------------------------------------------------- # UI helpers # --------------------------------------------------- def detect_logo_url(): candidates = [ "Brain chat-09.png", "brainchat_logo.png", "Brain Chat Imagen.svg", ] for name in candidates: if os.path.exists(name): return f"/gradio_api/file={quote(name)}" return None def header_html(): logo_url = detect_logo_url() if logo_url: logo = f'' else: logo = '
BRAIN
CHAT
' return f"""
{logo}
BrainChat
Interactive neurology and neuroanatomy tutor built from your books
""" CSS = """ body, .gradio-container { background: #dcdcdc !important; } footer { display: none !important; } """ # --------------------------------------------------- # App # --------------------------------------------------- with gr.Blocks(css=CSS) as demo: gr.HTML(header_html()) with gr.Row(): mode = gr.Dropdown( choices=["Explain", "Detailed", "Short Notes", "Quiz Me", "Flashcards", "Case-Based"], value="Explain", label="Tutor Mode" ) language_mode = gr.Dropdown( choices=["Auto", "English", "Spanish", "Bilingual"], value="Auto", label="Answer Language" ) show_sources = gr.Checkbox(value=True, label="Show Sources") gr.ChatInterface( fn=answer_question, additional_inputs=[mode, language_mode, show_sources], title=None, description="Ask questions from all uploaded neurology and neuroanatomy books.", examples=[ ["Explain the function of the cerebellum."], ["Give short notes on basal ganglia."], ["Quiz me on cranial nerves."], ["Create flashcards on hippocampus."], ["Explain multiple sclerosis in Spanish."], ], textbox=gr.Textbox( placeholder="Ask a question...", lines=1 ) ) if __name__ == "__main__": demo.launch()