File size: 24,417 Bytes
2701365
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1294c2d
2701365
1294c2d
 
2701365
 
 
1294c2d
 
2701365
 
a042485
1294c2d
 
51c0848
a042485
 
2701365
 
1294c2d
2701365
51c0848
 
 
 
2701365
 
 
 
 
1294c2d
2701365
 
 
 
 
 
 
 
 
1294c2d
2701365
1294c2d
51c0848
1294c2d
51c0848
 
1294c2d
 
 
 
 
 
 
 
 
 
 
 
 
2701365
 
 
 
 
 
 
 
 
 
 
2da2fd5
 
 
 
 
 
 
 
 
 
 
 
 
2701365
 
a042485
 
 
 
 
2701365
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1294c2d
2701365
 
1294c2d
 
2701365
1294c2d
 
2701365
 
 
 
 
 
 
 
 
1294c2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51c0848
 
1294c2d
51c0848
 
 
 
 
 
1294c2d
 
51c0848
 
 
1294c2d
51c0848
1294c2d
51c0848
1294c2d
51c0848
1294c2d
51c0848
1294c2d
51c0848
 
1294c2d
 
51c0848
 
 
1294c2d
 
 
 
 
 
51c0848
 
 
 
1294c2d
 
 
 
 
 
51c0848
1294c2d
 
51c0848
1294c2d
51c0848
1294c2d
 
51c0848
 
 
 
 
 
 
 
 
 
 
 
 
634376a
 
 
 
 
 
 
 
 
 
 
 
51c0848
 
 
 
1294c2d
 
 
51c0848
 
1294c2d
51c0848
 
 
 
 
 
1294c2d
51c0848
 
 
 
1294c2d
51c0848
1294c2d
 
51c0848
1294c2d
51c0848
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
634376a
51c0848
 
 
1294c2d
 
 
 
 
 
 
 
 
 
 
 
51c0848
1294c2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2701365
 
51c0848
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2701365
 
 
 
 
 
 
 
 
 
 
 
 
 
a042485
 
51c0848
 
 
 
 
 
a042485
 
2701365
 
 
a042485
 
 
2701365
 
 
 
 
 
51c0848
1294c2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51c0848
 
 
 
 
 
1294c2d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51c0848
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2701365
 
 
 
 
 
51c0848
1294c2d
51c0848
 
1294c2d
 
 
2701365
 
 
 
 
1294c2d
2701365
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
"""
src/workflow/graph.py
ReAct (Reasoning + Acting) workflow using LangGraph.

Pattern:
    1. Code calls LLM with the user query + available tools
    2. LLM reasons and returns a tool call (or a final answer)
    3. Code executes the chosen tool and appends the result to messages
    4. Code calls LLM again β€” LLM sees the result and reasons further
    5. Repeat until LLM returns a final answer (no tool call)

This means the LLM can call MULTIPLE agents in one turn:
    "What's the news on my portfolio stocks?"
    β†’ LLM calls analyze_portfolio  (gets tickers)
    β†’ LLM calls get_financial_news (gets news for those tickers)
    β†’ LLM writes one combined answer

Memory: MemorySaver persists the full state (messages, goal, savings,
risk profile) per thread_id across conversation turns.

Usage:
    from src.workflow.graph import invoke

    r = invoke("I want $2M in 20 years. I have $100K.", thread_id="u1")
    r = invoke("I'm aggressive with risk.",              thread_id="u1")
    r = invoke("Actually I have $200K saved.",           thread_id="u1")
    r = invoke("What about 401k and NVDA news?",         thread_id="u1")
    print(r["answer"])
"""

import re
from typing import Annotated

from pydantic import BaseModel

from langgraph.graph import StateGraph, START, END
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import ToolNode, tools_condition
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage, ToolMessage, trim_messages

from src.core.llm import load_llm
from src.utils.logger import get_logger
from src.workflow.state import FinnieState
from src.workflow.tools import TOOLS
from src.workflow.prompts import _system_prompt, _synth_prompt

log = get_logger(__name__)


# ── Nodes ─────────────────────────────────────────────────────────────────────

