Abhinav Biju commited on
Commit
cc2dc62
·
1 Parent(s): 182e0fa

fast/thinking toggle

Browse files
app.py CHANGED
@@ -382,20 +382,28 @@ def ingest_url_ui(
382
  def send_chat_ui(
383
  notebook_id: str | None,
384
  question: str,
 
385
  history: list[dict[str, str]] | None,
386
  current_username: str,
387
  profile: gr.OAuthProfile | None,
388
  request: gr.Request,
389
- ) -> tuple[list[dict[str, str]], str, str]:
390
  """Send one chat question and append the grounded answer to the chat history."""
391
 
392
  username: str = _resolve_username(profile, request, current_username)
393
  if not notebook_id:
394
- raise gr.Error("Select a notebook before asking a question.")
395
  if not question or not question.strip():
396
- raise gr.Error("Enter a question before sending.")
 
 
 
 
 
 
 
 
397
 
398
- response: ChatResponse = answer_question(username, notebook_id, question.strip())
399
  updated_history: list[dict[str, str]] = list(history or [])
400
  updated_history.append({"role": "user", "content": question.strip()})
401
  updated_history.append(
@@ -404,7 +412,7 @@ def send_chat_ui(
404
  "content": response["content"] + _render_citations(response["citations"]),
405
  }
406
  )
407
- return updated_history, "", f"Question answered with {len(response['citations'])} citations."
408
 
409
 
410
  def _append_artifact_path(current_paths: list[str] | None, artifact: ArtifactRef) -> tuple[list[str], gr.Dropdown]:
@@ -528,9 +536,24 @@ with gr.Blocks(title="NotebookLM Clone") as demo:
528
 
529
  with gr.Column():
530
  gr.Markdown("## Chat")
531
- chat_history = gr.Chatbot(label="Grounded Chat")
532
- question_input = gr.Textbox(label="Question", placeholder="Ask about this notebook")
533
- ask_button = gr.Button("Ask")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
534
 
535
  with gr.Column():
536
  gr.Markdown("## Artifacts")
@@ -580,10 +603,15 @@ with gr.Blocks(title="NotebookLM Clone") as demo:
580
  outputs=[ingest_status, uploaded_docs_state, uploaded_docs_display],
581
  )
582
 
583
- ask_button.click(
 
 
 
 
 
584
  send_chat_ui,
585
- inputs=[notebook_dropdown, question_input, chat_history, username_state],
586
- outputs=[chat_history, question_input, activity_status],
587
  )
588
 
589
  report_button.click(
 
382
  def send_chat_ui(
383
  notebook_id: str | None,
384
  question: str,
385
+ rag_mode: str,
386
  history: list[dict[str, str]] | None,
387
  current_username: str,
388
  profile: gr.OAuthProfile | None,
389
  request: gr.Request,
390
+ ) -> tuple[str, list[dict[str, str]]]:
391
  """Send one chat question and append the grounded answer to the chat history."""
392
 
393
  username: str = _resolve_username(profile, request, current_username)
394
  if not notebook_id:
395
+ raise gr.Error("Select a notebook before sending a message.")
396
  if not question or not question.strip():
397
+ raise gr.Error("Message cannot be empty.")
398
+
399
+ chat_history: list[dict[str, str]] = history or []
400
+ try:
401
+ response: ChatResponse = answer_question(username, notebook_id, question.strip(), rag_mode)
402
+ except Exception as e:
403
+ chat_history.append({"role": "user", "content": question.strip()})
404
+ chat_history.append({"role": "assistant", "content": f"Error: {e}"})
405
+ return "", chat_history
406
 
 
407
  updated_history: list[dict[str, str]] = list(history or [])
408
  updated_history.append({"role": "user", "content": question.strip()})
409
  updated_history.append(
 
412
  "content": response["content"] + _render_citations(response["citations"]),
413
  }
414
  )
415
+ return "", updated_history
416
 
417
 
418
  def _append_artifact_path(current_paths: list[str] | None, artifact: ArtifactRef) -> tuple[list[str], gr.Dropdown]:
 
536
 
537
  with gr.Column():
538
  gr.Markdown("## Chat")
539
+ chat_history = gr.Chatbot(
540
+ elem_id="chat-history",
541
+ show_label=False,
542
+ )
543
+ with gr.Row():
544
+ chat_input = gr.Textbox(
545
+ show_label=False,
546
+ placeholder="Ask a question about your sources...",
547
+ scale=4,
548
+ )
549
+ rag_mode = gr.Radio(
550
+ choices=["Fast", "Reasoning"],
551
+ value="Reasoning",
552
+ label="RAG Mode",
553
+ scale=1,
554
+ interactive=True,
555
+ )
556
+ chat_submit = gr.Button("Send", variant="primary")
557
 
