import os
import re
import html
import json
import pickle
from urllib.parse import quote
import numpy as np
import gradio as gr
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer
from openai import OpenAI
BUILD_DIR = "brainchat_build"
CHUNKS_PATH = os.path.join(BUILD_DIR, "chunks.pkl")
TOKENS_PATH = os.path.join(BUILD_DIR, "tokenized_chunks.pkl")
EMBED_PATH = os.path.join(BUILD_DIR, "embeddings.npy")
CONFIG_PATH = os.path.join(BUILD_DIR, "config.json")
EMBED_MODEL = None
BM25 = None
CHUNKS = None
EMBEDDINGS = None
OAI = None
def tokenize(text: str):
return re.findall(r"\w+", text.lower(), flags=re.UNICODE)
def ensure_loaded():
global EMBED_MODEL, BM25, CHUNKS, EMBEDDINGS, OAI
if CHUNKS is None:
if not os.path.exists(CHUNKS_PATH):
raise FileNotFoundError("Missing brainchat_build/chunks.pkl")
if not os.path.exists(TOKENS_PATH):
raise FileNotFoundError("Missing brainchat_build/tokenized_chunks.pkl")
if not os.path.exists(EMBED_PATH):
raise FileNotFoundError("Missing brainchat_build/embeddings.npy")
if not os.path.exists(CONFIG_PATH):
raise FileNotFoundError("Missing brainchat_build/config.json")
with open(CHUNKS_PATH, "rb") as f:
CHUNKS = pickle.load(f)
with open(TOKENS_PATH, "rb") as f:
tokenized_chunks = pickle.load(f)
EMBEDDINGS = np.load(EMBED_PATH)
with open(CONFIG_PATH, "r", encoding="utf-8") as f:
cfg = json.load(f)
BM25 = BM25Okapi(tokenized_chunks)
EMBED_MODEL = SentenceTransformer(cfg["embedding_model"])
if OAI is None:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("OPENAI_API_KEY is missing in Hugging Face Space Secrets.")
OAI = OpenAI(api_key=api_key)
def search_hybrid(query: str, shortlist_k: int = 30, final_k: int = 5):
ensure_loaded()
query_tokens = tokenize(query)
bm25_scores = BM25.get_scores(query_tokens)
shortlist_idx = np.argsort(bm25_scores)[::-1][:shortlist_k]
shortlist_embeddings = EMBEDDINGS[shortlist_idx]
qvec = EMBED_MODEL.encode([query], normalize_embeddings=True).astype("float32")[0]
dense_scores = shortlist_embeddings @ qvec
rerank_order = np.argsort(dense_scores)[::-1][:final_k]
final_idx = shortlist_idx[rerank_order]
return [CHUNKS[int(i)] for i in final_idx]
def build_context(records):
blocks = []
for i, r in enumerate(records, start=1):
blocks.append(
f"""[Source {i}]
Book: {r['book']}
Section: {r['section_title']}
Pages: {r['page_start']}-{r['page_end']}
Text:
{r['text']}"""
)
return "\n\n".join(blocks)
def make_sources(records):
seen = set()
lines = []
for r in records:
key = (r["book"], r["section_title"], r["page_start"], r["page_end"])
if key in seen:
continue
seen.add(key)
lines.append(f"- {r['book']} | {r['section_title']} | pp. {r['page_start']}-{r['page_end']}")
return "\n".join(lines)
def answer_question(message: str, history, show_sources: bool):
if not message or not message.strip():
return "Please type a question."
try:
records = search_hybrid(message, shortlist_k=30, final_k=5)
context = build_context(records)
system_prompt = """You are BrainChat, a neurology and neuroanatomy tutor.
Rules:
- Answer only from the provided context.
- If the answer is not supported by the context, say exactly:
Not found in the course material.
- Keep the answer clear and concise unless the user asks for more detail.
- If the question is in Spanish, answer in Spanish.
- If the question is in English, answer in English.
"""
user_prompt = f"""Context:
{context}
Question:
{message}
"""
resp = OAI.chat.completions.create(
model="gpt-4o-mini",
temperature=0.2,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
)
answer = resp.choices[0].message.content.strip()
if show_sources:
answer += "\n\n---\nSources used:\n" + make_sources(records)
return answer
except Exception as e:
return f"Error: {str(e)}"
def detect_logo_url():
candidates = [
"Brain chat-09.png",
"brainchat_logo.png",
"Brain Chat Imagen.svg",
]
for name in candidates:
if os.path.exists(name):
return f"/gradio_api/file={quote(name)}"
return None
def top_html():
logo_url = detect_logo_url()
if logo_url:
logo = f''
else:
logo = '