# "AAPL: 100" or "MSFT - 200"
_TICKER_RE = re.compile(r'\b([A-Z]{2,5})\s*[:\-]\s*(\d+(?:\.\d+)?)\b')
# "100 AAPL" or "1000 QQQ" β€” number-first format
_TICKER_RE_REVERSE = re.compile(r'\b(\d+(?:\.\d+)?)\s+([A-Z]{2,5})\b')


def param_extractor_node(state: FinnieState) -> dict:
    """
    Runs once at the start of each turn.
    Extracts portfolio holdings and age from the latest message into state.
    """
    last_human = next(
        (m for m in reversed(state.get("messages", []))
         if hasattr(m, "type") and m.type == "human"),
        None,
    )
    if not last_human:
        return {}

    text = str(last_human.content)
    updates: dict = {}

    # Extract portfolio holdings β€” try "TICKER: N" format first, then "N TICKER" format
    holdings = {m.group(1): float(m.group(2)) for m in _TICKER_RE.finditer(text)}
    if len(holdings) < 2:
        holdings = {m.group(2): float(m.group(1)) for m in _TICKER_RE_REVERSE.finditer(text)}
    if len(holdings) >= 2:          # require at least 2 tickers to avoid false positives
        updates["portfolio_holdings"] = holdings
        log.info("ParamExtractor | portfolio=%s", list(holdings.keys()))

    age_m = re.search(r'\b(?:i am|i\'m|im)\s+(\d{1,3})\s*(?:years?\s*old)?\b', text.lower())
    if age_m:
        candidate = int(age_m.group(1))
        if 18 <= candidate <= 100:
            updates["age"] = candidate
            log.info("ParamExtractor | age=%d", candidate)

    if updates:
        log.info("ParamExtractor | extracted=%s", {k: v for k, v in updates.items() if k != "portfolio_holdings"})
    return updates


def agent_node(state: FinnieState) -> dict:
    """
    Core ReAct node β€” LLM reasons about the query and decides what to do:
      - Returns a tool call β†’ ToolNode executes it, loop continues
      - Returns a plain message β†’ conversation turn is complete
    """
    llm_with_tools = load_llm().bind_tools(TOOLS)

    # Keep only the last 20 messages to avoid unbounded context growth.
    # The system prompt always carries the key user context (goal, savings, risk),
    # so trimming old turns doesn't lose critical state.
    recent = trim_messages(
        state["messages"],
        max_tokens=20,
        token_counter=len,      # count by message count, not tokens
        strategy="last",
        start_on="human",       # never start mid-tool-call
        include_system=False,
    )
    messages = [SystemMessage(content=_system_prompt(state))] + recent
    log.debug("LLM | sending %d messages (trimmed from %d)", len(messages), len(state["messages"]) + 1)
    response = llm_with_tools.invoke(messages)

    if response.tool_calls:
        log.info("LLM | β†’ tool_calls: %s", [tc["name"] for tc in response.tool_calls])
    else:
        log.info("LLM | β†’ final answer (%d chars)", len(str(response.content)))

    return {"messages": [response]}


# ── Graph builder ─────────────────────────────────────────────────────────────

def build_graph():
    """
    Build and compile the Finnie ReAct graph.

    Flow per turn:
        START β†’ param_extractor β†’ agent βŸ΅β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
                                     ↓ tool_call?        β”‚
                                  tool_node β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
                                     ↓ no tool_call
                                    END
    """
    builder = StateGraph(FinnieState)

    builder.add_node("param_extractor", param_extractor_node)
    builder.add_node("agent",           agent_node)
    builder.add_node("tools",           ToolNode(TOOLS))

    # Entry: always extract params first, then let LLM reason
    builder.add_edge(START,             "param_extractor")
    builder.add_edge("param_extractor", "agent")

    # ReAct loop: if LLM returned a tool call β†’ execute β†’ reason again
    #             if LLM returned a final answer β†’ done
    builder.add_conditional_edges("agent", tools_condition)
    builder.add_edge("tools", "agent")

    return builder.compile(checkpointer=_MEMORY)


# ── Shared memory β€” both graphs use the same MemorySaver so switching
#    between ReAct and fan-out mid-conversation never loses state ────────────

_MEMORY = MemorySaver()
_graph  = None


def _get_graph():
    global _graph
    if _graph is None:
        _graph = build_graph()
    return _graph


