Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import json | |
| import pickle | |
| from urllib.parse import quote | |
| import numpy as np | |
| import gradio as gr | |
| from rank_bm25 import BM25Okapi | |
| from sentence_transformers import SentenceTransformer | |
| from openai import OpenAI | |
| # ===================================================== | |
| # PATHS | |
| # ===================================================== | |
| BUILD_DIR = "brainchat_build" | |
| CHUNKS_PATH = os.path.join(BUILD_DIR, "chunks.pkl") | |
| TOKENS_PATH = os.path.join(BUILD_DIR, "tokenized_chunks.pkl") | |
| EMBED_PATH = os.path.join(BUILD_DIR, "embeddings.npy") | |
| CONFIG_PATH = os.path.join(BUILD_DIR, "config.json") | |
| LOGO_FILE = "Brain chat-09.png" | |
| # ===================================================== | |
| # GLOBALS | |
| # ===================================================== | |
| EMBED_MODEL = None | |
| BM25 = None | |
| CHUNKS = None | |
| EMBEDDINGS = None | |
| CLIENT = None | |
| # ===================================================== | |
| # LOADERS | |
| # ===================================================== | |
| def tokenize(text: str): | |
| return re.findall(r"\w+", text.lower(), flags=re.UNICODE) | |
| def ensure_loaded(): | |
| global EMBED_MODEL, BM25, CHUNKS, EMBEDDINGS, CLIENT | |
| if CHUNKS is None: | |
| for path in [CHUNKS_PATH, TOKENS_PATH, EMBED_PATH, CONFIG_PATH]: | |
| if not os.path.exists(path): | |
| raise FileNotFoundError(f"Missing file: {path}") | |
| with open(CHUNKS_PATH, "rb") as f: | |
| CHUNKS = pickle.load(f) | |
| with open(TOKENS_PATH, "rb") as f: | |
| tokenized_chunks = pickle.load(f) | |
| EMBEDDINGS = np.load(EMBED_PATH) | |
| with open(CONFIG_PATH, "r", encoding="utf-8") as f: | |
| cfg = json.load(f) | |
| BM25 = BM25Okapi(tokenized_chunks) | |
| EMBED_MODEL = SentenceTransformer(cfg["embedding_model"]) | |
| if CLIENT is None: | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| if not api_key: | |
| raise ValueError("OPENAI_API_KEY is missing in Hugging Face Space Secrets.") | |
| CLIENT = OpenAI(api_key=api_key) | |
| # ===================================================== | |
| # RETRIEVAL | |
| # ===================================================== | |
| def search_hybrid(query: str, shortlist_k: int = 20, final_k: int = 3): | |
| ensure_loaded() | |
| query_tokens = tokenize(query) | |
| bm25_scores = BM25.get_scores(query_tokens) | |
| shortlist_idx = np.argsort(bm25_scores)[::-1][:shortlist_k] | |
| shortlist_embeddings = EMBEDDINGS[shortlist_idx] | |
| qvec = EMBED_MODEL.encode([query], normalize_embeddings=True).astype("float32")[0] | |
| dense_scores = shortlist_embeddings @ qvec | |
| rerank_order = np.argsort(dense_scores)[::-1][:final_k] | |
| final_idx = shortlist_idx[rerank_order] | |
| return [CHUNKS[int(i)] for i in final_idx] | |
| def build_context(records): | |
| blocks = [] | |
| for i, r in enumerate(records, start=1): | |
| blocks.append( | |
| f"""[Source {i}] | |
| Book: {r['book']} | |
| Section: {r['section_title']} | |
| Pages: {r['page_start']}-{r['page_end']} | |
| Text: | |
| {r['text']}""" | |
| ) | |
| return "\n\n".join(blocks) | |
| def make_sources(records): | |
| seen = set() | |
| lines = [] | |
| for r in records: | |
| key = (r["book"], r["section_title"], r["page_start"], r["page_end"]) | |
| if key in seen: | |
| continue | |
| seen.add(key) | |
| lines.append( | |
| f"• {r['book']} | {r['section_title']} | pp. {r['page_start']}-{r['page_end']}" | |
| ) | |
| return "\n".join(lines) | |
| # ===================================================== | |
| # PROMPTS | |
| # ===================================================== | |
| def language_instruction(language_mode: str) -> str: | |
| if language_mode == "English": | |
| return "Answer only in English." | |
| if language_mode == "Spanish": | |
| return "Answer only in Spanish." | |
| if language_mode == "Bilingual": | |
| return "Answer first in English, then provide a Spanish version under the heading 'Español:'." | |
| return ( | |
| "If the user's message is in Spanish, answer in Spanish. " | |
| "If the user's message is in English, answer in English." | |
| ) | |
| def choose_quiz_count(user_text: str, selector: str) -> int: | |
| if selector in {"3", "5", "7"}: | |
| return int(selector) | |
| t = user_text.lower() | |
| if any(k in t for k in ["mock test", "final exam", "exam practice", "full test"]): | |
| return 7 | |
| if any(k in t for k in ["detailed", "revision", "comprehensive", "study"]): | |
| return 5 | |
| return 3 | |
| def build_tutor_prompt(mode: str, language_mode: str, question: str, context: str) -> str: | |
| mode_map = { | |
| "Explain": ( | |
| "Explain clearly like a friendly tutor using simple language. " | |
| "Use short headings if useful." | |
| ), | |
| "Detailed": ( | |
| "Give a fuller and more detailed explanation. Include concept, key points, and clinical relevance when supported by context." | |
| ), | |
| "Short Notes": ( | |
| "Answer in concise revision-note format using short bullet points." | |
| ), | |
| "Flashcards": ( | |
| "Create 6 flashcards in Q/A format using only the provided context." | |
| ), | |
| "Case-Based": ( | |
| "Create a short clinical scenario and explain it clearly using the provided context." | |
| ) | |
| } | |
| return f""" | |
| You are BrainChat, an interactive neurology and neuroanatomy tutor. | |
| Rules: | |
| - Use only the provided context from the books. | |
| - If the answer is not supported by the context, say exactly: | |
| Not found in the course material. | |
| - Be accurate and student-friendly. | |
| - Do not invent facts outside the context. | |
| - {language_instruction(language_mode)} | |
| Teaching style: | |
| {mode_map[mode]} | |
| Context: | |
| {context} | |
| Question: | |
| {question} | |
| """.strip() | |
| def build_quiz_generation_prompt(language_mode: str, topic: str, context: str, n_questions: int) -> str: | |
| return f""" | |
| You are BrainChat, an interactive tutor. | |
| Rules: | |
| - Use only the provided context. | |
| - Create exactly {n_questions} quiz questions. | |
| - Questions should be short and clear. | |
| - Also create a short answer key. | |
| - Return valid JSON only. | |
| - {language_instruction(language_mode)} | |
| Required JSON format: | |
| {{ | |
| "title": "short quiz title", | |
| "questions": [ | |
| {{"q": "question 1", "answer_key": "expected short answer"}}, | |
| {{"q": "question 2", "answer_key": "expected short answer"}} | |
| ] | |
| }} | |
| Context: | |
| {context} | |
| Topic: | |
| {topic} | |
| """.strip() | |
| def build_quiz_evaluation_prompt(language_mode: str, quiz_data: dict, user_answers: str) -> str: | |
| quiz_json = json.dumps(quiz_data, ensure_ascii=False) | |
| return f""" | |
| You are BrainChat, an interactive tutor. | |
| Evaluate the student's answers fairly using the quiz answer key. | |
| Give: | |
| - total score | |
| - per-question feedback | |
| - one short improvement suggestion | |
| Rules: | |
| - Accept semantically correct answers even if wording differs. | |
| - Return valid JSON only. | |
| - {language_instruction(language_mode)} | |
| Required JSON format: | |
| {{ | |
| "score_obtained": 0, | |
| "score_total": 0, | |
| "summary": "short overall feedback", | |
| "results": [ | |
| {{ | |
| "question": "question text", | |
| "student_answer": "student answer", | |
| "result": "Correct / Partially Correct / Incorrect", | |
| "feedback": "short explanation" | |
| }} | |
| ] | |
| }} | |
| Quiz data: | |
| {quiz_json} | |
| Student answers: | |
| {user_answers} | |
| """.strip() | |
| # ===================================================== | |
| # OPENAI HELPERS | |
| # ===================================================== | |
| def chat_text(prompt: str) -> str: | |
| resp = CLIENT.chat.completions.create( | |
| model="gpt-4o-mini", | |
| temperature=0.2, | |
| messages=[ | |
| {"role": "system", "content": "You are a helpful educational assistant."}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| ) | |
| return resp.choices[0].message.content.strip() | |
| def chat_json(prompt: str) -> dict: | |
| resp = CLIENT.chat.completions.create( | |
| model="gpt-4o-mini", | |
| temperature=0.2, | |
| response_format={"type": "json_object"}, | |
| messages=[ | |
| {"role": "system", "content": "Return only valid JSON."}, | |
| {"role": "user", "content": prompt}, | |
| ], | |
| ) | |
| return json.loads(resp.choices[0].message.content) | |
| # ===================================================== | |
| # HTML RENDERING | |
| # ===================================================== | |
| def md_to_html(text: str) -> str: | |
| safe = ( | |
| text.replace("&", "&") | |
| .replace("<", "<") | |
| .replace(">", ">") | |
| ) | |
| safe = re.sub(r"\*\*(.+?)\*\*", r"<strong>\1</strong>", safe) | |
| safe = safe.replace("\n", "<br>") | |
| return safe | |
| def render_chat(history): | |
| if not history: | |
| return """ | |
| <div class="empty-chat"> | |
| <div class="empty-chat-text"> | |
| Ask a question, choose a tutor mode, or start a quiz. | |
| </div> | |
| </div> | |
| """ | |
| rows = [] | |
| for msg in history: | |
| role = msg["role"] | |
| content = md_to_html(msg["content"]) | |
| if role == "user": | |
| rows.append( | |
| f'<div class="msg-row user-row"><div class="msg-bubble user-bubble">{content}</div></div>' | |
| ) | |
| else: | |
| rows.append( | |
| f'<div class="msg-row bot-row"><div class="msg-bubble bot-bubble">{content}</div></div>' | |
| ) | |
| return f'<div class="chat-wrap">{"".join(rows)}</div>' | |
| def detect_logo_url(): | |
| if os.path.exists(LOGO_FILE): | |
| return f"/gradio_api/file={quote(LOGO_FILE)}" | |
| return None | |
| def render_header(): | |
| logo_url = detect_logo_url() | |
| if logo_url: | |
| logo_html = f""" | |
| <img src="{logo_url}" alt="BrainChat Logo" | |
| style="width:120px;height:120px;object-fit:contain;display:block;margin:0 auto;"> | |
| """ | |
| else: | |
| logo_html = """ | |
| <div style=" | |
| width:120px;height:120px;border-radius:50%; | |
| background:#efe85a;display:flex;align-items:center;justify-content:center; | |
| font-weight:700;text-align:center;margin:0 auto;"> | |
| BRAIN<br>CHAT | |
| </div> | |
| """ | |
| return f""" | |
| <div class="hero-card"> | |
| <div class="hero-inner"> | |
| <div class="hero-logo">{logo_html}</div> | |
| <div class="hero-title">BrainChat</div> | |
| <div class="hero-subtitle"> | |
| Interactive neurology and neuroanatomy tutor based on your uploaded books | |
| </div> | |
| </div> | |
| </div> | |
| """ | |
| # ===================================================== | |
| # MAIN LOGIC | |
| # ===================================================== | |
| def answer_question(message, history, mode, language_mode, quiz_count_mode, show_sources, quiz_state): | |
| if history is None: | |
| history = [] | |
| if quiz_state is None: | |
| quiz_state = { | |
| "active": False, | |
| "topic": None, | |
| "quiz_data": None, | |
| "language_mode": "Auto" | |
| } | |
| if not message or not message.strip(): | |
| return history, render_chat(history), quiz_state, "" | |
| try: | |
| ensure_loaded() | |
| user_text = message.strip() | |
| history = history + [{"role": "user", "content": user_text}] | |
| # ------------------------------- | |
| # QUIZ EVALUATION | |
| # ------------------------------- | |
| if quiz_state.get("active", False): | |
| evaluation_prompt = build_quiz_evaluation_prompt( | |
| quiz_state["language_mode"], | |
| quiz_state["quiz_data"], | |
| user_text | |
| ) | |
| evaluation = chat_json(evaluation_prompt) | |
| lines = [] | |
| lines.append(f"**Score:** {evaluation['score_obtained']}/{evaluation['score_total']}") | |
| lines.append("") | |
| lines.append(f"**Overall feedback:** {evaluation['summary']}") | |
| lines.append("") | |
| lines.append("**Question-wise evaluation:**") | |
| for item in evaluation["results"]: | |
| lines.append("") | |
| lines.append(f"**Q:** {item['question']}") | |
| lines.append(f"**Your answer:** {item['student_answer']}") | |
| lines.append(f"**Result:** {item['result']}") | |
| lines.append(f"**Feedback:** {item['feedback']}") | |
| final_answer = "\n".join(lines) | |
| history = history + [{"role": "assistant", "content": final_answer}] | |
| quiz_state = { | |
| "active": False, | |
| "topic": None, | |
| "quiz_data": None, | |
| "language_mode": language_mode | |
| } | |
| return history, render_chat(history), quiz_state, "" | |
| # ------------------------------- | |
| # NORMAL RETRIEVAL | |
| # ------------------------------- | |
| records = search_hybrid(user_text, shortlist_k=20, final_k=3) | |
| context = build_context(records) | |
| # ------------------------------- | |
| # QUIZ GENERATION | |
| # ------------------------------- | |
| if mode == "Quiz Me": | |
| n_questions = choose_quiz_count(user_text, quiz_count_mode) | |
| prompt = build_quiz_generation_prompt(language_mode, user_text, context, n_questions) | |
| quiz_data = chat_json(prompt) | |
| lines = [] | |
| lines.append(f"**{quiz_data.get('title', 'Quiz')}**") | |
| lines.append("") | |
| lines.append("Please answer the following questions in one message.") | |
| lines.append("You can reply in numbered format, for example:") | |
| lines.append("1. ...") | |
| lines.append("2. ...") | |
| lines.append("") | |
| lines.append(f"**Total questions: {len(quiz_data['questions'])}**") | |
| lines.append("") | |
| for i, q in enumerate(quiz_data["questions"], start=1): | |
| lines.append(f"**Q{i}.** {q['q']}") | |
| if show_sources: | |
| lines.append("\n---\n**Topic sources used to create the quiz:**") | |
| lines.append(make_sources(records)) | |
| assistant_text = "\n".join(lines) | |
| history = history + [{"role": "assistant", "content": assistant_text}] | |
| quiz_state = { | |
| "active": True, | |
| "topic": user_text, | |
| "quiz_data": quiz_data, | |
| "language_mode": language_mode | |
| } | |
| return history, render_chat(history), quiz_state, "" | |
| # ------------------------------- | |
| # OTHER MODES | |
| # ------------------------------- | |
| prompt = build_tutor_prompt(mode, language_mode, user_text, context) | |
| answer = chat_text(prompt) | |
| if show_sources: | |
| answer += "\n\n---\n**Sources used:**\n" + make_sources(records) | |
| history = history + [{"role": "assistant", "content": answer}] | |
| return history, render_chat(history), quiz_state, "" | |
| except Exception as e: | |
| history = history + [{"role": "assistant", "content": f"Error: {str(e)}"}] | |
| quiz_state["active"] = False | |
| return history, render_chat(history), quiz_state, "" | |
| def clear_all(): | |
| empty_history = [] | |
| empty_quiz = { | |
| "active": False, | |
| "topic": None, | |
| "quiz_data": None, | |
| "language_mode": "Auto" | |
| } | |
| return empty_history, render_chat(empty_history), empty_quiz, "" | |
| # ===================================================== | |
| # CSS | |
| # ===================================================== | |
| CSS = """ | |
| body, .gradio-container { | |
| background: #dcdcdc !important; | |
| font-family: Arial, Helvetica, sans-serif !important; | |
| } | |
| footer { display: none !important; } | |
| .hero-card { | |
| max-width: 900px; | |
| margin: 18px auto 14px auto; | |
| border-radius: 28px; | |
| background: linear-gradient(180deg, #e8c7d4 0%, #a55ca2 48%, #2b0c46 100%); | |
| padding: 22px 22px 18px 22px; | |
| } | |
| .hero-inner { text-align: center; } | |
| .hero-title { | |
| color: white; | |
| font-size: 34px; | |
| font-weight: 800; | |
| margin-top: 6px; | |
| } | |
| .hero-subtitle { | |
| color: white; | |
| opacity: 0.92; | |
| font-size: 16px; | |
| margin-top: 6px; | |
| } | |
| .chat-panel { | |
| max-width: 900px; | |
| margin: 0 auto; | |
| background: white; | |
| border-radius: 22px; | |
| padding: 16px; | |
| min-height: 420px; | |
| box-shadow: 0 6px 18px rgba(0,0,0,0.08); | |
| } | |
| .chat-wrap { | |
| display: flex; | |
| flex-direction: column; | |
| gap: 14px; | |
| } | |
| .msg-row { | |
| display: flex; | |
| width: 100%; | |
| } | |
| .user-row { | |
| justify-content: flex-end; | |
| } | |
| .bot-row { | |
| justify-content: flex-start; | |
| } | |
| .msg-bubble { | |
| max-width: 80%; | |
| padding: 14px 16px; | |
| border-radius: 18px; | |
| line-height: 1.5; | |
| font-size: 15px; | |
| word-wrap: break-word; | |
| } | |
| .user-bubble { | |
| background: #e9d8ff; | |
| color: #111; | |
| border-bottom-right-radius: 6px; | |
| } | |
| .bot-bubble { | |
| background: #f7f3a1; | |
| color: #111; | |
| border-bottom-left-radius: 6px; | |
| } | |
| .empty-chat { | |
| display: flex; | |
| justify-content: center; | |
| align-items: center; | |
| min-height: 360px; | |
| } | |
| .empty-chat-text { | |
| color: #777; | |
| font-size: 16px; | |
| text-align: center; | |
| } | |
| .controls-wrap { | |
| max-width: 900px; | |
| margin: 0 auto; | |
| } | |
| """ | |
| # ===================================================== | |
| # UI | |
| # ===================================================== | |
| with gr.Blocks() as demo: | |
| history_state = gr.State([]) | |
| quiz_state = gr.State({ | |
| "active": False, | |
| "topic": None, | |
| "quiz_data": None, | |
| "language_mode": "Auto" | |
| }) | |
| gr.HTML(render_header()) | |
| with gr.Row(elem_classes="controls-wrap"): | |
| mode = gr.Dropdown( | |
| choices=["Explain", "Detailed", "Short Notes", "Flashcards", "Case-Based", "Quiz Me"], | |
| value="Explain", | |
| label="Tutor Mode" | |
| ) | |
| language_mode = gr.Dropdown( | |
| choices=["Auto", "English", "Spanish", "Bilingual"], | |
| value="Auto", | |
| label="Answer Language" | |
| ) | |
| with gr.Row(elem_classes="controls-wrap"): | |
| quiz_count_mode = gr.Dropdown( | |
| choices=["Auto", "3", "5", "7"], | |
| value="Auto", | |
| label="Quiz Questions" | |
| ) | |
| show_sources = gr.Checkbox(value=True, label="Show Sources") | |
| gr.Markdown(""" | |
| **How to use** | |
| - Choose a **Tutor Mode** | |
| - Then type a topic or question | |
| - For **Quiz Me**, type a topic such as: `cranial nerves` | |
| - The system will ask questions, and your **next message will be evaluated automatically** | |
| """) | |
| chat_html = gr.HTML(render_chat([]), elem_classes="chat-panel") | |
| with gr.Row(elem_classes="controls-wrap"): | |
| msg = gr.Textbox( | |
| placeholder="Ask a question or type a topic...", | |
| lines=1, | |
| show_label=False, | |
| scale=8 | |
| ) | |
| send_btn = gr.Button("Send", scale=1) | |
| with gr.Row(elem_classes="controls-wrap"): | |
| clear_btn = gr.Button("Clear Chat") | |
| msg.submit( | |
| answer_question, | |
| inputs=[msg, history_state, mode, language_mode, quiz_count_mode, show_sources, quiz_state], | |
| outputs=[history_state, chat_html, quiz_state, msg] | |
| ) | |
| send_btn.click( | |
| answer_question, | |
| inputs=[msg, history_state, mode, language_mode, quiz_count_mode, show_sources, quiz_state], | |
| outputs=[history_state, chat_html, quiz_state, msg] | |
| ) | |
| clear_btn.click( | |
| clear_all, | |
| inputs=[], | |
| outputs=[history_state, chat_html, quiz_state, msg] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch(css=CSS) |