arjun10g commited on
Commit
f623ec4
·
verified ·
1 Parent(s): 30d4760

Deploy BankMind

Browse files
Files changed (4) hide show
  1. app.py +6 -2
  2. app/main.py +173 -47
  3. app/query_pipeline.py +172 -3
  4. pipelines/shared/llm.py +32 -0
app.py CHANGED
@@ -15,14 +15,18 @@ sys.path.insert(0, str(ROOT))
15
 
16
  import gradio as gr
17
 
18
- from app.main import build_app
19
 
20
 
21
  if __name__ == "__main__":
22
  demo = build_app()
 
 
 
23
  demo.launch(
24
  server_name=os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0"),
25
  server_port=int(os.environ.get("GRADIO_SERVER_PORT", "7860")),
26
  show_error=True,
27
- theme=gr.themes.Soft(),
 
28
  )
 
15
 
16
  import gradio as gr
17
 
18
+ from app.main import _BANKY_CSS, _BANKY_THEME, build_app
19
 
20
 
21
  if __name__ == "__main__":
22
  demo = build_app()
23
+ # Explicit queue so streaming events + multiple clients don't block each
24
+ # other (and tab switches don't freeze when a chat turn is in flight).
25
+ demo.queue(default_concurrency_limit=4, max_size=16)
26
  demo.launch(
27
  server_name=os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0"),
28
  server_port=int(os.environ.get("GRADIO_SERVER_PORT", "7860")),
29
  show_error=True,
30
+ theme=_BANKY_THEME,
31
+ css=_BANKY_CSS,
32
  )
app/main.py CHANGED
@@ -30,7 +30,7 @@ from app.charts import (
30
  retrieval_stage_2_figure,
31
  retrieval_stage_3_figure,
32
  )
33
- from app.query_pipeline import run_query
34
 
35
 
36
  COMPLIANCE_STRATEGIES = ["regulatory_boundary", "semantic", "hierarchical"]
@@ -253,55 +253,97 @@ def _build_qa_tab(module: str, strategies: list[str], default_strategy: str):
253
  # ---- Chat handlers ----------------------------------------------------
254
 
255
  def _on_send(user_msg, hist_pairs, strat, d, m, r, t, k, fk, gen, mx, chat_msgs):
256
- """Append user msg, run pipeline, append assistant msg.
 
257
 
258
- gr.Chatbot defaults to the messages format on modern Gradio: a list of
259
- {"role": "user"/"assistant", "content": "..."} dicts. hist_pairs is our
260
- parallel (user, assistant) tuple list used by the LLM rewriter.
261
  """
262
  user_msg = (user_msg or "").strip()
263
  if not user_msg:
264
- return (chat_msgs or []), hist_pairs, gr.update(), "", "", "_(empty input)_", "_(no chunks)_", None
265
-
266
- result = run_query(
267
- query=user_msg,
268
- module=module,
269
- chunk_strategy=strat,
270
- embedding_dim=int(d),
271
- retrieval_method=m,
272
- reranker=r,
273
- query_transform=t,
274
- top_k=int(k),
275
- final_k=int(fk),
276
- generate_answer=bool(gen),
277
- chat_history=hist_pairs or [],
278
- max_answer_tokens=int(mx),
279
- )
280
-
281
- assistant_msg = result.answer if result.answer else "_(generation is off in pipeline configuration)_"
282
-
 
 
 
 
 
 
 
 
 
 
 
 
283
  new_chat_msgs = (chat_msgs or []) + [
284
  {"role": "user", "content": user_msg},
285
- {"role": "assistant", "content": assistant_msg},
286
  ]