# ── Smart parallel fan-out graph ──────────────────────────────────────────────
#
# Unlike the ReAct graph (LLM picks tools one at a time, sequentially),
# this graph first routes to only the RELEVANT agents, then runs them in
# parallel:
#
#   START β†’ param_extractor β†’ smart_fanout β†’ tools β†’ synth β†’ END
#
# smart_fanout_node: asks LLM which subset of tools is needed β†’ AIMessage
# ToolNode:          runs only those tools in parallel β†’ ToolMessages
# synth_node:        LLM synthesizes all results into one answer

# One-line descriptions used by the router β€” kept short so the routing call
# is cheap and the LLM focuses on intent rather than implementation detail.
_TOOL_DESCRIPTIONS = {
    "answer_finance_question": "General financial education: what is X, how does Y work, compound interest, ETFs, bonds, index funds",
    "analyze_portfolio":       "Analyse the user's portfolio holdings: allocation, diversification score, sector breakdown",
    "get_market_data":         "Real-time price, P/E ratio, market cap and analysis for specific stock tickers",
    "plan_financial_goal":     "Savings / retirement goal planning: can I reach $X in Y years, monthly contributions needed",
    "get_financial_news":      "Recent news headlines for specific stocks or tickers",
    "get_tax_education":       "Tax on investments: capital gains, IRA/Roth IRA, 401k, HSA, tax-loss harvesting",
}


class _ToolSelection(BaseModel):
    tools: list[str]


def _enrich_query(state: FinnieState, query: str) -> str:
    """Inject remembered portfolio tickers when the user says 'my portfolio' without listing them."""
    holdings = state.get("portfolio_holdings") or {}
    if not holdings or _TICKER_RE.search(query):
        return query
    if not any(p in query.lower() for p in ["my holding", "my portfolio", "my stock", "my top"]):
        return query
    ticker_list = ", ".join(f"{t}: {int(s)}" for t, s in holdings.items())
    return f"Portfolio: {ticker_list}\n\nQuestion: {query}"


def _build_ctx_note(state: FinnieState) -> str:
    """Summarize established conversation context for the routing prompt."""
    parts = []
    if state.get("goal_amount"):
        parts.append(f"goal ${state['goal_amount']:,.0f}")
    if state.get("time_horizon_years"):
        parts.append(f"timeline {state['time_horizon_years']:.0f} yr")
    if state.get("current_savings") is not None:
        parts.append(f"savings ${state['current_savings']:,.0f}")
    if state.get("risk_profile") and state.get("risk_profile") != "moderate":
        parts.append(f"risk {state['risk_profile']}")
    if state.get("portfolio_holdings"):
        parts.append(f"portfolio {list(state['portfolio_holdings'].keys())}")
    return f"\nConversation context already established: {', '.join(parts)}." if parts else ""


def _select_tools(query: str, ctx_note: str) -> list[str]:
    """Ask the LLM (structured output) to pick the 2 most relevant tools for this query."""
    tool_list = "\n".join(f"- {name}: {desc}" for name, desc in _TOOL_DESCRIPTIONS.items())
    routing_prompt = (
        f"Select the tools needed to give a complete, well-rounded answer to this query.\n\n"
        f"Available tools:\n{tool_list}\n\n"
        f"Query: {query}{ctx_note}\n\n"
        "Rules (apply the FIRST matching rule and stop β€” do not stack rules):\n"
        "- News, headlines, or recent events for a specific stock or ticker β†’ get_financial_news + get_market_data\n"
        "- General advice, tips, or education β†’ answer_finance_question + plan_financial_goal (if goal in context, else + get_tax_education)\n"
        "- Any 'explain', 'what is', 'how does', 'what are' question (NOT news/prices) β†’ answer_finance_question + get_tax_education\n"
        "- Context has goal_amount + timeline and message adds savings/contribution/risk β†’ plan_financial_goal + get_tax_education\n"
        "- Portfolio questions (holdings, sectors, allocation, P/E, rate sensitivity) β†’ analyze_portfolio + get_market_data\n"
        "- Retirement / savings goal questions β†’ plan_financial_goal + get_tax_education\n"
        "- 'Is my allocation right for my age?' β†’ analyze_portfolio + answer_finance_question\n"
        "- Rate hike / interest rate vulnerability β†’ analyze_portfolio + answer_finance_question\n"
        "- Stock news or market events β†’ get_financial_news + get_market_data\n"
        "- Tax questions (selling, gains, IRA, 401k) β†’ get_tax_education + answer_finance_question\n"
        "- 52-week high, dividends, P/E for a specific stock β†’ get_market_data + answer_finance_question\n"
        "- ALWAYS select exactly 2 tools.\n"
    )
    valid_names = {t.name for t in TOOLS}
    selection = load_llm().with_structured_output(_ToolSelection).invoke([HumanMessage(content=routing_prompt)])
    selected = [name for name in selection.tools if name in valid_names]
    return selected or ["answer_finance_question"]


