Spaces:
Sleeping
Sleeping
| import json | |
| import uuid | |
| import time | |
| import logging | |
| import gradio as gr | |
| from langchain_core.messages import HumanMessage | |
| from langgraph.types import Command | |
| from src.config import settings | |
| from src.agent.nodes import reset_model | |
| from src.db.connection import get_connection | |
| logger = logging.getLogger("cashy.ui") | |
| def list_threads(): | |
| """Get all thread_ids from the checkpoints table, most recent first.""" | |
| try: | |
| with get_connection() as conn: | |
| with conn.cursor() as cur: | |
| cur.execute(""" | |
| SELECT thread_id, MAX(checkpoint_id) AS latest | |
| FROM checkpoints | |
| GROUP BY thread_id | |
| ORDER BY latest DESC | |
| """) | |
| return [row[0] for row in cur.fetchall()] | |
| except Exception as e: | |
| logger.warning("Could not list threads: %s", e) | |
| return [] | |
| def load_thread_history(agent, thread_id): | |
| """Load messages from a thread and convert to Gradio chatbot format.""" | |
| config = {"configurable": {"thread_id": thread_id}} | |
| state = agent.get_state(config) | |
| messages = state.values.get("messages", []) | |
| history = [] | |
| for msg in messages: | |
| if msg.type == "human": | |
| history.append({"role": "user", "content": msg.content}) | |
| elif msg.type == "ai" and msg.content: | |
| history.append({"role": "assistant", "content": msg.content}) | |
| elif msg.type == "tool": | |
| try: | |
| data = json.loads(msg.content) | |
| if isinstance(data, dict) and "chart_path" in data: | |
| history.append({"role": "assistant", "content": {"path": data["chart_path"]}}) | |
| except (json.JSONDecodeError, TypeError): | |
| pass | |
| return history | |
| def get_thread_title(agent, thread_id): | |
| """Extract first user message as thread title (truncated to 50 chars).""" | |
| config = {"configurable": {"thread_id": thread_id}} | |
| try: | |
| state = agent.get_state(config) | |
| for msg in state.values.get("messages", []): | |
| if msg.type == "human": | |
| title = msg.content[:50] | |
| return title + "..." if len(msg.content) > 50 else title | |
| except Exception: | |
| pass | |
| return thread_id[:12] | |
| def get_thread_choices(agent, min_user_messages=2): | |
| """Build dropdown choices as (title, thread_id) tuples. | |
| Only includes threads with at least min_user_messages user messages, | |
| filtering out orphan single-exchange threads. | |
| """ | |
| threads = list_threads() | |
| choices = [] | |
| for tid in threads: | |
| config = {"configurable": {"thread_id": tid}} | |
| try: | |
| state = agent.get_state(config) | |
| messages = state.values.get("messages", []) | |
| user_msgs = [m for m in messages if m.type == "human"] | |
| if len(user_msgs) < min_user_messages: | |
| continue | |
| first_msg = user_msgs[0].content[:50] | |
| title = first_msg + "..." if len(user_msgs[0].content) > 50 else first_msg | |
| choices.append((title, tid)) | |
| except Exception: | |
| continue | |
| return choices | |
| def format_confirmation(interrupt_data): | |
| """Format an interrupt payload as a user-friendly confirmation message.""" | |
| action = interrupt_data.get("action", "unknown") | |
| message = interrupt_data.get("message", "Confirm this action?") | |
| # Build a readable action label | |
| action_labels = { | |
| "create_transaction": "Create Transaction", | |
| "update_transaction": "Update Transaction", | |
| "delete_transaction": "Delete Transaction", | |
| } | |
| label = action_labels.get(action, action.replace("_", " ").title()) | |
| lines = [f"**Confirm: {label}**\n"] | |
| # Show details as a table | |
| details = interrupt_data.get("details", {}) | |
| if details: | |
| lines.append("| Field | Value |") | |
| lines.append("|-------|-------|") | |
| for key, value in details.items(): | |
| display_key = key.replace("_", " ").title() | |
| if key == "amount": | |
| display_value = f"${value:,.2f}" | |
| else: | |
| display_value = str(value) | |
| lines.append(f"| {display_key} | {display_value} |") | |
| lines.append("") | |
| # Show changes for update operations | |
| changes = interrupt_data.get("changes", {}) | |
| if changes: | |
| lines.append("**Changes:**\n") | |
| lines.append("| Field | New Value |") | |
| lines.append("|-------|-----------|") | |
| for key, value in changes.items(): | |
| display_key = key.replace("_", " ").title() | |
| if key == "amount": | |
| display_value = f"${value:,.2f}" | |
| else: | |
| display_value = str(value) | |
| lines.append(f"| {display_key} | {display_value} |") | |
| lines.append("") | |
| # Show current values for update operations | |
| current = interrupt_data.get("current", {}) | |
| if current: | |
| lines.append("**Current values:**\n") | |
| lines.append("| Field | Value |") | |
| lines.append("|-------|-------|") | |
| for key, value in current.items(): | |
| display_key = key.replace("_", " ").title() | |
| if key == "amount": | |
| display_value = f"${value:,.2f}" | |
| else: | |
| display_value = str(value) | |
| lines.append(f"| {display_key} | {display_value} |") | |
| lines.append("") | |
| lines.append("Reply **yes** to confirm or **no** to cancel.") | |
| return "\n".join(lines) | |
| WELCOME_MESSAGE_DEMO = """\ | |
| Hi! I'm **Cashy**, your AI financial advisor. | |
| I'm connected to a demo database with **4 months of financial data** for a US-based freelance web developer: | |
| - **11 accounts** — Chase, PayPal, Stripe, Wise, Marcus, Fidelity, credit cards, and cash | |
| - **233 transactions** — client invoices, business expenses, personal spending, transfers | |
| - **20 budgets** — monthly spending limits across 35 categories | |
| **Ready to go** with the free tier, or switch to your own LLM provider in the sidebar. | |
| Ask me anything about your finances. Here are some ideas: | |
| 1. **"What accounts do I have?"** — See all accounts and balances | |
| 2. **"How much did I spend this month?"** — Spending breakdown by category | |
| 3. **"How much did I earn from clients in January?"** — Income tracking | |
| 4. **"Am I over budget on anything?"** — Budget vs. actual comparison | |
| 5. **"Show me my last 10 transactions"** — Recent transaction history | |
| """ | |
| WELCOME_MESSAGE_PERSONAL = """\ | |
| Hi! I'm **Cashy**, your AI financial advisor. | |
| I'm connected to your personal financial database. Ask me anything about your accounts, transactions, spending, or budgets. | |
| Here are some things I can help with: | |
| 1. **"What accounts do I have?"** — See all accounts and balances | |
| 2. **"How much did I spend this month?"** — Spending breakdown by category | |
| 3. **"Show me my last 10 transactions"** — Recent transaction history | |
| 4. **"Am I over budget on anything?"** — Budget vs. actual comparison | |
| 5. **"Show me a chart of my spending"** — Visual spending analysis | |
| """ | |
| PROVIDERS = ["free-tier", "openai", "anthropic", "google", "huggingface"] | |
| DEFAULT_MODELS = { | |
| "free-tier": "Qwen/Qwen2.5-7B-Instruct", | |
| "openai": "gpt-5-mini", | |
| "anthropic": "claude-sonnet-4-20250514", | |
| "google": "gemini-2.5-flash", | |
| "huggingface": "meta-llama/Llama-3.3-70B-Instruct", | |
| } | |
| HF_INFERENCE_PROVIDERS = [ | |
| "cerebras", | |
| "cohere", | |
| "featherless-ai", | |
| "fireworks-ai", | |
| "groq", | |
| "hf-inference", | |
| "hyperbolic", | |
| "nebius", | |
| "novita", | |
| "nscale", | |
| "ovhcloud", | |
| "sambanova", | |
| "scaleway", | |
| "together", | |
| ] | |
| def create_ui(agent): | |
| """Create the Gradio chat UI with compact reference sidebar.""" | |
| current_provider = settings.resolved_provider or "openai" | |
| has_provider = settings.resolved_provider is not None | |
| is_free = current_provider == "free-tier" | |
| is_demo = settings.app_mode == "demo" | |
| mode_label = "Demo" if is_demo else "Personal" | |
| welcome_text = WELCOME_MESSAGE_DEMO if is_demo else WELCOME_MESSAGE_PERSONAL | |
| FREE_TIER_DISCLAIMER = ( | |
| "\n\n---\n*Free tier uses a lightweight open-source model. " | |
| "For better results, switch to OpenAI, Anthropic, or Google in the sidebar.*" | |
| ) | |
| def respond(message, history, thread_id, pending): | |
| config = {"configurable": {"thread_id": thread_id}} | |
| logger.info(">>> User [thread=%s]: %s", thread_id[:8], message) | |
| start = time.time() | |
| try: | |
| # --- Resume from interrupt (user confirming/rejecting) --- | |
| if pending: | |
| approved = message.strip().lower() in ("yes", "approve", "confirm", "y") | |
| logger.info("Interrupt response: %s", "approved" if approved else "rejected") | |
| result = agent.invoke(Command(resume={"approved": approved}), config) | |
| response = result["messages"][-1].content | |
| elapsed = time.time() - start | |
| logger.info("<<< Response [%.1fs]: %s", elapsed, response[:120]) | |
| if settings.resolved_provider == "free-tier": | |
| response += FREE_TIER_DISCLAIMER | |
| history.append({"role": "user", "content": message}) | |
| history.append({"role": "assistant", "content": response}) | |
| return "", history, thread_id, False | |
| # --- Normal flow --- | |
| # Count existing messages so we only scan new ones for charts | |
| state = agent.get_state(config) | |
| prev_count = len(state.values.get("messages", [])) | |
| result = agent.invoke( | |
| {"messages": [HumanMessage(content=message)]}, | |
| config, | |
| ) | |
| # --- Check for interrupt (write operation needs confirmation) --- | |
| if "__interrupt__" in result: | |
| interrupt_data = result["__interrupt__"][0].value | |
| confirmation_msg = format_confirmation(interrupt_data) | |
| elapsed = time.time() - start | |
| logger.info("<<< Interrupt [%.1fs]: %s", elapsed, interrupt_data.get("action", "unknown")) | |
| history.append({"role": "user", "content": message}) | |
| history.append({"role": "assistant", "content": confirmation_msg}) | |
| return "", history, thread_id, True | |
| # --- Normal response (no interrupt) --- | |
| response = result["messages"][-1].content | |
| elapsed = time.time() - start | |
| logger.info("<<< Response [%.1fs]: %s", elapsed, response[:120]) | |
| # Scan only NEW messages for chart images (skip prior history) | |
| chart_paths = [] | |
| for msg in result["messages"][prev_count:]: | |
| if hasattr(msg, "type") and msg.type == "tool": | |
| try: | |
| data = json.loads(msg.content) | |
| if isinstance(data, dict) and "chart_path" in data: | |
| chart_paths.append(data["chart_path"]) | |
| except (json.JSONDecodeError, TypeError): | |
| pass | |
| if settings.resolved_provider == "free-tier": | |
| response += FREE_TIER_DISCLAIMER | |
| history.append({"role": "user", "content": message}) | |
| history.append({"role": "assistant", "content": response}) | |
| for path in chart_paths: | |
| history.append({"role": "assistant", "content": {"path": path}}) | |
| return "", history, thread_id, False | |
| except Exception as e: | |
| logger.error("<<< Error: %s", e) | |
| error_str = str(e).lower() | |
| if "ssl" in error_str or "connection" in error_str and "closed" in error_str: | |
| msg = ( | |
| "The database connection was lost (the cloud database likely went to sleep). " | |
| "Please try again in a few seconds -- it should reconnect automatically. " | |
| "If the issue persists, restart the Space from Settings." | |
| ) | |
| else: | |
| msg = f"**Error:** {e}" | |
| history.append({"role": "user", "content": message}) | |
| history.append({"role": "assistant", "content": msg}) | |
| return "", history, thread_id, False | |
| def switch_provider(provider): | |
| settings.llm_provider = provider | |
| settings.model_name = "" # reset to default for new provider | |
| reset_model() | |
| model = DEFAULT_MODELS.get(provider, "default") | |
| is_free = provider == "free-tier" | |
| is_hf = provider == "huggingface" | |
| show_byok = not is_free # free-tier hides API key, model, HF provider | |
| logger.info("Provider switched to: %s (%s)", provider, model) | |
| if is_free: | |
| status = f"Using **Free Tier** ({model}) -- no API key needed" | |
| else: | |
| status = f"Switched to **{provider.capitalize()}** ({model})" | |
| return ( | |
| status, | |
| gr.update(visible=show_byok, value=""), | |
| gr.update(visible=show_byok, placeholder=f"Default: {model}", value=""), | |
| gr.update(visible=is_hf), | |
| gr.update(visible=show_byok), | |
| ) | |
| def set_api_key(provider, api_key, model_name, hf_provider): | |
| key = api_key.strip() | |
| if not key: | |
| return "No key entered." | |
| key_fields = { | |
| "openai": "openai_api_key", | |
| "anthropic": "anthropic_api_key", | |
| "google": "google_api_key", | |
| "huggingface": "hf_token", | |
| } | |
| field = key_fields.get(provider) | |
| if not field: | |
| return f"Unknown provider: {provider}" | |
| setattr(settings, field, key) | |
| settings.llm_provider = provider | |
| if model_name.strip(): | |
| settings.model_name = model_name.strip() | |
| if provider == "huggingface" and hf_provider: | |
| settings.hf_inference_provider = hf_provider | |
| reset_model() | |
| model = settings.model_name or DEFAULT_MODELS.get(provider, "default") | |
| logger.info("API key set for provider: %s (%s)", provider, model) | |
| return f"API key saved. Using **{provider.capitalize()}** ({model})." | |
| def set_model(provider, model_name): | |
| name = model_name.strip() | |
| settings.model_name = name | |
| reset_model() | |
| model = name or DEFAULT_MODELS.get(provider, "default") | |
| logger.info("Model changed to: %s", model) | |
| return f"Model set to **{model}**." | |
| def set_hf_provider(hf_provider): | |
| settings.hf_inference_provider = hf_provider | |
| reset_model() | |
| logger.info("HF inference provider changed to: %s", hf_provider) | |
| return f"Inference provider set to **{hf_provider}**." | |
| welcome = [{"role": "assistant", "content": welcome_text}] | |
| if is_demo: | |
| theme = gr.themes.Glass(primary_hue="indigo") | |
| else: | |
| theme = gr.themes.Default() | |
| with gr.Blocks(title="Cashy - AI Financial Advisor") as demo: | |
| gr.Markdown("# Cashy — AI Financial Advisor") | |
| session_thread_id = gr.State(value=lambda: str(uuid.uuid4())) | |
| pending_interrupt = gr.State(value=False) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| chatbot = gr.Chatbot( | |
| value=welcome, | |
| height=600, | |
| buttons=["copy"], | |
| ) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| placeholder="Ask about your finances...", | |
| show_label=False, | |
| scale=9, | |
| ) | |
| submit_btn = gr.Button("Send", variant="primary", scale=1) | |
| with gr.Column(scale=1, min_width=250): | |
| new_chat_btn = gr.Button("+ New Chat", variant="secondary") | |
| if not is_demo: | |
| with gr.Accordion("Chat History", open=False): | |
| thread_dropdown = gr.Dropdown( | |
| choices=[], | |
| label="Previous chats", | |
| interactive=True, | |
| ) | |
| load_btn = gr.Button("Load Chat") | |
| gr.Markdown(f"**Mode:** {mode_label}") | |
| gr.Markdown("---") | |
| if is_demo: | |
| gr.Markdown( | |
| "**Demo Data** · Oct 2025 – Jan 2026 · USD\n\n" | |
| "11 accounts · 233 transactions · 20 budgets" | |
| ) | |
| gr.Markdown("---") | |
| gr.Markdown( | |
| "**Capabilities**\n\n" | |
| "- Check account balances\n" | |
| "- Analyze spending by category\n" | |
| "- Search transaction history\n" | |
| "- Compare budgets vs. actual\n" | |
| "- Create, update, delete transactions\n" | |
| "- Run custom SQL queries" | |
| ) | |
| gr.Markdown("---") | |
| if is_demo: | |
| gr.Markdown( | |
| "**Try asking**\n\n" | |
| '*"What\'s the balance on Chase Business?"*\n\n' | |
| '*"Am I over budget on anything?"*\n\n' | |
| '*"I need a $1,500 laptop -- can I afford it?"*\n\n' | |
| '*"Show me a pie chart of my spending"*\n\n' | |
| '*"Chart my budget vs actual for January"*' | |
| ) | |
| else: | |
| gr.Markdown( | |
| "**Try asking**\n\n" | |
| '*"What accounts do I have?"*\n\n' | |
| '*"How much did I spend this month?"*\n\n' | |
| '*"Show me a chart of my spending"*' | |
| ) | |
| gr.Markdown("---") | |
| provider_dropdown = gr.Dropdown( | |
| choices=PROVIDERS, | |
| value=current_provider, | |
| label="LLM Provider", | |
| ) | |
| with gr.Row(): | |
| api_key_input = gr.Textbox( | |
| label="API Key", | |
| placeholder="Paste your API key here...", | |
| type="password", | |
| scale=4, | |
| visible=not is_free, | |
| ) | |
| save_key_btn = gr.Button("Save", variant="primary", scale=1, visible=not is_free) | |
| model_name_input = gr.Textbox( | |
| label="Model Name (optional)", | |
| placeholder=f"Default: {DEFAULT_MODELS.get(current_provider, '')}", | |
| value="", | |
| visible=not is_free, | |
| ) | |
| hf_provider_dropdown = gr.Dropdown( | |
| choices=HF_INFERENCE_PROVIDERS, | |
| value=settings.hf_inference_provider, | |
| label="Inference Provider", | |
| visible=current_provider == "huggingface", | |
| ) | |
| if is_free: | |
| status_text = f"Using **Free Tier** ({DEFAULT_MODELS['free-tier']}) -- no API key needed" | |
| elif has_provider: | |
| status_text = f"Using **{current_provider.capitalize()}** ({DEFAULT_MODELS.get(current_provider, 'default')})" | |
| else: | |
| status_text = "No API key configured -- select a provider and enter one above" | |
| provider_status = gr.Markdown(status_text) | |
| # --- Event handlers --- | |
| # Chat events (now include pending_interrupt state) | |
| chat_inputs = [msg, chatbot, session_thread_id, pending_interrupt] | |
| chat_outputs = [msg, chatbot, session_thread_id, pending_interrupt] | |
| if is_demo: | |
| def new_chat_demo(): | |
| new_id = str(uuid.uuid4()) | |
| logger.info("New chat started [thread=%s]", new_id[:8]) | |
| return new_id, welcome, "", False | |
| msg.submit(respond, chat_inputs, chat_outputs) | |
| submit_btn.click(respond, chat_inputs, chat_outputs) | |
| new_chat_btn.click( | |
| new_chat_demo, [], [session_thread_id, chatbot, msg, pending_interrupt] | |
| ) | |
| else: | |
| def new_chat(): | |
| new_id = str(uuid.uuid4()) | |
| choices = get_thread_choices(agent) | |
| logger.info("New chat started [thread=%s]", new_id[:8]) | |
| return new_id, welcome, "", gr.update(choices=choices), False | |
| def load_thread(selected_thread_id): | |
| if not selected_thread_id: | |
| return gr.update(), gr.update(), False | |
| history = load_thread_history(agent, selected_thread_id) | |
| logger.info("Loaded thread %s (%d messages)", selected_thread_id[:8], len(history)) | |
| return selected_thread_id, history, False | |
| def refresh_threads(): | |
| choices = get_thread_choices(agent) | |
| return gr.update(choices=choices) | |
| msg.submit(respond, chat_inputs, chat_outputs).then( | |
| refresh_threads, [], [thread_dropdown] | |
| ) | |
| submit_btn.click(respond, chat_inputs, chat_outputs).then( | |
| refresh_threads, [], [thread_dropdown] | |
| ) | |
| new_chat_btn.click( | |
| new_chat, [], [session_thread_id, chatbot, msg, thread_dropdown, pending_interrupt] | |
| ) | |
| load_btn.click( | |
| load_thread, [thread_dropdown], [session_thread_id, chatbot, pending_interrupt] | |
| ) | |
| provider_dropdown.change( | |
| switch_provider, [provider_dropdown], | |
| [provider_status, api_key_input, model_name_input, hf_provider_dropdown, save_key_btn], | |
| ) | |
| api_key_inputs = [provider_dropdown, api_key_input, model_name_input, hf_provider_dropdown] | |
| api_key_input.submit(set_api_key, api_key_inputs, [provider_status]) | |
| save_key_btn.click(set_api_key, api_key_inputs, [provider_status]) | |
| model_name_input.submit(set_model, [provider_dropdown, model_name_input], [provider_status]) | |
| hf_provider_dropdown.change(set_hf_provider, [hf_provider_dropdown], [provider_status]) | |
| # Populate thread list on page load (personal mode only) | |
| if not is_demo: | |
| demo.load(refresh_threads, [], [thread_dropdown]) | |
| return demo, theme | |