import json import uuid import time import logging import gradio as gr from langchain_core.messages import HumanMessage from langgraph.types import Command from src.config import settings from src.agent.nodes import reset_model from src.db.connection import get_connection logger = logging.getLogger("cashy.ui") def list_threads(): """Get all thread_ids from the checkpoints table, most recent first.""" try: with get_connection() as conn: with conn.cursor() as cur: cur.execute(""" SELECT thread_id, MAX(checkpoint_id) AS latest FROM checkpoints GROUP BY thread_id ORDER BY latest DESC """) return [row[0] for row in cur.fetchall()] except Exception as e: logger.warning("Could not list threads: %s", e) return [] def load_thread_history(agent, thread_id): """Load messages from a thread and convert to Gradio chatbot format.""" config = {"configurable": {"thread_id": thread_id}} state = agent.get_state(config) messages = state.values.get("messages", []) history = [] for msg in messages: if msg.type == "human": history.append({"role": "user", "content": msg.content}) elif msg.type == "ai" and msg.content: history.append({"role": "assistant", "content": msg.content}) elif msg.type == "tool": try: data = json.loads(msg.content) if isinstance(data, dict) and "chart_path" in data: history.append({"role": "assistant", "content": {"path": data["chart_path"]}}) except (json.JSONDecodeError, TypeError): pass return history def get_thread_title(agent, thread_id): """Extract first user message as thread title (truncated to 50 chars).""" config = {"configurable": {"thread_id": thread_id}} try: state = agent.get_state(config) for msg in state.values.get("messages", []): if msg.type == "human": title = msg.content[:50] return title + "..." if len(msg.content) > 50 else title except Exception: pass return thread_id[:12] def get_thread_choices(agent, min_user_messages=2): """Build dropdown choices as (title, thread_id) tuples. Only includes threads with at least min_user_messages user messages, filtering out orphan single-exchange threads. """ threads = list_threads() choices = [] for tid in threads: config = {"configurable": {"thread_id": tid}} try: state = agent.get_state(config) messages = state.values.get("messages", []) user_msgs = [m for m in messages if m.type == "human"] if len(user_msgs) < min_user_messages: continue first_msg = user_msgs[0].content[:50] title = first_msg + "..." if len(user_msgs[0].content) > 50 else first_msg choices.append((title, tid)) except Exception: continue return choices def format_confirmation(interrupt_data): """Format an interrupt payload as a user-friendly confirmation message.""" action = interrupt_data.get("action", "unknown") message = interrupt_data.get("message", "Confirm this action?") # Build a readable action label action_labels = { "create_transaction": "Create Transaction", "update_transaction": "Update Transaction", "delete_transaction": "Delete Transaction", } label = action_labels.get(action, action.replace("_", " ").title()) lines = [f"**Confirm: {label}**\n"] # Show details as a table details = interrupt_data.get("details", {}) if details: lines.append("| Field | Value |") lines.append("|-------|-------|") for key, value in details.items(): display_key = key.replace("_", " ").title() if key == "amount": display_value = f"${value:,.2f}" else: display_value = str(value) lines.append(f"| {display_key} | {display_value} |") lines.append("") # Show changes for update operations changes = interrupt_data.get("changes", {}) if changes: lines.append("**Changes:**\n") lines.append("| Field | New Value |") lines.append("|-------|-----------|") for key, value in changes.items(): display_key = key.replace("_", " ").title() if key == "amount": display_value = f"${value:,.2f}" else: display_value = str(value) lines.append(f"| {display_key} | {display_value} |") lines.append("") # Show current values for update operations current = interrupt_data.get("current", {}) if current: lines.append("**Current values:**\n") lines.append("| Field | Value |") lines.append("|-------|-------|") for key, value in current.items(): display_key = key.replace("_", " ").title() if key == "amount": display_value = f"${value:,.2f}" else: display_value = str(value) lines.append(f"| {display_key} | {display_value} |") lines.append("") lines.append("Reply **yes** to confirm or **no** to cancel.") return "\n".join(lines) WELCOME_MESSAGE_DEMO = """\ Hi! I'm **Cashy**, your AI financial advisor. I'm connected to a demo database with **4 months of financial data** for a US-based freelance web developer: - **11 accounts** — Chase, PayPal, Stripe, Wise, Marcus, Fidelity, credit cards, and cash - **233 transactions** — client invoices, business expenses, personal spending, transfers - **20 budgets** — monthly spending limits across 35 categories **Ready to go** with the free tier, or switch to your own LLM provider in the sidebar. Ask me anything about your finances. Here are some ideas: 1. **"What accounts do I have?"** — See all accounts and balances 2. **"How much did I spend this month?"** — Spending breakdown by category 3. **"How much did I earn from clients in January?"** — Income tracking 4. **"Am I over budget on anything?"** — Budget vs. actual comparison 5. **"Show me my last 10 transactions"** — Recent transaction history """ WELCOME_MESSAGE_PERSONAL = """\ Hi! I'm **Cashy**, your AI financial advisor. I'm connected to your personal financial database. Ask me anything about your accounts, transactions, spending, or budgets. Here are some things I can help with: 1. **"What accounts do I have?"** — See all accounts and balances 2. **"How much did I spend this month?"** — Spending breakdown by category 3. **"Show me my last 10 transactions"** — Recent transaction history 4. **"Am I over budget on anything?"** — Budget vs. actual comparison 5. **"Show me a chart of my spending"** — Visual spending analysis """ PROVIDERS = ["free-tier", "openai", "anthropic", "google", "huggingface"] DEFAULT_MODELS = { "free-tier": "Qwen/Qwen2.5-7B-Instruct", "openai": "gpt-5-mini", "anthropic": "claude-sonnet-4-20250514", "google": "gemini-2.5-flash", "huggingface": "meta-llama/Llama-3.3-70B-Instruct", } HF_INFERENCE_PROVIDERS = [ "cerebras", "cohere", "featherless-ai", "fireworks-ai", "groq", "hf-inference", "hyperbolic", "nebius", "novita", "nscale", "ovhcloud", "sambanova", "scaleway", "together", ] def create_ui(agent): """Create the Gradio chat UI with compact reference sidebar.""" current_provider = settings.resolved_provider or "openai" has_provider = settings.resolved_provider is not None is_free = current_provider == "free-tier" is_demo = settings.app_mode == "demo" mode_label = "Demo" if is_demo else "Personal" welcome_text = WELCOME_MESSAGE_DEMO if is_demo else WELCOME_MESSAGE_PERSONAL FREE_TIER_DISCLAIMER = ( "\n\n---\n*Free tier uses a lightweight open-source model. " "For better results, switch to OpenAI, Anthropic, or Google in the sidebar.*" ) def respond(message, history, thread_id, pending): config = {"configurable": {"thread_id": thread_id}} logger.info(">>> User [thread=%s]: %s", thread_id[:8], message) start = time.time() try: # --- Resume from interrupt (user confirming/rejecting) --- if pending: approved = message.strip().lower() in ("yes", "approve", "confirm", "y") logger.info("Interrupt response: %s", "approved" if approved else "rejected") result = agent.invoke(Command(resume={"approved": approved}), config) response = result["messages"][-1].content elapsed = time.time() - start logger.info("<<< Response [%.1fs]: %s", elapsed, response[:120]) if settings.resolved_provider == "free-tier": response += FREE_TIER_DISCLAIMER history.append({"role": "user", "content": message}) history.append({"role": "assistant", "content": response}) return "", history, thread_id, False # --- Normal flow --- # Count existing messages so we only scan new ones for charts state = agent.get_state(config) prev_count = len(state.values.get("messages", [])) result = agent.invoke( {"messages": [HumanMessage(content=message)]}, config, ) # --- Check for interrupt (write operation needs confirmation) --- if "__interrupt__" in result: interrupt_data = result["__interrupt__"][0].value confirmation_msg = format_confirmation(interrupt_data) elapsed = time.time() - start logger.info("<<< Interrupt [%.1fs]: %s", elapsed, interrupt_data.get("action", "unknown")) history.append({"role": "user", "content": message}) history.append({"role": "assistant", "content": confirmation_msg}) return "", history, thread_id, True # --- Normal response (no interrupt) --- response = result["messages"][-1].content elapsed = time.time() - start logger.info("<<< Response [%.1fs]: %s", elapsed, response[:120]) # Scan only NEW messages for chart images (skip prior history) chart_paths = [] for msg in result["messages"][prev_count:]: if hasattr(msg, "type") and msg.type == "tool": try: data = json.loads(msg.content) if isinstance(data, dict) and "chart_path" in data: chart_paths.append(data["chart_path"]) except (json.JSONDecodeError, TypeError): pass if settings.resolved_provider == "free-tier": response += FREE_TIER_DISCLAIMER history.append({"role": "user", "content": message}) history.append({"role": "assistant", "content": response}) for path in chart_paths: history.append({"role": "assistant", "content": {"path": path}}) return "", history, thread_id, False except Exception as e: logger.error("<<< Error: %s", e) error_str = str(e).lower() if "ssl" in error_str or "connection" in error_str and "closed" in error_str: msg = ( "The database connection was lost (the cloud database likely went to sleep). " "Please try again in a few seconds -- it should reconnect automatically. " "If the issue persists, restart the Space from Settings." ) else: msg = f"**Error:** {e}" history.append({"role": "user", "content": message}) history.append({"role": "assistant", "content": msg}) return "", history, thread_id, False def switch_provider(provider): settings.llm_provider = provider settings.model_name = "" # reset to default for new provider reset_model() model = DEFAULT_MODELS.get(provider, "default") is_free = provider == "free-tier" is_hf = provider == "huggingface" show_byok = not is_free # free-tier hides API key, model, HF provider logger.info("Provider switched to: %s (%s)", provider, model) if is_free: status = f"Using **Free Tier** ({model}) -- no API key needed" else: status = f"Switched to **{provider.capitalize()}** ({model})" return ( status, gr.update(visible=show_byok, value=""), gr.update(visible=show_byok, placeholder=f"Default: {model}", value=""), gr.update(visible=is_hf), gr.update(visible=show_byok), ) def set_api_key(provider, api_key, model_name, hf_provider): key = api_key.strip() if not key: return "No key entered." key_fields = { "openai": "openai_api_key", "anthropic": "anthropic_api_key", "google": "google_api_key", "huggingface": "hf_token", } field = key_fields.get(provider) if not field: return f"Unknown provider: {provider}" setattr(settings, field, key) settings.llm_provider = provider if model_name.strip(): settings.model_name = model_name.strip() if provider == "huggingface" and hf_provider: settings.hf_inference_provider = hf_provider reset_model() model = settings.model_name or DEFAULT_MODELS.get(provider, "default") logger.info("API key set for provider: %s (%s)", provider, model) return f"API key saved. Using **{provider.capitalize()}** ({model})." def set_model(provider, model_name): name = model_name.strip() settings.model_name = name reset_model() model = name or DEFAULT_MODELS.get(provider, "default") logger.info("Model changed to: %s", model) return f"Model set to **{model}**." def set_hf_provider(hf_provider): settings.hf_inference_provider = hf_provider reset_model() logger.info("HF inference provider changed to: %s", hf_provider) return f"Inference provider set to **{hf_provider}**." welcome = [{"role": "assistant", "content": welcome_text}] if is_demo: theme = gr.themes.Glass(primary_hue="indigo") else: theme = gr.themes.Default() with gr.Blocks(title="Cashy - AI Financial Advisor") as demo: gr.Markdown("# Cashy — AI Financial Advisor") session_thread_id = gr.State(value=lambda: str(uuid.uuid4())) pending_interrupt = gr.State(value=False) with gr.Row(): with gr.Column(scale=3): chatbot = gr.Chatbot( value=welcome, height=600, buttons=["copy"], ) with gr.Row(): msg = gr.Textbox( placeholder="Ask about your finances...", show_label=False, scale=9, ) submit_btn = gr.Button("Send", variant="primary", scale=1) with gr.Column(scale=1, min_width=250): new_chat_btn = gr.Button("+ New Chat", variant="secondary") if not is_demo: with gr.Accordion("Chat History", open=False): thread_dropdown = gr.Dropdown( choices=[], label="Previous chats", interactive=True, ) load_btn = gr.Button("Load Chat") gr.Markdown(f"**Mode:** {mode_label}") gr.Markdown("---") if is_demo: gr.Markdown( "**Demo Data** · Oct 2025 – Jan 2026 · USD\n\n" "11 accounts · 233 transactions · 20 budgets" ) gr.Markdown("---") gr.Markdown( "**Capabilities**\n\n" "- Check account balances\n" "- Analyze spending by category\n" "- Search transaction history\n" "- Compare budgets vs. actual\n" "- Create, update, delete transactions\n" "- Run custom SQL queries" ) gr.Markdown("---") if is_demo: gr.Markdown( "**Try asking**\n\n" '*"What\'s the balance on Chase Business?"*\n\n' '*"Am I over budget on anything?"*\n\n' '*"I need a $1,500 laptop -- can I afford it?"*\n\n' '*"Show me a pie chart of my spending"*\n\n' '*"Chart my budget vs actual for January"*' ) else: gr.Markdown( "**Try asking**\n\n" '*"What accounts do I have?"*\n\n' '*"How much did I spend this month?"*\n\n' '*"Show me a chart of my spending"*' ) gr.Markdown("---") provider_dropdown = gr.Dropdown( choices=PROVIDERS, value=current_provider, label="LLM Provider", ) with gr.Row(): api_key_input = gr.Textbox( label="API Key", placeholder="Paste your API key here...", type="password", scale=4, visible=not is_free, ) save_key_btn = gr.Button("Save", variant="primary", scale=1, visible=not is_free) model_name_input = gr.Textbox( label="Model Name (optional)", placeholder=f"Default: {DEFAULT_MODELS.get(current_provider, '')}", value="", visible=not is_free, ) hf_provider_dropdown = gr.Dropdown( choices=HF_INFERENCE_PROVIDERS, value=settings.hf_inference_provider, label="Inference Provider", visible=current_provider == "huggingface", ) if is_free: status_text = f"Using **Free Tier** ({DEFAULT_MODELS['free-tier']}) -- no API key needed" elif has_provider: status_text = f"Using **{current_provider.capitalize()}** ({DEFAULT_MODELS.get(current_provider, 'default')})" else: status_text = "No API key configured -- select a provider and enter one above" provider_status = gr.Markdown(status_text) # --- Event handlers --- # Chat events (now include pending_interrupt state) chat_inputs = [msg, chatbot, session_thread_id, pending_interrupt] chat_outputs = [msg, chatbot, session_thread_id, pending_interrupt] if is_demo: def new_chat_demo(): new_id = str(uuid.uuid4()) logger.info("New chat started [thread=%s]", new_id[:8]) return new_id, welcome, "", False msg.submit(respond, chat_inputs, chat_outputs) submit_btn.click(respond, chat_inputs, chat_outputs) new_chat_btn.click( new_chat_demo, [], [session_thread_id, chatbot, msg, pending_interrupt] ) else: def new_chat(): new_id = str(uuid.uuid4()) choices = get_thread_choices(agent) logger.info("New chat started [thread=%s]", new_id[:8]) return new_id, welcome, "", gr.update(choices=choices), False def load_thread(selected_thread_id): if not selected_thread_id: return gr.update(), gr.update(), False history = load_thread_history(agent, selected_thread_id) logger.info("Loaded thread %s (%d messages)", selected_thread_id[:8], len(history)) return selected_thread_id, history, False def refresh_threads(): choices = get_thread_choices(agent) return gr.update(choices=choices) msg.submit(respond, chat_inputs, chat_outputs).then( refresh_threads, [], [thread_dropdown] ) submit_btn.click(respond, chat_inputs, chat_outputs).then( refresh_threads, [], [thread_dropdown] ) new_chat_btn.click( new_chat, [], [session_thread_id, chatbot, msg, thread_dropdown, pending_interrupt] ) load_btn.click( load_thread, [thread_dropdown], [session_thread_id, chatbot, pending_interrupt] ) provider_dropdown.change( switch_provider, [provider_dropdown], [provider_status, api_key_input, model_name_input, hf_provider_dropdown, save_key_btn], ) api_key_inputs = [provider_dropdown, api_key_input, model_name_input, hf_provider_dropdown] api_key_input.submit(set_api_key, api_key_inputs, [provider_status]) save_key_btn.click(set_api_key, api_key_inputs, [provider_status]) model_name_input.submit(set_model, [provider_dropdown, model_name_input], [provider_status]) hf_provider_dropdown.change(set_hf_provider, [hf_provider_dropdown], [provider_status]) # Populate thread list on page load (personal mode only) if not is_demo: demo.load(refresh_threads, [], [thread_dropdown]) return demo, theme