287
- new_pairs = (hist_pairs or []) + [(user_msg, assistant_msg)]
288
-
289
- config_line = (
290
- f"`{result.config_summary}` · query_id=`{result.query_id or '—'}`"
291
- )
292
- if result.rewritten_query and result.rewritten_query != user_msg:
293
- config_line += f"\n\n_Follow-up rewritten as:_ `{result.rewritten_query}`"
294
-
295
- return (
296
- new_chat_msgs,
297
- new_pairs,
298
- gr.update(value=""), # clear input
299
- _format_timings(result.timings),
300
- config_line,
301
- _format_guardrails(result.guardrail_report),
302
- _format_chunks(result.chunks),
303
- result,
304
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
305
 
306
  def _on_clear():
307
  return [], [], None, "", "", "_(send a message to see guardrails)_", "_(send a message to see retrieved passages)_"
@@ -349,11 +391,94 @@ def _build_perf_tab(module: str):
349
  # Build the app
350
  # =============================================================================
351
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  def build_app() -> gr.Blocks:
353
  with gr.Blocks(title="BankMind") as demo:
354
- gr.Markdown(
355
- "# 🏦 BankMind\n"
356
- "_Multi-domain RAG for financial intelligence: compliance and credit._"
 
 
 
 
357
  )
358
  with gr.Tabs():
359
  with gr.Tab("⚖️ Compliance Q&A"):
@@ -392,5 +517,6 @@ def build_app() -> gr.Blocks:
392
 
393
  if __name__ == "__main__":
394
  app = build_app()
 
395
  app.launch(server_name="127.0.0.1", server_port=7860, show_error=True,
396
- theme=gr.themes.Soft())
 
30
  retrieval_stage_2_figure,
31
  retrieval_stage_3_figure,
32
  )
33
+ from app.query_pipeline import run_query, run_query_stream
34
 
35
 
36
  COMPLIANCE_STRATEGIES = ["regulatory_boundary", "semantic", "hierarchical"]
 
253
  # ---- Chat handlers ----------------------------------------------------
254
 
255
  def _on_send(user_msg, hist_pairs, strat, d, m, r, t, k, fk, gen, mx, chat_msgs):
256
+ """Streaming chat handler. Yields progressive Gradio updates as tokens
257
+ stream in, so the UI never freezes during generation.
258
 
259
+ When `gen` is False (Generate answer toggled off), runs sync retrieve
260
+ only no LLM call, free.
 
261
  """
262
  user_msg = (user_msg or "").strip()
263
  if not user_msg:
264
+ yield (chat_msgs or [], hist_pairs, gr.update(), "", "",
265
+ "_(empty input)_", "_(no chunks)_", None)
266
+ return
267
+
268
+ # Free path: no generation. Falls back to the non-streaming run_query.
269
+ if not gen:
270
+ result = run_query(
271
+ query=user_msg, module=module, chunk_strategy=strat,
272
+ embedding_dim=int(d), retrieval_method=m, reranker=r,
273
+ query_transform=t, top_k=int(k), final_k=int(fk),
274
+ generate_answer=False, chat_history=hist_pairs or [],
275
+ max_answer_tokens=int(mx),
276
+ )
277
+ assistant_msg = "_(generation is off in pipeline configuration)_"
278
+ new_chat_msgs = (chat_msgs or []) + [
279
+ {"role": "user", "content": user_msg},
280
+ {"role": "assistant", "content": assistant_msg},
281
+ ]
282
+ new_pairs = (hist_pairs or []) + [(user_msg, assistant_msg)]
283
+ cfg = f"`{result.config_summary}` · query_id=`{result.query_id or '—'}`"
284
+ if result.rewritten_query and result.rewritten_query != user_msg:
285
+ cfg += f"\n\n_Follow-up rewritten as:_ `{result.rewritten_query}`"
286
+ yield (new_chat_msgs, new_pairs, gr.update(value=""),
287
+ _format_timings(result.timings), cfg,
288
+ _format_guardrails(result.guardrail_report),
289
+ _format_chunks(result.chunks), result)
290
+ return
291
+
292
+ # Streaming path. Show user message + an empty assistant placeholder
293
+ # immediately, then progressively fill the assistant message as tokens
294
+ # arrive.
295
  new_chat_msgs = (chat_msgs or []) + [
296
  {"role": "user", "content": user_msg},
297
+ {"role": "assistant", "content": "…"},
298
  ]
299
+ # Initial yield: clear input, lock in the user message, show "thinking".
300
+ yield (new_chat_msgs, hist_pairs or [], gr.update(value=""),
301
+ "_…retrieving…_", "", "_(running guardrails after generation)_",
302
+ "_(loading)_", None)
303
+
304
+ last_setup_result = None
305
+ accumulated = ""
306
+ for event_type, payload in run_query_stream(
307
+ query=user_msg, module=module, chunk_strategy=strat,
308
+ embedding_dim=int(d), retrieval_method=m, reranker=r,
309
+ query_transform=t, top_k=int(k), final_k=int(fk),
310
+ chat_history=hist_pairs or [], max_answer_tokens=int(mx),
311
+ ):
312
+ if event_type == "setup":
313
+ last_setup_result = payload
314
+ cfg = (
315
+ f"`{payload.config_summary}`"
316
+ + (f"\n\n_Follow-up rewritten as:_ `{payload.rewritten_query}`"
317
+ if payload.rewritten_query and payload.rewritten_query != user_msg
318
+ else "")
319
+ )
320
+ yield (new_chat_msgs, hist_pairs or [], gr.update(),
321
+ _format_timings(payload.timings),
322
+ cfg,
323
+ "_(running guardrails after generation)_",
324
+ _format_chunks(payload.chunks),
325
+ None)
326
+ elif event_type == "token":
327
+ accumulated = payload
328
+ # Update the LAST assistant message in place
329
+ new_chat_msgs[-1] = {"role": "assistant", "content": accumulated or "…"}
330
+ yield (new_chat_msgs, hist_pairs or [], gr.update(),
331
+ gr.update(), gr.update(), gr.update(), gr.update(), None)
332
+ elif event_type == "done":
333
+ final_result = payload
334
+ final_answer = final_result.answer or accumulated or "_(no answer)_"
335
+ new_chat_msgs[-1] = {"role": "assistant", "content": final_answer}
336
+ new_pairs = (hist_pairs or []) + [(user_msg, final_answer)]
337
+ cfg = f"`{final_result.config_summary}` · query_id=`{final_result.query_id or '—'}`"
338
+ if final_result.rewritten_query and final_result.rewritten_query != user_msg:
339
+ cfg += f"\n\n_Follow-up rewritten as:_ `{final_result.rewritten_query}`"
340
+ yield (new_chat_msgs, new_pairs, gr.update(),
341
+ _format_timings(final_result.timings),
342
+ cfg,
343
+ _format_guardrails(final_result.guardrail_report),
344
+ _format_chunks(final_result.chunks),
345
+ final_result)
346
+ return
347
 
348
  def _on_clear():
349
  return [], [], None, "", "", "_(send a message to see guardrails)_", "_(send a message to see retrieved passages)_"
 
391
  # Build the app
392
  # =============================================================================
393
 