_MARKET_INTENT = {"get_market_data", "get_financial_news", "analyze_portfolio"}


def _apply_goal_override(state: FinnieState, selected: list[str]) -> list[str]:
    """Pull in plan_financial_goal for follow-ups when an active goal is in state and no market tool was picked."""
    if not (state.get("goal_amount") or state.get("time_horizon_years")):
        return selected
    if "plan_financial_goal" in selected or any(t in selected for t in _MARKET_INTENT):
        return selected
    log.info("SmartFanOut | injected plan_financial_goal for active goal context")
    return ["plan_financial_goal"] + [s for s in selected if s != "get_financial_news"][:1]


def _apply_portfolio_override(selected: list[str], holdings: dict) -> list[str]:
    """Ensure get_market_data is always paired with analyze_portfolio.

    The routing LLM sometimes picks answer_finance_question as the companion,
    which skips live price data. Force the correct pairing whenever holdings exist.
    """
    if "analyze_portfolio" in selected and "get_market_data" not in selected and holdings:
        log.info("SmartFanOut | injected get_market_data alongside analyze_portfolio")
        return ["analyze_portfolio", "get_market_data"]
    return selected


def _build_tool_calls(query: str, selected: list[str], holdings: dict) -> list[dict]:
    """Build tool_calls, expanding get_market_data to per-ticker calls for comparison/portfolio queries."""
    q = query.lower()
    is_comparison        = any(t in q for t in ["p/e", "pe ratio", "price-to-earnings", "compare", "versus", "vs"])
    is_portfolio_analysis = "analyze_portfolio" in selected and bool(holdings)

    if (is_comparison or is_portfolio_analysis) and "get_market_data" in selected and holdings:
        # Cap at 5 tickers to keep parallel calls manageable; sort by value (shares) descending
        top_n = sorted(holdings.items(), key=lambda x: x[1], reverse=True)[:5]
        if is_comparison:
            per_ticker = [
                {"name": "get_market_data", "args": {"query": f"P/E ratio and valuation for {t}"},
                 "id": f"call_md_{t}", "type": "tool_call"}
                for t, _ in top_n
            ] + [{"name": "get_market_data", "args": {"query": "S&P 500 SPY average P/E ratio valuation"},
                  "id": "call_md_SPY", "type": "tool_call"}]
        else:
            per_ticker = [
                {"name": "get_market_data", "args": {"query": f"current price and analysis for {t}"},
                 "id": f"call_md_{t}", "type": "tool_call"}
                for t, _ in top_n
            ]
        return [
            {"name": name, "args": {"query": query}, "id": f"call_{name}", "type": "tool_call"}
            for name in selected if name != "get_market_data"
        ] + per_ticker

    return [
        {"name": name, "args": {"query": query}, "id": f"call_{name}", "type": "tool_call"}
        for name in selected
    ]


def smart_fanout_node(state: FinnieState) -> dict:
    """Route to the relevant tools and emit parallel tool_calls for ToolNode to execute."""
    last_human = next(
        (m for m in reversed(state.get("messages", []))
         if hasattr(m, "type") and m.type == "human"),
        None,
    )
    query    = str(last_human.content) if last_human else ""
    holdings = state.get("portfolio_holdings") or {}

    query    = _enrich_query(state, query)
    ctx_note = _build_ctx_note(state)
    log.info("SmartFanOut | routing query=%r ctx=%s", query[:120], ctx_note[:80])

    selected = _select_tools(query, ctx_note)
    selected = _apply_goal_override(state, selected)
    selected = _apply_portfolio_override(selected, holdings)
    log.info("SmartFanOut | selected=%s | query=%r", selected, query[:60])

    return {"messages": [AIMessage(content="", tool_calls=_build_tool_calls(query, selected, holdings))]}


