BrainChat / app5.py
Deevyankar's picture
Rename app.py to app5.py
b9eb209 verified
import os
import re
import json
import pickle
from urllib.parse import quote
import numpy as np
import gradio as gr
from rank_bm25 import BM25Okapi
from sentence_transformers import SentenceTransformer
from openai import OpenAI
# ============================================================
# Configuration
# ============================================================
BUILD_DIR = "brainchat_build"
CHUNKS_PATH = os.path.join(BUILD_DIR, "chunks.pkl")
TOKENS_PATH = os.path.join(BUILD_DIR, "tokenized_chunks.pkl")
EMBED_PATH = os.path.join(BUILD_DIR, "embeddings.npy")
CONFIG_PATH = os.path.join(BUILD_DIR, "config.json")
# Put ONE of these logo files in your Space repo root (same folder as app.py)
LOGO_CANDIDATES = [
"Brain chat-09.png",
"brainchat_logo.png.png",
"Brain Chat Imagen.svg",
"ebcbb9f5-022f-473a-bf51-7e7974f794b4.png",
]
MODEL_NAME_TEXT = os.getenv("OPENAI_MODEL", "gpt-4o-mini")
# ============================================================
# Globals (lazy loaded)
# ============================================================
BM25 = None
CHUNKS = None
EMBEDDINGS = None
EMBED_MODEL = None
CLIENT = None
# ============================================================
# Utilities
# ============================================================
def tokenize(text: str):
return re.findall(r"\w+", text.lower(), flags=re.UNICODE)
def ensure_loaded():
global BM25, CHUNKS, EMBEDDINGS, EMBED_MODEL, CLIENT
if CHUNKS is None:
missing = [p for p in [CHUNKS_PATH, TOKENS_PATH, EMBED_PATH, CONFIG_PATH] if not os.path.exists(p)]
if missing:
raise FileNotFoundError(
"Missing build files. Make sure you ran the build step and committed brainchat_build/.\n"
+ "\n".join(missing)
)
with open(CHUNKS_PATH, "rb") as f:
CHUNKS = pickle.load(f)
with open(TOKENS_PATH, "rb") as f:
tokenized_chunks = pickle.load(f)
EMBEDDINGS = np.load(EMBED_PATH)
with open(CONFIG_PATH, "r", encoding="utf-8") as f:
cfg = json.load(f)
BM25 = BM25Okapi(tokenized_chunks)
EMBED_MODEL = SentenceTransformer(cfg["embedding_model"])
if CLIENT is None:
api_key = os.getenv("OPENAI_API_KEY")
if not api_key:
raise ValueError("OPENAI_API_KEY is missing. Add it in your Space Secrets.")
CLIENT = OpenAI(api_key=api_key)
def search_hybrid(query: str, shortlist_k: int = 30, final_k: int = 5):
ensure_loaded()
q_tokens = tokenize(query)
bm25_scores = BM25.get_scores(q_tokens)
shortlist_idx = np.argsort(bm25_scores)[::-1][:shortlist_k]
qvec = EMBED_MODEL.encode([query], normalize_embeddings=True).astype("float32")[0]
shortlist_emb = EMBEDDINGS[shortlist_idx]
dense_scores = shortlist_emb @ qvec
rerank = np.argsort(dense_scores)[::-1][:final_k]
final_idx = shortlist_idx[rerank]
return [CHUNKS[int(i)] for i in final_idx]
def build_context(records):
blocks = []
for i, r in enumerate(records, start=1):
blocks.append(
f"""[Source {i}]
Book: {r.get('book','')}
Section: {r.get('section_title','')}
Pages: {r.get('page_start','')}-{r.get('page_end','')}
Text:
{r.get('text','')}"""
)
return "\n\n".join(blocks)
def make_sources(records):
seen = set()
lines = []
for r in records:
key = (r.get("book"), r.get("section_title"), r.get("page_start"), r.get("page_end"))
if key in seen:
continue
seen.add(key)
lines.append(
f"• {r.get('book','')} | {r.get('section_title','')} | pp. {r.get('page_start','')}-{r.get('page_end','')}"
)
return "\n".join(lines)
def choose_quiz_count(user_text: str, selector: str) -> int:
if selector in {"3", "5", "7"}:
return int(selector)
t = user_text.lower()
if any(k in t for k in ["mock test", "final exam", "exam practice", "full test"]):
return 7
if any(k in t for k in ["detailed", "revision", "comprehensive", "study"]):
return 5
return 3
def language_instruction(language_mode: str) -> str:
if language_mode == "English":
return "Answer only in English."
if language_mode == "Spanish":
return "Answer only in Spanish."
if language_mode == "Bilingual":
return "Answer first in English, then provide a Spanish version under the heading 'Español:'."
return "If the user writes in Spanish, answer in Spanish; otherwise answer in English."
def build_tutor_prompt(mode: str, language_mode: str, question: str, context: str) -> str:
mode_map = {
"Explain": (
"Explain clearly like a friendly tutor using simple language. "
"Use short headings if helpful."
),
"Detailed": (
"Give a detailed explanation. Include key terms and clinical relevance only if supported by the context."
),
"Short Notes": "Write concise revision notes using bullet points.",
"Flashcards": "Create 6 flashcards in Q/A format.",
"Case-Based": (
"Create a short clinical scenario (2–4 lines) and then explain the underlying concept using the context."
),
}
return f"""
You are BrainChat, an interactive neurology and neuroanatomy tutor.
Rules:
- Use ONLY the provided context from the books.
- If the answer is not supported by the context, say exactly:
Not found in the course material.
- Do not invent facts outside the context.
- {language_instruction(language_mode)}
Teaching style:
{mode_map.get(mode, mode_map['Explain'])}
Context:
{context}
Student question:
{question}
""".strip()
def build_quiz_generation_prompt(language_mode: str, topic: str, context: str, n_questions: int) -> str:
return f"""
You are BrainChat, an interactive tutor.
Rules:
- Use ONLY the provided context.
- Create exactly {n_questions} quiz questions.
- Questions should be short, clear, and course-aligned.
- Provide a short answer key per question.
- Return VALID JSON only.
- {language_instruction(language_mode)}
Required JSON format:
{{
"title": "short quiz title",
"questions": [
{{"q": "question 1", "answer_key": "expected short answer"}},
{{"q": "question 2", "answer_key": "expected short answer"}}
]
}}
Context:
{context}
Topic:
{topic}
""".strip()
def build_quiz_evaluation_prompt(language_mode: str, quiz_data: dict, user_answers: str) -> str:
quiz_json = json.dumps(quiz_data, ensure_ascii=False)
return f"""
You are BrainChat, an interactive tutor.
Task:
Evaluate the student's answers fairly against the answer keys.
Accept semantically correct answers even if wording differs.
Return VALID JSON only.
Required JSON format:
{{
"score_obtained": 0,
"score_total": 0,
"summary": "short overall feedback",
"results": [
{{
"question": "question text",
"answer_key": "expected short answer",
"student_answer": "student answer",
"result": "Correct / Partially Correct / Incorrect",
"feedback": "short explanation"
}}
],
"improvement_tip": "one short study suggestion"
}}
Quiz:
{quiz_json}
Student answers:
{user_answers}
Language:
{language_instruction(language_mode)}
""".strip()
def chat_text(prompt: str) -> str:
ensure_loaded()
resp = CLIENT.chat.completions.create(
model=MODEL_NAME_TEXT,
temperature=0.2,
messages=[
{"role": "system", "content": "You are a helpful educational assistant."},
{"role": "user", "content": prompt},
],
)
return resp.choices[0].message.content.strip()
def chat_json(prompt: str) -> dict:
ensure_loaded()
resp = CLIENT.chat.completions.create(
model=MODEL_NAME_TEXT,
temperature=0.2,
response_format={"type": "json_object"},
messages=[
{"role": "system", "content": "Return only valid JSON."},
{"role": "user", "content": prompt},
],
)
return json.loads(resp.choices[0].message.content)
# ============================================================
# Logo + Header HTML
# ============================================================
def find_logo_file():
for name in LOGO_CANDIDATES:
if os.path.exists(name):
return name
return None
def logo_img_tag(size_px: int = 88) -> str:
logo_file = find_logo_file()
if logo_file:
url = f"/gradio_api/file={quote(logo_file)}"
return f'<img src="{url}" class="bc-logo-img" width="{size_px}" height="{size_px}" alt="BrainChat logo" />'
return '<div class="bc-logo-fallback">BRAIN<br>CHAT</div>'
def render_top_banner() -> str:
return f"""
<div class="bc-banner">
<div class="bc-banner-inner">
<div class="bc-banner-logo">{logo_img_tag(64)}</div>
<div class="bc-banner-text">
<div class="bc-banner-title">BrainChat</div>
<div class="bc-banner-subtitle">Neurology & neuroanatomy tutor (book-based)</div>
</div>
</div>
</div>
""".strip()
def render_phone_logo() -> str:
return f"""
<div class="bc-phone-logo">
{logo_img_tag(84)}
</div>
""".strip()
# ============================================================
# Chat logic (with quiz state)
# ============================================================
def respond(message, history, mode, language_mode, quiz_count_mode, show_sources, quiz_state):
if history is None:
history = []
if quiz_state is None:
quiz_state = {"active": False, "quiz_data": None, "language_mode": "Auto"}
user_text = (message or "").strip()
if not user_text:
return "", history, quiz_state
try:
history = history + [{"role": "user", "content": user_text}]
# Quiz evaluation step
if quiz_state.get("active", False):
evaluation_prompt = build_quiz_evaluation_prompt(
quiz_state.get("language_mode", language_mode),
quiz_state.get("quiz_data", {}),
user_text,
)
evaluation = chat_json(evaluation_prompt)
lines = []
lines.append(f"**Score:** {evaluation.get('score_obtained', 0)}/{evaluation.get('score_total', 0)}")
if evaluation.get("summary"):
lines.append(f"\n**Overall:** {evaluation['summary']}")
if evaluation.get("improvement_tip"):
lines.append(f"\n**Tip:** {evaluation['improvement_tip']}\n")
results = evaluation.get("results", [])
if results:
lines.append("**Question-wise feedback:**")
for item in results:
lines.append("")
lines.append(f"**Q:** {item.get('question','')}")
lines.append(f"**Your answer:** {item.get('student_answer','')}")
lines.append(f"**Expected:** {item.get('answer_key','')}")
lines.append(f"**Result:** {item.get('result','')}")
lines.append(f"**Feedback:** {item.get('feedback','')}")
assistant_text = "\n".join(lines).strip()
history = history + [{"role": "assistant", "content": assistant_text}]
quiz_state = {"active": False, "quiz_data": None, "language_mode": language_mode}
return "", history, quiz_state
# Normal retrieval
records = search_hybrid(user_text, shortlist_k=30, final_k=5)
context = build_context(records)
# Quiz generation
if mode == "Quiz Me":
n_questions = choose_quiz_count(user_text, quiz_count_mode)
quiz_prompt = build_quiz_generation_prompt(language_mode, user_text, context, n_questions)
quiz_data = chat_json(quiz_prompt)
lines = []
lines.append(f"**{quiz_data.get('title','Quiz')}**")
lines.append(f"\n**Total questions:** {len(quiz_data.get('questions', []))}\n")
lines.append("Reply in ONE message with numbered answers, like:")
lines.append("1. ...")
lines.append("2. ...\n")
for i, q in enumerate(quiz_data.get("questions", []), start=1):
lines.append(f"**Q{i}.** {q.get('q','')}")
if show_sources:
lines.append("\n\n**Sources used to create this quiz:**")
lines.append(make_sources(records))
assistant_text = "\n".join(lines).strip()
history = history + [{"role": "assistant", "content": assistant_text}]
quiz_state = {"active": True, "quiz_data": quiz_data, "language_mode": language_mode}
return "", history, quiz_state
# Other modes
tutor_prompt = build_tutor_prompt(mode, language_mode, user_text, context)
answer = chat_text(tutor_prompt)
if show_sources:
answer = (answer or "").strip() + "\n\n**Sources:**\n" + make_sources(records)
history = history + [{"role": "assistant", "content": answer.strip()}]
return "", history, quiz_state
except Exception as e:
history = history + [{"role": "assistant", "content": f"Error: {str(e)}"}]
quiz_state = {"active": False, "quiz_data": None, "language_mode": language_mode}
return "", history, quiz_state
def clear_all():
return "", [], {"active": False, "quiz_data": None, "language_mode": "Auto"}
# ============================================================
# CSS (Instagram-style phone mock)
# ============================================================
CSS = r"""
:root{
--bc-page-bg: #dcdcdc;
--bc-grad-top: #E8C7D4;
--bc-grad-mid: #A55CA2;
--bc-grad-bot: #2B0C46;
--bc-yellow: #FFF34A;
--bc-bot-bubble: #FAF7B4;
--bc-user-bubble: #FFFFFF;
--bc-ink: #141414;
}
body, .gradio-container{
background: var(--bc-page-bg) !important;
font-family: ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial;
}
footer{ display:none !important; }
/* Banner */
#bc_banner{ max-width: 980px; margin: 18px auto 8px auto; }
.bc-banner{
background: linear-gradient(180deg, var(--bc-grad-top) 0%, var(--bc-grad-mid) 52%, var(--bc-grad-bot) 100%);
border-radius: 26px;
padding: 14px 16px;
box-shadow: 0 10px 26px rgba(0,0,0,.12);
}
.bc-banner-inner{ display:flex; align-items:center; gap: 12px; color: white; }
.bc-banner-title{ font-size: 20px; font-weight: 800; line-height:1.1; }
.bc-banner-subtitle{ font-size: 13px; opacity:.92; margin-top:2px; }
.bc-banner-logo .bc-logo-img{ border-radius: 999px; background: var(--bc-yellow); padding: 6px; display:block; }
.bc-logo-fallback{
width: 64px; height: 64px;
border-radius: 999px;
background: var(--bc-yellow);
display:flex; align-items:center; justify-content:center;
color: #111; font-weight: 900; font-size: 12px; text-align:center;
}
/* Settings */
#bc_settings{ max-width: 980px; margin: 0 auto 10px auto; }
#bc_settings .label{ font-weight: 700; }
/* Phone */
#bc_phone{
max-width: 420px;
margin: 0 auto 18px auto;
border-radius: 38px;
background: linear-gradient(180deg, var(--bc-grad-top) 0%, var(--bc-grad-mid) 45%, var(--bc-grad-bot) 100%);
box-shadow: 0 18px 40px rgba(0,0,0,.18);
border: 1px solid rgba(255,255,255,.22);
padding: 14px 14px 12px 14px;
position: relative;
}
/* Floating logo in phone */
#bc_phone_logo{
position: absolute;
top: 12px;
left: 50%;
transform: translateX(-50%);
z-index: 10;
}
.bc-phone-logo{
width: 92px; height: 92px;
border-radius: 999px;
background: var(--bc-yellow);
display:flex; align-items:center; justify-content:center;
box-shadow: 0 10px 22px rgba(0,0,0,.18);
}
.bc-phone-logo .bc-logo-img{
width: 84px !important; height: 84px !important; object-fit: contain;
}
/* Push chat down under logo */
#bc_chatbot{ margin-top: 92px; }
/* Chatbot transparent */
#bc_chatbot, #bc_chatbot > div{
background: transparent !important;
border: none !important;
box-shadow: none !important;
}
#bc_chatbot .toolbar{ display:none !important; }
/* Bubble styling via internal testid markers */
#bc_chatbot button[data-testid="user"],
#bc_chatbot button[data-testid="bot"]{
max-width: 82%;
border-radius: 18px !important;
padding: 12px 14px !important;
color: var(--bc-ink) !important;
box-shadow: 0 8px 18px rgba(0,0,0,.10);
border: 0 !important;
line-height: 1.35;
font-size: 14px;
}
/* User bubble white */
#bc_chatbot button[data-testid="user"]{
background: var(--bc-user-bubble) !important;
}
/* Bot bubble pale yellow */
#bc_chatbot button[data-testid="bot"]{
background: var(--bc-bot-bubble) !important;
}
/* Bubble tails */
#bc_chatbot button[data-testid="user"]::after{
content:"";
position:absolute;
right:-7px;
bottom: 12px;
width:0; height:0;
border-left: 10px solid var(--bc-user-bubble);
border-top: 8px solid transparent;
border-bottom: 8px solid transparent;
}
#bc_chatbot button[data-testid="bot"]::before{
content:"";
position:absolute;
left:-7px;
bottom: 12px;
width:0; height:0;
border-right: 10px solid var(--bc-bot-bubble);
border-top: 8px solid transparent;
border-bottom: 8px solid transparent;
}
/* Input bar */
#bc_input_row{
margin-top: 10px;
background: rgba(255,243,74,.96);
border-radius: 999px;
padding: 10px 10px;
box-shadow: 0 10px 22px rgba(0,0,0,.14);
align-items: center;
}
#bc_plus{
width: 34px; height: 34px;
border-radius: 999px;
display:flex;
align-items:center;
justify-content:center;
font-weight: 900;
color: var(--bc-grad-bot);
background: rgba(255,255,255,.35);
user-select: none;
}
#bc_msg textarea{
background: rgba(255,255,255,.35) !important;
border-radius: 999px !important;
border: none !important;
padding: 10px 12px !important;
color: var(--bc-grad-bot) !important;
box-shadow: none !important;
}
#bc_send{
min-width: 42px !important;
height: 38px !important;
border-radius: 999px !important;
border: none !important;
background: rgba(255,255,255,.35) !important;
color: var(--bc-grad-bot) !important;
font-size: 18px !important;
font-weight: 900 !important;
}
#bc_send:hover{ background: rgba(255,255,255,.55) !important; }
/* Clear */
#bc_clear{
max-width: 420px;
margin: 10px auto 0 auto;
border-radius: 14px !important;
}
@media (max-width: 480px){
#bc_phone{ max-width: 95vw; }
#bc_chatbot button[data-testid="user"],
#bc_chatbot button[data-testid="bot"]{
max-width: 88%;
font-size: 14px;
}
}
"""
# ============================================================
# UI
# ============================================================
with gr.Blocks() as demo:
quiz_state = gr.State({"active": False, "quiz_data": None, "language_mode": "Auto"})
gr.HTML(render_top_banner(), elem_id="bc_banner")
with gr.Accordion("Settings", open=False, elem_id="bc_settings"):
mode = gr.Dropdown(
choices=["Explain", "Detailed", "Short Notes", "Flashcards", "Case-Based", "Quiz Me"],
value="Explain",
label="Tutor Mode",
)
language_mode = gr.Dropdown(
choices=["Auto", "English", "Spanish", "Bilingual"],
value="Auto",
label="Answer Language",
)
quiz_count_mode = gr.Dropdown(
choices=["Auto", "3", "5", "7"],
value="Auto",
label="Quiz Questions",
)
show_sources = gr.Checkbox(value=True, label="Show Sources")
with gr.Group(elem_id="bc_phone"):
gr.HTML(render_phone_logo(), elem_id="bc_phone_logo")
chatbot = gr.Chatbot(
value=[],
elem_id="bc_chatbot",
height=560,
layout="bubble",
container=False,
show_label=False,
autoscroll=True,
buttons=[],
placeholder="Ask a question or type a topic…",
)
with gr.Row(elem_id="bc_input_row"):
gr.HTML("<div>+</div>", elem_id="bc_plus")
msg = gr.Textbox(
placeholder="Type a message…",
show_label=False,
container=False,
scale=8,
elem_id="bc_msg",
)
send_btn = gr.Button("➤", elem_id="bc_send", scale=1)
clear_btn = gr.Button("Clear chat", elem_id="bc_clear")
msg.submit(
respond,
inputs=[msg, chatbot, mode, language_mode, quiz_count_mode, show_sources, quiz_state],
outputs=[msg, chatbot, quiz_state],
)
send_btn.click(
respond,
inputs=[msg, chatbot, mode, language_mode, quiz_count_mode, show_sources, quiz_state],
outputs=[msg, chatbot, quiz_state],
)
clear_btn.click(
clear_all,
inputs=None,
outputs=[msg, chatbot, quiz_state],
queue=False,
)
if __name__ == "__main__":
demo.launch(css=CSS)