import os
import re
import json
import html
import pickle
from urllib.parse import quote
import numpy as np
import gradio as gr
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer
from openai import OpenAI
# ---------------------------------------------------
# Paths
# ---------------------------------------------------
BUILD_DIR = "brainchat_build"
CHUNKS_PATH = os.path.join(BUILD_DIR, "chunks.pkl")
TOKENS_PATH = os.path.join(BUILD_DIR, "tokenized_chunks.pkl")
EMBED_PATH = os.path.join(BUILD_DIR, "embeddings.npy")
CONFIG_PATH = os.path.join(BUILD_DIR, "config.json")
EMBED_MODEL = None
BM25 = None
CHUNKS = None
EMBEDDINGS = None
OAI = None
# ---------------------------------------------------
# Load resources once
# ---------------------------------------------------
def tokenize(text: str):
return re.findall(r"\w+", text.lower(), flags=re.UNICODE)
def ensure_loaded():
global EMBED_MODEL, BM25, CHUNKS, EMBEDDINGS, OAI
if CHUNKS is None:
missing = []
for path in [CHUNKS_PATH, TOKENS_PATH, EMBED_PATH, CONFIG_PATH]:
if not os.path.exists(path):
missing.append(path)
if missing:
raise FileNotFoundError(
"Missing build files:\n" + "\n".join(missing)
)
with open(CHUNKS_PATH, "rb") as f:
CHUNKS = pickle.load(f)
with open(TOKENS_PATH, "rb") as f:
tokenized_chunks = pickle.load(f)
EMBEDDINGS = np.load(EMBED_PATH)
with open(CONFIG_PATH, "r", encoding="utf-8") as f:
cfg = json.load(f)
BM25 = BM25Okapi(tokenized_chunks)
EMBED_MODEL = SentenceTransformer(cfg["embedding_model"])
if OAI is None:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("OPENAI_API_KEY is missing in Hugging Face Space Secrets.")
OAI = OpenAI(api_key=api_key)
# ---------------------------------------------------
# Hybrid retrieval
# ---------------------------------------------------
def search_hybrid(query: str, shortlist_k: int = 30, final_k: int = 5):
ensure_loaded()
query_tokens = tokenize(query)
bm25_scores = BM25.get_scores(query_tokens)
shortlist_idx = np.argsort(bm25_scores)[::-1][:shortlist_k]
shortlist_embeddings = EMBEDDINGS[shortlist_idx]
qvec = EMBED_MODEL.encode([query], normalize_embeddings=True).astype("float32")[0]
dense_scores = shortlist_embeddings @ qvec
rerank_order = np.argsort(dense_scores)[::-1][:final_k]
final_idx = shortlist_idx[rerank_order]
return [CHUNKS[int(i)] for i in final_idx]
def build_context(records):
blocks = []
for i, r in enumerate(records, start=1):
blocks.append(
f"""[Source {i}]
Book: {r['book']}
Section: {r['section_title']}
Pages: {r['page_start']}-{r['page_end']}
Text:
{r['text']}"""
)
return "\n\n".join(blocks)
def make_sources(records):
seen = set()
lines = []
for r in records:
key = (r["book"], r["section_title"], r["page_start"], r["page_end"])
if key in seen:
continue
seen.add(key)
lines.append(
f"- {r['book']} | {r['section_title']} | pp. {r['page_start']}-{r['page_end']}"
)
return "\n".join(lines)
# ---------------------------------------------------
# Prompt helpers
# ---------------------------------------------------
def build_system_prompt(mode: str, language_mode: str) -> str:
mode_map = {
"Explain": (
"Explain the answer clearly like a supportive tutor. "
"Use short headings if helpful. Keep it easy to understand."
),
"Detailed": (
"Give a fuller, more detailed explanation like a tutor teaching a serious student. "
"Include concept, key points, and clinical relevance when supported by context."
),
"Short Notes": (
"Answer in concise revision-note format. "
"Use short bullet points with only the most important facts."
),
"Quiz Me": (
"Do not immediately give the full answer. "
"First ask 3 short quiz questions based on the topic. "
"Then give a brief correct-answer summary."
),
"Flashcards": (
"Create 6 short flashcards in Q/A format using only the provided context."
),
"Case-Based": (
"Create a short case-based explanation or clinical vignette, then explain the answer clearly."
),
}
language_map = {
"Auto": (
"If the user's question is in Spanish, answer in Spanish. "
"If the user's question is in English, answer in English."
),
"English": "Answer only in English.",
"Spanish": "Answer only in Spanish.",
"Bilingual": (
"Answer first in English, then provide a Spanish version under a heading 'EspaƱol:'."
),
}
return f"""
You are BrainChat, an interactive neurology and neuroanatomy tutor.
Rules:
- Use only the provided context from the books.
- If the answer is not supported by the context, say exactly:
Not found in the course material.
- Be accurate, calm, and student-friendly.
- Do not invent facts outside the provided context.
- If sources are weak or incomplete, be honest.
Teaching mode:
{mode_map[mode]}
Language behavior:
{language_map[language_mode]}
""".strip()
# ---------------------------------------------------
# Main answer function
# ---------------------------------------------------
def answer_question(message: str, history, mode: str, language_mode: str, show_sources: bool):
if not message or not message.strip():
return "Please type a question."
try:
records = search_hybrid(message, shortlist_k=30, final_k=5)
context = build_context(records)
system_prompt = build_system_prompt(mode, language_mode)
user_prompt = f"""Context:
{context}
Question:
{message}
"""
resp = OAI.chat.completions.create(
model="gpt-4o-mini",
temperature=0.2,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
)
answer = resp.choices[0].message.content.strip()
if show_sources:
answer += "\n\n---\nSources used:\n" + make_sources(records)
return answer
except Exception as e:
return f"Error: {str(e)}"
# ---------------------------------------------------
# UI helpers
# ---------------------------------------------------
def detect_logo_url():
candidates = [
"Brain chat-09.png",
"brainchat_logo.png",
"Brain Chat Imagen.svg",
]
for name in candidates:
if os.path.exists(name):
return f"/gradio_api/file={quote(name)}"
return None
def header_html():
logo_url = detect_logo_url()
if logo_url:
logo = f''
else:
logo = '