Spaces:
Sleeping
Sleeping
| import os | |
| import re | |
| import html | |
| import json | |
| import pickle | |
| from urllib.parse import quote | |
| import numpy as np | |
| import gradio as gr | |
| from rank_bm25 import BM25Okapi | |
| from sentence_transformers import SentenceTransformer | |
| from openai import OpenAI | |
| BUILD_DIR = "brainchat_build" | |
| CHUNKS_PATH = os.path.join(BUILD_DIR, "chunks.pkl") | |
| TOKENS_PATH = os.path.join(BUILD_DIR, "tokenized_chunks.pkl") | |
| EMBED_PATH = os.path.join(BUILD_DIR, "embeddings.npy") | |
| CONFIG_PATH = os.path.join(BUILD_DIR, "config.json") | |
| EMBED_MODEL = None | |
| BM25 = None | |
| CHUNKS = None | |
| EMBEDDINGS = None | |
| OAI = None | |
| def tokenize(text: str): | |
| return re.findall(r"\w+", text.lower(), flags=re.UNICODE) | |
| def ensure_loaded(): | |
| global EMBED_MODEL, BM25, CHUNKS, EMBEDDINGS, OAI | |
| if CHUNKS is None: | |
| if not os.path.exists(CHUNKS_PATH): | |
| raise FileNotFoundError("Missing brainchat_build/chunks.pkl") | |
| if not os.path.exists(TOKENS_PATH): | |
| raise FileNotFoundError("Missing brainchat_build/tokenized_chunks.pkl") | |
| if not os.path.exists(EMBED_PATH): | |
| raise FileNotFoundError("Missing brainchat_build/embeddings.npy") | |
| if not os.path.exists(CONFIG_PATH): | |
| raise FileNotFoundError("Missing brainchat_build/config.json") | |
| with open(CHUNKS_PATH, "rb") as f: | |
| CHUNKS = pickle.load(f) | |
| with open(TOKENS_PATH, "rb") as f: | |
| tokenized_chunks = pickle.load(f) | |
| EMBEDDINGS = np.load(EMBED_PATH) | |
| with open(CONFIG_PATH, "r", encoding="utf-8") as f: | |
| cfg = json.load(f) | |
| BM25 = BM25Okapi(tokenized_chunks) | |
| EMBED_MODEL = SentenceTransformer(cfg["embedding_model"]) | |
| if OAI is None: | |
| api_key = os.getenv("OPENAI_API_KEY") | |
| if not api_key: | |
| raise ValueError("OPENAI_API_KEY is missing in Hugging Face Space Secrets.") | |
| OAI = OpenAI(api_key=api_key) | |
| def search_hybrid(query: str, shortlist_k: int = 30, final_k: int = 5): | |
| ensure_loaded() | |
| query_tokens = tokenize(query) | |
| bm25_scores = BM25.get_scores(query_tokens) | |
| shortlist_idx = np.argsort(bm25_scores)[::-1][:shortlist_k] | |
| shortlist_embeddings = EMBEDDINGS[shortlist_idx] | |
| qvec = EMBED_MODEL.encode([query], normalize_embeddings=True).astype("float32")[0] | |
| dense_scores = shortlist_embeddings @ qvec | |
| rerank_order = np.argsort(dense_scores)[::-1][:final_k] | |
| final_idx = shortlist_idx[rerank_order] | |
| return [CHUNKS[int(i)] for i in final_idx] | |
| def build_context(records): | |
| blocks = [] | |
| for i, r in enumerate(records, start=1): | |
| blocks.append( | |
| f"""[Source {i}] | |
| Book: {r['book']} | |
| Section: {r['section_title']} | |
| Pages: {r['page_start']}-{r['page_end']} | |
| Text: | |
| {r['text']}""" | |
| ) | |
| return "\n\n".join(blocks) | |
| def make_sources(records): | |
| seen = set() | |
| lines = [] | |
| for r in records: | |
| key = (r["book"], r["section_title"], r["page_start"], r["page_end"]) | |
| if key in seen: | |
| continue | |
| seen.add(key) | |
| lines.append(f"- {r['book']} | {r['section_title']} | pp. {r['page_start']}-{r['page_end']}") | |
| return "\n".join(lines) | |
| def answer_question(message: str, history, show_sources: bool): | |
| if not message or not message.strip(): | |
| return "Please type a question." | |
| try: | |
| records = search_hybrid(message, shortlist_k=30, final_k=5) | |
| context = build_context(records) | |
| system_prompt = """You are BrainChat, a neurology and neuroanatomy tutor. | |
| Rules: | |
| - Answer only from the provided context. | |
| - If the answer is not supported by the context, say exactly: | |
| Not found in the course material. | |
| - Keep the answer clear and concise unless the user asks for more detail. | |
| - If the question is in Spanish, answer in Spanish. | |
| - If the question is in English, answer in English. | |
| """ | |
| user_prompt = f"""Context: | |
| {context} | |
| Question: | |
| {message} | |
| """ | |
| resp = OAI.chat.completions.create( | |
| model="gpt-4o-mini", | |
| temperature=0.2, | |
| messages=[ | |
| {"role": "system", "content": system_prompt}, | |
| {"role": "user", "content": user_prompt}, | |
| ], | |
| ) | |
| answer = resp.choices[0].message.content.strip() | |
| if show_sources: | |
| answer += "\n\n---\nSources used:\n" + make_sources(records) | |
| return answer | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| def detect_logo_url(): | |
| candidates = [ | |
| "Brain chat-09.png", | |
| "brainchat_logo.png", | |
| "Brain Chat Imagen.svg", | |
| ] | |
| for name in candidates: | |
| if os.path.exists(name): | |
| return f"/gradio_api/file={quote(name)}" | |
| return None | |
| def top_html(): | |
| logo_url = detect_logo_url() | |
| if logo_url: | |
| logo = f'<img src="{logo_url}" style="width:110px;height:110px;object-fit:contain;border-radius:50%;">' | |
| else: | |
| logo = '<div style="width:110px;height:110px;border-radius:50%;background:#efe85a;display:flex;align-items:center;justify-content:center;font-weight:bold;">BRAIN<br>CHAT</div>' | |
| return f""" | |
| <div style=" | |
| max-width:430px; | |
| margin:18px auto 0 auto; | |
| border:16px solid black; | |
| border-radius:42px; | |
| background:linear-gradient(180deg,#e8c7d4 0%,#a55ca2 48%,#2b0c46 100%); | |
| padding:72px 18px 18px 18px; | |
| box-sizing:border-box; | |
| position:relative;"> | |
| <div style="position:absolute;top:0;left:50%;transform:translateX(-50%);width:170px;height:30px;background:black;border-bottom-left-radius:20px;border-bottom-right-radius:20px;"></div> | |
| <div style="display:flex;justify-content:center;margin-bottom:18px;">{logo}</div> | |
| <div style="text-align:center;color:white;font-size:28px;font-weight:700;margin-bottom:8px;">BrainChat</div> | |
| <div style="text-align:center;color:white;opacity:0.9;margin-bottom:10px;">Ask questions from all your uploaded neurology books</div> | |
| </div> | |
| """ | |
| CUSTOM_CSS = """ | |
| body, .gradio-container { | |
| background:#dcdcdc !important; | |
| } | |
| footer {display:none !important;} | |
| """ | |
| with gr.Blocks(css=CUSTOM_CSS) as demo: | |
| gr.HTML(top_html()) | |
| show_sources = gr.Checkbox(value=True, label="Show sources") | |
| gr.ChatInterface( | |
| fn=answer_question, | |
| additional_inputs=[show_sources], | |
| title=None, | |
| description=None, | |
| textbox=gr.Textbox(placeholder="Ask a question...", lines=1), | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |