# app.py import os import json from datetime import datetime import re import pandas as pd import streamlit as st from huggingface_hub import InferenceClient # ------------------ Setup PublicAI Inference Client ------------------ @st.cache_resource def get_publicai_client(): token = os.getenv("publicai") # Hugging Face space secret if not token: raise ValueError("Missing PublicAI token. Set 'PUBLICAI' env variable in the space secrets.") return InferenceClient(api_key=token) def _call_llama_space(prompt: str) -> str: try: client = get_publicai_client() response = client.chat.completions.create( model="swiss-ai/apertus-8b-instruct", messages=[{"role": "user", "content": prompt}], ) return response.choices[0].message["content"].strip() except Exception as e: return f"Error calling PublicAI model: {e}" # ------------------ Utils ------------------ try: from .utils import ( generate_synthetic_transactions, filter_transactions, compute_aggregations, build_time_series_chart, build_category_bar_chart, build_payment_method_pie_chart, ) except Exception: from utils import ( generate_synthetic_transactions, filter_transactions, compute_aggregations, build_time_series_chart, build_category_bar_chart, build_payment_method_pie_chart, ) # ------------------ Streamlit Setup ------------------ st.set_page_config(page_title="AI Spending Analyser", page_icon="💳", layout="wide") # ------------------ Helpers for Markdown/Spikes ------------------ def strip_markdown(text: str) -> str: if not isinstance(text, str): return text text = re.sub(r"(\*\*|__)(.*?)\1", r"\2", text) text = re.sub(r"(\*|_)(.*?)\1", r"\2", text) text = re.sub(r"`(.*?)`", r"\1", text) text = re.sub(r"~~(.*?)~~", r"\1", text) return text def markdown_bold_to_html(text: str) -> str: return re.sub(r"\*\*(.*?)\*\*", r"\1", text) def sanitize_llm_text_for_users(text: str) -> str: if not isinstance(text, str): return text text = re.sub(r"\bIsSpike\b", "spike indicator", text, flags=re.IGNORECASE) text = re.sub(r"\bTrue\b", "yes", text) text = re.sub(r"\bFalse\b", "no", text) return text # ------------------ Session State ------------------ def init_session_state(): if "data" not in st.session_state: st.session_state.data = generate_synthetic_transactions(n_rows=900, seed=42) if "filters" not in st.session_state: min_date = st.session_state.data["Date"].min() max_date = st.session_state.data["Date"].max() st.session_state.filters = {"date_range": (min_date, max_date), "categories": [], "merchant_query": ""} # ------------------ Header ------------------ def render_header(): st.markdown("""
^
AI Spending Analyser
""", unsafe_allow_html=True) def render_assistant_banner(): return def render_chat_fab(): return # ------------------ Sidebar ------------------ def render_sidebar(df: pd.DataFrame): st.sidebar.header("Filters") min_d, max_d = df["Date"].min(), df["Date"].max() col1, col2 = st.sidebar.columns(2) with col1: from_date = st.date_input("From", value=min_d.date(), min_value=min_d.date(), max_value=max_d.date()) with col2: to_date = st.date_input("To", value=max_d.date(), min_value=min_d.date(), max_value=max_d.date()) if from_date > to_date: st.sidebar.error("From date cannot be after To date") from_date, to_date = min_d.date(), max_d.date() categories = st.sidebar.multiselect("Category", options=sorted(df["Category"].unique()), default=[]) merchant_query = st.sidebar.text_input("Merchant search", value="", placeholder="Type a merchant name…") st.sidebar.divider() st.sidebar.header("AI") summary_mode = st.sidebar.radio("Summary", options=["Concise", "Detailed"], index=0, horizontal=True) engine = st.sidebar.selectbox("Engine", options=["swiss-ai/apertus-8b-instruct"], index=0) st.sidebar.divider() st.sidebar.header("Anomalies & Highlights") show_spikes = st.sidebar.toggle("Show spike markers", value=True) large_tx_threshold = st.sidebar.slider("Large transaction threshold (£)", 50, 1000, 250, step=25) if st.sidebar.button("Regenerate"): st.session_state.data = generate_synthetic_transactions(n_rows=900) st.session_state.filters = { "date_range": (datetime.combine(from_date, datetime.min.time()), datetime.combine(to_date, datetime.max.time())), "categories": categories, "merchant_query": merchant_query.strip(), "summary_mode": summary_mode, "engine": engine, "show_spikes": show_spikes, "large_tx_threshold": large_tx_threshold, } # ------------------ Metrics ------------------ def render_metrics(agg: dict): col1, col2, col3, col4 = st.columns(4) for col, label, value in zip( [col1, col2, col3, col4], ["Total Value", "Avg Monthly", "Max Transaction", "Min Transaction"], [agg["total_spend"], agg["avg_monthly_spend"], agg["max_transaction"]["Amount"], agg["min_transaction"]["Amount"]], ): gbp_whole, gbp_pence = f"{value:,.2f}".split(".") col.markdown(f"""
{label}
£{gbp_whole}{'.'+gbp_pence}
""", unsafe_allow_html=True) # ------------------ ISA Widget ------------------ def render_isa_widget(current_spend: float, allowance: float): used = min(current_spend, allowance) remaining = max(allowance - used, 0) percent = 0 if allowance <= 0 else int((used / allowance) * 100) st.markdown("
", unsafe_allow_html=True) st.subheader("ISA allowance") st.markdown(f"
", unsafe_allow_html=True) col1, col2 = st.columns(2) with col1: used_whole, used_pence = f"{used:,.2f}".split(".") st.markdown(f"
USED
£{used_whole}.{used_pence}
", unsafe_allow_html=True) with col2: rem_whole, rem_pence = f"{remaining:,.2f}".split(".") st.markdown(f"
REMAINING
£{rem_whole}.{rem_pence}
", unsafe_allow_html=True) st.markdown("
", unsafe_allow_html=True) # ------------------ Charts ------------------ def render_charts(filtered_df: pd.DataFrame, agg: dict, template: str, show_spikes: bool): t1, t2, t3 = st.tabs(["Trend", "By Category", "Payment Methods"]) chart_style = "chart-card" with t1: fig = build_time_series_chart(filtered_df, template=template, spike_overlay=agg["spikes"] if show_spikes else None) st.markdown(f"
", unsafe_allow_html=True) st.plotly_chart(fig, use_container_width=True, responsive=True) st.markdown("
", unsafe_allow_html=True) with t2: st.caption("Tip: Select categories in the sidebar to compare their total spend.") fig = build_category_bar_chart(agg["spend_per_category"], template=template) st.markdown(f"
", unsafe_allow_html=True) st.plotly_chart(fig, use_container_width=True, responsive=True) st.markdown("
", unsafe_allow_html=True) with t3: fig = build_payment_method_pie_chart(agg["spend_per_payment"], template=template) st.markdown(f"
", unsafe_allow_html=True) st.plotly_chart(fig, use_container_width=True, responsive=True) st.markdown("
", unsafe_allow_html=True) # ------------------ Heuristic AI Summary ------------------ def heuristic_summary(agg: dict, mode: str) -> str: total = agg.get("total_spend", 0) avg_month = agg.get("avg_monthly_spend", 0) spend_per_category = agg.get("spend_per_category", {}) df_spikes = agg.get("spikes", pd.DataFrame()) days = st.session_state.filters.get("date_range", (None, None)) num_days = (days[1] - days[0]).days + 1 if all(days) else "unknown" if df_spikes.empty or len(df_spikes) == 0: spike_text = "No unusual spending spikes were detected during this period." else: try: top_spike = df_spikes.loc[df_spikes["Amount"].idxmax()] spike_date = pd.to_datetime(top_spike["Date"]).strftime("%Y-%m-%d") spike_amt = top_spike["Amount"] spike_text = f"A noticeable spending spike occurred on {spike_date}, reaching £{spike_amt:,.2f}, which exceeded the 28-day rolling average." except Exception: spike_text = "Some days showed higher-than-normal spending compared to the 28-day rolling average." text = f"Over a period of {num_days} days, the total spend was £{total:,.2f} with an average monthly spend of £{avg_month:,.2f}. {spike_text} Here’s a breakdown of the spending patterns:\n\n" if spend_per_category: sorted_cats = sorted(spend_per_category.items(), key=lambda x: x[1], reverse=True) for i, (cat, val) in enumerate(sorted_cats[:8], 1): pct = val / total * 100 if total else 0 text += f"{i}. {cat}: £{val:,.2f} ({pct:.1f}% of total spend). " text += " Overall, this summary helps identify key categories and spending trends over the selected period." return text # ------------------ AI Summary ------------------ def render_ai_summary(agg: dict, mode: str, engine: str): st.subheader("AI Summary") placeholder = st.empty() placeholder.markdown("
Generating summary…
", unsafe_allow_html=True) if engine == "swiss-ai/apertus-8b-instruct": prompt = ( f"Generate a {mode.lower()} human-readable spending summary in clear sentences. " f"Use only the pound symbol (£) for all amounts. Ensure proper spacing between words and numbers and correct date formatting (YYYY-MM-DD). " f"Avoid exposing internal column names or raw boolean values. " f"Data: {json.dumps({'total_spend': agg.get('total_spend'), 'avg_monthly_spend': agg.get('avg_monthly_spend'), 'top_categories': agg.get('spend_per_category'), 'spikes': agg.get('spikes')}, default=str)}" ) raw_text = _call_llama_space(prompt) raw_text = raw_text.replace("$", "£") raw_text = re.sub(r"\b(\d{4})(\d{2})(\d{2})\b", r"\1-\2-\3", raw_text) raw_text = re.sub(r"[^\x00-\x7F]+", "", raw_text) raw_text = sanitize_llm_text_for_users(raw_text) raw_text = re.sub(r"(\d)([A-Za-z£])", r"\1 \2", raw_text) raw_text = re.sub(r"([A-Za-z£])(\d)", r"\1 \2", raw_text) clean_body = strip_markdown(raw_text).strip() else: clean_body = heuristic_summary(agg, mode).strip() clean_body = sanitize_llm_text_for_users(clean_body) header_md = f"The total spending over the period was **£{agg.get('total_spend', 0):,.2f}**, averaging **£{agg.get('avg_monthly_spend', 0):,.2f} per month**. " combined = header_md + " " + clean_body html_safe = markdown_bold_to_html(combined) placeholder.markdown( f"""
{html_safe}
""", unsafe_allow_html=True ) # ------------------ Main ------------------ def main(): init_session_state() st.markdown(""" """, unsafe_allow_html=True) render_header() render_assistant_banner() render_chat_fab() render_sidebar(st.session_state.data) filters = st.session_state.filters filtered = filter_transactions( st.session_state.data, date_range=filters["date_range"], categories=filters["categories"], merchant_query=filters["merchant_query"], ) if filtered.empty: st.info("No data for selected filters. Adjust filters to see insights.") return agg = compute_aggregations(filtered) st.markdown("
", unsafe_allow_html=True) render_metrics(agg) st.markdown("
", unsafe_allow_html=True) with st.expander("Allowance widget"): allowance = st.number_input("Annual allowance (£)", min_value=0, value=20000, step=500) render_isa_widget(current_spend=float(agg["total_spend"]), allowance=float(allowance)) template = "plotly_dark" render_charts(filtered, agg, template, show_spikes=filters["show_spikes"]) render_ai_summary(agg, mode=filters["summary_mode"], engine=filters["engine"]) threshold = filters["large_tx_threshold"] large_df = filtered[filtered["Amount"] >= threshold].sort_values("Amount", ascending=False) with st.expander(f"Show large transactions (≥ £{threshold}) [{len(large_df)}]"): st.dataframe(large_df, use_container_width=True, hide_index=True) st.divider() col1, col2 = st.columns([2, 1]) with col1: st.caption("Download filtered data") csv = filtered.to_csv(index=False).encode("utf-8") st.download_button("Download CSV", csv, file_name="transactions_filtered.csv", mime="text/csv") with col2: st.caption("Dataset size") st.write(f"{len(filtered):,} rows") if __name__ == "__main__": main()