import os
import re
import json
import pickle
from urllib.parse import quote
import numpy as np
import gradio as gr
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer
from openai import OpenAI
# =====================================================
# PATHS
# =====================================================
BUILD_DIR = "brainchat_build"
CHUNKS_PATH = os.path.join(BUILD_DIR, "chunks.pkl")
TOKENS_PATH = os.path.join(BUILD_DIR, "tokenized_chunks.pkl")
EMBED_PATH = os.path.join(BUILD_DIR, "embeddings.npy")
CONFIG_PATH = os.path.join(BUILD_DIR, "config.json")
LOGO_FILE = "Brain chat-09.png"
# =====================================================
# GLOBALS
# =====================================================
EMBED_MODEL = None
BM25 = None
CHUNKS = None
EMBEDDINGS = None
CLIENT = None
# =====================================================
# LOADERS
# =====================================================
def tokenize(text: str):
return re.findall(r"\w+", text.lower(), flags=re.UNICODE)
def ensure_loaded():
global EMBED_MODEL, BM25, CHUNKS, EMBEDDINGS, CLIENT
if CHUNKS is None:
for path in [CHUNKS_PATH, TOKENS_PATH, EMBED_PATH, CONFIG_PATH]:
if not os.path.exists(path):
raise FileNotFoundError(f"Missing file: {path}")
with open(CHUNKS_PATH, "rb") as f:
CHUNKS = pickle.load(f)
with open(TOKENS_PATH, "rb") as f:
tokenized_chunks = pickle.load(f)
EMBEDDINGS = np.load(EMBED_PATH)
with open(CONFIG_PATH, "r", encoding="utf-8") as f:
cfg = json.load(f)
BM25 = BM25Okapi(tokenized_chunks)
EMBED_MODEL = SentenceTransformer(cfg["embedding_model"])
if CLIENT is None:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("OPENAI_API_KEY is missing in Hugging Face Space Secrets.")
CLIENT = OpenAI(api_key=api_key)
# =====================================================
# RETRIEVAL
# =====================================================
def search_hybrid(query: str, shortlist_k: int = 20, final_k: int = 3):
ensure_loaded()
query_tokens = tokenize(query)
bm25_scores = BM25.get_scores(query_tokens)
shortlist_idx = np.argsort(bm25_scores)[::-1][:shortlist_k]
shortlist_embeddings = EMBEDDINGS[shortlist_idx]
qvec = EMBED_MODEL.encode([query], normalize_embeddings=True).astype("float32")[0]
dense_scores = shortlist_embeddings @ qvec
rerank_order = np.argsort(dense_scores)[::-1][:final_k]
final_idx = shortlist_idx[rerank_order]
return [CHUNKS[int(i)] for i in final_idx]
def build_context(records):
blocks = []
for i, r in enumerate(records, start=1):
blocks.append(
f"""[Source {i}]
Book: {r['book']}
Section: {r['section_title']}
Pages: {r['page_start']}-{r['page_end']}
Text:
{r['text']}"""
)
return "\n\n".join(blocks)
def make_sources(records):
seen = set()
lines = []
for r in records:
key = (r["book"], r["section_title"], r["page_start"], r["page_end"])
if key in seen:
continue
seen.add(key)
lines.append(
f"• {r['book']} | {r['section_title']} | pp. {r['page_start']}-{r['page_end']}"
)
return "\n".join(lines)
# =====================================================
# PROMPTS
# =====================================================
def language_instruction(language_mode: str) -> str:
if language_mode == "English":
return "Answer only in English."
if language_mode == "Spanish":
return "Answer only in Spanish."
if language_mode == "Bilingual":
return "Answer first in English, then provide a Spanish version under the heading 'Español:'."
return (
"If the user's message is in Spanish, answer in Spanish. "
"If the user's message is in English, answer in English."
)
def choose_quiz_count(user_text: str, selector: str) -> int:
if selector in {"3", "5", "7"}:
return int(selector)
t = user_text.lower()
if any(k in t for k in ["mock test", "final exam", "exam practice", "full test"]):
return 7
if any(k in t for k in ["detailed", "revision", "comprehensive", "study"]):
return 5
return 3
def build_tutor_prompt(mode: str, language_mode: str, question: str, context: str) -> str:
mode_map = {
"Explain": (
"Explain clearly like a friendly tutor using simple language. "
"Use short headings if useful."
),
"Detailed": (
"Give a fuller and more detailed explanation. Include concept, key points, and clinical relevance when supported by context."
),
"Short Notes": (
"Answer in concise revision-note format using short bullet points."
),
"Flashcards": (
"Create 6 flashcards in Q/A format using only the provided context."
),
"Case-Based": (
"Create a short clinical scenario and explain it clearly using the provided context."
)
}
return f"""
You are BrainChat, an interactive neurology and neuroanatomy tutor.
Rules:
- Use only the provided context from the books.
- If the answer is not supported by the context, say exactly:
Not found in the course material.
- Be accurate and student-friendly.
- Do not invent facts outside the context.
- {language_instruction(language_mode)}
Teaching style:
{mode_map[mode]}
Context:
{context}
Question:
{question}
""".strip()
def build_quiz_generation_prompt(language_mode: str, topic: str, context: str, n_questions: int) -> str:
return f"""
You are BrainChat, an interactive tutor.
Rules:
- Use only the provided context.
- Create exactly {n_questions} quiz questions.
- Questions should be short and clear.
- Also create a short answer key.
- Return valid JSON only.
- {language_instruction(language_mode)}
Required JSON format:
{{
"title": "short quiz title",
"questions": [
{{"q": "question 1", "answer_key": "expected short answer"}},
{{"q": "question 2", "answer_key": "expected short answer"}}
]
}}
Context:
{context}
Topic:
{topic}
""".strip()
def build_quiz_evaluation_prompt(language_mode: str, quiz_data: dict, user_answers: str) -> str:
quiz_json = json.dumps(quiz_data, ensure_ascii=False)
return f"""
You are BrainChat, an interactive tutor.
Evaluate the student's answers fairly using the quiz answer key.
Give:
- total score
- per-question feedback
- one short improvement suggestion
Rules:
- Accept semantically correct answers even if wording differs.
- Return valid JSON only.
- {language_instruction(language_mode)}
Required JSON format:
{{
"score_obtained": 0,
"score_total": 0,
"summary": "short overall feedback",
"results": [
{{
"question": "question text",
"student_answer": "student answer",
"result": "Correct / Partially Correct / Incorrect",
"feedback": "short explanation"
}}
]
}}
Quiz data:
{quiz_json}
Student answers:
{user_answers}
""".strip()
# =====================================================
# OPENAI HELPERS
# =====================================================
def chat_text(prompt: str) -> str:
resp = CLIENT.chat.completions.create(
model="gpt-4o-mini",
temperature=0.2,
messages=[
{"role": "system", "content": "You are a helpful educational assistant."},
{"role": "user", "content": prompt},
],
)
return resp.choices[0].message.content.strip()
def chat_json(prompt: str) -> dict:
resp = CLIENT.chat.completions.create(
model="gpt-4o-mini",
temperature=0.2,
response_format={"type": "json_object"},
messages=[
{"role": "system", "content": "Return only valid JSON."},
{"role": "user", "content": prompt},
],
)
return json.loads(resp.choices[0].message.content)
# =====================================================
# HTML RENDERING
# =====================================================
def md_to_html(text: str) -> str:
safe = (
text.replace("&", "&")
.replace("<", "<")
.replace(">", ">")
)
safe = re.sub(r"\*\*(.+?)\*\*", r"\1", safe)
safe = safe.replace("\n", "
")
return safe
def render_chat(history):
if not history:
return """