BudgetBuddy / core /agent.py
KrishnaGarg's picture
Deploy BudgetBuddy update
724a5ef verified
Raw
History Blame Contribute Delete
25.2 kB
"""Tool-using spending agent (Best Agent badge): plan + multi-step tool calls.
Two layers, both grounded in core.analytics, both emitting a visible tool trace:
1. A deterministic **router** (`_fast_answer`) maps a clear, single-intent question
straight to the right analytics tool — instant, no model round-trip, no chance
of a hallucinated number. Most everyday questions land here.
2. For open-ended / multi-part questions, a **ReAct loop** drives MiniCPM4.1-8B
(on Modal): it picks a tool, reads the result, optionally calls another, then
answers. If the small model never produces a clean plan, we fall back to the
grounded one-shot answer.
Either way the UI gets `trace:[{tool,args,result}]`, so judges can see the tools
the agent used. Everything stays well under the 32B cap (1.3B vision + 8B text).
"""
from __future__ import annotations
import json
import re
from datetime import date
from typing import Any
from core import analytics, chat, inference
MAX_STEPS = 2 # fewer Modal round-trips -> faster
AGENT_TOKENS = 220
_MONTH_NAMES = ["January", "February", "March", "April", "May", "June", "July",
"August", "September", "October", "November", "December"]
_MONTHS = {}
for _i, _full in enumerate(_MONTH_NAMES, start=1):
_MONTHS[_full.lower()] = _i
_MONTHS[_full.lower()[:3]] = _i
_MONTHS["sept"] = 9
# Synonyms -> canonical category name, so "food"/"petrol"/"cab" resolve.
_CAT_SYNONYMS = {
"food": "Dining", "restaurant": "Dining", "restaurants": "Dining",
"eating out": "Dining", "dine": "Dining", "dining": "Dining",
"grocery": "Groceries", "groceries": "Groceries",
"coffee": "Cafe", "cafe": "Cafe", "tea": "Cafe",
"fuel": "Fuel", "petrol": "Fuel", "diesel": "Fuel", "gas": "Fuel",
"transport": "Transport", "cab": "Transport", "taxi": "Transport",
"uber": "Transport", "ola": "Transport", "bus": "Transport", "metro": "Transport",
"medicine": "Pharmacy", "medicines": "Pharmacy", "pharmacy": "Pharmacy",
"health": "Health", "doctor": "Health", "hospital": "Health",
"clothes": "Clothing", "clothing": "Clothing", "apparel": "Clothing",
"rent": "Rent", "electricity": "Utilities", "water": "Utilities",
"utility": "Utilities", "utilities": "Utilities", "bills": "Utilities",
"shopping": "Shopping", "phone": "Telecom", "mobile": "Telecom",
"recharge": "Telecom", "subscription": "Subscriptions",
"subscriptions": "Subscriptions", "entertainment": "Entertainment",
"movie": "Entertainment", "travel": "Travel", "education": "Education",
"school": "Education", "fees": "Fees & Charges",
}
# --------------------------------------------------------------------------- #
# Period filtering + labels
# --------------------------------------------------------------------------- #
def _period_records(records, period: str) -> list[dict]:
"""Filter records by a flexible period: all | this_month | last_month |
this_year | last_year | a specific month 'YYYY-MM' | a specific year 'YYYY'."""
period = (period or "this_month").strip().lower().replace(" ", "_")
if period in ("all", "total", "overall", "alltime", "all_time", ""):
return records
today = date.today()
def keep(pred):
return [r for r in records if (lambda d: d and pred(d))(analytics.parse_date(r.get("date")))]
mo_match = re.match(r"^(\d{4})-(\d{1,2})$", period)
if mo_match:
y, mo = int(mo_match.group(1)), int(mo_match.group(2))
return keep(lambda d: (d.year, d.month) == (y, mo))
yr_match = re.match(r"^(\d{4})$", period)
if yr_match:
y = int(yr_match.group(1))
return keep(lambda d: d.year == y)
if "this_year" in period or period == "year":
return keep(lambda d: d.year == today.year)
if "last_year" in period:
return keep(lambda d: d.year == today.year - 1)
if "last" in period: # last month
y, m = (today.year, today.month - 1) if today.month > 1 else (today.year - 1, 12)
else: # this month
y, m = today.year, today.month
return keep(lambda d: (d.year, d.month) == (y, m))
def _plabel(period: str) -> str:
"""Human-friendly label for a period, always stated back to the user."""
p = (period or "all").strip().lower()
fixed = {"all": "all time", "this_month": "this month", "last_month": "last month",
"this_year": "this year", "last_year": "last year", "year": "this year"}
if p in fixed:
return fixed[p]
m = re.match(r"^(\d{4})-(\d{1,2})$", p)
if m:
return f"{_MONTH_NAMES[int(m.group(2)) - 1]} {m.group(1)}"
if re.match(r"^\d{4}$", p):
return p
return p.replace("_", " ")
def _cur(records) -> str:
c = analytics.dominant_currency(records)
return f"{c} " if c else ""
# --------------------------------------------------------------------------- #
# Tools — each takes (records, **args) and returns a short observation string
# --------------------------------------------------------------------------- #
def _t_total_spend(records, period="this_month", **_):
recs = _period_records(records, period)
tot = sum(analytics._num(r.get("total", 0)) for r in recs)
return f"Total spend ({_plabel(period)}): {_cur(records)}{tot:,.2f} over {len(recs)} transaction(s)."
def _t_category_spend(records, category="", period="this_month", **_):
recs = _period_records(records, period)
cat = (category or "").strip().lower()
matched = [r for r in recs if analytics._category(r).lower() == cat]
if not matched: # fall back to a fuzzy contains-match (e.g. "food" ~ "Dining")
matched = [r for r in recs if cat and cat in analytics._category(r).lower()]
tot = sum(analytics._num(r.get("total", 0)) for r in matched)
return f"{category or '?'} spend ({_plabel(period)}): {_cur(records)}{tot:,.2f} over {len(matched)} transaction(s)."
def _t_item_spend(records, query="", period="all", **_):
recs = _period_records(records, period)
q = (query or "").strip().lower()
total, n = 0.0, 0
for r in recs:
for it in (r.get("line_items") or []):
if q and q in str(it.get("name", "")).lower():
total += analytics._num(it.get("amount"))
n += 1
return f"'{query}' ({_plabel(period)}): {_cur(records)}{total:,.2f} across {n} item(s)."
def _t_vendor_spend(records, vendor="", period="all", **_):
recs = _period_records(records, period)
v = (vendor or "").strip().lower()
matched = [r for r in recs if v and v in str(r.get("vendor", "")).lower()]
tot = sum(analytics._num(r.get("total", 0)) for r in matched)
return f"Spend at '{vendor}' ({_plabel(period)}): {_cur(records)}{tot:,.2f} over {len(matched)} transaction(s)."
def _t_top_categories(records, period="this_month", **_):
rows = analytics.spend_by_category(_period_records(records, period))[:5]
if not rows:
return f"No spending found ({_plabel(period)})."
return f"Top categories ({_plabel(period)}): " + ", ".join(
f"{r['category']} {_cur(records)}{r['amount']:,.2f}" for r in rows)
_ORD = {1: "Biggest", 2: "2nd biggest", 3: "3rd biggest", 4: "4th biggest", 5: "5th biggest"}
def _fmt_txn(r, cur):
return (f"{cur}{analytics._num(r.get('total')):,.2f} at {r.get('vendor') or '(no name)'} "
f"on {r.get('date') or '?'} ({analytics._category(r)})")
def _t_biggest_expense(records, period="all", rank=1, **_):
recs = sorted(_period_records(records, period),
key=lambda r: analytics._num(r.get("total")), reverse=True)
if not recs:
return f"No transactions ({_plabel(period)})."
try:
rank = max(1, int(rank))
except Exception:
rank = 1
if rank > len(recs):
return f"There are only {len(recs)} transaction(s) ({_plabel(period)}), no #{rank}."
label = _ORD.get(rank, f"#{rank} biggest")
return f"{label} expense ({_plabel(period)}): {_fmt_txn(recs[rank - 1], _cur(records))}."
def _t_smallest_expense(records, period="all", **_):
recs = [r for r in _period_records(records, period) if analytics._num(r.get("total")) > 0]
if not recs:
return f"No transactions ({_plabel(period)})."
r = min(recs, key=lambda r: analytics._num(r.get("total")))
return f"Smallest expense ({_plabel(period)}): {_fmt_txn(r, _cur(records))}."
def _t_average_monthly(records, **_):
rows = analytics.spend_over_time(records, "Monthly")
if not rows:
return "No dated transactions yet."
avg = sum(r["amount"] for r in rows) / len(rows)
return f"Average monthly spend: {_cur(records)}{avg:,.2f} across {len(rows)} month(s) with spending."
def _t_average_spend(records, period="all", **_):
recs = _period_records(records, period)
n = len(recs)
tot = sum(analytics._num(r.get("total", 0)) for r in recs)
avg = tot / n if n else 0.0
return (f"Average transaction ({_plabel(period)}): {_cur(records)}{avg:,.2f} "
f"across {n} transaction(s) totalling {_cur(records)}{tot:,.2f}.")
def _t_count_transactions(records, period="all", **_):
recs = _period_records(records, period)
tot = sum(analytics._num(r.get("total", 0)) for r in recs)
return f"{len(recs)} transaction(s) ({_plabel(period)}), totalling {_cur(records)}{tot:,.2f}."
def _t_budget_status(records, budget=0.0, **_):
spent = sum(analytics._num(r.get("total", 0)) for r in _period_records(records, "this_month"))
cur = _cur(records)
budget = analytics._num(budget)
if budget <= 0:
return f"No monthly budget set yet. This month you've spent {cur}{spent:,.2f}."
left = budget - spent
if left >= 0:
return f"Budget this month: {cur}{budget:,.2f}. Spent {cur}{spent:,.2f}. You have {cur}{left:,.2f} left."
return f"Budget this month: {cur}{budget:,.2f}. Spent {cur}{spent:,.2f}. You are {cur}{abs(left):,.2f} over budget."
def _t_monthly_trend(records, **_):
rows = analytics.spend_over_time(records, "Monthly")
if not rows:
return "No dated transactions yet."
return "Monthly totals: " + ", ".join(f"{r['period']}={r['amount']:,.2f}" for r in rows[-6:])
def _t_recent(records, n=5, **_):
try:
n = int(n)
except Exception:
n = 5
rows = analytics.transactions_table(records)[:max(1, min(n, 12))]
if not rows:
return "No transactions yet."
return "Recent: " + "; ".join(f"{d} {v} {t:,.2f} ({c})" for d, v, t, c in rows)
TOOLS = {
"total_spend": _t_total_spend,
"category_spend": _t_category_spend,
"item_spend": _t_item_spend,
"vendor_spend": _t_vendor_spend,
"top_categories": _t_top_categories,
"biggest_expense": _t_biggest_expense,
"smallest_expense": _t_smallest_expense,
"average_spend": _t_average_spend,
"average_monthly": _t_average_monthly,
"count_transactions": _t_count_transactions,
"budget_status": _t_budget_status,
"monthly_trend": _t_monthly_trend,
"recent": _t_recent,
}
_ORDINALS = {"first": 1, "1st": 1, "second": 2, "2nd": 2, "third": 3, "3rd": 3,
"fourth": 4, "4th": 4, "fifth": 5, "5th": 5}
SYSTEM_PROMPT = (
"You are BudgetBuddy's spending agent. Answer the user's question about THEIR "
"spending by planning and calling tools one at a time, then giving a final "
"answer with the exact numbers from the tools.\n"
"period can be: 'this_month', 'last_month', 'this_year', 'last_year', 'all', "
"a SPECIFIC month like '2026-07' (for 'July 2026'), or a year like '2026'.\n"
"Tools:\n"
"- total_spend({\"period\": ...})\n"
"- category_spend({\"category\": <name>, \"period\": ...})\n"
"- item_spend({\"query\": <item name, e.g. 'misal pav'>, \"period\": ...})\n"
"- vendor_spend({\"vendor\": <shop/payee name>, \"period\": ...})\n"
"- top_categories({\"period\": ...})\n"
"- biggest_expense({\"period\": ..., \"rank\": 1}) // rank 2 = 2nd most expensive\n"
"- smallest_expense({\"period\": ...})\n"
"- average_spend({\"period\": ...}) // average per transaction\n"
"- average_monthly({}) // average spend per month\n"
"- count_transactions({\"period\": ...})\n"
"- budget_status({})\n"
"- monthly_trend({})\n"
"- recent({\"n\": 5})\n"
"To call a tool, reply with EXACTLY one line:\n"
"CALL <tool>(<json args>)\n"
"Example: CALL category_spend({\"category\": \"Groceries\", \"period\": \"this_month\"})\n"
"After you get a line starting with TOOL_RESULT, either call another tool or "
"finish with:\nANSWER: <concise answer using the numbers>\n"
"Always finish with an ANSWER line. Do not invent numbers."
)
HELP_TEXT = (
"I'm your spending assistant — I read your saved transactions and answer with real numbers. "
"Try things like:\n"
"• \"How much did I spend this month?\" or \"…in July 2026?\"\n"
"• \"What's my top category?\" / \"Where did my money go?\"\n"
"• \"How much on Groceries last month?\"\n"
"• \"What's my biggest expense?\" • \"What's my average spend?\"\n"
"• \"How many transactions this year?\" • \"Am I over budget?\""
)
# --------------------------------------------------------------------------- #
# Deterministic router (fast, reliable, still emits a tool trace)
# --------------------------------------------------------------------------- #
def _parse_period(ql: str):
"""Resolve a time phrase to a canonical period, or None if unspecified."""
if re.search(r"\b(all[\s-]?time|overall|in total|altogether|ever|so far|to date|lifetime)\b", ql):
return "all"
if re.search(r"\b(last|previous|prev)\s+month\b", ql):
return "last_month"
if re.search(r"\bthis\s+month\b|\bcurrent\s+month\b|\bthis\s+mth\b", ql):
return "this_month"
if re.search(r"\b(last|previous|prev)\s+year\b", ql):
return "last_year"
if re.search(r"\bthis\s+year\b|\bcurrent\s+year\b", ql):
return "this_year"
ym = re.search(r"\b(20\d{2})-(\d{1,2})\b", ql)
if ym:
return f"{ym.group(1)}-{int(ym.group(2)):02d}"
mon = re.search(r"\b(jan(?:uary)?|feb(?:ruary)?|mar(?:ch)?|apr(?:il)?|may|jun(?:e)?|"
r"jul(?:y)?|aug(?:ust)?|sep(?:t)?(?:ember)?|oct(?:ober)?|nov(?:ember)?|dec(?:ember)?)\b", ql)
if mon:
mo = _MONTHS.get(mon.group(1)) or _MONTHS.get(mon.group(1)[:3])
yr = re.search(r"\b(20\d{2})\b", ql)
year = int(yr.group(1)) if yr else date.today().year
if mo:
return f"{year}-{mo:02d}"
yr = re.search(r"\b(20\d{2})\b", ql)
if yr:
return yr.group(1)
return None
def _detect_category(ql: str):
from core.categorize import CATEGORIES
for cat in CATEGORIES:
if re.search(r"\b" + re.escape(cat.lower()) + r"\b", ql):
return cat
for syn, cat in _CAT_SYNONYMS.items():
if re.search(r"\b" + re.escape(syn) + r"\b", ql):
return cat
return None
def _detect_vendor(ql: str, records):
"""Match a known vendor name appearing in the question (longest first)."""
seen = {}
for r in records:
v = str(r.get("vendor", "")).strip()
if len(v) >= 3:
seen[v.lower()] = v
for low in sorted(seen, key=len, reverse=True):
if re.search(r"\b" + re.escape(low) + r"\b", ql):
return seen[low]
return None
def _greeting_reply(question: str):
"""Instant reply for greetings / 'what can you do' — no model call."""
bare = question.lower().strip()
ql = " " + bare + " "
if re.fullmatch(r"(hi|hello|hey|yo|hiya|howdy|sup|namaste|hi there|hey there)[\s!.,]*", bare) or \
re.search(r"\bwhat can you (do|help)\b|\bhow do you work\b|\bwho are you\b|\bwhat do you do\b", ql) or \
bare in ("help", "?", "??", "menu", "options"):
return HELP_TEXT
return None
def _fast_answer(question: str, records, budget):
"""Deterministic fallback: map a clear single-intent question to the exact
tool, else None. Used only when the LLM agent can't produce a grounded plan."""
ql = " " + question.lower().strip() + " "
period = _parse_period(ql)
P = period or "all"
def fire(tool, args, fn):
obs = fn()
return obs, [{"tool": tool, "args": args, "result": obs}]
if re.search(r"\bbudget\b|\bover\s?spend|under\s?spend|can i (afford|spend)\b", ql):
return fire("budget_status", {}, lambda: _t_budget_status(records, budget=budget))
if re.search(r"\bhow many\b.*\b(transaction|bill|receipt|purchase|entr|payment)|number of (transaction|bill|receipt|purchase)", ql):
return fire("count_transactions", {"period": P}, lambda: _t_count_transactions(records, period=P))
if re.search(r"\b(smallest|cheapest|least expensive|lowest|min(?:imum)?)\b", ql) and "categor" not in ql:
return fire("smallest_expense", {"period": P}, lambda: _t_smallest_expense(records, period=P))
if re.search(r"\b(biggest|largest|highest|most expensive|priciest|max(?:imum)?|single largest|top spend|dearest)\b", ql) and "categor" not in ql:
rank = 1
om = re.search(r"\b(first|second|third|fourth|fifth|\d+(?:st|nd|rd|th))\b", ql)
if om:
g = om.group(1)
rank = _ORDINALS.get(g) or int(re.sub(r"\D", "", g) or 1)
return fire("biggest_expense", {"period": P, "rank": rank},
lambda: _t_biggest_expense(records, period=P, rank=rank))
if re.search(r"\b(average|avg|mean|typical|on average)\b", ql):
if re.search(r"\b(month|monthly|per month|a month|each month)\b", ql):
return fire("average_monthly", {}, lambda: _t_average_monthly(records))
return fire("average_spend", {"period": P}, lambda: _t_average_spend(records, period=P))
if re.search(r"\b(trend|over time|by month|each month|month by month|monthly|per month|month[\s-]?wise)\b", ql):
return fire("monthly_trend", {}, lambda: _t_monthly_trend(records))
if re.search(r"\b(recent|latest|last few|show me|list).*(transaction|spend|purchase|bought|buy)|what did i (buy|spend on) (recently|lately)", ql):
return fire("recent", {"n": 5}, lambda: _t_recent(records, n=5))
if re.search(r"\b(top categor|biggest categor|main categor|where.*(money|spend|spent).*(go|went)|"
r"spend most|most on|breakdown|by category|categor(y|ies))\b", ql):
return fire("top_categories", {"period": P}, lambda: _t_top_categories(records, period=P))
cat = _detect_category(ql)
if cat and re.search(r"\b(spend|spent|spending|cost|how much|total|pay|paid)\b", ql):
return fire("category_spend", {"category": cat, "period": P},
lambda: _t_category_spend(records, category=cat, period=P))
ven = _detect_vendor(ql, records)
if ven and re.search(r"\b(spend|spent|how much|total|pay|paid|at|from)\b", ql):
return fire("vendor_spend", {"vendor": ven, "period": P},
lambda: _t_vendor_spend(records, vendor=ven, period=P))
if re.search(r"\b(how much|total|spend|spent|spending)\b", ql):
return fire("total_spend", {"period": P}, lambda: _t_total_spend(records, period=P))
return None
# --------------------------------------------------------------------------- #
# Parsing the model's action (ReAct loop)
# --------------------------------------------------------------------------- #
_CALL_RE = re.compile(r"CALL\s+(\w+)\s*\((.*)\)", re.S)
_BARE_RE = re.compile(r"\b(" + "|".join(TOOLS) + r")\b\s*\((.*?)\)", re.S)
_PERIODS = ("this_month", "last_month", "this_year", "last_year", "all", "year")
def _loose_args(raw: str) -> dict:
raw = raw.strip().strip("()")
if not raw:
return {}
try:
return json.loads(raw)
except Exception:
pass
args: dict[str, Any] = {}
for tok in re.split(r"[,]", raw):
tok = tok.strip().strip('"\'')
if not tok:
continue
low = tok.lower().replace(" ", "_")
if low in _PERIODS or re.match(r"^\d{4}-\d{1,2}$", low):
args["period"] = low
elif re.match(r"^(19|20)\d{2}$", tok): # a 4-digit year is a period, not a count
args["period"] = tok
elif tok.isdigit():
args["n"] = int(tok)
elif ":" in tok or "=" in tok:
parts = re.split(r"[:=]", tok, maxsplit=1)
if len(parts) == 2:
args[parts[0].strip().strip('"\'')] = parts[1].strip().strip('"\'')
else:
args.setdefault("category", tok)
return args
def _parse_action(text: str) -> dict:
t = (text or "").strip()
m = re.search(r"ANSWER:\s*(.+)", t, re.S)
if m:
return {"type": "answer", "text": m.group(1).strip()}
m = _CALL_RE.search(t) or _BARE_RE.search(t)
if m and m.group(1) in TOOLS:
raw = m.group(2).strip()
try:
args = json.loads(raw) if raw.startswith("{") else _loose_args(raw)
except Exception:
args = _loose_args(raw)
return {"type": "call", "name": m.group(1),
"args": args if isinstance(args, dict) else {}}
return {"type": "unknown", "text": t}
def _run_tool(name: str, records, args: dict, budget: float = 0.0) -> str:
try:
call_args = dict(args or {})
if name == "budget_status":
call_args.setdefault("budget", budget)
return TOOLS[name](records, **call_args)
except Exception as e: # pragma: no cover
return f"(tool error: {e})"
# --------------------------------------------------------------------------- #
# Agent entry point
# --------------------------------------------------------------------------- #
def _react(question: str, records, budget: float, max_steps: int):
"""The real agent: the 8B plans, calls a tool, reads the result, optionally
calls another, then answers. Returns {reply, trace} ONLY when the answer is
grounded by at least one tool call; otherwise None so the caller can fall back
(we never surface an ungrounded, possibly hallucinated answer)."""
transcript = [
{"role": "system", "content": SYSTEM_PROMPT},
{"role": "user", "content": f"Today is {date.today().isoformat()}.\nQuestion: {question}\n"
"Plan and call the first tool (use a specific 'YYYY-MM' period for a named month)."},
]
trace: list[dict] = []
last_obs = None
for _ in range(max_steps):
out = inference.text_generate(transcript, max_new_tokens=AGENT_TOKENS)
action = _parse_action(out)
if action["type"] == "answer":
# Only trust an answer the model reached AFTER actually using a tool.
return {"reply": action["text"], "trace": trace} if (trace and action["text"]) else None
if action["type"] == "call":
obs = _run_tool(action["name"], records, action["args"], budget)
trace.append({"tool": action["name"], "args": action["args"], "result": obs})
last_obs = obs
transcript.append({"role": "assistant", "content": out})
transcript.append({
"role": "user",
"content": f"TOOL_RESULT: {obs}\nNow either CALL another tool or give "
"ANSWER: <one sentence using the exact numbers above>.",
})
else:
break
# Tools ran but no clean ANSWER line — return the last exact tool result
# (still grounded, still correct) with the trace.
return {"reply": last_obs, "trace": trace} if trace else None
def run(question: str, records: list[dict], max_steps: int = MAX_STEPS,
budget: float = 0.0) -> dict:
"""Answer a spending question. Returns {reply, trace:[{tool,args,result}]}.
Layers: instant greeting → the 8B-planned ReAct agent (primary) → a
deterministic tool router (fallback if the model can't produce a grounded
plan) → a grounded one-shot answer (last resort). Numbers always come from
deterministic tools, so they can't be hallucinated.
"""
question = (question or "").strip()
if not question:
return {"reply": "Ask me anything about your spending.", "trace": []}
if not records:
return {"reply": "You have no saved transactions yet — add a few and ask again.",
"trace": []}
greet = _greeting_reply(question)
if greet:
return {"reply": greet, "trace": []}
# PRIMARY: the AI agent plans and calls the tools.
try:
agentic = _react(question, records, budget, max_steps)
if agentic is not None:
return agentic
except Exception as e: # pragma: no cover - model/runtime dependent
print(f"[agent] react failed: {e}")
# FALLBACK 1: deterministic router — guarantees a correct answer.
try:
fast = _fast_answer(question, records, budget)
if fast is not None:
return {"reply": fast[0], "trace": fast[1]}
except Exception as e: # pragma: no cover
print(f"[agent] fallback router failed: {e}")
# FALLBACK 2: grounded one-shot answer.
return {"reply": chat.answer(question, records), "trace": []}