BrainChat / app3.py
Deevyankar's picture
Rename app.py to app3.py
07e9bdd verified
import os
import re
import json
import pickle
from urllib.parse import quote
import numpy as np
import gradio as gr
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer
from openai import OpenAI
BUILD_DIR = "brainchat_build"
CHUNKS_PATH = os.path.join(BUILD_DIR, "chunks.pkl")
TOKENS_PATH = os.path.join(BUILD_DIR, "tokenized_chunks.pkl")
EMBED_PATH = os.path.join(BUILD_DIR, "embeddings.npy")
CONFIG_PATH = os.path.join(BUILD_DIR, "config.json")
LOGO_FILE = "Brain chat-09.png"
EMBED_MODEL = None
BM25 = None
CHUNKS = None
EMBEDDINGS = None
OAI = None
def tokenize(text: str):
return re.findall(r"\w+", text.lower(), flags=re.UNICODE)
def ensure_loaded():
global EMBED_MODEL, BM25, CHUNKS, EMBEDDINGS, OAI
if CHUNKS is None:
for path in [CHUNKS_PATH, TOKENS_PATH, EMBED_PATH, CONFIG_PATH]:
if not os.path.exists(path):
raise FileNotFoundError(f"Missing file: {path}")
with open(CHUNKS_PATH, "rb") as f:
CHUNKS = pickle.load(f)
with open(TOKENS_PATH, "rb") as f:
tokenized_chunks = pickle.load(f)
EMBEDDINGS = np.load(EMBED_PATH)
with open(CONFIG_PATH, "r", encoding="utf-8") as f:
cfg = json.load(f)
BM25 = BM25Okapi(tokenized_chunks)
EMBED_MODEL = SentenceTransformer(cfg["embedding_model"])
if OAI is None:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("OPENAI_API_KEY is missing in Hugging Face Space Secrets.")
OAI = OpenAI(api_key=api_key)
def search_hybrid(query: str, shortlist_k: int = 30, final_k: int = 5):
ensure_loaded()
query_tokens = tokenize(query)
bm25_scores = BM25.get_scores(query_tokens)
shortlist_idx = np.argsort(bm25_scores)[::-1][:shortlist_k]
shortlist_embeddings = EMBEDDINGS[shortlist_idx]
qvec = EMBED_MODEL.encode([query], normalize_embeddings=True).astype("float32")[0]
dense_scores = shortlist_embeddings @ qvec
rerank_order = np.argsort(dense_scores)[::-1][:final_k]
final_idx = shortlist_idx[rerank_order]
return [CHUNKS[int(i)] for i in final_idx]
def build_context(records):
blocks = []
for i, r in enumerate(records, start=1):
blocks.append(
f"""[Source {i}]
Book: {r['book']}
Section: {r['section_title']}
Pages: {r['page_start']}-{r['page_end']}
Text:
{r['text']}"""
)
return "\n\n".join(blocks)
def make_sources(records):
seen = set()
lines = []
for r in records:
key = (r["book"], r["section_title"], r["page_start"], r["page_end"])
if key in seen:
continue
seen.add(key)
lines.append(
f"• {r['book']} | {r['section_title']} | pp. {r['page_start']}-{r['page_end']}"
)
return "\n".join(lines)
def choose_quiz_count(user_text: str, selector: str) -> int:
if selector in {"3", "5", "7"}:
return int(selector)
t = user_text.lower()
if any(k in t for k in ["mock test", "final exam", "exam practice", "full test"]):
return 7
if any(k in t for k in ["detailed", "revision", "comprehensive", "study"]):
return 5
return 3
def language_instruction(language_mode: str) -> str:
if language_mode == "English":
return "Answer only in English."
if language_mode == "Spanish":
return "Answer only in Spanish."
if language_mode == "Bilingual":
return "Answer first in English, then provide a Spanish version under the heading 'Español:'."
return (
"If the user's message is in Spanish, answer in Spanish. "
"If the user's message is in English, answer in English."
)
def build_tutor_prompt(mode: str, language_mode: str, question: str, context: str) -> str:
mode_map = {
"Explain": "Explain clearly like a friendly tutor using simple language.",
"Detailed": "Give a fuller and more detailed explanation.",
"Short Notes": "Answer in concise revision-note format using bullets.",
"Flashcards": "Create 6 flashcards in Q/A format.",
"Case-Based": "Create a short clinical scenario and explain it clearly."
}
return f"""
You are BrainChat, an interactive neurology and neuroanatomy tutor.
Rules:
- Use only the provided context from the books.
- If the answer is not supported by the context, say exactly:
Not found in the course material.
- Be accurate and student-friendly.
- {language_instruction(language_mode)}
Teaching style:
{mode_map[mode]}
Context:
{context}
Question:
{question}
""".strip()
def build_quiz_generation_prompt(language_mode: str, topic: str, context: str, n_questions: int) -> str:
return f"""
You are BrainChat, an interactive tutor.
Rules:
- Use only the provided context.
- Create exactly {n_questions} quiz questions.
- Questions should be short and clear.
- Also create a short answer key.
- Return valid JSON only.
- {language_instruction(language_mode)}
Required JSON format:
{{
"title": "short quiz title",
"questions": [
{{"q": "question 1", "answer_key": "expected short answer"}},
{{"q": "question 2", "answer_key": "expected short answer"}}
]
}}
Context:
{context}
Topic:
{topic}
""".strip()
def chat_text(prompt: str) -> str:
resp = OAI.chat.completions.create(
model="gpt-4o-mini",
temperature=0.2,
messages=[
{"role": "system", "content": "You are a helpful educational assistant."},
{"role": "user", "content": prompt},
],
)
return resp.choices[0].message.content.strip()
def chat_json(prompt: str) -> dict:
resp = OAI.chat.completions.create(
model="gpt-4o-mini",
temperature=0.2,
response_format={"type": "json_object"},
messages=[
{"role": "system", "content": "Return only valid JSON."},
{"role": "user", "content": prompt},
],
)
return json.loads(resp.choices[0].message.content)
def detect_logo_url():
if os.path.exists(LOGO_FILE):
return f"/gradio_api/file={quote(LOGO_FILE)}"
return None
def render_header():
logo_url = detect_logo_url()
if logo_url:
logo_html = f'<img src="{logo_url}" alt="BrainChat Logo" style="width:120px;height:120px;object-fit:contain;display:block;margin:0 auto;">'
else:
logo_html = '<div style="width:120px;height:120px;border-radius:50%;background:#efe85a;display:flex;align-items:center;justify-content:center;font-weight:700;text-align:center;margin:0 auto;">BRAIN<br>CHAT</div>'
return f"""
<div class="hero-card">
<div class="hero-inner">
<div class="hero-logo">{logo_html}</div>
<div class="hero-title">BrainChat</div>
<div class="hero-subtitle">
Interactive neurology and neuroanatomy tutor based on your uploaded books
</div>
</div>
</div>
"""
def answer_question(message, history, mode, language_mode, quiz_count_mode, show_sources):
if not message or not message.strip():
return "Please type a topic or question."
ensure_loaded()
user_text = message.strip()
records = search_hybrid(user_text, shortlist_k=30, final_k=5)
context = build_context(records)
if mode == "Quiz Me":
n_questions = choose_quiz_count(user_text, quiz_count_mode)
prompt = build_quiz_generation_prompt(language_mode, user_text, context, n_questions)
quiz_data = chat_json(prompt)
lines = []
lines.append(f"**{quiz_data.get('title', 'Quiz')}**")
lines.append("")
lines.append(f"**Total questions: {len(quiz_data['questions'])}**")
lines.append("")
for i, q in enumerate(quiz_data["questions"], start=1):
lines.append(f"**Q{i}.** {q['q']}")
lines.append("")
lines.append("Reply with your answers in one message, for example:")
lines.append("1. ...")
lines.append("2. ...")
lines.append("")
lines.append("This version generates quiz questions only. Evaluation can be added next.")
if show_sources:
lines.append("\n---\n**Topic sources used to create the quiz:**")
lines.append(make_sources(records))
return "\n".join(lines)
prompt = build_tutor_prompt(mode, language_mode, user_text, context)
answer = chat_text(prompt)
if show_sources:
answer += "\n\n---\n**Sources used:**\n" + make_sources(records)
return answer
CSS = """
body, .gradio-container {
background: #dcdcdc !important;
font-family: Arial, Helvetica, sans-serif !important;
}
footer { display: none !important; }
.hero-card {
max-width: 860px;
margin: 18px auto 14px auto;
border-radius: 28px;
background: linear-gradient(180deg, #e8c7d4 0%, #a55ca2 48%, #2b0c46 100%);
padding: 22px 22px 18px 22px;
}
.hero-inner { text-align: center; }
.hero-title {
color: white;
font-size: 34px;
font-weight: 800;
margin-top: 6px;
}
.hero-subtitle {
color: white;
opacity: 0.92;
font-size: 16px;
margin-top: 6px;
}
"""
with gr.Blocks(css=CSS) as demo:
gr.HTML(render_header())
with gr.Row():
mode = gr.Dropdown(
choices=["Explain", "Detailed", "Short Notes", "Quiz Me", "Flashcards", "Case-Based"],
value="Explain",
label="Tutor Mode"
)
language_mode = gr.Dropdown(
choices=["Auto", "English", "Spanish", "Bilingual"],
value="Auto",
label="Answer Language"
)
with gr.Row():
quiz_count_mode = gr.Dropdown(
choices=["Auto", "3", "5", "7"],
value="Auto",
label="Quiz Questions"
)
show_sources = gr.Checkbox(value=True, label="Show Sources")
gr.Markdown("""
**How to use**
- Choose a **Tutor Mode**
- Then type a topic or question
- For **Quiz Me**, type a topic such as: `cranial nerves`
- For **Flashcards**, type a topic such as: `hippocampus`
""")
gr.ChatInterface(
fn=answer_question,
additional_inputs=[mode, language_mode, quiz_count_mode, show_sources],
title=None,
description=None,
textbox=gr.Textbox(
placeholder="Ask a question or type a topic...",
lines=1
)
)
if __name__ == "__main__":
demo.launch()