558
  with gr.Column():
559
  gr.Markdown("## Artifacts")
 
603
  outputs=[ingest_status, uploaded_docs_state, uploaded_docs_display],
604
  )
605
 
606
+ chat_submit.click(
607
+ send_chat_ui,
608
+ inputs=[notebook_dropdown, chat_input, rag_mode, chat_history, username_state],
609
+ outputs=[chat_input, chat_history],
610
+ )
611
+ chat_input.submit(
612
  send_chat_ui,
613
+ inputs=[notebook_dropdown, chat_input, rag_mode, chat_history, username_state],
614
+ outputs=[chat_input, chat_history],
615
  )
616
 
617
  report_button.click(
src/notebooklm_clone/chat.py CHANGED
@@ -240,7 +240,7 @@ def _generate_answer(question: str, context: str) -> str:
240
  raise ChatGenerationError("Chat model returned an empty response.")
241
 
242
 
243
- def answer_question(username: str, notebook_id: str, question: str) -> ChatResponse:
244
  """Answer a notebook question using retrieved chunks and inline citations.
245
 
246
  Spec references:
@@ -270,6 +270,7 @@ def answer_question(username: str, notebook_id: str, question: str) -> ChatRespo
270
  notebook_id=notebook_id,
271
  query=normalized_question,
272
  k=_RETRIEVAL_K,
 
273
  )
274
 
275
  if not retrieved_chunks:
 
240
  raise ChatGenerationError("Chat model returned an empty response.")
241
 
242
 
243
+ def answer_question(username: str, notebook_id: str, question: str, rag_mode: str = "Reasoning") -> ChatResponse:
244
  """Answer a notebook question using retrieved chunks and inline citations.
245
 
246
  Spec references:
 
270
  notebook_id=notebook_id,
271
  query=normalized_question,
272
  k=_RETRIEVAL_K,
273
+ rag_mode=rag_mode,
274
  )
275
 
276
  if not retrieved_chunks:
src/notebooklm_clone/retrieval.py CHANGED
@@ -418,6 +418,7 @@ def retrieve(
418
  notebook_id: str,
419
  query: str,
420
  k: int,
 
421
  ) -> list[RetrievalResult]:
422
  """Retrieve top notebook chunks with hybrid scoring, query expansion, and reranking.
423
 
@@ -459,7 +460,7 @@ def retrieve(
459
  }
460
 
461
  # Query expansion: generate alt phrasings and merge scores
462
- queries: list[str] = _expand_query(query)
463
  bm25_raw, vector_raw = _multi_query_scores(
464
  chunk_documents, collection, queries, len(ids)
465
  )
@@ -515,10 +516,13 @@ def retrieve(
515
 
516
  ranked_results.sort(key=lambda item: (-item["score"], item["chunk_id"]))
517
 
518
- # Rerank only top-N candidates to control latency (default: 10)
519
- _rerank_n: int = int(os.getenv("NOTEBOOKLM_RERANK_TOP_N", "10"))
520
- rerank_pool: list[RetrievalResult] = ranked_results[:_rerank_n]
521
- result: list[RetrievalResult] = _rerank(query, rerank_pool, k)
 
 
 
522
 
523
  _log_retrieval(username, notebook_id, "success", started_at)
524
  return result
 
418
  notebook_id: str,
419
  query: str,
420
  k: int,
421
+ rag_mode: str = "Reasoning",
422
  ) -> list[RetrievalResult]:
423
  """Retrieve top notebook chunks with hybrid scoring, query expansion, and reranking.
424
 
 
460
  }
461
 
462
  # Query expansion: generate alt phrasings and merge scores
463
+ queries: list[str] = _expand_query(query) if rag_mode == "Reasoning" else [query]
464
  bm25_raw, vector_raw = _multi_query_scores(
465
  chunk_documents, collection, queries, len(ids)
466
  )
 
516
 
517
  ranked_results.sort(key=lambda item: (-item["score"], item["chunk_id"]))
518
 
519
+ if rag_mode == "Fast":
520
+ result: list[RetrievalResult] = ranked_results[:k]
521
+ else:
522
+ # Rerank only top-N candidates to control latency (default: 10)
523
+ _rerank_n: int = int(os.getenv("NOTEBOOKLM_RERANK_TOP_N", "10"))
524
+ rerank_pool: list[RetrievalResult] = ranked_results[:_rerank_n]
525
+ result: list[RetrievalResult] = _rerank(query, rerank_pool, k)
526
 
527
  _log_retrieval(username, notebook_id, "success", started_at)
528
  return result