Spaces:
Running on Zero
Running on Zero
| """Tool-using spending agent (Best Agent badge): plan + multi-step tool calls. | |
| Two layers, both grounded in core.analytics, both emitting a visible tool trace: | |
| 1. A deterministic **router** (`_fast_answer`) maps a clear, single-intent question | |
| straight to the right analytics tool — instant, no model round-trip, no chance | |
| of a hallucinated number. Most everyday questions land here. | |
| 2. For open-ended / multi-part questions, a **ReAct loop** drives MiniCPM4.1-8B | |
| (on Modal): it picks a tool, reads the result, optionally calls another, then | |
| answers. If the small model never produces a clean plan, we fall back to the | |
| grounded one-shot answer. | |
| Either way the UI gets `trace:[{tool,args,result}]`, so judges can see the tools | |
| the agent used. Everything stays well under the 32B cap (1.3B vision + 8B text). | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import re | |
| from datetime import date | |
| from typing import Any | |
| from core import analytics, chat, inference | |
| MAX_STEPS = 2 # fewer Modal round-trips -> faster | |
| AGENT_TOKENS = 220 | |
| _MONTH_NAMES = ["January", "February", "March", "April", "May", "June", "July", | |
| "August", "September", "October", "November", "December"] | |
| _MONTHS = {} | |
| for _i, _full in enumerate(_MONTH_NAMES, start=1): | |
| _MONTHS[_full.lower()] = _i | |
| _MONTHS[_full.lower()[:3]] = _i | |
| _MONTHS["sept"] = 9 | |
| # Synonyms -> canonical category name, so "food"/"petrol"/"cab" resolve. | |
| _CAT_SYNONYMS = { | |
| "food": "Dining", "restaurant": "Dining", "restaurants": "Dining", | |
| "eating out": "Dining", "dine": "Dining", "dining": "Dining", | |
| "grocery": "Groceries", "groceries": "Groceries", | |
| "coffee": "Cafe", "cafe": "Cafe", "tea": "Cafe", | |
| "fuel": "Fuel", "petrol": "Fuel", "diesel": "Fuel", "gas": "Fuel", | |
| "transport": "Transport", "cab": "Transport", "taxi": "Transport", | |
| "uber": "Transport", "ola": "Transport", "bus": "Transport", "metro": "Transport", | |
| "medicine": "Pharmacy", "medicines": "Pharmacy", "pharmacy": "Pharmacy", | |
| "health": "Health", "doctor": "Health", "hospital": "Health", | |
| "clothes": "Clothing", "clothing": "Clothing", "apparel": "Clothing", | |
| "rent": "Rent", "electricity": "Utilities", "water": "Utilities", | |
| "utility": "Utilities", "utilities": "Utilities", "bills": "Utilities", | |
| "shopping": "Shopping", "phone": "Telecom", "mobile": "Telecom", | |
| "recharge": "Telecom", "subscription": "Subscriptions", | |
| "subscriptions": "Subscriptions", "entertainment": "Entertainment", | |
| "movie": "Entertainment", "travel": "Travel", "education": "Education", | |
| "school": "Education", "fees": "Fees & Charges", | |
| } | |
| # --------------------------------------------------------------------------- # | |
| # Period filtering + labels | |
| # --------------------------------------------------------------------------- # | |
| def _period_records(records, period: str) -> list[dict]: | |
| """Filter records by a flexible period: all | this_month | last_month | | |
| this_year | last_year | a specific month 'YYYY-MM' | a specific year 'YYYY'.""" | |
| period = (period or "this_month").strip().lower().replace(" ", "_") | |
| if period in ("all", "total", "overall", "alltime", "all_time", ""): | |
| return records | |
| today = date.today() | |
| def keep(pred): | |
| return [r for r in records if (lambda d: d and pred(d))(analytics.parse_date(r.get("date")))] | |
| mo_match = re.match(r"^(\d{4})-(\d{1,2})$", period) | |
| if mo_match: | |
| y, mo = int(mo_match.group(1)), int(mo_match.group(2)) | |
| return keep(lambda d: (d.year, d.month) == (y, mo)) | |
| yr_match = re.match(r"^(\d{4})$", period) | |
| if yr_match: | |
| y = int(yr_match.group(1)) | |
| return keep(lambda d: d.year == y) | |
| if "this_year" in period or period == "year": | |
| return keep(lambda d: d.year == today.year) | |
| if "last_year" in period: | |
| return keep(lambda d: d.year == today.year - 1) | |
| if "last" in period: # last month | |
| y, m = (today.year, today.month - 1) if today.month > 1 else (today.year - 1, 12) | |
| else: # this month | |
| y, m = today.year, today.month | |
| return keep(lambda d: (d.year, d.month) == (y, m)) | |
| def _plabel(period: str) -> str: | |
| """Human-friendly label for a period, always stated back to the user.""" | |
| p = (period or "all").strip().lower() | |
| fixed = {"all": "all time", "this_month": "this month", "last_month": "last month", | |
| "this_year": "this year", "last_year": "last year", "year": "this year"} | |
| if p in fixed: | |
| return fixed[p] | |
| m = re.match(r"^(\d{4})-(\d{1,2})$", p) | |
| if m: | |
| return f"{_MONTH_NAMES[int(m.group(2)) - 1]} {m.group(1)}" | |
| if re.match(r"^\d{4}$", p): | |
| return p | |
| return p.replace("_", " ") | |
| def _cur(records) -> str: | |
| c = analytics.dominant_currency(records) | |
| return f"{c} " if c else "" | |
| # --------------------------------------------------------------------------- # | |
| # Tools — each takes (records, **args) and returns a short observation string | |
| # --------------------------------------------------------------------------- # | |
| def _t_total_spend(records, period="this_month", **_): | |
| recs = _period_records(records, period) | |
| tot = sum(analytics._num(r.get("total", 0)) for r in recs) | |
| return f"Total spend ({_plabel(period)}): {_cur(records)}{tot:,.2f} over {len(recs)} transaction(s)." | |
| def _t_category_spend(records, category="", period="this_month", **_): | |
| recs = _period_records(records, period) | |
| cat = (category or "").strip().lower() | |
| matched = [r for r in recs if analytics._category(r).lower() == cat] | |
| if not matched: # fall back to a fuzzy contains-match (e.g. "food" ~ "Dining") | |
| matched = [r for r in recs if cat and cat in analytics._category(r).lower()] | |
| tot = sum(analytics._num(r.get("total", 0)) for r in matched) | |
| return f"{category or '?'} spend ({_plabel(period)}): {_cur(records)}{tot:,.2f} over {len(matched)} transaction(s)." | |
| def _t_item_spend(records, query="", period="all", **_): | |
| recs = _period_records(records, period) | |
| q = (query or "").strip().lower() | |
| total, n = 0.0, 0 | |
| for r in recs: | |
| for it in (r.get("line_items") or []): | |
| if q and q in str(it.get("name", "")).lower(): | |
| total += analytics._num(it.get("amount")) | |
| n += 1 | |
| return f"'{query}' ({_plabel(period)}): {_cur(records)}{total:,.2f} across {n} item(s)." | |
| def _t_vendor_spend(records, vendor="", period="all", **_): | |
| recs = _period_records(records, period) | |
| v = (vendor or "").strip().lower() | |
| matched = [r for r in recs if v and v in str(r.get("vendor", "")).lower()] | |
| tot = sum(analytics._num(r.get("total", 0)) for r in matched) | |
| return f"Spend at '{vendor}' ({_plabel(period)}): {_cur(records)}{tot:,.2f} over {len(matched)} transaction(s)." | |
| def _t_top_categories(records, period="this_month", **_): | |
| rows = analytics.spend_by_category(_period_records(records, period))[:5] | |
| if not rows: | |
| return f"No spending found ({_plabel(period)})." | |
| return f"Top categories ({_plabel(period)}): " + ", ".join( | |
| f"{r['category']} {_cur(records)}{r['amount']:,.2f}" for r in rows) | |
| _ORD = {1: "Biggest", 2: "2nd biggest", 3: "3rd biggest", 4: "4th biggest", 5: "5th biggest"} | |
| def _fmt_txn(r, cur): | |
| return (f"{cur}{analytics._num(r.get('total')):,.2f} at {r.get('vendor') or '(no name)'} " | |
| f"on {r.get('date') or '?'} ({analytics._category(r)})") | |
| def _t_biggest_expense(records, period="all", rank=1, **_): | |
| recs = sorted(_period_records(records, period), | |
| key=lambda r: analytics._num(r.get("total")), reverse=True) | |
| if not recs: | |
| return f"No transactions ({_plabel(period)})." | |
| try: | |
| rank = max(1, int(rank)) | |
| except Exception: | |
| rank = 1 | |
| if rank > len(recs): | |
| return f"There are only {len(recs)} transaction(s) ({_plabel(period)}), no #{rank}." | |
| label = _ORD.get(rank, f"#{rank} biggest") | |
| return f"{label} expense ({_plabel(period)}): {_fmt_txn(recs[rank - 1], _cur(records))}." | |
| def _t_smallest_expense(records, period="all", **_): | |
| recs = [r for r in _period_records(records, period) if analytics._num(r.get("total")) > 0] | |
| if not recs: | |
| return f"No transactions ({_plabel(period)})." | |
| r = min(recs, key=lambda r: analytics._num(r.get("total"))) | |
| return f"Smallest expense ({_plabel(period)}): {_fmt_txn(r, _cur(records))}." | |
| def _t_average_monthly(records, **_): | |
| rows = analytics.spend_over_time(records, "Monthly") | |
| if not rows: | |
| return "No dated transactions yet." | |
| avg = sum(r["amount"] for r in rows) / len(rows) | |
| return f"Average monthly spend: {_cur(records)}{avg:,.2f} across {len(rows)} month(s) with spending." | |
| def _t_average_spend(records, period="all", **_): | |
| recs = _period_records(records, period) | |
| n = len(recs) | |
| tot = sum(analytics._num(r.get("total", 0)) for r in recs) | |
| avg = tot / n if n else 0.0 | |
| return (f"Average transaction ({_plabel(period)}): {_cur(records)}{avg:,.2f} " | |
| f"across {n} transaction(s) totalling {_cur(records)}{tot:,.2f}.") | |
| def _t_count_transactions(records, period="all", **_): | |
| recs = _period_records(records, period) | |
| tot = sum(analytics._num(r.get("total", 0)) for r in recs) | |
| return f"{len(recs)} transaction(s) ({_plabel(period)}), totalling {_cur(records)}{tot:,.2f}." | |
| def _t_budget_status(records, budget=0.0, **_): | |
| spent = sum(analytics._num(r.get("total", 0)) for r in _period_records(records, "this_month")) | |
| cur = _cur(records) | |
| budget = analytics._num(budget) | |
| if budget <= 0: | |
| return f"No monthly budget set yet. This month you've spent {cur}{spent:,.2f}." | |
| left = budget - spent | |
| if left >= 0: | |
| return f"Budget this month: {cur}{budget:,.2f}. Spent {cur}{spent:,.2f}. You have {cur}{left:,.2f} left." | |
| return f"Budget this month: {cur}{budget:,.2f}. Spent {cur}{spent:,.2f}. You are {cur}{abs(left):,.2f} over budget." | |
| def _t_monthly_trend(records, **_): | |
| rows = analytics.spend_over_time(records, "Monthly") | |
| if not rows: | |
| return "No dated transactions yet." | |
| return "Monthly totals: " + ", ".join(f"{r['period']}={r['amount']:,.2f}" for r in rows[-6:]) | |
| def _t_recent(records, n=5, **_): | |
| try: | |
| n = int(n) | |
| except Exception: | |
| n = 5 | |
| rows = analytics.transactions_table(records)[:max(1, min(n, 12))] | |
| if not rows: | |
| return "No transactions yet." | |
| return "Recent: " + "; ".join(f"{d} {v} {t:,.2f} ({c})" for d, v, t, c in rows) | |
| TOOLS = { | |
| "total_spend": _t_total_spend, | |
| "category_spend": _t_category_spend, | |
| "item_spend": _t_item_spend, | |
| "vendor_spend": _t_vendor_spend, | |
| "top_categories": _t_top_categories, | |
| "biggest_expense": _t_biggest_expense, | |
| "smallest_expense": _t_smallest_expense, | |
| "average_spend": _t_average_spend, | |
| "average_monthly": _t_average_monthly, | |
| "count_transactions": _t_count_transactions, | |
| "budget_status": _t_budget_status, | |
| "monthly_trend": _t_monthly_trend, | |
| "recent": _t_recent, | |
| } | |
| _ORDINALS = {"first": 1, "1st": 1, "second": 2, "2nd": 2, "third": 3, "3rd": 3, | |
| "fourth": 4, "4th": 4, "fifth": 5, "5th": 5} | |
| SYSTEM_PROMPT = ( | |
| "You are BudgetBuddy's spending agent. Answer the user's question about THEIR " | |
| "spending by planning and calling tools one at a time, then giving a final " | |
| "answer with the exact numbers from the tools.\n" | |
| "period can be: 'this_month', 'last_month', 'this_year', 'last_year', 'all', " | |
| "a SPECIFIC month like '2026-07' (for 'July 2026'), or a year like '2026'.\n" | |
| "Tools:\n" | |
| "- total_spend({\"period\": ...})\n" | |
| "- category_spend({\"category\": <name>, \"period\": ...})\n" | |
| "- item_spend({\"query\": <item name, e.g. 'misal pav'>, \"period\": ...})\n" | |
| "- vendor_spend({\"vendor\": <shop/payee name>, \"period\": ...})\n" | |
| "- top_categories({\"period\": ...})\n" | |
| "- biggest_expense({\"period\": ..., \"rank\": 1}) // rank 2 = 2nd most expensive\n" | |
| "- smallest_expense({\"period\": ...})\n" | |
| "- average_spend({\"period\": ...}) // average per transaction\n" | |
| "- average_monthly({}) // average spend per month\n" | |
| "- count_transactions({\"period\": ...})\n" | |
| "- budget_status({})\n" | |
| "- monthly_trend({})\n" | |
| "- recent({\"n\": 5})\n" | |
| "To call a tool, reply with EXACTLY one line:\n" | |
| "CALL <tool>(<json args>)\n" | |
| "Example: CALL category_spend({\"category\": \"Groceries\", \"period\": \"this_month\"})\n" | |
| "After you get a line starting with TOOL_RESULT, either call another tool or " | |
| "finish with:\nANSWER: <concise answer using the numbers>\n" | |
| "Always finish with an ANSWER line. Do not invent numbers." | |
| ) | |
| HELP_TEXT = ( | |
| "I'm your spending assistant — I read your saved transactions and answer with real numbers. " | |
| "Try things like:\n" | |
| "• \"How much did I spend this month?\" or \"…in July 2026?\"\n" | |
| "• \"What's my top category?\" / \"Where did my money go?\"\n" | |
| "• \"How much on Groceries last month?\"\n" | |
| "• \"What's my biggest expense?\" • \"What's my average spend?\"\n" | |
| "• \"How many transactions this year?\" • \"Am I over budget?\"" | |
| ) | |
| # --------------------------------------------------------------------------- # | |
| # Deterministic router (fast, reliable, still emits a tool trace) | |
| # --------------------------------------------------------------------------- # | |
| def _parse_period(ql: str): | |
| """Resolve a time phrase to a canonical period, or None if unspecified.""" | |
| if re.search(r"\b(all[\s-]?time|overall|in total|altogether|ever|so far|to date|lifetime)\b", ql): | |
| return "all" | |
| if re.search(r"\b(last|previous|prev)\s+month\b", ql): | |
| return "last_month" | |
| if re.search(r"\bthis\s+month\b|\bcurrent\s+month\b|\bthis\s+mth\b", ql): | |
| return "this_month" | |
| if re.search(r"\b(last|previous|prev)\s+year\b", ql): | |
| return "last_year" | |
| if re.search(r"\bthis\s+year\b|\bcurrent\s+year\b", ql): | |
| return "this_year" | |
| ym = re.search(r"\b(20\d{2})-(\d{1,2})\b", ql) | |
| if ym: | |
| return f"{ym.group(1)}-{int(ym.group(2)):02d}" | |
| mon = re.search(r"\b(jan(?:uary)?|feb(?:ruary)?|mar(?:ch)?|apr(?:il)?|may|jun(?:e)?|" | |
| r"jul(?:y)?|aug(?:ust)?|sep(?:t)?(?:ember)?|oct(?:ober)?|nov(?:ember)?|dec(?:ember)?)\b", ql) | |
| if mon: | |
| mo = _MONTHS.get(mon.group(1)) or _MONTHS.get(mon.group(1)[:3]) | |
| yr = re.search(r"\b(20\d{2})\b", ql) | |
| year = int(yr.group(1)) if yr else date.today().year | |
| if mo: | |
| return f"{year}-{mo:02d}" | |
| yr = re.search(r"\b(20\d{2})\b", ql) | |
| if yr: | |
| return yr.group(1) | |
| return None | |
| def _detect_category(ql: str): | |
| from core.categorize import CATEGORIES | |
| for cat in CATEGORIES: | |
| if re.search(r"\b" + re.escape(cat.lower()) + r"\b", ql): | |
| return cat | |
| for syn, cat in _CAT_SYNONYMS.items(): | |
| if re.search(r"\b" + re.escape(syn) + r"\b", ql): | |
| return cat | |
| return None | |
| def _detect_vendor(ql: str, records): | |
| """Match a known vendor name appearing in the question (longest first).""" | |
| seen = {} | |
| for r in records: | |
| v = str(r.get("vendor", "")).strip() | |
| if len(v) >= 3: | |
| seen[v.lower()] = v | |
| for low in sorted(seen, key=len, reverse=True): | |
| if re.search(r"\b" + re.escape(low) + r"\b", ql): | |
| return seen[low] | |
| return None | |
| def _greeting_reply(question: str): | |
| """Instant reply for greetings / 'what can you do' — no model call.""" | |
| bare = question.lower().strip() | |
| ql = " " + bare + " " | |
| if re.fullmatch(r"(hi|hello|hey|yo|hiya|howdy|sup|namaste|hi there|hey there)[\s!.,]*", bare) or \ | |
| re.search(r"\bwhat can you (do|help)\b|\bhow do you work\b|\bwho are you\b|\bwhat do you do\b", ql) or \ | |
| bare in ("help", "?", "??", "menu", "options"): | |
| return HELP_TEXT | |
| return None | |
| def _fast_answer(question: str, records, budget): | |
| """Deterministic fallback: map a clear single-intent question to the exact | |
| tool, else None. Used only when the LLM agent can't produce a grounded plan.""" | |
| ql = " " + question.lower().strip() + " " | |
| period = _parse_period(ql) | |
| P = period or "all" | |
| def fire(tool, args, fn): | |
| obs = fn() | |
| return obs, [{"tool": tool, "args": args, "result": obs}] | |
| if re.search(r"\bbudget\b|\bover\s?spend|under\s?spend|can i (afford|spend)\b", ql): | |
| return fire("budget_status", {}, lambda: _t_budget_status(records, budget=budget)) | |
| if re.search(r"\bhow many\b.*\b(transaction|bill|receipt|purchase|entr|payment)|number of (transaction|bill|receipt|purchase)", ql): | |
| return fire("count_transactions", {"period": P}, lambda: _t_count_transactions(records, period=P)) | |
| if re.search(r"\b(smallest|cheapest|least expensive|lowest|min(?:imum)?)\b", ql) and "categor" not in ql: | |
| return fire("smallest_expense", {"period": P}, lambda: _t_smallest_expense(records, period=P)) | |
| if re.search(r"\b(biggest|largest|highest|most expensive|priciest|max(?:imum)?|single largest|top spend|dearest)\b", ql) and "categor" not in ql: | |
| rank = 1 | |
| om = re.search(r"\b(first|second|third|fourth|fifth|\d+(?:st|nd|rd|th))\b", ql) | |
| if om: | |
| g = om.group(1) | |
| rank = _ORDINALS.get(g) or int(re.sub(r"\D", "", g) or 1) | |
| return fire("biggest_expense", {"period": P, "rank": rank}, | |
| lambda: _t_biggest_expense(records, period=P, rank=rank)) | |
| if re.search(r"\b(average|avg|mean|typical|on average)\b", ql): | |
| if re.search(r"\b(month|monthly|per month|a month|each month)\b", ql): | |
| return fire("average_monthly", {}, lambda: _t_average_monthly(records)) | |
| return fire("average_spend", {"period": P}, lambda: _t_average_spend(records, period=P)) | |
| if re.search(r"\b(trend|over time|by month|each month|month by month|monthly|per month|month[\s-]?wise)\b", ql): | |
| return fire("monthly_trend", {}, lambda: _t_monthly_trend(records)) | |
| if re.search(r"\b(recent|latest|last few|show me|list).*(transaction|spend|purchase|bought|buy)|what did i (buy|spend on) (recently|lately)", ql): | |
| return fire("recent", {"n": 5}, lambda: _t_recent(records, n=5)) | |
| if re.search(r"\b(top categor|biggest categor|main categor|where.*(money|spend|spent).*(go|went)|" | |
| r"spend most|most on|breakdown|by category|categor(y|ies))\b", ql): | |
| return fire("top_categories", {"period": P}, lambda: _t_top_categories(records, period=P)) | |
| cat = _detect_category(ql) | |
| if cat and re.search(r"\b(spend|spent|spending|cost|how much|total|pay|paid)\b", ql): | |
| return fire("category_spend", {"category": cat, "period": P}, | |
| lambda: _t_category_spend(records, category=cat, period=P)) | |
| ven = _detect_vendor(ql, records) | |
| if ven and re.search(r"\b(spend|spent|how much|total|pay|paid|at|from)\b", ql): | |
| return fire("vendor_spend", {"vendor": ven, "period": P}, | |
| lambda: _t_vendor_spend(records, vendor=ven, period=P)) | |
| if re.search(r"\b(how much|total|spend|spent|spending)\b", ql): | |
| return fire("total_spend", {"period": P}, lambda: _t_total_spend(records, period=P)) | |
| return None | |
| # --------------------------------------------------------------------------- # | |
| # Parsing the model's action (ReAct loop) | |
| # --------------------------------------------------------------------------- # | |
| _CALL_RE = re.compile(r"CALL\s+(\w+)\s*\((.*)\)", re.S) | |
| _BARE_RE = re.compile(r"\b(" + "|".join(TOOLS) + r")\b\s*\((.*?)\)", re.S) | |
| _PERIODS = ("this_month", "last_month", "this_year", "last_year", "all", "year") | |
| def _loose_args(raw: str) -> dict: | |
| raw = raw.strip().strip("()") | |
| if not raw: | |
| return {} | |
| try: | |
| return json.loads(raw) | |
| except Exception: | |
| pass | |
| args: dict[str, Any] = {} | |
| for tok in re.split(r"[,]", raw): | |
| tok = tok.strip().strip('"\'') | |
| if not tok: | |
| continue | |
| low = tok.lower().replace(" ", "_") | |
| if low in _PERIODS or re.match(r"^\d{4}-\d{1,2}$", low): | |
| args["period"] = low | |
| elif re.match(r"^(19|20)\d{2}$", tok): # a 4-digit year is a period, not a count | |
| args["period"] = tok | |
| elif tok.isdigit(): | |
| args["n"] = int(tok) | |
| elif ":" in tok or "=" in tok: | |
| parts = re.split(r"[:=]", tok, maxsplit=1) | |
| if len(parts) == 2: | |
| args[parts[0].strip().strip('"\'')] = parts[1].strip().strip('"\'') | |
| else: | |
| args.setdefault("category", tok) | |
| return args | |
| def _parse_action(text: str) -> dict: | |
| t = (text or "").strip() | |
| m = re.search(r"ANSWER:\s*(.+)", t, re.S) | |
| if m: | |
| return {"type": "answer", "text": m.group(1).strip()} | |
| m = _CALL_RE.search(t) or _BARE_RE.search(t) | |
| if m and m.group(1) in TOOLS: | |
| raw = m.group(2).strip() | |
| try: | |
| args = json.loads(raw) if raw.startswith("{") else _loose_args(raw) | |
| except Exception: | |
| args = _loose_args(raw) | |
| return {"type": "call", "name": m.group(1), | |
| "args": args if isinstance(args, dict) else {}} | |
| return {"type": "unknown", "text": t} | |
| def _run_tool(name: str, records, args: dict, budget: float = 0.0) -> str: | |
| try: | |
| call_args = dict(args or {}) | |
| if name == "budget_status": | |
| call_args.setdefault("budget", budget) | |
| return TOOLS[name](records, **call_args) | |
| except Exception as e: # pragma: no cover | |
| return f"(tool error: {e})" | |
| # --------------------------------------------------------------------------- # | |
| # Agent entry point | |
| # --------------------------------------------------------------------------- # | |
| def _react(question: str, records, budget: float, max_steps: int): | |
| """The real agent: the 8B plans, calls a tool, reads the result, optionally | |
| calls another, then answers. Returns {reply, trace} ONLY when the answer is | |
| grounded by at least one tool call; otherwise None so the caller can fall back | |
| (we never surface an ungrounded, possibly hallucinated answer).""" | |
| transcript = [ | |
| {"role": "system", "content": SYSTEM_PROMPT}, | |
| {"role": "user", "content": f"Today is {date.today().isoformat()}.\nQuestion: {question}\n" | |
| "Plan and call the first tool (use a specific 'YYYY-MM' period for a named month)."}, | |
| ] | |
| trace: list[dict] = [] | |
| last_obs = None | |
| for _ in range(max_steps): | |
| out = inference.text_generate(transcript, max_new_tokens=AGENT_TOKENS) | |
| action = _parse_action(out) | |
| if action["type"] == "answer": | |
| # Only trust an answer the model reached AFTER actually using a tool. | |
| return {"reply": action["text"], "trace": trace} if (trace and action["text"]) else None | |
| if action["type"] == "call": | |
| obs = _run_tool(action["name"], records, action["args"], budget) | |
| trace.append({"tool": action["name"], "args": action["args"], "result": obs}) | |
| last_obs = obs | |
| transcript.append({"role": "assistant", "content": out}) | |
| transcript.append({ | |
| "role": "user", | |
| "content": f"TOOL_RESULT: {obs}\nNow either CALL another tool or give " | |
| "ANSWER: <one sentence using the exact numbers above>.", | |
| }) | |
| else: | |
| break | |
| # Tools ran but no clean ANSWER line — return the last exact tool result | |
| # (still grounded, still correct) with the trace. | |
| return {"reply": last_obs, "trace": trace} if trace else None | |
| def run(question: str, records: list[dict], max_steps: int = MAX_STEPS, | |
| budget: float = 0.0) -> dict: | |
| """Answer a spending question. Returns {reply, trace:[{tool,args,result}]}. | |
| Layers: instant greeting → the 8B-planned ReAct agent (primary) → a | |
| deterministic tool router (fallback if the model can't produce a grounded | |
| plan) → a grounded one-shot answer (last resort). Numbers always come from | |
| deterministic tools, so they can't be hallucinated. | |
| """ | |
| question = (question or "").strip() | |
| if not question: | |
| return {"reply": "Ask me anything about your spending.", "trace": []} | |
| if not records: | |
| return {"reply": "You have no saved transactions yet — add a few and ask again.", | |
| "trace": []} | |
| greet = _greeting_reply(question) | |
| if greet: | |
| return {"reply": greet, "trace": []} | |
| # PRIMARY: the AI agent plans and calls the tools. | |
| try: | |
| agentic = _react(question, records, budget, max_steps) | |
| if agentic is not None: | |
| return agentic | |
| except Exception as e: # pragma: no cover - model/runtime dependent | |
| print(f"[agent] react failed: {e}") | |
| # FALLBACK 1: deterministic router — guarantees a correct answer. | |
| try: | |
| fast = _fast_answer(question, records, budget) | |
| if fast is not None: | |
| return {"reply": fast[0], "trace": fast[1]} | |
| except Exception as e: # pragma: no cover | |
| print(f"[agent] fallback router failed: {e}") | |
| # FALLBACK 2: grounded one-shot answer. | |
| return {"reply": chat.answer(question, records), "trace": []} | |