394
+ _BANKY_THEME = gr.themes.Base(
395
+ primary_hue=gr.themes.colors.amber,
396
+ secondary_hue=gr.themes.colors.slate,
397
+ neutral_hue=gr.themes.colors.slate,
398
+ font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
399
+ font_mono=[gr.themes.GoogleFont("JetBrains Mono"), "ui-monospace", "monospace"],
400
+ ).set(
401
+ body_background_fill="#0b1220",
402
+ body_background_fill_dark="#0b1220",
403
+ body_text_color="#e5e7eb",
404
+ body_text_color_dark="#e5e7eb",
405
+ background_fill_primary="#111827",
406
+ background_fill_primary_dark="#111827",
407
+ background_fill_secondary="#0f172a",
408
+ background_fill_secondary_dark="#0f172a",
409
+ block_background_fill="#0f172a",
410
+ block_background_fill_dark="#0f172a",
411
+ block_border_color="#1f2937",
412
+ block_border_color_dark="#1f2937",
413
+ block_label_background_fill="#0b1220",
414
+ block_label_background_fill_dark="#0b1220",
415
+ block_label_text_color="#fbbf24",
416
+ block_label_text_color_dark="#fbbf24",
417
+ border_color_accent="#fbbf24",
418
+ border_color_accent_dark="#fbbf24",
419
+ border_color_primary="#1f2937",
420
+ border_color_primary_dark="#1f2937",
421
+ button_primary_background_fill="#fbbf24",
422
+ button_primary_background_fill_dark="#fbbf24",
423
+ button_primary_background_fill_hover="#f59e0b",
424
+ button_primary_background_fill_hover_dark="#f59e0b",
425
+ button_primary_text_color="#0b1220",
426
+ button_primary_text_color_dark="#0b1220",
427
+ button_secondary_background_fill="#1f2937",
428
+ button_secondary_background_fill_dark="#1f2937",
429
+ button_secondary_text_color="#e5e7eb",
430
+ button_secondary_text_color_dark="#e5e7eb",
431
+ input_background_fill="#0b1220",
432
+ input_background_fill_dark="#0b1220",
433
+ input_border_color="#1f2937",
434
+ input_border_color_dark="#1f2937",
435
+ input_border_color_focus="#fbbf24",
436
+ input_border_color_focus_dark="#fbbf24",
437
+ color_accent_soft="#1f2937",
438
+ color_accent_soft_dark="#1f2937",
439
+ )
440
+
441
+
442
+ _BANKY_CSS = """
443
+ .gradio-container { max-width: 1280px !important; }
444
+ h1, h2, h3, h4 { letter-spacing: -0.01em; font-weight: 600; }
445
+ h1 { font-size: 1.75rem !important; }
446
+ .tabitem { padding-top: 0.5rem; }
447
+
448
+ /* Tighter accordion headers */
449
+ .label-wrap > span { font-weight: 600; letter-spacing: 0.01em; }
450
+
451
+ /* Markdown body text */
452
+ .prose { line-height: 1.55; }
453
+ .prose code, .prose pre { background: #0b1220 !important; border: 1px solid #1f2937; border-radius: 6px; }
454
+ .prose table { font-size: 0.92em; }
455
+ .prose th { background: #0b1220 !important; color: #fbbf24 !important; font-weight: 600; }
456
+ .prose td { border-color: #1f2937 !important; }
457
+
458
+ /* Chatbot polish */
459
+ .message.user { background: #1e293b !important; }
460
+ .message.bot, .message.assistant { background: #0f172a !important; border: 1px solid #1f2937; }
461
+
462
+ /* Subtle gold accent on the title bar */
463
+ #title-banner {
464
+ border-left: 3px solid #fbbf24;
465
+ padding-left: 0.85rem;
466
+ margin: 0.25rem 0 1rem 0;
467
+ }
468
+ #title-banner h1 { margin: 0; font-size: 1.5rem !important; }
469
+ #title-banner .tagline { color: #94a3b8; font-size: 0.95rem; margin-top: 0.15rem; }
470
+ """
471
+
472
+
473
  def build_app() -> gr.Blocks:
474
  with gr.Blocks(title="BankMind") as demo:
475
+ gr.HTML(
476
+ """
477
+ <div id="title-banner">
478
+ <h1>🏦 BankMind</h1>
479
+ <div class="tagline">Multi-domain RAG for financial intelligence: regulatory compliance and credit risk.</div>
480
+ </div>
481
+ """
482
  )
483
  with gr.Tabs():
484
  with gr.Tab("⚖️ Compliance Q&A"):
 
517
 
518
  if __name__ == "__main__":
519
  app = build_app()
520
+ app.queue(default_concurrency_limit=4, max_size=16)
521
  app.launch(server_name="127.0.0.1", server_port=7860, show_error=True,
522
+ theme=_BANKY_THEME, css=_BANKY_CSS)
app/query_pipeline.py CHANGED
@@ -19,7 +19,7 @@ from typing import Optional
19
 
20
  from pipelines.shared.fusion import convex_combination, hierarchical, rrf
21
  from pipelines.shared.guardrails import GuardrailReport, check as run_guardrails
22
- from pipelines.shared.llm import claude_text
23
  from pipelines.shared.query_logger import chunk_for_log, log_query
24
  from pipelines.shared.query_transformer import apply_transform
25
  from pipelines.shared.reranker import rerank
@@ -123,7 +123,9 @@ def _format_history(history: list[tuple[str, str]], *, max_turns: int = 6) -> st
123
  def _rewrite_followup(query: str, history: list[tuple[str, str]]) -> str:
124
  """Rewrite a possibly-ambiguous follow-up into a standalone query.
125
 
126
- No-op when history is empty. On any LLM error, returns the original query.
 
 
127
  """
128
  if not history:
129
  return query
@@ -131,10 +133,10 @@ def _rewrite_followup(query: str, history: list[tuple[str, str]]) -> str:
131
  rewritten = claude_text(
132
  _REWRITE_USER.format(history=_format_history(history), query=query),
133
  system=_REWRITE_SYSTEM,
 
134
  max_tokens=200,
135
  temperature=0.0,
136
  )
137
- # Defensive: strip surrounding quotes / leading/trailing whitespace
138
  rewritten = rewritten.strip().strip('"').strip("'").strip()
139
  return rewritten or query
140
  except Exception:
@@ -305,3 +307,170 @@ def run_query(
305
  query_id=qid,
306
  rewritten_query=rewritten_query,
307
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
 
20
  from pipelines.shared.fusion import convex_combination, hierarchical, rrf
21
  from pipelines.shared.guardrails import GuardrailReport, check as run_guardrails
22
+ from pipelines.shared.llm import FAST_MODEL, claude_text, claude_text_stream
23
  from pipelines.shared.query_logger import chunk_for_log, log_query
24
  from pipelines.shared.query_transformer import apply_transform
25
  from pipelines.shared.reranker import rerank
 
123
  def _rewrite_followup(query: str, history: list[tuple[str, str]]) -> str:
124
  """Rewrite a possibly-ambiguous follow-up into a standalone query.
125
 
126
+ Uses the fast/cheap model (Haiku by default) since paraphrasing is a
127
+ utility task and Sonnet is overkill. No-op when history is empty.
128
+ On any LLM error, returns the original query.
129
  """
130
  if not history:
131
  return query
 
133
  rewritten = claude_text(
134
  _REWRITE_USER.format(history=_format_history(history), query=query),
135
  system=_REWRITE_SYSTEM,
136
+ model=FAST_MODEL,
137
  max_tokens=200,
138
  temperature=0.0,
139
  )
 
140
  rewritten = rewritten.strip().strip('"').strip("'").strip()
141
  return rewritten or query
142
  except Exception:
 
307
  query_id=qid,
308
  rewritten_query=rewritten_query,
309
  )