def synth_node(state: FinnieState) -> dict:
    """Synthesize all parallel tool results into one final answer."""
    recent = trim_messages(
        state["messages"],
        max_tokens=30,
        token_counter=len,
        strategy="last",
        start_on="human",
        include_system=False,
    )
    messages = [SystemMessage(content=_synth_prompt(state))] + recent
    log.debug("Synth | sending %d messages", len(messages))
    response = load_llm().invoke(messages)
    log.info("Synth | answer_len=%d", len(str(response.content)))
    return {"messages": [response]}


def build_all_graph():
    """Build and compile the smart parallel fan-out graph."""
    builder = StateGraph(FinnieState)

    builder.add_node("param_extractor", param_extractor_node)
    builder.add_node("smart_fanout",    smart_fanout_node)
    builder.add_node("tools",           ToolNode(TOOLS))
    builder.add_node("synth",           synth_node)

    builder.add_edge(START,             "param_extractor")
    builder.add_edge("param_extractor", "smart_fanout")
    builder.add_edge("smart_fanout",    "tools")
    builder.add_edge("tools",           "synth")
    builder.add_edge("synth",           END)

    return builder.compile(checkpointer=_MEMORY)


_all_graph = None


def _get_all_graph():
    global _all_graph
    if _all_graph is None:
        _all_graph = build_all_graph()
    return _all_graph


# ── Public API ────────────────────────────────────────────────────────────────

# Default values injected only on the first turn of a thread.
# Subsequent turns must NOT include these or LangGraph's LastValue channel
# will overwrite persisted checkpoint values (e.g. reset "aggressive" β†’ "moderate").
_FIRST_TURN_DEFAULTS = {
    "risk_profile":        "moderate",
    "current_savings":     None,
    "goal_amount":         None,
    "time_horizon_years":  None,
    "annual_contribution": None,
    "portfolio_holdings":  None,
    "portfolio_value":     None,
    "age":                 None,
}


def invoke(query: str, thread_id: str = "default", _initial: dict | None = None) -> dict:
    """
    Run one conversational turn through the Finnie ReAct workflow.

    Args:
        query:     The user's message.
        thread_id: Session ID β€” all turns with the same ID share memory.

    Returns:
        {
            "answer":    str,   final answer from the LLM
            "messages":  list,  full updated message history
        }
    """
    config = {"configurable": {"thread_id": thread_id}}
    log.info("Invoke | thread=%s | query=%r", thread_id[:8], query[:80])

    if _initial is not None:
        initial = _initial
    else:
        initial = {"messages": [HumanMessage(content=query)]}
        if not _get_graph().checkpointer.get(config):
            initial.update(_FIRST_TURN_DEFAULTS)

    result = _get_graph().invoke(initial, config=config)

    # Last message is always the LLM's final answer
    last = result["messages"][-1]
    n_messages = len(result["messages"])
    log.info("Invoke | thread=%s | messages=%d | answer_len=%d",
             thread_id[:8], n_messages, len(str(last.content)))
    return {
        "answer":   str(last.content),
        "messages": result["messages"],
    }


def invoke_all(query: str, thread_id: str = "default", _initial: dict | None = None) -> dict:
    """
    Run one conversational turn through the smart parallel fan-out workflow.
    The LLM first selects only the relevant agents, then runs them in parallel
    and synthesizes the results into one answer.

    Args:
        query:     The user's message.
        thread_id: Session ID β€” all turns with the same ID share memory.

    Returns:
        {
            "answer":    str,   final synthesized answer
            "messages":  list,  full updated message history
        }
    """
    config = {"configurable": {"thread_id": thread_id}}
    log.info("InvokeAll | thread=%s | query=%r", thread_id[:8], query[:80])

    if _initial is not None:
        initial = _initial
    else:
        initial = {"messages": [HumanMessage(content=query)]}
        if not _get_all_graph().checkpointer.get(config):
            initial.update(_FIRST_TURN_DEFAULTS)

    result = _get_all_graph().invoke(initial, config=config)
    last = result["messages"][-1]

    # Collect tools called in THIS turn only (after the last HumanMessage)
    import re as _re
    msgs = result["messages"]
    last_human_idx = next(
        (i for i in range(len(msgs) - 1, -1, -1)
         if hasattr(msgs[i], "type") and msgs[i].type == "human"),
        0,
    )
    agents_used = [
        m.name for m in msgs[last_human_idx:]
        if isinstance(m, ToolMessage) and hasattr(m, "name")
    ]

    # Escape bare $ so Streamlit/MathJax doesn't treat currency as LaTeX math
    answer = _re.sub(r"(?<!\\)\$", r"\\$", str(last.content))

    log.info("InvokeAll | thread=%s | agents=%s | answer_len=%d",
             thread_id[:8], agents_used, len(answer))
    return {
        "answer":      answer,
        "messages":    result["messages"],
        "agents_used": agents_used,
    }


