Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import json | |
| import pickle | |
| from urllib.parse import quote | |
| import numpy as np | |
| import gradio as gr | |
| from rank_bm25 import BM25Okapi | |
| from sentence_transformers import SentenceTransformer | |
| from openai import OpenAI | |
| # ============================================================ | |
| # Configuration | |
| # ============================================================ | |
| BUILD_DIR = "brainchat_build" | |
| CHUNKS_PATH = os.path.join(BUILD_DIR, "chunks.pkl") | |
| TOKENS_PATH = os.path.join(BUILD_DIR, "tokenized_chunks.pkl") | |
| EMBED_PATH = os.path.join(BUILD_DIR, "embeddings.npy") | |
| CONFIG_PATH = os.path.join(BUILD_DIR, "config.json") | |
| # Put ONE of these logo files in your Space repo root (same folder as app.py) | |
| LOGO_CANDIDATES = [ | |
| "Brain chat-09.png", | |
| "brainchat_logo.png.png", | |
| "Brain Chat Imagen.svg", | |
| "ebcbb9f5-022f-473a-bf51-7e7974f794b4.png", | |
| ] | |
| MODEL_NAME_TEXT = os.getenv("OPENAI_MODEL", "gpt-4o-mini") | |
| # ============================================================ | |
| # Globals (lazy loaded) | |
| # ============================================================ | |
| BM25 = None | |
| CHUNKS = None | |
| EMBEDDINGS = None | |
| EMBED_MODEL = None | |
| CLIENT = None | |
| # ============================================================ | |
| # Utilities | |
| # ============================================================ | |
| def tokenize(text: str): | |
| return re.findall(r"\w+", text.lower(), flags=re.UNICODE) | |
| def ensure_loaded(): | |
| global BM25, CHUNKS, EMBEDDINGS, EMBED_MODEL, CLIENT | |
| if CHUNKS is None: | |
| missing = [p for p in [CHUNKS_PATH, TOKENS_PATH, EMBED_PATH, CONFIG_PATH] if not os.path.exists(p)] | |
| if missing: | |
| raise FileNotFoundError( | |
| "Missing build files. Make sure you ran the build step and committed brainchat_build/.\n" | |
| + "\n".join(missing) | |
| ) | |
| with open(CHUNKS_PATH, "rb") as f: | |
| CHUNKS = pickle.load(f) | |
| with open(TOKENS_PATH, "rb") as f: | |
| tokenized_chunks = pickle.load(f) | |
| EMBEDDINGS = np.load(EMBED_PATH) | |
| with open(CONFIG_PATH, "r", encoding="utf-8") as f: | |
| cfg = json.load(f) | |
| BM25 = BM25Okapi(tokenized_chunks) | |
| EMBED_MODEL = SentenceTransformer(cfg["embedding_model"]) | |
| if CLIENT is None: | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| if not api_key: | |
| raise ValueError("OPENAI_API_KEY is missing. Add it in your Space Secrets.") | |
| CLIENT = OpenAI(api_key=api_key) | |
| def search_hybrid(query: str, shortlist_k: int = 30, final_k: int = 5): | |
| ensure_loaded() | |
| q_tokens = tokenize(query) | |
| bm25_scores = BM25.get_scores(q_tokens) | |
| shortlist_idx = np.argsort(bm25_scores)[::-1][:shortlist_k] | |
| qvec = EMBED_MODEL.encode([query], normalize_embeddings=True).astype("float32")[0] | |
| shortlist_emb = EMBEDDINGS[shortlist_idx] | |
| dense_scores = shortlist_emb @ qvec | |
| rerank = np.argsort(dense_scores)[::-1][:final_k] | |
| final_idx = shortlist_idx[rerank] | |
| return [CHUNKS[int(i)] for i in final_idx] | |
| def build_context(records): | |
| blocks = [] | |
| for i, r in enumerate(records, start=1): | |
| blocks.append( | |
| f"""[Source {i}] | |
| Book: {r.get('book','')} | |
| Section: {r.get('section_title','')} | |
| Pages: {r.get('page_start','')}-{r.get('page_end','')} | |
| Text: | |
| {r.get('text','')}""" | |
| ) | |
| return "\n\n".join(blocks) | |
| def make_sources(records): | |
| seen = set() | |
| lines = [] | |
| for r in records: | |
| key = (r.get("book"), r.get("section_title"), r.get("page_start"), r.get("page_end")) | |
| if key in seen: | |
| continue | |
| seen.add(key) | |
| lines.append( | |
| f"• {r.get('book','')} | {r.get('section_title','')} | pp. {r.get('page_start','')}-{r.get('page_end','')}" | |
| ) | |
| return "\n".join(lines) | |
| def choose_quiz_count(user_text: str, selector: str) -> int: | |
| if selector in {"3", "5", "7"}: | |
| return int(selector) | |
| t = user_text.lower() | |
| if any(k in t for k in ["mock test", "final exam", "exam practice", "full test"]): | |
| return 7 | |
| if any(k in t for k in ["detailed", "revision", "comprehensive", "study"]): | |
| return 5 | |
| return 3 | |
| def language_instruction(language_mode: str) -> str: | |
| if language_mode == "English": | |
| return "Answer only in English." | |
| if language_mode == "Spanish": | |
| return "Answer only in Spanish." | |
| if language_mode == "Bilingual": | |
| return "Answer first in English, then provide a Spanish version under the heading 'Español:'." | |
| return "If the user writes in Spanish, answer in Spanish; otherwise answer in English." | |
| def build_tutor_prompt(mode: str, language_mode: str, question: str, context: str) -> str: | |
| mode_map = { | |
| "Explain": ( | |
| "Explain clearly like a friendly tutor using simple language. " | |
| "Use short headings if helpful." | |
| ), | |
| "Detailed": ( | |
| "Give a detailed explanation. Include key terms and clinical relevance only if supported by the context." | |
| ), | |
| "Short Notes": "Write concise revision notes using bullet points.", | |
| "Flashcards": "Create 6 flashcards in Q/A format.", | |
| "Case-Based": ( | |
| "Create a short clinical scenario (2–4 lines) and then explain the underlying concept using the context." | |
| ), | |
| } | |
| return f""" | |
| You are BrainChat, an interactive neurology and neuroanatomy tutor. | |
| Rules: | |
| - Use ONLY the provided context from the books. | |
| - If the answer is not supported by the context, say exactly: | |
| Not found in the course material. | |
| - Do not invent facts outside the context. | |
| - {language_instruction(language_mode)} | |
| Teaching style: | |
| {mode_map.get(mode, mode_map['Explain'])} | |
| Context: | |
| {context} | |
| Student question: | |
| {question} | |
| """.strip() | |
| def build_quiz_generation_prompt(language_mode: str, topic: str, context: str, n_questions: int) -> str: | |
| return f""" | |
| You are BrainChat, an interactive tutor. | |
| Rules: | |
| - Use ONLY the provided context. | |
| - Create exactly {n_questions} quiz questions. | |
| - Questions should be short, clear, and course-aligned. | |
| - Provide a short answer key per question. | |
| - Return VALID JSON only. | |
| - {language_instruction(language_mode)} | |
| Required JSON format: | |
| {{ | |
| "title": "short quiz title", | |
| "questions": [ | |
| {{"q": "question 1", "answer_key": "expected short answer"}}, | |
| {{"q": "question 2", "answer_key": "expected short answer"}} | |
| ] | |
| }} | |
| Context: | |
| {context} | |
| Topic: | |
| {topic} | |
| """.strip() | |
| def build_quiz_evaluation_prompt(language_mode: str, quiz_data: dict, user_answers: str) -> str: | |
| quiz_json = json.dumps(quiz_data, ensure_ascii=False) | |
| return f""" | |
| You are BrainChat, an interactive tutor. | |
| Task: | |
| Evaluate the student's answers fairly against the answer keys. | |
| Accept semantically correct answers even if wording differs. | |
| Return VALID JSON only. | |
| Required JSON format: | |
| {{ | |
| "score_obtained": 0, | |
| "score_total": 0, | |
| "summary": "short overall feedback", | |
| "results": [ | |
| {{ | |
| "question": "question text", | |
| "answer_key": "expected short answer", | |
| "student_answer": "student answer", | |
| "result": "Correct / Partially Correct / Incorrect", | |
| "feedback": "short explanation" | |
| }} | |
| ], | |
| "improvement_tip": "one short study suggestion" | |
| }} | |
| Quiz: | |
| {quiz_json} | |
| Student answers: | |
| {user_answers} | |
| Language: | |
| {language_instruction(language_mode)} | |
| """.strip() | |
| def chat_text(prompt: str) -> str: | |
| ensure_loaded() | |
| resp = CLIENT.chat.completions.create( | |
| model=MODEL_NAME_TEXT, | |
| temperature=0.2, | |
| messages=[ | |
| {"role": "system", "content": "You are a helpful educational assistant."}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| ) | |
| return resp.choices[0].message.content.strip() | |
| def chat_json(prompt: str) -> dict: | |
| ensure_loaded() | |
| resp = CLIENT.chat.completions.create( | |
| model=MODEL_NAME_TEXT, | |
| temperature=0.2, | |
| response_format={"type": "json_object"}, | |
| messages=[ | |
| {"role": "system", "content": "Return only valid JSON."}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| ) | |
| return json.loads(resp.choices[0].message.content) | |
| # ============================================================ | |
| # Logo + Header HTML | |
| # ============================================================ | |
| def find_logo_file(): | |
| for name in LOGO_CANDIDATES: | |
| if os.path.exists(name): | |
| return name | |
| return None | |
| def logo_img_tag(size_px: int = 88) -> str: | |
| logo_file = find_logo_file() | |
| if logo_file: | |
| url = f"/gradio_api/file={quote(logo_file)}" | |
| return f'<img src="{url}" class="bc-logo-img" width="{size_px}" height="{size_px}" alt="BrainChat logo" />' | |
| return '<div class="bc-logo-fallback">BRAIN<br>CHAT</div>' | |
| def render_top_banner() -> str: | |
| return f""" | |
| <div class="bc-banner"> | |
| <div class="bc-banner-inner"> | |
| <div class="bc-banner-logo">{logo_img_tag(64)}</div> | |
| <div class="bc-banner-text"> | |
| <div class="bc-banner-title">BrainChat</div> | |
| <div class="bc-banner-subtitle">Neurology & neuroanatomy tutor (book-based)</div> | |
| </div> | |
| </div> | |
| </div> | |
| """.strip() | |
| def render_phone_logo() -> str: | |
| return f""" | |
| <div class="bc-phone-logo"> | |
| {logo_img_tag(84)} | |
| </div> | |
| """.strip() | |
| # ============================================================ | |
| # Chat logic (with quiz state) | |
| # ============================================================ | |
| def respond(message, history, mode, language_mode, quiz_count_mode, show_sources, quiz_state): | |
| if history is None: | |
| history = [] | |
| if quiz_state is None: | |
| quiz_state = {"active": False, "quiz_data": None, "language_mode": "Auto"} | |
| user_text = (message or "").strip() | |
| if not user_text: | |
| return "", history, quiz_state | |
| try: | |
| history = history + [{"role": "user", "content": user_text}] | |
| # Quiz evaluation step | |
| if quiz_state.get("active", False): | |
| evaluation_prompt = build_quiz_evaluation_prompt( | |
| quiz_state.get("language_mode", language_mode), | |
| quiz_state.get("quiz_data", {}), | |
| user_text, | |
| ) | |
| evaluation = chat_json(evaluation_prompt) | |
| lines = [] | |
| lines.append(f"**Score:** {evaluation.get('score_obtained', 0)}/{evaluation.get('score_total', 0)}") | |
| if evaluation.get("summary"): | |
| lines.append(f"\n**Overall:** {evaluation['summary']}") | |
| if evaluation.get("improvement_tip"): | |
| lines.append(f"\n**Tip:** {evaluation['improvement_tip']}\n") | |
| results = evaluation.get("results", []) | |
| if results: | |
| lines.append("**Question-wise feedback:**") | |
| for item in results: | |
| lines.append("") | |
| lines.append(f"**Q:** {item.get('question','')}") | |
| lines.append(f"**Your answer:** {item.get('student_answer','')}") | |
| lines.append(f"**Expected:** {item.get('answer_key','')}") | |
| lines.append(f"**Result:** {item.get('result','')}") | |
| lines.append(f"**Feedback:** {item.get('feedback','')}") | |
| assistant_text = "\n".join(lines).strip() | |
| history = history + [{"role": "assistant", "content": assistant_text}] | |
| quiz_state = {"active": False, "quiz_data": None, "language_mode": language_mode} | |
| return "", history, quiz_state | |
| # Normal retrieval | |
| records = search_hybrid(user_text, shortlist_k=30, final_k=5) | |
| context = build_context(records) | |
| # Quiz generation | |
| if mode == "Quiz Me": | |
| n_questions = choose_quiz_count(user_text, quiz_count_mode) | |
| quiz_prompt = build_quiz_generation_prompt(language_mode, user_text, context, n_questions) | |
| quiz_data = chat_json(quiz_prompt) | |
| lines = [] | |
| lines.append(f"**{quiz_data.get('title','Quiz')}**") | |
| lines.append(f"\n**Total questions:** {len(quiz_data.get('questions', []))}\n") | |
| lines.append("Reply in ONE message with numbered answers, like:") | |
| lines.append("1. ...") | |
| lines.append("2. ...\n") | |
| for i, q in enumerate(quiz_data.get("questions", []), start=1): | |
| lines.append(f"**Q{i}.** {q.get('q','')}") | |
| if show_sources: | |
| lines.append("\n\n**Sources used to create this quiz:**") | |
| lines.append(make_sources(records)) | |
| assistant_text = "\n".join(lines).strip() | |
| history = history + [{"role": "assistant", "content": assistant_text}] | |
| quiz_state = {"active": True, "quiz_data": quiz_data, "language_mode": language_mode} | |
| return "", history, quiz_state | |
| # Other modes | |
| tutor_prompt = build_tutor_prompt(mode, language_mode, user_text, context) | |
| answer = chat_text(tutor_prompt) | |
| if show_sources: | |
| answer = (answer or "").strip() + "\n\n**Sources:**\n" + make_sources(records) | |
| history = history + [{"role": "assistant", "content": answer.strip()}] | |
| return "", history, quiz_state | |
| except Exception as e: | |
| history = history + [{"role": "assistant", "content": f"Error: {str(e)}"}] | |
| quiz_state = {"active": False, "quiz_data": None, "language_mode": language_mode} | |
| return "", history, quiz_state | |
| def clear_all(): | |
| return "", [], {"active": False, "quiz_data": None, "language_mode": "Auto"} | |
| # ============================================================ | |
| # CSS (Instagram-style phone mock) | |
| # ============================================================ | |
| CSS = r""" | |
| :root{ | |
| --bc-page-bg: #dcdcdc; | |
| --bc-grad-top: #E8C7D4; | |
| --bc-grad-mid: #A55CA2; | |
| --bc-grad-bot: #2B0C46; | |
| --bc-yellow: #FFF34A; | |
| --bc-bot-bubble: #FAF7B4; | |
| --bc-user-bubble: #FFFFFF; | |
| --bc-ink: #141414; | |
| } | |
| body, .gradio-container{ | |
| background: var(--bc-page-bg) !important; | |
| font-family: ui-sans-serif, system-ui, -apple-system, Segoe UI, Roboto, Helvetica, Arial; | |
| } | |
| footer{ display:none !important; } | |
| /* Banner */ | |
| #bc_banner{ max-width: 980px; margin: 18px auto 8px auto; } | |
| .bc-banner{ | |
| background: linear-gradient(180deg, var(--bc-grad-top) 0%, var(--bc-grad-mid) 52%, var(--bc-grad-bot) 100%); | |
| border-radius: 26px; | |
| padding: 14px 16px; | |
| box-shadow: 0 10px 26px rgba(0,0,0,.12); | |
| } | |
| .bc-banner-inner{ display:flex; align-items:center; gap: 12px; color: white; } | |
| .bc-banner-title{ font-size: 20px; font-weight: 800; line-height:1.1; } | |
| .bc-banner-subtitle{ font-size: 13px; opacity:.92; margin-top:2px; } | |
| .bc-banner-logo .bc-logo-img{ border-radius: 999px; background: var(--bc-yellow); padding: 6px; display:block; } | |
| .bc-logo-fallback{ | |
| width: 64px; height: 64px; | |
| border-radius: 999px; | |
| background: var(--bc-yellow); | |
| display:flex; align-items:center; justify-content:center; | |
| color: #111; font-weight: 900; font-size: 12px; text-align:center; | |
| } | |
| /* Settings */ | |
| #bc_settings{ max-width: 980px; margin: 0 auto 10px auto; } | |
| #bc_settings .label{ font-weight: 700; } | |
| /* Phone */ | |
| #bc_phone{ | |
| max-width: 420px; | |
| margin: 0 auto 18px auto; | |
| border-radius: 38px; | |
| background: linear-gradient(180deg, var(--bc-grad-top) 0%, var(--bc-grad-mid) 45%, var(--bc-grad-bot) 100%); | |
| box-shadow: 0 18px 40px rgba(0,0,0,.18); | |
| border: 1px solid rgba(255,255,255,.22); | |
| padding: 14px 14px 12px 14px; | |
| position: relative; | |
| } | |
| /* Floating logo in phone */ | |
| #bc_phone_logo{ | |
| position: absolute; | |
| top: 12px; | |
| left: 50%; | |
| transform: translateX(-50%); | |
| z-index: 10; | |
| } | |
| .bc-phone-logo{ | |
| width: 92px; height: 92px; | |
| border-radius: 999px; | |
| background: var(--bc-yellow); | |
| display:flex; align-items:center; justify-content:center; | |
| box-shadow: 0 10px 22px rgba(0,0,0,.18); | |
| } | |
| .bc-phone-logo .bc-logo-img{ | |
| width: 84px !important; height: 84px !important; object-fit: contain; | |
| } | |
| /* Push chat down under logo */ | |
| #bc_chatbot{ margin-top: 92px; } | |
| /* Chatbot transparent */ | |
| #bc_chatbot, #bc_chatbot > div{ | |
| background: transparent !important; | |
| border: none !important; | |
| box-shadow: none !important; | |
| } | |
| #bc_chatbot .toolbar{ display:none !important; } | |
| /* Bubble styling via internal testid markers */ | |
| #bc_chatbot button[data-testid="user"], | |
| #bc_chatbot button[data-testid="bot"]{ | |
| max-width: 82%; | |
| border-radius: 18px !important; | |
| padding: 12px 14px !important; | |
| color: var(--bc-ink) !important; | |
| box-shadow: 0 8px 18px rgba(0,0,0,.10); | |
| border: 0 !important; | |
| line-height: 1.35; | |
| font-size: 14px; | |
| } | |
| /* User bubble white */ | |
| #bc_chatbot button[data-testid="user"]{ | |
| background: var(--bc-user-bubble) !important; | |
| } | |
| /* Bot bubble pale yellow */ | |
| #bc_chatbot button[data-testid="bot"]{ | |
| background: var(--bc-bot-bubble) !important; | |
| } | |
| /* Bubble tails */ | |
| #bc_chatbot button[data-testid="user"]::after{ | |
| content:""; | |
| position:absolute; | |
| right:-7px; | |
| bottom: 12px; | |
| width:0; height:0; | |
| border-left: 10px solid var(--bc-user-bubble); | |
| border-top: 8px solid transparent; | |
| border-bottom: 8px solid transparent; | |
| } | |
| #bc_chatbot button[data-testid="bot"]::before{ | |
| content:""; | |
| position:absolute; | |
| left:-7px; | |
| bottom: 12px; | |
| width:0; height:0; | |
| border-right: 10px solid var(--bc-bot-bubble); | |
| border-top: 8px solid transparent; | |
| border-bottom: 8px solid transparent; | |
| } | |
| /* Input bar */ | |
| #bc_input_row{ | |
| margin-top: 10px; | |
| background: rgba(255,243,74,.96); | |
| border-radius: 999px; | |
| padding: 10px 10px; | |
| box-shadow: 0 10px 22px rgba(0,0,0,.14); | |
| align-items: center; | |
| } | |
| #bc_plus{ | |
| width: 34px; height: 34px; | |
| border-radius: 999px; | |
| display:flex; | |
| align-items:center; | |
| justify-content:center; | |
| font-weight: 900; | |
| color: var(--bc-grad-bot); | |
| background: rgba(255,255,255,.35); | |
| user-select: none; | |
| } | |
| #bc_msg textarea{ | |
| background: rgba(255,255,255,.35) !important; | |
| border-radius: 999px !important; | |
| border: none !important; | |
| padding: 10px 12px !important; | |
| color: var(--bc-grad-bot) !important; | |
| box-shadow: none !important; | |
| } | |
| #bc_send{ | |
| min-width: 42px !important; | |
| height: 38px !important; | |
| border-radius: 999px !important; | |
| border: none !important; | |
| background: rgba(255,255,255,.35) !important; | |
| color: var(--bc-grad-bot) !important; | |
| font-size: 18px !important; | |
| font-weight: 900 !important; | |
| } | |
| #bc_send:hover{ background: rgba(255,255,255,.55) !important; } | |
| /* Clear */ | |
| #bc_clear{ | |
| max-width: 420px; | |
| margin: 10px auto 0 auto; | |
| border-radius: 14px !important; | |
| } | |
| @media (max-width: 480px){ | |
| #bc_phone{ max-width: 95vw; } | |
| #bc_chatbot button[data-testid="user"], | |
| #bc_chatbot button[data-testid="bot"]{ | |
| max-width: 88%; | |
| font-size: 14px; | |
| } | |
| } | |
| """ | |
| # ============================================================ | |
| # UI | |
| # ============================================================ | |
| with gr.Blocks() as demo: | |
| quiz_state = gr.State({"active": False, "quiz_data": None, "language_mode": "Auto"}) | |
| gr.HTML(render_top_banner(), elem_id="bc_banner") | |
| with gr.Accordion("Settings", open=False, elem_id="bc_settings"): | |
| mode = gr.Dropdown( | |
| choices=["Explain", "Detailed", "Short Notes", "Flashcards", "Case-Based", "Quiz Me"], | |
| value="Explain", | |
| label="Tutor Mode", | |
| ) | |
| language_mode = gr.Dropdown( | |
| choices=["Auto", "English", "Spanish", "Bilingual"], | |
| value="Auto", | |
| label="Answer Language", | |
| ) | |
| quiz_count_mode = gr.Dropdown( | |
| choices=["Auto", "3", "5", "7"], | |
| value="Auto", | |
| label="Quiz Questions", | |
| ) | |
| show_sources = gr.Checkbox(value=True, label="Show Sources") | |
| with gr.Group(elem_id="bc_phone"): | |
| gr.HTML(render_phone_logo(), elem_id="bc_phone_logo") | |
| chatbot = gr.Chatbot( | |
| value=[], | |
| elem_id="bc_chatbot", | |
| height=560, | |
| layout="bubble", | |
| container=False, | |
| show_label=False, | |
| autoscroll=True, | |
| buttons=[], | |
| placeholder="Ask a question or type a topic…", | |
| ) | |
| with gr.Row(elem_id="bc_input_row"): | |
| gr.HTML("<div>+</div>", elem_id="bc_plus") | |
| msg = gr.Textbox( | |
| placeholder="Type a message…", | |
| show_label=False, | |
| container=False, | |
| scale=8, | |
| elem_id="bc_msg", | |
| ) | |
| send_btn = gr.Button("➤", elem_id="bc_send", scale=1) | |
| clear_btn = gr.Button("Clear chat", elem_id="bc_clear") | |
| msg.submit( | |
| respond, | |
| inputs=[msg, chatbot, mode, language_mode, quiz_count_mode, show_sources, quiz_state], | |
| outputs=[msg, chatbot, quiz_state], | |
| ) | |
| send_btn.click( | |
| respond, | |
| inputs=[msg, chatbot, mode, language_mode, quiz_count_mode, show_sources, quiz_state], | |
| outputs=[msg, chatbot, quiz_state], | |
| ) | |
| clear_btn.click( | |
| clear_all, | |
| inputs=None, | |
| outputs=[msg, chatbot, quiz_state], | |
| queue=False, | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(css=CSS) | |