310
+
311
+
312
+ # === Streaming variant for the chat UI ======================================
313
+ # Yields three kinds of events:
314
+ # ("setup", QueryResult) -> retrieval + rerank done, generation starts
315
+ # ("token", str) -> cumulative answer text (after each delta)
316
+ # ("done", QueryResult) -> final guardrail-checked + logged result
317
+ # This lets the UI render the chunks panel and timings as soon as retrieval is
318
+ # done, then stream tokens into the chat, then apply guardrails when generation
319
+ # completes.
320
+
321
+ def run_query_stream(
322
+ *,
323
+ query: str,
324
+ module: str,
325
+ chunk_strategy: str,
326
+ embedding_dim: int,
327
+ retrieval_method: str,
328
+ reranker: str,
329
+ query_transform: str,
330
+ top_k: int = 10,
331
+ final_k: int = 5,
332
+ chat_history: Optional[list[tuple[str, str]]] = None,
333
+ max_answer_tokens: int = 900,
334
+ ):
335
+ if not query.strip():
336
+ yield ("done", QueryResult(answer=None, chunks=[], timings={},
337
+ transformed_queries=[], config_summary="(empty query)"))
338
+ return
339
+
340
+ timings: dict[str, float] = {}
341
+ retr = _retriever()
342
+ history = chat_history or []
343
+
344
+ # 1) Follow-up rewrite (Haiku, fast)
345
+ rewritten_query = None
346
+ retrieval_query = query
347
+ if history:
348
+ t0 = time.perf_counter()
349
+ rewritten_query = _rewrite_followup(query, history)
350
+ retrieval_query = rewritten_query
351
+ timings["rewrite_ms"] = (time.perf_counter() - t0) * 1000
352
+
353
+ # 2) Optional pre-retrieval transform
354
+ t0 = time.perf_counter()
355
+ try:
356
+ tr = apply_transform(
357
+ query_transform, retrieval_query, module=module,
358
+ retriever=retr, chunk_strategy=chunk_strategy,
359
+ embedding_dim=embedding_dim,
360
+ )
361
+ except Exception as e:
362
+ from pipelines.shared.query_transformer import TransformResult
363
+ tr = TransformResult(queries=[retrieval_query], transform_name="none-fallback",
364
+ extras={"error": str(e)})
365
+ timings["transform_ms"] = (time.perf_counter() - t0) * 1000
366
+
367
+ def _retrieve_one(q: str) -> list[ScoredChunk]:
368
+ if retrieval_method == "dense":
369
+ return retr.search(query=q, module=module, chunk_strategy=chunk_strategy,
370
+ mode="dense", embedding_dim=embedding_dim, top_k=top_k)
371
+ if retrieval_method in ("bm25", "splade"):
372
+ return retr.search(query=q, module=module, chunk_strategy=chunk_strategy,
373
+ mode="sparse", sparse_name=retrieval_method, top_k=top_k)
374
+ if retrieval_method == "hybrid_rrf":
375
+ return retr.search(query=q, module=module, chunk_strategy=chunk_strategy,
376
+ mode="hybrid", embedding_dim=embedding_dim, top_k=top_k)
377
+ if retrieval_method == "hybrid_convex":
378
+ d, s, _ = retr.search_separate_channels(
379
+ query=q, module=module, chunk_strategy=chunk_strategy,
380
+ embedding_dim=embedding_dim, top_k=50,
381
+ )
382
+ return convex_combination(d, s, alpha=0.7, top_k=top_k)
383
+ if retrieval_method == "hybrid_hier":
384
+ d, s, _ = retr.search_separate_channels(
385
+ query=q, module=module, chunk_strategy=chunk_strategy,
386
+ embedding_dim=embedding_dim, top_k=50,
387
+ )
388
+ return hierarchical(q, d, s, top_k=top_k)
389
+ raise ValueError(f"unknown retrieval_method: {retrieval_method}")
390
+
391
+ t0 = time.perf_counter()
392
+ if len(tr.queries) == 1:
393
+ retrieved = _retrieve_one(tr.queries[0])
394
+ else:
395
+ retrieved = rrf([_retrieve_one(q) for q in tr.queries], top_k=top_k)
396
+ timings["retrieve_ms"] = (time.perf_counter() - t0) * 1000
397
+
398
+ t0 = time.perf_counter()
399
+ if reranker != "none":
400
+ try:
401
+ top = rerank(retrieval_query, retrieved, name=reranker, top_n=final_k)
402
+ except Exception as e:
403
+ top = retrieved[:final_k]
404
+ timings["reranker_error"] = str(e)
405
+ else:
406
+ top = retrieved[:final_k]
407
+ timings["rerank_ms"] = (time.perf_counter() - t0) * 1000
408
+
409
+ # Setup event: chunks are known, timings up to this point are known.
410
+ setup_result = QueryResult(
411
+ answer=None,
412
+ chunks=top,
413
+ timings=dict(timings),
414
+ transformed_queries=tr.queries if query_transform != "none" else [],
415
+ config_summary=(
416
+ f"module={module} strategy={chunk_strategy} dim={embedding_dim} "
417
+ f"retrieval={retrieval_method} reranker={reranker} transform={query_transform} "
418
+ f"top_k={top_k} final_k={final_k} chat_turns={len(history)}"
419
+ ),
420
+ rewritten_query=rewritten_query,
421
+ )
422
+ yield ("setup", setup_result)
423
+
424
+ # Stream generation
425
+ role = "compliance officer" if module == "compliance" else "credit analyst"
426
+ passages = "\n\n".join(
427
+ f"[{i+1}] (doc: {c.payload.get('doc_id','?')}, section: {c.payload.get('section_title','')})\n{c.content[:1800]}"
428
+ for i, c in enumerate(top)
429
+ )
430
+ if history:
431
+ user_prompt = _USER_TURN_TEMPLATE.format(
432
+ history=_format_history(history), query=query, passages=passages,
433
+ )
434
+ else:
435
+ user_prompt = _NO_HISTORY_TEMPLATE.format(query=query, passages=passages)
436
+
437
+ t0 = time.perf_counter()
438
+ accumulated = ""
439
+ for partial in claude_text_stream(
440
+ user_prompt,
441
+ system=_SYSTEM_PROMPT.format(role=role),
442
+ max_tokens=max_answer_tokens,
443
+ temperature=0.0,
444
+ ):
445
+ accumulated = partial
446
+ yield ("token", accumulated)
447
+ timings["generate_ms"] = (time.perf_counter() - t0) * 1000
448
+
449
+ # Guardrails + log
450
+ t0 = time.perf_counter()
451
+ guardrail_report = run_guardrails(module, accumulated, top, query)
452
+ timings["guardrails_ms"] = (time.perf_counter() - t0) * 1000
453
+ timings["total_ms"] = sum(v for k, v in timings.items() if k.endswith("_ms"))
454
+
455
+ config_dict = {
456
+ "module": module, "chunk_strategy": chunk_strategy,
457
+ "embedding_dim": embedding_dim, "retrieval_method": retrieval_method,
458
+ "reranker": reranker, "query_transform": query_transform,
459
+ "top_k": top_k, "final_k": final_k, "generate_answer": True,
460
+ "chat_turns": len(history), "rewritten_query": rewritten_query,
461
+ "streaming": True,
462
+ }
463
+ qid = log_query(
464
+ query=query, config=config_dict, timings=timings,
465
+ transformed_queries=tr.queries if query_transform != "none" else [],
466
+ top_chunks=[chunk_for_log(c) for c in top], answer=accumulated,
467
+ guardrail_report=guardrail_report.to_dict(),
468
+ )
469
+
470
+ yield ("done", QueryResult(
471
+ answer=accumulated, chunks=top, timings=timings,
472
+ transformed_queries=tr.queries if query_transform != "none" else [],
473
+ config_summary=setup_result.config_summary,
474
+ guardrail_report=guardrail_report, query_id=qid,
475
+ rewritten_query=rewritten_query,
476
+ ))
pipelines/shared/llm.py CHANGED
@@ -29,6 +29,8 @@ from dotenv import load_dotenv
29
  load_dotenv()