# ── Queries that need sequential reasoning (tool B uses tool A's output) ──────
#
# For these, the ReAct graph is used so the LLM can chain tool calls.
# Everything else goes through the faster parallel fan-out graph.

_SEQUENTIAL_QUERIES = (
    "p/e", "pe ratio", "price-to-earnings",
    "compare", "versus", "vs ",
    "rate hike", "rate sensitive", "interest rate", "vulnerable",
    "52-week", "52 week",
    "dividend",
    "at a loss", "trading at a loss", "tax benefit", "tax loss",
    "allocation", "too aggressive", "too conservative", "for my age",
    "which of my", "most vulnerable", "most exposed",
    "falls", "drops", "impact on my portfolio",
    "shares of", "retire after selling", "selling my shares",
    "if i sell", "when i sell", "want to sell", "going to sell",
    "sell all", "sell my", "planning to sell",
    "how much tax", "tax will i", "tax do i", "tax on selling",
    # Withdrawal / decumulation β€” need sequential tool chaining
    "withdraw", "withdrawal", "drawdown", "draw down",
    "how long will", "how long would", "how long can",
    "live off", "live on my", "retirement income",
    "how much can i take", "how much can i spend",
    "recalculate", "calculate with",
    # Retirement + portfolio goal β€” analyze_portfolio must run first to feed portfolio_value
    # into plan_financial_goal; parallel fanout cannot do this ordering
    "retire with", "retire in", "want to retire",
    "save for retirement", "retirement goal",
)


def chat(query: str, thread_id: str = "default") -> dict:
    """
    Single entry point for a conversational turn.

    Routes to the ReAct graph when the query requires sequential tool use
    (e.g. one tool's output feeds the next), or to the parallel fan-out
    graph for everything else β€” never both.

    Returns:
        {
            "answer":      str,   final answer, $ signs escaped for Streamlit
            "messages":    list,  full message history
            "agents_used": list,  tool names called this turn
        }
    """
    config = {"configurable": {"thread_id": thread_id}}
    initial: dict = {"messages": [HumanMessage(content=query)]}
    if not _MEMORY.get(config):
        initial.update(_FIRST_TURN_DEFAULTS)

    if any(k in query.lower() for k in _SEQUENTIAL_QUERIES):
        result = invoke(query, thread_id=thread_id, _initial=initial)
        msgs = result["messages"]
        last_human_idx = next(
            (i for i in range(len(msgs) - 1, -1, -1)
             if hasattr(msgs[i], "type") and msgs[i].type == "human"),
            0,
        )
        agents_used = [
            m.name for m in msgs[last_human_idx:]
            if isinstance(m, ToolMessage) and hasattr(m, "name")
        ]
        answer = re.sub(r"(?<!\\)\$", r"\\$", result["answer"])
        return {**result, "answer": answer, "agents_used": agents_used}

    return invoke_all(query, thread_id=thread_id, _initial=initial)


# ── Smoke test ────────────────────────────────────────────────────────────────

if __name__ == "__main__":
    tid = "demo"

    turns = [
       """I have 1000 AAPL, 500 MSFT, 300 GOOGL, 200 TSLA, and 800 NVDA stocks.
I want to retire in 20 years with $2 million and I currently have $100K saved.
Add the value of my portfolio to my savings and tell me if I'm on track to reach my goal. 
Also, what's the news on these stocks?,
I'm aggressive with risk.
can I actually reach my $2M goal, and can you explain how compound interest works?
"""
    ]

    for q in turns:
        print(f"\nUser  : {q}")
        r = invoke(q, thread_id=tid)
        print(f"Finnie: {r['answer'][:10000]}...")
        print("-" * 60)