Deploy BankMind
Browse files- app.py +6 -2
- app/main.py +173 -47
- app/query_pipeline.py +172 -3
- pipelines/shared/llm.py +32 -0
app.py
CHANGED
|
@@ -15,14 +15,18 @@ sys.path.insert(0, str(ROOT))
|
|
| 15 |
|
| 16 |
import gradio as gr
|
| 17 |
|
| 18 |
-
from app.main import build_app
|
| 19 |
|
| 20 |
|
| 21 |
if __name__ == "__main__":
|
| 22 |
demo = build_app()
|
|
|
|
|
|
|
|
|
|
| 23 |
demo.launch(
|
| 24 |
server_name=os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0"),
|
| 25 |
server_port=int(os.environ.get("GRADIO_SERVER_PORT", "7860")),
|
| 26 |
show_error=True,
|
| 27 |
-
theme=
|
|
|
|
| 28 |
)
|
|
|
|
| 15 |
|
| 16 |
import gradio as gr
|
| 17 |
|
| 18 |
+
from app.main import _BANKY_CSS, _BANKY_THEME, build_app
|
| 19 |
|
| 20 |
|
| 21 |
if __name__ == "__main__":
|
| 22 |
demo = build_app()
|
| 23 |
+
# Explicit queue so streaming events + multiple clients don't block each
|
| 24 |
+
# other (and tab switches don't freeze when a chat turn is in flight).
|
| 25 |
+
demo.queue(default_concurrency_limit=4, max_size=16)
|
| 26 |
demo.launch(
|
| 27 |
server_name=os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0"),
|
| 28 |
server_port=int(os.environ.get("GRADIO_SERVER_PORT", "7860")),
|
| 29 |
show_error=True,
|
| 30 |
+
theme=_BANKY_THEME,
|
| 31 |
+
css=_BANKY_CSS,
|
| 32 |
)
|
app/main.py
CHANGED
|
@@ -30,7 +30,7 @@ from app.charts import (
|
|
| 30 |
retrieval_stage_2_figure,
|
| 31 |
retrieval_stage_3_figure,
|
| 32 |
)
|
| 33 |
-
from app.query_pipeline import run_query
|
| 34 |
|
| 35 |
|
| 36 |
COMPLIANCE_STRATEGIES = ["regulatory_boundary", "semantic", "hierarchical"]
|
|
@@ -253,55 +253,97 @@ def _build_qa_tab(module: str, strategies: list[str], default_strategy: str):
|
|
| 253 |
# ---- Chat handlers ----------------------------------------------------
|
| 254 |
|
| 255 |
def _on_send(user_msg, hist_pairs, strat, d, m, r, t, k, fk, gen, mx, chat_msgs):
|
| 256 |
-
"""
|
|
|
|
| 257 |
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
parallel (user, assistant) tuple list used by the LLM rewriter.
|
| 261 |
"""
|
| 262 |
user_msg = (user_msg or "").strip()
|
| 263 |
if not user_msg:
|
| 264 |
-
|
| 265 |
-
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
| 275 |
-
|
| 276 |
-
|
| 277 |
-
|
| 278 |
-
|
| 279 |
-
|
| 280 |
-
|
| 281 |
-
|
| 282 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
new_chat_msgs = (chat_msgs or []) + [
|
| 284 |
{"role": "user", "content": user_msg},
|
| 285 |
-
{"role": "assistant", "content":
|
| 286 |
]
|
| 287 |
-
|
| 288 |
-
|
| 289 |
-
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
|
| 295 |
-
|
| 296 |
-
|
| 297 |
-
|
| 298 |
-
|
| 299 |
-
|
| 300 |
-
|
| 301 |
-
|
| 302 |
-
|
| 303 |
-
|
| 304 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 305 |
|
| 306 |
def _on_clear():
|
| 307 |
return [], [], None, "", "", "_(send a message to see guardrails)_", "_(send a message to see retrieved passages)_"
|
|
@@ -349,11 +391,94 @@ def _build_perf_tab(module: str):
|
|
| 349 |
# Build the app
|
| 350 |
# =============================================================================
|
| 351 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 352 |
def build_app() -> gr.Blocks:
|
| 353 |
with gr.Blocks(title="BankMind") as demo:
|
| 354 |
-
gr.
|
| 355 |
-
"
|
| 356 |
-
"
|
|
|
|
|
|
|
|
|
|
|
|
|
| 357 |
)
|
| 358 |
with gr.Tabs():
|
| 359 |
with gr.Tab("⚖️ Compliance Q&A"):
|
|
@@ -392,5 +517,6 @@ def build_app() -> gr.Blocks:
|
|
| 392 |
|
| 393 |
if __name__ == "__main__":
|
| 394 |
app = build_app()
|
|
|
|
| 395 |
app.launch(server_name="127.0.0.1", server_port=7860, show_error=True,
|
| 396 |
-
theme=
|
|
|
|
| 30 |
retrieval_stage_2_figure,
|
| 31 |
retrieval_stage_3_figure,
|
| 32 |
)
|
| 33 |
+
from app.query_pipeline import run_query, run_query_stream
|
| 34 |
|
| 35 |
|
| 36 |
COMPLIANCE_STRATEGIES = ["regulatory_boundary", "semantic", "hierarchical"]
|
|
|
|
| 253 |
# ---- Chat handlers ----------------------------------------------------
|
| 254 |
|
| 255 |
def _on_send(user_msg, hist_pairs, strat, d, m, r, t, k, fk, gen, mx, chat_msgs):
|
| 256 |
+
"""Streaming chat handler. Yields progressive Gradio updates as tokens
|
| 257 |
+
stream in, so the UI never freezes during generation.
|
| 258 |
|
| 259 |
+
When `gen` is False (Generate answer toggled off), runs sync retrieve
|
| 260 |
+
only — no LLM call, free.
|
|
|
|
| 261 |
"""
|
| 262 |
user_msg = (user_msg or "").strip()
|
| 263 |
if not user_msg:
|
| 264 |
+
yield (chat_msgs or [], hist_pairs, gr.update(), "", "",
|
| 265 |
+
"_(empty input)_", "_(no chunks)_", None)
|
| 266 |
+
return
|
| 267 |
+
|
| 268 |
+
# Free path: no generation. Falls back to the non-streaming run_query.
|
| 269 |
+
if not gen:
|
| 270 |
+
result = run_query(
|
| 271 |
+
query=user_msg, module=module, chunk_strategy=strat,
|
| 272 |
+
embedding_dim=int(d), retrieval_method=m, reranker=r,
|
| 273 |
+
query_transform=t, top_k=int(k), final_k=int(fk),
|
| 274 |
+
generate_answer=False, chat_history=hist_pairs or [],
|
| 275 |
+
max_answer_tokens=int(mx),
|
| 276 |
+
)
|
| 277 |
+
assistant_msg = "_(generation is off in pipeline configuration)_"
|
| 278 |
+
new_chat_msgs = (chat_msgs or []) + [
|
| 279 |
+
{"role": "user", "content": user_msg},
|
| 280 |
+
{"role": "assistant", "content": assistant_msg},
|
| 281 |
+
]
|
| 282 |
+
new_pairs = (hist_pairs or []) + [(user_msg, assistant_msg)]
|
| 283 |
+
cfg = f"`{result.config_summary}` · query_id=`{result.query_id or '—'}`"
|
| 284 |
+
if result.rewritten_query and result.rewritten_query != user_msg:
|
| 285 |
+
cfg += f"\n\n_Follow-up rewritten as:_ `{result.rewritten_query}`"
|
| 286 |
+
yield (new_chat_msgs, new_pairs, gr.update(value=""),
|
| 287 |
+
_format_timings(result.timings), cfg,
|
| 288 |
+
_format_guardrails(result.guardrail_report),
|
| 289 |
+
_format_chunks(result.chunks), result)
|
| 290 |
+
return
|
| 291 |
+
|
| 292 |
+
# Streaming path. Show user message + an empty assistant placeholder
|
| 293 |
+
# immediately, then progressively fill the assistant message as tokens
|
| 294 |
+
# arrive.
|
| 295 |
new_chat_msgs = (chat_msgs or []) + [
|
| 296 |
{"role": "user", "content": user_msg},
|
| 297 |
+
{"role": "assistant", "content": "…"},
|
| 298 |
]
|
| 299 |
+
# Initial yield: clear input, lock in the user message, show "thinking".
|
| 300 |
+
yield (new_chat_msgs, hist_pairs or [], gr.update(value=""),
|
| 301 |
+
"_…retrieving…_", "", "_(running guardrails after generation)_",
|
| 302 |
+
"_(loading)_", None)
|
| 303 |
+
|
| 304 |
+
last_setup_result = None
|
| 305 |
+
accumulated = ""
|
| 306 |
+
for event_type, payload in run_query_stream(
|
| 307 |
+
query=user_msg, module=module, chunk_strategy=strat,
|
| 308 |
+
embedding_dim=int(d), retrieval_method=m, reranker=r,
|
| 309 |
+
query_transform=t, top_k=int(k), final_k=int(fk),
|
| 310 |
+
chat_history=hist_pairs or [], max_answer_tokens=int(mx),
|
| 311 |
+
):
|
| 312 |
+
if event_type == "setup":
|
| 313 |
+
last_setup_result = payload
|
| 314 |
+
cfg = (
|
| 315 |
+
f"`{payload.config_summary}`"
|
| 316 |
+
+ (f"\n\n_Follow-up rewritten as:_ `{payload.rewritten_query}`"
|
| 317 |
+
if payload.rewritten_query and payload.rewritten_query != user_msg
|
| 318 |
+
else "")
|
| 319 |
+
)
|
| 320 |
+
yield (new_chat_msgs, hist_pairs or [], gr.update(),
|
| 321 |
+
_format_timings(payload.timings),
|
| 322 |
+
cfg,
|
| 323 |
+
"_(running guardrails after generation)_",
|
| 324 |
+
_format_chunks(payload.chunks),
|
| 325 |
+
None)
|
| 326 |
+
elif event_type == "token":
|
| 327 |
+
accumulated = payload
|
| 328 |
+
# Update the LAST assistant message in place
|
| 329 |
+
new_chat_msgs[-1] = {"role": "assistant", "content": accumulated or "…"}
|
| 330 |
+
yield (new_chat_msgs, hist_pairs or [], gr.update(),
|
| 331 |
+
gr.update(), gr.update(), gr.update(), gr.update(), None)
|
| 332 |
+
elif event_type == "done":
|
| 333 |
+
final_result = payload
|
| 334 |
+
final_answer = final_result.answer or accumulated or "_(no answer)_"
|
| 335 |
+
new_chat_msgs[-1] = {"role": "assistant", "content": final_answer}
|
| 336 |
+
new_pairs = (hist_pairs or []) + [(user_msg, final_answer)]
|
| 337 |
+
cfg = f"`{final_result.config_summary}` · query_id=`{final_result.query_id or '—'}`"
|
| 338 |
+
if final_result.rewritten_query and final_result.rewritten_query != user_msg:
|
| 339 |
+
cfg += f"\n\n_Follow-up rewritten as:_ `{final_result.rewritten_query}`"
|
| 340 |
+
yield (new_chat_msgs, new_pairs, gr.update(),
|
| 341 |
+
_format_timings(final_result.timings),
|
| 342 |
+
cfg,
|
| 343 |
+
_format_guardrails(final_result.guardrail_report),
|
| 344 |
+
_format_chunks(final_result.chunks),
|
| 345 |
+
final_result)
|
| 346 |
+
return
|
| 347 |
|
| 348 |
def _on_clear():
|
| 349 |
return [], [], None, "", "", "_(send a message to see guardrails)_", "_(send a message to see retrieved passages)_"
|
|
|
|
| 391 |
# Build the app
|
| 392 |
# =============================================================================
|
| 393 |
|
| 394 |
+
_BANKY_THEME = gr.themes.Base(
|
| 395 |
+
primary_hue=gr.themes.colors.amber,
|
| 396 |
+
secondary_hue=gr.themes.colors.slate,
|
| 397 |
+
neutral_hue=gr.themes.colors.slate,
|
| 398 |
+
font=[gr.themes.GoogleFont("Inter"), "ui-sans-serif", "system-ui", "sans-serif"],
|
| 399 |
+
font_mono=[gr.themes.GoogleFont("JetBrains Mono"), "ui-monospace", "monospace"],
|
| 400 |
+
).set(
|
| 401 |
+
body_background_fill="#0b1220",
|
| 402 |
+
body_background_fill_dark="#0b1220",
|
| 403 |
+
body_text_color="#e5e7eb",
|
| 404 |
+
body_text_color_dark="#e5e7eb",
|
| 405 |
+
background_fill_primary="#111827",
|
| 406 |
+
background_fill_primary_dark="#111827",
|
| 407 |
+
background_fill_secondary="#0f172a",
|
| 408 |
+
background_fill_secondary_dark="#0f172a",
|
| 409 |
+
block_background_fill="#0f172a",
|
| 410 |
+
block_background_fill_dark="#0f172a",
|
| 411 |
+
block_border_color="#1f2937",
|
| 412 |
+
block_border_color_dark="#1f2937",
|
| 413 |
+
block_label_background_fill="#0b1220",
|
| 414 |
+
block_label_background_fill_dark="#0b1220",
|
| 415 |
+
block_label_text_color="#fbbf24",
|
| 416 |
+
block_label_text_color_dark="#fbbf24",
|
| 417 |
+
border_color_accent="#fbbf24",
|
| 418 |
+
border_color_accent_dark="#fbbf24",
|
| 419 |
+
border_color_primary="#1f2937",
|
| 420 |
+
border_color_primary_dark="#1f2937",
|
| 421 |
+
button_primary_background_fill="#fbbf24",
|
| 422 |
+
button_primary_background_fill_dark="#fbbf24",
|
| 423 |
+
button_primary_background_fill_hover="#f59e0b",
|
| 424 |
+
button_primary_background_fill_hover_dark="#f59e0b",
|
| 425 |
+
button_primary_text_color="#0b1220",
|
| 426 |
+
button_primary_text_color_dark="#0b1220",
|
| 427 |
+
button_secondary_background_fill="#1f2937",
|
| 428 |
+
button_secondary_background_fill_dark="#1f2937",
|
| 429 |
+
button_secondary_text_color="#e5e7eb",
|
| 430 |
+
button_secondary_text_color_dark="#e5e7eb",
|
| 431 |
+
input_background_fill="#0b1220",
|
| 432 |
+
input_background_fill_dark="#0b1220",
|
| 433 |
+
input_border_color="#1f2937",
|
| 434 |
+
input_border_color_dark="#1f2937",
|
| 435 |
+
input_border_color_focus="#fbbf24",
|
| 436 |
+
input_border_color_focus_dark="#fbbf24",
|
| 437 |
+
color_accent_soft="#1f2937",
|
| 438 |
+
color_accent_soft_dark="#1f2937",
|
| 439 |
+
)
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
_BANKY_CSS = """
|
| 443 |
+
.gradio-container { max-width: 1280px !important; }
|
| 444 |
+
h1, h2, h3, h4 { letter-spacing: -0.01em; font-weight: 600; }
|
| 445 |
+
h1 { font-size: 1.75rem !important; }
|
| 446 |
+
.tabitem { padding-top: 0.5rem; }
|
| 447 |
+
|
| 448 |
+
/* Tighter accordion headers */
|
| 449 |
+
.label-wrap > span { font-weight: 600; letter-spacing: 0.01em; }
|
| 450 |
+
|
| 451 |
+
/* Markdown body text */
|
| 452 |
+
.prose { line-height: 1.55; }
|
| 453 |
+
.prose code, .prose pre { background: #0b1220 !important; border: 1px solid #1f2937; border-radius: 6px; }
|
| 454 |
+
.prose table { font-size: 0.92em; }
|
| 455 |
+
.prose th { background: #0b1220 !important; color: #fbbf24 !important; font-weight: 600; }
|
| 456 |
+
.prose td { border-color: #1f2937 !important; }
|
| 457 |
+
|
| 458 |
+
/* Chatbot polish */
|
| 459 |
+
.message.user { background: #1e293b !important; }
|
| 460 |
+
.message.bot, .message.assistant { background: #0f172a !important; border: 1px solid #1f2937; }
|
| 461 |
+
|
| 462 |
+
/* Subtle gold accent on the title bar */
|
| 463 |
+
#title-banner {
|
| 464 |
+
border-left: 3px solid #fbbf24;
|
| 465 |
+
padding-left: 0.85rem;
|
| 466 |
+
margin: 0.25rem 0 1rem 0;
|
| 467 |
+
}
|
| 468 |
+
#title-banner h1 { margin: 0; font-size: 1.5rem !important; }
|
| 469 |
+
#title-banner .tagline { color: #94a3b8; font-size: 0.95rem; margin-top: 0.15rem; }
|
| 470 |
+
"""
|
| 471 |
+
|
| 472 |
+
|
| 473 |
def build_app() -> gr.Blocks:
|
| 474 |
with gr.Blocks(title="BankMind") as demo:
|
| 475 |
+
gr.HTML(
|
| 476 |
+
"""
|
| 477 |
+
<div id="title-banner">
|
| 478 |
+
<h1>🏦 BankMind</h1>
|
| 479 |
+
<div class="tagline">Multi-domain RAG for financial intelligence: regulatory compliance and credit risk.</div>
|
| 480 |
+
</div>
|
| 481 |
+
"""
|
| 482 |
)
|
| 483 |
with gr.Tabs():
|
| 484 |
with gr.Tab("⚖️ Compliance Q&A"):
|
|
|
|
| 517 |
|
| 518 |
if __name__ == "__main__":
|
| 519 |
app = build_app()
|
| 520 |
+
app.queue(default_concurrency_limit=4, max_size=16)
|
| 521 |
app.launch(server_name="127.0.0.1", server_port=7860, show_error=True,
|
| 522 |
+
theme=_BANKY_THEME, css=_BANKY_CSS)
|
app/query_pipeline.py
CHANGED
|
@@ -19,7 +19,7 @@ from typing import Optional
|
|
| 19 |
|
| 20 |
from pipelines.shared.fusion import convex_combination, hierarchical, rrf
|
| 21 |
from pipelines.shared.guardrails import GuardrailReport, check as run_guardrails
|
| 22 |
-
from pipelines.shared.llm import claude_text
|
| 23 |
from pipelines.shared.query_logger import chunk_for_log, log_query
|
| 24 |
from pipelines.shared.query_transformer import apply_transform
|
| 25 |
from pipelines.shared.reranker import rerank
|
|
@@ -123,7 +123,9 @@ def _format_history(history: list[tuple[str, str]], *, max_turns: int = 6) -> st
|
|
| 123 |
def _rewrite_followup(query: str, history: list[tuple[str, str]]) -> str:
|
| 124 |
"""Rewrite a possibly-ambiguous follow-up into a standalone query.
|
| 125 |
|
| 126 |
-
|
|
|
|
|
|
|
| 127 |
"""
|
| 128 |
if not history:
|
| 129 |
return query
|
|
@@ -131,10 +133,10 @@ def _rewrite_followup(query: str, history: list[tuple[str, str]]) -> str:
|
|
| 131 |
rewritten = claude_text(
|
| 132 |
_REWRITE_USER.format(history=_format_history(history), query=query),
|
| 133 |
system=_REWRITE_SYSTEM,
|
|
|
|
| 134 |
max_tokens=200,
|
| 135 |
temperature=0.0,
|
| 136 |
)
|
| 137 |
-
# Defensive: strip surrounding quotes / leading/trailing whitespace
|
| 138 |
rewritten = rewritten.strip().strip('"').strip("'").strip()
|
| 139 |
return rewritten or query
|
| 140 |
except Exception:
|
|
@@ -305,3 +307,170 @@ def run_query(
|
|
| 305 |
query_id=qid,
|
| 306 |
rewritten_query=rewritten_query,
|
| 307 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
|
| 20 |
from pipelines.shared.fusion import convex_combination, hierarchical, rrf
|
| 21 |
from pipelines.shared.guardrails import GuardrailReport, check as run_guardrails
|
| 22 |
+
from pipelines.shared.llm import FAST_MODEL, claude_text, claude_text_stream
|
| 23 |
from pipelines.shared.query_logger import chunk_for_log, log_query
|
| 24 |
from pipelines.shared.query_transformer import apply_transform
|
| 25 |
from pipelines.shared.reranker import rerank
|
|
|
|
| 123 |
def _rewrite_followup(query: str, history: list[tuple[str, str]]) -> str:
|
| 124 |
"""Rewrite a possibly-ambiguous follow-up into a standalone query.
|
| 125 |
|
| 126 |
+
Uses the fast/cheap model (Haiku by default) since paraphrasing is a
|
| 127 |
+
utility task and Sonnet is overkill. No-op when history is empty.
|
| 128 |
+
On any LLM error, returns the original query.
|
| 129 |
"""
|
| 130 |
if not history:
|
| 131 |
return query
|
|
|
|
| 133 |
rewritten = claude_text(
|
| 134 |
_REWRITE_USER.format(history=_format_history(history), query=query),
|
| 135 |
system=_REWRITE_SYSTEM,
|
| 136 |
+
model=FAST_MODEL,
|
| 137 |
max_tokens=200,
|
| 138 |
temperature=0.0,
|
| 139 |
)
|
|
|
|
| 140 |
rewritten = rewritten.strip().strip('"').strip("'").strip()
|
| 141 |
return rewritten or query
|
| 142 |
except Exception:
|
|
|
|
| 307 |
query_id=qid,
|
| 308 |
rewritten_query=rewritten_query,
|
| 309 |
)
|
| 310 |
+
|
| 311 |
+
|
| 312 |
+
# === Streaming variant for the chat UI ======================================
|
| 313 |
+
# Yields three kinds of events:
|
| 314 |
+
# ("setup", QueryResult) -> retrieval + rerank done, generation starts
|
| 315 |
+
# ("token", str) -> cumulative answer text (after each delta)
|
| 316 |
+
# ("done", QueryResult) -> final guardrail-checked + logged result
|
| 317 |
+
# This lets the UI render the chunks panel and timings as soon as retrieval is
|
| 318 |
+
# done, then stream tokens into the chat, then apply guardrails when generation
|
| 319 |
+
# completes.
|
| 320 |
+
|
| 321 |
+
def run_query_stream(
|
| 322 |
+
*,
|
| 323 |
+
query: str,
|
| 324 |
+
module: str,
|
| 325 |
+
chunk_strategy: str,
|
| 326 |
+
embedding_dim: int,
|
| 327 |
+
retrieval_method: str,
|
| 328 |
+
reranker: str,
|
| 329 |
+
query_transform: str,
|
| 330 |
+
top_k: int = 10,
|
| 331 |
+
final_k: int = 5,
|
| 332 |
+
chat_history: Optional[list[tuple[str, str]]] = None,
|
| 333 |
+
max_answer_tokens: int = 900,
|
| 334 |
+
):
|
| 335 |
+
if not query.strip():
|
| 336 |
+
yield ("done", QueryResult(answer=None, chunks=[], timings={},
|
| 337 |
+
transformed_queries=[], config_summary="(empty query)"))
|
| 338 |
+
return
|
| 339 |
+
|
| 340 |
+
timings: dict[str, float] = {}
|
| 341 |
+
retr = _retriever()
|
| 342 |
+
history = chat_history or []
|
| 343 |
+
|
| 344 |
+
# 1) Follow-up rewrite (Haiku, fast)
|
| 345 |
+
rewritten_query = None
|
| 346 |
+
retrieval_query = query
|
| 347 |
+
if history:
|
| 348 |
+
t0 = time.perf_counter()
|
| 349 |
+
rewritten_query = _rewrite_followup(query, history)
|
| 350 |
+
retrieval_query = rewritten_query
|
| 351 |
+
timings["rewrite_ms"] = (time.perf_counter() - t0) * 1000
|
| 352 |
+
|
| 353 |
+
# 2) Optional pre-retrieval transform
|
| 354 |
+
t0 = time.perf_counter()
|
| 355 |
+
try:
|
| 356 |
+
tr = apply_transform(
|
| 357 |
+
query_transform, retrieval_query, module=module,
|
| 358 |
+
retriever=retr, chunk_strategy=chunk_strategy,
|
| 359 |
+
embedding_dim=embedding_dim,
|
| 360 |
+
)
|
| 361 |
+
except Exception as e:
|
| 362 |
+
from pipelines.shared.query_transformer import TransformResult
|
| 363 |
+
tr = TransformResult(queries=[retrieval_query], transform_name="none-fallback",
|
| 364 |
+
extras={"error": str(e)})
|
| 365 |
+
timings["transform_ms"] = (time.perf_counter() - t0) * 1000
|
| 366 |
+
|
| 367 |
+
def _retrieve_one(q: str) -> list[ScoredChunk]:
|
| 368 |
+
if retrieval_method == "dense":
|
| 369 |
+
return retr.search(query=q, module=module, chunk_strategy=chunk_strategy,
|
| 370 |
+
mode="dense", embedding_dim=embedding_dim, top_k=top_k)
|
| 371 |
+
if retrieval_method in ("bm25", "splade"):
|
| 372 |
+
return retr.search(query=q, module=module, chunk_strategy=chunk_strategy,
|
| 373 |
+
mode="sparse", sparse_name=retrieval_method, top_k=top_k)
|
| 374 |
+
if retrieval_method == "hybrid_rrf":
|
| 375 |
+
return retr.search(query=q, module=module, chunk_strategy=chunk_strategy,
|
| 376 |
+
mode="hybrid", embedding_dim=embedding_dim, top_k=top_k)
|
| 377 |
+
if retrieval_method == "hybrid_convex":
|
| 378 |
+
d, s, _ = retr.search_separate_channels(
|
| 379 |
+
query=q, module=module, chunk_strategy=chunk_strategy,
|
| 380 |
+
embedding_dim=embedding_dim, top_k=50,
|
| 381 |
+
)
|
| 382 |
+
return convex_combination(d, s, alpha=0.7, top_k=top_k)
|
| 383 |
+
if retrieval_method == "hybrid_hier":
|
| 384 |
+
d, s, _ = retr.search_separate_channels(
|
| 385 |
+
query=q, module=module, chunk_strategy=chunk_strategy,
|
| 386 |
+
embedding_dim=embedding_dim, top_k=50,
|
| 387 |
+
)
|
| 388 |
+
return hierarchical(q, d, s, top_k=top_k)
|
| 389 |
+
raise ValueError(f"unknown retrieval_method: {retrieval_method}")
|
| 390 |
+
|
| 391 |
+
t0 = time.perf_counter()
|
| 392 |
+
if len(tr.queries) == 1:
|
| 393 |
+
retrieved = _retrieve_one(tr.queries[0])
|
| 394 |
+
else:
|
| 395 |
+
retrieved = rrf([_retrieve_one(q) for q in tr.queries], top_k=top_k)
|
| 396 |
+
timings["retrieve_ms"] = (time.perf_counter() - t0) * 1000
|
| 397 |
+
|
| 398 |
+
t0 = time.perf_counter()
|
| 399 |
+
if reranker != "none":
|
| 400 |
+
try:
|
| 401 |
+
top = rerank(retrieval_query, retrieved, name=reranker, top_n=final_k)
|
| 402 |
+
except Exception as e:
|
| 403 |
+
top = retrieved[:final_k]
|
| 404 |
+
timings["reranker_error"] = str(e)
|
| 405 |
+
else:
|
| 406 |
+
top = retrieved[:final_k]
|
| 407 |
+
timings["rerank_ms"] = (time.perf_counter() - t0) * 1000
|
| 408 |
+
|
| 409 |
+
# Setup event: chunks are known, timings up to this point are known.
|
| 410 |
+
setup_result = QueryResult(
|
| 411 |
+
answer=None,
|
| 412 |
+
chunks=top,
|
| 413 |
+
timings=dict(timings),
|
| 414 |
+
transformed_queries=tr.queries if query_transform != "none" else [],
|
| 415 |
+
config_summary=(
|
| 416 |
+
f"module={module} strategy={chunk_strategy} dim={embedding_dim} "
|
| 417 |
+
f"retrieval={retrieval_method} reranker={reranker} transform={query_transform} "
|
| 418 |
+
f"top_k={top_k} final_k={final_k} chat_turns={len(history)}"
|
| 419 |
+
),
|
| 420 |
+
rewritten_query=rewritten_query,
|
| 421 |
+
)
|
| 422 |
+
yield ("setup", setup_result)
|
| 423 |
+
|
| 424 |
+
# Stream generation
|
| 425 |
+
role = "compliance officer" if module == "compliance" else "credit analyst"
|
| 426 |
+
passages = "\n\n".join(
|
| 427 |
+
f"[{i+1}] (doc: {c.payload.get('doc_id','?')}, section: {c.payload.get('section_title','')})\n{c.content[:1800]}"
|
| 428 |
+
for i, c in enumerate(top)
|
| 429 |
+
)
|
| 430 |
+
if history:
|
| 431 |
+
user_prompt = _USER_TURN_TEMPLATE.format(
|
| 432 |
+
history=_format_history(history), query=query, passages=passages,
|
| 433 |
+
)
|
| 434 |
+
else:
|
| 435 |
+
user_prompt = _NO_HISTORY_TEMPLATE.format(query=query, passages=passages)
|
| 436 |
+
|
| 437 |
+
t0 = time.perf_counter()
|
| 438 |
+
accumulated = ""
|
| 439 |
+
for partial in claude_text_stream(
|
| 440 |
+
user_prompt,
|
| 441 |
+
system=_SYSTEM_PROMPT.format(role=role),
|
| 442 |
+
max_tokens=max_answer_tokens,
|
| 443 |
+
temperature=0.0,
|
| 444 |
+
):
|
| 445 |
+
accumulated = partial
|
| 446 |
+
yield ("token", accumulated)
|
| 447 |
+
timings["generate_ms"] = (time.perf_counter() - t0) * 1000
|
| 448 |
+
|
| 449 |
+
# Guardrails + log
|
| 450 |
+
t0 = time.perf_counter()
|
| 451 |
+
guardrail_report = run_guardrails(module, accumulated, top, query)
|
| 452 |
+
timings["guardrails_ms"] = (time.perf_counter() - t0) * 1000
|
| 453 |
+
timings["total_ms"] = sum(v for k, v in timings.items() if k.endswith("_ms"))
|
| 454 |
+
|
| 455 |
+
config_dict = {
|
| 456 |
+
"module": module, "chunk_strategy": chunk_strategy,
|
| 457 |
+
"embedding_dim": embedding_dim, "retrieval_method": retrieval_method,
|
| 458 |
+
"reranker": reranker, "query_transform": query_transform,
|
| 459 |
+
"top_k": top_k, "final_k": final_k, "generate_answer": True,
|
| 460 |
+
"chat_turns": len(history), "rewritten_query": rewritten_query,
|
| 461 |
+
"streaming": True,
|
| 462 |
+
}
|
| 463 |
+
qid = log_query(
|
| 464 |
+
query=query, config=config_dict, timings=timings,
|
| 465 |
+
transformed_queries=tr.queries if query_transform != "none" else [],
|
| 466 |
+
top_chunks=[chunk_for_log(c) for c in top], answer=accumulated,
|
| 467 |
+
guardrail_report=guardrail_report.to_dict(),
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
yield ("done", QueryResult(
|
| 471 |
+
answer=accumulated, chunks=top, timings=timings,
|
| 472 |
+
transformed_queries=tr.queries if query_transform != "none" else [],
|
| 473 |
+
config_summary=setup_result.config_summary,
|
| 474 |
+
guardrail_report=guardrail_report, query_id=qid,
|
| 475 |
+
rewritten_query=rewritten_query,
|
| 476 |
+
))
|
pipelines/shared/llm.py
CHANGED
|
@@ -29,6 +29,8 @@ from dotenv import load_dotenv
|
|
| 29 |
load_dotenv()
|
| 30 |
|
| 31 |
DEFAULT_MODEL = os.environ.get("CLAUDE_MODEL", "claude-sonnet-4-6")
|
|
|
|
|
|
|
| 32 |
DEFAULT_MAX_TOKENS = 1024
|
| 33 |
|
| 34 |
|
|
@@ -84,6 +86,36 @@ def claude_text(
|
|
| 84 |
return "".join(parts).strip()
|
| 85 |
|
| 86 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 87 |
_JSON_FENCE = re.compile(r"```(?:json)?\s*([\s\S]*?)```")
|
| 88 |
|
| 89 |
|
|
|
|
| 29 |
load_dotenv()
|
| 30 |
|
| 31 |
DEFAULT_MODEL = os.environ.get("CLAUDE_MODEL", "claude-sonnet-4-6")
|
| 32 |
+
# Light, fast model for utility tasks like follow-up rewriting
|
| 33 |
+
FAST_MODEL = os.environ.get("CLAUDE_FAST_MODEL", "claude-haiku-4-5-20251001")
|
| 34 |
DEFAULT_MAX_TOKENS = 1024
|
| 35 |
|
| 36 |
|
|
|
|
| 86 |
return "".join(parts).strip()
|
| 87 |
|
| 88 |
|
| 89 |
+
def claude_text_stream(
|
| 90 |
+
prompt: str,
|
| 91 |
+
*,
|
| 92 |
+
system: str = "",
|
| 93 |
+
model: str = DEFAULT_MODEL,
|
| 94 |
+
max_tokens: int = DEFAULT_MAX_TOKENS,
|
| 95 |
+
temperature: float = 0.0,
|
| 96 |
+
):
|
| 97 |
+
"""Generator yielding the response text as it arrives.
|
| 98 |
+
|
| 99 |
+
Each yield is the *cumulative* output so far (suitable for piping straight
|
| 100 |
+
into Gradio's Chatbot streaming). On error, yields a single error message.
|
| 101 |
+
"""
|
| 102 |
+
try:
|
| 103 |
+
client = _get_client()
|
| 104 |
+
accumulated = ""
|
| 105 |
+
with client.messages.stream(
|
| 106 |
+
model=model,
|
| 107 |
+
max_tokens=max_tokens,
|
| 108 |
+
temperature=temperature,
|
| 109 |
+
system=system or "You are a helpful assistant.",
|
| 110 |
+
messages=[{"role": "user", "content": prompt}],
|
| 111 |
+
) as stream:
|
| 112 |
+
for delta in stream.text_stream:
|
| 113 |
+
accumulated += delta
|
| 114 |
+
yield accumulated
|
| 115 |
+
except Exception as e:
|
| 116 |
+
yield f"_(generation failed: {type(e).__name__}: {e})_"
|
| 117 |
+
|
| 118 |
+
|
| 119 |
_JSON_FENCE = re.compile(r"```(?:json)?\s*([\s\S]*?)```")
|
| 120 |
|
| 121 |
|