30
 
31
  DEFAULT_MODEL = os.environ.get("CLAUDE_MODEL", "claude-sonnet-4-6")
 
 
32
  DEFAULT_MAX_TOKENS = 1024
33
 
34
 
@@ -84,6 +86,36 @@ def claude_text(
84
  return "".join(parts).strip()
85
 
86
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
87
  _JSON_FENCE = re.compile(r"```(?:json)?\s*([\s\S]*?)```")
88
 
89
 
 
29
  load_dotenv()
30
 
31
  DEFAULT_MODEL = os.environ.get("CLAUDE_MODEL", "claude-sonnet-4-6")
32
+ # Light, fast model for utility tasks like follow-up rewriting
33
+ FAST_MODEL = os.environ.get("CLAUDE_FAST_MODEL", "claude-haiku-4-5-20251001")
34
  DEFAULT_MAX_TOKENS = 1024
35
 
36
 
 
86
  return "".join(parts).strip()
87
 
88
 
89
+ def claude_text_stream(
90
+ prompt: str,
91
+ *,
92
+ system: str = "",
93
+ model: str = DEFAULT_MODEL,
94
+ max_tokens: int = DEFAULT_MAX_TOKENS,
95
+ temperature: float = 0.0,
96
+ ):
97
+ """Generator yielding the response text as it arrives.
98
+
99
+ Each yield is the *cumulative* output so far (suitable for piping straight
100
+ into Gradio's Chatbot streaming). On error, yields a single error message.
101
+ """
102
+ try:
103
+ client = _get_client()
104
+ accumulated = ""
105
+ with client.messages.stream(
106
+ model=model,
107
+ max_tokens=max_tokens,
108
+ temperature=temperature,
109
+ system=system or "You are a helpful assistant.",
110
+ messages=[{"role": "user", "content": prompt}],
111
+ ) as stream:
112
+ for delta in stream.text_stream:
113
+ accumulated += delta
114
+ yield accumulated
115
+ except Exception as e:
116
+ yield f"_(generation failed: {type(e).__name__}: {e})_"
117
+
118
+
119
  _JSON_FENCE = re.compile(r"```(?:json)?\s*([\s\S]*?)```")
120
 
121