Deevyankar commited on
Commit
74a7a39
·
verified ·
1 Parent(s): 07e9bdd

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +528 -0
app.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import re
3
+ import json
4
+ import pickle
5
+ from urllib.parse import quote
6
+
7
+ import numpy as np
8
+ import gradio as gr
9
+ from rank_bm25 import BM25Okapi
10
+ from sentence_transformers import SentenceTransformer
11
+ from openai import OpenAI
12
+
13
+ # =====================================================
14
+ # PATHS
15
+ # =====================================================
16
+ BUILD_DIR = "brainchat_build"
17
+ CHUNKS_PATH = os.path.join(BUILD_DIR, "chunks.pkl")
18
+ TOKENS_PATH = os.path.join(BUILD_DIR, "tokenized_chunks.pkl")
19
+ EMBED_PATH = os.path.join(BUILD_DIR, "embeddings.npy")
20
+ CONFIG_PATH = os.path.join(BUILD_DIR, "config.json")
21
+ LOGO_FILE = "Brain chat-09.png"
22
+
23
+ # =====================================================
24
+ # GLOBALS
25
+ # =====================================================
26
+ EMBED_MODEL = None
27
+ BM25 = None
28
+ CHUNKS = None
29
+ EMBEDDINGS = None
30
+ CLIENT = None
31
+
32
+
33
+ # =====================================================
34
+ # LOADERS
35
+ # =====================================================
36
+ def tokenize(text: str):
37
+ return re.findall(r"\w+", text.lower(), flags=re.UNICODE)
38
+
39
+
40
+ def ensure_loaded():
41
+ global EMBED_MODEL, BM25, CHUNKS, EMBEDDINGS, CLIENT
42
+
43
+ if CHUNKS is None:
44
+ for path in [CHUNKS_PATH, TOKENS_PATH, EMBED_PATH, CONFIG_PATH]:
45
+ if not os.path.exists(path):
46
+ raise FileNotFoundError(f"Missing file: {path}")
47
+
48
+ with open(CHUNKS_PATH, "rb") as f:
49
+ CHUNKS = pickle.load(f)
50
+
51
+ with open(TOKENS_PATH, "rb") as f:
52
+ tokenized_chunks = pickle.load(f)
53
+
54
+ EMBEDDINGS = np.load(EMBED_PATH)
55
+
56
+ with open(CONFIG_PATH, "r", encoding="utf-8") as f:
57
+ cfg = json.load(f)
58
+
59
+ BM25 = BM25Okapi(tokenized_chunks)
60
+ EMBED_MODEL = SentenceTransformer(cfg["embedding_model"])
61
+
62
+ if CLIENT is None:
63
+ api_key = os.getenv("OPENAI_API_KEY")
64
+ if not api_key:
65
+ raise ValueError("OPENAI_API_KEY is missing in Hugging Face Space Secrets.")
66
+ CLIENT = OpenAI(api_key=api_key)
67
+
68
+
69
+ # =====================================================
70
+ # RETRIEVAL
71
+ # =====================================================
72
+ def search_hybrid(query: str, shortlist_k: int = 20, final_k: int = 3):
73
+ ensure_loaded()
74
+
75
+ query_tokens = tokenize(query)
76
+ bm25_scores = BM25.get_scores(query_tokens)
77
+
78
+ shortlist_idx = np.argsort(bm25_scores)[::-1][:shortlist_k]
79
+ shortlist_embeddings = EMBEDDINGS[shortlist_idx]
80
+
81
+ qvec = EMBED_MODEL.encode([query], normalize_embeddings=True).astype("float32")[0]
82
+ dense_scores = shortlist_embeddings @ qvec
83
+
84
+ rerank_order = np.argsort(dense_scores)[::-1][:final_k]
85
+ final_idx = shortlist_idx[rerank_order]
86
+
87
+ return [CHUNKS[int(i)] for i in final_idx]
88
+
89
+
90
+ def build_context(records):
91
+ blocks = []
92
+ for i, r in enumerate(records, start=1):
93
+ blocks.append(
94
+ f"""[Source {i}]
95
+ Book: {r['book']}
96
+ Section: {r['section_title']}
97
+ Pages: {r['page_start']}-{r['page_end']}
98
+ Text:
99
+ {r['text']}"""
100
+ )
101
+ return "\n\n".join(blocks)
102
+
103
+
104
+ def make_sources(records):
105
+ seen = set()
106
+ lines = []
107
+ for r in records:
108
+ key = (r["book"], r["section_title"], r["page_start"], r["page_end"])
109
+ if key in seen:
110
+ continue
111
+ seen.add(key)
112
+ lines.append(
113
+ f"• {r['book']} | {r['section_title']} | pp. {r['page_start']}-{r['page_end']}"
114
+ )
115
+ return "\n".join(lines)
116
+
117
+
118
+ # =====================================================
119
+ # PROMPTS
120
+ # =====================================================
121
+ def language_instruction(language_mode: str) -> str:
122
+ if language_mode == "English":
123
+ return "Answer only in English."
124
+ if language_mode == "Spanish":
125
+ return "Answer only in Spanish."
126
+ if language_mode == "Bilingual":
127
+ return "Answer first in English, then provide a Spanish version under the heading 'Español:'."
128
+ return (
129
+ "If the user's message is in Spanish, answer in Spanish. "
130
+ "If the user's message is in English, answer in English."
131
+ )
132
+
133
+
134
+ def choose_quiz_count(user_text: str, selector: str) -> int:
135
+ if selector in {"3", "5", "7"}:
136
+ return int(selector)
137
+
138
+ t = user_text.lower()
139
+ if any(k in t for k in ["mock test", "final exam", "exam practice", "full test"]):
140
+ return 7
141
+ if any(k in t for k in ["detailed", "revision", "comprehensive", "study"]):
142
+ return 5
143
+ return 3
144
+
145
+
146
+ def build_tutor_prompt(mode: str, language_mode: str, question: str, context: str) -> str:
147
+ mode_map = {
148
+ "Explain": "Explain clearly like a friendly tutor using simple language.",
149
+ "Detailed": "Give a fuller and more detailed explanation.",
150
+ "Short Notes": "Answer in concise revision-note format using bullets.",
151
+ "Flashcards": "Create 6 flashcards in Q/A format.",
152
+ "Case-Based": "Create a short clinical scenario and explain it clearly."
153
+ }
154
+
155
+ return f"""
156
+ You are BrainChat, an interactive neurology and neuroanatomy tutor.
157
+
158
+ Rules:
159
+ - Use only the provided context from the books.
160
+ - If the answer is not supported by the context, say exactly:
161
+ Not found in the course material.
162
+ - Be accurate and student-friendly.
163
+ - {language_instruction(language_mode)}
164
+
165
+ Teaching style:
166
+ {mode_map[mode]}
167
+
168
+ Context:
169
+ {context}
170
+
171
+ Question:
172
+ {question}
173
+ """.strip()
174
+
175
+
176
+ def build_quiz_generation_prompt(language_mode: str, topic: str, context: str, n_questions: int) -> str:
177
+ return f"""
178
+ You are BrainChat, an interactive tutor.
179
+
180
+ Rules:
181
+ - Use only the provided context.
182
+ - Create exactly {n_questions} quiz questions.
183
+ - Questions should be short and clear.
184
+ - Also create a short answer key.
185
+ - Return valid JSON only.
186
+ - {language_instruction(language_mode)}
187
+
188
+ Required JSON format:
189
+ {{
190
+ "title": "short quiz title",
191
+ "questions": [
192
+ {{"q": "question 1", "answer_key": "expected short answer"}},
193
+ {{"q": "question 2", "answer_key": "expected short answer"}}
194
+ ]
195
+ }}
196
+
197
+ Context:
198
+ {context}
199
+
200
+ Topic:
201
+ {topic}
202
+ """.strip()
203
+
204
+
205
+ def build_quiz_evaluation_prompt(language_mode: str, quiz_data: dict, user_answers: str) -> str:
206
+ quiz_json = json.dumps(quiz_data, ensure_ascii=False)
207
+ return f"""
208
+ You are BrainChat, an interactive tutor.
209
+
210
+ Evaluate the student's answers fairly using the quiz answer key.
211
+ Give:
212
+ - total score
213
+ - per-question feedback
214
+ - one short improvement suggestion
215
+
216
+ Rules:
217
+ - Accept semantically correct answers even if wording differs.
218
+ - Return valid JSON only.
219
+ - {language_instruction(language_mode)}
220
+
221
+ Required JSON format:
222
+ {{
223
+ "score_obtained": 0,
224
+ "score_total": 0,
225
+ "summary": "short overall feedback",
226
+ "results": [
227
+ {{
228
+ "question": "question text",
229
+ "student_answer": "student answer",
230
+ "result": "Correct / Partially Correct / Incorrect",
231
+ "feedback": "short explanation"
232
+ }}
233
+ ]
234
+ }}
235
+
236
+ Quiz data:
237
+ {quiz_json}
238
+
239
+ Student answers:
240
+ {user_answers}
241
+ """.strip()
242
+
243
+
244
+ def chat_text(prompt: str) -> str:
245
+ resp = CLIENT.chat.completions.create(
246
+ model="gpt-4o-mini",
247
+ temperature=0.2,
248
+ messages=[
249
+ {"role": "system", "content": "You are a helpful educational assistant."},
250
+ {"role": "user", "content": prompt},
251
+ ],
252
+ )
253
+ return resp.choices[0].message.content.strip()
254
+
255
+
256
+ def chat_json(prompt: str) -> dict:
257
+ resp = CLIENT.chat.completions.create(
258
+ model="gpt-4o-mini",
259
+ temperature=0.2,
260
+ response_format={"type": "json_object"},
261
+ messages=[
262
+ {"role": "system", "content": "Return only valid JSON."},
263
+ {"role": "user", "content": prompt},
264
+ ],
265
+ )
266
+ return json.loads(resp.choices[0].message.content)
267
+
268
+
269
+ # =====================================================
270
+ # UI HELPERS
271
+ # =====================================================
272
+ def detect_logo_url():
273
+ if os.path.exists(LOGO_FILE):
274
+ return f"/gradio_api/file={quote(LOGO_FILE)}"
275
+ return None
276
+
277
+
278
+ def render_header():
279
+ logo_url = detect_logo_url()
280
+ if logo_url:
281
+ logo_html = f'<img src="{logo_url}" alt="BrainChat Logo" style="width:120px;height:120px;object-fit:contain;display:block;margin:0 auto;">'
282
+ else:
283
+ logo_html = '<div style="width:120px;height:120px;border-radius:50%;background:#efe85a;display:flex;align-items:center;justify-content:center;font-weight:700;text-align:center;margin:0 auto;">BRAIN<br>CHAT</div>'
284
+
285
+ return f"""
286
+ <div class="hero-card">
287
+ <div class="hero-inner">
288
+ <div class="hero-logo">{logo_html}</div>
289
+ <div class="hero-title">BrainChat</div>
290
+ <div class="hero-subtitle">
291
+ Interactive neurology and neuroanatomy tutor based on your uploaded books
292
+ </div>
293
+ </div>
294
+ </div>
295
+ """
296
+
297
+
298
+ # =====================================================
299
+ # MAIN CHAT FUNCTION
300
+ # =====================================================
301
+ def answer_question(message, history, mode, language_mode, quiz_count_mode, show_sources, quiz_state):
302
+ if history is None:
303
+ history = []
304
+ if quiz_state is None:
305
+ quiz_state = {
306
+ "active": False,
307
+ "topic": None,
308
+ "quiz_data": None,
309
+ "language_mode": "Auto"
310
+ }
311
+
312
+ if not message or not message.strip():
313
+ return history, quiz_state, ""
314
+
315
+ try:
316
+ ensure_loaded()
317
+ user_text = message.strip()
318
+
319
+ # Add user message first
320
+ history.append({"role": "user", "content": user_text})
321
+
322
+ # -----------------------------------------
323
+ # QUIZ EVALUATION STEP
324
+ # -----------------------------------------
325
+ if quiz_state.get("active", False):
326
+ evaluation_prompt = build_quiz_evaluation_prompt(
327
+ quiz_state["language_mode"],
328
+ quiz_state["quiz_data"],
329
+ user_text
330
+ )
331
+ evaluation = chat_json(evaluation_prompt)
332
+
333
+ lines = []
334
+ lines.append(f"**Score:** {evaluation['score_obtained']}/{evaluation['score_total']}")
335
+ lines.append("")
336
+ lines.append(f"**Overall feedback:** {evaluation['summary']}")
337
+ lines.append("")
338
+ lines.append("**Question-wise evaluation:**")
339
+
340
+ for item in evaluation["results"]:
341
+ lines.append("")
342
+ lines.append(f"**Q:** {item['question']}")
343
+ lines.append(f"**Your answer:** {item['student_answer']}")
344
+ lines.append(f"**Result:** {item['result']}")
345
+ lines.append(f"**Feedback:** {item['feedback']}")
346
+
347
+ final_answer = "\n".join(lines)
348
+ history.append({"role": "assistant", "content": final_answer})
349
+
350
+ quiz_state = {
351
+ "active": False,
352
+ "topic": None,
353
+ "quiz_data": None,
354
+ "language_mode": language_mode
355
+ }
356
+
357
+ return history, quiz_state, ""
358
+
359
+ # -----------------------------------------
360
+ # NORMAL RETRIEVAL
361
+ # -----------------------------------------
362
+ records = search_hybrid(user_text, shortlist_k=20, final_k=3)
363
+ context = build_context(records)
364
+
365
+ # -----------------------------------------
366
+ # QUIZ GENERATION
367
+ # -----------------------------------------
368
+ if mode == "Quiz Me":
369
+ n_questions = choose_quiz_count(user_text, quiz_count_mode)
370
+ prompt = build_quiz_generation_prompt(language_mode, user_text, context, n_questions)
371
+ quiz_data = chat_json(prompt)
372
+
373
+ lines = []
374
+ lines.append(f"**{quiz_data.get('title', 'Quiz')}**")
375
+ lines.append("")
376
+ lines.append("Please answer the following questions in one message.")
377
+ lines.append("You can reply in numbered format, for example:")
378
+ lines.append("1. ...")
379
+ lines.append("2. ...")
380
+ lines.append("")
381
+ lines.append(f"**Total questions: {len(quiz_data['questions'])}**")
382
+ lines.append("")
383
+
384
+ for i, q in enumerate(quiz_data["questions"], start=1):
385
+ lines.append(f"**Q{i}.** {q['q']}")
386
+
387
+ if show_sources:
388
+ lines.append("\n---\n**Topic sources used to create the quiz:**")
389
+ lines.append(make_sources(records))
390
+
391
+ assistant_text = "\n".join(lines)
392
+ history.append({"role": "assistant", "content": assistant_text})
393
+
394
+ quiz_state = {
395
+ "active": True,
396
+ "topic": user_text,
397
+ "quiz_data": quiz_data,
398
+ "language_mode": language_mode
399
+ }
400
+
401
+ return history, quiz_state, ""
402
+
403
+ # -----------------------------------------
404
+ # OTHER MODES
405
+ # -----------------------------------------
406
+ prompt = build_tutor_prompt(mode, language_mode, user_text, context)
407
+ answer = chat_text(prompt)
408
+
409
+ if show_sources:
410
+ answer += "\n\n---\n**Sources used:**\n" + make_sources(records)
411
+
412
+ history.append({"role": "assistant", "content": answer})
413
+ return history, quiz_state, ""
414
+
415
+ except Exception as e:
416
+ history.append({"role": "assistant", "content": f"Error: {str(e)}"})
417
+ quiz_state["active"] = False
418
+ return history, quiz_state, ""
419
+
420
+
421
+ def clear_all():
422
+ empty_quiz = {
423
+ "active": False,
424
+ "topic": None,
425
+ "quiz_data": None,
426
+ "language_mode": "Auto"
427
+ }
428
+ return [], empty_quiz, ""
429
+
430
+
431
+ CSS = """
432
+ body, .gradio-container {
433
+ background: #dcdcdc !important;
434
+ font-family: Arial, Helvetica, sans-serif !important;
435
+ }
436
+ footer { display: none !important; }
437
+ .hero-card {
438
+ max-width: 860px;
439
+ margin: 18px auto 14px auto;
440
+ border-radius: 28px;
441
+ background: linear-gradient(180deg, #e8c7d4 0%, #a55ca2 48%, #2b0c46 100%);
442
+ padding: 22px 22px 18px 22px;
443
+ }
444
+ .hero-inner { text-align: center; }
445
+ .hero-title {
446
+ color: white;
447
+ font-size: 34px;
448
+ font-weight: 800;
449
+ margin-top: 6px;
450
+ }
451
+ .hero-subtitle {
452
+ color: white;
453
+ opacity: 0.92;
454
+ font-size: 16px;
455
+ margin-top: 6px;
456
+ }
457
+ """
458
+
459
+
460
+ with gr.Blocks(css=CSS) as demo:
461
+ quiz_state = gr.State({
462
+ "active": False,
463
+ "topic": None,
464
+ "quiz_data": None,
465
+ "language_mode": "Auto"
466
+ })
467
+
468
+ gr.HTML(render_header())
469
+
470
+ with gr.Row():
471
+ mode = gr.Dropdown(
472
+ choices=["Explain", "Detailed", "Short Notes", "Quiz Me", "Flashcards", "Case-Based"],
473
+ value="Explain",
474
+ label="Tutor Mode"
475
+ )
476
+ language_mode = gr.Dropdown(
477
+ choices=["Auto", "English", "Spanish", "Bilingual"],
478
+ value="Auto",
479
+ label="Answer Language"
480
+ )
481
+
482
+ with gr.Row():
483
+ quiz_count_mode = gr.Dropdown(
484
+ choices=["Auto", "3", "5", "7"],
485
+ value="Auto",
486
+ label="Quiz Questions"
487
+ )
488
+ show_sources = gr.Checkbox(value=True, label="Show Sources")
489
+
490
+ gr.Markdown("""
491
+ **How to use**
492
+ - Choose a **Tutor Mode**
493
+ - Then type a topic or question
494
+ - For **Quiz Me**, type a topic such as: `cranial nerves`
495
+ - The system will ask questions, and your **next message will be evaluated automatically**
496
+ """)
497
+
498
+ chatbot = gr.Chatbot(height=520, type="messages")
499
+ msg = gr.Textbox(
500
+ placeholder="Ask a question or type a topic...",
501
+ lines=1,
502
+ show_label=False
503
+ )
504
+
505
+ with gr.Row():
506
+ send_btn = gr.Button("Send")
507
+ clear_btn = gr.Button("Clear Chat")
508
+
509
+ msg.submit(
510
+ answer_question,
511
+ inputs=[msg, chatbot, mode, language_mode, quiz_count_mode, show_sources, quiz_state],
512
+ outputs=[chatbot, quiz_state, msg]
513
+ )
514
+
515
+ send_btn.click(
516
+ answer_question,
517
+ inputs=[msg, chatbot, mode, language_mode, quiz_count_mode, show_sources, quiz_state],
518
+ outputs=[chatbot, quiz_state, msg]
519
+ )
520
+
521
+ clear_btn.click(
522
+ clear_all,
523
+ inputs=[],
524
+ outputs=[chatbot, quiz_state, msg]
525
+ )
526
+
527
+ if __name__ == "__main__":
528
+ demo.launch()